-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: u must have shape (batch_size, H, L) #26
Comments
Hey, I think the code for @@ -314,13 +314,13 @@ class HyenaOperator(nn.Module):
uc = self.short_filter(u)[...,:l_filter]
- uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
- z=self.num_blocks,
- ho=self.num_heads,
- v=self.head_dim * (self.order + 1)
- )
+ # uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l',
+ # z=self.num_blocks,
+ # ho=self.num_heads,
+ # v=self.head_dim * (self.order + 1)
+ # )
- *x, v = uc.split(self.d_model, dim=2)
+ *x, v = uc.split(self.d_model, dim=1)
k = self.filter_fn.filter(l_filter)
# `c` is always 1 by default
@@ -339,7 +339,7 @@ class HyenaOperator(nn.Module):
v = self.dropout(v * x_i)
# the bias term is broadcasted. Last dimension (l) is handled by fftconv
- v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None])
+ v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o])
if self.post_order_ffn:
w = self.ord_proj_w[o]
@@ -347,7 +347,10 @@ class HyenaOperator(nn.Module):
rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l')
)
- y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads))
+ y = self.activation(
+ (v * x[0]).transpose(-2, -1),
+ # rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads)
+ )
y = self.out_proj(y)
if self.return_state:
@@ -356,4 +359,4 @@ class HyenaOperator(nn.Module):
@property
def d_output(self):
- return self.d_model
\ No newline at end of file
+ return self.d_model |
Hi is there any update about the fftconv for multi-head support? |
The module already supports multi-head - you can find an example in the H3 code: https://github.com/HazyResearch/safari/blob/main/src/models/sequence/h3.py#L160 In H3, the names of the three branches (what Hyena calls Passing in
|
See: HazyResearch/safari#26 (comment) Signed-off-by: Guy Jacob <guyj@nvidia.com>
* Initial reference code commit, unchanged Signed-off-by: Guy Jacob <guyj@nvidia.com> * Hyena code changes for NeMO compatibility Signed-off-by: Guy Jacob <guyj@nvidia.com> * MCore spec override functionality + example config w. hyena Signed-off-by: Guy Jacob <guyj@nvidia.com> * Additional changes - now working on char-level TinyShakespeare * Add missing input LayerNorm to spec (in the default attention spec it's fused with the projection Linear layer, so not explicitly defined) * Shape conversion at start and end of Hyena forward Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add fftconv cuda impl from safari Signed-off-by: Guy Jacob <guyj@nvidia.com> * Workaround for shape error in fftconv See: HazyResearch/safari#26 (comment) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Explicitly convert kernel to FP32 (torch.fft doesn't support bf16) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Working run configs Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove sharded_state_dict from HyenaOperator (made redundant by the default inmplementation in Megatron) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Update configs Signed-off-by: Guy Jacob <guyj@nvidia.com> * Testing TE Linear classes in HyenaOperator Signed-off-by: Guy Jacob <guyj@nvidia.com> * Revert to FusedDense for in/out projections after merging with 24.01.01 Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fix bug (use fused LNorm+Linear), bring back TE layers Signed-off-by: Guy Jacob <guyj@nvidia.com> * Configs rename + cleanup Signed-off-by: Guy Jacob <guyj@nvidia.com> * FlashFFTConv, Multi-head, some cleanup Signed-off-by: Guy Jacob <guyj@nvidia.com> * Bug fix - init FlashFFTConv with 2*seq_len Signed-off-by: Guy Jacob <guyj@nvidia.com> * ModuleSpec + replace nn.Conv1d with causal_conv1d Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove unneeded arguments Signed-off-by: Guy Jacob <guyj@nvidia.com> * More cleanup, remove fftconv ref functions Signed-off-by: Guy Jacob <guyj@nvidia.com> * Refactor HyenaFilter + more cleanup * Refactor in spirit of implementation in MAD-Lab repo: https://github.com/athms/mad-lab/blob/main/mad/model/layers/hyena.py Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add missing attributions Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove fftconv sources Signed-off-by: Guy Jacob <guyj@nvidia.com> * Bug fixes Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove d_model from external API, take from TransformerConfig Signed-off-by: Guy Jacob <guyj@nvidia.com> * cleanup config Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove spec override logic (possibly push separately) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add tests Signed-off-by: Guy Jacob <guyj@nvidia.com> * Keep only megatron_gpt_config_hyena (w. 153m parameters) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Black + isort formatting changes Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fixes following PR review * Clearer names + more documentation for config params * Clearer README * Check seq len < 8K with safari-fftconv * Avoid 0*bias op during forward Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fix tests following param name changes Signed-off-by: Guy Jacob <guyj@nvidia.com> --------- Signed-off-by: Guy Jacob <guyj@nvidia.com>
* Initial reference code commit, unchanged Signed-off-by: Guy Jacob <guyj@nvidia.com> * Hyena code changes for NeMO compatibility Signed-off-by: Guy Jacob <guyj@nvidia.com> * MCore spec override functionality + example config w. hyena Signed-off-by: Guy Jacob <guyj@nvidia.com> * Additional changes - now working on char-level TinyShakespeare * Add missing input LayerNorm to spec (in the default attention spec it's fused with the projection Linear layer, so not explicitly defined) * Shape conversion at start and end of Hyena forward Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add fftconv cuda impl from safari Signed-off-by: Guy Jacob <guyj@nvidia.com> * Workaround for shape error in fftconv See: HazyResearch/safari#26 (comment) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Explicitly convert kernel to FP32 (torch.fft doesn't support bf16) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Working run configs Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove sharded_state_dict from HyenaOperator (made redundant by the default inmplementation in Megatron) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Update configs Signed-off-by: Guy Jacob <guyj@nvidia.com> * Testing TE Linear classes in HyenaOperator Signed-off-by: Guy Jacob <guyj@nvidia.com> * Revert to FusedDense for in/out projections after merging with 24.01.01 Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fix bug (use fused LNorm+Linear), bring back TE layers Signed-off-by: Guy Jacob <guyj@nvidia.com> * Configs rename + cleanup Signed-off-by: Guy Jacob <guyj@nvidia.com> * FlashFFTConv, Multi-head, some cleanup Signed-off-by: Guy Jacob <guyj@nvidia.com> * Bug fix - init FlashFFTConv with 2*seq_len Signed-off-by: Guy Jacob <guyj@nvidia.com> * ModuleSpec + replace nn.Conv1d with causal_conv1d Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove unneeded arguments Signed-off-by: Guy Jacob <guyj@nvidia.com> * More cleanup, remove fftconv ref functions Signed-off-by: Guy Jacob <guyj@nvidia.com> * Refactor HyenaFilter + more cleanup * Refactor in spirit of implementation in MAD-Lab repo: https://github.com/athms/mad-lab/blob/main/mad/model/layers/hyena.py Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add missing attributions Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove fftconv sources Signed-off-by: Guy Jacob <guyj@nvidia.com> * Bug fixes Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove d_model from external API, take from TransformerConfig Signed-off-by: Guy Jacob <guyj@nvidia.com> * cleanup config Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove spec override logic (possibly push separately) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add tests Signed-off-by: Guy Jacob <guyj@nvidia.com> * Keep only megatron_gpt_config_hyena (w. 153m parameters) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Black + isort formatting changes Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fixes following PR review * Clearer names + more documentation for config params * Clearer README * Check seq len < 8K with safari-fftconv * Avoid 0*bias op during forward Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fix tests following param name changes Signed-off-by: Guy Jacob <guyj@nvidia.com> --------- Signed-off-by: Guy Jacob <guyj@nvidia.com>
* Initial reference code commit, unchanged Signed-off-by: Guy Jacob <guyj@nvidia.com> * Hyena code changes for NeMO compatibility Signed-off-by: Guy Jacob <guyj@nvidia.com> * MCore spec override functionality + example config w. hyena Signed-off-by: Guy Jacob <guyj@nvidia.com> * Additional changes - now working on char-level TinyShakespeare * Add missing input LayerNorm to spec (in the default attention spec it's fused with the projection Linear layer, so not explicitly defined) * Shape conversion at start and end of Hyena forward Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add fftconv cuda impl from safari Signed-off-by: Guy Jacob <guyj@nvidia.com> * Workaround for shape error in fftconv See: HazyResearch/safari#26 (comment) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Explicitly convert kernel to FP32 (torch.fft doesn't support bf16) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Working run configs Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove sharded_state_dict from HyenaOperator (made redundant by the default inmplementation in Megatron) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Update configs Signed-off-by: Guy Jacob <guyj@nvidia.com> * Testing TE Linear classes in HyenaOperator Signed-off-by: Guy Jacob <guyj@nvidia.com> * Revert to FusedDense for in/out projections after merging with 24.01.01 Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fix bug (use fused LNorm+Linear), bring back TE layers Signed-off-by: Guy Jacob <guyj@nvidia.com> * Configs rename + cleanup Signed-off-by: Guy Jacob <guyj@nvidia.com> * FlashFFTConv, Multi-head, some cleanup Signed-off-by: Guy Jacob <guyj@nvidia.com> * Bug fix - init FlashFFTConv with 2*seq_len Signed-off-by: Guy Jacob <guyj@nvidia.com> * ModuleSpec + replace nn.Conv1d with causal_conv1d Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove unneeded arguments Signed-off-by: Guy Jacob <guyj@nvidia.com> * More cleanup, remove fftconv ref functions Signed-off-by: Guy Jacob <guyj@nvidia.com> * Refactor HyenaFilter + more cleanup * Refactor in spirit of implementation in MAD-Lab repo: https://github.com/athms/mad-lab/blob/main/mad/model/layers/hyena.py Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add missing attributions Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove fftconv sources Signed-off-by: Guy Jacob <guyj@nvidia.com> * Bug fixes Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove d_model from external API, take from TransformerConfig Signed-off-by: Guy Jacob <guyj@nvidia.com> * cleanup config Signed-off-by: Guy Jacob <guyj@nvidia.com> * Remove spec override logic (possibly push separately) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Add tests Signed-off-by: Guy Jacob <guyj@nvidia.com> * Keep only megatron_gpt_config_hyena (w. 153m parameters) Signed-off-by: Guy Jacob <guyj@nvidia.com> * Black + isort formatting changes Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fixes following PR review * Clearer names + more documentation for config params * Clearer README * Check seq len < 8K with safari-fftconv * Avoid 0*bias op during forward Signed-off-by: Guy Jacob <guyj@nvidia.com> * Fix tests following param name changes Signed-off-by: Guy Jacob <guyj@nvidia.com> --------- Signed-off-by: Guy Jacob <guyj@nvidia.com>
Hello,
I am trying to run the benchmark here with fused_fft_conv enabled but I am getting RuntimeError: u must have shape (batch_size, H, L) error. In this case the shape of u is
[1, 1, 768, 1, 2048]
but it expects[1, 1, 768]
. Normally, fftconv handles the last dimension but in this case, the shape check fails.Log:
The text was updated successfully, but these errors were encountered: