Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Connection] Unify connection classes from sdk & execution #170

Merged
merged 9 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/actions/step_sdk_setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ runs:
shell: bash -el {0}
run: |-
conda activate release-env
pip uninstall -y promptflow promptflow-sdk
pip uninstall -y promptflow promptflow-sdk promptflow-tools
- name: 'Build and install: promptflow-tools'
shell: bash -el {0}
run: |-
conda activate release-env
python ./setup.py bdist_wheel
cd dist
package=$(ls | grep '.whl')
eval "pip install '$package'"
echo "########### pip freeze (After) ###########"
pip freeze
working-directory: src/promptflow-tools
- name: 'Build and install: promptflow with extra'
if: inputs.setupType == 'promptflow_new_extra'
shell: bash -el {0}
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/tools_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,21 @@ jobs:
echo "run_tests=false" >> $GITHUB_OUTPUT
fi
# Eventually only pip install promptflow and uninstall promptflow-sdk
# Install local build promptflow package to ensure some cross-package changes are working
- name: Setup
if: steps.check_changes.outputs.run_tests == 'true'
run: |
python -m pip install --upgrade pip
pip install promptflow-sdk promptflow --extra-index-url https://azuremlsdktestpypi.azureedge.net/promptflow/
pip install promptflow-sdk --extra-index-url https://azuremlsdktestpypi.azureedge.net/promptflow/
brynn-code marked this conversation as resolved.
Show resolved Hide resolved
echo "Local build and install promptflow"
cd src/promptflow
python ./setup.py bdist_wheel
pip install './dist/promptflow-0.0.1-py3-none-any.whl'
pip install pytest
pip install pytest_mock
pip install azure-identity
pip install azure-keyvault-secrets


- name: Generate configs
if: steps.check_changes.outputs.run_tests == 'true'
run: |
Expand Down
5 changes: 2 additions & 3 deletions src/promptflow-tools/promptflow/tools/aoai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from dataclasses import asdict

import openai

Expand All @@ -14,13 +13,13 @@ class AzureOpenAI(ToolProvider):
def __init__(self, connection: AzureOpenAIConnection):
super().__init__()
self.connection = connection
self._connection_dict = asdict(self.connection)
self._connection_dict = dict(self.connection)

def calculate_cache_string_for_completion(
self,
**kwargs,
) -> str:
d = asdict(self.connection)
d = dict(self.connection)
d.pop("api_key")
d.update({**kwargs})
return json.dumps(d)
Expand Down
3 changes: 1 addition & 2 deletions src/promptflow-tools/promptflow/tools/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import asdict
from enum import Enum
from typing import Union

