Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(services-bff): Refresh token with polling #16872

Merged
merged 14 commits into from
Nov 29, 2024
32 changes: 6 additions & 26 deletions apps/services/bff/src/app/modules/auth/auth.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
UnauthorizedException,
} from '@nestjs/common'
import { ConfigType } from '@nestjs/config'
import { CookieOptions, Request, Response } from 'express'
import type { Request, Response } from 'express'
import jwksClient from 'jwks-rsa'
import { jwtDecode } from 'jwt-decode'

Expand All @@ -26,6 +26,7 @@ import {
CreateErrorQueryStrArgs,
createErrorQueryStr,
} from '../../utils/create-error-query-str'
import { getCookieOptions } from '../../utils/get-cookie-options'
import { validateUri } from '../../utils/validate-uri'
import { CacheService } from '../cache/cache.service'
import { IdsService } from '../ids/ids.service'
Expand Down Expand Up @@ -55,17 +56,6 @@ export class AuthService {
this.baseUrl = this.config.ids.issuer
}

private getCookieOptions(): CookieOptions {
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
return {
httpOnly: true,
secure: true,
// The lax setting allows cookies to be sent on top-level navigations (such as redirects),
// while still providing some protection against CSRF attacks.
sameSite: 'lax',
path: environment.keyPath,
}
}

/**
* Creates the client base URL with the path appended.
*/
Expand Down Expand Up @@ -212,12 +202,8 @@ export class AuthService {
prompt,
})

if (parResponse.type === 'error') {
throw parResponse.data
}

searchParams = new URLSearchParams({
request_uri: parResponse.data.request_uri,
request_uri: parResponse.request_uri,
client_id: this.config.ids.clientId,
})
} else {
Expand Down Expand Up @@ -297,13 +283,7 @@ export class AuthService {
codeVerifier: loginAttemptData.codeVerifier,
})

if (tokenResponse.type === 'error') {
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
throw tokenResponse.data
}

const updatedTokenResponse = await this.updateTokenCache(
tokenResponse.data,
)
const updatedTokenResponse = await this.updateTokenCache(tokenResponse)

