summaryrefslogtreecommitdiff
path: root/src/_pytest/python_api.py
diff options
context:
space:
mode:
authorRan Benita <ran234@gmail.com>2019-07-10 20:12:41 +0300
committerRan Benita <ran234@gmail.com>2019-07-14 14:28:24 +0300
commit14bf4cdf44be4d8e2482b1f2b9cafeba06c03550 (patch)
tree07db3ac81a850d5992c9052a332936c5847993a0 /src/_pytest/python_api.py
parent56dcc9e1f884dc9f5f699c975a303cb0a97ccfa9 (diff)
downloadpytest-14bf4cdf44be4d8e2482b1f2b9cafeba06c03550.tar.gz
Make ExceptionInfo generic in the exception type
This way, in with pytest.raises(ValueError) as cm: ... cm.value is a ValueError and not a BaseException.
Diffstat (limited to 'src/_pytest/python_api.py')
-rw-r--r--src/_pytest/python_api.py33
1 files changed, 20 insertions, 13 deletions
diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py
index 9ede24df6..7ca545878 100644
--- a/src/_pytest/python_api.py
+++ b/src/_pytest/python_api.py
@@ -10,10 +10,13 @@ from numbers import Number
from types import TracebackType
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Generic
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
+from typing import TypeVar
from typing import Union
from more_itertools.more import always_iterable
@@ -537,33 +540,35 @@ def _is_numpy_array(obj):
# builtin pytest.raises helper
+_E = TypeVar("_E", bound=BaseException)
+
@overload
def raises(
- expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
+ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
-) -> "RaisesContext":
+) -> "RaisesContext[_E]":
... # pragma: no cover
@overload
def raises(
- expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
+ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable,
*args: Any,
match: Optional[str] = ...,
**kwargs: Any
-) -> Optional[_pytest._code.ExceptionInfo]:
+) -> Optional[_pytest._code.ExceptionInfo[_E]]:
... # pragma: no cover
def raises(
- expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
+ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
-) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
+) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
r"""
Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise.
@@ -703,28 +708,30 @@ def raises(
try:
func(*args[1:], **kwargs)
except expected_exception:
- return _pytest._code.ExceptionInfo.from_current()
+ # Cast to narrow the type to expected_exception (_E).
+ return cast(
+ _pytest._code.ExceptionInfo[_E],
+ _pytest._code.ExceptionInfo.from_current(),
+ )
fail(message)
raises.Exception = fail.Exception # type: ignore
-class RaisesContext:
+class RaisesContext(Generic[_E]):
def __init__(
self,
- expected_exception: Union[
- "Type[BaseException]", Tuple["Type[BaseException]", ...]
- ],
+ expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
message: str,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
- self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
+ self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
- def __enter__(self) -> _pytest._code.ExceptionInfo:
+ def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo