Skip to content

Commit

Permalink
Add ResNetUnit Python API (#35426)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzSean authored Oct 15, 2021
1 parent 2de0b58 commit 12882b2
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
out_var_ptr->GeneratedOp());

// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
if (right_generated_op->Name() != "conv2d_grad") {
if (right_generated_op->Name() != "conv2d_grad" &&
right_generated_op->Name() != "resnet_unit_grad") {
continue;
}

Expand Down Expand Up @@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) {
if (node.inputs.empty()) return false;
auto *generated_op = node.inputs[0];
auto *op_desc = generated_op->Op();
if (op_desc == nullptr || op_desc->Type() != "conv2d_grad") {
if (op_desc == nullptr || (op_desc->Type() != "conv2d_grad" &&
op_desc->Type() != "resnet_unit_grad")) {
return false;
}
const auto &outputs = op_desc->Outputs();
auto iter = outputs.find(GradVarName("Input"));
std::string grad_var_name = op_desc->Type() == "conv2d_grad" ? "Input" : "X";
auto iter = outputs.find(GradVarName(grad_var_name));
return iter != outputs.end() && !iter->second.empty() &&
iter->second[0] == node.Name() &&
!op_desc->GetAttrIfExists<bool>("use_addto");
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/operators/fused/resnet_unit_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,14 @@ class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("use_addto", "").SetDefault(false);
AddAttr<std::string>("act_type", "The activation type to be fused.")
.SetDefault("relu");
AddComment(R"DOC(
Fusion op of the basic unit of resnet block.
Fusion op of the basic unit of resnet block.
The implementation is based on the latest fusion op interface in cuDNN v8.0.
For more details:
For more details:
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnFusedOps_t
)DOC");
Expand Down
19 changes: 10 additions & 9 deletions paddle/fluid/operators/fused/resnet_unit_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
Expand Down Expand Up @@ -87,7 +87,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_x.Resize(param_dims);
sum_of_squares_x.Resize(param_dims);
CudnnNormConvolution<T> conv_x_op(dev_ctx, input_x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
output_shape, padding, stride, dilation,
group);
conv_x_op.Forward(dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x,
&sum_of_squares_x);
Expand Down Expand Up @@ -129,8 +129,8 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_z.Resize(param_dims);
sum_of_squares_z.Resize(param_dims);
CudnnNormConvolution<T> conv_z_op(dev_ctx, input_z_shape, filter_z_shape,
output_shape, padding, stride_z, dilate,
group);
output_shape, padding, stride_z,
dilation, group);
conv_z_op.Forward(dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z,
&sum_of_squares_z);

Expand Down Expand Up @@ -189,7 +189,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
Expand Down Expand Up @@ -263,7 +263,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
auto filter_z_shape = framework::vectorize<int>(filter_z->dims());
CudnnNormConvolutionGrad<T> conv_z_op(dev_ctx, z_shape, filter_z_shape,
output_shape, padding, stride_z,
dilate, group);
dilation, group);
conv_z_op.Backward(dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad,
filter_z_grad);
} else {
Expand All @@ -278,11 +278,12 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
}

// 2. Backward of Conv for x, get x_grad and filter_x_grad
bool use_addto = ctx.Attr<bool>("use_addto");
CudnnNormConvolutionGrad<T> conv_x_op(dev_ctx, x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
group);
output_shape, padding, stride,
dilation, group);
conv_x_op.Backward(dev_ctx, *x, *filter_x, conv_out_x_grad, x_grad,
filter_x_grad);
filter_x_grad, use_addto);
}
};

Expand Down
1 change: 1 addition & 0 deletions python/paddle/incubate/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401
from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401
from .resnet_unit import ResNetUnit #noqa: F401
269 changes: 269 additions & 0 deletions python/paddle/incubate/operators/resnet_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import collections
import itertools
import six
import math
import sys
import warnings
from functools import partial, reduce

import numpy as np
import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.device import get_device, get_cudnn_version
from paddle.nn import initializer as I
from paddle.nn import Layer, LayerList
from paddle.fluid.layers import utils
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.param_attr import ParamAttr
from paddle import _C_ops
__all__ = ['resnet_unit', 'ResNetUnit']


def resnet_unit(x, filter_x, scale_x, bias_x, mean_x, var_x, z, filter_z,
scale_z, bias_z, mean_z, var_z, stride, stride_z, padding,
dilation, groups, momentum, eps, data_format, fuse_add,
has_shortcut, use_global_stats, is_test, act):

helper = LayerHelper('resnet_unit', **locals())
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bit_mask_dtype = fluid.core.VarDesc.VarType.INT32
out = helper.create_variable_for_type_inference(x.dtype)
bit_mask = helper.create_variable_for_type_inference(
dtype=bit_mask_dtype, stop_gradient=True)
# intermediate_out for x
conv_x = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
saved_mean_x = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd_x = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean_x = mean_x
running_var_x = var_x
# intermediate_out for z
conv_z = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
saved_mean_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if mean_z is None else mean_z
running_var_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if var_z is None else var_z

