11
11
12
12
class MMLM (nn .Module ):
13
13
def __init__ (
14
- self ,
15
- lm_config ,
16
- lm_model = None ,
17
- lm_tokenizer = None ,
18
- audio_config = 1 ,
19
- audio_model = None ,
20
- audio_adapter_config = None ,
21
- visual_config = 1 ,
22
- visual_model = None ,
23
- visual_adapter_config = None ,
14
+ self ,
15
+ lm_config ,
16
+ lm_model = None ,
17
+ lm_tokenizer = None ,
18
+ audio_config = 1 ,
19
+ audio_model = None ,
20
+ audio_adapter_config = None ,
21
+ visual_config = 1 ,
22
+ visual_model = None ,
23
+ visual_adapter_config = None ,
24
24
):
25
25
super ().__init__ ()
26
26
self .device = "cuda" if torch .cuda .is_available () else "cpu"
@@ -86,26 +86,25 @@ def _setup_continuous_feature_processing(self, config, model, adapter_config, mo
86
86
)
87
87
88
88
def forward (
89
- self ,
90
- input_ids : torch .LongTensor = None ,
91
- audio_features = None ,
92
- vision_features = None ,
93
- attention_mask : Optional [torch .Tensor ] = None ,
94
- position_ids : Optional [torch .LongTensor ] = None ,
95
- inputs_embeds : Optional [torch .FloatTensor ] = None ,
96
- labels : Optional [torch .LongTensor ] = None ,
97
- use_cache : Optional [bool ] = None ,
98
- output_attentions : Optional [bool ] = None ,
99
- output_hidden_states : Optional [bool ] = None ,
100
- return_dict : Optional [bool ] = None ,
89
+ self ,
90
+ input_ids : torch .LongTensor = None ,
91
+ audio_features = None ,
92
+ vision_features = None ,
93
+ attention_mask : Optional [torch .Tensor ] = None ,
94
+ position_ids : Optional [torch .LongTensor ] = None ,
95
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
96
+ labels : Optional [torch .LongTensor ] = None ,
97
+ use_cache : Optional [bool ] = None ,
98
+ output_attentions : Optional [bool ] = None ,
99
+ output_hidden_states : Optional [bool ] = None ,
100
+ return_dict : Optional [bool ] = None ,
101
101
) -> Union [Tuple , CausalLMOutputWithPast ]:
102
102
103
103
output_attentions = output_attentions if output_attentions is not None else self .lm_model .config .output_attentions
104
104
output_hidden_states = (
105
105
output_hidden_states if output_hidden_states is not None else self .lm_model .config .output_hidden_states
106
106
)
107
107
return_dict = return_dict if return_dict is not None else self .lm_model .config .use_return_dict
108
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
109
108
110
109
if inputs_embeds is None :
111
110
embeder = self .lm_model .get_input_embeddings ()
@@ -130,16 +129,19 @@ def forward(
130
129
text_ids .append (i )
131
130
if len (audio_discrete_token ) > 0 :
132
131
audio_discrete_token = audio_discrete_token [
133
- :len (audio_discrete_token ) // self .audio_config * self .audio_config ]
132
+ :len (audio_discrete_token ) // self .audio_config * self .audio_config
133
+ ]
134
134
discrete_audio_input_id = torch .tensor (audio_discrete_token , dtype = torch .long ).view (
135
- self .audio_config , - 1 )
135
+ self .audio_config , - 1
136
+ )
136
137
discrete_audio_input_ids = []
137
138
for i in range (self .audio_config ):
138
139
input_scale = embeder (discrete_audio_input_id [i , :].to (self .device ))
139
140
discrete_audio_input_ids .append (input_scale )
140
141
weighted_discrete_inputs_embeds = torch .mul (
141
142
torch .stack (discrete_audio_input_ids , dim = 0 ).to (self .device ),
142
- F .softmax (self .audio_learnable_weight , dim = 0 ).to (self .device ))
143
+ F .softmax (self .audio_learnable_weight , dim = 0 ).to (self .device )
144
+ )
143
145
weighted_discrete_inputs_embeds = torch .sum (weighted_discrete_inputs_embeds , dim = 0 )
144
146
if discrete_audio_input_ids :
145
147
input_embeds .append (weighted_discrete_inputs_embeds )
@@ -152,7 +154,8 @@ def forward(
152
154
discrete_visual_input_ids .append (input_scale )
153
155
weighted_discrete_inputs_embeds = torch .mul (
154
156
torch .stack (discrete_visual_input_ids , dim = 0 ).to (self .device ),
155
- F .softmax (self .visual_learnable_weight , dim = 0 ).to (self .device ))
157
+ F .softmax (self .visual_learnable_weight , dim = 0 ).to (self .device )
158
+ )
156
159
weighted_discrete_inputs_embeds = torch .sum (weighted_discrete_inputs_embeds , dim = 0 )
157
160
if discrete_visual_input_ids :
158
161
input_embeds .append (weighted_discrete_inputs_embeds )
@@ -181,7 +184,7 @@ def forward(
181
184
output_hidden_states = output_hidden_states ,
182
185
return_dict = return_dict ,
183
186
)
184
- elif self .audio_config or self .visual_config : # repack input_embeds
187
+ elif self .audio_config or self .visual_config :
185
188
for batch_num , batch_input in enumerate (input_ids ):
186
189
vision_features_id = 0
187
190
audio_features_id = 0
@@ -190,13 +193,14 @@ def forward(
190
193
audio_feature = self .audio_adapter (audio_features [batch_num ][audio_features_id ]).to (self .device )
191
194
audio_features_id += 1
192
195
inputs_embeds = torch .cat (
193
- (inputs_embeds [:, :pos , :], audio_feature , inputs_embeds [:, pos + 1 :, :]), dim = 1 ). to (
194
- self .device )
196
+ (inputs_embeds [:, :pos , :], audio_feature , inputs_embeds [:, pos + 1 :, :]), dim = 1
197
+ ). to ( self .device )
195
198
if self .continue_visual_feature_type_ids [0 ] < ids < self .continue_visual_feature_type_ids [1 ]:
196
199
vision_features = self .visual_adapter (vision_features [batch_num ][vision_features_id ])
197
200
vision_features_id += 1
198
201
inputs_embeds = torch .cat (
199
- (inputs_embeds [:, :pos , :], vision_features , inputs_embeds [:, pos + 1 :, :]), dim = 1 )
202
+ (inputs_embeds [:, :pos , :], vision_features , inputs_embeds [:, pos + 1 :, :]), dim = 1
203
+ )
200
204
outputs = self .lm_model (
201
205
inputs_embeds = inputs_embeds ,
202
206
attention_mask = attention_mask ,
@@ -234,7 +238,6 @@ def generate(self, input_ids, audio_feature=None, max_length=50):
234
238
for _ in range (max_length ):
235
239
outputs = self .forward (input_ids = generated , audio_features = audio_feature )
236
240
next_token_logits = outputs .logits [:, - 1 , :]
237
- next_token_logits = next_token_logits
238
241
next_token = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
239
242
generated = torch .cat ((generated , next_token ), dim = - 1 )
240
243
if next_token .item () == self .tokenizer .eos_token_id :
0 commit comments