-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Litellm code qa common config (#7113)
* feat(base_llm): initial commit for common base config class Addresses code qa critique andrewyng/aisuite#113 (comment) * feat(base_llm/): add transform request/response abstract methods to base config class * feat(cohere-+-clarifai): refactor integrations to use common base config class * fix: fix linting errors * refactor(anthropic/): move anthropic + vertex anthropic to use base config * test: fix xai test * test: fix tests * fix: fix linting errors * test: comment out WIP test * fix(transformation.py): fix is pdf used check * fix: fix linting error
- Loading branch information
1 parent
98902d6
commit 48b134a
Showing
1 changed file
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import json | ||
import os | ||
import time | ||
import traceback | ||
import types | ||
from typing import Callable, List, Optional | ||
|
||
import httpx | ||
import requests | ||
|
||
import litellm | ||
from litellm.llms.custom_httpx.http_handler import ( | ||
AsyncHTTPHandler, | ||
_get_httpx_client, | ||
get_async_httpx_client, | ||
) | ||
from litellm.types.llms.openai import AllMessageValues | ||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage | ||
|
||
from ...prompt_templates.factory import custom_prompt, prompt_factory | ||
from ..common_utils import ClarifaiError | ||
|
||
|
||
async def async_completion( | ||
model: str, | ||
messages: List[AllMessageValues], | ||
model_response: ModelResponse, | ||
encoding, | ||
api_key, | ||
api_base: str, | ||
logging_obj, | ||
data: dict, | ||
optional_params: dict, | ||
litellm_params=None, | ||
logger_fn=None, | ||
headers={}, | ||
): | ||
|
||
async_handler = get_async_httpx_client( | ||
llm_provider=litellm.LlmProviders.CLARIFAI, | ||
params={"timeout": 600.0}, | ||
) | ||
response = await async_handler.post( | ||
url=api_base, headers=headers, data=json.dumps(data) | ||
) | ||
|
||
return litellm.ClarifaiConfig().transform_response( | ||
model=model, | ||
raw_response=response, | ||
model_response=model_response, | ||
logging_obj=logging_obj, | ||
api_key=api_key, | ||
request_data=data, | ||
messages=messages, | ||
optional_params=optional_params, | ||
encoding=encoding, | ||
) | ||
|
||
|
||
def completion( | ||
model: str, | ||
messages: list, | ||
api_base: str, | ||
model_response: ModelResponse, | ||
print_verbose: Callable, | ||
encoding, | ||
api_key, | ||
logging_obj, | ||
optional_params: dict, | ||
litellm_params: dict, | ||
custom_prompt_dict={}, | ||
acompletion=False, | ||
logger_fn=None, | ||
headers={}, | ||
): | ||
headers = litellm.ClarifaiConfig().validate_environment( | ||
api_key=api_key, | ||
headers=headers, | ||
model=model, | ||
messages=messages, | ||
optional_params=optional_params, | ||
) | ||
data = litellm.ClarifaiConfig().transform_request( | ||
model=model, | ||
messages=messages, | ||
optional_params=optional_params, | ||
litellm_params=litellm_params, | ||
headers=headers, | ||
) | ||
|
||
## LOGGING | ||
logging_obj.pre_call( | ||
input=data, | ||
api_key=api_key, | ||
additional_args={ | ||
"complete_input_dict": data, | ||
"headers": headers, | ||
"api_base": model, | ||
}, | ||
) | ||
if acompletion is True: | ||
return async_completion( | ||
model=model, | ||
messages=messages, | ||
api_base=api_base, | ||
model_response=model_response, | ||
encoding=encoding, | ||
api_key=api_key, | ||
logging_obj=logging_obj, | ||
data=data, | ||
optional_params=optional_params, | ||
litellm_params=litellm_params, | ||
logger_fn=logger_fn, | ||
headers=headers, | ||
) | ||
else: | ||
## COMPLETION CALL | ||
httpx_client = _get_httpx_client( | ||
params={"timeout": 600.0}, | ||
) | ||
response = httpx_client.post( | ||
url=api_base, | ||
headers=headers, | ||
data=json.dumps(data), | ||
) | ||
|
||
if response.status_code != 200: | ||
raise ClarifaiError(status_code=response.status_code, message=response.text) | ||
|
||
if "stream" in optional_params and optional_params["stream"] is True: | ||
completion_stream = response.iter_lines() | ||
stream_response = CustomStreamWrapper( | ||
completion_stream=completion_stream, | ||
model=model, | ||
custom_llm_provider="clarifai", | ||
logging_obj=logging_obj, | ||
) | ||
return stream_response | ||
|
||
else: | ||
return litellm.ClarifaiConfig().transform_response( | ||
model=model, | ||
raw_response=response, | ||
model_response=model_response, | ||
logging_obj=logging_obj, | ||
api_key=api_key, | ||
request_data=data, | ||
messages=messages, | ||
optional_params=optional_params, | ||
encoding=encoding, | ||
) | ||
|
||
|
||
class ModelResponseIterator: | ||
def __init__(self, model_response): | ||
self.model_response = model_response | ||
self.is_done = False | ||
|
||
# Sync iterator | ||
def __iter__(self): | ||
return self | ||
|
||
def __next__(self): | ||
if self.is_done: | ||
raise StopIteration | ||
self.is_done = True | ||
return self.model_response | ||
|
||
# Async iterator | ||
def __aiter__(self): | ||
return self | ||
|
||
async def __anext__(self): | ||
if self.is_done: | ||
raise StopAsyncIteration | ||
self.is_done = True | ||
return self.model_response |