diff --git a/README.md b/README.md index 6893da2..222deff 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ console.log(prediction.output); // ['https://replicate.delivery/pbxt/RoaxeXqhL0xaYyLm6w3bpGwF5RaNBjADukfFnMbhOyeoWBdhA/out-0.png'] ``` -To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can convert file data into a base64-encoded data URI and pass that directly: +To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can pass the data directly. ```js const fs = require("node:fs/promises"); @@ -81,18 +81,9 @@ const fs = require("node:fs/promises"); // Or when using ESM. // import fs from "node:fs/promises"; -// Read the file into a buffer -const data = await fs.readFile("path/to/image.png"); -// Convert the buffer into a base64-encoded string -const base64 = data.toString("base64"); -// Set MIME type for PNG image -const mimeType = "image/png"; -// Create the data URI -const dataURI = `data:${mimeType};base64,${base64}`; - const model = "nightmareai/real-esrgan:42fed1c4974146d4d2414e2be2c5277c7fcf05fcc3a73abf41610695738c1d7b"; const input = { - image: dataURI, + image: await fs.readFile("path/to/image.png"), }; const output = await replicate.run(model, { input }); // ['https://replicate.delivery/mgxm/e7b0e122-9daa-410e-8cde-006c7308ff4d/output.png'] diff --git a/index.test.ts b/index.test.ts index ae01338..f00a7e6 100644 --- a/index.test.ts +++ b/index.test.ts @@ -221,6 +221,54 @@ describe("Replicate client", () => { expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq"); }); + test.each([ + // Skip test case if File type is not available + ...(typeof File !== "undefined" + ? [ + { + type: "file", + value: new File(["hello world"], "hello.txt", { + type: "text/plain", + }), + expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=", + }, + ] + : []), + { + type: "blob", + value: new Blob(["hello world"], { type: "text/plain" }), + expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=", + }, + { + type: "buffer", + value: Buffer.from("hello world"), + expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=", + }, + ])( + "converts a $type input into a base64 encoded string", + async ({ value: data, expected }) => { + let actual: Record | undefined; + nock(BASE_URL) + .post("/predictions") + .reply(201, (uri: string, body: Record) => { + actual = body; + return body; + }); + + await client.predictions.create({ + version: + "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + input: { + prompt: "Tell me a story", + data, + }, + stream: true, + }); + + expect(actual?.input.data).toEqual(expected); + } + ); + test("Passes stream parameter to API endpoint", async () => { nock(BASE_URL) .post("/predictions") diff --git a/lib/deployments.js b/lib/deployments.js index 8ba5ea3..3e1ceeb 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -1,3 +1,5 @@ +const { transformFileInputs } = require("./util"); + /** * Create a new prediction with a deployment * @@ -11,7 +13,7 @@ * @returns {Promise} Resolves with the created prediction data */ async function createPrediction(deployment_owner, deployment_name, options) { - const { stream, ...data } = options; + const { stream, input, ...data } = options; if (data.webhook) { try { @@ -26,7 +28,11 @@ async function createPrediction(deployment_owner, deployment_name, options) { `/deployments/${deployment_owner}/${deployment_name}/predictions`, { method: "POST", - data: { ...data, stream }, + data: { + ...data, + input: await transformFileInputs(input), + stream, + }, } ); diff --git a/lib/predictions.js b/lib/predictions.js index 294e8d9..5b0370e 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -1,3 +1,5 @@ +const { transformFileInputs } = require("./util"); + /** * Create a new prediction * @@ -11,7 +13,7 @@ * @returns {Promise} Resolves with the created prediction */ async function createPrediction(options) { - const { model, version, stream, ...data } = options; + const { model, version, stream, input, ...data } = options; if (data.webhook) { try { @@ -26,12 +28,21 @@ async function createPrediction(options) { if (version) { response = await this.request("/predictions", { method: "POST", - data: { ...data, stream, version }, + data: { + ...data, + input: await transformFileInputs(input), + version, + stream, + }, }); } else if (model) { response = await this.request(`/models/${model}/predictions`, { method: "POST", - data: { ...data, stream }, + data: { + ...data, + input: await transformFileInputs(input), + stream, + }, }); } else { throw new Error("Either model or version must be specified"); diff --git a/lib/util.js b/lib/util.js index 6bd70ec..48d7563 100644 --- a/lib/util.js +++ b/lib/util.js @@ -156,4 +156,94 @@ async function withAutomaticRetries(request, options = {}) { return request(); } -module.exports = { validateWebhook, withAutomaticRetries }; +const MAX_DATA_URI_SIZE = 10_000_000; + +/** + * Walks the inputs and transforms any binary data found into a + * base64-encoded data URI. + * + * @param {object} inputs - The inputs to transform + * @returns {object} - The transformed inputs + * @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE + */ +async function transformFileInputs(inputs) { + let totalBytes = 0; + const result = await transform(inputs, async (value) => { + let buffer; + let mime; + + if (value instanceof Blob) { + // Currently we use a NodeJS only API for base64 encoding, as + // we move to support the browser we could support either using + // btoa (which does string encoding), the FileReader API or + // a JavaScript implenentation like base64-js. + // See: https://developer.mozilla.org/en-US/docs/Glossary/Base64 + // See: https://github.com/beatgammit/base64-js + buffer = Buffer.from(await value.arrayBuffer()); + mime = value.type; + } else if (Buffer.isBuffer(value)) { + buffer = value; + } else { + return value; + } + + totalBytes += buffer.byteLength; + if (totalBytes > MAX_DATA_URI_SIZE) { + throw new Error( + `Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead` + ); + } + + const data = buffer.toString("base64"); + mime = mime ?? "application/octet-stream"; + + return `data:${mime};base64,${data}`; + }); + + return result; +} + +// Walk a JavaScript object and transform the leaf values. +async function transform(value, mapper) { + if (Array.isArray(value)) { + let copy = []; + for (const val of value) { + copy = await transform(val, mapper); + } + return copy; + } + + if (isPlainObject(value)) { + const copy = {}; + for (const key of Object.keys(value)) { + copy[key] = await transform(value[key], mapper); + } + return copy; + } + + return await mapper(value); +} + +// Test for a plain JS object. +// Source: lodash.isPlainObject +function isPlainObject(value) { + const isObjectLike = typeof value === "object" && value !== null; + if (!isObjectLike || String(value) !== "[object Object]") { + return false; + } + const proto = Object.getPrototypeOf(value); + if (proto === null) { + return true; + } + const Ctor = + Object.prototype.hasOwnProperty.call(proto, "constructor") && + proto.constructor; + return ( + typeof Ctor === "function" && + Ctor instanceof Ctor && + Function.prototype.toString.call(Ctor) === + Function.prototype.toString.call(Object) + ); +} + +module.exports = { transformFileInputs, validateWebhook, withAutomaticRetries };