Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ dependencies = [
"httptools",
# Used by uvicorn for reload functionality
"watchfiles",
"azure-ai-inference",
"azure-identity",
"openai",
"aiohttp",
"python-dotenv",
"pyyaml"
Expand Down
70 changes: 33 additions & 37 deletions src/quartapp/chat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import json
import os

from azure.ai.inference.aio import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage
from azure.identity.aio import (
AzureDeveloperCliCredential,
ChainedTokenCredential,
ManagedIdentityCredential,
)
from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential, get_bearer_token_provider
from openai import AsyncOpenAI
from quart import (
Blueprint,
Response,
Expand All @@ -22,38 +17,34 @@

@bp.before_app_serving
async def configure_openai():
# Use ManagedIdentityCredential with the client_id for user-assigned managed identities
user_assigned_managed_identity_credential = ManagedIdentityCredential(client_id=os.getenv("AZURE_CLIENT_ID"))

# Use AzureDeveloperCliCredential with the current tenant.
azure_dev_cli_credential = AzureDeveloperCliCredential(tenant_id=os.getenv("AZURE_TENANT_ID"), process_timeout=60)

# Create a ChainedTokenCredential with ManagedIdentityCredential and AzureDeveloperCliCredential
# - ManagedIdentityCredential is used for deployment on Azure Container Apps

# - AzureDeveloperCliCredential is used for local development
# The order of the credentials is important, as the first valid token is used
# For more information check out:

# https://learn.microsoft.com/azure/developer/python/sdk/authentication/credential-chains?tabs=ctc#chainedtokencredential-overview
azure_credential = ChainedTokenCredential(user_assigned_managed_identity_credential, azure_dev_cli_credential)
current_app.logger.info("Using Azure OpenAI with credential")

if not os.getenv("AZURE_INFERENCE_ENDPOINT"):
raise ValueError("AZURE_INFERENCE_ENDPOINT is required for Azure OpenAI")
if os.getenv("RUNNING_IN_PRODUCTION"):
client_id = os.environ["AZURE_CLIENT_ID"]
current_app.logger.info("Using Azure OpenAI with managed identity credential for client ID: %s", client_id)
bp.azure_credential = ManagedIdentityCredential(client_id=client_id)
else:
tenant_id = os.environ["AZURE_TENANT_ID"]
current_app.logger.info("Using Azure OpenAI with Azure Developer CLI credential for tenant ID: %s", tenant_id)
bp.azure_credential = AzureDeveloperCliCredential(tenant_id=tenant_id)

# Get the token provider for Azure OpenAI based on the selected Azure credential
bp.openai_token_provider = get_bearer_token_provider(
bp.azure_credential, "https://cognitiveservices.azure.com/.default"
)

# Create the Asynchronous Azure OpenAI client
bp.ai_client = ChatCompletionsClient(
endpoint=os.environ["AZURE_INFERENCE_ENDPOINT"],
credential=azure_credential,
credential_scopes=["https://cognitiveservices.azure.com/.default"],
model="DeepSeek-R1",
bp.openai_client = AsyncOpenAI(
base_url=os.environ["AZURE_INFERENCE_ENDPOINT"],
api_key=await bp.openai_token_provider(),
default_query={"api-version": "2024-05-01-preview"},
)

# Set the model name to the Azure OpenAI model deployment name
bp.openai_model = os.getenv("AZURE_DEEPSEEK_DEPLOYMENT")


@bp.after_app_serving
async def shutdown_openai():
await bp.ai_client.close()
await bp.openai_client.close()


@bp.get("/")
Expand All @@ -69,15 +60,20 @@ async def chat_handler():
async def response_stream():
# This sends all messages, so API request may exceed token limits
all_messages = [
SystemMessage(content="You are a helpful assistant."),
{"role": "system", "content": "You are a helpful assistant."},
] + request_messages

client: ChatCompletionsClient = bp.ai_client
result = await client.complete(messages=all_messages, max_tokens=2048, stream=True)
bp.openai_client.api_key = await bp.openai_token_provider()
chat_coroutine = bp.openai_client.chat.completions.create(
# Azure Open AI takes the deployment name as the model name
model=bp.openai_model,
messages=all_messages,
stream=True,
)

try:
is_thinking = False
async for update in result:
async for update in await chat_coroutine:
if update.choices:
content = update.choices[0].delta.content
if content == "<think>":
Expand All @@ -103,4 +99,4 @@ async def response_stream():
current_app.logger.error(e)
yield json.dumps({"error": str(e)}, ensure_ascii=False) + "\n"

return Response(response_stream(), mimetype="application/json")
return Response(response_stream())
52 changes: 39 additions & 13 deletions src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.11
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --output-file=requirements.txt pyproject.toml
Expand All @@ -12,24 +12,28 @@ aiohttp==3.10.11
# via quartapp (pyproject.toml)
aiosignal==1.3.1
# via aiohttp
annotated-types==0.7.0
# via pydantic
anyio==4.6.0
# via watchfiles
# via
# httpx
# openai
# watchfiles
attrs==24.2.0
# via aiohttp
azure-ai-inference==1.0.0b8
# via quartapp (pyproject.toml)
azure-core==1.31.0
# via
# azure-ai-inference
# azure-identity
# via azure-identity
azure-identity==1.19.0
# via quartapp (pyproject.toml)
blinker==1.8.2
# via
# flask
# quart
certifi==2024.8.30
# via requests
# via
# httpcore
# httpx
# requests
cffi==1.17.1
# via cryptography
charset-normalizer==3.4.0
Expand All @@ -44,6 +48,8 @@ cryptography==44.0.1
# azure-identity
# msal
# pyjwt
distro==1.9.0
# via openai
flask==3.0.3
# via quart
frozenlist==1.4.1
Expand All @@ -54,26 +60,30 @@ gunicorn==23.0.0
# via quartapp (pyproject.toml)
h11==0.14.0
# via
# httpcore
# hypercorn
# uvicorn
# wsproto
h2==4.1.0
# via hypercorn
hpack==4.0.0
# via h2
httpcore==1.0.7
# via httpx
httptools==0.6.4
# via quartapp (pyproject.toml)
httpx==0.28.1
# via openai
hypercorn==0.17.3
# via quart
hyperframe==6.0.1
# via h2
idna==3.10
# via
# anyio
# httpx
# requests
# yarl
isodate==0.7.2
# via azure-ai-inference
itsdangerous==2.2.0
# via
# flask
Expand All @@ -82,6 +92,8 @@ jinja2==3.1.5
# via
# flask
# quart
jiter==0.9.0
# via openai
markupsafe==3.0.1
# via
# jinja2
Expand All @@ -97,6 +109,8 @@ multidict==6.1.0
# via
# aiohttp
# yarl
openai==1.66.2
# via quartapp (pyproject.toml)
packaging==24.1
# via gunicorn
portalocker==2.10.1
Expand All @@ -107,8 +121,14 @@ propcache==0.2.0
# via yarl
pycparser==2.22
# via cffi
pydantic==2.10.6
# via openai
pydantic-core==2.27.2
# via pydantic
pyjwt[crypto]==2.9.0
# via msal
# via
# msal
# pyjwt
python-dotenv==1.0.1
# via quartapp (pyproject.toml)
pyyaml==6.0.2
Expand All @@ -122,12 +142,18 @@ requests==2.32.3
six==1.16.0
# via azure-core
sniffio==1.3.1
# via anyio
# via
# anyio
# openai
tqdm==4.67.1
# via openai
typing-extensions==4.12.2
# via
# azure-ai-inference
# azure-core
# azure-identity
# openai
# pydantic
# pydantic-core
urllib3==2.2.3
# via requests
uvicorn==0.32.0
Expand Down
83 changes: 50 additions & 33 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import azure.ai.inference.models
import openai
import pytest
import pytest_asyncio

Expand All @@ -13,54 +13,70 @@ class AsyncChatCompletionIterator:
def __init__(self, answer: str):
self.chunk_index = 0
self.chunks = [
azure.ai.inference.models.StreamingChatCompletionsUpdate(
id="test-123",
created=1703462735,
model="DeepSeek-R1",
choices=[
azure.ai.inference.models.StreamingChatChoiceUpdate(
delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(
content=None, role="assistant"
),
index=0,
finish_reason=None,
)
openai.types.chat.ChatCompletionChunk(
object="chat.completion.chunk",
choices=[],
id="",
created=0,
model="",
prompt_filter_results=[
{
"prompt_index": 0,
"content_filter_results": {
"hate": {"filtered": False, "severity": "safe"},
"self_harm": {"filtered": False, "severity": "safe"},
"sexual": {"filtered": False, "severity": "safe"},
"violence": {"filtered": False, "severity": "safe"},
},
}
],
),
)
]
answer_deltas = answer.split(" ")
for answer_index, answer_delta in enumerate(answer_deltas):
# Completion chunks include whitespace, so we need to add it back in
if answer_index > 0:
# Text completion chunks include whitespace, so we need to add it back in
if answer_index > 0 and answer_delta != "</think>":
answer_delta = " " + answer_delta
self.chunks.append(
azure.ai.inference.models.StreamingChatCompletionsUpdate(
openai.types.chat.ChatCompletionChunk(
id="test-123",
created=1703462735,
model="DeepSeek-R1",
object="chat.completion.chunk",
choices=[
azure.ai.inference.models.StreamingChatChoiceUpdate(
delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(
content=answer_delta, role=None
openai.types.chat.chat_completion_chunk.Choice(
delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(
role=None, content=answer_delta
),
index=0,
finish_reason=None,
index=0,
logprobs=None,
# Only Azure includes content_filter_results
content_filter_results={
"hate": {"filtered": False, "severity": "safe"},
"self_harm": {"filtered": False, "severity": "safe"},
"sexual": {"filtered": False, "severity": "safe"},
"violence": {"filtered": False, "severity": "safe"},
},
)
],
created=1703462735,
model="DeepSeek-R1",
)
)
self.chunks.append(
azure.ai.inference.models.StreamingChatCompletionsUpdate(
openai.types.chat.ChatCompletionChunk(
id="test-123",
created=1703462735,
model="DeepSeek-R1",
object="chat.completion.chunk",
choices=[
azure.ai.inference.models.StreamingChatChoiceUpdate(
delta=azure.ai.inference.models.StreamingChatResponseMessageUpdate(content=None, role=None),
openai.types.chat.chat_completion_chunk.Choice(
delta=openai.types.chat.chat_completion_chunk.ChoiceDelta(content=None, role=None),
index=0,
finish_reason="stop",
# Only Azure includes content_filter_results
content_filter_results={},
)
],
created=1703462735,
model="DeepSeek-R1",
)
)

Expand All @@ -75,28 +91,29 @@ async def __anext__(self):
else:
raise StopAsyncIteration

async def mock_complete(*args, **kwargs):
async def mock_acreate(*args, **kwargs):
# Only mock a stream=True completion
last_message = kwargs.get("messages")[-1]["content"]
if last_message == "What is the capital of France?":
return AsyncChatCompletionIterator("The capital of France is Paris.")
return AsyncChatCompletionIterator("<think> hmm </think> The capital of France is Paris.")
elif last_message == "What is the capital of Germany?":
return AsyncChatCompletionIterator("The capital of Germany is Berlin.")
return AsyncChatCompletionIterator("<think> hmm </think> The capital of Germany is Berlin.")
else:
raise ValueError(f"Unexpected message: {last_message}")

monkeypatch.setattr("azure.ai.inference.aio.ChatCompletionsClient.complete", mock_complete)
monkeypatch.setattr("openai.resources.chat.AsyncCompletions.create", mock_acreate)


@pytest.fixture
def mock_defaultazurecredential(monkeypatch):
monkeypatch.setattr("azure.identity.aio.DefaultAzureCredential", mock_cred.MockAzureCredential)
monkeypatch.setattr("azure.identity.aio.AzureDeveloperCliCredential", mock_cred.MockAzureCredential)
monkeypatch.setattr("azure.identity.aio.ManagedIdentityCredential", mock_cred.MockAzureCredential)


@pytest_asyncio.fixture
async def client(monkeypatch, mock_openai_chatcompletion, mock_defaultazurecredential):
monkeypatch.setenv("AZURE_INFERENCE_ENDPOINT", "test-deepseek-service.ai.azure.com")
monkeypatch.setenv("AZURE_TENANT_ID", "test-tenant-id")

quart_app = quartapp.create_app(testing=True)

Expand Down
Loading
Loading