Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][wenet/LLM] support LLMs #2460

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8cdc9ed
add casual model
Mddct Apr 7, 2024
8559f93
fix typo
Mddct Apr 7, 2024
9f3dd76
rm ckpt
Mddct Apr 7, 2024
9958a55
add topk topp sampler
Mddct Apr 11, 2024
1de7240
fix positoin
Mddct Apr 12, 2024
a90d336
Merge branch 'main' into Mddct-llm
Mddct Apr 12, 2024
6568552
add generate
Mddct Apr 12, 2024
984d481
add toto
Mddct Apr 12, 2024
b36b3ad
support sft & pretrain training forward
Mddct Apr 12, 2024
cc57164
gemm conversion works
Mddct Apr 13, 2024
4180661
support init casual model
Mddct Apr 14, 2024
3fabb2b
Merge branch 'main' into Mddct-llm
Mddct Apr 14, 2024
7bbb2d7
Merge branch 'main' into Mddct-llm
Mddct Apr 14, 2024
e6a6d02
all gemma model works
Mddct Apr 14, 2024
fbe519f
fix ut
Mddct Apr 14, 2024
ed38698
merge main
Mddct Apr 14, 2024
50458c3
Merge branch 'main' into Mddct-llm
Mddct Apr 15, 2024
acd42c7
merge main
Mddct Apr 17, 2024
25f5ef3
fix cache
Mddct Apr 18, 2024
135b9c0
Merge branch 'main' into Mddct-llm
Mddct Apr 18, 2024
33a55d5
generate works!
Mddct Apr 18, 2024
05a2579
unify chat pattern
Mddct Apr 19, 2024
126d740
convert llama3 works
Mddct Apr 19, 2024
bd6a6e6
merge main
Mddct Apr 19, 2024
34eecb2
fix w1 w2 w3 in feedforward
Mddct Apr 19, 2024
dabcdf2
add 70b temporarily
Mddct Apr 20, 2024
72c0f23
mv LLM to wenet
Mddct Apr 21, 2024
e92b207
support llm dataset
Mddct Apr 21, 2024
b892c44
unify config
Mddct Apr 22, 2024
d01c3ec
add dataset yaml in script
Mddct Apr 22, 2024
7f683a9
support llm dataset
Mddct Apr 23, 2024
592fb69
dynamic static bucket works
Mddct Apr 24, 2024
a8cbf23
merge main
Mddct Apr 24, 2024
38330d1
training works
Mddct Apr 24, 2024
bfe0628
pretrain works
Mddct Apr 26, 2024
79bafa3
refactor covert
Mddct Apr 27, 2024
a9a7f7b
fix flash att in generate
Mddct Apr 28, 2024
e5e36fc
llama works
Mddct Apr 28, 2024
0e81840
fix llama3
Mddct Apr 30, 2024
9805ed6
fix speed
Mddct Apr 30, 2024
32853c6
try fix ut
Mddct Apr 30, 2024
e81b110
support stop tokens in gen and support ppl
Mddct Apr 30, 2024
e246769
support stop tokens in gen and support ppl
Mddct Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions test/wenet/dataset/test_datapipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from wenet.dataset.datapipes import (RepeatDatapipe, SortDataPipe,
WenetRawDatasetSource,
WenetTarShardDatasetSource)
from wenet.dataset.processor import (DynamicBatchWindow, decode_wav, padding,
parse_json, compute_fbank,
detect_language, detect_task)
from wenet.dataset.processor import (DynamicBatchWindow, decode_wav,
feats_length_fn, padding, parse_json,
compute_fbank, detect_language,
detect_task)


