Skip to content

Commit

Permalink
Remove calls to deprecated function jax.lax.tie_in
Browse files Browse the repository at this point in the history
This has been a no-op since jax v0.2.0, and passes the second argument through unchanged. `tie_in` will be deprecated as of jax v0.4.24; see jax-ml/jax#19413

PiperOrigin-RevId: 600515954
  • Loading branch information
Jake VanderPlas authored and Flax Authors committed Jan 22, 2024
1 parent 0370a61 commit 993aa2c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions flax/core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def dot_product_attention(
if not deterministic and dropout_rate > 0.0:
if dropout_rng is None:
dropout_rng = scope.make_rng('dropout')
keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate)
keep_prob = 1.0 - dropout_rate
if broadcast_dropout:
# dropout is broadcast across the batch+head+non-attention dimension
dropout_dims = attn_weights.shape[-(2 * len(axis)) :]
Expand Down Expand Up @@ -511,8 +511,8 @@ def tri(n, m, k=0):
# Tie in the key to avoid the mask becoming a constant.
# This way XLA can construct the mask during computation and fuse it
# with the attention ops.
x = lax.tie_in(key, jnp.arange(n, dtype=jnp.int32))
y = lax.tie_in(key, jnp.arange(m, dtype=jnp.int32))
x = jnp.arange(n, dtype=jnp.int32)
y = jnp.arange(m, dtype=jnp.int32)
mask = lax.ge(
(lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,))) + k,
lax.broadcast(y, [n]),
Expand Down

0 comments on commit 993aa2c

Please sign in to comment.