Skip to content

Commit 0197303

Browse files
committedJul 4, 2024
feat(asr): add example ASR script and setup configuration
1 parent d4fb66f commit 0197303

File tree

3 files changed

+101
-32
lines changed

3 files changed

+101
-32
lines changed
 

‎example_asr.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from mmlm.model import MMLM
2+
from mmlm.utility import MMLMUtility
3+
from datasets import load_dataset
4+
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, AutoModel
5+
6+
lm_model = AutoModelForCausalLM.from_pretrained('voidful/phi-1_5_chat_128k')
7+
lm_tokenizer = AutoTokenizer.from_pretrained('voidful/phi-1_5_chat_128k')
8+
audio_model = AutoModel.from_pretrained('ntu-spml/distilhubert')
9+
10+
mmlm = MMLM('voidful/phi-1_5_chat_128k', lm_model=lm_model, lm_tokenizer=lm_tokenizer, audio_config=8)
11+
mmlu = MMLMUtility(mmlm)
12+
13+
dataset = load_dataset("voidful/cv_13_tw_speech_tokenizer_asr")
14+
15+
tokenized_datasets = dataset.map(mmlu.tokenize_function, batched=False)
16+
17+
dc = mmlu.MMLMDataCollator(mmlm.tokenizer)
18+
19+
mmlm.tokenizer.pad_token = mmlm.tokenizer.eos_token
20+
training_args = TrainingArguments(
21+
output_dir='./results_asr',
22+
evaluation_strategy="epoch",
23+
learning_rate=2e-5,
24+
per_device_train_batch_size=3,
25+
per_device_eval_batch_size=3,
26+
logging_steps=1,
27+
num_train_epochs=10,
28+
weight_decay=0.01,
29+
logging_dir='./logs',
30+
)
31+
32+
# Initialize the Trainer
33+
trainer = Trainer(
34+
model=mmlm,
35+
args=training_args,
36+
train_dataset=tokenized_datasets['train'],
37+
eval_dataset=tokenized_datasets['test'],
38+
tokenizer=lm_tokenizer,
39+
data_collator=dc
40+
)
41+
42+
trainer.train()

‎mmlm/model.py

+35-32
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111

1212
class MMLM(nn.Module):
1313
def __init__(
14-
self,
15-
lm_config,
16-
lm_model=None,
17-
lm_tokenizer=None,
18-
audio_config=1,
19-
audio_model=None,
20-
audio_adapter_config=None,
21-
visual_config=1,
22-
visual_model=None,
23-
visual_adapter_config=None,
14+
self,
15+
lm_config,
16+
lm_model=None,
17+
lm_tokenizer=None,
18+
audio_config=1,
19+
audio_model=None,
20+
audio_adapter_config=None,
21+
visual_config=1,
22+
visual_model=None,
23+
visual_adapter_config=None,
2424
):
2525
super().__init__()
2626
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -86,26 +86,25 @@ def _setup_continuous_feature_processing(self, config, model, adapter_config, mo
8686
)
8787

8888
def forward(
89-
self,
90-
input_ids: torch.LongTensor = None,
91-
audio_features=None,
92-
vision_features=None,
93-
attention_mask: Optional[torch.Tensor] = None,
94-
position_ids: Optional[torch.LongTensor] = None,
95-
inputs_embeds: Optional[torch.FloatTensor] = None,
96-
labels: Optional[torch.LongTensor] = None,
97-
use_cache: Optional[bool] = None,
98-
output_attentions: Optional[bool] = None,
99-
output_hidden_states: Optional[bool] = None,
100-
return_dict: Optional[bool] = None,
89+
self,
90+
input_ids: torch.LongTensor = None,
91+
audio_features=None,
92+
vision_features=None,
93+
attention_mask: Optional[torch.Tensor] = None,
94+
position_ids: Optional[torch.LongTensor] = None,
95+
inputs_embeds: Optional[torch.FloatTensor] = None,
96+
labels: Optional[torch.LongTensor] = None,
97+
use_cache: Optional[bool] = None,
98+
output_attentions: Optional[bool] = None,
99+
output_hidden_states: Optional[bool] = None,
100+
return_dict: Optional[bool] = None,
101101
) -> Union[Tuple, CausalLMOutputWithPast]:
102102

103103
output_attentions = output_attentions if output_attentions is not None else self.lm_model.config.output_attentions
104104
output_hidden_states = (
105105
output_hidden_states if output_hidden_states is not None else self.lm_model.config.output_hidden_states
106106
)
107107
return_dict = return_dict if return_dict is not None else self.lm_model.config.use_return_dict
108-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
109108

