Skip to content

Commit

Permalink
[MLU] fix sync_bn of mlu and add unittests (PaddlePaddle#45707)
Browse files Browse the repository at this point in the history
* [MLU] fix sync_bn of mlu and add unittests

* [MLU] remove redunant code of pytest
  • Loading branch information
qipengh authored and Caozhou1995 committed Sep 8, 2022
1 parent c40a36b commit cf08c44
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 45 deletions.
57 changes: 34 additions & 23 deletions paddle/fluid/operators/sync_batch_norm_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
GetBasePtr(&local_var));

Tensor input_count;
input_count.mutable_data<T>(phi::make_ddim({1}), ctx.GetPlace());
FillMLUTensorWithHostValue<T>(
ctx, static_cast<T>(x->numel() / C), &input_count);
input_count.mutable_data<MPDType>(phi::make_ddim({1}), ctx.GetPlace());
FillMLUTensorWithHostValue<MPDType>(
ctx, static_cast<MPDType>(x->numel() / C), &input_count);

Tensor count_all;
Tensor mean_all(mean->dtype());
Expand All @@ -170,28 +170,31 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_CNCL
auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.cncl_comm();
if (comm) {
auto *comm = paddle::platform::CNCLCommContext::Instance()
.Get(0, ctx.GetPlace())
->comm();
auto cncl_comm = paddle::platform::CNCLCommContext::Instance().Get(
0, ctx.GetPlace());
auto *comm = cncl_comm->comm();
auto comm_stream = cncl_comm->stream();
int count;
PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCommCount(&count, comm));
count_all.mutable_data<T>(phi::make_ddim({count}), ctx.GetPlace());
count_all.mutable_data<MPDType>(phi::make_ddim({count}),
ctx.GetPlace());
mean_all.mutable_data<MPDType>(phi::make_ddim({count, mean->numel()}),
ctx.GetPlace());
invstd_all.mutable_data<MPDType>(
phi::make_ddim({count, variance->numel()}), ctx.GetPlace());
// before comm_stream exec, need sync compute_stream.
dev_ctx.Wait();

cnclDataType_t dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(count_all.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&input_count),
GetBasePtr(&count_all),
1,
dtype,
comm,
stream));

mean_all.mutable_data<MPDType>(phi::make_ddim({count, mean->numel()}),
ctx.GetPlace());
invstd_all.mutable_data<MPDType>(
phi::make_ddim({count, variance->numel()}), ctx.GetPlace());
comm_stream));

auto cncl_dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(mean_all.dtype()));
Expand All @@ -200,14 +203,17 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
local_mean.numel(),
cncl_dtype,
comm,
stream));
comm_stream));

PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&local_var),
GetBasePtr(&invstd_all),
local_var.numel(),
cncl_dtype,
comm,
stream));
comm_stream));
// after comm_stream exec, need sync queue for using compute_stream
// correctly.
PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(comm_stream));
#else
if (NO_USE_CNCL) {
#endif
Expand Down Expand Up @@ -412,12 +418,14 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_CNCL
auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.cncl_comm();
if (comm) {
auto *comm = paddle::platform::CNCLCommContext::Instance()
.Get(0, ctx.GetPlace())
->comm();
auto cncl_comm =
paddle::platform::CNCLCommContext::Instance().Get(0, ctx.GetPlace());
auto *comm = cncl_comm->comm();
auto comm_stream = cncl_comm->stream();
// before comm_stream exec, need sync compute_stream.
dev_ctx.Wait();
cnclDataType_t dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(numel_count.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&numel_count),
Expand All @@ -426,7 +434,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
dtype,
cnclSum,
comm,
stream));
comm_stream));

auto cncl_dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(sum_dy.dtype()));
Expand All @@ -436,15 +444,18 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
cncl_dtype,
cnclSum,
comm,
stream));
comm_stream));

PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&sum_dy_xmu),
GetBasePtr(&sum_dy_xmu),
sum_dy_xmu.numel(),
cncl_dtype,
cnclSum,
comm,
stream));
comm_stream));
// after comm_stream exec, need sync queue for using compute_stream
// correctly.
PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(comm_stream));
}
#endif

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import contextlib
import unittest
import numpy as np
import six
import pickle

import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.nn import Conv2D, Linear, SyncBatchNorm
from paddle.fluid.dygraph.base import to_variable
import sys

sys.path.append("..")
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase


class TestLayer(fluid.dygraph.Layer):

def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(TestLayer, self).__init__()

self._conv = Conv2D(in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)

self._sync_batch_norm = SyncBatchNorm(num_filters)

self._conv2 = Conv2D(in_channels=num_filters,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)

self._sync_batch_norm2 = SyncBatchNorm(num_filters,
weight_attr=False,
bias_attr=False)

def forward(self, inputs):
y = self._conv(inputs)
y = self._sync_batch_norm(y)
y = self._conv2(y)
y = self._sync_batch_norm2(y)

return y


class TestSyncBatchNorm(TestParallelDyGraphRunnerBase):

def get_model(self):
model = TestLayer(3, 64, 7)
train_reader = paddle.batch(paddle.dataset.flowers.test(use_xmap=False),
batch_size=32,
drop_last=True)
opt = fluid.optimizer.Adam(learning_rate=1e-3,
parameter_list=model.parameters())
return model, train_reader, opt

def run_one_loop(self, model, opt, data):
batch_size = len(data)
dy_x_data = np.array([x[0].reshape(3, 224, 224)
for x in data]).astype('float32')
img = to_variable(dy_x_data)
img.stop_gradient = False

out = model(img)

out = paddle.mean(out)

return out


if __name__ == "__main__":
runtime_main(TestSyncBatchNorm)
Loading

0 comments on commit cf08c44

Please sign in to comment.