Skip to content

Commit

Permalink
Merge pull request ethereum#95 from zama-ai/louis-output
Browse files Browse the repository at this point in the history
  • Loading branch information
tremblaythibaultl authored May 24, 2023
2 parents e08adf4 + 41d9bcc commit cbb0464
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 83 deletions.
28 changes: 19 additions & 9 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package vm
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
Expand All @@ -40,6 +41,7 @@ import (
"github.com/ethereum/go-ethereum/params"
"github.com/holiman/uint256"
"github.com/naoina/toml"
"golang.org/x/crypto/nacl/box"
"golang.org/x/crypto/ripemd160"
)

Expand Down Expand Up @@ -1384,13 +1386,19 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad
return ctHash[:], nil
}

func fheEncryptToUserKey(value uint64, userAddress common.Address, t fheUintType) ([]byte, error) {
userPublicKey := strings.ToLower(usersKeysDir + userAddress.Hex())
pks, err := os.ReadFile(userPublicKey)
func classicalPublicKeyEncrypt(value *big.Int, userPublicKey []byte) ([]byte, error) {
encrypted, err := box.SealAnonymous(nil, value.Bytes(), (*[32]byte)(userPublicKey), rand.Reader)
if err != nil {
return nil, err
}
return encrypted, nil
}

func encryptToUserKey(value *big.Int, pubKey []byte) ([]byte, error) {
ct, err := classicalPublicKeyEncrypt(value, pubKey)
if err != nil {
return nil, err
}
ct := publicKeyEncrypt(pks, value, t)

// TODO: for testing
err = os.WriteFile("/tmp/public_encrypt_result", ct, 0644)
Expand Down Expand Up @@ -1479,15 +1487,16 @@ func (e *reencrypt) Run(accessibleState PrecompileAccessibleState, caller common
accessibleState.Interpreter().evm.Logger.Error(msg)
return nil, errors.New(msg)
}
if len(input) != 32 {
msg := "reencrypt input len must be 32 bytes"
if len(input) != 64 {
msg := "reencrypt input len must be 64 bytes)"
accessibleState.Interpreter().evm.Logger.Error(msg, "input", hex.EncodeToString(input), "len", len(input))
return nil, errors.New(msg)
}
ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input))
ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if ct != nil {
decryptedValue := ct.ciphertext.decrypt()
reencryptedValue, err := fheEncryptToUserKey(decryptedValue, accessibleState.Interpreter().evm.Origin, ct.ciphertext.fheUintType)
pubKey := input[32:64]
reencryptedValue, err := encryptToUserKey(&decryptedValue, pubKey)
if err != nil {
accessibleState.Interpreter().evm.Logger.Error("reencrypt failed to encrypt to user key", "err", err)
return nil, err
Expand Down Expand Up @@ -1534,7 +1543,8 @@ func requireURL(key *string) string {
// Returns the require value.
func putRequire(ct *tfheCiphertext, interpreter *EVMInterpreter) bool {
ciphertext := ct.serialize()
value := (ct.decrypt() != 0)
plaintext := ct.decrypt()
value := (plaintext.BitLen() != 0)
key := requireKey(ciphertext)
j, err := json.Marshal(requireMessage{value, signRequire(ciphertext, value)})
if err != nil {
Expand Down
16 changes: 8 additions & 8 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ func newTestState() *statefulPrecompileAccessibleState {

func verifyCiphertextInTestMemory(interpreter *EVMInterpreter, value uint64, depth int, t fheUintType) *tfheCiphertext {
ct := new(tfheCiphertext)
ct.encrypt(value, t)
ct.encrypt(*new(big.Int).SetUint64(value), t)
return verifyTfheCiphertextInTestMemory(interpreter, ct, depth)
}

Expand Down Expand Up @@ -486,7 +486,7 @@ func FheAdd(t *testing.T, fheUintType fheUintType) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != uint64(expected) {
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
}
}
Expand Down Expand Up @@ -524,7 +524,7 @@ func FheSub(t *testing.T, fheUintType fheUintType) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != expected {
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
}
}
Expand Down Expand Up @@ -562,7 +562,7 @@ func FheMul(t *testing.T, fheUintType fheUintType) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != expected {
if decrypted.Uint64() != expected {
t.Fatalf("invalid decrypted result")
}
}
Expand Down Expand Up @@ -601,7 +601,7 @@ func FheLte(t *testing.T, fheUintType fheUintType) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 0 {
if decrypted.Uint64() != 0 {
t.Fatalf("invalid decrypted result")
}

Expand All @@ -616,7 +616,7 @@ func FheLte(t *testing.T, fheUintType fheUintType) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted = res.ciphertext.decrypt()
if decrypted != 1 {
if decrypted.Uint64() != 1 {
t.Fatalf("invalid decrypted result")
}
}
Expand Down Expand Up @@ -656,7 +656,7 @@ func FheLt(t *testing.T, fheUintType fheUintType) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 0 {
if decrypted.Uint64() != 0 {
t.Fatalf("invalid decrypted result")
}

Expand All @@ -671,7 +671,7 @@ func FheLt(t *testing.T, fheUintType fheUintType) {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted = res.ciphertext.decrypt()
if decrypted != 1 {
if decrypted.Uint64() != 1 {
t.Fatalf("invalid decrypted result")
}
}
Expand Down
27 changes: 19 additions & 8 deletions core/vm/tfhe.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,10 @@ import "C"

import (
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"math/big"
"os"
"runtime"
"sync/atomic"
Expand Down Expand Up @@ -509,7 +511,7 @@ type tfheCiphertext struct {
ptr unsafe.Pointer
serialization []byte
hash []byte
value *uint64
value *big.Int
random bool
fheUintType fheUintType
}
Expand All @@ -536,17 +538,26 @@ func (ct *tfheCiphertext) deserialize(in []byte, t fheUintType) error {
return nil
}

func (ct *tfheCiphertext) encrypt(value uint64, t fheUintType) {
func (ct *tfheCiphertext) encrypt(value big.Int, t fheUintType) {
if ct.initialized() {
panic("cannot encrypt to an existing ciphertext")
}

switch t {
case FheUint8:
ct.setPtr(C.client_key_encrypt_fhe_uint8(cks, C.uchar(value)))
valBytes := [1]byte{}
value.FillBytes(valBytes[:])
ct.setPtr(C.client_key_encrypt_fhe_uint8(cks, C.uchar(valBytes[len(valBytes)-1])))
case FheUint16:
ct.setPtr(C.client_key_encrypt_fhe_uint16(cks, C.ushort(value)))
valBytes := [2]byte{}
value.FillBytes(valBytes[:])
var valInt uint16 = binary.BigEndian.Uint16(valBytes[:])
ct.setPtr(C.client_key_encrypt_fhe_uint16(cks, C.ushort(valInt)))
case FheUint32:
ct.setPtr(C.client_key_encrypt_fhe_uint32(cks, C.uint(value)))
valBytes := [4]byte{}
value.FillBytes(valBytes[:])
var valInt uint32 = binary.BigEndian.Uint32(valBytes[:])
ct.setPtr(C.client_key_encrypt_fhe_uint32(cks, C.uint(valInt)))
}
ct.fheUintType = t
ct.value = &value
Expand Down Expand Up @@ -700,7 +711,7 @@ func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) (*tfheCiphertext, error) {
return res, nil
}

func (ct *tfheCiphertext) decrypt() uint64 {
func (ct *tfheCiphertext) decrypt() big.Int {
if !ct.availableForOps() {
panic("cannot decrypt a null ciphertext")
} else if ct.value != nil {
Expand All @@ -715,8 +726,8 @@ func (ct *tfheCiphertext) decrypt() uint64 {
case FheUint32:
value = uint64(C.decrypt_fhe_uint32(cks, ct.ptr))
}
ct.value = &value
return value
ct.value = new(big.Int).SetUint64(value)
return *ct.value
}

func (ct *tfheCiphertext) setPtr(ptr unsafe.Pointer) {
Expand Down
Loading

0 comments on commit cbb0464

Please sign in to comment.