Skip to content

Commit

Permalink
Add sort/orderby for infinity-sdk and infinity-embedded-sdk (#1944)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
- add sort/orderby for infinity-sdk and infinity-embedded-sdk

TODO: add sort/orderby for http

### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- [x] Test cases
- [x] Python SDK impacted, Need to update PyPI
  • Loading branch information
Ami11111 authored Oct 8, 2024
1 parent c9d3f47 commit ae40b5a
Show file tree
Hide file tree
Showing 22 changed files with 331 additions and 20 deletions.
2 changes: 1 addition & 1 deletion benchmark/local_infinity/fulltext/fulltext_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ void BenchmarkQuery(SharedPtr<Infinity> infinity, const String &db_name, const S
output_columns->emplace_back(select_rowid_expr);
output_columns->emplace_back(select_score_expr);
}
infinity->Search(db_name, table_name, search_expr, nullptr, nullptr, nullptr, output_columns);
infinity->Search(db_name, table_name, search_expr, nullptr, nullptr, nullptr, output_columns, nullptr);
/*
auto result = infinity->Search(db_name, table_name, search_expr, nullptr, output_columns);
{
Expand Down
2 changes: 1 addition & 1 deletion benchmark/local_infinity/infinity_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ int main() {
output_columns->emplace_back(col2);

[[maybe_unused]] auto ignored =
infinity->Search("default_db", "benchmark_test", nullptr, nullptr, nullptr, nullptr, output_columns);
infinity->Search("default_db", "benchmark_test", nullptr, nullptr, nullptr, nullptr, output_columns, nullptr);
});
results.push_back(fmt::format("-> Select QPS: {}", total_times / tims_costing_second));
}
Expand Down
2 changes: 1 addition & 1 deletion benchmark/local_infinity/knn/knn_query_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ int main(int argc, char *argv[]) {
auto select_rowid_expr = new FunctionExpr();
select_rowid_expr->func_name_ = "row_id";
output_columns->emplace_back(select_rowid_expr);
auto result = infinity->Search(db_name, table_name, search_expr, nullptr, nullptr, nullptr, output_columns);
auto result = infinity->Search(db_name, table_name, search_expr, nullptr, nullptr, nullptr, output_columns, nullptr);
{
auto &cv = result.result_table_->GetDataBlockById(0)->column_vectors;
auto &column = *cv[0];
Expand Down
3 changes: 3 additions & 0 deletions python/infinity_embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class ConflictType(object):
Error = 1
Replace = 2

class SortType(object):
Asc = 0
Desc = 1

class InfinityException(Exception):
def __init__(self, error_code=0, error_message=None):
Expand Down
5 changes: 3 additions & 2 deletions python/infinity_embedded/local_infinity/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,13 @@ def export_data(self, db_name: str, table_name: str, file_name: str, export_opti
return self.convert_res(self.client.Export(db_name, table_name, columns, file_name, export_options))

def select(self, db_name: str, table_name: str, select_list: list[WrapParsedExpr], search_expr,
where_expr, limit_expr, offset_expr, group_by_list=None):
where_expr, limit_expr, offset_expr, order_by_list: list[WrapOrderByExpr], group_by_list=None):
if self.client is None:
raise Exception("Local infinity is not connected")
return self.convert_res(self.client.Search(db_name, table_name, select_list,
order_by_list=order_by_list,
wrap_search_expr=search_expr, where_expr=where_expr,
limit_expr=limit_expr, offset_expr=offset_expr),
limit_expr=limit_expr, offset_expr=offset_expr, ),
has_result_data=True)

def explain(self, db_name: str, table_name: str, explain_type, select_list, search_expr,
Expand Down
75 changes: 74 additions & 1 deletion python/infinity_embedded/local_infinity/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ def __init__(
filter: Optional[WrapParsedExpr],
limit: Optional[WrapParsedExpr],
offset: Optional[WrapParsedExpr],
sort: Optional[WrapOrderByExpr]
):
self.columns = columns
self.search = search
self.filter = filter
self.limit = limit
self.offset = offset
self.sort = sort


class ExplainQuery(Query):
Expand All @@ -44,7 +46,7 @@ def __init__(
offset: Optional[WrapParsedExpr],
explain_type: Optional[BaseExplainType],
):
super().__init__(columns, search, filter, limit, offset)
super().__init__(columns, search, filter, limit, offset, None)
self.explain_type = explain_type


Expand All @@ -56,13 +58,15 @@ def __init__(self, table):
self._filter = None
self._limit = None
self._offset = None
self._sort = []

def reset(self):
self._columns = None
self._search = None
self._filter = None
self._limit = None
self._offset = None
self._sort = []

def match_dense(
self,
Expand Down Expand Up @@ -434,13 +438,82 @@ def output(self, columns: Optional[list]) -> InfinityLocalQueryBuilder:
self._columns = select_list
return self

def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityLocalQueryBuilder:
sort_list: List[WrapOrderByExpr] = []
for order_by_expr in order_by_expr_list:
if isinstance(order_by_expr[0], str):
order_by_expr[0] = order_by_expr[0].lower()

match order_by_expr[0]:
case "*":
column_expr = WrapColumnExpr()
column_expr.star = True

parsed_expr = WrapParsedExpr(ParsedExprType.kColumn)
parsed_expr.column_expr = column_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case "_row_id":
func_expr = WrapFunctionExpr()
func_expr.func_name = "row_id"
func_expr.arguments = []

expr_type = ParsedExprType(ParsedExprType.kFunction)
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case "_score":
func_expr = WrapFunctionExpr()
func_expr.func_name = "score"
func_expr.arguments = []

expr_type = ParsedExprType(ParsedExprType.kFunction)
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case "_similarity":
func_expr = WrapFunctionExpr()
func_expr.func_name = "similarity"
func_expr.arguments = []

expr_type = ParsedExprType(ParsedExprType.kFunction)
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case "_distance":
func_expr = WrapFunctionExpr()
func_expr.func_name = "distance"
func_expr.arguments = []

expr_type = ParsedExprType(ParsedExprType.kFunction)
parsed_expr = WrapParsedExpr(expr_type)
parsed_expr.function_expr = func_expr

order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)
case _:
parsed_expr = parse_expr(maybe_parse(order_by_expr[0]))
order_by_expr = WrapOrderByExpr(parsed_expr, order_by_expr[1])
sort_list.append(order_by_expr)

self._sort = sort_list
return self

def to_result(self):
query = Query(
columns=self._columns,
search=self._search,
filter=self._filter,
limit=self._limit,
offset=self._offset,
sort=self._sort,
)
self.reset()
return self._table._execute_query(query)
Expand Down
18 changes: 16 additions & 2 deletions python/infinity_embedded/local_infinity/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from infinity_embedded.embedded_infinity_ext import ConflictType as LocalConflictType
from infinity_embedded.embedded_infinity_ext import WrapIndexInfo, ImportOptions, CopyFileType, WrapParsedExpr, \
ParsedExprType, WrapUpdateExpr, ExportOptions, WrapOptimizeOptions
from infinity_embedded.common import ConflictType, DEFAULT_MATCH_VECTOR_TOPN
from infinity_embedded.common import ConflictType, DEFAULT_MATCH_VECTOR_TOPN, SortType
from infinity_embedded.common import INSERT_DATA, VEC, SparseVector, InfinityException
from infinity_embedded.errors import ErrorCode
from infinity_embedded.index import IndexInfo
Expand Down Expand Up @@ -357,6 +357,19 @@ def limit(self, limit: Optional[int]):
def offset(self, offset: Optional[int]):
self.query_builder.offset(offset)
return self

def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]):
for order_by_expr in order_by_expr_list:
if len(order_by_expr) != 2:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"order_by_expr_list must be a list of [column_name, sort_type]")
if order_by_expr[1] not in [SortType.Asc, SortType.Desc]:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"sort_type must be SortType.Asc or SortType.Desc")
if order_by_expr[1] == SortType.Asc:
order_by_expr[1] = True
else :
order_by_expr[1] = False
self.query_builder.sort(order_by_expr_list)
return self

def to_df(self):
return self.query_builder.to_df()
Expand Down Expand Up @@ -398,7 +411,8 @@ def _execute_query(self, query: Query):
where_expr=query.filter,
group_by_list=None,
limit_expr=query.limit,
offset_expr=query.offset)
offset_expr=query.offset,
order_by_list=query.sort)

# process the results
if res.error_code == ErrorCode.OK:
Expand Down
3 changes: 3 additions & 0 deletions python/infinity_sdk/infinity/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class ConflictType(object):
Error = 1
Replace = 2

class SortType(object):
Asc = 0
Desc = 1

class InfinityException(Exception):
def __init__(self, error_code=0, error_message=None):
Expand Down
3 changes: 2 additions & 1 deletion python/infinity_sdk/infinity/remote_thrift/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def export_data(self, db_name: str, table_name: str, file_name: str, export_opti
export_option=export_options))

def select(self, db_name: str, table_name: str, select_list, search_expr,
where_expr, group_by_list, limit_expr, offset_expr):
where_expr, group_by_list, limit_expr, offset_expr, order_by_list):
return self.client.Select(SelectRequest(session_id=self.session_id,
db_name=db_name,
table_name=table_name,
Expand All @@ -209,6 +209,7 @@ def select(self, db_name: str, table_name: str, select_list, search_expr,
group_by_list=group_by_list,
limit_expr=limit_expr,
offset_expr=offset_expr,
order_by_list=order_by_list
))

def explain(self, db_name: str, table_name: str, select_list, search_expr,
Expand Down
54 changes: 52 additions & 2 deletions python/infinity_sdk/infinity/remote_thrift/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyarrow import Table
from sqlglot import condition, maybe_parse

from infinity.common import VEC, SparseVector, InfinityException
from infinity.common import VEC, SparseVector, InfinityException, SortType
from infinity.errors import ErrorCode
from infinity.remote_thrift.infinity_thrift_rpc.ttypes import *
from infinity.remote_thrift.types import (
Expand All @@ -45,12 +45,14 @@ def __init__(
filter: Optional[ParsedExpr],
limit: Optional[ParsedExpr],
offset: Optional[ParsedExpr],
sort: Optional[List[OrderByExpr]],
):
self.columns = columns
self.search = search
self.filter = filter
self.limit = limit
self.offset = offset
self.sort = sort


class ExplainQuery(Query):
Expand All @@ -61,9 +63,10 @@ def __init__(
filter: Optional[ParsedExpr],
limit: Optional[ParsedExpr],
offset: Optional[ParsedExpr],
#sort: Optional[List[OrderByExpr]],
explain_type: Optional[ExplainType],
):
super().__init__(columns, search, filter, limit, offset)
super().__init__(columns, search, filter, limit, offset, None)
self.explain_type = explain_type


Expand All @@ -75,13 +78,15 @@ def __init__(self, table):
self._filter = None
self._limit = None
self._offset = None
self._sort = None

def reset(self):
self._columns = None
self._search = None
self._filter = None
self._limit = None
self._offset = None
self._sort = None

def match_dense(
self,
Expand Down Expand Up @@ -340,6 +345,50 @@ def output(self, columns: Optional[list]) -> InfinityThriftQueryBuilder:

self._columns = select_list
return self

def sort(self, order_by_expr_list: Optional[List[list[str, bool]]]) -> InfinityThriftQueryBuilder:
sort_list: List[OrderByExpr] = []
for order_by_expr in order_by_expr_list:
if isinstance(order_by_expr[0], str):
order_by_expr[0] = order_by_expr[0].lower()

match order_by_expr[0]:
case "*":
column_expr = ColumnExpr(star=True, column_name=[])
expr_type = ParsedExprType(column_expr=column_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case "_row_id":
func_expr = FunctionExpr(function_name="row_id", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case "_score":
func_expr = FunctionExpr(function_name="score", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case "_similarity":
func_expr = FunctionExpr(function_name="similarity", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case "_distance":
func_expr = FunctionExpr(function_name="distance", arguments=[])
expr_type = ParsedExprType(function_expr=func_expr)
parsed_expr = ParsedExpr(type=expr_type)
order_by_expr = OrderByExpr(expr = parsed_expr, asc = order_by_expr[1])
sort_list.append(order_by_expr)
case _:
parsed_expr = parse_expr(maybe_parse(order_by_expr[0]))
sort_list.append(OrderByExpr(expr = parsed_expr, asc = order_by_expr[1]))

self._sort = sort_list
return self

def to_result(self) -> tuple[dict[str, list[Any]], dict[str, Any]]:
query = Query(
Expand All @@ -348,6 +397,7 @@ def to_result(self) -> tuple[dict[str, list[Any]], dict[str, Any]]:
filter=self._filter,
limit=self._limit,
offset=self._offset,
sort=self._sort,
)
self.reset()
return self._table._execute_query(query)
Expand Down
18 changes: 16 additions & 2 deletions python/infinity_sdk/infinity/remote_thrift/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
get_ordinary_info,
)
from infinity.table import ExplainType
from infinity.common import ConflictType, DEFAULT_MATCH_VECTOR_TOPN
from infinity.common import ConflictType, DEFAULT_MATCH_VECTOR_TOPN, SortType
from infinity.utils import deprecated_api


Expand Down Expand Up @@ -376,6 +376,19 @@ def limit(self, limit: Optional[int]):
def offset(self, offset: Optional[int]):
self.query_builder.offset(offset)
return self

def sort(self, order_by_expr_list: Optional[List[list[str, SortType]]]):
for order_by_expr in order_by_expr_list:
if len(order_by_expr) != 2:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"order_by_expr_list must be a list of [column_name, sort_type]")
if order_by_expr[1] not in [SortType.Asc, SortType.Desc]:
raise InfinityException(ErrorCode.INVALID_PARAMETER, f"sort_type must be SortType.Asc or SortType.Desc")
if order_by_expr[1] == SortType.Asc:
order_by_expr[1] = True
else :
order_by_expr[1] = False
self.query_builder.sort(order_by_expr_list)
return self

def to_result(self):
return self.query_builder.to_result()
Expand Down Expand Up @@ -421,7 +434,8 @@ def _execute_query(self, query: Query) -> tuple[dict[str, list[Any]], dict[str,
where_expr=query.filter,
group_by_list=None,
limit_expr=query.limit,
offset_expr=query.offset)
offset_expr=query.offset,
order_by_list=query.sort)

# process the results
if res.error_code == ErrorCode.OK:
Expand Down
Loading

0 comments on commit ae40b5a

Please sign in to comment.