Skip to content

Commit

Permalink
fix import
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Aug 29, 2024
1 parent 00d17e7 commit 53d1a8c
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from typing import Annotated
from unittest.mock import patch
import copy
import os
import tempfile

# Third Party
import pytest
import torch

# First Party
from tests.data import TWITTER_COMPLAINTS_JSON_FORMAT, TWITTER_COMPLAINTS_TOKENIZED
from tests.test_sft_trainer import DATA_ARGS, MODEL_ARGS, PEFT_LORA_ARGS, TRAIN_ARGS

# Local
Expand Down Expand Up @@ -53,6 +53,16 @@
)
from tuning.utils.import_utils import is_fms_accelerate_available

# for some reason the CI will raise an import error if we try to import
# these from tests.data
TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(
os.path.dirname(__file__), "../data/twitter_complaints_json.json"
)
TWITTER_COMPLAINTS_TOKENIZED = os.path.join(
os.path.dirname(__file__),
"../data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json",
)

# pylint: disable=import-error
if is_fms_accelerate_available():

Expand Down Expand Up @@ -491,7 +501,7 @@ def test_framework_initialize_and_trains_with_aadp():
data_args.dataset_text_field = None

# initialize a config
attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig(
aadp_config = AttentionAndDistributedPackingConfig(
padding_free=PaddingFree(method="huggingface")
)

Expand All @@ -514,7 +524,7 @@ def test_framework_initialize_and_trains_with_aadp():
model_args,
data_args,
train_args,
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
attention_and_distributed_packing_config=aadp_config,
)

# spy inside the train to ensure that the ilab plugin is called
Expand Down Expand Up @@ -550,7 +560,7 @@ def test_padding_free_plugin_raises_error_with_untokenized_dataset():
data_args.dataset_text_field = "output"

# initialize a config
attention_and_distributed_packing_config = AttentionAndDistributedPackingConfig(
aadp_config = AttentionAndDistributedPackingConfig(
padding_free=PaddingFree(method="huggingface")
)

Expand All @@ -573,7 +583,7 @@ def test_padding_free_plugin_raises_error_with_untokenized_dataset():
model_args,
data_args,
train_args,
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
attention_and_distributed_packing_config=aadp_config,
)


Expand Down

0 comments on commit 53d1a8c

Please sign in to comment.