From 9bb59d2b2dd03c0fef98e3e5bb918db3093dabdb Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 25 Aug 2023 15:12:54 +0800 Subject: [PATCH 1/8] Unify connection classes Signed-off-by: Brynn Yin --- src/promptflow-tools/promptflow/tools/aoai.py | 5 +- .../promptflow/tools/embedding.py | 3 +- .../promptflow/tools/openai.py | 3 +- .../promptflow/_core/connection_manager.py | 25 +++--- .../promptflow/_sdk/entities/_connection.py | 87 ++++++++++++++++++- src/promptflow/promptflow/connections.py | 84 +++++------------- 6 files changed, 124 insertions(+), 83 deletions(-) diff --git a/src/promptflow-tools/promptflow/tools/aoai.py b/src/promptflow-tools/promptflow/tools/aoai.py index c605ae276c2..188cff66095 100644 --- a/src/promptflow-tools/promptflow/tools/aoai.py +++ b/src/promptflow-tools/promptflow/tools/aoai.py @@ -1,5 +1,4 @@ import json -from dataclasses import asdict import openai @@ -14,13 +13,13 @@ class AzureOpenAI(ToolProvider): def __init__(self, connection: AzureOpenAIConnection): super().__init__() self.connection = connection - self._connection_dict = asdict(self.connection) + self._connection_dict = dict(self.connection.items()) def calculate_cache_string_for_completion( self, **kwargs, ) -> str: - d = asdict(self.connection) + d = dict(self.connection.items()) d.pop("api_key") d.update({**kwargs}) return json.dumps(d) diff --git a/src/promptflow-tools/promptflow/tools/embedding.py b/src/promptflow-tools/promptflow/tools/embedding.py index 193c81d3e5a..e383b760c64 100644 --- a/src/promptflow-tools/promptflow/tools/embedding.py +++ b/src/promptflow-tools/promptflow/tools/embedding.py @@ -1,4 +1,3 @@ -from dataclasses import asdict from enum import Enum from typing import Union @@ -20,7 +19,7 @@ class EmbeddingModel(str, Enum): @handle_openai_error() def embedding(connection: Union[AzureOpenAIConnection, OpenAIConnection], input: str, deployment_name: str = "", model: EmbeddingModel = EmbeddingModel.TEXT_EMBEDDING_ADA_002): - connection_dict = asdict(connection) + connection_dict = dict(connection.items()) if isinstance(connection, AzureOpenAIConnection): return openai.Embedding.create( input=input, diff --git a/src/promptflow-tools/promptflow/tools/openai.py b/src/promptflow-tools/promptflow/tools/openai.py index 99975cf5498..0590f6cd6a0 100644 --- a/src/promptflow-tools/promptflow/tools/openai.py +++ b/src/promptflow-tools/promptflow/tools/openai.py @@ -1,4 +1,3 @@ -from dataclasses import asdict from enum import Enum import openai @@ -25,7 +24,7 @@ class OpenAI(ToolProvider): def __init__(self, connection: OpenAIConnection): super().__init__() self.connection = connection - self._connection_dict = asdict(self.connection) + self._connection_dict = dict(self.connection.items()) @tool @handle_openai_error() diff --git a/src/promptflow/promptflow/_core/connection_manager.py b/src/promptflow/promptflow/_core/connection_manager.py index 2c375fc172a..c266a41be1a 100644 --- a/src/promptflow/promptflow/_core/connection_manager.py +++ b/src/promptflow/promptflow/_core/connection_manager.py @@ -49,22 +49,27 @@ def _build_connections(cls, _dict: Dict[str, dict]): from promptflow.connections import CustomConnection if connection_class is CustomConnection: - connection_value = connection_class(**value) - connection_value.__secret_keys = connection_dict.get("secret_keys", []) # Note: CustomConnection definition can not be got, secret keys will be provided in connection dict. - setattr(connection_value, CONNECTION_SECRET_KEYS, connection_dict.get("secret_keys", [])) + secret_keys = connection_dict.get("secret_keys", []) + secrets = {k: v for k, v in value.items() if k in secret_keys} + configs = {k: v for k, v in value.items() if k not in secrets} + connection_value = connection_class(configs=configs, secrets=secrets) else: """ Note: Ignore non exists keys of connection class, because there are some keys just used by UX like resource id, while not used by backend. """ - cls_fields = {f.name: f for f in fields(connection_class)} if is_dataclass(connection_class) else {} - connection_value = connection_class(**{k: v for k, v in value.items() if k in cls_fields}) - setattr( - connection_value, - CONNECTION_SECRET_KEYS, - [f.name for f in cls_fields.values() if f.type == Secret], - ) + if is_dataclass(connection_class): + # Do not delete this branch, as embedding store connection is dataclass type. + cls_fields = {f.name: f for f in fields(connection_class)} + connection_value = connection_class(**{k: v for k, v in value.items() if k in cls_fields}) + setattr( + connection_value, + CONNECTION_SECRET_KEYS, + [f.name for f in cls_fields.values() if f.type == Secret], + ) + else: + connection_value = connection_class(**{k: v for k, v in value.items()}) # Use this hack to make sure serialization works setattr(connection_value, CONNECTION_NAME_PROPERTY, key) connections[key] = connection_value diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py index e1319ca23ba..a94cd9618ff 100644 --- a/src/promptflow/promptflow/_sdk/entities/_connection.py +++ b/src/promptflow/promptflow/_sdk/entities/_connection.py @@ -3,6 +3,7 @@ # --------------------------------------------------------- import abc import json +from builtins import _dict_items from os import PathLike from pathlib import Path from typing import Dict, Union @@ -17,6 +18,7 @@ ConnectionType, ) from promptflow._sdk._errors import UnsecureConnectionError +from promptflow._sdk._logger_factory import LoggerFactory from promptflow._sdk._orm.connection import Connection as ORMConnection from promptflow._sdk._utils import ( decrypt_secret_value, @@ -39,13 +41,15 @@ WeaviateConnectionSchema, ) +logger = LoggerFactory.get_logger(name=__name__) + class _Connection(YAMLTranslatableMixin): TYPE = ConnectionType._NOT_SET def __init__( self, - name, + name: str = "default_connection", module: str = "promptflow.connections", configs: Dict[str, str] = None, secrets: Dict[str, str] = None, @@ -92,6 +96,9 @@ def _casting_type(cls, typ): return type_dict.get(typ) return snake_to_camel(typ) + def items(self) -> _dict_items: + return {**self.configs, **self.secrets}.items() + @classmethod def _is_scrubbed_value(cls, value): """For scrubbed value, cli will get original for update, and prompt user to input for create.""" @@ -276,6 +283,13 @@ def api_key(self, value): class AzureOpenAIConnection(_StrongTypeConnection): + """ + :param api_key: The api key. + :param api_base: The api base. + :param api_type: The api type, default "azure". + :param api_version: The api version, default "2023-07-01-preview". + :param name: Connection name. + """ TYPE = ConnectionType.AZURE_OPEN_AI def __init__( @@ -315,6 +329,11 @@ def api_version(self, value): class OpenAIConnection(_StrongTypeConnection): + """ + :param api_key: The api key. + :param organization: The organization, optional. + :param name: Connection name. + """ TYPE = ConnectionType.OPEN_AI def __init__(self, api_key: str, organization: str = None, **kwargs): @@ -336,6 +355,10 @@ def organization(self, value): class SerpConnection(_StrongTypeConnection): + """ + :param api_key: The api key. + :param name: Connection name. + """ TYPE = ConnectionType.SERP def __init__(self, api_key: str, **kwargs): @@ -365,6 +388,11 @@ def api_base(self, value): class QdrantConnection(_EmbeddingStoreConnection): + """ + :param api_key: The api key. + :param api_base: The api base. + :param name: Connection name. + """ TYPE = ConnectionType.QDRANT @classmethod @@ -373,6 +401,11 @@ def _get_schema_cls(cls): class WeaviateConnection(_EmbeddingStoreConnection): + """ + :param api_key: The api key. + :param api_base: The api base. + :param name: Connection name. + """ TYPE = ConnectionType.WEAVIATE @classmethod @@ -381,6 +414,12 @@ def _get_schema_cls(cls): class CognitiveSearchConnection(_StrongTypeConnection): + """ + :param api_key: The api key. + :param api_base: The api base. + :param api_version: The api version, default "2023-07-01-Preview". + :param name: Connection name. + """ TYPE = ConnectionType.COGNITIVE_SEARCH def __init__(self, api_key: str, api_base: str, api_version: str = "2023-07-01-Preview", **kwargs): @@ -410,6 +449,13 @@ def api_version(self, value): class AzureContentSafetyConnection(_StrongTypeConnection): + """ + :param api_key: The api key. + :param endpoint: The api endpoint. + :param api_version: The api version, default "2023-04-30-preview". + :param api_type: The api type, default "Content Safety". + :param name: Connection name. + """ TYPE = ConnectionType.AZURE_CONTENT_SAFETY def __init__( @@ -454,6 +500,13 @@ def api_type(self, value): class FormRecognizerConnection(AzureContentSafetyConnection): + """ + :param api_key: The api key. + :param endpoint: The api endpoint. + :param api_version: The api version, default "2023-07-31". + :param api_type: The api type, default "Form Recognizer". + :param name: Connection name. + """ # Note: FormRecognizer and ContentSafety are using CognitiveService type in ARM, so keys are the same. TYPE = ConnectionType.FORM_RECOGNIZER @@ -468,17 +521,47 @@ def _get_schema_cls(cls): class CustomConnection(_Connection): + """ + :param configs: The configs kv pairs. + :param secrets: The secrets kv pairs. + :param name: Connection name + """ TYPE = ConnectionType.CUSTOM def __init__(self, secrets: Dict[str, str], configs: Dict[str, str] = None, **kwargs): if not secrets: - raise ValueError("secrets is required for custom connection.") + raise ValueError( + "Secrets is required for custom connection, " + "please use CustomConnection(configs={key1: val1}, secrets={key2: val2}) " + "to initialize custom connection." + ) super().__init__(secrets=secrets, configs=configs, **kwargs) @classmethod def _get_schema_cls(cls): return CustomConnectionSchema + def __getattr__(self, item): + # Note: This is added for compatibility with promptflow.connections custom connection usage. + if item == "secrets": + # Usually obj.secrets will not reach here + # This is added to handle copy.deepcopy loop issue + return super().__getattribute__("secrets") + if item == "configs": + # Usually obj.configs will not reach here + # This is added to handle copy.deepcopy loop issue + return super().__getattribute__("configs") + if item in self.secrets: + logger.warning("Please use connection.secrets[key] to access secrets.") + return self.secrets[item] + if item in self.configs: + logger.warning("Please use connection.configs[key] to access configs.") + return self.configs[item] + return super().__getattribute__(item) + def is_secret(self, item): + # Note: This is added for compatibility with promptflow.connections custom connection usage. + return item in self.secrets + def _to_orm_object(self): # Both keys & secrets will be set in custom configs with value type specified for custom connection. custom_configs = { diff --git a/src/promptflow/promptflow/connections.py b/src/promptflow/promptflow/connections.py index f35f2f17b4b..dc017689987 100644 --- a/src/promptflow/promptflow/connections.py +++ b/src/promptflow/promptflow/connections.py @@ -4,8 +4,17 @@ from dataclasses import dataclass, is_dataclass -from promptflow._constants import CONNECTION_SECRET_KEYS from promptflow._core.tools_manager import register_connections +from promptflow._sdk.entities import ( + AzureContentSafetyConnection, + AzureOpenAIConnection, + CognitiveSearchConnection, + CustomConnection, + FormRecognizerConnection, + OpenAIConnection, + SerpConnection, +) +from promptflow._sdk.entities._connection import _Connection from promptflow.contracts.types import Secret @@ -15,67 +24,14 @@ class BingConnection: url: str = "https://api.bing.microsoft.com/v7.0/search" -@dataclass -class OpenAIConnection: - api_key: Secret - organization: str = None - - -@dataclass -class AzureOpenAIConnection: - api_key: Secret - api_base: str - api_type: str = "azure" - api_version: str = "2023-07-01-preview" - - -@dataclass -class AzureContentSafetyConnection: - api_key: Secret - endpoint: str - api_version: str = "2023-04-30-preview" - - -@dataclass -class SerpConnection: - api_key: Secret - - -@dataclass -class CognitiveSearchConnection: - api_key: Secret - api_base: str - api_version: str = "2023-07-01-Preview" - - -@dataclass -class FormRecognizerConnection: - api_key: Secret - endpoint: str - api_version: str = "2023-07-31" - - -class CustomConnection(dict): - def __init__(self, *args, **kwargs): - # record secret keys if init from local - for k, v in kwargs.items(): - if isinstance(v, Secret): - self._set_secret(k) - super().__init__(*args, **kwargs) - - def __getattr__(self, item): - if item in self: - return self.__getitem__(item) - return super().__getattribute__(item) - - def is_secret(self, item): - secret_keys = getattr(self, CONNECTION_SECRET_KEYS, []) - return item in secret_keys - - def _set_secret(self, item): - secret_keys = getattr(self, CONNECTION_SECRET_KEYS, []) - secret_keys.append(item) - setattr(self, CONNECTION_SECRET_KEYS, secret_keys) - +OpenAIConnection = OpenAIConnection +AzureOpenAIConnection = AzureOpenAIConnection +AzureContentSafetyConnection = AzureContentSafetyConnection +SerpConnection = SerpConnection +CognitiveSearchConnection = CognitiveSearchConnection +FormRecognizerConnection = FormRecognizerConnection +CustomConnection = CustomConnection -register_connections([v for v in globals().values() if is_dataclass(v) or v is CustomConnection]) +register_connections( + [v for v in globals().values() if is_dataclass(v) or (isinstance(v, type) and issubclass(v, _Connection))] +) \ No newline at end of file From 47bd8c49716957d7d1b3e5a8a203ea4ec3aba58a Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 25 Aug 2023 15:17:42 +0800 Subject: [PATCH 2/8] Fix flake 8 Signed-off-by: Brynn Yin --- src/promptflow/promptflow/_sdk/entities/_connection.py | 1 + src/promptflow/promptflow/connections.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py index a94cd9618ff..8894619c353 100644 --- a/src/promptflow/promptflow/_sdk/entities/_connection.py +++ b/src/promptflow/promptflow/_sdk/entities/_connection.py @@ -558,6 +558,7 @@ def __getattr__(self, item): logger.warning("Please use connection.configs[key] to access configs.") return self.configs[item] return super().__getattribute__(item) + def is_secret(self, item): # Note: This is added for compatibility with promptflow.connections custom connection usage. return item in self.secrets diff --git a/src/promptflow/promptflow/connections.py b/src/promptflow/promptflow/connections.py index dc017689987..d996647f1d0 100644 --- a/src/promptflow/promptflow/connections.py +++ b/src/promptflow/promptflow/connections.py @@ -34,4 +34,4 @@ class BingConnection: register_connections( [v for v in globals().values() if is_dataclass(v) or (isinstance(v, type) and issubclass(v, _Connection))] -) \ No newline at end of file +) From 2cdae7aba1900a4b16a6314e5b7998876269c99b Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 25 Aug 2023 15:56:17 +0800 Subject: [PATCH 3/8] Update usage Signed-off-by: Brynn Yin --- src/promptflow-tools/promptflow/tools/aoai.py | 4 ++-- src/promptflow-tools/promptflow/tools/embedding.py | 2 +- src/promptflow-tools/promptflow/tools/openai.py | 2 +- .../promptflow/_sdk/entities/_connection.py | 13 ++++++++++--- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/promptflow-tools/promptflow/tools/aoai.py b/src/promptflow-tools/promptflow/tools/aoai.py index 188cff66095..1ac063af957 100644 --- a/src/promptflow-tools/promptflow/tools/aoai.py +++ b/src/promptflow-tools/promptflow/tools/aoai.py @@ -13,13 +13,13 @@ class AzureOpenAI(ToolProvider): def __init__(self, connection: AzureOpenAIConnection): super().__init__() self.connection = connection - self._connection_dict = dict(self.connection.items()) + self._connection_dict = {**self.connection} def calculate_cache_string_for_completion( self, **kwargs, ) -> str: - d = dict(self.connection.items()) + d = {**self.connection} d.pop("api_key") d.update({**kwargs}) return json.dumps(d) diff --git a/src/promptflow-tools/promptflow/tools/embedding.py b/src/promptflow-tools/promptflow/tools/embedding.py index e383b760c64..33b789f00d2 100644 --- a/src/promptflow-tools/promptflow/tools/embedding.py +++ b/src/promptflow-tools/promptflow/tools/embedding.py @@ -19,7 +19,7 @@ class EmbeddingModel(str, Enum): @handle_openai_error() def embedding(connection: Union[AzureOpenAIConnection, OpenAIConnection], input: str, deployment_name: str = "", model: EmbeddingModel = EmbeddingModel.TEXT_EMBEDDING_ADA_002): - connection_dict = dict(connection.items()) + connection_dict = {**connection} if isinstance(connection, AzureOpenAIConnection): return openai.Embedding.create( input=input, diff --git a/src/promptflow-tools/promptflow/tools/openai.py b/src/promptflow-tools/promptflow/tools/openai.py index 0590f6cd6a0..effaa0e9070 100644 --- a/src/promptflow-tools/promptflow/tools/openai.py +++ b/src/promptflow-tools/promptflow/tools/openai.py @@ -24,7 +24,7 @@ class OpenAI(ToolProvider): def __init__(self, connection: OpenAIConnection): super().__init__() self.connection = connection - self._connection_dict = dict(self.connection.items()) + self._connection_dict = {**self.connection} @tool @handle_openai_error() diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py index 8894619c353..5c554bdf0fd 100644 --- a/src/promptflow/promptflow/_sdk/entities/_connection.py +++ b/src/promptflow/promptflow/_sdk/entities/_connection.py @@ -3,7 +3,6 @@ # --------------------------------------------------------- import abc import json -from builtins import _dict_items from os import PathLike from pathlib import Path from typing import Dict, Union @@ -96,8 +95,16 @@ def _casting_type(cls, typ): return type_dict.get(typ) return snake_to_camel(typ) - def items(self) -> _dict_items: - return {**self.configs, **self.secrets}.items() + def keys(self): + return list(self.configs.keys()) + list(self.secrets.keys()) + + def __getitem__(self, item): + # Note: This is added to allow usage **connection(). + if item in self.secrets: + return self.secrets[item] + if item in self.configs: + return self.configs[item] + raise KeyError(f"Key {item!r} not found in connection {self.name!r}.") @classmethod def _is_scrubbed_value(cls, value): From 8589a9331dad83dd1eff9f817aff19243b3f0e22 Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 25 Aug 2023 16:36:02 +0800 Subject: [PATCH 4/8] Update workflow Signed-off-by: Brynn Yin --- .github/actions/step_sdk_setup/action.yml | 11 ++++++++++- .github/workflows/tools_tests.yml | 7 +++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/.github/actions/step_sdk_setup/action.yml b/.github/actions/step_sdk_setup/action.yml index 80d7be1c845..a61ede1000d 100644 --- a/.github/actions/step_sdk_setup/action.yml +++ b/.github/actions/step_sdk_setup/action.yml @@ -16,7 +16,16 @@ runs: shell: bash -el {0} run: |- conda activate release-env - pip uninstall -y promptflow promptflow-sdk + pip uninstall -y promptflow promptflow-sdk promptflow-tools + - name: 'Build and install: promptflow-tools' + shell: bash -el {0} + run: |- + conda activate release-env + python ./setup.py bdist_wheel + pip install './dist/promptflow_tools-0.0.1-py3-none-any.whl' + echo "########### pip freeze (After) ###########" + pip freeze + working-directory: src/promptflow-tools - name: 'Build and install: promptflow with extra' if: inputs.setupType == 'promptflow_new_extra' shell: bash -el {0} diff --git a/.github/workflows/tools_tests.yml b/.github/workflows/tools_tests.yml index 1bc90b6990a..1dbd66e04ee 100644 --- a/.github/workflows/tools_tests.yml +++ b/.github/workflows/tools_tests.yml @@ -34,12 +34,15 @@ jobs: if: steps.check_changes.outputs.run_tests == 'true' run: | python -m pip install --upgrade pip - pip install promptflow-sdk promptflow --extra-index-url https://azuremlsdktestpypi.azureedge.net/promptflow/ + pip install promptflow-sdk --extra-index-url https://azuremlsdktestpypi.azureedge.net/promptflow/ pip install pytest pip install pytest_mock pip install azure-identity pip install azure-keyvault-secrets - + echo "Local build and install promptflow" + cd src/promptflow + python ./setup.py bdist_wheel + pip install './dist/promptflow-0.0.1-py3-none-any.whl' - name: Generate configs if: steps.check_changes.outputs.run_tests == 'true' From 57c67c13fadcb38c15b838af78342e80405ecce1 Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 25 Aug 2023 16:50:09 +0800 Subject: [PATCH 5/8] Update Signed-off-by: Brynn Yin --- .github/actions/step_sdk_setup/action.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/actions/step_sdk_setup/action.yml b/.github/actions/step_sdk_setup/action.yml index a61ede1000d..cf97a33de9b 100644 --- a/.github/actions/step_sdk_setup/action.yml +++ b/.github/actions/step_sdk_setup/action.yml @@ -22,7 +22,8 @@ runs: run: |- conda activate release-env python ./setup.py bdist_wheel - pip install './dist/promptflow_tools-0.0.1-py3-none-any.whl' + package=$(ls | grep '.whl') + pip install '$package' echo "########### pip freeze (After) ###########" pip freeze working-directory: src/promptflow-tools From a3cee552cc5edec04329e9c0b2b299637f0d800f Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 25 Aug 2023 16:54:59 +0800 Subject: [PATCH 6/8] Update script Signed-off-by: Brynn Yin --- .github/actions/step_sdk_setup/action.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/actions/step_sdk_setup/action.yml b/.github/actions/step_sdk_setup/action.yml index cf97a33de9b..6a1876761b7 100644 --- a/.github/actions/step_sdk_setup/action.yml +++ b/.github/actions/step_sdk_setup/action.yml @@ -22,8 +22,9 @@ runs: run: |- conda activate release-env python ./setup.py bdist_wheel + cd dist package=$(ls | grep '.whl') - pip install '$package' + eval "pip install '$package'" echo "########### pip freeze (After) ###########" pip freeze working-directory: src/promptflow-tools From 53f2b16927c730140e7c6a41aca81c4d040c2e6a Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Fri, 25 Aug 2023 17:28:10 +0800 Subject: [PATCH 7/8] Fix test Signed-off-by: Brynn Yin --- src/promptflow-tools/promptflow/tools/aoai.py | 4 ++-- src/promptflow-tools/promptflow/tools/embedding.py | 2 +- src/promptflow-tools/promptflow/tools/openai.py | 2 +- .../tests/test_configs/flows/basic-with-connection/hello.py | 4 +--- .../tests/test_configs/flows/openai_api_flow/chat.py | 3 +-- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/promptflow-tools/promptflow/tools/aoai.py b/src/promptflow-tools/promptflow/tools/aoai.py index 1ac063af957..bc9e6f570c6 100644 --- a/src/promptflow-tools/promptflow/tools/aoai.py +++ b/src/promptflow-tools/promptflow/tools/aoai.py @@ -13,13 +13,13 @@ class AzureOpenAI(ToolProvider): def __init__(self, connection: AzureOpenAIConnection): super().__init__() self.connection = connection - self._connection_dict = {**self.connection} + self._connection_dict = dict(self.connection) def calculate_cache_string_for_completion( self, **kwargs, ) -> str: - d = {**self.connection} + d = dict(self.connection) d.pop("api_key") d.update({**kwargs}) return json.dumps(d) diff --git a/src/promptflow-tools/promptflow/tools/embedding.py b/src/promptflow-tools/promptflow/tools/embedding.py index 33b789f00d2..ea8bb4582c9 100644 --- a/src/promptflow-tools/promptflow/tools/embedding.py +++ b/src/promptflow-tools/promptflow/tools/embedding.py @@ -19,7 +19,7 @@ class EmbeddingModel(str, Enum): @handle_openai_error() def embedding(connection: Union[AzureOpenAIConnection, OpenAIConnection], input: str, deployment_name: str = "", model: EmbeddingModel = EmbeddingModel.TEXT_EMBEDDING_ADA_002): - connection_dict = {**connection} + connection_dict = dict(connection) if isinstance(connection, AzureOpenAIConnection): return openai.Embedding.create( input=input, diff --git a/src/promptflow-tools/promptflow/tools/openai.py b/src/promptflow-tools/promptflow/tools/openai.py index effaa0e9070..1fd4d20f7b8 100644 --- a/src/promptflow-tools/promptflow/tools/openai.py +++ b/src/promptflow-tools/promptflow/tools/openai.py @@ -24,7 +24,7 @@ class OpenAI(ToolProvider): def __init__(self, connection: OpenAIConnection): super().__init__() self.connection = connection - self._connection_dict = {**self.connection} + self._connection_dict = dict(self.connection) @tool @handle_openai_error() diff --git a/src/promptflow/tests/test_configs/flows/basic-with-connection/hello.py b/src/promptflow/tests/test_configs/flows/basic-with-connection/hello.py index 49ac214e967..d691b75c71e 100644 --- a/src/promptflow/tests/test_configs/flows/basic-with-connection/hello.py +++ b/src/promptflow/tests/test_configs/flows/basic-with-connection/hello.py @@ -1,8 +1,6 @@ import os -from dataclasses import asdict import openai -from dotenv import load_dotenv from promptflow import tool from promptflow.connections import AzureOpenAIConnection @@ -63,7 +61,7 @@ def my_python_tool( logit_bias=logit_bias if logit_bias else {}, user=user, request_timeout=30, - **asdict(connection), + **dict(connection), ) # get first element because prompt is single. diff --git a/src/promptflow/tests/test_configs/flows/openai_api_flow/chat.py b/src/promptflow/tests/test_configs/flows/openai_api_flow/chat.py index f3ed6c00a78..6fe291663cc 100644 --- a/src/promptflow/tests/test_configs/flows/openai_api_flow/chat.py +++ b/src/promptflow/tests/test_configs/flows/openai_api_flow/chat.py @@ -1,5 +1,4 @@ import openai -from dataclasses import asdict from typing import List from promptflow import tool @@ -26,7 +25,7 @@ def chat(connection: AzureOpenAIConnection, question: str, chat_history: List) - stream=stream, stop=None, max_tokens=16, - **asdict(connection), + **dict(connection), ) if stream: From 6ee808be94a135182e5c3bf48000560b2bad3a03 Mon Sep 17 00:00:00 2001 From: Brynn Yin Date: Mon, 28 Aug 2023 11:12:54 +0800 Subject: [PATCH 8/8] Add comments about changes Signed-off-by: Brynn Yin --- .github/workflows/tools_tests.yml | 9 +++++---- src/promptflow/promptflow/_core/connection_manager.py | 2 +- src/promptflow/promptflow/connections.py | 2 ++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tools_tests.yml b/.github/workflows/tools_tests.yml index 1dbd66e04ee..919009a610a 100644 --- a/.github/workflows/tools_tests.yml +++ b/.github/workflows/tools_tests.yml @@ -30,19 +30,20 @@ jobs: echo "run_tests=false" >> $GITHUB_OUTPUT fi # Eventually only pip install promptflow and uninstall promptflow-sdk + # Install local build promptflow package to ensure some cross-package changes are working - name: Setup if: steps.check_changes.outputs.run_tests == 'true' run: | python -m pip install --upgrade pip pip install promptflow-sdk --extra-index-url https://azuremlsdktestpypi.azureedge.net/promptflow/ - pip install pytest - pip install pytest_mock - pip install azure-identity - pip install azure-keyvault-secrets echo "Local build and install promptflow" cd src/promptflow python ./setup.py bdist_wheel pip install './dist/promptflow-0.0.1-py3-none-any.whl' + pip install pytest + pip install pytest_mock + pip install azure-identity + pip install azure-keyvault-secrets - name: Generate configs if: steps.check_changes.outputs.run_tests == 'true' diff --git a/src/promptflow/promptflow/_core/connection_manager.py b/src/promptflow/promptflow/_core/connection_manager.py index c266a41be1a..2b7efada9e6 100644 --- a/src/promptflow/promptflow/_core/connection_manager.py +++ b/src/promptflow/promptflow/_core/connection_manager.py @@ -60,7 +60,7 @@ def _build_connections(cls, _dict: Dict[str, dict]): because there are some keys just used by UX like resource id, while not used by backend. """ if is_dataclass(connection_class): - # Do not delete this branch, as embedding store connection is dataclass type. + # Do not delete this branch, as promptflow_vectordb.connections is dataclass type. cls_fields = {f.name: f for f in fields(connection_class)} connection_value = connection_class(**{k: v for k, v in value.items() if k in cls_fields}) setattr( diff --git a/src/promptflow/promptflow/connections.py b/src/promptflow/promptflow/connections.py index d996647f1d0..5f8cb8a5cc2 100644 --- a/src/promptflow/promptflow/connections.py +++ b/src/promptflow/promptflow/connections.py @@ -24,6 +24,8 @@ class BingConnection: url: str = "https://api.bing.microsoft.com/v7.0/search" +# We should use unified connection class everywhere. +# Do not add new connection class definition directly here. OpenAIConnection = OpenAIConnection AzureOpenAIConnection = AzureOpenAIConnection AzureContentSafetyConnection = AzureContentSafetyConnection