diff --git a/Cargo.lock b/Cargo.lock index 6dd8489..6ad4fe6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,23 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.8.11" @@ -241,6 +258,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bindgen" version = "0.69.4" @@ -336,6 +359,27 @@ version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cast" version = "0.3.0" @@ -347,6 +391,10 @@ name = "cc" version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" +dependencies = [ + "jobserver", + "libc", +] [[package]] name = "cexpr" @@ -402,6 +450,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -415,18 +473,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.10" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f6b81fb3c84f5563d509c59b5a48d935f689e993afa90fe39047f05adef9142" +checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.10" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca6706fd5224857d9ac5eb9355f6683563cc0541c7cd9d014043b57cbec78ac" +checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" dependencies = [ "anstyle", "clap_lex", @@ -477,6 +535,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + [[package]] name = "core-foundation-sys" version = "0.8.6" @@ -492,6 +556,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -575,6 +648,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ca96b45ca70b8045e0462f191bd209fcb3c3bfe8dbfb1257ada54c4dd59169" +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "digest" version = "0.10.7" @@ -583,6 +665,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -666,15 +749,10 @@ dependencies = [ "matrixcompare", "matrixcompare-core", "nano-gemm", - "npyz", "num-complex", "num-traits", "paste", - "rand", - "rand_distr", - "rayon", "reborrow", - "serde", ] [[package]] @@ -692,6 +770,37 @@ dependencies = [ "reborrow", ] +[[package]] +name = "filetime" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.4.1", + "windows-sys", +] + +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "gemm" version = "0.18.0" @@ -757,7 +866,6 @@ dependencies = [ "paste", "pulp", "raw-cpuid", - "rayon", "seq-macro", "sysctl", ] @@ -776,7 +884,6 @@ dependencies = [ "num-traits", "paste", "raw-cpuid", - "rayon", "seq-macro", ] @@ -867,6 +974,15 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.9" @@ -899,6 +1015,16 @@ dependencies = [ "cc", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indicatif" version = "0.17.8" @@ -918,6 +1044,15 @@ version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "instant" version = "0.1.13" @@ -971,6 +1106,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.69" @@ -1147,6 +1291,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" +dependencies = [ + "adler", +] + [[package]] name = "multiversion" version = "0.7.4" @@ -1262,17 +1415,6 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "npyz" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13f27ea175875c472b3df61ece89a6d6ef4e0627f43704e400c782f174681ebd" -dependencies = [ - "byteorder", - "num-bigint", - "py_literal", -] - [[package]] name = "num" version = "0.4.3" @@ -1305,9 +1447,14 @@ checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "bytemuck", "num-traits", - "rand", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -1382,12 +1529,14 @@ dependencies = [ "itertools 0.13.0", "numpy", "nuts-rs", + "ort", "pyo3", "rand", "rand_chacha", "rand_distr", "rayon", "smallvec", + "tch", "thiserror", "time-humanize", "upon", @@ -1396,8 +1545,6 @@ dependencies = [ [[package]] name = "nuts-rs" version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8573e3b5c83e8ec0570ebbd75dd6fdc7dfcfa5da9b5f9d9d63fedefebbd9cf8" dependencies = [ "anyhow", "arrow", @@ -1424,6 +1571,35 @@ version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +[[package]] +name = "ort" +version = "2.0.0-rc.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86d83095ae3c1258738d70ae7a06195c94d966a8e546f0d3609dc90885fb61f5" +dependencies = [ + "half", + "js-sys", + "libloading", + "ndarray", + "ort-sys", + "thiserror", + "tracing", + "web-sys", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f2f6427193c808010b126bef45ebd33f8dee43770223a1200f84d3734d6c656" +dependencies = [ + "flate2", + "pkg-config", + "sha2", + "tar", + "ureq", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -1442,11 +1618,22 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.3", "smallvec", "windows-targets", ] +[[package]] +name = "password-hash" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.15" @@ -1472,49 +1659,34 @@ dependencies = [ ] [[package]] -name = "pest" -version = "2.7.11" +name = "pbkdf2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd53dff83f26735fdc1ca837098ccf133605d794cdae66acfc2bfac3ec809d95" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" dependencies = [ - "memchr", - "thiserror", - "ucd-trie", + "digest", + "hmac", + "password-hash", + "sha2", ] [[package]] -name = "pest_derive" -version = "2.7.11" +name = "percent-encoding" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a548d2beca6773b1c244554d36fcf8548a8a58e74156968211567250e48e49a" -dependencies = [ - "pest", - "pest_generator", -] +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] -name = "pest_generator" -version = "2.7.11" +name = "pin-project-lite" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c93a82e8d145725dcbaf44e5ea887c8a869efdcc28706df2d08c69e17077183" -dependencies = [ - "pest", - "pest_meta", - "proc-macro2", - "quote", - "syn 2.0.72", -] +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" [[package]] -name = "pest_meta" -version = "2.7.11" +name = "pkg-config" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a941429fea7e08bedec25e4f6785b6ffaacc6b755da98df5ef3e7dcf4a124c4f" -dependencies = [ - "once_cell", - "pest", - "sha2", -] +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "plotters" @@ -1550,6 +1722,12 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1587,19 +1765,6 @@ dependencies = [ "reborrow", ] -[[package]] -name = "py_literal" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1" -dependencies = [ - "num-bigint", - "num-complex", - "num-traits", - "pest", - "pest_derive", -] - [[package]] name = "pyo3" version = "0.21.2" @@ -1754,6 +1919,15 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.5.3" @@ -1792,6 +1966,21 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -1811,12 +2000,54 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "rustls" +version = "0.23.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" + +[[package]] +name = "rustls-webpki" +version = "0.102.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "ryu" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -1869,6 +2100,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -1892,12 +2134,35 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "static_assertions" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "1.0.109" @@ -1934,6 +2199,17 @@ dependencies = [ "walkdir", ] +[[package]] +name = "tar" +version = "0.4.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "target-features" version = "0.1.6" @@ -1946,6 +2222,23 @@ version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +[[package]] +name = "tch" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61fd89a98303b22acd6d4969b4c8940f7a30ba79af32b744a2028375d156e95a" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand", + "safetensors", + "thiserror", + "torch-sys", + "zip", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -1966,6 +2259,25 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "time" +version = "0.3.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + [[package]] name = "time-humanize" version = "0.1.3" @@ -1991,6 +2303,64 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "torch-sys" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5997681f7f3700fa475f541fcda44c8959ea42a724194316fe7297cb96ebb08" +dependencies = [ + "anyhow", + "cc", + "libc", + "zip", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", +] + [[package]] name = "typenum" version = "1.17.0" @@ -1998,10 +2368,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] -name = "ucd-trie" -version = "0.1.6" +name = "unicode-bidi" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" [[package]] name = "unicode-ident" @@ -2009,6 +2379,15 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-width" version = "0.1.13" @@ -2021,12 +2400,45 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "upon" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fe29601d1624f104fa9a35ea71a5f523dd8bd1cfc8c31f8124ad2b829f013c0" +[[package]] +name = "ureq" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea" +dependencies = [ + "base64", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "socks", + "url", + "webpki-roots", +] + +[[package]] +name = "url" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "version_check" version = "0.9.4" @@ -2113,6 +2525,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "4.4.2" @@ -2125,6 +2546,22 @@ dependencies = [ "rustix", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.8" @@ -2134,6 +2571,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.52.0" @@ -2216,6 +2659,17 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "xattr" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +dependencies = [ + "libc", + "linux-raw-sys", + "rustix", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -2235,3 +2689,58 @@ dependencies = [ "quote", "syn 2.0.72", ] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd", +] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.12+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 48634d8..387339a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,10 @@ rust-version = "1.76" [features] extension-module = ["pyo3/extension-module"] -default = ["extension-module"] +default = ["extension-module", "onnx"] simd_support = ["nuts-rs/simd_support"] +torch = ["dep:tch"] +onnx = ["dep:ort"] [lib] name = "_lib" @@ -27,17 +29,24 @@ numpy = "0.21.0" rand = "0.8.5" thiserror = "1.0.44" rand_chacha = "0.3.1" -rayon = "1.9.0" +rayon = "1.10.0" # Keep arrow in sync with nuts-rs requirements -arrow = { version = "52.0.0", default-features = false, features = ["ffi"] } +arrow = { version = "52.1.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" itertools = "0.13.0" bridgestan = "2.5.0" rand_distr = "0.4.3" -smallvec = "1.11.0" +smallvec = "1.13.0" upon = { version = "0.8.1", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.17.8" +tch = { version = "0.16.0", optional = true } +ort = { version = "2.0.0-rc.4", optional = true, features = [ + "cuda", + #"tensorrt", + #"openvino", + "load-dynamic", +] } [dependencies.pyo3] version = "0.21.0" diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index 980b7e5..bf0fb20 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -1,7 +1,14 @@ from nutpie import _lib +from nutpie.compile_onnx import compile_pytensor_module from nutpie.compile_pymc import compile_pymc_model from nutpie.compile_stan import compile_stan_model -from nutpie.sample import sample +from nutpie.sampling import sample __version__: str = _lib.__version__ -__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"] +__all__ = [ + "__version__", + "sample", + "compile_pymc_model", + "compile_stan_model", + "compile_pytensor_module", +] diff --git a/python/nutpie/compile_onnx.py b/python/nutpie/compile_onnx.py new file mode 100644 index 0000000..4b65701 --- /dev/null +++ b/python/nutpie/compile_onnx.py @@ -0,0 +1,67 @@ +import dataclasses +import io +from typing import Any + +from nutpie import _lib +from nutpie.sampling import CompiledModel + + +def compile_pytensor_module(module, n_dim): + import torch + + x = torch.zeros(n_dim) + exported = torch.onnx.dynamo_export(module, x) + + exported_bytes = io.BytesIO() + exported.save(exported_bytes) + exported_bytes = exported_bytes.getvalue() + + compiled = CompiledOnnx( + _n_dim=n_dim, + providers=None, + logp_module_bytes=exported_bytes, + dims={"unconstrained_draw": ("unconstrained_parameter",)}, + ) + + return compiled.with_providers(["cpu"]) + + +@dataclasses.dataclass(frozen=True) +class CompiledOnnx(CompiledModel): + logp_module_bytes: Any + providers: Any + _n_dim: int + + @property + def shapes(self): + return {"unconstrained_draw": (self.n_dim,)} + + @property + def coords(self): + return {} + + @property + def n_dim(self): + return self._n_dim + + def _make_model(self, init_mean): + return _lib.OnnxModel(self.n_dim, self.logp_module_bytes, self.providers) + + def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None): + model = self._make_model(init_mean) + return _lib.PySampler.from_onnx( + settings, cores, model, template, rate, callback + ) + + def with_providers(self, provider_names): + providers = _lib.OnnxProviders() + for name in provider_names: + if name == "cuda": + providers.add_cuda() + elif name == "tensorrt": + providers.add_tensorrt() + elif name == "cpu": + providers.add_cpu() + else: + raise ValueError(f"Unknown provider {name}") + return dataclasses.replace(self, providers=providers) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index aaee456..cf439e3 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -12,7 +12,7 @@ from nutpie import _lib from nutpie.compiled_pyfunc import from_pyfunc -from nutpie.sample import CompiledModel +from nutpie.sampling import CompiledModel try: from numba.extending import intrinsic @@ -427,6 +427,7 @@ def _compute_shapes(model): def _make_functions(model, *, mode, compute_grad, join_expanded): + # TODO do we want to freeze the model? import pytensor import pytensor.link.numba.dispatch import pytensor.tensor as pt diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 7a28052..26e74f6 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -10,7 +10,7 @@ from numpy.typing import NDArray from nutpie import _lib -from nutpie.sample import CompiledModel +from nutpie.sampling import CompiledModel class _NumpyArrayEncoder(json.JSONEncoder): diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 4db549c..5debd19 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -6,7 +6,7 @@ import numpy as np from nutpie import _lib -from nutpie.sample import CompiledModel +from nutpie.sampling import CompiledModel @dataclass(frozen=True) diff --git a/python/nutpie/sample.py b/python/nutpie/sampling.py similarity index 100% rename from python/nutpie/sample.py rename to python/nutpie/sampling.py diff --git a/src/iree.rs b/src/iree.rs new file mode 100644 index 0000000..7bd0ea6 --- /dev/null +++ b/src/iree.rs @@ -0,0 +1,723 @@ +use std::{ + io::{stderr, stdout, Write}, + mem::{forget, transmute, ManuallyDrop}, + sync::{ + mpsc::{sync_channel, Receiver, SyncSender}, + Arc, Mutex, OnceLock, + }, + thread::{spawn, JoinHandle}, +}; + +use anyhow::{anyhow, Context, Result}; +use arrow2::{ + array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush}, + datatypes::{DataType, Field}, +}; +use eerie::runtime::{ + api::{Call, Instance, InstanceOptions, Session, SessionOptions}, + hal::{BufferMapping, BufferView, Device, DriverRegistry, EncodingType}, + vm::{DynamicList, Function, List, Ref, ToRef, Undefined, Value}, +}; +use numpy::{PyArray1, PyReadonlyArray1, PyReadwriteArray1}; +use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Math, Model}; +use pyo3::{ + pyclass, pymethods, + types::{PyBytes, PyBytesMethods}, + Bound, Py, Python, +}; +use rand_distr::{num_traits::ToPrimitive, Distribution, StandardNormal}; +use thiserror::Error; + +static INSTANCE: OnceLock> = OnceLock::new(); + +fn get_instance() -> Result<&'static Instance> { + match INSTANCE.get_or_init(|| { + let mut registry = DriverRegistry::new(); + let options = InstanceOptions::new(&mut registry).use_all_available_drivers(); + let instance = Instance::new(&options)?; + + Ok(instance) + }) { + &Ok(ref instance) => Ok(instance), + &Err(ref err) => Err(anyhow!("Could not access iree instance: {}", err)), + } +} + +#[pyclass] +#[derive(Clone, Debug)] +pub struct IreeModel { + //logp_module: Box<[u8]>, + //expand_module: Box<[u8]>, + //devices: Arc]>>, + //devices: Box<[String]>, + ndim: usize, + session_maker: Arc>>>>>, + maker_thread: Arc>>, +} + +#[pymethods] +impl IreeModel { + #[new] + pub fn new_py<'py>( + device: String, + logp_module: Bound<'py, PyBytes>, + expand_module: Bound<'py, PyBytes>, + ndim: usize, + ) -> Result { + let logp_module: Box<[u8]> = logp_module.as_bytes().into(); + let expand_module: Box<[u8]> = expand_module.as_bytes().into(); + + Self::new(device, logp_module, expand_module, ndim) + } + + pub fn call_logp( + &self, + position: PyReadonlyArray1, + mut gradient: PyReadwriteArray1, + ) -> Result { + let mut math = self.math()?; + let logp = math.logp(&position.as_slice()?, gradient.as_slice_mut()?)?; + Ok(logp) + } +} + +impl IreeModel { + fn new( + device: String, + logp_module: Box<[u8]>, + expand_module: Box<[u8]>, + ndim: usize, + ) -> Result { + let (session_maker_sender, session_maker) = sync_channel(0); + + let maker_thread = spawn(move || { + let run_loop = move || { + let instance = get_instance()?; + let device = instance.try_create_default_device(&device)?; + + // FIXME + let device: Device<'static> = unsafe { transmute(device) }; + let devices = vec![device]; + + let logp_module: Arc<[u8]> = Arc::from(logp_module); + let expand_module: Arc<[u8]> = Arc::from(expand_module); + + for device in devices.iter().cycle() { + let make_math = || { + let logp_func = LogpFunc::new( + ndim, + logp_module.clone(), + expand_module.clone(), + device, + )?; + Ok(CpuMath::new(logp_func)) + }; + + let math_result = make_math(); + session_maker_sender + .send(math_result) + .map_err(|_| anyhow!("Could not send iree math"))?; + } + Ok(()) + }; + let res = run_loop(); + dbg!(res) + }); + + let session_maker = Arc::new(Mutex::new(session_maker)); + let maker_thread = Arc::new(maker_thread); + + Ok(IreeModel { + ndim, + //devices: vec![device].into(), + //logp_module: logp_module.as_bytes().into(), + //expand_module: expand_module.as_bytes().into(), + maker_thread, + session_maker, + }) + } +} + +#[derive(Debug)] +pub struct LogpFunc<'model> { + pub outputs: DynamicList<'model, Undefined>, + //pub inputs: DynamicList<'model, Ref<'model, BufferView<'model, f32>>>, + pub inputs: DynamicList<'model, Undefined>, + pub logp_func: ManuallyDrop>, + pub session: ManuallyDrop>, + //pub device: ManuallyDrop>, + //pub device: &'model Device<'model>, + logp_compiled: Arc<[u8]>, + expand_compiled: Arc<[u8]>, + pub ndim: usize, + pub buffer: Box<[f32]>, +} + +impl<'model> LogpFunc<'model> { + pub fn new( + ndim: usize, + //device: &'model Device<'model>, + logp_compiled: Arc<[u8]>, + expand_compiled: Arc<[u8]>, + //session: Session<'model>, + //device: &'model str, + device: &Device<'static>, + ) -> Result { + let instance = get_instance()?; + //let device = instance.try_create_default_device(device)?; + + let options = SessionOptions::default(); + let session = Session::create_with_device(instance, &options, &device) + .context("Could not create session")?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + //let session: Session<'static> = unsafe { transmute(session) }; + + // TODO fix the lifetime of this reference + unsafe { session.append_module_from_memory(&logp_compiled) } + .context("Could not load iree logp function")?; + //unsafe { session.append_module_from_memory(expand_compiled) }.context("Coxd not load iree expand function")?; + + let logp_func = session + .lookup_function("jit_jax_funcified_fgraph.logp") + .context("Could not find gradient function in module")?; + + //let call = Call::new(&session, &logp_func)?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + //let call: Call<'model> = unsafe { transmute(call) }; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + let logp_func: Function<'model> = unsafe { transmute(logp_func) }; + + let inputs = DynamicList::new(2, instance)?; + let outputs = DynamicList::new(2, instance)?; + + Ok(Self { + //device: ManuallyDrop::new(device), + //device, + inputs, + outputs, + logp_compiled, + expand_compiled, + ndim, + session: ManuallyDrop::new(session), + logp_func: ManuallyDrop::new(logp_func), + buffer: vec![0.; ndim].into(), + //call, + }) + } +} + +impl<'model> Drop for LogpFunc<'model> { + fn drop(&mut self) { + unsafe { + drop(ManuallyDrop::take(&mut self.logp_func)); + drop(ManuallyDrop::take(&mut self.session)); + //drop(ManuallyDrop::take(&mut self.device)); + } + } +} + +#[derive(Error, Debug)] +pub enum IreeLogpError { + #[error("Error while computing logp and gradient: {0:?}")] + Iree(#[from] anyhow::Error), + #[error("Bad logp value in gradient evaluation")] + BadLogp(), +} + +impl LogpError for IreeLogpError { + fn is_recoverable(&self) -> bool { + match self { + Self::BadLogp() => true, + _ => false, + } + } +} + +impl<'model> CpuLogpFunc for LogpFunc<'model> { + type LogpError = IreeLogpError; + + fn dim(&self) -> usize { + self.ndim + } + + fn logp( + &mut self, + position: &[f64], + gradient: &mut [f64], + ) -> std::result::Result { + let instance = get_instance()?; + + self.buffer + .iter_mut() + .zip(position.iter()) + .for_each(|(out, &val)| *out = val as f32); + + let input_buffer = BufferView::::new( + &self.session, + &[position.len()], + EncodingType::DenseRowMajor, + &self.buffer, + ) + .context("Could not create buffer view")?; + + let input_buffer_ref = input_buffer + .to_ref(instance) + .context("Could not create iree ref")?; + + //dbg!(&input_buffer_ref); + + self.inputs + .push_ref(&input_buffer_ref) + .context("Could not push input buffer to inputs")?; + + let logp_func = self + .session + .lookup_function("jit_jax_funcified_fgraph.logp") + .context("Could not find gradient function in module")?; + + //dbg!(&self.inputs.get_ref::>(0)); + //stderr().lock().flush(); + //stdout().lock().flush(); + + logp_func + .invoke(&self.inputs, &self.outputs) + .context("Could not invoke logp function")?; + //let mut call = Call::new(&self.session, &self.logp_func).context("Could not create iree Call")?; + + //let inputs = call.input_list(); + + //inputs.push_ref(&input_buffer_ref).context("Could not push input")?; + //drop(input_buffer_ref); + //drop(input_buffer); + //drop(inputs); + + //call.invoke().context("Could not invoke iree function")?; + + let output_val: Value = self + .outputs + .get_value(0) + .context("Could not extract logp value")?; + let logp: f64 = output_val.from_value().into(); + + /* + let logp_buffer_ref: Ref> = self + .outputs + .get_ref(0) + .context("Could not get logp buffer")?; + let logp_buffer = logp_buffer_ref.to_buffer_view(&self.session); + */ + + let gradient_buffer_ref: Ref> = self + .outputs + .get_ref(1) + .context("Could not get output buffer")?; + let gradient_buffer = gradient_buffer_ref.to_buffer_view(&self.session); + + gradient_buffer + .copy_to_host(&mut self.buffer) + .context("Could not copy gradient buffer from iree device")?; + + //let mut logp_array = [0f32]; + //logp_buffer.copy_to_host(&self.device, &mut logp_array).context("Could not copy logp value")?; + //let logp = logp_array[0]; + + drop(input_buffer_ref); + drop(input_buffer); + + drop(gradient_buffer_ref); + drop(gradient_buffer); + + //drop(logp_buffer_ref); + //drop(logp_buffer); + + self.inputs.clear(); + self.outputs.clear(); + + let mut has_bad_grad = false; + gradient + .iter_mut() + .zip(self.buffer.iter()) + .for_each(|(out, &val)| { + *out = val as f64; + if !val.is_finite() { + has_bad_grad = true; + } + }); + + if (!logp.is_finite()) | has_bad_grad { + return Err(IreeLogpError::BadLogp()); + } + + Ok(logp as f64) + } +} + +#[derive(Clone)] +pub struct IreeTrace { + trace: MutableFixedSizeListArray>, +} + +impl DrawStorage for IreeTrace { + fn append_value(&mut self, point: &[f64]) -> Result<()> { + self.trace.try_push(Some(point.iter().map(|&x| Some(x))))?; + Ok(()) + } + + fn finalize(mut self) -> Result> { + let field = Field::new("unconstrained_draw", self.trace.data_type().clone(), false); + let fields = vec![field]; + let data_type = DataType::Struct(fields); + let struct_array = StructArray::new(data_type, vec![self.trace.as_box()], None); + Ok(Box::new(struct_array)) + } + + fn inspect(&mut self) -> Result> { + self.clone().finalize() + } +} + +impl Model for IreeModel { + type Math<'model> = CpuMath> + where + Self: 'model; + + type DrawStorage<'model, S: nuts_rs::Settings> = IreeTrace + where + Self: 'model; + + fn new_trace<'model, S: nuts_rs::Settings, R: rand::prelude::Rng + ?Sized>( + &'model self, + rng: &mut R, + chain_id: u64, + settings: &'model S, + ) -> Result> { + let items = MutablePrimitiveArray::new(); + let trace = MutableFixedSizeListArray::new(items, self.ndim); + + Ok(IreeTrace { trace }) + } + + fn math(&self) -> Result> { + self.session_maker + .lock() + .expect("Poisoned mutex") + .recv() + .context("Could not create iree session")? + } + + fn init_position( + &self, + rng: &mut R, + position: &mut [f64], + ) -> Result<()> { + let dist = StandardNormal; + dist.sample_iter(rng) + .zip(position.iter_mut()) + .for_each(|(val, pos)| *pos = val); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{ + fs::File, + io::Read, + mem::{transmute, ManuallyDrop}, + path::Path, + }; + + use anyhow::{Context, Result}; + use eerie::runtime::{ + api::{Call, Session, SessionOptions}, + hal::{BufferView, EncodingType}, + vm::{DynamicList, Function, List, Ref, ToRef, Undefined, Value}, + }; + use nuts_rs::{Math, Model}; + + use super::{get_instance, IreeModel, LogpFunc}; + + #[test] + fn run_logp_manual1() -> Result<()> { + let path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("example-iree") + .join("example-logp.fbvm"); + let mut logp_compiled = Vec::new(); + File::open(path)?.read_to_end(&mut logp_compiled)?; + + let instance = get_instance()?; + let device = instance.try_create_default_device("local-task")?; + + let options = SessionOptions::default(); + let session = Session::create_with_device(instance, &options, &device) + .context("Could not create session")?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + let session: Session<'static> = unsafe { transmute(session) }; + + unsafe { session.append_module_from_memory(&logp_compiled) } + .context("Could not load iree logp function")?; + //unsafe { session.append_module_from_memory(expand_compiled) }.context("Coxd not load iree expand function")?; + + let logp_func = session + .lookup_function("jit_jax_funcified_fgraph.logp") + .context("Could not find gradient function in module")?; + + //let inputs: DynamicList>> = DynamicList::new(2, instance)?; + //let inputs: DynamicList = DynamicList::new(2, instance)?; + //let outputs: DynamicList = DynamicList::new(2, instance)?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + let logp_func: Function<'static> = unsafe { transmute(logp_func) }; + + let mut call = Call::new(&session, &logp_func)?; + + let mut buffer: Box<[f32]> = vec![0.; 2].into(); + let position = vec![1., 2.]; + let mut gradient: Box<[f64]> = vec![-1., -1.].into(); + + buffer + .iter_mut() + .zip(position.iter()) + .for_each(|(out, &val)| *out = val as _); + + let input_buffer = BufferView::::new( + &session, + &[position.len()], + EncodingType::DenseRowMajor, + &buffer, + ) + .context("Could not create buffer view")?; + + let input_buffer_ref = input_buffer + .to_ref(instance) + .context("Could not create iree ref")?; + + let inputs = call.input_list(); + inputs + .push_ref(&input_buffer_ref) + .context("Could not push input buffer to inputs")?; + + //dbg!(&inputs.get_ref::>(0)); + //dbg!(&input_buffer_ref); + + /* + logp_func + .invoke(&inputs, &outputs) + .context("Could not invoke logp function")?; + */ + drop(inputs); + call.invoke().context("Could not invoke iree function")?; + + drop(input_buffer_ref); + drop(input_buffer); + + // TODO For some reason it seems we need to keep this alive until after the call... + // Maybe a missing refcount increase somewhere? + + let outputs = call.output_list(); + + let output_val: Value = outputs + .get_value(0) + .context("Could not extract logp value")?; + let logp: f64 = output_val.from_value().into(); + dbg!(logp); + + let gradient_buffer: Ref> = + outputs.get_ref(1).context("Could not get output buffer")?; + let gradient_buffer = gradient_buffer.to_buffer_view(&session); + + gradient_buffer + .copy_to_host(&mut buffer) + .context("Could not copy gradient buffer from iree device")?; + + gradient + .iter_mut() + .zip(buffer.iter()) + .for_each(|(out, &val)| *out = val as _); + + dbg!(gradient); + + Ok(()) + } + + #[test] + fn run_logp_seg() -> Result<()> { + let path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("example-iree") + .join("example-logp.fbvm"); + let mut logp_compiled = Vec::new(); + File::open(path)?.read_to_end(&mut logp_compiled)?; + + let logp_expand = vec![]; + + let model = IreeModel::new( + "local-task".into(), + logp_compiled.into(), + logp_expand.into(), + 2, + )?; + + let mut math = model.math()?; + + let position = vec![1., 2.]; + let mut gradient = vec![-1., -1.]; + math.logp(&position, &mut gradient)?; + + drop(math); + drop(model); + + Ok(()) + } + + #[test] + fn run_logp_manual2() -> Result<()> { + /* + let path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("example-iree") + .join("example-logp.fbvm"); + let mut logp_compiled = Vec::new(); + File::open(path)?.read_to_end(&mut logp_compiled)?; + + let logp_expand = vec![]; + + let model = IreeModel::new( + "cuda".into(), + logp_compiled.into(), + logp_expand.into(), + 2, + ); + + let instance = get_instance()?; + + let device = instance.try_create_default_device(&model.devices[0])?; + + //let mut math_obj = LogpFunc::new(model.ndim, &model.logp_module, &model.expand_module, device)?; + + let mut math_obj = { + + let instance = get_instance()?; + + let options = SessionOptions::default(); + let session = Session::create_with_device(instance, &options, &device) + .context("Could not create session")?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + //let session: Session<'static> = unsafe { transmute(session) }; + + unsafe { session.append_module_from_memory(&model.logp_module) } + .context("Could not load iree logp function")?; + //unsafe { session.append_module_from_memory(expand_compiled) }.context("Coxd not load iree expand function")?; + + let logp_func = session + .lookup_function("jit_jax_funcified_fgraph.logp") + .context("Could not find gradient function in module")?; + + //let call = Call::new(&session, &logp_func)?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + //let call: Call<'model> = unsafe { transmute(call) }; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + let logp_func: Function<'static> = unsafe { transmute(logp_func) }; + + let inputs = DynamicList::new(2, instance)?; + let outputs = DynamicList::new(2, instance)?; + + LogpFunc { + device: ManuallyDrop::new(device), + inputs, + outputs, + logp_compiled: &model.logp_module, + expand_compiled: &model.expand_module, + ndim: 2, + session: ManuallyDrop::new(session), + //logp_func: ManuallyDrop::new(logp_func), + buffer: vec![0.; 2].into(), + } + + }; + + let math = &mut math_obj; + + let position = vec![1., 2.]; + let mut gradient = vec![-1., -1.]; + + let instance = get_instance()?; + + math.buffer + .iter_mut() + .zip(position.iter()) + .for_each(|(out, &val)| *out = val as _); + + let input_buffer = BufferView::::new( + &math.session, + &[position.len()], + EncodingType::DenseRowMajor, + &math.buffer, + ) + .context("Could not create buffer view")?; + + (math.inputs) + .push_ref( + &input_buffer + .to_ref(instance) + .context("Could not dereference input buffer")?, + ) + .context("Could not push input buffer to inputs")?; + + (math.logp_func) + .invoke(&math.inputs, &math.outputs) + .context("Could not invoke logp function")?; + + drop(input_buffer); + + let output_val: Value = (math.outputs) + .get_value(0) + .context("Could not extract logp value")?; + let logp: f64 = output_val.from_value().into(); + + let gradient_buffer_ref: Ref> = (&math.outputs) + .get_ref(1) + .context("Could not get output buffer")?; + let gradient_buffer = gradient_buffer_ref.to_buffer_view(&math.session); + + gradient_buffer + .copy_into(&math.device, &mut math.buffer) + .context("Could not copy gradient buffer from iree device")?; + + let mut has_bad_grad = false; + gradient + .iter_mut() + .zip(math.buffer.iter()) + .for_each(|(out, &val)| { + *out = val as f64; + if !val.is_finite() { + has_bad_grad = true; + } + }); + + drop(gradient_buffer_ref); + drop(gradient_buffer); + + math.inputs.clear(); + math.outputs.clear(); + + drop(math); + drop(math_obj); + + /* + drop(gradient_buffer); + drop(gradient_buffer_ref); + drop(math); + drop(math_obj); + drop(model); + */ + + */ + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 6154f92..719e66f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,11 @@ +#[cfg(feature = "onnx")] +mod ort; mod progress; mod pyfunc; mod pymc; mod stan; +#[cfg(feature = "torch")] +mod torch; mod wrapper; pub use wrapper::_lib; diff --git a/src/ort.rs b/src/ort.rs new file mode 100644 index 0000000..1754883 --- /dev/null +++ b/src/ort.rs @@ -0,0 +1,396 @@ +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use anyhow::Result; +use anyhow::{anyhow, Context}; +use arrow::array::{Array, FixedSizeListBuilder, PrimitiveBuilder, StructArray}; +use arrow::datatypes::{Field, Float64Type}; +use itertools::Itertools; +use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Math, Model}; +use ort::{ + AllocationDevice, Allocator, CPUExecutionProvider, CUDAExecutionProvider, + ExecutionProviderDispatch, InMemorySession, IoBinding, MemoryInfo, MemoryType, + OpenVINOExecutionProvider, Session, SessionInputs, TVMExecutionProvider, Tensor, + TensorRTExecutionProvider, Value, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyBytes, PyBytesMethods}, + Bound, +}; +use rand_distr::{Distribution, Uniform}; +use thiserror::Error; + +#[derive(Debug, Clone)] +#[pyclass] +pub struct OnnxModel { + ndim: usize, + logp_model: Box<[u8]>, + providers: OnnxProviders, + sessions: Arc>, + count: Arc, +} + +impl OnnxModel { + fn make_plain_logp_session<'a>(&'a self) -> Result { + let logp_session = Session::builder()? + .with_optimization_level(ort::GraphOptimizationLevel::Level3)? + .with_memory_pattern(true)? + //.commit_from_memory_directly(&self.logp_model)?; + .commit_from_memory(&self.logp_model)?; + // + + Ok(logp_session) + } + + fn make_logp_session<'a>(&'a self) -> Result { + let logp_session = Session::builder()? + .with_optimization_level(ort::GraphOptimizationLevel::Level3)? + .with_execution_providers( + self.providers + .clone() + .providers + .into_iter() + .map(|val| val.into()), + )? + .with_memory_pattern(true)? + //.commit_from_memory_directly(&self.logp_model)?; + .commit_from_memory(&self.logp_model)?; + // + + Ok(logp_session) + } +} + +#[pymethods] +impl OnnxModel { + #[new] + pub fn new_py<'py>( + ndim: usize, + logp_model: Bound<'py, PyBytes>, + providers: OnnxProviders, + ) -> Result { + let mut model = Self { + ndim, + providers, + logp_model: logp_model.as_bytes().into(), + sessions: Arc::new(vec![]), + count: Arc::new(0usize.into()), + }; + for _ in 0..6 { + let session = model.make_logp_session()?; + Arc::get_mut(&mut model.sessions).unwrap().push(session); + } + + let session = model.make_plain_logp_session()?; + + let pos = vec![0f32; ndim]; + let input = Tensor::from_array(([ndim], pos))?; + + session.run(ort::inputs![input]?)?; + + Ok(model) + } +} + +pub struct OnnxTrace { + trace: FixedSizeListBuilder>, +} + +impl DrawStorage for OnnxTrace { + fn append_value(&mut self, point: &[f64]) -> Result<()> { + self.trace.values().append_slice(point); + self.trace.append(true); + Ok(()) + } + + fn finalize(mut self) -> Result> { + //let data_type = DataType::Struct(fields.into()); + let data: Arc = Arc::new(self.trace.finish()); + let field = Field::new("unconstrained_draw", data.data_type().clone(), false); + let fields = vec![field]; + let struct_array = StructArray::new(fields.into(), vec![data], None); + Ok(Arc::new(struct_array)) + } + + fn inspect(&self) -> Result> { + let data: Arc = Arc::new(self.trace.finish_cloned()); + let field = Field::new("unconstrained_draw", data.data_type().clone(), false); + let fields = vec![field]; + let struct_array = StructArray::new(fields.into(), vec![data], None); + Ok(Arc::new(struct_array)) + } +} + +#[derive(Error, Debug)] +pub enum OnnxLogpError { + #[error("Error while computing logp and gradient: {0:?}")] + Iree(#[from] anyhow::Error), + #[error("Bad logp value in gradient evaluation")] + BadLogp(), +} + +impl LogpError for OnnxLogpError { + fn is_recoverable(&self) -> bool { + match self { + Self::BadLogp() => true, + _ => false, + } + } +} + +pub struct OnnxLogpFunc<'model> { + //session: &'model InMemorySession<'model>, + input: Tensor, + binding: IoBinding<'model>, + session: &'model Session, + ndim: usize, + input_allocator: Allocator, + output_allocator: Allocator, +} + +impl<'model> OnnxLogpFunc<'model> { + //fn new(ndim: usize, session: &'model InMemorySession<'model>) -> Result { + fn new( + ndim: usize, + binding: IoBinding<'model>, + session: &'model Session, + input: Tensor, + input_allocator: Allocator, + output_allocator: Allocator, + ) -> Result { + Ok(Self { + session, + binding, + ndim, + input, + input_allocator, + output_allocator, + }) + } +} + +impl<'model> CpuLogpFunc for OnnxLogpFunc<'model> { + type LogpError = OnnxLogpError; + + fn dim(&self) -> usize { + self.ndim + } + + fn logp( + &mut self, + position: &[f64], + gradient: &mut [f64], + ) -> std::result::Result { + /* + let position = position.iter().map(|&x| x as f32).collect_vec(); + let position = + Value::from_array(([position.len()], position)).context("Could not create input")?; + let inputs = SessionInputs::ValueArray([position.into()]); + let mut outputs = self + .session + .run(inputs) + .context("Could not run logp function")?; + let logp = outputs + .pop_first() + .context("Could not extract first output")?; + let grad = outputs + .pop_first() + .context("Could not extract second output")?; + let logp: f32 = logp + .1 + .try_extract_raw_tensor() + .context("Could not read logp value")? + .1[0]; + let vals = grad + .1 + .try_extract_raw_tensor::() + .context("Could not read grad value")? + .1; + if vals.len() != gradient.len() { + Err(anyhow!("Logp return gradient with incorrect length"))?; + } + gradient + .iter_mut() + .zip(vals.iter()) + .for_each(|(out, &val)| *out = val as f64); + */ + + let (_, input_vals) = self.input.extract_raw_tensor_mut(); + position + .iter() + .zip(input_vals.iter_mut()) + .for_each(|(val, loc)| *loc = *val as _); + + self.binding + .bind_input(&self.session.inputs[0].name, &self.input) + .context("Coud not bind input to logp function")?; + + let outputs = self.binding.run().context("Could not run logp function")?; + let first = &outputs[0]; + let logp: f32 = first + .try_extract_scalar() + .context("First output wnot a scalar")?; + + let grad = &outputs[1]; + let (_, grad): (_, &[f32]) = grad + .try_extract_raw_tensor() + .context("First output wnot a scalar")?; + + gradient + .iter_mut() + .zip(grad.iter()) + .for_each(|(out, &val)| *out = val as f64); + + Ok(logp as f64) + } +} + +impl Model for OnnxModel { + type Math<'model> = CpuMath> + where + Self: 'model; + + type DrawStorage<'model, S: nuts_rs::Settings> = OnnxTrace + where + Self: 'model; + + fn new_trace<'model, S: nuts_rs::Settings, R: rand::prelude::Rng + ?Sized>( + &'model self, + rng: &mut R, + chain_id: u64, + settings: &'model S, + ) -> Result> { + let items = PrimitiveBuilder::new(); + let trace = FixedSizeListBuilder::new(items, self.ndim.try_into().unwrap()); + + Ok(OnnxTrace { trace }) + } + + fn math(&self) -> Result> { + //let session = self.make_logp_session()?; + let count = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let count = count % self.sessions.len(); + + let session = &self.sessions[count]; + + let input_allocator = Allocator::new( + session, + MemoryInfo::new( + AllocationDevice::CUDAPinned, + 0, + ort::AllocatorType::Device, + MemoryType::CPUInput, + )?, + )?; + let output_allocator = Allocator::new( + session, + MemoryInfo::new( + AllocationDevice::CUDAPinned, + 0, + ort::AllocatorType::Device, + MemoryType::CPUOutput, + )?, + )?; + + let mut binding = session.create_binding()?; + + let input = Tensor::::new(&input_allocator, [self.ndim])?; + + binding.bind_input(&session.inputs[0].name, &input)?; + + let scalar_shape: [usize; 0] = []; + let logp_output = Tensor::::new(&output_allocator, scalar_shape)?; + let grad_output = Tensor::::new(&output_allocator, [self.ndim])?; + + binding.bind_output(&session.outputs[0].name, logp_output)?; + binding.bind_output(&session.outputs[1].name, grad_output)?; + + Ok(CpuMath::new(OnnxLogpFunc::new( + self.ndim, + binding, + session, + input, + input_allocator, + output_allocator, + )?)) + } + + fn init_position( + &self, + rng: &mut R, + position: &mut [f64], + ) -> Result<()> { + let dist = Uniform::new(-2., 2.); + dist.sample_iter(rng) + .zip(position.iter_mut()) + .for_each(|(val, pos)| *pos = val); + Ok(()) + } +} + +#[derive(Debug, Clone)] +enum Provider { + Cpu(CPUExecutionProvider), + Cuda(CUDAExecutionProvider), + TensorRt(TensorRTExecutionProvider), + Tvm(TVMExecutionProvider), + OpenVINO(OpenVINOExecutionProvider), +} + +impl Into for Provider { + fn into(self) -> ExecutionProviderDispatch { + match self { + Self::Cpu(val) => val.build().error_on_failure().into(), + Self::Cuda(val) => val.build().error_on_failure().into(), + Self::TensorRt(val) => val.build().error_on_failure().into(), + Self::Tvm(val) => val.build().error_on_failure().into(), + Self::OpenVINO(val) => val.build().error_on_failure().into(), + } + } +} + +#[derive(Debug, Clone)] +#[pyclass] +pub struct OnnxProviders { + providers: Vec, +} + +#[pymethods] +impl OnnxProviders { + #[new] + pub fn new() -> Self { + Self { providers: vec![] } + } + + pub fn add_cpu(&mut self) -> Result<()> { + self.providers + .push(Provider::Cpu(CPUExecutionProvider::default())); + Ok(()) + } + + pub fn add_cuda(&mut self) -> Result<()> { + self.providers.push(Provider::Cuda( + CUDAExecutionProvider::default().with_cuda_graph(), + )); + Ok(()) + } + + pub fn add_tvm(&mut self) -> Result<()> { + let provider = TVMExecutionProvider::default(); + self.providers.push(Provider::Tvm(provider)); + Ok(()) + } + + pub fn add_openvino(&mut self) -> Result<()> { + let provider = OpenVINOExecutionProvider::default(); + self.providers.push(Provider::OpenVINO(provider)); + Ok(()) + } + + pub fn add_tensorrt(&mut self) -> Result<()> { + self.providers + .push(Provider::TensorRt(TensorRTExecutionProvider::default())); + Ok(()) + } +} diff --git a/src/torch.rs b/src/torch.rs new file mode 100644 index 0000000..d11cecb --- /dev/null +++ b/src/torch.rs @@ -0,0 +1,147 @@ +use anyhow::{anyhow, Context}; +use anyhow::{bail, Result}; +use arrow2::{ + array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush}, + datatypes::{DataType, Field}, +}; +use itertools::Itertools; +use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; +use ort::{ + inputs, CPUExecutionProvider, CUDAExecutionProvider, ExecutionProviderDispatch, + InMemorySession, Session, SessionBuilder, SessionInputValue, SessionInputs, Value, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyBytes, PyBytesMethods}, + Bound, +}; +use rand_distr::{Distribution, Uniform}; +use thiserror::Error; + +#[pyclass] +#[derive(Clone, Debug)] +pub struct TorchModel { + ndim: usize, + logp_model: Box<[u8]>, + providers: Vec, +} + +impl TorchModel { + fn make_logp_session<'a>(&'a self) -> Result<()> { + todo!() + } +} + +#[pymethods] +impl TorchModel { + #[new] + pub fn new_py<'py>(ndim: usize, logp_model: Bound<'py, PyBytes>) -> Result { + todo!() + } +} + +#[derive(Clone)] +pub struct TorchTrace { + trace: MutableFixedSizeListArray>, +} + +impl DrawStorage for TorchTrace { + fn append_value(&mut self, point: &[f64]) -> Result<()> { + self.trace.try_push(Some(point.iter().map(|&x| Some(x))))?; + Ok(()) + } + + fn finalize(mut self) -> Result> { + let field = Field::new("unconstrained_draw", self.trace.data_type().clone(), false); + let fields = vec![field]; + let data_type = DataType::Struct(fields); + let struct_array = StructArray::new(data_type, vec![self.trace.as_box()], None); + Ok(Box::new(struct_array)) + } + + fn inspect(&mut self) -> Result> { + self.clone().finalize() + } +} + +#[derive(Error, Debug)] +pub enum TorchLogpError { + #[error("Error while computing logp and gradient: {0:?}")] + Iree(#[from] anyhow::Error), + #[error("Bad logp value in gradient evaluation")] + BadLogp(), +} + +impl LogpError for TorchLogpError { + fn is_recoverable(&self) -> bool { + match self { + Self::BadLogp() => true, + _ => false, + } + } +} + +pub struct TorchLogpFunc<'model> { + ndim: usize, +} + +impl<'model> TorchLogpFunc<'model> { + fn new(ndim: usize) -> Result { + todo!() + } +} + +impl<'model> CpuLogpFunc for TorchLogpFunc<'model> { + type LogpError = TorchLogpError; + + fn dim(&self) -> usize { + self.ndim + } + + fn logp( + &mut self, + position: &[f64], + gradient: &mut [f64], + ) -> std::result::Result { + todo!() + } +} + +impl Model for TorchModel { + type Math<'model> = CpuMath> + where + Self: 'model; + + type DrawStorage<'model, S: nuts_rs::Settings> = TorchTrace + where + Self: 'model; + + fn new_trace<'model, S: nuts_rs::Settings, R: rand::prelude::Rng + ?Sized>( + &'model self, + rng: &mut R, + chain_id: u64, + settings: &'model S, + ) -> Result> { + let items = MutablePrimitiveArray::new(); + let trace = MutableFixedSizeListArray::new(items, self.ndim); + + Ok(OnnxTrace { trace }) + } + + fn math(&self) -> Result> { + let session = self.make_logp_session()?; + Ok(CpuMath::new(OnnxLogpFunc::new(self.ndim, session)?)) + } + + fn init_position( + &self, + rng: &mut R, + position: &mut [f64], + ) -> Result<()> { + let dist = Uniform::new(-2., 2.); + dist.sample_iter(rng) + .zip(position.iter_mut()) + .for_each(|(val, pos)| *pos = val); + Ok(()) + } +} diff --git a/src/wrapper.rs b/src/wrapper.rs index 6f9ad49..99ce92f 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -4,7 +4,11 @@ use std::{ time::{Duration, Instant}, }; +#[cfg(feature = "onnx")] +use crate::ort::OnnxModel; + use crate::{ + ort::OnnxProviders, progress::{IndicatifHandler, ProgressHandler}, pyfunc::{ExpandDtype, PyModel, PyVariable, TensorShape}, pymc::{ExpandFunc, LogpFunc, PyMcModel}, @@ -569,6 +573,27 @@ impl PySampler { } } + #[cfg(feature = "onnx")] + #[staticmethod] + fn from_onnx( + settings: PyNutsSettings, + cores: usize, + model: OnnxModel, + progress_type: ProgressType, + ) -> PyResult { + let callback = progress_type.into_callback()?; + match settings.into_settings() { + Settings::LowRank(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler))) + } + Settings::Diag(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler))) + } + } + } + fn is_finished(&mut self, py: Python<'_>) -> PyResult { py.allow_threads(|| { let state = std::mem::replace(&mut self.0, SamplerState::Empty); @@ -773,6 +798,10 @@ pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + #[cfg(feature = "onnx")] + m.add_class::()?; + #[cfg(feature = "onnx")] + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;