Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🙂 Azure OpenAI #1165

Merged
merged 1 commit into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from reworkd_platform.schemas.workflow.base import Block, BlockIOBase
from reworkd_platform.services.aws.s3 import SimpleStorageService
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.model_settings import create_model
from reworkd_platform.web.api.agent.model_factory import create_model


class SummaryAgentInput(BlockIOBase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tiktoken import Encoding, get_encoding

from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS, LLM_Model
from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI


class TokenService:
Expand Down
17 changes: 17 additions & 0 deletions platform/reworkd_platform/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ class Settings(BaseSettings):
openai_api_key: str = "<Should be updated via env>"
secondary_openai_api_key: Optional[str] = None

# Azure OpenAI
azure_openai_api_version: str = "2023-06-01-preview"
azure_openai_api_key: str = ""
azure_openai_api_base: str = ""
azure_openai_deployment_name: str = ""

replicate_api_key: Optional[str] = None
serp_api_key: Optional[str] = None
scrapingbee_api_key: Optional[str] = None
Expand Down Expand Up @@ -154,6 +160,17 @@ def kafka_enabled(self) -> bool:
]
)

@property
def azure_openai_enabled(self) -> bool:
return all(
[
self.azure_openai_api_base,
self.azure_openai_deployment_name,
self.azure_openai_api_version,
self.azure_openai_api_key,
]
)

class Config:
env_file = ".env"
env_prefix = ENV_PREFIX
Expand Down
2 changes: 1 addition & 1 deletion platform/reworkd_platform/tests/agent/create_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from reworkd_platform.schemas.agent import ModelSettings
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.web.api.agent.model_settings import create_model
from reworkd_platform.web.api.agent.model_factory import create_model


@pytest.mark.parametrize(
Expand Down
26 changes: 26 additions & 0 deletions platform/reworkd_platform/tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,29 @@ def test_pusher_enabled(settings: Dict[str, Any], expected: bool):
def test_kafka_enabled(settings: Dict[str, Any], expected: bool):
settings = Settings(**settings)
assert settings.kafka_enabled == expected


@pytest.mark.parametrize(
"settings, expected",
[
(
{
"azure_openai_api_base": "123",
"azure_openai_api_key": "123",
"azure_openai_deployment_name": "123",
},
True,
),
(
{
"azure_openai_api_base": "123",
"azure_openai_api_key": "123",
},
False,
),
({}, False),
],
)
def test_azure_enabled(settings: Dict[str, Any], expected: bool):
settings = Settings(**settings)
assert settings.azure_openai_enabled == expected
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
OpenAIAgentService,
)
from reworkd_platform.web.api.agent.dependancies import get_agent_memory
from reworkd_platform.web.api.agent.model_settings import create_model
from reworkd_platform.web.api.agent.model_factory import create_model
from reworkd_platform.web.api.dependencies import get_current_user
from reworkd_platform.web.api.memory.memory import AgentMemory


def get_agent_service(
validator: Callable[..., Coroutine[Any, Any, AgentRun]],
streaming: bool = False,
azure: bool = False, # As of 07/2023, azure does not support functions
) -> Callable[..., AgentService]:
def func(
run: AgentRun = Depends(validator),
Expand All @@ -33,7 +34,7 @@ def func(
if settings.ff_mock_mode_enabled:
return MockAgentService()

model = create_model(run.model_settings, user, streaming=streaming)
model = create_model(run.model_settings, user, streaming=streaming, azure=azure)
return OpenAIAgentService(
model,
run.model_settings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
openai_error_handler,
parse_with_handling,
)
from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
from reworkd_platform.web.api.agent.prompts import (
analyze_task_prompt,
chat_prompt,
Expand Down
72 changes: 72 additions & 0 deletions platform/reworkd_platform/web/api/agent/model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any

import openai
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from pydantic import Field

from reworkd_platform.schemas.agent import LLM_Model, ModelSettings
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.api_utils import rotate_keys

openai.api_base = settings.openai_api_base


class WrappedChatOpenAI(ChatOpenAI):
client: Any = Field(
default=None,
description="Meta private value but mypy will complain its missing",
)
max_tokens: int
model_name: LLM_Model = Field(alias="model")


class WrappedAzureChatOpenAI(WrappedChatOpenAI, AzureChatOpenAI):
openai_api_base: str = Field(default=settings.azure_openai_api_base)
openai_api_version: str = Field(default=settings.azure_openai_api_version)
deployment_name: str = Field(default=settings.azure_openai_deployment_name)


def create_model(
model_settings: ModelSettings,
user: UserBase,
streaming: bool = False,
azure: bool = False,
) -> WrappedChatOpenAI:
if (
not model_settings.custom_api_key
and model_settings.model == "gpt-3.5-turbo"
and azure
and settings.azure_openai_enabled
):
return _create_azure_model(model_settings, user, streaming)

api_key = model_settings.custom_api_key or rotate_keys(
gpt_3_key=settings.openai_api_key,
gpt_4_key=settings.secondary_openai_api_key,
model=model_settings.model,
)

return WrappedChatOpenAI(
openai_api_key=api_key,
temperature=model_settings.temperature,
model=model_settings.model,
max_tokens=model_settings.max_tokens,
streaming=streaming,
max_retries=5,
model_kwargs={"user": user.email},
)


def _create_azure_model(
model_settings: ModelSettings, user: UserBase, streaming: bool = False
) -> WrappedChatOpenAI:
return WrappedAzureChatOpenAI(
openai_api_key=settings.azure_openai_api_key,
temperature=model_settings.temperature,
model=model_settings.model,
max_tokens=model_settings.max_tokens,
streaming=streaming,
max_retries=5,
model_kwargs={"user": user.email},
)
36 changes: 0 additions & 36 deletions platform/reworkd_platform/web/api/agent/model_settings.py

This file was deleted.

8 changes: 6 additions & 2 deletions platform/reworkd_platform/web/api/agent/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
)
async def start_tasks(
req_body: AgentRun = Depends(agent_start_validator),
agent_service: AgentService = Depends(get_agent_service(agent_start_validator)),
agent_service: AgentService = Depends(
get_agent_service(agent_start_validator, azure=True)
),
) -> NewTasksResponse:
new_tasks = await agent_service.start_goal_agent(goal=req_body.goal)
return NewTasksResponse(newTasks=new_tasks, run_id=req_body.run_id)
Expand Down Expand Up @@ -71,7 +73,9 @@ async def execute_tasks(
@router.post("/create")
async def create_tasks(
req_body: AgentTaskCreate = Depends(agent_create_validator),
agent_service: AgentService = Depends(get_agent_service(agent_create_validator)),
agent_service: AgentService = Depends(
get_agent_service(agent_create_validator, azure=True)
),
) -> NewTasksResponse:
new_tasks = await agent_service.create_tasks_agent(
goal=req_body.goal,
Expand Down