Skip to content

Commit

Permalink
log GPU memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 committed Aug 9, 2023
1 parent 176b888 commit e303d64
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ evaluate==0.4.0
rouge-score==0.1.2
scipy
scikit-learn==1.2.2
nvidia-ml-py3
3 changes: 3 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers import GenerationConfig, TextStreamer

from axolotl.logging_config import configure_logging
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
Expand Down Expand Up @@ -250,6 +251,8 @@ def train(
LOG.info("Finished preparing dataset. Exiting...")
return

log_gpu_memory_usage(LOG, "baseline", cfg.device)

# Load the model and tokenizer
LOG.info("loading model and peft_config...")
model, peft_config = load_model(
Expand Down
23 changes: 23 additions & 0 deletions src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Benchmarking and measurement utilities"""

import pynvml
import torch


def gpu_memory_usage(device):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])

# NB torch.cuda.memory_usage returns zero so we use lower level api
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3


def log_gpu_memory_usage(log, msg, device):
log.info(
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
)
27 changes: 27 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Callbacks for Trainer class"""

import logging
import os

from optimum.bettertransformer import BetterTransformer
Expand All @@ -11,6 +12,10 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

from axolotl.utils.bench import log_gpu_memory_usage

LOG = logging.getLogger("axolotl.callbacks")


class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter"""
Expand Down Expand Up @@ -67,3 +72,25 @@ def on_step_end(
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control


class PrintGPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to print GPU utilization"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False

def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if not self.logged:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control
7 changes: 7 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage

LOG = logging.getLogger("axolotl")

Expand Down Expand Up @@ -324,6 +325,9 @@ def load_model(
)
model.config.max_position_embeddings = cfg.sequence_len

if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)

if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
Expand Down Expand Up @@ -360,6 +364,9 @@ def load_model(
module.scales = module.scales.half()
module.bias = module.bias.half()

if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after adapters", model.device)

if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers.trainer_pt_utils import get_parameter_names

from axolotl.utils.callbacks import (
PrintGPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
Expand Down Expand Up @@ -292,6 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)

callbacks = []
callbacks.append(PrintGPUStatsCallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
Expand Down

0 comments on commit e303d64

Please sign in to comment.