Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fusion wgpu compilation cache #1069

Merged
merged 14 commits into from
Dec 18, 2023
Merged

Fusion wgpu compilation cache #1069

merged 14 commits into from
Dec 18, 2023

Conversation

nathanielsimard
Copy link
Member

@nathanielsimard nathanielsimard commented Dec 14, 2023

Changes

Add a cache step where the compilation of fused kernels are only executed once, then the kernel ID is used without any overhead.

Refactor the way we create kernels with fusion into a DSL so that we can use it to create WGSL shaders easily, not just when using fusion. I migrated the unary kernels to use the new way, but following PRs will migrate more kernels and remove a lot of code.

Drop

I removed the clamp max and min functions in Fusion as well as in WGPU, since we need to reduce the number of operations that can be fused to their basic operations. Long story short, they will be fused, so we don't need specific kernels, and having those functions will actually break the fusing stream, so they actually hurt performance as well as add complexity.

@nathanielsimard nathanielsimard changed the title Refactor/fusion wgpu Fusion wgpu compilation cache Dec 14, 2023
Copy link

codecov bot commented Dec 14, 2023

Codecov Report

Attention: 31 lines in your changes are missing coverage. Please review.

Comparison is base (d9f93d3) 85.55% compared to head (0e1a5e6) 85.66%.
Report is 12 commits behind head on main.

Files Patch % Lines
burn-wgpu/src/fusion/elemwise/builder.rs 10.00% 18 Missing ⚠️
burn-wgpu/src/codegen/operator.rs 77.27% 5 Missing ⚠️
burn-wgpu/src/codegen/shader.rs 20.00% 4 Missing ⚠️
burn-wgpu/src/element.rs 66.66% 2 Missing ⚠️
burn-wgpu/src/fusion/cache.rs 91.66% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1069      +/-   ##
==========================================
+ Coverage   85.55%   85.66%   +0.11%     
==========================================
  Files         508      509       +1     
  Lines       53910    54126     +216     
==========================================
+ Hits        46122    46368     +246     
+ Misses       7788     7758      -30     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, please look at my comments

impl ElemWiseKernelCodegen<CompilationPhase> {
/// Compile the kernel into a [compute shader](ComputeShader).
pub fn compile(self) -> ComputeShader {
let mut inputs = Vec::with_capacity(self.input_bindings.len());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs and outpus can be directly self.input_bindings and self.output_bindings

pub enum Visibility {
Read,
ReadWrite,
}

#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
pub enum Elem {
F32,
#[allow(dead_code)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check

items: HashSet<String>,
}

pub enum CachedComputeShader {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cachable

fn source(&self) -> SourceTemplate {
match self {
CachedComputeShader::Cached(_) => {
panic!("NoSource compute shader should only be used by a higher level cache.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NoSource

inputs,
outputs,
locals,
operators: self.operators.clone(),
scalars_f32: self.scalars_f32,
device: self.device.clone(),
cache: KernelCache::default(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it useless

@@ -1,21 +0,0 @@
@group(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to delete safe tanh file

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect

@louisfd louisfd merged commit b5c49c5 into main Dec 18, 2023
15 checks passed
@louisfd louisfd deleted the refactor/fusion-wgpu branch December 18, 2023 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants