1108 lines
42 KiB
Python
1108 lines
42 KiB
Python
"""Rewrite assertion AST to produce nice error messages."""
|
|
import ast
|
|
import errno
|
|
import functools
|
|
import importlib.abc
|
|
import importlib.machinery
|
|
import importlib.util
|
|
import io
|
|
import itertools
|
|
import marshal
|
|
import os
|
|
import struct
|
|
import sys
|
|
import tokenize
|
|
import types
|
|
from pathlib import Path
|
|
from pathlib import PurePath
|
|
from typing import Callable
|
|
from typing import Dict
|
|
from typing import IO
|
|
from typing import Iterable
|
|
from typing import Iterator
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Sequence
|
|
from typing import Set
|
|
from typing import Tuple
|
|
from typing import TYPE_CHECKING
|
|
from typing import Union
|
|
|
|
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
|
|
from _pytest._io.saferepr import saferepr
|
|
from _pytest._version import version
|
|
from _pytest.assertion import util
|
|
from _pytest.assertion.util import ( # noqa: F401
|
|
format_explanation as _format_explanation,
|
|
)
|
|
from _pytest.config import Config
|
|
from _pytest.main import Session
|
|
from _pytest.pathlib import absolutepath
|
|
from _pytest.pathlib import fnmatch_ex
|
|
from _pytest.stash import StashKey
|
|
|
|
if TYPE_CHECKING:
|
|
from _pytest.assertion import AssertionState
|
|
|
|
|
|
assertstate_key = StashKey["AssertionState"]()
|
|
|
|
|
|
# pytest caches rewritten pycs in pycache dirs
|
|
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
|
|
PYC_EXT = ".py" + (__debug__ and "c" or "o")
|
|
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
|
|
|
|
|
|
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
|
|
"""PEP302/PEP451 import hook which rewrites asserts."""
|
|
|
|
def __init__(self, config: Config) -> None:
|
|
self.config = config
|
|
try:
|
|
self.fnpats = config.getini("python_files")
|
|
except ValueError:
|
|
self.fnpats = ["test_*.py", "*_test.py"]
|
|
self.session: Optional[Session] = None
|
|
self._rewritten_names: Dict[str, Path] = {}
|
|
self._must_rewrite: Set[str] = set()
|
|
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
|
|
# which might result in infinite recursion (#3506)
|
|
self._writing_pyc = False
|
|
self._basenames_to_check_rewrite = {"conftest"}
|
|
self._marked_for_rewrite_cache: Dict[str, bool] = {}
|
|
self._session_paths_checked = False
|
|
|
|
def set_session(self, session: Optional[Session]) -> None:
|
|
self.session = session
|
|
self._session_paths_checked = False
|
|
|
|
# Indirection so we can mock calls to find_spec originated from the hook during testing
|
|
_find_spec = importlib.machinery.PathFinder.find_spec
|
|
|
|
def find_spec(
|
|
self,
|
|
name: str,
|
|
path: Optional[Sequence[Union[str, bytes]]] = None,
|
|
target: Optional[types.ModuleType] = None,
|
|
) -> Optional[importlib.machinery.ModuleSpec]:
|
|
if self._writing_pyc:
|
|
return None
|
|
state = self.config.stash[assertstate_key]
|
|
if self._early_rewrite_bailout(name, state):
|
|
return None
|
|
state.trace("find_module called for: %s" % name)
|
|
|
|
# Type ignored because mypy is confused about the `self` binding here.
|
|
spec = self._find_spec(name, path) # type: ignore
|
|
if (
|
|
# the import machinery could not find a file to import
|
|
spec is None
|
|
# this is a namespace package (without `__init__.py`)
|
|
# there's nothing to rewrite there
|
|
or spec.origin is None
|
|
# we can only rewrite source files
|
|
or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
|
|
# if the file doesn't exist, we can't rewrite it
|
|
or not os.path.exists(spec.origin)
|
|
):
|
|
return None
|
|
else:
|
|
fn = spec.origin
|
|
|
|
if not self._should_rewrite(name, fn, state):
|
|
return None
|
|
|
|
return importlib.util.spec_from_file_location(
|
|
name,
|
|
fn,
|
|
loader=self,
|
|
submodule_search_locations=spec.submodule_search_locations,
|
|
)
|
|
|
|
def create_module(
|
|
self, spec: importlib.machinery.ModuleSpec
|
|
) -> Optional[types.ModuleType]:
|
|
return None # default behaviour is fine
|
|
|
|
def exec_module(self, module: types.ModuleType) -> None:
|
|
assert module.__spec__ is not None
|
|
assert module.__spec__.origin is not None
|
|
fn = Path(module.__spec__.origin)
|
|
state = self.config.stash[assertstate_key]
|
|
|
|
self._rewritten_names[module.__name__] = fn
|
|
|
|
# The requested module looks like a test file, so rewrite it. This is
|
|
# the most magical part of the process: load the source, rewrite the
|
|
# asserts, and load the rewritten source. We also cache the rewritten
|
|
# module code in a special pyc. We must be aware of the possibility of
|
|
# concurrent pytest processes rewriting and loading pycs. To avoid
|
|
# tricky race conditions, we maintain the following invariant: The
|
|
# cached pyc is always a complete, valid pyc. Operations on it must be
|
|
# atomic. POSIX's atomic rename comes in handy.
|
|
write = not sys.dont_write_bytecode
|
|
cache_dir = get_cache_dir(fn)
|
|
if write:
|
|
ok = try_makedirs(cache_dir)
|
|
if not ok:
|
|
write = False
|
|
state.trace(f"read only directory: {cache_dir}")
|
|
|
|
cache_name = fn.name[:-3] + PYC_TAIL
|
|
pyc = cache_dir / cache_name
|
|
# Notice that even if we're in a read-only directory, I'm going
|
|
# to check for a cached pyc. This may not be optimal...
|
|
co = _read_pyc(fn, pyc, state.trace)
|
|
if co is None:
|
|
state.trace(f"rewriting {fn!r}")
|
|
source_stat, co = _rewrite_test(fn, self.config)
|
|
if write:
|
|
self._writing_pyc = True
|
|
try:
|
|
_write_pyc(state, co, source_stat, pyc)
|
|
finally:
|
|
self._writing_pyc = False
|
|
else:
|
|
state.trace(f"found cached rewritten pyc for {fn}")
|
|
exec(co, module.__dict__)
|
|
|
|
def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
|
|
"""A fast way to get out of rewriting modules.
|
|
|
|
Profiling has shown that the call to PathFinder.find_spec (inside of
|
|
the find_spec from this class) is a major slowdown, so, this method
|
|
tries to filter what we're sure won't be rewritten before getting to
|
|
it.
|
|
"""
|
|
if self.session is not None and not self._session_paths_checked:
|
|
self._session_paths_checked = True
|
|
for initial_path in self.session._initialpaths:
|
|
# Make something as c:/projects/my_project/path.py ->
|
|
# ['c:', 'projects', 'my_project', 'path.py']
|
|
parts = str(initial_path).split(os.path.sep)
|
|
# add 'path' to basenames to be checked.
|
|
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
|
|
|
|
# Note: conftest already by default in _basenames_to_check_rewrite.
|
|
parts = name.split(".")
|
|
if parts[-1] in self._basenames_to_check_rewrite:
|
|
return False
|
|
|
|
# For matching the name it must be as if it was a filename.
|
|
path = PurePath(*parts).with_suffix(".py")
|
|
|
|
for pat in self.fnpats:
|
|
# if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
|
|
# on the name alone because we need to match against the full path
|
|
if os.path.dirname(pat):
|
|
return False
|
|
if fnmatch_ex(pat, path):
|
|
return False
|
|
|
|
if self._is_marked_for_rewrite(name, state):
|
|
return False
|
|
|
|
state.trace(f"early skip of rewriting module: {name}")
|
|
return True
|
|
|
|
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
|
|
# always rewrite conftest files
|
|
if os.path.basename(fn) == "conftest.py":
|
|
state.trace(f"rewriting conftest file: {fn!r}")
|
|
return True
|
|
|
|
if self.session is not None:
|
|
if self.session.isinitpath(absolutepath(fn)):
|
|
state.trace(f"matched test file (was specified on cmdline): {fn!r}")
|
|
return True
|
|
|
|
# modules not passed explicitly on the command line are only
|
|
# rewritten if they match the naming convention for test files
|
|
fn_path = PurePath(fn)
|
|
for pat in self.fnpats:
|
|
if fnmatch_ex(pat, fn_path):
|
|
state.trace(f"matched test file {fn!r}")
|
|
return True
|
|
|
|
return self._is_marked_for_rewrite(name, state)
|
|
|
|
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
|
|
try:
|
|
return self._marked_for_rewrite_cache[name]
|
|
except KeyError:
|
|
for marked in self._must_rewrite:
|
|
if name == marked or name.startswith(marked + "."):
|
|
state.trace(f"matched marked file {name!r} (from {marked!r})")
|
|
self._marked_for_rewrite_cache[name] = True
|
|
return True
|
|
|
|
self._marked_for_rewrite_cache[name] = False
|
|
return False
|
|
|
|
def mark_rewrite(self, *names: str) -> None:
|
|
"""Mark import names as needing to be rewritten.
|
|
|
|
The named module or package as well as any nested modules will
|
|
be rewritten on import.
|
|
"""
|
|
already_imported = (
|
|
set(names).intersection(sys.modules).difference(self._rewritten_names)
|
|
)
|
|
for name in already_imported:
|
|
mod = sys.modules[name]
|
|
if not AssertionRewriter.is_rewrite_disabled(
|
|
mod.__doc__ or ""
|
|
) and not isinstance(mod.__loader__, type(self)):
|
|
self._warn_already_imported(name)
|
|
self._must_rewrite.update(names)
|
|
self._marked_for_rewrite_cache.clear()
|
|
|
|
def _warn_already_imported(self, name: str) -> None:
|
|
from _pytest.warning_types import PytestAssertRewriteWarning
|
|
|
|
self.config.issue_config_time_warning(
|
|
PytestAssertRewriteWarning(
|
|
"Module already imported so cannot be rewritten: %s" % name
|
|
),
|
|
stacklevel=5,
|
|
)
|
|
|
|
def get_data(self, pathname: Union[str, bytes]) -> bytes:
|
|
"""Optional PEP302 get_data API."""
|
|
with open(pathname, "rb") as f:
|
|
return f.read()
|
|
|
|
if sys.version_info >= (3, 10):
|
|
|
|
def get_resource_reader(self, name: str) -> importlib.abc.TraversableResources: # type: ignore
|
|
if sys.version_info < (3, 11):
|
|
from importlib.readers import FileReader
|
|
else:
|
|
from importlib.resources.readers import FileReader
|
|
|
|
return FileReader( # type:ignore[no-any-return]
|
|
types.SimpleNamespace(path=self._rewritten_names[name])
|
|
)
|
|
|
|
|
|
def _write_pyc_fp(
|
|
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
|
|
) -> None:
|
|
# Technically, we don't have to have the same pyc format as
|
|
# (C)Python, since these "pycs" should never be seen by builtin
|
|
# import. However, there's little reason to deviate.
|
|
fp.write(importlib.util.MAGIC_NUMBER)
|
|
# https://www.python.org/dev/peps/pep-0552/
|
|
flags = b"\x00\x00\x00\x00"
|
|
fp.write(flags)
|
|
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
|
|
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
|
|
size = source_stat.st_size & 0xFFFFFFFF
|
|
# "<LL" stands for 2 unsigned longs, little-endian.
|
|
fp.write(struct.pack("<LL", mtime, size))
|
|
fp.write(marshal.dumps(co))
|
|
|
|
|
|
def _write_pyc(
|
|
state: "AssertionState",
|
|
co: types.CodeType,
|
|
source_stat: os.stat_result,
|
|
pyc: Path,
|
|
) -> bool:
|
|
proc_pyc = f"{pyc}.{os.getpid()}"
|
|
try:
|
|
with open(proc_pyc, "wb") as fp:
|
|
_write_pyc_fp(fp, source_stat, co)
|
|
except OSError as e:
|
|
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
|
|
return False
|
|
|
|
try:
|
|
os.replace(proc_pyc, pyc)
|
|
except OSError as e:
|
|
state.trace(f"error writing pyc file at {pyc}: {e}")
|
|
# we ignore any failure to write the cache file
|
|
# there are many reasons, permission-denied, pycache dir being a
|
|
# file etc.
|
|
return False
|
|
return True
|
|
|
|
|
|
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
|
|
"""Read and rewrite *fn* and return the code object."""
|
|
stat = os.stat(fn)
|
|
source = fn.read_bytes()
|
|
strfn = str(fn)
|
|
tree = ast.parse(source, filename=strfn)
|
|
rewrite_asserts(tree, source, strfn, config)
|
|
co = compile(tree, strfn, "exec", dont_inherit=True)
|
|
return stat, co
|
|
|
|
|
|
def _read_pyc(
|
|
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
|
|
) -> Optional[types.CodeType]:
|
|
"""Possibly read a pytest pyc containing rewritten code.
|
|
|
|
Return rewritten code if successful or None if not.
|
|
"""
|
|
try:
|
|
fp = open(pyc, "rb")
|
|
except OSError:
|
|
return None
|
|
with fp:
|
|
try:
|
|
stat_result = os.stat(source)
|
|
mtime = int(stat_result.st_mtime)
|
|
size = stat_result.st_size
|
|
data = fp.read(16)
|
|
except OSError as e:
|
|
trace(f"_read_pyc({source}): OSError {e}")
|
|
return None
|
|
# Check for invalid or out of date pyc file.
|
|
if len(data) != (16):
|
|
trace("_read_pyc(%s): invalid pyc (too short)" % source)
|
|
return None
|
|
if data[:4] != importlib.util.MAGIC_NUMBER:
|
|
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
|
|
return None
|
|
if data[4:8] != b"\x00\x00\x00\x00":
|
|
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
|
|
return None
|
|
mtime_data = data[8:12]
|
|
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
|
|
trace("_read_pyc(%s): out of date" % source)
|
|
return None
|
|
size_data = data[12:16]
|
|
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
|
|
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
|
|
return None
|
|
try:
|
|
co = marshal.load(fp)
|
|
except Exception as e:
|
|
trace(f"_read_pyc({source}): marshal.load error {e}")
|
|
return None
|
|
if not isinstance(co, types.CodeType):
|
|
trace("_read_pyc(%s): not a code object" % source)
|
|
return None
|
|
return co
|
|
|
|
|
|
def rewrite_asserts(
|
|
mod: ast.Module,
|
|
source: bytes,
|
|
module_path: Optional[str] = None,
|
|
config: Optional[Config] = None,
|
|
) -> None:
|
|
"""Rewrite the assert statements in mod."""
|
|
AssertionRewriter(module_path, config, source).run(mod)
|
|
|
|
|
|
def _saferepr(obj: object) -> str:
|
|
r"""Get a safe repr of an object for assertion error messages.
|
|
|
|
The assertion formatting (util.format_explanation()) requires
|
|
newlines to be escaped since they are a special character for it.
|
|
Normally assertion.util.format_explanation() does this but for a
|
|
custom repr it is possible to contain one of the special escape
|
|
sequences, especially '\n{' and '\n}' are likely to be present in
|
|
JSON reprs.
|
|
"""
|
|
maxsize = _get_maxsize_for_saferepr(util._config)
|
|
return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
|
|
|
|
|
|
def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
|
|
"""Get `maxsize` configuration for saferepr based on the given config object."""
|
|
verbosity = config.getoption("verbose") if config is not None else 0
|
|
if verbosity >= 2:
|
|
return None
|
|
if verbosity >= 1:
|
|
return DEFAULT_REPR_MAX_SIZE * 10
|
|
return DEFAULT_REPR_MAX_SIZE
|
|
|
|
|
|
def _format_assertmsg(obj: object) -> str:
|
|
r"""Format the custom assertion message given.
|
|
|
|
For strings this simply replaces newlines with '\n~' so that
|
|
util.format_explanation() will preserve them instead of escaping
|
|
newlines. For other objects saferepr() is used first.
|
|
"""
|
|
# reprlib appears to have a bug which means that if a string
|
|
# contains a newline it gets escaped, however if an object has a
|
|
# .__repr__() which contains newlines it does not get escaped.
|
|
# However in either case we want to preserve the newline.
|
|
replaces = [("\n", "\n~"), ("%", "%%")]
|
|
if not isinstance(obj, str):
|
|
obj = saferepr(obj)
|
|
replaces.append(("\\n", "\n~"))
|
|
|
|
for r1, r2 in replaces:
|
|
obj = obj.replace(r1, r2)
|
|
|
|
return obj
|
|
|
|
|
|
def _should_repr_global_name(obj: object) -> bool:
|
|
if callable(obj):
|
|
return False
|
|
|
|
try:
|
|
return not hasattr(obj, "__name__")
|
|
except Exception:
|
|
return True
|
|
|
|
|
|
def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
|
|
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
|
|
return explanation.replace("%", "%%")
|
|
|
|
|
|
def _call_reprcompare(
|
|
ops: Sequence[str],
|
|
results: Sequence[bool],
|
|
expls: Sequence[str],
|
|
each_obj: Sequence[object],
|
|
) -> str:
|
|
for i, res, expl in zip(range(len(ops)), results, expls):
|
|
try:
|
|
done = not res
|
|
except Exception:
|
|
done = True
|
|
if done:
|
|
break
|
|
if util._reprcompare is not None:
|
|
custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
|
|
if custom is not None:
|
|
return custom
|
|
return expl
|
|
|
|
|
|
def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
|
|
if util._assertion_pass is not None:
|
|
util._assertion_pass(lineno, orig, expl)
|
|
|
|
|
|
def _check_if_assertion_pass_impl() -> bool:
|
|
"""Check if any plugins implement the pytest_assertion_pass hook
|
|
in order not to generate explanation unnecessarily (might be expensive)."""
|
|
return True if util._assertion_pass else False
|
|
|
|
|
|
UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
|
|
|
|
BINOP_MAP = {
|
|
ast.BitOr: "|",
|
|
ast.BitXor: "^",
|
|
ast.BitAnd: "&",
|
|
ast.LShift: "<<",
|
|
ast.RShift: ">>",
|
|
ast.Add: "+",
|
|
ast.Sub: "-",
|
|
ast.Mult: "*",
|
|
ast.Div: "/",
|
|
ast.FloorDiv: "//",
|
|
ast.Mod: "%%", # escaped for string formatting
|
|
ast.Eq: "==",
|
|
ast.NotEq: "!=",
|
|
ast.Lt: "<",
|
|
ast.LtE: "<=",
|
|
ast.Gt: ">",
|
|
ast.GtE: ">=",
|
|
ast.Pow: "**",
|
|
ast.Is: "is",
|
|
ast.IsNot: "is not",
|
|
ast.In: "in",
|
|
ast.NotIn: "not in",
|
|
ast.MatMult: "@",
|
|
}
|
|
|
|
|
|
def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
|
|
"""Recursively yield node and all its children in depth-first order."""
|
|
yield node
|
|
for child in ast.iter_child_nodes(node):
|
|
yield from traverse_node(child)
|
|
|
|
|
|
@functools.lru_cache(maxsize=1)
|
|
def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
|
|
"""Return a mapping from {lineno: "assertion test expression"}."""
|
|
ret: Dict[int, str] = {}
|
|
|
|
depth = 0
|
|
lines: List[str] = []
|
|
assert_lineno: Optional[int] = None
|
|
seen_lines: Set[int] = set()
|
|
|
|
def _write_and_reset() -> None:
|
|
nonlocal depth, lines, assert_lineno, seen_lines
|
|
assert assert_lineno is not None
|
|
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
|
|
depth = 0
|
|
lines = []
|
|
assert_lineno = None
|
|
seen_lines = set()
|
|
|
|
tokens = tokenize.tokenize(io.BytesIO(src).readline)
|
|
for tp, source, (lineno, offset), _, line in tokens:
|
|
if tp == tokenize.NAME and source == "assert":
|
|
assert_lineno = lineno
|
|
elif assert_lineno is not None:
|
|
# keep track of depth for the assert-message `,` lookup
|
|
if tp == tokenize.OP and source in "([{":
|
|
depth += 1
|
|
elif tp == tokenize.OP and source in ")]}":
|
|
depth -= 1
|
|
|
|
if not lines:
|
|
lines.append(line[offset:])
|
|
seen_lines.add(lineno)
|
|
# a non-nested comma separates the expression from the message
|
|
elif depth == 0 and tp == tokenize.OP and source == ",":
|
|
# one line assert with message
|
|
if lineno in seen_lines and len(lines) == 1:
|
|
offset_in_trimmed = offset + len(lines[-1]) - len(line)
|
|
lines[-1] = lines[-1][:offset_in_trimmed]
|
|
# multi-line assert with message
|
|
elif lineno in seen_lines:
|
|
lines[-1] = lines[-1][:offset]
|
|
# multi line assert with escapd newline before message
|
|
else:
|
|
lines.append(line[:offset])
|
|
_write_and_reset()
|
|
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
|
|
_write_and_reset()
|
|
elif lines and lineno not in seen_lines:
|
|
lines.append(line)
|
|
seen_lines.add(lineno)
|
|
|
|
return ret
|
|
|
|
|
|
class AssertionRewriter(ast.NodeVisitor):
|
|
"""Assertion rewriting implementation.
|
|
|
|
The main entrypoint is to call .run() with an ast.Module instance,
|
|
this will then find all the assert statements and rewrite them to
|
|
provide intermediate values and a detailed assertion error. See
|
|
http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
|
|
for an overview of how this works.
|
|
|
|
The entry point here is .run() which will iterate over all the
|
|
statements in an ast.Module and for each ast.Assert statement it
|
|
finds call .visit() with it. Then .visit_Assert() takes over and
|
|
is responsible for creating new ast statements to replace the
|
|
original assert statement: it rewrites the test of an assertion
|
|
to provide intermediate values and replace it with an if statement
|
|
which raises an assertion error with a detailed explanation in
|
|
case the expression is false and calls pytest_assertion_pass hook
|
|
if expression is true.
|
|
|
|
For this .visit_Assert() uses the visitor pattern to visit all the
|
|
AST nodes of the ast.Assert.test field, each visit call returning
|
|
an AST node and the corresponding explanation string. During this
|
|
state is kept in several instance attributes:
|
|
|
|
:statements: All the AST statements which will replace the assert
|
|
statement.
|
|
|
|
:variables: This is populated by .variable() with each variable
|
|
used by the statements so that they can all be set to None at
|
|
the end of the statements.
|
|
|
|
:variable_counter: Counter to create new unique variables needed
|
|
by statements. Variables are created using .variable() and
|
|
have the form of "@py_assert0".
|
|
|
|
:expl_stmts: The AST statements which will be executed to get
|
|
data from the assertion. This is the code which will construct
|
|
the detailed assertion message that is used in the AssertionError
|
|
or for the pytest_assertion_pass hook.
|
|
|
|
:explanation_specifiers: A dict filled by .explanation_param()
|
|
with %-formatting placeholders and their corresponding
|
|
expressions to use in the building of an assertion message.
|
|
This is used by .pop_format_context() to build a message.
|
|
|
|
:stack: A stack of the explanation_specifiers dicts maintained by
|
|
.push_format_context() and .pop_format_context() which allows
|
|
to build another %-formatted string while already building one.
|
|
|
|
This state is reset on every new assert statement visited and used
|
|
by the other visitors.
|
|
"""
|
|
|
|
def __init__(
|
|
self, module_path: Optional[str], config: Optional[Config], source: bytes
|
|
) -> None:
|
|
super().__init__()
|
|
self.module_path = module_path
|
|
self.config = config
|
|
if config is not None:
|
|
self.enable_assertion_pass_hook = config.getini(
|
|
"enable_assertion_pass_hook"
|
|
)
|
|
else:
|
|
self.enable_assertion_pass_hook = False
|
|
self.source = source
|
|
|
|
def run(self, mod: ast.Module) -> None:
|
|
"""Find all assert statements in *mod* and rewrite them."""
|
|
if not mod.body:
|
|
# Nothing to do.
|
|
return
|
|
|
|
# We'll insert some special imports at the top of the module, but after any
|
|
# docstrings and __future__ imports, so first figure out where that is.
|
|
doc = getattr(mod, "docstring", None)
|
|
expect_docstring = doc is None
|
|
if doc is not None and self.is_rewrite_disabled(doc):
|
|
return
|
|
pos = 0
|
|
lineno = 1
|
|
for item in mod.body:
|
|
if (
|
|
expect_docstring
|
|
and isinstance(item, ast.Expr)
|
|
and isinstance(item.value, ast.Str)
|
|
):
|
|
doc = item.value.s
|
|
if self.is_rewrite_disabled(doc):
|
|
return
|
|
expect_docstring = False
|
|
elif (
|
|
isinstance(item, ast.ImportFrom)
|
|
and item.level == 0
|
|
and item.module == "__future__"
|
|
):
|
|
pass
|
|
else:
|
|
break
|
|
pos += 1
|
|
# Special case: for a decorated function, set the lineno to that of the
|
|
# first decorator, not the `def`. Issue #4984.
|
|
if isinstance(item, ast.FunctionDef) and item.decorator_list:
|
|
lineno = item.decorator_list[0].lineno
|
|
else:
|
|
lineno = item.lineno
|
|
# Now actually insert the special imports.
|
|
if sys.version_info >= (3, 10):
|
|
aliases = [
|
|
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
|
|
ast.alias(
|
|
"_pytest.assertion.rewrite",
|
|
"@pytest_ar",
|
|
lineno=lineno,
|
|
col_offset=0,
|
|
),
|
|
]
|
|
else:
|
|
aliases = [
|
|
ast.alias("builtins", "@py_builtins"),
|
|
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
|
|
]
|
|
imports = [
|
|
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
|
|
]
|
|
mod.body[pos:pos] = imports
|
|
|
|
# Collect asserts.
|
|
nodes: List[ast.AST] = [mod]
|
|
while nodes:
|
|
node = nodes.pop()
|
|
for name, field in ast.iter_fields(node):
|
|
if isinstance(field, list):
|
|
new: List[ast.AST] = []
|
|
for i, child in enumerate(field):
|
|
if isinstance(child, ast.Assert):
|
|
# Transform assert.
|
|
new.extend(self.visit(child))
|
|
else:
|
|
new.append(child)
|
|
if isinstance(child, ast.AST):
|
|
nodes.append(child)
|
|
setattr(node, name, new)
|
|
elif (
|
|
isinstance(field, ast.AST)
|
|
# Don't recurse into expressions as they can't contain
|
|
# asserts.
|
|
and not isinstance(field, ast.expr)
|
|
):
|
|
nodes.append(field)
|
|
|
|
@staticmethod
|
|
def is_rewrite_disabled(docstring: str) -> bool:
|
|
return "PYTEST_DONT_REWRITE" in docstring
|
|
|
|
def variable(self) -> str:
|
|
"""Get a new variable."""
|
|
# Use a character invalid in python identifiers to avoid clashing.
|
|
name = "@py_assert" + str(next(self.variable_counter))
|
|
self.variables.append(name)
|
|
return name
|
|
|
|
def assign(self, expr: ast.expr) -> ast.Name:
|
|
"""Give *expr* a name."""
|
|
name = self.variable()
|
|
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
|
|
return ast.Name(name, ast.Load())
|
|
|
|
def display(self, expr: ast.expr) -> ast.expr:
|
|
"""Call saferepr on the expression."""
|
|
return self.helper("_saferepr", expr)
|
|
|
|
def helper(self, name: str, *args: ast.expr) -> ast.expr:
|
|
"""Call a helper in this module."""
|
|
py_name = ast.Name("@pytest_ar", ast.Load())
|
|
attr = ast.Attribute(py_name, name, ast.Load())
|
|
return ast.Call(attr, list(args), [])
|
|
|
|
def builtin(self, name: str) -> ast.Attribute:
|
|
"""Return the builtin called *name*."""
|
|
builtin_name = ast.Name("@py_builtins", ast.Load())
|
|
return ast.Attribute(builtin_name, name, ast.Load())
|
|
|
|
def explanation_param(self, expr: ast.expr) -> str:
|
|
"""Return a new named %-formatting placeholder for expr.
|
|
|
|
This creates a %-formatting placeholder for expr in the
|
|
current formatting context, e.g. ``%(py0)s``. The placeholder
|
|
and expr are placed in the current format context so that it
|
|
can be used on the next call to .pop_format_context().
|
|
"""
|
|
specifier = "py" + str(next(self.variable_counter))
|
|
self.explanation_specifiers[specifier] = expr
|
|
return "%(" + specifier + ")s"
|
|
|
|
def push_format_context(self) -> None:
|
|
"""Create a new formatting context.
|
|
|
|
The format context is used for when an explanation wants to
|
|
have a variable value formatted in the assertion message. In
|
|
this case the value required can be added using
|
|
.explanation_param(). Finally .pop_format_context() is used
|
|
to format a string of %-formatted values as added by
|
|
.explanation_param().
|
|
"""
|
|
self.explanation_specifiers: Dict[str, ast.expr] = {}
|
|
self.stack.append(self.explanation_specifiers)
|
|
|
|
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
|
|
"""Format the %-formatted string with current format context.
|
|
|
|
The expl_expr should be an str ast.expr instance constructed from
|
|
the %-placeholders created by .explanation_param(). This will
|
|
add the required code to format said string to .expl_stmts and
|
|
return the ast.Name instance of the formatted string.
|
|
"""
|
|
current = self.stack.pop()
|
|
if self.stack:
|
|
self.explanation_specifiers = self.stack[-1]
|
|
keys = [ast.Str(key) for key in current.keys()]
|
|
format_dict = ast.Dict(keys, list(current.values()))
|
|
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
|
|
name = "@py_format" + str(next(self.variable_counter))
|
|
if self.enable_assertion_pass_hook:
|
|
self.format_variables.append(name)
|
|
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
|
|
return ast.Name(name, ast.Load())
|
|
|
|
def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
|
|
"""Handle expressions we don't have custom code for."""
|
|
assert isinstance(node, ast.expr)
|
|
res = self.assign(node)
|
|
return res, self.explanation_param(self.display(res))
|
|
|
|
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
|
|
"""Return the AST statements to replace the ast.Assert instance.
|
|
|
|
This rewrites the test of an assertion to provide
|
|
intermediate values and replace it with an if statement which
|
|
raises an assertion error with a detailed explanation in case
|
|
the expression is false.
|
|
"""
|
|
if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
|
|
from _pytest.warning_types import PytestAssertRewriteWarning
|
|
import warnings
|
|
|
|
# TODO: This assert should not be needed.
|
|
assert self.module_path is not None
|
|
warnings.warn_explicit(
|
|
PytestAssertRewriteWarning(
|
|
"assertion is always true, perhaps remove parentheses?"
|
|
),
|
|
category=None,
|
|
filename=self.module_path,
|
|
lineno=assert_.lineno,
|
|
)
|
|
|
|
self.statements: List[ast.stmt] = []
|
|
self.variables: List[str] = []
|
|
self.variable_counter = itertools.count()
|
|
|
|
if self.enable_assertion_pass_hook:
|
|
self.format_variables: List[str] = []
|
|
|
|
self.stack: List[Dict[str, ast.expr]] = []
|
|
self.expl_stmts: List[ast.stmt] = []
|
|
self.push_format_context()
|
|
# Rewrite assert into a bunch of statements.
|
|
top_condition, explanation = self.visit(assert_.test)
|
|
|
|
negation = ast.UnaryOp(ast.Not(), top_condition)
|
|
|
|
if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
|
|
msg = self.pop_format_context(ast.Str(explanation))
|
|
|
|
# Failed
|
|
if assert_.msg:
|
|
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
|
gluestr = "\n>assert "
|
|
else:
|
|
assertmsg = ast.Str("")
|
|
gluestr = "assert "
|
|
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
|
|
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
|
|
err_name = ast.Name("AssertionError", ast.Load())
|
|
fmt = self.helper("_format_explanation", err_msg)
|
|
exc = ast.Call(err_name, [fmt], [])
|
|
raise_ = ast.Raise(exc, None)
|
|
statements_fail = []
|
|
statements_fail.extend(self.expl_stmts)
|
|
statements_fail.append(raise_)
|
|
|
|
# Passed
|
|
fmt_pass = self.helper("_format_explanation", msg)
|
|
orig = _get_assertion_exprs(self.source)[assert_.lineno]
|
|
hook_call_pass = ast.Expr(
|
|
self.helper(
|
|
"_call_assertion_pass",
|
|
ast.Num(assert_.lineno),
|
|
ast.Str(orig),
|
|
fmt_pass,
|
|
)
|
|
)
|
|
# If any hooks implement assert_pass hook
|
|
hook_impl_test = ast.If(
|
|
self.helper("_check_if_assertion_pass_impl"),
|
|
self.expl_stmts + [hook_call_pass],
|
|
[],
|
|
)
|
|
statements_pass = [hook_impl_test]
|
|
|
|
# Test for assertion condition
|
|
main_test = ast.If(negation, statements_fail, statements_pass)
|
|
self.statements.append(main_test)
|
|
if self.format_variables:
|
|
variables = [
|
|
ast.Name(name, ast.Store()) for name in self.format_variables
|
|
]
|
|
clear_format = ast.Assign(variables, ast.NameConstant(None))
|
|
self.statements.append(clear_format)
|
|
|
|
else: # Original assertion rewriting
|
|
# Create failure message.
|
|
body = self.expl_stmts
|
|
self.statements.append(ast.If(negation, body, []))
|
|
if assert_.msg:
|
|
assertmsg = self.helper("_format_assertmsg", assert_.msg)
|
|
explanation = "\n>assert " + explanation
|
|
else:
|
|
assertmsg = ast.Str("")
|
|
explanation = "assert " + explanation
|
|
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
|
|
msg = self.pop_format_context(template)
|
|
fmt = self.helper("_format_explanation", msg)
|
|
err_name = ast.Name("AssertionError", ast.Load())
|
|
exc = ast.Call(err_name, [fmt], [])
|
|
raise_ = ast.Raise(exc, None)
|
|
|
|
body.append(raise_)
|
|
|
|
# Clear temporary variables by setting them to None.
|
|
if self.variables:
|
|
variables = [ast.Name(name, ast.Store()) for name in self.variables]
|
|
clear = ast.Assign(variables, ast.NameConstant(None))
|
|
self.statements.append(clear)
|
|
# Fix locations (line numbers/column offsets).
|
|
for stmt in self.statements:
|
|
for node in traverse_node(stmt):
|
|
ast.copy_location(node, assert_)
|
|
return self.statements
|
|
|
|
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
|
|
# Display the repr of the name if it's a local variable or
|
|
# _should_repr_global_name() thinks it's acceptable.
|
|
locs = ast.Call(self.builtin("locals"), [], [])
|
|
inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
|
|
dorepr = self.helper("_should_repr_global_name", name)
|
|
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
|
|
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
|
|
return name, self.explanation_param(expr)
|
|
|
|
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
|
|
res_var = self.variable()
|
|
expl_list = self.assign(ast.List([], ast.Load()))
|
|
app = ast.Attribute(expl_list, "append", ast.Load())
|
|
is_or = int(isinstance(boolop.op, ast.Or))
|
|
body = save = self.statements
|
|
fail_save = self.expl_stmts
|
|
levels = len(boolop.values) - 1
|
|
self.push_format_context()
|
|
# Process each operand, short-circuiting if needed.
|
|
for i, v in enumerate(boolop.values):
|
|
if i:
|
|
fail_inner: List[ast.stmt] = []
|
|
# cond is set in a prior loop iteration below
|
|
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
|
self.expl_stmts = fail_inner
|
|
self.push_format_context()
|
|
res, expl = self.visit(v)
|
|
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
|
expl_format = self.pop_format_context(ast.Str(expl))
|
|
call = ast.Call(app, [expl_format], [])
|
|
self.expl_stmts.append(ast.Expr(call))
|
|
if i < levels:
|
|
cond: ast.expr = res
|
|
if is_or:
|
|
cond = ast.UnaryOp(ast.Not(), cond)
|
|
inner: List[ast.stmt] = []
|
|
self.statements.append(ast.If(cond, inner, []))
|
|
self.statements = body = inner
|
|
self.statements = save
|
|
self.expl_stmts = fail_save
|
|
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
|
|
expl = self.pop_format_context(expl_template)
|
|
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
|
|
|
|
def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
|
|
pattern = UNARY_MAP[unary.op.__class__]
|
|
operand_res, operand_expl = self.visit(unary.operand)
|
|
res = self.assign(ast.UnaryOp(unary.op, operand_res))
|
|
return res, pattern % (operand_expl,)
|
|
|
|
def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
|
|
symbol = BINOP_MAP[binop.op.__class__]
|
|
left_expr, left_expl = self.visit(binop.left)
|
|
right_expr, right_expl = self.visit(binop.right)
|
|
explanation = f"({left_expl} {symbol} {right_expl})"
|
|
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
|
|
return res, explanation
|
|
|
|
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
new_func, func_expl = self.visit(call.func)
|
|
arg_expls = []
|
|
new_args = []
|
|
new_kwargs = []
|
|
for arg in call.args:
|
|
res, expl = self.visit(arg)
|
|
arg_expls.append(expl)
|
|
new_args.append(res)
|
|
for keyword in call.keywords:
|
|
res, expl = self.visit(keyword.value)
|
|
new_kwargs.append(ast.keyword(keyword.arg, res))
|
|
if keyword.arg:
|
|
arg_expls.append(keyword.arg + "=" + expl)
|
|
else: # **args have `arg` keywords with an .arg of None
|
|
arg_expls.append("**" + expl)
|
|
|
|
expl = "{}({})".format(func_expl, ", ".join(arg_expls))
|
|
new_call = ast.Call(new_func, new_args, new_kwargs)
|
|
res = self.assign(new_call)
|
|
res_expl = self.explanation_param(self.display(res))
|
|
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
|
|
return res, outer_expl
|
|
|
|
def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
|
|
# A Starred node can appear in a function call.
|
|
res, expl = self.visit(starred.value)
|
|
new_starred = ast.Starred(res, starred.ctx)
|
|
return new_starred, "*" + expl
|
|
|
|
def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
|
|
if not isinstance(attr.ctx, ast.Load):
|
|
return self.generic_visit(attr)
|
|
value, value_expl = self.visit(attr.value)
|
|
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
|
|
res_expl = self.explanation_param(self.display(res))
|
|
pat = "%s\n{%s = %s.%s\n}"
|
|
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
|
|
return res, expl
|
|
|
|
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
|
self.push_format_context()
|
|
left_res, left_expl = self.visit(comp.left)
|
|
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
|
left_expl = f"({left_expl})"
|
|
res_variables = [self.variable() for i in range(len(comp.ops))]
|
|
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
|
|
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
|
|
it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
|
|
expls = []
|
|
syms = []
|
|
results = [left_res]
|
|
for i, op, next_operand in it:
|
|
next_res, next_expl = self.visit(next_operand)
|
|
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
|
next_expl = f"({next_expl})"
|
|
results.append(next_res)
|
|
sym = BINOP_MAP[op.__class__]
|
|
syms.append(ast.Str(sym))
|
|
expl = f"{left_expl} {sym} {next_expl}"
|
|
expls.append(ast.Str(expl))
|
|
res_expr = ast.Compare(left_res, [op], [next_res])
|
|
self.statements.append(ast.Assign([store_names[i]], res_expr))
|
|
left_res, left_expl = next_res, next_expl
|
|
# Use pytest.assertion.util._reprcompare if that's available.
|
|
expl_call = self.helper(
|
|
"_call_reprcompare",
|
|
ast.Tuple(syms, ast.Load()),
|
|
ast.Tuple(load_names, ast.Load()),
|
|
ast.Tuple(expls, ast.Load()),
|
|
ast.Tuple(results, ast.Load()),
|
|
)
|
|
if len(comp.ops) > 1:
|
|
res: ast.expr = ast.BoolOp(ast.And(), load_names)
|
|
else:
|
|
res = load_names[0]
|
|
return res, self.explanation_param(self.pop_format_context(expl_call))
|
|
|
|
|
|
def try_makedirs(cache_dir: Path) -> bool:
|
|
"""Attempt to create the given directory and sub-directories exist.
|
|
|
|
Returns True if successful or if it already exists.
|
|
"""
|
|
try:
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
except (FileNotFoundError, NotADirectoryError, FileExistsError):
|
|
# One of the path components was not a directory:
|
|
# - we're in a zip file
|
|
# - it is a file
|
|
return False
|
|
except PermissionError:
|
|
return False
|
|
except OSError as e:
|
|
# as of now, EROFS doesn't have an equivalent OSError-subclass
|
|
if e.errno == errno.EROFS:
|
|
return False
|
|
raise
|
|
return True
|
|
|
|
|
|
def get_cache_dir(file_path: Path) -> Path:
|
|
"""Return the cache directory to write .pyc files for the given .py file path."""
|
|
if sys.version_info >= (3, 8) and sys.pycache_prefix:
|
|
# given:
|
|
# prefix = '/tmp/pycs'
|
|
# path = '/home/user/proj/test_app.py'
|
|
# we want:
|
|
# '/tmp/pycs/home/user/proj'
|
|
return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
|
|
else:
|
|
# classic pycache directory
|
|
return file_path.parent / "__pycache__"
|