diff --git a/requirements.txt b/requirements.txt index 98a57c66a..f49ec035b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ evaluate==0.4.0 rouge-score==0.1.2 scipy scikit-learn==1.2.2 +nvidia-ml-py3 diff --git a/scripts/finetune.py b/scripts/finetune.py index ddf1992d6..47623a518 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -18,6 +18,7 @@ from transformers import GenerationConfig, TextStreamer from axolotl.logging_config import configure_logging +from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -250,6 +251,8 @@ def train( LOG.info("Finished preparing dataset. Exiting...") return + log_gpu_memory_usage(LOG, "baseline", cfg.device) + # Load the model and tokenizer LOG.info("loading model and peft_config...") model, peft_config = load_model( diff --git a/src/axolotl/utils/bench.py b/src/axolotl/utils/bench.py new file mode 100644 index 000000000..759fb6e21 --- /dev/null +++ b/src/axolotl/utils/bench.py @@ -0,0 +1,23 @@ +"""Benchmarking and measurement utilities""" + +import pynvml +import torch + + +def gpu_memory_usage(device): + if isinstance(device, torch.device): + device = device.index + if isinstance(device, str) and device.startswith("cuda:"): + device = int(device[5:]) + + # NB torch.cuda.memory_usage returns zero so we use lower level api + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(device) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return info.used / 1024.0**3 + + +def log_gpu_memory_usage(log, msg, device): + log.info( + f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2 + ) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 526121f2e..f06762b6b 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -1,5 +1,6 @@ """Callbacks for Trainer class""" +import logging import os from optimum.bettertransformer import BetterTransformer @@ -11,6 +12,10 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy +from axolotl.utils.bench import log_gpu_memory_usage + +LOG = logging.getLogger("axolotl.callbacks") + class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods """Callback to save the PEFT adapter""" @@ -67,3 +72,25 @@ def on_step_end( # the trainer will raise an exception since it can't save a BetterTransformer wrapped model control.should_save = False return control + + +class PrintGPUStatsCallback( + TrainerCallback +): # pylint: disable=too-few-public-methods disable=unused-argument + """Callback to print GPU utilization""" + + def __init__(self, cfg): + self.cfg = cfg + self.logged = False + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if not self.logged: + log_gpu_memory_usage(LOG, "while training", self.cfg.device) + self.logged = True + return control diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7501878ba..31e211953 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -22,6 +22,7 @@ ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN +from axolotl.utils.bench import log_gpu_memory_usage LOG = logging.getLogger("axolotl") @@ -324,6 +325,9 @@ def load_model( ) model.config.max_position_embeddings = cfg.sequence_len + if model.device.type == "cuda": + log_gpu_memory_usage(LOG, "after model load", model.device) + if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) @@ -360,6 +364,9 @@ def load_model( module.scales = module.scales.half() module.bias = module.bias.half() + if model.device.type == "cuda": + log_gpu_memory_usage(LOG, "after adapters", model.device) + if ( torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2144e6b02..a5d2ea74e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -18,6 +18,7 @@ from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import ( + PrintGPUStatsCallback, SaveBetterTransformerModelCallback, SavePeftModelCallback, ) @@ -292,6 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) callbacks = [] + callbacks.append(PrintGPUStatsCallback(cfg)) # TODO on_save callback to sync checkpoints to GCP/AWS in background if cfg.early_stopping_patience: early_stop_cb = EarlyStoppingCallback(