diff --git a/middleware/requiresAuth.js b/middleware/requiresAuth.js index ef2fae11..54ce6117 100644 --- a/middleware/requiresAuth.js +++ b/middleware/requiresAuth.js @@ -54,13 +54,9 @@ function checkJSONprimitive(value) { } } -// TODO: find a better name -module.exports.withClaimEqualCheck = function withClaimEqualCheck( - claim, - expected -) { +module.exports.claimEquals = function claimEquals(claim, expected) { // check that claim is a string value - if (typeof claims !== 'string') { + if (typeof claim !== 'string') { throw new TypeError('"claim" must be a string'); } // check that expected is a JSON supported primitive @@ -84,13 +80,9 @@ module.exports.withClaimEqualCheck = function withClaimEqualCheck( return requiresLoginMiddleware.bind(undefined, authenticationCheck); }; -// TODO: find a better name -module.exports.withClaimIncluding = function withClaimIncluding( - claim, - ...expected -) { +module.exports.claimIncludes = function claimIncludes(claim, ...expected) { // check that claim is a string value - if (typeof claims !== 'string') { + if (typeof claim !== 'string') { throw new TypeError('"claim" must be a string'); } // check that all expected are JSON supported primitives @@ -109,7 +101,9 @@ module.exports.withClaimIncluding = function withClaimIncluding( if (typeof actual === 'string') { actual = actual.split(' '); } else if (!Array.isArray(actual)) { - // TODO: log unexpected type; + debug.trace( + `Unexpected claim type. Expected array or string, got ${typeof actual}` + ); return true; } @@ -120,11 +114,10 @@ module.exports.withClaimIncluding = function withClaimIncluding( return requiresLoginMiddleware.bind(undefined, authenticationCheck); }; -// TODO: find a better name -module.exports.custom = function custom(func) { +module.exports.claimCheck = function claimCheck(func) { // check that func is a function if (typeof func !== 'function' || func.constructor.name !== 'Function') { - throw new TypeError('"function" must be a function'); + throw new TypeError('"claimCheck" expects a function'); } const authenticationCheck = (req) => { if (defaultRequiresLogin(req)) { @@ -133,7 +126,7 @@ module.exports.custom = function custom(func) { const { idTokenClaims } = req.oidc; - return func(req, idTokenClaims); + return !func(req, idTokenClaims); }; return requiresLoginMiddleware.bind(undefined, authenticationCheck); }; diff --git a/test/callback.tests.js b/test/callback.tests.js index 08362173..651d0052 100644 --- a/test/callback.tests.js +++ b/test/callback.tests.js @@ -10,7 +10,7 @@ const TransientCookieHandler = require('../lib/transientHandler'); const { encodeState } = require('../lib/hooks/getLoginState'); const expressOpenid = require('..'); const { create: createServer } = require('./fixture/server'); -const cert = require('./fixture/cert'); +const { makeIdToken } = require('./fixture/cert'); const clientID = '__test_client_id__'; const expectedDefaultState = encodeState({ returnTo: 'https://example.org' }); const nock = require('nock'); @@ -100,26 +100,6 @@ const setup = async (params) => { }; }; -function makeIdToken(payload) { - payload = Object.assign( - { - nickname: '__test_nickname__', - sub: '__test_sub__', - iss: 'https://op.example.com/', - aud: clientID, - iat: Math.round(Date.now() / 1000), - exp: Math.round(Date.now() / 1000) + 60000, - nonce: '__test_nonce__', - }, - payload - ); - - return jose.JWT.sign(payload, cert.key, { - algorithm: 'RS256', - header: { kid: cert.kid }, - }); -} - // For the purpose of this test the fake SERVER returns the error message in the body directly // production application should have an error middleware. // http://expressjs.com/en/guide/error-handling.html diff --git a/test/fixture/cert.js b/test/fixture/cert.js index 945300d7..d5c3626c 100644 --- a/test/fixture/cert.js +++ b/test/fixture/cert.js @@ -1,6 +1,6 @@ -const jose = require('jose'); +const { JWK, JWKS, JWT } = require('jose'); -const key = jose.JWK.asKey({ +const key = JWK.asKey({ e: 'AQAB', n: 'wQrThQ9HKf8ksCQEzqOu0ofF8DtLJgexeFSQBNnMQetACzt4TbHPpjhTWUIlD8bFCkyx88d2_QV3TewMtfS649Pn5hV6adeYW2TxweAA8HVJxskcqTSa_ktojQ-cD43HIStsbqJhHoFv0UY6z5pwJrVPT-yt38ciKo9Oc9IhEl6TSw-zAnuNW0zPOhKjuiIqpAk1lT3e6cYv83ahx82vpx3ZnV83dT9uRbIbcgIpK4W64YnYb5uDH7hGI8-4GnalZDfdApTu-9Y8lg_1v5ul-eQDsLCkUCPkqBaNiCG3gfZUAKp9rrFRE_cJTv_MJn-y_XSTMWILvTY7vdSMRMo4kQ', @@ -21,7 +21,27 @@ const key = jose.JWK.asKey({ alg: 'RS256', }); -module.exports.jwks = new jose.JWKS.KeyStore(key).toJWKS(false); +module.exports.jwks = new JWKS.KeyStore(key).toJWKS(false); module.exports.key = key.toPEM(true); module.exports.kid = key.kid; + +module.exports.makeIdToken = (payload) => { + payload = Object.assign( + { + nickname: '__test_nickname__', + sub: '__test_sub__', + iss: 'https://op.example.com/', + aud: '__test_client_id__', + iat: Math.round(Date.now() / 1000), + exp: Math.round(Date.now() / 1000) + 60000, + nonce: '__test_nonce__', + }, + payload + ); + + return JWT.sign(payload, key.toPEM(true), { + algorithm: 'RS256', + header: { kid: key.kid }, + }); +}; diff --git a/test/invalid_response_type.tests.js b/test/invalid_response_type.tests.js deleted file mode 100644 index 42c6fe3b..00000000 --- a/test/invalid_response_type.tests.js +++ /dev/null @@ -1,18 +0,0 @@ -const assert = require('chai').assert; -const expressOpenid = require('..'); - -describe('with an unsupported response type', function () { - it('should return an error', async function () { - assert.throws(() => { - expressOpenid.auth({ - secret: '__test_session_secret__', - clientID: '__test_client_id__', - baseURL: 'https://example.org', - issuerBaseURL: 'https://op.example.com', - authorizationParams: { - response_type: '__invalid_response_type__', - }, - }); - }, '"authorizationParams.response_type" must be one of [id_token, code id_token, code]'); - }); -}); diff --git a/test/requiresAuth.tests.js b/test/requiresAuth.tests.js index bae7dad7..43c12d4d 100644 --- a/test/requiresAuth.tests.js +++ b/test/requiresAuth.tests.js @@ -1,5 +1,7 @@ const { assert } = require('chai'); +const sinon = require('sinon'); const { create: createServer } = require('./fixture/server'); +const { makeIdToken } = require('./fixture/cert'); const { auth, requiresAuth } = require('./..'); const request = require('request-promise-native').defaults({ simple: false, @@ -7,6 +9,10 @@ const request = require('request-promise-native').defaults({ followRedirect: false, }); +const baseUrl = 'http://localhost:3000'; + +const { claimEquals, claimIncludes, claimCheck } = requiresAuth; + const defaultConfig = { secret: '__test_session_secret__', clientID: '__test_client_id__', @@ -14,9 +20,20 @@ const defaultConfig = { issuerBaseURL: 'https://op.example.com', }; +const login = async (claims) => { + const jar = request.jar(); + await request.post('/session', { + baseUrl, + jar, + json: { + id_token: makeIdToken(claims), + }, + }); + return jar; +}; + describe('requiresAuth', () => { let server; - const baseUrl = 'http://localhost:3000'; afterEach(async () => { if (server) { @@ -24,6 +41,20 @@ describe('requiresAuth', () => { } }); + it('should allow logged in users to visit a protected route', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + }), + requiresAuth() + ); + const jar = await login(); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 200); + }); + it('should ask anonymous user to login when visiting a protected route', async () => { server = await createServer( auth({ @@ -56,7 +87,7 @@ describe('requiresAuth', () => { assert.equal(response.statusCode, 401); }); - it('should throw when no auth middleware', async () => { + it("should throw when there's no auth middleware", async () => { server = await createServer(null, requiresAuth()); const { body: { err }, @@ -66,4 +97,245 @@ describe('requiresAuth', () => { 'req.oidc is not found, did you include the auth middleware?' ); }); + + it('should allow logged in users with the right claim', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimEquals('foo', 'bar') + ); + const jar = await login({ foo: 'bar' }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 200); + }); + + it("should return 401 when logged in user doesn't have the right value for claim", async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimEquals('foo', 'bar') + ); + const jar = await login({ foo: 'baz' }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 401); + }); + + it("should return 401 when logged in user doesn't have the claim", async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimEquals('baz', 'bar') + ); + const jar = await login({ foo: 'bar' }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 401); + }); + + it("should return 401 when anonymous user doesn't have the right claim", async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimEquals('foo', 'bar') + ); + const response = await request({ baseUrl, url: '/protected' }); + + assert.equal(response.statusCode, 401); + }); + + it('should throw when claim is not a string', () => { + assert.throws( + () => claimEquals(true, 'bar'), + TypeError, + '"claim" must be a string' + ); + }); + + it('should throw when claim value is a non primitive', () => { + assert.throws( + () => claimEquals('foo', { bar: 1 }), + TypeError, + '"expected" must be a string, number, boolean or null' + ); + }); + + it('should allow logged in users with all of the requested claims', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimIncludes('foo', 'bar', 'baz') + ); + const jar = await login({ foo: ['baz', 'bar'] }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 200); + }); + + it('should return 401 for logged with some of the requested claims', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimIncludes('foo', 'bar', 'baz', 'qux') + ); + const jar = await login({ foo: 'baz bar' }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 401); + }); + + it('should accept claim values as a space separated list', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimIncludes('foo', 'bar', 'baz') + ); + const jar = await login({ foo: 'baz bar' }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 200); + }); + + it("should not accept claim values that aren't a string or array", async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimIncludes('foo', 'bar', 'baz') + ); + const jar = await login({ foo: { bar: 'baz' } }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 401); + }); + + it('should throw when claim value for checking many claims is a non primitive', () => { + assert.throws( + () => claimIncludes(false, 'bar'), + TypeError, + '"claim" must be a string' + ); + }); + + it("should return 401 when checking multiple claims and the user doesn't have the claim", async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimIncludes('foo', 'bar', 'baz') + ); + const jar = await login({ bar: 'bar baz' }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 401); + }); + + it('should return 401 when checking many claims with anonymous user', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimIncludes('foo', 'bar', 'baz') + ); + const response = await request({ baseUrl, url: '/protected' }); + + assert.equal(response.statusCode, 401); + }); + + it("should throw when custom claim check doesn't get a function", async () => { + assert.throws( + () => claimCheck(null), + TypeError, + '"claimCheck" expects a function' + ); + }); + + it('should allow user when custom claim check returns truthy', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimCheck(() => true) + ); + const jar = await login(); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 200); + }); + + it('should not allow user when custom claim check returns falsey', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimCheck(() => false) + ); + const jar = await login(); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 401); + }); + + it('should make the token claims available to custom check', async () => { + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimCheck((req, claims) => claims.foo === 'some_claim') + ); + const jar = await login({ foo: 'some_claim' }); + const response = await request({ baseUrl, jar, url: '/protected' }); + + assert.equal(response.statusCode, 200); + }); + + it('should not allow anonymouse users to check custom claims', async () => { + const checkSpy = sinon.spy(); + server = await createServer( + auth({ + ...defaultConfig, + authRequired: false, + errorOnRequiredAuth: true, + }), + claimCheck(checkSpy) + ); + const response = await request({ baseUrl, url: '/protected' }); + + assert.equal(response.statusCode, 401); + sinon.assert.notCalled(checkSpy); + }); });