Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Decode witness to SMT #1363

Open
wants to merge 7 commits into
base: zkevm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions smt/pkg/db/mdbx.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package db

import (
"context"
"encoding/hex"
"math/big"

"fmt"
Expand Down Expand Up @@ -304,6 +305,19 @@ func (m *EriRoDb) GetCode(codeHash []byte) ([]byte, error) {
return data, nil
}

func (m *EriDb) AddCode(code []byte) error {
codeHash := utils.HashContractBytecode(hex.EncodeToString(code))

codeHashBytes, err := hex.DecodeString(strings.TrimPrefix(codeHash, "0x"))
if err != nil {
return err
}

codeHashBytes = utils.ResizeHashTo32BytesByPrefixingWithZeroes(codeHashBytes)

return m.tx.Put(kv.Code, codeHashBytes, code)
}

func (m *EriRoDb) PrintDb() {
err := m.kvTxRo.ForEach(TableSmt, []byte{}, func(k, v []byte) error {
println(string(k), string(v))
Expand Down
69 changes: 41 additions & 28 deletions smt/pkg/smt/entity_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,55 @@ import (
"github.com/ledgerwatch/erigon/smt/pkg/utils"
)

// SetAccountState sets the balance and nonce of an account
func (s *SMT) SetAccountState(ethAddr string, balance, nonce *big.Int) (*big.Int, error) {
_, err := s.SetAccountBalance(ethAddr, balance)
if err != nil {
return nil, err
}

auxOut, err := s.SetAccountNonce(ethAddr, nonce)
if err != nil {
return nil, err
}

return auxOut, nil
}

// SetAccountBalance sets the balance of an account
func (s *SMT) SetAccountBalance(ethAddr string, balance *big.Int) (*big.Int, error) {
keyBalance := utils.KeyEthAddrBalance(ethAddr)
keyNonce := utils.KeyEthAddrNonce(ethAddr)

if _, err := s.InsertKA(keyBalance, balance); err != nil {
response, err := s.InsertKA(keyBalance, balance)
if err != nil {
return nil, err
}

ks := utils.EncodeKeySource(utils.KEY_BALANCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
if err := s.Db.InsertKeySource(keyBalance, ks); err != nil {
err = s.Db.InsertKeySource(keyBalance, ks)
if err != nil {
return nil, err
}

auxRes, err := s.InsertKA(keyNonce, nonce)
return response.NewRootScalar.ToBigInt(), err
}

// SetAccountNonce sets the nonce of an account
func (s *SMT) SetAccountNonce(ethAddr string, nonce *big.Int) (*big.Int, error) {
keyNonce := utils.KeyEthAddrNonce(ethAddr)

response, err := s.InsertKA(keyNonce, nonce)
if err != nil {
return nil, err
}

ks = utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
if err := s.Db.InsertKeySource(keyNonce, ks); err != nil {
ks := utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
err = s.Db.InsertKeySource(keyNonce, ks)
if err != nil {
return nil, err
}

return auxRes.NewRootScalar.ToBigInt(), nil
return response.NewRootScalar.ToBigInt(), nil
}

func (s *SMT) SetAccountStorage(addr libcommon.Address, acc *accounts.Account) error {
Expand Down Expand Up @@ -80,13 +105,7 @@ func (s *SMT) SetContractBytecode(ethAddr string, bytecode string) error {

ks = utils.EncodeKeySource(utils.SC_LENGTH, utils.ConvertHexToAddress(ethAddr), common.Hash{})

err = s.Db.InsertKeySource(keyContractLength, ks)

if err != nil {
return err
}

return err
return s.Db.InsertKeySource(keyContractLength, ks)
}

func (s *SMT) SetContractStorage(ethAddr string, storage map[string]string, progressChan chan uint64) (*big.Int, error) {
Expand Down Expand Up @@ -203,7 +222,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
for addr, acc := range accChanges {
select {
case <-ctx.Done():
return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix))
return nil, nil, fmt.Errorf("[%s] Context done", logPrefix)
default:
}
ethAddr := addr.String()
Expand Down Expand Up @@ -250,7 +269,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
for addr, code := range codeChanges {
select {
case <-ctx.Done():
return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix))
return nil, nil, fmt.Errorf("[%s] Context done", logPrefix)
default:
}

Expand Down Expand Up @@ -295,7 +314,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
for addr, storage := range storageChanges {
select {
case <-ctx.Done():
return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix))
return nil, nil, fmt.Errorf("[%s] Context done", logPrefix)
default:
}
ethAddr := addr.String()
Expand All @@ -304,7 +323,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l

for k, v := range storage {
keyStoragePosition := utils.KeyContractStorage(ethAddrBigIngArray, k)
valueBigInt := convertStrintToBigInt(v)
valueBigInt := convertStringToBigInt(v)
keysBatchStorage = append(keysBatchStorage, &keyStoragePosition)
if valuesBatchStorage, isDelete, err = appendToValuesBatchStorageBigInt(valuesBatchStorage, valueBigInt); err != nil {
return nil, nil, err
Expand Down Expand Up @@ -341,7 +360,7 @@ func (s *SMT) DeleteKeySource(nodeKey *utils.NodeKey) error {
}

func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) {
val := convertStrintToBigInt(v)
val := convertStringToBigInt(v)

x := utils.ScalarToArrayBig(val)
value, err := utils.NodeValue8FromBigIntArray(x)
Expand All @@ -354,10 +373,10 @@ func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) {
return value, h, nil
}

func convertStrintToBigInt(v string) *big.Int {
func convertStringToBigInt(v string) *big.Int {
base := 10
if strings.HasPrefix(v, "0x") {
v = v[2:]
v = strings.TrimPrefix(v, "0x")
base = 16
}

Expand All @@ -374,14 +393,8 @@ func appendToValuesBatchStorageBigInt(valuesBatchStorage []*utils.NodeValue8, va
}

func convertBytecodeToBigInt(bytecode string) (*big.Int, int, error) {
var parsedBytecode string
bi := utils.HashContractBytecodeBigInt(bytecode)

if strings.HasPrefix(bytecode, "0x") {
parsedBytecode = bytecode[2:]
} else {
parsedBytecode = bytecode
}
parsedBytecode := strings.TrimPrefix(bytecode, "0x")

if len(parsedBytecode)%2 != 0 {
parsedBytecode = "0" + parsedBytecode
Expand Down
94 changes: 92 additions & 2 deletions smt/pkg/smt/smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type DB interface {
InsertKeySource(key utils.NodeKey, value []byte) error
DeleteKeySource(key utils.NodeKey) error
InsertHashKey(key utils.NodeKey, value utils.NodeKey) error
AddCode(code []byte) error
DeleteHashKey(key utils.NodeKey) error
Delete(string) error
DeleteByNodeKey(key utils.NodeKey) error
Expand Down Expand Up @@ -297,7 +298,9 @@ func (s *SMT) insert(k utils.NodeKey, v utils.NodeValue8, newValH [4]uint64, old
if err != nil {
return nil, err
}
s.Db.InsertHashKey(newLeafHash, k)
if err := s.Db.InsertHashKey(newLeafHash, k); err != nil {
return nil, err
}
if level >= 0 {
for j := 0; j < 4; j++ {
siblings[level][keys[level]*4+j] = new(big.Int).SetUint64(newLeafHash[j])
Expand Down Expand Up @@ -649,7 +652,7 @@ func (s *SMT) updateDepth(newDepth int) {

newDepthAsByte := byte(newDepth & 0xFF)
if oldDepth < newDepthAsByte {
s.Db.SetDepth(newDepthAsByte)
_ = s.Db.SetDepth(newDepthAsByte)
}
}

Expand Down Expand Up @@ -728,3 +731,90 @@ func (s *RoSMT) traverseAndMark(ctx context.Context, node *big.Int, visited Visi
return true, nil
})
}

// InsertHashNode inserts a hash node into the SMT. The SMT should not contain any other leaf nodes with the same path prefix. Otherwise, the new root hash will be incorrect.
// TODO: Support insertion of hash nodes even if there are leaf nodes with the same path prefix in SMT.
func (s *SMT) InsertHashNode(path []int, hash *big.Int) (*big.Int, error) {
s.clearUpMutex.Lock()
defer s.clearUpMutex.Unlock()

or, err := s.getLastRoot()
if err != nil {
return nil, err
}

h := utils.ScalarToArray(hash)

var nodeHash [4]uint64
copy(nodeHash[:], h[:4])

lastRoot, err := s.insertHashNode(path, nodeHash, or)
if err != nil {
return nil, err
}

if err = s.setLastRoot(lastRoot); err != nil {
return nil, err
}

return lastRoot.ToBigInt(), nil
}

func (s *SMT) insertHashNode(path []int, hash [4]uint64, root utils.NodeKey) (utils.NodeKey, error) {
if len(path) == 0 {
newValHBig := utils.ArrayToScalar(hash[:])
v := utils.ScalarToNodeValue8(newValHBig)

err := s.hashSave(v.ToUintArray(), utils.LeafCapacity, hash)
if err != nil {
return utils.NodeKey{}, err
}

return hash, nil
}

rootVal := utils.NodeValue12{}

if !root.IsZero() {
v, err := s.Db.Get(root)
if err != nil {
return utils.NodeKey{}, err
}

rootVal = v
}

childIndex := path[0]

childOldRoot := rootVal[childIndex*4 : childIndex*4+4]

childNewRoot, err := s.insertHashNode(path[1:], hash, utils.NodeKeyFromBigIntArray(childOldRoot))

if err != nil {
return utils.NodeKey{}, err
}

var newIn [8]uint64

emptyRootVal := utils.NodeValue12{}

if childIndex == 0 {
var sibling [4]uint64
if rootVal == emptyRootVal {
sibling = [4]uint64{0, 0, 0, 0}
} else {
sibling = *rootVal.Get4to8()
}
newIn = utils.ConcatArrays4(childNewRoot, sibling)
} else {
var sibling [4]uint64
if rootVal == emptyRootVal {
sibling = [4]uint64{0, 0, 0, 0}
} else {
sibling = *rootVal.Get0to4()
}
newIn = utils.ConcatArrays4(sibling, childNewRoot)
}

return s.hashcalcAndSave(newIn, utils.BranchCapacity)
}
Loading
Loading