Skip to content

Commit

Permalink
Merge pull request #741 from stanfea/main
Browse files Browse the repository at this point in the history
Use supplied headers
  • Loading branch information
ishaan-jaff authored Nov 1, 2023
2 parents f698322 + bbc82f3 commit 6685047
Showing 1 changed file with 38 additions and 13 deletions.
51 changes: 38 additions & 13 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,11 @@ def completion(
get_secret("AZURE_API_KEY")
)

headers = (
headers or
litellm.headers
)

## LOAD CONFIG - if set
config=litellm.AzureOpenAIConfig.get_config()
for k, v in config.items():
Expand All @@ -368,7 +373,7 @@ def completion(
input=messages,
api_key=api_key,
additional_args={
"headers": litellm.headers,
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
Expand All @@ -377,7 +382,7 @@ def completion(
response = openai.ChatCompletion.create(
engine=model,
messages=messages,
headers=litellm.headers,
headers=headers,
api_key=api_key,
api_base=api_base,
api_version=api_version,
Expand All @@ -393,7 +398,7 @@ def completion(
api_key=api_key,
original_response=response,
additional_args={
"headers": litellm.headers,
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
Expand Down Expand Up @@ -426,6 +431,11 @@ def completion(
get_secret("OPENAI_API_KEY")
)

headers = (
headers or
litellm.headers
)

## LOAD CONFIG - if set
config=litellm.OpenAIConfig.get_config()
for k, v in config.items():
Expand All @@ -436,7 +446,7 @@ def completion(
logging.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": litellm.headers, "api_base": api_base},
additional_args={"headers": headers, "api_base": api_base},
)
## COMPLETION CALL
try:
Expand All @@ -457,7 +467,7 @@ def completion(
response = openai.ChatCompletion.create(
model=model,
messages=messages,
headers=litellm.headers, # None by default
headers=headers, # None by default
api_base=api_base, # thread safe setting base, key, api_version
api_key=api_key,
api_type="openai",
Expand All @@ -470,7 +480,7 @@ def completion(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"headers": litellm.headers},
additional_args={"headers": headers},
)
raise e

Expand All @@ -482,7 +492,7 @@ def completion(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": litellm.headers},
additional_args={"headers": headers},
)
elif (
model in litellm.open_ai_text_completion_models
Expand Down Expand Up @@ -514,6 +524,11 @@ def completion(
get_secret("OPENAI_API_KEY")
)

headers = (
headers or
litellm.headers
)

## LOAD CONFIG - if set
config=litellm.OpenAITextCompletionConfig.get_config()
for k, v in config.items():
Expand All @@ -534,7 +549,7 @@ def completion(
api_key=api_key,
additional_args={
"openai_organization": litellm.organization,
"headers": litellm.headers,
"headers": headers,
"api_base": api_base,
"api_type": openai.api_type,
},
Expand All @@ -543,7 +558,7 @@ def completion(
response = openai.Completion.create(
model=model,
prompt=prompt,
headers=litellm.headers,
headers=headers,
api_key = api_key,
api_base=api_base,
**optional_params
Expand All @@ -558,7 +573,7 @@ def completion(
original_response=response,
additional_args={
"openai_organization": litellm.organization,
"headers": litellm.headers,
"headers": headers,
"api_base": openai.api_base,
"api_type": openai.api_type,
},
Expand Down Expand Up @@ -796,6 +811,11 @@ def completion(
or "https://api.deepinfra.com/v1/openai"
)

headers = (
headers or
litellm.headers
)

## LOGGING
logging.pre_call(
input=messages,
Expand Down Expand Up @@ -828,7 +848,7 @@ def completion(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": litellm.headers},
additional_args={"headers": headers},
)
elif (
custom_llm_provider == "huggingface"
Expand Down Expand Up @@ -909,6 +929,11 @@ def completion(
"OR_API_KEY"
) or litellm.api_key

headers = (
headers or
litellm.headers
)

data = {
"model": model,
"messages": messages,
Expand All @@ -917,9 +942,9 @@ def completion(
## LOGGING
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
## COMPLETION CALL
if litellm.headers:
if headers:
response = openai.ChatCompletion.create(
headers=litellm.headers,
headers=headers,
**data,
)
else:
Expand Down

0 comments on commit 6685047

Please sign in to comment.