Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions slime/backends/fsdp_utils/data_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import torch

from slime.utils.mask_utils import CompressedLossMask, decompress_loss_mask
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions


def pack_sequences(
tokens: list[list[int]],
loss_masks: list[list[int]],
loss_masks: list[CompressedLossMask],
rewards: list[float],
raw_rewards: list,
response_lengths: list[int],
Expand All @@ -25,7 +26,7 @@ def pack_sequences(

Args:
tokens: List of token sequences
loss_masks: List of loss masks
loss_masks: List of compressed loss masks (run-length encoded)
rewards: List of rewards per sequence
raw_rewards: List of raw rewards per sequence
response_lengths: List of response lengths per sequence
Expand Down Expand Up @@ -72,7 +73,7 @@ def pack_sequences(

for i in indices:
seq_tokens = tokens[i]
seq_mask = loss_masks[i]
seq_mask = decompress_loss_mask(loss_masks[i])
seq_positionids = list(range(len(seq_tokens)))

flat_tokens.extend(seq_tokens)
Expand Down
6 changes: 4 additions & 2 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from slime.utils import train_metric_utils
from slime.utils.data import get_minimum_num_micro_batch_size
from slime.utils.mask_utils import decompress_loss_mask
from slime.utils.flops_utils import calculate_fwd_flops
from slime.utils.metric_utils import compute_pass_rate, compute_rollout_step
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
Expand Down Expand Up @@ -102,15 +103,16 @@ def get_batch(
batch["tokens"] = tokens
batch["packed_seq_params"] = packed_seq_params

# loss masks
# loss masks - decompress from run-length encoding
loss_masks = []
for loss_mask, total_length, response_length in zip(
for compressed_loss_mask, total_length, response_length in zip(
batch["loss_masks"],
batch["total_lengths"],
batch["response_lengths"],
strict=True,
):
prompt_length = total_length - response_length
loss_mask = torch.tensor(decompress_loss_mask(compressed_loss_mask), dtype=torch.int)
loss_mask = F.pad(loss_mask, (prompt_length - 1, 1), value=0)
loss_mask = slice_with_cp(loss_mask, 0, qkv_format, max_seqlen)
loss_masks.append(loss_mask)
Expand Down
6 changes: 3 additions & 3 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from slime.rollout.base_types import call_rollout_fn
from slime.utils import logging_utils
from slime.utils.health_monitor import RolloutHealthMonitor
from slime.utils.mask_utils import compress_loss_mask
from slime.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client
from slime.utils.logging_utils import configure_logger, init_tracking
from slime.utils.metric_utils import (
Expand Down Expand Up @@ -350,8 +351,7 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
"sample_indices": [sample.index for sample in samples],
}

# loss mask
# TODO: compress the loss mask
# loss mask - compressed using run-length encoding to reduce memory and network overhead
loss_masks = []
for sample in samples:
# always instantiate loss_mask if not provided
Expand All @@ -363,7 +363,7 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
), f"loss mask length {len(sample.loss_mask)} != response length {sample.response_length}"
if sample.remove_sample:
sample.loss_mask = [0] * sample.response_length
loss_masks.append(sample.loss_mask)
loss_masks.append(compress_loss_mask(sample.loss_mask))
train_data["loss_masks"] = loss_masks

# overwriting the raw reward
Expand Down
78 changes: 78 additions & 0 deletions slime/utils/mask_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,84 @@
from transformers import AutoTokenizer


# Run-length encoding (RLE) for compressing binary loss masks.
# Reference: https://en.wikipedia.org/wiki/Run-length_encoding
# Type alias for compressed loss mask: (run_lengths, starting_value)
# Example: [0,0,0,1,1,1,1,1] -> ([3, 5], 0) meaning "3 zeros, then 5 ones"
CompressedLossMask = tuple[list[int], int]


def compress_loss_mask(mask: list[int]) -> CompressedLossMask:
"""Compress a binary loss mask using run-length encoding.

Args:
mask: A list of 0s and 1s representing the loss mask.

Returns:
A tuple of (run_lengths, starting_value) where:
- run_lengths: list of consecutive run lengths
- starting_value: the value (0 or 1) of the first run

Examples:
>>> compress_loss_mask([0, 0, 0, 1, 1, 1, 1, 1])
([3, 5], 0)
>>> compress_loss_mask([1, 1, 1, 1])
([4], 1)
>>> compress_loss_mask([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])
([4, 4, 4, 7], 0)
>>> compress_loss_mask([])
([], 0)
"""
if not mask:
return ([], 0)

runs = []
starting_value = mask[0]
current_value = starting_value
current_run = 0

for val in mask:
if val == current_value:
current_run += 1
else:
runs.append(current_run)
current_value = val
current_run = 1

runs.append(current_run)
return (runs, starting_value)


def decompress_loss_mask(compressed: CompressedLossMask) -> list[int]:
"""Decompress a run-length encoded loss mask back to a list of 0s and 1s.

Args:
compressed: A tuple of (run_lengths, starting_value).

Returns:
The original loss mask as a list of 0s and 1s.

Examples:
>>> decompress_loss_mask(([3, 5], 0))
[0, 0, 0, 1, 1, 1, 1, 1]
>>> decompress_loss_mask(([4], 1))
[1, 1, 1, 1]
>>> decompress_loss_mask(([], 0))
[]
"""
runs, starting_value = compressed
if not runs:
return []

mask = []
current_value = starting_value
for run_length in runs:
mask.extend([current_value] * run_length)
current_value = 1 - current_value # Toggle between 0 and 1

return mask


def get_response_lengths(loss_masks: list[list[int]]) -> list[int]:
# return the lengths starting from the first occurrence of 1 to the end of each loss mask
return [len(mask[mask.index(1) :]) if 1 in mask else 0 for mask in loss_masks]
Expand Down
81 changes: 80 additions & 1 deletion tests/utils/test_mask_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,85 @@
from transformers import AutoTokenizer

from slime.utils.mask_utils import MultiTurnLossMaskGenerator
from slime.utils.mask_utils import (
MultiTurnLossMaskGenerator,
compress_loss_mask,
decompress_loss_mask,
)


def test_compress_decompress_all_ones():
"""Test compression/decompression with all 1s (common default case)."""
mask = [1] * 1000
compressed = compress_loss_mask(mask)
assert compressed == ([1000], 1), f"Expected ([1000], 1), got {compressed}"
decompressed = decompress_loss_mask(compressed)
assert decompressed == mask, "Decompressed mask doesn't match original"


def test_compress_decompress_all_zeros():
"""Test compression/decompression with all 0s (remove_sample case)."""
mask = [0] * 500
compressed = compress_loss_mask(mask)
assert compressed == ([500], 0), f"Expected ([500], 0), got {compressed}"
decompressed = decompress_loss_mask(compressed)
assert decompressed == mask, "Decompressed mask doesn't match original"


def test_compress_decompress_prefix_zeros():
"""Test compression/decompression with prefix zeros (common multi-turn pattern)."""
mask = [0] * 100 + [1] * 200
compressed = compress_loss_mask(mask)
assert compressed == ([100, 200], 0), f"Expected ([100, 200], 0), got {compressed}"
decompressed = decompress_loss_mask(compressed)
assert decompressed == mask, "Decompressed mask doesn't match original"


def test_compress_decompress_alternating():
"""Test compression/decompression with alternating pattern."""
mask = [0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1]
compressed = compress_loss_mask(mask)
assert compressed == ([3, 5, 2, 3], 0), f"Expected ([3, 5, 2, 3], 0), got {compressed}"
decompressed = decompress_loss_mask(compressed)
assert decompressed == mask, "Decompressed mask doesn't match original"


def test_compress_decompress_empty():
"""Test compression/decompression with empty mask."""
mask = []
compressed = compress_loss_mask(mask)
assert compressed == ([], 0), f"Expected ([], 0), got {compressed}"
decompressed = decompress_loss_mask(compressed)
assert decompressed == mask, "Decompressed mask doesn't match original"


def test_compress_decompress_single_element():
"""Test compression/decompression with single element masks."""
mask_one = [1]
compressed = compress_loss_mask(mask_one)
assert compressed == ([1], 1), f"Expected ([1], 1), got {compressed}"
assert decompress_loss_mask(compressed) == mask_one

mask_zero = [0]
compressed = compress_loss_mask(mask_zero)
assert compressed == ([1], 0), f"Expected ([1], 0), got {compressed}"
assert decompress_loss_mask(compressed) == mask_zero


def test_compression_efficiency():
"""Test that compression actually reduces size for typical patterns."""
import sys

# All 1s case (8192 tokens)
mask = [1] * 8192
compressed = compress_loss_mask(mask)
# Compressed: ([8192], 1) - just 2 elements
assert len(compressed[0]) == 1, f"Expected 1 run, got {len(compressed[0])}"

# Prefix zeros case (common in multi-turn)
mask = [0] * 2000 + [1] * 6000
compressed = compress_loss_mask(mask)
# Compressed: ([2000, 6000], 0) - just 2 elements
assert len(compressed[0]) == 2, f"Expected 2 runs, got {len(compressed[0])}"


def test_loss_mask_qwen3_simple(model_name: str = "Qwen/Qwen3-8B"):
Expand Down