From eca76fcdeedb3b8684e644160d25d309adf16453 Mon Sep 17 00:00:00 2001 From: Guy Jacob Date: Wed, 24 Jan 2024 11:34:39 +0200 Subject: [PATCH] Workaround for shape error in fftconv See: https://github.com/HazyResearch/safari/issues/26#issuecomment-1589018138 --- .../nlp/modules/common/hyena/hyena.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/nemo/collections/nlp/modules/common/hyena/hyena.py b/nemo/collections/nlp/modules/common/hyena/hyena.py index f90ae680db311..4e4711e41eca2 100644 --- a/nemo/collections/nlp/modules/common/hyena/hyena.py +++ b/nemo/collections/nlp/modules/common/hyena/hyena.py @@ -345,13 +345,17 @@ def forward(self, u, *args, **kwargs): uc = self.short_filter(u)[..., :l_filter] - uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l', - z=self.num_blocks, - ho=self.num_heads, - v=self.head_dim * (self.order + 1) - ) + # Workaround for shape error in fftconv, based on: + # https://github.com/HazyResearch/safari/issues/26#issuecomment-1589018138 - *x, v = uc.split(self.d_model, dim=2) + # uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l', + # z=self.num_blocks, + # ho=self.num_heads, + # v=self.head_dim * (self.order + 1) + # ) + + # *x, v = uc.split(self.d_model, dim=2) + *x, v = uc.split(self.d_model, dim=1) k = self.filter_fn.filter(l_filter) # `c` is always 1 by default @@ -370,7 +374,8 @@ def forward(self, u, *args, **kwargs): v = self.dropout(v * x_i) # the bias term is broadcasted. Last dimension (l) is handled by fftconv - v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None]) + # v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None]) + v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o]) if self.post_order_ffn: w = self.ord_proj_w[o] @@ -378,7 +383,8 @@ def forward(self, u, *args, **kwargs): rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l') ) - y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads)) + # y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads)) + y = self.activation((v * x[0]).transpose(-2, -1)) y = self.out_proj(y) # Convert back to sequence-first for MCore