diff options
author | Ran Benita <ran234@gmail.com> | 2019-07-10 20:12:41 +0300 |
---|---|---|
committer | Ran Benita <ran234@gmail.com> | 2019-07-14 14:28:24 +0300 |
commit | 14bf4cdf44be4d8e2482b1f2b9cafeba06c03550 (patch) | |
tree | 07db3ac81a850d5992c9052a332936c5847993a0 /src/_pytest/python_api.py | |
parent | 56dcc9e1f884dc9f5f699c975a303cb0a97ccfa9 (diff) | |
download | pytest-14bf4cdf44be4d8e2482b1f2b9cafeba06c03550.tar.gz |
Make ExceptionInfo generic in the exception type
This way, in
with pytest.raises(ValueError) as cm:
...
cm.value is a ValueError and not a BaseException.
Diffstat (limited to 'src/_pytest/python_api.py')
-rw-r--r-- | src/_pytest/python_api.py | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 9ede24df6..7ca545878 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -10,10 +10,13 @@ from numbers import Number from types import TracebackType from typing import Any from typing import Callable +from typing import cast +from typing import Generic from typing import Optional from typing import overload from typing import Pattern from typing import Tuple +from typing import TypeVar from typing import Union from more_itertools.more import always_iterable @@ -537,33 +540,35 @@ def _is_numpy_array(obj): # builtin pytest.raises helper +_E = TypeVar("_E", bound=BaseException) + @overload def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], *, match: Optional[Union[str, Pattern]] = ... -) -> "RaisesContext": +) -> "RaisesContext[_E]": ... # pragma: no cover @overload def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], func: Callable, *args: Any, match: Optional[str] = ..., **kwargs: Any -) -> Optional[_pytest._code.ExceptionInfo]: +) -> Optional[_pytest._code.ExceptionInfo[_E]]: ... # pragma: no cover def raises( - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], *args: Any, match: Optional[Union[str, Pattern]] = None, **kwargs: Any -) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: +) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]: r""" Assert that a code block/function call raises ``expected_exception`` or raise a failure exception otherwise. @@ -703,28 +708,30 @@ def raises( try: func(*args[1:], **kwargs) except expected_exception: - return _pytest._code.ExceptionInfo.from_current() + # Cast to narrow the type to expected_exception (_E). + return cast( + _pytest._code.ExceptionInfo[_E], + _pytest._code.ExceptionInfo.from_current(), + ) fail(message) raises.Exception = fail.Exception # type: ignore -class RaisesContext: +class RaisesContext(Generic[_E]): def __init__( self, - expected_exception: Union[ - "Type[BaseException]", Tuple["Type[BaseException]", ...] - ], + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], message: str, match_expr: Optional[Union[str, Pattern]] = None, ) -> None: self.expected_exception = expected_exception self.message = message self.match_expr = match_expr - self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]] - def __enter__(self) -> _pytest._code.ExceptionInfo: + def __enter__(self) -> _pytest._code.ExceptionInfo[_E]: self.excinfo = _pytest._code.ExceptionInfo.for_later() return self.excinfo |