From f64914b3a5a1a7ff75c55dad7b9a460542dd16d5 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 18 Nov 2024 18:16:54 +0100 Subject: [PATCH] Implicit GEMM optimizations/bug fixes (#2499) --- Cargo.lock | 395 +++++++++++++----- Cargo.toml | 8 +- .../src/kernel/conv/conv2d/implicit_gemm.rs | 151 ++++--- .../src/kernel/interpolate/bicubic.rs | 18 +- .../src/kernel/interpolate/bilinear.rs | 54 ++- .../src/kernel/reduce/shared/kernel.rs | 35 +- .../src/kernel/reduce/subcube/kernel.rs | 54 ++- .../burn-jit/src/kernel/reduce/tune/base.rs | 7 +- 8 files changed, 480 insertions(+), 242 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0bb14b760f..999a9c8388 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,9 +64,9 @@ checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" [[package]] name = "allocator-api2" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "android-tzdata" @@ -85,9 +85,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.17" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23a1e53f0f5d86382dafe1cf314783b2044280f406e7e1506368220ad11b1338" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" dependencies = [ "anstyle", "anstyle-parse", @@ -134,15 +134,15 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "arbitrary" -version = "1.3.2" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" dependencies = [ "derive_arbitrary", ] @@ -373,7 +373,7 @@ dependencies = [ "burn", "burn-common", "burn-wgpu", - "clap 4.5.20", + "clap 4.5.21", "colored", "cubecl", "derive-new 0.7.0", @@ -490,9 +490,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.5.3" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" +checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" [[package]] name = "blas-src" @@ -531,9 +531,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +checksum = "1a68f1f47cdf0ec8ee4b941b2eee2a80cb796db73118c0dd09ac63fbe405be22" dependencies = [ "memchr", "serde", @@ -621,7 +621,7 @@ dependencies = [ "derive-new 0.7.0", "flate2", "half", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "log", "num-traits", "portable-atomic-util", @@ -701,7 +701,7 @@ dependencies = [ "burn-tensor", "derive-new 0.7.0", "half", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "log", "serde", "spin", @@ -762,7 +762,7 @@ dependencies = [ "derive-new 0.7.0", "futures-lite", "half", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "log", "num-traits", "paste", @@ -833,7 +833,7 @@ dependencies = [ "burn-ndarray", "burn-tensor", "burn-wgpu", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "log", "spin", ] @@ -862,7 +862,7 @@ dependencies = [ "cubecl", "derive-new 0.7.0", "half", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "num-traits", "portable-atomic-util", "rand", @@ -1049,9 +1049,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.34" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -1151,9 +1151,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive 4.5.18", @@ -1161,13 +1161,13 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstream", "anstyle", - "clap_lex 0.7.2", + "clap_lex 0.7.3", "strsim 0.11.1", ] @@ -1207,9 +1207,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "clipboard-win" @@ -1294,14 +1294,14 @@ dependencies = [ [[package]] name = "comfy-table" -version = "7.1.1" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" +checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ - "crossterm 0.27.0", + "crossterm", "strum", "strum_macros", - "unicode-width 0.1.14", + "unicode-width 0.2.0", ] [[package]] @@ -1404,9 +1404,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "0ca741a962e1b0bff6d724a1a0958b686406e853bb14061f218562e1896f95e6" dependencies = [ "libc", ] @@ -1478,19 +1478,6 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" -[[package]] -name = "crossterm" -version = "0.27.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" -dependencies = [ - "bitflags 2.6.0", - "crossterm_winapi", - "libc", - "parking_lot 0.12.3", - "winapi", -] - [[package]] name = "crossterm" version = "0.28.1" @@ -1556,7 +1543,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1587,7 +1574,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1604,7 +1591,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1622,7 +1609,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1636,7 +1623,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1652,7 +1639,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1677,7 +1664,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "bytemuck", "cubecl-core", @@ -1688,7 +1675,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1703,7 +1690,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1740,7 +1727,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "async-channel", "async-lock", @@ -1761,7 +1748,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1775,7 +1762,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=99df09381aac4e2cd1354a744ec99bbd364bc9ea#99df09381aac4e2cd1354a744ec99bbd364bc9ea" +source = "git+https://github.com/tracel-ai/cubecl?rev=8f4861ebe577065e2209ee94724c05b514e1b860#8f4861ebe577065e2209ee94724c05b514e1b860" dependencies = [ "ash", "async-channel", @@ -1976,9 +1963,9 @@ dependencies = [ [[package]] name = "derive_arbitrary" -version = "1.3.2" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" +checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", @@ -2338,9 +2325,9 @@ checksum = "a2a2b11eda1d40935b26cf18f6833c526845ae8c41e58d09af6adeb6f0269183" [[package]] name = "fastrand" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "fdeflate" @@ -2371,9 +2358,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", @@ -2843,7 +2830,7 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "regex-syntax 0.8.5", ] @@ -3022,9 +3009,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" dependencies = [ "allocator-api2", "equivalent", @@ -3343,6 +3330,124 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -3351,12 +3456,23 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] @@ -3369,7 +3485,7 @@ dependencies = [ "globset", "log", "memchr", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "same-file", "walkdir", "winapi-util", @@ -3451,20 +3567,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.0", + "hashbrown 0.15.1", ] [[package]] name = "indicatif" -version = "0.17.8" +version = "0.17.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" dependencies = [ "console", - "instant", "number_prefix", "portable-atomic", - "unicode-width 0.1.14", + "unicode-width 0.2.0", + "web-time", ] [[package]] @@ -3484,10 +3600,14 @@ dependencies = [ [[package]] name = "instability" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b23a0c8dfe501baac4adf6ebbfa6eddf8f0c07f56b058cc1288017e32397846c" +checksum = "b829f37dead9dc39df40c2d3376c179fdfd2ac771f53f55d3c30dc096a3c0c6e" dependencies = [ + "darling", + "indoc", + "pretty_assertions", + "proc-macro2", "quote", "syn 2.0.87", ] @@ -3624,19 +3744,18 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libfuzzer-sys" -version = "0.4.7" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96cfd5557eb82f2b83fed4955246c988d331975a002961b07c81584d107e7f7" +checksum = "9b9569d2f74e257076d8c6bfa73fb505b46b851e51ddaecc825944aa3bed17fa" dependencies = [ "arbitrary", "cc", - "once_cell", ] [[package]] @@ -3683,6 +3802,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" + [[package]] name = "litrs" version = "0.4.1" @@ -3726,7 +3851,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.0", + "hashbrown 0.15.1", ] [[package]] @@ -5362,9 +5487,9 @@ dependencies = [ [[package]] name = "psm" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa37f80ca58604976033fae9515a8a2989fc13797d953f7c04fb8fa36a11f205" +checksum = "200b9ff220857e53e184257720a14553b2f4aa02577d2ed9842d45d4b9654810" dependencies = [ "cc", ] @@ -5503,7 +5628,7 @@ dependencies = [ "bitflags 2.6.0", "cassowary", "compact_str", - "crossterm 0.28.1", + "crossterm", "indoc", "instability", "itertools 0.13.0", @@ -5690,7 +5815,7 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.8", + "regex-automata 0.4.9", "regex-syntax 0.8.5", ] @@ -5705,9 +5830,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -6081,9 +6206,9 @@ dependencies = [ [[package]] name = "scc" -version = "2.2.4" +version = "2.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8d25269dd3a12467afe2e510f69fb0b46b698e5afb296b59f2145259deaf8e8" +checksum = "66b202022bb57c049555430e11fc22fea12909276a80a4c3d368da36ac1d88ed" dependencies = [ "sdd", ] @@ -6133,9 +6258,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" dependencies = [ "core-foundation-sys", "libc", @@ -6823,18 +6948,18 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" [[package]] name = "thiserror" -version = "1.0.67" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3c6efbfc763e64eb85c11c25320f0737cb7364c4b6336db90aa9ebe27a0bbd" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.67" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b607164372e89797d78b8e23a6d67d5d1038c1c65efd52e1389ef8b77caba2a6" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", @@ -6904,6 +7029,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -7108,7 +7243,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4126466aafe1c518cb5c23979c286903cb1d1ff1bc3b76891254a243a0ed1e15" dependencies = [ "anyhow", - "clap 4.5.20", + "clap 4.5.21", "derive_more 0.99.18", "env_logger", "log", @@ -7235,12 +7370,6 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" -[[package]] -name = "unicode-bidi" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" - [[package]] name = "unicode-ident" version = "1.0.13" @@ -7343,9 +7472,9 @@ dependencies = [ [[package]] name = "url" -version = "2.5.2" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" dependencies = [ "form_urlencoded", "idna", @@ -7358,6 +7487,18 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -7945,6 +8086,18 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "wsl" version = "0.1.0" @@ -7981,9 +8134,9 @@ dependencies = [ [[package]] name = "xml-rs" -version = "0.8.22" +version = "0.8.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af4e2e2f7cba5a093896c1e150fbfe177d1883e7448200efb81d40b9d339ef26" +checksum = "af310deaae937e48a26602b730250b4949e125f468f11e6990be3e5304ddd96f" [[package]] name = "xtask" @@ -8093,6 +8246,28 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index eb52910438..ac443a4659 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,8 +117,8 @@ bincode = { version = "2.0.0-rc.3", features = [ # # The following packages disable the "std" feature for no_std compatibility # -derive-new = { version = "0.7.0", default-features = false } cfg-if = "1.0.0" +derive-new = { version = "0.7.0", default-features = false } blas-src = { version = "0.10.0", default-features = false } half = { version = "2.4.1", features = [ @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.2", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99df09381aac4e2cd1354a744ec99bbd364bc9ea" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8f4861ebe577065e2209ee94724c05b514e1b860" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8f4861ebe577065e2209ee94724c05b514e1b860" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } @@ -166,4 +166,4 @@ cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features tracel-xtask = { version = "~1.1" } [profile.dev] -debug = 0 # Speed up compilation time and not necessary. +debug = 0 # Speed up compilation time and not necessary. diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index ed28c2bee9..ae74b4129d 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -3,11 +3,16 @@ use burn_tensor::{ Shape, }; use cmma::{Matrix, MatrixIdent, MatrixLayout}; -use cubecl::{cube, prelude::*, Compiler, CubeCount, CubeDim, Feature}; +use cubecl::{ + cube, + ir::{Elem, FloatKind}, + prelude::*, + Compiler, CubeCount, CubeDim, Feature, +}; use half::f16; use crate::{ - kernel::{into_contiguous, slice}, + kernel::{into_contiguous, slice, slice_assign}, ops::{ numeric::{empty_device, zeros_device}, permute, @@ -30,9 +35,17 @@ pub fn conv2d_implicit_gemm( bias: Option>, options: ConvOptions<2>, ) -> JitTensor { + let is_tf32 = F::as_elem() == Elem::Float(FloatKind::F32) + && input + .client + .properties() + .feature_enabled(Feature::Type(Elem::Float(FloatKind::TF32))); + + let k_target = if is_tf32 { 8 } else { 16 }; + let [batch_size, in_channels, height, width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); - let (pad_in_channels, pad_kh, pad_kw) = padded_k(in_channels, kernel_h, kernel_w); + let (pad_in_channels, pad_kh, pad_kw) = padded_k(in_channels, kernel_h, kernel_w, k_target); let padded_out_channels = out_channels.div_ceil(16) * 16; let out_h = calculate_conv_output_size( @@ -66,12 +79,13 @@ pub fn conv2d_implicit_gemm( "Requirements for implicit GEMM not met: - CMMA must be available - `groups` must be 1 +- subcube size must be non-variable (might not hold on Intel) " ); } let input = into_contiguous(permute(input, &[0, 2, 3, 1])); - let weight = into_contiguous(permute(weight, &[0, 2, 3, 1])); + let weight = into_contiguous(permute(weight, &[2, 3, 1, 0])); let out_shape = Shape::new([padded_batch_size, out_h, out_w, padded_out_channels]); let out = empty_device(input.client.clone(), input.device.clone(), out_shape); @@ -81,10 +95,10 @@ pub fn conv2d_implicit_gemm( let gemm_n = padded_out_channels as u32; let gemm_k = (pad_in_channels * pad_kh * pad_kw) as u32; - let slice_size = pad_kh * pad_kw * pad_in_channels; - let (cmma_m, cmma_n, cmma_k) = - find_cmma_size::(&input.client, gemm_m, gemm_k, gemm_n).unwrap(); + find_cmma_size::(&input.client, gemm_m, gemm_k, gemm_n).unwrap(); + + let slice_size = pad_kh * pad_kw * pad_in_channels; let cube_dim_x = 128; let cube_dim_y = Ord::min(gemm_n.div_ceil(16), 2); @@ -92,7 +106,8 @@ pub fn conv2d_implicit_gemm( let input_tile_size = cmma_m * cmma_k; let weight_tile_size = cmma_k * cmma_n; - let warp_size = 32; + let topology = input.client.properties().hardware_properties(); + let warp_size = topology.plane_size_min; let warps_per_cube = (cube_dim_y * cube_dim_x) / warp_size; let supported_vecs = R::supported_line_sizes(); @@ -102,12 +117,19 @@ pub fn conv2d_implicit_gemm( let weight_elems_per_thread = weight_tile_size / warp_size; let weight_vectorization = - find_common_vec(in_channels, weight_elems_per_thread, supported_vecs); + find_common_vec(out_channels, weight_elems_per_thread, supported_vecs); let has_bias = bias.is_some(); - let bias = bias.unwrap_or_else(|| { - zeros_device(input.client.clone(), input.device.clone(), Shape::new([1])) - }); + let bias = match bias { + Some(bias) if out_channels == padded_out_channels => bias, + Some(bias) => { + let shape = Shape::new([padded_out_channels]); + let padded_bias = zeros_device(bias.client.clone(), bias.device.clone(), shape); + #[allow(clippy::single_range_in_vec_init)] + slice_assign(padded_bias, &[0..out_channels], bias) + } + None => empty_device(input.client.clone(), input.device.clone(), Shape::new([1])), + }; let settings = GemmSettings { cmma_m, @@ -138,7 +160,12 @@ pub fn conv2d_implicit_gemm( let cube_count = CubeCount::Static(cube_count_x, cube_count_y, 1); - implicit_gemm_kernel::launch::( + let launch = match is_tf32 { + false => implicit_gemm_kernel::launch::, + true => implicit_gemm_kernel::launch::, + }; + + launch( &input.client, cube_count, cube_dim, @@ -303,7 +330,7 @@ fn implicit_gemm_kernel( let mut out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size); if conv_settings.aligned || pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n { - execute_gemm( + execute_gemm::( input, weight, bias, @@ -396,7 +423,7 @@ fn make_matrices( cmma_m, cmma_n, cmma_k, - MatrixLayout::ColMajor, + MatrixLayout::RowMajor, ) }, acc, @@ -422,8 +449,7 @@ fn execute_gemm( let matrices = make_matrices::(g_settings, has_bias); if has_bias { - let n = UNIT_POS_Y * cmma_n + pos.global_n; - let bias_tile = bias.slice(n, n + cmma_n); + let bias_tile = bias.slice(pos.global_n, pos.global_n + cmma_n); cmma::load_with_layout(&matrices.acc, &bias_tile, 0, MatrixLayout::RowMajor); } @@ -440,8 +466,8 @@ fn execute_gemm( load_weight_tile(weight, weight_tile, dims, pos, k, g_settings, k_settings); // Run CMMA + cmma::load(&matrices.b, &weight_tile.to_slice(), cmma_n); cmma::load(&matrices.a, &input_tile.to_slice(), cmma_k); - cmma::load(&matrices.b, &weight_tile.to_slice(), cmma_k); cmma::execute::(&matrices.a, &matrices.b, &matrices.acc, &matrices.acc); } @@ -573,29 +599,31 @@ fn load_weight_tile( let cmma_filter_tile_size = cmma_k * cmma_n; let elems_per_thread = cmma_filter_tile_size / warp_size; let start = pos.intra_warp_unit_idx * elems_per_thread; - let abs_slice_col = pos.global_n + (start / cmma_k); // Row of the matrix the slice is on - let n_in_bounds = !check_n || abs_slice_col < weight.shape(0); - let col_idx = abs_slice_col * weight.stride(0); + let global_k = start / cmma_n + k; + + let (k_idx, k_in_bounds) = if check_k { + let channel = global_k % dims.pad_channels; + let kernel_x = global_k / dims.pad_channels % dims.pad_kw; + let kernel_y = global_k / (dims.pad_channels * dims.pad_kw); + let k_in_bounds = + !check_k || (channel < weight.shape(2) && kernel_x < kernel_w && kernel_y < kernel_h); + let idx = + kernel_y * weight.stride(0) + kernel_x * weight.stride(1) + channel * weight.stride(2); + (idx, k_in_bounds) + } else { + (global_k * weight.stride(2), true) + }; #[unroll] for n in range_stepped(0, elems_per_thread, vec) { let n = n + start; - // Compute where in the slice we are starting - let rel_slice_row = n % cmma_k; // Relative row (0 - 15) - let abs_slice_row = k + rel_slice_row; // Row of the matrix the slice is on - - let (idx, k_in_bounds) = if check_k { - let channel = abs_slice_row % dims.pad_channels; - let kernel_x = abs_slice_row / dims.pad_channels % dims.pad_kw; - let kernel_y = abs_slice_row / (dims.pad_channels * dims.pad_kw); - let k_in_bounds = !check_k - || (channel < weight.shape(3) && kernel_x < kernel_w && kernel_y < kernel_h); - let idx = col_idx + kernel_y * weight.stride(1) + kernel_x * weight.stride(2) + channel; - (idx, k_in_bounds) - } else { - (col_idx + abs_slice_row, true) - }; + + let global_n = (n % cmma_n) + pos.global_n; + let n_in_bounds = !check_n || global_n < weight.shape(3); + + let idx = k_idx + global_n; + let value = FMat::cast_from(weight[idx / vec]); let value = select(k_in_bounds && n_in_bounds, value, FMat::new(0.0)); @@ -617,7 +645,18 @@ pub(crate) fn can_do_implicit_gemm( out_w: usize, client: &ComputeClient, ) -> bool { - let (in_channels, kernel_h, kernel_w) = padded_k(in_channels, kernel_size[0], kernel_size[1]); + let cmma_k = match ( + E::as_elem(), + client + .properties() + .feature_enabled(Feature::Type(tf32::as_elem())), + ) { + (Elem::Float(FloatKind::F32), true) => 8, + _ => 16, + }; + + let (in_channels, kernel_h, kernel_w) = + padded_k(in_channels, kernel_size[0], kernel_size[1], cmma_k); let batch_size = padded_batch_size(batch_size, out_h, out_w); let out_channels = out_channels.div_ceil(16) * 16; @@ -625,21 +664,27 @@ pub(crate) fn can_do_implicit_gemm( let gemm_n = out_channels; let gemm_k = in_channels * kernel_h * kernel_w; - let size = find_cmma_size::(client, gemm_m as u32, gemm_k as u32, gemm_n as u32); + let size = find_cmma_size::(client, gemm_m as u32, gemm_k as u32, gemm_n as u32); if let Some((cmma_m, cmma_k, cmma_n)) = size { let warps_per_cube = 8; let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::(); + let topology = client.properties().hardware_properties(); + let not_intel = topology.plane_size_min >= 32; - ::max_shared_memory_size() >= smem_size && groups == 1 + ::max_shared_memory_size() >= smem_size && groups == 1 && not_intel } else { false } } -fn padded_k(in_channels: usize, kernel_h: usize, kernel_w: usize) -> (usize, usize, usize) { - let target = 16; +fn padded_k( + in_channels: usize, + kernel_h: usize, + kernel_w: usize, + target: usize, +) -> (usize, usize, usize) { if in_channels * kernel_h * kernel_w % target == 0 { return (in_channels, kernel_h, kernel_w); } @@ -659,7 +704,7 @@ fn padded_k(in_channels: usize, kernel_h: usize, kernel_w: usize) -> (usize, usi fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize { let out_size = out_h * out_w; - let target = if out_size % 2 == 0 { + let target = if out_size.is_power_of_two() || out_size % 16 == 0 { (16usize).div_ceil(out_size) } else { 16 @@ -667,13 +712,13 @@ fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize { batch_size.div_ceil(target) * target } -fn find_cmma_size( +fn find_cmma_size( client: &ComputeClient, gemm_m: u32, gemm_k: u32, gemm_n: u32, ) -> Option<(u32, u32, u32)> { - supported_cmma_sizes::(client) + supported_cmma_sizes::(client) .into_iter() .find(|(m, k, n)| { gemm_m % *m as u32 == 0 && gemm_k % *k as u32 == 0 && gemm_n % *n as u32 == 0 @@ -681,19 +726,27 @@ fn find_cmma_size( .map(|(m, k, n)| (m as u32, n as u32, k as u32)) } -fn supported_cmma_sizes( +fn supported_cmma_sizes( client: &ComputeClient, ) -> Vec<(u8, u8, u8)> { - let requested_sizes = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]; + let (requested_sizes, matrix_elem) = match ( + F::as_elem(), + client + .properties() + .feature_enabled(Feature::Type(tf32::as_elem())), + ) { + (Elem::Float(FloatKind::F32), true) => (vec![(16, 8, 16)], tf32::as_elem()), + _ => (vec![(16, 16, 16), (32, 16, 8), (8, 16, 32)], f16::as_elem()), + }; requested_sizes .iter() .copied() .filter(|(m, k, n)| { client.properties().feature_enabled(Feature::Cmma { - a: F::as_elem(), - b: F::as_elem(), - c: FAcc::as_elem(), + a: matrix_elem, + b: matrix_elem, + c: F::as_elem(), m: *m, k: *k, n: *n, diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index c88a95e968..2e554bf647 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -2,7 +2,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; -#[cube(launch_unchecked)] +#[cube(launch)] fn interpolate_bicubic_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { return; @@ -128,15 +128,13 @@ pub(crate) fn interpolate_bicubic_launch( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); - unsafe { - interpolate_bicubic_kernel::launch_unchecked::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), - ) - }; + interpolate_bicubic_kernel::launch::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg(1), + output.as_tensor_arg(1), + ); output } diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index 840bb13954..0314c77544 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -2,7 +2,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; -#[cube(launch_unchecked)] +#[cube(launch)] fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { return; @@ -17,23 +17,25 @@ fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor< let denominator = F::cast_from(Max::max(output.shape(2) - 1, 1)); let factor = F::cast_from(y); - let frac = factor * numerator / denominator; + let frac = factor * (numerator / denominator); let v0 = Floor::floor(frac); let v1: F = Ceil::ceil(frac); let yw = frac - v0; let yw_ = F::new(1.0) - yw; + let y0_ok = v0 >= F::new(0.0); let y0 = u32::cast_from(v0); let y1 = u32::cast_from(v1); let numerator = F::cast_from(input.shape(3) - 1); let denominator = F::cast_from(Max::max(output.shape(3) - 1, 1)); let factor = F::cast_from(x); - let frac = factor * numerator / denominator; + let frac = factor * (numerator / denominator); let v0 = Floor::floor(frac); let v1: F = Ceil::ceil(frac); let xw = frac - v0; let xw_ = F::new(1.0) - xw; + let x0_ok = v0 >= F::new(0.0); let x0 = u32::cast_from(v0); let x1 = u32::cast_from(v1); @@ -47,10 +49,32 @@ fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor< let x0_stride = x0 * in_stride_x; let x1_stride = x1 * in_stride_x; - let p_a = input[index_base + y0_stride + x0_stride] * xw_ * yw_; - let p_b = input[index_base + y0_stride + x1_stride] * xw * yw_; - let p_c = input[index_base + y1_stride + x0_stride] * xw_ * yw; - let p_d = input[index_base + y1_stride + x1_stride] * xw * yw; + let height = input.shape(2); + let width = input.shape(3); + + let y1_ok = y1 < height; + let x1_ok = x1 < width; + + let p_a = select( + x0_ok && y0_ok, + input[index_base + y0_stride + x0_stride] * xw_ * yw_, + F::new(0.0), + ); + let p_b = select( + x1_ok && y0_ok, + input[index_base + y0_stride + x1_stride] * xw * yw_, + F::new(0.0), + ); + let p_c = select( + x0_ok && y1_ok, + input[index_base + y1_stride + x0_stride] * xw_ * yw, + F::new(0.0), + ); + let p_d = select( + x1_ok && y1_ok, + input[index_base + y1_stride + x1_stride] * xw * yw, + F::new(0.0), + ); output[ABSOLUTE_POS] = p_a + p_b + p_c + p_d; } @@ -62,15 +86,13 @@ pub(crate) fn interpolate_bilinear_launch( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); - unsafe { - interpolate_bilinear_kernel::launch_unchecked::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), - ) - }; + interpolate_bilinear_kernel::launch::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg(1), + output.as_tensor_arg(1), + ); output } diff --git a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs index 0e004030f1..dbf8ef7a65 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs @@ -4,7 +4,7 @@ use crate::{kernel::reduce::init_reduce_output, tensor::JitTensor, JitElement, J use super::base::ReduceDimShared; -#[cube(launch_unchecked)] +#[cube(launch)] pub fn reduce_dim_shared_kernel< RD: ReduceDimShared, EIn: JitElement, @@ -16,14 +16,9 @@ pub fn reduce_dim_shared_kernel< #[comptime] smem_size: u32, #[comptime] elems_per_thread: u32, #[comptime] divisible_shape: bool, - #[comptime] check_out: bool, ) { let reduce_group_id = CUBE_POS; - if check_out && reduce_group_id >= output.len() { - return; - } - let stride_reduce_dim_input = input.stride(dim); let shape_reduce_dim_input = input.shape(dim); @@ -105,22 +100,18 @@ pub fn reduce_dim_shared< f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - let check_out = (cube_count_x * cube_count_y) as usize != num_elems_output; - - unsafe { - reduce_dim_shared_kernel::launch_unchecked::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), - dim as u32, - cube_dim.num_elems(), - elems_per_thread, - divisible_shape, - check_out, - ) - }; + + reduce_dim_shared_kernel::launch::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg(1), + output.as_tensor_arg(1), + dim as u32, + cube_dim.num_elems(), + elems_per_thread, + divisible_shape, + ); output } diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs index 1d5654fa9d..d6b7d15f7e 100644 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs @@ -8,7 +8,7 @@ use crate::{ use super::base::ReduceDimSubcube; -#[cube(launch_unchecked)] +#[cube(launch)] pub fn reduce_dim_subcube_kernel< RD: ReduceDimSubcube, EIn: JitElement, @@ -17,17 +17,12 @@ pub fn reduce_dim_subcube_kernel< input: &Tensor, output: &mut Tensor, #[comptime] dim: u32, - #[comptime] smem_size: u32, + #[comptime] subcube_size: u32, #[comptime] elems_per_thread: u32, #[comptime] divisible_shape: bool, - #[comptime] check_out: bool, ) { let reduce_group_id = CUBE_POS; - if check_out && reduce_group_id >= output.len() { - return; - } - let stride_reduce_dim_input = input.stride(dim); let shape_reduce_dim_input = input.shape(dim); @@ -35,7 +30,7 @@ pub fn reduce_dim_subcube_kernel< let warp_id = UNIT_POS / PLANE_DIM; - let mut shared_memory = RD::init_shared(smem_size); + let mut shared_memory = RD::init_shared(subcube_size); let mut index_offset = 0; @@ -94,18 +89,22 @@ pub fn reduce_dim_subcube< input: JitTensor, dim: usize, ) -> JitTensor { - if !input.client.properties().feature_enabled(Feature::Plane) { + let topology = input.client.properties().hardware_properties(); + + if !input.client.properties().feature_enabled(Feature::Plane) + || topology.plane_size_min != topology.plane_size_max + { return reduce_dim_shared::(input, dim); } - let output = init_reduce_output::(&input, dim); + let subcube_size = topology.plane_size_min; - let warp_size = 32; // TODO: Add a method to client to query this + let output = init_reduce_output::(&input, dim); let num_elems_output = output.shape.num_elements(); let cube_dim = CubeDim { - x: warp_size, - y: warp_size, + x: subcube_size, + y: subcube_size, z: 1, }; let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); @@ -118,23 +117,18 @@ pub fn reduce_dim_subcube< f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - let check_out = (cube_count_x * cube_count_y) as usize != num_elems_output; - let smem_size = cube_dim.num_elems() / warp_size; - - unsafe { - reduce_dim_subcube_kernel::launch_unchecked::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg(1), - output.as_tensor_arg(1), - dim as u32, - smem_size, - elems_per_thread, - divisible_shape, - check_out, - ) - }; + + reduce_dim_subcube_kernel::launch::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg(1), + output.as_tensor_arg(1), + dim as u32, + subcube_size, + elems_per_thread, + divisible_shape, + ); output } diff --git a/crates/burn-jit/src/kernel/reduce/tune/base.rs b/crates/burn-jit/src/kernel/reduce/tune/base.rs index f26d3cca69..cf8a51f16d 100644 --- a/crates/burn-jit/src/kernel/reduce/tune/base.rs +++ b/crates/burn-jit/src/kernel/reduce/tune/base.rs @@ -83,7 +83,12 @@ fn should_run< // Shared 1 => key.reduce_dim_length >= 16, // Subcube - 2 => op.input.client.properties().feature_enabled(Feature::Plane), + 2 => { + let props = op.input.client.properties(); + let hardware = props.hardware_properties(); + props.feature_enabled(Feature::Plane) + && hardware.plane_size_min == hardware.plane_size_max + } _ => true, } }