diff --git a/README.md b/README.md index 8f4350b..dabee2e 100644 --- a/README.md +++ b/README.md @@ -455,6 +455,18 @@ const response = await replicate.models.list(); } ``` +### `replicate.models.search` + +Search for public models on Replicate. + +```js +const response = await replicate.models.search(query); +``` + +| name | type | description | +| ------- | ------ | -------------------------------------- | +| `query` | string | **Required**. The search query string. | + ### `replicate.models.create` Create a new public or private model. diff --git a/index.d.ts b/index.d.ts index 7d4ef0a..25cd51a 100644 --- a/index.d.ts +++ b/index.d.ts @@ -281,6 +281,7 @@ declare module "replicate" { version_id: string ): Promise; }; + search(query: string): Promise>; }; predictions: { diff --git a/index.js b/index.js index 21e83f9..6b35db6 100644 --- a/index.js +++ b/index.js @@ -98,6 +98,7 @@ class Replicate { list: models.versions.list.bind(this), get: models.versions.get.bind(this), }, + search: models.search.bind(this), }; this.predictions = { diff --git a/index.test.ts b/index.test.ts index c4d7e06..7f9fcf2 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,5 +1,11 @@ import { expect, jest, test } from "@jest/globals"; -import Replicate, { ApiError, Model, Prediction, validateWebhook, parseProgressFromLogs } from "replicate"; +import Replicate, { + ApiError, + Model, + Prediction, + validateWebhook, + parseProgressFromLogs, +} from "replicate"; import nock from "nock"; import { Readable } from "node:stream"; import { createReadableStream } from "./lib/stream"; @@ -36,7 +42,8 @@ const fileTestCases = [ describe("Replicate client", () => { let unmatched: any[] = []; - const handleNoMatch = (req: unknown, options: any, body: string) => unmatched.push({ req, options, body }); + const handleNoMatch = (req: unknown, options: any, body: string) => + unmatched.push({ req, options, body }); beforeEach(() => { client = new Replicate({ auth: "test-token" }); @@ -116,7 +123,8 @@ describe("Replicate client", () => { { name: "Super resolution", slug: "super-resolution", - description: "Upscaling models that create high-quality images from low-quality images.", + description: + "Upscaling models that create high-quality images from low-quality images.", }, { name: "Image classification", @@ -139,7 +147,8 @@ describe("Replicate client", () => { nock(BASE_URL).get("/collections/super-resolution").reply(200, { name: "Super resolution", slug: "super-resolution", - description: "Upscaling models that create high-quality images from low-quality images.", + description: + "Upscaling models that create high-quality images from low-quality images.", models: [], }); @@ -179,7 +188,9 @@ describe("Replicate client", () => { results: [{ url: "https://replicate.com/some-user/model-1" }], next: "https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get("/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") + .get( + "/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" + ) .reply(200, { results: [{ url: "https://replicate.com/some-user/model-2" }], next: null, @@ -237,10 +248,12 @@ describe("Replicate client", () => { expectedResponse: { id: "ufawqhfynnddngldkgtslldrkq", model: "replicate/hello-world", - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: + "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, input: testCase.input, created_at: "2022-04-26T22:13:06.224088Z", @@ -250,64 +263,79 @@ describe("Replicate client", () => { }, })); - test.each(predictionTestCases)("$description", async ({ input, expectedResponse }) => { - nock(BASE_URL) - .post("/predictions", { - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + test.each(predictionTestCases)( + "$description", + async ({ input, expectedResponse }) => { + nock(BASE_URL) + .post("/predictions", { + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: input as Record, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + }) + .reply(200, expectedResponse); + + const response = await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: input as Record, webhook: "http://test.host/webhook", webhook_events_filter: ["output", "completed"], - }) - .reply(200, expectedResponse); + }); - const response = await client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: input as Record, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - }); + expect(response.input).toEqual(input); + expect(response.status).toBe(expectedResponse.status); + } + ); - expect(response.input).toEqual(input); - expect(response.status).toBe(expectedResponse.status); - }); + test.each(fileTestCases)( + "converts a $type input into a Replicate file URL", + async ({ value: data, type }) => { + const mockedFetch = jest.spyOn(client, "fetch"); - test.each(fileTestCases)("converts a $type input into a Replicate file URL", async ({ value: data, type }) => { - const mockedFetch = jest.spyOn(client, "fetch"); + nock(BASE_URL) + .post("/files") + .reply(201, { + urls: { + get: "https://replicate.com/api/files/123", + }, + }) + .post( + "/predictions", + (body) => body.input.data === "https://replicate.com/api/files/123" + ) + .reply(201, (_uri: string, body: Record) => { + return body; + }); - nock(BASE_URL) - .post("/files") - .reply(201, { - urls: { - get: "https://replicate.com/api/files/123", + const prediction = await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { + prompt: "Tell me a story", + data, }, - }) - .post("/predictions", (body) => body.input.data === "https://replicate.com/api/files/123") - .reply(201, (_uri: string, body: Record) => { - return body; }); - const prediction = await client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - input: { - prompt: "Tell me a story", - data, - }, - }); - - expect(client.fetch).toHaveBeenCalledWith(new URL("https://api.replicate.com/v1/files"), { - method: "POST", - body: expect.any(FormData), - headers: expect.any(Object), - }); - const form = mockedFetch.mock.calls[0][1]?.body as FormData; - // @ts-ignore - expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`)); + expect(client.fetch).toHaveBeenCalledWith( + new URL("https://api.replicate.com/v1/files"), + { + method: "POST", + body: expect.any(FormData), + headers: expect.any(Object), + } + ); + const form = mockedFetch.mock.calls[0][1]?.body as FormData; + // @ts-ignore + expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`)); - expect(prediction.input).toEqual({ - prompt: "Tell me a story", - data: "https://replicate.com/api/files/123", - }); - }); + expect(prediction.input).toEqual({ + prompt: "Tell me a story", + data: "https://replicate.com/api/files/123", + }); + } + ); test.each(fileTestCases)( "converts a $type input into a base64 encoded string", @@ -323,7 +351,8 @@ describe("Replicate client", () => { }); await client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { prompt: "Tell me a story", data, @@ -332,7 +361,7 @@ describe("Replicate client", () => { }); expect(actual?.input.data).toEqual(expected); - }, + } ); test.each(fileTestCases)( @@ -350,7 +379,8 @@ describe("Replicate client", () => { await expect(async () => { await client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { prompt: "Tell me a story", data, @@ -361,9 +391,9 @@ describe("Replicate client", () => { expect.objectContaining({ name: "ApiError", message: expect.stringContaining("401"), - }), + }) ); - }, + } ); test("Passes stream parameter to API endpoint", async () => { @@ -375,7 +405,8 @@ describe("Replicate client", () => { }); await client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { prompt: "Tell me a story", }, @@ -386,7 +417,8 @@ describe("Replicate client", () => { test("Throws an error if webhook URL is invalid", async () => { await expect(async () => { await client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, @@ -402,14 +434,15 @@ describe("Replicate client", () => { status: 400, detail: "Invalid input", }, - { "Content-Type": "application/json" }, + { "Content-Type": "application/json" } ); try { expect.hasAssertions(); await client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: null, }, @@ -428,14 +461,15 @@ describe("Replicate client", () => { { detail: "Too many requests", }, - { "Content-Type": "application/json", "Retry-After": "1" }, + { "Content-Type": "application/json", "Retry-After": "1" } ) .post("/predictions") .reply(201, { id: "ufawqhfynnddngldkgtslldrkq", }); const prediction = await client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, @@ -449,18 +483,19 @@ describe("Replicate client", () => { { detail: "Internal server error", }, - { "Content-Type": "application/json" }, + { "Content-Type": "application/json" } ); await expect( client.predictions.create({ - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input: { text: "Alice", }, - }), + }) ).rejects.toThrow( - `Request to https://api.replicate.com/v1/predictions failed with status 500 Internal Server Error: {"detail":"Internal server error"}.`, + `Request to https://api.replicate.com/v1/predictions failed with status 500 Internal Server Error: {"detail":"Internal server error"}.` ); }); }); @@ -472,10 +507,12 @@ describe("Replicate client", () => { .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", model: "stability-ai/stable-diffusion", - version: "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e", + version: + "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e", urls: { get: "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe", - cancel: "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel", + cancel: + "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel", }, created_at: "2022-09-13T22:54:18.578761Z", started_at: "2022-09-13T22:54:19.438525Z", @@ -494,7 +531,9 @@ describe("Replicate client", () => { predict_time: 4.484541, }, }); - const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); + const prediction = await client.predictions.get( + "rrr4z55ocneqzikepnug6xezpe" + ); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); @@ -506,14 +545,16 @@ describe("Replicate client", () => { { detail: "Too many requests", }, - { "Content-Type": "application/json", "Retry-After": "1" }, + { "Content-Type": "application/json", "Retry-After": "1" } ) .get("/predictions/rrr4z55ocneqzikepnug6xezpe") .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", }); - const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); + const prediction = await client.predictions.get( + "rrr4z55ocneqzikepnug6xezpe" + ); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); @@ -525,14 +566,16 @@ describe("Replicate client", () => { { detail: "Internal server error", }, - { "Content-Type": "application/json" }, + { "Content-Type": "application/json" } ) .get("/predictions/rrr4z55ocneqzikepnug6xezpe") .reply(200, { id: "rrr4z55ocneqzikepnug6xezpe", }); - const prediction = await client.predictions.get("rrr4z55ocneqzikepnug6xezpe"); + const prediction = await client.predictions.get( + "rrr4z55ocneqzikepnug6xezpe" + ); expect(prediction.id).toBe("rrr4z55ocneqzikepnug6xezpe"); }); }); @@ -544,10 +587,12 @@ describe("Replicate client", () => { .reply(200, { id: "ufawqhfynnddngldkgtslldrkq", model: "replicate/hello-world", - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: + "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, created_at: "2022-04-26T22:13:06.224088Z", started_at: "2022-04-26T22:13:06.224088Z", @@ -562,7 +607,9 @@ describe("Replicate client", () => { metrics: {}, }); - const prediction = await client.predictions.cancel("ufawqhfynnddngldkgtslldrkq"); + const prediction = await client.predictions.cancel( + "ufawqhfynnddngldkgtslldrkq" + ); expect(prediction.status).toBe("canceled"); }); @@ -580,10 +627,12 @@ describe("Replicate client", () => { { id: "jpzd7hm5gfcapbfyt4mqytarku", model: "stability-ai/stable-diffusion", - version: "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", + version: + "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", urls: { get: "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku", - cancel: "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel", + cancel: + "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel", }, created_at: "2022-04-26T20:00:40.658234Z", started_at: "2022-04-26T20:00:84.583803Z", @@ -606,7 +655,9 @@ describe("Replicate client", () => { results: [{ id: "ufawqhfynnddngldkgtslldrkq" }], next: "https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get("/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") + .get( + "/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" + ) .reply(200, { results: [{ id: "rrr4z55ocneqzikepnug6xezpe" }], next: null, @@ -616,7 +667,10 @@ describe("Replicate client", () => { for await (const batch of client.paginate(client.predictions.list)) { results.push(...batch); } - expect(results).toEqual([{ id: "ufawqhfynnddngldkgtslldrkq" }, { id: "rrr4z55ocneqzikepnug6xezpe" }]); + expect(results).toEqual([ + { id: "ufawqhfynnddngldkgtslldrkq" }, + { id: "rrr4z55ocneqzikepnug6xezpe" }, + ]); // Add more tests for error handling, edge cases, etc. }); @@ -625,10 +679,13 @@ describe("Replicate client", () => { describe("trainings.create", () => { test("Calls the correct API route with the correct payload", async () => { nock(BASE_URL) - .post("/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings") + .post( + "/models/owner/model/versions/632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532/trainings" + ) .reply(200, { id: "zz4ibbonubfz7carwiefibzgga", - version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + version: + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", status: "starting", input: { text: "...", @@ -650,20 +707,25 @@ describe("Replicate client", () => { input: { text: "...", }, - }, + } ); expect(training.id).toBe("zz4ibbonubfz7carwiefibzgga"); }); test("Throws an error if webhook is not a valid URL", async () => { await expect( - client.trainings.create("owner", "model", "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", { - destination: "new_owner/new_model", - input: { - text: "...", - }, - webhook: "invalid-url", - }), + client.trainings.create( + "owner", + "model", + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + { + destination: "new_owner/new_model", + input: { + text: "...", + }, + webhook: "invalid-url", + } + ) ).rejects.toThrow("Invalid webhook URL"); }); @@ -723,7 +785,9 @@ describe("Replicate client", () => { completed_at: null, }); - const training = await client.trainings.cancel("zz4ibbonubfz7carwiefibzgga"); + const training = await client.trainings.cancel( + "zz4ibbonubfz7carwiefibzgga" + ); expect(training.status).toBe("canceled"); }); @@ -741,10 +805,12 @@ describe("Replicate client", () => { { id: "jpzd7hm5gfcapbfyt4mqytarku", model: "stability-ai/sdxl", - version: "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", + version: + "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05", urls: { get: "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku", - cancel: "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel", + cancel: + "https://api.replicate.com/v1/trainings/jpzd7hm5gfcapbfyt4mqytarku/cancel", }, created_at: "2022-04-26T20:00:40.658234Z", started_at: "2022-04-26T20:00:84.583803Z", @@ -767,7 +833,9 @@ describe("Replicate client", () => { results: [{ id: "ufawqhfynnddngldkgtslldrkq" }], next: "https://api.replicate.com/v1/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw", }) - .get("/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw") + .get( + "/trainings?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw" + ) .reply(200, { results: [{ id: "rrr4z55ocneqzikepnug6xezpe" }], next: null, @@ -777,7 +845,10 @@ describe("Replicate client", () => { for await (const batch of client.paginate(client.trainings.list)) { results.push(...batch); } - expect(results).toEqual([{ id: "ufawqhfynnddngldkgtslldrkq" }, { id: "rrr4z55ocneqzikepnug6xezpe" }]); + expect(results).toEqual([ + { id: "ufawqhfynnddngldkgtslldrkq" }, + { id: "rrr4z55ocneqzikepnug6xezpe" }, + ]); // Add more tests for error handling, edge cases, etc. }); @@ -790,10 +861,12 @@ describe("Replicate client", () => { .reply(200, { id: "mfrgcyzzme2wkmbwgzrgmntcg", model: "replicate/hello-world", - version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", urls: { get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq", - cancel: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", + cancel: + "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel", }, created_at: "2022-09-10T09:44:22.165836Z", started_at: null, @@ -807,13 +880,17 @@ describe("Replicate client", () => { logs: null, metrics: {}, }); - const prediction = await client.deployments.predictions.create("replicate", "greeter", { - input: { - text: "Alice", - }, - webhook: "http://test.host/webhook", - webhook_events_filter: ["output", "completed"], - }); + const prediction = await client.deployments.predictions.create( + "replicate", + "greeter", + { + input: { + text: "Alice", + }, + webhook: "http://test.host/webhook", + webhook_events_filter: ["output", "completed"], + } + ); expect(prediction.id).toBe("mfrgcyzzme2wkmbwgzrgmntcg"); }); // Add more tests for error handling, edge cases, etc. @@ -829,7 +906,8 @@ describe("Replicate client", () => { current_release: { number: 1, model: "stability-ai/sdxl", - version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: + "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", created_at: "2024-02-15T16:32:57.018467Z", created_by: { type: "organization", @@ -845,7 +923,10 @@ describe("Replicate client", () => { }, }); - const deployment = await client.deployments.get("acme", "my-app-image-generator"); + const deployment = await client.deployments.get( + "acme", + "my-app-image-generator" + ); expect(deployment.owner).toBe("acme"); expect(deployment.name).toBe("my-app-image-generator"); @@ -864,7 +945,8 @@ describe("Replicate client", () => { current_release: { number: 1, model: "stability-ai/sdxl", - version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: + "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", created_at: "2024-02-15T16:32:57.018467Z", created_by: { type: "organization", @@ -883,7 +965,8 @@ describe("Replicate client", () => { const deployment = await client.deployments.create({ name: "my-app-image-generator", model: "stability-ai/sdxl", - version: "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + version: + "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", hardware: "gpu-t4", min_instances: 1, max_instances: 5, @@ -906,7 +989,8 @@ describe("Replicate client", () => { current_release: { number: 2, model: "stability-ai/sdxl", - version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + version: + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", created_at: "2024-02-16T08:14:22.345678Z", created_by: { type: "organization", @@ -922,18 +1006,25 @@ describe("Replicate client", () => { }, }); - const deployment = await client.deployments.update("acme", "my-app-image-generator", { - version: "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", - hardware: "gpu-a40-large", - min_instances: 3, - max_instances: 10, - }); + const deployment = await client.deployments.update( + "acme", + "my-app-image-generator", + { + version: + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + hardware: "gpu-a40-large", + min_instances: 3, + max_instances: 10, + } + ); expect(deployment.current_release.number).toBe(2); expect(deployment.current_release.version).toBe( - "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532", + "632231d0d49d34d5c4633bd838aee3d81d936e59a886fbf28524702003b4c532" + ); + expect(deployment.current_release.configuration.hardware).toBe( + "gpu-a40-large" ); - expect(deployment.current_release.configuration.hardware).toBe("gpu-a40-large"); expect(deployment.current_release.configuration.min_instances).toBe(3); expect(deployment.current_release.configuration.max_instances).toBe(10); }); @@ -942,9 +1033,14 @@ describe("Replicate client", () => { describe("deployments.delete", () => { test("Calls the correct API route with the correct payload", async () => { - nock(BASE_URL).delete("/deployments/acme/my-app-image-generator").reply(204); + nock(BASE_URL) + .delete("/deployments/acme/my-app-image-generator") + .reply(204); - const success = await client.deployments.delete("acme", "my-app-image-generator"); + const success = await client.deployments.delete( + "acme", + "my-app-image-generator" + ); expect(success).toBe(true); }); }); @@ -977,7 +1073,7 @@ describe("Replicate client", () => { describe("predictions.create with model", () => { test("Calls the correct API route with the correct payload", async () => { nock(BASE_URL) - .post("/models/meta/llama-2-70b-chat/predictions") + .post("/models/meta/meta-llama-3-70b-instruct/predictions") .reply(200, { id: "heat2o3bzn3ahtr6bjfftvbaci", model: "replicate/lifeboat-70b", @@ -990,12 +1086,13 @@ describe("Replicate client", () => { status: "starting", created_at: "2023-11-27T13:35:45.99397566Z", urls: { - cancel: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel", + cancel: + "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel", get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci", }, }); const prediction = await client.predictions.create({ - model: "meta/llama-2-70b-chat", + model: "meta/meta-llama-3-70b-instruct", input: { prompt: "Please write a haiku about llamas.", }, @@ -1140,6 +1237,44 @@ describe("Replicate client", () => { }); }); + describe("models.search", () => { + test("Calls the correct API route with the correct payload", async () => { + nock(BASE_URL) + .intercept("/models", "QUERY") + .reply(200, { + results: [ + { + url: "https://replicate.com/meta/meta-llama-3-70b-instruct", + owner: "meta", + name: "meta-llama-3-70b-instruct", + description: + "Llama 2 is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters.", + visibility: "public", + github_url: null, + paper_url: + "https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/", + license_url: "https://ai.meta.com/llama/license/", + run_count: 1000000, + cover_image_url: + "https://replicate.delivery/pbxt/IJqFrnAKEDiCBnlXyndzVVxkZvfQ7kLjGVEZZPXTRXxOOPkQA/llama2.png", + default_example: null, + latest_version: null, + }, + // ... more results ... + ], + next: null, + previous: null, + }); + + const searchResults = await client.models.search("llama"); + expect(searchResults.results.length).toBeGreaterThan(0); + expect(searchResults.results[0].owner).toBe("meta"); + expect(searchResults.results[0].name).toBe("meta-llama-3-70b-instruct"); + }); + + // Add more tests for error handling, edge cases, etc. + }); + describe("run", () => { test("Calls the correct API routes", async () => { nock(BASE_URL) @@ -1200,7 +1335,7 @@ describe("Replicate client", () => { (prediction) => { const progress = parseProgressFromLogs(prediction); callback(prediction, progress); - }, + } ); expect(output).toBe("Goodbye!"); @@ -1212,7 +1347,7 @@ describe("Replicate client", () => { status: "starting", logs: null, }, - null, + null ); expect(callback).toHaveBeenNthCalledWith( @@ -1226,7 +1361,7 @@ describe("Replicate client", () => { percentage: 0.4, current: 2, total: 5, - }, + } ); expect(callback).toHaveBeenNthCalledWith( @@ -1240,7 +1375,7 @@ describe("Replicate client", () => { percentage: 0.8, current: 4, total: 5, - }, + } ); expect(callback).toHaveBeenNthCalledWith( @@ -1255,7 +1390,7 @@ describe("Replicate client", () => { percentage: 1.0, current: 5, total: 5, - }, + } ); expect(callback).toHaveBeenCalledTimes(4); @@ -1289,7 +1424,7 @@ describe("Replicate client", () => { input: { text: "Hello, world!" }, wait: { interval: 1 }, }, - progress, + progress ); expect(output).toBe("Goodbye!"); @@ -1332,7 +1467,9 @@ describe("Replicate client", () => { output: "foobar", }); - await expect(client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } })).resolves.not.toThrow(); + await expect( + client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } }) + ).resolves.not.toThrow(); }); test("Throws an error for invalid identifiers", async () => { @@ -1349,12 +1486,15 @@ describe("Replicate client", () => { test("Throws an error if webhook URL is invalid", async () => { await expect(async () => { - await client.run("owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", { - input: { - text: "Alice", - }, - webhook: "invalid-url", - }); + await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { + text: "Alice", + }, + webhook: "invalid-url", + } + ); }).rejects.toThrow("Invalid webhook URL"); }); @@ -1393,7 +1533,7 @@ describe("Replicate client", () => { input: { text: "Hello, world!" }, signal, }, - onProgress, + onProgress ); expect(body).toBeDefined(); @@ -1405,19 +1545,19 @@ describe("Replicate client", () => { 1, expect.objectContaining({ status: "processing", - }), + }) ); expect(onProgress).toHaveBeenNthCalledWith( 2, expect.objectContaining({ status: "processing", - }), + }) ); expect(onProgress).toHaveBeenNthCalledWith( 3, expect.objectContaining({ status: "canceled", - }), + }) ); scope.done(); @@ -1442,7 +1582,8 @@ describe("Replicate client", () => { "Content-Type": "application/json", "Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek", "Webhook-Timestamp": "1614265330", - "Webhook-Signature": "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", + "Webhook-Signature": + "v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=", }, body: `{"test": 2432232314}`, }); @@ -1485,7 +1626,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, ""), + `.replace(/^[ ]+/gm, "") ); const iterator = stream[Symbol.asyncIterator](); @@ -1516,7 +1657,7 @@ describe("Replicate client", () => { id: EVENT_3 data: {} - `.replace(/^[ ]+/gm, ""), + `.replace(/^[ ]+/gm, "") ); const iterator = stream[Symbol.asyncIterator](); @@ -1550,7 +1691,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, ""), + `.replace(/^[ ]+/gm, "") ); const iterator = stream[Symbol.asyncIterator](); @@ -1582,7 +1723,7 @@ describe("Replicate client", () => { id: EVENT_2 data: {} - `.replace(/^[ ]+/gm, ""), + `.replace(/^[ ]+/gm, "") ); const iterator = stream[Symbol.asyncIterator](); @@ -1703,7 +1844,7 @@ describe("Replicate client", () => { id: EVENT_1 data: hello world - `.replace(/^[ ]+/gm, ""), + `.replace(/^[ ]+/gm, "") ); const iterator = stream[Symbol.asyncIterator](); @@ -1725,7 +1866,7 @@ describe("Replicate client", () => { id: EVENT_2 data: An unexpected error occurred - `.replace(/^[ ]+/gm, ""), + `.replace(/^[ ]+/gm, "") ); const iterator = stream[Symbol.asyncIterator](); @@ -1733,7 +1874,9 @@ describe("Replicate client", () => { done: false, value: { event: "output", id: "EVENT_1", data: "hello world" }, }); - await expect(iterator.next()).rejects.toThrowError("An unexpected error occurred"); + await expect(iterator.next()).rejects.toThrowError( + "An unexpected error occurred" + ); expect(await iterator.next()).toEqual({ done: true }); }); @@ -1741,7 +1884,7 @@ describe("Replicate client", () => { const stream = createStream("{}", 500); const iterator = stream[Symbol.asyncIterator](); await expect(iterator.next()).rejects.toThrowError( - "Request to https://stream.replicate.com/fake_stream failed with status 500", + "Request to https://stream.replicate.com/fake_stream failed with status 500" ); expect(await iterator.next()).toEqual({ done: true }); }); diff --git a/lib/models.js b/lib/models.js index c6a02fc..272d9ed 100644 --- a/lib/models.js +++ b/lib/models.js @@ -89,9 +89,28 @@ async function createModel(model_owner, model_name, options) { return response.json(); } +/** + * Search for public models + * + * @param {string} query - The search query + * @returns {Promise} Resolves with a page of models matching the search query + */ +async function search(query) { + const response = await this.request("/models", { + method: "QUERY", + headers: { + "Content-Type": "text/plain", + }, + data: query, + }); + + return response.json(); +} + module.exports = { get: getModel, list: listModels, create: createModel, versions: { list: listModelVersions, get: getModelVersion }, + search, };