Skip to content

Commit

Permalink
tests: fix tests that requires decryption after kms integration
Browse files Browse the repository at this point in the history
- we now run tests that requires decryption using a local decryption
  instead of calling the kms
- tests that necessarily require a kms to be live were disabled for now
  • Loading branch information
youben11 committed Jan 5, 2024
1 parent c04cf17 commit 4e7ede6
Showing 1 changed file with 120 additions and 34 deletions.
154 changes: 120 additions & 34 deletions fhevm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package fhevm

import (
"bytes"
"encoding/hex"
"errors"
"math/big"
"strings"
"testing"
Expand Down Expand Up @@ -73,6 +76,73 @@ func toPrecompileInputNoScalar(isScalar bool, hashes ...common.Hash) []byte {
return ret
}

func evaluateRemainingOptimisticRequiresWithoutKms(environment EVMEnvironment) (bool, error) {
requires := environment.FhevmData().optimisticRequires
len := len(requires)
defer func() { environment.FhevmData().optimisticRequires = make([]*tfheCiphertext, 0) }()
if len != 0 {
var cumulative *tfheCiphertext = requires[0]
var err error
for i := 1; i < len; i++ {
cumulative, err = cumulative.bitand(requires[i])
if err != nil {
environment.GetLogger().Error("evaluateRemainingOptimisticRequires bitand failed", "err", err)
return false, err
}
}
result, err := cumulative.decrypt()
return result.Uint64() != 0, err
}
return true, nil
}

func decryptRunWithoutKms(environment EVMEnvironment, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
logger := environment.GetLogger()
// if not gas estimation and not view function fail if decryptions are disabled in transactions
if environment.IsCommitting() && !environment.IsEthCall() && environment.FhevmParams().DisableDecryptionsInTransaction {
msg := "decryptions during transaction are disabled"
logger.Error(msg, "input", hex.EncodeToString(input))
return nil, errors.New(msg)
}
if len(input) != 32 {
msg := "decrypt input len must be 32 bytes"
logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input))
return nil, errors.New(msg)
}
ct := getVerifiedCiphertext(environment, common.BytesToHash(input))
if ct == nil {
msg := "decrypt unverified handle"
logger.Error(msg, "input", hex.EncodeToString(input))
return nil, errors.New(msg)
}

// If we are doing gas estimation, skip decryption and make sure we return the maximum possible value.
// We need that, because non-zero bytes cost more than zero bytes in some contexts (e.g. SSTORE or memory operations).
if !environment.IsCommitting() && !environment.IsEthCall() {
return bytes.Repeat([]byte{0xFF}, 32), nil
}
// Make sure we don't decrypt before any optimistic requires are checked.
optReqResult, optReqErr := evaluateRemainingOptimisticRequiresWithoutKms(environment)
if optReqErr != nil {
return nil, optReqErr
} else if !optReqResult {
return nil, ErrExecutionReverted
}

plaintext, err := ct.ciphertext.decrypt()
if err != nil {
logger.Error("decrypt failed", "err", err)
return nil, err
}

logger.Info("decrypt success", "plaintext", plaintext)

// Always return a 32-byte big-endian integer.
ret := make([]byte, 32)
plaintext.FillBytes(ret)
return ret, nil
}

var scalarBytePadding = make([]byte, 31)

func toLibPrecompileInput(method string, isScalar bool, hashes ...common.Hash) []byte {
Expand Down Expand Up @@ -1290,7 +1360,7 @@ func FheLibIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted, err := res.ciphertext.decrypt()
if err != nil || condition == 1 && decrypted.Uint64() != second || condition == 0 && decrypted.Uint64() != third {
if err != nil || condition == 1 && decrypted.Uint64() != second || condition == 0 && decrypted.Uint64() != third {
t.Fatalf("invalid decrypted result, decrypted %v != expected %v", decrypted.Uint64(), 1)
}
}
Expand Down Expand Up @@ -1392,29 +1462,30 @@ func TestLibVerifyCiphertextInvalidType(t *testing.T) {
}
}

func TestLibReencrypt(t *testing.T) {
signature := "reencrypt(uint256,uint256)"
hashRes := crypto.Keccak256([]byte(signature))
signatureBytes := hashRes[0:4]
depth := 1
environment := newTestEVMEnvironment()
environment.depth = depth
environment.ethCall = true
toEncrypt := 7
fheUintType := FheUint8
encCiphertext := verifyCiphertextInTestMemory(environment, uint64(toEncrypt), depth, fheUintType).getHash()
addr := common.Address{}
readOnly := false
input := make([]byte, 0)
input = append(input, signatureBytes...)
input = append(input, encCiphertext.Bytes()...)
// just append twice not to generate public key
input = append(input, encCiphertext.Bytes()...)
_, err := FheLibRun(environment, addr, addr, input, readOnly)
if err != nil {
t.Fatalf("Reencrypt error: %s", err.Error())
}
}
// TODO: can be enabled if mocking kms or running a kms during tests
// func TestLibReencrypt(t *testing.T) {
// signature := "reencrypt(uint256,uint256)"
// hashRes := crypto.Keccak256([]byte(signature))
// signatureBytes := hashRes[0:4]
// depth := 1
// environment := newTestEVMEnvironment()
// environment.depth = depth
// environment.ethCall = true
// toEncrypt := 7
// fheUintType := FheUint8
// encCiphertext := verifyCiphertextInTestMemory(environment, uint64(toEncrypt), depth, fheUintType).getHash()
// addr := common.Address{}
// readOnly := false
// input := make([]byte, 0)
// input = append(input, signatureBytes...)
// input = append(input, encCiphertext.Bytes()...)
// // just append twice not to generate public key
// input = append(input, encCiphertext.Bytes()...)
// _, err := FheLibRun(environment, addr, addr, input, readOnly)
// if err != nil {
// t.Fatalf("Reencrypt error: %s", err.Error())
// }
// }

