Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
eye operator, for default storage type (#9770)
Browse files Browse the repository at this point in the history
* eye

* more test

* change to two kernels

* address comments

* Update init_op.h
  • Loading branch information
ZiyueHuang authored and piiswrong committed Feb 26, 2018
1 parent 88f763e commit db24ac1
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 2 deletions.
41 changes: 40 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ._internal import NDArrayBase

__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
"ones", "add", "arange", "divide", "equal", "full", "greater", "greater_equal",
"ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
"imdecode", "lesser", "lesser_equal", "maximum", "minimum", "moveaxis", "modulo",
"multiply", "not_equal", "onehot_encode", "power", "subtract", "true_divide",
"waitall", "_new_empty_handle"]
Expand Down Expand Up @@ -3411,6 +3411,45 @@ def zeros(shape, ctx=None, dtype=None, **kwargs):
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
# pylint: enable= no-member, protected-access

def eye(N, M=0, k=0, ctx=None, dtype=None, **kwargs):
"""Return a 2-D array with ones on the diagonal and zeros elsewhere.
Parameters
----------
N: int
Number of rows in the output.
M: int, optional
Number of columns in the output. If 0, defaults to N.
k: int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal,
and a negative value to a lower diagonal.
ctx: Context, optional
An optional device context (default is the current default context)
dtype: str or numpy.dtype, optional
An optional value type (default is `float32`)
Returns
-------
NDArray
A created array
Examples
--------
>>> mx.nd.eye(2)
[[ 1. 0.]
[ 0. 1.]]
<NDArray 2x2 @cpu(0)>
>>> mx.nd.eye(2, 3, 1)
[[ 0. 1. 0.]
[ 0. 0. 1.]]
<NDArray 2x3 @cpu(0)>
"""
# pylint: disable= unused-argument
if ctx is None:
ctx = Context.default_ctx
dtype = mx_real_t if dtype is None else dtype
# pylint: disable= no-member, protected-access
return _internal._eye(N=N, M=M, k=k, ctx=ctx, dtype=dtype, **kwargs)
# pylint: enable= no-member, protected-access


def empty(shape, ctx=None, dtype=None):
"""Returns a new array of given shape and type, without initializing entries.
Expand Down
25 changes: 24 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ._internal import SymbolBase, _set_symbol_class

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "maximum", "minimum", "hypot", "zeros", "ones", "full", "arange"]
"pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange"]


class Symbol(SymbolBase):
Expand Down Expand Up @@ -2731,6 +2731,29 @@ def hypot(left, right):
else:
raise TypeError('types (%s, %s) not supported' % (str(type(left)), str(type(right))))

def eye(N, M=0, k=0, dtype=None, **kwargs):
"""Returns a new symbol of 2-D shpae, filled with ones on the diagonal
and zeros elsewhere.
Parameters
----------
N: int
Number of rows in the output.
M: int, optional
Number of columns in the output. If 0, defaults to N.
k: int, optional
Index of the diagonal: 0 (the default) refers to the main diagonal,
a positive value refers to an upper diagonal,
and a negative value to a lower diagonal.
dtype : str or numpy.dtype, optional
The value type of the inner value, default to ``np.float32``.
Returns
-------
out : Symbol
The created Symbol.
"""
if dtype is None:
dtype = _numpy.float32
return _internal._eye(N, M, k, dtype=dtype, **kwargs)

def zeros(shape, dtype=None, **kwargs):
"""Returns a new symbol of given shape and type, filled with zeros.
Expand Down
11 changes: 11 additions & 0 deletions src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(InitOpParam);
DMLC_REGISTER_PARAMETER(InitOpWithScalarParam);
DMLC_REGISTER_PARAMETER(RangeParam);
DMLC_REGISTER_PARAMETER(EyeParam);


NNVM_REGISTER_OP(_zeros)
Expand All @@ -45,6 +46,16 @@ NNVM_REGISTER_OP(_zeros)
.set_attr<FComputeEx>("FComputeEx<cpu>", FillComputeZerosEx<cpu>)
.add_arguments(InitOpParam::__FIELDS__());

