1// Copyright (c) 2023 Google LLC. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include "flags.h" 16 17#include <cstdlib> 18#include <cstring> 19#include <iostream> 20#include <regex> 21#include <string> 22#include <unordered_set> 23#include <variant> 24#include <vector> 25 26namespace flags { 27 28std::vector<std::string> positional_arguments; 29 30namespace { 31 32using token_t = const char*; 33using token_iterator_t = token_t*; 34 35// Extracts the flag name from a potential token. 36// This function only looks for a '=', to split the flag name from the value for 37// long-form flags. Returns the name of the flag, prefixed with the hyphen(s). 38inline std::string get_flag_name(const std::string& flag, bool is_short_flag) { 39 if (is_short_flag) { 40 return flag; 41 } 42 43 size_t equal_index = flag.find('='); 44 if (equal_index == std::string::npos) { 45 return flag; 46 } 47 return flag.substr(0, equal_index); 48} 49 50// Parse a boolean flag. Returns `true` if the parsing succeeded, `false` 51// otherwise. 52bool parse_bool_flag(Flag<bool>& flag, bool is_short_flag, 53 const std::string& token) { 54 if (is_short_flag) { 55 flag.value() = true; 56 return true; 57 } 58 59 const std::string raw_flag(token); 60 size_t equal_index = raw_flag.find('='); 61 if (equal_index == std::string::npos) { 62 flag.value() = true; 63 return true; 64 } 65 66 const std::string value = raw_flag.substr(equal_index + 1); 67 if (value == "true") { 68 flag.value() = true; 69 return true; 70 } 71 72 if (value == "false") { 73 flag.value() = false; 74 return true; 75 } 76 77 return false; 78} 79 80// Parse a uint32_t flag value. 81bool parse_flag_value(Flag<uint32_t>& flag, const std::string& value) { 82 std::regex unsigned_pattern("^ *[0-9]+ *$"); 83 if (!std::regex_match(value, unsigned_pattern)) { 84 std::cerr << "'" << value << "' is not a unsigned number." << std::endl; 85 return false; 86 } 87 88 errno = 0; 89 char* end_ptr = nullptr; 90 const uint64_t number = strtoull(value.c_str(), &end_ptr, 10); 91 if (end_ptr == nullptr || end_ptr != value.c_str() + value.size() || 92 errno == EINVAL) { 93 std::cerr << "'" << value << "' is not a unsigned number." << std::endl; 94 return false; 95 } 96 97 if (errno == ERANGE || number > static_cast<size_t>(UINT32_MAX)) { 98 std::cerr << "'" << value << "' cannot be represented as a 32bit unsigned." 99 << std::endl; 100 return false; 101 } 102 103 flag.value() = static_cast<uint32_t>(number); 104 return true; 105} 106 107// "Parse" a string flag value (assigns it, cannot fail). 108bool parse_flag_value(Flag<std::string>& flag, const std::string& value) { 109 flag.value() = value; 110 return true; 111} 112 113// Parse a potential multi-token flag. Moves the iterator to the last flag's 114// token if it's a multi-token flag. Returns `true` if the parsing succeeded. 115// The iterator is moved to the last parsed token. 116template <typename T> 117bool parse_flag(Flag<T>& flag, bool is_short_flag, const char*** iterator) { 118 const std::string raw_flag(**iterator); 119 std::string raw_value; 120 const size_t equal_index = raw_flag.find('='); 121 122 if (is_short_flag || equal_index == std::string::npos) { 123 if ((*iterator)[1] == nullptr) { 124 return false; 125 } 126 127 // This is a bi-token flag. Moving iterator to the last parsed token. 128 raw_value = (*iterator)[1]; 129 *iterator += 1; 130 } else { 131 // This is a mono-token flag, no need to move the iterator. 132 raw_value = raw_flag.substr(equal_index + 1); 133 } 134 135 return parse_flag_value(flag, raw_value); 136} 137 138} // namespace 139 140// This is the function to expand if you want to support a new type. 141bool FlagList::parse_flag_info(FlagInfo& info, token_iterator_t* iterator) { 142 bool success = false; 143 144 std::visit( 145 [&](auto&& item) { 146 using T = std::decay_t<decltype(item.get())>; 147 if constexpr (std::is_same_v<T, Flag<bool>>) { 148 success = parse_bool_flag(item.get(), info.is_short, **iterator); 149 } else if constexpr (std::is_same_v<T, Flag<std::string>>) { 150 success = parse_flag(item.get(), info.is_short, iterator); 151 } else if constexpr (std::is_same_v<T, Flag<uint32_t>>) { 152 success = parse_flag(item.get(), info.is_short, iterator); 153 } else { 154 static_assert(always_false_v<T>, "Unsupported flag type."); 155 } 156 }, 157 info.flag); 158 159 return success; 160} 161 162bool FlagList::parse(token_t* argv) { 163 flags::positional_arguments.clear(); 164 std::unordered_set<const FlagInfo*> parsed_flags; 165 166 bool ignore_flags = false; 167 for (const char** it = argv + 1; *it != nullptr; it++) { 168 if (ignore_flags) { 169 flags::positional_arguments.emplace_back(*it); 170 continue; 171 } 172 173 // '--' alone is used to mark the end of the flags. 174 if (std::strcmp(*it, "--") == 0) { 175 ignore_flags = true; 176 continue; 177 } 178 179 // '-' alone is not a flag, but often used to say 'stdin'. 180 if (std::strcmp(*it, "-") == 0) { 181 flags::positional_arguments.emplace_back(*it); 182 continue; 183 } 184 185 const std::string raw_flag(*it); 186 if (raw_flag.size() == 0) { 187 continue; 188 } 189 190 if (raw_flag[0] != '-') { 191 flags::positional_arguments.emplace_back(*it); 192 continue; 193 } 194 195 // Only case left: flags (long and shorts). 196 if (raw_flag.size() < 2) { 197 std::cerr << "Unknown flag " << raw_flag << std::endl; 198 return false; 199 } 200 const bool is_short_flag = std::strncmp(*it, "--", 2) != 0; 201 const std::string flag_name = get_flag_name(raw_flag, is_short_flag); 202 203 auto needle = std::find_if( 204 get_flags().begin(), get_flags().end(), 205 [&flag_name](const auto& item) { return item.name == flag_name; }); 206 if (needle == get_flags().end()) { 207 std::cerr << "Unknown flag " << flag_name << std::endl; 208 return false; 209 } 210 211 if (parsed_flags.count(&*needle) != 0) { 212 std::cerr << "The flag " << flag_name << " was specified multiple times." 213 << std::endl; 214 return false; 215 } 216 parsed_flags.insert(&*needle); 217 218 if (!parse_flag_info(*needle, &it)) { 219 std::cerr << "Invalid usage for flag " << flag_name << std::endl; 220 return false; 221 } 222 } 223 224 // Check that we parsed all required flags. 225 for (const auto& flag : get_flags()) { 226 if (!flag.required) { 227 continue; 228 } 229 230 if (parsed_flags.count(&flag) == 0) { 231 std::cerr << "Missing required flag " << flag.name << std::endl; 232 return false; 233 } 234 } 235 236 return true; 237} 238 239// Just the public wrapper around the parse function. 240bool Parse(const char** argv) { return FlagList::parse(argv); } 241 242} // namespace flags 243