Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto packing fixes #783

Merged
merged 8 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import tempfile
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple

import numpy as np
import torch
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

log = logging.getLogger(__name__)


class BinPackCollator:
"""Utility collator for packing to reduce padding."""
Expand Down Expand Up @@ -289,8 +294,13 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
# Set the seed so that auto packing is deterministic.
reproducibility.seed_all(0)

max_seq_len = dataloader_cfg.dataset.max_seq_len
# If max_seq_len is very small, skip profiling and select packing ratio of 1.
if max_seq_len <= 100:
return 1

min_ratio = 1
max_ratio = dataloader_cfg.dataset.max_seq_len / 100
max_ratio = max_seq_len / 100
profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio,
max_ratio, num_packing_ratios,
device_batch_size)
Expand All @@ -299,7 +309,7 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,
# profiling_results are sorted from smallest to largest packing_ratio.
packing_ratio = 1
for packing_ratio_candidate, _, waste in profiling_results:
if waste > 0:
if waste is None or waste > 0:
break
packing_ratio = packing_ratio_candidate

Expand All @@ -318,9 +328,10 @@ def auto_packing_ratio(dataloader_cfg: DictConfig,


def profile_packing(
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
min_ratio: float, max_ratio: float, num_packing_ratios: int,
device_batch_size: int) -> Iterable[Tuple[float, float, float]]:
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
min_ratio: float, max_ratio: float, num_packing_ratios: int,
device_batch_size: int
) -> Iterable[Tuple[float, Optional[float], Optional[float]]]:
"""Generator function that profiles example packing across packing ratios.

Args:
Expand Down Expand Up @@ -350,6 +361,10 @@ def profile_packing(
dataloader_cfg.prefetch_factor = None
dataloader_cfg.persistent_workers = False

# If streaming dataset, use a temporary local folder for profiling
if dataloader_cfg.dataset.get('remote') is not None:
dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name

# Determine the packing_ratio values we'll try
packing_ratios, raw_batch_sizes = [], []
for packing_ratio in np.linspace(min_ratio,
Expand Down Expand Up @@ -382,7 +397,7 @@ def split_big_batch(raw_batch_size: int) -> List:
batches[idx].update({key: split})
return batches

def profile(raw_batch_size: int) -> Tuple[float, float]:
def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
packer = BinPackCollator(
collator=lambda x: x,
target_batch_size=device_batch_size,
Expand All @@ -395,9 +410,15 @@ def profile(raw_batch_size: int) -> Tuple[float, float]:
for batch in split_big_batch(raw_batch_size):
if batch['input_ids'].shape[0] < device_batch_size:
continue
_ = packer.pack(batch)
packer.pack(batch)

if packer.n_packed_examples == 0:
log.debug(
'No examples packed during profiling. Dataset is smaller than device batch size.'
)
return None, None

# Return the padding / waste stats over that bunch of data
# Return the padding and waste stats over that bunch of data
padding_percent = 100 * (1 - packer.efficiency)
waste_percent = 100 * packer.waste
return padding_percent, waste_percent
Expand Down
2 changes: 1 addition & 1 deletion scripts/misc/profile_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def parse_args() -> Namespace:
help='Path to the YAML that defines the workload to profile.')
parser.add_argument('--num-devices',
type=int,
default=None,
required=True,
help='How many devices your run will use.')
parser.add_argument('--min',
type=float,
Expand Down
40 changes: 40 additions & 0 deletions tests/data/test_packing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any, Dict, List
from unittest.mock import Mock, patch

Expand All @@ -9,6 +10,7 @@
from composer.utils import dist, reproducibility
from omegaconf import DictConfig
from pytest import approx
from streaming import MDSWriter
from torch.utils.data import DataLoader

from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
Expand Down Expand Up @@ -149,6 +151,44 @@ def patched_packing_ratio(*args: Any, **kwargs: Any):
return auto_packing_ratio(*args, **kwargs, num_packing_ratios=4)


@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio',
patched_packing_ratio)
def test_auto_packing_with_streaming_dataloader(tmp_path: Path):
columns = {'prompt': 'str', 'response': 'str'}
tokenizer = build_tokenizer('gpt2', {})
remote_dir = str(tmp_path / 'remote')
local_dir = str(tmp_path / 'local')
with MDSWriter(out=remote_dir, columns=columns, compression=None) as out:
irenedea marked this conversation as resolved.
Show resolved Hide resolved
out.write({'prompt': 'HELLO', 'response': 'WORLD'})
cfg = DictConfig({
'name': 'finetuning',
'dataset': {
'remote': remote_dir,
'local': local_dir,
'packing_ratio': 'auto',
'max_seq_len': 200,
'decoder_only_format': True
},
'drop_last': False,
# Need to test with 0 num_workers because the packing collator object
# Gets copied per worker and we cannot check the waste for child processes.
'num_workers': 0,
'pin_memory': False,
'prefetch_factor': None,
'persistent_workers': False,
'timeout': 0,
})

loader = build_finetuning_dataloader(cfg, tokenizer,
irenedea marked this conversation as resolved.
Show resolved Hide resolved
device_batch_size=6).dataloader

batch_ix = 0
for _ in loader:
batch_ix += 1
if batch_ix >= 3:
break


@pytest.mark.parametrize('packing_ratio', ['auto', 2.0])
@patch('llmfoundry.data.finetuning.dataloader.auto_packing_ratio',
patched_packing_ratio)
Expand Down
Loading