Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Add Delete and Has methods to SparseMerkleTree #12

Merged
merged 7 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ func (smt *SparseMerkleTree) GetForRoot(key []byte, root []byte) ([]byte, error)
return value, nil
}

// Has returns true if tree cointains given key, false otherwise.
func (smt *SparseMerkleTree) Has(key []byte) (bool, error) {
val, err := smt.Get(key)
return !bytes.Equal(defaultValue, val), err
}

// HasForRoot returns true if tree cointains given key at a specific root, false otherwise.
func (smt *SparseMerkleTree) HasForRoot(key, root []byte) (bool, error) {
val, err := smt.GetForRoot(key, root)
return !bytes.Equal(defaultValue, val), err
}

// Update sets a new value for a key in the tree, and sets and returns the new root of the tree.
func (smt *SparseMerkleTree) Update(key []byte, value []byte) ([]byte, error) {
newRoot, err := smt.UpdateForRoot(key, value, smt.Root())
Expand All @@ -132,6 +144,11 @@ func (smt *SparseMerkleTree) Update(key []byte, value []byte) ([]byte, error) {
return newRoot, err
}

// Delete deletes a value from tree. It returns the new root of the tree.
func (smt *SparseMerkleTree) Delete(key []byte) ([]byte, error) {
return smt.Update(key, defaultValue)
}

// UpdateForRoot sets a new value for a key in the tree at a specific root, and returns the new root.
func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte) ([]byte, error) {
path := smt.th.path(key)
Expand All @@ -155,6 +172,11 @@ func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte
return newRoot, err
}

// Delete deletes a value from tree at a specific root. It returns the new root of the tree.
func (smt *SparseMerkleTree) DeleteForRoot(key, root []byte) ([]byte, error) {
return smt.UpdateForRoot(key, defaultValue, root)
}

func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte, oldLeafHash []byte, oldLeafData []byte) ([]byte, error) {
if bytes.Equal(oldLeafHash, smt.th.placeholder()) {
// This key is already empty as it is a placeholder; return an error.
Expand Down
89 changes: 89 additions & 0 deletions smt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
sm := NewSimpleMap()
smt := NewSparseMerkleTree(sm, sha256.New())
var value []byte
var has bool
var err error

// Test getting an empty key.
Expand All @@ -23,6 +24,13 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
if !bytes.Equal(defaultValue, value) {
t.Error("did not get default value when getting empty key")
}
has, err = smt.Has([]byte("testKey"))
if err != nil {
t.Errorf("returned error when checking presence of empty key: %v", err)
}
if has {
t.Error("did not get 'false' when checking presence of empty key")
}

// Test updating the empty key.
_, err = smt.Update([]byte("testKey"), []byte("testValue"))
Expand All @@ -36,6 +44,13 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
if !bytes.Equal([]byte("testValue"), value) {
t.Error("did not get correct value when getting non-empty key")
}
has, err = smt.Has([]byte("testKey"))
if err != nil {
t.Errorf("returned error when checking presence of non-empty key: %v", err)
}
if !has {
t.Error("did not get 'true' when checking presence of non-empty key")
}

// Test updating the non-empty key.
_, err = smt.Update([]byte("testKey"), []byte("testValue2"))
Expand Down Expand Up @@ -114,6 +129,33 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) {
if !bytes.Equal([]byte("testValue2"), value) {
t.Error("did not get correct value when getting non-empty key")
}
has, err = smt.HasForRoot([]byte("testKey"), root)
if err != nil {
t.Errorf("returned error when checking presence of non-empty key: %v", err)
}
if !has {
t.Error("did not get 'false' when checking presence of non-empty key")
}

// Test that it is possible to delete key in an older root.
root, err = smt.DeleteForRoot([]byte("testKey3"), root)
if err != nil {
t.Errorf("unable to delete key: %v", err)
}
value, err = smt.GetForRoot([]byte("testKey3"), root)
if err != nil {
t.Errorf("returned error when getting empty key: %v", err)
}
if !bytes.Equal(defaultValue, value) {
t.Error("did not get correct value when getting empty key")
}
has, err = smt.HasForRoot([]byte("testKey3"), root)
if err != nil {
t.Errorf("returned error when checking presence of empty key: %v", err)
}
if has {
t.Error("did not get 'false' when checking presence of empty key")
}

// Test that a tree can be imported from a MapStore.
smt2 := ImportSparseMerkleTree(sm, sha256.New(), smt.Root())
Expand Down Expand Up @@ -195,6 +237,13 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) {
if !bytes.Equal(defaultValue, value) {
t.Error("did not get default value when getting deleted key")
}
has, err := smt.Has([]byte("testKey"))
if err != nil {
t.Errorf("returned error when checking existence of deleted key: %v", err)
}
if has {
t.Error("returned 'true' when checking existernce of deleted key")
}
_, err = smt.Update([]byte("testKey"), []byte("testValue"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
Expand Down Expand Up @@ -269,6 +318,46 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) {
if !bytes.Equal(root1, smt.Root()) {
t.Error("tree root is not as expected after deleting second key")
}

// Testing inserting, deleting a key, and inserting it again, using Delete
_, err = smt.Update([]byte("testKey"), []byte("testValue"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
root1 = smt.Root()
_, err = smt.Delete([]byte("testKey"))
if err != nil {
t.Errorf("returned error when deleting key: %v", err)
}
value, err = smt.Get([]byte("testKey"))
if err != nil {
t.Errorf("returned error when getting deleted key: %v", err)
}
if !bytes.Equal(defaultValue, value) {
t.Error("did not get default value when getting deleted key")
}
has, err = smt.Has([]byte("testKey"))
if err != nil {
t.Errorf("returned error when checking existence of deleted key: %v", err)
}
if has {
t.Error("returned 'true' when checking existernce of deleted key")
}
_, err = smt.Update([]byte("testKey"), []byte("testValue"))
if err != nil {
t.Errorf("returned error when updating empty key: %v", err)
}
value, err = smt.Get([]byte("testKey"))
if err != nil {
t.Errorf("returned error when getting non-empty key: %v", err)
}
if !bytes.Equal([]byte("testValue"), value) {
t.Error("did not get correct value when getting non-empty key")
}
if !bytes.Equal(root1, smt.Root()) {
t.Error("tree root is not as expected after re-inserting key after deletion")
}

}

// dummyHasher is a dummy hasher for tests, where the digest of keys is equivalent to the preimage.
Expand Down