diff --git a/examples/nemo_run/qat/README.md b/examples/nemo_run/qat/README.md index 79715953c..1a56d9513 100644 --- a/examples/nemo_run/qat/README.md +++ b/examples/nemo_run/qat/README.md @@ -56,18 +56,21 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar You can run the example either locally or on a [Slurm cluster](ADVANCED.md). -To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container. +To run the example locally, first clone the `TensorRT-Model-Optimizer` repository, then mount the repository to a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09. After launching the Docker container, make sure to also set your HuggingFace token for dataset/model downloading. + +Set up repo: - `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git` -- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a` -Example docker command: +Run docker command (modify with your paths) and export the HuggingFace token: ```bash -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash +docker run -v /home/user/:/home/user/ -v /home/user/TensorRT-Model-Optimizer/:/opt/TensorRT-Model-Optimizer/ --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash + +export HF_TOKEN= ``` -You will also need to set your Huggingface token with `export HF_TOKEN=`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. + You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. ### Running the Flow Locally @@ -92,7 +95,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p To perform QAD training, run: ```bash -python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment +python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4 ``` ## Supported models diff --git a/examples/nemo_run/qat/nemo_qat_flow.py b/examples/nemo_run/qat/nemo_qat_flow.py index df921bd19..af5a602ec 100644 --- a/examples/nemo_run/qat/nemo_qat_flow.py +++ b/examples/nemo_run/qat/nemo_qat_flow.py @@ -228,7 +228,7 @@ def main(args): global_batch_size=GBS, micro_batch_size=MBS, use_hf_tokenizer_chat_template=True, - num_workers=2, + num_workers=1, persistent_workers=True, ) if args.distill: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3ef014d65..f987efcd6 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -26,7 +26,7 @@ from modelopt.torch.opt.searcher import ForwardLoop from modelopt.torch.utils import print_rank_0 -from modelopt.torch.utils.distributed import ParallelState +from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context @@ -81,6 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis return def sync_quantizer_amax_across_dp(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel group.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: sync_quantizer_amax_across_dp(_q, parallel_state) @@ -94,7 +95,6 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): sync_quantizer_amax_across_dp(child, module.parallel_state) - # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -114,6 +114,7 @@ def sync_quantizer_amax_across_tp( axes_for_sync: list, parallel_state: ParallelState, ): + # Syncing amax across TP for sequential quantizer if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: sync_quantizer_amax_across_tp( @@ -598,19 +599,37 @@ def forward(self, input, *args, **kwargs): # This will also perform distributed amax sync for input_quantizers max_calibrate(model, lambda model: None) + def sync_act_scale_across_dp(module, data_parallel_group): + """Sync activation scale across Data Parallel (DP).""" + if data_parallel_group.is_initialized(): + dist.all_reduce( + module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group + ) + for name, module in model.named_modules(): if ( is_quantized_linear(module) and hasattr(module, "awq_lite") and module.awq_lite.num_cache_steps > 0 ): + # Hack: MoEs forward all tokens through all experts if _if_calib is True + module._if_calib = True module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps - if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( + + has_nan_local = torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( torch.isnan(module.awq_lite.weight_scale) - ): + ) + has_nan = DistributedProcessGroup.get_dist_syncd_obj( + has_nan_local, module.parallel_state.data_parallel_group, lambda objs: any(objs) + ) + + if has_nan: module.awq_lite.is_enabled = False - # Hack: MoEs forward all tokens through all experts if _if_calib is True - module._if_calib = True + else: + sync_act_scale_across_dp( + module, + module.parallel_state.data_parallel_group, + ) AWQLiteHelper.cache_mode = False print_rank_0("awq_lite: Searching parameters...") diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 1cf9416ec..85784d2fe 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -15,6 +15,7 @@ """Support quantization for megatron linear layers.""" +import logging import warnings from typing import Any @@ -22,6 +23,7 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @@ -38,6 +40,8 @@ from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +logger = logging.getLogger(__name__) + __all__ = [] @@ -217,8 +221,14 @@ class _MegatronParallelLinear(_ParallelLinear): ] def _setup(self): + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + logger.warning("Context parallel group is not initialized, using data parallel group") + data_parallel_group = get_data_parallel_group() self.parallel_state = ParallelState( - getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(), + data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), ) super()._setup() diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 76965dc0e..f11a736db 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -247,7 +247,10 @@ def __init__( self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) def __repr__(self) -> str: - return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}" + return ( + f"data_parallel_group: {self.data_parallel_group}, " + f"tensor_parallel_group: {self.tensor_parallel_group}, " + ) def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 99d5715ee..ca6b9bff7 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -84,9 +84,12 @@ class MegatronModel(MegatronModule): - def __init__(self, tp_size: int = 1, use_te_norm: bool = False): + def __init__( + self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False, tp_group=None + ): config = TransformerConfig( tensor_model_parallel_size=tp_size, + context_parallel_size=cp_size, pipeline_model_parallel_size=1, normalization="LayerNorm", # Unused parameters below are set to avoid ZeroDivisionError in __post_init__ @@ -104,6 +107,7 @@ def __init__(self, tp_size: int = 1, use_te_norm: bool = False): gather_output=False, skip_bias_add=True, is_expert=False, + tp_group=tp_group, ) self.activation = nn.ReLU() if use_te_norm: @@ -118,6 +122,7 @@ def __init__(self, tp_size: int = 1, use_te_norm: bool = False): skip_bias_add=True, input_is_parallel=True, is_expert=False, + tp_group=tp_group, ) def forward(self, x): @@ -127,7 +132,11 @@ def forward(self, x): x = x[0] return x - def get_dummy_input(self) -> torch.Tensor: + def get_dummy_input(self, seed: int | None = None) -> torch.Tensor: + if seed is not None: + gen = torch.Generator() + gen.manual_seed(seed) + return torch.randn(1, 4, 32, generator=gen) return torch.randn(1, 4, 32) @@ -390,13 +399,20 @@ def run_mcore_inference_with_dummy_input( def initialize_for_megatron( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234 + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + seed=1234, + context_parallel_size=1, ): """Initialize Megatron model parallelism. NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. """ - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) + initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + ) model_parallel_cuda_manual_seed(seed) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 505eac2b6..8647aaa00 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +from unittest.mock import patch import pytest import torch @@ -22,7 +23,9 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm +from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer from modelopt.torch.quantization.utils import is_quantized_linear from modelopt.torch.utils import torch_to @@ -116,38 +119,95 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N mto.restore_from_modelopt_state(model_ref, state_dict) -def tensor_parallel_test_helper(model, config, tp_group, dp_group): - # The input to fist layer, the column parallel should be the same across all tp ranks - calib_data = model.get_dummy_input().cuda() - dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) +def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[]): + quantizer_attr = getattr(quantizer, attr).clone() + for group in groups: + if group is not None: + dist.all_reduce(quantizer_attr, op=op, group=group) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) - def forward_loop(model): - model(calib_data) - model = mtq.quantize(model, config, forward_loop) +original_awq_lite = model_calib_module.awq_lite - # Sanity check - forward_loop(model) - if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: - # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks - activation_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) +def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True, **kwargs): + """Function to mock awq_lite function to always use debug=True for testing""" + return original_awq_lite(model, forward_loop, alpha_step, debug=True, **kwargs) - # Lets check the row parallel weight amax; it should be the same across all tp ranks - weight_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) - if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: - # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks - input_quantizer = model.fc1.input_quantizer - pre_quant_scale = input_quantizer.pre_quant_scale.clone() - dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def data_tensor_context_parallel_test_helper( + model, config, mock_awq_lite, dp_group=None, tp_group=None, test_pre_quant_scale=True +): + # Calib data should be different across each DP rank + dp_rank = dist.get_rank(group=dp_group) + calib_data = model.get_dummy_input(seed=dp_rank).cuda() + + if tp_group is not None: + # The input to first layer, the column parallel should be the same across all tp ranks + dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) - dist.destroy_process_group() + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + _distributed_attr_check( + model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] + ) + _distributed_attr_check( + model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] + ) + + # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks + # Channel-wise (INT8) only expects same amax across row parallel ranks + # Block-wise quantization does not expect same amax across row and column parallel ranks + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + _distributed_attr_check( + quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] + ) + else: + _distributed_attr_check( + model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] + ) + + if config in [ + mtq.FP8_DEFAULT_CFG, + mtq.NVFP4_DEFAULT_CFG, + mtq.INT8_DEFAULT_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + ]: + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + _distributed_attr_check( + quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] + ) + else: + _distributed_attr_check( + model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, groups=[dp_group, tp_group] + ) + + # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks + # It is different across DP/CP ranks since the input is different + if ( + test_pre_quant_scale + and tp_group + and config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG] + ): + input_quantizer = model.fc1.input_quantizer + _distributed_attr_check( + input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, groups=[dp_group, tp_group] + ) + + # Check act scale + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + _distributed_attr_check( + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group] + ) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index 208fb2287..f32065bce 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -34,6 +34,12 @@ def need_2_gpus(): pytest.skip("Need at least 2 GPUs to run this test") +@pytest.fixture +def need_8_gpus(): + if torch.cuda.device_count() < 8: + pytest.skip("Need at least 8 GPUs to run this test") + + @pytest.fixture(scope="module") def set_torch_dtype(request): orig_dtype = torch.get_default_dtype() diff --git a/tests/gpu/torch/quantization/plugins/test_apex.py b/tests/gpu/torch/quantization/plugins/test_apex.py index 6fd9d501f..af20e19f5 100644 --- a/tests/gpu/torch/quantization/plugins/test_apex.py +++ b/tests/gpu/torch/quantization/plugins/test_apex.py @@ -23,7 +23,7 @@ from _test_utils.torch_quantization.models import RegularQuantModelForTP from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, - tensor_parallel_test_helper, + data_tensor_context_parallel_test_helper, ) import modelopt.torch.quantization as mtq @@ -58,7 +58,11 @@ def forward(self, x): x = x[0] return x - def get_dummy_input(self): + def get_dummy_input(self, seed: int | None = None): + if seed is not None: + gen = torch.Generator() + gen.manual_seed(seed) + return torch.randn(1, 4, 32, generator=gen) return torch.randn(1, 4, 32) @@ -106,8 +110,11 @@ def _test_tensor_parallel_helper(config, rank, size): model_parallel_cuda_manual_seed(SEED) model = ApexModel().cuda() - tensor_parallel_test_helper( - model, config, get_tensor_model_parallel_group(), get_data_parallel_group() + data_tensor_context_parallel_test_helper( + model, + config, + tp_group=get_tensor_model_parallel_group(), + dp_group=get_data_parallel_group(), ) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 226403ea2..b63462ef3 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -31,7 +31,7 @@ from _test_utils.torch_quantization.quant_utils import get_model_size from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, - tensor_parallel_test_helper, + data_tensor_context_parallel_test_helper, ) skip_if_no_megatron() @@ -90,12 +90,51 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1): destroy_model_parallel() -def _test_tensor_parallel_helper(config, rank, size): - initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED) - model = MegatronModel(size).cuda() +# Unified parallelism test helper +def _test_parallelism_helper( + config, + rank, + size, + tensor_model_parallel_size=1, + context_parallel_size=1, + use_rank_in_seed=False, + test_pre_quant_scale=True, +): + """ + Unified helper for testing different parallelism configurations. + Args: + config: Quantization config to test + rank: Current rank in distributed setup + size: Total number of processes + tensor_model_parallel_size: Size of tensor model parallel group (default: 1) + context_parallel_size: Size of context parallel group (default: 1) + use_rank_in_seed: Whether to add rank to seed for different data across ranks (default: False) + """ + seed = SEED + rank if use_rank_in_seed else SEED + initialize_for_megatron( + tensor_model_parallel_size=tensor_model_parallel_size, + context_parallel_size=context_parallel_size, + seed=seed, + ) - tensor_parallel_test_helper( - model, config, get_tensor_model_parallel_group(), get_data_parallel_group() + # Determine if we need tp_group and dp_group + tp_group = get_tensor_model_parallel_group() if tensor_model_parallel_size > 1 else None + dp_group = get_data_parallel_group(with_context_parallel=True) + + # Create model with appropriate parallelism settings + model = MegatronModel( + tp_size=tensor_model_parallel_size, + cp_size=context_parallel_size, + tp_group=tp_group, + ).cuda() + + # Call the test helper with appropriate groups + data_tensor_context_parallel_test_helper( + model, + config, + dp_group=dp_group, + tp_group=tp_group, + test_pre_quant_scale=test_pre_quant_scale, ) @@ -113,7 +152,78 @@ def _test_tensor_parallel_helper(config, rank, size): ) def test_tensor_parallel(need_2_gpus, config): spawn_multiprocess_job( - size=2, job=partial(_test_tensor_parallel_helper, config), backend="nccl" + size=2, + job=partial(_test_parallelism_helper, config, tensor_model_parallel_size=2), + backend="nccl", + ) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, + job=partial(_test_parallelism_helper, config, use_rank_in_seed=True), + backend="nccl", + ) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_context_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, + job=partial( + _test_parallelism_helper, config, context_parallel_size=2, use_rank_in_seed=True + ), + backend="nccl", + ) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_tensor_context_parallel(need_8_gpus, config): + spawn_multiprocess_job( + size=8, + job=partial( + _test_parallelism_helper, + config, + tensor_model_parallel_size=2, + context_parallel_size=2, + use_rank_in_seed=True, + test_pre_quant_scale=False, + ), + backend="nccl", ) @@ -126,7 +236,7 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, - num_attention_heads=4, + num_attention_heads=8, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, @@ -138,7 +248,7 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, - num_attention_heads=4, + num_attention_heads=8, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size,