Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Mixed precison binary op backward (use in) for numpy #16791

Merged
merged 2 commits into from
Nov 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
'subok': True,
}

_set_np_shape_logged = False
_set_np_array_logged = False


def makedirs(d):
"""Create directories recursively if they don't exist. os.makedirs(exist_ok=True) is not
Expand Down Expand Up @@ -87,13 +90,16 @@ def set_np_shape(active):
>>> print(mx.is_np_shape())
True
"""
global _set_np_shape_logged
if active:
import logging
logging.info('NumPy-shape semantics has been activated in your code. '
'This is required for creating and manipulating scalar and zero-size '
'tensors, which were not supported in MXNet before, as in the official '
'NumPy library. Please DO NOT manually deactivate this semantics while '
'using `mxnet.numpy` and `mxnet.numpy_extension` modules.')
if not _set_np_shape_logged:
import logging
logging.info('NumPy-shape semantics has been activated in your code. '
'This is required for creating and manipulating scalar and zero-size '
'tensors, which were not supported in MXNet before, as in the official '
'NumPy library. Please DO NOT manually deactivate this semantics while '
'using `mxnet.numpy` and `mxnet.numpy_extension` modules.')
_set_np_shape_logged = True
elif is_np_array():
raise ValueError('Deactivating NumPy shape semantics while NumPy array semantics is still'
' active is not allowed. Please consider calling `npx.reset_np()` to'
Expand Down Expand Up @@ -678,11 +684,14 @@ def _set_np_array(active):
-------
A bool value indicating the previous state of NumPy array semantics.
"""
global _set_np_array_logged
if active:
import logging
logging.info('NumPy array semantics has been activated in your code. This allows you'
' to use operators from MXNet NumPy and NumPy Extension modules as well'
' as MXNet NumPy `ndarray`s.')
if not _set_np_array_logged:
import logging
logging.info('NumPy array semantics has been activated in your code. This allows you'
' to use operators from MXNet NumPy and NumPy Extension modules as well'
' as MXNet NumPy `ndarray`s.')
_set_np_array_logged = True
cur_state = is_np_array()
_NumpyArrayScope._current.value = _NumpyArrayScope(active)
return cur_state
Expand Down
17 changes: 16 additions & 1 deletion src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
"FCompute<cpu>",
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"});

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::right,
mshadow_op::left>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
Expand Down
4 changes: 4 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ NNVM_REGISTER_OP(_npi_multiply)
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
#endif

NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, mshadow_op::right,
mshadow_op::left>);

NNVM_REGISTER_OP(_npi_mod)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);

Expand Down
104 changes: 102 additions & 2 deletions src/operator/numpy/np_elemwise_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_

#include <algorithm>
#include <vector>
#include <string>

Expand Down Expand Up @@ -391,11 +392,13 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
}

template<typename xpu, typename LOP, typename ROP>
void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);

Expand All @@ -406,7 +409,104 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
return;
}

PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
const TBlob& ograd = inputs[0];
const TBlob& lgrad = outputs[0];
const TBlob& rgrad = outputs[1];

if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
// If any of the inputs is a float, it's the same type as the output
// So 2 of the 3 tensors have the same data type
Stream<xpu> *s = ctx.get_stream<xpu>();
mxnet::TShape new_lshape, new_rshape, new_oshape;
using namespace broadcast;
const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, rgrad.shape_, ograd.shape_,
&new_lshape, &new_rshape, &new_oshape) != 0;

// Prepare all the temporary memory
size_t workspace_size_l = 0, workspace_size_r = 0;
TBlob temp_tblob; // The TBlob for casted input data
TBlob temp_igrad; // The TBlob for casted grad results
size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() : rgrad.Size();
Tensor<xpu, 1, char> workspace;

MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, {
workspace_size_l = ReduceWorkspaceSize<ndim, OType>(
s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
workspace_size_r = ReduceWorkspaceSize<ndim, OType>(
s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
});
size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
size_t cast_tensor_size = tensor_size * sizeof(OType);
// Allocate the temporary memories now
Tensor<xpu, 1, char> temp_space =
ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(workspace_size + cast_tensor_size * 2), s);
// Tensor for temp_tblob
Tensor<xpu, 1, OType> temp_tblob_tensor(
reinterpret_cast<OType*>(temp_space.dptr_),
Shape1(tensor_size), s);
// Tensor for temp_igrad
Tensor<xpu, 1, OType> temp_igrad_tensor(
reinterpret_cast<OType*>(temp_space.dptr_) + tensor_size,
Shape1(tensor_size), s);
temp_tblob =
TBlob(temp_tblob_tensor)
.reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
temp_igrad =
TBlob(temp_igrad_tensor)
.reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
if (temp_igrad.Size() != 0) {
Kernel<set_zero, xpu>::Launch(s, temp_igrad.Size(), temp_igrad.dptr<OType>());
}
workspace =
Tensor<xpu, 1, char>(temp_space.dptr_ + 2 * cast_tensor_size, Shape1(workspace_size), s);
});
// Cast the input that does not have consistent type to temp_tblob
CastCompute<xpu>(
attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, {kWriteTo}, {temp_tblob});
if (!need_bc) {
if (lhs.type_flag_ != ograd.type_flag_) {
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad});
} else {
ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad});
}
} else {
if (lhs.type_flag_ != ograd.type_flag_) {
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad},
workspace, new_lshape, new_rshape, new_oshape);
});
});
} else {
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad},
workspace, new_lshape, new_rshape, new_oshape);
});
});
}
}

// If both inputs are floating numbers, cast the igrad to the input that has
// the different data type
if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
if (lhs.type_flag_ != ograd.type_flag_) {
CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad});
} else {
CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad});
}
}
} else {
// Case where both inputs are integer types, should not even do
// backward computation for this case.
PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
}
}

} // namespace op
Expand Down
26 changes: 26 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,32 @@ BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mshadow::Tensor<xpu, 1, char>& workspace,
const mxnet::TShape& new_lshape,
const mxnet::TShape& new_rshape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob lgrad = outputs[0].reshape(new_lshape);
const TBlob rgrad = outputs[1].reshape(new_rshape);
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob lhs = inputs[1].reshape(new_lshape);
const TBlob rhs = inputs[2].reshape(new_rshape);
if (ograd.Size() != 0) {
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, LOP>(s, lgrad, req[0], workspace,
ograd, lhs, rhs);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, ROP>(s, rgrad, req[1], workspace,
ograd, lhs, rhs);
}
}

template<typename xpu, int ndim, typename DType, typename LOP, typename ROP>
inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, DstDType> out = outputs[0].FlatTo1D<xpu, DstDType>(s);
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, {
Tensor<xpu, 1, SrcDType> data = inputs[0].FlatTo1D<xpu, SrcDType>(s);
if (outputs[0].type_flag_ != inputs[0].type_flag_ ||
req[0] != kWriteInplace) {
if ((outputs[0].type_flag_ != inputs[0].type_flag_ ||
req[0] != kWriteInplace) && outputs[0].Size() != 0) {
Assign(out, req[0], tcast<DstDType>(data));
}
});
Expand Down
20 changes: 12 additions & 8 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,7 +1685,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
@with_seed()
@use_np
def test_np_mixed_precision_binary_funcs():
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype):
itypes = [np.bool, np.int8, np.int32, np.int64]
ftypes = [np.float16, np.float32, np.float64]
def check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, ltype, rtype):
class TestMixedBinary(HybridBlock):
def __init__(self, func):
super(TestMixedBinary, self).__init__()
Expand Down Expand Up @@ -1719,13 +1721,15 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
use_broadcast=False, equal_nan=True)

funcs = {
'add': (-1.0, 1.0),
'subtract': (-1.0, 1.0),
'multiply': (-1.0, 1.0),
'add': (-1.0, 1.0, None, None),
'subtract': (-1.0, 1.0, None, None),
'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape),
lambda y, x1, x2: _np.broadcast_to(x1, y.shape))
}

shape_pairs = [((3, 2), (3, 2)),
((3, 2), (3, 1)),
((3, 0), (3, 0)),
((3, 1), (3, 0)),
((0, 2), (1, 2)),
((2, 3, 4), (3, 1)),
Expand All @@ -1735,16 +1739,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
itypes = [np.bool, np.int8, np.int32, np.int64]
ftypes = [np.float16, np.float32, np.float64]
for func, func_data in funcs.items():
low, high = func_data
low, high, lgrad, rgrad = func_data
for lshape, rshape in shape_pairs:
for type1, type2 in itertools.product(itypes, ftypes):
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type2, type1)

for type1, type2 in itertools.product(ftypes, ftypes):
if type1 == type2:
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2)
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)


@with_seed()
Expand Down