Skip to content

Commit

Permalink
implement Gary's suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
gballet committed Aug 3, 2022
1 parent 7b4c185 commit ff32da7
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 70 deletions.
4 changes: 2 additions & 2 deletions core/state/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Trie interface {
// GetKey returns the sha3 preimage of a hashed key that was previously used
// to store a value.
//
// TODO(fjl): remove this when SecureTrie is removed
// TODO(fjl): remove this when StateTrie is removed
GetKey([]byte) []byte

// TryGet returns the value for key stored in the trie. The value bytes must
Expand Down Expand Up @@ -155,7 +155,7 @@ func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) {
// CopyTrie returns an independent copy of the given trie.
func (db *cachingDB) CopyTrie(t Trie) Trie {
switch t := t.(type) {
case *trie.SecureTrie:
case *trie.StateTrie:
return t.Copy()
default:
panic(fmt.Errorf("unknown trie type %T", t))
Expand Down
2 changes: 1 addition & 1 deletion core/state/snapshot/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func checkSnapRoot(t *testing.T, snap *diskLayer, trieRoot common.Hash) {
type testHelper struct {
diskdb ethdb.Database
triedb *trie.Database
accTrie *trie.SecureTrie
accTrie *trie.StateTrie
}

func newHelper() *testHelper {
Expand Down
6 changes: 3 additions & 3 deletions eth/protocols/snap/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,15 +413,15 @@ func ServiceGetStorageRangesQuery(chain *core.BlockChain, req *GetStorageRangesP
if origin != (common.Hash{}) || (abort && len(storage) > 0) {
// Request started at a non-zero hash or was capped prematurely, add
// the endpoint Merkle proofs
accTrie, err := trie.New(common.Hash{}, req.Root, chain.StateCache().TrieDB())
accTrie, err := trie.NewSecure(common.Hash{}, req.Root, chain.StateCache().TrieDB())
if err != nil {
return nil, nil
}
acc, err := accTrie.TryGetAccount(account[:])
if err != nil {
if err != nil || acc == nil {
return nil, nil
}
stTrie, err := trie.New(account, acc.Root, chain.StateCache().TrieDB())
stTrie, err := trie.NewSecure(account, acc.Root, chain.StateCache().TrieDB())
if err != nil {
return nil, nil
}
Expand Down
14 changes: 10 additions & 4 deletions light/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,18 @@ func (t *odrTrie) TryGet(key []byte) ([]byte, error) {

func (t *odrTrie) TryGetAccount(key []byte) (*types.StateAccount, error) {
key = crypto.Keccak256(key)
var res *types.StateAccount
var res types.StateAccount
err := t.do(key, func() (err error) {
res, err = t.trie.TryGetAccount(key)
return err
value, err := t.trie.TryGet(key)
if err != nil {
return err
}
if value == nil {
return nil
}
return rlp.DecodeBytes(value, &res)
})
return res, err
return &res, err
}

func (t *odrTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error {
Expand Down
2 changes: 1 addition & 1 deletion trie/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func (l *loggingDb) Close() error {
}

// makeLargeTestTrie create a sample test trie
func makeLargeTestTrie() (*Database, *SecureTrie, *loggingDb) {
func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) {
// Create an empty trie
logDb := &loggingDb{0, memorydb.New()}
triedb := NewDatabase(logDb)
Expand Down
2 changes: 1 addition & 1 deletion trie/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
// If the trie does not contain a value for key, the returned proof contains all
// nodes of the longest existing prefix of the key (at least the root node), ending
// with the node that proves the absence of the key.
func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error {
func (t *StateTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error {
return t.trie.Prove(key, fromLevel, proofDb)
}

Expand Down
64 changes: 37 additions & 27 deletions trie/secure_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ import (
"github.com/ethereum/go-ethereum/rlp"
)

// SecureTrie wraps a trie with key hashing. In a secure trie, all
// StateTrie wraps a trie with key hashing. In a secure trie, all
// access operations hash the key using keccak256. This prevents
// calling code from creating long chains of nodes that
// increase the access time.
//
// Contrary to a regular trie, a SecureTrie can only be created with
// Contrary to a regular trie, a StateTrie can only be created with
// New and must have an attached database. The database also stores
// the preimage of each key.
//
// SecureTrie is not safe for concurrent use.
type SecureTrie struct {
// StateTrie is not safe for concurrent use.
type StateTrie struct {
trie Trie
preimages *preimageStore
hashKeyBuf [common.HashLength]byte
secKeyCache map[string][]byte
secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch
secKeyCacheOwner *StateTrie // Pointer to self, replace the key cache on mismatch
}

// NewSecure creates a trie with an existing root node from a backing database
Expand All @@ -54,20 +54,20 @@ type SecureTrie struct {
// Loaded nodes are kept around until their 'cache generation' expires.
// A new cache generation is created by each call to Commit.
// cachelimit sets the number of past cache generations to keep.
func NewSecure(owner common.Hash, root common.Hash, db *Database) (*SecureTrie, error) {
func NewSecure(owner common.Hash, root common.Hash, db *Database) (*StateTrie, error) {
if db == nil {
panic("trie.NewSecure called without a database")
}
trie, err := New(owner, root, db)
if err != nil {
return nil, err
}
return &SecureTrie{trie: *trie, preimages: db.preimages}, nil
return &StateTrie{trie: *trie, preimages: db.preimages}, nil
}

// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *SecureTrie) Get(key []byte) []byte {
func (t *StateTrie) Get(key []byte) []byte {
res, err := t.TryGet(key)
if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
Expand All @@ -78,23 +78,33 @@ func (t *SecureTrie) Get(key []byte) []byte {
// TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned.
func (t *SecureTrie) TryGet(key []byte) ([]byte, error) {
func (t *StateTrie) TryGet(key []byte) ([]byte, error) {
return t.trie.TryGet(t.hashKey(key))
}

func (t *SecureTrie) TryGetAccount(key []byte) (*types.StateAccount, error) {
return t.trie.TryGetAccount(t.hashKey(key))
func (t *StateTrie) TryGetAccount(key []byte) (*types.StateAccount, error) {
var ret types.StateAccount
res, err := t.trie.TryGet(t.hashKey(key))
if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
return &ret, err
}
if res == nil {
return nil, nil
}
err = rlp.DecodeBytes(res, &ret)
return &ret, err
}

// TryGetNode attempts to retrieve a trie node by compact-encoded path. It is not
// possible to use keybyte-encoding as the path might contain odd nibbles.
func (t *SecureTrie) TryGetNode(path []byte) ([]byte, int, error) {
func (t *StateTrie) TryGetNode(path []byte) ([]byte, int, error) {
return t.trie.TryGetNode(path)
}

// TryUpdateAccount account will abstract the write of an account to the
// secure trie.
func (t *SecureTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error {
func (t *StateTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error {
hk := t.hashKey(key)
data, err := rlp.EncodeToBytes(acc)
if err != nil {
Expand All @@ -113,7 +123,7 @@ func (t *SecureTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error
//
// The value bytes must not be modified by the caller while they are
// stored in the trie.
func (t *SecureTrie) Update(key, value []byte) {
func (t *StateTrie) Update(key, value []byte) {
if err := t.TryUpdate(key, value); err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
}
Expand All @@ -127,7 +137,7 @@ func (t *SecureTrie) Update(key, value []byte) {
// stored in the trie.
//
// If a node was not found in the database, a MissingNodeError is returned.
func (t *SecureTrie) TryUpdate(key, value []byte) error {
func (t *StateTrie) TryUpdate(key, value []byte) error {
hk := t.hashKey(key)
err := t.trie.TryUpdate(hk, value)
if err != nil {
Expand All @@ -138,23 +148,23 @@ func (t *SecureTrie) TryUpdate(key, value []byte) error {
}

// Delete removes any existing value for key from the trie.
func (t *SecureTrie) Delete(key []byte) {
func (t *StateTrie) Delete(key []byte) {
if err := t.TryDelete(key); err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
}
}

// TryDelete removes any existing value for key from the trie.
// If a node was not found in the database, a MissingNodeError is returned.
func (t *SecureTrie) TryDelete(key []byte) error {
func (t *StateTrie) TryDelete(key []byte) error {
hk := t.hashKey(key)
delete(t.getSecKeyCache(), string(hk))
return t.trie.TryDelete(hk)
}

// GetKey returns the sha3 preimage of a hashed key that was
// previously used to store a value.
func (t *SecureTrie) GetKey(shaKey []byte) []byte {
func (t *StateTrie) GetKey(shaKey []byte) []byte {
if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
return key
}
Expand All @@ -169,7 +179,7 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
//
// Committing flushes nodes from memory. Subsequent Get calls will load nodes
// from the database.
func (t *SecureTrie) Commit(onleaf LeafCallback) (common.Hash, int, error) {
func (t *StateTrie) Commit(onleaf LeafCallback) (common.Hash, int, error) {
// Write all the pre-images to the actual disk database
if len(t.getSecKeyCache()) > 0 {
if t.preimages != nil {
Expand All @@ -185,15 +195,15 @@ func (t *SecureTrie) Commit(onleaf LeafCallback) (common.Hash, int, error) {
return t.trie.Commit(onleaf)
}

// Hash returns the root hash of SecureTrie. It does not write to the
// Hash returns the root hash of StateTrie. It does not write to the
// database and can be used even if the trie doesn't have one.
func (t *SecureTrie) Hash() common.Hash {
func (t *StateTrie) Hash() common.Hash {
return t.trie.Hash()
}

// Copy returns a copy of SecureTrie.
func (t *SecureTrie) Copy() *SecureTrie {
return &SecureTrie{
// Copy returns a copy of StateTrie.
func (t *StateTrie) Copy() *StateTrie {
return &StateTrie{
trie: *t.trie.Copy(),
preimages: t.preimages,
secKeyCache: t.secKeyCache,
Expand All @@ -202,14 +212,14 @@ func (t *SecureTrie) Copy() *SecureTrie {

// NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration
// starts at the key after the given start key.
func (t *SecureTrie) NodeIterator(start []byte) NodeIterator {
func (t *StateTrie) NodeIterator(start []byte) NodeIterator {
return t.trie.NodeIterator(start)
}

// hashKey returns the hash of key as an ephemeral buffer.
// The caller must not hold onto the return value because it will become
// invalid on the next call to hashKey or secKey.
func (t *SecureTrie) hashKey(key []byte) []byte {
func (t *StateTrie) hashKey(key []byte) []byte {
h := newHasher(false)
h.sha.Reset()
h.sha.Write(key)
Expand All @@ -221,7 +231,7 @@ func (t *SecureTrie) hashKey(key []byte) []byte {
// getSecKeyCache returns the current secure key cache, creating a new one if
// ownership changed (i.e. the current secure trie is a copy of another owning
// the actual cache).
func (t *SecureTrie) getSecKeyCache() map[string][]byte {
func (t *StateTrie) getSecKeyCache() map[string][]byte {
if t != t.secKeyCacheOwner {
t.secKeyCacheOwner = t
t.secKeyCache = make(map[string][]byte)
Expand Down
12 changes: 6 additions & 6 deletions trie/secure_trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ import (
"github.com/ethereum/go-ethereum/ethdb/memorydb"
)

func newEmptySecure() *SecureTrie {
func newEmptySecure() *StateTrie {
trie, _ := NewSecure(common.Hash{}, common.Hash{}, NewDatabase(memorydb.New()))
return trie
}

// makeTestSecureTrie creates a large enough secure trie for testing.
func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) {
// makeTestStateTrie creates a large enough secure trie for testing.
func makeTestStateTrie() (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie
triedb := NewDatabase(memorydb.New())
trie, _ := NewSecure(common.Hash{}, common.Hash{}, triedb)
Expand Down Expand Up @@ -105,12 +105,12 @@ func TestSecureGetKey(t *testing.T) {
}
}

func TestSecureTrieConcurrency(t *testing.T) {
func TestStateTrieConcurrency(t *testing.T) {
// Create an initial trie and copy if for concurrent access
_, trie, _ := makeTestSecureTrie()
_, trie, _ := makeTestStateTrie()

threads := runtime.NumCPU()
tries := make([]*SecureTrie, threads)
tries := make([]*StateTrie, threads)
for i := 0; i < threads; i++ {
tries[i] = trie.Copy()
}
Expand Down
2 changes: 1 addition & 1 deletion trie/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

// makeTestTrie create a sample test trie to test node-wise reconstruction.
func makeTestTrie() (*Database, *SecureTrie, map[string][]byte) {
func makeTestTrie() (*Database, *StateTrie, map[string][]byte) {
// Create an empty trie
triedb := NewDatabase(memorydb.New())
trie, _ := NewSecure(common.Hash{}, common.Hash{}, triedb)
Expand Down
24 changes: 0 additions & 24 deletions trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ import (

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
)

var (
Expand Down Expand Up @@ -154,20 +152,6 @@ func (t *Trie) Get(key []byte) []byte {
return res
}

func (t *Trie) TryGetAccount(key []byte) (*types.StateAccount, error) {
res, err := t.TryGet(key)
if err != nil {
log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
return &types.StateAccount{}, err
}
if res == nil {
return nil, nil
}
var ret types.StateAccount
err = rlp.DecodeBytes(res, &ret)
return &ret, err
}

// TryGet returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
// If a node was not found in the database, a MissingNodeError is returned.
Expand Down Expand Up @@ -304,14 +288,6 @@ func (t *Trie) Update(key, value []byte) {
}
}

func (t *Trie) TryUpdateAccount(key []byte, acc *types.StateAccount) error {
data, err := rlp.EncodeToBytes(acc)
if err != nil {
return fmt.Errorf("can't encode object at %x: %w", key[:], err)
}
return t.tryUpdate(key, data)
}

// TryUpdate associates key with value in the trie. Subsequent calls to
// Get will return value. If value has length zero, any existing value
// is deleted from the trie and calls to Get will return nil.
Expand Down

0 comments on commit ff32da7

Please sign in to comment.