diff --git a/sample.config.toml b/sample.config.toml index 980e99d..9e41edb 100644 --- a/sample.config.toml +++ b/sample.config.toml @@ -19,6 +19,12 @@ API_KEY = "" API_URL = "" MODEL_NAME = "" +[MODELS.AZURE_OPENAI] +API_KEY = "" +ENDPOINT = "" +MODEL_NAME = "" +API_VERSION = "" + [MODELS.OLLAMA] API_URL = "" # Ollama API URL - http://host.docker.internal:11434 diff --git a/src/app/api/chat/route.ts b/src/app/api/chat/route.ts index e566edb..042fe9b 100644 --- a/src/app/api/chat/route.ts +++ b/src/app/api/chat/route.ts @@ -14,12 +14,18 @@ import { chats, messages as messagesSchema } from '@/lib/db/schema'; import { and, eq, gt } from 'drizzle-orm'; import { getFileDetails } from '@/lib/utils/files'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai'; import { getCustomOpenaiApiKey, getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '@/lib/config'; +import { + getAzureOpenaiApiKey, + getAzureOpenaiEndpoint, + getAzureOpenaiModelName, + getAzureOpenaiApiVersion, +} from '@/lib/config'; import { searchHandlers } from '@/lib/search'; export const runtime = 'nodejs'; @@ -186,6 +192,8 @@ export const POST = async (req: Request) => { const body = (await req.json()) as Body; const { message } = body; + console.error('An error occurred while processing chat request:', "here"); + if (message.content === '') { return Response.json( { @@ -222,6 +230,7 @@ export const POST = async (req: Request) => { let embedding = embeddingModel.model; if (body.chatModel?.provider === 'custom_openai') { + console.error('An error occurred while processing chat request:', "custom_openai"); llm = new ChatOpenAI({ openAIApiKey: getCustomOpenaiApiKey(), modelName: getCustomOpenaiModelName(), @@ -230,6 +239,15 @@ export const POST = async (req: Request) => { baseURL: getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; + } else if (body.chatModel?.provider == 'azure_openai') { + console.error('An error occurred while processing chat request:', "azure_openai"); + llm = new AzureChatOpenAI({ + openAIApiKey: getAzureOpenaiApiKey(), + deploymentName: getAzureOpenaiModelName(), + openAIBasePath: getAzureOpenaiEndpoint(), + openAIApiVersion: getAzureOpenaiApiVersion(), + temperature: 0.7 + }) as unknown as BaseChatModel } else if (chatModelProvider && chatModel) { llm = chatModel.model; } @@ -297,7 +315,7 @@ export const POST = async (req: Request) => { }, }); } catch (err) { - console.error('An error occurred while processing chat request:', err); + console.error('An error occurred while processing chat request 123:', err); return Response.json( { message: 'An error occurred while processing chat request' }, { status: 500 }, diff --git a/src/app/api/config/route.ts b/src/app/api/config/route.ts index 39c1f84..941df60 100644 --- a/src/app/api/config/route.ts +++ b/src/app/api/config/route.ts @@ -3,6 +3,10 @@ import { getCustomOpenaiApiKey, getCustomOpenaiApiUrl, getCustomOpenaiModelName, + getAzureOpenaiApiKey, + getAzureOpenaiApiVersion, + getAzureOpenaiModelName, + getAzureOpenaiEndpoint, getGeminiApiKey, getGroqApiKey, getOllamaApiEndpoint, @@ -58,6 +62,10 @@ export const GET = async (req: Request) => { config['customOpenaiApiUrl'] = getCustomOpenaiApiUrl(); config['customOpenaiApiKey'] = getCustomOpenaiApiKey(); config['customOpenaiModelName'] = getCustomOpenaiModelName(); + config['azureOpenaiApiKey'] = getAzureOpenaiApiKey(); + config['azureOpenaiApiVersion'] = getAzureOpenaiApiVersion(); + config['azureOpenaiModelName'] = getAzureOpenaiModelName(); + config['azureOpenaiEndpoint'] = getAzureOpenaiEndpoint(); return Response.json({ ...config }, { status: 200 }); } catch (err) { @@ -98,6 +106,12 @@ export const POST = async (req: Request) => { API_KEY: config.customOpenaiApiKey, MODEL_NAME: config.customOpenaiModelName, }, + AZURE_OPENAI: { + API_KEY: config.azureOpenaiApiKey, + MODEL_NAME: config.azureOpenaiModelName, + ENDPOINT: config.azureOpenaiEndpoint, + API_VERSION: config.azureOpenaiApiVersion, + } }, }; diff --git a/src/app/api/images/route.ts b/src/app/api/images/route.ts index db39d9f..98751ba 100644 --- a/src/app/api/images/route.ts +++ b/src/app/api/images/route.ts @@ -4,10 +4,16 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '@/lib/config'; +import { + getAzureOpenaiApiKey, + getAzureOpenaiEndpoint, + getAzureOpenaiModelName, + getAzureOpenaiApiVersion, +} from '@/lib/config'; import { getAvailableChatModelProviders } from '@/lib/providers'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai'; interface ChatModel { provider: string; @@ -56,6 +62,14 @@ export const POST = async (req: Request) => { baseURL: getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; + } else if (body.chatModel?.provider == 'azure_openai') { + llm = new AzureChatOpenAI({ + openAIApiKey: getAzureOpenaiApiKey(), + deploymentName: getAzureOpenaiModelName(), + openAIBasePath: getAzureOpenaiEndpoint(), + openAIApiVersion: getAzureOpenaiApiVersion(), + temperature: 0.7 + }) as unknown as BaseChatModel } else if (chatModelProvider && chatModel) { llm = chatModel.model; } diff --git a/src/app/api/models/route.ts b/src/app/api/models/route.ts index 04a6949..f886da6 100644 --- a/src/app/api/models/route.ts +++ b/src/app/api/models/route.ts @@ -5,11 +5,15 @@ import { export const GET = async (req: Request) => { try { + console.error('here ok0'); const [chatModelProviders, embeddingModelProviders] = await Promise.all([ getAvailableChatModelProviders(), getAvailableEmbeddingModelProviders(), ]); + + console.error('here ok1'); + Object.keys(chatModelProviders).forEach((provider) => { Object.keys(chatModelProviders[provider]).forEach((model) => { delete (chatModelProviders[provider][model] as { model?: unknown }) @@ -17,6 +21,8 @@ export const GET = async (req: Request) => { }); }); + console.error('here ok2'); + Object.keys(embeddingModelProviders).forEach((provider) => { Object.keys(embeddingModelProviders[provider]).forEach((model) => { delete (embeddingModelProviders[provider][model] as { model?: unknown }) diff --git a/src/app/api/search/route.ts b/src/app/api/search/route.ts index 970ec42..74b2100 100644 --- a/src/app/api/search/route.ts +++ b/src/app/api/search/route.ts @@ -1,6 +1,6 @@ import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai'; import { getAvailableChatModelProviders, getAvailableEmbeddingModelProviders, @@ -12,6 +12,12 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '@/lib/config'; +import { + getAzureOpenaiApiKey, + getAzureOpenaiEndpoint, + getAzureOpenaiModelName, + getAzureOpenaiApiVersion, +} from '@/lib/config'; import { searchHandlers } from '@/lib/search'; interface chatModel { @@ -19,6 +25,10 @@ interface chatModel { name: string; customOpenAIKey?: string; customOpenAIBaseURL?: string; + azureOpenAIApiVersion?: string; + azureOpenAIApiKey?: string; + azureOpenAIApiDeploymentName?: string; + azureOpenAIEndpoint?: string; } interface embeddingModel { @@ -89,6 +99,14 @@ export const POST = async (req: Request) => { body.chatModel?.customOpenAIBaseURL || getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; + } else if (body.chatModel?.provider == 'azure_openai') { + llm = new AzureChatOpenAI({ + openAIApiKey: body.chatModel?.azureOpenAIApiKey || getAzureOpenaiApiKey(), + deploymentName: body.chatModel?.azureOpenAIApiDeploymentName || getAzureOpenaiModelName(), + openAIBasePath: body.chatModel?.azureOpenAIEndpoint || getAzureOpenaiEndpoint(), + openAIApiVersion: body.chatModel?.azureOpenAIApiVersion || getAzureOpenaiApiVersion(), + temperature: 0.7 + }) as unknown as BaseChatModel } else if ( chatModelProviders[chatModelProvider] && chatModelProviders[chatModelProvider][chatModel] diff --git a/src/app/api/suggestions/route.ts b/src/app/api/suggestions/route.ts index e92e5ec..d7ea6f9 100644 --- a/src/app/api/suggestions/route.ts +++ b/src/app/api/suggestions/route.ts @@ -4,10 +4,16 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '@/lib/config'; +import { + getAzureOpenaiApiKey, + getAzureOpenaiEndpoint, + getAzureOpenaiModelName, + getAzureOpenaiApiVersion, +} from '@/lib/config'; import { getAvailableChatModelProviders } from '@/lib/providers'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai'; interface ChatModel { provider: string; @@ -55,6 +61,14 @@ export const POST = async (req: Request) => { baseURL: getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; + } else if (body.chatModel?.provider == 'azure_openai') { + llm = new AzureChatOpenAI({ + openAIApiKey: getAzureOpenaiApiKey(), + deploymentName: getAzureOpenaiModelName(), + openAIBasePath: getAzureOpenaiEndpoint(), + openAIApiVersion: getAzureOpenaiApiVersion(), + temperature: 0.7 + }) as unknown as BaseChatModel } else if (chatModelProvider && chatModel) { llm = chatModel.model; } diff --git a/src/app/api/videos/route.ts b/src/app/api/videos/route.ts index 34ae7fd..ea8d3a5 100644 --- a/src/app/api/videos/route.ts +++ b/src/app/api/videos/route.ts @@ -4,10 +4,16 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '@/lib/config'; +import { + getAzureOpenaiApiKey, + getAzureOpenaiEndpoint, + getAzureOpenaiModelName, + getAzureOpenaiApiVersion, +} from '@/lib/config'; import { getAvailableChatModelProviders } from '@/lib/providers'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { AIMessage, BaseMessage, HumanMessage } from '@langchain/core/messages'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai'; interface ChatModel { provider: string; @@ -56,6 +62,14 @@ export const POST = async (req: Request) => { baseURL: getCustomOpenaiApiUrl(), }, }) as unknown as BaseChatModel; + } else if (body.chatModel?.provider == 'azure_openai') { + llm = new AzureChatOpenAI({ + openAIApiKey: getAzureOpenaiApiKey(), + deploymentName: getAzureOpenaiModelName(), + openAIBasePath: getAzureOpenaiEndpoint(), + openAIApiVersion: getAzureOpenaiApiVersion(), + temperature: 0.7 + }) as unknown as BaseChatModel } else if (chatModelProvider && chatModel) { llm = chatModel.model; } diff --git a/src/app/settings/page.tsx b/src/app/settings/page.tsx index 8eee9a4..8755c6b 100644 --- a/src/app/settings/page.tsx +++ b/src/app/settings/page.tsx @@ -24,6 +24,10 @@ interface SettingsType { customOpenaiApiKey: string; customOpenaiApiUrl: string; customOpenaiModelName: string; + azureOpenaiModelName: string; + azureOpenaiEndpoint: string; + azureOpenaiApiKey: string; + azureOpenaiApiVersion: string; } interface InputProps extends React.InputHTMLAttributes { @@ -556,7 +560,8 @@ const Page = () => { {selectedChatModelProvider && - selectedChatModelProvider != 'custom_openai' && ( + selectedChatModelProvider != 'custom_openai' && + selectedChatModelProvider != 'azure_openai' && (

