Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Set OpenAI organization for litellm.completion, Proxy Config #1689

Merged
merged 7 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/my-website/docs/proxy/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ print(response)
</Tabs>


## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Headers etc.)
## Save Model-specific params (API Base, API Keys, Temperature, Max Tokens, Seed, Organization, Headers etc.)
You can use the config to save model-specific information like api_base, api_key, temperature, max_tokens, etc.

[**All input params**](https://docs.litellm.ai/docs/completion/input#input-params-1)
Expand All @@ -210,6 +210,12 @@ model_list:
api_key: sk-123
api_base: https://openai-gpt-4-test-v-2.openai.azure.com/
temperature: 0.2
- model_name: openai-gpt-3.5
litellm_params:
model: openai/gpt-3.5-turbo
api_key: sk-123
organization: org-ikDc4ex8NB
temperature: 0.2
- model_name: mistral-7b
litellm_params:
model: ollama/mistral
Expand Down
16 changes: 15 additions & 1 deletion litellm/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def completion(
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
):
super().completion()
exception_mapping_worked = False
Expand Down Expand Up @@ -254,6 +255,7 @@ def completion(
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
return self.acompletion(
Expand All @@ -266,6 +268,7 @@ def completion(
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
elif optional_params.get("stream", False):
return self.streaming(
Expand All @@ -278,6 +281,7 @@ def completion(
timeout=timeout,
client=client,
max_retries=max_retries,
organization=organization,
)
else:
if not isinstance(max_retries, int):
Expand All @@ -291,6 +295,7 @@ def completion(
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
Expand Down Expand Up @@ -358,6 +363,7 @@ async def acompletion(
timeout: float,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
logging_obj=None,
Expand All @@ -372,6 +378,7 @@ async def acompletion(
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
Expand Down Expand Up @@ -412,6 +419,7 @@ def streaming(
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
Expand All @@ -423,6 +431,7 @@ def streaming(
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
Expand Down Expand Up @@ -454,6 +463,7 @@ async def async_streaming(
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
organization: Optional[str] = None,
client=None,
max_retries=None,
headers=None,
Expand All @@ -467,6 +477,7 @@ async def async_streaming(
http_client=litellm.aclient_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_aclient = client
Expand Down Expand Up @@ -748,8 +759,11 @@ async def ahealth_check(
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
organization: Optional[str] = None,
):
client = AsyncOpenAI(api_key=api_key, timeout=timeout)
client = AsyncOpenAI(
api_key=api_key, timeout=timeout, organization=organization
)
if model is None and mode != "image_generation":
raise Exception("model is not set")

Expand Down
7 changes: 6 additions & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def completion(
num_retries = kwargs.get("num_retries", None) ## deprecated
max_retries = kwargs.get("max_retries", None)
context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None)
organization = kwargs.get("organization", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
Expand Down Expand Up @@ -787,7 +788,8 @@ def completion(
or "https://api.openai.com/v1"
)
openai.organization = (
litellm.organization
organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
Expand Down Expand Up @@ -827,6 +829,7 @@ def completion(
timeout=timeout,
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
)
except Exception as e:
## LOGGING - log the original exception returned
Expand Down Expand Up @@ -3224,6 +3227,7 @@ async def ahealth_check(
or custom_llm_provider == "text-completion-openai"
):
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
organization = model_params.get("organization")

timeout = (
model_params.get("timeout")
Expand All @@ -3241,6 +3245,7 @@ async def ahealth_check(
mode=mode,
prompt=prompt,
input=input,
organization=organization,
)
else:
if mode == "embedding":
Expand Down
10 changes: 10 additions & 0 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,12 @@ def set_client(self, model: dict):
max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries

organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization

if "azure" in model_name:
if api_base is None:
raise ValueError(
Expand Down Expand Up @@ -1610,6 +1616,7 @@ def set_client(self, model: dict):
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
Expand All @@ -1630,6 +1637,7 @@ def set_client(self, model: dict):
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(
Expand All @@ -1651,6 +1659,7 @@ def set_client(self, model: dict):
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(),
limits=httpx.Limits(
Expand All @@ -1672,6 +1681,7 @@ def set_client(self, model: dict):
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
organization=organization,
http_client=httpx.Client(
transport=CustomHTTPTransport(),
limits=httpx.Limits(
Expand Down
16 changes: 16 additions & 0 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,22 @@ def test_completion_openai():
# test_completion_openai()


def test_completion_openai_organization():
try:
litellm.set_verbose = True
try:
response = completion(
model="gpt-3.5-turbo", messages=messages, organization="org-ikDc4ex8NB"
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
assert "No such organization: org-ikDc4ex8NB" in str(e)

except Exception as e:
print(e)
pytest.fail(f"Error occurred: {e}")


def test_completion_text_openai():
try:
# litellm.set_verbose = True
Expand Down
53 changes: 53 additions & 0 deletions litellm/tests/test_router_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,56 @@ def test_router_init_gpt_4_vision_enhancements():
print("passed")
except Exception as e:
pytest.fail(f"Error occurred: {e}")


def test_openai_with_organization():
try:
print("Testing OpenAI with organization")
model_list = [
{
"model_name": "openai-bad-org",
"litellm_params": {
"model": "gpt-3.5-turbo",
"organization": "org-ikDc4ex8NB",
},
},
{
"model_name": "openai-good-org",
"litellm_params": {"model": "gpt-3.5-turbo"},
},
]

router = Router(model_list=model_list)

print(router.model_list)
print(router.model_list[0])

openai_client = router._get_client(
deployment=router.model_list[0],
kwargs={"input": ["hello"], "model": "openai-bad-org"},
)
print(vars(openai_client))

assert openai_client.organization == "org-ikDc4ex8NB"

# bad org raises error

try:
response = router.completion(
model="openai-bad-org",
messages=[{"role": "user", "content": "this is a test"}],
)
pytest.fail("Request should have failed - This organization does not exist")
except Exception as e:
print("Got exception: " + str(e))
assert "No such organization: org-ikDc4ex8NB" in str(e)

# good org works
response = router.completion(
model="openai-good-org",
messages=[{"role": "user", "content": "this is a test"}],
max_tokens=5,
)

except Exception as e:
pytest.fail(f"Error occurred: {e}")