diff --git a/src/lib/providers/index.ts b/src/lib/providers/index.ts index d919fd4..49f2e22 100644 --- a/src/lib/providers/index.ts +++ b/src/lib/providers/index.ts @@ -42,5 +42,7 @@ export const getAvailableEmbeddingModelProviders = async () => { } } + models['custom_openai'] = {}; + return models; }; diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts index 70e20d9..b77106c 100644 --- a/src/websocket/connectionManager.ts +++ b/src/websocket/connectionManager.ts @@ -8,7 +8,7 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Embeddings } from '@langchain/core/embeddings'; import type { IncomingMessage } from 'http'; import logger from '../utils/logger'; -import { ChatOpenAI } from '@langchain/openai'; +import { ChatOpenAI, OpenAIEmbeddings } from '@langchain/openai'; export const handleConnection = async ( ws: WebSocket, @@ -61,11 +61,20 @@ export const handleConnection = async ( if ( embeddingModelProviders[embeddingModelProvider] && - embeddingModelProviders[embeddingModelProvider][embeddingModel] + embeddingModelProviders[embeddingModelProvider][embeddingModel] && + embeddingModelProvider != 'custom_openai' ) { embeddings = embeddingModelProviders[embeddingModelProvider][ embeddingModel ] as Embeddings | undefined; + } else if (embeddingModelProvider == 'custom_openai') { + embeddings = new OpenAIEmbeddings({ + modelName: embeddingModel, + openAIApiKey: searchParams.get('openAIApiKey'), + configuration: { + baseURL: searchParams.get('openAIBaseURL'), + }, + }) as unknown as Embeddings } if (!llm || !embeddings) { diff --git a/ui/components/SettingsDialog.tsx b/ui/components/SettingsDialog.tsx index 171e812..7b84bf5 100644 --- a/ui/components/SettingsDialog.tsx +++ b/ui/components/SettingsDialog.tsx @@ -9,7 +9,7 @@ import React, { } from 'react'; import ThemeSwitcher from './theme/Switcher'; -interface InputProps extends React.InputHTMLAttributes {} +interface InputProps extends React.InputHTMLAttributes { } const Input = ({ className, ...restProps }: InputProps) => { return ( @@ -258,30 +258,30 @@ const SettingsDialog = ({ options={(() => { const chatModelProvider = config.chatModelProviders[ - selectedChatModelProvider + selectedChatModelProvider ]; return chatModelProvider ? chatModelProvider.length > 0 ? chatModelProvider.map((model) => ({ - value: model, - label: model, - })) + value: model, + label: model, + })) : [ - { - value: '', - label: 'No models available', - disabled: true, - }, - ] - : [ { value: '', - label: - 'Invalid provider, please check backend logs', + label: 'No models available', disabled: true, }, - ]; + ] + : [ + { + value: '', + label: + 'Invalid provider, please check backend logs', + disabled: true, + }, + ]; })()} /> @@ -355,7 +355,7 @@ const SettingsDialog = ({ /> )} - {selectedEmbeddingModelProvider && ( + {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider != 'custom_openai' && (

Embedding Model @@ -368,34 +368,49 @@ const SettingsDialog = ({ options={(() => { const embeddingModelProvider = config.embeddingModelProviders[ - selectedEmbeddingModelProvider + selectedEmbeddingModelProvider ]; return embeddingModelProvider ? embeddingModelProvider.length > 0 ? embeddingModelProvider.map((model) => ({ - label: model, - value: model, - })) + label: model, + value: model, + })) : [ - { - label: 'No embedding models available', - value: '', - disabled: true, - }, - ] - : [ { - label: - 'Invalid provider, please check backend logs', + label: 'No embedding models available', value: '', disabled: true, }, - ]; + ] + : [ + { + label: + 'Invalid provider, please check backend logs', + value: '', + disabled: true, + }, + ]; })()} />

)} + {selectedEmbeddingModelProvider && selectedEmbeddingModelProvider === 'custom_openai' && ( +
+

+ Embedding Model +

+ + setSelectedEmbeddingModel(e.target.value) + } + /> +
+ )}

OpenAI API Key