diff --git a/chat-server/src/middleware/requestOrigin.test.ts b/chat-server/src/middleware/requestOrigin.test.ts index 11b064f63..659fb5bae 100644 --- a/chat-server/src/middleware/requestOrigin.test.ts +++ b/chat-server/src/middleware/requestOrigin.test.ts @@ -9,28 +9,77 @@ const baseReq = { ip: "127.0.0.1", }; +function caseInsensitiveHeaders(headers: Record) { + // Express automatically converts all headers to lowercase but + // node-mocks-http does not. This function is a workaround for that. + return Object.entries(headers).reduce((acc, [key, value]) => { + acc[key.toLowerCase()] = value; + return acc; + }, {} as Record); +} + describe("requireRequestOrigin", () => { - it("blocks any request where the Origin header is not set", async () => { + it(`blocks any request where neither the Origin nor the X-Request-Origin header is set`, async () => { const req = createRequest(); const res = createResponse(); const next = jest.fn(); const middleware = requireRequestOrigin(); - req.body = baseReq.body - req.params = baseReq.params - req.query = baseReq.query - req.headers = baseReq.headers - req.ip = baseReq.ip + req.body = baseReq.body; + req.params = baseReq.params; + req.query = baseReq.query; + req.headers = baseReq.headers; + req.ip = baseReq.ip; await middleware(req, res, next); expect(next).toHaveBeenCalledTimes(0); expect(res.statusCode).toBe(400); expect(res._getJSONData()).toEqual({ - error: "No Origin header", + error: "You must specify either an Origin or X-Request-Origin header", + }); + }); + it(`allows any request where the Origin header is set`, async () => { + const req = createRequest(); + const res = createResponse(); + const next = jest.fn(); + + const middleware = requireRequestOrigin(); + req.body = baseReq.body; + req.params = baseReq.params; + req.query = baseReq.query; + req.headers = caseInsensitiveHeaders({ + ...baseReq.headers, + origin: "http://localhost:5173", + }); + req.ip = baseReq.ip; + + await middleware(req, res, next); + + expect(next).toHaveBeenCalledTimes(1); + expect(req.origin).toEqual("http://localhost:5173"); + }); + it(`allows any request where the X-Request-Origin header is set`, async () => { + const req = createRequest(); + const res = createResponse(); + const next = jest.fn(); + + const middleware = requireRequestOrigin(); + req.body = baseReq.body; + req.params = baseReq.params; + req.query = baseReq.query; + req.headers = caseInsensitiveHeaders({ + ...baseReq.headers, + "X-Request-Origin": "http://localhost:5173/foo/bar", }); + req.ip = baseReq.ip; + + await middleware(req, res, next); + + expect(next).toHaveBeenCalledTimes(1); + expect(req.origin).toEqual("http://localhost:5173/foo/bar"); }); - it("allows any request where the Origin header is set", async () => { + it(`prefers X-Request-Origin over Origin when both are set`, async () => { const req = createRequest(); const res = createResponse(); const next = jest.fn(); @@ -42,11 +91,13 @@ describe("requireRequestOrigin", () => { req.headers = { ...baseReq.headers, origin: "http://localhost:5173", + "x-request-origin": "http://localhost:5173/foo/bar", }; req.ip = baseReq.ip; await middleware(req, res, next); expect(next).toHaveBeenCalledTimes(1); + expect(req.origin).toEqual("http://localhost:5173/foo/bar"); }); }); diff --git a/chat-server/src/middleware/requestOrigin.ts b/chat-server/src/middleware/requestOrigin.ts index dc5e3a6b3..2194bb2b0 100644 --- a/chat-server/src/middleware/requestOrigin.ts +++ b/chat-server/src/middleware/requestOrigin.ts @@ -1,6 +1,8 @@ import { Request, Response, NextFunction } from "express"; import { getRequestId, logRequest, sendErrorResponse } from "../utils"; +export const CUSTOM_REQUEST_ORIGIN_HEADER = "X-Request-Origin"; + declare module "express-serve-static-core" { interface Request { origin: string; @@ -10,20 +12,26 @@ declare module "express-serve-static-core" { export function requireRequestOrigin() { return async (req: Request, res: Response, next: NextFunction) => { const reqId = getRequestId(req); - const { origin } = req.headers; - if (!origin) { + + const origin = req.header("origin"); + const customOrigin = req.header(CUSTOM_REQUEST_ORIGIN_HEADER); + const requestOrigin = customOrigin || origin; + + if (!requestOrigin) { return sendErrorResponse({ reqId, res, httpStatus: 400, - errorMessage: "No Origin header", + errorMessage: `You must specify either an Origin or ${CUSTOM_REQUEST_ORIGIN_HEADER} header`, }); } + + req.origin = requestOrigin; logRequest({ reqId, - message: `Request origin ${origin} is allowed`, + message: `Request origin ${req.origin} is allowed`, }); - req.origin = origin; + return next(); }; } diff --git a/chat-server/src/routes/conversations/addMessageToConversation.test.ts b/chat-server/src/routes/conversations/addMessageToConversation.test.ts index 895aee81f..aa6cf0d79 100644 --- a/chat-server/src/routes/conversations/addMessageToConversation.test.ts +++ b/chat-server/src/routes/conversations/addMessageToConversation.test.ts @@ -160,12 +160,14 @@ describe("POST /conversations/:conversationId/messages", () => { }); }); - it("should respond 400 if the Origin header is missing", async () => { + it("should respond 400 if neither the Origin nor X-Request-Origin header is present", async () => { const res: request.Response = await request(app) .post(endpointUrl.replace(":conversationId", conversationId)) .send({ message: "howdy there" }); expect(res.statusCode).toEqual(400); - expect(res.body).toEqual({ error: "No Origin header" }); + expect(res.body).toEqual({ + error: "You must specify either an Origin or X-Request-Origin header", + }); }); it("should respond 400 for invalid request bodies", async () => { diff --git a/chat-server/src/routes/conversations/createConversation.test.ts b/chat-server/src/routes/conversations/createConversation.test.ts index 0210ce51a..a8581569e 100644 --- a/chat-server/src/routes/conversations/createConversation.test.ts +++ b/chat-server/src/routes/conversations/createConversation.test.ts @@ -35,8 +35,11 @@ describe("POST /conversations", () => { expect(count).toBe(1); }); - it("should respond 400 if the Origin header is missing", async () => { + it("should respond 400 if neither the Origin nor X-Request-Origin header is present", async () => { const res = await request(app).post(CONVERSATIONS_API_V1_PREFIX).send(); expect(res.statusCode).toEqual(400); + expect(res.body).toEqual({ + error: "You must specify either an Origin or X-Request-Origin header", + }); }); }); diff --git a/chat-ui/src/services/conversations.test.ts b/chat-ui/src/services/conversations.test.ts index 8cf303c03..2d48bf462 100644 --- a/chat-ui/src/services/conversations.test.ts +++ b/chat-ui/src/services/conversations.test.ts @@ -1,5 +1,5 @@ import { vi } from "vitest"; -import { ConversationService, formatReferences } from "./conversations"; +import { ConversationService, formatReferences, getCustomRequestOrigin } from "./conversations"; import { type References } from "mongodb-rag-core"; import * as FetchEventSource from "@microsoft/fetch-event-source"; @@ -251,3 +251,24 @@ describe("formatReferences", () => { ); }); }); + +describe("getCustomRequestOrigin", () => { + it("returns the current window location if it exists", () => { + const mockWindowLocation = "https://example.com/foo/bar"; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (global as any).window = { + ...global.window, + location: { + ...global.window.location, + href: mockWindowLocation, + }, + }; + expect(getCustomRequestOrigin()).toEqual(mockWindowLocation); + }) + + it("returns null if the current window location does not exist", () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (global as any).window = undefined; + expect(getCustomRequestOrigin()).toEqual(undefined); + }) +}); diff --git a/chat-ui/src/services/conversations.ts b/chat-ui/src/services/conversations.ts index 8bd81cde7..bc489e89b 100644 --- a/chat-ui/src/services/conversations.ts +++ b/chat-ui/src/services/conversations.ts @@ -26,6 +26,15 @@ export function formatReferences(references: References): string { return [heading, ...listOfLinks].join("\n\n"); } +export const CUSTOM_REQUEST_ORIGIN_HEADER = "X-Request-Origin"; + +export function getCustomRequestOrigin() { + if (typeof window !== "undefined") { + return window.location.href; + } + return undefined; +} + class RetriableError extends Error { retryAfter: number; data?: Data; @@ -96,6 +105,7 @@ export class ConversationService { method: "POST", headers: { "Content-Type": "application/json", + [CUSTOM_REQUEST_ORIGIN_HEADER]: getCustomRequestOrigin() ?? "", }, }); const conversation = await resp.json(); @@ -127,6 +137,7 @@ export class ConversationService { method: "POST", headers: { "Content-Type": "application/json", + [CUSTOM_REQUEST_ORIGIN_HEADER]: getCustomRequestOrigin() ?? "", }, body: JSON.stringify({ message }), }); @@ -192,6 +203,7 @@ export class ConversationService { method: "POST", headers: { "Content-Type": "application/json", + [CUSTOM_REQUEST_ORIGIN_HEADER]: getCustomRequestOrigin() ?? "", }, body: JSON.stringify({ message }), openWhenHidden: true,