From f8536c1c3db21d13f33afd9520d29cbae66d59f5 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 29 Nov 2023 13:57:32 -0800 Subject: [PATCH] Change acceptance function names per feedback `_acceptance_fn_default1` and `_acceptance_fn_default2` are now `_acceptance_fn_default_unary` and `_acceptance_fn_default_binary` --- dpctl/tensor/_elementwise_common.py | 8 ++++---- dpctl/tensor/_type_utils.py | 8 ++++---- dpctl/tests/elementwise/test_type_utils.py | 8 ++++++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/dpctl/tensor/_elementwise_common.py b/dpctl/tensor/_elementwise_common.py index a6e9d6043c..f638ebc3c1 100644 --- a/dpctl/tensor/_elementwise_common.py +++ b/dpctl/tensor/_elementwise_common.py @@ -28,8 +28,8 @@ from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK from ._type_utils import ( - _acceptance_fn_default1, - _acceptance_fn_default2, + _acceptance_fn_default_binary, + _acceptance_fn_default_unary, _all_data_types, _find_buf_dtype, _find_buf_dtype2, @@ -95,7 +95,7 @@ def __init__( if callable(acceptance_fn): self.acceptance_fn_ = acceptance_fn else: - self.acceptance_fn_ = _acceptance_fn_default1 + self.acceptance_fn_ = _acceptance_fn_default_unary def __str__(self): return f"<{self.__name__} '{self.name_}'>" @@ -526,7 +526,7 @@ def __init__( if callable(acceptance_fn): self.acceptance_fn_ = acceptance_fn else: - self.acceptance_fn_ = _acceptance_fn_default2 + self.acceptance_fn_ = _acceptance_fn_default_binary def __str__(self): return f"<{self.__name__} '{self.name_}'>" diff --git a/dpctl/tensor/_type_utils.py b/dpctl/tensor/_type_utils.py index dcf8617b3f..bacd488226 100644 --- a/dpctl/tensor/_type_utils.py +++ b/dpctl/tensor/_type_utils.py @@ -132,7 +132,7 @@ def _to_device_supported_dtype(dt, dev): return dt -def _acceptance_fn_default1(arg_dtype, ret_buf_dt, res_dt, sycl_dev): +def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev): return True @@ -187,7 +187,7 @@ def _get_device_default_dtype(dt_kind, sycl_dev): raise RuntimeError -def _acceptance_fn_default2( +def _acceptance_fn_default_binary( arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev ): return True @@ -254,8 +254,8 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn): "_find_buf_dtype", "_find_buf_dtype2", "_to_device_supported_dtype", - "_acceptance_fn_default1", + "_acceptance_fn_default_unary", "_acceptance_fn_reciprocal", - "_acceptance_fn_default2", + "_acceptance_fn_default_binary", "_acceptance_fn_divide", ] diff --git a/dpctl/tests/elementwise/test_type_utils.py b/dpctl/tests/elementwise/test_type_utils.py index 93290447ce..9407195bd9 100644 --- a/dpctl/tests/elementwise/test_type_utils.py +++ b/dpctl/tests/elementwise/test_type_utils.py @@ -125,7 +125,7 @@ def _denier_fn(dt): dev = MockDevice(fp16, fp64) arg_dt = dpt.float64 r = tu._find_buf_dtype( - arg_dt, _denier_fn, dev, tu._acceptance_fn_default1 + arg_dt, _denier_fn, dev, tu._acceptance_fn_default_unary ) assert r == ( None, @@ -159,7 +159,11 @@ def _denier_fn(dt1, dt2): arg1_dt = dpt.float64 arg2_dt = dpt.complex64 r = tu._find_buf_dtype2( - arg1_dt, arg2_dt, _denier_fn, dev, tu._acceptance_fn_default2 + arg1_dt, + arg2_dt, + _denier_fn, + dev, + tu._acceptance_fn_default_binary, ) assert r == ( None,