diff --git a/requirements.txt b/requirements.txt index 98a57c66a9..f49ec035b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ evaluate==0.4.0 rouge-score==0.1.2 scipy scikit-learn==1.2.2 +nvidia-ml-py3 diff --git a/scripts/finetune.py b/scripts/finetune.py index ddf1992d6c..2836a5f7e1 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -18,6 +18,7 @@ from transformers import GenerationConfig, TextStreamer from axolotl.logging_config import configure_logging +from axolotl.utils.bench import gpu_utilization from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -250,6 +251,8 @@ def train( LOG.info("Finished preparing dataset. Exiting...") return + LOG.info(f"GPU memory baseline: {gpu_utilization(cfg.device)} MB.") + # Load the model and tokenizer LOG.info("loading model and peft_config...") model, peft_config = load_model( diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py new file mode 100644 index 0000000000..5cde43d5a5 --- /dev/null +++ b/src/axolotl/utils/bench.py @@ -0,0 +1,12 @@ +"""Benchmarking and measurement utilities""" + +import pynvml + + +def gpu_utilization(device): + if not device.startswith("cuda:"): + return None + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(int(device[5:])) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return info.used // 1024**2 diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 526121f2e3..692098756c 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -1,5 +1,6 @@ """Callbacks for Trainer class""" +import logging import os from optimum.bettertransformer import BetterTransformer @@ -11,6 +12,10 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +from axolotl.utils.bench import gpu_utilization + +LOG = logging.getLogger("axolotl.callbacks") + class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods """Callback to save the PEFT adapter""" @@ -67,3 +72,27 @@ def on_step_end( # the trainer will raise an exception since it can't save a BetterTransformer wrapped model control.should_save = False return control + + +class PrintGPUStatsCallback( + TrainerCallback +): # pylint: disable=too-few-public-methods disable=unused-argument + """Callback to print GPU utilization""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if not self.logged: + LOG.info( + f"GPU memory while training: {gpu_utilization(self.cfg.device)} MB." + ) + self.logged = True + return control diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7501878ba8..9f4cf2960f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -22,6 +22,7 @@ ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN +from axolotl.utils.bench import gpu_utilization LOG = logging.getLogger("axolotl") @@ -324,6 +325,11 @@ def load_model( ) model.config.max_position_embeddings = cfg.sequence_len + if model.device.type == "cuda": + LOG.info( + f"GPU memory after model load: {gpu_utilization(model.device.type+':'+str(model.device.index))} MB." + ) + if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -360,6 +366,11 @@ def load_model( module.scales = module.scales.half() module.bias = module.bias.half() + if model.device.type == "cuda": + LOG.info( + f"GPU memory after adapters: {gpu_utilization(model.device.type+':'+str(model.device.index))} MB." + ) + if ( torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2144e6b023..a5d2ea74ed 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -18,6 +18,7 @@ from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import ( + PrintGPUStatsCallback, SaveBetterTransformerModelCallback, SavePeftModelCallback, ) @@ -292,6 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) callbacks = [] + callbacks.append(PrintGPUStatsCallback(cfg)) # TODO on_save callback to sync checkpoints to GCP/AWS in background if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback(