diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc index a500d4c2df6d..3bf3228a4b44 100644 --- a/tests/cpp/operator/mkldnn_operator_test.cc +++ b/tests/cpp/operator/mkldnn_operator_test.cc @@ -347,6 +347,31 @@ OpAttrs GetDeconvBackwardOp(int kernel, int num_filters, int dim, int stride, in return attrs; } +OpAttrs GetBNOp() { + OpAttrs attrs; + attrs.attrs.op = Op::Get("BatchNorm"); + attrs.num_inputs = 5; + attrs.num_outputs = 3; + attrs.accept_dims.insert(4); + attrs.requests.insert(OpReqType::kWriteTo); + attrs.attrs.op->attr_parser(&attrs.attrs); + attrs.input_types = ArrayTypes::Normal | + ArrayTypes::MKLDNN; + attrs.output_types = ArrayTypes::Normal | + ArrayTypes::MKLDNN; + return attrs; +} + +OpAttrs GetBNBackwardOp() { + OpAttrs attrs; + attrs.attrs.op = Op::Get("_backward_BatchNorm"); + attrs.num_inputs = 8; + attrs.num_outputs = 3; + attrs.attrs.op->attr_parser(&attrs.attrs); + attrs.requests.insert(OpReqType::kWriteTo); + return attrs; +} + void AssertEqual(const std::vector &in_arrs, const std::vector &out_arrs, float rtol = 1e-5, float atol = 1e-8) { @@ -710,7 +735,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { // If the array is a view, we shouldn't write data to it. if (in_arr.arr.IsView()) - continue; + continue; NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy"); for (int i = 0; i < forward_attrs.num_inputs; i++) @@ -735,6 +760,124 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { } } + +void TestOpExBNBackward(const OpAttrs &forward_attrs, + const OpAttrs &backwards_attrs, + const OpReqType &req, + const std::vector &inputs, + const std::vector &outputs, + const NDArrayAttrs &in_arr, + NDArrayAttrs* out_arr) { + std::vector backwards_input(backwards_attrs.num_inputs); + + std::vector backwards_buffer(backwards_attrs.num_outputs); + std::vector backwards_buffer2(backwards_attrs.num_outputs); + + std::vector backwards_outputs(backwards_attrs.num_outputs); + std::vector backwards_ex_outputs(backwards_attrs.num_outputs); + std::vector backwards_req(backwards_attrs.num_outputs); + + if (req == kWriteTo) { + backwards_input[0] = &(out_arr->arr); // output grad + backwards_input[1] = outputs[1]; // mean + backwards_input[2] = outputs[2]; // var + backwards_input[3] = inputs[0]; // data + backwards_input[4] = inputs[1]; // gamma + backwards_input[5] = inputs[2]; // beta + backwards_input[6] = inputs[3]; // moving mean + backwards_input[7] = inputs[4]; // moving var + + for (size_t i = 0; i < backwards_attrs.num_outputs; i++) { + auto tmp_output = in_arr.arr; + backwards_buffer.emplace_back(tmp_output.Copy(Context())); + backwards_buffer2.emplace_back(tmp_output.Copy(Context())); + backwards_outputs[i] = &backwards_buffer.back(); + backwards_ex_outputs[i] = &backwards_buffer2.back(); + Engine::Get()->WaitForAll(); + backwards_req[i] = kWriteTo; + } + + std::cout << "Backwards: "; + PrintVerifyMsg(*out_arr, in_arr); + Imperative::Get()->InvokeOp( + Context(), backwards_attrs.attrs, backwards_input, backwards_outputs, + backwards_req, DispatchMode::kFCompute, mxnet::OpStatePtr()); + Imperative::Get()->InvokeOp( + Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs, + backwards_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr()); + Engine::Get()->WaitForAll(); + AssertEqual(backwards_outputs, backwards_ex_outputs); + } +} + +// compares output of fcompute with fcomputex +void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) { + std::vector inputs(forward_attrs.num_inputs); + std::vector inputs2(forward_attrs.num_inputs); + std::vector inputs_buffer(forward_attrs.num_inputs); + std::vector inputs2_buffer(forward_attrs.num_inputs); + std::vector outputs(forward_attrs.num_outputs); + std::vector ex_outputs(forward_attrs.num_outputs); + std::vector req(forward_attrs.num_outputs); + + TestArrayShapes tas = GetTestArrayShapes(); + std::vector pds = tas.pds; + + std::vector in_arrs = GetTestInputArrays(forward_attrs.input_types, false); + std::vector> out_arrs(forward_attrs.num_outputs); + std::vector> ex_out_arrs(forward_attrs.num_outputs); + + if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) { + for (int i1 = 0; i1 < in_arrs.size(); i1++) { + auto in_arr = in_arrs[i1]; + + CHECK_NE(forward_attrs.accept_dims.size(), 0); + if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) == + forward_attrs.accept_dims.end()) + continue; + for (int i = 0; i < forward_attrs.num_outputs; i++) { + out_arrs[i] = + GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types); + ex_out_arrs[i] = + GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types); + } + for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { + inputs_buffer.clear(); + inputs2_buffer.clear(); + + for (int i = 0; i < forward_attrs.num_inputs; i++) { + inputs_buffer.emplace_back(in_arr.arr.Copy(Context())); + inputs2_buffer.emplace_back(in_arr.arr.Copy(Context())); + Engine::Get()->WaitForAll(); + inputs[i] = &inputs_buffer.back(); + inputs2[i] = &inputs2_buffer.back(); + } + for (int i = 0; i < forward_attrs.num_outputs; i++) { + req[i] = kWriteTo; + outputs[i] = &out_arrs[i][output_i].arr; + ex_outputs[i] = &ex_out_arrs[i][output_i].arr; + } + Imperative::Get()->set_is_training(true); + + PrintVerifyMsg(in_arr, out_arrs[0][output_i]); + Imperative::Get()->InvokeOp( + Context(), forward_attrs.attrs, inputs, outputs, req, + DispatchMode::kFCompute, mxnet::OpStatePtr()); + Imperative::Get()->InvokeOp( + Context(), forward_attrs.attrs, inputs2, ex_outputs, req, + DispatchMode::kFComputeEx, mxnet::OpStatePtr()); + Engine::Get()->WaitForAll(); + AssertEqual(outputs, ex_outputs); + + if (!backwards_attrs.requests.empty()) { + TestOpExBNBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo, + inputs, outputs, in_arr, &out_arrs[0][output_i]); + } + } + } + } +} + // Computes second dimension of FC weight matrix based on input shape uint32_t GetFCWeightDim2(const nnvm::TShape arr) { uint32_t dim = 1; @@ -1204,4 +1347,10 @@ TEST(IMPERATIVE, DeconvOp) { } } +TEST(IMPERATIVE, BNOp) { + OpAttrs forward_attrs = GetBNOp(); + OpAttrs backwards_attrs = GetBNBackwardOp(); + TestOpExBN(forward_attrs, backwards_attrs); +} + #endif