Skip to content

Commit

Permalink
Add base API URL field for Ollama and OpenAI embedding models (jupyte…
Browse files Browse the repository at this point in the history
…rlab#1136)

* Base API URL added for embedding models

Jupyter AI currently allows the user to call a model at a URL (location) different from the default one by specifying a selected Base API URL. This can be done for Ollama, OpenAI provider models. However, for these providers, there is no way to change the API URL for embedding models when using the `/learn` command in RAG mode. This PR adds an extra field to make this feasible.

Tested as follows for Ollama:
[1] Start the Ollama system from port 11435 instead 11434 (the default):
`OLLAMA_HOST=127.0.0.1:11435 ollama serve`
[2] Set the Base API URL:

[3] Check that the new API URL works:

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* allow embedding model fields to be saved

* exclude empty str fields from config manager

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: David L. Qiu <david@qiu.dev>
  • Loading branch information
3 people committed Jan 6, 2025
1 parent 8fd29ed commit 40ba065
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langchain_ollama import ChatOllama, OllamaEmbeddings

from ..embedding_providers import BaseEmbeddingsProvider
from ..providers import BaseProvider, EnvAuthStrategy, TextField
from ..providers import BaseProvider, TextField


class OllamaProvider(BaseProvider, ChatOllama):
Expand All @@ -23,10 +23,14 @@ class OllamaEmbeddingsProvider(BaseEmbeddingsProvider, OllamaEmbeddings):
id = "ollama"
name = "Ollama"
# source: https://ollama.com/library
model_id_key = "model"
models = [
"nomic-embed-text",
"mxbai-embed-large",
"all-minilm",
"snowflake-arctic-embed",
]
model_id_key = "model"
registry = True
fields = [
TextField(key="base_url", label="Base API URL (optional)", format="text"),
]
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class OpenAIEmbeddingsProvider(BaseEmbeddingsProvider, OpenAIEmbeddings):
model_id_key = "model"
pypi_package_deps = ["langchain_openai"]
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
registry = True
fields = [
TextField(
key="openai_api_base", label="Base API URL (optional)", format="text"
),
]


class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbeddings):
Expand All @@ -122,5 +128,7 @@ class AzureOpenAIEmbeddingsProvider(BaseEmbeddingsProvider, AzureOpenAIEmbedding
auth_strategy = EnvAuthStrategy(
name="AZURE_OPENAI_API_KEY", keyword_param="openai_api_key"
)

registry = True
fields = [
TextField(key="azure_endpoint", label="Base API URL (optional)", format="text"),
]
7 changes: 7 additions & 0 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,13 @@ def _provider_params(self, key, listing, completions: bool = False):
else:
fields = config.fields.get(model_uid, {})

# exclude empty fields
# TODO: modify the config manager to never save empty fields in the
# first place.
for field_key in fields:
if isinstance(fields[field_key], str) and not len(fields[field_key]):
fields[field_key] = None

# get authn fields
_, Provider = get_em_provider(model_uid, listing)
authn_fields = {}
Expand Down
63 changes: 43 additions & 20 deletions packages/jupyter-ai/src/components/chat-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
const [apiKeys, setApiKeys] = useState<Record<string, string>>({});
const [sendWse, setSendWse] = useState<boolean>(false);
const [fields, setFields] = useState<Record<string, any>>({});
const [embeddingModelFields, setEmbeddingModelFields] = useState<
Record<string, any>
>({});

const [isCompleterEnabled, setIsCompleterEnabled] = useState(
props.completionProvider && props.completionProvider.isEnabled()
Expand Down Expand Up @@ -188,7 +191,15 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
const currFields: Record<string, any> =
server.config.fields?.[lmGlobalId] ?? {};
setFields(currFields);
}, [server, lmProvider]);

if (!emGlobalId) {
return;
}

const initEmbeddingModelFields: Record<string, any> =
server.config.fields?.[emGlobalId] ?? {};
setEmbeddingModelFields(initEmbeddingModelFields);
}, [server, lmGlobalId, emGlobalId]);

const handleSave = async () => {
// compress fields with JSON values
Expand Down Expand Up @@ -222,6 +233,9 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
}),
...(clmGlobalId && {
[clmGlobalId]: fields
}),
...(emGlobalId && {
[emGlobalId]: embeddingModelFields
})
}
}),
Expand Down Expand Up @@ -376,26 +390,35 @@ export function ChatSettings(props: ChatSettingsProps): JSX.Element {
{/* Embedding model section */}
<h2 className="jp-ai-ChatSettings-header">Embedding model</h2>
{server.emProviders.providers.length > 0 ? (
<Select
value={emGlobalId}
label="Embedding model"
onChange={e => {
const emGid = e.target.value === 'null' ? null : e.target.value;
setEmGlobalId(emGid);
}}
MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }}
>
<MenuItem value="null">None</MenuItem>
{server.emProviders.providers.map(emp =>
emp.models
.filter(em => em !== '*') // TODO: support registry providers
.map(em => (
<MenuItem value={`${emp.id}:${em}`}>
{emp.name} :: {em}
</MenuItem>
))
<Box>
<Select
value={emGlobalId}
label="Embedding model"
onChange={e => {
const emGid = e.target.value === 'null' ? null : e.target.value;
setEmGlobalId(emGid);
}}
MenuProps={{ sx: { maxHeight: '50%', minHeight: 400 } }}
>
<MenuItem value="null">None</MenuItem>
{server.emProviders.providers.map(emp =>
emp.models
.filter(em => em !== '*') // TODO: support registry providers
.map(em => (
<MenuItem value={`${emp.id}:${em}`}>
{emp.name} :: {em}
</MenuItem>
))
)}
</Select>
{emGlobalId && (
<ModelFields
fields={emProvider?.fields}
values={embeddingModelFields}
onChange={setEmbeddingModelFields}
/>
)}
</Select>
</Box>
) : (
<p>No embedding models available.</p>
)}
Expand Down

0 comments on commit 40ba065

Please sign in to comment.