diff --git a/lib/connect-redis.js b/lib/connect-redis.js index d16d650..a104755 100644 --- a/lib/connect-redis.js +++ b/lib/connect-redis.js @@ -10,6 +10,7 @@ module.exports = function (session) { // All callbacks should have a noop if none provided for compatibility // with the most Redis clients. const noop = () => {} + const TOMBSTONE = 'TOMBSTONE' class RedisStore extends Store { constructor(options = {}) { @@ -27,12 +28,13 @@ module.exports = function (session) { this.disableTouch = options.disableTouch || false } - get(sid, cb = noop) { + get(sid, cb = noop, showTombs = false) { let key = this.prefix + sid this.client.get(key, (err, data) => { if (err) return cb(err) if (!data) return cb() + if (data === TOMBSTONE) return cb(null, showTombs ? data : undefined) let result try { @@ -45,28 +47,40 @@ module.exports = function (session) { } set(sid, sess, cb = noop) { - let args = [this.prefix + sid] - - let value - try { - value = this.serializer.stringify(sess) - } catch (er) { - return cb(er) - } - args.push(value) + this.get( + sid, + (err, oldSess) => { + if (oldSess === TOMBSTONE) { + return cb() + } else if (oldSess && oldSess.lastModified !== sess.lastModified) { + sess = mergeDeep(oldSess, sess) + } + let args = [this.prefix + sid] + let value + sess.lastModified = Date.now() + try { + value = this.serializer.stringify(sess) + } catch (er) { + return cb(er) + } + args.push(value) + args.push('EX', this._getTTL(sess)) - let ttl = 1 - if (!this.disableTTL) { - ttl = this._getTTL(sess) - args.push('EX', ttl) - } + let ttl = 1 + if (!this.disableTTL) { + ttl = this._getTTL(sess) + args.push('EX', ttl) + } - if (ttl > 0) { - this.client.set(args, cb) - } else { - // If the resulting TTL is negative we can delete / destroy the key - this.destroy(sid, cb) - } + if (ttl > 0) { + this.client.set(args, cb) + } else { + // If the resulting TTL is negative we can delete / destroy the key + this.destroy(sid, cb) + } + }, + true + ) } touch(sid, sess, cb = noop) { @@ -81,7 +95,9 @@ module.exports = function (session) { destroy(sid, cb = noop) { let key = this.prefix + sid - this.client.del(key, cb) + this.client.set([key, TOMBSTONE, 'EX', 300], (err) => { + cb(err, 1) + }) } clear(cb = noop) { @@ -92,9 +108,9 @@ module.exports = function (session) { } length(cb = noop) { - this._getAllKeys((err, keys) => { + this.all((err, result) => { if (err) return cb(err) - return cb(null, keys.length) + return cb(null, result.length) }) } @@ -121,7 +137,7 @@ module.exports = function (session) { let result try { result = sessions.reduce((accum, data, index) => { - if (!data) return accum + if (!data || data === TOMBSTONE) return accum data = this.serializer.parse(data) data.id = keys[index].substr(prefixLen) accum.push(data) @@ -173,3 +189,35 @@ module.exports = function (session) { return RedisStore } + +/** + * Simple object check. + * @param item + * @returns {boolean} + */ +function isObject(item) { + return item && typeof item === 'object' && !Array.isArray(item) +} + +/** + * Deep merge two objects. + * @param target + * @param ...sources + */ +function mergeDeep(target, ...sources) { + if (!sources.length) return target + const source = sources.shift() + + if (isObject(target) && isObject(source)) { + for (const key in source) { + if (isObject(source[key])) { + if (!target[key]) Object.assign(target, { [key]: {} }) + mergeDeep(target[key], source[key]) + } else { + Object.assign(target, { [key]: source[key] }) + } + } + } + + return mergeDeep(target, ...sources) +} diff --git a/package.json b/package.json index 1f91246..eb5147e 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,7 @@ "eslint-config-prettier": "^8.3.0", "express-session": "^1.17.0", "ioredis": "^4.17.1", + "mockdate": "^2.0.5", "nyc": "^15.0.1", "prettier": "^2.0.5", "redis": "^3.1.2", diff --git a/test/connect-redis-test.js b/test/connect-redis-test.js index 0ee5e7f..a21b934 100644 --- a/test/connect-redis-test.js +++ b/test/connect-redis-test.js @@ -4,8 +4,10 @@ const session = require('express-session') const redis = require('redis') const ioRedis = require('ioredis') const redisMock = require('redis-mock') +const MockDate = require('mockdate') let RedisStore = require('../')(session) +MockDate.set('2000-11-22') let p = (ctx, method) => @@ -58,11 +60,34 @@ test('redis-mock client', async (t) => { test('teardown', redisSrv.disconnect) async function lifecycleTest(store, t) { - let res = await p(store, 'set')('123', { foo: 'bar' }) + await p(store, 'set')('123', { foo: 'bar3' }) + let res = await p(store, 'get')('123') + t.same(res, { foo: 'bar3', lastModified: 974851200000 }, 'get value 1') + await p(store, 'set')('123', { + foo: 'bar3', + luke: 'skywalker', + obi: 'wan', + lastModified: 974851000000, + }) + await p(store, 'set')('123', { + luke: 'skywalker', + lastModified: 974851000000, + }) + res = await p(store, 'get')('123') + t.same( + res, + { foo: 'bar3', luke: 'skywalker', obi: 'wan', lastModified: 974851200000 }, + 'get merged value' + ) + + res = await p(store, 'clear')() + t.ok(res >= 1, 'cleared key') + + res = await p(store, 'set')('123', { foo: 'bar' }) t.equal(res, 'OK', 'set value') res = await p(store, 'get')('123') - t.same(res, { foo: 'bar' }, 'get value') + t.same(res, { foo: 'bar', lastModified: 974851200000 }, 'get value') res = await p(store.client, 'ttl')('sess:123') t.ok(res >= 86399, 'check one day ttl') @@ -96,8 +121,8 @@ async function lifecycleTest(store, t) { t.same( res, [ - { id: '123', foo: 'bar' }, - { id: '456', cookie: { expires } }, + { id: '123', foo: 'bar', lastModified: 974851200000 }, + { id: '456', cookie: { expires }, lastModified: 974851200000 }, ], 'stored two keys data' ) @@ -105,11 +130,20 @@ async function lifecycleTest(store, t) { res = await p(store, 'destroy')('456') t.equal(res, 1, 'destroyed one') + res = await p(store, 'get')('456') + t.equal(res, undefined, 'tombstoned one') + + res = await p(store, 'set')('456', { a: 'new hope' }) + t.equal(res, undefined, 'tombstoned set') + + res = await p(store, 'get')('456') + t.equal(res, undefined, 'tombstoned two') + res = await p(store, 'length')() t.equal(res, 1, 'one key remains') res = await p(store, 'clear')() - t.equal(res, 1, 'cleared remaining key') + t.equal(res, 2, 'cleared remaining key') res = await p(store, 'length')() t.equal(res, 0, 'no key remains')