Skip to content

Commit

Permalink
feat: support from_sklearn for trees
Browse files Browse the repository at this point in the history
Support `from_sklearn` for tree based models.

Two options:
- Quantization from thresholds: the main idea is to consider the
  thresholds of the nodes of the trees for quantization to not have
  to use data.
- Quantization from data: build a quantizer from the data provided by
  the user and quantize the thresholds based on that.

This also raises the question of non-uniform input quantization.
We could quantize the data based on the thresholds thus reducing the
number of bits required to log2(max_{feature}(node_{feature})).

That would leak the thresholds used in the model per feature but not the
structure of the tree itself while increasing significantly the number
of bits required.

We could try to automatically determine the n-bits to use to
properly represent all thresholds but this might result in a very high
bit-with.
  • Loading branch information
fd0r committed May 21, 2024
1 parent 55f681a commit e744f62
Show file tree
Hide file tree
Showing 12 changed files with 852 additions and 90 deletions.
1 change: 1 addition & 0 deletions .gitleaksignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ a99389ee01cbb972e46a892d3d0e9c7f8ee23f59:use_case_examples/training/analyze.ipyn
a99389ee01cbb972e46a892d3d0e9c7f8ee23f59:use_case_examples/training/analyze.ipynb:aws-access-token:18379
f41de03048a9ed27946b875e81b34138bb4bb17b:use_case_examples/training/analyze.ipynb:aws-access-token:6404
e2904473898ddd325f245f4faca526a0e9520f49:builders/Dockerfile.zamalang-env:generic-api-key:5
7d5e885816f1f1e432dd94da38c5c8267292056a:docs/advanced_examples/XGBRegressor.ipynb:aws-access-token:1026
21 changes: 11 additions & 10 deletions deps_licenses/licenses_mac_silicon_user.txt
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
Name, Version, License
PyYAML, 6.0.1, MIT License
brevitas, 0.8.0, UNKNOWN
certifi, 2024.2.2, Mozilla Public License 2.0 (MPL 2.0)
certifi, 2023.7.22, Mozilla Public License 2.0 (MPL 2.0)
charset-normalizer, 3.3.2, MIT License
coloredlogs, 15.0.1, MIT License
concrete-python, 2024.4.19, BSD-3-Clause
dependencies, 2.0.1, BSD License
dill, 0.3.8, BSD License
filelock, 3.14.0, The Unlicense (Unlicense)
filelock, 3.13.4, The Unlicense (Unlicense)
flatbuffers, 24.3.25, Apache Software License
fsspec, 2024.3.1, BSD License
huggingface-hub, 0.22.2, Apache Software License
humanfriendly, 10.0, MIT License
hummingbird-ml, 0.4.11, MIT License
hummingbird-ml, 0.4.8, MIT License
idna, 3.7, BSD License
importlib_resources, 6.4.0, Apache Software License
joblib, 1.4.0, BSD License
jsonpickle, 3.0.4, BSD License
mpmath, 1.3.0, BSD License
networkx, 3.1, BSD License
numpy, 1.24.3, BSD License
onnx, 1.16.0, Apache License v2.0
numpy, 1.23.5, BSD License
onnx, 1.15.0, Apache License v2.0
onnxconverter-common, 1.13.0, MIT License
onnxmltools, 1.12.0, Apache Software License
onnxmltools, 1.11.0, Apache Software License
onnxoptimizer, 0.3.13, Apache License v2.0
onnxruntime, 1.17.3, MIT License
packaging, 24.0, Apache Software License; BSD License
Expand All @@ -32,19 +32,20 @@ psutil, 5.9.8, BSD License
python-dateutil, 2.9.0.post0, Apache Software License; BSD License
pytz, 2024.1, MIT License
requests, 2.31.0, Apache Software License
scikit-learn, 1.3.2, BSD License
scikit-learn, 1.1.3, BSD License
scipy, 1.10.1, BSD License
six, 1.16.0, MIT License
skl2onnx, 1.16.0, Apache Software License
skl2onnx, 1.12, Apache Software License
skops, 0.5.0, MIT
skorch, 0.11.0, new BSD 3-Clause
sympy, 1.12, BSD License
tabulate, 0.8.10, MIT License
threadpoolctl, 3.5.0, BSD License
threadpoolctl, 3.4.0, BSD License
torch, 1.13.1, BSD License
tqdm, 4.66.2, MIT License; Mozilla Public License 2.0 (MPL 2.0)
typing_extensions, 4.5.0, Python Software Foundation License
tzdata, 2024.1, Apache Software License
urllib3, 2.2.1, MIT License
xgboost, 1.7.6, Apache Software License
xgboost, 1.6.2, Apache Software License
z3-solver, 4.13.0.0, MIT License
zipp, 3.18.1, MIT License
2 changes: 1 addition & 1 deletion deps_licenses/licenses_mac_silicon_user.txt.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7be80ba54850fbc203015560c8acb9a8
9b8316c2a6c823884676b39f52eb018a
279 changes: 262 additions & 17 deletions docs/advanced_examples/DecisionTreeClassifier.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/advanced_examples/XGBClassifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -587,5 +587,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
51 changes: 17 additions & 34 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ pylint = "^2.13.0"
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2541
pytest = "7.4.1"
pytest-cov = "^4.1.0"
pytest-xdist = "^3.3.1"
pytest-randomly = "^3.11.0"
pytest-repeat = "^0.9.1"
pytest-subtests = "^0.11.0"
pytest_codeblocks = "^0.14.0"
mypy = "^1.8.0"
pydocstyle = "^6.1.1"
python-semantic-release = "^7.27.0"
semver = "^2.13.0"
tomlkit = "^0.7.0"
pytest-json-report = "^1.5.0"
pytest-xdist = "^3.3.1"
pytest-randomly = "^3.11.0"
nbmake = "^1.3.0"
pygments-style-tomorrow = "^1.0.0"
pytest-repeat = "^0.9.1"
mdformat = "^0.7.14"
mdformat_myst = "^0.1.4"
mdformat-toc = "^0.3.0"
Expand Down
13 changes: 10 additions & 3 deletions src/concrete/ml/onnx/onnx_impl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy
from concrete.fhe import conv as fhe_conv
from concrete.fhe import ones as fhe_ones
from concrete.fhe import round_bit_pattern
from concrete.fhe import truncate_bit_pattern
from concrete.fhe.tracing import Tracer

