xref: /third_party/nghttp2/src/shrpx_router.cc (revision 2c593315)
1/*
2 * nghttp2 - HTTP/2 C Library
3 *
4 * Copyright (c) 2015 Tatsuhiro Tsujikawa
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining
7 * a copy of this software and associated documentation files (the
8 * "Software"), to deal in the Software without restriction, including
9 * without limitation the rights to use, copy, modify, merge, publish,
10 * distribute, sublicense, and/or sell copies of the Software, and to
11 * permit persons to whom the Software is furnished to do so, subject to
12 * the following conditions:
13 *
14 * The above copyright notice and this permission notice shall be
15 * included in all copies or substantial portions of the Software.
16 *
17 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
20 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
21 * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
22 * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
23 * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24 */
25#include "shrpx_router.h"
26
27#include <algorithm>
28
29#include "shrpx_config.h"
30#include "shrpx_log.h"
31
32namespace shrpx {
33
34RNode::RNode() : s(nullptr), len(0), index(-1), wildcard_index(-1) {}
35
36RNode::RNode(const char *s, size_t len, ssize_t index, ssize_t wildcard_index)
37    : s(s), len(len), index(index), wildcard_index(wildcard_index) {}
38
39Router::Router() : balloc_(1024, 1024), root_{} {}
40
41Router::~Router() {}
42
43namespace {
44RNode *find_next_node(const RNode *node, char c) {
45  auto itr = std::lower_bound(std::begin(node->next), std::end(node->next), c,
46                              [](const std::unique_ptr<RNode> &lhs,
47                                 const char c) { return lhs->s[0] < c; });
48  if (itr == std::end(node->next) || (*itr)->s[0] != c) {
49    return nullptr;
50  }
51
52  return (*itr).get();
53}
54} // namespace
55
56namespace {
57void add_next_node(RNode *node, std::unique_ptr<RNode> new_node) {
58  auto itr = std::lower_bound(std::begin(node->next), std::end(node->next),
59                              new_node->s[0],
60                              [](const std::unique_ptr<RNode> &lhs,
61                                 const char c) { return lhs->s[0] < c; });
62  node->next.insert(itr, std::move(new_node));
63}
64} // namespace
65
66void Router::add_node(RNode *node, const char *pattern, size_t patlen,
67                      ssize_t index, ssize_t wildcard_index) {
68  auto pat = make_string_ref(balloc_, StringRef{pattern, patlen});
69  auto new_node =
70      std::make_unique<RNode>(pat.c_str(), pat.size(), index, wildcard_index);
71  add_next_node(node, std::move(new_node));
72}
73
74size_t Router::add_route(const StringRef &pattern, size_t idx, bool wildcard) {
75  ssize_t index = -1, wildcard_index = -1;
76  if (wildcard) {
77    wildcard_index = idx;
78  } else {
79    index = idx;
80  }
81
82  auto node = &root_;
83  size_t i = 0;
84
85  for (;;) {
86    auto next_node = find_next_node(node, pattern[i]);
87    if (next_node == nullptr) {
88      add_node(node, pattern.c_str() + i, pattern.size() - i, index,
89               wildcard_index);
90      return idx;
91    }
92
93    node = next_node;
94
95    auto slen = pattern.size() - i;
96    auto s = pattern.c_str() + i;
97    auto n = std::min(node->len, slen);
98    size_t j;
99    for (j = 0; j < n && node->s[j] == s[j]; ++j)
100      ;
101    if (j == n) {
102      // The common prefix was matched
103      if (slen == node->len) {
104        // Complete match
105        if (index != -1) {
106          if (node->index != -1) {
107            // Return the existing index for duplicates.
108            return node->index;
109          }
110          node->index = index;
111          return idx;
112        }
113
114        assert(wildcard_index != -1);
115
116        if (node->wildcard_index != -1) {
117          return node->wildcard_index;
118        }
119        node->wildcard_index = wildcard_index;
120        return idx;
121      }
122
123      if (slen > node->len) {
124        // We still have pattern to add
125        i += j;
126
127        continue;
128      }
129    }
130
131    if (node->len > j) {
132      // node must be split into 2 nodes.  new_node is now the child
133      // of node.
134      auto new_node = std::make_unique<RNode>(
135          &node->s[j], node->len - j, node->index, node->wildcard_index);
136      std::swap(node->next, new_node->next);
137
138      node->len = j;
139      node->index = -1;
140      node->wildcard_index = -1;
141
142      add_next_node(node, std::move(new_node));
143
144      if (slen == j) {
145        node->index = index;
146        node->wildcard_index = wildcard_index;
147        return idx;
148      }
149    }
150
151    i += j;
152
153    assert(pattern.size() > i);
154    add_node(node, pattern.c_str() + i, pattern.size() - i, index,
155             wildcard_index);
156
157    return idx;
158  }
159}
160
161namespace {
162const RNode *match_complete(size_t *offset, const RNode *node,
163                            const char *first, const char *last) {
164  *offset = 0;
165
166  if (first == last) {
167    return node;
168  }
169
170  auto p = first;
171
172  for (;;) {
173    auto next_node = find_next_node(node, *p);
174    if (next_node == nullptr) {
175      return nullptr;
176    }
177
178    node = next_node;
179
180    auto n = std::min(node->len, static_cast<size_t>(last - p));
181    if (memcmp(node->s, p, n) != 0) {
182      return nullptr;
183    }
184    p += n;
185    if (p == last) {
186      *offset = n;
187      return node;
188    }
189  }
190}
191} // namespace
192
193namespace {
194const RNode *match_partial(bool *pattern_is_wildcard, const RNode *node,
195                           size_t offset, const char *first, const char *last) {
196  *pattern_is_wildcard = false;
197
198  if (first == last) {
199    if (node->len == offset) {
200      return node;
201    }
202    return nullptr;
203  }
204
205  auto p = first;
206
207  const RNode *found_node = nullptr;
208
209  if (offset > 0) {
210    auto n = std::min(node->len - offset, static_cast<size_t>(last - first));
211    if (memcmp(node->s + offset, first, n) != 0) {
212      return nullptr;
213    }
214
215    p += n;
216
217    if (p == last) {
218      if (node->len == offset + n) {
219        if (node->index != -1) {
220          return node;
221        }
222
223        // The last '/' handling, see below.
224        node = find_next_node(node, '/');
225        if (node != nullptr && node->index != -1 && node->len == 1) {
226          return node;
227        }
228
229        return nullptr;
230      }
231
232      // The last '/' handling, see below.
233      if (node->index != -1 && offset + n + 1 == node->len &&
234          node->s[node->len - 1] == '/') {
235        return node;
236      }
237
238      return nullptr;
239    }
240
241    if (node->wildcard_index != -1) {
242      found_node = node;
243      *pattern_is_wildcard = true;
244    } else if (node->index != -1 && node->s[node->len - 1] == '/') {
245      found_node = node;
246      *pattern_is_wildcard = false;
247    }
248
249    assert(node->len == offset + n);
250  }
251
252  for (;;) {
253    auto next_node = find_next_node(node, *p);
254    if (next_node == nullptr) {
255      return found_node;
256    }
257
258    node = next_node;
259
260    auto n = std::min(node->len, static_cast<size_t>(last - p));
261    if (memcmp(node->s, p, n) != 0) {
262      return found_node;
263    }
264
265    p += n;
266
267    if (p == last) {
268      if (node->len == n) {
269        // Complete match with this node
270        if (node->index != -1) {
271          *pattern_is_wildcard = false;
272          return node;
273        }
274
275        // The last '/' handling, see below.
276        node = find_next_node(node, '/');
277        if (node != nullptr && node->index != -1 && node->len == 1) {
278          *pattern_is_wildcard = false;
279          return node;
280        }
281
282        return found_node;
283      }
284
285      // We allow match without trailing "/" at the end of pattern.
286      // So, if pattern ends with '/', and pattern and path matches
287      // without that slash, we consider they match to deal with
288      // request to the directory without trailing slash.  That is if
289      // pattern is "/foo/" and path is "/foo", we consider they
290      // match.
291      if (node->index != -1 && n + 1 == node->len && node->s[n] == '/') {
292        *pattern_is_wildcard = false;
293        return node;
294      }
295
296      return found_node;
297    }
298
299    if (node->wildcard_index != -1) {
300      found_node = node;
301      *pattern_is_wildcard = true;
302    } else if (node->index != -1 && node->s[node->len - 1] == '/') {
303      // This is the case when pattern which ends with "/" is included
304      // in query.
305      found_node = node;
306      *pattern_is_wildcard = false;
307    }
308
309    assert(node->len == n);
310  }
311}
312} // namespace
313
314ssize_t Router::match(const StringRef &host, const StringRef &path) const {
315  const RNode *node;
316  size_t offset;
317
318  node = match_complete(&offset, &root_, std::begin(host), std::end(host));
319  if (node == nullptr) {
320    return -1;
321  }
322
323  bool pattern_is_wildcard;
324  node = match_partial(&pattern_is_wildcard, node, offset, std::begin(path),
325                       std::end(path));
326  if (node == nullptr || node == &root_) {
327    return -1;
328  }
329
330  return pattern_is_wildcard ? node->wildcard_index : node->index;
331}
332
333ssize_t Router::match(const StringRef &s) const {
334  const RNode *node;
335  size_t offset;
336
337  node = match_complete(&offset, &root_, std::begin(s), std::end(s));
338  if (node == nullptr) {
339    return -1;
340  }
341
342  if (node->len != offset) {
343    return -1;
344  }
345
346  return node->index;
347}
348
349namespace {
350const RNode *match_prefix(size_t *nread, const RNode *node, const char *first,
351                          const char *last) {
352  if (first == last) {
353    return nullptr;
354  }
355
356  auto p = first;
357
358  for (;;) {
359    auto next_node = find_next_node(node, *p);
360    if (next_node == nullptr) {
361      return nullptr;
362    }
363
364    node = next_node;
365
366    auto n = std::min(node->len, static_cast<size_t>(last - p));
367    if (memcmp(node->s, p, n) != 0) {
368      return nullptr;
369    }
370
371    p += n;
372
373    if (p != last) {
374      if (node->index != -1) {
375        *nread = p - first;
376        return node;
377      }
378      continue;
379    }
380
381    if (node->len == n) {
382      *nread = p - first;
383      return node;
384    }
385
386    return nullptr;
387  }
388}
389} // namespace
390
391ssize_t Router::match_prefix(size_t *nread, const RNode **last_node,
392                             const StringRef &s) const {
393  if (*last_node == nullptr) {
394    *last_node = &root_;
395  }
396
397  auto node =
398      ::shrpx::match_prefix(nread, *last_node, std::begin(s), std::end(s));
399  if (node == nullptr) {
400    return -1;
401  }
402
403  *last_node = node;
404
405  return node->index;
406}
407
408namespace {
409void dump_node(const RNode *node, int depth) {
410  fprintf(stderr, "%*ss='%.*s', len=%zu, index=%zd\n", depth, "",
411          (int)node->len, node->s, node->len, node->index);
412  for (auto &nd : node->next) {
413    dump_node(nd.get(), depth + 4);
414  }
415}
416} // namespace
417
418void Router::dump() const { dump_node(&root_, 0); }
419
420} // namespace shrpx
421