1use crate::internals::ast::{Container, Data}; 2use crate::internals::{attr, ungroup}; 3use proc_macro2::Span; 4use std::collections::HashSet; 5use syn::punctuated::{Pair, Punctuated}; 6use syn::Token; 7 8// Remove the default from every type parameter because in the generated impls 9// they look like associated types: "error: associated type bindings are not 10// allowed here". 11pub fn without_defaults(generics: &syn::Generics) -> syn::Generics { 12 syn::Generics { 13 params: generics 14 .params 15 .iter() 16 .map(|param| match param { 17 syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam { 18 eq_token: None, 19 default: None, 20 ..param.clone() 21 }), 22 _ => param.clone(), 23 }) 24 .collect(), 25 ..generics.clone() 26 } 27} 28 29pub fn with_where_predicates( 30 generics: &syn::Generics, 31 predicates: &[syn::WherePredicate], 32) -> syn::Generics { 33 let mut generics = generics.clone(); 34 generics 35 .make_where_clause() 36 .predicates 37 .extend(predicates.iter().cloned()); 38 generics 39} 40 41pub fn with_where_predicates_from_fields( 42 cont: &Container, 43 generics: &syn::Generics, 44 from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>, 45) -> syn::Generics { 46 let predicates = cont 47 .data 48 .all_fields() 49 .filter_map(|field| from_field(&field.attrs)) 50 .flat_map(<[syn::WherePredicate]>::to_vec); 51 52 let mut generics = generics.clone(); 53 generics.make_where_clause().predicates.extend(predicates); 54 generics 55} 56 57pub fn with_where_predicates_from_variants( 58 cont: &Container, 59 generics: &syn::Generics, 60 from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>, 61) -> syn::Generics { 62 let variants = match &cont.data { 63 Data::Enum(variants) => variants, 64 Data::Struct(_, _) => { 65 return generics.clone(); 66 } 67 }; 68 69 let predicates = variants 70 .iter() 71 .filter_map(|variant| from_variant(&variant.attrs)) 72 .flat_map(<[syn::WherePredicate]>::to_vec); 73 74 let mut generics = generics.clone(); 75 generics.make_where_clause().predicates.extend(predicates); 76 generics 77} 78 79// Puts the given bound on any generic type parameters that are used in fields 80// for which filter returns true. 81// 82// For example, the following struct needs the bound `A: Serialize, B: 83// Serialize`. 84// 85// struct S<'b, A, B: 'b, C> { 86// a: A, 87// b: Option<&'b B> 88// #[serde(skip_serializing)] 89// c: C, 90// } 91pub fn with_bound( 92 cont: &Container, 93 generics: &syn::Generics, 94 filter: fn(&attr::Field, Option<&attr::Variant>) -> bool, 95 bound: &syn::Path, 96) -> syn::Generics { 97 struct FindTyParams<'ast> { 98 // Set of all generic type parameters on the current struct (A, B, C in 99 // the example). Initialized up front. 100 all_type_params: HashSet<syn::Ident>, 101 102 // Set of generic type parameters used in fields for which filter 103 // returns true (A and B in the example). Filled in as the visitor sees 104 // them. 105 relevant_type_params: HashSet<syn::Ident>, 106 107 // Fields whose type is an associated type of one of the generic type 108 // parameters. 109 associated_type_usage: Vec<&'ast syn::TypePath>, 110 } 111 112 impl<'ast> FindTyParams<'ast> { 113 fn visit_field(&mut self, field: &'ast syn::Field) { 114 if let syn::Type::Path(ty) = ungroup(&field.ty) { 115 if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() { 116 if self.all_type_params.contains(&t.ident) { 117 self.associated_type_usage.push(ty); 118 } 119 } 120 } 121 self.visit_type(&field.ty); 122 } 123 124 fn visit_path(&mut self, path: &'ast syn::Path) { 125 if let Some(seg) = path.segments.last() { 126 if seg.ident == "PhantomData" { 127 // Hardcoded exception, because PhantomData<T> implements 128 // Serialize and Deserialize whether or not T implements it. 129 return; 130 } 131 } 132 if path.leading_colon.is_none() && path.segments.len() == 1 { 133 let id = &path.segments[0].ident; 134 if self.all_type_params.contains(id) { 135 self.relevant_type_params.insert(id.clone()); 136 } 137 } 138 for segment in &path.segments { 139 self.visit_path_segment(segment); 140 } 141 } 142 143 // Everything below is simply traversing the syntax tree. 144 145 fn visit_type(&mut self, ty: &'ast syn::Type) { 146 match ty { 147 #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] 148 syn::Type::Array(ty) => self.visit_type(&ty.elem), 149 syn::Type::BareFn(ty) => { 150 for arg in &ty.inputs { 151 self.visit_type(&arg.ty); 152 } 153 self.visit_return_type(&ty.output); 154 } 155 syn::Type::Group(ty) => self.visit_type(&ty.elem), 156 syn::Type::ImplTrait(ty) => { 157 for bound in &ty.bounds { 158 self.visit_type_param_bound(bound); 159 } 160 } 161 syn::Type::Macro(ty) => self.visit_macro(&ty.mac), 162 syn::Type::Paren(ty) => self.visit_type(&ty.elem), 163 syn::Type::Path(ty) => { 164 if let Some(qself) = &ty.qself { 165 self.visit_type(&qself.ty); 166 } 167 self.visit_path(&ty.path); 168 } 169 syn::Type::Ptr(ty) => self.visit_type(&ty.elem), 170 syn::Type::Reference(ty) => self.visit_type(&ty.elem), 171 syn::Type::Slice(ty) => self.visit_type(&ty.elem), 172 syn::Type::TraitObject(ty) => { 173 for bound in &ty.bounds { 174 self.visit_type_param_bound(bound); 175 } 176 } 177 syn::Type::Tuple(ty) => { 178 for elem in &ty.elems { 179 self.visit_type(elem); 180 } 181 } 182 183 syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {} 184 185 _ => {} 186 } 187 } 188 189 fn visit_path_segment(&mut self, segment: &'ast syn::PathSegment) { 190 self.visit_path_arguments(&segment.arguments); 191 } 192 193 fn visit_path_arguments(&mut self, arguments: &'ast syn::PathArguments) { 194 match arguments { 195 syn::PathArguments::None => {} 196 syn::PathArguments::AngleBracketed(arguments) => { 197 for arg in &arguments.args { 198 match arg { 199 #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] 200 syn::GenericArgument::Type(arg) => self.visit_type(arg), 201 syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty), 202 syn::GenericArgument::Lifetime(_) 203 | syn::GenericArgument::Const(_) 204 | syn::GenericArgument::AssocConst(_) 205 | syn::GenericArgument::Constraint(_) => {} 206 _ => {} 207 } 208 } 209 } 210 syn::PathArguments::Parenthesized(arguments) => { 211 for argument in &arguments.inputs { 212 self.visit_type(argument); 213 } 214 self.visit_return_type(&arguments.output); 215 } 216 } 217 } 218 219 fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) { 220 match return_type { 221 syn::ReturnType::Default => {} 222 syn::ReturnType::Type(_, output) => self.visit_type(output), 223 } 224 } 225 226 fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) { 227 match bound { 228 #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] 229 syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path), 230 syn::TypeParamBound::Lifetime(_) | syn::TypeParamBound::Verbatim(_) => {} 231 _ => {} 232 } 233 } 234 235 // Type parameter should not be considered used by a macro path. 236 // 237 // struct TypeMacro<T> { 238 // mac: T!(), 239 // marker: PhantomData<T>, 240 // } 241 fn visit_macro(&mut self, _mac: &'ast syn::Macro) {} 242 } 243 244 let all_type_params = generics 245 .type_params() 246 .map(|param| param.ident.clone()) 247 .collect(); 248 249 let mut visitor = FindTyParams { 250 all_type_params, 251 relevant_type_params: HashSet::new(), 252 associated_type_usage: Vec::new(), 253 }; 254 match &cont.data { 255 Data::Enum(variants) => { 256 for variant in variants { 257 let relevant_fields = variant 258 .fields 259 .iter() 260 .filter(|field| filter(&field.attrs, Some(&variant.attrs))); 261 for field in relevant_fields { 262 visitor.visit_field(field.original); 263 } 264 } 265 } 266 Data::Struct(_, fields) => { 267 for field in fields.iter().filter(|field| filter(&field.attrs, None)) { 268 visitor.visit_field(field.original); 269 } 270 } 271 } 272 273 let relevant_type_params = visitor.relevant_type_params; 274 let associated_type_usage = visitor.associated_type_usage; 275 let new_predicates = generics 276 .type_params() 277 .map(|param| param.ident.clone()) 278 .filter(|id| relevant_type_params.contains(id)) 279 .map(|id| syn::TypePath { 280 qself: None, 281 path: id.into(), 282 }) 283 .chain(associated_type_usage.into_iter().cloned()) 284 .map(|bounded_ty| { 285 syn::WherePredicate::Type(syn::PredicateType { 286 lifetimes: None, 287 // the type parameter that is being bounded e.g. T 288 bounded_ty: syn::Type::Path(bounded_ty), 289 colon_token: <Token![:]>::default(), 290 // the bound e.g. Serialize 291 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound { 292 paren_token: None, 293 modifier: syn::TraitBoundModifier::None, 294 lifetimes: None, 295 path: bound.clone(), 296 })] 297 .into_iter() 298 .collect(), 299 }) 300 }); 301 302 let mut generics = generics.clone(); 303 generics 304 .make_where_clause() 305 .predicates 306 .extend(new_predicates); 307 generics 308} 309 310pub fn with_self_bound( 311 cont: &Container, 312 generics: &syn::Generics, 313 bound: &syn::Path, 314) -> syn::Generics { 315 let mut generics = generics.clone(); 316 generics 317 .make_where_clause() 318 .predicates 319 .push(syn::WherePredicate::Type(syn::PredicateType { 320 lifetimes: None, 321 // the type that is being bounded e.g. MyStruct<'a, T> 322 bounded_ty: type_of_item(cont), 323 colon_token: <Token![:]>::default(), 324 // the bound e.g. Default 325 bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound { 326 paren_token: None, 327 modifier: syn::TraitBoundModifier::None, 328 lifetimes: None, 329 path: bound.clone(), 330 })] 331 .into_iter() 332 .collect(), 333 })); 334 generics 335} 336 337pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics { 338 let bound = syn::Lifetime::new(lifetime, Span::call_site()); 339 let def = syn::LifetimeParam { 340 attrs: Vec::new(), 341 lifetime: bound.clone(), 342 colon_token: None, 343 bounds: Punctuated::new(), 344 }; 345 346 let params = Some(syn::GenericParam::Lifetime(def)) 347 .into_iter() 348 .chain(generics.params.iter().cloned().map(|mut param| { 349 match &mut param { 350 syn::GenericParam::Lifetime(param) => { 351 param.bounds.push(bound.clone()); 352 } 353 syn::GenericParam::Type(param) => { 354 param 355 .bounds 356 .push(syn::TypeParamBound::Lifetime(bound.clone())); 357 } 358 syn::GenericParam::Const(_) => {} 359 } 360 param 361 })) 362 .collect(); 363 364 syn::Generics { 365 params, 366 ..generics.clone() 367 } 368} 369 370fn type_of_item(cont: &Container) -> syn::Type { 371 syn::Type::Path(syn::TypePath { 372 qself: None, 373 path: syn::Path { 374 leading_colon: None, 375 segments: vec![syn::PathSegment { 376 ident: cont.ident.clone(), 377 arguments: syn::PathArguments::AngleBracketed( 378 syn::AngleBracketedGenericArguments { 379 colon2_token: None, 380 lt_token: <Token![<]>::default(), 381 args: cont 382 .generics 383 .params 384 .iter() 385 .map(|param| match param { 386 syn::GenericParam::Type(param) => { 387 syn::GenericArgument::Type(syn::Type::Path(syn::TypePath { 388 qself: None, 389 path: param.ident.clone().into(), 390 })) 391 } 392 syn::GenericParam::Lifetime(param) => { 393 syn::GenericArgument::Lifetime(param.lifetime.clone()) 394 } 395 syn::GenericParam::Const(_) => { 396 panic!("Serde does not support const generics yet"); 397 } 398 }) 399 .collect(), 400 gt_token: <Token![>]>::default(), 401 }, 402 ), 403 }] 404 .into_iter() 405 .collect(), 406 }, 407 }) 408} 409