Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print module - implement module display for remaining modules (part2) #1933

Merged
merged 56 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
32cd4c2
Display for ConvTranspose1d
antimora Jun 26, 2024
dab0d25
Display for ConvTranspose2d
antimora Jun 26, 2024
a342324
Test for batch norm display
antimora Jun 26, 2024
9e2fa95
Display for group norm
antimora Jun 26, 2024
1a6f303
Display for Instance norm
antimora Jun 26, 2024
1e73430
Test for layer display
antimora Jun 26, 2024
eea0821
Display for RMS
antimora Jun 26, 2024
a762ede
Fix conv transpose
antimora Jun 26, 2024
80cb068
Test for conv1d display
antimora Jun 26, 2024
359f7b0
Test for conv2d display
antimora Jun 26, 2024
d76fca7
Add display for BinaryCrossEntropyLoss
antimora Jun 26, 2024
5e5f4a3
Make attributes pub
antimora Jun 26, 2024
edf2945
Display for cross entropy
antimora Jun 26, 2024
7451557
Rename print to display in tests
antimora Jun 26, 2024
7da0e1f
Removed PhantomData
antimora Jun 26, 2024
d13b8bb
Added huber display
antimora Jun 26, 2024
4711f9a
Add display for MSE
antimora Jun 26, 2024
0e40d35
Merge remote-tracking branch 'upstream/main' into print-module-part2
antimora Jun 26, 2024
1a291a7
Update log message format
antimora Jun 26, 2024
a7bf1fc
Display for AdaptiveAvgPool1d
antimora Jun 26, 2024
d6c7db0
Display for AdaptiveAvgPool2d
antimora Jun 26, 2024
1c10e32
Make attributes pub for AvgPool1d
antimora Jun 26, 2024
f453b99
Display for AvgPool1d and AvgPool2d
antimora Jun 26, 2024
c4c3d96
Merge remote-tracking branch 'upstream/main' into print-module-part2
antimora Jun 26, 2024
1324471
Display for MaxPool1d and MaxPool2d
antimora Jun 26, 2024
7b736cb
Add display for Gru
antimora Jun 26, 2024
f38fc2e
Add display for lstm
antimora Jun 26, 2024
a7f3717
Add display test to dropout
antimora Jun 26, 2024
169d426
Make dropout attributes pub
antimora Jun 26, 2024
ade898e
Add display for Embedding
antimora Jun 26, 2024
9632ff4
Clean up
antimora Jun 26, 2024
de92657
Add linear display test
antimora Jun 26, 2024
cd38b99
Make attribute pub
antimora Jun 26, 2024
635b2cb
Clean up
antimora Jun 27, 2024
5d605c2
Add display test to gelu
antimora Jun 27, 2024
a2ab8dd
Add display to leaky relu
antimora Jun 27, 2024
84221ad
Add display to prelu
antimora Jun 27, 2024
213732c
Add test to relu
antimora Jun 27, 2024
c7b6df4
Merge remote-tracking branch 'upstream/main' into print-module-part2
antimora Jun 27, 2024
d7c488e
Fix prelu test
antimora Jun 27, 2024
bfcf067
Clean up
antimora Jun 27, 2024
81ad5f7
Add display for PositionalEncoding
antimora Jun 27, 2024
129e220
Add display for role encoding
antimora Jun 27, 2024
30f01f4
Add display for SwiGlu
antimora Jun 27, 2024
e29943a
Add display test for Tanh
antimora Jun 27, 2024
fd53c9a
Fix burn-import
antimora Jun 27, 2024
4ebc60c
Add display to unfold
antimora Jun 27, 2024
004fc39
Add display for pwff
antimora Jun 27, 2024
b50d3d2
Add display for Transformer decoder
antimora Jun 27, 2024
d2d6fec
Add display for encoder
antimora Jun 27, 2024
604129e
Add display for MultiHeadAttention
antimora Jun 27, 2024
957e239
Fix test
antimora Jun 27, 2024
4a437ce
Merge remote-tracking branch 'upstream/main' into print-module-part2
antimora Jun 28, 2024
fbd0e51
Make module attributes pub
antimora Jun 28, 2024
dd5be66
Add module display for BiLstm
antimora Jun 28, 2024
6045411
Clean up
antimora Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions crates/burn-core/src/nn/attention/mha.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate as burn;

