diff options
author | Ran Benita <ran@unusedvar.com> | 2020-10-27 16:07:03 +0200 |
---|---|---|
committer | Ran Benita <ran@unusedvar.com> | 2020-10-31 12:40:25 +0200 |
commit | 531416cc5a85e7e90c03ad75962fa5caf92fcf36 (patch) | |
tree | de1378929e4a765f0085d2a299d47830d1270402 /src | |
parent | 6506f016acf77415b7d682bf15cac865ab39273f (diff) | |
download | pytest-531416cc5a85e7e90c03ad75962fa5caf92fcf36.tar.gz |
code: simplify Code construction
Diffstat (limited to 'src')
-rw-r--r-- | src/_pytest/_code/code.py | 14 | ||||
-rw-r--r-- | src/_pytest/_code/source.py | 26 | ||||
-rw-r--r-- | src/_pytest/python.py | 2 |
3 files changed, 22 insertions, 20 deletions
diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py index 430e45242..423069330 100644 --- a/src/_pytest/_code/code.py +++ b/src/_pytest/_code/code.py @@ -56,12 +56,12 @@ class Code: __slots__ = ("raw",) - def __init__(self, rawcode) -> None: - if not hasattr(rawcode, "co_filename"): - rawcode = getrawcode(rawcode) - if not isinstance(rawcode, CodeType): - raise TypeError(f"not a code object: {rawcode!r}") - self.raw = rawcode + def __init__(self, obj: CodeType) -> None: + self.raw = obj + + @classmethod + def from_function(cls, obj: object) -> "Code": + return cls(getrawcode(obj)) def __eq__(self, other): return self.raw == other.raw @@ -1196,7 +1196,7 @@ def getfslineno(obj: object) -> Tuple[Union[str, py.path.local], int]: obj = obj.place_as # type: ignore[attr-defined] try: - code = Code(obj) + code = Code.from_function(obj) except TypeError: try: fn = inspect.getsourcefile(obj) or inspect.getfile(obj) # type: ignore[arg-type] diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index c63a42360..6f54057c0 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -2,6 +2,7 @@ import ast import inspect import textwrap import tokenize +import types import warnings from bisect import bisect_right from typing import Iterable @@ -29,8 +30,11 @@ class Source: elif isinstance(obj, str): self.lines = deindent(obj.split("\n")) else: - rawcode = getrawcode(obj) - src = inspect.getsource(rawcode) + try: + rawcode = getrawcode(obj) + src = inspect.getsource(rawcode) + except TypeError: + src = inspect.getsource(obj) # type: ignore[arg-type] self.lines = deindent(src.split("\n")) def __eq__(self, other: object) -> bool: @@ -122,19 +126,17 @@ def findsource(obj) -> Tuple[Optional[Source], int]: return source, lineno -def getrawcode(obj, trycall: bool = True): +def getrawcode(obj: object, trycall: bool = True) -> types.CodeType: """Return code object for given function.""" try: - return obj.__code__ + return obj.__code__ # type: ignore[attr-defined,no-any-return] except AttributeError: - obj = getattr(obj, "f_code", obj) - obj = getattr(obj, "__code__", obj) - if trycall and not hasattr(obj, "co_firstlineno"): - if hasattr(obj, "__call__") and not inspect.isclass(obj): - x = getrawcode(obj.__call__, trycall=False) - if hasattr(x, "co_firstlineno"): - return x - return obj + pass + if trycall: + call = getattr(obj, "__call__", None) + if call and not isinstance(obj, type): + return getrawcode(call, trycall=False) + raise TypeError(f"could not get code object for {obj!r}") def deindent(lines: Iterable[str]) -> List[str]: diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 35797cc07..e477b8b45 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -1647,7 +1647,7 @@ class Function(PyobjMixin, nodes.Item): def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None: if hasattr(self, "_obj") and not self.config.getoption("fulltrace", False): - code = _pytest._code.Code(get_real_func(self.obj)) + code = _pytest._code.Code.from_function(get_real_func(self.obj)) path, firstlineno = code.path, code.firstlineno traceback = excinfo.traceback ntraceback = traceback.cut(path=path, firstlineno=firstlineno) |