From 9d5b63cfd00469364486807f70b7b07bc5f7d38e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 13 Oct 2023 15:28:32 +0200 Subject: [PATCH 1/6] Update all dependencies to current versions --- Cargo.lock | 619 +++++++++--------- syntaxdot-cli/Cargo.toml | 12 +- syntaxdot-cli/src/progress.rs | 7 +- syntaxdot-cli/src/subcommands/distill.rs | 45 +- syntaxdot-cli/src/subcommands/finetune.rs | 22 +- syntaxdot-cli/src/subcommands/prepare.rs | 8 +- syntaxdot-encoders/Cargo.toml | 6 +- syntaxdot-summary/Cargo.toml | 2 +- syntaxdot-tch-ext/Cargo.toml | 4 +- syntaxdot-tokenizers/Cargo.toml | 2 +- syntaxdot-tokenizers/src/bert.rs | 8 +- syntaxdot-transformers/Cargo.toml | 6 +- syntaxdot-transformers/src/activations.rs | 2 +- syntaxdot-transformers/src/layers.rs | 6 +- syntaxdot-transformers/src/loss.rs | 27 +- .../src/models/albert/encoder.rs | 6 +- .../src/models/bert/embeddings.rs | 5 +- .../src/models/bert/encoder.rs | 9 +- .../src/models/roberta/mod.rs | 4 +- .../src/models/sinusoidal/mod.rs | 5 +- .../src/models/squeeze_bert/embeddings.rs | 2 +- .../src/models/squeeze_bert/encoder.rs | 6 +- .../src/models/squeeze_bert/layer.rs | 2 +- syntaxdot-transformers/src/util.rs | 6 +- syntaxdot/Cargo.toml | 10 +- syntaxdot/src/dataset/mod.rs | 5 +- syntaxdot/src/model/pooling.rs | 23 +- syntaxdot/src/optimizers/grad_scale.rs | 10 +- syntaxdot/src/tensor.rs | 43 +- 29 files changed, 478 insertions(+), 434 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8031d95..b1cea03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -22,39 +22,44 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.0.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67fc08ce920c31afb70f013dcce1bfc3a3195de6a228474e45e1f145b36f8d04" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" -version = "0.3.0" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e579a7752471abc2a8268df8b20005e3eadd975f585398f17efcfd8d4927371" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", - "is-terminal", "utf8parse", ] [[package]] name = "anstyle" -version = "1.0.0" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anstyle-parse" -version = "0.2.0" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e765fd216e48e067936442276d1d57399e37bce53c264d6fefbe298080cb57ee" +checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" dependencies = [ "utf8parse", ] @@ -70,9 +75,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "1.0.0" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcd8291a340dd8ac70e18878bc4501dd7b4ff970cfa21c207d36ece51ea88fd" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" dependencies = [ "anstyle", "windows-sys 0.48.0", @@ -80,30 +85,19 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.70" +version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" [[package]] name = "approx" -version = "0.4.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f2a05fd1bd10b2527e20a2cd32d8873d115b8b39fe219ee25f42a8aca6ba278" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" dependencies = [ "num-traits", ] -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - [[package]] name = "autocfg" version = "1.1.0" @@ -133,9 +127,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.12.1" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b1ce199063694f33ffb7dd4e0ee620741495c32833cde5aa08f02a0bf96f0c8" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "bytecount" @@ -145,15 +139,15 @@ checksum = "2c676a478f63e9fa2dd5368a42f28bba0d6c560b775f38583c8bbaa7fcd67c9c" [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "bzip2" @@ -194,11 +188,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ "jobserver", + "libc", ] [[package]] @@ -208,45 +203,59 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] -name = "cipher" -version = "0.3.0" +name = "ciborium" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" dependencies = [ - "generic-array", + "ciborium-io", + "ciborium-ll", + "serde", ] [[package]] -name = "clap" -version = "2.34.0" +name = "ciborium-io" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" + +[[package]] +name = "ciborium-ll" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" dependencies = [ - "bitflags", - "textwrap", - "unicode-width", + "ciborium-io", + "half 1.8.2", +] + +[[package]] +name = "cipher" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ee52072ec15386f770805afd189a01c8841be8696bed250fa2f13c4c0d6dfb7" +dependencies = [ + "generic-array", ] [[package]] name = "clap" -version = "4.2.4" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "956ac1f6381d8d82ab4684768f89c0ea3afe66925ceadb4eeb3fc452ffc55d62" +checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.2.4" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84080e799e54cff944f4b4a4b0e71630b0e0443b25b985175c7dddc1a859b749" +checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45" dependencies = [ "anstream", "anstyle", - "bitflags", "clap_lex", - "once_cell", "strsim", ] @@ -256,14 +265,14 @@ version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a19591b2ab0e3c04b588a0e04ddde7b9eaa423646d1b4a8092879216bf47473" dependencies = [ - "clap 4.2.4", + "clap", ] [[package]] name = "clap_lex" -version = "0.4.1" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" +checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" [[package]] name = "cmake" @@ -286,7 +295,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb103d8de17a098acd48b5b6092e3f3b5467fb95311770f9c6b9fa4e12dea486" dependencies = [ - "itertools", + "itertools 0.10.5", "thiserror", "udgraph", ] @@ -300,6 +309,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", + "unicode-width", "windows-sys 0.42.0", ] @@ -311,9 +321,9 @@ checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" [[package]] name = "cpufeatures" -version = "0.2.7" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +checksum = "a17b76ff3a4162b0b27f354a0c87015ddad39d35f9c0c36607a3bdd175dde1f1" dependencies = [ "libc", ] @@ -329,24 +339,24 @@ dependencies = [ [[package]] name = "criterion" -version = "0.3.6" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" dependencies = [ - "atty", + "anes", "cast", - "clap 2.34.0", + "ciborium", + "clap", "criterion-plot", - "csv", - "itertools", - "lazy_static", + "is-terminal", + "itertools 0.10.5", "num-traits", + "once_cell", "oorandom", "plotters", "rayon", "regex", "serde", - "serde_cbor", "serde_derive", "serde_json", "tinytemplate", @@ -355,22 +365,12 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.4.5" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", -] - -[[package]] -name = "crossbeam-channel" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" -dependencies = [ - "cfg-if", - "crossbeam-utils", + "itertools 0.10.5", ] [[package]] @@ -386,9 +386,9 @@ dependencies = [ [[package]] name = "crossbeam-epoch" -version = "0.9.14" +version = "0.9.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" dependencies = [ "autocfg", "cfg-if", @@ -399,9 +399,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.15" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" dependencies = [ "cfg-if", ] @@ -422,32 +422,11 @@ dependencies = [ "typenum", ] -[[package]] -name = "csv" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b015497079b9a9d69c02ad25de6c0a6edef051ea6360a327d0bd05802ef64ad" -dependencies = [ - "csv-core", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "csv-core" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" -dependencies = [ - "memchr", -] - [[package]] name = "digest" -version = "0.10.6" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", @@ -456,9 +435,9 @@ dependencies = [ [[package]] name = "either" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "encode_unicode" @@ -468,36 +447,31 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "env_logger" -version = "0.9.3" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0" dependencies = [ - "atty", "humantime", + "is-terminal", "log", "regex", "termcolor", ] [[package]] -name = "errno" -version = "0.3.1" +name = "equivalent" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" -dependencies = [ - "errno-dragonfly", - "libc", - "windows-sys 0.48.0", -] +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] -name = "errno-dragonfly" -version = "0.1.2" +name = "errno" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "ac3e13f66a2f95e32a39eaa81f6b95d42878ca0e1db0c7543723dfe12557e860" dependencies = [ - "cc", "libc", + "windows-sys 0.48.0", ] [[package]] @@ -534,9 +508,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", "libc", @@ -565,28 +539,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] -name = "hermit-abi" -version = "0.1.19" +name = "hashbrown" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] +checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" [[package]] name = "hermit-abi" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" -dependencies = [ - "libc", -] - -[[package]] -name = "hermit-abi" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "hmac" @@ -621,28 +583,48 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" +dependencies = [ + "equivalent", + "hashbrown 0.14.1", ] [[package]] name = "indicatif" -version = "0.16.2" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d207dc617c7a380ab07ff572a6e52fa202a2a8f355860ac9c38e23f8196be1b" +checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" dependencies = [ "console", - "lazy_static", + "instant", "number_prefix", - "regex", + "portable-atomic", + "unicode-width", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", ] [[package]] name = "io-lifetimes" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "libc", "windows-sys 0.48.0", ] @@ -653,7 +635,7 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" dependencies = [ - "hermit-abi 0.3.1", + "hermit-abi", "io-lifetimes", "rustix", "windows-sys 0.48.0", @@ -668,17 +650,26 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itoa" -version = "1.0.6" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" [[package]] name = "jobserver" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "936cfd212a0155903bcbc060e316fb6cc7cbf2e1907329391ebadc1fe0ce77c2" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" dependencies = [ "libc", ] @@ -700,36 +691,27 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.142" +version = "0.2.149" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" +checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" [[package]] name = "libm" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" - -[[package]] -name = "linked-hash-map" -version = "0.5.6" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "linux-raw-sys" -version = "0.3.4" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36eb31c1778188ae1e64398743890d0877fef36d11521ac60406b42016e8c2cf" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "log" -version = "0.4.17" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "maplit" @@ -745,24 +727,25 @@ checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" [[package]] name = "matrixmultiply" -version = "0.3.3" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb99c395ae250e1bf9133673f03ca9f97b7e71b705436bf8f089453445d1e9fe" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" dependencies = [ + "autocfg", "rawpointer", ] [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "memoffset" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" dependencies = [ "autocfg", ] @@ -803,9 +786,9 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" dependencies = [ "num-traits", ] @@ -833,24 +816,14 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", "libm", ] -[[package]] -name = "num_cpus" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" -dependencies = [ - "hermit-abi 0.2.6", - "libc", -] - [[package]] name = "number_prefix" version = "0.4.0" @@ -869,9 +842,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.1" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "oorandom" @@ -887,11 +860,12 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "ordered-float" -version = "2.10.0" +version = "4.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7940cf2ca942593318d07fcf2596cdca60a85c9e7fab408a5e21a4f9dcd40d87" +checksum = "536900a8093134cf9ccf00a27deb3532421099e958d9dd431135d0c7543ca1e8" dependencies = [ "num-traits", + "rand", "serde", ] @@ -925,14 +899,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" dependencies = [ "fixedbitset", - "indexmap", + "indexmap 1.9.3", ] [[package]] name = "pkg-config" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" +checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" [[package]] name = "plotters" @@ -962,6 +936,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31114a898e107c51bb1609ffaf55a0e011cf6a4d7f1170d0015a165082c0338b" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -970,41 +950,41 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.56" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" dependencies = [ "unicode-ident", ] [[package]] name = "prost" -version = "0.9.0" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "444879275cb4fd84958b1a1d5420d15e6fcf7c235fe47f053c9c2a80aceb6001" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" dependencies = [ "bytes", - "prost-derive 0.9.0", + "prost-derive 0.11.9", ] [[package]] name = "prost" -version = "0.11.9" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" +checksum = "f4fdd22f3b9c31b53c060df4a0613a1c7f062d4115a2b984dd15b1858f7e340d" dependencies = [ "bytes", - "prost-derive 0.11.9", + "prost-derive 0.12.1", ] [[package]] name = "prost-derive" -version = "0.9.0" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9cc1a3263e07e0bf68e96268f37665207b49560d98739662cdfaae215c720fe" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" dependencies = [ "anyhow", - "itertools", + "itertools 0.10.5", "proc-macro2", "quote", "syn 1.0.109", @@ -1012,22 +992,22 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.11.9" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" +checksum = "265baba7fabd416cf5078179f7d2cbeca4ce7a9041111900675ea7c4cb8a4c32" dependencies = [ "anyhow", - "itertools", + "itertools 0.11.0", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.38", ] [[package]] name = "quote" -version = "1.0.26" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -1041,6 +1021,7 @@ dependencies = [ "libc", "rand_chacha", "rand_core", + "serde", ] [[package]] @@ -1060,6 +1041,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ "getrandom", + "serde", ] [[package]] @@ -1089,9 +1071,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", @@ -1099,21 +1081,31 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "num_cpus", ] [[package]] name = "regex" -version = "1.8.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370" +checksum = "d119d7c7ca818f8a53c300863d4f87566aac09943aef5b355bb83969dae75d87" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465c6fc0621e4abc4187a2bda0937bfd4f722c2730b29562e19689ea796c9a4b" dependencies = [ "aho-corasick", "memchr", @@ -1122,15 +1114,15 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.1" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c" +checksum = "56d84fdd47036b038fc80dd333d10b6aab10d5d31f4a366e20014def75328d33" [[package]] name = "rustix" -version = "0.37.14" +version = "0.37.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b864d3c18a5785a05953adeed93e2dca37ed30f18e69bba9f30079d51f363f" +checksum = "d4eb579851244c2c03e7c24f501c3432bed80b8f720af1d6e5b0e0f01555a035" dependencies = [ "bitflags", "errno", @@ -1142,9 +1134,19 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.13" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "safetensors" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] [[package]] name = "same-file" @@ -1157,9 +1159,9 @@ dependencies = [ [[package]] name = "scopeguard" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sentencepiece" @@ -1195,62 +1197,62 @@ checksum = "4749ccfc2197a9853b509e11542fb9a339340fa448f8b6f65ca1a70b3d9b63f7" [[package]] name = "serde" -version = "1.0.160" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2f3770c8bce3bcda7e149193a069a0f4365bda1fa5cd88e03bca26afc1216c" +checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" dependencies = [ "serde_derive", ] -[[package]] -name = "serde_cbor" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" -dependencies = [ - "half 1.8.2", - "serde", -] - [[package]] name = "serde_derive" -version = "1.0.160" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291a097c63d8497e00160b166a967a4a79c64f3facdd01cbd7502231688d77df" +checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.38", ] [[package]] name = "serde_json" -version = "1.0.96" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "057d394a50403bcac12672b2b18fb387ab6d289d957dab67dd201875391e52f1" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ "itoa", "ryu", "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96426c9936fd7a0124915f9185ea1d20aa9445cc9821142f0a73bc9207a2e186" +dependencies = [ + "serde", +] + [[package]] name = "serde_yaml" -version = "0.8.26" +version = "0.9.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578a7433b776b56a35785ed5ce9a7e777ac0598aac5a6dd1b4b18a307c7fc71b" +checksum = "1a49e178e4452f45cb61d0cd8cebc1b0fafd3e41929e996cef79aa3aca91f574" dependencies = [ - "indexmap", + "indexmap 2.0.2", + "itoa", "ryu", "serde", - "yaml-rust", + "unsafe-libyaml", ] [[package]] name = "sha1" -version = "0.10.5" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", @@ -1299,9 +1301,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.15" +version = "2.0.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" dependencies = [ "proc-macro2", "quote", @@ -1341,12 +1343,12 @@ version = "0.5.0" dependencies = [ "anyhow", "bytecount", - "clap 4.2.4", + "clap", "clap_complete", "conllu", "env_logger", "indicatif", - "itertools", + "itertools 0.11.0", "log", "ndarray", "ordered-float", @@ -1371,7 +1373,7 @@ dependencies = [ "conllu", "criterion", "fst", - "itertools", + "itertools 0.11.0", "lazy_static", "maplit", "ndarray", @@ -1394,14 +1396,14 @@ name = "syntaxdot-summary" version = "0.5.0" dependencies = [ "hostname", - "prost 0.9.0", + "prost 0.12.1", ] [[package]] name = "syntaxdot-tch-ext" version = "0.5.0" dependencies = [ - "itertools", + "itertools 0.11.0", "tch", ] @@ -1431,15 +1433,16 @@ dependencies = [ [[package]] name = "tch" -version = "0.11.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3394fea57e43ef3708196025beb3445dc13998b162f972233f11b9fdf091401" +checksum = "0ed5dddab3812892bf5fb567136e372ea49f31672931e21cec967ca68aec03da" dependencies = [ "half 2.2.1", "lazy_static", "libc", "ndarray", "rand", + "safetensors", "thiserror", "torch-sys", "zip", @@ -1447,40 +1450,31 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" dependencies = [ "winapi-util", ] -[[package]] -name = "textwrap" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" -dependencies = [ - "unicode-width", -] - [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.38", ] [[package]] @@ -1526,18 +1520,43 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "toml" -version = "0.5.11" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "185d8ab0dfbb35cf1399a6344d8484209c088f75f8f68230da55d48d95d43e3d" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338" dependencies = [ + "indexmap 2.0.2", "serde", + "serde_spanned", + "toml_datetime", + "winnow", ] [[package]] name = "torch-sys" -version = "0.11.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff928d3e632acda675e388d8e0efe526136350620800cc770e3727c9abc6ea1" +checksum = "803446f89fb877a117503dbfb8375b6a29fa8b0e0f44810fac3863c798ecef22" dependencies = [ "anyhow", "cc", @@ -1547,9 +1566,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "udgraph" @@ -1563,9 +1582,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.8" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -1578,9 +1597,15 @@ dependencies = [ [[package]] name = "unicode-width" -version = "0.1.10" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + +[[package]] +name = "unsafe-libyaml" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +checksum = "f28467d3e1d3c6586d8f25fa243f544f5800fec42d97032474e17222c2b75cfa" [[package]] name = "utf8parse" @@ -1829,22 +1854,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" [[package]] -name = "wordpieces" -version = "0.5.0" +name = "winnow" +version = "0.5.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b28fb50e814eb5e0fadbdbdd2eabb4f70d54d6b1981943c32d019555c0247b9e" +checksum = "037711d82167854aff2018dfd193aa0fef5370f456732f0d5a0c59b0f1b4b907" dependencies = [ - "fst", - "thiserror", + "memchr", ] [[package]] -name = "yaml-rust" -version = "0.4.5" +name = "wordpieces" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85" +checksum = "e23923addebb8cc91761c6c01624e96d3fabeea2fddc4c5ccf0f9c1a3be6abf1" dependencies = [ - "linked-hash-map", + "fst", + "thiserror", ] [[package]] diff --git a/syntaxdot-cli/Cargo.toml b/syntaxdot-cli/Cargo.toml index 65050c5..884a5fd 100644 --- a/syntaxdot-cli/Cargo.toml +++ b/syntaxdot-cli/Cargo.toml @@ -19,14 +19,14 @@ bytecount = "0.6" clap = { version = "4", features = ["cargo"] } clap_complete = "4" conllu = "0.8" -env_logger = "0.9" -indicatif = "0.16" -itertools = "0.10" +env_logger = "0.10" +indicatif = "0.17" +itertools = "0.11" log = "0.4" ndarray = "0.15" -ordered-float = { version = "2", features = ["serde"] } +ordered-float = { version = "4", features = ["serde"] } rayon = "1" -serde_yaml = "0.8" +serde_yaml = "0.9" stdinout = "0.4" syntaxdot = { path = "../syntaxdot", version = "0.5.0", default-features = false } syntaxdot-encoders = { path = "../syntaxdot-encoders", version = "0.5.0" } @@ -34,5 +34,5 @@ syntaxdot-summary = { path = "../syntaxdot-summary", version = "0.5.0" } syntaxdot-tch-ext = { path = "../syntaxdot-tch-ext", version = "0.5.0" } syntaxdot-tokenizers = { path = "../syntaxdot-tokenizers", version = "0.5.0" } syntaxdot-transformers = { path = "../syntaxdot-transformers", version = "0.5.0", default-features = false } -tch = { version = "0.11", default-features = false } +tch = { version = "0.14", default-features = false } udgraph = "0.8" diff --git a/syntaxdot-cli/src/progress.rs b/syntaxdot-cli/src/progress.rs index 63ee623..0da367d 100644 --- a/syntaxdot-cli/src/progress.rs +++ b/syntaxdot-cli/src/progress.rs @@ -21,8 +21,11 @@ where let len = read.seek(SeekFrom::End(0))? + 1; read.seek(SeekFrom::Start(0))?; let progress_bar = ProgressBar::new(len); - progress_bar - .set_style(ProgressStyle::default_bar().template("{bar} {bytes}/{total_bytes}")); + progress_bar.set_style( + ProgressStyle::default_bar() + .template("{bar} {bytes}/{total_bytes}") + .expect("Invalid progress style"), + ); Ok(ReadProgress { inner: read, diff --git a/syntaxdot-cli/src/subcommands/distill.rs b/syntaxdot-cli/src/subcommands/distill.rs index 419449a..d3511be 100644 --- a/syntaxdot-cli/src/subcommands/distill.rs +++ b/syntaxdot-cli/src/subcommands/distill.rs @@ -1,6 +1,7 @@ use std::cell::RefCell; use std::collections::btree_map::{BTreeMap, Entry}; use std::collections::{HashMap, VecDeque}; +use std::convert::{TryFrom, TryInto}; use std::fs::{self, File}; use std::io::{BufRead, BufReader, Seek}; @@ -298,7 +299,7 @@ impl DistillApp { let train_progress = ProgressBar::new(n_steps as u64); train_progress.set_style(ProgressStyle::default_bar().template( "[Time: {elapsed_precise}, ETA: {eta_precise}] {bar} {percent}% train {msg}", - )); + )?); while global_step < n_steps - 1 { let mut teacher_train_dataset = Self::open_dataset(teacher_train_file)?; @@ -664,9 +665,9 @@ impl DistillApp { global_step, lr_encoder, lr_classifier, - f32::from(distill_loss.soft_loss), - f32::from(distill_loss.attention_loss), - f32::from(distill_loss.hidden_loss) + f32::try_from(distill_loss.soft_loss)?, + f32::try_from(distill_loss.attention_loss)?, + f32::try_from(distill_loss.hidden_loss)? )); progress.inc(1); @@ -847,7 +848,7 @@ impl DistillApp { let progress_bar = read_progress.progress_bar().clone(); progress_bar.set_style(ProgressStyle::default_bar().template( "[Time: {elapsed_precise}, ETA: {eta_precise}] {bar} {percent}% validation {msg}", - )); + )?); let mut dataset = ConlluDataSet::new(BufReader::new(read_progress)); @@ -868,7 +869,8 @@ impl DistillApp { { let batch = batch?; - let n_batch_tokens = i64::from(batch.token_spans.token_mask()?.f_sum(Kind::Int64)?); + let n_batch_tokens = + i64::try_from(batch.token_spans.token_mask()?.f_sum(Kind::Int64)?)?; let attention_mask = batch.seq_lens.attention_mask()?; @@ -902,41 +904,41 @@ impl DistillApp { .seq_classifiers .summed_loss .f_sum(Kind::Float)? - .into(); + .try_into()?; for (encoder_name, loss) in model_loss.seq_classifiers.encoder_losses { match encoder_accuracy.entry(encoder_name.clone()) { Entry::Vacant(entry) => { entry.insert( - f32::from( + f32::try_from( &model_loss.seq_classifiers.encoder_accuracies[&encoder_name], - ) * n_batch_tokens as f32, + )? * n_batch_tokens as f32, ); } Entry::Occupied(mut entry) => { - *entry.get_mut() += f32::from( + *entry.get_mut() += f32::try_from( &model_loss.seq_classifiers.encoder_accuracies[&encoder_name], - ) * n_batch_tokens as f32; + )? * n_batch_tokens as f32; } }; match encoder_loss.entry(encoder_name) { Entry::Vacant(entry) => { - entry.insert(f32::from(loss) * n_batch_tokens as f32); + entry.insert(f32::try_from(loss)? * n_batch_tokens as f32); } Entry::Occupied(mut entry) => { - *entry.get_mut() += f32::from(loss) * n_batch_tokens as f32 + *entry.get_mut() += f32::try_from(loss)? * n_batch_tokens as f32 } }; } if let Some(biaffine_loss) = model_loss.biaffine.as_ref() { - let head_loss = f32::from(&biaffine_loss.head_loss); - let relation_loss = f32::from(&biaffine_loss.relation_loss); + let head_loss = f32::try_from(&biaffine_loss.head_loss)?; + let relation_loss = f32::try_from(&biaffine_loss.relation_loss)?; - biaffine_las += f32::from(&biaffine_loss.acc.las) * n_batch_tokens as f32; - biaffine_ls += f32::from(&biaffine_loss.acc.ls) * n_batch_tokens as f32; - biaffine_uas += f32::from(&biaffine_loss.acc.uas) * n_batch_tokens as f32; + biaffine_las += f32::try_from(&biaffine_loss.acc.las)? * n_batch_tokens as f32; + biaffine_ls += f32::try_from(&biaffine_loss.acc.ls)? * n_batch_tokens as f32; + biaffine_uas += f32::try_from(&biaffine_loss.acc.uas)? * n_batch_tokens as f32; biaffine_head_loss += head_loss * n_batch_tokens as f32; biaffine_relation_loss += relation_loss * n_batch_tokens as f32; @@ -1306,10 +1308,11 @@ impl TrainDuration { ReadProgress::new(train_file.try_clone()?).context("Cannot open train file")?; let progress_bar = read_progress.progress_bar().clone(); - progress_bar - .set_style(ProgressStyle::default_bar().template( + progress_bar.set_style( + ProgressStyle::default_bar().template( "[Time: {elapsed_precise}, ETA: {eta_precise}] {bar} {percent}%", - )); + )?, + ); let n_sentences = count_sentences(BufReader::new(read_progress))?; diff --git a/syntaxdot-cli/src/subcommands/finetune.rs b/syntaxdot-cli/src/subcommands/finetune.rs index 6c7f139..df0ccf1 100644 --- a/syntaxdot-cli/src/subcommands/finetune.rs +++ b/syntaxdot-cli/src/subcommands/finetune.rs @@ -1,4 +1,5 @@ use std::collections::BTreeMap; +use std::convert::{TryFrom, TryInto}; use std::fs::File; use std::io::BufReader; @@ -248,7 +249,7 @@ impl FinetuneApp { progress_bar.set_style(ProgressStyle::default_bar().template(&format!( "[Time: {{elapsed_precise}}, ETA: {{eta_precise}}] {{bar}} {{percent}}% {} {{msg}}", epoch_type - ))); + ))?); let mut dataset = ConlluDataSet::new(BufReader::new(read_progress)); @@ -287,7 +288,8 @@ impl FinetuneApp { let attention_mask = batch.seq_lens.attention_mask()?; - let n_batch_tokens = i64::from(batch.token_spans.token_mask()?.f_sum(Kind::Int64)?); + let n_batch_tokens = + i64::try_from(batch.token_spans.token_mask()?.f_sum(Kind::Int64)?)?; let model_loss = autocast_or_preserve(self.mixed_precision, || { model.loss( @@ -319,7 +321,7 @@ impl FinetuneApp { .seq_classifiers .summed_loss .f_sum(Kind::Float)? - .into(); + .try_into()?; if let Some(scaler) = &mut grad_scaler { let optimizer = scaler.optimizer_mut(); @@ -361,19 +363,19 @@ impl FinetuneApp { for (encoder_name, loss) in model_loss.seq_classifiers.encoder_losses { *encoder_accuracy.entry(encoder_name.clone()).or_insert(0f32) += - f32::from(&model_loss.seq_classifiers.encoder_accuracies[&encoder_name]) + f32::try_from(&model_loss.seq_classifiers.encoder_accuracies[&encoder_name])? * n_batch_tokens as f32; *encoder_loss.entry(encoder_name).or_insert(0f32) += - f32::from(loss) * n_batch_tokens as f32; + f32::try_from(loss)? * n_batch_tokens as f32; } if let Some(biaffine_loss) = model_loss.biaffine.as_ref() { - let head_loss = f32::from(&biaffine_loss.head_loss); - let relation_loss = f32::from(&biaffine_loss.relation_loss); + let head_loss = f32::try_from(&biaffine_loss.head_loss)?; + let relation_loss = f32::try_from(&biaffine_loss.relation_loss)?; - biaffine_las += f32::from(&biaffine_loss.acc.las) * n_batch_tokens as f32; - biaffine_ls += f32::from(&biaffine_loss.acc.ls) * n_batch_tokens as f32; - biaffine_uas += f32::from(&biaffine_loss.acc.uas) * n_batch_tokens as f32; + biaffine_las += f32::try_from(&biaffine_loss.acc.las)? * n_batch_tokens as f32; + biaffine_ls += f32::try_from(&biaffine_loss.acc.ls)? * n_batch_tokens as f32; + biaffine_uas += f32::try_from(&biaffine_loss.acc.uas)? * n_batch_tokens as f32; biaffine_head_loss += head_loss * n_batch_tokens as f32; biaffine_relation_loss += relation_loss * n_batch_tokens as f32; diff --git a/syntaxdot-cli/src/subcommands/prepare.rs b/syntaxdot-cli/src/subcommands/prepare.rs index 37b4d1c..7b7b0ad 100644 --- a/syntaxdot-cli/src/subcommands/prepare.rs +++ b/syntaxdot-cli/src/subcommands/prepare.rs @@ -86,10 +86,10 @@ impl SyntaxDotApp for PrepareApp { .context(format!("Cannot open train data file: {}", self.train_data))?; let read_progress = ReadProgress::new(train_file).context("Cannot create progress bar")?; let progress_bar = read_progress.progress_bar().clone(); - progress_bar.set_style( - ProgressStyle::default_bar() - .template("[Time: {elapsed_precise}, ETA: {eta_precise}] {bar} {percent}% {msg}"), - ); + progress_bar + .set_style(ProgressStyle::default_bar().template( + "[Time: {elapsed_precise}, ETA: {eta_precise}] {bar} {percent}% {msg}", + )?); let treebank_reader = Reader::new(BufReader::new(read_progress)); diff --git a/syntaxdot-encoders/Cargo.toml b/syntaxdot-encoders/Cargo.toml index c05a38a..285fbfd 100644 --- a/syntaxdot-encoders/Cargo.toml +++ b/syntaxdot-encoders/Cargo.toml @@ -13,12 +13,12 @@ license = "MIT OR Apache-2.0" caseless = "0.2" conllu = "0.8" fst = "0.4" -itertools = "0.10" +itertools = "0.11" numberer = "0.2" lazy_static = "1" maplit = "1" ndarray = "0.15" -ordered-float = "2" +ordered-float = "4" petgraph = "0.6" seqalign = "0.2" serde = { version = "1", features = ["derive"] } @@ -28,7 +28,7 @@ udgraph = "0.8" unicode-normalization = "0.1" [dev-dependencies] -criterion = "0.3" +criterion = "0.5" ndarray-rand = "0.14" rand = "0.8" rand_xorshift = "0.3" diff --git a/syntaxdot-summary/Cargo.toml b/syntaxdot-summary/Cargo.toml index 35b044f..9b876a0 100644 --- a/syntaxdot-summary/Cargo.toml +++ b/syntaxdot-summary/Cargo.toml @@ -12,4 +12,4 @@ license = "MIT OR Apache-2.0" [dependencies] hostname = "0.3" -prost = { version = "0.9", features = ["prost-derive"] } +prost = { version = "0.12", features = ["prost-derive"] } diff --git a/syntaxdot-tch-ext/Cargo.toml b/syntaxdot-tch-ext/Cargo.toml index 7e69c7d..e071ca2 100644 --- a/syntaxdot-tch-ext/Cargo.toml +++ b/syntaxdot-tch-ext/Cargo.toml @@ -10,8 +10,8 @@ documentation = "https://docs.rs/syntaxdot-tch-ext/" license = "MIT OR Apache-2.0" [dependencies] -itertools = "0.10" -tch = { version = "0.11", default-features = false } +itertools = "0.11" +tch = { version = "0.14", default-features = false } [features] doc-only = ["tch/doc-only"] diff --git a/syntaxdot-tokenizers/Cargo.toml b/syntaxdot-tokenizers/Cargo.toml index e4434c1..f1cbce5 100644 --- a/syntaxdot-tokenizers/Cargo.toml +++ b/syntaxdot-tokenizers/Cargo.toml @@ -14,7 +14,7 @@ ndarray = "0.15" sentencepiece = "0.11" thiserror = "1" udgraph = "0.8" -wordpieces = "0.5" +wordpieces = "0.6" [features] model-tests = [] diff --git a/syntaxdot-tokenizers/src/bert.rs b/syntaxdot-tokenizers/src/bert.rs index b1f59cc..e46d418 100644 --- a/syntaxdot-tokenizers/src/bert.rs +++ b/syntaxdot-tokenizers/src/bert.rs @@ -1,4 +1,3 @@ -use std::convert::TryFrom; use std::fs::File; use std::io::{BufRead, BufReader}; @@ -56,7 +55,7 @@ impl BertTokenizer { where R: BufRead, { - let word_pieces = WordPieces::try_from(buf_read.lines())?; + let word_pieces = WordPieces::from_buf_read(buf_read)?; Ok(Self::new(word_pieces, unknown_piece)) } } @@ -103,9 +102,8 @@ impl Tokenize for BertTokenizer { #[cfg(feature = "model-tests")] #[cfg(test)] mod tests { - use std::convert::TryFrom; use std::fs::File; - use std::io::{BufRead, BufReader}; + use std::io::BufReader; use std::iter::FromIterator; use ndarray::array; @@ -118,7 +116,7 @@ mod tests { fn read_pieces() -> WordPieces { let f = File::open(env!("BERT_BASE_GERMAN_CASED_VOCAB")).unwrap(); - WordPieces::try_from(BufReader::new(f).lines()).unwrap() + WordPieces::from_buf_read(BufReader::new(f)).unwrap() } fn sentence_from_forms(forms: &[&str]) -> Sentence { diff --git a/syntaxdot-transformers/Cargo.toml b/syntaxdot-transformers/Cargo.toml index 4339f17..e3c3bc7 100644 --- a/syntaxdot-transformers/Cargo.toml +++ b/syntaxdot-transformers/Cargo.toml @@ -12,13 +12,13 @@ license = "MIT OR Apache-2.0" [dependencies] serde = { version = "1", features = ["derive"] } syntaxdot-tch-ext = { path = "../syntaxdot-tch-ext", version = "0.5.0" } -tch = { version = "0.11", default-features = false } +tch = { version = "0.14", default-features = false } thiserror = "1" [dev-dependencies] -approx = "0.4" +approx = "0.5" maplit = "1" -ndarray = { version = "0.15", features = ["approx"] } +ndarray = { version = "0.15", features = ["approx-0_5"] } [features] model-tests = [] diff --git a/syntaxdot-transformers/src/activations.rs b/syntaxdot-transformers/src/activations.rs index ad824f3..3e2c8b5 100644 --- a/syntaxdot-transformers/src/activations.rs +++ b/syntaxdot-transformers/src/activations.rs @@ -88,7 +88,7 @@ mod tests { fn gelu_new_returns_correct_values() { let gelu_new = Activation::GeluNew; let activations: ArrayD = (&gelu_new - .forward(&Tensor::of_slice(&[-1., -0.5, 0., 0.5, 1.])) + .forward(&Tensor::from_slice(&[-1., -0.5, 0., 0.5, 1.])) .unwrap()) .try_into() .unwrap(); diff --git a/syntaxdot-transformers/src/layers.rs b/syntaxdot-transformers/src/layers.rs index 2b53a67..ba81ead 100644 --- a/syntaxdot-transformers/src/layers.rs +++ b/syntaxdot-transformers/src/layers.rs @@ -302,18 +302,18 @@ impl PairwiseBilinear { if self.pairwise { // [batch_size, max_seq_len, out_features, v features]. - let intermediate = Tensor::f_einsum("blu,uov->blov", &[&u, &self.weight], None)?; + let intermediate = Tensor::f_einsum("blu,uov->blov", &[&u, &self.weight], None::)?; // We perform a matrix multiplication to get the output with // the shape [batch_size, seq_len, seq_len, out_features]. - let bilinear = Tensor::f_einsum("bmv,blov->bmlo", &[&v, &intermediate], None)?; + let bilinear = Tensor::f_einsum("bmv,blov->bmlo", &[&v, &intermediate], None::)?; Ok(bilinear.f_squeeze_dim(-1)?) } else { Ok(Tensor::f_einsum( "blu,uov,blv->blo", &[&u, &self.weight, &v], - None, + None::, )?) } } diff --git a/syntaxdot-transformers/src/loss.rs b/syntaxdot-transformers/src/loss.rs index a4f0aac..54f1815 100644 --- a/syntaxdot-transformers/src/loss.rs +++ b/syntaxdot-transformers/src/loss.rs @@ -168,6 +168,7 @@ impl MSELoss { #[cfg(test)] mod tests { + use std::convert::TryFrom; use std::convert::TryInto; use approx::assert_abs_diff_eq; @@ -180,8 +181,8 @@ mod tests { #[test] fn cross_entropy_loss_without_label_smoothing() { - let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]); - let targets = Tensor::of_slice(&[2i64]).view([1]); + let logits = Tensor::from_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]); + let targets = Tensor::from_slice(&[2i64]).view([1]); let cross_entropy_loss = CrossEntropyLoss::new(-1, None, Reduction::None); let loss: ArrayD = (&cross_entropy_loss.forward(&logits, &targets, None).unwrap()) .try_into() @@ -192,8 +193,8 @@ mod tests { #[test] fn cross_entropy_with_label_smoothing() { - let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]); - let targets = Tensor::of_slice(&[2i64]).view([1]); + let logits = Tensor::from_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]); + let targets = Tensor::from_slice(&[2i64]).view([1]); let cross_entropy_loss = CrossEntropyLoss::new(-1, Some(0.1), Reduction::None); let loss: ArrayD = (&cross_entropy_loss.forward(&logits, &targets, None).unwrap()) .try_into() @@ -203,9 +204,9 @@ mod tests { #[test] fn cross_entropy_with_label_smoothing_and_mask() { - let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]); - let target_mask = Tensor::of_slice(&[true, false, true, false, true]).view([1, 5]); - let targets = Tensor::of_slice(&[2i64]).view([1]); + let logits = Tensor::from_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]); + let target_mask = Tensor::from_slice(&[true, false, true, false, true]).view([1, 5]); + let targets = Tensor::from_slice(&[2i64]).view([1]); let cross_entropy_loss = CrossEntropyLoss::new(-1, Some(0.1), Reduction::None); let loss: ArrayD = (&cross_entropy_loss .forward(&logits, &targets, Some(&target_mask)) @@ -217,19 +218,19 @@ mod tests { #[test] fn mse_loss_with_averaging() { - let prediction = Tensor::of_slice(&[-0.5, -0.5, 0.0, 1.0]).view([1, 4]); - let target = Tensor::of_slice(&[-1.0, 0.0, 1.0, 1.0]).view([1, 4]); + let prediction = Tensor::from_slice(&[-0.5, -0.5, 0.0, 1.0]).view([1, 4]); + let target = Tensor::from_slice(&[-1.0, 0.0, 1.0, 1.0]).view([1, 4]); let mse_loss = MSELoss::new(super::MSELossNormalization::Mean); let loss = &mse_loss.forward(&prediction, &target).unwrap(); - assert_abs_diff_eq!(f32::from(loss), 0.375f32, epsilon = 1e-6); + assert_abs_diff_eq!(f32::try_from(loss).unwrap(), 0.375f32, epsilon = 1e-6); } #[test] fn mse_loss_with_squared_l2_norm() { - let prediction = Tensor::of_slice(&[-0.5, -0.5, 0.0, 1.0]).view([2, 2]); - let target = Tensor::of_slice(&[-1.0, 0.0, 1.0, 1.0]).view([2, 2]); + let prediction = Tensor::from_slice(&[-0.5, -0.5, 0.0, 1.0]).view([2, 2]); + let target = Tensor::from_slice(&[-1.0, 0.0, 1.0, 1.0]).view([2, 2]); let mse_loss = MSELoss::new(super::MSELossNormalization::SquaredL2Norm); let loss = mse_loss.forward(&prediction, &target).unwrap(); - assert_abs_diff_eq!(f32::from(loss), 0.5, epsilon = 1e-6); + assert_abs_diff_eq!(f32::try_from(&loss).unwrap(), 0.5, epsilon = 1e-6); } } diff --git a/syntaxdot-transformers/src/models/albert/encoder.rs b/syntaxdot-transformers/src/models/albert/encoder.rs index 99636f6..07d9f8c 100644 --- a/syntaxdot-transformers/src/models/albert/encoder.rs +++ b/syntaxdot-transformers/src/models/albert/encoder.rs @@ -185,7 +185,7 @@ mod tests { vs.load(ALBERT_BASE_V2).unwrap(); // Pierre Vinken [...] - let pieces = Tensor::of_slice(&[ + let pieces = Tensor::from_slice(&[ 5399i64, 9730, 2853, 15, 6784, 122, 315, 15, 129, 1865, 14, 686, 9, ]) .reshape(&[1, 13]); @@ -227,12 +227,12 @@ mod tests { vs.load(ALBERT_BASE_V2).unwrap(); // Pierre Vinken [...] - let pieces = Tensor::of_slice(&[ + let pieces = Tensor::from_slice(&[ 5399i64, 9730, 2853, 15, 6784, 122, 315, 15, 129, 1865, 14, 686, 9, 0, 0, ]) .reshape(&[1, 15]); - let attention_mask = seqlen_to_mask(Tensor::of_slice(&[13]), pieces.size()[1]); + let attention_mask = seqlen_to_mask(Tensor::from_slice(&[13]), pieces.size()[1]); let embeddings = embeddings.forward_t(&pieces, false).unwrap(); diff --git a/syntaxdot-transformers/src/models/bert/embeddings.rs b/syntaxdot-transformers/src/models/bert/embeddings.rs index 83fa6bb..4e91fda 100644 --- a/syntaxdot-transformers/src/models/bert/embeddings.rs +++ b/syntaxdot-transformers/src/models/bert/embeddings.rs @@ -195,8 +195,9 @@ mod tests { vs.load(BERT_BASE_GERMAN_CASED).unwrap(); // Word pieces of: Veruntreute die AWO spendengeld ? - let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2]) - .reshape(&[1, 10]); + let pieces = + Tensor::from_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2]) + .reshape(&[1, 10]); let summed_embeddings = embeddings diff --git a/syntaxdot-transformers/src/models/bert/encoder.rs b/syntaxdot-transformers/src/models/bert/encoder.rs index 0f317b8..202ba7a 100644 --- a/syntaxdot-transformers/src/models/bert/encoder.rs +++ b/syntaxdot-transformers/src/models/bert/encoder.rs @@ -165,8 +165,9 @@ mod tests { vs.load(BERT_BASE_GERMAN_CASED).unwrap(); // Word pieces of: Veruntreute die AWO spendengeld ? - let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2]) - .reshape(&[1, 10]); + let pieces = + Tensor::from_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2]) + .reshape(&[1, 10]); let embeddings = embeddings.forward_t(&pieces, false).unwrap(); @@ -206,12 +207,12 @@ mod tests { // Word pieces of: Veruntreute die AWO spendengeld ? // Add some padding to simulate inactive time steps. - let pieces = Tensor::of_slice(&[ + let pieces = Tensor::from_slice(&[ 133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2, 0, 0, 0, 0, 0, ]) .reshape(&[1, 15]); - let attention_mask = seqlen_to_mask(Tensor::of_slice(&[10]), pieces.size()[1]); + let attention_mask = seqlen_to_mask(Tensor::from_slice(&[10]), pieces.size()[1]); let embeddings = embeddings.forward_t(&pieces, false).unwrap(); diff --git a/syntaxdot-transformers/src/models/roberta/mod.rs b/syntaxdot-transformers/src/models/roberta/mod.rs index 870159a..5a5e0df 100644 --- a/syntaxdot-transformers/src/models/roberta/mod.rs +++ b/syntaxdot-transformers/src/models/roberta/mod.rs @@ -127,7 +127,7 @@ mod tests { vs.load(XLM_ROBERTA_BASE).unwrap(); // Subtokenization of: Veruntreute die AWO spendengeld ? - let pieces = Tensor::of_slice(&[ + let pieces = Tensor::from_slice(&[ 0i64, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2, ]) .reshape(&[1, 12]); @@ -165,7 +165,7 @@ mod tests { vs.load(XLM_ROBERTA_BASE).unwrap(); // Subtokenization of: Veruntreute die AWO spendengeld ? - let pieces = Tensor::of_slice(&[ + let pieces = Tensor::from_slice(&[ 0i64, 310, 23451, 107, 6743, 68, 62, 43789, 207126, 49004, 705, 2, ]) .reshape(&[1, 12]); diff --git a/syntaxdot-transformers/src/models/sinusoidal/mod.rs b/syntaxdot-transformers/src/models/sinusoidal/mod.rs index c894c59..812fab1 100644 --- a/syntaxdot-transformers/src/models/sinusoidal/mod.rs +++ b/syntaxdot-transformers/src/models/sinusoidal/mod.rs @@ -172,8 +172,9 @@ mod tests { vs.load(BERT_BASE_GERMAN_CASED).unwrap(); // Word pieces of: Veruntreute die AWO spendengeld ? - let pieces = Tensor::of_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2]) - .reshape(&[1, 10]); + let pieces = + Tensor::from_slice(&[133i64, 1937, 14010, 30, 32, 26939, 26962, 12558, 2739, 2]) + .reshape(&[1, 10]); let summed_embeddings = embeddings diff --git a/syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs b/syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs index 4041fb6..54c1eeb 100644 --- a/syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs +++ b/syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs @@ -55,7 +55,7 @@ mod tests { // Word pieces of: Did the AWO embezzle donations ? let pieces = - Tensor::of_slice(&[2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029]) + Tensor::from_slice(&[2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029]) .reshape(&[1, 9]); let summed_embeddings = diff --git a/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs b/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs index 90899de..599c34c 100644 --- a/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs +++ b/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs @@ -185,7 +185,7 @@ mod tests { // Word pieces of: Did the AWO embezzle donations ? let pieces = - Tensor::of_slice(&[2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029]) + Tensor::from_slice(&[2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029]) .reshape(&[1, 9]); let embeddings = embeddings.forward_t(&pieces, false).unwrap(); @@ -226,12 +226,12 @@ mod tests { // Word pieces of: Did the AWO embezzle donations ? // Add some padding to simulate inactive time steps. - let pieces = Tensor::of_slice(&[ + let pieces = Tensor::from_slice(&[ 2106i64, 1996, 22091, 2080, 7861, 4783, 17644, 11440, 1029, 0, 0, 0, 0, 0, ]) .reshape(&[1, 14]); - let attention_mask = seqlen_to_mask(Tensor::of_slice(&[9]), pieces.size()[1]); + let attention_mask = seqlen_to_mask(Tensor::from_slice(&[9]), pieces.size()[1]); let embeddings = embeddings.forward_t(&pieces, false).unwrap(); diff --git a/syntaxdot-transformers/src/models/squeeze_bert/layer.rs b/syntaxdot-transformers/src/models/squeeze_bert/layer.rs index c6cc7c6..fb51be3 100644 --- a/syntaxdot-transformers/src/models/squeeze_bert/layer.rs +++ b/syntaxdot-transformers/src/models/squeeze_bert/layer.rs @@ -112,7 +112,7 @@ impl ConvActivation { let vs = vs.borrow(); Ok(ConvActivation { - conv1d: Conv1D::new(vs.borrow() / "conv1d", cin, cout, 1, groups)?, + conv1d: Conv1D::new(vs / "conv1d", cin, cout, 1, groups)?, activation, }) } diff --git a/syntaxdot-transformers/src/util.rs b/syntaxdot-transformers/src/util.rs index 133e75f..0d74721 100644 --- a/syntaxdot-transformers/src/util.rs +++ b/syntaxdot-transformers/src/util.rs @@ -166,19 +166,19 @@ pub mod tests { #[test] #[should_panic] fn mask_dimensionality_should_be_correct_for_logits_mask() { - LogitsMask::from_bool_mask(&Tensor::of_slice(&[false])).unwrap(); + LogitsMask::from_bool_mask(&Tensor::from_slice(&[false])).unwrap(); } #[test] fn logits_mask_is_constructed_correctly() { let mask = LogitsMask::from_bool_mask( - &Tensor::of_slice(&[true, false, true, false, true, true, false, false]).view((2, 4)), + &Tensor::from_slice(&[true, false, true, false, true, true, false, false]).view((2, 4)), ) .unwrap(); assert_eq!( mask.inner, - Tensor::of_slice(&[0i64, -10_000, 0, -10_000, 0, 0, -10_000, -10_000]) + Tensor::from_slice(&[0i64, -10_000, 0, -10_000, 0, 0, -10_000, -10_000]) .view((2, 1, 1, 4)) ); } diff --git a/syntaxdot/Cargo.toml b/syntaxdot/Cargo.toml index a791b46..77e40a0 100644 --- a/syntaxdot/Cargo.toml +++ b/syntaxdot/Cargo.toml @@ -14,7 +14,7 @@ conllu = "0.8" ndarray = "0.15" numberer = "0.2" log = "0.4" -ordered-float = "2" +ordered-float = "4" rand = "0.8" rand_xorshift = "0.3" serde = { version = "1", features = [ "derive" ] } @@ -23,16 +23,16 @@ syntaxdot-encoders = { path = "../syntaxdot-encoders", version = "0.5.0" } syntaxdot-tch-ext = { path = "../syntaxdot-tch-ext", version = "0.5.0" } syntaxdot-tokenizers = { path = "../syntaxdot-tokenizers", default-features = false, version = "0.5.0" } syntaxdot-transformers = { path = "../syntaxdot-transformers", default-features = false, version = "0.5.0" } -tch = { version = "0.11", default-features = false } +tch = { version = "0.14", default-features = false } thiserror = "1" -toml = "0.5" +toml = "0.8" udgraph = "0.8" [dev-dependencies] -approx = "0.4" +approx = "0.5" lazy_static = "1" maplit = "1" -wordpieces = "0.5" +wordpieces = "0.6" [features] model-tests = [] diff --git a/syntaxdot/src/dataset/mod.rs b/syntaxdot/src/dataset/mod.rs index 539070f..1180ec6 100644 --- a/syntaxdot/src/dataset/mod.rs +++ b/syntaxdot/src/dataset/mod.rs @@ -31,8 +31,7 @@ pub trait DataSet<'a> { #[cfg(test)] pub(crate) mod tests { - use std::convert::TryFrom; - use std::io::{BufRead, BufReader, Cursor}; + use std::io::{BufReader, Cursor}; use lazy_static::lazy_static; use ndarray::{array, Array1}; @@ -77,7 +76,7 @@ nu"#; } pub fn wordpiece_tokenizer() -> BertTokenizer { - let pieces = WordPieces::try_from(BufReader::new(Cursor::new(PIECES)).lines()).unwrap(); + let pieces = WordPieces::from_buf_read(BufReader::new(Cursor::new(PIECES))).unwrap(); BertTokenizer::new(pieces, "[UNK]") } } diff --git a/syntaxdot/src/model/pooling.rs b/syntaxdot/src/model/pooling.rs index a999e73..83cb88a 100644 --- a/syntaxdot/src/model/pooling.rs +++ b/syntaxdot/src/model/pooling.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use std::convert::TryFrom; use syntaxdot_tch_ext::tensor::SumDim; use syntaxdot_transformers::models::LayerOutput; use syntaxdot_transformers::TransformerError; @@ -126,7 +127,7 @@ impl EmbeddingsPerToken for TokenSpansWithRoot { let (batch_size, _pieces_len, embed_size) = embeddings.size3()?; let (_batch_size, tokens_len) = self.offsets().size2()?; - let max_token_len = i64::from(self.lens().max()); + let max_token_len = i64::try_from(&self.lens().max())?; let piece_range = Tensor::f_arange(max_token_len, (Kind::Int64, self.lens().device()))? .f_view([1, 1, max_token_len])?; @@ -163,8 +164,8 @@ mod tests { #[test] fn discard_pooler_works_correctly() { let spans = TokenSpans::new( - Tensor::of_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]), - Tensor::of_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]), + Tensor::from_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]), + Tensor::from_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]), ); let hidden = Tensor::arange_start_step(36, 0, -1, (Kind::Int64, Device::Cpu)) @@ -179,7 +180,7 @@ mod tests { assert_eq!( token_embeddings, - Tensor::of_slice2(&[ + Tensor::from_slice2(&[ &[36, 35, 34, 33, 30, 29, 28, 27, 0, 0, 0, 0], &[18, 17, 16, 15, 12, 11, 10, 9, 6, 5, 4, 3] ]) @@ -190,8 +191,8 @@ mod tests { #[test] fn embeddings_are_returned_per_token() { let spans = TokenSpansWithRoot::new( - Tensor::of_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]), - Tensor::of_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]), + Tensor::from_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]), + Tensor::from_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]), ); let hidden = @@ -201,7 +202,7 @@ mod tests { assert_eq!( token_embeddings.embeddings, - Tensor::of_slice(&[ + Tensor::from_slice(&[ 30, 29, 28, 27, 26, 25, 0, 0, 24, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 13, 12, 11, 10, 9, 0, 0, 8, 7, 6, 5, 4, 3, 0, 0, 2, 1, 0, 0 ]) @@ -210,7 +211,7 @@ mod tests { assert_eq!( token_embeddings.mask, - Tensor::of_slice(&[ + Tensor::from_slice(&[ true, true, true, false, true, false, false, false, false, false, true, true, true, false, true, true, true, false, true, false ]) @@ -221,8 +222,8 @@ mod tests { #[test] fn mean_pooler_works_correctly() { let spans = TokenSpans::new( - Tensor::of_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]), - Tensor::of_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]), + Tensor::from_slice2(&[[1, 3, 4, -1, -1], [1, 3, 4, 6, 7]]), + Tensor::from_slice2(&[[2, 1, 1, -1, -1], [2, 1, 2, 1, 1]]), ); let hidden = Tensor::arange_start_step(36, 0, -1, (Kind::Int64, Device::Cpu)) @@ -237,7 +238,7 @@ mod tests { assert_eq!( token_embeddings, - Tensor::of_slice2(&[ + Tensor::from_slice2(&[ &[36, 35, 33, 32, 30, 29, 28, 27, 0, 0, 0, 0,], &[18, 17, 15, 14, 12, 11, 9, 8, 6, 5, 4, 3] ]) diff --git a/syntaxdot/src/optimizers/grad_scale.rs b/syntaxdot/src/optimizers/grad_scale.rs index b665d25..9c3fe69 100644 --- a/syntaxdot/src/optimizers/grad_scale.rs +++ b/syntaxdot/src/optimizers/grad_scale.rs @@ -1,3 +1,5 @@ +use std::convert::TryFrom; + use tch::{Kind, Tensor}; use super::{Optimizer, ZeroGrad}; @@ -72,7 +74,7 @@ where /// Get the current scale. pub fn current_scale(&self) -> f32 { - Vec::::from(&self.scale)[0] + Vec::::try_from(&self.scale).expect("Tensor cannot be conversted to Vec")[0] } /// Get a reference to the wrapped optimizer. @@ -150,7 +152,11 @@ where .internal_amp_non_finite_check_and_unscale(&mut self.found_inf, &inv_scale); } - let found_inf = (f32::from(&self.found_inf) - 1.0).abs() < f32::EPSILON; + let found_inf = (f32::try_from(&self.found_inf) + .expect("Cannot convert boolean for infinity detection to f32") + - 1.0) + .abs() + < f32::EPSILON; // Only step when there are no infinite gradients. if !found_inf { diff --git a/syntaxdot/src/tensor.rs b/syntaxdot/src/tensor.rs index 9250ac9..72977da 100644 --- a/syntaxdot/src/tensor.rs +++ b/syntaxdot/src/tensor.rs @@ -328,7 +328,7 @@ impl SequenceLengths { /// Convert sequence lengths to masks. pub fn attention_mask(&self) -> Result { - let max_len = i64::from(self.inner.max()); + let max_len = i64::try_from(self.inner.max())?; let batch_size = self.inner.size()[0]; Ok(Tensor::f_arange(max_len, (Kind::Int, self.inner.device()))? // Construct a matrix [batch_size, max_len] where each row @@ -509,10 +509,10 @@ mod tests { #[test] fn attention_masking_is_correct() { - let seq_lens = SequenceLengths::new(Tensor::of_slice(&[3, 5, 1])); + let seq_lens = SequenceLengths::new(Tensor::from_slice(&[3, 5, 1])); assert_eq!( seq_lens.attention_mask().unwrap(), - Tensor::of_slice(&[ + Tensor::from_slice(&[ true, true, true, false, false, // Sequence 0 true, true, true, true, true, // Sequence 1 true, false, false, false, false, // Sequence 2 @@ -542,10 +542,10 @@ mod tests { // No labels. assert_eq!(tensors.labels, None); - assert_eq!(*tensors.seq_lens, Tensor::of_slice(&[2, 3])); + assert_eq!(*tensors.seq_lens, Tensor::from_slice(&[2, 3])); assert_eq!( tensors.inputs, - Tensor::of_slice(&[1, 2, 0, 3, 4, 5]).reshape(&[2, 3]) + Tensor::from_slice(&[1, 2, 0, 3, 4, 5]).reshape(&[2, 3]) ); } @@ -580,8 +580,8 @@ mod tests { assert_eq!( tensors.biaffine_encodings, Some(BiaffineTensors { - heads: Tensor::of_slice(&[1, -1, 0, 1]).reshape(&[2, 2]), - relations: Tensor::of_slice(&[2, -1, 3, 1]).reshape(&[2, 2]) + heads: Tensor::from_slice(&[1, -1, 0, 1]).reshape(&[2, 2]), + relations: Tensor::from_slice(&[2, -1, 3, 1]).reshape(&[2, 2]) }) ); @@ -592,11 +592,11 @@ mod tests { vec![ ( "a".to_string(), - Tensor::of_slice(&[12, 0, 13, 15]).reshape(&[2, 2]) + Tensor::from_slice(&[12, 0, 13, 15]).reshape(&[2, 2]) ), ( "b".to_string(), - Tensor::of_slice(&[21, 0, 24, 25]).reshape(&[2, 2]) + Tensor::from_slice(&[21, 0, 24, 25]).reshape(&[2, 2]) ) ] .into_iter() @@ -604,10 +604,10 @@ mod tests { ) ); - assert_eq!(*tensors.seq_lens, Tensor::of_slice(&[2, 3])); + assert_eq!(*tensors.seq_lens, Tensor::from_slice(&[2, 3])); assert_eq!( tensors.inputs, - Tensor::of_slice(&[1, 2, 0, 3, 4, 5]).reshape(&[2, 3]) + Tensor::from_slice(&[1, 2, 0, 3, 4, 5]).reshape(&[2, 3]) ); } @@ -664,12 +664,12 @@ mod tests { #[test] fn token_masking_is_correct() { let token_offsets = TokenSpans::new( - Tensor::of_slice2(&[&[1, 3, 5, -1, -1], &[1, 2, 8, 11, 13]]), - Tensor::of_slice2(&[&[2, 2, 1, -1, -1], &[1, 6, 3, 2, 1]]), + Tensor::from_slice2(&[&[1, 3, 5, -1, -1], &[1, 2, 8, 11, 13]]), + Tensor::from_slice2(&[&[2, 2, 1, -1, -1], &[1, 6, 3, 2, 1]]), ); assert_eq!( *token_offsets.token_mask().unwrap(), - Tensor::of_slice(&[ + Tensor::from_slice(&[ true, true, true, false, false, // Sequence 0 true, true, true, true, true // Sequence 1 ]) @@ -680,13 +680,13 @@ mod tests { #[test] fn token_masking_with_root_is_correct() { let token_offsets = TokenSpans::new( - Tensor::of_slice2(&[&[1, 3, 5, -1, -1], &[1, 2, 8, 11, 13]]), - Tensor::of_slice2(&[&[2, 2, 1, -1, -1], &[1, 6, 3, 2, 1]]), + Tensor::from_slice2(&[&[1, 3, 5, -1, -1], &[1, 2, 8, 11, 13]]), + Tensor::from_slice2(&[&[2, 2, 1, -1, -1], &[1, 6, 3, 2, 1]]), ); assert_eq!( *token_offsets.token_mask().unwrap().with_root().unwrap(), - Tensor::of_slice(&[ + Tensor::from_slice(&[ true, true, true, true, false, false, // Sequence 0 true, true, true, true, true, true // Sequence 1 ]) @@ -697,9 +697,12 @@ mod tests { #[test] fn token_sequence_lengths_are_correct() { let token_offsets = TokenSpans::new( - Tensor::of_slice2(&[&[1, 3, 5, -1, -1], &[1, 2, 8, 11, 13]]), - Tensor::of_slice2(&[&[2, 2, 1, -1, -1], &[1, 6, 3, 2, 1]]), + Tensor::from_slice2(&[&[1, 3, 5, -1, -1], &[1, 2, 8, 11, 13]]), + Tensor::from_slice2(&[&[2, 2, 1, -1, -1], &[1, 6, 3, 2, 1]]), + ); + assert_eq!( + token_offsets.seq_lens().unwrap(), + Tensor::from_slice(&[3, 5]) ); - assert_eq!(token_offsets.seq_lens().unwrap(), Tensor::of_slice(&[3, 5])); } } From f47af997346fdc645ca18807a5284d86947a2572 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 13 Oct 2023 17:40:27 +0200 Subject: [PATCH 2/6] Rename test checkpoint extension to `.ot` Some extra processing is used for `.pt`, which fails for loading into `VarStore`. --- scripts/test-all.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scripts/test-all.sh b/scripts/test-all.sh index 2fb5353..f8916fe 100755 --- a/scripts/test-all.sh +++ b/scripts/test-all.sh @@ -32,6 +32,11 @@ for var in "${!models[@]}"; do url="${models[$var]}" data="${cache_dir}/$(basename "${url}")" + # Since these checkpoints were generated, an assumption was added that + # .pt files are created from Python code. Rename to .ot to avoid loading + # issues. + data=${data/%.pt/.ot} + if [ ! -e "${data}" ]; then curl -fo "${data}" "${url}" fi From ccb59b5bccc8f6fd67cebd8b1aec3a59c7208cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 13 Oct 2023 17:41:45 +0200 Subject: [PATCH 3/6] Run SqueezeBert tests in full precision --- syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs | 1 + syntaxdot-transformers/src/models/squeeze_bert/encoder.rs | 2 ++ 2 files changed, 3 insertions(+) diff --git a/syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs b/syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs index 54c1eeb..4d595a8 100644 --- a/syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs +++ b/syntaxdot-transformers/src/models/squeeze_bert/embeddings.rs @@ -52,6 +52,7 @@ mod tests { let embeddings = BertEmbeddings::new(root.sub("embeddings"), &bert_config).unwrap(); vs.load(SQUEEZEBERT_UNCASED).unwrap(); + vs.float(); // Word pieces of: Did the AWO embezzle donations ? let pieces = diff --git a/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs b/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs index 599c34c..f93634d 100644 --- a/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs +++ b/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs @@ -182,6 +182,7 @@ mod tests { let encoder = SqueezeBertEncoder::new(root.sub("encoder"), &config).unwrap(); vs.load(SQUEEZEBERT_UNCASED).unwrap(); + vs.float(); // Word pieces of: Did the AWO embezzle donations ? let pieces = @@ -223,6 +224,7 @@ mod tests { let encoder = SqueezeBertEncoder::new(root.sub("encoder"), &config).unwrap(); vs.load(SQUEEZEBERT_UNCASED).unwrap(); + vs.float(); // Word pieces of: Did the AWO embezzle donations ? // Add some padding to simulate inactive time steps. From ce36ac42e19682b8e7e5723a295ad80ae1781b50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 13 Oct 2023 19:47:31 +0200 Subject: [PATCH 4/6] Update Rust CI toolchain --- .github/workflows/rust.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 9a88562..5516473 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -34,7 +34,7 @@ jobs: - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: 1.69.0 + toolchain: stable override: true components: clippy - uses: tensordot/libtorch-action@v2.0.0 From a3d67caf7ca4f591df001fad89eb669283f71fcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 13 Oct 2023 19:59:02 +0200 Subject: [PATCH 5/6] Use libtorch 2.1.0 action --- .github/workflows/release.yml | 2 +- .github/workflows/rust.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ffeaf48..526f3f5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -42,7 +42,7 @@ jobs: echo "699a31cf52211cf5ad6e35a8801eb637bc7f3c43117140426400d67b7babd792 patchelf-0.12.tar.bz2" | sha256sum -c - tar jxf patchelf-0.12.tar.bz2 ( cd patchelf-0.12.20200827.8d3a16e && ./configure && make -j4 ) - - uses: tensordot/libtorch-action@v2.0.0 + - uses: tensordot/libtorch-action@v2.1.0 with: device: ${{matrix.device}} - uses: actions-rs/toolchain@v1 diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 5516473..3c2f04b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -37,7 +37,7 @@ jobs: toolchain: stable override: true components: clippy - - uses: tensordot/libtorch-action@v2.0.0 + - uses: tensordot/libtorch-action@v2.1.0 - uses: actions-rs/cargo@v1 with: command: clippy @@ -60,7 +60,7 @@ jobs: profile: minimal toolchain: stable override: true - - uses: tensordot/libtorch-action@v2.0.0 + - uses: tensordot/libtorch-action@v2.1.0 - uses: actions-rs/cargo@v1 with: command: test From ffc691e7f85f0a0c24cad49eba15b5d467d6df15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Sat, 14 Oct 2023 10:50:50 +0200 Subject: [PATCH 6/6] Fix Clippy warnings --- syntaxdot-cli/src/subcommands/distill.rs | 18 +++++++++--------- syntaxdot-transformers/src/layers.rs | 10 +++++----- syntaxdot-transformers/src/loss.rs | 2 +- .../src/models/bert/layer.rs | 4 ++-- .../src/models/squeeze_albert/mod.rs | 4 ++-- .../src/models/squeeze_bert/encoder.rs | 4 ++-- .../src/models/squeeze_bert/layer.rs | 8 ++++---- syntaxdot-transformers/src/util.rs | 4 ++-- .../src/model/biaffine_dependency_layer.rs | 16 ++++++++-------- syntaxdot/src/model/pooling.rs | 2 +- syntaxdot/src/model/seq_classifiers.rs | 2 +- syntaxdot/src/optimizers/grad_scale.rs | 6 +++--- syntaxdot/src/tensor.rs | 10 +++++----- 13 files changed, 45 insertions(+), 45 deletions(-) diff --git a/syntaxdot-cli/src/subcommands/distill.rs b/syntaxdot-cli/src/subcommands/distill.rs index d3511be..8ca4841 100644 --- a/syntaxdot-cli/src/subcommands/distill.rs +++ b/syntaxdot-cli/src/subcommands/distill.rs @@ -405,7 +405,7 @@ impl DistillApp { let token_mask = token_mask.with_root()?; let mse_loss = MSELoss::new(MSELossNormalization::SquaredL2Norm); - let mut loss = Tensor::zeros(&[], (Kind::Float, self.device)); + let mut loss = Tensor::zeros([], (Kind::Float, self.device)); let (batch_size, _) = token_mask .size2() @@ -427,8 +427,8 @@ impl DistillApp { .f_matmul(&mapping.mapping)? .f_masked_select(&token_mask.f_unsqueeze(-1)?)?; - let teacher_hidden = teacher_hidden.f_reshape(&[batch_size, -1])?; - let student_hidden = student_hidden.f_reshape(&[batch_size, -1])?; + let teacher_hidden = teacher_hidden.f_reshape([batch_size, -1])?; + let student_hidden = student_hidden.f_reshape([batch_size, -1])?; let _ = loss.f_add_(&mse_loss.forward(&student_hidden, &teacher_hidden)?); } @@ -472,7 +472,7 @@ impl DistillApp { student_encoder_logits: HashMap, token_mask: &TokenMask, ) -> Result { - let mut loss = Tensor::zeros(&[], (Kind::Float, token_mask.device())); + let mut loss = Tensor::zeros([], (Kind::Float, token_mask.device())); for (encoder_name, teacher_logits) in teacher_encoder_logits { let n_labels = teacher_logits.size()[2]; @@ -480,10 +480,10 @@ impl DistillApp { // Select the outputs for the relevant time steps. let student_logits = student_encoder_logits[&encoder_name] .masked_select(&token_mask.unsqueeze(-1)) - .reshape(&[-1, n_labels]); + .reshape([-1, n_labels]); let teacher_logits = teacher_logits .masked_select(&token_mask.unsqueeze(-1)) - .reshape(&[-1, n_labels]); + .reshape([-1, n_labels]); // Compute the soft loss. let teacher_probs = teacher_logits.f_softmax(-1, Kind::Float)?; @@ -551,7 +551,7 @@ impl DistillApp { }, )?; - let mut soft_loss = Tensor::zeros(&[], (Kind::Float, self.device)); + let mut soft_loss = Tensor::zeros([], (Kind::Float, self.device)); // Compute biaffine encoder/decoder loss. match ( @@ -582,7 +582,7 @@ impl DistillApp { let attention_loss = if self.attention_loss { self.attention_loss(&teacher_layer_outputs, &student_layer_outputs)? } else { - Tensor::zeros(&[], (Kind::Float, self.device)) + Tensor::zeros([], (Kind::Float, self.device)) }; let hidden_loss = match auxiliary_params.hidden_mappings { @@ -592,7 +592,7 @@ impl DistillApp { &teacher_layer_outputs, &student_layer_outputs, )?, - None => Tensor::zeros(&[], (Kind::Float, self.device)), + None => Tensor::zeros([], (Kind::Float, self.device)), }; Ok(DistillLoss { diff --git a/syntaxdot-transformers/src/layers.rs b/syntaxdot-transformers/src/layers.rs index ba81ead..ec7c0a4 100644 --- a/syntaxdot-transformers/src/layers.rs +++ b/syntaxdot-transformers/src/layers.rs @@ -59,9 +59,9 @@ impl FallibleModule for Conv1D { xs, &self.ws, self.bs.as_ref(), - &[self.config.stride], - &[self.config.padding], - &[self.config.dilation], + [self.config.stride], + [self.config.padding], + [self.config.dilation], self.config.groups, )?) } @@ -286,7 +286,7 @@ impl PairwiseBilinear { let (batch_size, seq_len, _) = u.size3()?; - let ones = Tensor::ones(&[batch_size, seq_len, 1], (u.kind(), u.device())); + let ones = Tensor::ones([batch_size, seq_len, 1], (u.kind(), u.device())); let u = if self.bias_u { Tensor::f_cat(&[u, &ones], -1)? @@ -346,7 +346,7 @@ impl FallibleModuleT for VariationalDropout { } let (batch_size, _, repr_size) = xs.size3()?; - let dropout_mask = Tensor::f_ones(&[batch_size, 1, repr_size], (xs.kind(), xs.device()))? + let dropout_mask = Tensor::f_ones([batch_size, 1, repr_size], (xs.kind(), xs.device()))? .f_dropout_(self.p, true)?; Ok(xs.f_mul(&dropout_mask)?) } diff --git a/syntaxdot-transformers/src/loss.rs b/syntaxdot-transformers/src/loss.rs index 54f1815..2c97004 100644 --- a/syntaxdot-transformers/src/loss.rs +++ b/syntaxdot-transformers/src/loss.rs @@ -155,7 +155,7 @@ impl MSELoss { match self.normalization { MSELossNormalization::Mean => loss, MSELossNormalization::SquaredL2Norm => { - let norm = target.f_frobenius_norm(&[1], true)?.f_square()?; + let norm = target.f_frobenius_norm([1], true)?.f_square()?; let (batch_size, _) = target.size2()?; loss? .f_div(&norm)? diff --git a/syntaxdot-transformers/src/models/bert/layer.rs b/syntaxdot-transformers/src/models/bert/layer.rs index f1e56f6..5e3f41f 100644 --- a/syntaxdot-transformers/src/models/bert/layer.rs +++ b/syntaxdot-transformers/src/models/bert/layer.rs @@ -257,7 +257,7 @@ impl BertSelfAttention { let context_layer = attention_probs.f_matmul(&value_layer)?; - let context_layer = context_layer.f_permute(&[0, 2, 1, 3])?.f_contiguous()?; + let context_layer = context_layer.f_permute([0, 2, 1, 3])?.f_contiguous()?; let mut new_context_layer_shape = context_layer.size(); new_context_layer_shape.splice( new_context_layer_shape.len() - 2.., @@ -273,7 +273,7 @@ impl BertSelfAttention { new_x_shape.pop(); new_x_shape.extend(&[self.num_attention_heads, self.attention_head_size]); - Ok(x.f_view_(&new_x_shape)?.f_permute(&[0, 2, 1, 3])?) + Ok(x.f_view_(&new_x_shape)?.f_permute([0, 2, 1, 3])?) } } diff --git a/syntaxdot-transformers/src/models/squeeze_albert/mod.rs b/syntaxdot-transformers/src/models/squeeze_albert/mod.rs index 832e98f..c6a8ba8 100644 --- a/syntaxdot-transformers/src/models/squeeze_albert/mod.rs +++ b/syntaxdot-transformers/src/models/squeeze_albert/mod.rs @@ -232,7 +232,7 @@ impl Encoder for SqueezeAlbertEncoder { ) -> Result, TransformerError> { let hidden_states = self.projection.forward(input); - let input = hidden_states.f_permute(&[0, 2, 1])?; + let input = hidden_states.f_permute([0, 2, 1])?; let mut all_layer_outputs = Vec::with_capacity(self.n_layers as usize + 1); all_layer_outputs.push(LayerOutput::Embedding(hidden_states.shallow_clone())); @@ -256,7 +256,7 @@ impl Encoder for SqueezeAlbertEncoder { // Convert hidden states to [batch_size, seq_len, hidden_size]. for layer_output in &mut all_layer_outputs { - *layer_output.output_mut() = layer_output.output().f_permute(&[0, 2, 1])?; + *layer_output.output_mut() = layer_output.output().f_permute([0, 2, 1])?; } Ok(all_layer_outputs) diff --git a/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs b/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs index f93634d..71c94ae 100644 --- a/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs +++ b/syntaxdot-transformers/src/models/squeeze_bert/encoder.rs @@ -59,7 +59,7 @@ impl Encoder for SqueezeBertEncoder { let attention_mask = attention_mask.map(LogitsMask::from_bool_mask).transpose()?; // [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size, seq_len] - let mut hidden_states = input.f_permute(&[0, 2, 1])?; + let mut hidden_states = input.f_permute([0, 2, 1])?; let mut all_layer_outputs = Vec::with_capacity(self.layers.len() + 1); all_layer_outputs.push(LayerOutput::Embedding(hidden_states.shallow_clone())); @@ -73,7 +73,7 @@ impl Encoder for SqueezeBertEncoder { // Convert hidden states to [batch_size, seq_len, hidden_size]. for layer_output in &mut all_layer_outputs { - *layer_output.output_mut() = layer_output.output().f_permute(&[0, 2, 1])?; + *layer_output.output_mut() = layer_output.output().f_permute([0, 2, 1])?; } Ok(all_layer_outputs) diff --git a/syntaxdot-transformers/src/models/squeeze_bert/layer.rs b/syntaxdot-transformers/src/models/squeeze_bert/layer.rs index fb51be3..fe9e840 100644 --- a/syntaxdot-transformers/src/models/squeeze_bert/layer.rs +++ b/syntaxdot-transformers/src/models/squeeze_bert/layer.rs @@ -49,9 +49,9 @@ impl FallibleModule for SqueezeBertLayerNorm { type Error = TransformerError; fn forward(&self, xs: &Tensor) -> Result { - let xs_perm = xs.f_permute(&[0, 2, 1])?; + let xs_perm = xs.f_permute([0, 2, 1])?; let xs_perm_norm = self.layer_norm.forward(&xs_perm)?; - Ok(xs_perm_norm.f_permute(&[0, 2, 1])?) + Ok(xs_perm_norm.f_permute([0, 2, 1])?) } } @@ -235,7 +235,7 @@ impl SqueezeBertSelfAttention { *x_size.last().unwrap(), ]; - Ok(x.f_view_(new_x_shape)?.f_permute(&[0, 1, 3, 2])?) + Ok(x.f_view_(new_x_shape)?.f_permute([0, 1, 3, 2])?) } fn transpose_key_for_scores(&self, x: &Tensor) -> Result { @@ -251,7 +251,7 @@ impl SqueezeBertSelfAttention { } fn transpose_output(&self, x: &Tensor) -> Result { - let x = x.f_permute(&[0, 1, 3, 2])?.f_contiguous()?; + let x = x.f_permute([0, 1, 3, 2])?.f_contiguous()?; let x_size = x.size(); let new_x_shape = &[x_size[0], self.all_head_size, x_size[3]]; Ok(x.f_view_(new_x_shape)?) diff --git a/syntaxdot-transformers/src/util.rs b/syntaxdot-transformers/src/util.rs index 0d74721..aa9af10 100644 --- a/syntaxdot-transformers/src/util.rs +++ b/syntaxdot-transformers/src/util.rs @@ -123,7 +123,7 @@ impl SinusoidalPositions for Tensor { if let Some(p) = p_norm { // Compute the p-norm. - let norm = self.f_norm_scalaropt_dim(p, &[-1], true)?; + let norm = self.f_norm_scalaropt_dim(p, [-1], true)?; // Normalize embeddings. let _ = self.f_div_(&norm)?; @@ -145,7 +145,7 @@ impl SinusoidalPositions for Tensor { dims ); - let mut positions = Tensor::f_empty(&[n_positions, dims], options)?; + let mut positions = Tensor::f_empty([n_positions, dims], options)?; positions.sinusoidal_positions_(p_norm)?; Ok(positions) diff --git a/syntaxdot/src/model/biaffine_dependency_layer.rs b/syntaxdot/src/model/biaffine_dependency_layer.rs index 54edc92..8682c83 100644 --- a/syntaxdot/src/model/biaffine_dependency_layer.rs +++ b/syntaxdot/src/model/biaffine_dependency_layer.rs @@ -268,7 +268,7 @@ impl BiaffineDependencyLayer { 1, &heads .f_unsqueeze(-1)? - .f_expand(&[batch_size, n_tokens, label_hidden_size], true)?, + .f_expand([batch_size, n_tokens, label_hidden_size], true)?, false, )?; @@ -370,8 +370,8 @@ impl BiaffineDependencyLayer { let head_logits = biaffine_logits .head_score_logits // Last dimension is ROOT + all tokens as head candidates. - .f_reshape(&[-1, seq_len + 1])?; - let head_targets = &targets.heads.f_view_(&[-1])?; + .f_reshape([-1, seq_len + 1])?; + let head_targets = &targets.heads.f_view_([-1])?; let head_loss = CrossEntropyLoss::new(-1, label_smoothing, Reduction::Mean).forward( &head_logits, head_targets, @@ -380,18 +380,18 @@ impl BiaffineDependencyLayer { // [batch_size, seq_len + 1] -> [batch_size, 1, seq_len + 1] .f_unsqueeze(1)? // [batch_size, 1, seq_len + 1] -> [batch_size, seq_len, seq_len + 1]. - .f_expand(&[-1, seq_len, -1], true)? + .f_expand([-1, seq_len, -1], true)? // [batch_size, seq_len, seq_len + 1] -> [batch_size * seq_len, seq_len + 1] - .f_reshape(&[-1, seq_len + 1])?, + .f_reshape([-1, seq_len + 1])?, ), )?; // Get the logits for the correct heads. let label_score_logits = biaffine_logits .relation_score_logits - .f_reshape(&[-1, self.n_relations])?; + .f_reshape([-1, self.n_relations])?; - let relation_targets = targets.relations.f_view_(&[-1])?; + let relation_targets = targets.relations.f_view_([-1])?; let relation_loss = CrossEntropyLoss::new(-1, label_smoothing, Reduction::Mean).forward( &label_score_logits, &relation_targets, @@ -423,7 +423,7 @@ impl BiaffineDependencyLayer { .f_argmax(-1, false)?; let relations_correct = relations_predicted .f_eq_tensor(&targets.relations)? - .f_view_(&[batch_size, seq_len])?; + .f_view_([batch_size, seq_len])?; let head_and_relations_correct = head_correct.f_logical_and(&relations_correct)?; diff --git a/syntaxdot/src/model/pooling.rs b/syntaxdot/src/model/pooling.rs index 83cb88a..9a292a8 100644 --- a/syntaxdot/src/model/pooling.rs +++ b/syntaxdot/src/model/pooling.rs @@ -141,7 +141,7 @@ impl EmbeddingsPerToken for TokenSpansWithRoot { 1, &piece_indices .f_view([batch_size, -1, 1])? - .f_expand(&[-1, -1, embed_size], true)?, + .f_expand([-1, -1, embed_size], true)?, false, )? .f_view([batch_size, tokens_len, max_token_len, embed_size])? diff --git a/syntaxdot/src/model/seq_classifiers.rs b/syntaxdot/src/model/seq_classifiers.rs index 8202dd0..d5d5a36 100644 --- a/syntaxdot/src/model/seq_classifiers.rs +++ b/syntaxdot/src/model/seq_classifiers.rs @@ -136,7 +136,7 @@ impl SequenceClassifiers { } let summed_loss = encoder_losses.values().try_fold( - Tensor::f_zeros(&[], (Kind::Float, layers_without_root[0].output().device()))?, + Tensor::f_zeros([], (Kind::Float, layers_without_root[0].output().device()))?, |summed_loss, loss| summed_loss.f_add(loss), )?; diff --git a/syntaxdot/src/optimizers/grad_scale.rs b/syntaxdot/src/optimizers/grad_scale.rs index 9c3fe69..71e9298 100644 --- a/syntaxdot/src/optimizers/grad_scale.rs +++ b/syntaxdot/src/optimizers/grad_scale.rs @@ -59,9 +59,9 @@ where optimizer, - found_inf: Tensor::full(&[1], 0.0, (Kind::Float, device)), - growth_tracker: Tensor::full(&[1], 0, (Kind::Int, device)), - scale: Tensor::full(&[1], init_scale, (Kind::Float, device)), + found_inf: Tensor::full([1], 0.0, (Kind::Float, device)), + growth_tracker: Tensor::full([1], 0, (Kind::Int, device)), + scale: Tensor::full([1], init_scale, (Kind::Float, device)), }) } diff --git a/syntaxdot/src/tensor.rs b/syntaxdot/src/tensor.rs index 72977da..8117718 100644 --- a/syntaxdot/src/tensor.rs +++ b/syntaxdot/src/tensor.rs @@ -333,8 +333,8 @@ impl SequenceLengths { Ok(Tensor::f_arange(max_len, (Kind::Int, self.inner.device()))? // Construct a matrix [batch_size, max_len] where each row // is 0..(max_len - 1). - .f_repeat(&[batch_size])? - .f_view_(&[batch_size, max_len])? + .f_repeat([batch_size])? + .f_view_([batch_size, max_len])? // Time steps less than the length in the sequence lengths are active. .f_lt_tensor(&self.inner.unsqueeze(1))? // For some reason the kind is Int? @@ -403,13 +403,13 @@ impl TokenSpans { let root_offset = Tensor::from(0) .f_view([1, 1])? - .f_expand(&[batch_size, 1], true)? + .f_expand([batch_size, 1], true)? .to_device(self.offsets.device()); let offsets = Tensor::f_cat(&[&root_offset, &self.offsets], 1)?; let root_len = Tensor::from(1) .f_view([1, 1])? - .f_expand(&[batch_size, 1], true)? + .f_expand([batch_size, 1], true)? .to_device(self.lens.device()); let lens = Tensor::f_cat(&[&root_len, &self.lens], 1)?; @@ -460,7 +460,7 @@ impl TokenMask { let (batch_size, _seq_len) = self.inner.size2()?; let root_mask = Tensor::from(true) - .f_expand(&[batch_size, 1], true)? + .f_expand([batch_size, 1], true)? .to_device(self.inner.device()); let token_mask_with_root = Tensor::f_cat(&[&root_mask, &self.inner], -1)?;