diff --git a/hamt.go b/hamt.go index 6b6c42a..4eb4c5c 100644 --- a/hamt.go +++ b/hamt.go @@ -256,38 +256,75 @@ func (n *Node) Set(ctx context.Context, k string, v interface{}) error { return n.modifyValue(ctx, &hashBits{b: n.hash(kb)}, kb, d) } -func (n *Node) cleanChild(chnd *Node, cindex byte) error { - l := len(chnd.Pointers) - switch { - case l == 0: - return fmt.Errorf("incorrectly formed HAMT") - case l == 1: - // TODO: only do this if its a value, cant do this for shards unless pairs requirements are met. +// the number of links to child nodes this node contains +func (n *Node) directChildCount() int { + count := 0 + for _, p := range n.Pointers { + if p.isShard() { + count++ + } + } + return count +} - ps := chnd.Pointers[0] - if ps.isShard() { - return nil +// the number of KV entries this node contains +func (n *Node) directKVCount() int { + count := 0 + for _, p := range n.Pointers { + if !p.isShard() { + count = count + len(p.KVs) } + } + return count +} - return n.setChild(cindex, ps) - case l <= arrayWidth: - var chvals []*KV - for _, p := range chnd.Pointers { - if p.isShard() { - return nil - } +// This happens after deletes to ensure that we retain canonical form for the +// given set of data this HAMT contains. This is a key part of the CHAMP +// algorithm. Any node that could be represented as a bucket in a parent node +// should be collapsed as such. This collapsing process could continue back up +// the tree as far as necessary to represent the data in the minimal HAMT form. +// This operation is done from a parent perspective, so we clean the child +// below us first and then our parent cleans us. +func (n *Node) cleanChild(chnd *Node, cindex byte) error { + if chnd.directChildCount() != 0 { + // child has its own children, nothing to collapse + return nil + } - for _, sp := range p.KVs { - if len(chvals) == arrayWidth { - return nil - } - chvals = append(chvals, sp) - } - } - return n.setChild(cindex, &Pointer{KVs: chvals}) - default: + if chnd.directKVCount() > arrayWidth { + // child contains more local elements than could be collapsed return nil } + + l := len(chnd.Pointers) + if l == 0 { + return fmt.Errorf("incorrectly formed HAMT") + } + + if l == 1 { + // The case where the child node has a single bucket, which we know can + // only contain `arrayWidth` elements (maximum), so we need to pull that + // bucket up into this node. + // This case should only happen when it bubbles up from the case below + // where a lower child has its elements compacted into a single bucket. We + // shouldn't be able to reach this block unless a delete has been + // performed on a lower block and we are performing a post-delete clean on + // a parent block. + return n.setChild(cindex, chnd.Pointers[0]) + } + + // The case where the child node contains enough elements to fit in a + // single bucket and therefore can't justify its existence as a node on its + // own. So we collapse all entries into a single bucket and replace the + // link to the child with that bucket. + // This may cause cascading collapses if this is the only bucket in the + // current node, that case will be handled by our parent node by the l==1 + // case above. + var chvals []*KV + for _, p := range chnd.Pointers { + chvals = append(chvals, p.KVs...) + } + return n.setChild(cindex, &Pointer{KVs: chvals}) } func (n *Node) modifyValue(ctx context.Context, hv *hashBits, k []byte, v *cbg.Deferred) error { diff --git a/hamt_test.go b/hamt_test.go index 753cf5b..9175fa1 100644 --- a/hamt_test.go +++ b/hamt_test.go @@ -113,6 +113,174 @@ func TestOverflow(t *testing.T) { } } +func TestFillAndCollapse(t *testing.T) { + ctx := context.Background() + cs := cbor.NewCborStore(newMockBlocks()) + root := NewNode(cs, UseHashFunction(identityHash)) + val := randValue() + + // start with a single node and a single full bucket + if err := root.Set(ctx, "AAAAAA11", val); err != nil { + t.Fatal(err) + } + if err := root.Set(ctx, "AAAAAA12", val); err != nil { + t.Fatal(err) + } + if err := root.Set(ctx, "AAAAAA21", val); err != nil { + t.Fatal(err) + } + + st := stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 1 || st.totalKvs != 3 || st.counts[3] != 1 { + t.Fatal("Should be 1 node with 1 bucket") + } + + baseCid, err := cs.Put(ctx, root) + if err != nil { + t.Fatal(err) + } + + // add a 4th colliding entry that forces a chain of new nodes to accommodate + // in a new node where there aren't collisions (7th byte) + if err := root.Set(ctx, "AAAAAA22", val); err != nil { + t.Fatal(err) + } + + st = stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 7 || st.totalKvs != 4 || st.counts[2] != 2 { + t.Fatal("Should be 7 nodes with 4 buckets") + } + + // remove and we should be back to the same structure as before + if err := root.Delete(ctx, "AAAAAA22"); err != nil { + t.Fatal(err) + } + + st = stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 1 || st.totalKvs != 3 || st.counts[3] != 1 { + t.Fatal("Should be 1 node with 1 bucket") + } + + c, err := cs.Put(ctx, root) + if err != nil { + t.Fatal(err) + } + if !c.Equals(baseCid) { + t.Fatal("CID mismatch on mutation") + } + + // insert elements that collide at the 4th position so push the tree down by + // 3 nodes + if err := root.Set(ctx, "AAA11AA", val); err != nil { + t.Fatal(err) + } + if err := root.Set(ctx, "AAA12AA", val); err != nil { + t.Fatal(err) + } + if err := root.Set(ctx, "AAA13AA", val); err != nil { + t.Fatal(err) + } + st = stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 4 || st.totalKvs != 6 || st.counts[3] != 2 { + t.Fatal("Should be 4 nodes with 2 buckets of 3") + } + + midCid, err := cs.Put(ctx, root) + if err != nil { + t.Fatal(err) + } + + // insert an overflow node that pushes the previous 4 into a separate node + if err := root.Set(ctx, "AAA14AA", val); err != nil { + t.Fatal(err) + } + + st = stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 5 || st.totalKvs != 7 || st.counts[1] != 4 || st.counts[3] != 1 { + t.Fatal("Should be 4 node with 2 buckets") + } + + // put the colliding 4th back in that will push down to full height + if err := root.Set(ctx, "AAAAAA22", val); err != nil { + t.Fatal(err) + } + + st = stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 8 || st.totalKvs != 8 || st.counts[1] != 4 || st.counts[2] != 2 { + t.Fatal("Should be 7 nodes with 5 buckets") + } + + // rewind back one step + if err := root.Delete(ctx, "AAAAAA22"); err != nil { + t.Fatal(err) + } + + st = stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 5 || st.totalKvs != 7 || st.counts[1] != 4 || st.counts[3] != 1 { + t.Fatal("Should be 4 node with 2 buckets") + } + + // rewind another step + if err := root.Delete(ctx, "AAA14AA"); err != nil { + t.Fatal(err) + } + st = stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 4 || st.totalKvs != 6 || st.counts[3] != 2 { + t.Fatal("Should be 4 nodes with 2 buckets of 3") + } + + c, err = cs.Put(ctx, root) + if err != nil { + t.Fatal(err) + } + if !c.Equals(midCid) { + t.Fatal("CID mismatch on mutation") + } + + // remove the 3 colliding node so we should be back to the initial state + if err := root.Delete(ctx, "AAA11AA"); err != nil { + t.Fatal(err) + } + if err := root.Delete(ctx, "AAA12AA"); err != nil { + t.Fatal(err) + } + if err := root.Delete(ctx, "AAA13AA"); err != nil { + t.Fatal(err) + } + + st = stats(root) + fmt.Println(st) + printHamt(root) + if st.totalNodes != 1 || st.totalKvs != 3 || st.counts[3] != 1 { + t.Fatal("Should be 1 node with 1 bucket") + } + + // should have the same CID as original + c, err = cs.Put(ctx, root) + if err != nil { + t.Fatal(err) + } + if !c.Equals(baseCid) { + t.Fatal("CID mismatch on mutation") + } +} + func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string, options ...Option) { ctx := context.Background() vals := make(map[string][]byte) @@ -249,6 +417,10 @@ type hamtStats struct { counts map[int]int } +func (hs hamtStats) String() string { + return fmt.Sprintf("nodes=%d, kvs=%d, counts=%v", hs.totalNodes, hs.totalKvs, hs.counts) +} + func stats(n *Node) *hamtStats { st := &hamtStats{counts: make(map[int]int)} statsrec(n, st)