Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix validation logic for date datatype in duckdb and postgres #101

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions bft/testers/duckdb/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import type_to_dialect_type
from bft.utils.utils import type_to_dialect_type, datetype_value_equal

type_map = {
"i8": "TINYINT",
Expand Down Expand Up @@ -53,7 +53,6 @@ def is_string_type(arg):
def is_datetype(arg):
return type(arg) in [datetime.datetime, datetime.date, datetime.timedelta]


class DuckDBRunner(SqlCaseRunner):
def __init__(self, dialect):
super().__init__(dialect)
Expand All @@ -62,31 +61,26 @@ def __init__(self, dialect):
def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:

try:
max_args = len(case.args) + 1
if case.function == 'regexp_replace':
max_args = 3
if case.function == 'regexp_match_substring':
max_args = 2
arg_defs = [
f"arg{idx} {type_to_duckdb_type(arg.type)}"
for idx, arg in enumerate(case.args[:max_args])
for idx, arg in enumerate(case.args)
]
schema = ",".join(arg_defs)
self.conn.execute(f"CREATE TABLE my_table({schema});")
self.conn.execute(f"SET TimeZone='UTC';")

arg_names = [f"arg{idx}" for idx in range(len(case.args[:max_args]))]
arg_names = [f"arg{idx}" for idx in range(len(case.args))]
joined_arg_names = ",".join(arg_names)
arg_vals_list = list()
for arg in case.args[:max_args]:
for arg in case.args:
if is_string_type(arg):
arg_vals_list.append("'" + literal_to_str(arg.value) + "'")
else:
arg_vals_list.append(literal_to_str(arg.value))
arg_vals = ", ".join(arg_vals_list)
if mapping.aggregate:
arg_vals_list = list()
for arg in case.args[:max_args]:
for arg in case.args:
arg_vals = ""
for value in arg.value:
if is_string_type(arg):
Expand Down Expand Up @@ -119,7 +113,7 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
if len(arg_names) != 2:
raise Exception(f"Extract function with {len(arg_names)} args")
expr = f"SELECT {mapping.local_name}({arg_vals_list[0]} FROM {arg_names[1]}) FROM my_table;"
elif mapping.local_name == 'count(*)':
elif mapping.local_name == "count(*)":
expr = f"SELECT {mapping.local_name} FROM my_table;"
elif mapping.aggregate:
if len(arg_names) < 1:
Expand Down Expand Up @@ -147,7 +141,9 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
else:
if result == case.result.value:
return SqlCaseResult.success()
elif is_datetype(result) and str(result) == case.result.value:
elif is_datetype(result) and datetype_value_equal(
result, case.result.value
):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
Expand Down
5 changes: 4 additions & 1 deletion bft/testers/postgres/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bft.cases.runner import SqlCaseResult, SqlCaseRunner
from bft.cases.types import Case
from bft.dialects.types import SqlMapping
from bft.utils.utils import datetype_value_equal

type_map = {
"i16": "smallint",
Expand Down Expand Up @@ -152,7 +153,9 @@ def run_sql_case(self, case: Case, mapping: SqlMapping) -> SqlCaseResult:
else:
if result == case.result.value:
return SqlCaseResult.success()
elif is_datetype(result) and str(result) == case.result.value:
elif is_datetype(result) and datetype_value_equal(
result, case.result.value
):
return SqlCaseResult.success()
else:
return SqlCaseResult.mismatch(str(result))
Expand Down
22 changes: 22 additions & 0 deletions bft/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
import datetime


def type_to_dialect_type(type: str, type_map: Dict[str, str])->str:
Expand All @@ -24,3 +25,24 @@ def type_to_dialect_type(type: str, type_map: Dict[str, str])->str:
return type_val
# transform parameterized type name to have dialect type
return type.replace(type_to_check, type_val).replace("<", "(").replace(">", ")")

def has_only_date(value: datetime.datetime):
if (
value.hour == 0
and value.minute == 0
and value.second == 0
and value.microsecond == 0
):
return True
return False

def datetype_value_equal(result, case_result):
if str(result) == case_result:
return True
if (
isinstance(result, datetime.datetime)
and has_only_date(result)
and str(result.date()) == case_result
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the long run we need to make sure that we're testing types as well as values. This is fine for now as a workaround.

):
return True
return False
Loading