Skip to content

Commit

Permalink
Workaround for shape error in fftconv
Browse files Browse the repository at this point in the history
  • Loading branch information
guyjacob committed May 2, 2024
1 parent 2e4b64c commit eca76fc
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions nemo/collections/nlp/modules/common/hyena/hyena.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,17 @@ def forward(self, u, *args, **kwargs):

uc = self.short_filter(u)[..., :l_filter]

uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
z=self.num_blocks,
ho=self.num_heads,
v=self.head_dim * (self.order + 1)
)
# Workaround for shape error in fftconv, based on:
# https://github.com/HazyResearch/safari/issues/26#issuecomment-1589018138

*x, v = uc.split(self.d_model, dim=2)
# uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
# z=self.num_blocks,
# ho=self.num_heads,
# v=self.head_dim * (self.order + 1)
# )

# *x, v = uc.split(self.d_model, dim=2)
*x, v = uc.split(self.d_model, dim=1)
k = self.filter_fn.filter(l_filter)

# `c` is always 1 by default
Expand All @@ -370,15 +374,17 @@ def forward(self, u, *args, **kwargs):
v = self.dropout(v * x_i)

# the bias term is broadcasted. Last dimension (l) is handled by fftconv
v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
# v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])

if self.post_order_ffn:
w = self.ord_proj_w[o]
v = mul_sum(
rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l')
)

y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads))
# y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads))
y = self.activation((v * x[0]).transpose(-2, -1))
y = self.out_proj(y)

# Convert back to sequence-first for MCore
Expand Down

0 comments on commit eca76fc

Please sign in to comment.