From 13244715ce24a914777a3f98678c36a336ede5a6 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:08:27 -0500 Subject: [PATCH] Display for MaxPool1d and MaxPool2d --- crates/burn-core/src/nn/pool/max_pool1d.rs | 48 ++++++++++++++++++++-- crates/burn-core/src/nn/pool/max_pool2d.rs | 48 ++++++++++++++++++++-- 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/crates/burn-core/src/nn/pool/max_pool1d.rs b/crates/burn-core/src/nn/pool/max_pool1d.rs index 040a7a1027..5be363e908 100644 --- a/crates/burn-core/src/nn/pool/max_pool1d.rs +++ b/crates/burn-core/src/nn/pool/max_pool1d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig1d; use crate::tensor::backend::Backend; @@ -28,11 +29,33 @@ pub struct MaxPool1dConfig { /// /// Should be created with [MaxPool1dConfig](MaxPool1dConfig). #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct MaxPool1d { - stride: usize, - kernel_size: usize, - padding: Ignored, - dilation: usize, + /// The stride. + pub stride: usize, + /// The size of the kernel. + pub kernel_size: usize, + /// The padding configuration. + pub padding: Ignored, + /// The dilation. + pub dilation: usize, +} + +impl ModuleDisplay for MaxPool1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &self.kernel_size) + .add("stride", &self.stride) + .add("padding", &self.padding) + .add("dilation", &self.dilation) + .optional() + } } impl MaxPool1dConfig { @@ -65,3 +88,20 @@ impl MaxPool1d { max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = MaxPool1dConfig::new(3); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "MaxPool1d {kernel_size: 3, stride: 1, padding: Valid, dilation: 1}" + ); + } +} diff --git a/crates/burn-core/src/nn/pool/max_pool2d.rs b/crates/burn-core/src/nn/pool/max_pool2d.rs index 552cde9b35..ab9c60d276 100644 --- a/crates/burn-core/src/nn/pool/max_pool2d.rs +++ b/crates/burn-core/src/nn/pool/max_pool2d.rs @@ -1,6 +1,7 @@ use crate as burn; use crate::config::Config; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; use crate::module::{Ignored, Module}; use crate::nn::PaddingConfig2d; use crate::tensor::backend::Backend; @@ -28,11 +29,33 @@ pub struct MaxPool2dConfig { /// /// Should be created with [MaxPool2dConfig](MaxPool2dConfig). #[derive(Module, Clone, Debug)] +#[module(custom_display)] pub struct MaxPool2d { - stride: [usize; 2], - kernel_size: [usize; 2], - padding: Ignored, - dilation: [usize; 2], + /// The strides. + pub stride: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The padding configuration. + pub padding: Ignored, + /// The dilation. + pub dilation: [usize; 2], +} + +impl ModuleDisplay for MaxPool2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("padding", &self.padding) + .add("dilation", &alloc::format!("{:?}", &self.dilation)) + .optional() + } } impl MaxPool2dConfig { @@ -65,3 +88,20 @@ impl MaxPool2d { max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = MaxPool2dConfig::new([3, 3]); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "MaxPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, dilation: [1, 1]}" + ); + } +}