Skip to content

Commit

Permalink
🙂 Azure OpenAI (#1165)
Browse files Browse the repository at this point in the history
  • Loading branch information
awtkns authored Jul 31, 2023
1 parent 35bd681 commit e2bd61f
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from reworkd_platform.schemas.workflow.base import Block, BlockIOBase
from reworkd_platform.services.aws.s3 import SimpleStorageService
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.model_settings import create_model
from reworkd_platform.web.api.agent.model_factory import create_model


class SummaryAgentInput(BlockIOBase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tiktoken import Encoding, get_encoding

from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS, LLM_Model
from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI


class TokenService:
Expand Down
17 changes: 17 additions & 0 deletions platform/reworkd_platform/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class Settings(BaseSettings):
openai_api_key: str = "<Should be updated via env>"
secondary_openai_api_key: Optional[str] = None

# Azure OpenAI
azure_openai_api_version: str = "2023-06-01-preview"
azure_openai_api_key: str = ""
azure_openai_api_base: str = ""
azure_openai_deployment_name: str = ""

replicate_api_key: Optional[str] = None
serp_api_key: Optional[str] = None
scrapingbee_api_key: Optional[str] = None
Expand Down Expand Up @@ -154,6 +160,17 @@ def kafka_enabled(self) -> bool:
]
)

@property
def azure_openai_enabled(self) -> bool:
return all(
[
self.azure_openai_api_base,
self.azure_openai_deployment_name,
self.azure_openai_api_version,
self.azure_openai_api_key,
]
)

class Config:
env_file = ".env"
env_prefix = ENV_PREFIX
Expand Down
2 changes: 1 addition & 1 deletion platform/reworkd_platform/tests/agent/create_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from reworkd_platform.schemas.agent import ModelSettings
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.web.api.agent.model_settings import create_model
from reworkd_platform.web.api.agent.model_factory import create_model


@pytest.mark.parametrize(
Expand Down
26 changes: 26 additions & 0 deletions platform/reworkd_platform/tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,29 @@ def test_pusher_enabled(settings: Dict[str, Any], expected: bool):
def test_kafka_enabled(settings: Dict[str, Any], expected: bool):
settings = Settings(**settings)
assert settings.kafka_enabled == expected


@pytest.mark.parametrize(
"settings, expected",
[
(
{
"azure_openai_api_base": "123",
"azure_openai_api_key": "123",
"azure_openai_deployment_name": "123",
},
True,
),
(
{
"azure_openai_api_base": "123",
"azure_openai_api_key": "123",
},
False,
),
({}, False),
],
)
def test_azure_enabled(settings: Dict[str, Any], expected: bool):
settings = Settings(**settings)
assert settings.azure_openai_enabled == expected
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
OpenAIAgentService,
)
from reworkd_platform.web.api.agent.dependancies import get_agent_memory
from reworkd_platform.web.api.agent.model_settings import create_model
from reworkd_platform.web.api.agent.model_factory import create_model
from reworkd_platform.web.api.dependencies import get_current_user
from reworkd_platform.web.api.memory.memory import AgentMemory


def get_agent_service(
validator: Callable[..., Coroutine[Any, Any, AgentRun]],
streaming: bool = False,
azure: bool = False, # As of 07/2023, azure does not support functions
) -> Callable[..., AgentService]:
def func(
run: AgentRun = Depends(validator),
Expand All @@ -33,7 +34,7 @@ def func(
if settings.ff_mock_mode_enabled:
return MockAgentService()

model = create_model(run.model_settings, user, streaming=streaming)
model = create_model(run.model_settings, user, streaming=streaming, azure=azure)
return OpenAIAgentService(
model,
run.model_settings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
openai_error_handler,
parse_with_handling,
)
from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
from reworkd_platform.web.api.agent.prompts import (
analyze_task_prompt,
chat_prompt,
Expand Down
72 changes: 72 additions & 0 deletions platform/reworkd_platform/web/api/agent/model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any

import openai
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from pydantic import Field

from reworkd_platform.schemas.agent import LLM_Model, ModelSettings
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.api_utils import rotate_keys

openai.api_base = settings.openai_api_base


class WrappedChatOpenAI(ChatOpenAI):
client: Any = Field(
default=None,
description="Meta private value but mypy will complain its missing",
)
max_tokens: int
model_name: LLM_Model = Field(alias="model")


class WrappedAzureChatOpenAI(WrappedChatOpenAI, AzureChatOpenAI):
openai_api_base: str = Field(default=settings.azure_openai_api_base)
openai_api_version: str = Field(default=settings.azure_openai_api_version)
deployment_name: str = Field(default=settings.azure_openai_deployment_name)


def create_model(
model_settings: ModelSettings,
user: UserBase,
streaming: bool = False,
azure: bool = False,
) -> WrappedChatOpenAI:
if (
not model_settings.custom_api_key
and model_settings.model == "gpt-3.5-turbo"
and azure
and settings.azure_openai_enabled
):
return _create_azure_model(model_settings, user, streaming)

api_key = model_settings.custom_api_key or rotate_keys(
gpt_3_key=settings.openai_api_key,
gpt_4_key=settings.secondary_openai_api_key,
model=model_settings.model,
)

return WrappedChatOpenAI(
openai_api_key=api_key,
temperature=model_settings.temperature,
model=model_settings.model,
max_tokens=model_settings.max_tokens,
streaming=streaming,
max_retries=5,
model_kwargs={"user": user.email},
)


def _create_azure_model(
model_settings: ModelSettings, user: UserBase, streaming: bool = False
) -> WrappedChatOpenAI:
return WrappedAzureChatOpenAI(
openai_api_key=settings.azure_openai_api_key,
temperature=model_settings.temperature,
model=model_settings.model,
max_tokens=model_settings.max_tokens,
streaming=streaming,
max_retries=5,
model_kwargs={"user": user.email},
)
36 changes: 0 additions & 36 deletions platform/reworkd_platform/web/api/agent/model_settings.py

This file was deleted.

8 changes: 6 additions & 2 deletions platform/reworkd_platform/web/api/agent/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
)
async def start_tasks(
req_body: AgentRun = Depends(agent_start_validator),
agent_service: AgentService = Depends(get_agent_service(agent_start_validator)),
agent_service: AgentService = Depends(
get_agent_service(agent_start_validator, azure=True)
),
) -> NewTasksResponse:
new_tasks = await agent_service.start_goal_agent(goal=req_body.goal)
return NewTasksResponse(newTasks=new_tasks, run_id=req_body.run_id)
Expand Down Expand Up @@ -71,7 +73,9 @@ async def execute_tasks(
@router.post("/create")
async def create_tasks(
req_body: AgentTaskCreate = Depends(agent_create_validator),
agent_service: AgentService = Depends(get_agent_service(agent_create_validator)),
agent_service: AgentService = Depends(
get_agent_service(agent_create_validator, azure=True)
),
) -> NewTasksResponse:
new_tasks = await agent_service.create_tasks_agent(
goal=req_body.goal,
Expand Down

0 comments on commit e2bd61f

Please sign in to comment.