From 02b50bff31af20eb2bce370cf51b8501f499de6c Mon Sep 17 00:00:00 2001 From: barry-jin Date: Mon, 10 Jan 2022 14:40:34 -0800 Subject: [PATCH] [BUGFIX] Fix #20471 --- src/operator/tensor/elemwise_binary_op.cc | 2 +- tests/python/gpu/test_operator_gpu.py | 156 ++++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/elemwise_binary_op.cc b/src/operator/tensor/elemwise_binary_op.cc index a1dcc387d9cb..7b6e662aa40e 100644 --- a/src/operator/tensor/elemwise_binary_op.cc +++ b/src/operator/tensor/elemwise_binary_op.cc @@ -279,7 +279,7 @@ void ElemwiseBinaryRTCBwdUseNone::operator()(const nnvm::NodeAttrs& attrs, (req[0] == kWriteInplace && LOP != "identity")); bool write_right_output = req[1] != kNullOp && (req[1] != kWriteInplace || - (req[1] == kWriteInplace && LOP != "identity")); + (req[1] == kWriteInplace && ROP != "identity")); const std::string code = std::string("const OpReqType lreq = ") + util::to_string(req[0]) + ";\n" diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 6592cd490dac..85d20091436c 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2212,3 +2212,159 @@ def test_split_v2_fwd(dtype): data = mx.sym.Variable("data") sym = mx.sym.split_v2(data, indices_or_sections=indices, axis=axis) check_symbolic_forward(sym, {"data": mx_data}, np_out, rtol=1e-3, atol=1e-5) + + +def test_subtract_backward(): + sym_json = \ + """ + { + "nodes":[ + { + "op":"null", + "name":".Inputs.Input", + "inputs":[] + }, + { + "op":"SwapAxis", + "name":".$0", + "attrs":{ + "dim1":"0", + "dim2":"1" + }, + "inputs":[[0,0,0]] + }, + { + "op":"Reshape", + "name":".$1", + "attrs":{ + "shape":"(-3, -2)" + }, + "inputs":[[1,0,0]] + }, + { + "op":"null", + "name":".Inputs.Target", + "inputs":[] + }, + { + "op":"SwapAxis", + "name":".$2", + "attrs":{ + "dim1":"0", + "dim2":"1" + }, + "inputs":[[3,0,0]] + }, + { + "op":"Reshape", + "name":".$3", + "attrs":{ + "shape":"(-3, -2)" + }, + "inputs":[[4,0,0]] + }, + { + "op":"elemwise_sub", + "name":".$4", + "inputs":[[2,0,0],[5,0,0]] + }, + { + "op":"abs", + "name":".$5", + "inputs":[[6,0,0]] + }, + { + "op":"mean", + "name":".$6", + "attrs":{ + "axis":"0", + "exclude":"true", + "keepdims":"false" + }, + "inputs":[[7,0,0]] + }, + { + "op":"reshape_like", + "name":".$7", + "attrs":{ + "lhs_begin":"0", + "lhs_end":"1", + "rhs_begin":"0", + "rhs_end":"2" + }, + "inputs":[[8,0,0],[1,0,0]] + }, + { + "op":"null", + "name":"seq_715248120", + "inputs":[] + }, + { + "op":"SequenceMask", + "name":".$8", + "attrs":{ + "axis":"0", + "use_sequence_length":"true", + "value":"0." + }, + "inputs":[[9,0,0],[10,0,0]] + }, + { + "op":"sum", + "name":".$9", + "attrs":{ + "axis":"0", + "keepdims":"false" + }, + "inputs":[[11,0,0]] + }, + { + "op":"elemwise_div", + "name":".$10", + "inputs":[[12,0,0],[10,0,0]] + }, + { + "op":"_copy", + "name":".Outputs.Loss", + "inputs":[[13,0,0]] + }, + { + "op":"_copy", + "name":"seq_715248120$0", + "inputs":[[10,0,0]] + }], + "arg_nodes":[0,3,10], + "node_row_ptr":[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16], + "heads":[[14,0,0],[15,0,0]], + "attrs":{ + "mxnet_version":["int",20000] + } + } + """ + + sym = mx.sym.fromjson(sym_json) + + def run_example(ctx, reqs): + ex = sym._bind( + ctx, + { + '.Inputs.Input': mx.ndarray.array([[1, 2, 3]], ctx=ctx), + '.Inputs.Target': mx.ndarray.array([[4, 5, 6]], ctx=ctx), + 'seq_715248120': mx.ndarray.array([3], ctx=ctx) + }, + args_grad={ + '.Inputs.Input': mx.ndarray.zeros([1, 3], ctx=ctx), + '.Inputs.Target': mx.ndarray.zeros([1, 3], ctx=ctx), + 'seq_715248120': mx.ndarray.zeros([1], ctx=ctx) + }, + grad_req=dict(zip(['.Inputs.Input', '.Inputs.Target', 'seq_715248120'], reqs)) + ) + + ex.forward() + ex.backward(out_grads=[mx.ndarray.array([1], ctx=ctx), mx.ndarray.array([1], ctx=ctx)]) + + return ex.grad_dict['.Inputs.Target'] + + out1 = run_example(mx.gpu(), ['write', 'write', 'null']) + out2 = run_example(mx.gpu(), ['null', 'write', 'null']) + assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-3, atol=1e-5)