Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Implement removal of orphan nodes #37

Merged
4 changes: 2 additions & 2 deletions bulk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ func TestSparseMerkleTree(t *testing.T) {

// Test all tree operations in bulk, with specified ratio probabilities of insert, update and delete.
func bulkOperations(t *testing.T, operations int, insert int, update int, delete int) {
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, sha256.New())
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())

max := insert + update + delete
kv := make(map[string]string)
Expand Down
23 changes: 12 additions & 11 deletions deepsubtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@ type DeepSparseMerkleSubTree struct {
}

// NewDeepSparseMerkleSubTree creates a new deep Sparse Merkle subtree on an empty MapStore.
func NewDeepSparseMerkleSubTree(ms MapStore, hasher hash.Hash, root []byte) *DeepSparseMerkleSubTree {
smt := &SparseMerkleTree{
th: *newTreeHasher(hasher),
ms: ms,
func NewDeepSparseMerkleSubTree(nodes, values MapStore, hasher hash.Hash, root []byte) *DeepSparseMerkleSubTree {
return &DeepSparseMerkleSubTree{
SparseMerkleTree: ImportSparseMerkleTree(nodes, values, hasher, root),
}

smt.SetRoot(root)

return &DeepSparseMerkleSubTree{SparseMerkleTree: smt}
}

// AddBranch adds a branch to the tree.
Expand All @@ -32,14 +27,20 @@ func NewDeepSparseMerkleSubTree(ms MapStore, hasher hash.Hash, root []byte) *Dee
// If the leaf may be updated (e.g. during a state transition fraud proof),
// an updatable proof should be used. See SparseMerkleTree.ProveUpdatable.
func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []byte, value []byte) error {
result, updates := verifyProofWithUpdates(proof, dsmst.Root(), key, value, dsmst.th.hasher)
result, updates, valueHash := verifyProofWithUpdates(proof, dsmst.Root(), key, value, dsmst.th.hasher)
if !result {
return ErrBadProof
}

if valueHash != nil {
if err := dsmst.values.Set(dsmst.th.path(key), value); err != nil {
return err
}
}

// Update nodes along branch
for _, update := range updates {
err := dsmst.ms.Set(update[0], update[1])
err := dsmst.nodes.Set(update[0], update[1])
if err != nil {
return err
}
Expand All @@ -48,7 +49,7 @@ func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []b
// Update sibling node
if proof.SiblingData != nil {
if proof.SideNodes != nil && len(proof.SideNodes) > 0 {
err := dsmst.ms.Set(proof.SideNodes[0], proof.SiblingData)
err := dsmst.nodes.Set(proof.SideNodes[0], proof.SiblingData)
if err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions deepsubtree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@ import (
)

func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {
smt := NewSparseMerkleTree(NewSimpleMap(), sha256.New())
smt := NewSparseMerkleTree(NewSimpleMap(), NewSimpleMap(), sha256.New())

smt.Update([]byte("testKey1"), []byte("testValue1"))
smt.Update([]byte("testKey2"), []byte("testValue2"))
smt.Update([]byte("testKey3"), []byte("testValue3"))
smt.Update([]byte("testKey4"), []byte("testValue4"))
smt.Update([]byte("testKey6"), []byte("testValue6"))

var originalRoot []byte
originalRoot := make([]byte, len(smt.Root()))
copy(originalRoot, smt.Root())

proof1, _ := smt.ProveUpdatable([]byte("testKey1"))
proof2, _ := smt.ProveUpdatable([]byte("testKey2"))
proof5, _ := smt.ProveUpdatable([]byte("testKey5"))

dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), sha256.New(), smt.Root())
dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), sha256.New(), smt.Root())
err := dsmst.AddBranch(proof1, []byte("testKey1"), []byte("testValue1"))
if err != nil {
t.Errorf("returned error when adding branch to deep subtree: %v", err)
Expand All @@ -39,21 +39,21 @@ func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {

value, err := dsmst.Get([]byte("testKey1"))
if err != nil {
t.Error("returned error when getting value in deep subtree")
t.Errorf("returned error when getting value in deep subtree: %v", err)
}
if !bytes.Equal(value, []byte("testValue1")) {
t.Error("did not get correct value in deep subtree")
}
value, err = dsmst.Get([]byte("testKey2"))
if err != nil {
t.Error("returned error when getting value in deep subtree")
t.Errorf("returned error when getting value in deep subtree: %v", err)
}
if !bytes.Equal(value, []byte("testValue2")) {
t.Error("did not get correct value in deep subtree")
}
value, err = dsmst.Get([]byte("testKey5"))
if err != nil {
t.Error("returned error when getting value in deep subtree")
t.Errorf("returned error when getting value in deep subtree: %v", err)
}
if !bytes.Equal(value, defaultValue) {
t.Error("did not get correct value in deep subtree")
Expand Down Expand Up @@ -120,7 +120,7 @@ func TestDeepSparseMerkleSubTreeBasic(t *testing.T) {
}

func TestDeepSparseMerkleSubTreeBadInput(t *testing.T) {
smt := NewSparseMerkleTree(NewSimpleMap(), sha256.New())
smt := NewSparseMerkleTree(NewSimpleMap(), NewSimpleMap(), sha256.New())

smt.Update([]byte("testKey1"), []byte("testValue1"))
smt.Update([]byte("testKey2"), []byte("testValue2"))
Expand All @@ -130,7 +130,7 @@ func TestDeepSparseMerkleSubTreeBadInput(t *testing.T) {
badProof, _ := smt.Prove([]byte("testKey1"))
badProof.SideNodes[0][0] = byte(0)

dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), sha256.New(), smt.Root())
dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), sha256.New(), smt.Root())
err := dsmst.AddBranch(badProof, []byte("testKey1"), []byte("testValue1"))
if !errors.Is(err, ErrBadProof) {
t.Error("did not return ErrBadProof for bad proof input")
Expand Down
7 changes: 0 additions & 7 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,3 @@ package smt

// Option is a function that configures SMT.
type Option func(*SparseMerkleTree)

// AutoRemoveOrphans option configures SMT to automatically remove orphaned nodes during Update/Delete operation.
func AutoRemoveOrphans() Option {
return func(smt *SparseMerkleTree) {
smt.prune = true
}
}
19 changes: 8 additions & 11 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,20 @@ func (proof *SparseCompactMerkleProof) sanityCheck(th *treeHasher) bool {

// VerifyProof verifies a Merkle proof.
func VerifyProof(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) bool {
result, _ := verifyProofWithUpdates(proof, root, key, value, hasher)
result, _, _ := verifyProofWithUpdates(proof, root, key, value, hasher)
return result
}

func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) (bool, [][][]byte) {
func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) (bool, [][][]byte, []byte) {
th := newTreeHasher(hasher)
path := th.path(key)

if !proof.sanityCheck(th) {
return false, nil
return false, nil, nil
}

var updates [][][]byte
var memberValueHash []byte

// Determine what the leaf hash should be.
var currentHash, currentData []byte
Expand All @@ -126,7 +127,7 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va
actualPath, valueHash := th.parseLeaf(proof.NonMembershipLeafData)
if bytes.Equal(actualPath, path) {
// This is not an unrelated leaf; non-membership proof failed.
return false, nil
return false, nil, nil
}
currentHash, currentData = th.digestLeaf(actualPath, valueHash)

Expand All @@ -135,13 +136,9 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va
updates = append(updates, update)
}
} else { // Membership proof.
valueHash := th.digest(value)
memberValueHash = th.digest(value)
currentHash, currentData = th.digestLeaf(path, memberValueHash)
update := make([][]byte, 2)
update[0], update[1] = valueHash, value
updates = append(updates, update)

currentHash, currentData = th.digestLeaf(path, valueHash)
update = make([][]byte, 2)
update[0], update[1] = currentHash, currentData
updates = append(updates, update)
}
Expand All @@ -162,7 +159,7 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va
updates = append(updates, update)
}

return bytes.Equal(currentHash, root), updates
return bytes.Equal(currentHash, root), updates, memberValueHash
adlerjohn marked this conversation as resolved.
Show resolved Hide resolved
}

// VerifyCompactProof verifies a compacted Merkle proof.
Expand Down
14 changes: 7 additions & 7 deletions proofs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import (

// Test base case Merkle proof operations.
func TestProofsBasic(t *testing.T) {
var sm *SimpleMap
var smn, smv *SimpleMap
var smt *SparseMerkleTree
var proof SparseMerkleProof
var result bool
var root []byte
var err error

sm = NewSimpleMap()
smt = NewSparseMerkleTree(sm, sha256.New())
smn, smv = NewSimpleMap(), NewSimpleMap()
smt = NewSparseMerkleTree(smn, smv, sha256.New())

// Generate and verify a proof on an empty key.
proof, err = smt.Prove([]byte("testKey3"))
Expand Down Expand Up @@ -123,8 +123,8 @@ func TestProofsBasic(t *testing.T) {

// Test sanity check cases for non-compact proofs.
func TestProofsSanityCheck(t *testing.T) {
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, sha256.New())
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())
th := &smt.th

smt.Update([]byte("testKey1"), []byte("testValue1"))
Expand Down Expand Up @@ -199,8 +199,8 @@ func TestProofsSanityCheck(t *testing.T) {

// Test sanity check cases for compact proofs.
func TestCompactProofsSanityCheck(t *testing.T) {
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, sha256.New())
smn, smv := NewSimpleMap(), NewSimpleMap()
smt := NewSparseMerkleTree(smn, smv, sha256.New())
th := &smt.th

smt.Update([]byte("testKey1"), []byte("testValue1"))
Expand Down
Loading