diff --git a/src/graphql/pyutils/is_nullish.py b/src/graphql/pyutils/is_nullish.py index 3e4f2a0d..6ccd42d9 100644 --- a/src/graphql/pyutils/is_nullish.py +++ b/src/graphql/pyutils/is_nullish.py @@ -1,3 +1,4 @@ +import math from typing import Any from ..error import INVALID @@ -7,4 +8,8 @@ def is_nullish(value: Any) -> bool: """Return true if a value is null, undefined, or NaN.""" - return value is None or value is INVALID or value != value + return ( + value is None + or value is INVALID + or (isinstance(value, float) and math.isnan(value)) + ) diff --git a/tests/pyutils/test_is_nullish.py b/tests/pyutils/test_is_nullish.py index 9e71b00a..7bad7763 100644 --- a/tests/pyutils/test_is_nullish.py +++ b/tests/pyutils/test_is_nullish.py @@ -4,6 +4,22 @@ from graphql.pyutils import is_nullish +class FakeNumpyArray: + def __eq__(self, other): + # Numpy arrays return an array when compared with another numpy array + # containing the pointwise equality of the two + if isinstance(other, FakeNumpyArray): + return FakeNumpyArray() + else: + return False + + def __bool__(self): + raise TypeError( + "The truth value of an array with more than one element is " + "ambiguous. Use a.any() or a.all()" + ) + + def describe_is_nullish(): def null_is_nullish(): assert is_nullish(None) is True @@ -29,3 +45,6 @@ def undefined_is_nullish(): def nan_is_nullish(): assert is_nullish(nan) + + def numpy_arrays_are_not_nullish(): + assert is_nullish(FakeNumpyArray()) is False