diff --git a/src/promptflow/promptflow/_sdk/_constants.py b/src/promptflow/promptflow/_sdk/_constants.py index 1fd4fa62f45..5884ccd1437 100644 --- a/src/promptflow/promptflow/_sdk/_constants.py +++ b/src/promptflow/promptflow/_sdk/_constants.py @@ -337,6 +337,7 @@ class ConnectionType(str, Enum): AZURE_CONTENT_SAFETY = "AzureContentSafety" FORM_RECOGNIZER = "FormRecognizer" WEAVIATE = "Weaviate" + SERVERLESS = "Serverless" CUSTOM = "Custom" diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py index c28e7176a53..10705e6bab8 100644 --- a/src/promptflow/promptflow/_sdk/entities/_connection.py +++ b/src/promptflow/promptflow/_sdk/entities/_connection.py @@ -22,7 +22,7 @@ ConnectionType, CustomStrongTypeConnectionConfigs, ) -from promptflow._sdk._errors import UnsecureConnectionError, SDKError +from promptflow._sdk._errors import SDKError, UnsecureConnectionError from promptflow._sdk._orm.connection import Connection as ORMConnection from promptflow._sdk._utils import ( decrypt_secret_value, @@ -43,11 +43,12 @@ OpenAIConnectionSchema, QdrantConnectionSchema, SerpConnectionSchema, + ServerlessConnectionSchema, WeaviateConnectionSchema, ) from promptflow._utils.logger_utils import LoggerFactory from promptflow.contracts.types import Secret -from promptflow.exceptions import ValidationException, UserErrorException +from promptflow.exceptions import UserErrorException, ValidationException logger = LoggerFactory.get_logger(name=__name__) PROMPTFLOW_CONNECTIONS = "promptflow.connections" @@ -461,6 +462,39 @@ def base_url(self, value): self.configs["base_url"] = value +class ServerlessConnection(_StrongTypeConnection): + """Serverless connection. + + :param api_key: The api key. + :type api_key: str + :param api_base: The api base. + :type api_base: str + :param name: Connection name. + :type name: str + """ + + TYPE = ConnectionType.SERVERLESS + + def __init__(self, api_key: str, api_base: str, **kwargs): + secrets = {"api_key": api_key} + configs = {"api_base": api_base} + super().__init__(secrets=secrets, configs=configs, **kwargs) + + @classmethod + def _get_schema_cls(cls): + return ServerlessConnectionSchema + + @property + def api_base(self): + """Return the connection api base.""" + return self.configs.get("api_base") + + @api_base.setter + def api_base(self, value): + """Set the connection api base.""" + self.configs["api_base"] = value + + class SerpConnection(_StrongTypeConnection): """Serp connection. diff --git a/src/promptflow/promptflow/_sdk/schemas/_connection.py b/src/promptflow/promptflow/_sdk/schemas/_connection.py index a5e9e9f4b53..31c6faf38a5 100644 --- a/src/promptflow/promptflow/_sdk/schemas/_connection.py +++ b/src/promptflow/promptflow/_sdk/schemas/_connection.py @@ -61,6 +61,12 @@ class OpenAIConnectionSchema(ConnectionSchema): base_url = fields.Str() +class ServerlessConnectionSchema(ConnectionSchema): + type = StringTransformedEnum(allowed_values=camel_to_snake(ConnectionType.SERVERLESS), required=True) + api_key = fields.Str(required=True) + api_base = fields.Str(required=True) + + class EmbeddingStoreConnectionSchema(ConnectionSchema): module = fields.Str(dump_default="promptflow_vectordb.connections") api_key = fields.Str(required=True) diff --git a/src/promptflow/promptflow/_trace/_start_trace.py b/src/promptflow/promptflow/_trace/_start_trace.py index 49ac4d2cf2b..44ec2443c23 100644 --- a/src/promptflow/promptflow/_trace/_start_trace.py +++ b/src/promptflow/promptflow/_trace/_start_trace.py @@ -51,6 +51,7 @@ def start_trace(*, session: typing.Optional[str] = None, **kwargs): # honor and set attributes if user has specified attributes: dict = kwargs.get("attributes", None) if attributes is not None: + _logger.debug("User specified attributes: %s", attributes) for attr_key, attr_value in attributes.items(): operation_context._add_otel_attributes(attr_key, attr_value) diff --git a/src/promptflow/promptflow/azure/operations/_arm_connection_operations.py b/src/promptflow/promptflow/azure/operations/_arm_connection_operations.py index b96adbc0bb7..7478fcdb0e9 100644 --- a/src/promptflow/promptflow/azure/operations/_arm_connection_operations.py +++ b/src/promptflow/promptflow/azure/operations/_arm_connection_operations.py @@ -1,7 +1,6 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from enum import Enum from typing import Any, Dict, Union import requests @@ -31,11 +30,14 @@ FLOW_META_PREFIX = "azureml.flow." -class ConnectionCategory(str, Enum): +class ConnectionCategory: AzureOpenAI = "AzureOpenAI" CognitiveSearch = "CognitiveSearch" CognitiveService = "CognitiveService" CustomKeys = "CustomKeys" + OpenAI = "OpenAI" + Serp = "Serp" + Serverless = "Serverless" def get_case_insensitive_key(d, key, default=None): @@ -120,10 +122,14 @@ def open_url(cls, token, url, action, host="management.azure.com", method="GET", def validate_and_fallback_connection_type(cls, name, type_name, category, metadata): if type_name: return type_name - if category == ConnectionCategory.AzureOpenAI: - return "AzureOpenAI" - if category == ConnectionCategory.CognitiveSearch: - return "CognitiveSearch" + # Below category has corresponding connection type in PromptFlow, so we can fall back directly. + # Note: CustomKeys may store different connection types for now, e.g. openai, serp. + if category in [ + ConnectionCategory.AzureOpenAI, + ConnectionCategory.CognitiveSearch, + ConnectionCategory.Serverless, + ]: + return category if category == ConnectionCategory.CognitiveService: kind = get_case_insensitive_key(metadata, "Kind") if kind == "Content Safety": @@ -191,6 +197,11 @@ def build_connection_dict_from_rest_object(cls, name, obj) -> dict: "api_base": properties.target, "api_version": get_case_insensitive_key(properties.metadata, "ApiVersion"), } + elif properties.category == ConnectionCategory.Serverless: + value = { + "api_key": properties.credentials.key, + "api_base": properties.target, + } elif properties.category == ConnectionCategory.CognitiveService: value = { "api_key": properties.credentials.key, diff --git a/src/promptflow/tests/sdk_cli_azure_test/unittests/test_arm_connection_build.py b/src/promptflow/tests/sdk_cli_azure_test/unittests/test_arm_connection_build.py index 61aa0388897..545292968f5 100644 --- a/src/promptflow/tests/sdk_cli_azure_test/unittests/test_arm_connection_build.py +++ b/src/promptflow/tests/sdk_cli_azure_test/unittests/test_arm_connection_build.py @@ -214,3 +214,28 @@ def test_build_connection_unknown_category(): with pytest.raises(Exception) as e: build_from_data_and_assert(data, {}) assert "Unknown connection mock category Unknown" in str(e.value) + + +@pytest.mark.unittest +def test_build_serverless_category_connection_from_rest_object(): + data = { + "id": "mock_id", + "name": "test_serverless_connection", + "type": "Microsoft.MachineLearningServices/workspaces/connections", + "properties": { + "authType": "ApiKey", + "credentials": {"key": "***"}, + "group": "AzureAI", + "category": "Serverless", + "expiryTime": None, + "target": "mock_base", + "sharedUserList": [], + "metadata": {}, + }, + } + expected = { + "type": "ServerlessConnection", + "module": "promptflow.connections", + "value": {"api_key": "***", "api_base": "mock_base"}, + } + build_from_data_and_assert(data, expected) diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_connection.py b/src/promptflow/tests/sdk_cli_test/unittests/test_connection.py index 94202712cc7..2f8adc1ac82 100644 --- a/src/promptflow/tests/sdk_cli_test/unittests/test_connection.py +++ b/src/promptflow/tests/sdk_cli_test/unittests/test_connection.py @@ -18,6 +18,7 @@ OpenAIConnection, QdrantConnection, SerpConnection, + ServerlessConnection, WeaviateConnection, _Connection, ) @@ -170,6 +171,19 @@ class TestConnection: "type": "weaviate", }, ), + ( + "serverless_connection.yaml", + ServerlessConnection, + { + "name": "my_serverless_connection", + "api_key": "", + "api_base": "https://mock.api.base", + }, + { + "module": "promptflow.connections", + "type": "serverless", + }, + ), ], ) def test_connection_load_dump(self, file_name, class_name, init_param, expected): diff --git a/src/promptflow/tests/test_configs/connections/serverless_connection.yaml b/src/promptflow/tests/test_configs/connections/serverless_connection.yaml new file mode 100644 index 00000000000..f85559132ae --- /dev/null +++ b/src/promptflow/tests/test_configs/connections/serverless_connection.yaml @@ -0,0 +1,5 @@ +$schema: https://azuremlschemas.azureedge.net/promptflow/latest/ServerlessConnection.schema.json +name: my_serverless_connection +type: serverless +api_key: "" +api_base: "https://mock.api.base" \ No newline at end of file