From 472b119865075a04ceab139eb21b62de1d85ceaa Mon Sep 17 00:00:00 2001 From: Anush Date: Tue, 19 Sep 2023 21:04:14 +0530 Subject: [PATCH] chore: revert 824abe6b9e3fa4c28cdb4bb49b1344ad0aa748e3 --- src/fastembed.ts | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/fastembed.ts b/src/fastembed.ts index da242a1..c5d61f9 100644 --- a/src/fastembed.ts +++ b/src/fastembed.ts @@ -289,7 +289,21 @@ export class FlagEmbedding extends Embedding { output.last_hidden_state.dims as number[] ); - const embeddings = lastHiddenState.map((sentence) => sentence[0]); + const embeddings = lastHiddenState.map((layer, layerIdx) => { + const weightedSum = layer.reduce((acc, tokenEmbedding, idx) => { + const attentionWeight = maskArray[layerIdx][idx]; + return acc.map( + (val, i) => val + tokenEmbedding[i] * Number(attentionWeight) + ); + }, new Array(layer[0].length).fill(0)); + + const inputMaskSum = maskArray[layerIdx].reduce( + (acc, attentionWeight) => acc + Number(attentionWeight), + 0 + ); + + return weightedSum.map((val) => val / (inputMaskSum + 1e-9)); + }); yield embeddings.map(normalize); }