Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Implement removal of orphan nodes #37

Merged
4 changes: 2 additions & 2 deletions bulk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestSparseMerkleTree(t *testing.T) {

// Test all tree operations in bulk, with specified ratio probabilities of insert, update and delete.
func bulkOperations(t *testing.T, operations int, insert int, update int, delete int) {
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, sha256.New())
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())

max := insert + update + delete
kv := make(map[string]string)
Expand Down
22 changes: 12 additions & 10 deletions deepsubtree.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package smt

import (
"bytes"
"errors"
"hash"
)
Expand All @@ -14,15 +15,10 @@ type DeepSparseMerkleSubTree struct {
}

// NewDeepSparseMerkleSubTree creates a new deep Sparse Merkle subtree on an empty MapStore.
func NewDeepSparseMerkleSubTree(ms MapStore, hasher hash.Hash, root []byte) *DeepSparseMerkleSubTree {
smt := &SparseMerkleTree{
th: *newTreeHasher(hasher),
ms: ms,
func NewDeepSparseMerkleSubTree(nodes, values MapStore, hasher hash.Hash, root []byte) *DeepSparseMerkleSubTree {
return &DeepSparseMerkleSubTree{
SparseMerkleTree: ImportSparseMerkleTree(nodes, values, hasher, root),
}

smt.SetRoot(root)

return &DeepSparseMerkleSubTree{SparseMerkleTree: smt}
}

// AddBranch adds a branch to the tree.
Expand All @@ -37,9 +33,15 @@ func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []b
return ErrBadProof
}

if !bytes.Equal(value, defaultValue) { // Membership proof.
if err := dsmst.values.Set(dsmst.th.path(key), value); err != nil {
return err
}
}

// Update nodes along branch
for _, update := range updates {
err := dsmst.ms.Set(update[0], update[1])
err := dsmst.nodes.Set(update[0], update[1])
if err != nil {
return err
}
Expand All @@ -48,7 +50,7 @@ func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []b
// Update sibling node
if proof.SiblingData != nil {
if proof.SideNodes != nil && len(proof.SideNodes) > 0 {
err := dsmst.ms.Set(proof.SideNodes[0], proof.SiblingData)
err := dsmst.nodes.Set(proof.SideNodes[0], proof.SiblingData)
if err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions deepsubtree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@ import (
)

func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {
smt := NewSparseMerkleTree(NewSimpleMap(), sha256.New())
smt := NewSparseMerkleTree(NewSimpleMap(), NewSimpleMap(), sha256.New())

smt.Update([]byte("testKey1"), []byte("testValue1"))
smt.Update([]byte("testKey2"), []byte("testValue2"))
smt.Update([]byte("testKey3"), []byte("testValue3"))
smt.Update([]byte("testKey4"), []byte("testValue4"))
smt.Update([]byte("testKey6"), []byte("testValue6"))

var originalRoot []byte
originalRoot := make([]byte, len(smt.Root()))
copy(originalRoot, smt.Root())

proof1, _ := smt.ProveUpdatable([]byte("testKey1"))
proof2, _ := smt.ProveUpdatable([]byte("testKey2"))
proof5, _ := smt.ProveUpdatable([]byte("testKey5"))

dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), sha256.New(), smt.Root())
dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), sha256.New(), smt.Root())
err := dsmst.AddBranch(proof1, []byte("testKey1"), []byte("testValue1"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
Expand All @@ -39,21 +39,21 @@ func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {

value, err := dsmst.Get([]byte("testKey1"))
if err != nil {
t.Error("returned error when getting value in deep subtree")
t.Errorf("returned error when getting value in deep subtree: %v", err)
}
if !bytes.Equal(value, []byte("testValue1")) {
t.Error("did not get correct value in deep subtree")
}
value, err = dsmst.Get([]byte("testKey2"))
if err != nil {
t.Error("returned error when getting value in deep subtree")
t.Errorf("returned error when getting value in deep subtree: %v", err)
}
if !bytes.Equal(value, []byte("testValue2")) {
t.Error("did not get correct value in deep subtree")
}
value, err = dsmst.Get([]byte("testKey5"))
if err != nil {
t.Error("returned error when getting value in deep subtree")
t.Errorf("returned error when getting value in deep subtree: %v", err)
}
if !bytes.Equal(value, defaultValue) {
t.Error("did not get correct value in deep subtree")
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {
}

func TestDeepSparseMerkleSubTreeBadInput(t *testing.T) {
smt := NewSparseMerkleTree(NewSimpleMap(), sha256.New())
smt := NewSparseMerkleTree(NewSimpleMap(), NewSimpleMap(), sha256.New())

smt.Update([]byte("testKey1"), []byte("testValue1"))
smt.Update([]byte("testKey2"), []byte("testValue2"))
Expand All @@ -130,7 +130,7 @@ func TestDeepSparseMerkleSubTreeBadInput(t *testing.T) {
badProof, _ := smt.Prove([]byte("testKey1"))
badProof.SideNodes[0][0] = byte(0)

dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), sha256.New(), smt.Root())
dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), sha256.New(), smt.Root())
err := dsmst.AddBranch(badProof, []byte("testKey1"), []byte("testValue1"))
if !errors.Is(err, ErrBadProof) {
t.Error("did not return ErrBadProof for bad proof input")
Expand Down
7 changes: 0 additions & 7 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,3 @@ package smt

// Option is a function that configures SMT.
type Option func(*SparseMerkleTree)

// AutoRemoveOrphans option configures SMT to automatically remove orphaned nodes during Update/Delete operation.
func AutoRemoveOrphans() Option {
return func(smt *SparseMerkleTree) {
smt.prune = true
}
}
6 changes: 1 addition & 5 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,8 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va
}
} else { // Membership proof.
valueHash := th.digest(value)
update := make([][]byte, 2)
update[0], update[1] = valueHash, value
updates = append(updates, update)

currentHash, currentData = th.digestLeaf(path, valueHash)
update = make([][]byte, 2)
update := make([][]byte, 2)
update[0], update[1] = currentHash, currentData
updates = append(updates, update)
}
Expand Down
14 changes: 7 additions & 7 deletions proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import (

// Test base case Merkle proof operations.
func TestProofsBasic(t *testing.T) {
var sm *SimpleMap
var smn, smv *SimpleMap
var smt *SparseMerkleTree
var proof SparseMerkleProof
var result bool
var root []byte
var err error

sm = NewSimpleMap()
smt = NewSparseMerkleTree(sm, sha256.New())
smn, smv = NewSimpleMap(), NewSimpleMap()
smt = NewSparseMerkleTree(smn, smv, sha256.New())

// Generate and verify a proof on an empty key.
proof, err = smt.Prove([]byte("testKey3"))
Expand Down Expand Up @@ -123,8 +123,8 @@ func TestProofsBasic(t *testing.T) {

// Test sanity check cases for non-compact proofs.
func TestProofsSanityCheck(t *testing.T) {
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, sha256.New())
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())
th := &smt.th

smt.Update([]byte("testKey1"), []byte("testValue1"))
Expand Down Expand Up @@ -199,8 +199,8 @@ func TestProofsSanityCheck(t *testing.T) {

// Test sanity check cases for compact proofs.
func TestCompactProofsSanityCheck(t *testing.T) {
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, sha256.New())
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())
th := &smt.th

smt.Update([]byte("testKey1"), []byte("testValue1"))
Expand Down
Loading