From 7493fa08afff1ad303479fdabde3918b94eb6ce8 Mon Sep 17 00:00:00 2001 From: Zhuocheng Xu Date: Tue, 20 Jan 2026 13:21:08 -0800 Subject: [PATCH] perf(rollout): compress loss masks with run-length encoding Reduce memory usage and network overhead by compressing binary loss masks using run-length encoding (RLE) when transferring data between rollout and training components. - Add compress_loss_mask and decompress_loss_mask utilities - Compress masks in RolloutManager before sending - Decompress in FSDP and Megatron data processing - Add comprehensive tests for RLE functions --- slime/backends/fsdp_utils/data_packing.py | 7 +- slime/backends/megatron_utils/data.py | 6 +- slime/ray/rollout.py | 6 +- slime/utils/mask_utils.py | 78 ++++++++++++++++++++++ tests/utils/test_mask_utils.py | 81 ++++++++++++++++++++++- 5 files changed, 169 insertions(+), 9 deletions(-) diff --git a/slime/backends/fsdp_utils/data_packing.py b/slime/backends/fsdp_utils/data_packing.py index 241889855..bae3b6d90 100644 --- a/slime/backends/fsdp_utils/data_packing.py +++ b/slime/backends/fsdp_utils/data_packing.py @@ -4,12 +4,13 @@ import torch +from slime.utils.mask_utils import CompressedLossMask, decompress_loss_mask from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions def pack_sequences( tokens: list[list[int]], - loss_masks: list[list[int]], + loss_masks: list[CompressedLossMask], rewards: list[float], raw_rewards: list, response_lengths: list[int], @@ -25,7 +26,7 @@ def pack_sequences( Args: tokens: List of token sequences - loss_masks: List of loss masks + loss_masks: List of compressed loss masks (run-length encoded) rewards: List of rewards per sequence raw_rewards: List of raw rewards per sequence response_lengths: List of response lengths per sequence @@ -72,7 +73,7 @@ def pack_sequences( for i in indices: seq_tokens = tokens[i] - seq_mask = loss_masks[i] + seq_mask = decompress_loss_mask(loss_masks[i]) seq_positionids = list(range(len(seq_tokens))) flat_tokens.extend(seq_tokens) diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index a6f83c06f..cb09c4482 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -11,6 +11,7 @@ from slime.utils import train_metric_utils from slime.utils.data import get_minimum_num_micro_batch_size +from slime.utils.mask_utils import decompress_loss_mask from slime.utils.flops_utils import calculate_fwd_flops from slime.utils.metric_utils import compute_pass_rate, compute_rollout_step from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions @@ -102,15 +103,16 @@ def get_batch( batch["tokens"] = tokens batch["packed_seq_params"] = packed_seq_params - # loss masks + # loss masks - decompress from run-length encoding loss_masks = [] - for loss_mask, total_length, response_length in zip( + for compressed_loss_mask, total_length, response_length in zip( batch["loss_masks"], batch["total_lengths"], batch["response_lengths"], strict=True, ): prompt_length = total_length - response_length + loss_mask = torch.tensor(decompress_loss_mask(compressed_loss_mask), dtype=torch.int) loss_mask = F.pad(loss_mask, (prompt_length - 1, 1), value=0) loss_mask = slice_with_cp(loss_mask, 0, qkv_format, max_seqlen) loss_masks.append(loss_mask) diff --git a/slime/ray/rollout.py b/slime/ray/rollout.py index 75cb053c1..b22ef6ca6 100644 --- a/slime/ray/rollout.py +++ b/slime/ray/rollout.py @@ -16,6 +16,7 @@ from slime.rollout.base_types import call_rollout_fn from slime.utils import logging_utils from slime.utils.health_monitor import RolloutHealthMonitor +from slime.utils.mask_utils import compress_loss_mask from slime.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client from slime.utils.logging_utils import configure_logger, init_tracking from slime.utils.metric_utils import ( @@ -350,8 +351,7 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl "sample_indices": [sample.index for sample in samples], } - # loss mask - # TODO: compress the loss mask + # loss mask - compressed using run-length encoding to reduce memory and network overhead loss_masks = [] for sample in samples: # always instantiate loss_mask if not provided @@ -363,7 +363,7 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl ), f"loss mask length {len(sample.loss_mask)} != response length {sample.response_length}" if sample.remove_sample: sample.loss_mask = [0] * sample.response_length - loss_masks.append(sample.loss_mask) + loss_masks.append(compress_loss_mask(sample.loss_mask)) train_data["loss_masks"] = loss_masks # overwriting the raw reward diff --git a/slime/utils/mask_utils.py b/slime/utils/mask_utils.py index 0ddb3a141..bbacfec7e 100644 --- a/slime/utils/mask_utils.py +++ b/slime/utils/mask_utils.py @@ -1,6 +1,84 @@ from transformers import AutoTokenizer +# Run-length encoding (RLE) for compressing binary loss masks. +# Reference: https://en.wikipedia.org/wiki/Run-length_encoding +# Type alias for compressed loss mask: (run_lengths, starting_value) +# Example: [0,0,0,1,1,1,1,1] -> ([3, 5], 0) meaning "3 zeros, then 5 ones" +CompressedLossMask = tuple[list[int], int] + + +def compress_loss_mask(mask: list[int]) -> CompressedLossMask: + """Compress a binary loss mask using run-length encoding. + + Args: + mask: A list of 0s and 1s representing the loss mask. + + Returns: + A tuple of (run_lengths, starting_value) where: + - run_lengths: list of consecutive run lengths + - starting_value: the value (0 or 1) of the first run + + Examples: + >>> compress_loss_mask([0, 0, 0, 1, 1, 1, 1, 1]) + ([3, 5], 0) + >>> compress_loss_mask([1, 1, 1, 1]) + ([4], 1) + >>> compress_loss_mask([0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]) + ([4, 4, 4, 7], 0) + >>> compress_loss_mask([]) + ([], 0) + """ + if not mask: + return ([], 0) + + runs = [] + starting_value = mask[0] + current_value = starting_value + current_run = 0 + + for val in mask: + if val == current_value: + current_run += 1 + else: + runs.append(current_run) + current_value = val + current_run = 1 + + runs.append(current_run) + return (runs, starting_value) + + +def decompress_loss_mask(compressed: CompressedLossMask) -> list[int]: + """Decompress a run-length encoded loss mask back to a list of 0s and 1s. + + Args: + compressed: A tuple of (run_lengths, starting_value). + + Returns: + The original loss mask as a list of 0s and 1s. + + Examples: + >>> decompress_loss_mask(([3, 5], 0)) + [0, 0, 0, 1, 1, 1, 1, 1] + >>> decompress_loss_mask(([4], 1)) + [1, 1, 1, 1] + >>> decompress_loss_mask(([], 0)) + [] + """ + runs, starting_value = compressed + if not runs: + return [] + + mask = [] + current_value = starting_value + for run_length in runs: + mask.extend([current_value] * run_length) + current_value = 1 - current_value # Toggle between 0 and 1 + + return mask + + def get_response_lengths(loss_masks: list[list[int]]) -> list[int]: # return the lengths starting from the first occurrence of 1 to the end of each loss mask return [len(mask[mask.index(1) :]) if 1 in mask else 0 for mask in loss_masks] diff --git a/tests/utils/test_mask_utils.py b/tests/utils/test_mask_utils.py index c0823f0bd..28368cf8d 100644 --- a/tests/utils/test_mask_utils.py +++ b/tests/utils/test_mask_utils.py @@ -1,6 +1,85 @@ from transformers import AutoTokenizer -from slime.utils.mask_utils import MultiTurnLossMaskGenerator +from slime.utils.mask_utils import ( + MultiTurnLossMaskGenerator, + compress_loss_mask, + decompress_loss_mask, +) + + +def test_compress_decompress_all_ones(): + """Test compression/decompression with all 1s (common default case).""" + mask = [1] * 1000 + compressed = compress_loss_mask(mask) + assert compressed == ([1000], 1), f"Expected ([1000], 1), got {compressed}" + decompressed = decompress_loss_mask(compressed) + assert decompressed == mask, "Decompressed mask doesn't match original" + + +def test_compress_decompress_all_zeros(): + """Test compression/decompression with all 0s (remove_sample case).""" + mask = [0] * 500 + compressed = compress_loss_mask(mask) + assert compressed == ([500], 0), f"Expected ([500], 0), got {compressed}" + decompressed = decompress_loss_mask(compressed) + assert decompressed == mask, "Decompressed mask doesn't match original" + + +def test_compress_decompress_prefix_zeros(): + """Test compression/decompression with prefix zeros (common multi-turn pattern).""" + mask = [0] * 100 + [1] * 200 + compressed = compress_loss_mask(mask) + assert compressed == ([100, 200], 0), f"Expected ([100, 200], 0), got {compressed}" + decompressed = decompress_loss_mask(compressed) + assert decompressed == mask, "Decompressed mask doesn't match original" + + +def test_compress_decompress_alternating(): + """Test compression/decompression with alternating pattern.""" + mask = [0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1] + compressed = compress_loss_mask(mask) + assert compressed == ([3, 5, 2, 3], 0), f"Expected ([3, 5, 2, 3], 0), got {compressed}" + decompressed = decompress_loss_mask(compressed) + assert decompressed == mask, "Decompressed mask doesn't match original" + + +def test_compress_decompress_empty(): + """Test compression/decompression with empty mask.""" + mask = [] + compressed = compress_loss_mask(mask) + assert compressed == ([], 0), f"Expected ([], 0), got {compressed}" + decompressed = decompress_loss_mask(compressed) + assert decompressed == mask, "Decompressed mask doesn't match original" + + +def test_compress_decompress_single_element(): + """Test compression/decompression with single element masks.""" + mask_one = [1] + compressed = compress_loss_mask(mask_one) + assert compressed == ([1], 1), f"Expected ([1], 1), got {compressed}" + assert decompress_loss_mask(compressed) == mask_one + + mask_zero = [0] + compressed = compress_loss_mask(mask_zero) + assert compressed == ([1], 0), f"Expected ([1], 0), got {compressed}" + assert decompress_loss_mask(compressed) == mask_zero + + +def test_compression_efficiency(): + """Test that compression actually reduces size for typical patterns.""" + import sys + + # All 1s case (8192 tokens) + mask = [1] * 8192 + compressed = compress_loss_mask(mask) + # Compressed: ([8192], 1) - just 2 elements + assert len(compressed[0]) == 1, f"Expected 1 run, got {len(compressed[0])}" + + # Prefix zeros case (common in multi-turn) + mask = [0] * 2000 + [1] * 6000 + compressed = compress_loss_mask(mask) + # Compressed: ([2000, 6000], 0) - just 2 elements + assert len(compressed[0]) == 2, f"Expected 2 runs, got {len(compressed[0])}" def test_loss_mask_qwen3_simple(model_name: str = "Qwen/Qwen3-8B"):