Compare commits

..

2 Commits

Author SHA1 Message Date
Isaac Shoebottom
a50f49d2c8 Add python venv 2022-10-31 10:10:52 -03:00
Isaac Shoebottom
fb1a0435c1 Update Lab 14 venv 2022-10-31 10:10:27 -03:00
915 changed files with 287883 additions and 2 deletions

View File

@ -4,7 +4,7 @@
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="jdk" jdkName="Python 3.10 (python-venv)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (L14)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (python-venv)" project-jdk-type="Python SDK" />
</project>

4
utils/python-venv/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
# created by virtualenv automatically
# Commit venv because it is shared between all of python environments for this class
# *

View File

@ -0,0 +1,166 @@
import sys
import os
import re
import importlib
import warnings
import contextlib
is_pypy = '__pypy__' in sys.builtin_module_names
warnings.filterwarnings('ignore',
r'.+ distutils\b.+ deprecated',
DeprecationWarning)
def warn_distutils_present():
if 'distutils' not in sys.modules:
return
if is_pypy and sys.version_info < (3, 7):
# PyPy for 3.6 unconditionally imports distutils, so bypass the warning
# https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
return
warnings.warn(
"Distutils was imported before Setuptools, but importing Setuptools "
"also replaces the `distutils` module in `sys.modules`. This may lead "
"to undesirable behaviors or errors. To avoid these issues, avoid "
"using distutils directly, ensure that setuptools is installed in the "
"traditional way (e.g. not an editable install), and/or make sure "
"that setuptools is always imported before distutils.")
def clear_distutils():
if 'distutils' not in sys.modules:
return
warnings.warn("Setuptools is replacing distutils.")
mods = [name for name in sys.modules if re.match(r'distutils\b', name)]
for name in mods:
del sys.modules[name]
def enabled():
"""
Allow selection of distutils by environment variable.
"""
which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'local')
return which == 'local'
def ensure_local_distutils():
clear_distutils()
# With the DistutilsMetaFinder in place,
# perform an import to cause distutils to be
# loaded from setuptools._distutils. Ref #2906.
with shim():
importlib.import_module('distutils')
# check that submodules load as expected
core = importlib.import_module('distutils.core')
assert '_distutils' in core.__file__, core.__file__
def do_override():
"""
Ensure that the local copy of distutils is preferred over stdlib.
See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401
for more motivation.
"""
if enabled():
warn_distutils_present()
ensure_local_distutils()
class DistutilsMetaFinder:
def find_spec(self, fullname, path, target=None):
if path is not None:
return
method_name = 'spec_for_{fullname}'.format(**locals())
method = getattr(self, method_name, lambda: None)
return method()
def spec_for_distutils(self):
import importlib.abc
import importlib.util
try:
mod = importlib.import_module('setuptools._distutils')
except Exception:
# There are a couple of cases where setuptools._distutils
# may not be present:
# - An older Setuptools without a local distutils is
# taking precedence. Ref #2957.
# - Path manipulation during sitecustomize removes
# setuptools from the path but only after the hook
# has been loaded. Ref #2980.
# In either case, fall back to stdlib behavior.
return
class DistutilsLoader(importlib.abc.Loader):
def create_module(self, spec):
return mod
def exec_module(self, module):
pass
return importlib.util.spec_from_loader('distutils', DistutilsLoader())
def spec_for_pip(self):
"""
Ensure stdlib distutils when running under pip.
See pypa/pip#8761 for rationale.
"""
if self.pip_imported_during_build():
return
clear_distutils()
self.spec_for_distutils = lambda: None
@classmethod
def pip_imported_during_build(cls):
"""
Detect if pip is being imported in a build script. Ref #2355.
"""
import traceback
return any(
cls.frame_file_is_setup(frame)
for frame, line in traceback.walk_stack(None)
)
@staticmethod
def frame_file_is_setup(frame):
"""
Return True if the indicated frame suggests a setup.py file.
"""
# some frames may not have __file__ (#2940)
return frame.f_globals.get('__file__', '').endswith('setup.py')
DISTUTILS_FINDER = DistutilsMetaFinder()
def add_shim():
DISTUTILS_FINDER in sys.meta_path or insert_shim()
@contextlib.contextmanager
def shim():
insert_shim()
try:
yield
finally:
remove_shim()
def insert_shim():
sys.meta_path.insert(0, DISTUTILS_FINDER)
def remove_shim():
try:
sys.meta_path.remove(DISTUTILS_FINDER)
except ValueError:
pass

View File

@ -0,0 +1 @@
__import__('_distutils_hack').do_override()

View File

@ -0,0 +1,9 @@
__all__ = ["__version__", "version_tuple"]
try:
from ._version import version as __version__, version_tuple
except ImportError: # pragma: no cover
# broken installation, we don't even try
# unknown only works because we do poor mans version compare
__version__ = "unknown"
version_tuple = (0, 0, "unknown") # type:ignore[assignment]

View File

@ -0,0 +1,116 @@
"""Allow bash-completion for argparse with argcomplete if installed.
Needs argcomplete>=0.5.6 for python 3.2/3.3 (older versions fail
to find the magic string, so _ARGCOMPLETE env. var is never set, and
this does not need special code).
Function try_argcomplete(parser) should be called directly before
the call to ArgumentParser.parse_args().
The filescompleter is what you normally would use on the positional
arguments specification, in order to get "dirname/" after "dirn<TAB>"
instead of the default "dirname ":
optparser.add_argument(Config._file_or_dir, nargs='*').completer=filescompleter
Other, application specific, completers should go in the file
doing the add_argument calls as they need to be specified as .completer
attributes as well. (If argcomplete is not installed, the function the
attribute points to will not be used).
SPEEDUP
=======
The generic argcomplete script for bash-completion
(/etc/bash_completion.d/python-argcomplete.sh)
uses a python program to determine startup script generated by pip.
You can speed up completion somewhat by changing this script to include
# PYTHON_ARGCOMPLETE_OK
so the python-argcomplete-check-easy-install-script does not
need to be called to find the entry point of the code and see if that is
marked with PYTHON_ARGCOMPLETE_OK.
INSTALL/DEBUGGING
=================
To include this support in another application that has setup.py generated
scripts:
- Add the line:
# PYTHON_ARGCOMPLETE_OK
near the top of the main python entry point.
- Include in the file calling parse_args():
from _argcomplete import try_argcomplete, filescompleter
Call try_argcomplete just before parse_args(), and optionally add
filescompleter to the positional arguments' add_argument().
If things do not work right away:
- Switch on argcomplete debugging with (also helpful when doing custom
completers):
export _ARC_DEBUG=1
- Run:
python-argcomplete-check-easy-install-script $(which appname)
echo $?
will echo 0 if the magic line has been found, 1 if not.
- Sometimes it helps to find early on errors using:
_ARGCOMPLETE=1 _ARC_DEBUG=1 appname
which should throw a KeyError: 'COMPLINE' (which is properly set by the
global argcomplete script).
"""
import argparse
import os
import sys
from glob import glob
from typing import Any
from typing import List
from typing import Optional
class FastFilesCompleter:
"""Fast file completer class."""
def __init__(self, directories: bool = True) -> None:
self.directories = directories
def __call__(self, prefix: str, **kwargs: Any) -> List[str]:
# Only called on non option completions.
if os.path.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.path.sep)
else:
prefix_dir = 0
completion = []
globbed = []
if "*" not in prefix and "?" not in prefix:
# We are on unix, otherwise no bash.
if not prefix or prefix[-1] == os.path.sep:
globbed.extend(glob(prefix + ".*"))
prefix += "*"
globbed.extend(glob(prefix))
for x in sorted(globbed):
if os.path.isdir(x):
x += "/"
# Append stripping the prefix (like bash, not like compgen).
completion.append(x[prefix_dir:])
return completion
if os.environ.get("_ARGCOMPLETE"):
try:
import argcomplete.completers
except ImportError:
sys.exit(-1)
filescompleter: Optional[FastFilesCompleter] = FastFilesCompleter()
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
argcomplete.autocomplete(parser, always_complete_options=False)
else:
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
pass
filescompleter = None

View File

@ -0,0 +1,22 @@
"""Python inspection/code generation API."""
from .code import Code
from .code import ExceptionInfo
from .code import filter_traceback
from .code import Frame
from .code import getfslineno
from .code import Traceback
from .code import TracebackEntry
from .source import getrawcode
from .source import Source
__all__ = [
"Code",
"ExceptionInfo",
"filter_traceback",
"Frame",
"getfslineno",
"getrawcode",
"Traceback",
"TracebackEntry",
"Source",
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,217 @@
import ast
import inspect
import textwrap
import tokenize
import types
import warnings
from bisect import bisect_right
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Tuple
from typing import Union
class Source:
"""An immutable object holding a source code fragment.
When using Source(...), the source lines are deindented.
"""
def __init__(self, obj: object = None) -> None:
if not obj:
self.lines: List[str] = []
elif isinstance(obj, Source):
self.lines = obj.lines
elif isinstance(obj, (tuple, list)):
self.lines = deindent(x.rstrip("\n") for x in obj)
elif isinstance(obj, str):
self.lines = deindent(obj.split("\n"))
else:
try:
rawcode = getrawcode(obj)
src = inspect.getsource(rawcode)
except TypeError:
src = inspect.getsource(obj) # type: ignore[arg-type]
self.lines = deindent(src.split("\n"))
def __eq__(self, other: object) -> bool:
if not isinstance(other, Source):
return NotImplemented
return self.lines == other.lines
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@overload
def __getitem__(self, key: int) -> str:
...
@overload
def __getitem__(self, key: slice) -> "Source":
...
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]:
if isinstance(key, int):
return self.lines[key]
else:
if key.step not in (None, 1):
raise IndexError("cannot slice a Source with a step")
newsource = Source()
newsource.lines = self.lines[key.start : key.stop]
return newsource
def __iter__(self) -> Iterator[str]:
return iter(self.lines)
def __len__(self) -> int:
return len(self.lines)
def strip(self) -> "Source":
"""Return new Source object with trailing and leading blank lines removed."""
start, end = 0, len(self)
while start < end and not self.lines[start].strip():
start += 1
while end > start and not self.lines[end - 1].strip():
end -= 1
source = Source()
source.lines[:] = self.lines[start:end]
return source
def indent(self, indent: str = " " * 4) -> "Source":
"""Return a copy of the source object with all lines indented by the
given indent-string."""
newsource = Source()
newsource.lines = [(indent + line) for line in self.lines]
return newsource
def getstatement(self, lineno: int) -> "Source":
"""Return Source statement which contains the given linenumber
(counted from 0)."""
start, end = self.getstatementrange(lineno)
return self[start:end]
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
"""Return (start, end) tuple which spans the minimal statement region
which containing the given lineno."""
if not (0 <= lineno < len(self)):
raise IndexError("lineno out of range")
ast, start, end = getstatementrange_ast(lineno, self)
return start, end
def deindent(self) -> "Source":
"""Return a new Source object deindented."""
newsource = Source()
newsource.lines[:] = deindent(self.lines)
return newsource
def __str__(self) -> str:
return "\n".join(self.lines)
#
# helper functions
#
def findsource(obj) -> Tuple[Optional[Source], int]:
try:
sourcelines, lineno = inspect.findsource(obj)
except Exception:
return None, -1
source = Source()
source.lines = [line.rstrip() for line in sourcelines]
return source, lineno
def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
"""Return code object for given function."""
try:
return obj.__code__ # type: ignore[attr-defined,no-any-return]
except AttributeError:
pass
if trycall:
call = getattr(obj, "__call__", None)
if call and not isinstance(obj, type):
return getrawcode(call, trycall=False)
raise TypeError(f"could not get code object for {obj!r}")
def deindent(lines: Iterable[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
# Flatten all statements and except handlers into one lineno-list.
# AST's line numbers start indexing at 1.
values: List[int] = []
for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
# Before Python 3.8, the lineno of a decorated class or function pointed at the decorator.
# Since Python 3.8, the lineno points to the class/def, so need to include the decorators.
if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
for d in x.decorator_list:
values.append(d.lineno - 1)
values.append(x.lineno - 1)
for name in ("finalbody", "orelse"):
val: Optional[List[ast.stmt]] = getattr(x, name, None)
if val:
# Treat the finally/orelse part as its own statement.
values.append(val[0].lineno - 1 - 1)
values.sort()
insert_index = bisect_right(values, lineno)
start = values[insert_index - 1]
if insert_index >= len(values):
end = None
else:
end = values[insert_index]
return start, end
def getstatementrange_ast(
lineno: int,
source: Source,
assertion: bool = False,
astnode: Optional[ast.AST] = None,
) -> Tuple[ast.AST, int, int]:
if astnode is None:
content = str(source)
# See #4260:
# Don't produce duplicate warnings when compiling source to find AST.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
astnode = ast.parse(content, "source", "exec")
start, end = get_statement_startend2(lineno, astnode)
# We need to correct the end:
# - ast-parsing strips comments
# - there might be empty lines
# - we might have lesser indented code blocks at the end
if end is None:
end = len(source.lines)
if end > start + 1:
# Make sure we don't span differently indented code blocks
# by using the BlockFinder helper used which inspect.getsource() uses itself.
block_finder = inspect.BlockFinder()
# If we start with an indented line, put blockfinder to "started" mode.
block_finder.started = source.lines[start][0].isspace()
it = ((x + "\n") for x in source.lines[start:end])
try:
for tok in tokenize.generate_tokens(lambda: next(it)):
block_finder.tokeneater(*tok)
except (inspect.EndOfBlock, IndentationError):
end = block_finder.last + start
except Exception:
pass
# The end might still point to a comment or empty line, correct it.
while end:
line = source.lines[end - 1].lstrip()
if line.startswith("#") or not line:
end -= 1
else:
break
return astnode, start, end

View File

@ -0,0 +1,8 @@
from .terminalwriter import get_terminal_width
from .terminalwriter import TerminalWriter
__all__ = [
"TerminalWriter",
"get_terminal_width",
]

View File

@ -0,0 +1,180 @@
import pprint
import reprlib
from typing import Any
from typing import Dict
from typing import IO
from typing import Optional
def _try_repr_or_str(obj: object) -> str:
try:
return repr(obj)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException:
return f'{type(obj).__name__}("{obj}")'
def _format_repr_exception(exc: BaseException, obj: object) -> str:
try:
exc_info = _try_repr_or_str(exc)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
exc_info = f"unpresentable exception ({_try_repr_or_str(exc)})"
return "<[{} raised in repr()] {} object at 0x{:x}>".format(
exc_info, type(obj).__name__, id(obj)
)
def _ellipsize(s: str, maxsize: int) -> str:
if len(s) > maxsize:
i = max(0, (maxsize - 3) // 2)
j = max(0, maxsize - 3 - i)
return s[:i] + "..." + s[len(s) - j :]
return s
class SafeRepr(reprlib.Repr):
"""
repr.Repr that limits the resulting size of repr() and includes
information on exceptions raised during the call.
"""
def __init__(self, maxsize: Optional[int], use_ascii: bool = False) -> None:
"""
:param maxsize:
If not None, will truncate the resulting repr to that specific size, using ellipsis
somewhere in the middle to hide the extra text.
If None, will not impose any size limits on the returning repr.
"""
super().__init__()
# ``maxstring`` is used by the superclass, and needs to be an int; using a
# very large number in case maxsize is None, meaning we want to disable
# truncation.
self.maxstring = maxsize if maxsize is not None else 1_000_000_000
self.maxsize = maxsize
self.use_ascii = use_ascii
def repr(self, x: object) -> str:
try:
if self.use_ascii:
s = ascii(x)
else:
s = super().repr(x)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
s = _format_repr_exception(exc, x)
if self.maxsize is not None:
s = _ellipsize(s, self.maxsize)
return s
def repr_instance(self, x: object, level: int) -> str:
try:
s = repr(x)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
s = _format_repr_exception(exc, x)
if self.maxsize is not None:
s = _ellipsize(s, self.maxsize)
return s
def safeformat(obj: object) -> str:
"""Return a pretty printed string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info.
"""
try:
return pprint.pformat(obj)
except Exception as exc:
return _format_repr_exception(exc, obj)
# Maximum size of overall repr of objects to display during assertion errors.
DEFAULT_REPR_MAX_SIZE = 240
def saferepr(
obj: object, maxsize: Optional[int] = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False
) -> str:
"""Return a size-limited safe repr-string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info and 'saferepr' generally takes
care to never raise exceptions itself.
This function is a wrapper around the Repr/reprlib functionality of the
stdlib.
"""
return SafeRepr(maxsize, use_ascii).repr(obj)
def saferepr_unlimited(obj: object, use_ascii: bool = True) -> str:
"""Return an unlimited-size safe repr-string for the given object.
As with saferepr, failing __repr__ functions of user instances
will be represented with a short exception info.
This function is a wrapper around simple repr.
Note: a cleaner solution would be to alter ``saferepr``this way
when maxsize=None, but that might affect some other code.
"""
try:
if use_ascii:
return ascii(obj)
return repr(obj)
except Exception as exc:
return _format_repr_exception(exc, obj)
class AlwaysDispatchingPrettyPrinter(pprint.PrettyPrinter):
"""PrettyPrinter that always dispatches (regardless of width)."""
def _format(
self,
object: object,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, Any],
level: int,
) -> None:
# Type ignored because _dispatch is private.
p = self._dispatch.get(type(object).__repr__, None) # type: ignore[attr-defined]
objid = id(object)
if objid in context or p is None:
# Type ignored because _format is private.
super()._format( # type: ignore[misc]
object,
stream,
indent,
allowance,
context,
level,
)
return
context[objid] = 1
p(self, object, stream, indent, allowance, context, level + 1)
del context[objid]
def _pformat_dispatch(
object: object,
indent: int = 1,
width: int = 80,
depth: Optional[int] = None,
*,
compact: bool = False,
) -> str:
return AlwaysDispatchingPrettyPrinter(
indent=indent, width=width, depth=depth, compact=compact
).pformat(object)

View File

@ -0,0 +1,233 @@
"""Helper functions for writing to terminals and files."""
import os
import shutil
import sys
from typing import Optional
from typing import Sequence
from typing import TextIO
from .wcwidth import wcswidth
from _pytest.compat import final
# This code was initially copied from py 1.8.1, file _io/terminalwriter.py.
def get_terminal_width() -> int:
width, _ = shutil.get_terminal_size(fallback=(80, 24))
# The Windows get_terminal_size may be bogus, let's sanify a bit.
if width < 40:
width = 80
return width
def should_do_markup(file: TextIO) -> bool:
if os.environ.get("PY_COLORS") == "1":
return True
if os.environ.get("PY_COLORS") == "0":
return False
if "NO_COLOR" in os.environ:
return False
if "FORCE_COLOR" in os.environ:
return True
return (
hasattr(file, "isatty") and file.isatty() and os.environ.get("TERM") != "dumb"
)
@final
class TerminalWriter:
_esctable = dict(
black=30,
red=31,
green=32,
yellow=33,
blue=34,
purple=35,
cyan=36,
white=37,
Black=40,
Red=41,
Green=42,
Yellow=43,
Blue=44,
Purple=45,
Cyan=46,
White=47,
bold=1,
light=2,
blink=5,
invert=7,
)
def __init__(self, file: Optional[TextIO] = None) -> None:
if file is None:
file = sys.stdout
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
try:
import colorama
except ImportError:
pass
else:
file = colorama.AnsiToWin32(file).stream
assert file is not None
self._file = file
self.hasmarkup = should_do_markup(file)
self._current_line = ""
self._terminal_width: Optional[int] = None
self.code_highlight = True
@property
def fullwidth(self) -> int:
if self._terminal_width is not None:
return self._terminal_width
return get_terminal_width()
@fullwidth.setter
def fullwidth(self, value: int) -> None:
self._terminal_width = value
@property
def width_of_current_line(self) -> int:
"""Return an estimate of the width so far in the current line."""
return wcswidth(self._current_line)
def markup(self, text: str, **markup: bool) -> str:
for name in markup:
if name not in self._esctable:
raise ValueError(f"unknown markup: {name!r}")
if self.hasmarkup:
esc = [self._esctable[name] for name, on in markup.items() if on]
if esc:
text = "".join("\x1b[%sm" % cod for cod in esc) + text + "\x1b[0m"
return text
def sep(
self,
sepchar: str,
title: Optional[str] = None,
fullwidth: Optional[int] = None,
**markup: bool,
) -> None:
if fullwidth is None:
fullwidth = self.fullwidth
# The goal is to have the line be as long as possible
# under the condition that len(line) <= fullwidth.
if sys.platform == "win32":
# If we print in the last column on windows we are on a
# new line but there is no way to verify/neutralize this
# (we may not know the exact line width).
# So let's be defensive to avoid empty lines in the output.
fullwidth -= 1
if title is not None:
# we want 2 + 2*len(fill) + len(title) <= fullwidth
# i.e. 2 + 2*len(sepchar)*N + len(title) <= fullwidth
# 2*len(sepchar)*N <= fullwidth - len(title) - 2
# N <= (fullwidth - len(title) - 2) // (2*len(sepchar))
N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1)
fill = sepchar * N
line = f"{fill} {title} {fill}"
else:
# we want len(sepchar)*N <= fullwidth
# i.e. N <= fullwidth // len(sepchar)
line = sepchar * (fullwidth // len(sepchar))
# In some situations there is room for an extra sepchar at the right,
# in particular if we consider that with a sepchar like "_ " the
# trailing space is not important at the end of the line.
if len(line) + len(sepchar.rstrip()) <= fullwidth:
line += sepchar.rstrip()
self.line(line, **markup)
def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None:
if msg:
current_line = msg.rsplit("\n", 1)[-1]
if "\n" in msg:
self._current_line = current_line
else:
self._current_line += current_line
msg = self.markup(msg, **markup)
try:
self._file.write(msg)
except UnicodeEncodeError:
# Some environments don't support printing general Unicode
# strings, due to misconfiguration or otherwise; in that case,
# print the string escaped to ASCII.
# When the Unicode situation improves we should consider
# letting the error propagate instead of masking it (see #7475
# for one brief attempt).
msg = msg.encode("unicode-escape").decode("ascii")
self._file.write(msg)
if flush:
self.flush()
def line(self, s: str = "", **markup: bool) -> None:
self.write(s, **markup)
self.write("\n")
def flush(self) -> None:
self._file.flush()
def _write_source(self, lines: Sequence[str], indents: Sequence[str] = ()) -> None:
"""Write lines of source code possibly highlighted.
Keeping this private for now because the API is clunky. We should discuss how
to evolve the terminal writer so we can have more precise color support, for example
being able to write part of a line in one color and the rest in another, and so on.
"""
if indents and len(indents) != len(lines):
raise ValueError(
"indents size ({}) should have same size as lines ({})".format(
len(indents), len(lines)
)
)
if not indents:
indents = [""] * len(lines)
source = "\n".join(lines)
new_lines = self._highlight(source).splitlines()
for indent, new_line in zip(indents, new_lines):
self.line(indent + new_line)
def _highlight(self, source: str) -> str:
"""Highlight the given source code if we have markup support."""
from _pytest.config.exceptions import UsageError
if not self.hasmarkup or not self.code_highlight:
return source
try:
from pygments.formatters.terminal import TerminalFormatter
from pygments.lexers.python import PythonLexer
from pygments import highlight
import pygments.util
except ImportError:
return source
else:
try:
highlighted: str = highlight(
source,
PythonLexer(),
TerminalFormatter(
bg=os.getenv("PYTEST_THEME_MODE", "dark"),
style=os.getenv("PYTEST_THEME"),
),
)
return highlighted
except pygments.util.ClassNotFound:
raise UsageError(
"PYTEST_THEME environment variable had an invalid value: '{}'. "
"Only valid pygment styles are allowed.".format(
os.getenv("PYTEST_THEME")
)
)
except pygments.util.OptionError:
raise UsageError(
"PYTEST_THEME_MODE environment variable had an invalid value: '{}'. "
"The only allowed values are 'dark' and 'light'.".format(
os.getenv("PYTEST_THEME_MODE")
)
)

View File

@ -0,0 +1,55 @@
import unicodedata
from functools import lru_cache
@lru_cache(100)
def wcwidth(c: str) -> int:
"""Determine how many columns are needed to display a character in a terminal.
Returns -1 if the character is not printable.
Returns 0, 1 or 2 for other characters.
"""
o = ord(c)
# ASCII fast path.
if 0x20 <= o < 0x07F:
return 1
# Some Cf/Zp/Zl characters which should be zero-width.
if (
o == 0x0000
or 0x200B <= o <= 0x200F
or 0x2028 <= o <= 0x202E
or 0x2060 <= o <= 0x2063
):
return 0
category = unicodedata.category(c)
# Control characters.
if category == "Cc":
return -1
# Combining characters with zero width.
if category in ("Me", "Mn"):
return 0
# Full/Wide east asian characters.
if unicodedata.east_asian_width(c) in ("F", "W"):
return 2
return 1
def wcswidth(s: str) -> int:
"""Determine how many columns are needed to display a string in a terminal.
Returns -1 if the string contains non-printable characters.
"""
width = 0
for c in unicodedata.normalize("NFC", s):
wc = wcwidth(c)
if wc < 0:
return -1
width += wc
return width

View File

@ -0,0 +1,109 @@
"""create errno-specific classes for IO or os calls."""
from __future__ import annotations
import errno
import os
import sys
from typing import Callable
from typing import TYPE_CHECKING
from typing import TypeVar
if TYPE_CHECKING:
from typing_extensions import ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
class Error(EnvironmentError):
def __repr__(self) -> str:
return "{}.{} {!r}: {} ".format(
self.__class__.__module__,
self.__class__.__name__,
self.__class__.__doc__,
" ".join(map(str, self.args)),
# repr(self.args)
)
def __str__(self) -> str:
s = "[{}]: {}".format(
self.__class__.__doc__,
" ".join(map(str, self.args)),
)
return s
_winerrnomap = {
2: errno.ENOENT,
3: errno.ENOENT,
17: errno.EEXIST,
18: errno.EXDEV,
13: errno.EBUSY, # empty cd drive, but ENOMEDIUM seems unavailiable
22: errno.ENOTDIR,
20: errno.ENOTDIR,
267: errno.ENOTDIR,
5: errno.EACCES, # anything better?
}
class ErrorMaker:
"""lazily provides Exception classes for each possible POSIX errno
(as defined per the 'errno' module). All such instances
subclass EnvironmentError.
"""
_errno2class: dict[int, type[Error]] = {}
def __getattr__(self, name: str) -> type[Error]:
if name[0] == "_":
raise AttributeError(name)
eno = getattr(errno, name)
cls = self._geterrnoclass(eno)
setattr(self, name, cls)
return cls
def _geterrnoclass(self, eno: int) -> type[Error]:
try:
return self._errno2class[eno]
except KeyError:
clsname = errno.errorcode.get(eno, "UnknownErrno%d" % (eno,))
errorcls = type(
clsname,
(Error,),
{"__module__": "py.error", "__doc__": os.strerror(eno)},
)
self._errno2class[eno] = errorcls
return errorcls
def checked_call(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
"""Call a function and raise an errno-exception if applicable."""
__tracebackhide__ = True
try:
return func(*args, **kwargs)
except Error:
raise
except OSError as value:
if not hasattr(value, "errno"):
raise
errno = value.errno
if sys.platform == "win32":
try:
cls = self._geterrnoclass(_winerrnomap[errno])
except KeyError:
raise value
else:
# we are not on Windows, or we got a proper OSError
cls = self._geterrnoclass(errno)
raise cls(f"{func.__name__}{args!r}")
_error_maker = ErrorMaker()
checked_call = _error_maker.checked_call
def __getattr__(attr: str) -> type[Error]:
return getattr(_error_maker, attr) # type: ignore[no-any-return]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,5 @@
# coding: utf-8
# file generated by setuptools_scm
# don't change, don't track in version control
__version__ = version = '7.2.0'
__version_tuple__ = version_tuple = (7, 2, 0)

View File

@ -0,0 +1,181 @@
"""Support for presenting detailed information in failing assertions."""
import sys
from typing import Any
from typing import Generator
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from _pytest.assertion import rewrite
from _pytest.assertion import truncate
from _pytest.assertion import util
from _pytest.assertion.rewrite import assertstate_key
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
if TYPE_CHECKING:
from _pytest.main import Session
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--assert",
action="store",
dest="assertmode",
choices=("rewrite", "plain"),
default="rewrite",
metavar="MODE",
help=(
"Control assertion debugging tools.\n"
"'plain' performs no assertion debugging.\n"
"'rewrite' (the default) rewrites assert statements in test modules"
" on import to provide assert expression information."
),
)
parser.addini(
"enable_assertion_pass_hook",
type="bool",
default=False,
help="Enables the pytest_assertion_pass hook. "
"Make sure to delete any previously generated pyc cache files.",
)
def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside
the package will get their assert statements rewritten.
Thus you should make sure to call this before the module is
actually imported, usually in your __init__.py if you are a plugin
using a package.
:param names: The module names to register.
"""
for name in names:
if not isinstance(name, str):
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
raise TypeError(msg.format(repr(names)))
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
importhook = hook
break
else:
# TODO(typing): Add a protocol for mark_rewrite() and use it
# for importhook and for PytestPluginManager.rewrite_hook.
importhook = DummyRewriteHook() # type: ignore
importhook.mark_rewrite(*names)
class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names: str) -> None:
pass
class AssertionState:
"""State for the assertion plugin."""
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook: Optional[rewrite.AssertionRewritingHook] = None
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
config.stash[assertstate_key] = AssertionState(config, "rewrite")
config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config.stash[assertstate_key].trace("installed rewrite import hook")
def undo() -> None:
hook = config.stash[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)
config.add_cleanup(undo)
return hook
def pytest_collection(session: "Session") -> None:
# This hook is only called when test modules are collected
# so for example not in the managing process of pytest-xdist
# (which does not collect test modules).
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(session)
@hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
The rewrite module will use util._reprcompare if it exists to use custom
reporting via the pytest_assertrepr_compare hook. This sets up this custom
comparison for the test.
"""
ihook = item.ihook
def callbinrepr(op, left: object, right: object) -> Optional[str]:
"""Call the pytest_assertrepr_compare hook and prepare the result.
This uses the first result from the hook and then ensures the
following:
* Overly verbose explanations are truncated unless configured otherwise
(eg. if running in verbose mode).
* Embedded newlines are escaped to help util.format_explanation()
later.
* If the rewrite mode is used embedded %-characters are replaced
to protect later % formatting.
The result can be formatted by util.format_explanation() for
pretty printing.
"""
hook_result = ihook.pytest_assertrepr_compare(
config=item.config, op=op, left=left, right=right
)
for new_expl in hook_result:
if new_expl:
new_expl = truncate.truncate_if_required(new_expl, item)
new_expl = [line.replace("\n", "\\n") for line in new_expl]
res = "\n~".join(new_expl)
if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%")
return res
return None
saved_assert_hooks = util._reprcompare, util._assertion_pass
util._reprcompare = callbinrepr
util._config = item.config
if ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
util._assertion_pass = call_assertion_pass_hook
yield
util._reprcompare, util._assertion_pass = saved_assert_hooks
util._config = None
def pytest_sessionfinish(session: "Session") -> None:
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(None)
def pytest_assertrepr_compare(
config: Config, op: str, left: Any, right: Any
) -> Optional[List[str]]:
return util.assertrepr_compare(config=config, op=op, left=left, right=right)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,94 @@
"""Utilities for truncating assertion output.
Current default behaviour is to truncate assertion explanations at
~8 terminal lines, unless running in "-vv" mode or running on CI.
"""
from typing import List
from typing import Optional
from _pytest.assertion import util
from _pytest.nodes import Item
DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = 8 * 80
USAGE_MSG = "use '-vv' to show"
def truncate_if_required(
explanation: List[str], item: Item, max_length: Optional[int] = None
) -> List[str]:
"""Truncate this assertion explanation if the given test item is eligible."""
if _should_truncate_item(item):
return _truncate_explanation(explanation)
return explanation
def _should_truncate_item(item: Item) -> bool:
"""Whether or not this test item is eligible for truncation."""
verbose = item.config.option.verbose
return verbose < 2 and not util.running_on_ci()
def _truncate_explanation(
input_lines: List[str],
max_lines: Optional[int] = None,
max_chars: Optional[int] = None,
) -> List[str]:
"""Truncate given list of strings that makes up the assertion explanation.
Truncates to either 8 lines, or 640 characters - whichever the input reaches
first. The remaining lines will be replaced by a usage message.
"""
if max_lines is None:
max_lines = DEFAULT_MAX_LINES
if max_chars is None:
max_chars = DEFAULT_MAX_CHARS
# Check if truncation required
input_char_count = len("".join(input_lines))
if len(input_lines) <= max_lines and input_char_count <= max_chars:
return input_lines
# Truncate first to max_lines, and then truncate to max_chars if max_chars
# is exceeded.
truncated_explanation = input_lines[:max_lines]
truncated_explanation = _truncate_by_char_count(truncated_explanation, max_chars)
# Add ellipsis to final line
truncated_explanation[-1] = truncated_explanation[-1] + "..."
# Append useful message to explanation
truncated_line_count = len(input_lines) - len(truncated_explanation)
truncated_line_count += 1 # Account for the part-truncated final line
msg = "...Full output truncated"
if truncated_line_count == 1:
msg += f" ({truncated_line_count} line hidden)"
else:
msg += f" ({truncated_line_count} lines hidden)"
msg += f", {USAGE_MSG}"
truncated_explanation.extend(["", str(msg)])
return truncated_explanation
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
# Check if truncation required
if len("".join(input_lines)) <= max_chars:
return input_lines
# Find point at which input length exceeds total allowed length
iterated_char_count = 0
for iterated_index, input_line in enumerate(input_lines):
if iterated_char_count + len(input_line) > max_chars:
break
iterated_char_count += len(input_line)
# Create truncated explanation with modified final line
truncated_result = input_lines[:iterated_index]
final_line = input_lines[iterated_index]
if final_line:
final_line_truncate_point = max_chars - iterated_char_count
final_line = final_line[:final_line_truncate_point]
truncated_result.append(final_line)
return truncated_result

View File

@ -0,0 +1,522 @@
"""Utilities for assertion debugging."""
import collections.abc
import os
import pprint
from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from unicodedata import normalize
import _pytest._code
from _pytest import outcomes
from _pytest._io.saferepr import _pformat_dispatch
from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited
from _pytest.config import Config
# The _reprcompare attribute on the util module is used by the new assertion
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
# DebugInterpreter.
_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass: Optional[Callable[[int, str, str], None]] = None
# Config object which is assigned during pytest_runtest_protocol.
_config: Optional[Config] = None
def format_explanation(explanation: str) -> str:
r"""Format an explanation.
Normally all embedded newlines are escaped, however there are
three exceptions: \n{, \n} and \n~. The first two are intended
cover nested explanations, see function and attribute explanations
for examples (.visit_Call(), visit_Attribute()). The last one is
for when one explanation needs to span multiple lines, e.g. when
displaying diffs.
"""
lines = _split_explanation(explanation)
result = _format_lines(lines)
return "\n".join(result)
def _split_explanation(explanation: str) -> List[str]:
r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'.
Any other newlines will be escaped and appear in the line as the
literal '\n' characters.
"""
raw_lines = (explanation or "").split("\n")
lines = [raw_lines[0]]
for values in raw_lines[1:]:
if values and values[0] in ["{", "}", "~", ">"]:
lines.append(values)
else:
lines[-1] += "\\n" + values
return lines
def _format_lines(lines: Sequence[str]) -> List[str]:
"""Format the individual lines.
This will replace the '{', '}' and '~' characters of our mini formatting
language with the proper 'where ...', 'and ...' and ' + ...' text, taking
care of indentation along the way.
Return a list of formatted lines.
"""
result = list(lines[:1])
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith("{"):
if stackcnt[-1]:
s = "and "
else:
s = "where "
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
elif line.startswith("}"):
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
else:
assert line[0] in ["~", ">"]
stack[-1] += 1
indent = len(stack) if line.startswith("~") else len(stack) - 1
result.append(" " * indent + line[1:])
assert len(stack) == 1
return result
def issequence(x: Any) -> bool:
return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
def istext(x: Any) -> bool:
return isinstance(x, str)
def isdict(x: Any) -> bool:
return isinstance(x, dict)
def isset(x: Any) -> bool:
return isinstance(x, (set, frozenset))
def isnamedtuple(obj: Any) -> bool:
return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None
def isattrs(obj: Any) -> bool:
return getattr(obj, "__attrs_attrs__", None) is not None
def isiterable(obj: Any) -> bool:
try:
iter(obj)
return not istext(obj)
except TypeError:
return False
def has_default_eq(
obj: object,
) -> bool:
"""Check if an instance of an object contains the default eq
First, we check if the object's __eq__ attribute has __code__,
if so, we check the equally of the method code filename (__code__.co_filename)
to the default one generated by the dataclass and attr module
for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated"
"""
# inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68
if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"):
code_filename = obj.__eq__.__code__.co_filename
if isattrs(obj):
return "attrs generated eq" in code_filename
return code_filename == "<string>" # data class
return True
def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False
) -> Optional[List[str]]:
"""Return specialised explanations for some operators/operands."""
verbose = config.getoption("verbose")
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
# See issue #3246.
use_ascii = (
isinstance(left, str)
and isinstance(right, str)
and normalize("NFD", left) == normalize("NFD", right)
)
if verbose > 1:
left_repr = saferepr_unlimited(left, use_ascii=use_ascii)
right_repr = saferepr_unlimited(right, use_ascii=use_ascii)
else:
# XXX: "15 chars indentation" is wrong
# ("E AssertionError: assert "); should use term width.
maxsize = (
80 - 15 - len(op) - 2
) // 2 # 15 chars indentation, 1 space around op
left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii)
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
summary = f"{left_repr} {op} {right_repr}"
explanation = None
try:
if op == "==":
explanation = _compare_eq_any(left, right, verbose)
elif op == "not in":
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
except outcomes.Exit:
raise
except Exception:
explanation = [
"(pytest_assertion plugin: representation of details failed: {}.".format(
_pytest._code.ExceptionInfo.from_current()._getreprcrash()
),
" Probably an object has a faulty __repr__.)",
]
if not explanation:
return None
return [summary] + explanation
def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
explanation = []
if istext(left) and istext(right):
explanation = _diff_text(left, right, verbose)
else:
from _pytest.python_api import ApproxBase
if isinstance(left, ApproxBase) or isinstance(right, ApproxBase):
# Although the common order should be obtained == expected, this ensures both ways
approx_side = left if isinstance(left, ApproxBase) else right
other_side = right if isinstance(left, ApproxBase) else left
explanation = approx_side._repr_compare(other_side)
elif type(left) == type(right) and (
isdatacls(left) or isattrs(left) or isnamedtuple(left)
):
# Note: unlike dataclasses/attrs, namedtuples compare only the
# field values, not the type or field names. But this branch
# intentionally only handles the same-type case, which was often
# used in older code bases before dataclasses/attrs were available.
explanation = _compare_eq_cls(left, right, verbose)
elif issequence(left) and issequence(right):
explanation = _compare_eq_sequence(left, right, verbose)
elif isset(left) and isset(right):
explanation = _compare_eq_set(left, right, verbose)
elif isdict(left) and isdict(right):
explanation = _compare_eq_dict(left, right, verbose)
if isiterable(left) and isiterable(right):
expl = _compare_eq_iterable(left, right, verbose)
explanation.extend(expl)
return explanation
def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
"""Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing
characters which are identical to keep the diff minimal.
"""
from difflib import ndiff
explanation: List[str] = []
if verbose < 1:
i = 0 # just in case left or right has zero length
for i in range(min(len(left), len(right))):
if left[i] != right[i]:
break
if i > 42:
i -= 10 # Provide some context
explanation = [
"Skipping %s identical leading characters in diff, use -v to show" % i
]
left = left[i:]
right = right[i:]
if len(left) == len(right):
for i in range(len(left)):
if left[-i] != right[-i]:
break
if i > 42:
i -= 10 # Provide some context
explanation += [
"Skipping {} identical trailing "
"characters in diff, use -v to show".format(i)
]
left = left[:-i]
right = right[:-i]
keepends = True
if left.isspace() or right.isspace():
left = repr(str(left))
right = repr(str(right))
explanation += ["Strings contain only whitespace, escaping them using repr()"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation += [
line.strip("\n")
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
]
return explanation
def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
"""Move opening/closing parenthesis/bracket to own lines."""
opening = lines[0][:1]
if opening in ["(", "[", "{"]:
lines[0] = " " + lines[0][1:]
lines[:] = [opening] + lines
closing = lines[-1][-1:]
if closing in [")", "]", "}"]:
lines[-1] = lines[-1][:-1] + ","
lines[:] = lines + [closing]
def _compare_eq_iterable(
left: Iterable[Any], right: Iterable[Any], verbose: int = 0
) -> List[str]:
if verbose <= 0 and not running_on_ci():
return ["Use -v to get more diff"]
# dynamic import to speedup pytest
import difflib
left_formatting = pprint.pformat(left).splitlines()
right_formatting = pprint.pformat(right).splitlines()
# Re-format for different output lengths.
lines_left = len(left_formatting)
lines_right = len(right_formatting)
if lines_left != lines_right:
left_formatting = _pformat_dispatch(left).splitlines()
right_formatting = _pformat_dispatch(right).splitlines()
if lines_left > 1 or lines_right > 1:
_surrounding_parens_on_own_lines(left_formatting)
_surrounding_parens_on_own_lines(right_formatting)
explanation = ["Full diff:"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting)
)
return explanation
def _compare_eq_sequence(
left: Sequence[Any], right: Sequence[Any], verbose: int = 0
) -> List[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation: List[str] = []
len_left = len(left)
len_right = len(right)
for i in range(min(len_left, len_right)):
if left[i] != right[i]:
if comparing_bytes:
# when comparing bytes, we want to see their ascii representation
# instead of their numeric values (#5260)
# using a slice gives us the ascii representation:
# >>> s = b'foo'
# >>> s[0]
# 102
# >>> s[0:1]
# b'f'
left_value = left[i : i + 1]
right_value = right[i : i + 1]
else:
left_value = left[i]
right_value = right[i]
explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"]
break
if comparing_bytes:
# when comparing bytes, it doesn't help to show the "sides contain one or more
# items" longer explanation, so skip it
return explanation
len_diff = len_left - len_right
if len_diff:
if len_diff > 0:
dir_with_more = "Left"
extra = saferepr(left[len_right])
else:
len_diff = 0 - len_diff
dir_with_more = "Right"
extra = saferepr(right[len_left])
if len_diff == 1:
explanation += [f"{dir_with_more} contains one more item: {extra}"]
else:
explanation += [
"%s contains %d more items, first extra item: %s"
% (dir_with_more, len_diff, extra)
]
return explanation
def _compare_eq_set(
left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0
) -> List[str]:
explanation = []
diff_left = left - right
diff_right = right - left
if diff_left:
explanation.append("Extra items in the left set:")
for item in diff_left:
explanation.append(saferepr(item))
if diff_right:
explanation.append("Extra items in the right set:")
for item in diff_right:
explanation.append(saferepr(item))
return explanation
def _compare_eq_dict(
left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
) -> List[str]:
explanation: List[str] = []
set_left = set(left)
set_right = set(right)
common = set_left.intersection(set_right)
same = {k: left[k] for k in common if left[k] == right[k]}
if same and verbose < 2:
explanation += ["Omitting %s identical items, use -vv to show" % len(same)]
elif same:
explanation += ["Common items:"]
explanation += pprint.pformat(same).splitlines()
diff = {k for k in common if left[k] != right[k]}
if diff:
explanation += ["Differing items:"]
for k in diff:
explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})]
extra_left = set_left - set_right
len_extra_left = len(extra_left)
if len_extra_left:
explanation.append(
"Left contains %d more item%s:"
% (len_extra_left, "" if len_extra_left == 1 else "s")
)
explanation.extend(
pprint.pformat({k: left[k] for k in extra_left}).splitlines()
)
extra_right = set_right - set_left
len_extra_right = len(extra_right)
if len_extra_right:
explanation.append(
"Right contains %d more item%s:"
% (len_extra_right, "" if len_extra_right == 1 else "s")
)
explanation.extend(
pprint.pformat({k: right[k] for k in extra_right}).splitlines()
)
return explanation
def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
import dataclasses
all_fields = dataclasses.fields(left)
fields_to_check = [info.name for info in all_fields if info.compare]
elif isattrs(left):
all_fields = left.__attrs_attrs__
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
elif isnamedtuple(left):
fields_to_check = left._fields
else:
assert False
indent = " "
same = []
diff = []
for field in fields_to_check:
if getattr(left, field) == getattr(right, field):
same.append(field)
else:
diff.append(field)
explanation = []
if same or diff:
explanation += [""]
if same and verbose < 2:
explanation.append("Omitting %s identical items, use -vv to show" % len(same))
elif same:
explanation += ["Matching attributes:"]
explanation += pprint.pformat(same).splitlines()
if diff:
explanation += ["Differing attributes:"]
explanation += pprint.pformat(diff).splitlines()
for field in diff:
field_left = getattr(left, field)
field_right = getattr(right, field)
explanation += [
"",
"Drill down into differing attribute %s:" % field,
("%s%s: %r != %r") % (indent, field, field_left, field_right),
]
explanation += [
indent + line
for line in _compare_eq_any(field_left, field_right, verbose)
]
return explanation
def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
index = text.find(term)
head = text[:index]
tail = text[index + len(term) :]
correct_text = head + tail
diff = _diff_text(text, correct_text, verbose)
newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)]
for line in diff:
if line.startswith("Skipping"):
continue
if line.startswith("- "):
continue
if line.startswith("+ "):
newdiff.append(" " + line[2:])
else:
newdiff.append(line)
return newdiff
def running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)

View File

@ -0,0 +1,580 @@
"""Implementation of the cache provider."""
# This plugin was not named "cache" to avoid conflicts with the external
# pytest-cache version.
import json
import os
from pathlib import Path
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from typing import Set
from typing import Union
import attr
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.compat import final
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.python import Module
from _pytest.python import Package
from _pytest.reports import TestReport
README_CONTENT = """\
# pytest cache directory #
This directory contains data from the pytest's cache plugin,
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
**Do not** commit this to version control.
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
"""
CACHEDIR_TAG_CONTENT = b"""\
Signature: 8a477f597d28d172789f06886806bc55
# This file is a cache directory tag created by pytest.
# For information about cache directory tags, see:
# https://bford.info/cachedir/spec.html
"""
@final
@attr.s(init=False, auto_attribs=True)
class Cache:
_cachedir: Path = attr.ib(repr=False)
_config: Config = attr.ib(repr=False)
# Sub-directory under cache-dir for directories created by `mkdir()`.
_CACHE_PREFIX_DIRS = "d"
# Sub-directory under cache-dir for values created by `set()`.
_CACHE_PREFIX_VALUES = "v"
def __init__(
self, cachedir: Path, config: Config, *, _ispytest: bool = False
) -> None:
check_ispytest(_ispytest)
self._cachedir = cachedir
self._config = config
@classmethod
def for_config(cls, config: Config, *, _ispytest: bool = False) -> "Cache":
"""Create the Cache instance for a Config.
:meta private:
"""
check_ispytest(_ispytest)
cachedir = cls.cache_dir_from_config(config, _ispytest=True)
if config.getoption("cacheclear") and cachedir.is_dir():
cls.clear_cache(cachedir, _ispytest=True)
return cls(cachedir, config, _ispytest=True)
@classmethod
def clear_cache(cls, cachedir: Path, _ispytest: bool = False) -> None:
"""Clear the sub-directories used to hold cached directories and values.
:meta private:
"""
check_ispytest(_ispytest)
for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES):
d = cachedir / prefix
if d.is_dir():
rm_rf(d)
@staticmethod
def cache_dir_from_config(config: Config, *, _ispytest: bool = False) -> Path:
"""Get the path to the cache directory for a Config.
:meta private:
"""
check_ispytest(_ispytest)
return resolve_from_str(config.getini("cache_dir"), config.rootpath)
def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
"""Issue a cache warning.
:meta private:
"""
check_ispytest(_ispytest)
import warnings
from _pytest.warning_types import PytestCacheWarning
warnings.warn(
PytestCacheWarning(fmt.format(**args) if args else fmt),
self._config.hook,
stacklevel=3,
)
def mkdir(self, name: str) -> Path:
"""Return a directory path object with the given name.
If the directory does not yet exist, it will be created. You can use
it to manage files to e.g. store/retrieve database dumps across test
sessions.
.. versionadded:: 7.0
:param name:
Must be a string not containing a ``/`` separator.
Make sure the name contains your plugin or application
identifiers to prevent clashes with other cache users.
"""
path = Path(name)
if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators")
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
res.mkdir(exist_ok=True, parents=True)
return res
def _getvaluepath(self, key: str) -> Path:
return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key))
def get(self, key: str, default):
"""Return the cached value for the given key.
If no value was yet cached or the value cannot be read, the specified
default is returned.
:param key:
Must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
:param default:
The value to return in case of a cache-miss or invalid cache value.
"""
path = self._getvaluepath(key)
try:
with path.open("r", encoding="UTF-8") as f:
return json.load(f)
except (ValueError, OSError):
return default
def set(self, key: str, value: object) -> None:
"""Save value for the given key.
:param key:
Must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
:param value:
Must be of any combination of basic python types,
including nested types like lists of dictionaries.
"""
path = self._getvaluepath(key)
try:
if path.parent.is_dir():
cache_dir_exists_already = True
else:
cache_dir_exists_already = self._cachedir.exists()
path.parent.mkdir(exist_ok=True, parents=True)
except OSError:
self.warn("could not create cache path {path}", path=path, _ispytest=True)
return
if not cache_dir_exists_already:
self._ensure_supporting_files()
data = json.dumps(value, ensure_ascii=False, indent=2)
try:
f = path.open("w", encoding="UTF-8")
except OSError:
self.warn("cache could not write path {path}", path=path, _ispytest=True)
else:
with f:
f.write(data)
def _ensure_supporting_files(self) -> None:
"""Create supporting files in the cache dir that are not really part of the cache."""
readme_path = self._cachedir / "README.md"
readme_path.write_text(README_CONTENT, encoding="UTF-8")
gitignore_path = self._cachedir.joinpath(".gitignore")
msg = "# Created by pytest automatically.\n*\n"
gitignore_path.write_text(msg, encoding="UTF-8")
cachedir_tag_path = self._cachedir.joinpath("CACHEDIR.TAG")
cachedir_tag_path.write_bytes(CACHEDIR_TAG_CONTENT)
class LFPluginCollWrapper:
def __init__(self, lfplugin: "LFPlugin") -> None:
self.lfplugin = lfplugin
self._collected_at_least_one_failure = False
@hookimpl(hookwrapper=True)
def pytest_make_collect_report(self, collector: nodes.Collector):
if isinstance(collector, Session):
out = yield
res: CollectReport = out.get_result()
# Sort any lf-paths to the beginning.
lf_paths = self.lfplugin._last_failed_paths
res.result = sorted(
res.result,
# use stable sort to priorize last failed
key=lambda x: x.path in lf_paths,
reverse=True,
)
return
elif isinstance(collector, Module):
if collector.path in self.lfplugin._last_failed_paths:
out = yield
res = out.get_result()
result = res.result
lastfailed = self.lfplugin.lastfailed
# Only filter with known failures.
if not self._collected_at_least_one_failure:
if not any(x.nodeid in lastfailed for x in result):
return
self.lfplugin.config.pluginmanager.register(
LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip"
)
self._collected_at_least_one_failure = True
session = collector.session
result[:] = [
x
for x in result
if x.nodeid in lastfailed
# Include any passed arguments (not trivial to filter).
or session.isinitpath(x.path)
# Keep all sub-collectors.
or isinstance(x, nodes.Collector)
]
return
yield
class LFPluginCollSkipfiles:
def __init__(self, lfplugin: "LFPlugin") -> None:
self.lfplugin = lfplugin
@hookimpl
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> Optional[CollectReport]:
# Packages are Modules, but _last_failed_paths only contains
# test-bearing paths and doesn't try to include the paths of their
# packages, so don't filter them.
if isinstance(collector, Module) and not isinstance(collector, Package):
if collector.path not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1
return CollectReport(
collector.nodeid, "passed", longrepr=None, result=[]
)
return None
class LFPlugin:
"""Plugin which implements the --lf (run last-failing) option."""
def __init__(self, config: Config) -> None:
self.config = config
active_keys = "lf", "failedfirst"
self.active = any(config.getoption(key) for key in active_keys)
assert config.cache
self.lastfailed: Dict[str, bool] = config.cache.get("cache/lastfailed", {})
self._previously_failed_count: Optional[int] = None
self._report_status: Optional[str] = None
self._skipped_files = 0 # count skipped files during collection due to --lf
if config.getoption("lf"):
self._last_failed_paths = self.get_last_failed_paths()
config.pluginmanager.register(
LFPluginCollWrapper(self), "lfplugin-collwrapper"
)
def get_last_failed_paths(self) -> Set[Path]:
"""Return a set with all Paths()s of the previously failed nodeids."""
rootpath = self.config.rootpath
result = {rootpath / nodeid.split("::")[0] for nodeid in self.lastfailed}
return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self) -> Optional[str]:
if self.active and self.config.getoption("verbose") >= 0:
return "run-last-failure: %s" % self._report_status
return None
def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped:
self.lastfailed.pop(report.nodeid, None)
elif report.failed:
self.lastfailed[report.nodeid] = True
def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped")
if passed:
if report.nodeid in self.lastfailed:
self.lastfailed.pop(report.nodeid)
self.lastfailed.update((item.nodeid, True) for item in report.result)
else:
self.lastfailed[report.nodeid] = True
@hookimpl(hookwrapper=True, tryfirst=True)
def pytest_collection_modifyitems(
self, config: Config, items: List[nodes.Item]
) -> Generator[None, None, None]:
yield
if not self.active:
return
if self.lastfailed:
previously_failed = []
previously_passed = []
for item in items:
if item.nodeid in self.lastfailed:
previously_failed.append(item)
else:
previously_passed.append(item)
self._previously_failed_count = len(previously_failed)
if not previously_failed:
# Running a subset of all tests with recorded failures
# only outside of it.
self._report_status = "%d known failures not in selected tests" % (
len(self.lastfailed),
)
else:
if self.config.getoption("lf"):
items[:] = previously_failed
config.hook.pytest_deselected(items=previously_passed)
else: # --failedfirst
items[:] = previously_failed + previously_passed
noun = "failure" if self._previously_failed_count == 1 else "failures"
suffix = " first" if self.config.getoption("failedfirst") else ""
self._report_status = "rerun previous {count} {noun}{suffix}".format(
count=self._previously_failed_count, suffix=suffix, noun=noun
)
if self._skipped_files > 0:
files_noun = "file" if self._skipped_files == 1 else "files"
self._report_status += " (skipped {files} {files_noun})".format(
files=self._skipped_files, files_noun=files_noun
)
else:
self._report_status = "no previously failed tests, "
if self.config.getoption("last_failed_no_failures") == "none":
self._report_status += "deselecting all items."
config.hook.pytest_deselected(items=items[:])
items[:] = []
else:
self._report_status += "not deselecting items."
def pytest_sessionfinish(self, session: Session) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
assert config.cache is not None
saved_lastfailed = config.cache.get("cache/lastfailed", {})
if saved_lastfailed != self.lastfailed:
config.cache.set("cache/lastfailed", self.lastfailed)
class NFPlugin:
"""Plugin which implements the --nf (run new-first) option."""
def __init__(self, config: Config) -> None:
self.config = config
self.active = config.option.newfirst
assert config.cache is not None
self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
@hookimpl(hookwrapper=True, tryfirst=True)
def pytest_collection_modifyitems(
self, items: List[nodes.Item]
) -> Generator[None, None, None]:
yield
if self.active:
new_items: Dict[str, nodes.Item] = {}
other_items: Dict[str, nodes.Item] = {}
for item in items:
if item.nodeid not in self.cached_nodeids:
new_items[item.nodeid] = item
else:
other_items[item.nodeid] = item
items[:] = self._get_increasing_order(
new_items.values()
) + self._get_increasing_order(other_items.values())
self.cached_nodeids.update(new_items)
else:
self.cached_nodeids.update(item.nodeid for item in items)
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> List[nodes.Item]:
return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True) # type: ignore[no-any-return]
def pytest_sessionfinish(self) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
if config.getoption("collectonly"):
return
assert config.cache is not None
config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--lf",
"--last-failed",
action="store_true",
dest="lf",
help="Rerun only the tests that failed "
"at the last run (or all if none failed)",
)
group.addoption(
"--ff",
"--failed-first",
action="store_true",
dest="failedfirst",
help="Run all tests, but run the last failures first. "
"This may re-order tests and thus lead to "
"repeated fixture setup/teardown.",
)
group.addoption(
"--nf",
"--new-first",
action="store_true",
dest="newfirst",
help="Run tests from new files first, then the rest of the tests "
"sorted by file mtime",
)
group.addoption(
"--cache-show",
action="append",
nargs="?",
dest="cacheshow",
help=(
"Show cache contents, don't perform collection or tests. "
"Optional argument: glob (default: '*')."
),
)
group.addoption(
"--cache-clear",
action="store_true",
dest="cacheclear",
help="Remove all cache contents at start of test run",
)
cache_dir_default = ".pytest_cache"
if "TOX_ENV_DIR" in os.environ:
cache_dir_default = os.path.join(os.environ["TOX_ENV_DIR"], cache_dir_default)
parser.addini("cache_dir", default=cache_dir_default, help="Cache directory path")
group.addoption(
"--lfnf",
"--last-failed-no-failures",
action="store",
dest="last_failed_no_failures",
choices=("all", "none"),
default="all",
help="Which tests to run with no previously (known) failures",
)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.cacheshow:
from _pytest.main import wrap_session
return wrap_session(config, cacheshow)
return None
@hookimpl(tryfirst=True)
def pytest_configure(config: Config) -> None:
config.cache = Cache.for_config(config, _ispytest=True)
config.pluginmanager.register(LFPlugin(config), "lfplugin")
config.pluginmanager.register(NFPlugin(config), "nfplugin")
@fixture
def cache(request: FixtureRequest) -> Cache:
"""Return a cache object that can persist state between testing sessions.
cache.get(key, default)
cache.set(key, value)
Keys must be ``/`` separated strings, where the first part is usually the
name of your plugin or application to avoid clashes with other cache users.
Values can be any object handled by the json stdlib module.
"""
assert request.config.cache is not None
return request.config.cache
def pytest_report_header(config: Config) -> Optional[str]:
"""Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
assert config.cache is not None
cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths
# starting with .., ../.. if sensible
try:
displaypath = cachedir.relative_to(config.rootpath)
except ValueError:
displaypath = cachedir
return f"cachedir: {displaypath}"
return None
def cacheshow(config: Config, session: Session) -> int:
from pprint import pformat
assert config.cache is not None
tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir():
tw.line("cache is empty")
return 0
glob = config.option.cacheshow[0]
if glob is None:
glob = "*"
dummy = object()
basedir = config.cache._cachedir
vdir = basedir / Cache._CACHE_PREFIX_VALUES
tw.sep("-", "cache values for %r" % glob)
for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()):
key = str(valpath.relative_to(vdir))
val = config.cache.get(key, dummy)
if val is dummy:
tw.line("%s contains unreadable content, will be ignored" % key)
else:
tw.line("%s contains:" % key)
for line in pformat(val).splitlines():
tw.line(" " + line)
ddir = basedir / Cache._CACHE_PREFIX_DIRS
if ddir.is_dir():
contents = sorted(ddir.rglob(glob))
tw.sep("-", "cache directories for %r" % glob)
for p in contents:
# if p.is_dir():
# print("%s/" % p.relative_to(basedir))
if p.is_file():
key = str(p.relative_to(basedir))
tw.line(f"{key} is a file of length {p.stat().st_size:d}")
return 0

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,417 @@
"""Python version compatibility code."""
import enum
import functools
import inspect
import os
import sys
from inspect import Parameter
from inspect import signature
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Generic
from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
import py
# fmt: off
# Workaround for https://github.com/sphinx-doc/sphinx/issues/10351.
# If `overload` is imported from `compat` instead of from `typing`,
# Sphinx doesn't recognize it as `overload` and the API docs for
# overloaded functions look good again. But type checkers handle
# it fine.
# fmt: on
if True:
from typing import overload as overload
if TYPE_CHECKING:
from typing_extensions import Final
_T = TypeVar("_T")
_S = TypeVar("_S")
#: constant to prepare valuing pylib path replacements/lazy proxies later on
# intended for removal in pytest 8.0 or 9.0
# fmt: off
# intentional space to create a fake difference for the verification
LEGACY_PATH = py.path. local
# fmt: on
def legacy_path(path: Union[str, "os.PathLike[str]"]) -> LEGACY_PATH:
"""Internal wrapper to prepare lazy proxies for legacy_path instances"""
return LEGACY_PATH(path)
# fmt: off
# Singleton type for NOTSET, as described in:
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
class NotSetType(enum.Enum):
token = 0
NOTSET: "Final" = NotSetType.token # noqa: E305
# fmt: on
if sys.version_info >= (3, 8):
import importlib.metadata
importlib_metadata = importlib.metadata
else:
import importlib_metadata as importlib_metadata # noqa: F401
def _format_args(func: Callable[..., Any]) -> str:
return str(signature(func))
def is_generator(func: object) -> bool:
genfunc = inspect.isgeneratorfunction(func)
return genfunc and not iscoroutinefunction(func)
def iscoroutinefunction(func: object) -> bool:
"""Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with
@asyncio.coroutine.
Note: copied and modified from Python 3.5's builtin couroutines.py to avoid
importing asyncio directly, which in turns also initializes the "logging"
module as a side-effect (see issue #8).
"""
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
def is_async_function(func: object) -> bool:
"""Return True if the given function seems to be an async function or
an async generator."""
return iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
def getlocation(function, curdir: Optional[str] = None) -> str:
function = get_real_func(function)
fn = Path(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
if curdir is not None:
try:
relfn = fn.relative_to(curdir)
except ValueError:
pass
else:
return "%s:%d" % (relfn, lineno + 1)
return "%s:%d" % (fn, lineno + 1)
def num_mock_patch_args(function) -> int:
"""Return number of arguments used up by mock arguments (if any)."""
patchings = getattr(function, "patchings", None)
if not patchings:
return 0
mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object())
ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object())
return len(
[
p
for p in patchings
if not p.attribute_name
and (p.new is mock_sentinel or p.new is ut_mock_sentinel)
]
)
def getfuncargnames(
function: Callable[..., Any],
*,
name: str = "",
is_method: bool = False,
cls: Optional[type] = None,
) -> Tuple[str, ...]:
"""Return the names of a function's mandatory arguments.
Should return the names of all function arguments that:
* Aren't bound to an instance or type as in instance or class methods.
* Don't have default values.
* Aren't bound with functools.partial.
* Aren't replaced with mocks.
The is_method and cls arguments indicate that the function should
be treated as a bound method even though it's not unless, only in
the case of cls, the function is a static method.
The name parameter should be the original name in which the function was collected.
"""
# TODO(RonnyPfannschmidt): This function should be refactored when we
# revisit fixtures. The fixture mechanism should ask the node for
# the fixture names, and not try to obtain directly from the
# function object well after collection has occurred.
# The parameters attribute of a Signature object contains an
# ordered mapping of parameter names to Parameter instances. This
# creates a tuple of the names of the parameters that don't have
# defaults.
try:
parameters = signature(function).parameters
except (ValueError, TypeError) as e:
from _pytest.outcomes import fail
fail(
f"Could not determine arguments of {function!r}: {e}",
pytrace=False,
)
arg_names = tuple(
p.name
for p in parameters.values()
if (
p.kind is Parameter.POSITIONAL_OR_KEYWORD
or p.kind is Parameter.KEYWORD_ONLY
)
and p.default is Parameter.empty
)
if not name:
name = function.__name__
# If this function should be treated as a bound method even though
# it's passed as an unbound method or function, remove the first
# parameter name.
if is_method or (
# Not using `getattr` because we don't want to resolve the staticmethod.
# Not using `cls.__dict__` because we want to check the entire MRO.
cls
and not isinstance(
inspect.getattr_static(cls, name, default=None), staticmethod
)
):
arg_names = arg_names[1:]
# Remove any names that will be replaced with mocks.
if hasattr(function, "__wrapped__"):
arg_names = arg_names[num_mock_patch_args(function) :]
return arg_names
def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of
# getfuncargnames, to get the arguments which were excluded from its result
# because they had default values.
return tuple(
p.name
for p in signature(function).parameters.values()
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
and p.default is not Parameter.empty
)
_non_printable_ascii_translate_table = {
i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127)
}
_non_printable_ascii_translate_table.update(
{ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"}
)
def _translate_non_printable(s: str) -> str:
return s.translate(_non_printable_ascii_translate_table)
STRING_TYPES = bytes, str
def _bytes_to_ascii(val: bytes) -> str:
return val.decode("ascii", "backslashreplace")
def ascii_escaped(val: Union[bytes, str]) -> str:
r"""If val is pure ASCII, return it as an str, otherwise, escape
bytes objects into a sequence of escaped bytes:
b'\xc3\xb4\xc5\xd6' -> r'\xc3\xb4\xc5\xd6'
and escapes unicode objects into a sequence of escaped unicode
ids, e.g.:
r'4\nV\U00043efa\x0eMXWB\x1e\u3028\u15fd\xcd\U0007d944'
Note:
The obvious "v.decode('unicode-escape')" will return
valid UTF-8 unicode if it finds them in bytes, but we
want to return escaped bytes for any byte, even if they match
a UTF-8 string.
"""
if isinstance(val, bytes):
ret = _bytes_to_ascii(val)
else:
ret = val.encode("unicode_escape").decode("ascii")
return _translate_non_printable(ret)
@attr.s
class _PytestWrapper:
"""Dummy wrapper around a function object for internal use only.
Used to correctly unwrap the underlying function object when we are
creating fixtures, because we wrap the function object ourselves with a
decorator to issue warnings when the fixture function is called directly.
"""
obj = attr.ib()
def get_real_func(obj):
"""Get the real function object of the (possibly) wrapped object by
functools.wraps or functools.partial."""
start_obj = obj
for i in range(100):
# __pytest_wrapped__ is set by @pytest.fixture when wrapping the fixture function
# to trigger a warning if it gets called directly instead of by pytest: we don't
# want to unwrap further than this otherwise we lose useful wrappings like @mock.patch (#3774)
new_obj = getattr(obj, "__pytest_wrapped__", None)
if isinstance(new_obj, _PytestWrapper):
obj = new_obj.obj
break
new_obj = getattr(obj, "__wrapped__", None)
if new_obj is None:
break
obj = new_obj
else:
from _pytest._io.saferepr import saferepr
raise ValueError(
("could not find real function of {start}\nstopped at {current}").format(
start=saferepr(start_obj), current=saferepr(obj)
)
)
if isinstance(obj, functools.partial):
obj = obj.func
return obj
def get_real_method(obj, holder):
"""Attempt to obtain the real function object that might be wrapping
``obj``, while at the same time returning a bound method to ``holder`` if
the original object was a bound method."""
try:
is_method = hasattr(obj, "__func__")
obj = get_real_func(obj)
except Exception: # pragma: no cover
return obj
if is_method and hasattr(obj, "__get__") and callable(obj.__get__):
obj = obj.__get__(holder)
return obj
def getimfunc(func):
try:
return func.__func__
except AttributeError:
return func
def safe_getattr(object: Any, name: str, default: Any) -> Any:
"""Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects.
See issue #214.
It catches OutcomeException because of #2490 (issue #580), new outcomes
are derived from BaseException instead of Exception (for more details
check #2707).
"""
from _pytest.outcomes import TEST_OUTCOME
try:
return getattr(object, name, default)
except TEST_OUTCOME:
return default
def safe_isclass(obj: object) -> bool:
"""Ignore any exception via isinstance on Python 3."""
try:
return inspect.isclass(obj)
except Exception:
return False
if TYPE_CHECKING:
if sys.version_info >= (3, 8):
from typing import final as final
else:
from typing_extensions import final as final
elif sys.version_info >= (3, 8):
from typing import final as final
else:
def final(f):
return f
if sys.version_info >= (3, 8):
from functools import cached_property as cached_property
else:
from typing import Type
class cached_property(Generic[_S, _T]):
__slots__ = ("func", "__doc__")
def __init__(self, func: Callable[[_S], _T]) -> None:
self.func = func
self.__doc__ = func.__doc__
@overload
def __get__(
self, instance: None, owner: Optional[Type[_S]] = ...
) -> "cached_property[_S, _T]":
...
@overload
def __get__(self, instance: _S, owner: Optional[Type[_S]] = ...) -> _T:
...
def __get__(self, instance, owner=None):
if instance is None:
return self
value = instance.__dict__[self.func.__name__] = self.func(instance)
return value
# Perform exhaustiveness checking.
#
# Consider this example:
#
# MyUnion = Union[int, str]
#
# def handle(x: MyUnion) -> int {
# if isinstance(x, int):
# return 1
# elif isinstance(x, str):
# return 2
# else:
# raise Exception('unreachable')
#
# Now suppose we add a new variant:
#
# MyUnion = Union[int, str, bytes]
#
# After doing this, we must remember ourselves to go and update the handle
# function to handle the new variant.
#
# With `assert_never` we can do better:
#
# // raise Exception('unreachable')
# return assert_never(x)
#
# Now, if we forget to handle the new variant, the type-checker will emit a
# compile-time error, instead of the runtime error we would have gotten
# previously.
#
# This also work for Enums (if you use `is` to compare) and Literals.
def assert_never(value: NoReturn) -> NoReturn:
assert False, f"Unhandled value: {value} ({type(value).__name__})"

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,551 @@
import argparse
import os
import sys
import warnings
from gettext import gettext
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import List
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
import _pytest._io
from _pytest.compat import final
from _pytest.config.exceptions import UsageError
from _pytest.deprecated import ARGUMENT_PERCENT_DEFAULT
from _pytest.deprecated import ARGUMENT_TYPE_STR
from _pytest.deprecated import ARGUMENT_TYPE_STR_CHOICE
from _pytest.deprecated import check_ispytest
if TYPE_CHECKING:
from typing_extensions import Literal
FILE_OR_DIR = "file_or_dir"
@final
class Parser:
"""Parser for command line arguments and ini-file values.
:ivar extra_info: Dict of generic param -> value to display in case
there's an error processing the command line arguments.
"""
prog: Optional[str] = None
def __init__(
self,
usage: Optional[str] = None,
processopt: Optional[Callable[["Argument"], None]] = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self._anonymous = OptionGroup("Custom options", parser=self, _ispytest=True)
self._groups: List[OptionGroup] = []
self._processopt = processopt
self._usage = usage
self._inidict: Dict[str, Tuple[str, Optional[str], Any]] = {}
self._ininames: List[str] = []
self.extra_info: Dict[str, Any] = {}
def processoption(self, option: "Argument") -> None:
if self._processopt:
if option.dest:
self._processopt(option)
def getgroup(
self, name: str, description: str = "", after: Optional[str] = None
) -> "OptionGroup":
"""Get (or create) a named option Group.
:param name: Name of the option group.
:param description: Long description for --help output.
:param after: Name of another group, used for ordering --help output.
:returns: The option group.
The returned group object has an ``addoption`` method with the same
signature as :func:`parser.addoption <pytest.Parser.addoption>` but
will be shown in the respective group in the output of
``pytest --help``.
"""
for group in self._groups:
if group.name == name:
return group
group = OptionGroup(name, description, parser=self, _ispytest=True)
i = 0
for i, grp in enumerate(self._groups):
if grp.name == after:
break
self._groups.insert(i + 1, group)
return group
def addoption(self, *opts: str, **attrs: Any) -> None:
"""Register a command line option.
:param opts:
Option names, can be short or long options.
:param attrs:
Same attributes as the argparse library's :py:func:`add_argument()
<argparse.ArgumentParser.add_argument>` function accepts.
After command line parsing, options are available on the pytest config
object via ``config.option.NAME`` where ``NAME`` is usually set
by passing a ``dest`` attribute, for example
``addoption("--long", dest="NAME", ...)``.
"""
self._anonymous.addoption(*opts, **attrs)
def parse(
self,
args: Sequence[Union[str, "os.PathLike[str]"]],
namespace: Optional[argparse.Namespace] = None,
) -> argparse.Namespace:
from _pytest._argcomplete import try_argcomplete
self.optparser = self._getparser()
try_argcomplete(self.optparser)
strargs = [os.fspath(x) for x in args]
return self.optparser.parse_args(strargs, namespace=namespace)
def _getparser(self) -> "MyOptionParser":
from _pytest._argcomplete import filescompleter
optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
groups = self._groups + [self._anonymous]
for group in groups:
if group.options:
desc = group.description or group.name
arggroup = optparser.add_argument_group(desc)
for option in group.options:
n = option.names()
a = option.attrs()
arggroup.add_argument(*n, **a)
file_or_dir_arg = optparser.add_argument(FILE_OR_DIR, nargs="*")
# bash like autocompletion for dirs (appending '/')
# Type ignored because typeshed doesn't know about argcomplete.
file_or_dir_arg.completer = filescompleter # type: ignore
return optparser
def parse_setoption(
self,
args: Sequence[Union[str, "os.PathLike[str]"]],
option: argparse.Namespace,
namespace: Optional[argparse.Namespace] = None,
) -> List[str]:
parsedoption = self.parse(args, namespace=namespace)
for name, value in parsedoption.__dict__.items():
setattr(option, name, value)
return cast(List[str], getattr(parsedoption, FILE_OR_DIR))
def parse_known_args(
self,
args: Sequence[Union[str, "os.PathLike[str]"]],
namespace: Optional[argparse.Namespace] = None,
) -> argparse.Namespace:
"""Parse the known arguments at this point.
:returns: An argparse namespace object.
"""
return self.parse_known_and_unknown_args(args, namespace=namespace)[0]
def parse_known_and_unknown_args(
self,
args: Sequence[Union[str, "os.PathLike[str]"]],
namespace: Optional[argparse.Namespace] = None,
) -> Tuple[argparse.Namespace, List[str]]:
"""Parse the known arguments at this point, and also return the
remaining unknown arguments.
:returns:
A tuple containing an argparse namespace object for the known
arguments, and a list of the unknown arguments.
"""
optparser = self._getparser()
strargs = [os.fspath(x) for x in args]
return optparser.parse_known_args(strargs, namespace=namespace)
def addini(
self,
name: str,
help: str,
type: Optional[
"Literal['string', 'paths', 'pathlist', 'args', 'linelist', 'bool']"
] = None,
default: Any = None,
) -> None:
"""Register an ini-file option.
:param name:
Name of the ini-variable.
:param type:
Type of the variable. Can be:
* ``string``: a string
* ``bool``: a boolean
* ``args``: a list of strings, separated as in a shell
* ``linelist``: a list of strings, separated by line breaks
* ``paths``: a list of :class:`pathlib.Path`, separated as in a shell
* ``pathlist``: a list of ``py.path``, separated as in a shell
.. versionadded:: 7.0
The ``paths`` variable type.
Defaults to ``string`` if ``None`` or not passed.
:param default:
Default value if no ini-file option exists but is queried.
The value of ini-variables can be retrieved via a call to
:py:func:`config.getini(name) <pytest.Config.getini>`.
"""
assert type in (None, "string", "paths", "pathlist", "args", "linelist", "bool")
self._inidict[name] = (help, type, default)
self._ininames.append(name)
class ArgumentError(Exception):
"""Raised if an Argument instance is created with invalid or
inconsistent arguments."""
def __init__(self, msg: str, option: Union["Argument", str]) -> None:
self.msg = msg
self.option_id = str(option)
def __str__(self) -> str:
if self.option_id:
return f"option {self.option_id}: {self.msg}"
else:
return self.msg
class Argument:
"""Class that mimics the necessary behaviour of optparse.Option.
It's currently a least effort implementation and ignoring choices
and integer prefixes.
https://docs.python.org/3/library/optparse.html#optparse-standard-option-types
"""
_typ_map = {"int": int, "string": str, "float": float, "complex": complex}
def __init__(self, *names: str, **attrs: Any) -> None:
"""Store params in private vars for use in add_argument."""
self._attrs = attrs
self._short_opts: List[str] = []
self._long_opts: List[str] = []
if "%default" in (attrs.get("help") or ""):
warnings.warn(ARGUMENT_PERCENT_DEFAULT, stacklevel=3)
try:
typ = attrs["type"]
except KeyError:
pass
else:
# This might raise a keyerror as well, don't want to catch that.
if isinstance(typ, str):
if typ == "choice":
warnings.warn(
ARGUMENT_TYPE_STR_CHOICE.format(typ=typ, names=names),
stacklevel=4,
)
# argparse expects a type here take it from
# the type of the first element
attrs["type"] = type(attrs["choices"][0])
else:
warnings.warn(
ARGUMENT_TYPE_STR.format(typ=typ, names=names), stacklevel=4
)
attrs["type"] = Argument._typ_map[typ]
# Used in test_parseopt -> test_parse_defaultgetter.
self.type = attrs["type"]
else:
self.type = typ
try:
# Attribute existence is tested in Config._processopt.
self.default = attrs["default"]
except KeyError:
pass
self._set_opt_strings(names)
dest: Optional[str] = attrs.get("dest")
if dest:
self.dest = dest
elif self._long_opts:
self.dest = self._long_opts[0][2:].replace("-", "_")
else:
try:
self.dest = self._short_opts[0][1:]
except IndexError as e:
self.dest = "???" # Needed for the error repr.
raise ArgumentError("need a long or short option", self) from e
def names(self) -> List[str]:
return self._short_opts + self._long_opts
def attrs(self) -> Mapping[str, Any]:
# Update any attributes set by processopt.
attrs = "default dest help".split()
attrs.append(self.dest)
for attr in attrs:
try:
self._attrs[attr] = getattr(self, attr)
except AttributeError:
pass
if self._attrs.get("help"):
a = self._attrs["help"]
a = a.replace("%default", "%(default)s")
# a = a.replace('%prog', '%(prog)s')
self._attrs["help"] = a
return self._attrs
def _set_opt_strings(self, opts: Sequence[str]) -> None:
"""Directly from optparse.
Might not be necessary as this is passed to argparse later on.
"""
for opt in opts:
if len(opt) < 2:
raise ArgumentError(
"invalid option string %r: "
"must be at least two characters long" % opt,
self,
)
elif len(opt) == 2:
if not (opt[0] == "-" and opt[1] != "-"):
raise ArgumentError(
"invalid short option string %r: "
"must be of the form -x, (x any non-dash char)" % opt,
self,
)
self._short_opts.append(opt)
else:
if not (opt[0:2] == "--" and opt[2] != "-"):
raise ArgumentError(
"invalid long option string %r: "
"must start with --, followed by non-dash" % opt,
self,
)
self._long_opts.append(opt)
def __repr__(self) -> str:
args: List[str] = []
if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)]
if self._long_opts:
args += ["_long_opts: " + repr(self._long_opts)]
args += ["dest: " + repr(self.dest)]
if hasattr(self, "type"):
args += ["type: " + repr(self.type)]
if hasattr(self, "default"):
args += ["default: " + repr(self.default)]
return "Argument({})".format(", ".join(args))
class OptionGroup:
"""A group of options shown in its own section."""
def __init__(
self,
name: str,
description: str = "",
parser: Optional[Parser] = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self.name = name
self.description = description
self.options: List[Argument] = []
self.parser = parser
def addoption(self, *opts: str, **attrs: Any) -> None:
"""Add an option to this group.
If a shortened version of a long option is specified, it will
be suppressed in the help. ``addoption('--twowords', '--two-words')``
results in help showing ``--two-words`` only, but ``--twowords`` gets
accepted **and** the automatic destination is in ``args.twowords``.
:param opts:
Option names, can be short or long options.
:param attrs:
Same attributes as the argparse library's :py:func:`add_argument()
<argparse.ArgumentParser.add_argument>` function accepts.
"""
conflict = set(opts).intersection(
name for opt in self.options for name in opt.names()
)
if conflict:
raise ValueError("option names %s already added" % conflict)
option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=False)
def _addoption(self, *opts: str, **attrs: Any) -> None:
option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=True)
def _addoption_instance(self, option: "Argument", shortupper: bool = False) -> None:
if not shortupper:
for opt in option._short_opts:
if opt[0] == "-" and opt[1].islower():
raise ValueError("lowercase shortoptions reserved")
if self.parser:
self.parser.processoption(option)
self.options.append(option)
class MyOptionParser(argparse.ArgumentParser):
def __init__(
self,
parser: Parser,
extra_info: Optional[Dict[str, Any]] = None,
prog: Optional[str] = None,
) -> None:
self._parser = parser
super().__init__(
prog=prog,
usage=parser._usage,
add_help=False,
formatter_class=DropShorterLongHelpFormatter,
allow_abbrev=False,
)
# extra_info is a dict of (param -> value) to display if there's
# an usage error to provide more contextual information to the user.
self.extra_info = extra_info if extra_info else {}
def error(self, message: str) -> NoReturn:
"""Transform argparse error message into UsageError."""
msg = f"{self.prog}: error: {message}"
if hasattr(self._parser, "_config_source_hint"):
# Type ignored because the attribute is set dynamically.
msg = f"{msg} ({self._parser._config_source_hint})" # type: ignore
raise UsageError(self.format_usage() + msg)
# Type ignored because typeshed has a very complex type in the superclass.
def parse_args( # type: ignore
self,
args: Optional[Sequence[str]] = None,
namespace: Optional[argparse.Namespace] = None,
) -> argparse.Namespace:
"""Allow splitting of positional arguments."""
parsed, unrecognized = self.parse_known_args(args, namespace)
if unrecognized:
for arg in unrecognized:
if arg and arg[0] == "-":
lines = ["unrecognized arguments: %s" % (" ".join(unrecognized))]
for k, v in sorted(self.extra_info.items()):
lines.append(f" {k}: {v}")
self.error("\n".join(lines))
getattr(parsed, FILE_OR_DIR).extend(unrecognized)
return parsed
if sys.version_info[:2] < (3, 9): # pragma: no cover
# Backport of https://github.com/python/cpython/pull/14316 so we can
# disable long --argument abbreviations without breaking short flags.
def _parse_optional(
self, arg_string: str
) -> Optional[Tuple[Optional[argparse.Action], str, Optional[str]]]:
if not arg_string:
return None
if not arg_string[0] in self.prefix_chars:
return None
if arg_string in self._option_string_actions:
action = self._option_string_actions[arg_string]
return action, arg_string, None
if len(arg_string) == 1:
return None
if "=" in arg_string:
option_string, explicit_arg = arg_string.split("=", 1)
if option_string in self._option_string_actions:
action = self._option_string_actions[option_string]
return action, option_string, explicit_arg
if self.allow_abbrev or not arg_string.startswith("--"):
option_tuples = self._get_option_tuples(arg_string)
if len(option_tuples) > 1:
msg = gettext(
"ambiguous option: %(option)s could match %(matches)s"
)
options = ", ".join(option for _, option, _ in option_tuples)
self.error(msg % {"option": arg_string, "matches": options})
elif len(option_tuples) == 1:
(option_tuple,) = option_tuples
return option_tuple
if self._negative_number_matcher.match(arg_string):
if not self._has_negative_number_optionals:
return None
if " " in arg_string:
return None
return None, arg_string, None
class DropShorterLongHelpFormatter(argparse.HelpFormatter):
"""Shorten help for long options that differ only in extra hyphens.
- Collapse **long** options that are the same except for extra hyphens.
- Shortcut if there are only two options and one of them is a short one.
- Cache result on the action object as this is called at least 2 times.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
# Use more accurate terminal width.
if "width" not in kwargs:
kwargs["width"] = _pytest._io.get_terminal_width()
super().__init__(*args, **kwargs)
def _format_action_invocation(self, action: argparse.Action) -> str:
orgstr = super()._format_action_invocation(action)
if orgstr and orgstr[0] != "-": # only optional arguments
return orgstr
res: Optional[str] = getattr(action, "_formatted_action_invocation", None)
if res:
return res
options = orgstr.split(", ")
if len(options) == 2 and (len(options[0]) == 2 or len(options[1]) == 2):
# a shortcut for '-h, --help' or '--abc', '-a'
action._formatted_action_invocation = orgstr # type: ignore
return orgstr
return_list = []
short_long: Dict[str, str] = {}
for option in options:
if len(option) == 2 or option[2] == " ":
continue
if not option.startswith("--"):
raise ArgumentError(
'long optional argument without "--": [%s]' % (option), option
)
xxoption = option[2:]
shortened = xxoption.replace("-", "")
if shortened not in short_long or len(short_long[shortened]) < len(
xxoption
):
short_long[shortened] = xxoption
# now short_long has been filled out to the longest with dashes
# **and** we keep the right option ordering from add_argument
for option in options:
if len(option) == 2 or option[2] == " ":
return_list.append(option)
if option[2:] == short_long.get(option.replace("-", "")):
return_list.append(option.replace(" ", "=", 1))
formatted_action_invocation = ", ".join(return_list)
action._formatted_action_invocation = formatted_action_invocation # type: ignore
return formatted_action_invocation
def _split_lines(self, text, width):
"""Wrap lines after splitting on original newlines.
This allows to have explicit line breaks in the help text.
"""
import textwrap
lines = []
for line in text.splitlines():
lines.extend(textwrap.wrap(line.strip(), width))
return lines

View File

@ -0,0 +1,71 @@
import functools
import warnings
from pathlib import Path
from typing import Optional
from ..compat import LEGACY_PATH
from ..compat import legacy_path
from ..deprecated import HOOK_LEGACY_PATH_ARG
from _pytest.nodes import _check_path
# hookname: (Path, LEGACY_PATH)
imply_paths_hooks = {
"pytest_ignore_collect": ("collection_path", "path"),
"pytest_collect_file": ("file_path", "path"),
"pytest_pycollect_makemodule": ("module_path", "path"),
"pytest_report_header": ("start_path", "startdir"),
"pytest_report_collectionfinish": ("start_path", "startdir"),
}
class PathAwareHookProxy:
"""
this helper wraps around hook callers
until pluggy supports fixingcalls, this one will do
it currently doesn't return full hook caller proxies for fixed hooks,
this may have to be changed later depending on bugs
"""
def __init__(self, hook_caller):
self.__hook_caller = hook_caller
def __dir__(self):
return dir(self.__hook_caller)
def __getattr__(self, key, _wraps=functools.wraps):
hook = getattr(self.__hook_caller, key)
if key not in imply_paths_hooks:
self.__dict__[key] = hook
return hook
else:
path_var, fspath_var = imply_paths_hooks[key]
@_wraps(hook)
def fixed_hook(**kw):
path_value: Optional[Path] = kw.pop(path_var, None)
fspath_value: Optional[LEGACY_PATH] = kw.pop(fspath_var, None)
if fspath_value is not None:
warnings.warn(
HOOK_LEGACY_PATH_ARG.format(
pylib_path_arg=fspath_var, pathlib_path_arg=path_var
),
stacklevel=2,
)
if path_value is not None:
if fspath_value is not None:
_check_path(path_value, fspath_value)
else:
fspath_value = legacy_path(path_value)
else:
assert fspath_value is not None
path_value = Path(fspath_value)
kw[path_var] = path_value
kw[fspath_var] = fspath_value
return hook(**kw)
fixed_hook.__name__ = key
self.__dict__[key] = fixed_hook
return fixed_hook

View File

@ -0,0 +1,11 @@
from _pytest.compat import final
@final
class UsageError(Exception):
"""Error in pytest usage or invocation."""
class PrintHelp(Exception):
"""Raised when pytest should print its help to skip the rest of the
argument parsing and validation."""

View File

@ -0,0 +1,218 @@
import os
import sys
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
import iniconfig
from .exceptions import UsageError
from _pytest.outcomes import fail
from _pytest.pathlib import absolutepath
from _pytest.pathlib import commonpath
if TYPE_CHECKING:
from . import Config
def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
"""Parse the given generic '.ini' file using legacy IniConfig parser, returning
the parsed object.
Raise UsageError if the file cannot be parsed.
"""
try:
return iniconfig.IniConfig(str(path))
except iniconfig.ParseError as exc:
raise UsageError(str(exc)) from exc
def load_config_dict_from_file(
filepath: Path,
) -> Optional[Dict[str, Union[str, List[str]]]]:
"""Load pytest configuration from the given file path, if supported.
Return None if the file does not contain valid pytest configuration.
"""
# Configuration from ini files are obtained from the [pytest] section, if present.
if filepath.suffix == ".ini":
iniconfig = _parse_ini_config(filepath)
if "pytest" in iniconfig:
return dict(iniconfig["pytest"].items())
else:
# "pytest.ini" files are always the source of configuration, even if empty.
if filepath.name == "pytest.ini":
return {}
# '.cfg' files are considered if they contain a "[tool:pytest]" section.
elif filepath.suffix == ".cfg":
iniconfig = _parse_ini_config(filepath)
if "tool:pytest" in iniconfig.sections:
return dict(iniconfig["tool:pytest"].items())
elif "pytest" in iniconfig.sections:
# If a setup.cfg contains a "[pytest]" section, we raise a failure to indicate users that
# plain "[pytest]" sections in setup.cfg files is no longer supported (#3086).
fail(CFG_PYTEST_SECTION.format(filename="setup.cfg"), pytrace=False)
# '.toml' files are considered if they contain a [tool.pytest.ini_options] table.
elif filepath.suffix == ".toml":
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
toml_text = filepath.read_text(encoding="utf-8")
try:
config = tomllib.loads(toml_text)
except tomllib.TOMLDecodeError as exc:
raise UsageError(f"{filepath}: {exc}") from exc
result = config.get("tool", {}).get("pytest", {}).get("ini_options", None)
if result is not None:
# TOML supports richer data types than ini files (strings, arrays, floats, ints, etc),
# however we need to convert all scalar values to str for compatibility with the rest
# of the configuration system, which expects strings only.
def make_scalar(v: object) -> Union[str, List[str]]:
return v if isinstance(v, list) else str(v)
return {k: make_scalar(v) for k, v in result.items()}
return None
def locate_config(
args: Iterable[Path],
) -> Tuple[Optional[Path], Optional[Path], Dict[str, Union[str, List[str]]]]:
"""Search in the list of arguments for a valid ini-file for pytest,
and return a tuple of (rootdir, inifile, cfg-dict)."""
config_names = [
"pytest.ini",
".pytest.ini",
"pyproject.toml",
"tox.ini",
"setup.cfg",
]
args = [x for x in args if not str(x).startswith("-")]
if not args:
args = [Path.cwd()]
for arg in args:
argpath = absolutepath(arg)
for base in (argpath, *argpath.parents):
for config_name in config_names:
p = base / config_name
if p.is_file():
ini_config = load_config_dict_from_file(p)
if ini_config is not None:
return base, p, ini_config
return None, None, {}
def get_common_ancestor(paths: Iterable[Path]) -> Path:
common_ancestor: Optional[Path] = None
for path in paths:
if not path.exists():
continue
if common_ancestor is None:
common_ancestor = path
else:
if common_ancestor in path.parents or path == common_ancestor:
continue
elif path in common_ancestor.parents:
common_ancestor = path
else:
shared = commonpath(path, common_ancestor)
if shared is not None:
common_ancestor = shared
if common_ancestor is None:
common_ancestor = Path.cwd()
elif common_ancestor.is_file():
common_ancestor = common_ancestor.parent
return common_ancestor
def get_dirs_from_args(args: Iterable[str]) -> List[Path]:
def is_option(x: str) -> bool:
return x.startswith("-")
def get_file_part_from_node_id(x: str) -> str:
return x.split("::")[0]
def get_dir_from_path(path: Path) -> Path:
if path.is_dir():
return path
return path.parent
def safe_exists(path: Path) -> bool:
# This can throw on paths that contain characters unrepresentable at the OS level,
# or with invalid syntax on Windows (https://bugs.python.org/issue35306)
try:
return path.exists()
except OSError:
return False
# These look like paths but may not exist
possible_paths = (
absolutepath(get_file_part_from_node_id(arg))
for arg in args
if not is_option(arg)
)
return [get_dir_from_path(path) for path in possible_paths if safe_exists(path)]
CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead."
def determine_setup(
inifile: Optional[str],
args: Sequence[str],
rootdir_cmd_arg: Optional[str] = None,
config: Optional["Config"] = None,
) -> Tuple[Path, Optional[Path], Dict[str, Union[str, List[str]]]]:
rootdir = None
dirs = get_dirs_from_args(args)
if inifile:
inipath_ = absolutepath(inifile)
inipath: Optional[Path] = inipath_
inicfg = load_config_dict_from_file(inipath_) or {}
if rootdir_cmd_arg is None:
rootdir = inipath_.parent
else:
ancestor = get_common_ancestor(dirs)
rootdir, inipath, inicfg = locate_config([ancestor])
if rootdir is None and rootdir_cmd_arg is None:
for possible_rootdir in (ancestor, *ancestor.parents):
if (possible_rootdir / "setup.py").is_file():
rootdir = possible_rootdir
break
else:
if dirs != [ancestor]:
rootdir, inipath, inicfg = locate_config(dirs)
if rootdir is None:
if config is not None:
cwd = config.invocation_params.dir
else:
cwd = Path.cwd()
rootdir = get_common_ancestor([cwd, ancestor])
is_fs_root = os.path.splitdrive(str(rootdir))[1] == "/"
if is_fs_root:
rootdir = ancestor
if rootdir_cmd_arg:
rootdir = absolutepath(os.path.expandvars(rootdir_cmd_arg))
if not rootdir.is_dir():
raise UsageError(
"Directory '{}' not found. Check your '--rootdir' option.".format(
rootdir
)
)
assert rootdir is not None
return rootdir, inipath, inicfg or {}

View File

@ -0,0 +1,391 @@
"""Interactive debugging with PDB, the Python Debugger."""
import argparse
import functools
import sys
import types
import unittest
from typing import Any
from typing import Callable
from typing import Generator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from _pytest import outcomes
from _pytest._code import ExceptionInfo
from _pytest.config import Config
from _pytest.config import ConftestImportFailure
from _pytest.config import hookimpl
from _pytest.config import PytestPluginManager
from _pytest.config.argparsing import Parser
from _pytest.config.exceptions import UsageError
from _pytest.nodes import Node
from _pytest.reports import BaseReport
if TYPE_CHECKING:
from _pytest.capture import CaptureManager
from _pytest.runner import CallInfo
def _validate_usepdb_cls(value: str) -> Tuple[str, str]:
"""Validate syntax of --pdbcls option."""
try:
modname, classname = value.split(":")
except ValueError as e:
raise argparse.ArgumentTypeError(
f"{value!r} is not in the format 'modname:classname'"
) from e
return (modname, classname)
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption(
"--pdb",
dest="usepdb",
action="store_true",
help="Start the interactive Python debugger on errors or KeyboardInterrupt",
)
group._addoption(
"--pdbcls",
dest="usepdb_cls",
metavar="modulename:classname",
type=_validate_usepdb_cls,
help="Specify a custom interactive Python debugger for use with --pdb."
"For example: --pdbcls=IPython.terminal.debugger:TerminalPdb",
)
group._addoption(
"--trace",
dest="trace",
action="store_true",
help="Immediately break when running each test",
)
def pytest_configure(config: Config) -> None:
import pdb
if config.getvalue("trace"):
config.pluginmanager.register(PdbTrace(), "pdbtrace")
if config.getvalue("usepdb"):
config.pluginmanager.register(PdbInvoke(), "pdbinvoke")
pytestPDB._saved.append(
(pdb.set_trace, pytestPDB._pluginmanager, pytestPDB._config)
)
pdb.set_trace = pytestPDB.set_trace
pytestPDB._pluginmanager = config.pluginmanager
pytestPDB._config = config
# NOTE: not using pytest_unconfigure, since it might get called although
# pytest_configure was not (if another plugin raises UsageError).
def fin() -> None:
(
pdb.set_trace,
pytestPDB._pluginmanager,
pytestPDB._config,
) = pytestPDB._saved.pop()
config.add_cleanup(fin)
class pytestPDB:
"""Pseudo PDB that defers to the real pdb."""
_pluginmanager: Optional[PytestPluginManager] = None
_config: Optional[Config] = None
_saved: List[
Tuple[Callable[..., None], Optional[PytestPluginManager], Optional[Config]]
] = []
_recursive_debug = 0
_wrapped_pdb_cls: Optional[Tuple[Type[Any], Type[Any]]] = None
@classmethod
def _is_capturing(cls, capman: Optional["CaptureManager"]) -> Union[str, bool]:
if capman:
return capman.is_capturing()
return False
@classmethod
def _import_pdb_cls(cls, capman: Optional["CaptureManager"]):
if not cls._config:
import pdb
# Happens when using pytest.set_trace outside of a test.
return pdb.Pdb
usepdb_cls = cls._config.getvalue("usepdb_cls")
if cls._wrapped_pdb_cls and cls._wrapped_pdb_cls[0] == usepdb_cls:
return cls._wrapped_pdb_cls[1]
if usepdb_cls:
modname, classname = usepdb_cls
try:
__import__(modname)
mod = sys.modules[modname]
# Handle --pdbcls=pdb:pdb.Pdb (useful e.g. with pdbpp).
parts = classname.split(".")
pdb_cls = getattr(mod, parts[0])
for part in parts[1:]:
pdb_cls = getattr(pdb_cls, part)
except Exception as exc:
value = ":".join((modname, classname))
raise UsageError(
f"--pdbcls: could not import {value!r}: {exc}"
) from exc
else:
import pdb
pdb_cls = pdb.Pdb
wrapped_cls = cls._get_pdb_wrapper_class(pdb_cls, capman)
cls._wrapped_pdb_cls = (usepdb_cls, wrapped_cls)
return wrapped_cls
@classmethod
def _get_pdb_wrapper_class(cls, pdb_cls, capman: Optional["CaptureManager"]):
import _pytest.config
# Type ignored because mypy doesn't support "dynamic"
# inheritance like this.
class PytestPdbWrapper(pdb_cls): # type: ignore[valid-type,misc]
_pytest_capman = capman
_continued = False
def do_debug(self, arg):
cls._recursive_debug += 1
ret = super().do_debug(arg)
cls._recursive_debug -= 1
return ret
def do_continue(self, arg):
ret = super().do_continue(arg)
if cls._recursive_debug == 0:
assert cls._config is not None
tw = _pytest.config.create_terminal_writer(cls._config)
tw.line()
capman = self._pytest_capman
capturing = pytestPDB._is_capturing(capman)
if capturing:
if capturing == "global":
tw.sep(">", "PDB continue (IO-capturing resumed)")
else:
tw.sep(
">",
"PDB continue (IO-capturing resumed for %s)"
% capturing,
)
assert capman is not None
capman.resume()
else:
tw.sep(">", "PDB continue")
assert cls._pluginmanager is not None
cls._pluginmanager.hook.pytest_leave_pdb(config=cls._config, pdb=self)
self._continued = True
return ret
do_c = do_cont = do_continue
def do_quit(self, arg):
"""Raise Exit outcome when quit command is used in pdb.
This is a bit of a hack - it would be better if BdbQuit
could be handled, but this would require to wrap the
whole pytest run, and adjust the report etc.
"""
ret = super().do_quit(arg)
if cls._recursive_debug == 0:
outcomes.exit("Quitting debugger")
return ret
do_q = do_quit
do_exit = do_quit
def setup(self, f, tb):
"""Suspend on setup().
Needed after do_continue resumed, and entering another
breakpoint again.
"""
ret = super().setup(f, tb)
if not ret and self._continued:
# pdb.setup() returns True if the command wants to exit
# from the interaction: do not suspend capturing then.
if self._pytest_capman:
self._pytest_capman.suspend_global_capture(in_=True)
return ret
def get_stack(self, f, t):
stack, i = super().get_stack(f, t)
if f is None:
# Find last non-hidden frame.
i = max(0, len(stack) - 1)
while i and stack[i][0].f_locals.get("__tracebackhide__", False):
i -= 1
return stack, i
return PytestPdbWrapper
@classmethod
def _init_pdb(cls, method, *args, **kwargs):
"""Initialize PDB debugging, dropping any IO capturing."""
import _pytest.config
if cls._pluginmanager is None:
capman: Optional[CaptureManager] = None
else:
capman = cls._pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend(in_=True)
if cls._config:
tw = _pytest.config.create_terminal_writer(cls._config)
tw.line()
if cls._recursive_debug == 0:
# Handle header similar to pdb.set_trace in py37+.
header = kwargs.pop("header", None)
if header is not None:
tw.sep(">", header)
else:
capturing = cls._is_capturing(capman)
if capturing == "global":
tw.sep(">", f"PDB {method} (IO-capturing turned off)")
elif capturing:
tw.sep(
">",
"PDB %s (IO-capturing turned off for %s)"
% (method, capturing),
)
else:
tw.sep(">", f"PDB {method}")
_pdb = cls._import_pdb_cls(capman)(**kwargs)
if cls._pluginmanager:
cls._pluginmanager.hook.pytest_enter_pdb(config=cls._config, pdb=_pdb)
return _pdb
@classmethod
def set_trace(cls, *args, **kwargs) -> None:
"""Invoke debugging via ``Pdb.set_trace``, dropping any IO capturing."""
frame = sys._getframe().f_back
_pdb = cls._init_pdb("set_trace", *args, **kwargs)
_pdb.set_trace(frame)
class PdbInvoke:
def pytest_exception_interact(
self, node: Node, call: "CallInfo[Any]", report: BaseReport
) -> None:
capman = node.config.pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend_global_capture(in_=True)
out, err = capman.read_global_capture()
sys.stdout.write(out)
sys.stdout.write(err)
assert call.excinfo is not None
if not isinstance(call.excinfo.value, unittest.SkipTest):
_enter_pdb(node, call.excinfo, report)
def pytest_internalerror(self, excinfo: ExceptionInfo[BaseException]) -> None:
tb = _postmortem_traceback(excinfo)
post_mortem(tb)
class PdbTrace:
@hookimpl(hookwrapper=True)
def pytest_pyfunc_call(self, pyfuncitem) -> Generator[None, None, None]:
wrap_pytest_function_for_tracing(pyfuncitem)
yield
def wrap_pytest_function_for_tracing(pyfuncitem):
"""Change the Python function object of the given Function item by a
wrapper which actually enters pdb before calling the python function
itself, effectively leaving the user in the pdb prompt in the first
statement of the function."""
_pdb = pytestPDB._init_pdb("runcall")
testfunction = pyfuncitem.obj
# we can't just return `partial(pdb.runcall, testfunction)` because (on
# python < 3.7.4) runcall's first param is `func`, which means we'd get
# an exception if one of the kwargs to testfunction was called `func`.
@functools.wraps(testfunction)
def wrapper(*args, **kwargs):
func = functools.partial(testfunction, *args, **kwargs)
_pdb.runcall(func)
pyfuncitem.obj = wrapper
def maybe_wrap_pytest_function_for_tracing(pyfuncitem):
"""Wrap the given pytestfunct item for tracing support if --trace was given in
the command line."""
if pyfuncitem.config.getvalue("trace"):
wrap_pytest_function_for_tracing(pyfuncitem)
def _enter_pdb(
node: Node, excinfo: ExceptionInfo[BaseException], rep: BaseReport
) -> BaseReport:
# XXX we re-use the TerminalReporter's terminalwriter
# because this seems to avoid some encoding related troubles
# for not completely clear reasons.
tw = node.config.pluginmanager.getplugin("terminalreporter")._tw
tw.line()
showcapture = node.config.option.showcapture
for sectionname, content in (
("stdout", rep.capstdout),
("stderr", rep.capstderr),
("log", rep.caplog),
):
if showcapture in (sectionname, "all") and content:
tw.sep(">", "captured " + sectionname)
if content[-1:] == "\n":
content = content[:-1]
tw.line(content)
tw.sep(">", "traceback")
rep.toterminal(tw)
tw.sep(">", "entering PDB")
tb = _postmortem_traceback(excinfo)
rep._pdbshown = True # type: ignore[attr-defined]
post_mortem(tb)
return rep
def _postmortem_traceback(excinfo: ExceptionInfo[BaseException]) -> types.TracebackType:
from doctest import UnexpectedException
if isinstance(excinfo.value, UnexpectedException):
# A doctest.UnexpectedException is not useful for post_mortem.
# Use the underlying exception instead:
return excinfo.value.exc_info[2]
elif isinstance(excinfo.value, ConftestImportFailure):
# A config.ConftestImportFailure is not useful for post_mortem.
# Use the underlying exception instead:
return excinfo.value.excinfo[2]
else:
assert excinfo._excinfo is not None
return excinfo._excinfo[2]
def post_mortem(t: types.TracebackType) -> None:
p = pytestPDB._init_pdb("post_mortem")
p.reset()
p.interaction(None, t)
if p.quitting:
outcomes.exit("Quitting debugger")

View File

@ -0,0 +1,146 @@
"""Deprecation messages and bits of code used elsewhere in the codebase that
is planned to be removed in the next pytest release.
Keeping it in a central location makes it easy to track what is deprecated and should
be removed when the time comes.
All constants defined in this module should be either instances of
:class:`PytestWarning`, or :class:`UnformattedWarning`
in case of warnings which need to format their messages.
"""
from warnings import warn
from _pytest.warning_types import PytestDeprecationWarning
from _pytest.warning_types import PytestRemovedIn8Warning
from _pytest.warning_types import UnformattedWarning
# set of plugins which have been integrated into the core; we use this list to ignore
# them during registration to avoid conflicts
DEPRECATED_EXTERNAL_PLUGINS = {
"pytest_catchlog",
"pytest_capturelog",
"pytest_faulthandler",
}
NOSE_SUPPORT = UnformattedWarning(
PytestRemovedIn8Warning,
"Support for nose tests is deprecated and will be removed in a future release.\n"
"{nodeid} is using nose method: `{method}` ({stage})\n"
"See docs: https://docs.pytest.org/en/stable/deprecations.html#support-for-tests-written-for-nose",
)
NOSE_SUPPORT_METHOD = UnformattedWarning(
PytestRemovedIn8Warning,
"Support for nose tests is deprecated and will be removed in a future release.\n"
"{nodeid} is using nose-specific method: `{method}(self)`\n"
"To remove this warning, rename it to `{method}_method(self)`\n"
"See docs: https://docs.pytest.org/en/stable/deprecations.html#support-for-tests-written-for-nose",
)
# This can be* removed pytest 8, but it's harmless and common, so no rush to remove.
# * If you're in the future: "could have been".
YIELD_FIXTURE = PytestDeprecationWarning(
"@pytest.yield_fixture is deprecated.\n"
"Use @pytest.fixture instead; they are the same."
)
WARNING_CMDLINE_PREPARSE_HOOK = PytestRemovedIn8Warning(
"The pytest_cmdline_preparse hook is deprecated and will be removed in a future release. \n"
"Please use pytest_load_initial_conftests hook instead."
)
FSCOLLECTOR_GETHOOKPROXY_ISINITPATH = PytestRemovedIn8Warning(
"The gethookproxy() and isinitpath() methods of FSCollector and Package are deprecated; "
"use self.session.gethookproxy() and self.session.isinitpath() instead. "
)
STRICT_OPTION = PytestRemovedIn8Warning(
"The --strict option is deprecated, use --strict-markers instead."
)
# This deprecation is never really meant to be removed.
PRIVATE = PytestDeprecationWarning("A private pytest class or function was used.")
ARGUMENT_PERCENT_DEFAULT = PytestRemovedIn8Warning(
'pytest now uses argparse. "%default" should be changed to "%(default)s"',
)
ARGUMENT_TYPE_STR_CHOICE = UnformattedWarning(
PytestRemovedIn8Warning,
"`type` argument to addoption() is the string {typ!r}."
" For choices this is optional and can be omitted, "
" but when supplied should be a type (for example `str` or `int`)."
" (options: {names})",
)
ARGUMENT_TYPE_STR = UnformattedWarning(
PytestRemovedIn8Warning,
"`type` argument to addoption() is the string {typ!r}, "
" but when supplied should be a type (for example `str` or `int`)."
" (options: {names})",
)
HOOK_LEGACY_PATH_ARG = UnformattedWarning(
PytestRemovedIn8Warning,
"The ({pylib_path_arg}: py.path.local) argument is deprecated, please use ({pathlib_path_arg}: pathlib.Path)\n"
"see https://docs.pytest.org/en/latest/deprecations.html"
"#py-path-local-arguments-for-hooks-replaced-with-pathlib-path",
)
NODE_CTOR_FSPATH_ARG = UnformattedWarning(
PytestRemovedIn8Warning,
"The (fspath: py.path.local) argument to {node_type_name} is deprecated. "
"Please use the (path: pathlib.Path) argument instead.\n"
"See https://docs.pytest.org/en/latest/deprecations.html"
"#fspath-argument-for-node-constructors-replaced-with-pathlib-path",
)
WARNS_NONE_ARG = PytestRemovedIn8Warning(
"Passing None has been deprecated.\n"
"See https://docs.pytest.org/en/latest/how-to/capture-warnings.html"
"#additional-use-cases-of-warnings-in-tests"
" for alternatives in common use cases."
)
KEYWORD_MSG_ARG = UnformattedWarning(
PytestRemovedIn8Warning,
"pytest.{func}(msg=...) is now deprecated, use pytest.{func}(reason=...) instead",
)
INSTANCE_COLLECTOR = PytestRemovedIn8Warning(
"The pytest.Instance collector type is deprecated and is no longer used. "
"See https://docs.pytest.org/en/latest/deprecations.html#the-pytest-instance-collector",
)
HOOK_LEGACY_MARKING = UnformattedWarning(
PytestDeprecationWarning,
"The hook{type} {fullname} uses old-style configuration options (marks or attributes).\n"
"Please use the pytest.hook{type}({hook_opts}) decorator instead\n"
" to configure the hooks.\n"
" See https://docs.pytest.org/en/latest/deprecations.html"
"#configuring-hook-specs-impls-using-markers",
)
# You want to make some `__init__` or function "private".
#
# def my_private_function(some, args):
# ...
#
# Do this:
#
# def my_private_function(some, args, *, _ispytest: bool = False):
# check_ispytest(_ispytest)
# ...
#
# Change all internal/allowed calls to
#
# my_private_function(some, args, _ispytest=True)
#
# All other calls will get the default _ispytest=False and trigger
# the warning (possibly error in the future).
def check_ispytest(ispytest: bool) -> None:
if not ispytest:
warn(PRIVATE, stacklevel=3)

View File

@ -0,0 +1,752 @@
"""Discover and run doctests in modules and test files."""
import bdb
import inspect
import os
import platform
import sys
import traceback
import types
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from _pytest import outcomes
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ReprFileLocation
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import safe_getattr
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import OutcomeException
from _pytest.outcomes import skip
from _pytest.pathlib import fnmatch_ex
from _pytest.pathlib import import_path
from _pytest.python import Module
from _pytest.python_api import approx
from _pytest.warning_types import PytestWarning
if TYPE_CHECKING:
import doctest
DOCTEST_REPORT_CHOICE_NONE = "none"
DOCTEST_REPORT_CHOICE_CDIFF = "cdiff"
DOCTEST_REPORT_CHOICE_NDIFF = "ndiff"
DOCTEST_REPORT_CHOICE_UDIFF = "udiff"
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE = "only_first_failure"
DOCTEST_REPORT_CHOICES = (
DOCTEST_REPORT_CHOICE_NONE,
DOCTEST_REPORT_CHOICE_CDIFF,
DOCTEST_REPORT_CHOICE_NDIFF,
DOCTEST_REPORT_CHOICE_UDIFF,
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE,
)
# Lazy definition of runner class
RUNNER_CLASS = None
# Lazy definition of output checker class
CHECKER_CLASS: Optional[Type["doctest.OutputChecker"]] = None
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"doctest_optionflags",
"Option flags for doctests",
type="args",
default=["ELLIPSIS"],
)
parser.addini(
"doctest_encoding", "Encoding used for doctest files", default="utf-8"
)
group = parser.getgroup("collect")
group.addoption(
"--doctest-modules",
action="store_true",
default=False,
help="Run doctests in all .py modules",
dest="doctestmodules",
)
group.addoption(
"--doctest-report",
type=str.lower,
default="udiff",
help="Choose another output format for diffs on doctest failure",
choices=DOCTEST_REPORT_CHOICES,
dest="doctestreport",
)
group.addoption(
"--doctest-glob",
action="append",
default=[],
metavar="pat",
help="Doctests file matching pattern, default: test*.txt",
dest="doctestglob",
)
group.addoption(
"--doctest-ignore-import-errors",
action="store_true",
default=False,
help="Ignore doctest ImportErrors",
dest="doctest_ignore_import_errors",
)
group.addoption(
"--doctest-continue-on-failure",
action="store_true",
default=False,
help="For a given doctest, continue to run after the first failure",
dest="doctest_continue_on_failure",
)
def pytest_unconfigure() -> None:
global RUNNER_CLASS
RUNNER_CLASS = None
def pytest_collect_file(
file_path: Path,
parent: Collector,
) -> Optional[Union["DoctestModule", "DoctestTextfile"]]:
config = parent.config
if file_path.suffix == ".py":
if config.option.doctestmodules and not any(
(_is_setup_py(file_path), _is_main_py(file_path))
):
mod: DoctestModule = DoctestModule.from_parent(parent, path=file_path)
return mod
elif _is_doctest(config, file_path, parent):
txt: DoctestTextfile = DoctestTextfile.from_parent(parent, path=file_path)
return txt
return None
def _is_setup_py(path: Path) -> bool:
if path.name != "setup.py":
return False
contents = path.read_bytes()
return b"setuptools" in contents or b"distutils" in contents
def _is_doctest(config: Config, path: Path, parent: Collector) -> bool:
if path.suffix in (".txt", ".rst") and parent.session.isinitpath(path):
return True
globs = config.getoption("doctestglob") or ["test*.txt"]
return any(fnmatch_ex(glob, path) for glob in globs)
def _is_main_py(path: Path) -> bool:
return path.name == "__main__.py"
class ReprFailDoctest(TerminalRepr):
def __init__(
self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]]
) -> None:
self.reprlocation_lines = reprlocation_lines
def toterminal(self, tw: TerminalWriter) -> None:
for reprlocation, lines in self.reprlocation_lines:
for line in lines:
tw.line(line)
reprlocation.toterminal(tw)
class MultipleDoctestFailures(Exception):
def __init__(self, failures: Sequence["doctest.DocTestFailure"]) -> None:
super().__init__()
self.failures = failures
def _init_runner_class() -> Type["doctest.DocTestRunner"]:
import doctest
class PytestDoctestRunner(doctest.DebugRunner):
"""Runner to collect failures.
Note that the out variable in this case is a list instead of a
stdout-like object.
"""
def __init__(
self,
checker: Optional["doctest.OutputChecker"] = None,
verbose: Optional[bool] = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> None:
super().__init__(checker=checker, verbose=verbose, optionflags=optionflags)
self.continue_on_failure = continue_on_failure
def report_failure(
self,
out,
test: "doctest.DocTest",
example: "doctest.Example",
got: str,
) -> None:
failure = doctest.DocTestFailure(test, example, got)
if self.continue_on_failure:
out.append(failure)
else:
raise failure
def report_unexpected_exception(
self,
out,
test: "doctest.DocTest",
example: "doctest.Example",
exc_info: Tuple[Type[BaseException], BaseException, types.TracebackType],
) -> None:
if isinstance(exc_info[1], OutcomeException):
raise exc_info[1]
if isinstance(exc_info[1], bdb.BdbQuit):
outcomes.exit("Quitting debugger")
failure = doctest.UnexpectedException(test, example, exc_info)
if self.continue_on_failure:
out.append(failure)
else:
raise failure
return PytestDoctestRunner
def _get_runner(
checker: Optional["doctest.OutputChecker"] = None,
verbose: Optional[bool] = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> "doctest.DocTestRunner":
# We need this in order to do a lazy import on doctest
global RUNNER_CLASS
if RUNNER_CLASS is None:
RUNNER_CLASS = _init_runner_class()
# Type ignored because the continue_on_failure argument is only defined on
# PytestDoctestRunner, which is lazily defined so can't be used as a type.
return RUNNER_CLASS( # type: ignore
checker=checker,
verbose=verbose,
optionflags=optionflags,
continue_on_failure=continue_on_failure,
)
class DoctestItem(Item):
def __init__(
self,
name: str,
parent: "Union[DoctestTextfile, DoctestModule]",
runner: Optional["doctest.DocTestRunner"] = None,
dtest: Optional["doctest.DocTest"] = None,
) -> None:
super().__init__(name, parent)
self.runner = runner
self.dtest = dtest
self.obj = None
self.fixture_request: Optional[FixtureRequest] = None
@classmethod
def from_parent( # type: ignore
cls,
parent: "Union[DoctestTextfile, DoctestModule]",
*,
name: str,
runner: "doctest.DocTestRunner",
dtest: "doctest.DocTest",
):
# incompatible signature due to imposed limits on subclass
"""The public named constructor."""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
def setup(self) -> None:
if self.dtest is not None:
self.fixture_request = _setup_fixtures(self)
globs = dict(getfixture=self.fixture_request.getfixturevalue)
for name, value in self.fixture_request.getfixturevalue(
"doctest_namespace"
).items():
globs[name] = value
self.dtest.globs.update(globs)
def runtest(self) -> None:
assert self.dtest is not None
assert self.runner is not None
_check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin()
failures: List["doctest.DocTestFailure"] = []
# Type ignored because we change the type of `out` from what
# doctest expects.
self.runner.run(self.dtest, out=failures) # type: ignore[arg-type]
if failures:
raise MultipleDoctestFailures(failures)
def _disable_output_capturing_for_darwin(self) -> None:
"""Disable output capturing. Otherwise, stdout is lost to doctest (#985)."""
if platform.system() != "Darwin":
return
capman = self.config.pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend_global_capture(in_=True)
out, err = capman.read_global_capture()
sys.stdout.write(out)
sys.stderr.write(err)
# TODO: Type ignored -- breaks Liskov Substitution.
def repr_failure( # type: ignore[override]
self,
excinfo: ExceptionInfo[BaseException],
) -> Union[str, TerminalRepr]:
import doctest
failures: Optional[
Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]]
] = None
if isinstance(
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException)
):
failures = [excinfo.value]
elif isinstance(excinfo.value, MultipleDoctestFailures):
failures = excinfo.value.failures
if failures is None:
return super().repr_failure(excinfo)
reprlocation_lines = []
for failure in failures:
example = failure.example
test = failure.test
filename = test.filename
if test.lineno is None:
lineno = None
else:
lineno = test.lineno + example.lineno + 1
message = type(failure).__name__
# TODO: ReprFileLocation doesn't expect a None lineno.
reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type]
checker = _get_checker()
report_choice = _get_report_choice(self.config.getoption("doctestreport"))
if lineno is not None:
assert failure.test.docstring is not None
lines = failure.test.docstring.splitlines(False)
# add line numbers to the left of the error message
assert test.lineno is not None
lines = [
"%03d %s" % (i + test.lineno + 1, x) for (i, x) in enumerate(lines)
]
# trim docstring error lines to 10
lines = lines[max(example.lineno - 9, 0) : example.lineno + 1]
else:
lines = [
"EXAMPLE LOCATION UNKNOWN, not showing all tests of that example"
]
indent = ">>>"
for line in example.source.splitlines():
lines.append(f"??? {indent} {line}")
indent = "..."
if isinstance(failure, doctest.DocTestFailure):
lines += checker.output_difference(
example, failure.got, report_choice
).split("\n")
else:
inner_excinfo = ExceptionInfo.from_exc_info(failure.exc_info)
lines += ["UNEXPECTED EXCEPTION: %s" % repr(inner_excinfo.value)]
lines += [
x.strip("\n") for x in traceback.format_exception(*failure.exc_info)
]
reprlocation_lines.append((reprlocation, lines))
return ReprFailDoctest(reprlocation_lines)
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]:
assert self.dtest is not None
return self.path, self.dtest.lineno, "[doctest] %s" % self.name
def _get_flag_lookup() -> Dict[str, int]:
import doctest
return dict(
DONT_ACCEPT_TRUE_FOR_1=doctest.DONT_ACCEPT_TRUE_FOR_1,
DONT_ACCEPT_BLANKLINE=doctest.DONT_ACCEPT_BLANKLINE,
NORMALIZE_WHITESPACE=doctest.NORMALIZE_WHITESPACE,
ELLIPSIS=doctest.ELLIPSIS,
IGNORE_EXCEPTION_DETAIL=doctest.IGNORE_EXCEPTION_DETAIL,
COMPARISON_FLAGS=doctest.COMPARISON_FLAGS,
ALLOW_UNICODE=_get_allow_unicode_flag(),
ALLOW_BYTES=_get_allow_bytes_flag(),
NUMBER=_get_number_flag(),
)
def get_optionflags(parent):
optionflags_str = parent.config.getini("doctest_optionflags")
flag_lookup_table = _get_flag_lookup()
flag_acc = 0
for flag in optionflags_str:
flag_acc |= flag_lookup_table[flag]
return flag_acc
def _get_continue_on_failure(config):
continue_on_failure = config.getvalue("doctest_continue_on_failure")
if continue_on_failure:
# We need to turn off this if we use pdb since we should stop at
# the first failure.
if config.getvalue("usepdb"):
continue_on_failure = False
return continue_on_failure
class DoctestTextfile(Module):
obj = None
def collect(self) -> Iterable[DoctestItem]:
import doctest
# Inspired by doctest.testfile; ideally we would use it directly,
# but it doesn't support passing a custom checker.
encoding = self.config.getini("doctest_encoding")
text = self.path.read_text(encoding)
filename = str(self.path)
name = self.path.name
globs = {"__name__": "__main__"}
optionflags = get_optionflags(self)
runner = _get_runner(
verbose=False,
optionflags=optionflags,
checker=_get_checker(),
continue_on_failure=_get_continue_on_failure(self.config),
)
parser = doctest.DocTestParser()
test = parser.get_doctest(text, globs, name, filename, 0)
if test.examples:
yield DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test
)
def _check_all_skipped(test: "doctest.DocTest") -> None:
"""Raise pytest.skip() if all examples in the given DocTest have the SKIP
option set."""
import doctest
all_skipped = all(x.options.get(doctest.SKIP, False) for x in test.examples)
if all_skipped:
skip("all tests skipped by +SKIP option")
def _is_mocked(obj: object) -> bool:
"""Return if an object is possibly a mock object by checking the
existence of a highly improbable attribute."""
return (
safe_getattr(obj, "pytest_mock_example_attribute_that_shouldnt_exist", None)
is not None
)
@contextmanager
def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
"""Context manager which replaces ``inspect.unwrap`` with a version
that's aware of mock objects and doesn't recurse into them."""
real_unwrap = inspect.unwrap
def _mock_aware_unwrap(
func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None
) -> Any:
try:
if stop is None or stop is _is_mocked:
return real_unwrap(func, stop=_is_mocked)
_stop = stop
return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func))
except Exception as e:
warnings.warn(
"Got %r when unwrapping %r. This is usually caused "
"by a violation of Python's object protocol; see e.g. "
"https://github.com/pytest-dev/pytest/issues/5080" % (e, func),
PytestWarning,
)
raise
inspect.unwrap = _mock_aware_unwrap
try:
yield
finally:
inspect.unwrap = real_unwrap
class DoctestModule(Module):
def collect(self) -> Iterable[DoctestItem]:
import doctest
class MockAwareDocTestFinder(doctest.DocTestFinder):
"""A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.
https://github.com/pytest-dev/pytest/issues/3456
https://bugs.python.org/issue25532
"""
def _find_lineno(self, obj, source_lines):
"""Doctest code does not take into account `@property`, this
is a hackish way to fix it. https://bugs.python.org/issue17446
Wrapped Doctests will need to be unwrapped so the correct
line number is returned. This will be reported upstream. #8796
"""
if isinstance(obj, property):
obj = getattr(obj, "fget", obj)
if hasattr(obj, "__wrapped__"):
# Get the main obj in case of it being wrapped
obj = inspect.unwrap(obj)
# Type ignored because this is a private function.
return super()._find_lineno( # type:ignore[misc]
obj,
source_lines,
)
def _find(
self, tests, obj, name, module, source_lines, globs, seen
) -> None:
if _is_mocked(obj):
return
with _patch_unwrap_mock_aware():
# Type ignored because this is a private function.
super()._find( # type:ignore[misc]
tests, obj, name, module, source_lines, globs, seen
)
if self.path.name == "conftest.py":
module = self.config.pluginmanager._importconftest(
self.path,
self.config.getoption("importmode"),
rootpath=self.config.rootpath,
)
else:
try:
module = import_path(
self.path,
root=self.config.rootpath,
mode=self.config.getoption("importmode"),
)
except ImportError:
if self.config.getvalue("doctest_ignore_import_errors"):
skip("unable to import module %r" % self.path)
else:
raise
# Uses internal doctest module parsing mechanism.
finder = MockAwareDocTestFinder()
optionflags = get_optionflags(self)
runner = _get_runner(
verbose=False,
optionflags=optionflags,
checker=_get_checker(),
continue_on_failure=_get_continue_on_failure(self.config),
)
for test in finder.find(module, module.__name__):
if test.examples: # skip empty doctests
yield DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test
)
def _setup_fixtures(doctest_item: DoctestItem) -> FixtureRequest:
"""Used by DoctestTextfile and DoctestItem to setup fixture information."""
def func() -> None:
pass
doctest_item.funcargs = {} # type: ignore[attr-defined]
fm = doctest_item.session._fixturemanager
doctest_item._fixtureinfo = fm.getfixtureinfo( # type: ignore[attr-defined]
node=doctest_item, func=func, cls=None, funcargs=False
)
fixture_request = FixtureRequest(doctest_item, _ispytest=True)
fixture_request._fillfixtures()
return fixture_request
def _init_checker_class() -> Type["doctest.OutputChecker"]:
import doctest
import re
class LiteralsOutputChecker(doctest.OutputChecker):
# Based on doctest_nose_plugin.py from the nltk project
# (https://github.com/nltk/nltk) and on the "numtest" doctest extension
# by Sebastien Boisgerault (https://github.com/boisgera/numtest).
_unicode_literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE)
_bytes_literal_re = re.compile(r"(\W|^)[bB]([rR]?[\'\"])", re.UNICODE)
_number_re = re.compile(
r"""
(?P<number>
(?P<mantissa>
(?P<integer1> [+-]?\d*)\.(?P<fraction>\d+)
|
(?P<integer2> [+-]?\d+)\.
)
(?:
[Ee]
(?P<exponent1> [+-]?\d+)
)?
|
(?P<integer3> [+-]?\d+)
(?:
[Ee]
(?P<exponent2> [+-]?\d+)
)
)
""",
re.VERBOSE,
)
def check_output(self, want: str, got: str, optionflags: int) -> bool:
if super().check_output(want, got, optionflags):
return True
allow_unicode = optionflags & _get_allow_unicode_flag()
allow_bytes = optionflags & _get_allow_bytes_flag()
allow_number = optionflags & _get_number_flag()
if not allow_unicode and not allow_bytes and not allow_number:
return False
def remove_prefixes(regex: Pattern[str], txt: str) -> str:
return re.sub(regex, r"\1\2", txt)
if allow_unicode:
want = remove_prefixes(self._unicode_literal_re, want)
got = remove_prefixes(self._unicode_literal_re, got)
if allow_bytes:
want = remove_prefixes(self._bytes_literal_re, want)
got = remove_prefixes(self._bytes_literal_re, got)
if allow_number:
got = self._remove_unwanted_precision(want, got)
return super().check_output(want, got, optionflags)
def _remove_unwanted_precision(self, want: str, got: str) -> str:
wants = list(self._number_re.finditer(want))
gots = list(self._number_re.finditer(got))
if len(wants) != len(gots):
return got
offset = 0
for w, g in zip(wants, gots):
fraction: Optional[str] = w.group("fraction")
exponent: Optional[str] = w.group("exponent1")
if exponent is None:
exponent = w.group("exponent2")
precision = 0 if fraction is None else len(fraction)
if exponent is not None:
precision -= int(exponent)
if float(w.group()) == approx(float(g.group()), abs=10**-precision):
# They're close enough. Replace the text we actually
# got with the text we want, so that it will match when we
# check the string literally.
got = (
got[: g.start() + offset] + w.group() + got[g.end() + offset :]
)
offset += w.end() - w.start() - (g.end() - g.start())
return got
return LiteralsOutputChecker
def _get_checker() -> "doctest.OutputChecker":
"""Return a doctest.OutputChecker subclass that supports some
additional options:
* ALLOW_UNICODE and ALLOW_BYTES options to ignore u'' and b''
prefixes (respectively) in string literals. Useful when the same
doctest should run in Python 2 and Python 3.
* NUMBER to ignore floating-point differences smaller than the
precision of the literal number in the doctest.
An inner class is used to avoid importing "doctest" at the module
level.
"""
global CHECKER_CLASS
if CHECKER_CLASS is None:
CHECKER_CLASS = _init_checker_class()
return CHECKER_CLASS()
def _get_allow_unicode_flag() -> int:
"""Register and return the ALLOW_UNICODE flag."""
import doctest
return doctest.register_optionflag("ALLOW_UNICODE")
def _get_allow_bytes_flag() -> int:
"""Register and return the ALLOW_BYTES flag."""
import doctest
return doctest.register_optionflag("ALLOW_BYTES")
def _get_number_flag() -> int:
"""Register and return the NUMBER flag."""
import doctest
return doctest.register_optionflag("NUMBER")
def _get_report_choice(key: str) -> int:
"""Return the actual `doctest` module flag value.
We want to do it as late as possible to avoid importing `doctest` and all
its dependencies when parsing options, as it adds overhead and breaks tests.
"""
import doctest
return {
DOCTEST_REPORT_CHOICE_UDIFF: doctest.REPORT_UDIFF,
DOCTEST_REPORT_CHOICE_CDIFF: doctest.REPORT_CDIFF,
DOCTEST_REPORT_CHOICE_NDIFF: doctest.REPORT_NDIFF,
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE: doctest.REPORT_ONLY_FIRST_FAILURE,
DOCTEST_REPORT_CHOICE_NONE: 0,
}[key]
@fixture(scope="session")
def doctest_namespace() -> Dict[str, Any]:
"""Fixture that returns a :py:class:`dict` that will be injected into the
namespace of doctests.
Usually this fixture is used in conjunction with another ``autouse`` fixture:
.. code-block:: python
@pytest.fixture(autouse=True)
def add_np(doctest_namespace):
doctest_namespace["np"] = numpy
For more details: :ref:`doctest_namespace`.
"""
return dict()

View File

@ -0,0 +1,97 @@
import io
import os
import sys
from typing import Generator
from typing import TextIO
import pytest
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
from _pytest.stash import StashKey
fault_handler_stderr_key = StashKey[TextIO]()
fault_handler_originally_enabled_key = StashKey[bool]()
def pytest_addoption(parser: Parser) -> None:
help = (
"Dump the traceback of all threads if a test takes "
"more than TIMEOUT seconds to finish"
)
parser.addini("faulthandler_timeout", help, default=0.0)
def pytest_configure(config: Config) -> None:
import faulthandler
stderr_fd_copy = os.dup(get_stderr_fileno())
config.stash[fault_handler_stderr_key] = open(stderr_fd_copy, "w")
config.stash[fault_handler_originally_enabled_key] = faulthandler.is_enabled()
faulthandler.enable(file=config.stash[fault_handler_stderr_key])
def pytest_unconfigure(config: Config) -> None:
import faulthandler
faulthandler.disable()
# Close the dup file installed during pytest_configure.
if fault_handler_stderr_key in config.stash:
config.stash[fault_handler_stderr_key].close()
del config.stash[fault_handler_stderr_key]
if config.stash.get(fault_handler_originally_enabled_key, False):
# Re-enable the faulthandler if it was originally enabled.
faulthandler.enable(file=get_stderr_fileno())
def get_stderr_fileno() -> int:
try:
fileno = sys.stderr.fileno()
# The Twisted Logger will return an invalid file descriptor since it is not backed
# by an FD. So, let's also forward this to the same code path as with pytest-xdist.
if fileno == -1:
raise AttributeError()
return fileno
except (AttributeError, io.UnsupportedOperation):
# pytest-xdist monkeypatches sys.stderr with an object that is not an actual file.
# https://docs.python.org/3/library/faulthandler.html#issue-with-file-descriptors
# This is potentially dangerous, but the best we can do.
return sys.__stderr__.fileno()
def get_timeout_config_value(config: Config) -> float:
return float(config.getini("faulthandler_timeout") or 0.0)
@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
timeout = get_timeout_config_value(item.config)
stderr = item.config.stash[fault_handler_stderr_key]
if timeout > 0 and stderr is not None:
import faulthandler
faulthandler.dump_traceback_later(timeout, file=stderr)
try:
yield
finally:
faulthandler.cancel_dump_traceback_later()
else:
yield
@pytest.hookimpl(tryfirst=True)
def pytest_enter_pdb() -> None:
"""Cancel any traceback dumping due to timeout before entering pdb."""
import faulthandler
faulthandler.cancel_dump_traceback_later()
@pytest.hookimpl(tryfirst=True)
def pytest_exception_interact() -> None:
"""Cancel any traceback dumping due to an interactive exception being
raised."""
import faulthandler
faulthandler.cancel_dump_traceback_later()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,44 @@
"""Provides a function to report all internal modules for using freezing
tools."""
import types
from typing import Iterator
from typing import List
from typing import Union
def freeze_includes() -> List[str]:
"""Return a list of module names used by pytest that should be
included by cx_freeze."""
import _pytest
result = list(_iter_all_modules(_pytest))
return result
def _iter_all_modules(
package: Union[str, types.ModuleType],
prefix: str = "",
) -> Iterator[str]:
"""Iterate over the names of all modules that can be found in the given
package, recursively.
>>> import _pytest
>>> list(_iter_all_modules(_pytest))
['_pytest._argcomplete', '_pytest._code.code', ...]
"""
import os
import pkgutil
if isinstance(package, str):
path = package
else:
# Type ignored because typeshed doesn't define ModuleType.__path__
# (only defined on packages).
package_path = package.__path__ # type: ignore[attr-defined]
path, prefix = package_path[0], package.__name__ + "."
for _, name, is_package in pkgutil.iter_modules([path]):
if is_package:
for m in _iter_all_modules(os.path.join(path, name), prefix=name + "."):
yield prefix + m
else:
yield prefix + name

View File

@ -0,0 +1,265 @@
"""Version info, help messages, tracing configuration."""
import os
import sys
from argparse import Action
from typing import List
from typing import Optional
from typing import Union
import pytest
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import PrintHelp
from _pytest.config.argparsing import Parser
class HelpAction(Action):
"""An argparse Action that will raise an exception in order to skip the
rest of the argument parsing when --help is passed.
This prevents argparse from quitting due to missing required arguments
when any are defined, for example by ``pytest_addoption``.
This is similar to the way that the builtin argparse --help option is
implemented by raising SystemExit.
"""
def __init__(self, option_strings, dest=None, default=False, help=None):
super().__init__(
option_strings=option_strings,
dest=dest,
const=True,
default=default,
nargs=0,
help=help,
)
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, self.const)
# We should only skip the rest of the parsing after preparse is done.
if getattr(parser._parser, "after_preparse", False):
raise PrintHelp
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--version",
"-V",
action="count",
default=0,
dest="version",
help="Display pytest version and information about plugins. "
"When given twice, also display information about plugins.",
)
group._addoption(
"-h",
"--help",
action=HelpAction,
dest="help",
help="Show help message and configuration info",
)
group._addoption(
"-p",
action="append",
dest="plugins",
default=[],
metavar="name",
help="Early-load given plugin module name or entry point (multi-allowed). "
"To avoid loading of plugins, use the `no:` prefix, e.g. "
"`no:doctest`.",
)
group.addoption(
"--traceconfig",
"--trace-config",
action="store_true",
default=False,
help="Trace considerations of conftest.py files",
)
group.addoption(
"--debug",
action="store",
nargs="?",
const="pytestdebug.log",
dest="debug",
metavar="DEBUG_FILE_NAME",
help="Store internal tracing debug information in this log file. "
"This file is opened with 'w' and truncated as a result, care advised. "
"Default: pytestdebug.log.",
)
group._addoption(
"-o",
"--override-ini",
dest="override_ini",
action="append",
help='Override ini option with "option=value" style, '
"e.g. `-o xfail_strict=True -o cache_dir=cache`.",
)
@pytest.hookimpl(hookwrapper=True)
def pytest_cmdline_parse():
outcome = yield
config: Config = outcome.get_result()
if config.option.debug:
# --debug | --debug <file.log> was provided.
path = config.option.debug
debugfile = open(path, "w")
debugfile.write(
"versions pytest-%s, "
"python-%s\ncwd=%s\nargs=%s\n\n"
% (
pytest.__version__,
".".join(map(str, sys.version_info)),
os.getcwd(),
config.invocation_params.args,
)
)
config.trace.root.setwriter(debugfile.write)
undo_tracing = config.pluginmanager.enable_tracing()
sys.stderr.write("writing pytest debug information to %s\n" % path)
def unset_tracing() -> None:
debugfile.close()
sys.stderr.write("wrote pytest debug information to %s\n" % debugfile.name)
config.trace.root.setwriter(None)
undo_tracing()
config.add_cleanup(unset_tracing)
def showversion(config: Config) -> None:
if config.option.version > 1:
sys.stdout.write(
"This is pytest version {}, imported from {}\n".format(
pytest.__version__, pytest.__file__
)
)
plugininfo = getpluginversioninfo(config)
if plugininfo:
for line in plugininfo:
sys.stdout.write(line + "\n")
else:
sys.stdout.write(f"pytest {pytest.__version__}\n")
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.version > 0:
showversion(config)
return 0
elif config.option.help:
config._do_configure()
showhelp(config)
config._ensure_unconfigure()
return 0
return None
def showhelp(config: Config) -> None:
import textwrap
reporter = config.pluginmanager.get_plugin("terminalreporter")
tw = reporter._tw
tw.write(config._parser.optparser.format_help())
tw.line()
tw.line(
"[pytest] ini-options in the first pytest.ini|tox.ini|setup.cfg file found:"
)
tw.line()
columns = tw.fullwidth # costly call
indent_len = 24 # based on argparse's max_help_position=24
indent = " " * indent_len
for name in config._parser._ininames:
help, type, default = config._parser._inidict[name]
if type is None:
type = "string"
if help is None:
raise TypeError(f"help argument cannot be None for {name}")
spec = f"{name} ({type}):"
tw.write(" %s" % spec)
spec_len = len(spec)
if spec_len > (indent_len - 3):
# Display help starting at a new line.
tw.line()
helplines = textwrap.wrap(
help,
columns,
initial_indent=indent,
subsequent_indent=indent,
break_on_hyphens=False,
)
for line in helplines:
tw.line(line)
else:
# Display help starting after the spec, following lines indented.
tw.write(" " * (indent_len - spec_len - 2))
wrapped = textwrap.wrap(help, columns - indent_len, break_on_hyphens=False)
if wrapped:
tw.line(wrapped[0])
for line in wrapped[1:]:
tw.line(indent + line)
tw.line()
tw.line("Environment variables:")
vars = [
("PYTEST_ADDOPTS", "Extra command line options"),
("PYTEST_PLUGINS", "Comma-separated plugins to load during startup"),
("PYTEST_DISABLE_PLUGIN_AUTOLOAD", "Set to disable plugin auto-loading"),
("PYTEST_DEBUG", "Set to enable debug tracing of pytest's internals"),
]
for name, help in vars:
tw.line(f" {name:<24} {help}")
tw.line()
tw.line()
tw.line("to see available markers type: pytest --markers")
tw.line("to see available fixtures type: pytest --fixtures")
tw.line(
"(shown according to specified file_or_dir or current dir "
"if not specified; fixtures with leading '_' are only shown "
"with the '-v' option"
)
for warningreport in reporter.stats.get("warnings", []):
tw.line("warning : " + warningreport.message, red=True)
return
conftest_options = [("pytest_plugins", "list of plugin names to load")]
def getpluginversioninfo(config: Config) -> List[str]:
lines = []
plugininfo = config.pluginmanager.list_plugin_distinfo()
if plugininfo:
lines.append("setuptools registered plugins:")
for plugin, dist in plugininfo:
loc = getattr(plugin, "__file__", repr(plugin))
content = f"{dist.project_name}-{dist.version} at {loc}"
lines.append(" " + content)
return lines
def pytest_report_header(config: Config) -> List[str]:
lines = []
if config.option.debug or config.option.traceconfig:
lines.append(f"using: pytest-{pytest.__version__}")
verinfo = getpluginversioninfo(config)
if verinfo:
lines.extend(verinfo)
if config.option.traceconfig:
lines.append("active plugins:")
items = config.pluginmanager.list_name_plugin()
for name, plugin in items:
if hasattr(plugin, "__file__"):
r = plugin.__file__
else:
r = repr(plugin)
lines.append(f" {name:<20}: {r}")
return lines

View File

@ -0,0 +1,972 @@
"""Hook specifications for pytest plugins which are invoked by pytest itself
and by builtin plugins."""
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from pluggy import HookspecMarker
from _pytest.deprecated import WARNING_CMDLINE_PREPARSE_HOOK
if TYPE_CHECKING:
import pdb
import warnings
from typing_extensions import Literal
from _pytest._code.code import ExceptionRepr
from _pytest.code import ExceptionInfo
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import PytestPluginManager
from _pytest.config import _PluggyPlugin
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
from _pytest.main import Session
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import Exit
from _pytest.python import Class
from _pytest.python import Function
from _pytest.python import Metafunc
from _pytest.python import Module
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
from _pytest.runner import CallInfo
from _pytest.terminal import TerminalReporter
from _pytest.compat import LEGACY_PATH
hookspec = HookspecMarker("pytest")
# -------------------------------------------------------------------------
# Initialization hooks called for every plugin
# -------------------------------------------------------------------------
@hookspec(historic=True)
def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None:
"""Called at plugin registration time to allow adding new hooks via a call to
``pluginmanager.add_hookspecs(module_or_class, prefix)``.
:param pytest.PytestPluginManager pluginmanager: The pytest plugin manager.
.. note::
This hook is incompatible with ``hookwrapper=True``.
"""
@hookspec(historic=True)
def pytest_plugin_registered(
plugin: "_PluggyPlugin", manager: "PytestPluginManager"
) -> None:
"""A new pytest plugin got registered.
:param plugin: The plugin module or instance.
:param pytest.PytestPluginManager manager: pytest plugin manager.
.. note::
This hook is incompatible with ``hookwrapper=True``.
"""
@hookspec(historic=True)
def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") -> None:
"""Register argparse-style options and ini-style config values,
called once at the beginning of a test run.
.. note::
This function should be implemented only in plugins or ``conftest.py``
files situated at the tests root directory due to how pytest
:ref:`discovers plugins during startup <pluginorder>`.
:param pytest.Parser parser:
To add command line options, call
:py:func:`parser.addoption(...) <pytest.Parser.addoption>`.
To add ini-file values call :py:func:`parser.addini(...)
<pytest.Parser.addini>`.
:param pytest.PytestPluginManager pluginmanager:
The pytest plugin manager, which can be used to install :py:func:`hookspec`'s
or :py:func:`hookimpl`'s and allow one plugin to call another plugin's hooks
to change how command line options are added.
Options can later be accessed through the
:py:class:`config <pytest.Config>` object, respectively:
- :py:func:`config.getoption(name) <pytest.Config.getoption>` to
retrieve the value of a command line option.
- :py:func:`config.getini(name) <pytest.Config.getini>` to retrieve
a value read from an ini-style file.
The config object is passed around on many internal objects via the ``.config``
attribute or can be retrieved as the ``pytestconfig`` fixture.
.. note::
This hook is incompatible with ``hookwrapper=True``.
"""
@hookspec(historic=True)
def pytest_configure(config: "Config") -> None:
"""Allow plugins and conftest files to perform initial configuration.
This hook is called for every plugin and initial conftest file
after command line options have been parsed.
After that, the hook is called for other conftest files as they are
imported.
.. note::
This hook is incompatible with ``hookwrapper=True``.
:param pytest.Config config: The pytest config object.
"""
# -------------------------------------------------------------------------
# Bootstrapping hooks called for plugins registered early enough:
# internal and 3rd party plugins.
# -------------------------------------------------------------------------
@hookspec(firstresult=True)
def pytest_cmdline_parse(
pluginmanager: "PytestPluginManager", args: List[str]
) -> Optional["Config"]:
"""Return an initialized :class:`~pytest.Config`, parsing the specified args.
Stops at first non-None result, see :ref:`firstresult`.
.. note::
This hook will only be called for plugin classes passed to the
``plugins`` arg when using `pytest.main`_ to perform an in-process
test run.
:param pluginmanager: The pytest plugin manager.
:param args: List of arguments passed on the command line.
:returns: A pytest config object.
"""
@hookspec(warn_on_impl=WARNING_CMDLINE_PREPARSE_HOOK)
def pytest_cmdline_preparse(config: "Config", args: List[str]) -> None:
"""(**Deprecated**) modify command line arguments before option parsing.
This hook is considered deprecated and will be removed in a future pytest version. Consider
using :hook:`pytest_load_initial_conftests` instead.
.. note::
This hook will not be called for ``conftest.py`` files, only for setuptools plugins.
:param config: The pytest config object.
:param args: Arguments passed on the command line.
"""
@hookspec(firstresult=True)
def pytest_cmdline_main(config: "Config") -> Optional[Union["ExitCode", int]]:
"""Called for performing the main command line action. The default
implementation will invoke the configure hooks and runtest_mainloop.
Stops at first non-None result, see :ref:`firstresult`.
:param config: The pytest config object.
:returns: The exit code.
"""
def pytest_load_initial_conftests(
early_config: "Config", parser: "Parser", args: List[str]
) -> None:
"""Called to implement the loading of initial conftest files ahead
of command line option parsing.
.. note::
This hook will not be called for ``conftest.py`` files, only for setuptools plugins.
:param early_config: The pytest config object.
:param args: Arguments passed on the command line.
:param parser: To add command line options.
"""
# -------------------------------------------------------------------------
# collection hooks
# -------------------------------------------------------------------------
@hookspec(firstresult=True)
def pytest_collection(session: "Session") -> Optional[object]:
"""Perform the collection phase for the given session.
Stops at first non-None result, see :ref:`firstresult`.
The return value is not used, but only stops further processing.
The default collection phase is this (see individual hooks for full details):
1. Starting from ``session`` as the initial collector:
1. ``pytest_collectstart(collector)``
2. ``report = pytest_make_collect_report(collector)``
3. ``pytest_exception_interact(collector, call, report)`` if an interactive exception occurred
4. For each collected node:
1. If an item, ``pytest_itemcollected(item)``
2. If a collector, recurse into it.
5. ``pytest_collectreport(report)``
2. ``pytest_collection_modifyitems(session, config, items)``
1. ``pytest_deselected(items)`` for any deselected items (may be called multiple times)
3. ``pytest_collection_finish(session)``
4. Set ``session.items`` to the list of collected items
5. Set ``session.testscollected`` to the number of collected items
You can implement this hook to only perform some action before collection,
for example the terminal plugin uses it to start displaying the collection
counter (and returns `None`).
:param session: The pytest session object.
"""
def pytest_collection_modifyitems(
session: "Session", config: "Config", items: List["Item"]
) -> None:
"""Called after collection has been performed. May filter or re-order
the items in-place.
:param session: The pytest session object.
:param config: The pytest config object.
:param items: List of item objects.
"""
def pytest_collection_finish(session: "Session") -> None:
"""Called after collection has been performed and modified.
:param session: The pytest session object.
"""
@hookspec(firstresult=True)
def pytest_ignore_collect(
collection_path: Path, path: "LEGACY_PATH", config: "Config"
) -> Optional[bool]:
"""Return True to prevent considering this path for collection.
This hook is consulted for all files and directories prior to calling
more specific hooks.
Stops at first non-None result, see :ref:`firstresult`.
:param collection_path: The path to analyze.
:param path: The path to analyze (deprecated).
:param config: The pytest config object.
.. versionchanged:: 7.0.0
The ``collection_path`` parameter was added as a :class:`pathlib.Path`
equivalent of the ``path`` parameter. The ``path`` parameter
has been deprecated.
"""
def pytest_collect_file(
file_path: Path, path: "LEGACY_PATH", parent: "Collector"
) -> "Optional[Collector]":
"""Create a :class:`~pytest.Collector` for the given path, or None if not relevant.
The new node needs to have the specified ``parent`` as a parent.
:param file_path: The path to analyze.
:param path: The path to collect (deprecated).
.. versionchanged:: 7.0.0
The ``file_path`` parameter was added as a :class:`pathlib.Path`
equivalent of the ``path`` parameter. The ``path`` parameter
has been deprecated.
"""
# logging hooks for collection
def pytest_collectstart(collector: "Collector") -> None:
"""Collector starts collecting.
:param collector:
The collector.
"""
def pytest_itemcollected(item: "Item") -> None:
"""We just collected a test item.
:param item:
The item.
"""
def pytest_collectreport(report: "CollectReport") -> None:
"""Collector finished collecting.
:param report:
The collect report.
"""
def pytest_deselected(items: Sequence["Item"]) -> None:
"""Called for deselected test items, e.g. by keyword.
May be called multiple times.
:param items:
The items.
"""
@hookspec(firstresult=True)
def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectReport]":
"""Perform :func:`collector.collect() <pytest.Collector.collect>` and return
a :class:`~pytest.CollectReport`.
Stops at first non-None result, see :ref:`firstresult`.
:param collector:
The collector.
"""
# -------------------------------------------------------------------------
# Python test function related hooks
# -------------------------------------------------------------------------
@hookspec(firstresult=True)
def pytest_pycollect_makemodule(
module_path: Path, path: "LEGACY_PATH", parent
) -> Optional["Module"]:
"""Return a :class:`pytest.Module` collector or None for the given path.
This hook will be called for each matching test module path.
The :hook:`pytest_collect_file` hook needs to be used if you want to
create test modules for files that do not match as a test module.
Stops at first non-None result, see :ref:`firstresult`.
:param module_path: The path of the module to collect.
:param path: The path of the module to collect (deprecated).
.. versionchanged:: 7.0.0
The ``module_path`` parameter was added as a :class:`pathlib.Path`
equivalent of the ``path`` parameter.
The ``path`` parameter has been deprecated in favor of ``fspath``.
"""
@hookspec(firstresult=True)
def pytest_pycollect_makeitem(
collector: Union["Module", "Class"], name: str, obj: object
) -> Union[None, "Item", "Collector", List[Union["Item", "Collector"]]]:
"""Return a custom item/collector for a Python object in a module, or None.
Stops at first non-None result, see :ref:`firstresult`.
:param collector:
The module/class collector.
:param name:
The name of the object in the module/class.
:param obj:
The object.
:returns:
The created items/collectors.
"""
@hookspec(firstresult=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
"""Call underlying test function.
Stops at first non-None result, see :ref:`firstresult`.
:param pyfuncitem:
The function item.
"""
def pytest_generate_tests(metafunc: "Metafunc") -> None:
"""Generate (multiple) parametrized calls to a test function.
:param metafunc:
The :class:`~pytest.Metafunc` helper for the test function.
"""
@hookspec(firstresult=True)
def pytest_make_parametrize_id(
config: "Config", val: object, argname: str
) -> Optional[str]:
"""Return a user-friendly string representation of the given ``val``
that will be used by @pytest.mark.parametrize calls, or None if the hook
doesn't know about ``val``.
The parameter name is available as ``argname``, if required.
Stops at first non-None result, see :ref:`firstresult`.
:param config: The pytest config object.
:param val: The parametrized value.
:param str argname: The automatic parameter name produced by pytest.
"""
# -------------------------------------------------------------------------
# runtest related hooks
# -------------------------------------------------------------------------
@hookspec(firstresult=True)
def pytest_runtestloop(session: "Session") -> Optional[object]:
"""Perform the main runtest loop (after collection finished).
The default hook implementation performs the runtest protocol for all items
collected in the session (``session.items``), unless the collection failed
or the ``collectonly`` pytest option is set.
If at any point :py:func:`pytest.exit` is called, the loop is
terminated immediately.
If at any point ``session.shouldfail`` or ``session.shouldstop`` are set, the
loop is terminated after the runtest protocol for the current item is finished.
:param session: The pytest session object.
Stops at first non-None result, see :ref:`firstresult`.
The return value is not used, but only stops further processing.
"""
@hookspec(firstresult=True)
def pytest_runtest_protocol(
item: "Item", nextitem: "Optional[Item]"
) -> Optional[object]:
"""Perform the runtest protocol for a single test item.
The default runtest protocol is this (see individual hooks for full details):
- ``pytest_runtest_logstart(nodeid, location)``
- Setup phase:
- ``call = pytest_runtest_setup(item)`` (wrapped in ``CallInfo(when="setup")``)
- ``report = pytest_runtest_makereport(item, call)``
- ``pytest_runtest_logreport(report)``
- ``pytest_exception_interact(call, report)`` if an interactive exception occurred
- Call phase, if the the setup passed and the ``setuponly`` pytest option is not set:
- ``call = pytest_runtest_call(item)`` (wrapped in ``CallInfo(when="call")``)
- ``report = pytest_runtest_makereport(item, call)``
- ``pytest_runtest_logreport(report)``
- ``pytest_exception_interact(call, report)`` if an interactive exception occurred
- Teardown phase:
- ``call = pytest_runtest_teardown(item, nextitem)`` (wrapped in ``CallInfo(when="teardown")``)
- ``report = pytest_runtest_makereport(item, call)``
- ``pytest_runtest_logreport(report)``
- ``pytest_exception_interact(call, report)`` if an interactive exception occurred
- ``pytest_runtest_logfinish(nodeid, location)``
:param item: Test item for which the runtest protocol is performed.
:param nextitem: The scheduled-to-be-next test item (or None if this is the end my friend).
Stops at first non-None result, see :ref:`firstresult`.
The return value is not used, but only stops further processing.
"""
def pytest_runtest_logstart(
nodeid: str, location: Tuple[str, Optional[int], str]
) -> None:
"""Called at the start of running the runtest protocol for a single item.
See :hook:`pytest_runtest_protocol` for a description of the runtest protocol.
:param nodeid: Full node ID of the item.
:param location: A tuple of ``(filename, lineno, testname)``.
"""
def pytest_runtest_logfinish(
nodeid: str, location: Tuple[str, Optional[int], str]
) -> None:
"""Called at the end of running the runtest protocol for a single item.
See :hook:`pytest_runtest_protocol` for a description of the runtest protocol.
:param nodeid: Full node ID of the item.
:param location: A tuple of ``(filename, lineno, testname)``.
"""
def pytest_runtest_setup(item: "Item") -> None:
"""Called to perform the setup phase for a test item.
The default implementation runs ``setup()`` on ``item`` and all of its
parents (which haven't been setup yet). This includes obtaining the
values of fixtures required by the item (which haven't been obtained
yet).
:param item:
The item.
"""
def pytest_runtest_call(item: "Item") -> None:
"""Called to run the test for test item (the call phase).
The default implementation calls ``item.runtest()``.
:param item:
The item.
"""
def pytest_runtest_teardown(item: "Item", nextitem: Optional["Item"]) -> None:
"""Called to perform the teardown phase for a test item.
The default implementation runs the finalizers and calls ``teardown()``
on ``item`` and all of its parents (which need to be torn down). This
includes running the teardown phase of fixtures required by the item (if
they go out of scope).
:param item:
The item.
:param nextitem:
The scheduled-to-be-next test item (None if no further test item is
scheduled). This argument is used to perform exact teardowns, i.e.
calling just enough finalizers so that nextitem only needs to call
setup functions.
"""
@hookspec(firstresult=True)
def pytest_runtest_makereport(
item: "Item", call: "CallInfo[None]"
) -> Optional["TestReport"]:
"""Called to create a :class:`~pytest.TestReport` for each of
the setup, call and teardown runtest phases of a test item.
See :hook:`pytest_runtest_protocol` for a description of the runtest protocol.
:param item: The item.
:param call: The :class:`~pytest.CallInfo` for the phase.
Stops at first non-None result, see :ref:`firstresult`.
"""
def pytest_runtest_logreport(report: "TestReport") -> None:
"""Process the :class:`~pytest.TestReport` produced for each
of the setup, call and teardown runtest phases of an item.
See :hook:`pytest_runtest_protocol` for a description of the runtest protocol.
"""
@hookspec(firstresult=True)
def pytest_report_to_serializable(
config: "Config",
report: Union["CollectReport", "TestReport"],
) -> Optional[Dict[str, Any]]:
"""Serialize the given report object into a data structure suitable for
sending over the wire, e.g. converted to JSON.
:param config: The pytest config object.
:param report: The report.
"""
@hookspec(firstresult=True)
def pytest_report_from_serializable(
config: "Config",
data: Dict[str, Any],
) -> Optional[Union["CollectReport", "TestReport"]]:
"""Restore a report object previously serialized with
:hook:`pytest_report_to_serializable`.
:param config: The pytest config object.
"""
# -------------------------------------------------------------------------
# Fixture related hooks
# -------------------------------------------------------------------------
@hookspec(firstresult=True)
def pytest_fixture_setup(
fixturedef: "FixtureDef[Any]", request: "SubRequest"
) -> Optional[object]:
"""Perform fixture setup execution.
:param fixturdef:
The fixture definition object.
:param request:
The fixture request object.
:returns:
The return value of the call to the fixture function.
Stops at first non-None result, see :ref:`firstresult`.
.. note::
If the fixture function returns None, other implementations of
this hook function will continue to be called, according to the
behavior of the :ref:`firstresult` option.
"""
def pytest_fixture_post_finalizer(
fixturedef: "FixtureDef[Any]", request: "SubRequest"
) -> None:
"""Called after fixture teardown, but before the cache is cleared, so
the fixture result ``fixturedef.cached_result`` is still available (not
``None``).
:param fixturdef:
The fixture definition object.
:param request:
The fixture request object.
"""
# -------------------------------------------------------------------------
# test session related hooks
# -------------------------------------------------------------------------
def pytest_sessionstart(session: "Session") -> None:
"""Called after the ``Session`` object has been created and before performing collection
and entering the run test loop.
:param session: The pytest session object.
"""
def pytest_sessionfinish(
session: "Session",
exitstatus: Union[int, "ExitCode"],
) -> None:
"""Called after whole test run finished, right before returning the exit status to the system.
:param session: The pytest session object.
:param exitstatus: The status which pytest will return to the system.
"""
def pytest_unconfigure(config: "Config") -> None:
"""Called before test process is exited.
:param config: The pytest config object.
"""
# -------------------------------------------------------------------------
# hooks for customizing the assert methods
# -------------------------------------------------------------------------
def pytest_assertrepr_compare(
config: "Config", op: str, left: object, right: object
) -> Optional[List[str]]:
"""Return explanation for comparisons in failing assert expressions.
Return None for no custom explanation, otherwise return a list
of strings. The strings will be joined by newlines but any newlines
*in* a string will be escaped. Note that all but the first line will
be indented slightly, the intention is for the first line to be a summary.
:param config: The pytest config object.
:param op: The operator, e.g. `"=="`, `"!="`, `"not in"`.
:param left: The left operand.
:param right: The right operand.
"""
def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> None:
"""Called whenever an assertion passes.
.. versionadded:: 5.0
Use this hook to do some processing after a passing assertion.
The original assertion information is available in the `orig` string
and the pytest introspected assertion information is available in the
`expl` string.
This hook must be explicitly enabled by the ``enable_assertion_pass_hook``
ini-file option:
.. code-block:: ini
[pytest]
enable_assertion_pass_hook=true
You need to **clean the .pyc** files in your project directory and interpreter libraries
when enabling this option, as assertions will require to be re-written.
:param item: pytest item object of current test.
:param lineno: Line number of the assert statement.
:param orig: String with the original assertion.
:param expl: String with the assert explanation.
"""
# -------------------------------------------------------------------------
# Hooks for influencing reporting (invoked from _pytest_terminal).
# -------------------------------------------------------------------------
def pytest_report_header(
config: "Config", start_path: Path, startdir: "LEGACY_PATH"
) -> Union[str, List[str]]:
"""Return a string or list of strings to be displayed as header info for terminal reporting.
:param config: The pytest config object.
:param start_path: The starting dir.
:param startdir: The starting dir (deprecated).
.. note::
Lines returned by a plugin are displayed before those of plugins which
ran before it.
If you want to have your line(s) displayed first, use
:ref:`trylast=True <plugin-hookorder>`.
.. note::
This function should be implemented only in plugins or ``conftest.py``
files situated at the tests root directory due to how pytest
:ref:`discovers plugins during startup <pluginorder>`.
.. versionchanged:: 7.0.0
The ``start_path`` parameter was added as a :class:`pathlib.Path`
equivalent of the ``startdir`` parameter. The ``startdir`` parameter
has been deprecated.
"""
def pytest_report_collectionfinish(
config: "Config",
start_path: Path,
startdir: "LEGACY_PATH",
items: Sequence["Item"],
) -> Union[str, List[str]]:
"""Return a string or list of strings to be displayed after collection
has finished successfully.
These strings will be displayed after the standard "collected X items" message.
.. versionadded:: 3.2
:param config: The pytest config object.
:param start_path: The starting dir.
:param startdir: The starting dir (deprecated).
:param items: List of pytest items that are going to be executed; this list should not be modified.
.. note::
Lines returned by a plugin are displayed before those of plugins which
ran before it.
If you want to have your line(s) displayed first, use
:ref:`trylast=True <plugin-hookorder>`.
.. versionchanged:: 7.0.0
The ``start_path`` parameter was added as a :class:`pathlib.Path`
equivalent of the ``startdir`` parameter. The ``startdir`` parameter
has been deprecated.
"""
@hookspec(firstresult=True)
def pytest_report_teststatus(
report: Union["CollectReport", "TestReport"], config: "Config"
) -> Tuple[str, str, Union[str, Mapping[str, bool]]]:
"""Return result-category, shortletter and verbose word for status
reporting.
The result-category is a category in which to count the result, for
example "passed", "skipped", "error" or the empty string.
The shortletter is shown as testing progresses, for example ".", "s",
"E" or the empty string.
The verbose word is shown as testing progresses in verbose mode, for
example "PASSED", "SKIPPED", "ERROR" or the empty string.
pytest may style these implicitly according to the report outcome.
To provide explicit styling, return a tuple for the verbose word,
for example ``"rerun", "R", ("RERUN", {"yellow": True})``.
:param report: The report object whose status is to be returned.
:param config: The pytest config object.
:returns: The test status.
Stops at first non-None result, see :ref:`firstresult`.
"""
def pytest_terminal_summary(
terminalreporter: "TerminalReporter",
exitstatus: "ExitCode",
config: "Config",
) -> None:
"""Add a section to terminal summary reporting.
:param terminalreporter: The internal terminal reporter object.
:param exitstatus: The exit status that will be reported back to the OS.
:param config: The pytest config object.
.. versionadded:: 4.2
The ``config`` parameter.
"""
@hookspec(historic=True)
def pytest_warning_recorded(
warning_message: "warnings.WarningMessage",
when: "Literal['config', 'collect', 'runtest']",
nodeid: str,
location: Optional[Tuple[str, int, str]],
) -> None:
"""Process a warning captured by the internal pytest warnings plugin.
:param warning_message:
The captured warning. This is the same object produced by :py:func:`warnings.catch_warnings`, and contains
the same attributes as the parameters of :py:func:`warnings.showwarning`.
:param when:
Indicates when the warning was captured. Possible values:
* ``"config"``: during pytest configuration/initialization stage.
* ``"collect"``: during test collection.
* ``"runtest"``: during test execution.
:param nodeid:
Full id of the item.
:param location:
When available, holds information about the execution context of the captured
warning (filename, linenumber, function). ``function`` evaluates to <module>
when the execution context is at the module level.
.. versionadded:: 6.0
"""
# -------------------------------------------------------------------------
# Hooks for influencing skipping
# -------------------------------------------------------------------------
def pytest_markeval_namespace(config: "Config") -> Dict[str, Any]:
"""Called when constructing the globals dictionary used for
evaluating string conditions in xfail/skipif markers.
This is useful when the condition for a marker requires
objects that are expensive or impossible to obtain during
collection time, which is required by normal boolean
conditions.
.. versionadded:: 6.2
:param config: The pytest config object.
:returns: A dictionary of additional globals to add.
"""
# -------------------------------------------------------------------------
# error handling and internal debugging hooks
# -------------------------------------------------------------------------
def pytest_internalerror(
excrepr: "ExceptionRepr",
excinfo: "ExceptionInfo[BaseException]",
) -> Optional[bool]:
"""Called for internal errors.
Return True to suppress the fallback handling of printing an
INTERNALERROR message directly to sys.stderr.
:param excrepr: The exception repr object.
:param excinfo: The exception info.
"""
def pytest_keyboard_interrupt(
excinfo: "ExceptionInfo[Union[KeyboardInterrupt, Exit]]",
) -> None:
"""Called for keyboard interrupt.
:param excinfo: The exception info.
"""
def pytest_exception_interact(
node: Union["Item", "Collector"],
call: "CallInfo[Any]",
report: Union["CollectReport", "TestReport"],
) -> None:
"""Called when an exception was raised which can potentially be
interactively handled.
May be called during collection (see :hook:`pytest_make_collect_report`),
in which case ``report`` is a :class:`CollectReport`.
May be called during runtest of an item (see :hook:`pytest_runtest_protocol`),
in which case ``report`` is a :class:`TestReport`.
This hook is not called if the exception that was raised is an internal
exception like ``skip.Exception``.
:param node:
The item or collector.
:param call:
The call information. Contains the exception.
:param report:
The collection or test report.
"""
def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
"""Called upon pdb.set_trace().
Can be used by plugins to take special action just before the python
debugger enters interactive mode.
:param config: The pytest config object.
:param pdb: The Pdb instance.
"""
def pytest_leave_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
"""Called when leaving pdb (e.g. with continue after pdb.set_trace()).
Can be used by plugins to take special action just after the python
debugger leaves interactive mode.
:param config: The pytest config object.
:param pdb: The Pdb instance.
"""

View File

@ -0,0 +1,699 @@
"""Report test results in JUnit-XML format, for use with Jenkins and build
integration servers.
Based on initial code from Ross Lawley.
Output conforms to
https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd
"""
import functools
import os
import platform
import re
import xml.etree.ElementTree as ET
from datetime import datetime
from typing import Callable
from typing import Dict
from typing import List
from typing import Match
from typing import Optional
from typing import Tuple
from typing import Union
import pytest
from _pytest import nodes
from _pytest import timing
from _pytest._code.code import ExceptionRepr
from _pytest._code.code import ReprFileLocation
from _pytest.config import Config
from _pytest.config import filename_arg
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.reports import TestReport
from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter
xml_key = StashKey["LogXML"]()
def bin_xml_escape(arg: object) -> str:
r"""Visually escape invalid XML characters.
For example, transforms
'hello\aworld\b'
into
'hello#x07world#x08'
Note that the #xABs are *not* XML escapes - missing the ampersand &#xAB.
The idea is to escape visually for the user rather than for XML itself.
"""
def repl(matchobj: Match[str]) -> str:
i = ord(matchobj.group())
if i <= 0xFF:
return "#x%02X" % i
else:
return "#x%04X" % i
# The spec range of valid chars is:
# Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
# For an unknown(?) reason, we disallow #x7F (DEL) as well.
illegal_xml_re = (
"[^\u0009\u000A\u000D\u0020-\u007E\u0080-\uD7FF\uE000-\uFFFD\u10000-\u10FFFF]"
)
return re.sub(illegal_xml_re, repl, str(arg))
def merge_family(left, right) -> None:
result = {}
for kl, vl in left.items():
for kr, vr in right.items():
if not isinstance(vl, list):
raise TypeError(type(vl))
result[kl] = vl + vr
left.update(result)
families = {}
families["_base"] = {"testcase": ["classname", "name"]}
families["_base_legacy"] = {"testcase": ["file", "line", "url"]}
# xUnit 1.x inherits legacy attributes.
families["xunit1"] = families["_base"].copy()
merge_family(families["xunit1"], families["_base_legacy"])
# xUnit 2.x uses strict base attributes.
families["xunit2"] = families["_base"]
class _NodeReporter:
def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None:
self.id = nodeid
self.xml = xml
self.add_stats = self.xml.add_stats
self.family = self.xml.family
self.duration = 0.0
self.properties: List[Tuple[str, str]] = []
self.nodes: List[ET.Element] = []
self.attrs: Dict[str, str] = {}
def append(self, node: ET.Element) -> None:
self.xml.add_stats(node.tag)
self.nodes.append(node)
def add_property(self, name: str, value: object) -> None:
self.properties.append((str(name), bin_xml_escape(value)))
def add_attribute(self, name: str, value: object) -> None:
self.attrs[str(name)] = bin_xml_escape(value)
def make_properties_node(self) -> Optional[ET.Element]:
"""Return a Junit node containing custom properties, if any."""
if self.properties:
properties = ET.Element("properties")
for name, value in self.properties:
properties.append(ET.Element("property", name=name, value=value))
return properties
return None
def record_testreport(self, testreport: TestReport) -> None:
names = mangle_test_address(testreport.nodeid)
existing_attrs = self.attrs
classnames = names[:-1]
if self.xml.prefix:
classnames.insert(0, self.xml.prefix)
attrs: Dict[str, str] = {
"classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]),
"file": testreport.location[0],
}
if testreport.location[1] is not None:
attrs["line"] = str(testreport.location[1])
if hasattr(testreport, "url"):
attrs["url"] = testreport.url
self.attrs = attrs
self.attrs.update(existing_attrs) # Restore any user-defined attributes.
# Preserve legacy testcase behavior.
if self.family == "xunit1":
return
# Filter out attributes not permitted by this test family.
# Including custom attributes because they are not valid here.
temp_attrs = {}
for key in self.attrs.keys():
if key in families[self.family]["testcase"]:
temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs
def to_xml(self) -> ET.Element:
testcase = ET.Element("testcase", self.attrs, time="%.3f" % self.duration)
properties = self.make_properties_node()
if properties is not None:
testcase.append(properties)
testcase.extend(self.nodes)
return testcase
def _add_simple(self, tag: str, message: str, data: Optional[str] = None) -> None:
node = ET.Element(tag, message=message)
node.text = bin_xml_escape(data)
self.append(node)
def write_captured_output(self, report: TestReport) -> None:
if not self.xml.log_passing_tests and report.passed:
return
content_out = report.capstdout
content_log = report.caplog
content_err = report.capstderr
if self.xml.logging == "no":
return
content_all = ""
if self.xml.logging in ["log", "all"]:
content_all = self._prepare_content(content_log, " Captured Log ")
if self.xml.logging in ["system-out", "out-err", "all"]:
content_all += self._prepare_content(content_out, " Captured Out ")
self._write_content(report, content_all, "system-out")
content_all = ""
if self.xml.logging in ["system-err", "out-err", "all"]:
content_all += self._prepare_content(content_err, " Captured Err ")
self._write_content(report, content_all, "system-err")
content_all = ""
if content_all:
self._write_content(report, content_all, "system-out")
def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""])
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
tag = ET.Element(jheader)
tag.text = bin_xml_escape(content)
self.append(tag)
def append_pass(self, report: TestReport) -> None:
self.add_stats("passed")
def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"):
self._add_simple("skipped", "xfail-marked test passes unexpectedly")
else:
assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr(
report.longrepr, "reprcrash", None
)
if reprcrash is not None:
message = reprcrash.message
else:
message = str(report.longrepr)
message = bin_xml_escape(message)
self._add_simple("failure", message, str(report.longrepr))
def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
assert report.longrepr is not None
self._add_simple("error", "collection failure", str(report.longrepr))
def append_collect_skipped(self, report: TestReport) -> None:
self._add_simple("skipped", "collection skipped", str(report.longrepr))
def append_error(self, report: TestReport) -> None:
assert report.longrepr is not None
reprcrash: Optional[ReprFileLocation] = getattr(
report.longrepr, "reprcrash", None
)
if reprcrash is not None:
reason = reprcrash.message
else:
reason = str(report.longrepr)
if report.when == "teardown":
msg = f'failed on teardown with "{reason}"'
else:
msg = f'failed on setup with "{reason}"'
self._add_simple("error", bin_xml_escape(msg), str(report.longrepr))
def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
xfailreason = report.wasxfail
if xfailreason.startswith("reason: "):
xfailreason = xfailreason[8:]
xfailreason = bin_xml_escape(xfailreason)
skipped = ET.Element("skipped", type="pytest.xfail", message=xfailreason)
self.append(skipped)
else:
assert isinstance(report.longrepr, tuple)
filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "):
skipreason = skipreason[9:]
details = f"{filename}:{lineno}: {skipreason}"
skipped = ET.Element("skipped", type="pytest.skip", message=skipreason)
skipped.text = bin_xml_escape(details)
self.append(skipped)
self.write_captured_output(report)
def finalize(self) -> None:
data = self.to_xml()
self.__dict__.clear()
# Type ignored because mypy doesn't like overriding a method.
# Also the return value doesn't match...
self.to_xml = lambda: data # type: ignore[assignment]
def _warn_incompatibility_with_xunit2(
request: FixtureRequest, fixture_name: str
) -> None:
"""Emit a PytestWarning about the given fixture being incompatible with newer xunit revisions."""
from _pytest.warning_types import PytestWarning
xml = request.config.stash.get(xml_key, None)
if xml is not None and xml.family not in ("xunit1", "legacy"):
request.node.warn(
PytestWarning(
"{fixture_name} is incompatible with junit_family '{family}' (use 'legacy' or 'xunit1')".format(
fixture_name=fixture_name, family=xml.family
)
)
)
@pytest.fixture
def record_property(request: FixtureRequest) -> Callable[[str, object], None]:
"""Add extra properties to the calling test.
User properties become part of the test report and are available to the
configured reporters, like JUnit XML.
The fixture is callable with ``name, value``. The value is automatically
XML-encoded.
Example::
def test_function(record_property):
record_property("example_key", 1)
"""
_warn_incompatibility_with_xunit2(request, "record_property")
def append_property(name: str, value: object) -> None:
request.node.user_properties.append((name, value))
return append_property
@pytest.fixture
def record_xml_attribute(request: FixtureRequest) -> Callable[[str, object], None]:
"""Add extra xml attributes to the tag for the calling test.
The fixture is callable with ``name, value``. The value is
automatically XML-encoded.
"""
from _pytest.warning_types import PytestExperimentalApiWarning
request.node.warn(
PytestExperimentalApiWarning("record_xml_attribute is an experimental feature")
)
_warn_incompatibility_with_xunit2(request, "record_xml_attribute")
# Declare noop
def add_attr_noop(name: str, value: object) -> None:
pass
attr_func = add_attr_noop
xml = request.config.stash.get(xml_key, None)
if xml is not None:
node_reporter = xml.node_reporter(request.node.nodeid)
attr_func = node_reporter.add_attribute
return attr_func
def _check_record_param_type(param: str, v: str) -> None:
"""Used by record_testsuite_property to check that the given parameter name is of the proper
type."""
__tracebackhide__ = True
if not isinstance(v, str):
msg = "{param} parameter needs to be a string, but {g} given" # type: ignore[unreachable]
raise TypeError(msg.format(param=param, g=type(v).__name__))
@pytest.fixture(scope="session")
def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object], None]:
"""Record a new ``<property>`` tag as child of the root ``<testsuite>``.
This is suitable to writing global information regarding the entire test
suite, and is compatible with ``xunit2`` JUnit family.
This is a ``session``-scoped fixture which is called with ``(name, value)``. Example:
.. code-block:: python
def test_foo(record_testsuite_property):
record_testsuite_property("ARCH", "PPC")
record_testsuite_property("STORAGE_TYPE", "CEPH")
:param name:
The property name.
:param value:
The property value. Will be converted to a string.
.. warning::
Currently this fixture **does not work** with the
`pytest-xdist <https://github.com/pytest-dev/pytest-xdist>`__ plugin. See
:issue:`7767` for details.
"""
__tracebackhide__ = True
def record_func(name: str, value: object) -> None:
"""No-op function in case --junitxml was not passed in the command-line."""
__tracebackhide__ = True
_check_record_param_type("name", name)
xml = request.config.stash.get(xml_key, None)
if xml is not None:
record_func = xml.add_global_property # noqa
return record_func
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group.addoption(
"--junitxml",
"--junit-xml",
action="store",
dest="xmlpath",
metavar="path",
type=functools.partial(filename_arg, optname="--junitxml"),
default=None,
help="Create junit-xml style report file at given path",
)
group.addoption(
"--junitprefix",
"--junit-prefix",
action="store",
metavar="str",
default=None,
help="Prepend prefix to classnames in junit-xml output",
)
parser.addini(
"junit_suite_name", "Test suite name for JUnit report", default="pytest"
)
parser.addini(
"junit_logging",
"Write captured log messages to JUnit report: "
"one of no|log|system-out|system-err|out-err|all",
default="no",
)
parser.addini(
"junit_log_passing_tests",
"Capture log information for passing tests to JUnit report: ",
type="bool",
default=True,
)
parser.addini(
"junit_duration_report",
"Duration time to report: one of total|call",
default="total",
) # choices=['total', 'call'])
parser.addini(
"junit_family",
"Emit XML for schema: one of legacy|xunit1|xunit2",
default="xunit2",
)
def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath
# Prevent opening xmllog on worker nodes (xdist).
if xmlpath and not hasattr(config, "workerinput"):
junit_family = config.getini("junit_family")
config.stash[xml_key] = LogXML(
xmlpath,
config.option.junitprefix,
config.getini("junit_suite_name"),
config.getini("junit_logging"),
config.getini("junit_duration_report"),
junit_family,
config.getini("junit_log_passing_tests"),
)
config.pluginmanager.register(config.stash[xml_key])
def pytest_unconfigure(config: Config) -> None:
xml = config.stash.get(xml_key, None)
if xml:
del config.stash[xml_key]
config.pluginmanager.unregister(xml)
def mangle_test_address(address: str) -> List[str]:
path, possible_open_bracket, params = address.partition("[")
names = path.split("::")
# Convert file path to dotted path.
names[0] = names[0].replace(nodes.SEP, ".")
names[0] = re.sub(r"\.py$", "", names[0])
# Put any params back.
names[-1] += possible_open_bracket + params
return names
class LogXML:
def __init__(
self,
logfile,
prefix: Optional[str],
suite_name: str = "pytest",
logging: str = "no",
report_duration: str = "total",
family="xunit1",
log_passing_tests: bool = True,
) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile))
self.logfile = os.path.normpath(os.path.abspath(logfile))
self.prefix = prefix
self.suite_name = suite_name
self.logging = logging
self.log_passing_tests = log_passing_tests
self.report_duration = report_duration
self.family = family
self.stats: Dict[str, int] = dict.fromkeys(
["error", "passed", "failure", "skipped"], 0
)
self.node_reporters: Dict[
Tuple[Union[str, TestReport], object], _NodeReporter
] = {}
self.node_reporters_ordered: List[_NodeReporter] = []
self.global_properties: List[Tuple[str, str]] = []
# List of reports that failed on call but teardown is pending.
self.open_reports: List[TestReport] = []
self.cnt_double_fail_tests = 0
# Replaces convenience family with real family.
if self.family == "legacy":
self.family = "xunit1"
def finalize(self, report: TestReport) -> None:
nodeid = getattr(report, "nodeid", report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
reporter = self.node_reporters.pop((nodeid, workernode))
if reporter is not None:
reporter.finalize()
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporter:
nodeid: Union[str, TestReport] = getattr(report, "nodeid", report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
key = nodeid, workernode
if key in self.node_reporters:
# TODO: breaks for --dist=each
return self.node_reporters[key]
reporter = _NodeReporter(nodeid, self)
self.node_reporters[key] = reporter
self.node_reporters_ordered.append(reporter)
return reporter
def add_stats(self, key: str) -> None:
if key in self.stats:
self.stats[key] += 1
def _opentestcase(self, report: TestReport) -> _NodeReporter:
reporter = self.node_reporter(report)
reporter.record_testreport(report)
return reporter
def pytest_runtest_logreport(self, report: TestReport) -> None:
"""Handle a setup/call/teardown report, generating the appropriate
XML tags as necessary.
Note: due to plugins like xdist, this hook may be called in interlaced
order with reports from other nodes. For example:
Usual call order:
-> setup node1
-> call node1
-> teardown node1
-> setup node2
-> call node2
-> teardown node2
Possible call order in xdist:
-> setup node1
-> call node1
-> setup node2
-> call node2
-> teardown node2
-> teardown node1
"""
close_report = None
if report.passed:
if report.when == "call": # ignore setup/teardown
reporter = self._opentestcase(report)
reporter.append_pass(report)
elif report.failed:
if report.when == "teardown":
# The following vars are needed when xdist plugin is used.
report_wid = getattr(report, "worker_id", None)
report_ii = getattr(report, "item_index", None)
close_report = next(
(
rep
for rep in self.open_reports
if (
rep.nodeid == report.nodeid
and getattr(rep, "item_index", None) == report_ii
and getattr(rep, "worker_id", None) == report_wid
)
),
None,
)
if close_report:
# We need to open new testcase in case we have failure in
# call and error in teardown in order to follow junit
# schema.
self.finalize(close_report)
self.cnt_double_fail_tests += 1
reporter = self._opentestcase(report)
if report.when == "call":
reporter.append_failure(report)
self.open_reports.append(report)
if not self.log_passing_tests:
reporter.write_captured_output(report)
else:
reporter.append_error(report)
elif report.skipped:
reporter = self._opentestcase(report)
reporter.append_skipped(report)
self.update_testcase_duration(report)
if report.when == "teardown":
reporter = self._opentestcase(report)
reporter.write_captured_output(report)
for propname, propvalue in report.user_properties:
reporter.add_property(propname, str(propvalue))
self.finalize(report)
report_wid = getattr(report, "worker_id", None)
report_ii = getattr(report, "item_index", None)
close_report = next(
(
rep
for rep in self.open_reports
if (
rep.nodeid == report.nodeid
and getattr(rep, "item_index", None) == report_ii
and getattr(rep, "worker_id", None) == report_wid
)
),
None,
)
if close_report:
self.open_reports.remove(close_report)
def update_testcase_duration(self, report: TestReport) -> None:
"""Accumulate total duration for nodeid from given report and update
the Junit.testcase with the new total if already created."""
if self.report_duration == "total" or report.when == self.report_duration:
reporter = self.node_reporter(report)
reporter.duration += getattr(report, "duration", 0.0)
def pytest_collectreport(self, report: TestReport) -> None:
if not report.passed:
reporter = self._opentestcase(report)
if report.failed:
reporter.append_collect_error(report)
else:
reporter.append_collect_skipped(report)
def pytest_internalerror(self, excrepr: ExceptionRepr) -> None:
reporter = self.node_reporter("internal")
reporter.attrs.update(classname="pytest", name="internal")
reporter._add_simple("error", "internal error", str(excrepr))
def pytest_sessionstart(self) -> None:
self.suite_start_time = timing.time()
def pytest_sessionfinish(self) -> None:
dirname = os.path.dirname(os.path.abspath(self.logfile))
if not os.path.isdir(dirname):
os.makedirs(dirname)
with open(self.logfile, "w", encoding="utf-8") as logfile:
suite_stop_time = timing.time()
suite_time_delta = suite_stop_time - self.suite_start_time
numtests = (
self.stats["passed"]
+ self.stats["failure"]
+ self.stats["skipped"]
+ self.stats["error"]
- self.cnt_double_fail_tests
)
logfile.write('<?xml version="1.0" encoding="utf-8"?>')
suite_node = ET.Element(
"testsuite",
name=self.suite_name,
errors=str(self.stats["error"]),
failures=str(self.stats["failure"]),
skipped=str(self.stats["skipped"]),
tests=str(numtests),
time="%.3f" % suite_time_delta,
timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(),
hostname=platform.node(),
)
global_properties = self._get_global_properties_node()
if global_properties is not None:
suite_node.append(global_properties)
for node_reporter in self.node_reporters_ordered:
suite_node.append(node_reporter.to_xml())
testsuites = ET.Element("testsuites")
testsuites.append(suite_node)
logfile.write(ET.tostring(testsuites, encoding="unicode"))
def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
terminalreporter.write_sep("-", f"generated xml file: {self.logfile}")
def add_global_property(self, name: str, value: object) -> None:
__tracebackhide__ = True
_check_record_param_type("name", name)
self.global_properties.append((name, bin_xml_escape(value)))
def _get_global_properties_node(self) -> Optional[ET.Element]:
"""Return a Junit node containing custom properties, if any."""
if self.global_properties:
properties = ET.Element("properties")
for name, value in self.global_properties:
properties.append(ET.Element("property", name=name, value=value))
return properties
return None

View File

@ -0,0 +1,479 @@
"""Add backward compatibility support for the legacy py path type."""
import shlex
import subprocess
from pathlib import Path
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
import attr
from iniconfig import SectionWrapper
from _pytest.cacheprovider import Cache
from _pytest.compat import final
from _pytest.compat import LEGACY_PATH
from _pytest.compat import legacy_path
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config import PytestPluginManager
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.nodes import Node
from _pytest.pytester import HookRecorder
from _pytest.pytester import Pytester
from _pytest.pytester import RunResult
from _pytest.terminal import TerminalReporter
from _pytest.tmpdir import TempPathFactory
if TYPE_CHECKING:
from typing_extensions import Final
import pexpect
@final
class Testdir:
"""
Similar to :class:`Pytester`, but this class works with legacy legacy_path objects instead.
All methods just forward to an internal :class:`Pytester` instance, converting results
to `legacy_path` objects as necessary.
"""
__test__ = False
CLOSE_STDIN: "Final" = Pytester.CLOSE_STDIN
TimeoutExpired: "Final" = Pytester.TimeoutExpired
def __init__(self, pytester: Pytester, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
self._pytester = pytester
@property
def tmpdir(self) -> LEGACY_PATH:
"""Temporary directory where tests are executed."""
return legacy_path(self._pytester.path)
@property
def test_tmproot(self) -> LEGACY_PATH:
return legacy_path(self._pytester._test_tmproot)
@property
def request(self):
return self._pytester._request
@property
def plugins(self):
return self._pytester.plugins
@plugins.setter
def plugins(self, plugins):
self._pytester.plugins = plugins
@property
def monkeypatch(self) -> MonkeyPatch:
return self._pytester._monkeypatch
def make_hook_recorder(self, pluginmanager) -> HookRecorder:
"""See :meth:`Pytester.make_hook_recorder`."""
return self._pytester.make_hook_recorder(pluginmanager)
def chdir(self) -> None:
"""See :meth:`Pytester.chdir`."""
return self._pytester.chdir()
def finalize(self) -> None:
"""See :meth:`Pytester._finalize`."""
return self._pytester._finalize()
def makefile(self, ext, *args, **kwargs) -> LEGACY_PATH:
"""See :meth:`Pytester.makefile`."""
if ext and not ext.startswith("."):
# pytester.makefile is going to throw a ValueError in a way that
# testdir.makefile did not, because
# pathlib.Path is stricter suffixes than py.path
# This ext arguments is likely user error, but since testdir has
# allowed this, we will prepend "." as a workaround to avoid breaking
# testdir usage that worked before
ext = "." + ext
return legacy_path(self._pytester.makefile(ext, *args, **kwargs))
def makeconftest(self, source) -> LEGACY_PATH:
"""See :meth:`Pytester.makeconftest`."""
return legacy_path(self._pytester.makeconftest(source))
def makeini(self, source) -> LEGACY_PATH:
"""See :meth:`Pytester.makeini`."""
return legacy_path(self._pytester.makeini(source))
def getinicfg(self, source: str) -> SectionWrapper:
"""See :meth:`Pytester.getinicfg`."""
return self._pytester.getinicfg(source)
def makepyprojecttoml(self, source) -> LEGACY_PATH:
"""See :meth:`Pytester.makepyprojecttoml`."""
return legacy_path(self._pytester.makepyprojecttoml(source))
def makepyfile(self, *args, **kwargs) -> LEGACY_PATH:
"""See :meth:`Pytester.makepyfile`."""
return legacy_path(self._pytester.makepyfile(*args, **kwargs))
def maketxtfile(self, *args, **kwargs) -> LEGACY_PATH:
"""See :meth:`Pytester.maketxtfile`."""
return legacy_path(self._pytester.maketxtfile(*args, **kwargs))
def syspathinsert(self, path=None) -> None:
"""See :meth:`Pytester.syspathinsert`."""
return self._pytester.syspathinsert(path)
def mkdir(self, name) -> LEGACY_PATH:
"""See :meth:`Pytester.mkdir`."""
return legacy_path(self._pytester.mkdir(name))
def mkpydir(self, name) -> LEGACY_PATH:
"""See :meth:`Pytester.mkpydir`."""
return legacy_path(self._pytester.mkpydir(name))
def copy_example(self, name=None) -> LEGACY_PATH:
"""See :meth:`Pytester.copy_example`."""
return legacy_path(self._pytester.copy_example(name))
def getnode(self, config: Config, arg) -> Optional[Union[Item, Collector]]:
"""See :meth:`Pytester.getnode`."""
return self._pytester.getnode(config, arg)
def getpathnode(self, path):
"""See :meth:`Pytester.getpathnode`."""
return self._pytester.getpathnode(path)
def genitems(self, colitems: List[Union[Item, Collector]]) -> List[Item]:
"""See :meth:`Pytester.genitems`."""
return self._pytester.genitems(colitems)
def runitem(self, source):
"""See :meth:`Pytester.runitem`."""
return self._pytester.runitem(source)
def inline_runsource(self, source, *cmdlineargs):
"""See :meth:`Pytester.inline_runsource`."""
return self._pytester.inline_runsource(source, *cmdlineargs)
def inline_genitems(self, *args):
"""See :meth:`Pytester.inline_genitems`."""
return self._pytester.inline_genitems(*args)
def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
"""See :meth:`Pytester.inline_run`."""
return self._pytester.inline_run(
*args, plugins=plugins, no_reraise_ctrlc=no_reraise_ctrlc
)
def runpytest_inprocess(self, *args, **kwargs) -> RunResult:
"""See :meth:`Pytester.runpytest_inprocess`."""
return self._pytester.runpytest_inprocess(*args, **kwargs)
def runpytest(self, *args, **kwargs) -> RunResult:
"""See :meth:`Pytester.runpytest`."""
return self._pytester.runpytest(*args, **kwargs)
def parseconfig(self, *args) -> Config:
"""See :meth:`Pytester.parseconfig`."""
return self._pytester.parseconfig(*args)
def parseconfigure(self, *args) -> Config:
"""See :meth:`Pytester.parseconfigure`."""
return self._pytester.parseconfigure(*args)
def getitem(self, source, funcname="test_func"):
"""See :meth:`Pytester.getitem`."""
return self._pytester.getitem(source, funcname)
def getitems(self, source):
"""See :meth:`Pytester.getitems`."""
return self._pytester.getitems(source)
def getmodulecol(self, source, configargs=(), withinit=False):
"""See :meth:`Pytester.getmodulecol`."""
return self._pytester.getmodulecol(
source, configargs=configargs, withinit=withinit
)
def collect_by_name(
self, modcol: Collector, name: str
) -> Optional[Union[Item, Collector]]:
"""See :meth:`Pytester.collect_by_name`."""
return self._pytester.collect_by_name(modcol, name)
def popen(
self,
cmdargs,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=CLOSE_STDIN,
**kw,
):
"""See :meth:`Pytester.popen`."""
return self._pytester.popen(cmdargs, stdout, stderr, stdin, **kw)
def run(self, *cmdargs, timeout=None, stdin=CLOSE_STDIN) -> RunResult:
"""See :meth:`Pytester.run`."""
return self._pytester.run(*cmdargs, timeout=timeout, stdin=stdin)
def runpython(self, script) -> RunResult:
"""See :meth:`Pytester.runpython`."""
return self._pytester.runpython(script)
def runpython_c(self, command):
"""See :meth:`Pytester.runpython_c`."""
return self._pytester.runpython_c(command)
def runpytest_subprocess(self, *args, timeout=None) -> RunResult:
"""See :meth:`Pytester.runpytest_subprocess`."""
return self._pytester.runpytest_subprocess(*args, timeout=timeout)
def spawn_pytest(
self, string: str, expect_timeout: float = 10.0
) -> "pexpect.spawn":
"""See :meth:`Pytester.spawn_pytest`."""
return self._pytester.spawn_pytest(string, expect_timeout=expect_timeout)
def spawn(self, cmd: str, expect_timeout: float = 10.0) -> "pexpect.spawn":
"""See :meth:`Pytester.spawn`."""
return self._pytester.spawn(cmd, expect_timeout=expect_timeout)
def __repr__(self) -> str:
return f"<Testdir {self.tmpdir!r}>"
def __str__(self) -> str:
return str(self.tmpdir)
class LegacyTestdirPlugin:
@staticmethod
@fixture
def testdir(pytester: Pytester) -> Testdir:
"""
Identical to :fixture:`pytester`, and provides an instance whose methods return
legacy ``LEGACY_PATH`` objects instead when applicable.
New code should avoid using :fixture:`testdir` in favor of :fixture:`pytester`.
"""
return Testdir(pytester, _ispytest=True)
@final
@attr.s(init=False, auto_attribs=True)
class TempdirFactory:
"""Backward compatibility wrapper that implements :class:`py.path.local`
for :class:`TempPathFactory`.
.. note::
These days, it is preferred to use ``tmp_path_factory``.
:ref:`About the tmpdir and tmpdir_factory fixtures<tmpdir and tmpdir_factory>`.
"""
_tmppath_factory: TempPathFactory
def __init__(
self, tmppath_factory: TempPathFactory, *, _ispytest: bool = False
) -> None:
check_ispytest(_ispytest)
self._tmppath_factory = tmppath_factory
def mktemp(self, basename: str, numbered: bool = True) -> LEGACY_PATH:
"""Same as :meth:`TempPathFactory.mktemp`, but returns a :class:`py.path.local` object."""
return legacy_path(self._tmppath_factory.mktemp(basename, numbered).resolve())
def getbasetemp(self) -> LEGACY_PATH:
"""Same as :meth:`TempPathFactory.getbasetemp`, but returns a :class:`py.path.local` object."""
return legacy_path(self._tmppath_factory.getbasetemp().resolve())
class LegacyTmpdirPlugin:
@staticmethod
@fixture(scope="session")
def tmpdir_factory(request: FixtureRequest) -> TempdirFactory:
"""Return a :class:`pytest.TempdirFactory` instance for the test session."""
# Set dynamically by pytest_configure().
return request.config._tmpdirhandler # type: ignore
@staticmethod
@fixture
def tmpdir(tmp_path: Path) -> LEGACY_PATH:
"""Return a temporary directory path object which is unique to each test
function invocation, created as a sub directory of the base temporary
directory.
By default, a new base temporary directory is created each test session,
and old bases are removed after 3 sessions, to aid in debugging. If
``--basetemp`` is used then it is cleared each session. See :ref:`base
temporary directory`.
The returned object is a `legacy_path`_ object.
.. note::
These days, it is preferred to use ``tmp_path``.
:ref:`About the tmpdir and tmpdir_factory fixtures<tmpdir and tmpdir_factory>`.
.. _legacy_path: https://py.readthedocs.io/en/latest/path.html
"""
return legacy_path(tmp_path)
def Cache_makedir(self: Cache, name: str) -> LEGACY_PATH:
"""Return a directory path object with the given name.
Same as :func:`mkdir`, but returns a legacy py path instance.
"""
return legacy_path(self.mkdir(name))
def FixtureRequest_fspath(self: FixtureRequest) -> LEGACY_PATH:
"""(deprecated) The file system path of the test module which collected this test."""
return legacy_path(self.path)
def TerminalReporter_startdir(self: TerminalReporter) -> LEGACY_PATH:
"""The directory from which pytest was invoked.
Prefer to use ``startpath`` which is a :class:`pathlib.Path`.
:type: LEGACY_PATH
"""
return legacy_path(self.startpath)
def Config_invocation_dir(self: Config) -> LEGACY_PATH:
"""The directory from which pytest was invoked.
Prefer to use :attr:`invocation_params.dir <InvocationParams.dir>`,
which is a :class:`pathlib.Path`.
:type: LEGACY_PATH
"""
return legacy_path(str(self.invocation_params.dir))
def Config_rootdir(self: Config) -> LEGACY_PATH:
"""The path to the :ref:`rootdir <rootdir>`.
Prefer to use :attr:`rootpath`, which is a :class:`pathlib.Path`.
:type: LEGACY_PATH
"""
return legacy_path(str(self.rootpath))
def Config_inifile(self: Config) -> Optional[LEGACY_PATH]:
"""The path to the :ref:`configfile <configfiles>`.
Prefer to use :attr:`inipath`, which is a :class:`pathlib.Path`.
:type: Optional[LEGACY_PATH]
"""
return legacy_path(str(self.inipath)) if self.inipath else None
def Session_stardir(self: Session) -> LEGACY_PATH:
"""The path from which pytest was invoked.
Prefer to use ``startpath`` which is a :class:`pathlib.Path`.
:type: LEGACY_PATH
"""
return legacy_path(self.startpath)
def Config__getini_unknown_type(
self, name: str, type: str, value: Union[str, List[str]]
):
if type == "pathlist":
# TODO: This assert is probably not valid in all cases.
assert self.inipath is not None
dp = self.inipath.parent
input_values = shlex.split(value) if isinstance(value, str) else value
return [legacy_path(str(dp / x)) for x in input_values]
else:
raise ValueError(f"unknown configuration type: {type}", value)
def Node_fspath(self: Node) -> LEGACY_PATH:
"""(deprecated) returns a legacy_path copy of self.path"""
return legacy_path(self.path)
def Node_fspath_set(self: Node, value: LEGACY_PATH) -> None:
self.path = Path(value)
@hookimpl(tryfirst=True)
def pytest_load_initial_conftests(early_config: Config) -> None:
"""Monkeypatch legacy path attributes in several classes, as early as possible."""
mp = MonkeyPatch()
early_config.add_cleanup(mp.undo)
# Add Cache.makedir().
mp.setattr(Cache, "makedir", Cache_makedir, raising=False)
# Add FixtureRequest.fspath property.
mp.setattr(FixtureRequest, "fspath", property(FixtureRequest_fspath), raising=False)
# Add TerminalReporter.startdir property.
mp.setattr(
TerminalReporter, "startdir", property(TerminalReporter_startdir), raising=False
)
# Add Config.{invocation_dir,rootdir,inifile} properties.
mp.setattr(Config, "invocation_dir", property(Config_invocation_dir), raising=False)
mp.setattr(Config, "rootdir", property(Config_rootdir), raising=False)
mp.setattr(Config, "inifile", property(Config_inifile), raising=False)
# Add Session.startdir property.
mp.setattr(Session, "startdir", property(Session_stardir), raising=False)
# Add pathlist configuration type.
mp.setattr(Config, "_getini_unknown_type", Config__getini_unknown_type)
# Add Node.fspath property.
mp.setattr(Node, "fspath", property(Node_fspath, Node_fspath_set), raising=False)
@hookimpl
def pytest_configure(config: Config) -> None:
"""Installs the LegacyTmpdirPlugin if the ``tmpdir`` plugin is also installed."""
if config.pluginmanager.has_plugin("tmpdir"):
mp = MonkeyPatch()
config.add_cleanup(mp.undo)
# Create TmpdirFactory and attach it to the config object.
#
# This is to comply with existing plugins which expect the handler to be
# available at pytest_configure time, but ideally should be moved entirely
# to the tmpdir_factory session fixture.
try:
tmp_path_factory = config._tmp_path_factory # type: ignore[attr-defined]
except AttributeError:
# tmpdir plugin is blocked.
pass
else:
_tmpdirhandler = TempdirFactory(tmp_path_factory, _ispytest=True)
mp.setattr(config, "_tmpdirhandler", _tmpdirhandler, raising=False)
config.pluginmanager.register(LegacyTmpdirPlugin, "legacypath-tmpdir")
@hookimpl
def pytest_plugin_registered(plugin: object, manager: PytestPluginManager) -> None:
# pytester is not loaded by default and is commonly loaded from a conftest,
# so checking for it in `pytest_configure` is not enough.
is_pytester = plugin is manager.get_plugin("pytester")
if is_pytester and not manager.is_registered(LegacyTestdirPlugin):
manager.register(LegacyTestdirPlugin, "legacypath-pytester")

View File

@ -0,0 +1,830 @@
"""Access and control log capturing."""
import io
import logging
import os
import re
from contextlib import contextmanager
from contextlib import nullcontext
from io import StringIO
from pathlib import Path
from typing import AbstractSet
from typing import Dict
from typing import Generator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.capture import CaptureManager
from _pytest.compat import final
from _pytest.config import _strtobool
from _pytest.config import Config
from _pytest.config import create_terminal_writer
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter
if TYPE_CHECKING:
logging_StreamHandler = logging.StreamHandler[StringIO]
from typing_extensions import Literal
else:
logging_StreamHandler = logging.StreamHandler
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
DEFAULT_LOG_DATE_FORMAT = "%H:%M:%S"
_ANSI_ESCAPE_SEQ = re.compile(r"\x1b\[[\d;]+m")
caplog_handler_key = StashKey["LogCaptureHandler"]()
caplog_records_key = StashKey[Dict[str, List[logging.LogRecord]]]()
def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text)
class ColoredLevelFormatter(logging.Formatter):
"""A logging formatter which colorizes the %(levelname)..s part of the
log format passed to __init__."""
LOGLEVEL_COLOROPTS: Mapping[int, AbstractSet[str]] = {
logging.CRITICAL: {"red"},
logging.ERROR: {"red", "bold"},
logging.WARNING: {"yellow"},
logging.WARN: {"yellow"},
logging.INFO: {"green"},
logging.DEBUG: {"purple"},
logging.NOTSET: set(),
}
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*(?:\.\d+)?s)")
def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._terminalwriter = terminalwriter
self._original_fmt = self._style._fmt
self._level_to_fmt_mapping: Dict[int, str] = {}
for level, color_opts in self.LOGLEVEL_COLOROPTS.items():
self.add_color_level(level, *color_opts)
def add_color_level(self, level: int, *color_opts: str) -> None:
"""Add or update color opts for a log level.
:param level:
Log level to apply a style to, e.g. ``logging.INFO``.
:param color_opts:
ANSI escape sequence color options. Capitalized colors indicates
background color, i.e. ``'green', 'Yellow', 'bold'`` will give bold
green text on yellow background.
.. warning::
This is an experimental API.
"""
assert self._fmt is not None
levelname_fmt_match = self.LEVELNAME_FMT_REGEX.search(self._fmt)
if not levelname_fmt_match:
return
levelname_fmt = levelname_fmt_match.group()
formatted_levelname = levelname_fmt % {"levelname": logging.getLevelName(level)}
# add ANSI escape sequences around the formatted levelname
color_kwargs = {name: True for name in color_opts}
colorized_formatted_levelname = self._terminalwriter.markup(
formatted_levelname, **color_kwargs
)
self._level_to_fmt_mapping[level] = self.LEVELNAME_FMT_REGEX.sub(
colorized_formatted_levelname, self._fmt
)
def format(self, record: logging.LogRecord) -> str:
fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt)
self._style._fmt = fmt
return super().format(record)
class PercentStyleMultiline(logging.PercentStyle):
"""A logging style with special support for multiline messages.
If the message of a record consists of multiple lines, this style
formats the message as if each line were logged separately.
"""
def __init__(self, fmt: str, auto_indent: Union[int, str, bool, None]) -> None:
super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod
def _get_auto_indent(auto_indent_option: Union[int, str, bool, None]) -> int:
"""Determine the current auto indentation setting.
Specify auto indent behavior (on/off/fixed) by passing in
extra={"auto_indent": [value]} to the call to logging.log() or
using a --log-auto-indent [value] command line or the
log_auto_indent [value] config option.
Default behavior is auto-indent off.
Using the string "True" or "on" or the boolean True as the value
turns auto indent on, using the string "False" or "off" or the
boolean False or the int 0 turns it off, and specifying a
positive integer fixes the indentation position to the value
specified.
Any other values for the option are invalid, and will silently be
converted to the default.
:param None|bool|int|str auto_indent_option:
User specified option for indentation from command line, config
or extra kwarg. Accepts int, bool or str. str option accepts the
same range of values as boolean config options, as well as
positive integers represented in str form.
:returns:
Indentation value, which can be
-1 (automatically determine indentation) or
0 (auto-indent turned off) or
>0 (explicitly set indentation position).
"""
if auto_indent_option is None:
return 0
elif isinstance(auto_indent_option, bool):
if auto_indent_option:
return -1
else:
return 0
elif isinstance(auto_indent_option, int):
return int(auto_indent_option)
elif isinstance(auto_indent_option, str):
try:
return int(auto_indent_option)
except ValueError:
pass
try:
if _strtobool(auto_indent_option):
return -1
except ValueError:
return 0
return 0
def format(self, record: logging.LogRecord) -> str:
if "\n" in record.message:
if hasattr(record, "auto_indent"):
# Passed in from the "extra={}" kwarg on the call to logging.log().
auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined]
else:
auto_indent = self._auto_indent
if auto_indent:
lines = record.message.splitlines()
formatted = self._fmt % {**record.__dict__, "message": lines[0]}
if auto_indent < 0:
indentation = _remove_ansi_escape_sequences(formatted).find(
lines[0]
)
else:
# Optimizes logging by allowing a fixed indentation.
indentation = auto_indent
lines[0] = formatted
return ("\n" + " " * indentation).join(lines)
return self._fmt % record.__dict__
def get_option_ini(config: Config, *names: str):
for name in names:
ret = config.getoption(name) # 'default' arg won't work as expected
if ret is None:
ret = config.getini(name)
if ret:
return ret
def pytest_addoption(parser: Parser) -> None:
"""Add options to control log capturing."""
group = parser.getgroup("logging")
def add_option_ini(option, dest, default=None, type=None, **kwargs):
parser.addini(
dest, default=default, type=type, help="Default value for " + option
)
group.addoption(option, dest=dest, **kwargs)
add_option_ini(
"--log-level",
dest="log_level",
default=None,
metavar="LEVEL",
help=(
"Level of messages to catch/display."
" Not set by default, so it depends on the root/parent log handler's"
' effective level, where it is "WARNING" by default.'
),
)
add_option_ini(
"--log-format",
dest="log_format",
default=DEFAULT_LOG_FORMAT,
help="Log format used by the logging module",
)
add_option_ini(
"--log-date-format",
dest="log_date_format",
default=DEFAULT_LOG_DATE_FORMAT,
help="Log date format used by the logging module",
)
parser.addini(
"log_cli",
default=False,
type="bool",
help='Enable log display during test run (also known as "live logging")',
)
add_option_ini(
"--log-cli-level", dest="log_cli_level", default=None, help="CLI logging level"
)
add_option_ini(
"--log-cli-format",
dest="log_cli_format",
default=None,
help="Log format used by the logging module",
)
add_option_ini(
"--log-cli-date-format",
dest="log_cli_date_format",
default=None,
help="Log date format used by the logging module",
)
add_option_ini(
"--log-file",
dest="log_file",
default=None,
help="Path to a file when logging will be written to",
)
add_option_ini(
"--log-file-level",
dest="log_file_level",
default=None,
help="Log file logging level",
)
add_option_ini(
"--log-file-format",
dest="log_file_format",
default=DEFAULT_LOG_FORMAT,
help="Log format used by the logging module",
)
add_option_ini(
"--log-file-date-format",
dest="log_file_date_format",
default=DEFAULT_LOG_DATE_FORMAT,
help="Log date format used by the logging module",
)
add_option_ini(
"--log-auto-indent",
dest="log_auto_indent",
default=None,
help="Auto-indent multiline messages passed to the logging module. Accepts true|on, false|off or an integer.",
)
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
# Not using @contextmanager for performance reasons.
class catching_logs:
"""Context manager that prepares the whole logging machinery properly."""
__slots__ = ("handler", "level", "orig_level")
def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None:
self.handler = handler
self.level = level
def __enter__(self):
root_logger = logging.getLogger()
if self.level is not None:
self.handler.setLevel(self.level)
root_logger.addHandler(self.handler)
if self.level is not None:
self.orig_level = root_logger.level
root_logger.setLevel(min(self.orig_level, self.level))
return self.handler
def __exit__(self, type, value, traceback):
root_logger = logging.getLogger()
if self.level is not None:
root_logger.setLevel(self.orig_level)
root_logger.removeHandler(self.handler)
class LogCaptureHandler(logging_StreamHandler):
"""A logging handler that stores log records and the log text."""
def __init__(self) -> None:
"""Create a new log handler."""
super().__init__(StringIO())
self.records: List[logging.LogRecord] = []
def emit(self, record: logging.LogRecord) -> None:
"""Keep the log records in a list in addition to the log text."""
self.records.append(record)
super().emit(record)
def reset(self) -> None:
self.records = []
self.stream = StringIO()
def clear(self) -> None:
self.records.clear()
self.stream = StringIO()
def handleError(self, record: logging.LogRecord) -> None:
if logging.raiseExceptions:
# Fail the test if the log message is bad (emit failed).
# The default behavior of logging is to print "Logging error"
# to stderr with the call stack and some extra details.
# pytest wants to make such mistakes visible during testing.
raise
@final
class LogCaptureFixture:
"""Provides access and control of log capturing."""
def __init__(self, item: nodes.Node, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
self._item = item
self._initial_handler_level: Optional[int] = None
# Dict of log name -> log level.
self._initial_logger_levels: Dict[Optional[str], int] = {}
def _finalize(self) -> None:
"""Finalize the fixture.
This restores the log levels changed by :meth:`set_level`.
"""
# Restore log levels.
if self._initial_handler_level is not None:
self.handler.setLevel(self._initial_handler_level)
for logger_name, level in self._initial_logger_levels.items():
logger = logging.getLogger(logger_name)
logger.setLevel(level)
@property
def handler(self) -> LogCaptureHandler:
"""Get the logging handler used by the fixture."""
return self._item.stash[caplog_handler_key]
def get_records(
self, when: "Literal['setup', 'call', 'teardown']"
) -> List[logging.LogRecord]:
"""Get the logging records for one of the possible test phases.
:param when:
Which test phase to obtain the records from.
Valid values are: "setup", "call" and "teardown".
:returns: The list of captured records at the given stage.
.. versionadded:: 3.4
"""
return self._item.stash[caplog_records_key].get(when, [])
@property
def text(self) -> str:
"""The formatted log text."""
return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property
def records(self) -> List[logging.LogRecord]:
"""The list of log records."""
return self.handler.records
@property
def record_tuples(self) -> List[Tuple[str, int, str]]:
"""A list of a stripped down version of log records intended
for use in assertion comparison.
The format of the tuple is:
(logger_name, log_level, message)
"""
return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property
def messages(self) -> List[str]:
"""A list of format-interpolated log messages.
Unlike 'records', which contains the format string and parameters for
interpolation, log messages in this list are all interpolated.
Unlike 'text', which contains the output from the handler, log
messages in this list are unadorned with levels, timestamps, etc,
making exact comparisons more reliable.
Note that traceback or stack info (from :func:`logging.exception` or
the `exc_info` or `stack_info` arguments to the logging functions) is
not included, as this is added by the formatter in the handler.
.. versionadded:: 3.7
"""
return [r.getMessage() for r in self.records]
def clear(self) -> None:
"""Reset the list of log records and the captured log text."""
self.handler.clear()
def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None:
"""Set the level of a logger for the duration of a test.
.. versionchanged:: 3.4
The levels of the loggers changed by this function will be
restored to their initial values at the end of the test.
:param level: The level.
:param logger: The logger to update. If not given, the root logger.
"""
logger_obj = logging.getLogger(logger)
# Save the original log-level to restore it during teardown.
self._initial_logger_levels.setdefault(logger, logger_obj.level)
logger_obj.setLevel(level)
if self._initial_handler_level is None:
self._initial_handler_level = self.handler.level
self.handler.setLevel(level)
@contextmanager
def at_level(
self, level: Union[int, str], logger: Optional[str] = None
) -> Generator[None, None, None]:
"""Context manager that sets the level for capturing of logs. After
the end of the 'with' statement the level is restored to its original
value.
:param level: The level.
:param logger: The logger to update. If not given, the root logger.
"""
logger_obj = logging.getLogger(logger)
orig_level = logger_obj.level
logger_obj.setLevel(level)
handler_orig_level = self.handler.level
self.handler.setLevel(level)
try:
yield
finally:
logger_obj.setLevel(orig_level)
self.handler.setLevel(handler_orig_level)
@fixture
def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
"""Access and control log capturing.
Captured logs are available through the following properties/methods::
* caplog.messages -> list of format-interpolated log messages
* caplog.text -> string containing formatted log output
* caplog.records -> list of logging.LogRecord instances
* caplog.record_tuples -> list of (logger_name, level, message) tuples
* caplog.clear() -> clear captured records and formatted log output string
"""
result = LogCaptureFixture(request.node, _ispytest=True)
yield result
result._finalize()
def get_log_level_for_setting(config: Config, *setting_names: str) -> Optional[int]:
for setting_name in setting_names:
log_level = config.getoption(setting_name)
if log_level is None:
log_level = config.getini(setting_name)
if log_level:
break
else:
return None
if isinstance(log_level, str):
log_level = log_level.upper()
try:
return int(getattr(logging, log_level, log_level))
except ValueError as e:
# Python logging does not recognise this as a logging level
raise UsageError(
"'{}' is not recognized as a logging level name for "
"'{}'. Please consider passing the "
"logging level num instead.".format(log_level, setting_name)
) from e
# run after terminalreporter/capturemanager are configured
@hookimpl(trylast=True)
def pytest_configure(config: Config) -> None:
config.pluginmanager.register(LoggingPlugin(config), "logging-plugin")
class LoggingPlugin:
"""Attaches to the logging module and captures log messages for each test."""
def __init__(self, config: Config) -> None:
"""Create a new plugin to capture log messages.
The formatter can be safely shared across all handlers so
create a single one for the entire test session here.
"""
self._config = config
# Report logging.
self.formatter = self._create_formatter(
get_option_ini(config, "log_format"),
get_option_ini(config, "log_date_format"),
get_option_ini(config, "log_auto_indent"),
)
self.log_level = get_log_level_for_setting(config, "log_level")
self.caplog_handler = LogCaptureHandler()
self.caplog_handler.setFormatter(self.formatter)
self.report_handler = LogCaptureHandler()
self.report_handler.setFormatter(self.formatter)
# File logging.
self.log_file_level = get_log_level_for_setting(config, "log_file_level")
log_file = get_option_ini(config, "log_file") or os.devnull
if log_file != os.devnull:
directory = os.path.dirname(os.path.abspath(log_file))
if not os.path.isdir(directory):
os.makedirs(directory)
self.log_file_handler = _FileHandler(log_file, mode="w", encoding="UTF-8")
log_file_format = get_option_ini(config, "log_file_format", "log_format")
log_file_date_format = get_option_ini(
config, "log_file_date_format", "log_date_format"
)
log_file_formatter = logging.Formatter(
log_file_format, datefmt=log_file_date_format
)
self.log_file_handler.setFormatter(log_file_formatter)
# CLI/live logging.
self.log_cli_level = get_log_level_for_setting(
config, "log_cli_level", "log_level"
)
if self._log_cli_enabled():
terminal_reporter = config.pluginmanager.get_plugin("terminalreporter")
capture_manager = config.pluginmanager.get_plugin("capturemanager")
# if capturemanager plugin is disabled, live logging still works.
self.log_cli_handler: Union[
_LiveLoggingStreamHandler, _LiveLoggingNullHandler
] = _LiveLoggingStreamHandler(terminal_reporter, capture_manager)
else:
self.log_cli_handler = _LiveLoggingNullHandler()
log_cli_formatter = self._create_formatter(
get_option_ini(config, "log_cli_format", "log_format"),
get_option_ini(config, "log_cli_date_format", "log_date_format"),
get_option_ini(config, "log_auto_indent"),
)
self.log_cli_handler.setFormatter(log_cli_formatter)
def _create_formatter(self, log_format, log_date_format, auto_indent):
# Color option doesn't exist if terminal plugin is disabled.
color = getattr(self._config.option, "color", "no")
if color != "no" and ColoredLevelFormatter.LEVELNAME_FMT_REGEX.search(
log_format
):
formatter: logging.Formatter = ColoredLevelFormatter(
create_terminal_writer(self._config), log_format, log_date_format
)
else:
formatter = logging.Formatter(log_format, log_date_format)
formatter._style = PercentStyleMultiline(
formatter._style._fmt, auto_indent=auto_indent
)
return formatter
def set_log_path(self, fname: str) -> None:
"""Set the filename parameter for Logging.FileHandler().
Creates parent directory if it does not exist.
.. warning::
This is an experimental API.
"""
fpath = Path(fname)
if not fpath.is_absolute():
fpath = self._config.rootpath / fpath
if not fpath.parent.exists():
fpath.parent.mkdir(exist_ok=True, parents=True)
# https://github.com/python/mypy/issues/11193
stream: io.TextIOWrapper = fpath.open(mode="w", encoding="UTF-8") # type: ignore[assignment]
old_stream = self.log_file_handler.setStream(stream)
if old_stream:
old_stream.close()
def _log_cli_enabled(self):
"""Return whether live logging is enabled."""
enabled = self._config.getoption(
"--log-cli-level"
) is not None or self._config.getini("log_cli")
if not enabled:
return False
terminal_reporter = self._config.pluginmanager.get_plugin("terminalreporter")
if terminal_reporter is None:
# terminal reporter is disabled e.g. by pytest-xdist.
return False
return True
@hookimpl(hookwrapper=True, tryfirst=True)
def pytest_sessionstart(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("sessionstart")
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
yield
@hookimpl(hookwrapper=True, tryfirst=True)
def pytest_collection(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("collection")
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
yield
@hookimpl(hookwrapper=True)
def pytest_runtestloop(self, session: Session) -> Generator[None, None, None]:
if session.config.option.collectonly:
yield
return
if self._log_cli_enabled() and self._config.getoption("verbose") < 1:
# The verbose flag is needed to avoid messy test progress output.
self._config.option.verbose = 1
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
yield # Run all the tests.
@hookimpl
def pytest_runtest_logstart(self) -> None:
self.log_cli_handler.reset()
self.log_cli_handler.set_when("start")
@hookimpl
def pytest_runtest_logreport(self) -> None:
self.log_cli_handler.set_when("logreport")
def _runtest_for(self, item: nodes.Item, when: str) -> Generator[None, None, None]:
"""Implement the internals of the pytest_runtest_xxx() hooks."""
with catching_logs(
self.caplog_handler,
level=self.log_level,
) as caplog_handler, catching_logs(
self.report_handler,
level=self.log_level,
) as report_handler:
caplog_handler.reset()
report_handler.reset()
item.stash[caplog_records_key][when] = caplog_handler.records
item.stash[caplog_handler_key] = caplog_handler
yield
log = report_handler.stream.getvalue().strip()
item.add_report_section(when, "log", log)
@hookimpl(hookwrapper=True)
def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("setup")
empty: Dict[str, List[logging.LogRecord]] = {}
item.stash[caplog_records_key] = empty
yield from self._runtest_for(item, "setup")
@hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("call")
yield from self._runtest_for(item, "call")
@hookimpl(hookwrapper=True)
def pytest_runtest_teardown(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("teardown")
yield from self._runtest_for(item, "teardown")
del item.stash[caplog_records_key]
del item.stash[caplog_handler_key]
@hookimpl
def pytest_runtest_logfinish(self) -> None:
self.log_cli_handler.set_when("finish")
@hookimpl(hookwrapper=True, tryfirst=True)
def pytest_sessionfinish(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("sessionfinish")
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
yield
@hookimpl
def pytest_unconfigure(self) -> None:
# Close the FileHandler explicitly.
# (logging.shutdown might have lost the weakref?!)
self.log_file_handler.close()
class _FileHandler(logging.FileHandler):
"""A logging FileHandler with pytest tweaks."""
def handleError(self, record: logging.LogRecord) -> None:
# Handled by LogCaptureHandler.
pass
class _LiveLoggingStreamHandler(logging_StreamHandler):
"""A logging StreamHandler used by the live logging feature: it will
write a newline before the first log message in each test.
During live logging we must also explicitly disable stdout/stderr
capturing otherwise it will get captured and won't appear in the
terminal.
"""
# Officially stream needs to be a IO[str], but TerminalReporter
# isn't. So force it.
stream: TerminalReporter = None # type: ignore
def __init__(
self,
terminal_reporter: TerminalReporter,
capture_manager: Optional[CaptureManager],
) -> None:
super().__init__(stream=terminal_reporter) # type: ignore[arg-type]
self.capture_manager = capture_manager
self.reset()
self.set_when(None)
self._test_outcome_written = False
def reset(self) -> None:
"""Reset the handler; should be called before the start of each test."""
self._first_record_emitted = False
def set_when(self, when: Optional[str]) -> None:
"""Prepare for the given test phase (setup/call/teardown)."""
self._when = when
self._section_name_shown = False
if when == "start":
self._test_outcome_written = False
def emit(self, record: logging.LogRecord) -> None:
ctx_manager = (
self.capture_manager.global_and_fixture_disabled()
if self.capture_manager
else nullcontext()
)
with ctx_manager:
if not self._first_record_emitted:
self.stream.write("\n")
self._first_record_emitted = True
elif self._when in ("teardown", "finish"):
if not self._test_outcome_written:
self._test_outcome_written = True
self.stream.write("\n")
if not self._section_name_shown and self._when:
self.stream.section("live log " + self._when, sep="-", bold=True)
self._section_name_shown = True
super().emit(record)
def handleError(self, record: logging.LogRecord) -> None:
# Handled by LogCaptureHandler.
pass
class _LiveLoggingNullHandler(logging.NullHandler):
"""A logging handler used when live logging is disabled."""
def reset(self) -> None:
pass
def set_when(self, when: str) -> None:
pass
def handleError(self, record: logging.LogRecord) -> None:
# Handled by LogCaptureHandler.
pass

View File

@ -0,0 +1,902 @@
"""Core implementation of the testing process: init, session, runtest loop."""
import argparse
import fnmatch
import functools
import importlib
import os
import sys
from pathlib import Path
from typing import Callable
from typing import Dict
from typing import FrozenSet
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
import attr
import _pytest._code
from _pytest import nodes
from _pytest.compat import final
from _pytest.compat import overload
from _pytest.config import Config
from _pytest.config import directory_arg
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import PytestPluginManager
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureManager
from _pytest.outcomes import exit
from _pytest.pathlib import absolutepath
from _pytest.pathlib import bestrelpath
from _pytest.pathlib import fnmatch_ex
from _pytest.pathlib import visit
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
from _pytest.runner import collect_one_node
from _pytest.runner import SetupState
if TYPE_CHECKING:
from typing_extensions import Literal
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"norecursedirs",
"Directory patterns to avoid for recursion",
type="args",
default=[
"*.egg",
".*",
"_darcs",
"build",
"CVS",
"dist",
"node_modules",
"venv",
"{arch}",
],
)
parser.addini(
"testpaths",
"Directories to search for tests when no files or directories are given on the "
"command line",
type="args",
default=[],
)
group = parser.getgroup("general", "Running and selection options")
group._addoption(
"-x",
"--exitfirst",
action="store_const",
dest="maxfail",
const=1,
help="Exit instantly on first error or failed test",
)
group = parser.getgroup("pytest-warnings")
group.addoption(
"-W",
"--pythonwarnings",
action="append",
help="Set which warnings to report, see -W option of Python itself",
)
parser.addini(
"filterwarnings",
type="linelist",
help="Each line specifies a pattern for "
"warnings.filterwarnings. "
"Processed after -W/--pythonwarnings.",
)
group._addoption(
"--maxfail",
metavar="num",
action="store",
type=int,
dest="maxfail",
default=0,
help="Exit after first num failures or errors",
)
group._addoption(
"--strict-config",
action="store_true",
help="Any warnings encountered while parsing the `pytest` section of the "
"configuration file raise errors",
)
group._addoption(
"--strict-markers",
action="store_true",
help="Markers not registered in the `markers` section of the configuration "
"file raise errors",
)
group._addoption(
"--strict",
action="store_true",
help="(Deprecated) alias to --strict-markers",
)
group._addoption(
"-c",
metavar="file",
type=str,
dest="inifilename",
help="Load configuration from `file` instead of trying to locate one of the "
"implicit configuration files",
)
group._addoption(
"--continue-on-collection-errors",
action="store_true",
default=False,
dest="continue_on_collection_errors",
help="Force test execution even if collection errors occur",
)
group._addoption(
"--rootdir",
action="store",
dest="rootdir",
help="Define root directory for tests. Can be relative path: 'root_dir', './root_dir', "
"'root_dir/another_dir/'; absolute path: '/home/user/root_dir'; path with variables: "
"'$HOME/root_dir'.",
)
group = parser.getgroup("collect", "collection")
group.addoption(
"--collectonly",
"--collect-only",
"--co",
action="store_true",
help="Only collect tests, don't execute them",
)
group.addoption(
"--pyargs",
action="store_true",
help="Try to interpret all arguments as Python packages",
)
group.addoption(
"--ignore",
action="append",
metavar="path",
help="Ignore path during collection (multi-allowed)",
)
group.addoption(
"--ignore-glob",
action="append",
metavar="path",
help="Ignore path pattern during collection (multi-allowed)",
)
group.addoption(
"--deselect",
action="append",
metavar="nodeid_prefix",
help="Deselect item (via node id prefix) during collection (multi-allowed)",
)
group.addoption(
"--confcutdir",
dest="confcutdir",
default=None,
metavar="dir",
type=functools.partial(directory_arg, optname="--confcutdir"),
help="Only load conftest.py's relative to specified dir",
)
group.addoption(
"--noconftest",
action="store_true",
dest="noconftest",
default=False,
help="Don't load any conftest.py files",
)
group.addoption(
"--keepduplicates",
"--keep-duplicates",
action="store_true",
dest="keepduplicates",
default=False,
help="Keep duplicate tests",
)
group.addoption(
"--collect-in-virtualenv",
action="store_true",
dest="collect_in_virtualenv",
default=False,
help="Don't ignore tests in a local virtualenv directory",
)
group.addoption(
"--import-mode",
default="prepend",
choices=["prepend", "append", "importlib"],
dest="importmode",
help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.",
)
group = parser.getgroup("debugconfig", "test session debugging and configuration")
group.addoption(
"--basetemp",
dest="basetemp",
default=None,
type=validate_basetemp,
metavar="dir",
help=(
"Base temporary directory for this test run. "
"(Warning: this directory is removed if it exists.)"
),
)
def validate_basetemp(path: str) -> str:
# GH 7119
msg = "basetemp must not be empty, the current working directory or any parent directory of it"
# empty path
if not path:
raise argparse.ArgumentTypeError(msg)
def is_ancestor(base: Path, query: Path) -> bool:
"""Return whether query is an ancestor of base."""
if base == query:
return True
return query in base.parents
# check if path is an ancestor of cwd
if is_ancestor(Path.cwd(), Path(path).absolute()):
raise argparse.ArgumentTypeError(msg)
# check symlinks for ancestors
if is_ancestor(Path.cwd().resolve(), Path(path).resolve()):
raise argparse.ArgumentTypeError(msg)
return path
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
session.exitstatus = doit(config, session) or 0
except UsageError:
session.exitstatus = ExitCode.USAGE_ERROR
raise
except Failed:
session.exitstatus = ExitCode.TESTS_FAILED
except (KeyboardInterrupt, exit.Exception):
excinfo = _pytest._code.ExceptionInfo.from_current()
exitstatus: Union[int, ExitCode] = ExitCode.INTERRUPTED
if isinstance(excinfo.value, exit.Exception):
if excinfo.value.returncode is not None:
exitstatus = excinfo.value.returncode
if initstate < 2:
sys.stderr.write(f"{excinfo.typename}: {excinfo.value.msg}\n")
config.hook.pytest_keyboard_interrupt(excinfo=excinfo)
session.exitstatus = exitstatus
except BaseException:
session.exitstatus = ExitCode.INTERNAL_ERROR
excinfo = _pytest._code.ExceptionInfo.from_current()
try:
config.notify_exception(excinfo, config.option)
except exit.Exception as exc:
if exc.returncode is not None:
session.exitstatus = exc.returncode
sys.stderr.write(f"{type(exc).__name__}: {exc}\n")
else:
if isinstance(excinfo.value, SystemExit):
sys.stderr.write("mainloop: caught unexpected SystemExit!\n")
finally:
# Explicitly break reference cycle.
excinfo = None # type: ignore
os.chdir(session.startpath)
if initstate >= 2:
try:
config.hook.pytest_sessionfinish(
session=session, exitstatus=session.exitstatus
)
except exit.Exception as exc:
if exc.returncode is not None:
session.exitstatus = exc.returncode
sys.stderr.write(f"{type(exc).__name__}: {exc}\n")
config._ensure_unconfigure()
return session.exitstatus
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
return wrap_session(config, _main)
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
config.hook.pytest_runtestloop(session=session)
if session.testsfailed:
return ExitCode.TESTS_FAILED
elif session.testscollected == 0:
return ExitCode.NO_TESTS_COLLECTED
return None
def pytest_collection(session: "Session") -> None:
session.perform_collect()
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
if session.shouldfail:
raise session.Failed(session.shouldfail)
if session.shouldstop:
raise session.Interrupted(session.shouldstop)
return True
def _in_venv(path: Path) -> bool:
"""Attempt to detect if ``path`` is the root of a Virtual Environment by
checking for the existence of the appropriate activate script."""
bindir = path.joinpath("Scripts" if sys.platform.startswith("win") else "bin")
try:
if not bindir.is_dir():
return False
except OSError:
return False
activates = (
"activate",
"activate.csh",
"activate.fish",
"Activate",
"Activate.bat",
"Activate.ps1",
)
return any(fname.name in activates for fname in bindir.iterdir())
def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[bool]:
ignore_paths = config._getconftest_pathlist(
"collect_ignore", path=collection_path.parent, rootpath=config.rootpath
)
ignore_paths = ignore_paths or []
excludeopt = config.getoption("ignore")
if excludeopt:
ignore_paths.extend(absolutepath(x) for x in excludeopt)
if collection_path in ignore_paths:
return True
ignore_globs = config._getconftest_pathlist(
"collect_ignore_glob", path=collection_path.parent, rootpath=config.rootpath
)
ignore_globs = ignore_globs or []
excludeglobopt = config.getoption("ignore_glob")
if excludeglobopt:
ignore_globs.extend(absolutepath(x) for x in excludeglobopt)
if any(fnmatch.fnmatch(str(collection_path), str(glob)) for glob in ignore_globs):
return True
allow_in_venv = config.getoption("collect_in_virtualenv")
if not allow_in_venv and _in_venv(collection_path):
return True
return None
def pytest_collection_modifyitems(items: List[nodes.Item], config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or [])
if not deselect_prefixes:
return
remaining = []
deselected = []
for colitem in items:
if colitem.nodeid.startswith(deselect_prefixes):
deselected.append(colitem)
else:
remaining.append(colitem)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = remaining
class FSHookProxy:
def __init__(self, pm: PytestPluginManager, remove_mods) -> None:
self.pm = pm
self.remove_mods = remove_mods
def __getattr__(self, name: str):
x = self.pm.subset_hook_caller(name, remove_plugins=self.remove_mods)
self.__dict__[name] = x
return x
class Interrupted(KeyboardInterrupt):
"""Signals that the test run was interrupted."""
__module__ = "builtins" # For py3.
class Failed(Exception):
"""Signals a stop as failed test run."""
@attr.s(slots=True, auto_attribs=True)
class _bestrelpath_cache(Dict[Path, str]):
path: Path
def __missing__(self, path: Path) -> str:
r = bestrelpath(self.path, path)
self[path] = r
return r
@final
class Session(nodes.FSCollector):
Interrupted = Interrupted
Failed = Failed
# Set on the session by runner.pytest_sessionstart.
_setupstate: SetupState
# Set on the session by fixtures.pytest_sessionstart.
_fixturemanager: FixtureManager
exitstatus: Union[int, ExitCode]
def __init__(self, config: Config) -> None:
super().__init__(
path=config.rootpath,
fspath=None,
parent=None,
config=config,
session=self,
nodeid="",
)
self.testsfailed = 0
self.testscollected = 0
self.shouldstop: Union[bool, str] = False
self.shouldfail: Union[bool, str] = False
self.trace = config.trace.root.get("collection")
self._initialpaths: FrozenSet[Path] = frozenset()
self._bestrelpathcache: Dict[Path, str] = _bestrelpath_cache(config.rootpath)
self.config.pluginmanager.register(self, name="session")
@classmethod
def from_config(cls, config: Config) -> "Session":
session: Session = cls._create(config=config)
return session
def __repr__(self) -> str:
return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % (
self.__class__.__name__,
self.name,
getattr(self, "exitstatus", "<UNSET>"),
self.testsfailed,
self.testscollected,
)
@property
def startpath(self) -> Path:
"""The path from which pytest was invoked.
.. versionadded:: 7.0.0
"""
return self.config.invocation_params.dir
def _node_location_to_relpath(self, node_path: Path) -> str:
# bestrelpath is a quite slow function.
return self._bestrelpathcache[node_path]
@hookimpl(tryfirst=True)
def pytest_collectstart(self) -> None:
if self.shouldfail:
raise self.Failed(self.shouldfail)
if self.shouldstop:
raise self.Interrupted(self.shouldstop)
@hookimpl(tryfirst=True)
def pytest_runtest_logreport(
self, report: Union[TestReport, CollectReport]
) -> None:
if report.failed and not hasattr(report, "wasxfail"):
self.testsfailed += 1
maxfail = self.config.getvalue("maxfail")
if maxfail and self.testsfailed >= maxfail:
self.shouldfail = "stopping after %d failures" % (self.testsfailed)
pytest_collectreport = pytest_runtest_logreport
def isinitpath(self, path: Union[str, "os.PathLike[str]"]) -> bool:
# Optimization: Path(Path(...)) is much slower than isinstance.
path_ = path if isinstance(path, Path) else Path(path)
return path_ in self._initialpaths
def gethookproxy(self, fspath: "os.PathLike[str]"):
# Optimization: Path(Path(...)) is much slower than isinstance.
path = fspath if isinstance(fspath, Path) else Path(fspath)
pm = self.config.pluginmanager
# Check if we have the common case of running
# hooks with all conftest.py files.
my_conftestmodules = pm._getconftestmodules(
path,
self.config.getoption("importmode"),
rootpath=self.config.rootpath,
)
remove_mods = pm._conftest_plugins.difference(my_conftestmodules)
if remove_mods:
# One or more conftests are not in use at this fspath.
from .config.compat import PathAwareHookProxy
proxy = PathAwareHookProxy(FSHookProxy(pm, remove_mods))
else:
# All plugins are active for this fspath.
proxy = self.config.hook
return proxy
def _recurse(self, direntry: "os.DirEntry[str]") -> bool:
if direntry.name == "__pycache__":
return False
fspath = Path(direntry.path)
ihook = self.gethookproxy(fspath.parent)
if ihook.pytest_ignore_collect(collection_path=fspath, config=self.config):
return False
norecursepatterns = self.config.getini("norecursedirs")
if any(fnmatch_ex(pat, fspath) for pat in norecursepatterns):
return False
return True
def _collectfile(
self, fspath: Path, handle_dupes: bool = True
) -> Sequence[nodes.Collector]:
assert (
fspath.is_file()
), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format(
fspath, fspath.is_dir(), fspath.exists(), fspath.is_symlink()
)
ihook = self.gethookproxy(fspath)
if not self.isinitpath(fspath):
if ihook.pytest_ignore_collect(collection_path=fspath, config=self.config):
return ()
if handle_dupes:
keepduplicates = self.config.getoption("keepduplicates")
if not keepduplicates:
duplicate_paths = self.config.pluginmanager._duplicatepaths
if fspath in duplicate_paths:
return ()
else:
duplicate_paths.add(fspath)
return ihook.pytest_collect_file(file_path=fspath, parent=self) # type: ignore[no-any-return]
@overload
def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ...
) -> Sequence[nodes.Item]:
...
@overload
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = ..., genitems: bool = ...
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
...
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = None, genitems: bool = True
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
"""Perform the collection phase for this session.
This is called by the default :hook:`pytest_collection` hook
implementation; see the documentation of this hook for more details.
For testing purposes, it may also be called directly on a fresh
``Session``.
This function normally recursively expands any collectors collected
from the session to their items, and only items are returned. For
testing purposes, this may be suppressed by passing ``genitems=False``,
in which case the return value contains these collectors unexpanded,
and ``session.items`` is empty.
"""
if args is None:
args = self.config.args
self.trace("perform_collect", self, args)
self.trace.root.indent += 1
self._notfound: List[Tuple[str, Sequence[nodes.Collector]]] = []
self._initial_parts: List[Tuple[Path, List[str]]] = []
self.items: List[nodes.Item] = []
hook = self.config.hook
items: Sequence[Union[nodes.Item, nodes.Collector]] = self.items
try:
initialpaths: List[Path] = []
for arg in args:
fspath, parts = resolve_collection_argument(
self.config.invocation_params.dir,
arg,
as_pypath=self.config.option.pyargs,
)
self._initial_parts.append((fspath, parts))
initialpaths.append(fspath)
self._initialpaths = frozenset(initialpaths)
rep = collect_one_node(self)
self.ihook.pytest_collectreport(report=rep)
self.trace.root.indent -= 1
if self._notfound:
errors = []
for arg, collectors in self._notfound:
if collectors:
errors.append(
f"not found: {arg}\n(no name {arg!r} in any of {collectors!r})"
)
else:
errors.append(f"found no collectors for {arg}")
raise UsageError(*errors)
if not genitems:
items = rep.result
else:
if rep.passed:
for node in rep.result:
self.items.extend(self.genitems(node))
self.config.pluginmanager.check_pending()
hook.pytest_collection_modifyitems(
session=self, config=self.config, items=items
)
finally:
hook.pytest_collection_finish(session=self)
self.testscollected = len(items)
return items
def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]:
from _pytest.python import Package
# Keep track of any collected nodes in here, so we don't duplicate fixtures.
node_cache1: Dict[Path, Sequence[nodes.Collector]] = {}
node_cache2: Dict[Tuple[Type[nodes.Collector], Path], nodes.Collector] = {}
# Keep track of any collected collectors in matchnodes paths, so they
# are not collected more than once.
matchnodes_cache: Dict[Tuple[Type[nodes.Collector], str], CollectReport] = {}
# Dirnames of pkgs with dunder-init files.
pkg_roots: Dict[str, Package] = {}
for argpath, names in self._initial_parts:
self.trace("processing argument", (argpath, names))
self.trace.root.indent += 1
# Start with a Session root, and delve to argpath item (dir or file)
# and stack all Packages found on the way.
# No point in finding packages when collecting doctests.
if not self.config.getoption("doctestmodules", False):
pm = self.config.pluginmanager
for parent in (argpath, *argpath.parents):
if not pm._is_in_confcutdir(argpath):
break
if parent.is_dir():
pkginit = parent / "__init__.py"
if pkginit.is_file() and pkginit not in node_cache1:
col = self._collectfile(pkginit, handle_dupes=False)
if col:
if isinstance(col[0], Package):
pkg_roots[str(parent)] = col[0]
node_cache1[col[0].path] = [col[0]]
# If it's a directory argument, recurse and look for any Subpackages.
# Let the Package collector deal with subnodes, don't collect here.
if argpath.is_dir():
assert not names, f"invalid arg {(argpath, names)!r}"
seen_dirs: Set[Path] = set()
for direntry in visit(str(argpath), self._recurse):
if not direntry.is_file():
continue
path = Path(direntry.path)
dirpath = path.parent
if dirpath not in seen_dirs:
# Collect packages first.
seen_dirs.add(dirpath)
pkginit = dirpath / "__init__.py"
if pkginit.exists():
for x in self._collectfile(pkginit):
yield x
if isinstance(x, Package):
pkg_roots[str(dirpath)] = x
if str(dirpath) in pkg_roots:
# Do not collect packages here.
continue
for x in self._collectfile(path):
key2 = (type(x), x.path)
if key2 in node_cache2:
yield node_cache2[key2]
else:
node_cache2[key2] = x
yield x
else:
assert argpath.is_file()
if argpath in node_cache1:
col = node_cache1[argpath]
else:
collect_root = pkg_roots.get(str(argpath.parent), self)
col = collect_root._collectfile(argpath, handle_dupes=False)
if col:
node_cache1[argpath] = col
matching = []
work: List[
Tuple[Sequence[Union[nodes.Item, nodes.Collector]], Sequence[str]]
] = [(col, names)]
while work:
self.trace("matchnodes", col, names)
self.trace.root.indent += 1
matchnodes, matchnames = work.pop()
for node in matchnodes:
if not matchnames:
matching.append(node)
continue
if not isinstance(node, nodes.Collector):
continue
key = (type(node), node.nodeid)
if key in matchnodes_cache:
rep = matchnodes_cache[key]
else:
rep = collect_one_node(node)
matchnodes_cache[key] = rep
if rep.passed:
submatchnodes = []
for r in rep.result:
# TODO: Remove parametrized workaround once collection structure contains
# parametrization.
if (
r.name == matchnames[0]
or r.name.split("[")[0] == matchnames[0]
):
submatchnodes.append(r)
if submatchnodes:
work.append((submatchnodes, matchnames[1:]))
else:
# Report collection failures here to avoid failing to run some test
# specified in the command line because the module could not be
# imported (#134).
node.ihook.pytest_collectreport(report=rep)
self.trace("matchnodes finished -> ", len(matching), "nodes")
self.trace.root.indent -= 1
if not matching:
report_arg = "::".join((str(argpath), *names))
self._notfound.append((report_arg, col))
continue
# If __init__.py was the only file requested, then the matched
# node will be the corresponding Package (by default), and the
# first yielded item will be the __init__ Module itself, so
# just use that. If this special case isn't taken, then all the
# files in the package will be yielded.
if argpath.name == "__init__.py" and isinstance(matching[0], Package):
try:
yield next(iter(matching[0].collect()))
except StopIteration:
# The package collects nothing with only an __init__.py
# file in it, which gets ignored by the default
# "python_files" option.
pass
continue
yield from matching
self.trace.root.indent -= 1
def genitems(
self, node: Union[nodes.Item, nodes.Collector]
) -> Iterator[nodes.Item]:
self.trace("genitems", node)
if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node)
yield node
else:
assert isinstance(node, nodes.Collector)
rep = collect_one_node(node)
if rep.passed:
for subnode in rep.result:
yield from self.genitems(subnode)
node.ihook.pytest_collectreport(report=rep)
def search_pypath(module_name: str) -> str:
"""Search sys.path for the given a dotted module name, and return its file system path."""
try:
spec = importlib.util.find_spec(module_name)
# AttributeError: looks like package module, but actually filename
# ImportError: module does not exist
# ValueError: not a module name
except (AttributeError, ImportError, ValueError):
return module_name
if spec is None or spec.origin is None or spec.origin == "namespace":
return module_name
elif spec.submodule_search_locations:
return os.path.dirname(spec.origin)
else:
return spec.origin
def resolve_collection_argument(
invocation_path: Path, arg: str, *, as_pypath: bool = False
) -> Tuple[Path, List[str]]:
"""Parse path arguments optionally containing selection parts and return (fspath, names).
Command-line arguments can point to files and/or directories, and optionally contain
parts for specific tests selection, for example:
"pkg/tests/test_foo.py::TestClass::test_foo"
This function ensures the path exists, and returns a tuple:
(Path("/full/path/to/pkg/tests/test_foo.py"), ["TestClass", "test_foo"])
When as_pypath is True, expects that the command-line argument actually contains
module paths instead of file-system paths:
"pkg.tests.test_foo::TestClass::test_foo"
In which case we search sys.path for a matching module, and then return the *path* to the
found module.
If the path doesn't exist, raise UsageError.
If the path is a directory and selection parts are present, raise UsageError.
"""
base, squacket, rest = str(arg).partition("[")
strpath, *parts = base.split("::")
if parts:
parts[-1] = f"{parts[-1]}{squacket}{rest}"
if as_pypath:
strpath = search_pypath(strpath)
fspath = invocation_path / strpath
fspath = absolutepath(fspath)
if not fspath.exists():
msg = (
"module or package not found: {arg} (missing __init__.py?)"
if as_pypath
else "file or directory not found: {arg}"
)
raise UsageError(msg.format(arg=arg))
if parts and fspath.is_dir():
msg = (
"package argument cannot contain :: selection parts: {arg}"
if as_pypath
else "directory argument cannot contain :: selection parts: {arg}"
)
raise UsageError(msg.format(arg=arg))
return fspath, parts

View File

@ -0,0 +1,266 @@
"""Generic mechanism for marking and selecting python functions."""
from typing import AbstractSet
from typing import Collection
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
import attr
from .expression import Expression
from .expression import ParseError
from .structures import EMPTY_PARAMETERSET_OPTION
from .structures import get_empty_parameterset_mark
from .structures import Mark
from .structures import MARK_GEN
from .structures import MarkDecorator
from .structures import MarkGenerator
from .structures import ParameterSet
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
if TYPE_CHECKING:
from _pytest.nodes import Item
__all__ = [
"MARK_GEN",
"Mark",
"MarkDecorator",
"MarkGenerator",
"ParameterSet",
"get_empty_parameterset_mark",
]
old_mark_config_key = StashKey[Optional[Config]]()
def param(
*values: object,
marks: Union[MarkDecorator, Collection[Union[MarkDecorator, Mark]]] = (),
id: Optional[str] = None,
) -> ParameterSet:
"""Specify a parameter in `pytest.mark.parametrize`_ calls or
:ref:`parametrized fixtures <fixture-parametrize-marks>`.
.. code-block:: python
@pytest.mark.parametrize(
"test_input,expected",
[
("3+5", 8),
pytest.param("6*9", 42, marks=pytest.mark.xfail),
],
)
def test_eval(test_input, expected):
assert eval(test_input) == expected
:param values: Variable args of the values of the parameter set, in order.
:param marks: A single mark or a list of marks to be applied to this parameter set.
:param id: The id to attribute to this parameter set.
"""
return ParameterSet.param(*values, marks=marks, id=id)
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption(
"-k",
action="store",
dest="keyword",
default="",
metavar="EXPRESSION",
help="Only run tests which match the given substring expression. "
"An expression is a Python evaluatable expression "
"where all names are substring-matched against test names "
"and their parent classes. Example: -k 'test_method or test_"
"other' matches all test functions and classes whose name "
"contains 'test_method' or 'test_other', while -k 'not test_method' "
"matches those that don't contain 'test_method' in their names. "
"-k 'not test_method and not test_other' will eliminate the matches. "
"Additionally keywords are matched to classes and functions "
"containing extra names in their 'extra_keyword_matches' set, "
"as well as functions which have names assigned directly to them. "
"The matching is case-insensitive.",
)
group._addoption(
"-m",
action="store",
dest="markexpr",
default="",
metavar="MARKEXPR",
help="Only run tests matching given mark expression. "
"For example: -m 'mark1 and not mark2'.",
)
group.addoption(
"--markers",
action="store_true",
help="show markers (builtin, plugin and per-project ones).",
)
parser.addini("markers", "Markers for test functions", "linelist")
parser.addini(EMPTY_PARAMETERSET_OPTION, "Default marker for empty parametersets")
@hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
import _pytest.config
if config.option.markers:
config._do_configure()
tw = _pytest.config.create_terminal_writer(config)
for line in config.getini("markers"):
parts = line.split(":", 1)
name = parts[0]
rest = parts[1] if len(parts) == 2 else ""
tw.write("@pytest.mark.%s:" % name, bold=True)
tw.line(rest)
tw.line()
config._ensure_unconfigure()
return 0
return None
@attr.s(slots=True, auto_attribs=True)
class KeywordMatcher:
"""A matcher for keywords.
Given a list of names, matches any substring of one of these names. The
string inclusion check is case-insensitive.
Will match on the name of colitem, including the names of its parents.
Only matches names of items which are either a :class:`Class` or a
:class:`Function`.
Additionally, matches on names in the 'extra_keyword_matches' set of
any item, as well as names directly assigned to test functions.
"""
_names: AbstractSet[str]
@classmethod
def from_item(cls, item: "Item") -> "KeywordMatcher":
mapped_names = set()
# Add the names of the current item and any parent items.
import pytest
for node in item.listchain():
if not isinstance(node, pytest.Session):
mapped_names.add(node.name)
# Add the names added as extra keywords to current or parent items.
mapped_names.update(item.listextrakeywords())
# Add the names attached to the current function through direct assignment.
function_obj = getattr(item, "function", None)
if function_obj:
mapped_names.update(function_obj.__dict__)
# Add the markers to the keywords as we no longer handle them correctly.
mapped_names.update(mark.name for mark in item.iter_markers())
return cls(mapped_names)
def __call__(self, subname: str) -> bool:
subname = subname.lower()
names = (name.lower() for name in self._names)
for name in names:
if subname in name:
return True
return False
def deselect_by_keyword(items: "List[Item]", config: Config) -> None:
keywordexpr = config.option.keyword.lstrip()
if not keywordexpr:
return
expr = _parse_expression(keywordexpr, "Wrong expression passed to '-k'")
remaining = []
deselected = []
for colitem in items:
if not expr.evaluate(KeywordMatcher.from_item(colitem)):
deselected.append(colitem)
else:
remaining.append(colitem)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = remaining
@attr.s(slots=True, auto_attribs=True)
class MarkMatcher:
"""A matcher for markers which are present.
Tries to match on any marker names, attached to the given colitem.
"""
own_mark_names: AbstractSet[str]
@classmethod
def from_item(cls, item: "Item") -> "MarkMatcher":
mark_names = {mark.name for mark in item.iter_markers()}
return cls(mark_names)
def __call__(self, name: str) -> bool:
return name in self.own_mark_names
def deselect_by_mark(items: "List[Item]", config: Config) -> None:
matchexpr = config.option.markexpr
if not matchexpr:
return
expr = _parse_expression(matchexpr, "Wrong expression passed to '-m'")
remaining: List[Item] = []
deselected: List[Item] = []
for item in items:
if expr.evaluate(MarkMatcher.from_item(item)):
remaining.append(item)
else:
deselected.append(item)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = remaining
def _parse_expression(expr: str, exc_message: str) -> Expression:
try:
return Expression.compile(expr)
except ParseError as e:
raise UsageError(f"{exc_message}: {expr}: {e}") from None
def pytest_collection_modifyitems(items: "List[Item]", config: Config) -> None:
deselect_by_keyword(items, config)
deselect_by_mark(items, config)
def pytest_configure(config: Config) -> None:
config.stash[old_mark_config_key] = MARK_GEN._config
MARK_GEN._config = config
empty_parameterset = config.getini(EMPTY_PARAMETERSET_OPTION)
if empty_parameterset not in ("skip", "xfail", "fail_at_collect", None, ""):
raise UsageError(
"{!s} must be one of skip, xfail or fail_at_collect"
" but it is {!r}".format(EMPTY_PARAMETERSET_OPTION, empty_parameterset)
)
def pytest_unconfigure(config: Config) -> None:
MARK_GEN._config = config.stash.get(old_mark_config_key, None)

View File

@ -0,0 +1,222 @@
r"""Evaluate match expressions, as used by `-k` and `-m`.
The grammar is:
expression: expr? EOF
expr: and_expr ('or' and_expr)*
and_expr: not_expr ('and' not_expr)*
not_expr: 'not' not_expr | '(' expr ')' | ident
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
The semantics are:
- Empty expression evaluates to False.
- ident evaluates to True of False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics.
"""
import ast
import enum
import re
import types
from typing import Callable
from typing import Iterator
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import Sequence
import attr
__all__ = [
"Expression",
"ParseError",
]
class TokenType(enum.Enum):
LPAREN = "left parenthesis"
RPAREN = "right parenthesis"
OR = "or"
AND = "and"
NOT = "not"
IDENT = "identifier"
EOF = "end of input"
@attr.s(frozen=True, slots=True, auto_attribs=True)
class Token:
type: TokenType
value: str
pos: int
class ParseError(Exception):
"""The expression contains invalid syntax.
:param column: The column in the line where the error occurred (1-based).
:param message: A description of the error.
"""
def __init__(self, column: int, message: str) -> None:
self.column = column
self.message = message
def __str__(self) -> str:
return f"at column {self.column}: {self.message}"
class Scanner:
__slots__ = ("tokens", "current")
def __init__(self, input: str) -> None:
self.tokens = self.lex(input)
self.current = next(self.tokens)
def lex(self, input: str) -> Iterator[Token]:
pos = 0
while pos < len(input):
if input[pos] in (" ", "\t"):
pos += 1
elif input[pos] == "(":
yield Token(TokenType.LPAREN, "(", pos)
pos += 1
elif input[pos] == ")":
yield Token(TokenType.RPAREN, ")", pos)
pos += 1
else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
if match:
value = match.group(0)
if value == "or":
yield Token(TokenType.OR, value, pos)
elif value == "and":
yield Token(TokenType.AND, value, pos)
elif value == "not":
yield Token(TokenType.NOT, value, pos)
else:
yield Token(TokenType.IDENT, value, pos)
pos += len(value)
else:
raise ParseError(
pos + 1,
f'unexpected character "{input[pos]}"',
)
yield Token(TokenType.EOF, "", pos)
def accept(self, type: TokenType, *, reject: bool = False) -> Optional[Token]:
if self.current.type is type:
token = self.current
if token.type is not TokenType.EOF:
self.current = next(self.tokens)
return token
if reject:
self.reject((type,))
return None
def reject(self, expected: Sequence[TokenType]) -> NoReturn:
raise ParseError(
self.current.pos + 1,
"expected {}; got {}".format(
" OR ".join(type.value for type in expected),
self.current.type.value,
),
)
# True, False and None are legal match expression identifiers,
# but illegal as Python identifiers. To fix this, this prefix
# is added to identifiers in the conversion to Python AST.
IDENT_PREFIX = "$"
def expression(s: Scanner) -> ast.Expression:
if s.accept(TokenType.EOF):
ret: ast.expr = ast.NameConstant(False)
else:
ret = expr(s)
s.accept(TokenType.EOF, reject=True)
return ast.fix_missing_locations(ast.Expression(ret))
def expr(s: Scanner) -> ast.expr:
ret = and_expr(s)
while s.accept(TokenType.OR):
rhs = and_expr(s)
ret = ast.BoolOp(ast.Or(), [ret, rhs])
return ret
def and_expr(s: Scanner) -> ast.expr:
ret = not_expr(s)
while s.accept(TokenType.AND):
rhs = not_expr(s)
ret = ast.BoolOp(ast.And(), [ret, rhs])
return ret
def not_expr(s: Scanner) -> ast.expr:
if s.accept(TokenType.NOT):
return ast.UnaryOp(ast.Not(), not_expr(s))
if s.accept(TokenType.LPAREN):
ret = expr(s)
s.accept(TokenType.RPAREN, reject=True)
return ret
ident = s.accept(TokenType.IDENT)
if ident:
return ast.Name(IDENT_PREFIX + ident.value, ast.Load())
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
class MatcherAdapter(Mapping[str, bool]):
"""Adapts a matcher function to a locals mapping as required by eval()."""
def __init__(self, matcher: Callable[[str], bool]) -> None:
self.matcher = matcher
def __getitem__(self, key: str) -> bool:
return self.matcher(key[len(IDENT_PREFIX) :])
def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
def __len__(self) -> int:
raise NotImplementedError()
class Expression:
"""A compiled match expression as used by -k and -m.
The expression can be evaluated against different matchers.
"""
__slots__ = ("code",)
def __init__(self, code: types.CodeType) -> None:
self.code = code
@classmethod
def compile(self, input: str) -> "Expression":
"""Compile a match expression.
:param input: The input expression - one line.
"""
astexpr = expression(Scanner(input))
code: types.CodeType = compile(
astexpr,
filename="<pytest match expression>",
mode="eval",
)
return Expression(code)
def evaluate(self, matcher: Callable[[str], bool]) -> bool:
"""Evaluate the match expression.
:param matcher:
Given an identifier, should return whether it matches or not.
Should be prepared to handle arbitrary strings as input.
:returns: Whether the expression matches or not.
"""
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
return ret

View File

@ -0,0 +1,613 @@
import collections.abc
import inspect
import warnings
from typing import Any
from typing import Callable
from typing import Collection
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
from .._code import getfslineno
from ..compat import ascii_escaped
from ..compat import final
from ..compat import NOTSET
from ..compat import NotSetType
from _pytest.config import Config
from _pytest.deprecated import check_ispytest
from _pytest.outcomes import fail
from _pytest.warning_types import PytestUnknownMarkWarning
if TYPE_CHECKING:
from ..nodes import Node
EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark"
def istestfunc(func) -> bool:
return callable(func) and getattr(func, "__name__", "<lambda>") != "<lambda>"
def get_empty_parameterset_mark(
config: Config, argnames: Sequence[str], func
) -> "MarkDecorator":
from ..nodes import Collector
fs, lineno = getfslineno(func)
reason = "got empty parameter set %r, function %s at %s:%d" % (
argnames,
func.__name__,
fs,
lineno,
)
requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
if requested_mark in ("", None, "skip"):
mark = MARK_GEN.skip(reason=reason)
elif requested_mark == "xfail":
mark = MARK_GEN.xfail(reason=reason, run=False)
elif requested_mark == "fail_at_collect":
f_name = func.__name__
_, lineno = getfslineno(func)
raise Collector.CollectError(
"Empty parameter set in '%s' at line %d" % (f_name, lineno + 1)
)
else:
raise LookupError(requested_mark)
return mark
class ParameterSet(NamedTuple):
values: Sequence[Union[object, NotSetType]]
marks: Collection[Union["MarkDecorator", "Mark"]]
id: Optional[str]
@classmethod
def param(
cls,
*values: object,
marks: Union["MarkDecorator", Collection[Union["MarkDecorator", "Mark"]]] = (),
id: Optional[str] = None,
) -> "ParameterSet":
if isinstance(marks, MarkDecorator):
marks = (marks,)
else:
assert isinstance(marks, collections.abc.Collection)
if id is not None:
if not isinstance(id, str):
raise TypeError(f"Expected id to be a string, got {type(id)}: {id!r}")
id = ascii_escaped(id)
return cls(values, marks, id)
@classmethod
def extract_from(
cls,
parameterset: Union["ParameterSet", Sequence[object], object],
force_tuple: bool = False,
) -> "ParameterSet":
"""Extract from an object or objects.
:param parameterset:
A legacy style parameterset that may or may not be a tuple,
and may or may not be wrapped into a mess of mark objects.
:param force_tuple:
Enforce tuple wrapping so single argument tuple values
don't get decomposed and break tests.
"""
if isinstance(parameterset, cls):
return parameterset
if force_tuple:
return cls.param(parameterset)
else:
# TODO: Refactor to fix this type-ignore. Currently the following
# passes type-checking but crashes:
#
# @pytest.mark.parametrize(('x', 'y'), [1, 2])
# def test_foo(x, y): pass
return cls(parameterset, marks=[], id=None) # type: ignore[arg-type]
@staticmethod
def _parse_parametrize_args(
argnames: Union[str, Sequence[str]],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
*args,
**kwargs,
) -> Tuple[Sequence[str], bool]:
if isinstance(argnames, str):
argnames = [x.strip() for x in argnames.split(",") if x.strip()]
force_tuple = len(argnames) == 1
else:
force_tuple = False
return argnames, force_tuple
@staticmethod
def _parse_parametrize_parameters(
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
force_tuple: bool,
) -> List["ParameterSet"]:
return [
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
]
@classmethod
def _for_parametrize(
cls,
argnames: Union[str, Sequence[str]],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
func,
config: Config,
nodeid: str,
) -> Tuple[Sequence[str], List["ParameterSet"]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues
if parameters:
# Check all parameter sets have the correct number of values.
for param in parameters:
if len(param.values) != len(argnames):
msg = (
'{nodeid}: in "parametrize" the number of names ({names_len}):\n'
" {names}\n"
"must be equal to the number of values ({values_len}):\n"
" {values}"
)
fail(
msg.format(
nodeid=nodeid,
values=param.values,
names=argnames,
names_len=len(argnames),
values_len=len(param.values),
),
pytrace=False,
)
else:
# Empty parameter set (likely computed at runtime): create a single
# parameter set with NOTSET values, with the "empty parameter set" mark applied to it.
mark = get_empty_parameterset_mark(config, argnames, func)
parameters.append(
ParameterSet(values=(NOTSET,) * len(argnames), marks=[mark], id=None)
)
return argnames, parameters
@final
@attr.s(frozen=True, init=False, auto_attribs=True)
class Mark:
#: Name of the mark.
name: str
#: Positional arguments of the mark decorator.
args: Tuple[Any, ...]
#: Keyword arguments of the mark decorator.
kwargs: Mapping[str, Any]
#: Source Mark for ids with parametrize Marks.
_param_ids_from: Optional["Mark"] = attr.ib(default=None, repr=False)
#: Resolved/generated ids with parametrize Marks.
_param_ids_generated: Optional[Sequence[str]] = attr.ib(default=None, repr=False)
def __init__(
self,
name: str,
args: Tuple[Any, ...],
kwargs: Mapping[str, Any],
param_ids_from: Optional["Mark"] = None,
param_ids_generated: Optional[Sequence[str]] = None,
*,
_ispytest: bool = False,
) -> None:
""":meta private:"""
check_ispytest(_ispytest)
# Weirdness to bypass frozen=True.
object.__setattr__(self, "name", name)
object.__setattr__(self, "args", args)
object.__setattr__(self, "kwargs", kwargs)
object.__setattr__(self, "_param_ids_from", param_ids_from)
object.__setattr__(self, "_param_ids_generated", param_ids_generated)
def _has_param_ids(self) -> bool:
return "ids" in self.kwargs or len(self.args) >= 4
def combined_with(self, other: "Mark") -> "Mark":
"""Return a new Mark which is a combination of this
Mark and another Mark.
Combines by appending args and merging kwargs.
:param Mark other: The mark to combine with.
:rtype: Mark
"""
assert self.name == other.name
# Remember source of ids with parametrize Marks.
param_ids_from: Optional[Mark] = None
if self.name == "parametrize":
if other._has_param_ids():
param_ids_from = other
elif self._has_param_ids():
param_ids_from = self
return Mark(
self.name,
self.args + other.args,
dict(self.kwargs, **other.kwargs),
param_ids_from=param_ids_from,
_ispytest=True,
)
# A generic parameter designating an object to which a Mark may
# be applied -- a test function (callable) or class.
# Note: a lambda is not allowed, but this can't be represented.
Markable = TypeVar("Markable", bound=Union[Callable[..., object], type])
@attr.s(init=False, auto_attribs=True)
class MarkDecorator:
"""A decorator for applying a mark on test functions and classes.
``MarkDecorators`` are created with ``pytest.mark``::
mark1 = pytest.mark.NAME # Simple MarkDecorator
mark2 = pytest.mark.NAME(name1=value) # Parametrized MarkDecorator
and can then be applied as decorators to test functions::
@mark2
def test_function():
pass
When a ``MarkDecorator`` is called, it does the following:
1. If called with a single class as its only positional argument and no
additional keyword arguments, it attaches the mark to the class so it
gets applied automatically to all test cases found in that class.
2. If called with a single function as its only positional argument and
no additional keyword arguments, it attaches the mark to the function,
containing all the arguments already stored internally in the
``MarkDecorator``.
3. When called in any other case, it returns a new ``MarkDecorator``
instance with the original ``MarkDecorator``'s content updated with
the arguments passed to this call.
Note: The rules above prevent a ``MarkDecorator`` from storing only a
single function or class reference as its positional argument with no
additional keyword or positional arguments. You can work around this by
using `with_args()`.
"""
mark: Mark
def __init__(self, mark: Mark, *, _ispytest: bool = False) -> None:
""":meta private:"""
check_ispytest(_ispytest)
self.mark = mark
@property
def name(self) -> str:
"""Alias for mark.name."""
return self.mark.name
@property
def args(self) -> Tuple[Any, ...]:
"""Alias for mark.args."""
return self.mark.args
@property
def kwargs(self) -> Mapping[str, Any]:
"""Alias for mark.kwargs."""
return self.mark.kwargs
@property
def markname(self) -> str:
""":meta private:"""
return self.name # for backward-compat (2.4.1 had this attr)
def with_args(self, *args: object, **kwargs: object) -> "MarkDecorator":
"""Return a MarkDecorator with extra arguments added.
Unlike calling the MarkDecorator, with_args() can be used even
if the sole argument is a callable/class.
"""
mark = Mark(self.name, args, kwargs, _ispytest=True)
return MarkDecorator(self.mark.combined_with(mark), _ispytest=True)
# Type ignored because the overloads overlap with an incompatible
# return type. Not much we can do about that. Thankfully mypy picks
# the first match so it works out even if we break the rules.
@overload
def __call__(self, arg: Markable) -> Markable: # type: ignore[misc]
pass
@overload
def __call__(self, *args: object, **kwargs: object) -> "MarkDecorator":
pass
def __call__(self, *args: object, **kwargs: object):
"""Call the MarkDecorator."""
if args and not kwargs:
func = args[0]
is_class = inspect.isclass(func)
if len(args) == 1 and (istestfunc(func) or is_class):
store_mark(func, self.mark)
return func
return self.with_args(*args, **kwargs)
def get_unpacked_marks(
obj: Union[object, type],
*,
consider_mro: bool = True,
) -> List[Mark]:
"""Obtain the unpacked marks that are stored on an object.
If obj is a class and consider_mro is true, return marks applied to
this class and all of its super-classes in MRO order. If consider_mro
is false, only return marks applied directly to this class.
"""
if isinstance(obj, type):
if not consider_mro:
mark_lists = [obj.__dict__.get("pytestmark", [])]
else:
mark_lists = [x.__dict__.get("pytestmark", []) for x in obj.__mro__]
mark_list = []
for item in mark_lists:
if isinstance(item, list):
mark_list.extend(item)
else:
mark_list.append(item)
else:
mark_attribute = getattr(obj, "pytestmark", [])
if isinstance(mark_attribute, list):
mark_list = mark_attribute
else:
mark_list = [mark_attribute]
return list(normalize_mark_list(mark_list))
def normalize_mark_list(
mark_list: Iterable[Union[Mark, MarkDecorator]]
) -> Iterable[Mark]:
"""
Normalize an iterable of Mark or MarkDecorator objects into a list of marks
by retrieving the `mark` attribute on MarkDecorator instances.
:param mark_list: marks to normalize
:returns: A new list of the extracted Mark objects
"""
for mark in mark_list:
mark_obj = getattr(mark, "mark", mark)
if not isinstance(mark_obj, Mark):
raise TypeError(f"got {repr(mark_obj)} instead of Mark")
yield mark_obj
def store_mark(obj, mark: Mark) -> None:
"""Store a Mark on an object.
This is used to implement the Mark declarations/decorators correctly.
"""
assert isinstance(mark, Mark), mark
# Always reassign name to avoid updating pytestmark in a reference that
# was only borrowed.
obj.pytestmark = [*get_unpacked_marks(obj, consider_mro=False), mark]
# Typing for builtin pytest marks. This is cheating; it gives builtin marks
# special privilege, and breaks modularity. But practicality beats purity...
if TYPE_CHECKING:
from _pytest.scope import _ScopeName
class _SkipMarkDecorator(MarkDecorator):
@overload # type: ignore[override,misc,no-overload-impl]
def __call__(self, arg: Markable) -> Markable:
...
@overload
def __call__(self, reason: str = ...) -> "MarkDecorator":
...
class _SkipifMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override]
self,
condition: Union[str, bool] = ...,
*conditions: Union[str, bool],
reason: str = ...,
) -> MarkDecorator:
...
class _XfailMarkDecorator(MarkDecorator):
@overload # type: ignore[override,misc,no-overload-impl]
def __call__(self, arg: Markable) -> Markable:
...
@overload
def __call__(
self,
condition: Union[str, bool] = ...,
*conditions: Union[str, bool],
reason: str = ...,
run: bool = ...,
raises: Union[Type[BaseException], Tuple[Type[BaseException], ...]] = ...,
strict: bool = ...,
) -> MarkDecorator:
...
class _ParametrizeMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override]
self,
argnames: Union[str, Sequence[str]],
argvalues: Iterable[Union[ParameterSet, Sequence[object], object]],
*,
indirect: Union[bool, Sequence[str]] = ...,
ids: Optional[
Union[
Iterable[Union[None, str, float, int, bool]],
Callable[[Any], Optional[object]],
]
] = ...,
scope: Optional[_ScopeName] = ...,
) -> MarkDecorator:
...
class _UsefixturesMarkDecorator(MarkDecorator):
def __call__(self, *fixtures: str) -> MarkDecorator: # type: ignore[override]
...
class _FilterwarningsMarkDecorator(MarkDecorator):
def __call__(self, *filters: str) -> MarkDecorator: # type: ignore[override]
...
@final
class MarkGenerator:
"""Factory for :class:`MarkDecorator` objects - exposed as
a ``pytest.mark`` singleton instance.
Example::
import pytest
@pytest.mark.slowtest
def test_function():
pass
applies a 'slowtest' :class:`Mark` on ``test_function``.
"""
# See TYPE_CHECKING above.
if TYPE_CHECKING:
skip: _SkipMarkDecorator
skipif: _SkipifMarkDecorator
xfail: _XfailMarkDecorator
parametrize: _ParametrizeMarkDecorator
usefixtures: _UsefixturesMarkDecorator
filterwarnings: _FilterwarningsMarkDecorator
def __init__(self, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
self._config: Optional[Config] = None
self._markers: Set[str] = set()
def __getattr__(self, name: str) -> MarkDecorator:
"""Generate a new :class:`MarkDecorator` with the given name."""
if name[0] == "_":
raise AttributeError("Marker name must NOT start with underscore")
if self._config is not None:
# We store a set of markers as a performance optimisation - if a mark
# name is in the set we definitely know it, but a mark may be known and
# not in the set. We therefore start by updating the set!
if name not in self._markers:
for line in self._config.getini("markers"):
# example lines: "skipif(condition): skip the given test if..."
# or "hypothesis: tests which use Hypothesis", so to get the
# marker name we split on both `:` and `(`.
marker = line.split(":")[0].split("(")[0].strip()
self._markers.add(marker)
# If the name is not in the set of known marks after updating,
# then it really is time to issue a warning or an error.
if name not in self._markers:
if self._config.option.strict_markers or self._config.option.strict:
fail(
f"{name!r} not found in `markers` configuration option",
pytrace=False,
)
# Raise a specific error for common misspellings of "parametrize".
if name in ["parameterize", "parametrise", "parameterise"]:
__tracebackhide__ = True
fail(f"Unknown '{name}' mark, did you mean 'parametrize'?")
warnings.warn(
"Unknown pytest.mark.%s - is this a typo? You can register "
"custom marks to avoid this warning - for details, see "
"https://docs.pytest.org/en/stable/how-to/mark.html" % name,
PytestUnknownMarkWarning,
2,
)
return MarkDecorator(Mark(name, (), {}, _ispytest=True), _ispytest=True)
MARK_GEN = MarkGenerator(_ispytest=True)
@final
class NodeKeywords(MutableMapping[str, Any]):
__slots__ = ("node", "parent", "_markers")
def __init__(self, node: "Node") -> None:
self.node = node
self.parent = node.parent
self._markers = {node.name: True}
def __getitem__(self, key: str) -> Any:
try:
return self._markers[key]
except KeyError:
if self.parent is None:
raise
return self.parent.keywords[key]
def __setitem__(self, key: str, value: Any) -> None:
self._markers[key] = value
# Note: we could've avoided explicitly implementing some of the methods
# below and use the collections.abc fallback, but that would be slow.
def __contains__(self, key: object) -> bool:
return (
key in self._markers
or self.parent is not None
and key in self.parent.keywords
)
def update( # type: ignore[override]
self,
other: Union[Mapping[str, Any], Iterable[Tuple[str, Any]]] = (),
**kwds: Any,
) -> None:
self._markers.update(other)
self._markers.update(kwds)
def __delitem__(self, key: str) -> None:
raise ValueError("cannot delete key in keywords dict")
def __iter__(self) -> Iterator[str]:
# Doesn't need to be fast.
yield from self._markers
if self.parent is not None:
for keyword in self.parent.keywords:
# self._marks and self.parent.keywords can have duplicates.
if keyword not in self._markers:
yield keyword
def __len__(self) -> int:
# Doesn't need to be fast.
return sum(1 for keyword in self)
def __repr__(self) -> str:
return f"<NodeKeywords for node {self.node}>"

View File

@ -0,0 +1,416 @@
"""Monkeypatching and mocking functionality."""
import os
import re
import sys
import warnings
from contextlib import contextmanager
from typing import Any
from typing import Generator
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import overload
from typing import Tuple
from typing import TypeVar
from typing import Union
from _pytest.compat import final
from _pytest.fixtures import fixture
from _pytest.warning_types import PytestWarning
RE_IMPORT_ERROR_NAME = re.compile(r"^No module named (.*)$")
K = TypeVar("K")
V = TypeVar("V")
@fixture
def monkeypatch() -> Generator["MonkeyPatch", None, None]:
"""A convenient fixture for monkey-patching.
The fixture provides these methods to modify objects, dictionaries, or
:data:`os.environ`:
* :meth:`monkeypatch.setattr(obj, name, value, raising=True) <pytest.MonkeyPatch.setattr>`
* :meth:`monkeypatch.delattr(obj, name, raising=True) <pytest.MonkeyPatch.delattr>`
* :meth:`monkeypatch.setitem(mapping, name, value) <pytest.MonkeyPatch.setitem>`
* :meth:`monkeypatch.delitem(obj, name, raising=True) <pytest.MonkeyPatch.delitem>`
* :meth:`monkeypatch.setenv(name, value, prepend=None) <pytest.MonkeyPatch.setenv>`
* :meth:`monkeypatch.delenv(name, raising=True) <pytest.MonkeyPatch.delenv>`
* :meth:`monkeypatch.syspath_prepend(path) <pytest.MonkeyPatch.syspath_prepend>`
* :meth:`monkeypatch.chdir(path) <pytest.MonkeyPatch.chdir>`
* :meth:`monkeypatch.context() <pytest.MonkeyPatch.context>`
All modifications will be undone after the requesting test function or
fixture has finished. The ``raising`` parameter determines if a :class:`KeyError`
or :class:`AttributeError` will be raised if the set/deletion operation does not have the
specified target.
To undo modifications done by the fixture in a contained scope,
use :meth:`context() <pytest.MonkeyPatch.context>`.
"""
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
def resolve(name: str) -> object:
# Simplified from zope.dottedname.
parts = name.split(".")
used = parts.pop(0)
found: object = __import__(used)
for part in parts:
used += "." + part
try:
found = getattr(found, part)
except AttributeError:
pass
else:
continue
# We use explicit un-nesting of the handling block in order
# to avoid nested exceptions.
try:
__import__(used)
except ImportError as ex:
expected = str(ex).split()[-1]
if expected == used:
raise
else:
raise ImportError(f"import error in {used}: {ex}") from ex
found = annotated_getattr(found, part, used)
return found
def annotated_getattr(obj: object, name: str, ann: str) -> object:
try:
obj = getattr(obj, name)
except AttributeError as e:
raise AttributeError(
"{!r} object at {} has no attribute {!r}".format(
type(obj).__name__, ann, name
)
) from e
return obj
def derive_importpath(import_path: str, raising: bool) -> Tuple[str, object]:
if not isinstance(import_path, str) or "." not in import_path:
raise TypeError(f"must be absolute import path string, not {import_path!r}")
module, attr = import_path.rsplit(".", 1)
target = resolve(module)
if raising:
annotated_getattr(target, attr, ann=module)
return attr, target
class Notset:
def __repr__(self) -> str:
return "<notset>"
notset = Notset()
@final
class MonkeyPatch:
"""Helper to conveniently monkeypatch attributes/items/environment
variables/syspath.
Returned by the :fixture:`monkeypatch` fixture.
.. versionchanged:: 6.2
Can now also be used directly as `pytest.MonkeyPatch()`, for when
the fixture is not available. In this case, use
:meth:`with MonkeyPatch.context() as mp: <context>` or remember to call
:meth:`undo` explicitly.
"""
def __init__(self) -> None:
self._setattr: List[Tuple[object, str, object]] = []
self._setitem: List[Tuple[MutableMapping[Any, Any], object, object]] = []
self._cwd: Optional[str] = None
self._savesyspath: Optional[List[str]] = None
@classmethod
@contextmanager
def context(cls) -> Generator["MonkeyPatch", None, None]:
"""Context manager that returns a new :class:`MonkeyPatch` object
which undoes any patching done inside the ``with`` block upon exit.
Example:
.. code-block:: python
import functools
def test_partial(monkeypatch):
with monkeypatch.context() as m:
m.setattr(functools, "partial", 3)
Useful in situations where it is desired to undo some patches before the test ends,
such as mocking ``stdlib`` functions that might break pytest itself if mocked (for examples
of this see :issue:`3290`).
"""
m = cls()
try:
yield m
finally:
m.undo()
@overload
def setattr(
self,
target: str,
name: object,
value: Notset = ...,
raising: bool = ...,
) -> None:
...
@overload
def setattr(
self,
target: object,
name: str,
value: object,
raising: bool = ...,
) -> None:
...
def setattr(
self,
target: Union[str, object],
name: Union[object, str],
value: object = notset,
raising: bool = True,
) -> None:
"""
Set attribute value on target, memorizing the old value.
For example:
.. code-block:: python
import os
monkeypatch.setattr(os, "getcwd", lambda: "/")
The code above replaces the :func:`os.getcwd` function by a ``lambda`` which
always returns ``"/"``.
For convenience, you can specify a string as ``target`` which
will be interpreted as a dotted import path, with the last part
being the attribute name:
.. code-block:: python
monkeypatch.setattr("os.getcwd", lambda: "/")
Raises :class:`AttributeError` if the attribute does not exist, unless
``raising`` is set to False.
**Where to patch**
``monkeypatch.setattr`` works by (temporarily) changing the object that a name points to with another one.
There can be many names pointing to any individual object, so for patching to work you must ensure
that you patch the name used by the system under test.
See the section :ref:`Where to patch <python:where-to-patch>` in the :mod:`unittest.mock`
docs for a complete explanation, which is meant for :func:`unittest.mock.patch` but
applies to ``monkeypatch.setattr`` as well.
"""
__tracebackhide__ = True
import inspect
if isinstance(value, Notset):
if not isinstance(target, str):
raise TypeError(
"use setattr(target, name, value) or "
"setattr(target, value) with target being a dotted "
"import string"
)
value = name
name, target = derive_importpath(target, raising)
else:
if not isinstance(name, str):
raise TypeError(
"use setattr(target, name, value) with name being a string or "
"setattr(target, value) with target being a dotted "
"import string"
)
oldval = getattr(target, name, notset)
if raising and oldval is notset:
raise AttributeError(f"{target!r} has no attribute {name!r}")
# avoid class descriptors like staticmethod/classmethod
if inspect.isclass(target):
oldval = target.__dict__.get(name, notset)
self._setattr.append((target, name, oldval))
setattr(target, name, value)
def delattr(
self,
target: Union[object, str],
name: Union[str, Notset] = notset,
raising: bool = True,
) -> None:
"""Delete attribute ``name`` from ``target``.
If no ``name`` is specified and ``target`` is a string
it will be interpreted as a dotted import path with the
last part being the attribute name.
Raises AttributeError it the attribute does not exist, unless
``raising`` is set to False.
"""
__tracebackhide__ = True
import inspect
if isinstance(name, Notset):
if not isinstance(target, str):
raise TypeError(
"use delattr(target, name) or "
"delattr(target) with target being a dotted "
"import string"
)
name, target = derive_importpath(target, raising)
if not hasattr(target, name):
if raising:
raise AttributeError(name)
else:
oldval = getattr(target, name, notset)
# Avoid class descriptors like staticmethod/classmethod.
if inspect.isclass(target):
oldval = target.__dict__.get(name, notset)
self._setattr.append((target, name, oldval))
delattr(target, name)
def setitem(self, dic: MutableMapping[K, V], name: K, value: V) -> None:
"""Set dictionary entry ``name`` to value."""
self._setitem.append((dic, name, dic.get(name, notset)))
dic[name] = value
def delitem(self, dic: MutableMapping[K, V], name: K, raising: bool = True) -> None:
"""Delete ``name`` from dict.
Raises ``KeyError`` if it doesn't exist, unless ``raising`` is set to
False.
"""
if name not in dic:
if raising:
raise KeyError(name)
else:
self._setitem.append((dic, name, dic.get(name, notset)))
del dic[name]
def setenv(self, name: str, value: str, prepend: Optional[str] = None) -> None:
"""Set environment variable ``name`` to ``value``.
If ``prepend`` is a character, read the current environment variable
value and prepend the ``value`` adjoined with the ``prepend``
character.
"""
if not isinstance(value, str):
warnings.warn( # type: ignore[unreachable]
PytestWarning(
"Value of environment variable {name} type should be str, but got "
"{value!r} (type: {type}); converted to str implicitly".format(
name=name, value=value, type=type(value).__name__
)
),
stacklevel=2,
)
value = str(value)
if prepend and name in os.environ:
value = value + prepend + os.environ[name]
self.setitem(os.environ, name, value)
def delenv(self, name: str, raising: bool = True) -> None:
"""Delete ``name`` from the environment.
Raises ``KeyError`` if it does not exist, unless ``raising`` is set to
False.
"""
environ: MutableMapping[str, str] = os.environ
self.delitem(environ, name, raising=raising)
def syspath_prepend(self, path) -> None:
"""Prepend ``path`` to ``sys.path`` list of import locations."""
if self._savesyspath is None:
self._savesyspath = sys.path[:]
sys.path.insert(0, str(path))
# https://github.com/pypa/setuptools/blob/d8b901bc/docs/pkg_resources.txt#L162-L171
# this is only needed when pkg_resources was already loaded by the namespace package
if "pkg_resources" in sys.modules:
from pkg_resources import fixup_namespace_packages
fixup_namespace_packages(str(path))
# A call to syspathinsert() usually means that the caller wants to
# import some dynamically created files, thus with python3 we
# invalidate its import caches.
# This is especially important when any namespace package is in use,
# since then the mtime based FileFinder cache (that gets created in
# this case already) gets not invalidated when writing the new files
# quickly afterwards.
from importlib import invalidate_caches
invalidate_caches()
def chdir(self, path: Union[str, "os.PathLike[str]"]) -> None:
"""Change the current working directory to the specified path.
:param path:
The path to change into.
"""
if self._cwd is None:
self._cwd = os.getcwd()
os.chdir(path)
def undo(self) -> None:
"""Undo previous changes.
This call consumes the undo stack. Calling it a second time has no
effect unless you do more monkeypatching after the undo call.
There is generally no need to call `undo()`, since it is
called automatically during tear-down.
.. note::
The same `monkeypatch` fixture is used across a
single test function invocation. If `monkeypatch` is used both by
the test function itself and one of the test fixtures,
calling `undo()` will undo all of the changes made in
both functions.
Prefer to use :meth:`context() <pytest.MonkeyPatch.context>` instead.
"""
for obj, name, value in reversed(self._setattr):
if value is not notset:
setattr(obj, name, value)
else:
delattr(obj, name)
self._setattr[:] = []
for dictionary, key, value in reversed(self._setitem):
if value is notset:
try:
del dictionary[key]
except KeyError:
pass # Was already deleted, so we have the desired state.
else:
dictionary[key] = value
self._setitem[:] = []
if self._savesyspath is not None:
sys.path[:] = self._savesyspath
self._savesyspath = None
if self._cwd is not None:
os.chdir(self._cwd)
self._cwd = None

View File

@ -0,0 +1,771 @@
import os
import warnings
from inspect import signature
from pathlib import Path
from typing import Any
from typing import Callable
from typing import cast
from typing import Iterable
from typing import Iterator
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import overload
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import _pytest._code
from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest.compat import cached_property
from _pytest.compat import LEGACY_PATH
from _pytest.config import Config
from _pytest.config import ConftestImportFailure
from _pytest.deprecated import FSCOLLECTOR_GETHOOKPROXY_ISINITPATH
from _pytest.deprecated import NODE_CTOR_FSPATH_ARG
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import NodeKeywords
from _pytest.outcomes import fail
from _pytest.pathlib import absolutepath
from _pytest.pathlib import commonpath
from _pytest.stash import Stash
from _pytest.warning_types import PytestWarning
if TYPE_CHECKING:
# Imported here due to circular import.
from _pytest.main import Session
from _pytest._code.code import _TracebackStyle
SEP = "/"
tracebackcutdir = Path(_pytest.__file__).parent
def iterparentnodeids(nodeid: str) -> Iterator[str]:
"""Return the parent node IDs of a given node ID, inclusive.
For the node ID
"testing/code/test_excinfo.py::TestFormattedExcinfo::test_repr_source"
the result would be
""
"testing"
"testing/code"
"testing/code/test_excinfo.py"
"testing/code/test_excinfo.py::TestFormattedExcinfo"
"testing/code/test_excinfo.py::TestFormattedExcinfo::test_repr_source"
Note that / components are only considered until the first ::.
"""
pos = 0
first_colons: Optional[int] = nodeid.find("::")
if first_colons == -1:
first_colons = None
# The root Session node - always present.
yield ""
# Eagerly consume SEP parts until first colons.
while True:
at = nodeid.find(SEP, pos, first_colons)
if at == -1:
break
if at > 0:
yield nodeid[:at]
pos = at + len(SEP)
# Eagerly consume :: parts.
while True:
at = nodeid.find("::", pos)
if at == -1:
break
if at > 0:
yield nodeid[:at]
pos = at + len("::")
# The node ID itself.
if nodeid:
yield nodeid
def _check_path(path: Path, fspath: LEGACY_PATH) -> None:
if Path(fspath) != path:
raise ValueError(
f"Path({fspath!r}) != {path!r}\n"
"if both path and fspath are given they need to be equal"
)
def _imply_path(
node_type: Type["Node"],
path: Optional[Path],
fspath: Optional[LEGACY_PATH],
) -> Path:
if fspath is not None:
warnings.warn(
NODE_CTOR_FSPATH_ARG.format(
node_type_name=node_type.__name__,
),
stacklevel=6,
)
if path is not None:
if fspath is not None:
_check_path(path, fspath)
return path
else:
assert fspath is not None
return Path(fspath)
_NodeType = TypeVar("_NodeType", bound="Node")
class NodeMeta(type):
def __call__(self, *k, **kw):
msg = (
"Direct construction of {name} has been deprecated, please use {name}.from_parent.\n"
"See "
"https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent"
" for more details."
).format(name=f"{self.__module__}.{self.__name__}")
fail(msg, pytrace=False)
def _create(self, *k, **kw):
try:
return super().__call__(*k, **kw)
except TypeError:
sig = signature(getattr(self, "__init__"))
known_kw = {k: v for k, v in kw.items() if k in sig.parameters}
from .warning_types import PytestDeprecationWarning
warnings.warn(
PytestDeprecationWarning(
f"{self} is not using a cooperative constructor and only takes {set(known_kw)}.\n"
"See https://docs.pytest.org/en/stable/deprecations.html"
"#constructors-of-custom-pytest-node-subclasses-should-take-kwargs "
"for more details."
)
)
return super().__call__(*k, **known_kw)
class Node(metaclass=NodeMeta):
"""Base class for Collector and Item, the components of the test
collection tree.
Collector subclasses have children; Items are leaf nodes.
"""
# Implemented in the legacypath plugin.
#: A ``LEGACY_PATH`` copy of the :attr:`path` attribute. Intended for usage
#: for methods not migrated to ``pathlib.Path`` yet, such as
#: :meth:`Item.reportinfo`. Will be deprecated in a future release, prefer
#: using :attr:`path` instead.
fspath: LEGACY_PATH
# Use __slots__ to make attribute access faster.
# Note that __dict__ is still available.
__slots__ = (
"name",
"parent",
"config",
"session",
"path",
"_nodeid",
"_store",
"__dict__",
)
def __init__(
self,
name: str,
parent: "Optional[Node]" = None,
config: Optional[Config] = None,
session: "Optional[Session]" = None,
fspath: Optional[LEGACY_PATH] = None,
path: Optional[Path] = None,
nodeid: Optional[str] = None,
) -> None:
#: A unique name within the scope of the parent node.
self.name: str = name
#: The parent collector node.
self.parent = parent
if config:
#: The pytest config object.
self.config: Config = config
else:
if not parent:
raise TypeError("config or parent must be provided")
self.config = parent.config
if session:
#: The pytest session this node is part of.
self.session: Session = session
else:
if not parent:
raise TypeError("session or parent must be provided")
self.session = parent.session
if path is None and fspath is None:
path = getattr(parent, "path", None)
#: Filesystem path where this node was collected from (can be None).
self.path: Path = _imply_path(type(self), path, fspath=fspath)
# The explicit annotation is to avoid publicly exposing NodeKeywords.
#: Keywords/markers collected from all scopes.
self.keywords: MutableMapping[str, Any] = NodeKeywords(self)
#: The marker objects belonging to this node.
self.own_markers: List[Mark] = []
#: Allow adding of extra keywords to use for matching.
self.extra_keyword_matches: Set[str] = set()
if nodeid is not None:
assert "::()" not in nodeid
self._nodeid = nodeid
else:
if not self.parent:
raise TypeError("nodeid or parent must be provided")
self._nodeid = self.parent.nodeid + "::" + self.name
#: A place where plugins can store information on the node for their
#: own use.
self.stash: Stash = Stash()
# Deprecated alias. Was never public. Can be removed in a few releases.
self._store = self.stash
@classmethod
def from_parent(cls, parent: "Node", **kw):
"""Public constructor for Nodes.
This indirection got introduced in order to enable removing
the fragile logic from the node constructors.
Subclasses can use ``super().from_parent(...)`` when overriding the
construction.
:param parent: The parent node of this Node.
"""
if "config" in kw:
raise TypeError("config is not a valid argument for from_parent")
if "session" in kw:
raise TypeError("session is not a valid argument for from_parent")
return cls._create(parent=parent, **kw)
@property
def ihook(self):
"""fspath-sensitive hook proxy used to call pytest hooks."""
return self.session.gethookproxy(self.path)
def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None))
def warn(self, warning: Warning) -> None:
"""Issue a warning for this Node.
Warnings will be displayed after the test session, unless explicitly suppressed.
:param Warning warning:
The warning instance to issue.
:raises ValueError: If ``warning`` instance is not a subclass of Warning.
Example usage:
.. code-block:: python
node.warn(PytestWarning("some message"))
node.warn(UserWarning("some message"))
.. versionchanged:: 6.2
Any subclass of :class:`Warning` is now accepted, rather than only
:class:`PytestWarning <pytest.PytestWarning>` subclasses.
"""
# enforce type checks here to avoid getting a generic type error later otherwise.
if not isinstance(warning, Warning):
raise ValueError(
"warning must be an instance of Warning or subclass, got {!r}".format(
warning
)
)
path, lineno = get_fslocation_from_item(self)
assert lineno is not None
warnings.warn_explicit(
warning,
category=None,
filename=str(path),
lineno=lineno + 1,
)
# Methods for ordering nodes.
@property
def nodeid(self) -> str:
"""A ::-separated string denoting its collection tree address."""
return self._nodeid
def __hash__(self) -> int:
return hash(self._nodeid)
def setup(self) -> None:
pass
def teardown(self) -> None:
pass
def listchain(self) -> List["Node"]:
"""Return list of all parent collectors up to self, starting from
the root of collection tree.
:returns: The nodes.
"""
chain = []
item: Optional[Node] = self
while item is not None:
chain.append(item)
item = item.parent
chain.reverse()
return chain
def add_marker(
self, marker: Union[str, MarkDecorator], append: bool = True
) -> None:
"""Dynamically add a marker object to the node.
:param marker:
The marker.
:param append:
Whether to append the marker, or prepend it.
"""
from _pytest.mark import MARK_GEN
if isinstance(marker, MarkDecorator):
marker_ = marker
elif isinstance(marker, str):
marker_ = getattr(MARK_GEN, marker)
else:
raise ValueError("is not a string or pytest.mark.* Marker")
self.keywords[marker_.name] = marker_
if append:
self.own_markers.append(marker_.mark)
else:
self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]:
"""Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute.
:returns: An iterator of the markers of the node.
"""
return (x[1] for x in self.iter_markers_with_node(name=name))
def iter_markers_with_node(
self, name: Optional[str] = None
) -> Iterator[Tuple["Node", Mark]]:
"""Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute.
:returns: An iterator of (node, mark) tuples.
"""
for node in reversed(self.listchain()):
for mark in node.own_markers:
if name is None or getattr(mark, "name", None) == name:
yield node, mark
@overload
def get_closest_marker(self, name: str) -> Optional[Mark]:
...
@overload
def get_closest_marker(self, name: str, default: Mark) -> Mark:
...
def get_closest_marker(
self, name: str, default: Optional[Mark] = None
) -> Optional[Mark]:
"""Return the first marker matching the name, from closest (for
example function) to farther level (for example module level).
:param default: Fallback return value if no marker was found.
:param name: Name to filter by.
"""
return next(self.iter_markers(name=name), default)
def listextrakeywords(self) -> Set[str]:
"""Return a set of all extra keywords in self and any parents."""
extra_keywords: Set[str] = set()
for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches)
return extra_keywords
def listnames(self) -> List[str]:
return [x.name for x in self.listchain()]
def addfinalizer(self, fin: Callable[[], object]) -> None:
"""Register a function to be called without arguments when this node is
finalized.
This method can only be called when this node is active
in a setup chain, for example during self.setup().
"""
self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]:
"""Get the next parent node (including self) which is an instance of
the given class.
:param cls: The node class to search for.
:returns: The node, if found.
"""
current: Optional[Node] = self
while current and not isinstance(current, cls):
current = current.parent
assert current is None or isinstance(current, cls)
return current
def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
pass
def _repr_failure_py(
self,
excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None,
) -> TerminalRepr:
from _pytest.fixtures import FixtureLookupError
if isinstance(excinfo.value, ConftestImportFailure):
excinfo = ExceptionInfo.from_exc_info(excinfo.value.excinfo)
if isinstance(excinfo.value, fail.Exception):
if not excinfo.value.pytrace:
style = "value"
if isinstance(excinfo.value, FixtureLookupError):
return excinfo.value.formatrepr()
if self.config.getoption("fulltrace", False):
style = "long"
else:
tb = _pytest._code.Traceback([excinfo.traceback[-1]])
self._prunetraceback(excinfo)
if len(excinfo.traceback) == 0:
excinfo.traceback = tb
if style == "auto":
style = "long"
# XXX should excinfo.getrepr record all data and toterminal() process it?
if style is None:
if self.config.getoption("tbstyle", "auto") == "short":
style = "short"
else:
style = "long"
if self.config.getoption("verbose", 0) > 1:
truncate_locals = False
else:
truncate_locals = True
# excinfo.getrepr() formats paths relative to the CWD if `abspath` is False.
# It is possible for a fixture/test to change the CWD while this code runs, which
# would then result in the user seeing confusing paths in the failure message.
# To fix this, if the CWD changed, always display the full absolute path.
# It will be better to just always display paths relative to invocation_dir, but
# this requires a lot of plumbing (#6428).
try:
abspath = Path(os.getcwd()) != self.config.invocation_params.dir
except OSError:
abspath = True
return excinfo.getrepr(
funcargs=True,
abspath=abspath,
showlocals=self.config.getoption("showlocals", False),
style=style,
tbfilter=False, # pruned already, or in --fulltrace mode.
truncate_locals=truncate_locals,
)
def repr_failure(
self,
excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None,
) -> Union[str, TerminalRepr]:
"""Return a representation of a collection or test failure.
.. seealso:: :ref:`non-python tests`
:param excinfo: Exception information for the failure.
"""
return self._repr_failure_py(excinfo, style)
def get_fslocation_from_item(node: "Node") -> Tuple[Union[str, Path], Optional[int]]:
"""Try to extract the actual location from a node, depending on available attributes:
* "location": a pair (path, lineno)
* "obj": a Python object that the node wraps.
* "fspath": just a path
:rtype: A tuple of (str|Path, int) with filename and line number.
"""
# See Item.location.
location: Optional[Tuple[str, Optional[int], str]] = getattr(node, "location", None)
if location is not None:
return location[:2]
obj = getattr(node, "obj", None)
if obj is not None:
return getfslineno(obj)
return getattr(node, "fspath", "unknown location"), -1
class Collector(Node):
"""Collector instances create children through collect() and thus
iteratively build a tree."""
class CollectError(Exception):
"""An error during collection, contains a custom message."""
def collect(self) -> Iterable[Union["Item", "Collector"]]:
"""Return a list of children (items and collectors) for this
collection node."""
raise NotImplementedError("abstract")
# TODO: This omits the style= parameter which breaks Liskov Substitution.
def repr_failure( # type: ignore[override]
self, excinfo: ExceptionInfo[BaseException]
) -> Union[str, TerminalRepr]:
"""Return a representation of a collection failure.
:param excinfo: Exception information for the failure.
"""
if isinstance(excinfo.value, self.CollectError) and not self.config.getoption(
"fulltrace", False
):
exc = excinfo.value
return str(exc.args[0])
# Respect explicit tbstyle option, but default to "short"
# (_repr_failure_py uses "long" with "fulltrace" option always).
tbstyle = self.config.getoption("tbstyle", "auto")
if tbstyle == "auto":
tbstyle = "short"
return self._repr_failure_py(excinfo, style=tbstyle)
def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
if hasattr(self, "path"):
traceback = excinfo.traceback
ntraceback = traceback.cut(path=self.path)
if ntraceback == traceback:
ntraceback = ntraceback.cut(excludepath=tracebackcutdir)
excinfo.traceback = ntraceback.filter()
def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]:
for initial_path in session._initialpaths:
if commonpath(path, initial_path) == initial_path:
rel = str(path.relative_to(initial_path))
return "" if rel == "." else rel
return None
class FSCollector(Collector):
def __init__(
self,
fspath: Optional[LEGACY_PATH] = None,
path_or_parent: Optional[Union[Path, Node]] = None,
path: Optional[Path] = None,
name: Optional[str] = None,
parent: Optional[Node] = None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
) -> None:
if path_or_parent:
if isinstance(path_or_parent, Node):
assert parent is None
parent = cast(FSCollector, path_or_parent)
elif isinstance(path_or_parent, Path):
assert path is None
path = path_or_parent
path = _imply_path(type(self), path, fspath=fspath)
if name is None:
name = path.name
if parent is not None and parent.path != path:
try:
rel = path.relative_to(parent.path)
except ValueError:
pass
else:
name = str(rel)
name = name.replace(os.sep, SEP)
self.path = path
if session is None:
assert parent is not None
session = parent.session
if nodeid is None:
try:
nodeid = str(self.path.relative_to(session.config.rootpath))
except ValueError:
nodeid = _check_initialpaths_for_relpath(session, path)
if nodeid and os.sep != SEP:
nodeid = nodeid.replace(os.sep, SEP)
super().__init__(
name=name,
parent=parent,
config=config,
session=session,
nodeid=nodeid,
path=path,
)
@classmethod
def from_parent(
cls,
parent,
*,
fspath: Optional[LEGACY_PATH] = None,
path: Optional[Path] = None,
**kw,
):
"""The public constructor."""
return super().from_parent(parent=parent, fspath=fspath, path=path, **kw)
def gethookproxy(self, fspath: "os.PathLike[str]"):
warnings.warn(FSCOLLECTOR_GETHOOKPROXY_ISINITPATH, stacklevel=2)
return self.session.gethookproxy(fspath)
def isinitpath(self, path: Union[str, "os.PathLike[str]"]) -> bool:
warnings.warn(FSCOLLECTOR_GETHOOKPROXY_ISINITPATH, stacklevel=2)
return self.session.isinitpath(path)
class File(FSCollector):
"""Base class for collecting tests from a file.
:ref:`non-python tests`.
"""
class Item(Node):
"""A basic test invocation item.
Note that for a single function there might be multiple test invocation items.
"""
nextitem = None
def __init__(
self,
name,
parent=None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
**kw,
) -> None:
# The first two arguments are intentionally passed positionally,
# to keep plugins who define a node type which inherits from
# (pytest.Item, pytest.File) working (see issue #8435).
# They can be made kwargs when the deprecation above is done.
super().__init__(
name,
parent,
config=config,
session=session,
nodeid=nodeid,
**kw,
)
self._report_sections: List[Tuple[str, str, str]] = []
#: A list of tuples (name, value) that holds user defined properties
#: for this test.
self.user_properties: List[Tuple[str, object]] = []
self._check_item_and_collector_diamond_inheritance()
def _check_item_and_collector_diamond_inheritance(self) -> None:
"""
Check if the current type inherits from both File and Collector
at the same time, emitting a warning accordingly (#8447).
"""
cls = type(self)
# We inject an attribute in the type to avoid issuing this warning
# for the same class more than once, which is not helpful.
# It is a hack, but was deemed acceptable in order to avoid
# flooding the user in the common case.
attr_name = "_pytest_diamond_inheritance_warning_shown"
if getattr(cls, attr_name, False):
return
setattr(cls, attr_name, True)
problems = ", ".join(
base.__name__ for base in cls.__bases__ if issubclass(base, Collector)
)
if problems:
warnings.warn(
f"{cls.__name__} is an Item subclass and should not be a collector, "
f"however its bases {problems} are collectors.\n"
"Please split the Collectors and the Item into separate node types.\n"
"Pytest Doc example: https://docs.pytest.org/en/latest/example/nonpython.html\n"
"example pull request on a plugin: https://github.com/asmeurer/pytest-flakes/pull/40/",
PytestWarning,
)
def runtest(self) -> None:
"""Run the test case for this item.
Must be implemented by subclasses.
.. seealso:: :ref:`non-python tests`
"""
raise NotImplementedError("runtest must be implemented by Item subclass")
def add_report_section(self, when: str, key: str, content: str) -> None:
"""Add a new report section, similar to what's done internally to add
stdout and stderr captured output::
item.add_report_section("call", "stdout", "report section contents")
:param str when:
One of the possible capture states, ``"setup"``, ``"call"``, ``"teardown"``.
:param str key:
Name of the section, can be customized at will. Pytest uses ``"stdout"`` and
``"stderr"`` internally.
:param str content:
The full contents as a string.
"""
if content:
self._report_sections.append((when, key, content))
def reportinfo(self) -> Tuple[Union["os.PathLike[str]", str], Optional[int], str]:
"""Get location information for this item for test reports.
Returns a tuple with three elements:
- The path of the test (default ``self.path``)
- The line number of the test (default ``None``)
- A name of the test to be shown (default ``""``)
.. seealso:: :ref:`non-python tests`
"""
return self.path, None, ""
@cached_property
def location(self) -> Tuple[str, Optional[int], str]:
location = self.reportinfo()
path = absolutepath(os.fspath(location[0]))
relfspath = self.session._node_location_to_relpath(path)
assert type(location[2]) is str
return (relfspath, location[1], location[2])

View File

@ -0,0 +1,50 @@
"""Run testsuites written for nose."""
import warnings
from _pytest.config import hookimpl
from _pytest.deprecated import NOSE_SUPPORT
from _pytest.fixtures import getfixturemarker
from _pytest.nodes import Item
from _pytest.python import Function
from _pytest.unittest import TestCaseFunction
@hookimpl(trylast=True)
def pytest_runtest_setup(item: Item) -> None:
if not isinstance(item, Function):
return
# Don't do nose style setup/teardown on direct unittest style classes.
if isinstance(item, TestCaseFunction):
return
# Capture the narrowed type of item for the teardown closure,
# see https://github.com/python/mypy/issues/2608
func = item
call_optional(func.obj, "setup", func.nodeid)
func.addfinalizer(lambda: call_optional(func.obj, "teardown", func.nodeid))
# NOTE: Module- and class-level fixtures are handled in python.py
# with `pluginmanager.has_plugin("nose")` checks.
# It would have been nicer to implement them outside of core, but
# it's not straightforward.
def call_optional(obj: object, name: str, nodeid: str) -> bool:
method = getattr(obj, name, None)
if method is None:
return False
is_fixture = getfixturemarker(method) is not None
if is_fixture:
return False
if not callable(method):
return False
# Warn about deprecation of this plugin.
method_name = getattr(method, "__name__", str(method))
warnings.warn(
NOSE_SUPPORT.format(nodeid=nodeid, method=method_name, stage=name), stacklevel=2
)
# If there are any problems allow the exception to raise rather than
# silently ignoring it.
method()
return True

View File

@ -0,0 +1,308 @@
"""Exception classes and constants handling test outcomes as well as
functions creating them."""
import sys
import warnings
from typing import Any
from typing import Callable
from typing import cast
from typing import NoReturn
from typing import Optional
from typing import Type
from typing import TypeVar
from _pytest.deprecated import KEYWORD_MSG_ARG
TYPE_CHECKING = False # Avoid circular import through compat.
if TYPE_CHECKING:
from typing_extensions import Protocol
else:
# typing.Protocol is only available starting from Python 3.8. It is also
# available from typing_extensions, but we don't want a runtime dependency
# on that. So use a dummy runtime implementation.
from typing import Generic
Protocol = Generic
class OutcomeException(BaseException):
"""OutcomeException and its subclass instances indicate and contain info
about test and collection outcomes."""
def __init__(self, msg: Optional[str] = None, pytrace: bool = True) -> None:
if msg is not None and not isinstance(msg, str):
error_msg = ( # type: ignore[unreachable]
"{} expected string as 'msg' parameter, got '{}' instead.\n"
"Perhaps you meant to use a mark?"
)
raise TypeError(error_msg.format(type(self).__name__, type(msg).__name__))
super().__init__(msg)
self.msg = msg
self.pytrace = pytrace
def __repr__(self) -> str:
if self.msg is not None:
return self.msg
return f"<{self.__class__.__name__} instance>"
__str__ = __repr__
TEST_OUTCOME = (OutcomeException, Exception)
class Skipped(OutcomeException):
# XXX hackish: on 3k we fake to live in the builtins
# in order to have Skipped exception printing shorter/nicer
__module__ = "builtins"
def __init__(
self,
msg: Optional[str] = None,
pytrace: bool = True,
allow_module_level: bool = False,
*,
_use_item_location: bool = False,
) -> None:
super().__init__(msg=msg, pytrace=pytrace)
self.allow_module_level = allow_module_level
# If true, the skip location is reported as the item's location,
# instead of the place that raises the exception/calls skip().
self._use_item_location = _use_item_location
class Failed(OutcomeException):
"""Raised from an explicit call to pytest.fail()."""
__module__ = "builtins"
class Exit(Exception):
"""Raised for immediate program exits (no tracebacks/summaries)."""
def __init__(
self, msg: str = "unknown reason", returncode: Optional[int] = None
) -> None:
self.msg = msg
self.returncode = returncode
super().__init__(msg)
# Elaborate hack to work around https://github.com/python/mypy/issues/2087.
# Ideally would just be `exit.Exception = Exit` etc.
_F = TypeVar("_F", bound=Callable[..., object])
_ET = TypeVar("_ET", bound=Type[BaseException])
class _WithException(Protocol[_F, _ET]):
Exception: _ET
__call__: _F
def _with_exception(exception_type: _ET) -> Callable[[_F], _WithException[_F, _ET]]:
def decorate(func: _F) -> _WithException[_F, _ET]:
func_with_exception = cast(_WithException[_F, _ET], func)
func_with_exception.Exception = exception_type
return func_with_exception
return decorate
# Exposed helper methods.
@_with_exception(Exit)
def exit(
reason: str = "", returncode: Optional[int] = None, *, msg: Optional[str] = None
) -> NoReturn:
"""Exit testing process.
:param reason:
The message to show as the reason for exiting pytest. reason has a default value
only because `msg` is deprecated.
:param returncode:
Return code to be used when exiting pytest.
:param msg:
Same as ``reason``, but deprecated. Will be removed in a future version, use ``reason`` instead.
"""
__tracebackhide__ = True
from _pytest.config import UsageError
if reason and msg:
raise UsageError(
"cannot pass reason and msg to exit(), `msg` is deprecated, use `reason`."
)
if not reason:
if msg is None:
raise UsageError("exit() requires a reason argument")
warnings.warn(KEYWORD_MSG_ARG.format(func="exit"), stacklevel=2)
reason = msg
raise Exit(reason, returncode)
@_with_exception(Skipped)
def skip(
reason: str = "", *, allow_module_level: bool = False, msg: Optional[str] = None
) -> NoReturn:
"""Skip an executing test with the given message.
This function should be called only during testing (setup, call or teardown) or
during collection by using the ``allow_module_level`` flag. This function can
be called in doctests as well.
:param reason:
The message to show the user as reason for the skip.
:param allow_module_level:
Allows this function to be called at module level, skipping the rest
of the module. Defaults to False.
:param msg:
Same as ``reason``, but deprecated. Will be removed in a future version, use ``reason`` instead.
.. note::
It is better to use the :ref:`pytest.mark.skipif ref` marker when
possible to declare a test to be skipped under certain conditions
like mismatching platforms or dependencies.
Similarly, use the ``# doctest: +SKIP`` directive (see :py:data:`doctest.SKIP`)
to skip a doctest statically.
"""
__tracebackhide__ = True
reason = _resolve_msg_to_reason("skip", reason, msg)
raise Skipped(msg=reason, allow_module_level=allow_module_level)
@_with_exception(Failed)
def fail(reason: str = "", pytrace: bool = True, msg: Optional[str] = None) -> NoReturn:
"""Explicitly fail an executing test with the given message.
:param reason:
The message to show the user as reason for the failure.
:param pytrace:
If False, msg represents the full failure information and no
python traceback will be reported.
:param msg:
Same as ``reason``, but deprecated. Will be removed in a future version, use ``reason`` instead.
"""
__tracebackhide__ = True
reason = _resolve_msg_to_reason("fail", reason, msg)
raise Failed(msg=reason, pytrace=pytrace)
def _resolve_msg_to_reason(
func_name: str, reason: str, msg: Optional[str] = None
) -> str:
"""
Handles converting the deprecated msg parameter if provided into
reason, raising a deprecation warning. This function will be removed
when the optional msg argument is removed from here in future.
:param str func_name:
The name of the offending function, this is formatted into the deprecation message.
:param str reason:
The reason= passed into either pytest.fail() or pytest.skip()
:param str msg:
The msg= passed into either pytest.fail() or pytest.skip(). This will
be converted into reason if it is provided to allow pytest.skip(msg=) or
pytest.fail(msg=) to continue working in the interim period.
:returns:
The value to use as reason.
"""
__tracebackhide__ = True
if msg is not None:
if reason:
from pytest import UsageError
raise UsageError(
f"Passing both ``reason`` and ``msg`` to pytest.{func_name}(...) is not permitted."
)
warnings.warn(KEYWORD_MSG_ARG.format(func=func_name), stacklevel=3)
reason = msg
return reason
class XFailed(Failed):
"""Raised from an explicit call to pytest.xfail()."""
@_with_exception(XFailed)
def xfail(reason: str = "") -> NoReturn:
"""Imperatively xfail an executing test or setup function with the given reason.
This function should be called only during testing (setup, call or teardown).
:param reason:
The message to show the user as reason for the xfail.
.. note::
It is better to use the :ref:`pytest.mark.xfail ref` marker when
possible to declare a test to be xfailed under certain conditions
like known bugs or missing features.
"""
__tracebackhide__ = True
raise XFailed(reason)
def importorskip(
modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
) -> Any:
"""Import and return the requested module ``modname``, or skip the
current test if the module cannot be imported.
:param modname:
The name of the module to import.
:param minversion:
If given, the imported module's ``__version__`` attribute must be at
least this minimal version, otherwise the test is still skipped.
:param reason:
If given, this reason is shown as the message when the module cannot
be imported.
:returns:
The imported module. This should be assigned to its canonical name.
Example::
docutils = pytest.importorskip("docutils")
"""
import warnings
__tracebackhide__ = True
compile(modname, "", "eval") # to catch syntaxerrors
with warnings.catch_warnings():
# Make sure to ignore ImportWarnings that might happen because
# of existing directories with the same name we're trying to
# import but without a __init__.py file.
warnings.simplefilter("ignore")
try:
__import__(modname)
except ImportError as exc:
if reason is None:
reason = f"could not import {modname!r}: {exc}"
raise Skipped(reason, allow_module_level=True) from None
mod = sys.modules[modname]
if minversion is None:
return mod
verattr = getattr(mod, "__version__", None)
if minversion is not None:
# Imported lazily to improve start-up time.
from packaging.version import Version
if verattr is None or Version(verattr) < Version(minversion):
raise Skipped(
"module %r has __version__ %r, required is: %r"
% (modname, verattr, minversion),
allow_module_level=True,
)
return mod

View File

@ -0,0 +1,110 @@
"""Submit failure or test session information to a pastebin service."""
import tempfile
from io import StringIO
from typing import IO
from typing import Union
import pytest
from _pytest.config import Config
from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter
pastebinfile_key = StashKey[IO[bytes]]()
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group._addoption(
"--pastebin",
metavar="mode",
action="store",
dest="pastebin",
default=None,
choices=["failed", "all"],
help="Send failed|all info to bpaste.net pastebin service",
)
@pytest.hookimpl(trylast=True)
def pytest_configure(config: Config) -> None:
if config.option.pastebin == "all":
tr = config.pluginmanager.getplugin("terminalreporter")
# If no terminal reporter plugin is present, nothing we can do here;
# this can happen when this function executes in a worker node
# when using pytest-xdist, for example.
if tr is not None:
# pastebin file will be UTF-8 encoded binary file.
config.stash[pastebinfile_key] = tempfile.TemporaryFile("w+b")
oldwrite = tr._tw.write
def tee_write(s, **kwargs):
oldwrite(s, **kwargs)
if isinstance(s, str):
s = s.encode("utf-8")
config.stash[pastebinfile_key].write(s)
tr._tw.write = tee_write
def pytest_unconfigure(config: Config) -> None:
if pastebinfile_key in config.stash:
pastebinfile = config.stash[pastebinfile_key]
# Get terminal contents and delete file.
pastebinfile.seek(0)
sessionlog = pastebinfile.read()
pastebinfile.close()
del config.stash[pastebinfile_key]
# Undo our patching in the terminal reporter.
tr = config.pluginmanager.getplugin("terminalreporter")
del tr._tw.__dict__["write"]
# Write summary.
tr.write_sep("=", "Sending information to Paste Service")
pastebinurl = create_new_paste(sessionlog)
tr.write_line("pastebin session-log: %s\n" % pastebinurl)
def create_new_paste(contents: Union[str, bytes]) -> str:
"""Create a new paste using the bpaste.net service.
:contents: Paste contents string.
:returns: URL to the pasted contents, or an error message.
"""
import re
from urllib.request import urlopen
from urllib.parse import urlencode
params = {"code": contents, "lexer": "text", "expiry": "1week"}
url = "https://bpa.st"
try:
response: str = (
urlopen(url, data=urlencode(params).encode("ascii")).read().decode("utf-8")
)
except OSError as exc_info: # urllib errors
return "bad response: %s" % exc_info
m = re.search(r'href="/raw/(\w+)"', response)
if m:
return f"{url}/show/{m.group(1)}"
else:
return "bad response: invalid format ('" + response + "')"
def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
if terminalreporter.config.option.pastebin != "failed":
return
if "failed" in terminalreporter.stats:
terminalreporter.write_sep("=", "Sending information to Paste Service")
for rep in terminalreporter.stats["failed"]:
try:
msg = rep.longrepr.reprtraceback.reprentries[-1].reprfileloc
except AttributeError:
msg = terminalreporter._getfailureheadline(rep)
file = StringIO()
tw = create_terminal_writer(terminalreporter.config, file)
rep.toterminal(tw)
s = file.getvalue()
assert len(s)
pastebinurl = create_new_paste(s)
terminalreporter.write_line(f"{msg} --> {pastebinurl}")

View File

@ -0,0 +1,735 @@
import atexit
import contextlib
import fnmatch
import importlib.util
import itertools
import os
import shutil
import sys
import uuid
import warnings
from enum import Enum
from errno import EBADF
from errno import ELOOP
from errno import ENOENT
from errno import ENOTDIR
from functools import partial
from os.path import expanduser
from os.path import expandvars
from os.path import isabs
from os.path import sep
from pathlib import Path
from pathlib import PurePath
from posixpath import sep as posix_sep
from types import ModuleType
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import Optional
from typing import Set
from typing import TypeVar
from typing import Union
from _pytest.compat import assert_never
from _pytest.outcomes import skip
from _pytest.warning_types import PytestWarning
LOCK_TIMEOUT = 60 * 60 * 24 * 3
_AnyPurePath = TypeVar("_AnyPurePath", bound=PurePath)
# The following function, variables and comments were
# copied from cpython 3.9 Lib/pathlib.py file.
# EBADF - guard against macOS `stat` throwing EBADF
_IGNORED_ERRORS = (ENOENT, ENOTDIR, EBADF, ELOOP)
_IGNORED_WINERRORS = (
21, # ERROR_NOT_READY - drive exists but is not accessible
1921, # ERROR_CANT_RESOLVE_FILENAME - fix for broken symlink pointing to itself
)
def _ignore_error(exception):
return (
getattr(exception, "errno", None) in _IGNORED_ERRORS
or getattr(exception, "winerror", None) in _IGNORED_WINERRORS
)
def get_lock_path(path: _AnyPurePath) -> _AnyPurePath:
return path.joinpath(".lock")
def on_rm_rf_error(func, path: str, exc, *, start_path: Path) -> bool:
"""Handle known read-only errors during rmtree.
The returned value is used only by our own tests.
"""
exctype, excvalue = exc[:2]
# Another process removed the file in the middle of the "rm_rf" (xdist for example).
# More context: https://github.com/pytest-dev/pytest/issues/5974#issuecomment-543799018
if isinstance(excvalue, FileNotFoundError):
return False
if not isinstance(excvalue, PermissionError):
warnings.warn(
PytestWarning(f"(rm_rf) error removing {path}\n{exctype}: {excvalue}")
)
return False
if func not in (os.rmdir, os.remove, os.unlink):
if func not in (os.open,):
warnings.warn(
PytestWarning(
"(rm_rf) unknown function {} when removing {}:\n{}: {}".format(
func, path, exctype, excvalue
)
)
)
return False
# Chmod + retry.
import stat
def chmod_rw(p: str) -> None:
mode = os.stat(p).st_mode
os.chmod(p, mode | stat.S_IRUSR | stat.S_IWUSR)
# For files, we need to recursively go upwards in the directories to
# ensure they all are also writable.
p = Path(path)
if p.is_file():
for parent in p.parents:
chmod_rw(str(parent))
# Stop when we reach the original path passed to rm_rf.
if parent == start_path:
break
chmod_rw(str(path))
func(path)
return True
def ensure_extended_length_path(path: Path) -> Path:
"""Get the extended-length version of a path (Windows).
On Windows, by default, the maximum length of a path (MAX_PATH) is 260
characters, and operations on paths longer than that fail. But it is possible
to overcome this by converting the path to "extended-length" form before
performing the operation:
https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#maximum-path-length-limitation
On Windows, this function returns the extended-length absolute version of path.
On other platforms it returns path unchanged.
"""
if sys.platform.startswith("win32"):
path = path.resolve()
path = Path(get_extended_length_path_str(str(path)))
return path
def get_extended_length_path_str(path: str) -> str:
"""Convert a path to a Windows extended length path."""
long_path_prefix = "\\\\?\\"
unc_long_path_prefix = "\\\\?\\UNC\\"
if path.startswith((long_path_prefix, unc_long_path_prefix)):
return path
# UNC
if path.startswith("\\\\"):
return unc_long_path_prefix + path[2:]
return long_path_prefix + path
def rm_rf(path: Path) -> None:
"""Remove the path contents recursively, even if some elements
are read-only."""
path = ensure_extended_length_path(path)
onerror = partial(on_rm_rf_error, start_path=path)
shutil.rmtree(str(path), onerror=onerror)
def find_prefixed(root: Path, prefix: str) -> Iterator[Path]:
"""Find all elements in root that begin with the prefix, case insensitive."""
l_prefix = prefix.lower()
for x in root.iterdir():
if x.name.lower().startswith(l_prefix):
yield x
def extract_suffixes(iter: Iterable[PurePath], prefix: str) -> Iterator[str]:
"""Return the parts of the paths following the prefix.
:param iter: Iterator over path names.
:param prefix: Expected prefix of the path names.
"""
p_len = len(prefix)
for p in iter:
yield p.name[p_len:]
def find_suffixes(root: Path, prefix: str) -> Iterator[str]:
"""Combine find_prefixes and extract_suffixes."""
return extract_suffixes(find_prefixed(root, prefix), prefix)
def parse_num(maybe_num) -> int:
"""Parse number path suffixes, returns -1 on error."""
try:
return int(maybe_num)
except ValueError:
return -1
def _force_symlink(
root: Path, target: Union[str, PurePath], link_to: Union[str, Path]
) -> None:
"""Helper to create the current symlink.
It's full of race conditions that are reasonably OK to ignore
for the context of best effort linking to the latest test run.
The presumption being that in case of much parallelism
the inaccuracy is going to be acceptable.
"""
current_symlink = root.joinpath(target)
try:
current_symlink.unlink()
except OSError:
pass
try:
current_symlink.symlink_to(link_to)
except Exception:
pass
def make_numbered_dir(root: Path, prefix: str, mode: int = 0o700) -> Path:
"""Create a directory with an increased number as suffix for the given prefix."""
for i in range(10):
# try up to 10 times to create the folder
max_existing = max(map(parse_num, find_suffixes(root, prefix)), default=-1)
new_number = max_existing + 1
new_path = root.joinpath(f"{prefix}{new_number}")
try:
new_path.mkdir(mode=mode)
except Exception:
pass
else:
_force_symlink(root, prefix + "current", new_path)
return new_path
else:
raise OSError(
"could not create numbered dir with prefix "
"{prefix} in {root} after 10 tries".format(prefix=prefix, root=root)
)
def create_cleanup_lock(p: Path) -> Path:
"""Create a lock to prevent premature folder cleanup."""
lock_path = get_lock_path(p)
try:
fd = os.open(str(lock_path), os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o644)
except FileExistsError as e:
raise OSError(f"cannot create lockfile in {p}") from e
else:
pid = os.getpid()
spid = str(pid).encode()
os.write(fd, spid)
os.close(fd)
if not lock_path.is_file():
raise OSError("lock path got renamed after successful creation")
return lock_path
def register_cleanup_lock_removal(lock_path: Path, register=atexit.register):
"""Register a cleanup function for removing a lock, by default on atexit."""
pid = os.getpid()
def cleanup_on_exit(lock_path: Path = lock_path, original_pid: int = pid) -> None:
current_pid = os.getpid()
if current_pid != original_pid:
# fork
return
try:
lock_path.unlink()
except OSError:
pass
return register(cleanup_on_exit)
def maybe_delete_a_numbered_dir(path: Path) -> None:
"""Remove a numbered directory if its lock can be obtained and it does
not seem to be in use."""
path = ensure_extended_length_path(path)
lock_path = None
try:
lock_path = create_cleanup_lock(path)
parent = path.parent
garbage = parent.joinpath(f"garbage-{uuid.uuid4()}")
path.rename(garbage)
rm_rf(garbage)
except OSError:
# known races:
# * other process did a cleanup at the same time
# * deletable folder was found
# * process cwd (Windows)
return
finally:
# If we created the lock, ensure we remove it even if we failed
# to properly remove the numbered dir.
if lock_path is not None:
try:
lock_path.unlink()
except OSError:
pass
def ensure_deletable(path: Path, consider_lock_dead_if_created_before: float) -> bool:
"""Check if `path` is deletable based on whether the lock file is expired."""
if path.is_symlink():
return False
lock = get_lock_path(path)
try:
if not lock.is_file():
return True
except OSError:
# we might not have access to the lock file at all, in this case assume
# we don't have access to the entire directory (#7491).
return False
try:
lock_time = lock.stat().st_mtime
except Exception:
return False
else:
if lock_time < consider_lock_dead_if_created_before:
# We want to ignore any errors while trying to remove the lock such as:
# - PermissionDenied, like the file permissions have changed since the lock creation;
# - FileNotFoundError, in case another pytest process got here first;
# and any other cause of failure.
with contextlib.suppress(OSError):
lock.unlink()
return True
return False
def try_cleanup(path: Path, consider_lock_dead_if_created_before: float) -> None:
"""Try to cleanup a folder if we can ensure it's deletable."""
if ensure_deletable(path, consider_lock_dead_if_created_before):
maybe_delete_a_numbered_dir(path)
def cleanup_candidates(root: Path, prefix: str, keep: int) -> Iterator[Path]:
"""List candidates for numbered directories to be removed - follows py.path."""
max_existing = max(map(parse_num, find_suffixes(root, prefix)), default=-1)
max_delete = max_existing - keep
paths = find_prefixed(root, prefix)
paths, paths2 = itertools.tee(paths)
numbers = map(parse_num, extract_suffixes(paths2, prefix))
for path, number in zip(paths, numbers):
if number <= max_delete:
yield path
def cleanup_numbered_dir(
root: Path, prefix: str, keep: int, consider_lock_dead_if_created_before: float
) -> None:
"""Cleanup for lock driven numbered directories."""
for path in cleanup_candidates(root, prefix, keep):
try_cleanup(path, consider_lock_dead_if_created_before)
for path in root.glob("garbage-*"):
try_cleanup(path, consider_lock_dead_if_created_before)
def make_numbered_dir_with_cleanup(
root: Path,
prefix: str,
keep: int,
lock_timeout: float,
mode: int,
) -> Path:
"""Create a numbered dir with a cleanup lock and remove old ones."""
e = None
for i in range(10):
try:
p = make_numbered_dir(root, prefix, mode)
lock_path = create_cleanup_lock(p)
register_cleanup_lock_removal(lock_path)
except Exception as exc:
e = exc
else:
consider_lock_dead_if_created_before = p.stat().st_mtime - lock_timeout
# Register a cleanup for program exit
atexit.register(
cleanup_numbered_dir,
root,
prefix,
keep,
consider_lock_dead_if_created_before,
)
return p
assert e is not None
raise e
def resolve_from_str(input: str, rootpath: Path) -> Path:
input = expanduser(input)
input = expandvars(input)
if isabs(input):
return Path(input)
else:
return rootpath.joinpath(input)
def fnmatch_ex(pattern: str, path: Union[str, "os.PathLike[str]"]) -> bool:
"""A port of FNMatcher from py.path.common which works with PurePath() instances.
The difference between this algorithm and PurePath.match() is that the
latter matches "**" glob expressions for each part of the path, while
this algorithm uses the whole path instead.
For example:
"tests/foo/bar/doc/test_foo.py" matches pattern "tests/**/doc/test*.py"
with this algorithm, but not with PurePath.match().
This algorithm was ported to keep backward-compatibility with existing
settings which assume paths match according this logic.
References:
* https://bugs.python.org/issue29249
* https://bugs.python.org/issue34731
"""
path = PurePath(path)
iswin32 = sys.platform.startswith("win")
if iswin32 and sep not in pattern and posix_sep in pattern:
# Running on Windows, the pattern has no Windows path separators,
# and the pattern has one or more Posix path separators. Replace
# the Posix path separators with the Windows path separator.
pattern = pattern.replace(posix_sep, sep)
if sep not in pattern:
name = path.name
else:
name = str(path)
if path.is_absolute() and not os.path.isabs(pattern):
pattern = f"*{os.sep}{pattern}"
return fnmatch.fnmatch(name, pattern)
def parts(s: str) -> Set[str]:
parts = s.split(sep)
return {sep.join(parts[: i + 1]) or sep for i in range(len(parts))}
def symlink_or_skip(src, dst, **kwargs):
"""Make a symlink, or skip the test in case symlinks are not supported."""
try:
os.symlink(str(src), str(dst), **kwargs)
except OSError as e:
skip(f"symlinks not supported: {e}")
class ImportMode(Enum):
"""Possible values for `mode` parameter of `import_path`."""
prepend = "prepend"
append = "append"
importlib = "importlib"
class ImportPathMismatchError(ImportError):
"""Raised on import_path() if there is a mismatch of __file__'s.
This can happen when `import_path` is called multiple times with different filenames that has
the same basename but reside in packages
(for example "/tests1/test_foo.py" and "/tests2/test_foo.py").
"""
def import_path(
p: Union[str, "os.PathLike[str]"],
*,
mode: Union[str, ImportMode] = ImportMode.prepend,
root: Path,
) -> ModuleType:
"""Import and return a module from the given path, which can be a file (a module) or
a directory (a package).
The import mechanism used is controlled by the `mode` parameter:
* `mode == ImportMode.prepend`: the directory containing the module (or package, taking
`__init__.py` files into account) will be put at the *start* of `sys.path` before
being imported with `__import__.
* `mode == ImportMode.append`: same as `prepend`, but the directory will be appended
to the end of `sys.path`, if not already in `sys.path`.
* `mode == ImportMode.importlib`: uses more fine control mechanisms provided by `importlib`
to import the module, which avoids having to use `__import__` and muck with `sys.path`
at all. It effectively allows having same-named test modules in different places.
:param root:
Used as an anchor when mode == ImportMode.importlib to obtain
a unique name for the module being imported so it can safely be stored
into ``sys.modules``.
:raises ImportPathMismatchError:
If after importing the given `path` and the module `__file__`
are different. Only raised in `prepend` and `append` modes.
"""
mode = ImportMode(mode)
path = Path(p)
if not path.exists():
raise ImportError(path)
if mode is ImportMode.importlib:
module_name = module_name_from_path(path, root)
for meta_importer in sys.meta_path:
spec = meta_importer.find_spec(module_name, [str(path.parent)])
if spec is not None:
break
else:
spec = importlib.util.spec_from_file_location(module_name, str(path))
if spec is None:
raise ImportError(f"Can't find module {module_name} at location {path}")
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod) # type: ignore[union-attr]
insert_missing_modules(sys.modules, module_name)
return mod
pkg_path = resolve_package_path(path)
if pkg_path is not None:
pkg_root = pkg_path.parent
names = list(path.with_suffix("").relative_to(pkg_root).parts)
if names[-1] == "__init__":
names.pop()
module_name = ".".join(names)
else:
pkg_root = path.parent
module_name = path.stem
# Change sys.path permanently: restoring it at the end of this function would cause surprising
# problems because of delayed imports: for example, a conftest.py file imported by this function
# might have local imports, which would fail at runtime if we restored sys.path.
if mode is ImportMode.append:
if str(pkg_root) not in sys.path:
sys.path.append(str(pkg_root))
elif mode is ImportMode.prepend:
if str(pkg_root) != sys.path[0]:
sys.path.insert(0, str(pkg_root))
else:
assert_never(mode)
importlib.import_module(module_name)
mod = sys.modules[module_name]
if path.name == "__init__.py":
return mod
ignore = os.environ.get("PY_IGNORE_IMPORTMISMATCH", "")
if ignore != "1":
module_file = mod.__file__
if module_file is None:
raise ImportPathMismatchError(module_name, module_file, path)
if module_file.endswith((".pyc", ".pyo")):
module_file = module_file[:-1]
if module_file.endswith(os.path.sep + "__init__.py"):
module_file = module_file[: -(len(os.path.sep + "__init__.py"))]
try:
is_same = _is_same(str(path), module_file)
except FileNotFoundError:
is_same = False
if not is_same:
raise ImportPathMismatchError(module_name, module_file, path)
return mod
# Implement a special _is_same function on Windows which returns True if the two filenames
# compare equal, to circumvent os.path.samefile returning False for mounts in UNC (#7678).
if sys.platform.startswith("win"):
def _is_same(f1: str, f2: str) -> bool:
return Path(f1) == Path(f2) or os.path.samefile(f1, f2)
else:
def _is_same(f1: str, f2: str) -> bool:
return os.path.samefile(f1, f2)
def module_name_from_path(path: Path, root: Path) -> str:
"""
Return a dotted module name based on the given path, anchored on root.
For example: path="projects/src/tests/test_foo.py" and root="/projects", the
resulting module name will be "src.tests.test_foo".
"""
path = path.with_suffix("")
try:
relative_path = path.relative_to(root)
except ValueError:
# If we can't get a relative path to root, use the full path, except
# for the first part ("d:\\" or "/" depending on the platform, for example).
path_parts = path.parts[1:]
else:
# Use the parts for the relative path to the root path.
path_parts = relative_path.parts
return ".".join(path_parts)
def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) -> None:
"""
Used by ``import_path`` to create intermediate modules when using mode=importlib.
When we want to import a module as "src.tests.test_foo" for example, we need
to create empty modules "src" and "src.tests" after inserting "src.tests.test_foo",
otherwise "src.tests.test_foo" is not importable by ``__import__``.
"""
module_parts = module_name.split(".")
while module_name:
if module_name not in modules:
try:
# If sys.meta_path is empty, calling import_module will issue
# a warning and raise ModuleNotFoundError. To avoid the
# warning, we check sys.meta_path explicitly and raise the error
# ourselves to fall back to creating a dummy module.
if not sys.meta_path:
raise ModuleNotFoundError
importlib.import_module(module_name)
except ModuleNotFoundError:
module = ModuleType(
module_name,
doc="Empty module created by pytest's importmode=importlib.",
)
modules[module_name] = module
module_parts.pop(-1)
module_name = ".".join(module_parts)
def resolve_package_path(path: Path) -> Optional[Path]:
"""Return the Python package path by looking for the last
directory upwards which still contains an __init__.py.
Returns None if it can not be determined.
"""
result = None
for parent in itertools.chain((path,), path.parents):
if parent.is_dir():
if not parent.joinpath("__init__.py").is_file():
break
if not parent.name.isidentifier():
break
result = parent
return result
def visit(
path: Union[str, "os.PathLike[str]"], recurse: Callable[["os.DirEntry[str]"], bool]
) -> Iterator["os.DirEntry[str]"]:
"""Walk a directory recursively, in breadth-first order.
Entries at each directory level are sorted.
"""
# Skip entries with symlink loops and other brokenness, so the caller doesn't
# have to deal with it.
entries = []
for entry in os.scandir(path):
try:
entry.is_file()
except OSError as err:
if _ignore_error(err):
continue
raise
entries.append(entry)
entries.sort(key=lambda entry: entry.name)
yield from entries
for entry in entries:
if entry.is_dir() and recurse(entry):
yield from visit(entry.path, recurse)
def absolutepath(path: Union[Path, str]) -> Path:
"""Convert a path to an absolute path using os.path.abspath.
Prefer this over Path.resolve() (see #6523).
Prefer this over Path.absolute() (not public, doesn't normalize).
"""
return Path(os.path.abspath(str(path)))
def commonpath(path1: Path, path2: Path) -> Optional[Path]:
"""Return the common part shared with the other path, or None if there is
no common part.
If one path is relative and one is absolute, returns None.
"""
try:
return Path(os.path.commonpath((str(path1), str(path2))))
except ValueError:
return None
def bestrelpath(directory: Path, dest: Path) -> str:
"""Return a string which is a relative path from directory to dest such
that directory/bestrelpath == dest.
The paths must be either both absolute or both relative.
If no such path can be determined, returns dest.
"""
assert isinstance(directory, Path)
assert isinstance(dest, Path)
if dest == directory:
return os.curdir
# Find the longest common directory.
base = commonpath(directory, dest)
# Can be the case on Windows for two absolute paths on different drives.
# Can be the case for two relative paths without common prefix.
# Can be the case for a relative path and an absolute path.
if not base:
return str(dest)
reldirectory = directory.relative_to(base)
reldest = dest.relative_to(base)
return os.path.join(
# Back from directory to base.
*([os.pardir] * len(reldirectory.parts)),
# Forward from base to dest.
*reldest.parts,
)
# Originates from py. path.local.copy(), with siginficant trims and adjustments.
# TODO(py38): Replace with shutil.copytree(..., symlinks=True, dirs_exist_ok=True)
def copytree(source: Path, target: Path) -> None:
"""Recursively copy a source directory to target."""
assert source.is_dir()
for entry in visit(source, recurse=lambda entry: not entry.is_symlink()):
x = Path(entry)
relpath = x.relative_to(source)
newx = target / relpath
newx.parent.mkdir(exist_ok=True)
if x.is_symlink():
newx.symlink_to(os.readlink(x))
elif x.is_file():
shutil.copyfile(x, newx)
elif x.is_dir():
newx.mkdir(exist_ok=True)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,75 @@
"""Helper plugin for pytester; should not be loaded on its own."""
# This plugin contains assertions used by pytester. pytester cannot
# contain them itself, since it is imported by the `pytest` module,
# hence cannot be subject to assertion rewriting, which requires a
# module to not be already imported.
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
def assertoutcome(
outcomes: Tuple[
Sequence[TestReport],
Sequence[Union[CollectReport, TestReport]],
Sequence[Union[CollectReport, TestReport]],
],
passed: int = 0,
skipped: int = 0,
failed: int = 0,
) -> None:
__tracebackhide__ = True
realpassed, realskipped, realfailed = outcomes
obtained = {
"passed": len(realpassed),
"skipped": len(realskipped),
"failed": len(realfailed),
}
expected = {"passed": passed, "skipped": skipped, "failed": failed}
assert obtained == expected, outcomes
def assert_outcomes(
outcomes: Dict[str, int],
passed: int = 0,
skipped: int = 0,
failed: int = 0,
errors: int = 0,
xpassed: int = 0,
xfailed: int = 0,
warnings: Optional[int] = None,
deselected: Optional[int] = None,
) -> None:
"""Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run."""
__tracebackhide__ = True
obtained = {
"passed": outcomes.get("passed", 0),
"skipped": outcomes.get("skipped", 0),
"failed": outcomes.get("failed", 0),
"errors": outcomes.get("errors", 0),
"xpassed": outcomes.get("xpassed", 0),
"xfailed": outcomes.get("xfailed", 0),
}
expected = {
"passed": passed,
"skipped": skipped,
"failed": failed,
"errors": errors,
"xpassed": xpassed,
"xfailed": xfailed,
}
if warnings is not None:
obtained["warnings"] = outcomes.get("warnings", 0)
expected["warnings"] = warnings
if deselected is not None:
obtained["deselected"] = outcomes.get("deselected", 0)
expected["deselected"] = deselected
assert obtained == expected

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,993 @@
import math
import pprint
from collections.abc import Collection
from collections.abc import Sized
from decimal import Decimal
from numbers import Complex
from types import TracebackType
from typing import Any
from typing import Callable
from typing import cast
from typing import Generic
from typing import List
from typing import Mapping
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
if TYPE_CHECKING:
from numpy import ndarray
import _pytest._code
from _pytest.compat import final
from _pytest.compat import STRING_TYPES
from _pytest.compat import overload
from _pytest.outcomes import fail
def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
at_str = f" at {at}" if at else ""
return TypeError(
"cannot make approximate comparisons to non-numeric values: {!r} {}".format(
value, at_str
)
)
def _compare_approx(
full_object: object,
message_data: Sequence[Tuple[str, str, str]],
number_of_elements: int,
different_ids: Sequence[object],
max_abs_diff: float,
max_rel_diff: float,
) -> List[str]:
message_list = list(message_data)
message_list.insert(0, ("Index", "Obtained", "Expected"))
max_sizes = [0, 0, 0]
for index, obtained, expected in message_list:
max_sizes[0] = max(max_sizes[0], len(index))
max_sizes[1] = max(max_sizes[1], len(obtained))
max_sizes[2] = max(max_sizes[2], len(expected))
explanation = [
f"comparison failed. Mismatched elements: {len(different_ids)} / {number_of_elements}:",
f"Max absolute difference: {max_abs_diff}",
f"Max relative difference: {max_rel_diff}",
] + [
f"{indexes:<{max_sizes[0]}} | {obtained:<{max_sizes[1]}} | {expected:<{max_sizes[2]}}"
for indexes, obtained, expected in message_list
]
return explanation
# builtin pytest.approx helper
class ApproxBase:
"""Provide shared utilities for making approximate comparisons between
numbers or sequences of numbers."""
# Tell numpy to use our `__eq__` operator instead of its.
__array_ufunc__ = None
__array_priority__ = 100
def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
__tracebackhide__ = True
self.expected = expected
self.abs = abs
self.rel = rel
self.nan_ok = nan_ok
self._check_type()
def __repr__(self) -> str:
raise NotImplementedError
def _repr_compare(self, other_side: Any) -> List[str]:
return [
"comparison failed",
f"Obtained: {other_side}",
f"Expected: {self}",
]
def __eq__(self, actual) -> bool:
return all(
a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
)
def __bool__(self):
__tracebackhide__ = True
raise AssertionError(
"approx() is not supported in a boolean context.\nDid you mean: `assert a == approx(b)`?"
)
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
def __ne__(self, actual) -> bool:
return not (actual == self)
def _approx_scalar(self, x) -> "ApproxScalar":
if isinstance(x, Decimal):
return ApproxDecimal(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
def _yield_comparisons(self, actual):
"""Yield all the pairs of numbers to be compared.
This is used to implement the `__eq__` method.
"""
raise NotImplementedError
def _check_type(self) -> None:
"""Raise a TypeError if the expected value is not a valid type."""
# This is only a concern if the expected value is a sequence. In every
# other case, the approx() function ensures that the expected value has
# a numeric type. For this reason, the default is to do nothing. The
# classes that deal with sequences should reimplement this method to
# raise if there are any non-numeric elements in the sequence.
def _recursive_sequence_map(f, x):
"""Recursively map a function over a sequence of arbitrary depth"""
if isinstance(x, (list, tuple)):
seq_type = type(x)
return seq_type(_recursive_sequence_map(f, xi) for xi in x)
else:
return f(x)
class ApproxNumpy(ApproxBase):
"""Perform approximate comparisons where the expected value is numpy array."""
def __repr__(self) -> str:
list_scalars = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)
return f"approx({list_scalars!r})"
def _repr_compare(self, other_side: "ndarray") -> List[str]:
import itertools
import math
def get_value_from_nested_list(
nested_list: List[Any], nd_index: Tuple[Any, ...]
) -> Any:
"""
Helper function to get the value out of a nested list, given an n-dimensional index.
This mimics numpy's indexing, but for raw nested python lists.
"""
value: Any = nested_list
for i in nd_index:
value = value[i]
return value
np_array_shape = self.expected.shape
approx_side_as_seq = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)
if np_array_shape != other_side.shape:
return [
"Impossible to compare arrays with different shapes.",
f"Shapes: {np_array_shape} and {other_side.shape}",
]
number_of_elements = self.expected.size
max_abs_diff = -math.inf
max_rel_diff = -math.inf
different_ids = []
for index in itertools.product(*(range(i) for i in np_array_shape)):
approx_value = get_value_from_nested_list(approx_side_as_seq, index)
other_value = get_value_from_nested_list(other_side, index)
if approx_value != other_value:
abs_diff = abs(approx_value.expected - other_value)
max_abs_diff = max(max_abs_diff, abs_diff)
if other_value == 0.0:
max_rel_diff = math.inf
else:
max_rel_diff = max(max_rel_diff, abs_diff / abs(other_value))
different_ids.append(index)
message_data = [
(
str(index),
str(get_value_from_nested_list(other_side, index)),
str(get_value_from_nested_list(approx_side_as_seq, index)),
)
for index in different_ids
]
return _compare_approx(
self.expected,
message_data,
number_of_elements,
different_ids,
max_abs_diff,
max_rel_diff,
)
def __eq__(self, actual) -> bool:
import numpy as np
# self.expected is supposed to always be an array here.
if not np.isscalar(actual):
try:
actual = np.asarray(actual)
except Exception as e:
raise TypeError(f"cannot compare '{actual}' to numpy.ndarray") from e
if not np.isscalar(actual) and actual.shape != self.expected.shape:
return False
return super().__eq__(actual)
def _yield_comparisons(self, actual):
import numpy as np
# `actual` can either be a numpy array or a scalar, it is treated in
# `__eq__` before being passed to `ApproxBase.__eq__`, which is the
# only method that calls this one.
if np.isscalar(actual):
for i in np.ndindex(self.expected.shape):
yield actual, self.expected[i].item()
else:
for i in np.ndindex(self.expected.shape):
yield actual[i].item(), self.expected[i].item()
class ApproxMapping(ApproxBase):
"""Perform approximate comparisons where the expected value is a mapping
with numeric values (the keys can be anything)."""
def __repr__(self) -> str:
return "approx({!r})".format(
{k: self._approx_scalar(v) for k, v in self.expected.items()}
)
def _repr_compare(self, other_side: Mapping[object, float]) -> List[str]:
import math
approx_side_as_map = {
k: self._approx_scalar(v) for k, v in self.expected.items()
}
number_of_elements = len(approx_side_as_map)
max_abs_diff = -math.inf
max_rel_diff = -math.inf
different_ids = []
for (approx_key, approx_value), other_value in zip(
approx_side_as_map.items(), other_side.values()
):
if approx_value != other_value:
max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value)
)
max_rel_diff = max(
max_rel_diff,
abs((approx_value.expected - other_value) / approx_value.expected),
)
different_ids.append(approx_key)
message_data = [
(str(key), str(other_side[key]), str(approx_side_as_map[key]))
for key in different_ids
]
return _compare_approx(
self.expected,
message_data,
number_of_elements,
different_ids,
max_abs_diff,
max_rel_diff,
)
def __eq__(self, actual) -> bool:
try:
if set(actual.keys()) != set(self.expected.keys()):
return False
except AttributeError:
return False
return super().__eq__(actual)
def _yield_comparisons(self, actual):
for k in self.expected.keys():
yield actual[k], self.expected[k]
def _check_type(self) -> None:
__tracebackhide__ = True
for key, value in self.expected.items():
if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
class ApproxSequenceLike(ApproxBase):
"""Perform approximate comparisons where the expected value is a sequence of numbers."""
def __repr__(self) -> str:
seq_type = type(self.expected)
if seq_type not in (tuple, list):
seq_type = list
return "approx({!r})".format(
seq_type(self._approx_scalar(x) for x in self.expected)
)
def _repr_compare(self, other_side: Sequence[float]) -> List[str]:
import math
if len(self.expected) != len(other_side):
return [
"Impossible to compare lists with different sizes.",
f"Lengths: {len(self.expected)} and {len(other_side)}",
]
approx_side_as_map = _recursive_sequence_map(self._approx_scalar, self.expected)
number_of_elements = len(approx_side_as_map)
max_abs_diff = -math.inf
max_rel_diff = -math.inf
different_ids = []
for i, (approx_value, other_value) in enumerate(
zip(approx_side_as_map, other_side)
):
if approx_value != other_value:
abs_diff = abs(approx_value.expected - other_value)
max_abs_diff = max(max_abs_diff, abs_diff)
if other_value == 0.0:
max_rel_diff = math.inf
else:
max_rel_diff = max(max_rel_diff, abs_diff / abs(other_value))
different_ids.append(i)
message_data = [
(str(i), str(other_side[i]), str(approx_side_as_map[i]))
for i in different_ids
]
return _compare_approx(
self.expected,
message_data,
number_of_elements,
different_ids,
max_abs_diff,
max_rel_diff,
)
def __eq__(self, actual) -> bool:
try:
if len(actual) != len(self.expected):
return False
except TypeError:
return False
return super().__eq__(actual)
def _yield_comparisons(self, actual):
return zip(actual, self.expected)
def _check_type(self) -> None:
__tracebackhide__ = True
for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
class ApproxScalar(ApproxBase):
"""Perform approximate comparisons where the expected value is a single number."""
# Using Real should be better than this Union, but not possible yet:
# https://github.com/python/typeshed/pull/3108
DEFAULT_ABSOLUTE_TOLERANCE: Union[float, Decimal] = 1e-12
DEFAULT_RELATIVE_TOLERANCE: Union[float, Decimal] = 1e-6
def __repr__(self) -> str:
"""Return a string communicating both the expected value and the
tolerance for the comparison being made.
For example, ``1.0 ± 1e-6``, ``(3+4j) ± 5e-6 ±180°``.
"""
# Don't show a tolerance for values that aren't compared using
# tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
abs(self.expected) # type: ignore[arg-type]
):
return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'.
try:
vetted_tolerance = f"{self.tolerance:.1e}"
if (
isinstance(self.expected, Complex)
and self.expected.imag
and not math.isinf(self.tolerance)
):
vetted_tolerance += " ∠ ±180°"
except ValueError:
vetted_tolerance = "???"
return f"{self.expected} ± {vetted_tolerance}"
def __eq__(self, actual) -> bool:
"""Return whether the given value is equal to the expected value
within the pre-specified tolerance."""
asarray = _as_numpy_array(actual)
if asarray is not None:
# Call ``__eq__()`` manually to prevent infinite-recursion with
# numpy<1.13. See #3748.
return all(self.__eq__(a) for a in asarray.flat)
# Short-circuit exact equality.
if actual == self.expected:
return True
# If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined.
if not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
):
return False
# Allow the user to control whether NaNs are considered equal to each
# other or not. The abs() calls are for compatibility with complex
# numbers.
if math.isnan(abs(self.expected)): # type: ignore[arg-type]
return self.nan_ok and math.isnan(abs(actual)) # type: ignore[arg-type]
# Infinity shouldn't be approximately equal to anything but itself, but
# if there's a relative tolerance, it will be infinite and infinity
# will seem approximately equal to everything. The equal-to-itself
# case would have been short circuited above, so here we can just
# return false if the expected value is infinite. The abs() call is
# for compatibility with complex numbers.
if math.isinf(abs(self.expected)): # type: ignore[arg-type]
return False
# Return true if the two numbers are within the tolerance.
result: bool = abs(self.expected - actual) <= self.tolerance
return result
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@property
def tolerance(self):
"""Return the tolerance for the comparison.
This could be either an absolute tolerance or a relative tolerance,
depending on what the user specified or which would be larger.
"""
def set_default(x, default):
return x if x is not None else default
# Figure out what the absolute tolerance should be. ``self.abs`` is
# either None or a value specified by the user.
absolute_tolerance = set_default(self.abs, self.DEFAULT_ABSOLUTE_TOLERANCE)
if absolute_tolerance < 0:
raise ValueError(
f"absolute tolerance can't be negative: {absolute_tolerance}"
)
if math.isnan(absolute_tolerance):
raise ValueError("absolute tolerance can't be NaN.")
# If the user specified an absolute tolerance but not a relative one,
# just return the absolute tolerance.
if self.rel is None:
if self.abs is not None:
return absolute_tolerance
# Figure out what the relative tolerance should be. ``self.rel`` is
# either None or a value specified by the user. This is done after
# we've made sure the user didn't ask for an absolute tolerance only,
# because we don't want to raise errors about the relative tolerance if
# we aren't even going to use it.
relative_tolerance = set_default(
self.rel, self.DEFAULT_RELATIVE_TOLERANCE
) * abs(self.expected)
if relative_tolerance < 0:
raise ValueError(
f"relative tolerance can't be negative: {relative_tolerance}"
)
if math.isnan(relative_tolerance):
raise ValueError("relative tolerance can't be NaN.")
# Return the larger of the relative and absolute tolerances.
return max(relative_tolerance, absolute_tolerance)
class ApproxDecimal(ApproxScalar):
"""Perform approximate comparisons where the expected value is a Decimal."""
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
"""Assert that two numbers (or two ordered sequences of numbers) are equal to each other
within some tolerance.
Due to the :doc:`python:tutorial/floatingpoint`, numbers that we
would intuitively expect to be equal are not always so::
>>> 0.1 + 0.2 == 0.3
False
This problem is commonly encountered when writing tests, e.g. when making
sure that floating-point values are what you expect them to be. One way to
deal with this problem is to assert that two floating-point numbers are
equal to within some appropriate tolerance::
>>> abs((0.1 + 0.2) - 0.3) < 1e-6
True
However, comparisons like this are tedious to write and difficult to
understand. Furthermore, absolute comparisons like the one above are
usually discouraged because there's no tolerance that works well for all
situations. ``1e-6`` is good for numbers around ``1``, but too small for
very big numbers and too big for very small ones. It's better to express
the tolerance as a fraction of the expected value, but relative comparisons
like that are even more difficult to write correctly and concisely.
The ``approx`` class performs floating-point comparisons using a syntax
that's as intuitive as possible::
>>> from pytest import approx
>>> 0.1 + 0.2 == approx(0.3)
True
The same syntax also works for ordered sequences of numbers::
>>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6))
True
``numpy`` arrays::
>>> import numpy as np # doctest: +SKIP
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
True
And for a ``numpy`` array against a scalar::
>>> import numpy as np # doctest: +SKIP
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP
True
Only ordered sequences are supported, because ``approx`` needs
to infer the relative position of the sequences without ambiguity. This means
``sets`` and other unordered sequences are not supported.
Finally, dictionary *values* can also be compared::
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
True
The comparison will be true if both mappings have the same keys and their
respective values match the expected tolerances.
**Tolerances**
By default, ``approx`` considers numbers within a relative tolerance of
``1e-6`` (i.e. one part in a million) of its expected value to be equal.
This treatment would lead to surprising results if the expected value was
``0.0``, because nothing but ``0.0`` itself is relatively close to ``0.0``.
To handle this case less surprisingly, ``approx`` also considers numbers
within an absolute tolerance of ``1e-12`` of its expected value to be
equal. Infinity and NaN are special cases. Infinity is only considered
equal to itself, regardless of the relative tolerance. NaN is not
considered equal to anything by default, but you can make it be equal to
itself by setting the ``nan_ok`` argument to True. (This is meant to
facilitate comparing arrays that use NaN to mean "no data".)
Both the relative and absolute tolerances can be changed by passing
arguments to the ``approx`` constructor::
>>> 1.0001 == approx(1)
False
>>> 1.0001 == approx(1, rel=1e-3)
True
>>> 1.0001 == approx(1, abs=1e-3)
True
If you specify ``abs`` but not ``rel``, the comparison will not consider
the relative tolerance at all. In other words, two numbers that are within
the default relative tolerance of ``1e-6`` will still be considered unequal
if they exceed the specified absolute tolerance. If you specify both
``abs`` and ``rel``, the numbers will be considered equal if either
tolerance is met::
>>> 1 + 1e-8 == approx(1)
True
>>> 1 + 1e-8 == approx(1, abs=1e-12)
False
>>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
True
You can also use ``approx`` to compare nonnumeric types, or dicts and
sequences containing nonnumeric types, in which case it falls back to
strict equality. This can be useful for comparing dicts and sequences that
can contain optional values::
>>> {"required": 1.0000005, "optional": None} == approx({"required": 1, "optional": None})
True
>>> [None, 1.0000005] == approx([None,1])
True
>>> ["foo", 1.0000005] == approx([None,1])
False
If you're thinking about using ``approx``, then you might want to know how
it compares to other good ways of comparing floating-point numbers. All of
these algorithms are based on relative and absolute tolerances and should
agree for the most part, but they do have meaningful differences:
- ``math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)``: True if the relative
tolerance is met w.r.t. either ``a`` or ``b`` or if the absolute
tolerance is met. Because the relative tolerance is calculated w.r.t.
both ``a`` and ``b``, this test is symmetric (i.e. neither ``a`` nor
``b`` is a "reference value"). You have to specify an absolute tolerance
if you want to compare to ``0.0`` because there is no tolerance by
default. More information: :py:func:`math.isclose`.
- ``numpy.isclose(a, b, rtol=1e-5, atol=1e-8)``: True if the difference
between ``a`` and ``b`` is less that the sum of the relative tolerance
w.r.t. ``b`` and the absolute tolerance. Because the relative tolerance
is only calculated w.r.t. ``b``, this test is asymmetric and you can
think of ``b`` as the reference value. Support for comparing sequences
is provided by :py:func:`numpy.allclose`. More information:
:std:doc:`numpy:reference/generated/numpy.isclose`.
- ``unittest.TestCase.assertAlmostEqual(a, b)``: True if ``a`` and ``b``
are within an absolute tolerance of ``1e-7``. No relative tolerance is
considered , so this function is not appropriate for very large or very
small numbers. Also, it's only available in subclasses of ``unittest.TestCase``
and it's ugly because it doesn't follow PEP8. More information:
:py:meth:`unittest.TestCase.assertAlmostEqual`.
- ``a == pytest.approx(b, rel=1e-6, abs=1e-12)``: True if the relative
tolerance is met w.r.t. ``b`` or if the absolute tolerance is met.
Because the relative tolerance is only calculated w.r.t. ``b``, this test
is asymmetric and you can think of ``b`` as the reference value. In the
special case that you explicitly specify an absolute tolerance but not a
relative tolerance, only the absolute tolerance is considered.
.. note::
``approx`` can handle numpy arrays, but we recommend the
specialised test helpers in :std:doc:`numpy:reference/routines.testing`
if you need support for comparisons, NaNs, or ULP-based tolerances.
To match strings using regex, you can use
`Matches <https://github.com/asottile/re-assert#re_assertmatchespattern-str-args-kwargs>`_
from the
`re_assert package <https://github.com/asottile/re-assert>`_.
.. warning::
.. versionchanged:: 3.2
In order to avoid inconsistent behavior, :py:exc:`TypeError` is
raised for ``>``, ``>=``, ``<`` and ``<=`` comparisons.
The example below illustrates the problem::
assert approx(0.1) > 0.1 + 1e-10 # calls approx(0.1).__gt__(0.1 + 1e-10)
assert 0.1 + 1e-10 > approx(0.1) # calls approx(0.1).__lt__(0.1 + 1e-10)
In the second example one expects ``approx(0.1).__le__(0.1 + 1e-10)``
to be called. But instead, ``approx(0.1).__lt__(0.1 + 1e-10)`` is used to
comparison. This is because the call hierarchy of rich comparisons
follows a fixed behavior. More information: :py:meth:`object.__ge__`
.. versionchanged:: 3.7.1
``approx`` raises ``TypeError`` when it encounters a dict value or
sequence element of nonnumeric type.
.. versionchanged:: 6.1.0
``approx`` falls back to strict equality for nonnumeric types instead
of raising ``TypeError``.
"""
# Delegate the comparison to a class that knows how to deal with the type
# of the expected value (e.g. int, float, list, dict, numpy.array, etc).
#
# The primary responsibility of these classes is to implement ``__eq__()``
# and ``__repr__()``. The former is used to actually check if some
# "actual" value is equivalent to the given expected value within the
# allowed tolerance. The latter is used to show the user the expected
# value and tolerance, in the case that a test failed.
#
# The actual logic for making approximate comparisons can be found in
# ApproxScalar, which is used to compare individual numbers. All of the
# other Approx classes eventually delegate to this class. The ApproxBase
# class provides some convenient methods and overloads, but isn't really
# essential.
__tracebackhide__ = True
if isinstance(expected, Decimal):
cls: Type[ApproxBase] = ApproxDecimal
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif _is_numpy_array(expected):
expected = _as_numpy_array(expected)
cls = ApproxNumpy
elif (
hasattr(expected, "__getitem__")
and isinstance(expected, Sized)
# Type ignored because the error is wrong -- not unreachable.
and not isinstance(expected, STRING_TYPES) # type: ignore[unreachable]
):
cls = ApproxSequenceLike
elif (
isinstance(expected, Collection)
# Type ignored because the error is wrong -- not unreachable.
and not isinstance(expected, STRING_TYPES) # type: ignore[unreachable]
):
msg = f"pytest.approx() only supports ordered sequences, but got: {repr(expected)}"
raise TypeError(msg)
else:
cls = ApproxScalar
return cls(expected, rel, abs, nan_ok)
def _is_numpy_array(obj: object) -> bool:
"""
Return true if the given object is implicitly convertible to ndarray,
and numpy is already imported.
"""
return _as_numpy_array(obj) is not None
def _as_numpy_array(obj: object) -> Optional["ndarray"]:
"""
Return an ndarray if the given object is implicitly convertible to ndarray,
and numpy is already imported, otherwise None.
"""
import sys
np: Any = sys.modules.get("numpy")
if np is not None:
# avoid infinite recursion on numpy scalars, which have __array__
if np.isscalar(obj):
return None
elif isinstance(obj, np.ndarray):
return obj
elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"):
return np.asarray(obj)
return None
# builtin pytest.raises helper
E = TypeVar("E", bound=BaseException)
@overload
def raises(
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
*,
match: Optional[Union[str, Pattern[str]]] = ...,
) -> "RaisesContext[E]":
...
@overload
def raises( # noqa: F811
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
func: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> _pytest._code.ExceptionInfo[E]:
...
def raises( # noqa: F811
expected_exception: Union[Type[E], Tuple[Type[E], ...]], *args: Any, **kwargs: Any
) -> Union["RaisesContext[E]", _pytest._code.ExceptionInfo[E]]:
r"""Assert that a code block/function call raises an exception.
:param typing.Type[E] | typing.Tuple[typing.Type[E], ...] expected_exception:
The excpected exception type, or a tuple if one of multiple possible
exception types are excepted.
:kwparam str | typing.Pattern[str] | None match:
If specified, a string containing a regular expression,
or a regular expression object, that is tested against the string
representation of the exception using :func:`re.search`.
To match a literal string that may contain :ref:`special characters
<re-syntax>`, the pattern can first be escaped with :func:`re.escape`.
(This is only used when :py:func:`pytest.raises` is used as a context manager,
and passed through to the function otherwise.
When using :py:func:`pytest.raises` as a function, you can use:
``pytest.raises(Exc, func, match="passed on").match("my pattern")``.)
.. currentmodule:: _pytest._code
Use ``pytest.raises`` as a context manager, which will capture the exception of the given
type::
>>> import pytest
>>> with pytest.raises(ZeroDivisionError):
... 1/0
If the code block does not raise the expected exception (``ZeroDivisionError`` in the example
above), or no exception at all, the check will fail instead.
You can also use the keyword argument ``match`` to assert that the
exception matches a text or regex::
>>> with pytest.raises(ValueError, match='must be 0 or None'):
... raise ValueError("value must be 0 or None")
>>> with pytest.raises(ValueError, match=r'must be \d+$'):
... raise ValueError("value must be 42")
The context manager produces an :class:`ExceptionInfo` object which can be used to inspect the
details of the captured exception::
>>> with pytest.raises(ValueError) as exc_info:
... raise ValueError("value must be 42")
>>> assert exc_info.type is ValueError
>>> assert exc_info.value.args[0] == "value must be 42"
.. note::
When using ``pytest.raises`` as a context manager, it's worthwhile to
note that normal context manager rules apply and that the exception
raised *must* be the final line in the scope of the context manager.
Lines of code after that, within the scope of the context manager will
not be executed. For example::
>>> value = 15
>>> with pytest.raises(ValueError) as exc_info:
... if value > 10:
... raise ValueError("value must be <= 10")
... assert exc_info.type is ValueError # this will not execute
Instead, the following approach must be taken (note the difference in
scope)::
>>> with pytest.raises(ValueError) as exc_info:
... if value > 10:
... raise ValueError("value must be <= 10")
...
>>> assert exc_info.type is ValueError
**Using with** ``pytest.mark.parametrize``
When using :ref:`pytest.mark.parametrize ref`
it is possible to parametrize tests such that
some runs raise an exception and others do not.
See :ref:`parametrizing_conditional_raising` for an example.
**Legacy form**
It is possible to specify a callable by passing a to-be-called lambda::
>>> raises(ZeroDivisionError, lambda: 1/0)
<ExceptionInfo ...>
or you can specify an arbitrary callable with arguments::
>>> def f(x): return 1/x
...
>>> raises(ZeroDivisionError, f, 0)
<ExceptionInfo ...>
>>> raises(ZeroDivisionError, f, x=0)
<ExceptionInfo ...>
The form above is fully supported but discouraged for new code because the
context manager form is regarded as more readable and less error-prone.
.. note::
Similar to caught exception objects in Python, explicitly clearing
local references to returned ``ExceptionInfo`` objects can
help the Python interpreter speed up its garbage collection.
Clearing those references breaks a reference cycle
(``ExceptionInfo`` --> caught exception --> frame stack raising
the exception --> current frame stack --> local variables -->
``ExceptionInfo``) which makes Python keep all objects referenced
from that cycle (including all local variables in the current
frame) alive until the next cyclic garbage collection run.
More detailed information can be found in the official Python
documentation for :ref:`the try statement <python:try>`.
"""
__tracebackhide__ = True
if not expected_exception:
raise ValueError(
f"Expected an exception type or a tuple of exception types, but got `{expected_exception!r}`. "
f"Raising exceptions is already understood as failing the test, so you don't need "
f"any special code to say 'this should never raise an exception'."
)
if isinstance(expected_exception, type):
excepted_exceptions: Tuple[Type[E], ...] = (expected_exception,)
else:
excepted_exceptions = expected_exception
for exc in excepted_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a))
message = f"DID NOT RAISE {expected_exception}"
if not args:
match: Optional[Union[str, Pattern[str]]] = kwargs.pop("match", None)
if kwargs:
msg = "Unexpected keyword arguments passed to pytest.raises: "
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return RaisesContext(expected_exception, message, match)
else:
func = args[0]
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
try:
func(*args[1:], **kwargs)
except expected_exception as e:
# We just caught the exception - there is a traceback.
assert e.__traceback__ is not None
return _pytest._code.ExceptionInfo.from_exc_info(
(type(e), e, e.__traceback__)
)
fail(message)
# This doesn't work with mypy for now. Use fail.Exception instead.
raises.Exception = fail.Exception # type: ignore
@final
class RaisesContext(Generic[E]):
def __init__(
self,
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
message: str,
match_expr: Optional[Union[str, Pattern[str]]] = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
self.excinfo: Optional[_pytest._code.ExceptionInfo[E]] = None
def __enter__(self) -> _pytest._code.ExceptionInfo[E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
__tracebackhide__ = True
if exc_type is None:
fail(self.message)
assert self.excinfo is not None
if not issubclass(exc_type, self.expected_exception):
return False
# Cast to narrow the exception type now that it's verified.
exc_info = cast(Tuple[Type[E], E, TracebackType], (exc_type, exc_val, exc_tb))
self.excinfo.fill_unfilled(exc_info)
if self.match_expr is not None:
self.excinfo.match(self.match_expr)
return True

View File

@ -0,0 +1,24 @@
import sys
import pytest
from pytest import Config
from pytest import Parser
def pytest_addoption(parser: Parser) -> None:
parser.addini("pythonpath", type="paths", help="Add paths to sys.path", default=[])
@pytest.hookimpl(tryfirst=True)
def pytest_load_initial_conftests(early_config: Config) -> None:
# `pythonpath = a b` will set `sys.path` to `[a, b, x, y, z, ...]`
for path in reversed(early_config.getini("pythonpath")):
sys.path.insert(0, str(path))
@pytest.hookimpl(trylast=True)
def pytest_unconfigure(config: Config) -> None:
for path in config.getini("pythonpath"):
path_str = str(path)
if path_str in sys.path:
sys.path.remove(path_str)

View File

@ -0,0 +1,313 @@
"""Record warnings during test function execution."""
import re
import warnings
from pprint import pformat
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterator
from typing import List
from typing import Optional
from typing import Pattern
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
from _pytest.compat import final
from _pytest.compat import overload
from _pytest.deprecated import check_ispytest
from _pytest.deprecated import WARNS_NONE_ARG
from _pytest.fixtures import fixture
from _pytest.outcomes import fail
T = TypeVar("T")
@fixture
def recwarn() -> Generator["WarningsRecorder", None, None]:
"""Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions.
See https://docs.pytest.org/en/latest/how-to/capture-warnings.html for information
on warning categories.
"""
wrec = WarningsRecorder(_ispytest=True)
with wrec:
warnings.simplefilter("default")
yield wrec
@overload
def deprecated_call(
*, match: Optional[Union[str, Pattern[str]]] = ...
) -> "WarningsRecorder":
...
@overload
def deprecated_call( # noqa: F811
func: Callable[..., T], *args: Any, **kwargs: Any
) -> T:
...
def deprecated_call( # noqa: F811
func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any
) -> Union["WarningsRecorder", Any]:
"""Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning``.
This function can be used as a context manager::
>>> import warnings
>>> def api_call_v2():
... warnings.warn('use v3 of this api', DeprecationWarning)
... return 200
>>> import pytest
>>> with pytest.deprecated_call():
... assert api_call_v2() == 200
It can also be used by passing a function and ``*args`` and ``**kwargs``,
in which case it will ensure calling ``func(*args, **kwargs)`` produces one of
the warnings types above. The return value is the return value of the function.
In the context manager form you may use the keyword argument ``match`` to assert
that the warning matches a text or regex.
The context manager produces a list of :class:`warnings.WarningMessage` objects,
one for each warning raised.
"""
__tracebackhide__ = True
if func is not None:
args = (func,) + args
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
@overload
def warns(
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = ...,
*,
match: Optional[Union[str, Pattern[str]]] = ...,
) -> "WarningsChecker":
...
@overload
def warns( # noqa: F811
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]],
func: Callable[..., T],
*args: Any,
**kwargs: Any,
) -> T:
...
def warns( # noqa: F811
expected_warning: Union[Type[Warning], Tuple[Type[Warning], ...]] = Warning,
*args: Any,
match: Optional[Union[str, Pattern[str]]] = None,
**kwargs: Any,
) -> Union["WarningsChecker", Any]:
r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or sequence
of warning classes, and the code inside the ``with`` block must issue at least one
warning of that class or classes.
This helper produces a list of :class:`warnings.WarningMessage` objects, one for
each warning raised (regardless of whether it is an ``expected_warning`` or not).
This function can be used as a context manager, which will capture all the raised
warnings inside it::
>>> import pytest
>>> with pytest.warns(RuntimeWarning):
... warnings.warn("my warning", RuntimeWarning)
In the context manager form you may use the keyword argument ``match`` to assert
that the warning matches a text or regex::
>>> with pytest.warns(UserWarning, match='must be 0 or None'):
... warnings.warn("value must be 0 or None", UserWarning)
>>> with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("value must be 42", UserWarning)
>>> with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("this is not here", UserWarning)
Traceback (most recent call last):
...
Failed: DID NOT WARN. No warnings of type ...UserWarning... were emitted...
**Using with** ``pytest.mark.parametrize``
When using :ref:`pytest.mark.parametrize ref` it is possible to parametrize tests
such that some runs raise a warning and others do not.
This could be achieved in the same way as with exceptions, see
:ref:`parametrizing_conditional_raising` for an example.
"""
__tracebackhide__ = True
if not args:
if kwargs:
argnames = ", ".join(sorted(kwargs))
raise TypeError(
f"Unexpected keyword arguments passed to pytest.warns: {argnames}"
"\nUse context-manager form instead?"
)
return WarningsChecker(expected_warning, match_expr=match, _ispytest=True)
else:
func = args[0]
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
with WarningsChecker(expected_warning, _ispytest=True):
return func(*args[1:], **kwargs)
class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
"""A context manager to record raised warnings.
Each recorded warning is an instance of :class:`warnings.WarningMessage`.
Adapted from `warnings.catch_warnings`.
.. note::
``DeprecationWarning`` and ``PendingDeprecationWarning`` are treated
differently; see :ref:`ensuring_function_triggers`.
"""
def __init__(self, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
# Type ignored due to the way typeshed handles warnings.catch_warnings.
super().__init__(record=True) # type: ignore[call-arg]
self._entered = False
self._list: List[warnings.WarningMessage] = []
@property
def list(self) -> List["warnings.WarningMessage"]:
"""The list of recorded warnings."""
return self._list
def __getitem__(self, i: int) -> "warnings.WarningMessage":
"""Get a recorded warning by index."""
return self._list[i]
def __iter__(self) -> Iterator["warnings.WarningMessage"]:
"""Iterate through the recorded warnings."""
return iter(self._list)
def __len__(self) -> int:
"""The number of recorded warnings."""
return len(self._list)
def pop(self, cls: Type[Warning] = Warning) -> "warnings.WarningMessage":
"""Pop the first recorded warning, raise exception if not exists."""
for i, w in enumerate(self._list):
if issubclass(w.category, cls):
return self._list.pop(i)
__tracebackhide__ = True
raise AssertionError(f"{cls!r} not found in warning list")
def clear(self) -> None:
"""Clear the list of recorded warnings."""
self._list[:] = []
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
# -- it returns a List but we only emulate one.
def __enter__(self) -> "WarningsRecorder": # type: ignore
if self._entered:
__tracebackhide__ = True
raise RuntimeError(f"Cannot enter {self!r} twice")
_list = super().__enter__()
# record=True means it's None.
assert _list is not None
self._list = _list
warnings.simplefilter("always")
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if not self._entered:
__tracebackhide__ = True
raise RuntimeError(f"Cannot exit {self!r} without entering first")
super().__exit__(exc_type, exc_val, exc_tb)
# Built-in catch_warnings does not reset entered state so we do it
# manually here for this context manager to become reusable.
self._entered = False
@final
class WarningsChecker(WarningsRecorder):
def __init__(
self,
expected_warning: Optional[
Union[Type[Warning], Tuple[Type[Warning], ...]]
] = Warning,
match_expr: Optional[Union[str, Pattern[str]]] = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
super().__init__(_ispytest=True)
msg = "exceptions must be derived from Warning, not %s"
if expected_warning is None:
warnings.warn(WARNS_NONE_ARG, stacklevel=4)
expected_warning_tup = None
elif isinstance(expected_warning, tuple):
for exc in expected_warning:
if not issubclass(exc, Warning):
raise TypeError(msg % type(exc))
expected_warning_tup = expected_warning
elif issubclass(expected_warning, Warning):
expected_warning_tup = (expected_warning,)
else:
raise TypeError(msg % type(expected_warning))
self.expected_warning = expected_warning_tup
self.match_expr = match_expr
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
super().__exit__(exc_type, exc_val, exc_tb)
__tracebackhide__ = True
def found_str():
return pformat([record.message for record in self], indent=2)
# only check if we're not currently handling an exception
if exc_type is None and exc_val is None and exc_tb is None:
if self.expected_warning is not None:
if not any(issubclass(r.category, self.expected_warning) for r in self):
__tracebackhide__ = True
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f"The list of emitted warnings is: {found_str()}."
)
elif self.match_expr is not None:
for r in self:
if issubclass(r.category, self.expected_warning):
if re.compile(self.match_expr).search(str(r.message)):
break
else:
fail(
f"""\
DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.
Regex: {self.match_expr}
Emitted warnings: {found_str()}"""
)

View File

@ -0,0 +1,603 @@
import os
from io import StringIO
from pprint import pprint
from typing import Any
from typing import cast
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ExceptionRepr
from _pytest._code.code import ReprEntry
from _pytest._code.code import ReprEntryNative
from _pytest._code.code import ReprExceptionInfo
from _pytest._code.code import ReprFileLocation
from _pytest._code.code import ReprFuncArgs
from _pytest._code.code import ReprLocals
from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import final
from _pytest.config import Config
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import skip
if TYPE_CHECKING:
from typing_extensions import Literal
from _pytest.runner import CallInfo
def getworkerinfoline(node):
try:
return node._workerinfocache
except AttributeError:
d = node.workerinfo
ver = "%s.%s.%s" % d["version_info"][:3]
node._workerinfocache = s = "[{}] {} -- Python {} {}".format(
d["id"], d["sysplatform"], ver, d["executable"]
)
return s
_R = TypeVar("_R", bound="BaseReport")
class BaseReport:
when: Optional[str]
location: Optional[Tuple[str, Optional[int], str]]
longrepr: Union[
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
]
sections: List[Tuple[str, str]]
nodeid: str
outcome: "Literal['passed', 'failed', 'skipped']"
def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw)
if TYPE_CHECKING:
# Can have arbitrary fields given to __init__().
def __getattr__(self, key: str) -> Any:
...
def toterminal(self, out: TerminalWriter) -> None:
if hasattr(self, "node"):
worker_info = getworkerinfoline(self.node)
if worker_info:
out.line(worker_info)
longrepr = self.longrepr
if longrepr is None:
return
if hasattr(longrepr, "toterminal"):
longrepr_terminal = cast(TerminalRepr, longrepr)
longrepr_terminal.toterminal(out)
else:
try:
s = str(longrepr)
except UnicodeEncodeError:
s = "<unprintable longrepr>"
out.line(s)
def get_sections(self, prefix: str) -> Iterator[Tuple[str, str]]:
for name, content in self.sections:
if name.startswith(prefix):
yield prefix, content
@property
def longreprtext(self) -> str:
"""Read-only property that returns the full string representation of
``longrepr``.
.. versionadded:: 3.0
"""
file = StringIO()
tw = TerminalWriter(file)
tw.hasmarkup = False
self.toterminal(tw)
exc = file.getvalue()
return exc.strip()
@property
def caplog(self) -> str:
"""Return captured log lines, if log capturing is enabled.
.. versionadded:: 3.5
"""
return "\n".join(
content for (prefix, content) in self.get_sections("Captured log")
)
@property
def capstdout(self) -> str:
"""Return captured text from stdout, if capturing is enabled.
.. versionadded:: 3.0
"""
return "".join(
content for (prefix, content) in self.get_sections("Captured stdout")
)
@property
def capstderr(self) -> str:
"""Return captured text from stderr, if capturing is enabled.
.. versionadded:: 3.0
"""
return "".join(
content for (prefix, content) in self.get_sections("Captured stderr")
)
@property
def passed(self) -> bool:
"""Whether the outcome is passed."""
return self.outcome == "passed"
@property
def failed(self) -> bool:
"""Whether the outcome is failed."""
return self.outcome == "failed"
@property
def skipped(self) -> bool:
"""Whether the outcome is skipped."""
return self.outcome == "skipped"
@property
def fspath(self) -> str:
"""The path portion of the reported node, as a string."""
return self.nodeid.split("::")[0]
@property
def count_towards_summary(self) -> bool:
"""**Experimental** Whether this report should be counted towards the
totals shown at the end of the test session: "1 passed, 1 failure, etc".
.. note::
This function is considered **experimental**, so beware that it is subject to changes
even in patch releases.
"""
return True
@property
def head_line(self) -> Optional[str]:
"""**Experimental** The head line shown with longrepr output for this
report, more commonly during traceback representation during
failures::
________ Test.foo ________
In the example above, the head_line is "Test.foo".
.. note::
This function is considered **experimental**, so beware that it is subject to changes
even in patch releases.
"""
if self.location is not None:
fspath, lineno, domain = self.location
return domain
return None
def _get_verbose_word(self, config: Config):
_category, _short, verbose = config.hook.pytest_report_teststatus(
report=self, config=config
)
return verbose
def _to_json(self) -> Dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries,
suitable for serialization.
This was originally the serialize_report() function from xdist (ca03269).
Experimental method.
"""
return _report_to_json(self)
@classmethod
def _from_json(cls: Type[_R], reportdict: Dict[str, object]) -> _R:
"""Create either a TestReport or CollectReport, depending on the calling class.
It is the callers responsibility to know which class to pass here.
This was originally the serialize_report() function from xdist (ca03269).
Experimental method.
"""
kwargs = _report_kwargs_from_json(reportdict)
return cls(**kwargs)
def _report_unserialization_failure(
type_name: str, report_class: Type[BaseReport], reportdict
) -> NoReturn:
url = "https://github.com/pytest-dev/pytest/issues"
stream = StringIO()
pprint("-" * 100, stream=stream)
pprint("INTERNALERROR: Unknown entry type returned: %s" % type_name, stream=stream)
pprint("report_name: %s" % report_class, stream=stream)
pprint(reportdict, stream=stream)
pprint("Please report this bug at %s" % url, stream=stream)
pprint("-" * 100, stream=stream)
raise RuntimeError(stream.getvalue())
@final
class TestReport(BaseReport):
"""Basic test report object (also used for setup and teardown calls if
they fail).
Reports can contain arbitrary extra attributes.
"""
__test__ = False
def __init__(
self,
nodeid: str,
location: Tuple[str, Optional[int], str],
keywords: Mapping[str, Any],
outcome: "Literal['passed', 'failed', 'skipped']",
longrepr: Union[
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
],
when: "Literal['setup', 'call', 'teardown']",
sections: Iterable[Tuple[str, str]] = (),
duration: float = 0,
user_properties: Optional[Iterable[Tuple[str, object]]] = None,
**extra,
) -> None:
#: Normalized collection nodeid.
self.nodeid = nodeid
#: A (filesystempath, lineno, domaininfo) tuple indicating the
#: actual location of a test item - it might be different from the
#: collected one e.g. if a method is inherited from a different module.
self.location: Tuple[str, Optional[int], str] = location
#: A name -> value dictionary containing all keywords and
#: markers associated with a test invocation.
self.keywords: Mapping[str, Any] = keywords
#: Test outcome, always one of "passed", "failed", "skipped".
self.outcome = outcome
#: None or a failure representation.
self.longrepr = longrepr
#: One of 'setup', 'call', 'teardown' to indicate runtest phase.
self.when = when
#: User properties is a list of tuples (name, value) that holds user
#: defined properties of the test.
self.user_properties = list(user_properties or [])
#: Tuples of str ``(heading, content)`` with extra information
#: for the test report. Used by pytest to add text captured
#: from ``stdout``, ``stderr``, and intercepted logging events. May
#: be used by other plugins to add arbitrary information to reports.
self.sections = list(sections)
#: Time it took to run just the test.
self.duration: float = duration
self.__dict__.update(extra)
def __repr__(self) -> str:
return "<{} {!r} when={!r} outcome={!r}>".format(
self.__class__.__name__, self.nodeid, self.when, self.outcome
)
@classmethod
def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport":
"""Create and fill a TestReport with standard item and call info.
:param item: The item.
:param call: The call info.
"""
when = call.when
# Remove "collect" from the Literal type -- only for collection calls.
assert when != "collect"
duration = call.duration
keywords = {x: 1 for x in item.keywords}
excinfo = call.excinfo
sections = []
if not call.excinfo:
outcome: Literal["passed", "failed", "skipped"] = "passed"
longrepr: Union[
None,
ExceptionInfo[BaseException],
Tuple[str, int, str],
str,
TerminalRepr,
] = None
else:
if not isinstance(excinfo, ExceptionInfo):
outcome = "failed"
longrepr = excinfo
elif isinstance(excinfo.value, skip.Exception):
outcome = "skipped"
r = excinfo._getreprcrash()
if excinfo.value._use_item_location:
path, line = item.reportinfo()[:2]
assert line is not None
longrepr = os.fspath(path), line + 1, r.message
else:
longrepr = (str(r.path), r.lineno, r.message)
else:
outcome = "failed"
if call.when == "call":
longrepr = item.repr_failure(excinfo)
else: # exception in setup or teardown
longrepr = item._repr_failure_py(
excinfo, style=item.config.getoption("tbstyle", "auto")
)
for rwhen, key, content in item._report_sections:
sections.append((f"Captured {key} {rwhen}", content))
return cls(
item.nodeid,
item.location,
keywords,
outcome,
longrepr,
when,
sections,
duration,
user_properties=item.user_properties,
)
@final
class CollectReport(BaseReport):
"""Collection report object.
Reports can contain arbitrary extra attributes.
"""
when = "collect"
def __init__(
self,
nodeid: str,
outcome: "Literal['passed', 'failed', 'skipped']",
longrepr: Union[
None, ExceptionInfo[BaseException], Tuple[str, int, str], str, TerminalRepr
],
result: Optional[List[Union[Item, Collector]]],
sections: Iterable[Tuple[str, str]] = (),
**extra,
) -> None:
#: Normalized collection nodeid.
self.nodeid = nodeid
#: Test outcome, always one of "passed", "failed", "skipped".
self.outcome = outcome
#: None or a failure representation.
self.longrepr = longrepr
#: The collected items and collection nodes.
self.result = result or []
#: Tuples of str ``(heading, content)`` with extra information
#: for the test report. Used by pytest to add text captured
#: from ``stdout``, ``stderr``, and intercepted logging events. May
#: be used by other plugins to add arbitrary information to reports.
self.sections = list(sections)
self.__dict__.update(extra)
@property
def location(self):
return (self.fspath, None, self.fspath)
def __repr__(self) -> str:
return "<CollectReport {!r} lenresult={} outcome={!r}>".format(
self.nodeid, len(self.result), self.outcome
)
class CollectErrorRepr(TerminalRepr):
def __init__(self, msg: str) -> None:
self.longrepr = msg
def toterminal(self, out: TerminalWriter) -> None:
out.line(self.longrepr, red=True)
def pytest_report_to_serializable(
report: Union[CollectReport, TestReport]
) -> Optional[Dict[str, Any]]:
if isinstance(report, (TestReport, CollectReport)):
data = report._to_json()
data["$report_type"] = report.__class__.__name__
return data
# TODO: Check if this is actually reachable.
return None # type: ignore[unreachable]
def pytest_report_from_serializable(
data: Dict[str, Any],
) -> Optional[Union[CollectReport, TestReport]]:
if "$report_type" in data:
if data["$report_type"] == "TestReport":
return TestReport._from_json(data)
elif data["$report_type"] == "CollectReport":
return CollectReport._from_json(data)
assert False, "Unknown report_type unserialize data: {}".format(
data["$report_type"]
)
return None
def _report_to_json(report: BaseReport) -> Dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries,
suitable for serialization.
This was originally the serialize_report() function from xdist (ca03269).
"""
def serialize_repr_entry(
entry: Union[ReprEntry, ReprEntryNative]
) -> Dict[str, Any]:
data = attr.asdict(entry)
for key, value in data.items():
if hasattr(value, "__dict__"):
data[key] = attr.asdict(value)
entry_data = {"type": type(entry).__name__, "data": data}
return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback) -> Dict[str, Any]:
result = attr.asdict(reprtraceback)
result["reprentries"] = [
serialize_repr_entry(x) for x in reprtraceback.reprentries
]
return result
def serialize_repr_crash(
reprcrash: Optional[ReprFileLocation],
) -> Optional[Dict[str, Any]]:
if reprcrash is not None:
return attr.asdict(reprcrash)
else:
return None
def serialize_exception_longrepr(rep: BaseReport) -> Dict[str, Any]:
assert rep.longrepr is not None
# TODO: Investigate whether the duck typing is really necessary here.
longrepr = cast(ExceptionRepr, rep.longrepr)
result: Dict[str, Any] = {
"reprcrash": serialize_repr_crash(longrepr.reprcrash),
"reprtraceback": serialize_repr_traceback(longrepr.reprtraceback),
"sections": longrepr.sections,
}
if isinstance(longrepr, ExceptionChainRepr):
result["chain"] = []
for repr_traceback, repr_crash, description in longrepr.chain:
result["chain"].append(
(
serialize_repr_traceback(repr_traceback),
serialize_repr_crash(repr_crash),
description,
)
)
else:
result["chain"] = None
return result
d = report.__dict__.copy()
if hasattr(report.longrepr, "toterminal"):
if hasattr(report.longrepr, "reprtraceback") and hasattr(
report.longrepr, "reprcrash"
):
d["longrepr"] = serialize_exception_longrepr(report)
else:
d["longrepr"] = str(report.longrepr)
else:
d["longrepr"] = report.longrepr
for name in d:
if isinstance(d[name], os.PathLike):
d[name] = os.fspath(d[name])
elif name == "result":
d[name] = None # for now
return d
def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
"""Return **kwargs that can be used to construct a TestReport or
CollectReport instance.
This was originally the serialize_report() function from xdist (ca03269).
"""
def deserialize_repr_entry(entry_data):
data = entry_data["data"]
entry_type = entry_data["type"]
if entry_type == "ReprEntry":
reprfuncargs = None
reprfileloc = None
reprlocals = None
if data["reprfuncargs"]:
reprfuncargs = ReprFuncArgs(**data["reprfuncargs"])
if data["reprfileloc"]:
reprfileloc = ReprFileLocation(**data["reprfileloc"])
if data["reprlocals"]:
reprlocals = ReprLocals(data["reprlocals"]["lines"])
reprentry: Union[ReprEntry, ReprEntryNative] = ReprEntry(
lines=data["lines"],
reprfuncargs=reprfuncargs,
reprlocals=reprlocals,
reprfileloc=reprfileloc,
style=data["style"],
)
elif entry_type == "ReprEntryNative":
reprentry = ReprEntryNative(data["lines"])
else:
_report_unserialization_failure(entry_type, TestReport, reportdict)
return reprentry
def deserialize_repr_traceback(repr_traceback_dict):
repr_traceback_dict["reprentries"] = [
deserialize_repr_entry(x) for x in repr_traceback_dict["reprentries"]
]
return ReprTraceback(**repr_traceback_dict)
def deserialize_repr_crash(repr_crash_dict: Optional[Dict[str, Any]]):
if repr_crash_dict is not None:
return ReprFileLocation(**repr_crash_dict)
else:
return None
if (
reportdict["longrepr"]
and "reprcrash" in reportdict["longrepr"]
and "reprtraceback" in reportdict["longrepr"]
):
reprtraceback = deserialize_repr_traceback(
reportdict["longrepr"]["reprtraceback"]
)
reprcrash = deserialize_repr_crash(reportdict["longrepr"]["reprcrash"])
if reportdict["longrepr"]["chain"]:
chain = []
for repr_traceback_data, repr_crash_data, description in reportdict[
"longrepr"
]["chain"]:
chain.append(
(
deserialize_repr_traceback(repr_traceback_data),
deserialize_repr_crash(repr_crash_data),
description,
)
)
exception_info: Union[
ExceptionChainRepr, ReprExceptionInfo
] = ExceptionChainRepr(chain)
else:
exception_info = ReprExceptionInfo(reprtraceback, reprcrash)
for section in reportdict["longrepr"]["sections"]:
exception_info.addsection(*section)
reportdict["longrepr"] = exception_info
return reportdict

View File

@ -0,0 +1,542 @@
"""Basic collect and runtest protocol implementations."""
import bdb
import os
import sys
from typing import Callable
from typing import cast
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
from _pytest import timing
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest.compat import final
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.nodes import Node
from _pytest.outcomes import Exit
from _pytest.outcomes import OutcomeException
from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME
if TYPE_CHECKING:
from typing_extensions import Literal
from _pytest.main import Session
from _pytest.terminal import TerminalReporter
#
# pytest plugin hooks.
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "Reporting", after="general")
group.addoption(
"--durations",
action="store",
type=int,
default=None,
metavar="N",
help="Show N slowest setup/test durations (N=0 for all)",
)
group.addoption(
"--durations-min",
action="store",
type=float,
default=0.005,
metavar="N",
help="Minimal duration in seconds for inclusion in slowest list. "
"Default: 0.005.",
)
def pytest_terminal_summary(terminalreporter: "TerminalReporter") -> None:
durations = terminalreporter.config.option.durations
durations_min = terminalreporter.config.option.durations_min
verbose = terminalreporter.config.getvalue("verbose")
if durations is None:
return
tr = terminalreporter
dlist = []
for replist in tr.stats.values():
for rep in replist:
if hasattr(rep, "duration"):
dlist.append(rep)
if not dlist:
return
dlist.sort(key=lambda x: x.duration, reverse=True) # type: ignore[no-any-return]
if not durations:
tr.write_sep("=", "slowest durations")
else:
tr.write_sep("=", "slowest %s durations" % durations)
dlist = dlist[:durations]
for i, rep in enumerate(dlist):
if verbose < 2 and rep.duration < durations_min:
tr.write_line("")
tr.write_line(
"(%s durations < %gs hidden. Use -vv to show these durations.)"
% (len(dlist) - i, durations_min)
)
break
tr.write_line(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}")
def pytest_sessionstart(session: "Session") -> None:
session._setupstate = SetupState()
def pytest_sessionfinish(session: "Session") -> None:
session._setupstate.teardown_exact(None)
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
runtestprotocol(item, nextitem=nextitem)
ihook.pytest_runtest_logfinish(nodeid=item.nodeid, location=item.location)
return True
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
# This only happens if the item is re-run, as is done by
# pytest-rerunfailures.
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
reports.append(call_and_report(item, "call", log))
reports.append(call_and_report(item, "teardown", log, nextitem=nextitem))
# After all teardown hooks have been called
# want funcargs and request info to go away.
if hasrequest:
item._request = False # type: ignore[attr-defined]
item.funcargs = None # type: ignore[attr-defined]
return reports
def show_test_item(item: Item) -> None:
"""Show test function, parameters and the fixtures of the test item."""
tw = item.config.get_terminal_writer()
tw.line()
tw.write(" " * 8)
tw.write(item.nodeid)
used_fixtures = sorted(getattr(item, "fixturenames", []))
if used_fixtures:
tw.write(" (fixtures used: {})".format(", ".join(used_fixtures)))
tw.flush()
def pytest_runtest_setup(item: Item) -> None:
_update_current_test_var(item, "setup")
item.session._setupstate.setup(item)
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
item.runtest()
except Exception as e:
# Store trace info to allow postmortem debugging
sys.last_type = type(e)
sys.last_value = e
assert e.__traceback__ is not None
# Skip *this* frame
sys.last_traceback = e.__traceback__.tb_next
raise e
def pytest_runtest_teardown(item: Item, nextitem: Optional[Item]) -> None:
_update_current_test_var(item, "teardown")
item.session._setupstate.teardown_exact(nextitem)
_update_current_test_var(item, None)
def _update_current_test_var(
item: Item, when: Optional["Literal['setup', 'call', 'teardown']"]
) -> None:
"""Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage.
If ``when`` is None, delete ``PYTEST_CURRENT_TEST`` from the environment.
"""
var_name = "PYTEST_CURRENT_TEST"
if when:
value = f"{item.nodeid} ({when})"
# don't allow null bytes on environment variables (see #2644, #2957)
value = value.replace("\x00", "(null)")
os.environ[var_name] = value
else:
os.environ.pop(var_name)
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if report.when in ("setup", "teardown"):
if report.failed:
# category, shortletter, verbose-word
return "error", "E", "ERROR"
elif report.skipped:
return "skipped", "s", "SKIPPED"
else:
return "", "", ""
return None
#
# Implementation
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
call = call_runtest_hook(item, when, **kwds)
hook = item.ihook
report: TestReport = hook.pytest_runtest_makereport(item=item, call=call)
if log:
hook.pytest_runtest_logreport(report=report)
if check_interactive_exception(call, report):
hook.pytest_exception_interact(node=item, call=call, report=report)
return report
def check_interactive_exception(call: "CallInfo[object]", report: BaseReport) -> bool:
"""Check whether the call raised an exception that should be reported as
interactive."""
if call.excinfo is None:
# Didn't raise.
return False
if hasattr(report, "wasxfail"):
# Exception was expected.
return False
if isinstance(call.excinfo.value, (Skipped, bdb.BdbQuit)):
# Special control flow exception.
return False
return True
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
TResult = TypeVar("TResult", covariant=True)
@final
@attr.s(repr=False, init=False, auto_attribs=True)
class CallInfo(Generic[TResult]):
"""Result/Exception info of a function invocation."""
_result: Optional[TResult]
#: The captured exception of the call, if it raised.
excinfo: Optional[ExceptionInfo[BaseException]]
#: The system time when the call started, in seconds since the epoch.
start: float
#: The system time when the call ended, in seconds since the epoch.
stop: float
#: The call duration, in seconds.
duration: float
#: The context of invocation: "collect", "setup", "call" or "teardown".
when: "Literal['collect', 'setup', 'call', 'teardown']"
def __init__(
self,
result: Optional[TResult],
excinfo: Optional[ExceptionInfo[BaseException]],
start: float,
stop: float,
duration: float,
when: "Literal['collect', 'setup', 'call', 'teardown']",
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self._result = result
self.excinfo = excinfo
self.start = start
self.stop = stop
self.duration = duration
self.when = when
@property
def result(self) -> TResult:
"""The return value of the call, if it didn't raise.
Can only be accessed if excinfo is None.
"""
if self.excinfo is not None:
raise AttributeError(f"{self!r} has no valid result")
# The cast is safe because an exception wasn't raised, hence
# _result has the expected function return type (which may be
# None, that's why a cast and not an assert).
return cast(TResult, self._result)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
"""Call func, wrapping the result in a CallInfo.
:param func:
The function to call. Called without arguments.
:param when:
The phase in which the function is called.
:param reraise:
Exception or exceptions that shall propagate if raised by the
function, instead of being wrapped in the CallInfo.
"""
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
result: Optional[TResult] = func()
except BaseException:
excinfo = ExceptionInfo.from_current()
if reraise is not None and isinstance(excinfo.value, reraise):
raise
result = None
# use the perf counter
precise_stop = timing.perf_counter()
duration = precise_stop - precise_start
stop = timing.time()
return cls(
start=start,
stop=stop,
duration=duration,
when=when,
result=result,
excinfo=excinfo,
_ispytest=True,
)
def __repr__(self) -> str:
if self.excinfo is None:
return f"<CallInfo when={self.when!r} result: {self._result!r}>"
return f"<CallInfo when={self.when!r} excinfo={self.excinfo!r}>"
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
return TestReport.from_item_and_call(item, call)
def pytest_make_collect_report(collector: Collector) -> CollectReport:
call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
longrepr: Union[None, Tuple[str, int, str], str, TerminalRepr] = None
if not call.excinfo:
outcome: Literal["passed", "skipped", "failed"] = "passed"
else:
skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest")
if unittest is not None:
# Type ignored because unittest is loaded dynamically.
skip_exceptions.append(unittest.SkipTest) # type: ignore
if isinstance(call.excinfo.value, tuple(skip_exceptions)):
outcome = "skipped"
r_ = collector._repr_failure_py(call.excinfo, "line")
assert isinstance(r_, ExceptionChainRepr), repr(r_)
r = r_.reprcrash
assert r
longrepr = (str(r.path), r.lineno, r.message)
else:
outcome = "failed"
errorinfo = collector.repr_failure(call.excinfo)
if not hasattr(errorinfo, "toterminal"):
assert isinstance(errorinfo, str)
errorinfo = CollectErrorRepr(errorinfo)
longrepr = errorinfo
result = call.result if not call.excinfo else None
rep = CollectReport(collector.nodeid, outcome, longrepr, result)
rep.call = call # type: ignore # see collect_one_node
return rep
class SetupState:
"""Shared state for setting up/tearing down test items or collectors
in a session.
Suppose we have a collection tree as follows:
<Session session>
<Module mod1>
<Function item1>
<Module mod2>
<Function item2>
The SetupState maintains a stack. The stack starts out empty:
[]
During the setup phase of item1, setup(item1) is called. What it does
is:
push session to stack, run session.setup()
push mod1 to stack, run mod1.setup()
push item1 to stack, run item1.setup()
The stack is:
[session, mod1, item1]
While the stack is in this shape, it is allowed to add finalizers to
each of session, mod1, item1 using addfinalizer().
During the teardown phase of item1, teardown_exact(item2) is called,
where item2 is the next item to item1. What it does is:
pop item1 from stack, run its teardowns
pop mod1 from stack, run its teardowns
mod1 was popped because it ended its purpose with item1. The stack is:
[session]
During the setup phase of item2, setup(item2) is called. What it does
is:
push mod2 to stack, run mod2.setup()
push item2 to stack, run item2.setup()
Stack:
[session, mod2, item2]
During the teardown phase of item2, teardown_exact(None) is called,
because item2 is the last item. What it does is:
pop item2 from stack, run its teardowns
pop mod2 from stack, run its teardowns
pop session from stack, run its teardowns
Stack:
[]
The end!
"""
def __init__(self) -> None:
# The stack is in the dict insertion order.
self.stack: Dict[
Node,
Tuple[
# Node's finalizers.
List[Callable[[], object]],
# Node's exception, if its setup raised.
Optional[Union[OutcomeException, Exception]],
],
] = {}
def setup(self, item: Item) -> None:
"""Setup objects along the collector chain to the item."""
needed_collectors = item.listchain()
# If a collector fails its setup, fail its entire subtree of items.
# The setup is not retried for each item - the same exception is used.
for col, (finalizers, exc) in self.stack.items():
assert col in needed_collectors, "previous item was not torn down properly"
if exc:
raise exc
for col in needed_collectors[len(self.stack) :]:
assert col not in self.stack
# Push onto the stack.
self.stack[col] = ([col.teardown], None)
try:
col.setup()
except TEST_OUTCOME as exc:
self.stack[col] = (self.stack[col][0], exc)
raise exc
def addfinalizer(self, finalizer: Callable[[], object], node: Node) -> None:
"""Attach a finalizer to the given node.
The node must be currently active in the stack.
"""
assert node and not isinstance(node, tuple)
assert callable(finalizer)
assert node in self.stack, (node, self.stack)
self.stack[node][0].append(finalizer)
def teardown_exact(self, nextitem: Optional[Item]) -> None:
"""Teardown the current stack up until reaching nodes that nextitem
also descends from.
When nextitem is None (meaning we're at the last item), the entire
stack is torn down.
"""
needed_collectors = nextitem and nextitem.listchain() or []
exc = None
while self.stack:
if list(self.stack.keys()) == needed_collectors[: len(self.stack)]:
break
node, (finalizers, _) = self.stack.popitem()
while finalizers:
fin = finalizers.pop()
try:
fin()
except TEST_OUTCOME as e:
# XXX Only first exception will be seen by user,
# ideally all should be reported.
if exc is None:
exc = e
if exc:
raise exc
if nextitem is None:
assert not self.stack
def collect_one_node(collector: Collector) -> CollectReport:
ihook = collector.ihook
ihook.pytest_collectstart(collector=collector)
rep: CollectReport = ihook.pytest_make_collect_report(collector=collector)
call = rep.__dict__.pop("call", None)
if call and check_interactive_exception(call, rep):
ihook.pytest_exception_interact(node=collector, call=call, report=rep)
return rep

View File

@ -0,0 +1,91 @@
"""
Scope definition and related utilities.
Those are defined here, instead of in the 'fixtures' module because
their use is spread across many other pytest modules, and centralizing it in 'fixtures'
would cause circular references.
Also this makes the module light to import, as it should.
"""
from enum import Enum
from functools import total_ordering
from typing import Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing_extensions import Literal
_ScopeName = Literal["session", "package", "module", "class", "function"]
@total_ordering
class Scope(Enum):
"""
Represents one of the possible fixture scopes in pytest.
Scopes are ordered from lower to higher, that is:
->>> higher ->>>
Function < Class < Module < Package < Session
<<<- lower <<<-
"""
# Scopes need to be listed from lower to higher.
Function: "_ScopeName" = "function"
Class: "_ScopeName" = "class"
Module: "_ScopeName" = "module"
Package: "_ScopeName" = "package"
Session: "_ScopeName" = "session"
def next_lower(self) -> "Scope":
"""Return the next lower scope."""
index = _SCOPE_INDICES[self]
if index == 0:
raise ValueError(f"{self} is the lower-most scope")
return _ALL_SCOPES[index - 1]
def next_higher(self) -> "Scope":
"""Return the next higher scope."""
index = _SCOPE_INDICES[self]
if index == len(_SCOPE_INDICES) - 1:
raise ValueError(f"{self} is the upper-most scope")
return _ALL_SCOPES[index + 1]
def __lt__(self, other: "Scope") -> bool:
self_index = _SCOPE_INDICES[self]
other_index = _SCOPE_INDICES[other]
return self_index < other_index
@classmethod
def from_user(
cls, scope_name: "_ScopeName", descr: str, where: Optional[str] = None
) -> "Scope":
"""
Given a scope name from the user, return the equivalent Scope enum. Should be used
whenever we want to convert a user provided scope name to its enum object.
If the scope name is invalid, construct a user friendly message and call pytest.fail.
"""
from _pytest.outcomes import fail
try:
# Holding this reference is necessary for mypy at the moment.
scope = Scope(scope_name)
except ValueError:
fail(
"{} {}got an unexpected scope value '{}'".format(
descr, f"from {where} " if where else "", scope_name
),
pytrace=False,
)
return scope
_ALL_SCOPES = list(Scope)
_SCOPE_INDICES = {scope: index for index, scope in enumerate(_ALL_SCOPES)}
# Ordered list of scopes which can contain many tests (in practice all except Function).
HIGH_SCOPES = [x for x in Scope if x is not Scope.Function]

View File

@ -0,0 +1,97 @@
from typing import Generator
from typing import Optional
from typing import Union
import pytest
from _pytest._io.saferepr import saferepr
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
from _pytest.scope import Scope
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--setuponly",
"--setup-only",
action="store_true",
help="Only setup fixtures, do not execute tests",
)
group.addoption(
"--setupshow",
"--setup-show",
action="store_true",
help="Show setup of fixtures while executing tests",
)
@pytest.hookimpl(hookwrapper=True)
def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest
) -> Generator[None, None, None]:
yield
if request.config.option.setupshow:
if hasattr(request, "param"):
# Save the fixture parameter so ._show_fixture_action() can
# display it now and during the teardown (in .finish()).
if fixturedef.ids:
if callable(fixturedef.ids):
param = fixturedef.ids(request.param)
else:
param = fixturedef.ids[request.param_index]
else:
param = request.param
fixturedef.cached_param = param # type: ignore[attr-defined]
_show_fixture_action(fixturedef, "SETUP")
def pytest_fixture_post_finalizer(fixturedef: FixtureDef[object]) -> None:
if fixturedef.cached_result is not None:
config = fixturedef._fixturemanager.config
if config.option.setupshow:
_show_fixture_action(fixturedef, "TEARDOWN")
if hasattr(fixturedef, "cached_param"):
del fixturedef.cached_param # type: ignore[attr-defined]
def _show_fixture_action(fixturedef: FixtureDef[object], msg: str) -> None:
config = fixturedef._fixturemanager.config
capman = config.pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend_global_capture()
tw = config.get_terminal_writer()
tw.line()
# Use smaller indentation the higher the scope: Session = 0, Package = 1, etc.
scope_indent = list(reversed(Scope)).index(fixturedef._scope)
tw.write(" " * 2 * scope_indent)
tw.write(
"{step} {scope} {fixture}".format(
step=msg.ljust(8), # align the output to TEARDOWN
scope=fixturedef.scope[0].upper(),
fixture=fixturedef.argname,
)
)
if msg == "SETUP":
deps = sorted(arg for arg in fixturedef.argnames if arg != "request")
if deps:
tw.write(" (fixtures used: {})".format(", ".join(deps)))
if hasattr(fixturedef, "cached_param"):
tw.write(f"[{saferepr(fixturedef.cached_param, maxsize=42)}]") # type: ignore[attr-defined]
tw.flush()
if capman:
capman.resume_global_capture()
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.setuponly:
config.option.setupshow = True
return None

View File

@ -0,0 +1,40 @@
from typing import Optional
from typing import Union
import pytest
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--setupplan",
"--setup-plan",
action="store_true",
help="Show what fixtures and tests would be executed but "
"don't execute anything",
)
@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest
) -> Optional[object]:
# Will return a dummy fixture if the setuponly option is provided.
if request.config.option.setupplan:
my_cache_key = fixturedef.cache_key(request)
fixturedef.cached_result = (None, my_cache_key, None)
return fixturedef.cached_result
return None
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.setupplan:
config.option.setuponly = True
config.option.setupshow = True
return None

View File

@ -0,0 +1,296 @@
"""Support for skip/xfail functions and markers."""
import os
import platform
import sys
import traceback
from collections.abc import Mapping
from typing import Generator
from typing import Optional
from typing import Tuple
from typing import Type
import attr
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.mark.structures import Mark
from _pytest.nodes import Item
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.reports import BaseReport
from _pytest.runner import CallInfo
from _pytest.stash import StashKey
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--runxfail",
action="store_true",
dest="runxfail",
default=False,
help="Report the results of xfail tests as if they were not marked",
)
parser.addini(
"xfail_strict",
"Default for the strict parameter of xfail "
"markers when not given explicitly (default: False)",
default=False,
type="bool",
)
def pytest_configure(config: Config) -> None:
if config.option.runxfail:
# yay a hack
import pytest
old = pytest.xfail
config.add_cleanup(lambda: setattr(pytest, "xfail", old))
def nop(*args, **kwargs):
pass
nop.Exception = xfail.Exception # type: ignore[attr-defined]
setattr(pytest, "xfail", nop)
config.addinivalue_line(
"markers",
"skip(reason=None): skip the given test function with an optional reason. "
'Example: skip(reason="no way of currently testing this") skips the '
"test.",
)
config.addinivalue_line(
"markers",
"skipif(condition, ..., *, reason=...): "
"skip the given test function if any of the conditions evaluate to True. "
"Example: skipif(sys.platform == 'win32') skips the test if we are on the win32 platform. "
"See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-skipif",
)
config.addinivalue_line(
"markers",
"xfail(condition, ..., *, reason=..., run=True, raises=None, strict=xfail_strict): "
"mark the test function as an expected failure if any of the conditions "
"evaluate to True. Optionally specify a reason for better reporting "
"and run=False if you don't even want to execute the test function. "
"If only specific exception(s) are expected, you can list them in "
"raises, and if the test fails in other ways, it will be reported as "
"a true failure. See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-xfail",
)
def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool, str]:
"""Evaluate a single skipif/xfail condition.
If an old-style string condition is given, it is eval()'d, otherwise the
condition is bool()'d. If this fails, an appropriately formatted pytest.fail
is raised.
Returns (result, reason). The reason is only relevant if the result is True.
"""
# String condition.
if isinstance(condition, str):
globals_ = {
"os": os,
"sys": sys,
"platform": platform,
"config": item.config,
}
for dictionary in reversed(
item.ihook.pytest_markeval_namespace(config=item.config)
):
if not isinstance(dictionary, Mapping):
raise ValueError(
"pytest_markeval_namespace() needs to return a dict, got {!r}".format(
dictionary
)
)
globals_.update(dictionary)
if hasattr(item, "obj"):
globals_.update(item.obj.__globals__) # type: ignore[attr-defined]
try:
filename = f"<{mark.name} condition>"
condition_code = compile(condition, filename, "eval")
result = eval(condition_code, globals_)
except SyntaxError as exc:
msglines = [
"Error evaluating %r condition" % mark.name,
" " + condition,
" " + " " * (exc.offset or 0) + "^",
"SyntaxError: invalid syntax",
]
fail("\n".join(msglines), pytrace=False)
except Exception as exc:
msglines = [
"Error evaluating %r condition" % mark.name,
" " + condition,
*traceback.format_exception_only(type(exc), exc),
]
fail("\n".join(msglines), pytrace=False)
# Boolean condition.
else:
try:
result = bool(condition)
except Exception as exc:
msglines = [
"Error evaluating %r condition as a boolean" % mark.name,
*traceback.format_exception_only(type(exc), exc),
]
fail("\n".join(msglines), pytrace=False)
reason = mark.kwargs.get("reason", None)
if reason is None:
if isinstance(condition, str):
reason = "condition: " + condition
else:
# XXX better be checked at collection time
msg = (
"Error evaluating %r: " % mark.name
+ "you need to specify reason=STRING when using booleans as conditions."
)
fail(msg, pytrace=False)
return result, reason
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Skip:
"""The result of evaluate_skip_marks()."""
reason: str = "unconditional skip"
def evaluate_skip_marks(item: Item) -> Optional[Skip]:
"""Evaluate skip and skipif marks on item, returning Skip if triggered."""
for mark in item.iter_markers(name="skipif"):
if "condition" not in mark.kwargs:
conditions = mark.args
else:
conditions = (mark.kwargs["condition"],)
# Unconditional.
if not conditions:
reason = mark.kwargs.get("reason", "")
return Skip(reason)
# If any of the conditions are true.
for condition in conditions:
result, reason = evaluate_condition(item, mark, condition)
if result:
return Skip(reason)
for mark in item.iter_markers(name="skip"):
try:
return Skip(*mark.args, **mark.kwargs)
except TypeError as e:
raise TypeError(str(e) + " - maybe you meant pytest.mark.skipif?") from None
return None
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Xfail:
"""The result of evaluate_xfail_marks()."""
reason: str
run: bool
strict: bool
raises: Optional[Tuple[Type[BaseException], ...]]
def evaluate_xfail_marks(item: Item) -> Optional[Xfail]:
"""Evaluate xfail marks on item, returning Xfail if triggered."""
for mark in item.iter_markers(name="xfail"):
run = mark.kwargs.get("run", True)
strict = mark.kwargs.get("strict", item.config.getini("xfail_strict"))
raises = mark.kwargs.get("raises", None)
if "condition" not in mark.kwargs:
conditions = mark.args
else:
conditions = (mark.kwargs["condition"],)
# Unconditional.
if not conditions:
reason = mark.kwargs.get("reason", "")
return Xfail(reason, run, strict, raises)
# If any of the conditions are true.
for condition in conditions:
result, reason = evaluate_condition(item, mark, condition)
if result:
return Xfail(reason, run, strict, raises)
return None
# Saves the xfail mark evaluation. Can be refreshed during call if None.
xfailed_key = StashKey[Optional[Xfail]]()
@hookimpl(tryfirst=True)
def pytest_runtest_setup(item: Item) -> None:
skipped = evaluate_skip_marks(item)
if skipped:
raise skip.Exception(skipped.reason, _use_item_location=True)
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
if xfailed and not item.config.option.runxfail and not xfailed.run:
xfail("[NOTRUN] " + xfailed.reason)
@hookimpl(hookwrapper=True)
def pytest_runtest_call(item: Item) -> Generator[None, None, None]:
xfailed = item.stash.get(xfailed_key, None)
if xfailed is None:
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
if xfailed and not item.config.option.runxfail and not xfailed.run:
xfail("[NOTRUN] " + xfailed.reason)
yield
# The test run may have added an xfail mark dynamically.
xfailed = item.stash.get(xfailed_key, None)
if xfailed is None:
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
@hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item: Item, call: CallInfo[None]):
outcome = yield
rep = outcome.get_result()
xfailed = item.stash.get(xfailed_key, None)
if item.config.option.runxfail:
pass # don't interfere
elif call.excinfo and isinstance(call.excinfo.value, xfail.Exception):
assert call.excinfo.value.msg is not None
rep.wasxfail = "reason: " + call.excinfo.value.msg
rep.outcome = "skipped"
elif not rep.skipped and xfailed:
if call.excinfo:
raises = xfailed.raises
if raises is not None and not isinstance(call.excinfo.value, raises):
rep.outcome = "failed"
else:
rep.outcome = "skipped"
rep.wasxfail = xfailed.reason
elif call.when == "call":
if xfailed.strict:
rep.outcome = "failed"
rep.longrepr = "[XPASS(strict)] " + xfailed.reason
else:
rep.outcome = "passed"
rep.wasxfail = xfailed.reason
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if hasattr(report, "wasxfail"):
if report.skipped:
return "xfailed", "x", "XFAIL"
elif report.passed:
return "xpassed", "X", "XPASS"
return None

View File

@ -0,0 +1,112 @@
from typing import Any
from typing import cast
from typing import Dict
from typing import Generic
from typing import TypeVar
from typing import Union
__all__ = ["Stash", "StashKey"]
T = TypeVar("T")
D = TypeVar("D")
class StashKey(Generic[T]):
"""``StashKey`` is an object used as a key to a :class:`Stash`.
A ``StashKey`` is associated with the type ``T`` of the value of the key.
A ``StashKey`` is unique and cannot conflict with another key.
"""
__slots__ = ()
class Stash:
r"""``Stash`` is a type-safe heterogeneous mutable mapping that
allows keys and value types to be defined separately from
where it (the ``Stash``) is created.
Usually you will be given an object which has a ``Stash``, for example
:class:`~pytest.Config` or a :class:`~_pytest.nodes.Node`:
.. code-block:: python
stash: Stash = some_object.stash
If a module or plugin wants to store data in this ``Stash``, it creates
:class:`StashKey`\s for its keys (at the module level):
.. code-block:: python
# At the top-level of the module
some_str_key = StashKey[str]()
some_bool_key = StashKey[bool]()
To store information:
.. code-block:: python
# Value type must match the key.
stash[some_str_key] = "value"
stash[some_bool_key] = True
To retrieve the information:
.. code-block:: python
# The static type of some_str is str.
some_str = stash[some_str_key]
# The static type of some_bool is bool.
some_bool = stash[some_bool_key]
"""
__slots__ = ("_storage",)
def __init__(self) -> None:
self._storage: Dict[StashKey[Any], object] = {}
def __setitem__(self, key: StashKey[T], value: T) -> None:
"""Set a value for key."""
self._storage[key] = value
def __getitem__(self, key: StashKey[T]) -> T:
"""Get the value for key.
Raises ``KeyError`` if the key wasn't set before.
"""
return cast(T, self._storage[key])
def get(self, key: StashKey[T], default: D) -> Union[T, D]:
"""Get the value for key, or return default if the key wasn't set
before."""
try:
return self[key]
except KeyError:
return default
def setdefault(self, key: StashKey[T], default: T) -> T:
"""Return the value of key if already set, otherwise set the value
of key to default and return default."""
try:
return self[key]
except KeyError:
self[key] = default
return default
def __delitem__(self, key: StashKey[T]) -> None:
"""Delete the value for key.
Raises ``KeyError`` if the key wasn't set before.
"""
del self._storage[key]
def __contains__(self, key: StashKey[T]) -> bool:
"""Return whether key was set."""
return key in self._storage
def __len__(self) -> int:
"""Return how many items exist in the stash."""
return len(self._storage)

View File

@ -0,0 +1,122 @@
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from _pytest import nodes
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.main import Session
from _pytest.reports import TestReport
if TYPE_CHECKING:
from _pytest.cacheprovider import Cache
STEPWISE_CACHE_DIR = "cache/stepwise"
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--sw",
"--stepwise",
action="store_true",
default=False,
dest="stepwise",
help="Exit on test failure and continue from last failing test next time",
)
group.addoption(
"--sw-skip",
"--stepwise-skip",
action="store_true",
default=False,
dest="stepwise_skip",
help="Ignore the first failing test but stop on the next failing test. "
"Implicitly enables --stepwise.",
)
@pytest.hookimpl
def pytest_configure(config: Config) -> None:
if config.option.stepwise_skip:
# allow --stepwise-skip to work on it's own merits.
config.option.stepwise = True
if config.getoption("stepwise"):
config.pluginmanager.register(StepwisePlugin(config), "stepwiseplugin")
def pytest_sessionfinish(session: Session) -> None:
if not session.config.getoption("stepwise"):
assert session.config.cache is not None
# Clear the list of failing tests if the plugin is not active.
session.config.cache.set(STEPWISE_CACHE_DIR, [])
class StepwisePlugin:
def __init__(self, config: Config) -> None:
self.config = config
self.session: Optional[Session] = None
self.report_status = ""
assert config.cache is not None
self.cache: Cache = config.cache
self.lastfailed: Optional[str] = self.cache.get(STEPWISE_CACHE_DIR, None)
self.skip: bool = config.getoption("stepwise_skip")
def pytest_sessionstart(self, session: Session) -> None:
self.session = session
def pytest_collection_modifyitems(
self, config: Config, items: List[nodes.Item]
) -> None:
if not self.lastfailed:
self.report_status = "no previously failed tests, not skipping."
return
# check all item nodes until we find a match on last failed
failed_index = None
for index, item in enumerate(items):
if item.nodeid == self.lastfailed:
failed_index = index
break
# If the previously failed test was not found among the test items,
# do not skip any tests.
if failed_index is None:
self.report_status = "previously failed test not found, not skipping."
else:
self.report_status = f"skipping {failed_index} already passed items."
deselected = items[:failed_index]
del items[:failed_index]
config.hook.pytest_deselected(items=deselected)
def pytest_runtest_logreport(self, report: TestReport) -> None:
if report.failed:
if self.skip:
# Remove test from the failed ones (if it exists) and unset the skip option
# to make sure the following tests will not be skipped.
if report.nodeid == self.lastfailed:
self.lastfailed = None
self.skip = False
else:
# Mark test as the last failing and interrupt the test session.
self.lastfailed = report.nodeid
assert self.session is not None
self.session.shouldstop = (
"Test failed, continuing from this test next run."
)
else:
# If the test was actually run and did pass.
if report.when == "call":
# Remove test from the failed ones, if exists.
if report.nodeid == self.lastfailed:
self.lastfailed = None
def pytest_report_collectionfinish(self) -> Optional[str]:
if self.config.getoption("verbose") >= 0 and self.report_status:
return f"stepwise: {self.report_status}"
return None
def pytest_sessionfinish(self) -> None:
self.cache.set(STEPWISE_CACHE_DIR, self.lastfailed)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,88 @@
import threading
import traceback
import warnings
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Generator
from typing import Optional
from typing import Type
import pytest
# Copied from cpython/Lib/test/support/threading_helper.py, with modifications.
class catch_threading_exception:
"""Context manager catching threading.Thread exception using
threading.excepthook.
Storing exc_value using a custom hook can create a reference cycle. The
reference cycle is broken explicitly when the context manager exits.
Storing thread using a custom hook can resurrect it if it is set to an
object which is being finalized. Exiting the context manager clears the
stored object.
Usage:
with threading_helper.catch_threading_exception() as cm:
# code spawning a thread which raises an exception
...
# check the thread exception: use cm.args
...
# cm.args attribute no longer exists at this point
# (to break a reference cycle)
"""
def __init__(self) -> None:
self.args: Optional["threading.ExceptHookArgs"] = None
self._old_hook: Optional[Callable[["threading.ExceptHookArgs"], Any]] = None
def _hook(self, args: "threading.ExceptHookArgs") -> None:
self.args = args
def __enter__(self) -> "catch_threading_exception":
self._old_hook = threading.excepthook
threading.excepthook = self._hook
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
assert self._old_hook is not None
threading.excepthook = self._old_hook
self._old_hook = None
del self.args
def thread_exception_runtest_hook() -> Generator[None, None, None]:
with catch_threading_exception() as cm:
yield
if cm.args:
thread_name = "<unknown>" if cm.args.thread is None else cm.args.thread.name
msg = f"Exception in thread {thread_name}\n\n"
msg += "".join(
traceback.format_exception(
cm.args.exc_type,
cm.args.exc_value,
cm.args.exc_traceback,
)
)
warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))
@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_runtest_setup() -> Generator[None, None, None]:
yield from thread_exception_runtest_hook()
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_call() -> Generator[None, None, None]:
yield from thread_exception_runtest_hook()
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_teardown() -> Generator[None, None, None]:
yield from thread_exception_runtest_hook()

View File

@ -0,0 +1,12 @@
"""Indirection for time functions.
We intentionally grab some "time" functions internally to avoid tests mocking "time" to affect
pytest runtime information (issue #185).
Fixture "mock_timing" also interacts with this module for pytest's own tests.
"""
from time import perf_counter
from time import sleep
from time import time
__all__ = ["perf_counter", "sleep", "time"]

View File

@ -0,0 +1,216 @@
"""Support for providing temporary directories to test functions."""
import os
import re
import sys
import tempfile
from pathlib import Path
from typing import Optional
import attr
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
from .pathlib import rm_rf
from _pytest.compat import final
from _pytest.config import Config
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.monkeypatch import MonkeyPatch
@final
@attr.s(init=False)
class TempPathFactory:
"""Factory for temporary directories under the common base temp directory.
The base directory can be configured using the ``--basetemp`` option.
"""
_given_basetemp = attr.ib(type=Optional[Path])
_trace = attr.ib()
_basetemp = attr.ib(type=Optional[Path])
def __init__(
self,
given_basetemp: Optional[Path],
trace,
basetemp: Optional[Path] = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
if given_basetemp is None:
self._given_basetemp = None
else:
# Use os.path.abspath() to get absolute path instead of resolve() as it
# does not work the same in all platforms (see #4427).
# Path.absolute() exists, but it is not public (see https://bugs.python.org/issue25012).
self._given_basetemp = Path(os.path.abspath(str(given_basetemp)))
self._trace = trace
self._basetemp = basetemp
@classmethod
def from_config(
cls,
config: Config,
*,
_ispytest: bool = False,
) -> "TempPathFactory":
"""Create a factory according to pytest configuration.
:meta private:
"""
check_ispytest(_ispytest)
return cls(
given_basetemp=config.option.basetemp,
trace=config.trace.get("tmpdir"),
_ispytest=True,
)
def _ensure_relative_to_basetemp(self, basename: str) -> str:
basename = os.path.normpath(basename)
if (self.getbasetemp() / basename).resolve().parent != self.getbasetemp():
raise ValueError(f"{basename} is not a normalized and relative path")
return basename
def mktemp(self, basename: str, numbered: bool = True) -> Path:
"""Create a new temporary directory managed by the factory.
:param basename:
Directory base name, must be a relative path.
:param numbered:
If ``True``, ensure the directory is unique by adding a numbered
suffix greater than any existing one: ``basename="foo-"`` and ``numbered=True``
means that this function will create directories named ``"foo-0"``,
``"foo-1"``, ``"foo-2"`` and so on.
:returns:
The path to the new directory.
"""
basename = self._ensure_relative_to_basetemp(basename)
if not numbered:
p = self.getbasetemp().joinpath(basename)
p.mkdir(mode=0o700)
else:
p = make_numbered_dir(root=self.getbasetemp(), prefix=basename, mode=0o700)
self._trace("mktemp", p)
return p
def getbasetemp(self) -> Path:
"""Return the base temporary directory, creating it if needed.
:returns:
The base temporary directory.
"""
if self._basetemp is not None:
return self._basetemp
if self._given_basetemp is not None:
basetemp = self._given_basetemp
if basetemp.exists():
rm_rf(basetemp)
basetemp.mkdir(mode=0o700)
basetemp = basetemp.resolve()
else:
from_env = os.environ.get("PYTEST_DEBUG_TEMPROOT")
temproot = Path(from_env or tempfile.gettempdir()).resolve()
user = get_user() or "unknown"
# use a sub-directory in the temproot to speed-up
# make_numbered_dir() call
rootdir = temproot.joinpath(f"pytest-of-{user}")
try:
rootdir.mkdir(mode=0o700, exist_ok=True)
except OSError:
# getuser() likely returned illegal characters for the platform, use unknown back off mechanism
rootdir = temproot.joinpath("pytest-of-unknown")
rootdir.mkdir(mode=0o700, exist_ok=True)
# Because we use exist_ok=True with a predictable name, make sure
# we are the owners, to prevent any funny business (on unix, where
# temproot is usually shared).
# Also, to keep things private, fixup any world-readable temp
# rootdir's permissions. Historically 0o755 was used, so we can't
# just error out on this, at least for a while.
if sys.platform != "win32":
uid = os.getuid()
rootdir_stat = rootdir.stat()
# getuid shouldn't fail, but cpython defines such a case.
# Let's hope for the best.
if uid != -1:
if rootdir_stat.st_uid != uid:
raise OSError(
f"The temporary directory {rootdir} is not owned by the current user. "
"Fix this and try again."
)
if (rootdir_stat.st_mode & 0o077) != 0:
os.chmod(rootdir, rootdir_stat.st_mode & ~0o077)
basetemp = make_numbered_dir_with_cleanup(
prefix="pytest-",
root=rootdir,
keep=3,
lock_timeout=LOCK_TIMEOUT,
mode=0o700,
)
assert basetemp is not None, basetemp
self._basetemp = basetemp
self._trace("new basetemp", basetemp)
return basetemp
def get_user() -> Optional[str]:
"""Return the current user name, or None if getuser() does not work
in the current environment (see #1010)."""
try:
# In some exotic environments, getpass may not be importable.
import getpass
return getpass.getuser()
except (ImportError, KeyError):
return None
def pytest_configure(config: Config) -> None:
"""Create a TempPathFactory and attach it to the config object.
This is to comply with existing plugins which expect the handler to be
available at pytest_configure time, but ideally should be moved entirely
to the tmp_path_factory session fixture.
"""
mp = MonkeyPatch()
config.add_cleanup(mp.undo)
_tmp_path_factory = TempPathFactory.from_config(config, _ispytest=True)
mp.setattr(config, "_tmp_path_factory", _tmp_path_factory, raising=False)
@fixture(scope="session")
def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
"""Return a :class:`pytest.TempPathFactory` instance for the test session."""
# Set dynamically by pytest_configure() above.
return request.config._tmp_path_factory # type: ignore
def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
name = request.node.name
name = re.sub(r"[\W]", "_", name)
MAXVAL = 30
name = name[:MAXVAL]
return factory.mktemp(name, numbered=True)
@fixture
def tmp_path(request: FixtureRequest, tmp_path_factory: TempPathFactory) -> Path:
"""Return a temporary directory path object which is unique to each test
function invocation, created as a sub directory of the base temporary
directory.
By default, a new base temporary directory is created each test session,
and old bases are removed after 3 sessions, to aid in debugging. If
``--basetemp`` is used then it is cleared each session. See :ref:`base
temporary directory`.
The returned object is a :class:`pathlib.Path` object.
"""
return _mk_tmp(request, tmp_path_factory)

View File

@ -0,0 +1,417 @@
"""Discover and run std-library "unittest" style tests."""
import sys
import traceback
import types
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
import _pytest._code
import pytest
from _pytest.compat import getimfunc
from _pytest.compat import is_async_function
from _pytest.config import hookimpl
from _pytest.fixtures import FixtureRequest
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import exit
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.python import Class
from _pytest.python import Function
from _pytest.python import Module
from _pytest.runner import CallInfo
from _pytest.scope import Scope
if TYPE_CHECKING:
import unittest
import twisted.trial.unittest
_SysExcInfoType = Union[
Tuple[Type[BaseException], BaseException, types.TracebackType],
Tuple[None, None, None],
]
def pytest_pycollect_makeitem(
collector: Union[Module, Class], name: str, obj: object
) -> Optional["UnitTestCase"]:
# Has unittest been imported and is obj a subclass of its TestCase?
try:
ut = sys.modules["unittest"]
# Type ignored because `ut` is an opaque module.
if not issubclass(obj, ut.TestCase): # type: ignore
return None
except Exception:
return None
# Yes, so let's collect it.
item: UnitTestCase = UnitTestCase.from_parent(collector, name=name, obj=obj)
return item
class UnitTestCase(Class):
# Marker for fixturemanger.getfixtureinfo()
# to declare that our children do not support funcargs.
nofuncargs = True
def collect(self) -> Iterable[Union[Item, Collector]]:
from unittest import TestLoader
cls = self.obj
if not getattr(cls, "__test__", True):
return
skipped = _is_skipped(cls)
if not skipped:
self._inject_setup_teardown_fixtures(cls)
self._inject_setup_class_fixture()
self.session._fixturemanager.parsefactories(self, unittest=True)
loader = TestLoader()
foundsomething = False
for name in loader.getTestCaseNames(self.obj):
x = getattr(self.obj, name)
if not getattr(x, "__test__", True):
continue
funcobj = getimfunc(x)
yield TestCaseFunction.from_parent(self, name=name, callobj=funcobj)
foundsomething = True
if not foundsomething:
runtest = getattr(self.obj, "runTest", None)
if runtest is not None:
ut = sys.modules.get("twisted.trial.unittest", None)
# Type ignored because `ut` is an opaque module.
if ut is None or runtest != ut.TestCase.runTest: # type: ignore
yield TestCaseFunction.from_parent(self, name="runTest")
def _inject_setup_teardown_fixtures(self, cls: type) -> None:
"""Injects a hidden auto-use fixture to invoke setUpClass/setup_method and corresponding
teardown functions (#517)."""
class_fixture = _make_xunit_fixture(
cls,
"setUpClass",
"tearDownClass",
"doClassCleanups",
scope=Scope.Class,
pass_self=False,
)
if class_fixture:
cls.__pytest_class_setup = class_fixture # type: ignore[attr-defined]
method_fixture = _make_xunit_fixture(
cls,
"setup_method",
"teardown_method",
None,
scope=Scope.Function,
pass_self=True,
)
if method_fixture:
cls.__pytest_method_setup = method_fixture # type: ignore[attr-defined]
def _make_xunit_fixture(
obj: type,
setup_name: str,
teardown_name: str,
cleanup_name: Optional[str],
scope: Scope,
pass_self: bool,
):
setup = getattr(obj, setup_name, None)
teardown = getattr(obj, teardown_name, None)
if setup is None and teardown is None:
return None
if cleanup_name:
cleanup = getattr(obj, cleanup_name, lambda *args: None)
else:
def cleanup(*args):
pass
@pytest.fixture(
scope=scope.value,
autouse=True,
# Use a unique name to speed up lookup.
name=f"_unittest_{setup_name}_fixture_{obj.__qualname__}",
)
def fixture(self, request: FixtureRequest) -> Generator[None, None, None]:
if _is_skipped(self):
reason = self.__unittest_skip_why__
raise pytest.skip.Exception(reason, _use_item_location=True)
if setup is not None:
try:
if pass_self:
setup(self, request.function)
else:
setup()
# unittest does not call the cleanup function for every BaseException, so we
# follow this here.
except Exception:
if pass_self:
cleanup(self)
else:
cleanup()
raise
yield
try:
if teardown is not None:
if pass_self:
teardown(self, request.function)
else:
teardown()
finally:
if pass_self:
cleanup(self)
else:
cleanup()
return fixture
class TestCaseFunction(Function):
nofuncargs = True
_excinfo: Optional[List[_pytest._code.ExceptionInfo[BaseException]]] = None
_testcase: Optional["unittest.TestCase"] = None
def _getobj(self):
assert self.parent is not None
# Unlike a regular Function in a Class, where `item.obj` returns
# a *bound* method (attached to an instance), TestCaseFunction's
# `obj` returns an *unbound* method (not attached to an instance).
# This inconsistency is probably not desirable, but needs some
# consideration before changing.
return getattr(self.parent.obj, self.originalname) # type: ignore[attr-defined]
def setup(self) -> None:
# A bound method to be called during teardown() if set (see 'runtest()').
self._explicit_tearDown: Optional[Callable[[], None]] = None
assert self.parent is not None
self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined]
self._obj = getattr(self._testcase, self.name)
if hasattr(self, "_request"):
self._request._fillfixtures()
def teardown(self) -> None:
if self._explicit_tearDown is not None:
self._explicit_tearDown()
self._explicit_tearDown = None
self._testcase = None
self._obj = None
def startTest(self, testcase: "unittest.TestCase") -> None:
pass
def _addexcinfo(self, rawexcinfo: "_SysExcInfoType") -> None:
# Unwrap potential exception info (see twisted trial support below).
rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo)
try:
excinfo = _pytest._code.ExceptionInfo[BaseException].from_exc_info(rawexcinfo) # type: ignore[arg-type]
# Invoke the attributes to trigger storing the traceback
# trial causes some issue there.
excinfo.value
excinfo.traceback
except TypeError:
try:
try:
values = traceback.format_exception(*rawexcinfo)
values.insert(
0,
"NOTE: Incompatible Exception Representation, "
"displaying natively:\n\n",
)
fail("".join(values), pytrace=False)
except (fail.Exception, KeyboardInterrupt):
raise
except BaseException:
fail(
"ERROR: Unknown Incompatible Exception "
"representation:\n%r" % (rawexcinfo,),
pytrace=False,
)
except KeyboardInterrupt:
raise
except fail.Exception:
excinfo = _pytest._code.ExceptionInfo.from_current()
self.__dict__.setdefault("_excinfo", []).append(excinfo)
def addError(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
) -> None:
try:
if isinstance(rawexcinfo[1], exit.Exception):
exit(rawexcinfo[1].msg)
except TypeError:
pass
self._addexcinfo(rawexcinfo)
def addFailure(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
) -> None:
self._addexcinfo(rawexcinfo)
def addSkip(self, testcase: "unittest.TestCase", reason: str) -> None:
try:
raise pytest.skip.Exception(reason, _use_item_location=True)
except skip.Exception:
self._addexcinfo(sys.exc_info())
def addExpectedFailure(
self,
testcase: "unittest.TestCase",
rawexcinfo: "_SysExcInfoType",
reason: str = "",
) -> None:
try:
xfail(str(reason))
except xfail.Exception:
self._addexcinfo(sys.exc_info())
def addUnexpectedSuccess(
self,
testcase: "unittest.TestCase",
reason: Optional["twisted.trial.unittest.Todo"] = None,
) -> None:
msg = "Unexpected success"
if reason:
msg += f": {reason.reason}"
# Preserve unittest behaviour - fail the test. Explicitly not an XPASS.
try:
fail(msg, pytrace=False)
except fail.Exception:
self._addexcinfo(sys.exc_info())
def addSuccess(self, testcase: "unittest.TestCase") -> None:
pass
def stopTest(self, testcase: "unittest.TestCase") -> None:
pass
def runtest(self) -> None:
from _pytest.debugging import maybe_wrap_pytest_function_for_tracing
assert self._testcase is not None
maybe_wrap_pytest_function_for_tracing(self)
# Let the unittest framework handle async functions.
if is_async_function(self.obj):
# Type ignored because self acts as the TestResult, but is not actually one.
self._testcase(result=self) # type: ignore[arg-type]
else:
# When --pdb is given, we want to postpone calling tearDown() otherwise
# when entering the pdb prompt, tearDown() would have probably cleaned up
# instance variables, which makes it difficult to debug.
# Arguably we could always postpone tearDown(), but this changes the moment where the
# TestCase instance interacts with the results object, so better to only do it
# when absolutely needed.
# We need to consider if the test itself is skipped, or the whole class.
assert isinstance(self.parent, UnitTestCase)
skipped = _is_skipped(self.obj) or _is_skipped(self.parent.obj)
if self.config.getoption("usepdb") and not skipped:
self._explicit_tearDown = self._testcase.tearDown
setattr(self._testcase, "tearDown", lambda *args: None)
# We need to update the actual bound method with self.obj, because
# wrap_pytest_function_for_tracing replaces self.obj by a wrapper.
setattr(self._testcase, self.name, self.obj)
try:
self._testcase(result=self) # type: ignore[arg-type]
finally:
delattr(self._testcase, self.name)
def _prunetraceback(
self, excinfo: _pytest._code.ExceptionInfo[BaseException]
) -> None:
super()._prunetraceback(excinfo)
traceback = excinfo.traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest")
)
if traceback:
excinfo.traceback = traceback
@hookimpl(tryfirst=True)
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None:
if isinstance(item, TestCaseFunction):
if item._excinfo:
call.excinfo = item._excinfo.pop(0)
try:
del call.result
except AttributeError:
pass
# Convert unittest.SkipTest to pytest.skip.
# This is actually only needed for nose, which reuses unittest.SkipTest for
# its own nose.SkipTest. For unittest TestCases, SkipTest is already
# handled internally, and doesn't reach here.
unittest = sys.modules.get("unittest")
if (
unittest
and call.excinfo
and isinstance(call.excinfo.value, unittest.SkipTest) # type: ignore[attr-defined]
):
excinfo = call.excinfo
call2 = CallInfo[None].from_call(
lambda: pytest.skip(str(excinfo.value)), call.when
)
call.excinfo = call2.excinfo
# Twisted trial support.
@hookimpl(hookwrapper=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
if isinstance(item, TestCaseFunction) and "twisted.trial.unittest" in sys.modules:
ut: Any = sys.modules["twisted.python.failure"]
Failure__init__ = ut.Failure.__init__
check_testcase_implements_trial_reporter()
def excstore(
self, exc_value=None, exc_type=None, exc_tb=None, captureVars=None
):
if exc_value is None:
self._rawexcinfo = sys.exc_info()
else:
if exc_type is None:
exc_type = type(exc_value)
self._rawexcinfo = (exc_type, exc_value, exc_tb)
try:
Failure__init__(
self, exc_value, exc_type, exc_tb, captureVars=captureVars
)
except TypeError:
Failure__init__(self, exc_value, exc_type, exc_tb)
ut.Failure.__init__ = excstore
yield
ut.Failure.__init__ = Failure__init__
else:
yield
def check_testcase_implements_trial_reporter(done: List[int] = []) -> None:
if done:
return
from zope.interface import classImplements
from twisted.trial.itrial import IReporter
classImplements(TestCaseFunction, IReporter)
done.append(1)
def _is_skipped(obj) -> bool:
"""Return True if the given object has been marked with @unittest.skip."""
return bool(getattr(obj, "__unittest_skip__", False))

View File

@ -0,0 +1,93 @@
import sys
import traceback
import warnings
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Generator
from typing import Optional
from typing import Type
import pytest
# Copied from cpython/Lib/test/support/__init__.py, with modifications.
class catch_unraisable_exception:
"""Context manager catching unraisable exception using sys.unraisablehook.
Storing the exception value (cm.unraisable.exc_value) creates a reference
cycle. The reference cycle is broken explicitly when the context manager
exits.
Storing the object (cm.unraisable.object) can resurrect it if it is set to
an object which is being finalized. Exiting the context manager clears the
stored object.
Usage:
with catch_unraisable_exception() as cm:
# code creating an "unraisable exception"
...
# check the unraisable exception: use cm.unraisable
...
# cm.unraisable attribute no longer exists at this point
# (to break a reference cycle)
"""
def __init__(self) -> None:
self.unraisable: Optional["sys.UnraisableHookArgs"] = None
self._old_hook: Optional[Callable[["sys.UnraisableHookArgs"], Any]] = None
def _hook(self, unraisable: "sys.UnraisableHookArgs") -> None:
# Storing unraisable.object can resurrect an object which is being
# finalized. Storing unraisable.exc_value creates a reference cycle.
self.unraisable = unraisable
def __enter__(self) -> "catch_unraisable_exception":
self._old_hook = sys.unraisablehook
sys.unraisablehook = self._hook
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
assert self._old_hook is not None
sys.unraisablehook = self._old_hook
self._old_hook = None
del self.unraisable
def unraisable_exception_runtest_hook() -> Generator[None, None, None]:
with catch_unraisable_exception() as cm:
yield
if cm.unraisable:
if cm.unraisable.err_msg is not None:
err_msg = cm.unraisable.err_msg
else:
err_msg = "Exception ignored in"
msg = f"{err_msg}: {cm.unraisable.object!r}\n\n"
msg += "".join(
traceback.format_exception(
cm.unraisable.exc_type,
cm.unraisable.exc_value,
cm.unraisable.exc_traceback,
)
)
warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_setup() -> Generator[None, None, None]:
yield from unraisable_exception_runtest_hook()
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_call() -> Generator[None, None, None]:
yield from unraisable_exception_runtest_hook()
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_teardown() -> Generator[None, None, None]:
yield from unraisable_exception_runtest_hook()

View File

@ -0,0 +1,171 @@
import inspect
import warnings
from types import FunctionType
from typing import Any
from typing import Generic
from typing import Type
from typing import TypeVar
import attr
from _pytest.compat import final
class PytestWarning(UserWarning):
"""Base class for all warnings emitted by pytest."""
__module__ = "pytest"
@final
class PytestAssertRewriteWarning(PytestWarning):
"""Warning emitted by the pytest assert rewrite module."""
__module__ = "pytest"
@final
class PytestCacheWarning(PytestWarning):
"""Warning emitted by the cache plugin in various situations."""
__module__ = "pytest"
@final
class PytestConfigWarning(PytestWarning):
"""Warning emitted for configuration issues."""
__module__ = "pytest"
@final
class PytestCollectionWarning(PytestWarning):
"""Warning emitted when pytest is not able to collect a file or symbol in a module."""
__module__ = "pytest"
class PytestDeprecationWarning(PytestWarning, DeprecationWarning):
"""Warning class for features that will be removed in a future version."""
__module__ = "pytest"
class PytestRemovedIn8Warning(PytestDeprecationWarning):
"""Warning class for features that will be removed in pytest 8."""
__module__ = "pytest"
class PytestReturnNotNoneWarning(PytestRemovedIn8Warning):
"""Warning emitted when a test function is returning value other than None."""
__module__ = "pytest"
@final
class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
"""Warning category used to denote experiments in pytest.
Use sparingly as the API might change or even be removed completely in a
future version.
"""
__module__ = "pytest"
@classmethod
def simple(cls, apiname: str) -> "PytestExperimentalApiWarning":
return cls(
"{apiname} is an experimental api that may change over time".format(
apiname=apiname
)
)
@final
class PytestUnhandledCoroutineWarning(PytestReturnNotNoneWarning):
"""Warning emitted for an unhandled coroutine.
A coroutine was encountered when collecting test functions, but was not
handled by any async-aware plugin.
Coroutine test functions are not natively supported.
"""
__module__ = "pytest"
@final
class PytestUnknownMarkWarning(PytestWarning):
"""Warning emitted on use of unknown markers.
See :ref:`mark` for details.
"""
__module__ = "pytest"
@final
class PytestUnraisableExceptionWarning(PytestWarning):
"""An unraisable exception was reported.
Unraisable exceptions are exceptions raised in :meth:`__del__ <object.__del__>`
implementations and similar situations when the exception cannot be raised
as normal.
"""
__module__ = "pytest"
@final
class PytestUnhandledThreadExceptionWarning(PytestWarning):
"""An unhandled exception occurred in a :class:`~threading.Thread`.
Such exceptions don't propagate normally.
"""
__module__ = "pytest"
_W = TypeVar("_W", bound=PytestWarning)
@final
@attr.s(auto_attribs=True)
class UnformattedWarning(Generic[_W]):
"""A warning meant to be formatted during runtime.
This is used to hold warnings that need to format their message at runtime,
as opposed to a direct message.
"""
category: Type["_W"]
template: str
def format(self, **kwargs: Any) -> _W:
"""Return an instance of the warning category, formatted with given kwargs."""
return self.category(self.template.format(**kwargs))
def warn_explicit_for(method: FunctionType, message: PytestWarning) -> None:
"""
Issue the warning :param:`message` for the definition of the given :param:`method`
this helps to log warnigns for functions defined prior to finding an issue with them
(like hook wrappers being marked in a legacy mechanism)
"""
lineno = method.__code__.co_firstlineno
filename = inspect.getfile(method)
module = method.__module__
mod_globals = method.__globals__
try:
warnings.warn_explicit(
message,
type(message),
filename=filename,
module=module,
registry=mod_globals.setdefault("__warningregistry__", {}),
lineno=lineno,
)
except Warning as w:
# If warnings are errors (e.g. -Werror), location information gets lost, so we add it to the message.
raise type(w)(f"{w}\n at {filename}:{lineno}") from None

View File

@ -0,0 +1,148 @@
import sys
import warnings
from contextlib import contextmanager
from typing import Generator
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from _pytest.config import apply_warning_filters
from _pytest.config import Config
from _pytest.config import parse_warning_filter
from _pytest.main import Session
from _pytest.nodes import Item
from _pytest.terminal import TerminalReporter
if TYPE_CHECKING:
from typing_extensions import Literal
def pytest_configure(config: Config) -> None:
config.addinivalue_line(
"markers",
"filterwarnings(warning): add a warning filter to the given test. "
"see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#pytest-mark-filterwarnings ",
)
@contextmanager
def catch_warnings_for_item(
config: Config,
ihook,
when: "Literal['config', 'collect', 'runtest']",
item: Optional[Item],
) -> Generator[None, None, None]:
"""Context manager that catches warnings generated in the contained execution block.
``item`` can be None if we are not in the context of an item execution.
Each warning captured triggers the ``pytest_warning_recorded`` hook.
"""
config_filters = config.getini("filterwarnings")
cmdline_filters = config.known_args_namespace.pythonwarnings or []
with warnings.catch_warnings(record=True) as log:
# mypy can't infer that record=True means log is not None; help it.
assert log is not None
if not sys.warnoptions:
# If user is not explicitly configuring warning filters, show deprecation warnings by default (#2908).
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.filterwarnings("always", category=PendingDeprecationWarning)
apply_warning_filters(config_filters, cmdline_filters)
# apply filters from "filterwarnings" marks
nodeid = "" if item is None else item.nodeid
if item is not None:
for mark in item.iter_markers(name="filterwarnings"):
for arg in mark.args:
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
yield
for warning_message in log:
ihook.pytest_warning_recorded.call_historic(
kwargs=dict(
warning_message=warning_message,
nodeid=nodeid,
when=when,
location=None,
)
)
def warning_record_to_str(warning_message: warnings.WarningMessage) -> str:
"""Convert a warnings.WarningMessage to a string."""
warn_msg = warning_message.message
msg = warnings.formatwarning(
str(warn_msg),
warning_message.category,
warning_message.filename,
warning_message.lineno,
warning_message.line,
)
if warning_message.source is not None:
try:
import tracemalloc
except ImportError:
pass
else:
tb = tracemalloc.get_object_traceback(warning_message.source)
if tb is not None:
formatted_tb = "\n".join(tb.format())
# Use a leading new line to better separate the (large) output
# from the traceback to the previous warning text.
msg += f"\nObject allocated at:\n{formatted_tb}"
else:
# No need for a leading new line.
url = "https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings"
msg += "Enable tracemalloc to get traceback where the object was allocated.\n"
msg += f"See {url} for more info."
return msg
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
with catch_warnings_for_item(
config=item.config, ihook=item.ihook, when="runtest", item=item
):
yield
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_collection(session: Session) -> Generator[None, None, None]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="collect", item=None
):
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_terminal_summary(
terminalreporter: TerminalReporter,
) -> Generator[None, None, None]:
config = terminalreporter.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
):
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_sessionfinish(session: Session) -> Generator[None, None, None]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
):
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_load_initial_conftests(
early_config: "Config",
) -> Generator[None, None, None]:
with catch_warnings_for_item(
config=early_config, ihook=early_config.hook, when="config", item=None
):
yield

View File

@ -0,0 +1 @@
import _virtualenv

View File

@ -0,0 +1,130 @@
"""Patches that are applied at runtime to the virtual environment"""
# -*- coding: utf-8 -*-
import os
import sys
VIRTUALENV_PATCH_FILE = os.path.join(__file__)
def patch_dist(dist):
"""
Distutils allows user to configure some arguments via a configuration file:
https://docs.python.org/3/install/index.html#distutils-configuration-files
Some of this arguments though don't make sense in context of the virtual environment files, let's fix them up.
"""
# we cannot allow some install config as that would get packages installed outside of the virtual environment
old_parse_config_files = dist.Distribution.parse_config_files
def parse_config_files(self, *args, **kwargs):
result = old_parse_config_files(self, *args, **kwargs)
install = self.get_option_dict("install")
if "prefix" in install: # the prefix governs where to install the libraries
install["prefix"] = VIRTUALENV_PATCH_FILE, os.path.abspath(sys.prefix)
for base in ("purelib", "platlib", "headers", "scripts", "data"):
key = "install_{}".format(base)
if key in install: # do not allow global configs to hijack venv paths
install.pop(key, None)
return result
dist.Distribution.parse_config_files = parse_config_files
# Import hook that patches some modules to ignore configuration values that break package installation in case
# of virtual environments.
_DISTUTILS_PATCH = "distutils.dist", "setuptools.dist"
if sys.version_info > (3, 4):
# https://docs.python.org/3/library/importlib.html#setting-up-an-importer
from functools import partial
from importlib.abc import MetaPathFinder
from importlib.util import find_spec
class _Finder(MetaPathFinder):
"""A meta path finder that allows patching the imported distutils modules"""
fullname = None
# lock[0] is threading.Lock(), but initialized lazily to avoid importing threading very early at startup,
# because there are gevent-based applications that need to be first to import threading by themselves.
# See https://github.com/pypa/virtualenv/issues/1895 for details.
lock = []
def find_spec(self, fullname, path, target=None):
if fullname in _DISTUTILS_PATCH and self.fullname is None:
# initialize lock[0] lazily
if len(self.lock) == 0:
import threading
lock = threading.Lock()
# there is possibility that two threads T1 and T2 are simultaneously running into find_spec,
# observing .lock as empty, and further going into hereby initialization. However due to the GIL,
# list.append() operation is atomic and this way only one of the threads will "win" to put the lock
# - that every thread will use - into .lock[0].
# https://docs.python.org/3/faq/library.html#what-kinds-of-global-value-mutation-are-thread-safe
self.lock.append(lock)
with self.lock[0]:
self.fullname = fullname
try:
spec = find_spec(fullname, path)
if spec is not None:
# https://www.python.org/dev/peps/pep-0451/#how-loading-will-work
is_new_api = hasattr(spec.loader, "exec_module")
func_name = "exec_module" if is_new_api else "load_module"
old = getattr(spec.loader, func_name)
func = self.exec_module if is_new_api else self.load_module
if old is not func:
try:
setattr(spec.loader, func_name, partial(func, old))
except AttributeError:
pass # C-Extension loaders are r/o such as zipimporter with <python 3.7
return spec
finally:
self.fullname = None
@staticmethod
def exec_module(old, module):
old(module)
if module.__name__ in _DISTUTILS_PATCH:
patch_dist(module)
@staticmethod
def load_module(old, name):
module = old(name)
if module.__name__ in _DISTUTILS_PATCH:
patch_dist(module)
return module
sys.meta_path.insert(0, _Finder())
else:
# https://www.python.org/dev/peps/pep-0302/
from imp import find_module
from pkgutil import ImpImporter, ImpLoader
class _VirtualenvImporter(object, ImpImporter):
def __init__(self, path=None):
object.__init__(self)
ImpImporter.__init__(self, path)
def find_module(self, fullname, path=None):
if fullname in _DISTUTILS_PATCH:
try:
return _VirtualenvLoader(fullname, *find_module(fullname.split(".")[-1], path))
except ImportError:
pass
return None
class _VirtualenvLoader(object, ImpLoader):
def __init__(self, fullname, file, filename, etc):
object.__init__(self)
ImpLoader.__init__(self, fullname, file, filename, etc)
def load_module(self, fullname):
module = super(_VirtualenvLoader, self).load_module(fullname)
patch_dist(module)
module.__loader__ = None # distlib fallback
return module
sys.meta_path.append(_VirtualenvImporter())

View File

@ -0,0 +1,79 @@
# SPDX-License-Identifier: MIT
import sys
from functools import partial
from . import converters, exceptions, filters, setters, validators
from ._cmp import cmp_using
from ._config import get_run_validators, set_run_validators
from ._funcs import asdict, assoc, astuple, evolve, has, resolve_types
from ._make import (
NOTHING,
Attribute,
Factory,
attrib,
attrs,
fields,
fields_dict,
make_class,
validate,
)
from ._version_info import VersionInfo
__version__ = "22.1.0"
__version_info__ = VersionInfo._from_version_string(__version__)
__title__ = "attrs"
__description__ = "Classes Without Boilerplate"
__url__ = "https://www.attrs.org/"
__uri__ = __url__
__doc__ = __description__ + " <" + __uri__ + ">"
__author__ = "Hynek Schlawack"
__email__ = "hs@ox.cx"
__license__ = "MIT"
__copyright__ = "Copyright (c) 2015 Hynek Schlawack"
s = attributes = attrs
ib = attr = attrib
dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
__all__ = [
"Attribute",
"Factory",
"NOTHING",
"asdict",
"assoc",
"astuple",
"attr",
"attrib",
"attributes",
"attrs",
"cmp_using",
"converters",
"evolve",
"exceptions",
"fields",
"fields_dict",
"filters",
"get_run_validators",
"has",
"ib",
"make_class",
"resolve_types",
"s",
"set_run_validators",
"setters",
"validate",
"validators",
]
if sys.version_info[:2] >= (3, 6):
from ._next_gen import define, field, frozen, mutable # noqa: F401
__all__.extend(("define", "field", "frozen", "mutable"))

View File

@ -0,0 +1,486 @@
import sys
from typing import (
Any,
Callable,
ClassVar,
Dict,
Generic,
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
# `import X as X` is required to make these public
from . import converters as converters
from . import exceptions as exceptions
from . import filters as filters
from . import setters as setters
from . import validators as validators
from ._cmp import cmp_using as cmp_using
from ._version_info import VersionInfo
__version__: str
__version_info__: VersionInfo
__title__: str
__description__: str
__url__: str
__uri__: str
__author__: str
__email__: str
__license__: str
__copyright__: str
_T = TypeVar("_T")
_C = TypeVar("_C", bound=type)
_EqOrderType = Union[bool, Callable[[Any], Any]]
_ValidatorType = Callable[[Any, Attribute[_T], _T], Any]
_ConverterType = Callable[[Any], Any]
_FilterType = Callable[[Attribute[_T], _T], bool]
_ReprType = Callable[[Any], str]
_ReprArgType = Union[bool, _ReprType]
_OnSetAttrType = Callable[[Any, Attribute[Any], Any], Any]
_OnSetAttrArgType = Union[
_OnSetAttrType, List[_OnSetAttrType], setters._NoOpType
]
_FieldTransformer = Callable[
[type, List[Attribute[Any]]], List[Attribute[Any]]
]
# FIXME: in reality, if multiple validators are passed they must be in a list
# or tuple, but those are invariant and so would prevent subtypes of
# _ValidatorType from working when passed in a list or tuple.
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]
# A protocol to be able to statically accept an attrs class.
class AttrsInstance(Protocol):
__attrs_attrs__: ClassVar[Any]
# _make --
NOTHING: object
# NOTE: Factory lies about its return type to make this possible:
# `x: List[int] # = Factory(list)`
# Work around mypy issue #4554 in the common case by using an overload.
if sys.version_info >= (3, 8):
from typing import Literal
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
@overload
def Factory(
factory: Callable[[Any], _T],
takes_self: Literal[True],
) -> _T: ...
@overload
def Factory(
factory: Callable[[], _T],
takes_self: Literal[False],
) -> _T: ...
else:
@overload
def Factory(factory: Callable[[], _T]) -> _T: ...
@overload
def Factory(
factory: Union[Callable[[Any], _T], Callable[[], _T]],
takes_self: bool = ...,
) -> _T: ...
# Static type inference support via __dataclass_transform__ implemented as per:
# https://github.com/microsoft/pyright/blob/1.1.135/specs/dataclass_transforms.md
# This annotation must be applied to all overloads of "define" and "attrs"
#
# NOTE: This is a typing construct and does not exist at runtime. Extensions
# wrapping attrs decorators should declare a separate __dataclass_transform__
# signature in the extension module using the specification linked above to
# provide pyright support.
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]: ...
class Attribute(Generic[_T]):
name: str
default: Optional[_T]
validator: Optional[_ValidatorType[_T]]
repr: _ReprArgType
cmp: _EqOrderType
eq: _EqOrderType
order: _EqOrderType
hash: Optional[bool]
init: bool
converter: Optional[_ConverterType]
metadata: Dict[Any, Any]
type: Optional[Type[_T]]
kw_only: bool
on_setattr: _OnSetAttrType
def evolve(self, **changes: Any) -> "Attribute[Any]": ...
# NOTE: We had several choices for the annotation to use for type arg:
# 1) Type[_T]
# - Pros: Handles simple cases correctly
# - Cons: Might produce less informative errors in the case of conflicting
# TypeVars e.g. `attr.ib(default='bad', type=int)`
# 2) Callable[..., _T]
# - Pros: Better error messages than #1 for conflicting TypeVars
# - Cons: Terrible error messages for validator checks.
# e.g. attr.ib(type=int, validator=validate_str)
# -> error: Cannot infer function type argument
# 3) type (and do all of the work in the mypy plugin)
# - Pros: Simple here, and we could customize the plugin with our own errors.
# - Cons: Would need to write mypy plugin code to handle all the cases.
# We chose option #1.
# `attr` lies about its return type to make the following possible:
# attr() -> Any
# attr(8) -> int
# attr(validator=<some callable>) -> Whatever the callable expects.
# This makes this type of assignments possible:
# x: int = attr(8)
#
# This form catches explicit None or no default but with no other arguments
# returns Any.
@overload
def attrib(
default: None = ...,
validator: None = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: None = ...,
converter: None = ...,
factory: None = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the
# other arguments.
@overload
def attrib(
default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form catches an explicit default argument.
@overload
def attrib(
default: _T,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: Optional[Type[_T]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def attrib(
default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
type: object = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
@overload
def field(
*,
default: None = ...,
validator: None = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: None = ...,
factory: None = ...,
kw_only: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
# This form catches an explicit None or no default and infers the type from the
# other arguments.
@overload
def field(
*,
default: None = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form catches an explicit default argument.
@overload
def field(
*,
default: _T,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> _T: ...
# This form covers type=non-Type: e.g. forward references (str), Any
@overload
def field(
*,
default: Optional[_T] = ...,
validator: Optional[_ValidatorArgType[_T]] = ...,
repr: _ReprArgType = ...,
hash: Optional[bool] = ...,
init: bool = ...,
metadata: Optional[Mapping[Any, Any]] = ...,
converter: Optional[_ConverterType] = ...,
factory: Optional[Callable[[], _T]] = ...,
kw_only: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
) -> Any: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
def attrs(
maybe_cls: _C,
these: Optional[Dict[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
auto_detect: bool = ...,
collect_by_mro: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(order_default=True, field_descriptors=(attrib, field))
def attrs(
maybe_cls: None = ...,
these: Optional[Dict[str, Any]] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
auto_detect: bool = ...,
collect_by_mro: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
match_args: bool = ...,
) -> Callable[[_C], _C]: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
def define(
maybe_cls: _C,
*,
these: Optional[Dict[str, Any]] = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
auto_detect: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
match_args: bool = ...,
) -> _C: ...
@overload
@__dataclass_transform__(field_descriptors=(attrib, field))
def define(
maybe_cls: None = ...,
*,
these: Optional[Dict[str, Any]] = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[bool] = ...,
order: Optional[bool] = ...,
auto_detect: bool = ...,
getstate_setstate: Optional[bool] = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
match_args: bool = ...,
) -> Callable[[_C], _C]: ...
mutable = define
frozen = define # they differ only in their defaults
def fields(cls: Type[AttrsInstance]) -> Any: ...
def fields_dict(cls: Type[AttrsInstance]) -> Dict[str, Attribute[Any]]: ...
def validate(inst: AttrsInstance) -> None: ...
def resolve_types(
cls: _C,
globalns: Optional[Dict[str, Any]] = ...,
localns: Optional[Dict[str, Any]] = ...,
attribs: Optional[List[Attribute[Any]]] = ...,
) -> _C: ...
# TODO: add support for returning a proper attrs class from the mypy plugin
# we use Any instead of _CountingAttr so that e.g. `make_class('Foo',
# [attr.ib()])` is valid
def make_class(
name: str,
attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]],
bases: Tuple[type, ...] = ...,
repr_ns: Optional[str] = ...,
repr: bool = ...,
cmp: Optional[_EqOrderType] = ...,
hash: Optional[bool] = ...,
init: bool = ...,
slots: bool = ...,
frozen: bool = ...,
weakref_slot: bool = ...,
str: bool = ...,
auto_attribs: bool = ...,
kw_only: bool = ...,
cache_hash: bool = ...,
auto_exc: bool = ...,
eq: Optional[_EqOrderType] = ...,
order: Optional[_EqOrderType] = ...,
collect_by_mro: bool = ...,
on_setattr: Optional[_OnSetAttrArgType] = ...,
field_transformer: Optional[_FieldTransformer] = ...,
) -> type: ...
# _funcs --
# TODO: add support for returning TypedDict from the mypy plugin
# FIXME: asdict/astuple do not honor their factory args. Waiting on one of
# these:
# https://github.com/python/mypy/issues/4236
# https://github.com/python/typing/issues/253
# XXX: remember to fix attrs.asdict/astuple too!
def asdict(
inst: AttrsInstance,
recurse: bool = ...,
filter: Optional[_FilterType[Any]] = ...,
dict_factory: Type[Mapping[Any, Any]] = ...,
retain_collection_types: bool = ...,
value_serializer: Optional[
Callable[[type, Attribute[Any], Any], Any]
] = ...,
tuple_keys: Optional[bool] = ...,
) -> Dict[str, Any]: ...
# TODO: add support for returning NamedTuple from the mypy plugin
def astuple(
inst: AttrsInstance,
recurse: bool = ...,
filter: Optional[_FilterType[Any]] = ...,
tuple_factory: Type[Sequence[Any]] = ...,
retain_collection_types: bool = ...,
) -> Tuple[Any, ...]: ...
def has(cls: type) -> bool: ...
def assoc(inst: _T, **changes: Any) -> _T: ...
def evolve(inst: _T, **changes: Any) -> _T: ...
# _config --
def set_run_validators(run: bool) -> None: ...
def get_run_validators() -> bool: ...
# aliases --
s = attributes = attrs
ib = attr = attrib
dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;)

View File

@ -0,0 +1,155 @@
# SPDX-License-Identifier: MIT
import functools
import types
from ._make import _make_ne
_operation_names = {"eq": "==", "lt": "<", "le": "<=", "gt": ">", "ge": ">="}
def cmp_using(
eq=None,
lt=None,
le=None,
gt=None,
ge=None,
require_same_type=True,
class_name="Comparable",
):
"""
Create a class that can be passed into `attr.ib`'s ``eq``, ``order``, and
``cmp`` arguments to customize field comparison.
The resulting class will have a full set of ordering methods if
at least one of ``{lt, le, gt, ge}`` and ``eq`` are provided.
:param Optional[callable] eq: `callable` used to evaluate equality
of two objects.
:param Optional[callable] lt: `callable` used to evaluate whether
one object is less than another object.
:param Optional[callable] le: `callable` used to evaluate whether
one object is less than or equal to another object.
:param Optional[callable] gt: `callable` used to evaluate whether
one object is greater than another object.
:param Optional[callable] ge: `callable` used to evaluate whether
one object is greater than or equal to another object.
:param bool require_same_type: When `True`, equality and ordering methods
will return `NotImplemented` if objects are not of the same type.
:param Optional[str] class_name: Name of class. Defaults to 'Comparable'.
See `comparison` for more details.
.. versionadded:: 21.1.0
"""
body = {
"__slots__": ["value"],
"__init__": _make_init(),
"_requirements": [],
"_is_comparable_to": _is_comparable_to,
}
# Add operations.
num_order_functions = 0
has_eq_function = False
if eq is not None:
has_eq_function = True
body["__eq__"] = _make_operator("eq", eq)
body["__ne__"] = _make_ne()
if lt is not None:
num_order_functions += 1
body["__lt__"] = _make_operator("lt", lt)
if le is not None:
num_order_functions += 1
body["__le__"] = _make_operator("le", le)
if gt is not None:
num_order_functions += 1
body["__gt__"] = _make_operator("gt", gt)
if ge is not None:
num_order_functions += 1
body["__ge__"] = _make_operator("ge", ge)
type_ = types.new_class(
class_name, (object,), {}, lambda ns: ns.update(body)
)
# Add same type requirement.
if require_same_type:
type_._requirements.append(_check_same_type)
# Add total ordering if at least one operation was defined.
if 0 < num_order_functions < 4:
if not has_eq_function:
# functools.total_ordering requires __eq__ to be defined,
# so raise early error here to keep a nice stack.
raise ValueError(
"eq must be define is order to complete ordering from "
"lt, le, gt, ge."
)
type_ = functools.total_ordering(type_)
return type_
def _make_init():
"""
Create __init__ method.
"""
def __init__(self, value):
"""
Initialize object with *value*.
"""
self.value = value
return __init__
def _make_operator(name, func):
"""
Create operator method.
"""
def method(self, other):
if not self._is_comparable_to(other):
return NotImplemented
result = func(self.value, other.value)
if result is NotImplemented:
return NotImplemented
return result
method.__name__ = "__%s__" % (name,)
method.__doc__ = "Return a %s b. Computed by attrs." % (
_operation_names[name],
)
return method
def _is_comparable_to(self, other):
"""
Check whether `other` is comparable to `self`.
"""
for func in self._requirements:
if not func(self, other):
return False
return True
def _check_same_type(self, other):
"""
Return True if *self* and *other* are of the same type, False otherwise.
"""
return other.value.__class__ is self.value.__class__

View File

@ -0,0 +1,13 @@
from typing import Any, Callable, Optional, Type
_CompareWithType = Callable[[Any, Any], bool]
def cmp_using(
eq: Optional[_CompareWithType],
lt: Optional[_CompareWithType],
le: Optional[_CompareWithType],
gt: Optional[_CompareWithType],
ge: Optional[_CompareWithType],
require_same_type: bool,
class_name: str,
) -> Type: ...

View File

@ -0,0 +1,185 @@
# SPDX-License-Identifier: MIT
import inspect
import platform
import sys
import threading
import types
import warnings
from collections.abc import Mapping, Sequence # noqa
PYPY = platform.python_implementation() == "PyPy"
PY36 = sys.version_info[:2] >= (3, 6)
HAS_F_STRINGS = PY36
PY310 = sys.version_info[:2] >= (3, 10)
if PYPY or PY36:
ordered_dict = dict
else:
from collections import OrderedDict
ordered_dict = OrderedDict
def just_warn(*args, **kw):
warnings.warn(
"Running interpreter doesn't sufficiently support code object "
"introspection. Some features like bare super() or accessing "
"__class__ will not work with slotted classes.",
RuntimeWarning,
stacklevel=2,
)
class _AnnotationExtractor:
"""
Extract type annotations from a callable, returning None whenever there
is none.
"""
__slots__ = ["sig"]
def __init__(self, callable):
try:
self.sig = inspect.signature(callable)
except (ValueError, TypeError): # inspect failed
self.sig = None
def get_first_param_type(self):
"""
Return the type annotation of the first argument if it's not empty.
"""
if not self.sig:
return None
params = list(self.sig.parameters.values())
if params and params[0].annotation is not inspect.Parameter.empty:
return params[0].annotation
return None
def get_return_type(self):
"""
Return the return type if it's not empty.
"""
if (
self.sig
and self.sig.return_annotation is not inspect.Signature.empty
):
return self.sig.return_annotation
return None
def make_set_closure_cell():
"""Return a function of two arguments (cell, value) which sets
the value stored in the closure cell `cell` to `value`.
"""
# pypy makes this easy. (It also supports the logic below, but
# why not do the easy/fast thing?)
if PYPY:
def set_closure_cell(cell, value):
cell.__setstate__((value,))
return set_closure_cell
# Otherwise gotta do it the hard way.
# Create a function that will set its first cellvar to `value`.
def set_first_cellvar_to(value):
x = value
return
# This function will be eliminated as dead code, but
# not before its reference to `x` forces `x` to be
# represented as a closure cell rather than a local.
def force_x_to_be_a_cell(): # pragma: no cover
return x
try:
# Extract the code object and make sure our assumptions about
# the closure behavior are correct.
co = set_first_cellvar_to.__code__
if co.co_cellvars != ("x",) or co.co_freevars != ():
raise AssertionError # pragma: no cover
# Convert this code object to a code object that sets the
# function's first _freevar_ (not cellvar) to the argument.
if sys.version_info >= (3, 8):
def set_closure_cell(cell, value):
cell.cell_contents = value
else:
args = [co.co_argcount]
args.append(co.co_kwonlyargcount)
args.extend(
[
co.co_nlocals,
co.co_stacksize,
co.co_flags,
co.co_code,
co.co_consts,
co.co_names,
co.co_varnames,
co.co_filename,
co.co_name,
co.co_firstlineno,
co.co_lnotab,
# These two arguments are reversed:
co.co_cellvars,
co.co_freevars,
]
)
set_first_freevar_code = types.CodeType(*args)
def set_closure_cell(cell, value):
# Create a function using the set_first_freevar_code,
# whose first closure cell is `cell`. Calling it will
# change the value of that cell.
setter = types.FunctionType(
set_first_freevar_code, {}, "setter", (), (cell,)
)
# And call it to set the cell.
setter(value)
# Make sure it works on this interpreter:
def make_func_with_cell():
x = None
def func():
return x # pragma: no cover
return func
cell = make_func_with_cell().__closure__[0]
set_closure_cell(cell, 100)
if cell.cell_contents != 100:
raise AssertionError # pragma: no cover
except Exception:
return just_warn
else:
return set_closure_cell
set_closure_cell = make_set_closure_cell()
# Thread-local global to track attrs instances which are already being repr'd.
# This is needed because there is no other (thread-safe) way to pass info
# about the instances that are already being repr'd through the call stack
# in order to ensure we don't perform infinite recursion.
#
# For instance, if an instance contains a dict which contains that instance,
# we need to know that we're already repr'ing the outside instance from within
# the dict's repr() call.
#
# This lives here rather than in _make.py so that the functions in _make.py
# don't have a direct reference to the thread-local in their globals dict.
# If they have such a reference, it breaks cloudpickle.
repr_context = threading.local()

View File

@ -0,0 +1,31 @@
# SPDX-License-Identifier: MIT
__all__ = ["set_run_validators", "get_run_validators"]
_run_validators = True
def set_run_validators(run):
"""
Set whether or not validators are run. By default, they are run.
.. deprecated:: 21.3.0 It will not be removed, but it also will not be
moved to new ``attrs`` namespace. Use `attrs.validators.set_disabled()`
instead.
"""
if not isinstance(run, bool):
raise TypeError("'run' must be bool.")
global _run_validators
_run_validators = run
def get_run_validators():
"""
Return whether or not validators are run.
.. deprecated:: 21.3.0 It will not be removed, but it also will not be
moved to new ``attrs`` namespace. Use `attrs.validators.get_disabled()`
instead.
"""
return _run_validators

View File

@ -0,0 +1,420 @@
# SPDX-License-Identifier: MIT
import copy
from ._make import NOTHING, _obj_setattr, fields
from .exceptions import AttrsAttributeNotFoundError
def asdict(
inst,
recurse=True,
filter=None,
dict_factory=dict,
retain_collection_types=False,
value_serializer=None,
):
"""
Return the ``attrs`` attribute values of *inst* as a dict.
Optionally recurse into other ``attrs``-decorated classes.
:param inst: Instance of an ``attrs``-decorated class.
:param bool recurse: Recurse into classes that are also
``attrs``-decorated.
:param callable filter: A callable whose return code determines whether an
attribute or element is included (``True``) or dropped (``False``). Is
called with the `attrs.Attribute` as the first argument and the
value as the second argument.
:param callable dict_factory: A callable to produce dictionaries from. For
example, to produce ordered dictionaries instead of normal Python
dictionaries, pass in ``collections.OrderedDict``.
:param bool retain_collection_types: Do not convert to ``list`` when
encountering an attribute whose type is ``tuple`` or ``set``. Only
meaningful if ``recurse`` is ``True``.
:param Optional[callable] value_serializer: A hook that is called for every
attribute or dict key/value. It receives the current instance, field
and value and must return the (updated) value. The hook is run *after*
the optional *filter* has been applied.
:rtype: return type of *dict_factory*
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 16.0.0 *dict_factory*
.. versionadded:: 16.1.0 *retain_collection_types*
.. versionadded:: 20.3.0 *value_serializer*
.. versionadded:: 21.3.0 If a dict has a collection for a key, it is
serialized as a tuple.
"""
attrs = fields(inst.__class__)
rv = dict_factory()
for a in attrs:
v = getattr(inst, a.name)
if filter is not None and not filter(a, v):
continue
if value_serializer is not None:
v = value_serializer(inst, a, v)
if recurse is True:
if has(v.__class__):
rv[a.name] = asdict(
v,
recurse=True,
filter=filter,
dict_factory=dict_factory,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
)
elif isinstance(v, (tuple, list, set, frozenset)):
cf = v.__class__ if retain_collection_types is True else list
rv[a.name] = cf(
[
_asdict_anything(
i,
is_key=False,
filter=filter,
dict_factory=dict_factory,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
)
for i in v
]
)
elif isinstance(v, dict):
df = dict_factory
rv[a.name] = df(
(
_asdict_anything(
kk,
is_key=True,
filter=filter,
dict_factory=df,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
),
_asdict_anything(
vv,
is_key=False,
filter=filter,
dict_factory=df,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
),
)
for kk, vv in v.items()
)
else:
rv[a.name] = v
else:
rv[a.name] = v
return rv
def _asdict_anything(
val,
is_key,
filter,
dict_factory,
retain_collection_types,
value_serializer,
):
"""
``asdict`` only works on attrs instances, this works on anything.
"""
if getattr(val.__class__, "__attrs_attrs__", None) is not None:
# Attrs class.
rv = asdict(
val,
recurse=True,
filter=filter,
dict_factory=dict_factory,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
)
elif isinstance(val, (tuple, list, set, frozenset)):
if retain_collection_types is True:
cf = val.__class__
elif is_key:
cf = tuple
else:
cf = list
rv = cf(
[
_asdict_anything(
i,
is_key=False,
filter=filter,
dict_factory=dict_factory,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
)
for i in val
]
)
elif isinstance(val, dict):
df = dict_factory
rv = df(
(
_asdict_anything(
kk,
is_key=True,
filter=filter,
dict_factory=df,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
),
_asdict_anything(
vv,
is_key=False,
filter=filter,
dict_factory=df,
retain_collection_types=retain_collection_types,
value_serializer=value_serializer,
),
)
for kk, vv in val.items()
)
else:
rv = val
if value_serializer is not None:
rv = value_serializer(None, None, rv)
return rv
def astuple(
inst,
recurse=True,
filter=None,
tuple_factory=tuple,
retain_collection_types=False,
):
"""
Return the ``attrs`` attribute values of *inst* as a tuple.
Optionally recurse into other ``attrs``-decorated classes.
:param inst: Instance of an ``attrs``-decorated class.
:param bool recurse: Recurse into classes that are also
``attrs``-decorated.
:param callable filter: A callable whose return code determines whether an
attribute or element is included (``True``) or dropped (``False``). Is
called with the `attrs.Attribute` as the first argument and the
value as the second argument.
:param callable tuple_factory: A callable to produce tuples from. For
example, to produce lists instead of tuples.
:param bool retain_collection_types: Do not convert to ``list``
or ``dict`` when encountering an attribute which type is
``tuple``, ``dict`` or ``set``. Only meaningful if ``recurse`` is
``True``.
:rtype: return type of *tuple_factory*
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 16.2.0
"""
attrs = fields(inst.__class__)
rv = []
retain = retain_collection_types # Very long. :/
for a in attrs:
v = getattr(inst, a.name)
if filter is not None and not filter(a, v):
continue
if recurse is True:
if has(v.__class__):
rv.append(
astuple(
v,
recurse=True,
filter=filter,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
)
elif isinstance(v, (tuple, list, set, frozenset)):
cf = v.__class__ if retain is True else list
rv.append(
cf(
[
astuple(
j,
recurse=True,
filter=filter,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(j.__class__)
else j
for j in v
]
)
)
elif isinstance(v, dict):
df = v.__class__ if retain is True else dict
rv.append(
df(
(
astuple(
kk,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(kk.__class__)
else kk,
astuple(
vv,
tuple_factory=tuple_factory,
retain_collection_types=retain,
)
if has(vv.__class__)
else vv,
)
for kk, vv in v.items()
)
)
else:
rv.append(v)
else:
rv.append(v)
return rv if tuple_factory is list else tuple_factory(rv)
def has(cls):
"""
Check whether *cls* is a class with ``attrs`` attributes.
:param type cls: Class to introspect.
:raise TypeError: If *cls* is not a class.
:rtype: bool
"""
return getattr(cls, "__attrs_attrs__", None) is not None
def assoc(inst, **changes):
"""
Copy *inst* and apply *changes*.
:param inst: Instance of a class with ``attrs`` attributes.
:param changes: Keyword changes in the new copy.
:return: A copy of inst with *changes* incorporated.
:raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't
be found on *cls*.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. deprecated:: 17.1.0
Use `attrs.evolve` instead if you can.
This function will not be removed du to the slightly different approach
compared to `attrs.evolve`.
"""
import warnings
warnings.warn(
"assoc is deprecated and will be removed after 2018/01.",
DeprecationWarning,
stacklevel=2,
)
new = copy.copy(inst)
attrs = fields(inst.__class__)
for k, v in changes.items():
a = getattr(attrs, k, NOTHING)
if a is NOTHING:
raise AttrsAttributeNotFoundError(
"{k} is not an attrs attribute on {cl}.".format(
k=k, cl=new.__class__
)
)
_obj_setattr(new, k, v)
return new
def evolve(inst, **changes):
"""
Create a new instance, based on *inst* with *changes* applied.
:param inst: Instance of a class with ``attrs`` attributes.
:param changes: Keyword changes in the new copy.
:return: A copy of inst with *changes* incorporated.
:raise TypeError: If *attr_name* couldn't be found in the class
``__init__``.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
.. versionadded:: 17.1.0
"""
cls = inst.__class__
attrs = fields(cls)
for a in attrs:
if not a.init:
continue
attr_name = a.name # To deal with private attributes.
init_name = attr_name if attr_name[0] != "_" else attr_name[1:]
if init_name not in changes:
changes[init_name] = getattr(inst, attr_name)
return cls(**changes)
def resolve_types(cls, globalns=None, localns=None, attribs=None):
"""
Resolve any strings and forward annotations in type annotations.
This is only required if you need concrete types in `Attribute`'s *type*
field. In other words, you don't need to resolve your types if you only
use them for static type checking.
With no arguments, names will be looked up in the module in which the class
was created. If this is not what you want, e.g. if the name only exists
inside a method, you may pass *globalns* or *localns* to specify other
dictionaries in which to look up these names. See the docs of
`typing.get_type_hints` for more details.
:param type cls: Class to resolve.
:param Optional[dict] globalns: Dictionary containing global variables.
:param Optional[dict] localns: Dictionary containing local variables.
:param Optional[list] attribs: List of attribs for the given class.
This is necessary when calling from inside a ``field_transformer``
since *cls* is not an ``attrs`` class yet.
:raise TypeError: If *cls* is not a class.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class and you didn't pass any attribs.
:raise NameError: If types cannot be resolved because of missing variables.
:returns: *cls* so you can use this function also as a class decorator.
Please note that you have to apply it **after** `attrs.define`. That
means the decorator has to come in the line **before** `attrs.define`.
.. versionadded:: 20.1.0
.. versionadded:: 21.1.0 *attribs*
"""
# Since calling get_type_hints is expensive we cache whether we've
# done it already.
if getattr(cls, "__attrs_types_resolved__", None) != cls:
import typing
hints = typing.get_type_hints(cls, globalns=globalns, localns=localns)
for field in fields(cls) if attribs is None else attribs:
if field.name in hints:
# Since fields have been frozen we must work around it.
_obj_setattr(field, "type", hints[field.name])
# We store the class we resolved so that subclasses know they haven't
# been resolved.
cls.__attrs_types_resolved__ = cls
# Return the class so you can use it as a decorator too.
return cls

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,220 @@
# SPDX-License-Identifier: MIT
"""
These are Python 3.6+-only and keyword-only APIs that call `attr.s` and
`attr.ib` with different default values.
"""
from functools import partial
from . import setters
from ._funcs import asdict as _asdict
from ._funcs import astuple as _astuple
from ._make import (
NOTHING,
_frozen_setattrs,
_ng_default_on_setattr,
attrib,
attrs,
)
from .exceptions import UnannotatedAttributeError
def define(
maybe_cls=None,
*,
these=None,
repr=None,
hash=None,
init=None,
slots=True,
frozen=False,
weakref_slot=True,
str=False,
auto_attribs=None,
kw_only=False,
cache_hash=False,
auto_exc=True,
eq=None,
order=False,
auto_detect=True,
getstate_setstate=None,
on_setattr=None,
field_transformer=None,
match_args=True,
):
r"""
Define an ``attrs`` class.
Differences to the classic `attr.s` that it uses underneath:
- Automatically detect whether or not *auto_attribs* should be `True` (c.f.
*auto_attribs* parameter).
- If *frozen* is `False`, run converters and validators when setting an
attribute by default.
- *slots=True*
.. caution::
Usually this has only upsides and few visible effects in everyday
programming. But it *can* lead to some suprising behaviors, so please
make sure to read :term:`slotted classes`.
- *auto_exc=True*
- *auto_detect=True*
- *order=False*
- Some options that were only relevant on Python 2 or were kept around for
backwards-compatibility have been removed.
Please note that these are all defaults and you can change them as you
wish.
:param Optional[bool] auto_attribs: If set to `True` or `False`, it behaves
exactly like `attr.s`. If left `None`, `attr.s` will try to guess:
1. If any attributes are annotated and no unannotated `attrs.fields`\ s
are found, it assumes *auto_attribs=True*.
2. Otherwise it assumes *auto_attribs=False* and tries to collect
`attrs.fields`\ s.
For now, please refer to `attr.s` for the rest of the parameters.
.. versionadded:: 20.1.0
.. versionchanged:: 21.3.0 Converters are also run ``on_setattr``.
"""
def do_it(cls, auto_attribs):
return attrs(
maybe_cls=cls,
these=these,
repr=repr,
hash=hash,
init=init,
slots=slots,
frozen=frozen,
weakref_slot=weakref_slot,
str=str,
auto_attribs=auto_attribs,
kw_only=kw_only,
cache_hash=cache_hash,
auto_exc=auto_exc,
eq=eq,
order=order,
auto_detect=auto_detect,
collect_by_mro=True,
getstate_setstate=getstate_setstate,
on_setattr=on_setattr,
field_transformer=field_transformer,
match_args=match_args,
)
def wrap(cls):
"""
Making this a wrapper ensures this code runs during class creation.
We also ensure that frozen-ness of classes is inherited.
"""
nonlocal frozen, on_setattr
had_on_setattr = on_setattr not in (None, setters.NO_OP)
# By default, mutable classes convert & validate on setattr.
if frozen is False and on_setattr is None:
on_setattr = _ng_default_on_setattr
# However, if we subclass a frozen class, we inherit the immutability
# and disable on_setattr.
for base_cls in cls.__bases__:
if base_cls.__setattr__ is _frozen_setattrs:
if had_on_setattr:
raise ValueError(
"Frozen classes can't use on_setattr "
"(frozen-ness was inherited)."
)
on_setattr = setters.NO_OP
break
if auto_attribs is not None:
return do_it(cls, auto_attribs)
try:
return do_it(cls, True)
except UnannotatedAttributeError:
return do_it(cls, False)
# maybe_cls's type depends on the usage of the decorator. It's a class
# if it's used as `@attrs` but ``None`` if used as `@attrs()`.
if maybe_cls is None:
return wrap
else:
return wrap(maybe_cls)
mutable = define
frozen = partial(define, frozen=True, on_setattr=None)
def field(
*,
default=NOTHING,
validator=None,
repr=True,
hash=None,
init=True,
metadata=None,
converter=None,
factory=None,
kw_only=False,
eq=None,
order=None,
on_setattr=None,
):
"""
Identical to `attr.ib`, except keyword-only and with some arguments
removed.
.. versionadded:: 20.1.0
"""
return attrib(
default=default,
validator=validator,
repr=repr,
hash=hash,
init=init,
metadata=metadata,
converter=converter,
factory=factory,
kw_only=kw_only,
eq=eq,
order=order,
on_setattr=on_setattr,
)
def asdict(inst, *, recurse=True, filter=None, value_serializer=None):
"""
Same as `attr.asdict`, except that collections types are always retained
and dict is always used as *dict_factory*.
.. versionadded:: 21.3.0
"""
return _asdict(
inst=inst,
recurse=recurse,
filter=filter,
value_serializer=value_serializer,
retain_collection_types=True,
)
def astuple(inst, *, recurse=True, filter=None):
"""
Same as `attr.astuple`, except that collections types are always retained
and `tuple` is always used as the *tuple_factory*.
.. versionadded:: 21.3.0
"""
return _astuple(
inst=inst, recurse=recurse, filter=filter, retain_collection_types=True
)

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: MIT
from functools import total_ordering
from ._funcs import astuple
from ._make import attrib, attrs
@total_ordering
@attrs(eq=False, order=False, slots=True, frozen=True)
class VersionInfo:
"""
A version object that can be compared to tuple of length 1--4:
>>> attr.VersionInfo(19, 1, 0, "final") <= (19, 2)
True
>>> attr.VersionInfo(19, 1, 0, "final") < (19, 1, 1)
True
>>> vi = attr.VersionInfo(19, 2, 0, "final")
>>> vi < (19, 1, 1)
False
>>> vi < (19,)
False
>>> vi == (19, 2,)
True
>>> vi == (19, 2, 1)
False
.. versionadded:: 19.2
"""
year = attrib(type=int)
minor = attrib(type=int)
micro = attrib(type=int)
releaselevel = attrib(type=str)
@classmethod
def _from_version_string(cls, s):
"""
Parse *s* and return a _VersionInfo.
"""
v = s.split(".")
if len(v) == 3:
v.append("final")
return cls(
year=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3]
)
def _ensure_tuple(self, other):
"""
Ensure *other* is a tuple of a valid length.
Returns a possibly transformed *other* and ourselves as a tuple of
the same length as *other*.
"""
if self.__class__ is other.__class__:
other = astuple(other)
if not isinstance(other, tuple):
raise NotImplementedError
if not (1 <= len(other) <= 4):
raise NotImplementedError
return astuple(self)[: len(other)], other
def __eq__(self, other):
try:
us, them = self._ensure_tuple(other)
except NotImplementedError:
return NotImplemented
return us == them
def __lt__(self, other):
try:
us, them = self._ensure_tuple(other)
except NotImplementedError:
return NotImplemented
# Since alphabetically "dev0" < "final" < "post1" < "post2", we don't
# have to do anything special with releaselevel for now.
return us < them

View File

@ -0,0 +1,9 @@
class VersionInfo:
@property
def year(self) -> int: ...
@property
def minor(self) -> int: ...
@property
def micro(self) -> int: ...
@property
def releaselevel(self) -> str: ...

View File

@ -0,0 +1,144 @@
# SPDX-License-Identifier: MIT
"""
Commonly useful converters.
"""
import typing
from ._compat import _AnnotationExtractor
from ._make import NOTHING, Factory, pipe
__all__ = [
"default_if_none",
"optional",
"pipe",
"to_bool",
]
def optional(converter):
"""
A converter that allows an attribute to be optional. An optional attribute
is one which can be set to ``None``.
Type annotations will be inferred from the wrapped converter's, if it
has any.
:param callable converter: the converter that is used for non-``None``
values.
.. versionadded:: 17.1.0
"""
def optional_converter(val):
if val is None:
return None
return converter(val)
xtr = _AnnotationExtractor(converter)
t = xtr.get_first_param_type()
if t:
optional_converter.__annotations__["val"] = typing.Optional[t]
rt = xtr.get_return_type()
if rt:
optional_converter.__annotations__["return"] = typing.Optional[rt]
return optional_converter
def default_if_none(default=NOTHING, factory=None):
"""
A converter that allows to replace ``None`` values by *default* or the
result of *factory*.
:param default: Value to be used if ``None`` is passed. Passing an instance
of `attrs.Factory` is supported, however the ``takes_self`` option
is *not*.
:param callable factory: A callable that takes no parameters whose result
is used if ``None`` is passed.
:raises TypeError: If **neither** *default* or *factory* is passed.
:raises TypeError: If **both** *default* and *factory* are passed.
:raises ValueError: If an instance of `attrs.Factory` is passed with
``takes_self=True``.
.. versionadded:: 18.2.0
"""
if default is NOTHING and factory is None:
raise TypeError("Must pass either `default` or `factory`.")
if default is not NOTHING and factory is not None:
raise TypeError(
"Must pass either `default` or `factory` but not both."
)
if factory is not None:
default = Factory(factory)
if isinstance(default, Factory):
if default.takes_self:
raise ValueError(
"`takes_self` is not supported by default_if_none."
)
def default_if_none_converter(val):
if val is not None:
return val
return default.factory()
else:
def default_if_none_converter(val):
if val is not None:
return val
return default
return default_if_none_converter
def to_bool(val):
"""
Convert "boolean" strings (e.g., from env. vars.) to real booleans.
Values mapping to :code:`True`:
- :code:`True`
- :code:`"true"` / :code:`"t"`
- :code:`"yes"` / :code:`"y"`
- :code:`"on"`
- :code:`"1"`
- :code:`1`
Values mapping to :code:`False`:
- :code:`False`
- :code:`"false"` / :code:`"f"`
- :code:`"no"` / :code:`"n"`
- :code:`"off"`
- :code:`"0"`
- :code:`0`
:raises ValueError: for any other value.
.. versionadded:: 21.3.0
"""
if isinstance(val, str):
val = val.lower()
truthy = {True, "true", "t", "yes", "y", "on", "1", 1}
falsy = {False, "false", "f", "no", "n", "off", "0", 0}
try:
if val in truthy:
return True
if val in falsy:
return False
except TypeError:
# Raised when "val" is not hashable (e.g., lists)
pass
raise ValueError("Cannot convert value to bool: {}".format(val))

View File

@ -0,0 +1,13 @@
from typing import Callable, Optional, TypeVar, overload
from . import _ConverterType
_T = TypeVar("_T")
def pipe(*validators: _ConverterType) -> _ConverterType: ...
def optional(converter: _ConverterType) -> _ConverterType: ...
@overload
def default_if_none(default: _T) -> _ConverterType: ...
@overload
def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType: ...
def to_bool(val: str) -> bool: ...

View File

@ -0,0 +1,92 @@
# SPDX-License-Identifier: MIT
class FrozenError(AttributeError):
"""
A frozen/immutable instance or attribute have been attempted to be
modified.
It mirrors the behavior of ``namedtuples`` by using the same error message
and subclassing `AttributeError`.
.. versionadded:: 20.1.0
"""
msg = "can't set attribute"
args = [msg]
class FrozenInstanceError(FrozenError):
"""
A frozen instance has been attempted to be modified.
.. versionadded:: 16.1.0
"""
class FrozenAttributeError(FrozenError):
"""
A frozen attribute has been attempted to be modified.
.. versionadded:: 20.1.0
"""
class AttrsAttributeNotFoundError(ValueError):
"""
An ``attrs`` function couldn't find an attribute that the user asked for.
.. versionadded:: 16.2.0
"""
class NotAnAttrsClassError(ValueError):
"""
A non-``attrs`` class has been passed into an ``attrs`` function.
.. versionadded:: 16.2.0
"""
class DefaultAlreadySetError(RuntimeError):
"""
A default has been set using ``attr.ib()`` and is attempted to be reset
using the decorator.
.. versionadded:: 17.1.0
"""
class UnannotatedAttributeError(RuntimeError):
"""
A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type
annotation.
.. versionadded:: 17.3.0
"""
class PythonTooOldError(RuntimeError):
"""
It was attempted to use an ``attrs`` feature that requires a newer Python
version.
.. versionadded:: 18.2.0
"""
class NotCallableError(TypeError):
"""
A ``attr.ib()`` requiring a callable has been set with a value
that is not callable.
.. versionadded:: 19.2.0
"""
def __init__(self, msg, value):
super(TypeError, self).__init__(msg, value)
self.msg = msg
self.value = value
def __str__(self):
return str(self.msg)

View File

@ -0,0 +1,17 @@
from typing import Any
class FrozenError(AttributeError):
msg: str = ...
class FrozenInstanceError(FrozenError): ...
class FrozenAttributeError(FrozenError): ...
class AttrsAttributeNotFoundError(ValueError): ...
class NotAnAttrsClassError(ValueError): ...
class DefaultAlreadySetError(RuntimeError): ...
class UnannotatedAttributeError(RuntimeError): ...
class PythonTooOldError(RuntimeError): ...
class NotCallableError(TypeError):
msg: str = ...
value: Any = ...
def __init__(self, msg: str, value: Any) -> None: ...

View File

@ -0,0 +1,51 @@
# SPDX-License-Identifier: MIT
"""
Commonly useful filters for `attr.asdict`.
"""
from ._make import Attribute
def _split_what(what):
"""
Returns a tuple of `frozenset`s of classes and attributes.
"""
return (
frozenset(cls for cls in what if isinstance(cls, type)),
frozenset(cls for cls in what if isinstance(cls, Attribute)),
)
def include(*what):
"""
Include *what*.
:param what: What to include.
:type what: `list` of `type` or `attrs.Attribute`\\ s
:rtype: `callable`
"""
cls, attrs = _split_what(what)
def include_(attribute, value):
return value.__class__ in cls or attribute in attrs
return include_
def exclude(*what):
"""
Exclude *what*.
:param what: What to exclude.
:type what: `list` of classes or `attrs.Attribute`\\ s.
:rtype: `callable`
"""
cls, attrs = _split_what(what)
def exclude_(attribute, value):
return value.__class__ not in cls and attribute not in attrs
return exclude_

View File

@ -0,0 +1,6 @@
from typing import Any, Union
from . import Attribute, _FilterType
def include(*what: Union[type, Attribute[Any]]) -> _FilterType[Any]: ...
def exclude(*what: Union[type, Attribute[Any]]) -> _FilterType[Any]: ...

View File

@ -0,0 +1,73 @@
# SPDX-License-Identifier: MIT
"""
Commonly used hooks for on_setattr.
"""
from . import _config
from .exceptions import FrozenAttributeError
def pipe(*setters):
"""
Run all *setters* and return the return value of the last one.
.. versionadded:: 20.1.0
"""
def wrapped_pipe(instance, attrib, new_value):
rv = new_value
for setter in setters:
rv = setter(instance, attrib, rv)
return rv
return wrapped_pipe
def frozen(_, __, ___):
"""
Prevent an attribute to be modified.
.. versionadded:: 20.1.0
"""
raise FrozenAttributeError()
def validate(instance, attrib, new_value):
"""
Run *attrib*'s validator on *new_value* if it has one.
.. versionadded:: 20.1.0
"""
if _config._run_validators is False:
return new_value
v = attrib.validator
if not v:
return new_value
v(instance, attrib, new_value)
return new_value
def convert(instance, attrib, new_value):
"""
Run *attrib*'s converter -- if it has one -- on *new_value* and return the
result.
.. versionadded:: 20.1.0
"""
c = attrib.converter
if c:
return c(new_value)
return new_value
# Sentinel for disabling class-wide *on_setattr* hooks for certain attributes.
# autodata stopped working, so the docstring is inlined in the API docs.
NO_OP = object()

View File

@ -0,0 +1,19 @@
from typing import Any, NewType, NoReturn, TypeVar, cast
from . import Attribute, _OnSetAttrType
_T = TypeVar("_T")
def frozen(
instance: Any, attribute: Attribute[Any], new_value: Any
) -> NoReturn: ...
def pipe(*setters: _OnSetAttrType) -> _OnSetAttrType: ...
def validate(instance: Any, attribute: Attribute[_T], new_value: _T) -> _T: ...
# convert is allowed to return Any, because they can be chained using pipe.
def convert(
instance: Any, attribute: Attribute[Any], new_value: Any
) -> Any: ...
_NoOpType = NewType("_NoOpType", object)
NO_OP: _NoOpType

View File

@ -0,0 +1,594 @@
# SPDX-License-Identifier: MIT
"""
Commonly useful validators.
"""
import operator
import re
from contextlib import contextmanager
from ._config import get_run_validators, set_run_validators
from ._make import _AndValidator, and_, attrib, attrs
from .exceptions import NotCallableError
try:
Pattern = re.Pattern
except AttributeError: # Python <3.7 lacks a Pattern type.
Pattern = type(re.compile(""))
__all__ = [
"and_",
"deep_iterable",
"deep_mapping",
"disabled",
"ge",
"get_disabled",
"gt",
"in_",
"instance_of",
"is_callable",
"le",
"lt",
"matches_re",
"max_len",
"min_len",
"optional",
"provides",
"set_disabled",
]
def set_disabled(disabled):
"""
Globally disable or enable running validators.
By default, they are run.
:param disabled: If ``True``, disable running all validators.
:type disabled: bool
.. warning::
This function is not thread-safe!
.. versionadded:: 21.3.0
"""
set_run_validators(not disabled)
def get_disabled():
"""
Return a bool indicating whether validators are currently disabled or not.
:return: ``True`` if validators are currently disabled.
:rtype: bool
.. versionadded:: 21.3.0
"""
return not get_run_validators()
@contextmanager
def disabled():
"""
Context manager that disables running validators within its context.
.. warning::
This context manager is not thread-safe!
.. versionadded:: 21.3.0
"""
set_run_validators(False)
try:
yield
finally:
set_run_validators(True)
@attrs(repr=False, slots=True, hash=True)
class _InstanceOfValidator:
type = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not isinstance(value, self.type):
raise TypeError(
"'{name}' must be {type!r} (got {value!r} that is a "
"{actual!r}).".format(
name=attr.name,
type=self.type,
actual=value.__class__,
value=value,
),
attr,
self.type,
value,
)
def __repr__(self):
return "<instance_of validator for type {type!r}>".format(
type=self.type
)
def instance_of(type):
"""
A validator that raises a `TypeError` if the initializer is called
with a wrong type for this particular attribute (checks are performed using
`isinstance` therefore it's also valid to pass a tuple of types).
:param type: The type to check for.
:type type: type or tuple of types
:raises TypeError: With a human readable error message, the attribute
(of type `attrs.Attribute`), the expected type, and the value it
got.
"""
return _InstanceOfValidator(type)
@attrs(repr=False, frozen=True, slots=True)
class _MatchesReValidator:
pattern = attrib()
match_func = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.match_func(value):
raise ValueError(
"'{name}' must match regex {pattern!r}"
" ({value!r} doesn't)".format(
name=attr.name, pattern=self.pattern.pattern, value=value
),
attr,
self.pattern,
value,
)
def __repr__(self):
return "<matches_re validator for pattern {pattern!r}>".format(
pattern=self.pattern
)
def matches_re(regex, flags=0, func=None):
r"""
A validator that raises `ValueError` if the initializer is called
with a string that doesn't match *regex*.
:param regex: a regex string or precompiled pattern to match against
:param int flags: flags that will be passed to the underlying re function
(default 0)
:param callable func: which underlying `re` function to call. Valid options
are `re.fullmatch`, `re.search`, and `re.match`; the default ``None``
means `re.fullmatch`. For performance reasons, the pattern is always
precompiled using `re.compile`.
.. versionadded:: 19.2.0
.. versionchanged:: 21.3.0 *regex* can be a pre-compiled pattern.
"""
valid_funcs = (re.fullmatch, None, re.search, re.match)
if func not in valid_funcs:
raise ValueError(
"'func' must be one of {}.".format(
", ".join(
sorted(
e and e.__name__ or "None" for e in set(valid_funcs)
)
)
)
)
if isinstance(regex, Pattern):
if flags:
raise TypeError(
"'flags' can only be used with a string pattern; "
"pass flags to re.compile() instead"
)
pattern = regex
else:
pattern = re.compile(regex, flags)
if func is re.match:
match_func = pattern.match
elif func is re.search:
match_func = pattern.search
else:
match_func = pattern.fullmatch
return _MatchesReValidator(pattern, match_func)
@attrs(repr=False, slots=True, hash=True)
class _ProvidesValidator:
interface = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.interface.providedBy(value):
raise TypeError(
"'{name}' must provide {interface!r} which {value!r} "
"doesn't.".format(
name=attr.name, interface=self.interface, value=value
),
attr,
self.interface,
value,
)
def __repr__(self):
return "<provides validator for interface {interface!r}>".format(
interface=self.interface
)
def provides(interface):
"""
A validator that raises a `TypeError` if the initializer is called
with an object that does not provide the requested *interface* (checks are
performed using ``interface.providedBy(value)`` (see `zope.interface
<https://zopeinterface.readthedocs.io/en/latest/>`_).
:param interface: The interface to check for.
:type interface: ``zope.interface.Interface``
:raises TypeError: With a human readable error message, the attribute
(of type `attrs.Attribute`), the expected interface, and the
value it got.
"""
return _ProvidesValidator(interface)
@attrs(repr=False, slots=True, hash=True)
class _OptionalValidator:
validator = attrib()
def __call__(self, inst, attr, value):
if value is None:
return
self.validator(inst, attr, value)
def __repr__(self):
return "<optional validator for {what} or None>".format(
what=repr(self.validator)
)
def optional(validator):
"""
A validator that makes an attribute optional. An optional attribute is one
which can be set to ``None`` in addition to satisfying the requirements of
the sub-validator.
:param validator: A validator (or a list of validators) that is used for
non-``None`` values.
:type validator: callable or `list` of callables.
.. versionadded:: 15.1.0
.. versionchanged:: 17.1.0 *validator* can be a list of validators.
"""
if isinstance(validator, list):
return _OptionalValidator(_AndValidator(validator))
return _OptionalValidator(validator)
@attrs(repr=False, slots=True, hash=True)
class _InValidator:
options = attrib()
def __call__(self, inst, attr, value):
try:
in_options = value in self.options
except TypeError: # e.g. `1 in "abc"`
in_options = False
if not in_options:
raise ValueError(
"'{name}' must be in {options!r} (got {value!r})".format(
name=attr.name, options=self.options, value=value
),
attr,
self.options,
value,
)
def __repr__(self):
return "<in_ validator with options {options!r}>".format(
options=self.options
)
def in_(options):
"""
A validator that raises a `ValueError` if the initializer is called
with a value that does not belong in the options provided. The check is
performed using ``value in options``.
:param options: Allowed options.
:type options: list, tuple, `enum.Enum`, ...
:raises ValueError: With a human readable error message, the attribute (of
type `attrs.Attribute`), the expected options, and the value it
got.
.. versionadded:: 17.1.0
.. versionchanged:: 22.1.0
The ValueError was incomplete until now and only contained the human
readable error message. Now it contains all the information that has
been promised since 17.1.0.
"""
return _InValidator(options)
@attrs(repr=False, slots=False, hash=True)
class _IsCallableValidator:
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not callable(value):
message = (
"'{name}' must be callable "
"(got {value!r} that is a {actual!r})."
)
raise NotCallableError(
msg=message.format(
name=attr.name, value=value, actual=value.__class__
),
value=value,
)
def __repr__(self):
return "<is_callable validator>"
def is_callable():
"""
A validator that raises a `attr.exceptions.NotCallableError` if the
initializer is called with a value for this particular attribute
that is not callable.
.. versionadded:: 19.1.0
:raises `attr.exceptions.NotCallableError`: With a human readable error
message containing the attribute (`attrs.Attribute`) name,
and the value it got.
"""
return _IsCallableValidator()
@attrs(repr=False, slots=True, hash=True)
class _DeepIterable:
member_validator = attrib(validator=is_callable())
iterable_validator = attrib(
default=None, validator=optional(is_callable())
)
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if self.iterable_validator is not None:
self.iterable_validator(inst, attr, value)
for member in value:
self.member_validator(inst, attr, member)
def __repr__(self):
iterable_identifier = (
""
if self.iterable_validator is None
else " {iterable!r}".format(iterable=self.iterable_validator)
)
return (
"<deep_iterable validator for{iterable_identifier}"
" iterables of {member!r}>"
).format(
iterable_identifier=iterable_identifier,
member=self.member_validator,
)
def deep_iterable(member_validator, iterable_validator=None):
"""
A validator that performs deep validation of an iterable.
:param member_validator: Validator(s) to apply to iterable members
:param iterable_validator: Validator to apply to iterable itself
(optional)
.. versionadded:: 19.1.0
:raises TypeError: if any sub-validators fail
"""
if isinstance(member_validator, (list, tuple)):
member_validator = and_(*member_validator)
return _DeepIterable(member_validator, iterable_validator)
@attrs(repr=False, slots=True, hash=True)
class _DeepMapping:
key_validator = attrib(validator=is_callable())
value_validator = attrib(validator=is_callable())
mapping_validator = attrib(default=None, validator=optional(is_callable()))
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if self.mapping_validator is not None:
self.mapping_validator(inst, attr, value)
for key in value:
self.key_validator(inst, attr, key)
self.value_validator(inst, attr, value[key])
def __repr__(self):
return (
"<deep_mapping validator for objects mapping {key!r} to {value!r}>"
).format(key=self.key_validator, value=self.value_validator)
def deep_mapping(key_validator, value_validator, mapping_validator=None):
"""
A validator that performs deep validation of a dictionary.
:param key_validator: Validator to apply to dictionary keys
:param value_validator: Validator to apply to dictionary values
:param mapping_validator: Validator to apply to top-level mapping
attribute (optional)
.. versionadded:: 19.1.0
:raises TypeError: if any sub-validators fail
"""
return _DeepMapping(key_validator, value_validator, mapping_validator)
@attrs(repr=False, frozen=True, slots=True)
class _NumberValidator:
bound = attrib()
compare_op = attrib()
compare_func = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if not self.compare_func(value, self.bound):
raise ValueError(
"'{name}' must be {op} {bound}: {value}".format(
name=attr.name,
op=self.compare_op,
bound=self.bound,
value=value,
)
)
def __repr__(self):
return "<Validator for x {op} {bound}>".format(
op=self.compare_op, bound=self.bound
)
def lt(val):
"""
A validator that raises `ValueError` if the initializer is called
with a number larger or equal to *val*.
:param val: Exclusive upper bound for values
.. versionadded:: 21.3.0
"""
return _NumberValidator(val, "<", operator.lt)
def le(val):
"""
A validator that raises `ValueError` if the initializer is called
with a number greater than *val*.
:param val: Inclusive upper bound for values
.. versionadded:: 21.3.0
"""
return _NumberValidator(val, "<=", operator.le)
def ge(val):
"""
A validator that raises `ValueError` if the initializer is called
with a number smaller than *val*.
:param val: Inclusive lower bound for values
.. versionadded:: 21.3.0
"""
return _NumberValidator(val, ">=", operator.ge)
def gt(val):
"""
A validator that raises `ValueError` if the initializer is called
with a number smaller or equal to *val*.
:param val: Exclusive lower bound for values
.. versionadded:: 21.3.0
"""
return _NumberValidator(val, ">", operator.gt)
@attrs(repr=False, frozen=True, slots=True)
class _MaxLengthValidator:
max_length = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if len(value) > self.max_length:
raise ValueError(
"Length of '{name}' must be <= {max}: {len}".format(
name=attr.name, max=self.max_length, len=len(value)
)
)
def __repr__(self):
return "<max_len validator for {max}>".format(max=self.max_length)
def max_len(length):
"""
A validator that raises `ValueError` if the initializer is called
with a string or iterable that is longer than *length*.
:param int length: Maximum length of the string or iterable
.. versionadded:: 21.3.0
"""
return _MaxLengthValidator(length)
@attrs(repr=False, frozen=True, slots=True)
class _MinLengthValidator:
min_length = attrib()
def __call__(self, inst, attr, value):
"""
We use a callable class to be able to change the ``__repr__``.
"""
if len(value) < self.min_length:
raise ValueError(
"Length of '{name}' must be => {min}: {len}".format(
name=attr.name, min=self.min_length, len=len(value)
)
)
def __repr__(self):
return "<min_len validator for {min}>".format(min=self.min_length)
def min_len(length):
"""
A validator that raises `ValueError` if the initializer is called
with a string or iterable that is shorter than *length*.
:param int length: Minimum length of the string or iterable
.. versionadded:: 22.1.0
"""
return _MinLengthValidator(length)

View File

@ -0,0 +1,80 @@
from typing import (
Any,
AnyStr,
Callable,
Container,
ContextManager,
Iterable,
List,
Mapping,
Match,
Optional,
Pattern,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from . import _ValidatorType
from . import _ValidatorArgType
_T = TypeVar("_T")
_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
_I = TypeVar("_I", bound=Iterable)
_K = TypeVar("_K")
_V = TypeVar("_V")
_M = TypeVar("_M", bound=Mapping)
def set_disabled(run: bool) -> None: ...
def get_disabled() -> bool: ...
def disabled() -> ContextManager[None]: ...
# To be more precise on instance_of use some overloads.
# If there are more than 3 items in the tuple then we fall back to Any
@overload
def instance_of(type: Type[_T]) -> _ValidatorType[_T]: ...
@overload
def instance_of(type: Tuple[Type[_T]]) -> _ValidatorType[_T]: ...
@overload
def instance_of(
type: Tuple[Type[_T1], Type[_T2]]
) -> _ValidatorType[Union[_T1, _T2]]: ...
@overload
def instance_of(
type: Tuple[Type[_T1], Type[_T2], Type[_T3]]
) -> _ValidatorType[Union[_T1, _T2, _T3]]: ...
@overload
def instance_of(type: Tuple[type, ...]) -> _ValidatorType[Any]: ...
def provides(interface: Any) -> _ValidatorType[Any]: ...
def optional(
validator: Union[_ValidatorType[_T], List[_ValidatorType[_T]]]
) -> _ValidatorType[Optional[_T]]: ...
def in_(options: Container[_T]) -> _ValidatorType[_T]: ...
def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ...
def matches_re(
regex: Union[Pattern[AnyStr], AnyStr],
flags: int = ...,
func: Optional[
Callable[[AnyStr, AnyStr, int], Optional[Match[AnyStr]]]
] = ...,
) -> _ValidatorType[AnyStr]: ...
def deep_iterable(
member_validator: _ValidatorArgType[_T],
iterable_validator: Optional[_ValidatorType[_I]] = ...,
) -> _ValidatorType[_I]: ...
def deep_mapping(
key_validator: _ValidatorType[_K],
value_validator: _ValidatorType[_V],
mapping_validator: Optional[_ValidatorType[_M]] = ...,
) -> _ValidatorType[_M]: ...
def is_callable() -> _ValidatorType[_T]: ...
def lt(val: _T) -> _ValidatorType[_T]: ...
def le(val: _T) -> _ValidatorType[_T]: ...
def ge(val: _T) -> _ValidatorType[_T]: ...
def gt(val: _T) -> _ValidatorType[_T]: ...
def max_len(length: int) -> _ValidatorType[_T]: ...
def min_len(length: int) -> _ValidatorType[_T]: ...

View File

@ -0,0 +1,11 @@
Credits
=======
``attrs`` is written and maintained by `Hynek Schlawack <https://hynek.me/>`_.
The development is kindly supported by `Variomedia AG <https://www.variomedia.de/>`_.
A full list of contributors can be found in `GitHub's overview <https://github.com/python-attrs/attrs/graphs/contributors>`_.
Its the spiritual successor of `characteristic <https://characteristic.readthedocs.io/>`_ and aspires to fix some of it clunkiness and unfortunate decisions.
Both were inspired by Twisteds `FancyEqMixin <https://docs.twisted.org/en/stable/api/twisted.python.util.FancyEqMixin.html>`_ but both are implemented using class decorators because `subclassing is bad for you <https://www.youtube.com/watch?v=3MNVP9-hglc>`_, mkay?

Some files were not shown because too many files have changed in this diff Show More