Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Tensordot backward
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Feb 10, 2020
1 parent b4592cc commit d29ae7f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 140 deletions.
110 changes: 2 additions & 108 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'resize', 'nan_to_num', 'where', 'bincount', 'zeros1', 'tensordot1', 'nop']
'diff', 'resize', 'nan_to_num', 'where', 'bincount', 'nop']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1282,21 +1282,7 @@ def tensordot(a, b, axes=2):
[ 4796., 5162.],
[ 4928., 5306.]])
"""
if _np.isscalar(axes):
return _npi.tensordot_int_axes(a, b, axes)

if len(axes) != 2:
raise ValueError('Axes must consist of two arrays.')
a_axes_summed, b_axes_summed = axes
if _np.isscalar(a_axes_summed):
a_axes_summed = (a_axes_summed,)
if _np.isscalar(b_axes_summed):
b_axes_summed = (b_axes_summed,)

if len(a_axes_summed) != len(b_axes_summed):
raise ValueError('Axes length mismatch')

return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)
return _api_internal.tensordot(a, b, axes)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -6676,98 +6662,6 @@ def bincount(x, weights=None, minlength=0):
return _npi.bincount(x, weights=weights, minlength=minlength, has_weights=True)


@set_module('mxnet.ndarray.numpy')
def zeros1(shape, dtype=None, order='C', ctx=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Parameters
----------
shape : int or tuple of int
The shape of the empty array.
dtype : str or numpy.dtype, optional
An optional value type. Default is `numpy.float32`. Note that this
behavior is different from NumPy's `zeros` function where `float64`
is the default value, because `float32` is considered as the default
data type in deep learning.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
ctx : Context, optional
An optional device context (default is the current default context).
Returns
-------
out : ndarray
Array of zeros with the given shape, dtype, and ctx.
"""
if order != 'C':
raise NotImplementedError
# if ctx is None:
# ctx = str(current_context())
if dtype is not None and not isinstance(dtype, str):
dtype = _np.dtype(dtype).name
return _api_internal.zeros(shape, dtype, ctx)


@set_module('mxnet.ndarray.numpy')
def tensordot1(a, b, axes=2):
r"""
tensordot(a, b, axes=2)
Compute tensor dot product along specified axes for arrays >= 1-D.
Given two tensors (arrays of dimension greater than or equal to one),
`a` and `b`, and an ndarray object containing two ndarray
objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s
elements (components) over the axes specified by ``a_axes`` and
``b_axes``. The third argument can be a single non-negative
integer_like scalar, ``N``; if it is such, then the last ``N``
dimensions of `a` and the first ``N`` dimensions of `b` are summed
over.
Parameters
----------
a, b : ndarray, len(shape) >= 1
Tensors to "dot".
axes : int or (2,) ndarray
* integer_like
If an int N, sum over the last N axes of `a` and the first N axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) ndarray
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements ndarray must be of the same length.
See Also
--------
dot, einsum
Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\cdot b`
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`
When `axes` is integer_like, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.
When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.
Examples
--------
>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[ 4400., 4730.],
[ 4532., 4874.],
[ 4664., 5018.],
[ 4796., 5162.],
[ 4928., 5306.]])
"""
return _api_internal.tensordot(a, b, axes)


