Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update completion model fields immediately on save #1137

Merged
merged 5 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,10 @@ def em_provider_params(self):
@property
def completions_lm_provider_params(self):
return self._provider_params(
"completions_model_provider_id", self._lm_providers
"completions_model_provider_id", self._lm_providers, completions=True
)

def _provider_params(self, key, listing):
def _provider_params(self, key, listing, completions: bool = False):
# read config
config = self._read_config()

Expand All @@ -457,7 +457,10 @@ def _provider_params(self, key, listing):
model_id = model_uid.split(":", 1)[1]

# get config fields (e.g. base API URL, etc.)
fields = config.fields.get(model_uid, {})
if completions:
fields = config.completions_fields.get(model_uid, {})
else:
fields = config.fields.get(model_uid, {})

# get authn fields
_, Provider = get_em_provider(model_uid, listing)
Expand Down
44 changes: 32 additions & 12 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,35 @@ def configure_to_openai(cm: ConfigManager):
return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS


def configure_with_fields(cm: ConfigManager):
def configure_with_fields(cm: ConfigManager, completions: bool = False):
"""
Configures the ConfigManager with fields and API keys.
Default behavior: Configures the ConfigManager with fields and API keys.
Returns the expected result of `cm.lm_provider_params`.

If `completions` is set to `True`, this configures the ConfigManager with
completion model fields, and returns the expected result of
`cm.completions_lm_provider_params`.
"""
req = UpdateConfigRequest(
model_provider_id="openai-chat:gpt-4o",
api_keys={"OPENAI_API_KEY": "foobar"},
fields={
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com",
}
},
)
if completions:
req = UpdateConfigRequest(
completions_model_provider_id="openai-chat:gpt-4o",
api_keys={"OPENAI_API_KEY": "foobar"},
completions_fields={
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com",
}
},
)
else:
req = UpdateConfigRequest(
model_provider_id="openai-chat:gpt-4o",
api_keys={"OPENAI_API_KEY": "foobar"},
fields={
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com",
}
},
)
cm.update_config(req)
return {
"model_id": "gpt-4o",
Expand Down Expand Up @@ -445,7 +460,7 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
assert config_desc.embeddings_provider_id is None


def test_config_manager_returns_fields(cm):
def test_returns_chat_model_fields(cm):
"""
Asserts that `ConfigManager.lm_provider_params` returns model fields set by
the user.
Expand All @@ -454,6 +469,11 @@ def test_config_manager_returns_fields(cm):
assert cm.lm_provider_params == expected_model_args


def test_returns_completion_model_fields(cm):
expected_model_args = configure_with_fields(cm, completions=True)
assert cm.completions_lm_provider_params == expected_model_args


def test_config_manager_does_not_write_to_defaults(
config_file_with_model_fields, schema_path
):
Expand Down
Loading