Skip to content

Add support for validating webhooks #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ declare module "replicate" {
retry?: number;
}

export interface WebhookSecret {
key: string;
}

export default class Replicate {
constructor(options?: {
auth?: string;
Expand Down Expand Up @@ -233,5 +237,26 @@ declare module "replicate" {
cancel(training_id: string): Promise<Training>;
list(): Promise<Page<Training>>;
};

webhooks: {
default: {
secret: {
get(): Promise<WebhookSecret>;
};
};
};
}

export function validateWebhook(
requestData:
| Request
| {
id?: string;
timestamp?: string;
body: string;
secret?: string;
signature?: string;
},
secret: string
): boolean;
}
12 changes: 11 additions & 1 deletion index.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const ApiError = require("./lib/error");
const ModelVersionIdentifier = require("./lib/identifier");
const { Stream } = require("./lib/stream");
const { withAutomaticRetries } = require("./lib/util");
const { withAutomaticRetries, validateWebhook } = require("./lib/util");

const accounts = require("./lib/accounts");
const collections = require("./lib/collections");
Expand All @@ -10,6 +10,7 @@ const hardware = require("./lib/hardware");
const models = require("./lib/models");
const predictions = require("./lib/predictions");
const trainings = require("./lib/trainings");
const webhooks = require("./lib/webhooks");

const packageJSON = require("./package.json");

Expand Down Expand Up @@ -90,6 +91,14 @@ class Replicate {
cancel: trainings.cancel.bind(this),
list: trainings.list.bind(this),
};

this.webhooks = {
default: {
secret: {
get: webhooks.default.secret.get.bind(this),
},
},
};
}

/**
Expand Down Expand Up @@ -364,3 +373,4 @@ class Replicate {
}

module.exports = Replicate;
module.exports.validateWebhook = validateWebhook;
41 changes: 40 additions & 1 deletion index.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { expect, jest, test } from "@jest/globals";
import Replicate, { ApiError, Model, Prediction } from "replicate";
import Replicate, {
ApiError,
Model,
Prediction,
validateWebhook,
} from "replicate";
import nock from "nock";
import fetch from "cross-fetch";

Expand Down Expand Up @@ -996,5 +1001,39 @@ describe("Replicate client", () => {
});
});

describe("webhooks.default.secret.get", () => {
test("Calls the correct API route", async () => {
nock(BASE_URL).get("/webhooks/default/secret").reply(200, {
key: "whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH",
});

const secret = await client.webhooks.default.secret.get();
expect(secret.key).toBe("whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH");
});

test("Can be used to validate webhook", async () => {
// Test case from https://github.com/svix/svix-webhooks/blob/b41728cd98a7e7004a6407a623f43977b82fcba4/javascript/src/webhook.test.ts#L190-L200
const request = new Request("http://test.host/webhook", {
method: "POST",
headers: {
"Content-Type": "application/json",
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
"Webhook-Timestamp": "1614265330",
"Webhook-Signature":
"v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=",
},
body: `{"test": 2432232314}`,
});

// This is a test secret and should not be used in production
const secret = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw";

const isValid = await validateWebhook(request, secret);
expect(isValid).toBe(true);
});

// Add more tests for error handling, edge cases, etc.
});

// Continue with tests for other methods
});
90 changes: 89 additions & 1 deletion lib/util.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,93 @@
const crypto = require("node:crypto");

const ApiError = require("./error");

/**
* @see {@link validateWebhook}
* @overload
* @param {object} requestData - The request data
* @param {string} requestData.id - The webhook ID header from the incoming request.
* @param {string} requestData.timestamp - The webhook timestamp header from the incoming request.
* @param {string} requestData.body - The raw body of the incoming webhook request.
* @param {string} requestData.secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method.
* @param {string} requestData.signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures.
*/

/**
* @see {@link validateWebhook}
* @overload
* @param {object} requestData - The request object
* @param {object} requestData.headers - The request headers
* @param {string} requestData.headers["webhook-id"] - The webhook ID header from the incoming request
* @param {string} requestData.headers["webhook-timestamp"] - The webhook timestamp header from the incoming request
* @param {string} requestData.headers["webhook-signature"] - The webhook signature header from the incoming request, comprising one or more space-delimited signatures
* @param {string} requestData.body - The raw body of the incoming webhook request
* @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method
*/

/**
* Validate a webhook signature
*
* @returns {boolean} - True if the signature is valid
* @throws {Error} - If the request is missing required headers, body, or secret
*/
async function validateWebhook(requestData, secret) {
let { id, timestamp, body, signature } = requestData;
const signingSecret = secret || requestData.secret;

if (requestData && requestData.headers && requestData.body) {
id = requestData.headers.get("webhook-id");
timestamp = requestData.headers.get("webhook-timestamp");
signature = requestData.headers.get("webhook-signature");
body = requestData.body;
}

if (body instanceof ReadableStream || body.readable) {
try {
const chunks = [];
for await (const chunk of body) {
chunks.push(Buffer.from(chunk));
}
body = Buffer.concat(chunks).toString("utf8");
} catch (err) {
throw new Error(`Error reading body: ${err.message}`);
}
} else if (body instanceof Buffer) {
body = body.toString("utf8");
} else if (typeof body !== "string") {
throw new Error("Invalid body type");
}

if (!id || !timestamp || !signature) {
throw new Error("Missing required webhook headers");
}

if (!body) {
throw new Error("Missing required body");
}

if (!signingSecret) {
throw new Error("Missing required secret");
}

const signedContent = `${id}.${timestamp}.${body}`;

const secretBytes = Buffer.from(signingSecret.split("_")[1], "base64");

const computedSignature = crypto
.createHmac("sha256", secretBytes)
.update(signedContent)
.digest("base64");

const expectedSignatures = signature
.split(" ")
.map((sig) => sig.split(",")[1]);

return expectedSignatures.some(
(expectedSignature) => expectedSignature === computedSignature
);
}

/**
* Automatically retry a request if it fails with an appropriate status code.
*
Expand Down Expand Up @@ -68,4 +156,4 @@ async function withAutomaticRetries(request, options = {}) {
return request();
}

module.exports = { withAutomaticRetries };
module.exports = { validateWebhook, withAutomaticRetries };
20 changes: 20 additions & 0 deletions lib/webhooks.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/**
* Get the default webhook signing secret
*
* @returns {Promise<object>} Resolves with the signing secret for the default webhook
*/
async function getDefaultWebhookSecret() {
const response = await this.request("/webhooks/default/secret", {
method: "GET",
});

return response.json();
}

module.exports = {
default: {
secret: {
get: getDefaultWebhookSecret,
},
},
};