From dc87e66aaa1110ae25e6153df6f9b6c0f30276b5 Mon Sep 17 00:00:00 2001 From: awtkns <32209255+awtkns@users.noreply.github.com> Date: Sun, 30 Jul 2023 18:23:34 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=99=82=20Azure=20OpenAI?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../schemas/workflow/blocks/summary_agent.py | 2 +- .../services/tokenizer/token_service.py | 2 +- platform/reworkd_platform/settings.py | 17 +++++ .../tests/agent/create_model_test.py | 2 +- .../reworkd_platform/tests/test_settings.py | 26 +++++++ .../agent_service/agent_service_provider.py | 5 +- .../agent_service/open_ai_agent_service.py | 2 +- .../web/api/agent/model_factory.py | 72 +++++++++++++++++++ .../web/api/agent/model_settings.py | 36 ---------- .../reworkd_platform/web/api/agent/views.py | 8 ++- 10 files changed, 128 insertions(+), 44 deletions(-) create mode 100644 platform/reworkd_platform/web/api/agent/model_factory.py delete mode 100644 platform/reworkd_platform/web/api/agent/model_settings.py diff --git a/platform/reworkd_platform/schemas/workflow/blocks/summary_agent.py b/platform/reworkd_platform/schemas/workflow/blocks/summary_agent.py index d995713f65..3e96027ec6 100644 --- a/platform/reworkd_platform/schemas/workflow/blocks/summary_agent.py +++ b/platform/reworkd_platform/schemas/workflow/blocks/summary_agent.py @@ -18,7 +18,7 @@ from reworkd_platform.schemas.workflow.base import Block, BlockIOBase from reworkd_platform.services.aws.s3 import SimpleStorageService from reworkd_platform.settings import settings -from reworkd_platform.web.api.agent.model_settings import create_model +from reworkd_platform.web.api.agent.model_factory import create_model class SummaryAgentInput(BlockIOBase): diff --git a/platform/reworkd_platform/services/tokenizer/token_service.py b/platform/reworkd_platform/services/tokenizer/token_service.py index 25fc18563f..b457c5f0a1 100644 --- a/platform/reworkd_platform/services/tokenizer/token_service.py +++ b/platform/reworkd_platform/services/tokenizer/token_service.py @@ -1,7 +1,7 @@ from tiktoken import Encoding, get_encoding from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS, LLM_Model -from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI +from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI class TokenService: diff --git a/platform/reworkd_platform/settings.py b/platform/reworkd_platform/settings.py index cb39a3e087..b6f8186a38 100644 --- a/platform/reworkd_platform/settings.py +++ b/platform/reworkd_platform/settings.py @@ -56,6 +56,12 @@ class Settings(BaseSettings): openai_api_key: str = "" secondary_openai_api_key: Optional[str] = None + # Azure OpenAI + azure_openai_api_version: str = "2023-06-01-preview" + azure_openai_api_key: str = "" + azure_openai_api_base: str = "" + azure_openai_deployment_name: str = "" + replicate_api_key: Optional[str] = None serp_api_key: Optional[str] = None scrapingbee_api_key: Optional[str] = None @@ -154,6 +160,17 @@ def kafka_enabled(self) -> bool: ] ) + @property + def azure_openai_enabled(self) -> bool: + return all( + [ + self.azure_openai_api_base, + self.azure_openai_deployment_name, + self.azure_openai_api_version, + self.azure_openai_api_key, + ] + ) + class Config: env_file = ".env" env_prefix = ENV_PREFIX diff --git a/platform/reworkd_platform/tests/agent/create_model_test.py b/platform/reworkd_platform/tests/agent/create_model_test.py index dd8dc0078c..ffede6884b 100644 --- a/platform/reworkd_platform/tests/agent/create_model_test.py +++ b/platform/reworkd_platform/tests/agent/create_model_test.py @@ -4,7 +4,7 @@ from reworkd_platform.schemas.agent import ModelSettings from reworkd_platform.schemas.user import UserBase -from reworkd_platform.web.api.agent.model_settings import create_model +from reworkd_platform.web.api.agent.model_factory import create_model @pytest.mark.parametrize( diff --git a/platform/reworkd_platform/tests/test_settings.py b/platform/reworkd_platform/tests/test_settings.py index c49970c522..38cde7c1c0 100644 --- a/platform/reworkd_platform/tests/test_settings.py +++ b/platform/reworkd_platform/tests/test_settings.py @@ -56,3 +56,29 @@ def test_pusher_enabled(settings: Dict[str, Any], expected: bool): def test_kafka_enabled(settings: Dict[str, Any], expected: bool): settings = Settings(**settings) assert settings.kafka_enabled == expected + + +@pytest.mark.parametrize( + "settings, expected", + [ + ( + { + "azure_openai_api_base": "123", + "azure_openai_api_key": "123", + "azure_openai_deployment_name": "123", + }, + True, + ), + ( + { + "azure_openai_api_base": "123", + "azure_openai_api_key": "123", + }, + False, + ), + ({}, False), + ], +) +def test_azure_enabled(settings: Dict[str, Any], expected: bool): + settings = Settings(**settings) + assert settings.azure_openai_enabled == expected diff --git a/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py b/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py index 4241c44448..1cfe9bbc64 100644 --- a/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py +++ b/platform/reworkd_platform/web/api/agent/agent_service/agent_service_provider.py @@ -15,7 +15,7 @@ OpenAIAgentService, ) from reworkd_platform.web.api.agent.dependancies import get_agent_memory -from reworkd_platform.web.api.agent.model_settings import create_model +from reworkd_platform.web.api.agent.model_factory import create_model from reworkd_platform.web.api.dependencies import get_current_user from reworkd_platform.web.api.memory.memory import AgentMemory @@ -23,6 +23,7 @@ def get_agent_service( validator: Callable[..., Coroutine[Any, Any, AgentRun]], streaming: bool = False, + azure: bool = False, # As of 07/2023, azure does not support functions ) -> Callable[..., AgentService]: def func( run: AgentRun = Depends(validator), @@ -33,7 +34,7 @@ def func( if settings.ff_mock_mode_enabled: return MockAgentService() - model = create_model(run.model_settings, user, streaming=streaming) + model = create_model(run.model_settings, user, streaming=streaming, azure=azure) return OpenAIAgentService( model, run.model_settings, diff --git a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py index f41d20d294..12b3b46617 100644 --- a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py +++ b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py @@ -19,7 +19,7 @@ openai_error_handler, parse_with_handling, ) -from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI +from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI from reworkd_platform.web.api.agent.prompts import ( analyze_task_prompt, chat_prompt, diff --git a/platform/reworkd_platform/web/api/agent/model_factory.py b/platform/reworkd_platform/web/api/agent/model_factory.py new file mode 100644 index 0000000000..d23c26bbf3 --- /dev/null +++ b/platform/reworkd_platform/web/api/agent/model_factory.py @@ -0,0 +1,72 @@ +from typing import Any + +import openai +from langchain.chat_models import ChatOpenAI, AzureChatOpenAI +from pydantic import Field + +from reworkd_platform.schemas.agent import LLM_Model, ModelSettings +from reworkd_platform.schemas.user import UserBase +from reworkd_platform.settings import settings +from reworkd_platform.web.api.agent.api_utils import rotate_keys + +openai.api_base = settings.openai_api_base + + +class WrappedChatOpenAI(ChatOpenAI): + client: Any = Field( + default=None, + description="Meta private value but mypy will complain its missing", + ) + max_tokens: int + model_name: LLM_Model = Field(alias="model") + + +class WrappedAzureChatOpenAI(WrappedChatOpenAI, AzureChatOpenAI): + openai_api_base: str = Field(default=settings.azure_openai_api_base) + openai_api_version: str = Field(default=settings.azure_openai_api_version) + deployment_name: str = Field(default=settings.azure_openai_deployment_name) + + +def create_model( + model_settings: ModelSettings, + user: UserBase, + streaming: bool = False, + azure: bool = False, +) -> WrappedChatOpenAI: + if ( + not model_settings.custom_api_key + and model_settings.model == "gpt-3.5-turbo" + and azure + and settings.azure_openai_enabled + ): + return _create_azure_model(model_settings, user, streaming) + + api_key = model_settings.custom_api_key or rotate_keys( + gpt_3_key=settings.openai_api_key, + gpt_4_key=settings.secondary_openai_api_key, + model=model_settings.model, + ) + + return WrappedChatOpenAI( + openai_api_key=api_key, + temperature=model_settings.temperature, + model=model_settings.model, + max_tokens=model_settings.max_tokens, + streaming=streaming, + max_retries=5, + model_kwargs={"user": user.email}, + ) + + +def _create_azure_model( + model_settings: ModelSettings, user: UserBase, streaming: bool = False +) -> WrappedChatOpenAI: + return WrappedAzureChatOpenAI( + openai_api_key=settings.azure_openai_api_key, + temperature=model_settings.temperature, + model=model_settings.model, + max_tokens=model_settings.max_tokens, + streaming=streaming, + max_retries=5, + model_kwargs={"user": user.email}, + ) diff --git a/platform/reworkd_platform/web/api/agent/model_settings.py b/platform/reworkd_platform/web/api/agent/model_settings.py deleted file mode 100644 index 878b49d985..0000000000 --- a/platform/reworkd_platform/web/api/agent/model_settings.py +++ /dev/null @@ -1,36 +0,0 @@ -import openai -from langchain.chat_models import ChatOpenAI -from pydantic import Field - -from reworkd_platform.schemas.agent import LLM_Model, ModelSettings -from reworkd_platform.schemas.user import UserBase -from reworkd_platform.settings import settings -from reworkd_platform.web.api.agent.api_utils import rotate_keys - -openai.api_base = settings.openai_api_base - - -class WrappedChatOpenAI(ChatOpenAI): - max_tokens: int - model_name: LLM_Model = Field(alias="model") - - -def create_model( - model_settings: ModelSettings, user: UserBase, streaming: bool = False -) -> WrappedChatOpenAI: - api_key = model_settings.custom_api_key or rotate_keys( - gpt_3_key=settings.openai_api_key, - gpt_4_key=settings.secondary_openai_api_key, - model=model_settings.model, - ) - - return WrappedChatOpenAI( - client=None, # Meta private value but mypy will complain its missing - openai_api_key=api_key, - temperature=model_settings.temperature, - model=model_settings.model, - max_tokens=model_settings.max_tokens, - streaming=streaming, - max_retries=5, - model_kwargs={"user": user.email}, - ) diff --git a/platform/reworkd_platform/web/api/agent/views.py b/platform/reworkd_platform/web/api/agent/views.py index c8cb666cd2..66fb697ab7 100644 --- a/platform/reworkd_platform/web/api/agent/views.py +++ b/platform/reworkd_platform/web/api/agent/views.py @@ -36,7 +36,9 @@ ) async def start_tasks( req_body: AgentRun = Depends(agent_start_validator), - agent_service: AgentService = Depends(get_agent_service(agent_start_validator)), + agent_service: AgentService = Depends( + get_agent_service(agent_start_validator, azure=True) + ), ) -> NewTasksResponse: new_tasks = await agent_service.start_goal_agent(goal=req_body.goal) return NewTasksResponse(newTasks=new_tasks, run_id=req_body.run_id) @@ -71,7 +73,9 @@ async def execute_tasks( @router.post("/create") async def create_tasks( req_body: AgentTaskCreate = Depends(agent_create_validator), - agent_service: AgentService = Depends(get_agent_service(agent_create_validator)), + agent_service: AgentService = Depends( + get_agent_service(agent_create_validator, azure=True) + ), ) -> NewTasksResponse: new_tasks = await agent_service.create_tasks_agent( goal=req_body.goal,