162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0
262306a36Sopenharmony_ci/*
362306a36Sopenharmony_ci * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
462306a36Sopenharmony_ci */
562306a36Sopenharmony_ci
662306a36Sopenharmony_ci#include "ratelimiter.h"
762306a36Sopenharmony_ci#include <linux/siphash.h>
862306a36Sopenharmony_ci#include <linux/mm.h>
962306a36Sopenharmony_ci#include <linux/slab.h>
1062306a36Sopenharmony_ci#include <net/ip.h>
1162306a36Sopenharmony_ci
1262306a36Sopenharmony_cistatic struct kmem_cache *entry_cache;
1362306a36Sopenharmony_cistatic hsiphash_key_t key;
1462306a36Sopenharmony_cistatic spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
1562306a36Sopenharmony_cistatic DEFINE_MUTEX(init_lock);
1662306a36Sopenharmony_cistatic u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
1762306a36Sopenharmony_cistatic atomic_t total_entries = ATOMIC_INIT(0);
1862306a36Sopenharmony_cistatic unsigned int max_entries, table_size;
1962306a36Sopenharmony_cistatic void wg_ratelimiter_gc_entries(struct work_struct *);
2062306a36Sopenharmony_cistatic DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
2162306a36Sopenharmony_cistatic struct hlist_head *table_v4;
2262306a36Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
2362306a36Sopenharmony_cistatic struct hlist_head *table_v6;
2462306a36Sopenharmony_ci#endif
2562306a36Sopenharmony_ci
2662306a36Sopenharmony_cistruct ratelimiter_entry {
2762306a36Sopenharmony_ci	u64 last_time_ns, tokens, ip;
2862306a36Sopenharmony_ci	void *net;
2962306a36Sopenharmony_ci	spinlock_t lock;
3062306a36Sopenharmony_ci	struct hlist_node hash;
3162306a36Sopenharmony_ci	struct rcu_head rcu;
3262306a36Sopenharmony_ci};
3362306a36Sopenharmony_ci
3462306a36Sopenharmony_cienum {
3562306a36Sopenharmony_ci	PACKETS_PER_SECOND = 20,
3662306a36Sopenharmony_ci	PACKETS_BURSTABLE = 5,
3762306a36Sopenharmony_ci	PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
3862306a36Sopenharmony_ci	TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
3962306a36Sopenharmony_ci};
4062306a36Sopenharmony_ci
4162306a36Sopenharmony_cistatic void entry_free(struct rcu_head *rcu)
4262306a36Sopenharmony_ci{
4362306a36Sopenharmony_ci	kmem_cache_free(entry_cache,
4462306a36Sopenharmony_ci			container_of(rcu, struct ratelimiter_entry, rcu));
4562306a36Sopenharmony_ci	atomic_dec(&total_entries);
4662306a36Sopenharmony_ci}
4762306a36Sopenharmony_ci
4862306a36Sopenharmony_cistatic void entry_uninit(struct ratelimiter_entry *entry)
4962306a36Sopenharmony_ci{
5062306a36Sopenharmony_ci	hlist_del_rcu(&entry->hash);
5162306a36Sopenharmony_ci	call_rcu(&entry->rcu, entry_free);
5262306a36Sopenharmony_ci}
5362306a36Sopenharmony_ci
5462306a36Sopenharmony_ci/* Calling this function with a NULL work uninits all entries. */
5562306a36Sopenharmony_cistatic void wg_ratelimiter_gc_entries(struct work_struct *work)
5662306a36Sopenharmony_ci{
5762306a36Sopenharmony_ci	const u64 now = ktime_get_coarse_boottime_ns();
5862306a36Sopenharmony_ci	struct ratelimiter_entry *entry;
5962306a36Sopenharmony_ci	struct hlist_node *temp;
6062306a36Sopenharmony_ci	unsigned int i;
6162306a36Sopenharmony_ci
6262306a36Sopenharmony_ci	for (i = 0; i < table_size; ++i) {
6362306a36Sopenharmony_ci		spin_lock(&table_lock);
6462306a36Sopenharmony_ci		hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
6562306a36Sopenharmony_ci			if (unlikely(!work) ||
6662306a36Sopenharmony_ci			    now - entry->last_time_ns > NSEC_PER_SEC)
6762306a36Sopenharmony_ci				entry_uninit(entry);
6862306a36Sopenharmony_ci		}
6962306a36Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
7062306a36Sopenharmony_ci		hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
7162306a36Sopenharmony_ci			if (unlikely(!work) ||
7262306a36Sopenharmony_ci			    now - entry->last_time_ns > NSEC_PER_SEC)
7362306a36Sopenharmony_ci				entry_uninit(entry);
7462306a36Sopenharmony_ci		}
7562306a36Sopenharmony_ci#endif
7662306a36Sopenharmony_ci		spin_unlock(&table_lock);
7762306a36Sopenharmony_ci		if (likely(work))
7862306a36Sopenharmony_ci			cond_resched();
7962306a36Sopenharmony_ci	}
8062306a36Sopenharmony_ci	if (likely(work))
8162306a36Sopenharmony_ci		queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
8262306a36Sopenharmony_ci}
8362306a36Sopenharmony_ci
8462306a36Sopenharmony_cibool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
8562306a36Sopenharmony_ci{
8662306a36Sopenharmony_ci	/* We only take the bottom half of the net pointer, so that we can hash
8762306a36Sopenharmony_ci	 * 3 words in the end. This way, siphash's len param fits into the final
8862306a36Sopenharmony_ci	 * u32, and we don't incur an extra round.
8962306a36Sopenharmony_ci	 */
9062306a36Sopenharmony_ci	const u32 net_word = (unsigned long)net;
9162306a36Sopenharmony_ci	struct ratelimiter_entry *entry;
9262306a36Sopenharmony_ci	struct hlist_head *bucket;
9362306a36Sopenharmony_ci	u64 ip;
9462306a36Sopenharmony_ci
9562306a36Sopenharmony_ci	if (skb->protocol == htons(ETH_P_IP)) {
9662306a36Sopenharmony_ci		ip = (u64 __force)ip_hdr(skb)->saddr;
9762306a36Sopenharmony_ci		bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
9862306a36Sopenharmony_ci				   (table_size - 1)];
9962306a36Sopenharmony_ci	}
10062306a36Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
10162306a36Sopenharmony_ci	else if (skb->protocol == htons(ETH_P_IPV6)) {
10262306a36Sopenharmony_ci		/* Only use 64 bits, so as to ratelimit the whole /64. */
10362306a36Sopenharmony_ci		memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
10462306a36Sopenharmony_ci		bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
10562306a36Sopenharmony_ci				   (table_size - 1)];
10662306a36Sopenharmony_ci	}
10762306a36Sopenharmony_ci#endif
10862306a36Sopenharmony_ci	else
10962306a36Sopenharmony_ci		return false;
11062306a36Sopenharmony_ci	rcu_read_lock();
11162306a36Sopenharmony_ci	hlist_for_each_entry_rcu(entry, bucket, hash) {
11262306a36Sopenharmony_ci		if (entry->net == net && entry->ip == ip) {
11362306a36Sopenharmony_ci			u64 now, tokens;
11462306a36Sopenharmony_ci			bool ret;
11562306a36Sopenharmony_ci			/* Quasi-inspired by nft_limit.c, but this is actually a
11662306a36Sopenharmony_ci			 * slightly different algorithm. Namely, we incorporate
11762306a36Sopenharmony_ci			 * the burst as part of the maximum tokens, rather than
11862306a36Sopenharmony_ci			 * as part of the rate.
11962306a36Sopenharmony_ci			 */
12062306a36Sopenharmony_ci			spin_lock(&entry->lock);
12162306a36Sopenharmony_ci			now = ktime_get_coarse_boottime_ns();
12262306a36Sopenharmony_ci			tokens = min_t(u64, TOKEN_MAX,
12362306a36Sopenharmony_ci				       entry->tokens + now -
12462306a36Sopenharmony_ci					       entry->last_time_ns);
12562306a36Sopenharmony_ci			entry->last_time_ns = now;
12662306a36Sopenharmony_ci			ret = tokens >= PACKET_COST;
12762306a36Sopenharmony_ci			entry->tokens = ret ? tokens - PACKET_COST : tokens;
12862306a36Sopenharmony_ci			spin_unlock(&entry->lock);
12962306a36Sopenharmony_ci			rcu_read_unlock();
13062306a36Sopenharmony_ci			return ret;
13162306a36Sopenharmony_ci		}
13262306a36Sopenharmony_ci	}
13362306a36Sopenharmony_ci	rcu_read_unlock();
13462306a36Sopenharmony_ci
13562306a36Sopenharmony_ci	if (atomic_inc_return(&total_entries) > max_entries)
13662306a36Sopenharmony_ci		goto err_oom;
13762306a36Sopenharmony_ci
13862306a36Sopenharmony_ci	entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
13962306a36Sopenharmony_ci	if (unlikely(!entry))
14062306a36Sopenharmony_ci		goto err_oom;
14162306a36Sopenharmony_ci
14262306a36Sopenharmony_ci	entry->net = net;
14362306a36Sopenharmony_ci	entry->ip = ip;
14462306a36Sopenharmony_ci	INIT_HLIST_NODE(&entry->hash);
14562306a36Sopenharmony_ci	spin_lock_init(&entry->lock);
14662306a36Sopenharmony_ci	entry->last_time_ns = ktime_get_coarse_boottime_ns();
14762306a36Sopenharmony_ci	entry->tokens = TOKEN_MAX - PACKET_COST;
14862306a36Sopenharmony_ci	spin_lock(&table_lock);
14962306a36Sopenharmony_ci	hlist_add_head_rcu(&entry->hash, bucket);
15062306a36Sopenharmony_ci	spin_unlock(&table_lock);
15162306a36Sopenharmony_ci	return true;
15262306a36Sopenharmony_ci
15362306a36Sopenharmony_cierr_oom:
15462306a36Sopenharmony_ci	atomic_dec(&total_entries);
15562306a36Sopenharmony_ci	return false;
15662306a36Sopenharmony_ci}
15762306a36Sopenharmony_ci
15862306a36Sopenharmony_ciint wg_ratelimiter_init(void)
15962306a36Sopenharmony_ci{
16062306a36Sopenharmony_ci	mutex_lock(&init_lock);
16162306a36Sopenharmony_ci	if (++init_refcnt != 1)
16262306a36Sopenharmony_ci		goto out;
16362306a36Sopenharmony_ci
16462306a36Sopenharmony_ci	entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
16562306a36Sopenharmony_ci	if (!entry_cache)
16662306a36Sopenharmony_ci		goto err;
16762306a36Sopenharmony_ci
16862306a36Sopenharmony_ci	/* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
16962306a36Sopenharmony_ci	 * but what it shares in common is that it uses a massive hashtable. So,
17062306a36Sopenharmony_ci	 * we borrow their wisdom about good table sizes on different systems
17162306a36Sopenharmony_ci	 * dependent on RAM. This calculation here comes from there.
17262306a36Sopenharmony_ci	 */
17362306a36Sopenharmony_ci	table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
17462306a36Sopenharmony_ci		max_t(unsigned long, 16, roundup_pow_of_two(
17562306a36Sopenharmony_ci			(totalram_pages() << PAGE_SHIFT) /
17662306a36Sopenharmony_ci			(1U << 14) / sizeof(struct hlist_head)));
17762306a36Sopenharmony_ci	max_entries = table_size * 8;
17862306a36Sopenharmony_ci
17962306a36Sopenharmony_ci	table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
18062306a36Sopenharmony_ci	if (unlikely(!table_v4))
18162306a36Sopenharmony_ci		goto err_kmemcache;
18262306a36Sopenharmony_ci
18362306a36Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
18462306a36Sopenharmony_ci	table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
18562306a36Sopenharmony_ci	if (unlikely(!table_v6)) {
18662306a36Sopenharmony_ci		kvfree(table_v4);
18762306a36Sopenharmony_ci		goto err_kmemcache;
18862306a36Sopenharmony_ci	}
18962306a36Sopenharmony_ci#endif
19062306a36Sopenharmony_ci
19162306a36Sopenharmony_ci	queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
19262306a36Sopenharmony_ci	get_random_bytes(&key, sizeof(key));
19362306a36Sopenharmony_ciout:
19462306a36Sopenharmony_ci	mutex_unlock(&init_lock);
19562306a36Sopenharmony_ci	return 0;
19662306a36Sopenharmony_ci
19762306a36Sopenharmony_cierr_kmemcache:
19862306a36Sopenharmony_ci	kmem_cache_destroy(entry_cache);
19962306a36Sopenharmony_cierr:
20062306a36Sopenharmony_ci	--init_refcnt;
20162306a36Sopenharmony_ci	mutex_unlock(&init_lock);
20262306a36Sopenharmony_ci	return -ENOMEM;
20362306a36Sopenharmony_ci}
20462306a36Sopenharmony_ci
20562306a36Sopenharmony_civoid wg_ratelimiter_uninit(void)
20662306a36Sopenharmony_ci{
20762306a36Sopenharmony_ci	mutex_lock(&init_lock);
20862306a36Sopenharmony_ci	if (!init_refcnt || --init_refcnt)
20962306a36Sopenharmony_ci		goto out;
21062306a36Sopenharmony_ci
21162306a36Sopenharmony_ci	cancel_delayed_work_sync(&gc_work);
21262306a36Sopenharmony_ci	wg_ratelimiter_gc_entries(NULL);
21362306a36Sopenharmony_ci	rcu_barrier();
21462306a36Sopenharmony_ci	kvfree(table_v4);
21562306a36Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
21662306a36Sopenharmony_ci	kvfree(table_v6);
21762306a36Sopenharmony_ci#endif
21862306a36Sopenharmony_ci	kmem_cache_destroy(entry_cache);
21962306a36Sopenharmony_ciout:
22062306a36Sopenharmony_ci	mutex_unlock(&init_lock);
22162306a36Sopenharmony_ci}
22262306a36Sopenharmony_ci
22362306a36Sopenharmony_ci#include "selftest/ratelimiter.c"
224