Skip to content

Automatically Base64 encode inputs #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,17 @@ console.log(prediction.output);
// ['https://replicate.delivery/pbxt/RoaxeXqhL0xaYyLm6w3bpGwF5RaNBjADukfFnMbhOyeoWBdhA/out-0.png']
```

To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can convert file data into a base64-encoded data URI and pass that directly:
To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can pass the data directly.

```js
const fs = require("node:fs/promises");

// Or when using ESM.
// import fs from "node:fs/promises";

// Read the file into a buffer
const data = await fs.readFile("path/to/image.png");
// Convert the buffer into a base64-encoded string
const base64 = data.toString("base64");
// Set MIME type for PNG image
const mimeType = "image/png";
// Create the data URI
const dataURI = `data:${mimeType};base64,${base64}`;

const model = "nightmareai/real-esrgan:42fed1c4974146d4d2414e2be2c5277c7fcf05fcc3a73abf41610695738c1d7b";
const input = {
image: dataURI,
image: await fs.readFile("path/to/image.png"),
};
const output = await replicate.run(model, { input });
// ['https://replicate.delivery/mgxm/e7b0e122-9daa-410e-8cde-006c7308ff4d/output.png']
Expand Down
48 changes: 48 additions & 0 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,54 @@ describe("Replicate client", () => {
expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq");
});

test.each([
// Skip test case if File type is not available
...(typeof File !== "undefined"
? [
{
type: "file",
value: new File(["hello world"], "hello.txt", {
type: "text/plain",
}),
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
},
]
: []),
{
type: "blob",
value: new Blob(["hello world"], { type: "text/plain" }),
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
},
{
type: "buffer",
value: Buffer.from("hello world"),
expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=",
},
])(
"converts a $type input into a base64 encoded string",
async ({ value: data, expected }) => {
let actual: Record<string, any> | undefined;
nock(BASE_URL)
.post("/predictions")
.reply(201, (uri: string, body: Record<string, any>) => {
actual = body;
return body;
});

await client.predictions.create({
version:
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
input: {
prompt: "Tell me a story",
data,
},
stream: true,
});

expect(actual?.input.data).toEqual(expected);
}
);

test("Passes stream parameter to API endpoint", async () => {
nock(BASE_URL)
.post("/predictions")
Expand Down
10 changes: 8 additions & 2 deletions lib/deployments.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const { transformFileInputs } = require("./util");

/**
* Create a new prediction with a deployment
*
Expand All @@ -11,7 +13,7 @@
* @returns {Promise<object>} Resolves with the created prediction data
*/
async function createPrediction(deployment_owner, deployment_name, options) {
const { stream, ...data } = options;
const { stream, input, ...data } = options;

if (data.webhook) {
try {
Expand All @@ -26,7 +28,11 @@ async function createPrediction(deployment_owner, deployment_name, options) {
`/deployments/${deployment_owner}/${deployment_name}/predictions`,
{
method: "POST",
data: { ...data, stream },
data: {
...data,
input: await transformFileInputs(input),
stream,
},
}
);

Expand Down
17 changes: 14 additions & 3 deletions lib/predictions.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const { transformFileInputs } = require("./util");

/**
* Create a new prediction
*
Expand All @@ -11,7 +13,7 @@
* @returns {Promise<object>} Resolves with the created prediction
*/
async function createPrediction(options) {
const { model, version, stream, ...data } = options;
const { model, version, stream, input, ...data } = options;

if (data.webhook) {
try {
Expand All @@ -26,12 +28,21 @@ async function createPrediction(options) {
if (version) {
response = await this.request("/predictions", {
method: "POST",
data: { ...data, stream, version },
data: {
...data,
input: await transformFileInputs(input),
version,
stream,
},
});
} else if (model) {
response = await this.request(`/models/${model}/predictions`, {
method: "POST",
data: { ...data, stream },
data: {
...data,
input: await transformFileInputs(input),
stream,
},
});
} else {
throw new Error("Either model or version must be specified");
Expand Down
92 changes: 91 additions & 1 deletion lib/util.js
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,94 @@ async function withAutomaticRetries(request, options = {}) {
return request();
}

module.exports = { validateWebhook, withAutomaticRetries };
const MAX_DATA_URI_SIZE = 10_000_000;

/**
* Walks the inputs and transforms any binary data found into a
* base64-encoded data URI.
*
* @param {object} inputs - The inputs to transform
* @returns {object} - The transformed inputs
* @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE
*/
async function transformFileInputs(inputs) {
let totalBytes = 0;
const result = await transform(inputs, async (value) => {
let buffer;
let mime;

if (value instanceof Blob) {
// Currently we use a NodeJS only API for base64 encoding, as
// we move to support the browser we could support either using
// btoa (which does string encoding), the FileReader API or
// a JavaScript implenentation like base64-js.
// See: https://developer.mozilla.org/en-US/docs/Glossary/Base64
// See: https://github.com/beatgammit/base64-js
buffer = Buffer.from(await value.arrayBuffer());
mime = value.type;
} else if (Buffer.isBuffer(value)) {
buffer = value;
} else {
return value;
}

totalBytes += buffer.byteLength;
if (totalBytes > MAX_DATA_URI_SIZE) {
throw new Error(
`Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead`
);
}

const data = buffer.toString("base64");
mime = mime ?? "application/octet-stream";

return `data:${mime};base64,${data}`;
});

return result;
}

// Walk a JavaScript object and transform the leaf values.
async function transform(value, mapper) {
if (Array.isArray(value)) {
let copy = [];
for (const val of value) {
copy = await transform(val, mapper);
}
return copy;
}

if (isPlainObject(value)) {
const copy = {};
for (const key of Object.keys(value)) {
copy[key] = await transform(value[key], mapper);
}
return copy;
}

return await mapper(value);
}

// Test for a plain JS object.
// Source: lodash.isPlainObject
function isPlainObject(value) {
const isObjectLike = typeof value === "object" && value !== null;
if (!isObjectLike || String(value) !== "[object Object]") {
return false;
}
const proto = Object.getPrototypeOf(value);
if (proto === null) {
return true;
}
const Ctor =
Object.prototype.hasOwnProperty.call(proto, "constructor") &&
proto.constructor;
return (
typeof Ctor === "function" &&
Ctor instanceof Ctor &&
Function.prototype.toString.call(Ctor) ===
Function.prototype.toString.call(Object)
);
}

module.exports = { transformFileInputs, validateWebhook, withAutomaticRetries };