Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move loss generating token counting to the dataloader #1632

Merged
merged 19 commits into from
Nov 4, 2024
5 changes: 5 additions & 0 deletions llmfoundry/data/contrastive_pairs/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ def collate_fn_without_labels(batch: list[Any]) -> dict[str, torch.Tensor]:
processed_batch: dict[str, torch.Tensor] = collate_fn(batch)
if 'labels' in processed_batch:
del processed_batch['labels']

dakinggg marked this conversation as resolved.
Show resolved Hide resolved
if 'total_tokens' in processed_batch:
del processed_batch['total_tokens']
if 'loss_generating_tokens' in processed_batch:
del processed_batch['loss_generating_tokens']
return processed_batch

dl = DataLoader(
Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ def profile_packing(
big_batch = next(iter(train_dataloader))

# Cut everything down to size
if 'total_tokens' in big_batch:
del big_batch['total_tokens']
if 'loss_generating_tokens' in big_batch:
del big_batch['loss_generating_tokens']
sizes, trimmed_examples = _trim_batch(big_batch)

def profile(raw_batch_size: int) -> tuple[Optional[float], Optional[float]]:
Expand Down
92 changes: 79 additions & 13 deletions llmfoundry/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,55 @@
log = logging.getLogger(__name__)


class LossGeneratingTokensCollatorWrapper:
"""Collator wrapper to add loss generating token counts to batch."""

def __init__(
self,
base_collator: Callable,
token_counting_func: Callable[[Batch], Union[int, dict[str, int]]],
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
):
self.base_collator = base_collator
self.token_counting_func = token_counting_func

self._token_count_batch_keys = [
'input_ids',
'attention_mask',
'labels',
'decoder_attention_mask',
]

def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]:
batch = self.base_collator(examples)

# Add token counts to batch as a list, one for each row, so that microbatch splitting works
output = {
'total_tokens': [],
'loss_generating_tokens': [],
}
num_rows = batch['input_ids'].shape[0]
for row in range(num_rows):
row_batch = {}
for key in self._token_count_batch_keys:
if key in batch:
row_batch[key] = batch[key][row:row + 1]

num_tokens = self.token_counting_func(row_batch)
if isinstance(num_tokens, dict):
output['total_tokens'].append(num_tokens['total'])
output['loss_generating_tokens'].append(
num_tokens['loss_generating'],
)
else:
output['total_tokens'].append(num_tokens)
output['loss_generating_tokens'].append(num_tokens)

batch['total_tokens'] = output['total_tokens']
batch['loss_generating_tokens'] = output['loss_generating_tokens']

return batch


def _validate_cfg(
dataset_cfg: dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
Expand Down Expand Up @@ -109,6 +158,13 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]:
'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key',
)

# Short cut if the dataloader has already calculated the number of tokens
if 'total_tokens' in batch and 'loss_generating_tokens' in batch:
return {
'total': sum(batch['total_tokens']),
'loss_generating': sum(batch['loss_generating_tokens']),
}

# Count number of non padding tokens in batch
if 'attention_mask' in batch:
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
Expand All @@ -117,16 +173,10 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]:

loss_generating_tokens = None
if 'labels' in batch:
loss_generating_tokens = int(
torch.sum(batch['labels'] != CROSS_ENTROPY_IGNORE_INDEX).item(),
)

# Subtract one for each example in the batch that starts with a non -100,
# because those will be shifted off
loss_generating_tokens -= int(
torch.sum(
batch['labels'][:, 0] != CROSS_ENTROPY_IGNORE_INDEX,
).item(),
loss_generating_tokens = (
batch['labels'].shape[0] * (batch['labels'].shape[1] - 1)
) - torch.count_nonzero(
torch.eq(batch['labels'][..., 1:], CROSS_ENTROPY_IGNORE_INDEX),
)

# For encoder decoder models only
Expand All @@ -151,7 +201,8 @@ def get_text_collator(
tokenizer: PreTrainedTokenizerBase,
dataset_batch_size: int,
) -> tuple[Union[transformers.DataCollatorForLanguageModeling,
ConcatenatedSequenceCollatorWrapper], int]:
ConcatenatedSequenceCollatorWrapper,
LossGeneratingTokensCollatorWrapper], int]:
dataset_cfg = dataloader_cfg.get('dataset')
assert isinstance(dataset_cfg, dict)
eos_token_id = dataset_cfg.get('eos_token_id', None)
Expand All @@ -171,12 +222,27 @@ def get_text_collator(
bos_token_id=bos_token_id,
)

collate_fn = LossGeneratingTokensCollatorWrapper(
collate_fn,
get_tokens_per_batch_func(),
)

return collate_fn, dataset_batch_size


def get_finetuning_collator(
dataloader_cfg: dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
dataset_batch_size: int,
) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]:
return build_collate_fn(dataloader_cfg, tokenizer, dataset_batch_size)
) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator,
LossGeneratingTokensCollatorWrapper], int]:
collate_fn, dataset_batch_size = build_collate_fn(
dataloader_cfg,
tokenizer,
dataset_batch_size,
)
collate_fn = LossGeneratingTokensCollatorWrapper(
collate_fn,
get_tokens_per_batch_func(),
)
return collate_fn, dataset_batch_size
3 changes: 2 additions & 1 deletion tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,8 @@ def test_mptmoe_huggingface_conversion_callback(
# Check output equivalence
loaded_model = loaded_model.cuda().bfloat16() # type: ignore
for k, v in batch.items():
batch[k] = v.cuda()
if isinstance(v, torch.Tensor):
batch[k] = v.cuda()
loaded_model_logits = loaded_model(
input_ids=batch.get('input_ids', None),
attention_mask=batch.get('attention_mask', None),
Expand Down
4 changes: 3 additions & 1 deletion tests/data/test_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.utils import LossGeneratingTokensCollatorWrapper
from llmfoundry.utils.builders import build_tokenizer


Expand Down Expand Up @@ -253,7 +254,8 @@ def test_packing_with_dataloader(packing_ratio: Any):
).dataloader

assert isinstance(loader, DataLoader)
pack_collator = loader.collate_fn
assert isinstance(loader.collate_fn, LossGeneratingTokensCollatorWrapper)
pack_collator = loader.collate_fn.base_collator
assert isinstance(pack_collator, BinPackCollator)

batch_ix = 0
Expand Down
Loading