Skip to content

Commit

Permalink
Merge pull request #1689 from BerriAI/litellm_set_organization_on_con…
Browse files Browse the repository at this point in the history
…fig.yaml

[Feat] Set OpenAI organization for litellm.completion, Proxy Config
  • Loading branch information
ishaan-jaff authored Jan 30, 2024
2 parents 2686ec0 + e011c4a commit dd9c788
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 3 deletions.
8 changes: 7 additions & 1 deletion docs/my-website/docs/proxy/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ print(response)
</Tabs>


## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Headers etc.)
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Organization, Headers etc.)
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.

[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
Expand All @@ -210,6 +210,12 @@ model_list:
api_key: sk-123
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
temperature: 0.2
- model_name: openai-gpt-3.5
litellm_params:
model: openai/gpt-3.5-turbo
api_key: sk-123
organization: org-ikDc4ex8NB
temperature: 0.2
- model_name: mistral-7b
litellm_params:
model: ollama/mistral
Expand Down
16 changes: 15 additions & 1 deletion litellm/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def completion(
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
):
super().completion()
exception_mapping_worked = False
Expand Down Expand Up @@ -254,6 +255,7 @@ def completion(
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
return self.acompletion(
Expand All @@ -266,6 +268,7 @@ def completion(
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
elif optional_params.get("stream", False):
return self.streaming(
Expand All @@ -278,6 +281,7 @@ def completion(
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
if not isinstance(max_retries, int):
Expand All @@ -291,6 +295,7 @@ def completion(
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
Expand Down Expand Up @@ -358,6 +363,7 @@ async def acompletion(
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
Expand All @@ -372,6 +378,7 @@ async def acompletion(
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
Expand Down Expand Up @@ -412,6 +419,7 @@ def streaming(
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
Expand All @@ -423,6 +431,7 @@ def streaming(
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
Expand Down Expand Up @@ -454,6 +463,7 @@ async def async_streaming(
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
Expand All @@ -467,6 +477,7 @@ async def async_streaming(
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
Expand Down Expand Up @@ -748,8 +759,11 @@ async def ahealth_check(
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
organization: Optional[str] = None,
):
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
client = AsyncOpenAI(
api_key=api_key, timeout=timeout, organization=organization
)
if model is None and mode != "image_generation":
raise Exception("model is not set")

Expand Down
7 changes: 6 additions & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def completion(
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
Expand Down Expand Up @@ -787,7 +788,8 @@ def completion(
or "https://api.openai.com/v1"
)
openai.organization = (
litellm.organization
organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
Expand Down Expand Up @@ -827,6 +829,7 @@ def completion(
timeout=timeout,
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
)
except Exception as e:
## LOGGING - log the original exception returned
Expand Down Expand Up @@ -3224,6 +3227,7 @@ async def ahealth_check(
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
organization = model_params.get("organization")

timeout = (
model_params.get("timeout")
Expand All @@ -3241,6 +3245,7 @@ async def ahealth_check(
mode=mode,
prompt=prompt,
input=input,
organization=organization,
)
else:
if mode == "embedding":
Expand Down
10 changes: 10 additions & 0 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,12 @@ def set_client(self, model: dict):
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries

organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization

if "azure" in model_name:
if api_base is None:
raise ValueError(
Expand Down Expand Up @@ -1610,6 +1616,7 @@ def set_client(self, model: dict):
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
Expand All @@ -1630,6 +1637,7 @@ def set_client(self, model: dict):
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(
Expand All @@ -1651,6 +1659,7 @@ def set_client(self, model: dict):
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
Expand All @@ -1672,6 +1681,7 @@ def set_client(self, model: dict):
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(
Expand Down
16 changes: 16 additions & 0 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,22 @@ def test_completion_openai():
# test_completion_openai()


def test_completion_openai_organization():
try:
litellm.set_verbose = True
try:
response = completion(
model="gpt-3.5-turbo", messages=messages, organization="org-ikDc4ex8NB"
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
assert "No such organization: org-ikDc4ex8NB" in str(e)

except Exception as e:
print(e)
pytest.fail(f"Error occurred: {e}")


def test_completion_text_openai():
try:
# litellm.set_verbose = True
Expand Down
53 changes: 53 additions & 0 deletions litellm/tests/test_router_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,56 @@ def test_router_init_gpt_4_vision_enhancements():
print("passed")
except Exception as e:
pytest.fail(f"Error occurred: {e}")


def test_openai_with_organization():
try:
print("Testing OpenAI with organization")
model_list = [
{
"model_name": "openai-bad-org",
"litellm_params": {
"model": "gpt-3.5-turbo",
"organization": "org-ikDc4ex8NB",
},
},
{
"model_name": "openai-good-org",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
]

router = Router(model_list=model_list)

print(router.model_list)
print(router.model_list[0])

openai_client = router._get_client(
deployment=router.model_list[0],
kwargs={"input": ["hello"], "model": "openai-bad-org"},
)
print(vars(openai_client))

assert openai_client.organization == "org-ikDc4ex8NB"

# bad org raises error

try:
response = router.completion(
model="openai-bad-org",
messages=[{"role": "user", "content": "this is a test"}],
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
print("Got exception: " + str(e))
assert "No such organization: org-ikDc4ex8NB" in str(e)

# good org works
response = router.completion(
model="openai-good-org",
messages=[{"role": "user", "content": "this is a test"}],
max_tokens=5,
)

except Exception as e:
pytest.fail(f"Error occurred: {e}")

0 comments on commit dd9c788

Please sign in to comment.