Skip to content

Commit

Permalink
Remove default connection name
Browse files Browse the repository at this point in the history
Signed-off-by: Brynn Yin <biyi@microsoft.com>
  • Loading branch information
brynn-code committed Mar 6, 2024
1 parent 343d29e commit 7cbc26c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 6 deletions.
6 changes: 6 additions & 0 deletions src/promptflow/promptflow/_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def __init__(self, env_vars: list, cls_name: str):
super().__init__(f"Required environment variables {env_vars} to build {cls_name} not set.")


class ConnectionNameNotSetError(SDKError):
"""Exception raised if connection not set when create or update."""

pass


class InvalidRunError(SDKError):
"""Exception raised if run name is not legal."""

Expand Down
24 changes: 19 additions & 5 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class _Connection(YAMLTranslatableMixin):

def __init__(
self,
name: str = "default_connection",
name: str = None,
module: str = "promptflow.connections",
configs: Dict[str, str] = None,
secrets: Dict[str, str] = None,
Expand Down Expand Up @@ -439,8 +439,15 @@ def get_token(self):
return self._token_provider.get_token()

@classmethod
def from_env(cls, name="default_env_connection"):
"""Build connection from environment variables."""
def from_env(cls, name=None):
"""
Build connection from environment variables.
Relevant environment variables:
- AZURE_OPENAI_ENDPOINT: The api base.
- AZURE_OPENAI_API_KEY: The api key.
- OPENAI_API_VERSION: Optional. The api version, default "2023-07-01-preview".
"""
# Env var name reference: https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py#L160
api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
Expand Down Expand Up @@ -503,8 +510,15 @@ def base_url(self, value):
self.configs["base_url"] = value

@classmethod
def from_env(cls, name="default_env_connection"):
"""Build connection from environment variables."""
def from_env(cls, name=None):
"""
Build connection from environment variables.
Relevant environment variables:
- OPENAI_API_KEY: The api key.
- OPENAI_ORG_ID: Optional. The unique identifier for your organization which can be used in API requests.
- OPENAI_BASE_URL: Optional. Specify when use customized api base, leave None to use OpenAI default api base.
"""
# Env var name reference: https://github.com/openai/openai-python/blob/main/src/openai/_client.py#L92
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List

from promptflow._sdk._constants import MAX_LIST_CLI_RESULTS
from promptflow._sdk._errors import ConnectionNameNotSetError
from promptflow._sdk._orm import Connection as ORMConnection
from promptflow._sdk._telemetry import ActivityType, TelemetryMixin, monitor_operation
from promptflow._sdk._utils import safe_parse_object_list
Expand Down Expand Up @@ -76,6 +77,8 @@ def create_or_update(self, connection: _Connection, **kwargs):
:param connection: Run object to create or update.
:type connection: ~promptflow.sdk.entities._connection._Connection
"""
if not connection.name:
raise ConnectionNameNotSetError("Name is required to create or update connection.")
orm_object = connection._to_orm_object()
now = datetime.now().isoformat()
if orm_object.createdDate is None:
Expand Down
11 changes: 10 additions & 1 deletion src/promptflow/tests/sdk_cli_test/e2etests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import uuid
from pathlib import Path

import pydash
import pytest
from mock import mock

from promptflow._sdk._constants import SCRUBBED_VALUE
from promptflow._sdk._errors import ConnectionNameNotSetError
from promptflow._sdk._pf_client import PFClient
from promptflow._sdk.entities import AzureOpenAIConnection, CustomConnection
from promptflow._sdk.entities import AzureOpenAIConnection, CustomConnection, OpenAIConnection

_client = PFClient()

Expand Down Expand Up @@ -118,3 +121,9 @@ def test_upsert_connection_from_file(self, file_name, expected_updated_item, exp
), "Assert secrets not updated failed, expected: {}, actual: {}".format(
expected_secret_item[1], result._secrets[expected_secret_item[0]]
)

def test_create_connection_no_name(self):
with mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}):
connection = OpenAIConnection.from_env()
with pytest.raises(ConnectionNameNotSetError):
_client.connections.create_or_update(connection)

0 comments on commit 7cbc26c

Please sign in to comment.