Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Implement np.random.pareto backward
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Roberts committed Feb 21, 2020
1 parent b6b1de0 commit 478705f
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 27 deletions.
8 changes: 5 additions & 3 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def weibull(a, size=None, ctx=None, out=None):
return _npi.weibull(a=a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
def pareto(a, size=None, ctx=None, out=None):
r"""Draw samples from a Pareto II or Lomax distribution with specified shape a.
Parameters
Expand Down Expand Up @@ -659,13 +659,15 @@ def pareto(a, size=None):
"""
from ...numpy import ndarray as np_ndarray
tensor_type_name = np_ndarray
if ctx is None:
ctx = current_context()
if size == ():
size = None
is_tensor = isinstance(a, tensor_type_name)
if is_tensor:
return _npi.pareto(a, a=None, size=size)
return _npi.pareto(a, a=None, size=size, ctx=ctx, out=out)
else:
return _npi.pareto(a=a, size=size)
return _npi.pareto(a=a, size=size, ctx=ctx, out=out)


def power(a, size=None):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def weibull(a, size=None, ctx=None, out=None):
return _mx_nd_np.random.weibull(a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
def pareto(a, size=None, ctx=None, out=None):
r"""Draw samples from a Pareto II or Lomax distribution with specified shape a.
Parameters
Expand Down Expand Up @@ -688,7 +688,7 @@ def pareto(a, size=None):
where a is the shape and m the scale. Here m is assumed 1. The Pareto distribution
is a power law distribution. Pareto created it to describe the wealth in the economy.
"""
return _mx_nd_np.random.pareto(a, size)
return _mx_nd_np.random.pareto(a, size=size, ctx=ctx, out=out)


def power(a, size=None):
Expand Down
8 changes: 5 additions & 3 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def weibull(a, size=None, ctx=None, out=None):
return _npi.weibull(a=a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
def pareto(a, size=None, ctx=None, out=None):
r"""Draw samples from a Pareto II or Lomax distribution with specified shape a.
Parameters
Expand Down Expand Up @@ -729,13 +729,15 @@ def pareto(a, size=None):
"""
from ..numpy import _Symbol as np_symbol
tensor_type_name = np_symbol
if ctx is None:
ctx = current_context()
if size == ():
size = None
is_tensor = isinstance(a, tensor_type_name)
if is_tensor:
return _npi.pareto(a, a=None, size=size)
return _npi.pareto(a, a=None, size=size, ctx=ctx, out=out)
else:
return _npi.pareto(a=a, size=size)
return _npi.pareto(a=a, size=size, ctx=ctx, out=out)


def power(a, size=None):
Expand Down
38 changes: 35 additions & 3 deletions src/operator/numpy/random/np_pareto_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(NumpyParetoParam);

NNVM_REGISTER_OP(_npi_pareto)
.describe("Numpy behavior Pareto")
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyParetoParam& param = nnvm::get<NumpyParetoParam>(attrs.parsed);
Expand All @@ -41,7 +42,11 @@ NNVM_REGISTER_OP(_npi_pareto)
}
return num_inputs;
})
.set_num_outputs(1)
.set_num_outputs(2)
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
[](const NodeAttrs& attrs){
return 1;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyParetoParam& param = nnvm::get<NumpyParetoParam>(attrs.parsed);
Expand All @@ -52,10 +57,11 @@ NNVM_REGISTER_OP(_npi_pareto)
return (num_inputs == 0) ? std::vector<std::string>() : std::vector<std::string>{"input1"};
})
.set_attr_parser(ParamParser<NumpyParetoParam>)
.set_attr<mxnet::FInferShape>("FInferShape", UnaryDistOpShape<NumpyParetoParam>)
.set_attr<mxnet::FInferShape>("FInferShape", TwoparamsDistOpShape<NumpyParetoParam>)
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs, std::vector<int> *out_attrs) {
(*out_attrs)[0] = mshadow::kFloat32;
(*out_attrs)[1] = mshadow::kFloat32;
return true;
})
.set_attr<FResourceRequest>("FResourceRequest",
Expand All @@ -64,9 +70,35 @@ NNVM_REGISTER_OP(_npi_pareto)
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyParetoForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_pareto"})
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyParetoParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_broadcast_pareto)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<NumpyParetoParam>)
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs){
const NumpyParetoParam& param = nnvm::get<NumpyParetoParam>(attrs.parsed);
int num_inputs = 5;
if (param.a.has_value()) num_inputs -= 1;
return num_inputs;
}
)
.set_num_outputs(
[](const nnvm::NodeAttrs& attrs){
const NumpyParetoParam& param = nnvm::get<NumpyParetoParam>(attrs.parsed);
int num_outputs = 1;
if (param.a.has_value()) num_outputs -= 1;
return num_outputs;
}
)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs){
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", ParetoReparamBackward<cpu>)
.add_arguments(NumpyParetoParam::__FIELDS__());

} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/random/np_pareto_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ namespace op {
NNVM_REGISTER_OP(_npi_pareto)
.set_attr<FCompute>("FCompute<gpu>", NumpyParetoForward<gpu>);

NNVM_REGISTER_OP(_backward_broadcast_pareto)
.set_attr<FCompute>("FCompute<gpu>", ParetoReparamBackward<gpu>);

} // namespace op
} // namespace mxnet
88 changes: 74 additions & 14 deletions src/operator/numpy/random/np_pareto_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace op {
struct NumpyParetoParam : public dmlc::Parameter<NumpyParetoParam> {
dmlc::optional<float> a;
dmlc::optional<mxnet::Tuple<int>> size;
std::string ctx;
DMLC_DECLARE_PARAMETER(NumpyParetoParam) {
DMLC_DECLARE_FIELD(a)
.set_default(dmlc::optional<float>());
Expand All @@ -52,22 +53,25 @@ struct NumpyParetoParam : public dmlc::Parameter<NumpyParetoParam> {
.describe("Output shape. If the given shape is, "
"e.g., (m, n, k), then m * n * k samples are drawn. "
"Default is None, in which case a single value is returned.");
DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe(
"Context of output, in format [cpu|gpu|cpu_pinned](n)."
" Only used for imperative calls.");
}
};

template <typename DType>
struct scalar_pareto_kernel {
MSHADOW_XINLINE static void Map(index_t i, float a, float *threshold,
MSHADOW_XINLINE static void Map(index_t i, float a, float *noise,
DType *out) {
out[i] = exp(-log(threshold[i])/a) - DType(1);
out[i] = exp(-log(noise[i])/a) - DType(1);
}
};

namespace mxnet_op {

template <typename IType>
struct check_legal_a_kernel {
MSHADOW_XINLINE static void Map(index_t i, IType *a, float* flag) {
MSHADOW_XINLINE static void Map(index_t i, IType *a, float *flag) {
if (a[i] <= 0.0) {
flag[0] = -1.0;
}
Expand All @@ -80,35 +84,37 @@ struct pareto_kernel {
MSHADOW_XINLINE static void Map(index_t i,
const Shape<ndim> &stride,
const Shape<ndim> &oshape,
IType *aparams, float* threshold, OType *out) {
IType *aparams, float *noise, OType *out) {
Shape<ndim> coord = unravel(i, oshape);
auto idx = static_cast<index_t>(dot(coord, stride));
out[i] = exp(-log(threshold[i])/aparams[idx]) - IType(1);
noise[i] = -log(noise[i]);
out[i] = exp(noise[i]/aparams[idx]) - IType(1);
// get grad
noise[i] = -noise[i] * (out[i] + 1.0) * (1.0/(aparams[idx] * aparams[idx]));
}
};

} // namespace mxnet_op

template <typename xpu>
void NumpyParetoForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mxnet_op;
const NumpyParetoParam &param = nnvm::get<NumpyParetoParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
index_t output_len = outputs[0].Size();
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(output_len + 1), s);
Tensor<xpu, 1, float> uniform_tensor = workspace.Slice(0, output_len);
Tensor<xpu, 1, float> indicator_device = workspace.Slice(output_len, output_len + 1);
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(1), s);
Tensor<xpu, 1, float> uniform_tensor = outputs[1].FlatTo1D<xpu, float>(s);
Tensor<xpu, 1, float> indicator_device = workspace;
float indicator_host = 1.0;
float *indicator_device_ptr = indicator_device.dptr_;
Kernel<set_zero, xpu>::Launch(s, 1, indicator_device_ptr);
prnd->SampleUniform(&workspace, 0.0, 1.0);
prnd->SampleUniform(&uniform_tensor, 0.0, 1.0);
if (param.a.has_value()) {
CHECK_GT(param.a.value(), 0.0) << "ValueError: expect a > 0";
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Expand Down Expand Up @@ -140,6 +146,60 @@ void NumpyParetoForward(const nnvm::NodeAttrs &attrs,
}
}

template<typename xpu, int ndim, typename DType>
inline void ScalarParetoReparamBackwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_ishape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob igrad = outputs[0].reshape(new_ishape);
// inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
// samples, noise]
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob itensor = inputs[2].reshape(new_ishape);
const TBlob samples = inputs[3].reshape(new_oshape);
const TBlob noise = inputs[4].reshape(new_oshape);
size_t workspace_size =
ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
s, igrad, req[0], workspace, ograd, noise, noise);
}

template<typename xpu>
void ParetoReparamBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<TBlob>& outputs) {
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
// [scalar] case
if (outputs.size() == 0U) {
return;
}
// [tensor] case
if (inputs.size() == 5U) {
mxnet::TShape new_ishape, new_oshape;
int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_,
&new_ishape, &new_ishape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
ScalarParetoReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, reqs, outputs, new_ishape, new_oshape);
});
});
}
}

} // namespace op
} // namespace mxnet

