Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE实现细节 #47

Closed
Forest-Scorpio opened this issue Apr 7, 2022 · 4 comments · Fixed by #49
Closed

RoPE实现细节 #47

Forest-Scorpio opened this issue Apr 7, 2022 · 4 comments · Fixed by #49
Assignees
Labels
bug Something isn't working

Comments

@Forest-Scorpio
Copy link

Forest-Scorpio commented Apr 7, 2022

# RoPE编码
if self.RoPE:
    pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
    # cos_pos = pos[..., 1::2].repeat(1, 1, 2)
    # sin_pos = pos[..., ::2].repeat(1, 1, 2)
    cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1)  # 修改后
    sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)  # 修改后

大佬你好,你的RoPE在实现上是不是有点问题,按照苏神的博客应该是上面修改后的代码吧

@jimme0421
Copy link
Collaborator

    if self.RoPE:
        pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
        cos_pos = pos[..., 1::2].repeat(1, 1, 2)
        sin_pos = pos[..., ::2].repeat(1, 1, 2)
        qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)

你可以看下qw2的实现,负项和正项是分开的。和博客中的公式(13)只是顺序不一样,但整体的结果是一样。

@Forest-Scorpio
Copy link
Author

Forest-Scorpio commented Apr 8, 2022

# RoPE编码
if self.RoPE:
    pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
    cos_pos = pos[..., 1::2].repeat(1, 1, 2)
    sin_pos = pos[..., ::2].repeat(1, 1, 2)
    qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
    qw2 = torch.reshape(qw2, qw.shape)
    qw = qw * cos_pos + qw2 * sin_pos

reshape之后不是变成了一负一正交替吗,如果把负项和正项分开的话,公式(13)里面左边的qw是不是也要把奇项和偶项分开才能保证各个位置对齐,最后内积的整体结果不变。

@jimme0421
Copy link
Collaborator

经过初步测试,确实存在你说的问题。

感谢你指出的问题,我们会在统一测试后进行修改。

@jimme0421 jimme0421 added the bug Something isn't working label Apr 8, 2022
@jimme0421
Copy link
Collaborator

我们会在下个版本会修复这个bug,并在commit再次表示感谢

@xiangking xiangking mentioned this issue Apr 9, 2022
@xiangking xiangking self-assigned this Apr 9, 2022
@xiangking xiangking mentioned this issue Apr 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants