From ee23841cd9e7d79c2ca20b14015e5e9810dccda9 Mon Sep 17 00:00:00 2001 From: Vivek Silimkhan Date: Fri, 27 Oct 2023 10:37:29 +0530 Subject: [PATCH] Format and remove unnecessary imports --- spacy_llm/models/__init__.py | 3 + spacy_llm/models/bedrock/model.py | 100 +++++++++++++----------- spacy_llm/models/bedrock/registry.py | 42 ++++++---- usage_examples/ner_v3_titan/fewshot.cfg | 2 +- 4 files changed, 83 insertions(+), 64 deletions(-) diff --git a/spacy_llm/models/__init__.py b/spacy_llm/models/__init__.py index c1427009..7ca165f9 100644 --- a/spacy_llm/models/__init__.py +++ b/spacy_llm/models/__init__.py @@ -1,3 +1,4 @@ +from .bedrock import titan_express, titan_lite from .hf import dolly_hf, openllama_hf, stablelm_hf from .langchain import query_langchain from .rest import anthropic, cohere, noop, openai, palm @@ -12,4 +13,6 @@ "openllama_hf", "palm", "query_langchain", + "titan_lite", + "titan_express", ] diff --git a/spacy_llm/models/bedrock/model.py b/spacy_llm/models/bedrock/model.py index b4fdd0bf..2ebec5b1 100644 --- a/spacy_llm/models/bedrock/model.py +++ b/spacy_llm/models/bedrock/model.py @@ -1,46 +1,29 @@ -import os import json +import os import warnings from enum import Enum -from requests import HTTPError -from typing import Any, Dict, Iterable, Optional, Type, List, Sized, Tuple - -from confection import SimpleFrozenDict +from typing import Any, Dict, Iterable, List, Optional -from ...registry import registry - -try: - import boto3 - import botocore - from botocore.config import Config -except ImportError as err: - print("To use Bedrock, you need to install boto3. Use `pip install boto3` ") - raise err class Models(str, Enum): # Completion models TITAN_EXPRESS = "amazon.titan-text-express-v1" TITAN_LITE = "amazon.titan-text-lite-v1" -class Bedrock(): + +class Bedrock: def __init__( - self, - model_id: str, - region: str, - config: Dict[Any, Any], - max_retries: int = 5 + self, model_id: str, region: str, config: Dict[Any, Any], max_retries: int = 5 ): - self._region = region self._model_id = model_id self._config = config self._max_retries = max_retries - - # @property - def get_session(self) -> Dict[str, str]: + + def get_session_kwargs(self) -> Dict[str, Optional[str]]: # Fetch and check the credentials - profile = os.getenv("AWS_PROFILE") if not None else "" + profile = os.getenv("AWS_PROFILE") if not None else "" secret_key_id = os.getenv("AWS_ACCESS_KEY_ID") secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") session_token = os.getenv("AWS_SESSION_TOKEN") @@ -48,53 +31,78 @@ def get_session(self) -> Dict[str, str]: if profile is None: warnings.warn( "Could not find the AWS_PROFILE to access the Amazon Bedrock . Ensure you have an AWS_PROFILE " - "set up by making it available as an environment variable 'AWS_PROFILE'." - ) + "set up by making it available as an environment variable AWS_PROFILE." + ) if secret_key_id is None: warnings.warn( "Could not find the AWS_ACCESS_KEY_ID to access the Amazon Bedrock . Ensure you have an AWS_ACCESS_KEY_ID " - "set up by making it available as an environment variable 'AWS_ACCESS_KEY_ID'." + "set up by making it available as an environment variable AWS_ACCESS_KEY_ID." ) + if secret_access_key is None: warnings.warn( "Could not find the AWS_SECRET_ACCESS_KEY to access the Amazon Bedrock . Ensure you have an AWS_SECRET_ACCESS_KEY " - "set up by making it available as an environment variable 'AWS_SECRET_ACCESS_KEY'." + "set up by making it available as an environment variable AWS_SECRET_ACCESS_KEY." ) + if session_token is None: warnings.warn( "Could not find the AWS_SESSION_TOKEN to access the Amazon Bedrock . Ensure you have an AWS_SESSION_TOKEN " - "set up by making it available as an environment variable 'AWS_SESSION_TOKEN'." + "set up by making it available as an environment variable AWS_SESSION_TOKEN." ) assert secret_key_id is not None assert secret_access_key is not None assert session_token is not None - - session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token} - bedrock = boto3.Session(**session_kwargs) - return bedrock - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + session_kwargs = { + "profile_name": profile, + "region_name": self._region, + "aws_access_key_id": secret_key_id, + "aws_secret_access_key": secret_access_key, + "aws_session_token": session_token, + } + return session_kwargs + + def __call__(self, prompts: Iterable[str]) -> Iterable[str]: api_responses: List[str] = [] prompts = list(prompts) - api_config = Config(retries = dict(max_attempts = self._max_retries)) - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - session = self.get_session() - print("Session:", session) + def _request(json_data: str) -> str: + try: + import boto3 + except ImportError as err: + warnings.warn( + "To use Bedrock, you need to install boto3. Use pip install boto3 " + ) + raise err + from botocore.config import Config + + session_kwargs = self.get_session_kwargs() + session = boto3.Session(**session_kwargs) + api_config = Config(retries=dict(max_attempts=self._max_retries)) bedrock = session.client(service_name="bedrock-runtime", config=api_config) - accept = 'application/json' - contentType = 'application/json' - r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType) - responses = json.loads(r['body'].read().decode())['results'][0]['outputText'] + accept = "application/json" + contentType = "application/json" + r = bedrock.invoke_model( + body=json_data, + modelId=self._model_id, + accept=accept, + contentType=contentType, + ) + responses = json.loads(r["body"].read().decode())["results"][0][ + "outputText" + ] return responses for prompt in prompts: if self._model_id in [Models.TITAN_LITE, Models.TITAN_EXPRESS]: - responses = _request(json.dumps({"inputText": prompt, "textGenerationConfig":self._config})) - if "error" in responses: - return responses["error"] + responses = _request( + json.dumps( + {"inputText": prompt, "textGenerationConfig": self._config} + ) + ) api_responses.append(responses) diff --git a/spacy_llm/models/bedrock/registry.py b/spacy_llm/models/bedrock/registry.py index 6e39d3c6..423279d7 100644 --- a/spacy_llm/models/bedrock/registry.py +++ b/spacy_llm/models/bedrock/registry.py @@ -1,22 +1,28 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, List from confection import SimpleFrozenDict from ...registry import registry from .model import Bedrock, Models -_DEFAULT_RETRIES = 5 -_DEFAULT_TEMPERATURE = 0.0 -_DEFAULT_MAX_TOKEN_COUNT = 512 -_DEFAULT_TOP_P = 1 -_DEFAULT_STOP_SEQUENCES = [] +_DEFAULT_RETRIES: int = 5 +_DEFAULT_TEMPERATURE: float = 0.0 +_DEFAULT_MAX_TOKEN_COUNT: int = 512 +_DEFAULT_TOP_P: int = 1 +_DEFAULT_STOP_SEQUENCES: List[str] = [] + @registry.llm_models("spacy.Bedrock.Titan.Express.v1") def titan_express( region: str, model_id: Models = Models.TITAN_EXPRESS, - config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P), - max_retries: int = _DEFAULT_RETRIES + config: Dict[Any, Any] = SimpleFrozenDict( + temperature=_DEFAULT_TEMPERATURE, + maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, + stopSequences=_DEFAULT_STOP_SEQUENCES, + topP=_DEFAULT_TOP_P, + ), + max_retries: int = _DEFAULT_RETRIES, ) -> Callable[[Iterable[str]], Iterable[str]]: """Returns Bedrock instance for 'amazon-titan-express' model using boto3 to prompt API. model_id (ModelId): ID of the deployed model (titan-express) @@ -24,18 +30,21 @@ def titan_express( config (Dict[Any, Any]): LLM config passed on to the model's initialization. """ return Bedrock( - model_id = model_id, - region = region, - config=config, - max_retries=max_retries + model_id=model_id, region=region, config=config, max_retries=max_retries ) + @registry.llm_models("spacy.Bedrock.Titan.Lite.v1") def titan_lite( region: str, model_id: Models = Models.TITAN_LITE, - config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P), - max_retries: int = _DEFAULT_RETRIES + config: Dict[Any, Any] = SimpleFrozenDict( + temperature=_DEFAULT_TEMPERATURE, + maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, + stopSequences=_DEFAULT_STOP_SEQUENCES, + topP=_DEFAULT_TOP_P, + ), + max_retries: int = _DEFAULT_RETRIES, ) -> Callable[[Iterable[str]], Iterable[str]]: """Returns Bedrock instance for 'amazon-titan-lite' model using boto3 to prompt API. region (str): Specify the AWS region for the service @@ -44,9 +53,8 @@ def titan_lite( config (Dict[Any, Any]): LLM config passed on to the model's initialization. """ return Bedrock( - model_id = model_id, - region = region, + model_id=model_id, + region=region, config=config, max_retries=max_retries, ) - diff --git a/usage_examples/ner_v3_titan/fewshot.cfg b/usage_examples/ner_v3_titan/fewshot.cfg index f49f7351..d18fb7ca 100644 --- a/usage_examples/ner_v3_titan/fewshot.cfg +++ b/usage_examples/ner_v3_titan/fewshot.cfg @@ -29,4 +29,4 @@ path = "${paths.examples}" [components.llm.model] @llm_models = "spacy.Bedrock.Titan.Express.v1" -region = us-east-1 \ No newline at end of file +region =