Skip to content

Commit

Permalink
Merge pull request ethereum#101 from zama-ai/petar/recover-verified-d…
Browse files Browse the repository at this point in the history
…epths-for-precompile-calls

Recover verified depths for calls to precompiles
  • Loading branch information
dartdart26 authored May 31, 2023
2 parents d0af5bf + 39fbc04 commit 5587be3
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 69 deletions.
2 changes: 1 addition & 1 deletion core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func ActivePrecompiles(rules params.Rules) []common.Address {
// - any error that occurred
func RunPrecompiledContract(p PrecompiledContract, accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, suppliedGas uint64, readOnly bool) (ret []byte, remainingGas uint64, err error) {
if accessibleState.Interpreter().evm.Commit {
accessibleState.Interpreter().evm.Logger.Info("Calling precompiled contract", "callerAddr", caller, "precompile", addr)
accessibleState.Interpreter().evm.Logger.Info("Calling precompile", "callerAddr", caller, "precompile", addr)
}
gasCost := p.RequiredGas(accessibleState, input)
if suppliedGas < gasCost {
Expand Down
22 changes: 11 additions & 11 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ func FheAdd(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

Expand Down Expand Up @@ -525,7 +525,7 @@ func FheSub(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

Expand Down Expand Up @@ -562,7 +562,7 @@ func FheMul(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), expected)
}
}

Expand Down Expand Up @@ -600,7 +600,7 @@ func FheLte(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != 0 {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0)
}

// rhs <= lhs
Expand All @@ -615,7 +615,7 @@ func FheLte(t *testing.T, fheUintType fheUintType) {
}
decrypted = res.ciphertext.decrypt()
if decrypted.Uint64() != 1 {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1)
}
}

Expand Down Expand Up @@ -654,7 +654,7 @@ func FheLt(t *testing.T, fheUintType fheUintType) {
}
decrypted := res.ciphertext.decrypt()
if decrypted.Uint64() != 0 {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 0)
}

// rhs < lhs
Expand All @@ -669,7 +669,7 @@ func FheLt(t *testing.T, fheUintType fheUintType) {
}
decrypted = res.ciphertext.decrypt()
if decrypted.Uint64() != 1 {
t.Fatalf("invalid decrypted result")
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1)
}
}

Expand Down Expand Up @@ -794,7 +794,7 @@ func TestCiphertextNotAutomaticallyDelegated(t *testing.T) {

ct := getVerifiedCiphertext(state, hash)
if ct != nil {
t.Fatalf("expected that ciphertext is not verified")
t.Fatalf("expected that ciphertext is not verified at depth (%d)", state.interpreter.evm.depth)
}
}

Expand All @@ -806,18 +806,18 @@ func TestCiphertextVerificationConditions(t *testing.T) {
state.interpreter.evm.depth = verifiedDepth
ctPtr := getVerifiedCiphertext(state, hash)
if ctPtr == nil {
t.Fatalf("expected that ciphertext is verified at verifiedDepth")
t.Fatalf("expected that ciphertext is verified at verifiedDepth (%d)", verifiedDepth)
}

state.interpreter.evm.depth = verifiedDepth + 1
ct := getVerifiedCiphertext(state, hash)
if ct != nil {
t.Fatalf("expected that ciphertext is not verified at verifiedDepth + 1")
t.Fatalf("expected that ciphertext is not verified at verifiedDepth + 1 (%d)", verifiedDepth+1)
}

state.interpreter.evm.depth = verifiedDepth - 1
ct = getVerifiedCiphertext(state, hash)
if ct != nil {
t.Fatalf("expected that ciphertext is not verified at verifiedDepth - 1")
t.Fatalf("expected that ciphertext is not verified at verifiedDepth - 1 (%d)", verifiedDepth-1)
}
}
29 changes: 19 additions & 10 deletions core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -835,12 +835,22 @@ func opCreate2(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]
}

// If there are ciphertext handles in the arguments to a call, delegate them to the callee.
func delegateCiphertextHandlesInArgs(interpreter *EVMInterpreter, args []byte) {
// Return a map from ciphertext hash -> depthSet before delegation.
func delegateCiphertextHandlesInArgs(interpreter *EVMInterpreter, args []byte) (verified map[common.Hash]*depthSet) {
verified = make(map[common.Hash]*depthSet)
for key, verifiedCiphertext := range interpreter.verifiedCiphertexts {
if contains(args, key.Bytes()) && isVerifiedAtCurrentDepth(interpreter, verifiedCiphertext) {
verified[key] = verifiedCiphertext.verifiedDepths.clone()
verifiedCiphertext.verifiedDepths.add(interpreter.evm.depth + 1)
}
}
return
}

func restoreVerifiedDepths(interpreter *EVMInterpreter, verified map[common.Hash]*depthSet) {
for k, v := range verified {
interpreter.verifiedCiphertexts[k].verifiedDepths = v
}
}

func opCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) {
Expand All @@ -867,15 +877,14 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byt
bigVal = value.ToBig()
}

delegateCiphertextHandlesInArgs(interpreter, args)

verifiedBefore := delegateCiphertextHandlesInArgs(interpreter, args)
ret, returnGas, err := interpreter.evm.Call(scope.Contract, toAddr, args, gas, bigVal)

if err != nil {
temp.Clear()
} else {
temp.SetOne()
}
restoreVerifiedDepths(interpreter, verifiedBefore)
stack.push(&temp)
if err == nil || err == ErrExecutionReverted {
ret = common.CopyBytes(ret)
Expand Down Expand Up @@ -906,14 +915,14 @@ func opCallCode(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([
bigVal = value.ToBig()
}

delegateCiphertextHandlesInArgs(interpreter, args)

verifiedBefore := delegateCiphertextHandlesInArgs(interpreter, args)
ret, returnGas, err := interpreter.evm.CallCode(scope.Contract, toAddr, args, gas, bigVal)
if err != nil {
temp.Clear()
} else {
temp.SetOne()
}
restoreVerifiedDepths(interpreter, verifiedBefore)
stack.push(&temp)
if err == nil || err == ErrExecutionReverted {
ret = common.CopyBytes(ret)
Expand All @@ -937,14 +946,14 @@ func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext
// Get arguments from the memory.
args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64()))

delegateCiphertextHandlesInArgs(interpreter, args)

verifiedBefore := delegateCiphertextHandlesInArgs(interpreter, args)
ret, returnGas, err := interpreter.evm.DelegateCall(scope.Contract, toAddr, args, gas)
if err != nil {
temp.Clear()
} else {
temp.SetOne()
}
restoreVerifiedDepths(interpreter, verifiedBefore)
stack.push(&temp)
if err == nil || err == ErrExecutionReverted {
ret = common.CopyBytes(ret)
Expand All @@ -968,14 +977,14 @@ func opStaticCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext)
// Get arguments from the memory.
args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64()))

delegateCiphertextHandlesInArgs(interpreter, args)

verifiedBefore := delegateCiphertextHandlesInArgs(interpreter, args)
ret, returnGas, err := interpreter.evm.StaticCall(scope.Contract, toAddr, args, gas)
if err != nil {
temp.Clear()
} else {
temp.SetOne()
}
restoreVerifiedDepths(interpreter, verifiedBefore)
stack.push(&temp)
if err == nil || err == ErrExecutionReverted {
ret = common.CopyBytes(ret)
Expand Down
Loading

0 comments on commit 5587be3

Please sign in to comment.