Skip to content

Commit 1aeb2e7

Browse files
committed
style
1 parent 2f18707 commit 1aeb2e7

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,40 @@ def forward(
16101610
return inputs_embeds
16111611

16121612

1613+
class Phi4MultimodalRotaryEmbedding(nn.Module):
1614+
def __init__(self, config: Phi4MultimodalConfig, device=None):
1615+
super().__init__()
1616+
# BC: "rope_type" was originally "type"
1617+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1618+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1619+
else:
1620+
self.rope_type = "default"
1621+
self.max_seq_len_cached = config.max_position_embeddings
1622+
self.original_max_seq_len = config.max_position_embeddings
1623+
1624+
self.config = config
1625+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
1626+
1627+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
1628+
self.register_buffer("inv_freq", inv_freq, persistent=False)
1629+
self.original_inv_freq = self.inv_freq
1630+
1631+
@torch.no_grad()
1632+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
1633+
def forward(self, x, position_ids):
1634+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1635+
position_ids_expanded = position_ids[:, None, :].float()
1636+
1637+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1638+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
1639+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1640+
emb = torch.cat((freqs, freqs), dim=-1)
1641+
cos = emb.cos() * self.attention_scaling
1642+
sin = emb.sin() * self.attention_scaling
1643+
1644+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1645+
1646+
16131647
PHI4_MULTIMODAL_START_DOCSTRING = r"""
16141648
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
16151649
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -1663,40 +1697,6 @@ def _init_weights(self, module):
16631697
module.sub_img_feature_extensor.data.zero_()
16641698

16651699

1666-
class Phi4MultimodalRotaryEmbedding(nn.Module):
1667-
def __init__(self, config: Phi4MultimodalConfig, device=None):
1668-
super().__init__()
1669-
# BC: "rope_type" was originally "type"
1670-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1671-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1672-
else:
1673-
self.rope_type = "default"
1674-
self.max_seq_len_cached = config.max_position_embeddings
1675-
self.original_max_seq_len = config.max_position_embeddings
1676-
1677-
self.config = config
1678-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
1679-
1680-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
1681-
self.register_buffer("inv_freq", inv_freq, persistent=False)
1682-
self.original_inv_freq = self.inv_freq
1683-
1684-
@torch.no_grad()
1685-
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
1686-
def forward(self, x, position_ids):
1687-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1688-
position_ids_expanded = position_ids[:, None, :].float()
1689-
1690-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1691-
with torch.autocast(device_type=device_type, enabled=False): # Force float32
1692-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1693-
emb = torch.cat((freqs, freqs), dim=-1)
1694-
cos = emb.cos() * self.attention_scaling
1695-
sin = emb.sin() * self.attention_scaling
1696-
1697-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1698-
1699-
17001700
PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r"""
17011701
Args:
17021702
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

0 commit comments

Comments
 (0)