diff --git a/nmt.go b/nmt.go index 9f52d03..63a2e13 100644 --- a/nmt.go +++ b/nmt.go @@ -143,7 +143,7 @@ func New(h hash.Hash, setters ...Option) *NamespacedMerkleTree { // Prove returns a NMT inclusion proof for the leaf at the supplied index. Note // this is not really NMT specific but the tree supports inclusions proofs like // any vanilla Merkle tree. Prove is a thin wrapper around the ProveRange. -// If the supplied index is invalid i.e., if index < 0 or index > len(n.leaves), then Prove returns an ErrInvalidRange error. Any other errors rather than this are irrecoverable and indicate an illegal state of the tree (n). +// If the supplied index is invalid i.e., if index < 0 or index > n.Size(), then Prove returns an ErrInvalidRange error. Any other errors rather than this are irrecoverable and indicate an illegal state of the tree (n). func (n *NamespacedMerkleTree) Prove(index int) (Proof, error) { return n.ProveRange(index, index+1) } @@ -164,7 +164,7 @@ func (n *NamespacedMerkleTree) Prove(index int) (Proof, error) { // generated using a modified version of the namespace hash with a custom // namespace ID range calculation. For more information on this, please refer to // the HashNode method in the Hasher. -// If the supplied (start, end) range is invalid i.e., if start < 0 or end > len(n.leafHashes) or start >= end, +// If the supplied (start, end) range is invalid i.e., if start < 0 or end > n.Size() or start >= end, // then ProveRange returns an ErrInvalidRange error. Any errors rather than ErrInvalidRange are irrecoverable and indicate an illegal state of the tree (n). func (n *NamespacedMerkleTree) ProveRange(start, end int) (Proof, error) { isMaxNsIgnored := n.treeHasher.IsMaxNamespaceIDIgnored() @@ -248,7 +248,7 @@ func (n *NamespacedMerkleTree) ProveNamespace(nID namespace.ID) (Proof, error) { // validateRange validates the range [start, end) against the size of the tree. // start is inclusive and end is non-inclusive. func (n *NamespacedMerkleTree) validateRange(start, end int) error { - if start < 0 || start >= end || end > len(n.leaves) { + if start < 0 || start >= end || end > n.Size() { return ErrInvalidRange } return nil @@ -268,12 +268,12 @@ func (n *NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) ([][]by } // start, end are indices of leaves in the tree hence they should be within - // the size of the tree i.e., less than or equal to the len(n.leaves) + // the size of the tree i.e., less than or equal to n.Size() // includeNode indicates whether the hash of the current subtree (covering // the supplied range i.e., [start, end)) or one of its constituent subtrees // should be part of the proof recurse = func(start, end int, includeNode bool) ([]byte, error) { - if start >= len(n.leafHashes) { + if start >= n.Size() { return nil, nil } @@ -341,7 +341,7 @@ func (n *NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) ([][]by return hash, nil } - fullTreeSize := getSplitPoint(len(n.leafHashes)) * 2 + fullTreeSize := getSplitPoint(n.Size()) * 2 if fullTreeSize < 1 { fullTreeSize = 1 } @@ -449,7 +449,7 @@ func (n *NamespacedMerkleTree) Push(namespacedData namespace.PrefixedData) error // Any error returned by this method is irrecoverable and indicate an illegal state of the tree (n). func (n *NamespacedMerkleTree) Root() ([]byte, error) { if n.rawRoot == nil { - res, err := n.computeRoot(0, len(n.leaves)) + res, err := n.computeRoot(0, n.Size()) if err != nil { return nil, err // this should never happen since leaves are validated in the Push method } @@ -484,7 +484,7 @@ func (n *NamespacedMerkleTree) MaxNamespace() (namespace.ID, error) { func (n *NamespacedMerkleTree) computeRoot(start, end int) ([]byte, error) { // in computeRoot, start may be equal to end which indicates an empty tree hence empty root. // Due to this, we need to perform custom range check instead of using validateRange() in which start=end is considered invalid. - if start < 0 || start > end || end > len(n.leaves) { + if start < 0 || start > end || end > n.Size() { return nil, fmt.Errorf("failed to compute root [%d, %d): %w", start, end, ErrInvalidRange) } switch end - start { @@ -533,8 +533,8 @@ func getSplitPoint(length int) int { } func (n *NamespacedMerkleTree) updateNamespaceRanges() { - if len(n.leaves) > 0 { - lastIndex := len(n.leaves) - 1 + if n.Size() > 0 { + lastIndex := n.Size() - 1 lastPushed := n.leaves[lastIndex] lastNsStr := string(lastPushed[:n.treeHasher.NamespaceSize()]) lastRange, found := n.namespaceRanges[lastNsStr] @@ -570,7 +570,7 @@ func (n *NamespacedMerkleTree) validateAndExtractNamespace(ndata namespace.Prefi nID := namespace.ID(ndata[:n.NamespaceSize()]) // ensure pushed data doesn't have a smaller namespace than the previous // one: - curSize := len(n.leaves) + curSize := n.Size() if curSize > 0 { if nID.Less(n.leaves[curSize-1][:nidSize]) { return nil, fmt.Errorf( @@ -615,3 +615,8 @@ func MaxNamespace(hash []byte, size namespace.IDSize) []byte { max := make([]byte, 0, size) return append(max, hash[size:size*2]...) } + +// Size returns the number of leaves in the tree. +func (n *NamespacedMerkleTree) Size() int { + return len(n.leaves) +}