18c2ecf20Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0
28c2ecf20Sopenharmony_ci/*
38c2ecf20Sopenharmony_ci * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
48c2ecf20Sopenharmony_ci */
58c2ecf20Sopenharmony_ci
68c2ecf20Sopenharmony_ci#include "ratelimiter.h"
78c2ecf20Sopenharmony_ci#include <linux/siphash.h>
88c2ecf20Sopenharmony_ci#include <linux/mm.h>
98c2ecf20Sopenharmony_ci#include <linux/slab.h>
108c2ecf20Sopenharmony_ci#include <net/ip.h>
118c2ecf20Sopenharmony_ci
128c2ecf20Sopenharmony_cistatic struct kmem_cache *entry_cache;
138c2ecf20Sopenharmony_cistatic hsiphash_key_t key;
148c2ecf20Sopenharmony_cistatic spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
158c2ecf20Sopenharmony_cistatic DEFINE_MUTEX(init_lock);
168c2ecf20Sopenharmony_cistatic u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
178c2ecf20Sopenharmony_cistatic atomic_t total_entries = ATOMIC_INIT(0);
188c2ecf20Sopenharmony_cistatic unsigned int max_entries, table_size;
198c2ecf20Sopenharmony_cistatic void wg_ratelimiter_gc_entries(struct work_struct *);
208c2ecf20Sopenharmony_cistatic DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
218c2ecf20Sopenharmony_cistatic struct hlist_head *table_v4;
228c2ecf20Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
238c2ecf20Sopenharmony_cistatic struct hlist_head *table_v6;
248c2ecf20Sopenharmony_ci#endif
258c2ecf20Sopenharmony_ci
268c2ecf20Sopenharmony_cistruct ratelimiter_entry {
278c2ecf20Sopenharmony_ci	u64 last_time_ns, tokens, ip;
288c2ecf20Sopenharmony_ci	void *net;
298c2ecf20Sopenharmony_ci	spinlock_t lock;
308c2ecf20Sopenharmony_ci	struct hlist_node hash;
318c2ecf20Sopenharmony_ci	struct rcu_head rcu;
328c2ecf20Sopenharmony_ci};
338c2ecf20Sopenharmony_ci
348c2ecf20Sopenharmony_cienum {
358c2ecf20Sopenharmony_ci	PACKETS_PER_SECOND = 20,
368c2ecf20Sopenharmony_ci	PACKETS_BURSTABLE = 5,
378c2ecf20Sopenharmony_ci	PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
388c2ecf20Sopenharmony_ci	TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
398c2ecf20Sopenharmony_ci};
408c2ecf20Sopenharmony_ci
418c2ecf20Sopenharmony_cistatic void entry_free(struct rcu_head *rcu)
428c2ecf20Sopenharmony_ci{
438c2ecf20Sopenharmony_ci	kmem_cache_free(entry_cache,
448c2ecf20Sopenharmony_ci			container_of(rcu, struct ratelimiter_entry, rcu));
458c2ecf20Sopenharmony_ci	atomic_dec(&total_entries);
468c2ecf20Sopenharmony_ci}
478c2ecf20Sopenharmony_ci
488c2ecf20Sopenharmony_cistatic void entry_uninit(struct ratelimiter_entry *entry)
498c2ecf20Sopenharmony_ci{
508c2ecf20Sopenharmony_ci	hlist_del_rcu(&entry->hash);
518c2ecf20Sopenharmony_ci	call_rcu(&entry->rcu, entry_free);
528c2ecf20Sopenharmony_ci}
538c2ecf20Sopenharmony_ci
548c2ecf20Sopenharmony_ci/* Calling this function with a NULL work uninits all entries. */
558c2ecf20Sopenharmony_cistatic void wg_ratelimiter_gc_entries(struct work_struct *work)
568c2ecf20Sopenharmony_ci{
578c2ecf20Sopenharmony_ci	const u64 now = ktime_get_coarse_boottime_ns();
588c2ecf20Sopenharmony_ci	struct ratelimiter_entry *entry;
598c2ecf20Sopenharmony_ci	struct hlist_node *temp;
608c2ecf20Sopenharmony_ci	unsigned int i;
618c2ecf20Sopenharmony_ci
628c2ecf20Sopenharmony_ci	for (i = 0; i < table_size; ++i) {
638c2ecf20Sopenharmony_ci		spin_lock(&table_lock);
648c2ecf20Sopenharmony_ci		hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
658c2ecf20Sopenharmony_ci			if (unlikely(!work) ||
668c2ecf20Sopenharmony_ci			    now - entry->last_time_ns > NSEC_PER_SEC)
678c2ecf20Sopenharmony_ci				entry_uninit(entry);
688c2ecf20Sopenharmony_ci		}
698c2ecf20Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
708c2ecf20Sopenharmony_ci		hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
718c2ecf20Sopenharmony_ci			if (unlikely(!work) ||
728c2ecf20Sopenharmony_ci			    now - entry->last_time_ns > NSEC_PER_SEC)
738c2ecf20Sopenharmony_ci				entry_uninit(entry);
748c2ecf20Sopenharmony_ci		}
758c2ecf20Sopenharmony_ci#endif
768c2ecf20Sopenharmony_ci		spin_unlock(&table_lock);
778c2ecf20Sopenharmony_ci		if (likely(work))
788c2ecf20Sopenharmony_ci			cond_resched();
798c2ecf20Sopenharmony_ci	}
808c2ecf20Sopenharmony_ci	if (likely(work))
818c2ecf20Sopenharmony_ci		queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
828c2ecf20Sopenharmony_ci}
838c2ecf20Sopenharmony_ci
848c2ecf20Sopenharmony_cibool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
858c2ecf20Sopenharmony_ci{
868c2ecf20Sopenharmony_ci	/* We only take the bottom half of the net pointer, so that we can hash
878c2ecf20Sopenharmony_ci	 * 3 words in the end. This way, siphash's len param fits into the final
888c2ecf20Sopenharmony_ci	 * u32, and we don't incur an extra round.
898c2ecf20Sopenharmony_ci	 */
908c2ecf20Sopenharmony_ci	const u32 net_word = (unsigned long)net;
918c2ecf20Sopenharmony_ci	struct ratelimiter_entry *entry;
928c2ecf20Sopenharmony_ci	struct hlist_head *bucket;
938c2ecf20Sopenharmony_ci	u64 ip;
948c2ecf20Sopenharmony_ci
958c2ecf20Sopenharmony_ci	if (skb->protocol == htons(ETH_P_IP)) {
968c2ecf20Sopenharmony_ci		ip = (u64 __force)ip_hdr(skb)->saddr;
978c2ecf20Sopenharmony_ci		bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
988c2ecf20Sopenharmony_ci				   (table_size - 1)];
998c2ecf20Sopenharmony_ci	}
1008c2ecf20Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
1018c2ecf20Sopenharmony_ci	else if (skb->protocol == htons(ETH_P_IPV6)) {
1028c2ecf20Sopenharmony_ci		/* Only use 64 bits, so as to ratelimit the whole /64. */
1038c2ecf20Sopenharmony_ci		memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
1048c2ecf20Sopenharmony_ci		bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
1058c2ecf20Sopenharmony_ci				   (table_size - 1)];
1068c2ecf20Sopenharmony_ci	}
1078c2ecf20Sopenharmony_ci#endif
1088c2ecf20Sopenharmony_ci	else
1098c2ecf20Sopenharmony_ci		return false;
1108c2ecf20Sopenharmony_ci	rcu_read_lock();
1118c2ecf20Sopenharmony_ci	hlist_for_each_entry_rcu(entry, bucket, hash) {
1128c2ecf20Sopenharmony_ci		if (entry->net == net && entry->ip == ip) {
1138c2ecf20Sopenharmony_ci			u64 now, tokens;
1148c2ecf20Sopenharmony_ci			bool ret;
1158c2ecf20Sopenharmony_ci			/* Quasi-inspired by nft_limit.c, but this is actually a
1168c2ecf20Sopenharmony_ci			 * slightly different algorithm. Namely, we incorporate
1178c2ecf20Sopenharmony_ci			 * the burst as part of the maximum tokens, rather than
1188c2ecf20Sopenharmony_ci			 * as part of the rate.
1198c2ecf20Sopenharmony_ci			 */
1208c2ecf20Sopenharmony_ci			spin_lock(&entry->lock);
1218c2ecf20Sopenharmony_ci			now = ktime_get_coarse_boottime_ns();
1228c2ecf20Sopenharmony_ci			tokens = min_t(u64, TOKEN_MAX,
1238c2ecf20Sopenharmony_ci				       entry->tokens + now -
1248c2ecf20Sopenharmony_ci					       entry->last_time_ns);
1258c2ecf20Sopenharmony_ci			entry->last_time_ns = now;
1268c2ecf20Sopenharmony_ci			ret = tokens >= PACKET_COST;
1278c2ecf20Sopenharmony_ci			entry->tokens = ret ? tokens - PACKET_COST : tokens;
1288c2ecf20Sopenharmony_ci			spin_unlock(&entry->lock);
1298c2ecf20Sopenharmony_ci			rcu_read_unlock();
1308c2ecf20Sopenharmony_ci			return ret;
1318c2ecf20Sopenharmony_ci		}
1328c2ecf20Sopenharmony_ci	}
1338c2ecf20Sopenharmony_ci	rcu_read_unlock();
1348c2ecf20Sopenharmony_ci
1358c2ecf20Sopenharmony_ci	if (atomic_inc_return(&total_entries) > max_entries)
1368c2ecf20Sopenharmony_ci		goto err_oom;
1378c2ecf20Sopenharmony_ci
1388c2ecf20Sopenharmony_ci	entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
1398c2ecf20Sopenharmony_ci	if (unlikely(!entry))
1408c2ecf20Sopenharmony_ci		goto err_oom;
1418c2ecf20Sopenharmony_ci
1428c2ecf20Sopenharmony_ci	entry->net = net;
1438c2ecf20Sopenharmony_ci	entry->ip = ip;
1448c2ecf20Sopenharmony_ci	INIT_HLIST_NODE(&entry->hash);
1458c2ecf20Sopenharmony_ci	spin_lock_init(&entry->lock);
1468c2ecf20Sopenharmony_ci	entry->last_time_ns = ktime_get_coarse_boottime_ns();
1478c2ecf20Sopenharmony_ci	entry->tokens = TOKEN_MAX - PACKET_COST;
1488c2ecf20Sopenharmony_ci	spin_lock(&table_lock);
1498c2ecf20Sopenharmony_ci	hlist_add_head_rcu(&entry->hash, bucket);
1508c2ecf20Sopenharmony_ci	spin_unlock(&table_lock);
1518c2ecf20Sopenharmony_ci	return true;
1528c2ecf20Sopenharmony_ci
1538c2ecf20Sopenharmony_cierr_oom:
1548c2ecf20Sopenharmony_ci	atomic_dec(&total_entries);
1558c2ecf20Sopenharmony_ci	return false;
1568c2ecf20Sopenharmony_ci}
1578c2ecf20Sopenharmony_ci
1588c2ecf20Sopenharmony_ciint wg_ratelimiter_init(void)
1598c2ecf20Sopenharmony_ci{
1608c2ecf20Sopenharmony_ci	mutex_lock(&init_lock);
1618c2ecf20Sopenharmony_ci	if (++init_refcnt != 1)
1628c2ecf20Sopenharmony_ci		goto out;
1638c2ecf20Sopenharmony_ci
1648c2ecf20Sopenharmony_ci	entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
1658c2ecf20Sopenharmony_ci	if (!entry_cache)
1668c2ecf20Sopenharmony_ci		goto err;
1678c2ecf20Sopenharmony_ci
1688c2ecf20Sopenharmony_ci	/* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
1698c2ecf20Sopenharmony_ci	 * but what it shares in common is that it uses a massive hashtable. So,
1708c2ecf20Sopenharmony_ci	 * we borrow their wisdom about good table sizes on different systems
1718c2ecf20Sopenharmony_ci	 * dependent on RAM. This calculation here comes from there.
1728c2ecf20Sopenharmony_ci	 */
1738c2ecf20Sopenharmony_ci	table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
1748c2ecf20Sopenharmony_ci		max_t(unsigned long, 16, roundup_pow_of_two(
1758c2ecf20Sopenharmony_ci			(totalram_pages() << PAGE_SHIFT) /
1768c2ecf20Sopenharmony_ci			(1U << 14) / sizeof(struct hlist_head)));
1778c2ecf20Sopenharmony_ci	max_entries = table_size * 8;
1788c2ecf20Sopenharmony_ci
1798c2ecf20Sopenharmony_ci	table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
1808c2ecf20Sopenharmony_ci	if (unlikely(!table_v4))
1818c2ecf20Sopenharmony_ci		goto err_kmemcache;
1828c2ecf20Sopenharmony_ci
1838c2ecf20Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
1848c2ecf20Sopenharmony_ci	table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
1858c2ecf20Sopenharmony_ci	if (unlikely(!table_v6)) {
1868c2ecf20Sopenharmony_ci		kvfree(table_v4);
1878c2ecf20Sopenharmony_ci		goto err_kmemcache;
1888c2ecf20Sopenharmony_ci	}
1898c2ecf20Sopenharmony_ci#endif
1908c2ecf20Sopenharmony_ci
1918c2ecf20Sopenharmony_ci	queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
1928c2ecf20Sopenharmony_ci	get_random_bytes(&key, sizeof(key));
1938c2ecf20Sopenharmony_ciout:
1948c2ecf20Sopenharmony_ci	mutex_unlock(&init_lock);
1958c2ecf20Sopenharmony_ci	return 0;
1968c2ecf20Sopenharmony_ci
1978c2ecf20Sopenharmony_cierr_kmemcache:
1988c2ecf20Sopenharmony_ci	kmem_cache_destroy(entry_cache);
1998c2ecf20Sopenharmony_cierr:
2008c2ecf20Sopenharmony_ci	--init_refcnt;
2018c2ecf20Sopenharmony_ci	mutex_unlock(&init_lock);
2028c2ecf20Sopenharmony_ci	return -ENOMEM;
2038c2ecf20Sopenharmony_ci}
2048c2ecf20Sopenharmony_ci
2058c2ecf20Sopenharmony_civoid wg_ratelimiter_uninit(void)
2068c2ecf20Sopenharmony_ci{
2078c2ecf20Sopenharmony_ci	mutex_lock(&init_lock);
2088c2ecf20Sopenharmony_ci	if (!init_refcnt || --init_refcnt)
2098c2ecf20Sopenharmony_ci		goto out;
2108c2ecf20Sopenharmony_ci
2118c2ecf20Sopenharmony_ci	cancel_delayed_work_sync(&gc_work);
2128c2ecf20Sopenharmony_ci	wg_ratelimiter_gc_entries(NULL);
2138c2ecf20Sopenharmony_ci	rcu_barrier();
2148c2ecf20Sopenharmony_ci	kvfree(table_v4);
2158c2ecf20Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
2168c2ecf20Sopenharmony_ci	kvfree(table_v6);
2178c2ecf20Sopenharmony_ci#endif
2188c2ecf20Sopenharmony_ci	kmem_cache_destroy(entry_cache);
2198c2ecf20Sopenharmony_ciout:
2208c2ecf20Sopenharmony_ci	mutex_unlock(&init_lock);
2218c2ecf20Sopenharmony_ci}
2228c2ecf20Sopenharmony_ci
2238c2ecf20Sopenharmony_ci#include "selftest/ratelimiter.c"
224