Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
implementation for equivalence of tf.moments
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed May 6, 2019
1 parent a722db4 commit 83e2a08
Show file tree
Hide file tree
Showing 4 changed files with 406 additions and 0 deletions.
254 changes: 254 additions & 0 deletions src/operator/nn/moments-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file moments-inl.h
* \brief Moments operator
* \author Hao Jin
*/

#ifndef MXNET_OPERATOR_NN_MOMENTS_INL_H_
#define MXNET_OPERATOR_NN_MOMENTS_INL_H_

#include <vector>
#include "../tensor/broadcast_reduce_op.h"

namespace mxnet {
namespace op {

struct MomentsParam : public dmlc::Parameter<MomentsParam> {
dmlc::optional<mxnet::TShape> axes;
bool keepdims;
DMLC_DECLARE_PARAMETER(MomentsParam) {
DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional<mxnet::TShape>())
.describe("Array of ints. Axes along which to compute mean and variance.");
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("produce moments with the same dimensionality as the input.");
}
};

inline bool MomentsShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
const MomentsParam& param = nnvm::get<MomentsParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 2U);

mxnet::TShape out_shape =
ReduceAxesShapeImpl((*in_attrs)[0], param.axes, param.keepdims, false);
if (!param.axes.has_value() || param.axes.value().ndim() == 0) {
LOG(FATAL) << "Empty axes is not supported, if you would like to do global moments, "
<< "please pass all axes to axes argument";
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);
SHAPE_ASSIGN_CHECK(*out_attrs, 1, out_shape);
return true;
}

inline bool MomentsType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 2U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(1));
return out_attrs->at(0) != -1 && out_attrs->at(1) != -1;
}

struct VarBroadcastKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
DType *out,
const DType *data,
const DType *mean,
mshadow::Shape<6> data_shape,
mshadow::Shape<6> mean_shape) {
size_t data_idx = i;
size_t mean_idx = i;
size_t data_stride = 1;
size_t mean_stride = 1;
for (int axis = 5; axis >= 0; --axis) {
size_t axis_idx = data_idx % data_shape[axis];
mean_idx -= axis_idx * data_stride;
if (mean_shape[axis] != 1) {
mean_idx += axis_idx * mean_stride;
}
data_idx /= data_shape[axis];
data_stride *= data_shape[axis];
mean_stride *= mean_shape[axis];
}
DType res = (data[i] - mean[mean_idx]);
out[i] = res * res;
}
};

template<typename xpu>
inline void MomentsForwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const dmlc::optional<mxnet::TShape>& axes,
const bool keepdims) {
using namespace mshadow;
using namespace mshadow_op;
using namespace mxnet_op;

Stream<xpu> *s = ctx.get_stream<xpu>();

const TBlob& data = inputs[0];
const TBlob& mean = outputs[0];
const TBlob& var = outputs[1];

mxnet::TShape small;
if (keepdims) {
small = outputs[0].shape_;
} else {
small = ReduceAxesShapeImpl(inputs[0].shape_, axes, true, false);
}

ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(ctx, {data}, {req[0]}, {mean}, small);
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
Shape<6> data_shape, mean_shape;
for (int i = 0; i < 6; ++i) {
data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1;
mean_shape[i] = (i < small.ndim()) ? small[i] : 1;
}
Tensor<xpu, 1, DType> temp_data =
ctx.requested[0].get_space_typed<xpu, 1, DType>(Shape1(data.shape_.Size()), s);;
Kernel<VarBroadcastKernel, xpu>::Launch(s, data.shape_.Size(), temp_data.dptr_,
data.dptr<DType>(), mean.dptr<DType>(), data_shape, mean_shape);
ReduceAxesComputeImpl<xpu, mshadow_op::sum, true, true>(
ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small);
});
}

template<typename xpu>
inline void MomentsForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow_op;
using namespace mxnet_op;

CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);

const MomentsParam& param = nnvm::get<MomentsParam>(attrs.parsed);

MomentsForwardImpl<xpu>(ctx, inputs, req, outputs, param.axes, param.keepdims);
}

