diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index 35fbfab53392..506a14a4c6d8 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -25,6 +25,8 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + "github.com/ethereum/go-ethereum/core/state" ) // precompiledTest defines the input/output pairs for precompiled contract tests. @@ -414,24 +416,32 @@ func (s *statefulPrecompileAccessibleState) Interpreter() *EVMInterpreter { return s.interpreter } -func newState() *statefulPrecompileAccessibleState { - s := new(statefulPrecompileAccessibleState) +func newTestInterpreter() *EVMInterpreter { cfg := Config{} evm := &EVM{} - s.interpreter = NewEVMInterpreter(evm, cfg) - evm.interpreter = s.interpreter + interpreter := NewEVMInterpreter(evm, cfg) + db := rawdb.NewMemoryDatabase() + state, _ := state.New(common.Hash{}, state.NewDatabase(db), nil) + interpreter.evm.StateDB = state + return interpreter +} + +func newTestState() *statefulPrecompileAccessibleState { + s := new(statefulPrecompileAccessibleState) + interpreter := newTestInterpreter() + s.interpreter = interpreter return s } -func verifyCiphertextInTestState(s *statefulPrecompileAccessibleState, value uint64, depth int) (*tfheCiphertext, common.Hash) { +func verifyCiphertextInTestState(interpreter *EVMInterpreter, value uint64, depth int) (*tfheCiphertext, common.Hash) { ct := new(tfheCiphertext) ct.encrypt(value) hash := ct.getHash() - s.interpreter.verifiedCiphertexts[hash] = &verifiedCiphertext{depth, ct} + interpreter.verifiedCiphertexts[hash] = &verifiedCiphertext{depth, ct} return ct, ct.getHash() } -func generateInput(hashes ...common.Hash) []byte { +func toPrecompileInput(hashes ...common.Hash) []byte { ret := make([]byte, 0) for _, hash := range hashes { ret = append(ret, hash.Bytes()...) @@ -442,14 +452,14 @@ func generateInput(hashes ...common.Hash) []byte { func TestFheAdd(t *testing.T) { c := &fheAdd{} depth := 1 - state := newState() + state := newTestState() state.interpreter.evm.depth = depth state.interpreter.evm.Commit = true addr := common.Address{} readOnly := false - _, lhs_hash := verifyCiphertextInTestState(state, 1, depth) - _, rhs_hash := verifyCiphertextInTestState(state, 1, depth) - input := generateInput(lhs_hash, rhs_hash) + _, lhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth) + _, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth) + input := toPrecompileInput(lhs_hash, rhs_hash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -467,14 +477,14 @@ func TestFheAdd(t *testing.T) { func TestFheSub(t *testing.T) { c := &fheSub{} depth := 1 - state := newState() + state := newTestState() state.interpreter.evm.depth = depth state.interpreter.evm.Commit = true addr := common.Address{} readOnly := false - _, lhs_hash := verifyCiphertextInTestState(state, 2, depth) - _, rhs_hash := verifyCiphertextInTestState(state, 1, depth) - input := generateInput(lhs_hash, rhs_hash) + _, lhs_hash := verifyCiphertextInTestState(state.interpreter, 2, depth) + _, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth) + input := toPrecompileInput(lhs_hash, rhs_hash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -492,16 +502,16 @@ func TestFheSub(t *testing.T) { func TestFheLte(t *testing.T) { c := &fheLte{} depth := 1 - state := newState() + state := newTestState() state.interpreter.evm.depth = depth state.interpreter.evm.Commit = true addr := common.Address{} readOnly := false - _, lhs_hash := verifyCiphertextInTestState(state, 2, depth) - _, rhs_hash := verifyCiphertextInTestState(state, 1, depth) + _, lhs_hash := verifyCiphertextInTestState(state.interpreter, 2, depth) + _, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth) // 2 <= 1 - input1 := generateInput(lhs_hash, rhs_hash) + input1 := toPrecompileInput(lhs_hash, rhs_hash) out, err := c.Run(state, addr, addr, input1, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -516,7 +526,7 @@ func TestFheLte(t *testing.T) { } // 1 <= 2 - input2 := generateInput(rhs_hash, lhs_hash) + input2 := toPrecompileInput(rhs_hash, lhs_hash) out, err = c.Run(state, addr, addr, input2, readOnly) if err != nil { t.Fatalf(err.Error()) @@ -533,9 +543,9 @@ func TestFheLte(t *testing.T) { func TestUnknownCiphertextHandle(t *testing.T) { depth := 1 - state := newState() + state := newTestState() state.interpreter.evm.depth = depth - _, hash := verifyCiphertextInTestState(state, 2, depth) + _, hash := verifyCiphertextInTestState(state.interpreter, 2, depth) _, found := getVerifiedCiphertext(state, hash) if !found { @@ -551,10 +561,10 @@ func TestUnknownCiphertextHandle(t *testing.T) { } func TestCiphertextNotVerifiedAtDepth(t *testing.T) { - state := newState() + state := newTestState() state.interpreter.evm.depth = 1 verifiedDepth := 2 - _, hash := verifyCiphertextInTestState(state, 1, verifiedDepth) + _, hash := verifyCiphertextInTestState(state.interpreter, 1, verifiedDepth) _, found := getVerifiedCiphertext(state, hash) if found { @@ -563,10 +573,10 @@ func TestCiphertextNotVerifiedAtDepth(t *testing.T) { } func TestCiphertextVerifiedAtFurtherDepth(t *testing.T) { - state := newState() + state := newTestState() state.interpreter.evm.depth = 3 verifiedDepth := 2 - _, hash := verifyCiphertextInTestState(state, 1, verifiedDepth) + _, hash := verifyCiphertextInTestState(state.interpreter, 1, verifiedDepth) _, found := getVerifiedCiphertext(state, hash) if !found { diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index fb0fcc1da49d..04ccda69d05f 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -695,3 +695,267 @@ func TestRandom(t *testing.T) { } } } + +type testContractAddress struct{} + +func (c testContractAddress) Address() common.Address { + return common.Address{} +} + +type testCallerAddress struct{} + +func (c testCallerAddress) Address() common.Address { + addr := common.Address{} + addr[0]++ + return addr +} + +func newTestScopeConext() *ScopeContext { + c := new(ScopeContext) + c.Memory = NewMemory() + c.Memory.Resize(ciphertextSize * 3) + c.Stack = newstack() + c.Contract = NewContract(testCallerAddress{}, testContractAddress{}, big.NewInt(10), 100000) + return c +} + +func uint256FromBig(b *big.Int) *uint256.Int { + value, overflow := uint256.FromBig(b) + if overflow { + panic("overflow") + } + return value +} + +func TestProtectedStorageSstoreSload(t *testing.T) { + pc := uint64(0) + depth := 1 + interpreter := newTestInterpreter() + ct, ctHash := verifyCiphertextInTestState(interpreter, 2, depth) + scope := newTestScopeConext() + loc := uint256.NewInt(10) + value := uint256FromBig(ctHash.Big()) + + // Setup and call SSTORE - it requires a location and a value to set there. + scope.Stack.push(value) + scope.Stack.push(loc) + _, err := opSstore(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + + // Clear the verified ciphertexts. + interpreter.verifiedCiphertexts = make(map[common.Hash]*verifiedCiphertext) + + // Setup and call SLOAD - it requires a location to load. + scope.Stack.push(loc) + _, err = opSload(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + + // Expect the ciphertext is verified after SLOAD. + ctAfterSload, found := interpreter.verifiedCiphertexts[ctHash] + if !found { + t.Fatalf("expected ciphertext is verified after sload") + } + if !bytes.Equal(ct.serialize(), ctAfterSload.ciphertext.serialize()) { + t.Fatalf("expected ciphertext after sload is the same as original") + } +} + +func TestProtectedStorageGarbageCollection(t *testing.T) { + pc := uint64(0) + depth := 1 + interpreter := newTestInterpreter() + ct, ctHash := verifyCiphertextInTestState(interpreter, 2, depth) + scope := newTestScopeConext() + loc := uint256.NewInt(10) + value := uint256FromBig(ctHash.Big()) + + // Persist the ciphertext in protected storage. + scope.Stack.push(value) + scope.Stack.push(loc) + _, err := opSstore(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + + // Make sure ciphertext is persisted to protected storage. + protectedStorage := crypto.CreateProtectedStorageContractAddress(scope.Contract.Address()) + metadata := ciphertextMetadata{} + metadata.deserialize(interpreter.evm.StateDB.GetState(protectedStorage, ctHash)) + if metadata.refCount != 1 { + t.Fatalf("metadata.refcount of ciphertext is not 1") + } + if metadata.length != uint64(len(ct.serialize())) { + t.Fatalf("metadata.length of ciphertext is incorrect") + } + ciphertextLocationsToCheck := (metadata.length + 32 - 1) / 32 + startOfCiphertext := newInt(ctHash[:]) + startOfCiphertext.AddUint64(startOfCiphertext, 1) + ctIdx := startOfCiphertext + foundNonZero := false + for i := uint64(0); i < ciphertextLocationsToCheck; i++ { + c := interpreter.evm.StateDB.GetState(protectedStorage, common.BytesToHash(ctIdx.Bytes())) + u := uint256FromBig(c.Big()) + if !u.IsZero() { + foundNonZero = true + break + } + ctIdx.AddUint64(startOfCiphertext, 1) + } + if !foundNonZero { + t.Fatalf("ciphertext is not persisted to protected storage") + } + + // Overwrite the ciphertext handle with 0. + scope.Stack.push(uint256.NewInt(0)) + scope.Stack.push(loc) + _, err = opSstore(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + + // Make sure the metadata and the ciphertext are garbage collected from protected storage. + protectedStorageIdx := newInt(ctHash[:]) + foundNonZero = false + for i := uint64(0); i < ciphertextLocationsToCheck; i++ { + c := interpreter.evm.StateDB.GetState(protectedStorage, common.BytesToHash(protectedStorageIdx.Bytes())) + u := uint256FromBig(c.Big()) + if !u.IsZero() { + foundNonZero = true + break + } + ctIdx.AddUint64(startOfCiphertext, 1) + } + if foundNonZero { + t.Fatalf("ciphertext is not garbage collected from protected storage") + } +} + +func TestProtectedStorageSloadDoesNotVerifyNonHandle(t *testing.T) { + pc := uint64(0) + interpreter := newTestInterpreter() + scope := newTestScopeConext() + loc := uint256.NewInt(10) + value := uint256.NewInt(42) + + scope.Stack.push(value) + scope.Stack.push(loc) + _, err := opSstore(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + + scope.Stack.push(loc) + _, err = opSload(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + + // Expect no verified ciphertexts + if len(interpreter.verifiedCiphertexts) != 0 { + t.Fatalf("expected no verified ciphetexts") + } +} + +func TestProtectedStorageSloadAlreadyVerified(t *testing.T) { + pc := uint64(0) + depth := 2 + interpreter := newTestInterpreter() + ct, ctHash := verifyCiphertextInTestState(interpreter, 2, depth) + scope := newTestScopeConext() + loc := uint256.NewInt(10) + value := uint256FromBig(ctHash.Big()) + + scope.Stack.push(value) + scope.Stack.push(loc) + _, err := opSstore(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + + // SLOAD at a further depth. + interpreter.evm.depth = depth + 1 + scope.Stack.push(loc) + _, err = opSload(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + ctAfterSload, found := interpreter.verifiedCiphertexts[ctHash] + if !found { + t.Fatalf("expected ciphertext is verified after sload") + } + if !bytes.Equal(ct.serialize(), ctAfterSload.ciphertext.serialize()) { + t.Fatalf("expected ciphertext after sload is the same as original") + } + if ctAfterSload.depth != depth { + t.Fatalf("expected already verified ciphertext to have the minimum depth of the two depths") + } + + // SLOAD at a smaller depth. + interpreter.evm.depth = depth - 1 + scope.Stack.push(loc) + _, err = opSload(&pc, interpreter, scope) + if err != nil { + t.Fatalf(err.Error()) + } + ctAfterSload, found = interpreter.verifiedCiphertexts[ctHash] + if !found { + t.Fatalf("expected ciphertext is verified after sload") + } + if !bytes.Equal(ct.serialize(), ctAfterSload.ciphertext.serialize()) { + t.Fatalf("expected ciphertext after sload is the same as original") + } + if ctAfterSload.depth != depth-1 { + t.Fatalf("expected already verified ciphertext to have the depth at sload") + } +} + +func TestOpReturnDelegation(t *testing.T) { + pc := uint64(0) + depth := 2 + interpreter := newTestInterpreter() + scope := newTestScopeConext() + ct, ctHash := verifyCiphertextInTestState(interpreter, 2, depth) + + offset := uint256.NewInt(0) + len := uint256.NewInt(32) + scope.Stack.push(len) + scope.Stack.push(offset) + scope.Memory.Set(offset.Uint64(), len.Uint64(), ctHash[:]) + interpreter.evm.depth = depth + opReturn(&pc, interpreter, scope) + ctAfterOp, found := interpreter.verifiedCiphertexts[ctHash] + if !found { + t.Fatalf("expected ciphertext is still verified after the return op") + } + if !bytes.Equal(ct.serialize(), ctAfterOp.ciphertext.serialize()) { + t.Fatalf("expected ciphertext after the return op is the same as original") + } + if ctAfterOp.depth != depth-1 { + t.Fatalf("expected ciphertext depth to be reduced by 1 after the return op") + } +} + +func TestOpReturnRemovesVerificationIfNotReturned(t *testing.T) { + pc := uint64(0) + depth := 2 + interpreter := newTestInterpreter() + scope := newTestScopeConext() + _, ctHash := verifyCiphertextInTestState(interpreter, 2, depth) + + offset := uint256.NewInt(0) + len := uint256.NewInt(32) + scope.Stack.push(len) + scope.Stack.push(offset) + // Set 0s as return. + scope.Memory.Set(offset.Uint64(), len.Uint64(), make([]byte, len.Uint64())) + interpreter.evm.depth = depth + opReturn(&pc, interpreter, scope) + _, found := interpreter.verifiedCiphertexts[ctHash] + if found { + t.Fatalf("expected ciphertext is not verified after the return op") + } +}