@pytest.mark.parametrize("data_list", [
Expand Down Expand Up @@ -106,7 +107,8 @@ def test_dynamic_batch_datapipe(data_list):
max_frames_in_batch = 10000
dataset = dataset.dynamic_batch(
window_class=DynamicBatchWindow(max_frames_in_batch),
wrapper_class=padding)
wrapper_class=padding,
elem_size_fn=feats_length_fn)

dataloader = torch.utils.data.DataLoader(dataset,
batch_size=None,
Expand Down
208 changes: 208 additions & 0 deletions wenet/LLM/causal_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
from typing import Dict, List, Optional, Union
import torch
from wenet.LLM.decoder import DecoderOnly
from wenet.LLM.sampler import sampler
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss
from wenet.utils.common import IGNORE_ID, th_accuracy
from wenet.utils.mask import make_pad_mask, subsequent_mask


class CausalLM(torch.nn.Module):

def __init__(
self,
vocab_size: int,
decoder: DecoderOnly,
special_tokens: dict,
tie_word_embedding: bool = False,
linear_bias: bool = False,
ignore_id: int = IGNORE_ID,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
) -> None:
super().__init__()
del special_tokens

self.embed = torch.nn.Embedding(vocab_size, decoder.hidden_size)
self.out = torch.nn.Linear(decoder.hidden_size,
vocab_size,
bias=linear_bias)

self.decoder = decoder
self.vocab_size = vocab_size
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
self.tie_word_embedding = tie_word_embedding
self.ignore_id = ignore_id

@torch.jit.unused
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
""" Forward for training
"""
text = batch['feats'].to(device)
target = batch['target'].to(device)
text_length = batch['feats_lengths'].to(device)

mask = ~make_pad_mask(text_length, max_len=text.size(1)).unsqueeze(
1) # (B,1,L)
causal_mask = subsequent_mask(
mask.size(-1), device=mask.device).unsqueeze(0) # (1,L,L)
att_mask = causal_mask & mask # (B, L, L)

embeding = self.embed(text)
decoder_out = self.out(self.decoder(embeding,
att_mask)[0]) # (B, L, vocab_size)
loss = self.criterion_att(decoder_out, target)
acc = th_accuracy(decoder_out.view(-1, self.vocab_size),
target,
ignore_label=self.ignore_id)

return {
"loss": loss,
"ppl": torch.exp(loss.detach()),
"th_accuracy": acc
}

def tie_or_clone_weights(self, jit_mode: bool):
if not self.tie_word_embedding:
return
if jit_mode:
self.out.weight = torch.nn.Parameter(self.embed.weight.clone())
else:
self.out.weight = self.embed.weight
# TODO(Mddct): whether to deal bias for other llm model

@torch.jit.unused
@torch.inference_mode()
def generate(
self,
prompts_tokens: List[List[int]],
device: torch.device,
stop_tokens: List[int],
dtype: torch.dtype = torch.float32,
output_len: int = 100,
temperature: Union[float, None] = 0.95,
top_p: float = 1.0,
top_k: int = 100,
) -> List[List[int]]:
"""Generates responses for given prompts using Gemma model."""
# If a single prompt is provided, treat it as a batch of 1.
batch_size = len(prompts_tokens)
min_prompt_len = min(len(p) for p in prompts_tokens)
max_prompt_len = max(len(p) for p in prompts_tokens)
max_seq_len = max_prompt_len + output_len
assert max_seq_len <= self.decoder.pos_enc.max_len

# build KV caches
kv_caches = []
for _ in range(len(self.decoder.decoders)):
size = (batch_size, 0, self.decoder.n_kv_head,
self.decoder.head_dim)
k_cache = torch.zeros(size=size, dtype=dtype, device=device)
v_cache = torch.zeros(size=size, dtype=dtype, device=device)
kv_caches.append((k_cache, v_cache))

# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
IGNORE_ID,
dtype=torch.int64,
device=device)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
IGNORE_ID,
dtype=torch.int64,
device=device)
# right padding
for i, p in enumerate(prompts_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])

prompt_mask_tensor = token_ids_tensor != IGNORE_ID
input_positions_tensor = torch.arange(0,
min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.ones((1, 1, max_seq_len, max_seq_len),
dtype=torch.bool)
mask_tensor = torch.tril(mask_tensor).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
att_mask = curr_mask_tensor.squeeze(
1)[:, :min_prompt_len, :min_prompt_len]
output_positions_tensor = torch.LongTensor([min_prompt_len - 1
]).to(device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
output_index = torch.tensor(min_prompt_len,
dtype=torch.int64).to(device)

input_token_embeding = self.embed(input_token_ids_tensor)
offset = torch.tensor([0] * len(prompts_tokens)).to(device)
input_offset = offset

stop_tokens_tensor = torch.tensor(stop_tokens, device=device)
# Prefill up to min_prompt_len tokens, then treat other prefill as
# decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
decoder_out, kv_caches, = self.decoder(
input_token_embeding,
att_mask,
input_offset,
kv_caches,
)
decoder_out = self.out(decoder_out)
decoder_out = decoder_out.index_select(1, output_positions_tensor)
next_token_ids = sampler(
decoder_out,
temperatures_tensor,
top_ps_tensor,
top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)

input_token_ids_tensor = output_token_ids
input_token_embeding = self.embed(input_token_ids_tensor)

input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(
2, input_positions_tensor)
att_mask = curr_mask_tensor.squeeze(1)[:, :output_index +
1, :output_index + 1]

output_positions_tensor = torch.tensor(
0, dtype=torch.int64).to(device)
input_offset = offset + output_index.unsqueeze(-1)
output_index = output_index + 1

if all(torch.isin(next_token_ids, stop_tokens_tensor)):
break

token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompts_tokens[i]
):len(prompts_tokens[i]) + output_len]
for stop_token in stop_tokens:
try:
eos_index = trimmed_output.index(stop_token)
trimmed_output = trimmed_output[:eos_index]
break
except Exception:
continue
results.append(trimmed_output)

