Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP pytorch/onnx backend #144

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
655 changes: 582 additions & 73 deletions Cargo.lock

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ rust-version = "1.76"

[features]
extension-module = ["pyo3/extension-module"]
default = ["extension-module"]
default = ["extension-module", "onnx"]
simd_support = ["nuts-rs/simd_support"]
torch = ["dep:tch"]
onnx = ["dep:ort"]

[lib]
name = "_lib"
Expand All @@ -27,17 +29,24 @@ numpy = "0.21.0"
rand = "0.8.5"
thiserror = "1.0.44"
rand_chacha = "0.3.1"
rayon = "1.9.0"
rayon = "1.10.0"
# Keep arrow in sync with nuts-rs requirements
arrow = { version = "52.0.0", default-features = false, features = ["ffi"] }
arrow = { version = "52.1.0", default-features = false, features = ["ffi"] }
anyhow = "1.0.72"
itertools = "0.13.0"
bridgestan = "2.5.0"
rand_distr = "0.4.3"
smallvec = "1.11.0"
smallvec = "1.13.0"
upon = { version = "0.8.1", default-features = false, features = [] }
time-humanize = { version = "0.1.3", default-features = false }
indicatif = "0.17.8"
tch = { version = "0.16.0", optional = true }
ort = { version = "2.0.0-rc.4", optional = true, features = [
"cuda",
#"tensorrt",
#"openvino",
"load-dynamic",
] }

[dependencies.pyo3]
version = "0.21.0"
Expand Down
11 changes: 9 additions & 2 deletions python/nutpie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from nutpie import _lib
from nutpie.compile_onnx import compile_pytensor_module
from nutpie.compile_pymc import compile_pymc_model
from nutpie.compile_stan import compile_stan_model
from nutpie.sample import sample
from nutpie.sampling import sample

__version__: str = _lib.__version__
__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"]
__all__ = [
"__version__",
"sample",
"compile_pymc_model",
"compile_stan_model",
"compile_pytensor_module",
]
67 changes: 67 additions & 0 deletions python/nutpie/compile_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import dataclasses
import io
from typing import Any

from nutpie import _lib
from nutpie.sampling import CompiledModel


def compile_pytensor_module(module, n_dim):
import torch

x = torch.zeros(n_dim)
exported = torch.onnx.dynamo_export(module, x)

exported_bytes = io.BytesIO()
exported.save(exported_bytes)
exported_bytes = exported_bytes.getvalue()

compiled = CompiledOnnx(
_n_dim=n_dim,
providers=None,
logp_module_bytes=exported_bytes,
dims={"unconstrained_draw": ("unconstrained_parameter",)},
)

return compiled.with_providers(["cpu"])


@dataclasses.dataclass(frozen=True)
class CompiledOnnx(CompiledModel):
logp_module_bytes: Any
providers: Any
_n_dim: int

@property
def shapes(self):
return {"unconstrained_draw": (self.n_dim,)}

@property
def coords(self):
return {}

@property
def n_dim(self):
return self._n_dim

def _make_model(self, init_mean):
return _lib.OnnxModel(self.n_dim, self.logp_module_bytes, self.providers)

def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None):
model = self._make_model(init_mean)
return _lib.PySampler.from_onnx(
settings, cores, model, template, rate, callback
)

def with_providers(self, provider_names):
providers = _lib.OnnxProviders()
for name in provider_names:
if name == "cuda":
providers.add_cuda()
elif name == "tensorrt":
providers.add_tensorrt()
elif name == "cpu":
providers.add_cpu()
else:
raise ValueError(f"Unknown provider {name}")
return dataclasses.replace(self, providers=providers)
3 changes: 2 additions & 1 deletion python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from nutpie import _lib
from nutpie.compiled_pyfunc import from_pyfunc
from nutpie.sample import CompiledModel
from nutpie.sampling import CompiledModel

try:
from numba.extending import intrinsic
Expand Down Expand Up @@ -427,6 +427,7 @@ def _compute_shapes(model):


def _make_functions(model, *, mode, compute_grad, join_expanded):
# TODO do we want to freeze the model?
import pytensor
import pytensor.link.numba.dispatch
import pytensor.tensor as pt
Expand Down
2 changes: 1 addition & 1 deletion python/nutpie/compile_stan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numpy.typing import NDArray

from nutpie import _lib
from nutpie.sample import CompiledModel
from nutpie.sampling import CompiledModel


class _NumpyArrayEncoder(json.JSONEncoder):
Expand Down
2 changes: 1 addition & 1 deletion python/nutpie/compiled_pyfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from nutpie import _lib
from nutpie.sample import CompiledModel
from nutpie.sampling import CompiledModel


@dataclass(frozen=True)
Expand Down
File renamed without changes.
Loading
Loading