Skip to content

Commit

Permalink
Distinquish between np typehins from Py ones (#4953)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver authored Dec 18, 2023
1 parent 8930861 commit 7364cf5
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
20 changes: 18 additions & 2 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def _parse_signature(fn: Callable) -> _ParsedSignature:
return p_sig


def _is_from_np_type(param_types:set[type], np_type_char: str) -> bool:
""" Determine if the given numpy type char comes for a numpy type in the given set of parameter type annotations"""
for t in param_types:
if issubclass(t, np.generic) and np.dtype(t).char == np_type_char:
return True
return False


def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
""" Convert a single argument to the type specified by the annotation """
if arg is None:
Expand Down Expand Up @@ -279,13 +287,21 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
else:
raise DHError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
else:
return np.dtype(param.int_char).type(arg)
# return a numpy integer instance only if the annotation is a numpy type
if _is_from_np_type(param.orig_types, param.int_char):
return np.dtype(param.int_char).type(arg)
else:
return arg
elif param.floating_char and isinstance(arg, float):
if isinstance(arg, float):
if arg == dh_null:
return np.nan if "N" not in param.encoded_types else None
else:
return np.dtype(param.floating_char).type(arg)
# return a numpy floating instance only if the annotation is a numpy type
if _is_from_np_type(param.orig_types, param.floating_char):
return np.dtype(param.floating_char).type(arg)
else:
return arg
elif t == "?" and isinstance(arg, bool):
return arg
elif t == "M":
Expand Down
35 changes: 35 additions & 0 deletions py/server/tests/test_udf_numpy_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,41 @@ def f31(p1: Optional[np.bool_], p2=None) -> bool:
t2 = t.update(["X1 = f31(null, Y)"])
self.assertEqual(10, t2.to_string("X1").count("true"))

def test_non_np_typehints(self):
py_types = {"int", "float"}

for p_type in py_types:
with self.subTest(p_type):
func_str = f"""
def f(x: {p_type}) -> bool: # note typing
return type(x) == {p_type}
"""
exec(func_str, globals())
t = empty_table(1).update(["X = i", f"Y = f(({p_type})X)"])
self.assertEqual(1, t.to_string(cols="Y").count("true"))


np_int_types = {"np.int8", "np.int16", "np.int32", "np.int64"}
for p_type in np_int_types:
with self.subTest(p_type):
func_str = f"""
def f(x: {p_type}) -> bool: # note typing
return type(x) == {p_type}
"""
exec(func_str, globals())
t = empty_table(1).update(["X = i", f"Y = f(X)"])
self.assertEqual(1, t.to_string(cols="Y").count("true"))

np_floating_types = {"np.float32", "np.float64"}
for p_type in np_floating_types:
with self.subTest(p_type):
func_str = f"""
def f(x: {p_type}) -> bool: # note typing
return type(x) == {p_type}
"""
exec(func_str, globals())
t = empty_table(1).update(["X = i", f"Y = f((float)X)"])
self.assertEqual(1, t.to_string(cols="Y").count("true"))

if __name__ == "__main__":
unittest.main()

0 comments on commit 7364cf5

Please sign in to comment.