Skip to content

Commit

Permalink
removed unit test enforcing pretokenized datasets with paddingfree
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
  • Loading branch information
achew010 committed Sep 4, 2024
1 parent 53d1a8c commit 8f1c9ea
Showing 1 changed file with 0 additions and 55 deletions.
55 changes: 0 additions & 55 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,61 +532,6 @@ def test_framework_initialize_and_trains_with_aadp():
assert spy["augmentation_calls"] == 1
assert spy["get_ready_for_train_calls"] == 1


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="aadp"),
reason="Only runs if fms-accelerate is installed along with \
attention_and_distributed_packing plugin",
)
def test_padding_free_plugin_raises_error_with_untokenized_dataset():
"""
Currently sft_trainer uses DataCollatorForCompletionOnlyLM for unformatted,
untokenized datasets. It uses a DataCollatorForSeq2Seq as default for pretokenized
datasets.
Ensure that padding free plugin will raise an error when an untokenized
dataset is passed to the padding-free plugin when it checks the data collator.
"""

with tempfile.TemporaryDirectory() as tempdir:

model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = "no"
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.response_template = "\n### Response:"
data_args.dataset_text_field = "output"

# initialize a config
aadp_config = AttentionAndDistributedPackingConfig(
padding_free=PaddingFree(method="huggingface")
)

with pytest.raises(
TypeError,
match="The padding-free plugin currently only works with a \
`DataCollatorForSeq2Seq` collate_fn",
):
with build_framework_and_maybe_instantiate(
[
(
["training.attention.padding_free"],
PaddingFreeAccelerationPlugin,
),
],
instantiate=False,
):
with instantiate_model_patcher():
sft_trainer.train(
model_args,
data_args,
train_args,
attention_and_distributed_packing_config=aadp_config,
)


def test_error_raised_with_paddingfree_and_flash_attn_disabled():
"""Ensure error raised when padding-free is not used with flash attention"""
with pytest.raises(
Expand Down

0 comments on commit 8f1c9ea

Please sign in to comment.