Skip to content

Commit

Permalink
Pass weakref to model in the SIGINT handler to free up model post tra…
Browse files Browse the repository at this point in the history
…in function (#1581)

* Pass weakref to model in the SIGINT handler to free up model post train()

* Fix lint issues

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
chiragjn and winglian authored May 3, 2024
1 parent b9bb169 commit dde02fc
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import signal
import sys
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
Expand Down Expand Up @@ -127,14 +128,20 @@ def train(
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:

def terminate_handler(_, __, model):
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
_model = model_weakref()
if cfg.flash_optimum and BetterTransformer:
_model = BetterTransformer.reverse(_model)
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
sys.exit(0)

_model_weakref = weakref.ref(model)
signal.signal(
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
signal.SIGINT,
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
)

badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
Expand Down

0 comments on commit dde02fc

Please sign in to comment.