Skip to content

Commit

Permalink
Backport PR #531: Adds multi-environment variable authentication, Bai…
Browse files Browse the repository at this point in the history
…du Qianfan ERNIE-bot provider (#539)

Co-authored-by: Jason Weill <93281816+JasonWeill@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and JasonWeill authored Dec 21, 2023
1 parent f6bd114 commit ad81851
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,5 @@ dev.sh

.jupyter_ystore.db
.yarn

.conda/
4 changes: 3 additions & 1 deletion docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,23 @@ Jupyter AI supports a wide range of model providers and models. To use Jupyter A

Jupyter AI supports the following model providers:

| Provider | Provider ID | Environment variable | Python package(s) |
| Provider | Provider ID | Environment variable(s) | Python package(s) |
|---------------------|----------------------|----------------------------|---------------------------------|
| AI21 | `ai21` | `AI21_API_KEY` | `ai21` |
| Anthropic | `anthropic` | `ANTHROPIC_API_KEY` | `anthropic` |
| Anthropic (chat) | `anthropic-chat` | `ANTHROPIC_API_KEY` | `anthropic` |
| Bedrock | `bedrock` | N/A | `boto3` |
| Bedrock (chat) | `bedrock-chat` | N/A | `boto3` |
| Cohere | `cohere` | `COHERE_API_KEY` | `cohere` |
| ERNIE-Bot | `qianfan` | `QIANFAN_AK`, `QIANFAN_SK` | `qianfan` |
| GPT4All | `gpt4all` | N/A | `gpt4all` |
| Hugging Face Hub | `huggingface_hub` | `HUGGINGFACEHUB_API_TOKEN` | `huggingface_hub`, `ipywidgets`, `pillow` |
| OpenAI | `openai` | `OPENAI_API_KEY` | `openai` |
| OpenAI (chat) | `openai-chat` | `OPENAI_API_KEY` | `openai` |
| SageMaker | `sagemaker-endpoint` | N/A | `boto3` |

The environment variable names shown above are also the names of the settings keys used when setting up the chat interface.
If multiple variables are listed for a provider, **all** must be specified.

To use the Bedrock models, you need access to the Bedrock service. For more information, see the
[Amazon Bedrock Homepage](https://aws.amazon.com/bedrock/).
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
GPT4AllEmbeddingsProvider,
HfHubEmbeddingsProvider,
OpenAIEmbeddingsProvider,
QianfanEmbeddingsEndpointProvider,
)
from .exception import store_exception
from .magics import AiMagics
Expand All @@ -27,6 +28,7 @@
GPT4AllProvider,
HfHubProvider,
OpenAIProvider,
QianfanProvider,
SmEndpointProvider,
)

Expand Down
3 changes: 3 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
"ernie-bot": "qianfan:ERNIE-Bot",
"ernie-bot-4": "qianfan:ERNIE-Bot-4",
"titan": "bedrock:amazon.titan-tg1-large",
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
AwsAuthStrategy,
EnvAuthStrategy,
Field,
MultiEnvAuthStrategy,
)
from langchain.embeddings import (
BedrockEmbeddings,
CohereEmbeddings,
GPT4AllEmbeddings,
HuggingFaceHubEmbeddings,
OpenAIEmbeddings,
QianfanEmbeddingsEndpoint,
)
from langchain.pydantic_v1 import BaseModel, Extra

Expand Down Expand Up @@ -127,3 +129,14 @@ def __init__(self, **kwargs):
models = ["all-MiniLM-L6-v2-f16"]
model_id_key = "model_id"
pypi_package_deps = ["gpt4all"]


class QianfanEmbeddingsEndpointProvider(
BaseEmbeddingsProvider, QianfanEmbeddingsEndpoint
):
id = "qianfan"
name = "ERNIE-Bot"
models = ["ERNIE-Bot", "ERNIE-Bot-4"]
model_id_key = "model"
pypi_package_deps = ["qianfan"]
auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])
97 changes: 60 additions & 37 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from IPython import get_ipython
from IPython.core.magic import Magics, line_cell_magic, magics_class
from IPython.display import HTML, JSON, Markdown, Math
from jupyter_ai_magics.aliases import MODEL_ID_ALIASES
from jupyter_ai_magics.utils import decompose_model_id, get_lm_providers
from langchain.chains import LLMChain
from langchain.schema import HumanMessage
Expand All @@ -28,14 +29,6 @@
)
from .providers import BaseProvider

