support AzureOpenAI
This commit is contained in:
parent
e226645bc7
commit
28e308db01
11 changed files with 256 additions and 8 deletions
|
|
@ -14,12 +14,18 @@ import { chats, messages as messagesSchema } from '@/lib/db/schema';
|
|||
import { and, eq, gt } from 'drizzle-orm';
|
||||
import { getFileDetails } from '@/lib/utils/files';
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
|
||||
import {
|
||||
getCustomOpenaiApiKey,
|
||||
getCustomOpenaiApiUrl,
|
||||
getCustomOpenaiModelName,
|
||||
} from '@/lib/config';
|
||||
import {
|
||||
getAzureOpenaiApiKey,
|
||||
getAzureOpenaiEndpoint,
|
||||
getAzureOpenaiModelName,
|
||||
getAzureOpenaiApiVersion,
|
||||
} from '@/lib/config';
|
||||
import { searchHandlers } from '@/lib/search';
|
||||
|
||||
export const runtime = 'nodejs';
|
||||
|
|
@ -186,6 +192,8 @@ export const POST = async (req: Request) => {
|
|||
const body = (await req.json()) as Body;
|
||||
const { message } = body;
|
||||
|
||||
console.error('An error occurred while processing chat request:', "here");
|
||||
|
||||
if (message.content === '') {
|
||||
return Response.json(
|
||||
{
|
||||
|
|
@ -222,6 +230,7 @@ export const POST = async (req: Request) => {
|
|||
let embedding = embeddingModel.model;
|
||||
|
||||
if (body.chatModel?.provider === 'custom_openai') {
|
||||
console.error('An error occurred while processing chat request:', "custom_openai");
|
||||
llm = new ChatOpenAI({
|
||||
openAIApiKey: getCustomOpenaiApiKey(),
|
||||
modelName: getCustomOpenaiModelName(),
|
||||
|
|
@ -230,6 +239,15 @@ export const POST = async (req: Request) => {
|
|||
baseURL: getCustomOpenaiApiUrl(),
|
||||
},
|
||||
}) as unknown as BaseChatModel;
|
||||
} else if (body.chatModel?.provider == 'azure_openai') {
|
||||
console.error('An error occurred while processing chat request:', "azure_openai");
|
||||
llm = new AzureChatOpenAI({
|
||||
openAIApiKey: getAzureOpenaiApiKey(),
|
||||
deploymentName: getAzureOpenaiModelName(),
|
||||
openAIBasePath: getAzureOpenaiEndpoint(),
|
||||
openAIApiVersion: getAzureOpenaiApiVersion(),
|
||||
temperature: 0.7
|
||||
}) as unknown as BaseChatModel
|
||||
} else if (chatModelProvider && chatModel) {
|
||||
llm = chatModel.model;
|
||||
}
|
||||
|
|
@ -297,7 +315,7 @@ export const POST = async (req: Request) => {
|
|||
},
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('An error occurred while processing chat request:', err);
|
||||
console.error('An error occurred while processing chat request 123:', err);
|
||||
return Response.json(
|
||||
{ message: 'An error occurred while processing chat request' },
|
||||
{ status: 500 },
|
||||
|
|
|
|||
|
|
@ -3,6 +3,10 @@ import {
|
|||
getCustomOpenaiApiKey,
|
||||
getCustomOpenaiApiUrl,
|
||||
getCustomOpenaiModelName,
|
||||
getAzureOpenaiApiKey,
|
||||
getAzureOpenaiApiVersion,
|
||||
getAzureOpenaiModelName,
|
||||
getAzureOpenaiEndpoint,
|
||||
getGeminiApiKey,
|
||||
getGroqApiKey,
|
||||
getOllamaApiEndpoint,
|
||||
|
|
@ -58,6 +62,10 @@ export const GET = async (req: Request) => {
|
|||
config['customOpenaiApiUrl'] = getCustomOpenaiApiUrl();
|
||||
config['customOpenaiApiKey'] = getCustomOpenaiApiKey();
|
||||
config['customOpenaiModelName'] = getCustomOpenaiModelName();
|
||||
config['azureOpenaiApiKey'] = getAzureOpenaiApiKey();
|
||||
config['azureOpenaiApiVersion'] = getAzureOpenaiApiVersion();
|
||||
config['azureOpenaiModelName'] = getAzureOpenaiModelName();
|
||||
config['azureOpenaiEndpoint'] = getAzureOpenaiEndpoint();
|
||||
|
||||
return Response.json({ ...config }, { status: 200 });
|
||||
} catch (err) {
|
||||
|
|
@ -98,6 +106,12 @@ export const POST = async (req: Request) => {
|
|||
API_KEY: config.customOpenaiApiKey,
|
||||
MODEL_NAME: config.customOpenaiModelName,
|
||||
},
|
||||
AZURE_OPENAI: {
|
||||
API_KEY: config.azureOpenaiApiKey,
|
||||
MODEL_NAME: config.azureOpenaiModelName,
|
||||
ENDPOINT: config.azureOpenaiEndpoint,
|
||||
API_VERSION: config.azureOpenaiApiVersion,
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@ import {
|
|||
getCustomOpenaiApiUrl,
|
||||
getCustomOpenaiModelName,
|
||||
} from '@/lib/config';
|
||||
import {
|
||||
getAzureOpenaiApiKey,
|
||||
getAzureOpenaiEndpoint,
|
||||
getAzureOpenaiModelName,
|
||||
getAzureOpenaiApiVersion,
|
||||
} from '@/lib/config';
|
||||
import { getAvailableChatModelProviders } from '@/lib/providers';
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
|
||||
|
||||
interface ChatModel {
|
||||
provider: string;
|
||||
|
|
@ -56,6 +62,14 @@ export const POST = async (req: Request) => {
|
|||
baseURL: getCustomOpenaiApiUrl(),
|
||||
},
|
||||
}) as unknown as BaseChatModel;
|
||||
} else if (body.chatModel?.provider == 'azure_openai') {
|
||||
llm = new AzureChatOpenAI({
|
||||
openAIApiKey: getAzureOpenaiApiKey(),
|
||||
deploymentName: getAzureOpenaiModelName(),
|
||||
openAIBasePath: getAzureOpenaiEndpoint(),
|
||||
openAIApiVersion: getAzureOpenaiApiVersion(),
|
||||
temperature: 0.7
|
||||
}) as unknown as BaseChatModel
|
||||
} else if (chatModelProvider && chatModel) {
|
||||
llm = chatModel.model;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,11 +5,15 @@ import {
|
|||
|
||||
export const GET = async (req: Request) => {
|
||||
try {
|
||||
console.error('here ok0');
|
||||
const [chatModelProviders, embeddingModelProviders] = await Promise.all([
|
||||
getAvailableChatModelProviders(),
|
||||
getAvailableEmbeddingModelProviders(),
|
||||
]);
|
||||
|
||||
|
||||
console.error('here ok1');
|
||||
|
||||
Object.keys(chatModelProviders).forEach((provider) => {
|
||||
Object.keys(chatModelProviders[provider]).forEach((model) => {
|
||||
delete (chatModelProviders[provider][model] as { model?: unknown })
|
||||
|
|
@ -17,6 +21,8 @@ export const GET = async (req: Request) => {
|
|||
});
|
||||
});
|
||||
|
||||
console.error('here ok2');
|
||||
|
||||
Object.keys(embeddingModelProviders).forEach((provider) => {
|
||||
Object.keys(embeddingModelProviders[provider]).forEach((model) => {
|
||||
delete (embeddingModelProviders[provider][model] as { model?: unknown })
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import type { Embeddings } from '@langchain/core/embeddings';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
|
||||
import {
|
||||
getAvailableChatModelProviders,
|
||||
getAvailableEmbeddingModelProviders,
|
||||
|
|
@ -12,6 +12,12 @@ import {
|
|||
getCustomOpenaiApiUrl,
|
||||
getCustomOpenaiModelName,
|
||||
} from '@/lib/config';
|
||||
import {
|
||||
getAzureOpenaiApiKey,
|
||||
getAzureOpenaiEndpoint,
|
||||
getAzureOpenaiModelName,
|
||||
getAzureOpenaiApiVersion,
|
||||
} from '@/lib/config';
|
||||
import { searchHandlers } from '@/lib/search';
|
||||
|
||||
interface chatModel {
|
||||
|
|
@ -19,6 +25,10 @@ interface chatModel {
|
|||
name: string;
|
||||
customOpenAIKey?: string;
|
||||
customOpenAIBaseURL?: string;
|
||||
azureOpenAIApiVersion?: string;
|
||||
azureOpenAIApiKey?: string;
|
||||
azureOpenAIApiDeploymentName?: string;
|
||||
azureOpenAIEndpoint?: string;
|
||||
}
|
||||
|
||||
interface embeddingModel {
|
||||
|
|
@ -89,6 +99,14 @@ export const POST = async (req: Request) => {
|
|||
body.chatModel?.customOpenAIBaseURL || getCustomOpenaiApiUrl(),
|
||||
},
|
||||
}) as unknown as BaseChatModel;
|
||||
} else if (body.chatModel?.provider == 'azure_openai') {
|
||||
llm = new AzureChatOpenAI({
|
||||
openAIApiKey: body.chatModel?.azureOpenAIApiKey || getAzureOpenaiApiKey(),
|
||||
deploymentName: body.chatModel?.azureOpenAIApiDeploymentName || getAzureOpenaiModelName(),
|
||||
openAIBasePath: body.chatModel?.azureOpenAIEndpoint || getAzureOpenaiEndpoint(),
|
||||
openAIApiVersion: body.chatModel?.azureOpenAIApiVersion || getAzureOpenaiApiVersion(),
|
||||
temperature: 0.7
|
||||
}) as unknown as BaseChatModel
|
||||
} else if (
|
||||
chatModelProviders[chatModelProvider] &&
|
||||
chatModelProviders[chatModelProvider][chatModel]
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@ import {
|
|||
getCustomOpenaiApiUrl,
|
||||
getCustomOpenaiModelName,
|
||||
} from '@/lib/config';
|
||||
import {
|
||||
getAzureOpenaiApiKey,
|
||||
getAzureOpenaiEndpoint,
|
||||
getAzureOpenaiModelName,
|
||||
getAzureOpenaiApiVersion,
|
||||
} from '@/lib/config';
|
||||
import { getAvailableChatModelProviders } from '@/lib/providers';
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
|
||||
|
||||
interface ChatModel {
|
||||
provider: string;
|
||||
|
|
@ -55,6 +61,14 @@ export const POST = async (req: Request) => {
|
|||
baseURL: getCustomOpenaiApiUrl(),
|
||||
},
|
||||
}) as unknown as BaseChatModel;
|
||||
} else if (body.chatModel?.provider == 'azure_openai') {
|
||||
llm = new AzureChatOpenAI({
|
||||
openAIApiKey: getAzureOpenaiApiKey(),
|
||||
deploymentName: getAzureOpenaiModelName(),
|
||||
openAIBasePath: getAzureOpenaiEndpoint(),
|
||||
openAIApiVersion: getAzureOpenaiApiVersion(),
|
||||
temperature: 0.7
|
||||
}) as unknown as BaseChatModel
|
||||
} else if (chatModelProvider && chatModel) {
|
||||
llm = chatModel.model;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@ import {
|
|||
getCustomOpenaiApiUrl,
|
||||
getCustomOpenaiModelName,
|
||||
} from '@/lib/config';
|
||||
import {
|
||||
getAzureOpenaiApiKey,
|
||||
getAzureOpenaiEndpoint,
|
||||
getAzureOpenaiModelName,
|
||||
getAzureOpenaiApiVersion,
|
||||
} from '@/lib/config';
|
||||
import { getAvailableChatModelProviders } from '@/lib/providers';
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
|
||||
import { ChatOpenAI } from '@langchain/openai';
|
||||
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
|
||||
|
||||
interface ChatModel {
|
||||
provider: string;
|
||||
|
|
@ -56,6 +62,14 @@ export const POST = async (req: Request) => {
|
|||
baseURL: getCustomOpenaiApiUrl(),
|
||||
},
|
||||
}) as unknown as BaseChatModel;
|
||||
} else if (body.chatModel?.provider == 'azure_openai') {
|
||||
llm = new AzureChatOpenAI({
|
||||
openAIApiKey: getAzureOpenaiApiKey(),
|
||||
deploymentName: getAzureOpenaiModelName(),
|
||||
openAIBasePath: getAzureOpenaiEndpoint(),
|
||||
openAIApiVersion: getAzureOpenaiApiVersion(),
|
||||
temperature: 0.7
|
||||
}) as unknown as BaseChatModel
|
||||
} else if (chatModelProvider && chatModel) {
|
||||
llm = chatModel.model;
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue