Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Commit

Permalink
Merge pull request #12 from lazyledger/tzdybal/has_and_delete
Browse files Browse the repository at this point in the history
Add Delete and Has methods to SparseMerkleTree
  • Loading branch information
tzdybal authored Jan 13, 2021
2 parents 42131aa + 870d9a9 commit 22852aa
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
22 changes: 22 additions & 0 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ func (smt *SparseMerkleTree) GetForRoot(key []byte, root []byte) ([]byte, error)
return value, nil
}

// Has returns true if tree cointains given key, false otherwise.
func (smt *SparseMerkleTree) Has(key []byte) (bool, error) {
val, err := smt.Get(key)
return !bytes.Equal(defaultValue, val), err
}

// HasForRoot returns true if tree cointains given key at a specific root, false otherwise.
func (smt *SparseMerkleTree) HasForRoot(key, root []byte) (bool, error) {
val, err := smt.GetForRoot(key, root)
return !bytes.Equal(defaultValue, val), err
}

// Update sets a new value for a key in the tree, and sets and returns the new root of the tree.
func (smt *SparseMerkleTree) Update(key []byte, value []byte) ([]byte, error) {
newRoot, err := smt.UpdateForRoot(key, value, smt.Root())
Expand All @@ -132,6 +144,11 @@ func (smt *SparseMerkleTree) Update(key []byte, value []byte) ([]byte, error) {
return newRoot, err
}

// Delete deletes a value from tree. It returns the new root of the tree.
func (smt *SparseMerkleTree) Delete(key []byte) ([]byte, error) {
return smt.Update(key, defaultValue)
}

// UpdateForRoot sets a new value for a key in the tree at a specific root, and returns the new root.
func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte) ([]byte, error) {
path := smt.th.path(key)
Expand All @@ -155,6 +172,11 @@ func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte
return newRoot, err
}

// Delete deletes a value from tree at a specific root. It returns the new root of the tree.
func (smt *SparseMerkleTree) DeleteForRoot(key, root []byte) ([]byte, error) {
return smt.UpdateForRoot(key, defaultValue, root)
}

func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte, oldLeafHash []byte, oldLeafData []byte) ([]byte, error) {
if bytes.Equal(oldLeafHash, smt.th.placeholder()) {
// This key is already empty as it is a placeholder; return an error.
Expand Down
89 changes: 89 additions & 0 deletions smt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, sha256.New())
var value []byte
var has bool
var err error

// Test getting an empty key.
Expand All @@ -23,6 +24,13 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
if !bytes.Equal(defaultValue, value) {
t.Error("did not get default value when getting empty key")
}
has, err = smt.Has([]byte("testKey"))
if err != nil {
t.Errorf("returned error when checking presence of empty key: %v", err)
}
if has {
t.Error("did not get 'false' when checking presence of empty key")
}

// Test updating the empty key.
_, err = smt.Update([]byte("testKey"), []byte("testValue"))
Expand All @@ -36,6 +44,13 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
if !bytes.Equal([]byte("testValue"), value) {
t.Error("did not get correct value when getting non-empty key")
}
has, err = smt.Has([]byte("testKey"))
if err != nil {
t.Errorf("returned error when checking presence of non-empty key: %v", err)
}
if !has {
t.Error("did not get 'true' when checking presence of non-empty key")
}

// Test updating the non-empty key.
_, err = smt.Update([]byte("testKey"), []byte("testValue2"))
Expand Down Expand Up @@ -114,6 +129,33 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
if !bytes.Equal([]byte("testValue2"), value) {
t.Error("did not get correct value when getting non-empty key")
}
has, err = smt.HasForRoot([]byte("testKey"), root)
if err != nil {
t.Errorf("returned error when checking presence of non-empty key: %v", err)
}
if !has {
t.Error("did not get 'false' when checking presence of non-empty key")
}

// Test that it is possible to delete key in an older root.
root, err = smt.DeleteForRoot([]byte("testKey3"), root)
if err != nil {
t.Errorf("unable to delete key: %v", err)
}
value, err = smt.GetForRoot([]byte("testKey3"), root)
if err != nil {
t.Errorf("returned error when getting empty key: %v", err)
}
if !bytes.Equal(defaultValue, value) {
t.Error("did not get correct value when getting empty key")
}
has, err = smt.HasForRoot([]byte("testKey3"), root)
if err != nil {
t.Errorf("returned error when checking presence of empty key: %v", err)
}
if has {
t.Error("did not get 'false' when checking presence of empty key")
}

// Test that a tree can be imported from a MapStore.
smt2 := ImportSparseMerkleTree(sm, sha256.New(), smt.Root())
Expand Down Expand Up @@ -195,6 +237,13 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) {
if !bytes.Equal(defaultValue, value) {
t.Error("did not get default value when getting deleted key")
}
has, err := smt.Has([]byte("testKey"))
if err != nil {
t.Errorf("returned error when checking existence of deleted key: %v", err)
}
if has {
t.Error("returned 'true' when checking existernce of deleted key")
}
_, err = smt.Update([]byte("testKey"), []byte("testValue"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
Expand Down Expand Up @@ -269,6 +318,46 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) {
if !bytes.Equal(root1, smt.Root()) {
t.Error("tree root is not as expected after deleting second key")
}

// Testing inserting, deleting a key, and inserting it again, using Delete
_, err = smt.Update([]byte("testKey"), []byte("testValue"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
root1 = smt.Root()
_, err = smt.Delete([]byte("testKey"))
if err != nil {
t.Errorf("returned error when deleting key: %v", err)
}
value, err = smt.Get([]byte("testKey"))
if err != nil {
t.Errorf("returned error when getting deleted key: %v", err)
}
if !bytes.Equal(defaultValue, value) {
t.Error("did not get default value when getting deleted key")
}
has, err = smt.Has([]byte("testKey"))
if err != nil {
t.Errorf("returned error when checking existence of deleted key: %v", err)
}
if has {
t.Error("returned 'true' when checking existernce of deleted key")
}
_, err = smt.Update([]byte("testKey"), []byte("testValue"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
value, err = smt.Get([]byte("testKey"))
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue"), value) {
t.Error("did not get correct value when getting non-empty key")
}
if !bytes.Equal(root1, smt.Root()) {
t.Error("tree root is not as expected after re-inserting key after deletion")
}

}

// dummyHasher is a dummy hasher for tests, where the digest of keys is equivalent to the preimage.
Expand Down

0 comments on commit 22852aa

Please sign in to comment.