Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Remove nnvm namespace for FInferShape, FInferType, and FInferStorageType
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jun 25, 2017
1 parent bcf9676 commit 8aedf05
Show file tree
Hide file tree
Showing 28 changed files with 168 additions and 128 deletions.
40 changes: 40 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,46 @@ using FComputeEx = std::function<void (const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs)>;

/*!
* \brief Inference function of certain type.
* \tparam AttrType The type of the attribute to be infered.
* \return whether all attributes are inferred.
*/
template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs)>;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
* TShape.ndim() == 0 means the shape is still unknown.
*
* \note Register under "FInferShape",
* by default do not update any shapes.
*
* FInferShape is needed by shape inference
*/
using FInferShape = FInferNodeEntryAttr<TShape>;

/*!
* \brief Type inference function.
* Update the type given the known type information.
*
* \note Register under "FInferType",
* by default set all the output types to 0.
*/
using FInferType = FInferNodeEntryAttr<int>;

/*!
* \brief Storage type inference function.
* Update the type given the known type information.
*
* \note Register under "FInferStorageType",
* by default set all the output types to 1.
*/
using FInferStorageType = FInferNodeEntryAttr<int>;

} // namespace mxnet

#endif // MXNET_OP_ATTR_TYPES_H_
6 changes: 3 additions & 3 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ void SetShapeType(const nnvm::Op* op,
std::vector<NDArray>* p_ndoutputs,
int* dispatch_stype) {
std::vector<NDArray>& ndoutputs = *p_ndoutputs;
static auto& infershape = nnvm::Op::GetAttr<nnvm::FInferShape>("FInferShape");
static auto& infertype = nnvm::Op::GetAttr<nnvm::FInferType>("FInferType");
static auto& inferstorage = nnvm::Op::GetAttr<nnvm::FInferStorageType>("FInferStorageType");
static auto& infershape = nnvm::Op::GetAttr<FInferShape>("FInferShape");
static auto& infertype = nnvm::Op::GetAttr<FInferType>("FInferType");
static auto& inferstorage = nnvm::Op::GetAttr<FInferStorageType>("FInferStorageType");
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
// infer shape
std::vector<TShape>& in_shapes = ret->arg_shapes;
Expand Down
2 changes: 1 addition & 1 deletion src/executor/infer_graph_attr_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,

const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<nnvm::FInferNodeEntryAttr<AttrType> >(infer_name);
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& is_backward =
Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
// gradient function, used to get node correspondence.
Expand Down
8 changes: 4 additions & 4 deletions src/io/image_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ NNVM_REGISTER_OP(_cvimresize)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(op::ParamParser<ResizeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ResizeShape)
.set_attr<nnvm::FInferType>("FInferType", op::ElemwiseType<1, 1>)
.set_attr<FInferShape>("FInferShape", ResizeShape)
.set_attr<FInferType>("FInferType", op::ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", Imresize)
.add_argument("src", "NDArray", "source image")
.add_arguments(ResizeParam::__FIELDS__());
Expand All @@ -292,8 +292,8 @@ NNVM_REGISTER_OP(_cvcopyMakeBorder)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(op::ParamParser<MakeBorderParam>)
.set_attr<nnvm::FInferShape>("FInferShape", MakeBorderShape)
.set_attr<nnvm::FInferType>("FInferType", op::ElemwiseType<1, 1>)
.set_attr<FInferShape>("FInferShape", MakeBorderShape)
.set_attr<FInferType>("FInferType", op::ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", copyMakeBorder)
.add_argument("src", "NDArray", "source image")
.add_arguments(MakeBorderParam::__FIELDS__());
Expand Down
4 changes: 2 additions & 2 deletions src/nnvm/legacy_op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ void RegisterLegacyOpProp() {
op.set_attr<nnvm::FListInputNames>("FListInputNames", OpPropListInputNames);
op.set_attr<nnvm::FListOutputNames>("FListOutputNames", OpPropListOutputNames);
op.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", OpPropNumVisibleOutputs);
op.set_attr<nnvm::FInferShape>("FInferShape", OpPropInferShape);
op.set_attr<nnvm::FInferType>("FInferType", OpPropInferType);
op.set_attr<FInferShape>("FInferShape", OpPropInferShape);
op.set_attr<FInferType>("FInferType", OpPropInferType);
op.set_attr<nnvm::FMutateInputs>("FMutateInputs", OpPropMutateInputs);
op.set_attr<nnvm::FInplaceOption>("FInplaceOption", OpPropInplaceOption);
op.set_attr<FResourceRequest>("FResourceRequest", OpPropResourceRequest);
Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ here `range(T) = numeric_limits<T>::max() - numeric_limits<T>::min()`
.set_attr_parser(ParamParser<DequantizeParam>)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", DequantizeShape)
.set_attr<nnvm::FInferType>("FInferType", DequantizeType)
.set_attr<FInferShape>("FInferShape", DequantizeShape)
.set_attr<FInferType>("FInferType", DequantizeType)
.set_attr<FCompute>("FCompute<cpu>", DequantizeCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_dequantize"})
.add_argument("input", "NDArray-or-Symbol", "A ndarray/symbol of type `uint8`")
Expand Down
4 changes: 2 additions & 2 deletions src/operator/contrib/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ here `range(T) = numeric_limits<T>::max() - numeric_limits<T>::min()`
.set_attr_parser(ParamParser<QuantizeParam>)
.set_num_inputs(3)
.set_num_outputs(3)
.set_attr<nnvm::FInferShape>("FInferShape", QuantizeShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizeType)
.set_attr<FInferShape>("FInferShape", QuantizeShape)
.set_attr<FInferType>("FInferType", QuantizeType)
.set_attr<FCompute>("FCompute<cpu>", QuantizeCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_quantize"})
.add_argument("input", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`")
Expand Down
4 changes: 2 additions & 2 deletions src/operator/loss_binary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ Example::
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", SoftmaxCrossEntropyShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FInferShape>("FInferShape", SoftmaxCrossEntropyShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/cast_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ NNVM_REGISTER_OP(cast_storage)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<CastStorageParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FInferStorageType>("FInferStorageType", CastStorageInferStorageType)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferStorageType>("FInferStorageType", CastStorageInferStorageType)
.set_attr<FCompute>("FCompute<cpu>", IdentityCompute<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", CastStorageComputeEx<cpu>)
.add_argument("data", "NDArray-or-Symbol", "The input.")
Expand Down
20 changes: 10 additions & 10 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ update is applied only to rows whose gradient has non-zero entries.
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SGDUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
Expand Down Expand Up @@ -63,8 +63,8 @@ only rows whose gradients contain non-zero entries are updated (for both weight
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SGDMomParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2};
Expand Down Expand Up @@ -100,8 +100,8 @@ It updates the weights using::
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
Expand Down Expand Up @@ -152,8 +152,8 @@ Hinton suggests the momentum term :math:`\gamma` to be 0.9 and the learning rate
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr_parser(ParamParser<RMSPropParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs &attrs) {
return std::vector<uint32_t>{2};
Expand Down Expand Up @@ -191,8 +191,8 @@ to be 0.9 and the learning rate :math:`\eta` to be 0.0001.
.set_num_inputs(5)
.set_num_outputs(1)
.set_attr_parser(ParamParser<RMSPropAlexParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<5, 1>)
.set_attr<FInferShape>("FInferShape", ElemwiseShape<5, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<5, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3, 4};
Expand Down
4 changes: 2 additions & 2 deletions src/operator/random/multisample_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ DMLC_REGISTER_PARAMETER(MultiSampleParam);
[](const NodeAttrs& attrs) { \
std::vector<std::string> v = {input_name_1, input_name_2}; v.resize(num_inputs); return v; \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", MultiSampleOpShape) \
.set_attr<nnvm::FInferType>("FInferType", MultiSampleOpType) \
.set_attr<FInferShape>("FInferShape", MultiSampleOpShape) \
.set_attr<FInferType>("FInferType", MultiSampleOpType) \
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { \
return std::vector<ResourceRequest>(1, ResourceRequest::kRandom); \
}) \
Expand Down
4 changes: 2 additions & 2 deletions src/operator/random/sample_multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ Examples::
return param.get_prob ? 2U : 1U;
})
.set_attr_parser(ParamParser<SampleMultinomialParam>)
.set_attr<nnvm::FInferShape>("FInferShape", SampleMultinomialOpShape)
.set_attr<nnvm::FInferType>("FInferType", SampleMultinomialOpType)
.set_attr<FInferShape>("FInferShape", SampleMultinomialOpShape)
.set_attr<FInferType>("FInferType", SampleMultinomialOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<ResourceRequest>{
Expand Down
4 changes: 2 additions & 2 deletions src/operator/random/sample_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialParam);
.set_num_inputs(0) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<ParamType>) \
.set_attr<nnvm::FInferShape>("FInferShape", InitShape<ParamType>) \
.set_attr<nnvm::FInferType>("FInferType", SampleOpType<ParamType>) \
.set_attr<FInferShape>("FInferShape", InitShape<ParamType>) \
.set_attr<FInferType>("FInferType", SampleOpType<ParamType>) \
.set_attr<FResourceRequest>("FResourceRequest", SampleResource) \
.add_arguments(ParamType::__FIELDS__())

Expand Down
10 changes: 5 additions & 5 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,8 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<ReduceAxisParam>) \
.set_attr<nnvm::FInferShape>("FInferShape", ReduceAxisShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferShape>("FInferShape", ReduceAxisShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.add_argument("data", "NDArray-or-Symbol", "The input") \
.add_arguments(ReduceAxisParam::__FIELDS__())

Expand All @@ -673,8 +673,8 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr_parser(AxesParamParser<ReduceAxesParam>) \
.set_attr<nnvm::FInferShape>("FInferShape", ReduceAxesShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferShape>("FInferShape", ReduceAxesShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.add_argument("data", "NDArray-or-Symbol", "The input") \
.add_arguments(ReduceAxesParam::__FIELDS__())

Expand All @@ -688,7 +688,7 @@ void PickOpBackward(const nnvm::NodeAttrs& attrs,
NNVM_REGISTER_OP(name) \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<nnvm::FGradient>("FGradient", \
[](const nnvm::NodePtr& n, \
const std::vector<nnvm::NodeEntry>& ograds) { \
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ Examples::
param.keepdims = false;
attrs->parsed = param;
})
.set_attr<nnvm::FInferShape>("FInferShape", ReduceAxisShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferShape>("FInferShape", ReduceAxisShape)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::maximum>)
.add_argument("data", "NDArray-or-Symbol", "The input array");

Expand Down Expand Up @@ -131,8 +131,8 @@ Examples::
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "index"};
})
.set_attr<nnvm::FInferShape>("FInferShape", PickOpShape)
.set_attr<nnvm::FInferType>("FInferType", PickOpType)
.set_attr<FInferShape>("FInferShape", PickOpShape)
.set_attr<FInferType>("FInferType", PickOpType)
.set_attr<FCompute>("FCompute<cpu>", PickOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/broadcast_reduce_op_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ Example::
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<BroadcastAxesParam>)
.add_arguments(BroadcastAxesParam::__FIELDS__())
.set_attr<nnvm::FInferShape>("FInferShape", BroadcastAxesShape)
.set_attr<FInferShape>("FInferShape", BroadcastAxesShape)
.set_attr<FCompute>("FCompute<cpu>", BroadcastCompute<cpu>);

MXNET_OPERATOR_REGISTER_BROADCAST(broadcast_to)
Expand All @@ -192,7 +192,7 @@ So with `shape=(2,0)`, we will obtain the same result as in the above example.
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<BroadcastToParam>)
.add_arguments(BroadcastToParam::__FIELDS__())
.set_attr<nnvm::FInferShape>("FInferShape", BroadcastToShape)
.set_attr<FInferShape>("FInferShape", BroadcastToShape)
.set_attr<FCompute>("FCompute<cpu>", BroadcastCompute<cpu>);

// backward op for broadcast.
Expand All @@ -218,7 +218,7 @@ Examples::
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape",
.set_attr<FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
Expand All @@ -228,7 +228,7 @@ Examples::
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1));
return true;
})
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", L2NormCompute<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Source input");

Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ NNVM_REGISTER_OP(where)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"condition", "x", "y"};
})
.set_attr<nnvm::FInferShape>("FInferShape", WhereOpShape)
.set_attr<nnvm::FInferType>("FInferType", WhereOpType)
.set_attr<FInferShape>("FInferShape", WhereOpShape)
.set_attr<FInferType>("FInferType", WhereOpType)
.set_attr<FCompute>("FCompute<cpu>", WhereOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
// Use the following lambda function instead of ElemwiseGradUseIn
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferShape>("FInferShape", BinaryBroadcastShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ void BinaryBackwardUseInWithHalf2(const nnvm::NodeAttrs& attrs,
[](const NodeAttrs& attrs) { \
return std::vector<std::string>{"lhs", "rhs"}; \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<2, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ MXNET_OPERATOR_REGISTER_BINARY(elemwise_add)
.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, mshadow::op::plus>)
.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryComputeEx<cpu, mshadow::op::plus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_add"})
.set_attr<nnvm::FInferStorageType>("FInferStorageType", ElemwiseStorageType<2, 1>);
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<2, 1>);

// specialized gradient add function to do add to optimization
// this must differ from elemwise_add to prevent add to optimization in forward pass.
Expand All @@ -33,7 +33,7 @@ NNVM_REGISTER_OP(_backward_add)
mshadow_op::identity>)
.set_attr<FComputeEx>("FComputeEx<cpu>",
BinaryBackwardUseNoneEx<cpu, mshadow_op::identity, mshadow_op::identity>)
.set_attr<nnvm::FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 2>);
.set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 2>);

MXNET_OPERATOR_REGISTER_BINARY(_sub)
.add_alias("_minus").add_alias("_Minus")
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_scalar_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ void BinaryScalarBackward(const nnvm::NodeAttrs& attrs,
.set_attr_parser([](NodeAttrs* attrs) { \
attrs->parsed = std::stod(attrs->dict["scalar"]); \
}) \
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FInferShape>("FInferShape", ElemwiseShape<1, 1>) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<nnvm::FInplaceOption>("FInplaceOption", \
[](const NodeAttrs& attrs){ \
return std::vector<std::pair<int, int> >{{0, 0}}; \
Expand Down
Loading

0 comments on commit 8aedf05

Please sign in to comment.