xref: /third_party/rust/crates/syn/codegen/src/fold.rs (revision fad3a1d3)
1use crate::{file, full, gen};
2use anyhow::Result;
3use proc_macro2::{Ident, Span, TokenStream};
4use quote::{format_ident, quote};
5use syn::Index;
6use syn_codegen::{Data, Definitions, Features, Node, Type};
7
8const FOLD_SRC: &str = "src/gen/fold.rs";
9
10fn simple_visit(item: &str, name: &TokenStream) -> TokenStream {
11    let ident = gen::under_name(item);
12    let method = format_ident!("fold_{}", ident);
13    quote! {
14        f.#method(#name)
15    }
16}
17
18fn visit(
19    ty: &Type,
20    features: &Features,
21    defs: &Definitions,
22    name: &TokenStream,
23) -> Option<TokenStream> {
24    match ty {
25        Type::Box(t) => {
26            let res = visit(t, features, defs, &quote!(*#name))?;
27            Some(quote! {
28                Box::new(#res)
29            })
30        }
31        Type::Vec(t) => {
32            let operand = quote!(it);
33            let val = visit(t, features, defs, &operand)?;
34            Some(quote! {
35                FoldHelper::lift(#name, |it| #val)
36            })
37        }
38        Type::Punctuated(p) => {
39            let operand = quote!(it);
40            let val = visit(&p.element, features, defs, &operand)?;
41            Some(quote! {
42                FoldHelper::lift(#name, |it| #val)
43            })
44        }
45        Type::Option(t) => {
46            let it = quote!(it);
47            let val = visit(t, features, defs, &it)?;
48            Some(quote! {
49                (#name).map(|it| #val)
50            })
51        }
52        Type::Tuple(t) => {
53            let mut code = TokenStream::new();
54            for (i, elem) in t.iter().enumerate() {
55                let i = Index::from(i);
56                let it = quote!((#name).#i);
57                let val = visit(elem, features, defs, &it).unwrap_or(it);
58                code.extend(val);
59                code.extend(quote!(,));
60            }
61            Some(quote! {
62                (#code)
63            })
64        }
65        Type::Syn(t) => {
66            fn requires_full(features: &Features) -> bool {
67                features.any.contains("full") && features.any.len() == 1
68            }
69            let mut res = simple_visit(t, name);
70            let target = defs.types.iter().find(|ty| ty.ident == *t).unwrap();
71            if requires_full(&target.features) && !requires_full(features) {
72                res = quote!(full!(#res));
73            }
74            Some(res)
75        }
76        Type::Ext(t) if gen::TERMINAL_TYPES.contains(&&t[..]) => Some(simple_visit(t, name)),
77        Type::Ext(_) | Type::Std(_) | Type::Token(_) | Type::Group(_) => None,
78    }
79}
80
81fn node(traits: &mut TokenStream, impls: &mut TokenStream, s: &Node, defs: &Definitions) {
82    let under_name = gen::under_name(&s.ident);
83    let ty = Ident::new(&s.ident, Span::call_site());
84    let fold_fn = format_ident!("fold_{}", under_name);
85
86    let mut fold_impl = TokenStream::new();
87
88    match &s.data {
89        Data::Enum(variants) => {
90            let mut fold_variants = TokenStream::new();
91
92            for (variant, fields) in variants {
93                let variant_ident = Ident::new(variant, Span::call_site());
94
95                if fields.is_empty() {
96                    fold_variants.extend(quote! {
97                        #ty::#variant_ident => {
98                            #ty::#variant_ident
99                        }
100                    });
101                } else {
102                    let mut bind_fold_fields = TokenStream::new();
103                    let mut fold_fields = TokenStream::new();
104
105                    for (idx, ty) in fields.iter().enumerate() {
106                        let binding = format_ident!("_binding_{}", idx);
107
108                        bind_fold_fields.extend(quote! {
109                            #binding,
110                        });
111
112                        let owned_binding = quote!(#binding);
113
114                        fold_fields.extend(
115                            visit(ty, &s.features, defs, &owned_binding).unwrap_or(owned_binding),
116                        );
117
118                        fold_fields.extend(quote!(,));
119                    }
120
121                    fold_variants.extend(quote! {
122                        #ty::#variant_ident(#bind_fold_fields) => {
123                            #ty::#variant_ident(
124                                #fold_fields
125                            )
126                        }
127                    });
128                }
129            }
130
131            fold_impl.extend(quote! {
132                match node {
133                    #fold_variants
134                }
135            });
136        }
137        Data::Struct(fields) => {
138            let mut fold_fields = TokenStream::new();
139
140            for (field, ty) in fields {
141                let id = Ident::new(field, Span::call_site());
142                let ref_toks = quote!(node.#id);
143
144                let fold = visit(ty, &s.features, defs, &ref_toks).unwrap_or(ref_toks);
145
146                fold_fields.extend(quote! {
147                    #id: #fold,
148                });
149            }
150
151            if fields.is_empty() {
152                if ty == "Ident" {
153                    fold_impl.extend(quote! {
154                        let mut node = node;
155                        let span = f.fold_span(node.span());
156                        node.set_span(span);
157                    });
158                }
159                fold_impl.extend(quote! {
160                    node
161                });
162            } else {
163                fold_impl.extend(quote! {
164                    #ty {
165                        #fold_fields
166                    }
167                });
168            }
169        }
170        Data::Private => {
171            if ty == "Ident" {
172                fold_impl.extend(quote! {
173                    let mut node = node;
174                    let span = f.fold_span(node.span());
175                    node.set_span(span);
176                });
177            }
178            fold_impl.extend(quote! {
179                node
180            });
181        }
182    }
183
184    let fold_span_only =
185        s.data == Data::Private && !gen::TERMINAL_TYPES.contains(&s.ident.as_str());
186    if fold_span_only {
187        fold_impl = quote! {
188            let span = f.fold_span(node.span());
189            let mut node = node;
190            node.set_span(span);
191            node
192        };
193    }
194
195    traits.extend(quote! {
196        fn #fold_fn(&mut self, i: #ty) -> #ty {
197            #fold_fn(self, i)
198        }
199    });
200
201    impls.extend(quote! {
202        pub fn #fold_fn<F>(f: &mut F, node: #ty) -> #ty
203        where
204            F: Fold + ?Sized,
205        {
206            #fold_impl
207        }
208    });
209}
210
211pub fn generate(defs: &Definitions) -> Result<()> {
212    let (traits, impls) = gen::traverse(defs, node);
213    let full_macro = full::get_macro();
214    file::write(
215        FOLD_SRC,
216        quote! {
217            // Unreachable code is generated sometimes without the full feature.
218            #![allow(unreachable_code, unused_variables)]
219            #![allow(
220                clippy::match_wildcard_for_single_variants,
221                clippy::needless_match,
222                clippy::needless_pass_by_ref_mut,
223            )]
224
225            #[cfg(any(feature = "full", feature = "derive"))]
226            use crate::gen::helper::fold::*;
227            use crate::*;
228            use proc_macro2::Span;
229
230            #full_macro
231
232            /// Syntax tree traversal to transform the nodes of an owned syntax tree.
233            ///
234            /// See the [module documentation] for details.
235            ///
236            /// [module documentation]: self
237            pub trait Fold {
238                #traits
239            }
240
241            #impls
242        },
243    )?;
244    Ok(())
245}
246