Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions skyrl-tx/tests/utils/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,15 @@ def __call__(
hidden_states = jnp.tile(base[None, None, :], (batch_size, seq_len, 1))
keys = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)]
values = [jnp.zeros((batch_size, seq_len, 1, 1), dtype=jnp.float32)]
kv_cache = KVCache(keys=keys, values=values, cache_position=seq_len)
# Per-sequence cache_position (all same length in this test)
cache_position = (
attention_mask.sum(axis=1) if attention_mask is not None else jnp.full((batch_size,), seq_len)
)
kv_cache = KVCache(keys=keys, values=values, cache_position=cache_position)
else:
# Step: hidden_states vary with cache_position
hidden_states = jnp.tile(base[None, None, :] + kv_cache.cache_position, (batch_size, 1, 1))
# Step: hidden_states vary with cache_position (use mean for batched position)
mean_pos = kv_cache.cache_position.mean()
hidden_states = jnp.tile(base[None, None, :] + mean_pos, (batch_size, 1, 1))
kv_cache = KVCache(keys=kv_cache.keys, values=kv_cache.values, cache_position=kv_cache.cache_position + 1)

return CausalLMOutput(last_hidden_state=hidden_states, kv_cache=kv_cache)
Expand Down
15 changes: 5 additions & 10 deletions skyrl-tx/tx/models/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __call__(
attention_mask: jax.Array,
positions: jax.Array,
adapter_indices: jax.Array | None = None,
kv_cache: tuple[jax.Array, jax.Array, int] | None = None,
kv_cache: tuple[jax.Array, jax.Array] | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
B, T, _ = x.shape

Expand All @@ -97,9 +97,7 @@ def __call__(

# Handle KV cache
if kv_cache is not None:
k_cache, v_cache, cache_position = kv_cache
k = jax.lax.dynamic_update_slice(k_cache, k, (0, cache_position, 0, 0))
v = jax.lax.dynamic_update_slice(v_cache, v, (0, cache_position, 0, 0))
k, v = KVCache.update_layer(kv_cache, k, v, positions)

updated_cache = (k, v)

Expand Down Expand Up @@ -175,7 +173,7 @@ def __call__(
attention_mask: jax.Array,
positions: jax.Array,
adapter_indices: jax.Array | None = None,
kv_cache: tuple[jax.Array, jax.Array, int] | None = None,
kv_cache: tuple[jax.Array, jax.Array] | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -243,7 +241,7 @@ def __call__(
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position),
kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]),
)
updated_keys.append(k)
updated_values.append(v)
Expand All @@ -252,12 +250,9 @@ def __call__(
if output_hidden_states:
all_hidden_states.append(hidden_states)

# Increment cache_position if cache exists, or use sequence length for new cache
new_cache_position = kv_cache.cache_position + 1 if kv_cache is not None else input_ids.shape[1]

return ModelOutput(
last_hidden_state=hidden_states,
kv_cache=KVCache(keys=updated_keys, values=updated_values, cache_position=new_cache_position),
kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask),
hidden_states=all_hidden_states if output_hidden_states else None,
)

Expand Down
15 changes: 5 additions & 10 deletions skyrl-tx/tx/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __call__(
attention_mask: jax.Array,
positions: jax.Array,
adapter_indices: jax.Array | None = None,
kv_cache: tuple[jax.Array, jax.Array, int] | None = None,
kv_cache: tuple[jax.Array, jax.Array] | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
B, T, _ = x.shape

Expand All @@ -98,9 +98,7 @@ def __call__(

# Handle KV cache
if kv_cache is not None:
k_cache, v_cache, cache_position = kv_cache
k = jax.lax.dynamic_update_slice(k_cache, k, (0, cache_position, 0, 0))
v = jax.lax.dynamic_update_slice(v_cache, v, (0, cache_position, 0, 0))
k, v = KVCache.update_layer(kv_cache, k, v, positions)

updated_cache = (k, v)

Expand Down Expand Up @@ -290,7 +288,7 @@ def __call__(
attention_mask: jax.Array,
positions: jax.Array,
adapter_indices: jax.Array | None = None,
kv_cache: tuple[jax.Array, jax.Array, int] | None = None,
kv_cache: tuple[jax.Array, jax.Array] | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -358,7 +356,7 @@ def __call__(
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx], kv_cache.cache_position),
kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]),
)
updated_keys.append(k)
updated_values.append(v)
Expand All @@ -367,12 +365,9 @@ def __call__(
if output_hidden_states:
all_hidden_states.append(hidden_states)

# Increment cache_position if cache exists, or use sequence length for new cache
new_cache_position = kv_cache.cache_position + 1 if kv_cache is not None else input_ids.shape[1]

return ModelOutput(
last_hidden_state=hidden_states,
kv_cache=KVCache(keys=updated_keys, values=updated_values, cache_position=new_cache_position),
kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask),
hidden_states=all_hidden_states if output_hidden_states else None,
)

