Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: wrong attention mask calculation resulted in wrong embeddings #14496

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/sparknlp/annotator/embeddings/bge_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class BGEEmbeddings(AnnotatorModel,
HasCaseSensitiveProperties,
HasStorageRef,
HasBatchedAnnotate,
HasMaxSentenceLengthLimit):
HasMaxSentenceLengthLimit,
HasClsTokenProperties):
"""Sentence embeddings using BGE.

BGE, or BAAI General Embeddings, a model that can map any text to a low-dimensional dense
Expand Down Expand Up @@ -60,6 +61,8 @@ class BGEEmbeddings(AnnotatorModel,
Max sentence length to process, by default 512
configProtoBytes
ConfigProto from tensorflow, serialized into byte array.
useCLSToken
Whether to use the CLS token for sentence embeddings, by default True

References
----------
Expand Down Expand Up @@ -148,6 +151,7 @@ def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.BGEEmbeddings", ja
batchSize=8,
maxSentenceLength=512,
caseSensitive=False,
useCLSToken=True
)

@staticmethod
Expand All @@ -171,13 +175,13 @@ def loadSavedModel(folder, spark_session):
return BGEEmbeddings(java_model=jModel)

@staticmethod
def pretrained(name="bge_base", lang="en", remote_loc=None):
def pretrained(name="bge_small_en_v1.5", lang="en", remote_loc=None):
"""Downloads and loads a pretrained model.

Parameters
----------
name : str, optional
Name of the pretrained model, by default "bge_base"
Name of the pretrained model, by default "bge_small_en_v1.5"
lang : str, optional
Language of the pretrained model, by default "en"
remote_loc : str, optional
Expand Down
27 changes: 27 additions & 0 deletions python/sparknlp/common/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,33 @@ def getCaseSensitive(self):
return self.getOrDefault(self.caseSensitive)


class HasClsTokenProperties:
useCLSToken = Param(Params._dummy(),
"useCLSToken",
"Whether to use CLS token for pooling (true) or attention-based average pooling (false)",
typeConverter=TypeConverters.toBoolean)

def setUseCLSToken(self, value):
"""Sets whether to ignore case in tokens for embeddings matching.

Parameters
----------
value : bool
Whether to use CLS token for pooling (true) or attention-based average pooling (false)
"""
return self._set(useCLSToken=value)

def getUseCLSToken(self):
"""Gets whether to use CLS token for pooling (true) or attention-based average pooling (false)

Returns
-------
bool
Whether to use CLS token for pooling (true) or attention-based average pooling (false)
"""
return self.getOrDefault(self.useCLSToken)


