Skip to content

Commit

Permalink
Merge pull request ethereum#36 from zama-ai/petar/sstore-sload-tests
Browse files Browse the repository at this point in the history
Add unit tests for SSTORE, SLOAD and RETURN ops
  • Loading branch information
dartdart26 authored Feb 1, 2023
2 parents 83bbc84 + 6b2a383 commit 69631db
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 26 deletions.
62 changes: 36 additions & 26 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
)

// precompiledTest defines the input/output pairs for precompiled contract tests.
Expand Down Expand Up @@ -414,24 +416,32 @@ func (s *statefulPrecompileAccessibleState) Interpreter() *EVMInterpreter {
return s.interpreter
}

func newState() *statefulPrecompileAccessibleState {
s := new(statefulPrecompileAccessibleState)
func newTestInterpreter() *EVMInterpreter {
cfg := Config{}
evm := &EVM{}
s.interpreter = NewEVMInterpreter(evm, cfg)
evm.interpreter = s.interpreter
interpreter := NewEVMInterpreter(evm, cfg)
db := rawdb.NewMemoryDatabase()
state, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
interpreter.evm.StateDB = state
return interpreter
}

func newTestState() *statefulPrecompileAccessibleState {
s := new(statefulPrecompileAccessibleState)
interpreter := newTestInterpreter()
s.interpreter = interpreter
return s
}

func verifyCiphertextInTestState(s *statefulPrecompileAccessibleState, value uint64, depth int) (*tfheCiphertext, common.Hash) {
func verifyCiphertextInTestState(interpreter *EVMInterpreter, value uint64, depth int) (*tfheCiphertext, common.Hash) {
ct := new(tfheCiphertext)
ct.encrypt(value)
hash := ct.getHash()
s.interpreter.verifiedCiphertexts[hash] = &verifiedCiphertext{depth, ct}
interpreter.verifiedCiphertexts[hash] = &verifiedCiphertext{depth, ct}
return ct, ct.getHash()
}

func generateInput(hashes ...common.Hash) []byte {
func toPrecompileInput(hashes ...common.Hash) []byte {
ret := make([]byte, 0)
for _, hash := range hashes {
ret = append(ret, hash.Bytes()...)
Expand All @@ -442,14 +452,14 @@ func generateInput(hashes ...common.Hash) []byte {
func TestFheAdd(t *testing.T) {
c := &fheAdd{}
depth := 1
state := newState()
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 1, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)
input := generateInput(lhs_hash, rhs_hash)
_, lhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)
_, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)
input := toPrecompileInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
Expand All @@ -467,14 +477,14 @@ func TestFheAdd(t *testing.T) {
func TestFheSub(t *testing.T) {
c := &fheSub{}
depth := 1
state := newState()
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)
input := generateInput(lhs_hash, rhs_hash)
_, lhs_hash := verifyCiphertextInTestState(state.interpreter, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)
input := toPrecompileInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
Expand All @@ -492,16 +502,16 @@ func TestFheSub(t *testing.T) {
func TestFheLte(t *testing.T) {
c := &fheLte{}
depth := 1
state := newState()
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)
_, lhs_hash := verifyCiphertextInTestState(state.interpreter, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)

// 2 <= 1
input1 := generateInput(lhs_hash, rhs_hash)
input1 := toPrecompileInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input1, readOnly)
if err != nil {
t.Fatalf(err.Error())
Expand All @@ -516,7 +526,7 @@ func TestFheLte(t *testing.T) {
}

// 1 <= 2
input2 := generateInput(rhs_hash, lhs_hash)
input2 := toPrecompileInput(rhs_hash, lhs_hash)
out, err = c.Run(state, addr, addr, input2, readOnly)
if err != nil {
t.Fatalf(err.Error())
Expand All @@ -533,9 +543,9 @@ func TestFheLte(t *testing.T) {

func TestUnknownCiphertextHandle(t *testing.T) {
depth := 1
state := newState()
state := newTestState()
state.interpreter.evm.depth = depth
_, hash := verifyCiphertextInTestState(state, 2, depth)
_, hash := verifyCiphertextInTestState(state.interpreter, 2, depth)

_, found := getVerifiedCiphertext(state, hash)
if !found {
Expand All @@ -551,10 +561,10 @@ func TestUnknownCiphertextHandle(t *testing.T) {
}

func TestCiphertextNotVerifiedAtDepth(t *testing.T) {
state := newState()
state := newTestState()
state.interpreter.evm.depth = 1
verifiedDepth := 2
_, hash := verifyCiphertextInTestState(state, 1, verifiedDepth)
_, hash := verifyCiphertextInTestState(state.interpreter, 1, verifiedDepth)

_, found := getVerifiedCiphertext(state, hash)
if found {
Expand All @@ -563,10 +573,10 @@ func TestCiphertextNotVerifiedAtDepth(t *testing.T) {
}

func TestCiphertextVerifiedAtFurtherDepth(t *testing.T) {
state := newState()
state := newTestState()
state.interpreter.evm.depth = 3
verifiedDepth := 2
_, hash := verifyCiphertextInTestState(state, 1, verifiedDepth)
_, hash := verifyCiphertextInTestState(state.interpreter, 1, verifiedDepth)

_, found := getVerifiedCiphertext(state, hash)
if !found {
Expand Down
Loading

0 comments on commit 69631db

Please sign in to comment.