Skip to content

Commit

Permalink
Litellm code qa common config (#7113)
Browse files Browse the repository at this point in the history
* feat(base_llm): initial commit for common base config class

Addresses code qa critique andrewyng/aisuite#113 (comment)

* feat(base_llm/): add transform request/response abstract methods to base config class

* feat(cohere-+-clarifai): refactor integrations to use common base config class

* fix: fix linting errors

* refactor(anthropic/): move anthropic + vertex anthropic to use base config

* test: fix xai test

* test: fix tests

* fix: fix linting errors

* test: comment out WIP test

* fix(transformation.py): fix is pdf used check

* fix: fix linting error
  • Loading branch information
krrishdholakia committed Dec 10, 2024
1 parent 98902d6 commit 48b134a
Showing 1 changed file with 177 additions and 0 deletions.
177 changes: 177 additions & 0 deletions litellm/llms/clarifai/chat/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import json
import os
import time
import traceback
import types
from typing import Callable, List, Optional

import httpx
import requests

import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage

from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import ClarifaiError


async def async_completion(
model: str,
messages: List[AllMessageValues],
model_response: ModelResponse,
encoding,
api_key,
api_base: str,
logging_obj,
data: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
):

async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.CLARIFAI,
params={"timeout": 600.0},
)
response = await async_handler.post(
url=api_base, headers=headers, data=json.dumps(data)
)

return litellm.ClarifaiConfig().transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)


def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params: dict,
custom_prompt_dict={},
acompletion=False,
logger_fn=None,
headers={},
):
headers = litellm.ClarifaiConfig().validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
data = litellm.ClarifaiConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)

## LOGGING
logging_obj.pre_call(
input=data,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": model,
},
)
if acompletion is True:
return async_completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
data=data,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
httpx_client = _get_httpx_client(
params={"timeout": 600.0},
)
response = httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
)

if response.status_code != 200:
raise ClarifaiError(status_code=response.status_code, message=response.text)

if "stream" in optional_params and optional_params["stream"] is True:
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="clarifai",
logging_obj=logging_obj,
)
return stream_response

else:
return litellm.ClarifaiConfig().transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)


class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False

# Sync iterator
def __iter__(self):
return self

def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response

# Async iterator
def __aiter__(self):
return self

async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response

0 comments on commit 48b134a

Please sign in to comment.