An ONNX-backed array library that is compliant with the Array API standard.
Releases are available on PyPI and conda-forge.
# using pip
pip install ndonnx
# using conda
conda install ndonnx
# using pixi
pixi add ndonnx
You can install the package in development mode using:
git clone https://github.com/quantco/ndonnx
cd ndonnx
# For Array API tests
git submodule update --init --recursive
pixi shell
pre-commit run -a
pip install --no-build-isolation --no-deps -e .
pytest tests -n auto
ndonnx
is an ONNX based python array library.
It has a couple of key features:
-
It implements the
Array API
standard. Standard compliant code can be executed without changes across numerous backends such as likeNumPy
,JAX
and nowndonnx
.import numpy as np import ndonnx as ndx import jax.numpy as jnp def mean_drop_outliers(a, low=-5, high=5): xp = a.__array_namespace__() return xp.mean(a[(low < a) & (a < high)]) np_result = mean_drop_outliers(np.asarray([-10, 0.5, 1, 5])) jax_result = mean_drop_outliers(jnp.asarray([-10, 0.5, 1, 5])) onnx_result = mean_drop_outliers(ndx.asarray([-10, 0.5, 1, 5])) assert np_result == onnx_result.to_numpy() == jax_result == 0.75
-
It supports ONNX export. This allows you persist your logic into an ONNX computation graph.
import ndonnx as ndx import onnx # Instantiate placeholder ndonnx array x = ndx.array(shape=("N",), dtype=ndx.float32) y = mean_drop_outliers(x) # Build and save ONNX model to disk model = ndx.build({"x": x}, {"y": y}) onnx.save(model, "mean_drop_outliers.onnx")
You can then make predictions using a runtime of your choice.
import onnxruntime as ort import numpy as np inference_session = ort.InferenceSession("mean_drop_outliers.onnx") prediction, = inference_session.run(None, { "x": np.array([-10, 0.5, 1, 5], dtype=np.float32), }) assert prediction == 0.75
In the future we will be enabling a stable API for an extensible data type system. This will allow users to define their own data types and operations on arrays with these data types.
Array API compatibility is tracked in api-coverage-tests
. Missing coverage is tracked in the skips.txt
file. Contributions are welcome!
Summary(1119 total):
- 961 passed
- 107 failed
- 51 deselected
Run the tests with:
pixi run arrayapitests