Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY/PASS] Simplify inference. #2033

Merged
merged 7 commits into from
Oct 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion nnvm/tests/python/compiler/test_simplify_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def simple_bn(x, gamma, beta, moving_mean, moving_var,
scale = sym.elemwise_mul(1 / sym.sqrt(moving_var + epsilon), gamma)
shift = sym.elemwise_add(
sym.elemwise_mul(sym.negative(moving_mean), scale), beta)
shape = [-1 if i == axis else 1 for i in range(len(shape))]
# for 2D
num_newaxis=len(shape) - axis - 1
if num_newaxis:
Expand Down
61 changes: 60 additions & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay."""
from __future__ import absolute_import
from numbers import Number as _Number

import numpy as _np
from .base import RelayNode, register_relay_node
Expand All @@ -11,6 +12,8 @@
from .. import nd as _nd
from .. import convert

# will be registered afterwards
_op_make = None

class Expr(RelayNode):
"""The base type for all Relay expressions."""
Expand Down Expand Up @@ -48,6 +51,62 @@ def astype(self, dtype):
"""
return _make.dtype_cast(self, dtype)

def __add__(self, other):
if isinstance(other, Expr):
return _op_make.add(self, other)
elif isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))

def __radd__(self, other):
return self.__add__(other)

def __sub__(self, other):
if isinstance(other, Expr):
return _op_make.subtract(self, other)
elif isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))

def __rsub__(self, other):
if isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))

def __mul__(self, other):
if isinstance(other, Expr):
return _op_make.multiply(self, other)
elif isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))

def __rmul__(self, other):
return self.__mul__(other)

def __div__(self, other):
if isinstance(other, Expr):
return _op_make.divide(self, other)
elif isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))

def __rdiv__(self, other):
if isinstance(other, _Number):
raise TypeError('convert "%s" with `const` first' % str(other))
else:
raise TypeError("type %s not supported" % str(type(other)))

def __truediv__(self, other):
return self.__div__(other)

def __rtruediv__(self, other):
return self.__rdiv__(other)


@register_relay_node
class Constant(Expr):
Expand Down Expand Up @@ -305,7 +364,7 @@ def __len__(self):

def __repr__(self):
return ("TupleWrapper(" + self.tuple_value.__repr__() +
", " + self.size + ")")
", " + str(self.size) + ")")

def astype(self, _):
raise TypeError("astype cannot be used on tuple")
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,21 @@ def free_type_vars(expr):
"""
return _ir_pass.free_type_vars(expr)

def simplify_inference(expr):
ZihengJiang marked this conversation as resolved.
Show resolved Hide resolved
""" Simplify the data-flow graph for inference phase.

Parameters
----------
e: tvm.relay.Expr
The input Expression

Returns
-------
result: tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
return _ir_pass.simplify_inference(expr)

