Skip to content

Commit

Permalink
[Connection] Support aoai & oai connection from_env (#2219)
Browse files Browse the repository at this point in the history
# Description

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Signed-off-by: Brynn Yin <biyi@microsoft.com>
  • Loading branch information
brynn-code authored Mar 6, 2024
1 parent f540a84 commit 7bd946e
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 5 deletions.
17 changes: 17 additions & 0 deletions docs/how-to-guides/manage-connections.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,23 @@ On the VS Code primary sidebar > prompt flow pane. You can find the connections
:::
::::

## Load from environment variables
With `promptflow>=1.7.0`, user is able to load a connection object from os environment variables with `<ConnectionType>.from_env` func.
Note that the connection object will **NOT BE CREATED** to local database.

Supported types are as follows:

| Connection Type | Field | Relevant Environment Variable |
|-----------------------| --- |--------------------------------------------------|
| OpenAIConnection | api_key | OPENAI_API_KEY |
| | organization | OPENAI_ORG_ID |
| | base_url | OPENAI_BASE_URL |
| AzureOpenAIConnection | api_key | AZURE_OPENAI_API_KEY |
| | api_base | AZURE_OPENAI_ENDPOINT |
| | api_version | OPENAI_API_VERSION |

For example, with `OPENAI_API_KEY` set to environment, an `OpenAIConnection` object can be loaded with `OpenAIConnection.from_env()`.


## Next steps
- Reach more detail about [connection concepts](../../concepts/concept-connections.md).
Expand Down
2 changes: 1 addition & 1 deletion examples/flows/standard/basic/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_client():
from openai import AzureOpenAI as Client
conn.update(
azure_endpoint=os.environ["AZURE_OPENAI_API_BASE"],
api_version=os.environ.get("AZURE_OPENAI_API_VERSION", "2023-07-01-preview"),
api_version=os.environ.get("OPENAI_API_VERSION", "2023-07-01-preview"),
)
return Client(**conn)

Expand Down
2 changes: 2 additions & 0 deletions src/promptflow/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
- CLI: Support `pfazure run create --resume-from <original-run-name>` to create a run resume from another run.
- SDK: Support `pf.run(resume_from=<original-run-name>)` to create a run resume from another run.

- [SDK/CLI] Support `AzureOpenAIConnection.from_env` and `OpenAIConnection.from_env`. Reach more details [here](https://microsoft.github.io/promptflow/how-to-guides/manage-connections.html#load-from-environment-variables).

## 1.6.0 (2024.03.01)

### Features Added
Expand Down
13 changes: 13 additions & 0 deletions src/promptflow/promptflow/_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ class ConnectionNotFoundError(SDKError):
pass


class RequiredEnvironmentVariablesNotSetError(SDKError):
"""Exception raised if connection from_env required env vars not found."""

def __init__(self, env_vars: list, cls_name: str):
super().__init__(f"Required environment variables {env_vars} to build {cls_name} not set.")


class ConnectionNameNotSetError(SDKError):
"""Exception raised if connection not set when create or update."""

pass


class InvalidRunError(SDKError):
"""Exception raised if run name is not legal."""

Expand Down
46 changes: 44 additions & 2 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import abc
import importlib
import json
import os
import types
from os import PathLike
from pathlib import Path
Expand All @@ -25,7 +26,7 @@
ConnectionType,
CustomStrongTypeConnectionConfigs,
)
from promptflow._sdk._errors import SDKError, UnsecureConnectionError
from promptflow._sdk._errors import RequiredEnvironmentVariablesNotSetError, SDKError, UnsecureConnectionError
from promptflow._sdk._orm.connection import Connection as ORMConnection
from promptflow._sdk._utils import (
decrypt_secret_value,
Expand Down Expand Up @@ -76,7 +77,7 @@ class _Connection(YAMLTranslatableMixin):

def __init__(
self,
name: str = "default_connection",
name: str = None,
module: str = "promptflow.connections",
configs: Dict[str, str] = None,
secrets: Dict[str, str] = None,
Expand Down Expand Up @@ -437,6 +438,29 @@ def get_token(self):

return self._token_provider.get_token()

@classmethod
def from_env(cls, name=None):
"""
Build connection from environment variables.
Relevant environment variables:
- AZURE_OPENAI_ENDPOINT: The api base.
- AZURE_OPENAI_API_KEY: The api key.
- OPENAI_API_VERSION: Optional. The api version, default "2023-07-01-preview".
"""
# Env var name reference: https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py#L160
api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
api_key = os.getenv("AZURE_OPENAI_API_KEY")
# Note: Name OPENAI_API_VERSION from OpenAI.
api_version = os.getenv("OPENAI_API_VERSION")
if api_base is None or api_key is None:
raise RequiredEnvironmentVariablesNotSetError(
env_vars=["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY"], cls_name=cls.__name__
)
# Note: Do not pass api_version None when init class object as we have default version.
optional_args = {"api_version": api_version} if api_version else {}
return cls(api_base=api_base, api_key=api_key, name=name, **optional_args)


class OpenAIConnection(_StrongTypeConnection):
"""Open AI connection.
Expand Down Expand Up @@ -485,6 +509,24 @@ def base_url(self, value):
"""Set the connection api base."""
self.configs["base_url"] = value

@classmethod
def from_env(cls, name=None):
"""
Build connection from environment variables.
Relevant environment variables:
- OPENAI_API_KEY: The api key.
- OPENAI_ORG_ID: Optional. The unique identifier for your organization which can be used in API requests.
- OPENAI_BASE_URL: Optional. Specify when use customized api base, leave None to use OpenAI default api base.
"""
# Env var name reference: https://github.com/openai/openai-python/blob/main/src/openai/_client.py#L92
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_BASE_URL")
organization = os.getenv("OPENAI_ORG_ID")
if api_key is None:
raise RequiredEnvironmentVariablesNotSetError(env_vars=["OPENAI_API_KEY"], cls_name=cls.__name__)
return cls(api_key=api_key, organization=organization, base_url=base_url, name=name)


class ServerlessConnection(_StrongTypeConnection):
"""Serverless connection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List

from promptflow._sdk._constants import MAX_LIST_CLI_RESULTS
from promptflow._sdk._errors import ConnectionNameNotSetError
from promptflow._sdk._orm import Connection as ORMConnection
from promptflow._sdk._telemetry import ActivityType, TelemetryMixin, monitor_operation
from promptflow._sdk._utils import safe_parse_object_list
Expand Down Expand Up @@ -76,6 +77,8 @@ def create_or_update(self, connection: _Connection, **kwargs):
:param connection: Run object to create or update.
:type connection: ~promptflow.sdk.entities._connection._Connection
"""
if not connection.name:
raise ConnectionNameNotSetError("Name is required to create or update connection.")
orm_object = connection._to_orm_object()
now = datetime.now().isoformat()
if orm_object.createdDate is None:
Expand Down
11 changes: 10 additions & 1 deletion src/promptflow/tests/sdk_cli_test/e2etests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import uuid
from pathlib import Path

import pydash
import pytest
from mock import mock

from promptflow._sdk._constants import SCRUBBED_VALUE
from promptflow._sdk._errors import ConnectionNameNotSetError
from promptflow._sdk._pf_client import PFClient
from promptflow._sdk.entities import AzureOpenAIConnection, CustomConnection
from promptflow._sdk.entities import AzureOpenAIConnection, CustomConnection, OpenAIConnection

_client = PFClient()

Expand Down Expand Up @@ -118,3 +121,9 @@ def test_upsert_connection_from_file(self, file_name, expected_updated_item, exp
), "Assert secrets not updated failed, expected: {}, actual: {}".format(
expected_secret_item[1], result._secrets[expected_secret_item[0]]
)

def test_create_connection_no_name(self):
with mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}):
connection = OpenAIConnection.from_env()
with pytest.raises(ConnectionNameNotSetError):
_client.connections.create_or_update(connection)
48 changes: 47 additions & 1 deletion src/promptflow/tests/sdk_cli_test/unittests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import os
from pathlib import Path
from unittest.mock import patch

import mock
import pytest

from promptflow._cli._pf._connection import validate_and_interactive_get_secrets
from promptflow._sdk._constants import SCRUBBED_VALUE, ConnectionAuthMode, CustomStrongTypeConnectionConfigs
from promptflow._sdk._errors import SDKError
from promptflow._sdk._errors import RequiredEnvironmentVariablesNotSetError, SDKError
from promptflow._sdk._load_functions import _load_env_to_connection
from promptflow._sdk.entities._connection import (
AzureContentSafetyConnection,
Expand Down Expand Up @@ -420,3 +422,47 @@ def test_convert_to_custom_strong_type(self, install_custom_tool_pkg):
match=r".*Failed to convert to custom strong type connection because of invalid module or class*",
):
connection._convert_to_custom_strong_type(module=module_name, to_class=custom_conn_type)

def test_connection_from_env(self):
with pytest.raises(RequiredEnvironmentVariablesNotSetError) as e:
AzureOpenAIConnection.from_env()
assert "to build AzureOpenAIConnection not set" in str(e.value)

with pytest.raises(RequiredEnvironmentVariablesNotSetError) as e:
OpenAIConnection.from_env()
assert "to build OpenAIConnection not set" in str(e.value)

# Happy path
# AzureOpenAI
with mock.patch.dict(
os.environ,
{
"AZURE_OPENAI_ENDPOINT": "test_endpoint",
"AZURE_OPENAI_API_KEY": "test_key",
"OPENAI_API_VERSION": "2024-01-01-preview",
},
):
connection = AzureOpenAIConnection.from_env("test_connection")
assert connection._to_dict() == {
"name": "test_connection",
"module": "promptflow.connections",
"type": "azure_open_ai",
"api_base": "test_endpoint",
"api_key": "test_key",
"api_type": "azure",
"api_version": "2024-01-01-preview",
"auth_mode": "key",
}
# OpenAI
with mock.patch.dict(
os.environ, {"OPENAI_API_KEY": "test_key", "OPENAI_BASE_URL": "test_base", "OPENAI_ORG_ID": "test_org"}
):
connection = OpenAIConnection.from_env("test_connection")
assert connection._to_dict() == {
"name": "test_connection",
"module": "promptflow.connections",
"type": "open_ai",
"api_key": "test_key",
"organization": "test_org",
"base_url": "test_base",
}

0 comments on commit 7bd946e

Please sign in to comment.