diff --git a/packages/authentication/integration-tests/__tests__/services/module/providers.spec.ts b/packages/authentication/integration-tests/__tests__/services/module/providers.spec.ts index 2f27b56a12580..63a3af08179a6 100644 --- a/packages/authentication/integration-tests/__tests__/services/module/providers.spec.ts +++ b/packages/authentication/integration-tests/__tests__/services/module/providers.spec.ts @@ -1,5 +1,6 @@ -import { IAuthenticationModuleService } from "@medusajs/types" import { MedusaModule, Modules } from "@medusajs/modules-sdk" + +import { IAuthenticationModuleService } from "@medusajs/types" import { MikroOrmWrapper } from "../../../utils" import { SqlEntityManager } from "@mikro-orm/postgresql" import { createAuthProviders } from "../../../__fixtures__/auth-provider" diff --git a/packages/authentication/integration-tests/__tests__/services/providers/username-password.spec.ts b/packages/authentication/integration-tests/__tests__/services/providers/username-password.spec.ts index e667444b50ec3..3f5ec7da6dc55 100644 --- a/packages/authentication/integration-tests/__tests__/services/providers/username-password.spec.ts +++ b/packages/authentication/integration-tests/__tests__/services/providers/username-password.spec.ts @@ -1,5 +1,6 @@ -import { IAuthenticationModuleService } from "@medusajs/types" import { MedusaModule, Modules } from "@medusajs/modules-sdk" + +import { IAuthenticationModuleService } from "@medusajs/types" import { MikroOrmWrapper } from "../../../utils" import Scrypt from "scrypt-kdf" import { SqlEntityManager } from "@mikro-orm/postgresql" diff --git a/packages/authentication/src/providers/google.ts b/packages/authentication/src/providers/google.ts index 39bc0dbe9f6a8..a1d7cd48bc3d0 100644 --- a/packages/authentication/src/providers/google.ts +++ b/packages/authentication/src/providers/google.ts @@ -43,55 +43,17 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider { this.authProviderService_ = authProviderService } - private async validateConfig(config: Partial) { - if (!config.clientID) { - throw new Error("Google clientID is required") - } - - if (!config.clientSecret) { - throw new Error("Google clientSecret is required") - } - - if (!config.callbackURL) { - throw new Error("Google callbackUrl is required") - } - } - - private originalURL(req: AuthenticationInput) { - const tls = req.connection.encrypted, - host = req.headers.host, - protocol = tls ? "https" : "http", - path = req.url || "" - return protocol + "://" + host + path - } - - async getProviderConfig(req: AuthenticationInput): Promise { - const { config } = (await this.authProviderService_.retrieve( - GoogleProvider.PROVIDER - )) as AuthProvider & { config: ProviderConfig } - - this.validateConfig(config || {}) - - const { callbackURL } = config - - const parsedCallbackUrl = !url.parse(callbackURL).protocol - ? url.resolve(this.originalURL(req), callbackURL) - : callbackURL - - return { ...config, callbackURL: parsedCallbackUrl } - } - async authenticate( req: AuthenticationInput ): Promise { - if (req.query && req.query.error) { + if (req.query?.error) { return { success: false, error: `${req.query.error_description}, read more at: ${req.query.error_uri}`, } } - let config + let config: ProviderConfig try { config = await this.getProviderConfig(req) @@ -99,43 +61,30 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider { return { success: false, error: error.message } } - let { callbackURL, clientID, clientSecret } = config - - const meta: ProviderConfig = { - clientID, - callbackURL, - clientSecret, - } - - const code = (req.query && req.query.code) || (req.body && req.body.code) - - // Redirect to google - if (!code) { - return this.getRedirect(meta) - } - - return await this.validateCallback(code, meta) + return this.getRedirect(config) } - // abstractable - private async validateCallback( - code: string, - { clientID, callbackURL, clientSecret }: ProviderConfig - ) { - const client = this.getAuthorizationCodeHandler({ clientID, clientSecret }) - - const tokenParams = { - code, - redirect_uri: callbackURL, + async validateCallback( + req: AuthenticationInput + ): Promise { + if (req.query && req.query.error) { + return { + success: false, + error: `${req.query.error_description}, read more at: ${req.query.error_uri}`, + } } - try { - const accessToken = await client.getToken(tokenParams) + let config: ProviderConfig - return await this.verify_(accessToken.token.id_token) + try { + config = await this.getProviderConfig(req) } catch (error) { return { success: false, error: error.message } } + + const code = req.query?.code ?? req.body?.code + + return await this.validateCallbackToken(code, config) } // abstractable @@ -169,6 +118,68 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider { return { success: true, authUser } } + // abstractable + private async validateCallbackToken( + code: string, + { clientID, callbackURL, clientSecret }: ProviderConfig + ) { + const client = this.getAuthorizationCodeHandler({ clientID, clientSecret }) + + const tokenParams = { + code, + redirect_uri: callbackURL, + } + + try { + const accessToken = await client.getToken(tokenParams) + + return await this.verify_(accessToken.token.id_token) + } catch (error) { + return { success: false, error: error.message } + } + } + + private async validateConfig(config: Partial) { + if (!config.clientID) { + throw new Error("Google clientID is required") + } + + if (!config.clientSecret) { + throw new Error("Google clientSecret is required") + } + + if (!config.callbackURL) { + throw new Error("Google callbackUrl is required") + } + } + + private originalURL(req: AuthenticationInput) { + const tls = req.connection.encrypted + const host = req.headers.host + const protocol = tls ? "https" : "http" + const path = req.url || "" + + return protocol + "://" + host + path + } + + private async getProviderConfig( + req: AuthenticationInput + ): Promise { + const { config } = (await this.authProviderService_.retrieve( + GoogleProvider.PROVIDER + )) as AuthProvider & { config: ProviderConfig } + + this.validateConfig(config || {}) + + const { callbackURL } = config + + const parsedCallbackUrl = !url.parse(callbackURL).protocol + ? url.resolve(this.originalURL(req), callbackURL) + : callbackURL + + return { ...config, callbackURL: parsedCallbackUrl } + } + // Abstractable private getRedirect({ clientID, callbackURL, clientSecret }: ProviderConfig) { const client = this.getAuthorizationCodeHandler({ clientID, clientSecret }) diff --git a/packages/authentication/src/providers/username-password.ts b/packages/authentication/src/providers/username-password.ts index b19f26c5d7526..905109fb929b6 100644 --- a/packages/authentication/src/providers/username-password.ts +++ b/packages/authentication/src/providers/username-password.ts @@ -1,8 +1,8 @@ -import { AuthenticationResponse } from "@medusajs/types" +import { AbstractAuthenticationModuleProvider, isString } from "@medusajs/utils" import { AuthUserService } from "@services" +import { AuthenticationResponse } from "@medusajs/types" import Scrypt from "scrypt-kdf" -import { AbstractAuthenticationModuleProvider, isString } from "@medusajs/utils" class UsernamePasswordProvider extends AbstractAuthenticationModuleProvider { public static PROVIDER = "usernamePassword" diff --git a/packages/authentication/src/services/authentication-module.ts b/packages/authentication/src/services/authentication-module.ts index f2fdf4948cd31..be6d1216ac470 100644 --- a/packages/authentication/src/services/authentication-module.ts +++ b/packages/authentication/src/services/authentication-module.ts @@ -368,18 +368,15 @@ export default class AuthenticationModuleService< return containerProvider } - @InjectTransactionManager("baseRepository_") async authenticate( provider: string, - authenticationData: Record, - @MedusaContext() sharedContext: Context = {} + authenticationData: Record ): Promise { - let registeredProvider - try { await this.retrieveAuthProvider(provider, {}) - registeredProvider = this.getRegisteredAuthenticationProvider(provider) + const registeredProvider = + this.getRegisteredAuthenticationProvider(provider) return await registeredProvider.authenticate(authenticationData) } catch (error) { @@ -387,6 +384,22 @@ export default class AuthenticationModuleService< } } + async validateCallback( + provider: string, + authenticationData: Record + ): Promise { + try { + await this.retrieveAuthProvider(provider, {}) + + const registeredProvider = + this.getRegisteredAuthenticationProvider(provider) + + return await registeredProvider.validateCallback(authenticationData) + } catch (error) { + return { success: false, error: error.message } + } + } + private async createProvidersOnLoad() { const providersToLoad = this.__container__["auth_providers"] diff --git a/packages/types/src/authentication/service.ts b/packages/types/src/authentication/service.ts index ddd0d76c6b4a6..bd2c8728e790e 100644 --- a/packages/types/src/authentication/service.ts +++ b/packages/types/src/authentication/service.ts @@ -1,7 +1,7 @@ import { - AuthenticationResponse, AuthProviderDTO, AuthUserDTO, + AuthenticationResponse, CreateAuthProviderDTO, CreateAuthUserDTO, FilterableAuthProviderProps, @@ -20,6 +20,11 @@ export interface IAuthenticationModuleService extends IModuleService { providerData: Record ): Promise + validateCallback( + provider: string, + providerData: Record + ): Promise + retrieveAuthProvider( provider: string, config?: FindConfig, diff --git a/packages/utils/src/authentication/abstract-authentication-provider.ts b/packages/utils/src/authentication/abstract-authentication-provider.ts index f4d69ab1bd82d..bfa21be4d84b9 100644 --- a/packages/utils/src/authentication/abstract-authentication-provider.ts +++ b/packages/utils/src/authentication/abstract-authentication-provider.ts @@ -1,4 +1,4 @@ -import { AuthenticationResponse } from "@medusajs/types"; +import { AuthenticationResponse } from "@medusajs/types" export abstract class AbstractAuthenticationModuleProvider { public static PROVIDER: string @@ -16,4 +16,12 @@ export abstract class AbstractAuthenticationModuleProvider { abstract authenticate( data: Record ): Promise + + public validateCallback( + data: Record + ): Promise { + throw new Error( + `Callback authentication not implemented for provider ${this.provider}` + ) + } }