From 3cd9a446100d1695b67d5b5e6e61550b92bdf779 Mon Sep 17 00:00:00 2001 From: Scott <146760070+scott-cohere@users.noreply.github.com> Date: Fri, 14 Jun 2024 13:44:44 -0400 Subject: [PATCH] [backend] make deployment field optional in API and DB (#213) * changes * saving changes * lint --- src/backend/database_models/agent.py | 4 +- src/backend/schemas/agent.py | 2 +- src/backend/tests/crud/test_agent.py | 12 --- src/backend/tests/routers/test_agent.py | 19 ----- .../src/cohere-client/generated/index.ts | 5 +- .../cohere-client/generated/models/Agent.ts | 2 + .../generated/models/CreateAgent.ts | 4 +- .../generated/models/JWTResponse.ts | 7 ++ .../models/{Auth.ts => ListAuthStrategy.ts} | 2 +- .../cohere-client/generated/models/Logout.ts | 5 ++ .../generated/models/StreamEnd.ts | 2 +- .../generated/models/ToolName.ts | 12 +++ .../generated/models/UpdateAgent.ts | 2 + .../generated/services/DefaultService.ts | 76 +++++++++++++------ 14 files changed, 93 insertions(+), 61 deletions(-) create mode 100644 src/interfaces/coral_web/src/cohere-client/generated/models/JWTResponse.ts rename src/interfaces/coral_web/src/cohere-client/generated/models/{Auth.ts => ListAuthStrategy.ts} (82%) create mode 100644 src/interfaces/coral_web/src/cohere-client/generated/models/Logout.ts create mode 100644 src/interfaces/coral_web/src/cohere-client/generated/models/ToolName.ts diff --git a/src/backend/database_models/agent.py b/src/backend/database_models/agent.py index e2561e0aba..e842e861c8 100644 --- a/src/backend/database_models/agent.py +++ b/src/backend/database_models/agent.py @@ -41,7 +41,9 @@ class Agent(Base): Enum(AgentModel, native_enum=False), nullable=False ) deployment: Mapped[AgentDeployment] = mapped_column( - Enum(AgentDeployment, native_enum=False), nullable=False + Enum(AgentDeployment, native_enum=False), + default=AgentDeployment.COHERE_PLATFORM, + nullable=False, ) user_id: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/src/backend/schemas/agent.py b/src/backend/schemas/agent.py index d7289f6d9e..ad4827393f 100644 --- a/src/backend/schemas/agent.py +++ b/src/backend/schemas/agent.py @@ -38,7 +38,7 @@ class CreateAgent(BaseModel): preamble: Optional[str] = None temperature: Optional[float] = None model: AgentModel - deployment: AgentDeployment + deployment: Optional[AgentDeployment] = None tools: Optional[list[ToolName]] = None class Config: diff --git a/src/backend/tests/crud/test_agent.py b/src/backend/tests/crud/test_agent.py index 989477c01f..6fbeb5e330 100644 --- a/src/backend/tests/crud/test_agent.py +++ b/src/backend/tests/crud/test_agent.py @@ -50,7 +50,6 @@ def test_create_agent_empty_non_required_fields(session, user): agent_data = Agent( user_id=user.id, name="test", - deployment=AgentDeployment.COHERE_PLATFORM, model=AgentModel.COMMAND_R_PLUS, ) @@ -99,17 +98,6 @@ def test_create_agent_missing_model(session, user): _ = agent_crud.create_agent(session, agent_data) -def test_create_agent_missing_deployment(session, user): - agent_data = Agent( - user_id=user.id, - name="test", - model=AgentModel.COMMAND_R_PLUS, - ) - - with pytest.raises(IntegrityError): - _ = agent_crud.create_agent(session, agent_data) - - def test_create_agent_missing_user_id(session): agent_data = Agent( name="test", diff --git a/src/backend/tests/routers/test_agent.py b/src/backend/tests/routers/test_agent.py index 41ab7f16b3..20b7d53611 100644 --- a/src/backend/tests/routers/test_agent.py +++ b/src/backend/tests/routers/test_agent.py @@ -77,22 +77,6 @@ def test_create_agent_missing_model( assert response.status_code == 422 -def test_create_agent_missing_deployment( - session_client: TestClient, session: Session -) -> None: - request_json = { - "name": "test agent", - "description": "test description", - "preamble": "test preamble", - "temperature": 0.5, - "model": AgentModel.COMMAND_R, - } - response = session_client.post( - "/v1/agents", json=request_json, headers={"User-Id": "123"} - ) - assert response.status_code == 422 - - def test_create_agent_missing_user_id_header( session_client: TestClient, session: Session ) -> None: @@ -111,7 +95,6 @@ def test_create_agent_missing_non_required_fields( request_json = { "name": "test agent", "model": AgentModel.COMMAND_R, - "deployment": AgentDeployment.COHERE_PLATFORM, } print(request_json) @@ -128,7 +111,6 @@ def test_create_agent_missing_non_required_fields( assert response_agent["preamble"] == "" assert response_agent["temperature"] == 0.3 assert response_agent["model"] == request_json["model"] - assert response_agent["deployment"] == request_json["deployment"] agent = session.get(Agent, response_agent["id"]) assert agent is not None @@ -138,7 +120,6 @@ def test_create_agent_missing_non_required_fields( assert agent.preamble == "" assert agent.temperature == 0.3 assert agent.model == request_json["model"] - assert agent.deployment == request_json["deployment"] def test_create_agent_wrong_model_deployment_enums( diff --git a/src/interfaces/coral_web/src/cohere-client/generated/index.ts b/src/interfaces/coral_web/src/cohere-client/generated/index.ts index 845b82cc13..7c8cd12d17 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/index.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/index.ts @@ -10,7 +10,6 @@ export type { OpenAPIConfig } from './core/OpenAPI'; export type { Agent } from './models/Agent'; export { AgentDeployment } from './models/AgentDeployment'; export { AgentModel } from './models/AgentModel'; -export type { Auth } from './models/Auth'; export type { Body_upload_file_v1_conversations_upload_file_post } from './models/Body_upload_file_v1_conversations_upload_file_post'; export { Category } from './models/Category'; export type { ChatMessage } from './models/ChatMessage'; @@ -31,9 +30,12 @@ export type { Deployment } from './models/Deployment'; export type { Document } from './models/Document'; export type { File } from './models/File'; export type { HTTPValidationError } from './models/HTTPValidationError'; +export type { JWTResponse } from './models/JWTResponse'; export type { LangchainChatRequest } from './models/LangchainChatRequest'; +export type { ListAuthStrategy } from './models/ListAuthStrategy'; export type { ListFile } from './models/ListFile'; export type { Login } from './models/Login'; +export type { Logout } from './models/Logout'; export type { ManagedTool } from './models/ManagedTool'; export type { Message } from './models/Message'; export { MessageAgent } from './models/MessageAgent'; @@ -53,6 +55,7 @@ export type { StreamToolResult } from './models/StreamToolResult'; export type { Tool } from './models/Tool'; export type { ToolCall } from './models/ToolCall'; export { ToolInputType } from './models/ToolInputType'; +export { ToolName } from './models/ToolName'; export type { UpdateAgent } from './models/UpdateAgent'; export type { UpdateConversation } from './models/UpdateConversation'; export type { UpdateDeploymentEnv } from './models/UpdateDeploymentEnv'; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/Agent.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/Agent.ts index c51543700c..9cd2f76beb 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/Agent.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/Agent.ts @@ -7,6 +7,7 @@ /* eslint-disable */ import type { AgentDeployment } from './AgentDeployment'; import type { AgentModel } from './AgentModel'; +import type { ToolName } from './ToolName'; export type Agent = { user_id: string; @@ -18,6 +19,7 @@ export type Agent = { description: string | null; preamble: string | null; temperature: number; + tools: Array; model: AgentModel; deployment: AgentDeployment; }; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/CreateAgent.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/CreateAgent.ts index 592fd5b397..13870aba77 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/CreateAgent.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/CreateAgent.ts @@ -7,6 +7,7 @@ /* eslint-disable */ import type { AgentDeployment } from './AgentDeployment'; import type { AgentModel } from './AgentModel'; +import type { ToolName } from './ToolName'; export type CreateAgent = { name: string; @@ -15,5 +16,6 @@ export type CreateAgent = { preamble?: string | null; temperature?: number | null; model: AgentModel; - deployment: AgentDeployment; + deployment?: AgentDeployment | null; + tools?: Array | null; }; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/JWTResponse.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/JWTResponse.ts new file mode 100644 index 0000000000..f9cb3a7b8a --- /dev/null +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/JWTResponse.ts @@ -0,0 +1,7 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type JWTResponse = { + token: string; +}; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/Auth.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/ListAuthStrategy.ts similarity index 82% rename from src/interfaces/coral_web/src/cohere-client/generated/models/Auth.ts rename to src/interfaces/coral_web/src/cohere-client/generated/models/ListAuthStrategy.ts index 6a7b355839..94050ba192 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/Auth.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/ListAuthStrategy.ts @@ -2,6 +2,6 @@ /* istanbul ignore file */ /* tslint:disable */ /* eslint-disable */ -export type Auth = { +export type ListAuthStrategy = { strategy: string; }; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/Logout.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/Logout.ts new file mode 100644 index 0000000000..4a3baa9a23 --- /dev/null +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/Logout.ts @@ -0,0 +1,5 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export type Logout = {}; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/StreamEnd.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/StreamEnd.ts index 8b68ada681..24ed36c3a6 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/StreamEnd.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/StreamEnd.ts @@ -20,5 +20,5 @@ export type StreamEnd = { search_results?: Array>; search_queries?: Array; tool_calls?: Array; - finish_reason: string; + finish_reason?: string | null; }; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/ToolName.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/ToolName.ts new file mode 100644 index 0000000000..14b40f9d5e --- /dev/null +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/ToolName.ts @@ -0,0 +1,12 @@ +/* generated using openapi-typescript-codegen -- do no edit */ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +export enum ToolName { + WIKIPEDIA = 'Wikipedia', + SEARCH_FILE = 'search_file', + READ_DOCUMENT = 'read_document', + PYTHON_INTERPRETER = 'Python_Interpreter', + CALCULATOR = 'Calculator', + INTERNET_SEARCH = 'Internet_Search', +} diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/UpdateAgent.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/UpdateAgent.ts index fb30ce9a1f..7462d81c05 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/UpdateAgent.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/UpdateAgent.ts @@ -7,6 +7,7 @@ /* eslint-disable */ import type { AgentDeployment } from './AgentDeployment'; import type { AgentModel } from './AgentModel'; +import type { ToolName } from './ToolName'; export type UpdateAgent = { name?: string | null; @@ -16,4 +17,5 @@ export type UpdateAgent = { temperature?: number | null; model?: AgentModel | null; deployment?: AgentDeployment | null; + tools?: Array | null; }; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/services/DefaultService.ts b/src/interfaces/coral_web/src/cohere-client/generated/services/DefaultService.ts index 82803b859f..300c941dc8 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/services/DefaultService.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/services/DefaultService.ts @@ -9,7 +9,6 @@ import type { CancelablePromise } from '../core/CancelablePromise'; import { OpenAPI } from '../core/OpenAPI'; import { request as __request } from '../core/request'; import type { Agent } from '../models/Agent'; -import type { Auth } from '../models/Auth'; import type { Body_upload_file_v1_conversations_upload_file_post } from '../models/Body_upload_file_v1_conversations_upload_file_post'; import type { ChatResponseEvent } from '../models/ChatResponseEvent'; import type { CohereChatRequest } from '../models/CohereChatRequest'; @@ -23,9 +22,12 @@ import type { DeleteFile } from '../models/DeleteFile'; import type { DeleteUser } from '../models/DeleteUser'; import type { Deployment } from '../models/Deployment'; import type { File } from '../models/File'; +import type { JWTResponse } from '../models/JWTResponse'; import type { LangchainChatRequest } from '../models/LangchainChatRequest'; +import type { ListAuthStrategy } from '../models/ListAuthStrategy'; import type { ListFile } from '../models/ListFile'; import type { Login } from '../models/Login'; +import type { Logout } from '../models/Logout'; import type { ManagedTool } from '../models/ManagedTool'; import type { NonStreamedChatResponse } from '../models/NonStreamedChatResponse'; import type { UpdateAgent } from '../models/UpdateAgent'; @@ -44,10 +46,10 @@ export class DefaultService { * * Returns: * List[dict]: List of dictionaries containing the enabled auth strategy names. - * @returns any Successful Response + * @returns ListAuthStrategy Successful Response * @throws ApiError */ - public static getStrategiesV1AuthStrategiesGet(): CancelablePromise { + public static getStrategiesV1AuthStrategiesGet(): CancelablePromise> { return __request(OpenAPI, { method: 'GET', url: '/v1/auth_strategies', @@ -74,7 +76,11 @@ export class DefaultService { * @returns any Successful Response * @throws ApiError */ - public static loginV1LoginPost({ requestBody }: { requestBody: Login }): CancelablePromise { + public static loginV1LoginPost({ + requestBody, + }: { + requestBody: Login; + }): CancelablePromise { return __request(OpenAPI, { method: 'POST', url: '/v1/login', @@ -86,35 +92,47 @@ export class DefaultService { }); } /** - * Authenticate - * Authentication endpoint used for OAuth strategies. Logs the user in the redirect environment and then - * sets the current session with the user returned from the auth token. + * Google Authenticate + * Callback authentication endpoint used for Google OAuth after redirecting to + * the service's login screen. * * Args: * request (Request): current Request object. - * login (Login): Login payload. * * Returns: * RedirectResponse: On success. * * Raises: * HTTPException: If authentication fails, or strategy is invalid. - * @returns any Successful Response + * @returns JWTResponse Successful Response * @throws ApiError */ - public static authenticateV1AuthPost({ - requestBody, - }: { - requestBody: Auth; - }): CancelablePromise { + public static googleAuthenticateV1GoogleAuthGet(): CancelablePromise { return __request(OpenAPI, { - method: 'POST', - url: '/v1/auth', - body: requestBody, - mediaType: 'application/json', - errors: { - 422: `Validation Error`, - }, + method: 'GET', + url: '/v1/google/auth', + }); + } + /** + * Oidc Authenticate + * Callback authentication endpoint used for OIDC after redirecting to + * the service's login screen. + * + * Args: + * request (Request): current Request object. + * + * Returns: + * RedirectResponse: On success. + * + * Raises: + * HTTPException: If authentication fails, or strategy is invalid. + * @returns JWTResponse Successful Response + * @throws ApiError + */ + public static oidcAuthenticateV1OidcAuthGet(): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/v1/oidc/auth', }); } /** @@ -126,10 +144,10 @@ export class DefaultService { * * Returns: * dict: Empty on success - * @returns any Successful Response + * @returns Logout Successful Response * @throws ApiError */ - public static logoutV1LogoutGet(): CancelablePromise { + public static logoutV1LogoutGet(): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/v1/logout', @@ -663,10 +681,20 @@ export class DefaultService { * @returns ManagedTool Successful Response * @throws ApiError */ - public static listToolsV1ToolsGet(): CancelablePromise> { + public static listToolsV1ToolsGet({ + agentId, + }: { + agentId?: string | null; + }): CancelablePromise> { return __request(OpenAPI, { method: 'GET', url: '/v1/tools', + query: { + agent_id: agentId, + }, + errors: { + 422: `Validation Error`, + }, }); } /**