Skip to content

Commit

Permalink
Ensured that JAX type checks under pytype on Python 3.12
Browse files Browse the repository at this point in the history
Some errors uncovered by pytype look genuine and need to be revisited in
the in the future.

PiperOrigin-RevId: 704268742
  • Loading branch information
superbobry authored and Google-ML-Automation committed Dec 9, 2024
1 parent 5a1c4c5 commit 1ac6b76
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
bufs.append(buf)
break
else:
bufs.append(buf)
bufs.append(candidates_list[-1])
return pxla.batched_device_put(x.aval, sharding, bufs, devices)


Expand Down
9 changes: 5 additions & 4 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,7 +1992,8 @@ def compute_dim_vars_from_arg_shapes(
generate the code for computing the dimension variables. It also generates
the shape assertions.
Returns: the values of the dimension variables, in the order determined by
Returns:
The values of the dimension variables, in the order determined by
`all_dim_vars(args_avals)`.
"""
dim_vars = all_dim_vars(args_avals)
Expand All @@ -2006,8 +2007,7 @@ def compute_dim_vars_from_arg_shapes(
}
synthetic_eval = ShapeEvaluator(synthetic_env)
shape_constraints.shape_assertions(synthetic_eval)
dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars]
return tuple(dim_values)
return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars)

def _solve_dim_equations(
eqns: list[_DimEquation],
Expand Down Expand Up @@ -2141,7 +2141,8 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv):
eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)]
if not eqns:
add_explicit_symbolic_constraints(shape_env)
return shape_env, shape_constraints # SUCCESS
# SUCCESS
return shape_env, shape_constraints # pytype: disable=bad-return-type
elif len(eqns) >= nr_eqns:
break

Expand Down
1 change: 1 addition & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value:
# For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2),
# then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None).
# The below custom call achieves the sharding like above example.
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
if config.use_shardy_partitioner.value:
physical_ndim = core.physical_aval(aval).ndim
s = sharding_impls.SdyArraySharding(
Expand Down
10 changes: 8 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,12 @@ def new_arg(self, pval: PartialVal) -> JaxprTracer:
if const is None:
aval = pval.get_aval()
if type(aval) is DShapedArray:
# TODO(dougalm): Fix the type error and remove the pytype pragmas.
# pytype: disable=attribute-error
shape = [self.new_instantiated_const(d)
if isinstance(d, Tracer) and d._trace.level < self.level else d
for d in aval.shape]
# pytype: enable=attribute-error
aval = aval.update(shape=tuple(shape))
return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding())
else:
Expand Down Expand Up @@ -1776,6 +1779,9 @@ def lit(a: Atom) -> Literal | None:
newvars: dict[Var, Var] = {}
newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval))
var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval))
lit_or_var = (
lambda a: a if isinstance(a, Literal) else (lit(a) or var(a))
)
dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval))

def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
Expand All @@ -1794,10 +1800,10 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(x) or var(x) for x in eqn.invars]
invars = [lit_or_var(x) for x in eqn.invars]
outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars]
new_eqns.append(eqn.replace(invars=invars, outvars=outvars))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_outvars = [lit_or_var(v) for v in jaxpr.outvars]
jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars,
new_eqns)
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
# jaxpr for each branch.
branches_known_ : list[core.ClosedJaxpr] = []
branches_staged_: list[core.ClosedJaxpr] = []
branch_res_avals: list[core.AbstractValue] = []
branch_res_avals: list[list[core.AbstractValue]] = []
for jaxpr in branches:
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ def _partial_eval_jaxpr_custom_rule(

def _add_reshapes(which, jaxpr_known, jaxpr_staged):
# add singleton axes to residuals which are from jaxpr_known and are scalars
which_ = [w and not v.aval.shape
which_ = [w and not v.aval.shape # pytype: disable=attribute-error
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
if not any(which_): return jaxpr_known, jaxpr_staged
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
Expand Down

0 comments on commit 1ac6b76

Please sign in to comment.