From e1a92a7f891e25a1b11eeaded821e927321f856c Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Mon, 25 Mar 2024 16:45:27 +0800 Subject: [PATCH] HTTP API: search finished (#874) ### What problem does this PR solve? ``` curl --request GET --url localhost:23820/databases/default/tables/my_table/docs --header 'accept: application/json' --header 'content-type: application/json' --data ' { "output": [ "num", "body", "vec" ], "fusion": { "method": "rrf", "match": { "fields": "body", "query": "bloom", "operator": "topn=3" }, "knn": { "fields": "vec", "query_vector": [3.0, 2.8, 2.7, 3.1], "element_type": "float", "top_k": 3, "metric_type": "inner_product" } } } ' ``` ``` { "error_code":0, "output":[{"body":"Office for Harmful Blooms","num":"2","vec":"4.000000, 4.200000, 4.300000, 4.500000"},{"body":"unnecessary and harmful","num":"1","vec":"1.000000, 1.200000, 0.800000, 0.900000"}] } ``` Issue link:#779 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Signed-off-by: Jin Hai --- docs/references/http_api_reference.md | 6 +- python/README.md | 3 +- python/infinity/db.py | 3 +- src/common/stl.cppm | 2 + src/network/http/http_search.cpp | 61 ++++++++++++--- src/storage/column_vector/value.cpp | 108 ++++++++++++++++++++++---- src/storage/column_vector/value.cppm | 48 ++++++++---- 7 files changed, 183 insertions(+), 48 deletions(-) diff --git a/docs/references/http_api_reference.md b/docs/references/http_api_reference.md index bf37ce40f4..2a8676a949 100644 --- a/docs/references/http_api_reference.md +++ b/docs/references/http_api_reference.md @@ -670,9 +670,9 @@ curl --request GET \ "method": "rrf", "match": { - "fields": "title", - "query": "rock fire", - "operator": "and" + "fields": "body", + "query": "bloom", + "operator": "topn=1" } "knn": { diff --git a/python/README.md b/python/README.md index 480680ded0..3fca78419d 100644 --- a/python/README.md +++ b/python/README.md @@ -27,10 +27,11 @@ Note that pypi allow a version of a package [be uploaded only once](https://pypi ```python import infinity from infinity.common import REMOTE_HOST +from infinity.common import ConflictType infinity_obj = infinity.connect(REMOTE_HOST) db = infinity_obj.get_database("default") -db.drop_table("my_table", if_exists=True) +db.drop_table("my_table", ConflictType.Ignore) table = db.create_table( "my_table", {"num": "integer", "body": "varchar", "vec": "vector,5,float"}, None) table.insert( diff --git a/python/infinity/db.py b/python/infinity/db.py index 03356ce61b..27244f6140 100644 --- a/python/infinity/db.py +++ b/python/infinity/db.py @@ -14,7 +14,6 @@ from abc import ABC, abstractmethod - class Database(ABC): @abstractmethod @@ -22,7 +21,7 @@ def create_table(self, table_name, schema, options): pass # implement create table logic here @abstractmethod - def drop_table(self, table_name, if_exists=True): + def drop_table(self, table_name): pass # implement drop table logic here @abstractmethod diff --git a/src/common/stl.cppm b/src/common/stl.cppm index f926069648..75a0089227 100644 --- a/src/common/stl.cppm +++ b/src/common/stl.cppm @@ -205,6 +205,8 @@ export namespace std { using std::function; using std::monostate; using std::thread; + + using std::is_same_v; } // namespace std namespace infinity { diff --git a/src/network/http/http_search.cpp b/src/network/http/http_search.cpp index 7b675c54eb..9c3ee4993b 100644 --- a/src/network/http/http_search.cpp +++ b/src/network/http/http_search.cpp @@ -24,6 +24,7 @@ import parsed_expr; import knn_expr; import match_expr; import search_expr; +import fusion_expr; import column_expr; import defer_op; import expr_parser; @@ -51,7 +52,7 @@ void HTTPSearch::Process(Infinity *infinity_ptr, Vector *output_columns{nullptr}; ParsedExpr *filter{nullptr}; - Vector *fusion_exprs{nullptr}; + FusionExpr *fusion_expr{nullptr}; KnnExpr *knn_expr{nullptr}; MatchExpr *match_expr{nullptr}; SearchExpr *search_expr = new SearchExpr(); @@ -67,12 +68,9 @@ void HTTPSearch::Process(Infinity *infinity_ptr, delete filter; filter = nullptr; } - if (fusion_exprs != nullptr) { - for (auto &expr : *fusion_exprs) { - delete expr; - } - delete fusion_exprs; - fusion_exprs = nullptr; + if (fusion_expr != nullptr) { + delete fusion_expr; + fusion_expr = nullptr; } if (knn_expr != nullptr) { delete knn_expr; @@ -122,15 +120,51 @@ void HTTPSearch::Process(Infinity *infinity_ptr, filter = ParseFilter(filter_json, http_status, response); } else if (IsEqual(key, "fusion")) { - if (fusion_exprs != nullptr or knn_expr != nullptr or match_expr != nullptr) { + if (fusion_expr != nullptr or knn_expr != nullptr or match_expr != nullptr) { response["error_code"] = ErrorCode::kInvalidExpression; response["error_message"] = "There are more than one fusion expressions, Or fusion expression coexists with knn / match expression "; return; } - + auto &fusion_children = elem.value(); + for (const auto &expression : fusion_children.items()) { + String key = expression.key(); + ToLower(key); + + if (IsEqual(key, "knn")) { + auto &knn_json = expression.value(); + if (!knn_json.is_object()) { + response["error_code"] = ErrorCode::kInvalidExpression; + response["error_message"] = "KNN field should be object"; + return; + } + knn_expr = ParseKnn(knn_json, http_status, response); + search_expr->AddExpr(knn_expr); + knn_expr = nullptr; + } else if (IsEqual(key, "match")) { + auto &match_json = expression.value(); + match_expr = ParseMatch(match_json, http_status, response); + search_expr->AddExpr(match_expr); + match_expr = nullptr; + } else if (IsEqual(key, "method")) { + if (fusion_expr != nullptr && !fusion_expr->method_.empty()) { + response["error_code"] = ErrorCode::kInvalidExpression; + response["error_message"] = "Method is already given"; + return; + } + fusion_expr = new FusionExpr(); + fusion_expr->method_ = expression.value(); + search_expr->AddExpr(fusion_expr); + fusion_expr = nullptr; + } else { + response["error_code"] = ErrorCode::kInvalidExpression; + response["error_message"] = "Error fusion clause"; + return; + } + } + search_expr->Validate(); } else if (IsEqual(key, "knn")) { - if (fusion_exprs != nullptr or knn_expr != nullptr or match_expr != nullptr) { + if (fusion_expr != nullptr or knn_expr != nullptr or match_expr != nullptr) { response["error_code"] = ErrorCode::kInvalidExpression; response["error_message"] = "There are more than one fusion expressions, Or fusion expression coexists with knn / match expression "; @@ -146,13 +180,16 @@ void HTTPSearch::Process(Infinity *infinity_ptr, search_expr->AddExpr(knn_expr); knn_expr = nullptr; } else if (IsEqual(key, "match")) { - if (fusion_exprs != nullptr or knn_expr != nullptr or match_expr != nullptr) { + if (fusion_expr != nullptr or knn_expr != nullptr or match_expr != nullptr) { response["error_code"] = ErrorCode::kInvalidExpression; response["error_message"] = "There are more than one fusion expressions, Or fusion expression coexists with knn / match expression "; return; } - + auto &match_json = elem.value(); + match_expr = ParseMatch(match_json, http_status, response); + search_expr->AddExpr(match_expr); + match_expr = nullptr; } else { response["error_code"] = ErrorCode::kInvalidExpression; response["error_message"] = "Unknown expression: " + key; diff --git a/src/storage/column_vector/value.cpp b/src/storage/column_vector/value.cpp index 018990cfb5..b7638e4bac 100644 --- a/src/storage/column_vector/value.cpp +++ b/src/storage/column_vector/value.cpp @@ -725,57 +725,133 @@ void Value::Reset() { String Value::ToString() const { switch (type_.type()) { - case kBoolean: { + case LogicalType::kBoolean: { return value_.boolean ? "true" : "false"; } - case kTinyInt: { + case LogicalType::kTinyInt: { return std::to_string(value_.tiny_int); } - case kSmallInt: { + case LogicalType::kSmallInt: { return std::to_string(value_.small_int); } - case kInteger: { + case LogicalType::kInteger: { return std::to_string(value_.integer); } - case kBigInt: { + case LogicalType::kBigInt: { return std::to_string(value_.big_int); } - case kHugeInt: { + case LogicalType::kHugeInt: { return value_.huge_int.ToString(); } - case kFloat: { + case LogicalType::kFloat: { return std::to_string(value_.float32); } - case kDouble: { + case LogicalType::kDouble: { return std::to_string(value_.float64); } - case kDate: { + case LogicalType::kDate: { return value_.date.ToString(); } - case kTime: { + case LogicalType::kTime: { return value_.time.ToString(); } - case kDateTime: { + case LogicalType::kDateTime: { return value_.datetime.ToString(); } - case kTimestamp: { + case LogicalType::kTimestamp: { return value_.timestamp.ToString(); } - case kInterval: { + case LogicalType::kInterval: { return value_.interval.ToString(); } - case kRowID: { + case LogicalType::kRowID: { return value_.row.ToString(); } - case kVarchar: { + case LogicalType::kVarchar: { return value_info_->Get().GetString(); } + case LogicalType::kEmbedding: { + EmbeddingInfo* embedding_info = static_cast(type_.type_info().get()); + return value_info_->Get().GetString(embedding_info); + } default: { UnrecoverableError(fmt::format("Value::ToString() not implemented for type {}", type_.ToString())); return {}; } } - return ""; + return {}; } +String EmbeddingValueInfo::GetString(EmbeddingInfo* embedding_info) { + String res; + SizeT count = embedding_info->Dimension(); + char* ptr = data_.data(); + switch(embedding_info->Type()) { + case EmbeddingDataType::kElemBit: { + UnrecoverableError("Not implemented embedding data type: bit."); + break; + } + case EmbeddingDataType::kElemInt8: { + for(SizeT i = 0; i < count - 1; ++ i) { + i8 element = ((i8*)ptr)[i]; + res += std::to_string(element) + ", "; + } + i8 element = ((i8*)ptr)[count - 1]; + res += std::to_string(element); + break; + } + case EmbeddingDataType::kElemInt16: { + for(SizeT i = 0; i < count - 1; ++ i) { + i16 element = ((i16*)ptr)[i]; + res += std::to_string(element) + ", "; + } + i16 element = ((i16*)ptr)[count - 1]; + res += std::to_string(element); + break; + } + case EmbeddingDataType::kElemInt32: { + for(SizeT i = 0; i < count - 1; ++ i) { + i32 element = ((i32*)ptr)[i]; + res += std::to_string(element) + ", "; + } + i32 element = ((i32*)ptr)[count - 1]; + res += std::to_string(element); + break; + } + case EmbeddingDataType::kElemInt64: { + for(SizeT i = 0; i < count - 1; ++ i) { + i64 element = ((i64*)ptr)[i]; + res += std::to_string(element) + ", "; + } + i64 element = ((i64*)ptr)[count - 1]; + res += std::to_string(element); + break; + } + case EmbeddingDataType::kElemFloat: { + for(SizeT i = 0; i < count - 1; ++ i) { + f32 element = ((f32*)ptr)[i]; + res += std::to_string(element) + ", "; + } + f32 element = ((f32*)ptr)[count - 1]; + res += std::to_string(element); + break; + } + case EmbeddingDataType::kElemDouble: { + for(SizeT i = 0; i < count - 1; ++ i) { + f64 element = ((f64*)ptr)[i]; + res += std::to_string(element) + ", "; + } + f64 element = ((f64*)ptr)[count - 1]; + res += std::to_string(element); + break; + } + default: { + UnrecoverableError("Not supported embedding data type."); + break; + } + } + return res; +} + + } // namespace infinity diff --git a/src/storage/column_vector/value.cppm b/src/storage/column_vector/value.cppm index f58bf244a0..5b0460ec1a 100644 --- a/src/storage/column_vector/value.cppm +++ b/src/storage/column_vector/value.cppm @@ -15,7 +15,6 @@ module; export module value; - import stl; import type_info; import logical_type; @@ -23,6 +22,7 @@ import infinity_exception; import internal_types; import embedding_info; import data_type; +import knn_expr; namespace infinity { @@ -95,6 +95,8 @@ public: std::memcpy(data_.data(), values_p.data(), len); } + String GetString(EmbeddingInfo* embedding_info); + Pair GetData() const { return MakePair(data_.data(), data_.size()); } protected: @@ -147,17 +149,17 @@ public: static Value MakeBox(BoxT input); -// static Value MakePath(PathT input); -// -// static Value MakePolygon(PolygonT input); + // static Value MakePath(PathT input); + // + // static Value MakePolygon(PolygonT input); static Value MakeCircle(CircleT input); -// static Value MakeBitmap(BitmapT input); + // static Value MakeBitmap(BitmapT input); static Value MakeUuid(UuidT input); -// static Value MakeBlob(BlobT input); + // static Value MakeBlob(BlobT input); static Value MakeRow(RowID input); @@ -172,6 +174,24 @@ public: auto embedding_info_ptr = EmbeddingInfo::Make(ToEmbeddingDataType(), vec.size()); Value value(LogicalType::kEmbedding, embedding_info_ptr); value.value_info_ = MakeShared(vec); + if constexpr (std::is_same_v) { + value.type_ = DataType(LogicalType::kEmbedding, EmbeddingInfo::Make(EmbeddingDataType::kElemBit, vec.size())); + } else if constexpr (std::is_same_v) { + value.type_ = DataType(LogicalType::kEmbedding, EmbeddingInfo::Make(EmbeddingDataType::kElemInt8, vec.size())); + } else if constexpr (std::is_same_v) { + value.type_ = DataType(LogicalType::kEmbedding, EmbeddingInfo::Make(EmbeddingDataType::kElemInt16, vec.size())); + } else if constexpr (std::is_same_v) { + value.type_ = DataType(LogicalType::kEmbedding, EmbeddingInfo::Make(EmbeddingDataType::kElemInt32, vec.size())); + } else if constexpr (std::is_same_v) { + value.type_ = DataType(LogicalType::kEmbedding, EmbeddingInfo::Make(EmbeddingDataType::kElemInt64, vec.size())); + } else if constexpr (std::is_same_v) { + value.type_ = DataType(LogicalType::kEmbedding, EmbeddingInfo::Make(EmbeddingDataType::kElemFloat, vec.size())); + } else if constexpr (std::is_same_v) { + value.type_ = DataType(LogicalType::kEmbedding, EmbeddingInfo::Make(EmbeddingDataType::kElemDouble, vec.size())); + } else { + UnrecoverableError("Not supported embedding data type."); + } + return value; } @@ -297,23 +317,23 @@ LineSegT Value::GetValue() const; template <> BoxT Value::GetValue() const; -//template <> -//PathT Value::GetValue() const; +// template <> +// PathT Value::GetValue() const; // -//template <> -//PolygonT Value::GetValue() const; +// template <> +// PolygonT Value::GetValue() const; template <> CircleT Value::GetValue() const; -//template <> -//BitmapT Value::GetValue() const; +// template <> +// BitmapT Value::GetValue() const; template <> UuidT Value::GetValue() const; -//template <> -//BlobT Value::GetValue() const; +// template <> +// BlobT Value::GetValue() const; template <> RowID Value::GetValue() const;