Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use trie for prefix matching #851

Merged
merged 13 commits into from
Jul 5, 2019
64 changes: 34 additions & 30 deletions publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
package badger

import (
"bytes"
"sync"

"github.com/dgraph-io/badger/pb"
"github.com/dgraph-io/badger/trie"
"github.com/dgraph-io/badger/y"
)

Expand All @@ -35,13 +35,15 @@ type publisher struct {
pubCh chan requests
subscribers map[uint64]subscriber
nextID uint64
indexer *trie.Trie
}

func newPublisher() *publisher {
return &publisher{
pubCh: make(chan requests, 1000),
subscribers: make(map[uint64]subscriber),
nextID: 0,
indexer: trie.NewTrie(),
}
}

Expand Down Expand Up @@ -72,42 +74,37 @@ func (p *publisher) listenForUpdates(c *y.Closer) {
}

func (p *publisher) publishUpdates(reqs requests) {
kvs := &pb.KVList{}
p.Lock()
defer func() {
p.Unlock()
// Release all the request.
reqs.DecrRef()
}()

// TODO: Optimize this, so we can figure out key -> subscriber quickly, without iterating over
// all the prefixes.
// TODO: Use trie to find subscribers.
for _, s := range p.subscribers {
// BUG: This would send out the same entry multiple times on multiple matches for the same
// subscriber.
for _, prefix := range s.prefixes {
for _, req := range reqs {
for _, e := range req.Entries {
if bytes.HasPrefix(e.Key, prefix) {
// TODO: Maybe we can optimize this by creating the KV once and sending it
// over to multiple subscribers.
k := y.SafeCopy(nil, e.Key)
kv := &pb.KV{
Key: y.ParseKey(k),
Value: y.SafeCopy(nil, e.Value),
UserMeta: []byte{e.UserMeta},
ExpiresAt: e.ExpiresAt,
Version: y.ParseTs(k),
}
kvs.Kv = append(kvs.Kv, kv)
batchedUpdates := make(map[uint64]*pb.KVList)
for _, req := range reqs {
for _, e := range req.Entries {
ids := p.indexer.Get(e.Key)
if len(ids) > 0 {
k := y.SafeCopy(nil, e.Key)
kv := &pb.KV{
Key: y.ParseKey(k),
Value: y.SafeCopy(nil, e.Value),
Meta: []byte{e.UserMeta},
ExpiresAt: e.ExpiresAt,
Version: y.ParseTs(k),
}
for id := range ids {
if _, ok := batchedUpdates[id]; !ok {
batchedUpdates[id] = &pb.KVList{}
}
batchedUpdates[id].Kv = append(batchedUpdates[id].Kv, kv)
}
}
}
if len(kvs.GetKv()) > 0 {
s.sendCh <- kvs
}
}

for id, kvs := range batchedUpdates {
p.subscribers[id].sendCh <- kvs
}
}

Expand All @@ -123,6 +120,9 @@ func (p *publisher) newSubscriber(c *y.Closer, prefixes ...[]byte) (<-chan *pb.K
sendCh: ch,
subCloser: c,
}
for _, prefix := range prefixes {
p.indexer.Add(prefix, id)
}
return ch, id
}

Expand All @@ -131,6 +131,9 @@ func (p *publisher) cleanSubscribers() {
p.Lock()
defer p.Unlock()
for id, s := range p.subscribers {
for _, prefix := range s.prefixes {
p.indexer.Delete(prefix, id)
}
delete(p.subscribers, id)
s.subCloser.SignalAndWait()
}
Expand All @@ -139,14 +142,15 @@ func (p *publisher) cleanSubscribers() {
func (p *publisher) deleteSubscriber(id uint64) {
p.Lock()
defer p.Unlock()
if _, ok := p.subscribers[id]; !ok {
return
if s, ok := p.subscribers[id]; ok {
for _, prefix := range s.prefixes {
p.indexer.Delete(prefix, id)
}
}
delete(p.subscribers, id)
}

func (p *publisher) sendUpdates(reqs []*request) {
// TODO: Prefix check before pushing into pubCh.
if p.noOfSubscribers() != 0 {
p.pubCh <- reqs
}
Expand Down
97 changes: 97 additions & 0 deletions trie/trie.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright 2019 Dgraph Labs, Inc. and Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package trie

type node struct {
children map[byte]*node
ids []uint64
}

func newNode() *node {
return &node{
children: make(map[byte]*node),
ids: []uint64{},
}
}

// Trie datastructure.
type Trie struct {
root *node
}

// NewTrie returns Trie.
func NewTrie() *Trie {
return &Trie{
root: newNode(),
}
}

// Add adds the id in the trie for the given prefix path.
func (t *Trie) Add(prefix []byte, id uint64) {
node := t.root
for _, val := range prefix {
child, ok := node.children[val]
if !ok {
child = newNode()
node.children[val] = child
}
node = child
}
// We only need to add the id to the last node of the given prefix.
node.ids = append(node.ids, id)
}

// Get returns prefix matched ids for the given key.
func (t *Trie) Get(key []byte) map[uint64]struct{} {
out := make(map[uint64]struct{})
node := t.root
for _, val := range key {
child, ok := node.children[val]
if !ok {
break
}
// We need ids of the all the node in the matching key path.
for _, id := range child.ids {
out[id] = struct{}{}
}
node = child
}
return out
}

// Delete will delete the id if the id exist in the given index path.
func (t *Trie) Delete(index []byte, id uint64) {
node := t.root
for _, val := range index {
child, ok := node.children[val]
if !ok {
return
}
node = child
}
// We're just removing the id not the hanging path.
out := node.ids[:0]
for _, val := range node.ids {
if val != id {
out = append(out, val)
}
}
for i := len(out); i < len(node.ids); i++ {
node.ids[i] = 0 // garbage collecting
}
node.ids = out
}
52 changes: 52 additions & 0 deletions trie/trie_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2019 Dgraph Labs, Inc. and Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package trie

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestGet(t *testing.T) {
trie := NewTrie()
trie.Add([]byte("hello"), 1)
trie.Add([]byte("hello"), 3)
trie.Add([]byte("hello"), 4)
trie.Add([]byte("hel"), 20)
trie.Add([]byte("he"), 20)
trie.Add([]byte("badger"), 30)
ids := trie.Get([]byte("hel"))
require.Equal(t, 1, len(ids))

require.Equal(t, map[uint64]struct{}{20: struct{}{}}, ids)
ids = trie.Get([]byte("badger"))
require.Equal(t, 1, len(ids))
require.Equal(t, map[uint64]struct{}{30: struct{}{}}, ids)
ids = trie.Get([]byte("hello"))
require.Equal(t, 4, len(ids))
require.Equal(t, map[uint64]struct{}{1: struct{}{}, 3: struct{}{}, 4: struct{}{}, 20: struct{}{}}, ids)
}

func TestTrieDelete(t *testing.T) {
trie := NewTrie()
trie.Add([]byte("hello"), 1)
trie.Add([]byte("hello"), 3)
trie.Add([]byte("hello"), 4)
trie.Delete([]byte("hello"), 4)
require.Equal(t, map[uint64]struct{}{1: struct{}{}, 3: struct{}{}}, trie.Get([]byte("hello")))
}