110109
if inputs_embeds is None:
111110
embeder = self.lm_model.get_input_embeddings()
@@ -130,16 +129,19 @@ def forward(
130129
text_ids.append(i)
131130
if len(audio_discrete_token) > 0:
132131
audio_discrete_token = audio_discrete_token[
133-
:len(audio_discrete_token) // self.audio_config * self.audio_config]
132+
:len(audio_discrete_token) // self.audio_config * self.audio_config
133+
]
134134
discrete_audio_input_id = torch.tensor(audio_discrete_token, dtype=torch.long).view(
135-
self.audio_config, -1)
135+
self.audio_config, -1
136+
)
136137
discrete_audio_input_ids = []
137138
for i in range(self.audio_config):
138139
input_scale = embeder(discrete_audio_input_id[i, :].to(self.device))
139140
discrete_audio_input_ids.append(input_scale)
140141
weighted_discrete_inputs_embeds = torch.mul(
141142
torch.stack(discrete_audio_input_ids, dim=0).to(self.device),
142-
F.softmax(self.audio_learnable_weight, dim=0).to(self.device))
143+
F.softmax(self.audio_learnable_weight, dim=0).to(self.device)
144+
)
143145
weighted_discrete_inputs_embeds = torch.sum(weighted_discrete_inputs_embeds, dim=0)
144146
if discrete_audio_input_ids:
145147
input_embeds.append(weighted_discrete_inputs_embeds)
@@ -152,7 +154,8 @@ def forward(
152154
discrete_visual_input_ids.append(input_scale)
153155
weighted_discrete_inputs_embeds = torch.mul(
154156
torch.stack(discrete_visual_input_ids, dim=0).to(self.device),
155-
F.softmax(self.visual_learnable_weight, dim=0).to(self.device))
157+
F.softmax(self.visual_learnable_weight, dim=0).to(self.device)
158+
)
156159
weighted_discrete_inputs_embeds = torch.sum(weighted_discrete_inputs_embeds, dim=0)
157160
if discrete_visual_input_ids:
158161
input_embeds.append(weighted_discrete_inputs_embeds)
@@ -181,7 +184,7 @@ def forward(
181184
output_hidden_states=output_hidden_states,
182185
return_dict=return_dict,
183186
)
184-
elif self.audio_config or self.visual_config: # repack input_embeds
187+
elif self.audio_config or self.visual_config:
185188
for batch_num, batch_input in enumerate(input_ids):
186189
vision_features_id = 0
187190
audio_features_id = 0
@@ -190,13 +193,14 @@ def forward(
190193
audio_feature = self.audio_adapter(audio_features[batch_num][audio_features_id]).to(self.device)
191194
audio_features_id += 1
192195
inputs_embeds = torch.cat(
193-
(inputs_embeds[:, :pos, :], audio_feature, inputs_embeds[:, pos + 1:, :]), dim=1).to(
194-
self.device)
196+
(inputs_embeds[:, :pos, :], audio_feature, inputs_embeds[:, pos + 1:, :]), dim=1
197+
).to(self.device)
195198
if self.continue_visual_feature_type_ids[0] < ids < self.continue_visual_feature_type_ids[1]:
196199
vision_features = self.visual_adapter(vision_features[batch_num][vision_features_id])
197200
vision_features_id += 1
198201
inputs_embeds = torch.cat(
199-
(inputs_embeds[:, :pos, :], vision_features, inputs_embeds[:, pos + 1:, :]), dim=1)
202+
(inputs_embeds[:, :pos, :], vision_features, inputs_embeds[:, pos + 1:, :]), dim=1
203+
)
200204
outputs = self.lm_model(
201205
inputs_embeds=inputs_embeds,
202206
attention_mask=attention_mask,
@@ -234,7 +238,6 @@ def generate(self, input_ids, audio_feature=None, max_length=50):
234238
for _ in range(max_length):
235239
outputs = self.forward(input_ids=generated, audio_features=audio_feature)
236240
next_token_logits = outputs.logits[:, -1, :]
237-
next_token_logits = next_token_logits
238241
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
239242
generated = torch.cat((generated, next_token), dim=-1)
240243
if next_token.item() == self.tokenizer.eos_token_id:

‎setup.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from setuptools import setup, find_packages
2+
3+
with open('requirements.txt') as f:
4+
required = f.read().splitlines()
5+
required = [i for i in required if "@" not in i]
6+
7+
setup(
8+
name='mmlm',
9+
version='0.0.1',
10+
description='',
11+
url='https://github.com/voidful/MMLM',
12+
author='Voidful',
13+
author_email='voidful.stack@gmail.com',
14+
long_description=open("README.md", encoding="utf8").read(),
15+
long_description_content_type="text/markdown",
16+
classifiers=[
17+
'Development Status :: 4 - Beta',
18+
"License :: OSI Approved :: Apache Software License",
19+
"Programming Language :: Python"
20+
],
21+
packages=find_packages(),
22+
install_requires=required,
23+
zip_safe=False,
24+
)

0 commit comments

Comments
 (0)
Please sign in to comment.