class HasClassifierActivationProperties:
activation = Param(Params._dummy(),
"activation",
Expand Down
6 changes: 0 additions & 6 deletions src/main/scala/com/johnsnowlabs/ml/ai/Albert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,7 @@ private[johnsnowlabs] class Albert(
segmentTensors.close()
}


case Openvino.name =>



val batchLength = batch.length
val shape = Array(batchLength, maxSentenceLength)
val (tokenTensors, maskTensors) =
Expand Down Expand Up @@ -192,8 +188,6 @@ private[johnsnowlabs] class Albert(
throw e
}



case _ =>
val tensors = new TensorResources()

Expand Down
73 changes: 44 additions & 29 deletions src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.johnsnowlabs.ml.ai

import ai.onnxruntime.{OnnxTensor, TensorInfo}
import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings
import breeze.linalg.DenseMatrix
import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper}
import com.johnsnowlabs.ml.openvino.OpenvinoWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
Expand Down Expand Up @@ -71,12 +71,14 @@ private[johnsnowlabs] class BGE(
* @return
* sentence embeddings
*/
private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = {
private def getSentenceEmbedding(
batch: Seq[Array[Int]],
useCLSToken: Boolean): Array[Array[Float]] = {
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength))
val embeddings = detectedEngine match {
case ONNX.name =>
getSentenceEmbeddingFromOnnx(paddedBatch, maxSentenceLength)
getSentenceEmbeddingFromOnnx(paddedBatch, maxSentenceLength, useCLSToken)

case Openvino.name =>
getSentenceEmbeddingFromOv(paddedBatch, maxSentenceLength)
Expand Down Expand Up @@ -168,22 +170,17 @@ private[johnsnowlabs] class BGE(
sentenceEmbeddingsFloatsArray
}



private def getSentenceEmbeddingFromOv(
batch: Seq[Array[Int]],
maxSentenceLength: Int): Array[Array[Float]] = {

batch: Seq[Array[Int]],
maxSentenceLength: Int): Array[Array[Float]] = {

val batchLength = batch.length
val shape = Array(batchLength, maxSentenceLength)
val tokenTensors =
new org.intel.openvino.Tensor(shape, batch.flatMap(x => x.map(xx => xx.toLong)).toArray)
val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray
val attentionMask = batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray

val maskTensors = new org.intel.openvino.Tensor(
shape,
attentionMask.flatten)
val maskTensors = new org.intel.openvino.Tensor(shape, attentionMask.flatten)

val segmentTensors = new Tensor(shape, Array.fill(batchLength * maxSentenceLength)(0L))
val inferRequest = openvinoWrapper.get.getCompiledModel().create_infer_request()
Expand All @@ -198,7 +195,7 @@ private[johnsnowlabs] class BGE(
val lastHiddenState = inferRequest
.get_tensor("last_hidden_state")
val shape = lastHiddenState.get_shape().map(_.toLong)
val flattenEmbeddings = lastHiddenState
val flattenEmbeddings = lastHiddenState
.data()
val embeddings = LinAlg.avgPooling(flattenEmbeddings, attentionMask, shape)
val normalizedEmbeddings = LinAlg.l2Normalize(embeddings)
Expand All @@ -215,13 +212,13 @@ private[johnsnowlabs] class BGE(

}


private def getSentenceEmbeddingFromOnnx(
batch: Seq[Array[Int]],
maxSentenceLength: Int): Array[Array[Float]] = {
maxSentenceLength: Int,
useCLSToken: Boolean): Array[Array[Float]] = {

val inputIds = batch.map(x => x.map(x => x.toLong)).toArray
val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray
val attentionMask = batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray

val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

Expand All @@ -238,27 +235,44 @@ private[johnsnowlabs] class BGE(
// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
val lastHiddenState = results.get("last_hidden_state").get()
val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo]
val shape = info.getShape
try {
val lastHiddenState = results.get("last_hidden_state").get()
val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo]
val shape = info.getShape

// shape is [batch_size, sequence_length, hidden_size]
val thirdDim = shape.last.toInt // hidden_size dimension
val secondDim = shape(1).toInt // sequence_length dimension

val flattenEmbeddings = lastHiddenState
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

val embeddings = LinAlg.avgPooling(flattenEmbeddings, attentionMask, shape)
val normalizedEmbeddings = LinAlg.l2Normalize(embeddings)
val normalizedEmbeddings = if (!useCLSToken) {
// Average pooling strategy
val pooledEmbeddings = LinAlg.avgPooling(flattenEmbeddings, attentionMask, shape)
LinAlg.l2Normalize(pooledEmbeddings)
} else {
// CLS token pooling strategy
val embeddings = flattenEmbeddings
.grouped(thirdDim)
.toArray
.grouped(secondDim)
.toArray

LinAlg.l2Normalize(DenseMatrix(LinAlg.clsPooling(embeddings, attentionMask): _*))
}

LinAlg.denseMatrixToArray(normalizedEmbeddings)
} finally if (results != null) results.close()
} finally {
if (results != null) results.close()
}
} catch {
case e: Exception =>
// Handle exceptions by logging or other means.
e.printStackTrace()
Array.empty[Array[Float]] // Return an empty array or appropriate error handling
logger.error("Error during sentence embedding computation", e)
Array.empty[Array[Float]]
} finally {
// Close tensors outside the try-catch to avoid repeated null checks.
// These resources are initialized before the try-catch, so they should be closed here.
tokenTensors.close()
maskTensors.close()
segmentTensors.close()
Expand All @@ -280,7 +294,8 @@ private[johnsnowlabs] class BGE(
sentences: Seq[Annotation],
tokenizedSentences: Seq[WordpieceTokenizedSentence],
batchSize: Int,
maxSentenceLength: Int): Seq[Annotation] = {
maxSentenceLength: Int,
useCLSToken: Boolean): Seq[Annotation] = {

tokenizedSentences
.zip(sentences)
Expand All @@ -294,7 +309,7 @@ private[johnsnowlabs] class BGE(
.map(y => y.pieceId)
.take(maxSentenceLength - 2) ++ Array(sentenceEndTokenId))

val sentenceEmbeddings = getSentenceEmbedding(tokens)
val sentenceEmbeddings = getSentenceEmbedding(tokens, useCLSToken)

batch.zip(sentenceEmbeddings).map { case (sentence, vectors) =>
Annotation(
Expand Down
Loading