Skip to content

Commit

Permalink
GPT recipes to use full te spec (NVIDIA#11119)
Browse files Browse the repository at this point in the history
* use full te spec

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>

* more recipe fix

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>

* rm dropout_ffn, default num modes

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

---------

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>
Co-authored-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>
  • Loading branch information
3 people authored and HuiyingLi committed Nov 15, 2024
1 parent d6e5841 commit 895e256
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
16 changes: 13 additions & 3 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ class GPTConfig126M(GPTConfig):
hidden_size: int = 768
ffn_hidden_size: int = 3072
num_attention_heads: int = 12
bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True
use_transformer_engine_full_layer_spec: bool = True


@dataclass
Expand All @@ -264,9 +267,9 @@ class GPTConfig5B(GPTConfig):
hidden_size: int = 4096
ffn_hidden_size: int = 16384
num_attention_heads: int = 32

bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True
use_transformer_engine_full_layer_spec: bool = True


@dataclass
Expand All @@ -276,6 +279,9 @@ class GPTConfig7B(GPTConfig):
hidden_size: int = 4096
ffn_hidden_size: int = 10880
num_attention_heads: int = 32
bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True
use_transformer_engine_full_layer_spec: bool = True


@dataclass
Expand All @@ -285,9 +291,9 @@ class GPTConfig20B(GPTConfig):
hidden_size: int = 6144
ffn_hidden_size: int = 24576
num_attention_heads: int = 48

bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True
use_transformer_engine_full_layer_spec: bool = True


@dataclass
Expand All @@ -297,6 +303,9 @@ class GPTConfig40B(GPTConfig):
hidden_size: int = 8192
ffn_hidden_size: int = 32768
num_attention_heads: int = 64
bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True
use_transformer_engine_full_layer_spec: bool = True


@dataclass
Expand All @@ -308,9 +317,10 @@ class GPTConfig175B(GPTConfig):
num_attention_heads: int = 96
hidden_dropout: float = 0.0
attention_dropout: float = 0.0
ffn_dropout: float = 0.0
bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True
use_transformer_engine_full_layer_spec: bool = True
layernorm_zero_centered_gamma: bool = True


class GPTModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin):
Expand Down
9 changes: 7 additions & 2 deletions nemo/collections/llm/recipes/llama3_70b.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_nodes: int = None,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = 'lora',
seq_length: Optional[int] = None,
Expand Down Expand Up @@ -293,11 +293,16 @@ def finetune_recipe(
if seq_length is None:
seq_length = 4096 if packed_sequence else 2048

if num_nodes is None:
if peft_scheme is None or peft_scheme.lower() == 'none':
num_nodes = 4
elif peft_scheme.lower() == 'lora':
num_nodes = 1

recipe = default_finetune_recipe(
model(), "meta-llama/Meta-Llama-3-70B", dir, name, num_nodes, num_gpus_per_node, packed_sequence
)
if peft_scheme is None or peft_scheme.lower() == 'none':
assert num_nodes >= 4
recipe.trainer.strategy.tensor_model_parallel_size = 8
recipe.trainer.strategy.pipeline_model_parallel_size = 4
recipe.optim.config.lr = 5e-6
Expand Down

0 comments on commit 895e256

Please sign in to comment.