Skip to content

Commit

Permalink
[GRADIENT] Register more gradient operators (apache#300)
Browse files Browse the repository at this point in the history
* Add conv2d max_pool backward op

* Added tests

* Fix testing

* Address comments

* Change dot to matmul

* Address comments

* Break down indicator function

* Make greater, less numpy compatible
  • Loading branch information
yuruofeifei authored and sergei-mironov committed Aug 8, 2018
1 parent f9684a1 commit 47ff210
Show file tree
Hide file tree
Showing 18 changed files with 1,016 additions and 162 deletions.
12 changes: 12 additions & 0 deletions nnvm/docs/top.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This level enables fully connected multi-layer perceptron.
:nosignatures:

nnvm.symbol.dense
nnvm.symbol.matmul
nnvm.symbol.relu
nnvm.symbol.tanh
nnvm.symbol.sigmoid
Expand All @@ -38,6 +39,7 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.elemwise_sum
nnvm.symbol.full
nnvm.symbol.full_like
nnvm.symbol.ones
Expand All @@ -54,6 +56,8 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.softmax
nnvm.symbol.log_softmax
nnvm.symbol.pad
nnvm.symbol.block_grad
nnvm.symbol.indicator


**Level 2: Convolutions**
Expand All @@ -77,6 +81,8 @@ This level enables typical convnet models.
:nosignatures:

nnvm.symbol.reshape
nnvm.symbol.reshape_like
nnvm.symbol.expand_like
nnvm.symbol.copy
nnvm.symbol.negative
nnvm.symbol.leaky_relu
Expand Down Expand Up @@ -107,6 +113,7 @@ This level enables typical convnet models.
Detailed Definitions
--------------------
.. autofunction:: nnvm.symbol.dense
.. autofunction:: nnvm.symbol.matmul
.. autofunction:: nnvm.symbol.relu
.. autofunction:: nnvm.symbol.tanh
.. autofunction:: nnvm.symbol.sigmoid
Expand All @@ -117,6 +124,7 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.elemwise_sum
.. autofunction:: nnvm.symbol.full
.. autofunction:: nnvm.symbol.full_like
.. autofunction:: nnvm.symbol.ones
Expand All @@ -133,6 +141,8 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.softmax
.. autofunction:: nnvm.symbol.log_softmax
.. autofunction:: nnvm.symbol.pad
.. autofunction:: nnvm.symbol.block_grad
.. autofunction:: nnvm.symbol.indicator

.. autofunction:: nnvm.symbol.conv2d
.. autofunction:: nnvm.symbol.conv2d_transpose
Expand All @@ -142,6 +152,8 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.global_avg_pool2d

.. autofunction:: nnvm.symbol.reshape
.. autofunction:: nnvm.symbol.reshape_like
.. autofunction:: nnvm.symbol.expand_like
.. autofunction:: nnvm.symbol.copy
.. autofunction:: nnvm.symbol.negative
.. autofunction:: nnvm.symbol.leaky_relu
Expand Down
78 changes: 74 additions & 4 deletions nnvm/include/nnvm/top/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ enum TypeFlag {
kUint64 = 10,
};

enum IndicatorRuleFlag {
kGT0 = 0,
kLT0 = 1,
kMax = 2,
kMin = 3,
};

#define DMLC_DECLARE_DTYPE_FIELD(name) \
DMLC_DECLARE_FIELD(name) \
.add_enum("float16", kFloat16) \
Expand All @@ -84,6 +91,28 @@ struct CastParam : public dmlc::Parameter<CastParam> {
}
};

struct IndicatorParam : public dmlc::Parameter<IndicatorParam> {
TShape axis;
bool exclude;
DMLC_DECLARE_PARAMETER(IndicatorParam) {
DMLC_DECLARE_FIELD(axis).set_default(TShape())
.describe(R"code(The axis or axes along which to perform the indicator rule.
The default, `axis=()`, will compute over all elements into a
scalar array with shape `(1,)`.
If `axis` is int, rule is applied on a particular axis.
If `axis` is a tuple of ints, rule is applied on all the axes
specified in the tuple.
If `exclude` is true, rule will be applied on the axes that are
NOT in axis instead.)code");
DMLC_DECLARE_FIELD(exclude).set_default(false)
.describe("Whether to apply rule on axis that are NOT in axis instead.");
}
};

struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
Tuple<int64_t> shape;

Expand All @@ -97,8 +126,7 @@ struct SqueezeParam : public dmlc::Parameter<SqueezeParam> {

DMLC_DECLARE_PARAMETER(SqueezeParam) {
DMLC_DECLARE_FIELD(axis).set_default(TShape())
.describe("The axis to squeeze in the input tensor."
" If set to None, all size=1 axes will be squeezed");
.describe("The axis to squeeze in the input tensor.");
}
};

Expand All @@ -110,6 +138,15 @@ struct ScalarParam : public dmlc::Parameter<ScalarParam> {
}
};

struct FillValueParam : public dmlc::Parameter<FillValueParam> {
double fill_value;

DMLC_DECLARE_PARAMETER(FillValueParam) {
DMLC_DECLARE_FIELD(fill_value)
.describe("Scalar value to be filled");
}
};

struct TransposeParam : public dmlc::Parameter<TransposeParam> {
TShape axes;

Expand Down Expand Up @@ -158,16 +195,49 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
}
};

struct InitOpWithScalarParam : public dmlc::Parameter<InitOpWithScalarParam> {
TShape shape;
int dtype;
double fill_value;

DMLC_DECLARE_PARAMETER(InitOpWithScalarParam) {
DMLC_DECLARE_FIELD(shape).set_default(TShape());
DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32)
.describe("Target data type.");
DMLC_DECLARE_FIELD(fill_value).describe("Scalar value to fill");
}
};

struct InitOpParam : public dmlc::Parameter<InitOpParam> {
TShape shape;
int dtype;
double value;

DMLC_DECLARE_PARAMETER(InitOpParam) {
DMLC_DECLARE_FIELD(shape).set_default(TShape());
DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32)
.describe("Target data type.");
DMLC_DECLARE_FIELD(value).describe("Value to fill");
}
};

struct ElementWiseReduceParam : public dmlc::Parameter<ElementWiseReduceParam> {
int num_args;
DMLC_DECLARE_PARAMETER(ElementWiseReduceParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs to be reduced.");
}
};

struct MatMulParam : public dmlc::Parameter<MatMulParam> {
bool transpose_a;
bool transpose_b;

DMLC_DECLARE_PARAMETER(MatMulParam) {
DMLC_DECLARE_FIELD(transpose_a)
.describe("If true then transpose the first input before dot.")
.set_default(false);
DMLC_DECLARE_FIELD(transpose_b)
.describe("If true then transpose the second input before dot.")
.set_default(false);
}
};

Expand Down
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
The input types to the graph
params : dict of str to NDArray
Input parameetrs to the graph that do not change
Input parameters to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
Expand Down
57 changes: 57 additions & 0 deletions nnvm/python/nnvm/compiler/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import tvm
from . import graph_attr

from ..graph import create
from ..symbol import Group, ones_like