Expand All @@ -20,7 +19,7 @@ class EmbeddingModel(str, Enum):
@handle_openai_error()
def embedding(connection: Union[AzureOpenAIConnection, OpenAIConnection], input: str, deployment_name: str = "",
model: EmbeddingModel = EmbeddingModel.TEXT_EMBEDDING_ADA_002):
connection_dict = asdict(connection)
connection_dict = dict(connection)
if isinstance(connection, AzureOpenAIConnection):
return openai.Embedding.create(
input=input,
Expand Down
3 changes: 1 addition & 2 deletions src/promptflow-tools/promptflow/tools/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import asdict
from enum import Enum

import openai
Expand All @@ -25,7 +24,7 @@ class OpenAI(ToolProvider):
def __init__(self, connection: OpenAIConnection):
super().__init__()
self.connection = connection
self._connection_dict = asdict(self.connection)
self._connection_dict = dict(self.connection)

@tool
@handle_openai_error()
Expand Down
25 changes: 15 additions & 10 deletions src/promptflow/promptflow/_core/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,27 @@ def _build_connections(cls, _dict: Dict[str, dict]):
from promptflow.connections import CustomConnection

if connection_class is CustomConnection:
connection_value = connection_class(**value)
connection_value.__secret_keys = connection_dict.get("secret_keys", [])
# Note: CustomConnection definition can not be got, secret keys will be provided in connection dict.
setattr(connection_value, CONNECTION_SECRET_KEYS, connection_dict.get("secret_keys", []))
secret_keys = connection_dict.get("secret_keys", [])
secrets = {k: v for k, v in value.items() if k in secret_keys}
configs = {k: v for k, v in value.items() if k not in secrets}
connection_value = connection_class(configs=configs, secrets=secrets)
else:
"""
Note: Ignore non exists keys of connection class,
because there are some keys just used by UX like resource id, while not used by backend.
"""
cls_fields = {f.name: f for f in fields(connection_class)} if is_dataclass(connection_class) else {}
connection_value = connection_class(**{k: v for k, v in value.items() if k in cls_fields})
setattr(
connection_value,
CONNECTION_SECRET_KEYS,
[f.name for f in cls_fields.values() if f.type == Secret],
)
if is_dataclass(connection_class):
# Do not delete this branch, as promptflow_vectordb.connections is dataclass type.
cls_fields = {f.name: f for f in fields(connection_class)}
connection_value = connection_class(**{k: v for k, v in value.items() if k in cls_fields})
setattr(
connection_value,
CONNECTION_SECRET_KEYS,
[f.name for f in cls_fields.values() if f.type == Secret],
)
else:
connection_value = connection_class(**{k: v for k, v in value.items()})
# Use this hack to make sure serialization works
setattr(connection_value, CONNECTION_NAME_PROPERTY, key)
connections[key] = connection_value
Expand Down
95 changes: 93 additions & 2 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ConnectionType,
)
from promptflow._sdk._errors import UnsecureConnectionError
from promptflow._sdk._logger_factory import LoggerFactory
from promptflow._sdk._orm.connection import Connection as ORMConnection
from promptflow._sdk._utils import (
decrypt_secret_value,
Expand All @@ -39,13 +40,15 @@
WeaviateConnectionSchema,
)

logger = LoggerFactory.get_logger(name=__name__)


class _Connection(YAMLTranslatableMixin):
TYPE = ConnectionType._NOT_SET