Expand Down
75 changes: 55 additions & 20 deletions skyrl-tx/tx/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,54 @@ class KVCache:

keys: list[jax.Array]
values: list[jax.Array]
cache_position: int
cache_position: jax.Array # Per-sequence positions of shape [B] for left-aligned decoding

@staticmethod
def update(
kv_cache: KVCache | None,
keys: list[jax.Array],
values: list[jax.Array],
positions: jax.Array,
attention_mask: jax.Array,
) -> KVCache:
"""Create an updated KVCache with computed cache positions for left-aligned decoding.

Args:
kv_cache: Existing KVCache (None during prefill).
keys: List of key arrays per layer.
values: List of value arrays per layer.
positions: Position indices with shape [B, seq_len].
attention_mask: Attention mask with shape [B, seq_len].

Returns:
New KVCache with computed cache_position.
"""
if kv_cache is not None:
# Decode: next position is current position + 1
cache_position = positions[:, 0] + 1
else:
# Prefill: next position is the sequence length (number of real tokens)
cache_position = attention_mask.sum(axis=1)
return KVCache(keys=keys, values=values, cache_position=cache_position)

@staticmethod
def update_layer(kv_cache, k, v, positions):
"""Update a single layer's KV cache at the given positions (for left-aligned decoding).

Args:
kv_cache: Tuple of (k_cache, v_cache) arrays for this layer.
k: New key values with shape [B, seq_len, num_heads, head_dim].
v: New value values with shape [B, seq_len, num_heads, head_dim].
positions: Position indices with shape [B, seq_len].
"""
k_cache, v_cache = kv_cache

def update_at_pos(cache_slice, new_val_slice, pos):
return jax.lax.dynamic_update_slice(cache_slice, new_val_slice, (pos, 0, 0))

k = jax.vmap(update_at_pos)(k_cache, k, positions[:, 0])
v = jax.vmap(update_at_pos)(v_cache, v, positions[:, 0])
return k, v

def pad_to_length(self, max_length: int) -> KVCache:
"""Pad KV cache to a specified maximum length.
Expand Down Expand Up @@ -68,11 +115,6 @@ class GenerateOutput:
prompt_logprobs: list[list[float]] | None = None


def batch_roll(arr: jax.Array, shifts: jax.Array) -> jax.Array:
"""Roll each element of a batch along its first non-batch axis (the sequence axis)."""
return jax.vmap(functools.partial(jnp.roll, axis=0))(arr, shifts)


def find_string_stop_position(
tokens: list[int],
tokenizer,
Expand Down Expand Up @@ -139,9 +181,8 @@ def _prefill_and_decode(
adapter_indices=adapter_indices,
)

# Compute sequence lengths and last token positions
seq_lengths = attention_mask.sum(axis=1) # Shape: [B]
last_token_idx = seq_lengths - 1
# For left-aligned sequences, find the last real token position for each sequence
last_token_idx = attention_mask.sum(axis=1) - 1 # Shape: [B]
batch_idx = jnp.arange(input_ids.shape[0])

# Compute logits for sampling and optionally for prompt logprobs
Expand All @@ -156,15 +197,8 @@ def _prefill_and_decode(
last_logits = model.compute_logits(last_hidden, adapter_indices)[:, 0, :]
prompt_logprobs_array = None

# Right-align KV cache and attention mask so decoding doesn't have gaps
prompt_length = attention_mask.shape[1]
shifts = prompt_length - seq_lengths
kv_cache = KVCache(
keys=[batch_roll(k, shifts) for k in outputs.kv_cache.keys],
values=[batch_roll(v, shifts) for v in outputs.kv_cache.values],
cache_position=outputs.kv_cache.cache_position,
)
attention_mask = batch_roll(attention_mask, shifts)
# Pad KV cache and attention mask
kv_cache = outputs.kv_cache.pad_to_length(max_length)

# Pad KV cache and attention mask to max_length
kv_cache = kv_cache.pad_to_length(max_length)
Comment on lines 203 to 204
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment and the following line are redundant. The kv_cache is already padded to max_length on line 159. These lines can be removed.

Expand Down Expand Up @@ -196,8 +230,9 @@ def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.A
is_stop = jnp.any(next_token == stop_tokens, axis=1)
stop_pos = jnp.where((s.stop_pos == -1) & is_stop, step + 1, s.stop_pos)

# Update attention mask: set next position to 1
next_attention_mask = s.attention_mask.at[:, s.kv_cache.cache_position].set(1)
# Update attention mask at per-sequence positions (for left-aligned sequences)
batch_idx = jnp.arange(s.attention_mask.shape[0])
next_attention_mask = s.attention_mask.at[batch_idx, s.kv_cache.cache_position].set(1)

outputs = model(
next_token,
Expand Down