From d584d242d652b31159d332057c09340d1f949bc3 Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Thu, 18 Jul 2024 15:18:10 +0800 Subject: [PATCH] support embedding normalization --- angle_emb/angle.py | 6 +++++- tests/test_loadding.py | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 3235e7b..b9e26cc 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -1550,7 +1550,8 @@ def encode(self, embedding_start: int = 0, embedding_size: Optional[int] = None, device: Optional[Any] = None, - prompt: Optional[str] = None): + prompt: Optional[str] = None, + normalize_embedding: bool = False): """ encode texts. @@ -1563,6 +1564,7 @@ def encode(self, The embeddings from embedding_start to embedding_start+embedding_size will be returned. :param device: Optional[Any]. Default None. :param prompt: Optional[str]. Default None. + :param normalize_embedding: bool. Default False. """ if layer_index != -1 and self.full_backbone is None: self.full_backbone = copy.deepcopy(self.backbone) @@ -1605,6 +1607,8 @@ def encode(self, layer_index=layer_index, embedding_start=embedding_start, embedding_size=embedding_size) + if normalize_embedding: + output = nn.functional.normalize(output, p=2, dim=-1) if to_numpy: return output.float().detach().cpu().numpy() return output diff --git a/tests/test_loadding.py b/tests/test_loadding.py index 46e31c8..4686410 100644 --- a/tests/test_loadding.py +++ b/tests/test_loadding.py @@ -25,3 +25,12 @@ def test_2dmse_loadding(): assert isinstance(vecs, np.ndarray) vecs = angle.encode(['hello world', 'hi theređź‘‹'], layer_index=20, embedding_size=512) assert isinstance(vecs, np.ndarray) + + +def test_normalize_embedding(): + import numpy as np + from angle_emb import AnglE + + angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1') + vecs = angle.encode('hello world', normalize_embedding=True) + assert isinstance(vecs, np.ndarray)