aboutsummaryrefslogtreecommitdiff
path: root/cros_utils/contextlib3_test.py
blob: 76c010f260f782fca11cf2d315a110af5df21bf0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# Copyright 2019 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Tests for contextlib3"""

from __future__ import division
from __future__ import print_function

import contextlib
import unittest

import contextlib3


class SomeException(Exception):
  """Just an alternative to ValueError in the Exception class hierarchy."""
  pass


class TestExitStack(unittest.TestCase):
  """Tests contextlib3.ExitStack"""

  def test_exceptions_in_exit_override_exceptions_in_with(self):

    @contextlib.contextmanager
    def raise_exit():
      raised = False
      try:
        yield
      except Exception:
        raised = True
        raise ValueError
      finally:
        self.assertTrue(raised)

    # (As noted in comments in contextlib3, this behavior is consistent with
    # how python2 works. Namely, if __exit__ raises, the exception from
    # __exit__ overrides the inner exception)
    with self.assertRaises(ValueError):
      with contextlib3.ExitStack() as stack:
        stack.enter_context(raise_exit())
        raise SomeException()

  def test_raising_in_exit_doesnt_block_later_exits(self):
    exited = []

    @contextlib.contextmanager
    def raise_exit():
      try:
        yield
      finally:
        exited.append('raise')
        raise ValueError

    @contextlib.contextmanager
    def push_exit():
      try:
        yield
      finally:
        exited.append('push')

    with self.assertRaises(ValueError):
      with contextlib3.ExitStack() as stack:
        stack.enter_context(push_exit())
        stack.enter_context(raise_exit())
    self.assertEqual(exited, ['raise', 'push'])

    exited = []
    with self.assertRaises(ValueError):
      with contextlib3.ExitStack() as stack:
        stack.enter_context(push_exit())
        stack.enter_context(raise_exit())
        raise SomeException()
    self.assertEqual(exited, ['raise', 'push'])

  def test_push_doesnt_enter_the_context(self):
    exited = []

    test_self = self

    class Manager(object):
      """A simple ContextManager for testing purposes"""

      def __enter__(self):
        test_self.fail('context manager was entered :(')

      def __exit__(self, *args, **kwargs):
        exited.append(1)

    with contextlib3.ExitStack() as stack:
      stack.push(Manager())
      self.assertEqual(exited, [])
    self.assertEqual(exited, [1])

  def test_callbacks_are_run_properly(self):
    callback_was_run = []

    def callback(arg, some_kwarg=None):
      self.assertEqual(arg, 41)
      self.assertEqual(some_kwarg, 42)
      callback_was_run.append(1)

    with contextlib3.ExitStack() as stack:
      stack.callback(callback, 41, some_kwarg=42)
      self.assertEqual(callback_was_run, [])
    self.assertEqual(callback_was_run, [1])

    callback_was_run = []
    with self.assertRaises(ValueError):
      with contextlib3.ExitStack() as stack:
        stack.callback(callback, 41, some_kwarg=42)
        raise ValueError()
    self.assertEqual(callback_was_run, [1])

  def test_finallys_are_run(self):
    finally_run = []

    @contextlib.contextmanager
    def append_on_exit():
      try:
        yield
      finally:
        finally_run.append(0)

    with self.assertRaises(ValueError):
      with contextlib3.ExitStack() as stack:
        stack.enter_context(append_on_exit())
        raise ValueError()
    self.assertEqual(finally_run, [0])

  def test_unwinding_happens_in_reverse_order(self):
    exit_runs = []

    @contextlib.contextmanager
    def append_things(start_push, end_push):
      exit_runs.append(start_push)
      try:
        yield
      finally:
        exit_runs.append(end_push)

    with contextlib3.ExitStack() as stack:
      stack.enter_context(append_things(1, 4))
      stack.enter_context(append_things(2, 3))
    self.assertEqual(exit_runs, [1, 2, 3, 4])

    exit_runs = []
    with self.assertRaises(ValueError):
      with contextlib3.ExitStack() as stack:
        stack.enter_context(append_things(1, 4))
        stack.enter_context(append_things(2, 3))
        raise ValueError
    self.assertEqual(exit_runs, [1, 2, 3, 4])

  def test_exceptions_are_propagated(self):

    @contextlib.contextmanager
    def die_on_regular_exit():
      yield
      self.fail('Unreachable in theory')

    with self.assertRaises(ValueError):
      with contextlib3.ExitStack() as stack:
        stack.enter_context(die_on_regular_exit())
        raise ValueError()

  def test_exceptions_can_be_blocked(self):

    @contextlib.contextmanager
    def block():
      try:
        yield
      except Exception:
        pass

    with contextlib3.ExitStack() as stack:
      stack.enter_context(block())
      raise ValueError()

  def test_objects_are_returned_from_enter_context(self):

    @contextlib.contextmanager
    def yield_arg(arg):
      yield arg

    with contextlib3.ExitStack() as stack:
      val = stack.enter_context(yield_arg(1))
      self.assertEqual(val, 1)


if __name__ == '__main__':
  unittest.main()