Skip to content

Conversation

@NickCheng0921
Copy link

torch.compile() seems to fails on the model because the current RoPE implementation re-accesses the position tensor.
forward() after compile() gives the following error

RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/path/to/modeling_llada.py", line 1431, in forward
    outputs = self.model.forward(
...
  File "/path/to/modeling_llada.py", line 397, in get_rotary_embedding
    pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]. To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

Rewriting the get_rotary_embedding call to use unsqueeze instead of a view makes the model compilable and results in a 19% speedup on forward() calls for an input of 3K prompt, 512 generation after warmup runs.

  • my hardware is a single 4090 w/ a 7700X

There's one potential issue w/ this implementation however, torch.allclose() between the sin and cos tensors will fail if the RoPE cache for the current implementation is enabled but passes if cache is disabled. The deviation between this implementation and the current version is ~1e-4 atol and affects 2% of the encodings.

I'm unsure where the deviation is coming from.

dim = self.config.d_model // self.config.n_heads
inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
seq = torch.arange(seq_len, device=device, dtype=torch.float)
freqs = einsum("i , j -> i j", seq, inv_freq)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using the torch outer product here instead of einsum gives 3% faster forward() calls from my testing
doesn't make a big difference though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant