Skip to content

Commit

Permalink
[Feature] Support for Azure AI Studio (#779)
Browse files Browse the repository at this point in the history
Add support for models deployed in Azure AI Studio. This has been done by combining the code for OpenAI models, and the same provided by Azure AI Studio.

Since there are a bunch of common test cases which need to run with multiple models, start refactoring those a bit as well (and hook the Azure OpenAI tests into this). This isn't using the same mechanism as the testing of local models, since we won't be running into trouble with fitting multiple LLMs on a single machine.
  • Loading branch information
riedgar-ms authored May 6, 2024
1 parent 8caf911 commit 631ff1a
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 42 deletions.
12 changes: 11 additions & 1 deletion .github/workflows/ci_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,23 @@ jobs:
python -c "import torch; assert torch.cuda.is_available()"
- name: Test with pytest
env:
# Configure endpoints
# Configure endpoints for Azure OpenAI
AZUREAI_CHAT_ENDPOINT: ${{ secrets.AZUREAI_CHAT_ENDPOINT }}
AZUREAI_CHAT_KEY: ${{ secrets.AZUREAI_CHAT_KEY }}
AZUREAI_CHAT_MODEL: ${{ secrets.AZUREAI_CHAT_MODEL }}
AZUREAI_COMPLETION_ENDPOINT: ${{ secrets.AZUREAI_COMPLETION_ENDPOINT }}
AZUREAI_COMPLETION_KEY: ${{ secrets.AZUREAI_COMPLETION_KEY }}
AZUREAI_COMPLETION_MODEL: ${{ secrets.AZUREAI_COMPLETION_MODEL }}
# Configure endpoints for Azure AI Studio
AZURE_AI_STUDIO_PHI3_ENDPOINT: ${{ vars.AZURE_AI_STUDIO_PHI3_ENDPOINT }}
AZURE_AI_STUDIO_PHI3_DEPLOYMENT: ${{ vars.AZURE_AI_STUDIO_PHI3_DEPLOYMENT }}
AZURE_AI_STUDIO_PHI3_KEY: ${{ secrets.AZURE_AI_STUDIO_PHI3_KEY }}
AZURE_AI_STUDIO_MISTRAL_CHAT_ENDPOINT: ${{ vars.AZURE_AI_STUDIO_MISTRAL_CHAT_ENDPOINT }}
AZURE_AI_STUDIO_MISTRAL_CHAT_DEPLOYMENT: ${{ vars.AZURE_AI_STUDIO_MISTRAL_CHAT_DEPLOYMENT }}
AZURE_AI_STUDIO_MISTRAL_CHAT_KEY: ${{ secrets.AZURE_AI_STUDIO_MISTRAL_CHAT_KEY }}
AZURE_AI_STUDIO_LLAMA3_CHAT_ENDPOINT: ${{ vars.AZURE_AI_STUDIO_LLAMA3_CHAT_ENDPOINT }}
AZURE_AI_STUDIO_LLAMA3_CHAT_DEPLOYMENT: ${{ vars.AZURE_AI_STUDIO_LLAMA3_CHAT_DEPLOYMENT }}
AZURE_AI_STUDIO_LLAMA3_CHAT_KEY: ${{ secrets.AZURE_AI_STUDIO_LLAMA3_CHAT_KEY }}
run: |
pytest --cov=guidance --cov-report=xml --cov-report=term-missing \
-m needs_credentials \
Expand Down
1 change: 1 addition & 0 deletions guidance/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AzureOpenAICompletion,
AzureOpenAIInstruct,
)
from ._azureai_studio import AzureAIStudioChat
from ._openai import OpenAI, OpenAIChat, OpenAIInstruct, OpenAICompletion
from ._lite_llm import LiteLLM, LiteLLMChat, LiteLLMInstruct, LiteLLMCompletion
from ._cohere import Cohere, CohereCompletion, CohereInstruct
Expand Down
222 changes: 222 additions & 0 deletions guidance/models/_azureai_studio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import hashlib
import pathlib
import urllib.parse

