diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 7f8c9f44..da8760d7 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -1,4 +1,5 @@ import contextlib +import csv import gc import math import os @@ -412,27 +413,23 @@ def test_all_pair_to_pair( ) -def log_throughput( +def create_table_log( config: Config, parallel_context: ParallelContext, - model_tflops=0, - hardware_tflops=0, - tokens_per_sec=0, - bandwidth=0, + model_tflops, + hardware_tflops, + tokens_per_sec, + bandwidth, + slurm_job_id, ): - micro_batch_size = config.micro_batch_size - n_micro_batches_per_batch = config.batch_accumulation_per_replica - global_batch_size = micro_batch_size * n_micro_batches_per_batch * parallel_context.dp_pg.size() - sequence_length = config.sequence_length - slurm_job_id = os.environ.get("SLURM_JOB_ID", "N/A") - csv_filename = config.benchmark_csv_path - table_log = [ - LogItem("model_name", config.model_name, "s"), + return [ + LogItem("job_id", slurm_job_id, "s"), + LogItem("name", config.general.run, "s"), LogItem("nodes", math.ceil(parallel_context.world_pg.size() / 8), "d"), - LogItem("seq_len", (sequence_length), "d"), - LogItem("mbs", micro_batch_size, "d"), - LogItem("batch_accum", n_micro_batches_per_batch, "d"), - LogItem("gbs", global_batch_size, "d"), + LogItem("seq_len", config.tokens.sequence_length, "d"), + LogItem("mbs", config.tokens.micro_batch_size, "d"), + LogItem("batch_accum", config.tokens.batch_accumulation_per_replica, "d"), + LogItem("gbs", config.global_batch_size, "d"), LogItem("mTFLOPs", model_tflops, ".2f"), LogItem("hTFLOPs", hardware_tflops, ".2f"), LogItem("tok/s/gpu", tokens_per_sec / parallel_context.world_pg.size(), ".2f"), @@ -441,7 +438,8 @@ def log_throughput( LogItem("Mem Res (GB)", torch.cuda.max_memory_reserved() / 1024**3, ".2f"), ] - column_widths = [max(len(item.tag), len(f"{item.scalar_value:{item.log_format}}")) for item in table_log] + +def create_table_output(table_log, column_widths): header_row = "| " + " | ".join([item.tag.ljust(width) for item, width in zip(table_log, column_widths)]) + " |" separator_row = "| " + " | ".join(["-" * width for width in column_widths]) + " |" data_row = ( @@ -451,7 +449,48 @@ def log_throughput( ) + " |" ) - table_output = f"{header_row}\n{separator_row}\n{data_row}" + return f"{header_row}\n{separator_row}\n{data_row}" + + +def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id): + if not os.path.exists(csv_filename): + os.makedirs(os.path.dirname(csv_filename), exist_ok=True) + with open(csv_filename, mode="w") as fo: + writer = csv.writer(fo) + writer.writerow([item.tag for item in table_log]) + writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log]) + # elif model_tflops > 0: + # # replace line with same job_id + # with open(csv_filename, mode="r") as fi: + # lines = fi.readlines() + # with open(csv_filename, mode="w") as fo: + # writer = csv.writer(fo) + # for line in lines: + # if line.startswith(slurm_job_id): + # writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log]) + # else: + # fo.write(line) + else: + with open(csv_filename, mode="a") as fo: + writer = csv.writer(fo) + writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log]) + + +def log_throughput( + config: Config, + parallel_context: ParallelContext, + model_tflops=0, + hardware_tflops=0, + tokens_per_sec=0, + bandwidth=0, +): + slurm_job_id = os.environ.get("SLURM_JOB_ID", "N/A") + + table_log = create_table_log( + config, parallel_context, model_tflops, hardware_tflops, tokens_per_sec, bandwidth, slurm_job_id + ) + column_widths = [max(len(item.tag), len(f"{item.scalar_value:{item.log_format}}")) for item in table_log] + table_output = create_table_output(table_log, column_widths) log_rank( table_output, @@ -460,26 +499,5 @@ def log_throughput( rank=0, ) - import csv - if dist.get_rank(parallel_context.world_pg) == 0: - if not os.path.exists(csv_filename): - with open(csv_filename, mode="w") as fo: - writer = csv.writer(fo) - writer.writerow([item.tag for item in table_log]) - writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log]) - elif model_tflops > 0: - # replace line with same job_id - with open(csv_filename, mode="r") as fi: - lines = fi.readlines() - with open(csv_filename, mode="w") as fo: - writer = csv.writer(fo) - for line in lines: - if line.startswith(slurm_job_id): - writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log]) - else: - fo.write(line) - else: - with open(csv_filename, mode="a") as fo: - writer = csv.writer(fo) - writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log]) + write_to_csv(config.general.benchmark_csv_path, table_log, model_tflops, slurm_job_id) diff --git a/src/nanotron/logging.py b/src/nanotron/logging.py index 7617bdd1..da1f94ec 100644 --- a/src/nanotron/logging.py +++ b/src/nanotron/logging.py @@ -21,6 +21,7 @@ from logging import CRITICAL, DEBUG, ERROR, FATAL, INFO, NOTSET, WARNING, Formatter, Logger from typing import List, Optional, Union +import torch from torch import distributed as torch_dist from nanotron import distributed as dist @@ -233,6 +234,18 @@ def human_format(num: float, billions: bool = False, divide_by_1024: bool = Fals return "{}{}".format("{:f}".format(num).rstrip("0").rstrip("."), SIZES[magnitude]) +def log_memory(logger: logging.Logger): + log_rank( + f" Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB." + f" Peak allocated {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB." + f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB", + logger=logger, + level=logging.INFO, + rank=0, + ) + torch.cuda.reset_peak_memory_stats() + + @dataclass class LogItem: tag: str diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index c7622506..3ba58d01 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -29,7 +29,7 @@ log_throughput, lr_scheduler_builder, ) -from nanotron.logging import LoggerWriter, LogItem, human_format, log_rank, set_logger_verbosity_format +from nanotron.logging import LoggerWriter, LogItem, human_format, log_memory, log_rank, set_logger_verbosity_format from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding @@ -301,16 +301,7 @@ def training_step( before_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) if self.iteration_step < 5: - log_rank( - f"[Before train batch iter] Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB." - f" Peak allocated {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB." - f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB", - logger=logger, - level=logging.INFO, - group=self.parallel_context.world_pg, - rank=0, - ) - torch.cuda.reset_peak_memory_stats() + log_memory(logger=logger) outputs = self.pipeline_engine.train_batch_iter( model=self.model, @@ -321,16 +312,7 @@ def training_step( ) if self.iteration_step < 5: - log_rank( - f"[After train batch iter] Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB." - f" Peak allocated {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB." - f" Peak reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f}MiB", - logger=logger, - level=logging.INFO, - group=self.parallel_context.world_pg, - rank=0, - ) - torch.cuda.reset_peak_memory_stats() + log_memory(logger=logger) after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator)