Skip to content

Commit

Permalink
Minor cleanup. (#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee authored Nov 16, 2024
1 parent 8702977 commit 594313d
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,21 +878,24 @@ def extend_step(
q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs, time_step=time_step)
updated_state = dict(time_step=time_step + num_query_steps)
if kv_state is None:
# Update the cache via one-hot broadcast and addition. [B, S, N, H].
# Update the cache via dynamic slice. [B, S, N, H].
cached_key = cached_states["key"]
cached_value = cached_states["value"]

# Ensure that we accumulate using the original dtype.
k_proj = k_proj.astype(cached_key.dtype)
v_proj = v_proj.astype(cached_value.dtype)

# Function to update the cached_key for a single batch element.
def update_single(cached_key_slice, k_proj_slice, time_idx):
start_indices = (time_idx, 0, 0)
return jax.lax.dynamic_update_slice(cached_key_slice, k_proj_slice, start_indices)
# Function to update the cache for a single batch element.
def update_single(cached_kv_slice, kv_proj_slice, time_idx):
return jax.lax.dynamic_update_slice_in_dim(
cached_kv_slice, kv_proj_slice, time_idx, axis=0
)

# Use jax.vmap to vectorize over the batch dimension.
k_proj = jax.vmap(update_single)(cached_key, k_proj, time_step)
v_proj = jax.vmap(update_single)(cached_value, v_proj, time_step)
vmap_update = jax.vmap(update_single)
k_proj = vmap_update(cached_key, k_proj, time_step)
v_proj = vmap_update(cached_value, v_proj, time_step)
updated_state.update(key=k_proj, value=v_proj)
return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj)

Expand Down Expand Up @@ -1348,17 +1351,10 @@ def forward(
cfg = self.config
# Query should have shape of [batch_size, seq_len, num_heads, per_head_dim].
query, key, value = self.i_proj(query, key=key, value=value)
if time_step is None:
# If time_step is None, then we set it to [batch_size, seq_len].
# In this case, batch_size can be set as 1.
time_step = jnp.expand_dims(jnp.arange(query.shape[1]), 0)
else:
# Time step shape is [batch_size]
# The expected input shape for rope_pos_emb_layer is [batch_size, seq_len]
# Therefore, expanding the shape of time_step to [batch_size, step].
step = query.shape[1]
time_step = jnp.arange(step)[None] + time_step[:, None]
sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(time_step).astype(query.dtype)
query_pos = jnp.arange(query.shape[1])[None] # [batch_size=1, seq_len].
if time_step is not None:
query_pos = query_pos + time_step[:, None] # [batch_size, seq_len].
sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(query_pos).astype(query.dtype)
# sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim]
sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2)
query, key, value = apply_rotary_position_embeddings(
Expand Down

0 comments on commit 594313d

Please sign in to comment.