Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Implement np.random.pareto backward #17607

Merged
merged 1 commit into from
Feb 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def weibull(a, size=None, ctx=None, out=None):
return _npi.weibull(a=a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
def pareto(a, size=None, ctx=None, out=None):
r"""Draw samples from a Pareto II or Lomax distribution with specified shape a.

Parameters
Expand Down Expand Up @@ -659,13 +659,15 @@ def pareto(a, size=None):
"""
from ...numpy import ndarray as np_ndarray
tensor_type_name = np_ndarray
if ctx is None:
ctx = current_context()
if size == ():
size = None
is_tensor = isinstance(a, tensor_type_name)
if is_tensor:
return _npi.pareto(a, a=None, size=size)
return _npi.pareto(a, a=None, size=size, ctx=ctx, out=out)
else:
return _npi.pareto(a=a, size=size)
return _npi.pareto(a=a, size=size, ctx=ctx, out=out)


def power(a, size=None):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def weibull(a, size=None, ctx=None, out=None):
return _mx_nd_np.random.weibull(a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
def pareto(a, size=None, ctx=None, out=None):
r"""Draw samples from a Pareto II or Lomax distribution with specified shape a.

Parameters
Expand Down Expand Up @@ -688,7 +688,7 @@ def pareto(a, size=None):
where a is the shape and m the scale. Here m is assumed 1. The Pareto distribution
is a power law distribution. Pareto created it to describe the wealth in the economy.
"""
return _mx_nd_np.random.pareto(a, size)
return _mx_nd_np.random.pareto(a, size=size, ctx=ctx, out=out)


def power(a, size=None):
Expand Down
8 changes: 5 additions & 3 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def weibull(a, size=None, ctx=None, out=None):
return _npi.weibull(a=a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
def pareto(a, size=None, ctx=None, out=None):
r"""Draw samples from a Pareto II or Lomax distribution with specified shape a.

Parameters
Expand Down Expand Up @@ -729,13 +729,15 @@ def pareto(a, size=None):
"""
from ..numpy import _Symbol as np_symbol
tensor_type_name = np_symbol
if ctx is None:
ctx = current_context()
if size == ():
size = None
is_tensor = isinstance(a, tensor_type_name)
if is_tensor:
return _npi.pareto(a, a=None, size=size)
return _npi.pareto(a, a=None, size=size, ctx=ctx, out=out)
else:
return _npi.pareto(a=a, size=size)
return _npi.pareto(a=a, size=size, ctx=ctx, out=out)


def power(a, size=None):
Expand Down
38 changes: 35 additions & 3 deletions src/operator/numpy/random/np_pareto_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(NumpyParetoParam);

NNVM_REGISTER_OP(_npi_pareto)
.describe("Numpy behavior Pareto")
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyParetoParam& param = nnvm::get<NumpyParetoParam>(attrs.parsed);
Expand All @@ -41,7 +42,11 @@ NNVM_REGISTER_OP(_npi_pareto)
}
return num_inputs;
})
.set_num_outputs(1)
.set_num_outputs(2)
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
[](const NodeAttrs& attrs){
return 1;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyParetoParam& param = nnvm::get<NumpyParetoParam>(attrs.parsed);
Expand All @@ -52,10 +57,11 @@ NNVM_REGISTER_OP(_npi_pareto)
return (num_inputs == 0) ? std::vector<std::string>() : std::vector<std::string>{"input1"};
})
.set_attr_parser(ParamParser<NumpyParetoParam>)
.set_attr<mxnet::FInferShape>("FInferShape", UnaryDistOpShape<NumpyParetoParam>)
.set_attr<mxnet::FInferShape>("FInferShape", TwoparamsDistOpShape<NumpyParetoParam>)
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs, std::vector<int> *out_attrs) {
(*out_attrs)[0] = mshadow::kFloat32;
(*out_attrs)[1] = mshadow::kFloat32;
return true;
})
.set_attr<FResourceRequest>("FResourceRequest",
Expand All @@ -64,9 +70,35 @@ NNVM_REGISTER_OP(_npi_pareto)
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyParetoForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_pareto"})
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyParetoParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_broadcast_pareto)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<NumpyParetoParam>)
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs){
const NumpyParetoParam& param = nnvm::get<NumpyParetoParam>(attrs.parsed);
int num_inputs = 5;
if (param.a.has_value()) num_inputs -= 1;
return num_inputs;
}
)
.set_num_outputs(
[](const nnvm::NodeAttrs& attrs){
const NumpyParetoParam& param = nnvm::get<NumpyParetoParam>(attrs.parsed);
int num_outputs = 1;
if (param.a.has_value()) num_outputs -= 1;
return num_outputs;
}
)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs){
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", ParetoReparamBackward<cpu>)
.add_arguments(NumpyParetoParam::__FIELDS__());

} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/random/np_pareto_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ namespace op {
NNVM_REGISTER_OP(_npi_pareto)
.set_attr<FCompute>("FCompute<gpu>", NumpyParetoForward<gpu>);

NNVM_REGISTER_OP(_backward_broadcast_pareto)
.set_attr<FCompute>("FCompute<gpu>", ParetoReparamBackward<gpu>);

} // namespace op
} // namespace mxnet
88 changes: 74 additions & 14 deletions src/operator/numpy/random/np_pareto_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace op {
struct NumpyParetoParam : public dmlc::Parameter<NumpyParetoParam> {
dmlc::optional<float> a;
dmlc::optional<mxnet::Tuple<int>> size;
std::string ctx;
DMLC_DECLARE_PARAMETER(NumpyParetoParam) {
DMLC_DECLARE_FIELD(a)
.set_default(dmlc::optional<float>());
Expand All @@ -52,22 +53,25 @@ struct NumpyParetoParam : public dmlc::Parameter<NumpyParetoParam> {
.describe("Output shape. If the given shape is, "
"e.g., (m, n, k), then m * n * k samples are drawn. "
"Default is None, in which case a single value is returned.");
DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe(
"Context of output, in format [cpu|gpu|cpu_pinned](n)."
" Only used for imperative calls.");
}
};

template <typename DType>
struct scalar_pareto_kernel {
MSHADOW_XINLINE static void Map(index_t i, float a, float *threshold,
MSHADOW_XINLINE static void Map(index_t i, float a, float *noise,
DType *out) {
out[i] = exp(-log(threshold[i])/a) - DType(1);
out[i] = exp(-log(noise[i])/a) - DType(1);
}
};

namespace mxnet_op {

template <typename IType>
struct check_legal_a_kernel {
MSHADOW_XINLINE static void Map(index_t i, IType *a, float* flag) {
MSHADOW_XINLINE static void Map(index_t i, IType *a, float *flag) {
if (a[i] <= 0.0) {
flag[0] = -1.0;
}
Expand All @@ -80,35 +84,37 @@ struct pareto_kernel {
MSHADOW_XINLINE static void Map(index_t i,
const Shape<ndim> &stride,
const Shape<ndim> &oshape,
IType *aparams, float* threshold, OType *out) {
IType *aparams, float *noise, OType *out) {
Shape<ndim> coord = unravel(i, oshape);
auto idx = static_cast<index_t>(dot(coord, stride));
out[i] = exp(-log(threshold[i])/aparams[idx]) - IType(1);
noise[i] = -log(noise[i]);
out[i] = exp(noise[i]/aparams[idx]) - IType(1);
// get grad
noise[i] = -noise[i] * (out[i] + 1.0) * (1.0/(aparams[idx] * aparams[idx]));
}
};

} // namespace mxnet_op

template <typename xpu>
void NumpyParetoForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mxnet_op;
const NumpyParetoParam &param = nnvm::get<NumpyParetoParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
index_t output_len = outputs[0].Size();
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(output_len + 1), s);
Tensor<xpu, 1, float> uniform_tensor = workspace.Slice(0, output_len);
Tensor<xpu, 1, float> indicator_device = workspace.Slice(output_len, output_len + 1);
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(1), s);
Tensor<xpu, 1, float> uniform_tensor = outputs[1].FlatTo1D<xpu, float>(s);
Tensor<xpu, 1, float> indicator_device = workspace;
float indicator_host = 1.0;
float *indicator_device_ptr = indicator_device.dptr_;
Kernel<set_zero, xpu>::Launch(s, 1, indicator_device_ptr);
prnd->SampleUniform(&workspace, 0.0, 1.0);
prnd->SampleUniform(&uniform_tensor, 0.0, 1.0);
if (param.a.has_value()) {
CHECK_GT(param.a.value(), 0.0) << "ValueError: expect a > 0";
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Expand Down Expand Up @@ -140,6 +146,60 @@ void NumpyParetoForward(const nnvm::NodeAttrs &attrs,
}
}

template<typename xpu, int ndim, typename DType>
inline void ScalarParetoReparamBackwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_ishape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob igrad = outputs[0].reshape(new_ishape);
// inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
// samples, noise]
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob itensor = inputs[2].reshape(new_ishape);
const TBlob samples = inputs[3].reshape(new_oshape);
const TBlob noise = inputs[4].reshape(new_oshape);
size_t workspace_size =
ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
s, igrad, req[0], workspace, ograd, noise, noise);
}

template<typename xpu>
void ParetoReparamBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<TBlob>& outputs) {
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
// [scalar] case
if (outputs.size() == 0U) {
return;
}
// [tensor] case
if (inputs.size() == 5U) {
mxnet::TShape new_ishape, new_oshape;
int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_,
&new_ishape, &new_ishape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
ScalarParetoReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, reqs, outputs, new_ishape, new_oshape);
});
});
}
}

} // namespace op
} // namespace mxnet

