Skip to content

Commit

Permalink
Display for MaxPool1d and MaxPool2d
Browse files Browse the repository at this point in the history
  • Loading branch information
antimora committed Jun 26, 2024
1 parent c4c3d96 commit 1324471
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 8 deletions.
48 changes: 44 additions & 4 deletions crates/burn-core/src/nn/pool/max_pool1d.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate as burn;

use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig1d;
use crate::tensor::backend::Backend;
Expand Down Expand Up @@ -28,11 +29,33 @@ pub struct MaxPool1dConfig {
///
/// Should be created with [MaxPool1dConfig](MaxPool1dConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct MaxPool1d {
stride: usize,
kernel_size: usize,
padding: Ignored<PaddingConfig1d>,
dilation: usize,
/// The stride.
pub stride: usize,
/// The size of the kernel.
pub kernel_size: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig1d>,
/// The dilation.
pub dilation: usize,
}

impl ModuleDisplay for MaxPool1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}

fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &self.kernel_size)
.add("stride", &self.stride)
.add("padding", &self.padding)
.add("dilation", &self.dilation)
.optional()
}
}

impl MaxPool1dConfig {
Expand Down Expand Up @@ -65,3 +88,20 @@ impl MaxPool1d {
max_pool1d(input, self.kernel_size, self.stride, padding, self.dilation)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn display() {
let config = MaxPool1dConfig::new(3);

let layer = config.init();

assert_eq!(
alloc::format!("{}", layer),
"MaxPool1d {kernel_size: 3, stride: 1, padding: Valid, dilation: 1}"
);
}
}
48 changes: 44 additions & 4 deletions crates/burn-core/src/nn/pool/max_pool2d.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate as burn;

use crate::config::Config;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::module::{Ignored, Module};
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
Expand Down Expand Up @@ -28,11 +29,33 @@ pub struct MaxPool2dConfig {
///
/// Should be created with [MaxPool2dConfig](MaxPool2dConfig).
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct MaxPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: Ignored<PaddingConfig2d>,
dilation: [usize; 2],
/// The strides.
pub stride: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The padding configuration.
pub padding: Ignored<PaddingConfig2d>,
/// The dilation.
pub dilation: [usize; 2],
}

impl ModuleDisplay for MaxPool2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}

fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("kernel_size", &alloc::format!("{:?}", &self.kernel_size))
.add("stride", &alloc::format!("{:?}", &self.stride))
.add("padding", &self.padding)
.add("dilation", &alloc::format!("{:?}", &self.dilation))
.optional()
}
}

impl MaxPool2dConfig {
Expand Down Expand Up @@ -65,3 +88,20 @@ impl MaxPool2d {
max_pool2d(input, self.kernel_size, self.stride, padding, self.dilation)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn display() {
let config = MaxPool2dConfig::new([3, 3]);

let layer = config.init();

assert_eq!(
alloc::format!("{}", layer),
"MaxPool2d {kernel_size: [3, 3], stride: [1, 1], padding: Valid, dilation: [1, 1]}"
);
}
}

0 comments on commit 1324471

Please sign in to comment.