diff --git a/core/vm/contracts.go b/core/vm/contracts.go index e43b0d2b43c8..f90ad9f30297 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -40,7 +40,6 @@ import ( "github.com/ethereum/go-ethereum/params" "github.com/holiman/uint256" "github.com/naoina/toml" - "golang.org/x/crypto/chacha20" "golang.org/x/crypto/ripemd160" ) @@ -74,8 +73,9 @@ var PrecompiledContractsHomestead = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, - common.BytesToAddress([]byte{74}): &fheRand{}, + // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -101,8 +101,9 @@ var PrecompiledContractsByzantium = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, - common.BytesToAddress([]byte{74}): &fheRand{}, + // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -129,8 +130,9 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, - common.BytesToAddress([]byte{74}): &fheRand{}, + // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -157,8 +159,9 @@ var PrecompiledContractsBerlin = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, - common.BytesToAddress([]byte{74}): &fheRand{}, + // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -185,8 +188,9 @@ var PrecompiledContractsBLS = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{71}): &fheSub{}, common.BytesToAddress([]byte{72}): &fheMul{}, common.BytesToAddress([]byte{73}): &fheLt{}, - common.BytesToAddress([]byte{74}): &fheRand{}, + // common.BytesToAddress([]byte{74}): &fheRand{}, common.BytesToAddress([]byte{75}): &optimisticRequire{}, + common.BytesToAddress([]byte{76}): &cast{}, common.BytesToAddress([]byte{99}): &faucet{}, } @@ -1138,12 +1142,6 @@ type tomlConfigOptions struct { RequireRetryCount uint8 } - Zk struct { - Verify bool - VerifyRPCAddress string - VerifyRetryCount uint8 - } - Tfhe struct { CiphertextsToGarbageCollect uint64 CiphertextsGarbageCollectIntervalSecs uint64 @@ -1169,7 +1167,6 @@ func generateEd25519Keys() error { } var requireHttpClient http.Client = http.Client{} -var zkHttpClient http.Client = http.Client{} var publicSignatureKey []byte var privateSignatureKey []byte @@ -1267,9 +1264,9 @@ func importCiphertext(accessibleState PrecompileAccessibleState, ct *tfheCiphert } // Used when we want to skip FHE computation, e.g. gas estimation. -func importRandomCiphertext(accessibleState PrecompileAccessibleState) []byte { +func importRandomCiphertext(accessibleState PrecompileAccessibleState, t fheUintType) []byte { ct := new(tfheCiphertext) - ct.makeRandom() + ct.makeRandom(t) importCiphertext(accessibleState, ct) ctHash := ct.getHash() return ctHash[:] @@ -1298,14 +1295,21 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState), nil + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil + } + + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + return nil, errors.New("only same type ops are supported for now") } - result := lhs.ciphertext.add(rhs.ciphertext) + result, err := lhs.ciphertext.add(rhs.ciphertext) + if err != nil { + return nil, err + } importCiphertext(accessibleState, result) // TODO: for testing - err := os.WriteFile("/tmp/add_result", result.serialize(), 0644) + err = os.WriteFile("/tmp/add_result", result.serialize(), 0644) if err != nil { return nil, err } @@ -1314,13 +1318,13 @@ func (e *fheAdd) Run(accessibleState PrecompileAccessibleState, caller common.Ad return ctHash[:], nil } -func fheEncryptToUserKey(value uint64, userAddress common.Address) ([]byte, error) { +func fheEncryptToUserKey(value uint64, userAddress common.Address, t fheUintType) ([]byte, error) { userPublicKey := strings.ToLower(usersKeysDir + userAddress.Hex()) pks, err := os.ReadFile(userPublicKey) if err != nil { return nil, err } - ct := publicKeyEncrypt(pks, value) + ct := publicKeyEncrypt(pks, value, t) // TODO: for testing err = os.WriteFile("/tmp/public_encrypt_result", ct, 0644) @@ -1342,50 +1346,23 @@ func (e *verifyCiphertext) RequiredGas(input []byte) uint64 { return 8 } -// Returns the verified ciphertext on success or nil on invalid ZK proof. -// Exits the process on errors. -func verifyZkProof(input []byte) []byte { - for try := uint8(1); try <= tomlConfig.Zk.VerifyRetryCount+1; try++ { - req, err := http.NewRequest(http.MethodPost, tomlConfig.Zk.VerifyRPCAddress, bytes.NewReader(input)) - if err != nil { - continue - } - req.Header.Add("Content-Type", "application/msgpack") - resp, err := zkHttpClient.Do(req) - if err != nil { - continue - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if resp.StatusCode == 406 { - // The ZKPoK service returns 406 if the proof is incorrect. - return nil - } else if resp.StatusCode != 200 || err != nil { - continue - } - return body +func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + if len(input) <= 1 { + return nil, errors.New("input needs to contain one 256-bit sized values and one 8-bit sized value") } - exitProcess() - return nil -} -func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + ctBytes := input[:len(input)-1] + ctType := fheUintType(input[len(input)-1]) + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState), nil - } - var ctBytes []byte - if !tomlConfig.Zk.Verify { - // For testing: if input size <= `fheCiphertextSize`, treat the whole input as ciphertext. - ctBytes = input[0:minInt(fheCiphertextSize, len(input))] - } else { - ctBytes = verifyZkProof(input) - if ctBytes == nil { - return nil, fmt.Errorf("invalid ZK Proof") - } + return importRandomCiphertext(accessibleState, ctType), nil } + ct := new(tfheCiphertext) - err := ct.deserialize(ctBytes) + err := ct.deserialize(ctBytes, ctType) + ct.fheUintType = ctType + if err != nil { return nil, err } @@ -1423,7 +1400,7 @@ func (e *reencrypt) Run(accessibleState PrecompileAccessibleState, caller common ct := getVerifiedCiphertext(accessibleState, common.BytesToHash(input)) if ct != nil { decryptedValue := ct.ciphertext.decrypt() - reencryptedValue, err := fheEncryptToUserKey(decryptedValue, accessibleState.Interpreter().evm.Origin) + reencryptedValue, err := fheEncryptToUserKey(decryptedValue, accessibleState.Interpreter().evm.Origin, ct.ciphertext.fheUintType) if err != nil { return nil, err } @@ -1583,7 +1560,11 @@ func (e *optimisticRequire) Run(accessibleState PrecompileAccessibleState, calle if accessibleState.Interpreter().optimisticRequire == nil { accessibleState.Interpreter().optimisticRequire = ct.ciphertext } else { - accessibleState.Interpreter().optimisticRequire = accessibleState.Interpreter().optimisticRequire.mul(ct.ciphertext) + optimisticRequire, err := accessibleState.Interpreter().optimisticRequire.mul(ct.ciphertext) + if err != nil { + return nil, err + } + accessibleState.Interpreter().optimisticRequire = optimisticRequire } return nil, nil } @@ -1609,16 +1590,23 @@ func (e *fheLte) Run(accessibleState PrecompileAccessibleState, caller common.Ad return nil, errors.New("unverified ciphertext handle") } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + return nil, errors.New("only same type ops are supported for now") + } + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState), nil + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil } - result := lhs.ciphertext.lte(rhs.ciphertext) + result, err := lhs.ciphertext.lte(rhs.ciphertext) + if err != nil { + return nil, err + } importCiphertext(accessibleState, result) // TODO: for testing - err := os.WriteFile("/tmp/lte_result", result.serialize(), 0644) + err = os.WriteFile("/tmp/lte_result", result.serialize(), 0644) if err != nil { return nil, err } @@ -1649,16 +1637,23 @@ func (e *fheSub) Run(accessibleState PrecompileAccessibleState, caller common.Ad return nil, errors.New("unverified ciphertext handle") } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + return nil, errors.New("only same type ops are supported for now") + } + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState), nil + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil } - result := lhs.ciphertext.sub(rhs.ciphertext) + result, err := lhs.ciphertext.sub(rhs.ciphertext) + if err != nil { + return nil, err + } importCiphertext(accessibleState, result) // TODO: for testing - err := os.WriteFile("/tmp/sub_result", result.serialize(), 0644) + err = os.WriteFile("/tmp/sub_result", result.serialize(), 0644) if err != nil { return nil, err } @@ -1689,16 +1684,23 @@ func (e *fheMul) Run(accessibleState PrecompileAccessibleState, caller common.Ad return nil, errors.New("unverified ciphertext handle") } + if lhs.ciphertext.fheUintType != rhs.ciphertext.fheUintType { + return nil, errors.New("only same type ops are supported for now") + } + // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState), nil + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil } - result := lhs.ciphertext.mul(rhs.ciphertext) + result, err := lhs.ciphertext.mul(rhs.ciphertext) + if err != nil { + return nil, err + } importCiphertext(accessibleState, result) // TODO: for testing - err := os.WriteFile("/tmp/mul_result", result.serialize(), 0644) + err = os.WriteFile("/tmp/mul_result", result.serialize(), 0644) if err != nil { return nil, err } @@ -1731,14 +1733,17 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState), nil + return importRandomCiphertext(accessibleState, lhs.ciphertext.fheUintType), nil } - result := lhs.ciphertext.lt(rhs.ciphertext) + result, err := lhs.ciphertext.lt(rhs.ciphertext) + if err != nil { + return nil, err + } importCiphertext(accessibleState, result) // TODO: for testing - err := os.WriteFile("/tmp/lt_result", result.serialize(), 0644) + err = os.WriteFile("/tmp/lt_result", result.serialize(), 0644) if err != nil { return nil, err } @@ -1748,82 +1753,95 @@ func (e *fheLt) Run(accessibleState PrecompileAccessibleState, caller common.Add return ctHash[:], nil } -type fheRand struct{} - -var globalRngSeed []byte - -var rngNonceKey [32]byte = uint256.NewInt(0).Bytes32() - -func init() { - if chacha20.NonceSizeX != 24 { - panic("expected 24 bytes for NonceSizeX") - } - - // TODO: Since the current implementation is not FHE-based and, hence, not private, - // we just initialize the global seed with non-random public data. We will change - // that once the FHE version is available. - globalRngSeed = make([]byte, chacha20.KeySize) - for i := range globalRngSeed { - globalRngSeed[i] = byte(1 + i) - } -} - -func (e *fheRand) RequiredGas(input []byte) uint64 { - // TODO - return 8 +// type fheRand struct{} + +// var globalRngSeed []byte + +// var rngNonceKey [32]byte = uint256.NewInt(0).Bytes32() + +// func init() { +// if chacha20.NonceSizeX != 24 { +// panic("expected 24 bytes for NonceSizeX") +// } + +// // TODO: Since the current implementation is not FHE-based and, hence, not private, +// // we just initialize the global seed with non-random public data. We will change +// // that once the FHE version is available. +// globalRngSeed = make([]byte, chacha20.KeySize) +// for i := range globalRngSeed { +// globalRngSeed[i] = byte(1 + i) +// } +// } + +// func (e *fheRand) RequiredGas(input []byte) uint64 { +// // TODO +// return 8 +// } + +// func (e *fheRand) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { +// // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. +// if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { +// return importRandomCiphertext(accessibleState), nil +// } + +// // Get the RNG nonce. +// protectedStorage := crypto.CreateProtectedStorageContractAddress(caller) +// currentRngNonceBytes := accessibleState.Interpreter().evm.StateDB.GetState(protectedStorage, rngNonceKey).Bytes() + +// // Increment the RNG nonce by 1. +// nextRngNonce := newInt(currentRngNonceBytes) +// nextRngNonce = nextRngNonce.AddUint64(nextRngNonce, 1) +// accessibleState.Interpreter().evm.StateDB.SetState(protectedStorage, rngNonceKey, nextRngNonce.Bytes32()) + +// // Compute the seed and use it to create a new cipher. +// hasher := crypto.NewKeccakState() +// hasher.Write(globalRngSeed) +// hasher.Write(caller.Bytes()) +// hasher.Write(currentRngNonceBytes) +// seed := common.Hash{} +// _, err := hasher.Read(seed[:]) +// if err != nil { +// return nil, err +// } +// // The RNG nonce bytes are of size chacha20.NonceSizeX, which is assumed to be 24 bytes (see init() above). +// // Since uint256.Int.z[0] is the least significant byte and since uint256.Int.Bytes32() serializes +// // in order of z[3], z[2], z[1], z[0], we want to essentially ignore the first byte, i.e. z[3], because +// // it will always be 0 as the nonce size is 24. +// cipher, err := chacha20.NewUnauthenticatedCipher(seed.Bytes(), currentRngNonceBytes[32-chacha20.NonceSizeX:32]) +// if err != nil { +// return nil, err +// } + +// // XOR a byte array of 0s with the stream from the cipher and receive the result in the same array. +// randBytes := make([]byte, 8) +// cipher.XORKeyStream(randBytes, randBytes) + +// // Trivially encrypt the random integer. +// randInt := binary.BigEndian.Uint64(randBytes) % math.BigPow(2, 3).Uint64() +// randCt := new(tfheCiphertext) +// randCt.trivialEncrypt(randInt) +// importCiphertext(accessibleState, randCt) + +// // TODO: for testing +// err = os.WriteFile("/tmp/rand_result", randCt.serialize(), 0644) +// if err != nil { +// return nil, err +// } +// ctHash := randCt.getHash() +// return ctHash[:], nil +// } + +type cast struct{} + +func (e *cast) RequiredGas(input []byte) uint64 { + return 0 } -func (e *fheRand) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { - // If we are doing gas estimation, skip execution and insert a random ciphertext as a result. - if !accessibleState.Interpreter().evm.Commit && !accessibleState.Interpreter().evm.EthCall { - return importRandomCiphertext(accessibleState), nil - } - - // Get the RNG nonce. - protectedStorage := crypto.CreateProtectedStorageContractAddress(caller) - currentRngNonceBytes := accessibleState.Interpreter().evm.StateDB.GetState(protectedStorage, rngNonceKey).Bytes() - - // Increment the RNG nonce by 1. - nextRngNonce := newInt(currentRngNonceBytes) - nextRngNonce = nextRngNonce.AddUint64(nextRngNonce, 1) - accessibleState.Interpreter().evm.StateDB.SetState(protectedStorage, rngNonceKey, nextRngNonce.Bytes32()) - - // Compute the seed and use it to create a new cipher. - hasher := crypto.NewKeccakState() - hasher.Write(globalRngSeed) - hasher.Write(caller.Bytes()) - hasher.Write(currentRngNonceBytes) - seed := common.Hash{} - _, err := hasher.Read(seed[:]) - if err != nil { - return nil, err - } - // The RNG nonce bytes are of size chacha20.NonceSizeX, which is assumed to be 24 bytes (see init() above). - // Since uint256.Int.z[0] is the least significant byte and since uint256.Int.Bytes32() serializes - // in order of z[3], z[2], z[1], z[0], we want to essentially ignore the first byte, i.e. z[3], because - // it will always be 0 as the nonce size is 24. - cipher, err := chacha20.NewUnauthenticatedCipher(seed.Bytes(), currentRngNonceBytes[32-chacha20.NonceSizeX:32]) - if err != nil { - return nil, err - } - - // XOR a byte array of 0s with the stream from the cipher and receive the result in the same array. - randBytes := make([]byte, 8) - cipher.XORKeyStream(randBytes, randBytes) - - // Trivially encrypt the random integer. - randInt := binary.BigEndian.Uint64(randBytes) % fheMessageModulus - randCt := new(tfheCiphertext) - randCt.trivialEncrypt(randInt) - importCiphertext(accessibleState, randCt) - - // TODO: for testing - err = os.WriteFile("/tmp/rand_result", randCt.serialize(), 0644) - if err != nil { - return nil, err - } - ctHash := randCt.getHash() - return ctHash[:], nil +// Implementation of the following is pending and will be completed once TFHE-rs add type casts to their high-level C API. +func (e *cast) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { + // var ctHandle = common.BytesToHash(input[0:31]) + // var toType = input[32] + return nil, nil } type faucet struct{} diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index ce5ef3d0b9f9..58224287e422 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -430,9 +430,9 @@ func newTestState() *statefulPrecompileAccessibleState { return s } -func verifyCiphertextInTestMemory(interpreter *EVMInterpreter, value uint64, depth int) *tfheCiphertext { +func verifyCiphertextInTestMemory(interpreter *EVMInterpreter, value uint64, depth int, t fheUintType) *tfheCiphertext { ct := new(tfheCiphertext) - ct.encrypt(value) + ct.encrypt(value, t) return verifyTfheCiphertextInTestMemory(interpreter, ct, depth) } @@ -457,8 +457,8 @@ func TestFheAdd(t *testing.T) { state.interpreter.evm.Commit = true addr := common.Address{} readOnly := false - lhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth).getHash() + lhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash() input := toPrecompileInput(lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { @@ -482,8 +482,8 @@ func TestFheSub(t *testing.T) { state.interpreter.evm.Commit = true addr := common.Address{} readOnly := false - lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth).getHash() + lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash() input := toPrecompileInput(lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { @@ -507,8 +507,8 @@ func TestFheMul(t *testing.T) { state.interpreter.evm.Commit = true addr := common.Address{} readOnly := false - lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth).getHash() + lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash() input := toPrecompileInput(lhsHash, rhsHash) out, err := c.Run(state, addr, addr, input, readOnly) if err != nil { @@ -532,8 +532,8 @@ func TestFheLte(t *testing.T) { state.interpreter.evm.Commit = true addr := common.Address{} readOnly := false - lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth).getHash() + lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash() // 2 <= 1 input1 := toPrecompileInput(lhsHash, rhsHash) @@ -574,8 +574,8 @@ func TestFheLt(t *testing.T) { state.interpreter.evm.Commit = true addr := common.Address{} readOnly := false - lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth).getHash() - rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth).getHash() + lhsHash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash() + rhsHash := verifyCiphertextInTestMemory(state.interpreter, 1, depth, FheUint8).getHash() // 2 < 1 input1 := toPrecompileInput(lhsHash, rhsHash) @@ -608,34 +608,34 @@ func TestFheLt(t *testing.T) { } } -func TestFheRand(t *testing.T) { - c := &fheRand{} - depth := 1 - state := newTestState() - state.interpreter.evm.depth = depth - state.interpreter.evm.Commit = true - addr := common.Address{} - readOnly := false - - out, err := c.Run(state, addr, addr, nil, readOnly) - if err != nil { - t.Fatalf(err.Error()) - } - res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) - if res == nil { - t.Fatalf("output ciphertext is not found in verifiedCiphertexts") - } - decrypted := res.ciphertext.decrypt() - if decrypted >= fheMessageModulus { - t.Fatalf("invalid decrypted result") - } -} +// func TestFheRand(t *testing.T) { +// c := &fheRand{} +// depth := 1 +// state := newTestState() +// state.interpreter.evm.depth = depth +// state.interpreter.evm.Commit = true +// addr := common.Address{} +// readOnly := false + +// out, err := c.Run(state, addr, addr, nil, readOnly) +// if err != nil { +// t.Fatalf(err.Error()) +// } +// res := getVerifiedCiphertextFromEVM(state.interpreter, common.BytesToHash(out)) +// if res == nil { +// t.Fatalf("output ciphertext is not found in verifiedCiphertexts") +// } +// decrypted := res.ciphertext.decrypt() +// if decrypted >= math.Pow(2, 3) { +// t.Fatalf("invalid decrypted result") +// } +// } func TestUnknownCiphertextHandle(t *testing.T) { depth := 1 state := newTestState() state.interpreter.evm.depth = depth - hash := verifyCiphertextInTestMemory(state.interpreter, 2, depth).getHash() + hash := verifyCiphertextInTestMemory(state.interpreter, 2, depth, FheUint8).getHash() ct := getVerifiedCiphertext(state, hash) if ct == nil { @@ -654,7 +654,7 @@ func TestCiphertextNotVerifiedWithoutReturn(t *testing.T) { state := newTestState() state.interpreter.evm.depth = 1 verifiedDepth := 2 - hash := verifyCiphertextInTestMemory(state.interpreter, 1, verifiedDepth).getHash() + hash := verifyCiphertextInTestMemory(state.interpreter, 1, verifiedDepth, FheUint8).getHash() ct := getVerifiedCiphertext(state, hash) if ct != nil { @@ -666,7 +666,7 @@ func TestCiphertextNotAutomaticallyDelegated(t *testing.T) { state := newTestState() state.interpreter.evm.depth = 3 verifiedDepth := 2 - hash := verifyCiphertextInTestMemory(state.interpreter, 1, verifiedDepth).getHash() + hash := verifyCiphertextInTestMemory(state.interpreter, 1, verifiedDepth, FheUint8).getHash() ct := getVerifiedCiphertext(state, hash) if ct != nil { @@ -677,7 +677,7 @@ func TestCiphertextNotAutomaticallyDelegated(t *testing.T) { func TestCiphertextVerificationConditions(t *testing.T) { state := newTestState() verifiedDepth := 2 - hash := verifyCiphertextInTestMemory(state.interpreter, 1, verifiedDepth).getHash() + hash := verifyCiphertextInTestMemory(state.interpreter, 1, verifiedDepth, FheUint8).getHash() state.interpreter.evm.depth = verifiedDepth ctPtr := getVerifiedCiphertext(state, hash) diff --git a/core/vm/instructions.go b/core/vm/instructions.go index bb204b06547e..4af266a87e96 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -522,16 +522,18 @@ func opMstore8(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([] } // Ciphertext metadata is stored in protected storage, in a 32-byte slot. -// Currently, we only utilize 16 bytes from the slot. +// Currently, we only utilize 17 bytes from the slot. type ciphertextMetadata struct { - refCount uint64 - length uint64 + refCount uint64 + length uint64 + fheUintType fheUintType } func (m ciphertextMetadata) serialize() [32]byte { u := uint256.NewInt(0) u[0] = m.refCount u[1] = m.length + u[2] = uint64(m.fheUintType) return u.Bytes32() } @@ -540,6 +542,7 @@ func (m *ciphertextMetadata) deserialize(buf [32]byte) *ciphertextMetadata { u.SetBytes(buf[:]) m.refCount = u[0] m.length = u[1] + m.fheUintType = fheUintType(u[2]) return m } @@ -580,7 +583,7 @@ func verifyIfCiphertextHandle(val common.Hash, interpreter *EVMInterpreter, cont ct = verifiedCt.ciphertext } else { ct = new(tfheCiphertext) - err := ct.deserialize(ctBytes) + err := ct.deserialize(ctBytes, metadata.fheUintType) if err != nil { exitProcess() } @@ -612,7 +615,7 @@ func persistIfVerifiedCiphertext(val common.Hash, protectedStorage common.Addres if metadataInt.IsZero() { // If no metadata, it means this ciphertext itself hasn't been persisted to protected storage yet. We do that as part of SSTORE. metadata.refCount = 1 - metadata.length = uint64(fheCiphertextSize) + metadata.length = uint64(fheCiphertextSize[verifiedCiphertext.ciphertext.fheUintType]) ciphertextSlot := newInt(val.Bytes()) ciphertextSlot.AddUint64(ciphertextSlot, 1) ctPart32 := make([]byte, 32) diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index 5c500133c9ea..71fb8fb87d23 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -713,7 +713,7 @@ func (c testCallerAddress) Address() common.Address { func newTestScopeConext() *ScopeContext { c := new(ScopeContext) c.Memory = NewMemory() - c.Memory.Resize(uint64(fheCiphertextSize) * 3) + c.Memory.Resize(uint64(fheCiphertextSize[FheUint8]) * 3) c.Stack = newstack() c.Contract = NewContract(testCallerAddress{}, testContractAddress{}, big.NewInt(10), 100000) return c @@ -731,7 +731,7 @@ func TestProtectedStorageSstoreSload(t *testing.T) { pc := uint64(0) depth := 1 interpreter := newTestInterpreter() - ct := verifyCiphertextInTestMemory(interpreter, 2, depth) + ct := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8) ctHash := ct.getHash() scope := newTestScopeConext() loc := uint256.NewInt(10) @@ -769,7 +769,7 @@ func TestProtectedStorageGarbageCollection(t *testing.T) { pc := uint64(0) depth := 1 interpreter := newTestInterpreter() - ctHash := verifyCiphertextInTestMemory(interpreter, 2, depth).getHash() + ctHash := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8).getHash() scope := newTestScopeConext() loc := uint256.NewInt(10) value := uint256FromBig(ctHash.Big()) @@ -789,7 +789,7 @@ func TestProtectedStorageGarbageCollection(t *testing.T) { if metadata.refCount != 1 { t.Fatalf("metadata.refcount of ciphertext is not 1") } - if metadata.length != uint64(fheCiphertextSize) { + if metadata.length != uint64(fheCiphertextSize[FheUint8]) { t.Fatalf("metadata.length of ciphertext is incorrect") } ciphertextLocationsToCheck := (metadata.length + 32 - 1) / 32 @@ -866,7 +866,7 @@ func TestOpReturnDelegation(t *testing.T) { depth := 2 interpreter := newTestInterpreter() scope := newTestScopeConext() - ct := verifyCiphertextInTestMemory(interpreter, 2, depth) + ct := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8) ctHash := ct.getHash() offset := uint256.NewInt(0) @@ -894,7 +894,7 @@ func TestOpReturnUnverifyIfNotReturned(t *testing.T) { depth := 2 interpreter := newTestInterpreter() scope := newTestScopeConext() - ctHash := verifyCiphertextInTestMemory(interpreter, 2, depth).getHash() + ctHash := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8).getHash() offset := uint256.NewInt(0) len := uint256.NewInt(32) @@ -915,7 +915,7 @@ func TestOpReturnDoesNotUnverifyIfNotVerified(t *testing.T) { pc := uint64(0) interpreter := newTestInterpreter() scope := newTestScopeConext() - ct := verifyCiphertextInTestMemory(interpreter, 2, 4) + ct := verifyCiphertextInTestMemory(interpreter, 2, 4, FheUint8) ctHash := ct.getHash() // Return from depth 3 to depth 2. However, ct is not verified at 3 and, hence, cannot @@ -999,7 +999,7 @@ func TestOpCallDelegatesIfHandleInArgs(t *testing.T) { depth := 2 interpreter := newTestInterpreter() interpreter.evm.depth = depth - ctHash := verifyCiphertextInTestMemory(interpreter, 2, depth).getHash() + ctHash := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8).getHash() pc, scope := setupOpCall(call, interpreter, ctHash) _, err := (*call)(&pc, interpreter, scope) if err != nil { @@ -1027,7 +1027,7 @@ func TestOpCallDoesNotDelegateIfHandleNotInArgs(t *testing.T) { depth := 2 interpreter := newTestInterpreter() interpreter.evm.depth = depth - ctHash := verifyCiphertextInTestMemory(interpreter, 2, depth).getHash() + ctHash := verifyCiphertextInTestMemory(interpreter, 2, depth, FheUint8).getHash() pc, scope := setupOpCall(call, interpreter, common.Hash{}) _, err := (*call)(&pc, interpreter, scope) if err != nil { @@ -1052,7 +1052,7 @@ func TestOpCallVerifySameCiphertextDeeperInStack(t *testing.T) { for _, call := range callsToTest { interpreter := newTestInterpreter() interpreter.evm.depth = 2 - ct := verifyCiphertextInTestMemory(interpreter, 2, 2) + ct := verifyCiphertextInTestMemory(interpreter, 2, 2, FheUint8) pc, scope := setupOpCall(call, interpreter, common.Hash{}) _, err := (*call)(&pc, interpreter, scope) @@ -1079,7 +1079,7 @@ func TestOpCallDoesNotDelegateIfNotVerified(t *testing.T) { verifiedAtDepth := 2 interpreter := newTestInterpreter() interpreter.evm.depth = verifiedAtDepth + 1 - ctHash := verifyCiphertextInTestMemory(interpreter, 2, verifiedAtDepth).getHash() + ctHash := verifyCiphertextInTestMemory(interpreter, 2, verifiedAtDepth, FheUint8).getHash() pc, scope := setupOpCall(call, interpreter, ctHash) _, err := (*call)(&pc, interpreter, scope) if err != nil { diff --git a/core/vm/tfhe.go b/core/vm/tfhe.go index 45b81bdc4858..154e001a61bc 100644 --- a/core/vm/tfhe.go +++ b/core/vm/tfhe.go @@ -26,151 +26,349 @@ package vm #include void* deserialize_server_key(BufferView in) { - ShortintServerKey* sks = NULL; - const int r = shortint_deserialize_server_key(in, &sks); + ServerKey* sks = NULL; + const int r = server_key_deserialize(in, &sks); assert(r == 0); return sks; } void* deserialize_client_key(BufferView in) { - ShortintClientKey* cks = NULL; - const int r = shortint_deserialize_client_key(in, &cks); + ClientKey* cks = NULL; + const int r = client_key_deserialize(in, &cks); assert(r == 0); return cks; } -void* deserialize_tfhe_ciphertext(BufferView in) { - ShortintCiphertext* ct = NULL; - const int r = shortint_deserialize_ciphertext(in, &ct); +void tfhe_set_server_key(void *sks) { + int r = set_server_key(sks); + assert(r == 0); +} + +void serialize_fhe_uint8(void *ct, Buffer* out) { + const int r = fhe_uint8_serialize(ct, out); + assert(r == 0); +} + +void* deserialize_fhe_uint8(BufferView in) { + FheUint8* ct = NULL; + const int r = fhe_uint8_deserialize(in, &ct); + if(r != 0) { + return NULL; + } + return ct; +} + +void serialize_fhe_uint16(void *ct, Buffer* out) { + const int r = fhe_uint16_serialize(ct, out); + assert(r == 0); +} + +void* deserialize_fhe_uint16(BufferView in) { + FheUint16* ct = NULL; + const int r = fhe_uint16_deserialize(in, &ct); if(r != 0) { return NULL; } return ct; } -void serialize_tfhe_ciphertext(void *ct, Buffer* out) { - const int r = shortint_serialize_ciphertext(ct, out); +void serialize_fhe_uint32(void *ct, Buffer* out) { + const int r = fhe_uint32_serialize(ct, out); + assert(r == 0); +} + +void* deserialize_fhe_uint32(BufferView in) { + FheUint32* ct = NULL; + const int r = fhe_uint32_deserialize(in, &ct); + if(r != 0) { + return NULL; + } + return ct; +} + +void destroy_fhe_uint8(void* ct) { + fhe_uint8_destroy(ct); +} + +void destroy_fhe_uint16(void* ct) { + fhe_uint16_destroy(ct); +} + +void destroy_fhe_uint32(void* ct) { + fhe_uint32_destroy(ct); +} + +void* add_fhe_uint8(void* ct1, void* ct2) +{ + FheUint8* result = NULL; + const int r = fhe_uint8_add(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* add_fhe_uint16(void* ct1, void* ct2) +{ + FheUint16* result = NULL; + const int r = fhe_uint16_add(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* add_fhe_uint32(void* ct1, void* ct2) +{ + FheUint32* result = NULL; + const int r = fhe_uint32_add(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* sub_fhe_uint8(void* ct1, void* ct2) +{ + FheUint8* result = NULL; + const int r = fhe_uint8_sub(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* sub_fhe_uint16(void* ct1, void* ct2) +{ + FheUint16* result = NULL; + const int r = fhe_uint16_sub(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* sub_fhe_uint32(void* ct1, void* ct2) +{ + FheUint32* result = NULL; + const int r = fhe_uint32_sub(ct1, ct2, &result); assert(r == 0); + return result; } -void destroy_tfhe_ciphertext(void* ct) { - destroy_shortint_ciphertext(ct); +void* mul_fhe_uint8(void* ct1, void* ct2) +{ + FheUint8* result = NULL; + const int r = fhe_uint8_mul(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* mul_fhe_uint16(void* ct1, void* ct2) +{ + FheUint16* result = NULL; + const int r = fhe_uint16_mul(ct1, ct2, &result); + assert(r == 0); + return result; } -void* tfhe_add(void* sks, void* ct1, void* ct2) +void* mul_fhe_uint32(void* ct1, void* ct2) { - ShortintCiphertext *result = NULL; - const int r = shortint_bc_server_key_smart_add(sks, ct1, ct2, &result); + FheUint32* result = NULL; + const int r = fhe_uint32_mul(ct1, ct2, &result); assert(r == 0); return result; } -void* tfhe_sub(void* sks, void* ct1, void* ct2) +void* le_fhe_uint8(void* ct1, void* ct2) { - ShortintCiphertext *result = NULL; - const int r = shortint_bc_server_key_smart_sub(sks, ct1, ct2, &result); + FheUint8* result = NULL; + const int r = fhe_uint8_le(ct1, ct2, &result); assert(r == 0); return result; } -void* tfhe_mul(void* sks, void* ct1, void* ct2) +void* le_fhe_uint16(void* ct1, void* ct2) { - ShortintCiphertext *result = NULL; - const int r = shortint_bc_server_key_smart_mul(sks, ct1, ct2, &result); + FheUint16* result = NULL; + const int r = fhe_uint16_le(ct1, ct2, &result); assert(r == 0); return result; } -void* tfhe_lte(void* sks, void* ct1, void* ct2) +void* le_fhe_uint32(void* ct1, void* ct2) { - ShortintCiphertext *result = NULL; - const int r = shortint_bc_server_key_smart_less_or_equal(sks, ct1, ct2, &result); + FheUint32* result = NULL; + const int r = fhe_uint32_le(ct1, ct2, &result); assert(r == 0); return result; } -void* tfhe_lt(void* sks, void* ct1, void* ct2) +void* lt_fhe_uint8(void* ct1, void* ct2) { - ShortintCiphertext *result = NULL; - const int r = shortint_bc_server_key_smart_less(sks, ct1, ct2, &result); + FheUint8* result = NULL; + const int r = fhe_uint8_lt(ct1, ct2, &result); assert(r == 0); return result; } -uint64_t decrypt(void* cks, void* ct) +void* lt_fhe_uint16(void* ct1, void* ct2) { - uint64_t res = 0; - const int r = shortint_bc_client_key_decrypt(cks, ct, &res); + FheUint16* result = NULL; + const int r = fhe_uint16_lt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +void* lt_fhe_uint32(void* ct1, void* ct2) +{ + FheUint32* result = NULL; + const int r = fhe_uint32_lt(ct1, ct2, &result); + assert(r == 0); + return result; +} + +uint8_t decrypt_fhe_uint8(void* cks, void* ct) +{ + uint8_t res = 0; + const int r = fhe_uint8_decrypt(ct, cks, &res); assert(r == 0); return res; } -void client_key_encrypt_and_ser(void* cks, uint64_t value, Buffer* out) { - ShortintCiphertext *ct = NULL; +uint16_t decrypt_fhe_uint16(void* cks, void* ct) +{ + uint16_t res = 0; + const int r = fhe_uint16_decrypt(ct, cks, &res); + assert(r == 0); + return res; +} + +uint32_t decrypt_fhe_uint32(void* cks, void* ct) +{ + uint32_t res = 0; + const int r = fhe_uint32_decrypt(ct, cks, &res); + assert(r == 0); + return res; +} + +void client_key_encrypt_and_ser_fhe_uint8(void* cks, uint8_t value, Buffer* out) { + FheUint8* ct = NULL; + + const int encrypt_ok = fhe_uint8_try_encrypt_with_client_key_u8(value, cks, &ct); + assert(encrypt_ok == 0); + + const int ser_ok = fhe_uint8_serialize(ct, out); + assert(ser_ok == 0); + + fhe_uint8_destroy(ct); +} + +void client_key_encrypt_and_ser_fhe_uint16(void* cks, uint16_t value, Buffer* out) { + FheUint16* ct = NULL; + + const int encrypt_ok = fhe_uint16_try_encrypt_with_client_key_u16(value, cks, &ct); + assert(encrypt_ok == 0); + + const int ser_ok = fhe_uint16_serialize(ct, out); + assert(ser_ok == 0); + + fhe_uint16_destroy(ct); +} + +void client_key_encrypt_and_ser_fhe_uint32(void* cks, uint32_t value, Buffer* out) { + FheUint32* ct = NULL; - const int encrypt_ok = shortint_bc_client_key_encrypt(cks, value, &ct); + const int encrypt_ok = fhe_uint32_try_encrypt_with_client_key_u32(value, cks, &ct); assert(encrypt_ok == 0); - const int ser_ok = shortint_serialize_ciphertext(ct, out); + const int ser_ok = fhe_uint32_serialize(ct, out); assert(ser_ok == 0); - destroy_shortint_ciphertext(ct); + fhe_uint32_destroy(ct); } -void* client_key_encrypt(void* cks, uint64_t value) { - ShortintCiphertext *ct = NULL; +void* client_key_encrypt_fhe_uint8(void* cks, uint8_t value) { + FheUint8* ct = NULL; - const int r = shortint_bc_client_key_encrypt(cks, value, &ct); + const int r = fhe_uint8_try_encrypt_with_client_key_u8(value, cks, &ct); assert(r == 0); return ct; } -void public_key_encrypt(BufferView pks_buf, uint64_t value, Buffer* out) +void* client_key_encrypt_fhe_uint16(void* cks, uint16_t value) { + FheUint16* ct = NULL; + + const int r = fhe_uint16_try_encrypt_with_client_key_u16(value, cks, &ct); + assert(r == 0); + + return ct; +} + +void* client_key_encrypt_fhe_uint32(void* cks, uint32_t value) { + FheUint32* ct = NULL; + + const int r = fhe_uint32_try_encrypt_with_client_key_u32(value, cks, &ct); + assert(r == 0); + + return ct; +} + +void public_key_encrypt_fhe_uint8(BufferView pks_buf, uint8_t value, Buffer* out) { - ShortintCiphertext *ct = NULL; - ShortintPublicKey *pks = NULL; + FheUint8 *ct = NULL; + PublicKey *pks = NULL; - const int deser_ok = shortint_deserialize_public_key(pks_buf, &pks); + const int deser_ok = public_key_deserialize(pks_buf, &pks); assert(deser_ok == 0); - const int encrypt_ok = shortint_bc_public_key_encrypt(pks, value, &ct); + const int encrypt_ok = fhe_uint8_try_encrypt_with_public_key_u8(value, pks, &ct); assert(encrypt_ok == 0); - const int ser_ok = shortint_serialize_ciphertext(ct, out); + const int ser_ok = fhe_uint8_serialize(ct, out); assert(ser_ok == 0); - destroy_shortint_public_key(pks); - destroy_shortint_ciphertext(ct); + public_key_destroy(pks); + fhe_uint8_destroy(ct); } -void* trivial_encrypt(void* sks, uint64_t value) { - ShortintCiphertext *ct = NULL; +void public_key_encrypt_fhe_uint16(BufferView pks_buf, uint16_t value, Buffer* out) +{ + FheUint16 *ct = NULL; + PublicKey *pks = NULL; - const int r = shortint_bc_server_key_create_trivial(sks, value, &ct); - assert(r == 0); + const int deser_ok = public_key_deserialize(pks_buf, &pks); + assert(deser_ok == 0); - return ct; + const int encrypt_ok = fhe_uint16_try_encrypt_with_public_key_u16(value, pks, &ct); + assert(encrypt_ok == 0); + + const int ser_ok = fhe_uint16_serialize(ct, out); + assert(ser_ok == 0); + + public_key_destroy(pks); + fhe_uint16_destroy(ct); } -size_t get_message_modulus(void* sks) { - size_t modulus = 0; +void public_key_encrypt_fhe_uint32(BufferView pks_buf, uint32_t value, Buffer* out) +{ + FheUint32 *ct = NULL; + PublicKey *pks = NULL; - const int r = shortint_server_key_get_message_modulus(sks, &modulus); - assert(r == 0); + const int deser_ok = public_key_deserialize(pks_buf, &pks); + assert(deser_ok == 0); - return modulus; -} + const int encrypt_ok = fhe_uint32_try_encrypt_with_public_key_u32(value, pks, &ct); + assert(encrypt_ok == 0); + const int ser_ok = fhe_uint32_serialize(ct, out); + assert(ser_ok == 0); + + public_key_destroy(pks); + fhe_uint32_destroy(ct); +} */ import "C" + +// TODO trivial encrypt + import ( "crypto/rand" "errors" "fmt" "os" "runtime" - "strings" "sync/atomic" "time" "unsafe" @@ -181,8 +379,8 @@ import ( func toBufferView(in []byte) C.BufferView { return C.BufferView{ - pointer: (*C.uchar)(unsafe.Pointer(&in[0])), - length: (C.ulong)(len(in)), + pointer: (*C.uint8_t)(unsafe.Pointer(&in[0])), + length: (C.size_t)(len(in)), } } @@ -195,10 +393,7 @@ func homeDir() string { } // The TFHE ciphertext size, in bytes. -var fheCiphertextSize int - -// The TFHE message modulus. Extracted from the `cks`. -var fheMessageModulus uint64 +var fheCiphertextSize map[fheUintType]uint var sks unsafe.Pointer var cks unsafe.Pointer @@ -230,26 +425,42 @@ func init() { } sks = C.deserialize_server_key(toBufferView(sks_bytes)) - if strings.ToLower(tomlConfig.Oracle.Mode) == "oracle" { - cks_bytes, err := os.ReadFile(networkKeysDir + "cks") - if err != nil { - fmt.Print("WARNING: file cks not found.\n") - return - } - cks = C.deserialize_client_key(toBufferView(cks_bytes)) + cks_bytes, err := os.ReadFile(networkKeysDir + "cks") + if err != nil { + fmt.Print("WARNING: file cks not found.\n") + return } - // Use trivial encryption to determine the ciphertext size for the used parameters. - // Note: parameters are embedded in the client `cks` key. - ct := new(tfheCiphertext) - ct.trivialEncrypt(1) - fheCiphertextSize = len(ct.serialize()) + sks = C.deserialize_server_key(toBufferView(sks_bytes)) + cks = C.deserialize_client_key(toBufferView(cks_bytes)) + + // Cannot use trivial encryption yet as it is not exposed by tfhe-rs + // ct := new(tfheCiphertext) + // ct.trivialEncrypt(1) + // fheCiphertextSize = len(ct.serialize()) - fheMessageModulus = uint64(C.get_message_modulus(sks)) + fheCiphertextSize = make(map[fheUintType]uint) + + fheCiphertextSize[FheUint8] = 28124 + fheCiphertextSize[FheUint16] = 56236 + fheCiphertextSize[FheUint32] = 112460 + + // TODO: understand when and how to set the server key + // C.tfhe_set_server_key(sks) go runGc() } +// Represents a TFHE ciphertext type (i.e., its bit capacity) + +type fheUintType uint8 + +const ( + FheUint8 fheUintType = 0 + FheUint16 fheUintType = 1 + FheUint32 fheUintType = 2 +) + // Represents a TFHE ciphertext. // // Once a ciphertext has a value (either from deserialization, encryption or makeRandom()), @@ -260,13 +471,22 @@ type tfheCiphertext struct { hash []byte value *uint64 random bool + fheUintType fheUintType } -func (ct *tfheCiphertext) deserialize(in []byte) error { +func (ct *tfheCiphertext) deserialize(in []byte, t fheUintType) error { if ct.initialized() { panic("cannot deserialize to an existing ciphertext") } - ptr := C.deserialize_tfhe_ciphertext(toBufferView((in))) + var ptr unsafe.Pointer + switch t { + case FheUint8: + ptr = C.deserialize_fhe_uint8(toBufferView((in))) + case FheUint16: + ptr = C.deserialize_fhe_uint16(toBufferView((in))) + case FheUint32: + ptr = C.deserialize_fhe_uint32(toBufferView((in))) + } if ptr == nil { return errors.New("tfhe ciphertext deserialization failed") } @@ -275,30 +495,38 @@ func (ct *tfheCiphertext) deserialize(in []byte) error { return nil } -func (ct *tfheCiphertext) encrypt(value uint64) { +func (ct *tfheCiphertext) encrypt(value uint64, t fheUintType) { if ct.initialized() { panic("cannot encrypt to an existing ciphertext") } - ct.setPtr(C.client_key_encrypt(cks, C.ulong(value))) + switch t { + case FheUint8: + ct.setPtr(C.client_key_encrypt_fhe_uint8(cks, C.uchar(value))) + case FheUint16: + ct.setPtr(C.client_key_encrypt_fhe_uint16(cks, C.ushort(value))) + case FheUint32: + ct.setPtr(C.client_key_encrypt_fhe_uint32(cks, C.uint(value))) + } ct.value = &value } -func (ct *tfheCiphertext) makeRandom() { +func (ct *tfheCiphertext) makeRandom(t fheUintType) { if ct.initialized() { panic("cannot make an existing ciphertext random") } - ct.serialization = make([]byte, fheCiphertextSize) + ct.serialization = make([]byte, fheCiphertextSize[t]) rand.Read(ct.serialization) + ct.fheUintType = t ct.random = true } -func (ct *tfheCiphertext) trivialEncrypt(value uint64) { - if ct.initialized() { - panic("cannot trivially encrypt to an existing ciphertext") - } - ct.setPtr(C.trivial_encrypt(sks, C.ulong(value))) - ct.value = &value -} +// func (ct *tfheCiphertext) trivialEncrypt(value uint64) { +// if ct.initialized() { +// panic("cannot trivially encrypt to an existing ciphertext") +// } +// ct.setPtr(C.trivial_encrypt(sks, C.ulong(value))) +// ct.value = &value +// } func (ct *tfheCiphertext) serialize() []byte { if !ct.initialized() { @@ -307,55 +535,130 @@ func (ct *tfheCiphertext) serialize() []byte { return ct.serialization } out := &C.Buffer{} - C.serialize_tfhe_ciphertext(ct.ptr, out) + switch ct.fheUintType { + case FheUint8: + C.serialize_fhe_uint8(ct.ptr, out) + case FheUint16: + C.serialize_fhe_uint16(ct.ptr, out) + case FheUint32: + C.serialize_fhe_uint32(ct.ptr, out) + } ct.serialization = C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) C.destroy_buffer(out) return ct.serialization } -func (lhs *tfheCiphertext) add(rhs *tfheCiphertext) *tfheCiphertext { +func (lhs *tfheCiphertext) add(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot add on a non-initialized ciphertext") } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + C.tfhe_set_server_key(sks) res := new(tfheCiphertext) - res.setPtr(C.tfhe_add(sks, lhs.ptr, rhs.ptr)) - return res + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.add_fhe_uint8(lhs.ptr, rhs.ptr)) + case FheUint16: + res.setPtr(C.add_fhe_uint16(lhs.ptr, rhs.ptr)) + case FheUint32: + res.setPtr(C.add_fhe_uint32(lhs.ptr, rhs.ptr)) + } + return res, nil } -func (lhs *tfheCiphertext) sub(rhs *tfheCiphertext) *tfheCiphertext { +func (lhs *tfheCiphertext) sub(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot sub on a non-initialized ciphertext") } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + C.tfhe_set_server_key(sks) res := new(tfheCiphertext) - res.setPtr(C.tfhe_sub(sks, lhs.ptr, rhs.ptr)) - return res + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.sub_fhe_uint8(lhs.ptr, rhs.ptr)) + case FheUint16: + res.setPtr(C.sub_fhe_uint16(lhs.ptr, rhs.ptr)) + case FheUint32: + res.setPtr(C.sub_fhe_uint32(lhs.ptr, rhs.ptr)) + } + return res, nil } -func (lhs *tfheCiphertext) mul(rhs *tfheCiphertext) *tfheCiphertext { +func (lhs *tfheCiphertext) mul(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot mul on a non-initialized ciphertext") } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + res := new(tfheCiphertext) - res.setPtr(C.tfhe_mul(sks, lhs.ptr, rhs.ptr)) - return res + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.mul_fhe_uint8(lhs.ptr, rhs.ptr)) + case FheUint16: + res.setPtr(C.mul_fhe_uint16(lhs.ptr, rhs.ptr)) + case FheUint32: + res.setPtr(C.mul_fhe_uint32(lhs.ptr, rhs.ptr)) + } + return res, nil } -func (lhs *tfheCiphertext) lte(rhs *tfheCiphertext) *tfheCiphertext { +func (lhs *tfheCiphertext) lte(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot lte on a non-initialized ciphertext") } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + + C.tfhe_set_server_key(sks) res := new(tfheCiphertext) - res.setPtr(C.tfhe_lte(sks, lhs.ptr, rhs.ptr)) - return res + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.le_fhe_uint8(lhs.ptr, rhs.ptr)) + case FheUint16: + res.setPtr(C.le_fhe_uint16(lhs.ptr, rhs.ptr)) + case FheUint32: + res.setPtr(C.le_fhe_uint32(lhs.ptr, rhs.ptr)) + } + return res, nil } -func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) *tfheCiphertext { +func (lhs *tfheCiphertext) lt(rhs *tfheCiphertext) (*tfheCiphertext, error) { if !lhs.availableForOps() || !rhs.availableForOps() { panic("cannot lt on a non-initialized ciphertext") } + + if lhs.fheUintType != rhs.fheUintType { + return nil, errors.New("binary operations are only well-defined for identical types") + } + res := new(tfheCiphertext) - res.setPtr(C.tfhe_lt(sks, lhs.ptr, rhs.ptr)) - return res + res.fheUintType = lhs.fheUintType + switch lhs.fheUintType { + case FheUint8: + res.setPtr(C.lt_fhe_uint8(lhs.ptr, rhs.ptr)) + case FheUint16: + res.setPtr(C.lt_fhe_uint16(lhs.ptr, rhs.ptr)) + case FheUint32: + res.setPtr(C.lt_fhe_uint32(lhs.ptr, rhs.ptr)) + } + return res, nil } func (ct *tfheCiphertext) decrypt() uint64 { @@ -364,7 +667,15 @@ func (ct *tfheCiphertext) decrypt() uint64 { } else if ct.value != nil { return *ct.value } - value := uint64(C.decrypt(cks, ct.ptr)) + var value uint64 + switch ct.fheUintType { + case FheUint8: + value = uint64(C.decrypt_fhe_uint8(cks, ct.ptr)) + case FheUint16: + value = uint64(C.decrypt_fhe_uint16(cks, ct.ptr)) + case FheUint32: + value = uint64(C.decrypt_fhe_uint32(cks, ct.ptr)) + } ct.value = &value return value } @@ -375,8 +686,22 @@ func (ct *tfheCiphertext) setPtr(ptr unsafe.Pointer) { } ct.ptr = ptr atomic.AddUint64(&allocatedCiphertexts, 1) + switch ct.fheUintType { + case FheUint8: + runtime.SetFinalizer(ct, func(ct *tfheCiphertext) { + C.destroy_fhe_uint8(ct.ptr) + }) + case FheUint16: + runtime.SetFinalizer(ct, func(ct *tfheCiphertext) { + C.destroy_fhe_uint16(ct.ptr) + }) + case FheUint32: + runtime.SetFinalizer(ct, func(ct *tfheCiphertext) { + C.destroy_fhe_uint32(ct.ptr) + }) + } runtime.SetFinalizer(ct, func(ct *tfheCiphertext) { - C.destroy_tfhe_ciphertext(ct.ptr) + C.destroy_fhe_uint8(ct.ptr) }) } @@ -398,17 +723,31 @@ func (ct *tfheCiphertext) initialized() bool { return (ct.ptr != nil || ct.random) } -func clientKeyEncrypt(value uint64) []byte { +func clientKeyEncrypt(value uint64, t fheUintType) []byte { out := &C.Buffer{} - C.client_key_encrypt_and_ser(cks, C.ulong(value), out) + switch t { + case FheUint8: + C.client_key_encrypt_and_ser_fhe_uint8(cks, C.uchar(value), out) + case FheUint16: + C.client_key_encrypt_and_ser_fhe_uint16(cks, C.ushort(value), out) + case FheUint32: + C.client_key_encrypt_and_ser_fhe_uint32(cks, C.uint(value), out) + } result := C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) C.destroy_buffer(out) return result } -func publicKeyEncrypt(pks []byte, value uint64) []byte { +func publicKeyEncrypt(pks []byte, value uint64, t fheUintType) []byte { out := &C.Buffer{} - C.public_key_encrypt(toBufferView(pks), C.ulong(value), out) + switch t { + case FheUint8: + C.public_key_encrypt_fhe_uint8(toBufferView(pks), C.uchar(value), out) + case FheUint16: + C.public_key_encrypt_fhe_uint16(toBufferView(pks), C.ushort(value), out) + case FheUint32: + C.public_key_encrypt_fhe_uint32(toBufferView(pks), C.uint(value), out) + } result := C.GoBytes(unsafe.Pointer(out.pointer), C.int(out.length)) C.destroy_buffer(out) return result diff --git a/core/vm/tfhe_test.go b/core/vm/tfhe_test.go index 0b159219f5a4..0b537e54a082 100644 --- a/core/vm/tfhe_test.go +++ b/core/vm/tfhe_test.go @@ -27,7 +27,7 @@ import ( func TestTfheCksEncryptDecrypt(t *testing.T) { val := uint64(2) ct := new(tfheCiphertext) - ct.encrypt(val) + ct.encrypt(val, FheUint8) res := ct.decrypt() if res != val { t.Fatalf("%d != %d", val, res) @@ -36,9 +36,9 @@ func TestTfheCksEncryptDecrypt(t *testing.T) { func TestTfheSerializeDeserialize(t *testing.T) { val := uint64(2) - ctBytes := clientKeyEncrypt(val) + ctBytes := clientKeyEncrypt(val, FheUint8) ct := new(tfheCiphertext) - err := ct.deserialize(ctBytes) + err := ct.deserialize(ctBytes, FheUint8) if err != nil { t.Fatalf("deserialization failed") } @@ -50,7 +50,7 @@ func TestTfheSerializeDeserialize(t *testing.T) { func TestTfheDeserializeFailure(t *testing.T) { ct := new(tfheCiphertext) - err := ct.deserialize(make([]byte, 10)) + err := ct.deserialize(make([]byte, 10), FheUint8) if err == nil { t.Fatalf("deserialization must have failed") } @@ -61,10 +61,10 @@ func TestTfheAdd(t *testing.T) { b := uint64(1) expected := uint64(2) ctA := new(tfheCiphertext) - ctA.encrypt(a) + ctA.encrypt(a, FheUint8) ctB := new(tfheCiphertext) - ctB.encrypt(b) - ctRes := ctA.add(ctB) + ctB.encrypt(b, FheUint8) + ctRes, _ := ctA.add(ctB) res := ctRes.decrypt() if res != expected { t.Fatalf("%d != %d", expected, res) @@ -76,10 +76,10 @@ func TestTfheSub(t *testing.T) { b := uint64(1) expected := uint64(1) ctA := new(tfheCiphertext) - ctA.encrypt(a) + ctA.encrypt(a, FheUint8) ctB := new(tfheCiphertext) - ctB.encrypt(b) - ctRes := ctA.sub(ctB) + ctB.encrypt(b, FheUint8) + ctRes, _ := ctA.sub(ctB) res := ctRes.decrypt() if res != expected { t.Fatalf("%d != %d", expected, res) @@ -91,10 +91,10 @@ func TestTfheMul(t *testing.T) { b := uint64(1) expected := uint64(2) ctA := new(tfheCiphertext) - ctA.encrypt(a) + ctA.encrypt(a, FheUint8) ctB := new(tfheCiphertext) - ctB.encrypt(b) - ctRes := ctA.mul(ctB) + ctB.encrypt(b, FheUint8) + ctRes, _ := ctA.mul(ctB) res := ctRes.decrypt() if res != expected { t.Fatalf("%d != %d", expected, res) @@ -105,11 +105,11 @@ func TestTfheLte(t *testing.T) { a := uint64(2) b := uint64(1) ctA := new(tfheCiphertext) - ctA.encrypt(a) + ctA.encrypt(a, FheUint8) ctB := new(tfheCiphertext) - ctB.encrypt(b) - ctRes1 := ctA.lte(ctB) - ctRes2 := ctB.lte(ctA) + ctB.encrypt(b, FheUint8) + ctRes1, _ := ctA.lte(ctB) + ctRes2, _ := ctB.lte(ctA) res1 := ctRes1.decrypt() res2 := ctRes2.decrypt() if res1 != 0 { @@ -123,11 +123,11 @@ func TestTfheLt(t *testing.T) { a := uint64(2) b := uint64(1) ctA := new(tfheCiphertext) - ctA.encrypt(a) + ctA.encrypt(a, FheUint8) ctB := new(tfheCiphertext) - ctB.encrypt(b) - ctRes1 := ctA.lte(ctB) - ctRes2 := ctB.lte(ctA) + ctB.encrypt(b, FheUint8) + ctRes1, _ := ctA.lte(ctB) + ctRes2, _ := ctB.lte(ctA) res1 := ctRes1.decrypt() res2 := ctRes2.decrypt() if res1 != 0 { @@ -138,53 +138,53 @@ func TestTfheLt(t *testing.T) { } } -func TestTfheTrivialEncryptDecrypt(t *testing.T) { - val := uint64(2) - ct := new(tfheCiphertext) - ct.trivialEncrypt(val) - res := ct.decrypt() - if res != val { - t.Fatalf("%d != %d", val, res) - } -} +// func TestTfheTrivialEncryptDecrypt(t *testing.T) { +// val := uint64(2) +// ct := new(tfheCiphertext) +// ct.trivialEncrypt(val) +// res := ct.decrypt() +// if res != val { +// t.Fatalf("%d != %d", val, res) +// } +// } -func TestTfheTrivialAndEncryptedLte(t *testing.T) { - a := uint64(2) - b := uint64(1) - ctA := new(tfheCiphertext) - ctA.encrypt(a) - ctB := new(tfheCiphertext) - ctB.trivialEncrypt(b) - ctRes1 := ctA.lte(ctB) - ctRes2 := ctB.lte(ctA) - res1 := ctRes1.decrypt() - res2 := ctRes2.decrypt() - if res1 != 0 { - t.Fatalf("%d != %d", 0, res1) - } - if res2 != 1 { - t.Fatalf("%d != %d", 0, res2) - } -} +// func TestTfheTrivialAndEncryptedLte(t *testing.T) { +// a := uint64(2) +// b := uint64(1) +// ctA := new(tfheCiphertext) +// ctA.encrypt(a) +// ctB := new(tfheCiphertext) +// ctB.trivialEncrypt(b) +// ctRes1 := ctA.lte(ctB) +// ctRes2 := ctB.lte(ctA) +// res1 := ctRes1.decrypt() +// res2 := ctRes2.decrypt() +// if res1 != 0 { +// t.Fatalf("%d != %d", 0, res1) +// } +// if res2 != 1 { +// t.Fatalf("%d != %d", 0, res2) +// } +// } -func TestTfheTrivialAndEncryptedAdd(t *testing.T) { - a := uint64(1) - b := uint64(1) - ctA := new(tfheCiphertext) - ctA.encrypt(a) - ctB := new(tfheCiphertext) - ctB.trivialEncrypt(b) - ctRes := ctA.add(ctB) - res := ctRes.decrypt() - if res != 2 { - t.Fatalf("%d != %d", 0, res) - } -} +// func TestTfheTrivialAndEncryptedAdd(t *testing.T) { +// a := uint64(1) +// b := uint64(1) +// ctA := new(tfheCiphertext) +// ctA.encrypt(a) +// ctB := new(tfheCiphertext) +// ctB.trivialEncrypt(b) +// ctRes := ctA.add(ctB) +// res := ctRes.decrypt() +// if res != 2 { +// t.Fatalf("%d != %d", 0, res) +// } +// } -func TestTfheTrivialSerializeSize(t *testing.T) { - ct := new(tfheCiphertext) - ct.trivialEncrypt(2) - if len(ct.serialize()) != fheCiphertextSize { - t.Fatalf("serialization of trivially encrypted unexpected size") - } -} +// func TestTfheTrivialSerializeSize(t *testing.T) { +// ct := new(tfheCiphertext) +// ct.trivialEncrypt(2) +// if len(ct.serialize()) != fheCiphertextSize { +// t.Fatalf("serialization of trivially encrypted unexpected size") +// } +// } diff --git a/install_thfe_rs_api.sh b/install_thfe_rs_api.sh index a7be17aa0073..8af191b302ed 100755 --- a/install_thfe_rs_api.sh +++ b/install_thfe_rs_api.sh @@ -3,7 +3,6 @@ git clone https://github.com/zama-ai/tfhe-rs.git mkdir -p core/vm/lib cd tfhe-rs -git checkout blockchain-demo-deterministic-fft make build_c_api cp target/release/libtfhe.* ../core/vm/lib cp target/release/tfhe.h ../core/vm