diff --git a/README.md b/README.md index c3ed09e70..81c697683 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,8 @@ A modern text indexing library in go * Conjunction, Disjunction, Boolean (must/should/must_not) * Term Range, Numeric Range, Date Range * [Geo Spatial](https://github.com/blevesearch/bleve/blob/master/geo/README.md) - * Simple [query string syntax](http://www.blevesearch.com/docs/Query-String-Query/) for human entry + * Simple [query string syntax](http://www.blevesearch.com/docs/Query-String-Query/) + * [Vector Search](https://github.com/blevesearch/bleve/blob/master/vectors.md) * [tf-idf](https://en.wikipedia.org/wiki/Tf-idf) Scoring * Query time boosting * Search result match highlighting with document fragments diff --git a/document/field_vector.go b/document/field_vector.go index 59ac02026..b019361cb 100644 --- a/document/field_vector.go +++ b/document/field_vector.go @@ -35,12 +35,13 @@ func init() { const DefaultVectorIndexingOptions = index.IndexField type VectorField struct { - name string - dims int // Dimensionality of the vector - similarity string // Similarity metric to use for scoring - options index.FieldIndexingOptions - value []float32 - numPlainTextBytes uint64 + name string + dims int // Dimensionality of the vector + similarity string // Similarity metric to use for scoring + options index.FieldIndexingOptions + value []float32 + numPlainTextBytes uint64 + vectorIndexOptimizedFor string // Optimization applied to this index. } func (n *VectorField) Size() int { @@ -95,25 +96,27 @@ func (n *VectorField) GoString() string { // For the sake of not polluting the API, we are keeping arrayPositions as a // parameter, but it is not used. func NewVectorField(name string, arrayPositions []uint64, - vector []float32, dims int, similarity string) *VectorField { + vector []float32, dims int, similarity, vectorIndexOptimizedFor string) *VectorField { return NewVectorFieldWithIndexingOptions(name, arrayPositions, - vector, dims, similarity, DefaultVectorIndexingOptions) + vector, dims, similarity, vectorIndexOptimizedFor, + DefaultVectorIndexingOptions) } // For the sake of not polluting the API, we are keeping arrayPositions as a // parameter, but it is not used. func NewVectorFieldWithIndexingOptions(name string, arrayPositions []uint64, - vector []float32, dims int, similarity string, + vector []float32, dims int, similarity, vectorIndexOptimizedFor string, options index.FieldIndexingOptions) *VectorField { options = options | DefaultVectorIndexingOptions return &VectorField{ - name: name, - dims: dims, - similarity: similarity, - options: options, - value: vector, - numPlainTextBytes: numBytesFloat32s(vector), + name: name, + dims: dims, + similarity: similarity, + options: options, + value: vector, + numPlainTextBytes: numBytesFloat32s(vector), + vectorIndexOptimizedFor: vectorIndexOptimizedFor, } } @@ -136,3 +139,7 @@ func (n *VectorField) Dims() int { func (n *VectorField) Similarity() string { return n.similarity } + +func (n *VectorField) IndexOptimizedFor() string { + return n.vectorIndexOptimizedFor +} diff --git a/error.go b/error.go index 7dd21194c..2d2751cd4 100644 --- a/error.go +++ b/error.go @@ -26,6 +26,7 @@ const ( ErrorUnknownIndexType ErrorEmptyID ErrorIndexReadInconsistency + ErrorTwoPhaseSearchInconsistency ) // Error represents a more strongly typed bleve error for detecting @@ -37,14 +38,15 @@ func (e Error) Error() string { } var errorMessages = map[Error]string{ - ErrorIndexPathExists: "cannot create new index, path already exists", - ErrorIndexPathDoesNotExist: "cannot open index, path does not exist", - ErrorIndexMetaMissing: "cannot open index, metadata missing", - ErrorIndexMetaCorrupt: "cannot open index, metadata corrupt", - ErrorIndexClosed: "index is closed", - ErrorAliasMulti: "cannot perform single index operation on multiple index alias", - ErrorAliasEmpty: "cannot perform operation on empty alias", - ErrorUnknownIndexType: "unknown index type", - ErrorEmptyID: "document ID cannot be empty", - ErrorIndexReadInconsistency: "index read inconsistency detected", + ErrorIndexPathExists: "cannot create new index, path already exists", + ErrorIndexPathDoesNotExist: "cannot open index, path does not exist", + ErrorIndexMetaMissing: "cannot open index, metadata missing", + ErrorIndexMetaCorrupt: "cannot open index, metadata corrupt", + ErrorIndexClosed: "index is closed", + ErrorAliasMulti: "cannot perform single index operation on multiple index alias", + ErrorAliasEmpty: "cannot perform operation on empty alias", + ErrorUnknownIndexType: "unknown index type", + ErrorEmptyID: "document ID cannot be empty", + ErrorIndexReadInconsistency: "index read inconsistency detected", + ErrorTwoPhaseSearchInconsistency: "2-phase search failed, likely due to an overlapping topology change", } diff --git a/go.mod b/go.mod index 5a922d8d0..078ac98e1 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,17 @@ module github.com/blevesearch/bleve/v2 -go 1.19 +go 1.20 require ( github.com/RoaringBitmap/roaring v1.2.3 github.com/bits-and-blooms/bitset v1.2.0 - github.com/blevesearch/bleve_index_api v1.0.6 - github.com/blevesearch/geo v0.1.18 + github.com/blevesearch/bleve_index_api v1.1.6 + github.com/blevesearch/geo v0.1.20 github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/blevesearch/go-porterstemmer v1.0.3 github.com/blevesearch/goleveldb v1.0.1 github.com/blevesearch/gtreap v0.1.1 - github.com/blevesearch/scorch_segment_api/v2 v2.1.6 + github.com/blevesearch/scorch_segment_api/v2 v2.2.9 github.com/blevesearch/segment v0.9.1 github.com/blevesearch/snowball v0.6.1 github.com/blevesearch/snowballstem v0.9.0 @@ -23,6 +23,7 @@ require ( github.com/blevesearch/zapx/v13 v13.3.10 github.com/blevesearch/zapx/v14 v14.3.10 github.com/blevesearch/zapx/v15 v15.3.13 + github.com/blevesearch/zapx/v16 v16.0.12 github.com/couchbase/moss v0.2.0 github.com/golang/protobuf v1.3.2 github.com/spf13/cobra v1.7.0 @@ -31,6 +32,7 @@ require ( ) require ( + github.com/blevesearch/go-faiss v1.0.13 // indirect github.com/blevesearch/mmap-go v1.0.4 // indirect github.com/couchbase/ghistogram v0.1.0 // indirect github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 // indirect @@ -39,5 +41,5 @@ require ( github.com/json-iterator/go v0.0.0-20171115153421-f7279a603ede // indirect github.com/mschoch/smat v0.2.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - golang.org/x/sys v0.5.0 // indirect + golang.org/x/sys v0.13.0 // indirect ) diff --git a/go.sum b/go.sum index 90ebfb29b..2cba39cd0 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,12 @@ github.com/RoaringBitmap/roaring v1.2.3 h1:yqreLINqIrX22ErkKI0vY47/ivtJr6n+kMhVO github.com/RoaringBitmap/roaring v1.2.3/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE= github.com/bits-and-blooms/bitset v1.2.0 h1:Kn4yilvwNtMACtf1eYDlG8H77R07mZSPbMjLyS07ChA= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= -github.com/blevesearch/bleve_index_api v1.0.6 h1:gyUUxdsrvmW3jVhhYdCVL6h9dCjNT/geNU7PxGn37p8= -github.com/blevesearch/bleve_index_api v1.0.6/go.mod h1:YXMDwaXFFXwncRS8UobWs7nvo0DmusriM1nztTlj1ms= -github.com/blevesearch/geo v0.1.18 h1:Np8jycHTZ5scFe7VEPLrDoHnnb9C4j636ue/CGrhtDw= -github.com/blevesearch/geo v0.1.18/go.mod h1:uRMGWG0HJYfWfFJpK3zTdnnr1K+ksZTuWKhXeSokfnM= +github.com/blevesearch/bleve_index_api v1.1.6 h1:orkqDFCBuNU2oHW9hN2YEJmet+TE9orml3FCGbl1cKk= +github.com/blevesearch/bleve_index_api v1.1.6/go.mod h1:PbcwjIcRmjhGbkS/lJCpfgVSMROV6TRubGGAODaK1W8= +github.com/blevesearch/geo v0.1.20 h1:paaSpu2Ewh/tn5DKn/FB5SzvH0EWupxHEIwbCk/QPqM= +github.com/blevesearch/geo v0.1.20/go.mod h1:DVG2QjwHNMFmjo+ZgzrIq2sfCh6rIHzy9d9d0B59I6w= +github.com/blevesearch/go-faiss v1.0.13 h1:zfFs7ZYD0NqXVSY37j0JZjZT1BhE9AE4peJfcx/NB4A= +github.com/blevesearch/go-faiss v1.0.13/go.mod h1:jrxHrbl42X/RnDPI+wBoZU8joxxuRwedrxqswQ3xfU8= github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:kDy+zgJFJJoJYBvdfBSiZYBbdsUL0XcjHYWezpQBGPA= github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:9eJDeqxJ3E7WnLebQUlPD7ZjSce7AnDb9vjGmMCbD0A= github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo= @@ -17,8 +19,8 @@ github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgY github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.1.6 h1:CdekX/Ob6YCYmeHzD72cKpwzBjvkOGegHOqhAkXp6yA= -github.com/blevesearch/scorch_segment_api/v2 v2.1.6/go.mod h1:nQQYlp51XvoSVxcciBjtvuHPIVjlWrN1hX4qwK2cqdc= +github.com/blevesearch/scorch_segment_api/v2 v2.2.9 h1:3nBaSBRFokjE4FtPW3eUDgcAu3KphBg1GP07zy/6Uyk= +github.com/blevesearch/scorch_segment_api/v2 v2.2.9/go.mod h1:ckbeb7knyOOvAdZinn/ASbB7EA3HoagnJkmEV3J7+sg= github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowball v0.6.1 h1:cDYjn/NCH+wwt2UdehaLpr2e4BwLIjN4V/TdLsL+B5A= @@ -41,6 +43,8 @@ github.com/blevesearch/zapx/v14 v14.3.10 h1:SG6xlsL+W6YjhX5N3aEiL/2tcWh3DO75Bnz7 github.com/blevesearch/zapx/v14 v14.3.10/go.mod h1:qqyuR0u230jN1yMmE4FIAuCxmahRQEOehF78m6oTgns= github.com/blevesearch/zapx/v15 v15.3.13 h1:6EkfaZiPlAxqXz0neniq35my6S48QI94W/wyhnpDHHQ= github.com/blevesearch/zapx/v15 v15.3.13/go.mod h1:Turk/TNRKj9es7ZpKK95PS7f6D44Y7fAFy8F4LXQtGg= +github.com/blevesearch/zapx/v16 v16.0.12 h1:Uccxvjmn+hQ6ywQP+wIiTpdq9LnAviGoryJOmGwAo/I= +github.com/blevesearch/zapx/v16 v16.0.12/go.mod h1:MYnOshRfSm4C4drxx1LGRI+MVFByykJ2anDY1fxdk9Q= github.com/couchbase/ghistogram v0.1.0 h1:b95QcQTCzjTUocDXp/uMgSNQi8oj1tGwnJ4bODWZnps= github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= github.com/couchbase/moss v0.2.0 h1:VCYrMzFwEryyhRSeI+/b3tRBSeTpi/8gn5Kf6dxqn+o= @@ -89,8 +93,8 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181221143128-b4a75ba826a6/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= diff --git a/index/scorch/introducer.go b/index/scorch/introducer.go index 123e71d63..2cb1398ec 100644 --- a/index/scorch/introducer.go +++ b/index/scorch/introducer.go @@ -30,6 +30,7 @@ type segmentIntroduction struct { obsoletes map[uint64]*roaring.Bitmap ids []string internal map[string][]byte + stats *fieldStats applied chan error persisted chan error @@ -146,7 +147,9 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error { newss := &SegmentSnapshot{ id: root.segment[i].id, segment: root.segment[i].segment, + stats: root.segment[i].stats, cachedDocs: root.segment[i].cachedDocs, + cachedMeta: root.segment[i].cachedMeta, creator: root.segment[i].creator, } @@ -154,7 +157,11 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error { if root.segment[i].deleted == nil { newss.deleted = delta } else { - newss.deleted = roaring.Or(root.segment[i].deleted, delta) + if delta.IsEmpty() { + newss.deleted = root.segment[i].deleted + } else { + newss.deleted = roaring.Or(root.segment[i].deleted, delta) + } } if newss.deleted.IsEmpty() { newss.deleted = nil @@ -188,7 +195,9 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error { newSegmentSnapshot := &SegmentSnapshot{ id: next.id, segment: next.data, // take ownership of next.data's ref-count + stats: next.stats, cachedDocs: &cachedDocs{cache: nil}, + cachedMeta: &cachedMeta{meta: nil}, creator: "introduceSegment", } newSnapshot.segment = append(newSnapshot.segment, newSegmentSnapshot) @@ -275,7 +284,9 @@ func (s *Scorch) introducePersist(persist *persistIntroduction) { id: segmentSnapshot.id, segment: replacement, deleted: segmentSnapshot.deleted, + stats: segmentSnapshot.stats, cachedDocs: segmentSnapshot.cachedDocs, + cachedMeta: segmentSnapshot.cachedMeta, creator: "introducePersist", mmaped: 1, } @@ -374,7 +385,9 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { id: root.segment[i].id, segment: root.segment[i].segment, deleted: root.segment[i].deleted, + stats: root.segment[i].stats, cachedDocs: root.segment[i].cachedDocs, + cachedMeta: root.segment[i].cachedMeta, creator: root.segment[i].creator, }) root.segment[i].segment.AddRef() @@ -394,7 +407,6 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { } } } - // before the newMerge introduction, need to clean the newly // merged segment wrt the current root segments, hence // applying the obsolete segment contents to newly merged segment @@ -415,12 +427,19 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { if nextMerge.new != nil && nextMerge.new.Count() > newSegmentDeleted.GetCardinality() { + stats := newFieldStats() + if fsr, ok := nextMerge.new.(segment.FieldStatsReporter); ok { + fsr.UpdateFieldStats(stats) + } + // put new segment at end newSnapshot.segment = append(newSnapshot.segment, &SegmentSnapshot{ id: nextMerge.id, segment: nextMerge.new, // take ownership for nextMerge.new's ref-count deleted: newSegmentDeleted, + stats: stats, cachedDocs: &cachedDocs{cache: nil}, + cachedMeta: &cachedMeta{meta: nil}, creator: "introduceMerge", mmaped: nextMerge.mmaped, }) diff --git a/index/scorch/merge.go b/index/scorch/merge.go index 92adc3fd4..339ec5969 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -290,7 +290,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, atomic.AddUint64(&s.stats.TotFileMergePlanTasksSegments, uint64(len(task.Segments))) - oldMap := make(map[uint64]*SegmentSnapshot) + oldMap := make(map[uint64]*SegmentSnapshot, len(task.Segments)) newSegmentID := atomic.AddUint64(&s.nextSegmentID, 1) segmentsToMerge := make([]segment.Segment, 0, len(task.Segments)) docsToDrop := make([]*roaring.Bitmap, 0, len(task.Segments)) @@ -357,7 +357,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, totalBytesRead := seg.BytesRead() + prevBytesReadTotal seg.ResetBytesRead(totalBytesRead) - oldNewDocNums = make(map[uint64][]uint64) + oldNewDocNums = make(map[uint64][]uint64, len(newDocNums)) for i, segNewDocNums := range newDocNums { oldNewDocNums[task.Segments[i].Id()] = segNewDocNums } @@ -485,8 +485,8 @@ func (s *Scorch) mergeSegmentBases(snapshot *IndexSnapshot, sm := &segmentMerge{ id: newSegmentID, - old: make(map[uint64]*SegmentSnapshot), - oldNewDocNums: make(map[uint64][]uint64), + old: make(map[uint64]*SegmentSnapshot, len(sbsIndexes)), + oldNewDocNums: make(map[uint64][]uint64, len(sbsIndexes)), new: seg, notifyCh: make(chan *mergeTaskIntroStatus), } diff --git a/index/scorch/optimize.go b/index/scorch/optimize.go index 3c7969fa9..968a744ac 100644 --- a/index/scorch/optimize.go +++ b/index/scorch/optimize.go @@ -16,10 +16,11 @@ package scorch import ( "fmt" + "sync/atomic" + "github.com/RoaringBitmap/roaring" index "github.com/blevesearch/bleve_index_api" segment "github.com/blevesearch/scorch_segment_api/v2" - "sync/atomic" ) var OptimizeConjunction = true diff --git a/index/scorch/optimize_knn.go b/index/scorch/optimize_knn.go new file mode 100644 index 000000000..330e214f3 --- /dev/null +++ b/index/scorch/optimize_knn.go @@ -0,0 +1,187 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package scorch + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "github.com/blevesearch/bleve/v2/search" + index "github.com/blevesearch/bleve_index_api" + segment_api "github.com/blevesearch/scorch_segment_api/v2" +) + +type OptimizeVR struct { + ctx context.Context + snapshot *IndexSnapshot + totalCost uint64 + // maps field to vector readers + vrs map[string][]*IndexSnapshotVectorReader +} + +// This setting _MUST_ only be changed during init and not after. +var BleveMaxKNNConcurrency = 10 + +func (o *OptimizeVR) invokeSearcherEndCallback() { + if o.ctx != nil { + if cb := o.ctx.Value(search.SearcherEndCallbackKey); cb != nil { + if cbF, ok := cb.(search.SearcherEndCallbackFn); ok { + if o.totalCost > 0 { + // notify the callback that the searcher creation etc. is finished + // and report back the total cost for it to track and take actions + // appropriately. + _ = cbF(o.totalCost) + } + } + } + } +} + +func (o *OptimizeVR) Finish() error { + // for each field, get the vector index --> invoke the zap func. + // for each VR, populate postings list and iterators + // by passing the obtained vector index and getting similar vectors. + // defer close index - just once. + var errorsM sync.Mutex + var errors []error + + defer o.invokeSearcherEndCallback() + + wg := sync.WaitGroup{} + semaphore := make(chan struct{}, BleveMaxKNNConcurrency) + // Launch goroutines to get vector index for each segment + for i, seg := range o.snapshot.segment { + if sv, ok := seg.segment.(segment_api.VectorSegment); ok { + wg.Add(1) + semaphore <- struct{}{} // Acquire a semaphore slot + go func(index int, segment segment_api.VectorSegment, origSeg *SegmentSnapshot) { + defer func() { + <-semaphore // Release the semaphore slot + wg.Done() + }() + for field, vrs := range o.vrs { + vecIndex, err := segment.InterpretVectorIndex(field) + if err != nil { + errorsM.Lock() + errors = append(errors, err) + errorsM.Unlock() + return + } + + // update the vector index size as a meta value in the segment snapshot + vectorIndexSize := vecIndex.Size() + origSeg.cachedMeta.updateMeta(field, vectorIndexSize) + for _, vr := range vrs { + // for each VR, populate postings list and iterators + // by passing the obtained vector index and getting similar vectors. + pl, err := vecIndex.Search(vr.vector, vr.k, origSeg.deleted) + if err != nil { + errorsM.Lock() + errors = append(errors, err) + errorsM.Unlock() + go vecIndex.Close() + return + } + + atomic.AddUint64(&o.snapshot.parent.stats.TotKNNSearches, uint64(1)) + + // postings and iterators are already alloc'ed when + // IndexSnapshotVectorReader is created + vr.postings[index] = pl + vr.iterators[index] = pl.Iterator(vr.iterators[index]) + } + go vecIndex.Close() + } + }(i, sv, seg) + } + } + wg.Wait() + close(semaphore) + if len(errors) > 0 { + return errors[0] + } + return nil +} + +func (s *IndexSnapshotVectorReader) VectorOptimize(ctx context.Context, + octx index.VectorOptimizableContext) (index.VectorOptimizableContext, error) { + + if s.snapshot.parent.segPlugin.Version() < VectorSearchSupportedSegmentVersion { + return nil, fmt.Errorf("vector search not supported for this index, "+ + "index's segment version %v, supported segment version for vector search %v", + s.snapshot.parent.segPlugin.Version(), VectorSearchSupportedSegmentVersion) + } + + if octx == nil { + octx = &OptimizeVR{snapshot: s.snapshot, + vrs: make(map[string][]*IndexSnapshotVectorReader), + } + } + + o, ok := octx.(*OptimizeVR) + if !ok { + return octx, nil + } + o.ctx = ctx + + if o.snapshot != s.snapshot { + o.invokeSearcherEndCallback() + return nil, fmt.Errorf("tried to optimize KNN across different snapshots") + } + + // for every searcher creation, consult the segment snapshot to see + // what's the vector index size and since you're anyways going + // to use this vector index to perform the search etc. as part of the Finish() + // perform a check as to whether we allow the searcher creation (the downstream) + // Finish() logic to even occur or not. + var sumVectorIndexSize uint64 + for _, seg := range o.snapshot.segment { + vecIndexSize := seg.cachedMeta.fetchMeta(s.field) + if vecIndexSize != nil { + sumVectorIndexSize += vecIndexSize.(uint64) + } + } + + if o.ctx != nil { + if cb := o.ctx.Value(search.SearcherStartCallbackKey); cb != nil { + if cbF, ok := cb.(search.SearcherStartCallbackFn); ok { + err := cbF(sumVectorIndexSize) + if err != nil { + // it's important to invoke the end callback at this point since + // if the earlier searchers of this optimze struct were successful + // the cost corresponding to it would be incremented and if the + // current searcher fails the check then we end up erroring out + // the overall optimized searcher creation, the cost needs to be + // handled appropriately. + o.invokeSearcherEndCallback() + return nil, err + } + } + } + } + + // total cost is essentially the sum of the vector indexes' size across all the + // searchers - all of them end up reading and maintaining a vector index. + // misacconting this value would end up calling the "end" callback with a value + // not equal to the value passed to "start" callback. + o.totalCost += sumVectorIndexSize + o.vrs[s.field] = append(o.vrs[s.field], s) + return o, nil +} diff --git a/index/scorch/persister.go b/index/scorch/persister.go index 217582fe1..afd518dde 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -17,6 +17,7 @@ package scorch import ( "bytes" "encoding/binary" + "encoding/json" "fmt" "io" "log" @@ -424,6 +425,7 @@ func (s *Scorch) persistSnapshotMaybeMerge(snapshot *IndexSnapshot) ( id: newSegmentID, segment: segment.segment, deleted: nil, // nil since merging handled deletions + stats: nil, }) break } @@ -602,6 +604,18 @@ func prepareBoltSnapshot(snapshot *IndexSnapshot, tx *bolt.Tx, path string, return nil, nil, err } } + + // store segment stats + if segmentSnapshot.stats != nil { + b, err := json.Marshal(segmentSnapshot.stats.Fetch()) + if err != nil { + return nil, nil, err + } + err = snapshotSegmentBucket.Put(boltStatsKey, b) + if err != nil { + return nil, nil, err + } + } } return filenames, newSegmentPaths, nil @@ -634,7 +648,7 @@ func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot) (err error) { // the newly populated boltdb snapshotBucket above if len(newSegmentPaths) > 0 { // now try to open all the new snapshots - newSegments := make(map[uint64]segment.Segment) + newSegments := make(map[uint64]segment.Segment, len(newSegmentPaths)) defer func() { for _, s := range newSegments { if s != nil { @@ -704,6 +718,7 @@ var boltMetaDataKey = []byte{'m'} var boltMetaDataSegmentTypeKey = []byte("type") var boltMetaDataSegmentVersionKey = []byte("version") var boltMetaDataTimeStamp = []byte("timeStamp") +var boltStatsKey = []byte("stats") var TotBytesWrittenKey = []byte("TotBytesWritten") func (s *Scorch) loadFromBolt() error { @@ -858,6 +873,7 @@ func (s *Scorch) loadSegment(segmentBucket *bolt.Bucket) (*SegmentSnapshot, erro rv := &SegmentSnapshot{ segment: segment, cachedDocs: &cachedDocs{cache: nil}, + cachedMeta: &cachedMeta{meta: nil}, } deletedBytes := segmentBucket.Get(boltDeletedKey) if deletedBytes != nil { @@ -872,6 +888,18 @@ func (s *Scorch) loadSegment(segmentBucket *bolt.Bucket) (*SegmentSnapshot, erro rv.deleted = deletedBitmap } } + statBytes := segmentBucket.Get(boltStatsKey) + if statBytes != nil { + var statsMap map[string]map[string]uint64 + + err := json.Unmarshal(statBytes, &statsMap) + stats := &fieldStats{statMap: statsMap} + if err != nil { + _ = segment.Close() + return nil, fmt.Errorf("error reading stat bytes: %v", err) + } + rv.stats = stats + } return rv, nil } diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index f30d795e9..2e6435ee0 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -428,6 +428,8 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { var newSegment segment.Segment var bufBytes uint64 + stats := newFieldStats() + if len(analysisResults) > 0 { newSegment, bufBytes, err = s.segPlugin.New(analysisResults) if err != nil { @@ -438,11 +440,14 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { segB.BytesWritten()) } atomic.AddUint64(&s.iStats.newSegBufBytesAdded, bufBytes) + if fsr, ok := newSegment.(segment.FieldStatsReporter); ok { + fsr.UpdateFieldStats(stats) + } } else { atomic.AddUint64(&s.stats.TotBatchesEmpty, 1) } - err = s.prepareSegment(newSegment, ids, batch.InternalOps, batch.PersistedCallback()) + err = s.prepareSegment(newSegment, ids, batch.InternalOps, batch.PersistedCallback(), stats) if err != nil { if newSegment != nil { _ = newSegment.Close() @@ -462,15 +467,15 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { } func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string, - internalOps map[string][]byte, persistedCallback index.BatchCallback) error { + internalOps map[string][]byte, persistedCallback index.BatchCallback, stats *fieldStats) error { // new introduction introduction := &segmentIntroduction{ id: atomic.AddUint64(&s.nextSegmentID, 1), data: newSegment, ids: ids, - obsoletes: make(map[uint64]*roaring.Bitmap), internal: internalOps, + stats: stats, applied: make(chan error), persistedCallback: persistedCallback, } @@ -487,6 +492,8 @@ func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string, defer func() { _ = root.DecRef() }() + introduction.obsoletes = make(map[uint64]*roaring.Bitmap, len(root.segment)) + for _, seg := range root.segment { delta, err := seg.segment.DocNumbers(ids) if err != nil { @@ -617,6 +624,8 @@ func (s *Scorch) StatsMap() map[string]interface{} { m["index_time"] = m["TotIndexTime"] m["term_searchers_started"] = m["TotTermSearchersStarted"] m["term_searchers_finished"] = m["TotTermSearchersFinished"] + m["knn_searches"] = m["TotKNNSearches"] + m["num_bytes_read_at_query_time"] = m["TotBytesReadAtQueryTime"] m["num_plain_text_bytes_indexed"] = m["TotIndexedPlainTextBytes"] m["num_bytes_written_at_index_time"] = m["TotBytesWrittenAtIndexTime"] @@ -638,6 +647,20 @@ func (s *Scorch) StatsMap() map[string]interface{} { m["num_persister_nap_merger_break"] = m["TotPersisterMergerNapBreak"] m["total_compaction_written_bytes"] = m["TotFileMergeWrittenBytes"] + // calculate the aggregate of all the segment's field stats + aggFieldStats := newFieldStats() + for _, segmentSnapshot := range indexSnapshot.Segments() { + if segmentSnapshot.stats != nil { + aggFieldStats.Aggregate(segmentSnapshot.stats) + } + } + + aggFieldStatsMap := aggFieldStats.Fetch() + for statName, stats := range aggFieldStatsMap { + for fieldName, val := range stats { + m["field:"+fieldName+":"+statName] = val + } + } return m } @@ -762,3 +785,50 @@ func parseToInteger(i interface{}) (int, error) { return 0, fmt.Errorf("expects int or float64 value") } } + +// Holds Zap's field level stats at a segment level +type fieldStats struct { + // StatName -> FieldName -> value + statMap map[string]map[string]uint64 +} + +// Add the data into the map after checking if the statname is valid +func (fs *fieldStats) Store(statName, fieldName string, value uint64) { + if _, exists := fs.statMap[statName]; !exists { + fs.statMap[statName] = make(map[string]uint64) + } + fs.statMap[statName][fieldName] = value +} + +// Combine the given stats map with the existing map +func (fs *fieldStats) Aggregate(stats segment.FieldStats) { + + statMap := stats.Fetch() + if statMap == nil { + return + } + for statName, statMap := range statMap { + if _, exists := fs.statMap[statName]; !exists { + fs.statMap[statName] = make(map[string]uint64) + } + for fieldName, val := range statMap { + if _, exists := fs.statMap[statName][fieldName]; !exists { + fs.statMap[statName][fieldName] = 0 + } + fs.statMap[statName][fieldName] += val + } + } +} + +// Returns the stats map +func (fs *fieldStats) Fetch() map[string]map[string]uint64 { + return fs.statMap +} + +// Initializes an empty stats map +func newFieldStats() *fieldStats { + rv := &fieldStats{ + statMap: map[string]map[string]uint64{}, + } + return rv +} diff --git a/index/scorch/scorch_test.go b/index/scorch/scorch_test.go index f9064b67f..b8f493fe8 100644 --- a/index/scorch/scorch_test.go +++ b/index/scorch/scorch_test.go @@ -2663,3 +2663,24 @@ func TestReadOnlyIndex(t *testing.T) { t.Errorf("Expected document count to be %d got %d", 1, docCount) } } + +func BenchmarkAggregateFieldStats(b *testing.B) { + + fieldStatsArray := make([]*fieldStats, 1000) + + for i := range fieldStatsArray { + fieldStatsArray[i] = newFieldStats() + + fieldStatsArray[i].Store("num_vectors", "vector", uint64(rand.Intn(1000))) + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + aggFieldStats := newFieldStats() + + for _, fs := range fieldStatsArray { + aggFieldStats.Aggregate(fs) + } + } +} diff --git a/index/scorch/segment_plugin.go b/index/scorch/segment_plugin.go index a84d2d55f..b3b9ba01f 100644 --- a/index/scorch/segment_plugin.go +++ b/index/scorch/segment_plugin.go @@ -28,6 +28,7 @@ import ( zapv13 "github.com/blevesearch/zapx/v13" zapv14 "github.com/blevesearch/zapx/v14" zapv15 "github.com/blevesearch/zapx/v15" + zapv16 "github.com/blevesearch/zapx/v16" ) // SegmentPlugin represents the essential functions required by a package to plug in @@ -73,7 +74,8 @@ var defaultSegmentPlugin SegmentPlugin func init() { ResetSegmentPlugins() - RegisterSegmentPlugin(&zapv15.ZapPlugin{}, true) + RegisterSegmentPlugin(&zapv16.ZapPlugin{}, true) + RegisterSegmentPlugin(&zapv15.ZapPlugin{}, false) RegisterSegmentPlugin(&zapv14.ZapPlugin{}, false) RegisterSegmentPlugin(&zapv13.ZapPlugin{}, false) RegisterSegmentPlugin(&zapv12.ZapPlugin{}, false) diff --git a/index/scorch/snapshot_index_vr.go b/index/scorch/snapshot_index_vr.go index 9c4de2560..04a9e0e6d 100644 --- a/index/scorch/snapshot_index_vr.go +++ b/index/scorch/snapshot_index_vr.go @@ -28,6 +28,8 @@ import ( segment_api "github.com/blevesearch/scorch_segment_api/v2" ) +const VectorSearchSupportedSegmentVersion = 16 + var reflectStaticSizeIndexSnapshotVectorReader int func init() { diff --git a/index/scorch/snapshot_segment.go b/index/scorch/snapshot_segment.go index 0b76ec746..1c14af726 100644 --- a/index/scorch/snapshot_segment.go +++ b/index/scorch/snapshot_segment.go @@ -39,6 +39,9 @@ type SegmentSnapshot struct { segment segment.Segment deleted *roaring.Bitmap creator string + stats *fieldStats + + cachedMeta *cachedMeta cachedDocs *cachedDocs } @@ -282,3 +285,30 @@ func (c *cachedDocs) visitDoc(localDocNum uint64, c.m.Unlock() } + +// the purpose of the cachedMeta is to simply allow the user of this type to record +// and cache certain meta data information (specific to the segment) that can be +// used across calls to save compute on the same. +// for example searcher creations on the same index snapshot can use this struct +// to help and fetch the backing index size information which can be used in +// memory usage calculation thereby deciding whether to allow a query or not. +type cachedMeta struct { + m sync.RWMutex + meta map[string]interface{} +} + +func (c *cachedMeta) updateMeta(field string, val interface{}) { + c.m.Lock() + if c.meta == nil { + c.meta = make(map[string]interface{}) + } + c.meta[field] = val + c.m.Unlock() +} + +func (c *cachedMeta) fetchMeta(field string) (rv interface{}) { + c.m.RLock() + rv = c.meta[field] + c.m.RUnlock() + return rv +} diff --git a/index/scorch/snapshot_vector_index.go b/index/scorch/snapshot_vector_index.go index 86aa6df54..9d6f0700e 100644 --- a/index/scorch/snapshot_vector_index.go +++ b/index/scorch/snapshot_vector_index.go @@ -42,16 +42,7 @@ func (is *IndexSnapshot) VectorReader(ctx context.Context, vector []float32, rv.iterators = make([]segment_api.VecPostingsIterator, len(is.segment)) } - for i, seg := range is.segment { - if sv, ok := seg.segment.(segment_api.VectorSegment); ok { - pl, err := sv.SimilarVectors(field, vector, k, seg.deleted) - if err != nil { - return nil, err - } - rv.postings[i] = pl - rv.iterators[i] = pl.Iterator(rv.iterators[i]) - } - } + // initialize postings and iterators within the OptimizeVR's Finish() return rv, nil } diff --git a/index/scorch/stats.go b/index/scorch/stats.go index dc74d9f29..269ae2f63 100644 --- a/index/scorch/stats.go +++ b/index/scorch/stats.go @@ -51,6 +51,8 @@ type Stats struct { TotTermSearchersStarted uint64 TotTermSearchersFinished uint64 + TotKNNSearches uint64 + TotEventTriggerStarted uint64 TotEventTriggerCompleted uint64 diff --git a/index_alias_impl.go b/index_alias_impl.go index ccb52f244..057d76b73 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -21,6 +21,8 @@ import ( "github.com/blevesearch/bleve/v2/mapping" "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/search/collector" + "github.com/blevesearch/bleve/v2/search/query" index "github.com/blevesearch/bleve_index_api" ) @@ -160,13 +162,88 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest if len(i.indexes) < 1 { return nil, ErrorAliasEmpty } + if _, ok := ctx.Value(search.PreSearchKey).(bool); ok { + // since presearchKey is set, it means that the request + // is being executed as part of a presearch, which + // indicates that this index alias is set as an Index + // in another alias, so we need to do a presearch search + // and NOT a real search + return preSearchDataSearch(ctx, req, i.indexes...) + } + + // at this point we know we are doing a real search + // either after a presearch is done, or directly + // on the alias + + // check if request has preSearchData which would indicate that the + // request has already been preSearched and we can skip the + // preSearch step now, we call an optional function to + // redistribute the preSearchData to the individual indexes + // if necessary + var preSearchData map[string]map[string]interface{} + if req.PreSearchData != nil { + if requestHasKNN(req) { + var err error + preSearchData, err = redistributeKNNPreSearchData(req, i.indexes) + if err != nil { + return nil, err + } + } + } // short circuit the simple case if len(i.indexes) == 1 { + if preSearchData != nil { + req.PreSearchData = preSearchData[i.indexes[0].Name()] + } return i.indexes[0].SearchInContext(ctx, req) } - return MultiSearch(ctx, req, i.indexes...) + // at this stage we know we have multiple indexes + // check if preSearchData needs to be gathered from all indexes + // before executing the query + var err error + // only perform presearch if + // - the request does not already have preSearchData + // - the request requires presearch + var preSearchDuration time.Duration + var sr *SearchResult + if req.PreSearchData == nil && preSearchRequired(req) { + searchStart := time.Now() + preSearchResult, err := preSearch(ctx, req, i.indexes...) + if err != nil { + return nil, err + } + // check if the presearch result has any errors and if so + // return the search result as is without executing the query + // so that the errors are not lost + if preSearchResult.Status.Failed > 0 { + return preSearchResult, nil + } + + // if there are no errors, then merge the data in the presearch result + preSearchResult = mergePreSearchResult(req, preSearchResult, i.indexes) + if requestSatisfiedByPreSearch(req) { + sr = finalizeSearchResult(req, preSearchResult) + // no need to run the 2nd phase MultiSearch(..) + } else { + preSearchData, err = constructPreSearchData(req, preSearchResult, i.indexes) + if err != nil { + return nil, err + } + } + preSearchDuration = time.Since(searchStart) + } + + // check if search result was generated as part of presearch itself + if sr == nil { + sr, err = MultiSearch(ctx, req, preSearchData, i.indexes...) + if err != nil { + return nil, err + } + } + sr.Took += preSearchDuration + return sr, nil } func (i *indexAliasImpl) Fields() ([]string, error) { @@ -429,8 +506,8 @@ func (i *indexAliasImpl) Swap(in, out []Index) { // the actual final results. // Perhaps that part needs to be optional, // could be slower in remote usages. -func createChildSearchRequest(req *SearchRequest) *SearchRequest { - return copySearchRequest(req) +func createChildSearchRequest(req *SearchRequest, preSearchData map[string]interface{}) *SearchRequest { + return copySearchRequest(req, preSearchData) } type asyncSearchResult struct { @@ -439,20 +516,109 @@ type asyncSearchResult struct { Err error } -// MultiSearch executes a SearchRequest across multiple Index objects, -// then merges the results. The indexes must honor any ctx deadline. -func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*SearchResult, error) { +func preSearchRequired(req *SearchRequest) bool { + return requestHasKNN(req) +} - searchStart := time.Now() - asyncResults := make(chan *asyncSearchResult, len(indexes)) +func preSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*SearchResult, error) { + // create a dummy request with a match none query + // since we only care about the preSearchData in PreSearch + dummyRequest := &SearchRequest{ + Query: query.NewMatchNoneQuery(), + } + newCtx := context.WithValue(ctx, search.PreSearchKey, true) + if requestHasKNN(req) { + addKnnToDummyRequest(dummyRequest, req) + } + return preSearchDataSearch(newCtx, dummyRequest, indexes...) +} + +func tagHitsWithIndexName(sr *SearchResult, indexName string) { + for _, hit := range sr.Hits { + hit.IndexNames = append(hit.IndexNames, indexName) + } +} +// if the request is satisfied by just the presearch result, +// finalize the result and return it directly without +// performing multi search +func finalizeSearchResult(req *SearchRequest, preSearchResult *SearchResult) *SearchResult { + if preSearchResult == nil { + return nil + } + + // global values across all hits irrespective of pagination settings + preSearchResult.Total = uint64(preSearchResult.Hits.Len()) + maxScore := float64(0) + for i, hit := range preSearchResult.Hits { + // since we are now using the presearch result as the final result + // we can discard the indexNames from the hits as they are no longer + // relevant. + hit.IndexNames = nil + if hit.Score > maxScore { + maxScore = hit.Score + } + hit.HitNumber = uint64(i) + } + preSearchResult.MaxScore = maxScore + // now apply pagination settings var reverseQueryExecution bool if req.SearchBefore != nil { reverseQueryExecution = true req.Sort.Reverse() req.SearchAfter = req.SearchBefore - req.SearchBefore = nil } + if req.SearchAfter != nil { + preSearchResult.Hits = collector.FilterHitsBySearchAfter(preSearchResult.Hits, req.Sort, req.SearchAfter) + } + preSearchResult.Hits = hitsInCurrentPage(req, preSearchResult.Hits) + if reverseQueryExecution { + // reverse the sort back to the original + req.Sort.Reverse() + // resort using the original order + mhs := newSearchHitSorter(req.Sort, preSearchResult.Hits) + req.SortFunc()(mhs) + req.SearchAfter = nil + } + + if req.Explain { + preSearchResult.Request = req + } + return preSearchResult +} + +func mergePreSearchResult(req *SearchRequest, res *SearchResult, + indexes []Index) *SearchResult { + if requestHasKNN(req) { + res.Hits = mergeKNNDocumentMatches(req, res.Hits) + } + return res +} + +func requestSatisfiedByPreSearch(req *SearchRequest) bool { + if requestHasKNN(req) && isKNNrequestSatisfiedByPreSearch(req) { + return true + } + return false +} + +func constructPreSearchData(req *SearchRequest, preSearchResult *SearchResult, indexes []Index) (map[string]map[string]interface{}, error) { + mergedOut := make(map[string]map[string]interface{}, len(indexes)) + for _, index := range indexes { + mergedOut[index.Name()] = make(map[string]interface{}) + } + var err error + if requestHasKNN(req) { + mergedOut, err = constructKnnPresearchData(mergedOut, preSearchResult, indexes) + if err != nil { + return nil, err + } + } + return mergedOut, nil +} + +func preSearchDataSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*SearchResult, error) { + asyncResults := make(chan *asyncSearchResult, len(indexes)) // run search on each index in separate go routine var waitGroup sync.WaitGroup @@ -466,7 +632,7 @@ func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*Se waitGroup.Add(len(indexes)) for _, in := range indexes { - go searchChildIndex(in, createChildSearchRequest(req)) + go searchChildIndex(in, createChildSearchRequest(req, nil)) } // on another go routine, close after finished @@ -483,8 +649,10 @@ func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*Se if sr == nil { // first result sr = asr.Result + tagHitsWithIndexName(sr, asr.Name) } else { // merge with previous + tagHitsWithIndexName(asr.Result, asr.Name) sr.Merge(asr.Result) } } else { @@ -504,25 +672,121 @@ func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*Se } } + // in presearch partial results are not allowed as it can lead to + // the real search giving incorrect results, and hence the search + // result is reset. + // discard partial hits if some child index has failed or + // if some child alias has returned partial results. + if len(indexErrors) > 0 || sr.Status.Failed > 0 { + sr = &SearchResult{ + Status: sr.Status, + } + if sr.Status.Errors == nil { + sr.Status.Errors = make(map[string]error) + } + for indexName, indexErr := range indexErrors { + sr.Status.Errors[indexName] = indexErr + sr.Status.Total++ + sr.Status.Failed++ + } + } + + return sr, nil +} + +// hitsInCurrentPage returns the hits in the current page +// using the From and Size parameters in the request +func hitsInCurrentPage(req *SearchRequest, hits []*search.DocumentMatch) []*search.DocumentMatch { sortFunc := req.SortFunc() // sort all hits with the requested order if len(req.Sort) > 0 { - sorter := newSearchHitSorter(req.Sort, sr.Hits) + sorter := newSearchHitSorter(req.Sort, hits) sortFunc(sorter) } - // now skip over the correct From - if req.From > 0 && len(sr.Hits) > req.From { - sr.Hits = sr.Hits[req.From:] + if req.From > 0 && len(hits) > req.From { + hits = hits[req.From:] } else if req.From > 0 { - sr.Hits = search.DocumentMatchCollection{} + hits = search.DocumentMatchCollection{} } - // now trim to the correct size - if req.Size > 0 && len(sr.Hits) > req.Size { - sr.Hits = sr.Hits[0:req.Size] + if req.Size > 0 && len(hits) > req.Size { + hits = hits[0:req.Size] + } + return hits +} + +// MultiSearch executes a SearchRequest across multiple Index objects, +// then merges the results. The indexes must honor any ctx deadline. +func MultiSearch(ctx context.Context, req *SearchRequest, preSearchData map[string]map[string]interface{}, indexes ...Index) (*SearchResult, error) { + + searchStart := time.Now() + asyncResults := make(chan *asyncSearchResult, len(indexes)) + + var reverseQueryExecution bool + if req.SearchBefore != nil { + reverseQueryExecution = true + req.Sort.Reverse() + req.SearchAfter = req.SearchBefore + req.SearchBefore = nil + } + + // run search on each index in separate go routine + var waitGroup sync.WaitGroup + + var searchChildIndex = func(in Index, childReq *SearchRequest) { + rv := asyncSearchResult{Name: in.Name()} + rv.Result, rv.Err = in.SearchInContext(ctx, childReq) + asyncResults <- &rv + waitGroup.Done() + } + + waitGroup.Add(len(indexes)) + for _, in := range indexes { + var payload map[string]interface{} + if preSearchData != nil { + payload = preSearchData[in.Name()] + } + go searchChildIndex(in, createChildSearchRequest(req, payload)) + } + + // on another go routine, close after finished + go func() { + waitGroup.Wait() + close(asyncResults) + }() + + var sr *SearchResult + indexErrors := make(map[string]error) + + for asr := range asyncResults { + if asr.Err == nil { + if sr == nil { + // first result + sr = asr.Result + } else { + // merge with previous + sr.Merge(asr.Result) + } + } else { + indexErrors[asr.Name] = asr.Err + } } + // merge just concatenated all the hits + // now lets clean it up + + // handle case where no results were successful + if sr == nil { + sr = &SearchResult{ + Status: &SearchStatus{ + Errors: make(map[string]error), + }, + } + } + + sr.Hits = hitsInCurrentPage(req, sr.Hits) + // fix up facets for name, fr := range req.Facets { sr.Facets.Fixup(name, fr.Size) @@ -533,14 +797,16 @@ func MultiSearch(ctx context.Context, req *SearchRequest, indexes ...Index) (*Se req.Sort.Reverse() // resort using the original order mhs := newSearchHitSorter(req.Sort, sr.Hits) - sortFunc(mhs) + req.SortFunc()(mhs) // reset request req.SearchBefore = req.SearchAfter req.SearchAfter = nil } // fix up original request - sr.Request = req + if req.Explain { + sr.Request = req + } searchDuration := time.Since(searchStart) sr.Took = searchDuration diff --git a/index_alias_impl_test.go b/index_alias_impl_test.go index 1b6ae55f4..623dbc623 100644 --- a/index_alias_impl_test.go +++ b/index_alias_impl_test.go @@ -494,8 +494,7 @@ func TestIndexAliasMulti(t *testing.T) { Successful: 2, Errors: make(map[string]error), }, - Request: sr, - Total: 2, + Total: 2, Hits: search.DocumentMatchCollection{ { ID: "b", @@ -575,8 +574,7 @@ func TestMultiSearchNoError(t *testing.T) { Successful: 2, Errors: make(map[string]error), }, - Request: sr, - Total: 2, + Total: 2, Hits: search.DocumentMatchCollection{ { Index: "2", @@ -594,7 +592,7 @@ func TestMultiSearchNoError(t *testing.T) { MaxScore: 2.0, } - results, err := MultiSearch(context.Background(), sr, ei1, ei2) + results, err := MultiSearch(context.Background(), sr, nil, ei1, ei2) if err != nil { t.Error(err) } @@ -625,7 +623,7 @@ func TestMultiSearchSomeError(t *testing.T) { }} ei2 := &stubIndex{name: "ei2", err: fmt.Errorf("deliberate error")} sr := NewSearchRequest(NewTermQuery("test")) - res, err := MultiSearch(context.Background(), sr, ei1, ei2) + res, err := MultiSearch(context.Background(), sr, nil, ei1, ei2) if err != nil { t.Errorf("expected no error, got %v", err) } @@ -652,7 +650,7 @@ func TestMultiSearchAllError(t *testing.T) { ei1 := &stubIndex{name: "ei1", err: fmt.Errorf("deliberate error")} ei2 := &stubIndex{name: "ei2", err: fmt.Errorf("deliberate error")} sr := NewSearchRequest(NewTermQuery("test")) - res, err := MultiSearch(context.Background(), sr, ei1, ei2) + res, err := MultiSearch(context.Background(), sr, nil, ei1, ei2) if err != nil { t.Errorf("expected no error, got %v", err) } @@ -708,7 +706,7 @@ func TestMultiSearchSecondPage(t *testing.T) { checkRequest: checkRequest, } sr := NewSearchRequestOptions(NewTermQuery("test"), 10, 10, false) - _, err := MultiSearch(context.Background(), sr, ei1, ei2) + _, err := MultiSearch(context.Background(), sr, nil, ei1, ei2) if err != nil { t.Errorf("unexpected error %v", err) } @@ -786,7 +784,7 @@ func TestMultiSearchTimeout(t *testing.T) { defer cancel() query := NewTermQuery("test") sr := NewSearchRequest(query) - res, err := MultiSearch(ctx, sr, ei1, ei2) + res, err := MultiSearch(ctx, sr, nil, ei1, ei2) if err != nil { t.Errorf("expected no error, got %v", err) } @@ -806,7 +804,7 @@ func TestMultiSearchTimeout(t *testing.T) { // now run a search again with an absurdly low timeout (should timeout) ctx, cancel = context.WithTimeout(context.Background(), 1*time.Microsecond) defer cancel() - res, err = MultiSearch(ctx, sr, ei1, ei2) + res, err = MultiSearch(ctx, sr, nil, ei1, ei2) if err != nil { t.Errorf("expected no error, got %v", err) } @@ -833,7 +831,7 @@ func TestMultiSearchTimeout(t *testing.T) { // now run a search again with a normal timeout, but cancel it first ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) cancel() - res, err = MultiSearch(ctx, sr, ei1, ei2) + res, err = MultiSearch(ctx, sr, nil, ei1, ei2) if err != nil { t.Errorf("expected no error, got %v", err) } @@ -949,8 +947,7 @@ func TestMultiSearchTimeoutPartial(t *testing.T) { "ei3": context.DeadlineExceeded, }, }, - Request: sr, - Total: 2, + Total: 2, Hits: search.DocumentMatchCollection{ { Index: "2", @@ -968,7 +965,7 @@ func TestMultiSearchTimeoutPartial(t *testing.T) { MaxScore: 2.0, } - res, err := MultiSearch(ctx, sr, ei1, ei2, ei3) + res, err := MultiSearch(ctx, sr, nil, ei1, ei2, ei3) if err != nil { t.Fatalf("expected no err, got %v", err) } @@ -1105,8 +1102,7 @@ func TestIndexAliasMultipleLayer(t *testing.T) { "ei3": context.DeadlineExceeded, }, }, - Request: sr, - Total: 2, + Total: 2, Hits: search.DocumentMatchCollection{ { Index: "4", @@ -1184,6 +1180,7 @@ func TestMultiSearchCustomSort(t *testing.T) { }} sr := NewSearchRequest(NewTermQuery("test")) + sr.Explain = true sr.SortBy([]string{"name"}) expected := &SearchResult{ Status: &SearchStatus{ @@ -1222,7 +1219,7 @@ func TestMultiSearchCustomSort(t *testing.T) { MaxScore: 3.0, } - results, err := MultiSearch(context.Background(), sr, ei1, ei2) + results, err := MultiSearch(context.Background(), sr, nil, ei1, ei2) if err != nil { t.Error(err) } diff --git a/index_impl.go b/index_impl.go index 5c9538822..5b407154b 100644 --- a/index_impl.go +++ b/index_impl.go @@ -433,6 +433,25 @@ func memNeededForSearch(req *SearchRequest, return uint64(estimate) } +func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader index.IndexReader) (*SearchResult, error) { + var knnHits []*search.DocumentMatch + var err error + if requestHasKNN(req) { + knnHits, err = i.runKnnCollector(ctx, req, reader, true) + if err != nil { + return nil, err + } + } + + return &SearchResult{ + Status: &SearchStatus{ + Total: 1, + Successful: 1, + }, + Hits: knnHits, + }, nil +} + // SearchInContext executes a search request operation within the provided // Context. Returns a SearchResult object or an error. func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr *SearchResult, err error) { @@ -445,6 +464,25 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr return nil, ErrorIndexClosed } + // open a reader for this search + indexReader, err := i.i.Reader() + if err != nil { + return nil, fmt.Errorf("error opening index reader %v", err) + } + defer func() { + if cerr := indexReader.Close(); err == nil && cerr != nil { + err = cerr + } + }() + + if _, ok := ctx.Value(search.PreSearchKey).(bool); ok { + preSearchResult, err := i.preSearch(ctx, req, indexReader) + if err != nil { + return nil, err + } + return preSearchResult, nil + } + var reverseQueryExecution bool if req.SearchBefore != nil { reverseQueryExecution = true @@ -460,16 +498,31 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr coll = collector.NewTopNCollector(req.Size, req.From, req.Sort) } - // open a reader for this search - indexReader, err := i.i.Reader() - if err != nil { - return nil, fmt.Errorf("error opening index reader %v", err) + var knnHits []*search.DocumentMatch + var ok bool + var skipKnnCollector bool + if req.PreSearchData != nil { + for k, v := range req.PreSearchData { + switch k { + case search.KnnPreSearchDataKey: + if v != nil { + knnHits, ok = v.([]*search.DocumentMatch) + if !ok { + return nil, fmt.Errorf("knn preSearchData must be of type []*search.DocumentMatch") + } + } + skipKnnCollector = true + } + } } - defer func() { - if cerr := indexReader.Close(); err == nil && cerr != nil { - err = cerr + if !skipKnnCollector && requestHasKNN(req) { + knnHits, err = i.runKnnCollector(ctx, req, indexReader, false) + if err != nil { + return nil, err } - }() + } + + setKnnHitsInCollector(knnHits, req, coll) // This callback and variable handles the tracking of bytes read // 1. as part of creation of tfr and its Next() calls which is @@ -496,11 +549,8 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey, search.GeoBufferPoolCallbackFunc(getBufferPool)) - // Using a disjunction query to get union of results from KNN query - // and the original query - searchQuery := disjunctQueryWithKNN(req) - searcher, err := searchQuery.Searcher(ctx, indexReader, i.m, search.SearcherOptions{ + searcher, err := req.Query.Searcher(ctx, indexReader, i.m, search.SearcherOptions{ Explain: req.Explain, IncludeTermVectors: req.IncludeLocations || req.Highlight != nil, Score: req.Score, @@ -544,14 +594,14 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if dateTimeParser == nil { return nil, fmt.Errorf("no date time parser named `%s` registered", dateTimeParserName) } - start, end, startLayout, endLayout, err := dr.ParseDates(dateTimeParser) + start, end, err := dr.ParseDates(dateTimeParser) if err != nil { return nil, fmt.Errorf("ParseDates err: %v, using date time parser named %s", err, dateTimeParserName) } if start.IsZero() && end.IsZero() { return nil, fmt.Errorf("date range query must specify either start, end or both for date range name '%s'", dr.Name) } - facetBuilder.AddRange(dr.Name, start, end, startLayout, endLayout) + facetBuilder.AddRange(dr.Name, start, end) } facetsBuilder.Add(facetName, facetBuilder) } else { @@ -609,7 +659,9 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr var storedFieldsCost uint64 for _, hit := range hits { - if i.name != "" { + // KNN documents will already have their Index value set as part of the knn collector output + // so check if the index is empty and set it to the current index name + if i.name != "" && hit.Index == "" { hit.Index = i.name } err, storedFieldsBytes := LoadAndHighlightFields(hit, req, i.name, indexReader, highlighter) @@ -642,18 +694,23 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr req.SearchAfter = nil } - return &SearchResult{ + rv := &SearchResult{ Status: &SearchStatus{ Total: 1, Successful: 1, }, - Request: req, Hits: hits, Total: coll.Total(), MaxScore: coll.MaxScore(), Took: searchDuration, Facets: coll.FacetResults(), - }, nil + } + + if req.Explain { + rv.Request = req + } + + return rv, nil } func LoadAndHighlightFields(hit *search.DocumentMatch, req *SearchRequest, @@ -663,7 +720,7 @@ func LoadAndHighlightFields(hit *search.DocumentMatch, req *SearchRequest, if len(req.Fields) > 0 || highlighter != nil { doc, err := r.Document(hit.ID) if err == nil && doc != nil { - if len(req.Fields) > 0 { + if len(req.Fields) > 0 && hit.Fields == nil { totalStoredFieldsBytes = doc.StoredFieldsBytes() fieldsToLoad := deDuplicate(req.Fields) for _, f := range fieldsToLoad { diff --git a/index_test.go b/index_test.go index f62b34669..c6cc0793e 100644 --- a/index_test.go +++ b/index_test.go @@ -304,7 +304,7 @@ func TestBytesWritten(t *testing.T) { typeFieldMapping.DocValues = false documentMapping.AddFieldMappingsAt("type", typeFieldMapping) - err = checkStatsOnIndexedBatch(tmpIndexPath, indexMapping, 37767) + err = checkStatsOnIndexedBatch(tmpIndexPath, indexMapping, 57273) if err != nil { t.Fatal(err) } @@ -313,7 +313,7 @@ func TestBytesWritten(t *testing.T) { contentFieldMapping.Store = true tmpIndexPath1 := createTmpIndexPath(t) - err := checkStatsOnIndexedBatch(tmpIndexPath1, indexMapping, 56582) + err := checkStatsOnIndexedBatch(tmpIndexPath1, indexMapping, 76069) if err != nil { t.Fatal(err) } @@ -323,7 +323,7 @@ func TestBytesWritten(t *testing.T) { contentFieldMapping.IncludeInAll = true tmpIndexPath2 := createTmpIndexPath(t) - err = checkStatsOnIndexedBatch(tmpIndexPath2, indexMapping, 44714) + err = checkStatsOnIndexedBatch(tmpIndexPath2, indexMapping, 68875) if err != nil { t.Fatal(err) } @@ -333,7 +333,7 @@ func TestBytesWritten(t *testing.T) { contentFieldMapping.IncludeTermVectors = true tmpIndexPath3 := createTmpIndexPath(t) - err = checkStatsOnIndexedBatch(tmpIndexPath3, indexMapping, 59479) + err = checkStatsOnIndexedBatch(tmpIndexPath3, indexMapping, 78985) if err != nil { t.Fatal(err) } @@ -343,7 +343,7 @@ func TestBytesWritten(t *testing.T) { contentFieldMapping.DocValues = true tmpIndexPath4 := createTmpIndexPath(t) - err = checkStatsOnIndexedBatch(tmpIndexPath4, indexMapping, 44722) + err = checkStatsOnIndexedBatch(tmpIndexPath4, indexMapping, 64228) if err != nil { t.Fatal(err) } @@ -401,8 +401,8 @@ func TestBytesRead(t *testing.T) { } stats, _ := idx.StatsMap()["index"].(map[string]interface{}) prevBytesRead, _ := stats["num_bytes_read_at_query_time"].(uint64) - if prevBytesRead != 36066 && res.Cost == prevBytesRead { - t.Fatalf("expected bytes read for query string 32349, got %v", + if prevBytesRead != 21639 && res.Cost == prevBytesRead { + t.Fatalf("expected bytes read for query string 21639, got %v", prevBytesRead) } @@ -504,33 +504,6 @@ func TestBytesRead(t *testing.T) { } } -func getBatchFromData(idx Index, fileName string) (*Batch, error) { - pwd, err := os.Getwd() - if err != nil { - return nil, err - } - path := filepath.Join(pwd, "data", "test", fileName) - batch := idx.NewBatch() - var dataset []map[string]interface{} - fileContent, err := os.ReadFile(path) - if err != nil { - return nil, err - } - err = json.Unmarshal(fileContent, &dataset) - if err != nil { - return nil, err - } - - for _, doc := range dataset { - err = batch.Index(fmt.Sprintf("%d", doc["id"]), doc) - if err != nil { - return nil, err - } - } - - return batch, err -} - func TestBytesReadStored(t *testing.T) { tmpIndexPath := createTmpIndexPath(t) defer cleanupTmpIndexPath(t, tmpIndexPath) @@ -580,8 +553,8 @@ func TestBytesReadStored(t *testing.T) { stats, _ := idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ := stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead != 25928 && bytesRead == res.Cost { - t.Fatalf("expected the bytes read stat to be around 25928, got %v", bytesRead) + if bytesRead != 11501 && bytesRead == res.Cost { + t.Fatalf("expected the bytes read stat to be around 11501, got %v", bytesRead) } prevBytesRead := bytesRead @@ -651,8 +624,8 @@ func TestBytesReadStored(t *testing.T) { stats, _ = idx1.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) - if bytesRead != 18114 && bytesRead == res.Cost { - t.Fatalf("expected the bytes read stat to be around 18114, got %v", bytesRead) + if bytesRead != 3687 && bytesRead == res.Cost { + t.Fatalf("expected the bytes read stat to be around 3687, got %v", bytesRead) } prevBytesRead = bytesRead @@ -680,6 +653,33 @@ func TestBytesReadStored(t *testing.T) { } } +func getBatchFromData(idx Index, fileName string) (*Batch, error) { + pwd, err := os.Getwd() + if err != nil { + return nil, err + } + path := filepath.Join(pwd, "data", "test", fileName) + batch := idx.NewBatch() + var dataset []map[string]interface{} + fileContent, err := os.ReadFile(path) + if err != nil { + return nil, err + } + err = json.Unmarshal(fileContent, &dataset) + if err != nil { + return nil, err + } + + for _, doc := range dataset { + err = batch.Index(fmt.Sprintf("%d", doc["id"]), doc) + if err != nil { + return nil, err + } + } + + return batch, err +} + func TestIndexCreateNewOverExisting(t *testing.T) { tmpIndexPath := createTmpIndexPath(t) diff --git a/mapping/document.go b/mapping/document.go index 9f5aea581..73bb124db 100644 --- a/mapping/document.go +++ b/mapping/document.go @@ -50,7 +50,8 @@ type DocumentMapping struct { StructTagKey string `json:"struct_tag_key,omitempty"` } -func (dm *DocumentMapping) Validate(cache *registry.Cache) error { +func (dm *DocumentMapping) Validate(cache *registry.Cache, + parentName string, fieldAliasCtx map[string]*FieldMapping) error { var err error if dm.DefaultAnalyzer != "" { _, err := cache.AnalyzerNamed(dm.DefaultAnalyzer) @@ -58,8 +59,12 @@ func (dm *DocumentMapping) Validate(cache *registry.Cache) error { return err } } - for _, property := range dm.Properties { - err = property.Validate(cache) + for propertyName, property := range dm.Properties { + newParent := propertyName + if parentName != "" { + newParent = fmt.Sprintf("%s.%s", parentName, propertyName) + } + err = property.Validate(cache, newParent, fieldAliasCtx) if err != nil { return err } @@ -78,21 +83,24 @@ func (dm *DocumentMapping) Validate(cache *registry.Cache) error { } } - err := validateFieldType(field.Type) + err := validateFieldMapping(field, parentName, fieldAliasCtx) if err != nil { return err } - - if field.Type == "vector" { - err := validateVectorField(field) - if err != nil { - return err - } - } } return nil } +func validateFieldType(field *FieldMapping) error { + switch field.Type { + case "text", "datetime", "number", "boolean", "geopoint", "geoshape", "IP": + return nil + default: + return fmt.Errorf("field: '%s', unknown field type: '%s'", + field.Name, field.Type) + } +} + // analyzerNameForPath attempts to first find the field // described by this path, then returns the analyzer // configured for that field @@ -148,15 +156,20 @@ func (dm *DocumentMapping) fieldDescribedByPath(path string) *FieldMapping { return nil } -// documentMappingForPath returns the EXACT and closest matches for a sub +// documentMappingForPathElements returns the EXACT and closest matches for a sub // document or for an explicitly mapped field; the closest most specific // document mapping could be one that matches part of the provided path. -func (dm *DocumentMapping) documentMappingForPath(path string) ( +func (dm *DocumentMapping) documentMappingForPathElements(pathElements []string) ( *DocumentMapping, *DocumentMapping) { - pathElements := decodePath(path) + var pathElementsCopy []string + if len(pathElements) == 0 { + pathElementsCopy = []string{""} + } else { + pathElementsCopy = pathElements + } current := dm OUTER: - for i, pathElement := range pathElements { + for i, pathElement := range pathElementsCopy { if subDocMapping, exists := current.Properties[pathElement]; exists { current = subDocMapping continue OUTER @@ -164,7 +177,7 @@ OUTER: // no subDocMapping matches this pathElement // only if this is the last element check for field name - if i == len(pathElements)-1 { + if i == len(pathElementsCopy)-1 { for _, field := range current.Fields { if field.Name == pathElement { break @@ -177,6 +190,15 @@ OUTER: return current, current } +// documentMappingForPath returns the EXACT and closest matches for a sub +// document or for an explicitly mapped field; the closest most specific +// document mapping could be one that matches part of the provided path. +func (dm *DocumentMapping) documentMappingForPath(path string) ( + *DocumentMapping, *DocumentMapping) { + pathElements := decodePath(path) + return dm.documentMappingForPathElements(pathElements) +} + // NewDocumentMapping returns a new document mapping // with all the default values. func NewDocumentMapping() *DocumentMapping { @@ -395,9 +417,8 @@ func (dm *DocumentMapping) walkDocument(data interface{}, path []string, indexes } func (dm *DocumentMapping) processProperty(property interface{}, path []string, indexes []uint64, context *walkContext) { - pathString := encodePath(path) // look to see if there is a mapping for this field - subDocMapping, closestDocMapping := dm.documentMappingForPath(pathString) + subDocMapping, closestDocMapping := dm.documentMappingForPathElements(path) // check to see if we even need to do further processing if subDocMapping != nil && !subDocMapping.Enabled { @@ -409,6 +430,8 @@ func (dm *DocumentMapping) processProperty(property interface{}, path []string, // cannot do anything with the zero value return } + + pathString := encodePath(path) propertyType := propertyValue.Type() switch propertyType.Kind() { case reflect.String: @@ -509,12 +532,20 @@ func (dm *DocumentMapping) processProperty(property interface{}, path []string, dm.walkDocument(property, path, indexes, context) } case reflect.Map, reflect.Slice: + var isPropertyVector bool + var isPropertyVectorInitialized bool if subDocMapping != nil { for _, fieldMapping := range subDocMapping.Fields { switch fieldMapping.Type { case "vector": - fieldMapping.processVector(property, pathString, path, + processed := fieldMapping.processVector(property, pathString, path, indexes, context) + if !isPropertyVectorInitialized { + isPropertyVector = processed + isPropertyVectorInitialized = true + } else { + isPropertyVector = isPropertyVector && processed + } case "geopoint": fieldMapping.processGeoPoint(property, pathString, path, indexes, context) case "IP": @@ -527,7 +558,9 @@ func (dm *DocumentMapping) processProperty(property interface{}, path []string, } } } - dm.walkDocument(property, path, indexes, context) + if !isPropertyVector { + dm.walkDocument(property, path, indexes, context) + } case reflect.Ptr: if !propertyValue.IsNil() { switch property := property.(type) { diff --git a/mapping/field.go b/mapping/field.go index 41aeb1512..f4339b384 100644 --- a/mapping/field.go +++ b/mapping/field.go @@ -75,8 +75,11 @@ type FieldMapping struct { // Similarity is the similarity algorithm used for scoring // vector fields. - // See: util.DefaultSimilarityMetric & util.SupportedSimilarityMetrics + // See: index.DefaultSimilarityMetric & index.SupportedSimilarityMetrics Similarity string `json:"similarity,omitempty"` + + // Applicable to vector fields only - optimization string + VectorIndexOptimizedFor string `json:"vector_index_optimized_for,omitempty"` } // NewTextFieldMapping returns a default field mapping for text @@ -466,6 +469,11 @@ func (fm *FieldMapping) UnmarshalJSON(data []byte) error { if err != nil { return err } + case "vector_index_optimized_for": + err := json.Unmarshal(v, &fm.VectorIndexOptimizedFor) + if err != nil { + return err + } default: invalidKeys = append(invalidKeys, k) } diff --git a/mapping/index.go b/mapping/index.go index 1c08bc589..171ee1a72 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -174,12 +174,14 @@ func (im *IndexMappingImpl) Validate() error { if err != nil { return err } - err = im.DefaultMapping.Validate(im.cache) + + fieldAliasCtx := make(map[string]*FieldMapping) + err = im.DefaultMapping.Validate(im.cache, "", fieldAliasCtx) if err != nil { return err } for _, docMapping := range im.TypeMapping { - err = docMapping.Validate(im.cache) + err = docMapping.Validate(im.cache, "", fieldAliasCtx) if err != nil { return err } diff --git a/mapping/mapping_no_vectors.go b/mapping/mapping_no_vectors.go index f4987596a..f9f35f57c 100644 --- a/mapping/mapping_no_vectors.go +++ b/mapping/mapping_no_vectors.go @@ -17,30 +17,19 @@ package mapping -import "fmt" - func NewVectorFieldMapping() *FieldMapping { return nil } func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, - pathString string, path []string, indexes []uint64, context *walkContext) { - + pathString string, path []string, indexes []uint64, context *walkContext) bool { + return false } // ----------------------------------------------------------------------------- // document validation functions -func validateVectorField(fieldMapping *FieldMapping) error { - return nil -} - -func validateFieldType(fieldType string) error { - switch fieldType { - case "text", "datetime", "number", "boolean", "geopoint", "geoshape", "IP": - default: - return fmt.Errorf("unknown field type: '%s'", fieldType) - } - - return nil +func validateFieldMapping(field *FieldMapping, parentName string, + fieldAliasCtx map[string]*FieldMapping) error { + return validateFieldType(field) } diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index a39820d96..a0b712608 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -23,6 +23,13 @@ import ( "github.com/blevesearch/bleve/v2/document" "github.com/blevesearch/bleve/v2/util" + index "github.com/blevesearch/bleve_index_api" +) + +// Min and Max allowed dimensions for a vector field +const ( + MinVectorDims = 1 + MaxVectorDims = 2048 ) func NewVectorFieldMapping() *FieldMapping { @@ -36,58 +43,133 @@ func NewVectorFieldMapping() *FieldMapping { } } -func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, - pathString string, path []string, indexes []uint64, context *walkContext) { - propertyVal := reflect.ValueOf(propertyMightBeVector) - if !propertyVal.IsValid() { - return - } - - // Validating the length of the vector is required here, in order to - // help zapx in deciding the shape of the batch of vectors to be indexed. - if propertyVal.Kind() == reflect.Slice && propertyVal.Len() == fm.Dims { - vector := make([]float32, propertyVal.Len()) - isVectorValid := true - for i := 0; i < propertyVal.Len(); i++ { - item := propertyVal.Index(i) - if item.CanInterface() { - itemVal := item.Interface() - itemFloat, ok := util.ExtractNumericValFloat32(itemVal) - if !ok { - isVectorValid = false - break - } - vector[i] = itemFloat - } +// validate and process a flat vector +func processFlatVector(vecV reflect.Value, dims int) ([]float32, bool) { + if vecV.Len() != dims { + return nil, false + } + + rv := make([]float32, dims) + for i := 0; i < vecV.Len(); i++ { + item := vecV.Index(i) + if !item.CanInterface() { + return nil, false + } + itemI := item.Interface() + itemFloat, ok := util.ExtractNumericValFloat32(itemI) + if !ok { + return nil, false } - // Even if one of the vector elements is not a float32, we do not index - // this field and return silently - if !isVectorValid { - return + rv[i] = itemFloat + } + + return rv, true +} + +// validate and process a vector +// max supported depth of nesting is 2 ([][]float32) +func processVector(vecI interface{}, dims int) ([]float32, bool) { + vecV := reflect.ValueOf(vecI) + if !vecV.IsValid() || vecV.Kind() != reflect.Slice || vecV.Len() == 0 { + return nil, false + } + + // Let's examine the first element (head) of the vector. + // If head is a slice, then vector is nested, otherwise flat. + head := vecV.Index(0) + if !head.CanInterface() { + return nil, false + } + headI := head.Interface() + headV := reflect.ValueOf(headI) + if !headV.IsValid() { + return nil, false + } + if headV.Kind() != reflect.Slice { // vector is flat + return processFlatVector(vecV, dims) + } + + // # process nested vector + + // pre-allocate memory for the flattened vector + // so that we can use copy() later + rv := make([]float32, dims*vecV.Len()) + + for i := 0; i < vecV.Len(); i++ { + subVec := vecV.Index(i) + if !subVec.CanInterface() { + return nil, false + } + subVecI := subVec.Interface() + subVecV := reflect.ValueOf(subVecI) + if !subVecV.IsValid() { + return nil, false + } + + if subVecV.Kind() != reflect.Slice { + return nil, false + } + + flatVector, ok := processFlatVector(subVecV, dims) + if !ok { + return nil, false } - fieldName := getFieldName(pathString, path, fm) - options := fm.Options() - field := document.NewVectorFieldWithIndexingOptions(fieldName, - indexes, vector, fm.Dims, fm.Similarity, options) - context.doc.AddField(field) + copy(rv[i*dims:(i+1)*dims], flatVector) + } + + return rv, true +} - // "_all" composite field is not applicable for vector field - context.excludedFromAll = append(context.excludedFromAll, fieldName) +func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, + pathString string, path []string, indexes []uint64, context *walkContext) bool { + vector, ok := processVector(propertyMightBeVector, fm.Dims) + // Don't add field to document if vector is invalid + if !ok { + return false } + + fieldName := getFieldName(pathString, path, fm) + options := fm.Options() + field := document.NewVectorFieldWithIndexingOptions(fieldName, indexes, vector, + fm.Dims, fm.Similarity, fm.VectorIndexOptimizedFor, options) + context.doc.AddField(field) + + // "_all" composite field is not applicable for vector field + context.excludedFromAll = append(context.excludedFromAll, fieldName) + return true } // ----------------------------------------------------------------------------- // document validation functions -func validateVectorField(field *FieldMapping) error { - if field.Dims <= 0 || field.Dims > 2048 { - return fmt.Errorf("invalid vector dimension,"+ - " value should be in range (%d, %d)", 0, 2048) +func validateFieldMapping(field *FieldMapping, parentName string, + fieldAliasCtx map[string]*FieldMapping) error { + switch field.Type { + case "vector": + return validateVectorFieldAlias(field, parentName, fieldAliasCtx) + default: // non-vector field + return validateFieldType(field) + } +} + +func validateVectorFieldAlias(field *FieldMapping, parentName string, + fieldAliasCtx map[string]*FieldMapping) error { + + if field.Name == "" { + field.Name = parentName } if field.Similarity == "" { - field.Similarity = util.DefaultSimilarityMetric + field.Similarity = index.DefaultSimilarityMetric + } + + if field.VectorIndexOptimizedFor == "" { + field.VectorIndexOptimizedFor = index.DefaultIndexOptimization + } + if _, exists := index.SupportedVectorIndexOptimizations[field.VectorIndexOptimizedFor]; !exists { + // if an unsupported config is provided, override to default + field.VectorIndexOptimizedFor = index.DefaultIndexOptimization } // following fields are not applicable for vector @@ -98,21 +180,40 @@ func validateVectorField(field *FieldMapping) error { field.DocValues = false field.SkipFreqNorm = true - if _, ok := util.SupportedSimilarityMetrics[field.Similarity]; !ok { - return fmt.Errorf("invalid similarity metric: '%s', "+ - "valid metrics are: %+v", field.Similarity, - reflect.ValueOf(util.SupportedSimilarityMetrics).MapKeys()) + // # If alias is present, validate the field options as per the alias + // note: reading from a nil map is safe + if fieldAlias, ok := fieldAliasCtx[field.Name]; ok { + if field.Dims != fieldAlias.Dims { + return fmt.Errorf("field: '%s', invalid alias "+ + "(different dimensions %d and %d)", fieldAlias.Name, field.Dims, + fieldAlias.Dims) + } + + if field.Similarity != fieldAlias.Similarity { + return fmt.Errorf("field: '%s', invalid alias "+ + "(different similarity values %s and %s)", fieldAlias.Name, + field.Similarity, fieldAlias.Similarity) + } + + return nil } - return nil -} + // # Validate field options + + if field.Dims < MinVectorDims || field.Dims > MaxVectorDims { + return fmt.Errorf("field: '%s', invalid vector dimension: %d,"+ + " value should be in range (%d, %d)", field.Name, field.Dims, + MinVectorDims, MaxVectorDims) + } + + if _, ok := index.SupportedSimilarityMetrics[field.Similarity]; !ok { + return fmt.Errorf("field: '%s', invalid similarity "+ + "metric: '%s', valid metrics are: %+v", field.Name, field.Similarity, + reflect.ValueOf(index.SupportedSimilarityMetrics).MapKeys()) + } -func validateFieldType(fieldType string) error { - switch fieldType { - case "text", "datetime", "number", "boolean", "geopoint", "geoshape", - "IP", "vector": - default: - return fmt.Errorf("unknown field type: '%s'", fieldType) + if fieldAliasCtx != nil { // writing to a nil map is unsafe + fieldAliasCtx[field.Name] = field } return nil diff --git a/mapping/mapping_vectors_test.go b/mapping/mapping_vectors_test.go new file mode 100644 index 000000000..8c04d03ac --- /dev/null +++ b/mapping/mapping_vectors_test.go @@ -0,0 +1,297 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package mapping + +import "testing" + +func TestVectorFieldAliasValidation(t *testing.T) { + tests := []struct { + // input + name string // name of the test + mappingStr string // index mapping json string + + // expected output + expValidity bool // validity of the mapping + errMsg string // error message, given expValidity is false + }{ + { + name: "test1", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "type": "vector", + "dims": 3 + }, + { + "name": "cityVec", + "type": "vector", + "dims": 4 + } + ] + } + } + } + }`, + expValidity: false, + errMsg: `field: 'cityVec', invalid alias (different dimensions 4 and 3)`, + }, + { + name: "test2", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "type": "vector", + "dims": 3, + "similarity": "l2_norm" + }, + { + "name": "cityVec", + "type": "vector", + "dims": 3, + "similarity": "dot_product" + } + ] + } + } + } + }`, + expValidity: false, + errMsg: `field: 'cityVec', invalid alias (different similarity values dot_product and l2_norm)`, + }, + { + name: "test3", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "type": "vector", + "dims": 3 + }, + { + "name": "cityVec", + "type": "vector", + "dims": 3 + } + ] + } + } + } + }`, + expValidity: true, + errMsg: "", + }, + { + name: "test4", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "name": "vecData", + "type": "vector", + "dims": 4 + } + ] + }, + "countryVec": { + "fields": [ + { + "name": "vecData", + "type": "vector", + "dims": 3 + } + ] + } + } + } + }`, + expValidity: false, + errMsg: `field: 'vecData', invalid alias (different dimensions 3 and 4)`, + }, + { + name: "test5", + mappingStr: ` + { + "default_mapping": { + "properties": { + "cityVec": { + "fields": [ + { + "name": "vecData", + "type": "vector", + "dims": 3 + } + ] + } + } + }, + "types": { + "type1": { + "properties": { + "cityVec": { + "fields": [ + { + "name": "vecData", + "type": "vector", + "dims": 4 + } + ] + } + } + } + } + }`, + expValidity: false, + errMsg: `field: 'vecData', invalid alias (different dimensions 4 and 3)`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + im := NewIndexMapping() + err := im.UnmarshalJSON([]byte(test.mappingStr)) + if err != nil { + t.Fatalf("failed to unmarshal index mapping: %v", err) + } + + err = im.Validate() + isValid := err == nil + if test.expValidity != isValid { + t.Fatalf("validity mismatch, expected: %v, got: %v", + test.expValidity, isValid) + } + + if !isValid && err.Error() != test.errMsg { + t.Fatalf("invalid error message, expected: %v, got: %v", + test.errMsg, err.Error()) + } + }) + } +} + +// A test case for processVector function +type vectorTest struct { + // Input + + ipVec interface{} // input vector + dims int // dimensionality of input vector + + // Expected Output + + expValidity bool // expected validity of the input + expOpVec []float32 // expected output vector, given the input is valid +} + +func TestProcessVector(t *testing.T) { + // Note: while creating vectors, we are using []any instead of []float32, + // this is done to enhance our test coverage. + // When we unmarshal a vector from a JSON, we get []any, not []float32. + tests := []vectorTest{ + // # Flat vectors + + // ## numeric cases + // (all numeric elements) + {[]any{1, 2.2, 3}, 3, true, []float32{1, 2.2, 3}}, // len==dims + {[]any{1, 2.2, 3}, 2, false, nil}, // len>dims + {[]any{1, 2.2, 3}, 4, false, nil}, // lendims + {[]any{[]any{1, 2, 3}}, 2, false, nil}, // len 0 { - if sr.Request.Size > 0 { + if sr.Request != nil && sr.Request.Size > 0 { rv = fmt.Sprintf("%d matches, showing %d through %d, took %s\n", sr.Total, sr.Request.From+1, sr.Request.From+len(sr.Hits), sr.Took) for i, hit := range sr.Hits { rv += fmt.Sprintf("%5d. %s (%f)\n", i+sr.Request.From+1, hit.ID, hit.Score) diff --git a/search/collector.go b/search/collector.go index 38e34fe7c..e81219e54 100644 --- a/search/collector.go +++ b/search/collector.go @@ -44,9 +44,15 @@ type MakeDocumentMatchHandlerKeyType string var MakeDocumentMatchHandlerKey = MakeDocumentMatchHandlerKeyType( "MakeDocumentMatchHandlerKey") +var MakeKNNDocumentMatchHandlerKey = MakeDocumentMatchHandlerKeyType( + "MakeKNNDocumentMatchHandlerKey") + // MakeDocumentMatchHandler is an optional DocumentMatchHandler // builder function which the applications can pass to bleve. // These builder methods gives a DocumentMatchHandler function // to bleve, which it will invoke on every document matches. type MakeDocumentMatchHandler func(ctx *SearchContext) ( callback DocumentMatchHandler, loadID bool, err error) + +type MakeKNNDocumentMatchHandler func(ctx *SearchContext) ( + callback DocumentMatchHandler, err error) diff --git a/search/collector/heap.go b/search/collector/heap.go index 9503f0060..cd662bcf9 100644 --- a/search/collector/heap.go +++ b/search/collector/heap.go @@ -69,6 +69,10 @@ func (c *collectStoreHeap) Final(skip int, fixup collectorFixup) (search.Documen return rv, nil } +func (c *collectStoreHeap) Internal() search.DocumentMatchCollection { + return c.heap +} + // heap interface implementation func (c *collectStoreHeap) Len() int { diff --git a/search/collector/knn.go b/search/collector/knn.go new file mode 100644 index 000000000..465bf6927 --- /dev/null +++ b/search/collector/knn.go @@ -0,0 +1,262 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package collector + +import ( + "context" + "time" + + "github.com/blevesearch/bleve/v2/search" + index "github.com/blevesearch/bleve_index_api" +) + +type collectStoreKNN struct { + internalHeaps []collectorStore + kValues []int64 + allHits map[*search.DocumentMatch]struct{} + ejectedDocs map[*search.DocumentMatch]struct{} +} + +func newStoreKNN(internalHeaps []collectorStore, kValues []int64) *collectStoreKNN { + return &collectStoreKNN{ + internalHeaps: internalHeaps, + kValues: kValues, + ejectedDocs: make(map[*search.DocumentMatch]struct{}), + allHits: make(map[*search.DocumentMatch]struct{}), + } +} + +// Adds a document to the collector store and returns the documents that were ejected +// from the store. The documents that were ejected from the store are the ones that +// were not in the top K documents for any of the heaps. +// These document are put back into the pool document match pool in the KNN Collector. +func (c *collectStoreKNN) AddDocument(doc *search.DocumentMatch) []*search.DocumentMatch { + for heapIdx := 0; heapIdx < len(c.internalHeaps); heapIdx++ { + if _, ok := doc.ScoreBreakdown[heapIdx]; !ok { + continue + } + ejectedDoc := c.internalHeaps[heapIdx].AddNotExceedingSize(doc, int(c.kValues[heapIdx])) + if ejectedDoc != nil { + delete(ejectedDoc.ScoreBreakdown, heapIdx) + c.ejectedDocs[ejectedDoc] = struct{}{} + } + } + var rv []*search.DocumentMatch + for doc := range c.ejectedDocs { + if len(doc.ScoreBreakdown) == 0 { + rv = append(rv, doc) + } + // clear out the ejectedDocs map to reuse it in the next AddDocument call + delete(c.ejectedDocs, doc) + } + return rv +} + +func (c *collectStoreKNN) Final(fixup collectorFixup) (search.DocumentMatchCollection, error) { + for _, heap := range c.internalHeaps { + for _, doc := range heap.Internal() { + // duplicates may be present across the internal heaps + // meaning the same document match may be in the top K + // for multiple KNN queries. + c.allHits[doc] = struct{}{} + } + } + size := len(c.allHits) + if size <= 0 { + return make(search.DocumentMatchCollection, 0), nil + } + rv := make(search.DocumentMatchCollection, size) + i := 0 + for doc := range c.allHits { + if fixup != nil { + err := fixup(doc) + if err != nil { + return nil, err + } + } + rv[i] = doc + i++ + } + return rv, nil +} + +func MakeKNNDocMatchHandler(ctx *search.SearchContext) (search.DocumentMatchHandler, error) { + var hc *KNNCollector + var ok bool + if hc, ok = ctx.Collector.(*KNNCollector); ok { + return func(d *search.DocumentMatch) error { + if d == nil { + return nil + } + toRelease := hc.knnStore.AddDocument(d) + for _, doc := range toRelease { + ctx.DocumentMatchPool.Put(doc) + } + return nil + }, nil + } + return nil, nil +} + +func GetNewKNNCollectorStore(kArray []int64) *collectStoreKNN { + internalHeaps := make([]collectorStore, len(kArray)) + for knnIdx, k := range kArray { + // TODO - Check if the datatype of k can be made into an int instead of int64 + idx := knnIdx + internalHeaps[idx] = getOptimalCollectorStore(int(k), 0, func(i, j *search.DocumentMatch) int { + if i.ScoreBreakdown[idx] < j.ScoreBreakdown[idx] { + return 1 + } + return -1 + }) + } + return newStoreKNN(internalHeaps, kArray) +} + +// implements Collector interface +type KNNCollector struct { + knnStore *collectStoreKNN + size int + total uint64 + took time.Duration + results search.DocumentMatchCollection + maxScore float64 +} + +func NewKNNCollector(kArray []int64, size int64) *KNNCollector { + return &KNNCollector{ + knnStore: GetNewKNNCollectorStore(kArray), + size: int(size), + } +} + +func (hc *KNNCollector) Collect(ctx context.Context, searcher search.Searcher, reader index.IndexReader) error { + startTime := time.Now() + var err error + var next *search.DocumentMatch + + // pre-allocate enough space in the DocumentMatchPool + // unless the sum of K is too large, then cap it + // everything should still work, just allocates DocumentMatches on demand + backingSize := hc.size + if backingSize > PreAllocSizeSkipCap { + backingSize = PreAllocSizeSkipCap + 1 + } + searchContext := &search.SearchContext{ + DocumentMatchPool: search.NewDocumentMatchPool(backingSize+searcher.DocumentMatchPoolSize(), 0), + Collector: hc, + IndexReader: reader, + } + + dmHandlerMakerKNN := MakeKNNDocMatchHandler + if cv := ctx.Value(search.MakeKNNDocumentMatchHandlerKey); cv != nil { + dmHandlerMakerKNN = cv.(search.MakeKNNDocumentMatchHandler) + } + // use the application given builder for making the custom document match + // handler and perform callbacks/invocations on the newly made handler. + dmHandler, err := dmHandlerMakerKNN(searchContext) + if err != nil { + return err + } + select { + case <-ctx.Done(): + search.RecordSearchCost(ctx, search.AbortM, 0) + return ctx.Err() + default: + next, err = searcher.Next(searchContext) + } + for err == nil && next != nil { + if hc.total%CheckDoneEvery == 0 { + select { + case <-ctx.Done(): + search.RecordSearchCost(ctx, search.AbortM, 0) + return ctx.Err() + default: + } + } + hc.total++ + + err = dmHandler(next) + if err != nil { + break + } + + next, err = searcher.Next(searchContext) + } + if err != nil { + return err + } + + // help finalize/flush the results in case + // of custom document match handlers. + err = dmHandler(nil) + if err != nil { + return err + } + + // compute search duration + hc.took = time.Since(startTime) + + // finalize actual results + err = hc.finalizeResults(reader) + if err != nil { + return err + } + return nil +} + +func (hc *KNNCollector) finalizeResults(r index.IndexReader) error { + var err error + hc.results, err = hc.knnStore.Final(func(doc *search.DocumentMatch) error { + if doc.ID == "" { + // look up the id since we need it for lookup + var err error + doc.ID, err = r.ExternalID(doc.IndexInternalID) + if err != nil { + return err + } + } + return nil + }) + return err +} + +func (hc *KNNCollector) Results() search.DocumentMatchCollection { + return hc.results +} + +func (hc *KNNCollector) Total() uint64 { + return hc.total +} + +func (hc *KNNCollector) MaxScore() float64 { + return hc.maxScore +} + +func (hc *KNNCollector) Took() time.Duration { + return hc.took +} + +func (hc *KNNCollector) SetFacetsBuilder(facetsBuilder *search.FacetsBuilder) { + // facet unsupported for vector search +} + +func (hc *KNNCollector) FacetResults() search.FacetResults { + // facet unsupported for vector search + return nil +} diff --git a/search/collector/list.go b/search/collector/list.go index 20d4c9d01..f73505e7d 100644 --- a/search/collector/list.go +++ b/search/collector/list.go @@ -81,6 +81,16 @@ func (c *collectStoreList) Final(skip int, fixup collectorFixup) (search.Documen return search.DocumentMatchCollection{}, nil } +func (c *collectStoreList) Internal() search.DocumentMatchCollection { + rv := make(search.DocumentMatchCollection, c.results.Len()) + i := 0 + for e := c.results.Front(); e != nil; e = e.Next() { + rv[i] = e.Value.(*search.DocumentMatch) + i++ + } + return rv +} + func (c *collectStoreList) len() int { return c.results.Len() } diff --git a/search/collector/slice.go b/search/collector/slice.go index b38d9abc4..07534e693 100644 --- a/search/collector/slice.go +++ b/search/collector/slice.go @@ -72,6 +72,10 @@ func (c *collectStoreSlice) Final(skip int, fixup collectorFixup) (search.Docume return search.DocumentMatchCollection{}, nil } +func (c *collectStoreSlice) Internal() search.DocumentMatchCollection { + return c.slice +} + func (c *collectStoreSlice) len() int { return len(c.slice) } diff --git a/search/collector/topn.go b/search/collector/topn.go index 270d5f924..fc338f54e 100644 --- a/search/collector/topn.go +++ b/search/collector/topn.go @@ -39,6 +39,9 @@ type collectorStore interface { AddNotExceedingSize(doc *search.DocumentMatch, size int) *search.DocumentMatch Final(skip int, fixup collectorFixup) (search.DocumentMatchCollection, error) + + // Provide access the internal heap implementation + Internal() search.DocumentMatchCollection } // PreAllocSizeSkipCap will cap preallocation to this amount when @@ -72,6 +75,9 @@ type TopNCollector struct { updateFieldVisitor index.DocValueVisitor dvReader index.DocValueReader searchAfter *search.DocumentMatch + + knnHits map[string]*search.DocumentMatch + computeNewScoreExpl search.ScoreExplCorrectionCallbackFunc } // CheckDoneEvery controls how frequently we check the context deadline @@ -89,27 +95,66 @@ func NewTopNCollector(size int, skip int, sort search.SortOrder) *TopNCollector // ordering hits by the provided sort order func NewTopNCollectorAfter(size int, sort search.SortOrder, after []string) *TopNCollector { rv := newTopNCollector(size, 0, sort) - rv.searchAfter = &search.DocumentMatch{ - Sort: after, + rv.searchAfter = createSearchAfterDocument(sort, after) + return rv +} + +func newTopNCollector(size int, skip int, sort search.SortOrder) *TopNCollector { + hc := &TopNCollector{size: size, skip: skip, sort: sort} + + hc.store = getOptimalCollectorStore(size, skip, func(i, j *search.DocumentMatch) int { + return hc.sort.Compare(hc.cachedScoring, hc.cachedDesc, i, j) + }) + + // these lookups traverse an interface, so do once up-front + if sort.RequiresDocID() { + hc.needDocIds = true } + hc.neededFields = sort.RequiredFields() + hc.cachedScoring = sort.CacheIsScore() + hc.cachedDesc = sort.CacheDescending() + return hc +} + +func createSearchAfterDocument(sort search.SortOrder, after []string) *search.DocumentMatch { + rv := &search.DocumentMatch{ + Sort: after, + } for pos, ss := range sort { if ss.RequiresDocID() { - rv.searchAfter.ID = after[pos] + rv.ID = after[pos] } if ss.RequiresScoring() { if score, err := strconv.ParseFloat(after[pos], 64); err == nil { - rv.searchAfter.Score = score + rv.Score = score } } } - return rv } -func newTopNCollector(size int, skip int, sort search.SortOrder) *TopNCollector { - hc := &TopNCollector{size: size, skip: skip, sort: sort} +// Filter document matches based on the SearchAfter field in the SearchRequest. +func FilterHitsBySearchAfter(hits []*search.DocumentMatch, sort search.SortOrder, after []string) []*search.DocumentMatch { + if len(hits) == 0 { + return hits + } + // create a search after document + searchAfter := createSearchAfterDocument(sort, after) + // filter the hits + idx := 0 + cachedScoring := sort.CacheIsScore() + cachedDesc := sort.CacheDescending() + for _, hit := range hits { + if sort.Compare(cachedScoring, cachedDesc, hit, searchAfter) > 0 { + hits[idx] = hit + idx++ + } + } + return hits[:idx] +} +func getOptimalCollectorStore(size, skip int, comparator collectorCompare) collectorStore { // pre-allocate space on the store to avoid reslicing // unless the size + skip is too large, then cap it // everything should still work, just reslices as necessary @@ -119,24 +164,10 @@ func newTopNCollector(size int, skip int, sort search.SortOrder) *TopNCollector } if size+skip > 10 { - hc.store = newStoreHeap(backingSize, func(i, j *search.DocumentMatch) int { - return hc.sort.Compare(hc.cachedScoring, hc.cachedDesc, i, j) - }) + return newStoreHeap(backingSize, comparator) } else { - hc.store = newStoreSlice(backingSize, func(i, j *search.DocumentMatch) int { - return hc.sort.Compare(hc.cachedScoring, hc.cachedDesc, i, j) - }) + return newStoreSlice(backingSize, comparator) } - - // these lookups traverse an interface, so do once up-front - if sort.RequiresDocID() { - hc.needDocIds = true - } - hc.neededFields = sort.RequiredFields() - hc.cachedScoring = sort.CacheIsScore() - hc.cachedDesc = sort.CacheDescending() - - return hc } func (hc *TopNCollector) Size() int { @@ -215,7 +246,12 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, } } - err = hc.prepareDocumentMatch(searchContext, reader, next) + err = hc.adjustDocumentMatch(searchContext, reader, next) + if err != nil { + break + } + + err = hc.prepareDocumentMatch(searchContext, reader, next, false) if err != nil { break } @@ -227,6 +263,23 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, next, err = searcher.Next(searchContext) } + if err != nil { + return err + } + if hc.knnHits != nil { + // we may have some knn hits left that did not match any of the top N tf-idf hits + // we need to add them to the collector store to consider them as well. + for _, knnDoc := range hc.knnHits { + err = hc.prepareDocumentMatch(searchContext, reader, knnDoc, true) + if err != nil { + return err + } + err = dmHandler(knnDoc) + if err != nil { + return err + } + } + } statsCallbackFn := ctx.Value(search.SearchIOStatsCallbackKey) if statsCallbackFn != nil { @@ -258,12 +311,40 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, var sortByScoreOpt = []string{"_score"} -func (hc *TopNCollector) prepareDocumentMatch(ctx *search.SearchContext, +func (hc *TopNCollector) adjustDocumentMatch(ctx *search.SearchContext, reader index.IndexReader, d *search.DocumentMatch) (err error) { + if hc.knnHits != nil { + d.ID, err = reader.ExternalID(d.IndexInternalID) + if err != nil { + return err + } + if knnHit, ok := hc.knnHits[d.ID]; ok { + d.Score, d.Expl = hc.computeNewScoreExpl(d, knnHit) + delete(hc.knnHits, d.ID) + } + } + return nil +} + +func (hc *TopNCollector) prepareDocumentMatch(ctx *search.SearchContext, + reader index.IndexReader, d *search.DocumentMatch, isKnnDoc bool) (err error) { // visit field terms for features that require it (sort, facets) - if len(hc.neededFields) > 0 { - err = hc.visitFieldTerms(reader, d) + if !isKnnDoc && len(hc.neededFields) > 0 { + err = hc.visitFieldTerms(reader, d, hc.updateFieldVisitor) + if err != nil { + return err + } + } else if isKnnDoc && hc.facetsBuilder != nil { + // we need to visit the field terms for the knn document + // only for those fields that are required for faceting + // and not for sorting. This is because the knn document's + // sort value is already computed in the knn collector. + err = hc.visitFieldTerms(reader, d, func(field string, term []byte) { + if hc.facetsBuilder != nil { + hc.facetsBuilder.UpdateVisitor(field, term) + } + }) if err != nil { return err } @@ -277,9 +358,14 @@ func (hc *TopNCollector) prepareDocumentMatch(ctx *search.SearchContext, if d.Score > hc.maxScore { hc.maxScore = d.Score } + // early exit as the document match had its sort value calculated in the knn + // collector itself + if isKnnDoc { + return nil + } // see if we need to load ID (at this early stage, for example to sort on it) - if hc.needDocIds { + if hc.needDocIds && d.ID == "" { d.ID, err = reader.ExternalID(d.IndexInternalID) if err != nil { return err @@ -314,6 +400,7 @@ func MakeTopNDocumentMatchHandler( // but we want to allow for exact match, so we pretend hc.searchAfter.HitNumber = d.HitNumber if hc.sort.Compare(hc.cachedScoring, hc.cachedDesc, d, hc.searchAfter) <= 0 { + ctx.DocumentMatchPool.Put(d) return nil } } @@ -353,12 +440,21 @@ func MakeTopNDocumentMatchHandler( // visitFieldTerms is responsible for visiting the field terms of the // search hit, and passing visited terms to the sort and facet builder -func (hc *TopNCollector) visitFieldTerms(reader index.IndexReader, d *search.DocumentMatch) error { +func (hc *TopNCollector) visitFieldTerms(reader index.IndexReader, d *search.DocumentMatch, v index.DocValueVisitor) error { if hc.facetsBuilder != nil { hc.facetsBuilder.StartDoc() } + if d.ID != "" && d.IndexInternalID == nil { + // this document may have been sent over as preSearchData and + // we need to look up the internal id to visit the doc values for it + var err error + d.IndexInternalID, err = reader.InternalID(d.ID) + if err != nil { + return err + } + } - err := hc.dvReader.VisitDocValues(d.IndexInternalID, hc.updateFieldVisitor) + err := hc.dvReader.VisitDocValues(d.IndexInternalID, v) if hc.facetsBuilder != nil { hc.facetsBuilder.EndDoc() } @@ -435,3 +531,11 @@ func (hc *TopNCollector) FacetResults() search.FacetResults { } return nil } + +func (hc *TopNCollector) SetKNNHits(knnHits search.DocumentMatchCollection, newScoreExplComputer search.ScoreExplCorrectionCallbackFunc) { + hc.knnHits = make(map[string]*search.DocumentMatch, len(knnHits)) + for _, hit := range knnHits { + hc.knnHits[hit.ID] = hit + } + hc.computeNewScoreExpl = newScoreExplComputer +} diff --git a/search/facet/facet_builder_datetime.go b/search/facet/facet_builder_datetime.go index c272396b7..ff5167f21 100644 --- a/search/facet/facet_builder_datetime.go +++ b/search/facet/facet_builder_datetime.go @@ -17,7 +17,6 @@ package facet import ( "reflect" "sort" - "strconv" "time" "github.com/blevesearch/bleve/v2/numeric" @@ -36,10 +35,8 @@ func init() { } type dateTimeRange struct { - start time.Time - end time.Time - startLayout string - endLayout string + start time.Time + end time.Time } type DateTimeFacetBuilder struct { @@ -78,12 +75,10 @@ func (fb *DateTimeFacetBuilder) Size() int { return sizeInBytes } -func (fb *DateTimeFacetBuilder) AddRange(name string, start, end time.Time, startLayout string, endLayout string) { +func (fb *DateTimeFacetBuilder) AddRange(name string, start, end time.Time) { r := dateTimeRange{ - start: start, - end: end, - startLayout: startLayout, - endLayout: endLayout, + start: start, + end: end, } fb.ranges[name] = &r } @@ -139,23 +134,11 @@ func (fb *DateTimeFacetBuilder) Result() *search.FacetResult { Count: count, } if !dateRange.start.IsZero() { - var start string - if dateRange.startLayout == "" { - // layout not set probably means it is probably a timestamp - start = strconv.FormatInt(dateRange.start.UnixNano(), 10) - } else { - start = dateRange.start.Format(dateRange.startLayout) - } + start := dateRange.start.Format(time.RFC3339Nano) tf.Start = &start } if !dateRange.end.IsZero() { - var end string - if dateRange.endLayout == "" { - // layout not set probably means it is probably a timestamp - end = strconv.FormatInt(dateRange.end.UnixNano(), 10) - } else { - end = dateRange.end.Format(dateRange.endLayout) - } + end := dateRange.end.Format(time.RFC3339Nano) tf.End = &end } rv.DateRanges = append(rv.DateRanges, tf) diff --git a/search/query/disjunction.go b/search/query/disjunction.go index e008a042a..b307865f3 100644 --- a/search/query/disjunction.go +++ b/search/query/disjunction.go @@ -27,10 +27,15 @@ import ( ) type DisjunctionQuery struct { - Disjuncts []Query `json:"disjuncts"` - BoostVal *Boost `json:"boost,omitempty"` - Min float64 `json:"min"` - queryStringMode bool + Disjuncts []Query `json:"disjuncts"` + BoostVal *Boost `json:"boost,omitempty"` + Min float64 `json:"min"` + retrieveScoreBreakdown bool + queryStringMode bool +} + +func (q *DisjunctionQuery) RetrieveScoreBreakdown(b bool) { + q.retrieveScoreBreakdown = b } // NewDisjunctionQuery creates a new compound Query. @@ -86,7 +91,9 @@ func (q *DisjunctionQuery) Searcher(ctx context.Context, i index.IndexReader, m return searcher.NewMatchNoneSearcher(i) } - return searcher.NewDisjunctionSearcher(ctx, i, ss, q.Min, options) + nctx := context.WithValue(ctx, search.IncludeScoreBreakdownKey, q.retrieveScoreBreakdown) + + return searcher.NewDisjunctionSearcher(nctx, i, ss, q.Min, options) } func (q *DisjunctionQuery) Validate() error { diff --git a/search/query/knn.go b/search/query/knn.go index c485b4a12..030483e54 100644 --- a/search/query/knn.go +++ b/search/query/knn.go @@ -19,11 +19,11 @@ package query import ( "context" + "fmt" "github.com/blevesearch/bleve/v2/mapping" "github.com/blevesearch/bleve/v2/search" "github.com/blevesearch/bleve/v2/search/searcher" - "github.com/blevesearch/bleve/v2/util" index "github.com/blevesearch/bleve_index_api" ) @@ -64,9 +64,11 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, fieldMapping := m.FieldMappingForPath(q.VectorField) similarityMetric := fieldMapping.Similarity if similarityMetric == "" { - similarityMetric = util.DefaultSimilarityMetric + similarityMetric = index.DefaultSimilarityMetric + } + if q.K <= 0 || len(q.Vector) == 0 { + return nil, fmt.Errorf("k must be greater than 0 and vector must be non-empty") } - return searcher.NewKNNSearcher(ctx, i, m, options, q.VectorField, q.Vector, q.K, q.BoostVal.Value(), similarityMetric) } diff --git a/search/query/query.go b/search/query/query.go index eb7b34adb..f0172b76b 100644 --- a/search/query/query.go +++ b/search/query/query.go @@ -65,6 +65,36 @@ type ValidatableQuery interface { Validate() error } +// ParseQuery deserializes a JSON representation of +// a PreSearchData object. +func ParsePreSearchData(input []byte) (map[string]interface{}, error) { + var rv map[string]interface{} + + var tmp map[string]json.RawMessage + err := util.UnmarshalJSON(input, &tmp) + if err != nil { + return nil, err + } + + for k, v := range tmp { + switch k { + case search.KnnPreSearchDataKey: + var value []*search.DocumentMatch + if v != nil { + err := util.UnmarshalJSON(v, &value) + if err != nil { + return nil, err + } + } + if rv == nil { + rv = make(map[string]interface{}) + } + rv[search.KnnPreSearchDataKey] = value + } + } + return rv, nil +} + // ParseQuery deserializes a JSON representation of // a Query object. func ParseQuery(input []byte) (Query, error) { diff --git a/search/scorer/scorer_constant.go b/search/scorer/scorer_constant.go index fc36fd5bf..10190bd85 100644 --- a/search/scorer/scorer_constant.go +++ b/search/scorer/scorer_constant.go @@ -37,6 +37,7 @@ type ConstantScorer struct { queryNorm float64 queryWeight float64 queryWeightExplanation *search.Explanation + includeScore bool } func (s *ConstantScorer) Size() int { @@ -51,10 +52,11 @@ func (s *ConstantScorer) Size() int { func NewConstantScorer(constant float64, boost float64, options search.SearcherOptions) *ConstantScorer { rv := ConstantScorer{ - options: options, - queryWeight: 1.0, - constant: constant, - boost: boost, + options: options, + queryWeight: 1.0, + constant: constant, + boost: boost, + includeScore: options.Score != "none", } return &rv @@ -92,35 +94,38 @@ func (s *ConstantScorer) SetQueryNorm(qnorm float64) { func (s *ConstantScorer) Score(ctx *search.SearchContext, id index.IndexInternalID) *search.DocumentMatch { var scoreExplanation *search.Explanation - score := s.constant + rv := ctx.DocumentMatchPool.Get() + rv.IndexInternalID = id - if s.options.Explain { - scoreExplanation = &search.Explanation{ - Value: score, - Message: fmt.Sprintf("ConstantScore()"), - } - } + if s.includeScore { + score := s.constant - // if the query weight isn't 1, multiply - if s.queryWeight != 1.0 { - score = score * s.queryWeight if s.options.Explain { - childExplanations := make([]*search.Explanation, 2) - childExplanations[0] = s.queryWeightExplanation - childExplanations[1] = scoreExplanation scoreExplanation = &search.Explanation{ - Value: score, - Message: fmt.Sprintf("weight(^%f), product of:", s.boost), - Children: childExplanations, + Value: score, + Message: fmt.Sprintf("ConstantScore()"), } } - } - rv := ctx.DocumentMatchPool.Get() - rv.IndexInternalID = id - rv.Score = score - if s.options.Explain { - rv.Expl = scoreExplanation + // if the query weight isn't 1, multiply + if s.queryWeight != 1.0 { + score = score * s.queryWeight + if s.options.Explain { + childExplanations := make([]*search.Explanation, 2) + childExplanations[0] = s.queryWeightExplanation + childExplanations[1] = scoreExplanation + scoreExplanation = &search.Explanation{ + Value: score, + Message: fmt.Sprintf("weight(^%f), product of:", s.boost), + Children: childExplanations, + } + } + } + + rv.Score = score + if s.options.Explain { + rv.Expl = scoreExplanation + } } return rv diff --git a/search/scorer/scorer_disjunction.go b/search/scorer/scorer_disjunction.go index 054e76fd4..fe319bbeb 100644 --- a/search/scorer/scorer_disjunction.go +++ b/search/scorer/scorer_disjunction.go @@ -81,3 +81,43 @@ func (s *DisjunctionQueryScorer) Score(ctx *search.SearchContext, constituents [ return rv } + +// This method is used only when disjunction searcher is used over multiple +// KNN searchers, where only the score breakdown and the optional explanation breakdown +// is required. The final score and explanation is set when we finalize the KNN hits. +func (s *DisjunctionQueryScorer) ScoreAndExplBreakdown(ctx *search.SearchContext, constituents []*search.DocumentMatch, + matchingIdxs []int, originalPositions []int, countTotal int) *search.DocumentMatch { + + scoreBreakdown := make(map[int]float64) + var childrenExplanations []*search.Explanation + if s.options.Explain { + // since we want to notify which expl belongs to which matched searcher within the disjunction searcher + childrenExplanations = make([]*search.Explanation, countTotal) + } + + for i, docMatch := range constituents { + var index int + if originalPositions != nil { + // scorer used in disjunction slice searcher + index = originalPositions[matchingIdxs[i]] + } else { + // scorer used in disjunction heap searcher + index = matchingIdxs[i] + } + scoreBreakdown[index] = docMatch.Score + if s.options.Explain { + childrenExplanations[index] = docMatch.Expl + } + } + var explBreakdown *search.Explanation + if s.options.Explain { + explBreakdown = &search.Explanation{Children: childrenExplanations} + } + + rv := constituents[0] + rv.ScoreBreakdown = scoreBreakdown + rv.Expl = explBreakdown + rv.FieldTermLocations = search.MergeFieldTermLocations( + rv.FieldTermLocations, constituents[1:]) + return rv +} diff --git a/search/scorer/scorer_knn.go b/search/scorer/scorer_knn.go index 511a47ecb..e7f0a5569 100644 --- a/search/scorer/scorer_knn.go +++ b/search/scorer/scorer_knn.go @@ -18,10 +18,12 @@ package scorer import ( + "fmt" + "math" "reflect" "github.com/blevesearch/bleve/v2/search" - "github.com/blevesearch/bleve/v2/util" + "github.com/blevesearch/bleve/v2/size" index "github.com/blevesearch/bleve_index_api" ) @@ -33,61 +35,93 @@ func init() { } type KNNQueryScorer struct { - queryVector []float32 - queryField string - queryWeight float64 - queryBoost float64 - queryNorm float64 - docTerm uint64 - docTotal uint64 - options search.SearcherOptions - includeScore bool - similarityMetric string + queryVector []float32 + queryField string + queryWeight float64 + queryBoost float64 + queryNorm float64 + options search.SearcherOptions + similarityMetric string + queryWeightExplanation *search.Explanation +} + +func (s *KNNQueryScorer) Size() int { + sizeInBytes := reflectStaticSizeKNNQueryScorer + size.SizeOfPtr + + (len(s.queryVector) * size.SizeOfFloat32) + len(s.queryField) + + if s.queryWeightExplanation != nil { + sizeInBytes += s.queryWeightExplanation.Size() + } + + return sizeInBytes } func NewKNNQueryScorer(queryVector []float32, queryField string, queryBoost float64, - docTerm uint64, docTotal uint64, options search.SearcherOptions, + options search.SearcherOptions, similarityMetric string) *KNNQueryScorer { return &KNNQueryScorer{ queryVector: queryVector, queryField: queryField, queryBoost: queryBoost, queryWeight: 1.0, - docTerm: docTerm, - docTotal: docTotal, options: options, - includeScore: options.Score != "none", similarityMetric: similarityMetric, } } +// Score used when the knnMatch.Score = 0 -> +// the query and indexed vector are exactly the same. +const maxKNNScore = math.MaxFloat64 + func (sqs *KNNQueryScorer) Score(ctx *search.SearchContext, knnMatch *index.VectorDoc) *search.DocumentMatch { rv := ctx.DocumentMatchPool.Get() - - if sqs.includeScore || sqs.options.Explain { - var scoreExplanation *search.Explanation - score := knnMatch.Score - if sqs.similarityMetric == util.EuclideanDistance { - // eucliden distances need to be inverted to work + var scoreExplanation *search.Explanation + score := knnMatch.Score + if sqs.similarityMetric == index.EuclideanDistance { + // in case of euclidean distance being the distance metric, + // an exact vector (perfect match), would return distance = 0 + if score == 0 { + score = maxKNNScore + } else { + // euclidean distances need to be inverted to work with // tf-idf scoring score = 1.0 / score } - - // if the query weight isn't 1, multiply - if sqs.queryWeight != 1.0 { - score = score * sqs.queryWeight - } - - if sqs.includeScore { - rv.Score = score + } + if sqs.options.Explain { + scoreExplanation = &search.Explanation{ + Value: score, + Message: fmt.Sprintf("fieldWeight(%s in doc %s), score of:", + sqs.queryField, knnMatch.ID), + Children: []*search.Explanation{ + { + Value: score, + Message: fmt.Sprintf("vector(field(%s:%s) with similarity_metric(%s)=%e", + sqs.queryField, knnMatch.ID, sqs.similarityMetric, score), + }, + }, } - + } + // if the query weight isn't 1, multiply + if sqs.queryWeight != 1.0 && score != maxKNNScore { + score = score * sqs.queryWeight if sqs.options.Explain { - rv.Expl = scoreExplanation + scoreExplanation = &search.Explanation{ + Value: score, + // Product of score * weight + // Avoid adding the query vector to the explanation since vectors + // can get quite large. + Message: fmt.Sprintf("weight(%s:query Vector^%f in %s), product of:", + sqs.queryField, sqs.queryBoost, knnMatch.ID), + Children: []*search.Explanation{sqs.queryWeightExplanation, scoreExplanation}, + } } } - + rv.Score = score + if sqs.options.Explain { + rv.Expl = scoreExplanation + } rv.IndexInternalID = append(rv.IndexInternalID, knnMatch.ID...) return rv } @@ -101,4 +135,22 @@ func (sqs *KNNQueryScorer) SetQueryNorm(qnorm float64) { // update the query weight sqs.queryWeight = sqs.queryBoost * sqs.queryNorm + + if sqs.options.Explain { + childrenExplanations := make([]*search.Explanation, 2) + childrenExplanations[0] = &search.Explanation{ + Value: sqs.queryBoost, + Message: "boost", + } + childrenExplanations[1] = &search.Explanation{ + Value: sqs.queryNorm, + Message: "queryNorm", + } + sqs.queryWeightExplanation = &search.Explanation{ + Value: sqs.queryWeight, + Message: fmt.Sprintf("queryWeight(%s:query Vector^%f), product of:", + sqs.queryField, sqs.queryBoost), + Children: childrenExplanations, + } + } } diff --git a/search/scorer/scorer_knn_test.go b/search/scorer/scorer_knn_test.go new file mode 100644 index 000000000..46a1de9b1 --- /dev/null +++ b/search/scorer/scorer_knn_test.go @@ -0,0 +1,181 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package scorer + +import ( + "reflect" + "testing" + + "github.com/blevesearch/bleve/v2/search" + index "github.com/blevesearch/bleve_index_api" +) + +func TestKNNScorerExplanation(t *testing.T) { + var queryVector []float32 + // arbitrary vector of dims: 64 + for i := 0; i < 64; i++ { + queryVector = append(queryVector, float32(i)) + } + + var resVector []float32 + // arbitrary res vector. + for i := 0; i < 64; i++ { + resVector = append(resVector, float32(i)) + } + + tests := []struct { + vectorMatch *index.VectorDoc + scorer *KNNQueryScorer + norm float64 + result *search.DocumentMatch + }{ + { + vectorMatch: &index.VectorDoc{ + ID: index.IndexInternalID("one"), + Score: 0.5, + Vector: resVector, + }, + norm: 1.0, + scorer: NewKNNQueryScorer(queryVector, "desc", 1.0, + search.SearcherOptions{Explain: true}, index.EuclideanDistance), + // Specifically testing EuclideanDistance since that involves score inversion. + result: &search.DocumentMatch{ + IndexInternalID: index.IndexInternalID("one"), + Score: 0.5, + Expl: &search.Explanation{ + Value: 1 / 0.5, + Message: "fieldWeight(desc in doc one), score of:", + Children: []*search.Explanation{ + { + Value: 1 / 0.5, + Message: "vector(field(desc:one) with similarity_metric(l2_norm)=2.000000e+00", + }, + }, + }, + }, + }, + { + vectorMatch: &index.VectorDoc{ + ID: index.IndexInternalID("one"), + Score: 0.0, + // Result vector is an exact match of an existing vector. + Vector: queryVector, + }, + norm: 1.0, + scorer: NewKNNQueryScorer(queryVector, "desc", 1.0, + search.SearcherOptions{Explain: true}, index.EuclideanDistance), + // Specifically testing EuclideanDistance with 0 score. + result: &search.DocumentMatch{ + IndexInternalID: index.IndexInternalID("one"), + Score: 0.0, + Expl: &search.Explanation{ + Value: maxKNNScore, + Message: "fieldWeight(desc in doc one), score of:", + Children: []*search.Explanation{ + { + Value: maxKNNScore, + Message: "vector(field(desc:one) with similarity_metric(l2_norm)=1.797693e+308", + }, + }, + }, + }, + }, + { + vectorMatch: &index.VectorDoc{ + ID: index.IndexInternalID("one"), + Score: 0.5, + Vector: resVector, + }, + norm: 1.0, + scorer: NewKNNQueryScorer(queryVector, "desc", 1.0, + search.SearcherOptions{Explain: true}, index.CosineSimilarity), + result: &search.DocumentMatch{ + IndexInternalID: index.IndexInternalID("one"), + Score: 0.5, + Expl: &search.Explanation{ + Value: 0.5, + Message: "fieldWeight(desc in doc one), score of:", + Children: []*search.Explanation{ + { + Value: 0.5, + Message: "vector(field(desc:one) with similarity_metric(dot_product)=5.000000e-01", + }, + }, + }, + }, + }, + { + vectorMatch: &index.VectorDoc{ + ID: index.IndexInternalID("one"), + Score: 0.25, + Vector: resVector, + }, + norm: 0.5, + scorer: NewKNNQueryScorer(queryVector, "desc", 1.0, + search.SearcherOptions{Explain: true}, index.CosineSimilarity), + result: &search.DocumentMatch{ + IndexInternalID: index.IndexInternalID("one"), + Score: 0.25, + Expl: &search.Explanation{ + Value: 0.125, + Message: "weight(desc:query Vector^1.000000 in one), product of:", + Children: []*search.Explanation{ + { + Value: 0.5, + Message: "queryWeight(desc:query Vector^1.000000), product of:", + Children: []*search.Explanation{ + { + Value: 1, + Message: "boost", + }, + { + Value: 0.5, + Message: "queryNorm", + }, + }, + }, + { + Value: 0.25, + Message: "fieldWeight(desc in doc one), score of:", + Children: []*search.Explanation{ + { + Value: 0.25, + Message: "vector(field(desc:one) with similarity_metric(dot_product)=2.500000e-01", + }, + }, + }, + }, + }, + }, + }, + } + + for _, test := range tests { + ctx := &search.SearchContext{ + DocumentMatchPool: search.NewDocumentMatchPool(1, 0), + } + test.scorer.SetQueryNorm(test.norm) + actual := test.scorer.Score(ctx, test.vectorMatch) + actual.Complete(nil) + + if !reflect.DeepEqual(actual.Expl, test.result.Expl) { + t.Errorf("expected %#v got %#v for %#v", test.result.Expl, + actual.Expl, test.vectorMatch) + } + } +} diff --git a/search/search.go b/search/search.go index b7a3c42ae..515a320f7 100644 --- a/search/search.go +++ b/search/search.go @@ -147,7 +147,7 @@ type DocumentMatch struct { Index string `json:"index,omitempty"` ID string `json:"id"` IndexInternalID index.IndexInternalID `json:"-"` - Score float64 `json:"score"` + Score float64 `json:"score,omitempty"` Expl *Explanation `json:"explanation,omitempty"` Locations FieldTermLocationMap `json:"locations,omitempty"` Fragments FieldFragmentMap `json:"fragments,omitempty"` @@ -173,6 +173,22 @@ type DocumentMatch struct { // not all sub-queries matched // if false, all the sub-queries matched PartialMatch bool `json:"partial_match,omitempty"` + + // used to indicate the sub-scores that combined to form the + // final score for this document match. This is only populated + // when the search request's query is a DisjunctionQuery + // or a ConjunctionQuery. The map key is the index of the sub-query + // in the DisjunctionQuery or ConjunctionQuery. The map value is the + // sub-score for that sub-query. + ScoreBreakdown map[int]float64 `json:"score_breakdown,omitempty"` + + // internal variable used in PreSearch phase of search in alias + // to indicate the name of the index that this match came from. + // used in knn search. + // it is a stack of index names, the top of the stack is the name + // of the index that this match came from + // of the current alias view, used in alias of aliases scenario + IndexNames []string `json:"index_names,omitempty"` } func (dm *DocumentMatch) AddFieldValue(name string, value interface{}) { @@ -334,7 +350,7 @@ func (dm *DocumentMatch) Complete(prealloc []Location) []Location { } func (dm *DocumentMatch) String() string { - return fmt.Sprintf("[%s-%f]", string(dm.IndexInternalID), dm.Score) + return fmt.Sprintf("[%s-%f]", dm.ID, dm.Score) } type DocumentMatchCollection []*DocumentMatch diff --git a/search/searcher/optimize_knn.go b/search/searcher/optimize_knn.go new file mode 100644 index 000000000..efe262b5b --- /dev/null +++ b/search/searcher/optimize_knn.go @@ -0,0 +1,53 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package searcher + +import ( + "context" + + "github.com/blevesearch/bleve/v2/search" + index "github.com/blevesearch/bleve_index_api" +) + +func optimizeKNN(ctx context.Context, indexReader index.IndexReader, + qsearchers []search.Searcher) error { + var octx index.VectorOptimizableContext + var err error + + for _, searcher := range qsearchers { + // Only applicable to KNN Searchers. + o, ok := searcher.(index.VectorOptimizable) + if !ok { + continue + } + + octx, err = o.VectorOptimize(ctx, octx) + if err != nil { + return err + } + } + + // No KNN searchers. + if octx == nil { + return nil + } + + // Postings lists and iterators replaced in the pointer to the + // vector reader + return octx.Finish() +} diff --git a/search/searcher/optimize_no_knn.go b/search/searcher/optimize_no_knn.go new file mode 100644 index 000000000..bd5d91fb9 --- /dev/null +++ b/search/searcher/optimize_no_knn.go @@ -0,0 +1,31 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vectors +// +build !vectors + +package searcher + +import ( + "context" + + "github.com/blevesearch/bleve/v2/search" + index "github.com/blevesearch/bleve_index_api" +) + +func optimizeKNN(ctx context.Context, indexReader index.IndexReader, + qsearchers []search.Searcher) error { + // No-op + return nil +} diff --git a/search/searcher/ordered_searchers_list.go b/search/searcher/ordered_searchers_list.go index f3e646e9d..ac9da563d 100644 --- a/search/searcher/ordered_searchers_list.go +++ b/search/searcher/ordered_searchers_list.go @@ -33,3 +33,23 @@ func (otrl OrderedSearcherList) Less(i, j int) bool { func (otrl OrderedSearcherList) Swap(i, j int) { otrl[i], otrl[j] = otrl[j], otrl[i] } + +type OrderedPositionalSearcherList struct { + searchers []search.Searcher + index []int +} + +// sort.Interface + +func (otrl OrderedPositionalSearcherList) Len() int { + return len(otrl.searchers) +} + +func (otrl OrderedPositionalSearcherList) Less(i, j int) bool { + return otrl.searchers[i].Count() < otrl.searchers[j].Count() +} + +func (otrl OrderedPositionalSearcherList) Swap(i, j int) { + otrl.searchers[i], otrl.searchers[j] = otrl.searchers[j], otrl.searchers[i] + otrl.index[i], otrl.index[j] = otrl.index[j], otrl.index[i] +} diff --git a/search/searcher/search_conjunction.go b/search/searcher/search_conjunction.go index 19ef199ac..25e661075 100644 --- a/search/searcher/search_conjunction.go +++ b/search/searcher/search_conjunction.go @@ -35,7 +35,7 @@ func init() { type ConjunctionSearcher struct { indexReader index.IndexReader - searchers OrderedSearcherList + searchers []search.Searcher queryNorm float64 currs []*search.DocumentMatch maxIDIdx int @@ -88,6 +88,20 @@ func NewConjunctionSearcher(ctx context.Context, indexReader index.IndexReader, return &rv, nil } +func (s *ConjunctionSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} + func (s *ConjunctionSearcher) Size() int { sizeInBytes := reflectStaticSizeConjunctionSearcher + size.SizeOfPtr + s.scorer.Size() @@ -105,20 +119,6 @@ func (s *ConjunctionSearcher) Size() int { return sizeInBytes } -func (s *ConjunctionSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } -} - func (s *ConjunctionSearcher) initSearchers(ctx *search.SearchContext) error { var err error // get all searchers pointing at their first match diff --git a/search/searcher/search_disjunction.go b/search/searcher/search_disjunction.go index 606a157ae..d165ec027 100644 --- a/search/searcher/search_disjunction.go +++ b/search/searcher/search_disjunction.go @@ -46,15 +46,31 @@ func optionsDisjunctionOptimizable(options search.SearcherOptions) bool { func newDisjunctionSearcher(ctx context.Context, indexReader index.IndexReader, qsearchers []search.Searcher, min float64, options search.SearcherOptions, limit bool) (search.Searcher, error) { - // attempt the "unadorned" disjunction optimization only when we - // do not need extra information like freq-norm's or term vectors - // and the requested min is simple - if len(qsearchers) > 1 && min <= 1 && - optionsDisjunctionOptimizable(options) { - rv, err := optimizeCompositeSearcher(ctx, "disjunction:unadorned", - indexReader, qsearchers, options) - if err != nil || rv != nil { - return rv, err + + var disjOverKNN bool + if ctx != nil { + disjOverKNN, _ = ctx.Value(search.IncludeScoreBreakdownKey).(bool) + } + if disjOverKNN { + // The KNN Searcher optimization is a necessary pre-req for the KNN Searchers, + // not an optional optimization like for, say term searchers. + // It's an optimization to repeat search an open vector index when applicable, + // rather than individually opening and searching a vector index. + err := optimizeKNN(ctx, indexReader, qsearchers) + if err != nil { + return nil, err + } + } else { + // attempt the "unadorned" disjunction optimization only when we + // do not need extra information like freq-norm's or term vectors + // and the requested min is simple + if len(qsearchers) > 1 && min <= 1 && + optionsDisjunctionOptimizable(options) { + rv, err := optimizeCompositeSearcher(ctx, "disjunction:unadorned", + indexReader, qsearchers, options) + if err != nil || rv != nil { + return rv, err + } } } diff --git a/search/searcher/search_disjunction_heap.go b/search/searcher/search_disjunction_heap.go index d36e30131..89bcd498f 100644 --- a/search/searcher/search_disjunction_heap.go +++ b/search/searcher/search_disjunction_heap.go @@ -39,22 +39,25 @@ func init() { } type SearcherCurr struct { - searcher search.Searcher - curr *search.DocumentMatch + searcher search.Searcher + curr *search.DocumentMatch + matchingIdx int } type DisjunctionHeapSearcher struct { indexReader index.IndexReader - numSearchers int - scorer *scorer.DisjunctionQueryScorer - min int - queryNorm float64 - initialized bool - searchers []search.Searcher - heap []*SearcherCurr + numSearchers int + scorer *scorer.DisjunctionQueryScorer + min int + queryNorm float64 + retrieveScoreBreakdown bool + initialized bool + searchers []search.Searcher + heap []*SearcherCurr matching []*search.DocumentMatch + matchingIdxs []int matchingCurrs []*SearcherCurr bytesRead uint64 @@ -67,22 +70,42 @@ func newDisjunctionHeapSearcher(ctx context.Context, indexReader index.IndexRead if limit && tooManyClauses(len(searchers)) { return nil, tooManyClausesErr("", len(searchers)) } + var retrieveScoreBreakdown bool + if ctx != nil { + retrieveScoreBreakdown, _ = ctx.Value(search.IncludeScoreBreakdownKey).(bool) + } // build our searcher rv := DisjunctionHeapSearcher{ - indexReader: indexReader, - searchers: searchers, - numSearchers: len(searchers), - scorer: scorer.NewDisjunctionQueryScorer(options), - min: int(min), - matching: make([]*search.DocumentMatch, len(searchers)), - matchingCurrs: make([]*SearcherCurr, len(searchers)), - heap: make([]*SearcherCurr, 0, len(searchers)), + indexReader: indexReader, + searchers: searchers, + numSearchers: len(searchers), + scorer: scorer.NewDisjunctionQueryScorer(options), + min: int(min), + matching: make([]*search.DocumentMatch, len(searchers)), + matchingCurrs: make([]*SearcherCurr, len(searchers)), + matchingIdxs: make([]int, len(searchers)), + retrieveScoreBreakdown: retrieveScoreBreakdown, + heap: make([]*SearcherCurr, 0, len(searchers)), } rv.computeQueryNorm() return &rv, nil } +func (s *DisjunctionHeapSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} + func (s *DisjunctionHeapSearcher) Size() int { sizeInBytes := reflectStaticSizeDisjunctionHeapSearcher + size.SizeOfPtr + s.scorer.Size() @@ -101,24 +124,11 @@ func (s *DisjunctionHeapSearcher) Size() int { // since searchers and document matches already counted above sizeInBytes += len(s.matchingCurrs) * reflectStaticSizeSearcherCurr sizeInBytes += len(s.heap) * reflectStaticSizeSearcherCurr + sizeInBytes += len(s.matchingIdxs) * size.SizeOfInt return sizeInBytes } -func (s *DisjunctionHeapSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } -} - func (s *DisjunctionHeapSearcher) initSearchers(ctx *search.SearchContext) error { // alloc a single block of SearcherCurrs block := make([]SearcherCurr, len(s.searchers)) @@ -132,6 +142,7 @@ func (s *DisjunctionHeapSearcher) initSearchers(ctx *search.SearchContext) error if curr != nil { block[i].searcher = searcher block[i].curr = curr + block[i].matchingIdx = i heap.Push(s, &block[i]) } } @@ -147,6 +158,7 @@ func (s *DisjunctionHeapSearcher) initSearchers(ctx *search.SearchContext) error func (s *DisjunctionHeapSearcher) updateMatches() error { matching := s.matching[:0] matchingCurrs := s.matchingCurrs[:0] + matchingIdxs := s.matchingIdxs[:0] if len(s.heap) > 0 { @@ -154,17 +166,20 @@ func (s *DisjunctionHeapSearcher) updateMatches() error { next := heap.Pop(s).(*SearcherCurr) matching = append(matching, next.curr) matchingCurrs = append(matchingCurrs, next) + matchingIdxs = append(matchingIdxs, next.matchingIdx) // now as long as top of heap matches, keep popping for len(s.heap) > 0 && bytes.Compare(next.curr.IndexInternalID, s.heap[0].curr.IndexInternalID) == 0 { next = heap.Pop(s).(*SearcherCurr) matching = append(matching, next.curr) matchingCurrs = append(matchingCurrs, next) + matchingIdxs = append(matchingIdxs, next.matchingIdx) } } s.matching = matching s.matchingCurrs = matchingCurrs + s.matchingIdxs = matchingIdxs return nil } @@ -197,10 +212,16 @@ func (s *DisjunctionHeapSearcher) Next(ctx *search.SearchContext) ( for !found && len(s.matching) > 0 { if len(s.matching) >= s.min { found = true - partialMatch := len(s.matching) != len(s.searchers) - // score this match - rv = s.scorer.Score(ctx, s.matching, len(s.matching), s.numSearchers) - rv.PartialMatch = partialMatch + if s.retrieveScoreBreakdown { + // just return score and expl breakdown here, since it is a disjunction over knn searchers, + // and the final score and expl is calculated in the knn collector + rv = s.scorer.ScoreAndExplBreakdown(ctx, s.matching, s.matchingIdxs, nil, s.numSearchers) + } else { + // score this match + partialMatch := len(s.matching) != len(s.searchers) + rv = s.scorer.Score(ctx, s.matching, len(s.matching), s.numSearchers) + rv.PartialMatch = partialMatch + } } // invoke next on all the matching searchers diff --git a/search/searcher/search_disjunction_slice.go b/search/searcher/search_disjunction_slice.go index 0969c8cf3..81b00cc22 100644 --- a/search/searcher/search_disjunction_slice.go +++ b/search/searcher/search_disjunction_slice.go @@ -34,17 +34,19 @@ func init() { } type DisjunctionSliceSearcher struct { - indexReader index.IndexReader - searchers OrderedSearcherList - numSearchers int - queryNorm float64 - currs []*search.DocumentMatch - scorer *scorer.DisjunctionQueryScorer - min int - matching []*search.DocumentMatch - matchingIdxs []int - initialized bool - bytesRead uint64 + indexReader index.IndexReader + searchers []search.Searcher + originalPos []int + numSearchers int + queryNorm float64 + retrieveScoreBreakdown bool + currs []*search.DocumentMatch + scorer *scorer.DisjunctionQueryScorer + min int + matching []*search.DocumentMatch + matchingIdxs []int + initialized bool + bytesRead uint64 } func newDisjunctionSliceSearcher(ctx context.Context, indexReader index.IndexReader, @@ -54,21 +56,45 @@ func newDisjunctionSliceSearcher(ctx context.Context, indexReader index.IndexRea if limit && tooManyClauses(len(qsearchers)) { return nil, tooManyClausesErr("", len(qsearchers)) } - // build the downstream searchers - searchers := make(OrderedSearcherList, len(qsearchers)) - for i, searcher := range qsearchers { - searchers[i] = searcher + + var searchers OrderedSearcherList + var originalPos []int + var retrieveScoreBreakdown bool + if ctx != nil { + retrieveScoreBreakdown, _ = ctx.Value(search.IncludeScoreBreakdownKey).(bool) + } + + if retrieveScoreBreakdown { + // needed only when kNN is in picture + sortedSearchers := &OrderedPositionalSearcherList{ + searchers: make([]search.Searcher, len(qsearchers)), + index: make([]int, len(qsearchers)), + } + for i, searcher := range qsearchers { + sortedSearchers.searchers[i] = searcher + sortedSearchers.index[i] = i + } + sort.Sort(sortedSearchers) + searchers = sortedSearchers.searchers + originalPos = sortedSearchers.index + } else { + searchers = make(OrderedSearcherList, len(qsearchers)) + for i, searcher := range qsearchers { + searchers[i] = searcher + } + sort.Sort(searchers) } - // sort the searchers - sort.Sort(sort.Reverse(searchers)) - // build our searcher + rv := DisjunctionSliceSearcher{ - indexReader: indexReader, - searchers: searchers, - numSearchers: len(searchers), - currs: make([]*search.DocumentMatch, len(searchers)), - scorer: scorer.NewDisjunctionQueryScorer(options), - min: int(min), + indexReader: indexReader, + searchers: searchers, + originalPos: originalPos, + numSearchers: len(searchers), + currs: make([]*search.DocumentMatch, len(searchers)), + scorer: scorer.NewDisjunctionQueryScorer(options), + min: int(min), + retrieveScoreBreakdown: retrieveScoreBreakdown, + matching: make([]*search.DocumentMatch, len(searchers)), matchingIdxs: make([]int, len(searchers)), } @@ -76,6 +102,20 @@ func newDisjunctionSliceSearcher(ctx context.Context, indexReader index.IndexRea return &rv, nil } +func (s *DisjunctionSliceSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} + func (s *DisjunctionSliceSearcher) Size() int { sizeInBytes := reflectStaticSizeDisjunctionSliceSearcher + size.SizeOfPtr + s.scorer.Size() @@ -97,24 +137,11 @@ func (s *DisjunctionSliceSearcher) Size() int { } sizeInBytes += len(s.matchingIdxs) * size.SizeOfInt + sizeInBytes += len(s.originalPos) * size.SizeOfInt return sizeInBytes } -func (s *DisjunctionSliceSearcher) computeQueryNorm() { - // first calculate sum of squared weights - sumOfSquaredWeights := 0.0 - for _, searcher := range s.searchers { - sumOfSquaredWeights += searcher.Weight() - } - // now compute query norm from this - s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) - // finally tell all the downstream searchers the norm - for _, searcher := range s.searchers { - searcher.SetQueryNorm(s.queryNorm) - } -} - func (s *DisjunctionSliceSearcher) initSearchers(ctx *search.SearchContext) error { var err error // get all searchers pointing at their first match @@ -197,10 +224,16 @@ func (s *DisjunctionSliceSearcher) Next(ctx *search.SearchContext) ( for !found && len(s.matching) > 0 { if len(s.matching) >= s.min { found = true - partialMatch := len(s.matching) != len(s.searchers) - // score this match - rv = s.scorer.Score(ctx, s.matching, len(s.matching), s.numSearchers) - rv.PartialMatch = partialMatch + if s.retrieveScoreBreakdown { + // just return score and expl breakdown here, since it is a disjunction over knn searchers, + // and the final score and expl is calculated in the knn collector + rv = s.scorer.ScoreAndExplBreakdown(ctx, s.matching, s.matchingIdxs, s.originalPos, s.numSearchers) + } else { + // score this match + partialMatch := len(s.matching) != len(s.searchers) + rv = s.scorer.Score(ctx, s.matching, len(s.matching), s.numSearchers) + rv.PartialMatch = partialMatch + } } // invoke next on all the matching searchers diff --git a/search/searcher/search_knn.go b/search/searcher/search_knn.go index 7dd59967e..8f146b3e8 100644 --- a/search/searcher/search_knn.go +++ b/search/searcher/search_knn.go @@ -19,13 +19,22 @@ package searcher import ( "context" + "reflect" "github.com/blevesearch/bleve/v2/mapping" "github.com/blevesearch/bleve/v2/search" "github.com/blevesearch/bleve/v2/search/scorer" + "github.com/blevesearch/bleve/v2/size" index "github.com/blevesearch/bleve_index_api" ) +var reflectStaticSizeKNNSearcher int + +func init() { + var ks KNNSearcher + reflectStaticSizeKNNSearcher = int(reflect.TypeOf(ks).Size()) +} + type KNNSearcher struct { field string vector []float32 @@ -41,16 +50,13 @@ func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMap options search.SearcherOptions, field string, vector []float32, k int64, boost float64, similarityMetric string) (search.Searcher, error) { if vr, ok := i.(index.VectorIndexReader); ok { - vectorReader, _ := vr.VectorReader(ctx, vector, field, k) - - count, err := i.DocCount() + vectorReader, err := vr.VectorReader(ctx, vector, field, k) if err != nil { - _ = vectorReader.Close() return nil, err } knnScorer := scorer.NewKNNQueryScorer(vector, field, boost, - vectorReader.Count(), count, options, similarityMetric) + options, similarityMetric) return &KNNSearcher{ indexReader: i, vectorReader: vectorReader, @@ -63,6 +69,16 @@ func NewKNNSearcher(ctx context.Context, i index.IndexReader, m mapping.IndexMap return nil, nil } +func (s *KNNSearcher) VectorOptimize(ctx context.Context, octx index.VectorOptimizableContext) ( + index.VectorOptimizableContext, error) { + o, ok := s.vectorReader.(index.VectorOptimizable) + if ok { + return o.VectorOptimize(ctx, octx) + } + + return nil, nil +} + func (s *KNNSearcher) Advance(ctx *search.SearchContext, ID index.IndexInternalID) ( *search.DocumentMatch, error) { knnMatch, err := s.vectorReader.Next(s.vd.Reset()) @@ -115,7 +131,10 @@ func (s *KNNSearcher) SetQueryNorm(qnorm float64) { } func (s *KNNSearcher) Size() int { - return 0 + return reflectStaticSizeKNNSearcher + size.SizeOfPtr + + s.vectorReader.Size() + + s.vd.Size() + + s.scorer.Size() } func (s *KNNSearcher) Weight() float64 { diff --git a/search/util.go b/search/util.go index b2cb62a2d..6472803d1 100644 --- a/search/util.go +++ b/search/util.go @@ -106,6 +106,7 @@ const ( const SearchIncrementalCostKey = "_search_incremental_cost_key" const QueryTypeKey = "_query_type_key" const FuzzyMatchPhraseKey = "_fuzzy_match_phrase_key" +const IncludeScoreBreakdownKey = "_include_score_breakdown_key" func RecordSearchCost(ctx context.Context, msg SearchIncrementalCostCallbackMsg, bytes uint64) { @@ -133,3 +134,15 @@ const MaxGeoBufPoolSize = 24 * 1024 const MinGeoBufPoolSize = 24 type GeoBufferPoolCallbackFunc func() *s2.GeoBufferPool + +const KnnPreSearchDataKey = "_knn_pre_search_data_key" + +const PreSearchKey = "_presearch_key" + +type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) + +type SearcherStartCallbackFn func(size uint64) error +type SearcherEndCallbackFn func(size uint64) error + +const SearcherStartCallbackKey = "_searcher_start_callback_key" +const SearcherEndCallbackKey = "_searcher_end_callback_key" diff --git a/search_knn.go b/search_knn.go index a2f8d343c..ccb7fb2ea 100644 --- a/search_knn.go +++ b/search_knn.go @@ -18,13 +18,22 @@ package bleve import ( + "context" "encoding/json" + "fmt" "sort" "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/search/collector" "github.com/blevesearch/bleve/v2/search/query" + index "github.com/blevesearch/bleve_index_api" ) +type knnOperator string + +// Must be updated only at init +var BleveMaxK = int64(10000) + type SearchRequest struct { Query query.Query `json:"query"` Size int `json:"size"` @@ -39,7 +48,20 @@ type SearchRequest struct { SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` - KNN []*KNNRequest `json:"knn"` + KNN []*KNNRequest `json:"knn"` + KNNOperator knnOperator `json:"knn_operator"` + + // PreSearchData will be a map that will be used + // in the second phase of any 2-phase search, to provide additional + // context to the second phase. This is useful in the case of index + // aliases where the first phase will gather the PreSearchData from all + // the indexes in the alias, and the second phase will use that + // PreSearchData to perform the actual search. + // The currently accepted map configuration is: + // + // "_knn_pre_search_data_key": []*search.DocumentMatch + + PreSearchData map[string]interface{} `json:"pre_search_data,omitempty"` sortFunc func(sort.Interface) } @@ -61,6 +83,10 @@ func (r *SearchRequest) AddKNN(field string, vector []float32, k int64, boost fl }) } +func (r *SearchRequest) AddKNNOperator(operator knnOperator) { + r.KNNOperator = operator +} + // UnmarshalJSON deserializes a JSON representation of // a SearchRequest func (r *SearchRequest) UnmarshalJSON(input []byte) error { @@ -78,6 +104,8 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` KNN []*KNNRequest `json:"knn"` + KNNOperator knnOperator `json:"knn_operator"` + PreSearchData json.RawMessage `json:"pre_search_data"` } err := json.Unmarshal(input, &temp) @@ -120,6 +148,17 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { } r.KNN = temp.KNN + r.KNNOperator = temp.KNNOperator + if r.KNNOperator == "" { + r.KNNOperator = knnOperatorOr + } + + if temp.PreSearchData != nil { + r.PreSearchData, err = query.ParsePreSearchData(temp.PreSearchData) + if err != nil { + return err + } + } return nil @@ -127,7 +166,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { // ----------------------------------------------------------------------------- -func copySearchRequest(req *SearchRequest) *SearchRequest { +func copySearchRequest(req *SearchRequest, preSearchData map[string]interface{}) *SearchRequest { rv := SearchRequest{ Query: req.Query, Size: req.Size + req.From, @@ -142,24 +181,333 @@ func copySearchRequest(req *SearchRequest) *SearchRequest { SearchAfter: req.SearchAfter, SearchBefore: req.SearchBefore, KNN: req.KNN, + KNNOperator: req.KNNOperator, + PreSearchData: preSearchData, } return &rv } -func disjunctQueryWithKNN(req *SearchRequest) query.Query { - if len(req.KNN) > 0 { - disjuncts := []query.Query{req.Query} +var ( + knnOperatorAnd = knnOperator("and") + knnOperatorOr = knnOperator("or") +) + +func createKNNQuery(req *SearchRequest) (query.Query, []int64, int64, error) { + if requestHasKNN(req) { + // first perform validation + err := validateKNN(req) + if err != nil { + return nil, nil, 0, err + } + var subQueries []query.Query + kArray := make([]int64, 0, len(req.KNN)) + sumOfK := int64(0) for _, knn := range req.KNN { - if knn != nil { - knnQuery := query.NewKNNQuery(knn.Vector) - knnQuery.SetFieldVal(knn.Field) - knnQuery.SetK(knn.K) - knnQuery.SetBoost(knn.Boost.Value()) - disjuncts = append(disjuncts, knnQuery) + knnQuery := query.NewKNNQuery(knn.Vector) + knnQuery.SetFieldVal(knn.Field) + knnQuery.SetK(knn.K) + knnQuery.SetBoost(knn.Boost.Value()) + subQueries = append(subQueries, knnQuery) + kArray = append(kArray, knn.K) + sumOfK += knn.K + } + rv := query.NewDisjunctionQuery(subQueries) + rv.RetrieveScoreBreakdown(true) + return rv, kArray, sumOfK, nil + } + return nil, nil, 0, nil +} + +func validateKNN(req *SearchRequest) error { + if req.KNN != nil && + req.KNNOperator != "" && + req.KNNOperator != knnOperatorOr && + req.KNNOperator != knnOperatorAnd { + return fmt.Errorf("unknown knn operator: %s", req.KNNOperator) + } + for _, q := range req.KNN { + if q == nil { + return fmt.Errorf("knn query cannot be nil") + } + if q.K <= 0 || len(q.Vector) == 0 { + return fmt.Errorf("k must be greater than 0 and vector must be non-empty") + } + if q.K > BleveMaxK { + return fmt.Errorf("k must be less than %d", BleveMaxK) + } + } + switch req.KNNOperator { + case knnOperatorAnd, knnOperatorOr, "": + // Valid cases, do nothing + default: + return fmt.Errorf("knn_operator must be either 'and' / 'or'") + } + return nil +} + +func addSortAndFieldsToKNNHits(req *SearchRequest, knnHits []*search.DocumentMatch, reader index.IndexReader, name string) (err error) { + requiredSortFields := req.Sort.RequiredFields() + var dvReader index.DocValueReader + var updateFieldVisitor index.DocValueVisitor + if len(requiredSortFields) > 0 { + dvReader, err = reader.DocValueReader(requiredSortFields) + if err != nil { + return err + } + updateFieldVisitor = func(field string, term []byte) { + req.Sort.UpdateVisitor(field, term) + } + } + for _, hit := range knnHits { + if len(requiredSortFields) > 0 { + err = dvReader.VisitDocValues(hit.IndexInternalID, updateFieldVisitor) + if err != nil { + return err + } + } + req.Sort.Value(hit) + err, _ = LoadAndHighlightFields(hit, req, "", reader, nil) + if err != nil { + return err + } + hit.Index = name + } + return nil +} + +func (i *indexImpl) runKnnCollector(ctx context.Context, req *SearchRequest, reader index.IndexReader, preSearch bool) ([]*search.DocumentMatch, error) { + KNNQuery, kArray, sumOfK, err := createKNNQuery(req) + if err != nil { + return nil, err + } + knnSearcher, err := KNNQuery.Searcher(ctx, reader, i.m, search.SearcherOptions{ + Explain: req.Explain, + }) + if err != nil { + return nil, err + } + knnCollector := collector.NewKNNCollector(kArray, sumOfK) + err = knnCollector.Collect(ctx, knnSearcher, reader) + if err != nil { + return nil, err + } + knnHits := knnCollector.Results() + if !preSearch { + knnHits = finalizeKNNResults(req, knnHits, len(req.KNN)) + } + // at this point, irrespective of whether it is a presearch or not, + // the knn hits are populated with Sort and Fields. + // it must be ensured downstream that the Sort and Fields are not + // re-evaluated, for these hits. + // also add the index names to the hits, so that when early + // exit takes place after the first phase, the hits will have + // a valid value for Index. + err = addSortAndFieldsToKNNHits(req, knnHits, reader, i.name) + if err != nil { + return nil, err + } + return knnHits, nil +} + +func setKnnHitsInCollector(knnHits []*search.DocumentMatch, req *SearchRequest, coll *collector.TopNCollector) { + if len(knnHits) > 0 { + newScoreExplComputer := func(queryMatch *search.DocumentMatch, knnMatch *search.DocumentMatch) (float64, *search.Explanation) { + totalScore := queryMatch.Score + knnMatch.Score + if !req.Explain { + // exit early as we don't need to compute the explanation + return totalScore, nil + } + return totalScore, &search.Explanation{Value: totalScore, Message: "sum of:", Children: []*search.Explanation{queryMatch.Expl, knnMatch.Expl}} + } + coll.SetKNNHits(knnHits, search.ScoreExplCorrectionCallbackFunc(newScoreExplComputer)) + } +} + +func finalizeKNNResults(req *SearchRequest, knnHits []*search.DocumentMatch, numKNNQueries int) []*search.DocumentMatch { + // if the KNN operator is AND, then we need to filter out the hits that + // do not have match the KNN queries. + if req.KNNOperator == knnOperatorAnd { + idx := 0 + for _, hit := range knnHits { + if len(hit.ScoreBreakdown) == numKNNQueries { + knnHits[idx] = hit + idx++ + } + } + knnHits = knnHits[:idx] + } + // fix the score using score breakdown now + // if the score is none, then we need to set the score to 0.0 + // if req.Explain is true, then we need to use the expl breakdown to + // finalize the correct explanation. + for _, hit := range knnHits { + hit.Score = 0.0 + if req.Score != "none" { + for _, score := range hit.ScoreBreakdown { + hit.Score += score + } + } + if req.Explain { + childrenExpl := make([]*search.Explanation, 0, len(hit.ScoreBreakdown)) + for i := range hit.ScoreBreakdown { + childrenExpl = append(childrenExpl, hit.Expl.Children[i]) + } + hit.Expl = &search.Explanation{Value: hit.Score, Message: "sum of:", Children: childrenExpl} + } + // we don't need the score breakdown anymore + // so we can set it to nil + hit.ScoreBreakdown = nil + } + return knnHits +} + +func mergeKNNDocumentMatches(req *SearchRequest, knnHits []*search.DocumentMatch) []*search.DocumentMatch { + kArray := make([]int64, len(req.KNN)) + for i, knnReq := range req.KNN { + kArray[i] = knnReq.K + } + knnStore := collector.GetNewKNNCollectorStore(kArray) + for _, hit := range knnHits { + knnStore.AddDocument(hit) + } + // passing nil as the document fixup function, because we don't need to + // fixup the document, since this was already done in the first phase. + // hence error is always nil. + mergedKNNhits, _ := knnStore.Final(nil) + return finalizeKNNResults(req, mergedKNNhits, len(req.KNN)) +} + +// when we are setting KNN hits in the preSearchData, we need to make sure that +// the KNN hit goes to the right index. This is because the KNN hits are +// collected from all the indexes in the alias, but the preSearchData is +// specific to each index. If alias A1 contains indexes I1 and I2 and +// the KNN hits collected from both I1 and I2, and merged to get top K +// hits, then the top K hits need to be distributed to I1 and I2, +// so that the preSearchData for I1 contains the top K hits from I1 and +// the preSearchData for I2 contains the top K hits from I2. +func validateAndDistributeKNNHits(knnHits []*search.DocumentMatch, indexes []Index) (map[string][]*search.DocumentMatch, error) { + // create a set of all the index names of this alias + indexNames := make(map[string]struct{}, len(indexes)) + for _, index := range indexes { + indexNames[index.Name()] = struct{}{} + } + segregatedKnnHits := make(map[string][]*search.DocumentMatch) + for _, hit := range knnHits { + // for each hit, we need to perform a validation check to ensure that the stack + // is still valid. + // + // if the stack is empty, then we have an inconsistency/abnormality + // since any hit with an empty stack is supposed to land on a leaf index, + // and not an alias. This cannot happen in normal circumstances. But + // performing this check to be safe. Since we extract the stack top + // in the following steps. + if len(hit.IndexNames) == 0 { + return nil, ErrorTwoPhaseSearchInconsistency + } + // since the stack is not empty, we need to check if the top of the stack + // is a valid index name, of an index that is part of this alias. If not, + // then we have an inconsistency that could be caused due to a topology + // change. + stackTopIdx := len(hit.IndexNames) - 1 + top := hit.IndexNames[stackTopIdx] + if _, exists := indexNames[top]; !exists { + return nil, ErrorTwoPhaseSearchInconsistency + } + if stackTopIdx == 0 { + // if the stack consists of only one index, then popping the top + // would result in an empty slice, and handle this case by setting + // indexNames to nil. So that the final search results will not + // contain the indexNames field. + hit.IndexNames = nil + } else { + hit.IndexNames = hit.IndexNames[:stackTopIdx] + } + segregatedKnnHits[top] = append(segregatedKnnHits[top], hit) + } + return segregatedKnnHits, nil +} + +func requestHasKNN(req *SearchRequest) bool { + return len(req.KNN) > 0 +} + +// returns true if the search request contains a KNN request that can be +// satisfied by just performing a presearch, completely bypassing the +// actual search. +func isKNNrequestSatisfiedByPreSearch(req *SearchRequest) bool { + // if req.Query is not match_none => then we need to go to phase 2 + // to perform the actual query. + if _, ok := req.Query.(*query.MatchNoneQuery); !ok { + return false + } + // req.Query is a match_none query + // + // if request contains facets, we need to perform phase 2 to calculate + // the facet result. Since documents were removed as part of the + // merging process after phase 1, if the facet results were to be calculated + // during phase 1, then they will be now be incorrect, since merging would + // remove some documents. + if req.Facets != nil { + return false + } + // the request is a match_none query and does not contain any facets + // so we can satisfy the request using just the preSearch result. + return true +} + +func constructKnnPresearchData(mergedOut map[string]map[string]interface{}, preSearchResult *SearchResult, + indexes []Index) (map[string]map[string]interface{}, error) { + + distributedHits, err := validateAndDistributeKNNHits([]*search.DocumentMatch(preSearchResult.Hits), indexes) + if err != nil { + return nil, err + } + for _, index := range indexes { + mergedOut[index.Name()][search.KnnPreSearchDataKey] = distributedHits[index.Name()] + } + return mergedOut, nil +} + +func addKnnToDummyRequest(dummyReq *SearchRequest, realReq *SearchRequest) { + dummyReq.KNN = realReq.KNN + dummyReq.KNNOperator = knnOperatorOr + dummyReq.Explain = realReq.Explain + dummyReq.Fields = realReq.Fields + dummyReq.Sort = realReq.Sort +} + +// the preSearchData for KNN is a list of DocumentMatch objects +// that need to be redistributed to the right index. +// This is used only in the case of an alias tree, where the indexes +// are at the leaves of the tree, and the master alias is at the root. +// At each level of the tree, the preSearchData needs to be redistributed +// to the indexes/aliases at that level. Because the preSearchData is +// specific to each final index at the leaf. +func redistributeKNNPreSearchData(req *SearchRequest, indexes []Index) (map[string]map[string]interface{}, error) { + knnHits, ok := req.PreSearchData[search.KnnPreSearchDataKey].([]*search.DocumentMatch) + if !ok { + return nil, fmt.Errorf("request does not have knn preSearchData for redistribution") + } + segregatedKnnHits, err := validateAndDistributeKNNHits(knnHits, indexes) + if err != nil { + return nil, err + } + + rv := make(map[string]map[string]interface{}) + for _, index := range indexes { + rv[index.Name()] = make(map[string]interface{}) + } + + for _, index := range indexes { + for k, v := range req.PreSearchData { + switch k { + case search.KnnPreSearchDataKey: + rv[index.Name()][k] = segregatedKnnHits[index.Name()] + default: + rv[index.Name()][k] = v } } - return query.NewDisjunctionQuery(disjuncts) } - return req.Query + return rv, nil } diff --git a/search_knn_test.go b/search_knn_test.go new file mode 100644 index 000000000..b54ce5a93 --- /dev/null +++ b/search_knn_test.go @@ -0,0 +1,1184 @@ +// Copyright (c) 2023 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package bleve + +import ( + "archive/zip" + "encoding/json" + "fmt" + "math" + "math/rand" + "sort" + "strconv" + "sync" + "testing" + + "github.com/blevesearch/bleve/v2/analysis/lang/en" + "github.com/blevesearch/bleve/v2/index/scorch" + "github.com/blevesearch/bleve/v2/mapping" + "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/search/query" + index "github.com/blevesearch/bleve_index_api" +) + +const testInputCompressedFile = "test/knn/knn_dataset_queries.zip" +const testDatasetFileName = "knn_dataset.json" +const testQueryFileName = "knn_queries.json" + +const testDatasetDims = 384 + +var knnOperators []knnOperator = []knnOperator{knnOperatorAnd, knnOperatorOr} + +func TestSimilaritySearchPartitionedIndex(t *testing.T) { + dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents := makeDatasetIntoDocuments(dataset) + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Analyzer = en.AnalyzerName + + vecFieldMappingL2 := mapping.NewVectorFieldMapping() + vecFieldMappingL2.Dims = testDatasetDims + vecFieldMappingL2.Similarity = index.EuclideanDistance + + indexMappingL2Norm := NewIndexMapping() + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingL2) + + vecFieldMappingDot := mapping.NewVectorFieldMapping() + vecFieldMappingDot.Dims = testDatasetDims + vecFieldMappingDot.Similarity = index.CosineSimilarity + + indexMappingDotProduct := NewIndexMapping() + indexMappingDotProduct.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMappingDotProduct.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingDot) + + type testCase struct { + testType string + queryIndex int + numIndexPartitions int + mapping mapping.IndexMapping + } + + testCases := []testCase{ + // l2 norm similarity + { + testType: "multi_partition:match_none:oneKNNreq:k=3", + queryIndex: 0, + numIndexPartitions: 4, + mapping: indexMappingL2Norm, + }, + { + testType: "multi_partition:match_none:oneKNNreq:k=2", + queryIndex: 0, + numIndexPartitions: 10, + mapping: indexMappingL2Norm, + }, + { + testType: "multi_partition:match:oneKNNreq:k=2", + queryIndex: 1, + numIndexPartitions: 5, + mapping: indexMappingL2Norm, + }, + { + testType: "multi_partition:disjunction:twoKNNreq:k=2,2", + queryIndex: 2, + numIndexPartitions: 4, + mapping: indexMappingL2Norm, + }, + // dot product similarity + { + testType: "multi_partition:match_none:oneKNNreq:k=3", + queryIndex: 0, + numIndexPartitions: 4, + mapping: indexMappingDotProduct, + }, + { + testType: "multi_partition:match_none:oneKNNreq:k=2", + queryIndex: 0, + numIndexPartitions: 10, + mapping: indexMappingDotProduct, + }, + { + testType: "multi_partition:match:oneKNNreq:k=2", + queryIndex: 1, + numIndexPartitions: 5, + mapping: indexMappingDotProduct, + }, + { + testType: "multi_partition:disjunction:twoKNNreq:k=2,2", + queryIndex: 2, + numIndexPartitions: 4, + mapping: indexMappingDotProduct, + }, + } + + index := NewIndexAlias() + for testCaseNum, testCase := range testCases { + for _, operator := range knnOperators { + index.indexes = make([]Index, 0) + + query := searchRequests[testCase.queryIndex] + query.AddKNNOperator(operator) + query.Sort = search.SortOrder{&search.SortScore{Desc: true}, &search.SortDocID{Desc: true}, &search.SortField{Desc: true, Field: "content"}} + query.Explain = true + + nameToIndex := createPartitionedIndex(documents, index, 1, testCase.mapping, t, false) + controlResult, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(controlResult.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected control result hits to have valid `Index`", testCaseNum) + } + cleanUp(t, nameToIndex) + index.indexes = make([]Index, 0) + nameToIndex = createPartitionedIndex(documents, index, testCase.numIndexPartitions, testCase.mapping, t, false) + experimentalResult, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(experimentalResult.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum) + } + verifyResult(t, controlResult, experimentalResult, testCaseNum, true) + cleanUp(t, nameToIndex) + + index.indexes = make([]Index, 0) + nameToIndex = createPartitionedIndex(documents, index, testCase.numIndexPartitions, testCase.mapping, t, true) + multiLevelIndexResult, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(multiLevelIndexResult.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum) + } + verifyResult(t, multiLevelIndexResult, experimentalResult, testCaseNum, false) + cleanUp(t, nameToIndex) + } + } + + var facets = map[string]*FacetRequest{ + "content": { + Field: "content", + Size: 10, + }, + } + + var sort = search.SortOrder{&search.SortScore{Desc: true}, &search.SortField{Desc: false, Field: "content"}} + + index = NewIndexAlias() + for testCaseNum, testCase := range testCases { + index.indexes = make([]Index, 0) + nameToIndex := createPartitionedIndex(documents, index, testCase.numIndexPartitions, testCase.mapping, t, false) + originalRequest := searchRequests[testCase.queryIndex] + for _, operator := range knnOperators { + + from, size := originalRequest.From, originalRequest.Size + query := copySearchRequest(searchRequests[testCase.queryIndex], nil) + query.AddKNNOperator(operator) + query.Explain = true + query.From = from + query.Size = size + + // Three types of queries to run wrt sort and facet fields that require fields. + // 1. Sort And Facet are there + // 2. Sort is there, Facet is not there + // 3. Sort is not there, Facet is there + // The case where both sort and facet are not there is already covered in the previous tests. + + // 1. Sort And Facet are there + query.Facets = facets + query.Sort = sort + + res1, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(res1.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum) + } + + facetRes1 := res1.Facets + facetRes1Str, err := json.Marshal(facetRes1) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + + // 2. Sort is there, Facet is not there + query.Facets = nil + query.Sort = sort + + res2, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(res2.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum) + } + + // 3. Sort is not there, Facet is there + query.Facets = facets + query.Sort = nil + res3, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(res3.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum) + } + + facetRes3 := res3.Facets + facetRes3Str, err := json.Marshal(facetRes3) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + + // Verify the facet results + if string(facetRes1Str) != string(facetRes3Str) { + fmt.Println(operator) + fmt.Println(string(facetRes1Str)) + fmt.Println(string(facetRes3Str)) + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected facet results to be equal", testCaseNum) + } + + // Verify the results + verifyResult(t, res1, res2, testCaseNum, false) + verifyResult(t, res2, res3, testCaseNum, true) + + // Test early exit fail case -> matchNone + facetRequest + query.Query = NewMatchNoneQuery() + query.Sort = sort + // control case + query.Facets = nil + res4Ctrl, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(res4Ctrl.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected control Result hits to have valid `Index`", testCaseNum) + } + + // experimental case + query.Facets = facets + res4Exp, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(res4Exp.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected experimental Result hits to have valid `Index`", testCaseNum) + } + + if !(operator == knnOperatorAnd && res4Ctrl.Total == 0 && res4Exp.Total == 0) { + // catch case where no hits are returned + // due to matchNone query with a KNN request with operator AND + // where no hits are part of the intersection in multi knn request + verifyResult(t, res4Ctrl, res4Exp, testCaseNum, false) + } + } + cleanUp(t, nameToIndex) + } + + // Test Pagination with multi partitioned index + index = NewIndexAlias() + index.indexes = make([]Index, 0) + nameToIndex := createPartitionedIndex(documents, index, 8, indexMappingL2Norm, t, true) + + // Test From + Size pagination for Hybrid Search (2-Phase) + query := copySearchRequest(searchRequests[4], nil) + query.Sort = sort + query.Facets = facets + query.Explain = true + + testFromSizePagination(t, query, index, nameToIndex) + + // Test From + Size pagination for Early Exit Hybrid Search (1-Phase) + query = copySearchRequest(searchRequests[4], nil) + query.Query = NewMatchNoneQuery() + query.Sort = sort + query.Facets = nil + query.Explain = true + + testFromSizePagination(t, query, index, nameToIndex) + + cleanUp(t, nameToIndex) +} + +func testFromSizePagination(t *testing.T, query *SearchRequest, index Index, nameToIndex map[string]Index) { + query.From = 0 + query.Size = 30 + + resCtrl, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + + ctrlHitIds := make([]string, len(resCtrl.Hits)) + for i, doc := range resCtrl.Hits { + ctrlHitIds[i] = doc.ID + } + // experimental case + + fromValues := []int{0, 5, 10, 15, 20, 25} + size := 5 + for fromIdx := 0; fromIdx < len(fromValues); fromIdx++ { + from := fromValues[fromIdx] + query.From = from + query.Size = size + resExp, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if from >= len(ctrlHitIds) { + if len(resExp.Hits) != 0 { + cleanUp(t, nameToIndex) + t.Fatalf("expected 0 hits, got %d", len(resExp.Hits)) + } + continue + } + numHitsExp := len(resExp.Hits) + numHitsCtrl := min(len(ctrlHitIds)-from, size) + if numHitsExp != numHitsCtrl { + cleanUp(t, nameToIndex) + t.Fatalf("expected %d hits, got %d", numHitsCtrl, numHitsExp) + } + for i := 0; i < numHitsExp; i++ { + doc := resExp.Hits[i] + startOffset := from + i + if doc.ID != ctrlHitIds[startOffset] { + cleanUp(t, nameToIndex) + t.Fatalf("expected %s at index %d, got %s", ctrlHitIds[startOffset], i, doc.ID) + } + } + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +type testDocument struct { + ID string `json:"id"` + Content string `json:"content"` + Vector []float64 `json:"vector"` +} + +func readDatasetAndQueries(fileName string) ([]testDocument, []*SearchRequest, error) { + // Open the zip archive for reading + r, err := zip.OpenReader(fileName) + if err != nil { + return nil, nil, err + } + var dataset []testDocument + var queries []*SearchRequest + + defer r.Close() + for _, f := range r.File { + jsonFile, err := f.Open() + if err != nil { + return nil, nil, err + } + defer jsonFile.Close() + if f.Name == testDatasetFileName { + err = json.NewDecoder(jsonFile).Decode(&dataset) + if err != nil { + return nil, nil, err + } + } else if f.Name == testQueryFileName { + err = json.NewDecoder(jsonFile).Decode(&queries) + if err != nil { + return nil, nil, err + } + } + } + return dataset, queries, nil +} + +func makeDatasetIntoDocuments(dataset []testDocument) []map[string]interface{} { + documents := make([]map[string]interface{}, len(dataset)) + for i := 0; i < len(dataset); i++ { + document := make(map[string]interface{}) + document["id"] = dataset[i].ID + document["content"] = dataset[i].Content + document["vector"] = dataset[i].Vector + documents[i] = document + } + return documents +} + +func cleanUp(t *testing.T, nameToIndex map[string]Index) { + for path, childIndex := range nameToIndex { + err := childIndex.Close() + if err != nil { + t.Fatal(err) + } + cleanupTmpIndexPath(t, path) + } +} + +func createChildIndex(docs []map[string]interface{}, mapping mapping.IndexMapping, t *testing.T, nameToIndex map[string]Index) Index { + tmpIndexPath := createTmpIndexPath(t) + index, err := New(tmpIndexPath, mapping) + if err != nil { + t.Fatal(err) + } + nameToIndex[index.Name()] = index + batch := index.NewBatch() + for _, doc := range docs { + err := batch.Index(doc["id"].(string), doc) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + } + err = index.Batch(batch) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + return index +} + +func createPartitionedIndex(documents []map[string]interface{}, index *indexAliasImpl, numPartitions int, + mapping mapping.IndexMapping, t *testing.T, multiLevel bool) map[string]Index { + + partitionSize := len(documents) / numPartitions + extraDocs := len(documents) % numPartitions + numDocsPerPartition := make([]int, numPartitions) + for i := 0; i < numPartitions; i++ { + numDocsPerPartition[i] = partitionSize + if extraDocs > 0 { + numDocsPerPartition[i]++ + extraDocs-- + } + } + docsPerPartition := make([][]map[string]interface{}, numPartitions) + prevCutoff := 0 + for i := 0; i < numPartitions; i++ { + docsPerPartition[i] = make([]map[string]interface{}, numDocsPerPartition[i]) + for j := 0; j < numDocsPerPartition[i]; j++ { + docsPerPartition[i][j] = documents[prevCutoff+j] + } + prevCutoff += numDocsPerPartition[i] + } + + rv := make(map[string]Index) + if !multiLevel { + // all indexes are at the same level + for i := 0; i < numPartitions; i++ { + index.Add(createChildIndex(docsPerPartition[i], mapping, t, rv)) + } + } else { + // alias tree + indexes := make([]Index, numPartitions) + for i := 0; i < numPartitions; i++ { + indexes[i] = createChildIndex(docsPerPartition[i], mapping, t, rv) + } + numAlias := int(math.Ceil(float64(numPartitions) / 2.0)) + aliases := make([]IndexAlias, numAlias) + for i := 0; i < numAlias; i++ { + aliases[i] = NewIndexAlias() + aliases[i].SetName(fmt.Sprintf("alias%d", i)) + for j := 0; j < 2; j++ { + if i*2+j < numPartitions { + aliases[i].Add(indexes[i*2+j]) + } + } + } + for i := 0; i < numAlias; i++ { + index.Add(aliases[i]) + } + } + return rv +} + +func createMultipleSegmentsIndex(documents []map[string]interface{}, index Index, numSegments int) error { + // create multiple batches to simulate more than one segment + numBatches := numSegments + + batches := make([]*Batch, numBatches) + numDocsPerBatch := len(documents) / numBatches + extraDocs := len(documents) % numBatches + + docsPerBatch := make([]int, numBatches) + for i := 0; i < numBatches; i++ { + docsPerBatch[i] = numDocsPerBatch + if extraDocs > 0 { + docsPerBatch[i]++ + extraDocs-- + } + } + prevCutoff := 0 + for i := 0; i < numBatches; i++ { + batches[i] = index.NewBatch() + for j := prevCutoff; j < prevCutoff+docsPerBatch[i]; j++ { + doc := documents[j] + err := batches[i].Index(doc["id"].(string), doc) + if err != nil { + return err + } + } + prevCutoff += docsPerBatch[i] + } + errMutex := sync.Mutex{} + var errors []error + wg := sync.WaitGroup{} + wg.Add(len(batches)) + for i, batch := range batches { + go func(ix int, batchx *Batch) { + defer wg.Done() + err := index.Batch(batchx) + if err != nil { + errMutex.Lock() + errors = append(errors, err) + errMutex.Unlock() + } + }(i, batch) + } + wg.Wait() + if len(errors) > 0 { + return errors[0] + } + return nil +} + +func truncateScore(score float64) float64 { + return float64(int(score*1e6)) / 1e6 +} + +// Function to compare two Explanation structs recursively +func compareExplanation(a, b *search.Explanation) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + if truncateScore(a.Value) != truncateScore(b.Value) || len(a.Children) != len(b.Children) { + return false + } + + // Sort the children slices before comparison + sortChildren(a.Children) + sortChildren(b.Children) + + for i := range a.Children { + if !compareExplanation(a.Children[i], b.Children[i]) { + return false + } + } + return true +} + +// Function to sort the children slices +func sortChildren(children []*search.Explanation) { + sort.Slice(children, func(i, j int) bool { + return children[i].Value < children[j].Value + }) +} + +// All hits from a hybrid search/knn search should not have +// index names or score breakdown. +func finalHitsOmitKNNMetadata(hits []*search.DocumentMatch) bool { + for _, hit := range hits { + if hit.IndexNames != nil || hit.ScoreBreakdown != nil { + fmt.Println(len(hit.IndexNames)) + return false + } + } + return true +} + +func finalHitsHaveValidIndex(hits []*search.DocumentMatch, indexes map[string]Index) bool { + for _, hit := range hits { + if hit.Index == "" { + return false + } + var idx Index + var ok bool + if idx, ok = indexes[hit.Index]; !ok { + return false + } + if idx == nil { + return false + } + var doc index.Document + doc, err = idx.Document(hit.ID) + if err != nil { + return false + } + if doc == nil { + return false + } + } + return true +} + +func verifyResult(t *testing.T, controlResult *SearchResult, experimentalResult *SearchResult, testCaseNum int, verifyOnlyDocIDs bool) { + if controlResult.Hits.Len() == 0 || experimentalResult.Hits.Len() == 0 { + t.Fatalf("test case #%d failed: 0 hits returned", testCaseNum) + } + if len(controlResult.Hits) != len(experimentalResult.Hits) { + t.Fatalf("test case #%d failed: expected %d results, got %d", testCaseNum, len(controlResult.Hits), len(experimentalResult.Hits)) + } + if controlResult.Total != experimentalResult.Total { + t.Fatalf("test case #%d failed: expected total hits to be %d, got %d", testCaseNum, controlResult.Total, experimentalResult.Total) + } + // KNN Metadata -> Score Breakdown and IndexNames MUST be omitted from the final hits + if !finalHitsOmitKNNMetadata(controlResult.Hits) || !finalHitsOmitKNNMetadata(experimentalResult.Hits) { + t.Fatalf("test case #%d failed: expected no KNN metadata in hits", testCaseNum) + } + if controlResult.Took == 0 || experimentalResult.Took == 0 { + t.Fatalf("test case #%d failed: expected non-zero took time", testCaseNum) + } + if controlResult.Request == nil || experimentalResult.Request == nil { + t.Fatalf("test case #%d failed: expected non-nil request", testCaseNum) + } + if verifyOnlyDocIDs { + // in multi partitioned index, we cannot be sure of the score or the ordering of the hits as the tf-idf scores are localized to each partition + // so we only check the ids + controlMap := make(map[string]struct{}) + experimentalMap := make(map[string]struct{}) + for _, hit := range controlResult.Hits { + controlMap[hit.ID] = struct{}{} + } + for _, hit := range experimentalResult.Hits { + experimentalMap[hit.ID] = struct{}{} + } + if len(controlMap) != len(experimentalMap) { + t.Fatalf("test case #%d failed: expected %d results, got %d", testCaseNum, len(controlMap), len(experimentalMap)) + } + for id := range controlMap { + if _, ok := experimentalMap[id]; !ok { + t.Fatalf("test case #%d failed: expected id %s to be in experimental result", testCaseNum, id) + } + } + return + } + for i := 0; i < len(controlResult.Hits); i++ { + if controlResult.Hits[i].ID != experimentalResult.Hits[i].ID { + t.Fatalf("test case #%d failed: expected hit %d to have id %s, got %s", testCaseNum, i, controlResult.Hits[i].ID, experimentalResult.Hits[i].ID) + } + // Truncate to 6 decimal places + actualScore := truncateScore(experimentalResult.Hits[i].Score) + expectScore := truncateScore(controlResult.Hits[i].Score) + if expectScore != actualScore { + t.Fatalf("test case #%d failed: expected hit %d to have score %f, got %f", testCaseNum, i, expectScore, actualScore) + } + if !compareExplanation(controlResult.Hits[i].Expl, experimentalResult.Hits[i].Expl) { + t.Fatalf("test case #%d failed: expected hit %d to have explanation %v, got %v", testCaseNum, i, controlResult.Hits[i].Expl, experimentalResult.Hits[i].Expl) + } + } + if truncateScore(controlResult.MaxScore) != truncateScore(experimentalResult.MaxScore) { + t.Fatalf("test case #%d: expected maxScore to be %f, got %f", testCaseNum, controlResult.MaxScore, experimentalResult.MaxScore) + } +} + +func TestSimilaritySearchMultipleSegments(t *testing.T) { + // using scorch options to prevent merges during the course of this test + // so that the knnCollector can be accurately tested + scorch.DefaultMemoryPressurePauseThreshold = 0 + scorch.DefaultMinSegmentsForInMemoryMerge = math.MaxInt + dataset, searchRequests, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents := makeDatasetIntoDocuments(dataset) + + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Analyzer = en.AnalyzerName + + vecFieldMappingL2 := mapping.NewVectorFieldMapping() + vecFieldMappingL2.Dims = testDatasetDims + vecFieldMappingL2.Similarity = index.EuclideanDistance + + vecFieldMappingDot := mapping.NewVectorFieldMapping() + vecFieldMappingDot.Dims = testDatasetDims + vecFieldMappingDot.Similarity = index.CosineSimilarity + + indexMappingL2Norm := NewIndexMapping() + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingL2) + + indexMappingDotProduct := NewIndexMapping() + indexMappingDotProduct.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMappingDotProduct.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingDot) + + testCases := []struct { + numSegments int + queryIndex int + mapping mapping.IndexMapping + scoreValue string + }{ + // L2 norm similarity + { + numSegments: 6, + queryIndex: 0, + mapping: indexMappingL2Norm, + }, + { + numSegments: 7, + queryIndex: 1, + mapping: indexMappingL2Norm, + }, + { + numSegments: 8, + queryIndex: 2, + mapping: indexMappingL2Norm, + }, + { + numSegments: 9, + queryIndex: 3, + mapping: indexMappingL2Norm, + }, + { + numSegments: 10, + queryIndex: 4, + mapping: indexMappingL2Norm, + }, + { + numSegments: 11, + queryIndex: 5, + mapping: indexMappingL2Norm, + }, + // dot_product similarity + { + numSegments: 6, + queryIndex: 0, + mapping: indexMappingDotProduct, + }, + { + numSegments: 7, + queryIndex: 1, + mapping: indexMappingDotProduct, + }, + { + numSegments: 8, + queryIndex: 2, + mapping: indexMappingDotProduct, + }, + { + numSegments: 9, + queryIndex: 3, + mapping: indexMappingDotProduct, + }, + { + numSegments: 10, + queryIndex: 4, + mapping: indexMappingDotProduct, + }, + { + numSegments: 11, + queryIndex: 5, + mapping: indexMappingDotProduct, + }, + // score none test + { + numSegments: 3, + queryIndex: 0, + mapping: indexMappingL2Norm, + scoreValue: "none", + }, + { + numSegments: 7, + queryIndex: 1, + mapping: indexMappingL2Norm, + scoreValue: "none", + }, + { + numSegments: 8, + queryIndex: 2, + mapping: indexMappingL2Norm, + scoreValue: "none", + }, + { + numSegments: 3, + queryIndex: 0, + mapping: indexMappingDotProduct, + scoreValue: "none", + }, + { + numSegments: 7, + queryIndex: 1, + mapping: indexMappingDotProduct, + scoreValue: "none", + }, + { + numSegments: 8, + queryIndex: 2, + mapping: indexMappingDotProduct, + scoreValue: "none", + }, + } + for testCaseNum, testCase := range testCases { + for _, operator := range knnOperators { + // run single segment test first + tmpIndexPath := createTmpIndexPath(t) + index, err := New(tmpIndexPath, testCase.mapping) + if err != nil { + t.Fatal(err) + } + query := searchRequests[testCase.queryIndex] + query.Sort = search.SortOrder{&search.SortScore{Desc: true}, &search.SortDocID{Desc: true}, &search.SortField{Desc: false, Field: "content"}} + query.AddKNNOperator(operator) + query.Explain = true + + nameToIndex := make(map[string]Index) + nameToIndex[index.Name()] = index + + err = createMultipleSegmentsIndex(documents, index, 1) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + controlResult, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(controlResult.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected control result hits to have valid `Index`", testCaseNum) + } + if testCase.scoreValue == "none" { + query.Score = testCase.scoreValue + expectedResultScoreNone, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(expectedResultScoreNone.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected score none hits to have valid `Index`", testCaseNum) + } + verifyResult(t, controlResult, expectedResultScoreNone, testCaseNum, true) + query.Score = "" + } + cleanUp(t, nameToIndex) + + // run multiple segments test + tmpIndexPath = createTmpIndexPath(t) + index, err = New(tmpIndexPath, testCase.mapping) + if err != nil { + t.Fatal(err) + } + nameToIndex = make(map[string]Index) + nameToIndex[index.Name()] = index + err = createMultipleSegmentsIndex(documents, index, testCase.numSegments) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + experimentalResult, err := index.Search(query) + if err != nil { + cleanUp(t, nameToIndex) + t.Fatal(err) + } + if !finalHitsHaveValidIndex(experimentalResult.Hits, nameToIndex) { + cleanUp(t, nameToIndex) + t.Fatalf("test case #%d failed: expected experimental result hits to have valid `Index`", testCaseNum) + } + verifyResult(t, controlResult, experimentalResult, testCaseNum, false) + cleanUp(t, nameToIndex) + } + } +} + +// Test to see if KNN Operators get added right to the query. +func TestKNNOperator(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + const dims = 5 + getRandomVector := func() []float32 { + vec := make([]float32, dims) + for i := 0; i < dims; i++ { + vec[i] = rand.Float32() + } + return vec + } + + dataset := make([]map[string]interface{}, 10) + + // Indexing just a few docs to populate index. + for i := 0; i < 10; i++ { + dataset = append(dataset, map[string]interface{}{ + "type": "vectorStuff", + "content": strconv.Itoa(i), + "vector": getRandomVector(), + }) + } + + indexMapping := NewIndexMapping() + indexMapping.TypeField = "type" + indexMapping.DefaultAnalyzer = "en" + documentMapping := NewDocumentMapping() + indexMapping.AddDocumentMapping("vectorStuff", documentMapping) + + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Index = true + contentFieldMapping.Store = true + documentMapping.AddFieldMappingsAt("content", contentFieldMapping) + + vecFieldMapping := mapping.NewVectorFieldMapping() + vecFieldMapping.Index = true + vecFieldMapping.Dims = 5 + vecFieldMapping.Similarity = "dot_product" + documentMapping.AddFieldMappingsAt("vector", vecFieldMapping) + + index, err := New(tmpIndexPath, indexMapping) + if err != nil { + t.Fatal(err) + } + defer func() { + err := index.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch := index.NewBatch() + for i := 0; i < len(dataset); i++ { + batch.Index(strconv.Itoa(i), dataset[i]) + } + + err = index.Batch(batch) + if err != nil { + t.Fatal(err) + } + + termQuery := query.NewTermQuery("2") + + searchRequest := NewSearchRequest(termQuery) + searchRequest.AddKNN("vector", getRandomVector(), 3, 2.0) + searchRequest.AddKNN("vector", getRandomVector(), 2, 1.5) + searchRequest.Fields = []string{"content", "vector"} + + // Conjunction + searchRequest.AddKNNOperator(knnOperatorAnd) + conjunction, _, _, err := createKNNQuery(searchRequest) + if err != nil { + t.Fatalf("unexpected error for AND knn operator") + } + + conj, ok := conjunction.(*query.DisjunctionQuery) + if !ok { + t.Fatalf("expected disjunction query") + } + + if len(conj.Disjuncts) != 2 { + t.Fatalf("expected 2 disjuncts") + } + + // Disjunction + searchRequest.AddKNNOperator(knnOperatorOr) + disjunction, _, _, err := createKNNQuery(searchRequest) + if err != nil { + t.Fatalf("unexpected error for OR knn operator") + } + + disj, ok := disjunction.(*query.DisjunctionQuery) + if !ok { + t.Fatalf("expected disjunction query") + } + + if len(disj.Disjuncts) != 2 { + t.Fatalf("expected 2 disjuncts") + } + + // Incorrect operator. + searchRequest.AddKNNOperator("bs_op") + searchRequest.Query, _, _, err = createKNNQuery(searchRequest) + if err == nil { + t.Fatalf("expected error for incorrect knn operator") + } +} + +// ----------------------------------------------------------------------------- +// Test nested vectors + +func TestNestedVectors(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + const dims = 3 + const k = 1 // one nearest neighbor + const vecFieldName = "vecData" + + dataset := map[string]map[string]interface{}{ // docID -> Doc + "doc1": { + vecFieldName: []float32{100, 100, 100}, + }, + "doc2": { + vecFieldName: [][]float32{{0, 0, 0}, {1000, 1000, 1000}}, + }, + } + + // Index mapping + indexMapping := NewIndexMapping() + vm := mapping.NewVectorFieldMapping() + vm.Dims = dims + vm.Similarity = "l2_norm" + indexMapping.DefaultMapping.AddFieldMappingsAt(vecFieldName, vm) + + // Create index and upload documents + index, err := New(tmpIndexPath, indexMapping) + if err != nil { + t.Fatal(err) + } + defer func() { + err := index.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch := index.NewBatch() + for docID, doc := range dataset { + batch.Index(docID, doc) + } + + err = index.Batch(batch) + if err != nil { + t.Fatal(err) + } + + // Run searches + + tests := []struct { + queryVec []float32 + expectedDocID string + }{ + { + queryVec: []float32{100, 100, 100}, + expectedDocID: "doc1", + }, + { + queryVec: []float32{0, 0, 0}, + expectedDocID: "doc2", + }, + { + queryVec: []float32{1000, 1000, 1000}, + expectedDocID: "doc2", + }, + } + + for _, test := range tests { + searchReq := NewSearchRequest(query.NewMatchNoneQuery()) + searchReq.AddKNN(vecFieldName, test.queryVec, k, 1000) + + res, err := index.Search(searchReq) + if err != nil { + t.Fatal(err) + } + + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hit, got %d", len(res.Hits)) + } + + if res.Hits[0].ID != test.expectedDocID { + t.Fatalf("expected docID %s, got %s", test.expectedDocID, + res.Hits[0].ID) + } + } +} + +func TestNumVecsStat(t *testing.T) { + + dataset, _, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents := makeDatasetIntoDocuments(dataset) + + indexMapping := NewIndexMapping() + + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Analyzer = en.AnalyzerName + indexMapping.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + + vecFieldMapping1 := mapping.NewVectorFieldMapping() + vecFieldMapping1.Dims = testDatasetDims + vecFieldMapping1.Similarity = index.EuclideanDistance + indexMapping.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMapping1) + + tmpIndexPath := createTmpIndexPath(t) + index, err := New(tmpIndexPath, indexMapping) + if err != nil { + t.Fatal(err) + } + defer func() { + err := index.Close() + if err != nil { + t.Fatal(err) + } + }() + + for i := 0; i < 10; i++ { + batch := index.NewBatch() + for j := 0; j < 3; j++ { + for k := 0; k < 10; k++ { + err := batch.Index(fmt.Sprintf("%d", i*30+j*10+k), documents[j*10+k]) + if err != nil { + t.Fatal(err) + } + } + } + err = index.Batch(batch) + if err != nil { + t.Fatal(err) + } + } + + statsMap := index.StatsMap() + + if indexStats, exists := statsMap["index"]; exists { + if indexStatsMap, ok := indexStats.(map[string]interface{}); ok { + v1, ok := indexStatsMap["field:vector:num_vectors"].(uint64) + if !ok || v1 != uint64(300) { + t.Fatalf("mismatch in the number of vectors, expected 300, got %d", indexStatsMap["field:vector:num_vectors"]) + } + } + } +} diff --git a/search_no_knn.go b/search_no_knn.go index fb3814911..9d8dd56e2 100644 --- a/search_no_knn.go +++ b/search_no_knn.go @@ -18,11 +18,14 @@ package bleve import ( + "context" "encoding/json" "sort" "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/search/collector" "github.com/blevesearch/bleve/v2/search/query" + index "github.com/blevesearch/bleve_index_api" ) // A SearchRequest describes all the parameters @@ -60,6 +63,18 @@ type SearchRequest struct { SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` + // PreSearchData will be a map that will be used + // in the second phase of any 2-phase search, to provide additional + // context to the second phase. This is useful in the case of index + // aliases where the first phase will gather the PreSearchData from all + // the indexes in the alias, and the second phase will use that + // PreSearchData to perform the actual search. + // The currently accepted map configuration is: + // + // "_knn_pre_search_data_key": []*search.DocumentMatch + + PreSearchData map[string]interface{} `json:"pre_search_data,omitempty"` + sortFunc func(sort.Interface) } @@ -79,6 +94,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { Score string `json:"score"` SearchAfter []string `json:"search_after"` SearchBefore []string `json:"search_before"` + PreSearchData json.RawMessage `json:"pre_search_data"` } err := json.Unmarshal(input, &temp) @@ -119,6 +135,12 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { if r.From < 0 { r.From = 0 } + if temp.PreSearchData != nil { + r.PreSearchData, err = query.ParsePreSearchData(temp.PreSearchData) + if err != nil { + return err + } + } return nil @@ -126,7 +148,7 @@ func (r *SearchRequest) UnmarshalJSON(input []byte) error { // ----------------------------------------------------------------------------- -func copySearchRequest(req *SearchRequest) *SearchRequest { +func copySearchRequest(req *SearchRequest, preSearchData map[string]interface{}) *SearchRequest { rv := SearchRequest{ Query: req.Query, Size: req.Size + req.From, @@ -140,10 +162,42 @@ func copySearchRequest(req *SearchRequest) *SearchRequest { Score: req.Score, SearchAfter: req.SearchAfter, SearchBefore: req.SearchBefore, + PreSearchData: preSearchData, } return &rv } -func disjunctQueryWithKNN(req *SearchRequest) query.Query { - return req.Query +func validateKNN(req *SearchRequest) error { + return nil +} + +func (i *indexImpl) runKnnCollector(ctx context.Context, req *SearchRequest, reader index.IndexReader, preSearch bool) ([]*search.DocumentMatch, error) { + return nil, nil +} + +func setKnnHitsInCollector(knnHits []*search.DocumentMatch, req *SearchRequest, coll *collector.TopNCollector) { +} + +func requestHasKNN(req *SearchRequest) bool { + return false +} + +func addKnnToDummyRequest(dummyReq *SearchRequest, realReq *SearchRequest) { +} + +func mergeKNNDocumentMatches(req *SearchRequest, knnHits []*search.DocumentMatch) []*search.DocumentMatch { + return nil +} + +func redistributeKNNPreSearchData(req *SearchRequest, indexes []Index) (map[string]map[string]interface{}, error) { + return nil, nil +} + +func isKNNrequestSatisfiedByPreSearch(req *SearchRequest) bool { + return false +} + +func constructKnnPresearchData(mergedOut map[string]map[string]interface{}, preSearchResult *SearchResult, + indexes []Index) (map[string]map[string]interface{}, error) { + return mergedOut, nil } diff --git a/search_test.go b/search_test.go index 37da8da0a..7f76978bd 100644 --- a/search_test.go +++ b/search_test.go @@ -17,6 +17,7 @@ package bleve import ( "encoding/json" "fmt" + "math" "reflect" "strconv" "strings" @@ -26,6 +27,7 @@ import ( "github.com/blevesearch/bleve/v2/analysis" "github.com/blevesearch/bleve/v2/analysis/analyzer/custom" "github.com/blevesearch/bleve/v2/analysis/analyzer/keyword" + "github.com/blevesearch/bleve/v2/analysis/analyzer/simple" "github.com/blevesearch/bleve/v2/analysis/analyzer/standard" html_char_filter "github.com/blevesearch/bleve/v2/analysis/char/html" regexp_char_filter "github.com/blevesearch/bleve/v2/analysis/char/regexp" @@ -125,6 +127,70 @@ func TestSortedFacetedQuery(t *testing.T) { } } +func TestMatchAllScorer(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + indexMapping := NewIndexMapping() + indexMapping.TypeField = "type" + indexMapping.DefaultAnalyzer = "en" + documentMapping := NewDocumentMapping() + + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Index = true + contentFieldMapping.Store = true + documentMapping.AddFieldMappingsAt("content", contentFieldMapping) + + index, err := New(tmpIndexPath, indexMapping) + if err != nil { + t.Fatal(err) + } + defer func() { + err := index.Close() + if err != nil { + t.Fatal(err) + } + }() + + index.Index("1", map[string]interface{}{ + "country": "india", + "content": "k", + }) + index.Index("2", map[string]interface{}{ + "country": "india", + "content": "l", + }) + index.Index("3", map[string]interface{}{ + "country": "india", + "content": "k", + }) + + d, err := index.DocCount() + if err != nil { + t.Fatal(err) + } + if d != 3 { + t.Errorf("expected 3, got %d", d) + } + + searchRequest := NewSearchRequest(NewMatchAllQuery()) + searchRequest.Score = "none" + searchResults, err := index.Search(searchRequest) + if err != nil { + t.Fatal(err) + } + + if searchResults.Total != 3 { + t.Fatalf("expected all the 3 docs in the index, got %v", searchResults.Total) + } + + for _, hit := range searchResults.Hits { + if hit.Score != 0.0 { + t.Fatalf("expected 0 score since score = none, got %v", hit.Score) + } + } +} + func TestSearchResultString(t *testing.T) { tests := []struct { @@ -2773,8 +2839,7 @@ func TestDateRangeStringQuery(t *testing.T) { } } } - -func TestDateRangeFaceQueriesWithCustomDateTimeParser(t *testing.T) { +func TestDateRangeFacetQueriesWithCustomDateTimeParser(t *testing.T) { idxMapping := NewIndexMapping() err := idxMapping.AddCustomDateTimeParser("customDT", map[string]interface{}{ @@ -2873,8 +2938,8 @@ func TestDateRangeFaceQueriesWithCustomDateTimeParser(t *testing.T) { end: "2001-08-20 18:10:00", result: testFacetResult{ name: "test", - start: "2001-08-20 18:00:00", - end: "2001-08-20 18:10:00", + start: "2001-08-20T18:00:00Z", + end: "2001-08-20T18:10:00Z", count: 2, err: nil, }, @@ -2886,8 +2951,8 @@ func TestDateRangeFaceQueriesWithCustomDateTimeParser(t *testing.T) { parser: "queryDT", result: testFacetResult{ name: "test", - start: "20/08/2001 6:00PM", - end: "20/08/2001 6:10PM", + start: "2001-08-20T18:00:00Z", + end: "2001-08-20T18:10:00Z", count: 2, err: nil, }, @@ -2899,8 +2964,8 @@ func TestDateRangeFaceQueriesWithCustomDateTimeParser(t *testing.T) { parser: "customDT", result: testFacetResult{ name: "test", - start: "20/08/2001 15:00:00", - end: "2001/08/20 6:10PM", + start: "2001-08-20T15:00:00Z", + end: "2001-08-20T18:10:00Z", count: 2, err: nil, }, @@ -2911,7 +2976,7 @@ func TestDateRangeFaceQueriesWithCustomDateTimeParser(t *testing.T) { parser: "customDT", result: testFacetResult{ name: "test", - end: "2001/08/20 6:15PM", + end: "2001-08-20T18:15:00Z", count: 3, err: nil, }, @@ -2922,7 +2987,7 @@ func TestDateRangeFaceQueriesWithCustomDateTimeParser(t *testing.T) { parser: "queryDT", result: testFacetResult{ name: "test", - start: "20/08/2001 6:15PM", + start: "2001-08-20T18:15:00Z", count: 2, err: nil, }, @@ -3376,3 +3441,161 @@ func TestPercentAndIsoStyleDates(t *testing.T) { } } } + +func roundToDecimalPlace(num float64, decimalPlaces int) float64 { + precision := math.Pow(10, float64(decimalPlaces)) + return math.Round(num*precision) / precision +} + +func TestScoreBreakdown(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + imap := mapping.NewIndexMapping() + textField := mapping.NewTextFieldMapping() + textField.Analyzer = simple.Name + imap.DefaultMapping.AddFieldMappingsAt("text", textField) + + documents := map[string]map[string]interface{}{ + "doc1": { + "text": "lorem ipsum dolor sit amet consectetur adipiscing elit do eiusmod tempor", + }, + "doc2": { + "text": "lorem dolor amet adipiscing sed eiusmod", + }, + "doc3": { + "text": "ipsum sit consectetur elit do tempor", + }, + "doc4": { + "text": "lorem ipsum sit amet adipiscing elit do eiusmod", + }, + } + + idx, err := New(tmpIndexPath, imap) + if err != nil { + t.Fatal(err) + } + defer func() { + err = idx.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch := idx.NewBatch() + for docID, doc := range documents { + err := batch.Index(docID, doc) + if err != nil { + t.Fatal(err) + } + } + err = idx.Batch(batch) + if err != nil { + t.Fatal(err) + } + + type testResult struct { + docID string // doc ID of the hit + score float64 + scoreBreakdown map[int]float64 + } + type testStruct struct { + query string + typ string + expectHits []testResult + } + testQueries := []testStruct{ + { + // trigger disjunction heap searcher (>10 searchers) + // expect score breakdown to have a 0 at BLANK + query: `{"disjuncts":[{"term":"lorem","field":"text"},{"term":"blank","field":"text"},{"term":"ipsum","field":"text"},{"term":"blank","field":"text"},{"term":"blank","field":"text"},{"term":"dolor","field":"text"},{"term":"sit","field":"text"},{"term":"amet","field":"text"},{"term":"consectetur","field":"text"},{"term":"blank","field":"text"},{"term":"adipiscing","field":"text"},{"term":"blank","field":"text"},{"term":"elit","field":"text"},{"term":"sed","field":"text"},{"term":"do","field":"text"},{"term":"eiusmod","field":"text"},{"term":"tempor","field":"text"},{"term":"blank","field":"text"},{"term":"blank","field":"text"}]}`, + typ: "disjunction", + expectHits: []testResult{ + { + docID: "doc1", + score: 0.3034548543819603, + scoreBreakdown: map[int]float64{0: 0.040398807605268316, 2: 0.040398807605268316, 5: 0.0669862776967768, 6: 0.040398807605268316, 7: 0.040398807605268316, 8: 0.0669862776967768, 10: 0.040398807605268316, 12: 0.040398807605268316, 14: 0.040398807605268316, 15: 0.040398807605268316, 16: 0.0669862776967768}, + }, + { + docID: "doc2", + score: 0.14725661652397853, + scoreBreakdown: map[int]float64{0: 0.05470024557900147, 5: 0.09069985124905133, 7: 0.05470024557900147, 10: 0.05470024557900147, 13: 0.15681178542754148, 15: 0.05470024557900147}, + }, + { + docID: "doc3", + score: 0.12637916362550797, + scoreBreakdown: map[int]float64{2: 0.05470024557900147, 6: 0.05470024557900147, 8: 0.09069985124905133, 12: 0.05470024557900147, 14: 0.05470024557900147, 16: 0.09069985124905133}, + }, + { + docID: "doc4", + score: 0.15956816751152955, + scoreBreakdown: map[int]float64{0: 0.04737179972998534, 2: 0.04737179972998534, 6: 0.04737179972998534, 7: 0.04737179972998534, 10: 0.04737179972998534, 12: 0.04737179972998534, 14: 0.04737179972998534, 15: 0.04737179972998534}, + }, + }, + }, + { + // trigger disjunction slice searcher (< 10 searchers) + // expect BLANK to give a 0 in score breakdown + query: `{"disjuncts":[{"term":"blank","field":"text"},{"term":"lorem","field":"text"},{"term":"ipsum","field":"text"},{"term":"blank","field":"text"},{"term":"blank","field":"text"},{"term":"dolor","field":"text"},{"term":"sit","field":"text"},{"term":"blank","field":"text"}]}`, + typ: "disjunction", + expectHits: []testResult{ + { + docID: "doc1", + score: 0.1340684440934241, + scoreBreakdown: map[int]float64{1: 0.05756326446708409, 2: 0.05756326446708409, 5: 0.09544709478559595, 6: 0.05756326446708409}, + }, + { + docID: "doc2", + score: 0.05179425287147191, + scoreBreakdown: map[int]float64{1: 0.0779410306721006, 5: 0.129235980813787}, + }, + { + docID: "doc3", + score: 0.0389705153360503, + scoreBreakdown: map[int]float64{2: 0.0779410306721006, 6: 0.0779410306721006}, + }, + { + docID: "doc4", + score: 0.07593627256602972, + scoreBreakdown: map[int]float64{1: 0.06749890894758198, 2: 0.06749890894758198, 6: 0.06749890894758198}, + }, + }, + }, + } + for _, dtq := range testQueries { + var q query.Query + var rv query.DisjunctionQuery + err := json.Unmarshal([]byte(dtq.query), &rv) + if err != nil { + t.Fatal(err) + } + rv.RetrieveScoreBreakdown(true) + q = &rv + sr := NewSearchRequest(q) + sr.SortBy([]string{"_id"}) + sr.Explain = true + res, err := idx.Search(sr) + if err != nil { + t.Fatal(err) + } + if len(res.Hits) != len(dtq.expectHits) { + t.Fatalf("expected %d hits, got %d", len(dtq.expectHits), len(res.Hits)) + } + for i, hit := range res.Hits { + if hit.ID != dtq.expectHits[i].docID { + t.Fatalf("expected docID %s, got %s", dtq.expectHits[i].docID, hit.ID) + } + if len(hit.ScoreBreakdown) != len(dtq.expectHits[i].scoreBreakdown) { + t.Fatalf("expected %d score breakdown, got %d", len(dtq.expectHits[i].scoreBreakdown), len(hit.ScoreBreakdown)) + } + for j, score := range hit.ScoreBreakdown { + actualScore := roundToDecimalPlace(score, 3) + expectScore := roundToDecimalPlace(dtq.expectHits[i].scoreBreakdown[j], 3) + if actualScore != expectScore { + t.Fatalf("expected score breakdown %f, got %f", dtq.expectHits[i].scoreBreakdown[j], score) + } + } + } + } + +} diff --git a/test/knn/knn_dataset_queries.zip b/test/knn/knn_dataset_queries.zip new file mode 100644 index 000000000..d840ded2f Binary files /dev/null and b/test/knn/knn_dataset_queries.zip differ diff --git a/test/tests/facet/searches.json b/test/tests/facet/searches.json index 33ac39775..6752282a4 100644 --- a/test/tests/facet/searches.json +++ b/test/tests/facet/searches.json @@ -129,12 +129,12 @@ { "name": "new", "count": 9, - "start": "2012-01-01" + "start": "2012-01-01T00:00:00Z" }, { "name": "old", "count": 1, - "end": "2012-01-01" + "end": "2012-01-01T00:00:00Z" } ] } diff --git a/util/extract.go b/util/extract.go index f8e61546a..e963d0c3a 100644 --- a/util/extract.go +++ b/util/extract.go @@ -15,6 +15,7 @@ package util import ( + "math" "reflect" ) @@ -24,13 +25,13 @@ func ExtractNumericValFloat64(v interface{}) (float64, bool) { if !val.IsValid() { return 0, false } - typ := val.Type() - switch typ.Kind() { - case reflect.Float32, reflect.Float64: + + switch { + case val.CanFloat(): return val.Float(), true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + case val.CanInt(): return float64(val.Int()), true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + case val.CanUint(): return float64(val.Uint()), true } @@ -43,13 +44,17 @@ func ExtractNumericValFloat32(v interface{}) (float32, bool) { if !val.IsValid() { return 0, false } - typ := val.Type() - switch typ.Kind() { - case reflect.Float32, reflect.Float64: - return float32(val.Float()), true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + + switch { + case val.CanFloat(): + floatVal := val.Float() + if floatVal > math.MaxFloat32 { + return 0, false + } + return float32(floatVal), true + case val.CanInt(): return float32(val.Int()), true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + case val.CanUint(): return float32(val.Uint()), true } diff --git a/util/knn.go b/util/knn.go deleted file mode 100644 index e50ff01da..000000000 --- a/util/knn.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2023 Couchbase, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build vectors -// +build vectors - -package util - -const ( - EuclideanDistance = "l2_norm" - - // dotProduct(vecA, vecB) = vecA . vecB = |vecA| * |vecB| * cos(theta); - // where, theta is the angle between vecA and vecB - // If vecA and vecB are normalized (unit magnitude), then - // vecA . vecB = cos(theta), which is the cosine similarity. - // Thus, we don't need a separate similarity type for cosine similarity - CosineSimilarity = "dot_product" -) - -const DefaultSimilarityMetric = EuclideanDistance - -// Supported similarity metrics for vector fields -var SupportedSimilarityMetrics = map[string]struct{}{ - EuclideanDistance: {}, - CosineSimilarity: {}, -} - diff --git a/vectors.md b/vectors.md new file mode 100644 index 000000000..b4264d291 --- /dev/null +++ b/vectors.md @@ -0,0 +1,116 @@ +# bleve@v2.4.0+ + +* *v2.4.0* (and after) will come with support for **vectors' indexing and search**. +* We've achieved this by embedding [FAISS](https://github.com/facebookresearch/faiss) indexes within our bleve indexes. +* A new zap file format: [v16](https://github.com/blevesearch/zapx/blob/master/zap.md) - which will be the default going forward. Here we co-locate text and vector indexes as neighbors within segments, continuing to conform to the segmented architecture of *scorch*. + +## Pre-requisite(s) + +* Induction of [FAISS](https://github.com/blevesearch/faiss) into our eco system. +* FAISS is a C++ library that needs to be compiled and it's shared libraries need to be situated at an accessible path for your application. +* A `vectors` GO TAG needs to be set for bleve to access all the supporting code. This TAG must be set only after the FAISS shared library is made available. Failure to do either will inhibit you from using this feature. +* Please follow these [instructions](#setup-instructions) below for any assistance in the area. + +## Indexing + +```go +doc := struct{ + Id string + Text string + Vec []float32 +}{ + Id: "example", + Text: "hello from united states", + Vec: []float32{0,1,2,3,4,5,6,7,8,9}, +} + +textFieldMapping := mapping.NewTextFieldMapping() +vectorFieldMapping := mapping.NewVectorFieldMapping() +vectorFieldMapping.Dims = 10 +vectorFieldMapping.Similarity = "l2_norm" // euclidean distance + +bleveMapping := bleve.NewIndexMapping() +bleveMapping.DefaultMapping.Dynamic = false +bleveMapping.DefaultMapping.AddFieldMappingsAt("text", textFieldMapping) +bleveMapping.DefaultMapping.AddFieldMappingsAt("vec", vectorFieldMapping) + +index, err := bleve.New("example.bleve", bleveMapping) +if err != nil { + panic(err) +} +index.Index(doc.Id, doc) +``` + +## Querying + +```go +searchRequest := NewSearchRequest(query.NewMatchNoneQuery()) +searchRequest.AddKNN( + "vec", // vector field name + []float32{10,11,12,13,14,15,16,17,18,19}, // query vector (same dims) + 5, // k + 0, // boost +) +searchResult, err := index.Search(searchRequest) +if err != nil { + panic(err) +} +fmt.Println(searchResult.Hits) +``` + +## Caveats + +* The `vector` field type is an array that is to hold float32 values only. +* Currently supported similarity metrics are: [`"l2_norm"`, `"dot_product"`]. +* Supported dimensionality is between 1 and 2048 at the moment. +* Vectors from documents that do not conform to the index mapping dimensionality are simply discarded at index time. +* The dimensionality of the query vector must match the dimensionality of the indexed vectors to obtain any results. +* Pure kNN searches can be performed, but the `query` attribute within the search request must be set - to `{"match_none": {}}` in this case. +* Hybrid searches are supported, where results from `query` are unioned (for now) with results from `knn`. The tf-idf scores from exact searches are simply summed with the similarity distances to determine the aggregate scores. +``` +aggregate_score = (query_boost * query_hit_score) + (knn_boost * knn_hit_distance) +``` +* Multi kNN searches are supported - the `knn` object within the search request accepts an array of requests. These sub objects are unioned by default but this behavior can be overriden by setting `knn_operator` to `"and"`. +* Previously supported pagination settings will work as they were, with size/limit being applied over the top-K hits combined with any exact search hits. + +## Setup Instructions + +* Using `cmake` is a recommended approach by FAISS authors. +* More details here - [faiss/INSTALL](https://github.com/blevesearch/faiss/blob/main/INSTALL.md). + +### Linux + +Also documented here - [go-faiss/README](https://github.com/blevesearch/go-faiss/blob/master/README.md). + +``` +git clone https://github.com/blevesearch/faiss.git +cd faiss +cmake -B build -DFAISS_ENABLE_GPU=OFF -DFAISS_ENABLE_C_API=ON -DBUILD_SHARED_LIBS=ON . +make -C build +sudo make -C build install +``` + +Building will produce the dynamic library `faiss_c`. You will need to install it in a place where your system will find it (e.g. /usr/lib). You can do this with: +``` +sudo cp build/c_api/libfaiss_c.so /usr/local/lib +``` + +### OSX + +While you shouldn't need to do any different over osX x86_64, with aarch64 - some instructions need adjusting (see [facebookresearch/faiss#2111](https://github.com/facebookresearch/faiss/issues/2111)) .. + +``` +LDFLAGS="-L/opt/homebrew/opt/llvm/lib" CPPFLAGS="-I/opt/homebrew/opt/llvm/include" CXX=/opt/homebrew/opt/llvm/bin/clang++ CC=/opt/homebrew/opt/llvm/bin/clang cmake -B build -DFAISS_ENABLE_GPU=OFF -DFAISS_ENABLE_C_API=ON -DBUILD_SHARED_LIBS=ON -DFAISS_ENABLE_PYTHON=OFF . +make -C build +sudo make -C build install +sudo cp build/c_api/libfaiss_c.dylib /usr/local/lib +``` + +### Sanity check + +Once the supporting library is built and made available, a sanity run is recommended to make sure all unit tests and especially those accessing the vectors' code pass. Here's how I do on mac - + +``` +export DYLD_LIBRARY_PATH=/usr/local/lib +go test -v ./... --tags=vectors +```