1// Copyright (c) 2023 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "flags.h"
16
17#include <cstdlib>
18#include <cstring>
19#include <iostream>
20#include <regex>
21#include <string>
22#include <unordered_set>
23#include <variant>
24#include <vector>
25
26namespace flags {
27
28std::vector<std::string> positional_arguments;
29
30namespace {
31
32using token_t = const char*;
33using token_iterator_t = token_t*;
34
35// Extracts the flag name from a potential token.
36// This function only looks for a '=', to split the flag name from the value for
37// long-form flags. Returns the name of the flag, prefixed with the hyphen(s).
38inline std::string get_flag_name(const std::string& flag, bool is_short_flag) {
39  if (is_short_flag) {
40    return flag;
41  }
42
43  size_t equal_index = flag.find('=');
44  if (equal_index == std::string::npos) {
45    return flag;
46  }
47  return flag.substr(0, equal_index);
48}
49
50// Parse a boolean flag. Returns `true` if the parsing succeeded, `false`
51// otherwise.
52bool parse_bool_flag(Flag<bool>& flag, bool is_short_flag,
53                     const std::string& token) {
54  if (is_short_flag) {
55    flag.value() = true;
56    return true;
57  }
58
59  const std::string raw_flag(token);
60  size_t equal_index = raw_flag.find('=');
61  if (equal_index == std::string::npos) {
62    flag.value() = true;
63    return true;
64  }
65
66  const std::string value = raw_flag.substr(equal_index + 1);
67  if (value == "true") {
68    flag.value() = true;
69    return true;
70  }
71
72  if (value == "false") {
73    flag.value() = false;
74    return true;
75  }
76
77  return false;
78}
79
80// Parse a uint32_t flag value.
81bool parse_flag_value(Flag<uint32_t>& flag, const std::string& value) {
82  std::regex unsigned_pattern("^ *[0-9]+ *$");
83  if (!std::regex_match(value, unsigned_pattern)) {
84    std::cerr << "'" << value << "' is not a unsigned number." << std::endl;
85    return false;
86  }
87
88  errno = 0;
89  char* end_ptr = nullptr;
90  const uint64_t number = strtoull(value.c_str(), &end_ptr, 10);
91  if (end_ptr == nullptr || end_ptr != value.c_str() + value.size() ||
92      errno == EINVAL) {
93    std::cerr << "'" << value << "' is not a unsigned number." << std::endl;
94    return false;
95  }
96
97  if (errno == ERANGE || number > static_cast<size_t>(UINT32_MAX)) {
98    std::cerr << "'" << value << "' cannot be represented as a 32bit unsigned."
99              << std::endl;
100    return false;
101  }
102
103  flag.value() = static_cast<uint32_t>(number);
104  return true;
105}
106
107// "Parse" a string flag value (assigns it, cannot fail).
108bool parse_flag_value(Flag<std::string>& flag, const std::string& value) {
109  flag.value() = value;
110  return true;
111}
112
113// Parse a potential multi-token flag. Moves the iterator to the last flag's
114// token if it's a multi-token flag. Returns `true` if the parsing succeeded.
115// The iterator is moved to the last parsed token.
116template <typename T>
117bool parse_flag(Flag<T>& flag, bool is_short_flag, const char*** iterator) {
118  const std::string raw_flag(**iterator);
119  std::string raw_value;
120  const size_t equal_index = raw_flag.find('=');
121
122  if (is_short_flag || equal_index == std::string::npos) {
123    if ((*iterator)[1] == nullptr) {
124      return false;
125    }
126
127    // This is a bi-token flag. Moving iterator to the last parsed token.
128    raw_value = (*iterator)[1];
129    *iterator += 1;
130  } else {
131    // This is a mono-token flag, no need to move the iterator.
132    raw_value = raw_flag.substr(equal_index + 1);
133  }
134
135  return parse_flag_value(flag, raw_value);
136}
137
138}  // namespace
139
140// This is the function to expand if you want to support a new type.
141bool FlagList::parse_flag_info(FlagInfo& info, token_iterator_t* iterator) {
142  bool success = false;
143
144  std::visit(
145      [&](auto&& item) {
146        using T = std::decay_t<decltype(item.get())>;
147        if constexpr (std::is_same_v<T, Flag<bool>>) {
148          success = parse_bool_flag(item.get(), info.is_short, **iterator);
149        } else if constexpr (std::is_same_v<T, Flag<std::string>>) {
150          success = parse_flag(item.get(), info.is_short, iterator);
151        } else if constexpr (std::is_same_v<T, Flag<uint32_t>>) {
152          success = parse_flag(item.get(), info.is_short, iterator);
153        } else {
154          static_assert(always_false_v<T>, "Unsupported flag type.");
155        }
156      },
157      info.flag);
158
159  return success;
160}
161
162bool FlagList::parse(token_t* argv) {
163  flags::positional_arguments.clear();
164  std::unordered_set<const FlagInfo*> parsed_flags;
165
166  bool ignore_flags = false;
167  for (const char** it = argv + 1; *it != nullptr; it++) {
168    if (ignore_flags) {
169      flags::positional_arguments.emplace_back(*it);
170      continue;
171    }
172
173    // '--' alone is used to mark the end of the flags.
174    if (std::strcmp(*it, "--") == 0) {
175      ignore_flags = true;
176      continue;
177    }
178
179    // '-' alone is not a flag, but often used to say 'stdin'.
180    if (std::strcmp(*it, "-") == 0) {
181      flags::positional_arguments.emplace_back(*it);
182      continue;
183    }
184
185    const std::string raw_flag(*it);
186    if (raw_flag.size() == 0) {
187      continue;
188    }
189
190    if (raw_flag[0] != '-') {
191      flags::positional_arguments.emplace_back(*it);
192      continue;
193    }
194
195    // Only case left: flags (long and shorts).
196    if (raw_flag.size() < 2) {
197      std::cerr << "Unknown flag " << raw_flag << std::endl;
198      return false;
199    }
200    const bool is_short_flag = std::strncmp(*it, "--", 2) != 0;
201    const std::string flag_name = get_flag_name(raw_flag, is_short_flag);
202
203    auto needle = std::find_if(
204        get_flags().begin(), get_flags().end(),
205        [&flag_name](const auto& item) { return item.name == flag_name; });
206    if (needle == get_flags().end()) {
207      std::cerr << "Unknown flag " << flag_name << std::endl;
208      return false;
209    }
210
211    if (parsed_flags.count(&*needle) != 0) {
212      std::cerr << "The flag " << flag_name << " was specified multiple times."
213                << std::endl;
214      return false;
215    }
216    parsed_flags.insert(&*needle);
217
218    if (!parse_flag_info(*needle, &it)) {
219      std::cerr << "Invalid usage for flag " << flag_name << std::endl;
220      return false;
221    }
222  }
223
224  // Check that we parsed all required flags.
225  for (const auto& flag : get_flags()) {
226    if (!flag.required) {
227      continue;
228    }
229
230    if (parsed_flags.count(&flag) == 0) {
231      std::cerr << "Missing required flag " << flag.name << std::endl;
232      return false;
233    }
234  }
235
236  return true;
237}
238
239// Just the public wrapper around the parse function.
240bool Parse(const char** argv) { return FlagList::parse(argv); }
241
242}  // namespace flags
243