Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backend] make deployment field optional in API and DB #213

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/backend/database_models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Agent(Base):
Enum(AgentModel, native_enum=False), nullable=False
)
deployment: Mapped[AgentDeployment] = mapped_column(
Enum(AgentDeployment, native_enum=False), nullable=False
Enum(AgentDeployment, native_enum=False), default=AgentDeployment.COHERE_PLATFORM, nullable=False
)

user_id: Mapped[str] = mapped_column(Text, nullable=False)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CreateAgent(BaseModel):
preamble: Optional[str] = None
temperature: Optional[float] = None
model: AgentModel
deployment: AgentDeployment
deployment: Optional[AgentDeployment] = None
tools: Optional[list[ToolName]] = None

class Config:
Expand Down
12 changes: 0 additions & 12 deletions src/backend/tests/crud/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_create_agent_empty_non_required_fields(session, user):
agent_data = Agent(
user_id=user.id,
name="test",
deployment=AgentDeployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R_PLUS,
)

Expand Down Expand Up @@ -99,17 +98,6 @@ def test_create_agent_missing_model(session, user):
_ = agent_crud.create_agent(session, agent_data)


def test_create_agent_missing_deployment(session, user):
agent_data = Agent(
user_id=user.id,
name="test",
model=AgentModel.COMMAND_R_PLUS,
)

with pytest.raises(IntegrityError):
_ = agent_crud.create_agent(session, agent_data)


def test_create_agent_missing_user_id(session):
agent_data = Agent(
name="test",
Expand Down
20 changes: 0 additions & 20 deletions src/backend/tests/routers/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,6 @@ def test_create_agent_missing_model(
)
assert response.status_code == 422


def test_create_agent_missing_deployment(
session_client: TestClient, session: Session
) -> None:
request_json = {
"name": "test agent",
"description": "test description",
"preamble": "test preamble",
"temperature": 0.5,
"model": AgentModel.COMMAND_R,
}
response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": "123"}
)
assert response.status_code == 422


