Skip to content

Commit a258145

Browse files
authored
Port Prompty code
A port of the Prompty code from the Promtpflow repo. The focus was on expediency, rather than elegance. The core logic of the the code is similar to the original except for the following changes: - Added type annotations - Removed support for the now legacy OpenAI completions API - Removed support for functions and tools. The former relied on an insecure implementation using eval. Since none of the evaluators use this feature right now, these were cut - Reworked the way images were handled to simplify (now handled in one pass, and no more surprise calls out to the internet to unnecessarily load image bytes) - Minor obvious tweaks to the code to improve readability, and trim unnecessary code paths
1 parent cff97a2 commit a258145

File tree

15 files changed

+1585
-79
lines changed

15 files changed

+1585
-79
lines changed

.vscode/cspell.json

+1
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,7 @@
14001400
"upia",
14011401
"xpia",
14021402
"expirable",
1403+
"ralphe"
14031404
]
14041405
},
14051406
{

sdk/evaluation/azure-ai-evaluation/assets.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/evaluation/azure-ai-evaluation",
5-
"Tag": "python/evaluation/azure-ai-evaluation_ceeaf3cbb7"
5+
"Tag": "python/evaluation/azure-ai-evaluation_bef898409a"
66
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
from azure.ai.evaluation.legacy.prompty._prompty import AsyncPrompty
6+
from azure.ai.evaluation.legacy.prompty._connection import Connection, OpenAIConnection, AzureOpenAIConnection
7+
from azure.ai.evaluation.legacy.prompty._exceptions import (
8+
PromptyException,
9+
MissingRequiredInputError,
10+
InvalidInputError,
11+
JinjaTemplateError,
12+
NotSupportedError,
13+
)
14+
15+
# =========================================================================================================
16+
# NOTE: All of the code here is largely copy of code from Promptflow. Generally speaking, the following
17+
# changes were made:
18+
# - Added type annotations
19+
# - Legacy or deprecated functionality has been removed (e.g. no more support for completions API)
20+
# - Reworked the way images are handled to 1) Reduce the amount of code brought over, 2) remove
21+
# the need to do two passes over the template to insert images, 3) remove the completely unnecessary
22+
# loading of image data from the internet when it is not actually needed
23+
# - Minor obvious tweaks to improve code readability, and removal of unused code paths
24+
# =========================================================================================================
25+
26+
__all__ = [
27+
"AsyncPrompty",
28+
"Connection",
29+
"AzureOpenAIConnection",
30+
"OpenAIConnection",
31+
"PromptyException",
32+
"MissingRequiredInputError",
33+
"InvalidInputError",
34+
"JinjaTemplateError",
35+
"NotSupportedError",
36+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
import os
6+
import re
7+
from abc import ABC, abstractmethod
8+
from dataclasses import dataclass
9+
from typing import Any, ClassVar, Mapping, Optional, Set, Union
10+
11+
from azure.ai.evaluation.legacy.prompty._exceptions import MissingRequiredInputError
12+
from azure.ai.evaluation.legacy.prompty._utils import dataclass_from_dict
13+
14+
15+
ENV_VAR_PATTERN = re.compile(r"^\$\{env:(.*)\}$")
16+
17+
18+
def _parse_environment_variable(value: Union[str, Any]) -> Union[str, Any]:
19+
"""Get environment variable from ${env:ENV_NAME}. If not found, return original value.
20+
21+
:param value: The value to parse.
22+
:type value: str | Any
23+
:return: The parsed value
24+
:rtype: str | Any"""
25+
if not isinstance(value, str):
26+
return value
27+
28+
result = re.match(ENV_VAR_PATTERN, value)
29+
if result:
30+
env_name = result.groups()[0]
31+
return os.environ.get(env_name, value)
32+
33+
return value
34+
35+
36+
def _is_empty_connection_config(connection_dict: Mapping[str, Any]) -> bool:
37+
ignored_fields = set(["azure_deployment", "model"])
38+
keys = {k for k, v in connection_dict.items() if v}
39+
return len(keys - ignored_fields) == 0
40+
41+
42+
@dataclass
43+
class Connection(ABC):
44+
"""Base class for all connection classes."""
45+
46+
@property
47+
@abstractmethod
48+
def type(self) -> str:
49+
"""Gets the type of the connection.
50+
51+
:return: The type of the connection.
52+
:rtype: str"""
53+
...
54+
55+
@abstractmethod
56+
def is_valid(self, missing_fields: Optional[Set[str]] = None) -> bool:
57+
"""Check if the connection is valid.
58+
59+
:param missing_fields: If set, this will be populated with the missing required fields.
60+
:type missing_fields: Set[str] | None
61+
:return: True if the connection is valid, False otherwise.
62+
:rtype: bool"""
63+
...
64+
65+
@staticmethod
66+
def parse_from_config(model_configuration: Mapping[str, Any]) -> "Connection":
67+
"""Parse a connection from a model configuration.
68+
69+
:param model_configuration: The model configuration.
70+
:type model_configuration: Mapping[str, Any]
71+
:return: The connection.
72+
:rtype: Connection
73+
"""
74+
connection: Connection
75+
connection_dict = {k: _parse_environment_variable(v) for k, v in model_configuration.items()}
76+
connection_type = connection_dict.pop("type", "")
77+
78+
if connection_type in [AzureOpenAIConnection.TYPE, "azure_openai"]:
79+
if _is_empty_connection_config(connection_dict):
80+
connection = AzureOpenAIConnection.from_env()
81+
else:
82+
connection = dataclass_from_dict(AzureOpenAIConnection, connection_dict)
83+
84+
elif connection_type in [OpenAIConnection.TYPE, "openai"]:
85+
if _is_empty_connection_config(connection_dict):
86+
connection = OpenAIConnection.from_env()
87+
else:
88+
connection = dataclass_from_dict(OpenAIConnection, connection_dict)
89+
90+
else:
91+
error_message = (
92+
f"'{connection_type}' is not a supported connection type. Valid values are "
93+
f"[{AzureOpenAIConnection.TYPE}, {OpenAIConnection.TYPE}]"
94+
)
95+
raise MissingRequiredInputError(error_message)
96+
97+
missing_fields: Set[str] = set()
98+
if not connection.is_valid(missing_fields):
99+
raise MissingRequiredInputError(
100+
f"The following required fields are missing for connection {connection.type}: "
101+
f"{', '.join(missing_fields)}"
102+
)
103+
104+
return connection
105+
106+
107+
@dataclass
108+
class OpenAIConnection(Connection):
109+
"""Connection class for OpenAI endpoints."""
110+
111+
base_url: str
112+
api_key: Optional[str] = None
113+
organization: Optional[str] = None
114+
115+
TYPE: ClassVar[str] = "openai"
116+
117+
@property
118+
def type(self) -> str:
119+
return OpenAIConnection.TYPE
120+
121+
@classmethod
122+
def from_env(cls) -> "OpenAIConnection":
123+
return cls(
124+
base_url=os.environ.get("OPENAI_BASE_URL", ""),
125+
api_key=os.environ.get("OPENAI_API_KEY"),
126+
organization=os.environ.get("OPENAI_ORG_ID"),
127+
)
128+
129+
def is_valid(self, missing_fields: Optional[Set[str]] = None) -> bool:
130+
if missing_fields is None:
131+
missing_fields = set()
132+
if not self.base_url:
133+
missing_fields.add("base_url")
134+
if not self.api_key:
135+
missing_fields.add("api_key")
136+
if not self.organization:
137+
missing_fields.add("organization")
138+
return not bool(missing_fields)
139+
140+
141+
@dataclass
142+
class AzureOpenAIConnection(Connection):
143+
"""Connection class for Azure OpenAI endpoints."""
144+
145+
azure_endpoint: str
146+
api_key: Optional[str] = None # TODO ralphe: Replace this TokenCredential to allow for more flexible authentication
147+
azure_deployment: Optional[str] = None
148+
api_version: Optional[str] = None
149+
resource_id: Optional[str] = None
150+
151+
TYPE: ClassVar[str] = "azure_openai"
152+
153+
@property
154+
def type(self) -> str:
155+
return AzureOpenAIConnection.TYPE
156+
157+
@classmethod
158+
def from_env(cls) -> "AzureOpenAIConnection":
159+
return cls(
160+
azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT", ""),
161+
api_key=os.environ.get("AZURE_OPENAI_API_KEY"),
162+
azure_deployment=os.environ.get("AZURE_OPENAI_DEPLOYMENT"),
163+
api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2024-02-01"),
164+
)
165+
166+
def __post_init__(self):
167+
# set default API version
168+
if not self.api_version:
169+
self.api_version = "2024-02-01"
170+
171+
def is_valid(self, missing_fields: Optional[Set[str]] = None) -> bool:
172+
if missing_fields is None:
173+
missing_fields = set()
174+
if not self.azure_endpoint:
175+
missing_fields.add("azure_endpoint")
176+
if not self.api_key:
177+
missing_fields.add("api_key")
178+
if not self.azure_deployment:
179+
missing_fields.add("azure_deployment")
180+
if not self.api_version:
181+
missing_fields.add("api_version")
182+
return not bool(missing_fields)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
from azure.ai.evaluation._exceptions import ErrorCategory, ErrorBlame, ErrorTarget, EvaluationException
6+
7+
8+
class PromptyException(EvaluationException):
9+
"""Exception class for Prompty related errors.
10+
11+
This exception is used to indicate that the error was caused by Prompty execution.
12+
13+
:param message: The error message.
14+
:type message: str
15+
"""
16+
17+
def __init__(self, message: str, **kwargs):
18+
kwargs.setdefault("category", ErrorCategory.INVALID_VALUE)
19+
kwargs.setdefault("target", ErrorTarget.UNKNOWN)
20+
kwargs.setdefault("blame", ErrorBlame.USER_ERROR)
21+
22+
super().__init__(message, **kwargs)
23+
24+
25+
class MissingRequiredInputError(PromptyException):
26+
"""Exception raised when missing required input"""
27+
28+
def __init__(self, message: str, **kwargs):
29+
kwargs.setdefault("category", ErrorCategory.MISSING_FIELD)
30+
kwargs.setdefault("target", ErrorTarget.EVAL_RUN)
31+
super().__init__(message, **kwargs)
32+
33+
34+
class InvalidInputError(PromptyException):
35+
"""Exception raised when an input is invalid, could not be loaded, or is not the expected format."""
36+
37+
def __init__(self, message: str, **kwargs):
38+
kwargs.setdefault("category", ErrorCategory.INVALID_VALUE)
39+
kwargs.setdefault("target", ErrorTarget.EVAL_RUN)
40+
super().__init__(message, **kwargs)
41+
42+
43+
class JinjaTemplateError(PromptyException):
44+
"""Exception raised when the Jinja template is invalid or could not be rendered."""
45+
46+
def __init__(self, message: str, **kwargs):
47+
kwargs.setdefault("category", ErrorCategory.INVALID_VALUE)
48+
kwargs.setdefault("target", ErrorTarget.EVAL_RUN)
49+
super().__init__(message, **kwargs)
50+
51+
52+
class NotSupportedError(PromptyException):
53+
"""Exception raised when the operation is not supported."""
54+
55+
def __init__(self, message: str, **kwargs):
56+
kwargs.setdefault("category", ErrorCategory.INVALID_VALUE)
57+
kwargs.setdefault("target", ErrorTarget.UNKNOWN)
58+
kwargs.setdefault("blame", ErrorBlame.SYSTEM_ERROR)
59+
super().__init__(message, **kwargs)

0 commit comments

Comments
 (0)