diff options
Diffstat (limited to 'src/_pytest')
-rw-r--r-- | src/_pytest/python_api.py | 33 |
1 files changed, 27 insertions, 6 deletions
diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index bae207689..81ce4f895 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -15,9 +15,14 @@ from typing import overload from typing import Pattern from typing import Tuple from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +if TYPE_CHECKING: + from numpy import ndarray + + import _pytest._code from _pytest.compat import final from _pytest.compat import STRING_TYPES @@ -232,10 +237,11 @@ class ApproxScalar(ApproxBase): def __eq__(self, actual) -> bool: """Return whether the given value is equal to the expected value within the pre-specified tolerance.""" - if _is_numpy_array(actual): + asarray = _as_numpy_array(actual) + if asarray is not None: # Call ``__eq__()`` manually to prevent infinite-recursion with # numpy<1.13. See #3748. - return all(self.__eq__(a) for a in actual.flat) + return all(self.__eq__(a) for a in asarray.flat) # Short-circuit exact equality. if actual == self.expected: @@ -521,6 +527,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase: elif isinstance(expected, Mapping): cls = ApproxMapping elif _is_numpy_array(expected): + expected = _as_numpy_array(expected) cls = ApproxNumpy elif ( isinstance(expected, Iterable) @@ -536,16 +543,30 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase: def _is_numpy_array(obj: object) -> bool: - """Return true if the given object is a numpy array. + """ + Return true if the given object is implicitly convertible to ndarray, + and numpy is already imported. + """ + return _as_numpy_array(obj) is not None + - A special effort is made to avoid importing numpy unless it's really necessary. +def _as_numpy_array(obj: object) -> Optional["ndarray"]: + """ + Return an ndarray if the given object is implicitly convertible to ndarray, + and numpy is already imported, otherwise None. """ import sys np: Any = sys.modules.get("numpy") if np is not None: - return isinstance(obj, np.ndarray) - return False + # avoid infinite recursion on numpy scalars, which have __array__ + if np.isscalar(obj): + return None + elif isinstance(obj, np.ndarray): + return obj + elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"): + return np.asarray(obj) + return None # builtin pytest.raises helper |