aboutsummaryrefslogtreecommitdiff
path: root/setuptools/tests/contexts.py
blob: 58948824353efeee97edd40f0988d472404ff845 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import tempfile
import os
import shutil
import sys
import contextlib
import site
import io

import pkg_resources
from filelock import FileLock


@contextlib.contextmanager
def tempdir(cd=lambda dir: None, **kwargs):
    temp_dir = tempfile.mkdtemp(**kwargs)
    orig_dir = os.getcwd()
    try:
        cd(temp_dir)
        yield temp_dir
    finally:
        cd(orig_dir)
        shutil.rmtree(temp_dir)


@contextlib.contextmanager
def environment(**replacements):
    """
    In a context, patch the environment with replacements. Pass None values
    to clear the values.
    """
    saved = dict(
        (key, os.environ[key])
        for key in replacements
        if key in os.environ
    )

    # remove values that are null
    remove = (key for (key, value) in replacements.items() if value is None)
    for key in list(remove):
        os.environ.pop(key, None)
        replacements.pop(key)

    os.environ.update(replacements)

    try:
        yield saved
    finally:
        for key in replacements:
            os.environ.pop(key, None)
        os.environ.update(saved)


@contextlib.contextmanager
def quiet():
    """
    Redirect stdout/stderr to StringIO objects to prevent console output from
    distutils commands.
    """

    old_stdout = sys.stdout
    old_stderr = sys.stderr
    new_stdout = sys.stdout = io.StringIO()
    new_stderr = sys.stderr = io.StringIO()
    try:
        yield new_stdout, new_stderr
    finally:
        new_stdout.seek(0)
        new_stderr.seek(0)
        sys.stdout = old_stdout
        sys.stderr = old_stderr


@contextlib.contextmanager
def save_user_site_setting():
    saved = site.ENABLE_USER_SITE
    try:
        yield saved
    finally:
        site.ENABLE_USER_SITE = saved


@contextlib.contextmanager
def save_pkg_resources_state():
    pr_state = pkg_resources.__getstate__()
    # also save sys.path
    sys_path = sys.path[:]
    try:
        yield pr_state, sys_path
    finally:
        sys.path[:] = sys_path
        pkg_resources.__setstate__(pr_state)


@contextlib.contextmanager
def suppress_exceptions(*excs):
    try:
        yield
    except excs:
        pass


def multiproc(request):
    """
    Return True if running under xdist and multiple
    workers are used.
    """
    try:
        worker_id = request.getfixturevalue('worker_id')
    except Exception:
        return False
    return worker_id != 'master'


@contextlib.contextmanager
def session_locked_tmp_dir(request, tmp_path_factory, name):
    """Uses a file lock to guarantee only one worker can access a temp dir"""
    # get the temp directory shared by all workers
    base = tmp_path_factory.getbasetemp()
    shared_dir = base.parent if multiproc(request) else base

    locked_dir = shared_dir / name
    with FileLock(locked_dir.with_suffix(".lock")):
        # ^-- prevent multiple workers to access the directory at once
        locked_dir.mkdir(exist_ok=True, parents=True)
        yield locked_dir