Skip to content

Commit

Permalink
fix matryoshka norm
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Jan 13, 2025
1 parent 7c3d9d5 commit 1cf0578
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions paddlenlp/transformers/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def forward(self, q_reps, p_reps):
if len(self.embedding_matryoshka_dims) > 0:
loss = 0.0
for dim in self.embedding_matryoshka_dims:
reduced_q_reps = q_reps[:, :dim]
reduced_q_reps = q_reps[:, :dim].astype("float32")
reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1)

reduced_p_reps = p_reps[:, :dim]
reduced_p_reps = p_reps[:, :dim].astype("float32")
reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1)

dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps)
Expand Down

0 comments on commit 1cf0578

Please sign in to comment.