Skip to content

Commit

Permalink
Move conv autotune under feature flag (except key) (#2330)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Oct 2, 2024
1 parent 99d9fa2 commit dce5565
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 51 deletions.
6 changes: 3 additions & 3 deletions crates/burn-jit/src/kernel/conv/conv2d/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use burn_tensor::{
use crate::{tensor::JitTensor, FloatElement, IntElement, JitElement, JitRuntime};

#[cfg(feature = "autotune")]
use super::conv2d_autotune;
use super::{conv2d_autotune, conv_transpose2d_autotune};
use super::{
conv2d_direct, conv2d_im2col, conv_transpose2d_autotune, conv_transpose2d_col2im,
conv_transpose2d_direct, implicit_gemm::conv2d_implicit_gemm,
conv2d_direct, conv2d_im2col, conv_transpose2d_col2im, conv_transpose2d_direct,
implicit_gemm::conv2d_implicit_gemm,
};

/// The strategy to be used when launching a convolution kernel.
Expand Down
2 changes: 0 additions & 2 deletions crates/burn-jit/src/kernel/conv/conv2d/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ mod im2col;
mod implicit_gemm;
mod transpose_direct;

#[cfg(feature = "autotune")]
mod tune;

pub use base::*;
Expand All @@ -14,5 +13,4 @@ pub use direct::*;
pub use im2col::*;
pub use implicit_gemm::*;
pub use transpose_direct::*;
#[cfg(feature = "autotune")]
pub use tune::*;
25 changes: 2 additions & 23 deletions crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ use burn_tensor::{
use cubecl::{
tune,
tune::{local_tuner, tune_with, LocalTuner},
AutotuneKey,
};
use serde::{Deserialize, Serialize};

use crate::{
kernel::{
Expand All @@ -18,6 +16,8 @@ use crate::{
FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId,
};

use super::Conv2dAutotuneKey;

/// Executes autotune on conv2d operations
pub fn conv2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E>,
Expand All @@ -38,27 +38,6 @@ pub fn conv2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
)
}

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct Conv2dAutotuneKey {
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
#[autotune(anchor)]
pub in_channels: usize,
#[autotune(anchor)]
pub out_channels: usize,
#[autotune(anchor)]
pub height: usize,
#[autotune(anchor)]
pub width: usize,
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
}

#[tune(
operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm),
create_key = create_key,
Expand Down
24 changes: 1 addition & 23 deletions crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ use burn_tensor::{ops::ConvTransposeOptions, ElementConversion, Shape};
use cubecl::{
tune,
tune::{local_tuner, tune_with, LocalTuner},
AutotuneKey,
};
use serde::{Deserialize, Serialize};

use crate::{
kernel::{
Expand All @@ -15,27 +13,7 @@ use crate::{
FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId,
};

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct ConvTranspose2dAutotuneKey {
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub padding_out: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
#[autotune(anchor)]
pub in_channels: usize,
#[autotune(anchor)]
pub out_channels: usize,
#[autotune(anchor)]
pub height: usize,
#[autotune(anchor)]
pub width: usize,
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
}
use super::ConvTranspose2dAutotuneKey;

/// Executes autotune on conv2d operations
pub fn conv_transpose2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
Expand Down
45 changes: 45 additions & 0 deletions crates/burn-jit/src/kernel/conv/conv2d/tune/key.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use cubecl::AutotuneKey;
use serde::{Deserialize, Serialize};

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct Conv2dAutotuneKey {
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
#[autotune(anchor)]
pub in_channels: usize,
#[autotune(anchor)]
pub out_channels: usize,
#[autotune(anchor)]
pub height: usize,
#[autotune(anchor)]
pub width: usize,
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
}

#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct ConvTranspose2dAutotuneKey {
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub padding_out: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
#[autotune(anchor)]
pub in_channels: usize,
#[autotune(anchor)]
pub out_channels: usize,
#[autotune(anchor)]
pub height: usize,
#[autotune(anchor)]
pub width: usize,
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
}
7 changes: 7 additions & 0 deletions crates/burn-jit/src/kernel/conv/conv2d/tune/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#[cfg(feature = "autotune")]
mod conv2d;
#[cfg(feature = "autotune")]
mod conv_transpose2d;

#[cfg(feature = "autotune")]
pub use conv2d::*;
#[cfg(feature = "autotune")]
pub use conv_transpose2d::*;

mod key;
pub use key::*;

0 comments on commit dce5565

Please sign in to comment.