NNVM_REGISTER_OP(_eye)
.describe("Return a 2-D array with ones on the diagonal and zeros elsewhere.")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<EyeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", InitEyeShape<EyeParam>)
.set_attr<nnvm::FInferType>("FInferType", InitType<EyeParam>)
.set_attr<FCompute>("FCompute<cpu>", EyeFill<cpu>)
.add_arguments(EyeParam::__FIELDS__());

NNVM_REGISTER_OP(_ones)
.describe("fill target with ones")
.set_num_inputs(0)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ NNVM_REGISTER_OP(_zeros)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>)
.set_attr<FComputeEx>("FComputeEx<gpu>", FillComputeZerosEx<gpu>);

NNVM_REGISTER_OP(_eye)
.set_attr<FCompute>("FCompute<gpu>", EyeFill<gpu>);

NNVM_REGISTER_OP(_ones)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);

Expand Down
89 changes: 89 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <dmlc/optional.h>
#include <vector>
#include <string>
#include <algorithm>
#include <limits>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
Expand Down Expand Up @@ -60,6 +61,62 @@ struct InitOpParam : public dmlc::Parameter<InitOpParam> {
}
};

struct EyeParam : public dmlc::Parameter<EyeParam> {
nnvm::dim_t N;
nnvm::dim_t M;
nnvm::dim_t k;
std::string ctx;
int dtype;

DMLC_DECLARE_PARAMETER(EyeParam) {
DMLC_DECLARE_FIELD(N)
.describe("Number of rows in the output.");
DMLC_DECLARE_FIELD(M)
.set_default(0)
.describe("Number of columns in the output. If 0, defaults to N");
DMLC_DECLARE_FIELD(k)
.set_default(0)
.describe("Index of the diagonal. 0 (the default) refers to the main diagonal."
"A positive value refers to an upper diagonal."
"A negative value to a lower diagonal.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
.add_enum("float32", mshadow::kFloat32)
.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int64", mshadow::kInt64)
.describe("Target data type.");
}
};

template<typename ParamType>
inline bool InitEyeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, param.M > 0 ? param.M : param.N));
return true;
}

template<int req>
struct eye_dns_fill {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const nnvm::dim_t init_col,
const nnvm::dim_t k,
const nnvm::dim_t num_cols) {
KERNEL_ASSIGN(out_data[(i+init_col-k)*num_cols+i+init_col], req, static_cast<DType>(1));
}
};


struct RangeParam : public dmlc::Parameter<RangeParam> {
double start;
dmlc::optional<double> stop;
Expand Down Expand Up @@ -336,6 +393,38 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
}
}


template<typename xpu>
void EyeFill(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 0U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
const TBlob& out_data = outputs[0];
const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;

const nnvm::dim_t cnnz = std::max(num_cols - std::abs(param.k), (nnvm::dim_t)0);
const nnvm::dim_t rnnz = std::max(param.N - std::abs(param.k), (nnvm::dim_t)0);
const nnvm::dim_t nnz = param.k > 0 ? std::min(cnnz, param.N) :
std::min(rnnz, num_cols);
using namespace mxnet_op;
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Fill(s, out_data, req[0], static_cast<DType>(0));
if (nnz > 0) {
Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
std::max(static_cast<nnvm::dim_t>(0), param.k), param.k, num_cols);
}
});
});
}


struct range_fwd {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, int repeat, DType start, DType step,
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,15 @@ def test_output():
assert_almost_equal(out.asnumpy(), ones.asnumpy() * 2)
arange_out = mx.nd.arange(0, 20, dtype='int64')
assert_almost_equal(arange_out.asnumpy(), np.arange(0, 20))
N_array = np.random.randint(1, high=8, size=10)
M_array = np.random.randint(1, high=8, size=10)
k_array = np.random.randint(-10, high=10, size=10)
for i in range(10):
N = N_array[i]
M = M_array[i]
k = k_array[i]
assert_almost_equal(np.eye(N, M, k), mx.nd.eye(N, M, k).asnumpy())
assert_almost_equal(np.eye(N, k=k), mx.nd.eye(N, k=k).asnumpy())


@with_seed()
Expand Down

0 comments on commit db24ac1

Please sign in to comment.