Skip to content

Fix special case testing signbit on NaNs #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dataclasses import dataclass, field
from decimal import ROUND_HALF_EVEN, Decimal
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal
from warnings import warn

import pytest
Expand Down Expand Up @@ -544,6 +544,10 @@ class UnaryCase(Case):
"If two integers are equally close to ``x_i``, "
"the result is the even integer closest to ``x_i``"
)
r_nan_signbit = re.compile(
"If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, "
"the result is ``(.+)``"
)


def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
Expand Down Expand Up @@ -599,6 +603,25 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
)


def make_nan_signbit_case(signbit: Literal[0, 1], expected: bool) -> UnaryCase:
if signbit:
sign = -1
nan_expr = "-NaN"
float_arg = "-nan"
else:
sign = 1
nan_expr = "+NaN"
float_arg = "nan"

return UnaryCase(
cond_expr=f"x_i is {nan_expr}",
cond=lambda i: math.isnan(i) and math.copysign(1, i) == sign,
cond_from_dtype=lambda _: st.just(float(float_arg)),
result_expr=str(expected),
check_result=lambda _, result: result == float(expected),
)


def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck:
def check_result(i: float, result: float) -> bool:
return check_just_result(result)
Expand Down Expand Up @@ -655,10 +678,14 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
cases = []
for case_m in r_case.finditer(case_block):
case_str = case_m.group(1)
if m := r_already_int_case.search(case_str):
if r_already_int_case.search(case_str):
cases.append(already_int_case)
elif m := r_even_round_halves_case.search(case_str):
elif r_even_round_halves_case.search(case_str):
cases.append(even_round_halves_case)
elif m := r_nan_signbit.search(case_str):
signbit = parse_value(m.group(1))
expected = bool(parse_value(m.group(2)))
cases.append(make_nan_signbit_case(signbit, expected))
elif m := r_unary_case.search(case_str):
try:
cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
Expand Down
Loading