Skip to content

Commit

Permalink
move batchnorm tests to nnvm interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Feb 1, 2018
1 parent d79067b commit 9c1745d
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 62 deletions.
50 changes: 49 additions & 1 deletion tests/cpp/include/test_core_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ inline const char *TimingDirectionAsString(const TimingDirection td) {
* Low-noise operator executor
* @tparam DType Data type for the operator executions
*/
template<typename DType>
template<typename DType, typename AccReal = float>
class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
, public test::op::OperatorExecutorTiming {
/*! \brief Performance timing categories */
Expand Down Expand Up @@ -224,7 +224,43 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
}

public:

enum BlobVectorType {
kInput,
kOutput,
kAux,
kInGrad,
kOutGrad,
kBlobVectorTypeCount
};

#define CASE_STR(__v$) case (__v$): return #__v$

/*! \brief Convert BlobVectorType enum into a string */
static inline const char *bvt2String(const BlobVectorType bvt) {
switch (bvt) {
CASE_STR(kInput);
CASE_STR(kOutput);
CASE_STR(kAux);
CASE_STR(kInGrad);
CASE_STR(kOutGrad);
default:
CHECK(false);
return "";
}
}
#undef CASE_STR

inline const std::vector<TBlob>& getBlobVect(const BlobVectorType bvt) const {
// Not implemented
CHECK(false);
static std::vector<TBlob> dummy;
return dummy;
}


typedef DType DataType;
typedef AccReal AccRealType;

/*! \brief Add 'fwd_op_name' to kwargs and return the new kwargs */
static kwargs_t ArgsWithOpName(const kwargs_t& args,
Expand Down Expand Up @@ -519,13 +555,17 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
*/
std::vector<NDArray>& inputs() { return inputs_; }
const std::vector<NDArray>& inputs() const { return inputs_; }
std::vector<TBlob>& input_blobs() { return blob_inputs_; }
const std::vector<TBlob>& input_blobs() const { return blob_inputs_; }

/*!
* \brief Access input NDArray vector
* \return reference to NDArray vector of forward outputs
*/
std::vector<NDArray>& outputs() { return outputs_; }
const std::vector<NDArray>& outputs() const { return outputs_; }
std::vector<TBlob>& output_blobs() { return blob_outputs_; }
const std::vector<TBlob>& output_blobs() const { return blob_outputs_; }

/*!
* \brief Backward inputs (i.e. output grad)
Expand All @@ -549,6 +589,14 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer<DType>
verbose_ = verbose;
}

virtual void resetForward() {
CHECK(false) << "Not implemented, generally inits forward-pass data";
}

virtual void resetBackward() {
CHECK(false) << "Not implemented, generally inits backward-pass data";
}

private:
/*!
* \brief Has the execution been initialized?
Expand Down
132 changes: 71 additions & 61 deletions tests/cpp/operator/batchnorm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
* \author Chris Olivier
*/

#if 0

#include <dmlc/logging.h>
#include <mxnet/tensor_blob.h>
#include "../../src/operator/nn/batch_norm-inl.h"
Expand Down Expand Up @@ -61,31 +59,30 @@ static constexpr int TIMING_DEPTH = 2;
static constexpr int TIMING_DH = 28;
static constexpr int TIMING_DW = 28;


/*! \brief BatchNorm-specific test data */
template <typename DType, typename AccReal>
class BNOperatorExecutor : public test::op::LegacyOperatorExecutor<DType, AccReal> {
class BNOperatorExecutor : public test::op::CoreOpExecutor<DType, AccReal> {
public:
BNOperatorExecutor(const bool isGPU, const TShape& inputShape,
const bool hasWeightAndBias = false)
: test::op::LegacyOperatorExecutor<DType, AccReal>(isGPU, { inputShape })
: test::op::CoreOpExecutor<DType, AccReal>(isGPU, { inputShape })
, hasWeightAndBias_(hasWeightAndBias) {
}

void resetForward() override {
// Init input data
MSHADOW_TYPE_SWITCH(
this->c_.blob_input_vec_[mxnet::op::batchnorm::kData].type_flag_,
this->input_blobs()[mxnet::op::batchnorm::kData].type_flag_,
DTypeX,
{
DTypeX val = 0;
test::patternFill<DTypeX>(&this->c_.blob_input_vec_[mxnet::op::batchnorm::kData],
test::patternFill<DTypeX>(&this->input_blobs()[mxnet::op::batchnorm::kData],
[&val]{ return val += 1; }); });

MSHADOW_TYPE_SWITCH(
this->c_.blob_input_vec_[mxnet::op::batchnorm::kGamma].type_flag_,
this->input_blobs()[mxnet::op::batchnorm::kGamma].type_flag_,
DTypeX, {
const TBlob& blob = this->c_.blob_input_vec_[mxnet::op::batchnorm::kGamma];
const TBlob& blob = this->input_blobs()[mxnet::op::batchnorm::kGamma];
test::fill(blob, DTypeX(1));
if (hasWeightAndBias_) {
if (blob.size(0) > 1) {
Expand All @@ -94,9 +91,9 @@ class BNOperatorExecutor : public test::op::LegacyOperatorExecutor<DType, AccRea
}
});
MSHADOW_TYPE_SWITCH(
this->c_.blob_input_vec_[mxnet::op::batchnorm::kBeta].type_flag_,
this->input_blobs()[mxnet::op::batchnorm::kBeta].type_flag_,
DTypeX, {
const TBlob& blob = this->c_.blob_input_vec_[mxnet::op::batchnorm::kBeta];
const TBlob& blob = this->input_blobs()[mxnet::op::batchnorm::kBeta];
if (!hasWeightAndBias_) {
test::fill(blob, DTypeX(0));
} else { // This will cause forward pass check to fail when calculating sum == 0
Expand All @@ -109,67 +106,67 @@ class BNOperatorExecutor : public test::op::LegacyOperatorExecutor<DType, AccRea

// Init the moving data (all mean = 0, all var = 1)
MSHADOW_TYPE_SWITCH(
this->c_.blob_aux_states_[mxnet::op::batchnorm::kMovingMean].type_flag_,
this->input_blobs()[mxnet::op::batchnorm::kMovingMean].type_flag_,
DTypeX, {
test::fill(this->c_.blob_aux_states_[mxnet::op::batchnorm::kMovingMean], DTypeX(0));
test::fill(this->input_blobs()[mxnet::op::batchnorm::kMovingMean], DTypeX(0));
});
MSHADOW_TYPE_SWITCH(
this->c_.blob_aux_states_[mxnet::op::batchnorm::kMovingVar].type_flag_,
this->input_blobs()[mxnet::op::batchnorm::kMovingVar].type_flag_,
DTypeX, {
test::fill(this->c_.blob_aux_states_[mxnet::op::batchnorm::kMovingVar], DTypeX(1));});
test::fill(this->input_blobs()[mxnet::op::batchnorm::kMovingVar], DTypeX(1));});

for (size_t i = 0, n = this->c_.blob_output_vec_.size(); i < n; ++i) {
const int dtype = this->c_.blob_output_vec_[i].type_flag_;
for (size_t i = 0, n = this->output_blobs().size(); i < n; ++i) {
const int dtype = this->output_blobs()[i].type_flag_;
MSHADOW_TYPE_SWITCH(dtype, DTypeX,
{ test::fill(this->c_.blob_output_vec_[i], DTypeX(0.1234)); });
{ test::fill(this->output_blobs()[i], DTypeX(0.1234)); });
}
}

void resetBackward() override {
DType val = -.001;
MSHADOW_TYPE_SWITCH(
this->c_.blob_out_grad_[mxnet::op::batchnorm::kOut].type_flag_,
this->output_blobs()[mxnet::op::batchnorm::kOut].type_flag_,
DTypeX, {
test::patternFill<DTypeX>(&this->c_.blob_out_grad_[mxnet::op::batchnorm::kOut],
test::patternFill<DTypeX>(&this->output_blobs()[mxnet::op::batchnorm::kOut],
[&val]{ return val += 1; });
});

// out-grad weights
if (mxnet::op::batchnorm::kGamma < this->c_.blob_out_grad_.size()) {
if (mxnet::op::batchnorm::kGamma < this->output_blobs().size()) {
MSHADOW_TYPE_SWITCH(
this->c_.blob_out_grad_[mxnet::op::batchnorm::kGamma].type_flag_,
this->output_blobs()[mxnet::op::batchnorm::kGamma].type_flag_,
DTypeX,
{ test::try_fill(this->c_.blob_out_grad_, mxnet::op::batchnorm::kGamma, DTypeX(0.1)); });
{ test::try_fill(this->output_blobs(), mxnet::op::batchnorm::kGamma, DTypeX(0.1)); });
}

// out-grad biases
if (mxnet::op::batchnorm::kBeta < this->c_.blob_out_grad_.size()) {
if (mxnet::op::batchnorm::kBeta < this->output_blobs().size()) {
MSHADOW_TYPE_SWITCH(
this->c_.blob_out_grad_[mxnet::op::batchnorm::kBeta].type_flag_,
this->output_blobs()[mxnet::op::batchnorm::kBeta].type_flag_,
DTypeX,
{ test::try_fill(this->c_.blob_out_grad_, mxnet::op::batchnorm::kBeta, DTypeX(0.1)); });
{ test::try_fill(this->output_blobs(), mxnet::op::batchnorm::kBeta, DTypeX(0.1)); });
}

// in-grad
MSHADOW_TYPE_SWITCH(
this->c_.blob_in_grad_[mxnet::op::batchnorm::kData].type_flag_,
this->input_blobs()[mxnet::op::batchnorm::kData].type_flag_,
DTypeX,
{ test::try_fill(this->c_.blob_in_grad_, mxnet::op::batchnorm::kData, DTypeX(0)); });
{ test::try_fill(this->input_blobs(), mxnet::op::batchnorm::kData, DTypeX(0)); });

// in-grad weights
if (mxnet::op::batchnorm::kGamma < this->c_.blob_in_grad_.size()) {
if (mxnet::op::batchnorm::kGamma < this->input_blobs().size()) {
MSHADOW_TYPE_SWITCH(
this->c_.blob_in_grad_[mxnet::op::batchnorm::kGamma].type_flag_,
this->input_blobs()[mxnet::op::batchnorm::kGamma].type_flag_,
DTypeX,
{ test::try_fill(this->c_.blob_in_grad_, mxnet::op::batchnorm::kGamma, DTypeX(0)); });
{ test::try_fill(this->input_blobs(), mxnet::op::batchnorm::kGamma, DTypeX(0)); });
}

// in-grad biases
if (mxnet::op::batchnorm::kBeta < this->c_.blob_in_grad_.size()) {
if (mxnet::op::batchnorm::kBeta < this->input_blobs().size()) {
MSHADOW_TYPE_SWITCH(
this->c_.blob_in_grad_[mxnet::op::batchnorm::kBeta].type_flag_,
this->input_blobs()[mxnet::op::batchnorm::kBeta].type_flag_,
DTypeX,
{ test::try_fill(this->c_.blob_in_grad_, mxnet::op::batchnorm::kBeta, DTypeX(0)); });
{ test::try_fill(this->input_blobs(), mxnet::op::batchnorm::kBeta, DTypeX(0)); });
}
}

Expand Down Expand Up @@ -338,15 +335,16 @@ class BatchNormValidator : public test::op::Validator<DType, AccReal> {
}

public:
template <typename ExecutorType>
static inline bool compare(const ExecutorType& i1,
const ExecutorType& i2,
template <typename ExecutorType1, typename ExecutorType2>
static inline bool compare(const ExecutorType1& i1,
const ExecutorType2& i2,
const typename
test::op::LegacyOperatorExecutor<DType, AccReal>::BlobVectorType bvt,
const size_t idx, bool print = false) {
test::op::CoreOpExecutor<DType>::BlobVectorType bvt,
const size_t idx,
bool print = false) {
// Validate legacy data
auto *legacy1 = dynamic_cast<const test::op::LegacyOperatorExecutor<DType, AccReal> *>(&i1);
auto *legacy2 = dynamic_cast<const test::op::LegacyOperatorExecutor<DType, AccReal> *>(&i2);
auto *legacy1 = dynamic_cast<const test::op::CoreOpExecutor<DType> *>(&i1);
auto *legacy2 = dynamic_cast<const test::op::CoreOpExecutor<DType> *>(&i2);
CHECK_NOTNULL(legacy1);
CHECK_NOTNULL(legacy2);
const std::vector<TBlob> &bv1 = legacy1->getBlobVect(bvt);
Expand All @@ -370,7 +368,7 @@ class BatchNormValidator : public test::op::Validator<DType, AccReal> {
/*! \brief Check batch norm output */
template<typename BNOperatorProp>
static void validateForward(const BNOperatorProp& data) {
const TBlob& outputBlob = data.outputs()[mxnet::op::batchnorm::kData];
const TBlob& outputBlob = data.output_blobs()[mxnet::op::batchnorm::kData];
switch (outputBlob.ndim()) {
case 3:
checkBatchNorm1D(&outputBlob);
Expand All @@ -391,50 +389,50 @@ class BatchNormValidator : public test::op::Validator<DType, AccReal> {
template<typename PropType1, typename PropType2>
static void compare(
const test::op::OpInfo<PropType1, BNOperatorExecutor<DType, AccReal>>& info_1,
const test::op::OpInfo<PropType2, BNOperatorExecutor<DType, AccReal>>& info_2) {
const test::op::OpInfo<PropType2, BNOperatorExecutor<DType, AccReal>>& info_2) {
// Input
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kInput,
test::op::CoreOpExecutor<DType>::kInput,
mxnet::op::batchnorm::kData));
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kInput,
test::op::CoreOpExecutor<DType>::kInput,
mxnet::op::batchnorm::kGamma));
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kInput,
test::op::CoreOpExecutor<DType>::kInput,
mxnet::op::batchnorm::kBeta));
// Output
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kOutput,
test::op::CoreOpExecutor<DType>::kOutput,
mxnet::op::batchnorm::kOut));
CHECK_EQ(info_2.prop_->getParam().use_global_stats,
info_1.prop_->getParam().use_global_stats);

#if MXNET_USE_CUDNN != 1 /* CUDNN takes a different approach here on first pass */
// Aux
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kAux,
test::op::CoreOpExecutor<DType>::kAux,
mxnet::op::batchnorm::kMovingMean));
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kAux,
test::op::CoreOpExecutor<DType>::kAux,
mxnet::op::batchnorm::kMovingVar));
#endif
if (!info_2.prop_->getParam().use_global_stats) {
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kOutput,
test::op::CoreOpExecutor<DType>::kOutput,
mxnet::op::batchnorm::kMean));
// InGrad
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kInGrad,
test::op::CoreOpExecutor<DType>::kInGrad,
mxnet::op::batchnorm::kData));
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kInGrad,
test::op::CoreOpExecutor<DType>::kInGrad,
mxnet::op::batchnorm::kGamma));
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kInGrad,
test::op::CoreOpExecutor<DType>::kInGrad,
mxnet::op::batchnorm::kBeta));
// OutGrad
EXPECT_TRUE(compare(*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kOutGrad,
test::op::CoreOpExecutor<DType>::kOutGrad,
mxnet::op::batchnorm::kData));
}
}
Expand Down Expand Up @@ -627,7 +625,7 @@ static test::op::OpInfoPair<OperatorProp1, OperatorProp2, OperatorExecutor> test

