Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skyrl-tx/tests/models/test_qwen3_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_qwen3_generate():
with tempfile.TemporaryDirectory() as tmp:
hf_model.save_pretrained(tmp, safe_serialization=True)
base_config = PretrainedConfig.from_pretrained(model_name)
config = Qwen3Config(base_config, max_lora_adapters=32, max_lora_rank=32, shard_attention_heads=True)
config = Qwen3Config(base_config, max_lora_adapters=2, max_lora_rank=32, shard_attention_heads=True)

mesh = jax.make_mesh((1, 1), ("fsdp", "tp"))
with jax.set_mesh(mesh):
Expand Down
4 changes: 2 additions & 2 deletions skyrl-tx/tx/models/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tx.layers.layernorm import RMSNorm
from tx.utils.logits_processor import LogitsProcessorMixin, LMHead
from tx.models.types import CausalLMOutput, ModelOutput
from tx.utils.generator import GeneratorMixin, KVCache, compute_positions
from tx.utils.generator import GeneratorMixin, KVCache


class Llama3Attention(nnx.Module):
Expand Down Expand Up @@ -303,7 +303,7 @@ def __call__(
kv_cache: KVCache | None = None,
) -> CausalLMOutput:
if positions is None:
positions = compute_positions(attention_mask)
positions = jnp.arange(attention_mask.shape[1])[None, :]

outputs = self.model(
input_ids,
Expand Down
4 changes: 2 additions & 2 deletions skyrl-tx/tx/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tx.models.configs import Qwen3Config
from tx.layers.layernorm import RMSNorm
from tx.models.types import CausalLMOutput, ModelOutput
from tx.utils.generator import GeneratorMixin, KVCache, compute_positions
from tx.utils.generator import GeneratorMixin, KVCache


class Qwen3Attention(nnx.Module):
Expand Down Expand Up @@ -418,7 +418,7 @@ def __call__(
kv_cache: KVCache | None = None,
) -> CausalLMOutput:
if positions is None:
positions = compute_positions(attention_mask)
positions = jnp.arange(attention_mask.shape[1])[None, :]

outputs = self.model(
input_ids,
Expand Down
1 change: 0 additions & 1 deletion skyrl-tx/tx/tinker/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,6 @@ def sample(

# Pad sequences to same length within the batch to minimize memory usage.
# Also bin it so the JIT has to compile fewer kernels.
# Use right-padding, which means during decoding there will be "gaps" in the attention mask.
max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0))
input_ids = pad_batch(batch_prompts, max_len, np.int32)
attention_mask = pad_batch([[1] * len(seq) for seq in batch_prompts], max_len, np.int32)
Expand Down
36 changes: 19 additions & 17 deletions skyrl-tx/tx/utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,9 @@ class GenerateOutput:
prompt_logprobs: list[list[float]] | None = None


def compute_positions(attention_mask: jax.Array) -> jax.Array:
"""Compute positions from attention mask.

Positions start at 0 from the first non-zero value in the attention mask
and increment sequentially.
"""
first_token_idx = jnp.argmax(attention_mask, axis=1, keepdims=True)
return jnp.arange(attention_mask.shape[1])[None, :] - first_token_idx
def batch_roll(arr: jax.Array, shifts: jax.Array) -> jax.Array:
"""Roll each element of a batch along its first non-batch axis (the sequence axis)."""
return jax.vmap(functools.partial(jnp.roll, axis=0))(arr, shifts)


def find_string_stop_position(
Expand Down Expand Up @@ -137,19 +132,16 @@ def _prefill_and_decode(
prompt_logprobs: bool = False,
):
"""JIT-compiled prefill + decode loop. Fuses everything for maximum efficiency."""
# Compute positions from attention mask
positions = compute_positions(attention_mask)

# Prefill: process full prompt
# Prefill: process full prompt (left-aligned, so positions start at 0)
outputs = model(
input_ids,
attention_mask=attention_mask,
positions=positions,
adapter_indices=adapter_indices,
)

# For left-aligned sequences, find the last real token position for each sequence
last_token_idx = attention_mask.sum(axis=1) - 1 # Shape: [B]
# Compute sequence lengths and last token positions
seq_lengths = attention_mask.sum(axis=1) # Shape: [B]
last_token_idx = seq_lengths - 1
batch_idx = jnp.arange(input_ids.shape[0])

# Compute logits for sampling and optionally for prompt logprobs
Expand All @@ -164,8 +156,18 @@ def _prefill_and_decode(
last_logits = model.compute_logits(last_hidden, adapter_indices)[:, 0, :]
prompt_logprobs_array = None

# Pad KV cache and attention mask
kv_cache = outputs.kv_cache.pad_to_length(max_length)
# Right-align KV cache and attention mask so decoding doesn't have gaps
prompt_length = attention_mask.shape[1]
shifts = prompt_length - seq_lengths
kv_cache = KVCache(
keys=[batch_roll(k, shifts) for k in outputs.kv_cache.keys],
values=[batch_roll(v, shifts) for v in outputs.kv_cache.values],
cache_position=outputs.kv_cache.cache_position,
)
attention_mask = batch_roll(attention_mask, shifts)

# Pad KV cache and attention mask to max_length
kv_cache = kv_cache.pad_to_length(max_length)
decode_attention_mask = jnp.pad(attention_mask, ((0, 0), (0, max_length - attention_mask.shape[1])))

def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.Array, jax.Array]]:
Expand Down