@@ -1610,6 +1610,40 @@ def forward(
16101610 return inputs_embeds
16111611
16121612
1613+ class Phi4MultimodalRotaryEmbedding (nn .Module ):
1614+ def __init__ (self , config : Phi4MultimodalConfig , device = None ):
1615+ super ().__init__ ()
1616+ # BC: "rope_type" was originally "type"
1617+ if hasattr (config , "rope_scaling" ) and config .rope_scaling is not None :
1618+ self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
1619+ else :
1620+ self .rope_type = "default"
1621+ self .max_seq_len_cached = config .max_position_embeddings
1622+ self .original_max_seq_len = config .max_position_embeddings
1623+
1624+ self .config = config
1625+ self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
1626+
1627+ inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device )
1628+ self .register_buffer ("inv_freq" , inv_freq , persistent = False )
1629+ self .original_inv_freq = self .inv_freq
1630+
1631+ @torch .no_grad ()
1632+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
1633+ def forward (self , x , position_ids ):
1634+ inv_freq_expanded = self .inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
1635+ position_ids_expanded = position_ids [:, None , :].float ()
1636+
1637+ device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
1638+ with torch .autocast (device_type = device_type , enabled = False ): # Force float32
1639+ freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
1640+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
1641+ cos = emb .cos () * self .attention_scaling
1642+ sin = emb .sin () * self .attention_scaling
1643+
1644+ return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
1645+
1646+
16131647PHI4_MULTIMODAL_START_DOCSTRING = r"""
16141648 This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
16151649 library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -1663,40 +1697,6 @@ def _init_weights(self, module):
16631697 module .sub_img_feature_extensor .data .zero_ ()
16641698
16651699
1666- class Phi4MultimodalRotaryEmbedding (nn .Module ):
1667- def __init__ (self , config : Phi4MultimodalConfig , device = None ):
1668- super ().__init__ ()
1669- # BC: "rope_type" was originally "type"
1670- if hasattr (config , "rope_scaling" ) and config .rope_scaling is not None :
1671- self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
1672- else :
1673- self .rope_type = "default"
1674- self .max_seq_len_cached = config .max_position_embeddings
1675- self .original_max_seq_len = config .max_position_embeddings
1676-
1677- self .config = config
1678- self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
1679-
1680- inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device )
1681- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
1682- self .original_inv_freq = self .inv_freq
1683-
1684- @torch .no_grad ()
1685- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
1686- def forward (self , x , position_ids ):
1687- inv_freq_expanded = self .inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
1688- position_ids_expanded = position_ids [:, None , :].float ()
1689-
1690- device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
1691- with torch .autocast (device_type = device_type , enabled = False ): # Force float32
1692- freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
1693- emb = torch .cat ((freqs , freqs ), dim = - 1 )
1694- cos = emb .cos () * self .attention_scaling
1695- sin = emb .sin () * self .attention_scaling
1696-
1697- return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
1698-
1699-
17001700PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r"""
17011701 Args:
17021702 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
0 commit comments