diff --git a/index.d.ts b/index.d.ts index 25cd51a..4801654 100644 --- a/index.d.ts +++ b/index.d.ts @@ -8,6 +8,12 @@ declare module "replicate" { response: Response; } + export interface FileOutput extends ReadableStream { + blob(): Promise; + url(): URL; + toString(): string; + } + export interface Account { type: "user" | "organization"; username: string; @@ -137,6 +143,7 @@ declare module "replicate" { init?: RequestInit ) => Promise; fileEncodingStrategy?: FileEncodingStrategy; + useFileOutput?: boolean; }); auth: string; diff --git a/index.js b/index.js index 6b35db6..f2c3e1e 100644 --- a/index.js +++ b/index.js @@ -1,7 +1,8 @@ const ApiError = require("./lib/error"); const ModelVersionIdentifier = require("./lib/identifier"); -const { createReadableStream } = require("./lib/stream"); +const { createReadableStream, createFileOutput } = require("./lib/stream"); const { + transform, withAutomaticRetries, validateWebhook, parseProgressFromLogs, @@ -47,6 +48,7 @@ class Replicate { * @param {string} options.userAgent - Identifier of your app * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` + * @param {boolean} [options.useFileOutput] - Set to `true` to return `FileOutput` objects from `run` instead of URLs, defaults to false. * @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use */ constructor(options = {}) { @@ -57,7 +59,8 @@ class Replicate { options.userAgent || `replicate-javascript/${packageJSON.version}`; this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; this.fetch = options.fetch || globalThis.fetch; - this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default"; + this.fileEncodingStrategy = options.fileEncodingStrategy || "default"; + this.useFileOutput = options.useFileOutput || false; this.accounts = { current: accounts.current.bind(this), @@ -196,7 +199,17 @@ class Replicate { throw new Error(`Prediction failed: ${prediction.error}`); } - return prediction.output; + return transform(prediction.output, (value) => { + if ( + typeof value === "string" && + (value.startsWith("https:") || value.startsWith("data:")) + ) { + return this.useFileOutput + ? createFileOutput({ url: value, fetch: this.fetch }) + : value; + } + return value; + }); } /** diff --git a/index.test.ts b/index.test.ts index 7f9fcf2..5ca3e54 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1,13 +1,13 @@ import { expect, jest, test } from "@jest/globals"; import Replicate, { ApiError, + FileOutput, Model, Prediction, validateWebhook, parseProgressFromLogs, } from "replicate"; import nock from "nock"; -import { Readable } from "node:stream"; import { createReadableStream } from "./lib/stream"; let client: Replicate; @@ -1562,6 +1562,203 @@ describe("Replicate client", () => { scope.done(); }); + + test("returns FileOutput for URLs when useFileOutput is true", async () => { + client = new Replicate({ auth: "foo", useFileOutput: true }); + + nock(BASE_URL) + .post("/predictions") + .reply(201, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + output: "https://example.com", + logs: [].join("\n"), + }); + + nock("https://example.com") + .get("/") + .reply(200, "hello world", { "Content-Type": "text/plain" }); + + const output = (await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + } + )) as FileOutput; + + expect(output).toBeInstanceOf(ReadableStream); + expect(output.url()).toEqual(new URL("https://example.com")); + + const blob = await output.blob(); + expect(blob.type).toEqual("text/plain"); + expect(blob.arrayBuffer()).toEqual( + new Blob(["Hello, world!"]).arrayBuffer() + ); + }); + + test("returns FileOutput for URLs when useFileOutput is true - acts like string", async () => { + client = new Replicate({ auth: "foo", useFileOutput: true }); + + nock(BASE_URL) + .post("/predictions") + .reply(201, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + output: "https://example.com", + logs: [].join("\n"), + }); + + nock("https://example.com") + .get("/") + .reply(200, "hello world", { "Content-Type": "text/plain" }); + + const output = (await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + } + )) as unknown as string; + + expect(fetch(output).then((r) => r.text())).resolves.toEqual( + "hello world" + ); + }); + + test("returns FileOutput for URLs when useFileOutput is true - array output", async () => { + client = new Replicate({ auth: "foo", useFileOutput: true }); + + nock(BASE_URL) + .post("/predictions") + .reply(201, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + output: ["https://example.com"], + logs: [].join("\n"), + }); + + nock("https://example.com") + .get("/") + .reply(200, "hello world", { "Content-Type": "text/plain" }); + + const [output] = (await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + } + )) as FileOutput[]; + + expect(output).toBeInstanceOf(ReadableStream); + expect(output.url()).toEqual(new URL("https://example.com")); + + const blob = await output.blob(); + expect(blob.type).toEqual("text/plain"); + expect(blob.arrayBuffer()).toEqual( + new Blob(["Hello, world!"]).arrayBuffer() + ); + }); + + test("returns FileOutput for URLs when useFileOutput is true - data uri", async () => { + client = new Replicate({ auth: "foo", useFileOutput: true }); + + nock(BASE_URL) + .post("/predictions") + .reply(201, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "starting", + logs: null, + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "processing", + logs: [].join("\n"), + }) + .get("/predictions/ufawqhfynnddngldkgtslldrkq") + .reply(200, { + id: "ufawqhfynnddngldkgtslldrkq", + status: "succeeded", + output: "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==", + logs: [].join("\n"), + }); + + const output = (await client.run( + "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + { + input: { text: "Hello, world!" }, + } + )) as FileOutput; + + expect(output).toBeInstanceOf(ReadableStream); + expect(output.url()).toEqual( + new URL("data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==") + ); + + const blob = await output.blob(); + expect(blob.type).toEqual("text/plain"); + expect(blob.arrayBuffer()).toEqual( + new Blob(["Hello, world!"]).arrayBuffer() + ); + }); }); describe("webhooks.default.secret.get", () => { diff --git a/lib/stream.js b/lib/stream.js index 2e0bbde..2f72e2c 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -98,7 +98,70 @@ function createReadableStream({ url, fetch, options = {} }) { }); } +/** + * Create a new readable stream for an output file + * created by running a Replicate model. + * + * @param {object} config + * @param {string} config.url The URL to connect to. + * @param {typeof fetch} [config.fetch] The URL to connect to. + * @returns {ReadableStream} + */ +function createFileOutput({ url, fetch }) { + let type = "application/octet-stream"; + + class FileOutput extends ReadableStream { + async blob() { + const chunks = []; + for await (const chunk of this) { + chunks.push(chunk); + } + return new Blob(chunks, { type }); + } + + url() { + return new URL(url); + } + + toString() { + return url; + } + } + + return new FileOutput({ + async start(controller) { + const response = await fetch(url); + + if (!response.ok) { + const text = await response.text(); + const request = new Request(url, init); + controller.error( + new ApiError( + `Request to ${url} failed with status ${response.status}: ${text}`, + request, + response + ) + ); + } + + if (response.headers.get("Content-Type")) { + type = response.headers.get("Content-Type"); + } + + try { + for await (const chunk of streamAsyncIterator(response.body)) { + controller.enqueue(chunk); + } + controller.close(); + } catch (err) { + controller.error(err); + } + }, + }); +} + module.exports = { + createFileOutput, createReadableStream, ServerSentEvent, }; diff --git a/lib/util.js b/lib/util.js index daecfd1..bd3c31e 100644 --- a/lib/util.js +++ b/lib/util.js @@ -318,7 +318,7 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) { } const data = bytesToBase64(buffer); - mime = mime ?? "application/octet-stream"; + mime = mime || "application/octet-stream"; return `data:${mime};base64,${data}`; }); @@ -452,6 +452,7 @@ async function* streamAsyncIterator(stream) { } module.exports = { + transform, transformFileInputs, validateWebhook, withAutomaticRetries,