Skip to content

Commit

Permalink
Merge pull request #1196 from zama-ai/chore/add_load_tfhers_params_dict
Browse files Browse the repository at this point in the history
chore: add load tfhers params from dict
  • Loading branch information
BourgerieQuentin authored Jan 23, 2025
2 parents fd61000 + af6095d commit 04bab98
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/dev/api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@
- [`utils.format_constant`](./concrete.fhe.representation.utils.md): Get the textual representation of a constant.
- [`utils.format_indexing_element`](./concrete.fhe.representation.utils.md): Format an indexing element.
- [`tfhers.get_type_from_params`](./concrete.fhe.tfhers.md): Get a TFHE-rs integer type from TFHE-rs parameters in JSON format.
- [`tfhers.get_type_from_params_dict`](./concrete.fhe.tfhers.md): Get a TFHE-rs integer type from TFHE-rs parameters in JSON format.
- [`bridge.new_bridge`](./concrete.fhe.tfhers.bridge.md): Create a TFHErs bridge from a circuit or module.
- [`tracing.from_native`](./concrete.fhe.tfhers.tracing.md): Convert a Concrete integer to the tfhers representation.
- [`tracing.to_native`](./concrete.fhe.tfhers.tracing.md): Convert a tfhers integer to the Concrete representation.
Expand Down
33 changes: 32 additions & 1 deletion docs/dev/api/concrete.fhe.tfhers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tfhers module to represent, and compute on tfhers integer values.

---

<a href="../../frontends/concrete-python/concrete/fhe/tfhers/__init__.py#L26"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>
<a href="../../frontends/concrete-python/concrete/fhe/tfhers/__init__.py#L27"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_type_from_params`

Expand Down Expand Up @@ -43,3 +43,34 @@ Get a TFHE-rs integer type from TFHE-rs parameters in JSON format.
- <b>`TFHERSIntegerType`</b>: constructed type from the loaded parameters


---

<a href="../../frontends/concrete-python/concrete/fhe/tfhers/__init__.py#L48"><img align="right" style="float:right;" src="https://img.shields.io/badge/-source-cccccc?style=flat-square"></a>

## <kbd>function</kbd> `get_type_from_params_dict`

```python
get_type_from_params_dict(
crypto_param_dict: Dict,
is_signed: bool,
precision: int
) → TFHERSIntegerType
```

Get a TFHE-rs integer type from TFHE-rs parameters in JSON format.



**Args:**

- <b>`crypto_param_dict`</b> (Dict): dictionary of TFHE-rs parameters
- <b>`is_signed`</b> (bool): sign of the result type
- <b>`precision`</b> (int): precision of the result type



**Returns:**

- <b>`TFHERSIntegerType`</b>: constructed type from the loaded parameters


18 changes: 18 additions & 0 deletions frontends/concrete-python/concrete/fhe/tfhers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
from math import log2
from typing import Dict

from .bridge import new_bridge
from .dtypes import (
Expand Down Expand Up @@ -41,6 +42,23 @@ def get_type_from_params(
with open(path_to_params_json) as f:
crypto_param_dict = json.load(f)

return get_type_from_params_dict(crypto_param_dict, is_signed, precision)


def get_type_from_params_dict(
crypto_param_dict: Dict, is_signed: bool, precision: int
) -> TFHERSIntegerType:
"""Get a TFHE-rs integer type from TFHE-rs parameters in JSON format.
Args:
crypto_param_dict (Dict): dictionary of TFHE-rs parameters
is_signed (bool): sign of the result type
precision (int): precision of the result type
Returns:
TFHERSIntegerType: constructed type from the loaded parameters
"""

lwe_dim = crypto_param_dict["lwe_dimension"]
glwe_dim = crypto_param_dict["glwe_dimension"]
poly_size = crypto_param_dict["polynomial_size"]
Expand Down
47 changes: 47 additions & 0 deletions frontends/concrete-python/tests/dtypes/test_tfhers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Tests of `TFHERSIntegerType` data type.
"""

import json
import os
import tempfile

import numpy as np
import pytest

Expand All @@ -18,6 +22,24 @@
tfhers.EncryptionKeyChoice.BIG,
)

DEFAULT_TFHERS_PARAMS_DICT = {
"lwe_dimension": 902,
"glwe_dimension": 1,
"polynomial_size": 4096,
"lwe_noise_distribution": {"Gaussian": {"std": 1.0994794733558207e-6, "mean": 0.0}},
"glwe_noise_distribution": {"Gaussian": {"std": 2.168404344971009e-19, "mean": 0.0}},
"pbs_base_log": 15,
"pbs_level": 2,
"ks_base_log": 3,
"ks_level": 6,
"message_modulus": 4,
"carry_modulus": 8,
"max_noise_level": 10,
"log2_p_fail": -64.084,
"ciphertext_modulus": {"modulus": 0, "scalar_bits": 64},
"encryption_key_choice": "Big",
}


def parameterize_partial_dtype(partial_dtype) -> tfhers.TFHERSIntegerType:
"""Create a tfhers type from a partial func missing tfhers params.
Expand Down Expand Up @@ -150,3 +172,28 @@ def test_tfhers_encryption_variance(crypto_params: tfhers.CryptoParams):
return
assert crypto_params.encryption_key_choice == tfhers.EncryptionKeyChoice.SMALL
assert crypto_params.encryption_variance() == crypto_params.lwe_noise_distribution**2


@pytest.mark.parametrize("params_dict", (DEFAULT_TFHERS_PARAMS_DICT,))
@pytest.mark.parametrize("is_signed", [True, False])
@pytest.mark.parametrize("n_bits", [5, 8, 13, 16])
def test_load_tfhers_params_dict(params_dict, is_signed, n_bits):
dtype = tfhers.get_type_from_params_dict(params_dict, is_signed, n_bits)
assert dtype.bit_width == n_bits
assert dtype.is_signed == is_signed

test_keys = ["lwe_dimension", "glwe_dimension", "polynomial_size", "pbs_base_log", "pbs_level"]

for k in test_keys:
assert getattr(dtype.params, k) == params_dict[k]


@pytest.mark.parametrize("params_dict", (DEFAULT_TFHERS_PARAMS_DICT,))
def test_load_tfhers_params_file(params_dict):
ftemp = tempfile.NamedTemporaryFile(delete=False)
fpath = ftemp.name
ftemp.write(bytes(json.dumps(params_dict), "utf8"))
ftemp.close()

tfhers.get_type_from_params(fpath, True, 8)
os.unlink(fpath)

0 comments on commit 04bab98

Please sign in to comment.