Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous improvements to approx() #3741

Merged
merged 16 commits into from
Aug 2, 2018
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/3473.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Raise immediately if ``approx()`` is given an expected value of a type it doesn't understand (e.g. strings, nested dicts, etc.).
1 change: 1 addition & 0 deletions changelog/3712.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Correctly represent the dimensions of an numpy array when calling ``repr()`` on ``approx()``.
109 changes: 80 additions & 29 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import math
import pprint
import sys
from numbers import Number
from decimal import Decimal

import py
from six.moves import zip, filterfalse
Expand Down Expand Up @@ -30,6 +33,15 @@ def _cmp_raises_type_error(self, other):
)


def _non_numeric_type_error(value, at):
at_str = " at {}".format(at) if at else ""
return TypeError(
"cannot make approximate comparisons to non-numeric values: {!r} {}".format(
value, at_str
)
)


# builtin pytest.approx helper


Expand All @@ -39,15 +51,17 @@ class ApproxBase(object):
or sequences of numbers.
"""

# Tell numpy to use our `__eq__` operator instead of its
# Tell numpy to use our `__eq__` operator instead of its.
__array_ufunc__ = None
__array_priority__ = 100

def __init__(self, expected, rel=None, abs=None, nan_ok=False):
__tracebackhide__ = True
self.expected = expected
self.abs = abs
self.rel = rel
self.nan_ok = nan_ok
self._check_type()

def __repr__(self):
raise NotImplementedError
Expand Down Expand Up @@ -75,20 +89,31 @@ def _yield_comparisons(self, actual):
"""
raise NotImplementedError

def _check_type(self):
"""
Raise a TypeError if the expected value is not a valid type.
"""
# This is only a concern if the expected value is a sequence. In every
# other case, the approx() function ensures that the expected value has
# a numeric type. For this reason, the default is to do nothing. The
# classes that deal with sequences should reimplement this method to
# raise if there are any non-numeric elements in the sequence.
pass


class ApproxNumpy(ApproxBase):
"""
Perform approximate comparisons for numpy arrays.
Perform approximate comparisons where the expected value is numpy array.
"""

def __repr__(self):
# It might be nice to rewrite this function to account for the
# shape of the array...
import numpy as np
def recursive_map(f, x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this helper is independent, lets move it out

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if isinstance(x, list):
return list(recursive_map(f, xi) for xi in x)
else:
return f(x)

list_scalars = []
for x in np.ndindex(self.expected.shape):
list_scalars.append(self._approx_scalar(np.asscalar(self.expected[x])))
list_scalars = recursive_map(self._approx_scalar, self.expected.tolist())

return "approx({!r})".format(list_scalars)

Expand Down Expand Up @@ -128,8 +153,8 @@ def _yield_comparisons(self, actual):

class ApproxMapping(ApproxBase):
"""
Perform approximate comparisons for mappings where the values are numbers
(the keys can be anything).
Perform approximate comparisons where the expected value is a mapping with
numeric values (the keys can be anything).
"""

def __repr__(self):
Expand All @@ -147,10 +172,20 @@ def _yield_comparisons(self, actual):
for k in self.expected.keys():
yield actual[k], self.expected[k]

def _check_type(self):
__tracebackhide__ = True
for key, value in self.expected.items():
if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lovely 👍

elif not isinstance(value, Number):
raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))


class ApproxSequence(ApproxBase):
"""
Perform approximate comparisons for sequences of numbers.
Perform approximate comparisons where the expected value is a sequence of
numbers.
"""

def __repr__(self):
Expand All @@ -169,10 +204,21 @@ def __eq__(self, actual):
def _yield_comparisons(self, actual):
return zip(actual, self.expected)

def _check_type(self):
__tracebackhide__ = True
for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
elif not isinstance(x, Number):
raise _non_numeric_type_error(
self.expected, at="index {}".format(index)
)


class ApproxScalar(ApproxBase):
"""
Perform approximate comparisons for single numbers only.
Perform approximate comparisons where the expected value is a single number.
"""

DEFAULT_ABSOLUTE_TOLERANCE = 1e-12
Expand Down Expand Up @@ -286,7 +332,9 @@ def set_default(x, default):


class ApproxDecimal(ApproxScalar):
from decimal import Decimal
"""
Perform approximate comparisons where the expected value is a decimal.
"""

DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
Expand Down Expand Up @@ -445,32 +493,35 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__
"""

from decimal import Decimal

# Delegate the comparison to a class that knows how to deal with the type
# of the expected value (e.g. int, float, list, dict, numpy.array, etc).
#
# This architecture is really driven by the need to support numpy arrays.
# The only way to override `==` for arrays without requiring that approx be
# the left operand is to inherit the approx object from `numpy.ndarray`.
# But that can't be a general solution, because it requires (1) numpy to be
# installed and (2) the expected value to be a numpy array. So the general
# solution is to delegate each type of expected value to a different class.
# The primary responsibility of these classes is to implement ``__eq__()``
# and ``__repr__()``. The former is used to actually check if some
# "actual" value is equivalent to the given expected value within the
# allowed tolerance. The latter is used to show the user the expected
# value and tolerance, in the case that a test failed.
#
# This has the advantage that it made it easy to support mapping types
# (i.e. dict). The old code accepted mapping types, but would only compare
# their keys, which is probably not what most people would expect.
# The actual logic for making approximate comparisons can be found in
# ApproxScalar, which is used to compare individual numbers. All of the
# other Approx classes eventually delegate to this class. The ApproxBase
# class provides some convenient methods and overloads, but isn't really
# essential.

if _is_numpy_array(expected):
cls = ApproxNumpy
__tracebackhide__ = True

if isinstance(expected, Decimal):
cls = ApproxDecimal
elif isinstance(expected, Number):
cls = ApproxScalar
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif isinstance(expected, Sequence) and not isinstance(expected, STRING_TYPES):
cls = ApproxSequence
elif isinstance(expected, Decimal):
cls = ApproxDecimal
elif _is_numpy_array(expected):
cls = ApproxNumpy
else:
cls = ApproxScalar
raise _non_numeric_type_error(expected, at=None)

return cls(expected, rel, abs, nan_ok)

Expand Down
31 changes: 21 additions & 10 deletions testing/python/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,21 @@ def test_repr_string(self, plus_minus):
),
)

def test_repr_0d_array(self, plus_minus):
@pytest.mark.parametrize(
"value, repr_string",
[
(5., "approx(5.0 {pm} 5.0e-06)"),
([5.], "approx([5.0 {pm} 5.0e-06])"),
([[5.]], "approx([[5.0 {pm} 5.0e-06]])"),
([[5., 6.]], "approx([[5.0 {pm} 5.0e-06, 6.0 {pm} 6.0e-06]])"),
([[5.], [6.]], "approx([[5.0 {pm} 5.0e-06], [6.0 {pm} 6.0e-06]])"),
],
)
def test_repr_nd_array(self, plus_minus, value, repr_string):
"""Make sure that arrays of all different dimensions are repr'd correctly."""
np = pytest.importorskip("numpy")
np_array = np.array(5.)
assert approx(np_array) == 5.0
string_expected = "approx([5.0 {} 5.0e-06])".format(plus_minus)

assert repr(approx(np_array)) == string_expected

np_array = np.array([5.])
assert approx(np_array) == 5.0
assert repr(approx(np_array)) == string_expected
np_array = np.array(value)
assert repr(approx(np_array)) == repr_string.format(pm=plus_minus)

def test_operator_overloading(self):
assert 1 == approx(1, rel=1e-6, abs=1e-12)
Expand Down Expand Up @@ -439,6 +443,13 @@ def test_foo():
["*At index 0 diff: 3 != 4 * {}".format(expected), "=* 1 failed in *="]
)

@pytest.mark.parametrize(
"x", [None, "string", ["string"], [[1]], {"key": "string"}, {"key": {"key": 1}}]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe test ids would help here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

)
def test_expected_value_type_error(self, x):
with pytest.raises(TypeError):
approx(x)

@pytest.mark.parametrize(
"op",
[
Expand Down