diff --git a/docs/dev/api/README.md b/docs/dev/api/README.md index 349fec2528..fb7482710a 100644 --- a/docs/dev/api/README.md +++ b/docs/dev/api/README.md @@ -364,6 +364,7 @@ - [`utils.format_constant`](./concrete.fhe.representation.utils.md): Get the textual representation of a constant. - [`utils.format_indexing_element`](./concrete.fhe.representation.utils.md): Format an indexing element. - [`tfhers.get_type_from_params`](./concrete.fhe.tfhers.md): Get a TFHE-rs integer type from TFHE-rs parameters in JSON format. +- [`tfhers.get_type_from_params_dict`](./concrete.fhe.tfhers.md): Get a TFHE-rs integer type from TFHE-rs parameters in JSON format. - [`bridge.new_bridge`](./concrete.fhe.tfhers.bridge.md): Create a TFHErs bridge from a circuit or module. - [`tracing.from_native`](./concrete.fhe.tfhers.tracing.md): Convert a Concrete integer to the tfhers representation. - [`tracing.to_native`](./concrete.fhe.tfhers.tracing.md): Convert a tfhers integer to the Concrete representation. diff --git a/docs/dev/api/concrete.fhe.tfhers.md b/docs/dev/api/concrete.fhe.tfhers.md index 1d1d7519b2..20c3c7d265 100644 --- a/docs/dev/api/concrete.fhe.tfhers.md +++ b/docs/dev/api/concrete.fhe.tfhers.md @@ -14,7 +14,7 @@ tfhers module to represent, and compute on tfhers integer values. --- - + ## function `get_type_from_params` @@ -43,3 +43,34 @@ Get a TFHE-rs integer type from TFHE-rs parameters in JSON format. - `TFHERSIntegerType`: constructed type from the loaded parameters +--- + + + +## function `get_type_from_params_dict` + +```python +get_type_from_params_dict( + crypto_param_dict: Dict, + is_signed: bool, + precision: int +) → TFHERSIntegerType +``` + +Get a TFHE-rs integer type from TFHE-rs parameters in JSON format. + + + +**Args:** + + - `crypto_param_dict` (Dict): dictionary of TFHE-rs parameters + - `is_signed` (bool): sign of the result type + - `precision` (int): precision of the result type + + + +**Returns:** + + - `TFHERSIntegerType`: constructed type from the loaded parameters + + diff --git a/frontends/concrete-python/concrete/fhe/tfhers/__init__.py b/frontends/concrete-python/concrete/fhe/tfhers/__init__.py index a604ee00a0..84612c92ff 100644 --- a/frontends/concrete-python/concrete/fhe/tfhers/__init__.py +++ b/frontends/concrete-python/concrete/fhe/tfhers/__init__.py @@ -4,6 +4,7 @@ import json from math import log2 +from typing import Dict from .bridge import new_bridge from .dtypes import ( @@ -41,6 +42,23 @@ def get_type_from_params( with open(path_to_params_json) as f: crypto_param_dict = json.load(f) + return get_type_from_params_dict(crypto_param_dict, is_signed, precision) + + +def get_type_from_params_dict( + crypto_param_dict: Dict, is_signed: bool, precision: int +) -> TFHERSIntegerType: + """Get a TFHE-rs integer type from TFHE-rs parameters in JSON format. + + Args: + crypto_param_dict (Dict): dictionary of TFHE-rs parameters + is_signed (bool): sign of the result type + precision (int): precision of the result type + + Returns: + TFHERSIntegerType: constructed type from the loaded parameters + """ + lwe_dim = crypto_param_dict["lwe_dimension"] glwe_dim = crypto_param_dict["glwe_dimension"] poly_size = crypto_param_dict["polynomial_size"] diff --git a/frontends/concrete-python/tests/dtypes/test_tfhers.py b/frontends/concrete-python/tests/dtypes/test_tfhers.py index 4cccca7052..f8ca1f38bd 100644 --- a/frontends/concrete-python/tests/dtypes/test_tfhers.py +++ b/frontends/concrete-python/tests/dtypes/test_tfhers.py @@ -2,6 +2,10 @@ Tests of `TFHERSIntegerType` data type. """ +import json +import os +import tempfile + import numpy as np import pytest @@ -18,6 +22,24 @@ tfhers.EncryptionKeyChoice.BIG, ) +DEFAULT_TFHERS_PARAMS_DICT = { + "lwe_dimension": 902, + "glwe_dimension": 1, + "polynomial_size": 4096, + "lwe_noise_distribution": {"Gaussian": {"std": 1.0994794733558207e-6, "mean": 0.0}}, + "glwe_noise_distribution": {"Gaussian": {"std": 2.168404344971009e-19, "mean": 0.0}}, + "pbs_base_log": 15, + "pbs_level": 2, + "ks_base_log": 3, + "ks_level": 6, + "message_modulus": 4, + "carry_modulus": 8, + "max_noise_level": 10, + "log2_p_fail": -64.084, + "ciphertext_modulus": {"modulus": 0, "scalar_bits": 64}, + "encryption_key_choice": "Big", +} + def parameterize_partial_dtype(partial_dtype) -> tfhers.TFHERSIntegerType: """Create a tfhers type from a partial func missing tfhers params. @@ -150,3 +172,28 @@ def test_tfhers_encryption_variance(crypto_params: tfhers.CryptoParams): return assert crypto_params.encryption_key_choice == tfhers.EncryptionKeyChoice.SMALL assert crypto_params.encryption_variance() == crypto_params.lwe_noise_distribution**2 + + +@pytest.mark.parametrize("params_dict", (DEFAULT_TFHERS_PARAMS_DICT,)) +@pytest.mark.parametrize("is_signed", [True, False]) +@pytest.mark.parametrize("n_bits", [5, 8, 13, 16]) +def test_load_tfhers_params_dict(params_dict, is_signed, n_bits): + dtype = tfhers.get_type_from_params_dict(params_dict, is_signed, n_bits) + assert dtype.bit_width == n_bits + assert dtype.is_signed == is_signed + + test_keys = ["lwe_dimension", "glwe_dimension", "polynomial_size", "pbs_base_log", "pbs_level"] + + for k in test_keys: + assert getattr(dtype.params, k) == params_dict[k] + + +@pytest.mark.parametrize("params_dict", (DEFAULT_TFHERS_PARAMS_DICT,)) +def test_load_tfhers_params_file(params_dict): + ftemp = tempfile.NamedTemporaryFile(delete=False) + fpath = ftemp.name + ftemp.write(bytes(json.dumps(params_dict), "utf8")) + ftemp.close() + + tfhers.get_type_from_params(fpath, True, 8) + os.unlink(fpath)