Skip to content

Commit

Permalink
fix tensor_scatter_nd_update (#7953)
Browse files Browse the repository at this point in the history
* fix tensor_scatter_nd_update

* auto backward

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and xiacijie committed Apr 24, 2022
1 parent fc455bb commit ef6dbb1
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 4 deletions.
9 changes: 9 additions & 0 deletions oneflow/user/kernels/nd_index_slice_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ struct ScatterNdAddFunctor<DeviceType::kCPU, T, I> final {
}
};

template<typename T, typename I>
struct ScatterNdUpdateFunctor<DeviceType::kCPU, T, I> final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
const T* slices, T* dense) const {
DoScatterNdUpdate<DeviceType::kCPU>(args.num_slices * args.slice_size, args.slice_size,
args.index_ndims, args.dense_shape, indices, slices, dense);
}
};

template<typename T, typename I>
struct FillByNdIndexFunctor<DeviceType::kCPU, T, I> final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
Expand Down
16 changes: 16 additions & 0 deletions oneflow/user/kernels/nd_index_slice_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ __global__ void CudaScatterNdAdd(NdIndexSliceArgs<T, I> args, const I* indices,
args.index_ndims, args.dense_shape, indices, slices, dense);
}

template<typename T, typename I>
__global__ void CudaScatterNdUpdate(NdIndexSliceArgs<T, I> args, const I* indices, const T* slices,
T* dense) {
DoScatterNdUpdate<DeviceType::kCUDA>(args.num_slices * args.slice_size, args.slice_size,
args.index_ndims, args.dense_shape, indices, slices, dense);
}

template<typename T, typename I>
__global__ void CudaFillByNdIndex(NdIndexSliceArgs<T, I> args, const I* indices, T* dense,
T value) {
Expand Down Expand Up @@ -61,6 +68,15 @@ struct ScatterNdAddFunctor<DeviceType::kCUDA, T, I> final {
}
};

template<typename T, typename I>
struct ScatterNdUpdateFunctor<DeviceType::kCUDA, T, I> final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
const T* slices, T* dense) const {
RUN_CUDA_KERNEL((CudaScatterNdUpdate<T, I>), stream, args.num_slices * args.slice_size, args,
indices, slices, dense);
}
};

template<typename T, typename I>
struct FillByNdIndexFunctor<DeviceType::kCUDA, T, I> final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
Expand Down
6 changes: 2 additions & 4 deletions oneflow/user/kernels/nd_index_slice_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ void TensorScatterNdUpdateKernel<device_type, T, I>::Compute(
Memcpy<device_type>(ctx->stream(), out->mut_dptr<T>(), params->dptr<T>(), out_bytes_size);
if (indices->shape().elem_cnt() == 0) { return; }
auto args = ConstructNdIndexSliceArgs<T, I>(*params, *updates, *indices);
FillByNdIndexFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),
out->mut_dptr<T>(), static_cast<T>(0));
ScatterNdAddFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),
updates->dptr<T>(), out->mut_dptr<T>());
ScatterNdUpdateFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),
updates->dptr<T>(), out->mut_dptr<T>());
}

template<DeviceType device_type, typename T, typename I>
Expand Down
16 changes: 16 additions & 0 deletions oneflow/user/kernels/nd_index_slice_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ struct ScatterNdAddFunctor final {
const T* slices, T* dense) const;
};

template<DeviceType device_type, typename T, typename I>
struct ScatterNdUpdateFunctor final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
const T* slices, T* dense) const;
};

template<DeviceType device_type, typename T, typename I>
struct FillByNdIndexFunctor final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
Expand Down Expand Up @@ -101,6 +107,16 @@ OF_DEVICE_FUNC void DoScatterNdAdd(int64_t elem_cnt, int64_t slice_size, int64_t
}
}

template<DeviceType device_type, typename T, typename I>
OF_DEVICE_FUNC void DoScatterNdUpdate(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims,
const int64_t* dense_shape, const I* indices, const T* slices,
T* dense) {
XPU_1D_KERNEL_LOOP(i, elem_cnt) {
int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i);
dense[offset] = slices[i];
}
}

template<typename T, typename I>
OF_DEVICE_FUNC void DoFillByNdIndex(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims,
const int64_t* dense_shape, const I* indices, T* dense,
Expand Down
13 changes: 13 additions & 0 deletions python/oneflow/test/tensor/test_tensor_part_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,19 @@ def compare_setitem_with_numpy(tensor, slices, value):
x = flow.Tensor(2, 3, 4)
compare_setitem_with_numpy(x, se[1, :, 2], v)

@flow.unittest.skip_unless_1n1d()
@autotest(check_graph=True)
def test_setitem_with_random_data(test_case):
device = random_device()
x = random_tensor(low=0, high=0, ndim=1, dim0=16).to(device)
y = random_tensor(low=-2, high=2, ndim=1, dim0=16).to(device)
idx = random_tensor(
low=0, high=15, ndim=1, dim0=20, dtype=int, requires_grad=False
).to(device)
z = y[idx]
x[idx] = z
return x

@flow.unittest.skip_unless_1n1d()
def test_div(test_case):
x = flow.Tensor(np.random.randn(1, 1))
Expand Down

0 comments on commit ef6dbb1

Please sign in to comment.