xref: /third_party/python/Python/ast.c (revision 7db96d56)
1/*
2 * This file exposes PyAST_Validate interface to check the integrity
3 * of the given abstract syntax tree (potentially constructed manually).
4 */
5#include "Python.h"
6#include "pycore_ast.h"           // asdl_stmt_seq
7#include "pycore_pystate.h"       // _PyThreadState_GET()
8
9#include <assert.h>
10#include <stdbool.h>
11
12struct validator {
13    int recursion_depth;            /* current recursion depth */
14    int recursion_limit;            /* recursion limit */
15};
16
17static int validate_stmts(struct validator *, asdl_stmt_seq *);
18static int validate_exprs(struct validator *, asdl_expr_seq *, expr_context_ty, int);
19static int validate_patterns(struct validator *, asdl_pattern_seq *, int);
20static int _validate_nonempty_seq(asdl_seq *, const char *, const char *);
21static int validate_stmt(struct validator *, stmt_ty);
22static int validate_expr(struct validator *, expr_ty, expr_context_ty);
23static int validate_pattern(struct validator *, pattern_ty, int);
24
25#define VALIDATE_POSITIONS(node) \
26    if (node->lineno > node->end_lineno) { \
27        PyErr_Format(PyExc_ValueError, \
28                     "AST node line range (%d, %d) is not valid", \
29                     node->lineno, node->end_lineno); \
30        return 0; \
31    } \
32    if ((node->lineno < 0 && node->end_lineno != node->lineno) || \
33        (node->col_offset < 0 && node->col_offset != node->end_col_offset)) { \
34        PyErr_Format(PyExc_ValueError, \
35                     "AST node column range (%d, %d) for line range (%d, %d) is not valid", \
36                     node->col_offset, node->end_col_offset, node->lineno, node->end_lineno); \
37        return 0; \
38    } \
39    if (node->lineno == node->end_lineno && node->col_offset > node->end_col_offset) { \
40        PyErr_Format(PyExc_ValueError, \
41                     "line %d, column %d-%d is not a valid range", \
42                     node->lineno, node->col_offset, node->end_col_offset); \
43        return 0; \
44    }
45
46static int
47validate_name(PyObject *name)
48{
49    assert(!PyErr_Occurred());
50    assert(PyUnicode_Check(name));
51    static const char * const forbidden[] = {
52        "None",
53        "True",
54        "False",
55        NULL
56    };
57    for (int i = 0; forbidden[i] != NULL; i++) {
58        if (_PyUnicode_EqualToASCIIString(name, forbidden[i])) {
59            PyErr_Format(PyExc_ValueError, "identifier field can't represent '%s' constant", forbidden[i]);
60            return 0;
61        }
62    }
63    return 1;
64}
65
66static int
67validate_comprehension(struct validator *state, asdl_comprehension_seq *gens)
68{
69    assert(!PyErr_Occurred());
70    if (!asdl_seq_LEN(gens)) {
71        PyErr_SetString(PyExc_ValueError, "comprehension with no generators");
72        return 0;
73    }
74    for (Py_ssize_t i = 0; i < asdl_seq_LEN(gens); i++) {
75        comprehension_ty comp = asdl_seq_GET(gens, i);
76        if (!validate_expr(state, comp->target, Store) ||
77            !validate_expr(state, comp->iter, Load) ||
78            !validate_exprs(state, comp->ifs, Load, 0))
79            return 0;
80    }
81    return 1;
82}
83
84static int
85validate_keywords(struct validator *state, asdl_keyword_seq *keywords)
86{
87    assert(!PyErr_Occurred());
88    for (Py_ssize_t i = 0; i < asdl_seq_LEN(keywords); i++)
89        if (!validate_expr(state, (asdl_seq_GET(keywords, i))->value, Load))
90            return 0;
91    return 1;
92}
93
94static int
95validate_args(struct validator *state, asdl_arg_seq *args)
96{
97    assert(!PyErr_Occurred());
98    for (Py_ssize_t i = 0; i < asdl_seq_LEN(args); i++) {
99        arg_ty arg = asdl_seq_GET(args, i);
100        VALIDATE_POSITIONS(arg);
101        if (arg->annotation && !validate_expr(state, arg->annotation, Load))
102            return 0;
103    }
104    return 1;
105}
106
107static const char *
108expr_context_name(expr_context_ty ctx)
109{
110    switch (ctx) {
111    case Load:
112        return "Load";
113    case Store:
114        return "Store";
115    case Del:
116        return "Del";
117    // No default case so compiler emits warning for unhandled cases
118    }
119    Py_UNREACHABLE();
120}
121
122static int
123validate_arguments(struct validator *state, arguments_ty args)
124{
125    assert(!PyErr_Occurred());
126    if (!validate_args(state, args->posonlyargs) || !validate_args(state, args->args)) {
127        return 0;
128    }
129    if (args->vararg && args->vararg->annotation
130        && !validate_expr(state, args->vararg->annotation, Load)) {
131            return 0;
132    }
133    if (!validate_args(state, args->kwonlyargs))
134        return 0;
135    if (args->kwarg && args->kwarg->annotation
136        && !validate_expr(state, args->kwarg->annotation, Load)) {
137            return 0;
138    }
139    if (asdl_seq_LEN(args->defaults) > asdl_seq_LEN(args->posonlyargs) + asdl_seq_LEN(args->args)) {
140        PyErr_SetString(PyExc_ValueError, "more positional defaults than args on arguments");
141        return 0;
142    }
143    if (asdl_seq_LEN(args->kw_defaults) != asdl_seq_LEN(args->kwonlyargs)) {
144        PyErr_SetString(PyExc_ValueError, "length of kwonlyargs is not the same as "
145                        "kw_defaults on arguments");
146        return 0;
147    }
148    return validate_exprs(state, args->defaults, Load, 0) && validate_exprs(state, args->kw_defaults, Load, 1);
149}
150
151static int
152validate_constant(struct validator *state, PyObject *value)
153{
154    assert(!PyErr_Occurred());
155    if (value == Py_None || value == Py_Ellipsis)
156        return 1;
157
158    if (PyLong_CheckExact(value)
159            || PyFloat_CheckExact(value)
160            || PyComplex_CheckExact(value)
161            || PyBool_Check(value)
162            || PyUnicode_CheckExact(value)
163            || PyBytes_CheckExact(value))
164        return 1;
165
166    if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) {
167        if (++state->recursion_depth > state->recursion_limit) {
168            PyErr_SetString(PyExc_RecursionError,
169                            "maximum recursion depth exceeded during compilation");
170            return 0;
171        }
172
173        PyObject *it = PyObject_GetIter(value);
174        if (it == NULL)
175            return 0;
176
177        while (1) {
178            PyObject *item = PyIter_Next(it);
179            if (item == NULL) {
180                if (PyErr_Occurred()) {
181                    Py_DECREF(it);
182                    return 0;
183                }
184                break;
185            }
186
187            if (!validate_constant(state, item)) {
188                Py_DECREF(it);
189                Py_DECREF(item);
190                return 0;
191            }
192            Py_DECREF(item);
193        }
194
195        Py_DECREF(it);
196        --state->recursion_depth;
197        return 1;
198    }
199
200    if (!PyErr_Occurred()) {
201        PyErr_Format(PyExc_TypeError,
202                     "got an invalid type in Constant: %s",
203                     _PyType_Name(Py_TYPE(value)));
204    }
205    return 0;
206}
207
208static int
209validate_expr(struct validator *state, expr_ty exp, expr_context_ty ctx)
210{
211    assert(!PyErr_Occurred());
212    VALIDATE_POSITIONS(exp);
213    int ret = -1;
214    if (++state->recursion_depth > state->recursion_limit) {
215        PyErr_SetString(PyExc_RecursionError,
216                        "maximum recursion depth exceeded during compilation");
217        return 0;
218    }
219    int check_ctx = 1;
220    expr_context_ty actual_ctx;
221
222    /* First check expression context. */
223    switch (exp->kind) {
224    case Attribute_kind:
225        actual_ctx = exp->v.Attribute.ctx;
226        break;
227    case Subscript_kind:
228        actual_ctx = exp->v.Subscript.ctx;
229        break;
230    case Starred_kind:
231        actual_ctx = exp->v.Starred.ctx;
232        break;
233    case Name_kind:
234        if (!validate_name(exp->v.Name.id)) {
235            return 0;
236        }
237        actual_ctx = exp->v.Name.ctx;
238        break;
239    case List_kind:
240        actual_ctx = exp->v.List.ctx;
241        break;
242    case Tuple_kind:
243        actual_ctx = exp->v.Tuple.ctx;
244        break;
245    default:
246        if (ctx != Load) {
247            PyErr_Format(PyExc_ValueError, "expression which can't be "
248                         "assigned to in %s context", expr_context_name(ctx));
249            return 0;
250        }
251        check_ctx = 0;
252        /* set actual_ctx to prevent gcc warning */
253        actual_ctx = 0;
254    }
255    if (check_ctx && actual_ctx != ctx) {
256        PyErr_Format(PyExc_ValueError, "expression must have %s context but has %s instead",
257                     expr_context_name(ctx), expr_context_name(actual_ctx));
258        return 0;
259    }
260
261    /* Now validate expression. */
262    switch (exp->kind) {
263    case BoolOp_kind:
264        if (asdl_seq_LEN(exp->v.BoolOp.values) < 2) {
265            PyErr_SetString(PyExc_ValueError, "BoolOp with less than 2 values");
266            return 0;
267        }
268        ret = validate_exprs(state, exp->v.BoolOp.values, Load, 0);
269        break;
270    case BinOp_kind:
271        ret = validate_expr(state, exp->v.BinOp.left, Load) &&
272            validate_expr(state, exp->v.BinOp.right, Load);
273        break;
274    case UnaryOp_kind:
275        ret = validate_expr(state, exp->v.UnaryOp.operand, Load);
276        break;
277    case Lambda_kind:
278        ret = validate_arguments(state, exp->v.Lambda.args) &&
279            validate_expr(state, exp->v.Lambda.body, Load);
280        break;
281    case IfExp_kind:
282        ret = validate_expr(state, exp->v.IfExp.test, Load) &&
283            validate_expr(state, exp->v.IfExp.body, Load) &&
284            validate_expr(state, exp->v.IfExp.orelse, Load);
285        break;
286    case Dict_kind:
287        if (asdl_seq_LEN(exp->v.Dict.keys) != asdl_seq_LEN(exp->v.Dict.values)) {
288            PyErr_SetString(PyExc_ValueError,
289                            "Dict doesn't have the same number of keys as values");
290            return 0;
291        }
292        /* null_ok=1 for keys expressions to allow dict unpacking to work in
293           dict literals, i.e. ``{**{a:b}}`` */
294        ret = validate_exprs(state, exp->v.Dict.keys, Load, /*null_ok=*/ 1) &&
295            validate_exprs(state, exp->v.Dict.values, Load, /*null_ok=*/ 0);
296        break;
297    case Set_kind:
298        ret = validate_exprs(state, exp->v.Set.elts, Load, 0);
299        break;
300#define COMP(NAME) \
301        case NAME ## _kind: \
302            ret = validate_comprehension(state, exp->v.NAME.generators) && \
303                validate_expr(state, exp->v.NAME.elt, Load); \
304            break;
305    COMP(ListComp)
306    COMP(SetComp)
307    COMP(GeneratorExp)
308#undef COMP
309    case DictComp_kind:
310        ret = validate_comprehension(state, exp->v.DictComp.generators) &&
311            validate_expr(state, exp->v.DictComp.key, Load) &&
312            validate_expr(state, exp->v.DictComp.value, Load);
313        break;
314    case Yield_kind:
315        ret = !exp->v.Yield.value || validate_expr(state, exp->v.Yield.value, Load);
316        break;
317    case YieldFrom_kind:
318        ret = validate_expr(state, exp->v.YieldFrom.value, Load);
319        break;
320    case Await_kind:
321        ret = validate_expr(state, exp->v.Await.value, Load);
322        break;
323    case Compare_kind:
324        if (!asdl_seq_LEN(exp->v.Compare.comparators)) {
325            PyErr_SetString(PyExc_ValueError, "Compare with no comparators");
326            return 0;
327        }
328        if (asdl_seq_LEN(exp->v.Compare.comparators) !=
329            asdl_seq_LEN(exp->v.Compare.ops)) {
330            PyErr_SetString(PyExc_ValueError, "Compare has a different number "
331                            "of comparators and operands");
332            return 0;
333        }
334        ret = validate_exprs(state, exp->v.Compare.comparators, Load, 0) &&
335            validate_expr(state, exp->v.Compare.left, Load);
336        break;
337    case Call_kind:
338        ret = validate_expr(state, exp->v.Call.func, Load) &&
339            validate_exprs(state, exp->v.Call.args, Load, 0) &&
340            validate_keywords(state, exp->v.Call.keywords);
341        break;
342    case Constant_kind:
343        if (!validate_constant(state, exp->v.Constant.value)) {
344            return 0;
345        }
346        ret = 1;
347        break;
348    case JoinedStr_kind:
349        ret = validate_exprs(state, exp->v.JoinedStr.values, Load, 0);
350        break;
351    case FormattedValue_kind:
352        if (validate_expr(state, exp->v.FormattedValue.value, Load) == 0)
353            return 0;
354        if (exp->v.FormattedValue.format_spec) {
355            ret = validate_expr(state, exp->v.FormattedValue.format_spec, Load);
356            break;
357        }
358        ret = 1;
359        break;
360    case Attribute_kind:
361        ret = validate_expr(state, exp->v.Attribute.value, Load);
362        break;
363    case Subscript_kind:
364        ret = validate_expr(state, exp->v.Subscript.slice, Load) &&
365            validate_expr(state, exp->v.Subscript.value, Load);
366        break;
367    case Starred_kind:
368        ret = validate_expr(state, exp->v.Starred.value, ctx);
369        break;
370    case Slice_kind:
371        ret = (!exp->v.Slice.lower || validate_expr(state, exp->v.Slice.lower, Load)) &&
372            (!exp->v.Slice.upper || validate_expr(state, exp->v.Slice.upper, Load)) &&
373            (!exp->v.Slice.step || validate_expr(state, exp->v.Slice.step, Load));
374        break;
375    case List_kind:
376        ret = validate_exprs(state, exp->v.List.elts, ctx, 0);
377        break;
378    case Tuple_kind:
379        ret = validate_exprs(state, exp->v.Tuple.elts, ctx, 0);
380        break;
381    case NamedExpr_kind:
382        ret = validate_expr(state, exp->v.NamedExpr.value, Load);
383        break;
384    /* This last case doesn't have any checking. */
385    case Name_kind:
386        ret = 1;
387        break;
388    // No default case so compiler emits warning for unhandled cases
389    }
390    if (ret < 0) {
391        PyErr_SetString(PyExc_SystemError, "unexpected expression");
392        ret = 0;
393    }
394    state->recursion_depth--;
395    return ret;
396}
397
398
399// Note: the ensure_literal_* functions are only used to validate a restricted
400//       set of non-recursive literals that have already been checked with
401//       validate_expr, so they don't accept the validator state
402static int
403ensure_literal_number(expr_ty exp, bool allow_real, bool allow_imaginary)
404{
405    assert(exp->kind == Constant_kind);
406    PyObject *value = exp->v.Constant.value;
407    return (allow_real && PyFloat_CheckExact(value)) ||
408           (allow_real && PyLong_CheckExact(value)) ||
409           (allow_imaginary && PyComplex_CheckExact(value));
410}
411
412static int
413ensure_literal_negative(expr_ty exp, bool allow_real, bool allow_imaginary)
414{
415    assert(exp->kind == UnaryOp_kind);
416    // Must be negation ...
417    if (exp->v.UnaryOp.op != USub) {
418        return 0;
419    }
420    // ... of a constant ...
421    expr_ty operand = exp->v.UnaryOp.operand;
422    if (operand->kind != Constant_kind) {
423        return 0;
424    }
425    // ... number
426    return ensure_literal_number(operand, allow_real, allow_imaginary);
427}
428
429static int
430ensure_literal_complex(expr_ty exp)
431{
432    assert(exp->kind == BinOp_kind);
433    expr_ty left = exp->v.BinOp.left;
434    expr_ty right = exp->v.BinOp.right;
435    // Ensure op is addition or subtraction
436    if (exp->v.BinOp.op != Add && exp->v.BinOp.op != Sub) {
437        return 0;
438    }
439    // Check LHS is a real number (potentially signed)
440    switch (left->kind)
441    {
442        case Constant_kind:
443            if (!ensure_literal_number(left, /*real=*/true, /*imaginary=*/false)) {
444                return 0;
445            }
446            break;
447        case UnaryOp_kind:
448            if (!ensure_literal_negative(left, /*real=*/true, /*imaginary=*/false)) {
449                return 0;
450            }
451            break;
452        default:
453            return 0;
454    }
455    // Check RHS is an imaginary number (no separate sign allowed)
456    switch (right->kind)
457    {
458        case Constant_kind:
459            if (!ensure_literal_number(right, /*real=*/false, /*imaginary=*/true)) {
460                return 0;
461            }
462            break;
463        default:
464            return 0;
465    }
466    return 1;
467}
468
469static int
470validate_pattern_match_value(struct validator *state, expr_ty exp)
471{
472    assert(!PyErr_Occurred());
473    if (!validate_expr(state, exp, Load)) {
474        return 0;
475    }
476
477    switch (exp->kind)
478    {
479        case Constant_kind:
480            /* Ellipsis and immutable sequences are not allowed.
481               For True, False and None, MatchSingleton() should
482               be used */
483            if (!validate_expr(state, exp, Load)) {
484                return 0;
485            }
486            PyObject *literal = exp->v.Constant.value;
487            if (PyLong_CheckExact(literal) || PyFloat_CheckExact(literal) ||
488                PyBytes_CheckExact(literal) || PyComplex_CheckExact(literal) ||
489                PyUnicode_CheckExact(literal)) {
490                return 1;
491            }
492            PyErr_SetString(PyExc_ValueError,
493                            "unexpected constant inside of a literal pattern");
494            return 0;
495        case Attribute_kind:
496            // Constants and attribute lookups are always permitted
497            return 1;
498        case UnaryOp_kind:
499            // Negated numbers are permitted (whether real or imaginary)
500            // Compiler will complain if AST folding doesn't create a constant
501            if (ensure_literal_negative(exp, /*real=*/true, /*imaginary=*/true)) {
502                return 1;
503            }
504            break;
505        case BinOp_kind:
506            // Complex literals are permitted
507            // Compiler will complain if AST folding doesn't create a constant
508            if (ensure_literal_complex(exp)) {
509                return 1;
510            }
511            break;
512        case JoinedStr_kind:
513            // Handled in the later stages
514            return 1;
515        default:
516            break;
517    }
518    PyErr_SetString(PyExc_ValueError,
519                    "patterns may only match literals and attribute lookups");
520    return 0;
521}
522
523static int
524validate_capture(PyObject *name)
525{
526    assert(!PyErr_Occurred());
527    if (_PyUnicode_EqualToASCIIString(name, "_")) {
528        PyErr_Format(PyExc_ValueError, "can't capture name '_' in patterns");
529        return 0;
530    }
531    return validate_name(name);
532}
533
534static int
535validate_pattern(struct validator *state, pattern_ty p, int star_ok)
536{
537    assert(!PyErr_Occurred());
538    VALIDATE_POSITIONS(p);
539    int ret = -1;
540    if (++state->recursion_depth > state->recursion_limit) {
541        PyErr_SetString(PyExc_RecursionError,
542                        "maximum recursion depth exceeded during compilation");
543        return 0;
544    }
545    switch (p->kind) {
546        case MatchValue_kind:
547            ret = validate_pattern_match_value(state, p->v.MatchValue.value);
548            break;
549        case MatchSingleton_kind:
550            ret = p->v.MatchSingleton.value == Py_None || PyBool_Check(p->v.MatchSingleton.value);
551            if (!ret) {
552                PyErr_SetString(PyExc_ValueError,
553                                "MatchSingleton can only contain True, False and None");
554            }
555            break;
556        case MatchSequence_kind:
557            ret = validate_patterns(state, p->v.MatchSequence.patterns, /*star_ok=*/1);
558            break;
559        case MatchMapping_kind:
560            if (asdl_seq_LEN(p->v.MatchMapping.keys) != asdl_seq_LEN(p->v.MatchMapping.patterns)) {
561                PyErr_SetString(PyExc_ValueError,
562                                "MatchMapping doesn't have the same number of keys as patterns");
563                ret = 0;
564                break;
565            }
566
567            if (p->v.MatchMapping.rest && !validate_capture(p->v.MatchMapping.rest)) {
568                ret = 0;
569                break;
570            }
571
572            asdl_expr_seq *keys = p->v.MatchMapping.keys;
573            for (Py_ssize_t i = 0; i < asdl_seq_LEN(keys); i++) {
574                expr_ty key = asdl_seq_GET(keys, i);
575                if (key->kind == Constant_kind) {
576                    PyObject *literal = key->v.Constant.value;
577                    if (literal == Py_None || PyBool_Check(literal)) {
578                        /* validate_pattern_match_value will ensure the key
579                           doesn't contain True, False and None but it is
580                           syntactically valid, so we will pass those on in
581                           a special case. */
582                        continue;
583                    }
584                }
585                if (!validate_pattern_match_value(state, key)) {
586                    ret = 0;
587                    break;
588                }
589            }
590            if (ret == 0) {
591                break;
592            }
593            ret = validate_patterns(state, p->v.MatchMapping.patterns, /*star_ok=*/0);
594            break;
595        case MatchClass_kind:
596            if (asdl_seq_LEN(p->v.MatchClass.kwd_attrs) != asdl_seq_LEN(p->v.MatchClass.kwd_patterns)) {
597                PyErr_SetString(PyExc_ValueError,
598                                "MatchClass doesn't have the same number of keyword attributes as patterns");
599                ret = 0;
600                break;
601            }
602            if (!validate_expr(state, p->v.MatchClass.cls, Load)) {
603                ret = 0;
604                break;
605            }
606
607            expr_ty cls = p->v.MatchClass.cls;
608            while (1) {
609                if (cls->kind == Name_kind) {
610                    break;
611                }
612                else if (cls->kind == Attribute_kind) {
613                    cls = cls->v.Attribute.value;
614                    continue;
615                }
616                else {
617                    PyErr_SetString(PyExc_ValueError,
618                                    "MatchClass cls field can only contain Name or Attribute nodes.");
619                    ret = 0;
620                    break;
621                }
622            }
623            if (ret == 0) {
624                break;
625            }
626
627            for (Py_ssize_t i = 0; i < asdl_seq_LEN(p->v.MatchClass.kwd_attrs); i++) {
628                PyObject *identifier = asdl_seq_GET(p->v.MatchClass.kwd_attrs, i);
629                if (!validate_name(identifier)) {
630                    ret = 0;
631                    break;
632                }
633            }
634            if (ret == 0) {
635                break;
636            }
637
638            if (!validate_patterns(state, p->v.MatchClass.patterns, /*star_ok=*/0)) {
639                ret = 0;
640                break;
641            }
642
643            ret = validate_patterns(state, p->v.MatchClass.kwd_patterns, /*star_ok=*/0);
644            break;
645        case MatchStar_kind:
646            if (!star_ok) {
647                PyErr_SetString(PyExc_ValueError, "can't use MatchStar here");
648                ret = 0;
649                break;
650            }
651            ret = p->v.MatchStar.name == NULL || validate_capture(p->v.MatchStar.name);
652            break;
653        case MatchAs_kind:
654            if (p->v.MatchAs.name && !validate_capture(p->v.MatchAs.name)) {
655                ret = 0;
656                break;
657            }
658            if (p->v.MatchAs.pattern == NULL) {
659                ret = 1;
660            }
661            else if (p->v.MatchAs.name == NULL) {
662                PyErr_SetString(PyExc_ValueError,
663                                "MatchAs must specify a target name if a pattern is given");
664                ret = 0;
665            }
666            else {
667                ret = validate_pattern(state, p->v.MatchAs.pattern, /*star_ok=*/0);
668            }
669            break;
670        case MatchOr_kind:
671            if (asdl_seq_LEN(p->v.MatchOr.patterns) < 2) {
672                PyErr_SetString(PyExc_ValueError,
673                                "MatchOr requires at least 2 patterns");
674                ret = 0;
675                break;
676            }
677            ret = validate_patterns(state, p->v.MatchOr.patterns, /*star_ok=*/0);
678            break;
679    // No default case, so the compiler will emit a warning if new pattern
680    // kinds are added without being handled here
681    }
682    if (ret < 0) {
683        PyErr_SetString(PyExc_SystemError, "unexpected pattern");
684        ret = 0;
685    }
686    state->recursion_depth--;
687    return ret;
688}
689
690static int
691_validate_nonempty_seq(asdl_seq *seq, const char *what, const char *owner)
692{
693    if (asdl_seq_LEN(seq))
694        return 1;
695    PyErr_Format(PyExc_ValueError, "empty %s on %s", what, owner);
696    return 0;
697}
698#define validate_nonempty_seq(seq, what, owner) _validate_nonempty_seq((asdl_seq*)seq, what, owner)
699
700static int
701validate_assignlist(struct validator *state, asdl_expr_seq *targets, expr_context_ty ctx)
702{
703    assert(!PyErr_Occurred());
704    return validate_nonempty_seq(targets, "targets", ctx == Del ? "Delete" : "Assign") &&
705        validate_exprs(state, targets, ctx, 0);
706}
707
708static int
709validate_body(struct validator *state, asdl_stmt_seq *body, const char *owner)
710{
711    assert(!PyErr_Occurred());
712    return validate_nonempty_seq(body, "body", owner) && validate_stmts(state, body);
713}
714
715static int
716validate_stmt(struct validator *state, stmt_ty stmt)
717{
718    assert(!PyErr_Occurred());
719    VALIDATE_POSITIONS(stmt);
720    int ret = -1;
721    if (++state->recursion_depth > state->recursion_limit) {
722        PyErr_SetString(PyExc_RecursionError,
723                        "maximum recursion depth exceeded during compilation");
724        return 0;
725    }
726    switch (stmt->kind) {
727    case FunctionDef_kind:
728        ret = validate_body(state, stmt->v.FunctionDef.body, "FunctionDef") &&
729            validate_arguments(state, stmt->v.FunctionDef.args) &&
730            validate_exprs(state, stmt->v.FunctionDef.decorator_list, Load, 0) &&
731            (!stmt->v.FunctionDef.returns ||
732             validate_expr(state, stmt->v.FunctionDef.returns, Load));
733        break;
734    case ClassDef_kind:
735        ret = validate_body(state, stmt->v.ClassDef.body, "ClassDef") &&
736            validate_exprs(state, stmt->v.ClassDef.bases, Load, 0) &&
737            validate_keywords(state, stmt->v.ClassDef.keywords) &&
738            validate_exprs(state, stmt->v.ClassDef.decorator_list, Load, 0);
739        break;
740    case Return_kind:
741        ret = !stmt->v.Return.value || validate_expr(state, stmt->v.Return.value, Load);
742        break;
743    case Delete_kind:
744        ret = validate_assignlist(state, stmt->v.Delete.targets, Del);
745        break;
746    case Assign_kind:
747        ret = validate_assignlist(state, stmt->v.Assign.targets, Store) &&
748            validate_expr(state, stmt->v.Assign.value, Load);
749        break;
750    case AugAssign_kind:
751        ret = validate_expr(state, stmt->v.AugAssign.target, Store) &&
752            validate_expr(state, stmt->v.AugAssign.value, Load);
753        break;
754    case AnnAssign_kind:
755        if (stmt->v.AnnAssign.target->kind != Name_kind &&
756            stmt->v.AnnAssign.simple) {
757            PyErr_SetString(PyExc_TypeError,
758                            "AnnAssign with simple non-Name target");
759            return 0;
760        }
761        ret = validate_expr(state, stmt->v.AnnAssign.target, Store) &&
762               (!stmt->v.AnnAssign.value ||
763                validate_expr(state, stmt->v.AnnAssign.value, Load)) &&
764               validate_expr(state, stmt->v.AnnAssign.annotation, Load);
765        break;
766    case For_kind:
767        ret = validate_expr(state, stmt->v.For.target, Store) &&
768            validate_expr(state, stmt->v.For.iter, Load) &&
769            validate_body(state, stmt->v.For.body, "For") &&
770            validate_stmts(state, stmt->v.For.orelse);
771        break;
772    case AsyncFor_kind:
773        ret = validate_expr(state, stmt->v.AsyncFor.target, Store) &&
774            validate_expr(state, stmt->v.AsyncFor.iter, Load) &&
775            validate_body(state, stmt->v.AsyncFor.body, "AsyncFor") &&
776            validate_stmts(state, stmt->v.AsyncFor.orelse);
777        break;
778    case While_kind:
779        ret = validate_expr(state, stmt->v.While.test, Load) &&
780            validate_body(state, stmt->v.While.body, "While") &&
781            validate_stmts(state, stmt->v.While.orelse);
782        break;
783    case If_kind:
784        ret = validate_expr(state, stmt->v.If.test, Load) &&
785            validate_body(state, stmt->v.If.body, "If") &&
786            validate_stmts(state, stmt->v.If.orelse);
787        break;
788    case With_kind:
789        if (!validate_nonempty_seq(stmt->v.With.items, "items", "With"))
790            return 0;
791        for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.With.items); i++) {
792            withitem_ty item = asdl_seq_GET(stmt->v.With.items, i);
793            if (!validate_expr(state, item->context_expr, Load) ||
794                (item->optional_vars && !validate_expr(state, item->optional_vars, Store)))
795                return 0;
796        }
797        ret = validate_body(state, stmt->v.With.body, "With");
798        break;
799    case AsyncWith_kind:
800        if (!validate_nonempty_seq(stmt->v.AsyncWith.items, "items", "AsyncWith"))
801            return 0;
802        for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.AsyncWith.items); i++) {
803            withitem_ty item = asdl_seq_GET(stmt->v.AsyncWith.items, i);
804            if (!validate_expr(state, item->context_expr, Load) ||
805                (item->optional_vars && !validate_expr(state, item->optional_vars, Store)))
806                return 0;
807        }
808        ret = validate_body(state, stmt->v.AsyncWith.body, "AsyncWith");
809        break;
810    case Match_kind:
811        if (!validate_expr(state, stmt->v.Match.subject, Load)
812            || !validate_nonempty_seq(stmt->v.Match.cases, "cases", "Match")) {
813            return 0;
814        }
815        for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.Match.cases); i++) {
816            match_case_ty m = asdl_seq_GET(stmt->v.Match.cases, i);
817            if (!validate_pattern(state, m->pattern, /*star_ok=*/0)
818                || (m->guard && !validate_expr(state, m->guard, Load))
819                || !validate_body(state, m->body, "match_case")) {
820                return 0;
821            }
822        }
823        ret = 1;
824        break;
825    case Raise_kind:
826        if (stmt->v.Raise.exc) {
827            ret = validate_expr(state, stmt->v.Raise.exc, Load) &&
828                (!stmt->v.Raise.cause || validate_expr(state, stmt->v.Raise.cause, Load));
829            break;
830        }
831        if (stmt->v.Raise.cause) {
832            PyErr_SetString(PyExc_ValueError, "Raise with cause but no exception");
833            return 0;
834        }
835        ret = 1;
836        break;
837    case Try_kind:
838        if (!validate_body(state, stmt->v.Try.body, "Try"))
839            return 0;
840        if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
841            !asdl_seq_LEN(stmt->v.Try.finalbody)) {
842            PyErr_SetString(PyExc_ValueError, "Try has neither except handlers nor finalbody");
843            return 0;
844        }
845        if (!asdl_seq_LEN(stmt->v.Try.handlers) &&
846            asdl_seq_LEN(stmt->v.Try.orelse)) {
847            PyErr_SetString(PyExc_ValueError, "Try has orelse but no except handlers");
848            return 0;
849        }
850        for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.Try.handlers); i++) {
851            excepthandler_ty handler = asdl_seq_GET(stmt->v.Try.handlers, i);
852            VALIDATE_POSITIONS(handler);
853            if ((handler->v.ExceptHandler.type &&
854                 !validate_expr(state, handler->v.ExceptHandler.type, Load)) ||
855                !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler"))
856                return 0;
857        }
858        ret = (!asdl_seq_LEN(stmt->v.Try.finalbody) ||
859                validate_stmts(state, stmt->v.Try.finalbody)) &&
860            (!asdl_seq_LEN(stmt->v.Try.orelse) ||
861             validate_stmts(state, stmt->v.Try.orelse));
862        break;
863    case TryStar_kind:
864        if (!validate_body(state, stmt->v.TryStar.body, "TryStar"))
865            return 0;
866        if (!asdl_seq_LEN(stmt->v.TryStar.handlers) &&
867            !asdl_seq_LEN(stmt->v.TryStar.finalbody)) {
868            PyErr_SetString(PyExc_ValueError, "TryStar has neither except handlers nor finalbody");
869            return 0;
870        }
871        if (!asdl_seq_LEN(stmt->v.TryStar.handlers) &&
872            asdl_seq_LEN(stmt->v.TryStar.orelse)) {
873            PyErr_SetString(PyExc_ValueError, "TryStar has orelse but no except handlers");
874            return 0;
875        }
876        for (Py_ssize_t i = 0; i < asdl_seq_LEN(stmt->v.TryStar.handlers); i++) {
877            excepthandler_ty handler = asdl_seq_GET(stmt->v.TryStar.handlers, i);
878            if ((handler->v.ExceptHandler.type &&
879                 !validate_expr(state, handler->v.ExceptHandler.type, Load)) ||
880                !validate_body(state, handler->v.ExceptHandler.body, "ExceptHandler"))
881                return 0;
882        }
883        ret = (!asdl_seq_LEN(stmt->v.TryStar.finalbody) ||
884                validate_stmts(state, stmt->v.TryStar.finalbody)) &&
885            (!asdl_seq_LEN(stmt->v.TryStar.orelse) ||
886             validate_stmts(state, stmt->v.TryStar.orelse));
887        break;
888    case Assert_kind:
889        ret = validate_expr(state, stmt->v.Assert.test, Load) &&
890            (!stmt->v.Assert.msg || validate_expr(state, stmt->v.Assert.msg, Load));
891        break;
892    case Import_kind:
893        ret = validate_nonempty_seq(stmt->v.Import.names, "names", "Import");
894        break;
895    case ImportFrom_kind:
896        if (stmt->v.ImportFrom.level < 0) {
897            PyErr_SetString(PyExc_ValueError, "Negative ImportFrom level");
898            return 0;
899        }
900        ret = validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom");
901        break;
902    case Global_kind:
903        ret = validate_nonempty_seq(stmt->v.Global.names, "names", "Global");
904        break;
905    case Nonlocal_kind:
906        ret = validate_nonempty_seq(stmt->v.Nonlocal.names, "names", "Nonlocal");
907        break;
908    case Expr_kind:
909        ret = validate_expr(state, stmt->v.Expr.value, Load);
910        break;
911    case AsyncFunctionDef_kind:
912        ret = validate_body(state, stmt->v.AsyncFunctionDef.body, "AsyncFunctionDef") &&
913            validate_arguments(state, stmt->v.AsyncFunctionDef.args) &&
914            validate_exprs(state, stmt->v.AsyncFunctionDef.decorator_list, Load, 0) &&
915            (!stmt->v.AsyncFunctionDef.returns ||
916             validate_expr(state, stmt->v.AsyncFunctionDef.returns, Load));
917        break;
918    case Pass_kind:
919    case Break_kind:
920    case Continue_kind:
921        ret = 1;
922        break;
923    // No default case so compiler emits warning for unhandled cases
924    }
925    if (ret < 0) {
926        PyErr_SetString(PyExc_SystemError, "unexpected statement");
927        ret = 0;
928    }
929    state->recursion_depth--;
930    return ret;
931}
932
933static int
934validate_stmts(struct validator *state, asdl_stmt_seq *seq)
935{
936    assert(!PyErr_Occurred());
937    for (Py_ssize_t i = 0; i < asdl_seq_LEN(seq); i++) {
938        stmt_ty stmt = asdl_seq_GET(seq, i);
939        if (stmt) {
940            if (!validate_stmt(state, stmt))
941                return 0;
942        }
943        else {
944            PyErr_SetString(PyExc_ValueError,
945                            "None disallowed in statement list");
946            return 0;
947        }
948    }
949    return 1;
950}
951
952static int
953validate_exprs(struct validator *state, asdl_expr_seq *exprs, expr_context_ty ctx, int null_ok)
954{
955    assert(!PyErr_Occurred());
956    for (Py_ssize_t i = 0; i < asdl_seq_LEN(exprs); i++) {
957        expr_ty expr = asdl_seq_GET(exprs, i);
958        if (expr) {
959            if (!validate_expr(state, expr, ctx))
960                return 0;
961        }
962        else if (!null_ok) {
963            PyErr_SetString(PyExc_ValueError,
964                            "None disallowed in expression list");
965            return 0;
966        }
967
968    }
969    return 1;
970}
971
972static int
973validate_patterns(struct validator *state, asdl_pattern_seq *patterns, int star_ok)
974{
975    assert(!PyErr_Occurred());
976    for (Py_ssize_t i = 0; i < asdl_seq_LEN(patterns); i++) {
977        pattern_ty pattern = asdl_seq_GET(patterns, i);
978        if (!validate_pattern(state, pattern, star_ok)) {
979            return 0;
980        }
981    }
982    return 1;
983}
984
985
986/* See comments in symtable.c. */
987#define COMPILER_STACK_FRAME_SCALE 3
988
989int
990_PyAST_Validate(mod_ty mod)
991{
992    assert(!PyErr_Occurred());
993    int res = -1;
994    struct validator state;
995    PyThreadState *tstate;
996    int recursion_limit = Py_GetRecursionLimit();
997    int starting_recursion_depth;
998
999    /* Setup recursion depth check counters */
1000    tstate = _PyThreadState_GET();
1001    if (!tstate) {
1002        return 0;
1003    }
1004    /* Be careful here to prevent overflow. */
1005    int recursion_depth = tstate->recursion_limit - tstate->recursion_remaining;
1006    starting_recursion_depth = (recursion_depth< INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1007        recursion_depth * COMPILER_STACK_FRAME_SCALE : recursion_depth;
1008    state.recursion_depth = starting_recursion_depth;
1009    state.recursion_limit = (recursion_limit < INT_MAX / COMPILER_STACK_FRAME_SCALE) ?
1010        recursion_limit * COMPILER_STACK_FRAME_SCALE : recursion_limit;
1011
1012    switch (mod->kind) {
1013    case Module_kind:
1014        res = validate_stmts(&state, mod->v.Module.body);
1015        break;
1016    case Interactive_kind:
1017        res = validate_stmts(&state, mod->v.Interactive.body);
1018        break;
1019    case Expression_kind:
1020        res = validate_expr(&state, mod->v.Expression.body, Load);
1021        break;
1022    case FunctionType_kind:
1023        res = validate_exprs(&state, mod->v.FunctionType.argtypes, Load, /*null_ok=*/0) &&
1024              validate_expr(&state, mod->v.FunctionType.returns, Load);
1025        break;
1026    // No default case so compiler emits warning for unhandled cases
1027    }
1028
1029    if (res < 0) {
1030        PyErr_SetString(PyExc_SystemError, "impossible module node");
1031        return 0;
1032    }
1033
1034    /* Check that the recursion depth counting balanced correctly */
1035    if (res && state.recursion_depth != starting_recursion_depth) {
1036        PyErr_Format(PyExc_SystemError,
1037            "AST validator recursion depth mismatch (before=%d, after=%d)",
1038            starting_recursion_depth, state.recursion_depth);
1039        return 0;
1040    }
1041    return res;
1042}
1043
1044PyObject *
1045_PyAST_GetDocString(asdl_stmt_seq *body)
1046{
1047    if (!asdl_seq_LEN(body)) {
1048        return NULL;
1049    }
1050    stmt_ty st = asdl_seq_GET(body, 0);
1051    if (st->kind != Expr_kind) {
1052        return NULL;
1053    }
1054    expr_ty e = st->v.Expr.value;
1055    if (e->kind == Constant_kind && PyUnicode_CheckExact(e->v.Constant.value)) {
1056        return e->v.Constant.value;
1057    }
1058    return NULL;
1059}
1060