Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix proof generation handling of empty sibilings #72

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.DS_Store
.vscode
20 changes: 8 additions & 12 deletions go-sdk/integration-test/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,11 @@ func TestZeto_3_SuccessfulProving(t *testing.T) {
assert.NoError(t, err)
err = mt.AddLeaf(n2)
assert.NoError(t, err)
proof1, _, err := mt.GenerateProof(input1, nil)
proofs, _, err := mt.GenerateProofs([]*big.Int{input1, input2}, nil)
assert.NoError(t, err)
circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT)
circomProof1, err := proofs[0].ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT)
assert.NoError(t, err)
proof2, _, err := mt.GenerateProof(input2, nil)
assert.NoError(t, err)
circomProof2, err := proof2.ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT)
circomProof2, err := proofs[1].ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT)
assert.NoError(t, err)

salt3 := crypto.NewSalt()
Expand Down Expand Up @@ -370,13 +368,11 @@ func TestZeto_4_SuccessfulProving(t *testing.T) {
assert.NoError(t, err)
err = mt.AddLeaf(n2)
assert.NoError(t, err)
proof1, _, err := mt.GenerateProof(input1, nil)
assert.NoError(t, err)
circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT)
proofs, _, err := mt.GenerateProofs([]*big.Int{input1, input2}, nil)
assert.NoError(t, err)
proof2, _, err := mt.GenerateProof(input2, nil)
circomProof1, err := proofs[0].ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT)
assert.NoError(t, err)
circomProof2, err := proof2.ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT)
circomProof2, err := proofs[1].ToCircomVerifierProof(input2, input2, mt.Root(), MAX_HEIGHT)
assert.NoError(t, err)

