Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log2feats函数中有疑惑 #28

Open
toyoululu opened this issue Dec 21, 2022 · 2 comments
Open

log2feats函数中有疑惑 #28

toyoululu opened this issue Dec 21, 2022 · 2 comments

Comments

@toyoululu
Copy link

attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev))
seqs的维度应该是(batch_size,seq_len,embedding)其中(tl, tl)怎么能保证batch_size=seq_len?
seqs = torch.transpose(seqs, 0, 1)为什么要transpose呀
期待你的答复

@pmixer
Copy link
Owner

pmixer commented Dec 22, 2022

@toyoululu 两件事,第一件事,没有地方要求过 batch size = seq len,tl, tl 也不对其提供保证;第二件事,做 transpose 是 torch 的 mha 层要求时间维提到最前面。疑问最终的源头可能是对多头注意力层(mha)不熟悉,建议观看 https://www.bilibili.com/video/BV1J441137V6/

@toyoululu
Copy link
Author

谢谢回答,我仔细看了看api和代码,发现没有任何问题,我之前自己理解错误了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants