diff --git a/basic_test.go b/basic_test.go index 8245dc12a..418a4d23f 100644 --- a/basic_test.go +++ b/basic_test.go @@ -16,19 +16,23 @@ import ( func TestBasic(t *testing.T) { tree, err := getTestTree(0) require.NoError(t, err) - up := tree.Set([]byte("1"), []byte("one")) + up, err := tree.Set([]byte("1"), []byte("one")) + require.NoError(t, err) if up { t.Error("Did not expect an update (should have been create)") } - up = tree.Set([]byte("2"), []byte("two")) + up, err = tree.Set([]byte("2"), []byte("two")) + require.NoError(t, err) if up { t.Error("Did not expect an update (should have been create)") } - up = tree.Set([]byte("2"), []byte("TWO")) + up, err = tree.Set([]byte("2"), []byte("TWO")) + require.NoError(t, err) if !up { t.Error("Expected an update") } - up = tree.Set([]byte("5"), []byte("five")) + up, err = tree.Set([]byte("5"), []byte("five")) + require.NoError(t, err) if up { t.Error("Did not expect an update (should have been create)") } @@ -38,7 +42,8 @@ func TestBasic(t *testing.T) { key := []byte{0x00} expected := "" - idx, val := tree.GetWithIndex(key) + idx, val, err := tree.GetWithIndex(key) + require.NoError(t, err) if val != nil { t.Error("Expected no value to exist") } @@ -49,7 +54,7 @@ func TestBasic(t *testing.T) { t.Errorf("Unexpected value %s", val) } - val = tree.Get(key) + val, err = tree.Get(key) if val != nil { t.Error("Fast method - expected no value to exist") } @@ -63,7 +68,8 @@ func TestBasic(t *testing.T) { key := []byte("1") expected := "one" - idx, val := tree.GetWithIndex(key) + idx, val, err := tree.GetWithIndex(key) + require.NoError(t, err) if val == nil { t.Error("Expected value to exist") } @@ -74,7 +80,8 @@ func TestBasic(t *testing.T) { t.Errorf("Unexpected value %s", val) } - val = tree.Get(key) + val, err = tree.Get(key) + require.NoError(t, err) if val == nil { t.Error("Fast method - expected value to exist") } @@ -88,7 +95,8 @@ func TestBasic(t *testing.T) { key := []byte("2") expected := "TWO" - idx, val := tree.GetWithIndex(key) + idx, val, err := tree.GetWithIndex(key) + require.NoError(t, err) if val == nil { t.Error("Expected value to exist") } @@ -99,7 +107,7 @@ func TestBasic(t *testing.T) { t.Errorf("Unexpected value %s", val) } - val = tree.Get(key) + val, err = tree.Get(key) if val == nil { t.Error("Fast method - expected value to exist") } @@ -113,7 +121,8 @@ func TestBasic(t *testing.T) { key := []byte("4") expected := "" - idx, val := tree.GetWithIndex(key) + idx, val, err := tree.GetWithIndex(key) + require.NoError(t, err) if val != nil { t.Error("Expected no value to exist") } @@ -124,7 +133,7 @@ func TestBasic(t *testing.T) { t.Errorf("Unexpected value %s", val) } - val = tree.Get(key) + val, err = tree.Get(key) if val != nil { t.Error("Fast method - expected no value to exist") } @@ -138,7 +147,8 @@ func TestBasic(t *testing.T) { key := []byte("6") expected := "" - idx, val := tree.GetWithIndex(key) + idx, val, err := tree.GetWithIndex(key) + require.NoError(t, err) if val != nil { t.Error("Expected no value to exist") } @@ -149,7 +159,7 @@ func TestBasic(t *testing.T) { t.Errorf("Unexpected value %s", val) } - val = tree.Get(key) + val, err = tree.Get(key) if val != nil { t.Error("Fast method - expected no value to exist") } @@ -163,7 +173,8 @@ func TestUnit(t *testing.T) { expectHash := func(tree *ImmutableTree, hashCount int64) { // ensure number of new hash calculations is as expected. - hash, count := tree.root.hashWithCount() + hash, count, err := tree.root.hashWithCount() + require.NoError(t, err) if count != hashCount { t.Fatalf("Expected %v new hashes, got %v", hashCount, count) } @@ -173,7 +184,8 @@ func TestUnit(t *testing.T) { return false }) // ensure that the new hash after nuking is the same as the old. - newHash, _ := tree.root.hashWithCount() + newHash, _, err := tree.root.hashWithCount() + require.NoError(t, err) if !bytes.Equal(hash, newHash) { t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash) } @@ -181,7 +193,8 @@ func TestUnit(t *testing.T) { expectSet := func(tree *MutableTree, i int, repr string, hashCount int64) { origNode := tree.root - updated := tree.Set(i2b(i), []byte{}) + updated, err := tree.Set(i2b(i), []byte{}) + require.NoError(t, err) // ensure node was added & structure is as expected. if updated || P(tree.root) != repr { t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", @@ -194,7 +207,8 @@ func TestUnit(t *testing.T) { expectRemove := func(tree *MutableTree, i int, repr string, hashCount int64) { origNode := tree.root - value, removed := tree.Remove(i2b(i)) + value, removed, err := tree.Remove(i2b(i)) + require.NoError(t, err) // ensure node was added & structure is as expected. if len(value) != 0 || !removed || P(tree.root) != repr { t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v removed:%v", @@ -208,35 +222,41 @@ func TestUnit(t *testing.T) { // Test Set cases: // Case 1: - t1 := T(N(4, 20)) + t1, err := T(N(4, 20)) + require.NoError(t, err) expectSet(t1, 8, "((4 8) 20)", 3) expectSet(t1, 25, "(4 (20 25))", 3) - t2 := T(N(4, N(20, 25))) + t2, err := T(N(4, N(20, 25))) + require.NoError(t, err) expectSet(t2, 8, "((4 8) (20 25))", 3) expectSet(t2, 30, "((4 20) (25 30))", 4) - t3 := T(N(N(1, 2), 6)) + t3, err := T(N(N(1, 2), 6)) + require.NoError(t, err) expectSet(t3, 4, "((1 2) (4 6))", 4) expectSet(t3, 8, "((1 2) (6 8))", 3) - t4 := T(N(N(1, 2), N(N(5, 6), N(7, 9)))) + t4, err := T(N(N(1, 2), N(N(5, 6), N(7, 9)))) + require.NoError(t, err) expectSet(t4, 8, "(((1 2) (5 6)) ((7 8) 9))", 5) expectSet(t4, 10, "(((1 2) (5 6)) (7 (9 10)))", 5) // Test Remove cases: - t10 := T(N(N(1, 2), 3)) + t10, err := T(N(N(1, 2), 3)) + require.NoError(t, err) expectRemove(t10, 2, "(1 3)", 1) expectRemove(t10, 3, "(1 2)", 0) - t11 := T(N(N(N(1, 2), 3), N(4, 5))) + t11, err := T(N(N(N(1, 2), 3), N(4, 5))) + require.NoError(t, err) expectRemove(t11, 4, "((1 2) (3 5))", 2) expectRemove(t11, 3, "((1 2) (4 5))", 1) @@ -287,11 +307,13 @@ func TestIntegration(t *testing.T) { for i := range records { r := randomRecord() records[i] = r - updated := tree.Set([]byte(r.key), []byte{}) + updated, err := tree.Set([]byte(r.key), []byte{}) + require.NoError(t, err) if updated { t.Error("should have not been updated") } - updated = tree.Set([]byte(r.key), []byte(r.value)) + updated, err = tree.Set([]byte(r.key), []byte(r.value)) + require.NoError(t, err) if !updated { t.Error("should have been updated") } @@ -301,31 +323,49 @@ func TestIntegration(t *testing.T) { } for _, r := range records { - if has := tree.Has([]byte(r.key)); !has { + has, err := tree.Has([]byte(r.key)) + require.NoError(t, err) + if !has { t.Error("Missing key", r.key) } - if has := tree.Has([]byte(randstr(12))); has { + + has, err = tree.Has([]byte(randstr(12))) + require.NoError(t, err) + if has { t.Error("Table has extra key") } - if val := tree.Get([]byte(r.key)); string(val) != r.value { + + val, err := tree.Get([]byte(r.key)) + require.NoError(t, err) + if string(val) != r.value { t.Error("wrong value") } } for i, x := range records { - if val, removed := tree.Remove([]byte(x.key)); !removed { + if val, removed, err := tree.Remove([]byte(x.key)); err != nil { + require.NoError(t, err) + } else if !removed { t.Error("Wasn't removed") } else if string(val) != x.value { t.Error("Wrong value") } + require.NoError(t, err) for _, r := range records[i+1:] { - if has := tree.Has([]byte(r.key)); !has { + has, err := tree.Has([]byte(r.key)) + require.NoError(t, err) + if !has { t.Error("Missing key", r.key) } - if has := tree.Has([]byte(randstr(12))); has { + + has, err = tree.Has([]byte(randstr(12))) + require.NoError(t, err) + if has { t.Error("Table has extra key") } - val := tree.Get([]byte(r.key)) + + val, err := tree.Get([]byte(r.key)) + require.NoError(t, err) if string(val) != r.value { t.Error("wrong value") } @@ -365,7 +405,8 @@ func TestIterateRange(t *testing.T) { // insert all the data for _, r := range records { - updated := tree.Set([]byte(r.key), []byte(r.value)) + updated, err := tree.Set([]byte(r.key), []byte(r.value)) + require.NoError(t, err) if updated { t.Error("should have not been updated") } @@ -443,7 +484,8 @@ func TestPersistence(t *testing.T) { require.NoError(t, err) t2.Load() for key, value := range records { - t2value := t2.Get([]byte(key)) + t2value, err := t2.Get([]byte(key)) + require.NoError(t, err) if string(t2value) != value { t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) } @@ -475,7 +517,9 @@ func TestProof(t *testing.T) { assert.NoError(t, err) assert.Equal(t, value, value2) if assert.NotNil(t, proof) { - verifyProof(t, proof, tree.WorkingHash()) + hash, err := tree.WorkingHash() + require.NoError(t, err) + verifyProof(t, proof, hash) } return false }) @@ -485,7 +529,9 @@ func TestTreeProof(t *testing.T) { db := db.NewMemDB() tree, err := NewMutableTree(db, 100) require.NoError(t, err) - assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hex.EncodeToString(tree.Hash())) + hash, err := tree.Hash() + require.NoError(t, err) + assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hex.EncodeToString(hash)) // should get false for proof with nil root value, proof, err := tree.GetWithProof([]byte("foo")) @@ -509,11 +555,14 @@ func TestTreeProof(t *testing.T) { assert.Nil(t, value) assert.NotNil(t, proof) assert.NoError(t, err) - assert.NoError(t, proof.Verify(tree.Hash())) + hash, err = tree.Hash() + assert.NoError(t, err) + assert.NoError(t, proof.Verify(hash)) assert.NoError(t, proof.VerifyAbsence([]byte("foo"))) // valid proof for real keys - root := tree.WorkingHash() + root, err := tree.WorkingHash() + assert.NoError(t, err) for _, key := range keys { value, proof, err := tree.GetWithProof(key) if assert.NoError(t, err) { diff --git a/benchmarks/bench_test.go b/benchmarks/bench_test.go index b7cd6553b..de79d9561 100644 --- a/benchmarks/bench_test.go +++ b/benchmarks/bench_test.go @@ -58,7 +58,9 @@ func commitTree(b *testing.B, t *iavl.MutableTree) { // queries random keys against live state. Keys are almost certainly not in the tree. func runQueriesFast(b *testing.B, t *iavl.MutableTree, keyLen int) { - require.True(b, t.IsFastCacheEnabled()) + isFastCacheEnabled, err := t.IsFastCacheEnabled() + require.NoError(b, err) + require.True(b, isFastCacheEnabled) for i := 0; i < b.N; i++ { q := randBytes(keyLen) t.Get(q) @@ -67,7 +69,9 @@ func runQueriesFast(b *testing.B, t *iavl.MutableTree, keyLen int) { // queries keys that are known to be in state func runKnownQueriesFast(b *testing.B, t *iavl.MutableTree, keys [][]byte) { - require.True(b, t.IsFastCacheEnabled()) // to ensure fast storage is enabled + isFastCacheEnabled, err := t.IsFastCacheEnabled() // to ensure fast storage is enabled + require.NoError(b, err) + require.True(b, isFastCacheEnabled) l := int32(len(keys)) for i := 0; i < b.N; i++ { q := keys[rand.Int31n(l)] @@ -84,7 +88,9 @@ func runQueriesSlow(b *testing.B, t *iavl.MutableTree, keyLen int) { itree, err := t.GetImmutable(version - 1) require.NoError(b, err) - require.False(b, itree.IsFastCacheEnabled()) // to ensure fast storage is not enabled + isFastCacheEnabled, err := t.IsFastCacheEnabled() // to ensure fast storage is enabled + require.NoError(b, err) + require.False(b, isFastCacheEnabled) // to ensure fast storage is not enabled b.StartTimer() for i := 0; i < b.N; i++ { @@ -102,21 +108,27 @@ func runKnownQueriesSlow(b *testing.B, t *iavl.MutableTree, keys [][]byte) { itree, err := t.GetImmutable(version - 1) require.NoError(b, err) - require.False(b, itree.IsFastCacheEnabled()) // to ensure fast storage is not enabled + isFastCacheEnabled, err := t.IsFastCacheEnabled() // to ensure fast storage is not enabled + require.NoError(b, err) + require.False(b, isFastCacheEnabled) b.StartTimer() l := int32(len(keys)) for i := 0; i < b.N; i++ { q := keys[rand.Int31n(l)] - index, value := itree.GetWithIndex(q) + index, value, err := itree.GetWithIndex(q) + require.NoError(b, err) require.True(b, index >= 0, "the index must not be negative") require.NotNil(b, value, "the value should exist") } } func runIterationFast(b *testing.B, t *iavl.MutableTree, expectedSize int) { - require.True(b, t.IsFastCacheEnabled()) // to ensure fast storage is enabled + isFastCacheEnabled, err := t.IsFastCacheEnabled() + require.NoError(b, err) + require.True(b, isFastCacheEnabled) // to ensure fast storage is enabled for i := 0; i < b.N; i++ { - itr := t.ImmutableTree.Iterator(nil, nil, false) + itr, err := t.ImmutableTree.Iterator(nil, nil, false) + require.NoError(b, err) iterate(b, itr, expectedSize) require.Nil(b, itr.Close(), ".Close should not error out") } diff --git a/cmd/iaviewer/main.go b/cmd/iaviewer/main.go index 75c47a4f7..ee881d9ea 100644 --- a/cmd/iaviewer/main.go +++ b/cmd/iaviewer/main.go @@ -47,7 +47,12 @@ func main() { switch args[0] { case "data": PrintKeys(tree) - fmt.Printf("Hash: %X\n", tree.Hash()) + hash, err := tree.Hash() + if err != nil { + fmt.Fprintf(os.Stderr, "Error hashing tree: %s\n", err) + os.Exit(1) + } + fmt.Printf("Hash: %X\n", hash) fmt.Printf("Size: %X\n", tree.Size()) case "shape": PrintShape(tree) @@ -157,7 +162,8 @@ func encodeID(id []byte) string { func PrintShape(tree *iavl.MutableTree) { // shape := tree.RenderShape(" ", nil) - shape := tree.RenderShape(" ", nodeEncoder) + //TODO: handle this error + shape, _ := tree.RenderShape(" ", nodeEncoder) fmt.Println(strings.Join(shape, "\n")) } diff --git a/export_test.go b/export_test.go index 56d3ef818..bea1befd3 100644 --- a/export_test.go +++ b/export_test.go @@ -75,20 +75,22 @@ func setupExportTreeRandom(t *testing.T) *ImmutableTree { index := r.Intn(len(keys)) key = keys[index] keys = append(keys[:index], keys[index+1:]...) - _, removed := tree.Remove(key) + _, removed, err := tree.Remove(key) + require.NoError(t, err) require.True(t, removed) case len(keys) > 0 && r.Float64() <= updateRatio: key = keys[r.Intn(len(keys))] r.Read(value) - updated := tree.Set(key, value) + updated, err := tree.Set(key, value) + require.NoError(t, err) require.True(t, updated) default: r.Read(key) r.Read(value) // If we get an update, set again - for tree.Set(key, value) { + for updated, err := tree.Set(key, value); updated && err == nil; { key = make([]byte, keySize) r.Read(key) } @@ -125,7 +127,9 @@ func setupExportTreeSized(t require.TestingT, treeSize int) *ImmutableTree { value := make([]byte, valueSize) r.Read(key) r.Read(value) - updated := tree.Set(key, value) + updated, err := tree.Set(key, value) + require.NoError(t, err) + if updated { i-- } @@ -206,13 +210,20 @@ func TestExporter_Import(t *testing.T) { require.NoError(t, err) } - require.Equal(t, tree.Hash(), newTree.Hash(), "Tree hash mismatch") + treeHash, err := tree.Hash() + require.NoError(t, err) + newTreeHash, err := newTree.Hash() + require.NoError(t, err) + + require.Equal(t, treeHash, newTreeHash, "Tree hash mismatch") require.Equal(t, tree.Size(), newTree.Size(), "Tree size mismatch") require.Equal(t, tree.Version(), newTree.Version(), "Tree version mismatch") tree.Iterate(func(key, value []byte) bool { - index, _ := tree.GetWithIndex(key) - newIndex, newValue := newTree.GetWithIndex(key) + index, _, err := tree.GetWithIndex(key) + require.NoError(t, err) + newIndex, newValue, err := newTree.GetWithIndex(key) + require.NoError(t, err) require.Equal(t, index, newIndex, "Index mismatch for key %v", key) require.Equal(t, value, newValue, "Value mismatch for key %v", key) return false diff --git a/immutable_tree.go b/immutable_tree.go index bbce557d7..5a59b1afd 100644 --- a/immutable_tree.go +++ b/immutable_tree.go @@ -51,7 +51,7 @@ func (t *ImmutableTree) String() string { // RenderShape provides a nested tree shape, ident is prepended in each level // Returns an array of strings, one per line, to join with "\n" or display otherwise -func (t *ImmutableTree) RenderShape(indent string, encoder NodeEncoder) []string { +func (t *ImmutableTree) RenderShape(indent string, encoder NodeEncoder) ([]string, error) { if encoder == nil { encoder = defaultNodeEncoder } @@ -75,25 +75,44 @@ func defaultNodeEncoder(id []byte, depth int, isLeaf bool) string { return fmt.Sprintf("%s%X", prefix, id) } -func (t *ImmutableTree) renderNode(node *Node, indent string, depth int, encoder func([]byte, int, bool) string) []string { +func (t *ImmutableTree) renderNode(node *Node, indent string, depth int, encoder func([]byte, int, bool) string) ([]string, error) { prefix := strings.Repeat(indent, depth) // handle nil if node == nil { - return []string{fmt.Sprintf("%s", prefix)} + return []string{fmt.Sprintf("%s", prefix)}, nil } // handle leaf if node.isLeaf() { here := fmt.Sprintf("%s%s", prefix, encoder(node.key, depth, true)) - return []string{here} + return []string{here}, nil } // recurse on inner node here := fmt.Sprintf("%s%s", prefix, encoder(node.hash, depth, false)) - right := t.renderNode(node.getRightNode(t), indent, depth+1, encoder) - result := t.renderNode(node.getLeftNode(t), indent, depth+1, encoder) // left + + rightNode, err := node.getRightNode(t) + if err != nil { + return nil, err + } + + leftNode, err := node.getLeftNode(t) + if err != nil { + return nil, err + } + + right, err := t.renderNode(rightNode, indent, depth+1, encoder) + if err != nil { + return nil, err + } + + result, err := t.renderNode(leftNode, indent, depth+1, encoder) // left + if err != nil { + return nil, err + } + result = append(result, here) result = append(result, right...) - return result + return result, nil } // Size returns the number of leaf nodes in the tree. @@ -118,17 +137,17 @@ func (t *ImmutableTree) Height() int8 { } // Has returns whether or not a key exists. -func (t *ImmutableTree) Has(key []byte) bool { +func (t *ImmutableTree) Has(key []byte) (bool, error) { if t.root == nil { - return false + return false, nil } return t.root.has(t, key) } // Hash returns the root hash. -func (t *ImmutableTree) Hash() []byte { - hash, _ := t.root.hashWithCount() - return hash +func (t *ImmutableTree) Hash() ([]byte, error) { + hash, _, err := t.root.hashWithCount() + return hash, err } // Export returns an iterator that exports tree nodes as ExportNodes. These nodes can be @@ -143,9 +162,9 @@ func (t *ImmutableTree) Export() *Exporter { // // The index is the index in the list of leaf nodes sorted lexicographically by key. The leftmost leaf has index 0. // It's neighbor has index 1 and so on. -func (t *ImmutableTree) GetWithIndex(key []byte) (int64, []byte) { +func (t *ImmutableTree) GetWithIndex(key []byte) (int64, []byte, error) { if t.root == nil { - return 0, nil + return 0, nil, nil } return t.root.get(t, key) } @@ -153,17 +172,17 @@ func (t *ImmutableTree) GetWithIndex(key []byte) (int64, []byte) { // Get returns the value of the specified key if it exists, or nil. // The returned value must not be modified, since it may point to data stored within IAVL. // Get potentially employs a more performant strategy than GetWithIndex for retrieving the value. -func (t *ImmutableTree) Get(key []byte) []byte { +func (t *ImmutableTree) Get(key []byte) ([]byte, error) { if t.root == nil { - return nil + return nil, nil } // attempt to get a FastNode directly from db/cache. // if call fails, fall back to the original IAVL logic in place. fastNode, err := t.ndb.GetFastNode(key) if err != nil { - _, result := t.root.get(t, key) - return result + _, result, err := t.root.get(t, key) + return result, err } if fastNode == nil { @@ -171,27 +190,27 @@ func (t *ImmutableTree) Get(key []byte) []byte { // then the regular node is not in the tree either because fast node // represents live state. if t.version == t.ndb.latestVersion { - return nil + return nil, nil } - _, result := t.root.get(t, key) - return result + _, result, err := t.root.get(t, key) + return result, err } if fastNode.versionLastUpdatedAt <= t.version { - return fastNode.value + return fastNode.value, nil } // Otherwise the cached node was updated later than the current tree. In this case, // we need to use the regular stategy for reading from the current tree to avoid staleness. - _, result := t.root.get(t, key) - return result + _, result, err := t.root.get(t, key) + return result, err } // GetByIndex gets the key and value at the specified index. -func (t *ImmutableTree) GetByIndex(index int64) (key []byte, value []byte) { +func (t *ImmutableTree) GetByIndex(index int64) (key []byte, value []byte, err error) { if t.root == nil { - return nil, nil + return nil, nil, nil } return t.root.getByIndex(t, index) @@ -199,28 +218,36 @@ func (t *ImmutableTree) GetByIndex(index int64) (key []byte, value []byte) { // Iterate iterates over all keys of the tree. The keys and values must not be modified, // since they may point to data stored within IAVL. Returns true if stopped by callback, false otherwise -func (t *ImmutableTree) Iterate(fn func(key []byte, value []byte) bool) bool { +func (t *ImmutableTree) Iterate(fn func(key []byte, value []byte) bool) (bool, error) { if t.root == nil { - return false + return false, nil } - itr := t.Iterator(nil, nil, true) + itr, err := t.Iterator(nil, nil, true) defer itr.Close() + if err != nil { + return false, err + } for ; itr.Valid(); itr.Next() { if fn(itr.Key(), itr.Value()) { - return true + return true, nil } } - return false + return false, nil } // Iterator returns an iterator over the immutable tree. -func (t *ImmutableTree) Iterator(start, end []byte, ascending bool) dbm.Iterator { - if t.IsFastCacheEnabled() { - return NewFastIterator(start, end, ascending, t.ndb) +func (t *ImmutableTree) Iterator(start, end []byte, ascending bool) (dbm.Iterator, error) { + isFastCacheEnabled, err := t.IsFastCacheEnabled() + if err != nil { + return nil, err } - return NewIterator(start, end, ascending, t) + + if isFastCacheEnabled { + return NewFastIterator(start, end, ascending, t.ndb), nil + } + return NewIterator(start, end, ascending, t), nil } // IterateRange makes a callback for all nodes with key between start and end non-inclusive. @@ -257,12 +284,20 @@ func (t *ImmutableTree) IterateRangeInclusive(start, end []byte, ascending bool, // For fast cache to be enabled, the following 2 conditions must be met: // 1. The tree is of the latest version. // 2. The underlying storage has been upgraded to fast cache -func (t *ImmutableTree) IsFastCacheEnabled() bool { - return t.isLatestTreeVersion() && t.ndb.hasUpgradedToFastStorage() +func (t *ImmutableTree) IsFastCacheEnabled() (bool, error) { + isLatestTreeVersion, err := t.isLatestTreeVersion() + if err != nil { + return false, err + } + return isLatestTreeVersion && t.ndb.hasUpgradedToFastStorage(), nil } -func (t *ImmutableTree) isLatestTreeVersion() bool { - return t.version == t.ndb.getLatestVersion() +func (t *ImmutableTree) isLatestTreeVersion() (bool, error) { + latestVersion, err := t.ndb.getLatestVersion() + if err != nil { + return false, err + } + return t.version == latestVersion, nil } // Clone creates a clone of the tree. diff --git a/import.go b/import.go index b0b609832..aff719bda 100644 --- a/import.go +++ b/import.go @@ -113,8 +113,12 @@ func (i *Importer) Add(exportNode *ExportNode) error { node.size += node.rightNode.size } - node._hash() - err := node.validate() + _, err := node._hash() + if err != nil { + return err + } + + err = node.validate() if err != nil { return err } @@ -168,11 +172,11 @@ func (i *Importer) Commit() error { switch len(i.stack) { case 0: if err := i.batch.Set(i.tree.ndb.rootKey(i.version), []byte{}); err != nil { - panic(err) + return err } case 1: if err := i.batch.Set(i.tree.ndb.rootKey(i.version), i.stack[0].hash); err != nil { - panic(err) + return err } default: return errors.Errorf("invalid node structure, found stack size %v when committing", diff --git a/import_test.go b/import_test.go index 26035c858..a5f494f44 100644 --- a/import_test.go +++ b/import_test.go @@ -164,7 +164,8 @@ func TestImporter_Close(t *testing.T) { require.NoError(t, err) importer.Close() - has := tree.Has([]byte("key")) + has, err := tree.Has([]byte("key")) + require.NoError(t, err) require.False(t, has) importer.Close() @@ -181,7 +182,8 @@ func TestImporter_Commit(t *testing.T) { err = importer.Commit() require.NoError(t, err) - has := tree.Has([]byte("key")) + has, err := tree.Has([]byte("key")) + require.NoError(t, err) require.True(t, has) } diff --git a/iterator.go b/iterator.go index cb22e50c9..a509ed1ad 100644 --- a/iterator.go +++ b/iterator.go @@ -89,17 +89,17 @@ func (nodes *delayedNodes) length() int { // 1. If the traversal is postorder, the current node will be append to the `delayedNodes` with `delayed` // set to false, and immediately returned at the subsequent call of `traversal.next()` at the last line. // 2. If the traversal is preorder, the current node will be returned. -func (t *traversal) next() *Node { +func (t *traversal) next() (*Node, error) { // End of traversal. if t.delayedNodes.length() == 0 { - return nil + return nil, nil } node, delayed := t.delayedNodes.pop() // Already expanded, immediately return. if !delayed || node == nil { - return node + return node, nil } afterStart := t.start == nil || bytes.Compare(t.start, node.key) < 0 @@ -122,22 +122,38 @@ func (t *traversal) next() *Node { if t.ascending { if beforeEnd { // push the delayed traversal for the right nodes, - t.delayedNodes.push(node.getRightNode(t.tree), true) + rightNode, err := node.getRightNode(t.tree) + if err != nil { + return nil, err + } + t.delayedNodes.push(rightNode, true) } if afterStart { // push the delayed traversal for the left nodes, - t.delayedNodes.push(node.getLeftNode(t.tree), true) + leftNode, err := node.getLeftNode(t.tree) + if err != nil { + return nil, err + } + t.delayedNodes.push(leftNode, true) } } else { // if node is a branch node and the order is not ascending // We traverse through the right subtree, then the left subtree. if afterStart { // push the delayed traversal for the left nodes, - t.delayedNodes.push(node.getLeftNode(t.tree), true) + leftNode, err := node.getLeftNode(t.tree) + if err != nil { + return nil, err + } + t.delayedNodes.push(leftNode, true) } if beforeEnd { // push the delayed traversal for the right nodes, - t.delayedNodes.push(node.getRightNode(t.tree), true) + rightNode, err := node.getRightNode(t.tree) + if err != nil { + return nil, err + } + t.delayedNodes.push(rightNode, true) } } } @@ -145,7 +161,7 @@ func (t *traversal) next() *Node { // case of preorder traversal. A-3 and B-2. // Process root then (recursively) processing left child, then process right child if !t.post && (!node.isLeaf() || (startOrAfter && beforeEnd)) { - return node + return node, nil } // Keep traversing and expanding the remaning delayed nodes. A-4. @@ -211,8 +227,9 @@ func (iter *Iterator) Next() { return } - node := iter.t.next() - if node == nil { + node, err := iter.t.next() + // TODO: double-check if this error is correctly handled. + if node == nil || err != nil { iter.t = nil iter.valid = false return diff --git a/iterator_test.go b/iterator_test.go index c5db3b0e2..84da66d36 100644 --- a/iterator_test.go +++ b/iterator_test.go @@ -182,7 +182,9 @@ func TestIterator_WithDelete_Full_Ascending_Success(t *testing.T) { err = tree.DeleteVersion(1) require.NoError(t, err) - immutableTree, err := tree.GetImmutable(tree.ndb.getLatestVersion()) + latestVersion, err := tree.ndb.getLatestVersion() + require.NoError(t, err) + immutableTree, err := tree.GetImmutable(latestVersion) require.NoError(t, err) // sort mirror for assertion @@ -252,7 +254,9 @@ func setupIteratorAndMirror(t *testing.T, config *iteratorTestConfig) (dbm.Itera _, _, err = tree.SaveVersion() require.NoError(t, err) - immutableTree, err := tree.GetImmutable(tree.ndb.getLatestVersion()) + latestVersion, err := tree.ndb.getLatestVersion() + require.NoError(t, err) + immutableTree, err := tree.GetImmutable(latestVersion) require.NoError(t, err) itr := NewIterator(config.startIterate, config.endIterate, config.ascending, immutableTree) @@ -314,7 +318,8 @@ func setupUnsavedFastIterator(t *testing.T, config *iteratorTestConfig) (dbm.Ite randIndex := rand.Intn(len(mirror)) keyToRemove := mirror[randIndex][0] - _, removed := tree.Remove([]byte(keyToRemove)) + _, removed, err := tree.Remove([]byte(keyToRemove)) + require.NoError(t, err) require.True(t, removed) mirror = append(mirror[:randIndex], mirror[randIndex+1:]...) diff --git a/mutable_tree.go b/mutable_tree.go index 05e54225c..654b1347f 100644 --- a/mutable_tree.go +++ b/mutable_tree.go @@ -103,12 +103,12 @@ func (tree *MutableTree) AvailableVersions() []int { // Hash returns the hash of the latest saved version of the tree, as returned // by SaveVersion. If no versions have been saved, Hash returns nil. -func (tree *MutableTree) Hash() []byte { +func (tree *MutableTree) Hash() ([]byte, error) { return tree.lastSaved.Hash() } // WorkingHash returns the hash of the current working tree. -func (tree *MutableTree) WorkingHash() []byte { +func (tree *MutableTree) WorkingHash() ([]byte, error) { return tree.ImmutableTree.Hash() } @@ -127,22 +127,28 @@ func (tree *MutableTree) prepareOrphansSlice() []*Node { // key/value byte slices must not be modified after this call, since they point // to slices stored within IAVL. It returns true when an existing value was // updated, while false means it was a new key. -func (tree *MutableTree) Set(key, value []byte) (updated bool) { +func (tree *MutableTree) Set(key, value []byte) (updated bool, err error) { var orphaned []*Node - orphaned, updated = tree.set(key, value) - tree.addOrphans(orphaned) - return updated + orphaned, updated, err = tree.set(key, value) + if err != nil { + return false, err + } + err = tree.addOrphans(orphaned) + if err != nil { + return updated, err + } + return updated, nil } // Get returns the value of the specified key if it exists, or nil otherwise. // The returned value must not be modified, since it may point to data stored within IAVL. -func (tree *MutableTree) Get(key []byte) []byte { +func (tree *MutableTree) Get(key []byte) ([]byte, error) { if tree.root == nil { - return nil + return nil, nil } if fastNode, ok := tree.unsavedFastNodeAdditions[string(key)]; ok { - return fastNode.value + return fastNode.value, nil } return tree.ImmutableTree.Get(key) @@ -162,12 +168,17 @@ func (tree *MutableTree) Import(version int64) (*Importer, error) { // Iterate iterates over all keys of the tree. The keys and values must not be modified, // since they may point to data stored within IAVL. Returns true if stopped by callnack, false otherwise -func (tree *MutableTree) Iterate(fn func(key []byte, value []byte) bool) (stopped bool) { +func (tree *MutableTree) Iterate(fn func(key []byte, value []byte) bool) (stopped bool, err error) { if tree.root == nil { - return false + return false, nil + } + + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + if err != nil { + return false, err } - if !tree.IsFastCacheEnabled() { + if !isFastCacheEnabled { return tree.ImmutableTree.Iterate(fn) } @@ -175,39 +186,44 @@ func (tree *MutableTree) Iterate(fn func(key []byte, value []byte) bool) (stoppe defer itr.Close() for ; itr.Valid(); itr.Next() { if fn(itr.Key(), itr.Value()) { - return true + return true, nil } } - return false + return false, nil } // Iterator returns an iterator over the mutable tree. // CONTRACT: no updates are made to the tree while an iterator is active. -func (tree *MutableTree) Iterator(start, end []byte, ascending bool) dbm.Iterator { - if tree.IsFastCacheEnabled() { - return NewUnsavedFastIterator(start, end, ascending, tree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals) +func (tree *MutableTree) Iterator(start, end []byte, ascending bool) (dbm.Iterator, error) { + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + if err != nil { + return nil, err + } + + if isFastCacheEnabled { + return NewUnsavedFastIterator(start, end, ascending, tree.ndb, tree.unsavedFastNodeAdditions, tree.unsavedFastNodeRemovals), nil } return tree.ImmutableTree.Iterator(start, end, ascending) } -func (tree *MutableTree) set(key []byte, value []byte) (orphans []*Node, updated bool) { +func (tree *MutableTree) set(key []byte, value []byte) (orphans []*Node, updated bool, err error) { if value == nil { - panic(fmt.Sprintf("Attempt to store nil value at key '%s'", key)) + return nil, updated, fmt.Errorf("attempt to store nil value at key '%s'", key) } if tree.ImmutableTree.root == nil { tree.addUnsavedAddition(key, NewFastNode(key, value, tree.version+1)) tree.ImmutableTree.root = NewNode(key, value, tree.version+1) - return nil, updated + return nil, updated, nil } orphans = tree.prepareOrphansSlice() - tree.ImmutableTree.root, updated = tree.recursiveSet(tree.ImmutableTree.root, key, value, &orphans) - return orphans, updated + tree.ImmutableTree.root, updated, err = tree.recursiveSet(tree.ImmutableTree.root, key, value, &orphans) + return orphans, updated, err } func (tree *MutableTree) recursiveSet(node *Node, key []byte, value []byte, orphans *[]*Node) ( - newSelf *Node, updated bool, + newSelf *Node, updated bool, err error, ) { version := tree.version + 1 @@ -223,7 +239,7 @@ func (tree *MutableTree) recursiveSet(node *Node, key []byte, value []byte, orph leftNode: NewNode(key, value, version), rightNode: node, version: version, - }, false + }, false, nil case 1: return &Node{ key: key, @@ -232,60 +248,97 @@ func (tree *MutableTree) recursiveSet(node *Node, key []byte, value []byte, orph leftNode: node, rightNode: NewNode(key, value, version), version: version, - }, false + }, false, nil default: *orphans = append(*orphans, node) - return NewNode(key, value, version), true + return NewNode(key, value, version), true, nil } } else { *orphans = append(*orphans, node) - node = node.clone(version) + node, err = node.clone(version) + if err != nil { + return nil, false, err + } if bytes.Compare(key, node.key) < 0 { - node.leftNode, updated = tree.recursiveSet(node.getLeftNode(tree.ImmutableTree), key, value, orphans) + leftNode, err := node.getLeftNode(tree.ImmutableTree) + if err != nil { + return nil, false, err + } + node.leftNode, updated, err = tree.recursiveSet(leftNode, key, value, orphans) + if err != nil { + return nil, updated, err + } node.leftHash = nil // leftHash is yet unknown } else { - node.rightNode, updated = tree.recursiveSet(node.getRightNode(tree.ImmutableTree), key, value, orphans) + rightNode, err := node.getRightNode(tree.ImmutableTree) + if err != nil { + return nil, false, err + } + node.rightNode, updated, err = tree.recursiveSet(rightNode, key, value, orphans) + if err != nil { + return nil, updated, err + } node.rightHash = nil // rightHash is yet unknown } if updated { - return node, updated + return node, updated, nil + } + err = node.calcHeightAndSize(tree.ImmutableTree) + if err != nil { + return nil, false, err + } + + newNode, err := tree.balance(node, orphans) + if err != nil { + return nil, false, err } - node.calcHeightAndSize(tree.ImmutableTree) - newNode := tree.balance(node, orphans) - return newNode, updated + return newNode, updated, err } } // Remove removes a key from the working tree. The given key byte slice should not be modified // after this call, since it may point to data stored inside IAVL. -func (tree *MutableTree) Remove(key []byte) ([]byte, bool) { - val, orphaned, removed := tree.remove(key) - tree.addOrphans(orphaned) - return val, removed +func (tree *MutableTree) Remove(key []byte) ([]byte, bool, error) { + val, orphaned, removed, err := tree.remove(key) + if err != nil { + return nil, false, err + } + + err = tree.addOrphans(orphaned) + if err != nil { + return val, removed, err + } + return val, removed, nil } // remove tries to remove a key from the tree and if removed, returns its // value, nodes orphaned and 'true'. -func (tree *MutableTree) remove(key []byte) (value []byte, orphaned []*Node, removed bool) { +func (tree *MutableTree) remove(key []byte) (value []byte, orphaned []*Node, removed bool, err error) { if tree.root == nil { - return nil, nil, false + return nil, nil, false, nil } orphaned = tree.prepareOrphansSlice() - newRootHash, newRoot, _, value := tree.recursiveRemove(tree.root, key, &orphaned) + newRootHash, newRoot, _, value, err := tree.recursiveRemove(tree.root, key, &orphaned) + if err != nil { + return nil, nil, false, err + } if len(orphaned) == 0 { - return nil, nil, false + return nil, nil, false, nil } tree.addUnsavedRemoval(key) if newRoot == nil && newRootHash != nil { - tree.root = tree.ndb.GetNode(newRootHash) + tree.root, err = tree.ndb.GetNode(newRootHash) + if err != nil { + return nil, nil, false, err + } } else { tree.root = newRoot } - return value, orphaned, true + return value, orphaned, true, nil } // removes the node corresponding to the passed key and balances the tree. @@ -295,54 +348,90 @@ func (tree *MutableTree) remove(key []byte) (value []byte, orphaned []*Node, rem // - new leftmost leaf key for tree after successfully removing 'key' if changed. // - the removed value // - the orphaned nodes. -func (tree *MutableTree) recursiveRemove(node *Node, key []byte, orphans *[]*Node) (newHash []byte, newSelf *Node, newKey []byte, newValue []byte) { +func (tree *MutableTree) recursiveRemove(node *Node, key []byte, orphans *[]*Node) (newHash []byte, newSelf *Node, newKey []byte, newValue []byte, err error) { version := tree.version + 1 if node.isLeaf() { if bytes.Equal(key, node.key) { *orphans = append(*orphans, node) - return nil, nil, nil, node.value + return nil, nil, nil, node.value, nil } - return node.hash, node, nil, nil + return node.hash, node, nil, nil, nil } // node.key < key; we go to the left to find the key: if bytes.Compare(key, node.key) < 0 { - newLeftHash, newLeftNode, newKey, value := tree.recursiveRemove(node.getLeftNode(tree.ImmutableTree), key, orphans) + leftNode, err := node.getLeftNode(tree.ImmutableTree) + if err != nil { + return nil, nil, nil, nil, err + } + newLeftHash, newLeftNode, newKey, value, err := tree.recursiveRemove(leftNode, key, orphans) + if err != nil { + return nil, nil, nil, nil, err + } if len(*orphans) == 0 { - return node.hash, node, nil, value + return node.hash, node, nil, value, nil } *orphans = append(*orphans, node) if newLeftHash == nil && newLeftNode == nil { // left node held value, was removed - return node.rightHash, node.rightNode, node.key, value + return node.rightHash, node.rightNode, node.key, value, nil + } + + newNode, err := node.clone(version) + if err != nil { + return nil, nil, nil, nil, err } - newNode := node.clone(version) newNode.leftHash, newNode.leftNode = newLeftHash, newLeftNode - newNode.calcHeightAndSize(tree.ImmutableTree) - newNode = tree.balance(newNode, orphans) - return newNode.hash, newNode, newKey, value + err = newNode.calcHeightAndSize(tree.ImmutableTree) + if err != nil { + return nil, nil, nil, nil, err + } + newNode, err = tree.balance(newNode, orphans) + if err != nil { + return nil, nil, nil, nil, err + } + + return newNode.hash, newNode, newKey, value, nil } // node.key >= key; either found or look to the right: - newRightHash, newRightNode, newKey, value := tree.recursiveRemove(node.getRightNode(tree.ImmutableTree), key, orphans) - + rightNode, err := node.getRightNode(tree.ImmutableTree) + if err != nil { + return nil, nil, nil, nil, err + } + newRightHash, newRightNode, newKey, value, err := tree.recursiveRemove(rightNode, key, orphans) + if err != nil { + return nil, nil, nil, nil, err + } if len(*orphans) == 0 { - return node.hash, node, nil, value + return node.hash, node, nil, value, nil } *orphans = append(*orphans, node) if newRightHash == nil && newRightNode == nil { // right node held value, was removed - return node.leftHash, node.leftNode, nil, value + return node.leftHash, node.leftNode, nil, value, nil + } + + newNode, err := node.clone(version) + if err != nil { + return nil, nil, nil, nil, err } - newNode := node.clone(version) newNode.rightHash, newNode.rightNode = newRightHash, newRightNode if newKey != nil { newNode.key = newKey } - newNode.calcHeightAndSize(tree.ImmutableTree) - newNode = tree.balance(newNode, orphans) - return newNode.hash, newNode, nil, value + err = newNode.calcHeightAndSize(tree.ImmutableTree) + if err != nil { + return nil, nil, nil, nil, err + } + + newNode, err = tree.balance(newNode, orphans) + if err != nil { + return nil, nil, nil, nil, err + } + + return newNode.hash, newNode, nil, value, nil } // Load the latest versioned tree from disk. @@ -358,7 +447,10 @@ func (tree *MutableTree) Load() (int64, error) { // performs a no-op. Otherwise, if the root does not exist, an error will be // returned. func (tree *MutableTree) LazyLoadVersion(targetVersion int64) (int64, error) { - latestVersion := tree.ndb.getLatestVersion() + latestVersion, err := tree.ndb.getLatestVersion() + if err != nil { + return 0, err + } if latestVersion < targetVersion { return latestVersion, fmt.Errorf("wanted to load target %d but only found up to %d", targetVersion, latestVersion) } @@ -399,7 +491,10 @@ func (tree *MutableTree) LazyLoadVersion(targetVersion int64) (int64, error) { if len(rootHash) > 0 { // If rootHash is empty then root of tree should be nil // This makes `LazyLoadVersion` to do the same thing as `LoadVersion` - iTree.root = tree.ndb.GetNode(rootHash) + iTree.root, err = tree.ndb.GetNode(rootHash) + if err != nil { + return 0, err + } } tree.orphans = map[string]int64{} @@ -465,7 +560,10 @@ func (tree *MutableTree) LoadVersion(targetVersion int64) (int64, error) { } if len(latestRoot) != 0 { - t.root = tree.ndb.GetNode(latestRoot) + t.root, err = tree.ndb.GetNode(latestRoot) + if err != nil { + return 0, err + } } tree.orphans = map[string]int64{} @@ -514,8 +612,12 @@ func (tree *MutableTree) LoadVersionForOverwriting(targetVersion int64) (int64, // Returns true if the tree may be auto-upgraded, false otherwise // An example of when an upgrade may be performed is when we are enaling fast storage for the first time or // need to overwrite fast nodes due to mismatch with live state. -func (tree *MutableTree) IsUpgradeable() bool { - return !tree.ndb.hasUpgradedToFastStorage() || tree.ndb.shouldForceFastStorageUpgrade() +func (tree *MutableTree) IsUpgradeable() (bool, error) { + shouldForce, err := tree.ndb.shouldForceFastStorageUpgrade() + if err != nil { + return false, err + } + return !tree.ndb.hasUpgradedToFastStorage() || shouldForce, nil } // enableFastStorageAndCommitIfNotEnabled if nodeDB doesn't mark fast storage as enabled, enable it, and commit the update. @@ -523,10 +625,18 @@ func (tree *MutableTree) IsUpgradeable() bool { // from latest tree. // nolint: unparam func (tree *MutableTree) enableFastStorageAndCommitIfNotEnabled() (bool, error) { - shouldForceUpdate := tree.ndb.shouldForceFastStorageUpgrade() + shouldForceUpdate, err := tree.ndb.shouldForceFastStorageUpgrade() + if err != nil { + return false, err + } isFastStorageEnabled := tree.ndb.hasUpgradedToFastStorage() - if !tree.IsUpgradeable() { + isUpgradeable, err := tree.IsUpgradeable() + if err != nil { + return false, err + } + + if !isUpgradeable { return false, nil } @@ -638,8 +748,13 @@ func (tree *MutableTree) GetImmutable(version int64) (*ImmutableTree, error) { }, nil } tree.versions[version] = true + + root, err := tree.ndb.GetNode(rootHash) + if err != nil { + return nil, err + } return &ImmutableTree{ - root: tree.ndb.GetNode(rootHash), + root: root, ndb: tree.ndb, version: version, }, nil @@ -660,26 +775,34 @@ func (tree *MutableTree) Rollback() { // GetVersioned gets the value at the specified key and version. The returned value must not be // modified, since it may point to data stored within IAVL. -func (tree *MutableTree) GetVersioned(key []byte, version int64) []byte { +func (tree *MutableTree) GetVersioned(key []byte, version int64) ([]byte, error) { if tree.VersionExists(version) { - if tree.IsFastCacheEnabled() { + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + if err != nil { + return nil, err + } + + if isFastCacheEnabled { fastNode, _ := tree.ndb.GetFastNode(key) if fastNode == nil && version == tree.ndb.latestVersion { - return nil + return nil, nil } if fastNode != nil && fastNode.versionLastUpdatedAt <= version { - return fastNode.value + return fastNode.value, nil } } t, err := tree.GetImmutable(version) if err != nil { - return nil + return nil, nil + } + value, err := t.Get(key) + if err != nil { + return nil, err } - value := t.Get(key) - return value + return value, nil } - return nil + return nil, nil } // SaveVersion saves a new tree version to disk, based on the current state of @@ -704,7 +827,10 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { existingHash = sha256.New().Sum(nil) } - var newHash = tree.WorkingHash() + newHash, err := tree.WorkingHash() + if err != nil { + return nil, version, err + } if bytes.Equal(existingHash, newHash) { tree.version = version @@ -721,7 +847,9 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { // There can still be orphans, for example if the root is the node being // removed. logger.Debug("SAVE EMPTY TREE %v\n", version) - tree.ndb.SaveOrphans(version, tree.orphans) + if err := tree.ndb.SaveOrphans(version, tree.orphans); err != nil { + return nil, 0, err + } if err := tree.ndb.SaveEmptyRoot(version); err != nil { return nil, 0, err } @@ -730,7 +858,9 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { if _, err := tree.ndb.SaveBranch(tree.root); err != nil { return nil, 0, err } - tree.ndb.SaveOrphans(version, tree.orphans) + if err := tree.ndb.SaveOrphans(version, tree.orphans); err != nil { + return nil, 0, err + } if err := tree.ndb.SaveRoot(tree.root, version); err != nil { return nil, 0, err } @@ -756,7 +886,12 @@ func (tree *MutableTree) SaveVersion() ([]byte, int64, error) { tree.unsavedFastNodeAdditions = make(map[string]*FastNode) tree.unsavedFastNodeRemovals = make(map[string]interface{}) - return tree.Hash(), version, nil + hash, err := tree.Hash() + if err != nil { + return nil, version, err + } + + return hash, version, nil } func (tree *MutableTree) saveFastNodeVersion() error { @@ -917,99 +1052,183 @@ func (tree *MutableTree) DeleteVersion(version int64) error { } // Rotate right and return the new node and orphan. -func (tree *MutableTree) rotateRight(node *Node) (*Node, *Node) { +func (tree *MutableTree) rotateRight(node *Node) (*Node, *Node, error) { version := tree.version + 1 + var err error // TODO: optimize balance & rotate. - node = node.clone(version) - orphaned := node.getLeftNode(tree.ImmutableTree) - newNode := orphaned.clone(version) + node, err = node.clone(version) + if err != nil { + return nil, nil, err + } + + orphaned, err := node.getLeftNode(tree.ImmutableTree) + if err != nil { + return nil, nil, err + } + newNode, err := orphaned.clone(version) + if err != nil { + return nil, nil, err + } newNoderHash, newNoderCached := newNode.rightHash, newNode.rightNode newNode.rightHash, newNode.rightNode = node.hash, node node.leftHash, node.leftNode = newNoderHash, newNoderCached - node.calcHeightAndSize(tree.ImmutableTree) - newNode.calcHeightAndSize(tree.ImmutableTree) + err = node.calcHeightAndSize(tree.ImmutableTree) + if err != nil { + return nil, nil, err + } + + err = newNode.calcHeightAndSize(tree.ImmutableTree) + if err != nil { + return nil, nil, err + } - return newNode, orphaned + return newNode, orphaned, nil } // Rotate left and return the new node and orphan. -func (tree *MutableTree) rotateLeft(node *Node) (*Node, *Node) { +func (tree *MutableTree) rotateLeft(node *Node) (*Node, *Node, error) { version := tree.version + 1 + var err error // TODO: optimize balance & rotate. - node = node.clone(version) - orphaned := node.getRightNode(tree.ImmutableTree) - newNode := orphaned.clone(version) + node, err = node.clone(version) + if err != nil { + return nil, nil, err + } + + orphaned, err := node.getRightNode(tree.ImmutableTree) + if err != nil { + return nil, nil, err + } + newNode, err := orphaned.clone(version) + if err != nil { + return nil, nil, err + } newNodelHash, newNodelCached := newNode.leftHash, newNode.leftNode newNode.leftHash, newNode.leftNode = node.hash, node node.rightHash, node.rightNode = newNodelHash, newNodelCached - node.calcHeightAndSize(tree.ImmutableTree) - newNode.calcHeightAndSize(tree.ImmutableTree) + err = node.calcHeightAndSize(tree.ImmutableTree) + if err != nil { + return nil, nil, err + } + + err = newNode.calcHeightAndSize(tree.ImmutableTree) + if err != nil { + return nil, nil, err + } - return newNode, orphaned + return newNode, orphaned, nil } // NOTE: assumes that node can be modified // TODO: optimize balance & rotate -func (tree *MutableTree) balance(node *Node, orphans *[]*Node) (newSelf *Node) { +func (tree *MutableTree) balance(node *Node, orphans *[]*Node) (newSelf *Node, err error) { if node.persisted { - panic("Unexpected balance() call on persisted node") + return nil, fmt.Errorf("unexpected balance() call on persisted node") + } + balance, err := node.calcBalance(tree.ImmutableTree) + if err != nil { + return nil, err } - balance := node.calcBalance(tree.ImmutableTree) if balance > 1 { - if node.getLeftNode(tree.ImmutableTree).calcBalance(tree.ImmutableTree) >= 0 { + leftNode, err := node.getLeftNode(tree.ImmutableTree) + if err != nil { + return nil, err + } + + lftBalance, err := leftNode.calcBalance(tree.ImmutableTree) + if err != nil { + return nil, err + } + + if lftBalance >= 0 { // Left Left Case - newNode, orphaned := tree.rotateRight(node) + newNode, orphaned, err := tree.rotateRight(node) + if err != nil { + return nil, err + } *orphans = append(*orphans, orphaned) - return newNode + return newNode, nil } // Left Right Case var leftOrphaned *Node - left := node.getLeftNode(tree.ImmutableTree) + left, err := node.getLeftNode(tree.ImmutableTree) + if err != nil { + return nil, err + } node.leftHash = nil - node.leftNode, leftOrphaned = tree.rotateLeft(left) - newNode, rightOrphaned := tree.rotateRight(node) + node.leftNode, leftOrphaned, err = tree.rotateLeft(left) + if err != nil { + return nil, err + } + + newNode, rightOrphaned, err := tree.rotateRight(node) + if err != nil { + return nil, err + } *orphans = append(*orphans, left, leftOrphaned, rightOrphaned) - return newNode + return newNode, nil } if balance < -1 { - if node.getRightNode(tree.ImmutableTree).calcBalance(tree.ImmutableTree) <= 0 { + rightNode, err := node.getRightNode(tree.ImmutableTree) + if err != nil { + return nil, err + } + + rightBalance, err := rightNode.calcBalance(tree.ImmutableTree) + if err != nil { + return nil, err + } + if rightBalance <= 0 { // Right Right Case - newNode, orphaned := tree.rotateLeft(node) + newNode, orphaned, err := tree.rotateLeft(node) + if err != nil { + return nil, err + } *orphans = append(*orphans, orphaned) - return newNode + return newNode, nil } // Right Left Case var rightOrphaned *Node - right := node.getRightNode(tree.ImmutableTree) + right, err := node.getRightNode(tree.ImmutableTree) + if err != nil { + return nil, err + } node.rightHash = nil - node.rightNode, rightOrphaned = tree.rotateRight(right) - newNode, leftOrphaned := tree.rotateLeft(node) + node.rightNode, rightOrphaned, err = tree.rotateRight(right) + if err != nil { + return nil, err + } + newNode, leftOrphaned, err := tree.rotateLeft(node) + if err != nil { + return nil, err + } *orphans = append(*orphans, right, leftOrphaned, rightOrphaned) - return newNode + return newNode, nil } // Nothing changed - return node + return node, nil } -func (tree *MutableTree) addOrphans(orphans []*Node) { +func (tree *MutableTree) addOrphans(orphans []*Node) error { for _, node := range orphans { if !node.persisted { // We don't need to orphan nodes that were never persisted. continue } if len(node.hash) == 0 { - panic("Expected to find node hash, but was empty") + return fmt.Errorf("expected to find node hash, but was empty") } tree.orphans[string(node.hash)] = node.version } + return nil } diff --git a/mutable_tree_test.go b/mutable_tree_test.go index d24cb0237..4e1d6efd8 100644 --- a/mutable_tree_test.go +++ b/mutable_tree_test.go @@ -77,7 +77,8 @@ func TestMutableTree_DeleteVersions(t *testing.T) { v := randBytes(10) entries[j] = entry{k, v} - _ = tree.Set(k, v) + _, err = tree.Set(k, v) + require.NoError(t, err) } _, v, err := tree.SaveVersion() @@ -106,7 +107,8 @@ func TestMutableTree_DeleteVersions(t *testing.T) { require.NoError(t, err) for _, e := range versionEntries[v] { - val := tree.Get(e.key) + val, err := tree.Get(e.key) + require.NoError(t, err) require.Equal(t, e.value, val) } } @@ -183,12 +185,14 @@ func TestMutableTree_DeleteVersionsRange(t *testing.T) { require.NoError(err, version) require.Equal(v, version) - value := tree.Get([]byte("aaa")) + value, err := tree.Get([]byte("aaa")) + require.NoError(err) require.Equal(string(value), "bbb") for _, count := range versions[:version] { countStr := strconv.Itoa(int(count)) - value := tree.Get([]byte("key" + countStr)) + value, err := tree.Get([]byte("key" + countStr)) + require.NoError(err) require.Equal(string(value), "value"+countStr) } } @@ -207,17 +211,20 @@ func TestMutableTree_DeleteVersionsRange(t *testing.T) { require.NoError(err) require.Equal(v, version) - value := tree.Get([]byte("aaa")) + value, err := tree.Get([]byte("aaa")) + require.NoError(err) require.Equal(string(value), "bbb") for _, count := range versions[:fromLength] { countStr := strconv.Itoa(int(count)) - value := tree.Get([]byte("key" + countStr)) + value, err := tree.Get([]byte("key" + countStr)) + require.NoError(err) require.Equal(string(value), "value"+countStr) } for _, count := range versions[int64(maxLength/2)-1 : version] { countStr := strconv.Itoa(int(count)) - value := tree.Get([]byte("key" + countStr)) + value, err := tree.Get([]byte("key" + countStr)) + require.NoError(err) require.Equal(string(value), "value"+countStr) } } @@ -324,7 +331,8 @@ func TestMutableTree_VersionExists(t *testing.T) { } func checkGetVersioned(t *testing.T, tree *MutableTree, version int64, key, value []byte) { - val := tree.GetVersioned(key, version) + val, err := tree.GetVersioned(key, version) + require.NoError(t, err) require.True(t, bytes.Equal(val, value)) } @@ -393,11 +401,14 @@ func TestMutableTree_SetSimple(t *testing.T) { const testKey1 = "a" const testVal1 = "test" - isUpdated := tree.Set([]byte(testKey1), []byte(testVal1)) + isUpdated, err := tree.Set([]byte(testKey1), []byte(testVal1)) + require.NoError(t, err) require.False(t, isUpdated) - fastValue := tree.Get([]byte(testKey1)) - _, regularValue := tree.GetWithIndex([]byte(testKey1)) + fastValue, err := tree.Get([]byte(testKey1)) + require.NoError(t, err) + _, regularValue, err := tree.GetWithIndex([]byte(testKey1)) + require.NoError(t, err) require.Equal(t, []byte(testVal1), fastValue) require.Equal(t, []byte(testVal1), regularValue) @@ -422,19 +433,25 @@ func TestMutableTree_SetTwoKeys(t *testing.T) { const testKey2 = "b" const testVal2 = "test2" - isUpdated := tree.Set([]byte(testKey1), []byte(testVal1)) + isUpdated, err := tree.Set([]byte(testKey1), []byte(testVal1)) + require.NoError(t, err) require.False(t, isUpdated) - isUpdated = tree.Set([]byte(testKey2), []byte(testVal2)) + isUpdated, err = tree.Set([]byte(testKey2), []byte(testVal2)) + require.NoError(t, err) require.False(t, isUpdated) - fastValue := tree.Get([]byte(testKey1)) - _, regularValue := tree.GetWithIndex([]byte(testKey1)) + fastValue, err := tree.Get([]byte(testKey1)) + require.NoError(t, err) + _, regularValue, err := tree.GetWithIndex([]byte(testKey1)) + require.NoError(t, err) require.Equal(t, []byte(testVal1), fastValue) require.Equal(t, []byte(testVal1), regularValue) - fastValue2 := tree.Get([]byte(testKey2)) - _, regularValue2 := tree.GetWithIndex([]byte(testKey2)) + fastValue2, err := tree.Get([]byte(testKey2)) + require.NoError(t, err) + _, regularValue2, err := tree.GetWithIndex([]byte(testKey2)) + require.NoError(t, err) require.Equal(t, []byte(testVal2), fastValue2) require.Equal(t, []byte(testVal2), regularValue2) @@ -461,14 +478,18 @@ func TestMutableTree_SetOverwrite(t *testing.T) { const testVal1 = "test" const testVal2 = "test2" - isUpdated := tree.Set([]byte(testKey1), []byte(testVal1)) + isUpdated, err := tree.Set([]byte(testKey1), []byte(testVal1)) + require.NoError(t, err) require.False(t, isUpdated) - isUpdated = tree.Set([]byte(testKey1), []byte(testVal2)) + isUpdated, err = tree.Set([]byte(testKey1), []byte(testVal2)) + require.NoError(t, err) require.True(t, isUpdated) - fastValue := tree.Get([]byte(testKey1)) - _, regularValue := tree.GetWithIndex([]byte(testKey1)) + fastValue, err := tree.Get([]byte(testKey1)) + require.NoError(t, err) + _, regularValue, err := tree.GetWithIndex([]byte(testKey1)) + require.NoError(t, err) require.Equal(t, []byte(testVal2), fastValue) require.Equal(t, []byte(testVal2), regularValue) @@ -490,11 +511,13 @@ func TestMutableTree_SetRemoveSet(t *testing.T) { const testVal1 = "test" // Set 1 - isUpdated := tree.Set([]byte(testKey1), []byte(testVal1)) + isUpdated, err := tree.Set([]byte(testKey1), []byte(testVal1)) + require.NoError(t, err) require.False(t, isUpdated) - fastValue := tree.Get([]byte(testKey1)) - _, regularValue := tree.GetWithIndex([]byte(testKey1)) + fastValue, err := tree.Get([]byte(testKey1)) + require.NoError(t, err) + _, regularValue, err := tree.GetWithIndex([]byte(testKey1)) require.Equal(t, []byte(testVal1), fastValue) require.Equal(t, []byte(testVal1), regularValue) @@ -507,7 +530,8 @@ func TestMutableTree_SetRemoveSet(t *testing.T) { require.Equal(t, int64(1), fastNodeAddition.versionLastUpdatedAt) // Remove - removedVal, isRemoved := tree.Remove([]byte(testKey1)) + removedVal, isRemoved, err := tree.Remove([]byte(testKey1)) + require.NoError(t, err) require.NotNil(t, removedVal) require.True(t, isRemoved) @@ -517,17 +541,22 @@ func TestMutableTree_SetRemoveSet(t *testing.T) { fastNodeRemovals := tree.getUnsavedFastNodeRemovals() require.Equal(t, 1, len(fastNodeRemovals)) - fastValue = tree.Get([]byte(testKey1)) - _, regularValue = tree.GetWithIndex([]byte(testKey1)) + fastValue, err = tree.Get([]byte(testKey1)) + require.NoError(t, err) + _, regularValue, err = tree.GetWithIndex([]byte(testKey1)) + require.NoError(t, err) require.Nil(t, fastValue) require.Nil(t, regularValue) // Set 2 - isUpdated = tree.Set([]byte(testKey1), []byte(testVal1)) + isUpdated, err = tree.Set([]byte(testKey1), []byte(testVal1)) + require.NoError(t, err) require.False(t, isUpdated) - fastValue = tree.Get([]byte(testKey1)) - _, regularValue = tree.GetWithIndex([]byte(testKey1)) + fastValue, err = tree.Get([]byte(testKey1)) + require.NoError(t, err) + _, regularValue, err = tree.GetWithIndex([]byte(testKey1)) + require.NoError(t, err) require.Equal(t, []byte(testVal1), fastValue) require.Equal(t, []byte(testVal1), regularValue) @@ -556,35 +585,40 @@ func TestMutableTree_FastNodeIntegration(t *testing.T) { const testVal2 = "test2" // Set key1 - res := tree.Set([]byte(key1), []byte(testVal1)) + res, err := tree.Set([]byte(key1), []byte(testVal1)) + require.NoError(t, err) require.False(t, res) unsavedNodeAdditions := tree.getUnsavedFastNodeAdditions() require.Equal(t, len(unsavedNodeAdditions), 1) // Set key2 - res = tree.Set([]byte(key2), []byte(testVal1)) + res, err = tree.Set([]byte(key2), []byte(testVal1)) + require.NoError(t, err) require.False(t, res) unsavedNodeAdditions = tree.getUnsavedFastNodeAdditions() require.Equal(t, len(unsavedNodeAdditions), 2) // Set key3 - res = tree.Set([]byte(key3), []byte(testVal1)) + res, err = tree.Set([]byte(key3), []byte(testVal1)) + require.NoError(t, err) require.False(t, res) unsavedNodeAdditions = tree.getUnsavedFastNodeAdditions() require.Equal(t, len(unsavedNodeAdditions), 3) // Set key3 with new value - res = tree.Set([]byte(key3), []byte(testVal2)) + res, err = tree.Set([]byte(key3), []byte(testVal2)) + require.NoError(t, err) require.True(t, res) unsavedNodeAdditions = tree.getUnsavedFastNodeAdditions() require.Equal(t, len(unsavedNodeAdditions), 3) // Remove key2 - removedVal, isRemoved := tree.Remove([]byte(key2)) + removedVal, isRemoved, err := tree.Remove([]byte(key2)) + require.NoError(t, err) require.True(t, isRemoved) require.Equal(t, []byte(testVal1), removedVal) @@ -612,18 +646,24 @@ func TestMutableTree_FastNodeIntegration(t *testing.T) { require.NoError(t, err) // Get and GetFast - fastValue := t2.Get([]byte(key1)) - _, regularValue := tree.GetWithIndex([]byte(key1)) + fastValue, err := t2.Get([]byte(key1)) + require.NoError(t, err) + _, regularValue, err := tree.GetWithIndex([]byte(key1)) + require.NoError(t, err) require.Equal(t, []byte(testVal1), fastValue) require.Equal(t, []byte(testVal1), regularValue) - fastValue = t2.Get([]byte(key2)) - _, regularValue = t2.GetWithIndex([]byte(key2)) + fastValue, err = t2.Get([]byte(key2)) + require.NoError(t, err) + _, regularValue, err = t2.GetWithIndex([]byte(key2)) + require.NoError(t, err) require.Nil(t, fastValue) require.Nil(t, regularValue) - fastValue = t2.Get([]byte(key3)) - _, regularValue = tree.GetWithIndex([]byte(key3)) + fastValue, err = t2.Get([]byte(key3)) + require.NoError(t, err) + _, regularValue, err = tree.GetWithIndex([]byte(key3)) + require.NoError(t, err) require.Equal(t, []byte(testVal2), fastValue) require.Equal(t, []byte(testVal2), regularValue) } @@ -659,8 +699,8 @@ func TestIterator_MutableTree_Invalid(t *testing.T) { tree, err := getTestTree(0) require.NoError(t, err) - itr := tree.Iterator([]byte("a"), []byte("b"), true) - + itr, err := tree.Iterator([]byte("a"), []byte("b"), true) + require.NoError(t, err) require.NotNil(t, itr) require.False(t, itr.Valid()) } @@ -671,21 +711,28 @@ func TestUpgradeStorageToFast_LatestVersion_Success(t *testing.T) { tree, err := NewMutableTree(db, 1000) // Default version when storage key does not exist in the db + isFastCacheEnabled, err := tree.IsFastCacheEnabled() require.NoError(t, err) - require.False(t, tree.IsFastCacheEnabled()) + require.False(t, isFastCacheEnabled) mirror := make(map[string]string) // Fill with some data randomizeTreeAndMirror(t, tree, mirror) // Enable fast storage - require.True(t, tree.IsUpgradeable()) + isUpgradeable, err := tree.IsUpgradeable() + require.True(t, isUpgradeable) + require.NoError(t, err) enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() require.NoError(t, err) require.True(t, enabled) - require.False(t, tree.IsUpgradeable()) + isUpgradeable, err = tree.IsUpgradeable() + require.False(t, isUpgradeable) + require.NoError(t, err) - require.True(t, tree.IsFastCacheEnabled()) + isFastCacheEnabled, err = tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) } func TestUpgradeStorageToFast_AlreadyUpgraded_Success(t *testing.T) { @@ -695,25 +742,35 @@ func TestUpgradeStorageToFast_AlreadyUpgraded_Success(t *testing.T) { // Default version when storage key does not exist in the db require.NoError(t, err) - require.False(t, tree.IsFastCacheEnabled()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) mirror := make(map[string]string) // Fill with some data randomizeTreeAndMirror(t, tree, mirror) // Enable fast storage - require.True(t, tree.IsUpgradeable()) + isUpgradeable, err := tree.IsUpgradeable() + require.True(t, isUpgradeable) + require.NoError(t, err) enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() require.NoError(t, err) require.True(t, enabled) - require.True(t, tree.IsFastCacheEnabled()) - require.False(t, tree.IsUpgradeable()) + isFastCacheEnabled, err = tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) + isUpgradeable, err = tree.IsUpgradeable() + require.False(t, isUpgradeable) + require.NoError(t, err) // Test enabling fast storage when already enabled enabled, err = tree.enableFastStorageAndCommitIfNotEnabled() require.NoError(t, err) require.False(t, enabled) - require.True(t, tree.IsFastCacheEnabled()) + isFastCacheEnabled, err = tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) } @@ -736,7 +793,10 @@ func TestUpgradeStorageToFast_DbErrorConstructor_Failure(t *testing.T) { tree, err := NewMutableTree(dbMock, 0) require.Nil(t, err) require.NotNil(t, tree) - require.False(t, tree.IsFastCacheEnabled()) + + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) } func TestUpgradeStorageToFast_DbErrorEnableFastStorage_Failure(t *testing.T) { @@ -762,12 +822,18 @@ func TestUpgradeStorageToFast_DbErrorEnableFastStorage_Failure(t *testing.T) { tree, err := NewMutableTree(dbMock, 0) require.Nil(t, err) require.NotNil(t, tree) - require.False(t, tree.IsFastCacheEnabled()) + + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() require.ErrorIs(t, err, expectedError) require.False(t, enabled) - require.False(t, tree.IsFastCacheEnabled()) + + isFastCacheEnabled, err = tree.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) } func TestFastStorageReUpgradeProtection_NoForceUpgrade_Success(t *testing.T) { @@ -800,11 +866,17 @@ func TestFastStorageReUpgradeProtection_NoForceUpgrade_Success(t *testing.T) { // Pretend that we called Load and have the latest state in the tree tree.version = latestTreeVersion - require.Equal(t, tree.ndb.getLatestVersion(), int64(latestTreeVersion)) + latestVersion, err := tree.ndb.getLatestVersion() + require.NoError(t, err) + require.Equal(t, latestVersion, int64(latestTreeVersion)) // Ensure that the right branch of enableFastStorageAndCommitIfNotEnabled will be triggered - require.True(t, tree.IsFastCacheEnabled()) - require.False(t, tree.ndb.shouldForceFastStorageUpgrade()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) + shouldForce, err := tree.ndb.shouldForceFastStorageUpgrade() + require.False(t, shouldForce) + require.NoError(t, err) enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() require.NoError(t, err) @@ -887,11 +959,17 @@ func TestFastStorageReUpgradeProtection_ForceUpgradeFirstTime_NoForceSecondTime_ // Pretend that we called Load and have the latest state in the tree tree.version = latestTreeVersion - require.Equal(t, tree.ndb.getLatestVersion(), int64(latestTreeVersion)) + latestVersion, err := tree.ndb.getLatestVersion() + require.NoError(t, err) + require.Equal(t, latestVersion, int64(latestTreeVersion)) // Ensure that the right branch of enableFastStorageAndCommitIfNotEnabled will be triggered - require.True(t, tree.IsFastCacheEnabled()) - require.True(t, tree.ndb.shouldForceFastStorageUpgrade()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) + shouldForce, err := tree.ndb.shouldForceFastStorageUpgrade() + require.True(t, shouldForce) + require.NoError(t, err) // Actual method under test enabled, err := tree.enableFastStorageAndCommitIfNotEnabled() @@ -908,26 +986,40 @@ func TestUpgradeStorageToFast_Integration_Upgraded_FastIterator_Success(t *testi // Setup tree, mirror := setupTreeAndMirrorForUpgrade(t) - require.False(t, tree.IsFastCacheEnabled()) - require.True(t, tree.IsUpgradeable()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) + isUpgradeable, err := tree.IsUpgradeable() + require.True(t, isUpgradeable) + require.NoError(t, err) // Should auto enable in save version - _, _, err := tree.SaveVersion() + _, _, err = tree.SaveVersion() require.NoError(t, err) - require.True(t, tree.IsFastCacheEnabled()) - require.False(t, tree.IsUpgradeable()) + isFastCacheEnabled, err = tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) + isUpgradeable, err = tree.IsUpgradeable() + require.False(t, isUpgradeable) + require.NoError(t, err) sut, _ := NewMutableTree(tree.ndb.db, 1000) - require.False(t, sut.IsFastCacheEnabled()) - require.False(t, sut.IsUpgradeable()) // upgraded in save version + isFastCacheEnabled, err = sut.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) + isUpgradeable, err = sut.IsUpgradeable() + require.False(t, isUpgradeable) // upgraded in save version + require.NoError(t, err) // Load version - should auto enable fast storage version, err := sut.Load() require.NoError(t, err) - require.True(t, sut.IsFastCacheEnabled()) + isFastCacheEnabled, err = tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) require.Equal(t, int64(1), version) @@ -961,32 +1053,47 @@ func TestUpgradeStorageToFast_Integration_Upgraded_GetFast_Success(t *testing.T) // Setup tree, mirror := setupTreeAndMirrorForUpgrade(t) - require.False(t, tree.IsFastCacheEnabled()) - require.True(t, tree.IsUpgradeable()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) + isUpgradeable, err := tree.IsUpgradeable() + require.True(t, isUpgradeable) + require.NoError(t, err) // Should auto enable in save version - _, _, err := tree.SaveVersion() + _, _, err = tree.SaveVersion() require.NoError(t, err) - require.True(t, tree.IsFastCacheEnabled()) - require.False(t, tree.IsUpgradeable()) + isFastCacheEnabled, err = tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) + isUpgradeable, err = tree.IsUpgradeable() + require.False(t, isUpgradeable) + require.NoError(t, err) sut, _ := NewMutableTree(tree.ndb.db, 1000) - require.False(t, sut.IsFastCacheEnabled()) - require.False(t, sut.IsUpgradeable()) // upgraded in save version + isFastCacheEnabled, err = sut.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) + isUpgradeable, err = sut.IsUpgradeable() + require.False(t, isUpgradeable) // upgraded in save version + require.NoError(t, err) // LazyLoadVersion - should auto enable fast storage version, err := sut.LazyLoadVersion(1) require.NoError(t, err) - require.True(t, sut.IsFastCacheEnabled()) + isFastCacheEnabled, err = tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) require.Equal(t, int64(1), version) t.Run("Mutable tree", func(t *testing.T) { for _, kv := range mirror { - v := sut.Get([]byte(kv[0])) + v, err := sut.Get([]byte(kv[0])) + require.NoError(t, err) require.Equal(t, []byte(kv[1]), v) } }) @@ -996,7 +1103,8 @@ func TestUpgradeStorageToFast_Integration_Upgraded_GetFast_Success(t *testing.T) require.NoError(t, err) for _, kv := range mirror { - v := immutableTree.Get([]byte(kv[0])) + v, err := immutableTree.Get([]byte(kv[0])) + require.NoError(t, err) require.Equal(t, []byte(kv[1]), v) } }) @@ -1015,7 +1123,9 @@ func setupTreeAndMirrorForUpgrade(t *testing.T) (*MutableTree, [][]string) { key := fmt.Sprintf("%s_%d", keyPrefix, i) val := fmt.Sprintf("%s_%d", valPrefix, i) mirror = append(mirror, []string{key, val}) - require.False(t, tree.Set([]byte(key), []byte(val))) + updated, err := tree.Set([]byte(key), []byte(val)) + require.False(t, updated) + require.NoError(t, err) } // Delete fast nodes from database to mimic a version with no upgrade diff --git a/node.go b/node.go index b7e369727..740cde71b 100644 --- a/node.go +++ b/node.go @@ -122,9 +122,9 @@ func (node *Node) String() string { } // clone creates a shallow copy of a node with its hash set to nil. -func (node *Node) clone(version int64) *Node { +func (node *Node) clone(version int64) (*Node, error) { if node.isLeaf() { - panic("Attempt to copy a leaf node") + return nil, ErrCloneLeafNode } return &Node{ key: node.key, @@ -137,7 +137,7 @@ func (node *Node) clone(version int64) *Node { rightHash: node.rightHash, rightNode: node.rightNode, persisted: false, - } + }, nil } func (node *Node) isLeaf() bool { @@ -145,107 +145,140 @@ func (node *Node) isLeaf() bool { } // Check if the node has a descendant with the given key. -func (node *Node) has(t *ImmutableTree, key []byte) (has bool) { +func (node *Node) has(t *ImmutableTree, key []byte) (has bool, err error) { if bytes.Equal(node.key, key) { - return true + return true, nil } if node.isLeaf() { - return false + return false, nil } if bytes.Compare(key, node.key) < 0 { - return node.getLeftNode(t).has(t, key) + leftNode, err := node.getLeftNode(t) + if err != nil { + return false, err + } + return leftNode.has(t, key) + } + + rightNode, err := node.getRightNode(t) + if err != nil { + return false, err } - return node.getRightNode(t).has(t, key) + + return rightNode.has(t, key) } // Get a key under the node. // // The index is the index in the list of leaf nodes sorted lexicographically by key. The leftmost leaf has index 0. // It's neighbor has index 1 and so on. -func (node *Node) get(t *ImmutableTree, key []byte) (index int64, value []byte) { +func (node *Node) get(t *ImmutableTree, key []byte) (index int64, value []byte, err error) { if node.isLeaf() { switch bytes.Compare(node.key, key) { case -1: - return 1, nil + return 1, nil, nil case 1: - return 0, nil + return 0, nil, nil default: - return 0, node.value + return 0, node.value, nil } } if bytes.Compare(key, node.key) < 0 { - return node.getLeftNode(t).get(t, key) + leftNode, err := node.getLeftNode(t) + if err != nil { + return 0, nil, err + } + + return leftNode.get(t, key) + } + + rightNode, err := node.getRightNode(t) + if err != nil { + return 0, nil, err + } + + index, value, err = rightNode.get(t, key) + if err != nil { + return 0, nil, err } - rightNode := node.getRightNode(t) - index, value = rightNode.get(t, key) + index += node.size - rightNode.size - return index, value + return index, value, nil } -func (node *Node) getByIndex(t *ImmutableTree, index int64) (key []byte, value []byte) { +func (node *Node) getByIndex(t *ImmutableTree, index int64) (key []byte, value []byte, err error) { if node.isLeaf() { if index == 0 { - return node.key, node.value + return node.key, node.value, nil } - return nil, nil + return nil, nil, nil } // TODO: could improve this by storing the // sizes as well as left/right hash. - leftNode := node.getLeftNode(t) + leftNode, err := node.getLeftNode(t) + if err != nil { + return nil, nil, err + } if index < leftNode.size { return leftNode.getByIndex(t, index) } - return node.getRightNode(t).getByIndex(t, index-leftNode.size) + + rightNode, err := node.getRightNode(t) + if err != nil { + return nil, nil, err + } + + return rightNode.getByIndex(t, index-leftNode.size) } // Computes the hash of the node without computing its descendants. Must be // called on nodes which have descendant node hashes already computed. -func (node *Node) _hash() []byte { +func (node *Node) _hash() ([]byte, error) { if node.hash != nil { - return node.hash + return node.hash, nil } h := sha256.New() buf := new(bytes.Buffer) if err := node.writeHashBytes(buf); err != nil { - panic(err) + return nil, err } _, err := h.Write(buf.Bytes()) if err != nil { - panic(err) + return nil, err } node.hash = h.Sum(nil) - return node.hash + return node.hash, nil } // Hash the node and its descendants recursively. This usually mutates all // descendant nodes. Returns the node hash and number of nodes hashed. // If the tree is empty (i.e. the node is nil), returns the hash of an empty input, // to conform with RFC-6962. -func (node *Node) hashWithCount() ([]byte, int64) { +func (node *Node) hashWithCount() ([]byte, int64, error) { if node == nil { - return sha256.New().Sum(nil), 0 + return sha256.New().Sum(nil), 0, nil } if node.hash != nil { - return node.hash, 0 + return node.hash, 0, nil } h := sha256.New() buf := new(bytes.Buffer) hashCount, err := node.writeHashBytesRecursively(buf) if err != nil { - panic(err) + return nil, 0, err } _, err = h.Write(buf.Bytes()) if err != nil { - panic(err) + return nil, 0, err } node.hash = h.Sum(nil) - return node.hash, hashCount + 1 + return node.hash, hashCount + 1, nil } // validate validates the node contents @@ -323,7 +356,7 @@ func (node *Node) writeHashBytes(w io.Writer) error { } } else { if node.leftHash == nil || node.rightHash == nil { - panic("Found an empty child hash") + return ErrEmptyChildHash } err = encoding.EncodeBytes(w, node.leftHash) if err != nil { @@ -342,12 +375,18 @@ func (node *Node) writeHashBytes(w io.Writer) error { // This function has the side-effect of calling hashWithCount. func (node *Node) writeHashBytesRecursively(w io.Writer) (hashCount int64, err error) { if node.leftNode != nil { - leftHash, leftCount := node.leftNode.hashWithCount() + leftHash, leftCount, err := node.leftNode.hashWithCount() + if err != nil { + return 0, err + } node.leftHash = leftHash hashCount += leftCount } if node.rightNode != nil { - rightHash, rightCount := node.rightNode.hashWithCount() + rightHash, rightCount, err := node.rightNode.hashWithCount() + if err != nil { + return 0, err + } node.rightHash = rightHash hashCount += rightCount } @@ -401,7 +440,7 @@ func (node *Node) writeBytes(w io.Writer) error { } } else { if node.leftHash == nil { - panic("node.leftHash was nil in writeBytes") + return ErrLeftHashIsNil } cause = encoding.EncodeBytes(w, node.leftHash) if cause != nil { @@ -409,7 +448,7 @@ func (node *Node) writeBytes(w io.Writer) error { } if node.rightHash == nil { - panic("node.rightHash was nil in writeBytes") + return ErrRightHashIsNil } cause = encoding.EncodeBytes(w, node.rightHash) if cause != nil { @@ -419,28 +458,59 @@ func (node *Node) writeBytes(w io.Writer) error { return nil } -func (node *Node) getLeftNode(t *ImmutableTree) *Node { +func (node *Node) getLeftNode(t *ImmutableTree) (*Node, error) { if node.leftNode != nil { - return node.leftNode + return node.leftNode, nil } - return t.ndb.GetNode(node.leftHash) + leftNode, err := t.ndb.GetNode(node.leftHash) + if err != nil { + return nil, err + } + + return leftNode, nil } -func (node *Node) getRightNode(t *ImmutableTree) *Node { +func (node *Node) getRightNode(t *ImmutableTree) (*Node, error) { if node.rightNode != nil { - return node.rightNode + return node.rightNode, nil } - return t.ndb.GetNode(node.rightHash) + rightNode, err := t.ndb.GetNode(node.rightHash) + if err != nil { + return nil, err + } + + return rightNode, nil } // NOTE: mutates height and size -func (node *Node) calcHeightAndSize(t *ImmutableTree) { - node.height = maxInt8(node.getLeftNode(t).height, node.getRightNode(t).height) + 1 - node.size = node.getLeftNode(t).size + node.getRightNode(t).size +func (node *Node) calcHeightAndSize(t *ImmutableTree) error { + leftNode, err := node.getLeftNode(t) + if err != nil { + return err + } + + rightNode, err := node.getRightNode(t) + if err != nil { + return err + } + + node.height = maxInt8(leftNode.height, rightNode.height) + 1 + node.size = leftNode.size + rightNode.size + return nil } -func (node *Node) calcBalance(t *ImmutableTree) int { - return int(node.getLeftNode(t).height) - int(node.getRightNode(t).height) +func (node *Node) calcBalance(t *ImmutableTree) (int, error) { + leftNode, err := node.getLeftNode(t) + if err != nil { + return 0, err + } + + rightNode, err := node.getRightNode(t) + if err != nil { + return 0, err + } + + return int(leftNode.height) - int(rightNode.height), nil } // traverse is a wrapper over traverseInRange when we want the whole tree @@ -461,7 +531,8 @@ func (node *Node) traversePost(t *ImmutableTree, ascending bool, cb func(*Node) func (node *Node) traverseInRange(tree *ImmutableTree, start, end []byte, ascending bool, inclusive bool, post bool, cb func(*Node) bool) bool { stop := false t := node.newTraversal(tree, start, end, ascending, inclusive, post) - for node2 := t.next(); node2 != nil; node2 = t.next() { + // TODO: figure out how to handle these errors + for node2, err := t.next(); node2 != nil && err == nil; node2, err = t.next() { stop = cb(node2) if stop { return stop @@ -469,3 +540,10 @@ func (node *Node) traverseInRange(tree *ImmutableTree, start, end []byte, ascend } return stop } + +var ( + ErrCloneLeafNode = fmt.Errorf("attempt to copy a leaf node") + ErrEmptyChildHash = fmt.Errorf("found an empty child hash") + ErrLeftHashIsNil = fmt.Errorf("node.leftHash was nil in writeBytes") + ErrRightHashIsNil = fmt.Errorf("node.rightHash was nil in writeBytes") +) diff --git a/nodedb.go b/nodedb.go index 06e887002..7d58ad926 100644 --- a/nodedb.go +++ b/nodedb.go @@ -115,40 +115,40 @@ func newNodeDB(db dbm.DB, cacheSize int, opts *Options) *nodeDB { // GetNode gets a node from memory or disk. If it is an inner node, it does not // load its children. -func (ndb *nodeDB) GetNode(hash []byte) *Node { +func (ndb *nodeDB) GetNode(hash []byte) (*Node, error) { ndb.mtx.Lock() defer ndb.mtx.Unlock() if len(hash) == 0 { - panic("nodeDB.GetNode() requires hash") + return nil, ErrNodeMissingHash } // Check the cache. if elem, ok := ndb.nodeCache[string(hash)]; ok { // Already exists. Move to back of nodeCacheQueue. ndb.nodeCacheQueue.MoveToBack(elem) - return elem.Value.(*Node) + return elem.Value.(*Node), nil } // Doesn't exist, load. buf, err := ndb.db.Get(ndb.nodeKey(hash)) if err != nil { - panic(fmt.Sprintf("can't get node %X: %v", hash, err)) + return nil, fmt.Errorf("can't get node %X: %v", hash, err) } if buf == nil { - panic(fmt.Sprintf("Value missing for hash %x corresponding to nodeKey %x", hash, ndb.nodeKey(hash))) + return nil, fmt.Errorf("Value missing for hash %x corresponding to nodeKey %x", hash, ndb.nodeKey(hash)) } node, err := MakeNode(buf) if err != nil { - panic(fmt.Sprintf("Error reading Node. bytes: %x, error: %v", buf, err)) + return nil, fmt.Errorf("Error reading Node. bytes: %x, error: %v", buf, err) } node.hash = hash node.persisted = true ndb.cacheNode(node) - return node + return node, nil } func (ndb *nodeDB) GetFastNode(key []byte) (*FastNode, error) { @@ -189,15 +189,15 @@ func (ndb *nodeDB) GetFastNode(key []byte) (*FastNode, error) { } // SaveNode saves a node to disk. -func (ndb *nodeDB) SaveNode(node *Node) { +func (ndb *nodeDB) SaveNode(node *Node) error { ndb.mtx.Lock() defer ndb.mtx.Unlock() if node.hash == nil { - panic("Expected to find node.hash, but none found.") + return ErrNodeMissingHash } if node.persisted { - panic("Shouldn't be calling save on an already persisted node.") + return ErrNodeAlreadyPersisted } // Save node bytes to db. @@ -205,15 +205,16 @@ func (ndb *nodeDB) SaveNode(node *Node) { buf.Grow(node.encodedSize()) if err := node.writeBytes(&buf); err != nil { - panic(err) + return err } if err := ndb.batch.Set(ndb.nodeKey(node.hash), buf.Bytes()); err != nil { - panic(err) + return err } logger.Debug("BATCH SAVE %X %p\n", node.hash, node) node.persisted = true ndb.cacheNode(node) + return nil } // SaveNode saves a FastNode to disk and add to cache. @@ -248,7 +249,12 @@ func (ndb *nodeDB) setFastStorageVersionToBatch() error { newVersion = fastStorageVersionValue } - newVersion += fastStorageVersionDelimiter + strconv.Itoa(int(ndb.getLatestVersion())) + latestVersion, err := ndb.getLatestVersion() + if err != nil { + return err + } + + newVersion += fastStorageVersionDelimiter + strconv.Itoa(int(latestVersion)) if err := ndb.batch.Set(metadataKeyFormat.Key([]byte(storageVersionKey)), []byte(newVersion)); err != nil { return err @@ -270,15 +276,20 @@ func (ndb *nodeDB) hasUpgradedToFastStorage() bool { // When the live state is not matched, we must force reupgrade. // We determine this by checking the version of the live state and the version of the live state when // latest storage was updated on disk the last time. -func (ndb *nodeDB) shouldForceFastStorageUpgrade() bool { +func (ndb *nodeDB) shouldForceFastStorageUpgrade() (bool, error) { versions := strings.Split(ndb.storageVersion, fastStorageVersionDelimiter) if len(versions) == 2 { - if versions[1] != strconv.Itoa(int(ndb.getLatestVersion())) { - return true + latestVersion, err := ndb.getLatestVersion() + if err != nil { + // TODO: should be true or false as default? (removed panic here) + return false, err + } + if versions[1] != strconv.Itoa(int(latestVersion)) { + return true, nil } } - return false + return false, nil } // SaveNode saves a FastNode to disk. @@ -349,8 +360,15 @@ func (ndb *nodeDB) SaveBranch(node *Node) ([]byte, error) { return nil, err } - node._hash() - ndb.SaveNode(node) + _, err = node._hash() + if err != nil { + return nil, err + } + + err = ndb.SaveNode(node) + if err != nil { + return nil, err + } // resetBatch only working on generate a genesis block if node.version <= genesisVersion { @@ -409,7 +427,10 @@ func (ndb *nodeDB) DeleteVersion(version int64, checkLatestVersion bool) error { // DeleteVersionsFrom permanently deletes all tree versions from the given version upwards. func (ndb *nodeDB) DeleteVersionsFrom(version int64) error { - latest := ndb.getLatestVersion() + latest, err := ndb.getLatestVersion() + if err != nil { + return err + } if latest < version { return nil } @@ -510,12 +531,18 @@ func (ndb *nodeDB) DeleteVersionsRange(fromVersion, toVersion int64) error { ndb.mtx.Lock() defer ndb.mtx.Unlock() - latest := ndb.getLatestVersion() + latest, err := ndb.getLatestVersion() + if err != nil { + return err + } if latest < toVersion { return errors.Errorf("cannot delete latest saved version (%d)", latest) } - predecessor := ndb.getPreviousVersion(fromVersion) + predecessor, err := ndb.getPreviousVersion(fromVersion) + if err != nil { + return err + } for v, r := range ndb.versionReaders { if v < toVersion && v > predecessor && r != 0 { @@ -534,11 +561,13 @@ func (ndb *nodeDB) DeleteVersionsRange(fromVersion, toVersion int64) error { } if from > predecessor { if err := ndb.batch.Delete(ndb.nodeKey(hash)); err != nil { - panic(err) + return err } ndb.uncacheNode(hash) } else { - ndb.saveOrphan(hash, from, predecessor) + if err := ndb.saveOrphan(hash, from, predecessor); err != nil { + return err + } } return nil }) @@ -548,7 +577,7 @@ func (ndb *nodeDB) DeleteVersionsRange(fromVersion, toVersion int64) error { } // Delete the version root entries - err := ndb.traverseRange(rootKeyFormat.Key(fromVersion), rootKeyFormat.Key(toVersion), func(k, v []byte) error { + err = ndb.traverseRange(rootKeyFormat.Key(fromVersion), rootKeyFormat.Key(toVersion), func(k, v []byte) error { if err := ndb.batch.Delete(k); err != nil { return err } @@ -578,7 +607,11 @@ func (ndb *nodeDB) deleteNodesFrom(version int64, hash []byte) error { return nil } - node := ndb.GetNode(hash) + node, err := ndb.GetNode(hash) + if err != nil { + return err + } + if node.leftHash != nil { if err := ndb.deleteNodesFrom(version, node.leftHash); err != nil { return err @@ -604,33 +637,45 @@ func (ndb *nodeDB) deleteNodesFrom(version int64, hash []byte) error { // Saves orphaned nodes to disk under a special prefix. // version: the new version being saved. // orphans: the orphan nodes created since version-1 -func (ndb *nodeDB) SaveOrphans(version int64, orphans map[string]int64) { +func (ndb *nodeDB) SaveOrphans(version int64, orphans map[string]int64) error { ndb.mtx.Lock() defer ndb.mtx.Unlock() - toVersion := ndb.getPreviousVersion(version) + toVersion, err := ndb.getPreviousVersion(version) + if err != nil { + return err + } + for hash, fromVersion := range orphans { logger.Debug("SAVEORPHAN %v-%v %X\n", fromVersion, toVersion, hash) - ndb.saveOrphan([]byte(hash), fromVersion, toVersion) + err := ndb.saveOrphan([]byte(hash), fromVersion, toVersion) + if err != nil { + return err + } } + return nil } // Saves a single orphan to disk. -func (ndb *nodeDB) saveOrphan(hash []byte, fromVersion, toVersion int64) { +func (ndb *nodeDB) saveOrphan(hash []byte, fromVersion, toVersion int64) error { if fromVersion > toVersion { - panic(fmt.Sprintf("Orphan expires before it comes alive. %d > %d", fromVersion, toVersion)) + return fmt.Errorf("orphan expires before it comes alive. %d > %d", fromVersion, toVersion) } key := ndb.orphanKey(fromVersion, toVersion, hash) if err := ndb.batch.Set(key, hash); err != nil { - panic(err) + return err } + return nil } // deleteOrphans deletes orphaned nodes from disk, and the associated orphan // entries. func (ndb *nodeDB) deleteOrphans(version int64) error { // Will be zero if there is no previous version. - predecessor := ndb.getPreviousVersion(version) + predecessor, err := ndb.getPreviousVersion(version) + if err != nil { + return err + } // Traverse orphans with a lifetime ending at the version specified. // TODO optimize. @@ -681,11 +726,15 @@ func (ndb *nodeDB) rootKey(version int64) []byte { return rootKeyFormat.Key(version) } -func (ndb *nodeDB) getLatestVersion() int64 { +func (ndb *nodeDB) getLatestVersion() (int64, error) { if ndb.latestVersion == 0 { - ndb.latestVersion = ndb.getPreviousVersion(1<<63 - 1) + var err error + ndb.latestVersion, err = ndb.getPreviousVersion(1<<63 - 1) + if err != nil { + return 0, err + } } - return ndb.latestVersion + return ndb.latestVersion, nil } func (ndb *nodeDB) updateLatestVersion(version int64) { @@ -698,13 +747,13 @@ func (ndb *nodeDB) resetLatestVersion(version int64) { ndb.latestVersion = version } -func (ndb *nodeDB) getPreviousVersion(version int64) int64 { +func (ndb *nodeDB) getPreviousVersion(version int64) (int64, error) { itr, err := ndb.db.ReverseIterator( rootKeyFormat.Key(1), rootKeyFormat.Key(version), ) if err != nil { - panic(err) + return 0, err } defer itr.Close() @@ -712,19 +761,24 @@ func (ndb *nodeDB) getPreviousVersion(version int64) int64 { for ; itr.Valid(); itr.Next() { k := itr.Key() rootKeyFormat.Scan(k, &pversion) - return pversion + return pversion, nil } if err := itr.Error(); err != nil { - panic(err) + return 0, err } - return 0 + return 0, nil } // deleteRoot deletes the root entry from disk, but not the node it points to. func (ndb *nodeDB) deleteRoot(version int64, checkLatestVersion bool) error { - if checkLatestVersion && version == ndb.getLatestVersion() { + latestVersion, err := ndb.getLatestVersion() + if err != nil { + return err + } + + if checkLatestVersion && version == latestVersion { return errors.New("tried to delete latest version") } if err := ndb.batch.Delete(ndb.rootKey(version)); err != nil { @@ -900,7 +954,7 @@ func (ndb *nodeDB) getRoots() (roots map[int64][]byte, err error) { // loaded later. func (ndb *nodeDB) SaveRoot(root *Node, version int64) error { if len(root.hash) == 0 { - panic("SaveRoot: root hash should not be empty") + return ErrRootMissingHash } return ndb.saveRoot(root.hash, version) } @@ -915,7 +969,10 @@ func (ndb *nodeDB) saveRoot(hash []byte, version int64) error { defer ndb.mtx.Unlock() // We allow the initial version to be arbitrary - latest := ndb.getLatestVersion() + latest, err := ndb.getLatestVersion() + if err != nil { + return err + } if latest > 0 && version != latest+1 { return fmt.Errorf("must save consecutive versions; expected %d, got %d", latest+1, version) } @@ -1093,3 +1150,9 @@ func (ndb *nodeDB) String() (string, error) { return "-" + "\n" + buf.String() + "-", nil } + +var ( + ErrNodeMissingHash = fmt.Errorf("node does not have a hash") + ErrNodeAlreadyPersisted = fmt.Errorf("shouldn't be calling save on an already persisted node") + ErrRootMissingHash = fmt.Errorf("root hash must not be empty") +) diff --git a/nodedb_test.go b/nodedb_test.go index 3db11769e..389e9d7ff 100644 --- a/nodedb_test.go +++ b/nodedb_test.go @@ -94,7 +94,10 @@ func TestSetStorageVersion_Success(t *testing.T) { err := ndb.setFastStorageVersionToBatch() require.NoError(t, err) - require.Equal(t, expectedVersion+fastStorageVersionDelimiter+strconv.Itoa(int(ndb.getLatestVersion())), ndb.getStorageVersion()) + + latestVersion, err := ndb.getLatestVersion() + require.NoError(t, err) + require.Equal(t, expectedVersion+fastStorageVersionDelimiter+strconv.Itoa(int(latestVersion)), ndb.getStorageVersion()) require.NoError(t, ndb.batch.Write()) } @@ -200,7 +203,9 @@ func TestShouldForceFastStorageUpdate_DefaultVersion_True(t *testing.T) { ndb.storageVersion = defaultStorageVersionValue ndb.latestVersion = 100 - require.False(t, ndb.shouldForceFastStorageUpgrade()) + shouldForce, err := ndb.shouldForceFastStorageUpgrade() + require.False(t, shouldForce) + require.NoError(t, err) } func TestShouldForceFastStorageUpdate_FastVersion_Greater_True(t *testing.T) { @@ -209,7 +214,9 @@ func TestShouldForceFastStorageUpdate_FastVersion_Greater_True(t *testing.T) { ndb.latestVersion = 100 ndb.storageVersion = fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(int(ndb.latestVersion+1)) - require.True(t, ndb.shouldForceFastStorageUpgrade()) + shouldForce, err := ndb.shouldForceFastStorageUpgrade() + require.True(t, shouldForce) + require.NoError(t, err) } func TestShouldForceFastStorageUpdate_FastVersion_Smaller_True(t *testing.T) { @@ -218,7 +225,9 @@ func TestShouldForceFastStorageUpdate_FastVersion_Smaller_True(t *testing.T) { ndb.latestVersion = 100 ndb.storageVersion = fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(int(ndb.latestVersion-1)) - require.True(t, ndb.shouldForceFastStorageUpgrade()) + shouldForce, err := ndb.shouldForceFastStorageUpgrade() + require.True(t, shouldForce) + require.NoError(t, err) } func TestShouldForceFastStorageUpdate_FastVersion_Match_False(t *testing.T) { @@ -227,7 +236,9 @@ func TestShouldForceFastStorageUpdate_FastVersion_Match_False(t *testing.T) { ndb.latestVersion = 100 ndb.storageVersion = fastStorageVersionValue + fastStorageVersionDelimiter + strconv.Itoa(int(ndb.latestVersion)) - require.False(t, ndb.shouldForceFastStorageUpgrade()) + shouldForce, err := ndb.shouldForceFastStorageUpgrade() + require.False(t, shouldForce) + require.NoError(t, err) } func TestIsFastStorageEnabled_True(t *testing.T) { @@ -245,7 +256,9 @@ func TestIsFastStorageEnabled_False(t *testing.T) { ndb.latestVersion = 100 ndb.storageVersion = defaultStorageVersionValue - require.False(t, ndb.shouldForceFastStorageUpgrade()) + shouldForce, err := ndb.shouldForceFastStorageUpgrade() + require.False(t, shouldForce) + require.NoError(t, err) } func makeHashes(b *testing.B, seed int64) [][]byte { diff --git a/proof.go b/proof.go index fdc10a06b..51a8860e2 100644 --- a/proof.go +++ b/proof.go @@ -61,7 +61,7 @@ func (pin ProofInnerNode) stringIndented(indent string) string { indent) } -func (pin ProofInnerNode) Hash(childHash []byte) []byte { +func (pin ProofInnerNode) Hash(childHash []byte) ([]byte, error) { hasher := sha256.New() buf := bufPool.Get().(*bytes.Buffer) @@ -92,14 +92,14 @@ func (pin ProofInnerNode) Hash(childHash []byte) []byte { } } if err != nil { - panic(fmt.Sprintf("Failed to hash ProofInnerNode: %v", err)) + return nil, fmt.Errorf("Failed to hash ProofInnerNode: %v", err) } _, err = hasher.Write(buf.Bytes()) if err != nil { - panic(err) + return nil, err } - return hasher.Sum(nil) + return hasher.Sum(nil), nil } // toProto converts the inner node proof to Protobuf, for use in ProofOps. @@ -154,7 +154,7 @@ func (pln ProofLeafNode) stringIndented(indent string) string { indent) } -func (pln ProofLeafNode) Hash() []byte { +func (pln ProofLeafNode) Hash() ([]byte, error) { hasher := sha256.New() buf := bufPool.Get().(*bytes.Buffer) @@ -175,15 +175,15 @@ func (pln ProofLeafNode) Hash() []byte { err = encoding.EncodeBytes(buf, pln.ValueHash) } if err != nil { - panic(fmt.Sprintf("Failed to hash ProofLeafNode: %v", err)) + return nil, fmt.Errorf("failed to hash ProofLeafNode: %v", err) } _, err = hasher.Write(buf.Bytes()) if err != nil { - panic(err) + return nil, err } - return hasher.Sum(nil) + return hasher.Sum(nil), nil } // toProto converts the leaf node proof to Protobuf, for use in ProofOps. @@ -235,26 +235,47 @@ func (node *Node) pathToLeaf(t *ImmutableTree, key []byte, path *PathToLeaf) (*N // already stored in the next ProofInnerNode in PathToLeaf. if bytes.Compare(key, node.key) < 0 { // left side + rightNode, err := node.getRightNode(t) + if err != nil { + return nil, err + } + pin := ProofInnerNode{ Height: node.height, Size: node.size, Version: node.version, Left: nil, - Right: node.getRightNode(t).hash, + Right: rightNode.hash, } *path = append(*path, pin) - n, err := node.getLeftNode(t).pathToLeaf(t, key, path) + + leftNode, err := node.getLeftNode(t) + if err != nil { + return nil, err + } + n, err := leftNode.pathToLeaf(t, key, path) return n, err } // right side + leftNode, err := node.getLeftNode(t) + if err != nil { + return nil, err + } + pin := ProofInnerNode{ Height: node.height, Size: node.size, Version: node.version, - Left: node.getLeftNode(t).hash, + Left: leftNode.hash, Right: nil, } *path = append(*path, pin) - n, err := node.getRightNode(t).pathToLeaf(t, key, path) + + rightNode, err := node.getRightNode(t) + if err != nil { + return nil, err + } + + n, err := rightNode.pathToLeaf(t, key, path) return n, err } diff --git a/proof_iavl_test.go b/proof_iavl_test.go index ef247ea44..cf3fcd877 100644 --- a/proof_iavl_test.go +++ b/proof_iavl_test.go @@ -19,7 +19,8 @@ func TestProofOp(t *testing.T) { key := []byte{ikey} tree.Set(key, key) } - root := tree.WorkingHash() + root, err := tree.WorkingHash() + require.NoError(t, err) testcases := []struct { key byte diff --git a/proof_ics23.go b/proof_ics23.go index 04c29ba6d..991359a79 100644 --- a/proof_ics23.go +++ b/proof_ics23.go @@ -30,18 +30,26 @@ If the key exists in the tree, this will return an error. */ func (t *ImmutableTree) GetNonMembershipProof(key []byte) (*ics23.CommitmentProof, error) { // idx is one node right of what we want.... - idx, val := t.GetWithIndex(key) + var err error + idx, val, err := t.GetWithIndex(key) + if err != nil { + return nil, err + } + if val != nil { return nil, fmt.Errorf("cannot create NonExistanceProof when Key in State") } - var err error nonexist := &ics23.NonExistenceProof{ Key: key, } if idx >= 1 { - leftkey, _ := t.GetByIndex(idx - 1) + leftkey, _, err := t.GetByIndex(idx - 1) + if err != nil { + return nil, err + } + nonexist.Left, err = createExistenceProof(t, leftkey) if err != nil { return nil, err @@ -49,7 +57,11 @@ func (t *ImmutableTree) GetNonMembershipProof(key []byte) (*ics23.CommitmentProo } // this will be nil if nothing right of the queried key - rightkey, _ := t.GetByIndex(idx) + rightkey, _, err := t.GetByIndex(idx) + if err != nil { + return nil, err + } + if rightkey != nil { nonexist.Right, err = createExistenceProof(t, rightkey) if err != nil { diff --git a/proof_ics23_test.go b/proof_ics23_test.go index 43b2f656c..db682a969 100644 --- a/proof_ics23_test.go +++ b/proof_ics23_test.go @@ -46,11 +46,13 @@ func TestGetMembership(t *testing.T) { require.NoError(t, err, "Creating tree: %+v", err) key := GetKey(allkeys, tc.loc) - val := tree.Get(key) + val, err := tree.Get(key) + require.NoError(t, err) proof, err := tree.GetMembershipProof(key) require.NoError(t, err, "Creating Proof: %+v", err) - root := tree.Hash() + root, err := tree.Hash() + require.NoError(t, err) valid := ics23.VerifyMembership(ics23.IavlSpec, root, proof, key, val) if !valid { require.NoError(t, err, "Membership Proof Invalid") @@ -78,7 +80,8 @@ func TestGetNonMembership(t *testing.T) { proof, err := tree.GetNonMembershipProof(key) require.NoError(t, err, "Creating Proof: %+v", err) - root := tree.Hash() + root, err := tree.Hash() + require.NoError(t, err) valid := ics23.VerifyNonMembership(ics23.IavlSpec, root, proof, key) if !valid { require.NoError(t, err, "Non Membership Proof Invalid") @@ -94,7 +97,9 @@ func TestGetNonMembership(t *testing.T) { _, _, err = tree.SaveVersion() require.NoError(t, err) - require.True(t, tree.IsFastCacheEnabled()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) performTest(tree, allkeys, tc.loc) }) @@ -102,7 +107,9 @@ func TestGetNonMembership(t *testing.T) { t.Run("regular-"+name, func(t *testing.T) { tree, allkeys, err := BuildTree(tc.size, 0) require.NoError(t, err, "Creating tree: %+v", err) - require.False(t, tree.IsFastCacheEnabled()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(t, err) + require.False(t, isFastCacheEnabled) performTest(tree, allkeys, tc.loc) }) @@ -129,7 +136,8 @@ func BenchmarkGetNonMembership(b *testing.B) { require.NoError(b, err, "Creating Proof: %+v", err) b.StopTimer() - root := tree.Hash() + root, err := tree.Hash() + require.NoError(b, err) valid := ics23.VerifyNonMembership(ics23.IavlSpec, root, proof, key) if !valid { require.NoError(b, err, "Non Membership Proof Invalid") @@ -149,7 +157,9 @@ func BenchmarkGetNonMembership(b *testing.B) { _, _, err = tree.SaveVersion() require.NoError(b, err) - require.True(b, tree.IsFastCacheEnabled()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(b, err) + require.True(b, isFastCacheEnabled) b.StartTimer() performTest(tree, allkeys, tc.loc) } @@ -163,7 +173,9 @@ func BenchmarkGetNonMembership(b *testing.B) { tree, allkeys, err := BuildTree(tc.size, 100000) require.NoError(b, err, "Creating tree: %+v", err) - require.False(b, tree.IsFastCacheEnabled()) + isFastCacheEnabled, err := tree.IsFastCacheEnabled() + require.NoError(b, err) + require.False(b, isFastCacheEnabled) b.StartTimer() performTest(tree, allkeys, tc.loc) @@ -205,7 +217,10 @@ func GenerateResult(size int, loc Where) (*Result, error) { if len(proof.Leaves) != 1 { return nil, fmt.Errorf("tree.GetWithProof returned %d leaves", len(proof.Leaves)) } - root := tree.Hash() + root, err := tree.Hash() + if err != nil { + return nil, err + } res := &Result{ Key: key, diff --git a/proof_path.go b/proof_path.go index 86915a1b8..95397858a 100644 --- a/proof_path.go +++ b/proof_path.go @@ -27,8 +27,11 @@ func (pwl pathWithLeaf) StringIndented(indent string) string { // `computeRootHash` computes the root hash with leaf node. // Does not verify the root hash. -func (pwl pathWithLeaf) computeRootHash() []byte { - leafHash := pwl.Leaf.Hash() +func (pwl pathWithLeaf) computeRootHash() ([]byte, error) { + leafHash, err := pwl.Leaf.Hash() + if err != nil { + return nil, err + } return pwl.Path.computeRootHash(leafHash) } @@ -64,13 +67,17 @@ func (pl PathToLeaf) stringIndented(indent string) string { // `computeRootHash` computes the root hash assuming some leaf hash. // Does not verify the root hash. -func (pl PathToLeaf) computeRootHash(leafHash []byte) []byte { +func (pl PathToLeaf) computeRootHash(leafHash []byte) ([]byte, error) { + var err error hash := leafHash for i := len(pl) - 1; i >= 0; i-- { pin := pl[i] - hash = pin.Hash(hash) + hash, err = pin.Hash(hash) + if err != nil { + return nil, err + } } - return hash + return hash, nil } func (pl PathToLeaf) isLeftmost() bool { diff --git a/proof_range.go b/proof_range.go index 7daa8eb5b..345bbe90d 100644 --- a/proof_range.go +++ b/proof_range.go @@ -245,11 +245,15 @@ func (proof *RangeProof) _computeRootHash() (rootHash []byte, treeEnd bool, err leaves = rleaves // Compute hash. - hash = (pathWithLeaf{ + hash, err = (pathWithLeaf{ Path: path, Leaf: nleaf, }).computeRootHash() + if err != nil { + return nil, treeEnd, false, err + } + // If we don't have any leaves left, we're done. if len(leaves) == 0 { rightmost = rightmost && path.isRightmost() @@ -370,20 +374,24 @@ func RangeProofFromProto(pbProof *iavlproto.RangeProof) (RangeProof, error) { // If keyStart or keyEnd don't exist, the leaf before keyStart // or after keyEnd will also be included, but not be included in values. // If keyEnd-1 exists, no later leaves will be included. -// If keyStart >= keyEnd and both not nil, panics. +// If keyStart >= keyEnd and both not nil, errors out. // Limit is never exceeded. //nolint:unparam func (t *ImmutableTree) getRangeProof(keyStart, keyEnd []byte, limit int) (proof *RangeProof, keys, values [][]byte, err error) { if keyStart != nil && keyEnd != nil && bytes.Compare(keyStart, keyEnd) >= 0 { - panic("if keyStart and keyEnd are present, need keyStart < keyEnd.") + return nil, nil, nil, fmt.Errorf("if keyStart and keyEnd are present, need keyStart < keyEnd") } if limit < 0 { - panic("limit must be greater or equal to 0 -- 0 means no limit") + return nil, nil, nil, fmt.Errorf("limit must be greater or equal to 0 -- 0 means no limit") } if t.root == nil { return nil, nil, nil, nil } - t.root.hashWithCount() // Ensure that all hashes are calculated. + + _, _, err = t.root.hashWithCount() // Ensure that all hashes are calculated. + if err != nil { + return nil, nil, nil, err + } // Get the first key/value pair proof, which provides us with the left key. path, left, err := t.root.PathToLeaf(t, keyStart) diff --git a/proof_test.go b/proof_test.go index b05e73223..8eb568ab0 100644 --- a/proof_test.go +++ b/proof_test.go @@ -22,7 +22,8 @@ func TestTreeGetWithProof(t *testing.T) { key := []byte{ikey} tree.Set(key, []byte(iavlrand.RandStr(8))) } - root := tree.WorkingHash() + root, err := tree.WorkingHash() + require.NoError(err) key := []byte{0x32} val, proof, err := tree.GetWithProof(key) @@ -52,7 +53,8 @@ func TestTreeGetWithProof(t *testing.T) { func TestTreeKeyExistsProof(t *testing.T) { tree, err := getTestTree(0) require.NoError(t, err) - root := tree.WorkingHash() + root, err := tree.WorkingHash() + require.NoError(t, err) // should get false for proof with nil root proof, keys, values, err := tree.getRangeProof([]byte("foo"), nil, 1) @@ -71,7 +73,8 @@ func TestTreeKeyExistsProof(t *testing.T) { allkeys[i] = []byte(key) } sortByteSlices(allkeys) // Sort all keys - root = tree.WorkingHash() + root, err = tree.WorkingHash() + require.NoError(t, err) // query random key fails proof, _, _, err = tree.getRangeProof([]byte("foo"), nil, 2) @@ -125,7 +128,8 @@ func TestTreeKeyInRangeProofs(t *testing.T) { key := []byte{ikey} tree.Set(key, key) } - root := tree.WorkingHash() + root, err := tree.WorkingHash() + require.NoError(err) // For spacing: T := 10 @@ -139,7 +143,7 @@ func TestTreeKeyInRangeProofs(t *testing.T) { pkeys []byte // proof keys, one byte per key. vals []byte // keys and values, one byte per key. lidx int64 // proof left index (index of first proof key). - pnc bool // does panic + err bool // does error }{ {start: 0x0a, end: 0xf7, pkeys: keys[0:T], vals: keys[0:9], lidx: 0}, // #0 {start: 0x0a, end: 0xf8, pkeys: keys[0:T], vals: keys[0:T], lidx: 0}, // #1 @@ -159,11 +163,11 @@ func TestTreeKeyInRangeProofs(t *testing.T) { {start: 0xf8, end: 0xff, pkeys: keys[9:T], vals: nil______, lidx: 9}, // #15 {start: 0x12, end: 0x20, pkeys: keys[1:3], vals: nil______, lidx: 1}, // #16 {start: 0x00, end: 0x09, pkeys: keys[0:1], vals: nil______, lidx: 0}, // #17 - {start: 0xf7, end: 0x00, pnc: true}, // #18 - {start: 0xf8, end: 0x00, pnc: true}, // #19 - {start: 0x10, end: 0x10, pnc: true}, // #20 - {start: 0x12, end: 0x12, pnc: true}, // #21 - {start: 0xff, end: 0xf7, pnc: true}, // #22 + {start: 0xf7, end: 0x00, err: true}, // #18 + {start: 0xf8, end: 0x00, err: true}, // #19 + {start: 0x10, end: 0x10, err: true}, // #20 + {start: 0x12, end: 0x12, err: true}, // #21 + {start: 0xff, end: 0xf7, err: true}, // #22 } // fmt.Println("PRINT TREE") @@ -175,35 +179,36 @@ func TestTreeKeyInRangeProofs(t *testing.T) { start := []byte{c.start} end := []byte{c.end} - if c.pnc { - require.Panics(func() { tree.GetRangeWithProof(start, end, 0) }) - continue - } - // Compute range proof. keys, values, proof, err := tree.GetRangeWithProof(start, end, 0) - require.NoError(err, "%+v", err) - require.Equal(c.pkeys, flatten(proof.Keys())) - require.Equal(c.vals, flatten(keys)) - require.Equal(c.vals, flatten(values)) - require.Equal(c.lidx, proof.LeftIndex()) - - // Verify that proof is valid. - err = proof.Verify(root) - require.NoError(err, "%+v", err) - verifyProof(t, proof, root) - - // Verify each value of pkeys. - for _, key := range c.pkeys { - err := proof.VerifyItem([]byte{key}, []byte{key}) - require.NoError(err) - } - // Verify each value of vals. - for _, key := range c.vals { - err := proof.VerifyItem([]byte{key}, []byte{key}) - require.NoError(err) + if c.err { + require.Error(err, "%+v", err) + } else { + require.NoError(err, "%+v", err) + require.Equal(c.pkeys, flatten(proof.Keys())) + require.Equal(c.vals, flatten(keys)) + require.Equal(c.vals, flatten(values)) + require.Equal(c.lidx, proof.LeftIndex()) + + // Verify that proof is valid. + err = proof.Verify(root) + require.NoError(err, "%+v", err) + verifyProof(t, proof, root) + + // Verify each value of pkeys. + for _, key := range c.pkeys { + err := proof.VerifyItem([]byte{key}, []byte{key}) + require.NoError(err) + } + + // Verify each value of vals. + for _, key := range c.vals { + err := proof.VerifyItem([]byte{key}, []byte{key}) + require.NoError(err) + } } + } } diff --git a/repair.go b/repair.go index e688b9cda..952e9ff9b 100644 --- a/repair.go +++ b/repair.go @@ -30,15 +30,15 @@ import ( // the case. func Repair013Orphans(db dbm.DB) (uint64, error) { ndb := newNodeDB(db, 0, &Options{Sync: true}) - version := ndb.getLatestVersion() + version, err := ndb.getLatestVersion() + if err != nil { + return 0, err + } if version == 0 { return 0, errors.New("no versions found") } - var ( - repaired uint64 - err error - ) + var repaired uint64 batch := db.NewBatch() defer batch.Close() err = ndb.traverseRange(orphanKeyFormat.Key(version), orphanKeyFormat.Key(int64(math.MaxInt64)), func(k, v []byte) error { diff --git a/repair_test.go b/repair_test.go index 560ea51c3..c60923d65 100644 --- a/repair_test.go +++ b/repair_test.go @@ -59,7 +59,8 @@ func TestRepair013Orphans(t *testing.T) { require.NoError(t, err) // Reading "rm7" (which should not have been deleted now) would panic with a broken database. - value := tree.Get([]byte("rm7")) + value, err := tree.Get([]byte("rm7")) + require.NoError(t, err) require.Equal(t, []byte{1}, value) // Check all persisted versions. @@ -91,7 +92,8 @@ func assertVersion(t *testing.T, tree *MutableTree, version int64) { version = itree.version // The "current" value should have the current version for <= 6, then 6 afterwards - value := itree.Get([]byte("current")) + value, err := itree.Get([]byte("current")) + require.NoError(t, err) if version >= 6 { require.EqualValues(t, []byte{6}, value) } else { @@ -101,14 +103,16 @@ func assertVersion(t *testing.T, tree *MutableTree, version int64) { // The "addX" entries should exist for 1-6 in the respective versions, and the // "rmX" entries should have been removed for 1-6 in the respective versions. for i := byte(1); i < 8; i++ { - value = itree.Get([]byte(fmt.Sprintf("add%v", i))) + value, err = itree.Get([]byte(fmt.Sprintf("add%v", i))) + require.NoError(t, err) if i <= 6 && int64(i) <= version { require.Equal(t, []byte{i}, value) } else { require.Nil(t, value) } - value = itree.Get([]byte(fmt.Sprintf("rm%v", i))) + value, err = itree.Get([]byte(fmt.Sprintf("rm%v", i))) + require.NoError(t, err) if i <= 6 && version >= int64(i) { require.Nil(t, value) } else { diff --git a/testutils_test.go b/testutils_test.go index 540561ac7..d45890d82 100644 --- a/testutils_test.go +++ b/testutils_test.go @@ -71,12 +71,15 @@ func N(l, r interface{}) *Node { } // Setup a deep node -func T(n *Node) *MutableTree { +func T(n *Node) (*MutableTree, error) { t, _ := getTestTree(0) - n.hashWithCount() + _, _, err := n.hashWithCount() + if err != nil { + return nil, err + } t.root = n - return t + return t, nil } // Convenience for simple printing of keys & tree structure @@ -189,7 +192,8 @@ func randomizeTreeAndMirror(t *testing.T, tree *MutableTree, mirror map[string]s key := randBytes(keyValLength) value := randBytes(keyValLength) - isUpdated := tree.Set(key, value) + isUpdated, err := tree.Set(key, value) + require.NoError(t, err) require.False(t, isUpdated) mirror[string(key)] = string(value) @@ -210,7 +214,8 @@ func randomizeTreeAndMirror(t *testing.T, tree *MutableTree, mirror map[string]s key := randBytes(keyValLength) value := randBytes(keyValLength) - isUpdated := tree.Set(key, value) + isUpdated, err := tree.Set(key, value) + require.NoError(t, err) require.False(t, isUpdated) mirror[string(key)] = string(value) case 1: @@ -223,7 +228,8 @@ func randomizeTreeAndMirror(t *testing.T, tree *MutableTree, mirror map[string]s key := getRandomKeyFrom(mirror) value := randBytes(keyValLength) - isUpdated := tree.Set([]byte(key), value) + isUpdated, err := tree.Set([]byte(key), value) + require.NoError(t, err) require.True(t, isUpdated) mirror[key] = string(value) case 2: @@ -234,7 +240,8 @@ func randomizeTreeAndMirror(t *testing.T, tree *MutableTree, mirror map[string]s key := getRandomKeyFrom(mirror) - val, isRemoved := tree.Remove([]byte(key)) + val, isRemoved, err := tree.Remove([]byte(key)) + require.NoError(t, err) require.True(t, isRemoved) require.NotNil(t, val) delete(mirror, key) @@ -269,7 +276,8 @@ func setupMirrorForIterator(t *testing.T, config *iteratorTestConfig, tree *Muta mirror = append(mirror, []string{string(curByte), string(value)}) } - isUpdated := tree.Set([]byte{curByte}, value) + isUpdated, err := tree.Set([]byte{curByte}, value) + require.NoError(t, err) require.False(t, isUpdated) if config.ascending { @@ -350,5 +358,9 @@ func (node *Node) lmd(t *ImmutableTree) *Node { if node.isLeaf() { return node } - return node.getLeftNode(t).lmd(t) + + // TODO: Should handle this error? + leftNode, _ := node.getLeftNode(t) + + return leftNode.lmd(t) } diff --git a/tree_dotgraph.go b/tree_dotgraph.go index 83acd46b3..38df20539 100644 --- a/tree_dotgraph.go +++ b/tree_dotgraph.go @@ -44,6 +44,7 @@ var defaultGraphNodeAttrs = map[string]string{ func WriteDOTGraph(w io.Writer, tree *ImmutableTree, paths []PathToLeaf) { ctx := &graphContext{} + // TODO: handle error tree.root.hashWithCount() tree.root.traverse(tree, true, func(node *Node) bool { graphNode := &graphNode{ diff --git a/tree_random_test.go b/tree_random_test.go index 422a3e3ed..7fee3b173 100644 --- a/tree_random_test.go +++ b/tree_random_test.go @@ -122,24 +122,27 @@ func testRandomOperations(t *testing.T, randSeed int64) { index := r.Intn(len(mirrorKeys)) key := mirrorKeys[index] mirrorKeys = append(mirrorKeys[:index], mirrorKeys[index+1:]...) - _, removed := tree.Remove([]byte(key)) + _, removed, err := tree.Remove([]byte(key)) + require.NoError(t, err) require.True(t, removed) delete(mirror, key) case len(mirror) > 0 && r.Float64() < updateRatio: key := mirrorKeys[r.Intn(len(mirrorKeys))] value := randString(valueSize) - updated := tree.Set([]byte(key), []byte(value)) + updated, err := tree.Set([]byte(key), []byte(value)) + require.NoError(t, err) require.True(t, updated) mirror[key] = value default: key := randString(keySize) value := randString(valueSize) - for tree.Has([]byte(key)) { + for has, err := tree.Has([]byte(key)); has && err == nil; { key = randString(keySize) } - updated := tree.Set([]byte(key), []byte(value)) + updated, err := tree.Set([]byte(key), []byte(value)) + require.NoError(t, err) require.False(t, updated) mirror[key] = value mirrorKeys = append(mirrorKeys, key) @@ -316,7 +319,8 @@ func testRandomOperations(t *testing.T, randSeed int64) { return false }) for _, key := range keys { - _, removed := tree.Remove(key) + _, removed, err := tree.Remove(key) + require.NoError(t, err) require.True(t, removed) } _, _, err = tree.SaveVersion() @@ -353,7 +357,9 @@ func assertEmptyDatabase(t *testing.T, tree *MutableTree) { storageVersionValue, err := tree.ndb.db.Get([]byte(firstKey)) require.NoError(t, err) - require.Equal(t, fastStorageVersionValue+fastStorageVersionDelimiter+strconv.Itoa(int(tree.ndb.getLatestVersion())), string(storageVersionValue)) + latestVersion, err := tree.ndb.getLatestVersion() + require.NoError(t, err) + require.Equal(t, fastStorageVersionValue+fastStorageVersionDelimiter+strconv.Itoa(int(latestVersion)), string(storageVersionValue)) var foundVersion int64 rootKeyFormat.Scan([]byte(secondKey), &foundVersion) @@ -401,9 +407,11 @@ func assertMirror(t *testing.T, tree *MutableTree, mirror map[string]string, ver require.EqualValues(t, len(mirror), itree.Size()) require.EqualValues(t, len(mirror), iterated) for key, value := range mirror { - actualFast := itree.Get([]byte(key)) + actualFast, err := itree.Get([]byte(key)) + require.NoError(t, err) require.Equal(t, value, string(actualFast)) - _, actual := itree.GetWithIndex([]byte(key)) + _, actual, err := itree.GetWithIndex([]byte(key)) + require.NoError(t, err) require.Equal(t, value, string(actual)) } @@ -413,7 +421,9 @@ func assertMirror(t *testing.T, tree *MutableTree, mirror map[string]string, ver // Checks that fast node cache matches live state. func assertFastNodeCacheIsLive(t *testing.T, tree *MutableTree, mirror map[string]string, version int64) { - if tree.ndb.getLatestVersion() != version { + latestVersion, err := tree.ndb.getLatestVersion() + require.NoError(t, err) + if latestVersion != version { // The fast node cache check should only be done to the latest version return } @@ -428,13 +438,15 @@ func assertFastNodeCacheIsLive(t *testing.T, tree *MutableTree, mirror map[strin // Checks that fast nodes on disk match live state. func assertFastNodeDiskIsLive(t *testing.T, tree *MutableTree, mirror map[string]string, version int64) { - if tree.ndb.getLatestVersion() != version { + latestVersion, err := tree.ndb.getLatestVersion() + require.NoError(t, err) + if latestVersion != version { // The fast node disk check should only be done to the latest version return } count := 0 - err := tree.ndb.traverseFastNodes(func(keyWithPrefix, v []byte) error { + err = tree.ndb.traverseFastNodes(func(keyWithPrefix, v []byte) error { key := keyWithPrefix[1:] count++ fastNode, err := DeserializeFastNode(key, v) diff --git a/tree_test.go b/tree_test.go index f2f9b6c59..e335dbfd7 100644 --- a/tree_test.go +++ b/tree_test.go @@ -148,20 +148,22 @@ func TestTreeHash(t *testing.T) { index := r.Intn(len(keys)) key = keys[index] keys = append(keys[:index], keys[index+1:]...) - _, removed := tree.Remove(key) + _, removed, err := tree.Remove(key) + require.NoError(t, err) require.True(t, removed) case len(keys) > 0 && r.Float64() <= updateRatio: key = keys[r.Intn(len(keys))] r.Read(value) - updated := tree.Set(key, value) + updated, err := tree.Set(key, value) + require.NoError(t, err) require.True(t, updated) default: r.Read(key) r.Read(value) // If we get an update, set again - for tree.Set(key, value) { + for updated, err := tree.Set(key, value); err == nil && updated; { key = make([]byte, keySize) r.Read(key) } @@ -220,7 +222,8 @@ func TestVersionedRandomTreeSmallKeys(t *testing.T) { // Try getting random keys. for i := 0; i < keysPerVersion; i++ { - val := tree.Get([]byte(iavlrand.RandStr(1))) + val, err := tree.Get([]byte(iavlrand.RandStr(1))) + require.NoError(err) require.NotNil(val) require.NotEmpty(val) } @@ -269,7 +272,8 @@ func TestVersionedRandomTreeSmallKeysRandomDeletes(t *testing.T) { // Try getting random keys. for i := 0; i < keysPerVersion; i++ { - val := tree.Get([]byte(iavlrand.RandStr(1))) + val, err := tree.Get([]byte(iavlrand.RandStr(1))) + require.NoError(err) require.NotNil(val) require.NotEmpty(val) } @@ -504,42 +508,54 @@ func TestVersionedTree(t *testing.T) { tree.Set([]byte("key1"), []byte("val0")) // "key2" - val := tree.GetVersioned([]byte("key2"), 0) + val, err := tree.GetVersioned([]byte("key2"), 0) + require.NoError(err) require.Nil(val) - val = tree.GetVersioned([]byte("key2"), 1) + val, err = tree.GetVersioned([]byte("key2"), 1) + require.NoError(err) require.Equal("val0", string(val)) - val = tree.GetVersioned([]byte("key2"), 2) + val, err = tree.GetVersioned([]byte("key2"), 2) + require.NoError(err) require.Equal("val1", string(val)) - val = tree.Get([]byte("key2")) + val, err = tree.Get([]byte("key2")) + require.NoError(err) require.Equal("val2", string(val)) // "key1" - val = tree.GetVersioned([]byte("key1"), 1) + val, err = tree.GetVersioned([]byte("key1"), 1) + require.NoError(err) require.Equal("val0", string(val)) - val = tree.GetVersioned([]byte("key1"), 2) + val, err = tree.GetVersioned([]byte("key1"), 2) + require.NoError(err) require.Equal("val1", string(val)) - val = tree.GetVersioned([]byte("key1"), 3) + val, err = tree.GetVersioned([]byte("key1"), 3) + require.NoError(err) require.Nil(val) - val = tree.GetVersioned([]byte("key1"), 4) + val, err = tree.GetVersioned([]byte("key1"), 4) + require.NoError(err) require.Nil(val) - val = tree.Get([]byte("key1")) + val, err = tree.Get([]byte("key1")) + require.NoError(err) require.Equal("val0", string(val)) // "key3" - val = tree.GetVersioned([]byte("key3"), 0) + val, err = tree.GetVersioned([]byte("key3"), 0) + require.NoError(err) require.Nil(val) - val = tree.GetVersioned([]byte("key3"), 2) + val, err = tree.GetVersioned([]byte("key3"), 2) + require.NoError(err) require.Equal("val1", string(val)) - val = tree.GetVersioned([]byte("key3"), 3) + val, err = tree.GetVersioned([]byte("key3"), 3) + require.NoError(err) require.Equal("val1", string(val)) // Delete a version. After this the keys in that version should not be found. @@ -560,26 +576,30 @@ func TestVersionedTree(t *testing.T) { require.True(len(nodes5) < len(nodes4), "db should have shrunk after delete %d !< %d", len(nodes5), len(nodes4)) - val = tree.GetVersioned([]byte("key2"), 2) + val, err = tree.GetVersioned([]byte("key2"), 2) require.Nil(val) - val = tree.GetVersioned([]byte("key3"), 2) + val, err = tree.GetVersioned([]byte("key3"), 2) require.Nil(val) // But they should still exist in the latest version. - val = tree.Get([]byte("key2")) + val, err = tree.Get([]byte("key2")) + require.NoError(err) require.Equal("val2", string(val)) - val = tree.Get([]byte("key3")) + val, err = tree.Get([]byte("key3")) + require.NoError(err) require.Equal("val1", string(val)) // Version 1 should still be available. - val = tree.GetVersioned([]byte("key1"), 1) + val, err = tree.GetVersioned([]byte("key1"), 1) + require.NoError(err) require.Equal("val0", string(val)) - val = tree.GetVersioned([]byte("key2"), 1) + val, err = tree.GetVersioned([]byte("key2"), 1) + require.NoError(err) require.Equal("val0", string(val)) } @@ -660,16 +680,20 @@ func TestVersionedTreeOrphanDeleting(t *testing.T) { tree.DeleteVersion(2) - val := tree.Get([]byte("key0")) + val, err := tree.Get([]byte("key0")) + require.NoError(t, err) require.Equal(t, val, []byte("val2")) - val = tree.Get([]byte("key1")) + val, err = tree.Get([]byte("key1")) + require.NoError(t, err) require.Nil(t, val) - val = tree.Get([]byte("key2")) + val, err = tree.Get([]byte("key2")) + require.NoError(t, err) require.Equal(t, val, []byte("val2")) - val = tree.Get([]byte("key3")) + val, err = tree.Get([]byte("key3")) + require.NoError(t, err) require.Equal(t, val, []byte("val1")) tree.DeleteVersion(1) @@ -700,7 +724,8 @@ func TestVersionedTreeSpecialCase(t *testing.T) { tree.DeleteVersion(2) - val := tree.GetVersioned([]byte("key2"), 1) + val, err := tree.GetVersioned([]byte("key2"), 1) + require.NoError(err) require.Equal("val0", string(val)) } @@ -729,7 +754,8 @@ func TestVersionedTreeSpecialCase2(t *testing.T) { require.NoError(tree.DeleteVersion(2)) - val := tree.GetVersioned([]byte("key2"), 1) + val, err := tree.GetVersioned([]byte("key2"), 1) + require.NoError(err) require.Equal("val0", string(val)) } @@ -786,7 +812,8 @@ func TestVersionedTreeSaveAndLoad(t *testing.T) { tree.SaveVersion() tree.SaveVersion() - preHash := tree.Hash() + preHash, err := tree.Hash() + require.NoError(err) require.NotNil(preHash) require.Equal(int64(6), tree.Version()) @@ -799,7 +826,8 @@ func TestVersionedTreeSaveAndLoad(t *testing.T) { require.False(ntree.IsEmpty()) require.Equal(int64(6), ntree.Version()) - postHash := ntree.Hash() + postHash, err := ntree.Hash() + require.NoError(err) require.Equal(preHash, postHash) ntree.Set([]byte("T"), []byte("MhkWjkVy")) @@ -838,7 +866,8 @@ func TestVersionedTreeErrors(t *testing.T) { require.Error(tree.DeleteVersion(1)) // Trying to get a key from a version which doesn't exist. - val := tree.GetVersioned([]byte("key"), 404) + val, err := tree.GetVersioned([]byte("key"), 404) + require.NoError(err) require.Nil(val) // Same thing with proof. We get an error because a proof couldn't be @@ -882,7 +911,8 @@ func TestVersionedCheckpoints(t *testing.T) { // Make sure all keys exist at least once. for _, ks := range keys { for _, k := range ks { - val := tree.Get(k) + val, err := tree.Get(k) + require.NoError(err) require.NotEmpty(val) } } @@ -891,7 +921,8 @@ func TestVersionedCheckpoints(t *testing.T) { for i := 1; i <= versions; i++ { if i%versionsPerCheckpoint != 0 { for _, k := range keys[int64(i)] { - val := tree.GetVersioned(k, int64(i)) + val, err := tree.GetVersioned(k, int64(i)) + require.NoError(err) require.Nil(val) } } @@ -901,7 +932,8 @@ func TestVersionedCheckpoints(t *testing.T) { for i := 1; i <= versions; i++ { for _, k := range keys[int64(i)] { if i%versionsPerCheckpoint == 0 { - val := tree.GetVersioned(k, int64(i)) + val, err := tree.GetVersioned(k, int64(i)) + require.NoError(err) require.NotEmpty(val) } } @@ -930,7 +962,7 @@ func TestVersionedCheckpointsSpecialCase(t *testing.T) { // checkpoint, which is version 10. tree.DeleteVersion(1) - val := tree.GetVersioned(key, 2) + val, err := tree.GetVersioned(key, 2) require.NotEmpty(val) require.Equal([]byte("val1"), val) } @@ -994,19 +1026,19 @@ func TestVersionedCheckpointsSpecialCase4(t *testing.T) { tree.Set([]byte("X"), []byte("New")) tree.SaveVersion() - val := tree.GetVersioned([]byte("A"), 2) + val, err := tree.GetVersioned([]byte("A"), 2) require.Nil(t, val) - val = tree.GetVersioned([]byte("A"), 1) + val, err = tree.GetVersioned([]byte("A"), 1) require.NotEmpty(t, val) tree.DeleteVersion(1) tree.DeleteVersion(2) - val = tree.GetVersioned([]byte("A"), 2) + val, err = tree.GetVersioned([]byte("A"), 2) require.Nil(t, val) - val = tree.GetVersioned([]byte("A"), 1) + val, err = tree.GetVersioned([]byte("A"), 1) require.Nil(t, val) } @@ -1158,7 +1190,8 @@ func TestVersionedTreeProofs(t *testing.T) { // printNode(tree.ndb, tree.root, 0) // fmt.Println("TREE VERSION 1 END") - root1 := tree.Hash() + root1, err := tree.Hash() + require.NoError(err) tree.Set([]byte("k2"), []byte("v2")) tree.Set([]byte("k4"), []byte("v2")) @@ -1169,18 +1202,16 @@ func TestVersionedTreeProofs(t *testing.T) { // printNode(tree.ndb, tree.root, 0) // fmt.Println("TREE VERSION END") - root2 := tree.Hash() + root2, err := tree.Hash() + require.NoError(err) require.NotEqual(root1, root2) tree.Remove([]byte("k2")) _, _, err = tree.SaveVersion() require.NoError(err) - // fmt.Println("TREE VERSION 3") - // printNode(tree.ndb, tree.root, 0) - // fmt.Println("TREE VERSION END") - - root3 := tree.Hash() + root3, err := tree.Hash() + require.NoError(err) require.NotEqual(root2, root3) val, proof, err := tree.GetVersionedWithProof([]byte("k2"), 1) @@ -1257,15 +1288,21 @@ func TestVersionedTreeHash(t *testing.T) { tree, err := getTestTree(0) require.NoError(err) - require.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hex.EncodeToString(tree.Hash())) + hash, err := tree.Hash() + require.NoError(err) + require.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hex.EncodeToString(hash)) tree.Set([]byte("I"), []byte("D")) - require.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hex.EncodeToString(tree.Hash())) + hash, err = tree.Hash() + require.NoError(err) + require.Equal("e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hex.EncodeToString(hash)) hash1, _, err := tree.SaveVersion() require.NoError(err) tree.Set([]byte("I"), []byte("F")) - require.EqualValues(hash1, tree.Hash()) + hash, err = tree.Hash() + require.NoError(err) + require.EqualValues(hash1, hash) hash2, _, err := tree.SaveVersion() require.NoError(err) @@ -1282,9 +1319,8 @@ func TestNilValueSemantics(t *testing.T) { tree, err := getTestTree(0) require.NoError(err) - require.Panics(func() { - tree.Set([]byte("k"), nil) - }) + _, err = tree.Set([]byte("k"), nil) + require.Error(err) } func TestCopyValueSemantics(t *testing.T) { @@ -1296,12 +1332,13 @@ func TestCopyValueSemantics(t *testing.T) { val := []byte("v1") tree.Set([]byte("k"), val) - v := tree.Get([]byte("k")) + v, err := tree.Get([]byte("k")) + require.NoError(err) require.Equal([]byte("v1"), v) val[1] = '2' - val = tree.Get([]byte("k")) + val, err = tree.Get([]byte("k")) require.Equal([]byte("v2"), val) } @@ -1325,13 +1362,13 @@ func TestRollback(t *testing.T) { require.Equal(int64(2), tree.Size()) - val := tree.Get([]byte("r")) + val, err := tree.Get([]byte("r")) require.Nil(val) - val = tree.Get([]byte("s")) + val, err = tree.Get([]byte("s")) require.Nil(val) - val = tree.Get([]byte("t")) + val, err = tree.Get([]byte("t")) require.Equal([]byte("v"), val) } @@ -1356,7 +1393,8 @@ func TestLazyLoadVersion(t *testing.T) { require.NoError(t, err, "unexpected error when lazy loading version") require.Equal(t, version, int64(maxVersions)) - value := tree.Get([]byte(fmt.Sprintf("key_%d", maxVersions))) + value, err := tree.Get([]byte(fmt.Sprintf("key_%d", maxVersions))) + require.NoError(t, err) require.Equal(t, value, []byte(fmt.Sprintf("value_%d", maxVersions)), "unexpected value") // require the ability to lazy load an older version @@ -1364,7 +1402,8 @@ func TestLazyLoadVersion(t *testing.T) { require.NoError(t, err, "unexpected error when lazy loading version") require.Equal(t, version, int64(maxVersions-1)) - value = tree.Get([]byte(fmt.Sprintf("key_%d", maxVersions-1))) + value, err = tree.Get([]byte(fmt.Sprintf("key_%d", maxVersions-1))) + require.NoError(t, err) require.Equal(t, value, []byte(fmt.Sprintf("value_%d", maxVersions-1)), "unexpected value") // require the inability to lazy load a non-valid version @@ -1690,7 +1729,8 @@ func TestLoadVersionForOverwritingCase2(t *testing.T) { require.NoError(err, "LoadVersionForOverwriting should not fail") for i := byte(0); i < 20; i++ { - v := tree.Get([]byte{i}) + v, err := tree.Get([]byte{i}) + require.NoError(err) require.Equal([]byte{i}, v) } @@ -1756,7 +1796,8 @@ func TestLoadVersionForOverwritingCase3(t *testing.T) { } for i := byte(0); i < 20; i++ { - v := tree.Get([]byte{i}) + v, err := tree.Get([]byte{i}) + require.NoError(err) require.Equal([]byte{i}, v) } } @@ -1800,12 +1841,15 @@ func TestGetByIndex_ImmutableTree(t *testing.T) { immutableTree, err := tree.GetImmutable(1) require.NoError(t, err) - require.True(t, immutableTree.IsFastCacheEnabled()) + isFastCacheEnabled, err := immutableTree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) for index, expectedKey := range mirrorKeys { expectedValue := mirror[expectedKey] - actualKey, actualValue := immutableTree.GetByIndex(int64(index)) + actualKey, actualValue, err := immutableTree.GetByIndex(int64(index)) + require.NoError(t, err) require.Equal(t, expectedKey, string(actualKey)) require.Equal(t, expectedValue, string(actualValue)) @@ -1822,12 +1866,15 @@ func TestGetWithIndex_ImmutableTree(t *testing.T) { immutableTree, err := tree.GetImmutable(1) require.NoError(t, err) - require.True(t, immutableTree.IsFastCacheEnabled()) + isFastCacheEnabled, err := immutableTree.IsFastCacheEnabled() + require.NoError(t, err) + require.True(t, isFastCacheEnabled) for expectedIndex, key := range mirrorKeys { expectedValue := mirror[key] - actualIndex, actualValue := immutableTree.GetWithIndex([]byte(key)) + actualIndex, actualValue, err := immutableTree.GetWithIndex([]byte(key)) + require.NoError(t, err) require.Equal(t, expectedValue, string(actualValue)) require.Equal(t, int64(expectedIndex), actualIndex) @@ -1857,7 +1904,9 @@ func Benchmark_GetWithIndex(b *testing.B) { runtime.GC() b.Run("fast", func(sub *testing.B) { - require.True(b, t.IsFastCacheEnabled()) + isFastCacheEnabled, err := t.IsFastCacheEnabled() + require.NoError(b, err) + require.True(b, isFastCacheEnabled) b.ResetTimer() for i := 0; i < sub.N; i++ { randKey := rand.Intn(numKeyVals) @@ -1873,7 +1922,9 @@ func Benchmark_GetWithIndex(b *testing.B) { itree, err := t.GetImmutable(latestVersion - 1) require.NoError(b, err) - require.False(b, itree.IsFastCacheEnabled()) + isFastCacheEnabled, err := itree.IsFastCacheEnabled() + require.NoError(b, err) + require.False(b, isFastCacheEnabled) b.ResetTimer() for i := 0; i < sub.N; i++ { randKey := rand.Intn(numKeyVals) @@ -1902,7 +1953,9 @@ func Benchmark_GetByIndex(b *testing.B) { runtime.GC() b.Run("fast", func(sub *testing.B) { - require.True(b, t.IsFastCacheEnabled()) + isFastCacheEnabled, err := t.IsFastCacheEnabled() + require.NoError(b, err) + require.True(b, isFastCacheEnabled) b.ResetTimer() for i := 0; i < sub.N; i++ { randIdx := rand.Intn(numKeyVals) @@ -1918,7 +1971,10 @@ func Benchmark_GetByIndex(b *testing.B) { itree, err := t.GetImmutable(latestVersion - 1) require.NoError(b, err) - require.False(b, itree.IsFastCacheEnabled()) + isFastCacheEnabled, err := itree.IsFastCacheEnabled() + require.NoError(b, err) + require.False(b, isFastCacheEnabled) + b.ResetTimer() for i := 0; i < sub.N; i++ { randIdx := rand.Intn(numKeyVals) diff --git a/util.go b/util.go index 9b451f9aa..676da1313 100644 --- a/util.go +++ b/util.go @@ -12,7 +12,7 @@ func PrintTree(tree *ImmutableTree) { printNode(ndb, root, 0) } -func printNode(ndb *nodeDB, node *Node, indent int) { +func printNode(ndb *nodeDB, node *Node, indent int) error { indentPrefix := "" for i := 0; i < indent; i++ { indentPrefix += " " @@ -20,16 +20,23 @@ func printNode(ndb *nodeDB, node *Node, indent int) { if node == nil { fmt.Printf("%s\n", indentPrefix) - return + return nil } if node.rightNode != nil { printNode(ndb, node.rightNode, indent+1) } else if node.rightHash != nil { - rightNode := ndb.GetNode(node.rightHash) + rightNode, err := ndb.GetNode(node.rightHash) + if err != nil { + return err + } printNode(ndb, rightNode, indent+1) } - hash := node._hash() + hash, err := node._hash() + if err != nil { + return err + } + fmt.Printf("%sh:%X\n", indentPrefix, hash) if node.isLeaf() { fmt.Printf("%s%X:%X (%v)\n", indentPrefix, node.key, node.value, node.height) @@ -38,10 +45,13 @@ func printNode(ndb *nodeDB, node *Node, indent int) { if node.leftNode != nil { printNode(ndb, node.leftNode, indent+1) } else if node.leftHash != nil { - leftNode := ndb.GetNode(node.leftHash) + leftNode, err := ndb.GetNode(node.leftHash) + if err != nil { + return err + } printNode(ndb, leftNode, indent+1) } - + return nil } func maxInt8(a, b int8) int8 {