From 2f779be3ef06476aee101ed713b55a512e991f40 Mon Sep 17 00:00:00 2001 From: adstraw Date: Wed, 7 Feb 2018 15:49:17 -0800 Subject: [PATCH] update tests for grad_req change --- src/ngraph/ngraph_nnvm_utils.h | 1 + tests/cpp/ngraph/test_ngraph_imperative.cc | 2 +- tests/cpp/ngraph/test_ngraph_imperative.h | 2 ++ tests/cpp/ngraph/test_ngraph_utils.cc | 11 +++++++---- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/ngraph/ngraph_nnvm_utils.h b/src/ngraph/ngraph_nnvm_utils.h index b988a2d80..0e320e1b4 100644 --- a/src/ngraph/ngraph_nnvm_utils.h +++ b/src/ngraph/ngraph_nnvm_utils.h @@ -19,6 +19,7 @@ #include #include +#include #include "ngraph_sgcompiler_utils.h" namespace ngraph_bridge { diff --git a/tests/cpp/ngraph/test_ngraph_imperative.cc b/tests/cpp/ngraph/test_ngraph_imperative.cc index 7992c0a11..984ef79d6 100644 --- a/tests/cpp/ngraph/test_ngraph_imperative.cc +++ b/tests/cpp/ngraph/test_ngraph_imperative.cc @@ -82,7 +82,7 @@ TEST_F(NGRAPH_IMPERATIVE, INVOKE_OP) { EXPECT_TRUE(op_ng); EXPECT_TRUE(test.op_ngraph_->ngraph_forward); EXPECT_EQ(vec3, std::vector({0, 0})); - compute_forward(opctx, op_ng, inputs, outputs); + compute_forward(opctx, op_ng, inputs, req, outputs); EXPECT_EQ(vec3, std::vector({2, 6})); } diff --git a/tests/cpp/ngraph/test_ngraph_imperative.h b/tests/cpp/ngraph/test_ngraph_imperative.h index 255cc2db6..c6e39fcb4 100644 --- a/tests/cpp/ngraph/test_ngraph_imperative.h +++ b/tests/cpp/ngraph/test_ngraph_imperative.h @@ -20,6 +20,7 @@ #include "../../src/ngraph/ngraph_imperative.h" #include "../../src/ngraph/ngraph_nnvm_utils.h" #include "test_util.h" + namespace ngraph_bridge { class NGRAPH_IMPERATIVE : public ::testing::Test { @@ -39,6 +40,7 @@ class NGRAPH_IMPERATIVE : public ::testing::Test { std::vector vec3{0, 0}; std::vector inputs; std::vector outputs; + std::vector req{mxnet::kWriteTo}; }; class testImperative : public NGImperative { diff --git a/tests/cpp/ngraph/test_ngraph_utils.cc b/tests/cpp/ngraph/test_ngraph_utils.cc index 98a92dfc7..7bbbc85e9 100644 --- a/tests/cpp/ngraph/test_ngraph_utils.cc +++ b/tests/cpp/ngraph/test_ngraph_utils.cc @@ -115,17 +115,20 @@ TEST(NGRAPH_NNVM, copy_TBlobs) { /* placeholders[1]) */ /* ->get_vector()); */ std::vector vec3{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - std::vector vec4{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::vector vec4{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; mxnet::TBlob TBlob3(vec3.data(), shape, 0); mxnet::TBlob TBlob4(vec4.data(), shape, 0); std::vector outblobs; outblobs.push_back(TBlob3); outblobs.push_back(TBlob4); - result_to_TBlob(placeholders[0], outblobs, 0); - result_to_TBlob(placeholders[1], outblobs, 1); + // test 1: kWriteTo - vec3 = vec1 + // test 2: kAddTo - vec4 += vec2 + std::vector req{mxnet::kWriteTo, mxnet::kAddTo}; + result_to_TBlob(placeholders, req, outblobs); EXPECT_EQ(vec1, vec3); - EXPECT_EQ(vec2, vec4); + std::vector vec4_plus_vec2{12, 13, 14, 15, 16, 17, 18, 19, 20, 11}; + EXPECT_EQ(vec4_plus_vec2, vec4); } } // namespace ngraph_bridge