xref: /third_party/rust/crates/syn/codegen/src/eq.rs (revision fad3a1d3)
1use crate::{cfg, file, lookup};
2use anyhow::Result;
3use proc_macro2::{Ident, Span, TokenStream};
4use quote::{format_ident, quote};
5use syn_codegen::{Data, Definitions, Node, Type};
6
7const EQ_SRC: &str = "src/gen/eq.rs";
8
9fn always_eq(field_type: &Type) -> bool {
10    match field_type {
11        Type::Ext(ty) => ty == "Span",
12        Type::Token(_) | Type::Group(_) => true,
13        Type::Box(inner) => always_eq(inner),
14        Type::Tuple(inner) => inner.iter().all(always_eq),
15        _ => false,
16    }
17}
18
19fn expand_impl_body(defs: &Definitions, node: &Node) -> TokenStream {
20    let type_name = &node.ident;
21    let ident = Ident::new(type_name, Span::call_site());
22
23    match &node.data {
24        Data::Enum(variants) if variants.is_empty() => quote!(match *self {}),
25        Data::Enum(variants) => {
26            let arms = variants.iter().map(|(variant_name, fields)| {
27                let variant = Ident::new(variant_name, Span::call_site());
28                if fields.is_empty() {
29                    quote! {
30                        (#ident::#variant, #ident::#variant) => true,
31                    }
32                } else {
33                    let mut this_pats = Vec::new();
34                    let mut other_pats = Vec::new();
35                    let mut comparisons = Vec::new();
36                    for (i, field) in fields.iter().enumerate() {
37                        if always_eq(field) {
38                            this_pats.push(format_ident!("_"));
39                            other_pats.push(format_ident!("_"));
40                            continue;
41                        }
42                        let this = format_ident!("self{}", i);
43                        let other = format_ident!("other{}", i);
44                        comparisons.push(match field {
45                            Type::Ext(ty) if ty == "TokenStream" => {
46                                quote!(TokenStreamHelper(#this) == TokenStreamHelper(#other))
47                            }
48                            Type::Ext(ty) if ty == "Literal" => {
49                                quote!(#this.to_string() == #other.to_string())
50                            }
51                            _ => quote!(#this == #other),
52                        });
53                        this_pats.push(this);
54                        other_pats.push(other);
55                    }
56                    if comparisons.is_empty() {
57                        comparisons.push(quote!(true));
58                    }
59                    let mut cfg = None;
60                    if node.ident == "Expr" {
61                        if let Type::Syn(ty) = &fields[0] {
62                            if !lookup::node(defs, ty).features.any.contains("derive") {
63                                cfg = Some(quote!(#[cfg(feature = "full")]));
64                            }
65                        }
66                    }
67                    quote! {
68                        #cfg
69                        (#ident::#variant(#(#this_pats),*), #ident::#variant(#(#other_pats),*)) => {
70                            #(#comparisons)&&*
71                        }
72                    }
73                }
74            });
75            let fallthrough = if variants.len() == 1 {
76                None
77            } else {
78                Some(quote!(_ => false,))
79            };
80            quote! {
81                match (self, other) {
82                    #(#arms)*
83                    #fallthrough
84                }
85            }
86        }
87        Data::Struct(fields) => {
88            let mut comparisons = Vec::new();
89            for (f, ty) in fields {
90                if always_eq(ty) {
91                    continue;
92                }
93                let ident = Ident::new(f, Span::call_site());
94                comparisons.push(match ty {
95                    Type::Ext(ty) if ty == "TokenStream" => {
96                        quote!(TokenStreamHelper(&self.#ident) == TokenStreamHelper(&other.#ident))
97                    }
98                    _ => quote!(self.#ident == other.#ident),
99                });
100            }
101            if comparisons.is_empty() {
102                quote!(true)
103            } else {
104                quote!(#(#comparisons)&&*)
105            }
106        }
107        Data::Private => unreachable!(),
108    }
109}
110
111fn expand_impl(defs: &Definitions, node: &Node) -> TokenStream {
112    if node.ident == "Member" || node.ident == "Index" || node.ident == "Lifetime" {
113        return TokenStream::new();
114    }
115
116    let ident = Ident::new(&node.ident, Span::call_site());
117    let cfg_features = cfg::features(&node.features, "extra-traits");
118
119    let eq = quote! {
120        #cfg_features
121        impl Eq for #ident {}
122    };
123
124    let manual_partial_eq = node.data == Data::Private;
125    if manual_partial_eq {
126        return eq;
127    }
128
129    let body = expand_impl_body(defs, node);
130    let other = match &node.data {
131        Data::Enum(variants) if variants.is_empty() => quote!(_other),
132        Data::Struct(fields) if fields.values().all(always_eq) => quote!(_other),
133        _ => quote!(other),
134    };
135
136    quote! {
137        #eq
138
139        #cfg_features
140        impl PartialEq for #ident {
141            fn eq(&self, #other: &Self) -> bool {
142                #body
143            }
144        }
145    }
146}
147
148pub fn generate(defs: &Definitions) -> Result<()> {
149    let mut impls = TokenStream::new();
150    for node in &defs.types {
151        impls.extend(expand_impl(defs, node));
152    }
153
154    file::write(
155        EQ_SRC,
156        quote! {
157            #[cfg(any(feature = "derive", feature = "full"))]
158            use crate::tt::TokenStreamHelper;
159            use crate::*;
160
161            #impls
162        },
163    )?;
164
165    Ok(())
166}
167