Skip to content

Commit

Permalink
HTTP API: search finished (#874)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

```
curl --request GET --url localhost:23820/databases/default/tables/my_table/docs --header 'accept: application/json' --header 'content-type: application/json' --data '
{
    "output":
    [
        "num",
        "body",
        "vec"
    ],
    "fusion": 
    {
        "method": "rrf",
        "match":
        {
            "fields": "body", 
            "query": "bloom",
            "operator": "topn=3"
        },
        "knn":
        {
            "fields": "vec",
            "query_vector": [3.0, 2.8, 2.7, 3.1],
            "element_type": "float",
            "top_k": 3,
            "metric_type": "inner_product"
        }
    }
} '
```

```
{
"error_code":0,
"output":[{"body":"Office for Harmful Blooms","num":"2","vec":"4.000000, 4.200000, 4.300000, 4.500000"},{"body":"unnecessary and harmful","num":"1","vec":"1.000000, 1.200000, 0.800000, 0.900000"}]
}
```

Issue link:#779

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
  • Loading branch information
JinHai-CN authored Mar 25, 2024
1 parent 99299ef commit e1a92a7
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 48 deletions.
6 changes: 3 additions & 3 deletions docs/references/http_api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,9 @@ curl --request GET \
"method": "rrf",
"match":
{
"fields": "title",
"query": "rock fire",
"operator": "and"
"fields": "body",
"query": "bloom",
"operator": "topn=1"
}
"knn":
{
Expand Down
3 changes: 2 additions & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ Note that pypi allow a version of a package [be uploaded only once](https://pypi
```python
import infinity
from infinity.common import REMOTE_HOST
from infinity.common import ConflictType

infinity_obj = infinity.connect(REMOTE_HOST)
db = infinity_obj.get_database("default")
db.drop_table("my_table", if_exists=True)
db.drop_table("my_table", ConflictType.Ignore)
table = db.create_table(
"my_table", {"num": "integer", "body": "varchar", "vec": "vector,5,float"}, None)
table.insert(
Expand Down
3 changes: 1 addition & 2 deletions python/infinity/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@

from abc import ABC, abstractmethod


class Database(ABC):

@abstractmethod
def create_table(self, table_name, schema, options):
pass # implement create table logic here

@abstractmethod
def drop_table(self, table_name, if_exists=True):
def drop_table(self, table_name):
pass # implement drop table logic here

@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ export namespace std {
using std::function;
using std::monostate;
using std::thread;

using std::is_same_v;
} // namespace std

namespace infinity {
Expand Down
61 changes: 49 additions & 12 deletions src/network/http/http_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import parsed_expr;
import knn_expr;
import match_expr;
import search_expr;
import fusion_expr;
import column_expr;
import defer_op;
import expr_parser;
Expand Down Expand Up @@ -51,7 +52,7 @@ void HTTPSearch::Process(Infinity *infinity_ptr,

Vector<ParsedExpr *> *output_columns{nullptr};
ParsedExpr *filter{nullptr};
Vector<ParsedExpr *> *fusion_exprs{nullptr};
FusionExpr *fusion_expr{nullptr};
KnnExpr *knn_expr{nullptr};
MatchExpr *match_expr{nullptr};
SearchExpr *search_expr = new SearchExpr();
Expand All @@ -67,12 +68,9 @@ void HTTPSearch::Process(Infinity *infinity_ptr,
delete filter;
filter = nullptr;
}
if (fusion_exprs != nullptr) {
for (auto &expr : *fusion_exprs) {
delete expr;
}
delete fusion_exprs;
fusion_exprs = nullptr;
if (fusion_expr != nullptr) {
delete fusion_expr;
fusion_expr = nullptr;
}
if (knn_expr != nullptr) {
delete knn_expr;
Expand Down Expand Up @@ -122,15 +120,51 @@ void HTTPSearch::Process(Infinity *infinity_ptr,

filter = ParseFilter(filter_json, http_status, response);
} else if (IsEqual(key, "fusion")) {
if (fusion_exprs != nullptr or knn_expr != nullptr or match_expr != nullptr) {
if (fusion_expr != nullptr or knn_expr != nullptr or match_expr != nullptr) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] =
"There are more than one fusion expressions, Or fusion expression coexists with knn / match expression ";
return;
}

auto &fusion_children = elem.value();
for (const auto &expression : fusion_children.items()) {
String key = expression.key();
ToLower(key);

if (IsEqual(key, "knn")) {
auto &knn_json = expression.value();
if (!knn_json.is_object()) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "KNN field should be object";
return;
}
knn_expr = ParseKnn(knn_json, http_status, response);
search_expr->AddExpr(knn_expr);
knn_expr = nullptr;
} else if (IsEqual(key, "match")) {
auto &match_json = expression.value();
match_expr = ParseMatch(match_json, http_status, response);
search_expr->AddExpr(match_expr);
match_expr = nullptr;
} else if (IsEqual(key, "method")) {
if (fusion_expr != nullptr && !fusion_expr->method_.empty()) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Method is already given";
return;
}
fusion_expr = new FusionExpr();
fusion_expr->method_ = expression.value();
search_expr->AddExpr(fusion_expr);
fusion_expr = nullptr;
} else {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Error fusion clause";
return;
}
}
search_expr->Validate();
} else if (IsEqual(key, "knn")) {
if (fusion_exprs != nullptr or knn_expr != nullptr or match_expr != nullptr) {
if (fusion_expr != nullptr or knn_expr != nullptr or match_expr != nullptr) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] =
"There are more than one fusion expressions, Or fusion expression coexists with knn / match expression ";
Expand All @@ -146,13 +180,16 @@ void HTTPSearch::Process(Infinity *infinity_ptr,
search_expr->AddExpr(knn_expr);
knn_expr = nullptr;
} else if (IsEqual(key, "match")) {
if (fusion_exprs != nullptr or knn_expr != nullptr or match_expr != nullptr) {
if (fusion_expr != nullptr or knn_expr != nullptr or match_expr != nullptr) {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] =
"There are more than one fusion expressions, Or fusion expression coexists with knn / match expression ";
return;
}

auto &match_json = elem.value();
match_expr = ParseMatch(match_json, http_status, response);
search_expr->AddExpr(match_expr);
match_expr = nullptr;
} else {
response["error_code"] = ErrorCode::kInvalidExpression;
response["error_message"] = "Unknown expression: " + key;
Expand Down
108 changes: 92 additions & 16 deletions src/storage/column_vector/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,57 +725,133 @@ void Value::Reset() {

String Value::ToString() const {
switch (type_.type()) {
case kBoolean: {
case LogicalType::kBoolean: {
return value_.boolean ? "true" : "false";
}
case kTinyInt: {
case LogicalType::kTinyInt: {
return std::to_string(value_.tiny_int);
}
case kSmallInt: {
case LogicalType::kSmallInt: {
return std::to_string(value_.small_int);
}
case kInteger: {
case LogicalType::kInteger: {
return std::to_string(value_.integer);
}
case kBigInt: {
case LogicalType::kBigInt: {
return std::to_string(value_.big_int);
}
case kHugeInt: {
case LogicalType::kHugeInt: {
return value_.huge_int.ToString();
}
case kFloat: {
case LogicalType::kFloat: {
return std::to_string(value_.float32);
}
case kDouble: {
case LogicalType::kDouble: {
return std::to_string(value_.float64);
}
case kDate: {
case LogicalType::kDate: {
return value_.date.ToString();
}
case kTime: {
case LogicalType::kTime: {
return value_.time.ToString();
}
case kDateTime: {
case LogicalType::kDateTime: {
return value_.datetime.ToString();
}
case kTimestamp: {
case LogicalType::kTimestamp: {
return value_.timestamp.ToString();
}
case kInterval: {
case LogicalType::kInterval: {
return value_.interval.ToString();
}
case kRowID: {
case LogicalType::kRowID: {
return value_.row.ToString();
}
case kVarchar: {
case LogicalType::kVarchar: {
return value_info_->Get<StringValueInfo>().GetString();
}
case LogicalType::kEmbedding: {
EmbeddingInfo* embedding_info = static_cast<EmbeddingInfo*>(type_.type_info().get());
return value_info_->Get<EmbeddingValueInfo>().GetString(embedding_info);
}
default: {
UnrecoverableError(fmt::format("Value::ToString() not implemented for type {}", type_.ToString()));
return {};
}
}
return "";
return {};
}

String EmbeddingValueInfo::GetString(EmbeddingInfo* embedding_info) {
String res;
SizeT count = embedding_info->Dimension();
char* ptr = data_.data();
switch(embedding_info->Type()) {
case EmbeddingDataType::kElemBit: {
UnrecoverableError("Not implemented embedding data type: bit.");
break;
}
case EmbeddingDataType::kElemInt8: {
for(SizeT i = 0; i < count - 1; ++ i) {
i8 element = ((i8*)ptr)[i];
res += std::to_string(element) + ", ";
}
i8 element = ((i8*)ptr)[count - 1];
res += std::to_string(element);
break;
}
case EmbeddingDataType::kElemInt16: {
for(SizeT i = 0; i < count - 1; ++ i) {
i16 element = ((i16*)ptr)[i];
res += std::to_string(element) + ", ";
}
i16 element = ((i16*)ptr)[count - 1];
res += std::to_string(element);
break;
}
case EmbeddingDataType::kElemInt32: {
for(SizeT i = 0; i < count - 1; ++ i) {
i32 element = ((i32*)ptr)[i];
res += std::to_string(element) + ", ";
}
i32 element = ((i32*)ptr)[count - 1];
res += std::to_string(element);
break;
}
case EmbeddingDataType::kElemInt64: {
for(SizeT i = 0; i < count - 1; ++ i) {
i64 element = ((i64*)ptr)[i];
res += std::to_string(element) + ", ";
}
i64 element = ((i64*)ptr)[count - 1];
res += std::to_string(element);
break;
}
case EmbeddingDataType::kElemFloat: {
for(SizeT i = 0; i < count - 1; ++ i) {
f32 element = ((f32*)ptr)[i];
res += std::to_string(element) + ", ";
}
f32 element = ((f32*)ptr)[count - 1];
res += std::to_string(element);
break;
}
case EmbeddingDataType::kElemDouble: {
for(SizeT i = 0; i < count - 1; ++ i) {
f64 element = ((f64*)ptr)[i];
res += std::to_string(element) + ", ";
}
f64 element = ((f64*)ptr)[count - 1];
res += std::to_string(element);
break;
}
default: {
UnrecoverableError("Not supported embedding data type.");
break;
}
}
return res;
}


} // namespace infinity
Loading

0 comments on commit e1a92a7

Please sign in to comment.