Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(services-bff): Refresh token with polling #16872

Merged
merged 14 commits into from
Nov 29, 2024
23 changes: 8 additions & 15 deletions apps/services/bff/src/app/modules/auth/auth.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import jwt from 'jsonwebtoken'
import request from 'supertest'
import { setupTestServer } from '../../../../test/setupTestServer'
import {
mockedTokensResponse as tokensResponse,
SID_VALUE,
SESSION_COOKIE_NAME,
ALGORITM_TYPE,
SESSION_COOKIE_NAME,
SID_VALUE,
getLoginSearchParmsFn,
mockedTokensResponse as tokensResponse,
} from '../../../../test/sharedConstants'
import { environment } from '../../../environment'
import { BffConfig } from '../../bff.config'
import { IdsService } from '../ids/ids.service'
import { ParResponse } from '../ids/ids.types'
Expand Down Expand Up @@ -58,17 +59,9 @@ const parResponse: ParResponse = {
const allowedTargetLinkUri = 'http://test-client.com/testclient'

const mockIdsService = {
getPar: jest.fn().mockResolvedValue({
type: 'success',
data: parResponse,
}),
getTokens: jest.fn().mockResolvedValue({
type: 'success',
data: tokensResponse,
}),
revokeToken: jest.fn().mockResolvedValue({
type: 'success',
}),
getPar: jest.fn().mockResolvedValue(parResponse),
getTokens: jest.fn().mockResolvedValue(tokensResponse),
revokeToken: jest.fn().mockResolvedValue(undefined),
getLoginSearchParams: jest.fn().mockImplementation(getLoginSearchParmsFn),
}

Expand All @@ -89,7 +82,7 @@ describe('AuthController', () => {
})

mockConfig = app.get<ConfigType<typeof BffConfig>>(BffConfig.KEY)
baseUrlWithKey = `${mockConfig.clientBaseUrl}${process.env.BFF_CLIENT_KEY_PATH}`
baseUrlWithKey = `${mockConfig.clientBaseUrl}${environment.keyPath}`

server = request(app.getHttpServer())
})
Expand Down
36 changes: 10 additions & 26 deletions apps/services/bff/src/app/modules/auth/auth.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
UnauthorizedException,
} from '@nestjs/common'
import { ConfigType } from '@nestjs/config'
import { CookieOptions, Request, Response } from 'express'
import type { Request, Response } from 'express'
import jwksClient from 'jwks-rsa'
import { jwtDecode } from 'jwt-decode'

Expand All @@ -26,6 +26,7 @@ import {
CreateErrorQueryStrArgs,
createErrorQueryStr,
} from '../../utils/create-error-query-str'
import { getCookieOptions } from '../../utils/get-cookie-options'
import { validateUri } from '../../utils/validate-uri'
import { CacheService } from '../cache/cache.service'
import { IdsService } from '../ids/ids.service'
Expand Down Expand Up @@ -55,17 +56,6 @@ export class AuthService {
this.baseUrl = this.config.ids.issuer
}

private getCookieOptions(): CookieOptions {
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
return {
httpOnly: true,
secure: true,
// The lax setting allows cookies to be sent on top-level navigations (such as redirects),
// while still providing some protection against CSRF attacks.
sameSite: 'lax',
path: environment.keyPath,
}
}

/**
* Creates the client base URL with the path appended.
*/
Expand Down Expand Up @@ -212,12 +202,8 @@ export class AuthService {
prompt,
})

if (parResponse.type === 'error') {
throw parResponse.data
}

searchParams = new URLSearchParams({
request_uri: parResponse.data.request_uri,
request_uri: parResponse.request_uri,
client_id: this.config.ids.clientId,
})
} else {
Expand Down Expand Up @@ -297,13 +283,7 @@ export class AuthService {
codeVerifier: loginAttemptData.codeVerifier,
})

if (tokenResponse.type === 'error') {
snaerseljan marked this conversation as resolved.
Show resolved Hide resolved
throw tokenResponse.data
}

