From bdd21098f97f19cd86c7b47eefc867470cc4cb89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=BF=94?= <276699683@qq.com> Date: Sat, 9 Apr 2022 10:10:53 +0800 Subject: [PATCH] =?UTF-8?q?fix(nn):=20=E4=BF=AE=E5=A4=8DRoPE=20repeat?= =?UTF-8?q?=E6=AD=A5=E9=AA=A4=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ark_nlp/nn/layer/global_pointer_block.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ark_nlp/nn/layer/global_pointer_block.py b/ark_nlp/nn/layer/global_pointer_block.py index cac630d..038b2da 100644 --- a/ark_nlp/nn/layer/global_pointer_block.py +++ b/ark_nlp/nn/layer/global_pointer_block.py @@ -75,8 +75,8 @@ def forward(self, inputs, mask=None): # RoPE编码 if self.RoPE: pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs) - cos_pos = pos[..., None, 1::2].repeat(1, 1, 1, 2) - sin_pos = pos[..., None, ::2].repeat(1, 1, 1, 2) + cos_pos = pos[..., None, 1::2].repeat_interleave(2, dim=-1) + sin_pos = pos[..., None, ::2].repeat_interleave(2, dim=-1) qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 4) qw2 = torch.reshape(qw2, qw.shape) qw = qw * cos_pos + qw2 * sin_pos @@ -112,8 +112,8 @@ def forward(self, inputs, mask=None): # RoPE编码 if self.RoPE: pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs) - cos_pos = pos[..., 1::2].repeat(1, 1, 2) - sin_pos = pos[..., ::2].repeat(1, 1, 2) + cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) + sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) qw2 = torch.reshape(qw2, qw.shape) qw = qw * cos_pos + qw2 * sin_pos