Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tensor_scatter_nd_update #7953

Merged
merged 25 commits into from
Apr 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9f98122
fix tensor_scatter_nd_update
simonJJJ Apr 2, 2022
8af17ad
auto backward
simonJJJ Apr 2, 2022
bba60d6
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 2, 2022
115e725
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 2, 2022
fe2d42a
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 2, 2022
f2bc9bd
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 2, 2022
aaf8cc7
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 2, 2022
6b4eb4f
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 3, 2022
7a69447
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 3, 2022
1331114
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 5, 2022
a00b872
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 5, 2022
1d07b69
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 6, 2022
94791ea
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 6, 2022
d8cd3b9
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 6, 2022
0529a58
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 7, 2022
bf31494
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 7, 2022
9519ddf
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 7, 2022
5a4e4c2
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 7, 2022
78a28a5
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 8, 2022
5c0ac33
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 8, 2022
783976b
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 8, 2022
7ff1c75
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 8, 2022
ffd9327
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 9, 2022
9ef1069
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 9, 2022
1c59d2f
Merge branch 'master' into fix_tensor_scatter_nd_update
mergify[bot] Apr 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions oneflow/user/kernels/nd_index_slice_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ struct ScatterNdAddFunctor<DeviceType::kCPU, T, I> final {
}
};

template<typename T, typename I>
struct ScatterNdUpdateFunctor<DeviceType::kCPU, T, I> final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
const T* slices, T* dense) const {
DoScatterNdUpdate<DeviceType::kCPU>(args.num_slices * args.slice_size, args.slice_size,
args.index_ndims, args.dense_shape, indices, slices, dense);
}
};

template<typename T, typename I>
struct FillByNdIndexFunctor<DeviceType::kCPU, T, I> final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
Expand Down
16 changes: 16 additions & 0 deletions oneflow/user/kernels/nd_index_slice_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ __global__ void CudaScatterNdAdd(NdIndexSliceArgs<T, I> args, const I* indices,
args.index_ndims, args.dense_shape, indices, slices, dense);
}

template<typename T, typename I>
__global__ void CudaScatterNdUpdate(NdIndexSliceArgs<T, I> args, const I* indices, const T* slices,
T* dense) {
DoScatterNdUpdate<DeviceType::kCUDA>(args.num_slices * args.slice_size, args.slice_size,
args.index_ndims, args.dense_shape, indices, slices, dense);
}

template<typename T, typename I>
__global__ void CudaFillByNdIndex(NdIndexSliceArgs<T, I> args, const I* indices, T* dense,
T value) {
Expand Down Expand Up @@ -61,6 +68,15 @@ struct ScatterNdAddFunctor<DeviceType::kCUDA, T, I> final {
}
};

template<typename T, typename I>
struct ScatterNdUpdateFunctor<DeviceType::kCUDA, T, I> final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
const T* slices, T* dense) const {
RUN_CUDA_KERNEL((CudaScatterNdUpdate<T, I>), stream, args.num_slices * args.slice_size, args,
indices, slices, dense);
}
};

template<typename T, typename I>
struct FillByNdIndexFunctor<DeviceType::kCUDA, T, I> final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
Expand Down
6 changes: 2 additions & 4 deletions oneflow/user/kernels/nd_index_slice_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ void TensorScatterNdUpdateKernel<device_type, T, I>::Compute(
Memcpy<device_type>(ctx->stream(), out->mut_dptr<T>(), params->dptr<T>(), out_bytes_size);
if (indices->shape().elem_cnt() == 0) { return; }
auto args = ConstructNdIndexSliceArgs<T, I>(*params, *updates, *indices);
FillByNdIndexFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),
out->mut_dptr<T>(), static_cast<T>(0));
ScatterNdAddFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),
updates->dptr<T>(), out->mut_dptr<T>());
ScatterNdUpdateFunctor<device_type, T, I>()(ctx->stream(), args, indices->dptr<I>(),
updates->dptr<T>(), out->mut_dptr<T>());
}

template<DeviceType device_type, typename T, typename I>
Expand Down
16 changes: 16 additions & 0 deletions oneflow/user/kernels/nd_index_slice_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ struct ScatterNdAddFunctor final {
const T* slices, T* dense) const;
};

template<DeviceType device_type, typename T, typename I>
struct ScatterNdUpdateFunctor final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
const T* slices, T* dense) const;
};

template<DeviceType device_type, typename T, typename I>
struct FillByNdIndexFunctor final {
void operator()(ep::Stream* stream, const NdIndexSliceArgs<T, I>& args, const I* indices,
Expand Down Expand Up @@ -101,6 +107,16 @@ OF_DEVICE_FUNC void DoScatterNdAdd(int64_t elem_cnt, int64_t slice_size, int64_t
}
}

template<DeviceType device_type, typename T, typename I>
OF_DEVICE_FUNC void DoScatterNdUpdate(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims,
const int64_t* dense_shape, const I* indices, const T* slices,
T* dense) {
XPU_1D_KERNEL_LOOP(i, elem_cnt) {
int64_t offset = OffsetInSliceToOffsetInDense(slice_size, index_ndims, dense_shape, indices, i);
dense[offset] = slices[i];
}
}

template<typename T, typename I>
OF_DEVICE_FUNC void DoFillByNdIndex(int64_t elem_cnt, int64_t slice_size, int64_t index_ndims,
const int64_t* dense_shape, const I* indices, T* dense,
Expand Down
13 changes: 13 additions & 0 deletions python/oneflow/test/tensor/test_tensor_part_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,19 @@ def compare_setitem_with_numpy(tensor, slices, value):
x = flow.Tensor(2, 3, 4)
compare_setitem_with_numpy(x, se[1, :, 2], v)

@flow.unittest.skip_unless_1n1d()
@autotest(check_graph=True)
def test_setitem_with_random_data(test_case):
device = random_device()
x = random_tensor(low=0, high=0, ndim=1, dim0=16).to(device)
y = random_tensor(low=-2, high=2, ndim=1, dim0=16).to(device)
idx = random_tensor(
low=0, high=15, ndim=1, dim0=20, dtype=int, requires_grad=False
).to(device)
z = y[idx]
x[idx] = z
return x

@flow.unittest.skip_unless_1n1d()
def test_div(test_case):
x = flow.Tensor(np.random.randn(1, 1))
Expand Down