Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Support Azure GPT-4 Vision Enhancements #1475

Merged
merged 7 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 37 additions & 28 deletions docs/my-website/docs/providers/azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,17 @@ response = completion(

```

#### Usage - with Azure Vision enhancements
### Usage - with Azure Vision enhancements

Note: **Azure requires the `base_url` to be set with `/extensions`**

Example
```python
base_url=https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions
# base_url="{azure_endpoint}/openai/deployments/{azure_deployment}/extensions"
```

**Usage**
```python
import os
from litellm import completion
Expand All @@ -126,34 +135,34 @@ os.environ["AZURE_API_KEY"] = "your-api-key"

# azure call
response = completion(
model = "azure/<your deployment name>",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What’s in this image?"
model="azure/gpt-4-vision",
timeout=5,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://avatars.githubusercontent.com/u/29436595?v=4"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
enhancements = {"ocr": {"enabled": True}, "grounding": {"enabled": True}},
dataSources = [
{
"type": "AzureComputerVision",
"parameters": {
"endpoint": "<your_computer_vision_endpoint>",
"key": "<your_computer_vision_key>",
},
}
],
},
],
}
],
base_url="https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions",
api_key=os.getenv("AZURE_VISION_API_KEY"),
enhancements={"ocr": {"enabled": True}, "grounding": {"enabled": True}},
dataSources=[
{
"type": "AzureComputerVision",
"parameters": {
"endpoint": "https://gpt-4-vision-enhancement.cognitiveservices.azure.com/",
"key": os.environ["AZURE_VISION_ENHANCE_KEY"],
},
}
],
)
```

Expand Down
38 changes: 38 additions & 0 deletions litellm/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def __init__(
)


def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
# "azure_endpoint": api_base,
# "azure_deployment": model,
# "http_client": litellm.client_session,
# "max_retries": max_retries,
# "timeout": timeout,
# }
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")

return azure_client_params


class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -239,6 +259,9 @@ def completion(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -303,6 +326,9 @@ async def acompletion(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -364,6 +390,9 @@ def streaming(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -414,6 +443,9 @@ async def async_streaming(
"max_retries": data.pop("max_retries", 2),
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -527,6 +559,9 @@ def embedding(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -659,6 +694,9 @@ def image_generation(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down
28 changes: 16 additions & 12 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,12 +1443,22 @@ def set_client(self, model: dict):
verbose_router_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
}
from litellm.llms.azure import select_azure_base_url_or_endpoint

# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params
)

cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
Expand All @@ -1467,9 +1477,7 @@ def set_client(self, model: dict):

cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
Expand All @@ -1489,9 +1497,7 @@ def set_client(self, model: dict):
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
Expand All @@ -1510,9 +1516,7 @@ def set_client(self, model: dict):

cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.Client(
Expand Down
20 changes: 15 additions & 5 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_completion_azure_gpt4_vision():
litellm.set_verbose = True
response = completion(
model="azure/gpt-4-vision",
timeout=1,
timeout=5,
messages=[
{
"role": "user",
Expand All @@ -244,21 +244,31 @@ def test_completion_azure_gpt4_vision():
],
}
],
base_url="https://gpt-4-vision-resource.openai.azure.com/",
base_url="https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions",
api_key=os.getenv("AZURE_VISION_API_KEY"),
enhancements={"ocr": {"enabled": True}, "grounding": {"enabled": True}},
dataSources=[
{
"type": "AzureComputerVision",
"parameters": {
"endpoint": "https://gpt-4-vision-enhancement.cognitiveservices.azure.com/",
"key": os.environ["AZURE_VISION_ENHANCE_KEY"],
},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh no :(

}
],
)
print(response)
except openai.APITimeoutError:
print("got a timeout error")
pass
except openai.RateLimitError:
print("got a rate liimt error")
except openai.RateLimitError as e:
print("got a rate liimt error", e)
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")


# test_completion_azure_gpt4_vision()
test_completion_azure_gpt4_vision()


@pytest.mark.skip(reason="this test is flaky")
Expand Down
59 changes: 58 additions & 1 deletion litellm/tests/test_router_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

def test_init_clients():
litellm.set_verbose = True
import logging
from litellm._logging import verbose_router_logger

verbose_router_logger.setLevel(logging.DEBUG)
try:
print("testing init 4 clients with diff timeouts")
model_list = [
Expand All @@ -39,7 +43,7 @@ def test_init_clients():
},
},
]
router = Router(model_list=model_list)
router = Router(model_list=model_list, set_verbose=True)
for elem in router.model_list:
model_id = elem["model_info"]["id"]
assert router.cache.get_cache(f"{model_id}_client") is not None
Expand All @@ -55,6 +59,18 @@ def test_init_clients():

assert async_client.timeout == 0.01
assert stream_async_client.timeout == 0.000_001
print(vars(async_client))
print()
print(async_client._base_url)
assert (
async_client._base_url
== "https://openai-gpt-4-test-v-1.openai.azure.com//openai/"
) # openai python adds the extra /
assert (
stream_async_client._base_url
== "https://openai-gpt-4-test-v-1.openai.azure.com//openai/"
)

print("PASSED !")

except Exception as e:
Expand Down Expand Up @@ -307,3 +323,44 @@ def test_xinference_embedding():


# test_xinference_embedding()


def test_router_init_gpt_4_vision_enhancements():
try:
# tests base_url set when any base_url with /openai/deployments passed to router
print("Testing Azure GPT_Vision enhancements")

model_list = [
{
"model_name": "gpt-4-vision-enhancements",
"litellm_params": {
"model": "azure/gpt-4-vision",
"api_key": os.getenv("AZURE_API_KEY"),
"base_url": "https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions/",
},
}
]

router = Router(model_list=model_list)

print(router.model_list)
print(router.model_list[0])

assert (
router.model_list[0]["litellm_params"]["base_url"]
== "https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions/"
) # set in env

azure_client = router._get_client(
deployment=router.model_list[0],
kwargs={"stream": True, "model": "gpt-4-vision-enhancements"},
client_type="async",
)

assert (
azure_client._base_url
== "https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions/"
)
print("passed")
except Exception as e:
pytest.fail(f"Error occurred: {e}")