Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Create hashcode for operator parameters properly.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Jan 31, 2018
1 parent 5b3f647 commit 3fbe716
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 81 deletions.
17 changes: 13 additions & 4 deletions src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,22 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
bool operator==(const ActivationParam& other) const {
return this->act_type == other.act_type;
}
};

} // namespace op
} // namespace mxnet

#if MXNET_USE_MKLDNN == 1
uint64_t GetHash() const {
return act_type;
namespace std {
template<>
struct hash<mxnet::op::ActivationParam> {
size_t operator()(const mxnet::op::ActivationParam& val) {
return val.act_type;
}
#endif
};
} // namespace std

namespace mxnet {
namespace op {

template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType>
void ActivationForward(const OpContext &ctx, const TBlob &in_data,
Expand Down
29 changes: 19 additions & 10 deletions src/operator/nn/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,28 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
this->axis == other.axis &&
this->cudnn_off == other.cudnn_off;
}
};

} // namespace op
} // namespace mxnet

#if MXNET_USE_MKLDNN == 1
uint64_t GetHash() const {
uint64_t hash = 0;
hash = hash * 2 + momentum * 10;
hash = hash * 2 + fix_gamma;
hash = hash * 2 + use_global_stats;
hash = hash * 2 + output_mean_var;
hash = hash * 2 + axis;
return hash;
namespace std {
template<>
struct hash<mxnet::op::BatchNormParam> {
size_t operator()(const mxnet::op::BatchNormParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.momentum);
ret = dmlc::HashCombine(ret, val.fix_gamma);
ret = dmlc::HashCombine(ret, val.use_global_stats);
ret = dmlc::HashCombine(ret, val.output_mean_var);
ret = dmlc::HashCombine(ret, val.axis);
return ret;
}
#endif
};
} // namespace std

namespace mxnet {
namespace op {

static inline bool IsBNWriting(const OpReqType ort) {
return ort == kWriteTo || ort == kWriteInplace;
Expand Down
23 changes: 0 additions & 23 deletions src/operator/nn/convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,29 +118,6 @@ struct ConvolutionParam : public dmlc::Parameter<ConvolutionParam> {
this->cudnn_off == other.cudnn_off &&
this->layout == other.layout;
}
#if MXNET_USE_MKLDNN == 1
static uint64_t ComputeHash(const TShape &shape) {
uint64_t hash = 0;
for (size_t i = 0; i < shape.ndim(); i++)
hash = hash * 2 + shape[i];
return hash;
}

uint64_t GetHash() const {
uint64_t hash = 0;
hash = hash * 2 + ComputeHash(kernel);
hash = hash * 2 + ComputeHash(stride);
hash = hash * 2 + ComputeHash(dilate);
hash = hash * 2 + ComputeHash(pad);
hash = hash * 2 + num_filter;
hash = hash * 2 + num_group;
hash = hash * 2 + workspace;
hash = hash * 2 + no_bias;
if (layout.has_value())
hash = hash * 2 + layout.value();
return hash;
}
#endif
};

} // namespace op
Expand Down
25 changes: 0 additions & 25 deletions src/operator/nn/deconvolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,31 +163,6 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
this->cudnn_off == other.cudnn_off &&
this->layout == other.layout;
}
#if MXNET_USE_MKLDNN == 1
static uint64_t ComputeHash(const TShape &shape) {
uint64_t hash = 0;
for (size_t i = 0; i < shape.ndim(); i++)
hash = hash * 2 + shape[i];
return hash;
}

uint64_t GetHash() const {
uint64_t hash = 0;
hash = hash * 2 + ComputeHash(kernel);
hash = hash * 2 + ComputeHash(stride);
hash = hash * 2 + ComputeHash(dilate);
hash = hash * 2 + ComputeHash(pad);
hash = hash * 2 + ComputeHash(adj);
hash = hash * 2 + ComputeHash(target_shape);
hash = hash * 2 + num_filter;
hash = hash * 2 + num_group;
hash = hash * 2 + workspace;
hash = hash * 2 + no_bias;
if (layout.has_value())
hash = hash * 2 + layout.value();
return hash;
}
#endif
};

} // namespace op
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ class MKLDNNParamOpSign: public MKLDNNOpSignature {

public:
explicit MKLDNNParamOpSign(const ParamType &_param): MKLDNNOpSignature(
_param.GetHash()), param(_param) {
std::hash<ParamType>(_param)), param(_param) {
}

bool operator==(const MKLDNNParamOpSign<ParamType> &sign) const {
Expand Down
38 changes: 20 additions & 18 deletions src/operator/nn/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,28 +88,30 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
this->global_pool == other.global_pool &&
this->cudnn_off == other.cudnn_off;
}
};

#if MXNET_USE_MKLDNN == 1
static uint64_t ComputeHash(const TShape &shape) {
uint64_t hash = 0;
for (size_t i = 0; i < shape.ndim(); i++)
hash = hash * 2 + shape[i];
return hash;
}
} // namespace op
} // namespace mxnet

uint64_t GetHash() const {
uint64_t hash = 0;
hash = hash * 2 + ComputeHash(kernel);
hash = hash * 2 + ComputeHash(stride);
hash = hash * 2 + ComputeHash(pad);
hash = hash * 2 + pool_type;
hash = hash * 2 + pooling_convention;
hash = hash * 2 + global_pool;
hash = hash * 2 + cudnn_off;
return hash;
namespace std {
template<>
struct hash<mxnet::op::PoolingParam> {
size_t operator()(const mxnet::op::PoolingParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.kernel);
ret = dmlc::HashCombine(ret, val.stride);
ret = dmlc::HashCombine(ret, val.pad);
ret = dmlc::HashCombine(ret, val.pool_type);
ret = dmlc::HashCombine(ret, val.pooling_convention);
ret = dmlc::HashCombine(ret, val.global_pool);
ret = dmlc::HashCombine(ret, val.cudnn_off);
return ret;
}
#endif
};
} // namespace std

namespace mxnet {
namespace op {

/*
* When MKLDNN is enabled, we might want 2 outputs instead of one inputs, which
Expand Down

0 comments on commit 3fbe716

Please sign in to comment.