Skip to content

Commit

Permalink
Add FHE multiplication as a precompiled contract
Browse files Browse the repository at this point in the history
  • Loading branch information
dartdart26 committed Feb 1, 2023
1 parent 69631db commit 9ed77d6
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 1 deletion.
49 changes: 49 additions & 0 deletions core/vm/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
}

// PrecompiledContractsByzantium contains the default set of pre-compiled Ethereum
Expand All @@ -93,6 +94,7 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
}

// PrecompiledContractsIstanbul contains the default set of pre-compiled Ethereum
Expand All @@ -116,6 +118,7 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
}

// PrecompiledContractsBerlin contains the default set of pre-compiled Ethereum
Expand All @@ -139,6 +142,7 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
}

// PrecompiledContractsBLS contains the set of pre-compiled Ethereum
Expand All @@ -162,6 +166,7 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{
common.BytesToAddress([]byte{69}): &require{},
common.BytesToAddress([]byte{70}): &fheLte{},
common.BytesToAddress([]byte{71}): &fheSub{},
common.BytesToAddress([]byte{72}): &fheMul{},
}

var (
Expand Down Expand Up @@ -1582,3 +1587,47 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad

return ctHash[:], nil
}

type fheMul struct{}

func (e *fheMul) RequiredGas(input []byte) uint64 {
// TODO
return 8
}

func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {
if len(input) != 64 {
return nil, errors.New("input needs to contain two 256-bit sized values")
}

lhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[0:32]))
if !exists {
return nil, errors.New("unverified ciphertext handle")
}
rhsCt, exists := getVerifiedCiphertext(accessibleState, common.BytesToHash(input[32:64]))
if !exists {
return nil, errors.New("unverified ciphertext handle")
}

// If we are not committing state, skip execution and insert a random ciphertext as a result.
if !accessibleState.Interpreter().evm.Commit {
return importRandomCiphertext(accessibleState), nil
}

result := lhsCt.mul(rhsCt)
verifiedCiphertext := &verifiedCiphertext{
depth: accessibleState.Interpreter().evm.depth,
ciphertext: result,
}

// TODO: for testing
err := os.WriteFile("/tmp/mul_result", verifiedCiphertext.ciphertext.serialize(), 0644)
if err != nil {
return nil, err
}

ctHash := result.getHash()
accessibleState.Interpreter().verifiedCiphertexts[ctHash] = verifiedCiphertext

return ctHash[:], nil
}
25 changes: 25 additions & 0 deletions core/vm/contracts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,31 @@ func TestFheSub(t *testing.T) {
}
}

func TestFheMul(t *testing.T) {
c := &fheMul{}
depth := 1
state := newTestState()
state.interpreter.evm.depth = depth
state.interpreter.evm.Commit = true
addr := common.Address{}
readOnly := false
_, lhs_hash := verifyCiphertextInTestState(state.interpreter, 2, depth)
_, rhs_hash := verifyCiphertextInTestState(state.interpreter, 1, depth)
input := toPrecompileInput(lhs_hash, rhs_hash)
out, err := c.Run(state, addr, addr, input, readOnly)
if err != nil {
t.Fatalf(err.Error())
}
res, exists := state.interpreter.verifiedCiphertexts[common.BytesToHash(out)]
if !exists {
t.Fatalf("output ciphertext is not found in verifiedCiphertexts")
}
decrypted := res.ciphertext.decrypt()
if decrypted != 2 {
t.Fatalf("invalid decrypted result")
}
}

func TestFheLte(t *testing.T) {
c := &fheLte{}
depth := 1
Expand Down
17 changes: 17 additions & 0 deletions core/vm/tfhe.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ void* tfhe_sub(void* sks, void* ct1, void* ct2)
return result;
}
void* tfhe_mul(void* sks, void* ct1, void* ct2)
{
ShortintCiphertext *result = NULL;
const int r = shortint_bc_server_key_smart_mul(sks, ct1, ct2, &result);
assert(r == 0);
return result;
}
void* tfhe_lte(void* sks, void* ct1, void* ct2)
{
ShortintCiphertext *result = NULL;
Expand Down Expand Up @@ -271,6 +279,15 @@ func (lhs *tfheCiphertext) sub(rhs *tfheCiphertext) *tfheCiphertext {
return res
}

func (lhs *tfheCiphertext) mul(rhs *tfheCiphertext) *tfheCiphertext {
if !lhs.availableForOps() || !rhs.availableForOps() {
panic("cannot mul on a non-initialized ciphertext")
}
res := new(tfheCiphertext)
res.setPtr(C.tfhe_mul(sks, lhs.ptr, rhs.ptr))
return res
}

func (lhs *tfheCiphertext) lte(rhs *tfheCiphertext) *tfheCiphertext {
if !lhs.availableForOps() || !rhs.availableForOps() {
panic("cannot lte on a non-initialized ciphertext")
Expand Down
15 changes: 15 additions & 0 deletions core/vm/tfhe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ func TestTfheSub(t *testing.T) {
}
}

func TestTfheMul(t *testing.T) {
a := uint64(2)
b := uint64(1)
expected := uint64(2)
ctA := new(tfheCiphertext)
ctA.encrypt(a)
ctB := new(tfheCiphertext)
ctB.encrypt(b)
ctRes := ctA.mul(ctB)
res := ctRes.decrypt()
if res != expected {
t.Fatalf("%d != %d", expected, res)
}
}

func TestTfheLte(t *testing.T) {
a := uint64(2)
b := uint64(1)
Expand Down
2 changes: 1 addition & 1 deletion install_thfe_rs_api.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
git clone git@github.com:zama-ai/tfhe-rs.git
mkdir -p core/vm/lib
cd tfhe-rs
git checkout blockchain-demo
git checkout blockchain-demo-deterministic-fft
make build_c_api
cp target/release/libtfhe.* ../core/vm/lib
cp target/release/tfhe.h ../core/vm

0 comments on commit 9ed77d6

Please sign in to comment.