From 420ed7a2db2e9137145ceccb6c5ed50494c62068 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Mon, 18 Nov 2024 17:48:17 -0800 Subject: [PATCH] Conv1D supports paddings. (#847) Conv1D is designed for processing sequence data, so it is natural to handle paddings as part of its functionality. Thus, it introduces Conv1DWithPadding. --- axlearn/common/layers.py | 76 ++++++++++++++++++++++++------- axlearn/common/layers_test.py | 85 +++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 15 deletions(-) diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index ea02901f..eb1f89a0 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -1216,17 +1216,7 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti class Conv2DWith1DPadding(Conv2D): - """The 2-D convolution with 1-D padding on the time axis. - - Kernel weights have the HWIO layout and in the shape of (window[0], window[1], input_dim, - output_dim). Both inputs and outputs will be in the NHWC layout. - - For audio inputs/outputs, we assume dims correspond to [batch_size, time, frequency, input_dim]. - This layer also returns paddings along the time axis. If specifying `cfg.padding` as a tuple of - (leading, trailing) paddings, leading padding frames are treated as valid (i.e. not masked by - the output paddings) while trailing padding frames are invalid (i.e. masked by the output - paddings). - """ + """The 2-D convolution with 1-D padding on the time axis.""" @config_class class Config(Conv2D.Config): @@ -1499,10 +1489,11 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]: def forward(self, x: Tensor) -> Tensor: cfg = self.config - dilation = (cfg.rhs_dilation,) if cfg.rhs_dilation else None + dilation = cfg.rhs_dilation or 1 conv_padding = conv_explicit_padding( - window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding, dilation=dilation + window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding, dilation=(dilation,) ) + transpose_dilation = cfg.lhs_dilation or 1 output = jax.lax.conv_general_dilated( lhs=x, rhs=self.parameters["weight"], @@ -1510,14 +1501,69 @@ def forward(self, x: Tensor) -> Tensor: dimension_numbers=("NWC", "WIO", "NWC"), padding=conv_padding, feature_group_count=cfg.num_input_dim_groups, - lhs_dilation=[cfg.lhs_dilation] if cfg.lhs_dilation is not None else None, - rhs_dilation=[cfg.rhs_dilation] if cfg.rhs_dilation is not None else None, + lhs_dilation=(transpose_dilation,), + rhs_dilation=(dilation,), ) if cfg.bias: output += self.parameters["bias"] return output +class Conv1DWithPadding(Conv1D): + """The 1-D convolution with 1-D padding on the time axis.""" + + @config_class + class Config(Conv1D.Config): + """Configures Conv1DWithPadding.""" + + # An optional integer in the range of [left_time_padding, window - right_time_padding) + # that specifies the anchor position within the convolution window that is used to + # determine output paddings. Specifically, the output token is valid iff the input token + # at the anchor position of the corresponding window is valid. + # If None, defaults to left time padding. See Conv2DWith1DPadding more details. + anchor: Optional[int] = None + + # We add a kwargs "paddings" to the forward method. + # pylint: disable-next=arguments-differ + def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: + """Computes convolution outputs and paddings. + + Args: + x: A Tensor of shape [batch_size, seq_len, frequency, input_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + + Returns: + output: A Tensor of shape [batch_size, seq_len, frequency, output_dim]. + paddings: 0/1 Tensor of shape [batch_size, seq_len]. + """ + cfg = self.config + chex.assert_rank(x, paddings.ndim + 1) + # Apply padding to the input. + x = x * (1 - paddings[..., None]) + + # Apply Conv1D. + output = super().forward(x) + + # TODO(dhwang2): Implement Conv1DTranspose separately for lhs_dilation. It's problematic + # for lhs_dilation (Conv Transpose) and rhs_dilation (Dilated Convolution) to be part of + # the same class. Not only are they never used together, but their combined usage would + # result in undefined behavior. Additionally, the logic for handling explicit padding and + # paddings is fundamentally different between them, so supporting both in a single class + # makes the code error-prone. + # Compute paddings conv output. + output_paddings = compute_conv_paddings( + paddings, + window=cfg.window, + stride=cfg.strides, + conv_padding=cfg.padding, + dilation=cfg.rhs_dilation, + anchor=cfg.anchor, + ) + # Apply padding to the outputs. + output = output * (1 - output_paddings[..., None]) + return output, output_paddings + + class DepthwiseConv1D(BaseConv): """The 1-D depth-wise convolution layer. diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index 24d97d0c..1c5ffc5f 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -19,6 +19,7 @@ from functools import partial from typing import Optional, Union +import einops import jax.random import numpy as np import tensorflow as tf @@ -39,6 +40,7 @@ CategoricalHingeLossMetric, ClassificationMetric, Conv1D, + Conv1DWithPadding, Conv2D, Conv2DTranspose, Conv2DWith1DPadding, @@ -1265,6 +1267,89 @@ def test_conv2d_with_1d_padding( jnp.take_along_axis(ref_paddings, permute_idx[:, None], axis=0)[:, :output_len], ) + @parameterized.named_parameters( + ("1_S1", 1, 1, "VALID", None), + ("2_S1_VALID", 2, 1, "VALID", None), + ("2_S2_SAME", 2, 2, "SAME", None), + ("2_S_CAUSAL", 2, 1, "CAUSAL", None), + ("2_S2_VALID", 2, 2, "VALID", None), + ("2_S2_CAUSAL", 2, 2, "CAUSAL", None), + ("3_S1_VALID", 3, 1, "VALID", None), + ("3_S1_VALID_A0", 3, 1, "VALID", 0), + ("3_S1_VALID_A1", 3, 1, "VALID", 1), + ("3_S1_VALID_A2", 3, 1, "VALID", 2), + ("3_S1_SAME", 3, 1, "SAME", None), + ("3_S1_CAUSAL", 3, 1, "CAUSAL", None), + ("3_S2_VALID", 3, 2, "VALID", None), + ("3_S2_CAUSAL", 3, 2, "CAUSAL", None), + ) + def test_conv1d_against_conv2d_with_1d_padding( + self, + window: int, + strides: int, + padding: ConvPaddingType, + anchor: Optional[int], + ): + input_dim, output_dim = 4, 6 + ref_cfg = Conv2DWith1DPadding.default_config().set( + name="ref", + input_dim=input_dim, + output_dim=output_dim, + window=(window, 1), + strides=(strides, 1), + padding=padding, + anchor=anchor, + ) + ref_layer = ref_cfg.instantiate(parent=None) + + test_cfg = Conv1DWithPadding.default_config().set( + name="test", + input_dim=input_dim, + output_dim=output_dim, + window=window, + strides=strides, + padding=padding, + anchor=anchor, + ) + test_layer = test_cfg.instantiate(parent=None) + + # Initialize layer parameters. + prng_key = jax.random.PRNGKey(123) + prng_key, init_key = jax.random.split(prng_key) + state = ref_layer.initialize_parameters_recursively(init_key) + test_state = dict( + bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o") + ) + + # Generate a batch of 10 input sequences. + batch_size, max_seq_len = 10, 10 + + prng_key, input_key = jax.random.split(prng_key) + inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim]) + # The 10 sequences have length 1 to 10. + paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1) + + (test_outputs, test_paddings), _ = F( + test_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=test_state, + prng_key=prng_key, + ) + + inputs = einops.rearrange(inputs, "b t i -> b t 1 i") + (ref_outputs, ref_paddings), _ = F( + ref_layer, + inputs=dict(x=inputs, paddings=paddings), + is_training=True, + state=state, + prng_key=prng_key, + ) + ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o") + + assert_allclose(ref_paddings, test_paddings) + assert_allclose(ref_outputs, test_outputs) + @parameterized.named_parameters( { "testcase_name": "2x2",