162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0
262306a36Sopenharmony_ci
362306a36Sopenharmony_ciuse proc_macro::{Delimiter, Group, TokenStream, TokenTree};
462306a36Sopenharmony_ciuse std::collections::HashSet;
562306a36Sopenharmony_ciuse std::fmt::Write;
662306a36Sopenharmony_ci
762306a36Sopenharmony_cipub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream {
862306a36Sopenharmony_ci    let mut tokens: Vec<_> = ts.into_iter().collect();
962306a36Sopenharmony_ci
1062306a36Sopenharmony_ci    // Scan for the `trait` or `impl` keyword.
1162306a36Sopenharmony_ci    let is_trait = tokens
1262306a36Sopenharmony_ci        .iter()
1362306a36Sopenharmony_ci        .find_map(|token| match token {
1462306a36Sopenharmony_ci            TokenTree::Ident(ident) => match ident.to_string().as_str() {
1562306a36Sopenharmony_ci                "trait" => Some(true),
1662306a36Sopenharmony_ci                "impl" => Some(false),
1762306a36Sopenharmony_ci                _ => None,
1862306a36Sopenharmony_ci            },
1962306a36Sopenharmony_ci            _ => None,
2062306a36Sopenharmony_ci        })
2162306a36Sopenharmony_ci        .expect("#[vtable] attribute should only be applied to trait or impl block");
2262306a36Sopenharmony_ci
2362306a36Sopenharmony_ci    // Retrieve the main body. The main body should be the last token tree.
2462306a36Sopenharmony_ci    let body = match tokens.pop() {
2562306a36Sopenharmony_ci        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
2662306a36Sopenharmony_ci        _ => panic!("cannot locate main body of trait or impl block"),
2762306a36Sopenharmony_ci    };
2862306a36Sopenharmony_ci
2962306a36Sopenharmony_ci    let mut body_it = body.stream().into_iter();
3062306a36Sopenharmony_ci    let mut functions = Vec::new();
3162306a36Sopenharmony_ci    let mut consts = HashSet::new();
3262306a36Sopenharmony_ci    while let Some(token) = body_it.next() {
3362306a36Sopenharmony_ci        match token {
3462306a36Sopenharmony_ci            TokenTree::Ident(ident) if ident.to_string() == "fn" => {
3562306a36Sopenharmony_ci                let fn_name = match body_it.next() {
3662306a36Sopenharmony_ci                    Some(TokenTree::Ident(ident)) => ident.to_string(),
3762306a36Sopenharmony_ci                    // Possibly we've encountered a fn pointer type instead.
3862306a36Sopenharmony_ci                    _ => continue,
3962306a36Sopenharmony_ci                };
4062306a36Sopenharmony_ci                functions.push(fn_name);
4162306a36Sopenharmony_ci            }
4262306a36Sopenharmony_ci            TokenTree::Ident(ident) if ident.to_string() == "const" => {
4362306a36Sopenharmony_ci                let const_name = match body_it.next() {
4462306a36Sopenharmony_ci                    Some(TokenTree::Ident(ident)) => ident.to_string(),
4562306a36Sopenharmony_ci                    // Possibly we've encountered an inline const block instead.
4662306a36Sopenharmony_ci                    _ => continue,
4762306a36Sopenharmony_ci                };
4862306a36Sopenharmony_ci                consts.insert(const_name);
4962306a36Sopenharmony_ci            }
5062306a36Sopenharmony_ci            _ => (),
5162306a36Sopenharmony_ci        }
5262306a36Sopenharmony_ci    }
5362306a36Sopenharmony_ci
5462306a36Sopenharmony_ci    let mut const_items;
5562306a36Sopenharmony_ci    if is_trait {
5662306a36Sopenharmony_ci        const_items = "
5762306a36Sopenharmony_ci                /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable)
5862306a36Sopenharmony_ci                /// attribute when implementing this trait.
5962306a36Sopenharmony_ci                const USE_VTABLE_ATTR: ();
6062306a36Sopenharmony_ci        "
6162306a36Sopenharmony_ci        .to_owned();
6262306a36Sopenharmony_ci
6362306a36Sopenharmony_ci        for f in functions {
6462306a36Sopenharmony_ci            let gen_const_name = format!("HAS_{}", f.to_uppercase());
6562306a36Sopenharmony_ci            // Skip if it's declared already -- this allows user override.
6662306a36Sopenharmony_ci            if consts.contains(&gen_const_name) {
6762306a36Sopenharmony_ci                continue;
6862306a36Sopenharmony_ci            }
6962306a36Sopenharmony_ci            // We don't know on the implementation-site whether a method is required or provided
7062306a36Sopenharmony_ci            // so we have to generate a const for all methods.
7162306a36Sopenharmony_ci            write!(
7262306a36Sopenharmony_ci                const_items,
7362306a36Sopenharmony_ci                "/// Indicates if the `{f}` method is overridden by the implementor.
7462306a36Sopenharmony_ci                const {gen_const_name}: bool = false;",
7562306a36Sopenharmony_ci            )
7662306a36Sopenharmony_ci            .unwrap();
7762306a36Sopenharmony_ci            consts.insert(gen_const_name);
7862306a36Sopenharmony_ci        }
7962306a36Sopenharmony_ci    } else {
8062306a36Sopenharmony_ci        const_items = "const USE_VTABLE_ATTR: () = ();".to_owned();
8162306a36Sopenharmony_ci
8262306a36Sopenharmony_ci        for f in functions {
8362306a36Sopenharmony_ci            let gen_const_name = format!("HAS_{}", f.to_uppercase());
8462306a36Sopenharmony_ci            if consts.contains(&gen_const_name) {
8562306a36Sopenharmony_ci                continue;
8662306a36Sopenharmony_ci            }
8762306a36Sopenharmony_ci            write!(const_items, "const {gen_const_name}: bool = true;").unwrap();
8862306a36Sopenharmony_ci        }
8962306a36Sopenharmony_ci    }
9062306a36Sopenharmony_ci
9162306a36Sopenharmony_ci    let new_body = vec![const_items.parse().unwrap(), body.stream()]
9262306a36Sopenharmony_ci        .into_iter()
9362306a36Sopenharmony_ci        .collect();
9462306a36Sopenharmony_ci    tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body)));
9562306a36Sopenharmony_ci    tokens.into_iter().collect()
9662306a36Sopenharmony_ci}
97