From fddbce037fc65220483822fbd317c51f6a5e8a0f Mon Sep 17 00:00:00 2001 From: ir2718 Date: Fri, 18 Oct 2024 00:08:47 +0200 Subject: [PATCH] fix mean pooling --- angle_emb/angle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 77ff58e..99d004b 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -272,8 +272,7 @@ def get_pooling(outputs: torch.Tensor, sequence_lengths = -1 if padding_side == 'left' else inputs["attention_mask"].sum(dim=1) - 1 outputs = outputs[torch.arange(batch_size, device=outputs.device), sequence_lengths] elif pooling_strategy == 'avg': - outputs = torch.sum( - outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"]) + outputs = torch.sum(outputs * inputs["attention_mask"][:, :, None], dim=1) / inputs["attention_mask"].sum(dim=1).unsqueeze(1) elif pooling_strategy == 'max': outputs, _ = torch.max(outputs * inputs["attention_mask"][:, :, None], dim=1) elif pooling_strategy == 'all':