def infer_shape(graph, **shape):
"""Infer the shape given the shape of inputs.
Expand Down Expand Up @@ -89,3 +92,57 @@ def check_graph_equal(grapha, graphb, compare_variable_attrs=False):
err = _deep_compare(grapha, graphb, compare_variable_attrs)
if err:
raise ValueError("Graph compare error: " + err)

def get_gradient_graph(ys, xs, grad_ys=None):
"""Create gradient graph of ys with respect to xs.
Parameters
----------
ys : Symbol or list of Symbol
Symbols from which the gradient is calculated.
xs : Symbol or list of Symbol
Symbols the gradient respect to.
For group symbol, gradients for all outputs will be calculated.
grad_ys : Symbol or list of Symbol
Head gradients for ys.
Returns
-------
ret : Graph
Generated gradient graph.
"""
if isinstance(ys, list):
ys = Group(ys)
g = create(ys)
g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
ny = len(ys.list_output_names())
if grad_ys is None:
grad_ys = [ones_like(ys[i]) for i in range(ny)]
g._set_symbol_list_attr('grad_ys_out_grad', grad_ys)
return g.apply('Gradient')

def gradients(ys, xs, grad_ys=None):
"""Create gradient symbol of ys respect to xs.
Parameters
----------
ys : Symbol or list of Symbol
Symbols from which the gradient is calculated.
xs : Symbol or list of Symbol
Symbols the gradient respect to.
For group symbol, gradients for all outputs will be calculated.
grad_ys : Symbol or list of Symbol
Head gradients for ys.
Returns
-------
ret : list of Symbol
Generated gradient symbol. For each xs,
all gradients from ys are merged into a single symbol.
"""
grad_g = get_gradient_graph(ys, xs, grad_ys)
nx = len(Group(xs).list_output_names()) \
if isinstance(xs, list) else len(xs.list_output_names())
ret = [grad_g.symbol[i] for i in range(nx)]
return ret
36 changes: 0 additions & 36 deletions nnvm/python/nnvm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ._base import GraphHandle, SymbolHandle
from ._base import check_call
from .symbol import Variable, Symbol, Group as _Group
from .symbol import ones_like

class GraphIndex(object):
"""Index for quickly accessing graph attributes.
Expand Down Expand Up @@ -271,38 +270,3 @@ def create(symbol):
check_call(_LIB.NNGraphCreate(
symbol.handle, ctypes.byref(ghandle)))
return Graph(ghandle)


def gradients(ys, xs, grad_ys=None):
"""Create gradient symbol of ys respect to xs.
Parameters
----------
ys : Symbol or list of Symbol
Symbols from which the gradient is calculated.
xs : Symbol or list of Symbol
Symbols the gradient respect to.
For group symbol, gradients for all outputs will be calculated.
grad_ys : Symbol or list of Symbol
Head gradients for ys.
Returns
-------
ret : list of Symbol
Generated gradient symbol. For each xs,
all gradients from ys are merged into a single symbol.
"""
if isinstance(ys, list):
ys = _Group(ys)
g = create(ys)
g._set_symbol_list_attr('grad_ys', ys)
g._set_symbol_list_attr('grad_xs', xs)
ny = len(ys.list_output_names())
if grad_ys is None:
grad_ys = [ones_like(ys[i]) for i in range(ny)]
g._set_symbol_list_attr('grad_ys_out_grad', grad_ys)
sym = g.apply('Gradient').symbol
nx = len(_Group(xs).list_output_names()) \
if isinstance(xs, list) else len(xs.list_output_names())
ret = [sym[i] for i in range(nx)]
return ret
9 changes: 7 additions & 2 deletions nnvm/src/pass/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@ namespace pass {
namespace {

// default aggregate gradient function
// require operator __zero__ and __ewise_sum__ to be presented.
// require operator zeros and elemwise_sum to be presented.
NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
if (v.size() == 1) {
return std::move(v[0]);
} else if (v.size() == 0) {
NodePtr zero_node = Node::Create();
zero_node->attrs.op = Op::Get("_zeros");
zero_node->attrs.op = Op::Get("zeros");
zero_node->attrs.name = "zero_grad";
zero_node->attrs.op->attr_parser(&(zero_node->attrs));
return NodeEntry{zero_node, 0, 0};
} else {
NodePtr sum_node = Node::Create();
sum_node->attrs.op = Op::Get("elemwise_sum");
sum_node->inputs = std::move(v);
sum_node->attrs.name = "grad_sum";
sum_node->attrs.dict["num_args"] = std::to_string(sum_node->inputs.size());
sum_node->attrs.op->attr_parser(&(sum_node->attrs));
return NodeEntry{sum_node, 0, 0};
}
}
Expand Down
Loading

0 comments on commit 47ff210

Please sign in to comment.