From 37c2edc83350abff38ca7d2f18c1cc0df65a6e54 Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Fri, 21 Jul 2023 22:16:13 +0800 Subject: [PATCH 01/52] remove inappropriate import in __init__.py --- bmtrain/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 7c7d6c2c..1695167b 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -12,9 +12,5 @@ from . import debug from .store import save, load -from . import benchmark -from . import optim -from . import inspect -from . import lr_scheduler from . import loss from . import distributed From 15de53cc6f60b38b90ef2cdbb07b2147257a55c1 Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Fri, 18 Aug 2023 01:21:10 -0700 Subject: [PATCH 02/52] FIX tensor slice in gather --- bmtrain/synchronize.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index 80d39057..7c3915ba 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -35,7 +35,10 @@ def sum_loss(loss : torch.Tensor): def gather_result(result: torch.Tensor): warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning) - + if result.storage().size() != result.numel(): + # Create a clone of the original tensor if it's a slice + result = result.clone() + output_cuda = True if not result.is_cuda: result = result.cuda() From 65535049886d71d3b1f1bf41cd02364e108d43a6 Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Fri, 18 Aug 2023 01:22:15 -0700 Subject: [PATCH 03/52] Update synchronize.py --- bmtrain/synchronize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index 7c3915ba..356f32c0 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -38,7 +38,6 @@ def gather_result(result: torch.Tensor): if result.storage().size() != result.numel(): # Create a clone of the original tensor if it's a slice result = result.clone() - output_cuda = True if not result.is_cuda: result = result.cuda() From 0c6da0b1186a5d6fb97ce534d32250ae53be1e52 Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Fri, 18 Aug 2023 01:23:43 -0700 Subject: [PATCH 04/52] FIX Error : tensor slice in gather() --- bmtrain/synchronize.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index 356f32c0..7c3915ba 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -38,6 +38,7 @@ def gather_result(result: torch.Tensor): if result.storage().size() != result.numel(): # Create a clone of the original tensor if it's a slice result = result.clone() + output_cuda = True if not result.is_cuda: result = result.cuda() From 7d4a5703e4483fc95177c1a95704d5a765de4350 Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Fri, 18 Aug 2023 03:47:36 -0700 Subject: [PATCH 05/52] Update synchronize.py --- bmtrain/synchronize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/synchronize.py b/bmtrain/synchronize.py index 7c3915ba..d562cc21 100644 --- a/bmtrain/synchronize.py +++ b/bmtrain/synchronize.py @@ -35,7 +35,7 @@ def sum_loss(loss : torch.Tensor): def gather_result(result: torch.Tensor): warnings.warn("bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", DeprecationWarning) - if result.storage().size() != result.numel(): + if result.storage_offset() != 0 or result.storage().size() != result.numel(): # Create a clone of the original tensor if it's a slice result = result.clone() From 245dcaad5b7bce25c18de0181c99f0eb36e2e9c3 Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Fri, 18 Aug 2023 03:48:49 -0700 Subject: [PATCH 06/52] Create test_synchronize.py --- tests/test_synchronize.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/test_synchronize.py diff --git a/tests/test_synchronize.py b/tests/test_synchronize.py new file mode 100644 index 00000000..73169c1e --- /dev/null +++ b/tests/test_synchronize.py @@ -0,0 +1,28 @@ +import torch +import bmtrain as bmt + +from bmtrain.global_var import config +from bmtrain import nccl, distributed +from bmtrain.synchronize import gather_result + +def test_main(): + tensor = torch.rand(5, 5) + result = bmt.gather_result(tensor) + + tensor_slice_0 = tensor[:1, :1] + result_slice_0 = bmt.gather_result(tensor_slice_0) + + tensor_slice_1 = tensor[:2, :2] + result_slice_1 = bmt.gather_result(tensor_slice_1) + + tensor_slice_2 = tensor[:3, :3] + result_slice_2 = bmt.gather_result(tensor_slice_2) + + tensor_slice_3 = tensor[:4, :4] + result_slice_3 = bmt.gather_result(tensor_slice_3) + + print(result, result_slice_1, result_slice_2, result_slice_3, sep='\n') + +if __name__ == '__main__': + bmt.init_distributed(pipe_size=1) + test_main() From 7484cfba3658c57f0b2e37fa76bc883ed912876a Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Fri, 18 Aug 2023 03:49:48 -0700 Subject: [PATCH 07/52] Update test_synchronize.py --- tests/test_synchronize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_synchronize.py b/tests/test_synchronize.py index 73169c1e..f92ff5b9 100644 --- a/tests/test_synchronize.py +++ b/tests/test_synchronize.py @@ -21,7 +21,7 @@ def test_main(): tensor_slice_3 = tensor[:4, :4] result_slice_3 = bmt.gather_result(tensor_slice_3) - print(result, result_slice_1, result_slice_2, result_slice_3, sep='\n') + print(result, result_slice_0, result_slice_1, result_slice_2, result_slice_3, sep='\n') if __name__ == '__main__': bmt.init_distributed(pipe_size=1) From e0f8c30c19e0bd5b30a17ab66525c46ec8d71d9e Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Fri, 18 Aug 2023 03:51:06 -0700 Subject: [PATCH 08/52] add test_synchronize --- tests/test_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_all.py b/tests/test_all.py index b614d3eb..ec23cf89 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -6,6 +6,7 @@ ("different_output_shape", 1), ("load_ckpt", 1), ("init_parameters", 1), + ("synchronize", 1), ("init_parameters_multi_gpu", 4), ("optim_state", 4), From 99e07067a2545e4b4eabe14c7d3ecdb5d82f72cf Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 18 Aug 2023 18:31:15 +0800 Subject: [PATCH 09/52] add ReduceScatter communication op --- bmtrain/distributed/__init__.py | 2 +- bmtrain/distributed/ops.py | 69 +++++++++++++++++++++++++++++++++ tests/test_nccl_backward.py | 13 ++++++- 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/bmtrain/distributed/__init__.py b/bmtrain/distributed/__init__.py index 9dc64bb8..84a4adf8 100644 --- a/bmtrain/distributed/__init__.py +++ b/bmtrain/distributed/__init__.py @@ -1 +1 @@ -from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations +from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index ef69659a..a6a3315c 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -3,6 +3,7 @@ from ..nccl import allGather as ncclAllGather, recv from ..nccl import allReduce as ncclAllReduce from ..nccl import broadcast as ncclBroadcast +from ..nccl import reduceScatter as ncclReduceScatter from ..nccl import send as ncclSend from ..nccl import recv as ncclRecv from ..nccl import commCount,commRank,NCCLCommunicator @@ -44,6 +45,7 @@ def recv_meta(prev_rank, comm): return dtype,shape class OpBroadcast(torch.autograd.Function): + @staticmethod def forward(ctx, src, root, comm = None): if comm is None: @@ -52,15 +54,19 @@ def forward(ctx, src, root, comm = None): outputs = torch.empty_like(src, dtype = src.dtype, device = src.device) ncclBroadcast(src.storage(), outputs.storage(), root, comm) return outputs + @staticmethod def backward(ctx, grad_output): res = all_reduce(grad_output, "sum", ctx.comm) return res, None, None + def broadcast(src, root, comm=None): if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") return OpBroadcast.apply(src, root, comm) + class OpAllGather(torch.autograd.Function): + @staticmethod def forward(ctx, input : torch.Tensor, comm = None): if comm is None: @@ -98,6 +104,69 @@ def all_gather(x : torch.Tensor, comm = None): assert x.is_cuda return OpAllGather.apply(x, comm) +class OpReduceScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None, dim = 0): + if comm is None: + comm = config["comm"] + ctx.comm = comm + ctx.dim = dim + if dim != 0: + assert input.shape[dim] % commCount(comm) == 0, "The dimension to reduce must be divisible by the number of communication processes" + input = input.transpose(0, dim) + if not input.is_contiguous(): + input = input.contiguous() + if input.storage_offset() != 0 or input.storage().size() != input.numel(): + input = input.clone() + output_shape = (input.shape[0] // commCount(comm), *input.shape[1:]) + output = torch.empty( output_shape, dtype=input.dtype, device=input.device ) + ncclReduceScatter( + input.storage(), + output.storage(), + op, + comm + ) + ctx.op = op + output = output.transpose(dim, 0) + if op in ["sum", "avg"]: + pass + elif op in ["max", "min"]: + ctx.save_for_backward( input != output ) + else: + ctx.save_for_backward( output / input ) + return output + + @staticmethod + def backward(ctx, grad_output): + with torch.no_grad(): + grad_output = OpAllGather.apply(grad_output, ctx.comm).transpose(1, ctx.dim).flatten(0,1).transpose(0, ctx.dim) + if ctx.op == "sum": + return grad_output, None, None + elif ctx.op == "avg": + return grad_output / commCount(ctx.comm), None, None + elif ctx.op in ["max", "min"]: + return torch.masked_fill(grad_output, ctx.saved_tensors[0], 0), None, None + else: + return grad_output * ctx.saved_tensors[0], None, None + +def reduce_scatter(x : torch.Tensor, op : str = "sum", comm = None): + """Reduces the input tensor from all processes. + + Args: + x (torch.Tensor): The input tensor of shape (world_size, ...). + op (str): The reduction operation, one of "sum", "avg", "max", "min", "prod". Default: "sum". + + Returns: + torch.Tensor: The reduced tensor of shape (...). + + """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + assert x.is_cuda + return OpReduceScatter.apply(x, op, comm) + class OpAllReduce(torch.autograd.Function): @staticmethod def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None): diff --git a/tests/test_nccl_backward.py b/tests/test_nccl_backward.py index 3dcd0560..279d9ace 100644 --- a/tests/test_nccl_backward.py +++ b/tests/test_nccl_backward.py @@ -11,10 +11,19 @@ def test_main(): ref = y for i in range(bmt.world_size()): if i != bmt.rank(): ref *= i+1 - print(x.grad) assert_eq(x.grad, ref) +def test_reducescatter(): + x = torch.ones((24,), dtype=torch.half, device="cuda").requires_grad_(True) + y = bmt.distributed.reduce_scatter(x, "sum") + loss = y.sum() + loss.backward() + ref = torch.ones((24,), dtype=torch.half, device="cuda") + print(loss) + assert y[0] == bmt.world_size() + assert_all_eq(x.grad,ref) + if __name__ == "__main__": bmt.init_distributed() - + test_reducescatter() test_main() \ No newline at end of file From 62138862d4df6adbb7539bc0d2cde439c8155f93 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 21 Aug 2023 11:48:20 +0800 Subject: [PATCH 10/52] FIX: unittest and backward for reduce scatter --- bmtrain/distributed/ops.py | 28 ++++++++++++---------------- tests/test_nccl_backward.py | 25 ++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index a6a3315c..e3504ab7 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -107,14 +107,12 @@ def all_gather(x : torch.Tensor, comm = None): class OpReduceScatter(torch.autograd.Function): @staticmethod - def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None, dim = 0): + def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None): if comm is None: comm = config["comm"] ctx.comm = comm - ctx.dim = dim - if dim != 0: - assert input.shape[dim] % commCount(comm) == 0, "The dimension to reduce must be divisible by the number of communication processes" - input = input.transpose(0, dim) + rank = commRank(comm) + assert input.shape[0] % commCount(comm) == 0, "The dimension 0 must be divisible by the number of communication processes" if not input.is_contiguous(): input = input.contiguous() if input.storage_offset() != 0 or input.storage().size() != input.numel(): @@ -128,27 +126,25 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None, comm ) ctx.op = op - output = output.transpose(dim, 0) if op in ["sum", "avg"]: pass elif op in ["max", "min"]: - ctx.save_for_backward( input != output ) + ctx.save_for_backward( output != input[rank * input.shape[0]:(rank + 1) * input.shape[0]] ) else: - ctx.save_for_backward( output / input ) + ctx.save_for_backward( output / input[rank * input.shape[0]:(rank + 1) * input.shape[0]] ) return output @staticmethod def backward(ctx, grad_output): with torch.no_grad(): - grad_output = OpAllGather.apply(grad_output, ctx.comm).transpose(1, ctx.dim).flatten(0,1).transpose(0, ctx.dim) - if ctx.op == "sum": - return grad_output, None, None - elif ctx.op == "avg": - return grad_output / commCount(ctx.comm), None, None - elif ctx.op in ["max", "min"]: - return torch.masked_fill(grad_output, ctx.saved_tensors[0], 0), None, None + grad_output = OpAllGather.apply(grad_output, ctx.comm).flatten(0,1) + if ctx.op in ["max", "min"]: + raise NotImplementedError("max min operation now do not support backward") else: - return grad_output * ctx.saved_tensors[0], None, None + if ctx.op == "avg": + grad_output /= commCount(ctx.comm) + return grad_output, None, None + def reduce_scatter(x : torch.Tensor, op : str = "sum", comm = None): """Reduces the input tensor from all processes. diff --git a/tests/test_nccl_backward.py b/tests/test_nccl_backward.py index 279d9ace..3e95b165 100644 --- a/tests/test_nccl_backward.py +++ b/tests/test_nccl_backward.py @@ -23,7 +23,30 @@ def test_reducescatter(): assert_all_eq(x.grad,ref) +def test_reducescatter(): + world_size = bmt.world_size() + for shape in [(128,), (128,128)]: + tensors = torch.randn(world_size, *shape, dtype=torch.half, device="cuda").requires_grad_(True) + local_tensor = tensors[bmt.rank()] + x = local_tensor.detach().clone().requires_grad_(True) + y = bmt.distributed.reduce_scatter(x, "sum") + ref = tensors.sum(0) + partition = x.shape[0] // bmt.world_size() + ref_p = ref[bmt.rank() * partition:(bmt.rank() + 1) * partition] + if bmt.rank() == 0: + print(ref_p) + print(y) + assert torch.allclose(ref_p, y, atol=1e-2, rtol=1e-3) + g = torch.randn_like(y) + grad = torch.autograd.grad(y, x, g)[0] + pgrad = grad[bmt.rank() * y.shape[0]: (bmt.rank() + 1) * y.shape[0]] + ref_g = g + if bmt.rank() == 0: + print(ref_g) + print(pgrad) + assert torch.allclose(ref_g, pgrad, atol=1e-3, rtol=1e-3) + if __name__ == "__main__": bmt.init_distributed() test_reducescatter() - test_main() \ No newline at end of file + test_main() From 74700e4a7af3b24414adb2bc6ae4db0214d4c9c9 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 15:13:17 +0800 Subject: [PATCH 11/52] Refactor ZeRO, checkpoint and pipeline code (#128) * using hooks to implement ZeRO and Checkpoint --------- Co-authored-by: zhangkaihuo --- bmtrain/__init__.py | 4 + bmtrain/block_layer.py | 596 +++++------------------ bmtrain/checkpointing.py | 156 +++++- bmtrain/hook_func.py | 110 +++++ bmtrain/loss/cross_entropy.py | 9 + bmtrain/nn/__init__.py | 1 + {example/layers => bmtrain/nn}/linear.py | 24 +- bmtrain/pipe_layer.py | 507 ++++++++----------- example/layers/__init__.py | 3 +- example/layers/attention.py | 2 +- example/layers/feedforward.py | 2 +- example/train.py | 14 +- tests/test_all.py | 1 + tests/test_has_inf_nan.py | 6 +- tests/test_inspector_hidden.py | 29 +- tests/test_middle_hidden.py | 16 +- tests/test_model_wrapper.py | 4 +- tests/test_no_grad.py | 46 ++ tests/test_optim.py | 5 +- tests/test_optim_state.py | 9 +- tests/test_other_hidden.py | 13 +- tests/test_requires_grad.py | 14 +- tests/test_requires_grad_multi_gpu.py | 8 +- tests/test_training.py | 22 +- 24 files changed, 744 insertions(+), 857 deletions(-) create mode 100644 bmtrain/hook_func.py create mode 100644 bmtrain/nn/__init__.py rename {example/layers => bmtrain/nn}/linear.py (50%) create mode 100644 tests/test_no_grad.py diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index 26c3760d..ae243e65 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -18,3 +18,7 @@ from . import loss from . import distributed +from . import nn +from . import optim +from . import inspect +from . import lr_scheduler diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 8647b7ff..8e52c68a 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -4,279 +4,18 @@ from .global_var import config import torch from . import nccl -from .synchronize import wait_loader from .parameter import DistributedParameter, OpAllGather -from .checkpointing import ScopedTensorInspectorContext -from . import debug -import copy -import inspect +from .checkpointing import ( + CheckpointBlockContext +) +from . import debug -# the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter -class OpCheckpointBlock(torch.autograd.Function): - @staticmethod - def forward(ctx, placeholder, block : 'CheckpointBlock', preserve_rng_state, len_args, *args): - ctx.block = block - ctx.preserve_rng_state = preserve_rng_state - - ctx.cuda_rng_state = torch.cuda.get_rng_state() if preserve_rng_state else None - tensors = [] - others = [] - for arg in args: - if torch.is_tensor(arg): - tensors.append(arg) - others.append(None) - else: - tensors.append(None) - others.append(arg) - - ctx.nontensor_inputs = others - ctx.len_args = len_args - ctx.save_for_backward(*tensors) - ctx.param_dict={} - if config['zero_level'] == 2: - flag = 1 - else: - flag = 0 - with torch.no_grad(), ScopedTensorInspectorContext() as inspector, CheckpointBlockContext(block, ctx.param_dict, flag): - inp_args = args[:len_args] - inp_kwargs = {} - for k, v in zip(args[len_args::2], args[len_args + 1::2]): - inp_kwargs[k] = v - outputs = ctx.block._module._call_impl(*inp_args, **inp_kwargs) - for it in inspector.hidden_states: - debug.append("_inspect_hidden_states", it) - ctx.inspect_list = inspector.hidden_states - - if not isinstance(outputs, list) and not isinstance(outputs, tuple): - outputs = [outputs] - len_outputs = 0 - else: - outputs = list(outputs) - len_outputs = len(outputs) - return tuple([len_outputs] + outputs + [hidden_state["tensor"] for hidden_state in inspector.hidden_states]) - - @staticmethod - def backward(ctx, _, *grads): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - - all_inputs = [] - input_reqires_grad = [] - len_args = ctx.len_args - for tensor, other in zip(ctx.saved_tensors, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_reqires_grad.append(False) - else: - input_reqires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - - - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=ctx.preserve_rng_state): - if ctx.preserve_rng_state: - torch.cuda.set_rng_state(ctx.cuda_rng_state) - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - with torch.enable_grad(), CheckpointBlockContext(ctx.block, ctx.param_dict, flag): - inp_args = all_inputs[:len_args] - inp_kwargs = {} - for k, v in zip(all_inputs[len_args::2], all_inputs[len_args + 1::2]): - inp_kwargs[k] = v - with ScopedTensorInspectorContext() as inspector: - outputs = ctx.block._module._call_impl(*inp_args, **inp_kwargs) - if not isinstance(outputs, tuple): - outputs = (outputs,) - - assert len(outputs) + len(inspector.hidden_states) == len(grads) - - outputs_with_grad = [] - grad_of_output = [] - for i, output in enumerate(outputs): - if torch.is_tensor(output) and output.requires_grad: - outputs_with_grad.append(output) - grad_of_output.append(grads[i]) - - # calculate gradients for inputs, also for parameters - torch.autograd.backward( - outputs_with_grad + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - grad_of_output + list(grads[len(outputs):]), - ) - assert len(ctx.inspect_list) == len(inspector.hidden_states), "Backward step changed" - for i, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.inspect_list[i]["name"], "Backward step changed" - assert it["shape"] == ctx.inspect_list[i]["shape"], "Backward step changed" - assert it["group"] == ctx.inspect_list[i]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.inspect_list[i]["tensor"] = it["tensor"] - ctx.inspect_list[i]["requires_grad"] = it["requires_grad"] - - grads = [] - for inp, requires_grad in zip(all_inputs, input_reqires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None) + tuple(grads) - -class CheckpointBlockContext: - def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, flag : int = 0, pipe = False) -> None: - self.block = block - self.ctx_dict = ctx_dict - self._param_buffer = {} - self._grad_buffer = {} - self._param_tensor = {} - self._grad_tensor = {} - self.flag = flag - self._need_release = False - if pipe: - self.comm = config["zero_comm"] - else: - self.comm = config["comm"] - def enter(self): - """ - gather parameters - """ - if self.block._ready: - return - self.block._ready = True - self._need_release = True - - wait_loader() - requires_grad = torch.is_grad_enabled() - with torch.cuda.stream(config["load_stream"]): - for kw, val in self.block._storage_info.items(): - assert self.block._storage_params[kw].is_cuda - assert kw not in self._grad_buffer - assert kw not in self._param_buffer - local_param = self.block._storage_params[kw] - - storage_type = local_param.storage_type() - if self.flag != 2: - self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) - - if requires_grad and local_param.requires_grad: - self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() - if self.flag != 2: - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - nccl.allGather( - self.block._storage_params[kw].storage(), - self._param_buffer[kw], - self.comm - ) - nccl.groupEnd() - - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["load_stream"]) - - # set wait stream for each storage - for kw in self.block._storage_info.keys(): - if self.flag != 2: - self._param_tensor[kw].record_stream(current_stream) - if requires_grad and kw in self._grad_tensor: - self._grad_tensor[kw].record_stream(current_stream) - - # update parameters in block - for param in self.block._param_info: - kw_name = param["kw_name"] - offset = param["offset"] - shape = param["shape"] - - if self.flag != 2: - dtype = self._param_buffer[kw_name].dtype - device = self._param_buffer[kw_name].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) - else: - dtype = param["parameter"].data.dtype - device = param["parameter"].data.device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) - - if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) - - def __enter__(self): - self.enter() - - def exit(self): - """ - Reduce scatter gradients - """ +from . import hook_func - if not self._need_release: - return - self._need_release = False - self.block._ready = False - requires_grad = torch.is_grad_enabled() - if requires_grad: - for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] - - # accumulate previous gradient - if local_param.requires_grad: - if local_param.grad is None: - grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist - local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() - else: - self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad - - current_stream = torch.cuda.current_stream() - config["load_stream"].wait_stream(current_stream) # wait for backward - - with torch.cuda.stream(config["load_stream"]): - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] - - # scatter gradient - if local_param.requires_grad: - nccl.reduceScatter( - self._grad_buffer[kw], - local_param.grad.storage(), - "sum", - self.comm - ) - nccl.groupEnd() - - # set wait stream for each storage - for kw in self._grad_tensor.keys(): - # grads can not be freed until reduce ops finish - self._grad_tensor[kw].record_stream(config["load_stream"]) - - # Release all parameters from buffer to block_storge - for param in self.block._param_info: - kw_name = param["kw_name"] - dtype = self.block._storage_params[kw_name].dtype - device = self.block._storage_params[kw_name].device - if "begin" not in param: - param["parameter"].data = torch.tensor([], dtype=dtype, device=device) - param["parameter"].grad = None - continue - begin = param["begin"] - end = param["end"] - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) - if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) - if self.flag == 1: - for i in self._param_buffer: - self.ctx_dict[i] = self._param_buffer[i] - self._grad_tensor = {} - self._param_tensor = {} - self._grad_buffer = {} - self._param_buffer = {} - def __exit__(self, exc_type, exc_val, exc_tb): - # reduce scatter gradients - self.exit() +import copy +import inspect +from torch.utils.checkpoint import checkpoint def storage_type_cuda(storage_type): STORAGE_MAP = { @@ -310,7 +49,7 @@ def _get_param_kw(param : DistributedParameter): return type_name + grad_name + group_name class CheckpointBlock(torch.nn.Module): - """ Checkpoint a model or part of the model. + """ A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint. Checkpoint block is used to save the occupation of GPU memory in training. @@ -318,6 +57,7 @@ class CheckpointBlock(torch.nn.Module): Args: model (torch.nn.Module): The model to be checkpointed. All kinds of modules are supported. + use_checkpoint (boolean): use checkpoint or not. Default True. Examples: >>> transformer_block = TransformerBlock(...) @@ -326,9 +66,13 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): super().__init__() self._module = inner_module + self._inputs = None + self._layer_dict = {} + self._forward_block_ctx = None + self._backward_block_ctx = None # build large parameter&grad here self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} @@ -440,23 +184,88 @@ def __init__(self, inner_module : torch.nn.Module): del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) - # clear parameter data, but keep the dtype and device setattr(param, "_in_checkpoint_block", True) for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] - - def __call__(self, *args, **kwargs): - # gather here - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - all_inputs = list(args) - for kw, val in kwargs.items(): - all_inputs.append(kw) - all_inputs.append(val) - outputs = OpCheckpointBlock.apply(placeholder, self, True, len(args), *all_inputs) - len_output = outputs[0] - return outputs[1:1+len_output] if len_output > 0 else outputs[1] + + self.use_checkpoint = use_checkpoint + self._is_first_layer = True + self._is_last_layer = True + self._release_list = [True] + self._next_module = [] #save the next module of self + self._pre_module = [] #save the pre module of self + self._ref_count = 0 #incremental in forward and decreasing in backward + self._mode = "BLOCK" #BLOCK or ZERO or PIPE + self.all_input_no_grad = False + self.all_param_no_grad = False + + def set_pre_module(self, pre_module): + if pre_module is not None: + self._pre_module.append(pre_module) + pre_module._next_module.append(self) + + def pre_module(self): + assert len(self._pre_module) == self._ref_count, "{} != {}".format(len(self._pre_module), self._ref_count) + return self._pre_module[self._ref_count-1] + + def next_module(self): + assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count) + return self._next_module[self._ref_count-1] + + def backward_release(self, flag): + if self._ref_count == 1: + self._backward_block_ctx.exit(flag, True) + config['load_stream'].record_event(config['load_event']) + self._ref_count -= 1 + + def pre_hook(self, *args): + grad_tensors = [] + grad_index = [] + arg_list = list(args) + for i, arg in enumerate(args): + if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad: + grad_tensors.append(arg) + grad_index.append(i) + grad_tensors = tuple(grad_tensors) + + pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors) + for i in range(len(grad_index)): + arg_list[grad_index[i]] = pre_out[i] + + if self._mode != "PIPE" and len(grad_tensors) == 0: + self.all_param_no_grad = True + for param in self._param_info: + if param['parameter'].requires_grad: + self.all_param_no_grad = False + break + self.all_input_no_grad = True + else: + self.all_input_no_grad = False + return arg_list + + def post_hook(self, out): + tuple_out = (out, ) if isinstance(out, torch.Tensor) else out + post_out = hook_func.PostHookFunc.apply(self, *tuple_out) + if isinstance(out, torch.Tensor) and isinstance(post_out, tuple): + return post_out[0] + post_out = tuple(post_out) + return post_out + + def forward(self, *args): + arg_list = self.pre_hook(*args) + + if self.all_input_no_grad and not self.all_param_no_grad: + placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) + return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) + + if self.use_checkpoint: + out = checkpoint(self._module, *arg_list, use_reentrant=not self.all_input_no_grad) + else: + out = self._module(*arg_list) + + return self.post_hook(out) def __getattr__(self,name:str): if name=="_module": @@ -619,6 +428,7 @@ def init_parameters(self): param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] del tmp_tensor + def _named_members(self, get_members_fn, prefix='', recurse=True, **kwargs): r"""Helper method for yielding various names + members of modules.""" @@ -685,192 +495,6 @@ def eval(self): def __repr__(self): return self._module.__repr__() -class OpTransformerBlockList(torch.autograd.Function): - @staticmethod - def forward(ctx, placeholder, self : 'TransformerBlockList', save_list, num_hidden, *args): - tensors = [] - others = [] - for arg in args[num_hidden:]: - if torch.is_tensor(arg): - tensors.append(arg) - others.append(None) - else: - tensors.append(None) - others.append(arg) - hidden_states = args[:num_hidden] - - ctx.nontensor_inputs = others - ctx.self = self - ctx.save_list = copy.deepcopy(save_list) - ctx.num_save_needed = save_list[-1][1]+1 - ctx.layers_dict = [{} for _ in range(len(self))] - layer_inputs = [] - layer_inspector = [] - cuda_rng_state = [] - for i in range(len(self)): - with torch.no_grad(): - if save_list[i][0] == i: - layer_inputs += [hidden_state.detach() for hidden_state in hidden_states] - cuda_rng_state.append( torch.cuda.get_rng_state() ) - if config['zero_level']==2: - flag = 1 - else: - flag = 0 - block_ctx = CheckpointBlockContext(self._modules[str(i)], ctx.layers_dict[i], flag) - # gather parameter on load stream - block_ctx.enter() - # call inner module directly - with ScopedTensorInspectorContext() as inspector: - hidden_states = self._modules[str(i)]._module._call_impl(*hidden_states, *args[num_hidden:]) - if not isinstance(hidden_states, tuple): - hidden_states = (hidden_states,) - block_ctx.exit() - for it in inspector.hidden_states: - debug.append("_inspect_hidden_states", it) - layer_inspector.append(inspector.hidden_states) - - ctx.layer_inspector = layer_inspector - ctx.cuda_rng_state = cuda_rng_state - ctx.num_hidden = num_hidden - - ctx.save_for_backward(*layer_inputs, *tensors) - - if self.return_hidden_states: - middle_hiddens = layer_inputs - for mid in middle_hiddens: - mid.requires_grad_() - middle_hiddens = [ - torch.stack(middle_hiddens[i::num_hidden], dim=0) - for i in range(num_hidden) - ] - else: - middle_hiddens = [None] * num_hidden - return tuple(list(hidden_states) + middle_hiddens + [it["tensor"] for inspector_hiddens in ctx.layer_inspector for it in inspector_hiddens]) - - - @staticmethod - def backward(ctx, *grads): - grad_hidden_states = grads[:ctx.num_hidden] - grad_middles = grads[ctx.num_hidden:2*ctx.num_hidden] - grad_inspectors = grads[2*ctx.num_hidden:] - def exit_prev(prev_ctx, prev_grad): - if prev_ctx is not None: - if prev_grad: - with torch.enable_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - else: - with torch.no_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - all_inputs = [] - input_requires_grad = [] - - layer_inputs = ctx.saved_tensors[:ctx.num_save_needed * ctx.num_hidden] - save_args = ctx.saved_tensors[ctx.num_save_needed * ctx.num_hidden:] - for tensor, other in zip(save_args, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_requires_grad.append(False) - else: - # detach for tensor inputs - input_requires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - # overlap load and scatter here - prev_ctx = None - prev_grad = False - for i in reversed(range(len(ctx.self))): - if ctx.save_list[i][0] != i: - with torch.no_grad(): - st = ctx.save_list[i][0] - for j in range(st, i): - torch.cuda.set_rng_state(ctx.cuda_rng_state[j]) - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(j)], ctx.layers_dict[j], flag) - block_ctx.enter() - exit_prev(prev_ctx, prev_grad) - outputs = ctx.self._modules[str(j)]._module._call_impl( - layer_inputs[ctx.save_list[j][1]*ctx.num_hidden: ctx.save_list[j][1]*ctx.num_hidden+ctx.num_hidden], - *all_inputs - ) - if not isinstance(outputs, tuple): - outputs = (outputs,) - prev_ctx = block_ctx - prev_grad = False - for k, output in enumerate(outputs): - layer_inputs[ctx.save_list[j+1][1]*ctx.num_hidden + k].copy_(output) - ctx.save_list[j+1][0] = j+1 - - torch.cuda.set_rng_state(ctx.cuda_rng_state[i]) - ipts = [ - layer_inputs[ctx.save_list[i][1]*ctx.num_hidden + k].detach().requires_grad_() - for k in range(ctx.num_hidden) - ] - if config['zero_level'] == 2: - flag = 2 - else: - flag = 0 - block_ctx = CheckpointBlockContext(ctx.self._modules[str(i)], ctx.layers_dict[i], flag) - block_ctx.enter() - exit_prev(prev_ctx, prev_grad) - prev_ctx = block_ctx - prev_grad = True - - with ScopedTensorInspectorContext() as inspector: - outputs = ctx.self._modules[str(i)]._module._call_impl(*ipts, *all_inputs) - if not isinstance(outputs, tuple): - outputs = (outputs,) - - assert len(ctx.layer_inspector[i]) == len(inspector.hidden_states), "Backward step changed" - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.layer_inspector[i][j]["name"], "Backward step changed" - assert it["shape"] == ctx.layer_inspector[i][j]["shape"], "Backward step changed" - assert it["group"] == ctx.layer_inspector[i][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.layer_inspector[i][j]["tensor"] = it["tensor"] - ctx.layer_inspector[i][j]["requires_grad"] = it["requires_grad"] - if len(inspector.hidden_states) > 0: - torch.autograd.backward( - list(outputs) + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - grad_hidden_states + grad_inspectors[-len(inspector.hidden_states):], - ) - grad_inspectors = grad_inspectors[:-len(inspector.hidden_states)] - else: - torch.autograd.backward( - outputs, - grad_hidden_states, - ) - grad_hidden_states = [ipt.grad for ipt in ipts] - for k in range(ctx.num_hidden): - if grad_middles[k] is not None: - grad_hidden_states[k] = grad_hidden_states[k] + grad_middles[k][i] - grad_hidden_states = tuple(grad_hidden_states) - - exit_prev(prev_ctx, prev_grad) - - grads = [] - for inp, requires_grad in zip(all_inputs, input_requires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None) + tuple(grad_hidden_states) + tuple(grads) - class TransformerBlockList(torch.nn.Module): r""" TransformerBlockList is a list of CheckpointBlocks. @@ -896,12 +520,23 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) super().__init__() self._modules = {} + pre_module = None for i, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) + + module._mode = "ZERO" + module.set_pre_module(pre_module) + pre_module = module + self._is_first_layer = False + self._is_last_layer = False + self._modules[str(i)] = module self.add_module(str(i), module) + self._modules[str(0)]._is_first_layer = True + self._modules[str(len(modules)-1)]._is_last_layer = True + self.num_hidden = num_hidden if sqrt: @@ -928,6 +563,7 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) def __len__(self) -> int: return len(self._modules) + def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: @@ -935,9 +571,23 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, *args, return_hidden_states = False): self.return_hidden_states = return_hidden_states - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - outputs = OpTransformerBlockList.apply(placeholder, self, self.save_list, self.num_hidden, *args) + hidden_states = [] + for i in range(len(self)): + if return_hidden_states: + for hidden_state in args[:self.num_hidden]: + hidden_states.append(hidden_state) + outputs = self._modules[str(i)]._call_impl(*args) + if not isinstance(outputs, tuple): + outputs = (outputs, ) + args = outputs + args[self.num_hidden:] + + if return_hidden_states: + hidden_states = [ + torch.stack(hidden_states[i::self.num_hidden], dim=0) + for i in range(self.num_hidden) + ] + if return_hidden_states: - return tuple(outputs[:2*self.num_hidden]) + return outputs + tuple(hidden_states) else: - return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] \ No newline at end of file + return tuple(outputs[:self.num_hidden]) if self.num_hidden > 1 else outputs[0] diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index ac6a8d4f..b2c9ec07 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -1,7 +1,8 @@ import torch -from typing import Callable, TypeVar -from functools import wraps from . import debug +from . import nccl +from .global_var import config +from .synchronize import wait_loader class ScopedDebugTensorList: def __init__(self) -> None: @@ -28,3 +29,154 @@ def __exit__(self, *args): self._local_list._set_hidden_states(debug.get("_inspect_hidden_states", [])) debug.set("_inspect_hidden_states", self.prev_hidden) self.prev_hidden = None + +class CheckpointBlockContext: + def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = False) -> None: + self.block = block + self.ctx_dict = ctx_dict + self._param_buffer = {} + self._grad_buffer = {} + self._param_tensor = {} + self._grad_tensor = {} + self._need_release = False + if pipe: + self.comm = config["zero_comm"] + else: + self.comm = config["comm"] + def enter(self, flag=0, requires_grad=False): + """ + gather parameters + """ + if self.block._ready: + return + self.block._ready = True + self._need_release = True + + wait_loader() + with torch.cuda.stream(config["load_stream"]): + for kw, val in self.block._storage_info.items(): + assert self.block._storage_params[kw].is_cuda + assert kw not in self._grad_buffer + assert kw not in self._param_buffer + local_param = self.block._storage_params[kw] + + storage_type = local_param.storage_type() + if flag != 2: + self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) + self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) + + if requires_grad and local_param.requires_grad: + self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) + self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() + if flag != 2: + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + nccl.allGather( + self.block._storage_params[kw].storage(), + self._param_buffer[kw], + self.comm + ) + nccl.groupEnd() + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["load_stream"]) + + # set wait stream for each storage + for kw in self.block._storage_info.keys(): + if flag != 2: + self._param_tensor[kw].record_stream(current_stream) + if requires_grad and kw in self._grad_tensor: + self._grad_tensor[kw].record_stream(current_stream) + + # update parameters in block + for param in self.block._param_info: + kw_name = param["kw_name"] + offset = param["offset"] + shape = param["shape"] + + if flag != 2: + dtype = self._param_buffer[kw_name].dtype + device = self._param_buffer[kw_name].device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) + else: + dtype = param["parameter"].data.dtype + device = param["parameter"].data.device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) + + if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) + + def __enter__(self): + self.enter() + + def exit(self, flag=0, backward=False): + """ + Reduce scatter gradients + """ + + if not self._need_release: + return + self._need_release = False + self.block._ready = False + if backward: + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # accumulate previous gradient + if local_param.requires_grad: + if local_param.grad is None: + grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist + local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() + else: + self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad + + current_stream = torch.cuda.current_stream() + config["load_stream"].wait_stream(current_stream) # wait for backward + + with torch.cuda.stream(config["load_stream"]): + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # scatter gradient + if local_param.requires_grad: + nccl.reduceScatter( + self._grad_buffer[kw], + local_param.grad.storage(), + "sum", + self.comm + ) + nccl.groupEnd() + + # set wait stream for each storage + for kw in self._grad_tensor.keys(): + # grads can not be freed until reduce ops finish + self._grad_tensor[kw].record_stream(config["load_stream"]) + + + # Release all parameters from buffer to block_storge + for param in self.block._param_info: + kw_name = param["kw_name"] + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device + if "begin" not in param: + param["parameter"].data = torch.tensor([], dtype=dtype, device=device) + param["parameter"].grad = None + continue + begin = param["begin"] + end = param["end"] + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) + if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if flag == 1: + for i in self._param_buffer: + self.ctx_dict[i] = self._param_buffer[i] + self._grad_tensor = {} + self._param_tensor = {} + self._grad_buffer = {} + self._param_buffer = {} + + + def __exit__(self, exc_type, exc_val, exc_tb): + # reduce scatter gradients + self.exit() diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py new file mode 100644 index 00000000..6a56300e --- /dev/null +++ b/bmtrain/hook_func.py @@ -0,0 +1,110 @@ +import torch +from .global_var import config +from .checkpointing import CheckpointBlockContext + +def zero_pre_forward(module, inputs): + enter = True + pipe = False + if module._mode == "PIPE": + enter = module._micro_idx == 0 + pipe = True + if enter: + zero_level = config['zero_level'] + forward_flag = 1 if zero_level == 2 else 0 + if zero_level == 2 and module._ref_count > 1: + forward_flag = 2 # repeating forward in same layer + if module.all_param_no_grad: #only forward + forward_flag = 0 + module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=pipe) + module._forward_block_ctx.enter(forward_flag) + +def zero_post_forward(module, inputs, outputs): + forward_flag = 1 if config['zero_level'] == 2 else 0 + if module.all_param_no_grad: + forward_flag = 0 + exit = True + if module._mode == "PIPE": + exit = module._micro_idx == config['micros'] - 1 + + if exit: + module._forward_block_ctx.exit(forward_flag) + module._ref_count += 1 + +def zero_pre_backward(module, grad_outputs): + backward_flag = 2 if config['zero_level'] == 2 else 0 + if module._mode != "PIPE": + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) + module._backward_block_ctx.enter(backward_flag, True) + if not module._is_last_layer: + module.next_module().backward_release(backward_flag) + else: + if module._micro_idx == config['micros'] - 1: + module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True) + module._backward_block_ctx.enter(backward_flag, True) + +def zero_post_backward(module, grad_inputs, grad_outputs): + backward_flag = 2 if config['zero_level'] == 2 else 0 + if module._mode != "PIPE": + if module._is_first_layer: + module.backward_release(backward_flag) + else: + if module._micro_idx == 0: + module.backward_release(backward_flag) + module._micro_idx -= 1 + +class OneStepNoGradFunc(torch.autograd.Function): + """ + requires_grad = False for all inputs + """ + @staticmethod + def forward(ctx, module, placeholder, *x): + ctx.x = x + ctx.module = module + ctx.rng_state = torch.cuda.get_rng_state() + + with torch.no_grad(): + out = module._module(*x) + zero_post_forward(module, None, out) + if not isinstance(out, torch.Tensor): + return tuple(out) + return out + + @staticmethod + def backward(ctx, grads): + zero_pre_backward(ctx.module, grads) + with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): + torch.cuda.set_rng_state(ctx.rng_state) + x = ctx.x + with torch.enable_grad(): + out = ctx.module._module(*x) + torch.autograd.backward(out, grads) + zero_post_backward(ctx.module, grads, None) + grads = [] + for _ in x: + grads.append(None) + return None, None, *grads + + +class PreHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *x): + ctx.module = module + zero_pre_forward(module, x) + return x + + @staticmethod + def backward(ctx, *grads): + zero_post_backward(ctx.module, grads, None) + return None, *grads + +class PostHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *out): + ctx.module = module + zero_post_forward(module, None, out) + return out + + @staticmethod + def backward(ctx, *grads): + zero_pre_backward(ctx.module, grads) + return None, *grads diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 160ef421..31223640 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -185,6 +185,15 @@ def __init__(self, self.inplace = inplace def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if input.dtype == torch.float32: + return torch.nn.functional.cross_entropy( + input, + target.long(), + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing) + if self.inplace: ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor else: diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py new file mode 100644 index 00000000..67f9fdee --- /dev/null +++ b/bmtrain/nn/__init__.py @@ -0,0 +1 @@ +from .linear import Linear diff --git a/example/layers/linear.py b/bmtrain/nn/linear.py similarity index 50% rename from example/layers/linear.py rename to bmtrain/nn/linear.py index 0aa0ab00..faf0770e 100644 --- a/example/layers/linear.py +++ b/bmtrain/nn/linear.py @@ -2,6 +2,26 @@ import torch.nn.functional as F import bmtrain as bmt +class CustomLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias=None): + ctx.save_for_backward(x, weight, bias) + return F.linear(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + grad_x = grad_weight = grad_bias = None + if x.requires_grad: + grad_x = grad_output.matmul(weight) + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(x.reshape(-1, x.shape[-1])) + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + return grad_x, grad_weight, grad_bias + class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: super().__init__() @@ -15,9 +35,9 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return F.linear(input, self.weight, self.bias) + return CustomLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None - ) \ No newline at end of file + ) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 69c299bc..0a34ac46 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -5,310 +5,161 @@ from typing import Dict, Iterable, Iterator, Tuple, Union, List import torch -from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations from .global_var import config from . import nccl -from .checkpointing import ScopedTensorInspectorContext +from .checkpointing import ( + CheckpointBlockContext +) from . import debug -from .block_layer import CheckpointBlockContext, CheckpointBlock, round_up, _get_param_kw +from .block_layer import CheckpointBlock, round_up, _get_param_kw -class OpMicroForward(torch.autograd.Function): +class PipePreFunction(torch.autograd.Function): @staticmethod - def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', micro_idx, block_ctx_list, layers_dict, save_list, hidden_state, *args): - with PipeContext(self, hidden_state) as pipe_input: - hidden_state = pipe_input[0].detach() - tensors = [arg if torch.is_tensor(arg) else None for arg in args] - others = [arg if not torch.is_tensor(arg) else None for arg in args] - ctx.nontensor_inputs = others - ctx.self = self - ctx.micro_idx = micro_idx - ctx.block_ctx_list = block_ctx_list - ctx.layers_dict = layers_dict - ctx.save_list = copy.deepcopy(save_list) - ctx.num_save_needed = save_list[-1][1]+1 - layer_inputs = [] - layer_inspector = [] - cuda_rng_state = [] - for idx,layer_id in enumerate(self.layer_ids): - with torch.no_grad(): - if save_list[idx][0] == idx: - layer_inputs.append(hidden_state.detach()) - cuda_rng_state.append( torch.cuda.get_rng_state() ) - # gather parameter on load stream - if ctx.micro_idx == 0: - block_ctx_list[idx] = CheckpointBlockContext(self._modules[str(layer_id)], ctx.layers_dict[idx], 1, pipe=True) - block_ctx_list[idx].enter() - # call inner module directly - with ScopedTensorInspectorContext() as inspector: - hidden_state = self._modules[str(layer_id)]._module._call_impl(hidden_state, *args) - if ctx.micro_idx == config["micros"]-1: - block_ctx_list[idx].exit() - for ith, it in enumerate(inspector.hidden_states): - it["inside_pipe"] = { - "stage_id": self.stage_id, - "stages": self.stages, - "st": (layer_id==self.layer_ids[0] and ith==0), - "ed": (layer_id==self.layer_ids[-1] and ith==len(inspector.hidden_states)-1), - } - debug.append("_inspect_hidden_states", it) - layer_inspector.append(inspector.hidden_states) - - ctx.layer_inspector = layer_inspector - ctx.cuda_rng_state = cuda_rng_state - - ctx.save_for_backward(*layer_inputs, *tensors) - pipe_input[0] = hidden_state - if self.return_hidden_states: - middle_hiddens = layer_inputs - for mid in middle_hiddens: - mid.requires_grad_() - middle_hiddens = torch.stack(middle_hiddens, dim=0) - else: - middle_hiddens = None - return tuple([pipe_input[0], middle_hiddens] + [hidden_state["tensor"] for hidden_states in ctx.layer_inspector for hidden_state in hidden_states]) + def forward(ctx, hidden_state, *args): + hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) + hidden_state_list.requires_grad_() - @staticmethod - def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, *grad_inspector): - def exit_prev(prev_ctx, prev_grad): - if prev_ctx is not None: - if prev_grad: - with torch.enable_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - else: - with torch.no_grad(): - prev_ctx.exit() - config["load_stream"].record_event(config["load_event"]) - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad() or when an `inputs` parameter" - " is passed to .backward(). Please use .backward() and do not pass its `inputs`" - " argument.") - all_inputs = [] - input_requires_grad = [] - - layer_inputs = ctx.saved_tensors[:ctx.num_save_needed] - save_args = ctx.saved_tensors[ctx.num_save_needed:] - for tensor, other in zip(save_args, ctx.nontensor_inputs): - if tensor is None: - all_inputs.append(other) - input_requires_grad.append(False) - else: - # detach for tensor inputs - input_requires_grad.append( tensor.requires_grad ) - nw_tensor = tensor.detach() - nw_tensor.requires_grad = tensor.requires_grad - all_inputs.append(nw_tensor) - with PipeContext(ctx.self, grad_hidden_state, backward=True) as pipe_input: - grad_hidden_state = pipe_input[0] - with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): - with torch.enable_grad(): - # overlap load and scatter here - prev_ctx = None - prev_grad = False - for idx, layer_id in list(enumerate(ctx.self.layer_ids))[::-1]: - torch.cuda.set_rng_state(ctx.cuda_rng_state[idx]) - ipt = layer_inputs[ctx.save_list[idx][1]].requires_grad_() - if ctx.micro_idx == 0: - ctx.block_ctx_list[idx] = CheckpointBlockContext(ctx.self._modules[str(layer_id)], ctx.layers_dict[idx], 2, pipe=True) - ctx.block_ctx_list[idx].enter() - if ctx.micro_idx == config["micros"]-1: - exit_prev(prev_ctx, prev_grad) - prev_ctx = ctx.block_ctx_list[idx] - prev_grad = True - - with ScopedTensorInspectorContext() as inspector: - output = ctx.self._modules[str(layer_id)]._module._call_impl(ipt, *all_inputs) - - assert len(ctx.layer_inspector[idx]) == len(inspector.hidden_states), "Backward step changed" - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.layer_inspector[idx][j]["name"], "Backward step changed" - assert it["shape"] == ctx.layer_inspector[idx][j]["shape"], "Backward step changed" - assert it["group"] == ctx.layer_inspector[idx][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.layer_inspector[idx][j]["tensor"] = it["tensor"] - ctx.layer_inspector[idx][j]["requires_grad"] = it["requires_grad"] - if len(inspector.hidden_states) > 0: - torch.autograd.backward( - [output] + [hidden_state["tensor"] for hidden_state in inspector.hidden_states], - [grad_hidden_state] + list(grad_inspector[-len(inspector.hidden_states):]), - ) - grad_inspector = grad_inspector[:-len(inspector.hidden_states)] - else: - torch.autograd.backward( - [output], - [grad_hidden_state], - ) - grad_hidden_state = ipt.grad - if grad_middle is not None: - grad_hidden_state = grad_hidden_state + grad_middle[idx] - if ctx.micro_idx == config["micros"]-1: - exit_prev(prev_ctx, prev_grad) - for inspector_hiddens in ctx.layer_inspector: - for it in inspector_hiddens: - debug.append("_inspect_hidden_states", it) - - pipe_input[0] = grad_hidden_state - grads = [] - for inp, requires_grad in zip(all_inputs, input_requires_grad): - if requires_grad: - grads.append(inp.grad) - else: - grads.append(None) - return (None, None, None, None, None, None, pipe_input[0]) + tuple(grads) - -class OpPipeTransformerBlockList(torch.autograd.Function): - @staticmethod - def forward(ctx, placeholder, self : 'PipelineTransformerBlockList', save_list, hidden_state, *args): - num_micros = config["micros"] - ctx.self = self - ctx.num_micros = num_micros - block_ctx = [None for _ in range(len(self))] - layers_dict = [{} for _ in range(len(self))] - args_list = [[] for _ in range(num_micros)] batch_related = args[-1] batch_related_origin = [True if i in args[-1] else False for i in range(len(args[:-1]))] batch_related_rule = [] args = args[:-1] + batch_size = hidden_state.shape[0] - assert (batch_size * config["pipe_size"]) % num_micros == 0, f'The batch size {(batch_size * config["pipe_size"])} must be divisible by the number of micro_batch {num_micros}' + num_micros = config["micros"] + args_list = [[] for _ in range(num_micros)] input_requires_grad = [] - inspector_hiddens = [] - ctx.inspector_hiddens_sep = [0] - ctx.micro_inspector = [] - with torch.enable_grad(): - for arg in args: - if torch.is_tensor(arg): - arg_all = all_gather(arg, config['pipe_comm']) - if arg.shape[0] == batch_size: - batch_related_rule.append(True) - arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) - arg_all = [tensor.detach().requires_grad_(arg.requires_grad) for tensor in arg_all] - else: - batch_related_rule.append(False) - # assert num_micros % self.stages == 0, "batch unrelated only support num_micros % stages == 0" - # arg_all = [arg_all[i // (num_micros // self.stages)].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] - arg_all = [arg_all[0].detach().requires_grad_(arg.requires_grad) for i in range(num_micros)] - input_requires_grad.append(arg.requires_grad) + for arg in args: + if torch.is_tensor(arg): + arg_all = all_gather(arg, config['pipe_comm']) + if arg.dim() == hidden_state.dim() and arg.shape[0] == batch_size: + batch_related_rule.append(True) + arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) + arg_all = [tensor.requires_grad_(arg.requires_grad) for tensor in arg_all] else: batch_related_rule.append(False) - arg_all = [arg for _ in range(num_micros)] - input_requires_grad.append(False) - for i in range(num_micros): - args_list[i].append(arg_all[i]) - outputs = [] - if self.return_hidden_states: - middles = [] - hidden_state_list = all_gather(hidden_state, config["pipe_comm"]).flatten(0, 1).detach().requires_grad_() - ctx.hidden_state_list = hidden_state_list - hidden_state_list = hidden_state_list.chunk(num_micros, dim=0) - for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - with ScopedTensorInspectorContext() as inspector: - micro_outputs = OpMicroForward.apply(placeholder, self, micro_idx, block_ctx, layers_dict, save_list, hidden_state, *arg) - output, middle = micro_outputs[:2] - outputs.append(output) - if self.return_hidden_states: - middles.append(middle) - for it in inspector.hidden_states: - inspector_hiddens.append(it["tensor"]) - it["tensor"] = it["tensor"].clone() - debug.append("_inspect_hidden_states", it) - ctx.inspector_hiddens_sep.append(len(inspector_hiddens)) - ctx.micro_inspector.append(inspector.hidden_states) + arg_all = [arg_all[0].requires_grad_(arg.requires_grad) for i in range(num_micros)] + input_requires_grad.append(arg.requires_grad) + else: + batch_related_rule.append(False) + arg_all = [arg for _ in range(num_micros)] + input_requires_grad.append(False) + for i in range(num_micros): + args_list[i].append(arg_all[i]) + ctx.input_requires_grad = input_requires_grad + ctx.args_list = args_list if len(batch_related) == 0: ctx.batch_related = batch_related_rule else: ctx.batch_related = batch_related_origin - ctx.args_list = args_list - ctx.input_requires_grad = input_requires_grad - ctx.output_list = outputs - if self.return_hidden_states: - ctx.middle_list = middles - - with torch.enable_grad(): - last_hidden = torch.cat(outputs, dim=0) - last_hidden_shape = last_hidden.shape - last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) - last_hidden = last_hidden.chunk(self.stages, dim=0) - last_hidden = last_hidden[self.stage_id].clone() - - if self.return_hidden_states: - middle_hiddens = [] - with torch.enable_grad(): - for stage_id in range(self.stages): - if self.stage_id == stage_id: - middle_hidden = torch.cat(middles, dim=1) # [(layers, micro_batch, ...), ] -> (layers, full_batch, ...) - else: - middle_shape = (self.get_part_len_by_stage_id(stage_id),)+last_hidden_shape - middle_hidden = torch.zeros(middle_shape, device=last_hidden.device, dtype=last_hidden.dtype) - middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) - middle_hidden = middle_hidden.chunk(self.stages, dim=1) - middle_hidden = middle_hidden[self.stage_id].clone() - middle_hiddens.append(middle_hidden) - middle_hiddens = torch.cat(middle_hiddens, dim=0) - else: - middle_hiddens = None - - ctx.save_for_backward(*inspector_hiddens) - return tuple([last_hidden, middle_hiddens] + [it["tensor"] for inspector_hiddens in ctx.micro_inspector for it in inspector_hiddens]) - + return hidden_state_list, args_list @staticmethod - def backward(ctx, grad_hidden_state : torch.Tensor, grad_middle : torch.Tensor, *grad_inspectors): - inspector_hiddens = ctx.saved_tensors - ipt = ctx.hidden_state_list - args_list = ctx.args_list - input_requires_grad = ctx.input_requires_grad - grad_hidden_state_list = all_gather(grad_hidden_state, config["pipe_comm"]).flatten(start_dim=0, end_dim=1).chunk(ctx.num_micros, dim=0) - if ctx.self.return_hidden_states: - for stage_id in range(ctx.self.stages): - layer_range = ctx.self.get_range_by_stage_id(stage_id) - grad_middle_state = grad_middle[layer_range] - grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]).flatten(start_dim=0, end_dim=1).transpose(0, 1).chunk(ctx.num_micros, dim=1) # (layer, micro_batch, ...) - if ctx.self.stage_id == stage_id: - grad_middle_state_list = grad_middle_state - - for m in range(ctx.num_micros): - outputs = [ctx.output_list[m]] - grad_outputs = [grad_hidden_state_list[m]] - if ctx.self.return_hidden_states: - outputs.append(ctx.middle_list[m]) - grad_outputs.append(grad_middle_state_list[m]) - outputs += list(inspector_hiddens[ctx.inspector_hiddens_sep[m]:ctx.inspector_hiddens_sep[m+1]]) - grad_outputs += list(grad_inspectors[ctx.inspector_hiddens_sep[m]:ctx.inspector_hiddens_sep[m+1]]) - with ScopedTensorInspectorContext() as inspector: - torch.autograd.backward( - outputs, - grad_outputs, - ) - for j, it in enumerate(inspector.hidden_states): - assert it["name"] == ctx.micro_inspector[m][j]["name"], "Backward step changed" - assert it["shape"] == ctx.micro_inspector[m][j]["shape"], "Backward step changed" - assert it["group"] == ctx.micro_inspector[m][j]["group"], "Backward step changed" - - # change the tensor in placeholder - ctx.micro_inspector[m][j]["tensor"] = it["tensor"] - ctx.micro_inspector[m][j]["requires_grad"] = it["requires_grad"] - - grads = [] - for idx,requires_grad in enumerate(input_requires_grad): + def backward(ctx, grads, arg_grads): + grads = broadcast(grads, 0, config['pipe_comm']) + topo = config['topology'] + arg_grads = [] + num_micros = config['micros'] + for idx,requires_grad in enumerate(ctx.input_requires_grad): if requires_grad: - grad = torch.cat([args_list[m][idx].grad for m in range(ctx.num_micros)], dim=0) + grad = torch.cat([ctx.args_list[m][idx].grad for m in range(num_micros)], dim=0) grad = all_reduce(grad, "sum", config["pipe_comm"]) - split_size = ctx.self.stages if ctx.batch_related[idx] else ctx.num_micros + split_size = topo.stages if ctx.batch_related[idx] else num_micros grad = grad.chunk(split_size) if ctx.batch_related[idx]: - grads.append(grad[ctx.self.stage_id]) + arg_grads.append(grad[topo.stage_id]) else: - grads.append(grad[0]) + arg_grads.append(grad[0]) else: - grads.append(None) - grad = broadcast(ipt.grad, 0, config["pipe_comm"]).chunk(ctx.self.stages) - grad = grad[ctx.self.stage_id] + arg_grads.append(None) + arg_grads.append(None) #for append(batch_related) + return grads.chunk(topo.stages, dim=0)[topo.stage_id], *arg_grads + +class PipePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, last_hidden, hidden_states=None, forward_stage_ranges=None, backward_stage_ranges=None, last_hidden_shape=None, return_hidden_states=False): + topo = config['topology'] + ctx.return_hidden_states = return_hidden_states + last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) + last_hidden = last_hidden.chunk(topo.stages, dim=0) + output = last_hidden[topo.stage_id] + output.requires_grad_() + + if return_hidden_states: + ctx.stage_id = topo.stage_id + ctx.stages = topo.stages + ctx.backward_stage_ranges = backward_stage_ranges + middle_hiddens = [] + for stage_id in range(ctx.stages): + if ctx.stage_id == stage_id: + middle_hidden = hidden_states + else: + middle_shape = (forward_stage_ranges[stage_id],) + last_hidden_shape + middle_hidden = torch.zeros(middle_shape, device=hidden_states.device, dtype=hidden_states.dtype) + middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) + middle_hidden = middle_hidden.chunk(ctx.stages, dim=1) + middle_hidden = middle_hidden[ctx.stage_id].clone() + middle_hiddens.append(middle_hidden) + middle_hiddens = torch.cat(middle_hiddens, dim=0) + middle_hiddens.requires_grad_() + return output, middle_hiddens + else: + return output + + @staticmethod + def backward(ctx, grads, grad_middle=None): + grad_list = all_gather(grads, config["pipe_comm"]) + grad_list = grad_list.flatten(start_dim=0, end_dim=1) + + if ctx.return_hidden_states: + for stage_id in range(ctx.stages): + layer_range = ctx.backward_stage_ranges[stage_id] + grad_middle_state = grad_middle[layer_range] + grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]) + grad_middle_state = grad_middle_state.flatten(start_dim=0, end_dim=1).transpose(0, 1) + if ctx.stage_id == stage_id: + grad_hidden_state_list = grad_middle_state + return grad_list, grad_hidden_state_list, None, None, None, None + else: + return grad_list + +class StagePreFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, stage_id): + ctx.stage_id = stage_id + ctx.is_first_stage = stage_id == 0 + ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + if not ctx.is_first_stage: + input = recv_activations(stage_id - 1, config['pipe_comm']) + input.requires_grad_() + return input + return input + + @staticmethod + def backward(ctx, grad_outputs): + if not ctx.is_first_stage: + send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs + send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + return grad_outputs, None + +class StagePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, outputs, stage_id): + ctx.stage_id = stage_id + ctx.is_first_stage = stage_id == 0 + ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + if not ctx.is_last_stage: + send_data = outputs[0] if isinstance(outputs, tuple) else outputs + send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + return outputs + + @staticmethod + def backward(ctx, grad_outputs): + if not ctx.is_last_stage: + pre_grad_inputs = recv_activations(ctx.stage_id + 1, config['pipe_comm']) + return pre_grad_inputs, None + return grad_outputs, None - return (None, None, None, grad) + tuple(grads) + (None,) class PipelineTransformerBlockList(torch.nn.Module): r""" @@ -331,9 +182,9 @@ class PipelineTransformerBlockList(torch.nn.Module): """ _modules: Dict[str, CheckpointBlock] - def __init__(self, modules: Iterable[CheckpointBlock]) -> None: + def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: super().__init__() - + self.num_hidden = num_hidden self._modules = {} rank = config['rank'] topo = config['topology'] @@ -345,18 +196,37 @@ def __init__(self, modules: Iterable[CheckpointBlock]) -> None: for idx, module in enumerate(modules): if not isinstance(module, CheckpointBlock): module = CheckpointBlock(module) + + module._mode = "PIPE" + module.stage_id = self.stage_id + module.stages = self.stages + self._modules[str(idx)] = module self.layer_ids = self.get_range_by_stage_id(self.stage_id) + + pre_module = None + for i,layer_id in enumerate(self.layer_ids): + module = self._modules[str(layer_id)] + module.set_pre_module(pre_module) + pre_module = module + + module._is_first_stage = True if self.stage_id == 0 else False + module._is_last_stage = True if self.stage_id == self.stages-1 else False + module._is_first_layer = False + module._is_last_layer = False + self._modules[str(self.layer_ids[0])]._is_first_layer = True + self._modules[str(self.layer_ids[-1])]._is_last_layer = True + self.partition_modules(self.layer_ids) self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 # self.micro_batches = config['num_micro_batches'] - + self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: - return len(self._modules) + return len(self._modules) def __iter__(self) -> Iterator[CheckpointBlock]: return iter(self._modules.values()) @@ -366,15 +236,47 @@ def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False): self.return_hidden_states = return_hidden_states - placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) - args = list(args) - args.append(batch_related) - outputs = OpPipeTransformerBlockList.apply(placeholder, self, self.save_list, hidden_state, *args) - hidden_state, middle_states = outputs[:2] + batch_size = hidden_state.shape[0] + num_micros = config["micros"] + args = args + (batch_related, ) + hidden_state.requires_grad_() + hidden_state_list, args_list = PipePreFunction.apply(hidden_state, *args) + + hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) + outputs = [] + hidden_states = [] + + for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): + micro_hidden_states = [] + + hidden_state = StagePreFunction.apply(hidden_state, self.stage_id) + + for idx,layer_id in enumerate(self.layer_ids): + self._modules[str(layer_id)]._micro_idx = micro_idx + if return_hidden_states: + micro_hidden_states.append(hidden_state) + hidden_state = self._modules[str(layer_id)](hidden_state, *arg) + hidden_state = StagePostFunction.apply(hidden_state, self.stage_id) + + outputs.append(hidden_state) + if return_hidden_states: + hidden_states.append(torch.stack(micro_hidden_states, dim=0)) + + last_hidden = torch.cat(outputs, dim=0) + last_hidden_shape = last_hidden.shape + if return_hidden_states: - return hidden_state, middle_states + hidden_states = torch.cat(hidden_states, dim=1) + forward_stage_ranges = [] + backward_stage_ranges = [] + for stage_id in range(self.stages): + forward_stage_ranges.append(self.get_part_len_by_stage_id(stage_id)) + backward_stage_ranges.append(self.get_range_by_stage_id(stage_id)) + outputs, hidden_states = PipePostFunction.apply(last_hidden, hidden_states, forward_stage_ranges, backward_stage_ranges, last_hidden_shape, return_hidden_states) + return outputs, hidden_states else: - return hidden_state + outputs = PipePostFunction.apply(last_hidden) + return outputs def get_range_by_stage_id(self, stage_id : int) -> List[int]: part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] @@ -486,32 +388,3 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for n, parameter in module._module.named_parameters(): destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']) -class PipeContext: - def __init__(self, module, hidden_state, backward=False): - self.module = module - self.stage_id = module.stage_id - self.stages = module.stages - self.next_rank = module.next_rank - self.prev_rank = module.prev_rank - self.hidden_state = [hidden_state] - self.backward = backward - self.send_buffer = {} - def enter(self): - if self.backward: - if self.stage_id != self.stages -1: - self.hidden_state[0] = recv_activations(self.stage_id + 1, config['pipe_comm']) - else: - if self.stage_id != 0: - self.hidden_state[0] = recv_activations(self.stage_id - 1, config['pipe_comm']) - return self.hidden_state - def exit(self): - if self.backward: - if self.stage_id != 0: - send_activations(self.hidden_state[0], self.stage_id - 1, config['pipe_comm']) - else: - if self.stage_id != self.stages - 1: - send_activations(self.hidden_state[0], self.stage_id + 1, config['pipe_comm']) - def __enter__(self): - return self.enter() - def __exit__(self, exc_type, exc_val, exc_tb): - self.exit() \ No newline at end of file diff --git a/example/layers/__init__.py b/example/layers/__init__.py index 425d0a1b..ef4617c0 100644 --- a/example/layers/__init__.py +++ b/example/layers/__init__.py @@ -1,6 +1,5 @@ -from .linear import Linear from .embedding import Embedding from .feedforward import Feedforward from .layernorm import Layernorm from .attention import Attention -from .transformer import TransformerEncoder \ No newline at end of file +from .transformer import TransformerEncoder diff --git a/example/layers/attention.py b/example/layers/attention.py index 4a0eec11..243df3ea 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,7 +1,7 @@ from typing import Optional import torch import bmtrain as bmt -from layers import Linear +from bmtrain.nn import Linear import math class Attention(bmt.DistributedModule): diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 3fe935bf..99d2dc3b 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,6 +1,6 @@ import torch import bmtrain as bmt -from layers import Linear +from bmtrain.nn import Linear class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: diff --git a/example/train.py b/example/train.py index 7bc92400..1a744e20 100644 --- a/example/train.py +++ b/example/train.py @@ -2,6 +2,8 @@ import bmtrain as bmt from models import GPT import time +from bmtrain import optim +from bmtrain import inspect def main(): bmt.init_distributed( @@ -51,10 +53,10 @@ def main(): break loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) - optimizer = bmt.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) - optim_manager = bmt.optim.OptimManager(loss_scale=2**20) + optim_manager = optim.OptimManager(loss_scale=2**20) optim_manager.add_optimizer(optimizer, lr_scheduler) bmt.synchronize() @@ -66,7 +68,7 @@ def main(): # load data st = time.time() - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) logits = model( enc_input, @@ -87,13 +89,13 @@ def main(): # print parameters of the model if iteration % 100 == 0: bmt.print_rank( - bmt.inspect.format_summary( + inspect.format_summary( inspector.get_summary() ) ) bmt.print_rank( - bmt.inspect.format_summary( - bmt.inspect.inspect_model(model, "*") + inspect.format_summary( + inspect.inspect_model(model, "*") ) ) diff --git a/tests/test_all.py b/tests/test_all.py index b614d3eb..6682aa93 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -24,6 +24,7 @@ ("send_recv", 4), ("nccl_backward", 4), + ("no_grad", 1), ("training", 4), ]) diff --git a/tests/test_has_inf_nan.py b/tests/test_has_inf_nan.py index b1b9b4a9..fda85515 100644 --- a/tests/test_has_inf_nan.py +++ b/tests/test_has_inf_nan.py @@ -1,12 +1,12 @@ from utils import * import torch -import bmtrain.optim._cuda as G +import bmtrain.loss._function as F import random def check(x, v): out = torch.zeros(1, dtype=torch.uint8, device="cuda")[0] - G.f_has_inf_nan(x, out) + F.has_inf_nan(x, out) assert_eq(out.item(), v) def test_main(): @@ -29,4 +29,4 @@ def test_main(): check(x, 1) if __name__ == "__main__": - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_inspector_hidden.py b/tests/test_inspector_hidden.py index 731884ad..c39de5fb 100644 --- a/tests/test_inspector_hidden.py +++ b/tests/test_inspector_hidden.py @@ -7,6 +7,7 @@ from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -48,7 +49,7 @@ def __init__(self, dim : int): def forward(self, x): x = self.m1(x) - bmt.inspect.record_tensor(x, "hidden") + inspect.record_tensor(x, "hidden") x = self.m2(x) return x @@ -160,10 +161,10 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len): bmt.init_parameters(m) m = cls(pre, [m for m in ms], post) ret = "" - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: logits = m(inp) - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" @@ -171,32 +172,32 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len): for i in range(len(ms)//2): loss = loss + (inspector.summary[i]['tensor'] * middle_weight[i]).sum() - with bmt.inspect.inspect_tensor(): + with inspect.inspect_tensor(): loss.backward() - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) + "\n" - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" - with bmt.inspect.inspect_tensor() as inspector: + with inspect.inspect_tensor() as inspector: logits = m(inp) - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" loss = (logits * last_weight).sum() - with bmt.inspect.inspect_tensor(): + with inspect.inspect_tensor(): loss.backward() - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) + "\n" - ret += bmt.inspect.format_summary( + ret += inspect.format_summary( inspector.get_summary() ) + "\n" @@ -237,4 +238,4 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed(pipe_size=2) - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_middle_hidden.py b/tests/test_middle_hidden.py index f0d5c559..688cdfe5 100644 --- a/tests/test_middle_hidden.py +++ b/tests/test_middle_hidden.py @@ -3,10 +3,10 @@ import bmtrain as bmt import random import torch -from bmtrain import config from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -143,8 +143,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid loss = (logits * last_weight).sum() loss.backward() ret += f"========================only last========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if only_middle: logits, hidden_states = m(inp, return_hidden_states=True) @@ -154,8 +154,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ]) loss.backward() ret += f"========================only middle========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if mix_test: logits, hidden_states = m(inp, return_hidden_states=True) @@ -165,8 +165,8 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_mid ]) + (logits * last_weight).sum() loss.backward() ret += f"========================mix========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) return ret + "\n" # replace for matching None grad with zero_grad @@ -209,4 +209,4 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed(pipe_size=4) - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_model_wrapper.py b/tests/test_model_wrapper.py index 409107e3..6f913d3c 100644 --- a/tests/test_model_wrapper.py +++ b/tests/test_model_wrapper.py @@ -164,7 +164,7 @@ def forward(self, out = input_emb for layer in self.transformers: - out = layer(out, position_bias=None, mask=mask_2d) + out = layer(out, mask_2d) out = self.layernorm(out) logits = F.linear(out, self.word_emb.weight) / math.sqrt(self.dim_model) @@ -218,4 +218,4 @@ def test_main(): if __name__ == '__main__': bmt.init_distributed(seed=0) - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_no_grad.py b/tests/test_no_grad.py new file mode 100644 index 00000000..3629921b --- /dev/null +++ b/tests/test_no_grad.py @@ -0,0 +1,46 @@ +import torch +import bmtrain as bmt + +class Layer(torch.nn.Module): + def __init__(self): + super(Layer, self).__init__() + self.linear = bmt.nn.Linear(32, 32) + self.count = 0 + def forward(self, x): + self.count += 1 + return self.linear(x) + +def test_no_grad(): + x = torch.randn(32, 32, device='cuda') + + layer1 = bmt.CheckpointBlock(Layer()) + layer2 = bmt.CheckpointBlock(Layer()) + layer1.linear.weight.requires_grad_(False) + layer1.linear.bias.requires_grad_(False) + y = layer1(x) + assert y.requires_grad == False + y = layer2(y) + y.sum().backward() + assert layer1.count == 1 + assert layer2.count == 2 + +def test_all_input_no_grad(): + linear1 = bmt.nn.Linear(32, 32) + linear2 = bmt.nn.Linear(32, 32) + + x = torch.randn(32,32, device='cuda') + + linear1 = bmt.CheckpointBlock(linear1) + linear2 = bmt.CheckpointBlock(linear2) + y = linear1(x) + y = linear2(y) + y.sum().backward() + assert linear1.weight.grad is not None + assert linear1.bias.grad is not None + assert x.grad is None + +if __name__ == '__main__': + bmt.init_distributed() + + test_no_grad() + test_all_input_no_grad() diff --git a/tests/test_optim.py b/tests/test_optim.py index 81356ede..fdb64521 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,5 +1,6 @@ import torch import bmtrain as bmt +from bmtrain import optim class TestModule(torch.nn.Module): def __init__(self): @@ -29,8 +30,8 @@ def main(): model2 = model2.cuda() model3 = model3.cuda() - opt1 = bmt.optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) - opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) for _ in range(100): diff --git a/tests/test_optim_state.py b/tests/test_optim_state.py index df697f49..cef06734 100644 --- a/tests/test_optim_state.py +++ b/tests/test_optim_state.py @@ -2,6 +2,7 @@ import bmtrain as bmt import os from copy import deepcopy +from bmtrain import optim class TestSubModule(bmt.DistributedModule): def __init__(self): @@ -67,10 +68,10 @@ def main(): bmt.load(model2, f"test_optim_state_model1.pt") bmt.load(model3, f"test_optim_state_model1.pt") - opt1 = bmt.optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) - opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) - optim_manager = bmt.optim.OptimManager(loss_scale=256) + optim_manager = optim.OptimManager(loss_scale=256) optim_manager.add_optimizer(opt1) optim_manager.add_optimizer(opt2) optim_manager.add_optimizer(opt3) @@ -121,4 +122,4 @@ def main(): if __name__ == "__main__": bmt.init_distributed() - main() \ No newline at end of file + main() diff --git a/tests/test_other_hidden.py b/tests/test_other_hidden.py index d1e317ad..1f6c8c65 100644 --- a/tests/test_other_hidden.py +++ b/tests/test_other_hidden.py @@ -7,6 +7,7 @@ from bmtrain.block_layer import CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -142,22 +143,22 @@ def sub_run(name, cls, num_layer, dim, batch, seq_len, only_pre=False, only_post loss = (pre.weight * last_weight).sum() loss.backward() ret += f"========================only last========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if only_post: loss = (post.weight * last_weight).sum() loss.backward() ret += f"========================only middle========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) if mix_test: loss = (pre.weight * last_weight).sum() + (post.weight * last_weight).sum() loss.backward() ret += f"========================mix========================\n" - ret += bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + ret += inspect.format_summary( + inspect.inspect_model(m, '*') ) return ret + "\n" # replace for matching None grad with zero_grad diff --git a/tests/test_requires_grad.py b/tests/test_requires_grad.py index 83fe8d17..943275c3 100644 --- a/tests/test_requires_grad.py +++ b/tests/test_requires_grad.py @@ -7,6 +7,7 @@ from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -25,19 +26,22 @@ def __init__(self, in_features : int, out_features: int, init_weight = None, ini else: self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) - def forward(self, input): + def forward(self, input, other_bias): ret = F.linear(input, self.weight, self.bias) + ret += other_bias return ret def run(m, a, b): inp = torch.rand((1, 10, 256)).cuda()*100 - logits = m(inp) + bias = torch.rand(256).cuda()*100 + logits = m(inp, bias) loss = logits.sum() loss.backward() bmt.synchronize() - sm = bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + sm = inspect.format_summary( + inspect.inspect_model(m, '*') ) + assert_eq(bias.requires_grad, False) return a.weight.grad is None, a.bias.grad is None, sm def test_main(): @@ -100,4 +104,4 @@ def test_main_pipe(): bmt.init_distributed(pipe_size=1) test_main() - test_main_pipe() \ No newline at end of file + test_main_pipe() diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index ebea096e..4a2670ae 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -2,11 +2,11 @@ import bmtrain as bmt import torch -from bmtrain import config from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List import torch.nn.functional as F +from bmtrain import inspect class Linear(bmt.DistributedModule): def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: @@ -35,8 +35,8 @@ def run(m, a, b): loss = logits.sum() loss.backward() - sm = bmt.inspect.format_summary( - bmt.inspect.inspect_model(m, '*') + sm = inspect.format_summary( + inspect.inspect_model(m, '*') ) return sm @@ -93,4 +93,4 @@ def test_main_pipe(): bmt.init_distributed(pipe_size=2) test_main() - test_main_pipe() \ No newline at end of file + test_main_pipe() diff --git a/tests/test_training.py b/tests/test_training.py index 7342fe6c..1d6481c9 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -6,8 +6,8 @@ import math import torch.nn.functional as F import bmtrain as bmt -from bmtrain.global_var import config import os +from bmtrain import inspect class Attention(torch.nn.Module): def __init__(self, @@ -151,6 +151,7 @@ def __init__(self, ) for _ in range(num_layers) ]) + self.run_unroll = False self.layernorm = torch.nn.LayerNorm(dim_model, dtype=dtype) @@ -166,7 +167,7 @@ def forward(self, input_emb = self.pos_emb(pos) + self.word_emb(input) out = input_emb - if isinstance(self.transformers, torch.nn.ModuleList): + if isinstance(self.transformers, torch.nn.ModuleList) or self.run_unroll: for layer in self.transformers: out = layer(out, mask_2d, None) else: @@ -250,7 +251,7 @@ def sub_train_torch(model, loss_func_cls, optimizer_cls): )) logs.append(global_loss) - summary = bmt.inspect.inspect_model(model, "*") + summary = inspect.inspect_model(model, "*") return logs, summary def sub_train(model, loss_func_cls, optimizer_cls): @@ -311,7 +312,7 @@ def sub_train(model, loss_func_cls, optimizer_cls): )) logs.append(global_loss) - summary = bmt.inspect.inspect_model(model, "*") + summary = inspect.inspect_model(model, "*") return logs, summary def train(model, loss_func, optimizer): @@ -376,11 +377,20 @@ def pipe_model(): bmt.load(pipe_model, ckpt_path) return model + def unroll_list_model(): + model = GPT(**kwargs) + list_model = bmt.BMTrainModelWrapper(model) + list_model.transformers = bmt.TransformerBlockList([m for m in list_model.transformers]) + bmt.load(list_model, ckpt_path) + model.run_unroll = True + return model + models = { "torch": torch_model, "wrapper": wrap_model, "blocklist": list_model, "pipelist": pipe_model, + "unroll_blocklist": unroll_list_model, } loss_funcs = { "bmt_entropy": bmt.loss.FusedCrossEntropy, @@ -406,6 +416,7 @@ def add_to_check_list(m, l, o): add_to_check_list("pipelist", "bmt_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "bmt_entropy", "bmt_adam_offload") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") if bmt.rank() == 0: os.remove(ckpt_path) check(ret) @@ -419,6 +430,7 @@ def add_to_check_list(m, l, o): add_to_check_list("pipelist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam_offload") add_to_check_list("blocklist", "torch_entropy", "torch_adam") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") if bmt.rank() == 0: os.remove(ckpt_path) check(ret) @@ -442,4 +454,4 @@ def check_param(info1, info2): if __name__ == '__main__': bmt.init_distributed(pipe_size=2) - test_main(test_fp16=True, test_fp32=True) \ No newline at end of file + test_main(test_fp16=True, test_fp32=True) From b0a0865664ca62aa7706d16f482d37481eba04e4 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Mon, 21 Aug 2023 15:14:37 +0800 Subject: [PATCH 12/52] prod backward raise error --- bmtrain/distributed/ops.py | 2 +- tests/test_nccl_backward.py | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index e3504ab7..d1b489e2 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -138,7 +138,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None) def backward(ctx, grad_output): with torch.no_grad(): grad_output = OpAllGather.apply(grad_output, ctx.comm).flatten(0,1) - if ctx.op in ["max", "min"]: + if ctx.op in ["max", "min", "prod"]: raise NotImplementedError("max min operation now do not support backward") else: if ctx.op == "avg": diff --git a/tests/test_nccl_backward.py b/tests/test_nccl_backward.py index 3e95b165..3d45f54e 100644 --- a/tests/test_nccl_backward.py +++ b/tests/test_nccl_backward.py @@ -12,16 +12,6 @@ def test_main(): for i in range(bmt.world_size()): if i != bmt.rank(): ref *= i+1 assert_eq(x.grad, ref) -def test_reducescatter(): - x = torch.ones((24,), dtype=torch.half, device="cuda").requires_grad_(True) - y = bmt.distributed.reduce_scatter(x, "sum") - loss = y.sum() - loss.backward() - ref = torch.ones((24,), dtype=torch.half, device="cuda") - print(loss) - assert y[0] == bmt.world_size() - assert_all_eq(x.grad,ref) - def test_reducescatter(): world_size = bmt.world_size() From 3f117444dbd5baea2f170c81687f07a3397cdb6c Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 21 Aug 2023 21:27:57 +0800 Subject: [PATCH 13/52] fix inspect_model where param is None (#151) --- bmtrain/inspect/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bmtrain/inspect/model.py b/bmtrain/inspect/model.py index 871adf52..79975cc5 100644 --- a/bmtrain/inspect/model.py +++ b/bmtrain/inspect/model.py @@ -220,6 +220,8 @@ def inspect_model(model : torch.nn.Module, param_name : str, prefix : str = ''): p = _gather_value(param.data, param.storage().size(), param._original_shape) else: p = param + if p is None: + continue stats = { 'name': prefix + name, 'shape': tuple(p.size()), From dc7284dfeba83713e2bb260dfc60cb2506441254 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 22 Aug 2023 15:37:37 +0800 Subject: [PATCH 14/52] fix is_first_layer (#152) --- bmtrain/block_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 8e52c68a..18438b8c 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -528,8 +528,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) module._mode = "ZERO" module.set_pre_module(pre_module) pre_module = module - self._is_first_layer = False - self._is_last_layer = False + module._is_first_layer = False + module._is_last_layer = False self._modules[str(i)] = module self.add_module(str(i), module) From 7d62a181c248c332f1ef5223a02c63c8de6e27ca Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 25 Aug 2023 15:43:07 +0800 Subject: [PATCH 15/52] mv zero_level to CheckpointBlock (#154) --- bmtrain/block_layer.py | 3 ++- bmtrain/hook_func.py | 8 ++++---- bmtrain/init.py | 3 --- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 18438b8c..4e9a6c01 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -66,7 +66,7 @@ class CheckpointBlock(torch.nn.Module): >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3): super().__init__() self._module = inner_module self._inputs = None @@ -200,6 +200,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True): self._mode = "BLOCK" #BLOCK or ZERO or PIPE self.all_input_no_grad = False self.all_param_no_grad = False + self._zero_level = zero_level def set_pre_module(self, pre_module): if pre_module is not None: diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 6a56300e..4d91d1d0 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -9,7 +9,7 @@ def zero_pre_forward(module, inputs): enter = module._micro_idx == 0 pipe = True if enter: - zero_level = config['zero_level'] + zero_level = module._zero_level forward_flag = 1 if zero_level == 2 else 0 if zero_level == 2 and module._ref_count > 1: forward_flag = 2 # repeating forward in same layer @@ -19,7 +19,7 @@ def zero_pre_forward(module, inputs): module._forward_block_ctx.enter(forward_flag) def zero_post_forward(module, inputs, outputs): - forward_flag = 1 if config['zero_level'] == 2 else 0 + forward_flag = 1 if module._zero_level == 2 else 0 if module.all_param_no_grad: forward_flag = 0 exit = True @@ -31,7 +31,7 @@ def zero_post_forward(module, inputs, outputs): module._ref_count += 1 def zero_pre_backward(module, grad_outputs): - backward_flag = 2 if config['zero_level'] == 2 else 0 + backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) @@ -43,7 +43,7 @@ def zero_pre_backward(module, grad_outputs): module._backward_block_ctx.enter(backward_flag, True) def zero_post_backward(module, grad_inputs, grad_outputs): - backward_flag = 2 if config['zero_level'] == 2 else 0 + backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": if module._is_first_layer: module.backward_release(backward_flag) diff --git a/bmtrain/init.py b/bmtrain/init.py index 1fa0712d..f3c1faa7 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -14,7 +14,6 @@ def init_distributed( init_method : str = "env://", seed : int = 0, - zero_level: int = 3, pipe_size: int = -1, num_micro_batches: int = None, ): @@ -24,7 +23,6 @@ def init_distributed( Args: seed (int): The random seed. - zero_level (int): The ZeRO optimization level. 2 for stage-2, 3 for stage-3. **init_distributed** reads the following environment variables: @@ -74,7 +72,6 @@ def init_distributed( config["load_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() - config["zero_level"] = zero_level config["topology"] = topology(config) config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] cpus_this_worker = None From 286297e4414527f016370318e2a7601a51f8a11c Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Sun, 27 Aug 2023 22:37:24 -0700 Subject: [PATCH 16/52] =?UTF-8?q?FIX=EF=BC=9Athe=20number=20of=20GPU=20of?= =?UTF-8?q?=20synchronize?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_all.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_all.py b/tests/test_all.py index ec23cf89..721abebc 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -6,7 +6,7 @@ ("different_output_shape", 1), ("load_ckpt", 1), ("init_parameters", 1), - ("synchronize", 1), + ("synchronize", 4), ("init_parameters_multi_gpu", 4), ("optim_state", 4), From 69ea264f2d4fca1ce93b8081e474126dde5760d5 Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Sun, 27 Aug 2023 22:43:22 -0700 Subject: [PATCH 17/52] add judgement through assert --- tests/test_synchronize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_synchronize.py b/tests/test_synchronize.py index f92ff5b9..1aea81cc 100644 --- a/tests/test_synchronize.py +++ b/tests/test_synchronize.py @@ -6,22 +6,26 @@ from bmtrain.synchronize import gather_result def test_main(): - tensor = torch.rand(5, 5) + tensor = torch.rand(5, 5) * bmt.rank() result = bmt.gather_result(tensor) tensor_slice_0 = tensor[:1, :1] result_slice_0 = bmt.gather_result(tensor_slice_0) + assert torch.allclose(result[:1, :1], result_slice_0, atol=1e-6), "Assertion failed for tensor_slice_0" tensor_slice_1 = tensor[:2, :2] result_slice_1 = bmt.gather_result(tensor_slice_1) + assert torch.allclose(result[:2, :2], result_slice_1, atol=1e-6), "Assertion failed for tensor_slice_1" tensor_slice_2 = tensor[:3, :3] result_slice_2 = bmt.gather_result(tensor_slice_2) + assert torch.allclose(result[:3, :3], result_slice_2, atol=1e-6), "Assertion failed for tensor_slice_2" tensor_slice_3 = tensor[:4, :4] result_slice_3 = bmt.gather_result(tensor_slice_3) + assert torch.allclose(result[:4, :4], result_slice_3, atol=1e-6), "Assertion failed for tensor_slice_3" - print(result, result_slice_0, result_slice_1, result_slice_2, result_slice_3, sep='\n') + print("All slice tests passed!") if __name__ == '__main__': bmt.init_distributed(pipe_size=1) From 17973a227f5eba24dd4992ef80c68204f381e7bb Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Sun, 27 Aug 2023 23:17:20 -0700 Subject: [PATCH 18/52] Add --- tests/test_synchronize.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/tests/test_synchronize.py b/tests/test_synchronize.py index 1aea81cc..ee43d39a 100644 --- a/tests/test_synchronize.py +++ b/tests/test_synchronize.py @@ -6,26 +6,19 @@ from bmtrain.synchronize import gather_result def test_main(): - tensor = torch.rand(5, 5) * bmt.rank() - result = bmt.gather_result(tensor) - - tensor_slice_0 = tensor[:1, :1] - result_slice_0 = bmt.gather_result(tensor_slice_0) - assert torch.allclose(result[:1, :1], result_slice_0, atol=1e-6), "Assertion failed for tensor_slice_0" - tensor_slice_1 = tensor[:2, :2] - result_slice_1 = bmt.gather_result(tensor_slice_1) - assert torch.allclose(result[:2, :2], result_slice_1, atol=1e-6), "Assertion failed for tensor_slice_1" + ref_result = torch.rand(5 * bmt.world_size(), 5) + tensor = ref_result.chunk(bmt.world_size(), dim=0)[bmt.rank()] + real_result = bmt.gather_result(tensor) + assert torch.allclose(ref_result, real_result, atol=1e-6), "Assertion failed for real gather result error" - tensor_slice_2 = tensor[:3, :3] - result_slice_2 = bmt.gather_result(tensor_slice_2) - assert torch.allclose(result[:3, :3], result_slice_2, atol=1e-6), "Assertion failed for tensor_slice_2" + for i in range(4): + size = i + 1 + tensor_slice = tensor[:size, :size] + result_slice = bmt.gather_result(tensor_slice) + assert torch.allclose(real_result[:size, :size], result_slice, atol=1e-6), f"Assertion failed for tensor_slice_{i}" - tensor_slice_3 = tensor[:4, :4] - result_slice_3 = bmt.gather_result(tensor_slice_3) - assert torch.allclose(result[:4, :4], result_slice_3, atol=1e-6), "Assertion failed for tensor_slice_3" - - print("All slice tests passed!") + print("All tensor slice tests passed!") if __name__ == '__main__': bmt.init_distributed(pipe_size=1) From d43509706e65ba19de3ee4c8a1f2074669c4644b Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Mon, 28 Aug 2023 01:53:07 -0700 Subject: [PATCH 19/52] FIX: Revise the document path --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 538b54ba..066680f7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,7 +12,7 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath('../../..')) import recommonmark from recommonmark.transform import AutoStructify From abc7b901adfe01582a951013d456924eeea46ad2 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 28 Aug 2023 20:53:38 +0800 Subject: [PATCH 20/52] Fix middle hidden (#155) --- bmtrain/block_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 4e9a6c01..9d852df7 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -216,7 +216,7 @@ def next_module(self): return self._next_module[self._ref_count-1] def backward_release(self, flag): - if self._ref_count == 1: + if self._ref_count == 1 and self._backward_block_ctx is not None: self._backward_block_ctx.exit(flag, True) config['load_stream'].record_event(config['load_event']) self._ref_count -= 1 From df43d6d6505daff52d3e94776e34946a22ea88f7 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Tue, 29 Aug 2023 11:13:55 +0800 Subject: [PATCH 21/52] Tensor Parallel (#153) --- bmtrain/block_layer.py | 44 +++++--- bmtrain/checkpointing.py | 9 +- bmtrain/init.py | 84 +++++++++++---- bmtrain/layer.py | 20 +++- bmtrain/loss/cross_entropy.py | 40 ++++--- bmtrain/nn/__init__.py | 7 +- bmtrain/nn/column_parallel_linear.py | 36 +++++++ bmtrain/nn/linear.py | 4 +- bmtrain/nn/parallel_cross_entropy_func.py | 122 ++++++++++++++++++++++ bmtrain/nn/parallel_embedding.py | 72 +++++++++++++ bmtrain/nn/parallel_linear_func.py | 108 +++++++++++++++++++ bmtrain/nn/row_parallel_linear.py | 38 +++++++ bmtrain/param_init.py | 18 +++- bmtrain/parameter.py | 68 ++++++++++-- bmtrain/pipe_layer.py | 88 ++-------------- bmtrain/utils.py | 7 ++ example/layers/attention.py | 68 ++++++++---- example/layers/embedding.py | 8 +- example/layers/feedforward.py | 15 ++- example/models/gpt.py | 13 ++- example/train.py | 15 ++- tests/test_all.py | 2 + tests/test_column_parallel_linear.py | 55 ++++++++++ tests/test_row_parallel_linear.py | 54 ++++++++++ 24 files changed, 804 insertions(+), 191 deletions(-) create mode 100644 bmtrain/nn/column_parallel_linear.py create mode 100644 bmtrain/nn/parallel_cross_entropy_func.py create mode 100644 bmtrain/nn/parallel_embedding.py create mode 100644 bmtrain/nn/parallel_linear_func.py create mode 100644 bmtrain/nn/row_parallel_linear.py create mode 100644 tests/test_column_parallel_linear.py create mode 100644 tests/test_row_parallel_linear.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 9d852df7..e16ef6a8 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -1,6 +1,6 @@ from typing import Dict, Iterable, Iterator, Union, List -from .utils import round_up +from .utils import (round_up, tp_split_tensor) from .global_var import config import torch from . import nccl @@ -94,7 +94,8 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev "total": 0, "storage_type": storage_type, "requires_grad": param.requires_grad, - "group": param.group + "group": param.group, + "zero_comm" : param._zero_comm } param_shape = param._original_shape @@ -108,11 +109,14 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev offsets = {} # intialize storage buffers for kw, val in self._storage_info.items(): - val["world_size"] = config["world_size"] + comm = val['zero_comm'] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + val["world_size"] = world_size partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] val["partition_size"] = partition_size - val["begin"] = config['rank'] * partition_size - val["end"] = (config['rank'] + 1) * partition_size + val["begin"] = rank * partition_size + val["end"] = (rank+1) * partition_size offsets[kw] = 0 @@ -302,13 +306,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if key in state_dict: # load here input_param = state_dict[key] + param = it['parameter'] + tp_mode = param._tp_mode if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() - if input_param.shape != it["shape"]: + + verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape) + if input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' - .format(key, input_param.shape, it["shape"])) + .format(key, input_param.shape, verify_shape)) continue + param_st = it["offset"] param_end = it["offset"] + it["size"] kw_name = it["kw_name"] @@ -322,8 +331,14 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, continue # copy to buffer - assert input_param.numel() == it["size"] + verify_size = verify_shape.numel() + assert input_param.numel() == verify_size + contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous() + + tp_split_dim = param._tp_split_dim + if tp_mode and tp_split_dim >= 0: + contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) offset_st = max(storage_st - param_st, 0) offset_end = min(storage_end - param_st, contiguous_param.numel()) @@ -331,7 +346,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, to_offset_st = offset_st + param_st - storage_st to_offset_end = offset_end + param_st - storage_st - + # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype @@ -398,7 +413,7 @@ def init_parameters(self): param = it["parameter"] if isinstance(param, DistributedParameter) and param._init_method is not None: # initialzie here - tmp_tensor = torch.empty(it["shape"], device=param.device, dtype=param.dtype) + tmp_tensor = torch.empty(param._tp_original_shape, device=param.device, dtype=param.dtype) param._init_method(tmp_tensor) param_st = it["offset"] param_end = it["offset"] + it["size"] @@ -412,16 +427,15 @@ def init_parameters(self): if param_end <= storage_st: continue + if param._tp_mode and param._tp_split_dim >= 0: + tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim) # copy to buffer assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel() - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, tmp_tensor.numel()) + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, tmp_tensor.numel()) assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - # copy to buffer # PyTorch 1.11 changed the API of storage.__getitem__ d_dtype = self._storage_params[kw_name].dtype diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py index b2c9ec07..550225be 100644 --- a/bmtrain/checkpointing.py +++ b/bmtrain/checkpointing.py @@ -39,10 +39,7 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = Fal self._param_tensor = {} self._grad_tensor = {} self._need_release = False - if pipe: - self.comm = config["zero_comm"] - else: - self.comm = config["comm"] + def enter(self, flag=0, requires_grad=False): """ gather parameters @@ -74,7 +71,7 @@ def enter(self, flag=0, requires_grad=False): nccl.allGather( self.block._storage_params[kw].storage(), self._param_buffer[kw], - self.comm + val['zero_comm'] ) nccl.groupEnd() @@ -144,7 +141,7 @@ def exit(self, flag=0, backward=False): self._grad_buffer[kw], local_param.grad.storage(), "sum", - self.comm + val['zero_comm'] ) nccl.groupEnd() diff --git a/bmtrain/init.py b/bmtrain/init.py index f3c1faa7..a6214d78 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -10,12 +10,12 @@ from . import nccl from .synchronize import synchronize - def init_distributed( init_method : str = "env://", seed : int = 0, pipe_size: int = -1, num_micro_batches: int = None, + tp_size : int = 1, ): """Initialize distributed training. This function will initialize the distributed training, set the random seed and global configurations. @@ -23,6 +23,9 @@ def init_distributed( Args: seed (int): The random seed. + pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups + num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode. + tp_size (int) : tp_size means the size of each of tensor parallel group **init_distributed** reads the following environment variables: @@ -70,10 +73,15 @@ def init_distributed( config["world_size"] = world_size config["calc_stream"] = torch.cuda.current_stream() config["load_stream"] = torch.cuda.Stream(priority=-1) + config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) + config["pp_comm_stream"] = torch.cuda.Stream(priority=-1) config['barrier_stream'] = torch.cuda.Stream() config["load_event"] = torch.cuda.Event() + config["tp_size"] = tp_size if tp_size > 0 else 1 config["topology"] = topology(config) - config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank'] + config["zero_rank"] = config['topology'].get_group_rank("zero") + config["tp_rank"] = config['topology'].get_group_rank("tp") + config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) @@ -102,21 +110,34 @@ def init_distributed( unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) config['comm'] = nccl.commInitRank(unique_id, world_size, rank) + topo = config['topology'] + if config['pipe_enabled']: config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"] - topo = config['topology'] if topo.stage_id == 0: unique_id = nccl.getUniqueId() store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) - if topo.zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) - config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//pipe_size, topo.zero_id) - else: - config['zero_comm'] = config['comm'] + + if topo.tp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + + if topo.zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) + config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size']), topo.zero_id) + for i in range(world_size): if i == rank: print_dict("Initialization", { @@ -129,40 +150,57 @@ def init_distributed( "cpus": cpus_this_worker }) synchronize() + class topology: def __init__(self,config): # pipe_idx is the idx of the pipeline in the group self.rank = config['rank'] pp_size = config["pipe_size"] + tp_size = config["tp_size"] world_size = config["world_size"] - assert world_size % pp_size == 0, "The nums of GPUs must be divisible by the pipeline parallel size" - - dp_size = world_size // pp_size - topo=torch.tensor(range(dp_size*pp_size),dtype=torch.int,device='cuda') - topo=topo.view(pp_size,dp_size) - self.pp_group=topo.transpose(0,1).reshape(-1,pp_size) - self.dp_group=topo - self.stage_id = (self.pp_group == self.rank).nonzero()[0,-1].item() + assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" + + dp_size = world_size // (pp_size * tp_size) + config['tp_zero_size'] = dp_size + config['zero_size'] = world_size // pp_size + topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda') + topo=topo.view(pp_size,dp_size*tp_size) self.stages = config['pipe_size'] - self.pipe_idx = (self.pp_group == self.rank).nonzero()[0, 0].item() # x axes - self.zero_id = self.pipe_idx - self.zero_idx = self.stage_id + + stage_size = world_size // pp_size + for i in range(world_size): + self.pipe_idx = self.rank % stage_size + self.stage_id = self.rank // stage_size + self.tp_id = self.rank % tp_size + self.tp_idx = self.rank // tp_size + self.zero_idx = self.stage_id + self.zero_id = self.pipe_idx + self.tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.tp_zero_id = self.pipe_idx // tp_size + self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 - self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist() - self.heads = self.pp_group[self.pipe_idx, :self.stage_id + 1].tolist() + def get_group_id(self,group_name): if group_name == "pipe": return self.pipe_idx elif group_name == "zero": return self.zero_idx + elif group_name == "tp_zero": + return self.tp_zero_idx + elif group_name == "tp": + return self.tp_idx def get_group_rank(self,group_name): if group_name == "pipe": return self.stage_id elif group_name == "zero": return self.zero_id + elif group_name == "tp_zero": + return self.tp_zero_id + elif group_name == "tp": + return self.tp_id def is_initialized() -> bool: return config["initialized"] diff --git a/bmtrain/layer.py b/bmtrain/layer.py index ebbef815..7de83e5e 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -1,6 +1,8 @@ import torch from .parameter import DistributedParameter +from .global_var import config import itertools +from .utils import tp_split_tensor class DistributedModule(torch.nn.Module): """ @@ -11,7 +13,7 @@ class DistributedModule(torch.nn.Module): def __getattr__(self, name: str): ret = super().__getattr__(name) # gather distributed parameters if not in CheckpointBlock - if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: + if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: return ret.gather() return ret @@ -30,8 +32,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - if isinstance(param, DistributedParameter) and not param._in_checkpoint_block: - destination[prefix + name] = param.gather().detach().cpu() # sync operation + if isinstance(param, DistributedParameter):#and not param._in_checkpoint_block: + if param._in_checkpoint_block: + destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation + else: + destination[prefix + name] = param.gather_all().detach().cpu() # sync operation else: destination[prefix + name] = param if keep_vars else param.detach().cpu() for name, buf in self._buffers.items(): @@ -81,6 +86,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, for name, param in local_state.items(): key = prefix + name if key in state_dict: + tp_mode = param._tp_mode input_param = state_dict[key] if input_param.__class__.__name__ == "DistributedTensorWrapper": input_param = input_param.broadcast() @@ -98,13 +104,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 'the shape in current model is {}.' .format(key, input_param.shape, param.shape)) continue - if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != param._original_shape: + verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' - .format(key, input_param.shape, param.shape)) + .format(key, input_param.shape, verify_shape)) try: with torch.no_grad(): if isinstance(param, DistributedParameter): + tp_split_dim = param._tp_split_dim + if tp_mode and tp_split_dim >= 0: + input_param = tp_split_tensor(input_param, tp_split_dim) param._copy_data(input_param) else: param.copy_(input_param) diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 31223640..982a6469 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -1,6 +1,10 @@ from typing import Optional import torch from . import _function as F +from bmtrain.nn import parallel_cross_entropy_func +from bmtrain.global_var import config +from bmtrain.distributed import all_gather + class OpFusedCrossEntropy(torch.autograd.Function): """ CrossEntropy dim = 1 @@ -176,6 +180,7 @@ def __init__(self, reduction: str = 'mean', label_smoothing: float = 0.0, # TODO not supported yet inplace: bool = False, + parallel: bool = False, ) -> None: super().__init__() self.weight = weight @@ -183,21 +188,26 @@ def __init__(self, self.reduction = reduction self.label_smoothing = label_smoothing self.inplace = inplace + self.parallel = parallel def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - if input.dtype == torch.float32: - return torch.nn.functional.cross_entropy( - input, - target.long(), - weight=self.weight, - ignore_index=self.ignore_index, - reduction=self.reduction, - label_smoothing=self.label_smoothing) - - if self.inplace: - ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor + if self.parallel: + target = all_gather(target, comm=config['tp_comm']).flatten(0,1) + ret = parallel_cross_entropy_func(input, target.long(), self.ignore_index) else: - ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor + if input.dtype == torch.float32: + return torch.nn.functional.cross_entropy( + input, + target.long(), + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing) + + if self.inplace: + ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor + else: + ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor if self.weight is not None: if self.weight.dim() != 1 or self.weight.size(0) != input.size(1): @@ -208,6 +218,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = (target != self.ignore_index).int() ret = w * ret + + if self.parallel and config['tp_size'] > 1: + ret_list = ret.chunk(config['tp_size'], dim=0) + ret = ret_list[config['topology'].tp_id] + w_list = w.chunk(config['tp_size'], dim=0) + w = w_list[config['topology'].tp_id] if self.reduction == "none": return ret diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 67f9fdee..e22d8c55 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1 +1,6 @@ -from .linear import Linear +from .linear import Linear, OpLinear +from .column_parallel_linear import ColumnParallelLinear +from .row_parallel_linear import RowParallelLinear +from .parallel_embedding import ParallelEmbedding +from .parallel_cross_entropy_func import parallel_cross_entropy_func +from .parallel_linear_func import OpParallelLinear \ No newline at end of file diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py new file mode 100644 index 00000000..e8f554c8 --- /dev/null +++ b/bmtrain/nn/column_parallel_linear.py @@ -0,0 +1,36 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_func import ( + OpParallelLinear, + ReduceType) + +class ColumnParallelLinear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False, gather_input=True) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.gather_input = gather_input + tp_size = config['tp_size'] + assert out_features % tp_size == 0 + self.out_features_per_partition = out_features // tp_size + self.weight = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, in_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=0, tp_mode=True) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(self.out_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=0, tp_mode=True) + else: + self.register_parameter('bias', None) + + def forward(self, input): + gather_input = self.gather_input + split_input = False + reduce_output_type = None + return OpParallelLinear.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features_per_partitions, self.bias is not None + ) + diff --git a/bmtrain/nn/linear.py b/bmtrain/nn/linear.py index faf0770e..cb04863a 100644 --- a/bmtrain/nn/linear.py +++ b/bmtrain/nn/linear.py @@ -2,7 +2,7 @@ import torch.nn.functional as F import bmtrain as bmt -class CustomLinear(torch.autograd.Function): +class OpLinear(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, bias=None): ctx.save_for_backward(x, weight, bias) @@ -35,7 +35,7 @@ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtyp self.register_parameter('bias', None) def forward(self, input): - return CustomLinear.apply(input, self.weight, self.bias) + return OpLinear.apply(input, self.weight, self.bias) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/parallel_cross_entropy_func.py b/bmtrain/nn/parallel_cross_entropy_func.py new file mode 100644 index 00000000..cd1f63bf --- /dev/null +++ b/bmtrain/nn/parallel_cross_entropy_func.py @@ -0,0 +1,122 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather + +class ParallelCrossEntropyFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits, target, label_smoothing=0.0): + comm = config['tp_comm'] + rank = config['topology'].tp_id + world_size = config['tp_size'] + + # local max + max_logits = torch.max(logits, dim=-1)[0] + # global max + max_logits = all_reduce(max_logits, op="max", comm=comm) + + logits = logits - max_logits.unsqueeze(dim=-1) + + partition_vocab_size = logits.size()[-1] + vocab_start_index = rank * partition_vocab_size + vocab_end_index = (rank + 1) * partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + + # All reduce is needed to get the chunks from other GPUs. + predicted_logits = all_reduce(predicted_logits, op="sum", comm=comm) + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = logits + torch.exp(logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits + + # Normalize and optionally smooth logits + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + vocab_size = exp_logits.size(-1) + if label_smoothing > 0: + """ + We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. + = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) + = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i + = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i + = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K + From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py + """ + assert 1.0 > label_smoothing > 0.0 + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + + # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. + log_probs = torch.log(exp_logits) + mean_log_probs = log_probs.mean(dim=-1) + loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs + + ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + # Store softmax, target-mask and masked-target for backward pass. + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + if label_smoothing > 0: + smoothing = label_smoothing * vocab_size / (vocab_size - 1) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update + average_grad = 1 / vocab_size + grad_2d[arange_1d, :] -= smoothing * average_grad + else: + grad_2d[arange_1d, masked_target_1d] -= softmax_update + + grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) + + return grad_input, None, None + + +def parallel_cross_entropy_func(logits, target, label_smoothing=0.0): + """ + Performs cross entropy loss when logits are split across tensor parallel ranks + + Arguments: + logits: logits split across tensor parallel ranks dimension is [batch * seq_len, hidden_size]. + target: correct vocab ids of dimseion [batch * seq_len]. + lobal_smoothing: smoothing factor, must be in range [0.0, 1.0). default is 0.0. + """ + out = ParallelCrossEntropyFunc.apply(logits.to(torch.float32), target, label_smoothing) + return out + diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py new file mode 100644 index 00000000..cd567b4e --- /dev/null +++ b/bmtrain/nn/parallel_embedding.py @@ -0,0 +1,72 @@ +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +import math + +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather +from .parallel_linear_func import OpParallelLinear + +class ParallelEmbedding(bmt.DistributedModule): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ): + super().__init__() + + self.dim_model = embedding_size + assert vocab_size % config['tp_size'] == 0 + self.vocab_size_per_partition = vocab_size // config['tp_size'] + self.start_index = config['topology'].tp_id * self.vocab_size_per_partition + self.end_index = (config['topology'].tp_id+1) * self.vocab_size_per_partition + self.weight = bmt.DistributedParameter( + torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + tp_mode=True, + tp_split_dim=0, + ) + + def forward(self, ids: torch.Tensor, gather_input=True): + """ + Args: + ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. + Return: + :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. + """ # noqa: E501 + + if config['tp_size'] > 1: + if gather_input: + ids = all_gather(ids, comm=config['tp_comm']) + input_mask = (ids < self.start_index) | (ids >= self.end_index) + ids = ids.clone() - self.start_index + ids[input_mask] = 0 + + embeds = F.embedding(ids, self.weight) + + if config['tp_size'] > 1: + embeds[input_mask, :] = 0.0 + embeds = all_reduce(embeds, op="sum", comm=config['tp_comm']) + embed_list = embeds.chunk(config['tp_size'], dim=0) + embeds = embed_list[config['topology'].tp_id].flatten(0,1) + + return embeds.clone() + + def projection(self, x: torch.Tensor): + """ + Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. + Args: + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. + """ # noqa: E501 + gather_input = True + split_input = False + reduce_output_type = None + gather_output = False + out = OpParallelLinear.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) + return out diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py new file mode 100644 index 00000000..55741530 --- /dev/null +++ b/bmtrain/nn/parallel_linear_func.py @@ -0,0 +1,108 @@ +import torch +import torch.nn.functional as F +from bmtrain.global_var import config +from ..distributed import all_gather, all_reduce +from .. import nccl +import bmtrain as bmt +from enum import Enum + +class ReduceType(Enum): + ALL_REDUCE = 1 + REDUCE_SCATTER = 2 + +def preprocess_input(input, gather_input, split_input): + if gather_input: + input = all_gather(input, config['tp_comm']) + input = input.flatten(0, 1) + + if split_input: + all_input_list = input.chunk(config['tp_size'], dim=-1) + input = all_input_list[config['topology'].tp_id] + return input + +class OpParallelLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): + if reduce_output_type is not None: + reduce_output_type = ReduceType(reduce_output_type) + + ctx.save_for_backward(input, weight, bias) + ctx.gather_output = gather_output + ctx.split_input = split_input + ctx.gather_input = gather_input + ctx.reduce_output_type = reduce_output_type + + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + out = F.linear(all_input, weight, bias) + + if gather_output: + all_output_list = all_gather(out, config['tp_comm']) + all_output_list = all_output_list.chunk(config['tp_size'], dim=0) + out = torch.cat(all_output_list, dim=all_output_list[0].dim()-1).flatten(0,1) + + if reduce_output_type is None: + return out + + if reduce_output_type == ReduceType.ALL_REDUCE: + nccl.allReduce(out.storage(), out.storage(), "sum", config['tp_comm']) + return out + + elif reduce_output_type == ReduceType.REDUCE_SCATTER: + shape = list(out.shape) + shape[0] = shape[0] // config['tp_size'] + reduce_out = torch.empty(shape, dtype=out.dtype, device=out.device) + nccl.reduceScatter(out.storage(), reduce_out.storage(), "sum", config['tp_comm']) + return reduce_out + else: + assert False, "no support reduce type{}".format(reduce_output_type) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + gather_output = ctx.gather_output + + if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER: + grad_output = all_gather(grad_output, config['tp_comm']) + grad_output = grad_output.flatten(0, 1) + + if gather_output: + tp_size = config['tp_size'] + tp_id = config['topology'].tp_id + grad_output_list = grad_output.chunk(tp_size, dim=-1) + grad_output = grad_output_list[tp_id] + + grad_input = grad_weight = grad_bias = None + + if input.requires_grad or weight.requires_grad: + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + + if input.requires_grad: + current_stream = torch.cuda.current_stream() + grad_all_input = grad_output.matmul(weight) + grad_input = torch.zeros_like(input) + if ctx.gather_input: + with torch.cuda.stream(config['tp_comm_stream']): + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) + grad_all_input.record_stream(config['tp_comm_stream']) + nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) + else: + grad_input = grad_all_input + + if ctx.split_input: + with torch.cuda.stream(config['tp_comm_stream']): + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) + grad_input = all_gather(grad_input, config['tp_comm']) + + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(all_input.reshape(-1, all_input.shape[-1])) + + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config['tp_comm_stream']) + return grad_input, grad_weight, grad_bias, None, None, None, None diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py new file mode 100644 index 00000000..bbe94cfa --- /dev/null +++ b/bmtrain/nn/row_parallel_linear.py @@ -0,0 +1,38 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_func import ( + OpParallelLinear, + ReduceType) + +class RowParallelLinear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, split_input=False, all_reduce_output=False) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.split_input = split_input + self.all_reduce_output = all_reduce_output + tp_size = config['tp_size'] + assert in_features % tp_size == 0 + self.in_features_per_partition = in_features // tp_size + self.weight = bmt.DistributedParameter(torch.empty(self.out_features, self.in_features_per_partition, dtype=dtype, device="cuda"), init_method=torch.nn.init.xavier_normal_, tp_split_dim=1, tp_mode=True) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(self.out_features, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_, tp_split_dim=-1, tp_mode=True) + else: + self.register_parameter('bias', None) + + def forward(self, input): + gather_input = self.split_input + gather_output = False + reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER + out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + if self.bias is not None: + out = out + self.bias + return out + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features_per_partition, self.out_features, self.bias is not None + ) diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 8b74c580..d5c86225 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -13,17 +13,27 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): continue with torch.no_grad(): partition_size = param.storage().size() - global_size = partition_size * config['world_size'] - + global_size = partition_size * config['tp_zero_size'] * config['tp_size'] tmp_storage = param.storage_type()(global_size) tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda") - tmp_tensor.set_(tmp_storage, 0, param._original_shape) + tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape) param._init_method(tmp_tensor) + if param._tp_mode and param._tp_split_dim >= 0: + tensor_list = tmp_tensor.chunk(config['tp_size'], dim=param._tp_split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) + + if param._tp_mode: + begin = config['tp_zero_rank'] + else: + begin = config['zero_rank'] + end = begin + 1 # Pytorch 1.11 changed the API of storage.__getitem__ torch.tensor([], dtype=param.dtype, device=param.device).set_(param.storage())[:] = \ - torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_storage)[partition_size * config['rank'] : partition_size * (config['rank'] + 1)] + torch.tensor([], dtype=param.dtype, device=param.device).set_(tmp_tensor.storage())[partition_size * begin : partition_size * end] # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) def iterate_parameters(model : torch.nn.Module): diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index f965cdac..b5e5e9ae 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -3,6 +3,7 @@ from .utils import round_up from .global_var import config from . import nccl +from .distributed import all_gather class DistributedParameter(torch.nn.Parameter): r""" @@ -31,7 +32,9 @@ def __new__(cls, data : torch.Tensor, requires_grad : bool = True, init_method : Optional[Callable[['DistributedParameter'], None]] = None, - group : Optional[str] = None + group : Optional[str] = None, + tp_mode : bool = False, + tp_split_dim : int = -1, ): if not config["initialized"]: raise RuntimeError("BMTrain is not initialized") @@ -39,14 +42,24 @@ def __new__(cls, num_of_elements = data.numel() cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") - cuda_storage_size = round_up(num_of_elements, config["world_size"]) // config["world_size"] + if tp_mode: + comm = config['tp_zero_comm'] + else: + comm = config['zero_comm'] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + cuda_storage_size = round_up(num_of_elements, world_size) // world_size original_shape = data.size() + tp_original_shape = original_shape + if tp_mode and tp_split_dim >= 0: + tp_original_shape = list(original_shape) + tp_original_shape[tp_split_dim] *= config['tp_size'] cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) - start_of_partition = cuda_storage_size * config["rank"] - end_of_partition = min(num_of_elements, cuda_storage_size * (config["rank"] + 1)) + start_of_partition = cuda_storage_size * rank + end_of_partition = min(num_of_elements, cuda_storage_size * (rank + 1)) # FX: cuda_tensor_size < 0 if num_of_elements is too small cuda_tensor_size = max(end_of_partition - start_of_partition, 0) @@ -60,7 +73,12 @@ def __new__(cls, setattr(ret, "_end_partition", end_of_partition) setattr(ret, "_init_method", init_method) setattr(ret, "_in_checkpoint_block", False) - setattr(ret, "_group", group) + setattr(ret, "_group", group if not tp_mode else "tp") + + setattr(ret, "_tp_mode", tp_mode) + setattr(ret, "_zero_comm", comm) + setattr(ret, "_tp_split_dim", tp_split_dim) + setattr(ret, "_tp_original_shape", tp_original_shape) return ret @property @@ -83,24 +101,52 @@ def gather(self) -> torch.Tensor: current_stream.wait_stream(config['load_stream']) return output_tensor + def gather_all(self) -> torch.tensor: + zero_param = self.gather() + if config['tp_size'] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(zero_param, config['tp_comm']) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config['tp_size'], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + return output + else: + return output_tensor.flatten(0,1) + else: + return zero_param + + def tp_gather(self) -> torch.tensor: + if config['tp_size'] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(self, config['tp_comm']) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config['tp_size'], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim()-1).flatten(0,1) + return output + else: + return output_tensor.flatten(0,1) + else: + return self + def _copy_data(self, data : torch.Tensor): self.data.copy_(data.view(-1)[self._start_partition : self._end_partition]) - class OpAllGather(torch.autograd.Function): @staticmethod def forward(ctx, value : DistributedParameter): assert isinstance(value, DistributedParameter) + comm = value._zero_comm #config['zero_comm'] + world_size = nccl.commCount(comm) + ctx.comm = comm + ctx.world_size = world_size partition_size = value.storage().size() - global_size = partition_size * config['world_size'] + global_size = partition_size * world_size storage = value.storage_type()(global_size) nccl.allGather( value.storage(), storage, - config['comm'] + comm ) output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") @@ -117,15 +163,15 @@ def backward(ctx, grad_output : torch.Tensor): grad_storage = grad_output.storage_type()(ctx.partition_size) grad_output_storage = grad_output.storage() - if grad_output_storage.size() == ctx.partition_size * config['world_size']: + if grad_output_storage.size() == ctx.partition_size * ctx.world_size: pass else: - grad_output_storage.resize_(ctx.partition_size * config['world_size']) + grad_output_storage.resize_(ctx.partition_size * ctx.world_size) nccl.reduceScatter( grad_output_storage, grad_storage, 'sum', - config['comm'] + ctx.comm ) grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 0a34ac46..c28ab87d 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -139,7 +139,11 @@ def forward(ctx, input, stage_id): def backward(ctx, grad_outputs): if not ctx.is_first_stage: send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs - send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config['pp_comm_stream']): + config['pp_comm_stream'].wait_stream(current_stream) + send_data.record_stream(current_stream) + send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) return grad_outputs, None class StagePostFunction(torch.autograd.Function): @@ -150,7 +154,11 @@ def forward(ctx, outputs, stage_id): ctx.is_last_stage = stage_id == config['pipe_size'] - 1 if not ctx.is_last_stage: send_data = outputs[0] if isinstance(outputs, tuple) else outputs - send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config['pp_comm_stream']): + config['pp_comm_stream'].wait_stream(current_stream) + send_data.record_stream(current_stream) + send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) return outputs @staticmethod @@ -189,7 +197,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: rank = config['rank'] topo = config['topology'] self.layer_ids = [] - pipe_group = topo.pp_group self.stages = topo.stages self.stage_id = topo.stage_id self.pipe_idx = topo.pipe_idx @@ -218,11 +225,6 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self._modules[str(self.layer_ids[0])]._is_first_layer = True self._modules[str(self.layer_ids[-1])]._is_last_layer = True - self.partition_modules(self.layer_ids) - self.next_rank = pipe_group[self.pipe_idx, self.stage_id + 1] if self.stage_id < config['pipe_size'] - 1 else -1 - self.prev_rank = pipe_group[self.pipe_idx, self.stage_id - 1] if self.stage_id > 0 else -1 - # self.micro_batches = config['num_micro_batches'] - self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: @@ -295,76 +297,6 @@ def get_stage_by_layer_id(self, layer_id : int) -> int: else: return rest + (layer_id - rest * (part_len+1)) // part_len - def partition_modules(self, idxs) -> None: - for i in range(len(self)): - contiguous_params = {} - for kw, val in self[i]._storage_info.items(): - storage_type = val["storage_type"] - contiguous_params[kw] = storage_type(round_up(val["total"], config["world_size"] // config["pipe_size"])) - nccl.allGather( - self[i]._storage_params[kw].storage(), - contiguous_params[kw], - config["comm"] - ) - - if i not in idxs: - for name, param in self[i]._module.named_parameters(): - param.data = torch.tensor([], dtype = param.dtype, device = param.device) - for kw, val in self[i]._storage_info.items(): - val["begin"] = self.stage_id - val["end"] = self.stage_id + 1 - val["partition_size"] = 1 - val["total"] = val["world_size"] - dtype = self[i]._storage_params[kw].dtype - device = self[i]._storage_params[kw].device - self[i]._storage_params[kw] = \ - torch.nn.Parameter(torch.tensor([0], dtype = dtype, device=device)) - else: - for kw, val in self[i]._storage_info.items(): - storage_type = val["storage_type"] - val["world_size"] = config["world_size"] // config["pipe_size"] - partition_size = round_up(val["total"], val["world_size"]) // val["world_size"] - val["partition_size"] = partition_size - val["begin"] = config['zero_rank'] * partition_size - val["end"] = (config['zero_rank'] + 1) * partition_size - storage_param_buffer = storage_type(partition_size) - dtype = storage_param_buffer.dtype - device = storage_param_buffer.device - self[i]._storage_params[kw] = torch.nn.Parameter( - torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer) - ) - if val["requires_grad"]: - self[i]._storage_params[kw].requires_grad_(True) - else: - self[i]._storage_params[kw].requires_grad_(False) - ordered_parameters = list(self[i]._module.named_parameters()) - for idx, named_param in enumerate(ordered_parameters): - name, param = named_param - param_info = self[i]._param_info[idx] - kw_name = _get_param_kw(param) - storage_info = self[i]._storage_info[kw_name] - storage_st = storage_info["begin"] - storage_end = storage_info["end"] - param_st = param_info["offset"] - param_end = param_st + param_info["size"] - if not (param_st >= storage_end or param_end <= storage_st): - # copy offset in parameter storage - offset_st = max(storage_st - param_st, 0) - offset_end = min(storage_end - param_st, param_info["size"]) - assert offset_st < offset_end - to_offset_st = offset_st + param_st - storage_st - to_offset_end = offset_end + param_st - storage_st - d_dtype = self[i]._storage_params[kw_name].dtype - d_device = self[i]._storage_params[kw_name].device - param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self[i]._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) - param_info["begin"] = to_offset_st - param_info["end"] = (to_offset_end - to_offset_st,) - param.data[:] = \ - torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_params[kw], storage_st+to_offset_st, (to_offset_end - to_offset_st,))[:] - else: - param.data = torch.tensor([], dtype=param.dtype, device=param.device) - del contiguous_params - def _save_to_state_dict(self, destination, prefix, keep_vars): for name, module in self._modules.items(): idx = int(name) diff --git a/bmtrain/utils.py b/bmtrain/utils.py index a5687c7d..8cb87808 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -107,6 +107,13 @@ def see_memory(message, detail=False): """) torch.cuda.reset_peak_memory_stats() +def tp_split_tensor(tensor, split_dim): + tensor_list = tensor.chunk(config['tp_size'], dim=split_dim) + sub_tensor = tensor_list[config['topology'].tp_id].contiguous() + tmp_tensor = torch.empty(sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype) + tmp_tensor.copy_(sub_tensor) + return tmp_tensor + class AverageRecorder: """A utility class to record the average value of a quantity over time. diff --git a/example/layers/attention.py b/example/layers/attention.py index 243df3ea..8fbb7510 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -1,8 +1,14 @@ from typing import Optional import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear, +) import math +from bmtrain.global_var import config +from bmtrain.distributed import all_gather class Attention(bmt.DistributedModule): def __init__(self, @@ -12,11 +18,17 @@ def __init__(self, ) -> None: super().__init__() - self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) - self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + if config['tp_size'] > 1: + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + else: + self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) - self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) self.softmax = torch.nn.Softmax(dim=-1) self.num_heads = num_heads @@ -32,32 +44,48 @@ def forward(self, batch_size, seq_q, dim_model = hidden_q.size() seq_kv = hidden_kv.size(1) - h_q : torch.Tensor = self.project_q(hidden_q) - h_k : torch.Tensor = self.project_k(hidden_kv) - h_v : torch.Tensor = self.project_v(hidden_kv) + assert hidden_q.data_ptr() == hidden_kv.data_ptr() - h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) - h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) - h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) + hidden_q = bmt.nn.OpParallelLinear.apply( + hidden_q, + torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), + torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0), + True, False, + False, None + ) + + h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) + + if config['tp_size'] > 1: + #batch_size will changed in TensorParallel + batch_size = h_v.shape[0] + + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) h_q = h_q.permute(0, 2, 1, 3).contiguous() h_k = h_k.permute(0, 2, 1, 3).contiguous() h_v = h_v.permute(0, 2, 1, 3).contiguous() - h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) - h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) - h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_q = h_q.view(-1, seq_q, self.dim_head) + h_k = h_k.view(-1, seq_kv, self.dim_head) + h_v = h_v.view(-1, seq_kv, self.dim_head) score = torch.bmm( h_q, h_k.transpose(1, 2) ) score = score / math.sqrt(self.dim_head) - score = score.view(batch_size, self.num_heads, seq_q, seq_kv) + score = score.view(batch_size, -1, seq_q, seq_kv) if position_bias is not None: - score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) - + score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) + + if config['tp_size'] > 1: + with torch.no_grad(): + mask = all_gather(mask, config['tp_comm']).flatten(0,1) + score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, @@ -70,14 +98,14 @@ def forward(self, torch.scalar_tensor(0, device=score.device, dtype=score.dtype) ) - score = score.view(batch_size * self.num_heads, seq_q, seq_kv) + score = score.view(-1, seq_q, seq_kv) h_out = torch.bmm( score, h_v ) - h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) + h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) h_out = h_out.permute(0, 2, 1, 3).contiguous() - h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) + h_out = h_out.view(batch_size, seq_q, -1) attn_out = self.project_out(h_out) return attn_out diff --git a/example/layers/embedding.py b/example/layers/embedding.py index 13c47384..f62151c4 100644 --- a/example/layers/embedding.py +++ b/example/layers/embedding.py @@ -77,11 +77,13 @@ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: if not projection: - return F.embedding( + out = F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) + return out else: - return F.linear(input, self.weight) / math.sqrt(self.embedding_dim) + out = F.linear(input, self.weight) + return out def extra_repr(self) -> str: s = '{num_embeddings}, {embedding_dim}' @@ -97,4 +99,4 @@ def extra_repr(self) -> str: s += ', sparse=True' return s.format(**self.__dict__) - \ No newline at end of file + diff --git a/example/layers/feedforward.py b/example/layers/feedforward.py index 99d2dc3b..e88d2495 100644 --- a/example/layers/feedforward.py +++ b/example/layers/feedforward.py @@ -1,16 +1,23 @@ import torch import bmtrain as bmt -from bmtrain.nn import Linear +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear) +from bmtrain.global_var import config class Feedforward(bmt.DistributedModule): def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: super().__init__() - self.w_in = Linear(dim_model, dim_ff, bias = bias, dtype=dtype) - self.w_out = Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + if config['tp_size'] > 1: + self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype) + else: + self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype) + self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype) self.relu = torch.nn.ReLU() def forward(self, input : torch.Tensor) -> torch.Tensor: - return self.w_out(self.relu(self.w_in(input))) diff --git a/example/models/gpt.py b/example/models/gpt.py index 78d77a7d..64474ba8 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -1,6 +1,7 @@ import torch import bmtrain as bmt from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from bmtrain.global_var import config class GPT(bmt.DistributedModule): def __init__(self, @@ -13,7 +14,10 @@ def __init__(self, self.max_distance = max_distance - self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + if config['tp_size'] > 1: + self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) + else: + self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) self.transformers = bmt.TransformerBlockList([ @@ -42,7 +46,10 @@ def forward(self, out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - logits = self.word_emb(out, projection=True) + if config['tp_size'] > 1: + logits = self.word_emb.projection(out) + else: + logits = self.word_emb(out, projection=True) bmt.inspect.record_tensor(logits, "logits") - return logits \ No newline at end of file + return logits diff --git a/example/train.py b/example/train.py index 1a744e20..8aaf65e4 100644 --- a/example/train.py +++ b/example/train.py @@ -3,12 +3,13 @@ from models import GPT import time from bmtrain import optim +from bmtrain.global_var import config from bmtrain import inspect def main(): bmt.init_distributed( seed=0, - zero_level=2, + tp_size=2, ) model = GPT( @@ -24,7 +25,6 @@ def main(): ) bmt.init_parameters(model) - # print_inspect(model, "*") bmt.print_rank("Model memory") bmt.print_rank(torch.cuda.memory_summary()) @@ -52,7 +52,11 @@ def main(): if i == bmt.rank(): break - loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + if config['tp_size'] > 1: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + else: + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) @@ -77,7 +81,10 @@ def main(): ) batch, seq_len, vocab_out_size = logits.size() - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + if config['tp_size'] > 1: + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) + else: + loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) global_loss = bmt.sum_loss(loss).item() diff --git a/tests/test_all.py b/tests/test_all.py index 0ac386c0..fc9ab3e9 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -26,6 +26,8 @@ ("send_recv", 4), ("nccl_backward", 4), ("no_grad", 1), + ("column_parallel_linear", 2), + ("row_parallel_linear", 2), ("training", 4), ]) diff --git a/tests/test_column_parallel_linear.py b/tests/test_column_parallel_linear.py new file mode 100644 index 00000000..1c49570a --- /dev/null +++ b/tests/test_column_parallel_linear.py @@ -0,0 +1,55 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, gather_output, ckp_path, tp_size=2): + linear = bmt.nn.ColumnParallelLinear(8,8, gather_output=gather_output) + linear = bmt.CheckpointBlock(linear) + bmt.init_parameters(linear) + y = linear(x[config['topology'].tp_id]) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(gather_output, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + x = torch.randn(tp_size, 8,8, device='cuda').requires_grad_() + y1, weight_grad1, bias_grad1 = run_bmt(x, gather_output, ckp_path) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + tp_rank = config['topology'].tp_id + if gather_output: + assert np.allclose(y1.detach().cpu().numpy(), y2.flatten(0,1).detach().cpu().numpy()) + else: + torch_out_list = torch.split(y2, y2.size()[-1] // tp_size, dim=y2.dim()-1) + assert np.allclose(y1.detach().cpu().numpy(), torch_out_list[tp_rank].flatten(0,1).detach().cpu().numpy()) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=0) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + bias_grad_list = bias_grad2.chunk(tp_size, dim=0) + assert np.allclose(bias_grad1.reshape(bias_grad_list[tp_rank].shape).cpu().numpy(), bias_grad_list[tp_rank].cpu().numpy()) + +def test_gather_output(): + run(True, 'linear.ckp') + +def test_no_gather_output(): + run(False, 'linear_no_gather.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_gather_output() + test_no_gather_output() + diff --git a/tests/test_row_parallel_linear.py b/tests/test_row_parallel_linear.py new file mode 100644 index 00000000..f89b6dc5 --- /dev/null +++ b/tests/test_row_parallel_linear.py @@ -0,0 +1,54 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, ckp_path, split_input=True, use_checkpoint_block=True): + linear = bmt.nn.RowParallelLinear(8,8, split_input=split_input, all_reduce_output=True) + if use_checkpoint_block: + linear = bmt.CheckpointBlock(linear) + bmt.init_parameters(linear) + y = linear(x) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(split_input, use_checkpoint_block, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + tp_rank = config['topology'].tp_id + x = torch.randn(8,8, device='cuda').requires_grad_() + rank_x = x.chunk(tp_size, dim=0 if split_input else 1)[tp_rank] + y1, weight_grad1, bias_grad1 = run_bmt(rank_x, ckp_path, split_input, use_checkpoint_block) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + np.testing.assert_allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), atol=1e-5) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=1) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + assert np.allclose(bias_grad1.cpu().numpy(), bias_grad2.cpu().numpy()) + +def test_split_input(): + run(True, False, 'row_parallel_linear.ckp') + run(True, True, 'row_parallel_linear.ckp') + +def test_no_split_input(): + run(False, False, 'row_parallel_linear_no_split.ckp') + run(False, True, 'row_parallel_linear_no_split.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_no_split_input() + test_split_input() + From 38461bc34893c293ae76b5398adfabbc2c29a5ee Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Tue, 29 Aug 2023 15:01:11 +0800 Subject: [PATCH 22/52] Add Bf16 Support (#136) --- .gitignore | 3 +- bmtrain/loss/_function.py | 74 ++++++--------- bmtrain/loss/cross_entropy.py | 30 ------ bmtrain/optim/_function.py | 81 ++++++++++++---- bmtrain/optim/adam.py | 63 +++++++------ bmtrain/optim/adam_offload.py | 83 +++++++++-------- bmtrain/optim/optim_manager.py | 6 +- csrc/bind.cpp | 18 ++-- csrc/cuda/adam_cuda.cu | 73 ++++++++++++++- csrc/cuda/bfloat16.cuh | 5 + csrc/cuda/cross_entropy.cu | 99 ++++++++++---------- csrc/cuda/has_inf_nan.cu | 84 ++++++++++++++--- csrc/cuda/reduce.cuh | 2 - csrc/include/adam_cpu.hpp | 163 ++++++++++++++++++++++----------- csrc/include/bind.hpp | 44 ++++++--- setup.py | 1 + tests/test_all.py | 2 + tests/test_has_inf_nan.py | 13 ++- tests/test_loss_func.py | 40 ++++---- tests/test_nccl_backward.py | 7 +- tests/test_optim.py | 87 ++++++++++++------ tests/test_optim_state.py | 2 +- 22 files changed, 613 insertions(+), 367 deletions(-) create mode 100644 csrc/cuda/bfloat16.cuh diff --git a/.gitignore b/.gitignore index 0222862f..2e8c0dcd 100644 --- a/.gitignore +++ b/.gitignore @@ -150,4 +150,5 @@ log .vscode !bmtrain/dist -tests/test_log.txt \ No newline at end of file +tests/test_log.txt +tests/*.opt \ No newline at end of file diff --git a/bmtrain/loss/_function.py b/bmtrain/loss/_function.py index 658ef242..e2b67bb8 100644 --- a/bmtrain/loss/_function.py +++ b/bmtrain/loss/_function.py @@ -2,16 +2,20 @@ from .. import C import torch CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda -def has_inf_nan(g_fp16: torch.Tensor, out: torch.Tensor) -> None: - assert g_fp16.dtype == torch.float16, "g_fp16 must be a half tensor" +def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None: assert out.dtype == torch.uint8, "out must be a uint8 tensor" - assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(g_half), "g_fp16 must be contiguous and on cuda" assert CHECK_INPUT(out), "out must be contiguous and on cuda" mid = torch.zeros(1024, device=out.device, dtype=out.dtype) stream = torch.cuda.current_stream().cuda_stream - C.has_nan_inf_launcher(g_fp16.numel(), g_fp16.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) - - + if g_half.dtype == torch.float16: + C.has_nan_inf_fp16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) + elif g_half.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.has_nan_inf_bf16_launcher(g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) + else: + raise ValueError(f"has_inf_nan not supported for dtype {g_half.dtype}") def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Tensor, softmax: torch.Tensor, output: torch.Tensor, ignore_index: int) -> None: @@ -19,9 +23,7 @@ def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Ten CHECK_INPUT(target) CHECK_INPUT(softmax) CHECK_INPUT(output) - assert input.dtype == torch.float16, "input must be a half tensor" assert target.dtype == torch.int32, "target must be an int tensor" - assert softmax.dtype == torch.float16, "softmax must be a half tensor" assert output.dtype == torch.float32, "output must be a float tensor" assert input.numel() == softmax.numel(), "input and softmax must have the same number of elements" assert target.numel() == output.numel(), "target and output must have the same number of elements" @@ -30,43 +32,14 @@ def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Ten softmax_ptr = softmax.data_ptr() output_ptr = output.data_ptr() cuda_stream = torch.cuda.current_stream().cuda_stream - C.cross_entropy_forward_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) - -def cross_entropy_backward(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor, - softmax: torch.Tensor, grad_input: torch.Tensor, ignore_index: int) -> None: - CHECK_INPUT(grad_output) - CHECK_INPUT(target) - CHECK_INPUT(softmax) - CHECK_INPUT(grad_input) - assert grad_output.dtype == torch.float32, "grad_output must be a float tensor" - assert target.dtype == torch.int32, "target must be an int tensor" - assert softmax.dtype == torch.float16, "softmax must be a half tensor" - assert grad_input.dtype == torch.float16, "grad_input must be a half tensor" - assert grad_input.numel() == softmax.numel(), "grad_input and softmax must have the same number of elements" - assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements" - grad_output_ptr = grad_output.data_ptr() - target_ptr = target.data_ptr() - softmax_ptr = softmax.data_ptr() - grad_input_ptr = grad_input.data_ptr() - cuda_stream = torch.cuda.current_stream().cuda_stream - C.cross_entropy_backward_launcher(m, n, grad_output_ptr, target_ptr, softmax_ptr, grad_input_ptr, ignore_index, cuda_stream) - -def cross_entropy_forward_inplace(m: int, n: int, x: torch.Tensor, target: torch.Tensor, - output: torch.Tensor, ignore_index: int) -> None: - CHECK_INPUT(x) - CHECK_INPUT(target) - CHECK_INPUT(output) - assert x.dtype == torch.float16, "x must be a half tensor" - assert target.dtype == torch.int32, "target must be an int tensor" - assert output.dtype == torch.float32, "output must be a float tensor" - assert target.numel() == output.numel(), "target and output must have the same number of elements" - cuda_stream = torch.cuda.current_stream().cuda_stream - x_ptr = x.data_ptr() - output_ptr = output.data_ptr() - target_ptr = target.data_ptr() - output_ptr = output.data_ptr() - - C.cross_entropy_forward_inplace_launcher(m, n, x_ptr, target_ptr, output_ptr, ignore_index, cuda_stream) + if input.dtype == torch.float16: + C.cross_entropy_forward_fp16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) + elif input.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.cross_entropy_forward_bf16_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) + else: + raise ValueError(f"cross_entropy_forward not supported for dtype {input.dtype}") def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor, x: torch.Tensor, ignore_index: int) -> None: @@ -75,12 +48,17 @@ def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, ta CHECK_INPUT(x) assert grad_output.dtype == torch.float32, "grad_output must be a float tensor" assert target.dtype == torch.int32, "target must be an int tensor" - assert x.dtype == torch.float16, "x must be a half tensor" assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements" cuda_stream = torch.cuda.current_stream().cuda_stream grad_output_ptr = grad_output.data_ptr() target_ptr = target.data_ptr() x_ptr = x.data_ptr() - C.cross_entropy_backward_inplace_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) - + if x.dtype == torch.float16: + C.cross_entropy_backward_inplace_fp16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) + elif x.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.cross_entropy_backward_inplace_bf16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) + else: + raise ValueError(f"cross_entropy_backward not supported for dtype {input.dtype}") diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 982a6469..a2e123ad 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -36,36 +36,6 @@ def backward(ctx, grad_output : torch.Tensor): ) return (softmax, None, None) -class OpFusedCrossEntropyInplace(torch.autograd.Function): - """ - CrossEntropy dim = 1 - """ - @staticmethod - def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int): - assert x.ndim == 2 - out = torch.empty(x.size(0), device=x.device, dtype=torch.float) - F.cross_entropy_forward_inplace( - x.size(0), x.size(1), - x, target, - out, - ignore_index, - ) # x is inplace modify to softmax result - ctx.ignore_index = ignore_index - ctx.save_for_backward(x, target) - return out # float tensor - - @staticmethod - def backward(ctx, grad_output : torch.Tensor): - grad_output = grad_output.contiguous() - softmax, target = ctx.saved_tensors - F.cross_entropy_backward_inplace( - softmax.size(0), softmax.size(1), - grad_output, target, - softmax, - ctx.ignore_index, - ) # softmax is inplace modify to grad_input - return (softmax, None, None) - class FusedCrossEntropy(torch.nn.Module): r"""This criterion computes the cross entropy loss between input and target. diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py index ee4b04a7..f04f9ca0 100644 --- a/bmtrain/optim/_function.py +++ b/bmtrain/optim/_function.py @@ -1,4 +1,3 @@ - from .. import C import torch CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda @@ -11,8 +10,8 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T assert m_fp32.is_contiguous(), "m_fp32 must be contiguous" assert v_fp32.is_contiguous(), "v_fp32 must be contiguous" assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" - assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor" - assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor" + assert param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16, "param_fp16 must be float16/bfloat16 tensor" + assert g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16, "g_fp16 must be float16/bfloat16 tensor" assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor" assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor" @@ -26,22 +25,28 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step - C.adam_cpu_launcher( - param_fp32.numel(), - param_fp32.data_ptr(), - param_fp16.data_ptr(), - g_fp16.data_ptr(), - m_fp32.data_ptr(), - v_fp32.data_ptr(), - beta1, beta2, - eps, lr, - scale, - weight_decay, - bias_correction1, - bias_correction2, - ) + if g_fp16.dtype == torch.float16: + launcher = C.adam_cpu_fp16_launcher + elif g_fp16.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + launcher = C.adam_cpu_bf16_launcher + launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_fp16.data_ptr(), + g_fp16.data_ptr(), + m_fp32.data_ptr(), + v_fp32.data_ptr(), + beta1, beta2, + eps, lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + ) -def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor, +def adam_fp16(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor, v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, weight_decay: float, step: int) -> None: assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" @@ -61,7 +66,7 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step stream = torch.cuda.current_stream().cuda_stream - C.adam_launcher( + C.adam_fp16_launcher( param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr(), @@ -76,3 +81,41 @@ def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tenso bias_correction2, stream ) + +def adam_bf16(param_fp32: torch.Tensor, param_bf16: torch.Tensor, g_bf16: torch.Tensor, m_fp32: torch.Tensor, + v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, + weight_decay: float, step: int) -> None: + assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda" + assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda" + assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda" + assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" + assert param_bf16.dtype == torch.bfloat16, "param_fp16 must be float16 tensor" + assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor" + assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor" + assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" + assert param_fp32.numel() == param_bf16.numel(), "param_fp32 and param_bf16 must have the same number of elements" + assert param_fp32.numel() == g_bf16.numel(), "param_fp32 and g_fp16 must have the same number of elements" + assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_m_fp32 must have the same number of elements" + assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + stream = torch.cuda.current_stream().cuda_stream + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.adam_bf16_launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_bf16.data_ptr(), + g_bf16.data_ptr(), + m_fp32.data_ptr(), + v_fp32.data_ptr(), + beta1, beta2, + eps, lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + stream + ) \ No newline at end of file diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index b63a4f51..a3138980 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -40,8 +40,9 @@ def _on_justify_scale(self, old_scale, new_scale): if p in self.state: state = self.state[p] if len(state) > 0: - state['exp_avg'] *= delta - state['exp_avg_sq'] *= delta + if p.dtype == torch.float16: + state['exp_avg'] *= delta + state['exp_avg_sq'] *= delta @torch.no_grad() def step(self, closure=None, scale=1): @@ -63,45 +64,32 @@ def step(self, closure=None, scale=1): if p.grad is not None and p.requires_grad: if p.grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - if p.dtype not in [torch.float16, torch.float32]: - raise RuntimeError('Adam only supports fp32 or fp16 gradients') + if p.dtype not in [torch.float32, torch.half, torch.bfloat16]: + raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients') state = self.state[p] # Lazy state initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros(p.size(), dtype=p.dtype, device=p.device) # on device + if p.dtype == torch.float16: + state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float16, device=p.device) # on device + else: + state['exp_avg'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device) # on device - - if p.dtype == torch.half: + state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device=p.device)# on device + + if p.dtype != torch.float32: state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device=p.device) # on device state['_param_fp32'].copy_(p) # update the steps for each param group update - state['step'] += 1 - if ('maximize' in group) and (group['maximize'] is True): grad = -p.grad else: grad = p.grad - if p.dtype == torch.half: - F.adam( - state["_param_fp32"], # fp32 - p, # fp16 - grad, # fp16 - state['exp_avg'], # fp16: m - state["exp_avg_sq"], # fp32: v - group['betas'][0], group['betas'][1], - group['eps'], - 0.0 if state["step"] <= self._hold_steps else group['lr'], - scale, - group['weight_decay'], - state['step'] - ) - else: + if p.dtype == torch.float32: other_kwargs = {} if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters: other_kwargs['maximize'] = False @@ -116,11 +104,30 @@ def step(self, closure=None, scale=1): amsgrad=False, beta1=group['betas'][0], beta2=group['betas'][1], - lr=0.0 if state["step"] <= self._hold_steps else group['lr'], + lr=0.0 if state["step"] < self._hold_steps else group['lr'], weight_decay=group['weight_decay'], eps=group['eps'], **other_kwargs ) + state['step'] += 1 + else: + f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16 + state['step'] += 1 + f( + state["_param_fp32"], # fp32 + p, # fp16 + grad, # fp16 + state['exp_avg'], # fp16: m + state["exp_avg_sq"], # fp32: v + group['betas'][0], group['betas'][1], + group['eps'], + 0.0 if state["step"] < self._hold_steps else group['lr'], + scale, + group['weight_decay'], + state['step'] + ) + + return loss @@ -159,11 +166,11 @@ def load_state_dict(self, state_dict: dict) -> None: if k in id_map: param = id_map[k] - if param.dtype == torch.half and "_param_fp32" not in v: + if param.dtype != torch.float32 and "_param_fp32" not in v: v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device=param.device) v["_param_fp32"].copy_(param) - for name, dtype in [("exp_avg", param.dtype), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + for name, dtype in [("exp_avg", torch.float16 if param.dtype == torch.float16 else torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: if name in v: v[name] = v[name].to(param.device).to(dtype) diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index e33219bf..5b34a287 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -54,8 +54,8 @@ def step(self, closure=None, scale=1): if p.grad is not None and p.requires_grad: if p.grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') - if p.dtype not in [torch.float16, torch.float32]: - raise RuntimeError('Adam only supports fp32 or fp16 gradients') + if p.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise RuntimeError('Adam only supports fp32, fp16 and bf16 gradients') state = self.state[p] # Lazy state initialization @@ -66,19 +66,19 @@ def step(self, closure=None, scale=1): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros(p.size(), dtype=torch.float32, device="cpu") # on host - if p.dtype == torch.half: - state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device="cpu") # on host + if p.dtype == torch.float32: + state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host state['_param_fp32'].copy_(p) # placeholder - state["_param_fp16"] = torch.empty(p.size(), dtype=torch.float16, pin_memory=True) # on host - state["_grad_fp16"] = torch.empty(p.size(), dtype=torch.float16, pin_memory=True) # on host + state["_grad_fp32"] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host else: - state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host + state['_param_fp32'] = torch.empty(p.size(), dtype=torch.float32, device="cpu") # on host state['_param_fp32'].copy_(p) # placeholder - state["_grad_fp32"] = torch.empty(p.size(), dtype=torch.float32, pin_memory=True) # on host + state["_param_fp16"] = torch.empty(p.size(), dtype=p.dtype, pin_memory=True) # on host + state["_grad_fp16"] = torch.empty(p.size(), dtype=p.dtype, pin_memory=True) # on host if p not in self._events: self._events[p] = torch.cuda.Event() @@ -87,39 +87,18 @@ def step(self, closure=None, scale=1): # transfer parameters to host asynchronously for param, state, event, _, _, _, _, _ in update_params: - if param.dtype == torch.half: - state["_grad_fp16"].copy_(param.grad, non_blocking=True) - else: + if param.dtype == torch.float32: state["_grad_fp32"].copy_(param.grad, non_blocking=True) + else: + state["_grad_fp16"].copy_(param.grad, non_blocking=True) torch.cuda.current_stream().record_event(event) for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params: # wait for transfer to host event.synchronize() - state["step"] += 1 - # update parameters - if param.dtype == torch.half: - if ('maximize' in group) and (group['maximize'] is True): - grad = -state["_grad_fp16"] - else: - grad = state["_grad_fp16"] - F.adam_cpu( - state["_param_fp32"].view(-1), - state["_param_fp16"].view(-1), - grad.view(-1), - state["exp_avg"].view(-1), - state["exp_avg_sq"].view(-1), - beta1, beta2, - eps, 0.0 if state["step"] <= self._hold_steps else lr, - scale, - weight_decay, - state["step"] - ) - # transfer parameters back to device asynchronously - param.copy_(state["_param_fp16"], non_blocking=True) - else: + if param.dtype == torch.float32: state["_grad_fp32"].mul_(1.0 / scale) if ('maximize' in group) and (group['maximize'] is True): grad = -state["_grad_fp32"] @@ -139,13 +118,35 @@ def step(self, closure=None, scale=1): amsgrad=False, beta1=beta1, beta2=beta2, - lr=0.0 if state["step"] <= self._hold_steps else lr, + lr=0.0 if state["step"] < self._hold_steps else lr, weight_decay=weight_decay, eps=eps, **other_kwargs ) # transfer parameters back to device asynchronously param.copy_(state["_param_fp32"], non_blocking=True) + state["step"] += 1 + else: + state["step"] += 1 + if ('maximize' in group) and (group['maximize'] is True): + grad = -state["_grad_fp16"] + else: + grad = state["_grad_fp16"] + F.adam_cpu( + state["_param_fp32"].view(-1), + state["_param_fp16"].view(-1), + grad.view(-1), + state["exp_avg"].view(-1), + state["exp_avg_sq"].view(-1), + beta1, beta2, + eps, 0.0 if state["step"] < self._hold_steps else lr, + scale, + weight_decay, + state["step"] + ) + # transfer parameters back to device asynchronously + param.copy_(state["_param_fp16"], non_blocking=True) + return loss @@ -193,15 +194,14 @@ def load_state_dict(self, state_dict: dict) -> None: v[name] = v[name].to("cpu").to(dtype) state[param] = v - if param.dtype == torch.half: + if param.dtype == torch.float32: + state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory() # on host # initialize placeholders - state[param]["_param_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host - state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host + state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host else: - state[param]["_param_fp32"] = state[param]["_param_fp32"].pin_memory() - # initialize placeholders - state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host + state[param]["_param_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host + state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host else: state[k] = v @@ -254,5 +254,4 @@ def cut_states(state): #TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu def zero_grad(self, set_to_none: bool = False): - super().zero_grad(set_to_none=set_to_none) - + super().zero_grad(set_to_none=set_to_none) \ No newline at end of file diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 78ad15f8..9b7a3120 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -11,9 +11,9 @@ def check_overflow(param_groups): has_inf_or_nan = torch.zeros(1, dtype=torch.uint8, device="cuda")[0] for group in param_groups: for p in group['params']: - if p.grad is not None and p.dtype == torch.half: # TODO support other types - has_inf_nan(p.grad, has_inf_or_nan) - + if p.grad is not None: + if p.dtype != torch.float: + has_inf_nan(p.grad, has_inf_or_nan) if "comm" in config: nccl.allReduce(has_inf_or_nan.storage(), has_inf_or_nan.storage(), "max", config["comm"]) diff --git a/csrc/bind.cpp b/csrc/bind.cpp index 8324ba52..73f79a61 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -1,13 +1,17 @@ #include "include/bind.hpp" PYBIND11_MODULE(C, m) { - m.def("has_nan_inf_launcher",&has_nan_inf_launcher,"has nan inf"); - m.def("adam_launcher", &adam_launcher, "adam function cpu"); - m.def("adam_cpu_launcher", &adam_cpu_launcher, "adam function cpu"); - m.def("cross_entropy_forward_launcher", &cross_entropy_forward_launcher, "cross entropy forward"); - m.def("cross_entropy_backward_launcher", &cross_entropy_backward_launcher, "cross entropy backward"); - m.def("cross_entropy_forward_inplace_launcher", &cross_entropy_forward_inplace_launcher, "cross entropy forward inplace"); - m.def("cross_entropy_backward_inplace_launcher", &cross_entropy_backward_inplace_launcher, "cross entropy backward inplace"); + m.def("is_bf16_supported",&is_bf16_supported,"whether bf16 supported"); + m.def("has_nan_inf_fp16_launcher",&has_nan_inf_fp16_launcher,"has nan inf"); + m.def("has_nan_inf_bf16_launcher",&has_nan_inf_bf16_launcher,"has nan inf bf16"); + m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu"); + m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu"); + m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu"); + m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu"); + m.def("cross_entropy_forward_fp16_launcher", &cross_entropy_forward_fp16_launcher, "cross entropy forward"); + m.def("cross_entropy_forward_bf16_launcher", &cross_entropy_forward_bf16_launcher, "cross entropy forward"); + m.def("cross_entropy_backward_inplace_fp16_launcher", &cross_entropy_backward_inplace_fp16_launcher, "cross entropy backward inplace"); + m.def("cross_entropy_backward_inplace_bf16_launcher", &cross_entropy_backward_inplace_bf16_launcher, "cross entropy backward inplace"); m.def("ncclGetUniqueId", &pyNCCLGetUniqueID, "nccl get unique ID"); m.def("ncclCommInitRank", &pyNCCLCommInitRank, "nccl init rank"); m.def("ncclCommDestroy", &pyNCCLCommDestroy, "nccl delete rank"); diff --git a/csrc/cuda/adam_cuda.cu b/csrc/cuda/adam_cuda.cu index 0ab55934..0510ac12 100644 --- a/csrc/cuda/adam_cuda.cu +++ b/csrc/cuda/adam_cuda.cu @@ -1,5 +1,7 @@ -#include #include +#include +#include +#include "bfloat16.cuh" namespace { // blocks , threads @@ -8,8 +10,8 @@ __global__ void adam_fp32_accum( const half *g, // (n) half *m, // (n) float *v, // (n) - float* param, // (n) - half* param_h, // (n) + float *param, // (n) + half *param_h, // (n) float beta1, float beta2, float eps, @@ -33,9 +35,45 @@ __global__ void adam_fp32_accum( m[col] = __float2half(local_m); } } + +__global__ void adam_fp32_accum_bf16( + int32_t n, + const std::uintptr_t g_ptr, // (n) + float *m, // (n) + float *v, // (n) + float *param, // (n) + std::uintptr_t param_h_ptr, // (n) + float beta1, + float beta2, + float eps, + float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* g = reinterpret_cast(g_ptr); + __nv_bfloat16* param_h = reinterpret_cast<__nv_bfloat16*>(param_h_ptr); + int32_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (col < n) { + float local_g = __bfloat162float(g[col]) / scale; // real_g + float local_m = beta1 * m[col] + (1 - beta1) * local_g; // real_m + float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g; // real_v + float local_p = param[col]; + local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2) + eps) - lr * weight_decay * local_p; + + param_h[col] = __float2bfloat16(local_p); + param[col] = local_p; + v[col] = local_v; + m[col] = local_m; + } +#endif +} + } -void adam_launcher( +void adam_fp16_launcher( int n, std::uintptr_t param_fp32, std::uintptr_t param_fp16, @@ -60,4 +98,29 @@ void adam_launcher( dim3 block_size = dim3(threads, 1, 1); dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); adam_fp32_accum<<(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); -} \ No newline at end of file +} + +void adam_bf16_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +) { + if (n <= 0) return; + auto m_ptr = reinterpret_cast(m_fp32); + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto v_fp32_ptr = reinterpret_cast(v_fp32); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + adam_fp32_accum_bf16<<(stream)>>>(n, g_bf16, m_ptr, v_fp32_ptr, param_fp32_ptr, param_bf16, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); +} diff --git a/csrc/cuda/bfloat16.cuh b/csrc/cuda/bfloat16.cuh new file mode 100644 index 00000000..564d8bec --- /dev/null +++ b/csrc/cuda/bfloat16.cuh @@ -0,0 +1,5 @@ +#include +#if defined(__CUDACC__) && CUDA_VERSION >= 11000 +#include +#define BF16_SUPPORT +#endif \ No newline at end of file diff --git a/csrc/cuda/cross_entropy.cu b/csrc/cuda/cross_entropy.cu index c0b742ac..bdd5a08e 100644 --- a/csrc/cuda/cross_entropy.cu +++ b/csrc/cuda/cross_entropy.cu @@ -1,11 +1,12 @@ -#include #include "reduce.cuh" #include -#include +#include +#include +#include "bfloat16.cuh" namespace { // blocks , threads<1024> -__global__ void cross_entropy_forward( +__global__ void cross_entropy_forward_fp16( int64_t n, const half *input, // (m, n) const int32_t *target, // (m) @@ -42,12 +43,11 @@ __global__ void cross_entropy_forward( } // blocks , threads<1024> -__global__ void cross_entropy_backward( +__global__ void cross_entropy_backward_inplace_fp16( int64_t n, const float *grad_output, // (m) const int32_t *target, // (m) - const half *softmax, // (m, n) - half *grad_input, // (m, n) + half *x, // (m, n) int32_t ignore_index ) { int64_t base_idx = blockIdx.x * n; @@ -56,83 +56,99 @@ __global__ void cross_entropy_backward( if (t == ignore_index) { half v = __float2half(0.); for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - grad_input[base_idx + i] = v; + x[base_idx + i] = v; } } else { half v = __float2half(grad_output[blockIdx.x]); + __syncthreads(); for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - grad_input[base_idx + i] = i==t ? __hsub(__hmul(softmax[base_idx + i], v), v) : __hmul(softmax[base_idx + i], v); + x[base_idx + i] = i==t ? __hsub(__hmul(x[base_idx + i], v), v) : __hmul(x[base_idx + i], v); } } } // blocks , threads<1024> -__global__ void cross_entropy_forward_inplace( +__global__ void cross_entropy_forward_bf16( int64_t n, - half *x, // (m, n) + std::uintptr_t input_ptr, // (m, n) const int32_t *target, // (m) + std::uintptr_t softmax_ptr, // (m, n) float *output, // (m) int32_t ignore_index ) { +#ifdef BF16_SUPPORT + __nv_bfloat16* input = reinterpret_cast<__nv_bfloat16*>(input_ptr); + __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr); int64_t base_idx = blockIdx.x * n; float local_max = -INFINITY; for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - local_max = fmaxf(__half2float(x[base_idx + i]), local_max); + local_max = fmaxf(__bfloat162float(input[base_idx + i]), local_max); } + local_max = fmaxf(block_allreduce_max(local_max), -1e6); float local_sum = 0; for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - local_sum += expf(__half2float(x[base_idx + i]) - local_max); + local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max); } local_sum = block_allreduce_sum(local_sum) + 1e-10; // avoid nan + + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(input[base_idx + i]) - local_max) / local_sum ); + } if (threadIdx.x == 0) { if (target[blockIdx.x] != ignore_index) { - output[blockIdx.x] = -__half2float(x[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum); + output[blockIdx.x] = -__bfloat162float(input[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum); } else { output[blockIdx.x] = 0; } } - - __syncthreads(); - - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - x[base_idx + i] = __float2half( expf(__half2float(x[base_idx + i]) - local_max) / local_sum ); - } +#endif } // blocks , threads<1024> -__global__ void cross_entropy_backward_inplace( +__global__ void cross_entropy_backward_inplace_bf16( int64_t n, const float *grad_output, // (m) const int32_t *target, // (m) - half *x, // (m, n) + std::uintptr_t x_ptr, // (m, n) int32_t ignore_index ) { +#ifdef BF16_SUPPORT + __nv_bfloat16* x = reinterpret_cast<__nv_bfloat16*>(x_ptr); int64_t base_idx = blockIdx.x * n; int32_t t = target[blockIdx.x]; if (t == ignore_index) { - half v = __float2half(0.); + __nv_bfloat16 v = __float2bfloat16(0.); for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { x[base_idx + i] = v; } } else { - half v = __float2half(grad_output[blockIdx.x]); + #if __CUDA_ARCH__ >= 800 + __nv_bfloat16 v = __float2bfloat16(grad_output[blockIdx.x]); __syncthreads(); for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { x[base_idx + i] = i==t ? __hsub(__hmul(x[base_idx + i], v), v) : __hmul(x[base_idx + i], v); } + #else + float v = grad_output[blockIdx.x]; + __syncthreads(); + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + x[base_idx + i] = __float2bfloat16(i==t ? (__bfloat162float(x[base_idx + i])*v)-v : __bfloat162float(x[base_idx + i])*v); + } + #endif } +#endif } } -void cross_entropy_forward_launcher( +void cross_entropy_forward_fp16_launcher( int32_t m, int32_t n, std::uintptr_t input, std::uintptr_t target, @@ -146,48 +162,40 @@ void cross_entropy_forward_launcher( auto softmax_ptr = reinterpret_cast(softmax); auto output_ptr = reinterpret_cast(output); int32_t threads = 1024; - cross_entropy_forward<<(stream)>>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index); + cross_entropy_forward_fp16<<(stream)>>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index); } -void cross_entropy_backward_launcher( +void cross_entropy_backward_inplace_fp16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, - std::uintptr_t softmax, - std::uintptr_t grad_input, + std::uintptr_t x, int32_t ignore_index, std::uintptr_t stream ) { - // auto output_ptr = grad_output.data_ptr(); auto output_ptr = reinterpret_cast(grad_output); - // auto target_ptr = target.data_ptr(); auto target_ptr = reinterpret_cast(target); - auto softmax_ptr = reinterpret_cast(softmax); - auto input_ptr = reinterpret_cast(grad_input); + auto x_ptr = reinterpret_cast(x); int32_t threads = 1024; - cross_entropy_backward<<(stream)>>>(n, output_ptr, target_ptr, softmax_ptr, input_ptr, ignore_index); + cross_entropy_backward_inplace_fp16<<(stream)>>>(n, output_ptr, target_ptr, x_ptr, ignore_index); } -void cross_entropy_forward_inplace_launcher( +void cross_entropy_forward_bf16_launcher( int32_t m, int32_t n, - std::uintptr_t x, + std::uintptr_t input, std::uintptr_t target, + std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, std::uintptr_t stream ) { - // auto x_ptr = reinterpret_cast(x.data_ptr()); - auto x_ptr = reinterpret_cast(x); - // auto target_ptr = target.data_ptr(); auto target_ptr = reinterpret_cast(target); - // auto output_ptr = output.data_ptr(); auto output_ptr = reinterpret_cast(output); int32_t threads = 1024; - // auto stream = at::cuda::getCurrentCUDAStream(); - cross_entropy_forward_inplace<<(stream)>>>(n, x_ptr, target_ptr, output_ptr, ignore_index); + cross_entropy_forward_bf16<<(stream)>>>(n, input, target_ptr, softmax, output_ptr, ignore_index); } -void cross_entropy_backward_inplace_launcher( +void cross_entropy_backward_inplace_bf16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, @@ -195,13 +203,8 @@ void cross_entropy_backward_inplace_launcher( int32_t ignore_index, std::uintptr_t stream ) { - // auto output_ptr = grad_output.data_ptr(); auto output_ptr = reinterpret_cast(grad_output); - // auto target_ptr = target.data_ptr(); auto target_ptr = reinterpret_cast(target); - // auto x_ptr = reinterpret_cast(x.data_ptr()); - auto x_ptr = reinterpret_cast(x); int32_t threads = 1024; - // auto stream = at::cuda::getCurrentCUDAStream(); - cross_entropy_backward_inplace<<(stream)>>>(n, output_ptr, target_ptr, x_ptr, ignore_index); + cross_entropy_backward_inplace_bf16<<(stream)>>>(n, output_ptr, target_ptr, x, ignore_index); } \ No newline at end of file diff --git a/csrc/cuda/has_inf_nan.cu b/csrc/cuda/has_inf_nan.cu index b0e906ff..32bc5a5f 100644 --- a/csrc/cuda/has_inf_nan.cu +++ b/csrc/cuda/has_inf_nan.cu @@ -1,16 +1,18 @@ -#include #include +#include +#include +#include +#include "bfloat16.cuh" namespace{ __inline__ __device__ bool isnan_(half v) { -#if __CUDA_ARCH__ >= 700 || __CUDA_ARCH__ == 600 - return __hisnan(v); -#else - - return !__heq(v, v); -#endif + #if __CUDA_ARCH__ >= 700 || __CUDA_ARCH__ == 600 + return __hisnan(v); + #else + return !__heq(v, v); + #endif } - + __inline__ __device__ int8_t warpReduceAny(int8_t x) { for (int offset = warpSize/2; offset > 0; offset /= 2) x |= __shfl_down_sync(0xFFFFFFFF, x, offset); @@ -30,7 +32,7 @@ __inline__ __device__ float blockReduceAny(int8_t x) { } // grid , thread<1024> -__global__ void bmt_has_nan_inf_1( +__global__ void bmt_has_nan_inf_fp16( int32_t n, const half* inp, // (n,) uint8_t* mid // (1024,) @@ -53,7 +55,7 @@ __global__ void bmt_has_nan_inf_1( } // grid <1>, thread<1024> -__global__ void bmt_has_nan_inf_2( +__global__ void bmt_has_nan_inf_reduce( const uint8_t* mid, // (1024,) uint8_t* out ) { @@ -64,9 +66,39 @@ __global__ void bmt_has_nan_inf_2( } } +// grid , thread<1024> +__global__ void bmt_has_nan_inf_bf16( + int32_t n, + const uintptr_t inp, // (n,) + uint8_t* mid // (1024,) +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* bf_inp = reinterpret_cast(inp); + int32_t gid = blockIdx.x * blockDim.x + threadIdx.x; + int32_t span = blockDim.x * gridDim.x; + + int8_t r = 0; + for (int i = gid; i < n; i += span) { + __nv_bfloat16 v = bf_inp[i]; + #if __CUDA_ARCH__ >= 800 + if (__hisinf(v) || __hisnan(v)) { + #else + if (isinf(__bfloat162float(v)) || isnan(__bfloat162float(v))) { + #endif + r = 1; + break; + } + } + r = blockReduceAny(r); + if (threadIdx.x == 0) { + mid[blockIdx.x] = r; + } +#endif +} + } -void has_nan_inf_launcher( +void has_nan_inf_fp16_launcher( int32_t n, std::uintptr_t g_fp16, std::uintptr_t mid, @@ -82,6 +114,32 @@ void has_nan_inf_launcher( dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1); - bmt_has_nan_inf_1<<(stream)>>>(n, g_ptr, mid_ptr); - bmt_has_nan_inf_2<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr); + bmt_has_nan_inf_fp16<<(stream)>>>(n, g_ptr, mid_ptr); + bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr); +} + +void has_nan_inf_bf16_launcher( + int32_t n, + std::uintptr_t g_bf16, + std::uintptr_t mid, + std::uintptr_t out, + std::uintptr_t stream +) { + if (n <= 0) return; + auto mid_ptr = reinterpret_cast(mid); + auto out_ptr = reinterpret_cast(out); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1); + + bmt_has_nan_inf_bf16<<(stream)>>>(n, g_bf16, mid_ptr); + bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr); +} + +int is_bf16_supported() { +#ifdef BF16_SUPPORT + return 1; +#endif + return 0; } \ No newline at end of file diff --git a/csrc/cuda/reduce.cuh b/csrc/cuda/reduce.cuh index 095e8593..a9c4c15b 100644 --- a/csrc/cuda/reduce.cuh +++ b/csrc/cuda/reduce.cuh @@ -1,5 +1,3 @@ -#include - namespace { const int WARP_SZ = 32; diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index d95cf637..1e497bb3 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -2,10 +2,9 @@ #include #include #include -#include +#include #include -#include -#include +#include #include #include #include @@ -69,8 +68,7 @@ inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, const F } } - - +// fp32 -> fp16 inline uint16_t fp16_ieee_from_fp32_value(float f) { // const float scale_to_inf = 0x1.0p+112f; // const float scale_to_zero = 0x1.0p-110f; @@ -84,45 +82,55 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) { float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - const uint32_t w = (uint32_t)fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } + const uint32_t w = (uint32_t)fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = (uint32_t)fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return static_cast( - (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) - ); - } + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = (uint32_t)fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) + ); +} +// fp16 -> fp32 inline float fp16_ieee_to_fp32_value(uint16_t h) { - const uint32_t w = (uint32_t)h << 16; const uint32_t sign = w & UINT32_C(0x80000000); const uint32_t two_w = w + w; const uint32_t exp_offset = UINT32_C(0xE0) << 23; const float exp_scale = 0x1.0p-112f; - const float normalized_value = - fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; const uint32_t magic_mask = UINT32_C(126) << 23; const float magic_bias = 0.5f; - const float denormalized_value = - fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; const uint32_t denormalized_cutoff = UINT32_C(1) << 27; const uint32_t result = - sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) - : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); + sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +// fp32 -> bf16 +inline uint16_t bf16_from_fp32_value(float f){ + return *reinterpret_cast(&f) >> 16; +} + +// bf16 -> fp32 +inline float bf16_to_fp32_value(uint16_t h){ + uint32_t src = h; + src <<= 16; + return *reinterpret_cast(&src); } void adam_cpu_0( @@ -141,23 +149,58 @@ void adam_cpu_0( ){ int64_t span = 1; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { - for (int64_t j = start; j < end; j += span) { - for (int64_t i = j; i < end; i++) { - float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; - float m = m_fp32_ptr[i]; - float v = v_fp32_ptr[i]; - float p = param_fp32_ptr[i]; - m = beta1 * m + (1 - beta1) * g; - v = beta2 * v + (1 - beta2) * g * g; - p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; - param_fp32_ptr[i] = p; - param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); - m_fp32_ptr[i] = m; - v_fp32_ptr[i] = v; - } + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } break; // must break here } - }); + }); +} + +void adam_cpu_bf16_0( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_bf16_ptr, + uint16_t* g_bf16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + int64_t span = 1; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_bf16_ptr[i] = bf16_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + } + }); } static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( @@ -223,7 +266,8 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( _mm256_storeu_ps(&m_fp32_ptr[j], m); _mm256_storeu_ps(&v_fp32_ptr[j], v); } - }}); + } + }); } static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( @@ -293,13 +337,10 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( _mm512_storeu_ps(&v_fp32_ptr[j], v); } } - }); + }); } - - - -void adam_cpu_launcher( +void adam_cpu_fp16_launcher( int64_t n, std::uintptr_t param_fp32, std::uintptr_t param_fp16, @@ -329,4 +370,24 @@ void adam_cpu_launcher( } } - +void adam_cpu_bf16_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto m_fp32_ptr = reinterpret_cast(m_fp32); + auto v_fp32_ptr = reinterpret_cast(v_fp32); + auto param_bf16_ptr = reinterpret_cast(param_bf16); + auto g_bf16_ptr = reinterpret_cast(g_bf16); + adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); +} diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp index 0929de91..94d6af95 100644 --- a/csrc/include/bind.hpp +++ b/csrc/include/bind.hpp @@ -1,19 +1,22 @@ -#include +#include #include "nccl.hpp" #include "adam_cpu.hpp" -void has_nan_inf_launcher(int32_t n,std::uintptr_t g_fp16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); +int is_bf16_supported(); -void cross_entropy_backward_launcher( +void has_nan_inf_fp16_launcher(int32_t n,std::uintptr_t g_fp16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); +void has_nan_inf_bf16_launcher(int32_t n,std::uintptr_t g_bf16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); + +void cross_entropy_forward_fp16_launcher( int32_t m, int32_t n, - std::uintptr_t grad_output, + std::uintptr_t input, std::uintptr_t target, std::uintptr_t softmax, - std::uintptr_t grad_input, + std::uintptr_t output, int32_t ignore_index, std::uintptr_t stream ); -void cross_entropy_backward_inplace_launcher( +void cross_entropy_backward_inplace_fp16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, @@ -21,24 +24,24 @@ void cross_entropy_backward_inplace_launcher( int32_t ignore_index, std::uintptr_t stream ); - void cross_entropy_forward_inplace_launcher( +void cross_entropy_forward_bf16_launcher( int32_t m, int32_t n, - std::uintptr_t x, + std::uintptr_t input, std::uintptr_t target, + std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, std::uintptr_t stream ); -void cross_entropy_forward_launcher( +void cross_entropy_backward_inplace_bf16_launcher( int32_t m, int32_t n, - std::uintptr_t input, + std::uintptr_t grad_output, std::uintptr_t target, - std::uintptr_t softmax, - std::uintptr_t output, + std::uintptr_t x, int32_t ignore_index, std::uintptr_t stream ); -void adam_launcher( +void adam_fp16_launcher( int n, std::uintptr_t param_fp32, std::uintptr_t param_fp16, @@ -52,4 +55,19 @@ void adam_launcher( float bias_correction1, float bias_correction2, uintptr_t stream +); +void adam_bf16_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream ); \ No newline at end of file diff --git a/setup.py b/setup.py index 2bbb55d8..1bac037e 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,7 @@ def build_extension(self, ext): if os.path.exists(build_temp): shutil.rmtree(build_temp) os.makedirs(build_temp) + cmake_args += ["-DPython_ROOT_DIR=" + os.path.dirname(os.path.dirname(sys.executable))] subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) diff --git a/tests/test_all.py b/tests/test_all.py index fc9ab3e9..07be4077 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -16,6 +16,8 @@ ("dropout", 1), ("loss_func", 1), + ("optim", 1), + ("multi_return", 2), ("middle_hidden", 4), ("other_hidden", 4), diff --git a/tests/test_has_inf_nan.py b/tests/test_has_inf_nan.py index fda85515..93ac8118 100644 --- a/tests/test_has_inf_nan.py +++ b/tests/test_has_inf_nan.py @@ -1,5 +1,4 @@ from utils import * - import torch import bmtrain.loss._function as F import random @@ -9,9 +8,9 @@ def check(x, v): F.has_inf_nan(x, out) assert_eq(out.item(), v) -def test_main(): +def test_main(dtype): for i in list(range(1, 100)) + [1000]*10 + [10000]*10 + [100000]*10 + [1000000]*10: - x = torch.rand((i,)).half().cuda() + x = torch.rand((i,)).to(dtype).cuda() check(x, 0) p = random.randint(0, i-1) x[p] = x[p] / 0 @@ -27,6 +26,12 @@ def test_main(): p = random.randint(0, i-1) x[p] = x[p] / 0 check(x, 1) + print("That's right") if __name__ == "__main__": - test_main() + test_main(torch.float16) + print("==============================================================================") + try: + test_main(torch.bfloat16) + except NotImplementedError: + pass diff --git a/tests/test_loss_func.py b/tests/test_loss_func.py index a76be5f4..a448b6d1 100644 --- a/tests/test_loss_func.py +++ b/tests/test_loss_func.py @@ -27,59 +27,53 @@ def check(x, tgt, loss_func1, loss_func2, bigmodel=None): loss_2, grad_2 = run(x, tgt, loss_func2, bigmodel=bigmodel, use_float=True) assert_eq(grad_1.isnan().sum(), 0) assert_eq(grad_2.isnan().sum(), 0) + print(f"{(loss_1 - loss_2).abs().item():.6f} {(grad_1 - grad_2).abs().max().item():.6f}") assert_lt((loss_1 - loss_2).abs().item(), 1e-5) - assert_lt((grad_1 - grad_2).abs().max().item(), 1e-2) + assert_lt((grad_1 - grad_2).abs().max().item(), 1e-1) -def test_simple(): +def test_simple(dtype): loss_func1 = bmt.loss.FusedCrossEntropy() loss_func2 = torch.nn.CrossEntropyLoss() N = 32 * 512 for i in range(1, 10): C = i * 10 - x = torch.randn(N, C).cuda().half() + x = torch.randn(N, C).cuda().to(dtype) tgt = torch.randint(0, C, (N,)).cuda().long() check(x, tgt, loss_func1, loss_func2) for i in range(1, 10): C = i * 100 - x = torch.randn(N, C).cuda().half() + x = torch.randn(N, C).cuda().to(dtype) tgt = torch.randint(0, C, (N,)).cuda().long() check(x, tgt, loss_func1, loss_func2) for i in range(1, 31): C = i * 1000 - x = torch.randn(N, C).cuda().half() + x = torch.randn(N, C).cuda().to(dtype) tgt = torch.randint(0, C, (N,)).cuda().long() check(x, tgt, loss_func1, loss_func2) -def test_other(): +def test_other(dtype): N = 32 * 512 for i in range(1, 11): C = i * 10 weight = [i+1 for i in range(C)] random.shuffle(weight) weight = torch.tensor(weight, device="cuda") - loss_func1 = bmt.loss.FusedCrossEntropy(weight=weight.clone().half()) + loss_func1 = bmt.loss.FusedCrossEntropy(weight=weight.clone().to(dtype)) loss_func2 = torch.nn.CrossEntropyLoss(weight=weight.clone().float()) - x = torch.randn(N, C).cuda().half() + x = torch.randn(N, C).cuda().to(dtype) tgt = torch.randint(0, C, (N,)).cuda().long() mask = torch.randint(0, 2, (N,)).cuda().bool() tgt[mask] = -100 check(x, tgt, loss_func1, loss_func2) -def test_inplace(): - loss_func1 = bmt.loss.FusedCrossEntropy(inplace=True) - loss_func2 = torch.nn.CrossEntropyLoss() - N = 32 * 512 - - for i in range(1, 11): - C = i * 10 - bigmodel = torch.nn.Linear(5, C).cuda().half() - x = torch.randn(N, 5).cuda().half() - tgt = torch.randint(0, C, (N,)).cuda().long() - check(x, tgt, loss_func1, loss_func2, bigmodel=bigmodel) - if __name__ == "__main__": - test_other() - test_inplace() - test_simple() \ No newline at end of file + test_other(torch.float16) + test_simple(torch.float16) + print("==============================================================================") + try: + test_other(torch.bfloat16) + test_simple(torch.bfloat16) + except NotImplementedError: + pass \ No newline at end of file diff --git a/tests/test_nccl_backward.py b/tests/test_nccl_backward.py index 3dcd0560..5e7b22d8 100644 --- a/tests/test_nccl_backward.py +++ b/tests/test_nccl_backward.py @@ -3,8 +3,8 @@ import bmtrain as bmt import torch -def test_main(): - x = torch.full((1,), bmt.rank() + 1, dtype=torch.half, device="cuda").requires_grad_(True) +def test_main(dtype): + x = torch.full((1,), bmt.rank() + 1, dtype=dtype, device="cuda").requires_grad_(True) y = bmt.distributed.all_reduce(x, "prod").view(-1) loss = (y * y).sum() / 2 loss.backward() @@ -17,4 +17,5 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed() - test_main() \ No newline at end of file + test_main(torch.half) + test_main(torch.bfloat16) \ No newline at end of file diff --git a/tests/test_optim.py b/tests/test_optim.py index fdb64521..0aca8c31 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,3 +1,4 @@ +from utils import * import torch import bmtrain as bmt from bmtrain import optim @@ -5,55 +6,89 @@ class TestModule(torch.nn.Module): def __init__(self): super(TestModule, self).__init__() - self.fc1 = torch.nn.Linear(128, 128) + self.fc1 = torch.nn.Linear(128, 128, bias=False) self.fc2 = torch.nn.Linear(128, 128) self.fc3 = torch.nn.Linear(128, 128) self.fc4 = torch.nn.Linear(128, 128) self.fc5 = torch.nn.Linear(128, 128) self.param = torch.nn.Parameter(torch.empty(1237)) -def main(): - # FIXME: this test script is not working + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.fc4(x) + x = self.fc5(x) + return x + +def main(dtype): model1 = TestModule() model2 = TestModule() model3 = TestModule() + model4 = TestModule() + model5 = TestModule() state_dict = model1.state_dict() for kw in state_dict.keys(): state_dict[kw] = torch.randn_like(state_dict[kw]) - + model1.load_state_dict(state_dict) model2.load_state_dict(state_dict) model3.load_state_dict(state_dict) + model4.load_state_dict(state_dict) + model5.load_state_dict(state_dict) - model1 = model1.cuda() - model2 = model2.cuda() + model1 = model1.cuda().to(dtype) + model2 = model2.cuda().to(dtype) model3 = model3.cuda() - - opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) - opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) - opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) + model4 = model4.cuda() + model5 = model5.cuda() + + opt1 = bmt.optim.AdamOptimizer(model1.parameters(), lr=1) + opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), lr=1) + opt3 = torch.optim.Adam(model3.parameters(), lr=1) + opt4 = bmt.optim.AdamOptimizer(model4.parameters(), lr=1) + opt5 = bmt.optim.AdamOffloadOptimizer(model5.parameters(), lr=1) + + optim_manager = bmt.optim.OptimManager(loss_scale=4) + optim_manager.add_optimizer(opt1) + optim_manager.add_optimizer(opt2) + optim_manager.add_optimizer(opt3) + optim_manager.add_optimizer(opt4) + optim_manager.add_optimizer(opt5) for _ in range(100): - opt1.zero_grad() - opt2.zero_grad() - opt3.zero_grad() + optim_manager.zero_grad() - for p1, p2, p3 in zip(model1.parameters(), model2.parameters(), model3.parameters()): + for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()): grad = torch.randn_like(p1) - p1.grad = grad - p2.grad = grad + p1.grad = grad.to(dtype) + p2.grad = grad.to(dtype) p3.grad = grad.float() - - opt1.step() - opt2.step() - opt3.step() + p4.grad = grad.float() + p5.grad = grad.float() + + optim_manager.step() + torch.cuda.synchronize() - for p1, p2, p3 in zip(model1.parameters(), model2.parameters(), model3.parameters()): - diff1 = torch.abs(p1 - p2).max() - diff2 = torch.abs(p1 - p3).max() - diff3 = torch.abs(p2 - p3).max() - print(diff1, diff2, diff3) + for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()): + diff1 = torch.abs(p1 - p2).max().item() + diff2 = torch.abs(p1 - p3).max().item() + diff3 = torch.abs(p2 - p3).max().item() + diff4 = torch.abs(p3 - p4).max().item() + diff5 = torch.abs(p3 - p5).max().item() + print(f"{diff1:.6f}, {diff2:.6f}, {diff3:.6f}, {diff4:.6f}, {diff5:.6f}") + assert_lt(diff1, 1) + assert_lt(diff2, 1) + assert_lt(diff3, 1) + assert_eq(diff4, 0) + assert_lt(diff5, 0.00001) if __name__ == "__main__": - main() + bmt.init_distributed() + main(torch.float16) + print("==============================================================================") + try: + main(torch.bfloat16) + except NotImplementedError: + pass diff --git a/tests/test_optim_state.py b/tests/test_optim_state.py index cef06734..16833b42 100644 --- a/tests/test_optim_state.py +++ b/tests/test_optim_state.py @@ -10,7 +10,7 @@ def __init__(self): self.fc1 = bmt.BMTrainModelWrapper(torch.nn.Linear(768, 3072)) self.fc2 = bmt.BMTrainModelWrapper(torch.nn.Linear(3072, 1024)) self.fc3 = bmt.BMTrainModelWrapper(torch.nn.Linear(1024, 768)) - self.param = bmt.DistributedParameter(torch.empty(1237)) + self.param = bmt.DistributedParameter(torch.zeros(1237)) self.fc4 = bmt.BMTrainModelWrapper(torch.nn.Linear(768, 300)) self.fc5 = bmt.BMTrainModelWrapper(torch.nn.Linear(300, 768)) self.dropout = torch.nn.Dropout(0.0) From a7fb0784ac9b0b6c2159e789b76f6ed55424a72e Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 30 Aug 2023 15:31:58 +0800 Subject: [PATCH 23/52] Rename class and files (#157) --- README.md | 20 +-- bmtrain/__init__.py | 4 +- bmtrain/block_layer.py | 47 ++++--- bmtrain/checkpointing.py | 179 -------------------------- bmtrain/hook_func.py | 8 +- bmtrain/inspect/model.py | 8 +- bmtrain/layer.py | 8 +- bmtrain/param_init.py | 8 +- bmtrain/parameter.py | 4 +- bmtrain/pipe_layer.py | 22 ++-- bmtrain/store.py | 4 +- bmtrain/wrapper.py | 12 +- docs/source/notes/quickstart.md | 18 +-- example/models/gpt.py | 2 +- tests/test_column_parallel_linear.py | 2 +- tests/test_different_output_shape.py | 10 +- tests/test_dropout.py | 4 +- tests/test_inspector_hidden.py | 8 +- tests/test_load_ckpt.py | 10 +- tests/test_middle_hidden.py | 8 +- tests/test_multi_return.py | 8 +- tests/test_no_grad.py | 8 +- tests/test_other_hidden.py | 8 +- tests/test_requires_grad.py | 6 +- tests/test_requires_grad_multi_gpu.py | 6 +- tests/test_row_parallel_linear.py | 2 +- tests/test_training.py | 2 +- 27 files changed, 123 insertions(+), 303 deletions(-) delete mode 100644 bmtrain/checkpointing.py diff --git a/README.md b/README.md index d3e5d8cf..f95328f0 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ To enable ZeRO optimization, you need to make some simple replacements to the or * `torch.nn.Module` -> `bmtrain.DistributedModule` * `torch.nn.Parameter` -> `bmtrain.DistributedParameter` -And wrap the transformer blocks with `bmtrain.CheckpointBlock`. +And wrap the transformer blocks with `bmtrain.Block`. Here is an example. @@ -118,9 +118,9 @@ class MyModule(bmt.DistributedModule): # changed here super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here self.module_list = torch.nn.ModuleList([ - bmt.CheckpointBlock(SomeTransformerBlock()), # changed here - bmt.CheckpointBlock(SomeTransformerBlock()), # changed here - bmt.CheckpointBlock(SomeTransformerBlock()) # changed here + bmt.Block(SomeTransformerBlock()), # changed here + bmt.Block(SomeTransformerBlock()), # changed here + bmt.Block(SomeTransformerBlock()) # changed here ]) def forward(self): @@ -151,9 +151,9 @@ class MyModule(bmt.DistributedModule): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = torch.nn.ModuleList([ - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()) + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) ]) def forward(self): @@ -174,9 +174,9 @@ class MyModule(bmt.DistributedModule): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = bmt.TransformerBlockList([ # changed here - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()) + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) ]) def forward(self): diff --git a/bmtrain/__init__.py b/bmtrain/__init__.py index ae243e65..f4ac3642 100644 --- a/bmtrain/__init__.py +++ b/bmtrain/__init__.py @@ -10,7 +10,7 @@ from .layer import DistributedModule from .param_init import init_parameters, grouped_parameters from .synchronize import synchronize, sum_loss, wait_loader, gather_result -from .block_layer import CheckpointBlock, TransformerBlockList +from .block_layer import Block, TransformerBlockList from .wrapper import BMTrainModelWrapper from .pipe_layer import PipelineTransformerBlockList from . import debug @@ -22,3 +22,5 @@ from . import optim from . import inspect from . import lr_scheduler + +CheckpointBlock = Block diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index e16ef6a8..61b335cf 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -5,15 +5,12 @@ import torch from . import nccl from .parameter import DistributedParameter, OpAllGather -from .checkpointing import ( - CheckpointBlockContext +from .zero_context import ( + ZeroContext ) -from . import debug - from . import hook_func -import copy import inspect from torch.utils.checkpoint import checkpoint @@ -48,21 +45,21 @@ def _get_param_kw(param : DistributedParameter): group_name = "_g_" + param.group return type_name + grad_name + group_name -class CheckpointBlock(torch.nn.Module): - """ A bmtrain block containing two memory-saving methods of ZeRO-2/3 and checkpoint. - - Checkpoint block is used to save the occupation of GPU memory in training. - - For details, please refer to `Checkpointing `_ . +class Block(torch.nn.Module): + """ A block containing two memory-saving methods of ZeRO and checkpoint. + For details please refer to `ZeRO `_ and + `Checkpointing `_ . Args: - model (torch.nn.Module): The model to be checkpointed. All kinds of modules are supported. + inner_module (torch.nn.Module): The module to reduce memory usage. All kinds of modules are supported. use_checkpoint (boolean): use checkpoint or not. Default True. + zero_level (int): 2 (ZeRO-2) indicates that optimizer states and gradients are partitioned across the process, + 3 (ZeRO-3) means that the parameters are partitioned one the basis of ZeRO-2. Default 3. Examples: >>> transformer_block = TransformerBlock(...) - >>> checkpoint_block = CheckpointBlock(transformer_block) - >>> y1, ... = checkpoint_block(x) + >>> bmt_block = Block(transformer_block) + >>> y1, ... = bmt_block(x) >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ @@ -189,7 +186,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) # clear parameter data, but keep the dtype and device - setattr(param, "_in_checkpoint_block", True) + setattr(param, "_in_block", True) for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] @@ -289,12 +286,12 @@ def __delattr__(self, name): object.__delattr__(self, name) def _save_to_state_dict(self, destination, prefix, keep_vars): - raise RuntimeError("._save_to_state_dict() of CheckpointBlock should not be called") + raise RuntimeError("._save_to_state_dict() of Block should not be called") def state_dict(self, destination=None, prefix='', keep_vars=False): # gather here with torch.no_grad(): - with CheckpointBlockContext(self): + with ZeroContext(self): return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, @@ -358,7 +355,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys.append(key) for name, param in self.named_parameters(): - if isinstance(param, DistributedParameter) and not param._in_checkpoint_block: + if isinstance(param, DistributedParameter) and not param._in_block: key = prefix + name all_keys.append(key) if key in state_dict: @@ -512,7 +509,7 @@ def __repr__(self): class TransformerBlockList(torch.nn.Module): r""" - TransformerBlockList is a list of CheckpointBlocks. + TransformerBlockList is a list of bmt.Block. This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass. @@ -529,16 +526,16 @@ class TransformerBlockList(torch.nn.Module): >>> hidden_state = transformer_module_list(hidden_state, ...) """ - _modules: Dict[str, CheckpointBlock] + _modules: Dict[str, Block] - def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) -> None: + def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: super().__init__() self._modules = {} pre_module = None for i, module in enumerate(modules): - if not isinstance(module, CheckpointBlock): - module = CheckpointBlock(module) + if not isinstance(module, Block): + module = Block(module) module._mode = "ZERO" module.set_pre_module(pre_module) @@ -579,9 +576,9 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1, sqrt=False) def __len__(self) -> int: return len(self._modules) - def __iter__(self) -> Iterator[CheckpointBlock]: + def __iter__(self) -> Iterator[Block]: return iter(self._modules.values()) - def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: + def __getitem__(self, index: Union[int, str]) -> Block: return self._modules[str(index)] def forward(self, *args, return_hidden_states = False): diff --git a/bmtrain/checkpointing.py b/bmtrain/checkpointing.py deleted file mode 100644 index 550225be..00000000 --- a/bmtrain/checkpointing.py +++ /dev/null @@ -1,179 +0,0 @@ -import torch -from . import debug -from . import nccl -from .global_var import config -from .synchronize import wait_loader - -class ScopedDebugTensorList: - def __init__(self) -> None: - self._hidden_states = [] - - @property - def hidden_states(self): - return self._hidden_states - - def _set_hidden_states(self, hidden_states): - self._hidden_states = hidden_states - -class ScopedTensorInspectorContext: - def __init__(self): - pass - - def __enter__(self): - self.prev_hidden = debug.get("_inspect_hidden_states", []) - debug.set("_inspect_hidden_states", []) - self._local_list = ScopedDebugTensorList() - return self._local_list - - def __exit__(self, *args): - self._local_list._set_hidden_states(debug.get("_inspect_hidden_states", [])) - debug.set("_inspect_hidden_states", self.prev_hidden) - self.prev_hidden = None - -class CheckpointBlockContext: - def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = False) -> None: - self.block = block - self.ctx_dict = ctx_dict - self._param_buffer = {} - self._grad_buffer = {} - self._param_tensor = {} - self._grad_tensor = {} - self._need_release = False - - def enter(self, flag=0, requires_grad=False): - """ - gather parameters - """ - if self.block._ready: - return - self.block._ready = True - self._need_release = True - - wait_loader() - with torch.cuda.stream(config["load_stream"]): - for kw, val in self.block._storage_info.items(): - assert self.block._storage_params[kw].is_cuda - assert kw not in self._grad_buffer - assert kw not in self._param_buffer - local_param = self.block._storage_params[kw] - - storage_type = local_param.storage_type() - if flag != 2: - self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) - - if requires_grad and local_param.requires_grad: - self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) - self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() - if flag != 2: - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - nccl.allGather( - self.block._storage_params[kw].storage(), - self._param_buffer[kw], - val['zero_comm'] - ) - nccl.groupEnd() - - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(config["load_stream"]) - - # set wait stream for each storage - for kw in self.block._storage_info.keys(): - if flag != 2: - self._param_tensor[kw].record_stream(current_stream) - if requires_grad and kw in self._grad_tensor: - self._grad_tensor[kw].record_stream(current_stream) - - # update parameters in block - for param in self.block._param_info: - kw_name = param["kw_name"] - offset = param["offset"] - shape = param["shape"] - - if flag != 2: - dtype = self._param_buffer[kw_name].dtype - device = self._param_buffer[kw_name].device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) - else: - dtype = param["parameter"].data.dtype - device = param["parameter"].data.device - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) - - if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) - - def __enter__(self): - self.enter() - - def exit(self, flag=0, backward=False): - """ - Reduce scatter gradients - """ - - if not self._need_release: - return - self._need_release = False - self.block._ready = False - if backward: - for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] - - # accumulate previous gradient - if local_param.requires_grad: - if local_param.grad is None: - grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist - local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() - else: - self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad - - current_stream = torch.cuda.current_stream() - config["load_stream"].wait_stream(current_stream) # wait for backward - - with torch.cuda.stream(config["load_stream"]): - nccl.groupStart() - for kw, val in self.block._storage_info.items(): - local_param = self.block._storage_params[kw] - - # scatter gradient - if local_param.requires_grad: - nccl.reduceScatter( - self._grad_buffer[kw], - local_param.grad.storage(), - "sum", - val['zero_comm'] - ) - nccl.groupEnd() - - # set wait stream for each storage - for kw in self._grad_tensor.keys(): - # grads can not be freed until reduce ops finish - self._grad_tensor[kw].record_stream(config["load_stream"]) - - - # Release all parameters from buffer to block_storge - for param in self.block._param_info: - kw_name = param["kw_name"] - dtype = self.block._storage_params[kw_name].dtype - device = self.block._storage_params[kw_name].device - if "begin" not in param: - param["parameter"].data = torch.tensor([], dtype=dtype, device=device) - param["parameter"].grad = None - continue - begin = param["begin"] - end = param["end"] - param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) - if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: - param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) - if flag == 1: - for i in self._param_buffer: - self.ctx_dict[i] = self._param_buffer[i] - self._grad_tensor = {} - self._param_tensor = {} - self._grad_buffer = {} - self._param_buffer = {} - - - def __exit__(self, exc_type, exc_val, exc_tb): - # reduce scatter gradients - self.exit() diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 4d91d1d0..2c9a68a4 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -1,6 +1,6 @@ import torch from .global_var import config -from .checkpointing import CheckpointBlockContext +from .zero_context import ZeroContext def zero_pre_forward(module, inputs): enter = True @@ -15,7 +15,7 @@ def zero_pre_forward(module, inputs): forward_flag = 2 # repeating forward in same layer if module.all_param_no_grad: #only forward forward_flag = 0 - module._forward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=pipe) + module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe) module._forward_block_ctx.enter(forward_flag) def zero_post_forward(module, inputs, outputs): @@ -33,13 +33,13 @@ def zero_post_forward(module, inputs, outputs): def zero_pre_backward(module, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": - module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict) + module._backward_block_ctx = ZeroContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) if not module._is_last_layer: module.next_module().backward_release(backward_flag) else: if module._micro_idx == config['micros'] - 1: - module._backward_block_ctx = CheckpointBlockContext(module, module._layer_dict, pipe=True) + module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True) module._backward_block_ctx.enter(backward_flag, True) def zero_post_backward(module, grad_inputs, grad_outputs): diff --git a/bmtrain/inspect/model.py b/bmtrain/inspect/model.py index 79975cc5..a296d832 100644 --- a/bmtrain/inspect/model.py +++ b/bmtrain/inspect/model.py @@ -1,7 +1,7 @@ import torch from ..store import broadcast_object from ..pipe_layer import PipelineTransformerBlockList -from ..block_layer import CheckpointBlock +from ..block_layer import Block from ..parameter import DistributedParameter from .. import nccl from ..global_var import config @@ -117,7 +117,7 @@ def inspect_pipeline_transformer_block_list(pipe_model: PipelineTransformerBlock return ret -def inspect_checkpoint_block(model : CheckpointBlock, param_name : str, prefix : str = ''): +def inspect_block(model : Block, param_name : str, prefix : str = ''): # fast check pass_fast_check = False for param in model._param_info: @@ -210,8 +210,8 @@ def inspect_model(model : torch.nn.Module, param_name : str, prefix : str = ''): """ if isinstance(model, PipelineTransformerBlockList): return inspect_pipeline_transformer_block_list(model, param_name, prefix) - elif isinstance(model, CheckpointBlock): - return inspect_checkpoint_block(model, param_name, prefix) + elif isinstance(model, Block): + return inspect_block(model, param_name, prefix) else: ret = [] for name, param in model._parameters.items(): diff --git a/bmtrain/layer.py b/bmtrain/layer.py index 7de83e5e..cf46814b 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -12,8 +12,8 @@ class DistributedModule(torch.nn.Module): def __getattr__(self, name: str): ret = super().__getattr__(name) - # gather distributed parameters if not in CheckpointBlock - if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block: + # gather distributed parameters if not in bmt.Block + if isinstance(ret, DistributedParameter) and not ret._in_block: return ret.gather() return ret @@ -32,8 +32,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - if isinstance(param, DistributedParameter):#and not param._in_checkpoint_block: - if param._in_checkpoint_block: + if isinstance(param, DistributedParameter):#and not param._in_block: + if param._in_block: destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation else: destination[prefix + name] = param.gather_all().detach().cpu() # sync operation diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index d5c86225..a46c7845 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -1,6 +1,6 @@ from typing import Generator, Iterable, List, Tuple import torch -from .block_layer import CheckpointBlock +from .block_layer import Block from .parameter import DistributedParameter from .global_var import config @@ -38,7 +38,7 @@ def init_distributed_parameter(params : Iterable[torch.nn.Parameter]): def iterate_parameters(model : torch.nn.Module): for kw, val in model._parameters.items(): - if hasattr(val,"_in_checkpoint_block") and val._in_checkpoint_block: + if hasattr(val,"_in_block") and val._in_block: return [] yield val @@ -49,7 +49,7 @@ def init_parameters(model : torch.nn.Module): modules = model.named_modules() for module_prefix, module in modules: - if isinstance(module, CheckpointBlock): + if isinstance(module, Block): module.init_parameters() else: init_distributed_parameter( iterate_parameters(module) ) @@ -65,7 +65,7 @@ def grouped_parameters(model : torch.nn.Module) -> Generator[Tuple[str, List[tor ret : List[torch.nn.Parameter] = {} for module in model.modules(): - if isinstance(module, CheckpointBlock): + if isinstance(module, Block): for kw, params in module.grouped_parameters(): if kw not in ret: ret[kw] = [] diff --git a/bmtrain/parameter.py b/bmtrain/parameter.py index b5e5e9ae..ffc27de2 100644 --- a/bmtrain/parameter.py +++ b/bmtrain/parameter.py @@ -25,7 +25,7 @@ class DistributedParameter(torch.nn.Parameter): _start_partition : int _end_partition : int _init_method : Optional[Callable[['DistributedParameter'], None]] - _in_checkpoint_block : bool + _in_block: bool _group : Optional[str] def __new__(cls, @@ -72,7 +72,7 @@ def __new__(cls, setattr(ret, "_start_partition", start_of_partition) setattr(ret, "_end_partition", end_of_partition) setattr(ret, "_init_method", init_method) - setattr(ret, "_in_checkpoint_block", False) + setattr(ret, "_in_block", False) setattr(ret, "_group", group if not tp_mode else "tp") setattr(ret, "_tp_mode", tp_mode) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index c28ab87d..e3913b6c 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -8,11 +8,11 @@ from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations from .global_var import config from . import nccl -from .checkpointing import ( - CheckpointBlockContext +from .zero_context import ( + ZeroContext ) from . import debug -from .block_layer import CheckpointBlock, round_up, _get_param_kw +from .block_layer import Block, round_up, _get_param_kw class PipePreFunction(torch.autograd.Function): @staticmethod @@ -171,7 +171,7 @@ def backward(ctx, grad_outputs): class PipelineTransformerBlockList(torch.nn.Module): r""" - TransformerBlockList is a list of CheckpointBlocks. + TransformerBlockList is a list of Blocks. This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass. @@ -188,9 +188,9 @@ class PipelineTransformerBlockList(torch.nn.Module): >>> hidden_state = transformer_module_list(hidden_state, ...) """ - _modules: Dict[str, CheckpointBlock] + _modules: Dict[str, Block] - def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: + def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: super().__init__() self.num_hidden = num_hidden self._modules = {} @@ -201,8 +201,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: self.stage_id = topo.stage_id self.pipe_idx = topo.pipe_idx for idx, module in enumerate(modules): - if not isinstance(module, CheckpointBlock): - module = CheckpointBlock(module) + if not isinstance(module, Block): + module = Block(module) module._mode = "PIPE" module.stage_id = self.stage_id @@ -230,10 +230,10 @@ def __init__(self, modules: Iterable[CheckpointBlock], num_hidden=1) -> None: def __len__(self) -> int: return len(self._modules) - def __iter__(self) -> Iterator[CheckpointBlock]: + def __iter__(self) -> Iterator[Block]: return iter(self._modules.values()) - def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: + def __getitem__(self, index: Union[int, str]) -> Block: return self._modules[str(index)] def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False): @@ -307,7 +307,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): if idx in self.layer_ids: with torch.no_grad(): - with CheckpointBlockContext(module, pipe=True): + with ZeroContext(module, pipe=True): module._module.state_dict(destination=dst, prefix=name, keep_vars=False) if config["zero_rank"] == 0: if config["rank"] == 0: diff --git a/bmtrain/store.py b/bmtrain/store.py index 0e0fd7ca..88ed7305 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -4,13 +4,13 @@ from .pipe_layer import PipelineTransformerBlockList from .global_var import config -from .block_layer import CheckpointBlock +from .block_layer import Block from . import nccl import io, pickle from typing import Mapping def _save_to_state_dict(model : torch.nn.Module, destination, prefix): - if isinstance(model, CheckpointBlock): + if isinstance(model, Block): if config['rank'] != 0: destination = OrderedDict() # creates an temporary ordered dict destination._metadata = OrderedDict() diff --git a/bmtrain/wrapper.py b/bmtrain/wrapper.py index 9a6f5f3a..722a8037 100644 --- a/bmtrain/wrapper.py +++ b/bmtrain/wrapper.py @@ -1,5 +1,5 @@ import torch -from .block_layer import CheckpointBlock, TransformerBlockList +from .block_layer import Block, TransformerBlockList from .layer import DistributedModule, DistributedParameter def make_distributed(model : torch.nn.Module): @@ -14,8 +14,8 @@ def make_distributed(model : torch.nn.Module): for kw in list(model._modules.keys()): if isinstance(model, torch.nn.ModuleList): - if not isinstance(model._modules[kw], CheckpointBlock): - model._modules[kw] = CheckpointBlock(model_wrapper_dispatch(model._modules[kw])) + if not isinstance(model._modules[kw], Block): + model._modules[kw] = Block(model_wrapper_dispatch(model._modules[kw])) else: model._modules[kw] = model_wrapper_dispatch(model._modules[kw]) @@ -27,7 +27,7 @@ def model_wrapper_dispatch(model : torch.nn.Module): return model elif isinstance(model, DistributedModule): return model - elif isinstance(model, CheckpointBlock): + elif isinstance(model, Block): return model else: return make_distributed(model) @@ -35,6 +35,6 @@ def model_wrapper_dispatch(model : torch.nn.Module): def BMTrainModelWrapper(model : torch.nn.Module) -> torch.nn.Module: """ Automatically wrap a model in a BMTrain model. - Replaces all parameters with DistributedParameter, all modules with DistributedModule, and modules in ModuleList with CheckpointBlock. + Replaces all parameters with DistributedParameter, all modules with DistributedModule, and modules in ModuleList with Block. """ - return model_wrapper_dispatch(model) \ No newline at end of file + return model_wrapper_dispatch(model) diff --git a/docs/source/notes/quickstart.md b/docs/source/notes/quickstart.md index 8cffb5ec..f139fd33 100644 --- a/docs/source/notes/quickstart.md +++ b/docs/source/notes/quickstart.md @@ -55,9 +55,9 @@ class MyModule(bmt.DistributedModule): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = torch.nn.ModuleList([ - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()) + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) ]) def forward(self): @@ -86,9 +86,9 @@ class MyModule(bmt.DistributedModule): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = torch.nn.ModuleList([ - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()) + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) ]) def forward(self): @@ -109,9 +109,9 @@ class MyModule(bmt.DistributedModule): super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) self.module_list = bmt.TransformerBlockList([ - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()), - bmt.CheckpointBlock(SomeTransformerBlock()) + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) ]) def forward(self): diff --git a/example/models/gpt.py b/example/models/gpt.py index 64474ba8..456dbcc9 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -21,7 +21,7 @@ def __init__(self, self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) self.transformers = bmt.TransformerBlockList([ - bmt.CheckpointBlock( + bmt.Block( TransformerEncoder( dim_model, dim_head, num_heads, dim_ff, bias, dtype ) diff --git a/tests/test_column_parallel_linear.py b/tests/test_column_parallel_linear.py index 1c49570a..c20707c2 100644 --- a/tests/test_column_parallel_linear.py +++ b/tests/test_column_parallel_linear.py @@ -5,7 +5,7 @@ def run_bmt(x, gather_output, ckp_path, tp_size=2): linear = bmt.nn.ColumnParallelLinear(8,8, gather_output=gather_output) - linear = bmt.CheckpointBlock(linear) + linear = bmt.Block(linear) bmt.init_parameters(linear) y = linear(x[config['topology'].tp_id]) y.sum().backward() diff --git a/tests/test_different_output_shape.py b/tests/test_different_output_shape.py index 437ba22c..bb8ab7fa 100644 --- a/tests/test_different_output_shape.py +++ b/tests/test_different_output_shape.py @@ -33,18 +33,18 @@ def forward(self, x): bmt.init_distributed() x = torch.tensor([1,2,3.]) - b = bmt.CheckpointBlock(Block0()) + b = bmt.Block(Block0()) y = b(x) assert isinstance(y, torch.Tensor) - b = bmt.CheckpointBlock(Block1()) + b = bmt.Block(Block1()) y = b(x) assert isinstance(y, tuple) and len(y)==1 - b = bmt.CheckpointBlock(Block2()) + b = bmt.Block(Block2()) y = b(x) assert isinstance(y, tuple) and len(y)==2 - b = bmt.CheckpointBlock(Block10()) + b = bmt.Block(Block10()) y = b(x) - assert isinstance(y, tuple) and len(y)==10 \ No newline at end of file + assert isinstance(y, tuple) and len(y)==10 diff --git a/tests/test_dropout.py b/tests/test_dropout.py index f02da4e3..d29240a1 100644 --- a/tests/test_dropout.py +++ b/tests/test_dropout.py @@ -17,7 +17,7 @@ def __init__(self) -> None: super().__init__() self.blk = bmt.TransformerBlockList([ - bmt.CheckpointBlock(InnerModule()) + bmt.Block(InnerModule()) for _ in range(5) ]) @@ -42,4 +42,4 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed() - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_inspector_hidden.py b/tests/test_inspector_hidden.py index c39de5fb..62e03656 100644 --- a/tests/test_inspector_hidden.py +++ b/tests/test_inspector_hidden.py @@ -4,7 +4,7 @@ import random import torch from bmtrain import config -from bmtrain.block_layer import CheckpointBlock, TransformerBlockList +from bmtrain.block_layer import Block, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F from bmtrain import inspect @@ -58,7 +58,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = TransformerBlockList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post @@ -77,7 +77,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = PipelineTransformerBlockList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post @@ -96,7 +96,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = torch.nn.ModuleList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post diff --git a/tests/test_load_ckpt.py b/tests/test_load_ckpt.py index 632f0e8e..0eb4f95f 100644 --- a/tests/test_load_ckpt.py +++ b/tests/test_load_ckpt.py @@ -41,7 +41,7 @@ def test_main(): ckpt_path = "test_ckpt.pt" # Transformer BlockList m = Linear_Normal(256, 256).cuda() - m2 = bmt.TransformerBlockList([bmt.CheckpointBlock(Linear_BMT(256, 256))]) + m2 = bmt.TransformerBlockList([bmt.Block(Linear_BMT(256, 256))]) if bmt.rank() == 0: torch.save(m.state_dict(), ckpt_path) dic2 = m.state_dict() @@ -56,13 +56,13 @@ def test_main(): os.remove(ckpt_path) print("Transformer Blocklist load_state_dict and state_dict test passed") - # CheckpointBlock - m3 = bmt.CheckpointBlock(Linear_BMT(256, 256)) + # Block + m3 = bmt.Block(Linear_BMT(256, 256)) m3.load_state_dict(m.state_dict()) for key in m.state_dict(): assert key in m3.state_dict(), "wrong key in bmtrain model" assert (m.state_dict()[key] == m3.state_dict()[key].cuda()).all(), "wrong param in bmtrain model" - print("CheckpointBlock load_state_dict and state_dict test passed") + print("Block load_state_dict and state_dict test passed") # normal Distributed module m4 = Linear_BMT(256, 256) @@ -75,4 +75,4 @@ def test_main(): if __name__ == "__main__": bmt.init_distributed() - test_main() \ No newline at end of file + test_main() diff --git a/tests/test_middle_hidden.py b/tests/test_middle_hidden.py index 688cdfe5..667d23b2 100644 --- a/tests/test_middle_hidden.py +++ b/tests/test_middle_hidden.py @@ -3,7 +3,7 @@ import bmtrain as bmt import random import torch -from bmtrain.block_layer import CheckpointBlock, TransformerBlockList +from bmtrain.block_layer import Block, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F from bmtrain import inspect @@ -34,7 +34,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = TransformerBlockList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post @@ -53,7 +53,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = PipelineTransformerBlockList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post @@ -72,7 +72,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = torch.nn.ModuleList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post diff --git a/tests/test_multi_return.py b/tests/test_multi_return.py index 8c408a1f..cc891fdf 100644 --- a/tests/test_multi_return.py +++ b/tests/test_multi_return.py @@ -4,7 +4,7 @@ import torch import random from bmtrain import config -from bmtrain.block_layer import CheckpointBlock, TransformerBlockList +from bmtrain.block_layer import Block, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F @@ -19,7 +19,7 @@ class Model_ZERO(torch.nn.Module): def __init__(self, ms) -> None: super().__init__() self.ms = TransformerBlockList([ - CheckpointBlock(m) + Block(m) for m in ms ], num_hidden=3) @@ -31,7 +31,7 @@ class Model_PIPE(torch.nn.Module): def __init__(self, ms) -> None: super().__init__() self.ms = PipelineTransformerBlockList([ - CheckpointBlock(m) + Block(m) for m in ms ], num_hidden=3) @@ -43,7 +43,7 @@ class Model_BLOCK(torch.nn.Module): def __init__(self, ms) -> None: super().__init__() self.ms = torch.nn.ModuleList([ - CheckpointBlock(m) + Block(m) for m in ms ]) diff --git a/tests/test_no_grad.py b/tests/test_no_grad.py index 3629921b..73231acc 100644 --- a/tests/test_no_grad.py +++ b/tests/test_no_grad.py @@ -13,8 +13,8 @@ def forward(self, x): def test_no_grad(): x = torch.randn(32, 32, device='cuda') - layer1 = bmt.CheckpointBlock(Layer()) - layer2 = bmt.CheckpointBlock(Layer()) + layer1 = bmt.Block(Layer()) + layer2 = bmt.Block(Layer()) layer1.linear.weight.requires_grad_(False) layer1.linear.bias.requires_grad_(False) y = layer1(x) @@ -30,8 +30,8 @@ def test_all_input_no_grad(): x = torch.randn(32,32, device='cuda') - linear1 = bmt.CheckpointBlock(linear1) - linear2 = bmt.CheckpointBlock(linear2) + linear1 = bmt.Block(linear1) + linear2 = bmt.Block(linear2) y = linear1(x) y = linear2(y) y.sum().backward() diff --git a/tests/test_other_hidden.py b/tests/test_other_hidden.py index 1f6c8c65..aab9c678 100644 --- a/tests/test_other_hidden.py +++ b/tests/test_other_hidden.py @@ -4,7 +4,7 @@ import random import torch from bmtrain import config -from bmtrain.block_layer import CheckpointBlock, TransformerBlockList +from bmtrain.block_layer import Block, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList import torch.nn.functional as F from bmtrain import inspect @@ -35,7 +35,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = TransformerBlockList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post @@ -54,7 +54,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = PipelineTransformerBlockList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post @@ -73,7 +73,7 @@ def __init__(self, pre, ms, post) -> None: super().__init__() self.pre = pre self.ms = torch.nn.ModuleList([ - CheckpointBlock(m) + Block(m) for m in ms ]) self.post = post diff --git a/tests/test_requires_grad.py b/tests/test_requires_grad.py index 943275c3..f8a11ec9 100644 --- a/tests/test_requires_grad.py +++ b/tests/test_requires_grad.py @@ -3,7 +3,7 @@ import bmtrain as bmt import torch from bmtrain import config -from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList +from bmtrain.block_layer import Block, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List import torch.nn.functional as F @@ -47,7 +47,7 @@ def run(m, a, b): def test_main(): a = Linear(256, 256) b = Linear(256, 256) - m = TransformerBlockList([CheckpointBlock(a), CheckpointBlock(b)]) + m = TransformerBlockList([Block(a), Block(b)]) bmt.init_parameters(m) a.bias.requires_grad_(False) @@ -75,7 +75,7 @@ def test_main(): def test_main_pipe(): a = Linear(256, 256) b = Linear(256, 256) - m = PipelineTransformerBlockList([CheckpointBlock(a), CheckpointBlock(b)]) + m = PipelineTransformerBlockList([Block(a), Block(b)]) bmt.init_parameters(m) a.bias.requires_grad_(False) diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index 4a2670ae..66b4880c 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -2,7 +2,7 @@ import bmtrain as bmt import torch -from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList +from bmtrain.block_layer import Block, TransformerBlockList from bmtrain.pipe_layer import PipelineTransformerBlockList from typing import List import torch.nn.functional as F @@ -43,7 +43,7 @@ def run(m, a, b): def test_main(): a = Linear(256, 256) b = Linear(256, 256) - m = TransformerBlockList([CheckpointBlock(a), CheckpointBlock(b)]) + m = TransformerBlockList([Block(a), Block(b)]) bmt.init_parameters(m) a.bias.requires_grad_(False) @@ -67,7 +67,7 @@ def test_main(): def test_main_pipe(): a = Linear(256, 256) b = Linear(256, 256) - m = PipelineTransformerBlockList([CheckpointBlock(a), CheckpointBlock(b)]) + m = PipelineTransformerBlockList([Block(a), Block(b)]) bmt.init_parameters(m) a.bias.requires_grad_(False) diff --git a/tests/test_row_parallel_linear.py b/tests/test_row_parallel_linear.py index f89b6dc5..23dce8b2 100644 --- a/tests/test_row_parallel_linear.py +++ b/tests/test_row_parallel_linear.py @@ -6,7 +6,7 @@ def run_bmt(x, ckp_path, split_input=True, use_checkpoint_block=True): linear = bmt.nn.RowParallelLinear(8,8, split_input=split_input, all_reduce_output=True) if use_checkpoint_block: - linear = bmt.CheckpointBlock(linear) + linear = bmt.Block(linear) bmt.init_parameters(linear) y = linear(x) y.sum().backward() diff --git a/tests/test_training.py b/tests/test_training.py index 1d6481c9..223c27df 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -372,7 +372,7 @@ def pipe_model(): model = GPT(**kwargs) pipe_model = bmt.BMTrainModelWrapper(model) for m in pipe_model.transformers: - assert isinstance(m, bmt.CheckpointBlock) + assert isinstance(m, bmt.Block) pipe_model.transformers = bmt.PipelineTransformerBlockList([m for m in pipe_model.transformers]) bmt.load(pipe_model, ckpt_path) return model From 90492fd2cb045f14f4c270c0af925a4fb5ae97c1 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 30 Aug 2023 20:17:56 +0800 Subject: [PATCH 24/52] Add zero_context.py (#160) --- bmtrain/zero_context.py | 151 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 bmtrain/zero_context.py diff --git a/bmtrain/zero_context.py b/bmtrain/zero_context.py new file mode 100644 index 00000000..653f40fa --- /dev/null +++ b/bmtrain/zero_context.py @@ -0,0 +1,151 @@ +import torch +from . import nccl +from .global_var import config +from .synchronize import wait_loader + +class ZeroContext: + def __init__(self, block : 'Block', ctx_dict : dict = None, pipe = False) -> None: + self.block = block + self.ctx_dict = ctx_dict + self._param_buffer = {} + self._grad_buffer = {} + self._param_tensor = {} + self._grad_tensor = {} + self._need_release = False + + def enter(self, flag=0, requires_grad=False): + """ + gather parameters + """ + if self.block._ready: + return + self.block._ready = True + self._need_release = True + + wait_loader() + with torch.cuda.stream(config["load_stream"]): + for kw, val in self.block._storage_info.items(): + assert self.block._storage_params[kw].is_cuda + assert kw not in self._grad_buffer + assert kw not in self._param_buffer + local_param = self.block._storage_params[kw] + + storage_type = local_param.storage_type() + if flag != 2: + self._param_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) + self._param_tensor[kw] = torch.tensor([], dtype=self._param_buffer[kw].dtype, device=self._param_buffer[kw].device).set_(self._param_buffer[kw]) + + if requires_grad and local_param.requires_grad: + self._grad_buffer[kw] = storage_type(val["partition_size"] * val["world_size"]) + self._grad_tensor[kw] = torch.tensor([], dtype=self._grad_buffer[kw].dtype, device=self._grad_buffer[kw].device).set_(self._grad_buffer[kw]).zero_() + if flag != 2: + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + nccl.allGather( + self.block._storage_params[kw].storage(), + self._param_buffer[kw], + val['zero_comm'] + ) + nccl.groupEnd() + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["load_stream"]) + + # set wait stream for each storage + for kw in self.block._storage_info.keys(): + if flag != 2: + self._param_tensor[kw].record_stream(current_stream) + if requires_grad and kw in self._grad_tensor: + self._grad_tensor[kw].record_stream(current_stream) + + # update parameters in block + for param in self.block._param_info: + kw_name = param["kw_name"] + offset = param["offset"] + shape = param["shape"] + + if flag != 2: + dtype = self._param_buffer[kw_name].dtype + device = self._param_buffer[kw_name].device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self._param_buffer[kw_name], offset, shape) + else: + dtype = param["parameter"].data.dtype + device = param["parameter"].data.device + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.ctx_dict[kw_name], offset, shape) + + if requires_grad and kw_name in self._grad_buffer and param["parameter"].requires_grad: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self._grad_buffer[kw_name], offset, shape) + + def __enter__(self): + self.enter() + + def exit(self, flag=0, backward=False): + """ + Reduce scatter gradients + """ + + if not self._need_release: + return + self._need_release = False + self.block._ready = False + if backward: + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # accumulate previous gradient + if local_param.requires_grad: + if local_param.grad is None: + grad_storage = val["storage_type"](val["partition_size"]) # initialize gradient if not exist + local_param.grad = torch.tensor([], dtype=grad_storage.dtype, device=grad_storage.device).set_(grad_storage).zero_() + else: + self._grad_tensor[kw][val["begin"]:val["end"]] += local_param.grad + + current_stream = torch.cuda.current_stream() + config["load_stream"].wait_stream(current_stream) # wait for backward + + with torch.cuda.stream(config["load_stream"]): + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # scatter gradient + if local_param.requires_grad: + nccl.reduceScatter( + self._grad_buffer[kw], + local_param.grad.storage(), + "sum", + val['zero_comm'] + ) + nccl.groupEnd() + + # set wait stream for each storage + for kw in self._grad_tensor.keys(): + # grads can not be freed until reduce ops finish + self._grad_tensor[kw].record_stream(config["load_stream"]) + + + # Release all parameters from buffer to block_storge + for param in self.block._param_info: + kw_name = param["kw_name"] + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device + if "begin" not in param: + param["parameter"].data = torch.tensor([], dtype=dtype, device=device) + param["parameter"].grad = None + continue + begin = param["begin"] + end = param["end"] + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].storage(), begin, end) + if param["parameter"].requires_grad and self.block._storage_params[kw_name].grad is not None: + param["parameter"].grad = torch.tensor([], dtype=dtype, device=device).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if flag == 1: + for i in self._param_buffer: + self.ctx_dict[i] = self._param_buffer[i] + self._grad_tensor = {} + self._param_tensor = {} + self._grad_buffer = {} + self._param_buffer = {} + + def __exit__(self, exc_type, exc_val, exc_tb): + # reduce scatter gradients + self.exit() From 9bbd27999adc40fbeb94feb2be791135b923491a Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Thu, 31 Aug 2023 01:15:35 -0700 Subject: [PATCH 25/52] FIX error of multi-gpus of test_synchronize.py --- tests/test_synchronize.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_synchronize.py b/tests/test_synchronize.py index ee43d39a..c5ceed3d 100644 --- a/tests/test_synchronize.py +++ b/tests/test_synchronize.py @@ -7,17 +7,18 @@ def test_main(): - ref_result = torch.rand(5 * bmt.world_size(), 5) + ref_result = torch.ones(5 * bmt.world_size(), 5) tensor = ref_result.chunk(bmt.world_size(), dim=0)[bmt.rank()] real_result = bmt.gather_result(tensor) assert torch.allclose(ref_result, real_result, atol=1e-6), "Assertion failed for real gather result error" for i in range(4): - size = i + 1 + size = i + 1 tensor_slice = tensor[:size, :size] result_slice = bmt.gather_result(tensor_slice) - assert torch.allclose(real_result[:size, :size], result_slice, atol=1e-6), f"Assertion failed for tensor_slice_{i}" - + test_slice = torch.chunk(result_slice, bmt.world_size(), dim=0)[0] + assert torch.allclose(tensor_slice, test_slice), f"Assertion failed for tensor_slice_{i}" + print("All tensor slice tests passed!") if __name__ == '__main__': From f256db67a6365f5f7dbf873e373dd25f44c1175b Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 31 Aug 2023 19:55:53 +0800 Subject: [PATCH 26/52] Refactor communicate groups and Block (#159) --- bmtrain/block_layer.py | 136 ++++++++++++++------------ bmtrain/hook_func.py | 11 +-- bmtrain/init.py | 39 +++++--- bmtrain/inspect/model.py | 4 +- bmtrain/pipe_layer.py | 29 ++---- example/models/gpt.py | 25 +++-- tests/test_middle_hidden.py | 2 +- tests/test_multi_return.py | 2 +- tests/test_no_grad.py | 38 +++++++ tests/test_requires_grad_multi_gpu.py | 2 +- tests/test_send_recv.py | 6 +- tests/test_training.py | 2 +- 12 files changed, 179 insertions(+), 117 deletions(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 61b335cf..816fb22c 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -5,12 +5,8 @@ import torch from . import nccl from .parameter import DistributedParameter, OpAllGather -from .zero_context import ( - ZeroContext -) - +from .zero_context import ZeroContext from . import hook_func - import inspect from torch.utils.checkpoint import checkpoint @@ -55,26 +51,51 @@ class Block(torch.nn.Module): use_checkpoint (boolean): use checkpoint or not. Default True. zero_level (int): 2 (ZeRO-2) indicates that optimizer states and gradients are partitioned across the process, 3 (ZeRO-3) means that the parameters are partitioned one the basis of ZeRO-2. Default 3. + initialized (bool): initialized parameter storage. Default False. + mode (str): the mode shouled be "PIPE" when runing in pipeline mode, otherwise mode="BLOCK". Default "BLOCK" Examples: >>> transformer_block = TransformerBlock(...) - >>> bmt_block = Block(transformer_block) - >>> y1, ... = bmt_block(x) + >>> block = Block(transformer_block) + >>> y1, ... = block(x) >>> y2, ... = transformer_block(x) >>> assert torch.allclose(y1, y2) """ - def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3): + def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_level=3, initialized=False, mode="BLOCK"): super().__init__() self._module = inner_module self._inputs = None self._layer_dict = {} self._forward_block_ctx = None self._backward_block_ctx = None - # build large parameter&grad here + self._param_info = [] self._storage_params : Dict[str, torch.nn.Parameter] = {} self._storage_info = {} self._ready = False + + self._use_checkpoint = use_checkpoint + self._is_first_layer = True + self._is_last_layer = True + self._need_release = True + self._next_module = None #save the next module of self + self._pre_module = None #save the pre module of self + self._mode = mode #BLOCK or PIPE + self.all_input_no_grad = False + self.all_param_no_grad = False + self._zero_level = zero_level + if not initialized: + self.init_param_storage() + + def reference(self, block): + self._param_info = block._param_info + self._storage_params = block._storage_params + self._storage_info = block._storage_info + self._layer_dict = block._layer_dict + self._initialized = True + self._need_release = False + + def init_param_storage(self): # sort parameters by name ordered_parameters = list(self._module.named_parameters()) @@ -87,12 +108,21 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev kw_name = _get_param_kw(param) if kw_name not in self._storage_info: + if self._mode == "PIPE" and param._tp_mode: + zero_comm = config["pp_tp_zero_comm"] + elif self._mode != "PIPE" and param._tp_mode: + zero_comm = config["tp_zero_comm"] + elif self._mode == "PIPE" and not param._tp_mode: + zero_comm = config["pp_zero_comm"] + else: + zero_comm = config["zero_comm"] + self._storage_info[kw_name] = { "total": 0, "storage_type": storage_type, "requires_grad": param.requires_grad, "group": param.group, - "zero_comm" : param._zero_comm + "zero_comm" : zero_comm } param_shape = param._original_shape @@ -106,7 +136,7 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev offsets = {} # intialize storage buffers for kw, val in self._storage_info.items(): - comm = val['zero_comm'] + comm = val["zero_comm"] world_size = nccl.commCount(comm) rank = nccl.commRank(comm) val["world_size"] = world_size @@ -191,36 +221,25 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev for kw in offsets.keys(): assert offsets[kw] == self._storage_info[kw]["total"] - self.use_checkpoint = use_checkpoint - self._is_first_layer = True - self._is_last_layer = True - self._release_list = [True] - self._next_module = [] #save the next module of self - self._pre_module = [] #save the pre module of self - self._ref_count = 0 #incremental in forward and decreasing in backward - self._mode = "BLOCK" #BLOCK or ZERO or PIPE - self.all_input_no_grad = False - self.all_param_no_grad = False - self._zero_level = zero_level - def set_pre_module(self, pre_module): if pre_module is not None: - self._pre_module.append(pre_module) - pre_module._next_module.append(self) + self._pre_module = pre_module + pre_module._next_module = self def pre_module(self): - assert len(self._pre_module) == self._ref_count, "{} != {}".format(len(self._pre_module), self._ref_count) - return self._pre_module[self._ref_count-1] + return self._pre_module if not self._is_first_layer else None def next_module(self): - assert len(self._next_module) == self._ref_count, "{} != {}".format(len(self._next_module), self._ref_count) - return self._next_module[self._ref_count-1] + return self._next_module if not self._is_last_layer else None + + def release_next_module(self, flag): + if self.next_module() is not None: + self.next_module().release(flag) - def backward_release(self, flag): - if self._ref_count == 1 and self._backward_block_ctx is not None: + def release(self, flag): + if self._need_release and self._backward_block_ctx is not None: self._backward_block_ctx.exit(flag, True) config['load_stream'].record_event(config['load_event']) - self._ref_count -= 1 def pre_hook(self, *args): grad_tensors = [] @@ -262,7 +281,7 @@ def forward(self, *args): placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) - if self.use_checkpoint: + if self._use_checkpoint: out = checkpoint(self._module, *arg_list, use_reentrant=not self.all_input_no_grad) else: out = self._module(*arg_list) @@ -506,6 +525,24 @@ def eval(self): def __repr__(self): return self._module.__repr__() + +def _block_wrapper(module, module_dict:dict, mode="BLOCK"): + if not isinstance(module, Block): + in_block = id(module) in module_dict + new_module = Block(module, initialized=in_block, mode=mode) + if in_block: + new_module.reference(modules[id(module)]) + else: + module_dict[id(module)] = new_module + else: + if mode == "PIPE" and module._mode != "PIPE": + assert False, "You must be set mode=\"PIPE\" in bmt.Block when use PipelineTransformerBlockList!" + if id(module._module) in module_dict: + assert False, "Duplicate bmt.Block not supported in same block list!" + else: + new_module = module + module_dict[id(module._module)] = new_module + return new_module class TransformerBlockList(torch.nn.Module): r""" @@ -528,21 +565,19 @@ class TransformerBlockList(torch.nn.Module): """ _modules: Dict[str, Block] - def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: + def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: super().__init__() self._modules = {} pre_module = None + module_dict = {} + module_dict = {} for i, module in enumerate(modules): - if not isinstance(module, Block): - module = Block(module) - - module._mode = "ZERO" + module = _block_wrapper(module, module_dict) module.set_pre_module(pre_module) pre_module = module module._is_first_layer = False module._is_last_layer = False - self._modules[str(i)] = module self.add_module(str(i), module) @@ -550,34 +585,13 @@ def __init__(self, modules: Iterable[Block], num_hidden=1, sqrt=False) -> None: self._modules[str(len(modules)-1)]._is_last_layer = True self.num_hidden = num_hidden - - if sqrt: - length = len(self) - num_save_needed = 0 - num_freed = 0 - save_list = [None]*length - for i in range(length-1, -1, -1): - if num_freed == 0 or i == 0: - num_save_needed += 1 - save_list[i] = [1, -num_save_needed] - num_freed = num_save_needed - else: - num_freed -= 1 - save_list[i] = [0, -(num_save_needed - num_freed)] - for i in range(length-1, -1, -1): - save_list[i][1] += num_save_needed - for i in range(0, length): - save_list[i][0] = i if save_list[i][0]==1 else save_list[i-1][0] - - self.save_list = save_list - else: - self.save_list = [(i, i) for i in range(len(self))] def __len__(self) -> int: return len(self._modules) def __iter__(self) -> Iterator[Block]: return iter(self._modules.values()) + def __getitem__(self, index: Union[int, str]) -> Block: return self._modules[str(index)] diff --git a/bmtrain/hook_func.py b/bmtrain/hook_func.py index 2c9a68a4..2c6108b0 100644 --- a/bmtrain/hook_func.py +++ b/bmtrain/hook_func.py @@ -11,7 +11,7 @@ def zero_pre_forward(module, inputs): if enter: zero_level = module._zero_level forward_flag = 1 if zero_level == 2 else 0 - if zero_level == 2 and module._ref_count > 1: + if zero_level == 2 and not module._need_release: forward_flag = 2 # repeating forward in same layer if module.all_param_no_grad: #only forward forward_flag = 0 @@ -28,15 +28,13 @@ def zero_post_forward(module, inputs, outputs): if exit: module._forward_block_ctx.exit(forward_flag) - module._ref_count += 1 def zero_pre_backward(module, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": module._backward_block_ctx = ZeroContext(module, module._layer_dict) module._backward_block_ctx.enter(backward_flag, True) - if not module._is_last_layer: - module.next_module().backward_release(backward_flag) + module.release_next_module(backward_flag) else: if module._micro_idx == config['micros'] - 1: module._backward_block_ctx = ZeroContext(module, module._layer_dict, pipe=True) @@ -46,10 +44,10 @@ def zero_post_backward(module, grad_inputs, grad_outputs): backward_flag = 2 if module._zero_level == 2 else 0 if module._mode != "PIPE": if module._is_first_layer: - module.backward_release(backward_flag) + module.release(backward_flag) else: if module._micro_idx == 0: - module.backward_release(backward_flag) + module.release(backward_flag) module._micro_idx -= 1 class OneStepNoGradFunc(torch.autograd.Function): @@ -84,7 +82,6 @@ def backward(ctx, grads): grads.append(None) return None, None, *grads - class PreHookFunc(torch.autograd.Function): @staticmethod def forward(ctx, module, *x): diff --git a/bmtrain/init.py b/bmtrain/init.py index a6214d78..cd304f89 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -130,13 +130,21 @@ def init_distributed( unique_id = nccl.getUniqueId() store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id) + config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['tp_size'], topo.tp_zero_id) - if topo.zero_id == 0: + if topo.pp_zero_id == 0: unique_id = nccl.getUniqueId() - store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode()) - config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size']), topo.zero_id) + store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()) + config['pp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['pipe_size'], topo.pp_zero_id) + + if topo.pp_tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) + config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) + + config ['zero_comm'] = config['comm'] for i in range(world_size): if i == rank: @@ -163,8 +171,6 @@ def __init__(self,config): dp_size = world_size // (pp_size * tp_size) config['tp_zero_size'] = dp_size config['zero_size'] = world_size // pp_size - topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda') - topo=topo.view(pp_size,dp_size*tp_size) self.stages = config['pipe_size'] stage_size = world_size // pp_size @@ -173,13 +179,18 @@ def __init__(self,config): self.stage_id = self.rank // stage_size self.tp_id = self.rank % tp_size self.tp_idx = self.rank // tp_size - self.zero_idx = self.stage_id - self.zero_id = self.pipe_idx - self.tp_zero_idx = self.stage_id * tp_size + self.tp_id - self.tp_zero_id = self.pipe_idx // tp_size - - self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1 - self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 + #pp->zero + self.pp_zero_idx = self.stage_id + self.pp_zero_id = self.pipe_idx + #tp->zero + self.tp_zero_idx = self.tp_id + self.tp_zero_id = self.tp_idx + #pp->tp->zero + self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.pp_tp_zero_id = self.pipe_idx // tp_size + #only zero + self.zero_idx = 0 + self.zero_id = self.rank def get_group_id(self,group_name): diff --git a/bmtrain/inspect/model.py b/bmtrain/inspect/model.py index a296d832..fc54f0d6 100644 --- a/bmtrain/inspect/model.py +++ b/bmtrain/inspect/model.py @@ -63,13 +63,13 @@ def inspect_pipeline_transformer_block_list(pipe_model: PipelineTransformerBlock nccl.allGather( model._storage_params[kw].storage(), _param_buffer[kw], - config["zero_comm"] + val["zero_comm"] ) if model._storage_params[kw].grad is not None: nccl.allGather( model._storage_params[kw].grad.storage(), _grad_buffer[kw], - config["zero_comm"] + val["zero_comm"] ) nccl.groupEnd() diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index e3913b6c..2b94c2b5 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -12,7 +12,7 @@ ZeroContext ) from . import debug -from .block_layer import Block, round_up, _get_param_kw +from .block_layer import Block, round_up, _get_param_kw, _block_wrapper class PipePreFunction(torch.autograd.Function): @staticmethod @@ -180,7 +180,7 @@ class PipelineTransformerBlockList(torch.nn.Module): Example: >>> module_list = [ ... ] >>> normal_module_list = torch.nn.ModuleList(module_list) - >>> transformer_module_list = TransformerBlockList(module_list) + >>> transformer_module_list = PipelineTransformerBlockList(module_list) >>> # Calling normal module list >>> for layer in normal_module_list: >>> hidden_state = layer.forward(hidden_state, ...) @@ -190,24 +190,18 @@ class PipelineTransformerBlockList(torch.nn.Module): """ _modules: Dict[str, Block] - def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: + def __init__(self, modules: Iterable[torch.nn.Module], num_hidden=1) -> None: super().__init__() self.num_hidden = num_hidden self._modules = {} - rank = config['rank'] - topo = config['topology'] self.layer_ids = [] + topo = config["topology"] self.stages = topo.stages self.stage_id = topo.stage_id self.pipe_idx = topo.pipe_idx + module_dict = {} for idx, module in enumerate(modules): - if not isinstance(module, Block): - module = Block(module) - - module._mode = "PIPE" - module.stage_id = self.stage_id - module.stages = self.stages - + module = _block_wrapper(module, module_dict, "PIPE") self._modules[str(idx)] = module self.layer_ids = self.get_range_by_stage_id(self.stage_id) @@ -217,15 +211,11 @@ def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: module = self._modules[str(layer_id)] module.set_pre_module(pre_module) pre_module = module - - module._is_first_stage = True if self.stage_id == 0 else False - module._is_last_stage = True if self.stage_id == self.stages-1 else False module._is_first_layer = False module._is_last_layer = False + self._modules[str(self.layer_ids[0])]._is_first_layer = True self._modules[str(self.layer_ids[-1])]._is_last_layer = True - - self.save_list = [(i, i) for i in range(len(self.layer_ids))] def __len__(self) -> int: return len(self._modules) @@ -309,7 +299,8 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): with torch.no_grad(): with ZeroContext(module, pipe=True): module._module.state_dict(destination=dst, prefix=name, keep_vars=False) - if config["zero_rank"] == 0: + + if config["topology"].pp_zero_id == 0: if config["rank"] == 0: destination.update(dst) else: @@ -318,5 +309,5 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): send_activations(tensor.cuda(), 0, config['pipe_comm']) if config['rank'] == 0 and idx not in self.layer_ids: for n, parameter in module._module.named_parameters(): - destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']) + destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']).cpu() diff --git a/example/models/gpt.py b/example/models/gpt.py index 456dbcc9..4596167c 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -20,14 +20,25 @@ def __init__(self, self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) - self.transformers = bmt.TransformerBlockList([ - bmt.Block( - TransformerEncoder( - dim_model, dim_head, num_heads, dim_ff, bias, dtype + if config['pipe_size'] > 1: + self.transformers = bmt.PipelineTransformerBlockList([ + bmt.Block( + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ) + , mode="PIPE" ) - ) - for _ in range(num_layers) - ]) + for _ in range(num_layers) + ]) + else: + self.transformers = bmt.TransformerBlockList([ + bmt.Block( + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ) + ) + for _ in range(num_layers) + ]) self.layernorm = Layernorm(dim_model, dtype=dtype) diff --git a/tests/test_middle_hidden.py b/tests/test_middle_hidden.py index 667d23b2..6a1f476f 100644 --- a/tests/test_middle_hidden.py +++ b/tests/test_middle_hidden.py @@ -207,6 +207,6 @@ def test_main(): assert_eq(w, w2) if __name__ == "__main__": - bmt.init_distributed(pipe_size=4) + bmt.init_distributed(pipe_size=1) test_main() diff --git a/tests/test_multi_return.py b/tests/test_multi_return.py index cc891fdf..f4a5d79f 100644 --- a/tests/test_multi_return.py +++ b/tests/test_multi_return.py @@ -121,6 +121,6 @@ def test_main(): assert_lt((r[i]-r2[i]).abs().max(), 1e-5) if __name__ == "__main__": - bmt.init_distributed(pipe_size=2) + bmt.init_distributed(pipe_size=1) test_main() diff --git a/tests/test_no_grad.py b/tests/test_no_grad.py index 73231acc..ab4830ee 100644 --- a/tests/test_no_grad.py +++ b/tests/test_no_grad.py @@ -24,6 +24,23 @@ def test_no_grad(): assert layer1.count == 1 assert layer2.count == 2 +def test_multi_layer_no_grad(): + x = torch.randn(32, 32, device='cuda') + + layers = [] + for i in range(10): + layer = bmt.Block(Layer()) + layer.linear.weight.requires_grad_(i > 4) + layer.linear.bias.requires_grad_(i > 4) + layers.append(layer) + + y = x + for layer in layers: + y = layer(y) + y.sum().backward() + for i in range(len(layers)): + assert layers[i].count == (1 if i <=4 else 2) + def test_all_input_no_grad(): linear1 = bmt.nn.Linear(32, 32) linear2 = bmt.nn.Linear(32, 32) @@ -39,8 +56,29 @@ def test_all_input_no_grad(): assert linear1.bias.grad is not None assert x.grad is None +def test_no_grad_error(): + layer = bmt.Block(Layer()) + + try: + block_list = bmt.TransformerBlockList([layer, layer]) + raise ValueError("test failed") + except AssertionError as e: + return + +def test_no_grad_error2(): + layer = bmt.Block(Layer()) + + try: + block_list = bmt.PipelineTransformerBlockList([layer]) + raise ValueError("test failed") + except AssertionError as e: + return + if __name__ == '__main__': bmt.init_distributed() test_no_grad() + test_multi_layer_no_grad() test_all_input_no_grad() + test_no_grad_error() + test_no_grad_error2() diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index 66b4880c..cd2e27b2 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -90,7 +90,7 @@ def test_main_pipe(): assert_eq(sm2.split('\n')[2], sm3.split('\n')[2]) if __name__ == "__main__": - bmt.init_distributed(pipe_size=2) + bmt.init_distributed(pipe_size=1) test_main() test_main_pipe() diff --git a/tests/test_send_recv.py b/tests/test_send_recv.py index 95c9c1e5..f933b0c2 100644 --- a/tests/test_send_recv.py +++ b/tests/test_send_recv.py @@ -6,12 +6,12 @@ def test_send_recv(): if config["topology"].stage_id == 0: - a = torch.ones((2,1)) * (config["zero_rank"]+1) + a = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) a = a.cuda() print(f"send {a}") bmt.distributed.send_activations(a, 1, config["pipe_comm"]) else: - ref = torch.ones((2,1)) * (config["zero_rank"]+1) + ref = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) a = bmt.distributed.recv_activations(0, config["pipe_comm"]) print(f"recv {a}") assert_all_eq(a, ref.cuda()) @@ -19,4 +19,4 @@ def test_send_recv(): if __name__ == '__main__': bmt.init_distributed(pipe_size=2) - test_send_recv() \ No newline at end of file + test_send_recv() diff --git a/tests/test_training.py b/tests/test_training.py index 223c27df..c1b3e4bd 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -452,6 +452,6 @@ def check_param(info1, info2): assert_lt(abs(v1-v2), 1e-2) if __name__ == '__main__': - bmt.init_distributed(pipe_size=2) + bmt.init_distributed(pipe_size=1) test_main(test_fp16=True, test_fp32=True) From 0ba1e3a3c23a5826d45028394bc7a9564a35b898 Mon Sep 17 00:00:00 2001 From: Congrui Yin <88324880+JerryYin777@users.noreply.github.com> Date: Thu, 31 Aug 2023 04:56:18 -0700 Subject: [PATCH 27/52] Update test_synchronize.py (#161) --- tests/test_synchronize.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_synchronize.py b/tests/test_synchronize.py index c5ceed3d..bea48b04 100644 --- a/tests/test_synchronize.py +++ b/tests/test_synchronize.py @@ -16,10 +16,10 @@ def test_main(): size = i + 1 tensor_slice = tensor[:size, :size] result_slice = bmt.gather_result(tensor_slice) - test_slice = torch.chunk(result_slice, bmt.world_size(), dim=0)[0] + test_slice = torch.chunk(result_slice, bmt.world_size(), dim=0)[i] assert torch.allclose(tensor_slice, test_slice), f"Assertion failed for tensor_slice_{i}" - - print("All tensor slice tests passed!") + +print("All test passed") if __name__ == '__main__': bmt.init_distributed(pipe_size=1) From 985a3c3032e0423fc7a0600405033de5d88ed331 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 31 Aug 2023 20:48:25 +0800 Subject: [PATCH 28/52] Fix cross (#162) --- bmtrain/nn/parallel_cross_entropy_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/nn/parallel_cross_entropy_func.py b/bmtrain/nn/parallel_cross_entropy_func.py index cd1f63bf..8e4548ae 100644 --- a/bmtrain/nn/parallel_cross_entropy_func.py +++ b/bmtrain/nn/parallel_cross_entropy_func.py @@ -103,7 +103,7 @@ def backward(ctx, grad_output): else: grad_2d[arange_1d, masked_target_1d] -= softmax_update - grad_input.mul_(grad_output.flatten(0,1).unsqueeze(dim=-1)) + grad_input.mul_(grad_output.view(*grad_input.shape[:-1]).unsqueeze(dim=-1)) return grad_input, None, None From 206e11dce1999907c2e4653d111df7f508b04a70 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Mon, 4 Sep 2023 16:07:06 +0800 Subject: [PATCH 29/52] Fix block wrapper (#166) --- bmtrain/block_layer.py | 2 +- tests/test_no_grad.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 816fb22c..ab027fc9 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -531,7 +531,7 @@ def _block_wrapper(module, module_dict:dict, mode="BLOCK"): in_block = id(module) in module_dict new_module = Block(module, initialized=in_block, mode=mode) if in_block: - new_module.reference(modules[id(module)]) + new_module.reference(module_dict[id(module)]) else: module_dict[id(module)] = new_module else: diff --git a/tests/test_no_grad.py b/tests/test_no_grad.py index ab4830ee..7851c670 100644 --- a/tests/test_no_grad.py +++ b/tests/test_no_grad.py @@ -56,6 +56,11 @@ def test_all_input_no_grad(): assert linear1.bias.grad is not None assert x.grad is None +def test_same_layer(): + layer = Layer() + block_list = bmt.TransformerBlockList([layer, layer]) + assert id(block_list[0]) != id(block_list[1]) + def test_no_grad_error(): layer = bmt.Block(Layer()) @@ -80,5 +85,6 @@ def test_no_grad_error2(): test_no_grad() test_multi_layer_no_grad() test_all_input_no_grad() + test_same_layer() test_no_grad_error() test_no_grad_error2() From 696eb5ca2d188c6feb4839db95e54e275e58a508 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 6 Sep 2023 09:49:29 +0800 Subject: [PATCH 30/52] Fix async send in pipe mode (#167) --- bmtrain/pipe_layer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bmtrain/pipe_layer.py b/bmtrain/pipe_layer.py index 2b94c2b5..4d3b17ad 100644 --- a/bmtrain/pipe_layer.py +++ b/bmtrain/pipe_layer.py @@ -142,7 +142,7 @@ def backward(ctx, grad_outputs): current_stream = torch.cuda.current_stream() with torch.cuda.stream(config['pp_comm_stream']): config['pp_comm_stream'].wait_stream(current_stream) - send_data.record_stream(current_stream) + send_data.record_stream(config['pp_comm_stream']) send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) return grad_outputs, None @@ -157,7 +157,7 @@ def forward(ctx, outputs, stage_id): current_stream = torch.cuda.current_stream() with torch.cuda.stream(config['pp_comm_stream']): config['pp_comm_stream'].wait_stream(current_stream) - send_data.record_stream(current_stream) + send_data.record_stream(config['pp_comm_stream']) send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) return outputs @@ -202,6 +202,7 @@ def __init__(self, modules: Iterable[torch.nn.Module], num_hidden=1) -> None: module_dict = {} for idx, module in enumerate(modules): module = _block_wrapper(module, module_dict, "PIPE") + module._zero_level = 2 #currently, only support ZeRO-2 in pipeline mode self._modules[str(idx)] = module self.layer_ids = self.get_range_by_stage_id(self.stage_id) From 535279e66b17465849494f2e94f664397c7887b7 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 6 Sep 2023 10:38:04 +0800 Subject: [PATCH 31/52] Faster parallel linear (#165) --- bmtrain/nn/column_parallel_linear.py | 5 +- bmtrain/nn/parallel_linear_func.py | 200 +++++++++++++++++++++++++-- bmtrain/nn/row_parallel_linear.py | 5 +- 3 files changed, 197 insertions(+), 13 deletions(-) diff --git a/bmtrain/nn/column_parallel_linear.py b/bmtrain/nn/column_parallel_linear.py index e8f554c8..a432d798 100644 --- a/bmtrain/nn/column_parallel_linear.py +++ b/bmtrain/nn/column_parallel_linear.py @@ -8,12 +8,13 @@ ReduceType) class ColumnParallelLinear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False, gather_input=True) -> None: + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, gather_output=False, gather_input=True, async_gather_chunks=2) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.gather_output = gather_output self.gather_input = gather_input + self.async_gather_chunks = async_gather_chunks tp_size = config['tp_size'] assert out_features % tp_size == 0 self.out_features_per_partition = out_features // tp_size @@ -27,7 +28,7 @@ def forward(self, input): gather_input = self.gather_input split_input = False reduce_output_type = None - return OpParallelLinear.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type) + return OpParallelLinear.apply(input, self.weight, self.bias, gather_input, self.gather_output, split_input, reduce_output_type, self.async_gather_chunks) def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 55741530..325bda23 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -20,9 +20,169 @@ def preprocess_input(input, gather_input, split_input): input = all_input_list[config['topology'].tp_id] return input +def async_all_gather_linear_func(input, weight, bias, async_chunks=2): + dim = input.dim() + shape = list(input.shape) + if dim > 2: + input = input.view(-1, input.shape[-1]) + tp_size = config['tp_size'] + current_stream = torch.cuda.current_stream() + comm_stream = config['tp_comm_stream'] + + rounds = async_chunks + inputs = input.chunk(rounds, dim=0) + comm_stream.wait_stream(current_stream) + outputs = [None] * tp_size * rounds + + input = all_gather(inputs[0], config['tp_comm']) + input = input.flatten(0, 1) + out = F.linear(input, weight, bias) + outs = out.chunk(tp_size, dim=0) + for i in range(tp_size): + outputs[i * rounds] = outs[i] + + #async all_gather and overalap with linear + for i in range(rounds-1): + with torch.cuda.stream(comm_stream): + inputs[i+1].record_stream(comm_stream) + input = all_gather(inputs[i+1], config['tp_comm']) + input = input.flatten(0, 1) + + current_stream.wait_stream(comm_stream) + out = F.linear(input, weight, bias) + outs = out.chunk(tp_size, dim=0) + for j in range(tp_size): + outputs[(i + 1) + j * rounds] = outs[j] + + out = torch.cat(outputs, dim=0) + if dim > 2: + out_shape = list(out.shape) + shape[-1] = out_shape[-1] + shape[0] = shape[0] * tp_size + out = out.view(shape) + return out + +def async_reduce_scatter_linear_func(input, weight, bias, async_chunks=2): + tp_size = config['tp_size'] + comm_stream = config['tp_comm_stream'] + rounds = async_chunks + input_shape = list(input.shape) + dim = input.dim() + if dim > 2: + input = input.view(-1, input.shape[-1]) + inputs = input.chunk(rounds*tp_size, dim=0) + current_stream = torch.cuda.current_stream() + + outputs = [None] * rounds + for i in range(rounds): + input = [None] * tp_size + for j in range(tp_size): + input[j] = inputs[j*rounds + i] + input = torch.cat(input, dim=0) + out = F.linear(input, weight, bias) + with torch.cuda.stream(comm_stream): + comm_stream.wait_stream(current_stream) + out.record_stream(comm_stream) + shape = list(out.shape) + shape[0] = shape[0] // config['tp_size'] + outputs[i] = torch.empty(shape, dtype=out.dtype, device=out.device) + nccl.reduceScatter(out.storage(), outputs[i].storage(), "sum", config['tp_comm']) + + current_stream.wait_stream(comm_stream) + out = torch.cat(outputs, dim=0) + if dim > 2: + out_shape = list(out.shape) + input_shape[-1] = out_shape[-1] + input_shape[0] = input_shape[0] // tp_size + out = out.view(input_shape) + + return out + +def async_all_gather_linear_backward_func(grad_out, input, weight, bias, async_chunks=2): + tp_size = config['tp_size'] + current_stream = torch.cuda.current_stream() + comm_stream = config['tp_comm_stream'] + input_require_grad = input.requires_grad + dim = input.dim() + input_shape = input.shape + if dim > 2: + input = input.view(-1, input_shape[-1]) + grad_out = grad_out.view(-1, grad_out.shape[-1]) + + rounds = async_chunks + grad_inputs = [None] * tp_size * rounds + grad_weights = [None] * tp_size * rounds + grad_outs = [None] * tp_size * rounds + local_grad_outs = grad_out.chunk(rounds, dim=0) + + inputs = [None] * rounds + comm_stream.wait_stream(current_stream) + if weight.requires_grad: + with torch.cuda.stream(comm_stream): + input.record_stream(comm_stream) + input_list = [None] * tp_size * rounds + tp_inputs = input.chunk(tp_size, dim=0) + for i in range(tp_size): + chunk_inputs = tp_inputs[i].chunk(rounds, dim=0) + for j in range(rounds): + input_list[j * tp_size + i] = chunk_inputs[j] + start = 0 + end = tp_size + for i in range(rounds): + inputs[i] = torch.cat(input_list[start:end], dim=0) + start = end + end += tp_size + + grad_input = grad_weight = grad_bias = None + + grad_out = all_gather(local_grad_outs[0], config['tp_comm']) + for j in range(tp_size): + grad_outs[j * rounds] = grad_out[j] + grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n) + if input_require_grad: + grad_input = grad_out.matmul(weight) # (tp_size * (m/rounds), n) * (n, k/tp_size) + tmp_grad_inputs = grad_input.chunk(tp_size, dim=0) + for j in range(tp_size): + grad_inputs[j * rounds] = tmp_grad_inputs[j] + + if weight.requires_grad: + grad_weight = grad_out.reshape(-1, + grad_out.shape[-1]).t().matmul(inputs[0].reshape(-1, inputs[0].shape[-1])) + + #async all_gather and overalap with matmul + for i in range(rounds-1): + with torch.cuda.stream(comm_stream): + local_grad_outs[i+1].record_stream(comm_stream) + grad_out = all_gather(local_grad_outs[i+1], config['tp_comm']) + for j in range(tp_size): + grad_outs[j * rounds + i+1] = grad_out[j] + grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n) + + current_stream.wait_stream(comm_stream) + if input_require_grad: + grad_input = grad_out.matmul(weight) # (tp_size * (m/rounds), n) * (n, k/tp_size) + tmp_grad_inputs = grad_input.chunk(tp_size, dim=0) + for j in range(tp_size): + grad_inputs[j * rounds + i+1] = tmp_grad_inputs[j] + + if weight.requires_grad: + dim = grad_out.dim() + grad_weight += grad_out.reshape(-1, + grad_out.shape[-1]).t().matmul(inputs[i+1].reshape(-1, inputs[i+1].shape[-1])) + + if input_require_grad: + grad_input = torch.cat(grad_inputs, dim=0) + grad_input = grad_input.view(input_shape) + + if bias is not None and bias.requires_grad: + grad_out = torch.cat(grad_outs, dim=0) + grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(0) + + return grad_input, grad_weight, grad_bias + class OpParallelLinear(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None): + def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=False, split_input=False, reduce_output_type=None, async_gather_chunks=2): if reduce_output_type is not None: reduce_output_type = ReduceType(reduce_output_type) @@ -31,9 +191,15 @@ def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=Fal ctx.split_input = split_input ctx.gather_input = gather_input ctx.reduce_output_type = reduce_output_type + ctx.async_gather_chunks = async_gather_chunks - all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) - out = F.linear(all_input, weight, bias) + if gather_input and config['tp_size'] > 1 and async_gather_chunks > 1: + out = async_all_gather_linear_func(input, weight, bias, async_gather_chunks) + elif reduce_output_type == ReduceType.REDUCE_SCATTER: + return async_reduce_scatter_linear_func(input, weight, bias, async_gather_chunks) + else: + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + out = F.linear(all_input, weight, bias) if gather_output: all_output_list = all_gather(out, config['tp_comm']) @@ -62,8 +228,12 @@ def backward(ctx, grad_output): gather_output = ctx.gather_output if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER: - grad_output = all_gather(grad_output, config['tp_comm']) - grad_output = grad_output.flatten(0, 1) + if input.requires_grad or weight.requires_grad: + grad_input, grad_weight, grad_bias = async_all_gather_linear_backward_func(grad_output, input, weight, bias, ctx.async_gather_chunks) + return grad_input, grad_weight, grad_bias, None, None, None, None, None + else: + grad_output = all_gather(grad_output, config['tp_comm']) + grad_output = grad_output.flatten(0, 1) if gather_output: tp_size = config['tp_size'] @@ -73,14 +243,24 @@ def backward(ctx, grad_output): grad_input = grad_weight = grad_bias = None + current_stream = torch.cuda.current_stream() if input.requires_grad or weight.requires_grad: - all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + if ctx.gather_input: + # async the all_gather + with torch.cuda.stream(config['tp_comm_stream']): + input.record_stream(config['tp_comm_stream']) + config['tp_comm_stream'].wait_stream(current_stream) + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + #use event to solve two streams waiting for each other + gather_event = config['tp_comm_stream'].record_event() + else: + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) if input.requires_grad: - current_stream = torch.cuda.current_stream() grad_all_input = grad_output.matmul(weight) grad_input = torch.zeros_like(input) if ctx.gather_input: + # async the reduce_scatter with torch.cuda.stream(config['tp_comm_stream']): config['tp_comm_stream'].wait_stream(current_stream) grad_input.record_stream(config['tp_comm_stream']) @@ -95,8 +275,10 @@ def backward(ctx, grad_output): grad_input.record_stream(config['tp_comm_stream']) grad_input = all_gather(grad_input, config['tp_comm']) + # wait all_gather + if ctx.gather_input: + current_stream.wait_event(gather_event) if weight.requires_grad: - dim = grad_output.dim() grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(all_input.reshape(-1, all_input.shape[-1])) @@ -105,4 +287,4 @@ def backward(ctx, grad_output): current_stream = torch.cuda.current_stream() current_stream.wait_stream(config['tp_comm_stream']) - return grad_input, grad_weight, grad_bias, None, None, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None, None diff --git a/bmtrain/nn/row_parallel_linear.py b/bmtrain/nn/row_parallel_linear.py index bbe94cfa..7451e7d3 100644 --- a/bmtrain/nn/row_parallel_linear.py +++ b/bmtrain/nn/row_parallel_linear.py @@ -8,12 +8,13 @@ ReduceType) class RowParallelLinear(bmt.DistributedModule): - def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, split_input=False, all_reduce_output=False) -> None: + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None, split_input=False, all_reduce_output=False, async_chunks=2) -> None: super().__init__() self.in_features = in_features self.out_features = out_features self.split_input = split_input self.all_reduce_output = all_reduce_output + self.async_chunks = async_chunks tp_size = config['tp_size'] assert in_features % tp_size == 0 self.in_features_per_partition = in_features // tp_size @@ -27,7 +28,7 @@ def forward(self, input): gather_input = self.split_input gather_output = False reduce_output_type = ReduceType.ALL_REDUCE if self.all_reduce_output else ReduceType.REDUCE_SCATTER - out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type) + out = OpParallelLinear.apply(input, self.weight, None, gather_input, gather_output, self.split_input, reduce_output_type, self.async_chunks) if self.bias is not None: out = out + self.bias return out From 95417b52bd73bdd1d0341a5dda3304d6e783a69b Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Sat, 16 Sep 2023 16:07:08 +0800 Subject: [PATCH 32/52] fix async row linear not support split_input (#168) --- .gitignore | 3 ++- bmtrain/nn/parallel_linear_func.py | 9 +-------- tests/test_column_parallel_linear.py | 16 +++++++++------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 2e8c0dcd..75138102 100644 --- a/.gitignore +++ b/.gitignore @@ -151,4 +151,5 @@ log !bmtrain/dist tests/test_log.txt -tests/*.opt \ No newline at end of file +tests/*.opt +tests/*.ckp \ No newline at end of file diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index 325bda23..c3eec380 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -193,7 +193,7 @@ def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=Fal ctx.reduce_output_type = reduce_output_type ctx.async_gather_chunks = async_gather_chunks - if gather_input and config['tp_size'] > 1 and async_gather_chunks > 1: + if gather_input and config['tp_size'] > 1 and async_gather_chunks > 1 and split_input == False: out = async_all_gather_linear_func(input, weight, bias, async_gather_chunks) elif reduce_output_type == ReduceType.REDUCE_SCATTER: return async_reduce_scatter_linear_func(input, weight, bias, async_gather_chunks) @@ -212,13 +212,6 @@ def forward(ctx, input, weight, bias=None, gather_input=False, gather_output=Fal if reduce_output_type == ReduceType.ALL_REDUCE: nccl.allReduce(out.storage(), out.storage(), "sum", config['tp_comm']) return out - - elif reduce_output_type == ReduceType.REDUCE_SCATTER: - shape = list(out.shape) - shape[0] = shape[0] // config['tp_size'] - reduce_out = torch.empty(shape, dtype=out.dtype, device=out.device) - nccl.reduceScatter(out.storage(), reduce_out.storage(), "sum", config['tp_comm']) - return reduce_out else: assert False, "no support reduce type{}".format(reduce_output_type) diff --git a/tests/test_column_parallel_linear.py b/tests/test_column_parallel_linear.py index c20707c2..6dce1abd 100644 --- a/tests/test_column_parallel_linear.py +++ b/tests/test_column_parallel_linear.py @@ -7,7 +7,7 @@ def run_bmt(x, gather_output, ckp_path, tp_size=2): linear = bmt.nn.ColumnParallelLinear(8,8, gather_output=gather_output) linear = bmt.Block(linear) bmt.init_parameters(linear) - y = linear(x[config['topology'].tp_id]) + y = linear(x) y.sum().backward() bmt.save(linear, ckp_path) bmt.synchronize() @@ -24,17 +24,19 @@ def run_torch(x, ckp_path): return y, linear.weight.grad, linear.bias.grad def run(gather_output, ckp_path): - tp_size = bmt.config['tp_size'] torch.cuda.manual_seed(100) - x = torch.randn(tp_size, 8,8, device='cuda').requires_grad_() - y1, weight_grad1, bias_grad1 = run_bmt(x, gather_output, ckp_path) + tp_size = config["tp_size"] + tp_rank = config['topology'].tp_id + x = torch.randn(8, 8, 8, device='cuda').requires_grad_() + rank_x = x.chunk(tp_size, dim=0)[tp_rank] + y1, weight_grad1, bias_grad1 = run_bmt(rank_x, gather_output, ckp_path) y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) tp_rank = config['topology'].tp_id if gather_output: - assert np.allclose(y1.detach().cpu().numpy(), y2.flatten(0,1).detach().cpu().numpy()) + assert np.allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy()) else: - torch_out_list = torch.split(y2, y2.size()[-1] // tp_size, dim=y2.dim()-1) - assert np.allclose(y1.detach().cpu().numpy(), torch_out_list[tp_rank].flatten(0,1).detach().cpu().numpy()) + torch_out_list = torch.split(y2, y2.size()[-1] // tp_size, dim=-1) + assert np.allclose(y1.detach().cpu().numpy(), torch_out_list[tp_rank].detach().cpu().numpy()) weight_grad_list = weight_grad2.chunk(tp_size, dim=0) assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) From 25e36718c73237fc12e532aebf45aec77280e7cc Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 27 Sep 2023 20:55:39 +0800 Subject: [PATCH 33/52] add _save_to_infer_model (#170) --- bmtrain/init.py | 1 + bmtrain/layer.py | 12 +++++++++--- bmtrain/store.py | 49 ++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index cd304f89..b7224f94 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -82,6 +82,7 @@ def init_distributed( config["zero_rank"] = config['topology'].get_group_rank("zero") config["tp_rank"] = config['topology'].get_group_rank("tp") config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero") + config["save_param_to_cpu"] = True cpus_this_worker = None all_available_cpus = sorted(list(os.sched_getaffinity(0))) diff --git a/bmtrain/layer.py b/bmtrain/layer.py index cf46814b..e071e01b 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -34,11 +34,17 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): if param is not None: if isinstance(param, DistributedParameter):#and not param._in_block: if param._in_block: - destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation + destination[prefix + name] = param.tp_gather().detach() # sync operation else: - destination[prefix + name] = param.gather_all().detach().cpu() # sync operation + destination[prefix + name] = param.gather_all().detach() # sync operation + if config['save_param_to_cpu']: + destination[prefix + name] = destination[prefix + name].cpu() else: - destination[prefix + name] = param if keep_vars else param.detach().cpu() + if config['save_param_to_cpu']: + destination[prefix + name] = param if keep_vars else param.detach().cpu() + else: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: destination[prefix + name] = buf if keep_vars else buf.detach() diff --git a/bmtrain/store.py b/bmtrain/store.py index 88ed7305..254213bd 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -3,31 +3,48 @@ import torch from .pipe_layer import PipelineTransformerBlockList +from .block_layer import TransformerBlockList from .global_var import config from .block_layer import Block from . import nccl import io, pickle from typing import Mapping -def _save_to_state_dict(model : torch.nn.Module, destination, prefix): +def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): if isinstance(model, Block): - if config['rank'] != 0: + if rank != 0: destination = OrderedDict() # creates an temporary ordered dict destination._metadata = OrderedDict() model.state_dict(destination=destination, prefix=prefix, keep_vars=False) else: - if config['rank'] != 0: + if rank != 0: destination = OrderedDict() # creates an temporary ordered dict destination._metadata = OrderedDict() model._save_to_state_dict(destination, prefix, False) +def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''): + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, config['local_rank'], destination, prefix) + for name, module in model._modules.items(): + if module is not None: + _save_to_local_rank0(module, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) if not isinstance(model, PipelineTransformerBlockList): - _save_to_state_dict(model, destination, prefix) + _save_to_state_dict(model, config['rank'], destination, prefix) for name, module in model._modules.items(): if module is not None: _save_to_rank0(module, destination, prefix + name + '.') @@ -38,6 +55,30 @@ def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): else: model._save_to_state_dict(destination, prefix, False) return destination + +def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, prefix=''): + config['save_param_to_cpu'] = False + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, config['local_rank'], destination, prefix) + for name, module in model._modules.items(): + if module is not None: + if isinstance(module, TransformerBlockList): + for local_name, local_module in module._modules.items(): + local_state_dict = _save_to_local_rank0(local_module, None, prefix + name + "." + local_name + '.') + if config['local_rank'] == 0: + infer_model.load_layer_state_dict(local_state_dict) + else: + _save_to_infer_model(module, infer_model, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + + if config['local_rank'] == 0: + infer_model.load_layer_state_dict(destination) From 290386a76d3268f3c42d5d1a5e8bd1186b26c1e6 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 27 Sep 2023 20:56:14 +0800 Subject: [PATCH 34/52] Async save state_dict to file (#171) --- bmtrain/store.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/bmtrain/store.py b/bmtrain/store.py index 254213bd..7279ac53 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -9,6 +9,7 @@ from . import nccl import io, pickle from typing import Mapping +import threading def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): if isinstance(model, Block): @@ -81,8 +82,12 @@ def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, infer_model.load_layer_state_dict(destination) +def async_save_to_file(state_dict, file_path): + torch.save(state_dict, file_path) + config['finish_save'] = True + print("finish save state_dict to ", file_path) -def save(model : torch.nn.Module, file_name : str): +def save(model : torch.nn.Module, file_name : str, non_blocking : bool=True): """Saves the model to the file. Similar to torch.save, but it used for distributed modules. @@ -90,6 +95,8 @@ def save(model : torch.nn.Module, file_name : str): Args: model (torch.nn.Module): The model to be saved. file_name (str): The file name of the checkpoint. + non_blocking (bool): Whether to asynchronously save state_dict to file + Examples: >>> bmtrain.save(model, "model.pt") @@ -97,7 +104,18 @@ def save(model : torch.nn.Module, file_name : str): torch.cuda.synchronize() state_dict = _save_to_rank0(model) if config["rank"] == 0: - torch.save(state_dict, file_name) + if non_blocking is False: + torch.save(state_dict, file_name) + else: + if 'finish_save' not in config: + config['finish_save'] = True + + if config['finish_save'] is False: + config['save_thread'].join() + + config['finish_save'] = False + config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name)) + config['save_thread'].start() DTYPE_LIST = [ torch.float64, From cf6ad5516970ed0189bcccfb3708a3ba03a95880 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sun, 8 Oct 2023 16:03:21 +0800 Subject: [PATCH 35/52] init tp comm when tp_size>1 --- bmtrain/init.py | 49 ++++++++++++++++++++----------------- example/layers/attention.py | 22 +++++++++-------- 2 files changed, 38 insertions(+), 33 deletions(-) diff --git a/bmtrain/init.py b/bmtrain/init.py index cd304f89..b3802f99 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -120,29 +120,32 @@ def init_distributed( unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) - if topo.tp_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) - unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) - config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) - - if topo.tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) - config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['tp_size'], topo.tp_zero_id) - - if topo.pp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()) - config['pp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['pipe_size'], topo.pp_zero_id) - - if topo.pp_tp_zero_id == 0: - unique_id = nccl.getUniqueId() - store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) - unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) - config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) + if topo.pp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()) + config['pp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['pipe_size'], topo.pp_zero_id) + + if config['tp_size'] > 1: + if topo.tp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()) + config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//config['tp_size'], topo.tp_zero_id) + + + if config['pipe_size'] > 1 and config['tp_size'] > 1: + if topo.pp_tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex() ) + unique_id = bytes.fromhex(store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()) + config['pp_tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.pp_tp_zero_id) config ['zero_comm'] = config['comm'] diff --git a/example/layers/attention.py b/example/layers/attention.py index 8fbb7510..32497bcb 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -46,19 +46,21 @@ def forward(self, assert hidden_q.data_ptr() == hidden_kv.data_ptr() - hidden_q = bmt.nn.OpParallelLinear.apply( - hidden_q, - torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), - torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0), - True, False, - False, None - ) - - h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) - if config['tp_size'] > 1: + hidden_q = bmt.nn.OpParallelLinear.apply( + hidden_q, + torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), + torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0), + True, False, + False, None + ) + h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) #batch_size will changed in TensorParallel batch_size = h_v.shape[0] + else: + h_q : torch.Tensor = self.project_q(hidden_q) + h_k : torch.Tensor = self.project_k(hidden_kv) + h_v : torch.Tensor = self.project_v(hidden_kv) h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) From abcd7c8f77a4bf96330b638d89e4058143ca80f2 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Sun, 8 Oct 2023 20:46:19 +0800 Subject: [PATCH 36/52] fix tp --- bmtrain/nn/parallel_linear_func.py | 6 +++++ tests/test_column_parallel_linear.py | 33 +++++++++++++++++++++------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/bmtrain/nn/parallel_linear_func.py b/bmtrain/nn/parallel_linear_func.py index c3eec380..e242f7ed 100644 --- a/bmtrain/nn/parallel_linear_func.py +++ b/bmtrain/nn/parallel_linear_func.py @@ -259,6 +259,12 @@ def backward(ctx, grad_output): grad_input.record_stream(config['tp_comm_stream']) grad_all_input.record_stream(config['tp_comm_stream']) nccl.reduceScatter(grad_all_input.storage(), grad_input.storage(), "sum", config['tp_comm']) + elif ctx.reduce_output_type is None: + with torch.cuda.stream(config['tp_comm_stream']): + config['tp_comm_stream'].wait_stream(current_stream) + grad_input.record_stream(config['tp_comm_stream']) + nccl.allReduce(grad_all_input.storage(), grad_all_input.storage(), "sum", config['tp_comm']) + grad_input = grad_all_input else: grad_input = grad_all_input diff --git a/tests/test_column_parallel_linear.py b/tests/test_column_parallel_linear.py index 6dce1abd..5f2fdad1 100644 --- a/tests/test_column_parallel_linear.py +++ b/tests/test_column_parallel_linear.py @@ -3,8 +3,8 @@ from bmtrain.global_var import config import numpy as np -def run_bmt(x, gather_output, ckp_path, tp_size=2): - linear = bmt.nn.ColumnParallelLinear(8,8, gather_output=gather_output) +def run_bmt(x, gather_input, gather_output, ckp_path, tp_size=2): + linear = bmt.nn.ColumnParallelLinear(8,8, gather_input=gather_input, gather_output=gather_output) linear = bmt.Block(linear) bmt.init_parameters(linear) y = linear(x) @@ -23,13 +23,19 @@ def run_torch(x, ckp_path): y.sum().backward() return y, linear.weight.grad, linear.bias.grad -def run(gather_output, ckp_path): +def run(gather_input, gather_output, ckp_path): torch.cuda.manual_seed(100) tp_size = config["tp_size"] tp_rank = config['topology'].tp_id - x = torch.randn(8, 8, 8, device='cuda').requires_grad_() - rank_x = x.chunk(tp_size, dim=0)[tp_rank] - y1, weight_grad1, bias_grad1 = run_bmt(rank_x, gather_output, ckp_path) + x = torch.randn(8, 8, 8, device='cuda') + bmt_x = x.clone() + if gather_input: + rank_x = bmt_x.chunk(tp_size, dim=0)[tp_rank] + else: + rank_x = bmt_x + rank_x.requires_grad_() + x.requires_grad_() + y1, weight_grad1, bias_grad1 = run_bmt(rank_x, gather_input, gather_output, ckp_path) y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) tp_rank = config['topology'].tp_id if gather_output: @@ -44,14 +50,25 @@ def run(gather_output, ckp_path): bias_grad_list = bias_grad2.chunk(tp_size, dim=0) assert np.allclose(bias_grad1.reshape(bias_grad_list[tp_rank].shape).cpu().numpy(), bias_grad_list[tp_rank].cpu().numpy()) + if gather_input: + x_grad_list = x.grad.chunk(tp_size, dim=0) + np.testing.assert_allclose(rank_x.grad.cpu().numpy(), x_grad_list[tp_rank].cpu().numpy(), atol=1e-4, rtol=1e-4) + else: + np.testing.assert_allclose(rank_x.grad.cpu().numpy(), x.grad.cpu().numpy(), atol=1e-4, rtol=1e-4) + def test_gather_output(): - run(True, 'linear.ckp') + run(True, True, 'linear.ckp') def test_no_gather_output(): - run(False, 'linear_no_gather.ckp') + run(True, False, 'linear_no_gather.ckp') + +def test_no_gather_input(): + run(False, True, 'linear.ckp') + if __name__ == "__main__": bmt.init_distributed(tp_size=2) test_gather_output() test_no_gather_output() + test_no_gather_input() From 9dc781134745b95cac20af083ab0c58feb5ee868 Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Fri, 10 Nov 2023 10:54:07 +0800 Subject: [PATCH 37/52] fix adam bf16 load changed to fp16 (#175) --- bmtrain/optim/adam_offload.py | 4 ++-- bmtrain/store.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index 5b34a287..d7910aec 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -200,8 +200,8 @@ def load_state_dict(self, state_dict: dict) -> None: state[param]["_grad_fp32"] = torch.empty(param.size(), dtype=torch.float32, pin_memory=True) # on host else: # initialize placeholders - state[param]["_param_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host - state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=torch.float16, pin_memory=True) # on host + state[param]["_param_fp16"] = torch.empty(param.size(), dtype=param.dtype, pin_memory=True) # on host + state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=param.dtype, pin_memory=True) # on host else: state[k] = v diff --git a/bmtrain/store.py b/bmtrain/store.py index 7279ac53..8f731190 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -87,7 +87,7 @@ def async_save_to_file(state_dict, file_path): config['finish_save'] = True print("finish save state_dict to ", file_path) -def save(model : torch.nn.Module, file_name : str, non_blocking : bool=True): +def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): """Saves the model to the file. Similar to torch.save, but it used for distributed modules. From ed90d1f9aa68a44fc18f668f989ed588d6e883ec Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Thu, 30 Nov 2023 15:10:30 +0800 Subject: [PATCH 38/52] Fix parallel_for when grain_size > 0 (#179) This do not affect current BMTrain since BMTrain pass grain_size=0 in all parallel_for. --- csrc/include/adam_cpu.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index 1e497bb3..81a8ec9a 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -37,7 +37,7 @@ inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, const F int64_t num_threads = 1; // Default to serial execution if (grain_size > 0) { - num_threads = std::max((numiter+num_threads-1) / grain_size, static_cast(1)); + num_threads = std::max((numiter+grain_size-1) / grain_size, static_cast(1)); } else{ cpu_set_t cpu_set; From 76aeeff878f1baf962af5f8f43c3e6adc51f1801 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 1 Dec 2023 13:29:30 +0800 Subject: [PATCH 39/52] fix test_training --- tests/test_training.py | 117 +++++++++++++++++++++++++++++++---------- 1 file changed, 88 insertions(+), 29 deletions(-) diff --git a/tests/test_training.py b/tests/test_training.py index c1b3e4bd..ada41cc9 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -9,6 +9,50 @@ import os from bmtrain import inspect +def clip_grad_norm(loss_scale, param_groups, max_norm, norm_type=2, eps=1e-6, is_torch=False): + """Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized. + max_norm (float or int): max norm of the gradients. + norm_type (float or int): type of the used p-norm. Can be 'inf' for infinity norm. + eps (float): epsilon used to avoid zero division. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + scale = loss_scale + grads = [] + parameters = [p for group in param_groups for p in group['params']] + for p in parameters: + if p.grad is not None: + grads.append(p.grad.data) + else: + grads.append(torch.zeros_like(p.data)) + + if norm_type == 'inf': + total_norm_cuda = max(g.data.abs().max() for g in grads).detach() + if not is_torch: + bmt.nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "max", bmt.config["comm"]) + total_norm = total_norm_cuda + else: + norm_type = float(norm_type) + total_norm_cuda = torch.cuda.FloatTensor([0]) + for index, g in enumerate(grads): + param_norm = g.data.float().norm(norm_type) + total_norm_cuda += param_norm ** norm_type + if not is_torch: + bmt.nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "sum", bmt.config["comm"]) + total_norm = total_norm_cuda[0] ** (1. / norm_type) + clip_coef = float(max_norm * scale) / (total_norm + eps) + if clip_coef < 1: + for p in parameters: + if p.grad is not None: + p.grad.data.mul_(clip_coef) + return total_norm / scale + class Attention(torch.nn.Module): def __init__(self, dim_model : int, dim_head : int, @@ -196,6 +240,7 @@ def sub_train_torch(model, loss_func_cls, optimizer_cls): enc_inputs = [] targetss = [] masks = [] + inps = [] for i in range(bmt.world_size()): sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() @@ -213,34 +258,32 @@ def sub_train_torch(model, loss_func_cls, optimizer_cls): enc_inputs.append(enc_input) targetss.append(targets) masks.append(mask) - - sent = torch.cat(sents, dim=0) - enc_length = torch.cat(enc_lengths, dim=0) - enc_input = torch.cat(enc_inputs, dim=0) - targets = torch.cat(targetss, dim=0) - mask = torch.cat(masks, dim=0) - - pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + inps.append((sent,enc_length,enc_input,targets,mask)) logs = [] for iter in range(100): + optim_manager.zero_grad() + global_loss = 0 + for inp in inps: + sent, enc_length, enc_input, targets, mask = inp + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + logits = model(enc_input, pos, pos < enc_length[:, None]) - logits = model(enc_input, pos, pos < enc_length[:, None]) - - batch, seq_len, vocab_out_size = logits.size() + batch, seq_len, vocab_out_size = logits.size() - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) / len(inps) - global_loss = loss.item() + global_loss += loss.item() - loss = optim_manager.loss_scale * loss - loss.backward() - - grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=10.0) + loss = optim_manager.loss_scale * loss + loss.backward() + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(bmt.config['load_stream']) + grad_norm = clip_grad_norm(optim_manager.loss_scale, optimizer.param_groups, max_norm=10.0, is_torch = True) optim_manager.step() - + bmt.print_rank("| Iter: {:6d} | loss: {:.4f} {:.4f} | lr: {:.4e} scale: {:10.4f} | grad_norm: {:.4f} |".format( iter, global_loss, @@ -266,7 +309,11 @@ def sub_train(model, loss_func_cls, optimizer_cls): torch.manual_seed(2333) batch_size = 2 seq_len = 512 - + sents = [] + enc_lengths = [] + enc_inputs = [] + targetss = [] + masks = [] for i in range(bmt.world_size()): sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() @@ -279,15 +326,27 @@ def sub_train(model, loss_func_cls, optimizer_cls): torch.full_like(targets, -100, dtype=torch.long) ) - if i == bmt.rank(): - break - - pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + sents.append(sent) + enc_lengths.append(enc_length) + enc_inputs.append(enc_input) + targetss.append(targets) + masks.append(mask) + # sent = torch.cat(sents, dim=0) + # enc_length = torch.cat(enc_lengths, dim=0) + # enc_input = torch.cat(enc_inputs, dim=0) + # targets = torch.cat(targetss, dim=0) + # mask = torch.cat(masks, dim=0) + sent = sents[bmt.rank()] + enc_length = enc_lengths[bmt.rank()] + enc_input = enc_inputs[bmt.rank()] + targets = targetss[bmt.rank()] + mask = masks[bmt.rank()] + logs = [] + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) for iter in range(100): optim_manager.zero_grad() - logits = model(enc_input, pos, pos < enc_length[:, None]) batch, seq_len, vocab_out_size = logits.size() @@ -297,11 +356,9 @@ def sub_train(model, loss_func_cls, optimizer_cls): global_loss = bmt.sum_loss(loss).item() optim_manager.backward(loss) - - grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=10.0) + grad_norm = clip_grad_norm(optim_manager.loss_scale, optimizer.param_groups, max_norm=10.0) optim_manager.step() - bmt.print_rank("| Iter: {:6d} | loss: {:.4f} {:.4f} | lr: {:.4e} scale: {:10.4f} | grad_norm: {:.4f} |".format( iter, global_loss, @@ -439,8 +496,9 @@ def check(ret): if bmt.rank() == 0: for k1, v1 in ret.items(): for k2, v2 in ret.items(): - print(f"checking {k1} vs. {k2}") - check_param(v1[1], v2[1]) + if k1 != k2: + print(f"checking {k1} vs. {k2}") + check_param(v1[1], v2[1]) bmt.synchronize() ret.clear() @@ -454,4 +512,5 @@ def check_param(info1, info2): if __name__ == '__main__': bmt.init_distributed(pipe_size=1) + test_main(test_fp16=True, test_fp32=True) From 5aeba096886620db41484b8540854b8a5e7ac417 Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Fri, 1 Dec 2023 19:45:22 +0800 Subject: [PATCH 40/52] tp cross entropy (#180) --- bmtrain/loss/_function.py | 45 ++++++ bmtrain/loss/cross_entropy.py | 89 +++++++++-- bmtrain/nn/__init__.py | 3 +- bmtrain/nn/parallel_cross_entropy_func.py | 122 --------------- bmtrain/nn/parallel_embedding.py | 64 ++++---- csrc/bind.cpp | 4 + csrc/cuda/cross_entropy.cu | 171 +++++++++++++++++----- csrc/include/bind.hpp | 30 +++- tests/test_all.py | 2 +- tests/test_middle_hidden.py | 2 +- tests/test_other_hidden.py | 2 +- tests/test_parallel_projection.py | 55 +++++++ tests/test_requires_grad.py | 2 +- tests/test_requires_grad_multi_gpu.py | 2 +- tests/test_training.py | 6 +- 15 files changed, 384 insertions(+), 215 deletions(-) delete mode 100644 bmtrain/nn/parallel_cross_entropy_func.py create mode 100644 tests/test_parallel_projection.py diff --git a/bmtrain/loss/_function.py b/bmtrain/loss/_function.py index e2b67bb8..4ac02f5d 100644 --- a/bmtrain/loss/_function.py +++ b/bmtrain/loss/_function.py @@ -62,3 +62,48 @@ def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, ta C.cross_entropy_backward_inplace_bf16_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) else: raise ValueError(f"cross_entropy_backward not supported for dtype {input.dtype}") + +def fused_sumexp(logits: torch.Tensor, max_logits: torch.Tensor) -> torch.Tensor: + CHECK_INPUT(logits) + CHECK_INPUT(max_logits) + assert max_logits.dtype == torch.float32, "max_logits must be float tensor" + assert max_logits.size(0) == logits.size(0), "max_logits must have same size(0) as logits" + sum_exp_logits = torch.empty(logits.size(0), dtype=torch.float32, device=logits.device) + m = logits.size(0) + n = logits.size(1) + cuda_stream = torch.cuda.current_stream().cuda_stream + logits_ptr = logits.data_ptr() + max_logits_ptr = max_logits.data_ptr() + sum_exp_logits_ptr = sum_exp_logits.data_ptr() + if logits.dtype == torch.float16: + C.fused_sumexp_fp16_launcher(m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream) + elif logits.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.fused_sumexp_bf16_launcher(m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream) + else: + raise ValueError(f"fused_sumexp not supported for dtype {logits.dtype}") + return sum_exp_logits + +def fused_softmax_inplace(logits: torch.Tensor, max_logits: torch.Tensor, sum_exp_logits: torch.Tensor) -> None: + CHECK_INPUT(logits) + CHECK_INPUT(max_logits) + CHECK_INPUT(sum_exp_logits) + assert max_logits.dtype == torch.float32, "max_logits must be float tensor" + assert sum_exp_logits.dtype == torch.float32, "sum_exp_logits must be float tensor" + assert max_logits.size(0) == logits.size(0), "max_logits must have same size(0) as logits" + assert sum_exp_logits.size(0) == logits.size(0), "sum_exp_logits must have same size(0) as logits" + m = logits.size(0) + n = logits.size(1) + cuda_stream = torch.cuda.current_stream().cuda_stream + logits_ptr = logits.data_ptr() + max_logits_ptr = max_logits.data_ptr() + sum_exp_logits_ptr = sum_exp_logits.data_ptr() + if logits.dtype == torch.float16: + C.fused_softmax_inplace_fp16_launcher(m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream) + elif logits.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.fused_softmax_inplace_bf16_launcher(m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream) + else: + raise ValueError(f"fused_softmax_inplace not supported for dtype {logits.dtype}") \ No newline at end of file diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index a2e123ad..5be07665 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -1,9 +1,8 @@ from typing import Optional import torch from . import _function as F -from bmtrain.nn import parallel_cross_entropy_func from bmtrain.global_var import config -from bmtrain.distributed import all_gather +from bmtrain.distributed import all_gather, all_reduce class OpFusedCrossEntropy(torch.autograd.Function): """ @@ -36,6 +35,76 @@ def backward(ctx, grad_output : torch.Tensor): ) return (softmax, None, None) +class VPFusedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits : torch.Tensor, target : torch.Tensor): + comm = config['tp_comm'] + rank = config['tp_rank'] + world_size = config['tp_size'] + + max_logits = torch.max(logits, dim=-1)[0].float() + max_logits = all_reduce(max_logits, op="max", comm=comm) + + partition_vocab_size = logits.size()[-1] + vocab_start_index = rank * partition_vocab_size + vocab_end_index = (rank + 1) * partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d].contiguous() # (-1,) + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 # if target=-100, it will also be 0 + + # All reduce is needed to get the chunks from other GPUs. + predicted_logits = all_reduce(predicted_logits.float(), op="sum", comm=comm) + predicted_logits = predicted_logits - max_logits + # Sum of exponential of logits along vocab dimension across all GPUs. + + sum_exp_logits = torch.empty(logits.size(0), device=logits.device, dtype=torch.float) + sum_exp_logits = F.fused_sumexp(logits, max_logits) # float + sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + 1e-10 # avoid nan + + softmax = logits.clone() + F.fused_softmax_inplace(softmax, max_logits, sum_exp_logits) # logits -> softmax + # logits = logits.float() - max_logits.unsqueeze(dim=-1).float() + # exp_logits = logits + # torch.exp(logits, out=exp_logits) + # sum_exp_logits = exp_logits.sum(dim=-1) + # exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits + + # Normalize + ctx.save_for_backward(softmax, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + grad_2d[arange_1d, masked_target_1d] -= softmax_update + grad_input.mul_(grad_output.view(*grad_input.shape[:-1]).unsqueeze(dim=-1)) + + return grad_input, None + class FusedCrossEntropy(torch.nn.Module): r"""This criterion computes the cross entropy loss between input and target. @@ -149,7 +218,6 @@ def __init__(self, ignore_index: int = -100, reduction: str = 'mean', label_smoothing: float = 0.0, # TODO not supported yet - inplace: bool = False, parallel: bool = False, ) -> None: super().__init__() @@ -157,13 +225,11 @@ def __init__(self, self.ignore_index = ignore_index self.reduction = reduction self.label_smoothing = label_smoothing - self.inplace = inplace self.parallel = parallel def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.parallel: - target = all_gather(target, comm=config['tp_comm']).flatten(0,1) - ret = parallel_cross_entropy_func(input, target.long(), self.ignore_index) + ret = VPFusedCrossEntropy.apply(input, target.long()) else: if input.dtype == torch.float32: return torch.nn.functional.cross_entropy( @@ -174,10 +240,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduction=self.reduction, label_smoothing=self.label_smoothing) - if self.inplace: - ret = OpFusedCrossEntropyInplace.apply(input, target.int(), self.ignore_index) # return float tensor - else: - ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor + ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor if self.weight is not None: if self.weight.dim() != 1 or self.weight.size(0) != input.size(1): @@ -188,12 +251,6 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = (target != self.ignore_index).int() ret = w * ret - - if self.parallel and config['tp_size'] > 1: - ret_list = ret.chunk(config['tp_size'], dim=0) - ret = ret_list[config['topology'].tp_id] - w_list = w.chunk(config['tp_size'], dim=0) - w = w_list[config['topology'].tp_id] if self.reduction == "none": return ret diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index e22d8c55..85f45a3e 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1,6 +1,5 @@ from .linear import Linear, OpLinear from .column_parallel_linear import ColumnParallelLinear from .row_parallel_linear import RowParallelLinear -from .parallel_embedding import ParallelEmbedding -from .parallel_cross_entropy_func import parallel_cross_entropy_func +from .parallel_embedding import Projection, VPProjection from .parallel_linear_func import OpParallelLinear \ No newline at end of file diff --git a/bmtrain/nn/parallel_cross_entropy_func.py b/bmtrain/nn/parallel_cross_entropy_func.py deleted file mode 100644 index 8e4548ae..00000000 --- a/bmtrain/nn/parallel_cross_entropy_func.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -import bmtrain as bmt -from bmtrain.global_var import config -from bmtrain.distributed import all_reduce, all_gather - -class ParallelCrossEntropyFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, logits, target, label_smoothing=0.0): - comm = config['tp_comm'] - rank = config['topology'].tp_id - world_size = config['tp_size'] - - # local max - max_logits = torch.max(logits, dim=-1)[0] - # global max - max_logits = all_reduce(max_logits, op="max", comm=comm) - - logits = logits - max_logits.unsqueeze(dim=-1) - - partition_vocab_size = logits.size()[-1] - vocab_start_index = rank * partition_vocab_size - vocab_end_index = (rank + 1) * partition_vocab_size - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 - - logits_2d = logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], - device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 - - # All reduce is needed to get the chunks from other GPUs. - predicted_logits = all_reduce(predicted_logits, op="sum", comm=comm) - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = logits - torch.exp(logits, out=exp_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) - - # Loss = log(sum(exp(logits))) - predicted-logit. - loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits - - # Normalize and optionally smooth logits - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - - vocab_size = exp_logits.size(-1) - if label_smoothing > 0: - """ - We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth. - = (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt}) - = (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i - = (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i - = (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K - From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py - """ - assert 1.0 > label_smoothing > 0.0 - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - - # Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs. - log_probs = torch.log(exp_logits) - mean_log_probs = log_probs.mean(dim=-1) - loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs - - ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - - # Store softmax, target-mask and masked-target for backward pass. - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - - return loss - - @staticmethod - def backward(ctx, grad_output): - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d = ctx.saved_tensors - label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size - - # All the inputs have softmax as thier gradient. - grad_input = softmax - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], - device=grad_2d.device) - - softmax_update = 1.0 - target_mask.view(-1).float() - - if label_smoothing > 0: - smoothing = label_smoothing * vocab_size / (vocab_size - 1) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update - average_grad = 1 / vocab_size - grad_2d[arange_1d, :] -= smoothing * average_grad - else: - grad_2d[arange_1d, masked_target_1d] -= softmax_update - - grad_input.mul_(grad_output.view(*grad_input.shape[:-1]).unsqueeze(dim=-1)) - - return grad_input, None, None - - -def parallel_cross_entropy_func(logits, target, label_smoothing=0.0): - """ - Performs cross entropy loss when logits are split across tensor parallel ranks - - Arguments: - logits: logits split across tensor parallel ranks dimension is [batch * seq_len, hidden_size]. - target: correct vocab ids of dimseion [batch * seq_len]. - lobal_smoothing: smoothing factor, must be in range [0.0, 1.0). default is 0.0. - """ - out = ParallelCrossEntropyFunc.apply(logits.to(torch.float32), target, label_smoothing) - return out - diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index cd567b4e..562b2cab 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -8,7 +8,7 @@ from bmtrain.distributed import all_reduce, all_gather from .parallel_linear_func import OpParallelLinear -class ParallelEmbedding(bmt.DistributedModule): +class Projection(bmt.DistributedModule): def __init__( self, vocab_size: int, @@ -20,41 +20,44 @@ def __init__( super().__init__() self.dim_model = embedding_size - assert vocab_size % config['tp_size'] == 0 - self.vocab_size_per_partition = vocab_size // config['tp_size'] - self.start_index = config['topology'].tp_id * self.vocab_size_per_partition - self.end_index = (config['topology'].tp_id+1) * self.vocab_size_per_partition self.weight = bmt.DistributedParameter( - torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), + torch.empty(vocab_size, embedding_size, dtype=dtype), init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), - tp_mode=True, - tp_split_dim=0, ) - def forward(self, ids: torch.Tensor, gather_input=True): + def projection(self, x: torch.Tensor): """ + Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. Args: - ids (:obj:`torch.Tensor` of shape ``(batch_size, seq_len)``): Indices of input sequence tokens. - Return: - :obj:`torch.Tensor` of shape ``(batch_size, seq_len, embedding_size)``: The embedding output. + x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection + Returns: + :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. """ # noqa: E501 + logits = F.linear(x, self.weight) + return logits - if config['tp_size'] > 1: - if gather_input: - ids = all_gather(ids, comm=config['tp_comm']) - input_mask = (ids < self.start_index) | (ids >= self.end_index) - ids = ids.clone() - self.start_index - ids[input_mask] = 0 - - embeds = F.embedding(ids, self.weight) +class VPProjection(bmt.DistributedModule): + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ): + super().__init__() - if config['tp_size'] > 1: - embeds[input_mask, :] = 0.0 - embeds = all_reduce(embeds, op="sum", comm=config['tp_comm']) - embed_list = embeds.chunk(config['tp_size'], dim=0) - embeds = embed_list[config['topology'].tp_id].flatten(0,1) - - return embeds.clone() + self.dim_model = embedding_size + assert vocab_size % bmt.config["tp_size"] == 0 + self.vocab_size_per_partition = vocab_size // bmt.config["tp_size"] + self.start_index = bmt.config["tp_rank"] * self.vocab_size_per_partition + self.end_index = (bmt.config["tp_rank"] + 1) * self.vocab_size_per_partition + self.weight = bmt.DistributedParameter( + torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), + tp_split_dim=0, + tp_mode=True, + ) def projection(self, x: torch.Tensor): """ @@ -64,9 +67,4 @@ def projection(self, x: torch.Tensor): Returns: :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. """ # noqa: E501 - gather_input = True - split_input = False - reduce_output_type = None - gather_output = False - out = OpParallelLinear.apply(x , self.weight, None, gather_input, gather_output, split_input, reduce_output_type) - return out + return bmt.nn.OpParallelLinear.apply(x, self.weight, None, False, False, False, None, 1) \ No newline at end of file diff --git a/csrc/bind.cpp b/csrc/bind.cpp index 73f79a61..047da896 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -12,6 +12,10 @@ PYBIND11_MODULE(C, m) { m.def("cross_entropy_forward_bf16_launcher", &cross_entropy_forward_bf16_launcher, "cross entropy forward"); m.def("cross_entropy_backward_inplace_fp16_launcher", &cross_entropy_backward_inplace_fp16_launcher, "cross entropy backward inplace"); m.def("cross_entropy_backward_inplace_bf16_launcher", &cross_entropy_backward_inplace_bf16_launcher, "cross entropy backward inplace"); + m.def("fused_sumexp_fp16_launcher", &fused_sumexp_fp16_launcher, "sum exp"); + m.def("fused_sumexp_bf16_launcher", &fused_sumexp_bf16_launcher, "sum exp"); + m.def("fused_softmax_inplace_fp16_launcher", &fused_softmax_inplace_fp16_launcher, "softmax inplace"); + m.def("fused_softmax_inplace_bf16_launcher", &fused_softmax_inplace_bf16_launcher, "softmax inplace"); m.def("ncclGetUniqueId", &pyNCCLGetUniqueID, "nccl get unique ID"); m.def("ncclCommInitRank", &pyNCCLCommInitRank, "nccl init rank"); m.def("ncclCommDestroy", &pyNCCLCommDestroy, "nccl delete rank"); diff --git a/csrc/cuda/cross_entropy.cu b/csrc/cuda/cross_entropy.cu index bdd5a08e..177c3b77 100644 --- a/csrc/cuda/cross_entropy.cu +++ b/csrc/cuda/cross_entropy.cu @@ -53,32 +53,23 @@ __global__ void cross_entropy_backward_inplace_fp16( int64_t base_idx = blockIdx.x * n; int32_t t = target[blockIdx.x]; - if (t == ignore_index) { - half v = __float2half(0.); - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - x[base_idx + i] = v; - } - } - else { - half v = __float2half(grad_output[blockIdx.x]); - __syncthreads(); - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - x[base_idx + i] = i==t ? __hsub(__hmul(x[base_idx + i], v), v) : __hmul(x[base_idx + i], v); - } + float v = grad_output[blockIdx.x]; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + x[base_idx + i] = __float2half(i==t ? (__half2float(x[base_idx + i])-1)*v : __half2float(x[base_idx + i])*v); } } // blocks , threads<1024> __global__ void cross_entropy_forward_bf16( int64_t n, - std::uintptr_t input_ptr, // (m, n) + const std::uintptr_t input_ptr, // (m, n) const int32_t *target, // (m) std::uintptr_t softmax_ptr, // (m, n) float *output, // (m) int32_t ignore_index ) { #ifdef BF16_SUPPORT - __nv_bfloat16* input = reinterpret_cast<__nv_bfloat16*>(input_ptr); + const __nv_bfloat16* input = reinterpret_cast(input_ptr); __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr); int64_t base_idx = blockIdx.x * n; @@ -122,30 +113,90 @@ __global__ void cross_entropy_backward_inplace_bf16( int64_t base_idx = blockIdx.x * n; int32_t t = target[blockIdx.x]; - if (t == ignore_index) { - __nv_bfloat16 v = __float2bfloat16(0.); - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - x[base_idx + i] = v; - } + float v = grad_output[blockIdx.x]; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + x[base_idx + i] = __float2bfloat16(i==t ? (__bfloat162float(x[base_idx + i])-1)*v : __bfloat162float(x[base_idx + i])*v); } - else { - #if __CUDA_ARCH__ >= 800 - __nv_bfloat16 v = __float2bfloat16(grad_output[blockIdx.x]); - __syncthreads(); - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - x[base_idx + i] = i==t ? __hsub(__hmul(x[base_idx + i], v), v) : __hmul(x[base_idx + i], v); - } - #else - float v = grad_output[blockIdx.x]; - __syncthreads(); - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { - x[base_idx + i] = __float2bfloat16(i==t ? (__bfloat162float(x[base_idx + i])*v)-v : __bfloat162float(x[base_idx + i])*v); - } - #endif +#endif +} + +// blocks , threads<1024> +__global__ void fused_sumexp_fp16( + int64_t n, + const half *input, // (m, n) + const float *global_max, // (m) + float *global_sum // (m) +) { + int64_t base_idx = blockIdx.x * n; + float local_max = global_max[blockIdx.x]; + + float local_sum = 0; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + local_sum += expf(__half2float(input[base_idx + i]) - local_max); + } + local_sum = block_allreduce_sum(local_sum); + if (threadIdx.x == 0) { + global_sum[blockIdx.x] = local_sum; + } +} + +// blocks , threads<1024> +__global__ void fused_sumexp_bf16( + int64_t n, + const std::uintptr_t input_ptr, // (m, n) + const float *global_max, // (m) + float *global_sum // (m) +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* input = reinterpret_cast(input_ptr); + int64_t base_idx = blockIdx.x * n; + float local_max = global_max[blockIdx.x]; + + float local_sum = 0; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max); + } + local_sum = block_allreduce_sum(local_sum); + if (threadIdx.x == 0) { + global_sum[blockIdx.x] = local_sum; } #endif } +// blocks , threads<1024> +__global__ void fused_softmax_inplace_fp16( + int64_t n, + half *softmax, // (m, n) + const float *global_max, // (m) + const float *global_sum // (m) +) { + int64_t base_idx = blockIdx.x * n; + float local_max = global_max[blockIdx.x]; + float local_sum = global_sum[blockIdx.x]; + + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + softmax[base_idx + i] = __float2half( expf(__half2float(softmax[base_idx + i]) - local_max) / local_sum ); + } +} + +// blocks , threads<1024> +__global__ void fused_softmax_inplace_bf16( + int64_t n, + std::uintptr_t softmax_ptr, // (m, n) + const float *global_max, // (m) + const float *global_sum // (m) +) { +#ifdef BF16_SUPPORT + __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr); + int64_t base_idx = blockIdx.x * n; + float local_max = global_max[blockIdx.x]; + float local_sum = global_sum[blockIdx.x]; + + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(softmax[base_idx + i]) - local_max) / local_sum ); + } +#endif +} } void cross_entropy_forward_fp16_launcher( @@ -207,4 +258,58 @@ void cross_entropy_backward_inplace_bf16_launcher( auto target_ptr = reinterpret_cast(target); int32_t threads = 1024; cross_entropy_backward_inplace_bf16<<(stream)>>>(n, output_ptr, target_ptr, x, ignore_index); +} + +void fused_sumexp_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +) { + auto logits_ptr = reinterpret_cast(logits); + auto max_logits_ptr = reinterpret_cast(max_logits); + auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits); + int32_t threads = 1024; + fused_sumexp_fp16<<(stream)>>>(n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr); +} + +void fused_sumexp_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +) { + auto max_logits_ptr = reinterpret_cast(max_logits); + auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits); + int32_t threads = 1024; + fused_sumexp_bf16<<(stream)>>>(n, logits, max_logits_ptr, sum_exp_logits_ptr); +} + +void fused_softmax_inplace_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +) { + auto logits_ptr = reinterpret_cast(logits); + auto max_logits_ptr = reinterpret_cast(max_logits); + auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits); + int32_t threads = 1024; + fused_softmax_inplace_fp16<<(stream)>>>(n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr); +} + +void fused_softmax_inplace_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +) { + auto max_logits_ptr = reinterpret_cast(max_logits); + auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits); + int32_t threads = 1024; + fused_softmax_inplace_bf16<<(stream)>>>(n, logits, max_logits_ptr, sum_exp_logits_ptr); } \ No newline at end of file diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp index 94d6af95..a9067a2a 100644 --- a/csrc/include/bind.hpp +++ b/csrc/include/bind.hpp @@ -41,6 +41,34 @@ void cross_entropy_backward_inplace_bf16_launcher( int32_t ignore_index, std::uintptr_t stream ); +void fused_sumexp_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +); +void fused_sumexp_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +); +void fused_softmax_inplace_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +); +void fused_softmax_inplace_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +); void adam_fp16_launcher( int n, std::uintptr_t param_fp32, @@ -70,4 +98,4 @@ void adam_bf16_launcher( float bias_correction1, float bias_correction2, uintptr_t stream -); \ No newline at end of file +); diff --git a/tests/test_all.py b/tests/test_all.py index 07be4077..db5d2dd4 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -21,7 +21,6 @@ ("multi_return", 2), ("middle_hidden", 4), ("other_hidden", 4), - ("inspector_hidden", 2), ("model_wrapper", 4), @@ -30,6 +29,7 @@ ("no_grad", 1), ("column_parallel_linear", 2), ("row_parallel_linear", 2), + ("parallel_projection", 4), ("training", 4), ]) diff --git a/tests/test_middle_hidden.py b/tests/test_middle_hidden.py index 6a1f476f..2a93efe0 100644 --- a/tests/test_middle_hidden.py +++ b/tests/test_middle_hidden.py @@ -185,7 +185,7 @@ def test_main(): ret["normal"] = run("normal", Model_NORMAL) ret["block"] = run("block", Model_BLOCK) ret["zero"] = run("zero", Model_ZERO) - ret["pipe"] = run("pipe", Model_PIPE) + # ret["pipe"] = run("pipe", Model_PIPE) for k, r in ret.items(): bmt.print_rank(f"============={k}============") bmt.print_rank(r) diff --git a/tests/test_other_hidden.py b/tests/test_other_hidden.py index aab9c678..27736aa7 100644 --- a/tests/test_other_hidden.py +++ b/tests/test_other_hidden.py @@ -177,7 +177,7 @@ def test_main(): ret.append( run("normal", Model_NORMAL) ) ret.append( run("block", Model_BLOCK) ) ret.append( run("zero", Model_ZERO) ) - ret.append( run("pipe", Model_PIPE) ) + # ret.append( run("pipe", Model_PIPE) ) for r in ret: bmt.print_rank(r) for r in ret: diff --git a/tests/test_parallel_projection.py b/tests/test_parallel_projection.py new file mode 100644 index 00000000..98de30a0 --- /dev/null +++ b/tests/test_parallel_projection.py @@ -0,0 +1,55 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np +import os + +def run_normal(x, t, ckp_path, dtype): + proj = bmt.nn.Projection(100, 64, dtype=dtype) + bmt.init_parameters(proj) + bmt.save(proj, ckp_path) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=False) + y = proj.projection(x) + y = y.detach().requires_grad_() + loss = loss_func(y, t) + loss.backward() + return y, loss, y.grad + +def run_vp(x, t, ckp_path, dtype): + proj = bmt.nn.VPProjection(100, 64, dtype=dtype) + bmt.load(proj, ckp_path) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + y = proj.projection(x) + y = y.detach().requires_grad_() + loss = loss_func(y, t) + loss.backward() + return y, loss, y.grad + +def run(dtype): + ckp_path = 'embedding.pt' + torch.cuda.manual_seed(100) + tp_size = config["tp_size"] + tp_rank = config['tp_rank'] + x = torch.randn(110, 64, device='cuda', dtype=dtype) + t = torch.cat([torch.arange(100).view(10, 10), torch.ones((10, 1))*-100], dim=-1).view(110).int().cuda() + y1, loss1, grad1 = run_normal(x, t, ckp_path, dtype) + y2, loss2, grad2 = run_vp(x, t, ckp_path, dtype) + y1 = y1.chunk(tp_size, dim=-1)[tp_rank] + grad1 = grad1.chunk(tp_size, dim=-1)[tp_rank] + for r in range(tp_size): + if bmt.rank() == r: + print((y1-y2).abs().max()) + print((loss1-loss2).abs().max()) + print((grad1-grad2).abs().max()) + assert (y1-y2).abs().max() < 1e-4 + assert (loss1-loss2).abs().max() < 1e-4 + assert (grad1-grad2).abs().max() < 1e-4 + bmt.synchronize() + if bmt.rank() == 0: + os.remove(f"embedding.pt") + +if __name__ == "__main__": + bmt.init_distributed(tp_size=4) + run(torch.half) + run(torch.bfloat16) + diff --git a/tests/test_requires_grad.py b/tests/test_requires_grad.py index f8a11ec9..9a443bd3 100644 --- a/tests/test_requires_grad.py +++ b/tests/test_requires_grad.py @@ -104,4 +104,4 @@ def test_main_pipe(): bmt.init_distributed(pipe_size=1) test_main() - test_main_pipe() + # test_main_pipe() diff --git a/tests/test_requires_grad_multi_gpu.py b/tests/test_requires_grad_multi_gpu.py index cd2e27b2..2eedf7b6 100644 --- a/tests/test_requires_grad_multi_gpu.py +++ b/tests/test_requires_grad_multi_gpu.py @@ -93,4 +93,4 @@ def test_main_pipe(): bmt.init_distributed(pipe_size=1) test_main() - test_main_pipe() + # test_main_pipe() diff --git a/tests/test_training.py b/tests/test_training.py index ada41cc9..46389802 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -446,7 +446,7 @@ def unroll_list_model(): "torch": torch_model, "wrapper": wrap_model, "blocklist": list_model, - "pipelist": pipe_model, + # "pipelist": pipe_model, "unroll_blocklist": unroll_list_model, } loss_funcs = { @@ -470,7 +470,7 @@ def add_to_check_list(m, l, o): add_to_check_list("torch", "bmt_entropy", "bmt_adam") add_to_check_list("wrapper", "bmt_entropy", "bmt_adam") add_to_check_list("blocklist", "bmt_entropy", "bmt_adam") - add_to_check_list("pipelist", "bmt_entropy", "bmt_adam") + # add_to_check_list("pipelist", "bmt_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "bmt_entropy", "bmt_adam_offload") add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") @@ -484,7 +484,7 @@ def add_to_check_list(m, l, o): add_to_check_list("torch", "torch_entropy", "bmt_adam") add_to_check_list("wrapper", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam") - add_to_check_list("pipelist", "torch_entropy", "bmt_adam") + # add_to_check_list("pipelist", "torch_entropy", "bmt_adam") add_to_check_list("blocklist", "torch_entropy", "bmt_adam_offload") add_to_check_list("blocklist", "torch_entropy", "torch_adam") add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") From 3c80ce48ce3c9292c676665b427f4d1e3c04f8f4 Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Wed, 27 Dec 2023 09:56:35 +0800 Subject: [PATCH 41/52] Update optim_manager.py (#181) fix loss scale for tp --- bmtrain/optim/optim_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 9b7a3120..19283cd0 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -85,7 +85,7 @@ def add_optimizer( def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: - return loss * (self.loss_scale / config['world_size']) # loss scale + return loss * (self.loss_scale / (config['world_size']//(config['tp_size']*config['pipe_size']))) # loss scale def backward(self, loss : torch.Tensor): """ From 6abcf772aa1e120192f7656e55c4adbcde53c886 Mon Sep 17 00:00:00 2001 From: William <323163497@qq.com> Date: Wed, 27 Dec 2023 09:57:13 +0800 Subject: [PATCH 42/52] Feat optim manager state (#176) * fix adam bf16 load changed to fp16 * support save optim_manager.state (including optimizer, lr_scheduler, loss_scale) --- bmtrain/optim/optim_manager.py | 17 ++++++++++++++++ bmtrain/store.py | 2 ++ tests/test_optim_state.py | 36 ++++++++++++++++++++++------------ 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 19283cd0..088f0e78 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -203,3 +203,20 @@ def _justify_scale(self, scale): self.loss_scale = scale self.steps_since_last_scale = 0 + def state_dict(self) -> dict: + return { + "optimizers": [opt.state_dict() for opt in self.optimizers], + "lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers], + "loss_scale": self.loss_scale, + "loss_scale_enabled": self.loss_scale_enabled, + } + + def load_state_dict(self, state_dict: dict) -> None: + assert len(self.optimizers) == len(state_dict["optimizers"]) + assert len(self.lr_schedulers) == len(state_dict["lr_schedulers"]) + for opt, opt_st in zip(self.optimizers, state_dict["optimizers"]): + opt.load_state_dict(opt_st) + for lrs, lrs_st in zip(self.lr_schedulers, state_dict["lr_schedulers"]): + lrs.load_state_dict(lrs_st) + self.loss_scale = state_dict["loss_scale"] + self.loss_scale_enabled = state_dict["loss_scale_enabled"] diff --git a/bmtrain/store.py b/bmtrain/store.py index 8f731190..49c596a5 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -10,6 +10,7 @@ import io, pickle from typing import Mapping import threading +import bmtrain as bmt def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): if isinstance(model, Block): @@ -116,6 +117,7 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): config['finish_save'] = False config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name)) config['save_thread'].start() + bmt.synchronize() DTYPE_LIST = [ torch.float64, diff --git a/tests/test_optim_state.py b/tests/test_optim_state.py index 16833b42..57d5d0e3 100644 --- a/tests/test_optim_state.py +++ b/tests/test_optim_state.py @@ -2,7 +2,7 @@ import bmtrain as bmt import os from copy import deepcopy -from bmtrain import optim +from bmtrain import optim, lr_scheduler class TestSubModule(bmt.DistributedModule): def __init__(self): @@ -71,10 +71,13 @@ def main(): opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) + lrs1 = lr_scheduler.Noam(opt1, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1) + lrs2 = lr_scheduler.Noam(opt2, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1) + lrs3 = lr_scheduler.Noam(opt3, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1) optim_manager = optim.OptimManager(loss_scale=256) - optim_manager.add_optimizer(opt1) - optim_manager.add_optimizer(opt2) - optim_manager.add_optimizer(opt3) + optim_manager.add_optimizer(opt1, lrs1) + optim_manager.add_optimizer(opt2, lrs2) + optim_manager.add_optimizer(opt3, lrs3) train(model1, model2, model3, optim_manager) @@ -82,21 +85,30 @@ def main(): bmt.save(model2, f"test_optim_state_model2.pt") bmt.save(model3, f"test_optim_state_model3.pt") - torch.save(opt1.state_dict(), f"test_optim_state_opt1_{bmt.rank()}.opt") - torch.save(opt2.state_dict(), f"test_optim_state_opt2_{bmt.rank()}.opt") - torch.save(opt3.state_dict(), f"test_optim_state_opt3_{bmt.rank()}.opt") + torch.save(optim_manager.state_dict(), f"test_optim_manager_{bmt.rank()}.opt") manual_seed() train(model1, model2, model3, optim_manager) state_2 = deepcopy([list(model1.parameters()), list(model2.parameters()), list(model3.parameters())]) + model1 = TestModule() + model2 = TestModule() + model3 = TestModule() bmt.load(model1, f"test_optim_state_model1.pt") bmt.load(model2, f"test_optim_state_model2.pt") bmt.load(model3, f"test_optim_state_model3.pt") - opt1.load_state_dict(torch.load(f"test_optim_state_opt1_{bmt.rank()}.opt")) - opt2.load_state_dict(torch.load(f"test_optim_state_opt2_{bmt.rank()}.opt")) - opt3.load_state_dict(torch.load(f"test_optim_state_opt3_{bmt.rank()}.opt")) + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-8, betas=(0.3, 0.333), eps=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-7, betas=(0.4, 0.456), eps=1e-1) + opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-6, betas=(0.9, 0.777), eps=1e-2) + lrs1 = lr_scheduler.Noam(opt1, start_lr=200, warmup_iter=30, end_iter=500, num_iter=3) + lrs2 = lr_scheduler.Noam(opt2, start_lr=20, warmup_iter=40, end_iter=600, num_iter=1) + lrs3 = lr_scheduler.Noam(opt3, start_lr=10, warmup_iter=50, end_iter=700, num_iter=2) + optim_manager = optim.OptimManager(loss_scale=10485760) + optim_manager.add_optimizer(opt1, lrs1) + optim_manager.add_optimizer(opt2, lrs2) + optim_manager.add_optimizer(opt3, lrs3) + optim_manager.load_state_dict(torch.load(f"test_optim_manager_{bmt.rank()}.opt")) manual_seed() train(model1, model2, model3, optim_manager) @@ -116,9 +128,7 @@ def main(): os.remove(f"test_optim_state_model1.pt") os.remove(f"test_optim_state_model2.pt") os.remove(f"test_optim_state_model3.pt") - os.remove(f"test_optim_state_opt1_{bmt.rank()}.opt") - os.remove(f"test_optim_state_opt2_{bmt.rank()}.opt") - os.remove(f"test_optim_state_opt3_{bmt.rank()}.opt") + os.remove(f"test_optim_manager_{bmt.rank()}.opt") if __name__ == "__main__": bmt.init_distributed() From b7d26e2d1e8e3b0978c94de0c5baa038398c805e Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 20 Feb 2024 13:22:45 +0800 Subject: [PATCH 43/52] FIX: allgather_object stuck --- bmtrain/store.py | 45 +++++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/bmtrain/store.py b/bmtrain/store.py index 49c596a5..2a3ee02c 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -133,26 +133,31 @@ def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): _pickler = pickle.Pickler _unpickler = pickle.Unpickler -def allgather_object(obj, comm): - f = io.BytesIO() - _pickler(f).dump(obj) - byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) - # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. - # Otherwise, it will casue 100X slowdown. - # See: - byte_tensor = torch.ByteTensor(byte_storage).cuda() - all_bytes_tensors = torch.empty(byte_tensor.numel() * nccl.commCount(comm), dtype=torch.uint8, device="cuda") - nccl.allGather( - byte_tensor.storage(), - all_bytes_tensors.storage(), - comm - ) - obj_list = [] - for i in range(nccl.commCount(comm)): - buf = all_bytes_tensors[i*byte_tensor.numel():(i+1)*byte_tensor.numel()].cpu().numpy().tobytes() - obj = _unpickler(io.BytesIO(buf)).load() - obj_list.append(obj) - return obj_list + +def allgather_objects(obj): + if bmt.world_size() == 1: + return [obj] + + with torch.no_grad(): + data_bytes: bytes = pickle.dumps(obj) + data_length: int = len(data_bytes) + + gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long) + gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu() + max_data_length = gathered_length.max().item() + + gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda") + byte_storage = torch.ByteStorage.from_buffer(data_bytes) + gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage) + + gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu() + + ret = [] + for i in range(gathered_data.size(0)): + data_bytes = gathered_data[i, : gathered_length[i].item()].numpy().tobytes() + ret.append(pickle.loads(data_bytes)) + return ret + def broadcast_object(obj, comm, src = 0): if nccl.commRank(comm) == src: f = io.BytesIO() From 3b7b6b9a4b0a33394b6d6521da19d43cf9d38ee9 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 20 Feb 2024 15:04:51 +0800 Subject: [PATCH 44/52] add test for allgather object --- tests/test_store.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 tests/test_store.py diff --git a/tests/test_store.py b/tests/test_store.py new file mode 100644 index 00000000..cb427d5b --- /dev/null +++ b/tests/test_store.py @@ -0,0 +1,13 @@ +import bmtrain as bmt +from bmtrain.store import allgather_object + +def test_allgather_object(): + + res = allgather_object(bmt.rank(), bmt.config["comm"]) + ref = [i for i in range(bmt.world_size())] + assert res == ref + +if __name__ == "__main__": + bmt.init_distributed() + test_allgather_object() + From 281860a074f2cc17922d264052fa253e69591b68 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Tue, 20 Feb 2024 17:26:29 +0800 Subject: [PATCH 45/52] fix projection interface --- bmtrain/nn/parallel_embedding.py | 4 ++-- tests/test_parallel_projection.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 562b2cab..39aa147b 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -25,7 +25,7 @@ def __init__( init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), ) - def projection(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. Args: @@ -59,7 +59,7 @@ def __init__( tp_mode=True, ) - def projection(self, x: torch.Tensor): + def forward(self, x: torch.Tensor): """ Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. Args: diff --git a/tests/test_parallel_projection.py b/tests/test_parallel_projection.py index 98de30a0..dc1e874d 100644 --- a/tests/test_parallel_projection.py +++ b/tests/test_parallel_projection.py @@ -9,7 +9,7 @@ def run_normal(x, t, ckp_path, dtype): bmt.init_parameters(proj) bmt.save(proj, ckp_path) loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=False) - y = proj.projection(x) + y = proj(x) y = y.detach().requires_grad_() loss = loss_func(y, t) loss.backward() @@ -19,7 +19,7 @@ def run_vp(x, t, ckp_path, dtype): proj = bmt.nn.VPProjection(100, 64, dtype=dtype) bmt.load(proj, ckp_path) loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) - y = proj.projection(x) + y = proj(x) y = y.detach().requires_grad_() loss = loss_func(y, t) loss.backward() From d933ee90ead05449542ccd19a01d2155ef451c76 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 22 Feb 2024 19:00:28 +0800 Subject: [PATCH 46/52] Optimizer load gathered state and record delta feature are supported now --- bmtrain/block_layer.py | 4 + bmtrain/optim/_distributed.py | 29 ++++++ bmtrain/optim/_function.py | 15 ++- bmtrain/optim/adam.py | 8 ++ bmtrain/optim/adam_offload.py | 66 +++++++++++-- bmtrain/optim/optim_manager.py | 4 +- csrc/bind.cpp | 25 +++-- csrc/include/adam_cpu.hpp | 175 +++++++++++++++++++++++++++++++-- csrc/include/bind.hpp | 24 +++-- 9 files changed, 315 insertions(+), 35 deletions(-) create mode 100644 bmtrain/optim/_distributed.py diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index ab027fc9..98200465 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -210,11 +210,15 @@ def init_param_storage(self): param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) self._param_info[-1]["begin"] = to_offset_st self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) + setattr(param, "_start_partition", offset_st) + setattr(param, "_end_partition", offset_end) param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) + setattr(param, "_start_partition", None) + setattr(param, "_end_partition", 0) # clear parameter data, but keep the dtype and device setattr(param, "_in_block", True) diff --git a/bmtrain/optim/_distributed.py b/bmtrain/optim/_distributed.py new file mode 100644 index 00000000..11daa2b0 --- /dev/null +++ b/bmtrain/optim/_distributed.py @@ -0,0 +1,29 @@ +import torch +from ..distributed import all_reduce, all_gather + +def state_dict_gather(state_dict): + param_key = [p for param_group in state_dict['param_groups'] for p in param_group['params'] ] + for k, v in state_dict['state'].items(): + if "step" in v: + step = v['step'] + + for k in param_key: + if k not in state_dict['state']: + state_dict['state'][k] = { + 'exp_avg' : torch.tensor([], device="cuda", dtype=torch.float32), + 'exp_avg_sq' : torch.tensor([], device="cuda", dtype=torch.float32), + '_param_fp32' : torch.tensor([], device="cuda", dtype=torch.float32), + 'step' : step + } + v = state_dict['state'][k] + for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + if name in v: + with torch.no_grad(): + numel = torch.tensor(v[name].numel(), device="cuda", dtype=torch.long) + max_numel = all_reduce(numel, op="max") + v_p = torch.nn.functional.pad(v[name], (0, max_numel - numel), value=-1e15) + if max_numel > 0: + whole_state = all_gather(v_p.cuda()).flatten() + whole_state = whole_state[whole_state != -1e15] + v[name] = whole_state.contiguous().cpu() + return state_dict \ No newline at end of file diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py index f04f9ca0..b72ea8e0 100644 --- a/bmtrain/optim/_function.py +++ b/bmtrain/optim/_function.py @@ -1,7 +1,18 @@ from .. import C import torch CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda -def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp32: torch.Tensor, + +def bf16_from_fp32(param_fp32): + param_bf16 = torch.empty_like(param_fp32, dtype=torch.bfloat16) + C.to_bf16_from_fp32(param_fp32.numel(), param_fp32.data_ptr(), param_bf16.data_ptr()) + return param_bf16 + +def fp16_from_fp32(param_fp32): + param_fp16 = torch.empty_like(param_fp32, dtype=torch.float16) + C.to_fp16_from_fp32(param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr()) + return param_fp16 + +def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, delta_info: torch.Tensor, g_fp16: torch.Tensor, m_fp32: torch.Tensor, v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, weight_decay: float, step: int) -> None: assert param_fp32.is_contiguous(), "param_fp32 must be contiguous" @@ -19,6 +30,7 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor" assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor" assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor" + #TODO check avg_delta and var_delta assert param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements" assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements" assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements" @@ -35,6 +47,7 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr(), + delta_info.data_ptr() if delta_info is not None else 0, g_fp16.data_ptr(), m_fp32.data_ptr(), v_fp32.data_ptr(), diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index a3138980..d412b80e 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -131,6 +131,14 @@ def step(self, closure=None, scale=1): return loss + def get_avg_delta(): + + raise NotImplementedError("get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer") + + def get_var_delta(): + + raise NotImplementedError("get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer") + def load_state_dict(self, state_dict: dict) -> None: r"""Loads the optimizer state. diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index d7910aec..7bbc7e46 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -7,6 +7,7 @@ from copy import deepcopy from itertools import chain from collections import defaultdict +from ._distributed import state_dict_gather class AdamOffloadOptimizer(torch.optim.Optimizer): """ @@ -14,7 +15,7 @@ class AdamOffloadOptimizer(torch.optim.Optimizer): """ _bmtrain_optimizer = True - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, hold_steps=0): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, hold_steps=0, record_delta=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -25,12 +26,17 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - + self.avg_delta = 0 + self.var_delta = 0 defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) - self._hold_steps = hold_steps self._events = {} + self.record_delta = record_delta + if self.record_delta: + for group in self.param_groups: + for p in group['params']: + setattr(p, "_delta_info", ( torch.tensor([0 for i in range(4)], dtype=torch.float32, device="cpu") )) @torch.no_grad() def step(self, closure=None, scale=1): @@ -92,7 +98,9 @@ def step(self, closure=None, scale=1): else: state["_grad_fp16"].copy_(param.grad, non_blocking=True) torch.cuda.current_stream().record_event(event) - + sum_delta = 0 + sum_sq_delta = 0 + total_numel = 0 for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params: # wait for transfer to host event.synchronize() @@ -135,6 +143,7 @@ def step(self, closure=None, scale=1): F.adam_cpu( state["_param_fp32"].view(-1), state["_param_fp16"].view(-1), + param._delta_info if self.record_delta else None, grad.view(-1), state["exp_avg"].view(-1), state["exp_avg_sq"].view(-1), @@ -144,12 +153,25 @@ def step(self, closure=None, scale=1): weight_decay, state["step"] ) + total_numel += state["_param_fp16"].numel() + if self.record_delta: + sum_delta += param._delta_info[2].item(); + sum_sq_delta += param._delta_info[3].item(); # transfer parameters back to device asynchronously param.copy_(state["_param_fp16"], non_blocking=True) + if self.record_delta: + self.avg_delta = sum_delta / total_numel + self.var_delta = sum_sq_delta / total_numel - self.avg_delta ** 2 return loss + def get_avg_delta(self) -> None: + return self.avg_delta if self.record_delta else 0 + + def get_var_delta(self) -> None: + return self.var_delta if self.record_delta else 0 + def load_state_dict(self, state_dict: dict) -> None: r"""Loads the optimizer state. @@ -158,6 +180,9 @@ def load_state_dict(self, state_dict: dict) -> None: from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API + + + state_dict = deepcopy(state_dict) # Validate the state_dict groups = self.param_groups @@ -177,13 +202,27 @@ def load_state_dict(self, state_dict: dict) -> None: zip(chain.from_iterable((g['params'] for g in saved_groups)), chain.from_iterable((g['params'] for g in groups)))} + # _param_start_end = chain.from_iterable((g["params_start_end"] for g in saved_groups)) # Copy state assigned to params (and cast tensors to appropriate types). # State that is not assigned to params is copied as is (needed for # backward compatibility). state = defaultdict(dict) + is_whole = False if "is_whole" not in state_dict else state_dict['is_whole'] + pop_key = [] for k, v in state_dict['state'].items(): if k in id_map: param = id_map[k] + if is_whole and param._start_partition is not None: + for key in ['_param_fp32', 'exp_avg_sq', 'exp_avg']: + if key in v: + v[key] = v[key][param._start_partition:param._end_partition] + elif is_whole and param._start_partition is None: + pop_key.append(param) + + if "_param_fp32" not in v: + with torch.no_grad(): + v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") + v["_param_fp32"].copy_(param) if "_param_fp32" not in v: v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") @@ -204,7 +243,8 @@ def load_state_dict(self, state_dict: dict) -> None: state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=param.dtype, pin_memory=True) # on host else: state[k] = v - + for k in pop_key: + state.pop(k) # Update parameter groups, setting their 'params' value def update_group(group, new_group): new_group['params'] = group['params'] @@ -212,8 +252,10 @@ def update_group(group, new_group): param_groups = [ update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({'state': state, 'param_groups': param_groups}) + + - def state_dict(self) -> dict: + def state_dict(self, gather=False) -> dict: r"""Returns the state of the optimizer as a :class:`dict`. It contains two entries: @@ -223,6 +265,7 @@ def state_dict(self) -> dict: * param_groups - a list containing all parameter groups where each parameter group is a dict """ + # Save order indices instead of Tensors param_mappings = {} start_index = 0 @@ -247,11 +290,18 @@ def cut_states(state): # Remap state to use order indices as keys packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v) for k, v in self.state.items()} - return { + states = { 'state': packed_state, 'param_groups': param_groups, } + if gather: + states = state_dict_gather(states) + states['is_whole'] = True + else: + states['is_whole'] = False + + return states #TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu def zero_grad(self, set_to_none: bool = False): - super().zero_grad(set_to_none=set_to_none) \ No newline at end of file + super().zero_grad(set_to_none=set_to_none) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 088f0e78..7aa1bb81 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -203,9 +203,9 @@ def _justify_scale(self, scale): self.loss_scale = scale self.steps_since_last_scale = 0 - def state_dict(self) -> dict: + def state_dict(self, gather_opt=False) -> dict: return { - "optimizers": [opt.state_dict() for opt in self.optimizers], + "optimizers": [opt.state_dict(gather_opt) for opt in self.optimizers], "lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers], "loss_scale": self.loss_scale, "loss_scale_enabled": self.loss_scale_enabled, diff --git a/csrc/bind.cpp b/csrc/bind.cpp index 047da896..db5aedc8 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -1,9 +1,11 @@ #include "include/bind.hpp" PYBIND11_MODULE(C, m) { - m.def("is_bf16_supported",&is_bf16_supported,"whether bf16 supported"); - m.def("has_nan_inf_fp16_launcher",&has_nan_inf_fp16_launcher,"has nan inf"); - m.def("has_nan_inf_bf16_launcher",&has_nan_inf_bf16_launcher,"has nan inf bf16"); + m.def("to_fp16_from_fp32", &fp16_from_fp32_value_launcher, "convert"); + m.def("to_bf16_from_fp32", &bf16_from_fp32_value_launcher, "convert"); + m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported"); + m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf"); + m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16"); m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu"); m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu"); m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu"); @@ -26,8 +28,17 @@ PYBIND11_MODULE(C, m) { m.def("ncclReduceScatter", &pyNCCLReduceScatter, "nccl reduce scatter"); m.def("ncclGroupStart", &pyNCCLGroupStart, "nccl group start"); m.def("ncclGroupEnd", &pyNCCLGroupEnd, "nccl group end"); - m.def("ncclSend",&pyNCCLSend,"nccl send"); - m.def("ncclRecv",&pyNCCLRecv,"nccl recv"); - m.def("ncclCommCount",&pyNCCLCommCount,"nccl comm count"); - m.def("ncclCommUserRank",&pyNCCLCommUserRank,"nccl comm user rank"); + m.def("ncclSend", &pyNCCLSend, "nccl send"); + m.def("ncclRecv", &pyNCCLRecv, "nccl recv"); + m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count"); + m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank"); + + py::class_(m, "CUDAEventScope") + .def(py::init(&CUDAEventScope::create)) + .def("recordStart", &CUDAEventScope::recordStart) + .def("recordEnd", &CUDAEventScope::recordEnd); + + py::class_(m, "WatchDog") + .def(py::init(&PyWatchDog::create)) + .def("watch", &PyWatchDog::watch); } diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index 81a8ec9a..4e3c9235 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -11,6 +11,16 @@ #include "cpu_info.h" #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +static inline float _mm256_reduce_add_ps(__m256 x) { + /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + return _mm_cvtss_f32(x32); +} inline float fp32_from_bits(uint32_t w) { union { @@ -121,11 +131,47 @@ inline float fp16_ieee_to_fp32_value(uint16_t h) { return fp32_from_bits(result); } -// fp32 -> bf16 inline uint16_t bf16_from_fp32_value(float f){ return *reinterpret_cast(&f) >> 16; } +// fp32 -> bf16 +void bf16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16 +){ + int span = 1; + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto param_bf16_ptr = reinterpret_cast(param_bf16); + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float p = param_fp32_ptr[i]; + param_bf16_ptr[i] = bf16_from_fp32_value(p); + } + break; // must break here + } + }); +} +void fp16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16 +){ + int span = 1; + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto param_fp16_ptr = reinterpret_cast(param_fp16); + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float p = param_fp32_ptr[i]; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + } + break; // must break here + } + }); +} // bf16 -> fp32 inline float bf16_to_fp32_value(uint16_t h){ uint32_t src = h; @@ -137,6 +183,7 @@ void adam_cpu_0( int64_t n, float* param_fp32_ptr, uint16_t* param_fp16_ptr, + float* delta_info_ptr, uint16_t* g_fp16_ptr, float* m_fp32_ptr, float* v_fp32_ptr, @@ -148,7 +195,12 @@ void adam_cpu_0( float bias_correction2 ){ int64_t span = 1; + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; for (int64_t j = start; j < end; j += span) { for (int64_t i = j; i < end; i++) { float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; @@ -157,6 +209,11 @@ void adam_cpu_0( float p = param_fp32_ptr[i]; m = beta1 * m + (1 - beta1) * g; v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; param_fp32_ptr[i] = p; param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); @@ -165,13 +222,24 @@ void adam_cpu_0( } break; // must break here } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } }); + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; } void adam_cpu_bf16_0( int64_t n, float* param_fp32_ptr, uint16_t* param_bf16_ptr, + float* delta_info_ptr, uint16_t* g_bf16_ptr, float* m_fp32_ptr, float* v_fp32_ptr, @@ -183,7 +251,12 @@ void adam_cpu_bf16_0( float bias_correction2 ){ int64_t span = 1; + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; for (int64_t j = start; j < end; j += span) { for (int64_t i = j; i < end; i++) { float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale; @@ -192,6 +265,11 @@ void adam_cpu_bf16_0( float p = param_fp32_ptr[i]; m = beta1 * m + (1 - beta1) * g; v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; param_fp32_ptr[i] = p; param_bf16_ptr[i] = bf16_from_fp32_value(p); @@ -200,13 +278,26 @@ void adam_cpu_bf16_0( } break; // must break here } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( int64_t n, float* param_fp32_ptr, uint16_t* param_fp16_ptr, + float* delta_info_ptr, uint16_t* g_fp16_ptr, float* m_fp32_ptr, float* v_fp32_ptr, @@ -217,6 +308,9 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( float bias_correction1, float bias_correction2 ){ + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; auto avx_beta1 = _mm256_set1_ps(beta1); auto avx_beta2 = _mm256_set1_ps(beta2); auto avx_beta1_1 = _mm256_set1_ps(1 - beta1); @@ -229,6 +323,8 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( auto avx_bias_correction2 = _mm256_set1_ps(bias_correction2); int64_t span = 8; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; for (int64_t j = start; j < end; j += span) { if (j + span > end) { for (int64_t i = j; i < end; i++) { @@ -238,6 +334,11 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( float p = param_fp32_ptr[i]; m = beta1 * m + (1 - beta1) * g; v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; param_fp32_ptr[i] = p; param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); @@ -252,6 +353,17 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( auto p = _mm256_loadu_ps(¶m_fp32_ptr[j]); m = _mm256_fmadd_ps(avx_beta1, m, _mm256_mul_ps(avx_beta1_1, g)); v = _mm256_fmadd_ps(avx_beta2, v, _mm256_mul_ps(avx_beta2_1, _mm256_mul_ps(g, g))); + if (delta_info_ptr != NULL){ + auto delta_256 = _mm256_add_ps( + _mm256_div_ps( + _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps + ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + _mm256_mul_ps(avx_weight_decay, p) // weight_decay * p + ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p + sum_delta_i += _mm256_reduce_add_ps(delta_256); + sum_sq_delta_i += _mm256_reduce_add_ps(_mm256_mul_ps(delta_256, delta_256)); + } p = _mm256_fmadd_ps(avx_neg_lr, _mm256_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p p = _mm256_fmadd_ps( avx_neg_lr, @@ -267,13 +379,26 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( _mm256_storeu_ps(&v_fp32_ptr[j], v); } } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( int64_t n, float* param_fp32_ptr, uint16_t* param_fp16_ptr, + float* delta_info_ptr, uint16_t* g_fp16_ptr, float* m_fp32_ptr, float* v_fp32_ptr, @@ -284,6 +409,9 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( float bias_correction1, float bias_correction2 ){ + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; auto avx_beta1 = _mm512_set1_ps(beta1); auto avx_beta2 = _mm512_set1_ps(beta2); auto avx_beta1_1 = _mm512_set1_ps(1 - beta1); @@ -296,6 +424,8 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( auto avx_bias_correction2 = _mm512_set1_ps(bias_correction2); int64_t span = 16; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; for (int64_t j = start; j < end; j += span) { if (j + span > end) { for (int64_t i = j; i < end; i++) { @@ -305,6 +435,11 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( float p = param_fp32_ptr[i]; m = beta1 * m + (1 - beta1) * g; v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; param_fp32_ptr[i] = p; param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); @@ -319,6 +454,17 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( auto p = _mm512_loadu_ps(¶m_fp32_ptr[j]); m = _mm512_fmadd_ps(avx_beta1, m, _mm512_mul_ps(avx_beta1_1, g)); v = _mm512_fmadd_ps(avx_beta2, v, _mm512_mul_ps(avx_beta2_1, _mm512_mul_ps(g, g))); + if (delta_info_ptr != NULL){ + auto delta_512 = _mm512_add_ps( + _mm512_div_ps( + _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm512_add_ps(_mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps + ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + _mm512_mul_ps(avx_weight_decay, p) // weight_decay * p + ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p + sum_delta_i += _mm512_reduce_add_ps(delta_512); + sum_sq_delta_i += _mm512_reduce_add_ps(_mm512_mul_ps(delta_512, delta_512)); + } p = _mm512_fmadd_ps(avx_neg_lr, _mm512_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p p = _mm512_fmadd_ps( avx_neg_lr, @@ -337,13 +483,26 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( _mm512_storeu_ps(&v_fp32_ptr[j], v); } } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } void adam_cpu_fp16_launcher( int64_t n, std::uintptr_t param_fp32, std::uintptr_t param_fp16, + std::uintptr_t delta_info, std::uintptr_t g_fp16, std::uintptr_t m_fp32, std::uintptr_t v_fp32, @@ -354,7 +513,7 @@ void adam_cpu_fp16_launcher( float bias_correction1, float bias_correction2 ) { - + auto delta_info_ptr = reinterpret_cast(delta_info); auto param_fp32_ptr = reinterpret_cast(param_fp32); auto m_fp32_ptr = reinterpret_cast(m_fp32); auto v_fp32_ptr = reinterpret_cast(v_fp32); @@ -362,11 +521,11 @@ void adam_cpu_fp16_launcher( auto g_fp16_ptr = reinterpret_cast(g_fp16); int cpu_level = get_cpu_level(); if (cpu_level == 0 ){ - adam_cpu_0(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_cpu_0(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); }else if(cpu_level == 1){ - adam_cpu_1(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_cpu_1(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); }else{ - adam_cpu_2(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_cpu_2(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); } } @@ -374,6 +533,7 @@ void adam_cpu_bf16_launcher( int64_t n, std::uintptr_t param_fp32, std::uintptr_t param_bf16, + std::uintptr_t delta_info, std::uintptr_t g_bf16, std::uintptr_t m_fp32, std::uintptr_t v_fp32, @@ -384,10 +544,11 @@ void adam_cpu_bf16_launcher( float bias_correction1, float bias_correction2 ) { - auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto delta_info_ptr = reinterpret_cast(delta_info); auto m_fp32_ptr = reinterpret_cast(m_fp32); auto v_fp32_ptr = reinterpret_cast(v_fp32); + auto param_fp32_ptr = reinterpret_cast(param_fp32); auto param_bf16_ptr = reinterpret_cast(param_bf16); auto g_bf16_ptr = reinterpret_cast(g_bf16); - adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, delta_info_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); } diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp index a9067a2a..bcfcb14e 100644 --- a/csrc/include/bind.hpp +++ b/csrc/include/bind.hpp @@ -4,9 +4,17 @@ int is_bf16_supported(); -void has_nan_inf_fp16_launcher(int32_t n,std::uintptr_t g_fp16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); -void has_nan_inf_bf16_launcher(int32_t n,std::uintptr_t g_bf16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); +void has_nan_inf_fp16_launcher(int32_t n, std::uintptr_t g_fp16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream); +void has_nan_inf_bf16_launcher(int32_t n, std::uintptr_t g_bf16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream); +void fp16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16); +void bf16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16); void cross_entropy_forward_fp16_launcher( int32_t m, int32_t n, std::uintptr_t input, @@ -14,16 +22,14 @@ void cross_entropy_forward_fp16_launcher( std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, - std::uintptr_t stream -); + std::uintptr_t stream); void cross_entropy_backward_inplace_fp16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, std::uintptr_t x, int32_t ignore_index, - std::uintptr_t stream -); + std::uintptr_t stream); void cross_entropy_forward_bf16_launcher( int32_t m, int32_t n, std::uintptr_t input, @@ -31,8 +37,7 @@ void cross_entropy_forward_bf16_launcher( std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, - std::uintptr_t stream -); + std::uintptr_t stream); void cross_entropy_backward_inplace_bf16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, @@ -82,8 +87,7 @@ void adam_fp16_launcher( float weight_decay, float bias_correction1, float bias_correction2, - uintptr_t stream -); + uintptr_t stream); void adam_bf16_launcher( int n, std::uintptr_t param_fp32, From 8cbb576f9fc5b680175bac04bc48332614c54878 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 23 Feb 2024 10:57:12 +0800 Subject: [PATCH 47/52] fix adam offload return_delta and formatting code --- bmtrain/optim/_function.py | 6 +++++- bmtrain/optim/adam_offload.py | 4 ---- csrc/bind.cpp | 9 --------- csrc/include/adam_cpu.hpp | 10 ++++++---- csrc/include/bind.hpp | 18 ++++++++++++------ 5 files changed, 23 insertions(+), 24 deletions(-) diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py index b72ea8e0..d4584457 100644 --- a/bmtrain/optim/_function.py +++ b/bmtrain/optim/_function.py @@ -30,11 +30,15 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, delta_info: tor assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor" assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor" assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor" - #TODO check avg_delta and var_delta assert param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements" assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements" assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements" assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" + if delta_info is not None: + assert delta_info.is_contiguous(), "delta_info must be contiguous" + assert delta_info.dtype == torch.float32, "delta_info must be float32 tensor" + assert delta_info.device == torch.device("cpu"), "delta_info must be a cpu tensor" + assert delta_info.numel() == 4, "delta_info have a length of 4" bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step if g_fp16.dtype == torch.float16: diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index 7bbc7e46..c088a5ee 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -223,10 +223,6 @@ def load_state_dict(self, state_dict: dict) -> None: with torch.no_grad(): v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") v["_param_fp32"].copy_(param) - - if "_param_fp32" not in v: - v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") - v["_param_fp32"].copy_(param) for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: if name in v: diff --git a/csrc/bind.cpp b/csrc/bind.cpp index db5aedc8..b8f6fa85 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -32,13 +32,4 @@ PYBIND11_MODULE(C, m) { m.def("ncclRecv", &pyNCCLRecv, "nccl recv"); m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count"); m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank"); - - py::class_(m, "CUDAEventScope") - .def(py::init(&CUDAEventScope::create)) - .def("recordStart", &CUDAEventScope::recordStart) - .def("recordEnd", &CUDAEventScope::recordEnd); - - py::class_(m, "WatchDog") - .def(py::init(&PyWatchDog::create)) - .def("watch", &PyWatchDog::watch); } diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index 4e3c9235..da6bfe8b 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -229,10 +229,12 @@ void adam_cpu_0( delta_mutex.unlock(); } }); - delta_info_ptr[0] = sum_delta / n; - delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 - delta_info_ptr[2] = sum_delta; - delta_info_ptr[3] = sum_sq_delta; + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } void adam_cpu_bf16_0( diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp index bcfcb14e..3ff967fd 100644 --- a/csrc/include/bind.hpp +++ b/csrc/include/bind.hpp @@ -10,11 +10,13 @@ void has_nan_inf_bf16_launcher(int32_t n, std::uintptr_t g_bf16, std::uintptr_t void fp16_from_fp32_value_launcher( int64_t n, std::uintptr_t param_fp32, - std::uintptr_t param_fp16); + std::uintptr_t param_fp16 +); void bf16_from_fp32_value_launcher( int64_t n, std::uintptr_t param_fp32, - std::uintptr_t param_bf16); + std::uintptr_t param_bf16 +); void cross_entropy_forward_fp16_launcher( int32_t m, int32_t n, std::uintptr_t input, @@ -22,14 +24,16 @@ void cross_entropy_forward_fp16_launcher( std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, - std::uintptr_t stream); + std::uintptr_t stream +); void cross_entropy_backward_inplace_fp16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, std::uintptr_t x, int32_t ignore_index, - std::uintptr_t stream); + std::uintptr_t stream +); void cross_entropy_forward_bf16_launcher( int32_t m, int32_t n, std::uintptr_t input, @@ -37,7 +41,8 @@ void cross_entropy_forward_bf16_launcher( std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, - std::uintptr_t stream); + std::uintptr_t stream +); void cross_entropy_backward_inplace_bf16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, @@ -87,7 +92,8 @@ void adam_fp16_launcher( float weight_decay, float bias_correction1, float bias_correction2, - uintptr_t stream); + uintptr_t stream +); void adam_bf16_launcher( int n, std::uintptr_t param_fp32, From fef8a7aa0f883cfdf25083363122b6f8b99f2d80 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 23 Feb 2024 12:05:58 +0800 Subject: [PATCH 48/52] workflow test --- .github/workflows/build.yml | 33 ++++++++++++++++++++ .github/workflows/build_whl.yml | 55 +++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/build_whl.yml diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000..2f2b3342 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,33 @@ +name: Build wheel + +on: + pull_request: + types: [opened, reopened] + branches: + - 'dev' + +jobs: + build-upload-wheel: + + uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@workflow + secrets: + DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + + publish: + needs: build + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set Up the Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Download distribution files + uses: actions/download-artifact@v2 + with: + name: dist + path: dist diff --git a/.github/workflows/build_whl.yml b/.github/workflows/build_whl.yml new file mode 100644 index 00000000..b8bbfcb9 --- /dev/null +++ b/.github/workflows/build_whl.yml @@ -0,0 +1,55 @@ +name: Build wheel + +on: + workflow_call: + secrets: + DOCKERHUB_TOKEN: + required: true + DOCKERHUB_USERNAME: + required: true + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['37', '38', '39', '310', '311'] + + + steps: + + - name: Check the disk space and clear unnecessary library + run: | + rm -rf /home/runner/work/BMTrain/BMTrain/dist + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + df -hl + + - name: Checkout code + uses: actions/checkout@v3 + + - name: Login to DockerHub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Pull Docker image + run: docker pull pytorch/manylinux-cuda113:latest + + - name: Run Docker image and execute script + run: | + version=${{ matrix.python-version }} + docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done" + + - name: Archive distribution files + uses: actions/upload-artifact@v2 + with: + name: dist + path: | + dist/*.tar.gz + dist/*.whl + + \ No newline at end of file From 75b84acb69c7e13b8da5350c8c16dd15a285d9ef Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 23 Feb 2024 12:09:06 +0800 Subject: [PATCH 49/52] update the action config yaml --- .github/workflows/build.yml | 11 ++++---- .github/workflows/build_whl.yml | 2 +- .github/workflows/publish.yaml | 47 +++++---------------------------- .github/workflows/release.yml | 44 ++++-------------------------- 4 files changed, 18 insertions(+), 86 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2f2b3342..a6ce9c93 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,13 +1,14 @@ -name: Build wheel +name: Build on: pull_request: types: [opened, reopened] - branches: - - 'dev' + branches: + - 'dev' + - 'main' jobs: - build-upload-wheel: + build-archive-wheel: uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@workflow secrets: @@ -15,7 +16,7 @@ jobs: DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} publish: - needs: build + needs: build-archive-wheel runs-on: ubuntu-latest steps: - name: Checkout code diff --git a/.github/workflows/build_whl.yml b/.github/workflows/build_whl.yml index b8bbfcb9..e612b6e8 100644 --- a/.github/workflows/build_whl.yml +++ b/.github/workflows/build_whl.yml @@ -1,4 +1,4 @@ -name: Build wheel +name: Build wheels in docker and archive on: workflow_call: diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 94361768..98e579fb 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -7,50 +7,15 @@ on: - "v*.*.*" jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['37', '38', '39', '310', '311'] - - - steps: - - - name: Check the disk space and clear unnecessary library - run: | - rm -rf /home/runner/work/BMTrain/BMTrain/dist - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" - df -hl - - - name: Checkout code - uses: actions/checkout@v3 - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Pull Docker image - run: docker pull pytorch/manylinux-cuda113:latest - - name: Run Docker image and execute script - run: | - version=${{ matrix.python-version }} - docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done" - - - name: Archive distribution files - uses: actions/upload-artifact@v2 - with: - name: dist - path: | - dist/*.tar.gz - dist/*.whl + build-archive-wheel: + uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@workflow + secrets: + DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} publish: - needs: build + needs: build-archive-wheel runs-on: ubuntu-latest steps: - name: Set Up the Python diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 964d6c27..fff3bcfd 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,47 +6,13 @@ on: - "v*.*.*" jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['37', '38', '39', '310', '311'] - - steps: - - name: Check the disk space and clear unnecessary library - run: | - rm -rf /home/runner/work/BMTrain/BMTrain/dist - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" - df -hl + build-archive-wheel: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Login to DockerHub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Pull Docker image - run: docker pull pytorch/manylinux-cuda113:latest - - - name: Run Docker image and execute script - run: | - version=${{ matrix.python-version }} - docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build;/opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done" - - - name: Archive distribution files - uses: actions/upload-artifact@v2 - with: - name: dist - path: | - dist/*.tar.gz - dist/*.whl + uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@workflow + secrets: + DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} publish: needs: build From 341caa0a1c709705604dd995e2927f800e09ad4c Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 23 Feb 2024 13:11:33 +0800 Subject: [PATCH 50/52] add trigger for pr synchronize --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a6ce9c93..50f2b4d4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,7 +2,7 @@ name: Build on: pull_request: - types: [opened, reopened] + types: [opened, reopened, synchronize] branches: - 'dev' - 'main' From 119e5cfb25723020d4adabe03ae8e9e2e348a40b Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Fri, 23 Feb 2024 13:25:02 +0800 Subject: [PATCH 51/52] fix c++ include --- csrc/include/adam_cpu.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index da6bfe8b..52575d69 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include From f915f9424cf350eb39a28a612352321ba3ed5ae0 Mon Sep 17 00:00:00 2001 From: Maydomine <1583143678@qq.com> Date: Fri, 23 Feb 2024 17:18:57 +0800 Subject: [PATCH 52/52] Vocal parallel Embedding impl and make example work when tp_size > 1 --- bmtrain/nn/__init__.py | 4 +-- bmtrain/nn/parallel_embedding.py | 46 +++++++------------------------- example/layers/attention.py | 16 +++++------ example/layers/transformer.py | 2 +- example/models/gpt.py | 14 +++++----- example/train.py | 8 +++--- 6 files changed, 31 insertions(+), 59 deletions(-) diff --git a/bmtrain/nn/__init__.py b/bmtrain/nn/__init__.py index 85f45a3e..60fed663 100644 --- a/bmtrain/nn/__init__.py +++ b/bmtrain/nn/__init__.py @@ -1,5 +1,5 @@ from .linear import Linear, OpLinear from .column_parallel_linear import ColumnParallelLinear from .row_parallel_linear import RowParallelLinear -from .parallel_embedding import Projection, VPProjection -from .parallel_linear_func import OpParallelLinear \ No newline at end of file +from .parallel_embedding import VPEmbedding +from .parallel_linear_func import OpParallelLinear diff --git a/bmtrain/nn/parallel_embedding.py b/bmtrain/nn/parallel_embedding.py index 39aa147b..43e7397d 100644 --- a/bmtrain/nn/parallel_embedding.py +++ b/bmtrain/nn/parallel_embedding.py @@ -8,35 +8,8 @@ from bmtrain.distributed import all_reduce, all_gather from .parallel_linear_func import OpParallelLinear -class Projection(bmt.DistributedModule): - def __init__( - self, - vocab_size: int, - embedding_size: int, - dtype: torch.dtype = torch.half, - init_mean: float = 0.0, - init_std: float = 1, - ): - super().__init__() - - self.dim_model = embedding_size - self.weight = bmt.DistributedParameter( - torch.empty(vocab_size, embedding_size, dtype=dtype), - init_method=bmt.ParameterInitializer(torch.nn.init.normal_, mean=init_mean, std=init_std), - ) - - def forward(self, x: torch.Tensor): - """ - Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. - Args: - x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection - Returns: - :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. - """ # noqa: E501 - logits = F.linear(x, self.weight) - return logits -class VPProjection(bmt.DistributedModule): +class VPEmbedding(bmt.DistributedModule): def __init__( self, vocab_size: int, @@ -59,12 +32,11 @@ def __init__( tp_mode=True, ) - def forward(self, x: torch.Tensor): - """ - Projection based on embedding's weight. For example, embedding map vocab_size to embed_size, than projection map embed_size back to vocab_size. - Args: - x (:obj:`torch.Tensor` of shape ``(batch, seq_len, dim_model)``): Input of projection - Returns: - :obj:`torch.Tensor` of shape ``(batch, seq_len, vocab_output_size)``: The projection output. - """ # noqa: E501 - return bmt.nn.OpParallelLinear.apply(x, self.weight, None, False, False, False, None, 1) \ No newline at end of file + def forward(self, x: torch.Tensor, projection=False): + if not projection: + weight = all_gather(self.weight, comm=config['tp_comm']).flatten(0,1) + out = F.embedding(x, weight) + return out + else: + x = bmt.distributed.all_gather(x, comm=bmt.config['tp_comm']).view(x.shape[0], -1, x.shape[-1]) + return bmt.nn.OpParallelLinear.apply(x, self.weight, None, False, False, False, None, 1) diff --git a/example/layers/attention.py b/example/layers/attention.py index 32497bcb..0f5155d4 100644 --- a/example/layers/attention.py +++ b/example/layers/attention.py @@ -41,8 +41,7 @@ def forward(self, mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) ) -> torch.Tensor: - batch_size, seq_q, dim_model = hidden_q.size() - seq_kv = hidden_kv.size(1) + batch_size = hidden_q.size()[0] assert hidden_q.data_ptr() == hidden_kv.data_ptr() @@ -54,14 +53,16 @@ def forward(self, True, False, False, None ) + hidden_q = hidden_q.view(batch_size, -1, hidden_q.shape[-1]) h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) - #batch_size will changed in TensorParallel - batch_size = h_v.shape[0] else: h_q : torch.Tensor = self.project_q(hidden_q) h_k : torch.Tensor = self.project_k(hidden_kv) h_v : torch.Tensor = self.project_v(hidden_kv) + seq_q = h_q.size()[1] + seq_kv = h_k.size(1) + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) @@ -84,10 +85,6 @@ def forward(self, if position_bias is not None: score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) - if config['tp_size'] > 1: - with torch.no_grad(): - mask = all_gather(mask, config['tp_comm']).flatten(0,1) - score = torch.where( mask.view(batch_size, 1, seq_q, seq_kv), score, @@ -108,8 +105,11 @@ def forward(self, h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) h_out = h_out.permute(0, 2, 1, 3).contiguous() h_out = h_out.view(batch_size, seq_q, -1) + if config['tp_size'] > 1: + h_out = h_out.view(h_out.shape[0] * bmt.config["tp_size"], -1, h_out.shape[-1]) attn_out = self.project_out(h_out) + return attn_out diff --git a/example/layers/transformer.py b/example/layers/transformer.py index 7cda1bb9..4cbff59b 100644 --- a/example/layers/transformer.py +++ b/example/layers/transformer.py @@ -28,7 +28,7 @@ def forward(self, x = self.ln_ff(hidden) x = self.ff(x) - hidden = hidden + x + hidden = hidden + x.view_as(hidden) return hidden diff --git a/example/models/gpt.py b/example/models/gpt.py index 4596167c..ed604382 100644 --- a/example/models/gpt.py +++ b/example/models/gpt.py @@ -14,8 +14,8 @@ def __init__(self, self.max_distance = max_distance - if config['tp_size'] > 1: - self.word_emb = bmt.nn.ParallelEmbedding(vocab_size, dim_model, dtype=dtype) + if config["tp_size"] > 1: + self.word_emb = bmt.nn.VPEmbedding(vocab_size, dim_model, dtype=dtype) else: self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) @@ -50,17 +50,15 @@ def forward(self, mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len) mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) - + if config["tp_size"] > 1: + input = input.chunk(config["tp_size"], dim=1)[config["tp_rank"]] + pos = pos.chunk(config["tp_size"], dim=1)[config["tp_rank"]] out = self.pos_emb(pos) + self.word_emb(input) # for layer in self.transformers: out = self.transformers(out, mask_2d, None) out = self.layernorm(out) - - if config['tp_size'] > 1: - logits = self.word_emb.projection(out) - else: - logits = self.word_emb(out, projection=True) + logits = self.word_emb(out, projection=True) bmt.inspect.record_tensor(logits, "logits") return logits diff --git a/example/train.py b/example/train.py index 8aaf65e4..d5906a06 100644 --- a/example/train.py +++ b/example/train.py @@ -36,8 +36,10 @@ def main(): batch_size = 2 seq_len = 512 + world_size = bmt.config["world_size"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_size"] + r = bmt.config["rank"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_rank"] - for i in range(bmt.world_size()): + for i in range(world_size): sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() enc_input = sent[:, :-1].long().cuda() @@ -49,7 +51,7 @@ def main(): torch.full_like(targets, -100, dtype=torch.long) ) - if i == bmt.rank(): + if i == r: break if config['tp_size'] > 1: @@ -82,7 +84,7 @@ def main(): batch, seq_len, vocab_out_size = logits.size() if config['tp_size'] > 1: - loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets) + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) else: loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))