import diskcache as dc
import platformdirs
import requests

from ._model import Chat
from ._grammarless import GrammarlessEngine, Grammarless


try:
import openai

is_openai = True
except ModuleNotFoundError:
is_openai = False


class AzureAIStudioChatEngine(GrammarlessEngine):
def __init__(
self,
*,
tokenizer,
max_streaming_tokens: int,
timeout: float,
compute_log_probs: bool,
azureai_studio_endpoint: str,
azureai_model_deployment: str,
azureai_studio_key: str,
clear_cache: bool,
):
endpoint_parts = urllib.parse.urlparse(azureai_studio_endpoint)
if endpoint_parts.path == "/score":
self._is_openai_compatible = False
self._endpoint = azureai_studio_endpoint
else:
if not is_openai:
raise ValueError(
"Detected OpenAI compatible model; please install openai package"
)
self._is_openai_compatible = True
self._endpoint = f"{endpoint_parts.scheme}://{endpoint_parts.hostname}"
self._deployment = azureai_model_deployment
self._api_key = azureai_studio_key

# There is a cache... better make sure it's specific
# to the endpoint and deployment
deployment_id = self._hash_prompt(self._endpoint + self._deployment)

path = (
pathlib.Path(platformdirs.user_cache_dir("guidance"))
/ f"azureaistudio.tokens.{deployment_id}"
)
self.cache = dc.Cache(path)
if clear_cache:
self.cache.clear()

super().__init__(tokenizer, max_streaming_tokens, timeout, compute_log_probs)

def _hash_prompt(self, prompt):
# Copied from OpenAIChatEngine
return hashlib.sha256(f"{prompt}".encode()).hexdigest()

def _generator(self, prompt, temperature: float):
# Initial parts of this straight up copied from OpenAIChatEngine

# The next loop (or one like it) appears in several places,
# and quite possibly belongs in a library function or superclass
# That said, I'm not _completely sure that there aren't subtle
# differences between the various versions

# find the role tags
pos = 0
role_end = b"<|im_end|>"
messages = []
found = True
while found:

# find the role text blocks
found = False
for role_name, start_bytes in (
("system", b"<|im_start|>system\n"),
("user", b"<|im_start|>user\n"),
("assistant", b"<|im_start|>assistant\n"),
):
if prompt[pos:].startswith(start_bytes):
pos += len(start_bytes)
end_pos = prompt[pos:].find(role_end)
if end_pos < 0:
assert (
role_name == "assistant"
), "Bad chat format! Last role before gen needs to be assistant!"
break
btext = prompt[pos : pos + end_pos]
pos += end_pos + len(role_end)
messages.append(
{"role": role_name, "content": btext.decode("utf8")}
)
found = True
break

# Add nice exception if no role tags were used in the prompt.
# TODO: Move this somewhere more general for all chat models?
if messages == []:
raise ValueError(
f"The model is a Chat-based model and requires role tags in the prompt! \
Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \
to appropriately format your guidance program for this type of model."
)

# Update shared data state
self._reset_shared_data(prompt[:pos], temperature)

# Use cache only when temperature is 0
if temperature == 0:
cache_key = self._hash_prompt(prompt)

# Check if the result is already in the cache
if cache_key in self.cache:
for chunk in self.cache[cache_key]:
yield chunk
return

# Call the actual API and extract the next chunk
if self._is_openai_compatible:
client = openai.OpenAI(api_key=self._api_key, base_url=self._endpoint)
response = client.chat.completions.create(
model=self._deployment,
messages=messages, # type: ignore[arg-type]
# max_tokens=self.max_streaming_tokens,
n=1,
top_p=1.0, # TODO: this should be controllable like temp (from the grammar)
temperature=temperature,
# stream=True,
)

result = response.choices[0]
encoded_chunk = result.message.content.encode("utf8") # type: ignore[union-attr]
else:
parameters = dict(temperature=temperature)
payload = dict(
input_data=dict(input_string=messages, parameters=parameters)
)

