From 9f98122de98c9dde462bfa6afc7b2d19bf484a57 Mon Sep 17 00:00:00 2001 From: simonJJJ <821898965@qq.com> Date: Sat, 2 Apr 2022 13:12:54 +0800 Subject: [PATCH 1/2] fix tensor_scatter_nd_update --- oneflow/user/kernels/nd_index_slice_kernels.cpp | 9 +++++++++ oneflow/user/kernels/nd_index_slice_kernels.cu | 16 ++++++++++++++++ oneflow/user/kernels/nd_index_slice_kernels.h | 6 ++---- oneflow/user/kernels/nd_index_slice_util.h | 16 ++++++++++++++++ python/oneflow/test/tensor/test_tensor_part_1.py | 13 +++++++++++++ 5 files changed, 56 insertions(+), 4 deletions(-) diff --git a/oneflow/user/kernels/nd_index_slice_kernels.cpp b/oneflow/user/kernels/nd_index_slice_kernels.cpp index 009c75b27e7..6c7f44fd677 100644 --- a/oneflow/user/kernels/nd_index_slice_kernels.cpp +++ b/oneflow/user/kernels/nd_index_slice_kernels.cpp @@ -35,6 +35,15 @@ struct ScatterNdAddFunctor final { } }; +template +struct ScatterNdUpdateFunctor final { + void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, + const T* slices, T* dense) const { + DoScatterNdUpdate(args.num_slices * args.slice_size, args.slice_size, + args.index_ndims, args.dense_shape, indices, slices, dense); + } +}; + template struct FillByNdIndexFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, diff --git a/oneflow/user/kernels/nd_index_slice_kernels.cu b/oneflow/user/kernels/nd_index_slice_kernels.cu index 06f43470127..3e22651ab73 100644 --- a/oneflow/user/kernels/nd_index_slice_kernels.cu +++ b/oneflow/user/kernels/nd_index_slice_kernels.cu @@ -34,6 +34,13 @@ __global__ void CudaScatterNdAdd(NdIndexSliceArgs args, const I* indices, args.index_ndims, args.dense_shape, indices, slices, dense); } +template +__global__ void CudaScatterNdUpdate(NdIndexSliceArgs args, const I* indices, const T* slices, + T* dense) { + DoScatterNdUpdate(args.num_slices * args.slice_size, args.slice_size, + args.index_ndims, args.dense_shape, indices, slices, dense); +} + template __global__ void CudaFillByNdIndex(NdIndexSliceArgs args, const I* indices, T* dense, T value) { @@ -61,6 +68,15 @@ struct ScatterNdAddFunctor final { } }; +template +struct ScatterNdUpdateFunctor final { + void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, + const T* slices, T* dense) const { + RUN_CUDA_KERNEL((CudaScatterNdUpdate), stream, args.num_slices * args.slice_size, args, + indices, slices, dense); + } +}; + template struct FillByNdIndexFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, diff --git a/oneflow/user/kernels/nd_index_slice_kernels.h b/oneflow/user/kernels/nd_index_slice_kernels.h index 848ed9d1b1a..871c73f47eb 100644 --- a/oneflow/user/kernels/nd_index_slice_kernels.h +++ b/oneflow/user/kernels/nd_index_slice_kernels.h @@ -103,10 +103,8 @@ void TensorScatterNdUpdateKernel::Compute( Memcpy(ctx->stream(), out->mut_dptr(), params->dptr(), out_bytes_size); if (indices->shape().elem_cnt() == 0) { return; } auto args = ConstructNdIndexSliceArgs(*params, *updates, *indices); - FillByNdIndexFunctor()(ctx->stream(), args, indices->dptr(), - out->mut_dptr(), static_cast(0)); - ScatterNdAddFunctor()(ctx->stream(), args, indices->dptr(), - updates->dptr(), out->mut_dptr()); + ScatterNdUpdateFunctor()(ctx->stream(), args, indices->dptr(), + updates->dptr(), out->mut_dptr()); } template diff --git a/oneflow/user/kernels/nd_index_slice_util.h b/oneflow/user/kernels/nd_index_slice_util.h index ad4065b3bee..167dd0cba29 100644 --- a/oneflow/user/kernels/nd_index_slice_util.h +++ b/oneflow/user/kernels/nd_index_slice_util.h @@ -55,6 +55,12 @@ struct ScatterNdAddFunctor final { const T* slices, T* dense) const; }; +template +struct ScatterNdUpdateFunctor final { + void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, + const T* slices, T* dense) const; +}; + template struct FillByNdIndexFunctor final { void operator()(ep::Stream* stream, const NdIndexSliceArgs& args, const I* indices, @@ -101,6 +107,16 @@ OF_DEVICE_FUNC void DoScatterNdAdd(int64_t elem_cnt, int64_t slice_size, int64_t } } +template +OF_DEVICE_FUNC void DoScatterNdUpdate(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims, + const int64_t* dense_shape, const I* indices, const T* slices, + T* dense) { + XPU_1D_KERNEL_LOOP(i, elem_cnt) { + int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i); + dense[offset] = slices[i]; + } +} + template OF_DEVICE_FUNC void DoFillByNdIndex(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims, const int64_t* dense_shape, const I* indices, T* dense, diff --git a/python/oneflow/test/tensor/test_tensor_part_1.py b/python/oneflow/test/tensor/test_tensor_part_1.py index 34064a55fc9..c7f4c79e6cc 100644 --- a/python/oneflow/test/tensor/test_tensor_part_1.py +++ b/python/oneflow/test/tensor/test_tensor_part_1.py @@ -474,6 +474,19 @@ def compare_setitem_with_numpy(tensor, slices, value): x = flow.Tensor(2, 3, 4) compare_setitem_with_numpy(x, se[1, :, 2], v) + @flow.unittest.skip_unless_1n1d() + @autotest(auto_backward=False, check_graph=True) + def test_setitem_with_random_data(test_case): + device = random_device() + x = random_tensor(low=0, high=0, ndim=1, dim0=16).to(device) + y = random_tensor(low=-2, high=2, ndim=1, dim0=16).to(device) + idx = random_tensor( + low=0, high=15, ndim=1, dim0=20, dtype=int, requires_grad=False + ).to(device) + z = y[idx] + x[idx] = z + return x + @flow.unittest.skip_unless_1n1d() def test_div(test_case): x = flow.Tensor(np.random.randn(1, 1)) From 8af17ad42471f86bb2c26bc06abdd871c3750ad5 Mon Sep 17 00:00:00 2001 From: simonJJJ <821898965@qq.com> Date: Sat, 2 Apr 2022 14:23:45 +0800 Subject: [PATCH 2/2] auto backward --- python/oneflow/test/tensor/test_tensor_part_1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/test/tensor/test_tensor_part_1.py b/python/oneflow/test/tensor/test_tensor_part_1.py index c7f4c79e6cc..9cdbd04745a 100644 --- a/python/oneflow/test/tensor/test_tensor_part_1.py +++ b/python/oneflow/test/tensor/test_tensor_part_1.py @@ -475,7 +475,7 @@ def compare_setitem_with_numpy(tensor, slices, value): compare_setitem_with_numpy(x, se[1, :, 2], v) @flow.unittest.skip_unless_1n1d() - @autotest(auto_backward=False, check_graph=True) + @autotest(check_graph=True) def test_setitem_with_random_data(test_case): device = random_device() x = random_tensor(low=0, high=0, ndim=1, dim0=16).to(device)