const updatedTokenResponse = await this.updateTokenCache(
tokenResponse.data,
)
const updatedTokenResponse = await this.updateTokenCache(tokenResponse)

// Clean up the login attempt from the cache since we have a successful login.
this.cacheService
Expand All @@ -312,11 +292,15 @@ export class AuthService {
this.logger.warn(err)
})

// Clear any existing session cookie first
// This prevents multiple session cookies being set.
res.clearCookie(SESSION_COOKIE_NAME, getCookieOptions())

// Create session cookie with successful login session id
res.cookie(
SESSION_COOKIE_NAME,
updatedTokenResponse.userProfile.sid,
this.getCookieOptions(),
getCookieOptions(),
)

// Check if there is an old session cookie and clean up the cache
Expand Down Expand Up @@ -424,7 +408,7 @@ export class AuthService {
* - Delete the current login from the cache
* - Clear the session cookie
*/
res.clearCookie(SESSION_COOKIE_NAME, this.getCookieOptions())
res.clearCookie(SESSION_COOKIE_NAME, getCookieOptions())

this.cacheService
.delete(currentLoginCacheKey)
Expand Down
258 changes: 258 additions & 0 deletions apps/services/bff/src/app/modules/auth/token-refresh.service.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import { LOGGER_PROVIDER } from '@island.is/logging'
import { Test } from '@nestjs/testing'
import { CacheService } from '../cache/cache.service'
import { IdsService } from '../ids/ids.service'
import { TokenResponse } from '../ids/ids.types'
import { AuthService } from './auth.service'
import { CachedTokenResponse } from './auth.types'
import { TokenRefreshService } from './token-refresh.service'

const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))

jest.mock('uuid', () => ({
v4: jest.fn().mockReturnValue('fake_uuid'),
}))

const mockLogger = {
error: jest.fn(),
warn: jest.fn(),
}

const mockCacheStore = new Map()

const mockTokenResponse: CachedTokenResponse = {
id_token: 'mock.id.token',
expires_in: 3600,
token_type: 'Bearer',
scope: 'openid profile offline_access',
scopes: ['openid', 'profile', 'offline_access'],
userProfile: {
sid: 'test-session-id',
nationalId: '1234567890',
name: 'Test User',
idp: 'test-idp',
subjectType: 'person',
delegationType: [],
locale: 'is',
birthdate: '1990-01-01',
},
accessTokenExp: Date.now() + 3600000, // Current time + 1 hour in milliseconds
encryptedAccessToken: 'encrypted.access.token',
encryptedRefreshToken: 'encrypted.refresh.token',
}

// When mocking IdsService.refreshToken response, we need TokenResponse type:
const mockIdsTokenResponse: TokenResponse = {
id_token: 'mock.id.token',
access_token: 'mock.access.token',
refresh_token: 'mock.refresh.token',
expires_in: 3600,
token_type: 'Bearer',
scope: 'openid profile offline_access',
}

