Skip to content

Commit

Permalink
Merge pull request #293 from josh-hook/josh/xread-resolve-last-id
Browse files Browse the repository at this point in the history
Resolve $ to latest ID in XREAD
  • Loading branch information
alicebob authored Oct 12, 2022
2 parents 12c3ec2 + ca6b916 commit 1d7ae5f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
5 changes: 4 additions & 1 deletion cmd_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -924,11 +924,14 @@ parsing:
}

opts.streams, opts.ids = args[0:len(args)/2], args[len(args)/2:]
for _, id := range opts.ids {
for i, id := range opts.ids {
if _, err := parseStreamID(id); id != `$` && err != nil {
setDirty(c)
c.WriteError(msgInvalidStreamID)
return
} else if id == "$" {
db := m.DB(getCtx(c).selectedDB)
opts.ids[i] = db.streamKeys[opts.streams[i]].lastID()
}
}
args = nil
Expand Down
29 changes: 29 additions & 0 deletions cmd_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"math"
"regexp"
"strconv"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -417,6 +418,34 @@ func TestStreamRead(t *testing.T) {
),
),
)

t.Run("blocking async", func(t *testing.T) {
// XREAD blocking test using latest ID
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
mustDo(t, c,
"XREAD", "BLOCK", "0", "STREAMS", "planets", "$",
proto.Array(
proto.Array(proto.String("planets"),
proto.Array(
proto.Array(proto.String("5-1"), proto.Strings("name", "block", "idx", "6")),
),
),
),
)
}()

// Wait for the blocking XREAD to start and then run XADD
xaddClient, err := proto.Dial(s.Addr())
ok(t, err)
defer xaddClient.Close()

_, err = xaddClient.Do("XADD", "planets", "5-1", "name", "block", "idx", "6")
ok(t, err)
wg.Wait()
})
})

t.Run("error cases", func(t *testing.T) {
Expand Down
22 changes: 22 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"
)

Expand All @@ -17,6 +18,7 @@ type streamKey struct {
entries []StreamEntry
groups map[string]*streamGroup
lastAllocatedID string
mu sync.Mutex
}

// a StreamEntry is an entry in a stream. The ID is always of the form
Expand Down Expand Up @@ -52,6 +54,7 @@ func newStreamKey() *streamKey {
}
}

// generateID doesn't lock the mutex
func (s *streamKey) generateID(now time.Time) string {
ts := uint64(now.UnixNano()) / 1_000_000

Expand All @@ -71,6 +74,7 @@ func (s *streamKey) generateID(now time.Time) string {
return next
}

// lastID doesn't lock the mutex
func (s *streamKey) lastID() string {
if len(s.entries) == 0 {
return "0-0"
Expand All @@ -80,6 +84,9 @@ func (s *streamKey) lastID() string {
}

func (s *streamKey) copy() *streamKey {
s.mu.Lock()
defer s.mu.Unlock()

cpy := &streamKey{
entries: s.entries,
}
Expand Down Expand Up @@ -194,6 +201,9 @@ func reversedStreamEntries(o []StreamEntry) []StreamEntry {
}

func (s *streamKey) createGroup(group, id string) error {
s.mu.Lock()
defer s.mu.Unlock()

if _, ok := s.groups[group]; ok {
return errors.New("BUSYGROUP Consumer Group name already exists")
}
Expand All @@ -213,6 +223,9 @@ func (s *streamKey) createGroup(group, id string) error {
// If id is empty or "*" the ID will be generated automatically.
// `values` should have an even length.
func (s *streamKey) add(entryID string, values []string, now time.Time) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()

if entryID == "" || entryID == "*" {
entryID = s.generateID(now)
}
Expand All @@ -236,13 +249,19 @@ func (s *streamKey) add(entryID string, values []string, now time.Time) (string,
}

func (s *streamKey) trim(n int) {
s.mu.Lock()
defer s.mu.Unlock()

if len(s.entries) > n {
s.entries = s.entries[len(s.entries)-n:]
}
}

// all entries after "id"
func (s *streamKey) after(id string) []StreamEntry {
s.mu.Lock()
defer s.mu.Unlock()

pos := sort.Search(len(s.entries), func(i int) bool {
return streamCmp(id, s.entries[i].ID) < 0
})
Expand All @@ -252,6 +271,9 @@ func (s *streamKey) after(id string) []StreamEntry {
// get a stream entry by ID
// Also returns the position in the entries slice, if found.
func (s *streamKey) get(id string) (int, *StreamEntry) {
s.mu.Lock()
defer s.mu.Unlock()

pos := sort.Search(len(s.entries), func(i int) bool {
return streamCmp(id, s.entries[i].ID) <= 0
})
Expand Down

0 comments on commit 1d7ae5f

Please sign in to comment.