Skip to content

Commit

Permalink
fix 3.1 rope init for compile (#1544)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Sep 11, 2024
1 parent 377abc0 commit 221031a
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions torchtune/models/llama3_1/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,11 @@ def __init__(
self.base = base
self.max_seq_len = max_seq_len

self.is_cache_built = False

self.scale_factor = scale_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.old_context_len = old_context_len

# TODO: delete this once all our recipes are moved off of FSDP1 since we
# no longer need to explicitly name our param init method reset_parameters
def reset_parameters(self):
self.is_cache_built = False
self.rope_init()

def rope_init(self):
Expand All @@ -73,6 +68,12 @@ def rope_init(self):
self.base
** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)
)

# If we're on meta device return early.
# We can't apply scaling until freqs is filled with real data
if freqs.is_meta:
return

theta = self.apply_scaling(
freqs,
self.scale_factor,
Expand Down Expand Up @@ -146,12 +147,15 @@ def forward(
- s: sequence length
- n_h: num heads
- h_d: head dim
Raises:
RuntimeError: if RoPE cache is not initialized prior to forward call
"""

# TODO: remove once our distributed recipes are on FSDP2
if not self.is_cache_built:
with torch.device(x.device):
self.rope_init()
raise RuntimeError(
"RoPE cache is not built. Please call rope_init() first."
)

# input tensor has shape [b, s, n_h, h_d]
seq_len = x.size(1)
Expand Down

0 comments on commit 221031a

Please sign in to comment.