summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRan Benita <ran234@gmail.com>2019-07-10 20:12:41 +0300
committerRan Benita <ran234@gmail.com>2019-07-14 14:28:24 +0300
commit14bf4cdf44be4d8e2482b1f2b9cafeba06c03550 (patch)
tree07db3ac81a850d5992c9052a332936c5847993a0
parent56dcc9e1f884dc9f5f699c975a303cb0a97ccfa9 (diff)
downloadpytest-14bf4cdf44be4d8e2482b1f2b9cafeba06c03550.tar.gz
Make ExceptionInfo generic in the exception type
This way, in with pytest.raises(ValueError) as cm: ... cm.value is a ValueError and not a BaseException.
-rw-r--r--src/_pytest/_code/code.py21
-rw-r--r--src/_pytest/python_api.py33
2 files changed, 33 insertions, 21 deletions
diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py
index d9b06ffd9..203e90287 100644
--- a/src/_pytest/_code/code.py
+++ b/src/_pytest/_code/code.py
@@ -6,9 +6,11 @@ from inspect import CO_VARARGS
from inspect import CO_VARKEYWORDS
from traceback import format_exception_only
from types import TracebackType
+from typing import Generic
from typing import Optional
from typing import Pattern
from typing import Tuple
+from typing import TypeVar
from typing import Union
from weakref import ref
@@ -379,22 +381,25 @@ co_equal = compile(
)
+_E = TypeVar("_E", bound=BaseException)
+
+
@attr.s(repr=False)
-class ExceptionInfo:
+class ExceptionInfo(Generic[_E]):
""" wraps sys.exc_info() objects and offers
help for navigating the traceback.
"""
_assert_start_repr = "AssertionError('assert "
- _excinfo = attr.ib(
- type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]]
- )
+ _excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
_striptext = attr.ib(type=str, default="")
_traceback = attr.ib(type=Optional[Traceback], default=None)
@classmethod
- def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo":
+ def from_current(
+ cls, exprinfo: Optional[str] = None
+ ) -> "ExceptionInfo[BaseException]":
"""returns an ExceptionInfo matching the current traceback
.. warning::
@@ -422,13 +427,13 @@ class ExceptionInfo:
return cls(tup, _striptext)
@classmethod
- def for_later(cls) -> "ExceptionInfo":
+ def for_later(cls) -> "ExceptionInfo[_E]":
"""return an unfilled ExceptionInfo
"""
return cls(None)
@property
- def type(self) -> "Type[BaseException]":
+ def type(self) -> "Type[_E]":
"""the exception class"""
assert (
self._excinfo is not None
@@ -436,7 +441,7 @@ class ExceptionInfo:
return self._excinfo[0]
@property
- def value(self) -> BaseException:
+ def value(self) -> _E:
"""the exception value"""
assert (
self._excinfo is not None
diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py
index 9ede24df6..7ca545878 100644
--- a/src/_pytest/python_api.py
+++ b/src/_pytest/python_api.py
@@ -10,10 +10,13 @@ from numbers import Number
from types import TracebackType
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Generic
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
+from typing import TypeVar
from typing import Union
from more_itertools.more import always_iterable
@@ -537,33 +540,35 @@ def _is_numpy_array(obj):
# builtin pytest.raises helper
+_E = TypeVar("_E", bound=BaseException)
+
@overload
def raises(
- expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
+ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
-) -> "RaisesContext":
+) -> "RaisesContext[_E]":
... # pragma: no cover
@overload
def raises(
- expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
+ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable,
*args: Any,
match: Optional[str] = ...,
**kwargs: Any
-) -> Optional[_pytest._code.ExceptionInfo]:
+) -> Optional[_pytest._code.ExceptionInfo[_E]]:
... # pragma: no cover
def raises(
- expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
+ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
-) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
+) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
r"""
Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise.
@@ -703,28 +708,30 @@ def raises(
try:
func(*args[1:], **kwargs)
except expected_exception:
- return _pytest._code.ExceptionInfo.from_current()
+ # Cast to narrow the type to expected_exception (_E).
+ return cast(
+ _pytest._code.ExceptionInfo[_E],
+ _pytest._code.ExceptionInfo.from_current(),
+ )
fail(message)
raises.Exception = fail.Exception # type: ignore
-class RaisesContext:
+class RaisesContext(Generic[_E]):
def __init__(
self,
- expected_exception: Union[
- "Type[BaseException]", Tuple["Type[BaseException]", ...]
- ],
+ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
message: str,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
- self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
+ self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
- def __enter__(self) -> _pytest._code.ExceptionInfo:
+ def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo