From c86d48b9ec296ba865719760933b62efb1426b8f Mon Sep 17 00:00:00 2001 From: Mddct Date: Sat, 1 Jun 2024 09:09:09 +0800 Subject: [PATCH] [LLM] support causllm model --- wenet/LLM/decoder.py | 161 +++++++++++++++++++++++++++++ wenet/transformer/asr_model.py | 3 + wenet/transformer/attention.py | 69 ++++++++----- wenet/transformer/decoder.py | 2 + wenet/transformer/embedding.py | 10 +- wenet/transformer/encoder_layer.py | 13 ++- wenet/transformer/norm.py | 7 +- wenet/utils/fsdp_utils.py | 3 + wenet/utils/init_model.py | 45 ++++++-- wenet/utils/rope_utils.py | 6 ++ wenet/utils/train_utils.py | 25 +++-- 11 files changed, 295 insertions(+), 49 deletions(-) create mode 100644 wenet/LLM/decoder.py diff --git a/wenet/LLM/decoder.py b/wenet/LLM/decoder.py new file mode 100644 index 0000000000..b25ee75dd6 --- /dev/null +++ b/wenet/LLM/decoder.py @@ -0,0 +1,161 @@ +from functools import partial +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint as ckpt +from wenet.transformer.attention import T_CACHE + +from wenet.transformer.encoder_layer import TransformerEncoderLayer +from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES, + WENET_ATTENTION_CLASSES, + WENET_EMB_CLASSES, WENET_MLP_CLASSES, + WENET_NORM_CLASSES) +from wenet.utils.common import mask_to_bias + + +class DecoderOnly(torch.nn.Module): + + def __init__( + self, + n_kv_head: int, + head_dim: int, + hidden_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + normalize_before: bool = True, + query_bias: bool = False, + key_bias: bool = False, + value_bias: bool = False, + mlp_bias: bool = False, + activation_type: str = "gelu", + gelu_approximate: Union[str, None] = None, + max_position_embeding: int = 8192, + mlp_type: str = 'gated', + layer_norm_type: str = 'rms_norm', + norm_eps: float = 1e-5, + rms_norm_offset: bool = True, + selfattention_layer_type: str = "rope_abs_selfattn", + use_sdpa: bool = False, + gradient_checkpointing: bool = False, + rope_theta: float = 10000.0, + rope_style: str = 'google', + scale_embed: bool = True, + ) -> None: + super().__init__() + + assert selfattention_layer_type in ['rope_abs_selfattn'] + self.pos_enc = WENET_EMB_CLASSES["rope_pos"]( + hidden_size, + head_dim, + max_len=max_position_embeding, + dropout_rate=positional_dropout_rate, + rope_theta=rope_theta, + scale=scale_embed) + if activation_type == "gelu" and gelu_approximate is not None: + activation = WENET_ACTIVATION_CLASSES['gelu']( + approximate=gelu_approximate) + else: + activation = WENET_ACTIVATION_CLASSES[activation_type]() + + mlp_class = WENET_MLP_CLASSES[mlp_type] + self.num_blocks = num_blocks + # TODO: support lora & refactor lora + self.decoders = torch.nn.ModuleList([ + TransformerEncoderLayer( + hidden_size, + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + attention_heads, + hidden_size, + attention_dropout_rate, + query_bias, + key_bias, + value_bias, + use_sdpa, + n_kv_head, + head_dim, + style=rope_style), + mlp_class(hidden_size, linear_units, dropout_rate, activation, + mlp_bias), + dropout_rate, + normalize_before, + layer_norm_type=layer_norm_type, + norm_eps=norm_eps, + rms_norm_offset=rms_norm_offset, + ) for _ in range(self.num_blocks) + ]) + self.pre_norm = normalize_before + self.final_norm: Optional[torch.nn.Module] = None + if self.pre_norm: + norm_class = WENET_NORM_CLASSES[layer_norm_type] + if layer_norm_type == "rms_norm": + norm_class = partial( + norm_class, + add_unit_offset=rms_norm_offset, + ) + self.final_norm = norm_class(hidden_size, eps=norm_eps) + + self.n_kv_head = n_kv_head + self.head_dim = head_dim + self._hidden_size = hidden_size + self.use_sdpa = use_sdpa + self.gradient_checkpointing = gradient_checkpointing + + def forward( + self, + input: torch.Tensor, + att_mask: torch.Tensor, + input_position: Union[int, torch.Tensor] = 0, + kv_caches: Optional[List[T_CACHE]] = None, + ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: + xs, pos_emb = self.pos_enc(input, offset=input_position) + if self.use_sdpa: + att_mask = mask_to_bias(att_mask, xs.dtype) + + if self.gradient_checkpointing and self.training: + xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb) + else: + xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb, + kv_caches) + if self.pre_norm and self.final_norm is not None: + xs = self.final_norm(xs) + return xs, kv_caches + + def forward_layers( + self, + xs: torch.Tensor, + att_mask: torch.Tensor, + pos_emb: torch.Tensor, + kv_caches: Optional[List[T_CACHE]] = None, + ) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]: + if self.training: + for (i, layer) in enumerate(self.decoders): + xs, _, _, _ = layer(xs, att_mask, pos_emb) + new_kv_caches = kv_caches + else: + assert kv_caches is not None + new_kv_caches = [] + for (i, layer) in enumerate(self.decoders): + xs, _, new_kv_cache, _ = layer(xs, + att_mask, + pos_emb, + att_cache=(kv_caches[i][0], + kv_caches[i][1])) + new_kv_caches.append(new_kv_cache) + + return xs, new_kv_caches + + @torch.jit.ignore(drop=True) + def forward_layers_checkpointed(self, xs: torch.Tensor, + att_mask: torch.Tensor, + pos_emb: torch.Tensor) -> torch.Tensor: + for layer in self.decoders: + xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask, + pos_emb) + return xs + + @property + def hidden_size(self): + return self._hidden_size diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 4099947b8f..a9cd77a455 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -133,6 +133,9 @@ def forward( "th_accuracy": acc_att, } + def tie_or_clone_weights(self, jit_mode: bool = True): + self.decoder.tie_or_clone_weights(jit_mode) + @torch.jit.unused def _forward_ctc( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 3ac4883aab..0fe042db4d 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -21,7 +21,7 @@ import torch from torch import nn -from wenet.utils.rope_utils import llama_apply_rotary_emb +from wenet.utils.rope_utils import WENET_APPLY_ROTARY_EMB T_CACHE = Tuple[torch.Tensor, torch.Tensor] @@ -80,7 +80,10 @@ def __init__(self, self.use_sdpa = use_sdpa self.dropout_rate = dropout_rate - def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor: + def _forward_linearx(self, + name: str, + x: torch.Tensor, + head_first: bool = True) -> torch.Tensor: assert x.ndim >= 3 if name == 'query': x = self.linear_q(x) @@ -98,7 +101,9 @@ def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor: # split last dim x = x.view(x_shape) - x = x.transpose(-3, -2) # (batch, ..., head or head_kv, time, d_k) + if head_first: + x = x.transpose(-3, + -2) # (batch, ..., head or head_kv, time, d_k) return x def forward_qkv( @@ -173,9 +178,15 @@ def forward_attention( return self.linear_out(x) # (batch, ..., time1, d_model) def _update_kv_and_cache( - self, k: torch.Tensor, v: torch.Tensor, - cache: T_CACHE) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE]: + self, + k: torch.Tensor, + v: torch.Tensor, + cache: T_CACHE, + head_first: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE]: new_cache = cache + seq_axis = -2 if head_first else -3 + head_axis = -3 if head_first else -2 if not self.training: # NOTE(xcsong): # when export onnx model, for 1st chunk, we feed @@ -195,9 +206,9 @@ def _update_kv_and_cache( # >>> torch.equal(d[0], d[1]) # True key_cache, value_cache = cache if key_cache.size(0) > 0: - k = torch.cat([key_cache, k], dim=2) + k = torch.cat([key_cache, k], dim=seq_axis) if value_cache.size(0) > 0: - v = torch.cat([value_cache, v], dim=2) + v = torch.cat([value_cache, v], dim=seq_axis) # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's # non-trivial to calculate `next_cache_start` here. # new_cache = torch.cat((k, v), dim=-1) if not self.training else cache @@ -218,17 +229,19 @@ def _update_kv_and_cache( # ) n_repeat = self.h // self.h_kv k_shape = k.size() - k = k.unsqueeze(-3).expand( - k_shape[:-2] + torch.Size([n_repeat]) + - k_shape[-2:]).reshape(k_shape[:-3] + - torch.Size([self.h_kv * n_repeat]) + - k_shape[-2:]) + repeat_axis = head_axis + 1 + k = k.unsqueeze(head_axis).expand( + k_shape[:repeat_axis] + torch.Size([n_repeat]) + + k_shape[repeat_axis:]).reshape( + k_shape[:head_axis] + torch.Size([self.h_kv * n_repeat]) + + k_shape[repeat_axis:]) v_shape = v.size() - v = v.unsqueeze(-3).expand( - v_shape[:-2] + torch.Size([n_repeat]) + - v_shape[-2:]).reshape(v_shape[:-3] + - torch.Size([self.h_kv * n_repeat]) + - v_shape[-2:]) + v = v.unsqueeze(head_axis).expand( + v_shape[:repeat_axis] + torch.Size([n_repeat]) + + v_shape[(repeat_axis):]).reshape( + v_shape[:head_axis] + torch.Size([self.h_kv * n_repeat]) + + v_shape[repeat_axis:]) + return k, v, new_cache def forward( @@ -594,9 +607,11 @@ def __init__(self, value_bias: bool = True, use_sdpa: bool = False, n_kv_head: Optional[int] = None, - head_dim: Optional[int] = None): + head_dim: Optional[int] = None, + style='google'): super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim) + self.style = style def forward( self, @@ -637,14 +652,22 @@ def forward( and `head * d_k == size` """ - q, k, v = self.forward_qkv(query, key, value) + q = self._forward_linearx('query', query, head_first=False) + k = self._forward_linearx('key', key, head_first=False) + v = self._forward_linearx('value', value, head_first=False) # NOTE(Mddct): In order to make the code easier to read, # these two lines are not placed in MultiHeadedAttention. - q = llama_apply_rotary_emb(q, pos_emb) - k = llama_apply_rotary_emb(k, pos_emb) - # see above - k, v, new_cache = self._update_kv_and_cache(k, v, cache) + q = WENET_APPLY_ROTARY_EMB[self.style](q, pos_emb) + k = WENET_APPLY_ROTARY_EMB[self.style](k, pos_emb) + + k, v, new_cache = self._update_kv_and_cache(k, + v, + cache, + head_first=False) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) if not self.use_sdpa: scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) return self.forward_attention(v, scores, mask), new_cache diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index ba31edffc7..0c4fab62af 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -286,6 +286,8 @@ def tie_or_clone_weights(self, jit_mode: bool = True): rank = int(os.environ.get('RANK', 0)) if not self.use_output_layer: return + if not self.tie_word_embedding: + return if jit_mode: if rank == 0: logging.info("clone emb.weight to output.weight") diff --git a/wenet/transformer/embedding.py b/wenet/transformer/embedding.py index db8a41333b..dcf717da3e 100644 --- a/wenet/transformer/embedding.py +++ b/wenet/transformer/embedding.py @@ -205,13 +205,15 @@ def __init__(self, head_dim: int, dropout_rate: float, max_len: int = 1500, - rope_theta=10000.0): + rope_theta=10000.0, + scale: bool = True): super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len) delattr(self, 'pe') self.max_len = max_len * 2 pe = precompute_freqs_cis(head_dim, self.max_len, rope_theta) self.register_buffer("pe", torch.view_as_real(pe.unsqueeze(0))) self.dropout_rate = dropout_rate + self.scale = scale def forward( self, @@ -220,10 +222,10 @@ def forward( torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]: pos_emb = self.position_encoding(offset, x.size(1), True) - pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2] + pos_emb = pos_emb.unsqueeze(2) # [1,seq, 1, head_dim//2] # NOTE(Mddct): some model don't scale - # TODO(Mddct): fix - x = x * self.xscale + if self.scale: + x = x * self.xscale return self.dropout(x), pos_emb def position_encoding(self, diff --git a/wenet/transformer/encoder_layer.py b/wenet/transformer/encoder_layer.py index bf2cdb2c2c..7228dfbf16 100644 --- a/wenet/transformer/encoder_layer.py +++ b/wenet/transformer/encoder_layer.py @@ -15,6 +15,7 @@ # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder self-attention layer definition.""" +from functools import partial from typing import Optional, Tuple import torch @@ -49,14 +50,22 @@ def __init__( normalize_before: bool = True, layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, + rms_norm_offset: bool = True, ): """Construct an EncoderLayer object.""" super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward assert layer_norm_type in ['layer_norm', 'rms_norm'] - self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) - self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps) + + norm_class = WENET_NORM_CLASSES[layer_norm_type] + if layer_norm_type == "rms_norm": + norm_class = partial( + norm_class, + add_unit_offset=rms_norm_offset, + ) + self.norm1 = norm_class(size, eps=norm_eps) + self.norm2 = norm_class(size, eps=norm_eps) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before diff --git a/wenet/transformer/norm.py b/wenet/transformer/norm.py index 2c3756f13f..8039228630 100644 --- a/wenet/transformer/norm.py +++ b/wenet/transformer/norm.py @@ -9,14 +9,19 @@ def __init__( self, dim: int, eps: float = 1e-6, + add_unit_offset: bool = True, ): super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(dim)) + self.add_unit_offset = add_unit_offset def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): x = self._norm(x.float()).type_as(x) - return x * self.weight + if self.add_unit_offset: + return x * (1 + self.weight) + else: + return x * self.weight diff --git a/wenet/utils/fsdp_utils.py b/wenet/utils/fsdp_utils.py index 33871f6f0c..77ca195953 100644 --- a/wenet/utils/fsdp_utils.py +++ b/wenet/utils/fsdp_utils.py @@ -5,6 +5,7 @@ from torch.distributed.fsdp.wrap import (lambda_auto_wrap_policy, transformer_auto_wrap_policy) +from wenet.LLM.decoder import DecoderOnly from wenet.branchformer.encoder_layer import BranchformerEncoderLayer from wenet.e_branchformer.encoder_layer import EBranchformerEncoderLayer from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer @@ -91,6 +92,8 @@ def check_gradient_checkpoint(model): if model.decoder.gradient_checkpointing: model.decoder.gradient_checkpointing = False ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) + if isinstance(model.decoder, DecoderOnly): + ckpt_laye_types += [DecoderOnly] return tuple(ckpt_laye_types) diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index ce8c12eeaf..c31fdbbddc 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -20,6 +20,8 @@ from wenet.paraformer.cif import Cif from wenet.paraformer.layers import SanmDecoder, SanmEncoder from wenet.paraformer.paraformer import Paraformer, Predictor +from wenet.LLM.causallm_model import CausalLM +from wenet.LLM.decoder import DecoderOnly from wenet.transducer.joint import TransducerJoint from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, RNNPredictor) @@ -84,11 +86,11 @@ "k2_model": K2Model, "transducer": Transducer, 'paraformer': Paraformer, + 'causal_llm': CausalLM, } -def init_model(args, configs): - +def init_speech_model(args, configs): # TODO(xcsong): Forcefully read the 'cmvn' attribute. if configs.get('cmvn', None) == 'global_cmvn': mean, istd = load_cmvn(configs['cmvn_conf']['cmvn_file'], @@ -168,6 +170,32 @@ def init_model(args, configs): special_tokens=configs.get('tokenizer_conf', {}).get('special_tokens', None), **configs['model_conf']) + return model, configs + + +def init_causal_llm(configs): + vocab_size = configs['output_dim'] + assert configs['decoder'] == 'decoder_only' + assert configs['model'] == 'causal_lm' + decoder_only = DecoderOnly(**configs['decoder_conf']) + + model = CausalLM( + vocab_size, + decoder_only, + **configs['model_conf'], + special_tokens=configs.get('tokenizer_conf', + {}).get('special_tokens', None), + ) + return model, configs + + +def init_model(args, configs): + + model_type = configs.get('model', 'asr_model') + if model_type == 'causal_lm': + model, configs = init_causal_llm(configs) + else: + model, configs = init_speech_model(args, configs) # If specify checkpoint, load some info from checkpoint if hasattr(args, 'checkpoint') and args.checkpoint is not None: @@ -178,16 +206,17 @@ def init_model(args, configs): infos = {} configs["init_infos"] = infos + print(configs) + # Trye to tie some weights + if hasattr(model, 'tie_or_clone_weights'): + if not hasattr(args, 'jit'): + args.jit = True # i.e. export onnx/jit/ipex + model.tie_or_clone_weights(args.jit) + if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora: mark_only_lora_as_trainable(model, bias='lora_only') if int(os.environ.get('RANK', 0)) == 0: print(configs) - # Tie emb.weight to decoder.output_layer.weight - if model.decoder.tie_word_embedding: - if not hasattr(args, 'jit'): - args.jit = True # i.e. export onnx/jit/ipex - model.decoder.tie_or_clone_weights(jit_mode=args.jit) - return model, configs diff --git a/wenet/utils/rope_utils.py b/wenet/utils/rope_utils.py index e80bf9ace7..54f13c47b8 100644 --- a/wenet/utils/rope_utils.py +++ b/wenet/utils/rope_utils.py @@ -31,3 +31,9 @@ def llama_apply_rotary_emb(x: torch.Tensor, x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) return x_out.type_as(x) + + +WENET_APPLY_ROTARY_EMB = { + 'google': google_apply_rotary_emb, + 'llama': llama_apply_rotary_emb, +} diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index cdf6da2b3a..8abb10a9df 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -282,20 +282,23 @@ def check_modify_and_save_config(args, configs, symbol_table): configs['encoder_conf']['lora_alpha'] = args.lora_alpha configs['encoder_conf']['lora_dropout'] = args.lora_dropout - if 'input_dim' not in configs: - if 'fbank_conf' in configs['dataset_conf']: - input_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins'] - elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: - input_dim = configs['dataset_conf']['log_mel_spectrogram_conf'][ - 'num_mel_bins'] + if configs["model"] == 'asr': + if 'input_dim' not in configs: + if 'fbank_conf' in configs['dataset_conf']: + input_dim = configs['dataset_conf']['fbank_conf'][ + 'num_mel_bins'] + elif 'log_mel_spectrogram_conf' in configs['dataset_conf']: + input_dim = configs['dataset_conf'][ + 'log_mel_spectrogram_conf']['num_mel_bins'] + else: + input_dim = configs['dataset_conf']['mfcc_conf'][ + 'num_mel_bins'] else: - input_dim = configs['dataset_conf']['mfcc_conf']['num_mel_bins'] - else: - input_dim = configs['input_dim'] + input_dim = configs['input_dim'] - configs, _ = get_blank_id(configs, symbol_table) + configs['input_dim'] = input_dim - configs['input_dim'] = input_dim + configs, _ = get_blank_id(configs, symbol_table) configs['output_dim'] = configs['vocab_size'] configs['train_engine'] = args.train_engine