diff --git a/.gitignore b/.gitignore index 4f61b8cbb..6246eec82 100644 --- a/.gitignore +++ b/.gitignore @@ -101,3 +101,4 @@ ENV/ .mypy_cache/ .DS_Store +.idea diff --git a/examples/gpt/hybrid_parallel/README.md b/examples/gpt/hybrid_parallel/README.md index 05ccfcd81..22d3a7f37 100644 --- a/examples/gpt/hybrid_parallel/README.md +++ b/examples/gpt/hybrid_parallel/README.md @@ -95,6 +95,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过 num_train_epochs: 1 seed: 1024 use_recompute: False + recompute_granularity: batch_size: global_batch_size: 8 local_batch_size: 8 @@ -113,6 +114,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过 save_steps: 1000 output_dir: ./output ckpt_dir: + fused_linear: False ``` 其中参数说明: @@ -124,6 +126,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过 | num_train_epochs | 训练的epoch数量 | | seed | 随机种子,保证训练过程可复现 | | use_recompute | 是否使用recompute训练 | +| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 | | global_batch_size | 全局的batch size大小,即一次参数更新等效的batch size | | local_batch_size | 每个进程训练的batch size大小 | | micro_batch_size | 每次前向计算的batch size大小 | @@ -138,6 +141,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过 | save_steps | 保存模型间隔 | | output_dir | 指定输出文件 | | ckpt_dir | checkpoint的加载目录 | +| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 | ### 并行维度 diff --git a/examples/gpt/hybrid_parallel/configs_1.3B_dp8.yaml b/examples/gpt/hybrid_parallel/configs_1.3B_dp8.yaml index 0297c58a9..d4a4f2407 100644 --- a/examples/gpt/hybrid_parallel/configs_1.3B_dp8.yaml +++ b/examples/gpt/hybrid_parallel/configs_1.3B_dp8.yaml @@ -5,6 +5,7 @@ PreTraining: num_train_epochs: 1 seed: 1024 use_recompute: True + recompute_granularity: 'only_attn' batch_size: global_batch_size: 64 local_batch_size: 8 @@ -22,7 +23,8 @@ PreTraining: save_load: save_steps: 1000 output_dir: ./output - ckpt_dir: + ckpt_dir: + fused_linear: True Model: vocab_size: 50304 diff --git a/examples/gpt/hybrid_parallel/configs_175B_mp8_pp16.yaml b/examples/gpt/hybrid_parallel/configs_175B_mp8_pp16.yaml index 713b3dae4..3b1042223 100644 --- a/examples/gpt/hybrid_parallel/configs_175B_mp8_pp16.yaml +++ b/examples/gpt/hybrid_parallel/configs_175B_mp8_pp16.yaml @@ -5,6 +5,7 @@ PreTraining: num_train_epochs: 1 seed: 1024 use_recompute: True + recompute_granularity: 'only_attn' batch_size: global_batch_size: 1536 local_batch_size: 1536 @@ -22,7 +23,8 @@ PreTraining: save_load: save_steps: 1000 output_dir: ./output - ckpt_dir: + ckpt_dir: + fused_linear: True Model: vocab_size: 51200 diff --git a/examples/gpt/hybrid_parallel/configs_6.7B_sharding16.yaml b/examples/gpt/hybrid_parallel/configs_6.7B_sharding16.yaml index 8d8745cc8..fb34d46ea 100644 --- a/examples/gpt/hybrid_parallel/configs_6.7B_sharding16.yaml +++ b/examples/gpt/hybrid_parallel/configs_6.7B_sharding16.yaml @@ -5,6 +5,7 @@ PreTraining: num_train_epochs: 1 seed: 1024 use_recompute: True + recompute_granularity: 'only_attn' batch_size: global_batch_size: 128 local_batch_size: 8 @@ -22,7 +23,8 @@ PreTraining: save_load: save_steps: 1000 output_dir: ./output - ckpt_dir: + ckpt_dir: + fused_linear: True Model: vocab_size: 50304 diff --git a/examples/gpt/single/README.md b/examples/gpt/single/README.md index ba4f5565e..22f3c0828 100644 --- a/examples/gpt/single/README.md +++ b/examples/gpt/single/README.md @@ -85,6 +85,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过 num_train_epochs: 1 seed: 1024 use_recompute: False + recompute_granularity: batch_size: global_batch_size: 8 local_batch_size: 8 @@ -103,6 +104,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过 save_steps: 1000 output_dir: ./output ckpt_dir: + fused_linear: False ``` 其中参数说明: @@ -114,6 +116,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过 | num_train_epochs | 训练的epoch数量 | | seed | 随机种子,保证训练过程可复现 | | use_recompute | 是否使用recompute训练 | +| recompute_granularity | recompute训练的粒度,可选 `full` `only_attn`,full即recompute全部transformer,only_attn表明只recompute self attention部分 | | global_batch_size | 全局的batch size大小,即一次参数更新等效的batch size | | local_batch_size | 每个进程训练的batch size大小 | | micro_batch_size | 每次前向计算的batch size大小 | @@ -128,6 +131,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过 | save_steps | 保存模型间隔 | | output_dir | 指定输出文件 | | ckpt_dir | checkpoint的加载目录 | +| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 | ## 运行方式 diff --git a/examples/gpt/single/configs_1.3B_single_card.yaml b/examples/gpt/single/configs_1.3B_single_card.yaml index b0cdabc19..7c00b1b08 100644 --- a/examples/gpt/single/configs_1.3B_single_card.yaml +++ b/examples/gpt/single/configs_1.3B_single_card.yaml @@ -5,6 +5,7 @@ PreTraining: num_train_epochs: 1 seed: 1024 use_recompute: True + recompute_granularity: 'only_attn' batch_size: global_batch_size: 8 local_batch_size: 8 @@ -22,7 +23,8 @@ PreTraining: save_load: save_steps: 1000 output_dir: ./output - ckpt_dir: + ckpt_dir: + fused_linear: True Model: vocab_size: 50304 diff --git a/examples/gpt/single/configs_345m_single_card.yaml b/examples/gpt/single/configs_345m_single_card.yaml index 27c9cf9ac..07c8754ce 100644 --- a/examples/gpt/single/configs_345m_single_card.yaml +++ b/examples/gpt/single/configs_345m_single_card.yaml @@ -5,6 +5,7 @@ PreTraining: num_train_epochs: 1 seed: 1024 use_recompute: False + recompute_granularity: batch_size: global_batch_size: 8 local_batch_size: 8 @@ -22,7 +23,8 @@ PreTraining: save_load: save_steps: 1000 output_dir: ./output - ckpt_dir: + ckpt_dir: + fused_linear: True Model: vocab_size: 50304 diff --git a/examples/gpt/tools.py b/examples/gpt/tools.py index d1e1488da..fcf1dfff4 100644 --- a/examples/gpt/tools.py +++ b/examples/gpt/tools.py @@ -16,12 +16,14 @@ from __future__ import division from __future__ import print_function +import logging import os import sys import yaml import paddle import paddle.distributed as dist +from paddle.fluid import core import argparse from fleetx.datasets.gpt import create_pretrained_dataset, get_train_data_file @@ -49,6 +51,13 @@ def process_batch_size(args): assert args.local_batch_size % args.micro_batch_size == 0 +def is_fused_matmul_bias_supported(): + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): + return hasattr(core.ops, 'fused_gemm_epilogue') + else: + return False + + def model_size(args): """ get model size for transformer @@ -84,6 +93,22 @@ def add_dict(config, k, v): args.test_iters = args.eval_iters * 10 + if args.fused_linear and not is_fused_matmul_bias_supported(): + args.fused_linear = False + logging.warning("The flag fused_linear only valid for cuda version higher than 11.6, " + "but the paddle is compiled with cuda " + paddle.version.cuda()) + + if args.recompute: + assert args.recompute_granularity is None or \ + isinstance(args.recompute_granularity, str), \ + "recompute_granularity must be a None or a string object" + if args.recompute_granularity is None: + args.recompute_granularity = "full" + else: + assert args.recompute_granularity in ["full", "only_attn"], \ + "recompute_granularity can be only chosen from " \ + "full or only_attn, but received " + args.recompute_granularity + # process batch size process_batch_size(args) diff --git a/fleetx/models/gpt_model/modeling.py b/fleetx/models/gpt_model/modeling.py index 812a6e0dd..83fff2bcc 100644 --- a/fleetx/models/gpt_model/modeling.py +++ b/fleetx/models/gpt_model/modeling.py @@ -15,6 +15,8 @@ # limitations under the License. import collections +import logging + import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -24,6 +26,7 @@ import paddle.incubate as incubate from paddle.distributed.fleet.utils import recompute from .config import configurable +from paddle.incubate.nn import FusedLinear class MultiHeadAttention(nn.Layer): @@ -46,7 +49,8 @@ def __init__(self, need_weights=False, weight_attr=None, bias_attr=None, - fuse=True): + fuse=True, + fused_linear=False): super(MultiHeadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -59,19 +63,21 @@ def __init__(self, self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + Linear = FusedLinear if fused_linear else nn.Linear + if self.fuse: assert self.kdim == embed_dim assert self.vdim == embed_dim - self.qkv_proj = nn.Linear( + self.qkv_proj = Linear( embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr) else: - self.q_proj = nn.Linear( + self.q_proj = Linear( embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) - self.k_proj = nn.Linear( + self.k_proj = Linear( self.kdim, embed_dim, weight_attr, bias_attr=bias_attr) - self.v_proj = nn.Linear( + self.v_proj = Linear( self.vdim, embed_dim, weight_attr, bias_attr=bias_attr) - self.out_proj = nn.Linear( + self.out_proj = Linear( embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) def _fuse_prepare_qkv(self, query): @@ -221,13 +227,15 @@ def __init__(self, num_layers, norm=None, hidden_size=None, - use_recompute=False): + use_recompute=False, + recompute_granularity="full"): super(TransformerDecoder, self).__init__() self.num_layers = num_layers self.layers = decoder_layers self.norm = norm self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity if norm == "LayerNorm": self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5) elif norm is not None: @@ -258,9 +266,10 @@ def forward(self, cache=cache) new_caches.append(new_cache) else: - output = recompute(mod, output, memory, tgt_mask, use_cache, cache) if self.use_recompute \ - else mod(output, memory, tgt_mask, use_cache, cache) - + if self.use_recompute and self.recompute_granularity == "full": + output = recompute(mod, output, memory, tgt_mask, use_cache, cache) + else: + output = mod(output, memory, tgt_mask, use_cache, cache) else: output, new_cache = mod(output, memory, @@ -304,7 +313,9 @@ def __init__(self, act_dropout=None, normalize_before=True, weight_attr=None, - bias_attr=None): + bias_attr=None, + fused_linear=False, + recompute_attn=False): self._config = locals() self._config.pop("self") self._config.pop("__class__", None) # py3 @@ -313,19 +324,23 @@ def __init__(self, attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before + self.recompute_attn = recompute_attn weight_attrs = _convert_param_attr_to_list(weight_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3) + Linear = FusedLinear if fused_linear else nn.Linear + self.self_attn = MultiHeadAttention( d_model, nhead, dropout=attn_dropout, weight_attr=weight_attrs[0], - bias_attr=bias_attrs[0]) - self.linear1 = nn.Linear( + bias_attr=bias_attrs[0], + fused_linear=fused_linear) + self.linear1 = Linear( d_model, dim_feedforward, weight_attrs[2], bias_attr=bias_attrs[2]) - self.linear2 = nn.Linear( + self.linear2 = Linear( dim_feedforward, d_model, weight_attrs[2], bias_attr=bias_attrs[2]) self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) @@ -341,7 +356,10 @@ def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None): tgt = self.norm1(tgt) if use_cache is False: - tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) + if self.recompute_attn: + tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache) + else: + tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) else: tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) @@ -421,10 +439,14 @@ def __init__(self, max_position_embeddings=512, type_vocab_size=16, use_recompute=False, - initializer_range=0.02): + initializer_range=0.02, + fused_linear=False, + recompute_granularity="full"): super(GPTModel, self).__init__() + recompute_attn = use_recompute and recompute_granularity == "only_attn" + self.initializer_range = initializer_range self.hidden_size = hidden_size self.vocab_size = vocab_size @@ -447,14 +469,17 @@ def __init__(self, weight_attr=paddle.ParamAttr( initializer=nn.initializer.Normal( mean=0.0, std=self.initializer_range)), - bias_attr=None)) + bias_attr=None, + fused_linear=fused_linear, + recompute_attn=recompute_attn)) self.decoder = TransformerDecoder( decoder_layers, num_layers, norm="LayerNorm", hidden_size=hidden_size, - use_recompute=use_recompute) + use_recompute=use_recompute, + recompute_granularity=recompute_granularity) @classmethod def from_config(cls, cfg): @@ -469,7 +494,9 @@ def from_config(cls, cfg): "max_position_embeddings": cfg.max_position_embeddings, "type_vocab_size": cfg.type_vocab_size, "initializer_range": cfg.initializer_range, - "use_recompute": cfg.use_recompute + "use_recompute": cfg.use_recompute, + "fused_linear": cfg.fused_linear, + "recompute_granularity": cfg.recompute_granularity } def forward(self, diff --git a/fleetx/models/gpt_model/modeling_hybrid.py b/fleetx/models/gpt_model/modeling_hybrid.py index b46b5566e..578dd9cee 100644 --- a/fleetx/models/gpt_model/modeling_hybrid.py +++ b/fleetx/models/gpt_model/modeling_hybrid.py @@ -15,6 +15,8 @@ # limitations under the License. import collections +import logging + import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -76,7 +78,8 @@ def __init__(self, weight_attr=None, bias_attr=None, fuse=True, - num_partitions=1): + num_partitions=1, + fused_linear=False): super(MultiHeadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -102,35 +105,40 @@ def __init__(self, 3 * embed_dim, weight_attr=weight_attr, has_bias=True, - gather_output=False) + gather_output=False, + fuse_matmul_bias=fused_linear) else: self.q_proj = fleet.meta_parallel.ColumnParallelLinear( embed_dim, embed_dim, weight_attr=weight_attr, has_bias=True, - gather_output=False) + gather_output=False, + fuse_matmul_bias=fused_linear) self.k_proj = fleet.meta_parallel.ColumnParallelLinear( self.kdim, embed_dim, weight_attr=weight_attr, has_bias=True, - gather_output=False) + gather_output=False, + fuse_matmul_bias=fused_linear) self.v_proj = fleet.meta_parallel.ColumnParallelLinear( self.vdim, embed_dim, weight_attr=weight_attr, has_bias=True, - gather_output=False) + gather_output=False, + fuse_matmul_bias=fused_linear) self.out_proj = fleet.meta_parallel.RowParallelLinear( embed_dim, embed_dim, weight_attr=weight_attr, has_bias=True, - input_is_parallel=True) + input_is_parallel=True, + fuse_matmul_bias=fused_linear) def _fuse_prepare_qkv(self, query): mix_layer = self.qkv_proj(query) @@ -280,13 +288,15 @@ def __init__(self, num_layers, norm=None, hidden_size=None, - use_recompute=False): + use_recompute=False, + recompute_granularity="full"): super(TransformerDecoder, self).__init__() self.num_layers = num_layers self.layers = decoder_layers self.norm = norm self.use_recompute = use_recompute + self.recompute_granularity = recompute_granularity if norm == "LayerNorm": self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5) elif norm is not None: @@ -317,8 +327,10 @@ def forward(self, cache=cache) new_caches.append(new_cache) else: - output = recompute(mod, output, memory, tgt_mask, use_cache, cache) if self.use_recompute \ - else mod(output, memory, tgt_mask, use_cache, cache) + if self.use_recompute and self.recompute_granularity == "full": + output = recompute(mod, output, memory, tgt_mask, use_cache, cache) + else: + output = mod(output, memory, tgt_mask, use_cache, cache) else: output, new_cache = mod(output, @@ -364,7 +376,9 @@ def __init__(self, normalize_before=True, weight_attr=None, bias_attr=None, - num_partitions=1): + num_partitions=1, + fused_linear=False, + recompute_attn=False): self._config = locals() self._config.pop("self") self._config.pop("__class__", None) # py3 @@ -373,6 +387,7 @@ def __init__(self, attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before + self.recompute_attn = recompute_attn weight_attrs = _convert_param_attr_to_list(weight_attr, 3) bias_attrs = _convert_param_attr_to_list(bias_attr, 3) @@ -383,21 +398,24 @@ def __init__(self, dropout=attn_dropout, weight_attr=weight_attrs[0], bias_attr=bias_attrs[0], - num_partitions=num_partitions) + num_partitions=num_partitions, + fused_linear=fused_linear) self.linear1 = fleet.meta_parallel.ColumnParallelLinear( d_model, dim_feedforward, weight_attr=weight_attrs[2], gather_output=False, - has_bias=True) + has_bias=True, + fuse_matmul_bias=fused_linear) self.linear2 = fleet.meta_parallel.RowParallelLinear( dim_feedforward, d_model, weight_attr=weight_attrs[2], input_is_parallel=True, - has_bias=True) + has_bias=True, + fuse_matmul_bias=fused_linear) self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5) @@ -417,7 +435,10 @@ def forward(self, tgt = self.norm1(tgt) if use_cache is False: - tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) + if self.recompute_attn: + tgt = recompute(self.self_attn, tgt, None, None, tgt_mask, use_cache, cache) + else: + tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) else: tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) @@ -505,10 +526,14 @@ def __init__(self, type_vocab_size=16, initializer_range=0.02, num_partitions=1, - use_recompute=False): + use_recompute=False, + fused_linear=False, + recompute_granularity="full"): super(GPTModel, self).__init__() + recompute_attn = use_recompute and recompute_granularity == "only_attn" + self.initializer_range = initializer_range self.hidden_size = hidden_size self.vocab_size = vocab_size @@ -532,14 +557,17 @@ def __init__(self, initializer=nn.initializer.Normal( mean=0.0, std=self.initializer_range)), bias_attr=None, - num_partitions=num_partitions)) + num_partitions=num_partitions, + fused_linear=fused_linear, + recompute_attn=recompute_attn)) self.decoder = TransformerDecoder( decoder_layers, num_layers, norm="LayerNorm", hidden_size=hidden_size, - use_recompute=use_recompute) + use_recompute=use_recompute, + recompute_granularity=recompute_granularity) @classmethod def from_config(cls, cfg): @@ -555,7 +583,9 @@ def from_config(cls, cfg): "type_vocab_size": cfg.type_vocab_size, "initializer_range": cfg.initializer_range, "num_partitions": cfg.mp_degree, - "use_recompute": cfg.use_recompute + "use_recompute": cfg.use_recompute, + "fused_linear": cfg.fused_linear, + "recompute_granularity": cfg.recompute_granularity } def forward(self, @@ -727,7 +757,11 @@ def __init__(self, initializer_range=0.02, num_partitions=1, topology=None, - use_recompute=False): + use_recompute=False, + fused_linear=False, + recompute_granularity="full"): + + recompute_attn = use_recompute and recompute_granularity == "only_attn" # forward desc self.descs = [] @@ -759,7 +793,9 @@ def __init__(self, initializer=nn.initializer.Normal( mean=0.0, std=initializer_range)), bias_attr=None, - num_partitions=num_partitions)) + num_partitions=num_partitions, + fused_linear=fused_linear, + recompute_attn=recompute_attn)) self.descs.append( LayerDesc( @@ -781,12 +817,16 @@ def _logits_helper(embedding, output): type_vocab_size=type_vocab_size, initializer_range=0.02)) + recompute_interval = 0 + if recompute and not recompute_attn: + recompute_interval = 1 + super().__init__( layers=self.descs, loss_fn=GPTPretrainingCriterionPipe(), topology=topology, seg_method="layer:TransformerDecoderLayer", - recompute_interval=1 if use_recompute else 0, + recompute_interval=recompute_interval, recompute_partition=False, recompute_offload=False) @@ -805,5 +845,7 @@ def from_config(cls, cfg): "initializer_range": cfg.initializer_range, "num_partitions": cfg.mp_degree, "use_recompute": cfg.use_recompute, - "topology": cfg.topology + "topology": cfg.topology, + "fused_linear": cfg.fused_linear, + "recompute_granularity": cfg.recompute_granularity }