diff --git a/paddle/fluid/operators/increment_op_npu.cc b/paddle/fluid/operators/increment_op_npu.cc new file mode 100644 index 0000000000000..90f9787cc38cf --- /dev/null +++ b/paddle/fluid/operators/increment_op_npu.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "paddle/fluid/operators/increment_op.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace framework { +class OpDesc; +class Variable; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +} // namespace paddle + +namespace paddle { +namespace operators { + + +template +class IncrementalNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x_tensor = context.Input("X"); + auto* out_tensor = context.Output("Out"); + float step = context.Attr("step"); + out_tensor->mutable_data(context.GetPlace()); + + Tensor step_tensor(x_tensor->type()); + std::vector step_vec; + step_vec.push_back(static_cast(step)); + framework::TensorFromVector( + step_vec, + context.device_context(), + &step_tensor); + + auto runner = NpuOpRunner("Add", + {*x_tensor, step_tensor}, + {*out_tensor}, + {}); + + auto stream = + context.template device_context() + .stream(); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + increment, + ops::IncrementalNPUKernel, + ops::IncrementalNPUKernel, + ops::IncrementalNPUKernel, + ops::IncrementalNPUKernel, + ops::IncrementalNPUKernel) + diff --git a/paddle/fluid/operators/increment_op_npu_test.cc b/paddle/fluid/operators/increment_op_npu_test.cc new file mode 100644 index 0000000000000..f4ce9ffe40b0d --- /dev/null +++ b/paddle/fluid/operators/increment_op_npu_test.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(increment); +USE_OP_DEVICE_KERNEL(increment, NPU); + +template +void Compare(f::Scope* scope, const p::DeviceContext& ctx, + std::string op_type) { + // init + auto x = scope->Var("X"); + auto tensor_x = x->GetMutable(); + + std::vector init; + init.push_back(static_cast(1.0)); + + TensorFromVector(init, ctx, tensor_x); + tensor_x->Resize({1}); + + ctx.Wait(); + + auto place = ctx.GetPlace(); + auto out = scope->Var("Out"); + auto tensor_out = out->GetMutable(); + + f::AttributeMap attr_input = { {"step", static_cast(2.0)} }; + auto op = f::OpRegistry::CreateOp("increment", {{"X", {"X"}}}, + {{"Out", {"Out"}}}, + attr_input); + + op->Run(*scope, place); + + std::vector out_vec; + TensorToVector(*tensor_out, ctx, &out_vec); + + ctx.Wait(); + + EXPECT_EQ((uint32_t)out_vec.size(), (uint32_t)1); + EXPECT_EQ(out_vec[0], static_cast(3.0)); +} + + +TEST(increment, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "increment"); +} + +TEST(increment, NPU_fp64) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx, "increment"); +} + diff --git a/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py new file mode 100644 index 0000000000000..09019e36c82fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py @@ -0,0 +1,116 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import core + +paddle.enable_static() +SEED = 2021 + +NPUPlace = 5 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestIncrement(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(NPUPlace) + self.op_type = "increment" + self.init_dtype() + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), } + + self.attrs = {"Step": 1} + self.outputs = {'Out': np.array([2])} + + def set_npu(self): + self.__class__.use_npu = True + self.__class__.no_need_check_grad = True + + def init_dtype(self): + self.dtype = np.int64 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestIncrementFP16(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(NPUPlace) + self.op_type = "increment" + self.init_dtype() + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), } + self.pre_input_id = id(self.inputs['X']) + + self.attrs = {"Step": 1} + self.outputs = {'Out': np.array([2])} + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestIncrementInplace(unittest.TestCase): + def test_npu(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + main_prog.random_seed = SEED + startup_prog.random_seed = SEED + np.random.seed(SEED) + + a_np = np.array([1]).astype('float32') + + with paddle.static.program_guard(main_prog, startup_prog): + a = paddle.static.data(name="a", shape=[1], dtype='float32') + b = fluid.layers.increment(a) + + place = paddle.NPUPlace(NPUPlace) + + exe = paddle.static.Executor(place) + exe.run(startup_prog) + + b_value = exe.run( + main_prog, + feed={"a": a_np,}, + fetch_list=[b]) + + print('input a id is : {}'.format(id(a))) + print('input b id is : {}'.format(id(b))) + + self.assertEqual(id(a), id(b)) + self.assertEqual(b_value[0], 2) + + +if __name__ == '__main__': + unittest.main()