Skip to content

Commit

Permalink
Merge pull request ethereum#22 from zama-ai/petar/fhe-subtract
Browse files Browse the repository at this point in the history
Add support for the `fheSub` precompiled contract
  • Loading branch information
dartdart26 authored Dec 21, 2022
2 parents 87161ee + 060b231 commit 00d6727
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,33 @@ void add_encrypted_integers(BufferView sks_view, BufferView ct1_view, BufferView
destroy_shortint_ciphertext(result_ct);
}
void sub_encrypted_integers(BufferView sks_view, BufferView ct1_view, BufferView ct2_view, Buffer* result)
{
ShortintServerKey *sks = NULL;
ShortintCiphertext *ct1 = NULL;
ShortintCiphertext *ct2 = NULL;
ShortintCiphertext *result_ct = NULL;
int deser_sks_ok = shortint_deserialize_server_key(sks_view, &sks);
assert(deser_sks_ok == 0);
int deser_ct1_ok = shortint_deserialize_ciphertext(ct1_view, &ct1);
assert(deser_ct1_ok == 0);
int deser_ct2_ok = shortint_deserialize_ciphertext(ct2_view, &ct2);
assert(deser_ct2_ok == 0);
int add_ok = shortint_server_key_smart_sub(sks, ct1, ct2, &result_ct);
assert(add_ok == 0);
int ser_ok = shortint_serialize_ciphertext(result_ct, result);
assert(ser_ok == 0);
destroy_shortint_server_key(sks);
destroy_shortint_ciphertext(ct1);
destroy_shortint_ciphertext(ct2);
destroy_shortint_ciphertext(result_ct);
}
void encrypt_integer(BufferView cks_buff_view, uint64_t val, Buffer* ct_buf)
{
Expand Down Expand Up @@ -160,6 +187,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{68}): &delegateCiphertext{},
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
}

// PrecompiledContractsByzantium contains the default set of pre-compiled Ethereum
Expand All @@ -181,6 +209,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{68}): &delegateCiphertext{},
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
}

// PrecompiledContractsIstanbul contains the default set of pre-compiled Ethereum
Expand All @@ -203,6 +232,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{68}): &delegateCiphertext{},
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
}

// PrecompiledContractsBerlin contains the default set of pre-compiled Ethereum
Expand All @@ -225,6 +255,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{68}): &delegateCiphertext{},
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
}

// PrecompiledContractsBLS contains the set of pre-compiled Ethereum
Expand All @@ -247,6 +278,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{68}): &delegateCiphertext{},
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
}

var (
Expand Down Expand Up @@ -1684,3 +1716,71 @@ func (e *fheLte) Run(accessibleState PrecompileAccessibleState, caller common.Ad
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext
return ctHash[:], nil
}

type fheSub struct{}

func (e *fheSub) RequiredGas(input []byte) uint64 {
// TODO
return 8
}

func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
if len(input) != 64 {
return nil, errors.New("input needs to contain two 256-bit sized values")
}

verifiedCiphertext1, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
return nil, errors.New("unverified ciphertext handle")
}
verifiedCiphertext2, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
return nil, errors.New("unverified ciphertext handle")
}

sks, err := os.ReadFile(networkKeysDir + "sks")
if err != nil {
return nil, err
}

cCiphertext1 := C.CBytes(verifiedCiphertext1)
viewCiphertext1 := C.BufferView{
pointer: (*C.uchar)(cCiphertext1),
length: (C.ulong)(len(verifiedCiphertext1)),
}

cCiphertext2 := C.CBytes(verifiedCiphertext2)
viewCiphertext2 := C.BufferView{
pointer: (*C.uchar)(cCiphertext2),
length: (C.ulong)(len(verifiedCiphertext2)),
}

cServerKey := C.CBytes(sks)
viewServerKey := C.BufferView{
pointer: (*C.uchar)(cServerKey),
length: (C.ulong)(len(sks)),
}

result := &C.Buffer{}
C.sub_encrypted_integers(viewServerKey, viewCiphertext1, viewCiphertext2, result)

ctBytes := C.GoBytes(unsafe.Pointer(result.pointer), C.int(result.length))
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: ctBytes,
}

err = os.WriteFile("/tmp/add_result", ctBytes, 0644)
if err != nil {
return nil, err
}

ctHash := crypto.Keccak256Hash(verifiedCiphertext.ciphertext)
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

C.free(cServerKey)
C.free(cCiphertext1)
C.free(cCiphertext2)

return ctHash[:], nil
}

0 comments on commit 00d6727

Please sign in to comment.