Skip to content

Commit

Permalink
Fix edge cases in const_int_bound and fold_scale_axis (apache#6911)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Dec 4, 2020
1 parent ee7788f commit 8f903c5
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ class ConstIntBoundAnalyzer::Impl
*/
static Entry MakeBound(int64_t min_value, int64_t max_value) {
Entry e;
e.min_value = min_value;
e.max_value = max_value;
e.min_value = (min_value == kPosInf) ? min_value - 1 : min_value;
e.max_value = (max_value == kNegInf) ? max_value + 1 : max_value;
return e;
}
/*!
Expand Down
13 changes: 12 additions & 1 deletion src/relay/transforms/fold_scale_axis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,18 @@ class ForwardPrep : private ExprVisitor {
}
}
// Visitor pattern override.
void VisitExpr_(const LetNode* call) { LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; }
void VisitExpr_(const LetNode* op) {
ExprVisitor::VisitExpr_(op);
// do pass through condition
// by assigning NullValue<Message>
// it means fuse signal cannot pass
// through into these subexpressions.
auto flazy = [this, op]() {
this->Update(op->value, NullValue<Message>());
this->Update(op->body, NullValue<Message>());
};
flist_.push_back(flazy);
}

void VisitExpr_(const FunctionNode* op) {
ExprVisitor::VisitExpr_(op);
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relay/test_pass_fold_scale_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,44 @@ def check(shape, channels, blocking, in_scale):
check((2, 11, 10, 2, 2), 4, (2, 2), in_scale)


def test_fold_fwd_let_fail():
"""testcase where we canont fold"""

def before(x, conv_weight, in_bias, in_scale, channels):
args = [x, conv_weight, in_bias]
x = relay.multiply(x, in_scale)
x = relay.nn.relu(x)
x = relay.add(x, in_bias)
x_var = relay.Var("x_var")
y1 = relay.nn.conv2d(
x_var,
conv_weight,
channels=channels,
kernel_size=(3, 3),
data_layout="NHWC",
kernel_layout="HWIO",
padding=(1, 1),
)
z = relay.add(y1, x)
let = relay.Let(x_var, x, z)
return relay.Function(args, let)

def check(shape, channels):
x = relay.var("x", shape=shape)
in_channels = shape[-1]
in_bias = relay.var("in_bias", shape=(in_channels,))
in_scale = relay.const(_get_positive_scale(size=(in_channels,)))
# test depthwise
assert in_channels == channels
weight = relay.var("weight")
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
assert tvm.ir.structural_equal(y1, y1_folded)

check((2, 11, 10, 4), 4)


def test_fold_fwd_negative_scale():
"""Testcase of folding negative scale"""

Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_arith_const_int_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ def test_add_sub_bound():
assert bd.min_value == bd.NEG_INF
assert bd.max_value == 1

## constants with negative or positive max(int64) occassionally show up
## in models, this is to ensure we can handle those cases
analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.NEG_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
bd = analyzer.const_int_bound(x + y)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF

analyzer.update(x, tvm.arith.ConstIntBound(bd.POS_INF, bd.POS_INF), override=True)
analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True)
bd = analyzer.const_int_bound(x + y)
assert bd.min_value == bd.NEG_INF
assert bd.max_value == bd.POS_INF


def test_mul_bound():
analyzer = tvm.arith.Analyzer()
Expand Down

0 comments on commit 8f903c5

Please sign in to comment.