Skip to content

Commit

Permalink
Add support for dilation to hk.DepthwiseConvND.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 545685750
  • Loading branch information
Haiku Contributor authored and copybara-github committed Jul 17, 2023
1 parent e0e9443 commit e954549
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion haiku/_src/depthwise_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
num_spatial_dims: int,
data_format: str,
stride: Union[int, Sequence[int]] = 1,
rate: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
with_bias: bool = True,
w_init: Optional[hk.initializers.Initializer] = None,
Expand All @@ -79,6 +80,9 @@ def __init__(
default, ``channels_last``. See :func:`get_channel_index`.
stride: Optional stride for the kernel. Either an integer or a sequence of
length ``num_spatial_dims``. Defaults to 1.
rate: Optional kernel dilation rate. Either an integer or a sequence of
length ``num_spatial_dims``. 1 corresponds to standard ND convolution,
``rate > 1`` corresponds to dilated convolution. Defaults to 1.
padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a
sequence of ``before, after`` pairs. Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
Expand All @@ -92,7 +96,7 @@ def __init__(
self.kernel_shape = utils.replicate(kernel_shape, self.num_spatial_dims,
"kernel_shape")
self.lhs_dilation = (1,) * len(self.kernel_shape)
self.rhs_dilation = (1,) * len(self.kernel_shape)
self.rhs_dilation = utils.replicate(rate, num_spatial_dims, "rhs_dilation")
self.channel_multiplier = channel_multiplier
self.padding = padding
self.stride = utils.replicate(stride, self.num_spatial_dims, "strides")
Expand Down Expand Up @@ -209,6 +213,7 @@ def __init__(
channel_multiplier: int,
kernel_shape: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
rate: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
with_bias: bool = True,
w_init: Optional[hk.initializers.Initializer] = None,
Expand All @@ -225,6 +230,9 @@ def __init__(
length 1.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 1. Defaults to 1.
rate: Optional kernel dilation rate. Either an integer or a sequence of
length 1. 1 corresponds to standard ND convolution,
``rate > 1`` corresponds to dilated convolution. Defaults to 1.
padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a
sequence of ``before, after`` pairs. Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
Expand All @@ -242,6 +250,7 @@ def __init__(
channel_multiplier=channel_multiplier,
kernel_shape=kernel_shape,
stride=stride,
rate=rate,
padding=padding,
with_bias=with_bias,
w_init=w_init,
Expand All @@ -257,6 +266,7 @@ def __init__(
channel_multiplier: int,
kernel_shape: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
rate: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
with_bias: bool = True,
w_init: Optional[hk.initializers.Initializer] = None,
Expand All @@ -273,6 +283,9 @@ def __init__(
length 2.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 2. Defaults to 1.
rate: Optional kernel dilation rate. Either an integer or a sequence of
length 1. 1 corresponds to standard ND convolution,
``rate > 1`` corresponds to dilated convolution. Defaults to 1.
padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a
sequence of ``before, after`` pairs. Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
Expand All @@ -290,6 +303,7 @@ def __init__(
channel_multiplier=channel_multiplier,
kernel_shape=kernel_shape,
stride=stride,
rate=rate,
padding=padding,
with_bias=with_bias,
w_init=w_init,
Expand All @@ -305,6 +319,7 @@ def __init__(
channel_multiplier: int,
kernel_shape: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
rate: Union[int, Sequence[int]] = 1,
padding: Union[str, Sequence[Tuple[int, int]]] = "SAME",
with_bias: bool = True,
w_init: Optional[hk.initializers.Initializer] = None,
Expand All @@ -321,6 +336,9 @@ def __init__(
length 3.
stride: Optional stride for the kernel. Either an integer or a sequence of
length 3. Defaults to 1.
rate: Optional kernel dilation rate. Either an integer or a sequence of
length 1. 1 corresponds to standard ND convolution,
``rate > 1`` corresponds to dilated convolution. Defaults to 1.
padding: Optional padding algorithm. Either ``VALID``, ``SAME`` or a
sequence of ``before, after`` pairs. Defaults to ``SAME``. See:
https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
Expand All @@ -338,6 +356,7 @@ def __init__(
channel_multiplier=channel_multiplier,
kernel_shape=kernel_shape,
stride=stride,
rate=rate,
padding=padding,
with_bias=with_bias,
w_init=w_init,
Expand Down

0 comments on commit e954549

Please sign in to comment.