-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fusion wgpu compilation cache #1069
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #1069 +/- ##
==========================================
+ Coverage 85.55% 85.66% +0.11%
==========================================
Files 508 509 +1
Lines 53910 54126 +216
==========================================
+ Hits 46122 46368 +246
+ Misses 7788 7758 -30 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, please look at my comments
burn-wgpu/src/codegen/kernel.rs
Outdated
impl ElemWiseKernelCodegen<CompilationPhase> { | ||
/// Compile the kernel into a [compute shader](ComputeShader). | ||
pub fn compile(self) -> ComputeShader { | ||
let mut inputs = Vec::with_capacity(self.input_bindings.len()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputs and outpus can be directly self.input_bindings and self.output_bindings
burn-wgpu/src/codegen/shader.rs
Outdated
pub enum Visibility { | ||
Read, | ||
ReadWrite, | ||
} | ||
|
||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)] | ||
#[derive(Debug, Clone, PartialEq, Eq, Copy)] | ||
pub enum Elem { | ||
F32, | ||
#[allow(dead_code)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check
burn-wgpu/src/fusion/cache.rs
Outdated
items: HashSet<String>, | ||
} | ||
|
||
pub enum CachedComputeShader { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cachable
burn-wgpu/src/fusion/cache.rs
Outdated
fn source(&self) -> SourceTemplate { | ||
match self { | ||
CachedComputeShader::Cached(_) => { | ||
panic!("NoSource compute shader should only be used by a higher level cache.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NoSource
inputs, | ||
outputs, | ||
locals, | ||
operators: self.operators.clone(), | ||
scalars_f32: self.scalars_f32, | ||
device: self.device.clone(), | ||
cache: KernelCache::default(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it useless
@@ -1,21 +0,0 @@ | |||
@group(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to delete safe tanh file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect
Changes
Add a cache step where the compilation of fused kernels are only executed once, then the kernel ID is used without any overhead.
Refactor the way we create kernels with fusion into a DSL so that we can use it to create WGSL shaders easily, not just when using fusion. I migrated the unary kernels to use the new way, but following PRs will migrate more kernels and remove a lot of code.
Drop
I removed the clamp max and min functions in Fusion as well as in WGPU, since we need to reduce the number of operations that can be fused to their basic operations. Long story short, they will be fused, so we don't need specific kernels, and having those functions will actually break the fusing stream, so they actually hurt performance as well as add complexity.