diff --git a/elm/base.py b/elm/base.py index a128e7f9..f9b9e4e2 100644 --- a/elm/base.py +++ b/elm/base.py @@ -44,6 +44,11 @@ class ApiBase(ABC): MODEL_ROLE = "You are a research assistant that answers questions." """High level model role""" + TOKENIZER_ALIASES = {'gpt-35-turbo': 'gpt-3.5-turbo', + 'gpt-4-32k': 'gpt-4-32k-0314' + } + """Optional mappings for unusual Azure names to tiktoken/openai names.""" + def __init__(self, model=None): """ Parameters @@ -338,8 +343,8 @@ def get_embedding(cls, text): return embedding - @staticmethod - def count_tokens(text, model): + @classmethod + def count_tokens(cls, text, model): """Return the number of tokens in a string. Parameters @@ -355,12 +360,7 @@ def count_tokens(text, model): Number of tokens in text """ - # Optional mappings for weird azure names to tiktoken/openai names - tokenizer_aliases = {'gpt-35-turbo': 'gpt-3.5-turbo', - 'gpt-4-32k': 'gpt-4-32k-0314' - } - - token_model = tokenizer_aliases.get(model, model) + token_model = cls.TOKENIZER_ALIASES.get(model, model) encoding = tiktoken.encoding_for_model(token_model) return len(encoding.encode(text)) diff --git a/elm/wizard.py b/elm/wizard.py index 963ed21a..689727d9 100644 --- a/elm/wizard.py +++ b/elm/wizard.py @@ -399,6 +399,11 @@ class EnergyWizardPostgres(EnergyWizardBase): """ EMBEDDING_MODEL = 'amazon.titan-embed-text-v1' + TOKENIZER_ALIASES = {**EnergyWizardBase.TOKENIZER_ALIASES, + 'ewiz-gpt-4': 'gpt-4' + } + """Optional mappings for weird azure names to tiktoken/openai names.""" + def __init__(self, db_host, db_port, db_name, db_schema, db_table, cursor=None, boto_client=None, model=None, token_budget=3500):