Skip to content

Commit

Permalink
model listing added
Browse files Browse the repository at this point in the history
  • Loading branch information
sajosam committed Oct 16, 2024
1 parent 4621cd5 commit 42564c8
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 0 deletions.
27 changes: 27 additions & 0 deletions app/api/v1/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,32 @@ def update_inference(inference_id: int, inference: schemas.InferenceBaseUpdate,
data={"inference": result}
)

@inference_router.post("/get/models", response_model=resp_schemas.CommonResponse)
def get_llm_provider_models(llm_provider: schemas.LLMProviderBase):
"""
Retrieves the models associated with the specified LLM provider.
Args:
llm_provider (schemas.LLMProviderBase): The details of the LLM provider.
db (Session): The database session dependency.
Returns:
CommonResponse: A response containing the list of LLM provider models or an error message.
"""

data, is_error = svc.get_llm_provider_models(llm_provider)

if is_error:
return commons.is_error_response("LLM Provider Models Not Found", data, {"provider_models": []})

return resp_schemas.CommonResponse(
status=True,
status_code=200,
message="LLM Provider Models Found",
error=None,
data={"provider_models": [data]}
)



@actions.get("/list", response_model=resp_schemas.CommonResponse)
Expand Down Expand Up @@ -819,3 +845,4 @@ def delete_action(action_id: int, db: Session = Depends(get_db)):
message="Action Deleted",
error=None
)

4 changes: 4 additions & 0 deletions app/schemas/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class InferenceBaseUpdate(BaseModel):
class InferenceResponse(InferenceBase):
id:int

class LLMProviderBase(BaseModel):
key:str
api_key:str

class ConfigurationResponse(ConfigurationBase):
id: int
capabilities: List[CapabilitiesBase]
Expand Down
18 changes: 18 additions & 0 deletions app/services/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from app.services.connector_details import get_plugin_metadata
from fastapi import Request
from app.providers.data_preperation import SourceDocuments
from app.services.model_reader import model_reader



Expand Down Expand Up @@ -924,6 +925,23 @@ def update_inference(inference_id: int, inference: schemas.InferenceBaseUpdate,

return data, None

def get_llm_provider_models(llm_provider: schemas.LLMProviderBase):
"""
Retrieves the models available for a given LLM provider.
Args:
llm_provider (schemas.LLMProviderBase): The LLM provider object.
Returns:
Tuple: List of models and error message (if any).
"""

if llm_provider.key:
return model_reader(llm_provider)
else:
return None, "Missing LLM provider key"


def list_actions(db:Session):

"""
Expand Down
101 changes: 101 additions & 0 deletions app/services/model_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import requests

def openai_llm_models(llm_provider):
"""
Retrieve models from the OpenAI API.
Args:
llm_provider: The LLM provider object with API key.
Returns:
List of OpenAI model names or an error message.
"""
url = "https://api.openai.com/v1/models"
headers = {
"Authorization": f"Bearer {llm_provider.api_key}",
}

try:
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
models = [{"display_name": model["id"], "id": model["id"]} for model in data.get("data", [])]

return models, False
else:
return f"Failed to retrieve OpenAI models: {response.status_code} {response.text}", True
except requests.RequestException as e:
return f"Error occurred: {str(e)}", True

def togetherai_llm_models(llm_provider):
"""
Retrieve models from the TogetherAI API.
Args:
llm_provider: The LLM provider object with API key.
Returns:
List of TogetherAI model names or an error message.
"""
url = "https://api.together.xyz/v1/models"
headers = {
"Authorization": f"Bearer {llm_provider.api_key}",
}

try:
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
models = [{"display_name": model["display_name"], "id": model["id"]} for model in data]

return models, False
else:
return f"Failed to retrieve TogetherAI models: {response.status_code} {response.text}", True
except requests.RequestException as e:
return f"Error occurred: {str(e)}", True

import requests

def ai71_llm_models(llm_provider):
"""
Retrieve models from the AI71 API and reformat the response.
Args:
llm_provider: The LLM provider object with API key.
Returns:
List of reformatted model information or an error message.
"""
url = "https://api.ai71.ai/v1/models"

try:
response = requests.get(url)
if response.status_code == 200:
data = response.json()
models = [{"display_name": model["name"], "id": model["id"]} for model in data.get("data", [])]
return models, False
else:
return f"Failed to retrieve AI71 models: {response.status_code} {response.text}", True
except requests.RequestException as e:
return f"Error occurred: {str(e)}", True


def model_reader(llm_provider):
"""
Retrieves the models for a given LLM provider based on its key.
Args:
llm_provider: The LLM provider object with API key and key.
Returns:
List of model names or an error message.
"""
match llm_provider.key:
case "openai":
return openai_llm_models(llm_provider)
case "togethor":
return togetherai_llm_models(llm_provider)
case "ai71":
return ai71_llm_models(llm_provider)
case _:
return "Invalid LLM provider key", False

0 comments on commit 42564c8

Please sign in to comment.