-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support for Azure AI Studio (#779)
Add support for models deployed in Azure AI Studio. This has been done by combining the code for OpenAI models, and the same provided by Azure AI Studio. Since there are a bunch of common test cases which need to run with multiple models, start refactoring those a bit as well (and hook the Azure OpenAI tests into this). This isn't using the same mechanism as the testing of local models, since we won't be running into trouble with fitting multiple LLMs on a single machine.
- Loading branch information
1 parent
8caf911
commit 631ff1a
Showing
7 changed files
with
412 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import hashlib | ||
import pathlib | ||
import urllib.parse | ||
|
||
import diskcache as dc | ||
import platformdirs | ||
import requests | ||
|
||
from ._model import Chat | ||
from ._grammarless import GrammarlessEngine, Grammarless | ||
|
||
|
||
try: | ||
import openai | ||
|
||
is_openai = True | ||
except ModuleNotFoundError: | ||
is_openai = False | ||
|
||
|
||
class AzureAIStudioChatEngine(GrammarlessEngine): | ||
def __init__( | ||
self, | ||
*, | ||
tokenizer, | ||
max_streaming_tokens: int, | ||
timeout: float, | ||
compute_log_probs: bool, | ||
azureai_studio_endpoint: str, | ||
azureai_model_deployment: str, | ||
azureai_studio_key: str, | ||
clear_cache: bool, | ||
): | ||
endpoint_parts = urllib.parse.urlparse(azureai_studio_endpoint) | ||
if endpoint_parts.path == "/score": | ||
self._is_openai_compatible = False | ||
self._endpoint = azureai_studio_endpoint | ||
else: | ||
if not is_openai: | ||
raise ValueError( | ||
"Detected OpenAI compatible model; please install openai package" | ||
) | ||
self._is_openai_compatible = True | ||
self._endpoint = f"{endpoint_parts.scheme}://{endpoint_parts.hostname}" | ||
self._deployment = azureai_model_deployment | ||
self._api_key = azureai_studio_key | ||
|
||
# There is a cache... better make sure it's specific | ||
# to the endpoint and deployment | ||
deployment_id = self._hash_prompt(self._endpoint + self._deployment) | ||
|
||
path = ( | ||
pathlib.Path(platformdirs.user_cache_dir("guidance")) | ||
/ f"azureaistudio.tokens.{deployment_id}" | ||
) | ||
self.cache = dc.Cache(path) | ||
if clear_cache: | ||
self.cache.clear() | ||
|
||
super().__init__(tokenizer, max_streaming_tokens, timeout, compute_log_probs) | ||
|
||
def _hash_prompt(self, prompt): | ||
# Copied from OpenAIChatEngine | ||
return hashlib.sha256(f"{prompt}".encode()).hexdigest() | ||
|
||
def _generator(self, prompt, temperature: float): | ||
# Initial parts of this straight up copied from OpenAIChatEngine | ||
|
||
# The next loop (or one like it) appears in several places, | ||
# and quite possibly belongs in a library function or superclass | ||
# That said, I'm not _completely sure that there aren't subtle | ||
# differences between the various versions | ||
|
||
# find the role tags | ||
pos = 0 | ||
role_end = b"<|im_end|>" | ||
messages = [] | ||
found = True | ||
while found: | ||
|
||
# find the role text blocks | ||
found = False | ||
for role_name, start_bytes in ( | ||
("system", b"<|im_start|>system\n"), | ||
("user", b"<|im_start|>user\n"), | ||
("assistant", b"<|im_start|>assistant\n"), | ||
): | ||
if prompt[pos:].startswith(start_bytes): | ||
pos += len(start_bytes) | ||
end_pos = prompt[pos:].find(role_end) | ||
if end_pos < 0: | ||
assert ( | ||
role_name == "assistant" | ||
), "Bad chat format! Last role before gen needs to be assistant!" | ||
break | ||
btext = prompt[pos : pos + end_pos] | ||
pos += end_pos + len(role_end) | ||
messages.append( | ||
{"role": role_name, "content": btext.decode("utf8")} | ||
) | ||
found = True | ||
break | ||
|
||
# Add nice exception if no role tags were used in the prompt. | ||
# TODO: Move this somewhere more general for all chat models? | ||
if messages == []: | ||
raise ValueError( | ||
f"The model is a Chat-based model and requires role tags in the prompt! \ | ||
Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \ | ||
to appropriately format your guidance program for this type of model." | ||
) | ||
|
||
# Update shared data state | ||
self._reset_shared_data(prompt[:pos], temperature) | ||
|
||
# Use cache only when temperature is 0 | ||
if temperature == 0: | ||
cache_key = self._hash_prompt(prompt) | ||
|
||
# Check if the result is already in the cache | ||
if cache_key in self.cache: | ||
for chunk in self.cache[cache_key]: | ||
yield chunk | ||
return | ||
|
||
# Call the actual API and extract the next chunk | ||
if self._is_openai_compatible: | ||
client = openai.OpenAI(api_key=self._api_key, base_url=self._endpoint) | ||
response = client.chat.completions.create( | ||
model=self._deployment, | ||
messages=messages, # type: ignore[arg-type] | ||
# max_tokens=self.max_streaming_tokens, | ||
n=1, | ||
top_p=1.0, # TODO: this should be controllable like temp (from the grammar) | ||
temperature=temperature, | ||
# stream=True, | ||
) | ||
|
||
result = response.choices[0] | ||
encoded_chunk = result.message.content.encode("utf8") # type: ignore[union-attr] | ||
else: | ||
parameters = dict(temperature=temperature) | ||
payload = dict( | ||
input_data=dict(input_string=messages, parameters=parameters) | ||
) | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": ("Bearer " + self._api_key), | ||
"azureml-model-deployment": self._deployment, | ||
} | ||
response_score = requests.post( | ||
self._endpoint, | ||
json=payload, | ||
headers=headers, | ||
) | ||
|
||
result_score = response_score.json() | ||
|
||
encoded_chunk = result_score["output"].encode("utf8") | ||
|
||
# Now back to OpenAIChatEngine, with slight modifications since | ||
# this isn't a streaming API | ||
if temperature == 0: | ||
cached_results = [] | ||
|
||
yield encoded_chunk | ||
|
||
if temperature == 0: | ||
cached_results.append(encoded_chunk) | ||
|
||
# Cache the results after the generator is exhausted | ||
if temperature == 0: | ||
self.cache[cache_key] = cached_results | ||
|
||
|
||
class AzureAIStudioChat(Grammarless, Chat): | ||
def __init__( | ||
self, | ||
azureai_studio_endpoint: str, | ||
azureai_studio_deployment: str, | ||
azureai_studio_key: str, | ||
tokenizer=None, | ||
echo: bool = True, | ||
max_streaming_tokens: int = 1000, | ||
timeout: float = 0.5, | ||
compute_log_probs: bool = False, | ||
clear_cache: bool = False, | ||
): | ||
"""Create a model object for interacting with Azure AI Studio chat endpoints. | ||
The required information about the deployed endpoint can | ||
be obtained from Azure AI Studio. | ||
A `diskcache`-based caching system is used to speed up | ||
repeated calls when the temperature is specified to be | ||
zero. | ||
Parameters | ||
---------- | ||
azureai_studio_endpoint : str | ||
The HTTPS endpoint deployed by Azure AI Studio | ||
azureai_studio_deployment : str | ||
The specific model deployed to the endpoint | ||
azureai_studio_key : str | ||
The key required for access to the API | ||
clear_cache : bool | ||
Whether to empty the internal cache | ||
""" | ||
super().__init__( | ||
AzureAIStudioChatEngine( | ||
azureai_studio_endpoint=azureai_studio_endpoint, | ||
azureai_model_deployment=azureai_studio_deployment, | ||
azureai_studio_key=azureai_studio_key, | ||
tokenizer=tokenizer, | ||
max_streaming_tokens=max_streaming_tokens, | ||
timeout=timeout, | ||
compute_log_probs=compute_log_probs, | ||
clear_cache=clear_cache, | ||
), | ||
echo=echo, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from guidance import assistant, gen, models, system, user | ||
|
||
|
||
def smoke_chat(lm: models.Chat, has_system_role: bool = True): | ||
if has_system_role: | ||
with system(): | ||
lm += "You are a math wiz." | ||
|
||
with user(): | ||
lm += "What is 1 + 1?" | ||
|
||
with assistant(): | ||
lm += gen(max_tokens=10, name="text", temperature=0.5) | ||
lm += "Pick a number: " | ||
|
||
print(str(lm)) | ||
assert len(lm["text"]) > 0 | ||
assert str(lm).endswith("Pick a number: <|im_end|>") | ||
|
||
|
||
def longer_chat_1(lm: models.Chat, has_system_role: bool = True): | ||
if has_system_role: | ||
with system(): | ||
lm += "You are a math wiz." | ||
|
||
with user(): | ||
lm += "What is 1 + 1?" | ||
|
||
with assistant(): | ||
lm += gen(max_tokens=10, name="text") | ||
lm += "Pick a number: " | ||
|
||
print(str(lm)) | ||
assert len(lm["text"]) > 0 | ||
assert str(lm).endswith("Pick a number: <|im_end|>") | ||
|
||
with user(): | ||
lm += "10. Now you pick a number between 0 and 20" | ||
|
||
with assistant(): | ||
lm += gen(max_tokens=2, name="number") | ||
|
||
print(str(lm)) | ||
assert len(lm["number"]) > 0 | ||
|
||
|
||
def longer_chat_2(lm: models.Chat, has_system_role: bool = True): | ||
if has_system_role: | ||
with system(): | ||
lm += "You are a math wiz." | ||
|
||
with user(): | ||
lm += "What is 1 + 1?" | ||
|
||
# This is the new part compared to longer_chat_1 | ||
with assistant(): | ||
lm += "2" | ||
|
||
with user(): | ||
lm += "What is 2 + 3?" | ||
|
||
# Resume the previous | ||
with assistant(): | ||
lm += gen(max_tokens=10, name="text") | ||
lm += "Pick a number: " | ||
|
||
print(str(lm)) | ||
assert len(lm["text"]) > 0 | ||
assert str(lm).endswith("Pick a number: <|im_end|>") | ||
|
||
with user(): | ||
lm += "10. Now you pick a number between 0 and 20" | ||
|
||
with assistant(): | ||
lm += gen(max_tokens=2, name="number") | ||
|
||
print(str(lm)) | ||
assert len(lm["number"]) > 0 |
Oops, something went wrong.