Skip to content

Commit

Permalink
[LLM] support causllm model
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Jun 1, 2024
1 parent 97ffee4 commit c86d48b
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 49 deletions.
161 changes: 161 additions & 0 deletions wenet/LLM/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint as ckpt
from wenet.transformer.attention import T_CACHE

from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_EMB_CLASSES, WENET_MLP_CLASSES,
WENET_NORM_CLASSES)
from wenet.utils.common import mask_to_bias


class DecoderOnly(torch.nn.Module):

def __init__(
self,
n_kv_head: int,
head_dim: int,
hidden_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
normalize_before: bool = True,
query_bias: bool = False,
key_bias: bool = False,
value_bias: bool = False,
mlp_bias: bool = False,
activation_type: str = "gelu",
gelu_approximate: Union[str, None] = None,
max_position_embeding: int = 8192,
mlp_type: str = 'gated',
layer_norm_type: str = 'rms_norm',
norm_eps: float = 1e-5,
rms_norm_offset: bool = True,
selfattention_layer_type: str = "rope_abs_selfattn",
use_sdpa: bool = False,
gradient_checkpointing: bool = False,
rope_theta: float = 10000.0,
rope_style: str = 'google',
scale_embed: bool = True,
) -> None:
super().__init__()

assert selfattention_layer_type in ['rope_abs_selfattn']
self.pos_enc = WENET_EMB_CLASSES["rope_pos"](
hidden_size,
head_dim,
max_len=max_position_embeding,
dropout_rate=positional_dropout_rate,
rope_theta=rope_theta,
scale=scale_embed)
if activation_type == "gelu" and gelu_approximate is not None:
activation = WENET_ACTIVATION_CLASSES['gelu'](
approximate=gelu_approximate)
else:
activation = WENET_ACTIVATION_CLASSES[activation_type]()

mlp_class = WENET_MLP_CLASSES[mlp_type]
self.num_blocks = num_blocks
# TODO: support lora & refactor lora
self.decoders = torch.nn.ModuleList([
TransformerEncoderLayer(
hidden_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
attention_heads,
hidden_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
style=rope_style),
mlp_class(hidden_size, linear_units, dropout_rate, activation,
mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
rms_norm_offset=rms_norm_offset,
) for _ in range(self.num_blocks)
])
self.pre_norm = normalize_before
self.final_norm: Optional[torch.nn.Module] = None
if self.pre_norm:
norm_class = WENET_NORM_CLASSES[layer_norm_type]
if layer_norm_type == "rms_norm":
norm_class = partial(
norm_class,
add_unit_offset=rms_norm_offset,
)
self.final_norm = norm_class(hidden_size, eps=norm_eps)

self.n_kv_head = n_kv_head
self.head_dim = head_dim
self._hidden_size = hidden_size
self.use_sdpa = use_sdpa
self.gradient_checkpointing = gradient_checkpointing

def forward(
self,
input: torch.Tensor,
att_mask: torch.Tensor,
input_position: Union[int, torch.Tensor] = 0,
kv_caches: Optional[List[T_CACHE]] = None,
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
xs, pos_emb = self.pos_enc(input, offset=input_position)
if self.use_sdpa:
att_mask = mask_to_bias(att_mask, xs.dtype)

if self.gradient_checkpointing and self.training:
xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb)
else:
xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb,
kv_caches)
if self.pre_norm and self.final_norm is not None:
xs = self.final_norm(xs)
return xs, kv_caches

def forward_layers(
self,
xs: torch.Tensor,
att_mask: torch.Tensor,
pos_emb: torch.Tensor,
kv_caches: Optional[List[T_CACHE]] = None,
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
if self.training:
for (i, layer) in enumerate(self.decoders):
xs, _, _, _ = layer(xs, att_mask, pos_emb)
new_kv_caches = kv_caches
else:
assert kv_caches is not None
new_kv_caches = []
for (i, layer) in enumerate(self.decoders):
xs, _, new_kv_cache, _ = layer(xs,
att_mask,
pos_emb,
att_cache=(kv_caches[i][0],
kv_caches[i][1]))
new_kv_caches.append(new_kv_cache)

return xs, new_kv_caches

@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, xs: torch.Tensor,
att_mask: torch.Tensor,
pos_emb: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask,
pos_emb)
return xs

