diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index 570911c23568..8e725f6669b6 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -56,7 +56,7 @@ inline const char *TimingDirectionAsString(const TimingDirection td) { * Low-noise operator executor * @tparam DType Data type for the operator executions */ -template +template class CoreOpExecutor : public test::op::OperatorDataInitializer , public test::op::OperatorExecutorTiming { /*! \brief Performance timing categories */ @@ -224,7 +224,43 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer } public: + + enum BlobVectorType { + kInput, + kOutput, + kAux, + kInGrad, + kOutGrad, + kBlobVectorTypeCount + }; + +#define CASE_STR(__v$) case (__v$): return #__v$ + + /*! \brief Convert BlobVectorType enum into a string */ + static inline const char *bvt2String(const BlobVectorType bvt) { + switch (bvt) { + CASE_STR(kInput); + CASE_STR(kOutput); + CASE_STR(kAux); + CASE_STR(kInGrad); + CASE_STR(kOutGrad); + default: + CHECK(false); + return ""; + } + } +#undef CASE_STR + + inline const std::vector& getBlobVect(const BlobVectorType bvt) const { + // Not implemented + CHECK(false); + static std::vector dummy; + return dummy; + } + + typedef DType DataType; + typedef AccReal AccRealType; /*! \brief Add 'fwd_op_name' to kwargs and return the new kwargs */ static kwargs_t ArgsWithOpName(const kwargs_t& args, @@ -519,6 +555,8 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer */ std::vector& inputs() { return inputs_; } const std::vector& inputs() const { return inputs_; } + std::vector& input_blobs() { return blob_inputs_; } + const std::vector& input_blobs() const { return blob_inputs_; } /*! * \brief Access input NDArray vector @@ -526,6 +564,8 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer */ std::vector& outputs() { return outputs_; } const std::vector& outputs() const { return outputs_; } + std::vector& output_blobs() { return blob_outputs_; } + const std::vector& output_blobs() const { return blob_outputs_; } /*! * \brief Backward inputs (i.e. output grad) @@ -549,6 +589,14 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer verbose_ = verbose; } + virtual void resetForward() { + CHECK(false) << "Not implemented, generally inits forward-pass data"; + } + + virtual void resetBackward() { + CHECK(false) << "Not implemented, generally inits backward-pass data"; + } + private: /*! * \brief Has the execution been initialized? diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index 607b9804684a..f38acf6fe6da 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -24,8 +24,6 @@ * \author Chris Olivier */ -#if 0 - #include #include #include "../../src/operator/nn/batch_norm-inl.h" @@ -61,31 +59,30 @@ static constexpr int TIMING_DEPTH = 2; static constexpr int TIMING_DH = 28; static constexpr int TIMING_DW = 28; - /*! \brief BatchNorm-specific test data */ template -class BNOperatorExecutor : public test::op::LegacyOperatorExecutor { +class BNOperatorExecutor : public test::op::CoreOpExecutor { public: BNOperatorExecutor(const bool isGPU, const TShape& inputShape, const bool hasWeightAndBias = false) - : test::op::LegacyOperatorExecutor(isGPU, { inputShape }) + : test::op::CoreOpExecutor(isGPU, { inputShape }) , hasWeightAndBias_(hasWeightAndBias) { } void resetForward() override { // Init input data MSHADOW_TYPE_SWITCH( - this->c_.blob_input_vec_[mxnet::op::batchnorm::kData].type_flag_, + this->input_blobs()[mxnet::op::batchnorm::kData].type_flag_, DTypeX, { DTypeX val = 0; - test::patternFill(&this->c_.blob_input_vec_[mxnet::op::batchnorm::kData], + test::patternFill(&this->input_blobs()[mxnet::op::batchnorm::kData], [&val]{ return val += 1; }); }); MSHADOW_TYPE_SWITCH( - this->c_.blob_input_vec_[mxnet::op::batchnorm::kGamma].type_flag_, + this->input_blobs()[mxnet::op::batchnorm::kGamma].type_flag_, DTypeX, { - const TBlob& blob = this->c_.blob_input_vec_[mxnet::op::batchnorm::kGamma]; + const TBlob& blob = this->input_blobs()[mxnet::op::batchnorm::kGamma]; test::fill(blob, DTypeX(1)); if (hasWeightAndBias_) { if (blob.size(0) > 1) { @@ -94,9 +91,9 @@ class BNOperatorExecutor : public test::op::LegacyOperatorExecutorc_.blob_input_vec_[mxnet::op::batchnorm::kBeta].type_flag_, + this->input_blobs()[mxnet::op::batchnorm::kBeta].type_flag_, DTypeX, { - const TBlob& blob = this->c_.blob_input_vec_[mxnet::op::batchnorm::kBeta]; + const TBlob& blob = this->input_blobs()[mxnet::op::batchnorm::kBeta]; if (!hasWeightAndBias_) { test::fill(blob, DTypeX(0)); } else { // This will cause forward pass check to fail when calculating sum == 0 @@ -109,67 +106,67 @@ class BNOperatorExecutor : public test::op::LegacyOperatorExecutorc_.blob_aux_states_[mxnet::op::batchnorm::kMovingMean].type_flag_, + this->input_blobs()[mxnet::op::batchnorm::kMovingMean].type_flag_, DTypeX, { - test::fill(this->c_.blob_aux_states_[mxnet::op::batchnorm::kMovingMean], DTypeX(0)); + test::fill(this->input_blobs()[mxnet::op::batchnorm::kMovingMean], DTypeX(0)); }); MSHADOW_TYPE_SWITCH( - this->c_.blob_aux_states_[mxnet::op::batchnorm::kMovingVar].type_flag_, + this->input_blobs()[mxnet::op::batchnorm::kMovingVar].type_flag_, DTypeX, { - test::fill(this->c_.blob_aux_states_[mxnet::op::batchnorm::kMovingVar], DTypeX(1));}); + test::fill(this->input_blobs()[mxnet::op::batchnorm::kMovingVar], DTypeX(1));}); - for (size_t i = 0, n = this->c_.blob_output_vec_.size(); i < n; ++i) { - const int dtype = this->c_.blob_output_vec_[i].type_flag_; + for (size_t i = 0, n = this->output_blobs().size(); i < n; ++i) { + const int dtype = this->output_blobs()[i].type_flag_; MSHADOW_TYPE_SWITCH(dtype, DTypeX, - { test::fill(this->c_.blob_output_vec_[i], DTypeX(0.1234)); }); + { test::fill(this->output_blobs()[i], DTypeX(0.1234)); }); } } void resetBackward() override { DType val = -.001; MSHADOW_TYPE_SWITCH( - this->c_.blob_out_grad_[mxnet::op::batchnorm::kOut].type_flag_, + this->output_blobs()[mxnet::op::batchnorm::kOut].type_flag_, DTypeX, { - test::patternFill(&this->c_.blob_out_grad_[mxnet::op::batchnorm::kOut], + test::patternFill(&this->output_blobs()[mxnet::op::batchnorm::kOut], [&val]{ return val += 1; }); }); // out-grad weights - if (mxnet::op::batchnorm::kGamma < this->c_.blob_out_grad_.size()) { + if (mxnet::op::batchnorm::kGamma < this->output_blobs().size()) { MSHADOW_TYPE_SWITCH( - this->c_.blob_out_grad_[mxnet::op::batchnorm::kGamma].type_flag_, + this->output_blobs()[mxnet::op::batchnorm::kGamma].type_flag_, DTypeX, - { test::try_fill(this->c_.blob_out_grad_, mxnet::op::batchnorm::kGamma, DTypeX(0.1)); }); + { test::try_fill(this->output_blobs(), mxnet::op::batchnorm::kGamma, DTypeX(0.1)); }); } // out-grad biases - if (mxnet::op::batchnorm::kBeta < this->c_.blob_out_grad_.size()) { + if (mxnet::op::batchnorm::kBeta < this->output_blobs().size()) { MSHADOW_TYPE_SWITCH( - this->c_.blob_out_grad_[mxnet::op::batchnorm::kBeta].type_flag_, + this->output_blobs()[mxnet::op::batchnorm::kBeta].type_flag_, DTypeX, - { test::try_fill(this->c_.blob_out_grad_, mxnet::op::batchnorm::kBeta, DTypeX(0.1)); }); + { test::try_fill(this->output_blobs(), mxnet::op::batchnorm::kBeta, DTypeX(0.1)); }); } // in-grad MSHADOW_TYPE_SWITCH( - this->c_.blob_in_grad_[mxnet::op::batchnorm::kData].type_flag_, + this->input_blobs()[mxnet::op::batchnorm::kData].type_flag_, DTypeX, - { test::try_fill(this->c_.blob_in_grad_, mxnet::op::batchnorm::kData, DTypeX(0)); }); + { test::try_fill(this->input_blobs(), mxnet::op::batchnorm::kData, DTypeX(0)); }); // in-grad weights - if (mxnet::op::batchnorm::kGamma < this->c_.blob_in_grad_.size()) { + if (mxnet::op::batchnorm::kGamma < this->input_blobs().size()) { MSHADOW_TYPE_SWITCH( - this->c_.blob_in_grad_[mxnet::op::batchnorm::kGamma].type_flag_, + this->input_blobs()[mxnet::op::batchnorm::kGamma].type_flag_, DTypeX, - { test::try_fill(this->c_.blob_in_grad_, mxnet::op::batchnorm::kGamma, DTypeX(0)); }); + { test::try_fill(this->input_blobs(), mxnet::op::batchnorm::kGamma, DTypeX(0)); }); } // in-grad biases - if (mxnet::op::batchnorm::kBeta < this->c_.blob_in_grad_.size()) { + if (mxnet::op::batchnorm::kBeta < this->input_blobs().size()) { MSHADOW_TYPE_SWITCH( - this->c_.blob_in_grad_[mxnet::op::batchnorm::kBeta].type_flag_, + this->input_blobs()[mxnet::op::batchnorm::kBeta].type_flag_, DTypeX, - { test::try_fill(this->c_.blob_in_grad_, mxnet::op::batchnorm::kBeta, DTypeX(0)); }); + { test::try_fill(this->input_blobs(), mxnet::op::batchnorm::kBeta, DTypeX(0)); }); } } @@ -338,15 +335,16 @@ class BatchNormValidator : public test::op::Validator { } public: - template - static inline bool compare(const ExecutorType& i1, - const ExecutorType& i2, + template + static inline bool compare(const ExecutorType1& i1, + const ExecutorType2& i2, const typename - test::op::LegacyOperatorExecutor::BlobVectorType bvt, - const size_t idx, bool print = false) { + test::op::CoreOpExecutor::BlobVectorType bvt, + const size_t idx, + bool print = false) { // Validate legacy data - auto *legacy1 = dynamic_cast *>(&i1); - auto *legacy2 = dynamic_cast *>(&i2); + auto *legacy1 = dynamic_cast *>(&i1); + auto *legacy2 = dynamic_cast *>(&i2); CHECK_NOTNULL(legacy1); CHECK_NOTNULL(legacy2); const std::vector &bv1 = legacy1->getBlobVect(bvt); @@ -370,7 +368,7 @@ class BatchNormValidator : public test::op::Validator { /*! \brief Check batch norm output */ template static void validateForward(const BNOperatorProp& data) { - const TBlob& outputBlob = data.outputs()[mxnet::op::batchnorm::kData]; + const TBlob& outputBlob = data.output_blobs()[mxnet::op::batchnorm::kData]; switch (outputBlob.ndim()) { case 3: checkBatchNorm1D(&outputBlob); @@ -391,20 +389,20 @@ class BatchNormValidator : public test::op::Validator { template static void compare( const test::op::OpInfo>& info_1, - const test::op::OpInfo>& info_2) { + const test::op::OpInfo>& info_2) { // Input EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kInput, + test::op::CoreOpExecutor::kInput, mxnet::op::batchnorm::kData)); EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kInput, + test::op::CoreOpExecutor::kInput, mxnet::op::batchnorm::kGamma)); EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kInput, + test::op::CoreOpExecutor::kInput, mxnet::op::batchnorm::kBeta)); // Output EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kOutput, + test::op::CoreOpExecutor::kOutput, mxnet::op::batchnorm::kOut)); CHECK_EQ(info_2.prop_->getParam().use_global_stats, info_1.prop_->getParam().use_global_stats); @@ -412,29 +410,29 @@ class BatchNormValidator : public test::op::Validator { #if MXNET_USE_CUDNN != 1 /* CUDNN takes a different approach here on first pass */ // Aux EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kAux, + test::op::CoreOpExecutor::kAux, mxnet::op::batchnorm::kMovingMean)); EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kAux, + test::op::CoreOpExecutor::kAux, mxnet::op::batchnorm::kMovingVar)); #endif if (!info_2.prop_->getParam().use_global_stats) { EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kOutput, + test::op::CoreOpExecutor::kOutput, mxnet::op::batchnorm::kMean)); // InGrad EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kInGrad, + test::op::CoreOpExecutor::kInGrad, mxnet::op::batchnorm::kData)); EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kInGrad, + test::op::CoreOpExecutor::kInGrad, mxnet::op::batchnorm::kGamma)); EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kInGrad, + test::op::CoreOpExecutor::kInGrad, mxnet::op::batchnorm::kBeta)); // OutGrad EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kOutGrad, + test::op::CoreOpExecutor::kOutGrad, mxnet::op::batchnorm::kData)); } } @@ -627,7 +625,7 @@ static test::op::OpInfoPair test BatchNormValidator::compare( *info_1.executor_, *info_2.executor_, - test::op::LegacyOperatorExecutor::kInput, + test::op::CoreOpExecutor::kInput, mxnet::op::batchnorm::kData, false); if (!thisCount) { @@ -648,9 +646,9 @@ static test::op::OpInfoPair test BatchNormValidator::compare(info_1, info_2); } while (++thisCount < cycleCount); - if (dumpC) { - info_1.executor_->dumpC(&std::cerr, "BN_testForwardAndBackward"); - } +// if (dumpC) { +// info_1.executor_->dumpC(&std::cerr, "BN_testForwardAndBackward"); +// } return { info_1, info_2 }; } @@ -673,15 +671,24 @@ testForwardAndBackward(const bool isGPU, cycleCount); } +// NOTE: This should know which version to use (V1, mkl, etc) +struct BatchNormCoreOpProp : public mxnet::test::op::CoreOpProp { + const mxnet::op::BatchNormParam& getParam() const { + CHECK(false); // Not implemented + static mxnet::op::BatchNormParam dummy; + return dummy; + } +}; + template -static test::op::OpInfoPair +static test::op::OpInfoPair testBNForwardAndBackward2D(const bool isGPU, const TShape &inputShape, const test::op::kwargs_t kwargs, const bool dumpC = false) { CHECK_EQ(inputShape.ndim(), 4); // V1 can only handle 2D - return testForwardAndBackward( + return testForwardAndBackward( isGPU, isGPU, inputShape, @@ -698,11 +705,14 @@ TEST(BATCH_NORM, Test2DForwardV1V2) { DType, AccReal, { + // Have to specify somehow v1 and v2 auto infoA = testBNForwardAndBackward2D>( false, {BATCH_SIZE, CHANNELS, DH, DW}, blank_kwargs); }); } +#if 0 + static const std::vector v2_types = {mshadow::kFloat32, mshadow::kFloat64, mshadow::kFloat16};