summaryrefslogtreecommitdiff
path: root/src/_pytest
diff options
context:
space:
mode:
Diffstat (limited to 'src/_pytest')
-rw-r--r--src/_pytest/python_api.py33
1 files changed, 27 insertions, 6 deletions
diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py
index bae207689..81ce4f895 100644
--- a/src/_pytest/python_api.py
+++ b/src/_pytest/python_api.py
@@ -15,9 +15,14 @@ from typing import overload
from typing import Pattern
from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
+if TYPE_CHECKING:
+ from numpy import ndarray
+
+
import _pytest._code
from _pytest.compat import final
from _pytest.compat import STRING_TYPES
@@ -232,10 +237,11 @@ class ApproxScalar(ApproxBase):
def __eq__(self, actual) -> bool:
"""Return whether the given value is equal to the expected value
within the pre-specified tolerance."""
- if _is_numpy_array(actual):
+ asarray = _as_numpy_array(actual)
+ if asarray is not None:
# Call ``__eq__()`` manually to prevent infinite-recursion with
# numpy<1.13. See #3748.
- return all(self.__eq__(a) for a in actual.flat)
+ return all(self.__eq__(a) for a in asarray.flat)
# Short-circuit exact equality.
if actual == self.expected:
@@ -521,6 +527,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif _is_numpy_array(expected):
+ expected = _as_numpy_array(expected)
cls = ApproxNumpy
elif (
isinstance(expected, Iterable)
@@ -536,16 +543,30 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
def _is_numpy_array(obj: object) -> bool:
- """Return true if the given object is a numpy array.
+ """
+ Return true if the given object is implicitly convertible to ndarray,
+ and numpy is already imported.
+ """
+ return _as_numpy_array(obj) is not None
+
- A special effort is made to avoid importing numpy unless it's really necessary.
+def _as_numpy_array(obj: object) -> Optional["ndarray"]:
+ """
+ Return an ndarray if the given object is implicitly convertible to ndarray,
+ and numpy is already imported, otherwise None.
"""
import sys
np: Any = sys.modules.get("numpy")
if np is not None:
- return isinstance(obj, np.ndarray)
- return False
+ # avoid infinite recursion on numpy scalars, which have __array__
+ if np.isscalar(obj):
+ return None
+ elif isinstance(obj, np.ndarray):
+ return obj
+ elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"):
+ return np.asarray(obj)
+ return None
# builtin pytest.raises helper