Skip to content

Commit

Permalink
implement round ste
Browse files Browse the repository at this point in the history
  • Loading branch information
Jopyth committed Sep 5, 2018
1 parent ada4ea1 commit 044f81f
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,14 @@ def round(self, *args, **kwargs):
"""
return op.round(self, *args, **kwargs)

def round_ste(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`round_ste`.
The arguments are the same as for :py:func:`round_ste`, with
this array as data.
"""
return op.round_ste(self, *args, **kwargs)

def rint(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rint`.
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,6 +2150,14 @@ def round(self, *args, **kwargs):
"""
return op.round(self, *args, **kwargs)

def round_ste(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`round_ste`.
The arguments are the same as for :py:func:`round_ste`, with
this array as data.
"""
return op.round_ste(self, *args, **kwargs)

def rint(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rint`.
Expand Down
15 changes: 15 additions & 0 deletions smd_hpi/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,18 @@ def test_det_sign():
assert_almost_equal(exp_y, y.asnumpy())
y.backward()
assert_almost_equal(exp_grad, x.grad.asnumpy())


def test_round_ste():
npy = np.random.uniform(-10, 10, (2, 3, 4))

exp_y = np.round(npy)
exp_grad = np.ones_like(npy)

x = mx.nd.array(npy)
x.attach_grad()
with autograd.record():
y = x.round_ste()
assert_almost_equal(exp_y, y.asnumpy())
y.backward()
assert_almost_equal(exp_grad, x.grad.asnumpy())
1 change: 1 addition & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sign_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::det_sign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::det_sign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::round); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::round); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::floor); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::trunc); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rint); // NOLINT()
Expand Down
19 changes: 19 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,25 @@ The storage type of ``round`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

// round_ste
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP(round_ste, cpu, mshadow_op::round)
MXNET_ADD_SPARSE_OP_ALIAS(round_ste)
.describe(R"code(Returns element-wise rounded value to the nearest integer of the input but with STE.
Example::
round_ste([-1.5, 1.5, -1.9, 1.9, 2.1]) = [-2., 2., -2., 2., 2.]
The storage type of ``round_ste`` output depends upon the input storage type:
- round_ste(default) = default
- round_ste(row_sparse) = row_sparse
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_round_ste"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_round_ste, unary_bwd<mshadow_op::identity_grad>);

// rint
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(rint, cpu, mshadow_op::rint)
.describe(R"code(Returns element-wise rounded value to the nearest integer of the input.
Expand Down
9 changes: 9 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ NNVM_REGISTER_OP(round)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::round>)
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::round>);

// round_ste
NNVM_REGISTER_OP(round_ste)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::round>)
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::ComputeEx<gpu, mshadow_op::round>);

NNVM_REGISTER_OP(_backward_round_ste)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
gpu, unary_bwd<mshadow_op::identity_grad> >);

// ceil
NNVM_REGISTER_OP(ceil)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::ceil>)
Expand Down

0 comments on commit 044f81f

Please sign in to comment.