salt3 := crypto.NewSalt()
Expand Down Expand Up @@ -510,9 +506,9 @@ func TestZeto_6_SuccessfulProving(t *testing.T) {
assert.NoError(t, err)
err = mt.AddLeaf(n1)
assert.NoError(t, err)
proof1, _, err := mt.GenerateProof(input1, nil)
proofs, _, err := mt.GenerateProofs([]*big.Int{input1}, nil)
assert.NoError(t, err)
circomProof1, err := proof1.ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT)
circomProof1, err := proofs[0].ToCircomVerifierProof(input1, input1, mt.Root(), MAX_HEIGHT)
assert.NoError(t, err)
proof1Siblings := make([]*big.Int, len(circomProof1.Siblings)-1)
for i, s := range circomProof1.Siblings[0 : len(circomProof1.Siblings)-1] {
Expand Down
21 changes: 18 additions & 3 deletions go-sdk/internal/sparse-merkle-tree/smt/merkletree.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,29 @@ func (mt *sparseMerkleTree) GetNode(key core.NodeIndex) (core.Node, error) {
// GenerateProof generates the proof of existence (or non-existence) of a leaf node
// for a Merkle Tree given the root. It uses the node's index to represent the node.
// If the rootKey is nil, the current merkletree root is used
jimthematrix marked this conversation as resolved.
Show resolved Hide resolved
func (mt *sparseMerkleTree) GenerateProof(k *big.Int, rootKey core.NodeIndex) (core.Proof, *big.Int, error) {
func (mt *sparseMerkleTree) GenerateProofs(keys []*big.Int, rootKey core.NodeIndex) ([]core.Proof, []*big.Int, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read this code several times trying to understand how this could relate to the problem described. then realized it's a just a refactor for code convenience. the description of the methods is not updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sorry this was a necessary improvement so that multiple proofs can be generated against the same root inside the sync'ed routine

mt.RLock()
defer mt.RUnlock()

merkleProofs := make([]core.Proof, len(keys))
foundValues := make([]*big.Int, len(keys))
for i, key := range keys {
proof, value, err := mt.generateProof(key, rootKey)
if err != nil {
return nil, nil, err
}
merkleProofs[i] = proof
foundValues[i] = value
}

return merkleProofs, foundValues, nil
}

func (mt *sparseMerkleTree) generateProof(key *big.Int, rootKey core.NodeIndex) (core.Proof, *big.Int, error) {
p := &proof{}
var siblingKey core.NodeIndex

kHash, err := node.NewNodeIndexFromBigInt(k)
kHash, err := node.NewNodeIndexFromBigInt(key)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -160,7 +175,7 @@ func (mt *sparseMerkleTree) GenerateProof(k *big.Int, rootKey core.NodeIndex) (c
p.siblings = append(p.siblings, siblingKey)
}
}
return nil, nil, ErrKeyNotFound
return nil, nil, ErrReachedMaxLevel
}

// must be called from inside a read lock
Expand Down
16 changes: 8 additions & 8 deletions go-sdk/internal/sparse-merkle-tree/smt/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@ func (p *proof) ExistingNode() core.Node {
}

func (p *proof) MarkNonEmptySibling(level uint) {
desiredLength := (level + 7) / 8
if desiredLength == 0 {
desiredLength = 1
}
if len(p.nonEmptySiblings) <= int(desiredLength) {
newBytes := make([]byte, desiredLength)
desiredByteLength := level/8 + 1
if len(p.nonEmptySiblings) <= int(desiredByteLength) {
// the bitmap is not big enough, resize it
newBytes := make([]byte, desiredByteLength)
if len(p.nonEmptySiblings) == 0 {
p.nonEmptySiblings = newBytes
} else {
Expand Down Expand Up @@ -186,10 +184,12 @@ func calculateRootFromProof(proof *proof, leafNode core.Node) (core.NodeIndex, e

// isBitOnBigEndian tests whether the bit n in bitmap is 1, in Big Endian.
func isBitOnBigEndian(bitmap []byte, n uint) bool {
return bitmap[uint(len(bitmap))-n/8-1]&(1<<(n%8)) != 0
byteIdxToCheck := n / 8
return bitmap[byteIdxToCheck]&(1<<(n%8)) != 0
}

// setBitBigEndian sets the bit n in the bitmap to 1, in Big Endian.
func setBitBigEndian(bitmap []byte, n uint) {
bitmap[uint(len(bitmap))-n/8-1] |= 1 << (n % 8)
byteIdxToSet := n / 8
bitmap[byteIdxToSet] |= 1 << (n % 8)
}
68 changes: 68 additions & 0 deletions go-sdk/internal/sparse-merkle-tree/smt/proof_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright © 2024 Kaleido, Inc.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package smt

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestMarkNonEmptySibling(t *testing.T) {
p := &proof{}
for i := 0; i < 256; i++ {
p.MarkNonEmptySibling(uint(i))
}
expected := make([]byte, 32)
for i := 0; i < 32; i++ {
expected[i] = 0xff
}
assert.Equal(t, p.nonEmptySiblings, expected)
}

func TestIsBitOnBigEndian(t *testing.T) {
p := &proof{}
expected := make([]byte, 32)
for i := 0; i < 32; i++ {
expected[i] = 0xff
}
p.nonEmptySiblings = expected
for i := 0; i < 256; i++ {
assert.True(t, isBitOnBigEndian(p.nonEmptySiblings, uint(i)))
}
}

func TestMarkAndCheck(t *testing.T) {
p := &proof{}
p.MarkNonEmptySibling(0)
p.MarkNonEmptySibling(10)
p.MarkNonEmptySibling(136)
assert.True(t, p.IsNonEmptySibling(0))
assert.False(t, p.IsNonEmptySibling(1))
assert.False(t, p.IsNonEmptySibling(2))
assert.False(t, p.IsNonEmptySibling(3))
assert.False(t, p.IsNonEmptySibling(4))
assert.False(t, p.IsNonEmptySibling(5))
assert.False(t, p.IsNonEmptySibling(6))
assert.False(t, p.IsNonEmptySibling(7))
assert.False(t, p.IsNonEmptySibling(8))
assert.False(t, p.IsNonEmptySibling(9))
assert.True(t, p.IsNonEmptySibling(10))
assert.False(t, p.IsNonEmptySibling(55))
assert.True(t, p.IsNonEmptySibling(136))
assert.False(t, p.IsNonEmptySibling(137))
}
22 changes: 10 additions & 12 deletions go-sdk/internal/sparse-merkle-tree/smt/smt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,22 +131,20 @@ func TestGenerateProof(t *testing.T) {
assert.NoError(t, err)

target1 := node1.Index().BigInt()
proof1, foundValue1, err := mt.GenerateProof(target1, mt.Root())
assert.NoError(t, err)
assert.Equal(t, target1, foundValue1)
assert.True(t, proof1.(*proof).existence)
valid := VerifyProof(mt.Root(), proof1, node1)
assert.True(t, valid)

utxo3 := node.NewFungible(big.NewInt(10), alice.PublicKey, big.NewInt(12347))
node3, err := node.NewLeafNode(utxo3)
assert.NoError(t, err)
target2 := node3.Index().BigInt()
proof2, _, err := mt.GenerateProof(target2, mt.Root())
proofs, foundValues, err := mt.GenerateProofs([]*big.Int{target1, target2}, mt.Root())
assert.NoError(t, err)
assert.False(t, proof2.(*proof).existence)
assert.Equal(t, target1, foundValues[0])
assert.True(t, proofs[0].(*proof).existence)
valid := VerifyProof(mt.Root(), proofs[0], node1)
assert.True(t, valid)
assert.False(t, proofs[1].(*proof).existence)

proof3, err := proof1.ToCircomVerifierProof(target1, foundValue1, mt.Root(), levels)
proof3, err := proofs[0].ToCircomVerifierProof(target1, foundValues[0], mt.Root(), levels)
assert.NoError(t, err)
assert.False(t, proof3.IsOld0)
}
Expand Down Expand Up @@ -181,11 +179,11 @@ func TestVerifyProof(t *testing.T) {

target := n.Index().BigInt()
root := mt.Root()
p, _, err := mt.GenerateProof(target, root)
p, _, err := mt.GenerateProofs([]*big.Int{target}, root)
assert.NoError(t, err)
assert.True(t, p.(*proof).existence)
assert.True(t, p[0].(*proof).existence)

valid := VerifyProof(root, p, n)
valid := VerifyProof(root, p[0], n)
assert.True(t, valid)
}()

Expand Down
6 changes: 3 additions & 3 deletions go-sdk/pkg/sparse-merkle-tree/core/merkletree.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ type SparseMerkleTree interface {
// Root returns the root hash of the tree
Root() NodeIndex
// AddLeaf adds a key-value pair to the tree
AddLeaf(Node) error
AddLeaf(leaf Node) error
// GetNode returns the node at the given reference hash
GetNode(NodeIndex) (Node, error)
GetNode(node NodeIndex) (Node, error)
// GetnerateProof generates a proof of existence (or non-existence) of a leaf node
GenerateProof(*big.Int, NodeIndex) (Proof, *big.Int, error)
GenerateProofs(nodeIndexes []*big.Int, root NodeIndex) ([]Proof, []*big.Int, error)
}
Loading