Skip to content

Commit

Permalink
fix validation logic for dates in duckdb and postgres
Browse files Browse the repository at this point in the history
This fix is pre-requisite to using test cases from substrait directly.

substrait testcases have examples which can only contain the date part
but the python client returns a datetime object.

Converting these to strings and comparing is not quite right.
e.g., '2024-03-01' == '2024-03-01:00:00:00' will fail if we do string compare.

The fix is to not compare the empty fields hrs, mins, secs, etc if we only
contain the date part
  • Loading branch information
srikrishnak committed Dec 10, 2024
1 parent 621bcb2 commit 9dea9ce
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
22 changes: 9 additions & 13 deletions bft/testers/duckdb/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import type_to_dialect_type
from bft.utils.utils import type_to_dialect_type, datetype_value_equal

type_map = {
"i8": "TINYINT",
Expand Down Expand Up @@ -53,7 +53,6 @@ def is_string_type(arg):
def is_datetype(arg):
return type(arg) in [datetime.datetime, datetime.date, datetime.timedelta]


class DuckDBRunner(SqlCaseRunner):
def __init__(self, dialect):
super().__init__(dialect)
Expand All @@ -62,31 +61,26 @@ def __init__(self, dialect):
def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:

try:
max_args = len(case.args) + 1
if case.function == 'regexp_replace':
max_args = 3
if case.function == 'regexp_match_substring':
max_args = 2
arg_defs = [
f"arg{idx} {type_to_duckdb_type(arg.type)}"
for idx, arg in enumerate(case.args[:max_args])
for idx, arg in enumerate(case.args)
]
schema = ",".join(arg_defs)
self.conn.execute(f"CREATE TABLE my_table({schema});")
self.conn.execute(f"SET TimeZone='UTC';")

arg_names = [f"arg{idx}" for idx in range(len(case.args[:max_args]))]
arg_names = [f"arg{idx}" for idx in range(len(case.args))]
joined_arg_names = ",".join(arg_names)
arg_vals_list = list()
for arg in case.args[:max_args]:
for arg in case.args:
if is_string_type(arg):
arg_vals_list.append("'" + literal_to_str(arg.value) + "'")
else:
arg_vals_list.append(literal_to_str(arg.value))
arg_vals = ", ".join(arg_vals_list)
if mapping.aggregate:
arg_vals_list = list()
for arg in case.args[:max_args]:
for arg in case.args:
arg_vals = ""
for value in arg.value:
if is_string_type(arg):
Expand Down Expand Up @@ -119,7 +113,7 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
if len(arg_names) != 2:
raise Exception(f"Extract function with {len(arg_names)} args")
expr = f"SELECT {mapping.local_name}({arg_vals_list[0]} FROM {arg_names[1]}) FROM my_table;"
elif mapping.local_name == 'count(*)':
elif mapping.local_name == "count(*)":
expr = f"SELECT {mapping.local_name} FROM my_table;"
elif mapping.aggregate:
if len(arg_names) < 1:
Expand Down Expand Up @@ -147,7 +141,9 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
else:
if result == case.result.value:
return SqlCaseResult.success()
elif is_datetype(result) and str(result) == case.result.value:
elif is_datetype(result) and datetype_value_equal(
result, case.result.value
):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
Expand Down
5 changes: 4 additions & 1 deletion bft/testers/postgres/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import datetype_value_equal

type_map = {
"i16": "smallint",
Expand Down Expand Up @@ -152,7 +153,9 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
else:
if result == case.result.value:
return SqlCaseResult.success()
elif is_datetype(result) and str(result) == case.result.value:
elif is_datetype(result) and datetype_value_equal(
result, case.result.value
):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
Expand Down
22 changes: 22 additions & 0 deletions bft/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
import datetime


def type_to_dialect_type(type: str, type_map: Dict[str, str])->str:
Expand All @@ -24,3 +25,24 @@ def type_to_dialect_type(type: str, type_map: Dict[str, str])->str:
return type_val
# transform parameterized type name to have dialect type
return type.replace(type_to_check, type_val).replace("<", "(").replace(">", ")")

def has_only_date(value: datetime.datetime):
if (
value.hour == 0
and value.minute == 0
and value.second == 0
and value.microsecond == 0
):
return True
return False

def datetype_value_equal(result, case_result):
if str(result) == case_result:
return True
if (
isinstance(result, datetime.datetime)
and has_only_date(result)
and str(result.date()) == case_result
):
return True
return False

0 comments on commit 9dea9ce

Please sign in to comment.