use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::nn::cache::TensorCache;
use crate::nn::Initializer;
use crate::{
config::Config,
module::Module,
nn,
tensor::{activation, backend::Backend, Bool, Tensor},
};
Expand Down Expand Up @@ -53,17 +53,49 @@ pub struct MultiHeadAttentionConfig {
///
/// Should be created with [MultiHeadAttentionConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct MultiHeadAttention<B: Backend> {
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
output: nn::Linear<B>,
dropout: nn::Dropout,
activation: nn::Gelu,
n_heads: usize,
d_k: usize,
min_float: f64,
quiet_softmax: bool,
/// Linear layer to transform the input features into the query space.
pub query: nn::Linear<B>,
/// Linear layer to transform the input features into the key space.
pub key: nn::Linear<B>,
/// Linear layer to transform the input features into the value space.
pub value: nn::Linear<B>,
/// Linear layer to transform the output features back to the original space.
pub output: nn::Linear<B>,
/// Dropout layer.
pub dropout: nn::Dropout,
/// Activation function.
pub activation: nn::Gelu,
/// The size of each linear layer.
pub d_model: usize,
/// The number of heads.
pub n_heads: usize,
/// Size of the key and query vectors.
pub d_k: usize,
/// Minimum value a float can take.
pub min_float: f64,
/// Use "quiet softmax" instead of regular softmax.
pub quiet_softmax: bool,
}

impl<B: Backend> ModuleDisplay for MultiHeadAttention<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}

fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("d_model", &self.d_model)
.add("n_heads", &self.n_heads)
.add("d_k", &self.d_k)
.add("dropout", &self.dropout.prob)
.add("min_float", &self.min_float)
.add("quiet_softmax", &self.quiet_softmax)
.optional()
}
}

/// [Multihead attention](MultiHeadAttention) forward pass input argument.
Expand Down Expand Up @@ -99,6 +131,7 @@ impl MultiHeadAttentionConfig {
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
d_model: self.d_model,
}
}
}
Expand Down Expand Up @@ -478,4 +511,16 @@ mod tests {
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
}

#[test]
fn display() {
let config = MultiHeadAttentionConfig::new(2, 4);
let mha = config.init::<TestBackend>(&Default::default());

assert_eq!(
alloc::format!("{}", mha),
"MultiHeadAttention {d_model: 2, n_heads: 4, d_k: 0, \
dropout: 0.1, min_float: -10000, quiet_softmax: false, params: 24}"
);
}
}
26 changes: 21 additions & 5 deletions crates/burn-core/src/nn/conv/conv1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,16 @@ pub struct Conv1d<B: Backend> {
pub weight: Param<Tensor<B, 3>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
stride: usize,
kernel_size: usize,
dilation: usize,
groups: usize,
padding: Ignored<PaddingConfig1d>,
/// Stride of the convolution.
pub stride: usize,
/// Size of the kernel.
pub kernel_size: usize,
/// Spacing between kernel elements.
pub dilation: usize,
/// Controls the connections between input and output channels.
pub groups: usize,
/// Padding configuration.
pub padding: Ignored<PaddingConfig1d>,
}

impl<B: Backend> ModuleDisplay for Conv1d<B> {
Expand Down Expand Up @@ -169,4 +174,15 @@ mod tests {
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
}

#[test]
fn display() {
let config = Conv1dConfig::new(5, 5, 5);
let conv = config.init::<TestBackend>(&Default::default());

assert_eq!(
alloc::format!("{}", conv),
"Conv1d {stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: Valid, params: 130}"
);
}
}
26 changes: 21 additions & 5 deletions crates/burn-core/src/nn/conv/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ pub struct Conv2d<B: Backend> {
pub weight: Param<Tensor<B, 4>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
stride: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
groups: usize,
padding: Ignored<PaddingConfig2d>,
/// Stride of the convolution.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
pub groups: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig2d>,
}

impl Conv2dConfig {
Expand Down Expand Up @@ -214,4 +219,15 @@ mod tests {

assert_eq!(config.initializer, init);
}

#[test]
fn display() {
let config = Conv2dConfig::new([5, 1], [5, 5]);
let conv = config.init::<TestBackend>(&Default::default());

assert_eq!(
alloc::format!("{}", conv),
"Conv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: Valid, params: 126}"
);
}
}
58 changes: 52 additions & 6 deletions crates/burn-core/src/nn/conv/conv_transpose1d.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use alloc::format;

use crate as burn;

