Skip to content

Commit

Permalink
Cube: CubeType (no launch) and Comptime::map (#1853)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Jun 4, 2024
1 parent a5af19b commit c42abad
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 69 deletions.
25 changes: 19 additions & 6 deletions crates/burn-cube-macros/src/analysis.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use syn::{Member, PathArguments, Stmt};
use syn::{Member, Pat, PathArguments, Stmt};

use crate::variable_key::VariableKey;

Expand Down Expand Up @@ -310,12 +310,25 @@ impl CodeAnalysisBuilder {
}
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Closure(expr) => {
assert!(
expr.inputs.is_empty(),
"Analysis: closure with args not supported"
);
let depth = depth + 1;

for path in expr.inputs.iter() {
let ident = match path {
Pat::Ident(pat_ident) => &pat_ident.ident,
Pat::Type(pat_type) => {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
&pat_ident.ident
} else {
todo!("Analysis: {:?} not supported in closure inputs. ", path);
}
}
_ => todo!("Analysis: {:?} not supported in closure inputs. ", path),
};

self.declarations.push(((ident).into(), depth));
}

self.find_occurrences_in_expr(&expr.body, depth + 1)
self.find_occurrences_in_expr(&expr.body, depth)
}
syn::Expr::Unary(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Field(expr) => {
Expand Down
33 changes: 28 additions & 5 deletions crates/burn-cube-macros/src/codegen_function/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,28 @@ pub(crate) fn codegen_closure(
) -> TokenStream {
let mut inputs = quote::quote! {};
for input in closure.inputs.iter() {
let ident = match input {
syn::Pat::Ident(ident) => &ident.ident,
let (ident, ty) = match input {
syn::Pat::Ident(ident) => (&ident.ident, None),
syn::Pat::Type(pat_type) => (
if let syn::Pat::Ident(ident) = &*pat_type.pat {
&ident.ident
} else {
panic!("Codegen: Unsupported {:?}", input);
},
Some(pat_type.ty.clone()),
),
_ => panic!("Codegen: Unsupported {:?}", input),
};
inputs.extend(quote::quote! {
#ident,
});

if let Some(ty) = ty {
inputs.extend(quote::quote! {
#ident : #ty,
});
} else {
inputs.extend(quote::quote! {
#ident,
});
}
}

let body = codegen_expr(closure.body.as_ref(), loop_level, variable_analyses);
Expand Down Expand Up @@ -124,6 +139,14 @@ pub(crate) fn parse_function_call(
let code = call.args.first().unwrap();
quote::quote! {#code}
}
"map" => {
let args = codegen_args(&call.args, loop_level, variable_analyses);

// Codegen
quote::quote! {
Comptime::map_expand(#args)
}
}
"unwrap_or_else" => {
let args = codegen_args(&call.args, loop_level, variable_analyses);

Expand Down
26 changes: 17 additions & 9 deletions crates/burn-cube-macros/src/codegen_type/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl TypeCodegen {
}
}

pub(crate) fn generate_cube_type(ast: &syn::DeriveInput) -> TokenStream {
pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream {
let name = ast.ident.clone();
let generics = ast.generics.clone();
let name_string = name.to_string();
Expand Down Expand Up @@ -210,14 +210,22 @@ pub(crate) fn generate_cube_type(ast: &syn::DeriveInput) -> TokenStream {
let arg_settings_impl = codegen.arg_settings_impl();
let launch_arg_impl = codegen.launch_arg_impl();

quote! {
#expand_ty
#launch_ty
#launch_new
if with_launch {
quote! {
#expand_ty
#launch_ty
#launch_new

#cube_type_impl
#arg_settings_impl
#launch_arg_impl
#cube_type_impl
#arg_settings_impl
#launch_arg_impl
}
.into()
} else {
quote! {
#expand_ty
#cube_type_impl
}
.into()
}
.into()
}
16 changes: 12 additions & 4 deletions crates/burn-cube-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,20 @@ enum CubeMode {
Debug,
}

// Derive macro to define a cube type.
#[proc_macro_derive(Cube)]
pub fn module_derive(input: TokenStream) -> TokenStream {
// Derive macro to define a cube type that is launched with a kernel
#[proc_macro_derive(CubeLaunch)]
pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();

generate_cube_type(&input)
generate_cube_type(&input, true)
}

// Derive macro to define a cube type that is not launched
#[proc_macro_derive(CubeType)]
pub fn module_derive_cube_type(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();

generate_cube_type(&input, false)
}

/// Derive macro for the module.
Expand Down
12 changes: 12 additions & 0 deletions crates/burn-cube/src/frontend/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ impl<T> Comptime<T> {
pub fn get(_comptime: Self) -> T {
unexpanded!()
}

pub fn map<R, F: Fn(T) -> R>(_comptime: Self, _closure: F) -> Comptime<R> {
unexpanded!()
}

pub fn map_expand<R, F: Fn(&mut CubeContext, T) -> R>(
context: &mut CubeContext,
inner: T,
closure: F,
) -> R {
closure(context, inner)
}
}

impl<T: CubeType + Into<T::ExpandType>> Comptime<Option<T>> {
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-cube/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ pub use pod::*;
pub use runtime::*;

pub use burn_cube_macros::cube;
pub use burn_cube_macros::Cube;
pub use burn_cube_macros::CubeLaunch;
pub use burn_cube_macros::CubeType;

/// An approximation of the subcube dimension.
pub const SUBCUBE_DIM_APPROX: usize = 16;
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cube/src/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub use crate::{cube, Cube, Kernel, RuntimeArg};
pub use crate::{cube, CubeLaunch, CubeType, Kernel, RuntimeArg};

pub use crate::codegen::{KernelExpansion, KernelIntegrator, KernelSettings};
pub use crate::compute::{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use burn_cube::prelude::*;

#[cube]
pub fn if_then_else<F: Float>(lhs: F) {
if lhs < F::from_int(0) {
let _ = lhs + F::from_int(4);
} else {
let _ = lhs - F::from_int(5);
}
#[derive(Clone)]
pub struct State {
cond: bool,
bound: u32,
}

#[cube]
Expand All @@ -18,29 +15,41 @@ pub fn comptime_if_else<T: Numeric>(lhs: T, cond: Comptime<bool>) {
}
}

#[cube]
pub fn comptime_with_map_bool<T: Numeric>(state: Comptime<State>) -> T {
let cond = Comptime::map(state, |s: State| s.cond);

let mut x = T::from_int(3);
if Comptime::get(cond) {
x += T::from_int(4);
} else {
x -= T::from_int(4);
}
x
}

#[cube]
pub fn comptime_with_map_uint<T: Numeric>(state: Comptime<State>) -> T {
let bound = Comptime::map(state, |s: State| s.bound);

let mut x = T::from_int(3);
for _ in range(0u32, Comptime::get(bound), Comptime::new(true)) {
x += T::from_int(4);
}

x
}

mod tests {
use super::*;
use burn_cube::{
cpa,
frontend::{CubeContext, CubeElem, F32},
ir::{Elem, Item, Variable},
ir::{Item, Variable},
};

use super::{comptime_if_else_expand, if_then_else_expand};

type ElemType = F32;

#[test]
fn cube_if_else_test() {
let mut context = CubeContext::root();

let lhs = context.create_local(Item::new(ElemType::as_elem()));

if_then_else_expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
}

#[test]
fn cube_comptime_if_test() {
let mut context = CubeContext::root();
Expand Down Expand Up @@ -71,24 +80,46 @@ mod tests {
);
}

fn inline_macro_ref() -> String {
#[test]
fn cube_comptime_map_bool_test() {
let mut context1 = CubeContext::root();
let mut context2 = CubeContext::root();

let comptime_state_true = State {
cond: true,
bound: 4,
};
let comptime_state_false = State {
cond: false,
bound: 4,
};

comptime_with_map_bool_expand::<ElemType>(&mut context1, comptime_state_true);
comptime_with_map_bool_expand::<ElemType>(&mut context2, comptime_state_false);

let scope1 = context1.into_scope();
let scope2 = context2.into_scope();

assert_ne!(
format!("{:?}", scope1.operations),
format!("{:?}", scope2.operations)
);
}

#[test]
fn cube_comptime_map_uint_test() {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);

let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let y = scope.create_local(item);
let comptime_state = State {
cond: true,
bound: 4,
};

cpa!(scope, cond = lhs < 0f32);
cpa!(&mut scope, if(cond).then(|scope| {
cpa!(scope, y = lhs + 4.0f32);
}).else(|scope|{
cpa!(scope, y = lhs - 5.0f32);
}));
comptime_with_map_uint_expand::<ElemType>(&mut context, comptime_state);

format!("{:?}", scope.operations)
let scope = context.into_scope();

assert!(!format!("{:?}", scope.operations).contains("RangeLoop"));
}

fn inline_macro_ref_comptime(cond: bool) -> String {
Expand Down
Loading

0 comments on commit c42abad

Please sign in to comment.