From 94967548f41ebca5939b2184c9ca2e14a0d3fa38 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 23 Jan 2026 22:46:18 -0800 Subject: [PATCH 1/4] [tx] Right align prompts for decoding --- skyrl-tx/tx/models/llama3.py | 4 ++-- skyrl-tx/tx/models/qwen3.py | 4 ++-- skyrl-tx/tx/utils/generator.py | 36 ++++++++++++++++++---------------- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/skyrl-tx/tx/models/llama3.py b/skyrl-tx/tx/models/llama3.py index b7eb14d52..f39bc12e2 100644 --- a/skyrl-tx/tx/models/llama3.py +++ b/skyrl-tx/tx/models/llama3.py @@ -9,7 +9,7 @@ from tx.layers.layernorm import RMSNorm from tx.utils.logits_processor import LogitsProcessorMixin, LMHead from tx.models.types import CausalLMOutput, ModelOutput -from tx.utils.generator import GeneratorMixin, KVCache, compute_positions +from tx.utils.generator import GeneratorMixin, KVCache class Llama3Attention(nnx.Module): @@ -303,7 +303,7 @@ def __call__( kv_cache: KVCache | None = None, ) -> CausalLMOutput: if positions is None: - positions = compute_positions(attention_mask) + positions = jnp.arange(attention_mask.shape[1])[None, :] outputs = self.model( input_ids, diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index fdf68ee48..236aa1193 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -10,7 +10,7 @@ from tx.models.configs import Qwen3Config from tx.layers.layernorm import RMSNorm from tx.models.types import CausalLMOutput, ModelOutput -from tx.utils.generator import GeneratorMixin, KVCache, compute_positions +from tx.utils.generator import GeneratorMixin, KVCache class Qwen3Attention(nnx.Module): @@ -418,7 +418,7 @@ def __call__( kv_cache: KVCache | None = None, ) -> CausalLMOutput: if positions is None: - positions = compute_positions(attention_mask) + positions = jnp.arange(attention_mask.shape[1])[None, :] outputs = self.model( input_ids, diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 5f97ccd17..69306e71e 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -68,14 +68,9 @@ class GenerateOutput: prompt_logprobs: list[list[float]] | None = None -def compute_positions(attention_mask: jax.Array) -> jax.Array: - """Compute positions from attention mask. - - Positions start at 0 from the first non-zero value in the attention mask - and increment sequentially. - """ - first_token_idx = jnp.argmax(attention_mask, axis=1, keepdims=True) - return jnp.arange(attention_mask.shape[1])[None, :] - first_token_idx +def batch_roll(arr: jax.Array, shifts: jax.Array) -> jax.Array: + """Roll each row of arr by its corresponding shift amount along axis 0.""" + return jax.vmap(jnp.roll)(arr, shifts) def find_string_stop_position( @@ -137,19 +132,16 @@ def _prefill_and_decode( prompt_logprobs: bool = False, ): """JIT-compiled prefill + decode loop. Fuses everything for maximum efficiency.""" - # Compute positions from attention mask - positions = compute_positions(attention_mask) - - # Prefill: process full prompt + # Prefill: process full prompt (left-aligned, so positions start at 0) outputs = model( input_ids, attention_mask=attention_mask, - positions=positions, adapter_indices=adapter_indices, ) - # For left-aligned sequences, find the last real token position for each sequence - last_token_idx = attention_mask.sum(axis=1) - 1 # Shape: [B] + # Compute sequence lengths and last token positions + seq_lengths = attention_mask.sum(axis=1) # Shape: [B] + last_token_idx = seq_lengths - 1 batch_idx = jnp.arange(input_ids.shape[0]) # Compute logits for sampling and optionally for prompt logprobs @@ -164,8 +156,18 @@ def _prefill_and_decode( last_logits = model.compute_logits(last_hidden, adapter_indices)[:, 0, :] prompt_logprobs_array = None - # Pad KV cache and attention mask - kv_cache = outputs.kv_cache.pad_to_length(max_length) + # Right-align KV cache and attention mask so decoding doesn't have gaps + prompt_length = attention_mask.shape[1] + shifts = prompt_length - seq_lengths + kv_cache = KVCache( + keys=[batch_roll(k, shifts) for k in outputs.kv_cache.keys], + values=[batch_roll(v, shifts) for v in outputs.kv_cache.values], + cache_position=outputs.kv_cache.cache_position, + ) + attention_mask = batch_roll(attention_mask, shifts) + + # Pad KV cache and attention mask to max_length + kv_cache = kv_cache.pad_to_length(max_length) decode_attention_mask = jnp.pad(attention_mask, ((0, 0), (0, max_length - attention_mask.shape[1]))) def decode_fn(s: DecodeState, step: jax.Array) -> tuple[DecodeState, tuple[jax.Array, jax.Array]]: From 013ca4af209647477ef554c81f996c69628b7dc1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 23 Jan 2026 22:47:03 -0800 Subject: [PATCH 2/4] update --- skyrl-tx/tx/tinker/backends/jax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py index e57bc7b7e..2ff3a0a99 100644 --- a/skyrl-tx/tx/tinker/backends/jax.py +++ b/skyrl-tx/tx/tinker/backends/jax.py @@ -725,7 +725,6 @@ def sample( # Pad sequences to same length within the batch to minimize memory usage. # Also bin it so the JIT has to compile fewer kernels. - # Use right-padding, which means during decoding there will be "gaps" in the attention mask. max_len = round_up_seq_len(max((len(seq) for seq in batch_prompts), default=0)) input_ids = pad_batch(batch_prompts, max_len, np.int32) attention_mask = pad_batch([[1] * len(seq) for seq in batch_prompts], max_len, np.int32) From b05f70a429d85fea02d7ff13624b24a339564810 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 23 Jan 2026 22:55:05 -0800 Subject: [PATCH 3/4] Update skyrl-tx/tx/utils/generator.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- skyrl-tx/tx/utils/generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-tx/tx/utils/generator.py b/skyrl-tx/tx/utils/generator.py index 69306e71e..8c0ade170 100644 --- a/skyrl-tx/tx/utils/generator.py +++ b/skyrl-tx/tx/utils/generator.py @@ -69,8 +69,8 @@ class GenerateOutput: def batch_roll(arr: jax.Array, shifts: jax.Array) -> jax.Array: - """Roll each row of arr by its corresponding shift amount along axis 0.""" - return jax.vmap(jnp.roll)(arr, shifts) + """Roll each element of a batch along its first non-batch axis (the sequence axis).""" + return jax.vmap(functools.partial(jnp.roll, axis=0))(arr, shifts) def find_string_stop_position( From bb06675aedf03f8aaeedb76c73a79823fd390058 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 23 Jan 2026 23:45:14 -0800 Subject: [PATCH 4/4] fix OOM in CI --- skyrl-tx/tests/models/test_qwen3_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-tx/tests/models/test_qwen3_generate.py b/skyrl-tx/tests/models/test_qwen3_generate.py index 704df770f..07d13f6b5 100644 --- a/skyrl-tx/tests/models/test_qwen3_generate.py +++ b/skyrl-tx/tests/models/test_qwen3_generate.py @@ -44,7 +44,7 @@ def test_qwen3_generate(): with tempfile.TemporaryDirectory() as tmp: hf_model.save_pretrained(tmp, safe_serialization=True) base_config = PretrainedConfig.from_pretrained(model_name) - config = Qwen3Config(base_config, max_lora_adapters=32, max_lora_rank=32, shard_attention_heads=True) + config = Qwen3Config(base_config, max_lora_adapters=2, max_lora_rank=32, shard_attention_heads=True) mesh = jax.make_mesh((1, 1), ("fsdp", "tp")) with jax.set_mesh(mesh):