Skip to content

Commit

Permalink
Refactor/cube/vectorization (#1781)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored May 19, 2024
1 parent 499ff0d commit 76fe0ed
Show file tree
Hide file tree
Showing 49 changed files with 433 additions and 277 deletions.
9 changes: 9 additions & 0 deletions crates/burn-cube-macros/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ impl CodeAnalysisBuilder {
}
syn::Expr::Break(_) => {}
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Array(expr) => {
for element in expr.elems.iter() {
match element {
syn::Expr::Lit(_) => {}
_ => todo!("Analysis: only array of literals is supported"),
}
}
}
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
_ => todo!("Analysis: unsupported expr {expr:?}"),
}
}
Expand Down
16 changes: 15 additions & 1 deletion crates/burn-cube-macros/src/codegen/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use super::{
branch::{codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_while_loop},
function::{codegen_call, codegen_closure, codegen_expr_method_call},
operation::codegen_binary,
variable::{codegen_assign, codegen_index, codegen_lit, codegen_local, codegen_path_rhs},
variable::{
codegen_array_lit, codegen_assign, codegen_index, codegen_lit, codegen_local,
codegen_path_rhs,
},
};

/// Codegen for a statement (generally one line)
Expand Down Expand Up @@ -59,6 +62,15 @@ pub(crate) fn codegen_expr_block(
codegen_block(&block.block, loop_level, variable_analyses)
}

pub(crate) fn codegen_ref(
reference: &syn::ExprReference,
loop_level: usize,
variable_analyses: &mut CodeAnalysis,
) -> TokenStream {
let inner = codegen_expr(&reference.expr, loop_level, variable_analyses);
quote::quote! { & #inner }
}

/// Codegen for expressions
/// There are many variants of expression, treated differently
pub(crate) fn codegen_expr(
Expand All @@ -84,6 +96,8 @@ pub(crate) fn codegen_expr(
syn::Expr::MethodCall(call) => codegen_expr_method_call(call),
syn::Expr::Index(index) => codegen_index(index, loop_level, variable_analyses),
syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_analyses),
syn::Expr::Array(array) => codegen_array_lit(array),
syn::Expr::Reference(reference) => codegen_ref(reference, loop_level, variable_analyses),
_ => panic!("Codegen: Unsupported {:?}", expr),
}
}
7 changes: 1 addition & 6 deletions crates/burn-cube-macros/src/codegen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@ pub(crate) fn codegen_closure(
}

/// Codegen for a function call
/// Supports:
/// func()
/// func::<T>()
/// T::func()
///
/// Should map:
/// Maps
/// [A[::<...>]?::]^* func[::<...>] (args)
/// to
/// [A[::<...>]?::]^* func_expand[::<...>] (context, args)
Expand Down
13 changes: 13 additions & 0 deletions crates/burn-cube-macros/src/codegen/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
}
}

