diff --git a/mapstore.go b/mapstore.go index 3db56d3..5fe381b 100644 --- a/mapstore.go +++ b/mapstore.go @@ -17,7 +17,7 @@ type InvalidKeyError struct { } func (e *InvalidKeyError) Error() string { - return fmt.Sprintf("invalid key: %s", e.Key) + return fmt.Sprintf("invalid key: %x", e.Key) } // SimpleMap is a simple in-memory map. diff --git a/proofs.go b/proofs.go index 86610a5..3a9efc4 100644 --- a/proofs.go +++ b/proofs.go @@ -129,11 +129,11 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va } // Recompute root. - for i := len(proof.SideNodes) - 1; i >= 0; i-- { + for i := 0; i < len(proof.SideNodes); i++ { node := make([]byte, th.pathSize()) copy(node, proof.SideNodes[i]) - if hasBit(path, i) == right { + if getBitAtFromMSB(path, len(proof.SideNodes)-1-i) == right { currentHash, currentData = th.digestNode(node, currentHash) } else { currentHash, currentData = th.digestNode(currentHash, node) @@ -170,7 +170,7 @@ func CompactProof(proof SparseMerkleProof, hasher hash.Hash) (SparseCompactMerkl node := make([]byte, th.hasher.Size()) copy(node, proof.SideNodes[i]) if bytes.Equal(node, th.placeholder()) { - setBit(bitMask, i) + setBitAtFromMSB(bitMask, i) } else { compactedSideNodes = append(compactedSideNodes, node) } @@ -195,7 +195,7 @@ func DecompactProof(proof SparseCompactMerkleProof, hasher hash.Hash) (SparseMer decompactedSideNodes := make([][]byte, proof.NumSideNodes) position := 0 for i := 0; i < proof.NumSideNodes; i++ { - if hasBit(proof.BitMask, i) == 1 { + if getBitAtFromMSB(proof.BitMask, i) == 1 { decompactedSideNodes[i] = th.placeholder() } else { decompactedSideNodes[i] = proof.SideNodes[position] diff --git a/proofs_test.go b/proofs_test.go index f0be57b..893262e 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -210,7 +210,7 @@ func TestCompactProofsSanityCheck(t *testing.T) { // Case (compact proofs): unexpected bit mask length. proof, _ = smt.ProveCompact([]byte("testKey1")) - proof.NumSideNodes = 1 + proof.NumSideNodes = 10 if proof.sanityCheck(th) { t.Error("sanity check incorrectly passed") } @@ -221,7 +221,7 @@ func TestCompactProofsSanityCheck(t *testing.T) { // Case (compact proofs): unexpected number of sidenodes for number of side nodes. proof, _ = smt.ProveCompact([]byte("testKey1")) - proof.SideNodes = proof.SideNodes[:1] + proof.SideNodes = append(proof.SideNodes, proof.SideNodes...) if proof.sanityCheck(th) { t.Error("sanity check incorrectly passed") } diff --git a/smt.go b/smt.go index 2882426..99edfec 100644 --- a/smt.go +++ b/smt.go @@ -96,7 +96,7 @@ func (smt *SparseMerkleTree) GetForRoot(key []byte, root []byte) ([]byte, error) } leftNode, rightNode := smt.th.parseNode(currentData) - if hasBit(path, i) == right { + if getBitAtFromMSB(path, i) == right { currentHash = rightNode } else { currentHash = leftNode @@ -188,7 +188,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte var currentHash, currentData []byte nonPlaceholderReached := false - for i := smt.depth() - 1; i >= 0; i-- { + for i := 0; i < len(sideNodes); i++ { if sideNodes[i] == nil { continue } @@ -215,7 +215,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte } if !nonPlaceholderReached && bytes.Equal(sideNode, smt.th.placeholder()) { - // We found another placeholder sibling node, keep going down the + // We found another placeholder sibling node, keep going up the // tree until we find the first sibling that is not a placeholder. continue } else if !nonPlaceholderReached { @@ -224,7 +224,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte nonPlaceholderReached = true } - if hasBit(path, i) == right { + if getBitAtFromMSB(path, len(sideNodes)-1-i) == right { currentHash, currentData = smt.th.digestNode(sideNode, currentData) } else { currentHash, currentData = smt.th.digestNode(currentData, sideNode) @@ -269,7 +269,7 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side commonPrefixCount = countCommonPrefix(path, actualPath) } if commonPrefixCount != smt.depth() { - if hasBit(path, commonPrefixCount) == right { + if getBitAtFromMSB(path, commonPrefixCount) == right { currentHash, currentData = smt.th.digestNode(oldLeafHash, currentData) } else { currentHash, currentData = smt.th.digestNode(currentData, oldLeafHash) @@ -283,11 +283,15 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side currentData = currentHash } - for i := smt.depth() - 1; i >= 0; i-- { + for i := 0; i < smt.depth(); i++ { sideNode := make([]byte, smt.th.pathSize()) - if sideNodes[i] == nil { - if commonPrefixCount != smt.depth() && commonPrefixCount > i { + // The offset from the bottom of the tree to the start of the side nodes + // i-offsetOfSideNodes is the index into sideNodes[] + offsetOfSideNodes := smt.depth() - len(sideNodes) + + if i-offsetOfSideNodes < 0 || sideNodes[i-offsetOfSideNodes] == nil { + if commonPrefixCount != smt.depth() && commonPrefixCount > smt.depth()-1-i { // If there are no sidenodes at this height, but the number of // bits that the paths of the two leaf nodes share in common is // greater than this height, then we need to build up the tree @@ -297,10 +301,10 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side continue } } else { - copy(sideNode, sideNodes[i]) + copy(sideNode, sideNodes[i-offsetOfSideNodes]) } - if hasBit(path, i) == right { + if getBitAtFromMSB(path, smt.depth()-1-i) == right { currentHash, currentData = smt.th.digestNode(sideNode, currentData) } else { currentHash, currentData = smt.th.digestNode(currentData, sideNode) @@ -319,7 +323,9 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side // Returns an array of sibling nodes, the leaf hash found at that path and the // leaf data. If the leaf is a placeholder, the leaf data is nil. func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byte, []byte, []byte, error) { - sideNodes := make([][]byte, smt.depth()) + // Side nodes for the path. Nodes are inserted in reverse order, then the + // slice is reversed at the end. + sideNodes := make([][]byte, 0, smt.depth()) if bytes.Equal(root, smt.th.placeholder()) { // If the root is a placeholder, there are no sidenodes to return. @@ -340,17 +346,17 @@ func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byt leftNode, rightNode := smt.th.parseNode(currentData) // Get sidenode depending on whether the path bit is on or off. - if hasBit(path, i) == right { - sideNodes[i] = leftNode + if getBitAtFromMSB(path, i) == right { + sideNodes = append(sideNodes, leftNode) nodeHash = rightNode } else { - sideNodes[i] = rightNode + sideNodes = append(sideNodes, rightNode) nodeHash = leftNode } if bytes.Equal(nodeHash, smt.th.placeholder()) { // If the node is a placeholder, we've reached the end. - return sideNodes, nodeHash, nil, nil + return reverseSideNodes(sideNodes), nodeHash, nil, nil } currentData, err = smt.ms.Get(nodeHash) @@ -362,7 +368,7 @@ func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte) ([][]byt } } - return sideNodes, nodeHash, currentData, err + return reverseSideNodes(sideNodes), nodeHash, currentData, nil } // Prove generates a Merkle proof for a key. diff --git a/smt_test.go b/smt_test.go index 1b9509a..41bfedc 100644 --- a/smt_test.go +++ b/smt_test.go @@ -168,7 +168,117 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { } } -// Test tree operations when two leafs are immediate neighbours. +// Test known tree ops +func TestSparseMerkleTreeKnown(t *testing.T) { + h := newDummyHasher(sha256.New()) + sm := NewSimpleMap() + smt := NewSparseMerkleTree(sm, h) + var value []byte + var err error + + baseKey := make([]byte, h.Size()+4) + key1 := make([]byte, h.Size()+4) + copy(key1, baseKey) + key1[4] = byte(0b00000000) + key2 := make([]byte, h.Size()+4) + copy(key2, baseKey) + key2[4] = byte(0b01000000) + key3 := make([]byte, h.Size()+4) + copy(key3, baseKey) + key3[4] = byte(0b10000000) + key4 := make([]byte, h.Size()+4) + copy(key4, baseKey) + key4[4] = byte(0b11000000) + key5 := make([]byte, h.Size()+4) + copy(key5, baseKey) + key5[4] = byte(0b11010000) + + _, err = smt.Update(key1, []byte("testValue1")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + _, err = smt.Update(key2, []byte("testValue2")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + _, err = smt.Update(key3, []byte("testValue3")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + _, err = smt.Update(key4, []byte("testValue4")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + _, err = smt.Update(key5, []byte("testValue5")) + if err != nil { + t.Errorf("returned error when updating empty key: %v", err) + } + + value, err = smt.Get(key1) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal([]byte("testValue1"), value) { + t.Error("did not get correct value when getting non-empty key") + } + value, err = smt.Get(key2) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal([]byte("testValue2"), value) { + t.Error("did not get correct value when getting non-empty key") + } + value, err = smt.Get(key3) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal([]byte("testValue3"), value) { + t.Error("did not get correct value when getting non-empty key") + } + value, err = smt.Get(key4) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal([]byte("testValue4"), value) { + t.Error("did not get correct value when getting non-empty key") + } + value, err = smt.Get(key5) + if err != nil { + t.Errorf("returned error when getting non-empty key: %v", err) + } + if !bytes.Equal([]byte("testValue5"), value) { + t.Error("did not get correct value when getting non-empty key") + } + + proof1, _ := smt.Prove(key1) + proof2, _ := smt.Prove(key2) + proof3, _ := smt.Prove(key3) + proof4, _ := smt.Prove(key4) + proof5, _ := smt.Prove(key5) + dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), h, smt.Root()) + err = dsmst.AddBranch(proof1, key1, []byte("testValue1")) + if err != nil { + t.Errorf("returned error when adding branch to deep subtree: %v", err) + } + err = dsmst.AddBranch(proof2, key2, []byte("testValue2")) + if err != nil { + t.Errorf("returned error when adding branch to deep subtree: %v", err) + } + err = dsmst.AddBranch(proof3, key3, []byte("testValue3")) + if err != nil { + t.Errorf("returned error when adding branch to deep subtree: %v", err) + } + err = dsmst.AddBranch(proof4, key4, []byte("testValue4")) + if err != nil { + t.Errorf("returned error when adding branch to deep subtree: %v", err) + } + err = dsmst.AddBranch(proof5, key5, []byte("testValue5")) + if err != nil { + t.Errorf("returned error when adding branch to deep subtree: %v", err) + } +} + +// Test tree operations when two leafs are immediate neighbors. func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { h := newDummyHasher(sha256.New()) sm := NewSimpleMap() @@ -176,9 +286,9 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { var value []byte var err error - // Make two neighbouring keys. + // Make two neighboring keys. // - // The dummy hash function excepts keys to prefixed with four bytes of 0, + // The dummy hash function expects keys to prefixed with four bytes of 0, // which will cause it to return the preimage itself as the digest, without // the first four bytes. key1 := make([]byte, h.Size()+4) @@ -187,7 +297,8 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { key1[h.Size()+4-1] = byte(0) key2 := make([]byte, h.Size()+4) copy(key2, key1) - setBit(key2, (h.Size()+4)*8-1) + // We make key2's least significant bit different than key1's + key2[h.Size()+4-1] = byte(1) _, err = smt.Update(key1, []byte("testValue1")) if err != nil { @@ -205,7 +316,6 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { if !bytes.Equal([]byte("testValue1"), value) { t.Error("did not get correct value when getting non-empty key") } - value, err = smt.Get(key2) if err != nil { t.Errorf("returned error when getting non-empty key: %v", err) @@ -213,6 +323,14 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { if !bytes.Equal([]byte("testValue2"), value) { t.Error("did not get correct value when getting non-empty key") } + + proof, err := smt.Prove(key1) + if err != nil { + t.Errorf("returned error when proving key: %v", err) + } + if len(proof.SideNodes) != 256 { + t.Errorf("unexpected proof size") + } } // Test base case tree delete operations with a few keys. diff --git a/utils.go b/utils.go index d413bfd..55163cd 100644 --- a/utils.go +++ b/utils.go @@ -1,22 +1,24 @@ package smt -func hasBit(data []byte, position int) int { - if int(data[position/8])&(1<<(uint(position)%8)) > 0 { +// getBitAtFromMSB gets the bit at an offset from the most significant bit +func getBitAtFromMSB(data []byte, position int) int { + if int(data[position/8])&(1<<(8-1-uint(position)%8)) > 0 { return 1 } return 0 } -func setBit(data []byte, position int) { +// setBitAtFromMSB sets the bit at an offset from the most significant bit +func setBitAtFromMSB(data []byte, position int) { n := int(data[position/8]) - n |= (1 << (uint(position) % 8)) + n |= (1 << (8 - 1 - uint(position)%8)) data[position/8] = byte(n) } func countSetBits(data []byte) int { count := 0 for i := 0; i < len(data)*8; i++ { - if hasBit(data, i) == 1 { + if getBitAtFromMSB(data, i) == 1 { count++ } } @@ -26,7 +28,7 @@ func countSetBits(data []byte) int { func countCommonPrefix(data1 []byte, data2 []byte) int { count := 0 for i := 0; i < len(data1)*8; i++ { - if hasBit(data1, i) == hasBit(data2, i) { + if getBitAtFromMSB(data1, i) == getBitAtFromMSB(data2, i) { count++ } else { break @@ -39,3 +41,11 @@ func emptyBytes(length int) []byte { b := make([]byte, length) return b } + +func reverseSideNodes(sideNodes [][]byte) [][]byte { + for left, right := 0, len(sideNodes)-1; left < right; left, right = left+1, right-1 { + sideNodes[left], sideNodes[right] = sideNodes[right], sideNodes[left] + } + + return sideNodes +}