From 3faf544bc417c70a33eaa6a636c9c21a563fcb63 Mon Sep 17 00:00:00 2001 From: towerpark <24645162+towerpark@users.noreply.github.com> Date: Tue, 25 Jun 2024 22:15:34 +0900 Subject: [PATCH 1/4] Book: Fix the link to burn-train in "Learner" page (#1920) Add the missing "crates/" to the link. Co-authored-by: towerpark --- burn-book/src/building-blocks/learner.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/burn-book/src/building-blocks/learner.md b/burn-book/src/building-blocks/learner.md index bae1a4fed5..e9ceb38a7f 100644 --- a/burn-book/src/building-blocks/learner.md +++ b/burn-book/src/building-blocks/learner.md @@ -1,12 +1,12 @@ # Learner -The [burn-train](https://github.com/tracel-ai/burn/tree/main/burn-train) crate encapsulates multiple -utilities for training deep learning models. The goal of the crate is to provide users with a -well-crafted and flexible training loop, so that projects do not have to write such components from -the ground up. Most of the interactions with `burn-train` will be with the `LearnerBuilder` struct, -briefly presented in the previous [training section](../basic-workflow/training.md). This struct -enables you to configure the training loop, offering support for registering metrics, enabling -logging, checkpointing states, using multiple devices, and so on. +The [burn-train](https://github.com/tracel-ai/burn/tree/main/crates/burn-train) crate encapsulates +multiple utilities for training deep learning models. The goal of the crate is to provide users with +a well-crafted and flexible training loop, so that projects do not have to write such components +from the ground up. Most of the interactions with `burn-train` will be with the `LearnerBuilder` +struct, briefly presented in the previous [training section](../basic-workflow/training.md). This +struct enables you to configure the training loop, offering support for registering metrics, +enabling logging, checkpointing states, using multiple devices, and so on. There are still some assumptions in the current provided APIs, which may make them inappropriate for your learning requirements. Indeed, they assume your model will learn from a training dataset and be From 2c516154716f734e84fff13f9b442eadd4edf7eb Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Tue, 25 Jun 2024 08:23:10 -0500 Subject: [PATCH 2/4] Print model structure like with PyTorch - Part 1 (#1912) --- Cargo.lock | 8 +- Cargo.toml | 13 +- crates/burn-common/Cargo.toml | 6 +- crates/burn-common/src/id.rs | 17 +- crates/burn-core/src/module/display.rs | 547 ++++++++++++++++++ crates/burn-core/src/module/mod.rs | 2 + crates/burn-core/src/module/param/constant.rs | 63 +- .../burn-core/src/module/param/primitive.rs | 65 ++- crates/burn-core/src/module/param/running.rs | 26 +- crates/burn-core/src/module/param/tensor.rs | 57 +- crates/burn-core/src/nn/conv/conv1d.rs | 43 +- crates/burn-core/src/nn/conv/conv2d.rs | 36 +- crates/burn-core/src/nn/dropout.rs | 15 +- crates/burn-core/src/nn/linear.rs | 20 + crates/burn-core/src/nn/norm/batch.rs | 20 + crates/burn-core/src/nn/norm/layer.rs | 20 +- crates/burn-core/src/nn/padding.rs | 5 +- crates/burn-core/src/nn/pool/avg_pool1d.rs | 6 +- crates/burn-core/src/nn/pool/avg_pool2d.rs | 6 +- crates/burn-core/src/nn/pool/max_pool1d.rs | 6 +- crates/burn-core/src/nn/pool/max_pool2d.rs | 6 +- crates/burn-core/src/nn/unfold.rs | 17 +- crates/burn-core/src/record/serde/ser.rs | 2 +- crates/burn-derive/src/lib.rs | 2 +- crates/burn-derive/src/module/codegen.rs | 110 +++- crates/burn-derive/src/module/display.rs | 93 ++- crates/burn-train/src/learner/summary.rs | 2 +- 27 files changed, 1135 insertions(+), 78 deletions(-) create mode 100644 crates/burn-core/src/module/display.rs diff --git a/Cargo.lock b/Cargo.lock index 663763a962..827ce13c76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -463,6 +463,7 @@ version = "0.14.0" dependencies = [ "async-trait", "dashmap", + "data-encoding", "derive-new", "getrandom", "indicatif", @@ -471,7 +472,6 @@ dependencies = [ "serde", "spin", "tokio", - "uuid", "web-time", ] @@ -1469,6 +1469,12 @@ dependencies = [ "parking_lot_core 0.9.10", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "deflate64" version = "0.1.8" diff --git a/Cargo.toml b/Cargo.toml index fab90620ee..8e9737edf9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,9 @@ colored = "2.1.0" console_error_panic_hook = "0.1.7" csv = "1.3.0" dashmap = "5.5.3" +data-encoding = { version = "2.6.0", default-features = false, features = [ + "alloc", +] } dirs = "5.0.1" fake = "2.9.2" flate2 = "1.0.30" @@ -42,16 +45,19 @@ getrandom = { version = "0.2.15", default-features = false } gix-tempfile = { version = "13.1.1", features = ["signals"] } globwalk = "0.9.1" hashbrown = "0.14.5" +hound = "3.5.1" +image = "0.25.1" indicatif = "0.17.8" js-sys = "0.3.69" libm = "0.2.8" log = { default-features = false, version = "0.4.21" } +md5 = "0.7.0" +percent-encoding = "2.3.1" pretty_assertions = "1.4.0" proc-macro2 = "1.0.85" protobuf = "3.4.0" protobuf-codegen = "3.4.0" quote = "1.0.36" -percent-encoding = "2.3.1" r2d2 = "0.8.10" r2d2_sqlite = { version = "0.24.0" } rayon = "1.10.0" @@ -63,6 +69,7 @@ rusqlite = { version = "0.31.0" } rust-format = { version = "0.3.4" } sanitize-filename = "0.5.0" serde_rusqlite = "0.35.0" +serial_test = "3.1.1" spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } strum = "0.26.2" strum_macros = "0.26.4" @@ -73,11 +80,7 @@ tokio = { version = "1.38.0", features = ["rt", "macros"] } tracing-appender = "0.2.3" tracing-core = "0.1.32" tracing-subscriber = "0.3.18" -md5 = "0.7.0" -serial_test = "3.1.1" web-time = "1.1.0" -hound = "3.5.1" -image = "0.25.1" zip = "2.1.3" # Terminal UI diff --git a/crates/burn-common/Cargo.toml b/crates/burn-common/Cargo.toml index ecd48b27bb..ec35ffbf8a 100644 --- a/crates/burn-common/Cargo.toml +++ b/crates/burn-common/Cargo.toml @@ -12,7 +12,7 @@ version.workspace = true [features] default = ["std"] -std = ["rand/std"] +std = ["rand/std", "data-encoding/std"] doc = ["default"] wasm-sync = [] network = ["dep:indicatif", "dep:reqwest", "dep:tokio"] @@ -27,10 +27,10 @@ web-time = { version = "1.1.0" } # ** Please make sure all dependencies support no_std when std is disabled ** rand = { workspace = true } -spin = { workspace = true } # using in place of use std::sync::Mutex; -uuid = { workspace = true } +spin = { workspace = true } # using in place of use std::sync::Mutex; derive-new = { workspace = true } serde = { workspace = true } +data-encoding = { workspace = true } # Network downloader indicatif = { workspace = true, optional = true } diff --git a/crates/burn-common/src/id.rs b/crates/burn-common/src/id.rs index 25c2161817..eb6bd900c9 100644 --- a/crates/burn-common/src/id.rs +++ b/crates/burn-common/src/id.rs @@ -1,18 +1,21 @@ +use alloc::string::String; + use crate::rand::gen_random; -use alloc::string::{String, ToString}; -use uuid::{Builder, Bytes}; + +use data_encoding::BASE32_DNSSEC; /// Simple ID generator. pub struct IdGenerator {} impl IdGenerator { - /// Generates a new ID in the form of a UUID. + /// Generates a new ID. pub fn generate() -> String { - let random_bytes: Bytes = gen_random(); - - let uuid = Builder::from_random_bytes(random_bytes).into_uuid(); + // Generate 6 random bytes (281,474,976,710,656 combinations) + let random_bytes: [u8; 6] = gen_random(); - uuid.as_hyphenated().to_string() + // Encode the random bytes in base32 DNSSEC + // 6 bytes encodes to 10 lower case characters, e.g. "3uu5e6vv7c" + BASE32_DNSSEC.encode(&random_bytes) } } diff --git a/crates/burn-core/src/module/display.rs b/crates/burn-core/src/module/display.rs new file mode 100644 index 0000000000..7a43b3e6b0 --- /dev/null +++ b/crates/burn-core/src/module/display.rs @@ -0,0 +1,547 @@ +use alloc::{ + borrow::ToOwned, + format, + string::{String, ToString}, + vec::Vec, +}; +use core::any; +use core::fmt::{Display, Write}; + +/// Default display settings for a module. +pub trait ModuleDisplayDefault { + /// Attributes of the module used for display purposes. + /// + /// # Arguments + /// + /// * `_content` - The content object that contains display settings and attributes. + /// + /// # Returns + /// + /// An optional content object containing the display attributes. + fn content(&self, _content: Content) -> Option; + + /// Gets the number of the parameters of the module. + fn num_params(&self) -> usize { + 0 + } +} + +/// Trait to implement custom display settings for a module. +/// +/// In order to implement custom display settings for a module, +/// 1. Add #[module(custom_display)] attribute to the module struct after #[derive(Module)] +/// 2. Implement ModuleDisplay trait for the module +pub trait ModuleDisplay: ModuleDisplayDefault { + /// Formats the module with provided display settings. + /// + /// # Arguments + /// + /// * `passed_settings` - Display settings passed to the module. + /// + /// # Returns + /// + /// A string representation of the formatted module. + fn format(&self, passed_settings: DisplaySettings) -> String { + let settings = if let Some(custom_settings) = self.custom_settings() { + custom_settings.inherit(passed_settings) + } else { + passed_settings + }; + + let indent = " ".repeat(settings.level * settings.indentation_size()); + let indent_close_braces = " ".repeat((settings.level - 1) * settings.indentation_size()); + + let settings = settings.level_up(); + + let self_type = extract_type_name::(); + + // Use custom content if it is implemented and show_all_attributes is false, + // otherwise use default content + let content = if !settings.show_all_attributes() { + self.custom_content(Content::new(settings.clone())) + .unwrap_or_else(|| { + self.content(Content::new(settings.clone())) + .unwrap_or_else(|| { + panic!("Default content should be implemented for {self_type}.") + }) + }) + } else { + self.content(Content::new(settings.clone())) + .unwrap_or_else(|| panic!("Default content should be implemented for {self_type}.")) + }; + + let top_level_type = if let Some(top_level_type) = content.top_level_type { + top_level_type.to_owned() + } else { + self_type.to_owned() + }; + + // If there is only one item in the content, return it or no attributes + if let Some(item) = content.single_item { + return item; + } else if content.attributes.is_empty() { + return top_level_type.to_string(); + } + + let mut result = String::new(); + + // Print the struct name + if settings.new_line_after_attribute() { + writeln!(result, "{} {{", top_level_type).unwrap(); + } else { + write!(result, "{} {{", top_level_type).unwrap(); + } + + for (i, attribute) in content.attributes.iter().enumerate() { + if settings.new_line_after_attribute() { + writeln!(result, "{indent}{}: {}", attribute.name, attribute.value).unwrap(); + } else if i == 0 { + write!(result, "{}: {}", attribute.name, attribute.value).unwrap(); + } else { + write!(result, ", {}: {}", attribute.name, attribute.value).unwrap(); + } + } + + if settings.show_num_parameters() { + let num_params = self.num_params(); + if num_params > 0 { + if settings.new_line_after_attribute() { + writeln!(result, "{indent}params: {}", num_params).unwrap(); + } else { + write!(result, ", params: {}", num_params).unwrap(); + } + } + } + + if settings.new_line_after_attribute() { + write!(result, "{indent_close_braces}}}").unwrap(); + } else { + write!(result, "}}").unwrap(); + } + + result + } + + /// Custom display settings for the module. + /// + /// # Returns + /// + /// An optional display settings object. + fn custom_settings(&self) -> Option { + None + } + + /// Custom attributes for the module. + /// + /// # Arguments + /// + /// * `_content` - The content object that contains display settings and attributes. + /// + /// # Returns + /// + /// An optional content object containing the custom attributes. + fn custom_content(&self, _content: Content) -> Option { + None + } +} + +/// Custom module display settings. +#[derive(Debug, Clone)] +pub struct DisplaySettings { + /// Whether to print the module parameter ids. + show_param_id: Option, + + /// Whether to print the module attributes. + show_all_attributes: Option, + + /// Whether to print the module number of parameters. + show_num_parameters: Option, + + /// Print new line after an attribute. + new_line_after_attribute: Option, + + /// Indentation size. + indentation_size: Option, + + /// Level of indentation. + level: usize, +} + +impl Default for DisplaySettings { + fn default() -> Self { + DisplaySettings { + show_param_id: None, + show_all_attributes: None, + show_num_parameters: None, + new_line_after_attribute: None, + indentation_size: None, + level: 1, + } + } +} + +impl DisplaySettings { + /// Create a new format settings. + /// + /// # Returns + /// + /// A new instance of `DisplaySettings`. + pub fn new() -> Self { + Default::default() + } + + /// Sets a flag to show module parameters. + /// + /// # Arguments + /// + /// * `flag` - Boolean flag to show module parameters. + /// + /// # Returns + /// + /// Updated `DisplaySettings` instance. + pub fn with_show_param_id(mut self, flag: bool) -> Self { + self.show_param_id = Some(flag); + self + } + + /// Sets a flag to show module attributes. + /// + /// # Arguments + /// + /// * `flag` - Boolean flag to show all module attributes. + /// + /// # Returns + /// + /// Updated `DisplaySettings` instance. + pub fn with_show_all_attributes(mut self, flag: bool) -> Self { + self.show_all_attributes = Some(flag); + self + } + + /// Sets a flag to show the number of module parameters. + /// + /// # Arguments + /// + /// * `flag` - Boolean flag to show the number of module parameters. + /// + /// # Returns + /// + /// Updated `DisplaySettings` instance. + pub fn with_show_num_parameters(mut self, flag: bool) -> Self { + self.show_num_parameters = Some(flag); + self + } + + /// Sets a flag to print a new line after an attribute. + /// + /// # Arguments + /// + /// * `flag` - Boolean flag to print a new line after an attribute. + /// + /// # Returns + /// + /// Updated `DisplaySettings` instance. + pub fn with_new_line_after_attribute(mut self, flag: bool) -> Self { + self.new_line_after_attribute = Some(flag); + self + } + + /// Sets the indentation size. + /// + /// # Arguments + /// + /// * `size` - The size of the indentation. + /// + /// # Returns + /// + /// Updated `DisplaySettings` instance. + pub fn with_indentation_size(mut self, size: usize) -> Self { + self.indentation_size = Some(size); + self + } + + /// Inherits settings from the provided settings and return a new settings object. + /// + /// # Arguments + /// + /// * `top` - The top level `DisplaySettings` to inherit from. + /// + /// # Returns + /// + /// Updated `DisplaySettings` instance. + pub fn inherit(self, top: Self) -> Self { + let mut updated = self.clone(); + + if let Some(show_param_id) = top.show_param_id { + updated.show_param_id = Some(show_param_id); + }; + + if let Some(show_all_attributes) = top.show_all_attributes { + updated.show_all_attributes = Some(show_all_attributes); + } + + if let Some(show_num_parameters) = top.show_num_parameters { + updated.show_num_parameters = Some(show_num_parameters); + } + + if let Some(new_line_after_attribute) = top.new_line_after_attribute { + updated.new_line_after_attribute = Some(new_line_after_attribute); + } + + if let Some(indentation_size) = top.indentation_size { + updated.indentation_size = Some(indentation_size); + } + + updated.level = top.level; + + updated + } + + /// A convenience method to wrap the DisplaySettings struct in an option. + /// + /// # Returns + /// + /// An optional `DisplaySettings`. + pub fn optional(self) -> Option { + Some(self) + } + + /// Increases the level of indentation. + /// + /// # Returns + /// + /// Updated `DisplaySettings` instance with increased indentation level. + pub fn level_up(mut self) -> Self { + self.level += 1; + self + } + + /// Gets `show_param_id` flag, substitutes false if not set. + /// + /// This flag is used to print the module parameter ids. + /// + /// # Returns + /// + /// A boolean value indicating whether to show parameter ids. + pub fn show_param_id(&self) -> bool { + self.show_param_id.unwrap_or(false) + } + + /// Gets `show_all_attributes`, substitutes false if not set. + /// + /// This flag is used to force to print all module attributes, overriding custom attributes. + /// + /// # Returns + /// + /// A boolean value indicating whether to show all attributes. + pub fn show_all_attributes(&self) -> bool { + self.show_all_attributes.unwrap_or(false) + } + + /// Gets `show_num_parameters`, substitutes true if not set. + /// + /// This flag is used to print the number of module parameters. + /// + /// # Returns + /// + /// A boolean value indicating whether to show the number of parameters. + pub fn show_num_parameters(&self) -> bool { + self.show_num_parameters.unwrap_or(true) + } + + /// Gets `new_line_after_attribute`, substitutes true if not set. + /// + /// This flag is used to print a new line after an attribute. + /// + /// # Returns + /// + /// A boolean value indicating whether to print a new line after an attribute. + pub fn new_line_after_attribute(&self) -> bool { + self.new_line_after_attribute.unwrap_or(true) + } + + /// Gets `indentation_size`, substitutes 2 if not set. + /// + /// This flag is used to set the size of indentation. + /// + /// # Returns + /// + /// An integer value indicating the size of indentation. + pub fn indentation_size(&self) -> usize { + self.indentation_size.unwrap_or(2) + } +} + +/// Struct to store the attributes of a module for formatting. +#[derive(Clone, Debug)] +pub struct Content { + /// List of attributes. + pub attributes: Vec, + + /// Single item content. + pub single_item: Option, + + /// Display settings. + pub display_settings: DisplaySettings, + + /// Top level type name. + pub top_level_type: Option, +} + +impl Content { + /// Creates a new attributes struct. + /// + /// # Arguments + /// + /// * `display_settings` - Display settings for the content. + /// + /// # Returns + /// + /// A new instance of `Content`. + pub fn new(display_settings: DisplaySettings) -> Self { + Content { + attributes: Vec::new(), + single_item: None, + display_settings, + top_level_type: None, + } + } + + /// Adds an attribute to the format settings. The value will be formatted and stored as a string. + /// + /// # Arguments + /// + /// * `name` - Name of the attribute. + /// * `value` - Value of the attribute. + /// + /// # Returns + /// + /// Updated `Content` instance with the new attribute added. + pub fn add(mut self, name: &str, value: &T) -> Self { + if self.single_item.is_some() { + panic!("Cannot add multiple attributes when single item is set."); + } + + let attribute = Attribute { + name: name.to_owned(), + value: value.format(self.display_settings.clone()), // TODO level + 1 + ty: any::type_name::().to_string(), + }; + self.attributes.push(attribute); + self + } + + /// Adds a single item. + /// + /// # Arguments + /// + /// * `value` - Rendered string of the single item. + /// + /// # Returns + /// + /// Updated `Content` instance with the single item added. + pub fn add_single(mut self, value: &T) -> Self { + if !self.attributes.is_empty() { + panic!("Cannot add single item when attributes are set."); + } + + self.single_item = Some(value.format(self.display_settings.clone())); + + self + } + + /// Adds a single item. + /// + /// # Arguments + /// + /// * `value` - Formatted display value. + /// + /// # Returns + /// + /// Updated `Content` instance with the formatted single item added. + pub fn add_formatted(mut self, value: &T) -> Self { + if !self.attributes.is_empty() { + panic!("Cannot add single item when attributes are set."); + } + + self.single_item = Some(format!("{}", value)); + self + } + + /// A convenience method to wrap the Attributes struct in an option + /// because it is often used as an optional field. + /// + /// # Returns + /// + /// An optional `Content`. + pub fn optional(self) -> Option { + if self.attributes.is_empty() && self.single_item.is_none() && self.top_level_type.is_none() + { + None + } else { + Some(self) + } + } + + /// Sets the top level type name. + /// + /// # Arguments + /// + /// * `ty` - The type name to set. + /// + /// # Returns + /// + /// Updated `Content` instance with the top level type name set. + pub fn set_top_level_type(mut self, ty: &str) -> Self { + self.top_level_type = Some(ty.to_owned()); + self + } +} + +/// Attribute to print in the display method. +#[derive(Clone, Debug)] +pub struct Attribute { + /// Name of the attribute. + pub name: String, + + /// Value of the attribute. + pub value: String, + + /// Type of the attribute. + pub ty: String, +} + +/// Extracts the short name of a type T +/// +/// # Returns +/// +/// A string slice representing the short name of the type. +pub fn extract_type_name() -> &'static str { + // Get the full type name of T, including module path and generic parameters + let ty = any::type_name::(); + + // Find the first occurrence of '<' in the full type name + // If not found, use the length of the type name + let end = ty.find('<').unwrap_or(ty.len()); + + // Slice the type name up to the first '<' or the end + let ty = &ty[0..end]; + + // Find the last occurrence of "::" in the sliced type name + // If found, add 2 to skip the "::" itself + // If not found, start from the beginning of the type name + let start = ty.rfind("::").map(|i| i + 2).unwrap_or(0); + + // Find the last occurrence of '<' in the sliced type name + // If not found, use the length of the type name + let end = ty.rfind('<').unwrap_or(ty.len()); + + // If the start index is less than the end index, + // return the slice of the type name from start to end + // Otherwise, return the entire sliced type name + if start < end { + &ty[start..end] + } else { + ty + } +} diff --git a/crates/burn-core/src/module/mod.rs b/crates/burn-core/src/module/mod.rs index 60d1567523..5e9e8ef934 100644 --- a/crates/burn-core/src/module/mod.rs +++ b/crates/burn-core/src/module/mod.rs @@ -1,5 +1,7 @@ mod base; +mod display; mod param; pub use base::*; +pub use display::*; pub use param::*; diff --git a/crates/burn-core/src/module/param/constant.rs b/crates/burn-core/src/module/param/constant.rs index 24c8e3a99e..5d969f520a 100644 --- a/crates/burn-core/src/module/param/constant.rs +++ b/crates/burn-core/src/module/param/constant.rs @@ -1,6 +1,12 @@ +use alloc::{format, string::ToString}; +use core::{fmt::Display, marker::PhantomData}; + use crate::{ self as burn, - module::{AutodiffModule, Devices, Module, ModuleMapper, ModuleVisitor}, + module::{ + AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault, + ModuleMapper, ModuleVisitor, + }, record::Record, }; use burn::record::PrecisionSettings; @@ -8,7 +14,6 @@ use burn_tensor::{ backend::{AutodiffBackend, Backend}, BasicAutodiffOps, BasicOps, Tensor, }; -use core::marker::PhantomData; /// Record used for constant type implementing the [module](crate::module::Module) trait. #[derive(Debug, Clone, Copy, new, Default)] @@ -96,6 +101,15 @@ macro_rules! constant { impl burn::module::AutodiffModule for $type { constant!(ad_module, $type); } + + impl burn::module::ModuleDisplayDefault for $type { + fn content(&self, content: burn::module::Content) -> Option { + let string = format!("{}", self); + content.add_formatted(&string).optional() + } + } + + impl burn::module::ModuleDisplay for $type {} }; } @@ -122,6 +136,13 @@ constant!(i32); constant!(i16); constant!(i8); +impl burn::module::ModuleDisplay for str {} +impl burn::module::ModuleDisplayDefault for str { + fn content(&self, content: burn::module::Content) -> Option { + content.add_formatted(&self).optional() + } +} + impl> Module for Tensor { type Record = ConstantRecord; @@ -158,6 +179,15 @@ impl> Module for Tensor { } } +impl> ModuleDisplayDefault for Tensor { + fn content(&self, content: Content) -> Option { + let string = format!("Tensor {{rank: {D}, shape: {:?}}}", self.shape().dims); + content.add_single(&string).optional() + } +} + +impl> ModuleDisplay for Tensor {} + impl> AutodiffModule for Tensor { @@ -200,6 +230,14 @@ impl Module for PhantomData { } } +impl ModuleDisplayDefault for PhantomData { + fn content(&self, content: Content) -> Option { + content.add_single(&"PhantomData".to_string()).optional() + } +} + +impl ModuleDisplay for PhantomData {} + impl AutodiffModule for PhantomData { type InnerModule = PhantomData; @@ -248,6 +286,27 @@ where } } +impl ModuleDisplayDefault for Ignored +where + T: Sync + Send + core::fmt::Debug + Clone, +{ + fn content(&self, content: Content) -> Option { + // For now, just print the debug representation of the ignored value + content.add_single(&format!("{:?}", self.0)).optional() + } +} + +impl ModuleDisplay for Ignored where T: Sync + Send + core::fmt::Debug + Clone {} + +impl Display for Ignored +where + T: Sync + Send + core::fmt::Debug + Clone, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{:?}", self.0) + } +} + impl AutodiffModule for Ignored where B: AutodiffBackend, diff --git a/crates/burn-core/src/module/param/primitive.rs b/crates/burn-core/src/module/param/primitive.rs index 840732fd46..719b61d5c7 100644 --- a/crates/burn-core/src/module/param/primitive.rs +++ b/crates/burn-core/src/module/param/primitive.rs @@ -1,5 +1,10 @@ -use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}; -use alloc::vec::Vec; +use crate::module::{ + AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper, + ModuleVisitor, +}; + +use alloc::{format, vec::Vec}; + use burn_tensor::backend::{AutodiffBackend, Backend}; use core::fmt::Debug; @@ -52,6 +57,17 @@ where } } +impl ModuleDisplayDefault for Option { + fn content(&self, content: Content) -> Option { + match self { + Some(module) => content.add_single(module).optional(), + None => content.add_single("None").optional(), + } + } +} + +impl ModuleDisplay for Option {} + impl AutodiffModule for Option where T: AutodiffModule + Debug + Send + Clone, @@ -128,6 +144,21 @@ where } } +impl ModuleDisplayDefault for Vec { + fn content(&self, content: Content) -> Option { + self.iter() + .enumerate() + .fold(content, |acc, (i, module)| { + let index = format!("{}", i); + acc.add(&index, module) + }) + .set_top_level_type(format!("Vec<0..{}>", self.len()).as_str()) + .optional() + } +} + +impl ModuleDisplay for Vec {} + impl AutodiffModule for Vec where T: AutodiffModule + Debug + Send + Clone, @@ -197,6 +228,21 @@ where } } +impl ModuleDisplayDefault for [T; N] { + fn content(&self, content: Content) -> Option { + self.iter() + .enumerate() + .fold(content, |acc, (i, module)| { + let index = format!("{}", i); + acc.add(&index, module) + }) + .set_top_level_type(format!("[0..{}]", self.len()).as_str()) + .optional() + } +} + +impl ModuleDisplay for [T; N] {} + impl AutodiffModule for [T; N] where T: AutodiffModule + Debug + Send + Clone + Copy, @@ -269,6 +315,21 @@ macro_rules! impl_module_tuple { ($(self.$i.valid(),)*) } } + + impl<$($l,)*> ModuleDisplayDefault for ($($l,)*) + where + $($l: ModuleDisplay,)* + { + fn content(&self, content: Content) -> Option { + let content = content + $(.add(&format!("{}", $i), &self.$i))* + .set_top_level_type(format!("({})", stringify!($($l),*)).as_str()); + content.optional() + } + } + + impl<$($l,)*> ModuleDisplay for ($($l,)*) where $($l: ModuleDisplay,)* {} + }; } diff --git a/crates/burn-core/src/module/param/running.rs b/crates/burn-core/src/module/param/running.rs index 7b2b2cf6a6..b81e2576f0 100644 --- a/crates/burn-core/src/module/param/running.rs +++ b/crates/burn-core/src/module/param/running.rs @@ -1,7 +1,13 @@ use super::ParamId; -use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor, Param}; +use crate::module::{ + AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper, + ModuleVisitor, Param, +}; + +use alloc::string::ToString; use alloc::sync::Arc; use alloc::vec::Vec; + use burn_common::stub::Mutex; use burn_tensor::{ backend::{AutodiffBackend, Backend}, @@ -45,6 +51,24 @@ pub struct RunningState { value: Arc>, } +// Implement display for the module + +impl core::fmt::Display for RunningState { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "RunningState(id={})", self.id) + } +} + +impl ModuleDisplayDefault for RunningState { + fn content(&self, content: Content) -> Option { + content + .add_formatted(&"RunningState".to_string()) + .optional() + } +} + +impl ModuleDisplay for RunningState {} + impl Module for RunningState> { type Record = Param>; diff --git a/crates/burn-core/src/module/param/tensor.rs b/crates/burn-core/src/module/param/tensor.rs index cb60252835..f9ce7e913c 100644 --- a/crates/burn-core/src/module/param/tensor.rs +++ b/crates/burn-core/src/module/param/tensor.rs @@ -1,10 +1,13 @@ use super::{Param, ParamId, Parameter}; -use crate::module::{AutodiffModule, Module, ModuleMapper, ModuleVisitor}; +use crate::module::{ + AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper, + ModuleVisitor, +}; use crate::tensor::{ backend::{AutodiffBackend, Backend}, Tensor, }; -use alloc::vec::Vec; +use alloc::{format, string::ToString, vec::Vec}; use burn_tensor::{Bool, Data, Float, Int}; impl Parameter for Tensor { @@ -147,6 +150,22 @@ impl Module for Param> { } } +impl ModuleDisplayDefault for Param> { + fn content(&self, content: Content) -> Option { + let id = if content.display_settings.show_param_id() { + format!(", id: {}", self.id) + } else { + "".to_string() + }; + let string = format!( + "ParamTensor {{rank: {D}, shape: {:?}, kind: float{id}}}", + self.shape().dims + ); + content.add_formatted(&string).optional() + } +} +impl ModuleDisplay for Param> {} + impl Module for Param> { type Record = Param>; @@ -198,6 +217,22 @@ impl Module for Param> { } } +impl ModuleDisplayDefault for Param> { + fn content(&self, content: Content) -> Option { + let id = if content.display_settings.show_param_id() { + format!(", id: {}", self.id) + } else { + "".to_string() + }; + let string = format!( + "ParamTensor {{rank: {D}, shape: {:?}, kind: int{id}}}", + self.shape().dims + ); + content.add_formatted(&string).optional() + } +} +impl ModuleDisplay for Param> {} + impl Module for Param> { type Record = Param>; @@ -249,6 +284,24 @@ impl Module for Param> { } } +impl ModuleDisplayDefault for Param> { + fn content(&self, content: Content) -> Option { + let id = if content.display_settings.show_param_id() { + format!(", id: {}", self.id) + } else { + "".to_string() + }; + + let string = format!( + "ParamTensor {{rank: {D}, shape: {:?}, kind: bool{id}}}", + self.shape().dims + ); + content.add_formatted(&string).optional() + } +} + +impl ModuleDisplay for Param> {} + impl AutodiffModule for Param> { type InnerModule = Param>; diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs index a14d668a82..e05231b274 100644 --- a/crates/burn-core/src/nn/conv/conv1d.rs +++ b/crates/burn-core/src/nn/conv/conv1d.rs @@ -1,14 +1,13 @@ +use alloc::format; + use crate as burn; -use crate::config::Config; -use crate::module::Module; -use crate::module::Param; -use crate::nn::conv::checks; -use crate::nn::{Initializer, PaddingConfig1d}; -use crate::tensor::backend::Backend; -use crate::tensor::module::conv1d; -use crate::tensor::ops::ConvOptions; -use crate::tensor::Tensor; +use crate::{ + config::Config, + module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param}, + nn::{conv::checks, Initializer, PaddingConfig1d}, + tensor::{backend::Backend, module::conv1d, ops::ConvOptions, Tensor}, +}; /// Configuration to create a [1D convolution](Conv1d) layer using the [init function](Conv1dConfig::init). #[derive(Config, Debug)] @@ -45,6 +44,7 @@ pub struct Conv1dConfig { /// /// Should be created with [Conv1dConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct Conv1d { /// Tensor of shape `[channels_out, channels_in / groups, kernel_size]` pub weight: Param>, @@ -54,7 +54,28 @@ pub struct Conv1d { kernel_size: usize, dilation: usize, groups: usize, - padding: PaddingConfig1d, + padding: Ignored, +} + +impl ModuleDisplay for Conv1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + // Since padding does not implement ModuleDisplay, we need to format it manually. + let padding_formatted = format!("{}", &self.padding); + + content + .add("stride", &self.stride) + .add("kernel_size", &self.kernel_size) + .add("dilation", &self.dilation) + .add("groups", &self.groups) + .add("padding", &padding_formatted) + .optional() + } } impl Conv1dConfig { @@ -87,7 +108,7 @@ impl Conv1dConfig { bias, stride: self.stride, kernel_size: self.kernel_size, - padding: self.padding.clone(), + padding: Ignored(self.padding.clone()), dilation: self.dilation, groups: self.groups, } diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index c7350d9916..ed34a089d4 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -1,8 +1,9 @@ +use alloc::format; + use crate as burn; use crate::config::Config; -use crate::module::Module; -use crate::module::Param; +use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param}; use crate::nn::Initializer; use crate::nn::PaddingConfig2d; use crate::tensor::backend::Backend; @@ -45,6 +46,7 @@ pub struct Conv2dConfig { /// /// Should be created with [Conv2dConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct Conv2d { /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]` pub weight: Param>, @@ -54,7 +56,7 @@ pub struct Conv2d { kernel_size: [usize; 2], dilation: [usize; 2], groups: usize, - padding: PaddingConfig2d, + padding: Ignored, } impl Conv2dConfig { @@ -93,12 +95,38 @@ impl Conv2dConfig { stride: self.stride, kernel_size: self.kernel_size, dilation: self.dilation, - padding: self.padding.clone(), + padding: Ignored(self.padding.clone()), groups: self.groups, } } } +impl ModuleDisplay for Conv2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + // Since padding does not implement ModuleDisplay, we need to format it manually. + let padding_formatted = format!("{}", &self.padding); + + // Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed. + let stride = format!("{:?}", self.stride); + let kernel_size = format!("{:?}", self.kernel_size); + let dilation = format!("{:?}", self.dilation); + + content + .add("stride", &stride) + .add("kernel_size", &kernel_size) + .add("dilation", &dilation) + .add("groups", &self.groups) + .add("padding", &padding_formatted) + .optional() + } +} + impl Conv2d { /// Applies the forward pass on the input tensor. /// diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs index e10dcf50b4..b4bee8d61e 100644 --- a/crates/burn-core/src/nn/dropout.rs +++ b/crates/burn-core/src/nn/dropout.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{DisplaySettings, Module, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::{Distribution, Tensor}; @@ -21,6 +21,7 @@ pub struct DropoutConfig { /// /// Should be created with [DropoutConfig]. #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct Dropout { prob: f64, } @@ -54,6 +55,18 @@ impl Dropout { } } +impl ModuleDisplay for Dropout { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: crate::module::Content) -> Option { + content.add("prob", &self.prob).optional() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index d54da5f6fd..0bc16552f1 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -1,4 +1,6 @@ use crate as burn; +use crate::module::DisplaySettings; +use crate::module::ModuleDisplay; use crate::config::Config; use crate::module::Module; @@ -30,6 +32,7 @@ pub struct LinearConfig { /// /// `O = IW + b` #[derive(Module, Debug)] +#[module(custom_display)] pub struct Linear { /// Matrix of shape `[d_input, d_output]` initialized from a uniform distribution: /// `U(-k, k)`, where `k = sqrt(1 / d_input)` @@ -83,6 +86,23 @@ impl Linear { } } +impl ModuleDisplay for Linear { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: crate::module::Content) -> Option { + let [d_input, d_output] = self.weight.shape().dims; + content + .add("d_input", &d_input) + .add("d_output", &d_output) + .add("bias", &self.bias.is_some()) + .optional() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/burn-core/src/nn/norm/batch.rs b/crates/burn-core/src/nn/norm/batch.rs index 9fdca5ed46..e0d06d4923 100644 --- a/crates/burn-core/src/nn/norm/batch.rs +++ b/crates/burn-core/src/nn/norm/batch.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::nn::Initializer; use crate::{ @@ -33,6 +34,7 @@ pub struct BatchNormConfig { /// /// Should be created using [BatchNormConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct BatchNorm { /// The learnable weight gamma. pub gamma: Param>, @@ -183,6 +185,24 @@ impl BatchNorm { } } +impl ModuleDisplay for BatchNorm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [num_features] = self.beta.shape().dims; + + content + .add("num_features", &num_features) + .add("momentum", &self.momentum) + .add("epsilon", &self.epsilon) + .optional() + } +} + #[cfg(feature = "std")] #[cfg(test)] mod tests_1d { diff --git a/crates/burn-core/src/nn/norm/layer.rs b/crates/burn-core/src/nn/norm/layer.rs index c0dc71afa8..f425f0c12f 100644 --- a/crates/burn-core/src/nn/norm/layer.rs +++ b/crates/burn-core/src/nn/norm/layer.rs @@ -1,7 +1,8 @@ use crate as burn; - use crate::config::Config; +use crate::module::DisplaySettings; use crate::module::Module; +use crate::module::ModuleDisplay; use crate::module::Param; use crate::nn::Initializer; use crate::tensor::backend::Backend; @@ -29,6 +30,7 @@ pub struct LayerNormConfig { /// /// Should be created using [LayerNormConfig](LayerNormConfig). #[derive(Module, Debug)] +#[module(custom_display)] pub struct LayerNorm { /// The learnable weight. gamma: Param>, @@ -71,6 +73,22 @@ impl LayerNorm { } } +impl ModuleDisplay for LayerNorm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: crate::module::Content) -> Option { + let [d_model] = self.gamma.shape().dims; + content + .add("d_model", &d_model) + .add("epsilon", &self.epsilon) + .optional() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/burn-core/src/nn/padding.rs b/crates/burn-core/src/nn/padding.rs index 8f64340108..f4ab4750ee 100644 --- a/crates/burn-core/src/nn/padding.rs +++ b/crates/burn-core/src/nn/padding.rs @@ -3,10 +3,9 @@ use crate as burn; use crate::tensor::ops::conv::calculate_conv_padding; use crate::config::Config; -use crate::module::Module; /// Padding configuration for 1D operators. -#[derive(Module, Config, Debug, PartialEq)] +#[derive(Config, Debug, PartialEq)] pub enum PaddingConfig1d { /// Dynamically calculate the amount of padding necessary to ensure that the output size will be /// the same as the input. @@ -34,7 +33,7 @@ impl PaddingConfig1d { } /// Padding configuration for 2D operators. -#[derive(Module, Config, Debug, PartialEq)] +#[derive(Config, Debug, PartialEq)] pub enum PaddingConfig2d { /// Dynamically calculate the amount of padding necessary to ensure that the output size will be /// the same as the input. diff --git a/crates/burn-core/src/nn/pool/avg_pool1d.rs b/crates/burn-core/src/nn/pool/avg_pool1d.rs index 4b58eb3025..5787cc5e2a 100644 --- a/crates/burn-core/src/nn/pool/avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool1d.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig1d; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -43,7 +43,7 @@ pub struct AvgPool1dConfig { pub struct AvgPool1d { stride: usize, kernel_size: usize, - padding: PaddingConfig1d, + padding: Ignored, count_include_pad: bool, } @@ -53,7 +53,7 @@ impl AvgPool1dConfig { AvgPool1d { stride: self.stride, kernel_size: self.kernel_size, - padding: self.padding.clone(), + padding: Ignored(self.padding.clone()), count_include_pad: self.count_include_pad, } } diff --git a/crates/burn-core/src/nn/pool/avg_pool2d.rs b/crates/burn-core/src/nn/pool/avg_pool2d.rs index fb2aff8d2e..00bf712f80 100644 --- a/crates/burn-core/src/nn/pool/avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool2d.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig2d; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -42,7 +42,7 @@ pub struct AvgPool2dConfig { pub struct AvgPool2d { stride: [usize; 2], kernel_size: [usize; 2], - padding: PaddingConfig2d, + padding: Ignored, count_include_pad: bool, } @@ -52,7 +52,7 @@ impl AvgPool2dConfig { AvgPool2d { stride: self.strides, kernel_size: self.kernel_size, - padding: self.padding.clone(), + padding: Ignored(self.padding.clone()), count_include_pad: self.count_include_pad, } } diff --git a/crates/burn-core/src/nn/pool/max_pool1d.rs b/crates/burn-core/src/nn/pool/max_pool1d.rs index 632ab6622d..040a7a1027 100644 --- a/crates/burn-core/src/nn/pool/max_pool1d.rs +++ b/crates/burn-core/src/nn/pool/max_pool1d.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig1d; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -31,7 +31,7 @@ pub struct MaxPool1dConfig { pub struct MaxPool1d { stride: usize, kernel_size: usize, - padding: PaddingConfig1d, + padding: Ignored, dilation: usize, } @@ -41,7 +41,7 @@ impl MaxPool1dConfig { MaxPool1d { stride: self.stride, kernel_size: self.kernel_size, - padding: self.padding.clone(), + padding: Ignored(self.padding.clone()), dilation: self.dilation, } } diff --git a/crates/burn-core/src/nn/pool/max_pool2d.rs b/crates/burn-core/src/nn/pool/max_pool2d.rs index 63dee1326d..552cde9b35 100644 --- a/crates/burn-core/src/nn/pool/max_pool2d.rs +++ b/crates/burn-core/src/nn/pool/max_pool2d.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig2d; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -31,7 +31,7 @@ pub struct MaxPool2dConfig { pub struct MaxPool2d { stride: [usize; 2], kernel_size: [usize; 2], - padding: PaddingConfig2d, + padding: Ignored, dilation: [usize; 2], } @@ -41,7 +41,7 @@ impl MaxPool2dConfig { MaxPool2d { stride: self.strides, kernel_size: self.kernel_size, - padding: self.padding.clone(), + padding: Ignored(self.padding.clone()), dilation: self.dilation, } } diff --git a/crates/burn-core/src/nn/unfold.rs b/crates/burn-core/src/nn/unfold.rs index 31acb1a87f..c958883e2b 100644 --- a/crates/burn-core/src/nn/unfold.rs +++ b/crates/burn-core/src/nn/unfold.rs @@ -1,12 +1,11 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; -use crate::tensor::backend::Backend; -use crate::tensor::ops::UnfoldOptions; -use crate::tensor::Tensor; - -use crate::tensor::module::unfold4d; +use crate::module::{Ignored, Module}; +use burn_tensor::backend::Backend; +use burn_tensor::module::unfold4d; +use burn_tensor::ops::UnfoldOptions; +use burn_tensor::Tensor; /// Configuration to create an [unfold 4d](Unfold4d) layer using the [init function](Unfold4dConfig::init). #[derive(Config, Debug)] @@ -29,14 +28,14 @@ pub struct Unfold4dConfig { /// Should be created with [Unfold4dConfig]. #[derive(Module, Clone, Debug)] pub struct Unfold4d { - config: Unfold4dConfig, + config: Ignored, } impl Unfold4dConfig { /// Initializes a new [Unfold4d] module. pub fn init(&self) -> Unfold4d { Unfold4d { - config: self.clone(), + config: Ignored(self.clone()), } } } @@ -48,7 +47,7 @@ impl Unfold4d { /// /// # Shapes /// - /// input: `[batch_size, channels_in, height, width]` + /// input: `[batch_size, channels_in, height, width]` /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]` pub fn forward(&self, input: Tensor) -> Tensor { unfold4d( diff --git a/crates/burn-core/src/record/serde/ser.rs b/crates/burn-core/src/record/serde/ser.rs index effeeab7e1..acc765ec88 100644 --- a/crates/burn-core/src/record/serde/ser.rs +++ b/crates/burn-core/src/record/serde/ser.rs @@ -370,6 +370,6 @@ mod tests { // Compare the lengths of expected and actual serialized strings because // the order of the fields is not guaranteed for HashMaps. - assert_eq!(serialized_str.len(), 134); + assert_eq!(serialized_str.len(), 108); } } diff --git a/crates/burn-derive/src/lib.rs b/crates/burn-derive/src/lib.rs index 91f254b88f..a9af743ca3 100644 --- a/crates/burn-derive/src/lib.rs +++ b/crates/burn-derive/src/lib.rs @@ -13,7 +13,7 @@ pub(crate) mod record; pub(crate) mod shared; /// Derive macro for the module. -#[proc_macro_derive(Module)] +#[proc_macro_derive(Module, attributes(module))] pub fn module_derive(input: TokenStream) -> TokenStream { let input = syn::parse(input).unwrap(); module::derive_impl(&input) diff --git a/crates/burn-derive/src/module/codegen.rs b/crates/burn-derive/src/module/codegen.rs index f8ae69145a..a6ab07efb3 100644 --- a/crates/burn-derive/src/module/codegen.rs +++ b/crates/burn-derive/src/module/codegen.rs @@ -2,7 +2,7 @@ use super::{display, record::ModuleRecordCodegen}; use crate::shared::generics::GenericsHelper; use proc_macro2::{Ident, TokenStream}; use quote::quote; -use syn::{parse_quote, Generics}; +use syn::{parse_quote, Attribute, Generics}; /// Basic trait to be implemented for Module generation. pub(crate) trait ModuleCodegen { @@ -30,8 +30,8 @@ pub(crate) fn generate_module_standard( let generics = GenericsParser::from_ast(&ast.generics); - let display_fn = display::display_fn(name); - + let display_fn = display::display_fn(ast); + let attributes_fn = display::attributes_fn(ast); let num_params_fn = codegen.gen_num_params(); let visit = codegen.gen_visit(); let map_mut = codegen.gen_map(); @@ -54,7 +54,7 @@ pub(crate) fn generate_module_standard( let generics_ty_inner_module = generics.inner_module_ty; - let gen = quote! { + let mut gen = quote! { impl #generics_module burn::module::Module for #name #generics_ty_module #generics_where_module { type Record = #record_name #generics_ty_module; @@ -69,6 +69,7 @@ pub(crate) fn generate_module_standard( #collect_devices #to_device #fork + } impl #generics_module_autodiff burn::module::AutodiffModule for #name #generics_ty_module_autodiff #generics_where_module_autodiff @@ -82,6 +83,15 @@ pub(crate) fn generate_module_standard( #display_fn } + + impl #generics_module burn::module::ModuleDisplayDefault for #name #generics_ty_module #generics_where_module { + #attributes_fn + + fn num_params(&self) -> usize { + burn::module::Module::num_params(self) + } + } + impl #generics_module Clone for #name #generics_ty_module #generics_where_module { #clone_fn } @@ -89,13 +99,21 @@ pub(crate) fn generate_module_standard( #record_type }; + if !has_custom_display(&ast.attrs) { + gen.extend(quote! { + impl #generics_module burn::module::ModuleDisplay for #name #generics_ty_module #generics_where_module { + + } + }); + } + gen } // When there is no backend in the generic parameter, the type is considered as a constant. pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream { let name = &ast.ident; - let (_generics, generics_ty, generics_where) = ast.generics.split_for_impl(); + let (generics, generics_ty, generics_where) = ast.generics.split_for_impl(); let backend: syn::Generics = parse_quote! { }; let backend_ad: syn::Generics = parse_quote! { }; @@ -112,7 +130,10 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream { let (generics_module, _, _) = generics_module.split_for_impl(); let (generics_module_ad, _, _) = generics_module_autodiff.split_for_impl(); - let gen = quote! { + let display_fn = display::display_fn(ast); + let attributes_fn = display::attributes_fn(ast); + + let mut gen = quote! { impl #generics_module burn::module::Module for #name #generics_ty #generics_where { burn::constant!(module); } @@ -121,8 +142,26 @@ pub(crate) fn generate_module_const(ast: &syn::DeriveInput) -> TokenStream { for #name #generics_ty #generics_where { burn::constant!(ad_module, #name #generics_ty); } + + impl #generics core::fmt::Display for #name #generics_ty #generics_where { + #display_fn + } + + + impl #generics burn::module::ModuleDisplayDefault for #name #generics_ty #generics_where { + #attributes_fn + } + }; + if !has_custom_display(&ast.attrs) { + gen.extend(quote! { + impl #generics burn::module::ModuleDisplay for #name #generics_ty #generics_where { + + } + }); + } + gen } @@ -159,22 +198,64 @@ impl GenericsParser { #ident: burn::module::Module } ); + + module.add_predicate( + parse_quote! { + #ident: burn::module::ModuleDisplayDefault + } + ); + + module.add_predicate( + parse_quote! { + #ident: burn::module::ModuleDisplay + } + ); + + module_autodiff.add_predicate( parse_quote! { #ident: burn::module::AutodiffModule } ); - module_autodiff.add_predicate( + + module_autodiff.add_predicate( parse_quote! { <#ident as burn::module::AutodiffModule>::InnerModule: burn::module::Module } ); + + module_autodiff.add_predicate( + parse_quote! { + <#ident as burn::module::AutodiffModule>::InnerModule: burn::module::ModuleDisplay + } + ); + + module_autodiff.add_predicate( + parse_quote! { + <#ident as burn::module::AutodiffModule>::InnerModule: burn::module::ModuleDisplay + } + ); + + generics_names_except_backend.extend(quote! { <#ident as burn::module::AutodiffModule>::InnerModule, }); module_autodiff.add_predicate( parse_quote! { #ident: burn::module::Module } ); + + module_autodiff.add_predicate( + parse_quote! { + #ident: burn::module::ModuleDisplayDefault + } + ); + + module_autodiff.add_predicate( + parse_quote! { + #ident: burn::module::ModuleDisplay + } + ); + }); module.consts().into_iter().for_each(|ident| { @@ -188,3 +269,18 @@ impl GenericsParser { } } } + +fn has_custom_display(attrs: &[Attribute]) -> bool { + attrs.iter().any(|attr| { + attr.path().is_ident("module") + && attr + .parse_nested_meta(|meta| { + if meta.path.is_ident("custom_display") { + Ok(()) + } else { + Err(meta.error("unsupported attribute")) + } + }) + .is_ok() + }) +} diff --git a/crates/burn-derive/src/module/display.rs b/crates/burn-derive/src/module/display.rs index f9c024ff49..7c70726997 100644 --- a/crates/burn-derive/src/module/display.rs +++ b/crates/burn-derive/src/module/display.rs @@ -1,11 +1,96 @@ -use proc_macro2::Ident; use quote::quote; -pub fn display_fn(name: &Ident) -> proc_macro2::TokenStream { +pub fn attributes_fn(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { + match &ast.data { + syn::Data::Struct(ref data_struct) => { + let fields = match &data_struct.fields { + syn::Fields::Named(ref named_fields) => { + named_fields.named.iter().collect::>() + } + syn::Fields::Unit => Vec::new(), + _ => panic!("attributes_fn only supports structs with named or unit fields"), + }; + let field_prints = fields.iter().map(|field| { + let field_name = &field.ident; + quote! { .add(stringify!(#field_name), &self.#field_name) } + }); + let struct_name = &ast.ident; + quote! { + fn content(&self, mut content: burn::module::Content) -> Option { + content + .set_top_level_type(&stringify!(#struct_name)) + #(#field_prints)* + .optional() + } + } + } + syn::Data::Enum(ref data_enum) => { + let variant_prints = data_enum.variants.iter().map(|variant| { + let variant_name = &variant.ident; + match &variant.fields { + syn::Fields::Unit => { + quote! { + Self::#variant_name => { + content.add_formatted(&stringify!(#variant_name).to_string()) + .optional() + + } + } + } + syn::Fields::Named(ref named_fields) => { + let field_prints = named_fields.named.iter().map(|field| { + let field_name = &field.ident; + quote! { .add(stringify!(#field_name), &self.#field_name) } + }); + + let field_names = named_fields.named.iter().map(|field| { + let field_name = &field.ident; + quote! { #field_name } + }); + + quote! { + Self::#variant_name { #(#field_names),* } => { + content.set_top_level_type(&stringify!(#variant_name)) + #(#field_prints)* + .optional() + } + } + } + syn::Fields::Unnamed(ref unnamed_fields) => { + let field_names = (0..unnamed_fields.unnamed.len()).map(|i| { + syn::Ident::new(&format!("_{}", i), proc_macro2::Span::call_site()) + }); + + let field_prints = field_names.clone().map(|field_name| { + quote! { .add(stringify!(#field_name), #field_name) } + }); + quote! { + Self::#variant_name(#(#field_names),*) => { + content.set_top_level_type(&stringify!(#variant_name)) + #(#field_prints)* + .optional() + } + } + } + } + }); + quote! { + fn content(&self, mut content: burn::module::Content) -> Option { + match self { + #(#variant_prints)* + } + } + } + } + _ => panic!("attributes_fn only supports structs and enums"), + } +} + +pub fn display_fn(_ast: &syn::DeriveInput) -> proc_macro2::TokenStream { quote! { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{}[num_params={}]", stringify!(#name), self.num_params()) - + let formatted = burn::module::ModuleDisplay::format(self, Default::default()); + write!(f, "{}", formatted) } } } diff --git a/crates/burn-train/src/learner/summary.rs b/crates/burn-train/src/learner/summary.rs index 7fee928bd8..4e058e302b 100644 --- a/crates/burn-train/src/learner/summary.rs +++ b/crates/burn-train/src/learner/summary.rs @@ -151,7 +151,7 @@ impl Display for LearnerSummary { )?; if let Some(model) = &self.model { - writeln!(f, "Model: {model}")?; + writeln!(f, "Model:\n{model}")?; } writeln!(f, "Total Epochs: {epochs}\n\n", epochs = self.epochs)?; From 0f8dd57d9cdce98909c1821d2885ac0929a84fe4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 25 Jun 2024 09:23:36 -0400 Subject: [PATCH 3/4] Combined PRs (#1928) * Bump syn from 2.0.66 to 2.0.68 Bumps [syn](https://github.com/dtolnay/syn) from 2.0.66 to 2.0.68. - [Release notes](https://github.com/dtolnay/syn/releases) - [Commits](https://github.com/dtolnay/syn/compare/2.0.66...2.0.68) --- updated-dependencies: - dependency-name: syn dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump proc-macro2 from 1.0.85 to 1.0.86 Bumps [proc-macro2](https://github.com/dtolnay/proc-macro2) from 1.0.85 to 1.0.86. - [Release notes](https://github.com/dtolnay/proc-macro2/releases) - [Commits](https://github.com/dtolnay/proc-macro2/compare/1.0.85...1.0.86) --- updated-dependencies: - dependency-name: proc-macro2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump wgpu from 0.20.0 to 0.20.1 Bumps [wgpu](https://github.com/gfx-rs/wgpu) from 0.20.0 to 0.20.1. - [Release notes](https://github.com/gfx-rs/wgpu/releases) - [Changelog](https://github.com/gfx-rs/wgpu/blob/v0.20.1/CHANGELOG.md) - [Commits](https://github.com/gfx-rs/wgpu/compare/v0.20.0...v0.20.1) --- updated-dependencies: - dependency-name: wgpu dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump strum from 0.26.2 to 0.26.3 Bumps [strum](https://github.com/Peternator7/strum) from 0.26.2 to 0.26.3. - [Release notes](https://github.com/Peternator7/strum/releases) - [Changelog](https://github.com/Peternator7/strum/blob/master/CHANGELOG.md) - [Commits](https://github.com/Peternator7/strum/compare/v0.26.2...v0.26.3) --- updated-dependencies: - dependency-name: strum dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump bytemuck from 1.16.0 to 1.16.1 Bumps [bytemuck](https://github.com/Lokathor/bytemuck) from 1.16.0 to 1.16.1. - [Changelog](https://github.com/Lokathor/bytemuck/blob/main/changelog.md) - [Commits](https://github.com/Lokathor/bytemuck/compare/v1.16.0...v1.16.1) --- updated-dependencies: - dependency-name: bytemuck dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Cargo.lock | 108 ++++++++++++++++++++++++++--------------------------- Cargo.toml | 10 ++--- 2 files changed, 59 insertions(+), 59 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 827ce13c76..0fbbafa610 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -171,7 +171,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -197,7 +197,7 @@ checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -546,7 +546,7 @@ dependencies = [ "derive-new", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -604,7 +604,7 @@ dependencies = [ "derive-new", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -642,7 +642,7 @@ dependencies = [ "serde_json", "strum", "strum_macros", - "syn 2.0.66", + "syn 2.0.68", "thiserror", "tracing-core", "tracing-subscriber", @@ -777,9 +777,9 @@ dependencies = [ [[package]] name = "bytemuck" -version = "1.16.0" +version = "1.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" +checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" dependencies = [ "bytemuck_derive", ] @@ -792,7 +792,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1030,7 +1030,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1442,7 +1442,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.10.0", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1453,7 +1453,7 @@ checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" dependencies = [ "darling_core", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1498,7 +1498,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1509,7 +1509,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1530,7 +1530,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1540,7 +1540,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1551,7 +1551,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1626,7 +1626,7 @@ checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1678,7 +1678,7 @@ dependencies = [ "heck 0.4.1", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1864,7 +1864,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -1955,7 +1955,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -2888,7 +2888,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -3288,7 +3288,7 @@ checksum = "bf307cbbbd777a9c10cec88ddafee572b3484caad5cce0c9236523c3803105a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -3448,7 +3448,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -3710,7 +3710,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -3870,7 +3870,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -3970,9 +3970,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.85" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -3993,7 +3993,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -4544,7 +4544,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.66", + "syn 2.0.68", "unicode-ident", ] @@ -4822,7 +4822,7 @@ checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -4889,7 +4889,7 @@ checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -5039,7 +5039,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ff9eaf853dec4c8802325d8b6d3dffa86cc707fd7a1a4cdbf416e13b061787a" dependencies = [ "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -5068,9 +5068,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "strum" -version = "0.26.2" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ "strum_macros", ] @@ -5085,7 +5085,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -5107,9 +5107,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.66" +version = "2.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c42f3f41a2de00b01c0aaad383c5a45241efc8b2d1eda5661812fda5f3cdcff5" +checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" dependencies = [ "proc-macro2", "quote", @@ -5130,7 +5130,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -5321,7 +5321,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -5459,7 +5459,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -5595,7 +5595,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -5858,7 +5858,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", "wasm-bindgen-shared", ] @@ -5892,7 +5892,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -5966,9 +5966,9 @@ checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" [[package]] name = "wgpu" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32ff1bfee408e1028e2e3acbf6d32d98b08a5a059ccbf5f33305534453ba5d3e" +checksum = "90e37c7b9921b75dfd26dd973fdcbce36f13dfa6e2dc82aece584e0ed48c355c" dependencies = [ "arrayvec", "cfg-if", @@ -5992,9 +5992,9 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac6a86eaa5e763e59c73cf9e97d55fffd4dfda69fd8bda19589fcf851ddfef1f" +checksum = "d59e0d5fc509601c69e4e1fa06c1eb3c4c9f12956a5e30c79b61ef1c1be7daf0" dependencies = [ "arrayvec", "bit-vec", @@ -6019,9 +6019,9 @@ dependencies = [ [[package]] name = "wgpu-hal" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d71c8ae05170583049b65ee562fd839fdc0b3e9ddb84f4e40c9d5f8ea0d4c8c" +checksum = "6aa24c3889f885a3fb9133b454c8418bfcfaadcfe4ed3be96ac80e76703b863b" dependencies = [ "android_system_properties", "arrayvec", @@ -6318,7 +6318,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -6401,7 +6401,7 @@ checksum = "9e6936f0cce458098a201c245a11bef556c6a0181129c7034d10d76d1ec3a2b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", "synstructure", ] @@ -6422,7 +6422,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] @@ -6442,7 +6442,7 @@ checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", "synstructure", ] @@ -6463,7 +6463,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.66", + "syn 2.0.68", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8e9737edf9..14628caf54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ license = "MIT OR Apache-2.0" [workspace.dependencies] async-trait = "0.1.80" -bytemuck = "1.16.0" +bytemuck = "1.16.1" candle-core = { version = "0.5.1" } clap = { version = "4.5.7", features = ["derive"] } colored = "2.1.0" @@ -54,7 +54,7 @@ log = { default-features = false, version = "0.4.21" } md5 = "0.7.0" percent-encoding = "2.3.1" pretty_assertions = "1.4.0" -proc-macro2 = "1.0.85" +proc-macro2 = "1.0.86" protobuf = "3.4.0" protobuf-codegen = "3.4.0" quote = "1.0.36" @@ -71,9 +71,9 @@ sanitize-filename = "0.5.0" serde_rusqlite = "0.35.0" serial_test = "3.1.1" spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } -strum = "0.26.2" +strum = "0.26.3" strum_macros = "0.26.4" -syn = { version = "2.0.66", features = ["full", "extra-traits"] } +syn = { version = "2.0.68", features = ["full", "extra-traits"] } tempfile = "3.10.1" thiserror = "1.0.61" tokio = { version = "1.38.0", features = ["rt", "macros"] } @@ -91,7 +91,7 @@ crossterm = "0.27.0" futures-intrusive = "0.5.0" text_placeholder = "0.5.0" pollster = "0.3.0" -wgpu = "0.20.0" +wgpu = "0.20.1" # Benchmarks and Burnbench arboard = "3.4.0" From 2fbc4628f3ddbd25e223084985c94d2c46dccb58 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 25 Jun 2024 09:55:55 -0400 Subject: [PATCH 4/4] Feat/cube/array assign ops (#1914) --- .../src/codegen_function/base.rs | 8 ++ .../src/codegen_function/expr.rs | 11 +- .../src/codegen_function/operation.rs | 111 ++++++++++++++---- .../src/codegen_function/variable.rs | 14 ++- .../src/frontend/operation/assignation.rs | 84 +++++++++++++ .../burn-cube/src/frontend/operation/base.rs | 40 +++++++ crates/burn-cube/tests/frontend/array.rs | 84 +++++++++++++ crates/burn-cube/tests/frontend/mod.rs | 1 + 8 files changed, 323 insertions(+), 30 deletions(-) create mode 100644 crates/burn-cube/tests/frontend/array.rs diff --git a/crates/burn-cube-macros/src/codegen_function/base.rs b/crates/burn-cube-macros/src/codegen_function/base.rs index f6f82621fe..e0e25c8c1d 100644 --- a/crates/burn-cube-macros/src/codegen_function/base.rs +++ b/crates/burn-cube-macros/src/codegen_function/base.rs @@ -49,6 +49,12 @@ pub(crate) fn codegen_block( pub(crate) struct Codegen { pub tokens: proc_macro2::TokenStream, pub is_comptime: bool, + pub array_indexing: Option, +} + +pub(crate) struct ArrayIndexing { + pub array: proc_macro2::TokenStream, + pub index: proc_macro2::TokenStream, } impl From for Codegen { @@ -56,6 +62,7 @@ impl From for Codegen { Self { tokens, is_comptime: false, + array_indexing: None, } } } @@ -65,6 +72,7 @@ impl Codegen { Self { tokens: tokens.into(), is_comptime, + array_indexing: None, } } diff --git a/crates/burn-cube-macros/src/codegen_function/expr.rs b/crates/burn-cube-macros/src/codegen_function/expr.rs index a509642abf..2b46f1eb80 100644 --- a/crates/burn-cube-macros/src/codegen_function/expr.rs +++ b/crates/burn-cube-macros/src/codegen_function/expr.rs @@ -25,6 +25,7 @@ pub(crate) fn codegen_expr( syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker), syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker), _ => { + let mut array_indexing = None; let tokens = match expr { syn::Expr::Path(path) => { return codegen_path_var(path, loop_level, variable_tracker) @@ -50,7 +51,11 @@ pub(crate) fn codegen_expr( syn::Expr::MethodCall(call) => { codegen_expr_method_call(call, loop_level, variable_tracker) } - syn::Expr::Index(index) => codegen_index(index, loop_level, variable_tracker), + syn::Expr::Index(index) => { + let codegen = codegen_index(index, loop_level, variable_tracker); + array_indexing = codegen.array_indexing; + codegen.tokens + } syn::Expr::Array(array) => codegen_array_lit(array), syn::Expr::Reference(reference) => { codegen_ref(reference, loop_level, variable_tracker) @@ -67,7 +72,9 @@ pub(crate) fn codegen_expr( } }; - Codegen::new(tokens, false) + let mut codegen = Codegen::new(tokens, false); + codegen.array_indexing = array_indexing; + codegen } } } diff --git a/crates/burn-cube-macros/src/codegen_function/operation.rs b/crates/burn-cube-macros/src/codegen_function/operation.rs index 8a9da3e60d..fa6235c045 100644 --- a/crates/burn-cube-macros/src/codegen_function/operation.rs +++ b/crates/burn-cube-macros/src/codegen_function/operation.rs @@ -8,7 +8,8 @@ pub(crate) fn codegen_binary( loop_level: usize, variable_tracker: &mut VariableTracker, ) -> Codegen { - let (lhs, is_comptime_lhs) = codegen_expr(&binary.left, loop_level, variable_tracker).split(); + let lhs = codegen_expr(&binary.left, loop_level, variable_tracker); + let (lhs, is_comptime_lhs, lhs_array) = (lhs.tokens, lhs.is_comptime, lhs.array_indexing); let (rhs, is_comptime_rhs) = codegen_expr(&binary.right, loop_level, variable_tracker).split(); if is_comptime_lhs && is_comptime_rhs { @@ -99,34 +100,94 @@ pub(crate) fn codegen_binary( burn_cube::frontend::eq::expand(context, _lhs, _rhs) } }, - syn::BinOp::AddAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs) + syn::BinOp::AddAssign(_) => { + if let Some(array) = lhs_array { + let (array, index) = (array.array, array.index); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #rhs; + burn_cube::frontend::add_assign_array_op::expand(context, _array, _index, _value) + } + } + } else { + quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs) + } + } } - }, - syn::BinOp::SubAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs) + } + syn::BinOp::SubAssign(_) => { + if let Some(array) = lhs_array { + let (array, index) = (array.array, array.index); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #rhs; + burn_cube::frontend::sub_assign_array_op::expand(context, _array, _index, _value) + } + } + } else { + quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs) + } + } } - }, - syn::BinOp::MulAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs) + } + syn::BinOp::MulAssign(_) => { + if let Some(array) = lhs_array { + let (array, index) = (array.array, array.index); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #rhs; + burn_cube::frontend::mul_assign_array_op::expand(context, _array, _index, _value) + } + } + } else { + quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs) + } + } } - }, - syn::BinOp::DivAssign(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs) + } + syn::BinOp::DivAssign(_) => { + if let Some(array) = lhs_array { + let (array, index) = (array.array, array.index); + + quote::quote! { + { + let _array = #array; + let _index = #index; + let _value = #rhs; + burn_cube::frontend::div_assign_array_op::expand(context, _array, _index, _value) + } + } + } else { + quote::quote! { + { + let _lhs = #lhs; + let _rhs = #rhs; + burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs) + } + } } - }, + } syn::BinOp::And(_) => quote::quote! { { diff --git a/crates/burn-cube-macros/src/codegen_function/variable.rs b/crates/burn-cube-macros/src/codegen_function/variable.rs index fe8a15b402..2ec75d9c1b 100644 --- a/crates/burn-cube-macros/src/codegen_function/variable.rs +++ b/crates/burn-cube-macros/src/codegen_function/variable.rs @@ -99,17 +99,25 @@ pub(crate) fn codegen_index( index: &syn::ExprIndex, loop_level: usize, variable_tracker: &mut VariableTracker, -) -> TokenStream { +) -> Codegen { let array = codegen_expr(&index.expr, loop_level, variable_tracker); let index = codegen_expr(&index.index, loop_level, variable_tracker); - quote::quote! { + let tokens = quote::quote! { { let _array = #array; let _index = #index; burn_cube::frontend::index::expand(context, _array, _index) } - } + }; + + let mut codegen = Codegen::new(tokens, false); + codegen.array_indexing = Some(super::base::ArrayIndexing { + array: array.tokens, + index: index.tokens, + }); + + codegen } /// Codegen for assignation diff --git a/crates/burn-cube/src/frontend/operation/assignation.rs b/crates/burn-cube/src/frontend/operation/assignation.rs index 5b1922774d..86edb022ee 100644 --- a/crates/burn-cube/src/frontend/operation/assignation.rs +++ b/crates/burn-cube/src/frontend/operation/assignation.rs @@ -113,6 +113,90 @@ pub mod index { impl_index!(SharedMemory); } +pub mod add_assign_array_op { + use crate::prelude::array_assign_binary_op_expand; + + use self::ir::Operator; + + use super::*; + + pub fn expand< + Array: Into, + Index: Into, + Value: Into, + >( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + ) { + array_assign_binary_op_expand(context, array, index, value, Operator::Add); + } +} + +pub mod sub_assign_array_op { + use crate::prelude::array_assign_binary_op_expand; + + use self::ir::Operator; + + use super::*; + + pub fn expand< + Array: Into, + Index: Into, + Value: Into, + >( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + ) { + array_assign_binary_op_expand(context, array, index, value, Operator::Sub); + } +} + +pub mod mul_assign_array_op { + use crate::prelude::array_assign_binary_op_expand; + + use self::ir::Operator; + + use super::*; + + pub fn expand< + Array: Into, + Index: Into, + Value: Into, + >( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + ) { + array_assign_binary_op_expand(context, array, index, value, Operator::Mul); + } +} + +pub mod div_assign_array_op { + use crate::prelude::array_assign_binary_op_expand; + + use self::ir::Operator; + + use super::*; + + pub fn expand< + Array: Into, + Index: Into, + Value: Into, + >( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + ) { + array_assign_binary_op_expand(context, array, index, value, Operator::Div); + } +} + pub mod add_assign_op { use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; diff --git a/crates/burn-cube/src/frontend/operation/base.rs b/crates/burn-cube/src/frontend/operation/base.rs index b7263db485..4d0c705486 100644 --- a/crates/burn-cube/src/frontend/operation/base.rs +++ b/crates/burn-cube/src/frontend/operation/base.rs @@ -203,3 +203,43 @@ fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization output } + +pub fn array_assign_binary_op_expand< + Array: Into, + Index: Into, + Value: Into, + F: Fn(BinaryOperator) -> Operator, +>( + context: &mut CubeContext, + array: Array, + index: Index, + value: Value, + func: F, +) { + let array: ExpandElement = array.into(); + let index: ExpandElement = index.into(); + let value: ExpandElement = value.into(); + + let tmp = context.create_local(array.item()); + + let read = Operator::Index(BinaryOperator { + lhs: *array, + rhs: *index, + out: *tmp, + }); + let calculate = func(BinaryOperator { + lhs: *tmp, + rhs: *value, + out: *tmp, + }); + + let write = Operator::IndexAssign(BinaryOperator { + lhs: *index, + rhs: *tmp, + out: *array, + }); + + context.register(read); + context.register(calculate); + context.register(write); +} diff --git a/crates/burn-cube/tests/frontend/array.rs b/crates/burn-cube/tests/frontend/array.rs new file mode 100644 index 0000000000..44a7fca4fa --- /dev/null +++ b/crates/burn-cube/tests/frontend/array.rs @@ -0,0 +1,84 @@ +use burn_cube::prelude::*; + +#[cube] +fn array_add_assign_simple(mut array: Array) { + array[UInt::new(1)] += UInt::new(1); +} + +#[cube] +fn array_add_assign_expr(mut array: Array) { + array[UInt::new(1) + UInt::new(5)] += UInt::new(1); +} + +mod tests { + use super::*; + use burn_cube::{ + cpa, + ir::{Elem, Item, Variable}, + }; + + #[test] + fn array_add_assign() { + let mut context = CubeContext::root(); + let array = context.input(0, Item::new(Elem::UInt)); + + array_add_assign_simple_expand(&mut context, array); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_array_add_assign_simple() + ); + } + + #[test] + fn array_add_assign_expr() { + let mut context = CubeContext::root(); + let array = context.input(0, Item::new(Elem::UInt)); + + array_add_assign_expr_expand(&mut context, array); + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_array_add_assign_expr() + ); + } + + fn inline_macro_array_add_assign_simple() -> String { + let context = CubeContext::root(); + + let mut scope = context.into_scope(); + let local = scope.create_local(Item::new(Elem::UInt)); + + let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt)); + let index = Variable::ConstantScalar(1., Elem::UInt); + let value = Variable::ConstantScalar(1., Elem::UInt); + + cpa!(scope, local = array[index]); + cpa!(scope, local += value); + cpa!(scope, array[index] = local); + + format!("{:?}", scope.operations) + } + + fn inline_macro_array_add_assign_expr() -> String { + let context = CubeContext::root(); + + let mut scope = context.into_scope(); + let index = scope.create_local(Item::new(Elem::UInt)); + let local = scope.create_local(Item::new(Elem::UInt)); + + let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt)); + let const1 = Variable::ConstantScalar(1., Elem::UInt); + let const2 = Variable::ConstantScalar(5., Elem::UInt); + let value = Variable::ConstantScalar(1., Elem::UInt); + + cpa!(scope, index = const1 + const2); + cpa!(scope, local = array[index]); + cpa!(scope, local += value); + cpa!(scope, array[index] = local); + + format!("{:?}", scope.operations) + } +} diff --git a/crates/burn-cube/tests/frontend/mod.rs b/crates/burn-cube/tests/frontend/mod.rs index c13c1300b3..f9b433e0ab 100644 --- a/crates/burn-cube/tests/frontend/mod.rs +++ b/crates/burn-cube/tests/frontend/mod.rs @@ -1,3 +1,4 @@ +mod array; mod assign; mod cast_elem; mod cast_kind;