support AzureOpenAI

This commit is contained in:
nnnyt 2025-04-07 13:49:09 -04:00
parent e226645bc7
commit 28e308db01
11 changed files with 256 additions and 8 deletions

View file

@ -14,12 +14,18 @@ import { chats, messages as messagesSchema } from '@/lib/db/schema';
import { and, eq, gt } from 'drizzle-orm';
import { getFileDetails } from '@/lib/utils/files';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ChatOpenAI } from '@langchain/openai';
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
import {
getCustomOpenaiApiKey,
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import {
getAzureOpenaiApiKey,
getAzureOpenaiEndpoint,
getAzureOpenaiModelName,
getAzureOpenaiApiVersion,
} from '@/lib/config';
import { searchHandlers } from '@/lib/search';
export const runtime = 'nodejs';
@ -186,6 +192,8 @@ export const POST = async (req: Request) => {
const body = (await req.json()) as Body;
const { message } = body;
console.error('An error occurred while processing chat request:', "here");
if (message.content === '') {
return Response.json(
{
@ -222,6 +230,7 @@ export const POST = async (req: Request) => {
let embedding = embeddingModel.model;
if (body.chatModel?.provider === 'custom_openai') {
console.error('An error occurred while processing chat request:', "custom_openai");
llm = new ChatOpenAI({
openAIApiKey: getCustomOpenaiApiKey(),
modelName: getCustomOpenaiModelName(),
@ -230,6 +239,15 @@ export const POST = async (req: Request) => {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (body.chatModel?.provider == 'azure_openai') {
console.error('An error occurred while processing chat request:', "azure_openai");
llm = new AzureChatOpenAI({
openAIApiKey: getAzureOpenaiApiKey(),
deploymentName: getAzureOpenaiModelName(),
openAIBasePath: getAzureOpenaiEndpoint(),
openAIApiVersion: getAzureOpenaiApiVersion(),
temperature: 0.7
}) as unknown as BaseChatModel
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}
@ -297,7 +315,7 @@ export const POST = async (req: Request) => {
},
});
} catch (err) {
console.error('An error occurred while processing chat request:', err);
console.error('An error occurred while processing chat request 123:', err);
return Response.json(
{ message: 'An error occurred while processing chat request' },
{ status: 500 },

View file

@ -3,6 +3,10 @@ import {
getCustomOpenaiApiKey,
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
getAzureOpenaiApiKey,
getAzureOpenaiApiVersion,
getAzureOpenaiModelName,
getAzureOpenaiEndpoint,
getGeminiApiKey,
getGroqApiKey,
getOllamaApiEndpoint,
@ -58,6 +62,10 @@ export const GET = async (req: Request) => {
config['customOpenaiApiUrl'] = getCustomOpenaiApiUrl();
config['customOpenaiApiKey'] = getCustomOpenaiApiKey();
config['customOpenaiModelName'] = getCustomOpenaiModelName();
config['azureOpenaiApiKey'] = getAzureOpenaiApiKey();
config['azureOpenaiApiVersion'] = getAzureOpenaiApiVersion();
config['azureOpenaiModelName'] = getAzureOpenaiModelName();
config['azureOpenaiEndpoint'] = getAzureOpenaiEndpoint();
return Response.json({ ...config }, { status: 200 });
} catch (err) {
@ -98,6 +106,12 @@ export const POST = async (req: Request) => {
API_KEY: config.customOpenaiApiKey,
MODEL_NAME: config.customOpenaiModelName,
},
AZURE_OPENAI: {
API_KEY: config.azureOpenaiApiKey,
MODEL_NAME: config.azureOpenaiModelName,
ENDPOINT: config.azureOpenaiEndpoint,
API_VERSION: config.azureOpenaiApiVersion,
}
},
};

View file

@ -4,10 +4,16 @@ import {
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import {
getAzureOpenaiApiKey,
getAzureOpenaiEndpoint,
getAzureOpenaiModelName,
getAzureOpenaiApiVersion,
} from '@/lib/config';
import { getAvailableChatModelProviders } from '@/lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { ChatOpenAI } from '@langchain/openai';
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
interface ChatModel {
provider: string;
@ -56,6 +62,14 @@ export const POST = async (req: Request) => {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (body.chatModel?.provider == 'azure_openai') {
llm = new AzureChatOpenAI({
openAIApiKey: getAzureOpenaiApiKey(),
deploymentName: getAzureOpenaiModelName(),
openAIBasePath: getAzureOpenaiEndpoint(),
openAIApiVersion: getAzureOpenaiApiVersion(),
temperature: 0.7
}) as unknown as BaseChatModel
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}

View file

@ -5,11 +5,15 @@ import {
export const GET = async (req: Request) => {
try {
console.error('here ok0');
const [chatModelProviders, embeddingModelProviders] = await Promise.all([
getAvailableChatModelProviders(),
getAvailableEmbeddingModelProviders(),
]);
console.error('here ok1');
Object.keys(chatModelProviders).forEach((provider) => {
Object.keys(chatModelProviders[provider]).forEach((model) => {
delete (chatModelProviders[provider][model] as { model?: unknown })
@ -17,6 +21,8 @@ export const GET = async (req: Request) => {
});
});
console.error('here ok2');
Object.keys(embeddingModelProviders).forEach((provider) => {
Object.keys(embeddingModelProviders[provider]).forEach((model) => {
delete (embeddingModelProviders[provider][model] as { model?: unknown })

View file

@ -1,6 +1,6 @@
import type { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Embeddings } from '@langchain/core/embeddings';
import { ChatOpenAI } from '@langchain/openai';
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
import {
getAvailableChatModelProviders,
getAvailableEmbeddingModelProviders,
@ -12,6 +12,12 @@ import {
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import {
getAzureOpenaiApiKey,
getAzureOpenaiEndpoint,
getAzureOpenaiModelName,
getAzureOpenaiApiVersion,
} from '@/lib/config';
import { searchHandlers } from '@/lib/search';
interface chatModel {
@ -19,6 +25,10 @@ interface chatModel {
name: string;
customOpenAIKey?: string;
customOpenAIBaseURL?: string;
azureOpenAIApiVersion?: string;
azureOpenAIApiKey?: string;
azureOpenAIApiDeploymentName?: string;
azureOpenAIEndpoint?: string;
}
interface embeddingModel {
@ -89,6 +99,14 @@ export const POST = async (req: Request) => {
body.chatModel?.customOpenAIBaseURL || getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (body.chatModel?.provider == 'azure_openai') {
llm = new AzureChatOpenAI({
openAIApiKey: body.chatModel?.azureOpenAIApiKey || getAzureOpenaiApiKey(),
deploymentName: body.chatModel?.azureOpenAIApiDeploymentName || getAzureOpenaiModelName(),
openAIBasePath: body.chatModel?.azureOpenAIEndpoint || getAzureOpenaiEndpoint(),
openAIApiVersion: body.chatModel?.azureOpenAIApiVersion || getAzureOpenaiApiVersion(),
temperature: 0.7
}) as unknown as BaseChatModel
} else if (
chatModelProviders[chatModelProvider] &&
chatModelProviders[chatModelProvider][chatModel]

View file

@ -4,10 +4,16 @@ import {
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import {
getAzureOpenaiApiKey,
getAzureOpenaiEndpoint,
getAzureOpenaiModelName,
getAzureOpenaiApiVersion,
} from '@/lib/config';
import { getAvailableChatModelProviders } from '@/lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { ChatOpenAI } from '@langchain/openai';
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
interface ChatModel {
provider: string;
@ -55,6 +61,14 @@ export const POST = async (req: Request) => {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (body.chatModel?.provider == 'azure_openai') {
llm = new AzureChatOpenAI({
openAIApiKey: getAzureOpenaiApiKey(),
deploymentName: getAzureOpenaiModelName(),
openAIBasePath: getAzureOpenaiEndpoint(),
openAIApiVersion: getAzureOpenaiApiVersion(),
temperature: 0.7
}) as unknown as BaseChatModel
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}

View file

@ -4,10 +4,16 @@ import {
getCustomOpenaiApiUrl,
getCustomOpenaiModelName,
} from '@/lib/config';
import {
getAzureOpenaiApiKey,
getAzureOpenaiEndpoint,
getAzureOpenaiModelName,
getAzureOpenaiApiVersion,
} from '@/lib/config';
import { getAvailableChatModelProviders } from '@/lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages';
import { ChatOpenAI } from '@langchain/openai';
import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai';
interface ChatModel {
provider: string;
@ -56,6 +62,14 @@ export const POST = async (req: Request) => {
baseURL: getCustomOpenaiApiUrl(),
},
}) as unknown as BaseChatModel;
} else if (body.chatModel?.provider == 'azure_openai') {
llm = new AzureChatOpenAI({
openAIApiKey: getAzureOpenaiApiKey(),
deploymentName: getAzureOpenaiModelName(),
openAIBasePath: getAzureOpenaiEndpoint(),
openAIApiVersion: getAzureOpenaiApiVersion(),
temperature: 0.7
}) as unknown as BaseChatModel
} else if (chatModelProvider && chatModel) {
llm = chatModel.model;
}