diff --git a/.travis.yml b/.travis.yml index d9122d17..c66fc906 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,4 +10,4 @@ script: make test testrace int sudo: false go: - - 1.11 + - 1.12 diff --git a/cmd_connection.go b/cmd_connection.go index ca648f4b..1f35b98f 100644 --- a/cmd_connection.go +++ b/cmd_connection.go @@ -21,7 +21,33 @@ func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } - c.WriteInline("PONG") + + if len(args) > 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + + payload := "" + if len(args) > 0 { + payload = args[0] + } + + // PING is allowed in subscribed state + if sub := getCtx(c).subscriber; sub != nil { + c.Block(func(c *server.Writer) { + c.WriteLen(2) + c.WriteBulk("pong") + c.WriteBulk(payload) + }) + return + } + + if payload == "" { + c.WriteInline("PONG") + return + } + c.WriteBulk(payload) } // AUTH @@ -31,6 +57,10 @@ func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) { c.WriteError(errWrongNumber(cmd)) return } + if m.checkPubsub(c) { + return + } + pw := args[0] m.Lock() @@ -58,6 +88,9 @@ func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } msg := args[0] c.WriteBulk(msg) @@ -73,6 +106,9 @@ func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } id, err := strconv.Atoi(args[0]) if err != nil { diff --git a/cmd_connection_test.go b/cmd_connection_test.go index 62b74212..eb28569a 100644 --- a/cmd_connection_test.go +++ b/cmd_connection_test.go @@ -30,6 +30,26 @@ func TestAuth(t *testing.T) { ok(t, err) } +func TestPing(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + r, err := redis.String(c.Do("PING")) + ok(t, err) + equals(t, "PONG", r) + + r, err = redis.String(c.Do("PING", "hi")) + ok(t, err) + equals(t, "hi", r) + + _, err = c.Do("PING", "foo", "bar") + mustFail(t, err, errWrongNumber("ping")) + +} + func TestEcho(t *testing.T) { s, err := Run() ok(t, err) diff --git a/cmd_generic.go b/cmd_generic.go index fa394790..129df634 100644 --- a/cmd_generic.go +++ b/cmd_generic.go @@ -49,6 +49,9 @@ func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] value := args[1] @@ -102,6 +105,10 @@ func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + key := args[0] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -133,6 +140,10 @@ func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + key := args[0] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -164,6 +175,10 @@ func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + key := args[0] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -191,6 +206,9 @@ func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -216,6 +234,9 @@ func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -242,6 +263,9 @@ func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -266,6 +290,9 @@ func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] targetDB, err := strconv.Atoi(args[1]) @@ -299,6 +326,9 @@ func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -323,6 +353,9 @@ func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -352,6 +385,9 @@ func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } from, to := args[0], args[1] @@ -378,6 +414,9 @@ func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } from, to := args[0], args[1] @@ -409,6 +448,9 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } cursor, err := strconv.Atoi(args[0]) if err != nil { diff --git a/cmd_hash.go b/cmd_hash.go index 1c65ebec..78fffb2c 100644 --- a/cmd_hash.go +++ b/cmd_hash.go @@ -37,6 +37,9 @@ func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field, value := args[0], args[1], args[2] @@ -66,6 +69,9 @@ func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field, value := args[0], args[1], args[2] @@ -102,6 +108,9 @@ func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] if len(args)%2 != 0 { @@ -138,6 +147,9 @@ func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field := args[0], args[1] @@ -172,6 +184,9 @@ func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, fields := args[0], args[1:] @@ -217,6 +232,9 @@ func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field := args[0], args[1] @@ -251,6 +269,9 @@ func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -285,6 +306,9 @@ func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -318,6 +342,9 @@ func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -351,6 +378,9 @@ func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -381,6 +411,9 @@ func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -419,6 +452,9 @@ func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field, deltas := args[0], args[1], args[2] @@ -456,6 +492,9 @@ func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field, deltas := args[0], args[1], args[2] @@ -493,6 +532,9 @@ func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] cursor, err := strconv.Atoi(args[1]) diff --git a/cmd_list.go b/cmd_list.go index ae543dc6..23aa62f4 100644 --- a/cmd_list.go +++ b/cmd_list.go @@ -57,6 +57,10 @@ func (m *Miniredis) cmdBXpop(c *server.Peer, cmd string, args []string, lr leftr if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + timeoutS := args[len(args)-1] keys := args[:len(args)-1] @@ -121,6 +125,9 @@ func (m *Miniredis) cmdLindex(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, offsets := args[0], args[1] @@ -167,6 +174,9 @@ func (m *Miniredis) cmdLinsert(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] where := 0 @@ -231,6 +241,9 @@ func (m *Miniredis) cmdLlen(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -271,6 +284,9 @@ func (m *Miniredis) cmdXpop(c *server.Peer, cmd string, args []string, lr leftri if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -317,6 +333,9 @@ func (m *Miniredis) cmdXpush(c *server.Peer, cmd string, args []string, lr leftr if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] @@ -360,6 +379,9 @@ func (m *Miniredis) cmdXpushx(c *server.Peer, cmd string, args []string, lr left if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] @@ -398,6 +420,9 @@ func (m *Miniredis) cmdLrange(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -445,6 +470,9 @@ func (m *Miniredis) cmdLrem(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] count, err := strconv.Atoi(args[1]) @@ -514,6 +542,9 @@ func (m *Miniredis) cmdLset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] index, err := strconv.Atoi(args[1]) @@ -561,6 +592,9 @@ func (m *Miniredis) cmdLtrim(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -612,6 +646,9 @@ func (m *Miniredis) cmdRpoplpush(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } src, dst := args[0], args[1] @@ -642,6 +679,9 @@ func (m *Miniredis) cmdBrpoplpush(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } src := args[0] dst := args[1] diff --git a/cmd_pubsub.go b/cmd_pubsub.go new file mode 100644 index 00000000..4763f9bf --- /dev/null +++ b/cmd_pubsub.go @@ -0,0 +1,321 @@ +// Commands from https://redis.io/commands#pubsub + +package miniredis + +import ( + "fmt" + "regexp" + "strings" + + "github.com/alicebob/miniredis/server" +) + +// commandsPubsub handles all PUB/SUB operations. +func commandsPubsub(m *Miniredis) { + m.srv.Register("SUBSCRIBE", m.cmdSubscribe) + m.srv.Register("UNSUBSCRIBE", m.cmdUnsubscribe) + m.srv.Register("PSUBSCRIBE", m.cmdPsubscribe) + m.srv.Register("PUNSUBSCRIBE", m.cmdPunsubscribe) + m.srv.Register("PUBLISH", m.cmdPublish) + m.srv.Register("PUBSUB", m.cmdPubSub) +} + +// SUBSCRIBE +func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + sub := m.subscribedState(c) + for _, channel := range args { + n := sub.Subscribe(channel) + c.Block(func(c *server.Writer) { + c.WriteLen(3) + c.WriteBulk("subscribe") + c.WriteBulk(channel) + c.WriteInt(n) + }) + } + }) +} + +// UNSUBSCRIBE +func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) { + if !m.handleAuth(c) { + return + } + + channels := args + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + sub := m.subscribedState(c) + + if len(channels) == 0 { + channels = sub.Channels() + } + + // there is no de-duplication + for _, channel := range channels { + n := sub.Unsubscribe(channel) + c.Block(func(c *server.Writer) { + c.WriteLen(3) + c.WriteBulk("unsubscribe") + c.WriteBulk(channel) + c.WriteInt(n) + }) + } + + if sub.Count() == 0 { + endSubscriber(m, c) + } + }) +} + +// PSUBSCRIBE +func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + sub := m.subscribedState(c) + for _, pat := range args { + n := sub.Psubscribe(pat) + c.Block(func(c *server.Writer) { + c.WriteLen(3) + c.WriteBulk("psubscribe") + c.WriteBulk(pat) + c.WriteInt(n) + }) + } + }) +} + +func compileChannelPattern(pattern string) *regexp.Regexp { + const readingLiteral uint8 = 0 + const afterEscape uint8 = 1 + const inClass uint8 = 2 + + rgx := []rune{'\\', 'A'} + state := readingLiteral + literals := []rune{} + klass := map[rune]struct{}{} + + for _, c := range pattern { + switch state { + case readingLiteral: + switch c { + case '\\': + state = afterEscape + case '?': + rgx = append(rgx, append([]rune(regexp.QuoteMeta(string(literals))), '.')...) + literals = []rune{} + case '*': + rgx = append(rgx, append([]rune(regexp.QuoteMeta(string(literals))), '.', '*')...) + literals = []rune{} + case '[': + rgx = append(rgx, []rune(regexp.QuoteMeta(string(literals)))...) + literals = []rune{} + state = inClass + default: + literals = append(literals, c) + } + case afterEscape: + literals = append(literals, c) + state = readingLiteral + case inClass: + if c == ']' { + expr := []rune{'['} + + if _, hasDash := klass['-']; hasDash { + delete(klass, '-') + expr = append(expr, '-') + } + + flatClass := make([]rune, len(klass)) + i := 0 + + for c := range klass { + flatClass[i] = c + i++ + } + + klass = map[rune]struct{}{} + expr = append(append(expr, []rune(regexp.QuoteMeta(string(flatClass)))...), ']') + + if len(expr) < 3 { + rgx = append(rgx, 'x', '\\', 'b', 'y') + } else { + rgx = append(rgx, expr...) + } + + state = readingLiteral + } else { + klass[c] = struct{}{} + } + } + } + + switch state { + case afterEscape: + rgx = append(rgx, '\\', '\\') + case inClass: + if len(klass) < 0 { + rgx = append(rgx, '\\', '[') + } else { + expr := []rune{'['} + + if _, hasDash := klass['-']; hasDash { + delete(klass, '-') + expr = append(expr, '-') + } + + flatClass := make([]rune, len(klass)) + i := 0 + + for c := range klass { + flatClass[i] = c + i++ + } + + expr = append(append(expr, []rune(regexp.QuoteMeta(string(flatClass)))...), ']') + + if len(expr) < 3 { + rgx = append(rgx, 'x', '\\', 'b', 'y') + } else { + rgx = append(rgx, expr...) + } + } + } + + return regexp.MustCompile(string(append(rgx, '\\', 'z'))) +} + +// PUNSUBSCRIBE +func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) { + if !m.handleAuth(c) { + return + } + + patterns := args + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + sub := m.subscribedState(c) + + if len(patterns) == 0 { + patterns = sub.Patterns() + } + + // there is no de-duplication + for _, pat := range patterns { + n := sub.Punsubscribe(pat) + c.Block(func(c *server.Writer) { + c.WriteLen(3) + c.WriteBulk("punsubscribe") + c.WriteBulk(pat) + c.WriteInt(n) + }) + } + + if sub.Count() == 0 { + endSubscriber(m, c) + } + }) +} + +// PUBLISH +func (m *Miniredis) cmdPublish(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + if m.checkPubsub(c) { + return + } + + channel, mesg := args[0], args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + c.WriteInt(m.publish(channel, mesg)) + }) +} + +// PUBSUB +func (m *Miniredis) cmdPubSub(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + + if m.checkPubsub(c) { + return + } + + subcommand := strings.ToUpper(args[0]) + subargs := args[1:] + var argsOk bool + + switch subcommand { + case "CHANNELS": + argsOk = len(subargs) < 2 + case "NUMSUB": + argsOk = true + case "NUMPAT": + argsOk = len(subargs) == 0 + default: + argsOk = false + } + + if !argsOk { + setDirty(c) + c.WriteError(fmt.Sprintf(msgFPubsubUsage, subcommand)) + return + } + + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + switch subcommand { + case "CHANNELS": + pat := "" + if len(subargs) == 1 { + pat = subargs[0] + } + + channels := activeChannels(m.allSubscribers(), pat) + + c.WriteLen(len(channels)) + for _, channel := range channels { + c.WriteBulk(channel) + } + + case "NUMSUB": + subs := m.allSubscribers() + c.WriteLen(len(subargs) * 2) + for _, channel := range subargs { + c.WriteBulk(channel) + c.WriteInt(countSubs(subs, channel)) + } + case "NUMPAT": + c.WriteInt(countPsubs(m.allSubscribers())) + } + }) +} diff --git a/cmd_pubsub_test.go b/cmd_pubsub_test.go new file mode 100644 index 00000000..b01026d9 --- /dev/null +++ b/cmd_pubsub_test.go @@ -0,0 +1,820 @@ +package miniredis + +import ( + "testing" + + "github.com/gomodule/redigo/redis" +) + +func TestSubscribe(t *testing.T) { + s, c, done := setup(t) + defer done() + defer c.Close() + + { + a, err := redis.Values(c.Do("SUBSCRIBE", "event1")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event1"), int64(1)}, a) + } + + { + a, err := redis.Values(c.Do("SUBSCRIBE", "event2")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event2"), int64(2)}, a) + } + + { + a, err := redis.Values(c.Do("SUBSCRIBE", "event3", "event4")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event3"), int64(3)}, a) + + a, err = redis.Values(c.Receive()) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event4"), int64(4)}, a) + } + + { + // publish something! + a, err := redis.Values(c.Do("SUBSCRIBE", "colors")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("colors"), int64(5)}, a) + + n := s.Publish("colors", "green") + equals(t, 1, n) + + s, err := redis.Strings(c.Receive()) + ok(t, err) + equals(t, []string{"message", "colors", "green"}, s) + } +} + +func TestUnsubscribe(t *testing.T) { + _, c, done := setup(t) + defer done() + + ok(t, c.Send("SUBSCRIBE", "event1", "event2", "event3", "event4", "event5")) + c.Flush() + c.Receive() + c.Receive() + c.Receive() + c.Receive() + c.Receive() + + { + a, err := redis.Values(c.Do("UNSUBSCRIBE", "event1", "event2")) + ok(t, err) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event1"), int64(4)}, a) + + a, err = redis.Values(c.Receive()) + ok(t, err) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event2"), int64(3)}, a) + } + + { + a, err := redis.Values(c.Do("UNSUBSCRIBE", "event3")) + ok(t, err) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event3"), int64(2)}, a) + } + + { + a, err := redis.Values(c.Do("UNSUBSCRIBE", "event999")) + ok(t, err) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event999"), int64(2)}, a) + } + + { + // unsub the rest + ok(t, c.Send("UNSUBSCRIBE")) + c.Flush() + seen := map[string]bool{} + for i := 0; i < 2; i++ { + vs, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, 3, len(vs)) + equals(t, "unsubscribe", string(vs[0].([]byte))) + seen[string(vs[1].([]byte))] = true + equals(t, 1-i, int(vs[2].(int64))) + } + equals(t, + map[string]bool{ + "event4": true, + "event5": true, + }, + seen, + ) + } +} + +func TestPsubscribe(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + a, err := redis.Values(c.Do("PSUBSCRIBE", "event1")) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event1"), int64(1)}, a) + } + + { + a, err := redis.Values(c.Do("PSUBSCRIBE", "event2?")) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event2?"), int64(2)}, a) + } + + { + a, err := redis.Values(c.Do("PSUBSCRIBE", "event3*", "event4[abc]")) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event3*"), int64(3)}, a) + + a, err = redis.Values(c.Receive()) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event4[abc]"), int64(4)}, a) + } + + { + a, err := redis.Values(c.Do("PSUBSCRIBE", "event5[]")) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event5[]"), int64(5)}, a) + } + + { + // publish some things! + n := s.Publish("event4b", "hello 4b!") + equals(t, 1, n) + + n = s.Publish("event4d", "hello 4d?") + equals(t, 0, n) + + s, err := redis.Strings(c.Receive()) + ok(t, err) + equals(t, []string{"message", "event4b", "hello 4b!"}, s) + } +} + +func TestPunsubscribe(t *testing.T) { + _, c, done := setup(t) + defer done() + + c.Send("PSUBSCRIBE", "event1", "event2?", "event3*", "event4[abc]", "event5[]") + c.Flush() + c.Receive() + c.Receive() + c.Receive() + c.Receive() + c.Receive() + + { + ok(t, c.Send("PUNSUBSCRIBE", "event1", "event2?")) + c.Flush() + seen := map[string]bool{} + for i := 0; i < 2; i++ { + vs, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, 3, len(vs)) + equals(t, "punsubscribe", string(vs[0].([]byte))) + seen[string(vs[1].([]byte))] = true + equals(t, 4-i, int(vs[2].(int64))) + } + equals(t, + map[string]bool{ + "event1": true, + "event2?": true, + }, + seen, + ) + } + + // punsub the rest + { + ok(t, c.Send("PUNSUBSCRIBE")) + c.Flush() + seen := map[string]bool{} + for i := 0; i < 3; i++ { + vs, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, 3, len(vs)) + equals(t, "punsubscribe", string(vs[0].([]byte))) + seen[string(vs[1].([]byte))] = true + equals(t, 2-i, int(vs[2].(int64))) + } + equals(t, + map[string]bool{ + "event3*": true, + "event4[abc]": true, + "event5[]": true, + }, + seen, + ) + } +} + +func TestPublishMode(t *testing.T) { + // only pubsub related commands should be accepted while there are + // subscriptions. + _, c, done := setup(t) + defer done() + + _, err := c.Do("SUBSCRIBE", "birds") + ok(t, err) + + _, err = c.Do("SET", "foo", "bar") + mustFail(t, err, "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") + + _, err = c.Do("UNSUBSCRIBE", "birds") + ok(t, err) + + // no subs left. All should be fine now. + _, err = c.Do("SET", "foo", "bar") + ok(t, err) +} + +func TestPublish(t *testing.T) { + s, c, c2, done := setup2(t) + defer done() + + a, err := redis.Values(c2.Do("SUBSCRIBE", "event1")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event1"), int64(1)}, a) + + { + n, err := redis.Int(c.Do("PUBLISH", "event1", "message2")) + ok(t, err) + equals(t, 1, n) + + s, err := redis.Strings(c2.Receive()) + ok(t, err) + equals(t, []string{"message", "event1", "message2"}, s) + } + + // direct access + { + equals(t, 1, s.Publish("event1", "message3")) + + s, err := redis.Strings(c2.Receive()) + ok(t, err) + equals(t, []string{"message", "event1", "message3"}, s) + } + + // Wrong usage + { + _, err := c2.Do("PUBLISH", "foo", "bar") + mustFail(t, err, "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") + } +} + +func TestPublishMix(t *testing.T) { + // SUBSCRIBE and PSUBSCRIBE + _, c, done := setup(t) + defer done() + + a, err := redis.Values(c.Do("SUBSCRIBE", "c1")) + ok(t, err) + equals(t, 1, int(a[2].(int64))) + + a, err = redis.Values(c.Do("PSUBSCRIBE", "c1")) + ok(t, err) + equals(t, 2, int(a[2].(int64))) + + a, err = redis.Values(c.Do("SUBSCRIBE", "c2")) + ok(t, err) + equals(t, 3, int(a[2].(int64))) + + a, err = redis.Values(c.Do("PUNSUBSCRIBE", "c1")) + ok(t, err) + equals(t, 2, int(a[2].(int64))) + + a, err = redis.Values(c.Do("UNSUBSCRIBE", "c1")) + ok(t, err) + equals(t, 1, int(a[2].(int64))) +} + +func TestPubsubChannels(t *testing.T) { + _, c1, c2, done := setup2(t) + defer done() + + a, err := redis.Strings(c1.Do("PUBSUB", "CHANNELS")) + ok(t, err) + equals(t, []string{}, a) + + a, err = redis.Strings(c1.Do("PUBSUB", "CHANNELS", "event1[abc]")) + ok(t, err) + equals(t, []string{}, a) + + _, err = c2.Do("SUBSCRIBE", "event1", "event1b", "event1c") + ok(t, err) + + a, err = redis.Strings(c1.Do("PUBSUB", "CHANNELS")) + ok(t, err) + equals(t, []string{"event1", "event1b", "event1c"}, a) + + a, err = redis.Strings(c1.Do("PUBSUB", "CHANNELS", "event1[abc]")) + ok(t, err) + equals(t, []string{"event1b", "event1c"}, a) +} + +func TestPubsubNumsub(t *testing.T) { + _, c, c2, done := setup2(t) + defer done() + + _, err := c2.Do("SUBSCRIBE", "event1", "event2", "event3") + ok(t, err) + + { + a, err := redis.Values(c.Do("PUBSUB", "NUMSUB")) + ok(t, err) + equals(t, []interface{}{}, a) + } + + { + a, err := redis.Values(c.Do("PUBSUB", "NUMSUB", "event1")) + ok(t, err) + equals(t, []interface{}{[]byte("event1"), int64(1)}, a) + } + + { + a, err := redis.Values(c.Do("PUBSUB", "NUMSUB", "event12", "event3")) + ok(t, err) + equals(t, + []interface{}{ + []byte("event12"), int64(0), + []byte("event3"), int64(1), + }, + a, + ) + } +} + +func TestPubsubNumpat(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + a, err := redis.Int(c.Do("PUBSUB", "NUMPAT")) + ok(t, err) + equals(t, 0, a) + } + + equals(t, 0, s.PubSubNumPat()) +} + +func TestPubSubBadArgs(t *testing.T) { + for _, command := range [9]struct { + command string + args []interface{} + err string + }{ + {"SUBSCRIBE", []interface{}{}, "ERR wrong number of arguments for 'subscribe' command"}, + {"PSUBSCRIBE", []interface{}{}, "ERR wrong number of arguments for 'psubscribe' command"}, + {"PUBLISH", []interface{}{}, "ERR wrong number of arguments for 'publish' command"}, + {"PUBLISH", []interface{}{"event1"}, "ERR wrong number of arguments for 'publish' command"}, + {"PUBLISH", []interface{}{"event1", "message2", "message3"}, "ERR wrong number of arguments for 'publish' command"}, + {"PUBSUB", []interface{}{}, "ERR wrong number of arguments for 'pubsub' command"}, + {"PUBSUB", []interface{}{"FOOBAR"}, "ERR Unknown subcommand or wrong number of arguments for 'FOOBAR'. Try PUBSUB HELP."}, + {"PUBSUB", []interface{}{"NUMPAT", "FOOBAR"}, "ERR Unknown subcommand or wrong number of arguments for 'NUMPAT'. Try PUBSUB HELP."}, + {"PUBSUB", []interface{}{"CHANNELS", "FOOBAR1", "FOOBAR2"}, "ERR Unknown subcommand or wrong number of arguments for 'CHANNELS'. Try PUBSUB HELP."}, + } { + _, c, done := setup(t) + + _, err := c.Do(command.command, command.args...) + mustFail(t, err, command.err) + + done() + } +} + +func TestPubSubInteraction(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + + ch := make(chan struct{}, 8) + tasks := [5]func(){} + directTasks := [4]func(){} + + for i, tester := range [5]func(t *testing.T, s *Miniredis, c redis.Conn, chCtl chan struct{}){ + testPubSubInteractionSub1, + testPubSubInteractionSub2, + testPubSubInteractionPsub1, + testPubSubInteractionPsub2, + testPubSubInteractionPub, + } { + tasks[i] = runActualRedisClientForPubSub(t, s, ch, tester) + } + + for i, tester := range [4]func(t *testing.T, s *Miniredis, chCtl chan struct{}){ + testPubSubInteractionDirectSub1, + testPubSubInteractionDirectSub2, + testPubSubInteractionDirectPsub1, + testPubSubInteractionDirectPsub2, + } { + directTasks[i] = runDirectRedisClientForPubSub(t, s, ch, tester) + } + + for _, task := range tasks { + task() + } + + for _, task := range directTasks { + task() + } +} + +func testPubSubInteractionSub1(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { + assertCorrectSubscriptionsCounts( + t, + []int64{1, 2, 3, 4}, + runCmdDuringPubSub(t, c, 3, "SUBSCRIBE", "event1", "event2", "event3", "event4"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '1', '2', '3', '4') + + assertCorrectSubscriptionsCounts( + t, + []int64{3, 2}, + runCmdDuringPubSub(t, c, 1, "UNSUBSCRIBE", "event2", "event3"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '1', '4') +} + +func testPubSubInteractionSub2(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { + assertCorrectSubscriptionsCounts( + t, + []int64{1, 2, 3, 4}, + runCmdDuringPubSub(t, c, 3, "SUBSCRIBE", "event3", "event4", "event5", "event6"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '3', '4', '5', '6') + + assertCorrectSubscriptionsCounts( + t, + []int64{3, 2}, + runCmdDuringPubSub(t, c, 1, "UNSUBSCRIBE", "event4", "event5"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '3', '6') +} + +func testPubSubInteractionDirectSub1(t *testing.T, s *Miniredis, ch chan struct{}) { + sub := s.NewSubscriber() + defer sub.Close() + + sub.Subscribe("event1") + sub.Subscribe("event3") + sub.Subscribe("event4") + sub.Subscribe("event6") + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '1', '3', '4', '6') + + sub.Unsubscribe("event1") + sub.Unsubscribe("event4") + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '3', '6') +} + +func testPubSubInteractionDirectSub2(t *testing.T, s *Miniredis, ch chan struct{}) { + sub := s.NewSubscriber() + defer sub.Close() + + sub.Subscribe("event2") + sub.Subscribe("event3") + sub.Subscribe("event4") + sub.Subscribe("event5") + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '2', '3', '4', '5') + + sub.Unsubscribe("event3") + sub.Unsubscribe("event5") + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '2', '4') +} + +func testPubSubInteractionPsub1(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { + assertCorrectSubscriptionsCounts( + t, + []int64{1, 2, 3, 4}, + runCmdDuringPubSub(t, c, 3, "PSUBSCRIBE", "event[ab1]", "event[cd]", "event[ef3]", "event[gh]"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '1', '3', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h') + + assertCorrectSubscriptionsCounts( + t, + []int64{3, 2}, + runCmdDuringPubSub(t, c, 1, "PUNSUBSCRIBE", "event[cd]", "event[ef3]"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '1', 'a', 'b', 'g', 'h') +} + +func testPubSubInteractionPsub2(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { + assertCorrectSubscriptionsCounts( + t, + []int64{1, 2, 3, 4}, + runCmdDuringPubSub(t, c, 3, "PSUBSCRIBE", "event[ef]", "event[gh4]", "event[ij]", "event[kl6]"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '4', '6', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l') + + assertCorrectSubscriptionsCounts( + t, + []int64{3, 2}, + runCmdDuringPubSub(t, c, 1, "PUNSUBSCRIBE", "event[gh4]", "event[ij]"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '6', 'e', 'f', 'k', 'l') +} + +func testPubSubInteractionDirectPsub1(t *testing.T, s *Miniredis, ch chan struct{}) { + sub := s.NewSubscriber() + defer sub.Close() + + sub.Psubscribe(`event[ab1]`) + sub.Psubscribe(`event[ef3]`) + sub.Psubscribe(`event[gh]`) + sub.Psubscribe(`event[kl6]`) + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '1', '3', '6', 'a', 'b', 'e', 'f', 'g', 'h', 'k', 'l') + + sub.Punsubscribe(`event[ab1]`) + sub.Punsubscribe(`event[gh]`) + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '3', '6', 'e', 'f', 'k', 'l') +} + +func testPubSubInteractionDirectPsub2(t *testing.T, s *Miniredis, ch chan struct{}) { + sub := s.NewSubscriber() + defer sub.Close() + + sub.Psubscribe(`event[cd]`) + sub.Psubscribe(`event[ef]`) + sub.Psubscribe(`event[gh4]`) + sub.Psubscribe(`event[ij]`) + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '4', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j') + + sub.Punsubscribe(`event[ef]`) + sub.Punsubscribe(`event[ij]`) + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '4', 'c', 'd', 'g', 'h') +} + +func testPubSubInteractionPub(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { + testPubSubInteractionPubStage1(t, s, c, ch) + testPubSubInteractionPubStage2(t, s, c, ch) +} + +func testPubSubInteractionPubStage1(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { + for i := uint8(0); i < 8; i++ { + <-ch + } + + for _, pattern := range []string{ + "", + "event?", + } { + assertActiveChannelsDuringPubSub(t, s, c, pattern, []string{ + "event1", "event2", "event3", "event4", "event5", "event6", + }) + } + + assertActiveChannelsDuringPubSub(t, s, c, "*[123]", []string{ + "event1", "event2", "event3", + }) + + assertNumSubDuringPubSub(t, s, c, map[string]int{ + "event1": 2, "event2": 2, "event3": 4, "event4": 4, "event5": 2, "event6": 2, + "event[ab1]": 0, "event[cd]": 0, "event[ef3]": 0, "event[gh]": 0, "event[ij]": 0, "event[kl6]": 0, + }) + + assertNumPatDuringPubSub(t, s, c, 16) + + for _, message := range [18]struct { + channelSuffix rune + subscribers uint8 + }{ + {'1', 4}, {'2', 2}, {'3', 6}, {'4', 6}, {'5', 2}, {'6', 4}, + {'a', 2}, {'b', 2}, {'c', 2}, {'d', 2}, {'e', 4}, {'f', 4}, + {'g', 4}, {'h', 4}, {'i', 2}, {'j', 2}, {'k', 2}, {'l', 2}, + } { + suffix := string([]rune{message.channelSuffix}) + replies := runCmdDuringPubSub(t, c, 0, "PUBLISH", "event"+suffix, "message"+suffix) + equals(t, []interface{}{int64(message.subscribers)}, replies) + } +} + +func testPubSubInteractionPubStage2(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { + for i := uint8(0); i < 8; i++ { + <-ch + } + + for _, pattern := range []string{ + "", + "event?", + } { + assertActiveChannelsDuringPubSub(t, s, c, pattern, []string{ + "event1", "event2", "event3", "event4", "event6", + }) + } + + assertActiveChannelsDuringPubSub(t, s, c, "*[123]", []string{"event1", "event2", "event3"}) + + assertNumSubDuringPubSub(t, s, c, map[string]int{ + "event1": 1, "event2": 1, "event3": 2, "event4": 2, "event5": 0, "event6": 2, + "event[ab1]": 0, "event[cd]": 0, "event[ef3]": 0, "event[gh]": 0, "event[ij]": 0, "event[kl6]": 0, + }) + + assertNumPatDuringPubSub(t, s, c, 8) + + for _, message := range [18]struct { + channelSuffix rune + subscribers uint8 + }{ + {'1', 2}, {'2', 1}, {'3', 3}, {'4', 3}, {'5', 0}, {'6', 4}, + {'a', 1}, {'b', 1}, {'c', 1}, {'d', 1}, {'e', 2}, {'f', 2}, + {'g', 2}, {'h', 2}, {'i', 0}, {'j', 0}, {'k', 2}, {'l', 2}, + } { + suffix := string([]rune{message.channelSuffix}) + equals(t, int(message.subscribers), s.Publish("event"+suffix, "message"+suffix)) + } +} + +func runActualRedisClientForPubSub(t *testing.T, s *Miniredis, chCtl chan struct{}, tester func(t *testing.T, s *Miniredis, c redis.Conn, chCtl chan struct{})) (wait func()) { + t.Helper() + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + ch := make(chan struct{}) + + go func() { + t.Helper() + + tester(t, s, c, chCtl) + c.Close() + close(ch) + }() + + return func() { <-ch } +} + +func runDirectRedisClientForPubSub(t *testing.T, s *Miniredis, chCtl chan struct{}, tester func(t *testing.T, s *Miniredis, chCtl chan struct{})) (wait func()) { + t.Helper() + + ch := make(chan struct{}) + + go func() { + t.Helper() + + tester(t, s, chCtl) + close(ch) + }() + + return func() { <-ch } +} + +func runCmdDuringPubSub(t *testing.T, c redis.Conn, followUpMessages uint8, command string, args ...interface{}) (replies []interface{}) { + t.Helper() + + replies = make([]interface{}, followUpMessages+1) + + reply, err := c.Do(command, args...) + ok(t, err) + + replies[0] = reply + i := 1 + + for ; followUpMessages > 0; followUpMessages-- { + reply, err := c.Receive() + ok(t, err) + + replies[i] = reply + i++ + } + + return +} + +func assertCorrectSubscriptionsCounts(t *testing.T, subscriptionsCounts []int64, replies []interface{}) { + t.Helper() + + for i, subscriptionsCount := range subscriptionsCounts { + if arrayReply, isArrayReply := replies[i].([]interface{}); isArrayReply && len(arrayReply) > 2 { + equals(t, subscriptionsCount, arrayReply[2]) + } + } +} + +func receiveMessagesDuringPubSub(t *testing.T, c redis.Conn, suffixes ...rune) { + t.Helper() + + for _, suffix := range suffixes { + msg, err := c.Receive() + ok(t, err) + + suff := string([]rune{suffix}) + equals(t, []interface{}{[]byte("message"), []byte("event" + suff), []byte("message" + suff)}, msg) + } +} + +func receiveMessagesDirectlyDuringPubSub(t *testing.T, sub *Subscriber, suffixes ...rune) { + t.Helper() + + for _, suffix := range suffixes { + suff := string([]rune{suffix}) + equals(t, PubsubMessage{"event" + suff, "message" + suff}, <-sub.Messages()) + } +} + +func assertActiveChannelsDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, pattern string, channels []string) { + var args []interface{} + if pattern == "" { + args = []interface{}{"CHANNELS"} + } else { + args = []interface{}{"CHANNELS", pattern} + } + + actual, err := redis.Strings(c.Do("PUBSUB", args...)) + ok(t, err) + + equals(t, channels, actual) + + equals(t, channels, s.PubSubChannels(pattern)) +} + +func assertNumSubDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, channels map[string]int) { + t.Helper() + + args := make([]interface{}, 1+len(channels)) + args[0] = "NUMSUB" + i := 1 + + flatChannels := make([]string, len(channels)) + j := 0 + + for channel := range channels { + args[i] = channel + i++ + + flatChannels[j] = channel + j++ + } + + a, err := redis.Values(c.Do("PUBSUB", args...)) + ok(t, err) + equals(t, len(channels)*2, len(a)) + + actualChannels := make(map[string]int, len(a)) + + var currentChannel string + currentState := uint8(0) + + for _, item := range a { + if currentState&uint8(1) == 0 { + if channelString, channelIsString := item.([]byte); channelIsString { + currentChannel = string(channelString) + currentState |= 2 + } else { + currentState &= ^uint8(2) + } + + currentState |= 1 + } else { + if subsInt, subsIsInt := item.(int64); subsIsInt && currentState&uint8(2) != 0 { + actualChannels[currentChannel] = int(subsInt) + } + + currentState &= ^uint8(1) + } + } + + equals(t, channels, actualChannels) + + equals(t, channels, s.PubSubNumSub(flatChannels...)) +} + +func assertNumPatDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, numPat int) { + t.Helper() + + a, err := redis.Int(c.Do("PUBSUB", "NUMPAT")) + ok(t, err) + equals(t, numPat, a) + + equals(t, numPat, s.PubSubNumPat()) +} diff --git a/cmd_scripting.go b/cmd_scripting.go index 296e61b9..13b3deca 100644 --- a/cmd_scripting.go +++ b/cmd_scripting.go @@ -113,6 +113,10 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + script, args := args[0], args[1:] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -129,6 +133,9 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } sha, args := args[0], args[1:] @@ -152,6 +159,9 @@ func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } subcmd, args := args[0], args[1:] diff --git a/cmd_server.go b/cmd_server.go index c021644c..1ed9ad2f 100644 --- a/cmd_server.go +++ b/cmd_server.go @@ -27,6 +27,9 @@ func (m *Miniredis) cmdDbsize(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -45,10 +48,12 @@ func (m *Miniredis) cmdFlushall(c *server.Peer, cmd string, args []string) { c.WriteError(msgSyntaxError) return } - if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { m.flushAll() @@ -66,10 +71,12 @@ func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) { c.WriteError(msgSyntaxError) return } - if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { m.db(ctx.selectedDB).flush() @@ -77,7 +84,7 @@ func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) { }) } -// TIME: time values are returned in string format instead of int +// TIME func (m *Miniredis) cmdTime(c *server.Peer, cmd string, args []string) { if len(args) > 0 { setDirty(c) @@ -87,6 +94,9 @@ func (m *Miniredis) cmdTime(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { now := m.now diff --git a/cmd_set.go b/cmd_set.go index 2220cf55..4cb6ee1b 100644 --- a/cmd_set.go +++ b/cmd_set.go @@ -39,6 +39,9 @@ func (m *Miniredis) cmdSadd(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, elems := args[0], args[1:] @@ -65,6 +68,9 @@ func (m *Miniredis) cmdScard(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -96,6 +102,9 @@ func (m *Miniredis) cmdSdiff(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } keys := args @@ -125,6 +134,9 @@ func (m *Miniredis) cmdSdiffstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } dest, keys := args[0], args[1:] @@ -153,6 +165,9 @@ func (m *Miniredis) cmdSinter(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } keys := args @@ -182,6 +197,9 @@ func (m *Miniredis) cmdSinterstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } dest, keys := args[0], args[1:] @@ -210,6 +228,9 @@ func (m *Miniredis) cmdSismember(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, value := args[0], args[1] @@ -244,6 +265,9 @@ func (m *Miniredis) cmdSmembers(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -279,6 +303,9 @@ func (m *Miniredis) cmdSmove(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } src, dst, member := args[0], args[1], args[2] @@ -320,6 +347,9 @@ func (m *Miniredis) cmdSpop(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] @@ -401,6 +431,9 @@ func (m *Miniredis) cmdSrandmember(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] count := 0 @@ -467,6 +500,9 @@ func (m *Miniredis) cmdSrem(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, fields := args[0], args[1:] @@ -497,6 +533,9 @@ func (m *Miniredis) cmdSunion(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } keys := args @@ -526,6 +565,9 @@ func (m *Miniredis) cmdSunionstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } dest, keys := args[0], args[1:] @@ -554,6 +596,9 @@ func (m *Miniredis) cmdSscan(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] cursor, err := strconv.Atoi(args[1]) diff --git a/cmd_sorted_set.go b/cmd_sorted_set.go index 5252b015..afc0a632 100644 --- a/cmd_sorted_set.go +++ b/cmd_sorted_set.go @@ -52,6 +52,9 @@ func (m *Miniredis) cmdZadd(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] var ( @@ -170,6 +173,9 @@ func (m *Miniredis) cmdZcard(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -200,6 +206,9 @@ func (m *Miniredis) cmdZcount(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseFloatRange(args[1]) @@ -244,6 +253,9 @@ func (m *Miniredis) cmdZincrby(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] delta, err := strconv.ParseFloat(args[1], 64) @@ -276,6 +288,9 @@ func (m *Miniredis) cmdZinterstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } destination := args[0] numKeys, err := strconv.Atoi(args[1]) @@ -405,6 +420,9 @@ func (m *Miniredis) cmdZlexcount(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseLexrange(args[1]) @@ -453,6 +471,9 @@ func (m *Miniredis) makeCmdZrange(reverse bool) server.Cmd { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -526,6 +547,9 @@ func (m *Miniredis) makeCmdZrangebylex(reverse bool) server.Cmd { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseLexrange(args[1]) @@ -637,6 +661,9 @@ func (m *Miniredis) makeCmdZrangebyscore(reverse bool) server.Cmd { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseFloatRange(args[1]) @@ -758,6 +785,9 @@ func (m *Miniredis) makeCmdZrank(reverse bool) server.Cmd { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, member := args[0], args[1] @@ -798,6 +828,9 @@ func (m *Miniredis) cmdZrem(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, members := args[0], args[1:] @@ -834,6 +867,9 @@ func (m *Miniredis) cmdZremrangebylex(c *server.Peer, cmd string, args []string) if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseLexrange(args[1]) @@ -884,6 +920,9 @@ func (m *Miniredis) cmdZremrangebyrank(c *server.Peer, cmd string, args []string if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -931,6 +970,9 @@ func (m *Miniredis) cmdZremrangebyscore(c *server.Peer, cmd string, args []strin if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseFloatRange(args[1]) @@ -979,6 +1021,9 @@ func (m *Miniredis) cmdZscore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, member := args[0], args[1] @@ -1128,6 +1173,9 @@ func (m *Miniredis) cmdZunionstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } destination := args[0] numKeys, err := strconv.Atoi(args[1]) @@ -1256,6 +1304,9 @@ func (m *Miniredis) cmdZscan(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] cursor, err := strconv.Atoi(args[1]) diff --git a/cmd_string.go b/cmd_string.go index 930da992..b99a34bd 100644 --- a/cmd_string.go +++ b/cmd_string.go @@ -47,6 +47,9 @@ func (m *Miniredis) cmdSet(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } var ( nx = false // set iff not exists @@ -133,6 +136,9 @@ func (m *Miniredis) cmdSetex(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] ttl, err := strconv.Atoi(args[1]) @@ -168,6 +174,9 @@ func (m *Miniredis) cmdPsetex(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] ttl, err := strconv.Atoi(args[1]) @@ -203,6 +212,9 @@ func (m *Miniredis) cmdSetnx(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, value := args[0], args[1] @@ -229,6 +241,9 @@ func (m *Miniredis) cmdMset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } if len(args)%2 != 0 { setDirty(c) @@ -261,6 +276,9 @@ func (m *Miniredis) cmdMsetnx(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } if len(args)%2 != 0 { setDirty(c) @@ -306,6 +324,9 @@ func (m *Miniredis) cmdGet(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -335,6 +356,9 @@ func (m *Miniredis) cmdGetset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, value := args[0], args[1] @@ -369,6 +393,9 @@ func (m *Miniredis) cmdMget(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -400,6 +427,9 @@ func (m *Miniredis) cmdIncr(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -429,6 +459,9 @@ func (m *Miniredis) cmdIncrby(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] delta, err := strconv.Atoi(args[1]) @@ -466,6 +499,9 @@ func (m *Miniredis) cmdIncrbyfloat(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] delta, err := strconv.ParseFloat(args[1], 64) @@ -503,6 +539,9 @@ func (m *Miniredis) cmdDecr(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -532,6 +571,9 @@ func (m *Miniredis) cmdDecrby(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] delta, err := strconv.Atoi(args[1]) @@ -569,6 +611,9 @@ func (m *Miniredis) cmdStrlen(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -594,6 +639,9 @@ func (m *Miniredis) cmdAppend(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, value := args[0], args[1] @@ -622,6 +670,9 @@ func (m *Miniredis) cmdGetrange(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -660,6 +711,9 @@ func (m *Miniredis) cmdSetrange(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] pos, err := strconv.Atoi(args[1]) @@ -705,6 +759,9 @@ func (m *Miniredis) cmdBitcount(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } var ( useRange = false @@ -767,6 +824,9 @@ func (m *Miniredis) cmdBitop(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } var ( op = strings.ToUpper(args[0]) @@ -844,6 +904,9 @@ func (m *Miniredis) cmdBitpos(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] bit, err := strconv.Atoi(args[1]) @@ -926,6 +989,9 @@ func (m *Miniredis) cmdGetbit(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] bit, err := strconv.Atoi(args[1]) @@ -969,6 +1035,9 @@ func (m *Miniredis) cmdSetbit(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] bit, err := strconv.Atoi(args[1]) diff --git a/cmd_transactions.go b/cmd_transactions.go index 64912cf5..d90ff73d 100644 --- a/cmd_transactions.go +++ b/cmd_transactions.go @@ -24,6 +24,9 @@ func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } ctx := getCtx(c) @@ -47,6 +50,9 @@ func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } ctx := getCtx(c) @@ -57,6 +63,8 @@ func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) { if ctx.dirtyTransaction { c.WriteError("EXECABORT Transaction discarded because of previous errors.") + // a failed EXEC finishes the tx + stopTx(ctx) return } @@ -93,6 +101,9 @@ func (m *Miniredis) cmdDiscard(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } ctx := getCtx(c) if !inTx(ctx) { @@ -114,6 +125,9 @@ func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } ctx := getCtx(c) if inTx(ctx) { @@ -141,6 +155,9 @@ func (m *Miniredis) cmdUnwatch(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } // Doesn't matter if UNWATCH is in a TX or not. Looks like a Redis bug to me. unwatch(getCtx(c)) diff --git a/direct.go b/direct.go index ca41449f..6d71f8b2 100644 --- a/direct.go +++ b/direct.go @@ -547,3 +547,42 @@ func (db *RedisDB) ZScore(k, member string) (float64, error) { } return db.ssetScore(k, member), nil } + +// Publish a message to subscribers. Returns the number of receivers. +func (m *Miniredis) Publish(channel, message string) int { + m.Lock() + defer m.Unlock() + return m.publish(channel, message) +} + +// PubSubChannels is "PUBSUB CHANNELS ". An empty pattern is fine +// (meaning all channels). +// Returned channels will be ordered alphabetically. +func (m *Miniredis) PubSubChannels(pattern string) []string { + m.Lock() + defer m.Unlock() + + return activeChannels(m.allSubscribers(), pattern) +} + +// PubSubNumSub is "PUBSUB NUMSUB [channels]". It returns all channels with their +// subscriber count. +func (m *Miniredis) PubSubNumSub(channels ...string) map[string]int { + m.Lock() + defer m.Unlock() + + subs := m.allSubscribers() + res := map[string]int{} + for _, channel := range channels { + res[channel] = countSubs(subs, channel) + } + return res +} + +// PubSubNumPat is "PUBSUB NUMPAT" +func (m *Miniredis) PubSubNumPat() int { + m.Lock() + defer m.Unlock() + + return countPsubs(m.allSubscribers()) +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..f15be7b4 --- /dev/null +++ b/go.mod @@ -0,0 +1,7 @@ +module github.com/alicebob/miniredis + +require ( + github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 + github.com/gomodule/redigo v2.0.0+incompatible + github.com/yuin/gopher-lua v0.0.0-20190206043414-8bfc7677f583 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..3bc2524c --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 h1:45bxf7AZMwWcqkLzDAQugVEwedisr5nRJ1r+7LYnv0U= +github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/yuin/gopher-lua v0.0.0-20190206043414-8bfc7677f583 h1:SZPG5w7Qxq7bMcMVl6e3Ht2X7f+AAGQdzjkbyOnNNZ8= +github.com/yuin/gopher-lua v0.0.0-20190206043414-8bfc7677f583/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/integration/Makefile b/integration/Makefile index d95d141b..79218443 100644 --- a/integration/Makefile +++ b/integration/Makefile @@ -4,3 +4,7 @@ all: test test: go test -tags int + +commands.txt: ../*.go + grep Register ../*.go|perl -ne '/"(.*)"/ && print "$$1\n"' | sort > commands.txt + diff --git a/integration/ephemeral.go b/integration/ephemeral.go index 1b5477b4..3dcb7e98 100644 --- a/integration/ephemeral.go +++ b/integration/ephemeral.go @@ -37,7 +37,7 @@ func runRedis(extraConfig string) (*ephemeral, string) { port := arbitraryPort() // we prefer the executable from ./redis_src, if any. See ./get_redis.sh - os.Setenv("PATH", fmt.Sprintf("%s:PATH", localSrc)) + os.Setenv("PATH", fmt.Sprintf("%s:%s", localSrc, os.Getenv("PATH"))) c := exec.Command(executable, "-") stdin, err := c.StdinPipe() diff --git a/integration/generic_test.go b/integration/generic_test.go index 45d4e09a..f1487a26 100644 --- a/integration/generic_test.go +++ b/integration/generic_test.go @@ -18,6 +18,14 @@ func TestEcho(t *testing.T) { ) } +func TestPing(t *testing.T) { + testCommands(t, + succ("PING"), + succ("PING", "hello world"), + fail("PING", "hello", "world"), + ) +} + func TestKeys(t *testing.T) { testCommands(t, succ("SET", "one", "1"), diff --git a/integration/pubsub_test.go b/integration/pubsub_test.go new file mode 100644 index 00000000..2263fe88 --- /dev/null +++ b/integration/pubsub_test.go @@ -0,0 +1,404 @@ +// +build int + +package main + +import ( + "sync" + "testing" + + "github.com/alicebob/miniredis" +) + +func TestSubscribe(t *testing.T) { + testCommands(t, + fail("SUBSCRIBE"), + + succ("SUBSCRIBE", "foo"), + succ("UNSUBSCRIBE"), + + succ("SUBSCRIBE", "foo"), + succ("UNSUBSCRIBE", "foo"), + + succ("SUBSCRIBE", "foo", "bar"), + succ("UNSUBSCRIBE", "foo", "bar"), + + succ("SUBSCRIBE", -1), + succ("UNSUBSCRIBE", -1), + ) +} + +func TestPSubscribe(t *testing.T) { + testCommands(t, + fail("PSUBSCRIBE"), + + succ("PSUBSCRIBE", "foo"), + succ("PUNSUBSCRIBE"), + + succ("PSUBSCRIBE", "foo"), + succ("PUNSUBSCRIBE", "foo"), + + succ("PSUBSCRIBE", "foo", "bar"), + succ("PUNSUBSCRIBE", "foo", "bar"), + + succ("PSUBSCRIBE", "f?o"), + succ("PUNSUBSCRIBE", "f?o"), + + succ("PSUBSCRIBE", "f*o"), + succ("PUNSUBSCRIBE", "f*o"), + + succ("PSUBSCRIBE", "f[oO]o"), + succ("PUNSUBSCRIBE", "f[oO]o"), + + succ("PSUBSCRIBE", "f\\?o"), + succ("PUNSUBSCRIBE", "f\\?o"), + + succ("PSUBSCRIBE", "f\\*o"), + succ("PUNSUBSCRIBE", "f\\*o"), + + succ("PSUBSCRIBE", "f\\[oO]o"), + succ("PUNSUBSCRIBE", "f\\[oO]o"), + + succ("PSUBSCRIBE", "f\\\\oo"), + succ("PUNSUBSCRIBE", "f\\\\oo"), + + succ("PSUBSCRIBE", -1), + succ("PUNSUBSCRIBE", -1), + ) +} + +func TestPublish(t *testing.T) { + testCommands(t, + fail("PUBLISH"), + fail("PUBLISH", "foo"), + succ("PUBLISH", "foo", "bar"), + fail("PUBLISH", "foo", "bar", "deadbeef"), + succ("PUBLISH", -1, -2), + ) +} + +func TestPubSub(t *testing.T) { + testCommands(t, + fail("PUBSUB"), + fail("PUBSUB", "FOO"), + + succ("PUBSUB", "CHANNELS"), + succ("PUBSUB", "CHANNELS", "foo"), + fail("PUBSUB", "CHANNELS", "foo", "bar"), + succ("PUBSUB", "CHANNELS", "f?o"), + succ("PUBSUB", "CHANNELS", "f*o"), + succ("PUBSUB", "CHANNELS", "f[oO]o"), + succ("PUBSUB", "CHANNELS", "f\\?o"), + succ("PUBSUB", "CHANNELS", "f\\*o"), + succ("PUBSUB", "CHANNELS", "f\\[oO]o"), + succ("PUBSUB", "CHANNELS", "f\\\\oo"), + succ("PUBSUB", "CHANNELS", -1), + + succ("PUBSUB", "NUMSUB"), + succ("PUBSUB", "NUMSUB", "foo"), + succ("PUBSUB", "NUMSUB", "foo", "bar"), + succ("PUBSUB", "NUMSUB", -1), + + succ("PUBSUB", "NUMPAT"), + fail("PUBSUB", "NUMPAT", "foo"), + ) +} + +func TestPubsubFull(t *testing.T) { + var wg1 sync.WaitGroup + wg1.Add(1) + testMultiCommands(t, + func(r chan<- command, _ *miniredis.Miniredis) { + r <- succ("SUBSCRIBE", "news", "sport") + r <- receive() + wg1.Done() + r <- receive() + r <- receive() + r <- receive() + r <- succ("UNSUBSCRIBE", "news", "sport") + r <- receive() + }, + func(r chan<- command, _ *miniredis.Miniredis) { + wg1.Wait() + r <- succ("PUBLISH", "news", "revolution!") + r <- succ("PUBLISH", "news", "alien invasion!") + r <- succ("PUBLISH", "sport", "lady biked too fast") + r <- succ("PUBLISH", "gossip", "man bites dog") + }, + ) +} + +func TestPubsubMulti(t *testing.T) { + var wg1 sync.WaitGroup + wg1.Add(2) + testMultiCommands(t, + func(r chan<- command, _ *miniredis.Miniredis) { + r <- succ("SUBSCRIBE", "news", "sport") + r <- receive() + wg1.Done() + r <- receive() + r <- receive() + r <- receive() + r <- succ("UNSUBSCRIBE", "news", "sport") + r <- receive() + }, + func(r chan<- command, _ *miniredis.Miniredis) { + r <- succ("SUBSCRIBE", "sport") + wg1.Done() + r <- receive() + r <- succ("UNSUBSCRIBE", "sport") + }, + func(r chan<- command, _ *miniredis.Miniredis) { + wg1.Wait() + r <- succ("PUBLISH", "news", "revolution!") + r <- succ("PUBLISH", "news", "alien invasion!") + r <- succ("PUBLISH", "sport", "lady biked too fast") + }, + ) +} + +func TestPubsubSelect(t *testing.T) { + testClients2(t, func(r1, r2 chan<- command) { + r1 <- succ("SUBSCRIBE", "news", "sport") + r1 <- receive() + r2 <- succ("SELECT", 3) + r2 <- succ("PUBLISH", "news", "revolution!") + r1 <- receive() + }) +} + +func TestPubsubMode(t *testing.T) { + // most commands aren't allowed in publish mode + testCommands(t, + succ("SUBSCRIBE", "news", "sport"), + receive(), + succ("PING"), + succ("PING", "foo"), + fail("ECHO", "foo"), + fail("HGET", "foo", "bar"), + fail("SET", "foo", "bar"), + succ("QUIT"), + ) + + e := "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context" + cbs := []command{ + succ("SUBSCRIBE", "news"), + // failWith(e, "PING"), + // failWith(e, "PSUBSCRIBE"), + // failWith(e, "PUNSUBSCRIBE"), + // failWith(e, "QUIT"), + // failWith(e, "SUBSCRIBE"), + // failWith(e, "UNSUBSCRIBE"), + + failWith(e, "APPEND", "foo", "foo"), + failWith(e, "AUTH", "foo"), + failWith(e, "BITCOUNT", "foo"), + failWith(e, "BITOP", "OR", "foo", "bar"), + failWith(e, "BITPOS", "foo", 0), + failWith(e, "BLPOP", "key", 1), + failWith(e, "BRPOP", "key", 1), + failWith(e, "BRPOPLPUSH", "foo", "bar", 1), + failWith(e, "DBSIZE"), + failWith(e, "DECR", "foo"), + failWith(e, "DECRBY", "foo", 3), + failWith(e, "DEL", "foo"), + failWith(e, "DISCARD"), + failWith(e, "ECHO", "foo"), + failWith(e, "EVAL", "foo", "{}"), + failWith(e, "EVALSHA", "foo", "{}"), + failWith(e, "EXEC"), + failWith(e, "EXISTS", "foo"), + failWith(e, "EXPIRE", "foo", 12), + failWith(e, "EXPIREAT", "foo", 12), + failWith(e, "FLUSHALL"), + failWith(e, "FLUSHDB"), + failWith(e, "GET", "foo"), + failWith(e, "GETBIT", "foo", 12), + failWith(e, "GETRANGE", "foo", 12, 12), + failWith(e, "GETSET", "foo", "bar"), + failWith(e, "HDEL", "foo", "bar"), + failWith(e, "HEXISTS", "foo", "bar"), + failWith(e, "HGET", "foo", "bar"), + failWith(e, "HGETALL", "foo"), + failWith(e, "HINCRBY", "foo", "bar", 12), + failWith(e, "HINCRBYFLOAT", "foo", "bar", 12.34), + failWith(e, "HKEYS", "foo"), + failWith(e, "HLEN", "foo"), + failWith(e, "HMGET", "foo", "bar"), + failWith(e, "HMSET", "foo", "bar", "baz"), + failWith(e, "HSCAN", "foo", 0), + failWith(e, "HSET", "foo", "bar", "baz"), + failWith(e, "HSETNX", "foo", "bar", "baz"), + failWith(e, "HVALS", "foo"), + failWith(e, "INCR", "foo"), + failWith(e, "INCRBY", "foo", 12), + failWith(e, "INCRBYFLOAT", "foo", 12.34), + failWith(e, "KEYS", "*"), + failWith(e, "LINDEX", "foo", 0), + failWith(e, "LINSERT", "foo", "after", "bar", 0), + failWith(e, "LLEN", "foo"), + failWith(e, "LPOP", "foo"), + failWith(e, "LPUSH", "foo", "bar"), + failWith(e, "LPUSHX", "foo", "bar"), + failWith(e, "LRANGE", "foo", 1, 1), + failWith(e, "LREM", "foo", 0, "bar"), + failWith(e, "LSET", "foo", 0, "bar"), + failWith(e, "LTRIM", "foo", 0, 0), + failWith(e, "MGET", "foo", "bar"), + failWith(e, "MOVE", "foo", "bar"), + failWith(e, "MSET", "foo", "bar"), + failWith(e, "MSETNX", "foo", "bar"), + failWith(e, "MULTI"), + failWith(e, "PERSIST", "foo"), + failWith(e, "PEXPIRE", "foo", 12), + failWith(e, "PEXPIREAT", "foo", 12), + failWith(e, "PSETEX", "foo", 12, "bar"), + failWith(e, "PTTL", "foo"), + failWith(e, "PUBLISH", "foo", "bar"), + failWith(e, "PUBSUB", "CHANNELS"), + failWith(e, "RANDOMKEY"), + failWith(e, "RENAME", "foo", "bar"), + failWith(e, "RENAMENX", "foo", "bar"), + failWith(e, "RPOP", "foo"), + failWith(e, "RPOPLPUSH", "foo", "bar"), + failWith(e, "RPUSH", "foo", "bar"), + failWith(e, "RPUSHX", "foo", "bar"), + failWith(e, "SADD", "foo", "bar"), + failWith(e, "SCAN", 0), + failWith(e, "SCARD", "foo"), + failWith(e, "SCRIPT", "FLUSH"), + failWith(e, "SDIFF", "foo"), + failWith(e, "SDIFFSTORE", "foo", "bar"), + failWith(e, "SELECT", 12), + failWith(e, "SET", "foo", "bar"), + failWith(e, "SETBIT", "foo", 0, 1), + failWith(e, "SETEX", "foo", 12, "bar"), + failWith(e, "SETNX", "foo", "bar"), + failWith(e, "SETRANGE", "foo", 0, "bar"), + failWith(e, "SINTER", "foo", "bar"), + failWith(e, "SINTERSTORE", "foo", "bar", "baz"), + failWith(e, "SISMEMBER", "foo", "bar"), + failWith(e, "SMEMBERS", "foo"), + failWith(e, "SMOVE", "foo", "bar", "baz"), + failWith(e, "SPOP", "foo"), + failWith(e, "SRANDMEMBER", "foo"), + failWith(e, "SREM", "foo", "bar", "baz"), + failWith(e, "SSCAN", "foo", 0), + failWith(e, "STRLEN", "foo"), + failWith(e, "SUNION", "foo", "bar"), + failWith(e, "SUNIONSTORE", "foo", "bar", "baz"), + failWith(e, "TIME"), + failWith(e, "TTL", "foo"), + failWith(e, "TYPE", "foo"), + failWith(e, "UNWATCH"), + failWith(e, "WATCH", "foo"), + failWith(e, "ZADD", "foo", "INCR", 1, "bar"), + failWith(e, "ZCARD", "foo"), + failWith(e, "ZCOUNT", "foo", 0, 1), + failWith(e, "ZINCRBY", "foo", "bar", 12), + failWith(e, "ZINTERSTORE", "foo", 1, "bar"), + failWith(e, "ZLEXCOUNT", "foo", "-", "+"), + failWith(e, "ZRANGE", "foo", 0, -1), + failWith(e, "ZRANGEBYLEX", "foo", "-", "+"), + failWith(e, "ZRANGEBYSCORE", "foo", 0, 1), + failWith(e, "ZRANK", "foo", "bar"), + failWith(e, "ZREM", "foo", "bar"), + failWith(e, "ZREMRANGEBYLEX", "foo", "-", "+"), + failWith(e, "ZREMRANGEBYRANK", "foo", 0, 1), + failWith(e, "ZREMRANGEBYSCORE", "foo", 0, 1), + failWith(e, "ZREVRANGE", "foo", 0, -1), + failWith(e, "ZREVRANGEBYLEX", "foo", "+", "-"), + failWith(e, "ZREVRANGEBYSCORE", "foo", 0, 1), + failWith(e, "ZREVRANK", "foo", "bar"), + failWith(e, "ZSCAN", "foo", 0), + failWith(e, "ZSCORE", "foo", "bar"), + failWith(e, "ZUNIONSTORE", "foo", 1, "bar"), + } + testCommands(t, cbs...) +} + +func TestSubscriptions(t *testing.T) { + testClients2(t, func(r1, r2 chan<- command) { + r1 <- succ("SUBSCRIBE", "foo", "bar", "foo") + r2 <- succ("PUBSUB", "NUMSUB") + r1 <- succ("UNSUBSCRIBE", "bar", "bar", "bar") + r2 <- succ("PUBSUB", "NUMSUB") + }) +} + +func TestPubsubUnsub(t *testing.T) { + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("SUBSCRIBE", "news", "sport") + c1 <- receive() + c2 <- succSorted("PUBSUB", "CHANNELS") + c1 <- succ("QUIT") + c2 <- succSorted("PUBSUB", "CHANNELS") + }) +} + +func TestPubsubTx(t *testing.T) { + // publish is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("SUBSCRIBE", "foo") + c2 <- succ("MULTI") + c2 <- succ("PUBSUB", "CHANNELS") + c2 <- succ("PUBLISH", "foo", "hello one") + c2 <- fail("GET") + c2 <- succ("PUBLISH", "foo", "hello two") + c2 <- fail("EXEC") + + c2 <- succ("PUBLISH", "foo", "post tx") + c1 <- receive() + }) + + // SUBSCRIBE is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("MULTI") + c1 <- succ("SUBSCRIBE", "foo") + c2 <- succ("PUBSUB", "CHANNELS") + c1 <- succ("EXEC") + c2 <- succ("PUBSUB", "CHANNELS") + + c1 <- fail("MULTI") // we're in SUBSCRIBE mode + }) + + // DISCARDing a tx prevents from entering publish mode + testCommands(t, + succ("MULTI"), + succ("SUBSCRIBE", "foo"), + succ("DISCARD"), + succ("PUBSUB", "CHANNELS"), + ) + + // UNSUBSCRIBE is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("MULTI") + c1 <- succ("SUBSCRIBE", "foo") + c1 <- succ("UNSUBSCRIBE", "foo") + c2 <- succ("PUBSUB", "CHANNELS") + c1 <- succ("EXEC") + c2 <- succ("PUBSUB", "CHANNELS") + c1 <- succ("PUBSUB", "CHANNELS") + }) + + // PSUBSCRIBE is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("MULTI") + c1 <- succ("PSUBSCRIBE", "foo") + c2 <- succ("PUBSUB", "NUMPAT") + c1 <- succ("EXEC") + c2 <- succ("PUBSUB", "NUMPAT") + + c1 <- fail("MULTI") // we're in SUBSCRIBE mode + }) + + // PUNSUBSCRIBE is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("MULTI") + c1 <- succ("PSUBSCRIBE", "foo") + c1 <- succ("PUNSUBSCRIBE", "foo") + c2 <- succ("PUBSUB", "NUMPAT") + c1 <- succ("EXEC") + c2 <- succ("PUBSUB", "NUMPAT") + c1 <- succ("PUBSUB", "NUMPAT") + }) +} diff --git a/integration/test.go b/integration/test.go index 5db7a36f..f55f8715 100644 --- a/integration/test.go +++ b/integration/test.go @@ -4,6 +4,7 @@ package main import ( "bytes" + "context" "fmt" "reflect" "sort" @@ -16,12 +17,13 @@ import ( ) type command struct { - cmd string // 'GET', 'SET', &c. - args []interface{} - error bool // Whether the command should return an error or not. - sort bool // Sort real redis's result. Used for 'keys'. - loosely bool // Don't compare values, only structure. (for random things) - errorSub string // Both errors need this substring + cmd string // 'GET', 'SET', &c. + args []interface{} + error bool // Whether the command should return an error or not. + sort bool // Sort real redis's result. Used for 'keys'. + loosely bool // Don't compare values, only structure. (for random things) + errorSub string // Both errors need this substring + receiveOnly bool // no command, only receive. For pubsub messages. } func succ(cmd string, args ...interface{}) command { @@ -78,6 +80,13 @@ func failLoosely(cmd string, args ...interface{}) command { } } +// don't send a message, only read one. For pubsub messages. +func receive() command { + return command{ + receiveOnly: true, + } +} + // ok fails the test if an err is not nil. func ok(tb testing.TB, err error) { tb.Helper() @@ -109,7 +118,7 @@ func testMultiCommands(t *testing.T, cs ...func(chan<- command, *miniredis.Minir var wg sync.WaitGroup for _, c := range cs { - // one connections per cs + // one connection per cs cMini, err := redis.Dial("tcp", sMini.Addr()) ok(t, err) @@ -134,6 +143,58 @@ func testMultiCommands(t *testing.T, cs ...func(chan<- command, *miniredis.Minir wg.Wait() } +// like testCommands, but multiple connections +func testClients2(t *testing.T, f func(c1, c2 chan<- command)) { + t.Helper() + sMini, err := miniredis.Run() + ok(t, err) + defer sMini.Close() + + sReal, realAddr := Redis() + defer sReal.Close() + + type aChan struct { + c chan command + cMini, cReal redis.Conn + } + chans := [2]aChan{} + for i := range chans { + gen := make(chan command) + cMini, err := redis.Dial("tcp", sMini.Addr()) + ok(t, err) + + cReal, err := redis.Dial("tcp", realAddr) + ok(t, err) + chans[i] = aChan{ + c: gen, + cMini: cMini, + cReal: cReal, + } + } + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + f(chans[0].c, chans[1].c) + cancel() + for _, c := range chans { + close(c.c) + } + }() + +loop: + for { + select { + case <-ctx.Done(): + break loop + case cm := <-chans[0].c: + runCommand(t, chans[0].cMini, chans[0].cReal, cm) + case cm := <-chans[1].c: + runCommand(t, chans[1].cMini, chans[1].cReal, cm) + } + } +} + func testAuthCommands(t *testing.T, passwd string, commands ...command) { sMini, err := miniredis.Run() ok(t, err) @@ -160,8 +221,17 @@ func runCommands(t *testing.T, realAddr, miniAddr string, commands []command) { func runCommand(t *testing.T, cMini, cReal redis.Conn, p command) { t.Helper() - vReal, errReal := cReal.Do(p.cmd, p.args...) - vMini, errMini := cMini.Do(p.cmd, p.args...) + var ( + vReal, vMini interface{} + errReal, errMini error + ) + if p.receiveOnly { + vReal, errReal = cReal.Receive() + vMini, errMini = cMini.Receive() + } else { + vReal, errReal = cReal.Do(p.cmd, p.args...) + vMini, errMini = cMini.Do(p.cmd, p.args...) + } if p.error { if errReal == nil { t.Errorf("got no error from realredis. case: %#v", p) @@ -211,6 +281,8 @@ func runCommand(t *testing.T, cMini, cReal redis.Conn, p command) { } else { if !reflect.DeepEqual(vReal, vMini) { t.Errorf("value error. expected: %#v got: %#v case: %#v", vReal, vMini, p) + dump(vReal, " --real-") + dump(vMini, " --mini-") return } } @@ -259,3 +331,16 @@ func looselyEqual(a, b interface{}) bool { panic(fmt.Sprintf("unhandled case, got a %#v", a)) } } + +func dump(r interface{}, prefix string) { + if ls, ok := r.([]interface{}); ok { + for _, k := range ls { + switch k := k.(type) { + case []byte: + fmt.Printf(" %s %s\n", prefix, string(k)) + default: + fmt.Printf(" %s %#v\n", prefix, k) + } + } + } +} diff --git a/integration/tx_test.go b/integration/tx_test.go index e3f958d4..b34cb599 100644 --- a/integration/tx_test.go +++ b/integration/tx_test.go @@ -142,4 +142,28 @@ func TestTx(t *testing.T) { succ("BITOP", "BROKEN", "str", ""), succ("EXEC"), ) + + // fail on invalid command + testCommands(t, + succ("MULTI"), + fail("GET"), + fail("EXEC"), + ) + + /* FIXME + // fail on unknown command + testCommands(t, + succ("MULTI"), + fail("NOSUCH"), + fail("EXEC"), + ) + */ + + // failed EXEC cleaned up the tx + testCommands(t, + succ("MULTI"), + fail("GET"), + fail("EXEC"), + succ("MULTI"), + ) } diff --git a/miniredis.go b/miniredis.go index 0688bdfe..a861237d 100644 --- a/miniredis.go +++ b/miniredis.go @@ -47,14 +47,15 @@ type RedisDB struct { // Miniredis is a Redis server implementation. type Miniredis struct { sync.Mutex - srv *server.Server - port int - password string - dbs map[int]*RedisDB - selectedDB int // DB id used in the direct Get(), Set() &c. - scripts map[string]string // sha1 -> lua src - signal *sync.Cond - now time.Time // used to make a duration from EXPIREAT. time.Now() if not set. + srv *server.Server + port int + password string + dbs map[int]*RedisDB + selectedDB int // DB id used in the direct Get(), Set() &c. + scripts map[string]string // sha1 -> lua src + signal *sync.Cond + now time.Time // used to make a duration from EXPIREAT. time.Now() if not set. + subscribers map[*Subscriber]struct{} } type txCmd func(*server.Peer, *connCtx) @@ -70,15 +71,17 @@ type connCtx struct { selectedDB int // selected DB authenticated bool // auth enabled and a valid AUTH seen transaction []txCmd // transaction callbacks. Or nil. - dirtyTransaction bool // any error during QUEUEing. - watch map[dbKey]uint // WATCHed keys. + dirtyTransaction bool // any error during QUEUEing + watch map[dbKey]uint // WATCHed keys + subscriber *Subscriber // client is in PUBSUB mode if not nil } // NewMiniRedis makes a new, non-started, Miniredis object. func NewMiniRedis() *Miniredis { m := Miniredis{ - dbs: map[int]*RedisDB{}, - scripts: map[string]string{}, + dbs: map[int]*RedisDB{}, + scripts: map[string]string{}, + subscribers: map[*Subscriber]struct{}{}, } m.signal = sync.NewCond(&m) return &m @@ -137,6 +140,7 @@ func (m *Miniredis) start(s *server.Server) error { commandsString(m) commandsHash(m) commandsList(m) + commandsPubsub(m) commandsSet(m) commandsSortedSet(m) commandsTransaction(m) @@ -154,12 +158,18 @@ func (m *Miniredis) Restart() error { // Close shuts down a Miniredis. func (m *Miniredis) Close() { m.Lock() - defer m.Unlock() + if m.srv == nil { + m.Unlock() return } - m.srv.Close() + srv := m.srv m.srv = nil + m.Unlock() + + // the OnDisconnect callbacks can lock m, so run Close() outside the lock. + srv.Close() + } // RequireAuth makes every connection need to AUTH first. Disable again by @@ -323,6 +333,21 @@ func (m *Miniredis) handleAuth(c *server.Peer) bool { return true } +// handlePubsub sends an error to the user if the connection is in PUBSUB mode. +// It'll return true if it did. +func (m *Miniredis) checkPubsub(c *server.Peer) bool { + m.Lock() + defer m.Unlock() + + ctx := getCtx(c) + if ctx.subscriber == nil { + return false + } + + c.WriteError("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") + return true +} + func getCtx(c *server.Peer) *connCtx { if c.Ctx == nil { c.Ctx = &connCtx{} @@ -371,3 +396,80 @@ func setDirty(c *server.Peer) { func setAuthenticated(c *server.Peer) { getCtx(c).authenticated = true } + +func (m *Miniredis) addSubscriber(s *Subscriber) { + m.subscribers[s] = struct{}{} +} + +// closes and remove the subscriber. +func (m *Miniredis) removeSubscriber(s *Subscriber) { + _, ok := m.subscribers[s] + delete(m.subscribers, s) + if ok { + s.Close() + } +} + +func (m *Miniredis) publish(c, msg string) int { + n := 0 + for s := range m.subscribers { + n += s.Publish(c, msg) + } + return n +} + +// enter 'subscribed state', or return the existing one. +func (m *Miniredis) subscribedState(c *server.Peer) *Subscriber { + ctx := getCtx(c) + sub := ctx.subscriber + if sub != nil { + return sub + } + + sub = newSubscriber() + m.addSubscriber(sub) + + c.OnDisconnect(func() { + m.Lock() + m.removeSubscriber(sub) + m.Unlock() + }) + + ctx.subscriber = sub + + go monitorPublish(c, sub.publish) + + return sub +} + +// whenever the p?sub count drops to 0 subscribed state should be stopped, and +// all redis commands are allowed again. +func endSubscriber(m *Miniredis, c *server.Peer) { + ctx := getCtx(c) + if sub := ctx.subscriber; sub != nil { + m.removeSubscriber(sub) // will Close() the sub + } + ctx.subscriber = nil +} + +// Start a new pubsub subscriber. It can (un) subscribe to channels and +// patterns, and has a channel to get published messages. Close it with +// Close(). +// Does not close itself when there are no subscriptions left. +func (m *Miniredis) NewSubscriber() *Subscriber { + sub := newSubscriber() + + m.Lock() + m.addSubscriber(sub) + m.Unlock() + + return sub +} + +func (m *Miniredis) allSubscribers() []*Subscriber { + var subs []*Subscriber + for s := range m.subscribers { + subs = append(subs, s) + } + return subs +} diff --git a/pubsub.go b/pubsub.go new file mode 100644 index 00000000..2d0f04ec --- /dev/null +++ b/pubsub.go @@ -0,0 +1,211 @@ +package miniredis + +import ( + "regexp" + "sort" + "sync" + + "github.com/alicebob/miniredis/server" +) + +// PubsubMessage is what gets broadcasted over pubsub channels. +type PubsubMessage struct { + Channel string + Message string +} + +// Subscriber has the (p)subscriptions. +type Subscriber struct { + publish chan PubsubMessage + channels map[string]struct{} + patterns map[string]*regexp.Regexp + mu sync.Mutex +} + +// Make a new subscriber. The channel is not buffered, so you will need to keep +// reading using Messages(). Use Close() when done, or unsubscribe. +func newSubscriber() *Subscriber { + return &Subscriber{ + publish: make(chan PubsubMessage), + channels: map[string]struct{}{}, + patterns: map[string]*regexp.Regexp{}, + } +} + +// Close the listening channel +func (s *Subscriber) Close() { + close(s.publish) +} + +// Count the total number of channels and patterns +func (s *Subscriber) Count() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.count() +} + +func (s *Subscriber) count() int { + return len(s.channels) + len(s.patterns) +} + +// Subscribe to a channel. Returns the total number of (p)subscriptions after +// subscribing. +func (s *Subscriber) Subscribe(c string) int { + s.mu.Lock() + defer s.mu.Unlock() + + s.channels[c] = struct{}{} + return s.count() +} + +// Unsubscribe a channel. Returns the total number of (p)subscriptions after +// unsubscribing. +func (s *Subscriber) Unsubscribe(c string) int { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.channels, c) + return s.count() +} + +// Subscribe to a pattern. Returns the total number of (p)subscriptions after +// subscribing. +func (s *Subscriber) Psubscribe(pat string) int { + s.mu.Lock() + defer s.mu.Unlock() + + s.patterns[pat] = compileChannelPattern(pat) + return s.count() +} + +// Unsubscribe a pattern. Returns the total number of (p)subscriptions after +// unsubscribing. +func (s *Subscriber) Punsubscribe(pat string) int { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.patterns, pat) + return s.count() +} + +// List all subscribed channels, in alphabetical order +func (s *Subscriber) Channels() []string { + s.mu.Lock() + defer s.mu.Unlock() + + var cs []string + for c := range s.channels { + cs = append(cs, c) + } + sort.Strings(cs) + return cs +} + +// List all subscribed patterns, in alphabetical order +func (s *Subscriber) Patterns() []string { + s.mu.Lock() + defer s.mu.Unlock() + + var ps []string + for p := range s.patterns { + ps = append(ps, p) + } + sort.Strings(ps) + return ps +} + +// Publish a message. Will return return how often we sent the message (can be +// a match for a subscription and for a psubscription. +func (s *Subscriber) Publish(c, msg string) int { + s.mu.Lock() + defer s.mu.Unlock() + + found := 0 + +subs: + for sub := range s.channels { + if sub == c { + s.publish <- PubsubMessage{c, msg} + found++ + break subs + } + } + +pats: + for _, pat := range s.patterns { + if pat.MatchString(c) { + s.publish <- PubsubMessage{c, msg} + found++ + break pats + } + } + + return found +} + +// The channel to read messages for this subscriber +func (s *Subscriber) Messages() <-chan PubsubMessage { + return s.publish +} + +// List all pubsub channels. If `pat` isn't empty channels names must match the +// pattern. Channels are returned alphabetically. +func activeChannels(subs []*Subscriber, pat string) []string { + channels := map[string]struct{}{} + for _, s := range subs { + for c := range s.channels { + channels[c] = struct{}{} + } + } + + var cpat *regexp.Regexp + if pat != "" { + cpat = compileChannelPattern(pat) + } + + var cs []string + for k := range channels { + if cpat != nil && !cpat.MatchString(k) { + continue + } + cs = append(cs, k) + } + sort.Strings(cs) + return cs +} + +// Count all subscribed (not psubscribed) clients for the given channel +// pattern. Channels are returned alphabetically. +func countSubs(subs []*Subscriber, channel string) int { + n := 0 + for _, p := range subs { + for c := range p.channels { + if c == channel { + n++ + break + } + } + } + return n +} + +// Count the total of all client psubscriptions. +func countPsubs(subs []*Subscriber) int { + n := 0 + for _, p := range subs { + n += len(p.patterns) + } + return n +} + +func monitorPublish(conn *server.Peer, msgs <-chan PubsubMessage) { + for msg := range msgs { + conn.Block(func(c *server.Writer) { + c.WriteLen(3) + c.WriteBulk("message") + c.WriteBulk(msg.Channel) + c.WriteBulk(msg.Message) + c.Flush() + }) + } +} diff --git a/redis.go b/redis.go index 49ff7bc3..ab353dbc 100644 --- a/redis.go +++ b/redis.go @@ -29,6 +29,7 @@ const ( msgInvalidKeysNumber = "ERR Number of keys can't be greater than number of args" msgNegativeKeysNumber = "ERR Number of keys can't be negative" msgFScriptUsage = "ERR Unknown subcommand or wrong number of arguments for '%s'. Try SCRIPT HELP." + msgFPubsubUsage = "ERR Unknown subcommand or wrong number of arguments for '%s'. Try PUBSUB HELP." msgSingleElementPair = "ERR INCR option supports a single increment-element pair" msgNoScriptFound = "NOSCRIPT No matching script. Please use EVAL." ) diff --git a/server/server.go b/server/server.go index 1796453d..c924bb3c 100644 --- a/server/server.go +++ b/server/server.go @@ -23,6 +23,8 @@ func errUnknownCommand(cmd string, args []string) string { // Cmd is what Register expects type Cmd func(c *Peer, cmd string, args []string) +type DisconnectHandler func(c *Peer) + // Server is a simple redis server type Server struct { l net.Listener @@ -124,19 +126,27 @@ func (s *Server) Register(cmd string, f Cmd) error { func (s *Server) servePeer(c net.Conn) { r := bufio.NewReader(c) - cl := &Peer{ + peer := &Peer{ w: bufio.NewWriter(c), } + defer func() { + for _, f := range peer.onDisconnect { + f() + } + }() + for { args, err := readArray(r) if err != nil { return } - s.dispatch(cl, args) - cl.w.Flush() - if cl.closed { + s.dispatch(peer, args) + peer.Flush() + s.mu.Lock() + closed := peer.closed + s.mu.Unlock() + if closed { c.Close() - return } } } @@ -182,29 +192,52 @@ func (s *Server) TotalConnections() int { // Peer is a client connected to the server type Peer struct { - w *bufio.Writer - closed bool - Ctx interface{} // anything goes, server won't touch this + w *bufio.Writer + closed bool + Ctx interface{} // anything goes, server won't touch this + onDisconnect []func() // list of callbacks + mu sync.Mutex // for Block() } // Flush the write buffer. Called automatically after every redis command func (c *Peer) Flush() { + c.mu.Lock() + defer c.mu.Unlock() c.w.Flush() } // Close the client connection after the current command is done. func (c *Peer) Close() { + c.mu.Lock() + defer c.mu.Unlock() c.closed = true } +// Register a function to execute on disconnect. There can be multiple +// functions registered. +func (c *Peer) OnDisconnect(f func()) { + c.onDisconnect = append(c.onDisconnect, f) +} + +// issue multiple calls, guarded with a mutex +func (c *Peer) Block(f func(*Writer)) { + c.mu.Lock() + defer c.mu.Unlock() + f(&Writer{c.w}) +} + // WriteError writes a redis 'Error' func (c *Peer) WriteError(e string) { - fmt.Fprintf(c.w, "-%s\r\n", toInline(e)) + c.Block(func(w *Writer) { + w.WriteError(e) + }) } // WriteInline writes a redis inline string func (c *Peer) WriteInline(s string) { - fmt.Fprintf(c.w, "+%s\r\n", toInline(s)) + c.Block(func(w *Writer) { + w.WriteInline(s) + }) } // WriteOK write the inline string `OK` @@ -214,22 +247,30 @@ func (c *Peer) WriteOK() { // WriteBulk writes a bulk string func (c *Peer) WriteBulk(s string) { - fmt.Fprintf(c.w, "$%d\r\n%s\r\n", len(s), s) + c.Block(func(w *Writer) { + w.WriteBulk(s) + }) } // WriteNull writes a redis Null element func (c *Peer) WriteNull() { - fmt.Fprintf(c.w, "$-1\r\n") + c.Block(func(w *Writer) { + w.WriteNull() + }) } // WriteLen starts an array with the given length func (c *Peer) WriteLen(n int) { - fmt.Fprintf(c.w, "*%d\r\n", n) + c.Block(func(w *Writer) { + w.WriteLen(n) + }) } // WriteInt writes an integer func (c *Peer) WriteInt(i int) { - fmt.Fprintf(c.w, ":%d\r\n", i) + c.Block(func(w *Writer) { + w.WriteInt(i) + }) } func toInline(s string) string { @@ -240,3 +281,41 @@ func toInline(s string) string { return r }, s) } + +// A Writer is given to the callback in Block() +type Writer struct { + w *bufio.Writer +} + +// WriteError writes a redis 'Error' +func (w *Writer) WriteError(e string) { + fmt.Fprintf(w.w, "-%s\r\n", toInline(e)) +} + +func (w *Writer) WriteLen(n int) { + fmt.Fprintf(w.w, "*%d\r\n", n) +} + +// WriteBulk writes a bulk string +func (w *Writer) WriteBulk(s string) { + fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(s), s) +} + +// WriteInt writes an integer +func (w *Writer) WriteInt(i int) { + fmt.Fprintf(w.w, ":%d\r\n", i) +} + +// WriteNull writes a redis Null element +func (w *Writer) WriteNull() { + fmt.Fprintf(w.w, "$-1\r\n") +} + +// WriteInline writes a redis inline string +func (w *Writer) WriteInline(s string) { + fmt.Fprintf(w.w, "+%s\r\n", toInline(s)) +} + +func (w *Writer) Flush() { + w.w.Flush() +}