1/* MIT License
2 *
3 * Copyright (c) 2023 Brad House
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a copy
6 * of this software and associated documentation files (the "Software"), to deal
7 * in the Software without restriction, including without limitation the rights
8 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 * copies of the Software, and to permit persons to whom the Software is
10 * furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 *
24 * SPDX-License-Identifier: MIT
25 */
26#include "ares_setup.h"
27#include "ares.h"
28#include "ares_private.h"
29
30struct ares__qcache {
31  ares__htable_strvp_t *cache;
32  ares__slist_t        *expire;
33  unsigned int          max_ttl;
34};
35
36typedef struct {
37  char              *key;
38  ares_dns_record_t *dnsrec;
39  time_t             expire_ts;
40  time_t             insert_ts;
41} ares__qcache_entry_t;
42
43static char *ares__qcache_calc_key(const ares_dns_record_t *dnsrec)
44{
45  ares__buf_t     *buf = ares__buf_create();
46  size_t           i;
47  ares_status_t    status;
48  ares_dns_flags_t flags;
49
50  if (dnsrec == NULL || buf == NULL) {
51    return NULL;
52  }
53
54  /* Format is OPCODE|FLAGS[|QTYPE1|QCLASS1|QNAME1]... */
55
56  status = ares__buf_append_str(
57    buf, ares_dns_opcode_tostr(ares_dns_record_get_opcode(dnsrec)));
58  if (status != ARES_SUCCESS) {
59    goto fail;
60  }
61
62  status = ares__buf_append_byte(buf, '|');
63  if (status != ARES_SUCCESS) {
64    goto fail;
65  }
66
67  flags = ares_dns_record_get_flags(dnsrec);
68  /* Only care about RD and CD */
69  if (flags & ARES_FLAG_RD) {
70    status = ares__buf_append_str(buf, "rd");
71    if (status != ARES_SUCCESS) {
72      goto fail;
73    }
74  }
75  if (flags & ARES_FLAG_CD) {
76    status = ares__buf_append_str(buf, "cd");
77    if (status != ARES_SUCCESS) {
78      goto fail;
79    }
80  }
81
82  for (i = 0; i < ares_dns_record_query_cnt(dnsrec); i++) {
83    const char         *name;
84    ares_dns_rec_type_t qtype;
85    ares_dns_class_t    qclass;
86
87    status = ares_dns_record_query_get(dnsrec, i, &name, &qtype, &qclass);
88    if (status != ARES_SUCCESS) {
89      goto fail;
90    }
91
92    status = ares__buf_append_byte(buf, '|');
93    if (status != ARES_SUCCESS) {
94      goto fail;
95    }
96
97    status = ares__buf_append_str(buf, ares_dns_rec_type_tostr(qtype));
98    if (status != ARES_SUCCESS) {
99      goto fail;
100    }
101
102    status = ares__buf_append_byte(buf, '|');
103    if (status != ARES_SUCCESS) {
104      goto fail;
105    }
106
107    status = ares__buf_append_str(buf, ares_dns_class_tostr(qclass));
108    if (status != ARES_SUCCESS) {
109      goto fail;
110    }
111
112    status = ares__buf_append_byte(buf, '|');
113    if (status != ARES_SUCCESS) {
114      goto fail;
115    }
116
117    status = ares__buf_append_str(buf, name);
118    if (status != ARES_SUCCESS) {
119      goto fail;
120    }
121  }
122
123  return ares__buf_finish_str(buf, NULL);
124
125fail:
126  ares__buf_destroy(buf);
127  return NULL;
128}
129
130static void ares__qcache_expire(ares__qcache_t       *cache,
131                                const struct timeval *now)
132{
133  ares__slist_node_t *node;
134
135  if (cache == NULL) {
136    return;
137  }
138
139  while ((node = ares__slist_node_first(cache->expire)) != NULL) {
140    const ares__qcache_entry_t *entry = ares__slist_node_val(node);
141    if (entry->expire_ts > now->tv_sec) {
142      break;
143    }
144
145    ares__htable_strvp_remove(cache->cache, entry->key);
146    ares__slist_node_destroy(node);
147  }
148}
149
150void ares__qcache_flush(ares__qcache_t *cache)
151{
152  struct timeval now;
153  memset(&now, 0, sizeof(now));
154  ares__qcache_expire(cache, &now);
155}
156
157void ares__qcache_destroy(ares__qcache_t *cache)
158{
159  if (cache == NULL) {
160    return;
161  }
162
163  ares__htable_strvp_destroy(cache->cache);
164  ares__slist_destroy(cache->expire);
165  ares_free(cache);
166}
167
168static int ares__qcache_entry_sort_cb(const void *arg1, const void *arg2)
169{
170  const ares__qcache_entry_t *entry1 = arg1;
171  const ares__qcache_entry_t *entry2 = arg2;
172
173  if (entry1->expire_ts > entry2->expire_ts) {
174    return 1;
175  }
176
177  if (entry1->expire_ts < entry2->expire_ts) {
178    return -1;
179  }
180
181  return 0;
182}
183
184static void ares__qcache_entry_destroy_cb(void *arg)
185{
186  ares__qcache_entry_t *entry = arg;
187  if (entry == NULL) {
188    return;
189  }
190
191  ares_free(entry->key);
192  ares_dns_record_destroy(entry->dnsrec);
193  ares_free(entry);
194}
195
196ares_status_t ares__qcache_create(ares_rand_state *rand_state,
197                                  unsigned int     max_ttl,
198                                  ares__qcache_t **cache_out)
199{
200  ares_status_t   status = ARES_SUCCESS;
201  ares__qcache_t *cache;
202
203  cache = ares_malloc_zero(sizeof(*cache));
204  if (cache == NULL) {
205    status = ARES_ENOMEM;
206    goto done;
207  }
208
209  cache->cache = ares__htable_strvp_create(NULL);
210  if (cache->cache == NULL) {
211    status = ARES_ENOMEM;
212    goto done;
213  }
214
215  cache->expire = ares__slist_create(rand_state, ares__qcache_entry_sort_cb,
216                                     ares__qcache_entry_destroy_cb);
217  if (cache->expire == NULL) {
218    status = ARES_ENOMEM;
219    goto done;
220  }
221
222  cache->max_ttl = max_ttl;
223
224done:
225  if (status != ARES_SUCCESS) {
226    *cache_out = NULL;
227    ares__qcache_destroy(cache);
228    return status;
229  }
230
231  *cache_out = cache;
232  return status;
233}
234
235static unsigned int ares__qcache_calc_minttl(ares_dns_record_t *dnsrec)
236{
237  unsigned int minttl = 0xFFFFFFFF;
238  size_t       sect;
239
240  for (sect = ARES_SECTION_ANSWER; sect <= ARES_SECTION_ADDITIONAL; sect++) {
241    size_t i;
242    for (i = 0; i < ares_dns_record_rr_cnt(dnsrec, (ares_dns_section_t)sect);
243         i++) {
244      const ares_dns_rr_t *rr =
245        ares_dns_record_rr_get(dnsrec, (ares_dns_section_t)sect, i);
246      ares_dns_rec_type_t type = ares_dns_rr_get_type(rr);
247      unsigned int        ttl  = ares_dns_rr_get_ttl(rr);
248      if (type == ARES_REC_TYPE_OPT || type == ARES_REC_TYPE_SOA) {
249        continue;
250      }
251
252      if (ttl < minttl) {
253        minttl = ttl;
254      }
255    }
256  }
257
258  return minttl;
259}
260
261static unsigned int ares__qcache_soa_minimum(ares_dns_record_t *dnsrec)
262{
263  size_t i;
264
265  /* RFC 2308 Section 5 says its the minimum of MINIMUM and the TTL of the
266   * record. */
267  for (i = 0; i < ares_dns_record_rr_cnt(dnsrec, ARES_SECTION_AUTHORITY); i++) {
268    const ares_dns_rr_t *rr =
269      ares_dns_record_rr_get(dnsrec, ARES_SECTION_AUTHORITY, i);
270    ares_dns_rec_type_t type = ares_dns_rr_get_type(rr);
271    unsigned int        ttl;
272    unsigned int        minimum;
273
274    if (type != ARES_REC_TYPE_SOA) {
275      continue;
276    }
277
278    minimum = ares_dns_rr_get_u32(rr, ARES_RR_SOA_MINIMUM);
279    ttl     = ares_dns_rr_get_ttl(rr);
280
281    if (ttl > minimum) {
282      return minimum;
283    }
284    return ttl;
285  }
286
287  return 0;
288}
289
290static char *ares__qcache_calc_key_frombuf(const unsigned char *qbuf,
291                                           size_t               qlen)
292{
293  ares_status_t      status;
294  ares_dns_record_t *dnsrec = NULL;
295  char              *key    = NULL;
296
297  status = ares_dns_parse(qbuf, qlen, 0, &dnsrec);
298  if (status != ARES_SUCCESS) {
299    goto done;
300  }
301
302  key = ares__qcache_calc_key(dnsrec);
303
304done:
305  ares_dns_record_destroy(dnsrec);
306  return key;
307}
308
309/* On success, takes ownership of dnsrec */
310static ares_status_t ares__qcache_insert(ares__qcache_t      *qcache,
311                                         ares_dns_record_t   *dnsrec,
312                                         const unsigned char *qbuf, size_t qlen,
313                                         const struct timeval *now)
314{
315  ares__qcache_entry_t *entry;
316  unsigned int          ttl;
317  ares_dns_rcode_t      rcode = ares_dns_record_get_rcode(dnsrec);
318  ares_dns_flags_t      flags = ares_dns_record_get_flags(dnsrec);
319
320  if (qcache == NULL || dnsrec == NULL) {
321    return ARES_EFORMERR;
322  }
323
324  /* Only save NOERROR or NXDOMAIN */
325  if (rcode != ARES_RCODE_NOERROR && rcode != ARES_RCODE_NXDOMAIN) {
326    return ARES_ENOTIMP;
327  }
328
329  /* Don't save truncated queries */
330  if (flags & ARES_FLAG_TC) {
331    return ARES_ENOTIMP;
332  }
333
334  /* Look at SOA for NXDOMAIN for minimum */
335  if (rcode == ARES_RCODE_NXDOMAIN) {
336    ttl = ares__qcache_soa_minimum(dnsrec);
337  } else {
338    ttl = ares__qcache_calc_minttl(dnsrec);
339  }
340
341  /* Don't cache something that is already expired */
342  if (ttl == 0) {
343    return ARES_EREFUSED;
344  }
345
346  if (ttl > qcache->max_ttl) {
347    ttl = qcache->max_ttl;
348  }
349
350  entry = ares_malloc_zero(sizeof(*entry));
351  if (entry == NULL) {
352    goto fail;
353  }
354
355  entry->dnsrec    = dnsrec;
356  entry->expire_ts = now->tv_sec + (time_t)ttl;
357  entry->insert_ts = now->tv_sec;
358
359  /* We can't guarantee the server responded with the same flags as the
360   * request had, so we have to re-parse the request in order to generate the
361   * key for caching, but we'll only do this once we know for sure we really
362   * want to cache it */
363  entry->key = ares__qcache_calc_key_frombuf(qbuf, qlen);
364  if (entry->key == NULL) {
365    goto fail;
366  }
367
368  if (!ares__htable_strvp_insert(qcache->cache, entry->key, entry)) {
369    goto fail;
370  }
371
372  if (ares__slist_insert(qcache->expire, entry) == NULL) {
373    goto fail;
374  }
375
376  return ARES_SUCCESS;
377
378fail:
379  if (entry != NULL && entry->key != NULL) {
380    ares__htable_strvp_remove(qcache->cache, entry->key);
381    ares_free(entry->key);
382    ares_free(entry);
383  }
384  return ARES_ENOMEM;
385}
386
387static ares_status_t ares__qcache_fetch(ares__qcache_t          *qcache,
388                                        const ares_dns_record_t *dnsrec,
389                                        const struct timeval    *now,
390                                        unsigned char **buf, size_t *buf_len)
391{
392  char                 *key = NULL;
393  ares__qcache_entry_t *entry;
394  ares_status_t         status;
395
396  if (qcache == NULL || dnsrec == NULL) {
397    return ARES_EFORMERR;
398  }
399
400  ares__qcache_expire(qcache, now);
401
402  key = ares__qcache_calc_key(dnsrec);
403  if (key == NULL) {
404    status = ARES_ENOMEM;
405    goto done;
406  }
407
408  entry = ares__htable_strvp_get_direct(qcache->cache, key);
409  if (entry == NULL) {
410    status = ARES_ENOTFOUND;
411    goto done;
412  }
413
414  ares_dns_record_write_ttl_decrement(
415    entry->dnsrec, (unsigned int)(now->tv_sec - entry->insert_ts));
416
417  status = ares_dns_write(entry->dnsrec, buf, buf_len);
418
419done:
420  ares_free(key);
421  return status;
422}
423
424ares_status_t ares_qcache_insert(ares_channel_t       *channel,
425                                 const struct timeval *now,
426                                 const struct query   *query,
427                                 ares_dns_record_t    *dnsrec)
428{
429  return ares__qcache_insert(channel->qcache, dnsrec, query->qbuf, query->qlen,
430                             now);
431}
432
433ares_status_t ares_qcache_fetch(ares_channel_t       *channel,
434                                const struct timeval *now,
435                                const unsigned char *qbuf, size_t qlen,
436                                unsigned char **abuf, size_t *alen)
437{
438  ares_status_t      status;
439  ares_dns_record_t *dnsrec = NULL;
440
441  if (channel->qcache == NULL) {
442    return ARES_ENOTFOUND;
443  }
444
445  status = ares_dns_parse(qbuf, qlen, 0, &dnsrec);
446  if (status != ARES_SUCCESS) {
447    goto done;
448  }
449
450  status = ares__qcache_fetch(channel->qcache, dnsrec, now, abuf, alen);
451
452done:
453  ares_dns_record_destroy(dnsrec);
454  return status;
455}
456