def dead_code_elimination(expr):
""" Remove expressions which does not effect the program result (dead code).
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,11 @@
from . import _tensor
from ..expr import Expr
from ..base import register_relay_node


def _register_op_make():
from . import _make
from .. import expr
expr._op_make = _make

_register_op_make()
34 changes: 34 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,40 @@ inline bool IsDepthwiseConv2D(const Call& call,
}


/*!
* \brief Create a Constant with a scalar
*
* \param dtype The data type.
* \param value The value of the scalar.
* \return A Constant.
*/
template<typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) {
CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch";
runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
*static_cast<T*>(arr->data) = value;
return ConstantNode::make(arr);
}


inline Expr Negative(Expr x) {
static const Op& op = Op::Get("negative");
return CallNode::make(op, {x}, Attrs(), {});
}


inline Expr Sqrt(Expr x) {
static const Op& op = Op::Get("sqrt");
return CallNode::make(op, {x}, Attrs(), {});
}


inline Expr Add(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}


inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
Expand Down
77 changes: 77 additions & 0 deletions src/relay/pass/simplify_inference.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*!
* Copyright (c) 2018 by Contributors
* \file simplify_inference.cc
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include "./pattern_util.h"

namespace tvm {
namespace relay {

Expr BatchNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
Expr beta,
Expr moving_mean,
Expr moving_var) {
const auto param = attrs.as<BatchNormAttrs>();
Expr epsilon = MakeConstantScalar(Float(32), static_cast<float>(param->epsilon));
Expr var_add_eps = Add(moving_var, epsilon);
Expr sqrt_var = Sqrt(var_add_eps);
Expr scale = Divide(MakeConstantScalar(Float(32), 1.0f), sqrt_var);

if (param->scale) {
scale = Multiply(scale, gamma);
}
Expr neg_mean = Negative(moving_mean);
Expr shift = Multiply(neg_mean, scale);
if (param->center) {
shift = Add(shift, beta);
}

int axis = param->axis;
const auto* tdata = data->type_as<TensorTypeNode>();
scale = ExpandBiasToMatchAxis(scale, tdata->shape.size(), {axis});
shift = ExpandBiasToMatchAxis(shift, tdata->shape.size(), {axis});

Expr out = Multiply(data, scale);
out = Add(out, shift);
return out;
}

class InferenceSimplifier : public ExprMutator {
public:
Expr VisitExpr_(const TupleGetItemNode* n) final {
static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& dropout = Op::Get("nn.dropout");

Expr new_e = ExprMutator::VisitExpr_(n);
const auto* new_n = new_e.as<TupleGetItemNode>();
ZihengJiang marked this conversation as resolved.
Show resolved Hide resolved
if (new_n->index != 0) {
return new_e;
}
if (const auto* call = new_n->tuple.as<CallNode>()) {
if (call->op.same_as(batch_norm)) {
return BatchNormToInferUnpack(call->attrs,
call->args[0], call->args[1], call->args[2], call->args[3], call->args[4]);
} else if (call->op.same_as(dropout)) {
return call->args[0];
}
}
return new_e;
}
};

Expr SimplifyInference(const Expr& e) {
return InferenceSimplifier().Mutate(e);
}

TVM_REGISTER_API("relay._ir_pass.simplify_inference")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SimplifyInference(args[0]);
});

} // namespace relay
} // namespace tvm
47 changes: 47 additions & 0 deletions tests/python/relay/test_pass_simplify_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from tvm import relay as rly
from tvm.relay.ir_pass import simplify_inference, alpha_equal

def test_simplify_batchnorm():
def simple_bn(x, gamma, beta, moving_mean, moving_var,
axis=1, epsilon=1e-5, shape=None):
# expect = (x - moving_mean) / sqrt(moving_var + eps) * gamma + beta
scale = rly.multiply(rly.const(1, 'float32') /
rly.sqrt(moving_var + rly.const(epsilon, 'float32')), gamma)
shift = rly.add(
rly.multiply(rly.negative(moving_mean), scale), beta)
num_newaxis = len(shape) - (axis + 1)
if num_newaxis:
scale = rly.expand_dims(scale, axis=1, num_newaxis=num_newaxis)
shift = rly.expand_dims(shift, axis=1, num_newaxis=num_newaxis)
return x * scale + shift

def check(dim, axis, nstep):
eps = 0.01
ttype1 = rly.TensorType(tuple(10 for i in range(dim)), 'float32')
ttype2 = rly.TensorType((10,), 'float32')
x = rly.var("x", ttype1)
beta = rly.var("beta", ttype2)
gamma = rly.var("gamma", ttype2)
moving_var = rly.var("moving_var", ttype2)
moving_mean = rly.var("moving_mean", ttype2)
y1, y2 = x, x

for _ in range(nstep):
y1, _, _ = rly.nn.batch_norm(y1 + rly.const(1, 'float32'),
gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = rly.nn.dropout(y1)
y1 = rly.ir_pass.infer_type(y1)
y1 = simplify_inference(y1)

y2 = simple_bn(y2 + rly.const(1, 'float32'),
gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ttype1.shape)
assert rly.ir_pass.graph_equal(y1, y2)

check(2, 1, 1)
check(4, 1, 1)
check(4, 0, 3)


if __name__ == "__main__":
test_simplify_batchnorm()