diff --git a/paddle/fluid/operators/scatter_op_npu.cc b/paddle/fluid/operators/scatter_op_npu.cc new file mode 100755 index 0000000000000..fb6958e9046cd --- /dev/null +++ b/paddle/fluid/operators/scatter_op_npu.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_ASCEND_CL +#include +#include + +#include "paddle/fluid/operators/scatter_op.h" +#include "paddle/fluid/operators/kron_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ScatterNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + + auto* x = ctx.Input("X"); + auto* index = ctx.Input("Ids"); + auto* updates = ctx.Input("Updates"); + bool overwrite = ctx.Attr("overwrite"); + + auto* out = ctx.Output("Out"); + + auto place = ctx.GetPlace(); + out->mutable_data(place); + + framework::Tensor tmp_tensor(index->type()); + const auto index_dims = index->dims(); + if (index_dims.size() == 1) { + tmp_tensor.ShareDataWith(*index); + std::vector new_dim = {index_dims[0], 1}; + tmp_tensor.Resize(framework::make_ddim(new_dim)); + index = &tmp_tensor; + } + + auto stream = + ctx.template device_context() + .stream(); + + if (overwrite){ + auto runner_update = NpuOpRunner("TensorScatterUpdate", {*x, *index, *updates}, {*out}, {}); + runner_update.Run(stream); + } + else{ + auto runner_add = NpuOpRunner("TensorScatterAdd", {*x, *index, *updates}, {*out}, {}); + runner_add.Run(stream); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + scatter, + ops::ScatterNPUKernel, + ops::ScatterNPUKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/npu/test_scatter_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_scatter_op_npu.py new file mode 100755 index 0000000000000..3110672b2dab6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_scatter_op_npu.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestCast1(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "scatter" + self.place = paddle.NPUPlace(0) + + ref_np = np.ones((3, 2)).astype("float32") + index_np = np.array([1]).astype("int32") + updates_np = np.random.random((1, 2)).astype("float32") + + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + self.attrs = {'overwrite': True} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +class TestCast2(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "scatter" + self.place = paddle.NPUPlace(0) + + ref_np = np.ones((3, 2)).astype("int32") + index_np = np.array([1]).astype("int32") + updates_np = np.zeros((1, 2)).astype("int32") + + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + self.attrs = {'overwrite': True} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + +class TestCast3(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "scatter" + self.place = paddle.NPUPlace(0) + + ref_np = np.ones((3, 2)).astype("float32") + index_np = np.array([1]).astype("int32") + updates_np = np.random.random((1, 2)).astype("float32") + + output_np = np.copy(ref_np) + output_np[index_np] += updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + self.attrs = {'overwrite': False} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +class TestCast4(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "scatter" + self.place = paddle.NPUPlace(0) + + ref_np = np.ones((3, 2)).astype("float32") + index_np = np.array([1, 2]).astype("int32") + updates_np = np.random.random((2, 2)).astype("float32") + + output_np = np.copy(ref_np) + output_np[1] = updates_np[0] + output_np[2] = updates_np[1] + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + self.attrs = {'overwrite': True} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + +if __name__ == '__main__': + unittest.main()