Skip to content

Commit

Permalink
Remove some useless code.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Feb 6, 2025
1 parent f1059b0 commit 14880e6
Showing 1 changed file with 0 additions and 55 deletions.
55 changes: 0 additions & 55 deletions comfy/ldm/lumina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,25 +352,6 @@ def forward(self, x, c):
return x


class RopeEmbedder:
def __init__(
self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512)
):
super().__init__()
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)

def __call__(self, ids: torch.Tensor):
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64)
result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1)


class NextDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
Expand Down Expand Up @@ -481,7 +462,6 @@ def __init__(
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
# self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads
Expand Down Expand Up @@ -609,7 +589,6 @@ def patchify_and_embed(

return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis


# def forward(self, x, t, cap_feats, cap_mask):
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
Expand Down Expand Up @@ -638,37 +617,3 @@ def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwar

return -x

@staticmethod
def precompute_freqs_cis(
dim: List[int],
end: List[int],
theta: float = 10000.0,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with
given dimensions.
This function calculates a frequency tensor with complex exponentials
using the given dimension 'dim' and the end index 'end'. The 'theta'
parameter scales the frequencies. The returned tensor contains complex
values in complex64 data type.
Args:
dim (list): Dimension of the frequency tensor.
end (list): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation.
Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex
exponentials.
"""
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)

return freqs_cis

0 comments on commit 14880e6

Please sign in to comment.