Skip to content

Commit

Permalink
all: refactor trie API
Browse files Browse the repository at this point in the history
In this PR, all TryXXX(e.g. TryGet) APIs of trie are renamed to
XXX(e.g. Get) with an error returned.

The original XXX(e.g. Get) APIs are renamed to MustXXX(e.g. MustGet)
which will panic in case any error occurs.
  • Loading branch information
rjl493456442 committed Apr 15, 2023
1 parent 4a9fa31 commit 225e20b
Show file tree
Hide file tree
Showing 28 changed files with 257 additions and 235 deletions.
3 changes: 2 additions & 1 deletion core/rawdb/accessors_indexes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ func (h *testHasher) Reset() {
h.hasher.Reset()
}

func (h *testHasher) Update(key, val []byte) {
func (h *testHasher) Update(key, val []byte) error {
h.hasher.Write(key)
h.hasher.Write(val)
return nil
}

func (h *testHasher) Hash() common.Hash {
Expand Down
2 changes: 1 addition & 1 deletion core/state/snapshot/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func stackTrieGenerate(db ethdb.KeyValueWriter, scheme string, owner common.Hash
}
t := trie.NewStackTrieWithOwner(nodeWriter, owner)
for leaf := range in {
t.TryUpdate(leaf.key[:], leaf.value)
t.Update(leaf.key[:], leaf.value)
}
var root common.Hash
if db == nil {
Expand Down
2 changes: 1 addition & 1 deletion core/state/snapshot/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func (dl *diskLayer) proveRange(ctx *generatorContext, trieId *trie.ID, prefix [
if origin == nil && !diskMore {
stackTr := trie.NewStackTrie(nil)
for i, key := range keys {
stackTr.TryUpdate(key, vals[i])
stackTr.Update(key, vals[i])
}
if gotRoot := stackTr.Hash(); gotRoot != root {
return &proofResult{
Expand Down
14 changes: 7 additions & 7 deletions core/state/snapshot/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func newHelper() *testHelper {

func (t *testHelper) addTrieAccount(acckey string, acc *Account) {
val, _ := rlp.EncodeToBytes(acc)
t.accTrie.Update([]byte(acckey), val)
t.accTrie.MustUpdate([]byte(acckey), val)
}

func (t *testHelper) addSnapAccount(acckey string, acc *Account) {
Expand All @@ -186,7 +186,7 @@ func (t *testHelper) makeStorageTrie(stateRoot, owner common.Hash, keys []string
id := trie.StorageTrieID(stateRoot, owner, common.Hash{})
stTrie, _ := trie.NewStateTrie(id, t.triedb)
for i, k := range keys {
stTrie.Update([]byte(k), []byte(vals[i]))
stTrie.MustUpdate([]byte(k), []byte(vals[i]))
}
if !commit {
return stTrie.Hash().Bytes()
Expand Down Expand Up @@ -491,7 +491,7 @@ func TestGenerateWithExtraAccounts(t *testing.T) {
)
acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}
val, _ := rlp.EncodeToBytes(acc)
helper.accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
helper.accTrie.MustUpdate([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e

// Identical in the snap
key := hashData([]byte("acc-1"))
Expand Down Expand Up @@ -562,7 +562,7 @@ func TestGenerateWithManyExtraAccounts(t *testing.T) {
)
acc := &Account{Balance: big.NewInt(1), Root: stRoot, CodeHash: types.EmptyCodeHash.Bytes()}
val, _ := rlp.EncodeToBytes(acc)
helper.accTrie.Update([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e
helper.accTrie.MustUpdate([]byte("acc-1"), val) // 0x9250573b9c18c664139f3b6a7a8081b7d8f8916a8fcc5d94feec6c29f5fd4e9e

// Identical in the snap
key := hashData([]byte("acc-1"))
Expand Down Expand Up @@ -613,8 +613,8 @@ func TestGenerateWithExtraBeforeAndAfter(t *testing.T) {
{
acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}
val, _ := rlp.EncodeToBytes(acc)
helper.accTrie.Update(common.HexToHash("0x03").Bytes(), val)
helper.accTrie.Update(common.HexToHash("0x07").Bytes(), val)
helper.accTrie.MustUpdate(common.HexToHash("0x03").Bytes(), val)
helper.accTrie.MustUpdate(common.HexToHash("0x07").Bytes(), val)

rawdb.WriteAccountSnapshot(helper.diskdb, common.HexToHash("0x01"), val)
rawdb.WriteAccountSnapshot(helper.diskdb, common.HexToHash("0x02"), val)
Expand Down Expand Up @@ -650,7 +650,7 @@ func TestGenerateWithMalformedSnapdata(t *testing.T) {
{
acc := &Account{Balance: big.NewInt(1), Root: types.EmptyRootHash.Bytes(), CodeHash: types.EmptyCodeHash.Bytes()}
val, _ := rlp.EncodeToBytes(acc)
helper.accTrie.Update(common.HexToHash("0x03").Bytes(), val)
helper.accTrie.MustUpdate(common.HexToHash("0x03").Bytes(), val)

junk := make([]byte, 100)
copy(junk, []byte{0xde, 0xad})
Expand Down
6 changes: 3 additions & 3 deletions core/state/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,22 @@ func testIterativeStateSync(t *testing.T, count int, commit bool, bypath bool) {
for i, node := range nodeElements {
if bypath {
if len(node.syncPath) == 1 {
data, _, err := srcTrie.TryGetNode(node.syncPath[0])
data, _, err := srcTrie.GetNode(node.syncPath[0])
if err != nil {
t.Fatalf("failed to retrieve node data for path %x: %v", node.syncPath[0], err)
}
nodeResults[i] = trie.NodeSyncResult{Path: node.path, Data: data}
} else {
var acc types.StateAccount
if err := rlp.DecodeBytes(srcTrie.Get(node.syncPath[0]), &acc); err != nil {
if err := rlp.DecodeBytes(srcTrie.MustGet(node.syncPath[0]), &acc); err != nil {
t.Fatalf("failed to decode account on path %x: %v", node.syncPath[0], err)
}
id := trie.StorageTrieID(srcRoot, common.BytesToHash(node.syncPath[0]), acc.Root)
stTrie, err := trie.New(id, srcDb.TrieDB())
if err != nil {
t.Fatalf("failed to retriev storage trie for path %x: %v", node.syncPath[1], err)
}
data, _, err := stTrie.TryGetNode(node.syncPath[1])
data, _, err := stTrie.GetNode(node.syncPath[1])
if err != nil {
t.Fatalf("failed to retrieve node data for path %x: %v", node.syncPath[1], err)
}
Expand Down
3 changes: 2 additions & 1 deletion core/types/block_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,10 @@ func (h *testHasher) Reset() {
h.hasher.Reset()
}

func (h *testHasher) Update(key, val []byte) {
func (h *testHasher) Update(key, val []byte) error {
h.hasher.Write(key)
h.hasher.Write(val)
return nil
}

func (h *testHasher) Hash() common.Hash {
Expand Down
5 changes: 4 additions & 1 deletion core/types/hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func prefixedRlpHash(prefix byte, x interface{}) (h common.Hash) {
// This is internal, do not use.
type TrieHasher interface {
Reset()
Update([]byte, []byte)
Update([]byte, []byte) error
Hash() common.Hash
}

Expand Down Expand Up @@ -93,6 +93,9 @@ func DeriveSha(list DerivableList, hasher TrieHasher) common.Hash {
// StackTrie requires values to be inserted in increasing hash order, which is not the
// order that `list` provides hashes in. This insertion sequence ensures that the
// order is correct.
//
// The error returned by hasher is omitted because hasher will produce an incorrect
// hash in case any error occurs.
var indexBuf []byte
for i := 1; i < list.Len() && i <= 0x7f; i++ {
indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
Expand Down
3 changes: 2 additions & 1 deletion core/types/hashing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,10 @@ func (d *hashToHumanReadable) Reset() {
d.data = make([]byte, 0)
}

func (d *hashToHumanReadable) Update(i []byte, i2 []byte) {
func (d *hashToHumanReadable) Update(i []byte, i2 []byte) error {
l := fmt.Sprintf("%x %x\n", i, i2)
d.data = append(d.data, []byte(l)...)
return nil
}

func (d *hashToHumanReadable) Hash() common.Hash {
Expand Down
20 changes: 10 additions & 10 deletions eth/protocols/snap/sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash,
for _, pathset := range paths {
switch len(pathset) {
case 1:
blob, _, err := t.accountTrie.TryGetNode(pathset[0])
blob, _, err := t.accountTrie.GetNode(pathset[0])
if err != nil {
t.logger.Info("Error handling req", "error", err)
break
Expand All @@ -225,7 +225,7 @@ func defaultTrieRequestHandler(t *testPeer, requestId uint64, root common.Hash,
default:
account := t.storageTries[(common.BytesToHash(pathset[0]))]
for _, path := range pathset[1:] {
blob, _, err := account.TryGetNode(path)
blob, _, err := account.GetNode(path)
if err != nil {
t.logger.Info("Error handling req", "error", err)
break
Expand Down Expand Up @@ -1381,7 +1381,7 @@ func makeAccountTrieNoStorage(n int) (string, *trie.Trie, entrySlice) {
})
key := key32(i)
elem := &kv{key, value}
accTrie.Update(elem.k, elem.v)
accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)
}
sort.Sort(entries)
Expand Down Expand Up @@ -1431,7 +1431,7 @@ func makeBoundaryAccountTrie(n int) (string, *trie.Trie, entrySlice) {
CodeHash: getCodeHash(uint64(i)),
})
elem := &kv{boundaries[i].Bytes(), value}
accTrie.Update(elem.k, elem.v)
accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)
}
// Fill other accounts if required
Expand All @@ -1443,7 +1443,7 @@ func makeBoundaryAccountTrie(n int) (string, *trie.Trie, entrySlice) {
CodeHash: getCodeHash(i),
})
elem := &kv{key32(i), value}
accTrie.Update(elem.k, elem.v)
accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)
}
sort.Sort(entries)
Expand Down Expand Up @@ -1487,7 +1487,7 @@ func makeAccountTrieWithStorageWithUniqueStorage(accounts, slots int, code bool)
CodeHash: codehash,
})
elem := &kv{key, value}
accTrie.Update(elem.k, elem.v)
accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)

storageRoots[common.BytesToHash(key)] = stRoot
Expand Down Expand Up @@ -1551,7 +1551,7 @@ func makeAccountTrieWithStorage(accounts, slots int, code, boundary bool) (strin
CodeHash: codehash,
})
elem := &kv{key, value}
accTrie.Update(elem.k, elem.v)
accTrie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)

// we reuse the same one for all accounts
Expand Down Expand Up @@ -1599,7 +1599,7 @@ func makeStorageTrieWithSeed(owner common.Hash, n, seed uint64, db *trie.Databas
key := crypto.Keccak256Hash(slotKey[:])

elem := &kv{key[:], rlpSlotValue}
trie.Update(elem.k, elem.v)
trie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)
}
sort.Sort(entries)
Expand Down Expand Up @@ -1638,7 +1638,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo
val := []byte{0xde, 0xad, 0xbe, 0xef}

elem := &kv{key[:], val}
trie.Update(elem.k, elem.v)
trie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)
}
// Fill other slots if required
Expand All @@ -1650,7 +1650,7 @@ func makeBoundaryStorageTrie(owner common.Hash, n int, db *trie.Database) (commo
rlpSlotValue, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(slotValue[:]))

elem := &kv{key[:], rlpSlotValue}
trie.Update(elem.k, elem.v)
trie.MustUpdate(elem.k, elem.v)
entries = append(entries, elem)
}
sort.Sort(entries)
Expand Down
2 changes: 1 addition & 1 deletion les/server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ func getAccount(triedb *trie.Database, root, hash common.Hash) (types.StateAccou
if err != nil {
return types.StateAccount{}, err
}
blob, err := trie.TryGet(hash[:])
blob, err := trie.Get(hash[:])
if err != nil {
return types.StateAccount{}, err
}
Expand Down
12 changes: 8 additions & 4 deletions light/postprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ func (c *ChtIndexerBackend) Process(ctx context.Context, header *types.Header) e
var encNumber [8]byte
binary.BigEndian.PutUint64(encNumber[:], num)
data, _ := rlp.EncodeToBytes(ChtNode{hash, td})
c.trie.Update(encNumber[:], data)
return nil
return c.trie.Update(encNumber[:], data)
}

// Commit implements core.ChainIndexerBackend
Expand Down Expand Up @@ -450,10 +449,15 @@ func (b *BloomTrieIndexerBackend) Commit() error {

decompSize += uint64(len(decomp))
compSize += uint64(len(comp))

var terr error
if len(comp) > 0 {
b.trie.Update(encKey[:], comp)
terr = b.trie.Update(encKey[:], comp)
} else {
b.trie.Delete(encKey[:])
terr = b.trie.Delete(encKey[:])
}
if terr != nil {
return terr
}
}
root, nodes := b.trie.Commit(false)
Expand Down
12 changes: 6 additions & 6 deletions light/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (t *odrTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) {
key = crypto.Keccak256(key)
var res []byte
err := t.do(key, func() (err error) {
res, err = t.trie.TryGet(key)
res, err = t.trie.Get(key)
return err
})
return res, err
Expand All @@ -119,7 +119,7 @@ func (t *odrTrie) GetAccount(address common.Address) (*types.StateAccount, error
var res types.StateAccount
key := crypto.Keccak256(address.Bytes())
err := t.do(key, func() (err error) {
value, err := t.trie.TryGet(key)
value, err := t.trie.Get(key)
if err != nil {
return err
}
Expand All @@ -138,29 +138,29 @@ func (t *odrTrie) UpdateAccount(address common.Address, acc *types.StateAccount)
return fmt.Errorf("decoding error in account update: %w", err)
}
return t.do(key, func() error {
return t.trie.TryUpdate(key, value)
return t.trie.Update(key, value)
})
}

func (t *odrTrie) UpdateStorage(_ common.Address, key, value []byte) error {
key = crypto.Keccak256(key)
return t.do(key, func() error {
return t.trie.TryUpdate(key, value)
return t.trie.Update(key, value)
})
}

func (t *odrTrie) DeleteStorage(_ common.Address, key []byte) error {
key = crypto.Keccak256(key)
return t.do(key, func() error {
return t.trie.TryDelete(key)
return t.trie.Delete(key)
})
}

// TryDeleteAccount abstracts an account deletion from the trie.
func (t *odrTrie) DeleteAccount(address common.Address) error {
key := crypto.Keccak256(address.Bytes())
return t.do(key, func() error {
return t.trie.TryDelete(key)
return t.trie.Delete(key)
})
}

Expand Down
4 changes: 2 additions & 2 deletions tests/fuzzers/les/les-fuzzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ func makeTries() (chtTrie *trie.Trie, bloomTrie *trie.Trie, chtKeys, bloomKeys [
// The element in CHT is <big-endian block number> -> <block hash>
key := make([]byte, 8)
binary.BigEndian.PutUint64(key, uint64(i+1))
chtTrie.Update(key, []byte{0x1, 0xf})
chtTrie.MustUpdate(key, []byte{0x1, 0xf})
chtKeys = append(chtKeys, key)

// The element in Bloom trie is <2 byte bit index> + <big-endian block number> -> bloom
key2 := make([]byte, 10)
binary.BigEndian.PutUint64(key2[2:], uint64(i+1))
bloomTrie.Update(key2, []byte{0x2, 0xe})
bloomTrie.MustUpdate(key2, []byte{0x2, 0xe})
bloomKeys = append(bloomKeys, key2)
}
return
Expand Down
6 changes: 3 additions & 3 deletions tests/fuzzers/rangeproof/rangeproof-fuzzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) {
for i := byte(0); i < byte(size); i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
trie.Update(value.k, value.v)
trie.Update(value2.k, value2.v)
trie.MustUpdate(value.k, value.v)
trie.MustUpdate(value2.k, value2.v)
vals[string(value.k)] = value
vals[string(value2.k)] = value2
}
Expand All @@ -82,7 +82,7 @@ func (f *fuzzer) randomTrie(n int) (*trie.Trie, map[string]*kv) {
k := f.randBytes(32)
v := f.randBytes(20)
value := &kv{k, v, false}
trie.Update(k, v)
trie.MustUpdate(k, v)
vals[string(k)] = value
if f.exhausted {
return nil, nil
Expand Down
6 changes: 3 additions & 3 deletions tests/fuzzers/stacktrie/trie_fuzzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (f *fuzzer) fuzz() int {
}
keys[string(k)] = struct{}{}
vals = append(vals, kv{k: k, v: v})
trieA.Update(k, v)
trieA.MustUpdate(k, v)
useful = true
}
if !useful {
Expand All @@ -195,7 +195,7 @@ func (f *fuzzer) fuzz() int {
if f.debugging {
fmt.Printf("{\"%#x\" , \"%#x\"} // stacktrie.Update\n", kv.k, kv.v)
}
trieB.Update(kv.k, kv.v)
trieB.MustUpdate(kv.k, kv.v)
}
rootB := trieB.Hash()
trieB.Commit()
Expand Down Expand Up @@ -223,7 +223,7 @@ func (f *fuzzer) fuzz() int {
checked int
)
for _, kv := range vals {
trieC.Update(kv.k, kv.v)
trieC.MustUpdate(kv.k, kv.v)
}
rootC, _ := trieC.Commit()
if rootA != rootC {
Expand Down
Loading

0 comments on commit 225e20b

Please sign in to comment.