From e377439cc5fdb41ca29aeda80e6bca61e414852d Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 1 Mar 2024 16:52:55 -0800 Subject: [PATCH 1/4] fix --- llmfoundry/data/finetuning/tasks.py | 10 +++++++++- tests/data/test_template_tokenization.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 126ed43812..ecae2c54a5 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -246,7 +246,15 @@ def _tokenize_prompt_response_formatted_example( f'Unable to tokenize example because {response_key} was not a string. {example=}' ) - return tokenizer(text=prompt, text_target=response) + tokenized_sample = tokenizer(text=prompt, text_target=response) + + # Remove the BOS token from the start of the labels if it was automatically added + if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token: + if tokenizer.bos_token_id is not None and tokenized_sample['labels'][ + 0] == tokenizer.bos_token_id: + tokenized_sample['labels'] = tokenized_sample['labels'][1:] + + return tokenized_sample def tokenize_formatted_example( diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 5491b94521..40c5b5969f 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -163,3 +163,23 @@ def test_tokenize_instruct_example_well_formed(): tokenized_example = tokenize_formatted_example(example, tokenizer) assert 'input_ids' in tokenized_example assert 'labels' in tokenized_example + + +def test_tokenize_no_labels_bos(): + # This tokenizer automatically adds bos tokens + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'mistralai/Mixtral-8x7B-v0.1') + + example = {'prompt': 'prompt', 'response': 'response'} + + tokenized_example = tokenize_formatted_example(example, tokenizer) + + assert len(tokenized_example['labels']) == 1 + assert tokenized_example['labels'][0] != '' + + # This tokenizer does not have the add_bos_token attribute + tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') + + tokenized_example = tokenize_formatted_example(example, tokenizer) + + assert len(tokenized_example['labels']) == 1 From e7f3c1891b1a790f73cff00bae450eab2ea5a2a1 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 4 Mar 2024 12:44:17 -0800 Subject: [PATCH 2/4] Deprecate triton, prefix lm, llama attention patch, and text denoising; Make ComposerHFT5 experimental (#1007) * Deprecate features and mark experimental * fix typo --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/data/denoising.py | 5 +++++ llmfoundry/models/hf/hf_causal_lm.py | 6 ++++++ llmfoundry/models/hf/hf_t5.py | 4 ++++ llmfoundry/models/hf/model_wrapper.py | 6 ++++++ llmfoundry/models/mpt/configuration_mpt.py | 11 +++++++++++ llmfoundry/utils/warnings.py | 13 +++++++++++++ 6 files changed, 45 insertions(+) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index 9c14f21751..303c9298bb 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -6,6 +6,7 @@ import logging import random import sys +import warnings from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union import numpy as np @@ -20,6 +21,7 @@ from llmfoundry.data.text_data import (StreamingTextDataset, get_tokens_per_batch_func) from llmfoundry.models import utils +from llmfoundry.utils.warnings import VersionedDeprecationWarning __all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader'] @@ -429,6 +431,9 @@ def build_text_denoising_dataloader( padding/waste rates for different `cfg.dataset.packing_ratio` choices, given a starting workload YAML. """ + warnings.warn( + VersionedDeprecationWarning('Text denoising is deprecated.', + remove_version='0.7.0')) assert cfg.name == 'text_denoising', f'Tried to build_denoising text dataloader with cfg.name={cfg.name}' collate_fn = MixtureOfDenoisersCollator( diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index dd766c99af..e3eaf3ad0c 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -28,6 +28,7 @@ from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.utils import init_empty_weights from llmfoundry.utils.config_utils import pop_config +from llmfoundry.utils.warnings import VersionedDeprecationWarning if TYPE_CHECKING: from peft import PeftConfig @@ -285,6 +286,11 @@ def _patch_attention_type(model: PreTrainedModel, f'attention_patch_type is only supported for llama models, but got {model.config.model_type}' ) + warnings.warn( + VersionedDeprecationWarning( + 'Attention patches for Llama models are deprecated. We recommend `use_flash_attention_2: True` for Llama models.', + remove_version='0.7.0')) + log.debug( f'Patching llama attention with {attention_patch_type} attention') from transformers.models.llama.modeling_llama import LlamaAttention diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py index 690a0de447..6c3976d072 100644 --- a/llmfoundry/models/hf/hf_t5.py +++ b/llmfoundry/models/hf/hf_t5.py @@ -5,6 +5,7 @@ from __future__ import annotations +import warnings from typing import Mapping from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy @@ -17,6 +18,7 @@ from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.utils import (adapt_tokenizer_for_denoising, init_empty_weights) +from llmfoundry.utils.warnings import ExperimentalWarning __all__ = ['ComposerHFT5'] @@ -57,6 +59,8 @@ class ComposerHFT5(HuggingFaceModelWithZLoss): def __init__(self, om_model_config: DictConfig, tokenizer: PreTrainedTokenizerBase): + warnings.warn(ExperimentalWarning(feature_name='ComposerHFT5')) + config = AutoConfig.from_pretrained( om_model_config.pretrained_model_name_or_path, trust_remote_code=om_model_config.get('trust_remote_code', True), diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index 58bd11a55f..e9f4c89796 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -5,6 +5,7 @@ from __future__ import annotations +import warnings from collections import UserDict from typing import TYPE_CHECKING, List, Mapping, Optional @@ -16,6 +17,7 @@ from transformers.utils.generic import ModelOutput from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp +from llmfoundry.utils.warnings import VersionedDeprecationWarning if TYPE_CHECKING: from peft import PeftConfig @@ -93,6 +95,10 @@ def loss(self, outputs: ModelOutput, batch: Mapping): if self.z_loss == 0.0: return loss + warnings.warn( + VersionedDeprecationWarning('z-loss is deprecated.', + remove_version='0.7.0')) + # Add a z_loss to the standard loss logits_flat = logits.view(-1, logits.size(-1)) labels_flat = batch['labels'].view(-1) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 20f435631f..68fe5befd5 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -214,6 +214,17 @@ def _validate_config(self) -> None: if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: raise ValueError( f"Unknown attn_impl={self.attn_config['attn_impl']}") + if self.attn_config['prefix_lm']: + warnings.warn( + VersionedDeprecationWarning( + 'Support for Prefix Language Models is deprecated.', + remove_version='0.7.0')) + if self.attn_config['attn_impl'] == 'triton': + warnings.warn( + VersionedDeprecationWarning( + 'Support for triton attention is deprecated. Please use torch or flash attention.', + remove_version='0.7.0')) + if self.attn_config['prefix_lm'] and self.attn_config[ 'attn_impl'] not in ['torch', 'triton']: raise NotImplementedError( diff --git a/llmfoundry/utils/warnings.py b/llmfoundry/utils/warnings.py index 2584ada601..b589ffdef0 100644 --- a/llmfoundry/utils/warnings.py +++ b/llmfoundry/utils/warnings.py @@ -25,3 +25,16 @@ class VersionedDeprecationWarning(DeprecationWarning): def __init__(self, message: str, remove_version: str) -> None: super().__init__(message + f' It will be removed in version {remove_version}.') + + +class ExperimentalWarning(Warning): + """A warning for experimental features. + + Attributes: + feature_name (str): The name of the experimental feature. + """ + + def __init__(self, feature_name: str) -> None: + super().__init__( + f'{feature_name} is experimental and may change with future versions.' + ) From 57681c4f6d6a9b1ec0b897bff38e88214379bea8 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Mon, 4 Mar 2024 18:02:26 -0800 Subject: [PATCH 3/4] pr comments --- llmfoundry/data/finetuning/tasks.py | 43 ++++++++++++++++------ tests/data/test_template_tokenization.py | 46 ++++++++++++++++++++++-- 2 files changed, 77 insertions(+), 12 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index ecae2c54a5..6fe48084af 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -195,6 +195,29 @@ def _slice_chat_formatted_example( return prompt, response +def _tokenize_with_bos_removal(tokenizer: PreTrainedTokenizerBase, text: str, + text_target: str) -> TokenizedExample: + """Tokenizes the prompt and response using the provided tokenizer. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer to use for tokenization. + prompt (str): The prompt to tokenize. + response (str): The response to tokenize. + + Returns: + TokenizedExample: The tokenized example. + """ + tokenized_sample = tokenizer(text=text, text_target=text_target) + + # Remove the BOS token from the start of the labels if it was automatically added + if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token: + if tokenizer.bos_token_id is not None and tokenized_sample['labels'][ + 0] == tokenizer.bos_token_id: + tokenized_sample['labels'] = tokenized_sample['labels'][1:] + + return tokenized_sample + + def _tokenize_chat_formatted_example( example: ChatFormattedDict, tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: @@ -208,7 +231,11 @@ def _tokenize_chat_formatted_example( TokenizedExample: The tokenized example. """ prompt, response = _slice_chat_formatted_example(example, tokenizer) - return tokenizer(text=prompt, text_target=response) + return _tokenize_with_bos_removal( + tokenizer=tokenizer, + text=prompt, + text_target=response, + ) def _tokenize_prompt_response_formatted_example( @@ -246,15 +273,11 @@ def _tokenize_prompt_response_formatted_example( f'Unable to tokenize example because {response_key} was not a string. {example=}' ) - tokenized_sample = tokenizer(text=prompt, text_target=response) - - # Remove the BOS token from the start of the labels if it was automatically added - if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token: - if tokenizer.bos_token_id is not None and tokenized_sample['labels'][ - 0] == tokenizer.bos_token_id: - tokenized_sample['labels'] = tokenized_sample['labels'][1:] - - return tokenized_sample + return _tokenize_with_bos_removal( + tokenizer=tokenizer, + text=prompt, + text_target=response, + ) def tokenize_formatted_example( diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 40c5b5969f..7b4f290abe 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -165,21 +165,63 @@ def test_tokenize_instruct_example_well_formed(): assert 'labels' in tokenized_example -def test_tokenize_no_labels_bos(): +def test_tokenize_no_labels_bos_pr(): # This tokenizer automatically adds bos tokens tokenizer = transformers.AutoTokenizer.from_pretrained( 'mistralai/Mixtral-8x7B-v0.1') example = {'prompt': 'prompt', 'response': 'response'} + assert tokenizer.add_bos_token == True + tokenized_example = tokenize_formatted_example(example, tokenizer) assert len(tokenized_example['labels']) == 1 - assert tokenized_example['labels'][0] != '' + assert tokenized_example['labels'][0] != tokenizer.bos_token_id + assert tokenized_example['input_ids'][0] == tokenizer.bos_token_id # This tokenizer does not have the add_bos_token attribute tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') + assert not hasattr(tokenizer, 'add_bos_token') + tokenized_example = tokenize_formatted_example(example, tokenizer) assert len(tokenized_example['labels']) == 1 + assert tokenized_example['labels'][0] != tokenizer.bos_token_id + assert tokenized_example['input_ids'][0] != tokenizer.bos_token_id + + +def test_tokenize_no_labels_bos_chat(): + # This tokenizer automatically adds bos tokens + tokenizer = transformers.AutoTokenizer.from_pretrained( + 'mistralai/Mixtral-8x7B-v0.1') + + example = { + 'messages': [{ + 'role': 'user', + 'content': 'Hello, GPT' + }, { + 'role': 'assistant', + 'content': 'response' + }] + } + + assert tokenizer.add_bos_token == True + + tokenized_example = tokenize_formatted_example(example, tokenizer) + + assert len(tokenized_example['labels']) == 4 + assert tokenized_example['labels'][0] != tokenizer.bos_token_id + assert tokenized_example['input_ids'][0] == tokenizer.bos_token_id + + # This tokenizer does not have the add_bos_token attribute + tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') + + assert not hasattr(tokenizer, 'add_bos_token') + + tokenized_example = tokenize_formatted_example(example, tokenizer) + + assert len(tokenized_example['labels']) == 2 + assert tokenized_example['labels'][0] != tokenizer.bos_token_id + assert tokenized_example['input_ids'][0] != tokenizer.bos_token_id From a58eb86e5a8871c52fa31eeecfc77332614b8403 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 5 Mar 2024 18:10:49 -0800 Subject: [PATCH 4/4] remove chat changes --- llmfoundry/data/finetuning/tasks.py | 6 +--- tests/data/test_template_tokenization.py | 35 ------------------------ 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 6fe48084af..8faddb2825 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -231,11 +231,7 @@ def _tokenize_chat_formatted_example( TokenizedExample: The tokenized example. """ prompt, response = _slice_chat_formatted_example(example, tokenizer) - return _tokenize_with_bos_removal( - tokenizer=tokenizer, - text=prompt, - text_target=response, - ) + return tokenizer(text=prompt, text_target=response) def _tokenize_prompt_response_formatted_example( diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 7b4f290abe..fdaf30ccc5 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -190,38 +190,3 @@ def test_tokenize_no_labels_bos_pr(): assert len(tokenized_example['labels']) == 1 assert tokenized_example['labels'][0] != tokenizer.bos_token_id assert tokenized_example['input_ids'][0] != tokenizer.bos_token_id - - -def test_tokenize_no_labels_bos_chat(): - # This tokenizer automatically adds bos tokens - tokenizer = transformers.AutoTokenizer.from_pretrained( - 'mistralai/Mixtral-8x7B-v0.1') - - example = { - 'messages': [{ - 'role': 'user', - 'content': 'Hello, GPT' - }, { - 'role': 'assistant', - 'content': 'response' - }] - } - - assert tokenizer.add_bos_token == True - - tokenized_example = tokenize_formatted_example(example, tokenizer) - - assert len(tokenized_example['labels']) == 4 - assert tokenized_example['labels'][0] != tokenizer.bos_token_id - assert tokenized_example['input_ids'][0] == tokenizer.bos_token_id - - # This tokenizer does not have the add_bos_token attribute - tokenizer = transformers.AutoTokenizer.from_pretrained('mosaicml/mpt-7b') - - assert not hasattr(tokenizer, 'add_bos_token') - - tokenized_example = tokenize_formatted_example(example, tokenizer) - - assert len(tokenized_example['labels']) == 2 - assert tokenized_example['labels'][0] != tokenizer.bos_token_id - assert tokenized_example['input_ids'][0] != tokenizer.bos_token_id