func TestLibCast(t *testing.T) {
signature := "cast(uint256,bytes1)"
Expand Down Expand Up @@ -2398,7 +2469,6 @@ func FheNot(t *testing.T, fheUintType FheUintType, scalar bool) {
}
}


func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
var lhs, rhs uint64
switch fheUintType {
Expand All @@ -2420,7 +2490,7 @@ func FheIfThenElse(t *testing.T, fheUintType FheUintType, condition uint64) {
conditionHash := verifyCiphertextInTestMemory(environment, condition, depth, fheUintType).getHash()
lhsHash := verifyCiphertextInTestMemory(environment, lhs, depth, fheUintType).getHash()
rhsHash := verifyCiphertextInTestMemory(environment, rhs, depth, fheUintType).getHash()

input1 := toPrecompileInputNoScalar(false, conditionHash, lhsHash, rhsHash)
out, err := fheIfThenElseRun(environment, addr, addr, input1, readOnly)
if err != nil {
Expand Down Expand Up @@ -2452,7 +2522,7 @@ func Decrypt(t *testing.T, fheUintType FheUintType) {
addr := common.Address{}
readOnly := false
hash := verifyCiphertextInTestMemory(environment, value, depth, fheUintType).getHash()
out, err := decryptRun(environment, addr, addr, hash.Bytes(), readOnly)
out, err := decryptRunWithoutKms(environment, addr, addr, hash.Bytes(), readOnly)
if err != nil {
t.Fatalf(err.Error())
} else if len(out) != 32 {
Expand Down Expand Up @@ -2730,9 +2800,10 @@ func TestFheLibTrivialEncrypt8(t *testing.T) {
LibTrivialEncrypt(t, FheUint8)
}

func TestLibDecrypt8(t *testing.T) {
LibDecrypt(t, FheUint8)
}
// TODO: can be enabled if mocking kms or running a kms during tests
// func TestLibDecrypt8(t *testing.T) {
// LibDecrypt(t, FheUint8)
// }

func TestFheAdd8(t *testing.T) {
FheAdd(t, FheUint8, false)
Expand Down Expand Up @@ -3430,12 +3501,27 @@ func TestFheRandBoundedEthCall(t *testing.T) {
}
}

func EvalRemOptReqWhenStopTokenWithoutKms(env EVMEnvironment) (err error) {
err = nil
// If we are finishing execution (about to go to from depth 1 to depth 0), evaluate
// any remaining optimistic requires.
if env.GetDepth() == 1 {
result, evalErr := evaluateRemainingOptimisticRequiresWithoutKms(env)
if evalErr != nil {
err = evalErr
} else if !result {
err = ErrExecutionReverted
}
}
return err
}

func interpreterRunWithStopContract(environment *MockEVMEnvironment, interpreter *vm.EVMInterpreter, contract *vm.Contract, input []byte, readOnly bool) (ret []byte, err error) {
ret, _ = interpreter.Run(contract, input, readOnly)
// the following functions are meant to be ran from within interpreter.run so we increment depth to emulate that
environment.depth++
RemoveVerifiedCipherextsAtCurrentDepth(environment)
err = EvalRemOptReqWhenStopToken(environment)
err = EvalRemOptReqWhenStopTokenWithoutKms(environment)
environment.depth--
return ret, err
}
Expand Down Expand Up @@ -3622,7 +3708,7 @@ func TestDecryptWithFalseOptimisticRequire(t *testing.T) {
t.Fatalf("require expected output len of 0, got %v", len(out))
}
// Call decrypt and expect it to fail due to the optimistic require being false.
_, err = decryptRun(environment, addr, addr, hash.Bytes(), readOnly)
_, err = decryptRunWithoutKms(environment, addr, addr, hash.Bytes(), readOnly)
if err == nil {
t.Fatalf("expected decrypt fails due to false optimistic require")
}
Expand All @@ -3647,7 +3733,7 @@ func TestDecryptWithTrueOptimisticRequire(t *testing.T) {
t.Fatalf("require expected output len of 0, got %v", len(out))
}
// Call decrypt and expect it to succeed due to the optimistic require being true.
out, err = decryptRun(environment, addr, addr, hash.Bytes(), readOnly)
out, err = decryptRunWithoutKms(environment, addr, addr, hash.Bytes(), readOnly)
if err != nil {
t.Fatalf(err.Error())
} else if len(out) != 32 {
Expand All @@ -3670,7 +3756,7 @@ func TestDecryptInTransactionDisabled(t *testing.T) {
readOnly := false
hash := verifyCiphertextInTestMemory(environment, 1, depth, FheUint8).getHash()
// Call decrypt and expect it to fail due to disabling of decryptions during commit
_, err := decryptRun(environment, addr, addr, hash.Bytes(), readOnly)
_, err := decryptRunWithoutKms(environment, addr, addr, hash.Bytes(), readOnly)
if err == nil {
t.Fatalf("expected to error out in test")
} else if err.Error() != "decryptions during transaction are disabled" {
Expand Down

0 comments on commit 4e7ede6

Please sign in to comment.