// Clean up the login attempt from the cache since we have a successful login.
this.cacheService
Expand All @@ -316,7 +296,7 @@ export class AuthService {
res.cookie(
SESSION_COOKIE_NAME,
updatedTokenResponse.userProfile.sid,
this.getCookieOptions(),
getCookieOptions(),
)

// Check if there is an old session cookie and clean up the cache
Expand Down Expand Up @@ -424,7 +404,7 @@ export class AuthService {
* - Delete the current login from the cache
* - Clear the session cookie
*/
res.clearCookie(SESSION_COOKIE_NAME, this.getCookieOptions())
res.clearCookie(SESSION_COOKIE_NAME, getCookieOptions())

this.cacheService
.delete(currentLoginCacheKey)
Expand Down
201 changes: 201 additions & 0 deletions apps/services/bff/src/app/modules/auth/token-refresh.service.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import { LOGGER_PROVIDER, Logger } from '@island.is/logging'
import { Inject, Injectable } from '@nestjs/common'
import { CacheService } from '../cache/cache.service'
import { IdsService } from '../ids/ids.service'
import { AuthService } from './auth.service'
import { CachedTokenResponse } from './auth.types'

/**
* Service responsible for handling token refresh operations
* Provides concurrent request protection and token refresh polling
*/
@Injectable()
export class TokenRefreshService {
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
private static POLL_INTERVAL = 100 // ms
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
private static MAX_POLL_TIME = 3000 // 3 seconds
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved

constructor(
@Inject(LOGGER_PROVIDER)
private logger: Logger,

private readonly authService: AuthService,
private readonly cacheService: CacheService,
private readonly idsService: IdsService,
) {}

/**
* Creates a unique key for tracking refresh token operations in progress
* This key is used to prevent concurrent refresh token requests for the same session
*
* @param sid - Session ID
* @returns Formatted key string for refresh token tracking
*/
private createRefreshTokenKey(sid: string): string {
return `refresh_token_in_progress:${sid}`
}

/**
* Creates a key for storing token response data in cache
* This key is used to store and retrieve the current token data for a session
*
* @param sid - Session ID
* @returns Formatted key string for token response data
*/
private createTokenResponseKey(sid: string): string {
return this.cacheService.createSessionKeyType('current', sid)
}

/**
* Executes the token refresh operation and updates the cache
* This method:
* 1. Sets a flag in cache to indicate refresh is in progress
* 2. Requests new tokens from the identity server
* 3. Updates the cache with the new token data
* 4. Cleans up the refresh flag
*
* @param params.refreshTokenKey - Redis key for tracking refresh status
* @param params.encryptedRefreshToken - Encrypted refresh token for getting new tokens
*
* @returns Promise<CachedTokenResponse> Updated token data
* @throws Will throw if token refresh fails or cache operations fail
*/
private async executeTokenRefresh({
refreshTokenKey,
encryptedRefreshToken,
}: {
refreshTokenKey: string
encryptedRefreshToken: string
}) {
// Set refresh in progress
await this.cacheService.save({
key: refreshTokenKey,
value: true,
ttl: TokenRefreshService.MAX_POLL_TIME,
})

const tokenResponse = await this.idsService.refreshToken(
encryptedRefreshToken,
)

// Update cache with new token data
const updatedTokenResponse = await this.authService.updateTokenCache(
tokenResponse,
)

// Delete the refresh token key to signal that the refresh operation is complete
await this.cacheService.delete(refreshTokenKey)

return updatedTokenResponse
}
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved

/**
* Polls the cache to check if a refresh token operation has completed
* This prevents multiple concurrent refresh token requests and ensures
* all requests wait for the ongoing refresh to complete
*
* @param sid - Session ID
*
* @returns Promise that resolves when refresh is complete or rejects on timeout
* @throws Rejects if polling times out or encounters an error
*/
private async pollForRefreshCompletion(sid: string): Promise<void> {
return new Promise((resolve, reject) => {
const timeoutId = setTimeout(() => {
clearInterval(pollInterval)
reject(
new Error(
`Polling timed out for token refresh completion for session ${sid}`,
),
)
}, TokenRefreshService.MAX_POLL_TIME)

const pollInterval = setInterval(async () => {
try {
const refreshTokenInProgress = await this.cacheService.get<boolean>(
this.createRefreshTokenKey(sid),
false,
)

if (!refreshTokenInProgress) {
clearInterval(pollInterval)
clearTimeout(timeoutId)
resolve()
}
} catch (error) {
clearInterval(pollInterval)
clearTimeout(timeoutId)
reject(new Error(`Error polling for refresh completion: ${error}`))
}
}, TokenRefreshService.POLL_INTERVAL)
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
})
}

/**
* Handles the complete token refresh process with concurrent request protection
* This method:
*
* 1. Checks if a refresh is already in progress
* 2. If yes, waits for it to complete
* 3. If no, initiates a new refresh
* 4. Updates the cache with new token data
* 5. Cleans up tracking flags
*
* @param params.sid - Session ID
* @param params.encryptedRefreshToken - Encrypted refresh token to use for getting new tokens
*
* @returns Promise resolving to updated token response data
* @throws Forwards any errors from the refresh process after logging
*/
public async refreshToken({
sid,
encryptedRefreshToken,
}: {
sid: string
encryptedRefreshToken: string
}): Promise<CachedTokenResponse> {
const refreshTokenKey = this.createRefreshTokenKey(sid)
const tokenResponseKey = this.createTokenResponseKey(sid)

try {
// Check if refresh is already in progress
const refreshTokenInProgress = await this.cacheService.get<boolean>(
refreshTokenKey,
false,
)

if (refreshTokenInProgress) {
try {
// Wait for the ongoing refresh to complete
await this.pollForRefreshCompletion(sid)

// Get the updated token response from cache
const tokenResponse =
await this.cacheService.get<CachedTokenResponse>(tokenResponseKey)

return tokenResponse
} catch (error) {
this.logger.error(error)
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved

// If polling times out, then retry the refresh
const updatedTokenResponse = await this.executeTokenRefresh({
refreshTokenKey,
encryptedRefreshToken,
})

return updatedTokenResponse
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
}
}

const updatedTokenResponse = await this.executeTokenRefresh({
refreshTokenKey,
encryptedRefreshToken,
})

return updatedTokenResponse
} catch (error) {
this.logger.error(`Token refresh failed for sid: ${sid}`)
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved

throw error
}
}
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
}
77 changes: 22 additions & 55 deletions apps/services/bff/src/app/modules/ids/ids.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import { BffConfig } from '../../bff.config'
import { CryptoService } from '../../services/crypto.service'
import { ENHANCED_FETCH_PROVIDER_KEY } from '../enhancedFetch/enhanced-fetch.provider'
import {
ApiResponse,
ErrorRes,
GetLoginSearchParamsReturnValue,
ParResponse,
TokenResponse,
Expand Down Expand Up @@ -35,60 +33,28 @@ export class IdsService {
private async postRequest<T>(
endpoint: string,
body: Record<string, string>,
): Promise<ApiResponse<T>> {
try {
const response = await this.enhancedFetch(
`${this.issuerUrl}${endpoint}`,
{
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
Authorization: this.createPARAuthorizationHeader(),
},
body: new URLSearchParams(body).toString(),
},
)

const contentType = response.headers.get('content-type') || ''

if (contentType.includes('application/json')) {
const data = await response.json()

if (!response.ok) {
// If error response from Ids is not in the expected format, throw the data as is
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
if (!data.error || !data.error_description) {
throw data
}

return {
type: 'error',
data: {
error: data.error,
error_description: data.error_description,
},
} as ErrorRes
}

return {
type: 'success',
data: data as T,
}
}

// Handle plain text responses
const textResponse = await response.text()

if (!response.ok) {
throw textResponse
}

return {
type: 'success',
data: textResponse,
} as ApiResponse<T>
} catch (error) {
throw new Error(error)
): Promise<T> {
const response = await this.enhancedFetch(`${this.issuerUrl}${endpoint}`, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
Authorization: this.createPARAuthorizationHeader(),
},
body: new URLSearchParams(body).toString(),
})

const contentType = response.headers.get('content-type') || ''

if (contentType.includes('application/json')) {
const data = await response.json()

return data
}

// Handle plain text responses
const textResponse = await response.text()

return textResponse as T
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
}

public getLoginSearchParams({
Expand All @@ -103,6 +69,7 @@ export class IdsService {
prompt?: string
}): GetLoginSearchParamsReturnValue {
const { ids } = this.config

return {
client_id: ids.clientId,
redirect_uri: this.config.callbacksRedirectUris.login,
Expand Down
Loading
Loading