diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index e2481229..d7be6b47 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -19,7 +19,7 @@ from dataclasses import dataclass, field from decimal import ROUND_HALF_EVEN, Decimal from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal from warnings import warn import pytest @@ -544,6 +544,10 @@ class UnaryCase(Case): "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) +r_nan_signbit = re.compile( + "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, " + "the result is ``(.+)``" +) def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: @@ -599,6 +603,25 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: ) +def make_nan_signbit_case(signbit: Literal[0, 1], expected: bool) -> UnaryCase: + if signbit: + sign = -1 + nan_expr = "-NaN" + float_arg = "-nan" + else: + sign = 1 + nan_expr = "+NaN" + float_arg = "nan" + + return UnaryCase( + cond_expr=f"x_i is {nan_expr}", + cond=lambda i: math.isnan(i) and math.copysign(1, i) == sign, + cond_from_dtype=lambda _: st.just(float(float_arg)), + result_expr=str(expected), + check_result=lambda _, result: result == float(expected), + ) + + def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck: def check_result(i: float, result: float) -> bool: return check_just_result(result) @@ -655,10 +678,14 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: cases = [] for case_m in r_case.finditer(case_block): case_str = case_m.group(1) - if m := r_already_int_case.search(case_str): + if r_already_int_case.search(case_str): cases.append(already_int_case) - elif m := r_even_round_halves_case.search(case_str): + elif r_even_round_halves_case.search(case_str): cases.append(even_round_halves_case) + elif m := r_nan_signbit.search(case_str): + signbit = parse_value(m.group(1)) + expected = bool(parse_value(m.group(2))) + cases.append(make_nan_signbit_case(signbit, expected)) elif m := r_unary_case.search(case_str): try: cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))