headers = {
"Content-Type": "application/json",
"Authorization": ("Bearer " + self._api_key),
"azureml-model-deployment": self._deployment,
}
response_score = requests.post(
self._endpoint,
json=payload,
headers=headers,
)

result_score = response_score.json()

encoded_chunk = result_score["output"].encode("utf8")

# Now back to OpenAIChatEngine, with slight modifications since
# this isn't a streaming API
if temperature == 0:
cached_results = []

yield encoded_chunk

if temperature == 0:
cached_results.append(encoded_chunk)

# Cache the results after the generator is exhausted
if temperature == 0:
self.cache[cache_key] = cached_results


class AzureAIStudioChat(Grammarless, Chat):
def __init__(
self,
azureai_studio_endpoint: str,
azureai_studio_deployment: str,
azureai_studio_key: str,
tokenizer=None,
echo: bool = True,
max_streaming_tokens: int = 1000,
timeout: float = 0.5,
compute_log_probs: bool = False,
clear_cache: bool = False,
):
"""Create a model object for interacting with Azure AI Studio chat endpoints.
The required information about the deployed endpoint can
be obtained from Azure AI Studio.
A `diskcache`-based caching system is used to speed up
repeated calls when the temperature is specified to be
zero.
Parameters
----------
azureai_studio_endpoint : str
The HTTPS endpoint deployed by Azure AI Studio
azureai_studio_deployment : str
The specific model deployed to the endpoint
azureai_studio_key : str
The key required for access to the API
clear_cache : bool
Whether to empty the internal cache
"""
super().__init__(
AzureAIStudioChatEngine(
azureai_studio_endpoint=azureai_studio_endpoint,
azureai_model_deployment=azureai_studio_deployment,
azureai_studio_key=azureai_studio_key,
tokenizer=tokenizer,
max_streaming_tokens=max_streaming_tokens,
timeout=timeout,
compute_log_probs=compute_log_probs,
clear_cache=clear_cache,
),
echo=echo,
)
78 changes: 78 additions & 0 deletions tests/models/common_chat_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from guidance import assistant, gen, models, system, user


def smoke_chat(lm: models.Chat, has_system_role: bool = True):
if has_system_role:
with system():
lm += "You are a math wiz."

with user():
lm += "What is 1 + 1?"

with assistant():
lm += gen(max_tokens=10, name="text", temperature=0.5)
lm += "Pick a number: "

print(str(lm))
assert len(lm["text"]) > 0
assert str(lm).endswith("Pick a number: <|im_end|>")


def longer_chat_1(lm: models.Chat, has_system_role: bool = True):
if has_system_role:
with system():
lm += "You are a math wiz."

with user():
lm += "What is 1 + 1?"

with assistant():
lm += gen(max_tokens=10, name="text")
lm += "Pick a number: "

print(str(lm))
assert len(lm["text"]) > 0
assert str(lm).endswith("Pick a number: <|im_end|>")

with user():
lm += "10. Now you pick a number between 0 and 20"

with assistant():
lm += gen(max_tokens=2, name="number")

print(str(lm))
assert len(lm["number"]) > 0


def longer_chat_2(lm: models.Chat, has_system_role: bool = True):
if has_system_role:
with system():
lm += "You are a math wiz."

with user():
lm += "What is 1 + 1?"

# This is the new part compared to longer_chat_1
with assistant():
lm += "2"

with user():
lm += "What is 2 + 3?"

# Resume the previous
with assistant():
lm += gen(max_tokens=10, name="text")
lm += "Pick a number: "

print(str(lm))
assert len(lm["text"]) > 0
assert str(lm).endswith("Pick a number: <|im_end|>")

with user():
lm += "10. Now you pick a number between 0 and 20"

with assistant():
lm += gen(max_tokens=2, name="number")

print(str(lm))
assert len(lm["number"]) > 0
Loading

0 comments on commit 631ff1a

Please sign in to comment.