use crate::config::Config;
use crate::module::Content;
use crate::module::DisplaySettings;
use crate::module::Module;
use crate::module::ModuleDisplay;
use crate::module::Param;
use crate::nn::conv::checks;
use crate::nn::Initializer;
Expand Down Expand Up @@ -45,17 +50,46 @@ pub struct ConvTranspose1dConfig {

/// Applies a 1D transposed convolution over input tensors.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct ConvTranspose1d<B: Backend> {
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size]`
pub weight: Param<Tensor<B, 3>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
stride: usize,
kernel_size: usize,
dilation: usize,
groups: usize,
padding: usize,
padding_out: usize,
/// Stride of the convolution.
pub stride: usize,
/// Size of the kernel.
pub kernel_size: usize,
/// Spacing between kernel elements.
pub dilation: usize,
/// Controls the connections between input and output channels.
pub groups: usize,
/// The padding configuration.
pub padding: usize,
/// The padding output configuration.
pub padding_out: usize,
/// The number of channels.
pub channels: [usize; 2],
}

impl<B: Backend> ModuleDisplay for ConvTranspose1d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}

fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("channels", &format!("{:?}", &self.channels))
.add("stride", &self.stride)
.add("kernel_size", &self.kernel_size)
.add("dilation", &self.dilation)
.add("groups", &self.groups)
.add("padding", &self.padding)
.add("padding_out", &self.padding_out)
.optional()
}
}

impl ConvTranspose1dConfig {
Expand Down Expand Up @@ -91,6 +125,7 @@ impl ConvTranspose1dConfig {
groups: self.groups,
padding: self.padding,
padding_out: self.padding_out,
channels: self.channels,
}
}
}
Expand Down Expand Up @@ -150,4 +185,15 @@ mod tests {
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
}

#[test]
fn display() {
let config = ConvTranspose1dConfig::new([5, 2], 5);
let conv = config.init::<TestBackend>(&Default::default());

assert_eq!(
format!("{}", conv),
"ConvTranspose1d {channels: [5, 2], stride: 1, kernel_size: 5, dilation: 1, groups: 1, padding: 0, padding_out: 0, params: 52}"
);
}
}
58 changes: 52 additions & 6 deletions crates/burn-core/src/nn/conv/conv_transpose2d.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use alloc::format;

use crate as burn;

use crate::config::Config;
use crate::module::Content;
use crate::module::DisplaySettings;
use crate::module::Module;
use crate::module::ModuleDisplay;
use crate::module::Param;
use crate::nn::conv::checks;
use crate::nn::Initializer;
Expand Down Expand Up @@ -45,17 +50,46 @@ pub struct ConvTranspose2dConfig {

/// Applies a 2D transposed convolution over input tensors.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct ConvTranspose2d<B: Backend> {
/// Tensor of shape `[channels_in, channels_out / groups, kernel_size_1, kernel_size_2]`
pub weight: Param<Tensor<B, 4>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
stride: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
groups: usize,
padding: [usize; 2],
padding_out: [usize; 2],
/// Stride of the convolution.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
pub groups: usize,
/// Padding configuration.
pub padding: [usize; 2],
/// Padding output configuration.
pub padding_out: [usize; 2],
/// Number of channels.
pub channels: [usize; 2],
}

impl<B: Backend> ModuleDisplay for ConvTranspose2d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}

fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("channels", &format!("{:?}", &self.channels))
.add("stride", &format!("{:?}", &self.stride))
.add("kernel_size", &format!("{:?}", &self.kernel_size))
.add("dilation", &format!("{:?}", &self.dilation))
.add("groups", &self.groups)
.add("padding", &format!("{:?}", &self.padding))
.add("padding_out", &format!("{:?}", &self.padding_out))
.optional()
}
}

impl ConvTranspose2dConfig {
Expand Down Expand Up @@ -92,6 +126,7 @@ impl ConvTranspose2dConfig {
groups: self.groups,
padding: self.padding,
padding_out: self.padding_out,
channels: self.channels,
}
}
}
Expand Down Expand Up @@ -152,4 +187,15 @@ mod tests {
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
}

#[test]
fn display() {
let config = ConvTranspose2dConfig::new([5, 2], [5, 5]);
let conv = config.init::<TestBackend>(&Default::default());

assert_eq!(
format!("{}", conv),
"ConvTranspose2d {channels: [5, 2], stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], groups: 1, padding: [0, 0], padding_out: [0, 0], params: 252}"
);
}
}
15 changes: 12 additions & 3 deletions crates/burn-core/src/nn/dropout.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate as burn;

use crate::config::Config;
use crate::module::{DisplaySettings, Module, ModuleDisplay};
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor};

Expand All @@ -23,7 +23,8 @@ pub struct DropoutConfig {
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Dropout {
prob: f64,
/// The probability of randomly zeroes some elements of the input tensor during training.
pub prob: f64,
}

impl DropoutConfig {
Expand Down Expand Up @@ -62,7 +63,7 @@ impl ModuleDisplay for Dropout {
.optional()
}

fn custom_content(&self, content: crate::module::Content) -> Option<crate::module::Content> {
fn custom_content(&self, content: Content) -> Option<Content> {
content.add("prob", &self.prob).optional()
}
}
Expand Down Expand Up @@ -99,4 +100,12 @@ mod tests {

assert_eq!(tensor.to_data(), output.to_data());
}

#[test]
fn display() {
let config = DropoutConfig::new(0.5);
let layer = config.init();

assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}");
}
}
Loading
Loading