From 0b614b7024df8673592b1ff8e8c5e8dc6c738334 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 25 Nov 2024 18:20:01 +0100 Subject: [PATCH] [Optimization] Add custom NCHW to NHWC kernel for implicit GEMM (#2530) --- Cargo.lock | 24 +-- Cargo.toml | 4 +- .../src/kernel/conv/conv2d/implicit_gemm.rs | 8 +- .../src/kernel/conv/conv2d/layout_swap.rs | 202 ++++++++++++++++++ crates/burn-jit/src/kernel/conv/conv2d/mod.rs | 3 +- crates/burn-jit/src/kernel/conv/mod.rs | 2 +- crates/burn-jit/src/ops/base.rs | 16 +- crates/burn-jit/src/ops/mod.rs | 2 +- crates/burn-jit/src/template/base.rs | 5 +- crates/burn-jit/src/tests/conv2d.rs | 27 ++- crates/burn-jit/src/tests/mod.rs | 1 + 11 files changed, 270 insertions(+), 24 deletions(-) create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs diff --git a/Cargo.lock b/Cargo.lock index 65c9af9137..b8e1591c3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1544,7 +1544,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1575,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1592,7 +1592,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1610,7 +1610,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1624,7 +1624,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1640,7 +1640,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1665,7 +1665,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "bytemuck", "cubecl-core", @@ -1676,7 +1676,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1691,7 +1691,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1728,7 +1728,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "async-channel", "async-lock", @@ -1749,7 +1749,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1763,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8#d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" +source = "git+https://github.com/tracel-ai/cubecl?rev=2c09d4dd1ecb9f474e524dc47b05599edb7049e7#2c09d4dd1ecb9f474e524dc47b05599edb7049e7" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index b01e07da2d..2e8c3bec66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "d8ebb40bc8c4900b0f1ee738b1dd0022b8d340e8" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2c09d4dd1ecb9f474e524dc47b05599edb7049e7" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index fd5d3857f2..49a639ef43 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -21,6 +21,8 @@ use crate::{ FloatElement, IntElement, JitRuntime, }; +use super::nchw_to_nhwc; + /// Perform a 2D convolution using the implicit GEMM algorithm. Requires `cmma` to be available. /// /// * `input` - The input feature map @@ -84,7 +86,11 @@ pub fn conv2d_implicit_gemm( ); } - let input = into_contiguous(permute(input, &[0, 2, 3, 1])); + // If input is contiguous NCHW, use custom transpose kernel + let input = match input.is_contiguous() { + true => nchw_to_nhwc::(input), + false => into_contiguous(permute(input, &[0, 2, 3, 1])), + }; let weight = into_contiguous(permute(weight, &[2, 3, 1, 0])); let out_shape = Shape::new([padded_batch_size, out_h, out_w, padded_out_channels]); diff --git a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs new file mode 100644 index 0000000000..a998bea86d --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs @@ -0,0 +1,202 @@ +use burn_tensor::Shape; +use cubecl::{prelude::*, CubeCount, CubeDim}; + +use crate::{ + ops::{max_vectorization, numeric::empty_device}, + tensor::JitTensor, + JitElement, JitRuntime, +}; + +/// Efficiently transpose an NCHW tensor to NHWC for use in kernels that prefer NHWC for performance. +/// Faster than `into_contiguous`, but specialized only for this specific permutation. +/// +/// # Arguments +/// +/// * `input` - The input in NCHW format +/// +/// # Output +/// +/// The input in NHWC format +/// +pub fn nchw_to_nhwc(input: JitTensor) -> JitTensor { + let tiles_per_block = 8; + let warp_size = 32; + let tile_dim = 16; + + let [batch_size, in_c, h, w] = input.shape.dims(); + let hw = h * w; + + let out_shape = Shape::new([batch_size, h, w, in_c]); + let out = empty_device::(input.client.clone(), input.device.clone(), out_shape); + + let tiles_channel = in_c.div_ceil(tile_dim) as u32; + let tiles_hw = hw.div_ceil(tile_dim) as u32; + + let block_tiles_y = Ord::min(tiles_channel.next_power_of_two(), tiles_per_block); + let block_tiles_x = Ord::min(tiles_per_block / block_tiles_y, tiles_hw); + + let cube_count_y = tiles_channel.div_ceil(block_tiles_y); + let cube_count_x = tiles_hw.div_ceil(block_tiles_x); + let cube_count_z = batch_size as u32; + + let config = ComptimeConfig { + tiles_x: block_tiles_x, + warps_per_cube: tiles_per_block, + tile_dim: tile_dim as u32, + warp_size, + num_banks: 32, + }; + + let cube_dim = CubeDim { + x: block_tiles_x * warp_size, + y: block_tiles_y, + z: 1, + }; + let cube_count = CubeCount::Static(cube_count_x, cube_count_y, cube_count_z); + + let in_vec = max_vectorization(&input); + let out_vec = R::supported_line_sizes() + .iter() + .copied() + .find(|vec| in_c % *vec as usize == 0) + .unwrap_or(1); + + unsafe { + nchw_to_nhwc_kernel::launch_unchecked::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg::(in_vec), + out.as_tensor_arg::(out_vec), + ScalarArg::new(hw as u32), + config, + ) + }; + + out +} + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +struct ComptimeConfig { + tiles_x: u32, + warps_per_cube: u32, + tile_dim: u32, + warp_size: u32, + num_banks: i32, +} + +#[cube(launch_unchecked)] +fn nchw_to_nhwc_kernel( + input: &Tensor>, + out: &mut Tensor>, + shape_hw: u32, + #[comptime] config: ComptimeConfig, +) { + let ComptimeConfig { + tiles_x, + warps_per_cube, + tile_dim, + warp_size, + num_banks, + } = config; + + let tile_elems = tile_dim * tile_dim; + + let unit_pos = UNIT_POS; + let intra_warp_unit_idx = unit_pos % 32; + let batch = CUBE_POS_Z; + + if batch >= input.shape(0) { + return; + } + + let batch_offset = batch * input.stride(0); + + let warp_id = plane_broadcast(unit_pos / 32, 0); + let warp_id_x = warp_id / CUBE_DIM_Y; + + let tile_x = CUBE_POS_X * tiles_x + warp_id_x; + let tile_y = ABSOLUTE_POS_Y; + + let mut shared = SharedMemory::::new(warps_per_cube * tile_elems); + let shared_start = warp_id * tile_elems; + + let base_hw = tile_x * tile_dim; + let base_c = tile_y * tile_dim; + + let elems_per_unit = tile_elems / warp_size; + let unit_start = intra_warp_unit_idx * elems_per_unit; + + let mat_hw_start = unit_start % tile_dim; + + let mat_c = unit_start / tile_dim; + let channel = base_c + mat_c; + let offset = channel * input.stride(1) + batch_offset; + + let input_vec = input.line_size(); + let out_vec = out.line_size(); + let in_max = input.buffer_len() - 1; + + let channels = input.shape(1); + + let mat_offset_base = shared_start + mat_c * tile_dim; + + #[unroll] + for hw in range_stepped(0, elems_per_unit, input_vec) { + let mat_hw = mat_hw_start + hw; + let hw = base_hw + mat_hw; + let offset = Min::min((offset + hw) / input_vec, in_max); + let value = input[offset]; + + let mat_idx = mat_offset_base + mat_hw; + + #[unroll] + for v in 0..input_vec { + let shared_idx = swizzle(mat_idx + v, num_banks); + shared[shared_idx] = value[v]; + } + } + + sync_units(); + + let mat_hw = mat_c; + let hw = base_hw + mat_hw; + + if hw >= shape_hw { + return; + } + + let mat_c_start = mat_hw_start; + let offset = hw * out.stride(2) + batch_offset; + let mat_base = shared_start + mat_hw; + + #[unroll] + for ch in range_stepped(0, elems_per_unit, out_vec) { + let mat_c = mat_c_start + ch; + let ch = base_c + mat_c; + + let mat_idx = mat_base + mat_c * tile_dim; + let mut value = Line::empty(out_vec); + let offset = (offset + ch) / out_vec; + + #[unroll] + for v in 0..out_vec { + let shared_idx = swizzle(mat_idx + v * tile_dim, num_banks); + value[v] = shared[shared_idx]; + } + + if ch < channels { + out[offset] = value; + } + } +} + +#[cube] +pub fn swizzle(offset: u32, #[comptime] bank_count: i32) -> u32 { + let num_bits = comptime!(i32::BITS - bank_count.leading_zeros() - 1); + let bit_mask = (1 << num_bits) - 1; + let yyy_mask = bit_mask << (num_bits); + let mask_shift = num_bits; + + offset ^ ((offset & yyy_mask) >> mask_shift) +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/mod.rs index 07bf617656..13900acdc1 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/mod.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/mod.rs @@ -3,8 +3,8 @@ mod col2im; mod direct; mod im2col; mod implicit_gemm; +mod layout_swap; mod transpose_direct; - mod tune; pub use base::*; @@ -12,5 +12,6 @@ pub use col2im::*; pub use direct::*; pub use im2col::*; pub use implicit_gemm::*; +pub use layout_swap::*; pub use transpose_direct::*; pub use tune::*; diff --git a/crates/burn-jit/src/kernel/conv/mod.rs b/crates/burn-jit/src/kernel/conv/mod.rs index 5ed7aa570f..5d6794495f 100644 --- a/crates/burn-jit/src/kernel/conv/mod.rs +++ b/crates/burn-jit/src/kernel/conv/mod.rs @@ -10,4 +10,4 @@ pub(crate) use conv_transpose3d::*; pub(crate) use deform_conv2d::*; pub(crate) use deform_conv_transpose2d::*; -pub use conv2d::{conv2d, conv_transpose2d, Conv2dStrategy, ConvTranspose2dStrategy}; +pub use conv2d::{conv2d, conv_transpose2d, nchw_to_nhwc, Conv2dStrategy, ConvTranspose2dStrategy}; diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index 9f6b8f2234..58e3b25c0c 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -1,6 +1,6 @@ use crate::{element::JitElement, kernel, tensor::JitTensor, JitRuntime}; use burn_tensor::{Shape, TensorData}; -use cubecl::CubeElement; +use cubecl::{tensor_vectorization_factor, CubeElement}; pub(crate) fn from_data( data: TensorData, @@ -20,8 +20,9 @@ pub(crate) async fn into_data(tensor: JitTensor TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) } +/// Read data from a `JitTensor` synchronously #[allow(unused, reason = "useful for debugging kernels")] -pub(crate) fn into_data_sync(tensor: JitTensor) -> TensorData { +pub fn into_data_sync(tensor: JitTensor) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one(tensor.handle.binding()); @@ -67,7 +68,7 @@ pub(crate) fn swap_dims( tensor } -pub(crate) fn permute(mut tensor: JitTensor, axes: &[usize]) -> JitTensor { +pub fn permute(mut tensor: JitTensor, axes: &[usize]) -> JitTensor { // remap strides tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect(); @@ -138,3 +139,12 @@ pub(crate) fn reshape(tensor: JitTensor, shape: Shape) -> JitT tensor.dtype, ) } + +pub(crate) fn max_vectorization(tensor: &JitTensor) -> u8 { + tensor_vectorization_factor( + R::supported_line_sizes(), + &tensor.shape.dims, + &tensor.strides, + tensor.shape.num_dims() - 1, + ) +} diff --git a/crates/burn-jit/src/ops/mod.rs b/crates/burn-jit/src/ops/mod.rs index d1e82684c7..2e23e3835d 100644 --- a/crates/burn-jit/src/ops/mod.rs +++ b/crates/burn-jit/src/ops/mod.rs @@ -6,7 +6,7 @@ mod module_ops; mod qtensor; mod transaction; -mod base; +pub(crate) mod base; pub(crate) use base::*; pub(crate) mod numeric; diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 61c181cbb4..9ff5f28247 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -19,12 +19,13 @@ pub struct SourceKernel { } impl CubeTask for SourceKernel { - fn compile(&self, _mode: ExecutionMode) -> CompiledKernel { + fn compile(&self, _options: &C::CompilationOptions, _mode: ExecutionMode) -> CompiledKernel { let source_template = self.kernel_source.source(); let source = source_template.complete(); CompiledKernel { - name: Some(core::any::type_name::()), + entrypoint_name: "kernel".to_string(), + debug_name: Some(core::any::type_name::()), source, cube_dim: self.cube_dim, shared_mem_bytes: 0, diff --git a/crates/burn-jit/src/tests/conv2d.rs b/crates/burn-jit/src/tests/conv2d.rs index e64d337e96..f93adffe8f 100644 --- a/crates/burn-jit/src/tests/conv2d.rs +++ b/crates/burn-jit/src/tests/conv2d.rs @@ -1,7 +1,11 @@ #[burn_tensor_testgen::testgen(conv2d)] mod tests { use super::*; - use burn_tensor::{module, Distribution, Tensor}; + use burn_jit::{ + kernel::{conv::nchw_to_nhwc, into_contiguous}, + tests::into_data_sync, + }; + use burn_tensor::{backend::Backend, module, Distribution, Tensor}; #[test] fn conv2d_should_match_reference_backend() { @@ -50,4 +54,25 @@ mod tests { .into_data() .assert_approx_eq(&output_ref.into_data(), 1); } + + #[test] + fn nchw_to_nhwc_should_match_into_contiguos() { + let test_device = Default::default(); + let input = + Tensor::::random([4, 72, 53, 56], Distribution::Default, &test_device); + + type Float = ::FloatElem; + + let output = nchw_to_nhwc::(input.clone().into_primitive().tensor()); + let output_ref = into_contiguous( + input + .clone() + .permute([0, 2, 3, 1]) + .into_primitive() + .tensor(), + ); + + into_data_sync::(output) + .assert_approx_eq(&into_data_sync::(output_ref), 1); + } } diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index 3528f94872..b1ee4ce26d 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -28,6 +28,7 @@ mod unary; mod uniform; // Re-export dependencies for tests +pub use crate::ops::base::into_data_sync; pub use burn_autodiff; pub use burn_fusion; pub use burn_ndarray;