template<int req>
struct VarBackwardKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
DType *igrad,
const DType *ograd,
const DType *data,
const DType *mean,
mshadow::Shape<6> data_shape,
mshadow::Shape<6> mean_shape,
const float N,
const float ddof = 0.0f) {
size_t data_idx = i;
size_t mean_idx = i;
size_t data_stride = 1;
size_t mean_stride = 1;
for (int axis = 5; axis >= 0; --axis) {
size_t axis_idx = data_idx % data_shape[axis];
mean_idx -= axis_idx * data_stride;
if (mean_shape[axis] != 1) {
mean_idx += axis_idx * mean_stride;
}
data_idx /= data_shape[axis];
data_stride *= data_shape[axis];
mean_stride *= mean_shape[axis];
}
KERNEL_ASSIGN(igrad[i], req, ograd[mean_idx] * (data[i] - mean[mean_idx]) * 2 / (N - ddof));
}
};

template<typename xpu>
inline void MomentsBackwardImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const dmlc::optional<mxnet::TShape>& axes) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace mshadow_op;
using namespace mxnet_op;

Stream<xpu> *s = ctx.get_stream<xpu>();

const TBlob& mean_grad = inputs[0];
const TBlob& var_grad = inputs[1];
const TBlob& data = inputs[2];
const TBlob& mean = inputs[3];
const TBlob& var = inputs[4];
const TBlob& data_grad = outputs[0];

mxnet::TShape small = ReduceAxesShapeImpl(data.shape_, axes, true, false);
BroadcastComputeImpl<xpu>(attrs, ctx, {mean_grad}, req, outputs, small);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> igrad = outputs[0].FlatTo1D<xpu, DType>(s);
igrad /= scalar<DType>(outputs[0].Size()/inputs[0].Size());
});

Shape<6> data_shape, var_shape;
float N = data_grad.Size() / var.Size();
for (int i = 0; i < 6; ++i) {
data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1;
var_shape[i] = (i < small.ndim()) ? small[i] : 1;
}
MSHADOW_TYPE_SWITCH(data_grad.type_flag_, DType, {
Kernel<VarBackwardKernel<kAddTo>, xpu>::Launch(
s, data_grad.shape_.Size(), data_grad.dptr<DType>(), var_grad.dptr<DType>(),
data.dptr<DType>(), mean.dptr<DType>(), data_shape, var_shape, N);
});
}

template<typename xpu>
inline void MomentsBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow_op;
using namespace mxnet_op;

CHECK_EQ(inputs.size(), 5U);
CHECK_EQ(outputs.size(), 1U);

const MomentsParam& param = nnvm::get<MomentsParam>(attrs.parsed);

MomentsBackwardImpl<xpu>(attrs, ctx, inputs, req, outputs, param.axes);
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_NN_MOMENTS_INL_H_
85 changes: 85 additions & 0 deletions src/operator/nn/moments.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file moments.cc
* \brief Moments operator
* \author Hao Jin
*/

#include "./moments-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(MomentsParam);

NNVM_REGISTER_OP(moments)
.describe(R"code(
Calculate the mean and variance of `data`.
The mean and variance are calculated by aggregating the contents of data across axes.
If x is 1-D and axes = [0] this is just the mean and variance of a vector.
Example:
x = [[1, 2, 3], [4, 5, 6]]
mean, var = moments(data=x, axes=[0])
mean = [2.5, 3.5, 4.5]
var = [2.25, 2.25, 2.25]
mean, var = moments(data=x, axes=[1])
mean = [2.0, 5.0]
var = [0.66666667, 0.66666667]
mean, var = moments(data=x, axis=[0, 1])
mean = [3.5]
var = [2.9166667]
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<MomentsParam>)
.set_num_inputs(1)
.set_num_outputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<mxnet::FInferShape>("FInferShape", MomentsShape)
.set_attr<nnvm::FInferType>("FInferType", MomentsType)
.set_attr<FCompute>("FCompute<cpu>", MomentsForward<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_moments"})
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(MomentsParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_moments)
.set_attr_parser(ParamParser<MomentsParam>)
.set_num_inputs(5)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", MomentsBackward<cpu>);

} // namespace op
} // namespace mxnet
39 changes: 39 additions & 0 deletions src/operator/nn/moments.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file moments.cu
* \brief Moments operator
* \author Hao Jin
*/

#include "./moments-inl.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(moments)
.set_attr<FCompute>("FCompute<gpu>", MomentsForward<gpu>);

NNVM_REGISTER_OP(_backward_moments)
.set_attr<FCompute>("FCompute<gpu>", MomentsBackward<gpu>);

} // namespace op
} // namespace mxnet
Loading

0 comments on commit 83e2a08

Please sign in to comment.