Skip to content

Commit

Permalink
remove run_distribtued flag and peft_saving callback
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianlim committed Feb 28, 2024
1 parent a5aceba commit 7196479
Showing 1 changed file with 2 additions and 15 deletions.
17 changes: 2 additions & 15 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,6 @@
from tuning.utils.data_type_utils import get_torch_dtype


class PeftSavingCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(
args.output_dir, f"checkpoint-{state.global_step}"
)
kwargs["model"].save_pretrained(checkpoint_path)

if "pytorch_model.bin" in os.listdir(checkpoint_path):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


class FileLoggingCallback(TrainerCallback):
"""Exports metrics, e.g., training loss to a file in the checkpoint directory."""

Expand Down Expand Up @@ -103,7 +92,6 @@ def train(
None for fine tuning
The peft configuration to pass to trainer
"""
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1

logger = logging.get_logger("sft_trainer")

Expand Down Expand Up @@ -204,8 +192,7 @@ def train(

aim_callback = get_aimstack_callback()
file_logger_callback = FileLoggingCallback(logger)
peft_saving_callback = PeftSavingCallback()
callbacks = [aim_callback, peft_saving_callback, file_logger_callback]
callbacks = [aim_callback, file_logger_callback]

if train_args.packing:
logger.info("Packing is set to True")
Expand Down Expand Up @@ -246,7 +233,7 @@ def train(
peft_config=peft_config,
)

if run_distributed and peft_config is not None:
if trainer.is_fsdp_enabled and peft_config is not None:
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
model
)
Expand Down

0 comments on commit 7196479

Please sign in to comment.