Skip to content

Commit

Permalink
Delegate handles found in call arguments
Browse files Browse the repository at this point in the history
If a ciphertext handle is part of the arguments to a call opcode (call,
static, delegate, callCode), delegate it such that it is verified for
the callee. Previously, the delegation was implicit for deeper call
depths. This PR makes it explicit, vie either the `delegateCiphertext`
precompile or a call opcode.

Implement by keeping a set of verified stack depths per ciphertext.

Increment the stack depth when calling a precompile such that the depth
is correct inside the precompile itself.

Closes ethereum#52.
  • Loading branch information
dartdart26 committed Mar 16, 2023
1 parent da6a06e commit 21b624e
Show file tree
Hide file tree
Showing 5 changed files with 542 additions and 227 deletions.
169 changes: 90 additions & 79 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ func ActivePrecompiles(rules params.Rules) []common.Address {
// - the _remaining_ gas,
// - any error that occurred
func RunPrecompiledContract(p PrecompiledContract, accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, suppliedGas uint64, readOnly bool) (ret []byte, remainingGas uint64, err error) {
accessibleState.Interpreter().evm.depth++
defer func() { accessibleState.Interpreter().evm.depth-- }()
gasCost := p.RequiredGas(input)
if suppliedGas < gasCost {
return nil, 0, ErrOutOfGas
Expand Down Expand Up @@ -1220,24 +1222,56 @@ func init() {
}
}

func getVerifiedCiphertext(accessibleState PrecompileAccessibleState, ciphertextHash common.Hash) (*tfheCiphertext, bool) {
ct, ok := accessibleState.Interpreter().verifiedCiphertexts[ciphertextHash]
if ok && ct.depth <= accessibleState.Interpreter().evm.depth {
return ct.ciphertext, true
func isVerifiedAtCurrentDepth(interpreter *EVMInterpreter, ct *verifiedCiphertext) bool {
return ct.verifiedDepths.has(interpreter.evm.depth)
}

// Returns a pointer to the ciphertext if the given hash points to a verified ciphertext.
// Else, it returns nil.
func getVerifiedCiphertextFromEVM(interpreter *EVMInterpreter, ciphertextHash common.Hash) *verifiedCiphertext {
ct, ok := interpreter.verifiedCiphertexts[ciphertextHash]
if ok && isVerifiedAtCurrentDepth(interpreter, ct) {
return ct
}
return nil
}

// See getVerifiedCiphertextFromEVM().
func getVerifiedCiphertext(accessibleState PrecompileAccessibleState, ciphertextHash common.Hash) *verifiedCiphertext {
return getVerifiedCiphertextFromEVM(accessibleState.Interpreter(), ciphertextHash)
}

func importCiphertextToEVMAtDepth(interpreter *EVMInterpreter, ct *tfheCiphertext, depth int) *verifiedCiphertext {
existing, ok := interpreter.verifiedCiphertexts[ct.getHash()]
if ok {
existing.verifiedDepths.add(depth)
return existing
} else {
verifiedDepths := newDepthSet()
verifiedDepths.add(depth)
new := &verifiedCiphertext{
verifiedDepths,
ct,
}
interpreter.verifiedCiphertexts[ct.getHash()] = new
return new
}
return nil, false
}

func importCiphertextToEVM(interpreter *EVMInterpreter, ct *tfheCiphertext) *verifiedCiphertext {
return importCiphertextToEVMAtDepth(interpreter, ct, interpreter.evm.depth)
}

func importCiphertext(accessibleState PrecompileAccessibleState, ct *tfheCiphertext) *verifiedCiphertext {
return importCiphertextToEVM(accessibleState.Interpreter(), ct)
}

// Used when we want to skip FHE computation, e.g. gas estimation.
func importRandomCiphertext(accessibleState PrecompileAccessibleState) []byte {
ct := new(tfheCiphertext)
ct.makeRandom()
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: ct,
}
importCiphertext(accessibleState, ct)
ctHash := ct.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext
return ctHash[:]
}

Expand All @@ -1253,12 +1287,12 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("input needs to contain two 256-bit sized values")
}

a, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
b, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1267,20 +1301,16 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return importRandomCiphertext(accessibleState), nil
}

result := a.add(b)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.add(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/add_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/add_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := verifiedCiphertext.ciphertext.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext
ctHash := result.getHash()
return ctHash[:], nil
}

Expand Down Expand Up @@ -1362,7 +1392,7 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller
return nil, err
}
ctHash := ct.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = &verifiedCiphertext{accessibleState.Interpreter().evm.depth, ct}
importCiphertext(accessibleState, ct)
return ctHash.Bytes(), nil
}

