Skip to content

Commit

Permalink
Conv1D supports paddings. (#847)
Browse files Browse the repository at this point in the history
Conv1D is designed for processing sequence data, so it is natural to handle
paddings as part of its functionality.
Thus, it introduces Conv1DWithPadding.
  • Loading branch information
ds-hwang authored Nov 19, 2024
1 parent 2335d13 commit 420ed7a
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 15 deletions.
76 changes: 61 additions & 15 deletions axlearn/common/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,17 +1216,7 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti


class Conv2DWith1DPadding(Conv2D):
"""The 2-D convolution with 1-D padding on the time axis.
Kernel weights have the HWIO layout and in the shape of (window[0], window[1], input_dim,
output_dim). Both inputs and outputs will be in the NHWC layout.
For audio inputs/outputs, we assume dims correspond to [batch_size, time, frequency, input_dim].
This layer also returns paddings along the time axis. If specifying `cfg.padding` as a tuple of
(leading, trailing) paddings, leading padding frames are treated as valid (i.e. not masked by
the output paddings) while trailing padding frames are invalid (i.e. masked by the output
paddings).
"""
"""The 2-D convolution with 1-D padding on the time axis."""

@config_class
class Config(Conv2D.Config):
Expand Down Expand Up @@ -1499,25 +1489,81 @@ def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:

def forward(self, x: Tensor) -> Tensor:
cfg = self.config
dilation = (cfg.rhs_dilation,) if cfg.rhs_dilation else None
dilation = cfg.rhs_dilation or 1
conv_padding = conv_explicit_padding(
window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding, dilation=dilation
window=(cfg.window,), strides=(cfg.strides,), padding=cfg.padding, dilation=(dilation,)
)
transpose_dilation = cfg.lhs_dilation or 1
output = jax.lax.conv_general_dilated(
lhs=x,
rhs=self.parameters["weight"],
window_strides=(cfg.strides,),
dimension_numbers=("NWC", "WIO", "NWC"),
padding=conv_padding,
feature_group_count=cfg.num_input_dim_groups,
lhs_dilation=[cfg.lhs_dilation] if cfg.lhs_dilation is not None else None,
rhs_dilation=[cfg.rhs_dilation] if cfg.rhs_dilation is not None else None,
lhs_dilation=(transpose_dilation,),
rhs_dilation=(dilation,),
)
if cfg.bias:
output += self.parameters["bias"]
return output


class Conv1DWithPadding(Conv1D):
"""The 1-D convolution with 1-D padding on the time axis."""

@config_class
class Config(Conv1D.Config):
"""Configures Conv1DWithPadding."""

# An optional integer in the range of [left_time_padding, window - right_time_padding)
# that specifies the anchor position within the convolution window that is used to
# determine output paddings. Specifically, the output token is valid iff the input token
# at the anchor position of the corresponding window is valid.
# If None, defaults to left time padding. See Conv2DWith1DPadding more details.
anchor: Optional[int] = None

# We add a kwargs "paddings" to the forward method.
# pylint: disable-next=arguments-differ
def forward(self, x: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]:
"""Computes convolution outputs and paddings.
Args:
x: A Tensor of shape [batch_size, seq_len, frequency, input_dim].
paddings: 0/1 Tensor of shape [batch_size, seq_len].
Returns:
output: A Tensor of shape [batch_size, seq_len, frequency, output_dim].
paddings: 0/1 Tensor of shape [batch_size, seq_len].
"""
cfg = self.config
chex.assert_rank(x, paddings.ndim + 1)
# Apply padding to the input.
x = x * (1 - paddings[..., None])

# Apply Conv1D.
output = super().forward(x)

# TODO(dhwang2): Implement Conv1DTranspose separately for lhs_dilation. It's problematic
# for lhs_dilation (Conv Transpose) and rhs_dilation (Dilated Convolution) to be part of
# the same class. Not only are they never used together, but their combined usage would
# result in undefined behavior. Additionally, the logic for handling explicit padding and
# paddings is fundamentally different between them, so supporting both in a single class
# makes the code error-prone.
# Compute paddings conv output.
output_paddings = compute_conv_paddings(
paddings,
window=cfg.window,
stride=cfg.strides,
conv_padding=cfg.padding,
dilation=cfg.rhs_dilation,
anchor=cfg.anchor,
)
# Apply padding to the outputs.
output = output * (1 - output_paddings[..., None])
return output, output_paddings


class DepthwiseConv1D(BaseConv):
"""The 1-D depth-wise convolution layer.
Expand Down
85 changes: 85 additions & 0 deletions axlearn/common/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from functools import partial
from typing import Optional, Union

import einops
import jax.random
import numpy as np
import tensorflow as tf
Expand All @@ -39,6 +40,7 @@
CategoricalHingeLossMetric,
ClassificationMetric,
Conv1D,
Conv1DWithPadding,
Conv2D,
Conv2DTranspose,
Conv2DWith1DPadding,
Expand Down Expand Up @@ -1265,6 +1267,89 @@ def test_conv2d_with_1d_padding(
jnp.take_along_axis(ref_paddings, permute_idx[:, None], axis=0)[:, :output_len],
)

@parameterized.named_parameters(
("1_S1", 1, 1, "VALID", None),
("2_S1_VALID", 2, 1, "VALID", None),
("2_S2_SAME", 2, 2, "SAME", None),
("2_S_CAUSAL", 2, 1, "CAUSAL", None),
("2_S2_VALID", 2, 2, "VALID", None),
("2_S2_CAUSAL", 2, 2, "CAUSAL", None),
("3_S1_VALID", 3, 1, "VALID", None),
("3_S1_VALID_A0", 3, 1, "VALID", 0),
("3_S1_VALID_A1", 3, 1, "VALID", 1),
("3_S1_VALID_A2", 3, 1, "VALID", 2),
("3_S1_SAME", 3, 1, "SAME", None),
("3_S1_CAUSAL", 3, 1, "CAUSAL", None),
("3_S2_VALID", 3, 2, "VALID", None),
("3_S2_CAUSAL", 3, 2, "CAUSAL", None),
)
def test_conv1d_against_conv2d_with_1d_padding(
self,
window: int,
strides: int,
padding: ConvPaddingType,
anchor: Optional[int],
):
input_dim, output_dim = 4, 6
ref_cfg = Conv2DWith1DPadding.default_config().set(
name="ref",
input_dim=input_dim,
output_dim=output_dim,
window=(window, 1),
strides=(strides, 1),
padding=padding,
anchor=anchor,
)
ref_layer = ref_cfg.instantiate(parent=None)

test_cfg = Conv1DWithPadding.default_config().set(
name="test",
input_dim=input_dim,
output_dim=output_dim,
window=window,
strides=strides,
padding=padding,
anchor=anchor,
)
test_layer = test_cfg.instantiate(parent=None)

# Initialize layer parameters.
prng_key = jax.random.PRNGKey(123)
prng_key, init_key = jax.random.split(prng_key)
state = ref_layer.initialize_parameters_recursively(init_key)
test_state = dict(
bias=state["bias"], weight=einops.rearrange(state["weight"], "t 1 i o -> t i o")
)

# Generate a batch of 10 input sequences.
batch_size, max_seq_len = 10, 10

prng_key, input_key = jax.random.split(prng_key)
inputs = jax.random.normal(input_key, [batch_size, max_seq_len, input_dim])
# The 10 sequences have length 1 to 10.
paddings = jnp.triu(jnp.ones((batch_size, max_seq_len)), k=1)

(test_outputs, test_paddings), _ = F(
test_layer,
inputs=dict(x=inputs, paddings=paddings),
is_training=True,
state=test_state,
prng_key=prng_key,
)

inputs = einops.rearrange(inputs, "b t i -> b t 1 i")
(ref_outputs, ref_paddings), _ = F(
ref_layer,
inputs=dict(x=inputs, paddings=paddings),
is_training=True,
state=state,
prng_key=prng_key,
)
ref_outputs = einops.rearrange(ref_outputs, "b t 1 o -> b t o")

assert_allclose(ref_paddings, test_paddings)
assert_allclose(ref_outputs, test_outputs)

@parameterized.named_parameters(
{
"testcase_name": "2x2",
Expand Down

0 comments on commit 420ed7a

Please sign in to comment.