Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

有无pytorch版本? #2

Open
nullxjx opened this issue Dec 24, 2021 · 3 comments
Open

有无pytorch版本? #2

nullxjx opened this issue Dec 24, 2021 · 3 comments

Comments

@nullxjx
Copy link

nullxjx commented Dec 24, 2021

如题

@nku-shengzheliu
Copy link
Owner

可以参考pytorch与paddle算子映射表对代码进行修改。后面我如果时间的话会修改一下对应的torch的模型文件放上来

@nku-shengzheliu
Copy link
Owner

可以参考代码改动说明中的2. Attention计算方式。对应swin_transformer.py的343~348行

@nku-shengzheliu
Copy link
Owner

qk = paddle.matmul(q, k, transpose_y=True) # [bs*num_window=1*64,num_heads=4,49,49] -> [bs*num_window=1*16,num_heads=8,49,49] -> [bs*num_window=1*4,num_heads=16,49,49] -> [bs*num_window=1*1,num_heads=32,49,49]
q2 = paddle.multiply(q, q).sum(-1).sqrt().unsqueeze(3)
k2 = paddle.multiply(k, k).sum(-1).sqrt().unsqueeze(3)
attn = qk/paddle.clip(paddle.matmul(q2, k2, transpose_y=True), min=1e-6)

这四行代码计算了qi和kj之间的余弦相似度,最终的attn范围已经是[-1,1]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants