Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add handling for grad req type other than kNullOp for indices
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Aug 14, 2018
1 parent 525ead9 commit a8e341b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1034,8 +1034,8 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 2U);
CHECK_EQ(req[take_::kIdx], kNullOp)
<< "take layer doesn't support gradient into index";
CHECK_NE(req[take_::kIdx], kAddTo)
<< "take layer doesn't support gradient of req type kAddTo to index";

const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);

Expand All @@ -1052,6 +1052,11 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
const TShape& arrshape = outputs[0].shape_;
const TShape& oshape = inputs[0].shape_;

if (req[take_::kIdx] != kNullOp) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
s, idxshape.Size(), outputs[take_::kIdx].dptr<IType>());
}

const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

int idxndim = idxshape.ndim();
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3810,6 +3810,31 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
exe.backward([mx.nd.array(grad_out)])
assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in)

def check_autograd_req():
row_len = 2
col_len = 8
shape = (row_len, col_len)
sc = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype="float32")
sc.attach_grad()
i = mx.nd.array([0], dtype="int64")
j = mx.nd.array([0], dtype="int64")
with mx.autograd.record(train_mode=True):
xs = []
for _ in range(row_len):
x_i = []
for _ in range(col_len):
x_ij = sc.take(i).squeeze(axis=0).take(j).squeeze(axis=0)
x_i.append(x_ij)
j = j + 1
i = i + 1
j = j - col_len # reset j
xs.append(mx.nd.stack(*x_i))
x = mx.nd.stack(*xs)
x = x.sum()

x.backward()
assert_almost_equal(np.ones(sc.grad.shape), sc.grad.asnumpy())

for mode in ['clip', 'wrap']:
for data_ndim in range(1, 5):
for idx_ndim in range(1, 4):
Expand All @@ -3822,6 +3847,8 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
idx_shape += (np.random.randint(low=1, high=5), )
check_output_n_grad(data_shape, idx_shape, axis, mode)

check_autograd_req()


@with_seed()
def test_grid_generator():
Expand Down

0 comments on commit a8e341b

Please sign in to comment.