Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK][Connection] Support serverless connection type #2080

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/promptflow/promptflow/_sdk/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ class ConnectionType(str, Enum):
AZURE_CONTENT_SAFETY = "AzureContentSafety"
FORM_RECOGNIZER = "FormRecognizer"
WEAVIATE = "Weaviate"
SERVERLESS = "Serverless"
CUSTOM = "Custom"


Expand Down
38 changes: 36 additions & 2 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ConnectionType,
CustomStrongTypeConnectionConfigs,
)
from promptflow._sdk._errors import UnsecureConnectionError, SDKError
from promptflow._sdk._errors import SDKError, UnsecureConnectionError
from promptflow._sdk._orm.connection import Connection as ORMConnection
from promptflow._sdk._utils import (
decrypt_secret_value,
Expand All @@ -43,11 +43,12 @@
OpenAIConnectionSchema,
QdrantConnectionSchema,
SerpConnectionSchema,
ServerlessConnectionSchema,
WeaviateConnectionSchema,
)
from promptflow._utils.logger_utils import LoggerFactory
from promptflow.contracts.types import Secret
from promptflow.exceptions import ValidationException, UserErrorException
from promptflow.exceptions import UserErrorException, ValidationException

logger = LoggerFactory.get_logger(name=__name__)
PROMPTFLOW_CONNECTIONS = "promptflow.connections"
Expand Down Expand Up @@ -461,6 +462,39 @@ def base_url(self, value):
self.configs["base_url"] = value


class ServerlessConnection(_StrongTypeConnection):
"""Serverless connection.

:param api_key: The api key.
:type api_key: str
:param api_base: The api base.
:type api_base: str
:param name: Connection name.
:type name: str
"""

TYPE = ConnectionType.SERVERLESS

def __init__(self, api_key: str, api_base: str, **kwargs):
secrets = {"api_key": api_key}
configs = {"api_base": api_base}
super().__init__(secrets=secrets, configs=configs, **kwargs)

@classmethod
def _get_schema_cls(cls):
return ServerlessConnectionSchema

@property
def api_base(self):
"""Return the connection api base."""
return self.configs.get("api_base")

@api_base.setter
def api_base(self, value):
"""Set the connection api base."""
self.configs["api_base"] = value


class SerpConnection(_StrongTypeConnection):
"""Serp connection.

Expand Down
6 changes: 6 additions & 0 deletions src/promptflow/promptflow/_sdk/schemas/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class OpenAIConnectionSchema(ConnectionSchema):
base_url = fields.Str()


class ServerlessConnectionSchema(ConnectionSchema):
type = StringTransformedEnum(allowed_values=camel_to_snake(ConnectionType.SERVERLESS), required=True)
api_key = fields.Str(required=True)
api_base = fields.Str(required=True)


class EmbeddingStoreConnectionSchema(ConnectionSchema):
module = fields.Str(dump_default="promptflow_vectordb.connections")
api_key = fields.Str(required=True)
Expand Down
1 change: 1 addition & 0 deletions src/promptflow/promptflow/_trace/_start_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def start_trace(*, session: typing.Optional[str] = None, **kwargs):
# honor and set attributes if user has specified
attributes: dict = kwargs.get("attributes", None)
if attributes is not None:
_logger.debug("User specified attributes: %s", attributes)
for attr_key, attr_value in attributes.items():
operation_context._add_otel_attributes(attr_key, attr_value)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from enum import Enum
from typing import Any, Dict, Union

import requests
Expand Down Expand Up @@ -31,11 +30,14 @@
FLOW_META_PREFIX = "azureml.flow."


class ConnectionCategory(str, Enum):
class ConnectionCategory:
AzureOpenAI = "AzureOpenAI"
CognitiveSearch = "CognitiveSearch"
CognitiveService = "CognitiveService"
CustomKeys = "CustomKeys"
OpenAI = "OpenAI"
Serp = "Serp"
Serverless = "Serverless"


def get_case_insensitive_key(d, key, default=None):
Expand Down Expand Up @@ -120,10 +122,14 @@ def open_url(cls, token, url, action, host="management.azure.com", method="GET",
def validate_and_fallback_connection_type(cls, name, type_name, category, metadata):
if type_name:
return type_name
if category == ConnectionCategory.AzureOpenAI:
return "AzureOpenAI"
if category == ConnectionCategory.CognitiveSearch:
return "CognitiveSearch"
# Below category has corresponding connection type in PromptFlow, so we can fall back directly.
# Note: CustomKeys may store different connection types for now, e.g. openai, serp.
if category in [
ConnectionCategory.AzureOpenAI,
ConnectionCategory.CognitiveSearch,
ConnectionCategory.Serverless,
]:
return category
if category == ConnectionCategory.CognitiveService:
kind = get_case_insensitive_key(metadata, "Kind")
if kind == "Content Safety":
Expand Down Expand Up @@ -191,6 +197,11 @@ def build_connection_dict_from_rest_object(cls, name, obj) -> dict:
"api_base": properties.target,
"api_version": get_case_insensitive_key(properties.metadata, "ApiVersion"),
}
elif properties.category == ConnectionCategory.Serverless:
value = {
"api_key": properties.credentials.key,
"api_base": properties.target,
}
elif properties.category == ConnectionCategory.CognitiveService:
value = {
"api_key": properties.credentials.key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,28 @@ def test_build_connection_unknown_category():
with pytest.raises(Exception) as e:
build_from_data_and_assert(data, {})
assert "Unknown connection mock category Unknown" in str(e.value)


@pytest.mark.unittest
def test_build_serverless_category_connection_from_rest_object():
data = {
"id": "mock_id",
"name": "test_serverless_connection",
"type": "Microsoft.MachineLearningServices/workspaces/connections",
"properties": {
"authType": "ApiKey",
"credentials": {"key": "***"},
"group": "AzureAI",
"category": "Serverless",
"expiryTime": None,
"target": "mock_base",
"sharedUserList": [],
"metadata": {},
},
}
expected = {
"type": "ServerlessConnection",
"module": "promptflow.connections",
"value": {"api_key": "***", "api_base": "mock_base"},
}
build_from_data_and_assert(data, expected)
14 changes: 14 additions & 0 deletions src/promptflow/tests/sdk_cli_test/unittests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
OpenAIConnection,
QdrantConnection,
SerpConnection,
ServerlessConnection,
WeaviateConnection,
_Connection,
)
Expand Down Expand Up @@ -170,6 +171,19 @@ class TestConnection:
"type": "weaviate",
},
),
(
"serverless_connection.yaml",
ServerlessConnection,
{
"name": "my_serverless_connection",
"api_key": "<to-be-replaced>",
"api_base": "https://mock.api.base",
},
{
"module": "promptflow.connections",
"type": "serverless",
},
),
],
)
def test_connection_load_dump(self, file_name, class_name, init_param, expected):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
$schema: https://azuremlschemas.azureedge.net/promptflow/latest/ServerlessConnection.schema.json
name: my_serverless_connection
type: serverless
api_key: "<to-be-replaced>"
api_base: "https://mock.api.base"
Loading