Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

fix azure by removing model checking #495

Merged
merged 1 commit into from
Jan 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from openai.types.chat.completion_create_params import ResponseFormat
from PIL import Image

from mentat.errors import MentatError, ModelError, ReturnToUser, UserError
from mentat.errors import MentatError, ReturnToUser, UserError
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import mentat_dir_path

Expand Down Expand Up @@ -184,9 +184,7 @@ def __init__(self, models: Dict[str, Model]):

def _validate_key(self, key: str) -> str:
"""Try to match fine-tuned models to their base models."""
if super().__contains__(key):
return key
if key.startswith("ft:"):
if not super().__contains__(key) and key.startswith("ft:"):
base_model = key.split(":")[
1
] # e.g. "ft:gpt-3.5-turbo-1106:abante::8dsQMc4F"
Expand All @@ -200,8 +198,7 @@ def _validate_key(self, key: str) -> str:
key, attr.evolve(super().__getitem__(base_model), name=key)
)
return key
raise ModelError(f"Could not identify base model for {key}")
raise ModelError(f"Unrecognized model: {key}")
return key

def __getitem__(self, key: str) -> Model:
return super().__getitem__(self._validate_key(key))
Expand Down Expand Up @@ -303,7 +300,7 @@ def initialize_client(self):

# ChromaDB requires a sync function for embeddings
if azure_endpoint and azure_key:
ctx.stream.send("Using Azure OpenAI client.", color="cyan")
ctx.stream.send("Using Azure OpenAI client.", style="warning")
self.async_client = AsyncAzureOpenAI(
api_key=azure_key,
api_version="2023-12-01-preview",
Expand Down