diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 7fc75a6e2e61..07ebf71af0fb 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -797,8 +797,6 @@ def __init__( params_dtype: torch.dtype = torch.float32, get_rng_state_tracker: Optional[Callable] = None, fuse_wgrad_accumulation: bool = False, - apply_query_key_layer_scaling: bool = False, - attention_softmax_in_fp32: bool = False, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, sequence_parallel: bool = False, @@ -828,8 +826,6 @@ def __init__( params_dtype=params_dtype, get_rng_state_tracker=get_rng_state_tracker, fuse_wgrad_accumulation=fuse_wgrad_accumulation, - apply_query_key_layer_scaling=apply_query_key_layer_scaling, - attention_softmax_in_fp32=attention_softmax_in_fp32, seq_length=seq_length, micro_batch_size=micro_batch_size, sequence_parallel=sequence_parallel, @@ -1082,7 +1078,6 @@ def build_layer(layer_number): params_dtype=torch.float32, # dtype params are initialized in get_rng_state_tracker=tensor_parallel.random.get_cuda_rng_tracker, fuse_wgrad_accumulation=config.gradient_accumulation_fusion, - apply_query_key_layer_scaling=apply_query_key_layer_scaling, seq_length=None, # used for jit warmup micro_batch_size=None, # used for jit warmup sequence_parallel=config.sequence_parallel,