diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 31243c0fd..6015cf8cd 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -17,6 +17,7 @@ from typing import Annotated from unittest.mock import patch import copy +import os import tempfile # Third Party @@ -24,7 +25,6 @@ import torch # First Party -from tests.data import TWITTER_COMPLAINTS_JSON_FORMAT, TWITTER_COMPLAINTS_TOKENIZED from tests.test_sft_trainer import DATA_ARGS, MODEL_ARGS, PEFT_LORA_ARGS, TRAIN_ARGS # Local @@ -53,6 +53,16 @@ ) from tuning.utils.import_utils import is_fms_accelerate_available +# for some reason the CI will raise an import error if we try to import +# these from tests.data +TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( + os.path.dirname(__file__), "../data/twitter_complaints_json.json" +) +TWITTER_COMPLAINTS_TOKENIZED = os.path.join( + os.path.dirname(__file__), + "../data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json", +) + # pylint: disable=import-error if is_fms_accelerate_available(): @@ -491,7 +501,7 @@ def test_framework_initialize_and_trains_with_aadp(): data_args.dataset_text_field = None # initialize a config - attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig( + aadp_config = AttentionAndDistributedPackingConfig( padding_free=PaddingFree(method="huggingface") ) @@ -514,7 +524,7 @@ def test_framework_initialize_and_trains_with_aadp(): model_args, data_args, train_args, - attention_and_distributed_packing_config=attention_and_distributed_packing_config, + attention_and_distributed_packing_config=aadp_config, ) # spy inside the train to ensure that the ilab plugin is called @@ -550,7 +560,7 @@ def test_padding_free_plugin_raises_error_with_untokenized_dataset(): data_args.dataset_text_field = "output" # initialize a config - attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig( + aadp_config = AttentionAndDistributedPackingConfig( padding_free=PaddingFree(method="huggingface") ) @@ -573,7 +583,7 @@ def test_padding_free_plugin_raises_error_with_untokenized_dataset(): model_args, data_args, train_args, - attention_and_distributed_packing_config=attention_and_distributed_packing_config, + attention_and_distributed_packing_config=aadp_config, )