Expand Down
37 changes: 35 additions & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,14 +4035,13 @@ def hybrid_forward(self, F, a):
expected_shape = a.shape
assert mx_out.shape == expected_shape

# test illegal parameter values (as numpy produces)
# test illegal parameter values
def _test_exception(a):
output = op(a=a).asnumpy()
for op in op_names:
op = getattr(np.random, op_name, None)
if op is not None:
assertRaises(ValueError, _test_exception, -1)
if op in ['pareto', 'power']:
assertRaises(ValueError, _test_exception, 0)


Expand Down Expand Up @@ -4079,6 +4078,40 @@ def hybrid_forward(self, F, a):
assert_almost_equal(a.grad.asnumpy().sum(), formula_grad.asnumpy().sum(), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_pareto_grad():
class TestRandomP(HybridBlock):
def __init__(self, shape):
super(TestRandomP, self).__init__()
self._shape = shape

def hybrid_forward(self, F, a):
return F.np.random.pareto(a, self._shape)

output_shapes = [
(3, 2),
(4, 3, 2, 2),
(3, 4, 5)
]
for hybridize in [False, True]:
for out_shape in output_shapes:
test_w_grad = TestRandomP(out_shape)
if hybridize:
test_w_grad.hybridize()
a = np.ones(out_shape)
a.attach_grad()
with mx.autograd.record():
mx_out = test_w_grad(a)
mx_out.backward()

# gradient formula from calculus (a=1)
noise = np.log(mx_out + np.ones(mx_out.shape))
formula_grad = - (mx_out + np.ones(mx_out.shape)) * noise
assert a.grad.shape == out_shape
assert_almost_equal(a.grad.asnumpy().sum(), formula_grad.asnumpy().sum(), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_randn():
Expand Down