BatchNormValidator<DType, AccReal>::compare(
*info_1.executor_, *info_2.executor_,
test::op::LegacyOperatorExecutor<DType, AccReal>::kInput,
test::op::CoreOpExecutor<DType>::kInput,
mxnet::op::batchnorm::kData, false);

if (!thisCount) {
Expand All @@ -648,9 +646,9 @@ static test::op::OpInfoPair<OperatorProp1, OperatorProp2, OperatorExecutor> test
BatchNormValidator<DType, AccReal>::compare(info_1, info_2);
} while (++thisCount < cycleCount);

if (dumpC) {
info_1.executor_->dumpC(&std::cerr, "BN_testForwardAndBackward");
}
// if (dumpC) {
// info_1.executor_->dumpC(&std::cerr, "BN_testForwardAndBackward");
// }

return { info_1, info_2 };
}
Expand All @@ -673,15 +671,24 @@ testForwardAndBackward(const bool isGPU,
cycleCount);
}

// NOTE: This should know which version to use (V1, mkl, etc)
struct BatchNormCoreOpProp : public mxnet::test::op::CoreOpProp {
const mxnet::op::BatchNormParam& getParam() const {
CHECK(false); // Not implemented
static mxnet::op::BatchNormParam dummy;
return dummy;
}
};