inputs = {
'X': x,
'FilterX': filter_x,
'ScaleX': scale_x,
'BiasX': bias_x,
'MeanX': mean_x,
'VarX': var_x,
'Z': z,
'FilterZ': filter_z,
'ScaleZ': scale_z,
'BiasZ': bias_z,
'MeanZ': mean_z,
'VarZ': var_z
}

attrs = {
'stride': stride,
'stride_z': stride_z,
'padding': padding,
'dilation': dilation,
'group': groups,
'momentum': momentum,
'epsilon': eps,
'data_format': data_format,
'fuse_add': fuse_add,
'has_shortcut': has_shortcut,
'use_global_stats': use_global_stats,
'is_test': is_test,
'act_type': act
}

outputs = {
'Y': out,
'BitMask': bit_mask,
'ConvX': conv_x,
'SavedMeanX': saved_mean_x,
'SavedInvstdX': saved_invstd_x,
'RunningMeanX': running_mean_x,
'RunningVarX': running_var_x,
'ConvZ': conv_z,
'SavedMeanZ': saved_mean_z,
'SavedInvstdZ': saved_invstd_z,
'RunningMeanZ': running_mean_z,
'RunningVarZ': running_var_z,
}

helper.append_op(
type='resnet_unit', inputs=inputs, outputs=outputs, attrs=attrs)

return out


class ResNetUnit(Layer):
r"""
******Temporary version******.
ResNetUnit is designed for optimize the performence by using cudnnv8 API.
"""

def __init__(self,
num_channels_x,
num_filters,
filter_size,
stride=1,
momentum=0.9,
eps=1e-5,
data_format='NHWC',
act='relu',
fuse_add=False,
has_shortcut=False,
use_global_stats=False,
is_test=False,
filter_x_attr=None,
scale_x_attr=None,
bias_x_attr=None,
moving_mean_x_name=None,
moving_var_x_name=None,
num_channels_z=1,
stride_z=1,
filter_z_attr=None,
scale_z_attr=None,
bias_z_attr=None,
moving_mean_z_name=None,
moving_var_z_name=None):
super(ResNetUnit, self).__init__()
self._stride = stride
self._stride_z = stride_z
self._dilation = 1
self._kernel_size = utils.convert_to_list(filter_size, 2, 'kernel_size')
self._padding = (filter_size - 1) // 2
self._groups = 1
self._momentum = momentum
self._eps = eps
self._data_format = data_format
self._act = act
self._fuse_add = fuse_add
self._has_shortcut = has_shortcut
self._use_global_stats = use_global_stats
self._is_test = is_test

# check format
valid_format = {'NHWC'}
if data_format not in valid_format:
raise ValueError(
"conv_format must be one of {}, but got conv_format='{}'".
format(valid_format, data_format))

def _get_default_param_initializer(channels):
filter_elem_num = np.prod(self._kernel_size) * channels
std = (2.0 / filter_elem_num)**0.5
return I.Normal(0.0, std)

# initial filter
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bn_param_shape = [1, 1, 1, num_filters]
filter_x_shape = [num_filters, filter_size, filter_size, num_channels_x]
filter_z_shape = [num_filters, filter_size, filter_size, num_channels_z]

self.filter_x = self.create_parameter(
shape=filter_x_shape,
attr=filter_x_attr,
default_initializer=_get_default_param_initializer(num_channels_x))
self.scale_x = self.create_parameter(
shape=bn_param_shape,
attr=scale_x_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_x = self.create_parameter(
shape=bn_param_shape,
attr=bias_x_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_x = self.create_parameter(
attr=ParamAttr(
name=moving_mean_x_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.mean_x.stop_gradient = True
self.var_x = self.create_parameter(
attr=ParamAttr(
name=moving_var_x_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.var_x.stop_gradient = True
if has_shortcut:
self.filter_z = self.create_parameter(
shape=filter_z_shape,
attr=filter_z_attr,
default_initializer=_get_default_param_initializer(
num_channels_z))
self.scale_z = self.create_parameter(
shape=bn_param_shape,
attr=scale_z_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_z = self.create_parameter(
shape=bn_param_shape,
attr=bias_z_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_z = self.create_parameter(
attr=ParamAttr(
name=moving_mean_z_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.mean_z.stop_gradient = True
self.var_z = self.create_parameter(
attr=ParamAttr(
name=moving_var_z_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.var_z.stop_gradient = True
else:
self.filter_z = None
self.scale_z = None
self.bias_z = None
self.mean_z = None
self.var_z = None

def forward(self, x, z=None):
if self._fuse_add and z is None:
raise ValueError("z can not be None")

out = resnet_unit(
x, self.filter_x, self.scale_x, self.bias_x, self.mean_x,
self.var_x, z, self.filter_z, self.scale_z, self.bias_z,
self.mean_z, self.var_z, self._stride, self._stride_z,
self._padding, self._dilation, self._groups, self._momentum,
self._eps, self._data_format, self._fuse_add, self._has_shortcut,
self._use_global_stats, self._is_test, self._act)
return out

0 comments on commit 12882b2

Please sign in to comment.