/// Codegen for arrays of literals
pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
let mut tokens = quote::quote! {};
for element in array.elems.iter() {
let token = match element {
syn::Expr::Lit(lit) => codegen_lit(lit),
_ => todo!("Codegen: Only arrays of literals are supported"),
};
tokens.extend(quote::quote! { #token, });
}
quote::quote! { [ #tokens ] }
}

/// Codegen for a local declaration (let ...)
/// Supports:
/// let x = ...
Expand Down
21 changes: 7 additions & 14 deletions crates/burn-cube/src/codegen/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,7 @@ impl core::fmt::Display for CompilationSettings {
}

match self.vectorization {
Some(vectorization) => match vectorization {
Vectorization::Vec4 => f.write_str("v4"),
Vectorization::Vec3 => f.write_str("v3"),
Vectorization::Vec2 => f.write_str("v2"),
Vectorization::Scalar => f.write_str("v1"),
}?,
Some(vectorization) => f.write_fmt(format_args!("v{}", vectorization))?,
None => f.write_str("vn")?,
};

Expand Down Expand Up @@ -154,7 +149,7 @@ impl InputInfo {
item,
visibility: _,
} => *item,
InputInfo::Scalar { elem, size: _ } => Item::Scalar(*elem),
InputInfo::Scalar { elem, size: _ } => Item::new(*elem),
}
}
}
Expand Down Expand Up @@ -252,7 +247,7 @@ impl Compilation {
named.push((
"info".to_string(),
Binding {
item: Item::Scalar(Elem::UInt),
item: Item::new(Elem::UInt),
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
Expand Down Expand Up @@ -300,7 +295,7 @@ impl Compilation {
self.named_bindings.push((
format!("scalars_{}", elem),
Binding {
item: Item::Scalar(elem),
item: Item::new(elem),
visibility: Visibility::Read,
location: Location::Storage,
size: Some(size),
Expand Down Expand Up @@ -440,11 +435,9 @@ impl Compilation {
}

fn bool_item(ty: Item) -> Item {
match ty {
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
Item {
elem: bool_elem(ty.elem),
vectorization: ty.vectorization,
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cube/src/codegen/dialect/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl RangeLoop {
func: F,
) {
let mut scope = parent_scope.child();
let index_ty = Item::Scalar(Elem::UInt);
let index_ty = Item::new(Elem::UInt);
let i = scope.create_local_undeclared(index_ty);

func(i, &mut scope);
Expand Down
47 changes: 21 additions & 26 deletions crates/burn-cube/src/codegen/dialect/procedure/assign.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::codegen::dialect::{macros::cpa, Item, Scope, Variable, Vectorization};
use crate::{
branch::range,
codegen::dialect::{macros::cpa, Scope, Variable, Vectorization},
};
use serde::{Deserialize, Serialize};

/// Assign value to a variable based on a given condition.
Expand All @@ -19,14 +22,15 @@ impl ConditionalAssign {
let rhs = self.rhs;
let out = self.out;

let index_var = |scope: &mut Scope, var: Variable, index: usize| match var.item() {
Item::Scalar(_) => var,
_ => {
let out = scope.create_local(var.item().elem());
cpa!(scope, out = var[index]);
out
}
};
let index_var =
|scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 {
true => var,
false => {
let out = scope.create_local(var.item().elem());
cpa!(scope, out = var[index]);
out
}
};

let mut assign_index = |index: usize| {
let cond = index_var(scope, cond, index);
Expand All @@ -42,29 +46,20 @@ impl ConditionalAssign {
}));
};

match out.item() {
Item::Vec4(_) => {
assign_index(0);
assign_index(1);
assign_index(2);
assign_index(3);
}
Item::Vec3(_) => {
assign_index(0);
assign_index(1);
assign_index(2);
}
Item::Vec2(_) => {
assign_index(0);
assign_index(1);
}
Item::Scalar(_) => {
let vectorization = out.item().vectorization;
match vectorization == 1 {
true => {
cpa!(scope, if (cond).then(|scope| {
cpa!(scope, out = lhs);
}).else(|scope| {
cpa!(scope, out = rhs);
}));
}
false => {
for i in range(0u32, vectorization as u32, true) {
assign_index(i);
}
}
};
}

Expand Down
8 changes: 4 additions & 4 deletions crates/burn-cube/src/codegen/dialect/procedure/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ impl CheckedIndex {
let lhs = self.lhs;
let rhs = self.rhs;
let out = self.out;
let array_len = scope.create_local(Item::Scalar(crate::dialect::Elem::UInt));
let inside_bound = scope.create_local(Item::Scalar(crate::dialect::Elem::Bool));
let array_len = scope.create_local(Item::new(crate::dialect::Elem::UInt));
let inside_bound = scope.create_local(Item::new(crate::dialect::Elem::Bool));

cpa!(scope, array_len = len(lhs));
cpa!(scope, inside_bound = rhs < array_len);
Expand Down Expand Up @@ -56,8 +56,8 @@ impl CheckedIndexAssign {
let lhs = self.lhs;
let rhs = self.rhs;
let out = self.out;
let array_len = scope.create_local(Item::Scalar(Elem::UInt));
let inside_bound = scope.create_local(Item::Scalar(Elem::Bool));
let array_len = scope.create_local(Item::new(Elem::UInt));
let inside_bound = scope.create_local(Item::new(Elem::Bool));

cpa!(scope, array_len = len(out));
cpa!(scope, inside_bound = lhs < array_len);
Expand Down
12 changes: 3 additions & 9 deletions crates/burn-cube/src/codegen/dialect/procedure/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,11 @@ impl IndexOffsetGlobalWithLayout {
#[allow(missing_docs)]
pub fn expand(self, scope: &mut Scope) {
let layout = self.layout;
let index_item_ty = Item::Scalar(Elem::UInt);
let index_item_ty = Item::new(Elem::UInt);
let offset_ref = self.position;
let zero: Variable = 0u32.into();
let vectorization_factor: Variable = match self.tensors[0].item() {
Item::Vec4(_) => 4u32,
Item::Vec3(_) => 3u32,
Item::Vec2(_) => 2u32,
Item::Scalar(_) => 1u32,
}
.into();

let vectorization_factor: u8 = self.tensors[0].item().vectorization;
let vectorization_factor: Variable = (vectorization_factor as u32).into();
for index in self.indexes.iter() {
cpa!(scope, index = zero);
}
Expand Down
8 changes: 3 additions & 5 deletions crates/burn-cube/src/codegen/dialect/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,9 @@ impl Scope {
position: Variable,
) -> Variable {
let item_global = match item.elem() {
Elem::Bool => match item {
Item::Vec4(_) => Item::Vec4(Elem::UInt),
Item::Vec3(_) => Item::Vec3(Elem::UInt),
Item::Vec2(_) => Item::Vec2(Elem::UInt),
Item::Scalar(_) => Item::Scalar(Elem::UInt),
Elem::Bool => Item {
elem: Elem::UInt,
vectorization: item.vectorization,
},
_ => item,
};
Expand Down
34 changes: 21 additions & 13 deletions crates/burn-cube/src/codegen/dialect/shader.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Scope;
use super::{Scope, Vectorization};
use crate::WORKGROUP_DEFAULT;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
Expand Down Expand Up @@ -44,7 +44,7 @@ pub enum Elem {

impl From<Elem> for Item {
fn from(val: Elem) -> Self {
Item::Scalar(val)
Item::new(val)
}
}

Expand Down Expand Up @@ -81,22 +81,30 @@ impl Display for Elem {
}

#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)]
#[allow(missing_docs)]
pub enum Item {
Vec4(Elem),
Vec3(Elem),
Vec2(Elem),
Scalar(Elem),
pub struct Item {
pub elem: Elem,
pub vectorization: Vectorization,
}

impl Item {
/// Fetch the elem of the item.
pub fn elem(&self) -> Elem {
match self {
Self::Vec4(elem) => *elem,
Self::Vec3(elem) => *elem,
Self::Vec2(elem) => *elem,
Self::Scalar(elem) => *elem,
self.elem
}

/// Create a new item without vectorization
pub fn new(elem: Elem) -> Self {
Self {
elem,
vectorization: 1,
}
}

/// Create a new item with vectorization
pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self {
Self {
elem,
vectorization,
}
}
}
Expand Down
42 changes: 21 additions & 21 deletions crates/burn-cube/src/codegen/dialect/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,30 @@ impl Variable {
match self {
Variable::GlobalInputArray(_, item) => *item,
Variable::GlobalOutputArray(_, item) => *item,
Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
Variable::GlobalScalar(_, elem) => Item::new(*elem),
Variable::Local(_, item, _) => *item,
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
Variable::LocalScalar(_, elem, _) => Item::new(*elem),
Variable::ConstantScalar(_, elem) => Item::new(*elem),
Variable::SharedMemory(_, item, _) => *item,
Variable::LocalArray(_, item, _, _) => *item,
Variable::Id => Item::Scalar(Elem::UInt),
Variable::Rank => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdX => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdY => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdZ => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdX => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdY => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdZ => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdX => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdY => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdZ => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeX => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeY => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeZ => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsX => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsY => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsZ => Item::Scalar(Elem::UInt),
Variable::Id => Item::new(Elem::UInt),
Variable::Rank => Item::new(Elem::UInt),
Variable::LocalInvocationIndex => Item::new(Elem::UInt),
Variable::LocalInvocationIdX => Item::new(Elem::UInt),
Variable::LocalInvocationIdY => Item::new(Elem::UInt),
Variable::LocalInvocationIdZ => Item::new(Elem::UInt),
Variable::WorkgroupIdX => Item::new(Elem::UInt),
Variable::WorkgroupIdY => Item::new(Elem::UInt),
Variable::WorkgroupIdZ => Item::new(Elem::UInt),
Variable::GlobalInvocationIdX => Item::new(Elem::UInt),
Variable::GlobalInvocationIdY => Item::new(Elem::UInt),
Variable::GlobalInvocationIdZ => Item::new(Elem::UInt),
Variable::WorkgroupSizeX => Item::new(Elem::UInt),
Variable::WorkgroupSizeY => Item::new(Elem::UInt),
Variable::WorkgroupSizeZ => Item::new(Elem::UInt),
Variable::NumWorkgroupsX => Item::new(Elem::UInt),
Variable::NumWorkgroupsY => Item::new(Elem::UInt),
Variable::NumWorkgroupsZ => Item::new(Elem::UInt),
}
}
}
Expand Down
Loading

0 comments on commit 76fe0ed

Please sign in to comment.