Chat Model @@ -666,6 +671,93 @@ const Page = () => {

)} + + {selectedChatModelProvider && + selectedChatModelProvider === 'azure_openai' && ( +
+
+

+ Model Name +

+ ) => { + setConfig((prev) => ({ + ...prev!, + azureOpenaiModelName: e.target.value, + })); + }} + onSave={(value) => + saveConfig('azureOpenaiModelName', value) + } + /> +
+
+

+ Azure OpenAI API Key +

+ ) => { + setConfig((prev) => ({ + ...prev!, + azureOpenaiApiKey: e.target.value, + })); + }} + onSave={(value) => + saveConfig('azureOpenaiApiKey', value) + } + /> +
+
+

+ Azure OpenAI Base URL +

+ ) => { + setConfig((prev) => ({ + ...prev!, + azureOpenaiEndpoint: e.target.value, + })); + }} + onSave={(value) => + saveConfig('azureOpenaiEndpoint', value) + } + /> +
+
+

+ Azure OpenAI Api Version +

+ ) => { + setConfig((prev) => ({ + ...prev!, + azureOpenaiApiVersion: e.target.value, + })); + }} + onSave={(value) => + saveConfig('azureOpenaiApiVersion', value) + } + /> +
+
+ )} + {config.embeddingModelProviders && (
diff --git a/src/lib/config.ts b/src/lib/config.ts index 2831214..9b7a70f 100644 --- a/src/lib/config.ts +++ b/src/lib/config.ts @@ -33,6 +33,12 @@ interface Config { API_KEY: string; MODEL_NAME: string; }; + AZURE_OPENAI: { + ENDPOINT: string; + API_KEY: string; + MODEL_NAME: string; + API_VERSION: string; + } }; API_ENDPOINTS: { SEARXNG: string; @@ -77,6 +83,18 @@ export const getCustomOpenaiApiUrl = () => export const getCustomOpenaiModelName = () => loadConfig().MODELS.CUSTOM_OPENAI.MODEL_NAME; +export const getAzureOpenaiApiKey = () => + loadConfig().MODELS.AZURE_OPENAI.API_KEY; + +export const getAzureOpenaiEndpoint = () => + loadConfig().MODELS.AZURE_OPENAI.ENDPOINT; + +export const getAzureOpenaiModelName = () => + loadConfig().MODELS.AZURE_OPENAI.MODEL_NAME; + +export const getAzureOpenaiApiVersion = () => + loadConfig().MODELS.AZURE_OPENAI.API_VERSION; + const mergeConfigs = (current: any, update: any): any => { if (update === null || update === undefined) { return current; diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index eef212f..4eb9548 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -6,13 +6,20 @@ import { getCustomOpenaiApiUrl, getCustomOpenaiModelName, } from '../config'; -import { ChatOpenAI } from '@langchain/openai'; +import { + getAzureOpenaiApiKey, + getAzureOpenaiEndpoint, + getAzureOpenaiModelName, + getAzureOpenaiApiVersion, +} from '../config'; +import { ChatOpenAI, AzureChatOpenAI } from '@langchain/openai'; import { loadOllamaChatModels, loadOllamaEmbeddingModels } from './ollama'; import { loadGroqChatModels } from './groq'; import { loadAnthropicChatModels } from './anthropic'; import { loadGeminiChatModels, loadGeminiEmbeddingModels } from './gemini'; import { loadTransformersEmbeddingsModels } from './transformers'; import { loadDeepseekChatModels } from './deepseek'; +import Chat from '@/components/Chat'; export interface ChatModel { displayName: string; @@ -60,6 +67,11 @@ export const getAvailableChatModelProviders = async () => { const customOpenAiApiUrl = getCustomOpenaiApiUrl(); const customOpenAiModelName = getCustomOpenaiModelName(); + const azureOpenAiApiKey = getAzureOpenaiApiKey(); + const azureOpenAiModelName = getAzureOpenaiModelName(); + const azureOpenAiApiVersion = getAzureOpenaiApiVersion(); + const azureOpenAiEndpoint = getAzureOpenaiEndpoint(); + models['custom_openai'] = { ...(customOpenAiApiKey && customOpenAiApiUrl && customOpenAiModelName ? { @@ -78,6 +90,28 @@ export const getAvailableChatModelProviders = async () => { : {}), }; + console.log("here ok1 - start azure_openai"); + console.log(azureOpenAiApiKey, azureOpenAiEndpoint, azureOpenAiApiVersion, azureOpenAiModelName); + + models['azure_openai'] = { + ...(azureOpenAiApiKey && azureOpenAiEndpoint && azureOpenAiApiVersion && azureOpenAiModelName + ? { + [azureOpenAiModelName]: { + displayName: azureOpenAiModelName, + model: new AzureChatOpenAI({ + openAIApiKey: azureOpenAiApiKey, + deploymentName: azureOpenAiModelName, + openAIBasePath: azureOpenAiEndpoint, + openAIApiVersion: azureOpenAiApiVersion, + temperature: 0.7 + }) as unknown as BaseChatModel + }, + } + : {}), + } + + console.log(models); + return models; };