Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Sep 24, 2024
1 parent 423f320 commit 544fcf3
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
);
}

let weight = into_contiguous(permute(weight, [0, 2, 3, 1]));
// channel last is more efficient even with the extra into_contiguous kernel
let input = into_contiguous(permute(input, [0, 2, 3, 1]));
let weight = into_contiguous(permute(weight, [0, 2, 3, 1]));

let out_shape = Shape::new([batch_size, out_h, out_w, out_channels]);
let mut out = empty_device(input.client.clone(), input.device.clone(), out_shape);
Expand Down Expand Up @@ -483,9 +484,9 @@ fn load_input_tile<F: Float, FMat: Float>(

if x >= 0 && x < width && y >= 0 && y < height {
let idx = batch * input.stride(0)
+ channel * input.stride(3)
+ y as u32 * input.stride(1)
+ x as u32 * input.stride(2);
+ x as u32 * input.stride(2)
+ channel * input.stride(3);
FMat::cast_from(input[idx / vec])
} else {
FMat::vectorized(0.0, vec)
Expand All @@ -494,9 +495,9 @@ fn load_input_tile<F: Float, FMat: Float>(
let y = out_y * args.stride_h + kernel_y * args.dilation_h;
let x = out_x * args.stride_w + kernel_x * args.dilation_w;
let idx = batch * input.stride(0)
+ channel * input.stride(3)
+ y * input.stride(1)
+ x * input.stride(2);
+ x * input.stride(2)
+ channel * input.stride(3);

FMat::cast_from(input[idx / vec])
};
Expand Down

0 comments on commit 544fcf3

Please sign in to comment.