diff --git a/src/agents/suggestionGeneratorAgent.ts b/src/agents/suggestionGeneratorAgent.ts index 0efdfa9..6ba255d 100644 --- a/src/agents/suggestionGeneratorAgent.ts +++ b/src/agents/suggestionGeneratorAgent.ts @@ -47,7 +47,7 @@ const generateSuggestions = ( input: SuggestionGeneratorInput, llm: BaseChatModel, ) => { - (llm as ChatOpenAI).temperature = 0; + (llm as unknown as ChatOpenAI).temperature = 0; const suggestionGeneratorChain = createSuggestionGeneratorChain(llm); return suggestionGeneratorChain.invoke(input); }; diff --git a/src/lib/providers/groq.ts b/src/lib/providers/groq.ts index ecdce4d..35bd125 100644 --- a/src/lib/providers/groq.ts +++ b/src/lib/providers/groq.ts @@ -5,6 +5,8 @@ import logger from '../../utils/logger'; export const loadGroqChatModels = async () => { const groqApiKey = getGroqApiKey(); + if (!groqApiKey) return {}; + try { const chatModels = { 'LLaMA3 8b': new ChatOpenAI( diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index 5807f94..b1d4502 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -1,7 +1,7 @@ import { loadGroqChatModels } from './groq'; -import { loadOllamaChatModels } from './ollama'; -import { loadOpenAIChatModels, loadOpenAIEmbeddingsModel } from './openai'; -import { loadTransformersEmbeddingsModel } from './transformers'; +import { loadOllamaChatModels, loadOllamaEmbeddingsModels } from './ollama'; +import { loadOpenAIChatModels, loadOpenAIEmbeddingsModels } from './openai'; +import { loadTransformersEmbeddingsModels } from './transformers'; const chatModelProviders = { openai: loadOpenAIChatModels, @@ -10,18 +10,23 @@ const chatModelProviders = { }; const embeddingModelProviders = { - openai: loadOpenAIEmbeddingsModel, - local: loadTransformersEmbeddingsModel, - ollama: loadOllamaChatModels, + openai: loadOpenAIEmbeddingsModels, + local: loadTransformersEmbeddingsModels, + ollama: loadOllamaEmbeddingsModels, }; export const getAvailableChatModelProviders = async () => { const models = {}; for (const provider in chatModelProviders) { - models[provider] = await chatModelProviders[provider](); + const providerModels = await chatModelProviders[provider](); + if (Object.keys(providerModels).length > 0) { + models[provider] = providerModels + } } + models['custom_openai'] = {} + return models; }; @@ -29,7 +34,10 @@ export const getAvailableEmbeddingModelProviders = async () => { const models = {}; for (const provider in embeddingModelProviders) { - models[provider] = await embeddingModelProviders[provider](); + const providerModels = await embeddingModelProviders[provider](); + if (Object.keys(providerModels).length > 0) { + models[provider] = providerModels + } } return models; diff --git a/src/lib/providers/ollama.ts b/src/lib/providers/ollama.ts index febe5e8..b2901ff 100644 --- a/src/lib/providers/ollama.ts +++ b/src/lib/providers/ollama.ts @@ -6,6 +6,8 @@ import { ChatOllama } from '@langchain/community/chat_models/ollama'; export const loadOllamaChatModels = async () => { const ollamaEndpoint = getOllamaApiEndpoint(); + if (!ollamaEndpoint) return {}; + try { const response = await fetch(`${ollamaEndpoint}/api/tags`, { headers: { @@ -31,9 +33,11 @@ export const loadOllamaChatModels = async () => { } }; -export const loadOpenAIEmbeddingsModel = async () => { +export const loadOllamaEmbeddingsModels = async () => { const ollamaEndpoint = getOllamaApiEndpoint(); + if (!ollamaEndpoint) return {}; + try { const response = await fetch(`${ollamaEndpoint}/api/tags`, { headers: { diff --git a/src/lib/providers/openai.ts b/src/lib/providers/openai.ts index 705f1a4..afc7ab8 100644 --- a/src/lib/providers/openai.ts +++ b/src/lib/providers/openai.ts @@ -5,6 +5,8 @@ import logger from '../../utils/logger'; export const loadOpenAIChatModels = async () => { const openAIApiKey = getOpenaiApiKey(); + if (!openAIApiKey) return {}; + try { const chatModels = { 'GPT-3.5 turbo': new ChatOpenAI({ @@ -36,9 +38,11 @@ export const loadOpenAIChatModels = async () => { } }; -export const loadOpenAIEmbeddingsModel = async () => { +export const loadOpenAIEmbeddingsModels = async () => { const openAIApiKey = getOpenaiApiKey(); + if (!openAIApiKey) return {}; + try { const embeddingModels = { 'Text embedding 3 small': new OpenAIEmbeddings({ diff --git a/src/lib/providers/transformers.ts b/src/lib/providers/transformers.ts index 7ef8596..0ec7052 100644 --- a/src/lib/providers/transformers.ts +++ b/src/lib/providers/transformers.ts @@ -1,7 +1,7 @@ import logger from '../../utils/logger'; import { HuggingFaceTransformersEmbeddings } from '../huggingfaceTransformer'; -export const loadTransformersEmbeddingsModel = async () => { +export const loadTransformersEmbeddingsModels = async () => { try { const embeddingModels = { 'BGE Small': new HuggingFaceTransformersEmbeddings({ diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index 5cb075b..70e20d9 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -45,7 +45,7 @@ export const handleConnection = async ( chatModelProviders[chatModelProvider][chatModel] && chatModelProvider != 'custom_openai' ) { - llm = chatModelProviders[chatModelProvider][chatModel] as + llm = chatModelProviders[chatModelProvider][chatModel] as unknown as | BaseChatModel | undefined; } else if (chatModelProvider == 'custom_openai') { @@ -56,7 +56,7 @@ export const handleConnection = async ( configuration: { baseURL: searchParams.get('openAIBaseURL'), }, - }); + }) as unknown as BaseChatModel; } if ( diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx index 675df49..b1a87a2 100644 --- a/ui/components/ChatWindow.tsx +++ b/ui/components/ChatWindow.tsx @@ -83,6 +83,55 @@ const useSocket = ( 'embeddingModelProvider', embeddingModelProvider, ); + } else { + const providers = await fetch( + `${process.env.NEXT_PUBLIC_API_URL}/models`, + { + headers: { + 'Content-Type': 'application/json', + }, + }, + ).then(async (res) => await res.json()); + + const chatModelProviders = providers.chatModelProviders; + const embeddingModelProviders = providers.embeddingModelProviders; + + if ( + Object.keys(chatModelProviders).length > 0 && + !chatModelProviders[chatModelProvider] + ) { + chatModelProvider = Object.keys(chatModelProviders)[0]; + localStorage.setItem('chatModelProvider', chatModelProvider); + } + + if ( + chatModelProvider && + !chatModelProviders[chatModelProvider][chatModel] + ) { + chatModel = Object.keys(chatModelProviders[chatModelProvider])[0]; + localStorage.setItem('chatModel', chatModel); + } + + if ( + Object.keys(embeddingModelProviders).length > 0 && + !embeddingModelProviders[embeddingModelProvider] + ) { + embeddingModelProvider = Object.keys(embeddingModelProviders)[0]; + localStorage.setItem( + 'embeddingModelProvider', + embeddingModelProvider, + ); + } + + if ( + embeddingModelProvider && + !embeddingModelProviders[embeddingModelProvider][embeddingModel] + ) { + embeddingModel = Object.keys( + embeddingModelProviders[embeddingModelProvider], + )[0]; + localStorage.setItem('embeddingModel', embeddingModel); + } } const wsURL = new URL(url);