diff --git a/client/config.go b/client/config.go index 2c3a95e..5e25b77 100644 --- a/client/config.go +++ b/client/config.go @@ -26,7 +26,7 @@ import ( // Constants and default values for the package bce const ( - SdkVersion = "0.0.1" + SdkVersion = "2.0.1" DefaultProtocol = "http" DefaultRegion = "bj" DefaultContentType = "application/json;charset=utf-8" diff --git a/example/example.go b/example/example.go index 354b3cd..66c390c 100644 --- a/example/example.go +++ b/example/example.go @@ -59,7 +59,6 @@ func (m *MochowTest) clearEnv() error { } } } - } // drop database if existed @@ -115,6 +114,11 @@ func (m *MochowTest) createDatabaseAndTable() error { NotNull: true, Dimension: 3, }, + { + FieldName: "segment", + FieldType: "TEXT", + NotNull: true, + }, } // Indexes @@ -138,6 +142,16 @@ func (m *MochowTest) createDatabaseAndTable() error { AutoBuild: true, AutoBuildPolicy: autoBuildPolicy.Params(), }, + { + IndexName: "book_segment_inverted_idx", + IndexType: api.InvertedIndex, + Fields: []string{"segment"}, + FieldAttributes: []api.InvertedIndexFieldAttribute{api.Analyzed}, + Params: api.InvertedIndexParams{ + "analyzer": api.ChineseAnalyzer, + "parseMode": api.FineMode, + }, + }, } // create table @@ -180,6 +194,7 @@ func (m *MochowTest) upsertData() error { "author": "吴承恩", "page": 21, "vector": []float32{0.2123, 0.21, 0.213}, + "segment": "富贵功名,前缘分定,为人切莫欺心。", }, }, { @@ -189,6 +204,7 @@ func (m *MochowTest) upsertData() error { "author": "吴承恩", "page": 22, "vector": []float32{0.2123, 0.22, 0.213}, + "segment": "正大光明,忠良善果弥深。些些狂妄天加谴,眼前不遇待时临。", }, }, { @@ -198,6 +214,7 @@ func (m *MochowTest) upsertData() error { "author": "罗贯中", "page": 23, "vector": []float32{0.2123, 0.23, 0.213}, + "segment": "细作探知这个消息,飞报吕布。", }, }, { @@ -207,6 +224,17 @@ func (m *MochowTest) upsertData() error { "author": "罗贯中", "page": 24, "vector": []float32{0.2123, 0.24, 0.213}, + "segment": "布大惊,与陈宫商议。宫曰:“闻刘玄德新领徐州,可往投之。” 布从其言,竟投徐州来。有人报知玄德。", + }, + }, + { + Fields: map[string]interface{}{ + "id": "0005", + "bookName": "三国演义", + "author": "罗贯中", + "page": 25, + "vector": []float32{0.2123, 0.24, 0.213}, + "segment": "玄德曰:“布乃当今英勇之士,可出迎之。”糜竺曰:“吕布乃虎狼之徒,不可收留;收则伤人矣。", }, }, } @@ -276,6 +304,7 @@ func (m *MochowTest) updateData() error { "bookName": "红楼梦", "author": "曹雪芹", "page": 100, + "segment": "满纸荒唐言,一把辛酸泪", }, } err := m.client.UpdateRow(updateArgs) @@ -286,7 +315,7 @@ func (m *MochowTest) updateData() error { return nil } -func (m *MochowTest) searchData() error { +func (m *MochowTest) topkSearch() error { // rebuild vector index if err := m.client.RebuildIndex(m.database, m.table, "vector_idx"); err != nil { log.Fatalf("Fail to rebuild index due to error: %v", err) @@ -302,50 +331,124 @@ func (m *MochowTest) searchData() error { } // search - hnswParams := api.NewSearchParams() - hnswParams.AddEf(200) - hnswParams.AddLimit(10) + vector := api.FloatVector([]float32{0.3123, 0.43, 0.213}) - // single ann search params - searchArgs := &api.SearchRowArgs{ + searchArgs := &api.VectorSearchArgs{ Database: m.database, Table: m.table, - ANNS: &api.ANNSearchParams{ - VectorField: "vector", - VectorFloats: []float32{0.3123, 0.43, 0.213}, - Filter: "bookName='三国演义'", - Params: hnswParams, - }, - RetrieveVector: false, + Request: api.VectorTopkSearchRequest{}. + New("vector", vector, 5). + Filter("bookName='三国演义'"). + Config(api.VectorSearchConfig{}.New().Ef(200)), } - searchResult, err := m.client.SearchRow(searchArgs) + + searchResult, err := m.client.VectorSearch(searchArgs) if err != nil { log.Fatalf("Fail to search row due to error: %v", err) return err } - log.Printf("search result: %+v", searchResult) - // batch ann search params - batchSearchArgs := &api.BatchSearchRowArgs{ + log.Printf("topk search result: %+v", searchResult.Rows) + return nil +} + +func (m *MochowTest) rangeSearch() error { + vector := api.FloatVector([]float32{0.3123, 0.43, 0.213}) + + searchArgs := &api.VectorSearchArgs{ Database: m.database, Table: m.table, - ANNS: &api.BatchANNSearchParams{ - VectorField: "vector", - VectorFloats: [][]float32{ - {0.3123, 0.43, 0.213}, - {0.5512, 0.33, 0.43}, - }, - Filter: "bookName='三国演义'", - Params: hnswParams, - }, - RetrieveVector: false, + Request: api.VectorRangeSearchRequest{}. + New("vector", vector, api.DistanceRange{Min: 0, Max: 20}). + Filter("bookName='三国演义'"). + Limit(15). + Config(api.VectorSearchConfig{}.New().Ef(200)), } - batchSearchResult, err := m.client.BatchSearchRow(batchSearchArgs) + + searchResult, err := m.client.VectorSearch(searchArgs) + if err != nil { + log.Fatalf("Fail to search row due to error: %v", err) + return err + } + + log.Printf("range search result: %+v", searchResult.Rows) + return nil +} + +func (m *MochowTest) batchSearch() error { + vectors := []api.Vector{ + api.FloatVector{0.3123, 0.43, 0.213}, + api.FloatVector{0.5512, 0.33, 0.43}, + } + + searchArgs := &api.VectorSearchArgs{ + Database: m.database, + Table: m.table, + Request: api.VectorBatchSearchRequest{}.New("vector", vectors). + Filter("bookName='三国演义'"). + Limit(10). + Config(api.VectorSearchConfig{}.New().Ef(200)). + Projections([]string{"id", "bookName", "author", "page"}), + } + + searchResult, err := m.client.VectorSearch(searchArgs) if err != nil { log.Fatalf("Fail to batch search row due to error: %v", err) return err } - log.Printf("batch search result: %+v", batchSearchResult) + + log.Printf("batch search result: %+v", searchResult.BatchRows) + + return nil +} + +func (m *MochowTest) bm25Search() error { + searchArgs := &api.BM25SearchArgs{ + Database: m.database, + Table: m.table, + Request: api.BM25SearchRequest{}. + New("book_segment_inverted_idx", "吕布"). + Limit(12). + Filter("bookName='三国演义'"). + ReadConsistency("STRONG"). + Projections([]string{"id", "vector"}), + } + + searchResult, err := m.client.BM25Search(searchArgs) + if err != nil { + log.Fatalf("Fail to search row due to error: %v", err) + return err + } + + log.Printf("bm25 search result: %+v", searchResult.Rows) + return nil +} + +func (m *MochowTest) hybridSearch() error { + vector := api.FloatVector([]float32{0.3123, 0.43, 0.213}) + + request := api.HybridSearchRequest{}. + New( + api.VectorTopkSearchRequest{}.New("vector", vector, 15).Config(api.VectorSearchConfig{}.New().Ef(200)), /*vector search args*/ + api.BM25SearchRequest{}.New("book_segment_inverted_idx", "吕布"), /*BM25 search args*/ + 0.4, /*vector search weight*/ + 0.6 /*BM25 search weight*/). + Filter("bookName='三国演义'"). + Limit(15) + + searchArgs := &api.HybridSearchArgs{ + Database: m.database, + Table: m.table, + Request: request, + } + + searchResult, err := m.client.HybridSearch(searchArgs) + if err != nil { + log.Fatalf("Fail to search row due to error: %v", err) + return err + } + + log.Printf("hybrid search result: %+v", searchResult.Rows) return nil } @@ -388,7 +491,7 @@ func (m *MochowTest) dropAndCreateVIndex() error { _, err := m.client.DescIndex(m.database, m.table, "vector_idx") if realErr, ok := err.(*client.BceServiceError); ok { if realErr.Code == int(api.IndexNotExist) { - log.Printf("Index already dropped") + log.Print("Index already dropped") break } } @@ -473,8 +576,24 @@ func main() { log.Println("Update row success.") // Search data - if err := mochowTest.searchData(); err != nil { - log.Printf("Fail to search row, err:%v", err) + if err := mochowTest.topkSearch(); err != nil { + log.Printf("Fail to topk search, err:%v", err) + return + } + if err := mochowTest.rangeSearch(); err != nil { + log.Printf("Fail to range search, err:%v", err) + return + } + if err := mochowTest.batchSearch(); err != nil { + log.Printf("Fail to batch search, err:%v", err) + return + } + if err := mochowTest.bm25Search(); err != nil { + log.Printf("Fail to bm25 search, err:%v", err) + return + } + if err := mochowTest.hybridSearch(); err != nil { + log.Printf("Fail to hybrid search, err:%v", err) return } log.Println("Search row success.") diff --git a/mochow/api/entity.go b/mochow/api/entity.go index 84f3442..45ef750 100644 --- a/mochow/api/entity.go +++ b/mochow/api/entity.go @@ -18,6 +18,7 @@ package api import ( "bytes" + "fmt" "github.com/bytedance/sonic" "github.com/bytedance/sonic/decoder" @@ -60,19 +61,23 @@ func (f *FieldSchema) MarshalJSON() ([]byte, error) { return field, nil } +type IndexParams interface{} type VectorIndexParams map[string]interface{} +type InvertedIndexParams map[string]interface{} type AutoBuildParams map[string]interface{} type IndexSchema struct { - IndexName string `json:"indexName,omitempty"` - IndexType IndexType `json:"indexType,omitempty"` - MetricType MetricType `json:"metricType,omitempty"` - Params VectorIndexParams `json:"params,omitempty"` - Field string `json:"field,omitempty"` - State IndexState `json:"state,omitempty"` - AutoBuild bool `json:"autoBuild,omitempty"` - AutoBuildPolicy AutoBuildParams `json:"autoBuildPolicy,omitempty"` + IndexName string `json:"indexName,omitempty"` + IndexType IndexType `json:"indexType,omitempty"` + MetricType MetricType `json:"metricType,omitempty"` + Params IndexParams `json:"params,omitempty"` + Field string `json:"field,omitempty"` + Fields []string `json:"fields,omitempty"` + FieldAttributes []InvertedIndexFieldAttribute `json:"fieldsIndexAttributes,omitempty"` + State IndexState `json:"state,omitempty"` + AutoBuild bool `json:"autoBuild,omitempty"` + AutoBuildPolicy AutoBuildParams `json:"autoBuildPolicy,omitempty"` } type TableSchema struct { @@ -167,6 +172,612 @@ type BatchANNSearchParams struct { Filter string `json:"filter,omitempty"` } +type searchRequest interface { + requestType() string + toDict() map[string]interface{} + isBatch() bool +} + +/* +Optional configurable params for vector search. + +For each index algorithm, the params that could be set are: + +IndexType: HNSW +Params: ef, pruning + +IndexType: HNSWPQ +Params: ef, pruning + +IndexType: PUCK +Params: searchCoarseCount + +IndexType: FLAT +Params: None +*/ +type VectorSearchConfig struct { + params map[string]interface{} +} + +func (h VectorSearchConfig) New() *VectorSearchConfig { + return &VectorSearchConfig{ + params: make(map[string]interface{}), + } +} + +func (h *VectorSearchConfig) Ef(ef uint32) *VectorSearchConfig { + h.params["ef"] = ef + return h +} + +func (h *VectorSearchConfig) Pruning(pruning bool) *VectorSearchConfig { + h.params["pruning"] = pruning + return h +} + +func (h *VectorSearchConfig) SearchCoarseCount(searchCoarseCount uint32) *VectorSearchConfig { + h.params["searchCoarseCount"] = searchCoarseCount + return h +} + +type vectorSearchRequest interface { + searchRequest + + // Make sure user not pass e.g. 'BM25SearchRequest' to VectorSearch api + vectorSearchRequestDummyInterface() +} + +type Vector interface { + name() string + representation() interface{} +} + +type FloatVector []float32 + +func (v FloatVector) name() string { + return "vectorFloats" +} + +func (v FloatVector) representation() interface{} { + return []float32(v) +} + +type request struct { + set map[string]bool +} + +func (r request) mark(key string) { + r.set[key] = true +} + +func (r request) isMarked(key string) bool { + _, ok := r.set[key] + return ok +} + +type searchCommonFields struct { + request + partitionKey map[string]interface{} + projections []string + readConsistency string + limit uint32 + filter string +} + +func searchCommonFieldsToMap(r *searchCommonFields) map[string]interface{} { + fields := make(map[string]interface{}) + if r.isMarked("partitionKey") { + fields["partitionKey"] = r.partitionKey + } + if r.isMarked("projections") { + fields["projections"] = r.projections + } + if r.isMarked("readConsistency") { + fields["readConsistency"] = r.readConsistency + } + if r.isMarked("filter") { + fields["filter"] = r.filter + } + if r.isMarked("limit") { + fields["limit"] = r.limit + } + return fields +} + +type vectorSearchFields struct { + searchCommonFields + vectorField string + vector Vector + vectors []Vector + distanceNear float64 + distanceFar float64 + config *VectorSearchConfig +} + +func (r vectorSearchFields) fillSearchFields(fields *map[string]interface{}) { + anns := make(map[string]interface{}) + if r.isMarked("vectorField") { + anns["vectorField"] = r.vectorField + } + if r.isMarked("vector") { + anns[r.vector.name()] = r.vector.representation() + } + if r.isMarked("vectors") && len(r.vectors) != 0 { + vectors := make([]interface{}, 0, len(r.vectors)) + for _, vec := range r.vectors { + vectors = append(vectors, vec.representation()) + } + anns[r.vectors[0].name()] = vectors + } + if r.isMarked("filter") { + anns["filter"] = r.filter + } + + params := make(map[string]interface{}) + if r.isMarked("config") { + for k, v := range r.config.params { + params[k] = v + } + } + if r.isMarked("distanceNear") { + params["distanceNear"] = r.distanceNear + } + if r.isMarked("distanceFar") { + params["distanceFar"] = r.distanceFar + } + if r.isMarked("limit") { + params["limit"] = r.limit + } + if len(params) != 0 { + anns["params"] = params + } + + if len(anns) != 0 { + (*fields)["anns"] = anns + } + + for k, v := range searchCommonFieldsToMap(&r.searchCommonFields) { + if k == "filter" || k == "limit" { // in "anns" + continue + } + (*fields)[k] = v + } +} + +/**** Vector Topk Search ****/ + +type VectorTopkSearchRequest struct { + vectorSearchRequest // interface + vectorSearchFields // common fields +} + +func (r VectorTopkSearchRequest) New(vectorField string, vector Vector, limit uint32) *VectorTopkSearchRequest { + r.set = make(map[string]bool, 0) + + r.mark("vectorField") + r.vectorField = vectorField + + r.mark("vector") + r.vector = vector + + r.mark("limit") + r.limit = limit + return &r +} + +func (r *VectorTopkSearchRequest) String() string { + return fmt.Sprintf("VectorTopkSearchRequest:%v", r.toDict()) +} + +func (r *VectorTopkSearchRequest) PartitionKey(partitionKey map[string]interface{}) *VectorTopkSearchRequest { + r.mark("partitionKey") + r.partitionKey = partitionKey + return r +} + +func (r *VectorTopkSearchRequest) ReadConsistency(readConsistency string) *VectorTopkSearchRequest { + r.mark("readConsistency") + r.readConsistency = readConsistency + return r +} + +func (r *VectorTopkSearchRequest) Projections(projections []string) *VectorTopkSearchRequest { + r.mark("projections") + r.projections = projections + return r +} + +func (r *VectorTopkSearchRequest) Filter(filter string) *VectorTopkSearchRequest { + r.mark("filter") + r.filter = filter + return r +} + +func (r *VectorTopkSearchRequest) Config(config *VectorSearchConfig) *VectorTopkSearchRequest { + r.mark("config") + r.config = config + return r +} + +func (r *VectorTopkSearchRequest) requestType() string { + return "search" +} + +func (r *VectorTopkSearchRequest) isBatch() bool { + return false +} + +func (r *VectorTopkSearchRequest) toDict() map[string]interface{} { + fields := make(map[string]interface{}) + r.fillSearchFields(&fields) + return fields +} + +func (r *VectorTopkSearchRequest) vectorSearchRequestDummyInterface() { +} + +/**** Vector Range Search ****/ +type DistanceRange struct { + Min, Max float64 +} + +type VectorRangeSearchRequest struct { + vectorSearchRequest // interface + vectorSearchFields // common fields +} + +func (r VectorRangeSearchRequest) New(vectorField string, vector Vector, distanceRange DistanceRange) *VectorRangeSearchRequest { + r.set = make(map[string]bool, 0) + + r.mark("vectorField") + r.vectorField = vectorField + + r.mark("vector") + r.vector = vector + + r.mark("distanceNear") + r.distanceNear = distanceRange.Min + + r.mark("distanceFar") + r.distanceFar = distanceRange.Max + return &r +} + +func (r *VectorRangeSearchRequest) String() string { + return fmt.Sprintf("VectorRangeSearchRequest:%v", r.toDict()) +} + +func (r *VectorRangeSearchRequest) PartitionKey(partitionKey map[string]interface{}) *VectorRangeSearchRequest { + r.mark("partitionKey") + r.partitionKey = partitionKey + return r +} + +func (r *VectorRangeSearchRequest) ReadConsistency(readConsistency string) *VectorRangeSearchRequest { + r.mark("readConsistency") + r.readConsistency = readConsistency + return r +} + +func (r *VectorRangeSearchRequest) Projections(projections []string) *VectorRangeSearchRequest { + r.mark("projections") + r.projections = projections + return r +} + +func (r *VectorRangeSearchRequest) Limit(limit uint32) *VectorRangeSearchRequest { + r.mark("limit") + r.limit = limit + return r +} + +func (r *VectorRangeSearchRequest) Filter(filter string) *VectorRangeSearchRequest { + r.mark("filter") + r.filter = filter + return r +} + +func (r *VectorRangeSearchRequest) Config(config *VectorSearchConfig) *VectorRangeSearchRequest { + r.mark("config") + r.config = config + return r +} + +func (r *VectorRangeSearchRequest) requestType() string { + return "search" +} + +func (r *VectorRangeSearchRequest) isBatch() bool { + return false +} + +func (r *VectorRangeSearchRequest) toDict() map[string]interface{} { + fields := make(map[string]interface{}) + r.fillSearchFields(&fields) + return fields +} + +func (r *VectorRangeSearchRequest) vectorSearchRequestDummyInterface() { +} + +/**** Vector Batch Search ****/ +type VectorBatchSearchRequest struct { + vectorSearchRequest // interface + vectorSearchFields // common fields +} + +func (r VectorBatchSearchRequest) New(vectorField string, vectors []Vector) *VectorBatchSearchRequest { + r.set = make(map[string]bool, 0) + + r.mark("vectorField") + r.vectorField = vectorField + + r.mark("vectors") + r.vectors = vectors + return &r +} + +func (r *VectorBatchSearchRequest) String() string { + return fmt.Sprintf("VectorBatchSearchRequest:%v", r.toDict()) +} + +func (r *VectorBatchSearchRequest) PartitionKey(partitionKey map[string]interface{}) *VectorBatchSearchRequest { + r.mark("partitionKey") + r.partitionKey = partitionKey + return r +} + +func (r *VectorBatchSearchRequest) ReadConsistency(readConsistency string) *VectorBatchSearchRequest { + r.mark("readConsistency") + r.readConsistency = readConsistency + return r +} + +func (r *VectorBatchSearchRequest) Projections(projections []string) *VectorBatchSearchRequest { + r.mark("projections") + r.projections = projections + return r +} + +func (r *VectorBatchSearchRequest) Limit(limit uint32) *VectorBatchSearchRequest { + r.mark("limit") + r.limit = limit + return r +} + +func (r *VectorBatchSearchRequest) DistanceRange(distanceRange DistanceRange) *VectorBatchSearchRequest { + r.mark("distanceNear") + r.distanceNear = distanceRange.Min + + r.mark("distanceFar") + r.distanceFar = distanceRange.Max + return r +} + +func (r *VectorBatchSearchRequest) Filter(filter string) *VectorBatchSearchRequest { + r.mark("filter") + r.filter = filter + return r +} + +func (r *VectorBatchSearchRequest) Config(config *VectorSearchConfig) *VectorBatchSearchRequest { + r.mark("config") + r.config = config + return r +} + +func (r *VectorBatchSearchRequest) isBatch() bool { + return true +} + +func (r *VectorBatchSearchRequest) requestType() string { + return "batchSearch" +} + +func (r *VectorBatchSearchRequest) toDict() map[string]interface{} { + fields := make(map[string]interface{}) + r.fillSearchFields(&fields) + return fields +} + +func (r *VectorBatchSearchRequest) vectorSearchRequestDummyInterface() { +} + +/**** BM25 Search ****/ +type bm25SearchRequest interface { + searchRequest + + // Make sure user not pass e.g. 'VectorSearchRequest' to BM25Search api + bm25SearchRequestDummyInterface() +} + +type BM25SearchRequest struct { + bm25SearchRequest // interface + searchCommonFields // common fields + indexName string + searchText string +} + +func (r BM25SearchRequest) New(indexName string, searchText string) *BM25SearchRequest { + r.set = make(map[string]bool, 0) + + r.indexName = indexName + r.searchText = searchText + return &r +} + +func (r *BM25SearchRequest) String() string { + return fmt.Sprintf("BM25SearchRequest:%v", r.toDict()) +} + +func (r *BM25SearchRequest) PartitionKey(partitionKey map[string]interface{}) *BM25SearchRequest { + r.mark("partitionKey") + r.partitionKey = partitionKey + return r +} + +func (r *BM25SearchRequest) ReadConsistency(readConsistency string) *BM25SearchRequest { + r.mark("readConsistency") + r.readConsistency = readConsistency + return r +} + +func (r *BM25SearchRequest) Projections(projections []string) *BM25SearchRequest { + r.mark("projections") + r.projections = projections + return r +} + +func (r *BM25SearchRequest) Limit(limit uint32) *BM25SearchRequest { + r.mark("limit") + r.limit = limit + return r +} + +func (r *BM25SearchRequest) Filter(filter string) *BM25SearchRequest { + r.mark("filter") + r.filter = filter + return r +} + +func (r *BM25SearchRequest) toDict() map[string]interface{} { + fields := make(map[string]interface{}) + for k, v := range searchCommonFieldsToMap(&r.searchCommonFields) { + fields[k] = v + } + + bm25Params := make(map[string]interface{}) + bm25Params["indexName"] = r.indexName + bm25Params["searchText"] = r.searchText + fields["BM25SearchParams"] = bm25Params + + return fields +} + +func (r *BM25SearchRequest) requestType() string { + return "search" +} + +func (r *BM25SearchRequest) isBatch() bool { + return false +} + +func (r *BM25SearchRequest) bm25SearchRequestDummyInterface() { +} + +/**** Hybrid Search ****/ + +type hybridSearchRequest interface { + searchRequest + + // Make sure user not pass e.g. 'VectorSearchRequest' to HybridSearch api + hybridSearchRequestDummyInterface() +} + +type HybridSearchRequest struct { + hybridSearchRequest // interface + searchCommonFields // common fields + + vectorRequest vectorSearchRequest + bm25Request bm25SearchRequest + vectorWeight float32 + bm25Weight float32 +} + +/* +Note: 'limit' and 'filter' are global settings, and they will +apply to both vector search and BM25 search. Avoid setting them in +'bm25Request' or 'vectorRequest'. Any settings in 'vectorRequest' +or 'bm25Request' for 'limit' or 'filter' will be overridden by the +general settings. +*/ +func (r HybridSearchRequest) New( + vectorRequest vectorSearchRequest, + bm25Request bm25SearchRequest, + vectorWeight float32, + bm25Weight float32, +) *HybridSearchRequest { + r.set = make(map[string]bool, 0) + + r.vectorRequest = vectorRequest + r.bm25Request = bm25Request + r.vectorWeight = vectorWeight + r.bm25Weight = bm25Weight + return &r +} + +func (r *HybridSearchRequest) String() string { + return fmt.Sprintf("HybridSearchRequest:%v", r.toDict()) +} + +func (r *HybridSearchRequest) PartitionKey(partitionKey map[string]interface{}) *HybridSearchRequest { + r.mark("partitionKey") + r.partitionKey = partitionKey + return r +} + +func (r *HybridSearchRequest) ReadConsistency(readConsistency string) *HybridSearchRequest { + r.mark("readConsistency") + r.readConsistency = readConsistency + return r +} + +func (r *HybridSearchRequest) Projections(projections []string) *HybridSearchRequest { + r.mark("projections") + r.projections = projections + return r +} + +func (r *HybridSearchRequest) Limit(limit uint32) *HybridSearchRequest { + r.mark("limit") + r.limit = limit + return r +} + +func (r *HybridSearchRequest) Filter(filter string) *HybridSearchRequest { + r.mark("filter") + r.filter = filter + return r +} + +func (r *HybridSearchRequest) toDict() map[string]interface{} { + fields := make(map[string]interface{}) + + for k, v := range r.bm25Request.toDict() { + fields[k] = v + } + for k, v := range r.vectorRequest.toDict() { + fields[k] = v + } + + for k, v := range searchCommonFieldsToMap(&r.searchCommonFields) { + fields[k] = v + } + + _, ok := fields["anns"] + if ok { + fields["anns"].(map[string]interface{})["weight"] = r.vectorWeight + } + + _, ok = fields["BM25SearchParams"] + if ok { + fields["BM25SearchParams"].(map[string]interface{})["weight"] = r.bm25Weight + } + + return fields +} + +func (r *HybridSearchRequest) isBatch() bool { + return false +} + +func (r *HybridSearchRequest) requestType() string { + return "search" +} + +func (r *HybridSearchRequest) hybridSearchRequestDummyInterface() { +} + type AutoBuildPolicy interface { Params() map[string]interface{} AddTiming(timing string) diff --git a/mochow/api/enum.go b/mochow/api/enum.go index 6487fac..949e36a 100644 --- a/mochow/api/enum.go +++ b/mochow/api/enum.go @@ -35,6 +35,31 @@ const ( // scalar index type SecondaryIndex IndexType = "SECONDARY" + + // inverted index type + InvertedIndex IndexType = "INVERTED" +) + +type InvertedIndexAnalyzer string + +const ( + EnglishAnalyzer InvertedIndexAnalyzer = "ENGLISH_ANALYZER" + ChineseAnalyzer InvertedIndexAnalyzer = "CHINESE_ANALYZER" + DefaultAnalyzer InvertedIndexAnalyzer = "DEFAULT_ANALYZER" +) + +type InvertedIndexParseMode string + +const ( + CoarseMode InvertedIndexParseMode = "COARSE_MODE" + FineMode InvertedIndexParseMode = "FINE_MODE" +) + +type InvertedIndexFieldAttribute string + +const ( + NotAnalyzed InvertedIndexFieldAttribute = "ATTRIBUTE_NOT_ANALYZED" + Analyzed InvertedIndexFieldAttribute = "ATTRIBUTE_ANALYZED" ) type FieldType string diff --git a/mochow/api/model.go b/mochow/api/model.go index 0741812..6a278ac 100644 --- a/mochow/api/model.go +++ b/mochow/api/model.go @@ -144,6 +144,7 @@ type QueryRowResult struct { Row Row `json:"row,omitempty"` } +// Deprecated type SearchRowArgs struct { Database string `json:"database"` Table string `json:"table"` @@ -157,6 +158,7 @@ type SearchRowArgs struct { type RowResult struct { Row Row `json:"row"` Distance float64 `json:"distance"` + Score float64 `json:"score"` } type SearchRowResult struct { @@ -164,6 +166,33 @@ type SearchRowResult struct { Rows []RowResult `json:"rows,omitempty"` } +// vector topk search, range search and batch search +type VectorSearchArgs struct { + Database string + Table string + Request vectorSearchRequest +} + +// BM25 search +type BM25SearchArgs struct { + Database string + Table string + Request bm25SearchRequest +} + +// hybrid search (vector + BM25) +type HybridSearchArgs struct { + Database string + Table string + Request hybridSearchRequest +} + +type SearchResult struct { + IsBatch bool + Rows *SearchRowResult // for single search + BatchRows *BatchSearchRowResult // for batch search +} + type UpdateRowArgs struct { Database string `json:"database"` Table string `json:"table"` diff --git a/mochow/api/row.go b/mochow/api/row.go index 5c1e31e..a9d35fc 100644 --- a/mochow/api/row.go +++ b/mochow/api/row.go @@ -140,6 +140,62 @@ func QueryRow(cli client.Client, args *QueryRowArgs) (*QueryRowResult, error) { return result, nil } +func VectorSearch(cli client.Client, args *VectorSearchArgs) (*SearchResult, error) { + return search(cli, args.Database, args.Table, args.Request) +} + +func BM25Search(cli client.Client, args *BM25SearchArgs) (*SearchResult, error) { + return search(cli, args.Database, args.Table, args.Request) +} + +func HybridSearch(cli client.Client, args *HybridSearchArgs) (*SearchResult, error) { + return search(cli, args.Database, args.Table, args.Request) +} + +func search(cli client.Client, database string, table string, request searchRequest) (*SearchResult, error) { + args := request.toDict() + args["database"] = database + args["table"] = table + + jsonBytes, err := sonic.Marshal(args) + if err != nil { + return nil, err + } + body, err := client.NewBodyFromBytes(jsonBytes) + if err != nil { + return nil, err + } + + req := &client.BceRequest{} + req.SetURI(getRowURI()) + req.SetMethod(http.Post) + req.SetParam(request.requestType(), "") + req.SetBody(body) + + resp := &client.BceResponse{} + if err := cli.SendRequest(req, resp); err != nil { + return nil, err + } + if resp.IsFail() { + return nil, resp.ServiceError() + } + result := &SearchResult{IsBatch: request.isBatch()} + if result.IsBatch { + result.BatchRows = &BatchSearchRowResult{} + if err := resp.ParseJSONBody(result.BatchRows); err != nil { + return nil, err + } + } else { + result.Rows = &SearchRowResult{} + if err := resp.ParseJSONBody(result.Rows); err != nil { + return nil, err + } + } + + return result, nil +} + +// Deprecated: you should use VectorSearch with VectorTopkSearchRequest or VectorRangeSearchRequest instead. func SearchRow(cli client.Client, args *SearchRowArgs) (*SearchRowResult, error) { req := &client.BceRequest{} req.SetURI(getRowURI()) @@ -227,6 +283,7 @@ func SelectRow(cli client.Client, args *SelectRowArgs) (*SelectRowResult, error) return result, nil } +// Deprecated: you should use VectorSearch with VectorBatchSearchRequest instead. func BatchSearchRow(cli client.Client, args *BatchSearchRowArgs) (*BatchSearchRowResult, error) { req := &client.BceRequest{} req.SetURI(getRowURI()) diff --git a/mochow/client.go b/mochow/client.go index e2295ce..64efe28 100644 --- a/mochow/client.go +++ b/mochow/client.go @@ -217,10 +217,23 @@ func (c *Client) QueryRow(args *api.QueryRowArgs) (*api.QueryRowResult, error) { return api.QueryRow(c, args) } +// Deprecated: you should use VectorSearch with VectorTopkSearchRequest or VectorRangeSearchRequest instead. func (c *Client) SearchRow(args *api.SearchRowArgs) (*api.SearchRowResult, error) { return api.SearchRow(c, args) } +func (c *Client) VectorSearch(args *api.VectorSearchArgs) (*api.SearchResult, error) { + return api.VectorSearch(c, args) +} + +func (c *Client) BM25Search(args *api.BM25SearchArgs) (*api.SearchResult, error) { + return api.BM25Search(c, args) +} + +func (c *Client) HybridSearch(args *api.HybridSearchArgs) (*api.SearchResult, error) { + return api.HybridSearch(c, args) +} + func (c *Client) UpdateRow(args *api.UpdateRowArgs) error { return api.UpdateRow(c, args) } @@ -229,6 +242,7 @@ func (c *Client) SelectRow(args *api.SelectRowArgs) (*api.SelectRowResult, error return api.SelectRow(c, args) } +// Deprecated: you should use VectorSearch with VectorBatchSearchRequest instead. func (c *Client) BatchSearchRow(args *api.BatchSearchRowArgs) (*api.BatchSearchRowResult, error) { return api.BatchSearchRow(c, args) }