def test_create_agent_missing_user_id_header(
session_client: TestClient, session: Session
) -> None:
Expand All @@ -111,7 +94,6 @@ def test_create_agent_missing_non_required_fields(
request_json = {
"name": "test agent",
"model": AgentModel.COMMAND_R,
"deployment": AgentDeployment.COHERE_PLATFORM,
}

print(request_json)
Expand All @@ -128,7 +110,6 @@ def test_create_agent_missing_non_required_fields(
assert response_agent["preamble"] == ""
assert response_agent["temperature"] == 0.3
assert response_agent["model"] == request_json["model"]
assert response_agent["deployment"] == request_json["deployment"]

agent = session.get(Agent, response_agent["id"])
assert agent is not None
Expand All @@ -138,7 +119,6 @@ def test_create_agent_missing_non_required_fields(
assert agent.preamble == ""
assert agent.temperature == 0.3
assert agent.model == request_json["model"]
assert agent.deployment == request_json["deployment"]


def test_create_agent_wrong_model_deployment_enums(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export type { OpenAPIConfig } from './core/OpenAPI';
export type { Agent } from './models/Agent';
export { AgentDeployment } from './models/AgentDeployment';
export { AgentModel } from './models/AgentModel';
export type { Auth } from './models/Auth';
export type { Body_upload_file_v1_conversations_upload_file_post } from './models/Body_upload_file_v1_conversations_upload_file_post';
export { Category } from './models/Category';
export type { ChatMessage } from './models/ChatMessage';
Expand All @@ -31,9 +30,12 @@ export type { Deployment } from './models/Deployment';
export type { Document } from './models/Document';
export type { File } from './models/File';
export type { HTTPValidationError } from './models/HTTPValidationError';
export type { JWTResponse } from './models/JWTResponse';
export type { LangchainChatRequest } from './models/LangchainChatRequest';
export type { ListAuthStrategy } from './models/ListAuthStrategy';
export type { ListFile } from './models/ListFile';
export type { Login } from './models/Login';
export type { Logout } from './models/Logout';
export type { ManagedTool } from './models/ManagedTool';
export type { Message } from './models/Message';
export { MessageAgent } from './models/MessageAgent';
Expand All @@ -53,6 +55,7 @@ export type { StreamToolResult } from './models/StreamToolResult';
export type { Tool } from './models/Tool';
export type { ToolCall } from './models/ToolCall';
export { ToolInputType } from './models/ToolInputType';
export { ToolName } from './models/ToolName';
export type { UpdateAgent } from './models/UpdateAgent';
export type { UpdateConversation } from './models/UpdateConversation';
export type { UpdateDeploymentEnv } from './models/UpdateDeploymentEnv';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/* eslint-disable */
import type { AgentDeployment } from './AgentDeployment';
import type { AgentModel } from './AgentModel';
import type { ToolName } from './ToolName';

export type Agent = {
user_id: string;
Expand All @@ -18,6 +19,7 @@ export type Agent = {
description: string | null;
preamble: string | null;
temperature: number;
tools: Array<ToolName>;
model: AgentModel;
deployment: AgentDeployment;
};
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/* eslint-disable */
import type { AgentDeployment } from './AgentDeployment';
import type { AgentModel } from './AgentModel';
import type { ToolName } from './ToolName';

export type CreateAgent = {
name: string;
Expand All @@ -15,5 +16,6 @@ export type CreateAgent = {
preamble?: string | null;
temperature?: number | null;
model: AgentModel;
deployment: AgentDeployment;
deployment?: AgentDeployment | null;
tools?: Array<ToolName> | null;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type JWTResponse = {
token: string;
};
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Auth = {
export type ListAuthStrategy = {
strategy: string;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Logout = {};
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ export type StreamEnd = {
search_results?: Array<Record<string, any>>;
search_queries?: Array<SearchQuery>;
tool_calls?: Array<ToolCall>;
finish_reason: string;
finish_reason?: string | null;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export enum ToolName {
WIKIPEDIA = 'Wikipedia',
SEARCH_FILE = 'search_file',
READ_DOCUMENT = 'read_document',
PYTHON_INTERPRETER = 'Python_Interpreter',
CALCULATOR = 'Calculator',
INTERNET_SEARCH = 'Internet_Search',
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/* eslint-disable */
import type { AgentDeployment } from './AgentDeployment';
import type { AgentModel } from './AgentModel';
import type { ToolName } from './ToolName';

export type UpdateAgent = {
name?: string | null;
Expand All @@ -16,4 +17,5 @@ export type UpdateAgent = {
temperature?: number | null;
model?: AgentModel | null;
deployment?: AgentDeployment | null;
tools?: Array<ToolName> | null;
};
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import type { CancelablePromise } from '../core/CancelablePromise';
import { OpenAPI } from '../core/OpenAPI';
import { request as __request } from '../core/request';
import type { Agent } from '../models/Agent';
import type { Auth } from '../models/Auth';
import type { Body_upload_file_v1_conversations_upload_file_post } from '../models/Body_upload_file_v1_conversations_upload_file_post';
import type { ChatResponseEvent } from '../models/ChatResponseEvent';
import type { CohereChatRequest } from '../models/CohereChatRequest';
Expand All @@ -23,9 +22,12 @@ import type { DeleteFile } from '../models/DeleteFile';
import type { DeleteUser } from '../models/DeleteUser';
import type { Deployment } from '../models/Deployment';
import type { File } from '../models/File';
import type { JWTResponse } from '../models/JWTResponse';
import type { LangchainChatRequest } from '../models/LangchainChatRequest';
import type { ListAuthStrategy } from '../models/ListAuthStrategy';
import type { ListFile } from '../models/ListFile';
import type { Login } from '../models/Login';
import type { Logout } from '../models/Logout';
import type { ManagedTool } from '../models/ManagedTool';
import type { NonStreamedChatResponse } from '../models/NonStreamedChatResponse';
import type { UpdateAgent } from '../models/UpdateAgent';
Expand All @@ -44,10 +46,10 @@ export class DefaultService {
*
* Returns:
* List[dict]: List of dictionaries containing the enabled auth strategy names.
* @returns any Successful Response
* @returns ListAuthStrategy Successful Response
* @throws ApiError
*/
public static getStrategiesV1AuthStrategiesGet(): CancelablePromise<any> {
public static getStrategiesV1AuthStrategiesGet(): CancelablePromise<Array<ListAuthStrategy>> {
return __request(OpenAPI, {
method: 'GET',
url: '/v1/auth_strategies',
Expand All @@ -74,7 +76,11 @@ export class DefaultService {
* @returns any Successful Response
* @throws ApiError
*/
public static loginV1LoginPost({ requestBody }: { requestBody: Login }): CancelablePromise<any> {
public static loginV1LoginPost({
requestBody,
}: {
requestBody: Login;
}): CancelablePromise<JWTResponse | null> {
return __request(OpenAPI, {
method: 'POST',
url: '/v1/login',
Expand All @@ -86,35 +92,47 @@ export class DefaultService {
});
}
/**
* Authenticate
* Authentication endpoint used for OAuth strategies. Logs the user in the redirect environment and then
* sets the current session with the user returned from the auth token.
* Google Authenticate
* Callback authentication endpoint used for Google OAuth after redirecting to
* the service's login screen.
*
* Args:
* request (Request): current Request object.
* login (Login): Login payload.
*
* Returns:
* RedirectResponse: On success.
*
* Raises:
* HTTPException: If authentication fails, or strategy is invalid.
* @returns any Successful Response
* @returns JWTResponse Successful Response
* @throws ApiError
*/
public static authenticateV1AuthPost({
requestBody,
}: {
requestBody: Auth;
}): CancelablePromise<any> {
public static googleAuthenticateV1GoogleAuthGet(): CancelablePromise<JWTResponse> {
return __request(OpenAPI, {
method: 'POST',
url: '/v1/auth',
body: requestBody,
mediaType: 'application/json',
errors: {
422: `Validation Error`,
},
method: 'GET',
url: '/v1/google/auth',
});
}
/**
* Oidc Authenticate
* Callback authentication endpoint used for OIDC after redirecting to
* the service's login screen.
*
* Args:
* request (Request): current Request object.
*
* Returns:
* RedirectResponse: On success.
*
* Raises:
* HTTPException: If authentication fails, or strategy is invalid.
* @returns JWTResponse Successful Response
* @throws ApiError
*/
public static oidcAuthenticateV1OidcAuthGet(): CancelablePromise<JWTResponse> {
return __request(OpenAPI, {
method: 'GET',
url: '/v1/oidc/auth',
});
}
/**
Expand All @@ -126,10 +144,10 @@ export class DefaultService {
*
* Returns:
* dict: Empty on success
* @returns any Successful Response
* @returns Logout Successful Response
* @throws ApiError
*/
public static logoutV1LogoutGet(): CancelablePromise<any> {
public static logoutV1LogoutGet(): CancelablePromise<Logout> {
return __request(OpenAPI, {
method: 'GET',
url: '/v1/logout',
Expand Down Expand Up @@ -663,10 +681,20 @@ export class DefaultService {
* @returns ManagedTool Successful Response
* @throws ApiError
*/
public static listToolsV1ToolsGet(): CancelablePromise<Array<ManagedTool>> {
public static listToolsV1ToolsGet({
agentId,
}: {
agentId?: string | null;
}): CancelablePromise<Array<ManagedTool>> {
return __request(OpenAPI, {
method: 'GET',
url: '/v1/tools',
query: {
agent_id: agentId,
},
errors: {
422: `Validation Error`,
},
});
}
/**
Expand Down
Loading