diff --git a/trie/binary.go b/trie/binary.go index 6dbf13a9ab75..d25193a8e7e4 100644 --- a/trie/binary.go +++ b/trie/binary.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" "github.com/hashicorp/golang-lru" + "github.com/ethereum/go-ethereum/rlp" ) // BinaryNode represents any node in a binary trie. @@ -37,6 +38,8 @@ type BinaryNode interface { HashM4() []byte hash(off int) []byte Commit() error + isLeaf() bool + Value([]byte) interface{} gv(string) (string, string) } @@ -83,6 +86,13 @@ type ( key []byte // TODO split into leaf and branch value []byte + account struct { + Balance *big.Int + Nonce uint64 + Code common.Hash + Storage common.Hash + } + // Used to send (hash, preimage) pairs when hashing CommitCh chan BinaryHashPreimage @@ -137,7 +147,7 @@ var ( ) // TryGet returns the value for a key stored in the trie. -func (bt *BinaryTrie) TryGet(key []byte) ([]byte, error) { +func (bt *BinaryTrie) TryGet(key []byte) (interface{}, error) { bk := newBinKey(key) off := 0 @@ -159,8 +169,8 @@ func (bt *BinaryTrie) TryGet(key []byte) ([]byte, error) { // If it is a leaf node, then the a leaf node // has been reached, and the value can be returned // right away. - if currentNode.value != nil { - return currentNode.value, nil + if currentNode.isLeaf() { + return currentNode.Value(key), nil } // This node is a fork, get the child node @@ -188,7 +198,7 @@ func (bt *BinaryTrie) ToGraphViz() string { } func newBranchNode(prefix binkey, key []byte, value []byte, ht hashType) *branch { - return &branch{ + br := &branch{ prefix: prefix, left: empty(struct{}{}), right: empty(struct{}{}), @@ -197,6 +207,17 @@ func newBranchNode(prefix binkey, key []byte, value []byte, ht hashType) *branch hType: ht, childCount: 0, } + + if len(key) == 32 { + if err := rlp.DecodeBytes(value, &br.account); err != nil { + panic(err) + } + + // set right node to the storage root hash + br.right = hashBinaryNode(br.account.Storage[:]) + } + + return br } // Hash calculates the hash of an expanded (i.e. not already @@ -225,6 +246,12 @@ func (br *branch) putHasher(hasher *hasher) { } func (br *branch) hash(off int) []byte { + // Special hashing case if this is called + // at account-level + if off+len(br.prefix) > 0 { + return br.hashAccountLevel(off) + } + var hasher *hasher var hash []byte if br.value == nil { @@ -283,6 +310,97 @@ func (br *branch) hash(off int) []byte { return hash } +func (br *branch) hashAccountLevel(off int) []byte { + var hash []byte + var hasher *hasher = br.getHasher() + defer br.putHasher(hasher) + + // Write the balance + hasher.sha.Write(br.key[0:254]) + hasher.sha.Write([]byte{0, 0}) + kh := hasher.sha.Sum(nil) + hasher.sha.Reset() + hasher.sha.Write(br.account.Balance.Bytes()) + hash = hasher.sha.Sum(nil) + hasher.sha.Reset() + hasher.sha.Write(kh) + hasher.sha.Write(hash) + hash = hasher.sha.Sum(nil) + hasher.sha.Reset() + + // Write the nonce + hasher.sha.Write(br.key[0:254]) + hasher.sha.Write([]byte{0, 1}) + kh = hasher.sha.Sum(nil) + hasher.sha.Reset() + var serialized [32]byte + binary.LittleEndian.PutUint64(serialized[:], br.account.Nonce) + hasher.sha.Write(serialized[:]) + hash = hasher.sha.Sum(nil) + hasher.sha.Reset() + hasher.sha.Write(kh) + hasher.sha.Write(hash) + hash = hasher.sha.Sum(nil) + hasher.sha.Reset() + + + // EoA case, add the prefix length + ok := bytes.Equal(br.account.Storage[:], emptyRoot[:]) + if len(br.account.Code) == 0 && ok { + hasher.sha.Write([]byte{255}) // depth is known + hasher.sha.Write(zero32[:31]) + hasher.sha.Write(hash) + hash = hasher.sha.Sum(nil) + return hash + } + + // Contract with storage case + if len(br.account.Code) != 0 && !ok { + // Write the code + hasher.sha.Write(br.key[0:254]) + hasher.sha.Write([]byte{1, 0}) + kh = hasher.sha.Sum(nil) + hasher.sha.Reset() + hasher.sha.Write(br.account.Code[:]) + hash2 := hasher.sha.Sum(nil) + hasher.sha.Reset() + hasher.sha.Write(kh) + hasher.sha.Write(hash2) + hash2 = hasher.sha.Sum(nil) + hasher.sha.Reset() + + // Write the storage trie + hasher.sha.Write(br.key[0:254]) + hasher.sha.Write([]byte{1, 1}) + kh = hasher.sha.Sum(nil) + hasher.sha.Reset() + hasher.sha.Write(br.right.hash(256)[:]) // storage root is right child of account "branch" + hash2 = hasher.sha.Sum(nil) + hasher.sha.Reset() + hasher.sha.Write(kh) + hasher.sha.Write(hash2) + hash2 = hasher.sha.Sum(nil) + hasher.sha.Reset() + + // merge the two sides of the branch at bit #254 + hasher.sha.Write(hash) + hasher.sha.Write(hash2) + hash = hasher.sha.Sum(nil) + + if len(br.prefix) > 0 { + hasher.sha.Reset() + hasher.sha.Write([]byte{254}) // depth is known + hasher.sha.Write(zero32[:31]) + hasher.sha.Write(hash) + hash = hasher.sha.Sum(nil) + } + + return hash + } + + panic("can't hash accounts with either only code or only NewMemoryDatabase") +} + func (br *branch) HashM4() []byte { var hasher *hasher var hash []byte @@ -327,6 +445,41 @@ func (br *branch) HashM4() []byte { return hash } +func (br *branch) isLeaf() bool { + _, ok := br.left.(empty) + return ok +} + +func (br *branch) Value(key []byte) interface{} { + if br.value == nil || (len(key) != 32 && len(key) != 64) { + panic(fmt.Sprintf("trying to get the value of an internal node %d : value = %p", len(key), br.value)) + } + + if len(key) != 32 { + return br.value + } + + switch key[31] & 0x3 { + case 0: + if br.account.Balance == nil { + panic(fmt.Sprintf("nil value at %x", key)) + } + return br.account.Balance + case 1: + return br.account.Nonce + case 2: + return br.account.Code + case 3: + // return the storage root, which is has to be + // recalculated in case it has been updated. + // The root of the storage trie is stored in + // the account level node's right value. + return common.BytesToHash(br.right.hash(256)) + default: + panic("should not be here") + } +} + func (br *branch) gv(path string) (string, string) { me := fmt.Sprintf("br%s", path) var l, r string @@ -542,11 +695,9 @@ func (bt *BinaryTrie) TryUpdate(key, value []byte) error { switch bt.root.(type) { case empty: // This is when the trie hasn't been inserted - // into, so initialize the root as a branch - // node (a value, really). + // into, so initialize the root as a value. bt.root = newBranchNode(bk, key, value, bt.hashType) bt.db.insert(key, value) - return nil case *branch: currentNode = bt.root.(*branch) @@ -617,15 +768,11 @@ func (bt *BinaryTrie) TryUpdate(key, value []byte) error { split := bk[off:].commonLength(currentNode.prefix) // A split is needed - midNode := &branch{ - prefix: currentNode.prefix[split+1:], - left: currentNode.left, - right: currentNode.right, - key: currentNode.key, - value: currentNode.value, - hType: bt.hashType, - parent: currentNode, - } + midNode := newBranchNode(currentNode.prefix[split+1:], currentNode.key, currentNode.value, bt.hashType) + midNode.left = currentNode.left + midNode.right = currentNode.right + midNode.parent = currentNode + currentNode.prefix = currentNode.prefix[:split] currentNode.value = nil childNode := newBranchNode(bk[off+split+1:], key, value, bt.hashType) @@ -703,6 +850,9 @@ func (h hashBinaryNode) gv(path string) (string, string) { return fmt.Sprintf("%s [label=\"H\"]\n", me), me } +func (h hashBinaryNode) isLeaf() bool { panic("calling isLeaf on a hash node") } +func (h hashBinaryNode) Value([]byte) interface{} { panic("calling value on a hash node") } + func (e empty) Hash() []byte { return emptyRoot[:] } @@ -723,6 +873,9 @@ func (e empty) tryGet(key []byte, depth int) ([]byte, error) { return nil, errReadFromEmptyTree } +func (e empty) isLeaf() bool { panic("calling isLeaf on an empty node") } +func (e empty) Value([]byte) interface{} { panic("calling value on an empty node") } + func (e empty) gv(path string) (string, string) { me := fmt.Sprintf("e%s", path) return fmt.Sprintf("%s [label=\"∅\"]\n", me), me diff --git a/trie/binary_test.go b/trie/binary_test.go index 74764ce2fb60..6bced0690062 100644 --- a/trie/binary_test.go +++ b/trie/binary_test.go @@ -18,12 +18,16 @@ package trie import ( "bytes" + "encoding/binary" + "math/big" "math/rand" "testing" "time" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" ) func TestBinaryKeyCreation(t *testing.T) { @@ -182,39 +186,134 @@ func TestBinaryTrieReadEmpty(t *testing.T) { } } +var ( + testAddr0 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") + testAddr1 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000001") + testAddr2 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000002") + testAddr3 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000003") + testAddr4 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000004") + testAddr8 = common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000008") + testAddr11 = common.Hex2Bytes("000000000000000000000000000000000000000000000000000000000000000B") +) + +func int2addr(x int) []byte { + addr := common.Hex2Bytes("0000000000000000000000000000000000000000000000000000000000000000") + binary.BigEndian.PutUint64(addr[24:], uint64(x)) + return addr +} + +type simpleAccount struct { + Balance *big.Int + Nonce uint64 + Code common.Hash + StorageRoot common.Hash +} + +var emptyCodeHash = crypto.Keccak256Hash(nil) + +var aoe = simpleAccount{Balance: big.NewInt(100), Nonce: 1, Code: emptyCodeHash, StorageRoot: emptyRoot} + func TestBinaryTrieReadOneLeaf(t *testing.T) { + payload, err := rlp.EncodeToBytes(aoe) + if err != nil { + t.Fatalf("%v", err) + } trie := NewBinaryTrie() - trie.Update([]byte{0}, []byte{10}) + trie.Update(testAddr0, payload) - v, err := trie.TryGet([]byte{0}) + // Check the balance can be recovered + v, err := trie.TryGet(testAddr0) if err != nil { t.Fatalf("error searching for key 0 in trie, err=%v", err) } - if !bytes.Equal(v, []byte{10}) { - t.Fatalf("could not find correct value %x != 0a", v) + w, ok := v.(*big.Int) + if !ok { + t.Fatalf("did not recover proper value type: %v", v) + } + if w.Cmp(aoe.Balance) != 0 { + t.Fatalf("could not find correct value %d != %d", w, aoe.Balance) + } + + // Check the nonce can be recovered + v, err = trie.TryGet(testAddr1) + if err != nil { + t.Fatalf("error searching for key 1 in trie, err=%v", err) + } + x, ok := v.(uint64) + if !ok { + t.Fatalf("did not recover proper value type: %v", v) + } + if x != aoe.Nonce { + t.Fatalf("could not find correct value %x != %d", x, aoe.Nonce) + } + + // Check the code can be recovered (and is empty) + v, err = trie.TryGet(testAddr2) + if err != nil { + t.Fatalf("error searching for key 1 in trie, err=%v", err) + } + y, ok := v.(common.Hash) + if !ok { + t.Fatalf("did not recover proper value type: %v", v) + } + if !bytes.Equal(y[:], emptyCodeHash[:]) { + t.Fatalf("could not find correct value %x != %x", v, emptyCodeHash) } - _, err = trie.TryGet([]byte{1}) + // Check the root trie can be recovered (and is empty) + v, err = trie.TryGet(testAddr3) + if err != nil { + t.Fatalf("error searching for key 1 in trie, err=%v", err) + } + z, ok := v.(common.Hash) + if !ok { + t.Fatalf("did not recover proper value type: %v", v) + } + if !bytes.Equal(z[:], emptyRoot[:]) { + t.Fatalf("could not find correct value %x != %x", v, emptyRoot) + } + + v, err = trie.TryGet(testAddr4) if err != errKeyNotPresent { t.Fatalf("incorrect error received, expected '%v', got '%v'", errKeyNotPresent, err) } } func TestBinaryTrieReadOneFromManyLeaves(t *testing.T) { + payload, err := rlp.EncodeToBytes(aoe) + if err != nil { + t.Fatalf("%v", err) + } trie := NewBinaryTrie() - trie.Update([]byte{0}, []byte{10}) - trie.Update([]byte{8}, []byte{18}) - trie.Update([]byte{11}, []byte{20}) + trie.Update(testAddr0, payload) + trie.Update(int2addr(8), payload) + trie.Update(int2addr(15), payload) - v, err := trie.TryGet([]byte{0}) + v, err := trie.TryGet(testAddr1) if err != nil { - t.Fatalf("error searching for key 0 in trie, err=%v", err) + t.Fatalf("error searching for key 1 in trie, err=%v", err) } - if !bytes.Equal(v, []byte{10}) { - t.Fatalf("could not find correct value %x != 0a", v) + w, ok := v.(uint64) + if !ok { + t.Fatalf("did not recover proper value type: %v", v) + } + if w != 1 { + t.Fatalf("could not find correct value %d != %d", v, aoe.Nonce) + } + + v, err = trie.TryGet(int2addr(9)) + if err != nil { + t.Fatalf("error searching for key 9 in trie, err=%v", err) + } + w, ok = v.(uint64) + if !ok { + t.Fatalf("did not recover proper value type: %v", v) + } + if w != 1 { + t.Fatalf("could not find correct value %d != %d", v, aoe.Nonce) } - _, err = trie.TryGet([]byte{1}) + _, err = trie.TryGet(testAddr4) if err != errKeyNotPresent { t.Fatalf("incorrect error received, expected '%v', got '%v'", errKeyNotPresent, err) } @@ -241,6 +340,19 @@ func TestBinaryTrieNodeResolution(t *testing.T) { if len(trie.db.dirties) != 2 { t.Fatalf("invalid number of dirty account entries after insert, %d != 2", len(trie.db.dirties)) } + got := trie.Hash() + + // Insert all the values in a live trie to make sure + // the root hashes are the same. + trieref := NewBinaryTrie() + trieref.Update(key1, []byte{10}) + trieref.Update(key2, []byte{10}) + trieref.Update(key3, []byte{10}) + exp := trieref.Hash() + + if !bytes.Equal(got[:], exp[:]) { + t.Fatalf("invalid root %x != %x", got, exp) + } } func BenchmarkTrieHash(b *testing.B) { diff --git a/trie/binkey.go b/trie/binkey.go index 42194534144b..9a8123731f82 100644 --- a/trie/binkey.go +++ b/trie/binkey.go @@ -48,6 +48,14 @@ func (b binkey) commonLength(other binkey) int { } return length } + +// Compare the prefix by the number of bytes; there is +// a twist for bit #254 and #255, for which 4 out of 5 +// nodes are grouped into one. func (b binkey) samePrefix(other binkey, off int) bool { - return bytes.Equal(b[off:off+len(other)], other[:]) + var boundary = off + len(other) + if boundary >= 255 && boundary <= 256 { + boundary = 254 + } + return bytes.Equal(b[off:boundary], other[:boundary-off]) }