Expand Down
37 changes: 35 additions & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4035,14 +4035,13 @@ def hybrid_forward(self, F, a):
expected_shape = a.shape
assert mx_out.shape == expected_shape

# test illegal parameter values (as numpy produces)
# test illegal parameter values
def _test_exception(a):
output = op(a=a).asnumpy()
for op in op_names:
op = getattr(np.random, op_name, None)
if op is not None:
assertRaises(ValueError, _test_exception, -1)
if op in ['pareto', 'power']:
assertRaises(ValueError, _test_exception, 0)


Expand Down Expand Up @@ -4079,6 +4078,40 @@ def hybrid_forward(self, F, a):
assert_almost_equal(a.grad.asnumpy().sum(), formula_grad.asnumpy().sum(), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_pareto_grad():
class TestRandomP(HybridBlock):
def __init__(self, shape):
super(TestRandomP, self).__init__()
self._shape = shape

def hybrid_forward(self, F, a):
return F.np.random.pareto(a, self._shape)

output_shapes = [
(3, 2),
(4, 3, 2, 2),
(3, 4, 5)
]
for hybridize in [False, True]:
for out_shape in output_shapes:
test_w_grad = TestRandomP(out_shape)
if hybridize:
test_w_grad.hybridize()
a = np.ones(out_shape)
a.attach_grad()
with mx.autograd.record():
mx_out = test_w_grad(a)
mx_out.backward()

# gradient formula from calculus (a=1)
noise = np.log(mx_out + np.ones(mx_out.shape))
formula_grad = - (mx_out + np.ones(mx_out.shape)) * noise
assert a.grad.shape == out_shape
assert_almost_equal(a.grad.asnumpy().sum(), formula_grad.asnumpy().sum(), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_randn():
Expand Down

0 comments on commit 478705f

Please sign in to comment.