@@ -1014,18 +1014,9 @@ def forward(
10141014 past_key_values : Optional [Cache ] = None ,
10151015 inputs_embeds : Optional [torch .FloatTensor ] = None ,
10161016 use_cache : Optional [bool ] = None ,
1017- output_attentions : Optional [bool ] = None ,
1018- output_hidden_states : Optional [bool ] = None ,
1019- return_dict : Optional [bool ] = None ,
10201017 cache_position : Optional [torch .LongTensor ] = None ,
10211018 ** kwargs : Unpack [FlashAttentionKwargs ],
10221019 ) -> Union [tuple , AriaModelOutputWithPast ]:
1023- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
1024- output_hidden_states = (
1025- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
1026- )
1027- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
1028-
10291020 if inputs_embeds is None :
10301021 inputs_embeds = self .get_input_embeddings ()(input_ids )
10311022
@@ -1037,7 +1028,7 @@ def forward(
10371028 vision_feature_layer = self .config .vision_feature_layer ,
10381029 )
10391030 image_features = image_features .to (inputs_embeds .device , inputs_embeds .dtype )
1040- special_image_mask = self ._get_image_mask (
1031+ special_image_mask = self .get_placeholder_mask (
10411032 input_ids , inputs_embeds = inputs_embeds , image_features = image_features
10421033 )
10431034 inputs_embeds = inputs_embeds .masked_scatter (special_image_mask , image_features )
@@ -1048,9 +1039,6 @@ def forward(
10481039 past_key_values = past_key_values ,
10491040 inputs_embeds = inputs_embeds ,
10501041 use_cache = use_cache ,
1051- output_attentions = output_attentions ,
1052- output_hidden_states = output_hidden_states ,
1053- return_dict = True ,
10541042 cache_position = cache_position ,
10551043 ** kwargs ,
10561044 )
@@ -1156,9 +1144,6 @@ def forward(
11561144 inputs_embeds : Optional [torch .FloatTensor ] = None ,
11571145 labels : Optional [torch .LongTensor ] = None ,
11581146 use_cache : Optional [bool ] = None ,
1159- output_attentions : Optional [bool ] = None ,
1160- output_hidden_states : Optional [bool ] = None ,
1161- return_dict : Optional [bool ] = None ,
11621147 logits_to_keep : Union [int , torch .Tensor ] = 0 ,
11631148 cache_position : Optional [torch .LongTensor ] = None ,
11641149 ** kwargs : Unpack [TransformersKwargs ],
@@ -1223,12 +1208,6 @@ def forward(
12231208 >>> print(generated_texts[1])
12241209 Assistant: The bridge is in San Francisco.
12251210 ```"""
1226- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
1227- output_hidden_states = (
1228- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
1229- )
1230- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
1231-
12321211 outputs = self .model (
12331212 input_ids = input_ids ,
12341213 pixel_values = pixel_values ,
@@ -1238,9 +1217,6 @@ def forward(
12381217 past_key_values = past_key_values ,
12391218 inputs_embeds = inputs_embeds ,
12401219 use_cache = use_cache ,
1241- output_attentions = output_attentions ,
1242- output_hidden_states = output_hidden_states ,
1243- return_dict = return_dict ,
12441220 cache_position = cache_position ,
12451221 ** kwargs ,
12461222 )
0 commit comments