Expand Down Expand Up @@ -1392,8 +1422,8 @@ func (e *reencrypt) Run(accessibleState PrecompileAccessibleState, caller common
if len(input) != 32 {
return nil, errors.New("invalid ciphertext handle")
}
ct, ok := accessibleState.Interpreter().verifiedCiphertexts[common.BytesToHash(input)]
if ok && ct.depth <= accessibleState.Interpreter().evm.depth {
ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input))
if ct != nil {
decryptedValue := ct.ciphertext.decrypt()
reencryptedValue, err := fheEncryptToUserKey(decryptedValue, accessibleState.Interpreter().evm.Origin)
if err != nil {
Expand All @@ -1415,9 +1445,9 @@ func (e *delegateCiphertext) Run(accessibleState PrecompileAccessibleState, call
if len(input) != 32 {
return nil, errors.New("invalid ciphertext handle")
}
ct, ok := accessibleState.Interpreter().verifiedCiphertexts[common.BytesToHash(input)]
if ok {
ct.depth = minInt(ct.depth, accessibleState.Interpreter().evm.depth-1)
ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input))
if ct != nil {
ct.verifiedDepths.add(accessibleState.Interpreter().evm.depth + 1)
return nil, nil
}
return nil, errors.New("unverified ciphertext handle")
Expand Down Expand Up @@ -1591,12 +1621,12 @@ func (e *fheLte) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1605,20 +1635,16 @@ func (e *fheLte) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.lte(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.lte(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/lte_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/lte_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
Expand All @@ -1635,12 +1661,12 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1649,20 +1675,16 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.sub(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.sub(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/sub_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/sub_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
Expand All @@ -1679,12 +1701,12 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1693,20 +1715,16 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.mul(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.mul(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/mul_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/mul_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
Expand All @@ -1723,12 +1741,12 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
lhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if lhs == nil {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
rhs := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if rhs == nil {
return nil, errors.New("unverified ciphertext handle")
}

Expand All @@ -1737,20 +1755,16 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.lt(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}
result := lhs.ciphertext.lt(rhs.ciphertext)
importCiphertext(accessibleState, result)

// TODO: for testing
err := os.WriteFile("/tmp/lt_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err := os.WriteFile("/tmp/lt_result", result.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
Expand Down Expand Up @@ -1822,17 +1836,14 @@ func (e *fheRand) Run(accessibleState PrecompileAccessibleState, caller common.A
randInt := binary.BigEndian.Uint64(randBytes) % fheMessageModulus
randCt := new(tfheCiphertext)
randCt.trivialEncrypt(randInt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: randCt,
}
importCiphertext(accessibleState, randCt)

// TODO: for testing
err = os.WriteFile("/tmp/rand_result", verifiedCiphertext.ciphertext.serialize(), 0644)
err = os.WriteFile("/tmp/rand_result", randCt.serialize(), 0644)
if err != nil {
return nil, err
}
ctHash := randCt.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext
return ctHash[:], nil
}

Expand All @@ -1843,6 +1854,6 @@ func (e *faucet) RequiredGas(input []byte) uint64 {
}

func (e *faucet) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
accessibleState.Interpreter().evm.StateDB.AddBalance(common.BytesToAddress(input[0:20]), big.NewInt(10000000000000000000))
accessibleState.Interpreter().evm.StateDB.AddBalance(common.BytesToAddress(input[0:20]), big.NewInt(1000000000000000000))
return input, nil
}
Loading

0 comments on commit 21b624e

Please sign in to comment.