Skip to content

Commit

Permalink
[Operator] Fix symbolic broadcasting (#131)
Browse files Browse the repository at this point in the history
The previous implementation is incorrect when dealing with a pair of
dimensions that are both symbolic. Minimal example:

import hidet

if __name__ == "__main__":
    x = hidet.symbol(["n"])
    y = hidet.symbol(["m"])
    z = x + y
    print(x.shape, y.shape, z.shape) # before: (n,) (m,) (m,)
  • Loading branch information
jacklee1792 authored and vadiklyutiy committed Jul 23, 2024
1 parent 61b0052 commit 1252220
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 4 deletions.
29 changes: 27 additions & 2 deletions python/hidet/ir/dialects/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,30 @@ def __init__(self, pattern, target, message=""):
self.target = target


def _certainly_equal(lhs: Expr, rhs: Expr) -> bool:
"""
Perform a conservative check of equality by traversing the expression trees in parallel.
"""
if lhs.__class__ != rhs.__class__:
return False

if lhs is rhs:
return True
if isinstance(lhs, Constant):
return lhs == rhs
if isinstance(lhs, BinaryExpr):
rhs: BinaryExpr
if _certainly_equal(lhs.a, rhs.a) and _certainly_equal(lhs.b, rhs.b):
return True
# Check (b, a) == (a, b) if commutative
commutative_ops = (Add, Multiply, BitwiseXor, LogicalAnd, LogicalOr)
if not isinstance(lhs, commutative_ops):
return False
return _certainly_equal(lhs.a, rhs.b) and _certainly_equal(lhs.b, rhs.a)

return False


class MatchContext:
def __init__(self, matcher: PatternMatcher, pattern: Expr, target: Expr):
self.matcher: PatternMatcher = matcher
Expand All @@ -45,9 +69,10 @@ def __init__(self, matcher: PatternMatcher, pattern: Expr, target: Expr):
def __enter__(self):
if self.pattern in self.matched:
if self.matched[self.pattern] is not self.target:
# we think the constant with the same value as the same object
# We have already matched a term in the pattern to another value earlier (call it lhs).
# If we cannot be certain that they are the same value, return no match.
lhs, rhs = self.matched[self.pattern], self.target
if isinstance(lhs, Constant) and isinstance(rhs, Constant) and lhs == rhs:
if _certainly_equal(lhs, rhs):
return
raise NotMatchedError(self.pattern, self.target, 'Can not match a pattern to two different targets')
else:
Expand Down
41 changes: 39 additions & 2 deletions python/hidet/ir/utils/broadcast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, List

import hidet.option
from hidet.ir.expr import Expr, Int, is_constant, if_then_else
from hidet.utils import repeat_until_converge


def can_broadcast(src_shape: Sequence[Int], dst_shape: Sequence[Int]) -> bool:
Expand Down Expand Up @@ -47,18 +50,52 @@ def broadcast_shape(x_shape: Sequence[Int], y_shape: Sequence[Int]) -> List[Int]
y_shape = [int32(1)] + y_shape
result_shape = []
for p, q in zip(x_shape, y_shape):
# Case 1: one of the dimensions is the constant 1
if is_constant(p) and p == 1:
result_shape.append(q)
elif is_constant(q) and q == 1:
result_shape.append(p)
# Case 2: both dimensions are constant
elif is_constant(p, q):
if p != q:
raise ValueError(
'can not broadcast two arrays with shape {} and {}'.format(orig_shapes[0], orig_shapes[1])
'Cannot broadcast operands with shape {} and {}'.format(orig_shapes[0], orig_shapes[1])
)
result_shape.append(p)
# Case 3: exactly one of the dimensions is constant, assume the symbolic dimension is 1
elif is_constant(p):
result_shape.append(p)
elif is_constant(q):
result_shape.append(q)
# Case 4: both dimensions are symbolic, this is only allowed if the dimensions are the same expression or at
# least one of them resolves to 1.
else:
result_shape.append(p if is_constant(p) else q)
if not hidet.option.get_option('debug_strict_broadcast_check'):
# Assume p == q
result_shape.append(p)
continue

from hidet.transforms.rule_based_simplifier import RuleBasedSimplifier

simp = RuleBasedSimplifier()
p = repeat_until_converge(simp, p)
q = repeat_until_converge(simp, q)

if is_constant(p) and p == 1:
result_shape.append(q)
elif is_constant(q) and q == 1:
result_shape.append(p)
else:
diff = repeat_until_converge(simp, p - q)
if not is_constant(diff) or diff != 0:
raise ValueError(
"Broadcasting between operands with symbolic shapes {} and {} is ambiguous,"
" consider explicitly broadcasting before the operator to resolve this ambiguity".format(
*orig_shapes
)
)
result_shape.append(p)

return result_shape


Expand Down
27 changes: 27 additions & 0 deletions python/hidet/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,17 @@ def register_hidet_options():
Hint: all variable ids will be 0 unless the debug_enable_var_id option is set to True.',
choices=[True, False],
)
register_option(
name='debug_strict_broadcast_check',
type_hint='bool',
default_value=False,
description=(
'Whether to enforce equality of shapes in symbolic broadcasts.'
' If set to True, the symbolic equivalence checker is used to prove correctness of broadcasts,'
' so broadcasting shapes [n] to [m] will raise ValueError. If set to False, broadcasting between'
' shapes [n] and [m] will proceed assuming n == m.'
),
)
register_option(
name='runtime_check',
type_hint='bool',
Expand Down Expand Up @@ -793,6 +804,22 @@ def debug_show_var_id(enable: bool = True):
OptionContext.current().set_option('debug_show_var_id', enable)


def debug_strict_broadcast_check(enable: bool = False):
"""
Whether to enforce equality of shapes in symbolic broadcasts.
If set to True, the symbolic equivalence checker is used to prove correctness of broadcasts,
so broadcasting shapes [n] to [m] will raise ValueError. If set to False, broadcasting between
shapes [n] and [m] will proceed assuming n == m.
Parameters
----------
enable: bool
Whether to enforce equality of shapes in symbolic broadcasts.
"""
OptionContext.current().set_option('debug_strict_broadcast_check', enable)


def runtime_check(enable: bool = True):
"""
Whether to check shapes and dtypes of all input arguments to compiled Graphs or Tasks.
Expand Down
4 changes: 4 additions & 0 deletions python/hidet/transforms/rule_based_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Dict
from itertools import product

from hidet.ir import dtypes
from hidet.ir.dialects.pattern import PlaceholderExpr, match
from hidet.ir.dtypes import boolean
from hidet.ir.expr import Add, convert, Sub, Multiply, Mod, LessThan, LessEqual, Equal, BinaryExpr, LogicalAnd
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self):
((c1 - e1) + e2, (e2 - e1) + c1),
((e1 - c1) + e2, (e1 + e2) - c1),
# sub
(ec1 - ec1, dtypes.int32.zero),
((c1 + e1) - e2, (e1 - e2) + c1),
(e1 - (c1 + e2), (e1 - e2) - c1),
((c1 - e1) - e2, c1 - (e1 + e2)),
Expand All @@ -125,6 +127,7 @@ def __init__(self):
((e1 - c1) * c2, e1 * c2 - c1 * c2, logical_and(c2 <= 1e5, -c2 <= 1e5)),
((e1 * c1) * c2, e1 * (c1 * c2)),
# div
(ec1 // ec1, dtypes.int32.one),
(((e1 * c1) + (e2 % c1)) // c1, e1),
((e1 // c1) // c2, e1 // (c1 * c2)),
((e1 * c1) // c1, e1),
Expand All @@ -145,6 +148,7 @@ def __init__(self):
# if then else
(IfThenElse(true, ec1, ec2), ec1),
(IfThenElse(false, ec1, ec2), ec2),
(IfThenElse(ec1, ec2, ec2), ec2),
]
self.bound_patterns = [
# ((pattern_args, pattern_func, target_args, target_func)
Expand Down
43 changes: 43 additions & 0 deletions tests/operators/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
import pytest
import numpy as np
import torch

import hidet
import hidet as hi
from hidet import ops
from hidet.ir.utils import broadcast_shape
from hidet.utils import prod


Expand Down Expand Up @@ -147,5 +150,45 @@ def test_broadcast(shape, broadcast_shape):
check_transform(shape, lambda x: x + np.zeros(broadcast_shape), lambda x: ops.broadcast(x, broadcast_shape))


def test_symbolic_broadcast():
"""
Test broadcasting semantics with symbolic shapes.
"""

n = hidet.symbol_var("n")
m = hidet.symbol_var("m")

# When strict broadcasting check is disabled, pairs of symbolic dimensions are assumed to be equal and no error
# is raised
with hidet.option.context():
hidet.option.debug_strict_broadcast_check(False)
broadcast_shape([n, m], [m, n])
broadcast_shape([n], [m, m])

with hidet.option.context():
hidet.option.debug_strict_broadcast_check(True)

# Broadcasting between these shapes with the strict broadcasting check enabled will raise an error
with pytest.raises(ValueError):
broadcast_shape([n, m], [m, n])
with pytest.raises(ValueError):
broadcast_shape([n], [m, m])

# If one dimension is 1, the broadcast result takes on the other dimension, even if it is symbolic
assert broadcast_shape([n, 1], [1, 2]) == [n, 2]
assert broadcast_shape([1], [n, n]) == [n, n]

# Pairs of symbolic dimensions don't necessarily have to be the same if one of them can be resolved to 1.
assert broadcast_shape([m // m], [n]) == [n]
assert broadcast_shape([n - n + 1], [3]) == [3]

# In the case where exactly one dimension is symbolic, the symbolic dimension is assumed to be 1.
assert broadcast_shape([2, 3], [n, n]) == [2, 3]

# This should never work without further conditions on n and m
with pytest.raises(ValueError):
broadcast_shape([n], [m])


if __name__ == '__main__':
pytest.main([__file__])
34 changes: 34 additions & 0 deletions tests/transforms/test_rule_based_simplifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple

import hidet
from hidet.ir.expr import if_then_else
from hidet.transforms.rule_based_simplifier import RuleBasedSimplifier
from hidet.utils import repeat_until_converge


def test_rule_based_simplify():
n = hidet.symbol_var("n")
m = hidet.symbol_var("m")
testcase = namedtuple("testcase", ["expr", "expected"])
cases = [
testcase(expr=(n + 1) - (1 + n), expected=0),
testcase(expr=if_then_else(n > 0, 1, 1), expected=1),
testcase(expr=(n + m - m) - (m + n - m), expected=0),
testcase(expr=(n + m) - (m + n), expected=0),
testcase(expr=n / n, expected=1),
]

simp = RuleBasedSimplifier()
for expr, expected in cases:
assert repeat_until_converge(simp, expr) == expected

0 comments on commit 1252220

Please sign in to comment.