diff --git a/LICENSES.md b/LICENSES.md index 5de8358331..3cdeddf6ff 100644 --- a/LICENSES.md +++ b/LICENSES.md @@ -2,7 +2,10 @@ In this file, we list the operations with other licenses instead of Apache 2.0. Users should be careful about adopting these operations in any commercial matters. -| Operation | Files | License | -| :--------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------: | -| upfirdn2d | [mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu](https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu) | NVIDIA License | -| fused_leaky_relu | [mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu](https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu) | NVIDIA License | +| Operation | Files | License | +| :--------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------: | +| upfirdn2d | [mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu) | NVIDIA License | +| fused_leaky_relu | [mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/fused_bias_leakyrelu_cuda.cu) | NVIDIA License | +| bias_act | [mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu) | NVIDIA License | +| filtered_lrelu | [mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu) | NVIDIA License | +| conv2d_gradfix | [mmcv/ops/conv2d_gradfix.py](https://github.com/open-mmlab/mmcv/tree/2.x/mmcv/ops/conv2d_gradfix.py) | NVIDIA License | diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 0fe338d47c..e60f77c772 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -61,3 +61,6 @@ We implement common ops used in detection, segmentation, etc. | Voxelization | √ | √ | | | √ | | PrRoIPool | | √ | | | | | BezierAlign | √ | √ | | | | +| BiasAct | | √ | | | | +| FilteredLrelu | | √ | | | | +| Conv2dGradfix | | √ | | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 9c2dc3f1e2..11b885d37c 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -61,3 +61,6 @@ MMCV 提供了检测、分割等任务中常用的算子 | Voxelization | √ | √ | | | √ | | PrRoIPool | | √ | | | | | BezierAlign | √ | √ | | | | +| BiasAct | | √ | | | | +| FilteredLrelu | | √ | | | | +| Conv2dGradfix | | √ | | | | diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 76558f3532..0c36433317 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -4,6 +4,7 @@ from .ball_query import ball_query from .bbox import bbox_overlaps from .bezier_align import BezierAlign, bezier_align +from .bias_act import bias_act from .border_align import BorderAlign, border_align from .box_iou_quadri import box_iou_quadri from .box_iou_rotated import box_iou_rotated @@ -11,6 +12,7 @@ from .cc_attention import CrissCrossAttention from .chamfer_distance import chamfer_distance from .contour_expand import contour_expand +from .conv2d_gradfix import conv2d, conv_transpose2d from .convex_iou import convex_giou, convex_iou from .corner_pool import CornerPool from .correlation import Correlation @@ -22,6 +24,7 @@ from .deprecated_wrappers import Linear_deprecated as Linear from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d +from .filtered_lrelu import filtered_lrelu from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, sigmoid_focal_loss, softmax_focal_loss) from .furthest_point_sample import (furthest_point_sample, @@ -71,6 +74,7 @@ from .tin_shift import TINShift, tin_shift from .upfirdn2d import upfirdn2d from .vector_pool_with_voxel_query import vector_pool_with_voxel_query +from .upfirdn2d import filter2d, upfirdn2d, upsample2d from .voxelize import Voxelization, voxelization __all__ = [ @@ -107,5 +111,7 @@ 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance', 'PrRoIPool', 'prroi_pool', 'three_nn_vector_pool_by_two_step', 'stack_three_interpolate', 'vector_pool_with_voxel_query', + 'BezierAlign', 'bezier_align', 'bias_act', 'filtered_lrelu', 'conv2d', + 'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align' ] diff --git a/mmcv/ops/bias_act.py b/mmcv/ops/bias_act.py new file mode 100644 index 0000000000..3dfa55743e --- /dev/null +++ b/mmcv/ops/bias_act.py @@ -0,0 +1,375 @@ +# Modified from +# https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.py + +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# source: https://github.com/open-mmlab/mmediting/blob/dev-1.x/mmedit/models/editors/stylegan3/stylegan3_ops/ops/bias_act.py # noqa +"""Custom PyTorch ops for efficient bias and activation.""" + +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', ['bias_act']) + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the + attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +activation_funcs = { + 'linear': + EasyDict( + func=lambda x, **_: x, + def_alpha=0, + def_gain=1, + cuda_idx=1, + ref='', + has_2nd_grad=False), + 'relu': + EasyDict( + func=lambda x, **_: torch.nn.functional.relu(x), + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=2, + ref='y', + has_2nd_grad=False), + 'lrelu': + EasyDict( + func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), + def_alpha=0.2, + def_gain=np.sqrt(2), + cuda_idx=3, + ref='y', + has_2nd_grad=False), + 'tanh': + EasyDict( + func=lambda x, **_: torch.tanh(x), + def_alpha=0, + def_gain=1, + cuda_idx=4, + ref='y', + has_2nd_grad=True), + 'sigmoid': + EasyDict( + func=lambda x, **_: torch.sigmoid(x), + def_alpha=0, + def_gain=1, + cuda_idx=5, + ref='y', + has_2nd_grad=True), + 'elu': + EasyDict( + func=lambda x, **_: torch.nn.functional.elu(x), + def_alpha=0, + def_gain=1, + cuda_idx=6, + ref='y', + has_2nd_grad=True), + 'selu': + EasyDict( + func=lambda x, **_: torch.nn.functional.selu(x), + def_alpha=0, + def_gain=1, + cuda_idx=7, + ref='y', + has_2nd_grad=True), + 'softplus': + EasyDict( + func=lambda x, **_: torch.nn.functional.softplus(x), + def_alpha=0, + def_gain=1, + cuda_idx=8, + ref='y', + has_2nd_grad=True), + 'swish': + EasyDict( + func=lambda x, **_: torch.sigmoid(x) * x, + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=9, + ref='x', + has_2nd_grad=True), +} + +_null_tensor = torch.empty([0]) + + +def bias_act(input: torch.Tensor, + bias: Optional[torch.Tensor] = None, + dim: int = 1, + act: str = 'linear', + alpha: Optional[Union[float, int]] = None, + gain: Optional[float] = None, + clamp: Optional[float] = None, + use_custom_op: bool = True): + r"""Fused bias and activation function. + + Adds `bias` to activation tensor `input`, and evaluates activation + function `act`, and scales the result by `gain`. Each of the steps is + optional. + + In most cases, the fused op is considerably more efficient than performing + the same calculation using standard PyTorch ops. It supports first and + second order gradients, but not third order gradients. + + Args: + input (torch.Tensor): Input activation tensor. Can be of any shape. + bias (torch.Tensor): Bias vector, or `None` to disable. + Must be a 1D tensor of the same type as `input`. The shape must + be known, and it must match the dimension of `input` corresponding + to `dim`. Defaults to None. + dim (int): The dimension in `input` corresponding to the elements of + `bias`. The value of `dim` is ignored if `b` is not specified. + Defaults to 1. + act (str): Name of the activation function to evaluate, or `"linear"` + to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid", + "swish", etc. See `activation_funcs` for a full list. `None` is not + allowed. Defaults to `linear`. + alpha (float or int): Shape parameter for the activation + function, or `None` to use the default. Defaults to None. + gain (float): Scaling factor for the output tensor, or `None` + to use default. See `activation_funcs` for the default scaling of + each activation function. If unsure, consider specifying 1. + Defaults to None. + clamp (float): Clamp the output values to `[-clamp, +clamp]`, + or `None` to disable the clamping (default). Defaults to None. + use_custom_op (bool): Whether to use customized op. + Defaults to True. + + Returns: + torch.Tensor: Tensor of the same shape and datatype as `input`. + """ + assert isinstance(input, torch.Tensor) + if use_custom_op and input.is_cuda: + return _bias_act_cuda( + dim=dim, act=act, alpha=alpha, gain=gain, + clamp=clamp).apply(input, bias) + return _bias_act_ref( + input=input, + bias=bias, + dim=dim, + act=act, + alpha=alpha, + gain=gain, + clamp=clamp) + + +def _bias_act_ref(input: torch.Tensor, + bias: Optional[torch.Tensor] = None, + dim: int = 1, + act: str = 'linear', + alpha: Optional[Union[float, int]] = None, + gain: Optional[float] = None, + clamp: Optional[float] = None): + """Slow reference implementation of `bias_act()` using standard PyTorch + ops. + + Adds `bias` to activation tensor `input`, and evaluates activation + function `act`, and scales the result by `gain`. Each of the steps is + optional. + + In most cases, the fused op is considerably more efficient than performing + the same calculation using standard PyTorch ops. It supports first and + second order gradients, but not third order gradients. + + Args: + input (torch.Tensor): Input activation tensor. Can be of any shape. + bias (torch.Tensor): Bias vector, or `None` to disable. + Must be a 1D tensor of the same type as `input`. The shape must + be known, and it must match the dimension of `input` corresponding + to `dim`. Defaults to None. + dim (int): The dimension in `input` corresponding to the elements of + `bias`. The value of `dim` is ignored if `b` is not specified. + Defaults to 1. + act (str): Name of the activation function to evaluate, or `"linear"` + to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid", + "swish", etc. See `activation_funcs` for a full list. `None` is not + allowed. Defaults to `linear`. + alpha (float or int): Shape parameter for the activation + function, or `None` to use the default. Defaults to None. + gain (float): Scaling factor for the output tensor, or `None` + to use default. See `activation_funcs` for the default scaling of + each activation function. If unsure, consider specifying 1. + Defaults to None. + clamp (float): Clamp the output values to + `[-clamp, +clamp]`, or `None` to disable the clamping (default). + Defaults to None. + + Returns: + torch.Tensor: Tensor of the same shape and datatype as `input`. + """ + assert isinstance(input, torch.Tensor) + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Add bias. + if bias is not None: + assert isinstance(bias, torch.Tensor) and bias.ndim == 1 + assert 0 <= dim < input.ndim + assert bias.shape[0] == input.shape[dim] + input = input + bias.reshape( + [-1 if i == dim else 1 for i in range(input.ndim)]) + + # Evaluate activation function. + alpha = float(alpha) + output = spec.func(input, alpha=alpha) + + # Scale by gain. + gain = float(gain) + if gain != 1: + output = output * gain + + # Clamp. + if clamp >= 0: + # pylint: disable=invalid-unary-operand-type + output = output.clamp(-clamp, clamp) + return output + + +_bias_act_cuda_cache: Dict = dict() + + +def _bias_act_cuda(dim: int = 1, + act: str = 'linear', + alpha: Optional[Union[float, int]] = None, + gain: Optional[float] = None, + clamp: Optional[float] = None): + """"Fast CUDA implementation of `bias_act()` using custom ops. + + Args: + dim (int): The dimension in `x` corresponding to the elements of `b`. + The value of `dim` is ignored if `b` is not specified. + Defaults to 1. + act (str): Name of the activation function to evaluate, or `"linear"` + to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid", + "swish", etc. See `activation_funcs` for a full list. `None` is not + allowed. Defaults to `linear`. + alpha (float | int): Shape parameter for the activation + function, or `None` to use the default. Defaults to None. + gain (float): Scaling factor for the output tensor, or `None` + to use default. See `activation_funcs` for the default scaling of + each activation function. If unsure, consider specifying 1. + Defaults to None. + clamp (float): Clamp the output values to `[-clamp, +clamp]`, + or `None` to disable the clamping (default). Defaults to None. + + Returns: + torch.Tensor: Tensor of the same shape and datatype as `x`. + """ + # Parse arguments. + assert clamp is None or clamp >= 0 + spec = activation_funcs[act] + alpha = float(alpha if alpha is not None else spec.def_alpha) + gain = float(gain if gain is not None else spec.def_gain) + clamp = float(clamp if clamp is not None else -1) + + # Lookup from cache. + key = (dim, act, alpha, gain, clamp) + if key in _bias_act_cuda_cache: + return _bias_act_cuda_cache[key] + + # Forward op. + class BiasActCuda(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, b): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride( + 1) == 1 else torch.contiguous_format + x = x.contiguous(memory_format=ctx.memory_format) + b = b.contiguous() if b is not None else _null_tensor.to(x.device) + y = x + if act != 'linear' or gain != 1 or clamp >= 0 or ( + b is not _null_tensor.to(x.device)): + y = ext_module.bias_act(x, b, _null_tensor.to(x.device), + _null_tensor.to(x.device), + _null_tensor.to(x.device), 0, dim, + spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor.to( + x.device), b if 'x' in spec.ref or spec.has_2nd_grad else + _null_tensor.to(x.device), + y if 'y' in spec.ref else _null_tensor.to(x.device)) + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + dy = dy.contiguous(memory_format=ctx.memory_format) + x, b, y = ctx.saved_tensors + dx = None + db = None + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + dx = dy + if act != 'linear' or gain != 1 or clamp >= 0: + dx = BiasActCudaGrad.apply(dy, x, b, y) + + if ctx.needs_input_grad[1]: + db = dx.sum([i for i in range(dx.ndim) if i != dim]) + + return dx, db + + # Backward op. + class BiasActCudaGrad(torch.autograd.Function): + + @staticmethod + def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ + ctx.memory_format = torch.channels_last if dy.ndim > 2 and ( + dy.stride(1) == 1) else torch.contiguous_format + dx = ext_module.bias_act(dy, b, x, y, _null_tensor.to(x.device), 1, + dim, spec.cuda_idx, alpha, gain, clamp) + ctx.save_for_backward( + dy if spec.has_2nd_grad else _null_tensor.to(x.device), x, b, + y) + return dx + + @staticmethod + def backward(ctx, d_dx): # pylint: disable=arguments-differ + d_dx = d_dx.contiguous(memory_format=ctx.memory_format) + dy, x, b, y = ctx.saved_tensors + d_dy = None + d_x = None + d_b = None + d_y = None + + if ctx.needs_input_grad[0]: + d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) + + if spec.has_2nd_grad and (ctx.needs_input_grad[1] + or ctx.needs_input_grad[2]): + d_x = ext_module.bias_act(d_dx, b, x, y, dy, 2, dim, + spec.cuda_idx, alpha, gain, clamp) + + if spec.has_2nd_grad and ctx.needs_input_grad[2]: + d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) + + return d_dy, d_x, d_b, d_y + + # Add to cache. + _bias_act_cuda_cache[key] = BiasActCuda + return BiasActCuda diff --git a/mmcv/ops/conv2d_gradfix.py b/mmcv/ops/conv2d_gradfix.py new file mode 100644 index 0000000000..fb998a8fd0 --- /dev/null +++ b/mmcv/ops/conv2d_gradfix.py @@ -0,0 +1,301 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# source: https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/conv2d_gradfix.py # noqa +"""Custom replacement for `torch.nn.functional.conv2d` that supports +arbitrarily high order gradients with zero performance penalty.""" + +import contextlib +import warnings +from typing import Dict, Optional, Tuple, Union + +import torch + +enabled = True +weight_gradients_disabled = False + + +@contextlib.contextmanager +def no_weight_gradients(disable=True): + global weight_gradients_disabled + old = weight_gradients_disabled + if disable: + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + + +def conv2d(input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1): + flag = True + if torch.__version__ >= '1.10.0': + warnings.warn('Since ' + 'aten:cudnn_convolution_backward_weight is ' + f'not supported in torch=={torch.__version__},' + ' rolling back to `torch.nn.functional.conv2d`') + flag = False + if _should_use_custom_op(input) and flag: + return _conv2d_gradfix( + transpose=False, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=0, + dilation=dilation, + groups=groups).apply(input, weight, bias) + return torch.nn.functional.conv2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups) + + +def conv_transpose2d(input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + output_padding: Union[int, Tuple[int, ...]] = 0, + groups: int = 1, + dilation: Union[int, Tuple[int, ...]] = 1): + if _should_use_custom_op(input): + return _conv2d_gradfix( + transpose=True, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation).apply(input, weight, bias) + return torch.nn.functional.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + + +def _should_use_custom_op(input): + assert isinstance(input, torch.Tensor) + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + if input.device.type != 'cuda': + return False + return True + + +def _to_tuple(x, ndim): + xs = tuple(x) if isinstance(x, (tuple, list)) else (x, ) * ndim + assert len(xs) == ndim + assert all(isinstance(x, int) for x in xs) + return xs + + +_conv2d_gradfix_cache: Dict = dict() +_null_tensor = torch.empty([0]) + + +def _conv2d_gradfix( + transpose: bool, + weight_shape: Tuple[int, ...], + stride: Union[int, Tuple[int, ...]], + padding: Union[int, Tuple[int, ...]], + output_padding: Union[int, Tuple[int, ...]], + dilation: Union[int, Tuple[int, ...]], + groups: int, +): + # Parse arguments. + ndim = 2 + weight_shape = tuple(weight_shape) + stride = _to_tuple(stride, ndim) + padding = _to_tuple(padding, ndim) + output_padding = _to_tuple(output_padding, ndim) + dilation = _to_tuple(dilation, ndim) + + # Lookup from cache. + key = (transpose, weight_shape, stride, padding, output_padding, dilation, + groups) + if key in _conv2d_gradfix_cache: + return _conv2d_gradfix_cache[key] + + # Validate arguments. + + assert groups >= 1 + assert len(weight_shape) == ndim + 2 + assert all(stride[i] >= 1 for i in range(ndim)) # type: ignore + assert all(padding[i] >= 0 for i in range(ndim)) # type: ignore + assert all(dilation[i] >= 0 for i in range(ndim)) # type: ignore + if not transpose: + assert all(output_padding[i] == 0 for i in range(ndim)) # type: ignore + else: # transpose + for i in range(ndim): + assert 0 <= output_padding[i] < max( # type: ignore + stride[i], # type: ignore + dilation[i]) # type: ignore + + # Helpers. + common_kwargs = dict( + stride=stride, padding=padding, dilation=dilation, groups=groups) + + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + return [ + input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] - + (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + # Forward & backward. + class Conv2d(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, weight, bias): + assert weight.shape == weight_shape + ctx.save_for_backward( + input if weight.requires_grad else _null_tensor, + weight if input.requires_grad else _null_tensor, + ) + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). + if weight_shape[2:] == stride == dilation == ( + 1, 1) and padding == ( + 0, 0) and torch.cuda.get_device_capability( + input.device) < (8, 0): + a = weight.reshape(groups, weight_shape[0] // groups, + weight_shape[1]) + b = input.reshape(input.shape[0], groups, + input.shape[1] // groups, -1) + c = (a.transpose(1, 2) if transpose else a) @ b.permute( + 1, 2, 0, 3).flatten(2) + c = c.reshape(-1, input.shape[0], + *input.shape[2:]).transpose(0, 1) + c = c if bias is None else c + bias.unsqueeze(0).unsqueeze( + 2).unsqueeze(3) + return c.contiguous( + memory_format=(torch.channels_last if input.stride(1) == + 1 else torch.contiguous_format)) + + # General case => cuDNN. + if transpose: + return torch.nn.functional.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + output_padding=output_padding, + **common_kwargs) + return torch.nn.functional.conv2d( + input=input, weight=weight, bias=bias, **common_kwargs) + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + input_shape = ctx.input_shape + grad_input = None + grad_weight = None + grad_bias = None + + if ctx.needs_input_grad[0]: + p = calc_output_padding( + input_shape=input_shape, output_shape=grad_output.shape) + op = _conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs) + grad_input = op.apply(grad_output, weight, None) + assert grad_input.shape == input_shape + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + assert grad_weight.shape == weight_shape + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum([0, 2, 3]) + + return grad_input, grad_weight, grad_bias + + # Gradient with respect to the weights. + class Conv2dGradWeight(torch.autograd.Function): + + @staticmethod + def forward(ctx, grad_output, input): + ctx.save_for_backward( + grad_output if input.requires_grad else _null_tensor, + input if grad_output.requires_grad else _null_tensor, + ) + ctx.grad_output_shape = grad_output.shape + ctx.input_shape = input.shape + + # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). + if weight_shape[2:] == stride == dilation == ( + 1, 1) and padding == (0, 0): + a = grad_output.reshape(grad_output.shape[0], groups, + grad_output.shape[1] // groups, + -1).permute(1, 2, 0, 3).flatten(2) + b = input.reshape(input.shape[0], groups, + input.shape[1] // groups, + -1).permute(1, 2, 0, 3).flatten(2) + c = (b @ a.transpose(1, 2) if transpose else + a @ b.transpose(1, 2)).reshape(weight_shape) + return c.contiguous( + memory_format=(torch.channels_last if input.stride(1) == + 1 else torch.contiguous_format)) + + # General case => cuDNN. + name = ('aten::cudnn_convolution_transpose_backward_weight' if + transpose else 'aten::cudnn_convolution_backward_weight') + flags = [ + torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, + torch.backends.cudnn.allow_tf32 + ] + return torch._C._jit_get_operation(name)(weight_shape, grad_output, + input, padding, stride, + dilation, groups, *flags) + + @staticmethod + def backward(ctx, grad2_grad_weight): + grad_output, input = ctx.saved_tensors + grad_output_shape = ctx.grad_output_shape + input_shape = ctx.input_shape + grad2_grad_output = None + grad2_input = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, + None) + assert grad2_grad_output.shape == grad_output_shape + + if ctx.needs_input_grad[1]: + p = calc_output_padding( + input_shape=input_shape, output_shape=grad_output_shape) + op = _conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs) + grad2_input = op.apply(grad_output, grad2_grad_weight, None) + assert grad2_input.shape == input_shape + + return grad2_grad_output, grad2_input + + _conv2d_gradfix_cache[key] = Conv2d + return Conv2d diff --git a/mmcv/ops/csrc/pytorch/bias_act.cpp b/mmcv/ops/csrc/pytorch/bias_act.cpp new file mode 100644 index 0000000000..5ad32dc501 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/bias_act.cpp @@ -0,0 +1,20 @@ +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +torch::Tensor bias_act_op_impl(const torch::Tensor &input, + const torch::Tensor &bias, + const torch::Tensor &xref, + const torch::Tensor &yref, + const torch::Tensor &dy, int grad, int dim, + int act, float alpha, float gain, float clamp) { + return DISPATCH_DEVICE_IMPL(bias_act_op_impl, input, bias, xref, yref, dy, + grad, dim, act, alpha, gain, clamp); +} + +torch::Tensor bias_act(const torch::Tensor &input, const torch::Tensor &bias, + const torch::Tensor &xref, const torch::Tensor &yref, + const torch::Tensor &dy, int grad, int dim, int act, + float alpha, float gain, float clamp) { + return bias_act_op_impl(input, bias, xref, yref, dy, grad, dim, act, alpha, + gain, clamp); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu new file mode 100644 index 0000000000..bf601dde8e --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu @@ -0,0 +1,295 @@ +// Modified from +// https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/bias_act.cpp + +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +#include +#include +#include + +#include "pytorch_cuda_helper.hpp" + +struct bias_act_kernel_params { + const void* x; // [sizeX] + const void* b; // [sizeB] or NULL + const void* xref; // [sizeX] or NULL + const void* yref; // [sizeX] or NULL + const void* dy; // [sizeX] or NULL + void* y; // [sizeX] + + int grad; + int act; + float alpha; + float gain; + float clamp; + + int sizeX; + int sizeB; + int stepB; + int loopX; +}; + +// CUDA kernel selection. + +template +void* choose_bias_act_kernel(const bias_act_kernel_params& p); +//------------------------------------------------------------------------ +// Helpers. + +template +struct InternalType; +template <> +struct InternalType { + typedef double scalar_t; +}; +template <> +struct InternalType { + typedef float scalar_t; +}; +template <> +struct InternalType { + typedef float scalar_t; +}; + +//------------------------------------------------------------------------ +// CUDA kernel. + +template +__global__ void bias_act_kernel(bias_act_kernel_params p) { + typedef typename InternalType::scalar_t scalar_t; + int G = p.grad; + scalar_t alpha = (scalar_t)p.alpha; + scalar_t gain = (scalar_t)p.gain; + scalar_t clamp = (scalar_t)p.clamp; + scalar_t one = (scalar_t)1; + scalar_t two = (scalar_t)2; + scalar_t expRange = (scalar_t)80; + scalar_t halfExpRange = (scalar_t)40; + scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; + scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; + + // Loop over elements. + int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; + for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; + loopIdx++, xi += blockDim.x) { + // Load. + scalar_t x = (scalar_t)((const T*)p.x)[xi]; + scalar_t b = + (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; + scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; + scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; + scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; + scalar_t yy = (gain != 0) ? yref / gain : 0; + scalar_t y = 0; + + // Apply bias. + ((G == 0) ? x : xref) += b; + + // linear + if (A == 1) { + if (G == 0) y = x; + if (G == 1) y = x; + } + + // relu + if (A == 2) { + if (G == 0) y = (x > 0) ? x : 0; + if (G == 1) y = (yy > 0) ? x : 0; + } + + // lrelu + if (A == 3) { + if (G == 0) y = (x > 0) ? x : x * alpha; + if (G == 1) y = (yy > 0) ? x : x * alpha; + } + + // tanh + if (A == 4) { + if (G == 0) { + scalar_t c = exp(x); + scalar_t d = one / c; + y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); + } + if (G == 1) y = x * (one - yy * yy); + if (G == 2) y = x * (one - yy * yy) * (-two * yy); + } + + // sigmoid + if (A == 5) { + if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); + if (G == 1) y = x * yy * (one - yy); + if (G == 2) y = x * yy * (one - yy) * (one - two * yy); + } + + // elu + if (A == 6) { + if (G == 0) y = (x >= 0) ? x : exp(x) - one; + if (G == 1) y = (yy >= 0) ? x : x * (yy + one); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); + } + + // selu + if (A == 7) { + if (G == 0) + y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); + if (G == 1) + y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); + if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); + } + + // softplus + if (A == 8) { + if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); + if (G == 1) y = x * (one - exp(-yy)); + if (G == 2) { + scalar_t c = exp(-yy); + y = x * c * (one - c); + } + } + + // swish + if (A == 9) { + if (G == 0) + y = (x < -expRange) ? 0 : x / (exp(-x) + one); + else { + scalar_t c = exp(xref); + scalar_t d = c + one; + if (G == 1) + y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); + else + y = (xref > halfExpRange) + ? 0 + : x * c * (xref * (two - d) + two * d) / (d * d * d); + yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; + } + } + + // Apply gain. + y *= gain * dy; + + // Clamp. + if (clamp >= 0) { + if (G == 0) + y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; + else + y = (yref > -clamp & yref < clamp) ? y : 0; + } + + // Store. + ((T*)p.y)[xi] = (T)y; + } +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template +void* choose_bias_act_kernel(const bias_act_kernel_params& p) { + if (p.act == 1) return (void*)bias_act_kernel; + if (p.act == 2) return (void*)bias_act_kernel; + if (p.act == 3) return (void*)bias_act_kernel; + if (p.act == 4) return (void*)bias_act_kernel; + if (p.act == 5) return (void*)bias_act_kernel; + if (p.act == 6) return (void*)bias_act_kernel; + if (p.act == 7) return (void*)bias_act_kernel; + if (p.act == 8) return (void*)bias_act_kernel; + if (p.act == 9) return (void*)bias_act_kernel; + return NULL; +} + +//------------------------------------------------------------------------ + +static bool has_same_layout(torch::Tensor x, torch::Tensor y) { + if (x.dim() != y.dim()) return false; + for (int64_t i = 0; i < x.dim(); i++) { + if (x.size(i) != y.size(i)) return false; + if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) return false; + } + return true; +} + +//------------------------------------------------------------------------ +torch::Tensor bias_act_op(const torch::Tensor& x, const torch::Tensor& b, + const torch::Tensor& xref, const torch::Tensor& yref, + const torch::Tensor& dy, int grad, int dim, int act, + float alpha, float gain, float clamp) { + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK( + b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), + "b must have the same dtype and device as x"); + TORCH_CHECK(xref.numel() == 0 || + (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && + xref.device() == x.device()), + "xref must have the same shape, dtype, and device as x"); + TORCH_CHECK(yref.numel() == 0 || + (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && + yref.device() == x.device()), + "yref must have the same shape, dtype, and device as x"); + TORCH_CHECK( + dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && + dy.device() == x.device()), + "dy must have the same dtype and device as x"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(b.dim() == 1, "b must have rank 1"); + TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), + "dim is out of bounds"); + TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), + "b has wrong number of elements"); + TORCH_CHECK(grad >= 0, "grad must be non-negative"); + + // Validate layout. + TORCH_CHECK(x.is_non_overlapping_and_dense(), + "x must be non-overlapping and dense"); + TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); + TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), + "xref must have the same layout as x"); + TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), + "yref must have the same layout as x"); + TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), + "dy must have the same layout as x"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + torch::Tensor y = torch::empty_like(x); + TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); + + // Initialize CUDA kernel parameters. + bias_act_kernel_params p; + p.x = x.data_ptr(); + p.b = (b.numel()) ? b.data_ptr() : NULL; + p.xref = (xref.numel()) ? xref.data_ptr() : NULL; + p.yref = (yref.numel()) ? yref.data_ptr() : NULL; + p.dy = (dy.numel()) ? dy.data_ptr() : NULL; + p.y = y.data_ptr(); + p.grad = grad; + p.act = act; + p.alpha = alpha; + p.gain = gain; + p.clamp = clamp; + p.sizeX = (int)x.numel(); + p.sizeB = (int)b.numel(); + p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; + + // Choose CUDA kernel. + void* kernel; + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + kernel = choose_bias_act_kernel(p); + }); + TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); + + // Launch CUDA kernel. + p.loopX = 4; + int blockSize = 4 * 32; + int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; + void* args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, + at::cuda::getCurrentCUDAStream())); + return y; +} diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index 5e496b9c50..0266439613 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -3,45 +3,45 @@ void AssignScoreWithKForwardCUDAKernelLauncher( int B, int N0, int N1, int M, int K, int O, int aggregate, - const Tensor& points, const Tensor& centers, const Tensor& scores, - const Tensor& knn_idx, Tensor& output); + const Tensor &points, const Tensor ¢ers, const Tensor &scores, + const Tensor &knn_idx, Tensor &output); void AssignScoreWithKBackwardCUDAKernelLauncher( int B, int N0, int N1, int M, int K, int O, int aggregate, - const Tensor& grad_out, const Tensor& points, const Tensor& centers, - const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, - Tensor& grad_centers, Tensor& grad_scores); + const Tensor &grad_out, const Tensor &points, const Tensor ¢ers, + const Tensor &scores, const Tensor &knn_idx, Tensor &grad_points, + Tensor &grad_centers, Tensor &grad_scores); void assign_score_withk_forward_cuda(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor& points, - const Tensor& centers, - const Tensor& scores, - const Tensor& knn_idx, Tensor& output) { + int aggregate, const Tensor &points, + const Tensor ¢ers, + const Tensor &scores, + const Tensor &knn_idx, Tensor &output) { AssignScoreWithKForwardCUDAKernelLauncher( B, N0, N1, M, K, O, aggregate, points, centers, scores, knn_idx, output); }; void assign_score_withk_backward_cuda( int B, int N0, int N1, int M, int K, int O, int aggregate, - const Tensor& grad_out, const Tensor& points, const Tensor& centers, - const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, - Tensor& grad_centers, Tensor& grad_scores) { + const Tensor &grad_out, const Tensor &points, const Tensor ¢ers, + const Tensor &scores, const Tensor &knn_idx, Tensor &grad_points, + Tensor &grad_centers, Tensor &grad_scores) { AssignScoreWithKBackwardCUDAKernelLauncher( B, N0, N1, M, K, O, aggregate, grad_out, points, centers, scores, knn_idx, grad_points, grad_centers, grad_scores); }; void assign_score_withk_forward_impl(int B, int N0, int N1, int M, int K, int O, - int aggregate, const Tensor& points, - const Tensor& centers, - const Tensor& scores, - const Tensor& knn_idx, Tensor& output); + int aggregate, const Tensor &points, + const Tensor ¢ers, + const Tensor &scores, + const Tensor &knn_idx, Tensor &output); void assign_score_withk_backward_impl( int B, int N0, int N1, int M, int K, int O, int aggregate, - const Tensor& grad_out, const Tensor& points, const Tensor& centers, - const Tensor& scores, const Tensor& knn_idx, Tensor& grad_points, - Tensor& grad_centers, Tensor& grad_scores); + const Tensor &grad_out, const Tensor &points, const Tensor ¢ers, + const Tensor &scores, const Tensor &knn_idx, Tensor &grad_points, + Tensor &grad_centers, Tensor &grad_scores); REGISTER_DEVICE_IMPL(assign_score_withk_forward_impl, CUDA, assign_score_withk_forward_cuda); @@ -104,37 +104,37 @@ void bbox_overlaps_impl(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, const bool aligned, const int offset); REGISTER_DEVICE_IMPL(bbox_overlaps_impl, CUDA, bbox_overlaps_cuda); -void BorderAlignForwardCUDAKernelLauncher(const Tensor& input, - const Tensor& boxes, Tensor output, +void BorderAlignForwardCUDAKernelLauncher(const Tensor &input, + const Tensor &boxes, Tensor output, Tensor argmax_idx, const int pool_size); -void BorderAlignBackwardCUDAKernelLauncher(const Tensor& grad_output, - const Tensor& boxes, - const Tensor& argmax_idx, +void BorderAlignBackwardCUDAKernelLauncher(const Tensor &grad_output, + const Tensor &boxes, + const Tensor &argmax_idx, Tensor grad_input, const int pool_size); -void border_align_forward_cuda(const Tensor& input, const Tensor& boxes, +void border_align_forward_cuda(const Tensor &input, const Tensor &boxes, Tensor output, Tensor argmax_idx, const int pool_size) { BorderAlignForwardCUDAKernelLauncher(input, boxes, output, argmax_idx, pool_size); } -void border_align_backward_cuda(const Tensor& grad_output, const Tensor& boxes, - const Tensor& argmax_idx, Tensor grad_input, +void border_align_backward_cuda(const Tensor &grad_output, const Tensor &boxes, + const Tensor &argmax_idx, Tensor grad_input, const int pool_size) { BorderAlignBackwardCUDAKernelLauncher(grad_output, boxes, argmax_idx, grad_input, pool_size); } -void border_align_forward_impl(const Tensor& input, const Tensor& boxes, +void border_align_forward_impl(const Tensor &input, const Tensor &boxes, Tensor output, Tensor argmax_idx, const int pool_size); -void border_align_backward_impl(const Tensor& grad_output, const Tensor& boxes, - const Tensor& argmax_idx, Tensor grad_input, +void border_align_backward_impl(const Tensor &grad_output, const Tensor &boxes, + const Tensor &argmax_idx, Tensor grad_input, const int pool_size); REGISTER_DEVICE_IMPL(border_align_forward_impl, CUDA, @@ -472,18 +472,18 @@ REGISTER_DEVICE_IMPL(softmax_focal_loss_backward_impl, CUDA, softmax_focal_loss_backward_cuda); void FurthestPointSamplingForwardCUDAKernelLauncher(int b, int n, int m, - const float* dataset, - float* temp, int* idxs); + const float *dataset, + float *temp, int *idxs); void FurthestPointSamplingWithDistForwardCUDAKernelLauncher( - int b, int n, int m, const float* dataset, float* temp, int* idxs); + int b, int n, int m, const float *dataset, float *temp, int *idxs); void furthest_point_sampling_forward_cuda(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, int b, int n, int m) { - const float* dataset = points_tensor.data_ptr(); - float* temp = temp_tensor.data_ptr(); - int* idxs = idx_tensor.data_ptr(); + const float *dataset = points_tensor.data_ptr(); + float *temp = temp_tensor.data_ptr(); + int *idxs = idx_tensor.data_ptr(); FurthestPointSamplingForwardCUDAKernelLauncher(b, n, m, dataset, temp, idxs); } @@ -491,9 +491,9 @@ void furthest_point_sampling_with_dist_forward_cuda(Tensor points_tensor, Tensor temp_tensor, Tensor idx_tensor, int b, int n, int m) { - const float* dataset = points_tensor.data_ptr(); - float* temp = temp_tensor.data_ptr(); - int* idxs = idx_tensor.data_ptr(); + const float *dataset = points_tensor.data_ptr(); + float *temp = temp_tensor.data_ptr(); + int *idxs = idx_tensor.data_ptr(); FurthestPointSamplingWithDistForwardCUDAKernelLauncher(b, n, m, dataset, temp, idxs); } @@ -529,18 +529,57 @@ void stack_furthest_point_sampling_forward_impl( REGISTER_DEVICE_IMPL(stack_furthest_point_sampling_forward_impl, CUDA, stack_furthest_point_sampling_forward_cuda); -torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor& input, - const torch::Tensor& bias, - const torch::Tensor& refer, int act, +torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor &input, + const torch::Tensor &bias, + const torch::Tensor &refer, int act, int grad, float alpha, float scale); -torch::Tensor fused_bias_leakyrelu_op_impl(const torch::Tensor& input, - const torch::Tensor& bias, - const torch::Tensor& refer, int act, +torch::Tensor fused_bias_leakyrelu_op_impl(const torch::Tensor &input, + const torch::Tensor &bias, + const torch::Tensor &refer, int act, int grad, float alpha, float scale); REGISTER_DEVICE_IMPL(fused_bias_leakyrelu_op_impl, CUDA, fused_bias_leakyrelu_op); +torch::Tensor bias_act_op_impl(const torch::Tensor &input, + const torch::Tensor &bias, + const torch::Tensor &xref, + const torch::Tensor &yref, + const torch::Tensor &dy, int grad, int dim, + int act, float alpha, float gain, float clamp); + +torch::Tensor bias_act_op(const torch::Tensor &input, const torch::Tensor &bias, + const torch::Tensor &xref, const torch::Tensor &yref, + const torch::Tensor &dy, int grad, int dim, int act, + float alpha, float gain, float clamp); + +REGISTER_DEVICE_IMPL(bias_act_op_impl, CUDA, bias_act_op); + +std::tuple filtered_lrelu_op_impl( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, + torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, + int sx, int sy, float gain, float slope, float clamp, bool flip_filters, + bool writeSigns); + +std::tuple filtered_lrelu_op( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, + torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, + int sx, int sy, float gain, float slope, float clamp, bool flip_filters, + bool writeSigns); + +REGISTER_DEVICE_IMPL(filtered_lrelu_op_impl, CUDA, filtered_lrelu_op); + +torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si, + int sx, int sy, float gain, + float slope, float clamp, + bool writeSigns); + +torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx, + int sy, float gain, float slope, + float clamp, bool writeSigns); + +REGISTER_DEVICE_IMPL(filtered_lrelu_act_op_impl, CUDA, filtered_lrelu_act_op); + void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints, const Tensor points, const Tensor idx, Tensor out); @@ -668,12 +707,12 @@ void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a, const Tensor boxes_b, Tensor ans_overlap); -void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep, - Tensor& keep_num, +void IoU3DNMS3DForwardCUDAKernelLauncher(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh); -void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor& keep, - Tensor& keep_num, +void IoU3DNMS3DNormalForwardCUDAKernelLauncher(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh); void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a, @@ -683,14 +722,14 @@ void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a, ans_overlap); }; -void iou3d_nms3d_forward_cuda(const Tensor boxes, Tensor& keep, - Tensor& keep_num, float nms_overlap_thresh) { +void iou3d_nms3d_forward_cuda(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh) { IoU3DNMS3DForwardCUDAKernelLauncher(boxes, keep, keep_num, nms_overlap_thresh); }; -void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, Tensor& keep, - Tensor& keep_num, +void iou3d_nms3d_normal_forward_cuda(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh) { IoU3DNMS3DNormalForwardCUDAKernelLauncher(boxes, keep, keep_num, nms_overlap_thresh); @@ -700,11 +739,11 @@ void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a, const int num_b, const Tensor boxes_b, Tensor ans_overlap); -void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor& keep, - Tensor& keep_num, float nms_overlap_thresh); +void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh); -void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor& keep, - Tensor& keep_num, +void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep, + Tensor &keep_num, float nms_overlap_thresh); REGISTER_DEVICE_IMPL(iou3d_boxes_overlap_bev_forward_impl, CUDA, @@ -829,31 +868,31 @@ REGISTER_DEVICE_IMPL(modulated_deformable_col2im_impl, CUDA, REGISTER_DEVICE_IMPL(modulated_deformable_col2im_coord_impl, CUDA, modulated_deformable_col2im_coord_cuda); -Tensor ms_deform_attn_cuda_forward(const Tensor& value, - const Tensor& spatial_shapes, - const Tensor& level_start_index, - const Tensor& sampling_loc, - const Tensor& attn_weight, +Tensor ms_deform_attn_cuda_forward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, const int im2col_step); void ms_deform_attn_cuda_backward( - const Tensor& value, const Tensor& spatial_shapes, - const Tensor& level_start_index, const Tensor& sampling_loc, - const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value, - Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step); - -Tensor ms_deform_attn_impl_forward(const Tensor& value, - const Tensor& spatial_shapes, - const Tensor& level_start_index, - const Tensor& sampling_loc, - const Tensor& attn_weight, + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step); + +Tensor ms_deform_attn_impl_forward(const Tensor &value, + const Tensor &spatial_shapes, + const Tensor &level_start_index, + const Tensor &sampling_loc, + const Tensor &attn_weight, const int im2col_step); void ms_deform_attn_impl_backward( - const Tensor& value, const Tensor& spatial_shapes, - const Tensor& level_start_index, const Tensor& sampling_loc, - const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value, - Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step); + const Tensor &value, const Tensor &spatial_shapes, + const Tensor &level_start_index, const Tensor &sampling_loc, + const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value, + Tensor &grad_sampling_loc, Tensor &grad_attn_weight, const int im2col_step); REGISTER_DEVICE_IMPL(ms_deform_attn_impl_forward, CUDA, ms_deform_attn_cuda_forward); @@ -1261,26 +1300,26 @@ REGISTER_DEVICE_IMPL(roi_pool_backward_impl, CUDA, roi_pool_backward_cuda); typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; std::vector DynamicPointToVoxelForwardCUDAKernelLauncher( - const at::Tensor& feats, const at::Tensor& coors, + const at::Tensor &feats, const at::Tensor &coors, const reduce_t reduce_type); void DynamicPointToVoxelBackwardCUDAKernelLauncher( - at::Tensor& grad_feats, const at::Tensor& grad_reduced_feats, - const at::Tensor& feats, const at::Tensor& reduced_feats, - const at::Tensor& coors_map, const at::Tensor& reduce_count, + at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats, + const at::Tensor &feats, const at::Tensor &reduced_feats, + const at::Tensor &coors_map, const at::Tensor &reduce_count, const reduce_t reduce_type); std::vector dynamic_point_to_voxel_forward_cuda( - const torch::Tensor& feats, const torch::Tensor& coors, + const torch::Tensor &feats, const torch::Tensor &coors, const reduce_t reduce_type) { return DynamicPointToVoxelForwardCUDAKernelLauncher(feats, coors, reduce_type); }; void dynamic_point_to_voxel_backward_cuda( - torch::Tensor& grad_feats, const torch::Tensor& grad_reduced_feats, - const torch::Tensor& feats, const torch::Tensor& reduced_feats, - const torch::Tensor& coors_idx, const torch::Tensor& reduce_count, + torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, const reduce_t reduce_type) { DynamicPointToVoxelBackwardCUDAKernelLauncher(grad_feats, grad_reduced_feats, feats, reduced_feats, coors_idx, @@ -1288,13 +1327,13 @@ void dynamic_point_to_voxel_backward_cuda( }; std::vector dynamic_point_to_voxel_forward_impl( - const torch::Tensor& feats, const torch::Tensor& coors, + const torch::Tensor &feats, const torch::Tensor &coors, const reduce_t reduce_type); void dynamic_point_to_voxel_backward_impl( - torch::Tensor& grad_feats, const torch::Tensor& grad_reduced_feats, - const torch::Tensor& feats, const torch::Tensor& reduced_feats, - const torch::Tensor& coors_idx, const torch::Tensor& reduce_count, + torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, const reduce_t reduce_type); REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, CUDA, @@ -1495,37 +1534,36 @@ void tin_shift_backward_impl(Tensor grad_output, Tensor shift, REGISTER_DEVICE_IMPL(tin_shift_forward_impl, CUDA, tin_shift_forward_cuda); REGISTER_DEVICE_IMPL(tin_shift_backward_impl, CUDA, tin_shift_backward_cuda); -torch::Tensor upfirdn2d_op(const torch::Tensor& input, - const torch::Tensor& kernel, int up_x, int up_y, - int down_x, int down_y, int pad_x0, int pad_x1, - int pad_y0, int pad_y1); +torch::Tensor upfirdn2d_op(torch::Tensor input, torch::Tensor filter, int upx, + int upy, int downx, int downy, int padx0, int padx1, + int pady0, int pady1, bool flip, float gain); -torch::Tensor upfirdn2d_op_impl(const torch::Tensor& input, - const torch::Tensor& kernel, int up_x, int up_y, - int down_x, int down_y, int pad_x0, int pad_x1, - int pad_y0, int pad_y1); +torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter, + int upx, int upy, int downx, int downy, + int padx0, int padx1, int pady0, int pady1, + bool flip, float gain); REGISTER_DEVICE_IMPL(upfirdn2d_op_impl, CUDA, upfirdn2d_op); int HardVoxelizeForwardCUDAKernelLauncher( - const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors, - at::Tensor& num_points_per_voxel, const std::vector voxel_size, + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, const std::vector coors_range, const int max_points, const int max_voxels, const int NDim = 3); int NondeterministicHardVoxelizeForwardCUDAKernelLauncher( - const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors, - at::Tensor& num_points_per_voxel, const std::vector voxel_size, + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, const std::vector coors_range, const int max_points, const int max_voxels, const int NDim = 3); void DynamicVoxelizeForwardCUDAKernelLauncher( - const at::Tensor& points, at::Tensor& coors, + const at::Tensor &points, at::Tensor &coors, const std::vector voxel_size, const std::vector coors_range, const int NDim = 3); -int hard_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& voxels, - at::Tensor& coors, - at::Tensor& num_points_per_voxel, +int hard_voxelize_forward_cuda(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, const std::vector coors_range, const int max_points, const int max_voxels, @@ -1536,8 +1574,8 @@ int hard_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& voxels, }; int nondeterministic_hard_voxelize_forward_cuda( - const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors, - at::Tensor& num_points_per_voxel, const std::vector voxel_size, + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, const std::vector coors_range, const int max_points, const int max_voxels, const int NDim) { return NondeterministicHardVoxelizeForwardCUDAKernelLauncher( @@ -1545,7 +1583,7 @@ int nondeterministic_hard_voxelize_forward_cuda( max_points, max_voxels, NDim); }; -void dynamic_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& coors, +void dynamic_voxelize_forward_cuda(const at::Tensor &points, at::Tensor &coors, const std::vector voxel_size, const std::vector coors_range, const int NDim) { @@ -1553,21 +1591,21 @@ void dynamic_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& coors, coors_range, NDim); }; -int hard_voxelize_forward_impl(const at::Tensor& points, at::Tensor& voxels, - at::Tensor& coors, - at::Tensor& num_points_per_voxel, +int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, const std::vector coors_range, const int max_points, const int max_voxels, const int NDim); int nondeterministic_hard_voxelize_forward_impl( - const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors, - at::Tensor& num_points_per_voxel, const std::vector voxel_size, + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, const std::vector coors_range, const int max_points, const int max_voxels, const int NDim); -void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors, +void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors, const std::vector voxel_size, const std::vector coors_range, const int NDim); diff --git a/mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu b/mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu new file mode 100644 index 0000000000..b0d76908ca --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/filtered_lrelu.cu @@ -0,0 +1,1956 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. +#include +#include +#include + +#include + +#include "pytorch_cuda_helper.hpp" + +//------------------------------------------------------------------------ +// CUDA kernel parameters. + +struct filtered_lrelu_kernel_params { + // These parameters decide which kernel to use. + int up; // upsampling ratio (1, 2, 4) + int down; // downsampling ratio (1, 2, 4) + int2 fuShape; // [size, 1] | [size, size] + int2 fdShape; // [size, 1] | [size, size] + + int _dummy; // Alignment. + + // Rest of the parameters. + const void* x; // Input tensor. + void* y; // Output tensor. + const void* b; // Bias tensor. + unsigned char* s; // Sign tensor in/out. NULL if unused. + const float* fu; // Upsampling filter. + const float* fd; // Downsampling filter. + + int2 pad0; // Left/top padding. + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + int flip; // Filter kernel flip for gradient computation. + + int tilesXdim; // Original number of horizontal output tiles. + int tilesXrep; // Number of horizontal tiles per CTA. + int blockZofs; // Block z offset to support large minibatch, channel + // dimensions. + + int4 xShape; // [width, height, channel, batch] + int4 yShape; // [width, height, channel, batch] + int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if + // unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. + int swLimit; // Active width of sign tensor in bytes. + + longlong4 xStride; // Strides of all tensors except signs, same component + // order as shapes. + longlong4 yStride; // + int64_t bStride; // + longlong3 fuStride; // + longlong3 fdStride; // +}; + +struct filtered_lrelu_act_kernel_params { + void* x; // Input/output, modified in-place. + unsigned char* s; // Sign tensor in/out. NULL if unused. + + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + + int4 xShape; // [width, height, channel, batch] + longlong4 xStride; // Input/output tensor strides, same order as in shape. + int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if + // unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. +}; + +//------------------------------------------------------------------------ +// CUDA kernel specialization. + +struct filtered_lrelu_kernel_spec { + void* setup; // Function for filter kernel setup. + void* exec; // Function for main operation. + int2 tileOut; // Width/height of launch tile. + int numWarps; // Number of warps per thread block, determines launch block + // size. + int xrep; // For processing multiple horizontal tiles per thread block. + int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. +}; + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template +filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel( + const filtered_lrelu_kernel_params& p, int sharedKB); +template +void* choose_filtered_lrelu_act_kernel(void); + +//------------------------------------------------------------------------ +// Helpers. + +enum // Filter modes. +{ MODE_SUSD = 0, // Separable upsampling, separable downsampling. + MODE_FUSD = 1, // Full upsampling, separable downsampling. + MODE_SUFD = 2, // Separable upsampling, full downsampling. + MODE_FUFD = 3, // Full upsampling, full downsampling. +}; + +template +struct InternalType; +template <> +struct InternalType { + typedef double scalar_t; + typedef double2 vec2_t; + typedef double4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { + return make_double2(0, 0); + } + __device__ __forceinline__ static vec4_t zero_vec4(void) { + return make_double4(0, 0, 0, 0); + } + __device__ __forceinline__ static double clamp(double x, double c) { + return fmin(fmax(x, -c), c); + } +}; +template <> +struct InternalType { + typedef float scalar_t; + typedef float2 vec2_t; + typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { + return make_float2(0, 0); + } + __device__ __forceinline__ static vec4_t zero_vec4(void) { + return make_float4(0, 0, 0, 0); + } + __device__ __forceinline__ static float clamp(float x, float c) { + return fminf(fmaxf(x, -c), c); + } +}; +template <> +struct InternalType { + typedef float scalar_t; + typedef float2 vec2_t; + typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { + return make_float2(0, 0); + } + __device__ __forceinline__ static vec4_t zero_vec4(void) { + return make_float4(0, 0, 0, 0); + } + __device__ __forceinline__ static float clamp(float x, float c) { + return fminf(fmaxf(x, -c), c); + } +}; + +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +#define MAX(A, B) ((A) > (B) ? (A) : (B)) +#define CEIL_DIV(A, B) \ + (((B) == 1) \ + ? (A) \ + : ((B) == 2) ? ((int)((A) + 1) >> 1) \ + : ((B) == 4) ? ((int)((A) + 3) >> 2) \ + : (((A) + ((A) > 0 ? (B)-1 : 0)) / (B))) + +// This works only up to blocks of size 256 x 256 and for all N that are powers +// of two. +template +__device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) { + if ((N & (N - 1)) && N <= 256) + y = (i * ((1 << 24) / N + 1)) >> 24; // Assumes N <= 256, i < N*256. + else + y = i / N; + + x = i - y * N; +} + +// Type cast stride before reading it. +template +__device__ __forceinline__ T get_stride(const int64_t& x) { + return *reinterpret_cast(&x); +} + +//------------------------------------------------------------------------ +// Filters, setup kernel, copying function. + +#define MAX_FILTER_SIZE 32 + +// Combined up/down filter buffers so that transfer can be done with one copy. +__device__ float + g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, + // written by setup kernel. +__device__ __constant__ float + c_fbuf[2 * MAX_FILTER_SIZE * + MAX_FILTER_SIZE]; // Filters in constant memory, read by main + // kernel. + +// Accessors to combined buffers to index up/down filters individually. +#define c_fu (c_fbuf) +#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) +#define g_fu (g_fbuf) +#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) + +// Set up filters into global memory buffer. +static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) { + for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; + idx += blockDim.x) { + int x, y; + fast_div_mod(x, y, idx); + + int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); + int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); + if (p.fuShape.y > 0) + g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) + ? 0.0f + : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; + else + g_fu[idx] = + (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; + + int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); + int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); + if (p.fdShape.y > 0) + g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) + ? 0.0f + : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; + else + g_fd[idx] = + (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; + } +} + +// Host function to copy filters written by setup kernel into constant buffer +// for main kernel. +static cudaError_t copy_filters(cudaStream_t stream) { + void* src = 0; + cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); + if (err) return err; + return cudaMemcpyToSymbolAsync( + c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, + cudaMemcpyDeviceToDevice, stream); +} + +//------------------------------------------------------------------------ +// Coordinate spaces: +// - Relative to input tensor: inX, inY, tileInX, tileInY +// - Relative to input tile: relInX, relInY, tileInW, tileInH +// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH +// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH +// - Relative to output tensor: outX, outY, tileOutX, tileOutY +// +// Relationships between coordinate spaces: +// - inX = tileInX + relInX +// - inY = tileInY + relInY +// - relUpX = relInX * up + phaseInX +// - relUpY = relInY * up + phaseInY +// - relUpX = relOutX * down +// - relUpY = relOutY * down +// - outX = tileOutX + relOutX +// - outY = tileOutY + relOutY + +extern __shared__ char + s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically + // inside the kernel, otherwise use the externally allocated + // shared memory buffer. + +template +static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { + // Check that we don't try to support non-existing filter modes. + static_assert(up == 1 || up == 2 || up == 4, + "only up=1, up=2, up=4 scales supported"); + static_assert(down == 1 || down == 2 || down == 4, + "only down=1, down=2, down=4 scales supported"); + static_assert(fuSize >= up, + "upsampling filter size must be at least upsampling factor"); + static_assert( + fdSize >= down, + "downsampling filter size must be at least downsampling factor"); + static_assert( + fuSize % up == 0, + "upsampling filter size must be divisible with upsampling factor"); + static_assert( + fdSize % down == 0, + "downsampling filter size must be divisible with downsampling factor"); + static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, + "filter size greater than MAX_FILTER_SIZE"); + static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || + filterMode == MODE_FUSD)), + "up=1 supported only for 1x1 full filters"); + static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || + filterMode == MODE_SUFD)), + "down=1 supported only for 1x1 full filters"); + static_assert( + !(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), + "full filters not supported for up=4"); + static_assert( + !(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), + "full filters not supported for down=4"); + + // Static definitions. + typedef typename InternalType::scalar_t scalar_t; + typedef typename InternalType::vec2_t vec2_t; + typedef typename InternalType::vec4_t vec4_t; + const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & + ~3; // Upsampled tile width, rounded up to multiple of 4. + const int tileUpH = + tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. + const int tileInW = + CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. + const int tileInH = + CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. + const int tileUpH_up = + CEIL_DIV(tileUpH, up) * + up; // Upsampled tile height rounded up to a multiple of up. + const int tileInH_up = + CEIL_DIV(tileUpH_up + (fuSize - 1), + up); // For allocations only, to avoid shared memory read + // overruns with up=2 and up=4. + + // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. + const bool downInline = + (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || + (up == 2 && filterMode == MODE_SUFD)); + + // Sizes of logical buffers. + const int szIn = tileInH_up * tileInW; + const int szUpX = tileInH_up * tileUpW; + const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); + const int szDownX = tileUpH * tileOutW; + + // Sizes for shared memory arrays. + const int s_buf0_size_base = + (filterMode == MODE_SUSD) + ? MAX(szIn, szUpXY) + : (filterMode == MODE_FUSD) + ? MAX(szIn, szDownX) + : (filterMode == MODE_SUFD) + ? MAX(szIn, szUpXY) + : (filterMode == MODE_FUFD) ? szIn : -1; + const int s_buf1_size_base = + (filterMode == MODE_SUSD) + ? MAX(szUpX, szDownX) + : (filterMode == MODE_FUSD) + ? szUpXY + : (filterMode == MODE_SUFD) + ? szUpX + : (filterMode == MODE_FUFD) ? szUpXY : -1; + + // Ensure U128 alignment. + const int s_buf0_size = (s_buf0_size_base + 3) & ~3; + const int s_buf1_size = (s_buf1_size_base + 3) & ~3; + + // Check at compile time that we don't use too much shared memory. + static_assert( + (s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), + "shared memory overflow"); + + // Declare shared memory arrays. + scalar_t* s_buf0; + scalar_t* s_buf1; + if (sharedKB <= 48) { + // Allocate shared memory arrays here. + __shared__ scalar_t + s_buf0_st[(sharedKB > 48) + ? (1 << 24) + : (s_buf0_size + + s_buf1_size)]; // Prevent launching if this isn't + // optimized away when unused. + s_buf0 = s_buf0_st; + s_buf1 = s_buf0 + s_buf0_size; + } else { + // Use the dynamically allocated shared memory array. + s_buf0 = (scalar_t*)s_buf_raw; + s_buf1 = s_buf0 + s_buf0_size; + } + + // Pointers to the buffers. + scalar_t* + s_tileIn; // Input tile: [relInX * tileInH + relInY] + scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + + // relUpX] + scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + + // relUpX] + scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + // + relOutX] + if (filterMode == MODE_SUSD) { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + s_tileDownX = s_buf1; + } else if (filterMode == MODE_FUSD) { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + s_tileDownX = s_buf0; + } else if (filterMode == MODE_SUFD) { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + } else if (filterMode == MODE_FUFD) { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + } + + // Allow large grids in z direction via per-launch offset. + int channelIdx = blockIdx.z + p.blockZofs; + int batchIdx = channelIdx / p.yShape.z; + channelIdx -= batchIdx * p.yShape.z; + + // Offset to output feature map. In bytes. + index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + + batchIdx * get_stride(p.yStride.w); + + // Sign shift amount. + uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; + +// Inner tile loop. +#pragma unroll 1 + for (int tileIdx = 0; + !enableXrep || + (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); + tileIdx++) { + // Locate output tile. + int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; + int tileOutX = tileX * tileOutW; + int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; + + // Locate input tile. + int tmpX = tileOutX * down - p.pad0.x; + int tmpY = tileOutY * down - p.pad0.y; + int tileInX = CEIL_DIV(tmpX, up); + int tileInY = CEIL_DIV(tmpY, up); + const int phaseInX = tileInX * up - tmpX; + const int phaseInY = tileInY * up - tmpY; + + // Extra sync if input and output buffers are the same and we are not on + // first tile. + if (enableXrep && tileIdx > 0 && + (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || + (filterMode == MODE_FUFD && downInline))) + __syncthreads(); + + // Load input tile & apply bias. Unrolled. + scalar_t b = + (scalar_t) * (const T*)((const char*)p.b + + (channelIdx * get_stride(p.bStride))); + index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + + batchIdx * get_stride(p.xStride.w); + int idx = threadIdx.x; + const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); +#pragma unroll + for (int loop = 0; loop < loopCountIN; loop++) { + int relInX, relInY; + fast_div_mod(relInX, relInY, idx); + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + + if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) + v = (scalar_t) * ((const T*)((const char*)p.x + + (inX * get_stride(p.xStride.x) + + inY * get_stride(p.xStride.y) + + mapOfsIn))) + + b; + + bool skip = (loop == loopCountIN - 1) && (idx >= tileInW * tileInH); + if (!skip) s_tileIn[idx] = v; + + idx += threadsPerBlock; + } + + if (filterMode == MODE_SUSD || + filterMode == MODE_SUFD) // Separable upsampling filter. + { + // Horizontal upsampling. + __syncthreads(); + if (up == 4) { + for (int idx = threadIdx.x * up; idx < tileUpW * tileInH; + idx += blockDim.x * up) { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + scalar_t a = s_tileIn[src0]; + if (phaseInX == 0) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } else if (phaseInX == 1) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } else if (phaseInX == 2) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } else // (phaseInX == 3) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst + 0] = v.x; + s_tileUpX[dst + 1] = v.y; + s_tileUpX[dst + 2] = v.z; + s_tileUpX[dst + 3] = v.w; + } + } else if (up == 2) { + bool p0 = (phaseInX == 0); + for (int idx = threadIdx.x * up; idx < tileUpW * tileInH; + idx += blockDim.x * up) { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + scalar_t a = s_tileIn[src0]; + if (p0) // (phaseInX == 0) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } else // (phaseInX == 1) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst + 0] = v.x; + s_tileUpX[dst + 1] = v.y; + } + } + + // Vertical upsampling & nonlinearity. + + __syncthreads(); + int groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + : 0; // Skip already written signs. + int sShapeMaxY = + MIN(p.sShape.y, + tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. + if (up == 4) { + minY -= 3; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; + idx += blockDim.x) { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec4_t v = InternalType::zero_vec4(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } else if (phaseInY == 1) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } else if (phaseInY == 2) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } else // (phaseInY == 3) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = + signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + index_t si2 = si0 + p.sShape.x * 2; + index_t si3 = si0 + p.sShape.x * 3; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) { + if (!enableWriteSkip) { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2 << 0; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (fabsf(v.y) > p.clamp) { + sy = 2 << 8; + v.y = InternalType::clamp(v.y, p.clamp); + } + if (fabsf(v.z) > p.clamp) { + sz = 2 << 16; + v.z = InternalType::clamp(v.z, p.clamp); + } + if (fabsf(v.w) > p.clamp) { + sw = 2 << 24; + v.w = InternalType::clamp(v.w, p.clamp); + } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) { + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { + p.s[si0] = (unsigned char)(s >> 0); + } + if ((uint32_t)(signY + 1) < sShapeMaxY) { + p.s[si1] = (unsigned char)(s >> 8); + } + if ((uint32_t)(signY + 2) < sShapeMaxY) { + p.s[si2] = (unsigned char)(s >> 16); + } + if ((uint32_t)(signY + 3) < sShapeMaxY) { + p.s[si3] = (unsigned char)(s >> 24); + } + } + } else { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2 << 0; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (fabsf(v.y) > p.clamp) { + sy = 2 << 8; + v.y = InternalType::clamp(v.y, p.clamp); + } + if (fabsf(v.z) > p.clamp) { + sz = 2 << 16; + v.z = InternalType::clamp(v.z, p.clamp); + } + if (fabsf(v.w) > p.clamp) { + sw = 2 << 24; + v.w = InternalType::clamp(v.w, p.clamp); + } + + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { + p.s[si0] = (unsigned char)(s >> 0); + } + if ((uint32_t)(signY + 1) < sShapeMaxY) { + p.s[si1] = (unsigned char)(s >> 8); + } + if ((uint32_t)(signY + 2) < sShapeMaxY) { + p.s[si2] = (unsigned char)(s >> 16); + } + if ((uint32_t)(signY + 3) < sShapeMaxY) { + p.s[si3] = (unsigned char)(s >> 24); + } + } else { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; + v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; + v.w = InternalType::clamp(v.w, p.clamp); + } + } + } else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) { + int ss = (signX & 3) << 1; + if ((uint32_t)(signY + 0) < p.sShape.y) { + int s = p.s[si0] >> ss; + if (s & 1) v.x *= p.slope; + if (s & 2) v.x = 0.f; + } + if ((uint32_t)(signY + 1) < p.sShape.y) { + int s = p.s[si1] >> ss; + if (s & 1) v.y *= p.slope; + if (s & 2) v.y = 0.f; + } + if ((uint32_t)(signY + 2) < p.sShape.y) { + int s = p.s[si2] >> ss; + if (s & 1) v.z *= p.slope; + if (s & 2) v.z = 0.f; + } + if ((uint32_t)(signY + 3) < p.sShape.y) { + int s = p.s[si3] >> ss; + if (s & 1) v.w *= p.slope; + if (s & 2) v.w = 0.f; + } + } + } else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; + v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; + v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[dst + 0 * tileUpW] = v.x; + if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; + if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; + if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; + } + } else if (up == 2) { + minY -= 1; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; + idx += blockDim.x) { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec2_t v = InternalType::zero_vec2(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } else // (phaseInY == 1) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = + signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) { + if (!enableWriteSkip) { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2 << 0; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (fabsf(v.y) > p.clamp) { + sy = 2 << 8; + v.y = InternalType::clamp(v.y, p.clamp); + } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) { + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { + p.s[si0] = (unsigned char)(s >> 0); + } + if ((uint32_t)(signY + 1) < sShapeMaxY) { + p.s[si1] = (unsigned char)(s >> 8); + } + } + } else { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2 << 0; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (fabsf(v.y) > p.clamp) { + sy = 2 << 8; + v.y = InternalType::clamp(v.y, p.clamp); + } + + // Combine signs. + int s = sx + sy; + s <<= signXo; + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { + p.s[si0] = (unsigned char)(s >> 0); + } + if ((uint32_t)(signY + 1) < sShapeMaxY) { + p.s[si1] = (unsigned char)(s >> 8); + } + } else { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + } + } + } else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) { + if ((uint32_t)(signY + 0) < p.sShape.y) { + int s = p.s[si0] >> signXo; + if (s & 1) v.x *= p.slope; + if (s & 2) v.x = 0.f; + } + if ((uint32_t)(signY + 1) < p.sShape.y) { + int s = p.s[si1] >> signXo; + if (s & 1) v.y *= p.slope; + if (s & 2) v.y = 0.f; + } + } + } else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + } + + if (!downInline) { + // Write into temporary buffer. + s_tileUpXY[dst] = v.x; + if (relUpY0 < tileUpH - 1) s_tileUpXY[dst + tileUpW] = v.y; + } else { + // Write directly into output buffer. + if ((uint32_t)x < p.yShape.x) { + int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); + index_t ofs = x * get_stride(p.yStride.x) + + y * get_stride(p.yStride.y) + mapOfsOut; + if ((uint32_t)y + 0 < p.yShape.y) + *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); + if ((uint32_t)y + 1 < ymax) + *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = + (T)(v.y * (scalar_t)c_fd[0]); + } + } + } + } + } else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) { + // Full upsampling filter. + + if (up == 2) { + // 2 x 2-wide. + __syncthreads(); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y + : 0; // Skip already written signs. + for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; + idx += blockDim.x * 4) { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); + int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); + int src0 = relInX0 + tileInW * relInY0; + int tap0y = (relInY0 * up + phaseInY - relUpY0); + +#define X_LOOP(TAPY, PX) \ + for (int sx = 0; sx < fuSize / up; sx++) { \ + v.x += a * (scalar_t)c_fu[(sx * up + (((PX)-0) & (up - 1))) + \ + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.z += b * (scalar_t)c_fu[(sx * up + (((PX)-0) & (up - 1))) + \ + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + if ((PX) == 0) { \ + a = b; \ + b = s_tileIn[src0 + 2 + sx + sy * tileInW]; \ + } \ + v.y += a * (scalar_t)c_fu[(sx * up + (((PX)-1) & (up - 1))) + \ + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.w += b * (scalar_t)c_fu[(sx * up + (((PX)-1) & (up - 1))) + \ + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + if ((PX) == 1) { \ + a = b; \ + b = s_tileIn[src0 + 2 + sx + sy * tileInW]; \ + } \ + } + + vec4_t v = InternalType::zero_vec4(); + if (tap0y == 0 && phaseInX == 0) +#pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { + scalar_t a = s_tileIn[src0 + sy * tileInW]; + scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; +#pragma unroll + X_LOOP(0, 0) + } + if (tap0y == 0 && phaseInX == 1) +#pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { + scalar_t a = s_tileIn[src0 + sy * tileInW]; + scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; +#pragma unroll + X_LOOP(0, 1) + } + if (tap0y == 1 && phaseInX == 0) +#pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { + scalar_t a = s_tileIn[src0 + sy * tileInW]; + scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; +#pragma unroll + X_LOOP(1, 0) + } + if (tap0y == 1 && phaseInX == 1) +#pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { + scalar_t a = s_tileIn[src0 + sy * tileInW]; + scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; +#pragma unroll + X_LOOP(1, 1) + } + +#undef X_LOOP + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = + signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) { + if (!enableWriteSkip) { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (sy) v.y *= p.slope; + if (fabsf(v.y) > p.clamp) { + sy = 2; + v.y = InternalType::clamp(v.y, p.clamp); + } + if (sz) v.z *= p.slope; + if (fabsf(v.z) > p.clamp) { + sz = 2; + v.z = InternalType::clamp(v.z, p.clamp); + } + if (sw) v.w *= p.slope; + if (fabsf(v.w) > p.clamp) { + sw = 2; + v.w = InternalType::clamp(v.w, p.clamp); + } + + if ((uint32_t)signXb < p.swLimit && + (uint32_t)signY < p.sShape.y && signY >= minY) { + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + } else { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && + (uint32_t)signY < p.sShape.y && signY >= minY) { + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (sy) v.y *= p.slope; + if (fabsf(v.y) > p.clamp) { + sy = 2; + v.y = InternalType::clamp(v.y, p.clamp); + } + if (sz) v.z *= p.slope; + if (fabsf(v.z) > p.clamp) { + sz = 2; + v.z = InternalType::clamp(v.z, p.clamp); + } + if (sw) v.w *= p.slope; + if (fabsf(v.w) > p.clamp) { + sw = 2; + v.w = InternalType::clamp(v.w, p.clamp); + } + + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } else { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; + v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; + v.w = InternalType::clamp(v.w, p.clamp); + } + } + } else if (signRead) // Read sign and apply. + { + if ((uint32_t)signY < p.sShape.y) { + int s = 0; + if ((uint32_t)signXb < p.swLimit) s = p.s[si]; + if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; + s >>= (signX & 3) << 1; + if (s & 0x01) v.x *= p.slope; + if (s & 0x02) v.x = 0.f; + if (s & 0x04) v.y *= p.slope; + if (s & 0x08) v.y = 0.f; + if (s & 0x10) v.z *= p.slope; + if (s & 0x20) v.z = 0.f; + if (s & 0x40) v.w *= p.slope; + if (s & 0x80) v.w = 0.f; + } + } else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; + v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; + v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[idx + 0] = v.x; + s_tileUpXY[idx + 1] = v.y; + s_tileUpXY[idx + 2] = v.z; + s_tileUpXY[idx + 3] = v.w; + } + } else if (up == 1) { + __syncthreads(); + uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + : 0; // Skip already written signs. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH; + idx += blockDim.x) { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = + signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + v *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) { + if (!enableWriteSkip) { + // Determine and write sign. + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + if ((uint32_t)signXb < p.swLimit && + (uint32_t)signY < p.sShape.y && signY >= minY) { + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } + } else { + // Determine and write sign. + if ((uint32_t)signXb < p.swLimit && + (uint32_t)signY < p.sShape.y && signY >= minY) { + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. + p.s[si] = s; // Write. + } else { + // Just compute the value. + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + } + } else if (signRead) { + // Read sign and apply if within sign tensor bounds. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) { + int s = p.s[si]; + s >>= signXo; + if (s & 1) v *= p.slope; + if (s & 2) v = 0.f; + } + } else // Forward pass with no sign write. + { + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + + if (!downInline) // Write into temporary buffer. + s_tileUpXY[idx] = v; + else if ((uint32_t)x < p.yShape.x && + (uint32_t)y < + p.yShape.y) // Write directly into output buffer + *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + + y * get_stride(p.yStride.y) + + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); + } + } + } + + // Downsampling. + if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) { + // Horizontal downsampling. + __syncthreads(); + if (down == 4 && tileOutW % 4 == 0) { + // Calculate 4 pixels at a time. + for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; + idx += blockDim.x * 4) { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); +#pragma unroll + for (int step = 0; step < fdSize; step++) { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; + v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; + v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx + 0] = v.x; + s_tileDownX[idx + 1] = v.y; + s_tileDownX[idx + 2] = v.z; + s_tileDownX[idx + 3] = v.w; + } + } else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) { + // Calculate 2 pixels at a time. + for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; + idx += blockDim.x * 2) { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); +#pragma unroll + for (int step = 0; step < fdSize; step++) { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx + 0] = v.x; + s_tileDownX[idx + 1] = v.y; + } + } else { + // Calculate 1 pixel at a time. + for (int idx = threadIdx.x; idx < tileOutW * tileUpH; + idx += blockDim.x) { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src = relUpY * tileUpW + relUpX0; + scalar_t v = 0.f; +#pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; + s_tileDownX[idx] = v; + } + } + + // Vertical downsampling & store output tile. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; + idx += blockDim.x) { + int relOutX, relOutY0; + fast_div_mod(relOutX, relOutY0, idx); + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileOutW + relOutX; + scalar_t v = 0; +#pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; + + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY0; + + if (outX < p.yShape.x & outY < p.yShape.y) + *((T*)((char*)p.y + + (outX * get_stride(p.yStride.x) + + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; + } + } else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) { + // Full downsampling filter. + if (down == 2) { + // 2-wide. + __syncthreads(); + for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; + idx += blockDim.x * 2) { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + int relUpX0 = relOutX0 * down; + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); +#pragma unroll + for (int sy = 0; sy < fdSize; sy++) +#pragma unroll + for (int sx = 0; sx < fdSize; sx++) { + v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * + (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * + (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + } + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outY < p.yShape.y) { + index_t ofs = outX * get_stride(p.yStride.x) + + outY * get_stride(p.yStride.y) + mapOfsOut; + if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; + if (outX + 1 < p.yShape.x) + *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = + (T)v.y; + } + } + } else if (down == 1 && !downInline) { + // Thread per pixel. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; + idx += blockDim.x) { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) + *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + + outY * get_stride(p.yStride.y) + + mapOfsOut))) = (T)v; + } + } + } + + if (!enableXrep) break; + } +} + +//------------------------------------------------------------------------ +// Compute activation function and signs for upsampled data tensor, modifying +// data tensor in-place. Used for accelerating the generic variant. Sign tensor +// is known to be contiguous, and p.x and p.s have the same z, w dimensions. +// 64-bit indexing is always used. + +template +static __global__ void filtered_lrelu_act_kernel( + filtered_lrelu_act_kernel_params p) { + typedef typename InternalType::scalar_t scalar_t; + + // Indexing. + int32_t x = threadIdx.x + blockIdx.x * blockDim.x; + int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; + int32_t qmax = + p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. + + // Loop to accommodate oversized tensors. + for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) + for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) { + // Extract z and w (channel, minibatch index). + int32_t w = q / p.xShape.z; + int32_t z = q - w * p.xShape.z; + + // Choose behavior based on sign read/write mode. + if (signWrite) { + // Process value if in p.x. + uint32_t s = 0; + if (x < p.xShape.x && y < p.xShape.y) { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + + // Gain, LReLU, clamp. + v *= p.gain; + if (v < 0.f) { + v *= p.slope; + s = 1; // Sign. + } + if (fabsf(v) > p.clamp) { + v = InternalType::clamp(v, p.clamp); + s = 2; // Clamp. + } + + *pv = (T)v; // Write value. + } + + // Coalesce into threads 0 and 16 of warp. + uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; + s <<= ((threadIdx.x & 15) << 1); // Shift into place. + s |= __shfl_xor_sync(m, s, 1); // Distribute. + s |= __shfl_xor_sync(m, s, 2); + s |= __shfl_xor_sync(m, s, 4); + s |= __shfl_xor_sync(m, s, 8); + + // Write signs if leader and in p.s. + if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. + { + uint64_t is = + x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. + ((uint32_t*)p.s)[is >> 4] = s; + } + } else if (signRead) { + // Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + + // Apply sign buffer offset. + uint32_t sx = x + p.sOfs.x; + uint32_t sy = y + p.sOfs.y; + + // Read and apply signs if we land inside valid region of sign buffer. + if (sx < p.sShape.x && sy < p.sShape.y) { + uint64_t is = + (sx >> 2) + (p.sShape.x >> 2) * + (sy + (uint64_t)p.sShape.y * q); // Contiguous. + unsigned char s = p.s[is]; + s >>= (sx & 3) << 1; // Shift into place. + if (s & 1) // Sign? + v *= p.slope; + if (s & 2) // Clamp? + v = 0.f; + } + + *pv = (T)v; // Write value. + } + } else { + // Forward pass with no sign write. Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + + w * p.xStride.w; + T* pv = ((T*)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + if (v < 0.f) v *= p.slope; + if (fabsf(v) > p.clamp) v = InternalType::clamp(v, p.clamp); + *pv = (T)v; // Write value. + } + } + } +} + +template +void* choose_filtered_lrelu_act_kernel(void) { + return (void*)filtered_lrelu_act_kernel; +} + +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template +filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel( + const filtered_lrelu_kernel_params& p, int sharedKB) { + filtered_lrelu_kernel_spec s = {0}; + + // Return the first matching kernel. +#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ + if (sharedKB >= SH) \ + if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || \ + (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ + if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || \ + (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ + if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && \ + p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) { \ + static_assert((D * TW % 4) == 0, \ + "down * tileWidth must be divisible by 4"); \ + static_assert( \ + FU % U == 0, \ + "upscaling filter size must be multiple of upscaling factor"); \ + static_assert(FD % D == 0, \ + "downscaling filter size must be multiple of " \ + "downscaling factor"); \ + s.setup = (void*)setup_filters_kernel; \ + s.exec = (void*) \ + filtered_lrelu_kernel; \ + s.tileOut = make_int2(TW, TH); \ + s.numWarps = W; \ + s.xrep = XR; \ + s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ + return s; \ + } + + // Launch parameters for various kernel specializations. + // Small filters must be listed before large filters, otherwise the kernel for + // larger filter will always match first. Kernels that use more shared memory + // must be listed before those that use less, for the same reason. + + CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 1, 1, /*mode*/ MODE_FUFD, + /*tw,th,warps,xrep,wskip*/ 64, 178, 32, 0, 0) // 1t-upf1-downf1 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 152, 95, 16, 0, 0) // 4t-ups2-downf1 + CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 8, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 56, 22, 16, 0, 0) // 4t-upf1-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 56, 29, 16, 11, 0) // 4t-ups2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 60, 28, 16, 0, 0) // 4t-upf2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 56, 28, 16, 0, 0) // 4t-ups2-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 16, /*down,fd*/ 2, 8, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 56, 31, 16, 11, 0) // 4t-ups4-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 16, /*down,fd*/ 2, 8, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 56, 36, 16, 0, 0) // 4t-ups4-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 4, 16, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 16, 22, 16, 12, 0) // 4t-ups2-downs4 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 4, 16, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 29, 15, 16, 0, 0) // 4t-upf2-downs4 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 96, 150, 28, 0, 0) // 6t-ups2-downf1 + CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 12, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 32, 35, 24, 0, 0) // 6t-upf1-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 32, 46, 16, 10, 0) // 6t-ups2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 58, 28, 24, 8, 0) // 6t-upf2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 52, 28, 16, 0, 0) // 6t-ups2-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 24, /*down,fd*/ 2, 12, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 32, 51, 16, 5, 0) // 6t-ups4-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 24, /*down,fd*/ 2, 12, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 32, 56, 16, 6, 0) // 6t-ups4-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 16, 18, 16, 12, 0) // 6t-ups2-downs4 + CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 27, 13, 24, 0, 0) // 6t-upf2-downs4 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 148, 89, 24, 0, 0) // 8t-ups2-downf1 + CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 16, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 32, 31, 16, 5, 0) // 8t-upf1-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 32, 41, 16, 9, 0) // 8t-ups2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 56, 26, 24, 0, 0) // 8t-upf2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 32, 40, 16, 0, 0) // 8t-ups2-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 32, /*down,fd*/ 2, 16, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 32, 46, 24, 5, 0) // 8t-ups4-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 32, /*down,fd*/ 2, 16, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 32, 50, 16, 0, 0) // 8t-ups4-downf2 + CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 16, 13, 16, 10, 1) // 8t-ups2-downs4 + CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 25, 10, 24, 0, 0) // 8t-upf2-downs4 + +#undef CASE + return s; // No kernel found. +} + +//------------------------------------------------------------------------ + +std::tuple filtered_lrelu_op( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, + torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, + int sx, int sy, float gain, float slope, float clamp, bool flip_filters, + bool writeSigns) { + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && + b.device() == x.device(), + "all input tensors must reside on the same device"); + TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, + "fu and fd must be float32"); + TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, + "x and b must be float16 or float32"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && + x.size(3) <= INT_MAX, + "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK( + (fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), + "fu and fd must be rank 1 or 2"); + TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, + "fu is too large"); + TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, + "fd is too large"); + TORCH_CHECK(fu.numel() > 0, "fu is empty"); + TORCH_CHECK(fd.numel() > 0, "fd is empty"); + TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), + "b must be a vector with the same number of channels as x"); + TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); + + // Figure out how much shared memory is available on the device. + int maxSharedBytes = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + x.device().index())); + int sharedKB = maxSharedBytes >> 10; + + // Populate enough launch parameters to check if a CUDA kernel exists. + filtered_lrelu_kernel_params p; + p.up = up; + p.down = down; + p.fuShape = + make_int2((int)fu.size(-1), + fu.dim() == 2 ? (int)fu.size(0) + : 0); // shape [n, 0] indicates separable filter. + p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); + filtered_lrelu_kernel_spec test_spec = + choose_filtered_lrelu_kernel(p, sharedKB); + if (!test_spec.exec) { + // No kernel found - return empty tensors and indicate missing kernel with + // return code of -1. + return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); + } + + // Input/output element size. + int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; + + // Input sizes. + int64_t xw = (int)x.size(3); + int64_t xh = (int)x.size(2); + int64_t fut_w = (int)fu.size(-1) - 1; + int64_t fut_h = (int)fu.size(0) - 1; + int64_t fdt_w = (int)fd.size(-1) - 1; + int64_t fdt_h = (int)fd.size(0) - 1; + + // Logical size of upsampled buffer. + int64_t cw = xw * up + (px0 + px1) - fut_w; + int64_t ch = xh * up + (py0 + py1) - fut_h; + TORCH_CHECK( + cw > fdt_w && ch > fdt_h, + "upsampled buffer must be at least the size of downsampling filter"); + TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); + + // Compute output size and allocate. + int64_t yw = (cw - fdt_w + (down - 1)) / down; + int64_t yh = (ch - fdt_h + (down - 1)) / down; + TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); + TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), + x.suggest_memory_format()); + + // Allocate sign tensor. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + int64_t sw_active = 0; // Active width of sign tensor. + if (writeSigns) { + sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. + int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. + int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, + // rounded up to multiple of 16. + TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); + s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, + x.options().dtype(torch::kUInt8), + at::MemoryFormat::Contiguous); + } else if (readSigns) + sw_active = s.size(3) << 2; + + // Validate sign tensor if in use. + if (readSigns || writeSigns) { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), + "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), + "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, + "signs is too large"); + } + + // Populate rest of CUDA kernel parameters. + p.x = x.data_ptr(); + p.y = y.data_ptr(); + p.b = b.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.fu = fu.data_ptr(); + p.fd = fd.data_ptr(); + p.pad0 = make_int2(px0, py0); + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.flip = (flip_filters) ? 1 : 0; + p.xShape = + make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.yShape = + make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.sShape = (readSigns || writeSigns) + ? make_int2((int)s.size(3), (int)s.size(2)) + : make_int2(0, 0); // Width is in bytes. Contiguous. + p.sOfs = make_int2(sx, sy); + p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. + + // x, y, b strides are in bytes. + p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), + sz * x.stride(1), sz * x.stride(0)); + p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), + sz * y.stride(1), sz * y.stride(0)); + p.bStride = sz * b.stride(0); + + // fu, fd strides are in elements. + p.fuStride = + make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); + p.fdStride = + make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); + + // Determine if indices don't fit in int32. Support negative strides although + // Torch currently never produces those. + bool index64b = false; + if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; + if (std::min(x.size(0) * p.xStride.w, 0ll) + + std::min(x.size(1) * p.xStride.z, 0ll) + + std::min(x.size(2) * p.xStride.y, 0ll) + + std::min(x.size(3) * p.xStride.x, 0ll) < + -INT_MAX) + index64b = true; + if (std::max(x.size(0) * p.xStride.w, 0ll) + + std::max(x.size(1) * p.xStride.z, 0ll) + + std::max(x.size(2) * p.xStride.y, 0ll) + + std::max(x.size(3) * p.xStride.x, 0ll) > + INT_MAX) + index64b = true; + if (std::min(y.size(0) * p.yStride.w, 0ll) + + std::min(y.size(1) * p.yStride.z, 0ll) + + std::min(y.size(2) * p.yStride.y, 0ll) + + std::min(y.size(3) * p.yStride.x, 0ll) < + -INT_MAX) + index64b = true; + if (std::max(y.size(0) * p.yStride.w, 0ll) + + std::max(y.size(1) * p.yStride.z, 0ll) + + std::max(y.size(2) * p.yStride.y, 0ll) + + std::max(y.size(3) * p.yStride.x, 0ll) > + INT_MAX) + index64b = true; + if (s.numel() > INT_MAX) index64b = true; + + // Choose CUDA kernel. + filtered_lrelu_kernel_spec spec = {0}; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x.scalar_type(), "filtered_lrelu_cuda", [&] { + if constexpr (sizeof(scalar_t) <= + 4) // Exclude doubles. constexpr prevents template + // instantiation. + { + // Choose kernel based on index type, datatype and sign read/write + // modes. + if (!index64b && writeSigns && !readSigns) + spec = choose_filtered_lrelu_kernel( + p, sharedKB); + else if (!index64b && !writeSigns && readSigns) + spec = choose_filtered_lrelu_kernel( + p, sharedKB); + else if (!index64b && !writeSigns && !readSigns) + spec = + choose_filtered_lrelu_kernel( + p, sharedKB); + else if (index64b && writeSigns && !readSigns) + spec = choose_filtered_lrelu_kernel( + p, sharedKB); + else if (index64b && !writeSigns && readSigns) + spec = choose_filtered_lrelu_kernel( + p, sharedKB); + else if (index64b && !writeSigns && !readSigns) + spec = + choose_filtered_lrelu_kernel( + p, sharedKB); + } + }); + TORCH_CHECK( + spec.exec, + "internal error - CUDA kernel not found") // This should not happen + // because we tested earlier + // that kernel exists. + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = spec.numWarps * 32; + int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; + int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; + int gz = p.yShape.z * p.yShape.w; + + // Repeat multiple horizontal tiles in a CTA? + if (spec.xrep) { + p.tilesXrep = spec.xrep; + p.tilesXdim = gx; + + gx = (gx + p.tilesXrep - 1) / p.tilesXrep; + std::swap(gx, gy); + } else { + p.tilesXrep = 0; + p.tilesXdim = 0; + } + + // Launch filter setup kernel. + AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, + at::cuda::getCurrentCUDAStream())); + + // Copy kernels to constant memory. + if (writeSigns && !readSigns) + AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && readSigns) + AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + else if (!writeSigns && !readSigns) + AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); + + // Set cache and shared memory configurations for main kernel. + AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); + if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? + AT_CUDA_CHECK(cudaFuncSetAttribute( + spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, + spec.dynamicSharedKB << 10)); + AT_CUDA_CHECK( + cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); + + // Launch main kernel. + const int maxSubGz = 65535; // CUDA maximum for block z dimension. + for (int zofs = 0; zofs < gz; + zofs += maxSubGz) // Do multiple launches if gz is too big. + { + p.blockZofs = zofs; + int subGz = std::min(maxSubGz, gz - zofs); + AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, + spec.dynamicSharedKB << 10, + at::cuda::getCurrentCUDAStream())); + } + + // Done. + return std::make_tuple(y, so, 0); +} + +//------------------------------------------------------------------------ + +torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx, + int sy, float gain, float slope, + float clamp, bool writeSigns) { + // Set CUDA device. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && + x.size(3) <= INT_MAX, + "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || + x.dtype() == torch::kDouble, + "x must be float16, float32 or float64"); + + // Output signs if we don't have sign input. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + if (writeSigns) { + int64_t sw = x.size(3); + sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. + s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, + x.options().dtype(torch::kUInt8), + at::MemoryFormat::Contiguous); + } + + // Validate sign tensor if in use. + if (readSigns || writeSigns) { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), + "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), + "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, + "signs tensor is too large"); + } + + // Initialize CUDA kernel parameters. + filtered_lrelu_act_kernel_params p; + p.x = x.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.xShape = + make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.xStride = + make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); + p.sShape = (readSigns || writeSigns) + ? make_int2((int)s.size(3) << 2, (int)s.size(2)) + : make_int2(0, 0); // Width is in elements. Contiguous. + p.sOfs = make_int2(sx, sy); + + // Choose CUDA kernel. + void* func = 0; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x.scalar_type(), "filtered_lrelu_act_cuda", [&] { + if (writeSigns) + func = choose_filtered_lrelu_act_kernel(); + else if (readSigns) + func = choose_filtered_lrelu_act_kernel(); + else + func = choose_filtered_lrelu_act_kernel(); + }); + TORCH_CHECK(func, "internal error - CUDA kernel not found"); + + // Launch CUDA kernel. + void* args[] = {&p}; + int bx = 128; // 4 warps per block. + + // Logical size of launch = writeSigns ? p.s : p.x + uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; + uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; + uint32_t gz = + p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. + gx = (gx - 1) / bx + 1; + + // Make sure grid y and z dimensions are within CUDA launch limits. Kernel + // loops internally to do the rest. + const uint32_t gmax = 65535; + gy = std::min(gy, gmax); + gz = std::min(gz, gmax); + + // Launch. + AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, + at::cuda::getCurrentCUDAStream())); + return so; +} diff --git a/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu b/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu index ea2f088200..a3aadf67bb 100644 --- a/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu +++ b/mmcv/ops/csrc/pytorch/cuda/upfirdn2d_kernel.cu @@ -1,370 +1,740 @@ -// Modified from -// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu -// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // -// This work is made available under the Nvidia Source Code License-NC. -// To view a copy of this license, visit -// https://nvlabs.github.io/stylegan2/license.html +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. +#include +#include + +#include "pytorch_cuda_helper.hpp" + +struct upfirdn2d_kernel_params { + const void *x; + const float *f; + void *y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; -#include -#include -#include -#include -#include -#include +//------------------------------------------------------------------------ +// CUDA kernel specialization. -#include +struct upfirdn2d_kernel_spec { + void *kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; -static __host__ __device__ __forceinline__ int floor_div(int a, int b) { - int c = a / b; +//------------------------------------------------------------------------ +// CUDA kernel selection. - if (c * b > a) { - c--; - } +template +upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params &p); +//------------------------------------------------------------------------ - return c; +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Helpers. + +template +struct InternalType; +template <> +struct InternalType { + typedef double scalar_t; +}; +template <> +struct InternalType { + typedef float scalar_t; +}; +template <> +struct InternalType { + typedef float scalar_t; +}; + +static __device__ __forceinline__ int floor_div(int a, int b) { + int t = 1 - a / b; + return (a + t * b) / b - t; } -struct UpFirDn2DKernelParams { - int up_x; - int up_y; - int down_x; - int down_y; - int pad_x0; - int pad_x1; - int pad_y0; - int pad_y1; - - int major_dim; - int in_h; - int in_w; - int minor_dim; - int kernel_h; - int kernel_w; - int out_h; - int out_w; - int loop_major; - int loop_x; -}; +//------------------------------------------------------------------------ +// Generic CUDA implementation for large filters. -template -__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, - const scalar_t *kernel, - const UpFirDn2DKernelParams p) { - int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; - int out_y = minor_idx / p.minor_dim; - minor_idx -= out_y * p.minor_dim; - int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; - int major_idx_base = blockIdx.z * p.loop_major; - - if (out_x_base >= p.out_w || out_y >= p.out_h || - major_idx_base >= p.major_dim) { +template +static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) { + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) return; - } - int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; - int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); - int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; - int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; - - for (int loop_major = 0, major_idx = major_idx_base; - loop_major < p.loop_major && major_idx < p.major_dim; - loop_major++, major_idx++) { - for (int loop_x = 0, out_x = out_x_base; - loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { - int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; - int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); - int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; - int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; - - const scalar_t *x_p = - &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + - minor_idx]; - const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; - int x_px = p.minor_dim; - int k_px = -p.up_x; - int x_py = p.in_w * p.minor_dim; - int k_py = -p.up_y * p.kernel_w; - - scalar_t v = 0.0f; - - for (int y = 0; y < h; y++) { - for (int x = 0; x < w; x++) { - v += static_cast(*x_p) * static_cast(*k_p); - x_p += x_px; - k_p += k_px; + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = + min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; + majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; + minorIdx < p.loopMinor & minor < p.sizeMinor; + minorIdx++, minor += p.launchMinor) { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; + loopX++, outX += blockDim.y) { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = + min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - + inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T *xp = + &((const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y + + c * p.inStride.z + n * p.inStride.w]; + const float *fp = + &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; } - x_p += x_py - w * x_px; - k_p += k_py - w * k_px; + // Store result. + v *= p.gain; + ((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y + + c * p.outStride.z + n * p.outStride.w] = (T)v; } - - out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + - minor_idx] = v; } - } } -template -__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, - const scalar_t *kernel, - const UpFirDn2DKernelParams p) { - const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; - const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; - - __shared__ volatile float sk[kernel_h][kernel_w]; - __shared__ volatile float sx[tile_in_h][tile_in_w]; - - int minor_idx = blockIdx.x; - int tile_out_y = minor_idx / p.minor_dim; - minor_idx -= tile_out_y * p.minor_dim; - tile_out_y *= tile_out_h; - int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; - int major_idx_base = blockIdx.z * p.loop_major; - - if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | - major_idx_base >= p.major_dim) { +//------------------------------------------------------------------------ +// Specialized CUDA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) { + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | + majorBase >= p.sizeMajor) return; - } - for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; - tap_idx += blockDim.x) { - int ky = tap_idx / kernel_w; - int kx = tap_idx - ky * kernel_w; - scalar_t v = 0.0; - - if (kx < p.kernel_w & ky < p.kernel_h) { - v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; + tapIdx += blockDim.x) { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; } - - sk[ky][kx] = v; + sf[fy][fx] = v; } - for (int loop_major = 0, major_idx = major_idx_base; - loop_major < p.loop_major & major_idx < p.major_dim; - loop_major++, major_idx++) { - for (int loop_x = 0, tile_out_x = tile_out_x_base; - loop_x < p.loop_x & tile_out_x < p.out_w; - loop_x++, tile_out_x += tile_out_w) { - int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; - int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; - int tile_in_x = floor_div(tile_mid_x, up_x); - int tile_in_y = floor_div(tile_mid_y, up_y); - + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; + majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; + loopX < p.loopX & tileOutX < p.outSize.x; + loopX++, tileOutX += tileOutW) { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); __syncthreads(); - - for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; - in_idx += blockDim.x) { - int rel_in_y = in_idx / tile_in_w; - int rel_in_x = in_idx - rel_in_y * tile_in_w; - int in_x = rel_in_x + tile_in_x; - int in_y = rel_in_y + tile_in_y; - - scalar_t v = 0.0; - - if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { - v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * - p.minor_dim + - minor_idx]; - } - - sx[rel_in_y][rel_in_x] = v; + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; + inIdx += blockDim.x) { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & + c < p.inSize.z) + v = (scalar_t)( + (const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y + + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; } + // Loop over output pixels. __syncthreads(); - for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; - out_idx += blockDim.x) { - int rel_out_y = out_idx / tile_out_w; - int rel_out_x = out_idx - rel_out_y * tile_out_w; - int out_x = rel_out_x + tile_out_x; - int out_y = rel_out_y + tile_out_y; - - int mid_x = tile_mid_x + rel_out_x * down_x; - int mid_y = tile_mid_y + rel_out_y * down_y; - int in_x = floor_div(mid_x, up_x); - int in_y = floor_div(mid_y, up_y); - int rel_in_x = in_x - tile_in_x; - int rel_in_y = in_y - tile_in_y; - int kernel_x = (in_x + 1) * up_x - mid_x - 1; - int kernel_y = (in_y + 1) * up_y - mid_y - 1; - - scalar_t v = 0.0; - + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; + outIdx += blockDim.x) { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) { + scalar_t v = 0; #pragma unroll - for (int y = 0; y < kernel_h / up_y; y++) + for (int y = 0; y < filterH / upy; y++) #pragma unroll - for (int x = 0; x < kernel_w / up_x; x++) - v += sx[rel_in_y + y][rel_in_x + x] * - sk[kernel_y + y * up_y][kernel_x + x * up_x]; - - if (out_x < p.out_w & out_y < p.out_h) { - out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + - minor_idx] = v; + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * + sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y + + c * p.outStride.z + n * p.outStride.w] = (T)v; } } } } } -torch::Tensor upfirdn2d_op(const torch::Tensor &input, - const torch::Tensor &kernel, int up_x, int up_y, - int down_x, int down_y, int pad_x0, int pad_x1, - int pad_y0, int pad_y1) { - int curDevice = -1; - cudaGetDevice(&curDevice); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); - - UpFirDn2DKernelParams p; - - auto x = input.contiguous(); - auto k = kernel.contiguous(); - - p.major_dim = x.size(0); - p.in_h = x.size(1); - p.in_w = x.size(2); - p.minor_dim = x.size(3); - p.kernel_h = k.size(0); - p.kernel_w = k.size(1); - p.up_x = up_x; - p.up_y = up_y; - p.down_x = down_x; - p.down_y = down_y; - p.pad_x0 = pad_x0; - p.pad_x1 = pad_x1; - p.pad_y0 = pad_y0; - p.pad_y1 = pad_y1; - - p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / - p.down_y; - p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / - p.down_x; - - auto out = - at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); - - int mode = -1; - - int tile_out_h = -1; - int tile_out_w = -1; - - if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && - p.kernel_h <= 4 && p.kernel_w <= 4) { - mode = 1; - tile_out_h = 16; - tile_out_w = 64; +//------------------------------------------------------------------------ +// CUDA kernel selection. + +template +upfirdn2d_kernel_spec choose_upfirdn2d_kernel( + const upfirdn2d_kernel_params &p) { + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + upfirdn2d_kernel_spec spec = {(void *)upfirdn2d_kernel_large, -1, -1, 1, + 4}; // contiguous + if (s == 1) + spec = {(void *)upfirdn2d_kernel_large, -1, -1, 4, 1}; // channels_last + + // No up/downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 7 && fy <= 7) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 5 && fy <= 5) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 3 && fy <= 3) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 7 && fy <= 7) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 5 && fy <= 5) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 3 && fy <= 3) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; } - if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && - p.kernel_h <= 3 && p.kernel_w <= 3) { - mode = 2; - tile_out_h = 16; - tile_out_w = 64; + // 2x upsampling. + if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 2 && fy <= 2) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 8 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 2 && fy <= 2) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; } - - if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && - p.kernel_h <= 4 && p.kernel_w <= 4) { - mode = 3; - tile_out_h = 16; - tile_out_w = 64; + if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; } - - if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && - p.kernel_h <= 2 && p.kernel_w <= 2) { - mode = 4; - tile_out_h = 16; - tile_out_w = 64; + if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; } - if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && - p.kernel_h <= 4 && p.kernel_w <= 4) { - mode = 5; - tile_out_h = 8; - tile_out_w = 32; + // 2x downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, 32, + 8, 1, 1}; + if (s != 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, 32, + 8, 1, 1}; + if (s != 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, 32, + 8, 1, 1}; + if (s != 1 && fx <= 2 && fy <= 2) + spec = {(void *)upfirdn2d_kernel_small, 32, + 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 1, 1}; + if (s == 1 && fx <= 8 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, 8, + 8, 8, 1}; + if (s == 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, 8, + 8, 8, 1}; + if (s == 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, 8, + 8, 8, 1}; + if (s == 1 && fx <= 2 && fy <= 2) + spec = {(void *)upfirdn2d_kernel_small, 8, + 8, 8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 64, 8, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 64, 8, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, 64, + 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 64, 1, 8, 1}; + if (s == 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 64, 1, 8, 1}; + if (s == 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, 64, + 1, 8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, 1, + 64, 8, 1}; + if (s == 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, 1, + 64, 8, 1}; + if (s == 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, 1, + 64, 8, 1}; } - if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && - p.kernel_h <= 2 && p.kernel_w <= 2) { - mode = 6; - tile_out_h = 8; - tile_out_w = 32; + // 4x upsampling. + if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 48 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 32 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 32 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + } + if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 32 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + } + if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; } - dim3 block_size; - dim3 grid_size; - - if (tile_out_h > 0 && tile_out_w > 0) { - p.loop_major = (p.major_dim - 1) / 16384 + 1; - p.loop_x = 1; - block_size = dim3(32 * 8, 1, 1); - grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, - (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, - (p.major_dim - 1) / p.loop_major + 1); - } else { - p.loop_major = (p.major_dim - 1) / 16384 + 1; - p.loop_x = 4; - block_size = dim3(4, 32, 1); - grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, - (p.out_w - 1) / (p.loop_x * block_size.y) + 1, - (p.major_dim - 1) / p.loop_major + 1); + // 4x downsampling (inefficient). + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 32, 8, 1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 32, 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 32, 1, 8, 1}; + if (s == 1 && fx <= 32 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 32, 1, 8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 32, 8, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 32, 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, 1, + 32, 8, 1}; + if (s == 1 && fx <= 1 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, 1, + 32, 8, 1}; } + return spec; +} +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel( + const upfirdn2d_kernel_params &p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel( + const upfirdn2d_kernel_params &p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel( + const upfirdn2d_kernel_params &p); + +//------------------------------------------------------------------------ + +//------------------------------------------------------------------------ + +torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy, + int downx, int downy, int padx0, int padx1, + int pady0, int pady1, bool flip, float gain) { + // Validate arguments. + TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); + TORCH_CHECK(f.device() == x.device(), + "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.numel() > 0, "x has zero size"); + TORCH_CHECK(f.numel() > 0, "f has zero size"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK((x.size(0) - 1) * x.stride(0) + (x.size(1) - 1) * x.stride(1) + + (x.size(2) - 1) * x.stride(2) + + (x.size(3) - 1) * x.stride(3) <= + INT_MAX, + "x memory footprint is too large"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, + "downsampling factor must be at least 1"); + + // Create output tensor. + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + int outW = + ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = + ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, + x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + TORCH_CHECK((y.size(0) - 1) * y.stride(0) + (y.size(1) - 1) * y.stride(1) + + (y.size(2) - 1) * y.stride(2) + + (y.size(3) - 1) * y.stride(3) <= + INT_MAX, + "output memory footprint is too large"); + + // Initialize CUDA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = + make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), + (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = + make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), + (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose CUDA kernel. + upfirdn2d_kernel_spec spec; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { - switch (mode) { - case 1: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 2: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 3: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 4: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 5: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - case 6: - upfirdn2d_kernel - <<>>(out.data_ptr(), - x.data_ptr(), - k.data_ptr(), p); - - break; - - default: - upfirdn2d_kernel_large<<>>( - out.data_ptr(), x.data_ptr(), - k.data_ptr(), p); - } + spec = choose_upfirdn2d_kernel(p); }); - return out; + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = + dim3(((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor); + } else // small + { + blockSize = dim3(256, 1, 1); + gridSize = + dim3(((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor); + } + + // Launch CUDA kernel. + void *args[] = {&p}; + AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, + at::cuda::getCurrentCUDAStream())); + return y; } diff --git a/mmcv/ops/csrc/pytorch/filtered_lrelu.cpp b/mmcv/ops/csrc/pytorch/filtered_lrelu.cpp new file mode 100644 index 0000000000..c7ecc14cf9 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/filtered_lrelu.cpp @@ -0,0 +1,37 @@ +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +std::tuple filtered_lrelu_op_impl( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, + torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, + int sx, int sy, float gain, float slope, float clamp, bool flip_filters, + bool writeSigns) { + return DISPATCH_DEVICE_IMPL(filtered_lrelu_op_impl, x, fu, fd, b, si, up, + down, px0, px1, py0, py1, sx, sy, gain, slope, + clamp, flip_filters, writeSigns); +} + +std::tuple filtered_lrelu( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, + torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, + int sx, int sy, float gain, float slope, float clamp, bool flip_filters, + bool writeSigns) { + return filtered_lrelu_op_impl(x, fu, fd, b, si, up, down, px0, px1, py0, py1, + sx, sy, gain, slope, clamp, flip_filters, + writeSigns); +} + +torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si, + int sx, int sy, float gain, + float slope, float clamp, + bool writeSigns) { + return DISPATCH_DEVICE_IMPL(filtered_lrelu_act_op_impl, x, si, sx, sy, gain, + slope, clamp, writeSigns); +} + +torch::Tensor filtered_lrelu_act_(torch::Tensor x, torch::Tensor si, int sx, + int sy, float gain, float slope, float clamp, + bool writeSigns) { + return filtered_lrelu_act_op_impl(x, si, sx, sy, gain, slope, clamp, + writeSigns); +} diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index bdebc6b175..4841cd27c2 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -322,9 +322,9 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, const Tensor dets_sorted, const Tensor labels, const float iou_threshold, const int multi_label); -Tensor upfirdn2d(const Tensor &input, const Tensor &kernel, int up_x, int up_y, - int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, - int pad_y1); +Tensor upfirdn2d(torch::Tensor input, torch::Tensor filter, int upx, int upy, + int downx, int downy, int padx0, int padx1, int pady0, + int pady1, bool flip, float gain); Tensor fused_bias_leakyrelu(const Tensor &input, const Tensor &bias, const Tensor &refer, int act, int grad, float alpha, @@ -449,6 +449,20 @@ void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, Tensor graddist2, Tensor gradxyz1, Tensor gradxyz2); +Tensor bias_act(const Tensor &input, const Tensor &bias, const Tensor &xref, + const Tensor &yref, const Tensor &dy, int grad, int dim, + int act, float alpha, float gain, float clamp); + +std::tuple filtered_lrelu( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, + torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, + int sx, int sy, float gain, float slope, float clamp, bool flip_filters, + bool writeSigns); + +torch::Tensor filtered_lrelu_act_(torch::Tensor x, torch::Tensor si, int sx, + int sy, float gain, float slope, float clamp, + bool writeSigns); + void box_iou_quadri(const Tensor boxes1, const Tensor boxes2, Tensor ious, const int mode_flag, const bool aligned); @@ -494,9 +508,9 @@ void bezier_align_backward(Tensor grad_output, Tensor rois, Tensor grad_input, PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), - py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), - py::arg("down_y"), py::arg("pad_x0"), py::arg("pad_x1"), - py::arg("pad_y0"), py::arg("pad_y1")); + py::arg("filter"), py::arg("upx"), py::arg("upy"), py::arg("downx"), + py::arg("downy"), py::arg("padx0"), py::arg("padx1"), py::arg("pady0"), + py::arg("pady1"), py::arg("flip"), py::arg("gain")); m.def("fused_bias_leakyrelu", &fused_bias_leakyrelu, "fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"), py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"), @@ -951,6 +965,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("input"), py::arg("rois"), py::arg("grad_rois"), py::arg("pooled_height"), py::arg("pooled_width"), py::arg("spatial_scale")); + m.def("bias_act", &bias_act, "bias_act (CUDA)", py::arg("input"), + py::arg("bias"), py::arg("xref"), py::arg("yref"), py::arg("dy"), + py::arg("grad"), py::arg("dim"), py::arg("act"), py::arg("alpha"), + py::arg("gain"), py::arg("clamp")); + m.def("filtered_lrelu", &filtered_lrelu, "filtered_lrelu (CUDA)", + py::arg("x"), py::arg("fu"), py::arg("fd"), py::arg("b"), py::arg("si"), + py::arg("up"), py::arg("down"), py::arg("px0"), py::arg("px1"), + py::arg("py0"), py::arg("py1"), py::arg("sx"), py::arg("sy"), + py::arg("gain"), py::arg("slope"), py::arg("clamp"), + py::arg("flip_filters"), py::arg("writeSigns")); + m.def("filtered_lrelu_act_", &filtered_lrelu_act_, + "filtered_lrelu_act_ (CUDA)", py::arg("x"), py::arg("si"), + py::arg("sx"), py::arg("sy"), py::arg("gain"), py::arg("slope"), + py::arg("clamp"), py::arg("writeSigns")); m.def("box_iou_quadri", &box_iou_quadri, "IoU for quadrilateral boxes", py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"), py::arg("mode_flag"), py::arg("aligned")); diff --git a/mmcv/ops/csrc/pytorch/upfirdn2d.cpp b/mmcv/ops/csrc/pytorch/upfirdn2d.cpp index dd325bd788..4a3e928c1a 100644 --- a/mmcv/ops/csrc/pytorch/upfirdn2d.cpp +++ b/mmcv/ops/csrc/pytorch/upfirdn2d.cpp @@ -102,17 +102,17 @@ THE POSSIBILITY OF SUCH DAMAGES. #include "pytorch_cpp_helper.hpp" #include "pytorch_device_registry.hpp" -torch::Tensor upfirdn2d_op_impl(const torch::Tensor& input, - const torch::Tensor& kernel, int up_x, int up_y, - int down_x, int down_y, int pad_x0, int pad_x1, - int pad_y0, int pad_y1) { - return DISPATCH_DEVICE_IMPL(upfirdn2d_op_impl, input, kernel, up_x, up_y, - down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter, + int upx, int upy, int downx, int downy, + int padx0, int padx1, int pady0, int pady1, + bool flip, float gain) { + return DISPATCH_DEVICE_IMPL(upfirdn2d_op_impl, input, filter, upx, upy, downx, + downy, padx0, padx1, pady0, pady1, flip, gain); } -torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, - int up_x, int up_y, int down_x, int down_y, int pad_x0, - int pad_x1, int pad_y0, int pad_y1) { - return upfirdn2d_op_impl(input, kernel, up_x, up_y, down_x, down_y, pad_x0, - pad_x1, pad_y0, pad_y1); +torch::Tensor upfirdn2d(torch::Tensor input, torch::Tensor filter, int upx, + int upy, int downx, int downy, int padx0, int padx1, + int pady0, int pady1, bool flip, float gain) { + return upfirdn2d_op_impl(input, filter, upx, upy, downx, downy, padx0, padx1, + pady0, pady1, flip, gain); } diff --git a/mmcv/ops/filtered_lrelu.py b/mmcv/ops/filtered_lrelu.py new file mode 100644 index 0000000000..04a98484ab --- /dev/null +++ b/mmcv/ops/filtered_lrelu.py @@ -0,0 +1,414 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# source: https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/filtered_lrelu.py # noqa +import warnings +from typing import Dict, Optional, Union + +import numpy as np +import torch + +from ..utils import ext_loader +from .bias_act import bias_act +from .upfirdn2d import _get_filter_size, _parse_padding, upfirdn2d + +ext_module = ext_loader.load_ext('_ext', + ['filtered_lrelu', 'filtered_lrelu_act_']) + +_plugin = None + + +def filtered_lrelu(input: torch.Tensor, + filter_up: Optional[torch.Tensor] = None, + filter_down: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + up: int = 1, + down: int = 1, + padding: int = 0, + gain: float = np.sqrt(2), + slope: float = 0.2, + clamp: Optional[Union[float, int]] = None, + flip_filter: bool = False, + use_custom_op: bool = True): + """Filtered leaky ReLU for a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Add channel-specific bias if `bias` is provided. + + 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 3. Pad the image with the specified number of zeros on each side + (`padding`). Negative padding corresponds to cropping the image. + + 4. Convolve the image with the specified upsampling FIR filter + (`filter_up`), shrinking it so that the footprint of all output pixels + lies within the input image. + + 5. Multiply each value by the provided gain factor (`gain`). + + 6. Apply leaky ReLU activation function to each value. + + 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is + provided. + + 8. Convolve the image with the specified downsampling FIR filter + (`filter_down`), shrinking it so that the footprint of all output + pixels lies within the input image. + + 9. Downsample the image by keeping every Nth pixel (`down`). + + The fused op is considerably more efficient than performing the same + calculation using standard PyTorch ops. It supports gradients of arbitrary + order. + + Args: + input (torch.Tensor): Float32/float16/float64 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + filter_up (torch.Tensor): Float32 upsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), `[filter_taps]` + (separable), or `None` (identity). Defaults to None. + filter_down (torch.Tensor): Float32 downsampling FIR filter of the + shape `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or `None` (identity). + Defaults to None. + bias (torch.Tensor): Bias vector, or `None` to disable. Must be + a 1D tensor of the same type as `input`. The length of vector must + match the channel dimension of `input`. Defaults to None. + up (int): Integer upsampling factor. Defaults to 1. + down (int): Integer downsampling factor. Defaults to 1. + padding (int): Padding with respect to the upsampled image. Can be a + single number or a list/tuple `[x, y]` or `[x_before, x_after, + y_before, y_after]`. Defaults to 0. + gain (float): Overall scaling factor for signal magnitude. + Defaults to np.sqrt(2). + slope (float): Slope on the negative side of leaky ReLU. + Defaults to 0.2. + clamp (Optional[Union[float, int]]): Maximum magnitude for leaky ReLU + output. Defaults to None. + flip_filter (bool): False = convolution, True = correlation. + Defaults to False. + use_custom_op (bool): Whether to use customized op. + Defaults to True. + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, + out_width]`. + """ + assert isinstance(input, torch.Tensor) + if use_custom_op and input.is_cuda: + return _filtered_lrelu_cuda( + up=up, + down=down, + padding=padding, + gain=gain, + slope=slope, + clamp=clamp, + flip_filter=flip_filter).apply(input, filter_up, filter_down, bias, + None, 0, 0) + return _filtered_lrelu_ref( + input, + filter_up=filter_up, + filter_down=filter_down, + bias=bias, + up=up, + down=down, + padding=padding, + gain=gain, + slope=slope, + clamp=clamp, + flip_filter=flip_filter) + + +def _filtered_lrelu_ref(input: torch.Tensor, + filter_up: Optional[torch.Tensor] = None, + filter_down: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + up: int = 1, + down: int = 1, + padding: int = 0, + gain: float = np.sqrt(2), + slope: float = 0.2, + clamp: Optional[Union[float, int]] = None, + flip_filter: bool = False): + """Slow and memory-inefficient reference implementation of + `filtered_lrelu()` using existing `upfirdn2n()` and `bias_act()` ops. + + Args: + input (torch.Tensor): Float32/float16/float64 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + filter_up (torch.Tensor): Float32 upsampling FIR filter of the shape + `[filter_height, filter_width]` (non-separable), `[filter_taps]` + (separable), or `None` (identity). Defaults to None. + filter_down (torch.Tensor): Float32 downsampling FIR filter of the + shape `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or `None` (identity). + Defaults to None. + bias (torch.Tensor): Bias vector, or `None` to disable. Must be + a 1D tensor of the same type as `input`. The length of vector must + match the channel dimension of `input`. Defaults to None. + up (int): Integer upsampling factor. Defaults to 1. + down (int): Integer downsampling factor. Defaults to 1. + padding (int): Padding with respect to the upsampled image. Can be a + single number or a list/tuple `[x, y]` or `[x_before, x_after, + y_before, y_after]`. Defaults to 0. + gain (float): Overall scaling factor for signal magnitude. + Defaults to np.sqrt(2). + slope (float): Slope on the negative side of leaky ReLU. + Defaults to 0.2. + clamp (float or int): Maximum magnitude for leaky ReLU + output. Defaults to None. + flip_filter (bool): False = convolution, True = correlation. + Defaults to False. + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, + out_width]`. + """ + assert isinstance(input, torch.Tensor) and input.ndim == 4 + filter_up_w, filter_up_h = _get_filter_size(filter_up) + filter_down_w, filter_down_h = _get_filter_size(filter_down) + if bias is not None: + assert isinstance(bias, torch.Tensor) and bias.dtype == input.dtype + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + assert slope == float(slope) and slope >= 0 + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + + # Calculate output size. + batch_size, channels, in_h, in_w = input.shape + in_dtype = input.dtype + out_w = (in_w * up + (px0 + px1) - (filter_up_w - 1) - + (filter_down_w - 1) + (down - 1)) // down + out_h = (in_h * up + (py0 + py1) - (filter_up_h - 1) - + (filter_down_h - 1) + (down - 1)) // down + + # Compute using existing ops. + output = bias_act(input=input, bias=bias) # Apply bias. + output = upfirdn2d( + input=output, + filter=filter_up, + up=up, + padding=[px0, px1, py0, py1], + gain=up**2, + flip_filter=flip_filter) # Upsample. + output = bias_act( + input=output, act='lrelu', alpha=slope, gain=gain, + clamp=clamp) # Bias, leaky ReLU, clamp. + output = upfirdn2d( + input=output, filter=filter_down, down=down, + flip_filter=flip_filter) # Downsample. + + assert output.shape == (batch_size, channels, out_h, out_w) + assert output.dtype == in_dtype + return output + + +_filtered_lrelu_cuda_cache: Dict = dict() + + +def _filtered_lrelu_cuda(up: int = 1, + down: int = 1, + padding: int = 0, + gain: float = np.sqrt(2), + slope: float = 0.2, + clamp: Optional[Union[float, int]] = None, + flip_filter: bool = False): + """Fast CUDA implementation of `filtered_lrelu()` using custom ops. + + Args: + up (int): Integer upsampling factor. Defaults to 1. + down (int): Integer downsampling factor. Defaults to 1. + padding (int): Padding with respect to the upsampled image. Can be a + single number or a list/tuple `[x, y]` or `[x_before, x_after, + y_before, y_after]`. Defaults to 0. + gain (float): Overall scaling factor for signal magnitude. + Defaults to np.sqrt(2). + slope (float): Slope on the negative side of leaky ReLU. + Defaults to 0.2. + clamp (float or int): Maximum magnitude for leaky ReLU + output. Defaults to None. + flip_filter (bool): False = convolution, True = correlation. + Defaults to False. + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, + out_width]`. + """ + assert isinstance(up, int) and up >= 1 + assert isinstance(down, int) and down >= 1 + px0, px1, py0, py1 = _parse_padding(padding) + assert gain == float(gain) and gain > 0 + gain = float(gain) + assert slope == float(slope) and slope >= 0 + slope = float(slope) + assert clamp is None or (clamp == float(clamp) and clamp >= 0) + clamp = float(clamp if clamp is not None else 'inf') + + # Lookup from cache. + key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) + if key in _filtered_lrelu_cuda_cache: + return _filtered_lrelu_cuda_cache[key] + + # Forward op. + class FilteredLReluCuda(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, filter_up, filter_down, bias, si, sx, sy): + # pylint: disable=arguments-differ + assert isinstance(input, torch.Tensor) and input.ndim == 4 + + # Replace empty up/downsample kernels with full 1x1 kernels + # (faster than separable). + if filter_up is None: + filter_up = torch.ones([1, 1], + dtype=torch.float32, + device=input.device) + if filter_down is None: + filter_down = torch.ones([1, 1], + dtype=torch.float32, + device=input.device) + assert 1 <= filter_up.ndim <= 2 + assert 1 <= filter_down.ndim <= 2 + + # Replace separable 1x1 kernels with full 1x1 kernels when scale + # factor is 1. + if up == 1 and filter_up.ndim == 1 and filter_up.shape[0] == 1: + filter_up = filter_up.square()[None] + if down == 1 and filter_down.ndim == 1 and filter_down.shape[ + 0] == 1: + filter_down = filter_down.square()[None] + + # Missing sign input tensor. + if si is None: + si = torch.empty([0]) + + # Missing bias tensor. + if bias is None: + bias = torch.zeros([input.shape[1]], + dtype=input.dtype, + device=input.device) + + # Construct internal sign tensor only if gradients are needed. + write_signs = (si.numel() == 0) and (input.requires_grad + or bias.requires_grad) + + # Warn if input storage strides are not in decreasing order due to + # e.g. channels-last layout. + strides = [ + input.stride(i) for i in range(input.ndim) if input.size(i) > 1 + ] + if any(a < b for a, b in zip(strides[:-1], strides[1:])): + warnings.warn( + 'low-performance memory layout detected in filtered_lrelu ' + 'input', RuntimeWarning) + + # Call C++/Cuda plugin if datatype is supported. + if input.dtype in [torch.float16, torch.float32]: + if torch.cuda.current_stream( + input.device) != torch.cuda.default_stream( + input.device): + warnings.warn( + 'filtered_lrelu called with non-default cuda stream ' + 'but concurrent execution is not supported', + RuntimeWarning) + y, so, return_code = ext_module.filtered_lrelu( + input, filter_up, filter_down, bias, si.to(input.device), + up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, + flip_filter, write_signs) + else: + return_code = -1 + + # No Cuda kernel found? Fall back to generic implementation. + # Still more memory efficient than the reference implementation + # because only the bit-packed sign tensor is retained for gradient + # computation. + if return_code < 0: + warnings.warn( + 'filtered_lrelu called with parameters that have no ' + 'optimized CUDA kernel, using generic fallback', + RuntimeWarning) + + y = input.add(bias.unsqueeze(-1).unsqueeze(-1)) # Add bias. + y = upfirdn2d( + input=y, + filter=filter_up, + up=up, + padding=[px0, px1, py0, py1], + gain=float(up**2), + flip_filter=flip_filter) # Upsample. + # Activation function and sign handling. Modifies y in-place. + so = ext_module.filtered_lrelu_act_(y, si.to(y.device), sx, sy, + gain, slope, clamp, + write_signs) + y = upfirdn2d( + input=y, + filter=filter_down, + down=down, + flip_filter=flip_filter) # Downsample. + + # Prepare for gradient computation. + ctx.save_for_backward(filter_up, filter_down, + (si if si.numel() else so)) + ctx.x_shape = input.shape + ctx.y_shape = y.shape + ctx.s_ofs = sx, sy + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + filter_up, filter_down, si = ctx.saved_tensors + _, _, xh, xw = ctx.x_shape + _, _, yh, yw = ctx.y_shape + sx, sy = ctx.s_ofs + dx = None # 0 + dfu = None + assert not ctx.needs_input_grad[1] + dfd = None + assert not ctx.needs_input_grad[2] + db = None # 3 + dsi = None + assert not ctx.needs_input_grad[4] + dsx = None + assert not ctx.needs_input_grad[5] + dsy = None + assert not ctx.needs_input_grad[6] + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: + pp = [ + (filter_up.shape[-1] - 1) + (filter_down.shape[-1] - 1) - + px0, + xw * up - yw * down + px0 - (up - 1), + (filter_up.shape[0] - 1) + (filter_down.shape[0] - 1) - + py0, + xh * up - yh * down + py0 - (up - 1), + ] + gg = gain * (up**2) / (down**2) + ff = (not flip_filter) + sx = sx - (filter_up.shape[-1] - 1) + px0 + sy = sy - (filter_up.shape[0] - 1) + py0 + dx = _filtered_lrelu_cuda( + up=down, + down=up, + padding=pp, + gain=gg, + slope=slope, + clamp=None, + flip_filter=ff).apply(dy, filter_down, filter_up, None, si, + sx, sy) + + if ctx.needs_input_grad[3]: + db = dx.sum([0, 2, 3]) + + return dx, dfu, dfd, db, dsi, dsx, dsy + + # Add to cache. + _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda + return FilteredLReluCuda diff --git a/mmcv/ops/upfirdn2d.py b/mmcv/ops/upfirdn2d.py index 574d4d315b..857e840c1b 100644 --- a/mmcv/ops/upfirdn2d.py +++ b/mmcv/ops/upfirdn2d.py @@ -1,341 +1,460 @@ -# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 - -# Copyright (c) 2021, NVIDIA Corporation. All rights reserved. -# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator -# Augmentation (ADA) -# ======================================================================= - -# 1. Definitions - -# "Licensor" means any person or entity that distributes its Work. - -# "Software" means the original work of authorship made available under -# this License. - -# "Work" means the Software and any additions to or derivative works of -# the Software that are made available under this License. - -# The terms "reproduce," "reproduction," "derivative works," and -# "distribution" have the meaning as provided under U.S. copyright law; -# provided, however, that for the purposes of this License, derivative -# works shall not include works that remain separable from, or merely -# link (or bind by name) to the interfaces of, the Work. - -# Works, including the Software, are "made available" under this License -# by including in or with the Work either (a) a copyright notice -# referencing the applicability of this License to the Work, or (b) a -# copy of this License. - -# 2. License Grants - -# 2.1 Copyright Grant. Subject to the terms and conditions of this -# License, each Licensor grants to you a perpetual, worldwide, -# non-exclusive, royalty-free, copyright license to reproduce, -# prepare derivative works of, publicly display, publicly perform, -# sublicense and distribute its Work and any resulting derivative -# works in any form. - -# 3. Limitations - -# 3.1 Redistribution. You may reproduce or distribute the Work only -# if (a) you do so under this License, (b) you include a complete -# copy of this License with your distribution, and (c) you retain -# without modification any copyright, patent, trademark, or -# attribution notices that are present in the Work. - -# 3.2 Derivative Works. You may specify that additional or different -# terms apply to the use, reproduction, and distribution of your -# derivative works of the Work ("Your Terms") only if (a) Your Terms -# provide that the use limitation in Section 3.3 applies to your -# derivative works, and (b) you identify the specific derivative -# works that are subject to Your Terms. Notwithstanding Your Terms, -# this License (including the redistribution requirements in Section -# 3.1) will continue to apply to the Work itself. - -# 3.3 Use Limitation. The Work and any derivative works thereof only -# may be used or intended for use non-commercially. Notwithstanding -# the foregoing, NVIDIA and its affiliates may use the Work and any -# derivative works commercially. As used herein, "non-commercially" -# means for research or evaluation purposes only. - -# 3.4 Patent Claims. If you bring or threaten to bring a patent claim -# against any Licensor (including any claim, cross-claim or -# counterclaim in a lawsuit) to enforce any patents that you allege -# are infringed by any Work, then your rights under this License from -# such Licensor (including the grant in Section 2.1) will terminate -# immediately. - -# 3.5 Trademarks. This License does not grant any rights to use any -# Licensor’s or its affiliates’ names, logos, or trademarks, except -# as necessary to reproduce the notices described in this License. - -# 3.6 Termination. If you violate any term of this License, then your -# rights under this License (including the grant in Section 2.1) will -# terminate immediately. - -# 4. Disclaimer of Warranty. - -# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR -# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER -# THIS LICENSE. - -# 5. Limitation of Liability. - -# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL -# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE -# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, -# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF -# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK -# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, -# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER -# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF -# THE POSSIBILITY OF SUCH DAMAGES. - -# ======================================================================= - -from typing import Any, List, Tuple, Union +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +# source: https://github.com/NVlabs/stylegan3/blob/main/torch_utils/ops/upfirdn2d.py # noqa +"""Custom PyTorch ops for efficient resampling of 2D images.""" +from typing import Dict, List, Union import torch -from mmengine.utils import to_2tuple -from torch.autograd import Function -from torch.nn import functional as F from ..utils import ext_loader +from .conv2d_gradfix import conv2d + +ext_module = ext_loader.load_ext('_ext', ['upfirdn2d']) + + +def _parse_scaling(scaling): + """parse scaling into list [x, y]""" + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + +def _parse_padding(padding): + """parse padding into list [padx0, padx1, pady0, pady1]""" + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + + +def _get_filter_size(filter): + """get width and height of filter kernel.""" + if filter is None: + return 1, 1 + assert isinstance(filter, torch.Tensor) and filter.ndim in [1, 2] + fw = filter.shape[-1] + fh = filter.shape[0] + fw = int(fw) + fh = int(fh) + assert fw >= 1 and fh >= 1 + return fw, fh + + +def upfirdn2d(input: torch.Tensor, + filter: torch.Tensor, + up: int = 1, + down: int = 1, + padding: Union[int, List[int]] = 0, + flip_filter: bool = False, + gain: Union[float, int] = 1, + use_custom_op: bool = True): + """Pad, upsample, filter, and downsample a batch of 2D images. + + Performs the following sequence of operations for each channel: + + 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). + + 2. Pad the image with the specified number of zeros on each side + (`padding`). Negative padding corresponds to cropping the image. + + 3. Convolve the image with the specified 2D FIR filter (`f`), + shrinking it so that the footprint of all output pixels lies within + the input image. + + 4. Downsample the image by keeping every Nth pixel (`down`). + + This sequence of operations bears close resemblance to + scipy.signal.upfirdn(). + + The fused op is considerably more efficient than performing the same + calculation using standard PyTorch ops. It supports gradients of arbitrary + order. -upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) - - -class UpFirDn2dBackward(Function): - - @staticmethod - def forward(ctx: Any, grad_output: torch.Tensor, kernel: torch.Tensor, - grad_kernel: torch.Tensor, up: tuple, down: tuple, pad: tuple, - g_pad: tuple, in_size: Union[List, Tuple], - out_size: Union[List, Tuple]) -> torch.Tensor: - - up_x, up_y = up - down_x, down_y = down - g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad - - grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) - - grad_input = upfirdn2d_ext.upfirdn2d( - grad_output, - grad_kernel, - up_x=down_x, - up_y=down_y, - down_x=up_x, - down_y=up_y, - pad_x0=g_pad_x0, - pad_x1=g_pad_x1, - pad_y0=g_pad_y0, - pad_y1=g_pad_y1) - grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], - in_size[3]) - - ctx.save_for_backward(kernel) - - pad_x0, pad_x1, pad_y0, pad_y1 = pad - - ctx.up_x = up_x - ctx.up_y = up_y - ctx.down_x = down_x - ctx.down_y = down_y - ctx.pad_x0 = pad_x0 - ctx.pad_x1 = pad_x1 - ctx.pad_y0 = pad_y0 - ctx.pad_y1 = pad_y1 - ctx.in_size = in_size - ctx.out_size = out_size - - return grad_input - - @staticmethod - def backward(ctx: Any, gradgrad_input: torch.Tensor) -> tuple: - kernel, = ctx.saved_tensors - - gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], - ctx.in_size[3], 1) - - gradgrad_out = upfirdn2d_ext.upfirdn2d( - gradgrad_input, - kernel, - up_x=ctx.up_x, - up_y=ctx.up_y, - down_x=ctx.down_x, - down_y=ctx.down_y, - pad_x0=ctx.pad_x0, - pad_x1=ctx.pad_x1, - pad_y0=ctx.pad_y0, - pad_y1=ctx.pad_y1) - # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], - # ctx.out_size[1], ctx.in_size[3]) - gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], - ctx.out_size[0], ctx.out_size[1]) - - return gradgrad_out, None, None, None, None, None, None, None, None - - -class UpFirDn2d(Function): - - @staticmethod - def forward(ctx: Any, input: torch.Tensor, kernel: torch.Tensor, up: tuple, - down: tuple, pad: tuple) -> torch.Tensor: - up_x, up_y = up - down_x, down_y = down - pad_x0, pad_x1, pad_y0, pad_y1 = pad - - kernel_h, kernel_w = kernel.shape - batch, channel, in_h, in_w = input.shape - ctx.in_size = input.shape - - input = input.reshape(-1, in_h, in_w, 1) - - ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - ctx.out_size = (out_h, out_w) - - ctx.up = (up_x, up_y) - ctx.down = (down_x, down_y) - ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) - - g_pad_x0 = kernel_w - pad_x0 - 1 - g_pad_y0 = kernel_h - pad_y0 - 1 - g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 - g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 - - ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) - - out = upfirdn2d_ext.upfirdn2d( - input, - kernel, - up_x=up_x, - up_y=up_y, - down_x=down_x, - down_y=down_y, - pad_x0=pad_x0, - pad_x1=pad_x1, - pad_y0=pad_y0, - pad_y1=pad_y1) - # out = out.view(major, out_h, out_w, minor) - out = out.view(-1, channel, out_h, out_w) - - return out - - @staticmethod - def backward(ctx: Any, grad_output: torch.Tensor) -> tuple: - kernel, grad_kernel = ctx.saved_tensors - - grad_input = UpFirDn2dBackward.apply( - grad_output, - kernel, - grad_kernel, - ctx.up, - ctx.down, - ctx.pad, - ctx.g_pad, - ctx.in_size, - ctx.out_size, - ) - - return grad_input, None, None, None, None - - -def upfirdn2d( - input: torch.Tensor, - kernel: torch.Tensor, - up: Union[int, tuple] = 1, - down: Union[int, tuple] = 1, - pad: tuple = (0, 0)) -> torch.Tensor: # noqa E125 - """UpFRIDn for 2d features. - - UpFIRDn is short for upsample, apply FIR filter and downsample. More - details can be found in: - https://www.mathworks.com/help/signal/ref/upfirdn.html + Args: + input (torch.Tensor): Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height, + filter_width]` (non-separable), `[filter_taps]` (separable), or + `None` (identity). + up (int): Integer upsampling factor. Can be a single int or a + list/tuple `[x, y]`. Defaults to 1. + down (int): Integer downsampling factor. Can be a single int + or a list/tuple `[x, y]`. Defaults to 1. + padding (int | tuple[int]): Padding with respect to the upsampled + image. Can be a single number or a list/tuple `[x, y]` or + `[x_before, x_after, y_before, y_after]`. Defaults to 0. + flip_filter (bool): False = convolution, True = correlation. + Defaults to False. + gain (int): Overall scaling factor for signal magnitude. + Defaults to 1. + use_custom_op (bool): Whether to use customized op. + Defaults to True. + + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]` + """ + assert isinstance(input, torch.Tensor) + if use_custom_op and input.device.type == 'cuda': + return _upfirdn2d_cuda( + up=up, + down=down, + padding=padding, + flip_filter=flip_filter, + gain=gain).apply(input, filter) + return _upfirdn2d_ref( + input, + filter, + up=up, + down=down, + padding=padding, + flip_filter=flip_filter, + gain=gain) + + +def _upfirdn2d_ref(input: torch.Tensor, + filter: torch.Tensor, + up: int = 1, + down: int = 1, + padding: Union[int, List[int]] = 0, + flip_filter: bool = False, + gain: Union[float, int] = 1): + """Slow reference implementation of `upfirdn2d()` using standard PyTorch + ops. + + Args: + input (torch.Tensor): Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height, + filter_width]` (non-separable), `[filter_taps]` (separable), or + `None` (identity). + up (int): Integer upsampling factor. Can be a single int or a + list/tuple `[x, y]`. Defaults to 1. + down (int): Integer downsampling factor. Can be a single int + or a list/tuple `[x, y]`. Defaults to 1. + padding (int | tuple[int]): Padding with respect to the upsampled + image. Can be a single number or a list/tuple `[x, y]` or + `[x_before, x_after, y_before, y_after]`. Defaults to 0. + flip_filter (bool): False = convolution, True = correlation. + Defaults to False. + gain (int): Overall scaling factor for signal magnitude. + Defaults to 1. + + Returns: + torch.Tensor: Tensor of the shape `[batch_size, num_channels, + out_height, out_width]`. + """ + # Validate arguments. + assert isinstance(input, torch.Tensor) and input.ndim == 4 + if filter is None: + filter = torch.ones([1, 1], dtype=torch.float32, device=input.device) + assert isinstance(filter, torch.Tensor) and filter.ndim in [1, 2] + assert filter.dtype == torch.float32 and not filter.requires_grad + batch_size, num_channels, in_height, in_width = input.shape + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Check that upsampled buffer is not smaller than the filter. + upW = in_width * upx + padx0 + padx1 + upH = in_height * upy + pady0 + pady1 + assert upW >= filter.shape[-1] and upH >= filter.shape[0] + + # Upsample by inserting zeros. + x = input.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) + x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) + x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) + + # Pad or crop. + x = torch.nn.functional.pad( + x, [max(padx0, 0), + max(padx1, 0), + max(pady0, 0), + max(pady1, 0)]) + x = x[:, :, + max(-pady0, 0):x.shape[2] - max(-pady1, 0), + max(-padx0, 0):x.shape[3] - max(-padx1, 0)] + + # Setup filter. + filter = filter * (gain**(filter.ndim / 2)) + filter = filter.to(x.dtype) + if not flip_filter: + filter = filter.flip(list(range(filter.ndim))) + + # Convolve with the filter. + filter = filter[None, None].repeat([num_channels, 1] + [1] * filter.ndim) + if filter.ndim == 4: + x = conv2d(input=x, weight=filter, groups=num_channels) + else: + x = conv2d(input=x, weight=filter.unsqueeze(2), groups=num_channels) + x = conv2d(input=x, weight=filter.unsqueeze(3), groups=num_channels) + + # Downsample by throwing away pixels. + x = x[:, :, ::downy, ::downx] + return x + + +_upfirdn2d_cuda_cache: Dict = dict() + + +def _upfirdn2d_cuda(up: int = 1, + down: int = 1, + padding: Union[int, List[int]] = 0, + flip_filter: bool = False, + gain: Union[float, int] = 1): + """Fast CUDA implementation of `upfirdn2d()` using custom ops. Args: - input (torch.Tensor): Tensor with shape of (n, c, h, w). - kernel (torch.Tensor): Filter kernel. - up (int | tuple[int], optional): Upsampling factor. If given a number, - we will use this factor for the both height and width side. + up (int): Integer upsampling factor. Can be a single int or a + list/tuple `[x, y]`. Defaults to 1. + down (int): Integer downsampling factor. Can be a single int + or a list/tuple `[x, y]`. Defaults to 1. + padding (int | tuple[int]): Padding with respect to the upsampled + image. Can be a single number or a list/tuple `[x, y]` or + `[x_before, x_after, y_before, y_after]`. Defaults to 0. + flip_filter (bool): False = convolution, True = correlation. + Defaults to False. + gain (int): Overall scaling factor for signal magnitude. Defaults to 1. - down (int | tuple[int], optional): Downsampling factor. If given a - number, we will use this factor for the both height and width side. + + Returns: + torch.Tensor: Tensor of the shape `[batch_size, num_channels, + out_height, out_width]` + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, + gain) + if key in _upfirdn2d_cuda_cache: + return _upfirdn2d_cuda_cache[key] + + # Forward op. + class Upfirdn2dCuda(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if f.ndim == 1 and f.shape[0] == 1: + f = f.square().unsqueeze( + 0) # Convert separable-1 into full-1x1. + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = ext_module.upfirdn2d(y, f, upx, upy, downx, downy, padx0, + padx1, pady0, pady1, flip_filter, + gain) + else: + y = ext_module.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, + padx0, padx1, 0, 0, flip_filter, 1.0) + y = ext_module.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, + 0, 0, pady0, pady1, flip_filter, gain) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_cuda( + up=down, + down=up, + padding=p, + flip_filter=(not flip_filter), + gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda + return Upfirdn2dCuda + + +def filter2d(input: torch.Tensor, + filter: torch.Tensor, + padding: Union[int, List[int]] = 0, + flip_filter: bool = False, + gain: Union[float, int] = 1, + use_custom_op: bool = True): + """Filter a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape matches the input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + + Args: + input (torch.Tensor): Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height, + filter_width]` (non-separable), `[filter_taps]` (separable), or + `None`. + padding (int | tuple[int]): Padding with respect to the output. + Can be a single number or a list/tuple `[x, y]` or `[x_before, + x_after, y_before, y_after]`. Defaults to 0. + flip_filter (bool): False = convolution, True = correlation. + Defaults to False. + gain (int): Overall scaling factor for signal magnitude. Defaults to 1. - pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or - (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0). + use_custom_op (bool): Whether to use customized op. + Defaults to True. Returns: - torch.Tensor: Tensor after UpFIRDn. + Tensor of the shape `[batch_size, num_channels, out_height, + out_width]`. """ - if input.device.type == 'cpu': - if len(pad) == 2: - pad = (pad[0], pad[1], pad[0], pad[1]) # type: ignore + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(filter) + p = [ + padx0 + fw // 2, + padx1 + (fw - 1) // 2, + pady0 + fh // 2, + pady1 + (fh - 1) // 2, + ] + return upfirdn2d( + input, + filter, + padding=p, + flip_filter=flip_filter, + gain=gain, + use_custom_op=use_custom_op) + + +def upsample2d(input: torch.Tensor, + filter: torch.Tensor, + up: int = 2, + padding: Union[int, List[int]] = 0, + flip_filter: bool = False, + gain: Union[float, int] = 1, + use_custom_op: bool = True): + """Upsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a multiple of the + input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. - _up = to_2tuple(up) + Args: + input (torch.Tensor): Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height, + filter_width]` (non-separable), `[filter_taps]` (separable), or + `None` (identity). + up (int): Integer upsampling factor. Can be a single int or a + list/tuple `[x, y]`. Defaults to 2. + padding (int | tuple[int]): Padding with respect to the output. + Can be a single number or a list/tuple `[x, y]` or `[x_before, + x_after, y_before, y_after]`. Defaults to 0. + flip_filter (bool): False = convolution, True = correlation. Defaults + to False. + gain (int): Overall scaling factor for signal magnitude. Defaults to 1. + use_custom_op (bool): Whether to use customized op. + Defaults to True. - _down = to_2tuple(down) + Returns: + torch.Tensor: Tensor of the shape `[batch_size, num_channels, + out_height, out_width]` + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(filter) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + return upfirdn2d( + input, + filter, + up=up, + padding=p, + flip_filter=flip_filter, + gain=gain * upx * upy, + use_custom_op=use_custom_op) + + +def downsample2d(input: torch.Tensor, + filter: torch.Tensor, + down: int = 2, + padding: Union[int, List[int]] = 0, + flip_filter: bool = False, + gain: Union[float, int] = 1, + use_custom_op: bool = True): + """Downsample a batch of 2D images using the given 2D FIR filter. + + By default, the result is padded so that its shape is a fraction of the + input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. - out = upfirdn2d_native(input, kernel, _up[0], _up[1], _down[0], - _down[1], pad[0], pad[1], pad[2], pad[3]) - else: - _up = to_2tuple(up) - - _down = to_2tuple(down) - - if len(pad) == 4: - _pad = pad - elif len(pad) == 2: - _pad = (pad[0], pad[1], pad[0], pad[1]) - - out = UpFirDn2d.apply(input, kernel, _up, _down, _pad) - - return out - - -def upfirdn2d_native(input: torch.Tensor, kernel: torch.Tensor, up_x: int, - up_y: int, down_x: int, down_y: int, pad_x0: int, - pad_x1: int, pad_y0: int, pad_y1: int) -> torch.Tensor: - _, channel, in_h, in_w = input.shape - input = input.reshape(-1, in_h, in_w, 1) - - _, in_h, in_w, minor = input.shape - kernel_h, kernel_w = kernel.shape - - out = input.view(-1, in_h, 1, in_w, 1, minor) - out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) - out = out.view(-1, in_h * up_y, in_w * up_x, minor) - - out = F.pad( - out, - [0, 0, - max(pad_x0, 0), - max(pad_x1, 0), - max(pad_y0, 0), - max(pad_y1, 0)]) - out = out[:, - max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), - max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] - - out = out.permute(0, 3, 1, 2) - out = out.reshape( - [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) - w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) - out = F.conv2d(out, w) - out = out.reshape( - -1, - minor, - in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, - in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, - ) - out = out.permute(0, 2, 3, 1) - out = out[:, ::down_y, ::down_x, :] - - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 - - return out.view(-1, channel, out_h, out_w) + Args: + input (torch.Tensor): Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + filter (torch.Tensor): Float32 FIR filter of the shape `[filter_height, + filter_width]` (non-separable), `[filter_taps]` (separable), or + `None` (identity). + down (int): Integer downsampling factor. Can be a single int or a + list/tuple `[x, y]` (default: 1). Defaults to 2. + padding (int | tuple[int]): Padding with respect to the input. + Can be a single number or a list/tuple `[x, y]` or `[x_before, + x_after, y_before, y_after]`. Defaults to 0. + flip_filter (bool): False = convolution, True = correlation. Defaults + to False. + gain (int): Overall scaling factor for signal magnitude. Defaults to 1. + use_custom_op (bool): Whether to use customized op. + Defaults to True. + + Returns: + torch.Tensor: Tensor of the shape `[batch_size, num_channels, + out_height, out_width]`. + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(filter) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + return upfirdn2d( + input, + filter, + down=down, + padding=p, + flip_filter=flip_filter, + gain=gain, + use_custom_op=use_custom_op) diff --git a/tests/test_ops/test_bias_act.py b/tests/test_ops/test_bias_act.py new file mode 100644 index 0000000000..01b57c4ae1 --- /dev/null +++ b/tests/test_ops/test_bias_act.py @@ -0,0 +1,144 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmcv.ops import bias_act +from mmcv.ops.bias_act import EasyDict + +_USING_PARROTS = True +try: + from parrots.autograd import gradcheck +except ImportError: + from torch.autograd import gradcheck, gradgradcheck + _USING_PARROTS = False + + +class TestBiasAct: + + @classmethod + def setup_class(cls): + cls.input_tensor = torch.randn((1, 3), requires_grad=True) + cls.bias = torch.randn(3, requires_grad=True) + + def test_bias_act_cpu(self): + out = bias_act(self.input_tensor, self.bias) + assert out.shape == (1, 3) + + # test with different dim + input_tensor = torch.randn((1, 1, 3), requires_grad=True) + bias = torch.randn(3, requires_grad=True) + out = bias_act(input_tensor, bias, dim=2) + assert out.shape == (1, 1, 3) + + # test with different act + out = bias_act(self.input_tensor, self.bias, act='relu') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor, self.bias, act='lrelu') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor, self.bias, act='tanh') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor, self.bias, act='sigmoid') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor, self.bias, act='elu') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor, self.bias, act='selu') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor, self.bias, act='softplus') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor, self.bias, act='swish') + assert out.shape == (1, 3) + + # test with different alpha + out = bias_act(self.input_tensor, self.bias, act='lrelu', alpha=0.1) + assert out.shape == (1, 3) + + # test with different gain + out1 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.2) + out2 = bias_act(self.input_tensor, self.bias, act='lrelu', gain=0.1) + assert torch.allclose(out1, out2 * 2) + + # test with different clamp + out1 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.5) + out2 = bias_act(self.input_tensor, self.bias, act='lrelu', clamp=0.2) + assert out1.max() <= 0.5 + assert out2.max() <= 0.5 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_bias_act_cuda(self): + if _USING_PARROTS: + gradcheck( + bias_act, (self.input_tensor.cuda(), self.bias.cuda()), + delta=1e-4, + pt_atol=1e-3) + else: + gradcheck( + bias_act, (self.input_tensor.cuda(), self.bias.cuda()), + eps=1e-4, + atol=1e-3) + + gradgradcheck( + bias_act, (self.input_tensor.cuda(), self.bias.cuda()), + eps=1e-4, + atol=1e-3) + + out = bias_act(self.input_tensor.cuda(), self.bias.cuda()) + assert out.shape == (1, 3) + + # test with different dim + input_tensor = torch.randn((1, 1, 3), requires_grad=True).cuda() + bias = torch.randn(3, requires_grad=True).cuda() + out = bias_act(input_tensor, bias, dim=2) + assert out.shape == (1, 1, 3) + + # test with different act + out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='relu') + assert out.shape == (1, 3) + + out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='lrelu') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='tanh') + assert out.shape == (1, 3) + out = bias_act( + self.input_tensor.cuda(), self.bias.cuda(), act='sigmoid') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='elu') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='selu') + assert out.shape == (1, 3) + out = bias_act( + self.input_tensor.cuda(), self.bias.cuda(), act='softplus') + assert out.shape == (1, 3) + out = bias_act(self.input_tensor.cuda(), self.bias.cuda(), act='swish') + assert out.shape == (1, 3) + + # test with different alpha + out = bias_act( + self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', alpha=0.1) + assert out.shape == (1, 3) + + # test with different gain + out1 = bias_act( + self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.2) + out2 = bias_act( + self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', gain=0.1) + assert torch.allclose(out1, out2 * 2) + + # test with different clamp + out1 = bias_act( + self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.5) + out2 = bias_act( + self.input_tensor.cuda(), self.bias.cuda(), act='lrelu', clamp=0.2) + assert out1.max() <= 0.5 + assert out2.max() <= 0.5 + + def test_easy_dict(self): + easy_dict = EasyDict( + func=lambda x, **_: x, + def_alpha=0, + def_gain=1, + cuda_idx=1, + ref='', + has_2nd_grad=False) + _ = easy_dict.def_alpha + easy_dict.def_alpha = 1 + del easy_dict.def_alpha diff --git a/tests/test_ops/test_conv_gradfix.py b/tests/test_ops/test_conv_gradfix.py new file mode 100644 index 0000000000..29148abc05 --- /dev/null +++ b/tests/test_ops/test_conv_gradfix.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +import torch.nn as nn +from torch.autograd import gradcheck, gradgradcheck + +from mmcv.ops import conv2d, conv_transpose2d + + +class TestCond2d: + + @classmethod + def setup_class(cls): + cls.input = torch.randn((1, 3, 32, 32), requires_grad=True) + cls.weight = nn.Parameter(torch.randn(1, 3, 3, 3)) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_conv2d_cuda(self): + x = self.input.cuda() + weight = self.weight.cuda() + res = conv2d(x, weight, None, 1, 1) + assert res.shape == (1, 1, 32, 32) + gradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2) + gradgradcheck(conv2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2) + + +class TestCond2dTansposed: + + @classmethod + def setup_class(cls): + cls.input = torch.randn((1, 3, 32, 32), requires_grad=True) + cls.weight = nn.Parameter(torch.randn(3, 1, 3, 3)) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_conv2d_transposed_cuda(self): + x = self.input.cuda() + weight = self.weight.cuda() + res = conv_transpose2d(x, weight, None, 1, 1) + assert res.shape == (1, 1, 32, 32) + gradcheck( + conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2) + gradgradcheck( + conv_transpose2d, (x, weight, None, 1, 1), eps=1e-2, atol=1e-2) diff --git a/tests/test_ops/test_filtered_lrelu.py b/tests/test_ops/test_filtered_lrelu.py new file mode 100644 index 0000000000..9322605d08 --- /dev/null +++ b/tests/test_ops/test_filtered_lrelu.py @@ -0,0 +1,219 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmcv.ops import filtered_lrelu + + +class TestFilteredLrelu: + + @classmethod + def setup_class(cls): + cls.input_tensor = torch.randn((1, 3, 16, 16), requires_grad=True) + cls.bias = torch.randn(3, requires_grad=True) + cls.filter_up = torch.randn((2, 2)) + cls.filter_down = torch.randn((2, 2)) + + def test_filtered_lrelu_cpu(self): + out = filtered_lrelu(self.input_tensor, bias=self.bias) + assert out.shape == (1, 3, 16, 16) + + out = filtered_lrelu( + self.input_tensor, + bias=self.bias, + filter_up=self.filter_up, + filter_down=self.filter_down, + up=2, + down=2, + padding=1, + clamp=0.5) + assert out.shape == (1, 3, 16, 16) + + # test with different filter_up + filter_up = torch.randn((4, 4)) + out = filtered_lrelu( + self.input_tensor, + bias=self.bias, + filter_up=filter_up, + filter_down=self.filter_down, + up=2, + down=2, + padding=2, + clamp=0.5) + assert out.shape == (1, 3, 16, 16) + + # test with different filter_down + filter_down = torch.randn((4, 4)) + out = filtered_lrelu( + self.input_tensor, + bias=self.bias, + filter_up=self.filter_up, + filter_down=filter_down, + up=2, + down=2, + padding=2, + clamp=0.5) + assert out.shape == (1, 3, 16, 16) + + # test with different b + input_tensor = torch.randn((1, 4, 16, 16), requires_grad=True) + bias = torch.randn(4, requires_grad=True) + out = filtered_lrelu( + input_tensor, + bias=bias, + filter_up=self.filter_up, + filter_down=self.filter_down, + up=2, + down=2, + padding=1, + clamp=0.5) + assert out.shape == (1, 4, 16, 16) + + # test with different up + out = filtered_lrelu( + self.input_tensor, + bias=self.bias, + filter_up=self.filter_up, + filter_down=self.filter_down, + up=4, + down=2, + padding=1, + clamp=0.5) + assert out.shape == (1, 3, 32, 32) + + # test with different down + out = filtered_lrelu( + self.input_tensor, + bias=self.bias, + filter_up=self.filter_up, + filter_down=self.filter_down, + up=2, + down=4, + padding=1, + clamp=0.5) + assert out.shape == (1, 3, 8, 8) + + # test with different gain + out1 = filtered_lrelu(self.input_tensor, bias=self.bias, gain=0.2) + out2 = filtered_lrelu(self.input_tensor, bias=self.bias, gain=0.1) + assert torch.allclose(out1, 2 * out2) + + # test with different slope + out = filtered_lrelu(self.input_tensor, bias=self.bias, slope=0.2) + assert out.shape == (1, 3, 16, 16) + + # test with different clamp + out1 = filtered_lrelu(self.input_tensor, bias=self.bias, clamp=0.2) + out2 = filtered_lrelu(self.input_tensor, bias=self.bias, clamp=0.1) + assert out1.max() <= 0.2 + assert out2.max() <= 0.1 + + # test with different flip_filter + out1 = filtered_lrelu( + self.input_tensor, bias=self.bias, flip_filter=True) + assert out.shape == (1, 3, 16, 16) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') + def test_filtered_lrelu_cuda(self): + out = filtered_lrelu(self.input_tensor.cuda(), bias=self.bias.cuda()) + assert out.shape == (1, 3, 16, 16) + + out = filtered_lrelu( + self.input_tensor.cuda(), + bias=self.bias.cuda(), + filter_up=self.filter_up.cuda(), + filter_down=self.filter_down.cuda(), + up=2, + down=2, + padding=1, + clamp=0.5) + assert out.shape == (1, 3, 16, 16) + + # test with different filter_up + filter_up = torch.randn((4, 4)) + out = filtered_lrelu( + self.input_tensor.cuda(), + bias=self.bias.cuda(), + filter_up=filter_up.cuda(), + filter_down=self.filter_down.cuda(), + up=2, + down=2, + padding=2, + clamp=0.5) + assert out.shape == (1, 3, 16, 16) + + # test with different filter_down + filter_down = torch.randn((4, 4)) + out = filtered_lrelu( + self.input_tensor.cuda(), + bias=self.bias.cuda(), + filter_up=self.filter_up.cuda(), + filter_down=filter_down.cuda(), + up=2, + down=2, + padding=2, + clamp=0.5) + assert out.shape == (1, 3, 16, 16) + + # test with different b + input_tensor = torch.randn((1, 4, 16, 16), requires_grad=True) + bias = torch.randn(4, requires_grad=True) + out = filtered_lrelu( + input_tensor.cuda(), + bias=bias.cuda(), + filter_up=self.filter_up.cuda(), + filter_down=self.filter_down.cuda(), + up=2, + down=2, + padding=1, + clamp=0.5) + assert out.shape == (1, 4, 16, 16) + + # test with different up + out = filtered_lrelu( + self.input_tensor.cuda(), + bias=self.bias.cuda(), + filter_up=self.filter_up.cuda(), + filter_down=self.filter_down.cuda(), + up=4, + down=2, + padding=1, + clamp=0.5) + assert out.shape == (1, 3, 32, 32) + + # test with different down + out = filtered_lrelu( + self.input_tensor.cuda(), + bias=self.bias.cuda(), + filter_up=self.filter_up.cuda(), + filter_down=self.filter_down.cuda(), + up=2, + down=4, + padding=1, + clamp=0.5) + assert out.shape == (1, 3, 8, 8) + + # test with different gain + out1 = filtered_lrelu( + self.input_tensor.cuda(), bias=self.bias.cuda(), gain=0.2) + out2 = filtered_lrelu( + self.input_tensor.cuda(), bias=self.bias.cuda(), gain=0.1) + assert torch.allclose(out1, 2 * out2) + + # test with different slope + out = filtered_lrelu( + self.input_tensor.cuda(), bias=self.bias.cuda(), slope=0.2) + assert out.shape == (1, 3, 16, 16) + + # test with different clamp + out1 = filtered_lrelu( + self.input_tensor.cuda(), bias=self.bias.cuda(), clamp=0.2) + out2 = filtered_lrelu( + self.input_tensor.cuda(), bias=self.bias.cuda(), clamp=0.1) + assert out1.max() <= 0.2 + assert out2.max() <= 0.1 + + # test with different flip_filter + out1 = filtered_lrelu( + self.input_tensor.cuda(), bias=self.bias.cuda(), flip_filter=True) + assert out.shape == (1, 3, 16, 16) diff --git a/tests/test_ops/test_upfirdn2d.py b/tests/test_ops/test_upfirdn2d.py index 6037a51c2f..1342480a63 100644 --- a/tests/test_ops/test_upfirdn2d.py +++ b/tests/test_ops/test_upfirdn2d.py @@ -56,3 +56,29 @@ def test_upfirdn2d(self): self.input_tensor).cuda(), self.factor, 1, self.pad), eps=1e-4, atol=1e-3) + + # test with different up + kernel = torch.randn(3, 3) + out = upfirdn2d( + self.input_tensor.cuda(), filter=kernel.cuda(), up=2, padding=1) + assert out.shape == (2, 3, 8, 8) + + # test with different down + input_tensor = torch.randn(2, 3, 8, 8) + out = upfirdn2d( + input_tensor.cuda(), filter=self.kernel.cuda(), down=2, padding=1) + assert out.shape == (2, 3, 4, 4) + + # test with different flip_filter + out = upfirdn2d( + self.input_tensor.cuda(), + filter=self.kernel.cuda(), + flip_filter=True) + assert out.shape == (2, 3, 1, 1) + + # test with different gain + out1 = upfirdn2d( + self.input_tensor.cuda(), filter=self.kernel.cuda(), gain=0.2) + out2 = upfirdn2d( + self.input_tensor.cuda(), filter=self.kernel.cuda(), gain=0.1) + assert torch.allclose(out1, out2 * 2)