Skip to content

Commit

Permalink
Recover verified depths for calls to precompiles
Browse files Browse the repository at this point in the history
Rationale is that if a smart contract at depth N calls a precompile by a
CALL opcode (or its variants), there is no corresponding call to the
RETURN opcode. Therefore, the ciphertext will remain delegated to depth
N + 1 and that might not be the smart contract's original intention.

The reason is that if depth N + 1 is EVM bytecode, it will have to go
back to N via RETURN. However, if N + 1 is a precompile, it will return
"immediately" in Go.

In order to solve it, remember the depth set of any verified ciphertext
that has a handle to it in a CALL opcode. Then, when returning from the
CALL, recover the depth set for the remembered ciphertexts before the
call. Rationale is that the CALL is supposed to go back to the same EVM
depth in one way or another.

Add unit tests to check for these scenarios.
  • Loading branch information
dartdart26 committed May 30, 2023
1 parent d0af5bf commit 39fbc04
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 69 deletions.
2 changes: 1 addition & 1 deletion core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func ActivePrecompiles(rules params.Rules) []common.Address {
// - any error that occurred
func RunPrecompiledContract(p PrecompiledContract, accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, suppliedGas uint64, readOnly bool) (ret []byte, remainingGas uint64, err error) {
if accessibleState.Interpreter().evm.Commit {
accessibleState.Interpreter().evm.Logger.Info("Calling precompiled contract", "callerAddr", caller, "precompile", addr)
accessibleState.Interpreter().evm.Logger.Info("Calling precompile", "callerAddr", caller, "precompile", addr)
}
gasCost := p.RequiredGas(accessibleState, input)
if suppliedGas < gasCost {
Expand Down
22 changes: 11 additions & 11 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ func FheAdd(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

Expand Down Expand Up @@ -525,7 +525,7 @@ func FheSub(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

Expand Down Expand Up @@ -562,7 +562,7 @@ func FheMul(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

Expand Down Expand Up @@ -600,7 +600,7 @@ func FheLte(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != 0 {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0)
}

// rhs <= lhs
Expand All @@ -615,7 +615,7 @@ func FheLte(t *testing.T, fheUintType fheUintType) {
}
decrypted = res.ciphertext.decrypt()
if decrypted.Uint64() != 1 {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1)
}
}

Expand Down Expand Up @@ -654,7 +654,7 @@ func FheLt(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != 0 {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0)
}

// rhs < lhs
Expand All @@ -669,7 +669,7 @@ func FheLt(t *testing.T, fheUintType fheUintType) {
}
decrypted = res.ciphertext.decrypt()
if decrypted.Uint64() != 1 {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1)
}
}

Expand Down Expand Up @@ -794,7 +794,7 @@ func TestCiphertextNotAutomaticallyDelegated(t *testing.T) {

ct := getVerifiedCiphertext(state, hash)
if ct != nil {
t.Fatalf("expected that ciphertext is not verified")
t.Fatalf("expected that ciphertext is not verified at depth (%d)", state.interpreter.evm.depth)
}
}

Expand All @@ -806,18 +806,18 @@ func TestCiphertextVerificationConditions(t *testing.T) {
state.interpreter.evm.depth = verifiedDepth
ctPtr := getVerifiedCiphertext(state, hash)
if ctPtr == nil {
t.Fatalf("expected that ciphertext is verified at verifiedDepth")
t.Fatalf("expected that ciphertext is verified at verifiedDepth (%d)", verifiedDepth)
}

state.interpreter.evm.depth = verifiedDepth + 1
ct := getVerifiedCiphertext(state, hash)
if ct != nil {
t.Fatalf("expected that ciphertext is not verified at verifiedDepth + 1")
t.Fatalf("expected that ciphertext is not verified at verifiedDepth + 1 (%d)", verifiedDepth+1)
}

state.interpreter.evm.depth = verifiedDepth - 1
ct = getVerifiedCiphertext(state, hash)
if ct != nil {
t.Fatalf("expected that ciphertext is not verified at verifiedDepth - 1")
t.Fatalf("expected that ciphertext is not verified at verifiedDepth - 1 (%d)", verifiedDepth-1)
}
}
29 changes: 19 additions & 10 deletions core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -835,12 +835,22 @@ func opCreate2(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]
}

// If there are ciphertext handles in the arguments to a call, delegate them to the callee.
func delegateCiphertextHandlesInArgs(interpreter *EVMInterpreter, args []byte) {
// Return a map from ciphertext hash -> depthSet before delegation.
func delegateCiphertextHandlesInArgs(interpreter *EVMInterpreter, args []byte) (verified map[common.Hash]*depthSet) {
verified = make(map[common.Hash]*depthSet)
for key, verifiedCiphertext := range interpreter.verifiedCiphertexts {
if contains(args, key.Bytes()) && isVerifiedAtCurrentDepth(interpreter, verifiedCiphertext) {
verified[key] = verifiedCiphertext.verifiedDepths.clone()
verifiedCiphertext.verifiedDepths.add(interpreter.evm.depth + 1)
}
}
return
}

func restoreVerifiedDepths(interpreter *EVMInterpreter, verified map[common.Hash]*depthSet) {
for k, v := range verified {
interpreter.verifiedCiphertexts[k].verifiedDepths = v
}
}

func opCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) {
Expand All @@ -867,15 +877,14 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byt
bigVal = value.ToBig()
}

delegateCiphertextHandlesInArgs(interpreter, args)

verifiedBefore := delegateCiphertextHandlesInArgs(interpreter, args)
ret, returnGas, err := interpreter.evm.Call(scope.Contract, toAddr, args, gas, bigVal)

if err != nil {
temp.Clear()
} else {
temp.SetOne()
}
restoreVerifiedDepths(interpreter, verifiedBefore)
stack.push(&temp)
if err == nil || err == ErrExecutionReverted {
ret = common.CopyBytes(ret)
Expand Down Expand Up @@ -906,14 +915,14 @@ func opCallCode(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([
bigVal = value.ToBig()
}

delegateCiphertextHandlesInArgs(interpreter, args)

verifiedBefore := delegateCiphertextHandlesInArgs(interpreter, args)
ret, returnGas, err := interpreter.evm.CallCode(scope.Contract, toAddr, args, gas, bigVal)
if err != nil {
temp.Clear()
} else {
temp.SetOne()
}
restoreVerifiedDepths(interpreter, verifiedBefore)
stack.push(&temp)
if err == nil || err == ErrExecutionReverted {
ret = common.CopyBytes(ret)
Expand All @@ -937,14 +946,14 @@ func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext
// Get arguments from the memory.
args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64()))

delegateCiphertextHandlesInArgs(interpreter, args)

verifiedBefore := delegateCiphertextHandlesInArgs(interpreter, args)
ret, returnGas, err := interpreter.evm.DelegateCall(scope.Contract, toAddr, args, gas)
if err != nil {
temp.Clear()
} else {
temp.SetOne()
}
restoreVerifiedDepths(interpreter, verifiedBefore)
stack.push(&temp)
if err == nil || err == ErrExecutionReverted {
ret = common.CopyBytes(ret)
Expand All @@ -968,14 +977,14 @@ func opStaticCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext)
// Get arguments from the memory.
args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64()))

delegateCiphertextHandlesInArgs(interpreter, args)

verifiedBefore := delegateCiphertextHandlesInArgs(interpreter, args)
ret, returnGas, err := interpreter.evm.StaticCall(scope.Contract, toAddr, args, gas)
if err != nil {
temp.Clear()
} else {
temp.SetOne()
}
restoreVerifiedDepths(interpreter, verifiedBefore)
stack.push(&temp)
if err == nil || err == ErrExecutionReverted {
ret = common.CopyBytes(ret)
Expand Down
Loading

0 comments on commit 39fbc04

Please sign in to comment.