Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(api): Support k-nearest neighbor (kNN) search #385

Merged
merged 9 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ jobs:
matrix:
java: [ "11", "17" ]
scala: [ "2.12.18", "2.13.12", "3.3.1" ]
elasticsearch: ["7.x", "8.x"]
steps:
- name: Checkout current branch
uses: actions/checkout@v4.1.1
Expand All @@ -55,7 +54,7 @@ jobs:
- name: Run tests
run: ./sbt ++${{ matrix.scala }}! library/test
- name: Run test container
run: docker-compose -f docker/elasticsearch-${{ matrix.elasticsearch }}.yml up -d
run: docker-compose -f docker/elasticsearch-8.x.yml up -d
- name: Run integration tests
run: ./sbt ++${{ matrix.scala }}! integration/test

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,61 @@ object HttpExecutorSpec extends IntegrationSpec {
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
)
),
suite("kNN search")(
test("search for top two results") {
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
(firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) =>
for {
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
firstDocumentUpdated = firstDocument.copy(vectorField = List(1, 5, -20))
secondDocumentUpdated = secondDocument.copy(vectorField = List(42, 8, -15))
thirdDocumentUpdated = thirdDocument.copy(vectorField = List(15, 11, 23))
req1 = ElasticRequest.create(firstSearchIndex, firstDocumentId, firstDocumentUpdated)
req2 = ElasticRequest.create(firstSearchIndex, secondDocumentId, secondDocumentUpdated)
req3 = ElasticRequest.create(firstSearchIndex, thirdDocumentId, thirdDocumentUpdated)
_ <- Executor.execute(ElasticRequest.bulk(req1, req2, req3).refreshTrue)
query = ElasticQuery.kNN(TestDocument.vectorField, 2, 3, Chunk(-5.0, 9.0, -12.0))
res <- Executor.execute(ElasticRequest.knnSearch(firstSearchIndex, query)).documentAs[TestDocument]
} yield (assert(res)(equalTo(Chunk(firstDocumentUpdated, thirdDocumentUpdated))))
}
} @@ around(
Executor.execute(
ElasticRequest.createIndex(
firstSearchIndex,
"""{ "mappings": { "properties": { "vectorField": { "type": "dense_vector", "dims": 3, "similarity": "l2_norm", "index": true } } } }"""
)
),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("search for top two results with filters") {
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
(firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) =>
for {
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
firstDocumentUpdated = firstDocument.copy(intField = 15, vectorField = List(1, 5, -20))
secondDocumentUpdated = secondDocument.copy(intField = 21, vectorField = List(42, 8, -15))
thirdDocumentUpdated = thirdDocument.copy(intField = 4, vectorField = List(15, 11, 23))
req1 = ElasticRequest.create(firstSearchIndex, firstDocumentId, firstDocumentUpdated)
req2 = ElasticRequest.create(firstSearchIndex, secondDocumentId, secondDocumentUpdated)
req3 = ElasticRequest.create(firstSearchIndex, thirdDocumentId, thirdDocumentUpdated)
_ <- Executor.execute(ElasticRequest.bulk(req1, req2, req3).refreshTrue)
query = ElasticQuery.kNN(TestDocument.vectorField, 2, 3, Chunk(-5.0, 9.0, -12.0))
filter = ElasticQuery.range(TestDocument.intField).gt(10)
res <- Executor
.execute(ElasticRequest.knnSearch(firstSearchIndex, query).filter(filter))
.documentAs[TestDocument]
} yield (assert(res)(equalTo(Chunk(firstDocumentUpdated, secondDocumentUpdated))))
}
} @@ around(
Executor.execute(
ElasticRequest.createIndex(
firstSearchIndex,
"""{ "mappings": { "properties": { "vectorField": { "type": "dense_vector", "dims": 3, "similarity": "l2_norm", "index": true } } } }"""
)
),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
)
) @@ shrinks(0),
suite("searching for documents")(
test("search for a document using a boosting query") {
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ trait IntegrationSpec extends ZIOSpecDefault {
doubleField <- Gen.double(100, 2000)
booleanField <- Gen.boolean
geoPointField <- genGeoPoint
vectorField <- Gen.listOfN(5)(Gen.int(-10, 10))
} yield TestDocument(
stringField = stringField,
dateField = dateField,
subDocumentList = subDocumentList,
intField = intField,
doubleField = doubleField,
booleanField = booleanField,
geoPointField = geoPointField
geoPointField = geoPointField,
vectorField = vectorField
)

def genTestSubDocument: Gen[Any, TestSubDocument] = for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,46 @@ object ElasticQuery {
final def ids(value: String, values: String*): IdsQuery[Any] =
Ids(values = Chunk.fromIterable(value +: values))

/**
* Constructs a type-safe instance of [[zio.elasticsearch.query.KNNQuery]] using the specified parameters.
* [[zio.elasticsearch.query.KNNQuery]] is used to perform a k-nearest neighbor (kNN) search and returns the matching
* documents.
*
* @param field
* the type-safe field for which query is specified for
* @param k
* number of nearest neighbors to return as top hits (must be less than `numCandidates`)
* @param numCandidates
* number of nearest neighbor candidates to consider per shard
* @param queryVector
* query vector
* @tparam S
* document for which field query is executed
* @return
* an instance of [[zio.elasticsearch.query.KNNQuery]] that represents the kNN query to be performed.
*/
final def kNN[S](field: Field[S, _], k: Int, numCandidates: Int, queryVector: Chunk[Double]): KNNQuery[S] =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could rename k to something more meaningful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess k makes more sense to users using it. Additionally, it's a kNN search.

KNN(field = field.toString, k = k, numCandidates = numCandidates, queryVector = queryVector, similarity = None)

/**
* Constructs an instance of [[zio.elasticsearch.query.KNNQuery]] using the specified parameters.
* [[zio.elasticsearch.query.KNNQuery]] is used to perform a k-nearest neighbor (kNN) search and returns the matching
* documents.
*
* @param field
* the field for which query is specified for
* @param k
* number of nearest neighbors to return as top hits (must be less than `numCandidates`)
* @param numCandidates
* number of nearest neighbor candidates to consider per shard
* @param queryVector
* query vector
* @return
* an instance of [[zio.elasticsearch.query.KNNQuery]] that represents the kNN query to be performed.
*/
final def kNN(field: String, k: Int, numCandidates: Int, queryVector: Chunk[Double]): KNNQuery[Any] =
KNN(field = field, k = k, numCandidates = numCandidates, queryVector = queryVector, similarity = None)

/**
* Constructs an instance of [[zio.elasticsearch.query.MatchAllQuery]] used for matching all documents.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ import zio.elasticsearch.IndexSelector.IndexNameSyntax
import zio.elasticsearch.aggregation.ElasticAggregation
import zio.elasticsearch.executor.response.BulkResponse
import zio.elasticsearch.highlights.Highlights
import zio.elasticsearch.query.ElasticQuery
import zio.elasticsearch.query.sort.Sort
import zio.elasticsearch.query.{ElasticQuery, KNNQuery}
import zio.elasticsearch.request._
import zio.elasticsearch.request.options._
import zio.elasticsearch.result.{
AggregateResult,
GetResult,
KNNSearchResult,
SearchAndAggregateResult,
SearchResult,
UpdateByQueryResult
Expand Down Expand Up @@ -215,6 +216,20 @@ object ElasticRequest {
final def getById(index: IndexName, id: DocumentId): GetByIdRequest =
GetById(index = index, id = id, refresh = None, routing = None)

/**
* Constructs an instance of [[KNNRequest]] used for performing a k-nearest neighbour (kNN) search. Given a query
* vector, it finds the k closest vectors and returns those documents as search hits.
*
* @param selectors
* the name of the index or more indices to search in
* @param query
* an instance of [[zio.elasticsearch.query.KNNQuery]] to run
* @return
* an instance of [[KNNRequest]] that represents k-nearest neighbour (kNN) operation to be performed.
*/
final def knnSearch[I: IndexSelector](selectors: I, query: KNNQuery[_]): KNNRequest =
KNN(knn = query, selectors = selectors.toSelector, filter = None, routing = None)

/**
* Constructs an instance of [[RefreshRequest]] used for refreshing an index with the specified name.
*
Expand Down Expand Up @@ -593,6 +608,40 @@ object ElasticRequest {
self.copy(routing = Some(value))
}

sealed trait KNNRequest extends ElasticRequest[KNNSearchResult] with HasRouting[KNNRequest] {

/**
* Adds an [[zio.elasticsearch.ElasticQuery]] to the [[zio.elasticsearch.ElasticRequest.KNNRequest]] to filter the
* documents that can match. If not provided, all documents are allowed to match.
*
* @param query
* the Elastic query to be added
drmarjanovic marked this conversation as resolved.
Show resolved Hide resolved
* @return
* an instance of a [[zio.elasticsearch.ElasticRequest.KNNRequest]] that represents the kNN search operation
* enriched with filter query to be performed.
drmarjanovic marked this conversation as resolved.
Show resolved Hide resolved
*/
def filter(query: ElasticQuery[_]): KNNRequest
}

private[elasticsearch] final case class KNN(
knn: KNNQuery[_],
selectors: String,
filter: Option[ElasticQuery[_]],
routing: Option[Routing]
) extends KNNRequest { self =>

def filter(query: ElasticQuery[_]): KNNRequest =
self.copy(filter = Some(query))

def routing(value: Routing): KNNRequest =
self.copy(routing = Some(value))

private[elasticsearch] def toJson: Json = {
val filterJson: Json = filter.fold(Obj())(f => Obj("filter" -> f.toJson(None)))
Obj("knn" -> knn.toJson) merge filterJson
}
}

sealed trait RefreshRequest extends ElasticRequest[Boolean]

private[elasticsearch] final case class Refresh(selectors: String) extends RefreshRequest
Expand All @@ -612,7 +661,7 @@ object ElasticRequest {
* [[zio.elasticsearch.ElasticRequest.SearchRequest]].
*
* @param aggregation
* the elastic aggregation to be added
* the Elastic aggregation to be added
drmarjanovic marked this conversation as resolved.
Show resolved Hide resolved
* @return
* an instance of a [[zio.elasticsearch.ElasticRequest.SearchAndAggregateRequest]] that represents search and
* aggregate operations to be performed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ private[elasticsearch] final class HttpExecutor private (esConfig: ElasticConfig
case r: DeleteIndex => executeDeleteIndex(r)
case r: Exists => executeExists(r)
case r: GetById => executeGetById(r)
case r: KNN => executeKnn(r)
case r: Refresh => executeRefresh(r)
case r: Search => executeSearch(r)
case r: SearchAndAggregate => executeSearchAndAggregate(r)
Expand Down Expand Up @@ -372,6 +373,31 @@ private[elasticsearch] final class HttpExecutor private (esConfig: ElasticConfig
}
}

private def executeKnn(r: KNN): Task[KNNSearchResult] = {
val uri = uri"${esConfig.uri}/${r.selectors}/_knn_search".withParams(
getQueryParams(Chunk(("routing", r.routing)))
)

sendRequestWithCustomResponse[SearchWithAggregationsResponse](
baseRequest
.post(uri)
.response(asJson[SearchWithAggregationsResponse])
.contentType(ApplicationJson)
.body(r.toJson)
).flatMap { response =>
response.code match {
case HttpOk =>
response.body.fold(
e => ZIO.fail(new ElasticException(s"Exception occurred: ${e.getMessage}")),
value =>
ZIO.succeed(new KNNSearchResult(itemsFromDocumentsWithHighlights(value.resultsWithHighlightsAndSort)))
)
case _ =>
ZIO.fail(handleFailuresFromCustomResponse(response))
}
}
}

private def executeRefresh(r: Refresh): Task[Boolean] =
sendRequest(baseRequest.get(uri"${esConfig.uri}/${r.selectors}/$Refresh")).flatMap { response =>
response.code match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,44 @@ private[elasticsearch] final case class Exists[S](field: String, boost: Option[D
)
}

sealed trait KNNQuery[-S] { self =>

/**
* Sets the `similarity` parameter for the [[zio.elasticsearch.query.KNNQuery]]. The `similarity` parameter is the
* required minimum similarity for a vector to be considered a match.
dbulaja98 marked this conversation as resolved.
Show resolved Hide resolved
*
* @param value
* a non-negative real number used for the `similarity`
* @return
* an instance of [[zio.elasticsearch.query.KNNQuery]] enriched with the `similarity` parameter.
*/
def similarity(value: Double): KNNQuery[S]

private[elasticsearch] def toJson: Json
}

private[elasticsearch] final case class KNN[S](
field: String,
k: Int,
numCandidates: Int,
queryVector: Chunk[Double],
similarity: Option[Double]
) extends KNNQuery[S] { self =>

def similarity(value: Double): KNN[S] =
self.copy(similarity = Some(value))

private[elasticsearch] def toJson: Json = {
val similarityJson = similarity.fold(Obj())(s => Obj("similarity" -> s.toJson))
Obj(
"field" -> field.toJson,
"query_vector" -> Arr(queryVector.map(_.toJson)),
"k" -> k.toJson,
"num_candidates" -> numCandidates.toJson
) merge similarityJson
}
}

sealed trait FunctionScoreQuery[S] extends ElasticQuery[S] with HasBoost[FunctionScoreQuery[S]] {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ final class GetResult private[elasticsearch] (private val doc: Option[Item]) ext
})
}

final class KNNSearchResult private[elasticsearch] (private val hits: Chunk[Item]) extends DocumentResult[Chunk] {

def documentAs[A: Schema]: IO[DecodingException, Chunk[A]] =
ZIO.fromEither {
ZValidation.validateAll(hits.map(item => ZValidation.fromEither(item.documentAs))).toEitherWith { errors =>
DecodingException(s"Could not parse all documents successfully: ${errors.map(_.message).mkString(", ")}")
}
}

lazy val items: UIO[Chunk[Item]] = ZIO.succeed(hits)
}

final class SearchResult private[elasticsearch] (
private val hits: Chunk[Item],
private val fullResponse: SearchWithAggregationsResponse
Expand Down
Loading