1from test.test_importlib import abc, util
2
3machinery = util.import_importlib('importlib.machinery')
4
5from test.support import captured_stdout, import_helper, STDLIB_DIR
6import _imp
7import contextlib
8import marshal
9import os.path
10import types
11import unittest
12import warnings
13
14
15@contextlib.contextmanager
16def deprecated():
17    with warnings.catch_warnings():
18        warnings.simplefilter('ignore', DeprecationWarning)
19        yield
20
21
22@contextlib.contextmanager
23def fresh(name, *, oldapi=False):
24    with util.uncache(name):
25        with import_helper.frozen_modules():
26            if oldapi:
27                with deprecated():
28                    yield
29            else:
30                yield
31
32
33def resolve_stdlib_file(name, ispkg=False):
34    assert name
35    if ispkg:
36        return os.path.join(STDLIB_DIR, *name.split('.'), '__init__.py')
37    else:
38        return os.path.join(STDLIB_DIR, *name.split('.')) + '.py'
39
40
41class ExecModuleTests(abc.LoaderTests):
42
43    def exec_module(self, name, origname=None):
44        with import_helper.frozen_modules():
45            is_package = self.machinery.FrozenImporter.is_package(name)
46        spec = self.machinery.ModuleSpec(
47            name,
48            self.machinery.FrozenImporter,
49            origin='frozen',
50            is_package=is_package,
51            loader_state=types.SimpleNamespace(
52                origname=origname or name,
53                filename=resolve_stdlib_file(origname or name, is_package),
54            ),
55        )
56        module = types.ModuleType(name)
57        module.__spec__ = spec
58        assert not hasattr(module, 'initialized')
59
60        with fresh(name):
61            self.machinery.FrozenImporter.exec_module(module)
62        with captured_stdout() as stdout:
63            module.main()
64
65        self.assertTrue(module.initialized)
66        self.assertTrue(hasattr(module, '__spec__'))
67        self.assertEqual(module.__spec__.origin, 'frozen')
68        return module, stdout.getvalue()
69
70    def test_module(self):
71        name = '__hello__'
72        module, output = self.exec_module(name)
73        check = {'__name__': name}
74        for attr, value in check.items():
75            self.assertEqual(getattr(module, attr), value)
76        self.assertEqual(output, 'Hello world!\n')
77        self.assertTrue(hasattr(module, '__spec__'))
78        self.assertEqual(module.__spec__.loader_state.origname, name)
79
80    def test_package(self):
81        name = '__phello__'
82        module, output = self.exec_module(name)
83        check = {'__name__': name}
84        for attr, value in check.items():
85            attr_value = getattr(module, attr)
86            self.assertEqual(attr_value, value,
87                        'for {name}.{attr}, {given!r} != {expected!r}'.format(
88                                 name=name, attr=attr, given=attr_value,
89                                 expected=value))
90        self.assertEqual(output, 'Hello world!\n')
91        self.assertEqual(module.__spec__.loader_state.origname, name)
92
93    def test_lacking_parent(self):
94        name = '__phello__.spam'
95        with util.uncache('__phello__'):
96            module, output = self.exec_module(name)
97        check = {'__name__': name}
98        for attr, value in check.items():
99            attr_value = getattr(module, attr)
100            self.assertEqual(attr_value, value,
101                    'for {name}.{attr}, {given} != {expected!r}'.format(
102                             name=name, attr=attr, given=attr_value,
103                             expected=value))
104        self.assertEqual(output, 'Hello world!\n')
105
106    def test_module_repr(self):
107        name = '__hello__'
108        module, output = self.exec_module(name)
109        with deprecated():
110            repr_str = self.machinery.FrozenImporter.module_repr(module)
111        self.assertEqual(repr_str,
112                         "<module '__hello__' (frozen)>")
113
114    def test_module_repr_indirect(self):
115        name = '__hello__'
116        module, output = self.exec_module(name)
117        self.assertEqual(repr(module),
118                         "<module '__hello__' (frozen)>")
119
120    # No way to trigger an error in a frozen module.
121    test_state_after_failure = None
122
123    def test_unloadable(self):
124        with import_helper.frozen_modules():
125            assert self.machinery.FrozenImporter.find_spec('_not_real') is None
126        with self.assertRaises(ImportError) as cm:
127            self.exec_module('_not_real')
128        self.assertEqual(cm.exception.name, '_not_real')
129
130
131(Frozen_ExecModuleTests,
132 Source_ExecModuleTests
133 ) = util.test_both(ExecModuleTests, machinery=machinery)
134
135
136class LoaderTests(abc.LoaderTests):
137
138    def load_module(self, name):
139        with fresh(name, oldapi=True):
140            module = self.machinery.FrozenImporter.load_module(name)
141        with captured_stdout() as stdout:
142            module.main()
143        return module, stdout
144
145    def test_module(self):
146        module, stdout = self.load_module('__hello__')
147        filename = resolve_stdlib_file('__hello__')
148        check = {'__name__': '__hello__',
149                '__package__': '',
150                '__loader__': self.machinery.FrozenImporter,
151                '__file__': filename,
152                }
153        for attr, value in check.items():
154            self.assertEqual(getattr(module, attr, None), value)
155        self.assertEqual(stdout.getvalue(), 'Hello world!\n')
156
157    def test_package(self):
158        module, stdout = self.load_module('__phello__')
159        filename = resolve_stdlib_file('__phello__', ispkg=True)
160        pkgdir = os.path.dirname(filename)
161        check = {'__name__': '__phello__',
162                 '__package__': '__phello__',
163                 '__path__': [pkgdir],
164                 '__loader__': self.machinery.FrozenImporter,
165                 '__file__': filename,
166                 }
167        for attr, value in check.items():
168            attr_value = getattr(module, attr, None)
169            self.assertEqual(attr_value, value,
170                             "for __phello__.%s, %r != %r" %
171                             (attr, attr_value, value))
172        self.assertEqual(stdout.getvalue(), 'Hello world!\n')
173
174    def test_lacking_parent(self):
175        with util.uncache('__phello__'):
176            module, stdout = self.load_module('__phello__.spam')
177        filename = resolve_stdlib_file('__phello__.spam')
178        check = {'__name__': '__phello__.spam',
179                '__package__': '__phello__',
180                '__loader__': self.machinery.FrozenImporter,
181                '__file__': filename,
182                }
183        for attr, value in check.items():
184            attr_value = getattr(module, attr)
185            self.assertEqual(attr_value, value,
186                             "for __phello__.spam.%s, %r != %r" %
187                             (attr, attr_value, value))
188        self.assertEqual(stdout.getvalue(), 'Hello world!\n')
189
190    def test_module_reuse(self):
191        with fresh('__hello__', oldapi=True):
192            module1 = self.machinery.FrozenImporter.load_module('__hello__')
193            module2 = self.machinery.FrozenImporter.load_module('__hello__')
194        with captured_stdout() as stdout:
195            module1.main()
196            module2.main()
197        self.assertIs(module1, module2)
198        self.assertEqual(stdout.getvalue(),
199                         'Hello world!\nHello world!\n')
200
201    def test_module_repr(self):
202        with fresh('__hello__', oldapi=True):
203            module = self.machinery.FrozenImporter.load_module('__hello__')
204            repr_str = self.machinery.FrozenImporter.module_repr(module)
205        self.assertEqual(repr_str,
206                         "<module '__hello__' (frozen)>")
207
208    # No way to trigger an error in a frozen module.
209    test_state_after_failure = None
210
211    def test_unloadable(self):
212        with import_helper.frozen_modules():
213            with deprecated():
214                assert self.machinery.FrozenImporter.find_module('_not_real') is None
215            with self.assertRaises(ImportError) as cm:
216                self.load_module('_not_real')
217            self.assertEqual(cm.exception.name, '_not_real')
218
219
220(Frozen_LoaderTests,
221 Source_LoaderTests
222 ) = util.test_both(LoaderTests, machinery=machinery)
223
224
225class InspectLoaderTests:
226
227    """Tests for the InspectLoader methods for FrozenImporter."""
228
229    def test_get_code(self):
230        # Make sure that the code object is good.
231        name = '__hello__'
232        with import_helper.frozen_modules():
233            code = self.machinery.FrozenImporter.get_code(name)
234            mod = types.ModuleType(name)
235            exec(code, mod.__dict__)
236        with captured_stdout() as stdout:
237            mod.main()
238        self.assertTrue(hasattr(mod, 'initialized'))
239        self.assertEqual(stdout.getvalue(), 'Hello world!\n')
240
241    def test_get_source(self):
242        # Should always return None.
243        with import_helper.frozen_modules():
244            result = self.machinery.FrozenImporter.get_source('__hello__')
245        self.assertIsNone(result)
246
247    def test_is_package(self):
248        # Should be able to tell what is a package.
249        test_for = (('__hello__', False), ('__phello__', True),
250                    ('__phello__.spam', False))
251        for name, is_package in test_for:
252            with import_helper.frozen_modules():
253                result = self.machinery.FrozenImporter.is_package(name)
254            self.assertEqual(bool(result), is_package)
255
256    def test_failure(self):
257        # Raise ImportError for modules that are not frozen.
258        for meth_name in ('get_code', 'get_source', 'is_package'):
259            method = getattr(self.machinery.FrozenImporter, meth_name)
260            with self.assertRaises(ImportError) as cm:
261                with import_helper.frozen_modules():
262                    method('importlib')
263            self.assertEqual(cm.exception.name, 'importlib')
264
265(Frozen_ILTests,
266 Source_ILTests
267 ) = util.test_both(InspectLoaderTests, machinery=machinery)
268
269
270if __name__ == '__main__':
271    unittest.main()
272