From 76fe0ed881b3965782f78896433f8bb5e2f13a1b Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Sun, 19 May 2024 13:20:55 -0400 Subject: [PATCH] Refactor/cube/vectorization (#1781) --- crates/burn-cube-macros/src/analysis.rs | 9 +++ crates/burn-cube-macros/src/codegen/base.rs | 16 ++++- .../burn-cube-macros/src/codegen/function.rs | 7 +- .../burn-cube-macros/src/codegen/variable.rs | 13 ++++ crates/burn-cube/src/codegen/compilation.rs | 21 ++---- .../burn-cube/src/codegen/dialect/branch.rs | 2 +- .../src/codegen/dialect/procedure/assign.rs | 47 ++++++------- .../src/codegen/dialect/procedure/index.rs | 8 +-- .../src/codegen/dialect/procedure/read.rs | 12 +--- crates/burn-cube/src/codegen/dialect/scope.rs | 8 +-- .../burn-cube/src/codegen/dialect/shader.rs | 34 ++++++---- .../burn-cube/src/codegen/dialect/variable.rs | 42 ++++++------ .../src/codegen/dialect/vectorization.rs | 32 ++------- crates/burn-cube/src/language/branch.rs | 2 +- crates/burn-cube/src/language/element/bool.rs | 17 ++++- .../src/language/element/conversion.rs | 2 +- .../burn-cube/src/language/element/float.rs | 22 ++++-- crates/burn-cube/src/language/element/int.rs | 22 ++++-- .../burn-cube/src/language/element/numeric.rs | 23 ++++++- .../src/language/element/primitive.rs | 7 +- crates/burn-cube/src/language/element/uint.rs | 20 +++++- .../src/language/operation/assignation.rs | 5 +- .../burn-cube/src/language/operation/base.rs | 26 +++++-- crates/burn-cube/tests/language/cast_elem.rs | 56 ++++++++-------- crates/burn-cube/tests/language/cast_kind.rs | 16 ++--- crates/burn-cube/tests/language/for_loop.rs | 10 +-- .../burn-cube/tests/language/function_call.rs | 12 ++-- .../tests/language/generic_kernel.rs | 8 +-- crates/burn-cube/tests/language/if.rs | 6 +- crates/burn-cube/tests/language/if_else.rs | 6 +- crates/burn-cube/tests/language/literal.rs | 6 +- crates/burn-cube/tests/language/loop.rs | 8 +-- crates/burn-cube/tests/language/mod.rs | 1 + .../burn-cube/tests/language/module_import.rs | 4 +- .../burn-cube/tests/language/parenthesis.rs | 8 +-- crates/burn-cube/tests/language/reuse.rs | 12 ++-- crates/burn-cube/tests/language/trait.rs | 20 +++--- .../burn-cube/tests/language/vectorization.rs | 67 +++++++++++++++++++ crates/burn-jit/src/fusion/elemwise/kernel.rs | 13 ++-- crates/burn-jit/src/fusion/tracing/builder.rs | 8 +-- crates/burn-jit/src/fusion/tracing/trace.rs | 4 +- crates/burn-jit/src/kernel/binary.rs | 6 +- crates/burn-jit/src/kernel/cast/bool_cast.rs | 2 +- .../tiling2d_shader/shader_information.rs | 8 +-- .../src/kernel/pool/max_pool2d_backward.rs | 4 +- .../burn-jit/src/kernel/pool/pool2d_shader.rs | 4 +- crates/burn-jit/src/kernel/unary.rs | 8 +-- crates/burn-jit/src/ops/int_ops.rs | 2 +- .../burn-wgpu/src/compiler/wgsl/compiler.rs | 14 ++-- 49 files changed, 433 insertions(+), 277 deletions(-) create mode 100644 crates/burn-cube/tests/language/vectorization.rs diff --git a/crates/burn-cube-macros/src/analysis.rs b/crates/burn-cube-macros/src/analysis.rs index 50f819a84a..256ea8d3b9 100644 --- a/crates/burn-cube-macros/src/analysis.rs +++ b/crates/burn-cube-macros/src/analysis.rs @@ -239,6 +239,15 @@ impl CodeAnalysisBuilder { } syn::Expr::Break(_) => {} syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth), + syn::Expr::Array(expr) => { + for element in expr.elems.iter() { + match element { + syn::Expr::Lit(_) => {} + _ => todo!("Analysis: only array of literals is supported"), + } + } + } + syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth), _ => todo!("Analysis: unsupported expr {expr:?}"), } } diff --git a/crates/burn-cube-macros/src/codegen/base.rs b/crates/burn-cube-macros/src/codegen/base.rs index 857d3bb518..8574e5aa66 100644 --- a/crates/burn-cube-macros/src/codegen/base.rs +++ b/crates/burn-cube-macros/src/codegen/base.rs @@ -6,7 +6,10 @@ use super::{ branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop}, function::{codegen_call, codegen_closure, codegen_expr_method_call}, operation::codegen_binary, - variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs}, + variable::{ + codegen_array_lit, codegen_assign, codegen_index, codegen_lit, codegen_local, + codegen_path_rhs, + }, }; /// Codegen for a statement (generally one line) @@ -59,6 +62,15 @@ pub(crate) fn codegen_expr_block( codegen_block(&block.block, loop_level, variable_analyses) } +pub(crate) fn codegen_ref( + reference: &syn::ExprReference, + loop_level: usize, + variable_analyses: &mut CodeAnalysis, +) -> TokenStream { + let inner = codegen_expr(&reference.expr, loop_level, variable_analyses); + quote::quote! { & #inner } +} + /// Codegen for expressions /// There are many variants of expression, treated differently pub(crate) fn codegen_expr( @@ -84,6 +96,8 @@ pub(crate) fn codegen_expr( syn::Expr::MethodCall(call) => codegen_expr_method_call(call), syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses), syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses), + syn::Expr::Array(array) => codegen_array_lit(array), + syn::Expr::Reference(reference) => codegen_ref(reference, loop_level, variable_analyses), _ => panic!("Codegen: Unsupported {:?}", expr), } } diff --git a/crates/burn-cube-macros/src/codegen/function.rs b/crates/burn-cube-macros/src/codegen/function.rs index 7152cee3f7..e2e549f7a3 100644 --- a/crates/burn-cube-macros/src/codegen/function.rs +++ b/crates/burn-cube-macros/src/codegen/function.rs @@ -34,12 +34,7 @@ pub(crate) fn codegen_closure( } /// Codegen for a function call -/// Supports: -/// func() -/// func::() -/// T::func() -/// -/// Should map: +/// Maps /// [A[::<...>]?::]^* func[::<...>] (args) /// to /// [A[::<...>]?::]^* func_expand[::<...>] (context, args) diff --git a/crates/burn-cube-macros/src/codegen/variable.rs b/crates/burn-cube-macros/src/codegen/variable.rs index 769bc6dff0..6d6b709e5f 100644 --- a/crates/burn-cube-macros/src/codegen/variable.rs +++ b/crates/burn-cube-macros/src/codegen/variable.rs @@ -19,6 +19,19 @@ pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { } } +/// Codegen for arrays of literals +pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream { + let mut tokens = quote::quote! {}; + for element in array.elems.iter() { + let token = match element { + syn::Expr::Lit(lit) => codegen_lit(lit), + _ => todo!("Codegen: Only arrays of literals are supported"), + }; + tokens.extend(quote::quote! { #token, }); + } + quote::quote! { [ #tokens ] } +} + /// Codegen for a local declaration (let ...) /// Supports: /// let x = ... diff --git a/crates/burn-cube/src/codegen/compilation.rs b/crates/burn-cube/src/codegen/compilation.rs index bd5504699a..ae30869ebc 100644 --- a/crates/burn-cube/src/codegen/compilation.rs +++ b/crates/burn-cube/src/codegen/compilation.rs @@ -81,12 +81,7 @@ impl core::fmt::Display for CompilationSettings { } match self.vectorization { - Some(vectorization) => match vectorization { - Vectorization::Vec4 => f.write_str("v4"), - Vectorization::Vec3 => f.write_str("v3"), - Vectorization::Vec2 => f.write_str("v2"), - Vectorization::Scalar => f.write_str("v1"), - }?, + Some(vectorization) => f.write_fmt(format_args!("v{}", vectorization))?, None => f.write_str("vn")?, }; @@ -154,7 +149,7 @@ impl InputInfo { item, visibility: _, } => *item, - InputInfo::Scalar { elem, size: _ } => Item::Scalar(*elem), + InputInfo::Scalar { elem, size: _ } => Item::new(*elem), } } } @@ -252,7 +247,7 @@ impl Compilation { named.push(( "info".to_string(), Binding { - item: Item::Scalar(Elem::UInt), + item: Item::new(Elem::UInt), visibility: Visibility::Read, location: Location::Storage, size: None, // We avoid putting the length here since it will force a new kernel @@ -300,7 +295,7 @@ impl Compilation { self.named_bindings.push(( format!("scalars_{}", elem), Binding { - item: Item::Scalar(elem), + item: Item::new(elem), visibility: Visibility::Read, location: Location::Storage, size: Some(size), @@ -440,11 +435,9 @@ impl Compilation { } fn bool_item(ty: Item) -> Item { - match ty { - Item::Vec4(elem) => Item::Vec4(bool_elem(elem)), - Item::Vec3(elem) => Item::Vec3(bool_elem(elem)), - Item::Vec2(elem) => Item::Vec2(bool_elem(elem)), - Item::Scalar(elem) => Item::Scalar(bool_elem(elem)), + Item { + elem: bool_elem(ty.elem), + vectorization: ty.vectorization, } } diff --git a/crates/burn-cube/src/codegen/dialect/branch.rs b/crates/burn-cube/src/codegen/dialect/branch.rs index fdc3ea9ba4..52ef709684 100644 --- a/crates/burn-cube/src/codegen/dialect/branch.rs +++ b/crates/burn-cube/src/codegen/dialect/branch.rs @@ -94,7 +94,7 @@ impl RangeLoop { func: F, ) { let mut scope = parent_scope.child(); - let index_ty = Item::Scalar(Elem::UInt); + let index_ty = Item::new(Elem::UInt); let i = scope.create_local_undeclared(index_ty); func(i, &mut scope); diff --git a/crates/burn-cube/src/codegen/dialect/procedure/assign.rs b/crates/burn-cube/src/codegen/dialect/procedure/assign.rs index ed2571fb1a..d0bf535b7c 100644 --- a/crates/burn-cube/src/codegen/dialect/procedure/assign.rs +++ b/crates/burn-cube/src/codegen/dialect/procedure/assign.rs @@ -1,4 +1,7 @@ -use crate::codegen::dialect::{macros::cpa, Item, Scope, Variable, Vectorization}; +use crate::{ + branch::range, + codegen::dialect::{macros::cpa, Scope, Variable, Vectorization}, +}; use serde::{Deserialize, Serialize}; /// Assign value to a variable based on a given condition. @@ -19,14 +22,15 @@ impl ConditionalAssign { let rhs = self.rhs; let out = self.out; - let index_var = |scope: &mut Scope, var: Variable, index: usize| match var.item() { - Item::Scalar(_) => var, - _ => { - let out = scope.create_local(var.item().elem()); - cpa!(scope, out = var[index]); - out - } - }; + let index_var = + |scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 { + true => var, + false => { + let out = scope.create_local(var.item().elem()); + cpa!(scope, out = var[index]); + out + } + }; let mut assign_index = |index: usize| { let cond = index_var(scope, cond, index); @@ -42,29 +46,20 @@ impl ConditionalAssign { })); }; - match out.item() { - Item::Vec4(_) => { - assign_index(0); - assign_index(1); - assign_index(2); - assign_index(3); - } - Item::Vec3(_) => { - assign_index(0); - assign_index(1); - assign_index(2); - } - Item::Vec2(_) => { - assign_index(0); - assign_index(1); - } - Item::Scalar(_) => { + let vectorization = out.item().vectorization; + match vectorization == 1 { + true => { cpa!(scope, if (cond).then(|scope| { cpa!(scope, out = lhs); }).else(|scope| { cpa!(scope, out = rhs); })); } + false => { + for i in range(0u32, vectorization as u32, true) { + assign_index(i); + } + } }; } diff --git a/crates/burn-cube/src/codegen/dialect/procedure/index.rs b/crates/burn-cube/src/codegen/dialect/procedure/index.rs index 614f43efe0..45f6c4bbec 100644 --- a/crates/burn-cube/src/codegen/dialect/procedure/index.rs +++ b/crates/burn-cube/src/codegen/dialect/procedure/index.rs @@ -19,8 +19,8 @@ impl CheckedIndex { let lhs = self.lhs; let rhs = self.rhs; let out = self.out; - let array_len = scope.create_local(Item::Scalar(crate::dialect::Elem::UInt)); - let inside_bound = scope.create_local(Item::Scalar(crate::dialect::Elem::Bool)); + let array_len = scope.create_local(Item::new(crate::dialect::Elem::UInt)); + let inside_bound = scope.create_local(Item::new(crate::dialect::Elem::Bool)); cpa!(scope, array_len = len(lhs)); cpa!(scope, inside_bound = rhs < array_len); @@ -56,8 +56,8 @@ impl CheckedIndexAssign { let lhs = self.lhs; let rhs = self.rhs; let out = self.out; - let array_len = scope.create_local(Item::Scalar(Elem::UInt)); - let inside_bound = scope.create_local(Item::Scalar(Elem::Bool)); + let array_len = scope.create_local(Item::new(Elem::UInt)); + let inside_bound = scope.create_local(Item::new(Elem::Bool)); cpa!(scope, array_len = len(out)); cpa!(scope, inside_bound = lhs < array_len); diff --git a/crates/burn-cube/src/codegen/dialect/procedure/read.rs b/crates/burn-cube/src/codegen/dialect/procedure/read.rs index 3351412c21..920334274c 100644 --- a/crates/burn-cube/src/codegen/dialect/procedure/read.rs +++ b/crates/burn-cube/src/codegen/dialect/procedure/read.rs @@ -140,17 +140,11 @@ impl IndexOffsetGlobalWithLayout { #[allow(missing_docs)] pub fn expand(self, scope: &mut Scope) { let layout = self.layout; - let index_item_ty = Item::Scalar(Elem::UInt); + let index_item_ty = Item::new(Elem::UInt); let offset_ref = self.position; let zero: Variable = 0u32.into(); - let vectorization_factor: Variable = match self.tensors[0].item() { - Item::Vec4(_) => 4u32, - Item::Vec3(_) => 3u32, - Item::Vec2(_) => 2u32, - Item::Scalar(_) => 1u32, - } - .into(); - + let vectorization_factor: u8 = self.tensors[0].item().vectorization; + let vectorization_factor: Variable = (vectorization_factor as u32).into(); for index in self.indexes.iter() { cpa!(scope, index = zero); } diff --git a/crates/burn-cube/src/codegen/dialect/scope.rs b/crates/burn-cube/src/codegen/dialect/scope.rs index c31679acd8..427bc08e05 100644 --- a/crates/burn-cube/src/codegen/dialect/scope.rs +++ b/crates/burn-cube/src/codegen/dialect/scope.rs @@ -336,11 +336,9 @@ impl Scope { position: Variable, ) -> Variable { let item_global = match item.elem() { - Elem::Bool => match item { - Item::Vec4(_) => Item::Vec4(Elem::UInt), - Item::Vec3(_) => Item::Vec3(Elem::UInt), - Item::Vec2(_) => Item::Vec2(Elem::UInt), - Item::Scalar(_) => Item::Scalar(Elem::UInt), + Elem::Bool => Item { + elem: Elem::UInt, + vectorization: item.vectorization, }, _ => item, }; diff --git a/crates/burn-cube/src/codegen/dialect/shader.rs b/crates/burn-cube/src/codegen/dialect/shader.rs index 64be9e0741..c864f69fe4 100644 --- a/crates/burn-cube/src/codegen/dialect/shader.rs +++ b/crates/burn-cube/src/codegen/dialect/shader.rs @@ -1,4 +1,4 @@ -use super::Scope; +use super::{Scope, Vectorization}; use crate::WORKGROUP_DEFAULT; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -44,7 +44,7 @@ pub enum Elem { impl From for Item { fn from(val: Elem) -> Self { - Item::Scalar(val) + Item::new(val) } } @@ -81,22 +81,30 @@ impl Display for Elem { } #[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)] -#[allow(missing_docs)] -pub enum Item { - Vec4(Elem), - Vec3(Elem), - Vec2(Elem), - Scalar(Elem), +pub struct Item { + pub elem: Elem, + pub vectorization: Vectorization, } impl Item { /// Fetch the elem of the item. pub fn elem(&self) -> Elem { - match self { - Self::Vec4(elem) => *elem, - Self::Vec3(elem) => *elem, - Self::Vec2(elem) => *elem, - Self::Scalar(elem) => *elem, + self.elem + } + + /// Create a new item without vectorization + pub fn new(elem: Elem) -> Self { + Self { + elem, + vectorization: 1, + } + } + + /// Create a new item with vectorization + pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self { + Self { + elem, + vectorization, } } } diff --git a/crates/burn-cube/src/codegen/dialect/variable.rs b/crates/burn-cube/src/codegen/dialect/variable.rs index 6fe537b9b6..5b748b974b 100644 --- a/crates/burn-cube/src/codegen/dialect/variable.rs +++ b/crates/burn-cube/src/codegen/dialect/variable.rs @@ -69,30 +69,30 @@ impl Variable { match self { Variable::GlobalInputArray(_, item) => *item, Variable::GlobalOutputArray(_, item) => *item, - Variable::GlobalScalar(_, elem) => Item::Scalar(*elem), + Variable::GlobalScalar(_, elem) => Item::new(*elem), Variable::Local(_, item, _) => *item, - Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem), - Variable::ConstantScalar(_, elem) => Item::Scalar(*elem), + Variable::LocalScalar(_, elem, _) => Item::new(*elem), + Variable::ConstantScalar(_, elem) => Item::new(*elem), Variable::SharedMemory(_, item, _) => *item, Variable::LocalArray(_, item, _, _) => *item, - Variable::Id => Item::Scalar(Elem::UInt), - Variable::Rank => Item::Scalar(Elem::UInt), - Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt), - Variable::LocalInvocationIdX => Item::Scalar(Elem::UInt), - Variable::LocalInvocationIdY => Item::Scalar(Elem::UInt), - Variable::LocalInvocationIdZ => Item::Scalar(Elem::UInt), - Variable::WorkgroupIdX => Item::Scalar(Elem::UInt), - Variable::WorkgroupIdY => Item::Scalar(Elem::UInt), - Variable::WorkgroupIdZ => Item::Scalar(Elem::UInt), - Variable::GlobalInvocationIdX => Item::Scalar(Elem::UInt), - Variable::GlobalInvocationIdY => Item::Scalar(Elem::UInt), - Variable::GlobalInvocationIdZ => Item::Scalar(Elem::UInt), - Variable::WorkgroupSizeX => Item::Scalar(Elem::UInt), - Variable::WorkgroupSizeY => Item::Scalar(Elem::UInt), - Variable::WorkgroupSizeZ => Item::Scalar(Elem::UInt), - Variable::NumWorkgroupsX => Item::Scalar(Elem::UInt), - Variable::NumWorkgroupsY => Item::Scalar(Elem::UInt), - Variable::NumWorkgroupsZ => Item::Scalar(Elem::UInt), + Variable::Id => Item::new(Elem::UInt), + Variable::Rank => Item::new(Elem::UInt), + Variable::LocalInvocationIndex => Item::new(Elem::UInt), + Variable::LocalInvocationIdX => Item::new(Elem::UInt), + Variable::LocalInvocationIdY => Item::new(Elem::UInt), + Variable::LocalInvocationIdZ => Item::new(Elem::UInt), + Variable::WorkgroupIdX => Item::new(Elem::UInt), + Variable::WorkgroupIdY => Item::new(Elem::UInt), + Variable::WorkgroupIdZ => Item::new(Elem::UInt), + Variable::GlobalInvocationIdX => Item::new(Elem::UInt), + Variable::GlobalInvocationIdY => Item::new(Elem::UInt), + Variable::GlobalInvocationIdZ => Item::new(Elem::UInt), + Variable::WorkgroupSizeX => Item::new(Elem::UInt), + Variable::WorkgroupSizeY => Item::new(Elem::UInt), + Variable::WorkgroupSizeZ => Item::new(Elem::UInt), + Variable::NumWorkgroupsX => Item::new(Elem::UInt), + Variable::NumWorkgroupsY => Item::new(Elem::UInt), + Variable::NumWorkgroupsZ => Item::new(Elem::UInt), } } } diff --git a/crates/burn-cube/src/codegen/dialect/vectorization.rs b/crates/burn-cube/src/codegen/dialect/vectorization.rs index 1bec46323c..5aa00aac03 100644 --- a/crates/burn-cube/src/codegen/dialect/vectorization.rs +++ b/crates/burn-cube/src/codegen/dialect/vectorization.rs @@ -1,19 +1,6 @@ use super::{BinaryOperator, ClampOperator, Item, Operation, Operator, UnaryOperator, Variable}; -/// Define a vectorization scheme. -#[allow(dead_code)] -#[derive(Copy, Clone, Debug, Default, Hash)] -pub enum Vectorization { - /// Use vec4 for vectorization. - Vec4, - /// Use vec3 for vectorization. - Vec3, - /// Use vec2 for vectorization. - Vec2, - /// Don't vectorize. - #[default] - Scalar, -} +pub type Vectorization = u8; impl Operation { pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { @@ -169,21 +156,14 @@ impl Variable { } impl Item { - pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Item { - match vectorize { - Vectorization::Vec4 => Item::Vec4(self.elem()), - Vectorization::Vec3 => Item::Vec3(self.elem()), - Vectorization::Vec2 => Item::Vec2(self.elem()), - Vectorization::Scalar => Item::Scalar(self.elem()), + pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Item { + Item { + elem: self.elem, + vectorization, } } pub(crate) fn vectorized_size(&self, vectorize: Vectorization, size: u32) -> u32 { - match vectorize { - Vectorization::Vec4 => size / 4, - Vectorization::Vec3 => size / 3, - Vectorization::Vec2 => size / 2, - Vectorization::Scalar => size, - } + size / (vectorize as u32) } } diff --git a/crates/burn-cube/src/language/branch.rs b/crates/burn-cube/src/language/branch.rs index 5ac7ce2be2..c796a271fa 100644 --- a/crates/burn-cube/src/language/branch.rs +++ b/crates/burn-cube/src/language/branch.rs @@ -41,7 +41,7 @@ pub fn range_expand( } } else { let mut child = context.child(); - let index_ty = Item::Scalar(Elem::UInt); + let index_ty = Item::new(Elem::UInt); let i = child.scope.borrow_mut().create_local_undeclared(index_ty); let i = ExpandElement::new(Rc::new(i)); diff --git a/crates/burn-cube/src/language/element/bool.rs b/crates/burn-cube/src/language/element/bool.rs index ebb1094bb7..f30faac8f3 100644 --- a/crates/burn-cube/src/language/element/bool.rs +++ b/crates/burn-cube/src/language/element/bool.rs @@ -1,4 +1,4 @@ -use crate::dialect::Elem; +use crate::dialect::{Elem, Vectorization}; use crate::language::{CubeContext, CubeType, ExpandElement, PrimitiveVariable}; @@ -34,10 +34,14 @@ impl Bool { impl PrimitiveVariable for Bool { type Primitive = bool; - fn into_elem() -> Elem { + fn as_elem() -> Elem { Elem::Bool } + fn vectorization(&self) -> Vectorization { + self.vectorization + } + fn to_f64(&self) -> f64 { match self.val { true => 1., @@ -52,4 +56,13 @@ impl PrimitiveVariable for Bool { fn from_i64(val: i64) -> Self { Self::from_f64(val as f64) } + + fn from_i64_vec(vec: &[i64]) -> Self { + Self { + // We take only one value, because type implements copy and we can't copy an unknown sized vec + // For debugging prefer unvectorized types + val: *vec.first().expect("Should be at least one value") > 0, + vectorization: vec.len() as u8, + } + } } diff --git a/crates/burn-cube/src/language/element/conversion.rs b/crates/burn-cube/src/language/element/conversion.rs index d9a08d0ac1..299d824798 100644 --- a/crates/burn-cube/src/language/element/conversion.rs +++ b/crates/burn-cube/src/language/element/conversion.rs @@ -8,7 +8,7 @@ pub trait Cast: PrimitiveVariable { context: &mut CubeContext, value: ::ExpandType, ) -> ::ExpandType { - let new_var = context.create_local(Item::Scalar(::into_elem())); + let new_var = context.create_local(Item::new(::as_elem())); assign::expand(context, value, new_var.clone()); new_var } diff --git a/crates/burn-cube/src/language/element/float.rs b/crates/burn-cube/src/language/element/float.rs index 24aa13d234..938f330bfb 100644 --- a/crates/burn-cube/src/language/element/float.rs +++ b/crates/burn-cube/src/language/element/float.rs @@ -1,4 +1,4 @@ -use crate::dialect::{Elem, FloatKind, Variable}; +use crate::dialect::{Elem, FloatKind, Variable, Vectorization}; use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; use std::rc::Rc; @@ -13,7 +13,7 @@ macro_rules! impl_float { #[derive(Clone, Copy)] pub struct $type { pub val: ::Primitive, - pub vectorization: usize, + pub vectorization: u8, } impl CubeType for $type { @@ -24,10 +24,14 @@ macro_rules! impl_float { type Primitive = f64; /// Return the element type to use on GPU - fn into_elem() -> Elem { + fn as_elem() -> Elem { Elem::Float(FloatKind::$type) } + fn vectorization(&self) -> Vectorization { + self.vectorization.into() + } + fn to_f64(&self) -> f64 { self.val } @@ -39,6 +43,16 @@ macro_rules! impl_float { fn from_i64(val: i64) -> Self { Self::new(val as f64) } + + fn from_i64_vec(vec: &[i64]) -> Self { + Self { + // We take only one value, because type implements copy and we can't copy an unknown sized vec + // When using CPU-side values for debugging kernels, prefer using unvectorized types + val: *vec.first().expect("Should be at least one value") + as ::Primitive, + vectorization: vec.len() as u8, + } + } } impl Numeric for $type {} @@ -55,7 +69,7 @@ macro_rules! impl_float { _context: &mut CubeContext, val: ::Primitive, ) -> ::ExpandType { - let new_var = Variable::ConstantScalar(val as f64, Self::into_elem()); + let new_var = Variable::ConstantScalar(val as f64, Self::as_elem()); ExpandElement::new(Rc::new(new_var)) } } diff --git a/crates/burn-cube/src/language/element/int.rs b/crates/burn-cube/src/language/element/int.rs index c3a02de437..dece9a078e 100644 --- a/crates/burn-cube/src/language/element/int.rs +++ b/crates/burn-cube/src/language/element/int.rs @@ -1,4 +1,4 @@ -use crate::dialect::{Elem, IntKind, Variable}; +use crate::dialect::{Elem, IntKind, Variable, Vectorization}; use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; use std::rc::Rc; @@ -13,7 +13,7 @@ macro_rules! impl_int { #[derive(Clone, Copy)] pub struct $type { pub val: ::Primitive, - pub vectorization: usize, + pub vectorization: u8, } impl CubeType for $type { @@ -23,10 +23,14 @@ macro_rules! impl_int { impl PrimitiveVariable for $type { type Primitive = i64; - fn into_elem() -> Elem { + fn as_elem() -> Elem { Elem::Int(IntKind::$type) } + fn vectorization(&self) -> Vectorization { + self.vectorization.into() + } + fn to_f64(&self) -> f64 { self.val as f64 } @@ -38,6 +42,16 @@ macro_rules! impl_int { fn from_i64(val: i64) -> Self { Self::new(val) } + + fn from_i64_vec(vec: &[i64]) -> Self { + Self { + // We take only one value, because type implements copy and we can't copy an unknown sized vec + // For debugging prefer unvectorized types + val: *vec.first().expect("Should be at least one value") + as ::Primitive, + vectorization: vec.len() as u8, + } + } } impl Numeric for $type {} @@ -53,7 +67,7 @@ macro_rules! impl_int { _context: &mut CubeContext, val: ::Primitive, ) -> ::ExpandType { - let new_var = Variable::ConstantScalar(val as f64, Self::into_elem()); + let new_var = Variable::ConstantScalar(val as f64, Self::as_elem()); ExpandElement::new(Rc::new(new_var)) } } diff --git a/crates/burn-cube/src/language/element/numeric.rs b/crates/burn-cube/src/language/element/numeric.rs index 3325a52230..e2d0bd0318 100644 --- a/crates/burn-cube/src/language/element/numeric.rs +++ b/crates/burn-cube/src/language/element/numeric.rs @@ -1,4 +1,5 @@ -use crate::dialect::Variable; +use crate::dialect::{Item, Variable}; +use crate::index_assign; use crate::language::{CubeContext, CubeType, ExpandElement, PrimitiveVariable}; use std::rc::Rc; @@ -24,9 +25,25 @@ pub trait Numeric: ::from_i64(val) } - /// Expand version of lit + /// Expand version of from_int fn from_int_expand(_context: &mut CubeContext, val: i64) -> ::ExpandType { - let new_var = Variable::ConstantScalar(val as f64, Self::into_elem()); + let new_var = Variable::ConstantScalar(val as f64, Self::as_elem()); ExpandElement::new(Rc::new(new_var)) } + + fn from_vec(vec: &[i64]) -> Self { + ::from_i64_vec(vec) + } + + fn from_vec_expand(context: &mut CubeContext, vec: &[i64]) -> ::ExpandType { + let mut new_var = context.create_local(Item { + elem: Self::as_elem(), + vectorization: (vec.len() as u8), + }); + for (i, element) in vec.iter().enumerate() { + new_var = index_assign::expand(context, new_var, i.into(), (*element).into()); + } + + new_var + } } diff --git a/crates/burn-cube/src/language/element/primitive.rs b/crates/burn-cube/src/language/element/primitive.rs index f56c5d9933..5c4cfe1984 100644 --- a/crates/burn-cube/src/language/element/primitive.rs +++ b/crates/burn-cube/src/language/element/primitive.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use crate::dialect::{Elem, Variable}; +use crate::dialect::{Elem, Variable, Vectorization}; use crate::language::{CubeType, ExpandElement}; /// Form of CubeType that encapsulates all primitive types: @@ -9,12 +9,15 @@ pub trait PrimitiveVariable: CubeType { type Primitive; /// Return the element type to use on GPU - fn into_elem() -> Elem; + fn as_elem() -> Elem; + fn vectorization(&self) -> Vectorization; // For easy CPU-side casting fn to_f64(&self) -> f64; fn from_f64(val: f64) -> Self; fn from_i64(val: i64) -> Self; + + fn from_i64_vec(vec: &[i64]) -> Self; } macro_rules! impl_into_expand_element { diff --git a/crates/burn-cube/src/language/element/uint.rs b/crates/burn-cube/src/language/element/uint.rs index 32263f3ade..304bbedc0f 100644 --- a/crates/burn-cube/src/language/element/uint.rs +++ b/crates/burn-cube/src/language/element/uint.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use crate::dialect::{Elem, Variable}; +use crate::dialect::{Elem, Variable, Vectorization}; use crate::language::{CubeContext, CubeType, ExpandElement, Numeric, PrimitiveVariable}; #[derive(Clone, Copy)] @@ -18,10 +18,14 @@ impl CubeType for UInt { impl PrimitiveVariable for UInt { type Primitive = u32; - fn into_elem() -> Elem { + fn as_elem() -> Elem { Elem::UInt } + fn vectorization(&self) -> Vectorization { + self.vectorization + } + fn to_f64(&self) -> f64 { self.val as f64 } @@ -33,6 +37,16 @@ impl PrimitiveVariable for UInt { fn from_i64(val: i64) -> Self { Self::new(val as ::Primitive) } + + fn from_i64_vec(vec: &[i64]) -> Self { + Self { + // We take only one value, because type implements copy and we can't copy an unknown sized vec + // For debugging prefer unvectorized types + val: *vec.first().expect("Should be at least one value") + as ::Primitive, + vectorization: vec.len() as u8, + } + } } impl Numeric for UInt {} @@ -48,7 +62,7 @@ impl UInt { _context: &mut CubeContext, val: ::Primitive, ) -> ::ExpandType { - let new_var = Variable::ConstantScalar(val as f64, Self::into_elem()); + let new_var = Variable::ConstantScalar(val as f64, Self::as_elem()); ExpandElement::new(Rc::new(new_var)) } } diff --git a/crates/burn-cube/src/language/operation/assignation.rs b/crates/burn-cube/src/language/operation/assignation.rs index 4ada9df10e..0d83035d19 100644 --- a/crates/burn-cube/src/language/operation/assignation.rs +++ b/crates/burn-cube/src/language/operation/assignation.rs @@ -26,12 +26,13 @@ pub mod index_assign { array: ExpandElement, index: ExpandElement, value: ExpandElement, - ) { + ) -> ExpandElement { context.register(Operator::IndexAssign(BinaryOperator { lhs: *index, rhs: *value, out: *array, - })) + })); + array } impl> core::ops::IndexMut for Array { diff --git a/crates/burn-cube/src/language/operation/base.rs b/crates/burn-cube/src/language/operation/base.rs index 6ccc0bc22d..b22c6f72d4 100644 --- a/crates/burn-cube/src/language/operation/base.rs +++ b/crates/burn-cube/src/language/operation/base.rs @@ -1,4 +1,4 @@ -use crate::dialect::{BinaryOperator, Elem, Item, Operator, Variable}; +use crate::dialect::{BinaryOperator, Elem, Item, Operator, Variable, Vectorization}; use crate::language::{CubeContext, ExpandElement}; pub(crate) fn binary_expand( @@ -14,6 +14,8 @@ where let rhs: Variable = *rhs; let item = lhs.item(); + check_vectorization(item.vectorization, rhs.item().vectorization); + let out = context.create_local(item); let out_var = *out; @@ -39,13 +41,15 @@ where { let lhs: Variable = *lhs; let rhs: Variable = *rhs; + let item = lhs.item(); + + check_vectorization(item.vectorization, rhs.item().vectorization); - let out_item = match lhs.item() { - Item::Vec4(_) => Item::Vec4(Elem::Bool), - Item::Vec3(_) => Item::Vec3(Elem::Bool), - Item::Vec2(_) => Item::Vec2(Elem::Bool), - Item::Scalar(_) => Item::Scalar(Elem::Bool), + let out_item = Item { + elem: Elem::Bool, + vectorization: item.vectorization, }; + let out = context.create_local(out_item); let out_var = *out; @@ -82,3 +86,13 @@ where lhs } + +fn check_vectorization(lhs: Vectorization, rhs: Vectorization) { + if lhs == 1 || rhs == 1 { + return; + } + assert!( + lhs == rhs, + "Tried to perform binary operation on different vectorization schemes." + ); +} diff --git a/crates/burn-cube/tests/language/cast_elem.rs b/crates/burn-cube/tests/language/cast_elem.rs index 4174b0563d..f22f8b1aaf 100644 --- a/crates/burn-cube/tests/language/cast_elem.rs +++ b/crates/burn-cube/tests/language/cast_elem.rs @@ -151,109 +151,109 @@ mod tests { cast_test!( cube_float_to_float_test, float_to_float_expand, - Item::Scalar(F32::into_elem()) + Item::new(F32::as_elem()) ); cast_test!( cube_float_to_int_test, float_to_int_expand, - Item::Scalar(F32::into_elem()), - Item::Scalar(I32::into_elem()) + Item::new(F32::as_elem()), + Item::new(I32::as_elem()) ); cast_test!( cube_float_to_uint_test, float_to_uint_expand, - Item::Scalar(F32::into_elem()), - Item::Scalar(Elem::UInt) + Item::new(F32::as_elem()), + Item::new(Elem::UInt) ); cast_test!( cube_float_to_bool_test, float_to_bool_expand, - Item::Scalar(F32::into_elem()), - Item::Scalar(Elem::Bool) + Item::new(F32::as_elem()), + Item::new(Elem::Bool) ); cast_test!( cube_int_to_float_test, int_to_float_expand, - Item::Scalar(I32::into_elem()), - Item::Scalar(F32::into_elem()) + Item::new(I32::as_elem()), + Item::new(F32::as_elem()) ); cast_test!( cube_int_to_int_test, int_to_int_expand, - Item::Scalar(I32::into_elem()) + Item::new(I32::as_elem()) ); cast_test!( cube_int_to_uint_test, int_to_uint_expand, - Item::Scalar(I32::into_elem()), - Item::Scalar(Elem::UInt) + Item::new(I32::as_elem()), + Item::new(Elem::UInt) ); cast_test!( cube_int_to_bool_test, int_to_bool_expand, - Item::Scalar(I32::into_elem()), - Item::Scalar(Elem::Bool) + Item::new(I32::as_elem()), + Item::new(Elem::Bool) ); cast_test!( cube_uint_to_float_test, uint_to_float_expand, - Item::Scalar(Elem::UInt), - Item::Scalar(F32::into_elem()) + Item::new(Elem::UInt), + Item::new(F32::as_elem()) ); cast_test!( cube_uint_to_int_test, uint_to_int_expand, - Item::Scalar(Elem::UInt), - Item::Scalar(I32::into_elem()) + Item::new(Elem::UInt), + Item::new(I32::as_elem()) ); cast_test!( cube_uint_to_uint_test, uint_to_uint_expand, - Item::Scalar(Elem::UInt) + Item::new(Elem::UInt) ); cast_test!( cube_uint_to_bool_test, uint_to_bool_expand, - Item::Scalar(Elem::UInt), - Item::Scalar(Elem::Bool) + Item::new(Elem::UInt), + Item::new(Elem::Bool) ); cast_test!( cube_bool_to_float_test, bool_to_float_expand, - Item::Scalar(Elem::Bool), - Item::Scalar(F32::into_elem()) + Item::new(Elem::Bool), + Item::new(F32::as_elem()) ); cast_test!( cube_bool_to_int_test, bool_to_int_expand, - Item::Scalar(Elem::Bool), - Item::Scalar(I32::into_elem()) + Item::new(Elem::Bool), + Item::new(I32::as_elem()) ); cast_test!( cube_bool_to_uint_test, bool_to_uint_expand, - Item::Scalar(Elem::Bool), - Item::Scalar(Elem::UInt) + Item::new(Elem::Bool), + Item::new(Elem::UInt) ); cast_test!( cube_bool_to_bool_test, bool_to_bool_expand, - Item::Scalar(Elem::Bool) + Item::new(Elem::Bool) ); fn inline_macro_ref_cast(from_item: Item, to_item: Item) -> String { diff --git a/crates/burn-cube/tests/language/cast_kind.rs b/crates/burn-cube/tests/language/cast_kind.rs index d31bd03697..8f4130c9a2 100644 --- a/crates/burn-cube/tests/language/cast_kind.rs +++ b/crates/burn-cube/tests/language/cast_kind.rs @@ -35,7 +35,7 @@ mod tests { #[test] fn cube_cast_float_kind_test() { let mut context = CubeContext::root(); - let item = Item::Scalar(F64::into_elem()); + let item = Item::new(F64::as_elem()); let input = context.create_local(item); @@ -48,7 +48,7 @@ mod tests { #[test] fn cube_cast_int_kind_test() { let mut context = CubeContext::root(); - let item = Item::Scalar(I32::into_elem()); + let item = Item::new(I32::as_elem()); let input = context.create_local(item); @@ -61,7 +61,7 @@ mod tests { #[test] fn cube_cast_numeric_kind_test() { let mut context = CubeContext::root(); - let item = Item::Scalar(I32::into_elem()); + let item = Item::new(I32::as_elem()); let input = context.create_local(item); @@ -74,7 +74,7 @@ mod tests { #[test] fn cube_cast_kind_numeric_test() { let mut context = CubeContext::root(); - let item = Item::Scalar(I32::into_elem()); + let item = Item::new(I32::as_elem()); let input = context.create_local(item); @@ -86,8 +86,8 @@ mod tests { fn inline_macro_ref_float() -> String { let mut context = CubeContext::root(); - let float_64 = Item::Scalar(F64::into_elem()); - let float_32 = Item::Scalar(F32::into_elem()); + let float_64 = Item::new(F64::as_elem()); + let float_32 = Item::new(F32::as_elem()); let input = context.create_local(float_64); let mut scope = context.into_scope(); @@ -104,8 +104,8 @@ mod tests { fn inline_macro_ref_int() -> String { let mut context = CubeContext::root(); - let int_32 = Item::Scalar(I32::into_elem()); - let int_64 = Item::Scalar(I64::into_elem()); + let int_32 = Item::new(I32::as_elem()); + let int_64 = Item::new(I64::as_elem()); let input = context.create_local(int_32); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/language/for_loop.rs b/crates/burn-cube/tests/language/for_loop.rs index 3b43a98678..792964f60e 100644 --- a/crates/burn-cube/tests/language/for_loop.rs +++ b/crates/burn-cube/tests/language/for_loop.rs @@ -25,8 +25,8 @@ mod tests { let mut context = CubeContext::root(); let unroll = true; - let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); - let rhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let lhs = context.create_local(Item::new(ElemType::as_elem())); + let rhs = context.create_local(Item::new(ElemType::as_elem())); let end = 4u32.into(); for_loop_expand::(&mut context, lhs, rhs, end, unroll); @@ -40,8 +40,8 @@ mod tests { let mut context = CubeContext::root(); let unroll = false; - let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); - let rhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let lhs = context.create_local(Item::new(ElemType::as_elem())); + let rhs = context.create_local(Item::new(ElemType::as_elem())); let end = 4u32.into(); for_loop_expand::(&mut context, lhs, rhs, end, unroll); @@ -52,7 +52,7 @@ mod tests { fn inline_macro_ref(unroll: bool) -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let lhs = context.create_local(item); let rhs = context.create_local(item); diff --git a/crates/burn-cube/tests/language/function_call.rs b/crates/burn-cube/tests/language/function_call.rs index 3b24c4d77f..8e05b89a40 100644 --- a/crates/burn-cube/tests/language/function_call.rs +++ b/crates/burn-cube/tests/language/function_call.rs @@ -55,12 +55,12 @@ mod tests { #[test] fn cube_call_equivalent_to_no_call_no_arg_test() { let mut caller_context = CubeContext::root(); - let x = caller_context.create_local(Item::Scalar(Elem::UInt)); + let x = caller_context.create_local(Item::new(Elem::UInt)); caller_no_arg_expand(&mut caller_context, x); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); - let x = no_call_context.create_local(Item::Scalar(Elem::UInt)); + let x = no_call_context.create_local(Item::new(Elem::UInt)); no_call_no_arg_expand(&mut no_call_context, x); let no_call_scope = no_call_context.into_scope(); @@ -74,12 +74,12 @@ mod tests { fn cube_call_equivalent_to_no_call_with_arg_test() { let mut caller_context = CubeContext::root(); - let x = caller_context.create_local(Item::Scalar(Elem::UInt)); + let x = caller_context.create_local(Item::new(Elem::UInt)); caller_with_arg_expand(&mut caller_context, x); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); - let x = no_call_context.create_local(Item::Scalar(Elem::UInt)); + let x = no_call_context.create_local(Item::new(Elem::UInt)); no_call_with_arg_expand(&mut no_call_context, x); let no_call_scope = no_call_context.into_scope(); @@ -93,12 +93,12 @@ mod tests { fn cube_call_equivalent_to_no_call_with_generics_test() { let mut caller_context = CubeContext::root(); type ElemType = I64; - let x = caller_context.create_local(Item::Scalar(ElemType::into_elem())); + let x = caller_context.create_local(Item::new(ElemType::as_elem())); caller_with_generics_expand::(&mut caller_context, x); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); - let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem())); + let x = no_call_context.create_local(Item::new(ElemType::as_elem())); no_call_with_generics_expand::(&mut no_call_context, x); let no_call_scope = no_call_context.into_scope(); diff --git a/crates/burn-cube/tests/language/generic_kernel.rs b/crates/burn-cube/tests/language/generic_kernel.rs index aa057f14f0..cdeb7da0cc 100644 --- a/crates/burn-cube/tests/language/generic_kernel.rs +++ b/crates/burn-cube/tests/language/generic_kernel.rs @@ -14,7 +14,7 @@ mod tests { fn cube_generic_float_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(F32::into_elem())); + let lhs = context.create_local(Item::new(F32::as_elem())); generic_kernel_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -26,7 +26,7 @@ mod tests { fn cube_generic_int_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(I32::into_elem())); + let lhs = context.create_local(Item::new(I32::as_elem())); generic_kernel_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -36,7 +36,7 @@ mod tests { fn inline_macro_ref_float() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(F32::into_elem()); + let item = Item::new(F32::as_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); @@ -48,7 +48,7 @@ mod tests { fn inline_macro_ref_int() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(I32::into_elem()); + let item = Item::new(I32::as_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/language/if.rs b/crates/burn-cube/tests/language/if.rs index 0fe62154ca..af3f6f0fa8 100644 --- a/crates/burn-cube/tests/language/if.rs +++ b/crates/burn-cube/tests/language/if.rs @@ -22,7 +22,7 @@ mod tests { fn cube_if_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let lhs = context.create_local(Item::new(ElemType::as_elem())); if_greater_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -32,11 +32,11 @@ mod tests { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let cond = scope.create_local(Item::new(Elem::Bool)); let lhs: Variable = lhs.into(); let y = scope.create_local(item); diff --git a/crates/burn-cube/tests/language/if_else.rs b/crates/burn-cube/tests/language/if_else.rs index 5fadc3906f..a17de7a05e 100644 --- a/crates/burn-cube/tests/language/if_else.rs +++ b/crates/burn-cube/tests/language/if_else.rs @@ -24,7 +24,7 @@ mod tests { fn cube_if_else_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let lhs = context.create_local(Item::new(ElemType::as_elem())); if_then_else_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -34,11 +34,11 @@ mod tests { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let cond = scope.create_local(Item::new(Elem::Bool)); let lhs: Variable = lhs.into(); let y = scope.create_local(item); diff --git a/crates/burn-cube/tests/language/literal.rs b/crates/burn-cube/tests/language/literal.rs index a4aeed88a1..3296395a58 100644 --- a/crates/burn-cube/tests/language/literal.rs +++ b/crates/burn-cube/tests/language/literal.rs @@ -20,7 +20,7 @@ mod tests { fn cube_literal_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let lhs = context.create_local(Item::new(ElemType::as_elem())); literal_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -32,7 +32,7 @@ mod tests { fn cube_literal_float_no_decimal_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let lhs = context.create_local(Item::new(ElemType::as_elem())); literal_float_no_decimals_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -42,7 +42,7 @@ mod tests { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/burn-cube/tests/language/loop.rs b/crates/burn-cube/tests/language/loop.rs index 71d913edab..2968e90bab 100644 --- a/crates/burn-cube/tests/language/loop.rs +++ b/crates/burn-cube/tests/language/loop.rs @@ -32,7 +32,7 @@ mod tests { fn cube_while_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let lhs = context.create_local(Item::new(ElemType::as_elem())); while_not_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -44,7 +44,7 @@ mod tests { fn cube_loop_break_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::Scalar(ElemType::into_elem())); + let lhs = context.create_local(Item::new(ElemType::as_elem())); manual_loop_break_expand::(&mut context, lhs); let scope = context.into_scope(); @@ -54,11 +54,11 @@ mod tests { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let lhs = context.create_local(item); let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let cond = scope.create_local(Item::new(Elem::Bool)); let lhs: Variable = lhs.into(); let rhs = scope.create_local(item); diff --git a/crates/burn-cube/tests/language/mod.rs b/crates/burn-cube/tests/language/mod.rs index efa001e9ed..23c80d9dc7 100644 --- a/crates/burn-cube/tests/language/mod.rs +++ b/crates/burn-cube/tests/language/mod.rs @@ -11,3 +11,4 @@ mod module_import; mod parenthesis; mod reuse; mod r#trait; +mod vectorization; diff --git a/crates/burn-cube/tests/language/module_import.rs b/crates/burn-cube/tests/language/module_import.rs index a03f5b5a64..15123c682c 100644 --- a/crates/burn-cube/tests/language/module_import.rs +++ b/crates/burn-cube/tests/language/module_import.rs @@ -33,12 +33,12 @@ mod tests { #[test] fn cube_call_equivalent_to_no_call_no_arg_test() { let mut caller_context = CubeContext::root(); - let x = caller_context.create_local(Item::Scalar(ElemType::into_elem())); + let x = caller_context.create_local(Item::new(ElemType::as_elem())); here::caller_expand::(&mut caller_context, x); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); - let x = no_call_context.create_local(Item::Scalar(ElemType::into_elem())); + let x = no_call_context.create_local(Item::new(ElemType::as_elem())); here::no_call_ref_expand::(&mut no_call_context, x); let no_call_scope = no_call_context.into_scope(); diff --git a/crates/burn-cube/tests/language/parenthesis.rs b/crates/burn-cube/tests/language/parenthesis.rs index 8536608c84..7adda62c80 100644 --- a/crates/burn-cube/tests/language/parenthesis.rs +++ b/crates/burn-cube/tests/language/parenthesis.rs @@ -20,9 +20,9 @@ mod tests { fn cube_parenthesis_priority_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(ElemType::into_elem())); - let y = context.create_local(Item::Scalar(ElemType::into_elem())); - let z = context.create_local(Item::Scalar(ElemType::into_elem())); + let x = context.create_local(Item::new(ElemType::as_elem())); + let y = context.create_local(Item::new(ElemType::as_elem())); + let z = context.create_local(Item::new(ElemType::as_elem())); parenthesis_expand::(&mut context, x, y, z); let scope = context.into_scope(); @@ -32,7 +32,7 @@ mod tests { fn inline_macro_ref() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let x = context.create_local(item); let y = context.create_local(item); let z = context.create_local(item); diff --git a/crates/burn-cube/tests/language/reuse.rs b/crates/burn-cube/tests/language/reuse.rs index 03551789f4..51d5e1410a 100644 --- a/crates/burn-cube/tests/language/reuse.rs +++ b/crates/burn-cube/tests/language/reuse.rs @@ -32,7 +32,7 @@ mod tests { fn cube_reuse_assign_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let x = context.create_local(Item::new(ElemType::as_elem())); reuse_expand::(&mut context, x); let scope = context.into_scope(); @@ -44,7 +44,7 @@ mod tests { fn cube_reuse_incr_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(ElemType::into_elem())); + let x = context.create_local(Item::new(ElemType::as_elem())); reuse_incr_expand::(&mut context, x); let scope = context.into_scope(); @@ -54,11 +54,11 @@ mod tests { fn inline_macro_ref_assign() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let x = context.create_local(item); let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let cond = scope.create_local(Item::new(Elem::Bool)); let x: Variable = x.into(); let tmp = scope.create_local(item); @@ -80,11 +80,11 @@ mod tests { fn inline_macro_ref_incr() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let x = context.create_local(item); let mut scope = context.into_scope(); - let cond = scope.create_local(Item::Scalar(Elem::Bool)); + let cond = scope.create_local(Item::new(Elem::Bool)); let x: Variable = x.into(); cpa!( diff --git a/crates/burn-cube/tests/language/trait.rs b/crates/burn-cube/tests/language/trait.rs index 31abb0c83c..fc9ef82e1e 100644 --- a/crates/burn-cube/tests/language/trait.rs +++ b/crates/burn-cube/tests/language/trait.rs @@ -114,8 +114,8 @@ mod tests { fn cube_strategy_trait_add_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(ElemType::into_elem())); - let y = context.create_local(Item::Scalar(ElemType::into_elem())); + let x = context.create_local(Item::new(ElemType::as_elem())); + let y = context.create_local(Item::new(ElemType::as_elem())); with_strategy_trait_expand::(&mut context, x, y); let scope = context.into_scope(); @@ -130,8 +130,8 @@ mod tests { fn cube_strategy_trait_sub_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(ElemType::into_elem())); - let y = context.create_local(Item::Scalar(ElemType::into_elem())); + let x = context.create_local(Item::new(ElemType::as_elem())); + let y = context.create_local(Item::new(ElemType::as_elem())); with_strategy_trait_expand::(&mut context, x, y); let scope = context.into_scope(); @@ -146,8 +146,8 @@ mod tests { fn cube_two_strategy_traits_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(ElemType::into_elem())); - let y = context.create_local(Item::Scalar(ElemType::into_elem())); + let x = context.create_local(Item::new(ElemType::as_elem())); + let y = context.create_local(Item::new(ElemType::as_elem())); two_strategy_traits_expand::(&mut context, x, y); let scope = context.into_scope(); @@ -159,8 +159,8 @@ mod tests { fn cube_trait_generic_method_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::Scalar(ElemType::into_elem())); - let y = context.create_local(Item::Scalar(ElemType::into_elem())); + let x = context.create_local(Item::new(ElemType::as_elem())); + let y = context.create_local(Item::new(ElemType::as_elem())); with_trait_generic_method_expand::(&mut context, x, y); let scope = context.into_scope(); @@ -173,7 +173,7 @@ mod tests { fn inline_macro_ref_one(is_add_strategy: bool) -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let x = context.create_local(item); let y = context.create_local(item); @@ -192,7 +192,7 @@ mod tests { fn inline_macro_ref_two() -> String { let mut context = CubeContext::root(); - let item = Item::Scalar(ElemType::into_elem()); + let item = Item::new(ElemType::as_elem()); let x = context.create_local(item); let y = context.create_local(item); diff --git a/crates/burn-cube/tests/language/vectorization.rs b/crates/burn-cube/tests/language/vectorization.rs new file mode 100644 index 0000000000..0d527d01dd --- /dev/null +++ b/crates/burn-cube/tests/language/vectorization.rs @@ -0,0 +1,67 @@ +use burn_cube::{cube, Numeric}; + +#[cube] +pub fn vectorization_binary(lhs: T) { + let _ = lhs + T::from_vec(&[4, 5]); +} + +#[cube] +pub fn vectorization_cmp(rhs: T) { + let _ = T::from_vec(&[4, 5]) > rhs; +} + +mod tests { + + use burn_cube::{dialect::Item, CubeContext, PrimitiveVariable, F32}; + + use crate::language::vectorization::{vectorization_binary_expand, vectorization_cmp_expand}; + + type ElemType = F32; + + #[test] + fn cube_vectorization_binary_op_with_same_scheme_does_not_fail() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); + + vectorization_binary_expand::(&mut context, lhs); + } + + #[test] + #[should_panic] + fn cube_vectorization_binary_op_with_different_scheme_fails() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); + + vectorization_binary_expand::(&mut context, lhs); + } + + #[test] + fn cube_vectorization_cmp_op_with_same_scheme_does_not_fail() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); + + vectorization_cmp_expand::(&mut context, lhs); + } + + #[test] + #[should_panic] + fn cube_vectorization_cmp_op_with_different_scheme_fails() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); + + vectorization_cmp_expand::(&mut context, lhs); + } + + #[test] + fn cube_vectorization_can_be_broadcasted() { + let mut context = CubeContext::root(); + + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1)); + + vectorization_cmp_expand::(&mut context, lhs); + } +} diff --git a/crates/burn-jit/src/fusion/elemwise/kernel.rs b/crates/burn-jit/src/fusion/elemwise/kernel.rs index df51463606..5f282fbdaa 100644 --- a/crates/burn-jit/src/fusion/elemwise/kernel.rs +++ b/crates/burn-jit/src/fusion/elemwise/kernel.rs @@ -1,7 +1,6 @@ use burn_cube::{ - calculate_num_elems_dyn_rank, - dialect::{Vectorization, WorkgroupSize}, - elemwise_workgroup, CompilationInfo, CompilationSettings, + calculate_num_elems_dyn_rank, dialect::WorkgroupSize, elemwise_workgroup, CompilationInfo, + CompilationSettings, }; use burn_tensor::repr::TensorDescription; @@ -56,12 +55,10 @@ impl FusionKernelFactory for ElementWiseKernelFactory { ); if vectorize_4 { - settings = settings.vectorize(Vectorization::Vec4); + settings = settings.vectorize(4); factor = 4; - } - - if !vectorize_4 && vectorize_2 { - settings = settings.vectorize(Vectorization::Vec2); + } else if vectorize_2 { + settings = settings.vectorize(2); factor = 2; } diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index b796dfd176..8b2d3e56ef 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -47,7 +47,7 @@ impl TraceBuilder { false => { // New input let index = self.inputs.len() as u16; - let item = Item::Scalar(elem); + let item = Item::new(elem); let local = self.scope.read_array(index, item, position); self.inputs.push((tensor.clone(), local)); @@ -56,7 +56,7 @@ impl TraceBuilder { true => match self.output_to_local.get(&tensor.id) { // Is a local variable. Some(local_index) => { - Variable::Local(*local_index, Item::Scalar(elem), self.scope.depth) + Variable::Local(*local_index, Item::new(elem), self.scope.depth) } // Isn't an operation output variable, so must be an existing input. None => self @@ -84,10 +84,10 @@ impl TraceBuilder { // Output already registered as a local variable. if let Some(index) = self.output_to_local.get(&tensor.id) { - return Variable::Local(*index, Item::Scalar(elem), self.scope.depth); + return Variable::Local(*index, Item::new(elem), self.scope.depth); } - let variable = self.scope.create_local(Item::Scalar(elem)); + let variable = self.scope.create_local(Item::new(elem)); let local_index = variable.index().unwrap(); self.output_to_local.insert(tensor.id, local_index); variable diff --git a/crates/burn-jit/src/fusion/tracing/trace.rs b/crates/burn-jit/src/fusion/tracing/trace.rs index bf1def94da..771d27d67b 100644 --- a/crates/burn-jit/src/fusion/tracing/trace.rs +++ b/crates/burn-jit/src/fusion/tracing/trace.rs @@ -45,7 +45,7 @@ impl Trace { .inputs .iter() .map(|(_tensor, elem, _)| InputInfo::Array { - item: Item::Scalar(*elem), + item: Item::new(*elem), visibility: Visibility::Read, }) .collect::>(); @@ -56,7 +56,7 @@ impl Trace { .zip(self.locals.iter()) .map( |((_tensor, elem, index_ref), local)| OutputInfo::ArrayWrite { - item: Item::Scalar(*elem), + item: Item::new(*elem), local: *local, position: *index_ref, }, diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index 236a6d84af..1a53d2e19b 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -65,15 +65,15 @@ macro_rules! binary { let local = scope.last_local_index().unwrap().index().unwrap(); let lhs = burn_cube::InputInfo::Array { - item: burn_cube::dialect::Item::Scalar(I::cube_elem()), + item: burn_cube::dialect::Item::new(I::cube_elem()), visibility: burn_cube::dialect::Visibility::Read, }; let rhs = burn_cube::InputInfo::Array { - item: burn_cube::dialect::Item::Scalar(I::cube_elem()), + item: burn_cube::dialect::Item::new(I::cube_elem()), visibility: burn_cube::dialect::Visibility::Read, }; let out = burn_cube::OutputInfo::ArrayWrite { - item: burn_cube::dialect::Item::Scalar(O::cube_elem()), + item: burn_cube::dialect::Item::new(O::cube_elem()), local, position, }; diff --git a/crates/burn-jit/src/kernel/cast/bool_cast.rs b/crates/burn-jit/src/kernel/cast/bool_cast.rs index a7493568a7..f8427d1d15 100644 --- a/crates/burn-jit/src/kernel/cast/bool_cast.rs +++ b/crates/burn-jit/src/kernel/cast/bool_cast.rs @@ -56,7 +56,7 @@ pub(crate) struct BoolCastEagerKernel { impl GpuComputeShaderPhase for BoolCastEagerKernel { fn compile(&self) -> ComputeShader { let mut scope = Scope::root(); - let item_input = Item::Scalar(Elem::Bool); + let item_input = Item::new(Elem::Bool); let item_output = EO::cube_elem().into(); let tensor = Variable::GlobalInputArray(0, item_input); diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs index 29b492f70f..b367bd1332 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs @@ -125,14 +125,14 @@ pub(crate) fn gather_shader_information( // Registers used in the compute pass let results = scope.create_local_array(elem, results_size); - let register_m = scope.create_local(Item::Vec4(elem)); - let register_n = scope.create_local(Item::Vec4(elem)); + let register_m = scope.create_local(Item::vectorized(elem, 4)); + let register_n = scope.create_local(Item::vectorized(elem, 4)); let shared_lhs = scope.create_shared( - Item::Vec4(elem), + Item::vectorized(elem, 4), shader.config.block_size_m as u32 * shader.config.block_size_k as u32 / 4u32, ); let shared_rhs = scope.create_shared( - Item::Vec4(elem), + Item::vectorized(elem, 4), shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32, ); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index c7c86c298c..21ea74a018 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -268,7 +268,7 @@ impl GpuComputeShaderPhase let mut scope = Scope::root(); let item = E::cube_elem().into(); - let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int(IntKind::I32))); + let indices = Variable::GlobalInputArray(0, Item::new(Elem::Int(IntKind::I32))); let grad = Variable::GlobalInputArray(1, item); let output = Variable::GlobalOutputArray(0, item); @@ -283,7 +283,7 @@ impl GpuComputeShaderPhase .expand(&mut scope); let indices = InputInfo::Array { - item: Item::Scalar(Elem::Int(IntKind::I32)), + item: Item::new(Elem::Int(IntKind::I32)), visibility: Visibility::Read, }; diff --git a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs index f7ce467938..2cfec02c06 100644 --- a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs @@ -184,7 +184,7 @@ impl GpuComputeShaderPhase let indices = if P::with_indices() { Some(Variable::GlobalOutputArray( 1, - Item::Scalar(Elem::Int(IntKind::I32)), + Item::new(Elem::Int(IntKind::I32)), )) } else { None @@ -216,7 +216,7 @@ impl GpuComputeShaderPhase vec![ output, OutputInfo::Array { - item: Item::Scalar(Elem::Int(IntKind::I32)), + item: Item::new(Elem::Int(IntKind::I32)), }, ] } else { diff --git a/crates/burn-jit/src/kernel/unary.rs b/crates/burn-jit/src/kernel/unary.rs index 1030f11b4d..e956c714e6 100644 --- a/crates/burn-jit/src/kernel/unary.rs +++ b/crates/burn-jit/src/kernel/unary.rs @@ -70,11 +70,11 @@ macro_rules! unary { let local = scope.last_local_index().unwrap().index().unwrap(); let input = burn_cube::InputInfo::Array { - item: burn_cube::dialect::Item::Scalar(E::cube_elem()), + item: burn_cube::dialect::Item::new(E::cube_elem()), visibility: burn_cube::dialect::Visibility::Read, }; let out = burn_cube::OutputInfo::ArrayWrite { - item: burn_cube::dialect::Item::Scalar(E::cube_elem()), + item: burn_cube::dialect::Item::new(E::cube_elem()), local, position: burn_cube::dialect::Variable::Id, }; @@ -146,7 +146,7 @@ macro_rules! unary { let local = scope.last_local_index().unwrap().index().unwrap(); let input = burn_cube::InputInfo::Array { - item: burn_cube::dialect::Item::Scalar(E::cube_elem()), + item: burn_cube::dialect::Item::new(E::cube_elem()), visibility: burn_cube::dialect::Visibility::Read, }; let scalars = burn_cube::InputInfo::Scalar { @@ -154,7 +154,7 @@ macro_rules! unary { size: $num, }; let out = burn_cube::OutputInfo::ArrayWrite { - item: burn_cube::dialect::Item::Scalar(E::cube_elem()), + item: burn_cube::dialect::Item::new(E::cube_elem()), local, position: burn_cube::dialect::Variable::Id, }; diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index c959d4a624..028a417b47 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -295,7 +295,7 @@ where fn int_abs(tensor: IntTensor) -> IntTensor { unary!( operation: |scope: &mut Scope, elem: Elem, position: Variable| Operator::Abs(UnaryOperator { - input: scope.read_array(0, Item::Scalar(elem), position), + input: scope.read_array(0, Item::new(elem), position), out: scope.create_local(elem), }), runtime: R, diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index 016ca2b0c6..af2a43a994 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -89,11 +89,13 @@ impl WgslCompiler { } fn compile_item(item: cube::Item) -> Item { - match item { - cube::Item::Vec4(elem) => wgsl::Item::Vec4(Self::compile_elem(elem)), - cube::Item::Vec3(elem) => wgsl::Item::Vec3(Self::compile_elem(elem)), - cube::Item::Vec2(elem) => wgsl::Item::Vec2(Self::compile_elem(elem)), - cube::Item::Scalar(elem) => wgsl::Item::Scalar(Self::compile_elem(elem)), + let elem = Self::compile_elem(item.elem); + match item.vectorization { + 1 => wgsl::Item::Scalar(elem), + 2 => wgsl::Item::Vec2(elem), + 3 => wgsl::Item::Vec3(elem), + 4 => wgsl::Item::Vec4(elem), + _ => panic!("Unsupported vectorizations scheme {:?}", item.vectorization), } } @@ -101,7 +103,7 @@ impl WgslCompiler { match value { cube::Elem::Float(f) => match f { cube::FloatKind::F16 => panic!("f16 is not yet supported"), - cube::FloatKind::BF16 => panic!("f64 is not a valid WgpuElement"), + cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"), cube::FloatKind::F32 => wgsl::Elem::F32, cube::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"), },