Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth-google,auth-github): Allow passing a custom callbackUrl to … #10829

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/core/modules-sdk/src/definitions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ export const ModulesDefinition: {
label: upperCaseFirst(Modules.AUTH),
isRequired: false,
isQueryable: true,
dependencies: [ContainerRegistrationKeys.LOGGER],
dependencies: [ContainerRegistrationKeys.LOGGER, Modules.CACHE],
defaultModuleDeclaration: {
scope: MODULE_SCOPE.INTERNAL,
},
Expand Down
4 changes: 4 additions & 0 deletions packages/core/types/src/auth/common/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ export type AuthenticationInput = {

/**
* Body of the incoming authentication request.
*
* One of the arguments that is suggested to be treated in a standard manner is a `callback_url` field.
* The field specifies where the user is redirected to after a successful authentication in the case of Oauth auhentication.
* If not passed, the provider will fallback to the callback_url provided in the provider options.
*/
body?: Record<string, string>

Expand Down
3 changes: 3 additions & 0 deletions packages/core/types/src/auth/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ export interface AuthIdentityProviderService {
user_metadata?: Record<string, unknown>
}
) => Promise<AuthIdentityDTO>
// These methods are used for OAuth providers to store and retrieve state
setState: (key: string, value: Record<string, unknown>) => Promise<void>
getState: (key: string) => Promise<Record<string, unknown> | null>
}

/**
Expand Down
27 changes: 26 additions & 1 deletion packages/modules/auth/src/services/auth-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
AuthTypes,
Context,
DAL,
ICacheService,
InferEntityType,
InternalModuleDeclaration,
Logger,
Expand All @@ -27,6 +28,7 @@ type InjectedDependencies = {
providerIdentityService: ModulesSdkTypes.IMedusaInternalService<any>
authProviderService: AuthProviderService
logger?: Logger
cache?: ICacheService
}
export default class AuthModuleService
extends MedusaService<{
Expand All @@ -43,13 +45,14 @@ export default class AuthModuleService
InferEntityType<typeof ProviderIdentity>
>
protected readonly authProviderService_: AuthProviderService

protected readonly cache_: ICacheService | undefined
constructor(
{
authIdentityService,
providerIdentityService,
authProviderService,
baseRepository,
cache,
}: InjectedDependencies,
protected readonly moduleDeclaration: InternalModuleDeclaration
) {
Expand All @@ -60,6 +63,7 @@ export default class AuthModuleService
this.authIdentityService_ = authIdentityService
this.authProviderService_ = authProviderService
this.providerIdentityService_ = providerIdentityService
this.cache_ = cache
}

__joinerConfig(): ModuleJoinerConfig {
Expand Down Expand Up @@ -372,6 +376,27 @@ export default class AuthModuleService

return serializedResponse
},
setState: async (key: string, value: Record<string, unknown>) => {
if (!this.cache_) {
throw new MedusaError(
MedusaError.Types.INVALID_ARGUMENT,
"Cache module dependency is required when using OAuth providers that require state"
)
}

// 20 minutes. Can be made configurable if necessary, but this is a good default.
this.cache_.set(key, value, 1200)
},
getState: async (key: string) => {
if (!this.cache_) {
throw new MedusaError(
MedusaError.Types.INVALID_ARGUMENT,
"Cache module dependency is required when using OAuth providers that require state"
)
}

return await this.cache_.get(key)
},
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { generateJwtToken, MedusaError } from "@medusajs/framework/utils"
import { MedusaError } from "@medusajs/framework/utils"
import { GithubAuthService } from "../../src/services/github"
import { http, HttpResponse } from "msw"
import { setupServer } from "msw/node"
Expand All @@ -20,6 +20,22 @@ const sampleIdPayload = {
}

const baseUrl = "https://someurl.com"
const callbackUrl = encodeURIComponent(
"https://someurl.com/auth/github/callback"
)

let state = {}
const defaultSpies = {
retrieve: jest.fn(),
create: jest.fn(),
update: jest.fn(),
setState: jest.fn().mockImplementation((key, value) => {
state[key] = value
}),
getState: jest.fn().mockImplementation((key) => {
return Promise.resolve(state[key])
}),
}

// This is just a network-layer mocking, it doesn't start an actual server
const server = setupServer(
Expand All @@ -29,7 +45,7 @@ const server = setupServer(
const url = request.url
if (
url ===
"https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback"
`https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=invalid-code&redirect_uri=${callbackUrl}`
) {
return new HttpResponse(null, {
status: 401,
Expand All @@ -39,7 +55,7 @@ const server = setupServer(

if (
url ===
"https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=valid-code&redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback"
`https://github.com/login/oauth/access_token?client_id=test&client_secret=test&code=valid-code&redirect_uri=${callbackUrl}`
) {
return new HttpResponse(
JSON.stringify({
Expand Down Expand Up @@ -91,6 +107,7 @@ describe("Github auth provider", () => {
afterEach(() => {
server.resetHandlers()
jest.restoreAllMocks()
state = {}
})

afterAll(() => server.close())
Expand All @@ -110,11 +127,27 @@ describe("Github auth provider", () => {
})

it("returns a redirect URL on authenticate", async () => {
const res = await githubService.authenticate({})
const res = await githubService.authenticate({}, defaultSpies)
expect(res).toEqual({
success: true,
location: `https://github.com/login/oauth/authorize?redirect_uri=${callbackUrl}&client_id=test&response_type=code&state=${
Object.keys(state)[0]
}`,
})
})

it("returns a custom redirect_uri on authenticate", async () => {
const res = await githubService.authenticate(
{
body: { callback_url: "https://someotherurl.com" },
},
defaultSpies
)
expect(res).toEqual({
success: true,
location:
"https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fsomeurl.com%2Fauth%2Fgithub%2Fcallback&client_id=test&response_type=code",
location: `https://github.com/login/oauth/authorize?redirect_uri=https%3A%2F%2Fsomeotherurl.com&client_id=test&response_type=code&state=${
Object.keys(state)[0]
}`,
})
})

Expand All @@ -123,22 +156,59 @@ describe("Github auth provider", () => {
{
query: {},
},
{} as any
defaultSpies
)
expect(res).toEqual({
success: false,
error: "No code provided",
})
})

it("validate callback should return an error on missing state", async () => {
const res = await githubService.validateCallback(
{
query: {
code: "valid-code",
},
},
defaultSpies
)
expect(res).toEqual({
success: false,
error: "No state provided, or session expired",
})
})

it("validate callback should return an error on expired/invalid state", async () => {
const res = await githubService.validateCallback(
{
query: {
code: "valid-code",
state: "somekey",
},
},
defaultSpies
)
expect(res).toEqual({
success: false,
error: "No state provided, or session expired",
})
})

it("validate callback should return on a missing access token for code", async () => {
state = {
somekey: {
callback_url: callbackUrl,
},
}
const res = await githubService.validateCallback(
{
query: {
code: "invalid-code",
state: "somekey",
},
},
{} as any
defaultSpies
)

expect(res).toEqual({
Expand All @@ -149,6 +219,7 @@ describe("Github auth provider", () => {

it("validate callback should return successfully on a correct code for a new user", async () => {
const authServiceSpies = {
...defaultSpies,
retrieve: jest.fn().mockImplementation(() => {
throw new MedusaError(MedusaError.Types.NOT_FOUND, "Not found")
}),
Expand All @@ -167,10 +238,17 @@ describe("Github auth provider", () => {
}),
}

state = {
somekey: {
callback_url: callbackUrl,
},
}

const res = await githubService.validateCallback(
{
query: {
code: "valid-code",
state: "somekey",
},
},
authServiceSpies
Expand All @@ -191,6 +269,7 @@ describe("Github auth provider", () => {

it("validate callback should return successfully on a correct code for an existing user", async () => {
const authServiceSpies = {
...defaultSpies,
retrieve: jest.fn().mockImplementation(() => {
return {
provider_identities: [
Expand Down Expand Up @@ -219,10 +298,17 @@ describe("Github auth provider", () => {
}),
}

state = {
somekey: {
callback_url: callbackUrl,
},
}

const res = await githubService.validateCallback(
{
query: {
code: "valid-code",
state: "somekey",
},
},
authServiceSpies
Expand Down
56 changes: 33 additions & 23 deletions packages/modules/providers/auth-github/src/services/github.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import crypto from "crypto"
import {
AuthenticationInput,
AuthenticationResponse,
Expand All @@ -16,7 +17,6 @@ type InjectedDependencies = {

interface LocalServiceConfig extends GithubAuthProviderOptions {}

// TODO: Add state param that is stored in Redis, to prevent CSRF attacks
export class GithubAuthService extends AbstractAuthModuleProvider {
static identifier = "github"
static DISPLAY_NAME = "Github Authentication"
Expand Down Expand Up @@ -56,37 +56,53 @@ export class GithubAuthService extends AbstractAuthModuleProvider {
}

async authenticate(
req: AuthenticationInput
req: AuthenticationInput,
authIdentityService: AuthIdentityProviderService
): Promise<AuthenticationResponse> {
if (req.query?.error) {
const query: Record<string, string> = req.query ?? {}
const body: Record<string, string> = req.body ?? {}

if (query.error) {
return {
success: false,
error: `${req.query.error_description}, read more at: ${req.query.error_uri}`,
error: `${query.error_description}, read more at: ${query.error_uri}`,
}
}

return this.getRedirect(this.config_)
const stateKey = crypto.randomBytes(32).toString("hex")
const state = {
callback_url: body?.callback_url ?? this.config_.callbackUrl,
}

await authIdentityService.setState(stateKey, state)
return this.getRedirect(this.config_.clientId, state.callback_url, stateKey)
}

async validateCallback(
req: AuthenticationInput,
authIdentityService: AuthIdentityProviderService
): Promise<AuthenticationResponse> {
if (req.query && req.query.error) {
const query: Record<string, string> = req.query ?? {}
const body: Record<string, string> = req.body ?? {}

if (query.error) {
return {
success: false,
error: `${req.query.error_description}, read more at: ${req.query.error_uri}`,
error: `${query.error_description}, read more at: ${query.error_uri}`,
}
}

const code = req.query?.code ?? req.body?.code
const code = query?.code ?? body?.code
if (!code) {
return { success: false, error: "No code provided" }
}

const params = `client_id=${this.config_.clientId}&client_secret=${
this.config_.clientSecret
}&code=${code}&redirect_uri=${encodeURIComponent(this.config_.callbackUrl)}`
const state = await authIdentityService.getState(query?.state as string)
if (!state) {
return { success: false, error: "No state provided, or session expired" }
}

const params = `client_id=${this.config_.clientId}&client_secret=${this.config_.clientSecret}&code=${code}&redirect_uri=${state.callback_url}`

const exchangeTokenUrl = new URL(
`https://github.com/login/oauth/access_token?${params}`
Expand Down Expand Up @@ -192,18 +208,12 @@ export class GithubAuthService extends AbstractAuthModuleProvider {
}
}

private getRedirect({ clientId, callbackUrl }: LocalServiceConfig) {
const redirectUrlParam = `redirect_uri=${encodeURIComponent(callbackUrl)}`
const clientIdParam = `client_id=${clientId}`
const responseTypeParam = "response_type=code"

const authUrl = new URL(
`https://github.com/login/oauth/authorize?${[
redirectUrlParam,
clientIdParam,
responseTypeParam,
].join("&")}`
)
private getRedirect(clientId: string, callbackUrl: string, stateKey: string) {
const authUrl = new URL(`https://github.com/login/oauth/authorize`)
authUrl.searchParams.set("redirect_uri", callbackUrl)
authUrl.searchParams.set("client_id", clientId)
authUrl.searchParams.set("response_type", "code")
authUrl.searchParams.set("state", stateKey)

return { success: true, location: authUrl.toString() }
}
Expand Down
Loading
Loading