diff --git a/smt.go b/smt.go index 78e549e..6ffec1e 100644 --- a/smt.go +++ b/smt.go @@ -123,6 +123,18 @@ func (smt *SparseMerkleTree) GetForRoot(key []byte, root []byte) ([]byte, error) return value, nil } +// Has returns true if tree cointains given key, false otherwise. +func (smt *SparseMerkleTree) Has(key []byte) (bool, error) { + val, err := smt.Get(key) + return !bytes.Equal(defaultValue, val), err +} + +// HasForRoot returns true if tree cointains given key at a specific root, false otherwise. +func (smt *SparseMerkleTree) HasForRoot(key, root []byte) (bool, error) { + val, err := smt.GetForRoot(key, root) + return !bytes.Equal(defaultValue, val), err +} + // Update sets a new value for a key in the tree, and sets and returns the new root of the tree. func (smt *SparseMerkleTree) Update(key []byte, value []byte) ([]byte, error) { newRoot, err := smt.UpdateForRoot(key, value, smt.Root()) @@ -132,6 +144,11 @@ func (smt *SparseMerkleTree) Update(key []byte, value []byte) ([]byte, error) { return newRoot, err } +// Delete deletes a value from tree. It returns the new root of the tree. +func (smt *SparseMerkleTree) Delete(key []byte) ([]byte, error) { + return smt.Update(key, defaultValue) +} + // UpdateForRoot sets a new value for a key in the tree at a specific root, and returns the new root. func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte) ([]byte, error) { path := smt.th.path(key) @@ -155,6 +172,11 @@ func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte return newRoot, err } +// Delete deletes a value from tree at a specific root. It returns the new root of the tree. +func (smt *SparseMerkleTree) DeleteForRoot(key, root []byte) ([]byte, error) { + return smt.UpdateForRoot(key, defaultValue, root) +} + func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte, oldLeafHash []byte, oldLeafData []byte) ([]byte, error) { if bytes.Equal(oldLeafHash, smt.th.placeholder()) { // This key is already empty as it is a placeholder; return an error. diff --git a/smt_test.go b/smt_test.go index 4533ec7..1b9509a 100644 --- a/smt_test.go +++ b/smt_test.go @@ -13,6 +13,7 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { sm := NewSimpleMap() smt := NewSparseMerkleTree(sm, sha256.New()) var value []byte + var has bool var err error // Test getting an empty key. @@ -23,6 +24,13 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { if !bytes.Equal(defaultValue, value) { t.Error("did not get default value when getting empty key") } + has, err = smt.Has([]byte("testKey")) + if err != nil { + t.Errorf("returned error when checking presence of empty key: %v", err) + } + if has { + t.Error("did not get 'false' when checking presence of empty key") + } // Test updating the empty key. _, err = smt.Update([]byte("testKey"), []byte("testValue")) @@ -36,6 +44,13 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { if !bytes.Equal([]byte("testValue"), value) { t.Error("did not get correct value when getting non-empty key") } + has, err = smt.Has([]byte("testKey")) + if err != nil { + t.Errorf("returned error when checking presence of non-empty key: %v", err) + } + if !has { + t.Error("did not get 'true' when checking presence of non-empty key") + } // Test updating the non-empty key. _, err = smt.Update([]byte("testKey"), []byte("testValue2")) @@ -114,6 +129,33 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { if !bytes.Equal([]byte("testValue2"), value) { t.Error("did not get correct value when getting non-empty key") } + has, err = smt.HasForRoot([]byte("testKey"), root) + if err != nil { + t.Errorf("returned error when checking presence of non-empty key: %v", err) + } + if !has { + t.Error("did not get 'false' when checking presence of non-empty key") + } + + // Test that it is possible to delete key in an older root. + root, err = smt.DeleteForRoot([]byte("testKey3"), root) + if err != nil { + t.Errorf("unable to delete key: %v", err) + } + value, err = smt.GetForRoot([]byte("testKey3"), root) + if err != nil { + t.Errorf("returned error when getting empty key: %v", err) + } + if !bytes.Equal(defaultValue, value) { + t.Error("did not get correct value when getting empty key") + } + has, err = smt.HasForRoot([]byte("testKey3"), root) + if err != nil { + t.Errorf("returned error when checking presence of empty key: %v", err) + } + if has { + t.Error("did not get 'false' when checking presence of empty key") + } // Test that a tree can be imported from a MapStore. smt2 := ImportSparseMerkleTree(sm, sha256.New(), smt.Root()) @@ -195,6 +237,13 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) { if !bytes.Equal(defaultValue, value) { t.Error("did not get default value when getting deleted key") } + has, err := smt.Has([]byte("testKey")) + if err != nil { + t.Errorf("returned error when checking existence of deleted key: %v", err) + } + if has { + t.Error("returned 'true' when checking existernce of deleted key") + } _, err = smt.Update([]byte("testKey"), []byte("testValue")) if err != nil { t.Errorf("returned error when updating empty key: %v", err) @@ -269,6 +318,46 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) { if !bytes.Equal(root1, smt.Root()) { t.Error("tree root is not as expected after deleting second key") } + + // Testing inserting, deleting a key, and inserting it again, using Delete + _, err = smt.Update([]byte("testKey"), []byte("testValue")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + root1 = smt.Root() + _, err = smt.Delete([]byte("testKey")) + if err != nil { + t.Errorf("returned error when deleting key: %v", err) + } + value, err = smt.Get([]byte("testKey")) + if err != nil { + t.Errorf("returned error when getting deleted key: %v", err) + } + if !bytes.Equal(defaultValue, value) { + t.Error("did not get default value when getting deleted key") + } + has, err = smt.Has([]byte("testKey")) + if err != nil { + t.Errorf("returned error when checking existence of deleted key: %v", err) + } + if has { + t.Error("returned 'true' when checking existernce of deleted key") + } + _, err = smt.Update([]byte("testKey"), []byte("testValue")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + value, err = smt.Get([]byte("testKey")) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal([]byte("testValue"), value) { + t.Error("did not get correct value when getting non-empty key") + } + if !bytes.Equal(root1, smt.Root()) { + t.Error("tree root is not as expected after re-inserting key after deletion") + } + } // dummyHasher is a dummy hasher for tests, where the digest of keys is equivalent to the preimage.