return results
161 changes: 161 additions & 0 deletions wenet/LLM/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint as ckpt
from wenet.transformer.attention import T_CACHE

from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.utils.class_utils import (WENET_ACTIVATION_CLASSES,
WENET_ATTENTION_CLASSES,
WENET_EMB_CLASSES, WENET_MLP_CLASSES,
WENET_NORM_CLASSES)
from wenet.utils.common import mask_to_bias


class DecoderOnly(torch.nn.Module):

def __init__(
self,
n_kv_head: int,
head_dim: int,
hidden_size: int,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
normalize_before: bool = True,
query_bias: bool = False,
key_bias: bool = False,
value_bias: bool = False,
mlp_bias: bool = False,
activation_type: str = "gelu",
gelu_approximate: Union[str, None] = None,
max_position_embeding: int = 8192,
mlp_type: str = 'gated',
layer_norm_type: str = 'rms_norm',
norm_eps: float = 1e-5,
rms_norm_offset: bool = True,
selfattention_layer_type: str = "rope_abs_selfattn",
use_sdpa: bool = False,
gradient_checkpointing: bool = False,
rope_theta: float = 10000.0,
rope_style: str = 'google',
scale_embed: bool = True,
) -> None:
super().__init__()

assert selfattention_layer_type in ['rope_abs_selfattn']
self.pos_enc = WENET_EMB_CLASSES["rope_pos"](
hidden_size,
head_dim,
max_len=max_position_embeding,
dropout_rate=positional_dropout_rate,
rope_theta=rope_theta,
scale=scale_embed)
if activation_type == "gelu" and gelu_approximate is not None:
activation = WENET_ACTIVATION_CLASSES['gelu'](
approximate=gelu_approximate)
else:
activation = WENET_ACTIVATION_CLASSES[activation_type]()

mlp_class = WENET_MLP_CLASSES[mlp_type]
self.num_blocks = num_blocks
# TODO: support lora & refactor lora
self.decoders = torch.nn.ModuleList([
TransformerEncoderLayer(
hidden_size,
WENET_ATTENTION_CLASSES[selfattention_layer_type](
attention_heads,
hidden_size,
attention_dropout_rate,
query_bias,
key_bias,
value_bias,
use_sdpa,
n_kv_head,
head_dim,
style=rope_style),
mlp_class(hidden_size, linear_units, dropout_rate, activation,
mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
rms_norm_offset=rms_norm_offset,
) for _ in range(self.num_blocks)
])
self.pre_norm = normalize_before
self.final_norm: Optional[torch.nn.Module] = None
if self.pre_norm:
norm_class = WENET_NORM_CLASSES[layer_norm_type]
if layer_norm_type == "rms_norm":
norm_class = partial(
norm_class,
add_unit_offset=rms_norm_offset,
)
self.final_norm = norm_class(hidden_size, eps=norm_eps)

self.n_kv_head = n_kv_head
self.head_dim = head_dim
self._hidden_size = hidden_size
self.use_sdpa = use_sdpa
self.gradient_checkpointing = gradient_checkpointing

def forward(
self,
input: torch.Tensor,
att_mask: torch.Tensor,
input_position: Union[int, torch.Tensor] = 0,
kv_caches: Optional[List[T_CACHE]] = None,
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
xs, pos_emb = self.pos_enc(input, offset=input_position)
if self.use_sdpa:
att_mask = mask_to_bias(att_mask, xs.dtype)

if self.gradient_checkpointing and self.training:
xs = self.forward_layers_checkpointed(xs, att_mask, pos_emb)
else:
xs, kv_caches = self.forward_layers(xs, att_mask, pos_emb,
kv_caches)
if self.pre_norm and self.final_norm is not None:
xs = self.final_norm(xs)
return xs, kv_caches

def forward_layers(
self,
xs: torch.Tensor,
att_mask: torch.Tensor,
pos_emb: torch.Tensor,
kv_caches: Optional[List[T_CACHE]] = None,
) -> Tuple[torch.Tensor, Union[List[T_CACHE], None]]:
if self.training:
for (i, layer) in enumerate(self.decoders):
xs, _, _, _ = layer(xs, att_mask, pos_emb)
new_kv_caches = kv_caches
else:
assert kv_caches is not None
new_kv_caches = []
for (i, layer) in enumerate(self.decoders):
xs, _, new_kv_cache, _ = layer(xs,
att_mask,
pos_emb,
att_cache=(kv_caches[i][0],
kv_caches[i][1]))
new_kv_caches.append(new_kv_cache)

return xs, new_kv_caches

@torch.jit.ignore(drop=True)
def forward_layers_checkpointed(self, xs: torch.Tensor,
att_mask: torch.Tensor,
pos_emb: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, att_mask,
pos_emb)
return xs

@property
def hidden_size(self):
return self._hidden_size
Loading
Loading