summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorRan Benita <ran@unusedvar.com>2020-10-27 16:07:03 +0200
committerRan Benita <ran@unusedvar.com>2020-10-31 12:40:25 +0200
commit531416cc5a85e7e90c03ad75962fa5caf92fcf36 (patch)
treede1378929e4a765f0085d2a299d47830d1270402 /src
parent6506f016acf77415b7d682bf15cac865ab39273f (diff)
downloadpytest-531416cc5a85e7e90c03ad75962fa5caf92fcf36.tar.gz
code: simplify Code construction
Diffstat (limited to 'src')
-rw-r--r--src/_pytest/_code/code.py14
-rw-r--r--src/_pytest/_code/source.py26
-rw-r--r--src/_pytest/python.py2
3 files changed, 22 insertions, 20 deletions
diff --git a/src/_pytest/_code/code.py b/src/_pytest/_code/code.py
index 430e45242..423069330 100644
--- a/src/_pytest/_code/code.py
+++ b/src/_pytest/_code/code.py
@@ -56,12 +56,12 @@ class Code:
__slots__ = ("raw",)
- def __init__(self, rawcode) -> None:
- if not hasattr(rawcode, "co_filename"):
- rawcode = getrawcode(rawcode)
- if not isinstance(rawcode, CodeType):
- raise TypeError(f"not a code object: {rawcode!r}")
- self.raw = rawcode
+ def __init__(self, obj: CodeType) -> None:
+ self.raw = obj
+
+ @classmethod
+ def from_function(cls, obj: object) -> "Code":
+ return cls(getrawcode(obj))
def __eq__(self, other):
return self.raw == other.raw
@@ -1196,7 +1196,7 @@ def getfslineno(obj: object) -> Tuple[Union[str, py.path.local], int]:
obj = obj.place_as # type: ignore[attr-defined]
try:
- code = Code(obj)
+ code = Code.from_function(obj)
except TypeError:
try:
fn = inspect.getsourcefile(obj) or inspect.getfile(obj) # type: ignore[arg-type]
diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py
index c63a42360..6f54057c0 100644
--- a/src/_pytest/_code/source.py
+++ b/src/_pytest/_code/source.py
@@ -2,6 +2,7 @@ import ast
import inspect
import textwrap
import tokenize
+import types
import warnings
from bisect import bisect_right
from typing import Iterable
@@ -29,8 +30,11 @@ class Source:
elif isinstance(obj, str):
self.lines = deindent(obj.split("\n"))
else:
- rawcode = getrawcode(obj)
- src = inspect.getsource(rawcode)
+ try:
+ rawcode = getrawcode(obj)
+ src = inspect.getsource(rawcode)
+ except TypeError:
+ src = inspect.getsource(obj) # type: ignore[arg-type]
self.lines = deindent(src.split("\n"))
def __eq__(self, other: object) -> bool:
@@ -122,19 +126,17 @@ def findsource(obj) -> Tuple[Optional[Source], int]:
return source, lineno
-def getrawcode(obj, trycall: bool = True):
+def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
"""Return code object for given function."""
try:
- return obj.__code__
+ return obj.__code__ # type: ignore[attr-defined,no-any-return]
except AttributeError:
- obj = getattr(obj, "f_code", obj)
- obj = getattr(obj, "__code__", obj)
- if trycall and not hasattr(obj, "co_firstlineno"):
- if hasattr(obj, "__call__") and not inspect.isclass(obj):
- x = getrawcode(obj.__call__, trycall=False)
- if hasattr(x, "co_firstlineno"):
- return x
- return obj
+ pass
+ if trycall:
+ call = getattr(obj, "__call__", None)
+ if call and not isinstance(obj, type):
+ return getrawcode(call, trycall=False)
+ raise TypeError(f"could not get code object for {obj!r}")
def deindent(lines: Iterable[str]) -> List[str]:
diff --git a/src/_pytest/python.py b/src/_pytest/python.py
index 35797cc07..e477b8b45 100644
--- a/src/_pytest/python.py
+++ b/src/_pytest/python.py
@@ -1647,7 +1647,7 @@ class Function(PyobjMixin, nodes.Item):
def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
if hasattr(self, "_obj") and not self.config.getoption("fulltrace", False):
- code = _pytest._code.Code(get_real_func(self.obj))
+ code = _pytest._code.Code.from_function(get_real_func(self.obj))
path, firstlineno = code.path, code.firstlineno
traceback = excinfo.traceback
ntraceback = traceback.cut(path=path, firstlineno=firstlineno)