def __init__(
self,
name,
name: str = "default_connection",
module: str = "promptflow.connections",
configs: Dict[str, str] = None,
secrets: Dict[str, str] = None,
Expand Down Expand Up @@ -92,6 +95,17 @@ def _casting_type(cls, typ):
return type_dict.get(typ)
return snake_to_camel(typ)

def keys(self):
return list(self.configs.keys()) + list(self.secrets.keys())

def __getitem__(self, item):
# Note: This is added to allow usage **connection().
if item in self.secrets:
return self.secrets[item]
if item in self.configs:
return self.configs[item]
raise KeyError(f"Key {item!r} not found in connection {self.name!r}.")

@classmethod
def _is_scrubbed_value(cls, value):
"""For scrubbed value, cli will get original for update, and prompt user to input for create."""
Expand Down Expand Up @@ -276,6 +290,13 @@ def api_key(self, value):


class AzureOpenAIConnection(_StrongTypeConnection):
"""
:param api_key: The api key.
:param api_base: The api base.
:param api_type: The api type, default "azure".
:param api_version: The api version, default "2023-07-01-preview".
:param name: Connection name.
"""
TYPE = ConnectionType.AZURE_OPEN_AI

def __init__(
Expand Down Expand Up @@ -315,6 +336,11 @@ def api_version(self, value):


class OpenAIConnection(_StrongTypeConnection):
"""
:param api_key: The api key.
:param organization: The organization, optional.
:param name: Connection name.
"""
TYPE = ConnectionType.OPEN_AI

def __init__(self, api_key: str, organization: str = None, **kwargs):
Expand All @@ -336,6 +362,10 @@ def organization(self, value):


class SerpConnection(_StrongTypeConnection):
"""
brynn-code marked this conversation as resolved.
Show resolved Hide resolved
:param api_key: The api key.
:param name: Connection name.
"""
TYPE = ConnectionType.SERP

def __init__(self, api_key: str, **kwargs):
Expand Down Expand Up @@ -365,6 +395,11 @@ def api_base(self, value):


class QdrantConnection(_EmbeddingStoreConnection):
"""
:param api_key: The api key.
:param api_base: The api base.
:param name: Connection name.
"""
TYPE = ConnectionType.QDRANT

@classmethod
Expand All @@ -373,6 +408,11 @@ def _get_schema_cls(cls):


class WeaviateConnection(_EmbeddingStoreConnection):
"""
:param api_key: The api key.
:param api_base: The api base.
:param name: Connection name.
"""
TYPE = ConnectionType.WEAVIATE

@classmethod
Expand All @@ -381,6 +421,12 @@ def _get_schema_cls(cls):


class CognitiveSearchConnection(_StrongTypeConnection):
"""
:param api_key: The api key.
:param api_base: The api base.
:param api_version: The api version, default "2023-07-01-Preview".
:param name: Connection name.
"""
TYPE = ConnectionType.COGNITIVE_SEARCH

def __init__(self, api_key: str, api_base: str, api_version: str = "2023-07-01-Preview", **kwargs):
Expand Down Expand Up @@ -410,6 +456,13 @@ def api_version(self, value):


class AzureContentSafetyConnection(_StrongTypeConnection):
"""
:param api_key: The api key.
:param endpoint: The api endpoint.
:param api_version: The api version, default "2023-04-30-preview".
:param api_type: The api type, default "Content Safety".
:param name: Connection name.
"""
TYPE = ConnectionType.AZURE_CONTENT_SAFETY

def __init__(
Expand Down Expand Up @@ -454,6 +507,13 @@ def api_type(self, value):


class FormRecognizerConnection(AzureContentSafetyConnection):
"""
:param api_key: The api key.
:param endpoint: The api endpoint.
:param api_version: The api version, default "2023-07-31".
:param api_type: The api type, default "Form Recognizer".
:param name: Connection name.
"""
# Note: FormRecognizer and ContentSafety are using CognitiveService type in ARM, so keys are the same.
TYPE = ConnectionType.FORM_RECOGNIZER

Expand All @@ -468,17 +528,48 @@ def _get_schema_cls(cls):


class CustomConnection(_Connection):
"""
:param configs: The configs kv pairs.
:param secrets: The secrets kv pairs.
:param name: Connection name
"""
TYPE = ConnectionType.CUSTOM

def __init__(self, secrets: Dict[str, str], configs: Dict[str, str] = None, **kwargs):
if not secrets:
raise ValueError("secrets is required for custom connection.")
raise ValueError(
"Secrets is required for custom connection, "
"please use CustomConnection(configs={key1: val1}, secrets={key2: val2}) "
"to initialize custom connection."
)
super().__init__(secrets=secrets, configs=configs, **kwargs)

@classmethod
def _get_schema_cls(cls):
return CustomConnectionSchema

def __getattr__(self, item):
# Note: This is added for compatibility with promptflow.connections custom connection usage.
if item == "secrets":
# Usually obj.secrets will not reach here
# This is added to handle copy.deepcopy loop issue
return super().__getattribute__("secrets")
if item == "configs":
# Usually obj.configs will not reach here
# This is added to handle copy.deepcopy loop issue
return super().__getattribute__("configs")
if item in self.secrets:
logger.warning("Please use connection.secrets[key] to access secrets.")
return self.secrets[item]
if item in self.configs:
logger.warning("Please use connection.configs[key] to access configs.")
return self.configs[item]
return super().__getattribute__(item)

def is_secret(self, item):
# Note: This is added for compatibility with promptflow.connections custom connection usage.
return item in self.secrets

def _to_orm_object(self):
# Both keys & secrets will be set in custom configs with value type specified for custom connection.
custom_configs = {
Expand Down
Loading
Loading