-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support for Azure AI Studio #779
Changes from 12 commits
386cdd3
176201c
beac0cc
3d90baa
32bc793
7840cfd
04e45c7
bcc241a
25ecccf
1265346
f348880
0fc4727
2be7f58
3bcb48e
559b341
9e46101
0a7cc81
1584d9f
60b23c8
10fc9ba
b68e9d7
9a4c1a8
c0769b6
2a60c94
9c755d6
21ee13f
cdf679c
4281d7f
7bf3d07
c973076
50847c5
3d73c42
64ec232
2327c8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import hashlib | ||
import pathlib | ||
|
||
import diskcache as dc | ||
import platformdirs | ||
import requests | ||
|
||
from ._model import Chat | ||
from ._grammarless import GrammarlessEngine, Grammarless | ||
|
||
|
||
class AzureAIStudioChatEngine(GrammarlessEngine): | ||
def __init__( | ||
self, | ||
*, | ||
tokenizer, | ||
max_streaming_tokens: int, | ||
timeout: float, | ||
compute_log_probs: bool, | ||
azureai_studio_endpoint: str, | ||
azureai_model_deployment: str, | ||
azureai_studio_key: str, | ||
): | ||
self._endpoint = azureai_studio_endpoint | ||
self._deployment = azureai_model_deployment | ||
self._api_key = azureai_studio_key | ||
|
||
# There is a cache... better make sure it's specific | ||
# to the endpoint and deployment | ||
deployment_id = self._hash_prompt(self._endpoint + self._deployment) | ||
|
||
path = ( | ||
pathlib.Path(platformdirs.user_cache_dir("guidance")) | ||
/ f"azureaistudio.tokens.{deployment_id}" | ||
) | ||
self.cache = dc.Cache(path) | ||
|
||
super().__init__(tokenizer, max_streaming_tokens, timeout, compute_log_probs) | ||
|
||
def _hash_prompt(self, prompt): | ||
# Copied from OpenAIChatEngine | ||
return hashlib.sha256(f"{prompt}".encode()).hexdigest() | ||
|
||
def _generator(self, prompt, temperature: float): | ||
# Initial parts of this straight up copied from OpenAIChatEngine | ||
|
||
# The next loop (or one like it) appears in several places, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thoughts on this? This is a straight-up copy of what is in |
||
# and quite possibly belongs in a library function or superclass | ||
# That said, I'm not _completely sure that there aren't subtle | ||
# differences between the various versions | ||
|
||
# find the role tags | ||
pos = 0 | ||
role_end = b"<|im_end|>" | ||
messages = [] | ||
found = True | ||
while found: | ||
|
||
# find the role text blocks | ||
found = False | ||
for role_name, start_bytes in ( | ||
("system", b"<|im_start|>system\n"), | ||
("user", b"<|im_start|>user\n"), | ||
("assistant", b"<|im_start|>assistant\n"), | ||
Comment on lines
+84
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do AzureAI models uniformly use the same role tags across their models? I don't think we can hard code a check for these start_bytes in this class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These aren't coming from the model, surely? They're coming from |
||
): | ||
if prompt[pos:].startswith(start_bytes): | ||
pos += len(start_bytes) | ||
end_pos = prompt[pos:].find(role_end) | ||
if end_pos < 0: | ||
assert ( | ||
role_name == "assistant" | ||
), "Bad chat format! Last role before gen needs to be assistant!" | ||
break | ||
btext = prompt[pos : pos + end_pos] | ||
pos += end_pos + len(role_end) | ||
messages.append( | ||
{"role": role_name, "content": btext.decode("utf8")} | ||
) | ||
found = True | ||
break | ||
|
||
# Add nice exception if no role tags were used in the prompt. | ||
# TODO: Move this somewhere more general for all chat models? | ||
if messages == []: | ||
raise ValueError( | ||
f"The model is a Chat-based model and requires role tags in the prompt! \ | ||
Make sure you are using guidance context managers like `with system():`, `with user():` and `with assistant():` \ | ||
to appropriately format your guidance program for this type of model." | ||
) | ||
|
||
# Update shared data state | ||
self._reset_shared_data(prompt[:pos], temperature) | ||
|
||
# Use cache only when temperature is 0 | ||
if temperature == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This caching logic is a bit of a concern - at least for models where T=0 doesn't actually get determinism. And in general, it means that a bunch of our tests might not quite be doing what we think they're doing, because they may just be hitting the cache. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The more I think about it, the less I like the idea of a disk-based cache. In some ways it's worse on the OpenAI side, where both AzureOpenAI and OpenAI will wind up sharing the same cache. How much speed up does it really give, compared to the Heisenbug potential it represents? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think most models have more reliable temp=0 determinism now, but agree that perhaps sharing between AzureOpenAI and OpenAI is problematic (though there shouldn't be differences between the two APIs in theory?). I do think caching is a nice feature to have in general, as production workflows often have shared inputs coming in that save time and money to reuse. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added a |
||
cache_key = self._hash_prompt(prompt) | ||
|
||
# Check if the result is already in the cache | ||
if cache_key in self.cache: | ||
for chunk in self.cache[cache_key]: | ||
yield chunk | ||
return | ||
|
||
# Prepare for the API call (this might be model specific....) | ||
riedgar-ms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parameters = dict(temperature=temperature) | ||
payload = dict(input_data=dict(input_string=messages, parameters=parameters)) | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": ("Bearer " + self._api_key), | ||
"azureml-model-deployment": self._deployment, | ||
} | ||
|
||
response = requests.post( | ||
self._endpoint, | ||
json=payload, | ||
headers=headers, | ||
) | ||
|
||
result = response.json() | ||
|
||
# Now back to OpenAIChatEngine, with slight modifications since | ||
# this isn't a streaming API | ||
riedgar-ms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if temperature == 0: | ||
cached_results = [] | ||
|
||
encoded_chunk = result["output"].encode("utf8") | ||
|
||
yield encoded_chunk | ||
|
||
if temperature == 0: | ||
cached_results.append(encoded_chunk) | ||
|
||
# Cache the results after the generator is exhausted | ||
if temperature == 0: | ||
self.cache[cache_key] = cached_results | ||
|
||
|
||
class AzureAIStudioChat(Grammarless, Chat): | ||
def __init__( | ||
self, | ||
azureai_studio_endpoint: str, | ||
azureai_studio_deployment: str, | ||
azureai_studio_key: str, | ||
tokenizer=None, | ||
echo: bool = True, | ||
max_streaming_tokens: int = 1000, | ||
timeout: float = 0.5, | ||
compute_log_probs: bool = False, | ||
): | ||
super().__init__( | ||
AzureAIStudioChatEngine( | ||
azureai_studio_endpoint=azureai_studio_endpoint, | ||
azureai_model_deployment=azureai_studio_deployment, | ||
azureai_studio_key=azureai_studio_key, | ||
tokenizer=tokenizer, | ||
max_streaming_tokens=max_streaming_tokens, | ||
timeout=timeout, | ||
compute_log_probs=compute_log_probs, | ||
), | ||
echo=echo, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import pytest | ||
|
||
import pytest | ||
|
||
from guidance import assistant, gen, models, system, user | ||
|
||
from ..utils import env_or_fail | ||
|
||
# Everything in here needs credentials to work | ||
# Mark is configured in pyproject.toml | ||
pytestmark = pytest.mark.needs_credentials | ||
|
||
|
||
def test_azureai_phi3_chat_smoke(rate_limiter): | ||
azureai_studio_endpoint = env_or_fail("AZURE_AI_STUDIO_PHI3_ENDPOINT") | ||
azureai_studio_deployment = env_or_fail("AZURE_AI_STUDIO_PHI3_DEPLOYMENT") | ||
azureai_studio_key = env_or_fail("AZURE_AI_STUDIO_PHI3_KEY") | ||
|
||
lm = models.AzureAIStudioChat( | ||
azureai_studio_endpoint=azureai_studio_endpoint, | ||
azureai_studio_deployment=azureai_studio_deployment, | ||
azureai_studio_key=azureai_studio_key, | ||
) | ||
assert isinstance(lm, models.AzureAIStudioChat) | ||
|
||
with system(): | ||
lm += "You are a math wiz." | ||
|
||
with user(): | ||
lm += "What is 1 + 1?" | ||
|
||
with assistant(): | ||
lm += gen(max_tokens=10, name="text", temperature=0.5) | ||
lm += "Pick a number: " | ||
|
||
print(str(lm)) | ||
assert len(lm["text"]) > 0 | ||
riedgar-ms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def test_azureai_mistral_chat_smoke(rate_limiter): | ||
azureai_studio_endpoint = env_or_fail("AZURE_AI_STUDIO_MISTRAL_CHAT_ENDPOINT") | ||
azureai_studio_deployment = env_or_fail("AZURE_AI_STUDIO_MISTRAL_CHAT_DEPLOYMENT") | ||
azureai_studio_key = env_or_fail("AZURE_AI_STUDIO_MISTRAL_CHAT_KEY") | ||
|
||
lm = models.AzureAIStudioChat( | ||
azureai_studio_endpoint=azureai_studio_endpoint, | ||
azureai_studio_deployment=azureai_studio_deployment, | ||
azureai_studio_key=azureai_studio_key, | ||
) | ||
assert isinstance(lm, models.AzureAIStudioChat) | ||
lm.engine.cache.clear() | ||
|
||
# No "system" role for Mistral? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes me unhappy |
||
# with system(): | ||
# lm += "You are a math wiz." | ||
|
||
with user(): | ||
lm += "What is 1 + 1?" | ||
|
||
with assistant(): | ||
lm += gen(max_tokens=15, name="text", temperature=0.5) | ||
lm += "\nPick a number: " | ||
|
||
print(str(lm)) | ||
assert len(lm["text"]) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never try setting the tokeniser, and it appears that it eventually defaults to GPT2. I don't quite see why a remote model like this would even need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so theoretically for token healing. However, I have a feeling that trying to figure out what tokeniser to use will be an exercise in fragility.