diff --git a/Dockerfile.ci b/Dockerfile.ci index 3d9a9d9b08a1..33490a6d9079 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -31,6 +31,10 @@ EOF WORKDIR /workspace +RUN pip install hatchling # needed to install nemo-run +ARG NEMU_RUN_TAG=34259bd3e752fef94045a9a019e4aaf62bd11ce2 +RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMU_RUN_TAG} + # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.15.0 diff --git a/examples/llm/auto_configurator/README.md b/examples/llm/auto_configurator/README.md new file mode 100644 index 000000000000..26cf5cd75263 --- /dev/null +++ b/examples/llm/auto_configurator/README.md @@ -0,0 +1,85 @@ +> [!IMPORTANT] +> This is an early version of the Auto Configurator, and the code base can be modified as it will be integrated into the CLI. + +Use Auto Configurator to Find the Optimal Configuration +------------------------------------------------------- + +Auto Configurator searches for hyperparameters (HPs) that achieve the maximum highest training throughput when working with Large Language Models (LLMs) utilizing the NeMo Framework. + +> [!NOTE] +> Auto Configurator is only supported now for GPT-based models: GPT3, LLama, Mixtral, Mistral, Gemma and Nemotron. + +Auto Configurator Capabilities +------------------------------ + +Auto Configurator is intended to iterate over different model configurations quickly and find the best configuration, that is, the configuration that minimizes both time and financial expenditure. It offers a range of features to facilitate this, as detailed in the list below. + +- **Model size recommendation**: finds the optimal model size if the parameter is not specified. +- **Training time estimation**: estimates model training time based on input parameters. +- **Base configuration generation**: returns a basic model configuration. +- **Hyperparameters recommendation**: finds the optimal list of hyperparameters to be trained. +- **Optimal configuration recommendation**: calculates the performance after a short training of candidate configurations and finds the optimal model configuration. + +Model Size Recommendation +------------------------- + +If you have not decided what model size you want to train, Auto Configurator can recommend a model size for your use case. If you know the number of GPUs, TFLOPS per GPU, the maximum time to train, and the number of tokens to train for, it can recommend a model size that can be trained with the specified hardware and time constraints. + +For example, if you had 20 NVIDIA DGX nodes available (in 80 GB GPU memory), and wanted to train a GPT model for a maximum of 5 days, Auto Configurator would recommend using a 5B parameter GPT model. + +Training Time Estimation +------------------------ + +Auto Configurator calculates the estimated training time for your model. It provides a projection of the training time in days, based on the input dataset and parameters you provide. + +Base Configuration Generation +----------------------------- + +When you provide the model size, or Auto Configurator has suggested one, it generates a base configuration for the target model. The base configuration is a valid configuration in NeMo 2.0 format. The optimization of throughput, however, is conducted in the next step. + +Hyperparameters Recommendation +------------------------------ + +After Auto Configurator generates the base configuration, it searches over four critical hyperparameters that have a great impact on training throughput but do not affect model convergence. These hyperparameters include Tensor Parallelism (TP), Pipeline Parallelism (PP), Context Parallelism (CP), Expert Parallelism (EP), Micro Batch Size (MBS), and Activation Checkpointing Layers (ActCkpt). Auto Configurator will also provide optimal Global Batch Size (GBS) if it's not specified. + +Auto Configurator initially applies heuristics to identify suitable candidates for the four key parameters, subsequently generating a grid of candidate configurations. It returns all of the candidate configurations in NeMo 2.0 format. + +> [!NOTE] +> Some of the candidate configurations may not work due to high-memory usage or other issues. + +Once the candidate configurations are generated, you can use NeMo Framework to launch the most promising candidates. + +When running the candidates on the cluster, you can limit job time and job max steps by using ``max_minutes_per_run`` and ``max_steps_per_run`` parameters. During this search, the jobs will run with the number of nodes specified in the configuration files, using the ``num_nodes`` parameter. Once all of the jobs have finished running, you'll need to run compare_throughput.py to get a ``.csv`` table with performance results for each succeeded job. + +Optimal Configuration Recommendation +------------------------------------ + +After all of the candidate jobs are done, Auto Configurator calculates performance parameters for each of the candidates. +Auto Configurator generates two ``.csv`` files: one detailing the performance measures of the candidates and another listing the candidates that failed due to out-of-memory errors. + +End-To-End Example +------------------ + +The following list shows the required input parameters for the Auto Configurator runner: + +- ``model``: model configuration based on NeMo 2.0. +- ``num_nodes``: number of nodes to be used for the training. +- ``seq_length``: sequence length to be used for the training. +- ``data_paths``: dataset to be used for the training. +- ``tokenizer_path``: path to tokenizer model if custom tokenizer will be used. + +The following list shows the optional parameters for the Auto Configurator runner: + +- ``global_batch_size``: global batch size to be used. +- ``tensor_parallel_sizes``: a list, such as ``[1, 2, 4]``. +- ``pipeline_parallel_sizes``: a list, such as ``[1, 2, 4]``. +- ``context_parallel_sizes``: a list, such as ``[1, 2, 4]``. +- ``expert_parallel_sizes``: a list, such as ``[1, 2, 4]``. +- ``micro_batch_sizes``: a list, such as ``[1, 2, 4]``. +- ``min_model_parallel_size``: a value for the minimum desired parallelism. +- ``max_model_parallel_size``: a value for the maximum desired parallelism. + +For each of the optional parameters, Auto Configurator will find the optimal value if the parameter is not specified. To view the full list of parameters, please visit [this page](https://github.com/NVIDIA/NeMo/blob/dpykhtar/nemo_autoconf/nemo/collections/llm/tools/auto_configurator/runner.py#L51). + +To view an end-to-end example of how to generate candidate configs, train them, and calculate the performance using Auto Configurator with NeMo Framework, please visit [this page](https://github.com/NVIDIA/NeMo/blob/dpykhtar/nemo_autoconf/examples/llm/auto_configurator/auto_config.py). + diff --git a/examples/llm/auto_configurator/auto_config.py b/examples/llm/auto_configurator/auto_config.py new file mode 100644 index 000000000000..c202d4d33325 --- /dev/null +++ b/examples/llm/auto_configurator/auto_config.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import fiddle as fdl +import nemo_run as run + +from nemo.collections.llm import GPTConfig126M +from nemo.collections.llm.tools.auto_configurator import AutoConfigurator, generate_configs, get_results + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--run_number", type=int, help="Number of config to run") + parser.add_argument("--logs_dir", type=str, help="Path where to save training logs") + parser.add_argument("--data_path", type=str, help="Path to the dataset") + parser.add_argument("--get_results", action="store_true") + + return parser.parse_args() + + +def train_config(args): + # GPT-3 126M + # This example will generate 3 configs. + # It is expected that this script will be run 3 times with changing --run_number flag for each run from 0 to 2. + # After all configurations are trained, please trigger the script using --get_results flag. + runner = AutoConfigurator( + model=run.Config(GPTConfig126M), + num_nodes=1, + gpus_per_node=1, + gpu_memory_gb=40, + global_batch_size=16, + seq_length=512, + tensor_parallel_sizes=[1], + pipeline_parallel_sizes=[1], + micro_batch_sizes=[1, 2, 4], + max_training_days=1, + max_steps_per_run=25, + num_tokens_in_b=10, + vocab_size=51200, + data_paths=args.data_path, + path_to_logs=args.logs_dir, + ) + + base_cfg, configs = generate_configs(runner) + if not args.get_results: + # Get generated configs + partials = list(configs.values()) + names = list(configs.keys()) + + # Run pre-training + partial = partials[args.run_number - 1] + partial.log.dir = os.path.join(args.logs_dir, names[args.run_number - 1]) + pretrain = fdl.build(partial) + pretrain() + else: + # # Get Auto Configurator results + get_results(base_cfg, runner, args.logs_dir) + print(f"The results were successfully saved to {args.logs_dir}.") + + +def main(): + args = get_args() + train_config(args) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index a5ce0c82a0e0..614af0df400c 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -51,6 +51,12 @@ GemmaConfig7B, GemmaModel, GPTConfig, + GPTConfig5B, + GPTConfig7B, + GPTConfig20B, + GPTConfig40B, + GPTConfig126M, + GPTConfig175B, GPTModel, Llama2Config7B, Llama2Config13B, diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 9785889aaf92..aa3615b3ddfd 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -15,6 +15,12 @@ from nemo.collections.llm.gpt.model.baichuan import Baichuan2Config, Baichuan2Config7B, Baichuan2Model from nemo.collections.llm.gpt.model.base import ( GPTConfig, + GPTConfig5B, + GPTConfig7B, + GPTConfig20B, + GPTConfig40B, + GPTConfig126M, + GPTConfig175B, GPTModel, MaskedTokenLossReduction, gpt_data_step, diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index a6b53f4e859d..e0d752bf3411 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -182,6 +182,60 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel": ) +@dataclass +class GPTConfig126M(GPTConfig): + seq_length: int = 2048 + num_layers: int = 12 + hidden_size: int = 768 + ffn_hidden_size: int = 3072 + num_attention_heads: int = 12 + + +@dataclass +class GPTConfig5B(GPTConfig): + seq_length: int = 2048 + num_layers: int = 24 + hidden_size: int = 4096 + ffn_hidden_size: int = 16384 + num_attention_heads: int = 32 + + +@dataclass +class GPTConfig7B(GPTConfig): + seq_length: int = 2048 + num_layers: int = 32 + hidden_size: int = 4096 + ffn_hidden_size: int = 10880 + num_attention_heads: int = 32 + + +@dataclass +class GPTConfig20B(GPTConfig): + seq_length: int = 2048 + num_layers: int = 44 + hidden_size: int = 6144 + ffn_hidden_size: int = 24576 + num_attention_heads: int = 48 + + +@dataclass +class GPTConfig40B(GPTConfig): + seq_length: int = 2048 + num_layers: int = 48 + hidden_size: int = 8192 + ffn_hidden_size: int = 32768 + num_attention_heads: int = 64 + + +@dataclass +class GPTConfig175B(GPTConfig): + seq_length: int = 2048 + num_layers: int = 96 + hidden_size: int = 12288 + ffn_hidden_size: int = 49152 + num_attention_heads: int = 96 + + class GPTModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): def __init__( self, diff --git a/nemo/collections/llm/tools/auto_configurator/__init__.py b/nemo/collections/llm/tools/auto_configurator/__init__.py new file mode 100644 index 000000000000..5c6bde2c285a --- /dev/null +++ b/nemo/collections/llm/tools/auto_configurator/__init__.py @@ -0,0 +1,2 @@ +from nemo.collections.llm.tools.auto_configurator.core.calculate_performance import get_results +from nemo.collections.llm.tools.auto_configurator.runner import AutoConfigurator, generate_configs diff --git a/nemo/collections/llm/tools/auto_configurator/core/__init__.py b/nemo/collections/llm/tools/auto_configurator/core/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/llm/tools/auto_configurator/core/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/llm/tools/auto_configurator/core/base_config.py b/nemo/collections/llm/tools/auto_configurator/core/base_config.py new file mode 100644 index 000000000000..ee1579f6f6e8 --- /dev/null +++ b/nemo/collections/llm/tools/auto_configurator/core/base_config.py @@ -0,0 +1,367 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import TensorBoardLogger + +from nemo import lightning as nl +from nemo.collections.common.tokenizers import AutoTokenizer, SentencePieceTokenizer +from nemo.collections.llm import PreTrainingDataModule +from nemo.collections.llm.utils import Config +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BaseConfig: + def __init__(self, config=None): + """ + Args: + config (AutoConfigurator): auto configurator runner config. + """ + + self.config = config + + self.model = self.get_model() + self.optim = self.get_optim() + self.trainer = self.get_trainer() + self.data = self.get_data() + self.log = self.get_logger() + self.run = self.get_run_config() + self.tokenizer = self.get_tokenizer(config.tokenizer_type, config.tokenizer_path) + + def get_model(self): + """Function that returns model config. + + Returns: + Config: model config. + """ + + self.config.model.seq_length = self.config.seq_length + + return self.config.model + + def get_optim(self) -> Config[OptimizerConfig]: + """Function that returns optimizer config. + + Returns: + Config[OptimizerConfig]: optimizer config. + """ + optim_params = { + "optimizer": "adam", + "lr": 1e-4, + "min_lr": 1e-5, + "use_distributed_optimizer": True, + "bf16": True, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "overlap_grad_reduce": True, + "overlap_param_gather": True, + "clip_grad": 1.0, + "adam_eps": 1e-5, + } + + optim_config = Config( + OptimizerConfig, + **optim_params, + ) + + sched = Config( + CosineAnnealingScheduler, + warmup_steps=10, + constant_steps=0, + min_lr=optim_config.min_lr, + ) + + return Config( + MegatronOptimizerModule, + config=optim_config, + lr_scheduler=sched, + ) + + def get_trainer(self) -> Config[nl.Trainer]: + """Function that returns config for PTL trainer. + + Returns: + Config[nl.Trainer]: trainer config. + """ + + trainer_config = { + "accelerator": "gpu", + "enable_checkpointing": False, + "use_distributed_sampler": False, + "max_epochs": None, + "log_every_n_steps": 1, + "limit_val_batches": 1, + "limit_test_batches": 1, + "accumulate_grad_batches": 1, + "num_nodes": self.config.num_nodes, + "devices": self.config.num_gpus, + "max_steps": self.config.max_steps_per_run, + "val_check_interval": self.config.max_steps_per_run, + } + + strategy = Config( + nl.MegatronStrategy, + pipeline_dtype=torch.bfloat16, + ) + + return Config( + nl.Trainer, + **trainer_config, + strategy=strategy, + plugins=Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), + callbacks=[Config(TimingCallback)], + ) + + def get_tokenizer(self, tokenizer_type: str, tokenizer_path: str) -> Config: + """Function that returns the tokenizer config. + + Args: + tokenizer_type (str): tokenizer type. + tokenizer_path (str): path to the tokenizer. + + Returns: + Config: tokenizer config. + """ + + if tokenizer_type == "sentencepiece": + return Config(SentencePieceTokenizer, model_path=tokenizer_path) + else: + return Config(AutoTokenizer, pretrained_model_name=tokenizer_path) + + def get_data(self) -> Config[PreTrainingDataModule]: + """Function that returns dataset config. + + Returns: + Config[PreTrainingDataModule]: data config. + """ + + # Data config + data_config = { + "paths": self.config.data_paths, + "seq_length": self.config.seq_length, + "global_batch_size": self.config.global_batch_size, + "num_workers": 2, + "index_mapping_dir": None, + } + + # Define the tokenizer + tokenizer = self.get_tokenizer( + self.config.tokenizer_type, + self.config.tokenizer_path, + ) + + return Config( + PreTrainingDataModule, + **data_config, + tokenizer=tokenizer, + ) + + def get_logger(self) -> Config[nl.NeMoLogger]: + """Function that returns the training strategy. + + Returns: + Config[nl.NeMoLogger]: NeMo Logger config. + """ + + # Define TensorBoard Logger + tb_logger = Config(TensorBoardLogger, save_dir="tb_logs") + + ckpt = Config( + nl.ModelCheckpoint, + monitor="reduced_train_loss", + save_last=False, + save_top_k=0, + ) + + return Config( + nl.NeMoLogger, + ckpt=ckpt, + tensorboard=tb_logger, + wandb=None, + dir=self.config.path_to_logs, + ) + + def get_run_config(self) -> dict: + """Function that returns config for cluster job. + + Returns: + dict: cluster job config. + """ + + run_config = { + "name": self.config.model.__class__.__name__, + "time_limit": f"0-00:{self.config.max_minutes_per_run}:00", + } + + return run_config + + +def calculate_model_size( + gpu_count: int, + max_training_days: float, + model_size_in_b: float = None, + tflops_per_gpu: int = 140, + num_tokens_in_b: int = 300, + model_name: str = "gpt3", +) -> float: + """Estimates a model size to be trained given the constraints. If the + model_size is provided, it estimates the time to train it with the given + constraints. + + Example: + output 5B params to train for 7 days with 160 GPUs. + + Args: + gpu_count (int): number of gpus to use (num_nodes * gpus_per_node). + max_training_days (float): number of days to train the model for. + model_size_in_b (float): number of parameters in the model, if known. + tflops_per_gpu (int): estimated number of TFLOPS/s per GPU. + num_tokens_in_b (int): number of tokens to train the model for. + model_name (str): name of the model. + + Returns: + float: number of parameters to use for training. + """ + + # Model size is not known, must be estimated. + if model_size_in_b is None: + model_size_in_b = _estimate_model_size( + max_training_days=max_training_days, + gpu_count=gpu_count, + tflops_per_gpu=tflops_per_gpu, + num_tokens_in_b=num_tokens_in_b, + model_name=model_name, + ) + # Model size is known, so only time to train estimate is needed. + else: + max_training_days = _estimate_training_time( + model_size_in_b=model_size_in_b, + gpu_count=gpu_count, + tflops_per_gpu=tflops_per_gpu, + num_tokens_in_b=num_tokens_in_b, + model_name=model_name, + ) + + print( + f"You can train a {model_size_in_b}B parameter model in " + f"{max_training_days} days using {gpu_count} GPUs. This result assumes " + f"you are training to {num_tokens_in_b}B tokens, and each GPU achieves " + f"{tflops_per_gpu} TFLOPS." + ) + return model_size_in_b + + +def _estimate_model_size( + max_training_days: float, + gpu_count: int, + tflops_per_gpu: int, + num_tokens_in_b: int, + model_name: str, +) -> float: + """Estimates model size given time and hardware constraints. It's only used if the model size is not provided by the user. + + Args: + max_training_days (float): number of days to train the model for. + gpu_count (int): number of gpus to use (num_nodes * gpus_per_node). + tflops_per_gpu (int): estimated number of TFLOPS/s per GPU. + num_tokens_in_b (int): number of tokens to train the model for. + model_name (str): name of the model, such as gpt3, t5, mt5... + + Returns: + float: number of parameters to use for training. + + Raises: + NotImplementedError: if the model_name is not one of the supported models. + """ + + model_penalty = 0.87 if model_name == "mt5" else 1.0 + valid_models = ["gpt3", "t5", "mt5", "bert", "llama", "mixtral", "mistral", "gemma", "nemotron"] + try: + if model_name in valid_models: + return round( + model_penalty + * (max_training_days * 3600 * 24 * gpu_count * tflops_per_gpu * 1e12) + / (8 * num_tokens_in_b * 1e9) + / 1e9, + 2, + ) + else: + raise NotImplementedError + except ValueError as err: + print(f"Input values were not valid: {err}") + except ZeroDivisionError as err: + print(f"Cannot divide by zero. This can happen if num_tokens_in_b is zero: {err}") + except NotImplementedError as err: + print(f"Model size estimation is only available for {valid_models}: {err}") + return None + + +def _estimate_training_time( + model_size_in_b: float, + gpu_count: int, + tflops_per_gpu: int, + num_tokens_in_b: int, + model_name: str, +) -> float: + """Estimates training time for a given model size and hardware constraint. To be used when a model size is provided by the user. + + Args: + model_size_in_b (float): number of parameters to use for training. + gpu_count (int): number of gpus to use (num_nodes * gpus_per_node). + tflops_per_gpu (int): estimated number of TFLOPS/s per GPU. + num_tokens_in_b (int): number of tokens to train the model for. + model_name (str): name of the model, such as gpt3, t5, mt5... + + Returns: + float: number of days it will take to train the model. + + Raises: + NotImplementedError: if the model_name is not one of the supported models. + """ + + model_penalty = 1.15 if model_name == "mt5" else 1.0 + valid_models = ["gpt3", "t5", "mt5", "bert", "llama", "mixtral", "mistral", "gemma", "nemotron"] + try: + if model_name in valid_models: + return round( + model_penalty + * (model_size_in_b * 1e9 * 8 * num_tokens_in_b * 1e9) + / (3600 * 24 * gpu_count * tflops_per_gpu * 1e12), + 2, + ) + else: + raise NotImplementedError + except ValueError as err: + print(f"Input values were not valid: {err}") + except ZeroDivisionError as err: + print(f"Cannot divide by zero. This can happen if gpu_count or tflops_per_gpu are zero: {err}") + except NotImplementedError as err: + print(f"Training time estimation is only available for {valid_models}: {err}") + return None diff --git a/nemo/collections/llm/tools/auto_configurator/core/calculate_performance.py b/nemo/collections/llm/tools/auto_configurator/core/calculate_performance.py new file mode 100644 index 000000000000..5b7ac0ebc4d3 --- /dev/null +++ b/nemo/collections/llm/tools/auto_configurator/core/calculate_performance.py @@ -0,0 +1,334 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from typing import Optional + +import pandas as pd +from tensorboard.backend.event_processing import event_accumulator + + +def get_results( + base_config=None, + train_config=None, + path_to_save: str = None, + output_top_n: Optional[int] = 10, +): + """Generates performance results. + + Args: + config (AutoConfigurator): auto configurator runner config. + path_to_save (str): path where to save performance results. + output_top_n (Optional[int]): Number of configs to be printed out as best configs. + """ + + # Define needed variables + model_name = train_config.model_type + model_size = train_config.model_size_in_b + global_batch_size = base_config.data.global_batch_size + seq_length = base_config.data.seq_length + + vocab_size = train_config.vocab_size + num_nodes = train_config.num_nodes + gpus_per_node = train_config.gpus_per_node + + layers = base_config.model.num_layers + hs = base_config.model.hidden_size + ffn_hs = base_config.model.ffn_hidden_size + + training_logs = path_to_save + final_result_logs = path_to_save + + result_columns = [ + "Model Name", + "Model Size", + "Seq Length", + "TP", + "PP", + "CP", + "EP", + "MBS", + "Act Ckpt Layers", + "Act Ckpt Micro Bathes", + "Act Ckpt Layers per Pipeline", + "Num Layers", + "Hidden Size", + "FFN Hidden Size", + "GBS", + "Nodes", + "GPUs per Node", + "Time per Step", + "Samples per Second", + "Model TFLOPS / GPU", + "Model TFLOPS Aggregate", + ] + error_columns = [ + "Model Name", + "Model Size", + "Seq Length", + "TP", + "PP", + "CP", + "EP", + "MBS", + "Act Ckpt Layers", + "Act Ckpt Micro Bathes", + "Act Ckpt Layers per Pipeline", + "Num Layers", + "Hidden Size", + "FFN Hidden Size", + "GBS", + "Nodes", + "GPUs per Node", + "Error Message", + ] + result = [] + errors = [] + dirs = [f.path for f in os.scandir(training_logs) if f.is_dir()] + + for candidate_dir in dirs: + logs_dir = os.path.join(training_logs, candidate_dir, "tb_logs/lightning_logs") + logs_folder = [f.path for f in os.scandir(logs_dir) if f.is_dir()][0] + tp, pp, cp, ep, mbs, act_ckpt, num_mbs_act, act_per_pipe = get_config(candidate_dir) + + for f in os.listdir(logs_folder): + if f.endswith("0.txt"): + error_file = os.path.join(logs_folder, f) + error = find_error(error_file) + if error: + errors.append( + [ + model_name, + model_size, + seq_length, + tp, + pp, + cp, + ep, + mbs, + act_ckpt, + num_mbs_act, + act_per_pipe, + layers, + hs, + ffn_hs, + global_batch_size, + num_nodes, + gpus_per_node, + error, + ] + ) + + files = os.listdir(logs_folder) + for f in files: + if f.startswith("events"): + event_file = os.path.join(logs_folder, f) + ea = event_accumulator.EventAccumulator(event_file) + ea.Reload() + try: + timing_list = ea.Scalars("train_step_timing in s") + if len(timing_list) <= 6: + continue + timing_list = [x.value for x in timing_list[5:]] + avg_global_step_time = round(sum(timing_list) / len(timing_list), 4) + samples_per_s = round(global_batch_size / avg_global_step_time, 2) + m_tflops, m_tflops_gpu = calculate_tflops( + model_name=model_name, + gbs=global_batch_size, + enc_seq_len=seq_length, + dec_seq_len=seq_length, + hs=hs, + ffn_hs=ffn_hs, + layers=layers, + vocab=vocab_size, + nodes=num_nodes, + gpus_per_node=gpus_per_node, + time_per_step=avg_global_step_time, + ) + config_name = f"tp{tp}_pp{pp}_cp{cp}_ep{ep}_mbs{mbs}_act_{act_ckpt}_num_mbs_act_{num_mbs_act}_act_per_pipe_{act_per_pipe}" + result.append( + [ + model_name, + model_size, + seq_length, + tp, + pp, + cp, + ep, + mbs, + act_ckpt, + num_mbs_act, + act_per_pipe, + layers, + hs, + ffn_hs, + global_batch_size, + num_nodes, + gpus_per_node, + avg_global_step_time, + samples_per_s, + m_tflops_gpu, + m_tflops, + ] + ) + finally: + continue + result.sort(key=lambda x: x[17]) + print(f"Top {min(output_top_n, len(result))} configs sorted from fastest to slowest:") + for i, res in enumerate(result): + print(f"Config #{i+1}: {res[-1]} with {res[17]:.4f}s per global step.") + if i + 1 == output_top_n: + break + + top_config = f"{model_name}_{model_size}b_{num_nodes}nodes_tp_{result[0][3]}_pp_{result[0][4]}_cp_{result[0][5]}_ep_{result[0][6]}_mbs_{result[0][7]}_act_ckpt_{result[0][8]}_num_mbs_act_{result[0][9]}_act_per_pipe_{result[0][10]}" + print("\n==================================================") + print(f"Optimal config: {top_config} with {result[0][17]:.4f}s per global step.") + print("==================================================\n") + + # Save results as a CSV file. + os.makedirs(final_result_logs, exist_ok=True) + result_df = pd.DataFrame(result, columns=result_columns) + result_df.to_csv(os.path.join(final_result_logs, f"final_summary_{num_nodes}nodes.csv"), index=False) + + error_df = pd.DataFrame(errors, columns=error_columns) + error_df.to_csv(os.path.join(final_result_logs, f"failed_jobs_{num_nodes}nodes.csv"), index=False) + + +def calculate_tflops( + model_name, + gbs, + enc_seq_len, + dec_seq_len, + hs, + ffn_hs, + layers, + vocab, + nodes, + gpus_per_node, + time_per_step, +): + """Calculates model and hardware TFLOPS for each model. + + GPT-3 Formulas: + Model FLOPs = (24𝐵𝑠ℎ^2 + 4𝐵��^2ℎ) x (3 x num_layers) + 6𝐵𝑠ℎ + T5/mT5 Formula: + Model FLOPs = + Bert Formula: + Model FLOPs = 72BLsh^2 * ( 1 + (s/6h) + (v/12hL)) + """ + + if model_name in ["gpt3", "llama", "baichuan2", "chatglm", "qwen2", "mixtral"]: + # Model FLOPS calculation + model_flops = ( + (24 * gbs * enc_seq_len * hs * hs + 4 * gbs * enc_seq_len * enc_seq_len * hs) * (3 * layers) + + (6 * gbs * enc_seq_len * hs * vocab) + ) / time_per_step + model_flops_per_gpu = model_flops / (nodes * gpus_per_node) + + model_tflops = model_flops / 1e12 + model_tflops_per_gpu = model_flops_per_gpu / 1e12 + + elif model_name == "bert": + model_flops = ( + 72 * gbs * layers * enc_seq_len * hs * hs * (1 + (enc_seq_len / (6 * hs)) + (vocab / (12 * hs * layers))) + ) / time_per_step + model_flops_per_gpu = model_flops / (nodes * gpus_per_node) + model_tflops = model_flops / 1e12 + model_tflops_per_gpu = model_flops_per_gpu / 1e12 + + elif model_name in ["t5", "mt5"]: + # Encoder Layer FLOPS: include self attention + MLP + flops_self_attn_enc = 8 * gbs * enc_seq_len * hs * hs + 4 * gbs * enc_seq_len * enc_seq_len * hs + flops_mlp_enc = 6 * gbs * enc_seq_len * hs * ffn_hs # geglu needs two gemms for h -> ffn_h + flops_enc_layer = flops_self_attn_enc + flops_mlp_enc + + # Decoder Layer FLOPS: inlcude self_attn + cross_attn + MLP + flops_self_attn_dec = 8 * gbs * dec_seq_len * hs * hs + 4 * gbs * dec_seq_len * dec_seq_len * hs + flops_cross_attn_dec = ( + 4 * gbs * enc_seq_len * hs * hs + + 4 * gbs * dec_seq_len * hs * hs + + 4 * gbs * enc_seq_len * dec_seq_len * hs + ) + flops_mlp_dec = 6 * gbs * dec_seq_len * hs * ffn_hs # geglu needs two gemms for h -> ffn_h + flops_dec_layer = flops_self_attn_dec + flops_cross_attn_dec + flops_mlp_dec + + # FLOPs of logits layer in the head + flops_logits = 2 * gbs * dec_seq_len * hs * vocab + + # FLOPs of fprop + flops_fprop = (flops_enc_layer + flops_dec_layer) * (layers // 2) + flops_logits + + # FLOPs of each train step (FLOPs of bprop is 2*fprop) + model_flops = 3 * flops_fprop / time_per_step + model_flops_per_gpu = model_flops / (nodes * gpus_per_node) + model_tflops = model_flops / 1e12 + model_tflops_per_gpu = model_flops_per_gpu / 1e12 + + else: + raise NotImplementedError("Model type not supported.") + return round(model_tflops, 2), round(model_tflops_per_gpu, 2) + + +def find_error(error_file: str, errors: list = ["CUDA out of memory"]): + """Function that finds the error among job output. + + Args: + errors (list): list of "popular" errors. + error_file (str): path to the job output. + + Returns: + str: serror message if job has been failed because of one of listed errors or None if not. + """ + + error = None + with open(error_file, "r") as f: + output = f.read() + for e in errors: + if e in output: + error = e + return error + + +def get_config(run_name: str) -> tuple: + """Function that extract model parallelism parameters + + Args: + run_name (str): name of the run. + + Returns: + tuple: model parallelism parameters. + """ + pattern = r'_(tp|pp|cp|ep|mbs|act_ckpt|num_mbs_act|act_per_pipe)_([^_]+)' + + # Find all matches in the input string + matches = re.findall(pattern, run_name) + + # Convert matches to a dictionary + params = {param: value for param, value in matches} + + return ( + params["tp"], + params["pp"], + params["cp"], + params["ep"], + params["mbs"], + params["act_ckpt"], + params["num_mbs_act"], + params["act_per_pipe"], + ) + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/llm/tools/auto_configurator/core/training_config.py b/nemo/collections/llm/tools/auto_configurator/core/training_config.py new file mode 100644 index 000000000000..087bf3c6fb0e --- /dev/null +++ b/nemo/collections/llm/tools/auto_configurator/core/training_config.py @@ -0,0 +1,892 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Tuple + +from nemo.collections.llm.tools.auto_configurator.core import utils + + +GPT_BASED_MODELS = [ + "gpt3", + "bert", + "llama", + "baichuan2", + "chatglm", + "qwen2", + "mixtral", + "mistral", + "gemma", + "nemotron", +] + + +def generate_grid_search_configs( + base_cfg: dict, + train_cfg: dict, +) -> Tuple[dict, dict]: + """Generates the grid of all possible configurations for the given model, and stores each different configuration in a yaml file. + + Args: + base_cfg (dict): base configuration of the model to be trained. + train_cfg (dict): train configuration of the model to be trained. + + Returns: + dict: base config. + dict: generated configs. + """ + + model_name = train_cfg.model_type + model_size_in_b = train_cfg.model_size_in_b + + # 2 * num_layers is needed because of encoder/decoder architecture. + multiplier = 1 if model_name in GPT_BASED_MODELS else 2 + + seq_length = base_cfg.model.seq_length + num_layers = base_cfg.model.num_layers if model_name in GPT_BASED_MODELS else base_cfg.model.encoder.num_layers + + if model_name in GPT_BASED_MODELS: + act_method = None + else: + act_method = base_cfg.model.encoder.activations_checkpoint_method + + params = _calculate_tp_pp_mbs_grid( + model_size_in_b=model_size_in_b, + num_layers=num_layers, + model_name=model_name, + seq_length=seq_length, + train_cfg=train_cfg, + ) + + max_minutes = train_cfg.max_minutes_per_run + max_steps = train_cfg.max_steps_per_run + num_nodes = train_cfg.num_nodes + + valid_tp_pp_list = [] + for tp in params.tp: + for pp in params.pp: + for cp in params.cp: + for ep in params.ep: + for mbs in params.mbs: + num_gpus = base_cfg.trainer.num_nodes * base_cfg.trainer.devices + base_cfg.data.global_batch_size = params.gbs + if model_name in GPT_BASED_MODELS: + att_heads = base_cfg.model.num_attention_heads + num_layers = base_cfg.model.num_layers + else: + att_heads = base_cfg.model.encoder.num_attention_heads + num_layers = base_cfg.model.encoder.num_layers + model_parallelism = (tp * pp * cp * ep) if (cp and ep) else (tp * pp) + mod_gbs = params.gbs % (mbs * num_gpus / model_parallelism) + mod_att_heads = att_heads % tp + mod_layers = (multiplier * num_layers) % pp + mod_cp = cp if cp else 1 + mod_ep = ep if ep else 1 + if ( + mod_gbs == 0 + and mod_att_heads == 0 + and mod_layers == 0 + and (tp, pp, cp, ep) not in valid_tp_pp_list + and (mod_cp // mod_ep == mod_cp or mod_ep // mod_cp == mod_ep) + and params.min_model_parallel <= model_parallelism <= params.max_model_parallel + ): + valid_tp_pp_list.append((tp, pp, cp, ep)) + + # Generate grid search configs. + configs = {} + for tp, pp, cp, ep in valid_tp_pp_list: + ( + virtual_pipelines, + act_ckpt_layers, + num_micro_batches_partial_act_ckpt, + act_ckpt_layers_per_pipeline, + ) = _set_activations_checkpoint_params( + tp, + pp, + cp, + ep, + num_layers, + act_method, + multiplier, + model_size_in_b, + model_name, + ) + for mbs in params.mbs: + kwargs = { + "base_cfg": base_cfg, + "act": None, + "num_mbs_act": None, + "act_per_pipe": None, + "tp": tp, + "pp": pp, + "cp": cp, + "ep": ep, + "virtual_pipelines": virtual_pipelines, + "mbs": mbs, + "max_minutes": max_minutes, + "max_steps": max_steps, + "num_nodes": num_nodes, + "model_name": model_name, + "model_size": model_size_in_b, + } + if act_ckpt_layers[0] is not None: + if act_layers is not None and act_layers != "auto": + act_ckpt_layers = act_layers + for act in act_ckpt_layers: + for num_mbs_act in num_micro_batches_partial_act_ckpt: + for act_per_pipe in act_ckpt_layers_per_pipeline: + kwargs["act"] = act + kwargs["num_mbs_act"] = num_mbs_act + kwargs["act_per_pipe"] = act_per_pipe + new_cfg = utils.modify_cfg(**kwargs) + if new_cfg: # Save candidate cfg. + configs[new_cfg["run"]["name"]] = new_cfg + else: + new_cfg = utils.modify_cfg(**kwargs) + if new_cfg: # Save candidate cfg. + config_name = new_cfg["run"]["name"] + new_cfg.pop("run") + configs[config_name] = new_cfg + + print(f"\nAll candidate configurations created correctly. Total number of configs: {len(configs)}.\n") + return base_cfg, configs + + +def _set_activations_checkpoint_params( + tp, pp, cp, ep, num_layers, act_method, multiplier, model_size_in_b, model_name +): + act_multiple = 4 // pp + if act_method == "block": + if 1.0 <= model_size_in_b < 11.3: + act_multiple = 8 // pp + elif 11.3 <= model_size_in_b < 26.0: + act_multiple = 16 // pp + elif 26.0 <= model_size_in_b < 60.0: + act_multiple = 16 // pp + elif 60.0 <= model_size_in_b: + act_multiple = 32 // pp + act_multiple = max(act_multiple, 1) + + virtual_pipelines = None + # Num micro batches with partial act ckpt + min_micro_b = 0 # 0 will not be used, minimum will be set to 1 later in the code. + max_micro_b = pp + interval_micro_b = 1 + # Act ckpt layers per pipeline + min_layers_per_pipe = 0 + max_layers_per_pipe = num_layers + interval_layers_per_pipe = act_multiple + if model_name in GPT_BASED_MODELS and pp > 2: # Interleaved pipeline scheduling. + virtual_pipelines = num_layers // pp # TODO: verify that this is the best value. + act_multiple = 1 + max_micro_b = pp * (virtual_pipelines - 1) + (pp - 1) * 2 + 1 + interval_micro_b = virtual_pipelines * 8 + max_layers_per_pipe = multiplier * num_layers // pp // virtual_pipelines + 1 + + ( + act_ckpt_layers, + num_micro_batches_partial_act_ckpt, + act_ckpt_layers_per_pipeline, + ) = ([None], [None], [None]) + if act_method == "block": + # Act ckpt num layers + if virtual_pipelines is None: + act_ckpt_layers = range(0, multiplier * num_layers // pp + 1, act_multiple) + else: + act_ckpt_layers = range(0, multiplier * num_layers // pp // virtual_pipelines + 1, act_multiple) + + if pp > 1 and model_name in GPT_BASED_MODELS: + # Num micro batches with partial act ckpt + num_micro_batches_partial_act_ckpt = list(range(min_micro_b, max_micro_b + 1, interval_micro_b)) + if num_micro_batches_partial_act_ckpt[0] == 0: + num_micro_batches_partial_act_ckpt[0] = 1 + + # Act ckpt layers per pipeline + act_ckpt_layers_per_pipeline = range( + min_layers_per_pipe, max_layers_per_pipe + 1, interval_layers_per_pipe + ) + + return ( + virtual_pipelines, + act_ckpt_layers, + num_micro_batches_partial_act_ckpt, + act_ckpt_layers_per_pipeline, + ) + + +@dataclass +class GPT3GridSearch: + """Selects grid search space for TP, PP, CP, EP, MBS parameters for GPT-3 and 80GB GPUs. + + Args: + model_size_in_b (float): number of parameters in the model. + valid_pp (List[int]): list of valid Pipeline Parallelism (PP) values for this config. + seq_length (int): sequence length to use for training. + gpu_memory_gb (int): size of GPU memory in GB. + """ + + model_size_in_b: int + valid_pp: List[int] + seq_length: int + gpu_memory_gb: int + + tp = [1, 2, 4, 8] + pp = [1] + cp = [1] + ep = [1] + mbs = [1, 2, 4, 8] + + gbs: int = 1024 + min_model_parallel: int = 1 + max_model_parallel: int = 8 + + def init_params(self): + model_size_in_b = self.model_size_in_b + gpu_memory_gb = self.gpu_memory_gb + seq_length = self.seq_length + + if gpu_memory_gb == 80: + if seq_length == 2048: + if model_size_in_b <= 1.0: + self.tp = [1, 2] + self.gbs = 256 + elif model_size_in_b <= 4.0: + self.tp = [1, 2, 4] + self.gbs = 1024 + elif model_size_in_b <= 8.0: + self.tp = [1, 2, 4] + self.gbs = 2048 + elif model_size_in_b <= 13.0: + self.tp = [1, 2, 4, 8] + self.gbs = 2048 + elif model_size_in_b <= 23.0: + self.tp = [1, 2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 4] + self.mbs = [1, 2, 4] + self.min_model_parallel = 4 + self.max_model_parallel = 8 + self.gbs = 2048 + elif model_size_in_b <= 45.0: + self.tp = [2, 4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 4] + self.mbs = [1, 2, 4] + self.min_model_parallel = 8 + self.max_model_parallel = 32 + self.gbs = 2048 + elif model_size_in_b <= 95: + self.tp = [2, 4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 8] + self.mbs = [1, 2, 4, 8] + self.min_model_parallel = 8 + self.max_model_parallel = 64 + self.gbs = 2048 + elif model_size_in_b <= 130.0: + self.tp = [2, 4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 16] + self.mbs = [1, 2, 4, 8] + self.min_model_parallel = 16 + self.max_model_parallel = 128 + self.gbs = 2048 + elif model_size_in_b <= 195.0: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 4 <= x <= 16] + self.mbs = [1, 2, 4] + self.min_model_parallel = 32 + self.max_model_parallel = 256 + self.gbs = 2048 + elif model_size_in_b <= 395.0: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 8 <= x <= 32] + self.mbs = [1, 2, 4] + self.min_model_parallel = 64 + self.max_model_parallel = 512 + self.gbs = 2048 + elif model_size_in_b <= 790.0: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 8 <= x <= 100] + self.mbs = [1, 2, 4] + self.min_model_parallel = 128 + self.max_model_parallel = 1024 + self.gbs = 2048 + elif model_size_in_b <= 1100.0: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 16 <= x <= 130] + self.mbs = [1, 2, 4] + self.min_model_parallel = 256 + self.max_model_parallel = 2048 + self.gbs = 2048 + elif seq_length == 4096: + if model_size_in_b <= 1.0: + self.tp = [1, 2, 4] + self.mbs = [1, 2, 4, 8] + self.gbs = 128 + elif model_size_in_b <= 4.0: + self.tp = [1, 2, 4] + self.mbs = [1, 2, 4, 8] + self.gbs = 512 + elif model_size_in_b <= 8.0: + self.tp = [1, 2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1, 2, 4] + self.gbs = 1024 + elif model_size_in_b <= 13.0: + self.tp = [2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1, 2, 4] + self.gbs = 1024 + elif model_size_in_b <= 23.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1, 2] + self.min_model_parallel = 4 + self.max_model_parallel = 16 + self.gbs = 1024 + elif model_size_in_b <= 45.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 2 <= x <= 4] + self.mbs = [1, 2] + self.min_model_parallel = 8 + self.max_model_parallel = 32 + self.gbs = 1024 + elif model_size_in_b <= 95: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 8] + self.mbs = [1, 2] + self.min_model_parallel = 8 + self.max_model_parallel = 64 + self.gbs = 1024 + elif seq_length == 8192: + if model_size_in_b <= 1.0: + self.tp = [1, 2] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1, 2, 4] + self.gbs = 64 + elif model_size_in_b <= 4.0: + self.tp = [1, 2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1, 2, 4] + self.gbs = 128 + elif model_size_in_b <= 8.0: + self.tp = [2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1, 2] + self.gbs = 256 + elif model_size_in_b <= 13.0: + self.tp = [2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1, 2] + self.gbs = 256 + elif model_size_in_b <= 23.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 4] + self.mbs = [1] + self.min_model_parallel = 8 + self.max_model_parallel = 32 + self.gbs = 256 + elif model_size_in_b <= 45.0: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 4 <= x <= 8] + self.mbs = [1] + self.min_model_parallel = 32 + self.max_model_parallel = 64 + self.gbs = 256 + elif seq_length == 16384: + if model_size_in_b <= 1.0: + self.tp = [2, 4] + self.mbs = [1, 2] + self.gbs = 32 + elif model_size_in_b <= 4.0: + self.tp = [2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1] + self.gbs = 64 + elif model_size_in_b <= 8.0: + self.tp = [2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1] + self.gbs = 128 + elif model_size_in_b <= 13.0: + self.tp = [2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1] + self.gbs = 128 + elif model_size_in_b <= 23.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 2 <= x <= 4] + self.mbs = [1] + self.min_model_parallel = 8 + self.max_model_parallel = 32 + self.gbs = 128 + elif seq_length == 32768: + if model_size_in_b <= 1.0: + self.tp = [2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1] + self.gbs = 16 + elif model_size_in_b <= 4.0: + self.tp = [2, 4] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1] + self.gbs = 32 + elif model_size_in_b <= 8.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.min_model_parallel = 4 + self.max_model_parallel = 16 + self.mbs = [1] + self.gbs = 64 + elif model_size_in_b <= 13.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.min_model_parallel = 4 + self.max_model_parallel = 16 + self.mbs = [1] + self.gbs = 64 + elif model_size_in_b <= 23.0: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 2 <= x <= 4] + self.mbs = [1] + self.min_model_parallel = 16 + self.max_model_parallel = 32 + self.gbs = 64 + elif gpu_memory_gb == 40: + if model_size_in_b <= 1.0: + self.tp = [1, 2, 4] + self.mbs = [1, 2, 4, 8] + self.gbs = 256 + elif model_size_in_b <= 4.0: + self.tp = [1, 2, 4, 8] + self.mbs = [1, 2, 4, 8] + self.gbs = 1024 + elif model_size_in_b <= 8.0: + self.tp = [2, 4, 8] + self.pp = [1, 2] + self.mbs = [1, 2, 4] + self.min_model_parallel = 2 + self.gbs = 2048 + elif model_size_in_b <= 13.0: + self.tp = [4, 8] + self.pp = [1, 2, 4] + self.mbs = [1, 2, 4] + self.min_model_parallel = 4 + self.max_model_parallel = 32 + self.gbs = 2048 + elif model_size_in_b <= 23.0: + self.tp = [2, 4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 8] + self.min_model_parallel = 8 + self.max_model_parallel = 64 + self.gbs = 2048 + elif model_size_in_b <= 45.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 12] + self.mbs = [1, 2, 4] + self.min_model_parallel = 16 + self.max_model_parallel = 128 + self.gbs = 2048 + elif model_size_in_b <= 95: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 16] + self.mbs = [1, 2, 4] + self.min_model_parallel = 16 + self.max_model_parallel = 256 + self.gbs = 2048 + elif model_size_in_b <= 130.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 2 <= x <= 26] + self.mbs = [1, 2] + self.min_model_parallel = 32 + self.max_model_parallel = 512 + self.gbs = 2048 + elif model_size_in_b <= 195.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 2 <= x <= 32] + self.mbs = [1, 2] + self.min_model_parallel = 64 + self.max_model_parallel = 1024 + self.gbs = 2048 + elif model_size_in_b <= 395.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 4 <= x <= 64] + self.mbs = [1, 2] + self.min_model_parallel = 128 + self.max_model_parallel = 2048 + self.gbs = 2048 + elif model_size_in_b <= 790.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 8 <= x <= 128] + self.mbs = [1, 2] + self.min_model_parallel = 256 + self.max_model_parallel = 4096 + self.gbs = 2048 + elif model_size_in_b <= 1100.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 8 <= x <= 192] + self.mbs = [1, 2] + self.min_model_parallel = 512 + self.max_model_parallel = 8192 + self.gbs = 2048 + + +@dataclass +class T5GridSearch: + """Selects grid search space for TP, PP, MBS parameters for T5/mT5 and 80GB GPUs. + + Args: + model_size_in_b (float): number of parameters in the model. + valid_pp (List[int]): list of valid Pipeline Parallelism (PP) values for this config. + seq_length (int): sequence length to use for training. + gpu_memory_gb (int): size of GPU memory in GB. + """ + + model_size_in_b: int + seq_length: int + gpu_memory_gb: int + valid_pp: List[int] + + tp = [1, 2, 4, 8] + pp = [1] + cp = [None] + ep = [None] + mbs = [1, 2, 4, 6, 8, 12, 16] + + gbs: int = 1920 + min_model_parallel: int = 1 + max_model_parallel: int = 8 + + def init_params(self): + model_size_in_b = self.model_size_in_b + gpu_memory_gb = self.gpu_memory_gb + seq_length = self.seq_length + + if gpu_memory_gb == 80: + if model_size_in_b <= 1.0: + self.tp = [1, 2] + self.mbs = [16, 32, 64, 128] + self.gbs = 2048 + elif model_size_in_b <= 4.0: + self.tp = [1, 2, 4] + self.mbs = [4, 6, 8, 12, 16, 24, 32, 48] + self.gbs = 1920 + elif model_size_in_b <= 8.0: + self.tp = [2, 4, 8] + self.mbs = [4, 6, 8, 12, 16, 24, 32] + self.gbs = 1920 + elif model_size_in_b <= 14.5: + self.tp = [4, 8] + self.mbs = [2, 4, 6, 8, 12, 16, 24] + self.gbs = 1920 + elif model_size_in_b <= 25.9: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [1, 2, 4, 6, 8] + self.min_model_parallel = 4 + self.max_model_parallel = 16 + self.gbs = 1920 + elif model_size_in_b <= 43.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 4] + self.mbs = [1, 2, 4, 6, 8] + self.min_model_parallel = 8 + self.max_model_parallel = 32 + self.gbs = 1920 + elif model_size_in_b <= 85.5: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 2 <= x <= 8] + self.mbs = [1, 2, 4, 6, 8] + self.min_model_parallel = 16 + self.max_model_parallel = 64 + self.gbs = 1920 + elif model_size_in_b <= 165.5: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 4 <= x <= 16] + self.mbs = [1, 2, 4, 6] + self.min_model_parallel = 32 + self.max_model_parallel = 128 + self.gbs = 1920 + elif model_size_in_b <= 250: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 4 <= x <= 32] + self.mbs = [1, 2, 4, 6, 8] + self.min_model_parallel = 64 + self.max_model_parallel = 256 + self.gbs = 1920 + elif gpu_memory_gb == 40: + if model_size_in_b <= 1.0: + self.tp = [1, 2] + self.mbs = [16, 32, 64, 128] + self.gbs = 2048 + elif model_size_in_b <= 4.0: + self.tp = [1, 2, 4] + self.mbs = [4, 8, 12, 16, 24, 32, 48] + self.gbs = 1920 + elif model_size_in_b <= 8.0: + self.tp = [2, 4, 8] + self.mbs = [4, 6, 8, 12, 16, 24] + self.gbs = 1920 + elif model_size_in_b <= 14.5: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 2] + self.mbs = [2, 4, 6, 8, 12, 16] + self.min_model_parallel = 4 + self.max_model_parallel = 16 + self.gbs = 1920 + elif model_size_in_b <= 25.9: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 8] + self.mbs = [1, 2, 4, 6, 8] + self.min_model_parallel = 8 + self.max_model_parallel = 32 + self.gbs = 1920 + elif model_size_in_b <= 43.0: + self.tp = [4, 8] + self.pp = [x for x in self.valid_pp if 1 <= x <= 8] + self.mbs = [1, 2, 4, 6, 8] + self.min_model_parallel = 16 + self.max_model_parallel = 32 + self.gbs = 1920 + elif model_size_in_b <= 85.5: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 2 <= x <= 8] + self.mbs = [1, 2, 4, 6, 8] + self.min_model_parallel = 32 + self.max_model_parallel = 64 + self.gbs = 1920 + elif model_size_in_b <= 165.5: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 4 <= x <= 32] + self.mbs = [1, 2, 4] + self.min_model_parallel = 64 + self.max_model_parallel = 128 + self.gbs = 1920 + elif model_size_in_b <= 250: + self.tp = [8] + self.pp = [x for x in self.valid_pp if 8 <= x <= 64] + self.mbs = [1, 2, 4] + self.min_model_parallel = 128 + self.max_model_parallel = 256 + self.gbs = 1920 + + +@dataclass +class BertGridSearch: + """Selects grid search space for TP, PP, MBS parameters for BERT and 80GB GPUs. + + Args: + model_size_in_b (float): number of parameters in the model. + valid_pp (List[int]): list of valid Pipeline Parallelism (PP) values for this config. + seq_length (int): sequence length to use for training. + gpu_memory_gb (int): size of GPU memory in GB. + """ + + model_size_in_b: int + seq_length: int + gpu_memory_gb: int + valid_pp: List[int] + + tp = [1, 2, 4, 8] + pp = [1] + cp = [None] + ep = [None] + mbs = [1, 2, 4, 6, 8, 12, 16] + + gbs: int = 1920 + min_model_parallel: int = 1 + max_model_parallel: int = 8 + + def init_params(self): + model_size_in_b = self.model_size_in_b + gpu_memory_gb = self.gpu_memory_gb + seq_length = self.seq_length + + if gpu_memory_gb == 80: + if model_size_in_b <= 1.0: + self.tp = [1, 2] + self.gbs = 256 + elif model_size_in_b <= 4.0: + self.tp = [1, 2, 4] + self.gbs = 1024 + elif model_size_in_b <= 8.0: + self.tp = [2, 4, 8] + self.min_model_parallel = 2 + self.gbs = 2048 + elif model_size_in_b <= 13.0: + self.tp = [2, 4, 8] + self.mbs = [1, 2, 3, 4, 6] + self.min_model_parallel = 2 + self.gbs = 2048 + elif model_size_in_b <= 25.0: + self.tp = [4, 8] + self.mbs = [1, 2, 3, 4] + self.min_model_parallel = 4 + self.gbs = 2048 + elif model_size_in_b <= 46.5: + self.tp = [4, 8] + self.pp = [1, 2, 4] + self.mbs = [1, 2, 3, 4] + self.min_model_parallel = 4 + self.max_model_parallel = 16 + self.gbs = 2048 + elif model_size_in_b <= 87.5: + self.tp = [4, 8] + self.pp = [2, 4, 6, 8] + self.mbs = [1, 2, 3, 4] + self.min_model_parallel = 8 + self.max_model_parallel = 32 + self.gbs = 2048 + elif model_size_in_b <= 165.5: + self.tp = [4, 8] + self.pp = [4, 6, 8, 16] + self.mbs = [2, 4, 6, 8] + self.min_model_parallel = 16 + self.max_model_parallel = 128 + self.gbs = 2048 + elif model_size_in_b <= 250.5: + self.tp = [8] + self.pp = [4, 8, 16, 32] + self.mbs = [1, 2, 3, 4] + self.min_model_parallel = 32 + self.max_model_parallel = 256 + self.gbs = 2048 + else: + raise ValueError("No BERT model larger than 250B parameters is supported.") + elif gpu_memory_gb == 40: + if model_size_in_b <= 1.0: + self.tp = [1, 2, 4] + self.gbs = 256 + elif model_size_in_b <= 4.0: + self.tp = [1, 2, 4, 8] + self.gbs = 1024 + elif model_size_in_b <= 8.0: + self.tp = [2, 4, 8] + self.mbs = [1, 2, 4] + self.gbs = 2048 + elif model_size_in_b <= 13.0: + self.tp = [2, 4, 8] + self.mbs = [1, 2, 4] + self.gbs = 2048 + elif model_size_in_b <= 25.0: + self.tp = [2, 4, 8] + self.pp = [1, 2] + self.mbs = [1, 2, 4] + self.min_model_parallel = 2 + self.max_model_parallel = 16 + self.gbs = 2048 + elif model_size_in_b <= 46.5: + self.tp = [4, 8] + self.pp = [1, 2, 4, 8] + self.mbs = [1, 2, 3] + self.min_model_parallel = 8 + self.max_model_parallel = 32 + self.gbs = 2048 + elif model_size_in_b <= 87.5: + self.tp = [4, 8] + self.pp = [2, 4, 6, 8] + self.mbs = [1, 2, 3] + self.min_model_parallel = 16 + self.max_model_parallel = 64 + self.gbs = 2048 + elif model_size_in_b <= 165.5: + self.tp = [8] + self.pp = [4, 6, 8, 16] + self.mbs = [1, 2] + self.min_model_parallel = 32 + self.max_model_parallel = 256 + self.gbs = 2048 + elif model_size_in_b <= 250.5: + self.tp = [8] + self.pp = [8, 16, 32] + self.mbs = [1, 2] + self.min_model_parallel = 64 + self.max_model_parallel = 512 + self.gbs = 2048 + else: + raise ValueError("No BERT model larger than 250B parameters is supported.") + + +def _calculate_tp_pp_mbs_grid( + model_size_in_b: float, + num_layers: int, + model_name: str, + seq_length: int, + train_cfg: dict, +) -> Tuple[int, int, int]: + """Selects grid search space for TP, PP, MBS parameters for any model, and calls the necessary heuristics function accordingly. + + Args: + model_size_in_b (float): number of parameters in the model. + num_layers (int): number of layers in the model config. + model_name (str): name of the model to be used, such as gpt3, t5, mt5... + seq_length (int): sequence length to use for training. + train_cfg (dict): config of the model that will be launched. + + Returns: + dataclass object with model parallelism parameters. + + Raises: + NotImplementedError: if the model_name is not one of the supported models. + """ + + tp_sizes = train_cfg.tensor_parallel_sizes + pp_sizes = train_cfg.pipeline_parallel_sizes + cp_sizes = train_cfg.context_parallel_sizes + ep_sizes = train_cfg.expert_parallel_sizes + min_model_parallel_size = train_cfg.min_model_parallel_size + max_model_parallel_size = train_cfg.max_model_parallel_size + mbs_sizes = train_cfg.micro_batch_sizes + gbs_size = train_cfg.global_batch_size + gpu_memory_gb = train_cfg.gpu_memory_gb + multiplier = 1 if model_name in GPT_BASED_MODELS else 2 + init_pp = [] if model_name in GPT_BASED_MODELS else [1] + valid_pp = init_pp + [ + multiplier * x for x in range(1, num_layers + 1) if num_layers % x == 0 + ] # Only divisors of num_layers are possible. + + kwargs = { + "model_size_in_b": model_size_in_b, + "valid_pp": valid_pp, + "seq_length": seq_length, + "gpu_memory_gb": gpu_memory_gb, + } + + if model_name in GPT_BASED_MODELS: + search_class = GPT3GridSearch + elif model_name in ["t5", "mt5"]: + search_class = T5GridSearch + elif model_name == "bert": + search_class = BertGridSearch + else: + raise NotImplementedError("Model name not implemented.") + + params = search_class(**kwargs) + params.init_params() + + # Override the tp, pp, mbs search if indicated in the config params. + if tp_sizes is not None and tp_sizes != "auto": + params.tp = tp_sizes + if pp_sizes is not None and pp_sizes != "auto": + params.pp = pp_sizes + if cp_sizes is not None and cp_sizes != "auto": + params.cp = cp_sizes + if ep_sizes is not None and ep_sizes != "auto": + params.ep = ep_sizes + if mbs_sizes is not None and mbs_sizes != "auto": + params.mbs = mbs_sizes + if gbs_size is not None and gbs_size != "auto": + params.gbs = gbs_size + if min_model_parallel_size is not None and min_model_parallel_size != "auto": + params.min_model_parallel = min_model_parallel_size + if max_model_parallel_size is not None and max_model_parallel_size != "auto": + params.max_model_parallel = max_model_parallel_size + return params diff --git a/nemo/collections/llm/tools/auto_configurator/core/utils.py b/nemo/collections/llm/tools/auto_configurator/core/utils.py new file mode 100644 index 000000000000..3441c7cdbf9b --- /dev/null +++ b/nemo/collections/llm/tools/auto_configurator/core/utils.py @@ -0,0 +1,470 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +GPT_BASED_MODELS = [ + "gpt3", + "bert", + "llama", + "baichuan2", + "chatglm", + "qwen2", + "mixtral", + "mistral", + "gemma", + "nemotron", +] + + +@dataclass +class ModelSizeParams: + """Calculates the parameters that affect model_size: hidden size, attention heads, KV channels, and FFN size. It also calculates the learning rate. + + Args: + model_size_in_b (float): number of parameters in the desired model config, in billions. + vocab_size (int): size of the vocabulary to use for training. + seq_length (int): sequence length to be used during training. + model_name (str): name of the model to be trained, i.e. gpt3, t5, mt5... + + Raises: + ValueError: if the model size is larger than the max supported model size. + NotImplementedError: if the model name is not supported. + """ + + model_size_in_b: float + vocab_size: int + seq_length: int + model_name: str + + # Model size params + layers: int = None + hs: int = None + att_h: int = None + ffn: int = None + kv: int = None + lr: float = None + + def init_params(self): + model_name = self.model_name + model_size_in_b = self.model_size_in_b + if model_name in GPT_BASED_MODELS: + if model_size_in_b < 0.25: + self.hs, self.att_h, self.lr = 768, 12, 6e-4 + elif model_size_in_b < 0.5: + self.hs, self.att_h, self.lr = 1024, 16, 3e-4 + elif model_size_in_b < 1: + self.hs, self.att_h, self.lr = 1536, 16, 2.5e-4 + elif model_size_in_b < 2: + self.hs, self.att_h, self.lr = 2048, 16, 2e-4 + elif model_size_in_b < 3: + self.hs, self.att_h, self.lr = 2560, 32, 1.6e-4 + elif model_size_in_b < 4.5: + self.hs, self.att_h, self.lr = 3072, 32, 1.4e-4 + elif model_size_in_b < 8: + self.hs, self.att_h, self.lr = 4096, 32, 1.2e-4 + elif model_size_in_b < 15: + self.hs, self.att_h, self.lr = 5120, 40, 1e-4 + elif model_size_in_b < 25: + self.hs, self.att_h, self.lr = 6144, 48, 1e-4 + elif model_size_in_b < 52: + self.hs, self.att_h, self.lr = 8192, 64, 0.8e-4 + elif model_size_in_b < 105: + self.hs, self.att_h, self.lr = 10240, 80, 0.7e-4 + elif model_size_in_b < 205: + self.hs, self.att_h, self.lr = 12288, 96, 0.6e-4 + elif model_size_in_b < 405: + self.hs, self.att_h, self.lr = 20480, 128, 0.5e-4 + elif model_size_in_b < 805: + self.hs, self.att_h, self.lr = 20480, 128, 0.4e-4 + elif model_size_in_b < 1105: + self.hs, self.att_h, self.lr = 25600, 160, 0.3e-4 + else: + raise ValueError("Model_size for GPT-3 must be smaller than 1.1T parameters.") + elif model_name == "t5": + self.kv, self.lr = 64, 1e-4 + if model_size_in_b < 0.1: + self.hs, self.att_h, self.ffn = 512, 6, 1024 + elif model_size_in_b < 0.4: + self.hs, self.att_h, self.ffn = 768, 12, 2048 + elif model_size_in_b < 1: + self.hs, self.att_h, self.ffn = 1024, 16, 2816 + elif model_size_in_b < 5: + self.hs, self.att_h, self.ffn = 2048, 32, 5120 + elif model_size_in_b < 15: + self.hs, self.att_h, self.ffn = 4096, 64, 10240 + elif model_size_in_b < 25.9: + self.hs, self.att_h, self.ffn = 5120, 80, 10880 + elif model_size_in_b < 43.0: + self.hs, self.att_h, self.ffn = 6144, 96, 10880 + elif model_size_in_b <= 85.5: + self.hs, self.att_h, self.ffn = 6144, 96, 16384 + elif model_size_in_b <= 165.5: + self.hs, self.att_h, self.ffn, kv = 7680, 96, 20480, 128 + elif model_size_in_b <= 250: + self.hs, self.att_h, self.ffn, kv = 12288, 96, 32768, 128 + else: + raise ValueError("Model_size for T5 must be smaller than 250B parameters.") + elif model_name == "mt5": + self.kv, self.lr = 64, 1e-4 + if model_size_in_b < 0.25: + self.hs, self.att_h, self.ffn = 512, 6, 1024 + elif model_size_in_b < 0.5: + self.hs, self.att_h, self.ffn = 768, 12, 2048 + elif model_size_in_b < 1.2: + self.hs, self.att_h, self.ffn = 1024, 16, 2816 + elif model_size_in_b < 5: + self.hs, self.att_h, self.ffn = 2048, 32, 5120 + elif model_size_in_b < 15: + self.hs, self.att_h, self.ffn = 4096, 64, 10240 + elif model_size_in_b < 25.9: + self.hs, self.att_h, self.ffn = 5120, 80, 10880 + elif model_size_in_b < 43.0: + self.hs, self.att_h, self.ffn = 6144, 96, 10880 + elif model_size_in_b <= 85.5: + self.hs, self.att_h, self.ffn = 6144, 96, 16384 + elif model_size_in_b <= 165.5: + self.hs, self.att_h, self.ffn, kv = 7680, 96, 20480, 128 + elif model_size_in_b <= 250: + self.hs, self.att_h, self.ffn, kv = 12288, 96, 32768, 128 + else: + raise ValueError("Model_size for mT5 must be smaller than 250B parameters.") + elif model_name == "bert": + self.lr = 1e-4 + if model_size_in_b < 0.25: + self.hs, self.att_h, self.lr = 768, 12, 2e-4 + elif model_size_in_b < 0.5: + self.hs, self.att_h, self.lr = 1024, 16, 2e-4 + elif model_size_in_b < 1: + self.hs, self.att_h = 1536, 16 + elif model_size_in_b < 2: + self.hs, self.att_h = 2048, 16 + elif model_size_in_b < 3: + self.hs, self.att_h = 2560, 32 + elif model_size_in_b < 4.5: + self.hs, self.att_h = 2560, 32 + elif model_size_in_b < 8: + self.hs, self.att_h = 4096, 32 + elif model_size_in_b < 15: + self.hs, self.att_h = 5120, 40 + elif model_size_in_b <= 25: + self.hs, self.att_h = 6144, 48 + elif model_size_in_b <= 46.5: + self.hs, self.att_h = 7680, 48 + elif model_size_in_b <= 87.5: + self.hs, self.att_h = 9216, 96 + elif model_size_in_b <= 165.5: + self.hs, self.att_h = 9216, 96 + elif model_size_in_b <= 250.5: + self.hs, self.att_h = 12288, 96 + else: + raise ValueError("Model_size for BERT must be smaller than 25B parameters.") + self.ffn = 4 * self.hs + else: + raise NotImplementedError("Model name is not valid.") + + # Try powers of 2 + margin = 0.01 + for attempt in range(0, 10): + for layers in (2**p for p in range(1, 10)): + out_size = _calculate_model_size( + vocab_size=self.vocab_size, + seq_length=self.seq_length, + hidden_size=self.hs, + num_layers=layers, + ffn_size=self.ffn, + kv_channels=self.kv, + att_heads=self.att_h, + model_name=self.model_name, + ) + if model_size_in_b * (1.0 - margin) < out_size < model_size_in_b * (1.0 + margin) and not self.layers: + self.layers = layers + margin += 0.01 # Double margin of acceptable model sizes. + + # Try multiples of 16 + margin = 0.01 + for attempt in range(0, 6): + for layers in range(16, 201, 16): + out_size = _calculate_model_size( + vocab_size=self.vocab_size, + seq_length=self.seq_length, + hidden_size=self.hs, + num_layers=layers, + ffn_size=self.ffn, + kv_channels=self.kv, + att_heads=self.att_h, + model_name=self.model_name, + ) + if model_size_in_b * (1.0 - margin) < out_size < model_size_in_b * (1.0 + margin) and not self.layers: + self.layers = layers + margin += 0.01 # Double margin of acceptable model sizes. + + # Try multiples of 2 + margin = 0.01 + for attempt in range(0, 6): + for layers in range(2, 201, 2): + out_size = _calculate_model_size( + vocab_size=self.vocab_size, + seq_length=self.seq_length, + hidden_size=self.hs, + num_layers=layers, + ffn_size=self.ffn, + kv_channels=self.kv, + att_heads=self.att_h, + model_name=self.model_name, + ) + if model_size_in_b * (1.0 - margin) < out_size < model_size_in_b * (1.0 + margin) and not self.layers: + self.layers = layers + margin += 0.01 # Double margin of acceptable model sizes. + + # Try multiples of 5 + margin = 0.01 + for attempt in range(0, 6): + for layers in range(5, 201, 5): + out_size = _calculate_model_size( + vocab_size=self.vocab_size, + seq_length=self.seq_length, + hidden_size=self.hs, + num_layers=layers, + ffn_size=self.ffn, + kv_channels=self.kv, + att_heads=self.att_h, + model_name=self.model_name, + ) + if model_size_in_b * (1.0 - margin) < out_size < model_size_in_b * (1.0 + margin) and not self.layers: + self.layers = layers + margin += 0.01 # Double margin of acceptable model sizes. + + # Try any valid number + margin = 0.01 + for attempt in range(0, 10): + for layers in range(1, 200): + out_size = _calculate_model_size( + vocab_size=self.vocab_size, + seq_length=self.seq_length, + hidden_size=self.hs, + num_layers=layers, + ffn_size=self.ffn, + kv_channels=self.kv, + att_heads=self.att_h, + model_name=self.model_name, + ) + if model_size_in_b * (1.0 - margin) < out_size < model_size_in_b * (1.0 + margin) and not self.layers: + self.layers = layers + margin += 0.01 # Double margin of acceptable model sizes. + + if not self.layers: + raise Exception("Number of layers not found, config is not possible.") + + +def _calculate_model_size( + vocab_size: int = None, + seq_length: int = None, + hidden_size: int = None, + num_layers: int = None, + ffn_size: int = None, + kv_channels: int = None, + att_heads: int = None, + model_name: str = "gpt3", +): + """Calculates the model size (number of parameters in billions), given the model parameters and name. + + Args: + vocab_size (int): vocabulary size to be used during training. + seq_length (int): input sequence length to be used during training. + hidden_size (int): size of the hidden layers of the model. + num_layers (int): number of layers in the model. + ffn_size (int): FFN size of the model. + kv_channels (int): number of KV channels in the transformer layers. + att_heads (int): number of attention heads in the transformer layers. + model_name (str): name of the model, i.e gpt3, t5, mt5... + + Returns: + float: size of the model in billions of parameters. + + Raises: + NotImplementedError: if the model name is not valid. + """ + + if model_name in GPT_BASED_MODELS: + model_size = ( + 12 + * num_layers + * hidden_size**2 + * (1 + (13 / (12 * hidden_size)) + ((vocab_size + seq_length) / (12 * num_layers * hidden_size))) + / 1e9 + ) + elif model_name in ["t5", "mt5"]: + # 2 L F + 3 L P + H (2 + 4 L F + L (21 + 12 P) + 1 S + 1 V) + proj_size = att_heads * kv_channels + model_size = ( + 2 * num_layers * 1.5 * ffn_size + + 3 * num_layers * proj_size + + hidden_size + * (2 + 4 * num_layers * 1.5 * ffn_size + num_layers * (21 + 12 * proj_size) + seq_length + vocab_size) + ) / 1e9 + elif model_name == "bert": + model_size = ( + num_layers * (ffn_size + hidden_size * (4 * hidden_size + 3 * att_heads + 2 * ffn_size + 6)) + + hidden_size * (vocab_size + seq_length + hidden_size + 5) + ) / 1e9 + + else: + raise NotImplementedError("Model name is not valid.") + + return model_size + + +def generic_base_config(config) -> dict: + """Generates a base config dictionary from a base config python file. + + Args: + config (AutoConfigurator): config object for the Auto Configurator tool. + + Returns: + BaseConfig: base configuration for the model. + AutoConfigurator: config object for the Auto Configurator tool. + """ + + from nemo.collections.llm.tools.auto_configurator.core.base_config import BaseConfig, calculate_model_size + + default_model = False if config.model_size_in_b else True + + model_size_in_b = calculate_model_size( + config.gpu_count, + config.max_training_days, + config.model_size_in_b, + config.tflops_per_gpu, + config.num_tokens_in_b, + config.model_type, + ) + base_cfg = BaseConfig(config) + + if default_model: + params = ModelSizeParams( + model_size_in_b, + config.vocab_size, + config.seq_length, + config.model_type, + ) + params.init_params() + + if config.model_type in GPT_BASED_MODELS: + base_cfg.model.num_layers = params.layers + base_cfg.model.hidden_size = params.hs + base_cfg.model.num_attention_heads = params.att_h + base_cfg.model.kv_channels = params.kv + if not params.ffn: + base_cfg.model.ffn_hidden_size = params.hs * 4 + else: + base_cfg.model.ffn_hidden_size = params.ffn + + config.model_size_in_b = model_size_in_b + + return base_cfg, config + + +def modify_cfg( + base_cfg: dict, + act: int, + num_mbs_act: int, + act_per_pipe: int, + tp: int, + pp: int, + cp: int, + ep: int, + virtual_pipelines: int, + mbs: int, + max_minutes: int, + max_steps: int, + num_nodes: int, + model_name: str, + model_size, +) -> dict: + """Modify the base configuration for the model with the new parameters that are specific to the current model, which the Auto Configurator tool heuristics selected. + + Args: + base_cfg (dict): base configuration for the current model, which will be modified in this function. + act (int): number of activation checkpointing layers to use for the model. + num_mbs_act (int): sets the number of micro-batches where only a partial number of Transformer layers get checkpointed and recomputed within a window of micro-batches. + act_per_pipe (int): sets the number of Transformer layers to skip checkpointing at later pipeline stages. + tp (int): Tensor Parallelism (TP) value to be set for the model. + pp (int): Pipeline Parallelism (PP) value to be set for the model. + cp (int): Context Parallelism (CP) value to be set for the model. + ep (int): Expert Parallelism (EP) value to be set for the model. + virtual_pipelines (int): Virtual Pipelines value to be set for the model. + mbs (int): Micro Batch Size (MBS) value to be set for the model. + max_minutes (int): maximum amount of time to run this model for. + max_steps (int): maximum number of steps to run this model for. + num_nodes (int): number of nodes to use for the training run. + model_name (str): name of the model, i.e. gpt3, t5, mt5... + + Returns: + dict: dictionary containing the updated model configuration parameters. + """ + + if model_name in GPT_BASED_MODELS: + att_heads = base_cfg.model.num_attention_heads + num_layers = base_cfg.model.num_layers + else: + att_heads = base_cfg.model.encoder.num_attention_heads + num_layers = base_cfg.model.encoder.num_layers + + # gbs = mbs * num_gpus * accumulate_grad_batches / (tp * pp) + num_gpus = base_cfg.trainer.num_nodes * base_cfg.trainer.devices + gbs = base_cfg.data.global_batch_size + seq_len = base_cfg.model.seq_length + + new_cfg = dict(run=base_cfg.run) + if act is not None: + if model_name in GPT_BASED_MODELS: + new_cfg["activations_checkpoint_num_layers"] = act + else: + new_cfg["encoder"]["activations_checkpoint_num_layers"] = act // 2 + new_cfg["decoder"]["activations_checkpoint_num_layers"] = act // 2 + + if num_mbs_act is not None and model_name in GPT_BASED_MODELS: + new_cfg["num_micro_batches_with_partial_activation_checkpoints"] = num_mbs_act + + if act_per_pipe is not None and model_name in GPT_BASED_MODELS: + new_cfg["activations_checkpoint_layers_per_pipeline"] = act_per_pipe + + if virtual_pipelines is not None and model_name in GPT_BASED_MODELS: + new_cfg["virtual_pipeline_model_parallel_size"] = virtual_pipelines + + new_cfg["tensor_model_parallel_size"] = tp + new_cfg["pipeline_model_parallel_size"] = pp + new_cfg["micro_batch_size"] = mbs + new_cfg["global_batch_size"] = gbs + + if cp is not None: + new_cfg["context_parallel_size"] = cp + + if ep is not None: + new_cfg["expert_model_parallel_size"] = ep + + mod_gbs = gbs % (mbs * num_gpus / (tp * pp)) + mod_att_heads = att_heads % tp + mod_layers = num_layers % pp + if mod_gbs == 0 and mod_att_heads == 0 and mod_layers == 0: + # Valid config + new_cfg["run"][ + "name" + ] = f"{model_name}_{str(model_size)}b_{num_nodes}nodes_tp_{tp}_pp_{pp}_cp_{cp}_ep_{ep}_mbs_{mbs}_act_ckpt_{act}_num_mbs_act_{num_mbs_act}_act_per_pipe_{act_per_pipe}" + print( + f"Valid config: SeqLen={seq_len}, GBS={gbs}, MBS={mbs}, TP={tp}, PP={pp}, CP={cp}, EP={ep}, act_ckpt_layers={act}, num_mbs_act={num_mbs_act}, act_per_pipe={act_per_pipe}. Adding to directory." + ) + return new_cfg + return None diff --git a/nemo/collections/llm/tools/auto_configurator/runner.py b/nemo/collections/llm/tools/auto_configurator/runner.py new file mode 100644 index 000000000000..0c80c9a21a9e --- /dev/null +++ b/nemo/collections/llm/tools/auto_configurator/runner.py @@ -0,0 +1,246 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import re + +from typing import List, Optional + +from nemo.collections.llm import GPTModel +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.tools.auto_configurator.core.training_config import generate_grid_search_configs +from nemo.collections.llm.tools.auto_configurator.core.utils import generic_base_config +from nemo.collections.llm.utils import Config, Partial +from nemo.utils import logging + +SUPPORTED_MODELS = [ + "gpt3", + "llama", + "mixtral", + "mistral", + "gemma", + "nemotron", +] + +SUPPORTED_TOKENIZERS = [ + "autotokenizer", + "sentencepiece", + "huggingface", +] + + +class AutoConfigurator: + """Auto Configurator runner config class.""" + + def __init__( + self, + model: Config = None, + num_nodes: int = None, + data_paths: List = None, + path_to_logs: str = None, + tokenizer_type: Optional[str] = "autotokenizer", + tokenizer_path: Optional[str] = "GPT2BPETokenizer", + gpus_per_node: Optional[int] = 8, + gpu_memory_gb: Optional[int] = 80, + seq_length: Optional[int] = 2048, + global_batch_size: Optional[int] = "auto", + tensor_parallel_sizes: Optional[List[int]] = "auto", + pipeline_parallel_sizes: Optional[List[int]] = "auto", + micro_batch_sizes: Optional[List[int]] = "auto", + context_parallel_sizes: Optional[List[int]] = [1], + expert_parallel_sizes: Optional[List[int]] = [1], + min_model_parallel_size: Optional[int] = "auto", + max_model_parallel_size: Optional[int] = "auto", + num_tokens_in_b: Optional[int] = 300, + tflops_per_gpu: Optional[int] = 140, + max_minutes_per_run: Optional[int] = 30, + max_training_days: Optional[int] = 2, + max_steps_per_run: Optional[int] = 50, + vocab_size: Optional[int] = 51200, + ): + """ + Args: + model_type (Config): model type to be used for training. + num_nodes (int): number of nodes to be used for training. + data_paths (List): list of datafiles to be used for training. + path_to_logs (str): path to the directory where the logs will be stored. + tokenizer_type (Optional[str]): tokenizer type. + tokenizer_path (Optional[str]): path to the tokenizer model. + model_size (Optional[int]): size of model to be trained. + gpus_per_node (Optional[int]): number of GPUs per node to be used. + gpu_memory_gb (Optional[int]): memory per GPU, in GB. Currently 40GB and 80GB A100s/H100s supported. + seq_length (Optional[int]): model sequence length. Available seq_length list for GPT-based models: [2048, 4096, 8192, 16384, 32768]. + global_batch_size (Optional[int]): model global batch size. Set to "auto" if you want auto configurator to find optimal gbs. + tensor_parallel_sizes (Optional[List[int]]): set to "auto" to use our recommendation, or a list, such as [1, 2, 4, 8]. + pipeline_parallel_sizes (Optional[List[int]]): set to "auto" to use our recommendation, or a list, such as [1, 2, 4, 8]. + micro_batch_sizes (Optional[List[int]]): set to "auto" to use our recommendation, or a list, such as [1, 2, 4, 8]. + context_parallel_sizes (Optional[List[int]]): model context parallel size. A list, such as [1, 2, 4, 8]. + expert_parallel_sizes (Optional[List[int]]): model expert parallel size. A list, such as [1, 2, 4, 8]. + min_model_parallel_size (Optional[int]): set to "auto" to use our recommendation, or a value for the minimum desired parallelism. + max_model_parallel_size (Optional[int]): set to "auto" to use our recommendation, or a value for the maximum desired parallelism. + num_tokens_in_b (Optional[int]): number of tokens in billions in train dataset. + tflops_per_gpu (Optional[int]): estimated tflops per GPU. + max_minutes_per_run (Optional[int]): maximum number of minutes per run for the grid search. + max_training_days (Optional[int]): number of days expected model to be trained. + max_steps_per_run (Optional[int]): maximum number of steps per run for the grid search. + vocab_size (Optional[int]): size of tokenizer vocabulary. + """ + + # Print out the config + config = locals() + config.pop('self') + for key, value in config.items(): + setattr(self, key, value) + logging.info(self._get_message(config)) + + model_type = self._get_model_type(model) + assert model_type in SUPPORTED_MODELS, f"model_type must be set to one of {SUPPORTED_MODELS}." + assert tokenizer_type in SUPPORTED_TOKENIZERS, f"tokenizer_type must be set to one of {SUPPORTED_TOKENIZERS}." + assert num_nodes, "num_nodes value must be specified." + assert data_paths, "training data must be specified." + assert path_to_logs, f"path_to_logs parameter must be specified." + gpu_count = num_nodes * gpus_per_node + assert gpu_count > 0, "num_nodes * gpus_per_node must be an int larger than zero." + assert gpu_memory_gb in ( + 40, + 80, + ), "gpu_memory_gb can only be 40 or 80." + assert max_minutes_per_run >= 10, "max_minutes_per_run must be an int and be at least 10 minutes." + + self.model_type = model_type + self.model_size_in_b = self._get_model_size(model) + self.gpu_count = gpu_count + self.num_gpus = gpus_per_node + + def _get_message(self, config: dict) -> str: + """ + Function that returns runner config line by line. + + Args: + config (dict): runner config. + + Returns: + str: runner config params. + """ + + message = "AutoConfigurator runner config:\n" + for key, value in config.items(): + message += f"{key}: {value}\n" + + return message + + def _get_model_type(self, model: Config) -> str: + """ + Function that returns model type from model class name. + + Args: + models (Config): model object. + + Returns: + str: model type. + """ + + match = re.search(r"\w+\d+[MB]", str(model)) + if match: + model = match.group(0) + + if "GPT" in model: + return "gpt3" + elif "Llama" in model: + return "llama" + elif "Mixtral" in model: + return "mixtral" + elif "Mistral" in model: + return "mistral" + elif "Gemma" in model: + return "gemma" + elif "Nemotron" in model: + return "nemotron" + else: + return None + + def _get_model_size(self, model: Config) -> int: + """ + Function that returns model size from model class name. + + Args: + model (Config): model class name. + + Returns: + int: model size. + """ + match = re.search(r'(\d+)([BM])', str(model)) + if match: + size = int(match.group(1)) + measure = match.group(2) + if measure == 'B': + return size + elif measure == 'M': + return size / 1000 # Convert millions to billions + return None + + +def generate_configs(runner_config: AutoConfigurator = None) -> dict: + """ + Function that returns a dictionary of Partial configs. + + Args: + config (AutoConfigurator): Auto Configurator object. + + Returns: + dict: dictionary of Partial configs. + """ + + # Generate base config for the given model size + base_cfg, train_cfg = generic_base_config(runner_config) + + # Launch grid search for training constraints + base_config, train_configs = generate_grid_search_configs(base_cfg, train_cfg) + + tokenizer = base_config.tokenizer + model = Config(GPTModel, config=base_config.model, tokenizer=tokenizer) + + configs = {} + for name, config in train_configs.items(): + trainer = copy.deepcopy(base_config.trainer) + data = copy.deepcopy(base_config.data) + log = copy.deepcopy(base_config.log) + + # Set data params + data.micro_batch_size = config.get("micro_batch_size") + data.global_batch_size = config.get("global_batch_size") + + # Set strategy params + trainer.strategy.tensor_model_parallel_size = config.get("tensor_model_parallel_size") + trainer.strategy.pipeline_model_parallel_size = config.get("pipeline_model_parallel_size") + trainer.strategy.context_parallel_size = config.get("context_parallel_size") + trainer.strategy.expert_model_parallel_size = config.get("expert_model_parallel_size") + trainer.strategy.virtual_pipeline_model_parallel_size = config.get( + "virtual_pipeline_model_parallel_size", None + ) + if config.get("tensor_model_parallel_size") > 1: + trainer.strategy.sequence_parallel = True + + # Set the directory where to save the logs + configs[name] = Partial( + pretrain, + model=model, + trainer=trainer, + data=data, + optim=base_config.optim, + log=log, + resume=None, + ) + + return base_cfg, configs diff --git a/tests/collections/llm/auto_conf/test_autoconf_utils.py b/tests/collections/llm/auto_conf/test_autoconf_utils.py new file mode 100644 index 000000000000..0faa86c13016 --- /dev/null +++ b/tests/collections/llm/auto_conf/test_autoconf_utils.py @@ -0,0 +1,131 @@ +from nemo.collections.llm.tools.auto_configurator.core.base_config import _estimate_training_time, calculate_model_size + + +class TestUtils: + def test_calculate_model_size(self): + # GPT + model_size = calculate_model_size( + 8, + 7, + None, + 140, + 300, + "gpt3", + ) + assert model_size == 0.28, f"expected model_size is 0.28 but got {model_size}." + + # Llama + model_size = calculate_model_size( + 128, + 30, + None, + 100, + 3000, + "llama", + ) + assert model_size == 1.38, f"expected model_size is 1.38 but got {model_size}." + + # Mixtral + model_size = calculate_model_size( + 256, + 20, + None, + 140, + 600, + "mixtral", + ) + assert model_size == 12.9, f"expected model_size is 12.9 but got {model_size}." + + # Mistral + model_size = calculate_model_size( + 1028, + 30, + None, + 240, + 100, + "mistral", + ) + assert model_size == 799.37, f"expected model_size is 799.37 but got {model_size}." + + # Gemma + model_size = calculate_model_size( + 512, + 30, + None, + 240, + 100, + "gemma", + ) + assert model_size == 398.13, f"expected model_size is 398.13 but got {model_size}." + + # Nemotron + model_size = calculate_model_size( + 256, + 15, + None, + 240, + 120, + "gemma", + ) + assert model_size == 82.94, f"expected model_size is 82.94 but got {model_size}." + + def test_calculate_train_time(self): + # GPT + train_time = _estimate_training_time( + 175, + 1024, + 140, + 300, + "gpt3", + ) + assert train_time == 33.91, f"expected train_time is 33.91 but got {train_time}." + + # Llama + train_time = _estimate_training_time( + 35, + 512, + 60, + 3000, + "llama", + ) + assert train_time == 316.48, f"expected train_time is 316.48 but got {train_time}." + + # Mixtral + train_time = _estimate_training_time( + 0.8, + 128, + 140, + 1000, + "mixtral", + ) + assert train_time == 4.13, f"expected train_time is 4.13 but got {train_time}." + + # Mistral + train_time = _estimate_training_time( + 11, + 24, + 60, + 250, + "mistral", + ) + assert train_time == 176.83, f"expected train_time is 176.83 but got {train_time}." + + # Gemma + train_time = _estimate_training_time( + 7, + 8, + 55, + 100, + "gemma", + ) + assert train_time == 147.31, f"expected train_time is 147.31 but got {train_time}." + + # Nemotron + train_time = _estimate_training_time( + 14, + 12, + 11, + 55, + "nemotron", + ) + assert train_time == 540.12, f"expected train_time is 540.12 but got {train_time}." diff --git a/tests/collections/llm/auto_conf/test_base_configs.py b/tests/collections/llm/auto_conf/test_base_configs.py new file mode 100644 index 000000000000..46ee49ae0629 --- /dev/null +++ b/tests/collections/llm/auto_conf/test_base_configs.py @@ -0,0 +1,341 @@ +import nemo_run as run +import torch + +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import TensorBoardLogger + +from nemo import lightning as nl +from nemo.collections.common.tokenizers import AutoTokenizer +from nemo.collections.llm import ( + GemmaConfig2B, + GPTConfig126M, + Llama3Config8B, + MistralConfig7B, + MixtralConfig8x3B, + Nemotron4Config22B, + PreTrainingDataModule, +) +from nemo.collections.llm.tools.auto_configurator import AutoConfigurator +from nemo.collections.llm.tools.auto_configurator.core.base_config import BaseConfig +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def get_tokenizer() -> run.Config: + return run.Config(AutoTokenizer, pretrained_model_name="GPT2BPETokenizer") + + +def get_data(seq_length, global_batch_size) -> run.Config[PreTrainingDataModule]: + config = { + "paths": "/", + "seq_length": seq_length, + "global_batch_size": global_batch_size, + "num_workers": 2, + "index_mapping_dir": None, + } + + return run.Config( + PreTrainingDataModule, + **config, + tokenizer=get_tokenizer(), + ) + + +def get_trainer(num_nodes) -> run.Config[nl.Trainer]: + trainer_config = { + "accelerator": "gpu", + "enable_checkpointing": False, + "use_distributed_sampler": False, + "max_epochs": None, + "log_every_n_steps": 1, + "limit_val_batches": 1, + "limit_test_batches": 1, + "accumulate_grad_batches": 1, + "num_nodes": num_nodes, + "devices": 8, + "max_steps": 50, + "val_check_interval": 50, + } + + strategy = run.Config( + nl.MegatronStrategy, + pipeline_dtype=torch.bfloat16, + ) + + return run.Config( + nl.Trainer, + **trainer_config, + strategy=strategy, + plugins=run.Config(nl.MegatronMixedPrecision, precision="bf16-mixed"), + callbacks=[run.Config(TimingCallback)], + ) + + +def get_optim() -> run.Config[OptimizerConfig]: + optim_params = { + "optimizer": "adam", + "lr": 1e-4, + "min_lr": 1e-5, + "use_distributed_optimizer": True, + "bf16": True, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "overlap_grad_reduce": True, + "overlap_param_gather": True, + "clip_grad": 1.0, + "adam_eps": 1e-5, + } + + optim_config = run.Config( + OptimizerConfig, + **optim_params, + ) + + sched = run.Config( + CosineAnnealingScheduler, + warmup_steps=10, + constant_steps=0, + min_lr=optim_config.min_lr, + ) + + return run.Config( + MegatronOptimizerModule, + config=optim_config, + lr_scheduler=sched, + ) + + +def get_logger() -> run.Config[nl.NeMoLogger]: + tb_logger = run.Config(TensorBoardLogger, save_dir="tb_logs") + + ckpt = run.Config( + nl.ModelCheckpoint, + monitor="reduced_train_loss", + save_last=False, + save_top_k=0, + ) + + return run.Config( + nl.NeMoLogger, + ckpt=ckpt, + tensorboard=tb_logger, + wandb=None, + dir="/", + ) + + +class TestBaseConfigs: + def test_gpt3_base_config(self): + # GPT3 7B + model_config = run.Config(GPTConfig126M) + runner = AutoConfigurator(model=model_config, num_nodes=8, path_to_logs="/", data_paths="/") + base_config = BaseConfig(runner) + model_size = runner._get_model_size(model_config) + model_type = runner._get_model_type(model_config) + data_config = get_data(2048, 'auto') + trainer_config = get_trainer(8) + optim_config = get_optim() + logger_config = get_logger() + + assert ( + base_config.model == model_config + ), f"{model_config} is expected class object but got {base_config.model}" + assert model_size == 0.126, f"0.126 is expected size for {model_config} but got {model_size}" + assert model_type == "gpt3", f"gpt3 is expected model type for {model_config} but got {model_type}" + assert ( + base_config.data == data_config + ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" + assert ( + base_config.trainer == trainer_config + ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" + assert ( + base_config.optim == optim_config + ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" + assert ( + base_config.log == logger_config + ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" + + def test_llama_base_config(self): + # Llama3 8B + model_config = run.Config(Llama3Config8B) + runner = AutoConfigurator( + model=model_config, + num_nodes=16, + path_to_logs="/", + data_paths="/", + seq_length=8192, + global_batch_size=2048, + ) + base_config = BaseConfig(runner) + model_size = runner._get_model_size(model_config) + model_type = runner._get_model_type(model_config) + data_config = get_data(8192, 2048) + trainer_config = get_trainer(16) + optim_config = get_optim() + logger_config = get_logger() + + assert ( + base_config.model == model_config + ), f"{model_config} is expected class object but got {base_config.model}" + assert model_size == 8, f"8 is expected size for {model_config} but got {model_size}" + assert model_type == "llama", f"llama is expected model type for {model_config} but got {model_type}" + assert ( + base_config.data == data_config + ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" + assert ( + base_config.trainer == trainer_config + ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" + assert ( + base_config.optim == optim_config + ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" + assert ( + base_config.log == logger_config + ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" + + def test_mistral_base_config(self): + # Mistral 7B + model_config = run.Config(MistralConfig7B) + runner = AutoConfigurator( + model=model_config, + num_nodes=16, + path_to_logs="/", + data_paths="/", + seq_length=32768, + global_batch_size=2048, + ) + base_config = BaseConfig(runner) + model_size = runner._get_model_size(model_config) + model_type = runner._get_model_type(model_config) + data_config = get_data(32768, 2048) + trainer_config = get_trainer(16) + optim_config = get_optim() + logger_config = get_logger() + + assert ( + base_config.model == model_config + ), f"{model_config} is expected class object but got {base_config.model}" + assert model_size == 7, f"7 is expected size for {model_config} but got {model_size}" + assert model_type == "mistral", f"mistral is expected model type for {model_config} but got {model_type}" + assert ( + base_config.data == data_config + ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" + assert ( + base_config.trainer == trainer_config + ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" + assert ( + base_config.optim == optim_config + ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" + assert ( + base_config.log == logger_config + ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" + + def test_mixtral_base_config(self): + # Mixtral 8x3B + model_config = run.Config(MixtralConfig8x3B) + runner = AutoConfigurator( + model=model_config, + num_nodes=16, + path_to_logs="/", + data_paths="/", + seq_length=4096, + global_batch_size=2048, + ) + base_config = BaseConfig(runner) + model_size = runner._get_model_size(model_config) + model_type = runner._get_model_type(model_config) + data_config = get_data(4096, 2048) + trainer_config = get_trainer(16) + optim_config = get_optim() + logger_config = get_logger() + + assert ( + base_config.model == model_config + ), f"{model_config} is expected class object but got {base_config.model}" + assert model_size == 3, f"3 is expected size for {model_config} but got {model_size}" + assert model_type == "mixtral", f"mixtral is expected model type for {model_config} but got {model_type}" + assert ( + base_config.data == data_config + ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" + assert ( + base_config.trainer == trainer_config + ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" + assert ( + base_config.optim == optim_config + ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" + assert ( + base_config.log == logger_config + ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" + + def test_gemma_base_config(self): + # Gemma 2B + model_config = run.Config(GemmaConfig2B) + runner = AutoConfigurator( + model=model_config, + num_nodes=8, + path_to_logs="/", + data_paths="/", + seq_length=4096, + global_batch_size=1024, + ) + base_config = BaseConfig(runner) + model_size = runner._get_model_size(model_config) + model_type = runner._get_model_type(model_config) + data_config = get_data(4096, 1024) + trainer_config = get_trainer(8) + optim_config = get_optim() + logger_config = get_logger() + + assert ( + base_config.model == model_config + ), f"{model_config} is expected class object but got {base_config.model}" + assert model_size == 2, f"2 is expected size for {model_config} but got {model_size}" + assert model_type == "gemma", f"gemma is expected model type for {model_config} but got {model_type}" + assert ( + base_config.data == data_config + ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" + assert ( + base_config.trainer == trainer_config + ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" + assert ( + base_config.optim == optim_config + ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" + assert ( + base_config.log == logger_config + ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" + + def test_nemotron_base_config(self): + # Nemotron 22B + model_config = run.Config(Nemotron4Config22B) + runner = AutoConfigurator( + model=model_config, + num_nodes=64, + path_to_logs="/", + data_paths="/", + seq_length=4096, + global_batch_size=2048, + ) + base_config = BaseConfig(runner) + model_size = runner._get_model_size(model_config) + model_type = runner._get_model_type(model_config) + data_config = get_data(4096, 2048) + trainer_config = get_trainer(64) + optim_config = get_optim() + logger_config = get_logger() + + assert ( + base_config.model == model_config + ), f"{model_config} is expected class object but got {base_config.model}" + assert model_size == 22, f"22 is expected size for {model_config} but got {model_size}" + assert model_type == "nemotron", f"nemotron is expected model type for {model_config} but got {model_type}" + assert ( + base_config.data == data_config + ), f"f{data_config} is expected data config for {model_config} but got {base_config.data}" + assert ( + base_config.trainer == trainer_config + ), f"f{trainer_config} is expected trainer config for {model_config} but got {base_config.trainer}" + assert ( + base_config.optim == optim_config + ), f"f{optim_config} is expected trainer config for {model_config} but got {base_config.optim}" + assert ( + base_config.log == logger_config + ), f"f{logger_config} is expected trainer config for {model_config} but got {logger_config}" diff --git a/tests/collections/llm/auto_conf/test_generate_configs.py b/tests/collections/llm/auto_conf/test_generate_configs.py new file mode 100644 index 000000000000..efb3bcf9a0ba --- /dev/null +++ b/tests/collections/llm/auto_conf/test_generate_configs.py @@ -0,0 +1,307 @@ +import nemo_run as run + +from nemo.collections.llm import ( + GemmaConfig7B, + GPTConfig5B, + Llama3Config70B, + MistralConfig7B, + MixtralConfig8x22B, + Nemotron3Config8B, +) +from nemo.collections.llm.tools.auto_configurator import AutoConfigurator, generate_configs + + +def get_auto_configs(configs): + auto_configs = [] + for run_name, config in configs.items(): + auto_configs.append( + [ + config.trainer.strategy.tensor_model_parallel_size, + config.trainer.strategy.pipeline_model_parallel_size, + config.trainer.strategy.context_parallel_size, + config.trainer.strategy.expert_model_parallel_size, + config.data.micro_batch_size, + ] + ) + + return auto_configs + + +class TestGenerateConfgis: + def test_gpt_model(self): + # GPT3 126M + runner = AutoConfigurator( + model=run.Config(GPTConfig5B), + num_nodes=16, + seq_length=2048, + global_batch_size=2048, + tensor_parallel_sizes=[4], + pipeline_parallel_sizes=[2], + micro_batch_sizes=[1, 2], + context_parallel_sizes=[1], + expert_parallel_sizes=[1], + min_model_parallel_size=8, + max_model_parallel_size=8, + data_paths="/", + path_to_logs="/", + ) + + _, configs = generate_configs(runner) + + mbs = [1, 2] + for run_name, config, mb in zip(configs.keys(), configs.values(), mbs): + assert config.data.micro_batch_size == mb + assert config.data.seq_length == 2048 + assert config.data.global_batch_size == 2048 + + assert len(configs) == 2, f"{len(configs)} configurations were generated but 2 were expected." + + auto_configs = get_auto_configs(configs) + assert auto_configs[0] == [ + 4, + 2, + 1, + 1, + 1, + ], f"[4, 2, 1, 1, 1] is expected configuration output but got {auto_configs[0]}." + + assert auto_configs[1] == [ + 4, + 2, + 1, + 1, + 2, + ], f"[4, 2, 1, 1, 2] is expected configuration output but got {auto_configs[1]}." + + def test_llama_model(self): + # Llama3 70B + runner = AutoConfigurator( + model=run.Config(Llama3Config70B), + num_nodes=128, + seq_length=8192, + global_batch_size=2048, + tensor_parallel_sizes="auto", + pipeline_parallel_sizes="auto", + micro_batch_sizes=[1], + context_parallel_sizes=[1, 2, 4], + expert_parallel_sizes=[1], + min_model_parallel_size=16, + max_model_parallel_size=64, + data_paths="/", + path_to_logs="/", + ) + + _, configs = generate_configs(runner) + + mbs = [1, 1, 1] + for run_name, config, mb in zip(configs.keys(), configs.values(), mbs): + assert config.data.micro_batch_size == mb + assert config.data.seq_length == 8192 + assert config.data.global_batch_size == 2048 + + assert len(configs) == 3, f"{len(configs)} configurations were generated but 3 were expected." + + auto_configs = get_auto_configs(configs) + assert auto_configs[0] == [ + 4, + 1, + 4, + 1, + 1, + ], f"[4, 1, 4, 1, 1] is expected configuration output but got {auto_configs[0]}." + + assert auto_configs[1] == [ + 8, + 1, + 2, + 1, + 1, + ], f"[8, 1, 2, 1, 1] is expected configuration output but got {auto_configs[1]}." + + assert auto_configs[2] == [ + 8, + 1, + 4, + 1, + 1, + ], f"[8, 1, 4, 1, 1] is expected configuration output but got {auto_configs[2]}." + + def test_mistral_model(self): + # Mistral 7B + runner = AutoConfigurator( + model=run.Config(MistralConfig7B), + num_nodes=16, + seq_length=4096, + global_batch_size=2048, + tensor_parallel_sizes=[4], + pipeline_parallel_sizes=[1, 2], + micro_batch_sizes=[1], + context_parallel_sizes=[1], + expert_parallel_sizes=[1], + min_model_parallel_size=4, + max_model_parallel_size=8, + data_paths="/", + path_to_logs="/", + ) + + _, configs = generate_configs(runner) + + mbs = [1, 1] + for run_name, config, mb in zip(configs.keys(), configs.values(), mbs): + assert config.data.micro_batch_size == mb + assert config.data.seq_length == 4096 + assert config.data.global_batch_size == 2048 + + assert len(configs) == 2, f"{len(configs)} configurations were generated but 3 were expected." + + auto_configs = get_auto_configs(configs) + assert auto_configs[0] == [ + 4, + 1, + 1, + 1, + 1, + ], f"[4, 1, 1, 1, 1] is expected configuration output but got {auto_configs[0]}." + + assert auto_configs[1] == [ + 4, + 2, + 1, + 1, + 1, + ], f"[4, 2, 1, 1, 1] is expected configuration output but got {auto_configs[1]}." + + def test_mixtral_model(self): + # Mixtral 8x22B + runner = AutoConfigurator( + model=run.Config(MixtralConfig8x22B), + num_nodes=16, + seq_length=4096, + global_batch_size=2048, + tensor_parallel_sizes=[4], + pipeline_parallel_sizes=[1], + micro_batch_sizes=[1], + context_parallel_sizes=[1], + expert_parallel_sizes=[1, 2], + min_model_parallel_size=4, + max_model_parallel_size=8, + data_paths="/", + path_to_logs="/", + ) + + _, configs = generate_configs(runner) + + mbs = [1, 1] + for run_name, config, mb in zip(configs.keys(), configs.values(), mbs): + assert config.data.micro_batch_size == mb + assert config.data.seq_length == 4096 + assert config.data.global_batch_size == 2048 + + assert len(configs) == 2, f"{len(configs)} configurations were generated but 3 were expected." + + auto_configs = get_auto_configs(configs) + assert auto_configs[0] == [ + 4, + 1, + 1, + 1, + 1, + ], f"[4, 1, 1, 1, 1] is expected configuration output but got {auto_configs[0]}." + + assert auto_configs[1] == [ + 4, + 1, + 1, + 2, + 1, + ], f"[4, 1, 1, 2, 1] is expected configuration output but got {auto_configs[1]}." + + def test_gemma_model(self): + # Gemma 7B + runner = AutoConfigurator( + model=run.Config(GemmaConfig7B), + num_nodes=16, + seq_length=8192, + global_batch_size=2048, + tensor_parallel_sizes=[2], + pipeline_parallel_sizes=[2], + micro_batch_sizes=[1, 2], + context_parallel_sizes=[1], + expert_parallel_sizes=[1], + min_model_parallel_size=4, + max_model_parallel_size=8, + data_paths="/", + path_to_logs="/", + ) + + _, configs = generate_configs(runner) + + mbs = [1, 2] + for run_name, config, mb in zip(configs.keys(), configs.values(), mbs): + assert config.data.micro_batch_size == mb + assert config.data.seq_length == 8192 + assert config.data.global_batch_size == 2048 + + assert len(configs) == 2, f"{len(configs)} configurations were generated but 3 were expected." + + auto_configs = get_auto_configs(configs) + assert auto_configs[0] == [ + 2, + 2, + 1, + 1, + 1, + ], f"[2, 2, 1, 1, 1] is expected configuration output but got {auto_configs[0]}." + + assert auto_configs[1] == [ + 2, + 2, + 1, + 1, + 2, + ], f"[2, 2, 1, 1, 2] is expected configuration output but got {auto_configs[1]}." + + def test_nemotron_model(self): + # Nemotron3 8B + runner = AutoConfigurator( + model=run.Config(Nemotron3Config8B), + num_nodes=16, + seq_length=4096, + global_batch_size=2048, + tensor_parallel_sizes=[1], + pipeline_parallel_sizes=[4], + micro_batch_sizes=[1, 2], + context_parallel_sizes=[1], + expert_parallel_sizes=[1], + min_model_parallel_size=4, + max_model_parallel_size=8, + data_paths="/", + path_to_logs="/", + ) + + _, configs = generate_configs(runner) + + mbs = [1, 2] + for run_name, config, mb in zip(configs.keys(), configs.values(), mbs): + assert config.data.micro_batch_size == mb + assert config.data.seq_length == 4096 + assert config.data.global_batch_size == 2048 + + assert len(configs) == 2, f"{len(configs)} configurations were generated but 3 were expected." + + auto_configs = get_auto_configs(configs) + assert auto_configs[0] == [ + 1, + 4, + 1, + 1, + 1, + ], f"[2, 2, 1, 1, 1] is expected configuration output but got {auto_configs[0]}." + + assert auto_configs[1] == [ + 1, + 4, + 1, + 1, + 2, + ], f"[2, 2, 1, 1, 2] is expected configuration output but got {auto_configs[1]}."