Skip to content

Commit

Permalink
fix comment
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com committed Jul 9, 2024
1 parent 00c2481 commit f8002fb
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 59 deletions.
1 change: 1 addition & 0 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(self):
self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
self.is_evaluating = False
self.recompute_forward_no_comm = False

@property
def config(self):
Expand Down
4 changes: 3 additions & 1 deletion internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,9 @@ def args_sanity_check():
), "only support interleaved pipeline scheduler with overlap"

# when not use tp or sp, checkpoint_tp_no_comm should always be False
if gpc.config.parallel["tensor"]["size"] <= 1 and getattr(gpc.config.model, "checkpoint_tp_no_comm", False):
if (gpc.config.parallel["tensor"]["mode"] == "isp" or gpc.config.parallel["tensor"]["size"] <= 1) and getattr(
gpc.config.model, "checkpoint_tp_no_comm", False
):
gpc.config.model.checkpoint_tp_no_comm = False

# monitoring default config
Expand Down
29 changes: 6 additions & 23 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.naive_amp import set_output_attr_to_module
from internlm.core.parallel.comm.tensor import _GATHER_DIM
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import new_linear
Expand All @@ -22,10 +21,11 @@
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
internlm1_mha_save_convert,
padding_residual,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel
from internlm.utils.parallel import is_using_sequence_parallel

logger = get_logger(__file__)

Expand Down Expand Up @@ -181,8 +181,6 @@ def _forward(self, hidden_states, *args, **kwargs):
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
indexes: the length of index is same as hidden states, which stand for the current position
"""
no_communication = args[4] if len(args) > 4 else False
args = args[:4]

def _dropout_and_norm_attn(_hidden_states):
_dropped = self.dropout1(_hidden_states)
Expand Down Expand Up @@ -215,28 +213,13 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
if self.residual_in_fp32:
residual = residual.to(torch.float32)

no_communication = gpc.recompute_forward_no_comm

hidden_states = self.mlp(hidden_states, no_communication=no_communication)

# pad residual
if no_communication and is_using_sequence_parallel() and not is_using_isp():
requires_grad = residual.requires_grad
pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM]
pad_after = (
gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1
) * residual.shape[_GATHER_DIM]

pad_before_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
pad_after_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)

residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(requires_grad)
if no_communication and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual

Expand Down
32 changes: 7 additions & 25 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.parallel.comm.tensor import _GATHER_DIM
from internlm.initialize.initialize_tensor import (
normal_,
scaled_init_method_normal,
Expand All @@ -22,10 +21,11 @@
from internlm.model.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
padding_residual,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel
from internlm.utils.parallel import is_using_sequence_parallel

logger = get_logger(__file__)

Expand Down Expand Up @@ -218,8 +218,6 @@ def _forward(self, hidden_states, residual, *args, **kwargs):
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
indexes: the length of index is same as hidden states, which stand for the current position
"""
no_communication = args[4] if len(args) > 4 else False
args = args[:4]
if self.prenorm:

def _dropout_and_norm_attn(_residual, _hidden_states):
Expand Down Expand Up @@ -259,30 +257,14 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):

if self.residual_in_fp32:
residual = residual.to(torch.float32)

no_communication = gpc.recompute_forward_no_comm

hidden_states = self.feed_forward(hidden_states, no_communication=no_communication)

# pad residual
if no_communication and is_using_sequence_parallel() and not is_using_isp():
requires_grad = residual.requires_grad
pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM]
pad_after = (
gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1
) * residual.shape[_GATHER_DIM]

pad_before_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
pad_after_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)

residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(
requires_grad
)
if no_communication and is_using_sequence_parallel():
residual = padding_residual(residual)

return hidden_states + residual
else:
Expand Down
27 changes: 27 additions & 0 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import Any, Dict, List

import torch

from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.parallel.comm.tensor import _GATHER_DIM
from internlm.model.modules.mha import MHA


Expand Down Expand Up @@ -51,3 +56,25 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]:
kwargs["max_seqlen"] = args[3]

return kwargs


def padding_residual(residual):
requires_grad = residual.requires_grad
pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM]
pad_after = (
gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1
) * residual.shape[_GATHER_DIM]

pad_before_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
pad_after_tensor = torch.zeros(
(*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]),
dtype=residual.dtype,
device=residual.device,
)
residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(requires_grad)

return residual
22 changes: 12 additions & 10 deletions internlm/solver/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
sync_states,
)
from internlm.core.parallel.comm.tensor import _GATHER_DIM, all_gather_raw
from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel
from internlm.utils.parallel import is_using_sequence_parallel

from ..utils.common import get_current_device

Expand Down Expand Up @@ -128,19 +128,18 @@ def backward(ctx, *args):
inputs[idx] = tensors[i]

# no_communication
no_communication = False
if getattr(gpc.config.model, "checkpoint_tp_no_comm", False):
no_communication = True
inputs.append(True)
no_communication = getattr(gpc.config.model, "checkpoint_tp_no_comm", False)

detached_inputs = detach_variable(tuple(inputs))

handle = None
if no_communication and is_using_sequence_parallel() and not is_using_isp():
grad_output = args[0]
grad_output, handle = all_gather_raw(
grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM
)
if no_communication:
gpc.recompute_forward_no_comm = True
if is_using_sequence_parallel():
grad_output = args[0]
grad_output, handle = all_gather_raw(
grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM
)

if ctx.had_autocast_in_fwd:
with torch.enable_grad(), internlm_accelerator.amp.autocast():
Expand All @@ -149,6 +148,9 @@ def backward(ctx, *args):
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

if gpc.recompute_forward_no_comm:
gpc.recompute_forward_no_comm = False

if handle:
handle.wait()
args = list(args)
Expand Down

0 comments on commit f8002fb

Please sign in to comment.