You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There is some issue with the sandbox.py file's input parameter. I get the following error
Traceback (most recent call last):
File "/Users/anshumansinha/Desktop/StructRepGen_Dev/influenza_transformer-main/sandbox.py", line 163, in <module>
prediction = model(src, tgt, src_mask, tgt_mask)
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/anshumansinha/Desktop/StructRepGen_Dev/influenza_transformer-main/transformer_timeseries.py", line 226, in forward
decoder_output = self.decoder(
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 369, in forward
output = mod(output, memory, tgt_mask=tgt_mask,
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 716, in forward
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 725, in _sa_block
x = self.self_attn(x, x, x,
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1205, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/functional.py", line 5251, in multi_head_attention_forward
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([48, 48]), but should be (128, 128).
(victor_env) (base) anshumansinha@Anshumans-MacBook-Pro-3 influenza_transformer-main %
The text was updated successfully, but these errors were encountered:
There is some issue with the sandbox.py file's input parameter. I get the following error
The text was updated successfully, but these errors were encountered: