From 32cd4c2626774342d988269e97227e088d455430 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:47:16 -0500 Subject: [PATCH 01/52] Display for ConvTranspose1d --- .../burn-core/src/nn/conv/conv_transpose1d.rs | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/crates/burn-core/src/nn/conv/conv_transpose1d.rs b/crates/burn-core/src/nn/conv/conv_transpose1d.rs index 1a73191798..d9c2268ba8 100644 --- a/crates/burn-core/src/nn/conv/conv_transpose1d.rs +++ b/crates/burn-core/src/nn/conv/conv_transpose1d.rs @@ -1,7 +1,12 @@ +use alloc::format; + use crate as burn; use crate::config::Config; +use crate::module::Content; +use crate::module::DisplaySettings; use crate::module::Module; +use crate::module::ModuleDisplay; use crate::module::Param; use crate::nn::conv::checks; use crate::nn::Initializer; @@ -45,6 +50,7 @@ pub struct ConvTranspose1dConfig { /// Applies a 1D transposed convolution over input tensors. #[derive(Module, Debug)] +#[module(custom_display)] pub struct ConvTranspose1d { /// Tensor of shape `[channels_in, channels_out / groups, kernel_size]` pub weight: Param>, @@ -56,6 +62,27 @@ pub struct ConvTranspose1d { groups: usize, padding: usize, padding_out: usize, + channels: [usize; 2], +} + +impl ModuleDisplay for ConvTranspose1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("channels", &format!("{:?}", &self.channels)) + .add("stride", &self.stride) + .add("kernel_size", &self.kernel_size) + .add("dilation", &self.dilation) + .add("groups", &self.groups) + .add("padding", &self.padding) + .add("padding_out", &self.padding_out) + .optional() + } } impl ConvTranspose1dConfig { @@ -91,6 +118,7 @@ impl ConvTranspose1dConfig { groups: self.groups, padding: self.padding, padding_out: self.padding_out, + channels: self.channels, } } } @@ -150,4 +178,15 @@ mod tests { .to_data() .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); } + + #[test] + fn print() { + let config = ConvTranspose1dConfig::new([5, 2], 5); + let conv = config.init::(&Default::default()); + + assert_eq!( + format!("{}", conv), + "ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}" + ); + } } From dab0d25b374ed45ca96e573a4c8f9fd14edb9373 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:47:44 -0500 Subject: [PATCH 02/52] Display for ConvTranspose2d --- .../burn-core/src/nn/conv/conv_transpose2d.rs | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/crates/burn-core/src/nn/conv/conv_transpose2d.rs b/crates/burn-core/src/nn/conv/conv_transpose2d.rs index 0d5132942d..8540f2dbf8 100644 --- a/crates/burn-core/src/nn/conv/conv_transpose2d.rs +++ b/crates/burn-core/src/nn/conv/conv_transpose2d.rs @@ -1,7 +1,12 @@ +use alloc::format; + use crate as burn; use crate::config::Config; +use crate::module::Content; +use crate::module::DisplaySettings; use crate::module::Module; +use crate::module::ModuleDisplay; use crate::module::Param; use crate::nn::conv::checks; use crate::nn::Initializer; @@ -45,6 +50,7 @@ pub struct ConvTranspose2dConfig { /// Applies a 2D transposed convolution over input tensors. #[derive(Module, Debug)] +#[module(custom_display)] pub struct ConvTranspose2d { /// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]` pub weight: Param>, @@ -56,6 +62,27 @@ pub struct ConvTranspose2d { groups: usize, padding: [usize; 2], padding_out: [usize; 2], + channels: [usize; 2], +} + +impl ModuleDisplay for ConvTranspose2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("channels", &format!("{:?}", &self.channels)) + .add("stride", &format!("{:?}", &self.stride)) + .add("kernel_size", &format!("{:?}", &self.kernel_size)) + .add("dilation", &format!("{:?}", &self.dilation)) + .add("groups", &self.groups) + .add("padding", &format!("{:?}", &self.padding)) + .add("padding_out", &format!("{:?}", &self.padding_out)) + .optional() + } } impl ConvTranspose2dConfig { @@ -92,6 +119,7 @@ impl ConvTranspose2dConfig { groups: self.groups, padding: self.padding, padding_out: self.padding_out, + channels: self.channels, } } } @@ -152,4 +180,15 @@ mod tests { .to_data() .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); } + + #[test] + fn print() { + let config = ConvTranspose2dConfig::new([5, 2], [5, 5]); + let conv = config.init::(&Default::default()); + + assert_eq!( + format!("{}", conv), + "ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}" + ); + } } From a34232438ce4d4e409d36a487880dc55102d32b3 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:48:05 -0500 Subject: [PATCH 03/52] Test for batch norm display --- crates/burn-core/src/nn/norm/batch.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/crates/burn-core/src/nn/norm/batch.rs b/crates/burn-core/src/nn/norm/batch.rs index e0d06d4923..9aa46f3626 100644 --- a/crates/burn-core/src/nn/norm/batch.rs +++ b/crates/burn-core/src/nn/norm/batch.rs @@ -426,4 +426,15 @@ mod tests_2d { device, ) } + + #[test] + fn print() { + let batch_norm = + BatchNormConfig::new(3).init::(&Default::default()); + + assert_eq!( + format!("{}", batch_norm), + "BatchNorm {num_features: 3, momentum: 0.1, epsilon: 0.00001, params: 12}" + ); + } } From 9e2fa95ff8cf2c43dcce3be7c5d44444f10b79aa Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:48:16 -0500 Subject: [PATCH 04/52] Display for group norm --- crates/burn-core/src/nn/norm/group.rs | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/crates/burn-core/src/nn/norm/group.rs b/crates/burn-core/src/nn/norm/group.rs index 91c04c3d1d..48b6a64716 100644 --- a/crates/burn-core/src/nn/norm/group.rs +++ b/crates/burn-core/src/nn/norm/group.rs @@ -4,6 +4,7 @@ use crate::nn::Initializer; use crate::config::Config; use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -36,6 +37,7 @@ pub struct GroupNormConfig { /// /// Should be created using [GroupNormConfig](GroupNormConfig). #[derive(Module, Debug)] +#[module(custom_display)] pub struct GroupNorm { /// The learnable weight pub gamma: Option>>, @@ -48,6 +50,23 @@ pub struct GroupNorm { pub(crate) affine: bool, } +impl ModuleDisplay for GroupNorm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("num_groups", &self.num_groups) + .add("num_channels", &self.num_channels) + .add("epsilon", &self.epsilon) + .add("affine", &self.affine) + .optional() + } +} + impl GroupNormConfig { /// Initialize a new [group norm](GroupNorm) module. pub fn init(&self, device: &B::Device) -> GroupNorm { @@ -169,6 +188,7 @@ mod tests { use super::*; use crate::tensor::Data; use crate::TestBackend; + use alloc::format; #[test] fn group_norm_forward_affine_false() { @@ -296,4 +316,15 @@ mod tests { 3, ); } + + #[test] + fn print() { + let config = GroupNormConfig::new(3, 6); + let group_norm = config.init::(&Default::default()); + + assert_eq!( + format!("{}", group_norm), + "GroupNorm {num_groups: 3, num_channels: 6, epsilon: 0.00001, affine: true, params: 12}" + ); + } } From 1a6f30343e902c9a1624a89de0320581caf88f33 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:48:29 -0500 Subject: [PATCH 05/52] Display for Instance norm --- crates/burn-core/src/nn/norm/instance.rs | 30 ++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/crates/burn-core/src/nn/norm/instance.rs b/crates/burn-core/src/nn/norm/instance.rs index fb47505b60..55f5805641 100644 --- a/crates/burn-core/src/nn/norm/instance.rs +++ b/crates/burn-core/src/nn/norm/instance.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Module, Param}; use crate::nn::norm::group_norm; use crate::nn::Initializer; @@ -25,6 +26,7 @@ pub struct InstanceNormConfig { /// /// Should be created using [InstanceNormConfig](InstanceNormConfig). #[derive(Module, Debug)] +#[module(custom_display)] pub struct InstanceNorm { /// The learnable weight pub gamma: Option>>, @@ -36,6 +38,22 @@ pub struct InstanceNorm { affine: bool, } +impl ModuleDisplay for InstanceNorm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("num_channels", &self.num_channels) + .add("epsilon", &self.epsilon) + .add("affine", &self.affine) + .optional() + } +} + impl InstanceNormConfig { /// Initialize a new [instance norm](InstanceNorm) module. pub fn init(&self, device: &B::Device) -> InstanceNorm { @@ -83,6 +101,7 @@ mod tests { use super::*; use crate::tensor::Data; use crate::TestBackend; + use alloc::format; #[test] fn instance_norm_forward_affine_false() { @@ -191,4 +210,15 @@ mod tests { 3, ); } + + #[test] + fn print() { + let config = InstanceNormConfig::new(6); + let instance_norm = config.init::(&Default::default()); + + assert_eq!( + format!("{}", instance_norm), + "InstanceNorm {num_channels: 6, epsilon: 0.00001, affine: true, params: 12}" + ); + } } From 1e73430c79aeb2496dcacef0f619859c7759bd0c Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:48:58 -0500 Subject: [PATCH 06/52] Test for layer display --- crates/burn-core/src/nn/norm/layer.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/norm/layer.rs b/crates/burn-core/src/nn/norm/layer.rs index f425f0c12f..7a700c65cc 100644 --- a/crates/burn-core/src/nn/norm/layer.rs +++ b/crates/burn-core/src/nn/norm/layer.rs @@ -1,5 +1,6 @@ use crate as burn; use crate::config::Config; +use crate::module::Content; use crate::module::DisplaySettings; use crate::module::Module; use crate::module::ModuleDisplay; @@ -80,7 +81,7 @@ impl ModuleDisplay for LayerNorm { .optional() } - fn custom_content(&self, content: crate::module::Content) -> Option { + fn custom_content(&self, content: Content) -> Option { let [d_model] = self.gamma.shape().dims; content .add("d_model", &d_model) @@ -93,6 +94,7 @@ impl ModuleDisplay for LayerNorm { mod tests { use super::*; use crate::tensor::Data; + use alloc::format; #[cfg(feature = "std")] use crate::{TestAutodiffBackend, TestBackend}; @@ -160,4 +162,15 @@ mod tests { .to_data() .assert_approx_eq(&Data::zeros(tensor_2_grad.shape()), 3); } + + #[test] + fn display() { + let config = LayerNormConfig::new(6); + let layer_norm = config.init::(&Default::default()); + + assert_eq!( + format!("{}", layer_norm), + "LayerNorm {d_model: 6, epsilon: 0.00001, params: 12}" + ); + } } From eea082130af34846f3bd6a55a9b703d8fc450add Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:49:08 -0500 Subject: [PATCH 07/52] Display for RMS --- crates/burn-core/src/nn/norm/rms.rs | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/crates/burn-core/src/nn/norm/rms.rs b/crates/burn-core/src/nn/norm/rms.rs index 2ac15df227..508d69694f 100644 --- a/crates/burn-core/src/nn/norm/rms.rs +++ b/crates/burn-core/src/nn/norm/rms.rs @@ -3,6 +3,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::nn::Initializer; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -48,6 +49,7 @@ impl RmsNormConfig { /// /// Should be created using the [RmsNormConfig](RmsNormConfig) configuration. #[derive(Module, Debug)] +#[module(custom_display)] pub struct RmsNorm { /// The learnable parameter to scale the normalized tensor pub gamma: Param>, @@ -71,11 +73,28 @@ impl RmsNorm { } } +impl ModuleDisplay for RmsNorm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_model] = self.gamma.shape().dims; + content + .add("d_model", &d_model) + .add("epsilon", &self.epsilon) + .optional() + } +} + #[cfg(test)] mod tests { use super::*; use crate::tensor::Data; use crate::TestBackend; + use alloc::format; #[test] fn rms_norm_forward() { @@ -97,4 +116,15 @@ mod tests { 4, ); } + + #[test] + fn display() { + let config = RmsNormConfig::new(6); + let layer_norm = config.init::(&Default::default()); + + assert_eq!( + format!("{}", layer_norm), + "RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}" + ); + } } From a762ede4ab0738aabe47ab4df6a0784609f0a667 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:49:31 -0500 Subject: [PATCH 08/52] Fix conv transpose --- crates/burn-import/src/burn/node/conv_transpose_2d.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/burn-import/src/burn/node/conv_transpose_2d.rs b/crates/burn-import/src/burn/node/conv_transpose_2d.rs index c0e7df7f1c..c0262ef784 100644 --- a/crates/burn-import/src/burn/node/conv_transpose_2d.rs +++ b/crates/burn-import/src/burn/node/conv_transpose_2d.rs @@ -98,6 +98,7 @@ impl NodeCodegen for ConvTranspose2dNode { groups: ConstantRecord::new(), padding: [ConstantRecord::new(); 2], padding_out: [ConstantRecord::new(); 2], + channels: [ConstantRecord::new(); 2], }; let item = Record::into_item::(record); From 80cb068910d9f9e3444e188f0f72b66eb9c11014 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:55:54 -0500 Subject: [PATCH 09/52] Test for conv1d display --- crates/burn-core/src/nn/conv/conv1d.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs index e05231b274..8157a4ee23 100644 --- a/crates/burn-core/src/nn/conv/conv1d.rs +++ b/crates/burn-core/src/nn/conv/conv1d.rs @@ -169,4 +169,15 @@ mod tests { .to_data() .assert_approx_eq(&Data::zeros(conv.weight.shape()), 3); } + + #[test] + fn print() { + let config = Conv1dConfig::new(5, 5, 5); + let conv = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", conv), + "Conv1d {stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}" + ); + } } From 359f7b03819baadd55e75ffbdacdcb481ce8dc9d Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:56:00 -0500 Subject: [PATCH 10/52] Test for conv2d display --- crates/burn-core/src/nn/conv/conv2d.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index ed34a089d4..cfb9fa6ea4 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -214,4 +214,15 @@ mod tests { assert_eq!(config.initializer, init); } + + #[test] + fn print() { + let config = Conv2dConfig::new([5, 1], [5, 5]); + let conv = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", conv), + "Conv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}" + ); + } } From d76fca7b9988159f8e9b8f107d2dd5b8eb1c9474 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:31:47 -0500 Subject: [PATCH 11/52] Add display for BinaryCrossEntropyLoss --- .../src/nn/loss/binary_cross_entropy.rs | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs index 192e07ab2a..a3015d45b5 100644 --- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::activation::log_sigmoid; use crate::tensor::{backend::Backend, Int, Tensor}; @@ -59,6 +60,7 @@ impl BinaryCrossEntropyLossConfig { /// /// Should be created using [BinaryCrossEntropyLossConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct BinaryCrossEntropyLoss { /// Weights for cross-entropy. pub weights: Option>, @@ -66,6 +68,22 @@ pub struct BinaryCrossEntropyLoss { logits: bool, } +impl ModuleDisplay for BinaryCrossEntropyLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("weights", &self.weights) + .add("smoothing", &self.smoothing) + .add("logits", &self.logits) + .optional() + } +} + impl BinaryCrossEntropyLoss { /// Compute the criterion on the input tensor. /// @@ -356,4 +374,16 @@ mod tests { .init(&device) .forward(logits, targets); } + + #[test] + fn print() { + let config = + BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9])); + let loss = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", loss), + "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}" + ); + } } From 5e5f4a3d45454c32d6229c80934fb13d7fd71ca2 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:35:38 -0500 Subject: [PATCH 12/52] Make attributes pub --- crates/burn-core/src/nn/loss/binary_cross_entropy.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs index a3015d45b5..407feb5700 100644 --- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -64,8 +64,10 @@ impl BinaryCrossEntropyLossConfig { pub struct BinaryCrossEntropyLoss { /// Weights for cross-entropy. pub weights: Option>, - smoothing: Option, - logits: bool, + /// Label smoothing alpha. + pub smoothing: Option, + /// Treat the inputs as logits + pub logits: bool, } impl ModuleDisplay for BinaryCrossEntropyLoss { From edf2945e7b712f16079163cc3a730b343385be0b Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:50:09 -0500 Subject: [PATCH 13/52] Display for cross entropy --- crates/burn-core/src/nn/loss/cross_entropy.rs | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/crates/burn-core/src/nn/loss/cross_entropy.rs b/crates/burn-core/src/nn/loss/cross_entropy.rs index 0dea70b517..0bf985505a 100644 --- a/crates/burn-core/src/nn/loss/cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/cross_entropy.rs @@ -1,8 +1,10 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::activation::log_softmax; use crate::tensor::{backend::Backend, Bool, Int, Tensor}; use crate::{config::Config, module::Module}; +use alloc::string::ToString; use alloc::vec; use alloc::vec::Vec; @@ -29,7 +31,7 @@ pub struct CrossEntropyLossConfig { /// Alpha = 0 would be the same as default. pub smoothing: Option, - /// Create cross-entropy with probabilities as input instead of logits. + /// Create cross-entropy with probabilities as input instead of logits. /// #[config(default = true)] pub logits: bool, @@ -71,12 +73,39 @@ impl CrossEntropyLossConfig { /// /// Should be created using [CrossEntropyLossConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct CrossEntropyLoss { - pad_tokens: Option>, + /// Pad tokens to ignore in the loss calculation. + pub pad_tokens: Option>, /// Weights for cross-entropy. pub weights: Option>, - smoothing: Option, - logits: bool, + /// Label smoothing factor. + pub smoothing: Option, + /// Use logits as input. + pub logits: bool, +} + +impl ModuleDisplay for CrossEntropyLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens { + alloc::format!("Vec<0..{}>", pad_tokens.len()) + } else { + "None".to_string() + }; + + content + .add("pad_tokens", &pad_tokens) + .add("weights", &self.weights) + .add("smoothing", &self.smoothing) + .add("logits", &self.logits) + .optional() + } } impl CrossEntropyLoss { @@ -406,4 +435,17 @@ mod tests { loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); } + + #[test] + fn display() { + let config = CrossEntropyLossConfig::new() + .with_weights(Some(alloc::vec![3., 7., 0.9])) + .with_smoothing(Some(0.5)); + let loss = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", loss), + "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}" + ); + } } From 7451557a444db6c7bf1ec05451acd47d1dcf8b2c Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:50:44 -0500 Subject: [PATCH 14/52] Rename print to display in tests --- crates/burn-core/src/nn/conv/conv1d.rs | 2 +- crates/burn-core/src/nn/conv/conv2d.rs | 2 +- crates/burn-core/src/nn/conv/conv_transpose1d.rs | 2 +- crates/burn-core/src/nn/conv/conv_transpose2d.rs | 2 +- crates/burn-core/src/nn/loss/binary_cross_entropy.rs | 2 +- crates/burn-core/src/nn/norm/batch.rs | 2 +- crates/burn-core/src/nn/norm/group.rs | 2 +- crates/burn-core/src/nn/norm/instance.rs | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs index 8157a4ee23..acb62f15b1 100644 --- a/crates/burn-core/src/nn/conv/conv1d.rs +++ b/crates/burn-core/src/nn/conv/conv1d.rs @@ -171,7 +171,7 @@ mod tests { } #[test] - fn print() { + fn display() { let config = Conv1dConfig::new(5, 5, 5); let conv = config.init::(&Default::default()); diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index cfb9fa6ea4..60b0147a3d 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -216,7 +216,7 @@ mod tests { } #[test] - fn print() { + fn display() { let config = Conv2dConfig::new([5, 1], [5, 5]); let conv = config.init::(&Default::default()); diff --git a/crates/burn-core/src/nn/conv/conv_transpose1d.rs b/crates/burn-core/src/nn/conv/conv_transpose1d.rs index d9c2268ba8..27cc217c6a 100644 --- a/crates/burn-core/src/nn/conv/conv_transpose1d.rs +++ b/crates/burn-core/src/nn/conv/conv_transpose1d.rs @@ -180,7 +180,7 @@ mod tests { } #[test] - fn print() { + fn display() { let config = ConvTranspose1dConfig::new([5, 2], 5); let conv = config.init::(&Default::default()); diff --git a/crates/burn-core/src/nn/conv/conv_transpose2d.rs b/crates/burn-core/src/nn/conv/conv_transpose2d.rs index 8540f2dbf8..b9f9f3aa21 100644 --- a/crates/burn-core/src/nn/conv/conv_transpose2d.rs +++ b/crates/burn-core/src/nn/conv/conv_transpose2d.rs @@ -182,7 +182,7 @@ mod tests { } #[test] - fn print() { + fn display() { let config = ConvTranspose2dConfig::new([5, 2], [5, 5]); let conv = config.init::(&Default::default()); diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs index 407feb5700..f0de046af9 100644 --- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -378,7 +378,7 @@ mod tests { } #[test] - fn print() { + fn display() { let config = BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9])); let loss = config.init::(&Default::default()); diff --git a/crates/burn-core/src/nn/norm/batch.rs b/crates/burn-core/src/nn/norm/batch.rs index 9aa46f3626..d44a2d5f4f 100644 --- a/crates/burn-core/src/nn/norm/batch.rs +++ b/crates/burn-core/src/nn/norm/batch.rs @@ -428,7 +428,7 @@ mod tests_2d { } #[test] - fn print() { + fn display() { let batch_norm = BatchNormConfig::new(3).init::(&Default::default()); diff --git a/crates/burn-core/src/nn/norm/group.rs b/crates/burn-core/src/nn/norm/group.rs index 48b6a64716..f73e5c423a 100644 --- a/crates/burn-core/src/nn/norm/group.rs +++ b/crates/burn-core/src/nn/norm/group.rs @@ -318,7 +318,7 @@ mod tests { } #[test] - fn print() { + fn display() { let config = GroupNormConfig::new(3, 6); let group_norm = config.init::(&Default::default()); diff --git a/crates/burn-core/src/nn/norm/instance.rs b/crates/burn-core/src/nn/norm/instance.rs index 55f5805641..bae03bdfaa 100644 --- a/crates/burn-core/src/nn/norm/instance.rs +++ b/crates/burn-core/src/nn/norm/instance.rs @@ -212,7 +212,7 @@ mod tests { } #[test] - fn print() { + fn display() { let config = InstanceNormConfig::new(6); let instance_norm = config.init::(&Default::default()); From 7da0e1fabba382520ca019aa12b8c45df36d8c00 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:57:06 -0500 Subject: [PATCH 15/52] Removed PhantomData --- crates/burn-core/src/nn/loss/huber.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/burn-core/src/nn/loss/huber.rs b/crates/burn-core/src/nn/loss/huber.rs index 06e4b0f143..758fdce5e0 100644 --- a/crates/burn-core/src/nn/loss/huber.rs +++ b/crates/burn-core/src/nn/loss/huber.rs @@ -3,7 +3,6 @@ use crate as burn; use crate::tensor::backend::Backend; use crate::tensor::Tensor; use crate::{config::Config, module::Module}; -use core::marker::PhantomData; use super::Reduction; @@ -16,7 +15,7 @@ pub struct HuberLossConfig { impl HuberLossConfig { /// Initialize [Huber loss](HuberLoss). - pub fn init(&self, device: &B::Device) -> HuberLoss { + pub fn init(&self, device: &B::Device) -> HuberLoss { // device is not needed as of now, but we might want to prepare some data on it // and its consistent with other loss functions let _ = device; @@ -24,7 +23,6 @@ impl HuberLossConfig { HuberLoss { delta: self.delta, lin_bias: self.delta * self.delta * 0.5, - _backend: PhantomData, } } @@ -52,14 +50,13 @@ impl HuberLossConfig { /// This loss function is less sensitive to outliers than the mean squared error loss. /// /// See also: -#[derive(Module, Debug)] -pub struct HuberLoss { +#[derive(Module, Debug, Clone)] +pub struct HuberLoss { delta: f32, lin_bias: f32, // delta * delta * 0.5 precomputed - _backend: PhantomData, } -impl HuberLoss { +impl HuberLoss { /// Compute the loss element-wise for the predictions and targets, then reduce /// to a single loss value. /// @@ -70,7 +67,7 @@ impl HuberLoss { /// - predictions: \[...dims\] /// - targets: \[...dims\] /// - output: \[1\] - pub fn forward( + pub fn forward( &self, predictions: Tensor, targets: Tensor, @@ -89,7 +86,7 @@ impl HuberLoss { /// - predictions: [...dims] /// - targets: [...dims] /// - output: [...dims] - pub fn forward_no_reduction( + pub fn forward_no_reduction( &self, predictions: Tensor, targets: Tensor, @@ -103,7 +100,10 @@ impl HuberLoss { /// /// - residuals: [...dims] /// - output: [...dims] - pub fn forward_residuals(&self, residuals: Tensor) -> Tensor { + pub fn forward_residuals( + &self, + residuals: Tensor, + ) -> Tensor { let is_large = residuals.clone().abs().greater_elem(self.delta); // We are interested in `sign(r)` when `abs(r) > self.delta`. Note that the // `sign()` function, in general, suffers from a jump at 0. From d13b8bbfc1434de6595124e7b06c6bccf7f14f71 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 11:03:41 -0500 Subject: [PATCH 16/52] Added huber display --- crates/burn-core/src/nn/loss/huber.rs | 37 ++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/crates/burn-core/src/nn/loss/huber.rs b/crates/burn-core/src/nn/loss/huber.rs index 758fdce5e0..d0be944b5a 100644 --- a/crates/burn-core/src/nn/loss/huber.rs +++ b/crates/burn-core/src/nn/loss/huber.rs @@ -1,5 +1,6 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; use crate::{config::Config, module::Module}; @@ -15,10 +16,7 @@ pub struct HuberLossConfig { impl HuberLossConfig { /// Initialize [Huber loss](HuberLoss). - pub fn init(&self, device: &B::Device) -> HuberLoss { - // device is not needed as of now, but we might want to prepare some data on it - // and its consistent with other loss functions - let _ = device; + pub fn init(&self) -> HuberLoss { self.assertions(); HuberLoss { delta: self.delta, @@ -51,11 +49,27 @@ impl HuberLossConfig { /// /// See also: #[derive(Module, Debug, Clone)] +#[module(custom_display)] pub struct HuberLoss { delta: f32, lin_bias: f32, // delta * delta * 0.5 precomputed } +impl ModuleDisplay for HuberLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("delta", &self.delta) + .add("lin_bias", &self.lin_bias) + .optional() + } +} + impl HuberLoss { /// Compute the loss element-wise for the predictions and targets, then reduce /// to a single loss value. @@ -138,7 +152,7 @@ mod tests { let predict = TestTensor::<1>::from_data(predict, &device); let targets = TestTensor::<1>::from_data(targets, &device); - let huber = HuberLossConfig::new(0.5).init(&device); + let huber = HuberLossConfig::new(0.5).init(); let loss_sum = huber.forward(predict.clone(), targets.clone(), Reduction::Sum); let loss = huber.forward(predict.clone(), targets.clone(), Reduction::Auto); @@ -165,7 +179,7 @@ mod tests { let predict = TestAutodiffTensor::from_data(predict, &device).require_grad(); let targets = TestAutodiffTensor::from_data(targets, &device); - let loss = HuberLossConfig::new(0.5).init(&device); + let loss = HuberLossConfig::new(0.5).init(); let loss = loss.forward_no_reduction(predict.clone(), targets); let grads = loss.backward(); @@ -175,4 +189,15 @@ mod tests { .to_data() .assert_approx_eq(&Data::from([-0.5, -0.5, 0., 0.3, 0.5]), 3); } + + #[test] + fn display() { + let config = HuberLossConfig::new(0.5); + let loss = config.init(); + + assert_eq!( + alloc::format!("{}", loss), + "HuberLoss {delta: 0.5, lin_bias: 0.125}" + ); + } } From 4711f9a9980403e66117792e4cf035871df78d5d Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 11:09:42 -0500 Subject: [PATCH 17/52] Add display for MSE --- crates/burn-core/src/nn/loss/mse.rs | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/crates/burn-core/src/nn/loss/mse.rs b/crates/burn-core/src/nn/loss/mse.rs index 0bd873a887..19dd50d38e 100644 --- a/crates/burn-core/src/nn/loss/mse.rs +++ b/crates/burn-core/src/nn/loss/mse.rs @@ -1,26 +1,24 @@ +use crate as burn; + use crate::nn::loss::reduction::Reduction; -use core::marker::PhantomData; +use crate::module::Module; use crate::tensor::{backend::Backend, Tensor}; /// Calculate the mean squared error loss from the input logits and the targets. -#[derive(Clone, Debug)] -pub struct MseLoss { - backend: PhantomData, -} +#[derive(Module, Clone, Debug)] +pub struct MseLoss; -impl Default for MseLoss { +impl Default for MseLoss { fn default() -> Self { Self::new() } } -impl MseLoss { +impl MseLoss { /// Create the criterion. pub fn new() -> Self { - Self { - backend: PhantomData, - } + Self } /// Compute the criterion on the input tensor. @@ -29,7 +27,7 @@ impl MseLoss { /// /// - logits: [batch_size, num_targets] /// - targets: [batch_size, num_targets] - pub fn forward( + pub fn forward( &self, logits: Tensor, targets: Tensor, @@ -43,7 +41,7 @@ impl MseLoss { } /// Compute the criterion on the input tensor without reducing. - pub fn forward_no_reduction( + pub fn forward_no_reduction( &self, logits: Tensor, targets: Tensor, @@ -79,4 +77,10 @@ mod tests { assert_eq!(loss.into_data(), Data::from([1.5])); assert_eq!(loss_sum.into_data(), Data::from([6.0])); } + + #[test] + fn display() { + let loss = MseLoss::new(); + assert_eq!(alloc::format!("{}", loss), "MseLoss"); + } } From 1a291a72cfb093f3c8a8725c1f4d5b077db0506f Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 12:11:02 -0500 Subject: [PATCH 18/52] Update log message format --- crates/burn-train/src/learner/train_val.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-train/src/learner/train_val.rs b/crates/burn-train/src/learner/train_val.rs index cbca2895a5..ed01271fc2 100644 --- a/crates/burn-train/src/learner/train_val.rs +++ b/crates/burn-train/src/learner/train_val.rs @@ -122,7 +122,7 @@ impl Learner { >::InnerModule: ValidStep, LC::EventProcessor: EventProcessor, { - log::info!("Fitting {}", self.model.to_string()); + log::info!("Fitting the model:\n {}", self.model.to_string()); // The reference model is always on the first device provided. if let Some(device) = self.devices.first() { self.model = self.model.fork(device); From a7bf1fc8bd38251d447d24ec2e67ad0684541048 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 13:45:22 -0500 Subject: [PATCH 19/52] Display for AdaptiveAvgPool1d --- .../src/nn/pool/adaptive_avg_pool1d.rs | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs b/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs index dd2c1d33c7..cb2d50bcce 100644 --- a/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs @@ -2,6 +2,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -18,10 +19,23 @@ pub struct AdaptiveAvgPool1dConfig { /// /// Should be created with [AdaptiveAvgPool1dConfig]. #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct AdaptiveAvgPool1d { output_size: usize, } +impl ModuleDisplay for AdaptiveAvgPool1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content.add("output_size", &self.output_size).optional() + } +} + impl AdaptiveAvgPool1dConfig { /// Initialize a new [adaptive avg pool 1d](AdaptiveAvgPool1d) module. pub fn init(&self) -> AdaptiveAvgPool1d { @@ -44,3 +58,19 @@ impl AdaptiveAvgPool1d { adaptive_avg_pool1d(input, self.output_size) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AdaptiveAvgPool1dConfig::new(3); + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "AdaptiveAvgPool1d {output_size: 3}" + ); + } +} From d6c7db0c05efe922b22db54905cd22d950f7c247 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 13:45:31 -0500 Subject: [PATCH 20/52] Display for AdaptiveAvgPool2d --- .../src/nn/pool/adaptive_avg_pool2d.rs | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs b/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs index 8d4d55d424..587b723a9b 100644 --- a/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs @@ -2,6 +2,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -18,10 +19,25 @@ pub struct AdaptiveAvgPool2dConfig { /// /// Should be created with [AdaptiveAvgPool2dConfig]. #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct AdaptiveAvgPool2d { output_size: [usize; 2], } +impl ModuleDisplay for AdaptiveAvgPool2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let output_size = alloc::format!("{:?}", self.output_size); + + content.add("output_size", &output_size).optional() + } +} + impl AdaptiveAvgPool2dConfig { /// Initialize a new [adaptive avg pool 2d](AdaptiveAvgPool2d) module. pub fn init(&self) -> AdaptiveAvgPool2d { @@ -44,3 +60,19 @@ impl AdaptiveAvgPool2d { adaptive_avg_pool2d(input, self.output_size) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AdaptiveAvgPool2dConfig::new([3, 3]); + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "AdaptiveAvgPool2d {output_size: [3, 3]}" + ); + } +} From 1c10e325012b4de95d9a0ea8027ebf281c964dbb Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 13:45:59 -0500 Subject: [PATCH 21/52] Make attributes pub for AvgPool1d --- crates/burn-core/src/nn/pool/avg_pool1d.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/crates/burn-core/src/nn/pool/avg_pool1d.rs b/crates/burn-core/src/nn/pool/avg_pool1d.rs index 5787cc5e2a..f5bc14ef69 100644 --- a/crates/burn-core/src/nn/pool/avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool1d.rs @@ -41,10 +41,14 @@ pub struct AvgPool1dConfig { #[derive(Module, Clone, Debug)] pub struct AvgPool1d { - stride: usize, - kernel_size: usize, - padding: Ignored, - count_include_pad: bool, + /// The stride. + pub stride: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The padding configuration. + pub padding: Ignored, + /// If the padding is counted in the denominator when computing the average. + pub count_include_pad: bool, } impl AvgPool1dConfig { From f453b99c9b0ca4b45186ba2d3c8bfe3ea3e6a591 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 13:58:07 -0500 Subject: [PATCH 22/52] Display for AvgPool1d and AvgPool2d --- crates/burn-core/src/nn/pool/avg_pool1d.rs | 35 +++++++++++++++++++++ crates/burn-core/src/nn/pool/avg_pool2d.rs | 36 ++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/crates/burn-core/src/nn/pool/avg_pool1d.rs b/crates/burn-core/src/nn/pool/avg_pool1d.rs index f5bc14ef69..949160fd5b 100644 --- a/crates/burn-core/src/nn/pool/avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool1d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig1d; use crate::tensor::backend::Backend; @@ -40,6 +41,7 @@ pub struct AvgPool1dConfig { /// [Issue 636](https://github.com/tracel-ai/burn/issues/636) #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct AvgPool1d { /// The stride. pub stride: usize, @@ -51,6 +53,23 @@ pub struct AvgPool1d { pub count_include_pad: bool, } +impl ModuleDisplay for AvgPool1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &self.kernel_size) + .add("stride", &self.stride) + .add("padding", &self.padding) + .add("count_include_pad", &self.count_include_pad) + .optional() + } +} + impl AvgPool1dConfig { /// Initialize a new [avg pool 1d](AvgPool1d) module. pub fn init(&self) -> AvgPool1d { @@ -87,3 +106,19 @@ impl AvgPool1d { ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AvgPool1dConfig::new(3); + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "AvgPool1d {kernel_size: 3, stride: 1, padding: Valid, count_include_pad: true}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/avg_pool2d.rs b/crates/burn-core/src/nn/pool/avg_pool2d.rs index 00bf712f80..36950b9436 100644 --- a/crates/burn-core/src/nn/pool/avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool2d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig2d; use crate::tensor::backend::Backend; @@ -39,6 +40,7 @@ pub struct AvgPool2dConfig { /// TODO: Add support for `count_include_pad=False`, see /// [Issue 636](https://github.com/tracel-ai/burn/issues/636) #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct AvgPool2d { stride: [usize; 2], kernel_size: [usize; 2], @@ -46,6 +48,23 @@ pub struct AvgPool2d { count_include_pad: bool, } +impl ModuleDisplay for AvgPool2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("padding", &self.padding) + .add("count_include_pad", &self.count_include_pad) + .optional() + } +} + impl AvgPool2dConfig { /// Initialize a new [avg pool 2d](AvgPool2d) module. pub fn init(&self) -> AvgPool2d { @@ -82,3 +101,20 @@ impl AvgPool2d { ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AvgPool2dConfig::new([3, 3]); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "AvgPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, count_include_pad: true}" + ); + } +} From 13244715ce24a914777a3f98678c36a336ede5a6 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:08:27 -0500 Subject: [PATCH 23/52] Display for MaxPool1d and MaxPool2d --- crates/burn-core/src/nn/pool/max_pool1d.rs | 48 ++++++++++++++++++++-- crates/burn-core/src/nn/pool/max_pool2d.rs | 48 ++++++++++++++++++++-- 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/crates/burn-core/src/nn/pool/max_pool1d.rs b/crates/burn-core/src/nn/pool/max_pool1d.rs index 040a7a1027..5be363e908 100644 --- a/crates/burn-core/src/nn/pool/max_pool1d.rs +++ b/crates/burn-core/src/nn/pool/max_pool1d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig1d; use crate::tensor::backend::Backend; @@ -28,11 +29,33 @@ pub struct MaxPool1dConfig { /// /// Should be created with [MaxPool1dConfig](MaxPool1dConfig). #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct MaxPool1d { - stride: usize, - kernel_size: usize, - padding: Ignored, - dilation: usize, + /// The stride. + pub stride: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The padding configuration. + pub padding: Ignored, + /// The dilation. + pub dilation: usize, +} + +impl ModuleDisplay for MaxPool1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &self.kernel_size) + .add("stride", &self.stride) + .add("padding", &self.padding) + .add("dilation", &self.dilation) + .optional() + } } impl MaxPool1dConfig { @@ -65,3 +88,20 @@ impl MaxPool1d { max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = MaxPool1dConfig::new(3); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "MaxPool1d {kernel_size: 3, stride: 1, padding: Valid, dilation: 1}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/max_pool2d.rs b/crates/burn-core/src/nn/pool/max_pool2d.rs index 552cde9b35..ab9c60d276 100644 --- a/crates/burn-core/src/nn/pool/max_pool2d.rs +++ b/crates/burn-core/src/nn/pool/max_pool2d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig2d; use crate::tensor::backend::Backend; @@ -28,11 +29,33 @@ pub struct MaxPool2dConfig { /// /// Should be created with [MaxPool2dConfig](MaxPool2dConfig). #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct MaxPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: Ignored, - dilation: [usize; 2], + /// The strides. + pub stride: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The padding configuration. + pub padding: Ignored, + /// The dilation. + pub dilation: [usize; 2], +} + +impl ModuleDisplay for MaxPool2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("padding", &self.padding) + .add("dilation", &alloc::format!("{:?}", &self.dilation)) + .optional() + } } impl MaxPool2dConfig { @@ -65,3 +88,20 @@ impl MaxPool2d { max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = MaxPool2dConfig::new([3, 3]); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "MaxPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, dilation: [1, 1]}" + ); + } +} From 7b736cb33d27bcbd6ae9a3054bf2acbcbc3dbebe Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:28:25 -0500 Subject: [PATCH 24/52] Add display for Gru --- crates/burn-core/src/nn/rnn/gru.rs | 45 +++++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/gru.rs b/crates/burn-core/src/nn/rnn/gru.rs index fc29bdc192..02c41050bc 100644 --- a/crates/burn-core/src/nn/rnn/gru.rs +++ b/crates/burn-core/src/nn/rnn/gru.rs @@ -2,6 +2,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::nn::rnn::gate_controller; use crate::nn::Initializer; use crate::tensor::activation; @@ -30,11 +31,35 @@ pub struct GruConfig { /// /// Should be created with [GruConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct Gru { - update_gate: GateController, - reset_gate: GateController, - new_gate: GateController, - d_hidden: usize, + /// The update gate controller. + pub update_gate: GateController, + /// The reset gate controller. + pub reset_gate: GateController, + /// The new gate controller. + pub new_gate: GateController, + /// The size of the hidden state. + pub d_hidden: usize, +} + +impl ModuleDisplay for Gru { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_input, _] = self.update_gate.input_transform.weight.shape().dims; + let bias = self.update_gate.input_transform.bias.is_some(); + + content + .add("d_input", &d_input) + .add("d_hidden", &self.d_hidden) + .add("bias", &bias) + .optional() + } } impl GruConfig { @@ -271,4 +296,16 @@ mod tests { assert_eq!(hidden_state.shape().dims, [8, 10, 1024]); } + + #[test] + fn display() { + let config = GruConfig::new(2, 8, true); + + let layer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", layer), + "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}" + ); + } } From f38fc2e2176e51234c04163450a495be176b9caf Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:34:37 -0500 Subject: [PATCH 25/52] Add display for lstm --- crates/burn-core/src/nn/rnn/lstm.rs | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index 8690b16d27..1d73beecb8 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -2,6 +2,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::nn::rnn::gate_controller::GateController; use crate::nn::Initializer; use crate::tensor::activation; @@ -43,6 +44,7 @@ pub struct LstmConfig { /// /// Should be created with [LstmConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct Lstm { /// The input gate regulates which information to update and store in the cell state at each time step. pub input_gate: GateController, @@ -55,6 +57,24 @@ pub struct Lstm { d_hidden: usize, } +impl ModuleDisplay for Lstm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_input, _] = self.input_gate.input_transform.weight.shape().dims; + let bias = self.input_gate.input_transform.bias.is_some(); + + content + .add("d_input", &d_input) + .add("d_hidden", &self.d_hidden) + .add("bias", &bias) + .optional() + } +} impl LstmConfig { /// Initialize a new [lstm](Lstm) module. pub fn init(&self, device: &B::Device) -> Lstm { @@ -686,4 +706,16 @@ mod tests { .to_data() .assert_approx_eq(&expected_cn_without_init_state, 3); } + + #[test] + fn display() { + let config = LstmConfig::new(2, 3, true); + + let layer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", layer), + "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}" + ); + } } From a7f371754f20452789735e49c4e60c903bbf3beb Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:38:12 -0500 Subject: [PATCH 26/52] Add display test to dropout --- crates/burn-core/src/nn/dropout.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs index b4bee8d61e..3cfb1f16e4 100644 --- a/crates/burn-core/src/nn/dropout.rs +++ b/crates/burn-core/src/nn/dropout.rs @@ -99,4 +99,12 @@ mod tests { assert_eq!(tensor.to_data(), output.to_data()); } + + #[test] + fn display() { + let config = DropoutConfig::new(0.5); + let layer = config.init(); + + assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}"); + } } From 169d42621c8c19c6adccc9f4fb6e25bb0af37600 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:50:13 -0500 Subject: [PATCH 27/52] Make dropout attributes pub --- crates/burn-core/src/nn/dropout.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs index 3cfb1f16e4..4224ed5e14 100644 --- a/crates/burn-core/src/nn/dropout.rs +++ b/crates/burn-core/src/nn/dropout.rs @@ -23,7 +23,8 @@ pub struct DropoutConfig { #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct Dropout { - prob: f64, + /// The probability of randomly zeroes some elements of the input tensor during training. + pub prob: f64, } impl DropoutConfig { From ade898e468e1f783561d4f9d3d44b001e415a492 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:50:33 -0500 Subject: [PATCH 28/52] Add display for Embedding --- crates/burn-core/src/nn/embedding.rs | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/crates/burn-core/src/nn/embedding.rs b/crates/burn-core/src/nn/embedding.rs index 3ad02f141d..0ded3fcaca 100644 --- a/crates/burn-core/src/nn/embedding.rs +++ b/crates/burn-core/src/nn/embedding.rs @@ -4,6 +4,7 @@ use super::Initializer; use crate::config::Config; use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Int; use crate::tensor::Tensor; @@ -26,12 +27,29 @@ pub struct EmbeddingConfig { /// /// Should be created with [EmbeddingConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct Embedding { /// The learnable weights of the module of shape `[n_embedding, d_model]` initialized /// from a normal distribution `N(0, 1)`. pub weight: Param>, } +impl ModuleDisplay for Embedding { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [n_embedding, d_model] = self.weight.shape().dims; + content + .add("n_embedding", &n_embedding) + .add("d_model", &d_model) + .optional() + } +} + impl EmbeddingConfig { /// Initialize a new [embedding](Embedding) module. pub fn init(&self, device: &B::Device) -> Embedding { @@ -98,4 +116,15 @@ mod tests { .to_data() .assert_approx_eq(&Data::zeros(embed.weight.shape()), 3); } + + #[test] + fn display() { + let config = EmbeddingConfig::new(100, 10); + let embed = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", embed), + "Embedding {n_embedding: 100, d_model: 10, params: 1000}" + ); + } } From 9632ff42b049f2671d1e573fdcd932aa0888bdf9 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:53:33 -0500 Subject: [PATCH 29/52] Clean up --- crates/burn-core/src/nn/gelu.rs | 2 +- crates/burn-core/src/nn/relu.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/burn-core/src/nn/gelu.rs b/crates/burn-core/src/nn/gelu.rs index 421f83452d..77a4511eb9 100644 --- a/crates/burn-core/src/nn/gelu.rs +++ b/crates/burn-core/src/nn/gelu.rs @@ -7,7 +7,7 @@ use crate::tensor::Tensor; /// Applies the Gaussian Error Linear Units function element-wise. /// See also [gelu](burn::tensor::activation::gelu) #[derive(Module, Clone, Debug, Default)] -pub struct Gelu {} +pub struct Gelu; impl Gelu { /// Create the module. diff --git a/crates/burn-core/src/nn/relu.rs b/crates/burn-core/src/nn/relu.rs index 262c393134..8d398c697b 100644 --- a/crates/burn-core/src/nn/relu.rs +++ b/crates/burn-core/src/nn/relu.rs @@ -8,7 +8,7 @@ use crate::tensor::Tensor; /// See also [relu](burn::tensor::activation::relu) /// #[derive(Module, Clone, Debug, Default)] -pub struct Relu {} +pub struct Relu; impl Relu { /// Create the module. From de9265789b6477d06b93d46204b6082e0f0b4461 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:53:53 -0500 Subject: [PATCH 30/52] Add linear display test --- crates/burn-core/src/nn/linear.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index 0bc16552f1..588fcc86ff 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -196,4 +196,15 @@ mod tests { assert_eq!(result_1d.into_data(), result_2d.into_data()); } + + #[test] + fn display() { + let config = LinearConfig::new(3, 5); + let linear = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", linear), + "Linear {d_input: 3, d_output: 5, bias: true, params: 20}" + ); + } } From cd38b99efc93db6d8176045886996a0aabb28d7c Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:54:09 -0500 Subject: [PATCH 31/52] Make attribute pub --- crates/burn-core/src/nn/pos_encoding.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/pos_encoding.rs b/crates/burn-core/src/nn/pos_encoding.rs index 703927dc3d..6e6f3c18a5 100644 --- a/crates/burn-core/src/nn/pos_encoding.rs +++ b/crates/burn-core/src/nn/pos_encoding.rs @@ -41,7 +41,8 @@ pub struct PositionalEncodingConfig { /// Should be created using [PositionalEncodingConfig] #[derive(Module, Debug)] pub struct PositionalEncoding { - sinusoids: Tensor, + /// The sinusoids used to add positional information to the input embeddings. + pub sinusoids: Tensor, } impl PositionalEncodingConfig { From 635b2cb253ca877c2def73881b70394272513a73 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:42:32 -0500 Subject: [PATCH 32/52] Clean up --- crates/burn-core/src/nn/dropout.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs index 4224ed5e14..d03e95c1f3 100644 --- a/crates/burn-core/src/nn/dropout.rs +++ b/crates/burn-core/src/nn/dropout.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::{DisplaySettings, Module, ModuleDisplay}; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::{Distribution, Tensor}; @@ -63,7 +63,7 @@ impl ModuleDisplay for Dropout { .optional() } - fn custom_content(&self, content: crate::module::Content) -> Option { + fn custom_content(&self, content: Content) -> Option { content.add("prob", &self.prob).optional() } } From 5d605c296a06064238b7607d947b8d52467e87f7 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:42:44 -0500 Subject: [PATCH 33/52] Add display test to gelu --- crates/burn-core/src/nn/gelu.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/crates/burn-core/src/nn/gelu.rs b/crates/burn-core/src/nn/gelu.rs index 77a4511eb9..f56bc29f83 100644 --- a/crates/burn-core/src/nn/gelu.rs +++ b/crates/burn-core/src/nn/gelu.rs @@ -25,3 +25,15 @@ impl Gelu { crate::tensor::activation::gelu(input) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let layer = Gelu::new(); + + assert_eq!(alloc::format!("{}", layer), "Gelu"); + } +} From a2ab8dd0c34c1b3fbb99744ae9ea65a2c5622210 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:43:01 -0500 Subject: [PATCH 34/52] Add display to leaky relu --- crates/burn-core/src/nn/leaky_relu.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/crates/burn-core/src/nn/leaky_relu.rs b/crates/burn-core/src/nn/leaky_relu.rs index 1a230a4841..aa1447d546 100644 --- a/crates/burn-core/src/nn/leaky_relu.rs +++ b/crates/burn-core/src/nn/leaky_relu.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; use crate::module::Module; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -10,6 +11,7 @@ use crate::tensor::activation::leaky_relu; /// /// Should be created with [LeakyReluConfig](LeakyReluConfig). #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct LeakyRelu { /// The negative slope. pub negative_slope: f64, @@ -30,6 +32,20 @@ impl LeakyReluConfig { } } +impl ModuleDisplay for LeakyRelu { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("negative_slope", &self.negative_slope) + .optional() + } +} + impl LeakyRelu { /// Forward pass for the Leaky ReLu layer. /// @@ -92,4 +108,13 @@ mod tests { .to_data() .assert_approx_eq(&Data::from(expected_output), 4) } + + #[test] + fn display() { + let config = LeakyReluConfig::new().init(); + assert_eq!( + alloc::format!("{}", config), + "LeakyRelu {negative_slope: 0.01}" + ); + } } From 84221ad71909ac72f9ffabc1b0bb70e629e2da51 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:43:23 -0500 Subject: [PATCH 35/52] Add display to prelu --- crates/burn-core/src/nn/prelu.rs | 37 +++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/prelu.rs b/crates/burn-core/src/nn/prelu.rs index f15c96481c..b625bb29f2 100644 --- a/crates/burn-core/src/nn/prelu.rs +++ b/crates/burn-core/src/nn/prelu.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::nn::Initializer; use crate::tensor::backend::Backend; use crate::tensor::Tensor; @@ -9,11 +9,33 @@ use crate::tensor::Tensor; /// /// Should be created using [PReluConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct PRelu { /// the weights learnt for PReLu. can be of shape \[1\] or \[num_parameters\] in which case it must /// be the same as number of channels in the input tensor pub alpha: Param>, + + /// Alpha value for the PRelu layer + pub alpha_value: f64, +} + +impl ModuleDisplay for PRelu { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [num_parameters] = self.alpha.shape().dims; + + content + .add("num_parameters", &num_parameters) + .add("alpha_value", &self.alpha_value) + .optional() + } } + /// Configuration to create a [Parametric Relu](PRelu) layer using the [init function](PReluConfig::init). #[derive(Config, Debug)] pub struct PReluConfig { @@ -24,6 +46,7 @@ pub struct PReluConfig { #[config(default = "0.25")] pub alpha: f64, } + impl PReluConfig { /// Initialize a new [Parametric Relu](PRelu) Layer pub fn init(&self, device: &B::Device) -> PRelu { @@ -47,3 +70,15 @@ impl PRelu { crate::tensor::activation::prelu(input, self.alpha.val()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let layer = Relu::new(); + + assert_eq!(alloc::format!("{}", layer), "Relu"); + } +} From 213732c14cd501f3b211f9940049842ca384abde Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:43:32 -0500 Subject: [PATCH 36/52] Add test to relu --- crates/burn-core/src/nn/relu.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/crates/burn-core/src/nn/relu.rs b/crates/burn-core/src/nn/relu.rs index 8d398c697b..67ed033b9a 100644 --- a/crates/burn-core/src/nn/relu.rs +++ b/crates/burn-core/src/nn/relu.rs @@ -25,3 +25,15 @@ impl Relu { crate::tensor::activation::relu(input) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let layer = Relu::new(); + + assert_eq!(alloc::format!("{}", layer), "Relu"); + } +} From d7c488e11a07958f62d0ce4d867123d29c761aa7 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 09:51:18 -0500 Subject: [PATCH 37/52] Fix prelu test --- crates/burn-core/src/nn/prelu.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/crates/burn-core/src/nn/prelu.rs b/crates/burn-core/src/nn/prelu.rs index b625bb29f2..6bb6c32f7f 100644 --- a/crates/burn-core/src/nn/prelu.rs +++ b/crates/burn-core/src/nn/prelu.rs @@ -53,6 +53,7 @@ impl PReluConfig { PRelu { // alpha is a tensor of length num_parameters alpha: Initializer::Constant { value: self.alpha }.init([self.num_parameters], device), + alpha_value: self.alpha, } } } @@ -74,11 +75,15 @@ impl PRelu { #[cfg(test)] mod tests { use super::*; + use crate::TestBackend; #[test] fn display() { - let layer = Relu::new(); + let layer = PReluConfig::new().init::(&Default::default()); - assert_eq!(alloc::format!("{}", layer), "Relu"); + assert_eq!( + alloc::format!("{}", layer), + "PRelu {num_parameters: 1, alpha_value: 0.25, params: 1}" + ); } } From bfcf06703cfd2528a7b98c11328e8003e8daba17 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:42:32 -0500 Subject: [PATCH 38/52] Clean up --- crates/burn-core/src/nn/linear.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index 4046c2ec9a..c90074b6db 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -1,10 +1,8 @@ use crate as burn; -use crate::module::DisplaySettings; -use crate::module::ModuleDisplay; use crate::config::Config; -use crate::module::Module; use crate::module::Param; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::tensor::{backend::Backend, Tensor}; use super::Initializer; @@ -93,7 +91,7 @@ impl ModuleDisplay for Linear { .optional() } - fn custom_content(&self, content: crate::module::Content) -> Option { + fn custom_content(&self, content: Content) -> Option { let [d_input, d_output] = self.weight.shape().dims; content .add("d_input", &d_input) From 81ad5f7bbbbc3d6c58de01b0485cd407b3dd2512 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:42:48 -0500 Subject: [PATCH 39/52] Add display for PositionalEncoding --- crates/burn-core/src/nn/pos_encoding.rs | 42 +++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/crates/burn-core/src/nn/pos_encoding.rs b/crates/burn-core/src/nn/pos_encoding.rs index 80b4177175..ff1db2cfcc 100644 --- a/crates/burn-core/src/nn/pos_encoding.rs +++ b/crates/burn-core/src/nn/pos_encoding.rs @@ -2,7 +2,8 @@ use alloc::vec::Vec; use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; + use crate::tensor::backend::Backend; use crate::tensor::Tensor; use crate::tensor::TensorData; @@ -40,9 +41,31 @@ pub struct PositionalEncodingConfig { /// /// Should be created using [PositionalEncodingConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct PositionalEncoding { /// The sinusoids used to add positional information to the input embeddings. pub sinusoids: Tensor, + /// The maximum sequence size to use. + pub max_sequence_size: usize, + /// Max time scale to use. + pub max_timescale: usize, +} + +impl ModuleDisplay for PositionalEncoding { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [_, _, d_model] = self.sinusoids.shape().dims; + content + .add("d_model", &d_model) + .add("max_sequence_size", &self.max_sequence_size) + .add("max_timescale", &self.max_timescale) + .optional() + } } impl PositionalEncodingConfig { @@ -56,7 +79,11 @@ impl PositionalEncodingConfig { ) .unsqueeze::<3>(); - PositionalEncoding { sinusoids } + PositionalEncoding { + sinusoids, + max_sequence_size: self.max_sequence_size, + max_timescale: self.max_timescale, + } } } @@ -246,4 +273,15 @@ mod tests { let input = Tensor::zeros([1, 6_000, d_model], &device); let _output = pe.forward(input); } + + #[test] + fn display() { + let config = PositionalEncodingConfig::new(4); + let pe = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", pe), + "PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}" + ); + } } From 129e220755ff484089e721710855a1693b8d343f Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:50:26 -0500 Subject: [PATCH 40/52] Add display for role encoding --- crates/burn-core/src/nn/rope_encoding.rs | 43 ++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/crates/burn-core/src/nn/rope_encoding.rs b/crates/burn-core/src/nn/rope_encoding.rs index 4351bcbe22..2f94014710 100644 --- a/crates/burn-core/src/nn/rope_encoding.rs +++ b/crates/burn-core/src/nn/rope_encoding.rs @@ -1,6 +1,6 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::tensor::backend::Backend; use crate::tensor::Int; use crate::tensor::Tensor; @@ -74,7 +74,11 @@ impl RotaryEncodingConfig { .repeat(2, 2) .reshape([self.max_sequence_length, self.d_model, 2]); - RotaryEncoding { freq_complex } + RotaryEncoding { + freq_complex, + max_sequence_length: self.max_sequence_length, + theta: self.theta, + } } } @@ -87,9 +91,31 @@ impl RotaryEncodingConfig { /// /// Should be created using [RotaryEncodingConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct RotaryEncoding { /// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components - freq_complex: Tensor, + pub freq_complex: Tensor, + /// Maximum sequence length of input + pub max_sequence_length: usize, + /// Scaling factor for frequency computation. + pub theta: f32, +} + +impl ModuleDisplay for RotaryEncoding { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [_, _, d_model] = self.freq_complex.shape().dims; + content + .add("d_model", &d_model) + .add("max_sequence_length", &self.max_sequence_length) + .add("theta", &self.theta) + .optional() + } } #[allow(clippy::single_range_in_vec_init)] @@ -238,4 +264,15 @@ mod tests { let input = Tensor::zeros([1, 5, d_model], &device); let _output = pe.forward(input); } + + #[test] + fn display() { + let config = RotaryEncodingConfig::new(10, 4); + let pe = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", pe), + "RotaryEncoding {d_model: 2, max_sequence_length: 10, theta: 10000}" + ); + } } From 30f01f4c816c3c4c316750523210fab90b4cc413 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:55:00 -0500 Subject: [PATCH 41/52] Add display for SwiGlu --- crates/burn-core/src/nn/swiglu.rs | 33 +++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/crates/burn-core/src/nn/swiglu.rs b/crates/burn-core/src/nn/swiglu.rs index 3dacbae68a..2227db5457 100644 --- a/crates/burn-core/src/nn/swiglu.rs +++ b/crates/burn-core/src/nn/swiglu.rs @@ -1,7 +1,7 @@ use crate as burn; use crate::config::Config; -use crate::module::Module; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::tensor::activation::silu; use crate::tensor::{backend::Backend, Tensor}; @@ -31,6 +31,7 @@ pub struct SwiGluConfig { /// /// Should be created with [SwiGluConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct SwiGlu { /// The inner linear layer for Swish activation function /// with `d_input` input features and `d_output` output features. @@ -40,6 +41,23 @@ pub struct SwiGlu { pub linear_outer: Linear, } +impl ModuleDisplay for SwiGlu { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_input, d_output] = self.linear_inner.weight.shape().dims; + content + .add("d_input", &d_input) + .add("d_output", &d_output) + .add("bias", &self.linear_inner.bias.is_some()) + .optional() + } +} + impl SwiGluConfig { /// Initialize a new [SwiGLU](SwiGlu) activation layer. pub fn init(&self, device: &B::Device) -> SwiGlu { @@ -61,7 +79,7 @@ impl SwiGlu { /// /// # Shapes /// - /// - input: `[batch_size, seq_length, d_input]` + /// - input: `[batch_size, seq_length, d_input]` /// - output: `[batch_size, seq_length, d_output]` pub fn forward(&self, input: Tensor) -> Tensor { let x = self.linear_inner.forward(input.clone()); @@ -112,4 +130,15 @@ mod tests { .to_data() .assert_approx_eq(&expected_output.to_data(), 4); } + + #[test] + fn display() { + let config = SwiGluConfig::new(3, 5); + let swiglu = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", swiglu), + "SwiGlu {d_input: 3, d_output: 5, bias: false, params: 30}" + ); + } } From e29943a8cc2f38b0040ef18d35b553e5feb19b08 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:56:13 -0500 Subject: [PATCH 42/52] Add display test for Tanh --- crates/burn-core/src/nn/tanh.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/crates/burn-core/src/nn/tanh.rs b/crates/burn-core/src/nn/tanh.rs index 293da36ce7..e9bbfb0ac5 100644 --- a/crates/burn-core/src/nn/tanh.rs +++ b/crates/burn-core/src/nn/tanh.rs @@ -24,3 +24,15 @@ impl Tanh { crate::tensor::activation::tanh(input) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let layer = Tanh::new(); + + assert_eq!(alloc::format!("{}", layer), "Tanh"); + } +} From fd53c9aa7612a744d480d752dddbb2a4914c3ba3 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 15:24:39 -0500 Subject: [PATCH 43/52] Fix burn-import --- crates/burn-import/src/burn/node/prelu.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/burn-import/src/burn/node/prelu.rs b/crates/burn-import/src/burn/node/prelu.rs index 57c948a3a6..59f9baed28 100644 --- a/crates/burn-import/src/burn/node/prelu.rs +++ b/crates/burn-import/src/burn/node/prelu.rs @@ -1,7 +1,7 @@ use super::{Node, NodeCodegen, SerializationBackend}; use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type}; use burn::{ - module::{Param, ParamId}, + module::{ConstantRecord, Param, ParamId}, nn::{PReluConfig, PReluRecord}, record::{PrecisionSettings, Record}, tensor::{Tensor, TensorData}, @@ -70,6 +70,7 @@ impl NodeCodegen for PReluNode { ParamId::new(), Tensor::from_data(self.alpha.clone().convert::(), &device), ), + alpha_value: ConstantRecord, }; let item = Record::into_item::(record); From 4ebc60c22f4ad4adcf6fe59771f8596b23a36a8d Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 15:28:49 -0500 Subject: [PATCH 44/52] Add display to unfold --- crates/burn-core/src/nn/unfold.rs | 59 ++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/crates/burn-core/src/nn/unfold.rs b/crates/burn-core/src/nn/unfold.rs index c958883e2b..41d2cedbb9 100644 --- a/crates/burn-core/src/nn/unfold.rs +++ b/crates/burn-core/src/nn/unfold.rs @@ -1,7 +1,8 @@ use crate as burn; use crate::config::Config; -use crate::module::{Ignored, Module}; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; + use burn_tensor::backend::Backend; use burn_tensor::module::unfold4d; use burn_tensor::ops::UnfoldOptions; @@ -27,15 +28,43 @@ pub struct Unfold4dConfig { /// /// Should be created with [Unfold4dConfig]. #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct Unfold4d { - config: Ignored, + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + pub stride: [usize; 2], + /// Spacing between kernel elements. + pub dilation: [usize; 2], + /// The padding configuration. + pub padding: [usize; 2], +} + +impl ModuleDisplay for Unfold4d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("dilation", &alloc::format!("{:?}", &self.dilation)) + .add("padding", &alloc::format!("{:?}", &self.padding)) + .optional() + } } impl Unfold4dConfig { /// Initializes a new [Unfold4d] module. pub fn init(&self) -> Unfold4d { Unfold4d { - config: Ignored(self.clone()), + kernel_size: self.kernel_size, + stride: self.stride, + dilation: self.dilation, + padding: self.padding, } } } @@ -52,12 +81,24 @@ impl Unfold4d { pub fn forward(&self, input: Tensor) -> Tensor { unfold4d( input, - self.config.kernel_size, - UnfoldOptions::new( - self.config.stride, - self.config.padding, - self.config.dilation, - ), + self.kernel_size, + UnfoldOptions::new(self.stride, self.padding, self.dilation), ) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = Unfold4dConfig::new([3, 3]); + let unfold = config.init(); + + assert_eq!( + alloc::format!("{}", unfold), + "Unfold4d {kernel_size: [3, 3], stride: [1, 1], dilation: [1, 1], padding: [0, 0]}" + ); + } +} From 004fc39ef0ca579aba357d577b465b4f4c012c9d Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 15:52:37 -0500 Subject: [PATCH 45/52] Add display for pwff --- crates/burn-core/src/nn/transformer/pwff.rs | 38 ++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/transformer/pwff.rs b/crates/burn-core/src/nn/transformer/pwff.rs index bd168b6433..b54278ff11 100644 --- a/crates/burn-core/src/nn/transformer/pwff.rs +++ b/crates/burn-core/src/nn/transformer/pwff.rs @@ -1,9 +1,9 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::nn::Initializer; use crate::{ config::Config, - module::Module, nn::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig}, tensor::{backend::Backend, Tensor}, }; @@ -36,6 +36,7 @@ pub struct PositionWiseFeedForwardConfig { /// /// Should be created using [PositionWiseFeedForwardConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct PositionWiseFeedForward { linear_inner: Linear, linear_outer: Linear, @@ -43,6 +44,24 @@ pub struct PositionWiseFeedForward { gelu: Gelu, } +impl ModuleDisplay for PositionWiseFeedForward { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_model, dff] = self.linear_inner.weight.shape().dims; + + content + .add("d_model", &d_model) + .add("d_ff", &dff) + .add("prob", &self.dropout.prob) + .optional() + } +} + impl PositionWiseFeedForwardConfig { /// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module. pub fn init(&self, device: &B::Device) -> PositionWiseFeedForward { @@ -74,3 +93,20 @@ impl PositionWiseFeedForward { self.linear_outer.forward(x) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn display() { + let config = PositionWiseFeedForwardConfig::new(2, 4); + let pwff = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", pwff), + "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}" + ); + } +} From b50d3d23806290d4b6d5db5d9aba74d105f3fef0 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 16:41:59 -0500 Subject: [PATCH 46/52] Add display for Transformer decoder --- .../burn-core/src/nn/transformer/decoder.rs | 76 +++++++++++++++++-- 1 file changed, 70 insertions(+), 6 deletions(-) diff --git a/crates/burn-core/src/nn/transformer/decoder.rs b/crates/burn-core/src/nn/transformer/decoder.rs index 85fc50159f..7784972c79 100644 --- a/crates/burn-core/src/nn/transformer/decoder.rs +++ b/crates/burn-core/src/nn/transformer/decoder.rs @@ -1,15 +1,15 @@ -use crate::tensor::Bool; use alloc::vec::Vec; +use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; + +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; +use crate::tensor::Bool; use crate::{ self as burn, nn::{attention::MhaCache, cache::TensorCache, Initializer}, }; - -use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ config::Config, - module::Module, nn::{ attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, Dropout, DropoutConfig, LayerNorm, LayerNormConfig, @@ -57,8 +57,51 @@ pub struct TransformerDecoderConfig { /// /// Should be created using [TransformerDecoderConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct TransformerDecoder { - layers: Vec>, + /// Transformer decoder layers. + pub layers: Vec>, + + /// The size of the model. + pub d_model: usize, + + /// The size of the position-wise feed-forward network. + pub d_ff: usize, + + /// The number of attention heads. + pub n_heads: usize, + + /// The number of layers. + pub n_layers: usize, + + /// The dropout rate. Default: 0.1 + pub dropout: f64, + + /// Layer norm will be applied first instead of after the other modules. + pub norm_first: bool, + + /// Use "quiet softmax" instead of regular softmax. + pub quiet_softmax: bool, +} + +impl ModuleDisplay for TransformerDecoder { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("d_model", &self.d_model) + .add("d_ff", &self.d_ff) + .add("n_heads", &self.n_heads) + .add("n_layers", &self.n_layers) + .add("dropout", &self.dropout) + .add("norm_first", &self.norm_first) + .add("quiet_softmax", &self.quiet_softmax) + .optional() + } } impl TransformerDecoderConfig { @@ -68,7 +111,16 @@ impl TransformerDecoderConfig { .map(|_| TransformerDecoderLayer::new(self, device)) .collect::>(); - TransformerDecoder { layers } + TransformerDecoder { + layers, + d_model: self.d_model, + d_ff: self.d_ff, + n_heads: self.n_heads, + n_layers: self.n_layers, + dropout: self.dropout, + norm_first: self.norm_first, + quiet_softmax: self.quiet_softmax, + } } } @@ -473,4 +525,16 @@ mod tests { .into_data() .assert_approx_eq(&output_2.into_data(), 3); } + + #[test] + fn display() { + let config = TransformerDecoderConfig::new(2, 4, 2, 3); + let transformer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", transformer), + "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \ + dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}" + ); + } } From d2d6fec3980b9646f6c3229d53d2ade5c4f90733 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 16:59:47 -0500 Subject: [PATCH 47/52] Add display for encoder --- .../burn-core/src/nn/transformer/encoder.rs | 73 +++++++++++++++++-- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/crates/burn-core/src/nn/transformer/encoder.rs b/crates/burn-core/src/nn/transformer/encoder.rs index 0eb226a3b9..6aea721d30 100644 --- a/crates/burn-core/src/nn/transformer/encoder.rs +++ b/crates/burn-core/src/nn/transformer/encoder.rs @@ -1,15 +1,14 @@ use crate::tensor::Bool; use alloc::vec::Vec; +use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::{ self as burn, nn::{attention::MhaCache, cache::TensorCache, Initializer}, }; - -use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig}; use crate::{ config::Config, - module::Module, nn::{ attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig}, Dropout, DropoutConfig, LayerNorm, LayerNormConfig, @@ -57,8 +56,51 @@ pub struct TransformerEncoderConfig { /// /// Should be created using [TransformerEncoderConfig] #[derive(Module, Debug)] +#[module(custom_display)] pub struct TransformerEncoder { - layers: Vec>, + /// The transformer encoder layers. + pub layers: Vec>, + + /// The size of the model. + pub d_model: usize, + + /// The size of the position-wise feed-forward network. + pub d_ff: usize, + + /// The number of attention heads. + pub n_heads: usize, + + /// The number of layers. + pub n_layers: usize, + + /// The dropout rate. Default: 0.1 + pub dropout: f64, + + /// Layer norm will be applied first instead of after the other modules. + pub norm_first: bool, + + /// Use "quiet softmax" instead of regular softmax. + pub quiet_softmax: bool, +} + +impl ModuleDisplay for TransformerEncoder { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("d_model", &self.d_model) + .add("d_ff", &self.d_ff) + .add("n_heads", &self.n_heads) + .add("n_layers", &self.n_layers) + .add("dropout", &self.dropout) + .add("norm_first", &self.norm_first) + .add("quiet_softmax", &self.quiet_softmax) + .optional() + } } /// [Transformer Encoder](TransformerEncoder) forward pass input argument. @@ -98,7 +140,16 @@ impl TransformerEncoderConfig { .map(|_| TransformerEncoderLayer::new(self, device)) .collect::>(); - TransformerEncoder { layers } + TransformerEncoder { + layers, + d_model: self.d_model, + d_ff: self.d_ff, + n_heads: self.n_heads, + n_layers: self.n_layers, + dropout: self.dropout, + norm_first: self.norm_first, + quiet_softmax: self.quiet_softmax, + } } } @@ -392,4 +443,16 @@ mod tests { .into_data() .assert_approx_eq(&output_2.into_data(), 3); } + + #[test] + fn display() { + let config = TransformerEncoderConfig::new(2, 4, 2, 3); + let transformer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", transformer), + "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \ + n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}" + ); + } } From 604129e1e285ac66e0a307b432547b7d5ef2f078 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:08:11 -0500 Subject: [PATCH 48/52] Add display for MultiHeadAttention --- crates/burn-core/src/nn/attention/mha.rs | 67 ++++++++++++++++++++---- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/crates/burn-core/src/nn/attention/mha.rs b/crates/burn-core/src/nn/attention/mha.rs index ed6eb49235..ad754bdf41 100644 --- a/crates/burn-core/src/nn/attention/mha.rs +++ b/crates/burn-core/src/nn/attention/mha.rs @@ -1,10 +1,10 @@ use crate as burn; +use crate::module::{Content, DisplaySettings, Module, ModuleDisplay}; use crate::nn::cache::TensorCache; use crate::nn::Initializer; use crate::{ config::Config, - module::Module, nn, tensor::{activation, backend::Backend, Bool, Tensor}, }; @@ -53,17 +53,49 @@ pub struct MultiHeadAttentionConfig { /// /// Should be created with [MultiHeadAttentionConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct MultiHeadAttention { - query: nn::Linear, - key: nn::Linear, - value: nn::Linear, - output: nn::Linear, - dropout: nn::Dropout, - activation: nn::Gelu, - n_heads: usize, - d_k: usize, - min_float: f64, - quiet_softmax: bool, + /// Linear layer to transform the input features into the query space. + pub query: nn::Linear, + /// Linear layer to transform the input features into the key space. + pub key: nn::Linear, + /// Linear layer to transform the input features into the value space. + pub value: nn::Linear, + /// Linear layer to transform the output features back to the original space. + pub output: nn::Linear, + /// Dropout layer. + pub dropout: nn::Dropout, + /// Activation function. + pub activation: nn::Gelu, + /// The size of each linear layer. + pub d_model: usize, + /// The number of heads. + pub n_heads: usize, + /// Size of the key and query vectors. + pub d_k: usize, + /// Minimum value a float can take. + pub min_float: f64, + /// Use "quiet softmax" instead of regular softmax. + pub quiet_softmax: bool, +} + +impl ModuleDisplay for MultiHeadAttention { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("d_model", &self.d_model) + .add("n_heads", &self.n_heads) + .add("d_k", &self.d_k) + .add("dropout", &self.dropout.prob) + .add("min_float", &self.min_float) + .add("quiet_softmax", &self.quiet_softmax) + .optional() + } } /// [Multihead attention](MultiHeadAttention) forward pass input argument. @@ -99,6 +131,7 @@ impl MultiHeadAttentionConfig { d_k: self.d_model / self.n_heads, min_float: self.min_float, quiet_softmax: self.quiet_softmax, + d_model: self.d_model, } } } @@ -478,4 +511,16 @@ mod tests { .into_data() .assert_approx_eq(&output_2.into_data(), 3); } + + #[test] + fn display() { + let config = MultiHeadAttentionConfig::new(2, 4); + let mha = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", mha), + "MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \ + dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}" + ); + } } From 957e2397f2c7621f1d6836ce2d7cec0a494ad4b0 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 27 Jun 2024 17:29:47 -0500 Subject: [PATCH 49/52] Fix test --- crates/burn-import/pytorch-tests/tests/linear/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/pytorch-tests/tests/linear/mod.rs b/crates/burn-import/pytorch-tests/tests/linear/mod.rs index 4244a35a4c..3ba09fa0ef 100644 --- a/crates/burn-import/pytorch-tests/tests/linear/mod.rs +++ b/crates/burn-import/pytorch-tests/tests/linear/mod.rs @@ -16,7 +16,7 @@ impl Net { pub fn init(device: &B::Device) -> Self { let fc1 = LinearConfig::new(2, 3).init(device); let fc2 = LinearConfig::new(3, 4).init(device); - let relu = Relu::default(); + let relu = Relu; Self { fc1, fc2, relu } } From fbd0e5103f0b25c5f7e61937a2344f566b55994b Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 28 Jun 2024 04:01:40 -0500 Subject: [PATCH 50/52] Make module attributes pub --- crates/burn-core/src/nn/conv/conv1d.rs | 15 ++++++++----- crates/burn-core/src/nn/conv/conv2d.rs | 15 ++++++++----- .../burn-core/src/nn/conv/conv_transpose1d.rs | 21 ++++++++++++------- .../burn-core/src/nn/conv/conv_transpose2d.rs | 21 ++++++++++++------- crates/burn-core/src/nn/loss/huber.rs | 6 ++++-- crates/burn-core/src/nn/norm/batch.rs | 6 ++++-- crates/burn-core/src/nn/norm/group.rs | 13 +++++++----- crates/burn-core/src/nn/norm/instance.rs | 10 +++++---- crates/burn-core/src/nn/norm/layer.rs | 5 +++-- crates/burn-core/src/nn/norm/rms.rs | 2 +- .../src/nn/pool/adaptive_avg_pool1d.rs | 3 ++- .../src/nn/pool/adaptive_avg_pool2d.rs | 3 ++- crates/burn-core/src/nn/pool/avg_pool2d.rs | 12 +++++++---- crates/burn-core/src/nn/transformer/pwff.rs | 12 +++++++---- 14 files changed, 94 insertions(+), 50 deletions(-) diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs index 15ec93759f..0b64eab324 100644 --- a/crates/burn-core/src/nn/conv/conv1d.rs +++ b/crates/burn-core/src/nn/conv/conv1d.rs @@ -50,11 +50,16 @@ pub struct Conv1d { pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, - stride: usize, - kernel_size: usize, - dilation: usize, - groups: usize, - padding: Ignored, + /// Stride of the convolution. + pub stride: usize, + /// Size of the kernel. + pub kernel_size: usize, + /// Spacing between kernel elements. + pub dilation: usize, + /// Controls the connections between input and output channels. + pub groups: usize, + /// Padding configuration. + pub padding: Ignored, } impl ModuleDisplay for Conv1d { diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index 494cf658d7..bf31fd9661 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -52,11 +52,16 @@ pub struct Conv2d { pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, - stride: [usize; 2], - kernel_size: [usize; 2], - dilation: [usize; 2], - groups: usize, - padding: Ignored, + /// Stride of the convolution. + pub stride: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Spacing between kernel elements. + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + pub groups: usize, + /// The padding configuration. + pub padding: Ignored, } impl Conv2dConfig { diff --git a/crates/burn-core/src/nn/conv/conv_transpose1d.rs b/crates/burn-core/src/nn/conv/conv_transpose1d.rs index 0aa83cbcd5..3598474713 100644 --- a/crates/burn-core/src/nn/conv/conv_transpose1d.rs +++ b/crates/burn-core/src/nn/conv/conv_transpose1d.rs @@ -56,13 +56,20 @@ pub struct ConvTranspose1d { pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, - stride: usize, - kernel_size: usize, - dilation: usize, - groups: usize, - padding: usize, - padding_out: usize, - channels: [usize; 2], + /// Stride of the convolution. + pub stride: usize, + /// Size of the kernel. + pub kernel_size: usize, + /// Spacing between kernel elements. + pub dilation: usize, + /// Controls the connections between input and output channels. + pub groups: usize, + /// The padding configuration. + pub padding: usize, + /// The padding output configuration. + pub padding_out: usize, + /// The number of channels. + pub channels: [usize; 2], } impl ModuleDisplay for ConvTranspose1d { diff --git a/crates/burn-core/src/nn/conv/conv_transpose2d.rs b/crates/burn-core/src/nn/conv/conv_transpose2d.rs index febb8aa89a..7fa3bd788c 100644 --- a/crates/burn-core/src/nn/conv/conv_transpose2d.rs +++ b/crates/burn-core/src/nn/conv/conv_transpose2d.rs @@ -56,13 +56,20 @@ pub struct ConvTranspose2d { pub weight: Param>, /// Tensor of shape `[channels_out]` pub bias: Option>>, - stride: [usize; 2], - kernel_size: [usize; 2], - dilation: [usize; 2], - groups: usize, - padding: [usize; 2], - padding_out: [usize; 2], - channels: [usize; 2], + /// Stride of the convolution. + pub stride: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Spacing between kernel elements. + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + pub groups: usize, + /// Padding configuration. + pub padding: [usize; 2], + /// Padding output configuration. + pub padding_out: [usize; 2], + /// Number of channels. + pub channels: [usize; 2], } impl ModuleDisplay for ConvTranspose2d { diff --git a/crates/burn-core/src/nn/loss/huber.rs b/crates/burn-core/src/nn/loss/huber.rs index 343647217c..8b227b0a44 100644 --- a/crates/burn-core/src/nn/loss/huber.rs +++ b/crates/burn-core/src/nn/loss/huber.rs @@ -51,8 +51,10 @@ impl HuberLossConfig { #[derive(Module, Debug, Clone)] #[module(custom_display)] pub struct HuberLoss { - delta: f32, - lin_bias: f32, // delta * delta * 0.5 precomputed + /// The bound where the Huber loss function changes from quadratic to linear behaviour. + pub delta: f32, + /// Precomputed value for the linear bias. + pub lin_bias: f32, // delta * delta * 0.5 precomputed } impl ModuleDisplay for HuberLoss { diff --git a/crates/burn-core/src/nn/norm/batch.rs b/crates/burn-core/src/nn/norm/batch.rs index f64c11cbcf..4a0cb7d2c6 100644 --- a/crates/burn-core/src/nn/norm/batch.rs +++ b/crates/burn-core/src/nn/norm/batch.rs @@ -44,8 +44,10 @@ pub struct BatchNorm { pub running_mean: RunningState>, /// The running variance. pub running_var: RunningState>, - momentum: f64, - epsilon: f64, + /// Momentum used to update the metrics. + pub momentum: f64, + /// A value required for numerical stability. + pub epsilon: f64, } impl BatchNormConfig { diff --git a/crates/burn-core/src/nn/norm/group.rs b/crates/burn-core/src/nn/norm/group.rs index 916d591f0d..3741565009 100644 --- a/crates/burn-core/src/nn/norm/group.rs +++ b/crates/burn-core/src/nn/norm/group.rs @@ -43,11 +43,14 @@ pub struct GroupNorm { pub gamma: Option>>, /// The learnable bias pub beta: Option>>, - - pub(crate) num_groups: usize, - pub(crate) num_channels: usize, - pub(crate) epsilon: f64, - pub(crate) affine: bool, + /// The number of groups to separate the channels into + pub num_groups: usize, + /// The number of channels expected in the input + pub num_channels: usize, + /// A value required for numerical stability + pub epsilon: f64, + /// A boolean value that when set to `true`, this module has learnable + pub affine: bool, } impl ModuleDisplay for GroupNorm { diff --git a/crates/burn-core/src/nn/norm/instance.rs b/crates/burn-core/src/nn/norm/instance.rs index 23ceda6d23..3ee5ef4c0b 100644 --- a/crates/burn-core/src/nn/norm/instance.rs +++ b/crates/burn-core/src/nn/norm/instance.rs @@ -32,10 +32,12 @@ pub struct InstanceNorm { pub gamma: Option>>, /// The learnable bias pub beta: Option>>, - - num_channels: usize, - epsilon: f64, - affine: bool, + /// The number of channels expected in the input + pub num_channels: usize, + /// A value required for numerical stability + pub epsilon: f64, + /// A boolean value that when set to `true`, this module has learnable + pub affine: bool, } impl ModuleDisplay for InstanceNorm { diff --git a/crates/burn-core/src/nn/norm/layer.rs b/crates/burn-core/src/nn/norm/layer.rs index 2e2a50ab1f..ea196906c2 100644 --- a/crates/burn-core/src/nn/norm/layer.rs +++ b/crates/burn-core/src/nn/norm/layer.rs @@ -34,9 +34,10 @@ pub struct LayerNormConfig { #[module(custom_display)] pub struct LayerNorm { /// The learnable weight. - gamma: Param>, + pub gamma: Param>, /// The learnable bias. - beta: Param>, + pub beta: Param>, + /// A value required for numerical stability. epsilon: f64, } diff --git a/crates/burn-core/src/nn/norm/rms.rs b/crates/burn-core/src/nn/norm/rms.rs index 0815f65d6c..e054c4f19a 100644 --- a/crates/burn-core/src/nn/norm/rms.rs +++ b/crates/burn-core/src/nn/norm/rms.rs @@ -54,7 +54,7 @@ pub struct RmsNorm { /// The learnable parameter to scale the normalized tensor pub gamma: Param>, /// A value required for numerical stability - epsilon: f64, + pub epsilon: f64, } impl RmsNorm { diff --git a/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs b/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs index cb2d50bcce..5322fa600f 100644 --- a/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/adaptive_avg_pool1d.rs @@ -21,7 +21,8 @@ pub struct AdaptiveAvgPool1dConfig { #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AdaptiveAvgPool1d { - output_size: usize, + /// The size of the output. + pub output_size: usize, } impl ModuleDisplay for AdaptiveAvgPool1d { diff --git a/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs b/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs index 587b723a9b..1f63fb8c92 100644 --- a/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/adaptive_avg_pool2d.rs @@ -21,7 +21,8 @@ pub struct AdaptiveAvgPool2dConfig { #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AdaptiveAvgPool2d { - output_size: [usize; 2], + /// The size of the output. + pub output_size: [usize; 2], } impl ModuleDisplay for AdaptiveAvgPool2d { diff --git a/crates/burn-core/src/nn/pool/avg_pool2d.rs b/crates/burn-core/src/nn/pool/avg_pool2d.rs index 36950b9436..6c6ffc87ed 100644 --- a/crates/burn-core/src/nn/pool/avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool2d.rs @@ -42,10 +42,14 @@ pub struct AvgPool2dConfig { #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AvgPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: Ignored, - count_include_pad: bool, + /// Stride of the pooling. + pub stride: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Padding configuration. + pub padding: Ignored, + /// If the padding is counted in the denominator when computing the average. + pub count_include_pad: bool, } impl ModuleDisplay for AvgPool2d { diff --git a/crates/burn-core/src/nn/transformer/pwff.rs b/crates/burn-core/src/nn/transformer/pwff.rs index b54278ff11..1c7af01496 100644 --- a/crates/burn-core/src/nn/transformer/pwff.rs +++ b/crates/burn-core/src/nn/transformer/pwff.rs @@ -38,10 +38,14 @@ pub struct PositionWiseFeedForwardConfig { #[derive(Module, Debug)] #[module(custom_display)] pub struct PositionWiseFeedForward { - linear_inner: Linear, - linear_outer: Linear, - dropout: Dropout, - gelu: Gelu, + /// Linear layer with `d_model` input features and `d_ff` output features. + pub linear_inner: Linear, + /// Linear layer with `d_ff` input features and `d_model` output features. + pub linear_outer: Linear, + /// Dropout layer. + pub dropout: Dropout, + /// GELU activation function. + pub gelu: Gelu, } impl ModuleDisplay for PositionWiseFeedForward { From dd5be6620d8796996ca54dbee9f1c06ea0b2936c Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 28 Jun 2024 04:01:58 -0500 Subject: [PATCH 51/52] Add module display for BiLstm --- crates/burn-core/src/nn/rnn/lstm.rs | 41 ++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/lstm.rs b/crates/burn-core/src/nn/rnn/lstm.rs index d38049510f..c1d1b23d44 100644 --- a/crates/burn-core/src/nn/rnn/lstm.rs +++ b/crates/burn-core/src/nn/rnn/lstm.rs @@ -54,7 +54,8 @@ pub struct Lstm { pub output_gate: GateController, /// The cell gate is used to compute the cell state that stores and carries information through time. pub cell_gate: GateController, - d_hidden: usize, + /// The hidden state of the LSTM. + pub d_hidden: usize, } impl ModuleDisplay for Lstm { @@ -75,6 +76,7 @@ impl ModuleDisplay for Lstm { .optional() } } + impl LstmConfig { /// Initialize a new [lstm](Lstm) module. pub fn init(&self, device: &B::Device) -> Lstm { @@ -215,12 +217,33 @@ pub struct BiLstmConfig { /// /// Should be created with [BiLstmConfig]. #[derive(Module, Debug)] +#[module(custom_display)] pub struct BiLstm { /// LSTM for the forward direction. pub forward: Lstm, /// LSTM for the reverse direction. pub reverse: Lstm, - d_hidden: usize, + /// The size of the hidden state. + pub d_hidden: usize, +} + +impl ModuleDisplay for BiLstm { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let [d_input, _] = self.forward.input_gate.input_transform.weight.shape().dims; + let bias = self.forward.input_gate.input_transform.bias.is_some(); + + content + .add("d_input", &d_input) + .add("d_hidden", &self.d_hidden) + .add("bias", &bias) + .optional() + } } impl BiLstmConfig { @@ -715,7 +738,7 @@ mod tests { } #[test] - fn display() { + fn display_lstm() { let config = LstmConfig::new(2, 3, true); let layer = config.init::(&Default::default()); @@ -725,4 +748,16 @@ mod tests { "Lstm {d_input: 2, d_hidden: 3, bias: true, params: 84}" ); } + + #[test] + fn display_bilstm() { + let config = BiLstmConfig::new(2, 3, true); + + let layer = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", layer), + "BiLstm {d_input: 2, d_hidden: 3, bias: true, params: 168}" + ); + } } From 60454110da0976f9f136d9f983db33eefd94cbb4 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Fri, 28 Jun 2024 04:02:09 -0500 Subject: [PATCH 52/52] Clean up --- crates/burn-core/src/nn/tanh.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-core/src/nn/tanh.rs b/crates/burn-core/src/nn/tanh.rs index e9bbfb0ac5..322ac68bd4 100644 --- a/crates/burn-core/src/nn/tanh.rs +++ b/crates/burn-core/src/nn/tanh.rs @@ -7,7 +7,7 @@ use crate::tensor::Tensor; /// Applies the tanh activation function element-wise /// See also [tanh](burn::tensor::activation::tanh) #[derive(Module, Clone, Debug, Default)] -pub struct Tanh {} +pub struct Tanh; impl Tanh { /// Create the module.