Skip to content

Commit

Permalink
Merge pull request #18762 from gnecula:poly_getitem_next
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587018677
  • Loading branch information
jax authors committed Dec 1, 2023
2 parents 54fee48 + 65fca0e commit b3c579e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 14 deletions.
71 changes: 60 additions & 11 deletions jax/experimental/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,9 @@ def normalize(cls, coeffs: dict[_DimMon, int]) -> DimSize:
def normalize_floordiv_times_divisor(cls, coeffs: dict[_DimMon, int]) -> DimSize:
# Look for floordiv(E, M) * M and turn into E - mod(E, M). This comes
# up when handling strided convolution.
for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV):
for dec in _decompose_expr(_DimExpr(coeffs), _DimAtom.FLOORDIV,
with_exp=1):
# e = factor * floordiv(operands)^exp * rest_monomial + rest_expr
if dec.exp != 1:
continue
if dec.rest_monomial == 1 and dec.factor == 1:
continue
m_trimmed, m_remainder = divmod(dec.factor * dec.rest_monomial, dec.operands[1])
Expand Down Expand Up @@ -472,11 +471,33 @@ def inconclusive_comparison(self, operation: str, op: Any) -> Exception:
"See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#comparison-of-symbolic-dimensions-is-partially-supported.")

def ge(self, other: DimSize) -> bool:
lb, ub = _ensure_poly(self - other, "ge").bounds()
self_minus_other = _ensure_poly(self - other, "ge")
lb, ub = self_minus_other.bounds()
if lb >= 0:
return True
if ub < 0:
return False
# Attempt to handle non_negative
for dec in _decompose_expr(self_minus_other, _DimAtom.NON_NEGATIVE):
# e = factor * non_negative(operands)^exp * rest_monomial + rest_expr
e1 = dec.rest_expr
e2 = dec.rest_expr + dec.factor * (dec.operands[0] ** dec.exp) * dec.rest_monomial
try:
if (e1 >= 0) and (e2 >= 0):
return True
except InconclusiveDimensionOperation:
continue
# Attempt to handle floordiv >= 0
for dec in _decompose_expr(self_minus_other, _DimAtom.FLOORDIV,
with_exp=1, with_rest_monomial=1,
with_rest_expr=0):
# e = factor * floordiv(op1, op2)^1 * 1 + 0
if dec.factor > 0:
try:
if (dec.operands[0] >= 0) and (dec.operands[1] >= 0):
return True
except InconclusiveDimensionOperation:
continue
raise self.inconclusive_comparison(">=", other)

def __hash__(self):
Expand Down Expand Up @@ -680,9 +701,10 @@ def bounds(self) -> tuple[float, float]:
# Watch for special-case: ct*a - ct*mod(b, a) >= 1 when ct >= 0 and a >= 0
# TODO(necula): add more principled support for floordiv and mod
# For example, this will miss "1 + a - mod(b, a)"
for dec in _decompose_expr(self, _DimAtom.MOD):
# E = factor*mod(op1, op2)^exp * rest_monomial + rest_expr
if dec.exp == 1 and dec.rest_monomial == 1 and dec.rest_expr == - dec.factor * dec.operands[1]:
for dec in _decompose_expr(self, _DimAtom.MOD,
with_exp=1, with_rest_monomial=1):
# E = factor*mod(op1, op2)^1 * 1 + rest_expr
if dec.rest_expr == - dec.factor * dec.operands[1]:
try:
if dec.operands[1] <= 0:
continue
Expand Down Expand Up @@ -729,7 +751,9 @@ def __jax_array__(self):
class _Decomposition:
"""Decomposition of an expression around an operation atom.
E = factor * mod(*operands)^exp * rest_monomial + rest_expr
E.g., for decomposing around "mod":
E = factor * mod(*operands)^exp * rest_monomial + rest_expr
"""
factor: int
operands: Sequence[_DimExpr]
Expand All @@ -738,19 +762,44 @@ class _Decomposition:
rest_expr: _DimExpr


def _decompose_expr(e: _DimExpr, operation: str) -> Iterable[_Decomposition]:
def _decompose_expr(e: _DimExpr, operation: str, *,
with_factor: Optional[int] = None,
with_exp: Optional[int] = None,
with_rest_monomial: Optional[Union[_DimExpr, int]] = None,
with_rest_expr: Optional[Union[_DimExpr, int]] = None,
) -> Iterable[_Decomposition]:
"""Computes the decompositions of `e` into `_Decomposition`.
Args:
e: the expression to decompose
operation: the operation atom around which to decompose
with_factor, with_exp, with_rest_monomial, with_rest_expr: if present,
keep only the decompositions that match.
"""
for m, m_factor in e.monomials():
atoms = [(a, aexp) for a, aexp in m.items() if a.operation == operation]
if atoms:
e_minus_m_coeffs = e._coeffs.copy()
del e_minus_m_coeffs[m]
for a, aexp in atoms:
if with_factor is not None and with_factor != m_factor:
continue
if with_exp is not None and with_exp != aexp:
continue
rest_monomial = _DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1})
if (with_rest_monomial is not None and
not core.definitely_equal(with_rest_monomial, rest_monomial)):
continue
rest_expr = _DimExpr(e_minus_m_coeffs)
if (with_rest_expr is not None and
not core.definitely_equal(with_rest_expr, rest_expr)):
continue
yield _Decomposition(
factor=m_factor,
operands=a.operands,
exp=aexp,
rest_monomial=_DimExpr({m.divide(_DimMon.from_atom(a, aexp)): 1}),
rest_expr=_DimExpr(e_minus_m_coeffs))
rest_monomial=rest_monomial,
rest_expr=rest_expr)

core.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval
xla.pytype_aval_mappings[_DimExpr] = _DimExpr.get_aval
Expand Down
3 changes: 0 additions & 3 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2353,9 +2353,6 @@ def test_harness(self, harness: PolyHarness):
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("JAX implements eig only on CPU.")

if harness.group_name == "indexing":
raise unittest.SkipTest("TODO(necula): fix the indexing tests")

prev_jax_config_flags = {
fname: getattr(jax.config, fname)
for fname, fvalue in harness.override_jax_config_flags.items()
Expand Down

0 comments on commit b3c579e

Please sign in to comment.