Skip to content

Commit

Permalink
log GPU memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 committed Aug 9, 2023
1 parent 176b888 commit 8787fec
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ evaluate==0.4.0
rouge-score==0.1.2
scipy
scikit-learn==1.2.2
nvidia-ml-py3
3 changes: 3 additions & 0 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers import GenerationConfig, TextStreamer

from axolotl.logging_config import configure_logging
from axolotl.utils.bench import gpu_utilization
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
Expand Down Expand Up @@ -250,6 +251,8 @@ def train(
LOG.info("Finished preparing dataset. Exiting...")
return

LOG.info(f"GPU memory baseline: {gpu_utilization(cfg.device)} MB.")

# Load the model and tokenizer
LOG.info("loading model and peft_config...")
model, peft_config = load_model(
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Benchmarking and measurement utilities"""

import pynvml


def gpu_utilization(device):
if not device.startswith("cuda:"):
return None
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(int(device[5:]))
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used // 1024**2
29 changes: 29 additions & 0 deletions src/axolotl/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Callbacks for Trainer class"""

import logging
import os

from optimum.bettertransformer import BetterTransformer
Expand All @@ -11,6 +12,10 @@
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy

from axolotl.utils.bench import gpu_utilization

LOG = logging.getLogger("axolotl.callbacks")


class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
"""Callback to save the PEFT adapter"""
Expand Down Expand Up @@ -67,3 +72,27 @@ def on_step_end(
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
control.should_save = False
return control


class PrintGPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to print GPU utilization"""

def __init__(self, cfg):
self.cfg = cfg
self.logged = False

def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if not self.logged:
LOG.info(
f"GPU memory while training: {gpu_utilization(self.cfg.device)} MB."
)
self.logged = True
return control
11 changes: 11 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
from axolotl.utils.bench import gpu_utilization

LOG = logging.getLogger("axolotl")

Expand Down Expand Up @@ -324,6 +325,11 @@ def load_model(
)
model.config.max_position_embeddings = cfg.sequence_len

if model.device.type == "cuda":
LOG.info(
f"GPU memory after model load: {gpu_utilization(model.device.type+':'+str(model.device.index))} MB."
)

if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
Expand Down Expand Up @@ -360,6 +366,11 @@ def load_model(
module.scales = module.scales.half()
module.bias = module.bias.half()

if model.device.type == "cuda":
LOG.info(
f"GPU memory after adapters: {gpu_utilization(model.device.type+':'+str(model.device.index))} MB."
)

if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from transformers.trainer_pt_utils import get_parameter_names

from axolotl.utils.callbacks import (
PrintGPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
Expand Down Expand Up @@ -292,6 +293,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)

callbacks = []
callbacks.append(PrintGPUStatsCallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
Expand Down

0 comments on commit 8787fec

Please sign in to comment.