Skip to content

Commit

Permalink
quantile_scalar (apache#17572)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu authored and anirudh2290 committed May 29, 2020
1 parent 98dd56a commit f4b49ce
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 26 deletions.
10 changes: 8 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6378,8 +6378,11 @@ def percentile(a, q, axis=None, out=None, overwrite_input=None, interpolation='l
"""
if overwrite_input is not None:
raise NotImplementedError('overwrite_input is not supported yet')
if isinstance(q, numeric_types):
return _npi.percentile(a, axis=axis, interpolation=interpolation,
keepdims=keepdims, q_scalar=q, out=out)
return _npi.percentile(a, q, axis=axis, interpolation=interpolation,
keepdims=keepdims, out=out)
keepdims=keepdims, q_scalar=None, out=out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6460,8 +6463,11 @@ def quantile(a, q, axis=None, out=None, overwrite_input=None, interpolation='lin
"""
if overwrite_input is not None:
raise NotImplementedError('overwrite_input is not supported yet')
if isinstance(q, numeric_types):
return _npi.percentile(a, axis=axis, interpolation=interpolation,
keepdims=keepdims, q_scalar=q * 100, out=out)
return _npi.percentile(a, q * 100, axis=axis, interpolation=interpolation,
keepdims=keepdims, out=out)
keepdims=keepdims, q_scalar=None, out=out)


@set_module('mxnet.ndarray.numpy')
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8367,8 +8367,8 @@ def quantile(a, q, axis=None, out=None, overwrite_input=None, interpolation='lin
>>> out
array([6.5, 4.5, 2.5])
"""
return _mx_nd_np.quantile(a, q, axis=axis, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims, out=out)
return _mx_nd_np.quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)


@set_module('mxnet.numpy')
Expand Down
10 changes: 8 additions & 2 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5747,8 +5747,11 @@ def percentile(a, q, axis=None, out=None, overwrite_input=None, interpolation='l
"""
if overwrite_input is not None:
raise NotImplementedError('overwrite_input is not supported yet')
if isinstance(q, numeric_types):
return _npi.percentile(a, axis=axis, interpolation=interpolation,
keepdims=keepdims, q_scalar=q, out=out)
return _npi.percentile(a, q, axis=axis, interpolation=interpolation,
keepdims=keepdims, out=out)
keepdims=keepdims, q_scalar=None, out=out)


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -5805,8 +5808,11 @@ def quantile(a, q, axis=None, out=None, overwrite_input=None, interpolation='lin
"""
if overwrite_input is not None:
raise NotImplementedError('overwrite_input is not supported yet')
if isinstance(q, numeric_types):
return _npi.percentile(a, axis=axis, interpolation=interpolation,
keepdims=keepdims, q_scalar=q * 100, out=out)
return _npi.percentile(a, q * 100, axis=axis, interpolation=interpolation,
keepdims=keepdims, out=out)
keepdims=keepdims, q_scalar=None, out=out)


@set_module('mxnet.symbol.numpy')
Expand Down
25 changes: 21 additions & 4 deletions src/operator/numpy/np_percentile_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct NumpyPercentileParam : public dmlc::Parameter<NumpyPercentileParam> {
dmlc::optional<mxnet::Tuple<int>> axis;
int interpolation;
bool keepdims;
dmlc::optional<double> q_scalar;
DMLC_DECLARE_PARAMETER(NumpyPercentileParam) {
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<mxnet::Tuple<int>>())
Expand All @@ -61,6 +62,8 @@ struct NumpyPercentileParam : public dmlc::Parameter<NumpyPercentileParam> {
DMLC_DECLARE_FIELD(keepdims).set_default(false)
.describe("If this is set to `True`, the reduced axes are left "
"in the result as dimension with size one.");
DMLC_DECLARE_FIELD(q_scalar).set_default(dmlc::optional<double>())
.describe("inqut q is a scalar");
}
};

Expand Down Expand Up @@ -133,22 +136,22 @@ void NumpyPercentileForward(const nnvm::NodeAttrs& attrs,
if (req[0] == kNullOp) return;
using namespace mxnet;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 2U);
CHECK_GE(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);

Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob &data = inputs[0];
const TBlob &percentile = inputs[1];
const TBlob &out = outputs[0];
const NumpyPercentileParam& param = nnvm::get<NumpyPercentileParam>(attrs.parsed);
const int interpolation = param.interpolation;
dmlc::optional<mxnet::Tuple<int>> axis = param.axis;
dmlc::optional<double> q_scalar = param.q_scalar;

auto small = NumpyReduceAxesShapeImpl(data.shape_, axis, false);

TShape r_shape;
r_shape = TShape(small.ndim()+1, 1);
r_shape[0] = percentile.Size();
r_shape[0] = q_scalar.has_value()? 1 : inputs[1].Size();
for (int i = 1; i < r_shape.ndim(); ++i) {
r_shape[i] = small[i-1];
}
Expand Down Expand Up @@ -216,14 +219,28 @@ void NumpyPercentileForward(const nnvm::NodeAttrs& attrs,
size_t temp_data_size = data.Size() * sizeof(DType);
size_t idx_size = data.Size() * sizeof(index_t);
size_t temp_mem_size = 2 * temp_data_size + idx_size;
size_t workspace_size = topk_workspace_size * 2 + temp_mem_size + 8;
size_t workspace_size = topk_workspace_size * 2 + temp_mem_size + 16;

Tensor<xpu, 1, char> temp_mem =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);

char* workspace_curr_ptr = temp_mem.dptr_;
DType* trans_ptr, *sort_ptr;
index_t* idx_ptr;
TBlob percentile;
double q;

if (q_scalar.has_value()) {
q = q_scalar.value();
Tensor<cpu, 1, double> host_q(&q, Shape1(1), ctx.get_stream<cpu>());
Tensor<xpu, 1, double> device_q(reinterpret_cast<double*>(workspace_curr_ptr),
Shape1(1), ctx.get_stream<xpu>());
mshadow::Copy(device_q, host_q, ctx.get_stream<xpu>());
percentile = TBlob(device_q.dptr_, TShape(0, 1), xpu::kDevMask);
workspace_curr_ptr += 8;
} else {
percentile = inputs[1];
} // handle input q is a scalar

char* is_valid_ptr = reinterpret_cast<char*>(workspace_curr_ptr);
MSHADOW_TYPE_SWITCH(percentile.type_flag_, QType, {
Expand Down
15 changes: 10 additions & 5 deletions src/operator/numpy/np_percentile_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ bool CheckInvalidInput(mshadow::Stream<cpu> *s, const QType *data,
inline bool NumpyPercentileShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_GE(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape qshape = in_attrs->at(1);
CHECK_LE(qshape.ndim(), 1U);
if (!shape_is_known(in_attrs->at(0))) {
return false;
}
const NumpyPercentileParam& param = nnvm::get<NumpyPercentileParam>(attrs.parsed);
mxnet::TShape shape = NumpyReduceAxesShapeImpl((*in_attrs)[0], param.axis, param.keepdims);

mxnet::TShape qshape = param.q_scalar.has_value()? mxnet::TShape(0, 1) : in_attrs->at(1);
CHECK_LE(qshape.ndim(), 1U);

if (qshape.ndim() == 0) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape);
} else {
Expand All @@ -67,7 +68,7 @@ inline bool NumpyPercentileShape(const nnvm::NodeAttrs& attrs,
inline bool NumpyPercentileType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_GE(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

if (in_attrs->at(0) == mshadow::kFloat64) {
Expand All @@ -81,7 +82,11 @@ inline bool NumpyPercentileType(const nnvm::NodeAttrs& attrs,
DMLC_REGISTER_PARAMETER(NumpyPercentileParam);

NNVM_REGISTER_OP(_npi_percentile)
.set_num_inputs(2)
.set_num_inputs([](const NodeAttrs& attrs) {
const NumpyPercentileParam& param =
nnvm::get<NumpyPercentileParam>(attrs.parsed);
return param.q_scalar.has_value()? 1 : 2;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyPercentileParam>)
.set_attr<mxnet::FInferShape>("FInferShape", NumpyPercentileShape)
Expand Down
13 changes: 9 additions & 4 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,17 @@ def _add_workload_quantile():
q2 = np.array(1)
q3 = np.array(0.5)
q4 = np.array([0, 0.75, 0.25, 0.5, 1.0])
q5 = 0.4

OpArgMngr.add_workload('quantile', x1, q1)
OpArgMngr.add_workload('quantile', x1, q2)
OpArgMngr.add_workload('quantile', x1, q3)
OpArgMngr.add_workload('quantile', x2, q4, interpolation="midpoint")
OpArgMngr.add_workload('quantile', x2, q4, interpolation="nearest")
OpArgMngr.add_workload('quantile', x2, q4, interpolation="lower")
OpArgMngr.add_workload('quantile', x2, q5, interpolation="midpoint")
OpArgMngr.add_workload('quantile', x2, q5, interpolation="nearest")
OpArgMngr.add_workload('quantile', x2, q5, interpolation="lower")


def _add_workload_percentile():
Expand All @@ -192,6 +196,7 @@ def _add_workload_percentile():
q2 = np.array(60)
x3 = np.arange(10)
q3 = np.array([25, 50, 100])
q4 = 65
x4 = np.arange(11 * 2).reshape(11, 1, 2, 1)
x5 = np.array([0, np.nan])

Expand All @@ -206,12 +211,12 @@ def _add_workload_percentile():
OpArgMngr.add_workload('percentile', x3, q3)
OpArgMngr.add_workload('percentile', x4, q2, axis=0)
OpArgMngr.add_workload('percentile', x4, q2, axis=1)
OpArgMngr.add_workload('percentile', x4, q2, axis=2)
OpArgMngr.add_workload('percentile', x4, q2, axis=3)
OpArgMngr.add_workload('percentile', x4, q4, axis=2)
OpArgMngr.add_workload('percentile', x4, q4, axis=3)
OpArgMngr.add_workload('percentile', x4, q2, axis=-1)
OpArgMngr.add_workload('percentile', x4, q2, axis=-2)
OpArgMngr.add_workload('percentile', x4, q2, axis=-3)
OpArgMngr.add_workload('percentile', x4, q2, axis=-4)
OpArgMngr.add_workload('percentile', x4, q4, axis=-3)
OpArgMngr.add_workload('percentile', x4, q4, axis=-4)
OpArgMngr.add_workload('percentile', x4, q2, axis=(1,2))
OpArgMngr.add_workload('percentile', x4, q3, axis=(-2,-1))
OpArgMngr.add_workload('percentile', x4, q2, axis=(1,2), keepdims=True)
Expand Down
22 changes: 15 additions & 7 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6908,8 +6908,8 @@ def hybrid_forward(self, F, a, q):
((2, 3, 4), (3,), (0, 2)),
((2, 3, 4), (3,), 1)
]
for hybridize, keepdims, (a_shape, q_shape, axis), interpolation, dtype in \
itertools.product(flags, flags, tensor_shapes, interpolation_options, dtypes):
for hybridize, keepdims, q_scalar, (a_shape, q_shape, axis), interpolation, dtype in \
itertools.product(flags, flags, flags, tensor_shapes, interpolation_options, dtypes):
if dtype == np.float16 and interpolation == 'linear': continue
atol = 3e-4 if dtype == np.float16 else 1e-4
rtol = 3e-2 if dtype == np.float16 else 1e-2
Expand All @@ -6923,9 +6923,13 @@ def hybrid_forward(self, F, a, q):
mx_out = test_quantile(a, q)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)


np_q = q.asnumpy()
if q_scalar and q_shape == ():
q = q.item()
np_q = q
mx_out = np.quantile(a, q, axis=axis, interpolation=interpolation, keepdims=keepdims)
np_out = _np.quantile(a.asnumpy(), q.asnumpy(), axis=axis, interpolation=interpolation, keepdims=keepdims)
np_out = _np.quantile(a.asnumpy(), np_q, axis=axis, interpolation=interpolation, keepdims=keepdims)
assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)


Expand Down Expand Up @@ -6955,8 +6959,8 @@ def hybrid_forward(self, F, a, q):
((2, 3, 4), (3,), (0, 2)),
((2, 3, 4), (3,), 1)
]
for hybridize, keepdims, (a_shape, q_shape, axis), interpolation, dtype in \
itertools.product(flags, flags, tensor_shapes, interpolation_options, dtypes):
for hybridize, keepdims, q_scalar, (a_shape, q_shape, axis), interpolation, dtype in \
itertools.product(flags, flags, flags, tensor_shapes, interpolation_options, dtypes):
if dtype == np.float16 and interpolation == 'linear': continue
atol = 3e-4 if dtype == np.float16 else 1e-4
rtol = 3e-2 if dtype == np.float16 else 1e-2
Expand All @@ -6971,8 +6975,12 @@ def hybrid_forward(self, F, a, q):
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)

np_q = q.asnumpy()
if q_scalar and q_shape == ():
q = q.item()
np_q = q
mx_out = np.percentile(a, q, axis=axis, interpolation=interpolation, keepdims=keepdims)
np_out = _np.percentile(a.asnumpy(), q.asnumpy(), axis=axis, interpolation=interpolation, keepdims=keepdims)
np_out = _np.percentile(a.asnumpy(), np_q, axis=axis, interpolation=interpolation, keepdims=keepdims)
assert_almost_equal(mx_out.asnumpy(), np_out, atol=atol, rtol=rtol)


Expand Down

0 comments on commit f4b49ce

Please sign in to comment.