Skip to content

Commit

Permalink
Remove uses of deprecated jax.lax.tie_in
Browse files Browse the repository at this point in the history
This has been a no-op since jax v0.2.0, and passes the second argument through unchanged. `tie_in` will be deprecated as of jax v0.4.24; see jax-ml/jax#19413

PiperOrigin-RevId: 602865274
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Jan 30, 2024
1 parent 3606248 commit 6dcafcf
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion aqt/jax_legacy/jax/flax_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def dot_product_attention(query,
if not deterministic and dropout_rate > 0.0:
if dropout_rng is None:
raise ValueError('dropout_rng cannot be None if dropout is requested.')
keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# dropout is broadcast across the batch+head+non-attention dimension
dropout_dims = attn_weights.shape[-(2 * len(axis)):]
Expand Down
3 changes: 1 addition & 2 deletions aqt/jax_legacy/jax/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,9 @@ def fake_quant(self,
*,
quantized_type: primitives.jnp_dtype,
fake_dependency: Optional[jnp.ndarray] = None) -> jnp.ndarray:
del fake_dependency # unused; this was a remnant of pre-omnistaging JAX.
x_dtype = x.dtype
quantized_x = self.to_quantized(x, dtype=quantized_type)
if fake_dependency is not None:
quantized_x = lax.tie_in(fake_dependency, quantized_x)
return self.from_quantized(quantized_x, dtype=x_dtype)

# Assumes weights are unsigned int of precision prec.
Expand Down
3 changes: 2 additions & 1 deletion aqt/jax_legacy/jax/wmt_mlperf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@


def hardware_bernoulli(rng_key, p=np.float32(0.5), shape=None):
return lax.rng_uniform(lax.tie_in(rng_key, 0.0), 1.0, shape) < p
del rng_key # unused
return lax.rng_uniform(0.0, 1.0, shape) < p


def set_hardware_bernoulli():
Expand Down

0 comments on commit 6dcafcf

Please sign in to comment.