Skip to content

Commit

Permalink
Showing 2 changed files with 6 additions and 7 deletions.
9 changes: 4 additions & 5 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
@@ -1328,8 +1328,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding):

def __init__(self, num_positions, embedding_dim, padding_idx=None):
super().__init__(num_positions, embedding_dim)
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
self.weight = self._init_weight(self.weight)

@staticmethod
@@ -1342,10 +1340,11 @@ def _init_weight(out: nn.Parameter):
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False
return out

@torch.no_grad()
4 changes: 2 additions & 2 deletions tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
@@ -620,8 +620,8 @@ def test_positional_emb_cache_logic(self):
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())

def test_odd_embed_dim(self):
with self.assertRaises(NotImplementedError):
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
# odd embedding_dim is allowed
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)

# odd num_positions is allowed
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)

0 comments on commit 504ff7b

Please sign in to comment.