@property
def hidden_size(self):
return self._hidden_size
3 changes: 3 additions & 0 deletions wenet/transformer/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def forward(
"th_accuracy": acc_att,
}

def tie_or_clone_weights(self, jit_mode: bool = True):
self.decoder.tie_or_clone_weights(jit_mode)

@torch.jit.unused
def _forward_ctc(
self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor,
Expand Down
69 changes: 46 additions & 23 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from torch import nn

from wenet.utils.rope_utils import llama_apply_rotary_emb
from wenet.utils.rope_utils import WENET_APPLY_ROTARY_EMB

T_CACHE = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -80,7 +80,10 @@ def __init__(self,
self.use_sdpa = use_sdpa
self.dropout_rate = dropout_rate

def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor:
def _forward_linearx(self,
name: str,
x: torch.Tensor,
head_first: bool = True) -> torch.Tensor:
assert x.ndim >= 3
if name == 'query':
x = self.linear_q(x)
Expand All @@ -98,7 +101,9 @@ def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor:

# split last dim
x = x.view(x_shape)
x = x.transpose(-3, -2) # (batch, ..., head or head_kv, time, d_k)
if head_first:
x = x.transpose(-3,
-2) # (batch, ..., head or head_kv, time, d_k)
return x

def forward_qkv(
Expand Down Expand Up @@ -173,9 +178,15 @@ def forward_attention(
return self.linear_out(x) # (batch, ..., time1, d_model)

def _update_kv_and_cache(
self, k: torch.Tensor, v: torch.Tensor,
cache: T_CACHE) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE]:
self,
k: torch.Tensor,
v: torch.Tensor,
cache: T_CACHE,
head_first: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, T_CACHE]:
new_cache = cache
seq_axis = -2 if head_first else -3
head_axis = -3 if head_first else -2
if not self.training:
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
Expand All @@ -195,9 +206,9 @@ def _update_kv_and_cache(
# >>> torch.equal(d[0], d[1]) # True
key_cache, value_cache = cache
if key_cache.size(0) > 0:
k = torch.cat([key_cache, k], dim=2)
k = torch.cat([key_cache, k], dim=seq_axis)
if value_cache.size(0) > 0:
v = torch.cat([value_cache, v], dim=2)
v = torch.cat([value_cache, v], dim=seq_axis)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
# new_cache = torch.cat((k, v), dim=-1) if not self.training else cache
Expand All @@ -218,17 +229,19 @@ def _update_kv_and_cache(
# )
n_repeat = self.h // self.h_kv
k_shape = k.size()
k = k.unsqueeze(-3).expand(
k_shape[:-2] + torch.Size([n_repeat]) +
k_shape[-2:]).reshape(k_shape[:-3] +
torch.Size([self.h_kv * n_repeat]) +
k_shape[-2:])
repeat_axis = head_axis + 1
k = k.unsqueeze(head_axis).expand(
k_shape[:repeat_axis] + torch.Size([n_repeat]) +
k_shape[repeat_axis:]).reshape(
k_shape[:head_axis] + torch.Size([self.h_kv * n_repeat]) +
k_shape[repeat_axis:])
v_shape = v.size()
v = v.unsqueeze(-3).expand(
v_shape[:-2] + torch.Size([n_repeat]) +
v_shape[-2:]).reshape(v_shape[:-3] +
torch.Size([self.h_kv * n_repeat]) +
v_shape[-2:])
v = v.unsqueeze(head_axis).expand(
v_shape[:repeat_axis] + torch.Size([n_repeat]) +
v_shape[(repeat_axis):]).reshape(
v_shape[:head_axis] + torch.Size([self.h_kv * n_repeat]) +
v_shape[repeat_axis:])

return k, v, new_cache

def forward(
Expand Down Expand Up @@ -594,9 +607,11 @@ def __init__(self,
value_bias: bool = True,
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
head_dim: Optional[int] = None,
style='google'):
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa, n_kv_head, head_dim)
self.style = style

def forward(
self,
Expand Down Expand Up @@ -637,14 +652,22 @@ def forward(
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = self._forward_linearx('query', query, head_first=False)
k = self._forward_linearx('key', key, head_first=False)
v = self._forward_linearx('value', value, head_first=False)
# NOTE(Mddct): In order to make the code easier to read,
# these two lines are not placed in MultiHeadedAttention.
q = llama_apply_rotary_emb(q, pos_emb)
k = llama_apply_rotary_emb(k, pos_emb)
# see above
k, v, new_cache = self._update_kv_and_cache(k, v, cache)
q = WENET_APPLY_ROTARY_EMB[self.style](q, pos_emb)
k = WENET_APPLY_ROTARY_EMB[self.style](k, pos_emb)

k, v, new_cache = self._update_kv_and_cache(k,
v,
cache,
head_first=False)

q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
Expand Down
2 changes: 2 additions & 0 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def tie_or_clone_weights(self, jit_mode: bool = True):
rank = int(os.environ.get('RANK', 0))
if not self.use_output_layer:
return
if not self.tie_word_embedding:
return
if jit_mode:
if rank == 0:
logging.info("clone emb.weight to output.weight")
Expand Down
10 changes: 6 additions & 4 deletions wenet/transformer/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,15 @@ def __init__(self,
head_dim: int,
dropout_rate: float,
max_len: int = 1500,
rope_theta=10000.0):
rope_theta=10000.0,
scale: bool = True):
super().__init__(d_model, dropout_rate=dropout_rate, max_len=max_len)
delattr(self, 'pe')
self.max_len = max_len * 2
pe = precompute_freqs_cis(head_dim, self.max_len, rope_theta)
self.register_buffer("pe", torch.view_as_real(pe.unsqueeze(0)))
self.dropout_rate = dropout_rate
self.scale = scale

def forward(
self,
Expand All @@ -220,10 +222,10 @@ def forward(
torch.Tensor] = 0) -> Tuple[torch.Tensor, torch.Tensor]:

pos_emb = self.position_encoding(offset, x.size(1), True)
pos_emb = pos_emb.unsqueeze(1) # [1, 1, seq, head_dim//2]
pos_emb = pos_emb.unsqueeze(2) # [1,seq, 1, head_dim//2]
# NOTE(Mddct): some model don't scale
# TODO(Mddct): fix
x = x * self.xscale
if self.scale:
x = x * self.xscale
return self.dropout(x), pos_emb

def position_encoding(self,
Expand Down
13 changes: 11 additions & 2 deletions wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder self-attention layer definition."""

from functools import partial
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -49,14 +50,22 @@ def __init__(
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
rms_norm_offset: bool = True,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)

norm_class = WENET_NORM_CLASSES[layer_norm_type]
if layer_norm_type == "rms_norm":
norm_class = partial(
norm_class,
add_unit_offset=rms_norm_offset,
)
self.norm1 = norm_class(size, eps=norm_eps)
self.norm2 = norm_class(size, eps=norm_eps)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
Expand Down
7 changes: 6 additions & 1 deletion wenet/transformer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@ def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = True,
):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
self.add_unit_offset = add_unit_offset

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
x = self._norm(x.float()).type_as(x)
return x * self.weight
if self.add_unit_offset:
return x * (1 + self.weight)
else:
return x * self.weight
Loading

0 comments on commit c86d48b

Please sign in to comment.