1from test.test_importlib import util as test_util
2
3init = test_util.import_importlib('importlib')
4
5import sys
6import threading
7import unittest
8import weakref
9
10from test import support
11from test.support import threading_helper
12from test import lock_tests
13
14
15threading_helper.requires_working_threading(module=True)
16
17
18class ModuleLockAsRLockTests:
19    locktype = classmethod(lambda cls: cls.LockType("some_lock"))
20
21    # _is_owned() unsupported
22    test__is_owned = None
23    # acquire(blocking=False) unsupported
24    test_try_acquire = None
25    test_try_acquire_contended = None
26    # `with` unsupported
27    test_with = None
28    # acquire(timeout=...) unsupported
29    test_timeout = None
30    # _release_save() unsupported
31    test_release_save_unacquired = None
32    # lock status in repr unsupported
33    test_repr = None
34    test_locked_repr = None
35
36LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock
37              for kind, splitinit in init.items()}
38
39(Frozen_ModuleLockAsRLockTests,
40 Source_ModuleLockAsRLockTests
41 ) = test_util.test_both(ModuleLockAsRLockTests, lock_tests.RLockTests,
42                         LockType=LOCK_TYPES)
43
44
45class DeadlockAvoidanceTests:
46
47    def setUp(self):
48        try:
49            self.old_switchinterval = sys.getswitchinterval()
50            support.setswitchinterval(0.000001)
51        except AttributeError:
52            self.old_switchinterval = None
53
54    def tearDown(self):
55        if self.old_switchinterval is not None:
56            sys.setswitchinterval(self.old_switchinterval)
57
58    def run_deadlock_avoidance_test(self, create_deadlock):
59        NLOCKS = 10
60        locks = [self.LockType(str(i)) for i in range(NLOCKS)]
61        pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)]
62        if create_deadlock:
63            NTHREADS = NLOCKS
64        else:
65            NTHREADS = NLOCKS - 1
66        barrier = threading.Barrier(NTHREADS)
67        results = []
68
69        def _acquire(lock):
70            """Try to acquire the lock. Return True on success,
71            False on deadlock."""
72            try:
73                lock.acquire()
74            except self.DeadlockError:
75                return False
76            else:
77                return True
78
79        def f():
80            a, b = pairs.pop()
81            ra = _acquire(a)
82            barrier.wait()
83            rb = _acquire(b)
84            results.append((ra, rb))
85            if rb:
86                b.release()
87            if ra:
88                a.release()
89        lock_tests.Bunch(f, NTHREADS).wait_for_finished()
90        self.assertEqual(len(results), NTHREADS)
91        return results
92
93    def test_deadlock(self):
94        results = self.run_deadlock_avoidance_test(True)
95        # At least one of the threads detected a potential deadlock on its
96        # second acquire() call.  It may be several of them, because the
97        # deadlock avoidance mechanism is conservative.
98        nb_deadlocks = results.count((True, False))
99        self.assertGreaterEqual(nb_deadlocks, 1)
100        self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks)
101
102    def test_no_deadlock(self):
103        results = self.run_deadlock_avoidance_test(False)
104        self.assertEqual(results.count((True, False)), 0)
105        self.assertEqual(results.count((True, True)), len(results))
106
107
108DEADLOCK_ERRORS = {kind: splitinit._bootstrap._DeadlockError
109                   for kind, splitinit in init.items()}
110
111(Frozen_DeadlockAvoidanceTests,
112 Source_DeadlockAvoidanceTests
113 ) = test_util.test_both(DeadlockAvoidanceTests,
114                         LockType=LOCK_TYPES,
115                         DeadlockError=DEADLOCK_ERRORS)
116
117
118class LifetimeTests:
119
120    @property
121    def bootstrap(self):
122        return self.init._bootstrap
123
124    def test_lock_lifetime(self):
125        name = "xyzzy"
126        self.assertNotIn(name, self.bootstrap._module_locks)
127        lock = self.bootstrap._get_module_lock(name)
128        self.assertIn(name, self.bootstrap._module_locks)
129        wr = weakref.ref(lock)
130        del lock
131        support.gc_collect()
132        self.assertNotIn(name, self.bootstrap._module_locks)
133        self.assertIsNone(wr())
134
135    def test_all_locks(self):
136        support.gc_collect()
137        self.assertEqual(0, len(self.bootstrap._module_locks),
138                         self.bootstrap._module_locks)
139
140
141(Frozen_LifetimeTests,
142 Source_LifetimeTests
143 ) = test_util.test_both(LifetimeTests, init=init)
144
145
146def setUpModule():
147    thread_info = threading_helper.threading_setup()
148    unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info)
149
150
151if __name__ == '__main__':
152    unittest.main()
153