diff --git a/packages/core/modules-sdk/src/definitions.ts b/packages/core/modules-sdk/src/definitions.ts index 00fd617168d03..e429682597b35 100644 --- a/packages/core/modules-sdk/src/definitions.ts +++ b/packages/core/modules-sdk/src/definitions.ts @@ -92,7 +92,7 @@ export const ModulesDefinition: { label: upperCaseFirst(Modules.AUTH), isRequired: false, isQueryable: true, - dependencies: [ContainerRegistrationKeys.LOGGER], + dependencies: [ContainerRegistrationKeys.LOGGER, Modules.CACHE], defaultModuleDeclaration: { scope: MODULE_SCOPE.INTERNAL, }, diff --git a/packages/core/types/src/auth/common/provider.ts b/packages/core/types/src/auth/common/provider.ts index 0cc313ffb6a01..bde741e0f92aa 100644 --- a/packages/core/types/src/auth/common/provider.ts +++ b/packages/core/types/src/auth/common/provider.ts @@ -59,6 +59,10 @@ export type AuthenticationInput = { /** * Body of the incoming authentication request. + * + * One of the arguments that is suggested to be treated in a standard manner is a `callback_url` field. + * The field specifies where the user is redirected to after a successful authentication in the case of Oauth auhentication. + * If not passed, the provider will fallback to the callback_url provided in the provider options. */ body?: Record diff --git a/packages/core/types/src/auth/provider.ts b/packages/core/types/src/auth/provider.ts index 5970920521ada..c531a6215213d 100644 --- a/packages/core/types/src/auth/provider.ts +++ b/packages/core/types/src/auth/provider.ts @@ -20,6 +20,9 @@ export interface AuthIdentityProviderService { user_metadata?: Record } ) => Promise + // These methods are used for OAuth providers to store and retrieve state + setState: (key: string, value: Record) => Promise + getState: (key: string) => Promise | null> } /** diff --git a/packages/modules/auth/src/services/auth-module.ts b/packages/modules/auth/src/services/auth-module.ts index de94b9c8754ff..95545d7eaa8be 100644 --- a/packages/modules/auth/src/services/auth-module.ts +++ b/packages/modules/auth/src/services/auth-module.ts @@ -5,6 +5,7 @@ import { AuthTypes, Context, DAL, + ICacheService, InferEntityType, InternalModuleDeclaration, Logger, @@ -27,6 +28,7 @@ type InjectedDependencies = { providerIdentityService: ModulesSdkTypes.IMedusaInternalService authProviderService: AuthProviderService logger?: Logger + cache?: ICacheService } export default class AuthModuleService extends MedusaService<{ @@ -43,13 +45,14 @@ export default class AuthModuleService InferEntityType > protected readonly authProviderService_: AuthProviderService - + protected readonly cache_: ICacheService | undefined constructor( { authIdentityService, providerIdentityService, authProviderService, baseRepository, + cache, }: InjectedDependencies, protected readonly moduleDeclaration: InternalModuleDeclaration ) { @@ -60,6 +63,7 @@ export default class AuthModuleService this.authIdentityService_ = authIdentityService this.authProviderService_ = authProviderService this.providerIdentityService_ = providerIdentityService + this.cache_ = cache } __joinerConfig(): ModuleJoinerConfig { @@ -372,6 +376,27 @@ export default class AuthModuleService return serializedResponse }, + setState: async (key: string, value: Record) => { + if (!this.cache_) { + throw new MedusaError( + MedusaError.Types.INVALID_ARGUMENT, + "Cache module dependency is required when using OAuth providers that require state" + ) + } + + // 20 minutes. Can be made configurable if necessary, but this is a good default. + this.cache_.set(key, value, 1200) + }, + getState: async (key: string) => { + if (!this.cache_) { + throw new MedusaError( + MedusaError.Types.INVALID_ARGUMENT, + "Cache module dependency is required when using OAuth providers that require state" + ) + } + + return await this.cache_.get(key) + }, } } } diff --git a/packages/modules/providers/auth-github/integration-tests/__tests__/services.spec.ts b/packages/modules/providers/auth-github/integration-tests/__tests__/services.spec.ts index af379d684db13..3d0f5b5c28d1c 100644 --- a/packages/modules/providers/auth-github/integration-tests/__tests__/services.spec.ts +++ b/packages/modules/providers/auth-github/integration-tests/__tests__/services.spec.ts @@ -1,4 +1,4 @@ -import { generateJwtToken, MedusaError } from "@medusajs/framework/utils" +import { MedusaError } from "@medusajs/framework/utils" import { GithubAuthService } from "../../src/services/github" import { http, HttpResponse } from "msw" import { setupServer } from "msw/node" @@ -20,6 +20,22 @@ const sampleIdPayload = { } const baseUrl = "https://someurl.com" +const callbackUrl = encodeURIComponent( + "https://someurl.com/auth/github/callback" +) + +let state = {} +const defaultSpies = { + retrieve: jest.fn(), + create: jest.fn(), + update: jest.fn(), + setState: jest.fn().mockImplementation((key, value) => { + state[key] = value + }), + getState: jest.fn().mockImplementation((key) => { + return Promise.resolve(state[key]) + }), +} // This is just a network-layer mocking, it doesn't start an actual server const server = setupServer( @@ -29,7 +45,7 @@ const server = setupServer( const url = request.url if ( url === - "https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback" + `https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=${callbackUrl}` ) { return new HttpResponse(null, { status: 401, @@ -39,7 +55,7 @@ const server = setupServer( if ( url === - "https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=valid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback" + `https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=valid-code&redirect_uri=${callbackUrl}` ) { return new HttpResponse( JSON.stringify({ @@ -91,6 +107,7 @@ describe("Github auth provider", () => { afterEach(() => { server.resetHandlers() jest.restoreAllMocks() + state = {} }) afterAll(() => server.close()) @@ -110,11 +127,27 @@ describe("Github auth provider", () => { }) it("returns a redirect URL on authenticate", async () => { - const res = await githubService.authenticate({}) + const res = await githubService.authenticate({}, defaultSpies) + expect(res).toEqual({ + success: true, + location: `https://github.com/login/oauth/authorize?redirect_uri=${callbackUrl}&client_id=test&response_type=code&state=${ + Object.keys(state)[0] + }`, + }) + }) + + it("returns a custom redirect_uri on authenticate", async () => { + const res = await githubService.authenticate( + { + body: { callback_url: "https://someotherurl.com" }, + }, + defaultSpies + ) expect(res).toEqual({ success: true, - location: - "https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback&client_id=test&response_type=code", + location: `https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fsomeotherurl.com&client_id=test&response_type=code&state=${ + Object.keys(state)[0] + }`, }) }) @@ -123,7 +156,7 @@ describe("Github auth provider", () => { { query: {}, }, - {} as any + defaultSpies ) expect(res).toEqual({ success: false, @@ -131,14 +164,51 @@ describe("Github auth provider", () => { }) }) + it("validate callback should return an error on missing state", async () => { + const res = await githubService.validateCallback( + { + query: { + code: "valid-code", + }, + }, + defaultSpies + ) + expect(res).toEqual({ + success: false, + error: "No state provided, or session expired", + }) + }) + + it("validate callback should return an error on expired/invalid state", async () => { + const res = await githubService.validateCallback( + { + query: { + code: "valid-code", + state: "somekey", + }, + }, + defaultSpies + ) + expect(res).toEqual({ + success: false, + error: "No state provided, or session expired", + }) + }) + it("validate callback should return on a missing access token for code", async () => { + state = { + somekey: { + callback_url: callbackUrl, + }, + } const res = await githubService.validateCallback( { query: { code: "invalid-code", + state: "somekey", }, }, - {} as any + defaultSpies ) expect(res).toEqual({ @@ -149,6 +219,7 @@ describe("Github auth provider", () => { it("validate callback should return successfully on a correct code for a new user", async () => { const authServiceSpies = { + ...defaultSpies, retrieve: jest.fn().mockImplementation(() => { throw new MedusaError(MedusaError.Types.NOT_FOUND, "Not found") }), @@ -167,10 +238,17 @@ describe("Github auth provider", () => { }), } + state = { + somekey: { + callback_url: callbackUrl, + }, + } + const res = await githubService.validateCallback( { query: { code: "valid-code", + state: "somekey", }, }, authServiceSpies @@ -191,6 +269,7 @@ describe("Github auth provider", () => { it("validate callback should return successfully on a correct code for an existing user", async () => { const authServiceSpies = { + ...defaultSpies, retrieve: jest.fn().mockImplementation(() => { return { provider_identities: [ @@ -219,10 +298,17 @@ describe("Github auth provider", () => { }), } + state = { + somekey: { + callback_url: callbackUrl, + }, + } + const res = await githubService.validateCallback( { query: { code: "valid-code", + state: "somekey", }, }, authServiceSpies diff --git a/packages/modules/providers/auth-github/src/services/github.ts b/packages/modules/providers/auth-github/src/services/github.ts index fd9113f3a0621..c174a43d25cee 100644 --- a/packages/modules/providers/auth-github/src/services/github.ts +++ b/packages/modules/providers/auth-github/src/services/github.ts @@ -1,3 +1,4 @@ +import crypto from "crypto" import { AuthenticationInput, AuthenticationResponse, @@ -16,7 +17,6 @@ type InjectedDependencies = { interface LocalServiceConfig extends GithubAuthProviderOptions {} -// TODO: Add state param that is stored in Redis, to prevent CSRF attacks export class GithubAuthService extends AbstractAuthModuleProvider { static identifier = "github" static DISPLAY_NAME = "Github Authentication" @@ -56,37 +56,53 @@ export class GithubAuthService extends AbstractAuthModuleProvider { } async authenticate( - req: AuthenticationInput + req: AuthenticationInput, + authIdentityService: AuthIdentityProviderService ): Promise { - if (req.query?.error) { + const query: Record = req.query ?? {} + const body: Record = req.body ?? {} + + if (query.error) { return { success: false, - error: `${req.query.error_description}, read more at: ${req.query.error_uri}`, + error: `${query.error_description}, read more at: ${query.error_uri}`, } } - return this.getRedirect(this.config_) + const stateKey = crypto.randomBytes(32).toString("hex") + const state = { + callback_url: body?.callback_url ?? this.config_.callbackUrl, + } + + await authIdentityService.setState(stateKey, state) + return this.getRedirect(this.config_.clientId, state.callback_url, stateKey) } async validateCallback( req: AuthenticationInput, authIdentityService: AuthIdentityProviderService ): Promise { - if (req.query && req.query.error) { + const query: Record = req.query ?? {} + const body: Record = req.body ?? {} + + if (query.error) { return { success: false, - error: `${req.query.error_description}, read more at: ${req.query.error_uri}`, + error: `${query.error_description}, read more at: ${query.error_uri}`, } } - const code = req.query?.code ?? req.body?.code + const code = query?.code ?? body?.code if (!code) { return { success: false, error: "No code provided" } } - const params = `client_id=${this.config_.clientId}&client_secret=${ - this.config_.clientSecret - }&code=${code}&redirect_uri=${encodeURIComponent(this.config_.callbackUrl)}` + const state = await authIdentityService.getState(query?.state as string) + if (!state) { + return { success: false, error: "No state provided, or session expired" } + } + + const params = `client_id=${this.config_.clientId}&client_secret=${this.config_.clientSecret}&code=${code}&redirect_uri=${state.callback_url}` const exchangeTokenUrl = new URL( `https://github.com/login/oauth/access_token?${params}` @@ -192,18 +208,12 @@ export class GithubAuthService extends AbstractAuthModuleProvider { } } - private getRedirect({ clientId, callbackUrl }: LocalServiceConfig) { - const redirectUrlParam = `redirect_uri=${encodeURIComponent(callbackUrl)}` - const clientIdParam = `client_id=${clientId}` - const responseTypeParam = "response_type=code" - - const authUrl = new URL( - `https://github.com/login/oauth/authorize?${[ - redirectUrlParam, - clientIdParam, - responseTypeParam, - ].join("&")}` - ) + private getRedirect(clientId: string, callbackUrl: string, stateKey: string) { + const authUrl = new URL(`https://github.com/login/oauth/authorize`) + authUrl.searchParams.set("redirect_uri", callbackUrl) + authUrl.searchParams.set("client_id", clientId) + authUrl.searchParams.set("response_type", "code") + authUrl.searchParams.set("state", stateKey) return { success: true, location: authUrl.toString() } } diff --git a/packages/modules/providers/auth-google/integration-tests/__tests__/services.spec.ts b/packages/modules/providers/auth-google/integration-tests/__tests__/services.spec.ts index de641ab8a7162..f02eb2253b519 100644 --- a/packages/modules/providers/auth-google/integration-tests/__tests__/services.spec.ts +++ b/packages/modules/providers/auth-google/integration-tests/__tests__/services.spec.ts @@ -28,6 +28,22 @@ const encodedIdToken = generateJwtToken(sampleIdPayload, { }) const baseUrl = "https://someurl.com" +const callbackUrl = encodeURIComponent( + "https://someurl.com/auth/google/callback" +) + +let state = {} +const defaultSpies = { + retrieve: jest.fn(), + create: jest.fn(), + update: jest.fn(), + setState: jest.fn().mockImplementation((key, value) => { + state[key] = value + }), + getState: jest.fn().mockImplementation((key) => { + return Promise.resolve(state[key]) + }), +} // This is just a network-layer mocking, it doesn't start an actual server const server = setupServer( @@ -37,7 +53,7 @@ const server = setupServer( const url = request.url if ( url === - "https://oauth2.googleapis.com/token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgoogle%2Fcallback&grant_type=authorization_code" + `https://oauth2.googleapis.com/token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=${callbackUrl}&grant_type=authorization_code` ) { return new HttpResponse(null, { status: 401, @@ -47,7 +63,7 @@ const server = setupServer( if ( url === - "https://oauth2.googleapis.com/token?client_id=test&client_secret=test&code=valid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgoogle%2Fcallback&grant_type=authorization_code" + `https://oauth2.googleapis.com/token?client_id=test&client_secret=test&code=valid-code&redirect_uri=${callbackUrl}&grant_type=authorization_code` ) { return new HttpResponse( JSON.stringify({ @@ -90,6 +106,7 @@ describe("Google auth provider", () => { afterEach(() => { server.resetHandlers() jest.restoreAllMocks() + state = {} }) afterAll(() => server.close()) @@ -109,11 +126,27 @@ describe("Google auth provider", () => { }) it("returns a redirect URL on authenticate", async () => { - const res = await googleService.authenticate({}) + const res = await googleService.authenticate({}, defaultSpies) + expect(res).toEqual({ + success: true, + location: `https://accounts.google.com/o/oauth2/v2/auth?redirect_uri=${callbackUrl}&client_id=test&response_type=code&scope=email+profile+openid&state=${ + Object.keys(state)[0] + }`, + }) + }) + + it("returns a custom redirect_uri on authenticate", async () => { + const res = await googleService.authenticate( + { + body: { callback_url: "https://someotherurl.com" }, + }, + defaultSpies + ) expect(res).toEqual({ success: true, - location: - "https://accounts.google.com/o/oauth2/v2/auth?redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgoogle%2Fcallback&client_id=test&response_type=code&scope=email+profile+openid", + location: `https://accounts.google.com/o/oauth2/v2/auth?redirect_uri=https%3A%2F%2Fsomeotherurl.com&client_id=test&response_type=code&scope=email+profile+openid&state=${ + Object.keys(state)[0] + }`, }) }) @@ -122,7 +155,7 @@ describe("Google auth provider", () => { { query: {}, }, - {} as any + defaultSpies ) expect(res).toEqual({ success: false, @@ -130,14 +163,52 @@ describe("Google auth provider", () => { }) }) + it("validate callback should return an error on missing state", async () => { + const res = await googleService.validateCallback( + { + query: { + code: "valid-code", + }, + }, + defaultSpies + ) + expect(res).toEqual({ + success: false, + error: "No state provided, or session expired", + }) + }) + + it("validate callback should return an error on expired/invalid state", async () => { + const res = await googleService.validateCallback( + { + query: { + code: "valid-code", + state: "somekey", + }, + }, + defaultSpies + ) + expect(res).toEqual({ + success: false, + error: "No state provided, or session expired", + }) + }) + it("validate callback should return on a missing access token for code", async () => { + state = { + somekey: { + callback_url: callbackUrl, + }, + } + const res = await googleService.validateCallback( { query: { code: "invalid-code", + state: "somekey", }, }, - {} as any + defaultSpies ) expect(res).toEqual({ @@ -148,6 +219,7 @@ describe("Google auth provider", () => { it("validate callback should return successfully on a correct code for a new user", async () => { const authServiceSpies = { + ...defaultSpies, retrieve: jest.fn().mockImplementation(() => { throw new MedusaError(MedusaError.Types.NOT_FOUND, "Not found") }), @@ -166,10 +238,17 @@ describe("Google auth provider", () => { }), } + state = { + somekey: { + callback_url: callbackUrl, + }, + } + const res = await googleService.validateCallback( { query: { code: "valid-code", + state: "somekey", }, }, authServiceSpies @@ -190,6 +269,7 @@ describe("Google auth provider", () => { it("validate callback should return successfully on a correct code for an existing user", async () => { const authServiceSpies = { + ...defaultSpies, retrieve: jest.fn().mockImplementation(() => { return { provider_identities: [ @@ -208,10 +288,17 @@ describe("Google auth provider", () => { }), } + state = { + somekey: { + callback_url: callbackUrl, + }, + } + const res = await googleService.validateCallback( { query: { code: "valid-code", + state: "somekey", }, }, authServiceSpies diff --git a/packages/modules/providers/auth-google/src/services/google.ts b/packages/modules/providers/auth-google/src/services/google.ts index 5f65f0b17c1b2..9b35337cb5071 100644 --- a/packages/modules/providers/auth-google/src/services/google.ts +++ b/packages/modules/providers/auth-google/src/services/google.ts @@ -1,3 +1,4 @@ +import crypto from "crypto" import { AuthenticationInput, AuthenticationResponse, @@ -16,8 +17,6 @@ type InjectedDependencies = { } interface LocalServiceConfig extends GoogleAuthProviderOptions {} - -// TODO: Add state param that is stored in Redis, to prevent CSRF attacks export class GoogleAuthService extends AbstractAuthModuleProvider { static identifier = "google" static DISPLAY_NAME = "Google Authentication" @@ -57,39 +56,53 @@ export class GoogleAuthService extends AbstractAuthModuleProvider { } async authenticate( - req: AuthenticationInput + req: AuthenticationInput, + authIdentityService: AuthIdentityProviderService ): Promise { - if (req.query?.error) { + const query: Record = req.query ?? {} + const body: Record = req.body ?? {} + + if (query.error) { return { success: false, - error: `${req.query.error_description}, read more at: ${req.query.error_uri}`, + error: `${query.error_description}, read more at: ${query.error_uri}`, } } - return this.getRedirect(this.config_) + const stateKey = crypto.randomBytes(32).toString("hex") + const state = { + callback_url: body?.callback_url ?? this.config_.callbackUrl, + } + + await authIdentityService.setState(stateKey, state) + return this.getRedirect(this.config_.clientId, state.callback_url, stateKey) } async validateCallback( req: AuthenticationInput, authIdentityService: AuthIdentityProviderService ): Promise { - if (req.query && req.query.error) { + const query: Record = req.query ?? {} + const body: Record = req.body ?? {} + + if (query.error) { return { success: false, - error: `${req.query.error_description}, read more at: ${req.query.error_uri}`, + error: `${query.error_description}, read more at: ${query.error_uri}`, } } - const code = req.query?.code ?? req.body?.code + const code = query?.code ?? body?.code if (!code) { return { success: false, error: "No code provided" } } - const params = `client_id=${this.config_.clientId}&client_secret=${ - this.config_.clientSecret - }&code=${code}&redirect_uri=${encodeURIComponent( - this.config_.callbackUrl - )}&grant_type=authorization_code` + const state = await authIdentityService.getState(query?.state as string) + if (!state) { + return { success: false, error: "No state provided, or session expired" } + } + + const params = `client_id=${this.config_.clientId}&client_secret=${this.config_.clientSecret}&code=${code}&redirect_uri=${state.callback_url}&grant_type=authorization_code` const exchangeTokenUrl = new URL( `https://oauth2.googleapis.com/token?${params}` ) @@ -175,20 +188,13 @@ export class GoogleAuthService extends AbstractAuthModuleProvider { } } - private getRedirect({ clientId, callbackUrl }: LocalServiceConfig) { - const redirectUrlParam = `redirect_uri=${encodeURIComponent(callbackUrl)}` - const clientIdParam = `client_id=${clientId}` - const responseTypeParam = "response_type=code" - const scopeParam = "scope=email+profile+openid" - - const authUrl = new URL( - `https://accounts.google.com/o/oauth2/v2/auth?${[ - redirectUrlParam, - clientIdParam, - responseTypeParam, - scopeParam, - ].join("&")}` - ) + private getRedirect(clientId: string, callbackUrl: string, stateKey: string) { + const authUrl = new URL(`https://accounts.google.com/o/oauth2/v2/auth`) + authUrl.searchParams.set("redirect_uri", callbackUrl) + authUrl.searchParams.set("client_id", clientId) + authUrl.searchParams.set("response_type", "code") + authUrl.searchParams.set("scope", "email profile openid") + authUrl.searchParams.set("state", stateKey) return { success: true, location: authUrl.toString() } }