template<typename OperatorExecutor>
static test::op::OpInfoPair<mxnet::op::BatchNormV1Prop, mxnet::op::BatchNormProp, OperatorExecutor>
static test::op::OpInfoPair<BatchNormCoreOpProp, BatchNormCoreOpProp, OperatorExecutor>
testBNForwardAndBackward2D(const bool isGPU,
const TShape &inputShape,
const test::op::kwargs_t kwargs,
const bool dumpC = false) {
CHECK_EQ(inputShape.ndim(), 4); // V1 can only handle 2D
return testForwardAndBackward<mxnet::op::BatchNormV1Prop,
mxnet::op::BatchNormProp, OperatorExecutor>(
return testForwardAndBackward<BatchNormCoreOpProp,
BatchNormCoreOpProp, OperatorExecutor>(
isGPU,
isGPU,
inputShape,
Expand All @@ -698,11 +705,14 @@ TEST(BATCH_NORM, Test2DForwardV1V2) {
DType,
AccReal,
{
// Have to specify somehow v1 and v2
auto infoA = testBNForwardAndBackward2D<BNOperatorExecutor<DType, AccReal>>(
false, {BATCH_SIZE, CHANNELS, DH, DW}, blank_kwargs);
});
}

#if 0

static const std::vector<int> v2_types = {mshadow::kFloat32,
mshadow::kFloat64,
mshadow::kFloat16};
Expand Down

0 comments on commit 9c1745d

Please sign in to comment.