1/* 2 * Copyright (C) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16#include "mdns_packet_parser.h" 17#include "netmgr_ext_log_wrapper.h" 18#include <cstring> 19 20namespace OHOS { 21namespace NetManagerStandard { 22 23namespace { 24 25constexpr size_t MDNS_STR_INITIAL_SIZE = 16; 26 27constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0; 28constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000; 29constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f; 30constexpr uint8_t DNS_STR_EOL = '\0'; 31 32template <class T> void WriteRawData(const T &data, MDnsPayload &payload) 33{ 34 const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data); 35 payload.insert(payload.end(), begin, begin + sizeof(T)); 36} 37 38template <class T> void WriteRawData(const T &data, uint8_t *ptr) 39{ 40 const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data); 41 for (size_t i = 0; i < sizeof(T); ++i) { 42 ptr[i] = *begin++; 43 } 44} 45 46template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data) 47{ 48 data = *reinterpret_cast<const T *>(raw); 49 return raw + sizeof(T); 50} 51 52const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data) 53{ 54 const uint8_t *tmp = ReadRawData(raw, data); 55 data = ntohs(data); 56 return tmp; 57} 58 59const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data) 60{ 61 const uint8_t *tmp = ReadRawData(raw, data); 62 data = ntohl(data); 63 return tmp; 64} 65 66std::string UnDotted(const std::string &name) 67{ 68 return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name; 69} 70 71} // namespace 72 73MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload) 74{ 75 MDnsMessage msg; 76 errorFlags_ = PARSE_OK; 77 pos_ = Parse(payload.data(), payload, msg); 78 return msg; 79} 80 81MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg) 82{ 83 MDnsPayload payload; 84 MDnsPayload *cachedPayload = &payload; 85 std::map<std::string, uint16_t> strCacheMap; 86 Serialize(msg, payload, cachedPayload, strCacheMap); 87 return payload; 88} 89 90const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg) 91{ 92 begin = ParseHeader(begin, payload, msg.header); 93 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 94 return begin; 95 } 96 for (int i = 0; i < msg.header.qdcount; ++i) { 97 begin = ParseQuestion(begin, payload, msg.questions); 98 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 99 return begin; 100 } 101 } 102 for (int i = 0; i < msg.header.ancount; ++i) { 103 begin = ParseRR(begin, payload, msg.answers); 104 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 105 return begin; 106 } 107 } 108 for (int i = 0; i < msg.header.nscount; ++i) { 109 begin = ParseRR(begin, payload, msg.authorities); 110 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 111 return begin; 112 } 113 } 114 for (int i = 0; i < msg.header.arcount; ++i) { 115 begin = ParseRR(begin, payload, msg.additional); 116 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 117 return begin; 118 } 119 } 120 return begin; 121} 122 123const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload, 124 DNSProto::Header &header) 125{ 126 const uint8_t *end = payload.data() + payload.size(); 127 if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) { 128 errorFlags_ |= PARSE_ERROR_BAD_SIZE; 129 return begin; 130 } 131 132 begin = ReadNUint16(begin, header.id); 133 begin = ReadNUint16(begin, header.flags); 134 begin = ReadNUint16(begin, header.qdcount); 135 begin = ReadNUint16(begin, header.ancount); 136 begin = ReadNUint16(begin, header.nscount); 137 begin = ReadNUint16(begin, header.arcount); 138 return begin; 139} 140 141const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload, 142 std::vector<DNSProto::Question> &questions) 143{ 144 questions.emplace_back(); 145 begin = ParseDnsString(begin, payload, questions.back().name); 146 questions.back().name = UnDotted(questions.back().name); 147 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 148 questions.pop_back(); 149 return begin; 150 } 151 152 const uint8_t *end = payload.data() + payload.size(); 153 if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) { 154 errorFlags_ |= PARSE_ERROR_BAD_SIZE; 155 questions.pop_back(); 156 return begin; 157 } 158 159 begin = ReadNUint16(begin, questions.back().qtype); 160 begin = ReadNUint16(begin, questions.back().qclass); 161 return begin; 162} 163 164const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload, 165 std::vector<DNSProto::ResourceRecord> &answers) 166{ 167 answers.emplace_back(); 168 begin = ParseDnsString(begin, payload, answers.back().name); 169 answers.back().name = UnDotted(answers.back().name); 170 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 171 answers.pop_back(); 172 return begin; 173 } 174 175 const uint8_t *end = payload.data() + payload.size(); 176 if (static_cast<ssize_t>(end - begin) < 177 static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) { 178 errorFlags_ |= PARSE_ERROR_BAD_SIZE; 179 answers.pop_back(); 180 return begin; 181 } 182 begin = ReadNUint16(begin, answers.back().rtype); 183 begin = ReadNUint16(begin, answers.back().rclass); 184 begin = ReadNUint32(begin, answers.back().ttl); 185 begin = ReadNUint16(begin, answers.back().length); 186 return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata); 187} 188 189const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length, 190 std::any &data) 191{ 192 switch (type) { 193 case DNSProto::RRTYPE_A: { 194 const uint8_t *end = payload.data() + payload.size(); 195 if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) { 196 errorFlags_ |= PARSE_ERROR_BAD_SIZE; 197 return begin; 198 } 199 in_addr addr; 200 begin = ReadRawData(begin, addr); 201 data = addr; 202 return begin; 203 } 204 case DNSProto::RRTYPE_AAAA: { 205 const uint8_t *end = payload.data() + payload.size(); 206 if ((static_cast<ssize_t>(end - begin) < 207 static_cast<ssize_t>(sizeof(in6_addr))) || (length != sizeof(in6_addr))) { 208 errorFlags_ |= PARSE_ERROR_BAD_SIZE; 209 return begin; 210 } 211 in6_addr addr; 212 begin = ReadRawData(begin, addr); 213 data = addr; 214 return begin; 215 } 216 case DNSProto::RRTYPE_PTR: { 217 std::string str; 218 begin = ParseDnsString(begin, payload, str); 219 str = UnDotted(str); 220 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 221 return begin; 222 } 223 data = str; 224 return begin; 225 } 226 case DNSProto::RRTYPE_SRV: { 227 return ParseSrv(begin, payload, data); 228 } 229 case DNSProto::RRTYPE_TXT: { 230 return ParseTxt(begin, payload, length, data); 231 } 232 default: { 233 errorFlags_ |= PARSE_WARNING_BAD_RRTYPE; 234 return begin + length; 235 } 236 } 237} 238 239const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data) 240{ 241 const uint8_t *end = payload.data() + payload.size(); 242 if (static_cast<ssize_t>(end - begin) < 243 static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) { 244 errorFlags_ |= PARSE_ERROR_BAD_SIZE; 245 return begin; 246 } 247 248 DNSProto::RDataSrv srv; 249 begin = ReadNUint16(begin, srv.priority); 250 begin = ReadNUint16(begin, srv.weight); 251 begin = ReadNUint16(begin, srv.port); 252 begin = ParseDnsString(begin, payload, srv.name); 253 srv.name = UnDotted(srv.name); 254 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 255 return begin; 256 } 257 data = srv; 258 return begin; 259} 260 261const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data) 262{ 263 const uint8_t *end = payload.data() + payload.size(); 264 if (end - begin < length) { 265 errorFlags_ |= PARSE_ERROR_BAD_SIZE; 266 return begin; 267 } 268 269 data = TxtRecordEncoded(begin, begin + length); 270 return begin + length; 271} 272 273const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str) 274{ 275 const uint8_t *end = payload.data() + payload.size(); 276 const uint8_t *p = begin; 277 str.reserve(MDNS_STR_INITIAL_SIZE); 278 while (p && p < end) { 279 if (*p == 0) { 280 return p + 1; 281 } 282 if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) { 283 str.append(reinterpret_cast<const char *>(p) + 1, *p); 284 str.push_back(MDNS_DOMAIN_SPLITER); 285 p += (*p + 1); 286 } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) { 287 if (end - p < static_cast<int>(sizeof(uint16_t))) { 288 errorFlags_ |= PARSE_ERROR_BAD_SIZE; 289 return begin; 290 } 291 292 uint16_t offset; 293 const uint8_t *tmp = ReadNUint16(p, offset); 294 offset = offset & ~DNS_STR_PTR_U16_MASK; 295 const uint8_t *next = payload.data() + (offset & ~DNS_STR_PTR_U16_MASK); 296 297 if (next >= end || next >= begin) { 298 errorFlags_ |= PARSE_ERROR_BAD_STRPTR; 299 return begin; 300 } 301 ParseDnsString(next, payload, str); 302 if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) { 303 return begin; 304 } 305 return tmp; 306 } else { 307 errorFlags_ |= PARSE_ERROR_BAD_STR; 308 return p; 309 } 310 } 311 return p; 312} 313 314void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload, MDnsPayload *cachedPayload, 315 std::map<std::string, uint16_t> &strCacheMap) 316{ 317 payload.reserve(sizeof(DNSProto::Message)); 318 DNSProto::Header header = msg.header; 319 header.qdcount = msg.questions.size(); 320 header.ancount = msg.answers.size(); 321 header.nscount = msg.authorities.size(); 322 header.arcount = msg.additional.size(); 323 SerializeHeader(header, msg, payload); 324 for (uint16_t i = 0; i < header.qdcount; ++i) { 325 SerializeQuestion(msg.questions[i], payload, cachedPayload, strCacheMap); 326 } 327 for (uint16_t i = 0; i < header.ancount; ++i) { 328 SerializeRR(msg.answers[i], payload, cachedPayload, strCacheMap); 329 } 330 for (uint16_t i = 0; i < header.nscount; ++i) { 331 SerializeRR(msg.authorities[i], payload, cachedPayload, strCacheMap); 332 } 333 for (uint16_t i = 0; i < header.arcount; ++i) { 334 SerializeRR(msg.additional[i], payload, cachedPayload, strCacheMap); 335 } 336} 337 338void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload) 339{ 340 WriteRawData(htons(header.id), payload); 341 WriteRawData(htons(header.flags), payload); 342 WriteRawData(htons(header.qdcount), payload); 343 WriteRawData(htons(header.ancount), payload); 344 WriteRawData(htons(header.nscount), payload); 345 WriteRawData(htons(header.arcount), payload); 346} 347 348void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, MDnsPayload &payload, 349 MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap) 350{ 351 SerializeDnsString(question.name, payload, cachedPayload, strCacheMap); 352 WriteRawData(htons(question.qtype), payload); 353 WriteRawData(htons(question.qclass), payload); 354} 355 356void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, MDnsPayload &payload, 357 MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap) 358{ 359 SerializeDnsString(rr.name, payload, cachedPayload, strCacheMap); 360 WriteRawData(htons(rr.rtype), payload); 361 WriteRawData(htons(rr.rclass), payload); 362 WriteRawData(htonl(rr.ttl), payload); 363 size_t lenStart = payload.size(); 364 WriteRawData(htons(rr.length), payload); 365 SerializeRData(rr.rdata, payload, cachedPayload, strCacheMap); 366 uint16_t len = payload.size() - lenStart - sizeof(uint16_t); 367 WriteRawData(htons(len), payload.data() + lenStart); 368} 369 370void MDnsPayloadParser::SerializeRData(const std::any &rdata, MDnsPayload &payload, MDnsPayload *cachedPayload, 371 std::map<std::string, uint16_t> &strCacheMap) 372{ 373 if (std::any_cast<const in_addr>(&rdata)) { 374 WriteRawData(*std::any_cast<const in_addr>(&rdata), payload); 375 } else if (std::any_cast<const in6_addr>(&rdata)) { 376 WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload); 377 } else if (std::any_cast<const std::string>(&rdata)) { 378 SerializeDnsString(*std::any_cast<const std::string>(&rdata), payload, cachedPayload, strCacheMap); 379 } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) { 380 const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata); 381 WriteRawData(htons(srv->priority), payload); 382 WriteRawData(htons(srv->weight), payload); 383 WriteRawData(htons(srv->port), payload); 384 SerializeDnsString(srv->name, payload, cachedPayload, strCacheMap); 385 } else if (std::any_cast<TxtRecordEncoded>(&rdata)) { 386 const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata); 387 payload.insert(payload.end(), txt->begin(), txt->end()); 388 } 389} 390 391void MDnsPayloadParser::SerializeDnsString(const std::string &str, MDnsPayload &payload, MDnsPayload *cachedPayload, 392 std::map<std::string, uint16_t> &strCacheMap) 393{ 394 size_t pos = 0; 395 while (pos < str.size()) { 396 if ((cachedPayload == &payload) && (strCacheMap.find(str.substr(pos)) != strCacheMap.end())) { 397 return WriteRawData(htons(strCacheMap[str.substr(pos)]), payload); 398 } 399 400 size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos); 401 if (nextDot == std::string::npos) { 402 nextDot = str.size(); 403 } 404 uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH; 405 406 uint16_t strptr = payload.size(); 407 WriteRawData(segLen, payload); 408 for (int i = 0; i < segLen; ++i) { 409 WriteRawData(str[pos + i], payload); 410 } 411 strCacheMap[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK; 412 pos = nextDot + 1; 413 } 414 WriteRawData(DNS_STR_EOL, payload); 415} 416 417uint32_t MDnsPayloadParser::GetError() const 418{ 419 return errorFlags_ & PARSE_ERROR; 420} 421 422} // namespace NetManagerStandard 423} // namespace OHOS 424