diff --git a/jax/_src/array.py b/jax/_src/array.py index d8182976254e..d5f742915284 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1120,7 +1120,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): bufs.append(buf) break else: - bufs.append(buf) + bufs.append(candidates_list[-1]) return pxla.batched_device_put(x.aval, sharding, bufs, devices) diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 1c4671ee6451..010edef1e54a 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1992,7 +1992,8 @@ def compute_dim_vars_from_arg_shapes( generate the code for computing the dimension variables. It also generates the shape assertions. - Returns: the values of the dimension variables, in the order determined by + Returns: + The values of the dimension variables, in the order determined by `all_dim_vars(args_avals)`. """ dim_vars = all_dim_vars(args_avals) @@ -2006,8 +2007,7 @@ def compute_dim_vars_from_arg_shapes( } synthetic_eval = ShapeEvaluator(synthetic_env) shape_constraints.shape_assertions(synthetic_eval) - dim_values = [synthetic_eval.evaluate(solution[var]) for var in dim_vars] - return tuple(dim_values) + return tuple(synthetic_eval.evaluate(solution[var]) for var in dim_vars) def _solve_dim_equations( eqns: list[_DimEquation], @@ -2141,7 +2141,8 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): eqns = [eqn for eqn in eqns if not process_one_eqn(eqn)] if not eqns: add_explicit_symbolic_constraints(shape_env) - return shape_env, shape_constraints # SUCCESS + # SUCCESS + return shape_env, shape_constraints # pytype: disable=bad-return-type elif len(eqns) >= nr_eqns: break diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 102e4f490b5c..531177b7244c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1699,6 +1699,7 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). # The below custom call achieves the sharding like above example. + assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) if config.use_shardy_partitioner.value: physical_ndim = core.physical_aval(aval).ndim s = sharding_impls.SdyArraySharding( diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 6e2f11833b9d..4b4f8f7eddee 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -177,9 +177,12 @@ def new_arg(self, pval: PartialVal) -> JaxprTracer: if const is None: aval = pval.get_aval() if type(aval) is DShapedArray: + # TODO(dougalm): Fix the type error and remove the pytype pragmas. + # pytype: disable=attribute-error shape = [self.new_instantiated_const(d) if isinstance(d, Tracer) and d._trace.level < self.level else d for d in aval.shape] + # pytype: enable=attribute-error aval = aval.update(shape=tuple(shape)) return JaxprTracer(self, PartialVal.unknown(aval), LambdaBinding()) else: @@ -1776,6 +1779,9 @@ def lit(a: Atom) -> Literal | None: newvars: dict[Var, Var] = {} newvar = lambda aval: newname(_substitute_vars_in_type(lits, newvars, aval)) var = lambda v: newvars.get(v) or newvars.setdefault(v, newvar(v.aval)) + lit_or_var = ( + lambda a: a if isinstance(a, Literal) else (lit(a) or var(a)) + ) dropvar = lambda aval: DropVar(_substitute_vars_in_type(lits, newvars, aval)) def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: @@ -1794,10 +1800,10 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: new_invars = [var(v) for v in jaxpr.invars] new_eqns = [] for eqn in jaxpr.eqns: - invars = [lit(x) or var(x) for x in eqn.invars] + invars = [lit_or_var(x) for x in eqn.invars] outvars = [var(v) if v in used else dropvar(v.aval) for v in eqn.outvars] new_eqns.append(eqn.replace(invars=invars, outvars=outvars)) - new_outvars = [lit(v) or var(v) for v in jaxpr.outvars] + new_outvars = [lit_or_var(v) for v in jaxpr.outvars] jaxpr_effects = make_jaxpr_effects(new_constvars, new_invars, new_outvars, new_eqns) new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 418240a4a86e..547415c098b4 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -513,7 +513,7 @@ def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn): # jaxpr for each branch. branches_known_ : list[core.ClosedJaxpr] = [] branches_staged_: list[core.ClosedJaxpr] = [] - branch_res_avals: list[core.AbstractValue] = [] + branch_res_avals: list[list[core.AbstractValue]] = [] for jaxpr in branches: jaxpr_known, jaxpr_staged, _, inst_out, num_res = \ pe.partial_eval_jaxpr_custom( diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 1c529b8938f1..b4609282e2f8 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1651,7 +1651,7 @@ def _partial_eval_jaxpr_custom_rule( def _add_reshapes(which, jaxpr_known, jaxpr_staged): # add singleton axes to residuals which are from jaxpr_known and are scalars - which_ = [w and not v.aval.shape + which_ = [w and not v.aval.shape # pytype: disable=attribute-error for w, v in zip(which, jaxpr_staged.invars[:len(which)])] if not any(which_): return jaxpr_known, jaxpr_staged assert not jaxpr_known.constvars and not jaxpr_staged.constvars