Skip to content

Commit

Permalink
fix completion fields and add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
dlqqq committed Dec 4, 2024
1 parent 342bb7b commit 9878da7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
9 changes: 6 additions & 3 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,10 @@ def em_provider_params(self):
@property
def completions_lm_provider_params(self):
return self._provider_params(
"completions_model_provider_id", self._lm_providers
"completions_model_provider_id", self._lm_providers, completions=True
)

def _provider_params(self, key, listing):
def _provider_params(self, key, listing, completions: bool = False):
# read config
config = self._read_config()

Expand All @@ -457,7 +457,10 @@ def _provider_params(self, key, listing):
model_id = model_uid.split(":", 1)[1]

# get config fields (e.g. base API URL, etc.)
fields = config.fields.get(model_uid, {})
if completions:
fields = config.completions_fields.get(model_uid, {})
else:
fields = config.fields.get(model_uid, {})

# get authn fields
_, Provider = get_em_provider(model_uid, listing)
Expand Down
43 changes: 31 additions & 12 deletions packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,35 @@ def configure_to_openai(cm: ConfigManager):
return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS


def configure_with_fields(cm: ConfigManager):
def configure_with_fields(cm: ConfigManager, completions: bool = False):
"""
Configures the ConfigManager with fields and API keys.
Default behavior: Configures the ConfigManager with fields and API keys.
Returns the expected result of `cm.lm_provider_params`.
If `completions` is set to `True`, this configures the ConfigManager with
completion model fields, and returns the expected result of
`cm.completions_lm_provider_params`.
"""
req = UpdateConfigRequest(
model_provider_id="openai-chat:gpt-4o",
api_keys={"OPENAI_API_KEY": "foobar"},
fields={
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com",
}
},
)
if completions:
req = UpdateConfigRequest(
completions_model_provider_id="openai-chat:gpt-4o",
api_keys={"OPENAI_API_KEY": "foobar"},
completions_fields={
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com",
}
},
)
else:
req = UpdateConfigRequest(
model_provider_id="openai-chat:gpt-4o",
api_keys={"OPENAI_API_KEY": "foobar"},
fields={
"openai-chat:gpt-4o": {
"openai_api_base": "https://example.com",
}
},
)
cm.update_config(req)
return {
"model_id": "gpt-4o",
Expand Down Expand Up @@ -445,14 +460,18 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids):
assert config_desc.embeddings_provider_id is None


def test_config_manager_returns_fields(cm):
def test_returns_chat_model_fields(cm):
"""
Asserts that `ConfigManager.lm_provider_params` returns model fields set by
the user.
"""
expected_model_args = configure_with_fields(cm)
assert cm.lm_provider_params == expected_model_args

def test_returns_completion_model_fields(cm):
expected_model_args = configure_with_fields(cm, completions=True)
assert cm.completions_lm_provider_params == expected_model_args


def test_config_manager_does_not_write_to_defaults(
config_file_with_model_fields, schema_path
Expand Down

0 comments on commit 9878da7

Please sign in to comment.