From 161205e72a67d63a2c0ed27f3120cc0107a0b698 Mon Sep 17 00:00:00 2001 From: Shuo Wu Date: Mon, 25 May 2020 16:07:21 -0400 Subject: [PATCH] feat: handle access token refresh - OKTA-291504 (#138) --- CHANGELOG.md | 4 + package.json | 2 +- src/client.js | 2 +- src/http.js | 111 +++++++++++---------- src/jwt.js | 4 +- src/oauth.js | 34 ++++--- test/jest/http.test.js | 215 +++++++++++++++++++++++++++------------- test/jest/jwt.test.js | 16 +-- test/jest/oauth.test.js | 45 ++++----- 9 files changed, 265 insertions(+), 168 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99b6b1588..79c628489 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Okta Node SDK Changelog +## 3.3.1 + +- [#138](https://github.com/okta/okta-sdk-nodejs/pull/138) Add strategy to handle access token refresh + ## 3.2.0 - [#128](https://github.com/okta/okta-sdk-nodejs/pull/128) Adds support for OAuth diff --git a/package.json b/package.json index 7e2135954..ad8c3daf1 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@okta/okta-sdk-nodejs", - "version": "3.3.0", + "version": "3.3.1", "description": "Okta API wrapper for Node.js", "engines": { "node": ">=8.11" diff --git a/src/client.js b/src/client.js index 2e8274b9f..81af54917 100644 --- a/src/client.js +++ b/src/client.js @@ -44,7 +44,7 @@ class Client extends GeneratedApiClient { errors.push('Okta Org URL not provided'); } - if (!parsedConfig.client.token) { + if (!parsedConfig.client.token && parsedConfig.client.authorizationMode !== 'PrivateKey') { errors.push('Okta API token not provided'); } diff --git a/src/http.js b/src/http.js index ccca7e25d..ab64c9e72 100644 --- a/src/http.js +++ b/src/http.js @@ -21,6 +21,26 @@ const defaultCacheMiddleware = require('./default-cache-middleware'); * @class Http */ class Http { + static errorFilter(response) { + if (response.status >= 200 && response.status < 300) { + return Promise.resolve(response); + } else { + return response.text() + .then(body => { + let err; + + // If the response is JSON, assume it's an Okta API error. Otherwise, assume it's some other HTTP error + + try { + err = new OktaApiError(response.url, response.status, JSON.parse(body), response.headers); + } catch (e) { + err = new HttpError(response.url, response.status, body, response.headers); + } + throw err; + }); + } + } + constructor(httpConfig) { this.defaultHeaders = {}; this.requestExecutor = httpConfig.requestExecutor; @@ -36,74 +56,57 @@ class Http { return Promise.resolve(); } - let getToken; - if (this.accessToken) { - getToken = Promise.resolve(this.accessToken); - } else { - getToken = this.oauth.getAccessToken() - .then(this.errorFilter) - .then(res => res.json()) - .then(accessToken => { - this.accessToken = accessToken; - return accessToken; - }); - } - - return getToken + return this.oauth.getAccessToken() .then(accessToken => { request.headers.Authorization = `Bearer ${accessToken.access_token}`; }); } - errorFilter(response) { - if (response.status >= 200 && response.status < 300) { - return response; - } else { - return response.text() - .then(body => { - let err; - - // If the response is JSON, assume it's an Okta API error. Otherwise, assume it's some other HTTP error - - try { - err = new OktaApiError(response.url, response.status, JSON.parse(body), response.headers); - } catch (e) { - err = new HttpError(response.url, response.status, body, response.headers); - } - throw err; - }); - } - } - http(uri, request, context) { request = request || {}; context = context || {}; request.url = uri; request.headers = Object.assign(this.defaultHeaders, request.headers); request.method = request.method || 'get'; - if (!this.cacheMiddleware) { - return this.prepareRequest(request) + + let retriedOnAuthError = false; + const execute = () => { + const promise = this.prepareRequest(request) .then(() => this.requestExecutor.fetch(request)) - .then(this.errorFilter); - } - const ctx = { - uri, // TODO: remove unused property. req.url should be the key. - isCollection: context.isCollection, - resources: context.resources, - req: request, - cacheStore: this.cacheStore - }; - return this.cacheMiddleware(ctx, () => { - if (ctx.res) { - return; + .then(Http.errorFilter) + .catch(error => { + // Clear cached token then retry request one more time + if (this.oauth && error && error.status === 401 && !retriedOnAuthError) { + retriedOnAuthError = true; + this.oauth.clearCachedAccessToken(); + return execute(); + } + + throw error; + }); + + if (!this.cacheMiddleware) { + return promise; } - return this.prepareRequest(request) - .then(() => this.requestExecutor.fetch(request)) - .then(this.errorFilter) - .then(res => ctx.res = res); - }) - .then(() => ctx.res); + const ctx = { + uri, // TODO: remove unused property. req.url should be the key. + isCollection: context.isCollection, + resources: context.resources, + req: request, + cacheStore: this.cacheStore + }; + return this.cacheMiddleware(ctx, () => { + if (ctx.res) { + return; + } + + return promise.then(res => ctx.res = res); + }) + .then(() => ctx.res); + }; + + return execute(); } delete(uri, request, context) { diff --git a/src/jwt.js b/src/jwt.js index d637c58e1..cc8fd97cf 100644 --- a/src/jwt.js +++ b/src/jwt.js @@ -32,12 +32,12 @@ function getPemAndJwk(privateKey) { } } -function makeJwt(client) { +function makeJwt(client, endpoint) { const now = Math.floor(new Date().getTime() / 1000); // seconds since epoch const plus5Minutes = new Date((now + (5 * 60)) * 1000); // Date object const claims = { - aud: `${client.baseUrl}/oauth2/v1/token`, + aud: `${client.baseUrl}${endpoint}`, }; return getPemAndJwk(client.privateKey) .then(res => { diff --git a/src/oauth.js b/src/oauth.js index e0ca17cc2..0121a1587 100644 --- a/src/oauth.js +++ b/src/oauth.js @@ -1,4 +1,5 @@ const { makeJwt } = require('./jwt'); +const Http = require('./http'); function formatParams(obj) { var str = []; @@ -21,11 +22,16 @@ function formatParams(obj) { class OAuth { constructor(client) { this.client = client; - this.jwt = null; + this.accessToken = null; } getAccessToken() { - return this.getJwt() + if (this.accessToken) { + return Promise.resolve(this.accessToken); + } + + const endpoint = '/oauth2/v1/token'; + return this.getJwt(endpoint) .then(jwt => { const params = formatParams({ grant_type: 'client_credentials', @@ -34,7 +40,7 @@ class OAuth { client_assertion: jwt }); return this.client.requestExecutor.fetch({ - url: `${this.client.baseUrl}/oauth2/v1/token`, + url: `${this.client.baseUrl}${endpoint}`, method: 'POST', body: params, headers: { @@ -42,18 +48,22 @@ class OAuth { 'Content-Type': 'application/x-www-form-urlencoded' } }); + }) + .then(Http.errorFilter) + .then(res => res.json()) + .then(accessToken => { + this.accessToken = accessToken; + return this.accessToken; }); } - getJwt() { - if (!this.jwt) { - return makeJwt(this.client) - .then(jwt => { - this.jwt = jwt.compact(); - return this.jwt; - }); - } - return Promise.resolve(this.jwt); + clearCachedAccessToken() { + this.accessToken = null; + } + + getJwt(endpoint) { + return makeJwt(this.client, endpoint) + .then(jwt => jwt.compact()); } } diff --git a/test/jest/http.test.js b/test/jest/http.test.js index 59bbbb0ad..9b24d8585 100644 --- a/test/jest/http.test.js +++ b/test/jest/http.test.js @@ -39,6 +39,72 @@ describe('Http class', () => { expect(http.oauth).toBe(oauth); }); }); + describe('errorFilter', () => { + it('should resolve promise for status in 200 - 300 range', () => { + expect.assertions(1); + const jsonResponse = { data: 'fake data' }; + const response = { + status: 200, + json: jest.fn().mockResolvedValueOnce(jsonResponse) + }; + return Promise.resolve(response) + .then(Http.errorFilter) + .then(res => res.json()) + .then(res => { + expect(res).toEqual(jsonResponse); + }); + }); + it('should reject with OktaApiError for json response with status equal or greater than 300', () => { + expect.assertions(2); + const errorObject = { + errorCode: 'a fake error' + }; + const response = { + status: 401, + url: 'http://fakey.local', + headers: { fakeHeaders: true }, + text: jest.fn().mockResolvedValueOnce(JSON.stringify(errorObject)) + }; + return Promise.resolve(response) + .then(Http.errorFilter) + .catch(err => { + expect(err).toBeInstanceOf(OktaApiError); + expect(err).toEqual({ + name: 'OktaApiError', + status: 401, + errorCode: 'a fake error', + errorSummary: '', + errorCauses: undefined, + errorLink: undefined, + errorId: undefined, + url: 'http://fakey.local', + headers: { fakeHeaders: true }, + message: 'Okta HTTP 401 a fake error ' + }); + }); + }); + it('should reject with HttpError for text response with status equal or greater than 300', () => { + expect.assertions(2); + const response = { + status: 500, + url: 'http://fakey.local', + headers: { fakeHeaders: true }, + text: jest.fn().mockResolvedValueOnce('an unknown error in plain text') + }; + return Promise.resolve(response) + .then(Http.errorFilter) + .catch(err => { + expect(err).toBeInstanceOf(HttpError); + expect(err).toEqual({ + name: 'HttpError', + status: 500, + url: 'http://fakey.local', + headers: { fakeHeaders: true }, + message: 'HTTP 500 an unknown error in plain text' + }); + }); + }); + }); describe('prepareRequest', () => { let request; beforeEach(() => { @@ -47,6 +113,7 @@ describe('Http class', () => { }; }); it('does not modify request headers if there is no "oauth" object', () => { + expect.assertions(1); const http = new Http({}); return http.prepareRequest(request) .then(() => { @@ -54,85 +121,23 @@ describe('Http class', () => { }); }); describe('OAuth', () => { - let oauth; - let getAccessTokenResponse; - let accessToken; - beforeEach(() => { - accessToken = { + it('should set Authorization header', () => { + expect.assertions(2); + const accessToken = { access_token: 'abcd1234' }; - getAccessTokenResponse = { - status: 200, - url: 'http://fakey.local', - headers: { fakeHeaders: true }, - json: jest.fn().mockImplementation(() => Promise.resolve(accessToken)) + const oauth = { + getAccessToken: jest.fn().mockResolvedValueOnce(accessToken) }; - oauth = { - getAccessToken: jest.fn().mockImplementation(() => Promise.resolve(getAccessTokenResponse)) - }; - }); - it('Sets authorization header if accessToken has been set', () => { - const http = new Http({ oauth }); - http.accessToken = accessToken; - return http.prepareRequest(request) - .then(() => { - expect(request.headers).toEqual({ - Authorization: `Bearer ${accessToken.access_token}` - }); - }); - }); - it('Requests and sets accessToken if it is not set', () => { const http = new Http({ oauth }); return http.prepareRequest(request) .then(() => { expect(oauth.getAccessToken).toHaveBeenCalled(); - expect(http.accessToken).toBe(accessToken); expect(request.headers).toEqual({ Authorization: `Bearer ${accessToken.access_token}` }); }); }); - it('Handles API errors', () => { - const http = new Http({ oauth }); - getAccessTokenResponse.status = 401; - const errorObj = { - errorCode: 'a fake error' - }; - getAccessTokenResponse.text = jest.fn().mockReturnValue(Promise.resolve(JSON.stringify(errorObj))); - return http.prepareRequest(request) - .catch(err => { - expect(err).toBeInstanceOf(OktaApiError); - expect(err).toEqual({ - name: 'OktaApiError', - status: 401, - errorCode: 'a fake error', - errorSummary: '', - errorCauses: undefined, - errorLink: undefined, - errorId: undefined, - url: 'http://fakey.local', - headers: { fakeHeaders: true }, - message: 'Okta HTTP 401 a fake error ' - }); - }); - }); - it('Handles unknown errors', () => { - const http = new Http({ oauth }); - getAccessTokenResponse.status = 500; - const errorText = 'an uknown error in plain text'; - getAccessTokenResponse.text = jest.fn().mockReturnValue(Promise.resolve(errorText)); - return http.prepareRequest(request) - .catch(err => { - expect(err).toBeInstanceOf(HttpError); - expect(err).toEqual({ - name: 'HttpError', - status: 500, - url: 'http://fakey.local', - headers: { fakeHeaders: true }, - message: 'HTTP 500 an uknown error in plain text' - }); - }); - }); }); }); describe('http method', () => { @@ -395,5 +400,81 @@ describe('Http class', () => { }, jasmine.any(Function)); }); }); + + describe('Oauth', () => { + it('should retry to get new token when response staus is 401', () => { + expect.assertions(4); + requestExecutor = { + fetch: jest.fn().mockImplementation((request) => { + if (request.headers.Authorization === 'Bearer expired_token') { + response.status = 401; + } else if (request.headers.Authorization === 'Bearer valid_token') { + response.status = 200; + } + return Promise.resolve(response); + }) + }; + const oauth = { + accessToken: { access_token: 'expired_token' }, + getAccessToken: jest.fn().mockImplementation(() => { + if (oauth.accessToken) { + return Promise.resolve(oauth.accessToken); + } + return Promise.resolve({ access_token: 'valid_token' }); + }), + clearCachedAccessToken: jest.fn().mockImplementation(() => { + oauth.accessToken = null; + }) + }; + const http = new Http({ requestExecutor, oauth }); + jest.spyOn(http, 'http'); + return http.http('http://fakey.local') + .then(res => { + expect(http.http).toHaveBeenCalledTimes(1); + expect(oauth.getAccessToken).toHaveBeenCalledTimes(2); + expect(oauth.clearCachedAccessToken).toHaveBeenCalledTimes(1); + expect(res.status).toEqual(200); + }); + }); + it('should retry only one time when response staus is 401', () => { + expect.assertions(5); + requestExecutor = { + fetch: jest.fn().mockImplementation((request) => { + if (request.headers.Authorization === 'Bearer invalid_token') { + response.status = 401; + } else if (request.headers.Authorization === 'Bearer valid_token') { + response.status = 200; + } + return Promise.resolve(response); + }) + }; + const oauth = { + getAccessToken: jest.fn().mockResolvedValue({ access_token: 'invalid_token' }), + clearCachedAccessToken: jest.fn() + }; + const http = new Http({ requestExecutor, oauth }); + jest.spyOn(http, 'http'); + return http.http('http://fakey.local') + .catch(error => { + expect(http.http).toHaveBeenCalledTimes(1); + expect(oauth.getAccessToken).toHaveBeenCalledTimes(2); + expect(oauth.clearCachedAccessToken).toHaveBeenCalledTimes(1); + expect(error).toBeInstanceOf(OktaApiError); + expect(error.status).toEqual(401); + }); + }); + it('should throw error from oauth.getAccessToken', () => { + expect.assertions(1); + response.status = 401; + const oauth = { + getAccessToken: jest.fn().mockRejectedValueOnce(new Error('bad jwk')) + }; + const http = new Http({ requestExecutor, oauth }); + return http.http('http://fakey.local') + .catch(err => { + expect(err.message).toEqual('bad jwk'); + }); + }); + }); }); }); diff --git a/test/jest/jwt.test.js b/test/jest/jwt.test.js index e077b7379..1b1ef1f6b 100644 --- a/test/jest/jwt.test.js +++ b/test/jest/jwt.test.js @@ -43,9 +43,9 @@ describe('JWT', () => { baseUrl: 'http://localhost' }; }); - function verifyJWT(jwt) { + function verifyJWT(jwt, endpoint) { expect(jwt.body).toEqual({ - aud: 'http://localhost/oauth2/v1/token', + aud: `http://localhost${endpoint}`, exp: 300, iat: 0, iss: 'fake-client-id', @@ -61,7 +61,7 @@ describe('JWT', () => { return Rasha.export({ jwk: JWK, 'public': true }).then(function (publicKey) { const verifiedJwt = nJwt.verify(compactedJwt, publicKey, 'RS256'); expect(verifiedJwt.body).toEqual({ - aud: 'http://localhost/oauth2/v1/token', + aud: `http://localhost${endpoint}`, jti: jasmine.any(String), iat: 0, exp: 300, @@ -76,16 +76,18 @@ describe('JWT', () => { } it('creates a valid JWT using PEM', () => { client.privateKey = PEM; - return JWT.makeJwt(client) + const endpoint = '/oauth2/v1/token'; + return JWT.makeJwt(client, endpoint) .then(jwt => { - return verifyJWT(jwt); + return verifyJWT(jwt, endpoint); }); }); it('creates a valid JWT using JWK', () => { client.privateKey = JWK; - return JWT.makeJwt(client) + const endpoint = '/oauth2/v1/token'; + return JWT.makeJwt(client, endpoint) .then(jwt => { - return verifyJWT(jwt); + return verifyJWT(jwt, endpoint); }); }); }); diff --git a/test/jest/oauth.test.js b/test/jest/oauth.test.js index d16d95a5a..0181b8cd6 100644 --- a/test/jest/oauth.test.js +++ b/test/jest/oauth.test.js @@ -1,6 +1,7 @@ const JWT_STRING = 'fake.jwt.string'; +const FAKE_ACCESS_TOKEN = { access_token: 'fake token' }; const mockJwt = { compact: jest.fn().mockReturnValue(JWT_STRING) }; @@ -10,12 +11,21 @@ const JWT = { }) }; jest.setMock('../../src/jwt', JWT); +const Http = { + errorFilter: jest.fn().mockImplementation(() => { + return Promise.resolve({ + json: jest.fn().mockResolvedValue(FAKE_ACCESS_TOKEN) + }); + }) +}; +jest.setMock('../../src/http', Http); const OAuth = require('../../src/oauth'); describe('OAuth', () => { let client; let oauth; + const endpoint = '/oauth/v1/token'; beforeEach(() => { client = { clientId: 'fake-client-id', @@ -30,40 +40,24 @@ describe('OAuth', () => { mockJwt.compact.mockClear(); }); describe('constructor', () => { - it('initializes jwt to null', () => { - expect(oauth.jwt).toBe(null); + it('initializes accessToken to null', () => { + expect(oauth.accessToken).toBe(null); }); }); describe('getJwt', () => { it('calls "makeJwt"', () => { - return oauth.getJwt() + return oauth.getJwt(endpoint) .then(() => { - expect(JWT.makeJwt).toHaveBeenCalledWith(client); + expect(JWT.makeJwt).toHaveBeenCalledWith(client, endpoint); }); }); it('compacts the JWT', () => { - return oauth.getJwt() - .then(jwt => { - expect(mockJwt.compact).toHaveBeenCalled(); - expect(typeof jwt).toBe('string'); - }); - }); - it('stores the compacted JWT in memory', () => { - return oauth.getJwt() + return oauth.getJwt(endpoint) .then(jwt => { - expect(jwt).toBe(JWT_STRING); - expect(oauth.jwt).toBe(JWT_STRING); + expect(mockJwt.compact).toHaveBeenCalled(); + expect(typeof jwt).toBe('string'); }); }); - it('returns JWT from memory if it exists', () => { - const jwtStr = 'a.different.jwt'; - oauth.jwt = jwtStr; - return oauth.getJwt() - .then(jwt => { - expect(jwt).toBe(jwtStr); - expect(oauth.jwt).toBe(jwtStr); - }); - }); }); describe('getAccessToken', () => { it('calls getJwt()', () => { @@ -74,8 +68,9 @@ describe('OAuth', () => { }); }); it('makes a POST request to the token endpoint', () => { + expect.assertions(3); return oauth.getAccessToken() - .then(() => { + .then(accessToken => { expect(client.requestExecutor.fetch).toHaveBeenCalledWith({ url: 'http://localhost/oauth2/v1/token', method: 'POST', @@ -90,6 +85,8 @@ describe('OAuth', () => { 'Content-Type': 'application/x-www-form-urlencoded' } }); + expect(accessToken).toEqual(FAKE_ACCESS_TOKEN); + expect(oauth.accessToken).toEqual(FAKE_ACCESS_TOKEN); }); }); });