from ..common.debugging import assert_true
Expand Down Expand Up @@ -265,9 +265,16 @@ def rounded_comparison(
# Workaround: in this context, `round_bit_pattern` is used as a truncate operation.
# Consequently, we subtract a term, called `half` that will subsequently be re-added during the
# `round_bit_pattern` process.
half = 1 << (lsbs_to_remove - 1)
# half = 1 << (lsbs_to_remove - 1)

# To determine if 'x' 'operation' 'y' (operation being <, >, >=, <=), we evaluate 'x - y'
rounded_subtraction = round_bit_pattern((x - y) - half, lsbs_to_remove=lsbs_to_remove)
# We cast to int because if half is too high the result might be float
# intermediate = ((x - y) - half)
# intermediate_as_int = intermediate.astype(numpy.int64)
#
# if not isinstance(intermediate, Tracer):
# assert (intermediate == intermediate_as_int).all()

rounded_subtraction = truncate_bit_pattern(x - y, lsbs_to_remove=lsbs_to_remove)

return (operation(rounded_subtraction),)
2 changes: 1 addition & 1 deletion src/concrete/ml/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def dequant(self, qvalues: numpy.ndarray) -> Union[numpy.ndarray, Tracer]:

values = self.scale * (qvalues - numpy.asarray(self.zero_point, dtype=numpy.float64))

assert isinstance(values, (numpy.ndarray, Tracer))
assert isinstance(values, (float, int, numpy.ndarray, Tracer))
return values


Expand Down
Loading

0 comments on commit e744f62

Please sign in to comment.