From 99177b59ed50006cea0ac07c8cc9849329c14aa3 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 13:06:33 -0700 Subject: [PATCH 01/19] fix --- llmfoundry/data/utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 21c28d9183..a5f0b61f62 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -118,15 +118,7 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: loss_generating_tokens = None if 'labels' in batch: loss_generating_tokens = int( - torch.sum(batch['labels'] != CROSS_ENTROPY_IGNORE_INDEX).item(), - ) - - # Subtract one for each example in the batch that starts with a non -100, - # because those will be shifted off - loss_generating_tokens -= int( - torch.sum( - batch['labels'][:, 0] != CROSS_ENTROPY_IGNORE_INDEX, - ).item(), + torch.sum(batch['labels'][...,1:] != CROSS_ENTROPY_IGNORE_INDEX).item(), ) # For encoder decoder models only From 14751e26918c8da2666705742f009946fa2ed5fb Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 13:13:43 -0700 Subject: [PATCH 02/19] try --- llmfoundry/data/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index a5f0b61f62..2247f4f9ac 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -117,9 +117,7 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: loss_generating_tokens = None if 'labels' in batch: - loss_generating_tokens = int( - torch.sum(batch['labels'][...,1:] != CROSS_ENTROPY_IGNORE_INDEX).item(), - ) + loss_generating_tokens = int((batch['labels'][...,1:] != CROSS_ENTROPY_IGNORE_INDEX).sum()) # For encoder decoder models only decoder_input_ids_tokens = 0 From 81f61744dd4ab5de2fcf3318a00bed560cf6c0bd Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 13:36:50 -0700 Subject: [PATCH 03/19] fix --- llmfoundry/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 2247f4f9ac..44a3bcfb21 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -117,7 +117,7 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: loss_generating_tokens = None if 'labels' in batch: - loss_generating_tokens = int((batch['labels'][...,1:] != CROSS_ENTROPY_IGNORE_INDEX).sum()) + loss_generating_tokens = (batch['labels'].shape[0] * (batch['labels'].shape[1] - 1)) - torch.count_nonzero(torch.eq(batch['labels'][...,1:], CROSS_ENTROPY_IGNORE_INDEX)) # For encoder decoder models only decoder_input_ids_tokens = 0 From 9d5579809efaef3708aa0b1ff82c6f6065656bb1 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 13:54:23 -0700 Subject: [PATCH 04/19] rm --- llmfoundry/data/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 44a3bcfb21..4f8caf9681 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -116,8 +116,8 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: input_ids_tokens = batch['input_ids'].numel() loss_generating_tokens = None - if 'labels' in batch: - loss_generating_tokens = (batch['labels'].shape[0] * (batch['labels'].shape[1] - 1)) - torch.count_nonzero(torch.eq(batch['labels'][...,1:], CROSS_ENTROPY_IGNORE_INDEX)) + # if 'labels' in batch: + # loss_generating_tokens = (batch['labels'].shape[0] * (batch['labels'].shape[1] - 1)) - torch.count_nonzero(torch.eq(batch['labels'][...,1:], CROSS_ENTROPY_IGNORE_INDEX)) # For encoder decoder models only decoder_input_ids_tokens = 0 From 0f3cbc08564ab152d4129b512f7cbcf26560c0de Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 13:58:50 -0700 Subject: [PATCH 05/19] put back --- llmfoundry/data/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 4f8caf9681..44a3bcfb21 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -116,8 +116,8 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: input_ids_tokens = batch['input_ids'].numel() loss_generating_tokens = None - # if 'labels' in batch: - # loss_generating_tokens = (batch['labels'].shape[0] * (batch['labels'].shape[1] - 1)) - torch.count_nonzero(torch.eq(batch['labels'][...,1:], CROSS_ENTROPY_IGNORE_INDEX)) + if 'labels' in batch: + loss_generating_tokens = (batch['labels'].shape[0] * (batch['labels'].shape[1] - 1)) - torch.count_nonzero(torch.eq(batch['labels'][...,1:], CROSS_ENTROPY_IGNORE_INDEX)) # For encoder decoder models only decoder_input_ids_tokens = 0 From 12e4f02a45dd7da7be7557b878716ddd1f4390ea Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 14:35:12 -0700 Subject: [PATCH 06/19] fix --- llmfoundry/data/utils.py | 44 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 44a3bcfb21..378f32c755 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -19,6 +19,41 @@ log = logging.getLogger(__name__) +class LossGeneratingTokensCollatorWrapper: + """Collator wrapper to add sequence_id to batch.""" + + def __init__( + self, + base_collator: Callable, + ): + self.base_collator = base_collator + + def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: + batch = self.base_collator(examples) + + # Add token counts to batch + output = { + 'total_tokens': [], + 'loss_generating_tokens': [], + } + num_rows = batch['input_ids'].shape[0] + for row in num_rows: + row_batch = { + 'input_ids': batch['input_ids'][row], + } + if 'attention_mask' in batch: + row_batch['attention_mask'] = batch['attention_mask'][row] + if 'labels' in batch: + row_batch['labels'] = batch['labels'][row] + + num_tokens = get_tokens_per_batch_func()(row_batch) + output['total_tokens'].append(num_tokens['total']) + output['loss_generating_tokens'].append(num_tokens['loss_generating']) + + batch['total_tokens'] = output['total_tokens'] + batch['loss_generating_tokens'] = output['loss_generating_tokens'] + + return batch def _validate_cfg( dataset_cfg: dict[str, Any], @@ -108,6 +143,13 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: raise ValueError( 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key', ) + + # Short cut if the dataloader has already calculated the number of tokens + if 'total_tokens' in batch and 'loss_generating_tokens' in batch: + return { + 'total': sum(batch['total_tokens']), + 'loss_generating': sum(batch['loss_generating_tokens']), + } # Count number of non padding tokens in batch if 'attention_mask' in batch: @@ -161,6 +203,8 @@ def get_text_collator( bos_token_id=bos_token_id, ) + collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn) + return collate_fn, dataset_batch_size From 4a425b5f65c53ab1ce1c4fb4e1bdb523f7d3ed8a Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 14:39:59 -0700 Subject: [PATCH 07/19] tensors --- llmfoundry/data/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 378f32c755..bfc3f9a322 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -50,8 +50,8 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: output['total_tokens'].append(num_tokens['total']) output['loss_generating_tokens'].append(num_tokens['loss_generating']) - batch['total_tokens'] = output['total_tokens'] - batch['loss_generating_tokens'] = output['loss_generating_tokens'] + batch['total_tokens'] = torch.tensor(output['total_tokens']) + batch['loss_generating_tokens'] = torch.tensor(output['loss_generating_tokens']) return batch From 3d207a0a97c1770d8f64362b5286cb239edba33d Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 14:42:13 -0700 Subject: [PATCH 08/19] fix --- llmfoundry/data/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index bfc3f9a322..e1709167ab 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -37,7 +37,7 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: 'loss_generating_tokens': [], } num_rows = batch['input_ids'].shape[0] - for row in num_rows: + for row in range(num_rows): row_batch = { 'input_ids': batch['input_ids'][row], } @@ -50,8 +50,8 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: output['total_tokens'].append(num_tokens['total']) output['loss_generating_tokens'].append(num_tokens['loss_generating']) - batch['total_tokens'] = torch.tensor(output['total_tokens']) - batch['loss_generating_tokens'] = torch.tensor(output['loss_generating_tokens']) + batch['total_tokens'] = output['total_tokens'] + batch['loss_generating_tokens'] = output['loss_generating_tokens'] return batch From fd86108a56424168a7b8463132a94b18979f3342 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 14:44:34 -0700 Subject: [PATCH 09/19] fix --- llmfoundry/data/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index e1709167ab..4d507c09cd 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -39,12 +39,12 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: num_rows = batch['input_ids'].shape[0] for row in range(num_rows): row_batch = { - 'input_ids': batch['input_ids'][row], + 'input_ids': batch['input_ids'][row].unsqueeze(0), } if 'attention_mask' in batch: - row_batch['attention_mask'] = batch['attention_mask'][row] + row_batch['attention_mask'] = batch['attention_mask'][row].unsqueeze(0) if 'labels' in batch: - row_batch['labels'] = batch['labels'][row] + row_batch['labels'] = batch['labels'][row].unsqueeze(0) num_tokens = get_tokens_per_batch_func()(row_batch) output['total_tokens'].append(num_tokens['total']) From a21c781dd1998ceded05f6dea98acfc278953aff Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 14:55:46 -0700 Subject: [PATCH 10/19] update --- llmfoundry/data/utils.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 4d507c09cd..68160280bb 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -19,6 +19,7 @@ log = logging.getLogger(__name__) + class LossGeneratingTokensCollatorWrapper: """Collator wrapper to add sequence_id to batch.""" @@ -30,7 +31,7 @@ def __init__( def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: batch = self.base_collator(examples) - + # Add token counts to batch output = { 'total_tokens': [], @@ -42,19 +43,30 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: 'input_ids': batch['input_ids'][row].unsqueeze(0), } if 'attention_mask' in batch: - row_batch['attention_mask'] = batch['attention_mask'][row].unsqueeze(0) + row_batch['attention_mask'] = batch['attention_mask'][ + row].unsqueeze(0) if 'labels' in batch: row_batch['labels'] = batch['labels'][row].unsqueeze(0) + if 'decoder_attention_mask' in batch: + row_batch['decoder_attention_mask'] = batch[ + 'decoder_attention_mask'][row].unsqueeze(0) num_tokens = get_tokens_per_batch_func()(row_batch) - output['total_tokens'].append(num_tokens['total']) - output['loss_generating_tokens'].append(num_tokens['loss_generating']) + if isinstance(num_tokens, dict): + output['total_tokens'].append(num_tokens['total']) + output['loss_generating_tokens'].append( + num_tokens['loss_generating'] + ) + else: + output['total_tokens'].append(num_tokens) + output['loss_generating_tokens'].append(num_tokens) batch['total_tokens'] = output['total_tokens'] batch['loss_generating_tokens'] = output['loss_generating_tokens'] return batch + def _validate_cfg( dataset_cfg: dict[str, Any], tokenizer: PreTrainedTokenizerBase, @@ -143,7 +155,7 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: raise ValueError( 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key', ) - + # Short cut if the dataloader has already calculated the number of tokens if 'total_tokens' in batch and 'loss_generating_tokens' in batch: return { @@ -159,7 +171,11 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: loss_generating_tokens = None if 'labels' in batch: - loss_generating_tokens = (batch['labels'].shape[0] * (batch['labels'].shape[1] - 1)) - torch.count_nonzero(torch.eq(batch['labels'][...,1:], CROSS_ENTROPY_IGNORE_INDEX)) + loss_generating_tokens = ( + batch['labels'].shape[0] * (batch['labels'].shape[1] - 1) + ) - torch.count_nonzero( + torch.eq(batch['labels'][..., 1:], CROSS_ENTROPY_IGNORE_INDEX) + ) # For encoder decoder models only decoder_input_ids_tokens = 0 From b4b9349b0d8ebf9f8a05b1f9db3431e241d2acac Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 14:58:56 -0700 Subject: [PATCH 11/19] add to ft --- llmfoundry/data/utils.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 68160280bb..fc8d2a8f9d 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -55,7 +55,7 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: if isinstance(num_tokens, dict): output['total_tokens'].append(num_tokens['total']) output['loss_generating_tokens'].append( - num_tokens['loss_generating'] + num_tokens['loss_generating'], ) else: output['total_tokens'].append(num_tokens) @@ -174,7 +174,7 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: loss_generating_tokens = ( batch['labels'].shape[0] * (batch['labels'].shape[1] - 1) ) - torch.count_nonzero( - torch.eq(batch['labels'][..., 1:], CROSS_ENTROPY_IGNORE_INDEX) + torch.eq(batch['labels'][..., 1:], CROSS_ENTROPY_IGNORE_INDEX), ) # For encoder decoder models only @@ -199,7 +199,8 @@ def get_text_collator( tokenizer: PreTrainedTokenizerBase, dataset_batch_size: int, ) -> tuple[Union[transformers.DataCollatorForLanguageModeling, - ConcatenatedSequenceCollatorWrapper], int]: + ConcatenatedSequenceCollatorWrapper, + LossGeneratingTokensCollatorWrapper], int]: dataset_cfg = dataloader_cfg.get('dataset') assert isinstance(dataset_cfg, dict) eos_token_id = dataset_cfg.get('eos_token_id', None) @@ -228,5 +229,10 @@ def get_finetuning_collator( dataloader_cfg: dict[str, Any], tokenizer: PreTrainedTokenizerBase, dataset_batch_size: int, -) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: - return build_collate_fn(dataloader_cfg, tokenizer, dataset_batch_size) +) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator, + LossGeneratingTokensCollatorWrapper], int]: + collate_fn, dataset_batch_size = build_collate_fn( + dataloader_cfg, tokenizer, dataset_batch_size + ) + collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn) + return collate_fn, dataset_batch_size From 2f75f19b04db4fe835465fc4cb60ae6c30aee3a0 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 15:06:43 -0700 Subject: [PATCH 12/19] pc --- llmfoundry/data/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index fc8d2a8f9d..6ccacbf73f 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -232,7 +232,9 @@ def get_finetuning_collator( ) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator, LossGeneratingTokensCollatorWrapper], int]: collate_fn, dataset_batch_size = build_collate_fn( - dataloader_cfg, tokenizer, dataset_batch_size + dataloader_cfg, + tokenizer, + dataset_batch_size, ) collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn) return collate_fn, dataset_batch_size From d92d535271c2e26298b4c139c99dcd6a68de7ce7 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 15:36:05 -0700 Subject: [PATCH 13/19] pc --- llmfoundry/data/packing.py | 4 ++++ llmfoundry/data/utils.py | 2 +- tests/data/test_packing.py | 4 +++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index e3c19cc91c..5eacced549 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -514,6 +514,10 @@ def profile_packing( big_batch = next(iter(train_dataloader)) # Cut everything down to size + if 'total_tokens' in big_batch: + del big_batch['total_tokens'] + if 'loss_generating_tokens' in big_batch: + del big_batch['loss_generating_tokens'] sizes, trimmed_examples = _trim_batch(big_batch) def profile(raw_batch_size: int) -> tuple[Optional[float], Optional[float]]: diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 6ccacbf73f..526d3dfb9f 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -32,7 +32,7 @@ def __init__( def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: batch = self.base_collator(examples) - # Add token counts to batch + # Add token counts to batch as a list, one for each row, so that microbatch splitting works output = { 'total_tokens': [], 'loss_generating_tokens': [], diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index 0fad6c0d53..48713f8a19 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -16,6 +16,7 @@ from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio +from llmfoundry.data.utils import LossGeneratingTokensCollatorWrapper from llmfoundry.utils.builders import build_tokenizer @@ -253,7 +254,8 @@ def test_packing_with_dataloader(packing_ratio: Any): ).dataloader assert isinstance(loader, DataLoader) - pack_collator = loader.collate_fn + assert isinstance(loader.collate_fn, LossGeneratingTokensCollatorWrapper) + pack_collator = loader.collate_fn.base_collator assert isinstance(pack_collator, BinPackCollator) batch_ix = 0 From 120853fdadbcfbeaac7d77144d0ffb52f18296ef Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 15:40:15 -0700 Subject: [PATCH 14/19] fix embedding code --- llmfoundry/data/contrastive_pairs/dataloader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/llmfoundry/data/contrastive_pairs/dataloader.py b/llmfoundry/data/contrastive_pairs/dataloader.py index d9760aa926..888898da04 100644 --- a/llmfoundry/data/contrastive_pairs/dataloader.py +++ b/llmfoundry/data/contrastive_pairs/dataloader.py @@ -249,6 +249,11 @@ def collate_fn_without_labels(batch: list[Any]) -> dict[str, torch.Tensor]: processed_batch: dict[str, torch.Tensor] = collate_fn(batch) if 'labels' in processed_batch: del processed_batch['labels'] + + if 'total_tokens' in processed_batch: + del processed_batch['total_tokens'] + if 'loss_generating_tokens' in processed_batch: + del processed_batch['loss_generating_tokens'] return processed_batch dl = DataLoader( From ad55425afdd9210b133c6f12b2dcee268d210d6c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 16:05:38 -0700 Subject: [PATCH 15/19] fix test and clean up; --- llmfoundry/data/utils.py | 8 +++++--- tests/a_scripts/inference/test_convert_composer_to_hf.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 526d3dfb9f..3523d708ff 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -26,8 +26,10 @@ class LossGeneratingTokensCollatorWrapper: def __init__( self, base_collator: Callable, + token_counting_func: Callable, ): self.base_collator = base_collator + self.token_counting_func = token_counting_func def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: batch = self.base_collator(examples) @@ -51,7 +53,7 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: row_batch['decoder_attention_mask'] = batch[ 'decoder_attention_mask'][row].unsqueeze(0) - num_tokens = get_tokens_per_batch_func()(row_batch) + num_tokens = self.token_counting_func(row_batch) if isinstance(num_tokens, dict): output['total_tokens'].append(num_tokens['total']) output['loss_generating_tokens'].append( @@ -220,7 +222,7 @@ def get_text_collator( bos_token_id=bos_token_id, ) - collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn) + collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn, get_tokens_per_batch_func()) return collate_fn, dataset_batch_size @@ -236,5 +238,5 @@ def get_finetuning_collator( tokenizer, dataset_batch_size, ) - collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn) + collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn, get_tokens_per_batch_func()) return collate_fn, dataset_batch_size diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index f599ebbc16..809babece9 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -1573,7 +1573,8 @@ def test_mptmoe_huggingface_conversion_callback( # Check output equivalence loaded_model = loaded_model.cuda().bfloat16() # type: ignore for k, v in batch.items(): - batch[k] = v.cuda() + if isinstance(v, torch.Tensor): + batch[k] = v.cuda() loaded_model_logits = loaded_model( input_ids=batch.get('input_ids', None), attention_mask=batch.get('attention_mask', None), From a40242288d4ab85bb89d9d7853bf9abfec0ed8bc Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sat, 2 Nov 2024 16:16:37 -0700 Subject: [PATCH 16/19] pc --- llmfoundry/data/utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 3523d708ff..98e2804bad 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -21,12 +21,12 @@ class LossGeneratingTokensCollatorWrapper: - """Collator wrapper to add sequence_id to batch.""" + """Collator wrapper to add loss generating token counts to batch.""" def __init__( self, base_collator: Callable, - token_counting_func: Callable, + token_counting_func: Callable[[Batch], Union[int, dict[str, int]]], ): self.base_collator = base_collator self.token_counting_func = token_counting_func @@ -222,7 +222,10 @@ def get_text_collator( bos_token_id=bos_token_id, ) - collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn, get_tokens_per_batch_func()) + collate_fn = LossGeneratingTokensCollatorWrapper( + collate_fn, + get_tokens_per_batch_func(), + ) return collate_fn, dataset_batch_size @@ -238,5 +241,8 @@ def get_finetuning_collator( tokenizer, dataset_batch_size, ) - collate_fn = LossGeneratingTokensCollatorWrapper(collate_fn, get_tokens_per_batch_func()) + collate_fn = LossGeneratingTokensCollatorWrapper( + collate_fn, + get_tokens_per_batch_func(), + ) return collate_fn, dataset_batch_size From b0ddad7720f8842fe15aacda65074a8307a6372c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 3 Nov 2024 16:01:01 -0800 Subject: [PATCH 17/19] pr comments --- llmfoundry/data/utils.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 98e2804bad..7206e1c5a9 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -31,6 +31,8 @@ def __init__( self.base_collator = base_collator self.token_counting_func = token_counting_func + self._token_count_batch_keys = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'] + def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: batch = self.base_collator(examples) @@ -41,17 +43,10 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: } num_rows = batch['input_ids'].shape[0] for row in range(num_rows): - row_batch = { - 'input_ids': batch['input_ids'][row].unsqueeze(0), - } - if 'attention_mask' in batch: - row_batch['attention_mask'] = batch['attention_mask'][ - row].unsqueeze(0) - if 'labels' in batch: - row_batch['labels'] = batch['labels'][row].unsqueeze(0) - if 'decoder_attention_mask' in batch: - row_batch['decoder_attention_mask'] = batch[ - 'decoder_attention_mask'][row].unsqueeze(0) + row_batch = {} + for key in self._token_count_batch_keys: + if key in batch: + row_batch[key] = batch[key][row:row+1] num_tokens = self.token_counting_func(row_batch) if isinstance(num_tokens, dict): From 7a7055ba3a53f64a82d91369529d71789f01068d Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 3 Nov 2024 16:10:46 -0800 Subject: [PATCH 18/19] pc --- llmfoundry/data/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 7206e1c5a9..62683266f7 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -31,7 +31,9 @@ def __init__( self.base_collator = base_collator self.token_counting_func = token_counting_func - self._token_count_batch_keys = ['input_ids', 'attention_mask', 'labels', 'decoder_attention_mask'] + self._token_count_batch_keys = [ + 'input_ids', 'attention_mask', 'labels', 'decoder_attention_mask' + ] def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: batch = self.base_collator(examples) @@ -46,7 +48,7 @@ def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: row_batch = {} for key in self._token_count_batch_keys: if key in batch: - row_batch[key] = batch[key][row:row+1] + row_batch[key] = batch[key][row:row + 1] num_tokens = self.token_counting_func(row_batch) if isinstance(num_tokens, dict): From ec2109da009ae94c5d835173e4bb7f2696b5b7ff Mon Sep 17 00:00:00 2001 From: Daniel King Date: Sun, 3 Nov 2024 16:17:33 -0800 Subject: [PATCH 19/19] pc --- llmfoundry/data/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 62683266f7..8038430259 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -32,7 +32,10 @@ def __init__( self.token_counting_func = token_counting_func self._token_count_batch_keys = [ - 'input_ids', 'attention_mask', 'labels', 'decoder_attention_mask' + 'input_ids', + 'attention_mask', + 'labels', + 'decoder_attention_mask', ] def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: