Skip to content

Commit

Permalink
fix: fix function lookup in coverage tool (#744)
Browse files Browse the repository at this point in the history
  • Loading branch information
srikrishnak authored Nov 18, 2024
1 parent a64a102 commit 3d2ff77
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 13 deletions.
2 changes: 1 addition & 1 deletion tests/cases/arithmetic_decimal/power.test
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### SUBSTRAIT_SCALAR_TEST: v1.0
### SUBSTRAIT_INCLUDE: 'extensions/functions_arithmetic_decimal.yaml'
### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic_decimal.yaml'

# basic: Basic examples without any special cases
power(8::dec<38, 0>, 2::dec<38, 0>) = 64::fp64
Expand Down
15 changes: 10 additions & 5 deletions tests/coverage/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,27 @@ def update_test_count(test_case_files: list, function_registry: FunctionRegistry
for test_file in test_case_files:
for test_case in test_file.testcases:
function_variant = function_registry.get_function(
test_case.func_name, test_case.get_arg_types()
test_case.func_name,
test_file.include,
test_case.get_arg_types(),
test_case.get_return_type(),
)
if function_variant:
if (
function_variant.return_type != test_case.get_return_type()
and not test_case.is_return_type_error()
not test_case.is_return_type_error()
and not function_registry.is_same_type(
function_variant.return_type, test_case.get_return_type()
)
):
error(
f"Return type mismatch in function {test_case.func_name}: "
f"Return type mismatch in function {test_case.get_signature()}: "
f"{function_variant.return_type} != {test_case.get_return_type()}"
)
num_tests_with_no_matching_function += 1
continue
function_variant.increment_test_count()
else:
error(f"Function not found: {test_case.func_name}({test_case.args})")
error(f"Function not found: {test_case.get_signature()}")
num_tests_with_no_matching_function += 1
return num_tests_with_no_matching_function

Expand Down
54 changes: 48 additions & 6 deletions tests/coverage/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import yaml

from tests.coverage.antlr_parser.FuncTestCaseLexer import FuncTestCaseLexer
from tests.coverage.nodes import SubstraitError

enable_debug = False

Expand Down Expand Up @@ -122,11 +123,10 @@ def get_supported_kernels_from_impls(func):
return overloads

@staticmethod
def add_functions_to_map(func_list, function_map, suffix, extension):
def add_functions_to_map(func_list, function_map, suffix, extension, uri):
dup_idx = 0
for func in func_list:
name = func["name"]
uri = extension[5:] # strip the ../..
if name in function_map:
debug(
f"Duplicate function name: {name} renaming to {name}_{suffix} extension: {extension}"
Expand Down Expand Up @@ -163,25 +163,35 @@ def read_substrait_extensions(dir_path: str):
suffix = suffix[
suffix.rfind("/") + 1 :
] # strip the path and get the name of the extension
uri = f"/extensions/{suffix}.yaml"
suffix = suffix[suffix.find("_") + 1 :] # get the suffix after the last _

dependencies[suffix] = Extension.get_base_uri() + extension
dependencies[suffix] = Extension.get_base_uri() + uri
with open(extension, "r") as fh:
data = yaml.load(fh, Loader=yaml.FullLoader)
if "scalar_functions" in data:
Extension.add_functions_to_map(
data["scalar_functions"], scalar_functions, suffix, extension
data["scalar_functions"],
scalar_functions,
suffix,
extension,
uri,
)
if "aggregate_functions" in data:
Extension.add_functions_to_map(
data["aggregate_functions"],
aggregate_functions,
suffix,
extension,
uri,
)
if "window_functions" in data:
Extension.add_functions_to_map(
data["window_functions"], scalar_functions, suffix, extension
data["window_functions"],
scalar_functions,
suffix,
extension,
uri,
)

return FunctionRegistry(
Expand Down Expand Up @@ -263,13 +273,45 @@ def add_functions(self, functions, func_type):
fun_arr.append(function)
self.registry[f_name] = fun_arr

def get_function(self, name: str, args: object) -> [FunctionVariant]:
@staticmethod
def is_type_any(func_arg_type):
return func_arg_type[:3] == "any"

@staticmethod
def is_same_type(func_arg_type, arg_type):
arg_type_base = arg_type.split("<")[0]
if func_arg_type == arg_type_base:
return True
return FunctionRegistry.is_type_any(func_arg_type)

def get_function(
self, name: str, uri: str, args: object, return_type
) -> [FunctionVariant]:
functions = self.registry.get(name, None)
if functions is None:
return None
for function in functions:
if uri != function.uri:
continue
if not isinstance(return_type, SubstraitError) and not self.is_same_type(
function.return_type, return_type
):
continue
if function.args == args:
return function
if len(function.args) != len(args) and not (
function.variadic and len(args) >= len(function.args)
):
continue
is_match = True
for i, arg in enumerate(args):
j = i if i < len(function.args) else len(function.args) - 1
if not self.is_same_type(function.args[j], arg):
is_match = False
break
if is_match:
return function
return None

def get_extension_list(self):
return list(self.extensions)
Expand Down
2 changes: 1 addition & 1 deletion tests/coverage/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_arg_types(self):
return [arg.get_base_type() for arg in self.args]

def get_signature(self):
return f"{self.func_name}({', '.join([arg.type for arg in self.args])})"
return f"{self.func_name}({', '.join([arg.type for arg in self.args])}) = {self.get_return_type()}"


@dataclass
Expand Down
51 changes: 51 additions & 0 deletions tests/coverage/test_coverage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import os

import pytest
from antlr4 import InputStream
from tests.coverage.case_file_parser import parse_stream, parse_one_file
from tests.coverage.extensions import Extension
from tests.coverage.visitor import ParseError
from tests.coverage.nodes import CaseLiteral

Expand Down Expand Up @@ -411,6 +414,10 @@ def test_parse_errors_with_bad_aggregate_testcases(input_func_test, expected_mes
"f37('1991-01-01T01:02:03.123456'::pts<6>, '1991-01-01T04:05:06.456'::precision_timestamp<6>) = 123456::i64",
"f38('1991-01-01T01:02:03.456+05:30'::ptstz<3>) = '1991-01-01T00:00:00+15:30'::precision_timestamp_tz<3>",
"f39('1991-01-01T01:02:03.123456+05:30'::ptstz<6>) = '1991-01-01T00:00:00+15:30'::precision_timestamp_tz<6>",
"logb(10::fp64, -inf::fp64) [on_domain_error:ERROR] = <!ERROR>",
"bitwise_and(-31766::dec<5, 0>, 900::dec<3, 0>) = 896::dec<5, 0>",
"or(true::bool, true::bool) = true::bool",
"between(5::i8, 0::i8, 127::i8) = true::bool",
],
)
def test_parse_various_scalar_func_argument_types(input_func_test):
Expand All @@ -436,3 +443,47 @@ def test_parse_various_aggregate_scalar_func_argument_types(input_func_test):
)
test_file = parse_string(header + input_func_test + "\n")
assert len(test_file.testcases) == 1


@pytest.mark.parametrize(
"func_name, func_args, func_ret, func_uri, expected_failure",
[
# lt for i8 with correct uri
("lt", ["i8", "i8"], "bool", "/extensions/functions_comparison.yaml", False),
("add", ["i8", "i8"], "i8", "/extensions/functions_arithmetic.yaml", False),
(
"add",
["dec", "dec"],
"dec",
"/extensions/functions_arithmetic_decimal.yaml",
False,
),
(
"bitwise_xor",
["dec", "dec"],
"dec",
"/extensions/functions_arithmetic_decimal.yaml",
False,
),
# negative case, lt for i8 with wrong uri
("lt", ["i8", "i8"], "bool", "/extensions/functions_datetime.yaml", True),
(
"add",
["i8", "i8"],
"i8",
"/extensions/functions_arithmetic_decimal.yaml",
True,
),
("add", ["dec", "dec"], "dec", "/extensions/functions_arithmetic.yaml", True),
("max", ["dec", "dec"], "dec", "/extensions/functions_arithmetic.yaml", True),
],
)
def test_uri_match_in_get_function(
func_name, func_args, func_ret, func_uri, expected_failure
):
script_dir = os.path.dirname(os.path.abspath(__file__))
extensions_path = os.path.join(script_dir, "../../extensions")
registry = Extension.read_substrait_extensions(extensions_path)

function = registry.get_function(func_name, func_uri, func_args, func_ret)
assert (function is None) == expected_failure

0 comments on commit 3d2ff77

Please sign in to comment.