summaryrefslogtreecommitdiff
path: root/lib/python2.7/test/test_contextlib.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/python2.7/test/test_contextlib.py')
-rw-r--r--lib/python2.7/test/test_contextlib.py326
1 files changed, 326 insertions, 0 deletions
diff --git a/lib/python2.7/test/test_contextlib.py b/lib/python2.7/test/test_contextlib.py
new file mode 100644
index 0000000..f28c95e
--- /dev/null
+++ b/lib/python2.7/test/test_contextlib.py
@@ -0,0 +1,326 @@
+"""Unit tests for contextlib.py, and other context managers."""
+
+import sys
+import tempfile
+import unittest
+from contextlib import * # Tests __all__
+from test import test_support
+try:
+ import threading
+except ImportError:
+ threading = None
+
+
+class ContextManagerTestCase(unittest.TestCase):
+
+ def test_contextmanager_plain(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ yield 42
+ state.append(999)
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ self.assertEqual(state, [1, 42, 999])
+
+ def test_contextmanager_finally(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ finally:
+ state.append(999)
+ with self.assertRaises(ZeroDivisionError):
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError()
+ self.assertEqual(state, [1, 42, 999])
+
+ def test_contextmanager_no_reraise(self):
+ @contextmanager
+ def whee():
+ yield
+ ctx = whee()
+ ctx.__enter__()
+ # Calling __exit__ should not result in an exception
+ self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
+
+ def test_contextmanager_trap_yield_after_throw(self):
+ @contextmanager
+ def whoo():
+ try:
+ yield
+ except:
+ yield
+ ctx = whoo()
+ ctx.__enter__()
+ self.assertRaises(
+ RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
+ )
+
+ def test_contextmanager_except(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ except ZeroDivisionError, e:
+ state.append(e.args[0])
+ self.assertEqual(state, [1, 42, 999])
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError(999)
+ self.assertEqual(state, [1, 42, 999])
+
+ def _create_contextmanager_attribs(self):
+ def attribs(**kw):
+ def decorate(func):
+ for k,v in kw.items():
+ setattr(func,k,v)
+ return func
+ return decorate
+ @contextmanager
+ @attribs(foo='bar')
+ def baz(spam):
+ """Whee!"""
+ return baz
+
+ def test_contextmanager_attribs(self):
+ baz = self._create_contextmanager_attribs()
+ self.assertEqual(baz.__name__,'baz')
+ self.assertEqual(baz.foo, 'bar')
+
+ @unittest.skipIf(sys.flags.optimize >= 2,
+ "Docstrings are omitted with -O2 and above")
+ def test_contextmanager_doc_attrib(self):
+ baz = self._create_contextmanager_attribs()
+ self.assertEqual(baz.__doc__, "Whee!")
+
+class NestedTestCase(unittest.TestCase):
+
+ # XXX This needs more work
+
+ def test_nested(self):
+ @contextmanager
+ def a():
+ yield 1
+ @contextmanager
+ def b():
+ yield 2
+ @contextmanager
+ def c():
+ yield 3
+ with nested(a(), b(), c()) as (x, y, z):
+ self.assertEqual(x, 1)
+ self.assertEqual(y, 2)
+ self.assertEqual(z, 3)
+
+ def test_nested_cleanup(self):
+ state = []
+ @contextmanager
+ def a():
+ state.append(1)
+ try:
+ yield 2
+ finally:
+ state.append(3)
+ @contextmanager
+ def b():
+ state.append(4)
+ try:
+ yield 5
+ finally:
+ state.append(6)
+ with self.assertRaises(ZeroDivisionError):
+ with nested(a(), b()) as (x, y):
+ state.append(x)
+ state.append(y)
+ 1 // 0
+ self.assertEqual(state, [1, 4, 2, 5, 6, 3])
+
+ def test_nested_right_exception(self):
+ @contextmanager
+ def a():
+ yield 1
+ class b(object):
+ def __enter__(self):
+ return 2
+ def __exit__(self, *exc_info):
+ try:
+ raise Exception()
+ except:
+ pass
+ with self.assertRaises(ZeroDivisionError):
+ with nested(a(), b()) as (x, y):
+ 1 // 0
+ self.assertEqual((x, y), (1, 2))
+
+ def test_nested_b_swallows(self):
+ @contextmanager
+ def a():
+ yield
+ @contextmanager
+ def b():
+ try:
+ yield
+ except:
+ # Swallow the exception
+ pass
+ try:
+ with nested(a(), b()):
+ 1 // 0
+ except ZeroDivisionError:
+ self.fail("Didn't swallow ZeroDivisionError")
+
+ def test_nested_break(self):
+ @contextmanager
+ def a():
+ yield
+ state = 0
+ while True:
+ state += 1
+ with nested(a(), a()):
+ break
+ state += 10
+ self.assertEqual(state, 1)
+
+ def test_nested_continue(self):
+ @contextmanager
+ def a():
+ yield
+ state = 0
+ while state < 3:
+ state += 1
+ with nested(a(), a()):
+ continue
+ state += 10
+ self.assertEqual(state, 3)
+
+ def test_nested_return(self):
+ @contextmanager
+ def a():
+ try:
+ yield
+ except:
+ pass
+ def foo():
+ with nested(a(), a()):
+ return 1
+ return 10
+ self.assertEqual(foo(), 1)
+
+class ClosingTestCase(unittest.TestCase):
+
+ # XXX This needs more work
+
+ def test_closing(self):
+ state = []
+ class C:
+ def close(self):
+ state.append(1)
+ x = C()
+ self.assertEqual(state, [])
+ with closing(x) as y:
+ self.assertEqual(x, y)
+ self.assertEqual(state, [1])
+
+ def test_closing_error(self):
+ state = []
+ class C:
+ def close(self):
+ state.append(1)
+ x = C()
+ self.assertEqual(state, [])
+ with self.assertRaises(ZeroDivisionError):
+ with closing(x) as y:
+ self.assertEqual(x, y)
+ 1 // 0
+ self.assertEqual(state, [1])
+
+class FileContextTestCase(unittest.TestCase):
+
+ def testWithOpen(self):
+ tfn = tempfile.mktemp()
+ try:
+ f = None
+ with open(tfn, "w") as f:
+ self.assertFalse(f.closed)
+ f.write("Booh\n")
+ self.assertTrue(f.closed)
+ f = None
+ with self.assertRaises(ZeroDivisionError):
+ with open(tfn, "r") as f:
+ self.assertFalse(f.closed)
+ self.assertEqual(f.read(), "Booh\n")
+ 1 // 0
+ self.assertTrue(f.closed)
+ finally:
+ test_support.unlink(tfn)
+
+@unittest.skipUnless(threading, 'Threading required for this test.')
+class LockContextTestCase(unittest.TestCase):
+
+ def boilerPlate(self, lock, locked):
+ self.assertFalse(locked())
+ with lock:
+ self.assertTrue(locked())
+ self.assertFalse(locked())
+ with self.assertRaises(ZeroDivisionError):
+ with lock:
+ self.assertTrue(locked())
+ 1 // 0
+ self.assertFalse(locked())
+
+ def testWithLock(self):
+ lock = threading.Lock()
+ self.boilerPlate(lock, lock.locked)
+
+ def testWithRLock(self):
+ lock = threading.RLock()
+ self.boilerPlate(lock, lock._is_owned)
+
+ def testWithCondition(self):
+ lock = threading.Condition()
+ def locked():
+ return lock._is_owned()
+ self.boilerPlate(lock, locked)
+
+ def testWithSemaphore(self):
+ lock = threading.Semaphore()
+ def locked():
+ if lock.acquire(False):
+ lock.release()
+ return False
+ else:
+ return True
+ self.boilerPlate(lock, locked)
+
+ def testWithBoundedSemaphore(self):
+ lock = threading.BoundedSemaphore()
+ def locked():
+ if lock.acquire(False):
+ lock.release()
+ return False
+ else:
+ return True
+ self.boilerPlate(lock, locked)
+
+# This is needed to make the test actually run under regrtest.py!
+def test_main():
+ with test_support.check_warnings(("With-statements now directly support "
+ "multiple context managers",
+ DeprecationWarning)):
+ test_support.run_unittest(__name__)
+
+if __name__ == "__main__":
+ test_main()