Skip to content

Commit

Permalink
Change acceptance function names per feedback
Browse files Browse the repository at this point in the history
`_acceptance_fn_default1` and `_acceptance_fn_default2` are now
`_acceptance_fn_default_unary` and `_acceptance_fn_default_binary`
  • Loading branch information
ndgrigorian committed Nov 29, 2023
1 parent 0018dfa commit f8536c1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
8 changes: 4 additions & 4 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._type_utils import (
_acceptance_fn_default1,
_acceptance_fn_default2,
_acceptance_fn_default_binary,
_acceptance_fn_default_unary,
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(
if callable(acceptance_fn):
self.acceptance_fn_ = acceptance_fn
else:
self.acceptance_fn_ = _acceptance_fn_default1
self.acceptance_fn_ = _acceptance_fn_default_unary

def __str__(self):
return f"<{self.__name__} '{self.name_}'>"
Expand Down Expand Up @@ -526,7 +526,7 @@ def __init__(
if callable(acceptance_fn):
self.acceptance_fn_ = acceptance_fn
else:
self.acceptance_fn_ = _acceptance_fn_default2
self.acceptance_fn_ = _acceptance_fn_default_binary

def __str__(self):
return f"<{self.__name__} '{self.name_}'>"
Expand Down
8 changes: 4 additions & 4 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _to_device_supported_dtype(dt, dev):
return dt


def _acceptance_fn_default1(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
return True


Expand Down Expand Up @@ -187,7 +187,7 @@ def _get_device_default_dtype(dt_kind, sycl_dev):
raise RuntimeError


def _acceptance_fn_default2(
def _acceptance_fn_default_binary(
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
):
return True
Expand Down Expand Up @@ -254,8 +254,8 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
"_find_buf_dtype",
"_find_buf_dtype2",
"_to_device_supported_dtype",
"_acceptance_fn_default1",
"_acceptance_fn_default_unary",
"_acceptance_fn_reciprocal",
"_acceptance_fn_default2",
"_acceptance_fn_default_binary",
"_acceptance_fn_divide",
]
8 changes: 6 additions & 2 deletions dpctl/tests/elementwise/test_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _denier_fn(dt):
dev = MockDevice(fp16, fp64)
arg_dt = dpt.float64
r = tu._find_buf_dtype(
arg_dt, _denier_fn, dev, tu._acceptance_fn_default1
arg_dt, _denier_fn, dev, tu._acceptance_fn_default_unary
)
assert r == (
None,
Expand Down Expand Up @@ -159,7 +159,11 @@ def _denier_fn(dt1, dt2):
arg1_dt = dpt.float64
arg2_dt = dpt.complex64
r = tu._find_buf_dtype2(
arg1_dt, arg2_dt, _denier_fn, dev, tu._acceptance_fn_default2
arg1_dt,
arg2_dt,
_denier_fn,
dev,
tu._acceptance_fn_default_binary,
)
assert r == (
None,
Expand Down

0 comments on commit f8536c1

Please sign in to comment.