Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom state handling #60

Merged
merged 8 commits into from
Feb 13, 2020
29 changes: 29 additions & 0 deletions EXAMPLES.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,32 @@ app.use(auth({
}
}));
```

## 8. Custom state handling

If your application needs to keep track of the request state before redirecting to log in, you can use the built-in state handling. By default, this library stores the post-callback redirect URL in a state object (along with a generated nonce) that is converted to a string, base64 encoded, and verified during callback (see [our documentation](https://auth0.com/docs/protocols/oauth2/oauth-state) for general information about this parameter). This state object can be added to and used during callback.

You can define a `getLoginState` configuration key set to a function that takes an Express `RequestHandler` and an options object and returns a plain object:

```js
app.use(auth({
getLoginState: function (req, options) {
// This object will be stringified and base64 URL-safe encoded.
return {
// Property used by the library for redirecting after logging in.
returnTo: '/custom-return-path',
// Additional properties as needed.
customProperty: req.someProperty,
};
},
handleCallback: function (req, res, next) {
// The req.openidState.customProperty is now available to use.
if ( req.openidState.customProperty ) {
// Do something ...
}

// Call next() to redirect to req.openidState.returnTo.
next();
}
}));
```
5 changes: 5 additions & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ interface ConfigParams {
*/
errorOnRequiredAuth?: boolean;

/**
* Function that returns a URL-safe state value for `res.openid.login()`.
*/
getLoginState?: (req: Request, config: object) => object;

/**
* Function that returns the profile for `req.openid.user`.
*/
Expand Down
2 changes: 2 additions & 0 deletions lib/config.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const Joi = require('@hapi/joi');
const clone = require('clone');
const { defaultState: getLoginState } = require('./hooks/getLoginState');
const getUser = require('./hooks/getUser');
const handleCallback = require('./hooks/handleCallback');

Expand Down Expand Up @@ -54,6 +55,7 @@ const paramsSchema = Joi.object({
),
clockTolerance: Joi.number().optional().default(60),
errorOnRequiredAuth: Joi.boolean().optional().default(false),
getLoginState: Joi.function().optional().default(() => getLoginState),
getUser: Joi.function().optional().default(() => getUser),
handleCallback: Joi.function().optional().default(() => handleCallback),
httpOptions: Joi.object().optional(),
Expand Down
60 changes: 37 additions & 23 deletions lib/context.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
const cb = require('cb');
const url = require('url');
const urlJoin = require('url-join');
const { TokenSet } = require('openid-client');

const transient = require('./transientHandler');
const { get: getClient } = require('./client');
const { TokenSet } = require('openid-client');
const { encodeState } = require('../lib/hooks/getLoginState');

class RequestContext {
constructor(config, req, res, next) {
Expand Down Expand Up @@ -46,37 +48,49 @@ class ResponseContext {
return urlJoin(this._config.baseURL, this._config.redirectUriPath);
}

async login(params = {}) {
async login(options = {}) {
const next = cb(this._next).once();
const req = this._req;
const res = this._res;
const config = this._config;

const client = req.openid.client;
const authorizeParams = config.authorizationParams;

// Set default returnTo value, allow passed-in options to override.
options = {
returnTo: this._config.baseURL,
authorizationParams: {},
...options
};

// Ensure a redirect_uri, merge in configuration options, then passed-in options.
options.authorizationParams = {
redirect_uri: this.getRedirectUri(),
...config.authorizationParams,
...options.authorizationParams
};

const transientOpts = {
legacySameSiteCookie: config.legacySameSiteCookie,
sameSite: config.authorizationParams.response_mode === 'form_post' ? 'None' : 'Lax'
sameSite: options.authorizationParams.response_mode === 'form_post' ? 'None' : 'Lax'
};

let stateValue = await config.getLoginState(req, options);
if ( typeof stateValue !== 'object' ) {
next(new Error( 'Custom state value must be an object.' ));
}
stateValue.nonce = transient.createNonce();

const stateTransientOpts = {
...transientOpts,
value: encodeState(stateValue)
};

try {
let returnTo;
if (params.returnTo) {
returnTo = params.returnTo;
} else if (req.method === 'GET') {
returnTo = req.originalUrl;
} else {
returnTo = this._config.baseURL;
}

// TODO: Store this in state
transient.store('returnTo', res, Object.assign({value: returnTo}, transientOpts));

const authParams = Object.assign({
const authParams = {
...options.authorizationParams,
joshcanhelp marked this conversation as resolved.
Show resolved Hide resolved
nonce: transient.store('nonce', res, transientOpts),
state: transient.store('state', res, transientOpts),
redirect_uri: this.getRedirectUri()
}, authorizeParams, params.authorizationParams || {});
state: transient.store('state', res, stateTransientOpts)
};

const authorizationUrl = client.authorizationUrl(authParams);
res.redirect(authorizationUrl);
Expand All @@ -91,11 +105,11 @@ class ResponseContext {
const res = this._res;

let returnURL = params.returnTo || req.query.returnTo || this._config.postLogoutRedirectUri;

if (url.parse(returnURL).host === null) {
returnURL = urlJoin(this._config.baseURL, returnURL);
}

if (!req.isAuthenticated()) {
return res.redirect(returnURL);
}
Expand Down
41 changes: 41 additions & 0 deletions lib/hooks/getLoginState.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
const { encode: base64encode, decode: base64decode } = require('base64url');

module.exports.defaultState = defaultState;
module.exports.encodeState = encodeState;
module.exports.decodeState = decodeState;

/**
* Generate a unique state value for use during login transactions.
*
* @param {RequestHandler} req
* @param {object} options
*
* @return {object}
*/
function defaultState(req, options) {
return {
returnTo: options.returnTo || req.originalUrl
};
}

/**
* Prepare a state object to send.
*
* @param {object} stateObject
*
* @return {string}
*/
function encodeState(stateObject) {
return base64encode(JSON.stringify(stateObject));
}

/**
* Decode a state value.
*
* @param {string} stateValue
*
* @return {object}
*/
function decodeState(stateValue) {
return JSON.parse(base64decode(stateValue));
}
8 changes: 4 additions & 4 deletions lib/transientHandler.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
const crypto = require('crypto');

exports.store = store;
exports.getOnce = getOnce;
exports.createNonce = createNonce;

/**
* Set a cookie with a value or a generated nonce.
*
Expand Down Expand Up @@ -84,7 +88,3 @@ function createNonce() {
function deleteCookie(name, res) {
res.cookie(name, '', {maxAge: 0});
}

exports.store = store;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just moving these to the top for clarity

exports.getOnce = getOnce;
exports.createNonce = createNonce;
10 changes: 6 additions & 4 deletions middleware/auth.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const requiresAuth = require('./requiresAuth');
const transient = require('../lib/transientHandler');
const { RequestContext, ResponseContext } = require('../lib/context');
const appSession = require('../lib/appSession');
const { decodeState } = require('../lib/hooks/getLoginState');

const enforceLeadingSlash = (path) => {
return '/' === path.split('')[0] ? path : '/' + path;
Expand Down Expand Up @@ -83,19 +84,21 @@ module.exports = function (params) {
const redirectUri = res.openid.getRedirectUri();
const client = req.openid.client;

let tokenSet;
req.openidState = transient.getOnce('state', req, res, transientOpts);
joshcanhelp marked this conversation as resolved.
Show resolved Hide resolved

let tokenSet;
try {
const callbackParams = client.callbackParams(req);
tokenSet = await client.callback(redirectUri, callbackParams, {
nonce: transient.getOnce('nonce', req, res, transientOpts),
state: transient.getOnce('state', req, res, transientOpts),
state: req.openidState,
response_type: authorizeParams.response_type,
});
} catch (err) {
throw createError.BadRequest(err.message);
}

req.openidState = decodeState(req.openidState);
req.openidTokens = tokenSet;

if (config.appSessionSecret) {
Expand All @@ -115,8 +118,7 @@ module.exports = function (params) {
},
config.handleCallback,
function (req, res) {
const returnTo = transient.getOnce('returnTo', req, res, transientOpts) || config.baseURL;
res.redirect(returnTo);
res.redirect(req.openidState.returnTo || config.baseURL);
}
);

Expand Down
Loading