describe('TokenRefreshService', () => {
let service: TokenRefreshService
let authService: AuthService
let idsService: IdsService
let cacheService: CacheService
const testSid = 'test-sid'
const testRefreshToken = 'test-refresh-token'
const refreshInProgressPrefix = 'refresh_token_in_progress'
const refreshInProgressKey = `${refreshInProgressPrefix}:${testSid}`

beforeEach(async () => {
const module = await Test.createTestingModule({
providers: [
TokenRefreshService,
{
provide: LOGGER_PROVIDER,
useValue: mockLogger,
},
{
provide: AuthService,
useValue: {
updateTokenCache: jest.fn().mockResolvedValue(mockTokenResponse),
},
},
{
provide: IdsService,
useValue: {
refreshToken: jest.fn().mockResolvedValue(mockTokenResponse),
},
},
{
provide: CacheService,
useValue: {
save: jest.fn().mockImplementation(async ({ key, value }) => {
mockCacheStore.set(key, value)
}),
get: jest
.fn()
.mockImplementation(async (key) => mockCacheStore.get(key)),
delete: jest
.fn()
.mockImplementation(async (key) => mockCacheStore.delete(key)),
createSessionKeyType: jest.fn((type, sid) => `${type}_${sid}`),
},
},
],
}).compile()

service = module.get<TokenRefreshService>(TokenRefreshService)
authService = module.get<AuthService>(AuthService)
idsService = module.get<IdsService>(IdsService)
cacheService = module.get<CacheService>(CacheService)
})

afterEach(() => {
mockCacheStore.clear()
jest.clearAllMocks()
})

describe('refreshToken', () => {
it('should successfully refresh token when no refresh is in progress', async () => {
// Act
const result = await service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
})

// Assert
expect(idsService.refreshToken).toHaveBeenCalledWith(testRefreshToken)
expect(authService.updateTokenCache).toHaveBeenCalledWith(
mockTokenResponse,
)
expect(result).toEqual(mockTokenResponse)
})

it('should wait for ongoing refresh and return cached result', async () => {
// Arrange
await cacheService.save({
key: refreshInProgressKey,
value: true,
ttl: 3000,
})

// Simulate another service updating the token while we wait
setTimeout(async () => {
await cacheService.delete(refreshInProgressKey)
await cacheService.save({
key: `current_${testSid}`,
value: mockTokenResponse,
ttl: 3600,
})
}, 500)

// Act
const result = await service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
})

// Assert
expect(result).toEqual(mockTokenResponse)
expect(idsService.refreshToken).not.toHaveBeenCalled()
})

it('should retry refresh if polling times out', async () => {
// Arrange
await cacheService.save({
key: refreshInProgressKey,
value: true,
ttl: 3000,
})

// Act
const result = await service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
})

// Assert
expect(mockLogger.warn).toHaveBeenCalled()
expect(idsService.refreshToken).toHaveBeenCalledWith(testRefreshToken)
expect(result).toEqual(mockTokenResponse)
})

it('should handle refresh token failure', async () => {
// Arrange
const error = new Error('Refresh token failed')
jest.spyOn(idsService, 'refreshToken').mockRejectedValueOnce(error)

// Act
const cachedTokenResponse = await service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
})
//
expect(cachedTokenResponse).toBe(null)

expect(mockLogger.warn).toHaveBeenCalledWith(
`Token refresh failed for sid: ${testSid}`,
)
})

it('should prevent concurrent refresh token requests', async () => {
// Arrange
const refreshPromises = []
const refreshCount = 5
let firstRequestStarted = false

// Mock cache.get to make sure first request get in progress lock and other requests waits
jest.spyOn(cacheService, 'get').mockImplementation(async (key) => {
if (key.includes(refreshInProgressPrefix)) {
return firstRequestStarted
}
return mockTokenResponse
})

// Mock cache.save to track first request
jest.spyOn(cacheService, 'save').mockImplementation(async ({ key }) => {
if (key.includes(refreshInProgressPrefix)) {
firstRequestStarted = true
// Add delay after setting lock
await delay(50)
}
})

// Mock cache.delete to clear the lock
jest.spyOn(cacheService, 'delete').mockImplementation(async (key) => {
if (key.includes(refreshInProgressPrefix)) {
firstRequestStarted = false
}
})

// Act
// First request
refreshPromises.push(
service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
}),
)

// Wait a tick to ensure first request starts
await delay(10)

// Remaining requests
for (let i = 1; i < refreshCount; i++) {
refreshPromises.push(
service.refreshToken({
sid: testSid,
encryptedRefreshToken: testRefreshToken,
}),
)
}

// Wait for all promises to resolve
const results = await Promise.all(refreshPromises)

// Assert
expect(idsService.refreshToken).toHaveBeenCalledTimes(1)
results.forEach((result) => {
expect(result).toEqual(mockTokenResponse)
})
})
})
})
Loading
Loading