diff --git a/eva/executor/create_mat_view_executor.py b/eva/executor/create_mat_view_executor.py index 3661792b44..d736110e5f 100644 --- a/eva/executor/create_mat_view_executor.py +++ b/eva/executor/create_mat_view_executor.py @@ -33,10 +33,17 @@ def exec(self): """Create materialized view executor""" if not handle_if_not_exists(self.node.view, self.node.if_not_exists): child = self.children[0] + project_cols = None # only support seq scan based materialization - if child.node.opr_type != PlanOprType.SEQUENTIAL_SCAN: - err_msg = "Invalid query {}, expected {}".format( - child.node.opr_type, PlanOprType.SEQUENTIAL_SCAN + if child.node.opr_type == PlanOprType.SEQUENTIAL_SCAN: + project_cols = child.project_expr + elif child.node.opr_type == PlanOprType.PROJECT: + project_cols = child.target_list + else: + err_msg = "Invalid query {}, expected {} or {}".format( + child.node.opr_type, + PlanOprType.SEQUENTIAL_SCAN, + PlanOprType.PROJECT, ) logger.error(err_msg) @@ -44,7 +51,7 @@ def exec(self): # gather child projected column objects child_objs = [] - for child_col in child.project_expr: + for child_col in project_cols: if child_col.etype == ExpressionType.TUPLE_VALUE: child_objs.append(child_col.col_object) elif child_col.etype == ExpressionType.FUNCTION_EXPRESSION: diff --git a/eva/executor/storage_executor.py b/eva/executor/storage_executor.py index 68f09950b9..f38ca60425 100644 --- a/eva/executor/storage_executor.py +++ b/eva/executor/storage_executor.py @@ -29,6 +29,10 @@ def validate(self): def exec(self) -> Iterator[Batch]: if self.node.video.is_video: - return VideoStorageEngine.read(self.node.video, self.node.batch_mem_size) + return VideoStorageEngine.read( + self.node.video, + self.node.batch_mem_size, + predicate=self.node.predicate, + ) else: return StorageEngine.read(self.node.video, self.node.batch_mem_size) diff --git a/eva/expression/expression_utils.py b/eva/expression/expression_utils.py index 56f494543e..290b7f631a 100644 --- a/eva/expression/expression_utils.py +++ b/eva/expression/expression_utils.py @@ -12,7 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from eva.expression.abstract_expression import ExpressionType + +from typing import List + +from eva.expression.abstract_expression import AbstractExpression, ExpressionType +from eva.expression.comparison_expression import ComparisonExpression +from eva.expression.constant_value_expression import ConstantValueExpression +from eva.expression.logical_expression import LogicalExpression +from eva.expression.tuple_value_expression import TupleValueExpression def expression_tree_to_conjunction_list(expression_tree): @@ -28,3 +35,234 @@ def expression_tree_to_conjunction_list(expression_tree): expression_list.append(expression_tree) return expression_list + + +def conjuction_list_to_expression_tree( + expression_list: List[AbstractExpression], +) -> AbstractExpression: + """Convert expression list to expression tree wuing conjuction connector + + Args: + expression_list (List[AbstractExpression]): list of conjunctives + + Returns: + AbstractExpression: expression tree + """ + if len(expression_list) == 0: + return None + prev_expr = expression_list[0] + for expr in expression_list[1:]: + prev_expr = LogicalExpression(ExpressionType.LOGICAL_AND, prev_expr, expr) + return prev_expr + + +def extract_range_list_from_comparison_expr( + expr: ComparisonExpression, lower_bound: int, upper_bound: int +) -> List: + """Extracts the valid range from the comparison expression. + The expression needs to be amongst <, >, <=, >=, =, !=. + + Args: + expr (ComparisonExpression): comparison expression with two children + that are leaf expression nodes. If the input doesnot match, + the function return False + lower_bound (int): lower bound of the comparison predicate + upper_bound (int): upper bound of the comparison predicate + + Returns: + List[Tuple(int)]: list of valid ranges + + Raises: + RuntimeError: Invalid expression + + Example: + extract_range_from_comparison_expr(id < 10, 0, inf): True, [(0,9)] + """ + + if not isinstance(expr, ComparisonExpression): + raise RuntimeError(f"Expected Comparision Expression, got {type(expr)}") + left = expr.children[0] + right = expr.children[1] + expr_type = expr.etype + val = None + const_first = False + if isinstance(left, TupleValueExpression) and isinstance( + right, ConstantValueExpression + ): + val = right.value + elif isinstance(left, ConstantValueExpression) and isinstance( + right, TupleValueExpression + ): + val = left.value + const_first = True + else: + raise RuntimeError( + f"Only supports extracting range from Comparision Expression \ + with two children TupleValueExpression and \ + ConstantValueExpression, got {left} and {right}" + ) + + if const_first: + if expr_type is ExpressionType.COMPARE_GREATER: + expr_type = ExpressionType.COMPARE_LESSER + elif expr_type is ExpressionType.COMPARE_LESSER: + expr_type = ExpressionType.COMPARE_GREATER + elif expr_type is ExpressionType.COMPARE_GEQ: + expr_type = ExpressionType.COMPARE_LEQ + elif expr_type is ExpressionType.COMPARE_LEQ: + expr_type = ExpressionType.COMPARE_GEQ + + valid_ranges = [] + if expr_type == ExpressionType.COMPARE_EQUAL: + valid_ranges.append((val, val)) + elif expr_type == ExpressionType.COMPARE_NEQ: + valid_ranges.append((lower_bound, val - 1)) + valid_ranges.append((val + 1, upper_bound)) + elif expr_type == ExpressionType.COMPARE_GREATER: + valid_ranges.append((val + 1, upper_bound)) + elif expr_type == ExpressionType.COMPARE_GEQ: + valid_ranges.append((val, upper_bound)) + elif expr_type == ExpressionType.COMPARE_LESSER: + valid_ranges.append((lower_bound, val - 1)) + elif expr_type == ExpressionType.COMPARE_LEQ: + valid_ranges.append((lower_bound, val)) + else: + raise RuntimeError(f"Unsupported Expression Type {expr_type}") + return valid_ranges + + +def extract_range_list_from_predicate( + predicate: AbstractExpression, lower_bound: int, upper_bound: int +) -> List: + """The function converts the range predicate on the column in the + `predicate` to a list of [(start_1, end_1), ... ] pairs. + Assumes the predicate contains conditions on only one column + + Args: + predicate (AbstractExpression): Input predicate to extract + valid ranges. The predicate should contain conditions on + only one columns, else it raise error. + lower_bound (int): lower bound of the comparison predicate + upper_bound (int): upper bound of the comparison predicate + + Returns: + List[Tuple]: list of (start, end) pairs of valid ranges + + Example: + id < 10 : [(0, 9)] + id > 5 AND id < 10 : [(6, 9)] + id < 10 OR id >20 : [(0, 9), (21, Inf)] + """ + + def overlap(x, y): + overlap = (max(x[0], y[0]), min(x[1], y[1])) + if overlap[0] <= overlap[1]: + return overlap + + def union(ranges: List): + # union all the ranges + reduced_list = [] + for begin, end in sorted(ranges): + if reduced_list and reduced_list[-1][1] >= begin - 1: + reduced_list[-1] = ( + reduced_list[-1][0], + max(reduced_list[-1][1], end), + ) + else: + reduced_list.append((begin, end)) + return reduced_list + + if predicate.etype == ExpressionType.LOGICAL_AND: + left_ranges = extract_range_list_from_predicate( + predicate.children[0], lower_bound, upper_bound + ) + right_ranges = extract_range_list_from_predicate( + predicate.children[1], lower_bound, upper_bound + ) + valid_overlaps = [] + for left_range in left_ranges: + for right_range in right_ranges: + over = overlap(left_range, right_range) + if over: + valid_overlaps.append(over) + return union(valid_overlaps) + + elif predicate.etype == ExpressionType.LOGICAL_OR: + left_ranges = extract_range_list_from_predicate( + predicate.children[0], lower_bound, upper_bound + ) + right_ranges = extract_range_list_from_predicate( + predicate.children[1], lower_bound, upper_bound + ) + return union(left_ranges + right_ranges) + + elif isinstance(predicate, ComparisonExpression): + return union( + extract_range_list_from_comparison_expr(predicate, lower_bound, upper_bound) + ) + + else: + raise RuntimeError(f"Contains unsuporrted expression {type(predicate)}") + + +def contains_single_column(predicate: AbstractExpression, column: str = None) -> bool: + """Checks if predicate contains conditions on single predicate + + Args: + predicate (AbstractExpression): predicate expression + column_alias (str): check if the single column matches + the input column_alias + Returns: + bool: True, if contains single predicate, else False + if predicate is None, return False + """ + + def get_columns(predicate): + if isinstance(predicate, TupleValueExpression): + return set([predicate.col_alias]) + cols = set() + for child in predicate.children: + child_cols = get_columns(child) + if len(child_cols): + cols.update(child_cols) + return cols + + if not predicate: + return False + + cols = get_columns(predicate) + if len(cols) == 1: + if column is None: + return True + pred_col = cols.pop() + if pred_col == column: + return True + return False + + +def is_simple_predicate(predicate: AbstractExpression) -> bool: + """Checks if conditions in the predicate are on a single column and + only contains LogicalExpression, ComparisonExpression, + TupleValueExpression or ConstantValueExpression + + Args: + predicate (AbstractExpression): predicate expression to check + + Returns: + bool: True, if it is a simple predicate, lese False + """ + + def _has_simple_expressions(expr): + simple = type(expr) in simple_expressions + for child in expr.children: + simple = simple and _has_simple_expressions(child) + return simple + + simple_expressions = [ + LogicalExpression, + ComparisonExpression, + TupleValueExpression, + ConstantValueExpression, + ] + + return _has_simple_expressions(predicate) and contains_single_column(predicate) diff --git a/eva/optimizer/optimizer_utils.py b/eva/optimizer/optimizer_utils.py index 0f8ae4f391..7998ac30d8 100644 --- a/eva/optimizer/optimizer_utils.py +++ b/eva/optimizer/optimizer_utils.py @@ -16,7 +16,12 @@ from eva.catalog.catalog_manager import CatalogManager from eva.expression.abstract_expression import AbstractExpression, ExpressionType -from eva.expression.expression_utils import expression_tree_to_conjunction_list +from eva.expression.expression_utils import ( + conjuction_list_to_expression_tree, + contains_single_column, + expression_tree_to_conjunction_list, + is_simple_predicate, +) from eva.parser.create_statement import ColumnDefinition from eva.utils.logging_manager import logger @@ -80,3 +85,37 @@ def extract_equi_join_keys( right_join_keys.append(left_child) return (left_join_keys, right_join_keys) + + +def extract_pushdown_predicate( + predicate: AbstractExpression, column_alias: str +) -> Tuple[AbstractExpression, AbstractExpression]: + """Decompose the predicate into pushdown predicate and remaining predicate + + Args: + predicate (AbstractExpression): predicate that needs to be decomposed + column (str): column_alias to extract predicate + Returns: + Tuple[AbstractExpression, AbstractExpression]: (pushdown predicate, + remaining predicate) + """ + if predicate is None: + return None, None + + if contains_single_column(predicate, column_alias): + if is_simple_predicate(predicate): + return predicate, None + + pushdown_preds = [] + rem_pred = [] + pred_list = expression_tree_to_conjunction_list(predicate) + for pred in pred_list: + if contains_single_column(pred, column_alias) and is_simple_predicate(pred): + pushdown_preds.append(pred) + else: + rem_pred.append(pred) + + return ( + conjuction_list_to_expression_tree(pushdown_preds), + conjuction_list_to_expression_tree(rem_pred), + ) diff --git a/eva/optimizer/rules/rules.py b/eva/optimizer/rules/rules.py index dd748186c5..0ef979bc8b 100644 --- a/eva/optimizer/rules/rules.py +++ b/eva/optimizer/rules/rules.py @@ -18,7 +18,13 @@ from enum import Flag, IntEnum, auto from typing import TYPE_CHECKING -from eva.optimizer.optimizer_utils import extract_equi_join_keys +from eva.optimizer.optimizer_utils import ( + extract_equi_join_keys, + extract_pushdown_predicate, +) +from eva.optimizer.rules.pattern import Pattern +from eva.parser.types import JoinType +from eva.planner.create_mat_view_plan import CreateMaterializedViewPlan from eva.planner.hash_join_build_plan import HashJoinBuildPlan from eva.planner.predicate_plan import PredicatePlan from eva.planner.project_plan import ProjectPlan @@ -53,9 +59,6 @@ Operator, OperatorType, ) -from eva.optimizer.rules.pattern import Pattern -from eva.parser.types import JoinType -from eva.planner.create_mat_view_plan import CreateMaterializedViewPlan from eva.planner.create_plan import CreatePlan from eva.planner.create_udf_plan import CreateUDFPlan from eva.planner.drop_plan import DropPlan @@ -238,22 +241,44 @@ def __init__(self): def promise(self): return Promise.EMBED_FILTER_INTO_GET - def check(self, before: Operator, context: OptimizerContext): - # nothing else to check if logical match found return true - return True + def check(self, before: LogicalFilter, context: OptimizerContext): + # System supports predicate pushdown only while reading video data + predicate = before.predicate + lget: LogicalGet = before.children[0] + if predicate and lget.dataset_metadata.is_video: + # System only supports pushing basic range predicates on id + video_alias = lget.video.alias + col_alias = f"{video_alias}.id" + pushdown_pred, _ = extract_pushdown_predicate(predicate, col_alias) + if pushdown_pred: + return True + return False def apply(self, before: LogicalFilter, context: OptimizerContext): predicate = before.predicate lget = before.children[0] - new_get_opr = LogicalGet( - lget.video, - lget.dataset_metadata, - alias=lget.alias, - predicate=predicate, - target_list=lget.target_list, - children=lget.children, + # System only supports pushing basic range predicates on id + video_alias = lget.video.alias + col_alias = f"{video_alias}.id" + pushdown_pred, unsupported_pred = extract_pushdown_predicate( + predicate, col_alias ) - return new_get_opr + if pushdown_pred: + new_get_opr = LogicalGet( + lget.video, + lget.dataset_metadata, + alias=lget.alias, + predicate=pushdown_pred, + target_list=lget.target_list, + children=lget.children, + ) + if unsupported_pred: + unsupported_opr = LogicalFilter(unsupported_pred) + unsupported_opr.append_child(new_get_opr) + return unsupported_opr + return new_get_opr + else: + return before class EmbedProjectIntoGet(Rule): @@ -604,9 +629,13 @@ def apply(self, before: LogicalGet, context: OptimizerContext): ) if config_batch_mem_size: batch_mem_size = config_batch_mem_size - after = SeqScanPlan(before.predicate, before.target_list, before.alias) + after = SeqScanPlan(None, before.target_list, before.alias) after.append_child( - StoragePlan(before.dataset_metadata, batch_mem_size=batch_mem_size) + StoragePlan( + before.dataset_metadata, + batch_mem_size=batch_mem_size, + predicate=before.predicate, + ) ) return after @@ -777,7 +806,10 @@ def apply(self, join_node: LogicalJoin, context: OptimizerContext): build_plan = HashJoinBuildPlan(join_node.join_type, a_join_keys) build_plan.append_child(a) probe_side = HashJoinProbePlan( - join_node.join_type, b_join_keys, join_predicates, join_node.join_project + join_node.join_type, + b_join_keys, + join_predicates, + join_node.join_project, ) probe_side.append_child(build_plan) probe_side.append_child(b) @@ -798,7 +830,9 @@ def check(self, grp_id: int, context: OptimizerContext): def apply(self, before: LogicalCreateMaterializedView, context: OptimizerContext): after = CreateMaterializedViewPlan( - before.view, columns=before.col_list, if_not_exists=before.if_not_exists + before.view, + columns=before.col_list, + if_not_exists=before.if_not_exists, ) for child in before.children: after.append_child(child) @@ -878,10 +912,10 @@ def __init__(self): self._rewrite_rules = [ EmbedFilterIntoGet(), - EmbedFilterIntoDerivedGet(), + # EmbedFilterIntoDerivedGet(), PushdownFilterThroughSample(), EmbedProjectIntoGet(), - EmbedProjectIntoDerivedGet(), + # EmbedProjectIntoDerivedGet(), PushdownProjectThroughSample(), ] diff --git a/eva/planner/storage_plan.py b/eva/planner/storage_plan.py index b1fefe274a..0488ff8d08 100644 --- a/eva/planner/storage_plan.py +++ b/eva/planner/storage_plan.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from eva.catalog.models.df_metadata import DataFrameMetadata +from eva.expression.abstract_expression import AbstractExpression from eva.planner.abstract_plan import AbstractPlan from eva.planner.types import PlanOprType @@ -41,6 +42,7 @@ def __init__( limit: int = None, total_shards: int = 0, curr_shard: int = 0, + predicate: AbstractExpression = None, ): super().__init__(PlanOprType.STORAGE_PLAN) self._video = video @@ -50,6 +52,7 @@ def __init__( self._limit = limit self._total_shards = total_shards self._curr_shard = curr_shard + self._predicate = predicate @property def video(self): @@ -79,6 +82,10 @@ def total_shards(self): def curr_shard(self): return self._curr_shard + @property + def predicate(self): + return self._predicate + def __hash__(self) -> int: return hash( ( @@ -90,5 +97,6 @@ def __hash__(self) -> int: self.limit, self.total_shards, self.curr_shard, + self.predicate, ) ) diff --git a/eva/readers/opencv_reader.py b/eva/readers/opencv_reader.py index f620d99953..fa096752d4 100644 --- a/eva/readers/opencv_reader.py +++ b/eva/readers/opencv_reader.py @@ -16,38 +16,40 @@ import cv2 +from eva.expression.abstract_expression import AbstractExpression +from eva.expression.expression_utils import extract_range_list_from_predicate from eva.readers.abstract_reader import AbstractReader from eva.utils.logging_manager import logger class OpenCVReader(AbstractReader): - def __init__(self, *args, start_frame_id=0, **kwargs): - """ - Reads video using OpenCV and yields frame data. - It will use the `start_frame_id` while annotating the - frames. The first frame will be annotated with `start_frame_id` - Attributes: - start_frame_id (int): id assigned to first read frame - eg: start_frame_id=10, returned Iterator will be - [{10, frame1}, {11, frame2} ...] - It is different from offset. Offset defines where in video - should we start reading. And start_frame_id defines the id - we assign to first read frame. + def __init__(self, *args, predicate: AbstractExpression = None, **kwargs): + """Read frames from the disk + + Args: + predicate (AbstractExpression, optional): If only subset of frames + need to be read. The predicate should be only on single column and + can be converted to ranges. Defaults to None. """ - self._start_frame_id = start_frame_id + self._predicate = predicate super().__init__(*args, **kwargs) def _read(self) -> Iterator[Dict]: video = cv2.VideoCapture(self.file_url) - video_offset = self.offset if self.offset else 0 - video.set(cv2.CAP_PROP_POS_FRAMES, video_offset) - + num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + if self._predicate: + range_list = extract_range_list_from_predicate( + self._predicate, 0, num_frames - 1 + ) + else: + range_list = [(0, num_frames - 1)] logger.debug("Reading frames") - _, frame = video.read() - frame_id = self._start_frame_id - - while frame is not None: - yield {"id": frame_id, "data": frame} + for (begin, end) in range_list: + video.set(cv2.CAP_PROP_POS_FRAMES, begin) _, frame = video.read() - frame_id += 1 + frame_id = begin + while frame is not None and frame_id <= end: + yield {"id": frame_id, "data": frame} + _, frame = video.read() + frame_id += 1 diff --git a/eva/storage/opencv_storage_engine.py b/eva/storage/opencv_storage_engine.py index 2f7cc4d1bf..6156a83aed 100644 --- a/eva/storage/opencv_storage_engine.py +++ b/eva/storage/opencv_storage_engine.py @@ -19,6 +19,7 @@ from eva.catalog.models.df_metadata import DataFrameMetadata from eva.configuration.configuration_manager import ConfigurationManager +from eva.expression.abstract_expression import AbstractExpression from eva.models.storage.batch import Batch from eva.readers.opencv_reader import OpenCVReader from eva.storage.abstract_storage_engine import AbstractStorageEngine @@ -59,13 +60,18 @@ def write(self, table: DataFrameMetadata, rows: Batch): pass def read( - self, table: DataFrameMetadata, batch_mem_size: int, predicate_func=None + self, + table: DataFrameMetadata, + batch_mem_size: int, + predicate: AbstractExpression = None, ) -> Iterator[Batch]: metadata_file = Path(table.file_url) / self.metadata video_file_name = self._get_video_file_path(metadata_file) video_file = Path(table.file_url) / video_file_name - reader = OpenCVReader(str(video_file), batch_mem_size=batch_mem_size) + reader = OpenCVReader( + str(video_file), batch_mem_size=batch_mem_size, predicate=predicate + ) for batch in reader.read(): yield batch @@ -88,7 +94,10 @@ def _create_video_metadata(self, dir_path, video_file): file_path_bytes = str(video_file).encode() length = len(file_path_bytes) data = struct.pack( - "!HH%ds" % (length,), self.curr_version, length, file_path_bytes + "!HH%ds" % (length,), + self.curr_version, + length, + file_path_bytes, ) f.write(data) diff --git a/test/expression/test_arithmetic.py b/test/expression/test_arithmetic.py index 2437be5579..d81a01249e 100644 --- a/test/expression/test_arithmetic.py +++ b/test/expression/test_arithmetic.py @@ -82,5 +82,5 @@ def test_aaequality(self): self.assertNotEqual(cmpr_exp2, cmpr_exp3) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/expression/test_expression_utils.py b/test/expression/test_expression_utils.py new file mode 100644 index 0000000000..f16d5655a4 --- /dev/null +++ b/test/expression/test_expression_utils.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2018-2022 EVA +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest.mock import Mock + +from eva.expression.abstract_expression import ExpressionType +from eva.expression.arithmetic_expression import ArithmeticExpression +from eva.expression.comparison_expression import ComparisonExpression +from eva.expression.constant_value_expression import ConstantValueExpression +from eva.expression.expression_utils import ( + conjuction_list_to_expression_tree, + contains_single_column, + extract_range_list_from_comparison_expr, + extract_range_list_from_predicate, + is_simple_predicate, +) +from eva.expression.logical_expression import LogicalExpression +from eva.expression.tuple_value_expression import TupleValueExpression + + +class ExpressionUtilsTest(unittest.TestCase): + def gen_cmp_expr( + self, + val, + expr_type=ExpressionType.COMPARE_GREATER, + name="id", + const_first=False, + ): + constexpr = ConstantValueExpression(val) + colname = TupleValueExpression(col_name=name, col_alias=f"T.{name}") + if const_first: + return ComparisonExpression(expr_type, constexpr, colname) + return ComparisonExpression(expr_type, colname, constexpr) + + def test_extract_range_list_from_comparison_expr(self): + expr_types = [ + ExpressionType.COMPARE_NEQ, + ExpressionType.COMPARE_EQUAL, + ExpressionType.COMPARE_GREATER, + ExpressionType.COMPARE_LESSER, + ExpressionType.COMPARE_GEQ, + ExpressionType.COMPARE_LEQ, + ] + results = [] + for expr_type in expr_types: + cmpr_exp = self.gen_cmp_expr(10, expr_type, const_first=True) + results.append(extract_range_list_from_comparison_expr(cmpr_exp, 0, 100)) + expected = [ + [(0, 9), (11, 100)], + [(10, 10)], + [(0, 9)], + [(11, 100)], + [(0, 10)], + [(10, 100)], + ] + self.assertEqual(results, expected) + + results = [] + for expr_type in expr_types: + cmpr_exp = self.gen_cmp_expr(10, expr_type) + results.append(extract_range_list_from_comparison_expr(cmpr_exp, 0, 100)) + expected = [ + [(0, 9), (11, 100)], + [(10, 10)], + [(11, 100)], + [(0, 9)], + [(10, 100)], + [(0, 10)], + ] + self.assertEqual(results, expected) + + with self.assertRaises(RuntimeError): + cmpr_exp = LogicalExpression(ExpressionType.LOGICAL_AND, Mock(), Mock()) + extract_range_list_from_comparison_expr(cmpr_exp, 0, 100) + + with self.assertRaises(RuntimeError): + cmpr_exp = self.gen_cmp_expr(10, ExpressionType.COMPARE_CONTAINS) + extract_range_list_from_comparison_expr(cmpr_exp, 0, 100) + with self.assertRaises(RuntimeError): + cmpr_exp = self.gen_cmp_expr(10, ExpressionType.COMPARE_IS_CONTAINED) + extract_range_list_from_comparison_expr(cmpr_exp, 0, 100) + + def test_extract_range_list_from_predicate(self): + # id > 10 AND id > 20 -> (21, 100) + expr = LogicalExpression( + ExpressionType.LOGICAL_AND, + self.gen_cmp_expr(10), + self.gen_cmp_expr(20), + ) + self.assertEqual(extract_range_list_from_predicate(expr, 0, 100), [(21, 100)]) + # id > 10 OR id > 20 -> (11, 100) + expr = LogicalExpression( + ExpressionType.LOGICAL_OR, + self.gen_cmp_expr(10), + self.gen_cmp_expr(20), + ) + self.assertEqual(extract_range_list_from_predicate(expr, 0, 100), [(11, 100)]) + + # (id > 10 OR id > 20) OR (id > 10 AND id < 5) -> (11, 100) + expr1 = LogicalExpression( + ExpressionType.LOGICAL_OR, + self.gen_cmp_expr(10), + self.gen_cmp_expr(20), + ) + expr2 = LogicalExpression( + ExpressionType.LOGICAL_AND, + self.gen_cmp_expr(10), + self.gen_cmp_expr(5, ExpressionType.COMPARE_LESSER), + ) + expr = LogicalExpression(ExpressionType.LOGICAL_OR, expr1, expr2) + self.assertEqual(extract_range_list_from_predicate(expr, 0, 100), [(11, 100)]) + + # (id > 10 OR id > 20) AND (id > 10 AND id < 5) -> [] + expr = LogicalExpression(ExpressionType.LOGICAL_AND, expr1, expr2) + self.assertEqual(extract_range_list_from_predicate(expr, 0, 100), []) + + # (id < 10 OR id > 20) OR (id > 25 OR id < 5) -> [(0,9), (21,100)] + expr1 = LogicalExpression( + ExpressionType.LOGICAL_OR, + self.gen_cmp_expr(10, ExpressionType.COMPARE_LESSER), + self.gen_cmp_expr(20), + ) + expr2 = LogicalExpression( + ExpressionType.LOGICAL_OR, + self.gen_cmp_expr(25), + self.gen_cmp_expr(5, ExpressionType.COMPARE_LESSER), + ) + expr = LogicalExpression(ExpressionType.LOGICAL_OR, expr1, expr2) + self.assertEqual( + extract_range_list_from_predicate(expr, 0, 100), + [(0, 9), (21, 100)], + ) + + with self.assertRaises(RuntimeError): + expr = ArithmeticExpression( + ExpressionType.AGGREGATION_COUNT, Mock(), Mock() + ) + extract_range_list_from_predicate(expr, 0, 100) + + def test_predicate_contains_single_column(self): + self.assertTrue(contains_single_column(self.gen_cmp_expr(10))) + expr1 = LogicalExpression( + ExpressionType.LOGICAL_OR, + self.gen_cmp_expr(10, ExpressionType.COMPARE_GREATER, "x"), + self.gen_cmp_expr(10, ExpressionType.COMPARE_GREATER, "x"), + ) + self.assertTrue(contains_single_column(expr1)) + expr2 = LogicalExpression( + ExpressionType.LOGICAL_OR, + self.gen_cmp_expr(10, ExpressionType.COMPARE_GREATER, "x"), + self.gen_cmp_expr(10, ExpressionType.COMPARE_GREATER, "y"), + ) + self.assertFalse(contains_single_column(expr2)) + expr = LogicalExpression(ExpressionType.LOGICAL_OR, expr1, expr2) + self.assertFalse(contains_single_column(expr)) + + def test_is_simple_predicate(self): + self.assertTrue(is_simple_predicate(self.gen_cmp_expr(10))) + + expr = ArithmeticExpression(ExpressionType.AGGREGATION_COUNT, Mock(), Mock()) + self.assertFalse(is_simple_predicate(expr)) + + expr = LogicalExpression( + ExpressionType.LOGICAL_OR, + self.gen_cmp_expr(10, ExpressionType.COMPARE_GREATER, "x"), + self.gen_cmp_expr(10, ExpressionType.COMPARE_GREATER, "y"), + ) + self.assertFalse(is_simple_predicate(expr)) + + def test_conjuction_list_to_expression_tree(self): + expr1 = self.gen_cmp_expr(10) + expr2 = self.gen_cmp_expr(20) + new_expr = conjuction_list_to_expression_tree([expr1, expr2]) + self.assertEqual(new_expr.etype, ExpressionType.LOGICAL_AND) + self.assertEqual(new_expr.children[0], expr1) + self.assertEqual(new_expr.children[1], expr2) diff --git a/test/integration_tests/test_select_executor.py b/test/integration_tests/test_select_executor.py index ace80c90a0..ef037372af 100644 --- a/test/integration_tests/test_select_executor.py +++ b/test/integration_tests/test_select_executor.py @@ -325,7 +325,3 @@ def test_hash_join_with_multiple_tables(self): expected_batch.sort_orderby(["table1.a0"]), actual_batch.sort_orderby(["table1.a0"]), ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/optimizer/rules/test_rules.py b/test/optimizer/rules/test_rules.py index 84f94b5a9b..c98b885d0d 100644 --- a/test/optimizer/rules/test_rules.py +++ b/test/optimizer/rules/test_rules.py @@ -125,10 +125,10 @@ def test_supported_rules(self): # adding/removing rules should update this test supported_rewrite_rules = [ EmbedFilterIntoGet(), - EmbedFilterIntoDerivedGet(), + # EmbedFilterIntoDerivedGet(), PushdownFilterThroughSample(), EmbedProjectIntoGet(), - EmbedProjectIntoDerivedGet(), + # EmbedProjectIntoDerivedGet(), PushdownProjectThroughSample(), ] self.assertEqual( diff --git a/test/optimizer/test_optimizer_task.py b/test/optimizer/test_optimizer_task.py index cd3a07ca04..4860d239cb 100644 --- a/test/optimizer/test_optimizer_task.py +++ b/test/optimizer/test_optimizer_task.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest +from unittest.mock import patch from mock import MagicMock @@ -27,6 +28,8 @@ from eva.optimizer.optimizer_tasks import BottomUpRewrite, OptimizeGroup, TopDownRewrite from eva.optimizer.property import PropertyType from eva.optimizer.rules.rules import RulesManager +from eva.planner.predicate_plan import PredicatePlan +from eva.planner.project_plan import ProjectPlan from eva.planner.seq_scan_plan import SeqScanPlan @@ -62,74 +65,96 @@ def implement_group(self, root_grp_id, opt_cxt): def test_simple_top_down_rewrite(self): predicate = MagicMock() - child_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) - root_opr = LogicalFilter(predicate, [child_opr]) + video = MagicMock() + with patch("eva.optimizer.rules.rules.extract_pushdown_predicate") as mock: + mock.return_value = (predicate, None) + child_opr = LogicalGet(video, MagicMock(), MagicMock()) + root_opr = LogicalFilter(predicate, [child_opr]) - opt_cxt, root_grp_id = self.top_down_rewrite(root_opr) + opt_cxt, root_grp_id = self.top_down_rewrite(root_opr) - grp_expr = opt_cxt.memo.groups[root_grp_id].logical_exprs[0] + grp_expr = opt_cxt.memo.groups[root_grp_id].logical_exprs[0] - self.assertEqual(type(grp_expr.opr), LogicalGet) - self.assertEqual(grp_expr.opr.predicate, predicate) - self.assertEqual(grp_expr.opr.children, []) + self.assertEqual(type(grp_expr.opr), LogicalGet) + self.assertEqual(grp_expr.opr.predicate, predicate) + self.assertEqual(grp_expr.opr.children, []) def test_nested_top_down_rewrite(self): child_predicate = MagicMock() root_predicate = MagicMock() - - child_get_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) - child_filter_opr = LogicalFilter(child_predicate, [child_get_opr]) - child_project_opr = LogicalProject([MagicMock()], [child_filter_opr]) - root_derived_get_opr = LogicalQueryDerivedGet( - MagicMock(), children=[child_project_opr] - ) - root_filter_opr = LogicalFilter(root_predicate, [root_derived_get_opr]) - root_project_opr = LogicalProject([MagicMock()], [root_filter_opr]) - - opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) - - grp_expr = opt_cxt.memo.groups[root_grp_id].logical_exprs[0] - - self.assertEqual(type(grp_expr.opr), LogicalProject) - self.assertEqual(len(grp_expr.children), 1) - - child_grp_id = grp_expr.children[0] - child_expr = opt_cxt.memo.groups[child_grp_id].logical_exprs[0] - self.assertEqual(type(child_expr.opr), LogicalQueryDerivedGet) - self.assertEqual(child_expr.opr.predicate, root_predicate) - self.assertEqual(len(child_expr.children), 1) - - child_grp_id = child_expr.children[0] - child_expr = opt_cxt.memo.groups[child_grp_id].logical_exprs[0] - self.assertEqual(type(child_expr.opr), LogicalGet) - self.assertEqual(child_expr.opr.predicate, child_predicate) + with patch("eva.optimizer.rules.rules.extract_pushdown_predicate") as mock: + mock.side_effect = [ + (root_predicate, None), + (root_predicate, None), + (child_predicate, None), + (child_predicate, None), + ] + child_get_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) + child_filter_opr = LogicalFilter(child_predicate, [child_get_opr]) + child_project_opr = LogicalProject([MagicMock()], [child_filter_opr]) + root_derived_get_opr = LogicalQueryDerivedGet( + MagicMock(), children=[child_project_opr] + ) + root_filter_opr = LogicalFilter(root_predicate, [root_derived_get_opr]) + root_project_opr = LogicalProject([MagicMock()], [root_filter_opr]) + + opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) + + expected_expr_order = [ + LogicalProject, + LogicalFilter, + LogicalQueryDerivedGet, + LogicalProject, + LogicalGet, + ] + curr_grp_id = root_grp_id + idx = 0 + while True: + grp_expr = opt_cxt.memo.groups[curr_grp_id].logical_exprs[0] + self.assertEqual(type(grp_expr.opr), expected_expr_order[idx]) + idx += 1 + if idx == len(expected_expr_order): + break + curr_grp_id = grp_expr.children[0] def test_nested_bottom_up_rewrite(self): child_predicate = MagicMock() root_predicate = MagicMock() - - child_get_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) - child_filter_opr = LogicalFilter(child_predicate, [child_get_opr]) - child_project_opr = LogicalProject([MagicMock()], [child_filter_opr]) - root_derived_get_opr = LogicalQueryDerivedGet( - MagicMock(), children=[child_project_opr] - ) - root_filter_opr = LogicalFilter(root_predicate, [root_derived_get_opr]) - root_project_opr = LogicalProject([MagicMock()], children=[root_filter_opr]) - - opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) - opt_cxt, root_grp_id = self.bottom_up_rewrite(root_grp_id, opt_cxt) - - grp_expr = opt_cxt.memo.groups[root_grp_id].logical_exprs[0] - - self.assertEqual(type(grp_expr.opr), LogicalQueryDerivedGet) - self.assertEqual(len(grp_expr.children), 1) - self.assertEqual(grp_expr.opr.predicate, root_predicate) - - child_grp_id = grp_expr.children[0] - child_expr = opt_cxt.memo.groups[child_grp_id].logical_exprs[0] - self.assertEqual(type(child_expr.opr), LogicalGet) - self.assertEqual(child_expr.opr.predicate, child_predicate) + with patch("eva.optimizer.rules.rules.extract_pushdown_predicate") as mock: + mock.side_effect = [ + (root_predicate, None), + (root_predicate, None), + (child_predicate, None), + (child_predicate, None), + ] + + child_get_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) + child_filter_opr = LogicalFilter(child_predicate, [child_get_opr]) + child_project_opr = LogicalProject([MagicMock()], [child_filter_opr]) + root_derived_get_opr = LogicalQueryDerivedGet( + MagicMock(), children=[child_project_opr] + ) + root_filter_opr = LogicalFilter(root_predicate, [root_derived_get_opr]) + root_project_opr = LogicalProject([MagicMock()], children=[root_filter_opr]) + + opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) + opt_cxt, root_grp_id = self.bottom_up_rewrite(root_grp_id, opt_cxt) + + expected_expr_order = [ + LogicalProject, + LogicalFilter, + LogicalQueryDerivedGet, + LogicalGet, + ] + curr_grp_id = root_grp_id + idx = 0 + while True: + grp_expr = opt_cxt.memo.groups[curr_grp_id].logical_exprs[0] + self.assertEqual(type(grp_expr.opr), expected_expr_order[idx]) + idx += 1 + if idx == len(expected_expr_order): + break + curr_grp_id = grp_expr.children[0] def test_simple_implementation(self): predicate = MagicMock() @@ -143,36 +168,47 @@ def test_simple_implementation(self): root_grp = opt_cxt.memo.groups[root_grp_id] best_root_grp_expr = root_grp.get_best_expr(PropertyType.DEFAULT) - self.assertEqual(type(best_root_grp_expr.opr), SeqScanPlan) - self.assertEqual(best_root_grp_expr.opr.predicate, predicate) + self.assertEqual(type(best_root_grp_expr.opr), PredicatePlan) def test_nested_implementation(self): child_predicate = MagicMock() root_predicate = MagicMock() - - child_get_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) - child_filter_opr = LogicalFilter(child_predicate, children=[child_get_opr]) - child_project_opr = LogicalProject([MagicMock()], children=[child_filter_opr]) - root_derived_get_opr = LogicalQueryDerivedGet( - MagicMock(), children=[child_project_opr] - ) - root_filter_opr = LogicalFilter(root_predicate, children=[root_derived_get_opr]) - root_project_opr = LogicalProject([MagicMock()], children=[root_filter_opr]) - - opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) - opt_cxt, root_grp_id = self.bottom_up_rewrite(root_grp_id, opt_cxt) - opt_cxt, root_grp_id = self.implement_group(root_grp_id, opt_cxt) - - root_grp = opt_cxt.memo.groups[root_grp_id] - best_root_grp_expr = root_grp.get_best_expr(PropertyType.DEFAULT) - - root_opr = best_root_grp_expr.opr - self.assertEqual(type(root_opr), SeqScanPlan) - self.assertEqual(root_opr.predicate, root_predicate) - - child_grp_id = best_root_grp_expr.children[0] - child_grp = opt_cxt.memo.groups[child_grp_id] - best_child_grp_expr = child_grp.get_best_expr(PropertyType.DEFAULT) - child_opr = best_child_grp_expr.opr - self.assertEqual(type(child_opr), SeqScanPlan) - self.assertEqual(child_opr.predicate, child_predicate) + with patch("eva.optimizer.rules.rules.extract_pushdown_predicate") as mock: + mock.side_effect = [ + (child_predicate, None), + (root_predicate, None), + ] + + child_get_opr = LogicalGet(MagicMock(), MagicMock(), MagicMock()) + child_filter_opr = LogicalFilter(child_predicate, children=[child_get_opr]) + child_project_opr = LogicalProject( + [MagicMock()], children=[child_filter_opr] + ) + root_derived_get_opr = LogicalQueryDerivedGet( + MagicMock(), children=[child_project_opr] + ) + root_filter_opr = LogicalFilter( + root_predicate, children=[root_derived_get_opr] + ) + root_project_opr = LogicalProject([MagicMock()], children=[root_filter_opr]) + + opt_cxt, root_grp_id = self.top_down_rewrite(root_project_opr) + opt_cxt, root_grp_id = self.bottom_up_rewrite(root_grp_id, opt_cxt) + opt_cxt, root_grp_id = self.implement_group(root_grp_id, opt_cxt) + + expected_expr_order = [ + ProjectPlan, + PredicatePlan, + SeqScanPlan, + SeqScanPlan, + ] + curr_grp_id = root_grp_id + idx = 0 + while True: + root_grp = opt_cxt.memo.groups[curr_grp_id] + best_root_grp_expr = root_grp.get_best_expr(PropertyType.DEFAULT) + self.assertEqual(type(best_root_grp_expr.opr), expected_expr_order[idx]) + idx += 1 + if idx == len(expected_expr_order): + break + curr_grp_id = best_root_grp_expr.children[0] diff --git a/test/readers/test_opencv_reader.py b/test/readers/test_opencv_reader.py index 3ce514b0bf..3d90e80b4c 100644 --- a/test/readers/test_opencv_reader.py +++ b/test/readers/test_opencv_reader.py @@ -70,28 +70,3 @@ def test_should_skip_first_two_frames_and_batch_size_equal_to_no_of_frames(self) batches = list(video_loader.read()) expected = list(create_dummy_batches(filters=[i for i in range(2, NUM_FRAMES)])) self.assertTrue(batches, expected) - - def test_should_start_frame_number_from_two(self): - video_loader = OpenCVReader( - file_url=os.path.join(PATH_PREFIX, "dummy.avi"), - batch_mem_size=FRAME_SIZE * NUM_FRAMES, - start_frame_id=2, - ) - batches = list(video_loader.read()) - expected = list( - create_dummy_batches(filters=[i for i in range(0, NUM_FRAMES)], start_id=2) - ) - self.assertTrue(batches, expected) - - def test_should_start_frame_number_from_two_and_offset_from_one(self): - video_loader = OpenCVReader( - file_url=os.path.join(PATH_PREFIX, "dummy.avi"), - batch_mem_size=FRAME_SIZE * NUM_FRAMES, - offset=1, - start_frame_id=2, - ) - batches = list(video_loader.read()) - expected = list( - create_dummy_batches(filters=[i for i in range(1, NUM_FRAMES)], start_id=2) - ) - self.assertTrue(batches, expected) diff --git a/test/udfs/test_fastrcnn_object_detector.py b/test/udfs/test_fastrcnn_object_detector.py index 59e95d4a0b..51ee4149a7 100644 --- a/test/udfs/test_fastrcnn_object_detector.py +++ b/test/udfs/test_fastrcnn_object_detector.py @@ -58,9 +58,7 @@ def test_should_return_batches_equivalent_to_number_of_frames(self): frame_dog = { "id": 1, - "data": self._load_image( - os.path.join(self.base_path, "data", "dog.jpeg") - ), + "data": self._load_image(os.path.join(self.base_path, "data", "dog.jpeg")), } frame_dog_cat = { "id": 2,