Skip to content

Commit

Permalink
Support SyncBatchNorm5D (apache#14542)
Browse files Browse the repository at this point in the history
* support SyncBatchNorm5D

* fix

* update testcase and reformat code

* retrigger CI

* update test case

* test

* Retrigger CI

* disable cudnn for batchnorm

* fix BatchNorm(cudnn)

* fix build

* Remove a testcase

* Update sync_batch_norm-inl.h

* update unittest

* update unittest

* update test

* fix test

* change atol and rtol

* BN(cudnn) 5d

* update test

* test

* Testing

* Update batch_norm.cu

* test cudnnoff

* Update test_operator.py

* update BN! : )
  • Loading branch information
wkcn authored and ZhennanQin committed Apr 3, 2019
1 parent a72d8c5 commit 89e0051
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 152 deletions.
31 changes: 19 additions & 12 deletions src/operator/contrib/sync_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ struct SyncBatchNormParam : public dmlc::Parameter<SyncBatchNormParam> {
DMLC_DECLARE_FIELD(ndev).set_default(1)
.describe("The count of GPU devices");
DMLC_DECLARE_FIELD(key)
.set_default("")
.describe("Hash key for synchronization, please set the same hash key for same layer, "
"Block.prefix is typically used as in :class:`gluon.nn.contrib.SyncBatchNorm`.");
}
Expand Down Expand Up @@ -275,14 +274,18 @@ class SyncBatchNorm : public Operator {
static_cast<real_t>(in_data[syncbatchnorm::kData].shape_.Size());
Tensor<xpu, 4> data;
Tensor<xpu, 4> out;
if (in_data[syncbatchnorm::kData].ndim() == 2) {
if (in_data[syncbatchnorm::kData].ndim() == 4) {
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
} else {
index_t num_channels = in_data[syncbatchnorm::kData].ndim() > 1 ?
in_data[syncbatchnorm::kData].shape_[1] : 1;
index_t spatial_size = in_data[syncbatchnorm::kData].shape_.ProdShape(2,
in_data[syncbatchnorm::kData].ndim());
Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0],
in_data[syncbatchnorm::kData].shape_[1], 1, 1);
num_channels, 1, spatial_size);
data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
out = out_data[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
} else {
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
}
Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
Tensor<xpu, 1> bias = in_data[syncbatchnorm::kBeta].get<xpu, 1, real_t>(s);
Expand Down Expand Up @@ -354,16 +357,20 @@ class SyncBatchNorm : public Operator {
Tensor<xpu, 4> data, grad, grad_in;
const real_t scale = static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_[1]) /
static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_.Size());
if (in_data[syncbatchnorm::kData].ndim() == 2) {
if (in_data[syncbatchnorm::kData].ndim() == 4) {
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
} else {
index_t num_channels = out_grad[syncbatchnorm::kOut].ndim() > 1 ?
out_grad[syncbatchnorm::kOut].shape_[1] : 1;
index_t spatial_size = out_grad[syncbatchnorm::kOut].shape_.ProdShape(2,
out_grad[syncbatchnorm::kOut].ndim());
Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0],
out_grad[syncbatchnorm::kOut].shape_[1], 1, 1);
num_channels, 1, spatial_size);
data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
grad = out_grad[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
grad_in = in_grad[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
} else {
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
}

Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, real_t>(s);
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
Expand Down Expand Up @@ -697,7 +697,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
Expand Down
14 changes: 8 additions & 6 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class CuDNNBatchNormOp {
}
CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo);
CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2);
CHECK_LE(in_data[cudnnbatchnorm::kData].ndim(), 4);

Init(in_data[cudnnbatchnorm::kData]);
Stream<gpu> *s = ctx.get_stream<gpu>();
Expand Down Expand Up @@ -273,12 +272,15 @@ class CuDNNBatchNormOp {

private:
void Init(const TBlob &in_data) {
for (int i = 0; i < 4; ++i) {
if (i < in_data.ndim()) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
} else {
shape_[i] = 1;
}
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
shape_[2] = 1;
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
}

CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
Expand Down
Loading

0 comments on commit 89e0051

Please sign in to comment.