From f6f2d25d11ea4774bc075a06ed961db11d6396ce Mon Sep 17 00:00:00 2001 From: cleroux Date: Sun, 16 Feb 2020 13:46:16 -0800 Subject: [PATCH] Support for TOUCH command --- README.md | 1 + cmd_generic.go | 32 +++++++++++++++ cmd_generic_test.go | 79 +++++++++++++++++++++++++++++++++++++ cmd_string.go | 4 ++ db.go | 11 ++++++ direct.go | 13 ++++++ integration/generic_test.go | 19 +++++++++ miniredis.go | 2 + 8 files changed, 161 insertions(+) diff --git a/README.md b/README.md index 516f651b..1742089c 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ Implemented commands: - RENAMENX - RANDOMKEY -- see m.Seed(...) - SCAN + - TOUCH - TTL - TYPE - UNLINK diff --git a/cmd_generic.go b/cmd_generic.go index 63197e70..b9fa2abe 100644 --- a/cmd_generic.go +++ b/cmd_generic.go @@ -31,6 +31,7 @@ func commandsGeneric(m *Miniredis) { m.srv.Register("RENAMENX", m.cmdRenamenx) // RESTORE // SORT + m.srv.Register("TOUCH", m.cmdTouch) m.srv.Register("TTL", m.cmdTTL) m.srv.Register("TYPE", m.cmdType) m.srv.Register("SCAN", m.cmdScan) @@ -85,6 +86,7 @@ func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, } else { db.ttl[key] = time.Duration(i) * d } + db.origTtl[key] = db.ttl[key] db.keyVersion[key]++ db.checkTTL(key) c.WriteInt(1) @@ -92,6 +94,35 @@ func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, } } +// TOUCH +func (m *Miniredis) cmdTouch(c *server.Peer, cmd string, args []string) { + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c) { + return + } + + if len(args) == 0 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + count := 0 + for _, key := range args { + if db.exists(key) { + count++ + db.touch(key) + } + } + c.WriteInt(count) + }) +} + // TTL func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) { if len(args) != 1 { @@ -192,6 +223,7 @@ func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) { c.WriteInt(0) return } + delete(db.origTtl, key) delete(db.ttl, key) db.keyVersion[key]++ c.WriteInt(1) diff --git a/cmd_generic_test.go b/cmd_generic_test.go index 5003cfda..323349d3 100644 --- a/cmd_generic_test.go +++ b/cmd_generic_test.go @@ -143,6 +143,85 @@ func TestExpireat(t *testing.T) { } } +func TestTouch(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + // Set something + t.Run("basic", func(t *testing.T) { + s.SetTime(time.Unix(1234567890, 0)) + _, err := c.Do("SET", "foo", "bar", "EX", 100) + ok(t, err) + _, err = c.Do("SET", "baz", "qux", "EX", 100) + ok(t, err) + + // Advance time, keys still exist with 1 second TTL + s.FastForward(time.Second * 99) + equals(t, time.Second, s.TTL("foo")) + equals(t, time.Second, s.TTL("baz")) + + // Change TTL on a key to test that TOUCH will use the new value + _, err = c.Do("EXPIRE", "foo", "200") + ok(t, err) + + // Touch one key + n, err := redis.Int(c.Do("TOUCH", "baz")) + ok(t, err) + equals(t, 1, n) + + s.FastForward(time.Second * 99) + equals(t, time.Second*101, s.TTL("foo")) + equals(t, time.Second, s.TTL("baz")) + + // Reset TTL on multiple keys, "nay" doesn't exist + n, err = redis.Int(c.Do("TOUCH", "foo", "baz", "nay")) + ok(t, err) + equals(t, 2, n) + + equals(t, time.Second*200, s.TTL("foo")) + equals(t, time.Second*100, s.TTL("baz")) + }) + + t.Run("rename", func(t *testing.T) { + s.SetTime(time.Unix(1234567890, 0)) + _, err := c.Do("SET", "foo", "bar", "EX", 100) + ok(t, err) + n, err := redis.Int(c.Do("TOUCH", "foo")) + ok(t, err) + equals(t, 1, n) + + s.FastForward(time.Second * 60) + equals(t, 40*time.Second, s.TTL("foo")) + + _, err = redis.String(c.Do("RENAME", "foo", "foo2")) + ok(t, err) + n, err = redis.Int(c.Do("TOUCH", "foo2")) + ok(t, err) + equals(t, 1, n) + equals(t, 100*time.Second, s.TTL("foo2")) + }) + + t.Run("failure cases", func(t *testing.T) { + _, err := c.Do("TOUCH") + mustFail(t, err, "ERR wrong number of arguments for 'touch' command") + }) + + t.Run("direct", func(t *testing.T) { + s.Set("dir", "bar") + s.SetTTL("dir", 30*time.Second) + equals(t, 30*time.Second, s.TTL("dir")) + + s.FastForward(10 * time.Second) + equals(t, 20*time.Second, s.TTL("dir")) + + s.Touch("dir") + equals(t, 30*time.Second, s.TTL("dir")) + }) +} + func TestPexpireat(t *testing.T) { s, err := Run() ok(t, err) diff --git a/cmd_string.go b/cmd_string.go index d4551c3b..889e745c 100644 --- a/cmd_string.go +++ b/cmd_string.go @@ -121,6 +121,7 @@ func (m *Miniredis) cmdSet(c *server.Peer, cmd string, args []string) { // a vanilla SET clears the expire db.stringSet(key, value) if ttl != 0 { + db.origTtl[key] = ttl db.ttl[key] = ttl } c.WriteOK() @@ -161,6 +162,7 @@ func (m *Miniredis) cmdSetex(c *server.Peer, cmd string, args []string) { db.del(key, true) // Clear any existing keys. db.stringSet(key, value) db.ttl[key] = time.Duration(ttl) * time.Second + db.origTtl[key] = db.ttl[key] c.WriteOK() }) } @@ -199,6 +201,7 @@ func (m *Miniredis) cmdPsetex(c *server.Peer, cmd string, args []string) { db.del(key, true) // Clear any existing keys. db.stringSet(key, value) db.ttl[key] = time.Duration(ttl) * time.Millisecond + db.origTtl[key] = db.ttl[key] c.WriteOK() }) } @@ -374,6 +377,7 @@ func (m *Miniredis) cmdGetset(c *server.Peer, cmd string, args []string) { old, ok := db.stringKeys[key] db.stringSet(key, value) // a GETSET clears the ttl + delete(db.origTtl, key) delete(db.ttl, key) if !ok { diff --git a/db.go b/db.go index 78557e75..72231cb2 100644 --- a/db.go +++ b/db.go @@ -41,6 +41,7 @@ func (db *RedisDB) flush() { db.listKeys = map[string]listKey{} db.setKeys = map[string]setKey{} db.sortedsetKeys = map[string]sortedSet{} + db.origTtl = map[string]time.Duration{} db.ttl = map[string]time.Duration{} } @@ -102,6 +103,10 @@ func (db *RedisDB) rename(from, to string) { if v, ok := db.ttl[from]; ok { db.ttl[to] = v } + if v, ok := db.origTtl[from]; ok { + delete(db.origTtl, from) + db.origTtl[to] = v + } db.del(from, true) } @@ -114,6 +119,7 @@ func (db *RedisDB) del(k string, delTTL bool) { delete(db.keys, k) db.keyVersion[k]++ if delTTL { + delete(db.origTtl, k) delete(db.ttl, k) } switch t { @@ -803,3 +809,8 @@ func (db *RedisDB) checkTTL(key string) { db.del(key, true) } } + +// Touch resets a key's TTL to its original unmodified value. +func (db *RedisDB) touch(k string) { + db.ttl[k] = db.origTtl[k] +} diff --git a/direct.go b/direct.go index ec64f387..2f28ac2a 100644 --- a/direct.go +++ b/direct.go @@ -416,10 +416,23 @@ func (db *RedisDB) SetTTL(k string, ttl time.Duration) { defer db.master.Unlock() defer db.master.signal.Broadcast() + db.origTtl[k] = ttl db.ttl[k] = ttl db.keyVersion[k]++ } +// Touch resets the TTL of the key +func (m *Miniredis) Touch(k string) { + m.DB(m.selectedDB).Touch(k) +} + +func (db *RedisDB) Touch(k string) { + db.master.Lock() + defer db.master.Unlock() + + db.touch(k) +} + // Type gives the type of a key, or "" func (m *Miniredis) Type(k string) string { return m.DB(m.selectedDB).Type(k) diff --git a/integration/generic_test.go b/integration/generic_test.go index eb4e1287..302f7f9e 100644 --- a/integration/generic_test.go +++ b/integration/generic_test.go @@ -279,3 +279,22 @@ func TestUnlink(t *testing.T) { succSorted("KEYS", "*"), ) } + +func TestTouch(t *testing.T) { + testCommands(t, + succ("SET", "a", "some value"), + succ("TTL", "a"), + succ("EXPIRE", "a", 400), + succ("TTL", "a"), + + succ("TOUCH", "a"), + succ("TTL", "a"), + succ("TOUCH", "a", "foobar", "a"), + + succ("RENAME", "a", "a2"), + succ("TOUCH", "a"), + succ("TTL", "a"), // hard to test + + fail("TOUCH"), + ) +} diff --git a/miniredis.go b/miniredis.go index 0565b92c..2c0d0c92 100644 --- a/miniredis.go +++ b/miniredis.go @@ -41,6 +41,7 @@ type RedisDB struct { sortedsetKeys map[string]sortedSet // ZADD &c. keys streamKeys map[string]streamKey // XADD &c. keys streamGroupKeys map[string]streamGroupKey // XREADGROUP &c. keys + origTtl map[string]time.Duration // unmodified TTL values ttl map[string]time.Duration // effective TTL values keyVersion map[string]uint // used to watch values } @@ -105,6 +106,7 @@ func newRedisDB(id int, m *Miniredis) RedisDB { sortedsetKeys: map[string]sortedSet{}, streamKeys: map[string]streamKey{}, streamGroupKeys: map[string]streamGroupKey{}, + origTtl: map[string]time.Duration{}, ttl: map[string]time.Duration{}, keyVersion: map[string]uint{}, }