@set_module('mxnet.ndarray.numpy')
def nop(*args):
r"""
Expand Down
81 changes: 49 additions & 32 deletions src/api/api_npi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ inline void SetInOut(std::vector<NDArray*>* ndinputs,
}
}

template<typename T>
inline std::vector<NDArray*> Invoke(const nnvm::Op* op,
nnvm::NodeAttrs* attrs,
int num_inputs,
Expand All @@ -92,6 +93,7 @@ inline std::vector<NDArray*> Invoke(const nnvm::Op* op,

auto state = Imperative::Get()->Invoke(Context::CPU(), *attrs, ndinputs, ndoutputs);
if (Imperative::Get()->is_recording()) {
::dmlc::get<T>(attrs->parsed).SetAttrDict(&(attrs->dict));
Imperative::Get()->RecordOp(std::move(*attrs), ndinputs, ndoutputs, state);
}
for (int i = *num_outputs; i < infered_num_outputs; ++i) delete ndoutputs[i];
Expand Down Expand Up @@ -120,51 +122,66 @@ MXNET_REGISTER_API("_npi.zeros")
attrs.dict["ctx"] = args[2].operator std::string();
}
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr);
auto ndoutputs = Invoke<op::InitOpParam>(op, &attrs, 0, nullptr, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.tensordot")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
inline static void _npi_tensordot_int_axes(runtime::MXNetArgs args,
runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_tensordot_int_axes");
op::TensordotIntAxesParam param;
nnvm::NodeAttrs attrs;
attrs.op = op;
param.axes = args[2].operator int();
// we directly copy TensordotIntAxesParam, which is trivially-copyable
attrs.parsed = param;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
auto ndoutputs = Invoke<op::TensordotIntAxesParam>(op, &attrs, 2, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}

inline static void _npi_tensordot(runtime::MXNetArgs args,
runtime::MXNetRetValue* ret) {
using namespace runtime;
bool isscalar = args[2].type_code() == kDLInt;
const nnvm::Op* op = Op::Get(isscalar ?
"_npi_tensordot_int_axes" :
"_npi_tensordot");
const nnvm::Op* op = Op::Get("_npi_tensordot");
op::TensordotParam param;
nnvm::NodeAttrs attrs;
attrs.op = op;
if (isscalar) {
mxnet::op::TensordotIntAxesParam param;
param.axes = args[2].operator int();
// we directly copy TensordotIntAxesParam, which is trivially-copyable
attrs.parsed = param;
const ObjectRef ref = args[2].operator ObjectRef();
if (const ADTObj* obj = ref.as<ADTObj>()) {
if (const IntegerObj* lop = (*obj)[0].as<IntegerObj>()) {
param.a_axes_summed = Tuple<int>(1, lop->value);
param.b_axes_summed = Tuple<int>(1, Downcast<Integer, ObjectRef>((*obj)[1])->value);
} else {
param.a_axes_summed = Tuple<int>((*obj)[0]);
param.b_axes_summed = Tuple<int>((*obj)[1]);
}
} else {
mxnet::op::TensordotParam param;
const ObjectRef ref = args[2].operator ObjectRef();
if (const ADTObj* obj = ref.as<ADTObj>()) {
if (const IntegerObj* lop = (*obj)[0].as<IntegerObj>()) {
param.a_axes_summed = Tuple<int>(1, lop->value);
param.b_axes_summed = Tuple<int>(1, Downcast<Integer, ObjectRef>((*obj)[1])->value);
} else {
param.a_axes_summed = Tuple<int>((*obj)[0]);
param.b_axes_summed = Tuple<int>((*obj)[1]);
}
Array<ObjectRef> arr = Downcast<Array<ObjectRef>, ObjectRef>(ref);
if (const IntImmNode* lop = arr[0].as<IntImmNode>()) {
param.a_axes_summed = Tuple<int>(1, lop->value);
param.b_axes_summed = Tuple<int>(1, Downcast<IntImm, ObjectRef>(arr[1])->value);
} else {
Array<ObjectRef> arr = Downcast<Array<ObjectRef>, ObjectRef>(ref);
if (const IntegerObj* lop = arr[0].as<IntegerObj>()) {
param.a_axes_summed = Tuple<int>(1, lop->value);
param.b_axes_summed = Tuple<int>(1, Downcast<Integer, ObjectRef>(arr[1])->value);
} else {
param.a_axes_summed = Tuple<int>(arr[0]);
param.b_axes_summed = Tuple<int>(arr[1]);
}
param.a_axes_summed = Tuple<int>(arr[0]);
param.b_axes_summed = Tuple<int>(arr[1]);
}
attrs.parsed = std::move(param);
}
attrs.parsed = std::move(param);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr);
auto ndoutputs = Invoke<op::TensordotParam>(op, &attrs, 2, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}

MXNET_REGISTER_API("_npi.tensordot")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
if (args[2].type_code() == kDLInt) {
_npi_tensordot_int_axes(args, ret);
} else {
_npi_tensordot(args, ret);
}
});

MXNET_REGISTER_API("_npi.nop")
Expand Down
13 changes: 13 additions & 0 deletions src/operator/numpy/np_tensordot_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define MXNET_OPERATOR_NUMPY_NP_TENSORDOT_OP_INL_H_

#include <vector>
#include <string>
#include "../tensor/matrix_op-inl.h"

namespace mxnet {
Expand All @@ -38,6 +39,13 @@ struct TensordotParam : public dmlc::Parameter<TensordotParam> {
DMLC_DECLARE_FIELD(a_axes_summed);
DMLC_DECLARE_FIELD(b_axes_summed);
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream a_axes_summed_s, b_axes_summed_s;
a_axes_summed_s << a_axes_summed;
b_axes_summed_s << b_axes_summed;
(*dict)["a_axes_summed"] = a_axes_summed_s.str();
(*dict)["b_axes_summed"] = b_axes_summed_s.str();
}
};

/**
Expand Down Expand Up @@ -553,6 +561,11 @@ struct TensordotIntAxesParam : public dmlc::Parameter<TensordotIntAxesParam> {
DMLC_DECLARE_PARAMETER(TensordotIntAxesParam) {
DMLC_DECLARE_FIELD(axes);
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axes_s;
axes_s << axes;
(*dict)["axes"] = axes_s.str();
}
};

/**
Expand Down
9 changes: 9 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ struct InitOpParam : public dmlc::Parameter<InitOpParam> {
MXNET_ADD_ALL_TYPES_WITH_BOOL
.describe("Target data type.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream shape_s, dtype_s;
shape_s << shape;
dtype_s << dtype;
(*dict)["shape"] = shape_s.str();
(*dict)["dtype"] = dtype_s.str();
// We do not set ctx, because ctx has been set in dict instead of InitOpParam.
// Setting ctx here results in an error.
}
};

struct InitOpWithoutDTypeParam : public dmlc::Parameter<InitOpWithoutDTypeParam> {
Expand Down

0 comments on commit d29ae7f

Please sign in to comment.