Skip to content

Commit

Permalink
Use smaller cube for gemm_n == 16
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Sep 24, 2024
1 parent 544fcf3 commit 0771f1f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ macro_rules! bench_on_backend {
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;
bench::<LibTorch>(&device, feature_name, url, token);
bench::<LibTorch<half::f16>>(&device, feature_name, url, token);
}

#[cfg(feature = "tch-cpu")]
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
let slice_size = kernel_h * kernel_w * in_channels;

let cube_dim_x = 128;
let cube_dim_y = 2;
let cube_dim_y = Ord::min(gemm_n.div_ceil(16), 2);

let cmma_m = 16;
let cmma_n = 16;
Expand Down

0 comments on commit 0771f1f

Please sign in to comment.