From 632c328276d5ed0846bb378900c0fd913bf8f045 Mon Sep 17 00:00:00 2001 From: Victor Elias Date: Fri, 13 Sep 2024 18:21:26 -0300 Subject: [PATCH] api: Adjust API handlers to new ai schema --- packages/api/src/controllers/generate.ts | 40 +++++++++++++---------- packages/api/src/schema/db-schema.yaml | 2 ++ packages/api/src/schema/pull-ai-schema.js | 3 +- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/packages/api/src/controllers/generate.ts b/packages/api/src/controllers/generate.ts index 43b875664..f9a02c85e 100644 --- a/packages/api/src/controllers/generate.ts +++ b/packages/api/src/controllers/generate.ts @@ -8,10 +8,11 @@ import sql from "sql-template-strings"; import { v4 as uuid } from "uuid"; import logger from "../logger"; import { authorizer, validateFormData, validatePost } from "../middleware"; +import { defaultModels } from "../schema/pull-ai-schema"; import { AiGenerateLog } from "../schema/types"; import { db } from "../store"; import { BadRequestError } from "../store/errors"; -import { fetchWithTimeout } from "../util"; +import { fetchWithTimeout, kebabToCamel } from "../util"; import { experimentSubjectsOnly } from "./experiment"; import { pathJoin2 } from "./helpers"; @@ -170,13 +171,24 @@ function logAiGenerateRequest( function registerGenerateHandler( type: AiGenerateType, - defaultModel: string, isJSONReq = false, // multipart by default ): RequestHandler { const path = `/${type}`; - const payloadParsers = isJSONReq - ? [validatePost(`${type}-payload`)] - : [multipart.any(), validateFormData(`${type}-payload`)]; + + let payloadParsers: RequestHandler[]; + let camelType = kebabToCamel(type); + camelType = camelType[0].toUpperCase() + camelType.slice(1); + if (isJSONReq) { + payloadParsers = [validatePost(`${camelType}Params`)]; + } else { + payloadParsers = [ + multipart.any(), + validateFormData(`Body_gen${camelType}`), + ]; + } + + const defaultModel = defaultModels[type]; + return app.post( path, authorizer({}), @@ -236,17 +248,11 @@ function registerGenerateHandler( ); } -registerGenerateHandler( - "text-to-image", - "SG161222/RealVisXL_V4.0_Lightning", - true, -); -registerGenerateHandler("image-to-image", "timbrooks/instruct-pix2pix"); -registerGenerateHandler( - "image-to-video", - "stabilityai/stable-video-diffusion-img2vid-xt-1-1", -); -registerGenerateHandler("upscale", "stabilityai/stable-diffusion-x4-upscaler"); -registerGenerateHandler("audio-to-text", "openai/whisper-large-v3"); +registerGenerateHandler("text-to-image", true); +registerGenerateHandler("image-to-image"); +registerGenerateHandler("image-to-video"); +registerGenerateHandler("upscale"); +registerGenerateHandler("audio-to-text"); +registerGenerateHandler("segment-anything-2"); export default app; diff --git a/packages/api/src/schema/db-schema.yaml b/packages/api/src/schema/db-schema.yaml index f79220712..6388c00a1 100644 --- a/packages/api/src/schema/db-schema.yaml +++ b/packages/api/src/schema/db-schema.yaml @@ -1453,12 +1453,14 @@ components: - image-to-image - image-to-video - upscale + - segment-anything-2 request: oneOf: - $ref: "./ai-api-schema.yaml#/components/schemas/TextToImageParams" - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToImage" - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToVideo" - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genUpscale" + - $ref: "./ai-api-schema.yaml#/components/schemas/Body_genSegmentAnything2" statusCode: type: integer description: HTTP status code received from the AI gateway diff --git a/packages/api/src/schema/pull-ai-schema.js b/packages/api/src/schema/pull-ai-schema.js index 6a5ad0907..4083e78b4 100644 --- a/packages/api/src/schema/pull-ai-schema.js +++ b/packages/api/src/schema/pull-ai-schema.js @@ -5,12 +5,13 @@ import path from "path"; // This downloads the AI schema from the AI worker repo and saves in the local // ai-api-schema.yaml file, referenced by our main api-schema.yaml file. -const defaultModels = { +export const defaultModels = { "text-to-image": "SG161222/RealVisXL_V4.0_Lightning", "image-to-image": "timbrooks/instruct-pix2pix", "image-to-video": "stabilityai/stable-video-diffusion-img2vid-xt-1-1", upscale: "stabilityai/stable-diffusion-x4-upscaler", "audio-to-text": "openai/whisper-large-v3", + "segment-anything-2": "facebook/sam2-hiera-large:", }; const schemaDir = path.resolve(__dirname, "."); const aiSchemaUrl =