1use crate::syntax::Atom::{self, *};
2use proc_macro2::{Literal, Span, TokenStream};
3use quote::ToTokens;
4use std::cmp::Ordering;
5use std::collections::BTreeSet;
6use std::fmt::{self, Display};
7use std::str::FromStr;
8use std::u64;
9use syn::{Error, Expr, Lit, Result, Token, UnOp};
10
11pub struct DiscriminantSet {
12    repr: Option<Atom>,
13    values: BTreeSet<Discriminant>,
14    previous: Option<Discriminant>,
15}
16
17#[derive(Copy, Clone, Eq, PartialEq)]
18pub struct Discriminant {
19    sign: Sign,
20    magnitude: u64,
21}
22
23#[derive(Copy, Clone, Eq, PartialEq)]
24enum Sign {
25    Negative,
26    Positive,
27}
28
29impl DiscriminantSet {
30    pub fn new(repr: Option<Atom>) -> Self {
31        DiscriminantSet {
32            repr,
33            values: BTreeSet::new(),
34            previous: None,
35        }
36    }
37
38    pub fn insert(&mut self, expr: &Expr) -> Result<Discriminant> {
39        let (discriminant, repr) = expr_to_discriminant(expr)?;
40        match (self.repr, repr) {
41            (None, Some(new_repr)) => {
42                if let Some(limits) = Limits::of(new_repr) {
43                    for &past in &self.values {
44                        if limits.min <= past && past <= limits.max {
45                            continue;
46                        }
47                        let msg = format!(
48                            "discriminant value `{}` is outside the limits of {}",
49                            past, new_repr,
50                        );
51                        return Err(Error::new(Span::call_site(), msg));
52                    }
53                }
54                self.repr = Some(new_repr);
55            }
56            (Some(prev), Some(repr)) if prev != repr => {
57                let msg = format!("expected {}, found {}", prev, repr);
58                return Err(Error::new(Span::call_site(), msg));
59            }
60            _ => {}
61        }
62        insert(self, discriminant)
63    }
64
65    pub fn insert_next(&mut self) -> Result<Discriminant> {
66        let discriminant = match self.previous {
67            None => Discriminant::zero(),
68            Some(mut discriminant) => match discriminant.sign {
69                Sign::Negative => {
70                    discriminant.magnitude -= 1;
71                    if discriminant.magnitude == 0 {
72                        discriminant.sign = Sign::Positive;
73                    }
74                    discriminant
75                }
76                Sign::Positive => {
77                    if discriminant.magnitude == u64::MAX {
78                        let msg = format!("discriminant overflow on value after {}", u64::MAX);
79                        return Err(Error::new(Span::call_site(), msg));
80                    }
81                    discriminant.magnitude += 1;
82                    discriminant
83                }
84            },
85        };
86        insert(self, discriminant)
87    }
88
89    pub fn inferred_repr(&self) -> Result<Atom> {
90        if let Some(repr) = self.repr {
91            return Ok(repr);
92        }
93        if self.values.is_empty() {
94            return Ok(U8);
95        }
96        let min = *self.values.iter().next().unwrap();
97        let max = *self.values.iter().next_back().unwrap();
98        for limits in &LIMITS {
99            if limits.min <= min && max <= limits.max {
100                return Ok(limits.repr);
101            }
102        }
103        let msg = "these discriminant values do not fit in any supported enum repr type";
104        Err(Error::new(Span::call_site(), msg))
105    }
106}
107
108fn expr_to_discriminant(expr: &Expr) -> Result<(Discriminant, Option<Atom>)> {
109    match expr {
110        Expr::Lit(expr) => {
111            if let Lit::Int(lit) = &expr.lit {
112                let discriminant = lit.base10_parse::<Discriminant>()?;
113                let repr = parse_int_suffix(lit.suffix())?;
114                return Ok((discriminant, repr));
115            }
116        }
117        Expr::Unary(unary) => {
118            if let UnOp::Neg(_) = unary.op {
119                let (mut discriminant, repr) = expr_to_discriminant(&unary.expr)?;
120                discriminant.sign = match discriminant.sign {
121                    Sign::Positive => Sign::Negative,
122                    Sign::Negative => Sign::Positive,
123                };
124                return Ok((discriminant, repr));
125            }
126        }
127        _ => {}
128    }
129    Err(Error::new_spanned(
130        expr,
131        "enums with non-integer literal discriminants are not supported yet",
132    ))
133}
134
135fn insert(set: &mut DiscriminantSet, discriminant: Discriminant) -> Result<Discriminant> {
136    if let Some(expected_repr) = set.repr {
137        if let Some(limits) = Limits::of(expected_repr) {
138            if discriminant < limits.min || limits.max < discriminant {
139                let msg = format!(
140                    "discriminant value `{}` is outside the limits of {}",
141                    discriminant, expected_repr,
142                );
143                return Err(Error::new(Span::call_site(), msg));
144            }
145        }
146    }
147    set.values.insert(discriminant);
148    set.previous = Some(discriminant);
149    Ok(discriminant)
150}
151
152impl Discriminant {
153    pub const fn zero() -> Self {
154        Discriminant {
155            sign: Sign::Positive,
156            magnitude: 0,
157        }
158    }
159
160    const fn pos(u: u64) -> Self {
161        Discriminant {
162            sign: Sign::Positive,
163            magnitude: u,
164        }
165    }
166
167    const fn neg(i: i64) -> Self {
168        Discriminant {
169            sign: if i < 0 {
170                Sign::Negative
171            } else {
172                Sign::Positive
173            },
174            // This is `i.abs() as u64` but without overflow on MIN. Uses the
175            // fact that MIN.wrapping_abs() wraps back to MIN whose binary
176            // representation is 1<<63, and thus the `as u64` conversion
177            // produces 1<<63 too which happens to be the correct unsigned
178            // magnitude.
179            magnitude: i.wrapping_abs() as u64,
180        }
181    }
182
183    #[cfg(feature = "experimental-enum-variants-from-header")]
184    pub const fn checked_succ(self) -> Option<Self> {
185        match self.sign {
186            Sign::Negative => {
187                if self.magnitude == 1 {
188                    Some(Discriminant::zero())
189                } else {
190                    Some(Discriminant {
191                        sign: Sign::Negative,
192                        magnitude: self.magnitude - 1,
193                    })
194                }
195            }
196            Sign::Positive => match self.magnitude.checked_add(1) {
197                Some(magnitude) => Some(Discriminant {
198                    sign: Sign::Positive,
199                    magnitude,
200                }),
201                None => None,
202            },
203        }
204    }
205}
206
207impl Display for Discriminant {
208    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
209        if self.sign == Sign::Negative {
210            f.write_str("-")?;
211        }
212        write!(f, "{}", self.magnitude)
213    }
214}
215
216impl ToTokens for Discriminant {
217    fn to_tokens(&self, tokens: &mut TokenStream) {
218        if self.sign == Sign::Negative {
219            Token![-](Span::call_site()).to_tokens(tokens);
220        }
221        Literal::u64_unsuffixed(self.magnitude).to_tokens(tokens);
222    }
223}
224
225impl FromStr for Discriminant {
226    type Err = Error;
227
228    fn from_str(mut s: &str) -> Result<Self> {
229        let sign = if s.starts_with('-') {
230            s = &s[1..];
231            Sign::Negative
232        } else {
233            Sign::Positive
234        };
235        match s.parse::<u64>() {
236            Ok(magnitude) => Ok(Discriminant { sign, magnitude }),
237            Err(_) => Err(Error::new(
238                Span::call_site(),
239                "discriminant value outside of supported range",
240            )),
241        }
242    }
243}
244
245impl Ord for Discriminant {
246    fn cmp(&self, other: &Self) -> Ordering {
247        use self::Sign::{Negative, Positive};
248        match (self.sign, other.sign) {
249            (Negative, Negative) => self.magnitude.cmp(&other.magnitude).reverse(),
250            (Negative, Positive) => Ordering::Less, // negative < positive
251            (Positive, Negative) => Ordering::Greater, // positive > negative
252            (Positive, Positive) => self.magnitude.cmp(&other.magnitude),
253        }
254    }
255}
256
257impl PartialOrd for Discriminant {
258    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
259        Some(self.cmp(other))
260    }
261}
262
263fn parse_int_suffix(suffix: &str) -> Result<Option<Atom>> {
264    if suffix.is_empty() {
265        return Ok(None);
266    }
267    if let Some(atom) = Atom::from_str(suffix) {
268        match atom {
269            U8 | U16 | U32 | U64 | Usize | I8 | I16 | I32 | I64 | Isize => return Ok(Some(atom)),
270            _ => {}
271        }
272    }
273    let msg = format!("unrecognized integer suffix: `{}`", suffix);
274    Err(Error::new(Span::call_site(), msg))
275}
276
277#[derive(Copy, Clone)]
278struct Limits {
279    repr: Atom,
280    min: Discriminant,
281    max: Discriminant,
282}
283
284impl Limits {
285    fn of(repr: Atom) -> Option<Limits> {
286        for limits in &LIMITS {
287            if limits.repr == repr {
288                return Some(*limits);
289            }
290        }
291        None
292    }
293}
294
295const LIMITS: [Limits; 8] = [
296    Limits {
297        repr: U8,
298        min: Discriminant::zero(),
299        max: Discriminant::pos(std::u8::MAX as u64),
300    },
301    Limits {
302        repr: I8,
303        min: Discriminant::neg(std::i8::MIN as i64),
304        max: Discriminant::pos(std::i8::MAX as u64),
305    },
306    Limits {
307        repr: U16,
308        min: Discriminant::zero(),
309        max: Discriminant::pos(std::u16::MAX as u64),
310    },
311    Limits {
312        repr: I16,
313        min: Discriminant::neg(std::i16::MIN as i64),
314        max: Discriminant::pos(std::i16::MAX as u64),
315    },
316    Limits {
317        repr: U32,
318        min: Discriminant::zero(),
319        max: Discriminant::pos(std::u32::MAX as u64),
320    },
321    Limits {
322        repr: I32,
323        min: Discriminant::neg(std::i32::MIN as i64),
324        max: Discriminant::pos(std::i32::MAX as u64),
325    },
326    Limits {
327        repr: U64,
328        min: Discriminant::zero(),
329        max: Discriminant::pos(std::u64::MAX),
330    },
331    Limits {
332        repr: I64,
333        min: Discriminant::neg(std::i64::MIN),
334        max: Discriminant::pos(std::i64::MAX as u64),
335    },
336];
337