From 9878da79e3db7e66842f1e494b3ea1801d9f0b05 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 4 Dec 2024 13:25:27 -0800 Subject: [PATCH 1/4] fix completion fields and add unit test --- .../jupyter-ai/jupyter_ai/config_manager.py | 9 ++-- .../jupyter_ai/tests/test_config_manager.py | 43 +++++++++++++------ 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index 71ca3f185..7b309faae 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -442,10 +442,10 @@ def em_provider_params(self): @property def completions_lm_provider_params(self): return self._provider_params( - "completions_model_provider_id", self._lm_providers + "completions_model_provider_id", self._lm_providers, completions=True ) - def _provider_params(self, key, listing): + def _provider_params(self, key, listing, completions: bool = False): # read config config = self._read_config() @@ -457,7 +457,10 @@ def _provider_params(self, key, listing): model_id = model_uid.split(":", 1)[1] # get config fields (e.g. base API URL, etc.) - fields = config.fields.get(model_uid, {}) + if completions: + fields = config.completions_fields.get(model_uid, {}) + else: + fields = config.fields.get(model_uid, {}) # get authn fields _, Provider = get_em_provider(model_uid, listing) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 168c675f7..9a9ce09be 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -194,20 +194,35 @@ def configure_to_openai(cm: ConfigManager): return LM_GID, EM_GID, LM_LID, EM_LID, API_PARAMS -def configure_with_fields(cm: ConfigManager): +def configure_with_fields(cm: ConfigManager, completions: bool = False): """ - Configures the ConfigManager with fields and API keys. + Default behavior: Configures the ConfigManager with fields and API keys. Returns the expected result of `cm.lm_provider_params`. + + If `completions` is set to `True`, this configures the ConfigManager with + completion model fields, and returns the expected result of + `cm.completions_lm_provider_params`. """ - req = UpdateConfigRequest( - model_provider_id="openai-chat:gpt-4o", - api_keys={"OPENAI_API_KEY": "foobar"}, - fields={ - "openai-chat:gpt-4o": { - "openai_api_base": "https://example.com", - } - }, - ) + if completions: + req = UpdateConfigRequest( + completions_model_provider_id="openai-chat:gpt-4o", + api_keys={"OPENAI_API_KEY": "foobar"}, + completions_fields={ + "openai-chat:gpt-4o": { + "openai_api_base": "https://example.com", + } + }, + ) + else: + req = UpdateConfigRequest( + model_provider_id="openai-chat:gpt-4o", + api_keys={"OPENAI_API_KEY": "foobar"}, + fields={ + "openai-chat:gpt-4o": { + "openai_api_base": "https://example.com", + } + }, + ) cm.update_config(req) return { "model_id": "gpt-4o", @@ -445,7 +460,7 @@ def test_handle_bad_provider_ids(cm_with_bad_provider_ids): assert config_desc.embeddings_provider_id is None -def test_config_manager_returns_fields(cm): +def test_returns_chat_model_fields(cm): """ Asserts that `ConfigManager.lm_provider_params` returns model fields set by the user. @@ -453,6 +468,10 @@ def test_config_manager_returns_fields(cm): expected_model_args = configure_with_fields(cm) assert cm.lm_provider_params == expected_model_args +def test_returns_completion_model_fields(cm): + expected_model_args = configure_with_fields(cm, completions=True) + assert cm.completions_lm_provider_params == expected_model_args + def test_config_manager_does_not_write_to_defaults( config_file_with_model_fields, schema_path From 18c6c6eb8efe047d11943e50c137af6be56a67db Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 4 Dec 2024 13:35:40 -0800 Subject: [PATCH 2/4] pre-commit --- packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py index 9a9ce09be..4a739f6e5 100644 --- a/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/tests/test_config_manager.py @@ -468,6 +468,7 @@ def test_returns_chat_model_fields(cm): expected_model_args = configure_with_fields(cm) assert cm.lm_provider_params == expected_model_args + def test_returns_completion_model_fields(cm): expected_model_args = configure_with_fields(cm, completions=True) assert cm.completions_lm_provider_params == expected_model_args From 50e5bef13b3d08cc6a6f03058b9d59d6bd20ea07 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 4 Dec 2024 16:53:54 -0800 Subject: [PATCH 3/4] fix install in CI --- scripts/install.sh | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scripts/install.sh b/scripts/install.sh index bfe594b03..4ca5e52b3 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -1,7 +1,15 @@ #!/bin/bash set -eux -# install core packages -pip install jupyterlab~=4.0 + +# Install JupyterLab +# +# Excludes v4.3.2 as it pins `httpx` to a very narrow range, causing `pip +# install` to stall on package resolution. +# +# See: https://github.com/jupyterlab/jupyter-ai/issues/1138 +pip install jupyterlab~=4.0,!=4.3.2 + +# Install core packages cp playground/config.example.py playground/config.py jlpm install jlpm dev-install From 02058241d2d09b2b908d9c8fec7b0f4a13d86e6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 00:54:05 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/install.sh b/scripts/install.sh index 4ca5e52b3..7031bacb3 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -2,7 +2,7 @@ set -eux # Install JupyterLab -# +# # Excludes v4.3.2 as it pins `httpx` to a very narrow range, causing `pip # install` to stall on package resolution. #