Skip to content

Commit

Permalink
(EAI-137) [Q&A] Custom request origins (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
nlarew authored Nov 13, 2023
1 parent a6d1521 commit 7d08d4e
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 17 deletions.
67 changes: 59 additions & 8 deletions chat-server/src/middleware/requestOrigin.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,77 @@ const baseReq = {
ip: "127.0.0.1",
};

function caseInsensitiveHeaders(headers: Record<string, string>) {
// Express automatically converts all headers to lowercase but
// node-mocks-http does not. This function is a workaround for that.
return Object.entries(headers).reduce((acc, [key, value]) => {
acc[key.toLowerCase()] = value;
return acc;
}, {} as Record<string, string>);
}

describe("requireRequestOrigin", () => {
it("blocks any request where the Origin header is not set", async () => {
it(`blocks any request where neither the Origin nor the X-Request-Origin header is set`, async () => {
const req = createRequest();
const res = createResponse();
const next = jest.fn();

const middleware = requireRequestOrigin();
req.body = baseReq.body
req.params = baseReq.params
req.query = baseReq.query
req.headers = baseReq.headers
req.ip = baseReq.ip
req.body = baseReq.body;
req.params = baseReq.params;
req.query = baseReq.query;
req.headers = baseReq.headers;
req.ip = baseReq.ip;

await middleware(req, res, next);

expect(next).toHaveBeenCalledTimes(0);
expect(res.statusCode).toBe(400);
expect(res._getJSONData()).toEqual({
error: "No Origin header",
error: "You must specify either an Origin or X-Request-Origin header",
});
});
it(`allows any request where the Origin header is set`, async () => {
const req = createRequest();
const res = createResponse();
const next = jest.fn();

const middleware = requireRequestOrigin();
req.body = baseReq.body;
req.params = baseReq.params;
req.query = baseReq.query;
req.headers = caseInsensitiveHeaders({
...baseReq.headers,
origin: "http://localhost:5173",
});
req.ip = baseReq.ip;

await middleware(req, res, next);

expect(next).toHaveBeenCalledTimes(1);
expect(req.origin).toEqual("http://localhost:5173");
});
it(`allows any request where the X-Request-Origin header is set`, async () => {
const req = createRequest();
const res = createResponse();
const next = jest.fn();

const middleware = requireRequestOrigin();
req.body = baseReq.body;
req.params = baseReq.params;
req.query = baseReq.query;
req.headers = caseInsensitiveHeaders({
...baseReq.headers,
"X-Request-Origin": "http://localhost:5173/foo/bar",
});
req.ip = baseReq.ip;

await middleware(req, res, next);

expect(next).toHaveBeenCalledTimes(1);
expect(req.origin).toEqual("http://localhost:5173/foo/bar");
});
it("allows any request where the Origin header is set", async () => {
it(`prefers X-Request-Origin over Origin when both are set`, async () => {
const req = createRequest();
const res = createResponse();
const next = jest.fn();
Expand All @@ -42,11 +91,13 @@ describe("requireRequestOrigin", () => {
req.headers = {
...baseReq.headers,
origin: "http://localhost:5173",
"x-request-origin": "http://localhost:5173/foo/bar",
};
req.ip = baseReq.ip;

await middleware(req, res, next);

expect(next).toHaveBeenCalledTimes(1);
expect(req.origin).toEqual("http://localhost:5173/foo/bar");
});
});
18 changes: 13 additions & 5 deletions chat-server/src/middleware/requestOrigin.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { Request, Response, NextFunction } from "express";
import { getRequestId, logRequest, sendErrorResponse } from "../utils";

export const CUSTOM_REQUEST_ORIGIN_HEADER = "X-Request-Origin";

declare module "express-serve-static-core" {
interface Request {
origin: string;
Expand All @@ -10,20 +12,26 @@ declare module "express-serve-static-core" {
export function requireRequestOrigin() {
return async (req: Request, res: Response, next: NextFunction) => {
const reqId = getRequestId(req);
const { origin } = req.headers;
if (!origin) {

const origin = req.header("origin");
const customOrigin = req.header(CUSTOM_REQUEST_ORIGIN_HEADER);
const requestOrigin = customOrigin || origin;

if (!requestOrigin) {
return sendErrorResponse({
reqId,
res,
httpStatus: 400,
errorMessage: "No Origin header",
errorMessage: `You must specify either an Origin or ${CUSTOM_REQUEST_ORIGIN_HEADER} header`,
});
}

req.origin = requestOrigin;
logRequest({
reqId,
message: `Request origin ${origin} is allowed`,
message: `Request origin ${req.origin} is allowed`,
});
req.origin = origin;

return next();
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,14 @@ describe("POST /conversations/:conversationId/messages", () => {
});
});

it("should respond 400 if the Origin header is missing", async () => {
it("should respond 400 if neither the Origin nor X-Request-Origin header is present", async () => {
const res: request.Response = await request(app)
.post(endpointUrl.replace(":conversationId", conversationId))
.send({ message: "howdy there" });
expect(res.statusCode).toEqual(400);
expect(res.body).toEqual({ error: "No Origin header" });
expect(res.body).toEqual({
error: "You must specify either an Origin or X-Request-Origin header",
});
});

it("should respond 400 for invalid request bodies", async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ describe("POST /conversations", () => {
expect(count).toBe(1);
});

it("should respond 400 if the Origin header is missing", async () => {
it("should respond 400 if neither the Origin nor X-Request-Origin header is present", async () => {
const res = await request(app).post(CONVERSATIONS_API_V1_PREFIX).send();
expect(res.statusCode).toEqual(400);
expect(res.body).toEqual({
error: "You must specify either an Origin or X-Request-Origin header",
});
});
});
23 changes: 22 additions & 1 deletion chat-ui/src/services/conversations.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { vi } from "vitest";
import { ConversationService, formatReferences } from "./conversations";
import { ConversationService, formatReferences, getCustomRequestOrigin } from "./conversations";
import { type References } from "mongodb-rag-core";
import * as FetchEventSource from "@microsoft/fetch-event-source";

Expand Down Expand Up @@ -251,3 +251,24 @@ describe("formatReferences", () => {
);
});
});

describe("getCustomRequestOrigin", () => {
it("returns the current window location if it exists", () => {
const mockWindowLocation = "https://example.com/foo/bar";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(global as any).window = {
...global.window,
location: {
...global.window.location,
href: mockWindowLocation,
},
};
expect(getCustomRequestOrigin()).toEqual(mockWindowLocation);
})

it("returns null if the current window location does not exist", () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(global as any).window = undefined;
expect(getCustomRequestOrigin()).toEqual(undefined);
})
});
12 changes: 12 additions & 0 deletions chat-ui/src/services/conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ export function formatReferences(references: References): string {
return [heading, ...listOfLinks].join("\n\n");
}

export const CUSTOM_REQUEST_ORIGIN_HEADER = "X-Request-Origin";

export function getCustomRequestOrigin() {
if (typeof window !== "undefined") {
return window.location.href;
}
return undefined;
}

class RetriableError<Data extends object = object> extends Error {
retryAfter: number;
data?: Data;
Expand Down Expand Up @@ -96,6 +105,7 @@ export class ConversationService {
method: "POST",
headers: {
"Content-Type": "application/json",
[CUSTOM_REQUEST_ORIGIN_HEADER]: getCustomRequestOrigin() ?? "",
},
});
const conversation = await resp.json();
Expand Down Expand Up @@ -127,6 +137,7 @@ export class ConversationService {
method: "POST",
headers: {
"Content-Type": "application/json",
[CUSTOM_REQUEST_ORIGIN_HEADER]: getCustomRequestOrigin() ?? "",
},
body: JSON.stringify({ message }),
});
Expand Down Expand Up @@ -192,6 +203,7 @@ export class ConversationService {
method: "POST",
headers: {
"Content-Type": "application/json",
[CUSTOM_REQUEST_ORIGIN_HEADER]: getCustomRequestOrigin() ?? "",
},
body: JSON.stringify({ message }),
openWhenHidden: true,
Expand Down

0 comments on commit 7d08d4e

Please sign in to comment.