From f27339d78519b2eb42591f401e064962c672520c Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Tue, 16 Mar 2021 04:10:33 +0000 Subject: [PATCH 1/6] init --- .../truncated_gaussian_random_op_npu.cc | 81 +++++++++++++++++ .../test_truncated_gaussian_random_op_npu.py | 91 +++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 paddle/fluid/operators/truncated_gaussian_random_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py diff --git a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc new file mode 100644 index 0000000000000..82a04335ae4e9 --- /dev/null +++ b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/truncated_gaussian_random_op.h" +#include +#include +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class TruncatedGaussianRandomNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // to do: select_rows + std::vector shape = ctx.Attr>("shape"); + Tensor shape_tensor(framework::proto::VarType::INT32); + shape_tensor.mutable_data({static_cast(shape.size())}, + ctx.GetPlace()); + TensorFromVector(shape, ctx.device_context(), &shape_tensor); + float mean = ctx.Attr("mean"); + Tensor mean_tensor(framework::proto::VarType::FP32); + mean_tensor.mutable_data({1}, ctx.GetPlace()); + TensorFromVector(std::vector{mean}, ctx.device_context(), + &mean_tensor); + + float std = ctx.Attr("std"); + Tensor std_tensor(framework::proto::VarType::FP32); + std_tensor.mutable_data({1}, ctx.GetPlace()); + TensorFromVector(std::vector{std}, ctx.device_context(), + &std_tensor); + + int32_t seed_var = ctx.Attr("seed"); + + Tensor min_tensor(framework::proto::VarType::FP32); + min_tensor.mutable_data({1}, ctx.GetPlace()); + float min_value = mean - std * 2.0; + TensorFromVector(std::vector{min_value}, ctx.device_context(), + &min_tensor); + + Tensor max_tensor(framework::proto::VarType::FP32); + max_tensor.mutable_data({1}, ctx.GetPlace()); + float max_value = mean + std * 2.0; + TensorFromVector(std::vector{max_value}, ctx.device_context(), + &max_tensor); + + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + auto stream = + ctx.template device_context() + .stream(); + auto runner = NpuOpRunner( + "ParameterizedTruncatedNormal", + {shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor}, {*out}, + {{"seed", seed_var}}); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + truncated_gaussian_random, + ops::TruncatedGaussianRandomNPUKernel, + ops::TruncatedGaussianRandomNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py new file mode 100644 index 0000000000000..9fb02fca3a5f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from paddle.fluid.executor import Executor + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestTrunctedGaussianRandom(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "truncated_gaussian_random" + self.place = paddle.NPUPlace(0) + self.inputs = {} + self.attrs = { + "shape": [10], + "mean": .0, + "std": 1., + "seed": 10, + } + + self.outputs = {"Out": np.random.random(10).astype(self.dtype)} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + #def test_npu(self): + # self.gaussian_random_test(place=paddle.NPUPlace(0)) + + #def gaussian_random_test(self, place): + + # program = fluid.Program() + # block = program.global_block() + # vout = block.create_var(name="Out") + # op = block.append_op( + # type=self.op_type, outputs={"Out": vout}, attrs=self.attrs) + + # op.desc.infer_var_type(block.desc) + # op.desc.infer_shape(block.desc) + + # fetch_list = [] + # for var_name in self.outputs: + # fetch_list.append(block.var(var_name)) + + # exe = Executor(place) + # outs = exe.run(program, fetch_list=fetch_list) + # tensor = outs[0] + # self.assertAlmostEqual(numpy.mean(tensor), .0, delta=0.1) + # self.assertAlmostEqual(numpy.var(tensor), 0.773, delta=0.1) + + # TODO(ascendrc): Add grad test + # def test_check_grad(self): + # if self.dtype == np.float16: + # return + # self.check_grad(['X'], 'Out') + # + + +if __name__ == '__main__': + unittest.main() From b98a3a38364914bc6c168e773e0b5325ee8c74fd Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Tue, 16 Mar 2021 04:38:13 +0000 Subject: [PATCH 2/6] add todo --- paddle/fluid/operators/truncated_gaussian_random_op_npu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc index 82a04335ae4e9..9dd292f869628 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc @@ -23,7 +23,7 @@ template class TruncatedGaussianRandomNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - // to do: select_rows + // TODO(zhiqiu): support dynamic shape and call ParameterizedTruncatedNormal std::vector shape = ctx.Attr>("shape"); Tensor shape_tensor(framework::proto::VarType::INT32); shape_tensor.mutable_data({static_cast(shape.size())}, From e87258bd7889299db31c728d0f3eeabc20669fe0 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 25 Mar 2021 10:16:18 +0000 Subject: [PATCH 3/6] add npu kernel for truncated_gaussian_random --- .../truncated_gaussian_random_op_npu.cc | 42 +++++++-- .../test_truncated_gaussian_random_op_npu.py | 85 +++++++------------ 2 files changed, 67 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc index 9dd292f869628..103f574d73fc9 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -68,14 +71,41 @@ class TruncatedGaussianRandomNPUKernel : public framework::OpKernel { } }; +// NOTE(zhiqiu): actually, this is cpu version kernel, and we need to make the +// above +// npu version work in the future. +template +class NPUTruncatedGaussianRandomKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + float mean = context.Attr("mean"); + float std = context.Attr("std"); + auto* tensor = context.Output("Out"); + tensor->mutable_data(context.GetPlace()); + + Tensor cpu_tensor(tensor->type()); + cpu_tensor.Resize(tensor->dims()); + T* cpu_data = cpu_tensor.mutable_data(platform::CPUPlace()); + std::uniform_real_distribution dist(std::numeric_limits::min(), + 1.0); + TruncatedNormal truncated_normal(mean, std); + int64_t size = tensor->numel(); + + unsigned int seed = static_cast(context.Attr("seed")); + auto engine = framework::GetCPURandomEngine(seed); + for (int64_t i = 0; i < size; ++i) { + cpu_data[i] = truncated_normal(dist(*engine)); + } + framework::TensorCopy( + cpu_tensor, context.GetPlace(), + context.template device_context(), tensor); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_NPU_KERNEL( - truncated_gaussian_random, - ops::TruncatedGaussianRandomNPUKernel, - ops::TruncatedGaussianRandomNPUKernel); +REGISTER_OP_NPU_KERNEL(truncated_gaussian_random, + ops::NPUTruncatedGaussianRandomKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py index 9fb02fca3a5f4..4b95db7915327 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py @@ -31,60 +31,37 @@ @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") -class TestTrunctedGaussianRandom(OpTest): - def setUp(self): - self.set_npu() - self.op_type = "truncated_gaussian_random" - self.place = paddle.NPUPlace(0) - self.inputs = {} - self.attrs = { - "shape": [10], - "mean": .0, - "std": 1., - "seed": 10, - } - - self.outputs = {"Out": np.random.random(10).astype(self.dtype)} - - def set_npu(self): - self.__class__.use_npu = True - - def init_dtype(self): - self.dtype = np.float32 - - def test_check_output(self): - self.check_output_with_place(self.place, check_dygraph=False) - - #def test_npu(self): - # self.gaussian_random_test(place=paddle.NPUPlace(0)) - - #def gaussian_random_test(self, place): - - # program = fluid.Program() - # block = program.global_block() - # vout = block.create_var(name="Out") - # op = block.append_op( - # type=self.op_type, outputs={"Out": vout}, attrs=self.attrs) - - # op.desc.infer_var_type(block.desc) - # op.desc.infer_shape(block.desc) - - # fetch_list = [] - # for var_name in self.outputs: - # fetch_list.append(block.var(var_name)) - - # exe = Executor(place) - # outs = exe.run(program, fetch_list=fetch_list) - # tensor = outs[0] - # self.assertAlmostEqual(numpy.mean(tensor), .0, delta=0.1) - # self.assertAlmostEqual(numpy.var(tensor), 0.773, delta=0.1) - - # TODO(ascendrc): Add grad test - # def test_check_grad(self): - # if self.dtype == np.float16: - # return - # self.check_grad(['X'], 'Out') - # +class TestPowNet(unittest.TestCase): + def _test(self, run_npu=True): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + paddle.seed(SEED) + + with paddle.static.program_guard(main_prog, startup_prog): + weight_attr = paddle.framework.ParamAttr( + name="linear_weight", + initializer=paddle.nn.initializer.TruncatedNormal( + mean=0.0, std=2.0)) + linear = paddle.nn.Linear( + 2, 2, weight_attr=weight_attr, bias_attr=False) + + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() + + exe = paddle.static.Executor(place) + w = exe.run(startup_prog, fetch_list=['linear_weight']) + return w + + def test_npu(self): + cpu_w = self._test(False) + npu_w = self._test(True) + + self.assertTrue(np.allclose(npu_w, cpu_w)) if __name__ == '__main__': From 4b36ee2c762398b454ad043a88472212307c9353 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 25 Mar 2021 10:21:10 +0000 Subject: [PATCH 4/6] add sync --- paddle/fluid/operators/truncated_gaussian_random_op_npu.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc index 103f574d73fc9..4253187fdde74 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc @@ -99,6 +99,8 @@ class NPUTruncatedGaussianRandomKernel : public framework::OpKernel { framework::TensorCopy( cpu_tensor, context.GetPlace(), context.template device_context(), tensor); + context.template device_context() + .Wait(); } }; From f07de67a814ab5a67726fbfec7e4116f2a9c2ba4 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 25 Mar 2021 11:19:03 +0000 Subject: [PATCH 5/6] fix concat_grad --- paddle/fluid/operators/concat_op_npu.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/concat_op_npu.cc b/paddle/fluid/operators/concat_op_npu.cc index 9b979dede048f..87bb3397ca267 100644 --- a/paddle/fluid/operators/concat_op_npu.cc +++ b/paddle/fluid/operators/concat_op_npu.cc @@ -80,7 +80,6 @@ class ConcatGradNPUKernel : public framework::OpKernel { axis = ComputeAxis(static_cast(axis), static_cast(ins[0]->dims().size())); - std::vector sizes; int offset = 0; auto stream = ctx.template device_context() @@ -91,7 +90,6 @@ class ConcatGradNPUKernel : public framework::OpKernel { if (out_var_names[j] != framework::kEmptyVarName && outs[j]->numel() != 0UL) { outs[j]->mutable_data(ctx.GetPlace()); - sizes.push_back(outs[j]->dims()[axis]); std::vector offsets; std::vector sizes; for (int dim = 0; dim < ins[j]->dims().size(); ++dim) { @@ -103,9 +101,8 @@ class ConcatGradNPUKernel : public framework::OpKernel { sizes.push_back(ins[j]->dims()[dim]); } } - auto runner = - NpuOpRunner("SliceD", {*out_grad}, {*outs[j]}, - {{"offsets", offset}, {"size", ins[j]->dims()[axis]}}); + auto runner = NpuOpRunner("SliceD", {*out_grad}, {*outs[j]}, + {{"offsets", offsets}, {"size", sizes}}); runner.Run(stream); } if (ins[j]->numel() != 0UL) { From b42cc8341a7ede4825dea591a7ac9a6c45fc6b02 Mon Sep 17 00:00:00 2001 From: zhiqiu Date: Thu, 25 Mar 2021 11:30:38 +0000 Subject: [PATCH 6/6] fix typo --- .../test_truncated_gaussian_random_op_npu.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py index 4b95db7915327..ff89508d19623 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_truncated_gaussian_random_op_npu.py @@ -31,31 +31,34 @@ @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") -class TestPowNet(unittest.TestCase): +class TestTruncatedNormal(unittest.TestCase): def _test(self, run_npu=True): main_prog = paddle.static.Program() startup_prog = paddle.static.Program() + scope = paddle.fluid.core.Scope() + main_prog.random_seed = SEED startup_prog.random_seed = SEED np.random.seed(SEED) paddle.seed(SEED) - with paddle.static.program_guard(main_prog, startup_prog): - weight_attr = paddle.framework.ParamAttr( - name="linear_weight", - initializer=paddle.nn.initializer.TruncatedNormal( - mean=0.0, std=2.0)) - linear = paddle.nn.Linear( - 2, 2, weight_attr=weight_attr, bias_attr=False) + with fluid.scope_guard(scope): + with paddle.static.program_guard(main_prog, startup_prog): + weight_attr = paddle.framework.ParamAttr( + name="linear_weight", + initializer=paddle.nn.initializer.TruncatedNormal( + mean=0.0, std=2.0)) + linear = paddle.nn.Linear( + 2, 2, weight_attr=weight_attr, bias_attr=False) - if run_npu: - place = paddle.NPUPlace(0) - else: - place = paddle.CPUPlace() + if run_npu: + place = paddle.NPUPlace(0) + else: + place = paddle.CPUPlace() - exe = paddle.static.Executor(place) - w = exe.run(startup_prog, fetch_list=['linear_weight']) - return w + exe = paddle.static.Executor(place) + w = exe.run(startup_prog, fetch_list=['linear_weight']) + return w def test_npu(self): cpu_w = self._test(False)