1/*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "mdns_packet_parser.h"
17#include "netmgr_ext_log_wrapper.h"
18#include <cstring>
19
20namespace OHOS {
21namespace NetManagerStandard {
22
23namespace {
24
25constexpr size_t MDNS_STR_INITIAL_SIZE = 16;
26
27constexpr uint8_t DNS_STR_PTR_U8_MASK = 0xc0;
28constexpr uint16_t DNS_STR_PTR_U16_MASK = 0xc000;
29constexpr uint16_t DNS_STR_PTR_LENGTH = 0x3f;
30constexpr uint8_t DNS_STR_EOL = '\0';
31
32template <class T> void WriteRawData(const T &data, MDnsPayload &payload)
33{
34    const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
35    payload.insert(payload.end(), begin, begin + sizeof(T));
36}
37
38template <class T> void WriteRawData(const T &data, uint8_t *ptr)
39{
40    const uint8_t *begin = reinterpret_cast<const uint8_t *>(&data);
41    for (size_t i = 0; i < sizeof(T); ++i) {
42        ptr[i] = *begin++;
43    }
44}
45
46template <class T> const uint8_t *ReadRawData(const uint8_t *raw, T &data)
47{
48    data = *reinterpret_cast<const T *>(raw);
49    return raw + sizeof(T);
50}
51
52const uint8_t *ReadNUint16(const uint8_t *raw, uint16_t &data)
53{
54    const uint8_t *tmp = ReadRawData(raw, data);
55    data = ntohs(data);
56    return tmp;
57}
58
59const uint8_t *ReadNUint32(const uint8_t *raw, uint32_t &data)
60{
61    const uint8_t *tmp = ReadRawData(raw, data);
62    data = ntohl(data);
63    return tmp;
64}
65
66std::string UnDotted(const std::string &name)
67{
68    return EndsWith(name, MDNS_DOMAIN_SPLITER_STR) ? name.substr(0, name.size() - 1) : name;
69}
70
71} // namespace
72
73MDnsMessage MDnsPayloadParser::FromBytes(const MDnsPayload &payload)
74{
75    MDnsMessage msg;
76    errorFlags_ = PARSE_OK;
77    pos_ = Parse(payload.data(), payload, msg);
78    return msg;
79}
80
81MDnsPayload MDnsPayloadParser::ToBytes(const MDnsMessage &msg)
82{
83    MDnsPayload payload;
84    MDnsPayload *cachedPayload = &payload;
85    std::map<std::string, uint16_t> strCacheMap;
86    Serialize(msg, payload, cachedPayload, strCacheMap);
87    return payload;
88}
89
90const uint8_t *MDnsPayloadParser::Parse(const uint8_t *begin, const MDnsPayload &payload, MDnsMessage &msg)
91{
92    begin = ParseHeader(begin, payload, msg.header);
93    if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
94        return begin;
95    }
96    for (int i = 0; i < msg.header.qdcount; ++i) {
97        begin = ParseQuestion(begin, payload, msg.questions);
98        if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
99            return begin;
100        }
101    }
102    for (int i = 0; i < msg.header.ancount; ++i) {
103        begin = ParseRR(begin, payload, msg.answers);
104        if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
105            return begin;
106        }
107    }
108    for (int i = 0; i < msg.header.nscount; ++i) {
109        begin = ParseRR(begin, payload, msg.authorities);
110        if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
111            return begin;
112        }
113    }
114    for (int i = 0; i < msg.header.arcount; ++i) {
115        begin = ParseRR(begin, payload, msg.additional);
116        if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
117            return begin;
118        }
119    }
120    return begin;
121}
122
123const uint8_t *MDnsPayloadParser::ParseHeader(const uint8_t *begin, const MDnsPayload &payload,
124                                              DNSProto::Header &header)
125{
126    const uint8_t *end = payload.data() + payload.size();
127    if (end - begin < static_cast<int>(sizeof(DNSProto::Header))) {
128        errorFlags_ |= PARSE_ERROR_BAD_SIZE;
129        return begin;
130    }
131
132    begin = ReadNUint16(begin, header.id);
133    begin = ReadNUint16(begin, header.flags);
134    begin = ReadNUint16(begin, header.qdcount);
135    begin = ReadNUint16(begin, header.ancount);
136    begin = ReadNUint16(begin, header.nscount);
137    begin = ReadNUint16(begin, header.arcount);
138    return begin;
139}
140
141const uint8_t *MDnsPayloadParser::ParseQuestion(const uint8_t *begin, const MDnsPayload &payload,
142                                                std::vector<DNSProto::Question> &questions)
143{
144    questions.emplace_back();
145    begin = ParseDnsString(begin, payload, questions.back().name);
146    questions.back().name = UnDotted(questions.back().name);
147    if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
148        questions.pop_back();
149        return begin;
150    }
151
152    const uint8_t *end = payload.data() + payload.size();
153    if (static_cast<ssize_t>(end - begin) < static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t))) {
154        errorFlags_ |= PARSE_ERROR_BAD_SIZE;
155        questions.pop_back();
156        return begin;
157    }
158
159    begin = ReadNUint16(begin, questions.back().qtype);
160    begin = ReadNUint16(begin, questions.back().qclass);
161    return begin;
162}
163
164const uint8_t *MDnsPayloadParser::ParseRR(const uint8_t *begin, const MDnsPayload &payload,
165                                          std::vector<DNSProto::ResourceRecord> &answers)
166{
167    answers.emplace_back();
168    begin = ParseDnsString(begin, payload, answers.back().name);
169    answers.back().name = UnDotted(answers.back().name);
170    if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
171        answers.pop_back();
172        return begin;
173    }
174
175    const uint8_t *end = payload.data() + payload.size();
176    if (static_cast<ssize_t>(end - begin) <
177        static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint16_t))) {
178        errorFlags_ |= PARSE_ERROR_BAD_SIZE;
179        answers.pop_back();
180        return begin;
181    }
182    begin = ReadNUint16(begin, answers.back().rtype);
183    begin = ReadNUint16(begin, answers.back().rclass);
184    begin = ReadNUint32(begin, answers.back().ttl);
185    begin = ReadNUint16(begin, answers.back().length);
186    return ParseRData(begin, payload, answers.back().rtype, answers.back().length, answers.back().rdata);
187}
188
189const uint8_t *MDnsPayloadParser::ParseRData(const uint8_t *begin, const MDnsPayload &payload, int type, int length,
190                                             std::any &data)
191{
192    switch (type) {
193        case DNSProto::RRTYPE_A: {
194            const uint8_t *end = payload.data() + payload.size();
195            if (static_cast<size_t>(end - begin) < sizeof(in_addr) || length != sizeof(in_addr)) {
196                errorFlags_ |= PARSE_ERROR_BAD_SIZE;
197                return begin;
198            }
199            in_addr addr;
200            begin = ReadRawData(begin, addr);
201            data = addr;
202            return begin;
203        }
204        case DNSProto::RRTYPE_AAAA: {
205            const uint8_t *end = payload.data() + payload.size();
206            if ((static_cast<ssize_t>(end - begin) <
207                static_cast<ssize_t>(sizeof(in6_addr))) || (length != sizeof(in6_addr))) {
208                errorFlags_ |= PARSE_ERROR_BAD_SIZE;
209                return begin;
210            }
211            in6_addr addr;
212            begin = ReadRawData(begin, addr);
213            data = addr;
214            return begin;
215        }
216        case DNSProto::RRTYPE_PTR: {
217            std::string str;
218            begin = ParseDnsString(begin, payload, str);
219            str = UnDotted(str);
220            if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
221                return begin;
222            }
223            data = str;
224            return begin;
225        }
226        case DNSProto::RRTYPE_SRV: {
227            return ParseSrv(begin, payload, data);
228        }
229        case DNSProto::RRTYPE_TXT: {
230            return ParseTxt(begin, payload, length, data);
231        }
232        default: {
233            errorFlags_ |= PARSE_WARNING_BAD_RRTYPE;
234            return begin + length;
235        }
236    }
237}
238
239const uint8_t *MDnsPayloadParser::ParseSrv(const uint8_t *begin, const MDnsPayload &payload, std::any &data)
240{
241    const uint8_t *end = payload.data() + payload.size();
242    if (static_cast<ssize_t>(end - begin) <
243        static_cast<ssize_t>(sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t))) {
244        errorFlags_ |= PARSE_ERROR_BAD_SIZE;
245        return begin;
246    }
247
248    DNSProto::RDataSrv srv;
249    begin = ReadNUint16(begin, srv.priority);
250    begin = ReadNUint16(begin, srv.weight);
251    begin = ReadNUint16(begin, srv.port);
252    begin = ParseDnsString(begin, payload, srv.name);
253    srv.name = UnDotted(srv.name);
254    if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
255        return begin;
256    }
257    data = srv;
258    return begin;
259}
260
261const uint8_t *MDnsPayloadParser::ParseTxt(const uint8_t *begin, const MDnsPayload &payload, int length, std::any &data)
262{
263    const uint8_t *end = payload.data() + payload.size();
264    if (end - begin < length) {
265        errorFlags_ |= PARSE_ERROR_BAD_SIZE;
266        return begin;
267    }
268
269    data = TxtRecordEncoded(begin, begin + length);
270    return begin + length;
271}
272
273const uint8_t *MDnsPayloadParser::ParseDnsString(const uint8_t *begin, const MDnsPayload &payload, std::string &str)
274{
275    const uint8_t *end = payload.data() + payload.size();
276    const uint8_t *p = begin;
277    str.reserve(MDNS_STR_INITIAL_SIZE);
278    while (p && p < end) {
279        if (*p == 0) {
280            return p + 1;
281        }
282        if (*p <= MDNS_MAX_DOMAIN_LABEL && p + *p < end) {
283            str.append(reinterpret_cast<const char *>(p) + 1, *p);
284            str.push_back(MDNS_DOMAIN_SPLITER);
285            p += (*p + 1);
286        } else if ((*p & DNS_STR_PTR_U8_MASK) == DNS_STR_PTR_U8_MASK) {
287            if (end - p < static_cast<int>(sizeof(uint16_t))) {
288                errorFlags_ |= PARSE_ERROR_BAD_SIZE;
289                return begin;
290            }
291
292            uint16_t offset;
293            const uint8_t *tmp = ReadNUint16(p, offset);
294            offset = offset & ~DNS_STR_PTR_U16_MASK;
295            const uint8_t *next = payload.data() + (offset & ~DNS_STR_PTR_U16_MASK);
296
297            if (next >= end || next >= begin) {
298                errorFlags_ |= PARSE_ERROR_BAD_STRPTR;
299                return begin;
300            }
301            ParseDnsString(next, payload, str);
302            if ((errorFlags_ & PARSE_ERROR) != PARSE_OK) {
303                return begin;
304            }
305            return tmp;
306        } else {
307            errorFlags_ |= PARSE_ERROR_BAD_STR;
308            return p;
309        }
310    }
311    return p;
312}
313
314void MDnsPayloadParser::Serialize(const MDnsMessage &msg, MDnsPayload &payload, MDnsPayload *cachedPayload,
315                                  std::map<std::string, uint16_t> &strCacheMap)
316{
317    payload.reserve(sizeof(DNSProto::Message));
318    DNSProto::Header header = msg.header;
319    header.qdcount = msg.questions.size();
320    header.ancount = msg.answers.size();
321    header.nscount = msg.authorities.size();
322    header.arcount = msg.additional.size();
323    SerializeHeader(header, msg, payload);
324    for (uint16_t i = 0; i < header.qdcount; ++i) {
325        SerializeQuestion(msg.questions[i], payload, cachedPayload, strCacheMap);
326    }
327    for (uint16_t i = 0; i < header.ancount; ++i) {
328        SerializeRR(msg.answers[i], payload, cachedPayload, strCacheMap);
329    }
330    for (uint16_t i = 0; i < header.nscount; ++i) {
331        SerializeRR(msg.authorities[i], payload, cachedPayload, strCacheMap);
332    }
333    for (uint16_t i = 0; i < header.arcount; ++i) {
334        SerializeRR(msg.additional[i], payload, cachedPayload, strCacheMap);
335    }
336}
337
338void MDnsPayloadParser::SerializeHeader(const DNSProto::Header &header, const MDnsMessage &msg, MDnsPayload &payload)
339{
340    WriteRawData(htons(header.id), payload);
341    WriteRawData(htons(header.flags), payload);
342    WriteRawData(htons(header.qdcount), payload);
343    WriteRawData(htons(header.ancount), payload);
344    WriteRawData(htons(header.nscount), payload);
345    WriteRawData(htons(header.arcount), payload);
346}
347
348void MDnsPayloadParser::SerializeQuestion(const DNSProto::Question &question, MDnsPayload &payload,
349                                          MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap)
350{
351    SerializeDnsString(question.name, payload, cachedPayload, strCacheMap);
352    WriteRawData(htons(question.qtype), payload);
353    WriteRawData(htons(question.qclass), payload);
354}
355
356void MDnsPayloadParser::SerializeRR(const DNSProto::ResourceRecord &rr, MDnsPayload &payload,
357                                    MDnsPayload *cachedPayload, std::map<std::string, uint16_t> &strCacheMap)
358{
359    SerializeDnsString(rr.name, payload, cachedPayload, strCacheMap);
360    WriteRawData(htons(rr.rtype), payload);
361    WriteRawData(htons(rr.rclass), payload);
362    WriteRawData(htonl(rr.ttl), payload);
363    size_t lenStart = payload.size();
364    WriteRawData(htons(rr.length), payload);
365    SerializeRData(rr.rdata, payload, cachedPayload, strCacheMap);
366    uint16_t len = payload.size() - lenStart - sizeof(uint16_t);
367    WriteRawData(htons(len), payload.data() + lenStart);
368}
369
370void MDnsPayloadParser::SerializeRData(const std::any &rdata, MDnsPayload &payload, MDnsPayload *cachedPayload,
371                                       std::map<std::string, uint16_t> &strCacheMap)
372{
373    if (std::any_cast<const in_addr>(&rdata)) {
374        WriteRawData(*std::any_cast<const in_addr>(&rdata), payload);
375    } else if (std::any_cast<const in6_addr>(&rdata)) {
376        WriteRawData(*std::any_cast<const in6_addr>(&rdata), payload);
377    } else if (std::any_cast<const std::string>(&rdata)) {
378        SerializeDnsString(*std::any_cast<const std::string>(&rdata), payload, cachedPayload, strCacheMap);
379    } else if (std::any_cast<const DNSProto::RDataSrv>(&rdata)) {
380        const DNSProto::RDataSrv *srv = std::any_cast<const DNSProto::RDataSrv>(&rdata);
381        WriteRawData(htons(srv->priority), payload);
382        WriteRawData(htons(srv->weight), payload);
383        WriteRawData(htons(srv->port), payload);
384        SerializeDnsString(srv->name, payload, cachedPayload, strCacheMap);
385    } else if (std::any_cast<TxtRecordEncoded>(&rdata)) {
386        const auto *txt = std::any_cast<TxtRecordEncoded>(&rdata);
387        payload.insert(payload.end(), txt->begin(), txt->end());
388    }
389}
390
391void MDnsPayloadParser::SerializeDnsString(const std::string &str, MDnsPayload &payload, MDnsPayload *cachedPayload,
392                                           std::map<std::string, uint16_t> &strCacheMap)
393{
394    size_t pos = 0;
395    while (pos < str.size()) {
396        if ((cachedPayload == &payload) && (strCacheMap.find(str.substr(pos)) != strCacheMap.end())) {
397            return WriteRawData(htons(strCacheMap[str.substr(pos)]), payload);
398        }
399
400        size_t nextDot = str.find(MDNS_DOMAIN_SPLITER, pos);
401        if (nextDot == std::string::npos) {
402            nextDot = str.size();
403        }
404        uint8_t segLen = (nextDot - pos) & DNS_STR_PTR_LENGTH;
405
406        uint16_t strptr = payload.size();
407        WriteRawData(segLen, payload);
408        for (int i = 0; i < segLen; ++i) {
409            WriteRawData(str[pos + i], payload);
410        }
411        strCacheMap[str.substr(pos)] = strptr | DNS_STR_PTR_U16_MASK;
412        pos = nextDot + 1;
413    }
414    WriteRawData(DNS_STR_EOL, payload);
415}
416
417uint32_t MDnsPayloadParser::GetError() const
418{
419    return errorFlags_ & PARSE_ERROR;
420}
421
422} // namespace NetManagerStandard
423} // namespace OHOS
424