Skip to content

Commit

Permalink
Add unit tests for SSTORE, SLOAD and RETURN ops
Browse files Browse the repository at this point in the history
Focus on protected storage and in-memory ciphertext verification.

Also, change code in contract_tests.go such that test code reuse is
possible.
  • Loading branch information
dartdart26 committed Feb 1, 2023
1 parent 83bbc84 commit 6b2a383
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 26 deletions.
62 changes: 36 additions & 26 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"time"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
)

// precompiledTest defines the input/output pairs for precompiled contract tests.
Expand Down Expand Up @@ -414,24 +416,32 @@ func (s *statefulPrecompileAccessibleState) Interpreter() *EVMInterpreter {
return s.interpreter
}

func newState() *statefulPrecompileAccessibleState {
s := new(statefulPrecompileAccessibleState)
func newTestInterpreter() *EVMInterpreter {
cfg := Config{}
evm := &EVM{}
s.interpreter = NewEVMInterpreter(evm, cfg)
evm.interpreter = s.interpreter
interpreter := NewEVMInterpreter(evm, cfg)
db := rawdb.NewMemoryDatabase()
state, _ := state.New(common.Hash{}, state.NewDatabase(db), nil)
interpreter.evm.StateDB = state
return interpreter
}

func newTestState() *statefulPrecompileAccessibleState {
s := new(statefulPrecompileAccessibleState)
interpreter := newTestInterpreter()
s.interpreter = interpreter
return s
}

func verifyCiphertextInTestState(s *statefulPrecompileAccessibleState, value uint64, depth int) (*tfheCiphertext, common.Hash) {
func verifyCiphertextInTestState(interpreter *EVMInterpreter, value uint64, depth int) (*tfheCiphertext, common.Hash) {
ct := new(tfheCiphertext)
ct.encrypt(value)
hash := ct.getHash()
s.interpreter.verifiedCiphertexts[hash] = &verifiedCiphertext{depth, ct}
interpreter.verifiedCiphertexts[hash] = &verifiedCiphertext{depth, ct}
return ct, ct.getHash()
}

func generateInput(hashes ...common.Hash) []byte {
func toPrecompileInput(hashes ...common.Hash) []byte {
ret := make([]byte, 0)
for _, hash := range hashes {
ret = append(ret, hash.Bytes()...)
Expand All @@ -442,14 +452,14 @@ func generateInput(hashes ...common.Hash) []byte {
func TestFheAdd(t *testing.T) {
c := &fheAdd{}
depth := 1
state := newState()
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 1, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)
input := generateInput(lhs_hash, rhs_hash)
_, lhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)
_, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)
input := toPrecompileInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
Expand All @@ -467,14 +477,14 @@ func TestFheAdd(t *testing.T) {
func TestFheSub(t *testing.T) {
c := &fheSub{}
depth := 1
state := newState()
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)
input := generateInput(lhs_hash, rhs_hash)
_, lhs_hash := verifyCiphertextInTestState(state.interpreter, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)
input := toPrecompileInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
Expand All @@ -492,16 +502,16 @@ func TestFheSub(t *testing.T) {
func TestFheLte(t *testing.T) {
c := &fheLte{}
depth := 1
state := newState()
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state, 1, depth)
_, lhs_hash := verifyCiphertextInTestState(state.interpreter, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)

// 2 <= 1
input1 := generateInput(lhs_hash, rhs_hash)
input1 := toPrecompileInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input1, readOnly)
if err != nil {
t.Fatalf(err.Error())
Expand All @@ -516,7 +526,7 @@ func TestFheLte(t *testing.T) {
}

// 1 <= 2
input2 := generateInput(rhs_hash, lhs_hash)
input2 := toPrecompileInput(rhs_hash, lhs_hash)
out, err = c.Run(state, addr, addr, input2, readOnly)
if err != nil {
t.Fatalf(err.Error())
Expand All @@ -533,9 +543,9 @@ func TestFheLte(t *testing.T) {

func TestUnknownCiphertextHandle(t *testing.T) {
depth := 1
state := newState()
state := newTestState()
state.interpreter.evm.depth = depth
_, hash := verifyCiphertextInTestState(state, 2, depth)
_, hash := verifyCiphertextInTestState(state.interpreter, 2, depth)

_, found := getVerifiedCiphertext(state, hash)
if !found {
Expand All @@ -551,10 +561,10 @@ func TestUnknownCiphertextHandle(t *testing.T) {
}

func TestCiphertextNotVerifiedAtDepth(t *testing.T) {
state := newState()
state := newTestState()
state.interpreter.evm.depth = 1
verifiedDepth := 2
_, hash := verifyCiphertextInTestState(state, 1, verifiedDepth)
_, hash := verifyCiphertextInTestState(state.interpreter, 1, verifiedDepth)

_, found := getVerifiedCiphertext(state, hash)
if found {
Expand All @@ -563,10 +573,10 @@ func TestCiphertextNotVerifiedAtDepth(t *testing.T) {
}

func TestCiphertextVerifiedAtFurtherDepth(t *testing.T) {
state := newState()
state := newTestState()
state.interpreter.evm.depth = 3
verifiedDepth := 2
_, hash := verifyCiphertextInTestState(state, 1, verifiedDepth)
_, hash := verifyCiphertextInTestState(state.interpreter, 1, verifiedDepth)

_, found := getVerifiedCiphertext(state, hash)
if !found {
Expand Down
Loading

0 comments on commit 6b2a383

Please sign in to comment.