Skip to content

Commit

Permalink
[Relay] Add support for relay expressions as pad value for static pad (
Browse files Browse the repository at this point in the history
…apache#7860)

* add support for expr as inputs to pad

* fix improper amount of args

* add dynamic padding test

* infer type better test

* add comments to type relations

* fix infer type layouts

* proper return shape

* proper shape infer type

* make the tests pass by setting the conditions

* make codegen reflect reality

* make ternary operations more pythonic

* proper infer layout

* fold explicit padding

* fix pattern matching in contrib

* revert tests for contrib now that pattern matching works

* revert import changes

* add newline
  • Loading branch information
AndrewZhaoLuo authored and Trevor Morris committed May 6, 2021
1 parent 4e42e80 commit 6abb2c4
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 66 deletions.
3 changes: 0 additions & 3 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1044,13 +1044,10 @@ struct UpSampling3DAttrs : public tvm::AttrsNode<UpSampling3DAttrs> {

/*! \brief Attributes used for the padding operator */
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
double pad_value;
Array<Array<Integer>> pad_width;
std::string pad_mode;

TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
TVM_ATTR_FIELD(pad_value).set_default(0.0).describe(
"The value used for padding when mode is 'constant'.");
TVM_ATTR_FIELD(pad_width).describe(
"Number of values padded to the edges of each axis, "
"in the format of ((before_1, after_1), ..., (before_N, after_N))");
Expand Down
11 changes: 5 additions & 6 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@
# pylint: disable=invalid-name, unused-argument
"""Arm Compute Library supported operators."""
import tvm

from tvm import relay
from tvm._ffi import register_func
from tvm.relay.expr import const
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.expr import const

from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr
from .register import register_pattern_table
from ...dataflow_pattern import is_constant, is_expr, is_op, wildcard
from ..strategy.generic import is_depthwise_conv2d
from .register import register_pattern_table


def is_arm_compute_runtime_enabled():
Expand Down Expand Up @@ -140,7 +139,7 @@ def conv_pattern():
pattern : dataflow_pattern.AltPattern
Denotes the convolution pattern.
"""
pattern = is_op("nn.pad")(wildcard()) | wildcard()
pattern = is_op("nn.pad")(wildcard(), wildcard()) | wildcard()
pattern = is_op("nn.conv2d")(pattern, is_constant())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
pattern = pattern.optional(is_op("nn.relu"))
Expand All @@ -154,7 +153,7 @@ def qnn_conv_pattern():
pattern : dataflow_pattern.AltPattern
Denotes the convolution pattern.
"""
pattern = is_op("nn.pad")(wildcard()) | wildcard()
pattern = is_op("nn.pad")(wildcard(), wildcard()) | wildcard()
pattern = is_op("qnn.conv2d")(
pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import wildcard, is_op, is_constant
from ... import qnn as _qnn
from .register import register_pattern_table
from ...dataflow_pattern import is_constant, is_op, wildcard
from . import _ethosn as support
from .register import register_pattern_table


class Available(Enum):
Expand Down Expand Up @@ -82,7 +82,7 @@ def pattern_table():
"""Get the Ethos-N compiler pattern table."""

def qnn_conv_pattern():
pattern = is_op("nn.pad")(wildcard()) | wildcard()
pattern = is_op("nn.pad")(wildcard(), wildcard()) | wildcard()
pattern = is_op("qnn.conv2d")(
pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
)
Expand Down
14 changes: 5 additions & 9 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
"""Neural network operations."""
from tvm.relay import expr

from . import _make
from ...expr import Constant, Expr, const
from ..dyn.nn import _make as _dyn_make
from . import _make
from .utils import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d
from ...expr import const, Expr, Constant


def conv1d(
Expand Down Expand Up @@ -1606,15 +1606,11 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
result : tvm.relay.Expr
The computed result.
"""
if isinstance(pad_value, Constant):
pad_value = pad_value.data.asnumpy().item()
if isinstance(pad_width, Constant):
pad_width = [list(i) for i in pad_width.data.asnumpy()]
if isinstance(pad_width, Expr) or (isinstance(pad_value, Expr)):
if not isinstance(pad_width, Expr):
pad_width = const(list(pad_width))
if not isinstance(pad_value, Expr):
pad_value = const(pad_value)
if not isinstance(pad_value, Expr):
pad_value = const(pad_value)
if isinstance(pad_width, Expr):
return _dyn_make.pad(data, pad_width, pad_value, pad_mode)
return _make.pad(data, pad_width, pad_value, pad_mode)

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String dst_l

Expr MakeOnes(Array<Integer> shape, DataType dtype);

Expr MakePad(Expr data, Array<Array<Integer>> pad_width, double pad_value, String pad_mode);
Expr MakePad(Expr data, Array<Array<Integer>> pad_width, Expr pad_value, String pad_mode);

Expr MakeReduce(Expr data, Array<Integer> axis, bool keepdims, bool exclude, String op_name);

Expand Down
40 changes: 22 additions & 18 deletions src/relay/op/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/relay/op.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/op.h>
#include <tvm/topi/elemwise.h>
#include <tvm/topi/nn.h>

#include <vector>
Expand All @@ -44,7 +45,7 @@ Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layou
// NOTE: Discard "const" qualifier here.
PadAttrs* params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());

Layout ret;
Layout ret_data;
// If new_in_layouts are defined, this code tries to modify the layout.
bool is_layout_modified = new_in_layouts.defined();
if (new_in_layouts.defined()) {
Expand All @@ -55,8 +56,8 @@ Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layou
// 1) Create a map from axis to param_width using old layout.
std::map<std::string, tvm::Array<Integer>> axis_pad_width;
int index_counter = 0;
ICHECK_EQ(new_in_layouts.size(), 1);
ICHECK_EQ(old_in_layouts.size(), 1);
ICHECK_EQ(new_in_layouts.size(), 2);
ICHECK_EQ(old_in_layouts.size(), 2);
for (auto iter_var : old_in_layouts[0]->axes) {
const auto& old_layout_axis = LayoutAxis::Get(iter_var);
axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]);
Expand Down Expand Up @@ -95,26 +96,29 @@ Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layou
// If the above conditions satisfied, we can set the newly created pad_width and use the new
// layout.
if (is_layout_modified) {
ret = new_in_layouts[0];
ret_data = new_in_layouts[0];
params->pad_width = new_pad_width;
}
}

if (!is_layout_modified) {
if (old_in_layouts.defined()) {
ICHECK_EQ(old_in_layouts.size(), 1);
ret = old_in_layouts[0];
ICHECK_EQ(old_in_layouts.size(), 2);
ret_data = old_in_layouts[0];
} else {
ret = Layout::Undef();
ret_data = Layout::Undef();
}
}

return Array<Array<Layout>>{{ret}, {ret}};
// The pad value is always a scalar
Layout ret_pad_value = Layout("1");
return Array<Array<Layout>>{{ret_data, ret_pad_value}, {ret_data}};
}

bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 2);
// types = [pad_data_type, pad_value_type, ret_type]
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

Expand Down Expand Up @@ -151,7 +155,7 @@ bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
}

reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype));
reporter->Assign(types[2], TensorType(Array<IndexExpr>(oshape), data->dtype));
return true;
}

Expand All @@ -170,20 +174,19 @@ Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs
for (size_t i = 0; i < pad_width.size(); ++i) {
pad_after.push_back(pad_width[i][1]);
}
const auto* out_ttype = out_type.as<TensorTypeNode>();
return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after,
tvm::tir::make_const(out_ttype->dtype, param->pad_value),
"T_pad", topi::kElementWise, param->pad_mode)};
te::Tensor cast_pad_value = topi::cast(inputs[1], inputs[0]->dtype);
const PrimExpr& pad_value = cast_pad_value(Array<PrimExpr>());
return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad",
topi::kElementWise, param->pad_mode)};
}

// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<Integer>> pad_width, double pad_value, String pad_mode) {
Expr MakePad(Expr data, Array<Array<Integer>> pad_width, Expr pad_value, String pad_mode) {
auto attrs = make_object<PadAttrs>();
attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width);
attrs->pad_mode = std::move(pad_mode);
static const Op& op = Op::Get("nn.pad");
return Call(op, {data}, Attrs(attrs), {});
return Call(op, {data, pad_value}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.pad").set_body_typed(MakePad);
Expand All @@ -193,8 +196,9 @@ RELAY_REGISTER_OP("nn.pad")
)code" TVM_ADD_FILELINE)
.set_attrs_type<PadAttrs>()
.set_num_inputs(1)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("pad_val", "Tensor", "The value to fill the padded area with")
.set_support_level(2)
.add_type_rel("Pad", PadRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
Expand Down
3 changes: 1 addition & 2 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,7 @@ Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2D
} else {
LOG(FATAL) << "qnn.conv2d does not support " << param->data_layout << " layout";
}
auto pad_value = GetScalarFromConstant<int>(input_zero_point);
padded_data = Pad(data, pad_width, pad_value, "constant");
padded_data = Pad(data, pad_width, input_zero_point, "constant");
}
return padded_data;
}
Expand Down
4 changes: 3 additions & 1 deletion src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ class DynamicToStaticMutator : public MixedModeMutator {

const PadAttrs* param = call_node->attrs.as<PadAttrs>();
ICHECK(param);
return MakePad(call_node->args[0], ToMatrix(pad_width->data), ToScalar(pad_fill->data),

Expr pad_value = args[2];
return MakePad(call_node->args[0], ToMatrix(pad_width->data), pad_value,
param->pad_mode);
}
return Expr(nullptr);
Expand Down
8 changes: 6 additions & 2 deletions src/relay/transforms/fold_explicit_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SimplifyConvPad {
SimplifyConvPad() {
x_ = IsWildcard();
w_ = IsWildcard();
pad_ = IsOp("nn.pad")({x_});
pad_ = IsOp("nn.pad")({x_, IsWildcard()});
conv1d_ = IsOp("nn.conv1d");
conv2d_ = IsOp("nn.conv2d");
conv3d_ = IsOp("nn.conv3d");
Expand Down Expand Up @@ -119,7 +119,11 @@ class SimplifyConvPad {
ICHECK(pad_node);
const PadAttrs* param = pad_node->attrs.as<PadAttrs>();
ICHECK(param);
if (param->pad_mode == "constant" && param->pad_value == 0.0) {
Array<Expr> args = pad_node->args;

// Possibly perform more optimizations if the pad_value is 0
const ConstantNode* pad_value = args[1].as<ConstantNode>();
if (param->pad_mode == "constant" && pad_value && ToScalar(pad_value->data) == 0.0) {
Attrs attrs;
if (node_map.count(conv1d_)) {
attrs = GetAttrs(param, call_node->attrs.as<Conv1DAttrs>());
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexE
count_include_pad, "nn.avg_pool2d");
}

static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value,
static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, Expr pad_value,
std::string pad_mode) {
Array<Array<Integer>> pad_width_int;
for (size_t i = 0; i < pad_width.size(); ++i) {
Expand Down
74 changes: 54 additions & 20 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
"""
import numpy as np
import tvm
from tvm import te
from tvm import autotvm
from tvm import relay
import tvm.testing
import tvm.topi.testing
from tvm import autotvm, relay, te
from tvm.contrib import utils
from tvm.relay import transform
from tvm.relay.testing import run_infer_type
from tvm.contrib import utils
import tvm.topi.testing
from tvm.topi.cuda.conv3d_winograd import _infer_tile_size
import tvm.testing


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -1197,6 +1195,30 @@ def test_pad_infer_type():
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n + (-2), c + (-4), h + (-2), w + 8), "float32")

# dealing with dynamic vals
n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w")
t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.pad(
t, ((1, 1), (2, 2), (3, 3), (4, 4)), pad_value=relay.var("pad_value", "float32")
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32")


def _get_numpy_pad(dshape, data, pad, pad_value=0):
mod_pad = []
for axis, (pad_x, pad_y) in enumerate(pad):
indices = range(dshape[axis])
if pad_x < 0:
indices = indices[abs(pad_x) :]
pad_x = 0
if pad_y < 0:
indices = indices[:pad_y]
pad_y = 0
data = np.take(data, indices, axis)
mod_pad.append((pad_x, pad_y))
return np.pad(data, tuple(mod_pad), "constant", constant_values=pad_value)


@tvm.testing.uses_gpu
def test_pad_run():
Expand All @@ -1209,20 +1231,7 @@ def _test_run(dtype):
y = relay.nn.pad(x, pad)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
mod_pad = []
mod_data = data
for axis, (pad_x, pad_y) in enumerate(pad):
indices = range(dshape[axis])
if pad_x < 0:
indices = indices[abs(pad_x) :]
pad_x = 0
if pad_y < 0:
indices = indices[:pad_y]
pad_y = 0
mod_data = np.take(mod_data, indices, axis)
mod_pad.append((pad_x, pad_y))

ref_res = np.pad(mod_data, tuple(mod_pad), "constant")
ref_res = _get_numpy_pad(dshape, data, pad)
for target, dev in tvm.testing.enabled_targets():
intrp1 = relay.create_executor("graph", device=dev, target=target)
op_res1 = intrp1.evaluate(func)(data)
Expand All @@ -1232,6 +1241,30 @@ def _test_run(dtype):
_test_run("int32")


@tvm.testing.uses_gpu
def test_pad_run_dynamic_pad_value():
def _test_run(dtype):
dshape = (4, 6, 3, 5)
pad = ((-1, -1), (2, -2), (0, -2), (4, 4))

data = relay.var("data", shape=dshape, dtype=dtype)
pad_value = relay.var("pad_value", dtype)
pad_data = relay.nn.pad(data, pad, pad_value=pad_value)
f = relay.Function([data, pad_value], pad_data)

data_arr = np.random.uniform(-10, 10, size=dshape).astype(dtype)
pad_value_arr = 2.0
ref_res = _get_numpy_pad(dshape, data_arr, pad, pad_value=pad_value_arr)

for target, dev in tvm.testing.enabled_targets():
intrp = relay.create_executor(kind="graph", device=dev, target=target)
result = intrp.evaluate(f)(data_arr, pad_value_arr)
tvm.testing.assert_allclose(result.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)

_test_run("float32")
_test_run("int32")


@tvm.testing.uses_gpu
def test_lrn():
n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
Expand Down Expand Up @@ -1766,6 +1799,7 @@ def _test_correlation(
test_flatten_infer_type()
test_pad_infer_type()
test_pad_run()
test_pad_run_dynamic_pad_value()
test_conv3d_transpose_infer_type()
test_conv3d_transpose_ncdhw_run()
test_conv2d_transpose_infer_type()
Expand Down

0 comments on commit 6abb2c4

Please sign in to comment.