MODEL_ID_ALIASES = {
"gpt2": "huggingface_hub:gpt2",
"gpt3": "openai:text-davinci-003",
"chatgpt": "openai-chat:gpt-3.5-turbo",
"gpt4": "openai-chat:gpt-4",
"titan": "bedrock:amazon.titan-tg1-large",
}


class TextOrMarkdown:
def __init__(self, text, markdown):
Expand Down Expand Up @@ -108,6 +101,18 @@ def _repr_mimebundle_(self, include=None, exclude=None):

AI_COMMANDS = {"delete", "error", "help", "list", "register", "update"}

# Strings for listing providers and models
# Avoid composing strings, to make localization easier in the future
ENV_NOT_SET = "You have not set this environment variable, so you cannot use this provider's models."
ENV_SET = (
"You have set this environment variable, so you can use this provider's models."
)
MULTIENV_NOT_SET = "You have not set all of these environment variables, so you cannot use this provider's models."
MULTIENV_SET = "You have set all of these environment variables, so you can use this provider's models."

ENV_REQUIRES = "Requires environment variable:"
MULTIENV_REQUIRES = "Requires environment variables:"


class FormatDict(dict):
"""Subclass of dict to be passed to str#format(). Suppresses KeyError and
Expand Down Expand Up @@ -190,44 +195,53 @@ def _ai_env_status_for_provider_markdown(self, provider_id):
):
return na_message # No emoji

try:
env_var = self.providers[provider_id].auth_strategy.name
except AttributeError: # No "name" attribute
not_set_title = ENV_NOT_SET
set_title = ENV_SET
env_status_ok = False

auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy.type == "env":
var_name = auth_strategy.name
env_var_display = f"`{var_name}`"
env_status_ok = var_name in os.environ
elif auth_strategy.type == "multienv":
# Check multiple environment variables
var_names = self.providers[provider_id].auth_strategy.names
formatted_names = [f"`{name}`" for name in var_names]
env_var_display = ", ".join(formatted_names)
env_status_ok = all(var_name in os.environ for var_name in var_names)
not_set_title = MULTIENV_NOT_SET
set_title = MULTIENV_SET
else: # No environment variables
return na_message

output = f"`{env_var}` | "
if os.getenv(env_var) == None:
output += (
'<abbr title="You have not set this environment variable, '
+ "so you cannot use this provider's models.\">❌</abbr>"
)
output = f"{env_var_display} | "
if env_status_ok:
output += f'<abbr title="{set_title}">✅</abbr>'
else:
output += (
'<abbr title="You have set this environment variable, '
+ "so you can use this provider's models.\">✅</abbr>"
)
output += f'<abbr title="{not_set_title}">❌</abbr>'

return output

def _ai_env_status_for_provider_text(self, provider_id):
if (
provider_id not in self.providers
or self.providers[provider_id].auth_strategy == None
# only handle providers with "env" or "multienv" auth strategy
auth_strategy = getattr(self.providers[provider_id], "auth_strategy", None)
if not auth_strategy or (
auth_strategy.type != "env" and auth_strategy.type != "multienv"
):
return "" # No message necessary

try:
env_var = self.providers[provider_id].auth_strategy.name
except AttributeError: # No "name" attribute
return ""

output = f"Requires environment variable {env_var} "
if os.getenv(env_var) != None:
output += "(set)"
else:
output += "(not set)"
prefix = ENV_REQUIRES if auth_strategy.type == "env" else MULTIENV_REQUIRES
envvars = (
[auth_strategy.name]
if auth_strategy.type == "env"
else auth_strategy.names[:]
)

for i in range(len(envvars)):
envvars[i] += " (set)" if envvars[i] in os.environ else " (not set)"

return output + "\n"
return prefix + " " + ", ".join(envvars) + "\n"

# Is this a name of a Python variable that can be called as a LangChain chain?
def _is_langchain_chain(self, name):
Expand Down Expand Up @@ -513,13 +527,22 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
# validate presence of authn credentials
auth_strategy = self.providers[provider_id].auth_strategy
if auth_strategy:
# TODO: handle auth strategies besides EnvAuthStrategy
if auth_strategy.type == "env" and auth_strategy.name not in os.environ:
raise OSError(
f"Authentication environment variable {auth_strategy.name} not provided.\n"
f"Authentication environment variable {auth_strategy.name} is not set.\n"
f"An authentication token is required to use models from the {Provider.name} provider.\n"
f"Please specify it via `%env {auth_strategy.name}=token`. "
) from None
if auth_strategy.type == "multienv":
# Multiple environment variables must be set
missing_vars = [
var for var in auth_strategy.names if var not in os.environ
]
raise OSError(
f"Authentication environment variables {missing_vars} are not set.\n"
f"Multiple authentication tokens are required to use models from the {Provider.name} provider.\n"
f"Please specify them all via `%env` commands. "
) from None

# configure and instantiate provider
provider_params = {"model_id": local_model_id}
Expand Down
14 changes: 13 additions & 1 deletion packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
BedrockChat,
ChatAnthropic,
ChatOpenAI,
QianfanChatEndpoint,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms import (
Expand All @@ -34,6 +35,7 @@
HuggingFaceHub,
OpenAI,
OpenAIChat,
QianfanLLMEndpoint,
SagemakerEndpoint,
)
from langchain.llms.sagemaker_endpoint import LLMContentHandler
Expand All @@ -54,7 +56,7 @@ class EnvAuthStrategy(BaseModel):
class MultiEnvAuthStrategy(BaseModel):
"""Require multiple auth tokens via multiple environment variables."""

type: Literal["file"] = "file"
type: Literal["multienv"] = "multienv"
names: List[str]


Expand Down Expand Up @@ -775,3 +777,13 @@ async def _agenerate(self, *args, **kwargs) -> Coroutine[Any, Any, LLMResult]:
@property
def allows_concurrency(self):
return not "anthropic" in self.model_id


# Baidu QianfanChat provider. temporarily living as a separate class until
class QianfanProvider(BaseProvider, QianfanChatEndpoint):
id = "qianfan"
name = "ERNIE-Bot"
models = ["ERNIE-Bot", "ERNIE-Bot-4"]
model_id_key = "model_name"
pypi_package_deps = ["qianfan"]
auth_strategy = MultiEnvAuthStrategy(names=["QIANFAN_AK", "QIANFAN_SK"])
5 changes: 4 additions & 1 deletion packages/jupyter-ai-magics/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ all = [
"ipywidgets",
"pillow",
"openai",
"boto3"
"boto3",
"qianfan"
]

[project.entry-points."jupyter_ai.model_providers"]
Expand All @@ -67,13 +68,15 @@ sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
anthropic-chat = "jupyter_ai_magics:ChatAnthropicProvider"
amazon-bedrock-chat = "jupyter_ai_magics:BedrockChatProvider"
qianfan = "jupyter_ai_magics:QianfanProvider"

[project.entry-points."jupyter_ai.embeddings_model_providers"]
bedrock = "jupyter_ai_magics:BedrockEmbeddingsProvider"
cohere = "jupyter_ai_magics:CohereEmbeddingsProvider"
gpt4all = "jupyter_ai_magics:GPT4AllEmbeddingsProvider"
huggingface_hub = "jupyter_ai_magics:HfHubEmbeddingsProvider"
openai = "jupyter_ai_magics:OpenAIEmbeddingsProvider"
qianfan = "jupyter_ai_magics:QianfanEmbeddingsEndpointProvider"

[tool.hatch.version]
source = "nodejs"
Expand Down
15 changes: 15 additions & 0 deletions packages/jupyter-ai/src/components/chat-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,27 @@ export function ChatSettings(): JSX.Element {
) {
newApiKeys[lmAuth.name] = '';
}
if (lmAuth?.type === 'multienv') {
lmAuth.names.forEach(apiKey => {
if (!server.config.api_keys.includes(apiKey)) {
newApiKeys[apiKey] = '';
}
});
}

if (
emAuth?.type === 'env' &&
!server.config.api_keys.includes(emAuth.name)
) {
newApiKeys[emAuth.name] = '';
}
if (emAuth?.type === 'multienv') {
emAuth.names.forEach(apiKey => {
if (!server.config.api_keys.includes(apiKey)) {
newApiKeys[apiKey] = '';
}
});
}

setApiKeys(newApiKeys);
}, [lmProvider, emProvider, server]);
Expand Down
11 changes: 10 additions & 1 deletion packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,16 @@ export namespace AiService {
type: 'aws';
};

export type AuthStrategy = EnvAuthStrategy | AwsAuthStrategy | null;
export type MultiEnvAuthStrategy = {
type: 'multienv';
names: string[];
};

export type AuthStrategy =
| AwsAuthStrategy
| EnvAuthStrategy
| MultiEnvAuthStrategy
| null;

export type TextField = {
type: 'text';
Expand Down

0 comments on commit ad81851

Please sign in to comment.