1import sysconfig
2import textwrap
3import unittest
4import os
5import shutil
6import tempfile
7from pathlib import Path
8
9from test import test_tools
10from test import support
11from test.support import os_helper
12from test.support.script_helper import assert_python_ok
13
14_py_cflags_nodist = sysconfig.get_config_var("PY_CFLAGS_NODIST")
15_pgo_flag = sysconfig.get_config_var("PGO_PROF_USE_FLAG")
16if _pgo_flag and _py_cflags_nodist and _pgo_flag in _py_cflags_nodist:
17    raise unittest.SkipTest("peg_generator test disabled under PGO build")
18
19test_tools.skip_if_missing("peg_generator")
20with test_tools.imports_under_tool("peg_generator"):
21    from pegen.grammar_parser import GeneratedParser as GrammarParser
22    from pegen.testutil import (
23        parse_string,
24        generate_parser_c_extension,
25        generate_c_parser_source,
26    )
27    from pegen.ast_dump import ast_dump
28
29
30TEST_TEMPLATE = """
31tmp_dir = {extension_path!r}
32
33import ast
34import traceback
35import sys
36import unittest
37
38from test import test_tools
39with test_tools.imports_under_tool("peg_generator"):
40    from pegen.ast_dump import ast_dump
41
42sys.path.insert(0, tmp_dir)
43import parse
44
45class Tests(unittest.TestCase):
46
47    def check_input_strings_for_grammar(
48        self,
49        valid_cases = (),
50        invalid_cases = (),
51    ):
52        if valid_cases:
53            for case in valid_cases:
54                parse.parse_string(case, mode=0)
55
56        if invalid_cases:
57            for case in invalid_cases:
58                with self.assertRaises(SyntaxError):
59                    parse.parse_string(case, mode=0)
60
61    def verify_ast_generation(self, stmt):
62        expected_ast = ast.parse(stmt)
63        actual_ast = parse.parse_string(stmt, mode=1)
64        self.assertEqual(ast_dump(expected_ast), ast_dump(actual_ast))
65
66    def test_parse(self):
67        {test_source}
68
69unittest.main()
70"""
71
72
73@support.requires_subprocess()
74class TestCParser(unittest.TestCase):
75
76    @classmethod
77    def setUpClass(cls):
78        # When running under regtest, a seperate tempdir is used
79        # as the current directory and watched for left-overs.
80        # Reusing that as the base for temporary directories
81        # ensures everything is cleaned up properly and
82        # cleans up afterwards if not (with warnings).
83        cls.tmp_base = os.getcwd()
84        if os.path.samefile(cls.tmp_base, os_helper.SAVEDCWD):
85            cls.tmp_base = None
86        # Create a directory for the reuseable static library part of
87        # the pegen extension build process.  This greatly reduces the
88        # runtime overhead of spawning compiler processes.
89        cls.library_dir = tempfile.mkdtemp(dir=cls.tmp_base)
90        cls.addClassCleanup(shutil.rmtree, cls.library_dir)
91
92    def setUp(self):
93        self._backup_config_vars = dict(sysconfig._CONFIG_VARS)
94        cmd = support.missing_compiler_executable()
95        if cmd is not None:
96            self.skipTest("The %r command is not found" % cmd)
97        self.old_cwd = os.getcwd()
98        self.tmp_path = tempfile.mkdtemp(dir=self.tmp_base)
99        self.enterContext(os_helper.change_cwd(self.tmp_path))
100
101    def tearDown(self):
102        os.chdir(self.old_cwd)
103        shutil.rmtree(self.tmp_path)
104        sysconfig._CONFIG_VARS.clear()
105        sysconfig._CONFIG_VARS.update(self._backup_config_vars)
106
107    def build_extension(self, grammar_source):
108        grammar = parse_string(grammar_source, GrammarParser)
109        # Because setUp() already changes the current directory to the
110        # temporary path, use a relative path here to prevent excessive
111        # path lengths when compiling.
112        generate_parser_c_extension(grammar, Path('.'), library_dir=self.library_dir)
113
114    def run_test(self, grammar_source, test_source):
115        self.build_extension(grammar_source)
116        test_source = textwrap.indent(textwrap.dedent(test_source), 8 * " ")
117        assert_python_ok(
118            "-c",
119            TEST_TEMPLATE.format(extension_path=self.tmp_path, test_source=test_source),
120        )
121
122    def test_c_parser(self) -> None:
123        grammar_source = """
124        start[mod_ty]: a[asdl_stmt_seq*]=stmt* $ { _PyAST_Module(a, NULL, p->arena) }
125        stmt[stmt_ty]: a=expr_stmt { a }
126        expr_stmt[stmt_ty]: a=expression NEWLINE { _PyAST_Expr(a, EXTRA) }
127        expression[expr_ty]: ( l=expression '+' r=term { _PyAST_BinOp(l, Add, r, EXTRA) }
128                            | l=expression '-' r=term { _PyAST_BinOp(l, Sub, r, EXTRA) }
129                            | t=term { t }
130                            )
131        term[expr_ty]: ( l=term '*' r=factor { _PyAST_BinOp(l, Mult, r, EXTRA) }
132                    | l=term '/' r=factor { _PyAST_BinOp(l, Div, r, EXTRA) }
133                    | f=factor { f }
134                    )
135        factor[expr_ty]: ('(' e=expression ')' { e }
136                        | a=atom { a }
137                        )
138        atom[expr_ty]: ( n=NAME { n }
139                    | n=NUMBER { n }
140                    | s=STRING { s }
141                    )
142        """
143        test_source = """
144        expressions = [
145            "4+5",
146            "4-5",
147            "4*5",
148            "1+4*5",
149            "1+4/5",
150            "(1+1) + (1+1)",
151            "(1+1) - (1+1)",
152            "(1+1) * (1+1)",
153            "(1+1) / (1+1)",
154        ]
155
156        for expr in expressions:
157            the_ast = parse.parse_string(expr, mode=1)
158            expected_ast = ast.parse(expr)
159            self.assertEqual(ast_dump(the_ast), ast_dump(expected_ast))
160        """
161        self.run_test(grammar_source, test_source)
162
163    def test_lookahead(self) -> None:
164        grammar_source = """
165        start: NAME &NAME expr NEWLINE? ENDMARKER
166        expr: NAME | NUMBER
167        """
168        test_source = """
169        valid_cases = ["foo bar"]
170        invalid_cases = ["foo 34"]
171        self.check_input_strings_for_grammar(valid_cases, invalid_cases)
172        """
173        self.run_test(grammar_source, test_source)
174
175    def test_negative_lookahead(self) -> None:
176        grammar_source = """
177        start: NAME !NAME expr NEWLINE? ENDMARKER
178        expr: NAME | NUMBER
179        """
180        test_source = """
181        valid_cases = ["foo 34"]
182        invalid_cases = ["foo bar"]
183        self.check_input_strings_for_grammar(valid_cases, invalid_cases)
184        """
185        self.run_test(grammar_source, test_source)
186
187    def test_cut(self) -> None:
188        grammar_source = """
189        start: X ~ Y Z | X Q S
190        X: 'x'
191        Y: 'y'
192        Z: 'z'
193        Q: 'q'
194        S: 's'
195        """
196        test_source = """
197        valid_cases = ["x y z"]
198        invalid_cases = ["x q s"]
199        self.check_input_strings_for_grammar(valid_cases, invalid_cases)
200        """
201        self.run_test(grammar_source, test_source)
202
203    def test_gather(self) -> None:
204        grammar_source = """
205        start: ';'.pass_stmt+ NEWLINE
206        pass_stmt: 'pass'
207        """
208        test_source = """
209        valid_cases = ["pass", "pass; pass"]
210        invalid_cases = ["pass;", "pass; pass;"]
211        self.check_input_strings_for_grammar(valid_cases, invalid_cases)
212        """
213        self.run_test(grammar_source, test_source)
214
215    def test_left_recursion(self) -> None:
216        grammar_source = """
217        start: expr NEWLINE
218        expr: ('-' term | expr '+' term | term)
219        term: NUMBER
220        """
221        test_source = """
222        valid_cases = ["-34", "34", "34 + 12", "1 + 1 + 2 + 3"]
223        self.check_input_strings_for_grammar(valid_cases)
224        """
225        self.run_test(grammar_source, test_source)
226
227    def test_advanced_left_recursive(self) -> None:
228        grammar_source = """
229        start: NUMBER | sign start
230        sign: ['-']
231        """
232        test_source = """
233        valid_cases = ["23", "-34"]
234        self.check_input_strings_for_grammar(valid_cases)
235        """
236        self.run_test(grammar_source, test_source)
237
238    def test_mutually_left_recursive(self) -> None:
239        grammar_source = """
240        start: foo 'E'
241        foo: bar 'A' | 'B'
242        bar: foo 'C' | 'D'
243        """
244        test_source = """
245        valid_cases = ["B E", "D A C A E"]
246        self.check_input_strings_for_grammar(valid_cases)
247        """
248        self.run_test(grammar_source, test_source)
249
250    def test_nasty_mutually_left_recursive(self) -> None:
251        grammar_source = """
252        start: target '='
253        target: maybe '+' | NAME
254        maybe: maybe '-' | target
255        """
256        test_source = """
257        valid_cases = ["x ="]
258        invalid_cases = ["x - + ="]
259        self.check_input_strings_for_grammar(valid_cases, invalid_cases)
260        """
261        self.run_test(grammar_source, test_source)
262
263    def test_return_stmt_noexpr_action(self) -> None:
264        grammar_source = """
265        start[mod_ty]: a=[statements] ENDMARKER { _PyAST_Module(a, NULL, p->arena) }
266        statements[asdl_stmt_seq*]: a[asdl_stmt_seq*]=statement+ { a }
267        statement[stmt_ty]: simple_stmt
268        simple_stmt[stmt_ty]: small_stmt
269        small_stmt[stmt_ty]: return_stmt
270        return_stmt[stmt_ty]: a='return' NEWLINE { _PyAST_Return(NULL, EXTRA) }
271        """
272        test_source = """
273        stmt = "return"
274        self.verify_ast_generation(stmt)
275        """
276        self.run_test(grammar_source, test_source)
277
278    def test_gather_action_ast(self) -> None:
279        grammar_source = """
280        start[mod_ty]: a[asdl_stmt_seq*]=';'.pass_stmt+ NEWLINE ENDMARKER { _PyAST_Module(a, NULL, p->arena) }
281        pass_stmt[stmt_ty]: a='pass' { _PyAST_Pass(EXTRA)}
282        """
283        test_source = """
284        stmt = "pass; pass"
285        self.verify_ast_generation(stmt)
286        """
287        self.run_test(grammar_source, test_source)
288
289    def test_pass_stmt_action(self) -> None:
290        grammar_source = """
291        start[mod_ty]: a=[statements] ENDMARKER { _PyAST_Module(a, NULL, p->arena) }
292        statements[asdl_stmt_seq*]: a[asdl_stmt_seq*]=statement+ { a }
293        statement[stmt_ty]: simple_stmt
294        simple_stmt[stmt_ty]: small_stmt
295        small_stmt[stmt_ty]: pass_stmt
296        pass_stmt[stmt_ty]: a='pass' NEWLINE { _PyAST_Pass(EXTRA) }
297        """
298        test_source = """
299        stmt = "pass"
300        self.verify_ast_generation(stmt)
301        """
302        self.run_test(grammar_source, test_source)
303
304    def test_if_stmt_action(self) -> None:
305        grammar_source = """
306        start[mod_ty]: a=[statements] ENDMARKER { _PyAST_Module(a, NULL, p->arena) }
307        statements[asdl_stmt_seq*]: a=statement+ { (asdl_stmt_seq*)_PyPegen_seq_flatten(p, a) }
308        statement[asdl_stmt_seq*]:  a=compound_stmt { (asdl_stmt_seq*)_PyPegen_singleton_seq(p, a) } | simple_stmt
309
310        simple_stmt[asdl_stmt_seq*]: a=small_stmt b=further_small_stmt* [';'] NEWLINE {
311                                            (asdl_stmt_seq*)_PyPegen_seq_insert_in_front(p, a, b) }
312        further_small_stmt[stmt_ty]: ';' a=small_stmt { a }
313
314        block: simple_stmt | NEWLINE INDENT a=statements DEDENT { a }
315
316        compound_stmt: if_stmt
317
318        if_stmt: 'if' a=full_expression ':' b=block { _PyAST_If(a, b, NULL, EXTRA) }
319
320        small_stmt[stmt_ty]: pass_stmt
321
322        pass_stmt[stmt_ty]: a='pass' { _PyAST_Pass(EXTRA) }
323
324        full_expression: NAME
325        """
326        test_source = """
327        stmt = "pass"
328        self.verify_ast_generation(stmt)
329        """
330        self.run_test(grammar_source, test_source)
331
332    def test_same_name_different_types(self) -> None:
333        grammar_source = """
334        start[mod_ty]: a[asdl_stmt_seq*]=import_from+ NEWLINE ENDMARKER { _PyAST_Module(a, NULL, p->arena)}
335        import_from[stmt_ty]: ( a='from' !'import' c=simple_name 'import' d=import_as_names_from {
336                                _PyAST_ImportFrom(c->v.Name.id, d, 0, EXTRA) }
337                            | a='from' '.' 'import' c=import_as_names_from {
338                                _PyAST_ImportFrom(NULL, c, 1, EXTRA) }
339                            )
340        simple_name[expr_ty]: NAME
341        import_as_names_from[asdl_alias_seq*]: a[asdl_alias_seq*]=','.import_as_name_from+ { a }
342        import_as_name_from[alias_ty]: a=NAME 'as' b=NAME { _PyAST_alias(((expr_ty) a)->v.Name.id, ((expr_ty) b)->v.Name.id, EXTRA) }
343        """
344        test_source = """
345        for stmt in ("from a import b as c", "from . import a as b"):
346            expected_ast = ast.parse(stmt)
347            actual_ast = parse.parse_string(stmt, mode=1)
348            self.assertEqual(ast_dump(expected_ast), ast_dump(actual_ast))
349        """
350        self.run_test(grammar_source, test_source)
351
352    def test_with_stmt_with_paren(self) -> None:
353        grammar_source = """
354        start[mod_ty]: a=[statements] ENDMARKER { _PyAST_Module(a, NULL, p->arena) }
355        statements[asdl_stmt_seq*]: a=statement+ { (asdl_stmt_seq*)_PyPegen_seq_flatten(p, a) }
356        statement[asdl_stmt_seq*]: a=compound_stmt { (asdl_stmt_seq*)_PyPegen_singleton_seq(p, a) }
357        compound_stmt[stmt_ty]: with_stmt
358        with_stmt[stmt_ty]: (
359            a='with' '(' b[asdl_withitem_seq*]=','.with_item+ ')' ':' c=block {
360                _PyAST_With(b, (asdl_stmt_seq*) _PyPegen_singleton_seq(p, c), NULL, EXTRA) }
361        )
362        with_item[withitem_ty]: (
363            e=NAME o=['as' t=NAME { t }] { _PyAST_withitem(e, _PyPegen_set_expr_context(p, o, Store), p->arena) }
364        )
365        block[stmt_ty]: a=pass_stmt NEWLINE { a } | NEWLINE INDENT a=pass_stmt DEDENT { a }
366        pass_stmt[stmt_ty]: a='pass' { _PyAST_Pass(EXTRA) }
367        """
368        test_source = """
369        stmt = "with (\\n    a as b,\\n    c as d\\n): pass"
370        the_ast = parse.parse_string(stmt, mode=1)
371        self.assertTrue(ast_dump(the_ast).startswith(
372            "Module(body=[With(items=[withitem(context_expr=Name(id='a', ctx=Load()), optional_vars=Name(id='b', ctx=Store())), "
373            "withitem(context_expr=Name(id='c', ctx=Load()), optional_vars=Name(id='d', ctx=Store()))]"
374        ))
375        """
376        self.run_test(grammar_source, test_source)
377
378    def test_ternary_operator(self) -> None:
379        grammar_source = """
380        start[mod_ty]: a=expr ENDMARKER { _PyAST_Module(a, NULL, p->arena) }
381        expr[asdl_stmt_seq*]: a=listcomp NEWLINE { (asdl_stmt_seq*)_PyPegen_singleton_seq(p, _PyAST_Expr(a, EXTRA)) }
382        listcomp[expr_ty]: (
383            a='[' b=NAME c=for_if_clauses d=']' { _PyAST_ListComp(b, c, EXTRA) }
384        )
385        for_if_clauses[asdl_comprehension_seq*]: (
386            a[asdl_comprehension_seq*]=(y=[ASYNC] 'for' a=NAME 'in' b=NAME c[asdl_expr_seq*]=('if' z=NAME { z })*
387                { _PyAST_comprehension(_PyAST_Name(((expr_ty) a)->v.Name.id, Store, EXTRA), b, c, (y == NULL) ? 0 : 1, p->arena) })+ { a }
388        )
389        """
390        test_source = """
391        stmt = "[i for i in a if b]"
392        self.verify_ast_generation(stmt)
393        """
394        self.run_test(grammar_source, test_source)
395
396    def test_syntax_error_for_string(self) -> None:
397        grammar_source = """
398        start: expr+ NEWLINE? ENDMARKER
399        expr: NAME
400        """
401        test_source = r"""
402        for text in ("a b 42 b a", "\u540d \u540d 42 \u540d \u540d"):
403            try:
404                parse.parse_string(text, mode=0)
405            except SyntaxError as e:
406                tb = traceback.format_exc()
407            self.assertTrue('File "<string>", line 1' in tb)
408            self.assertTrue(f"SyntaxError: invalid syntax" in tb)
409        """
410        self.run_test(grammar_source, test_source)
411
412    def test_headers_and_trailer(self) -> None:
413        grammar_source = """
414        @header 'SOME HEADER'
415        @subheader 'SOME SUBHEADER'
416        @trailer 'SOME TRAILER'
417        start: expr+ NEWLINE? ENDMARKER
418        expr: x=NAME
419        """
420        grammar = parse_string(grammar_source, GrammarParser)
421        parser_source = generate_c_parser_source(grammar)
422
423        self.assertTrue("SOME HEADER" in parser_source)
424        self.assertTrue("SOME SUBHEADER" in parser_source)
425        self.assertTrue("SOME TRAILER" in parser_source)
426
427    def test_error_in_rules(self) -> None:
428        grammar_source = """
429        start: expr+ NEWLINE? ENDMARKER
430        expr: NAME {PyTuple_New(-1)}
431        """
432        # PyTuple_New raises SystemError if an invalid argument was passed.
433        test_source = """
434        with self.assertRaises(SystemError):
435            parse.parse_string("a", mode=0)
436        """
437        self.run_test(grammar_source, test_source)
438
439    def test_no_soft_keywords(self) -> None:
440        grammar_source = """
441        start: expr+ NEWLINE? ENDMARKER
442        expr: 'foo'
443        """
444        grammar = parse_string(grammar_source, GrammarParser)
445        parser_source = generate_c_parser_source(grammar)
446        assert "expect_soft_keyword" not in parser_source
447
448    def test_soft_keywords(self) -> None:
449        grammar_source = """
450        start: expr+ NEWLINE? ENDMARKER
451        expr: "foo"
452        """
453        grammar = parse_string(grammar_source, GrammarParser)
454        parser_source = generate_c_parser_source(grammar)
455        assert "expect_soft_keyword" in parser_source
456
457    def test_soft_keywords_parse(self) -> None:
458        grammar_source = """
459        start: "if" expr '+' expr NEWLINE
460        expr: NAME
461        """
462        test_source = """
463        valid_cases = ["if if + if"]
464        invalid_cases = ["if if"]
465        self.check_input_strings_for_grammar(valid_cases, invalid_cases)
466        """
467        self.run_test(grammar_source, test_source)
468
469    def test_soft_keywords_lookahead(self) -> None:
470        grammar_source = """
471        start: &"if" "if" expr '+' expr NEWLINE
472        expr: NAME
473        """
474        test_source = """
475        valid_cases = ["if if + if"]
476        invalid_cases = ["if if"]
477        self.check_input_strings_for_grammar(valid_cases, invalid_cases)
478        """
479        self.run_test(grammar_source, test_source)
480
481    def test_forced(self) -> None:
482        grammar_source = """
483        start: NAME &&':' | NAME
484        """
485        test_source = """
486        self.assertEqual(parse.parse_string("number :", mode=0), None)
487        with self.assertRaises(SyntaxError) as e:
488            parse.parse_string("a", mode=0)
489        self.assertIn("expected ':'", str(e.exception))
490        """
491        self.run_test(grammar_source, test_source)
492
493    def test_forced_with_group(self) -> None:
494        grammar_source = """
495        start: NAME &&(':' | ';') | NAME
496        """
497        test_source = """
498        self.assertEqual(parse.parse_string("number :", mode=0), None)
499        self.assertEqual(parse.parse_string("number ;", mode=0), None)
500        with self.assertRaises(SyntaxError) as e:
501            parse.parse_string("a", mode=0)
502        self.assertIn("expected (':' | ';')", e.exception.args[0])
503        """
504        self.run_test(grammar_source, test_source)
505