18c2ecf20Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0
28c2ecf20Sopenharmony_ci/*
38c2ecf20Sopenharmony_ci * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
48c2ecf20Sopenharmony_ci */
58c2ecf20Sopenharmony_ci
68c2ecf20Sopenharmony_ci#include "peerlookup.h"
78c2ecf20Sopenharmony_ci#include "peer.h"
88c2ecf20Sopenharmony_ci#include "noise.h"
98c2ecf20Sopenharmony_ci
108c2ecf20Sopenharmony_cistatic struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table,
118c2ecf20Sopenharmony_ci					const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
128c2ecf20Sopenharmony_ci{
138c2ecf20Sopenharmony_ci	/* siphash gives us a secure 64bit number based on a random key. Since
148c2ecf20Sopenharmony_ci	 * the bits are uniformly distributed, we can then mask off to get the
158c2ecf20Sopenharmony_ci	 * bits we need.
168c2ecf20Sopenharmony_ci	 */
178c2ecf20Sopenharmony_ci	const u64 hash = siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key);
188c2ecf20Sopenharmony_ci
198c2ecf20Sopenharmony_ci	return &table->hashtable[hash & (HASH_SIZE(table->hashtable) - 1)];
208c2ecf20Sopenharmony_ci}
218c2ecf20Sopenharmony_ci
228c2ecf20Sopenharmony_cistruct pubkey_hashtable *wg_pubkey_hashtable_alloc(void)
238c2ecf20Sopenharmony_ci{
248c2ecf20Sopenharmony_ci	struct pubkey_hashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);
258c2ecf20Sopenharmony_ci
268c2ecf20Sopenharmony_ci	if (!table)
278c2ecf20Sopenharmony_ci		return NULL;
288c2ecf20Sopenharmony_ci
298c2ecf20Sopenharmony_ci	get_random_bytes(&table->key, sizeof(table->key));
308c2ecf20Sopenharmony_ci	hash_init(table->hashtable);
318c2ecf20Sopenharmony_ci	mutex_init(&table->lock);
328c2ecf20Sopenharmony_ci	return table;
338c2ecf20Sopenharmony_ci}
348c2ecf20Sopenharmony_ci
358c2ecf20Sopenharmony_civoid wg_pubkey_hashtable_add(struct pubkey_hashtable *table,
368c2ecf20Sopenharmony_ci			     struct wg_peer *peer)
378c2ecf20Sopenharmony_ci{
388c2ecf20Sopenharmony_ci	mutex_lock(&table->lock);
398c2ecf20Sopenharmony_ci	hlist_add_head_rcu(&peer->pubkey_hash,
408c2ecf20Sopenharmony_ci			   pubkey_bucket(table, peer->handshake.remote_static));
418c2ecf20Sopenharmony_ci	mutex_unlock(&table->lock);
428c2ecf20Sopenharmony_ci}
438c2ecf20Sopenharmony_ci
448c2ecf20Sopenharmony_civoid wg_pubkey_hashtable_remove(struct pubkey_hashtable *table,
458c2ecf20Sopenharmony_ci				struct wg_peer *peer)
468c2ecf20Sopenharmony_ci{
478c2ecf20Sopenharmony_ci	mutex_lock(&table->lock);
488c2ecf20Sopenharmony_ci	hlist_del_init_rcu(&peer->pubkey_hash);
498c2ecf20Sopenharmony_ci	mutex_unlock(&table->lock);
508c2ecf20Sopenharmony_ci}
518c2ecf20Sopenharmony_ci
528c2ecf20Sopenharmony_ci/* Returns a strong reference to a peer */
538c2ecf20Sopenharmony_cistruct wg_peer *
548c2ecf20Sopenharmony_ciwg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
558c2ecf20Sopenharmony_ci			   const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
568c2ecf20Sopenharmony_ci{
578c2ecf20Sopenharmony_ci	struct wg_peer *iter_peer, *peer = NULL;
588c2ecf20Sopenharmony_ci
598c2ecf20Sopenharmony_ci	rcu_read_lock_bh();
608c2ecf20Sopenharmony_ci	hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey),
618c2ecf20Sopenharmony_ci				    pubkey_hash) {
628c2ecf20Sopenharmony_ci		if (!memcmp(pubkey, iter_peer->handshake.remote_static,
638c2ecf20Sopenharmony_ci			    NOISE_PUBLIC_KEY_LEN)) {
648c2ecf20Sopenharmony_ci			peer = iter_peer;
658c2ecf20Sopenharmony_ci			break;
668c2ecf20Sopenharmony_ci		}
678c2ecf20Sopenharmony_ci	}
688c2ecf20Sopenharmony_ci	peer = wg_peer_get_maybe_zero(peer);
698c2ecf20Sopenharmony_ci	rcu_read_unlock_bh();
708c2ecf20Sopenharmony_ci	return peer;
718c2ecf20Sopenharmony_ci}
728c2ecf20Sopenharmony_ci
738c2ecf20Sopenharmony_cistatic struct hlist_head *index_bucket(struct index_hashtable *table,
748c2ecf20Sopenharmony_ci				       const __le32 index)
758c2ecf20Sopenharmony_ci{
768c2ecf20Sopenharmony_ci	/* Since the indices are random and thus all bits are uniformly
778c2ecf20Sopenharmony_ci	 * distributed, we can find its bucket simply by masking.
788c2ecf20Sopenharmony_ci	 */
798c2ecf20Sopenharmony_ci	return &table->hashtable[(__force u32)index &
808c2ecf20Sopenharmony_ci				 (HASH_SIZE(table->hashtable) - 1)];
818c2ecf20Sopenharmony_ci}
828c2ecf20Sopenharmony_ci
838c2ecf20Sopenharmony_cistruct index_hashtable *wg_index_hashtable_alloc(void)
848c2ecf20Sopenharmony_ci{
858c2ecf20Sopenharmony_ci	struct index_hashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);
868c2ecf20Sopenharmony_ci
878c2ecf20Sopenharmony_ci	if (!table)
888c2ecf20Sopenharmony_ci		return NULL;
898c2ecf20Sopenharmony_ci
908c2ecf20Sopenharmony_ci	hash_init(table->hashtable);
918c2ecf20Sopenharmony_ci	spin_lock_init(&table->lock);
928c2ecf20Sopenharmony_ci	return table;
938c2ecf20Sopenharmony_ci}
948c2ecf20Sopenharmony_ci
958c2ecf20Sopenharmony_ci/* At the moment, we limit ourselves to 2^20 total peers, which generally might
968c2ecf20Sopenharmony_ci * amount to 2^20*3 items in this hashtable. The algorithm below works by
978c2ecf20Sopenharmony_ci * picking a random number and testing it. We can see that these limits mean we
988c2ecf20Sopenharmony_ci * usually succeed pretty quickly:
998c2ecf20Sopenharmony_ci *
1008c2ecf20Sopenharmony_ci * >>> def calculation(tries, size):
1018c2ecf20Sopenharmony_ci * ...     return (size / 2**32)**(tries - 1) *  (1 - (size / 2**32))
1028c2ecf20Sopenharmony_ci * ...
1038c2ecf20Sopenharmony_ci * >>> calculation(1, 2**20 * 3)
1048c2ecf20Sopenharmony_ci * 0.999267578125
1058c2ecf20Sopenharmony_ci * >>> calculation(2, 2**20 * 3)
1068c2ecf20Sopenharmony_ci * 0.0007318854331970215
1078c2ecf20Sopenharmony_ci * >>> calculation(3, 2**20 * 3)
1088c2ecf20Sopenharmony_ci * 5.360489012673497e-07
1098c2ecf20Sopenharmony_ci * >>> calculation(4, 2**20 * 3)
1108c2ecf20Sopenharmony_ci * 3.9261394135792216e-10
1118c2ecf20Sopenharmony_ci *
1128c2ecf20Sopenharmony_ci * At the moment, we don't do any masking, so this algorithm isn't exactly
1138c2ecf20Sopenharmony_ci * constant time in either the random guessing or in the hash list lookup. We
1148c2ecf20Sopenharmony_ci * could require a minimum of 3 tries, which would successfully mask the
1158c2ecf20Sopenharmony_ci * guessing. this would not, however, help with the growing hash lengths, which
1168c2ecf20Sopenharmony_ci * is another thing to consider moving forward.
1178c2ecf20Sopenharmony_ci */
1188c2ecf20Sopenharmony_ci
1198c2ecf20Sopenharmony_ci__le32 wg_index_hashtable_insert(struct index_hashtable *table,
1208c2ecf20Sopenharmony_ci				 struct index_hashtable_entry *entry)
1218c2ecf20Sopenharmony_ci{
1228c2ecf20Sopenharmony_ci	struct index_hashtable_entry *existing_entry;
1238c2ecf20Sopenharmony_ci
1248c2ecf20Sopenharmony_ci	spin_lock_bh(&table->lock);
1258c2ecf20Sopenharmony_ci	hlist_del_init_rcu(&entry->index_hash);
1268c2ecf20Sopenharmony_ci	spin_unlock_bh(&table->lock);
1278c2ecf20Sopenharmony_ci
1288c2ecf20Sopenharmony_ci	rcu_read_lock_bh();
1298c2ecf20Sopenharmony_ci
1308c2ecf20Sopenharmony_cisearch_unused_slot:
1318c2ecf20Sopenharmony_ci	/* First we try to find an unused slot, randomly, while unlocked. */
1328c2ecf20Sopenharmony_ci	entry->index = (__force __le32)get_random_u32();
1338c2ecf20Sopenharmony_ci	hlist_for_each_entry_rcu_bh(existing_entry,
1348c2ecf20Sopenharmony_ci				    index_bucket(table, entry->index),
1358c2ecf20Sopenharmony_ci				    index_hash) {
1368c2ecf20Sopenharmony_ci		if (existing_entry->index == entry->index)
1378c2ecf20Sopenharmony_ci			/* If it's already in use, we continue searching. */
1388c2ecf20Sopenharmony_ci			goto search_unused_slot;
1398c2ecf20Sopenharmony_ci	}
1408c2ecf20Sopenharmony_ci
1418c2ecf20Sopenharmony_ci	/* Once we've found an unused slot, we lock it, and then double-check
1428c2ecf20Sopenharmony_ci	 * that nobody else stole it from us.
1438c2ecf20Sopenharmony_ci	 */
1448c2ecf20Sopenharmony_ci	spin_lock_bh(&table->lock);
1458c2ecf20Sopenharmony_ci	hlist_for_each_entry_rcu_bh(existing_entry,
1468c2ecf20Sopenharmony_ci				    index_bucket(table, entry->index),
1478c2ecf20Sopenharmony_ci				    index_hash) {
1488c2ecf20Sopenharmony_ci		if (existing_entry->index == entry->index) {
1498c2ecf20Sopenharmony_ci			spin_unlock_bh(&table->lock);
1508c2ecf20Sopenharmony_ci			/* If it was stolen, we start over. */
1518c2ecf20Sopenharmony_ci			goto search_unused_slot;
1528c2ecf20Sopenharmony_ci		}
1538c2ecf20Sopenharmony_ci	}
1548c2ecf20Sopenharmony_ci	/* Otherwise, we know we have it exclusively (since we're locked),
1558c2ecf20Sopenharmony_ci	 * so we insert.
1568c2ecf20Sopenharmony_ci	 */
1578c2ecf20Sopenharmony_ci	hlist_add_head_rcu(&entry->index_hash,
1588c2ecf20Sopenharmony_ci			   index_bucket(table, entry->index));
1598c2ecf20Sopenharmony_ci	spin_unlock_bh(&table->lock);
1608c2ecf20Sopenharmony_ci
1618c2ecf20Sopenharmony_ci	rcu_read_unlock_bh();
1628c2ecf20Sopenharmony_ci
1638c2ecf20Sopenharmony_ci	return entry->index;
1648c2ecf20Sopenharmony_ci}
1658c2ecf20Sopenharmony_ci
1668c2ecf20Sopenharmony_cibool wg_index_hashtable_replace(struct index_hashtable *table,
1678c2ecf20Sopenharmony_ci				struct index_hashtable_entry *old,
1688c2ecf20Sopenharmony_ci				struct index_hashtable_entry *new)
1698c2ecf20Sopenharmony_ci{
1708c2ecf20Sopenharmony_ci	bool ret;
1718c2ecf20Sopenharmony_ci
1728c2ecf20Sopenharmony_ci	spin_lock_bh(&table->lock);
1738c2ecf20Sopenharmony_ci	ret = !hlist_unhashed(&old->index_hash);
1748c2ecf20Sopenharmony_ci	if (unlikely(!ret))
1758c2ecf20Sopenharmony_ci		goto out;
1768c2ecf20Sopenharmony_ci
1778c2ecf20Sopenharmony_ci	new->index = old->index;
1788c2ecf20Sopenharmony_ci	hlist_replace_rcu(&old->index_hash, &new->index_hash);
1798c2ecf20Sopenharmony_ci
1808c2ecf20Sopenharmony_ci	/* Calling init here NULLs out index_hash, and in fact after this
1818c2ecf20Sopenharmony_ci	 * function returns, it's theoretically possible for this to get
1828c2ecf20Sopenharmony_ci	 * reinserted elsewhere. That means the RCU lookup below might either
1838c2ecf20Sopenharmony_ci	 * terminate early or jump between buckets, in which case the packet
1848c2ecf20Sopenharmony_ci	 * simply gets dropped, which isn't terrible.
1858c2ecf20Sopenharmony_ci	 */
1868c2ecf20Sopenharmony_ci	INIT_HLIST_NODE(&old->index_hash);
1878c2ecf20Sopenharmony_ciout:
1888c2ecf20Sopenharmony_ci	spin_unlock_bh(&table->lock);
1898c2ecf20Sopenharmony_ci	return ret;
1908c2ecf20Sopenharmony_ci}
1918c2ecf20Sopenharmony_ci
1928c2ecf20Sopenharmony_civoid wg_index_hashtable_remove(struct index_hashtable *table,
1938c2ecf20Sopenharmony_ci			       struct index_hashtable_entry *entry)
1948c2ecf20Sopenharmony_ci{
1958c2ecf20Sopenharmony_ci	spin_lock_bh(&table->lock);
1968c2ecf20Sopenharmony_ci	hlist_del_init_rcu(&entry->index_hash);
1978c2ecf20Sopenharmony_ci	spin_unlock_bh(&table->lock);
1988c2ecf20Sopenharmony_ci}
1998c2ecf20Sopenharmony_ci
2008c2ecf20Sopenharmony_ci/* Returns a strong reference to a entry->peer */
2018c2ecf20Sopenharmony_cistruct index_hashtable_entry *
2028c2ecf20Sopenharmony_ciwg_index_hashtable_lookup(struct index_hashtable *table,
2038c2ecf20Sopenharmony_ci			  const enum index_hashtable_type type_mask,
2048c2ecf20Sopenharmony_ci			  const __le32 index, struct wg_peer **peer)
2058c2ecf20Sopenharmony_ci{
2068c2ecf20Sopenharmony_ci	struct index_hashtable_entry *iter_entry, *entry = NULL;
2078c2ecf20Sopenharmony_ci
2088c2ecf20Sopenharmony_ci	rcu_read_lock_bh();
2098c2ecf20Sopenharmony_ci	hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index),
2108c2ecf20Sopenharmony_ci				    index_hash) {
2118c2ecf20Sopenharmony_ci		if (iter_entry->index == index) {
2128c2ecf20Sopenharmony_ci			if (likely(iter_entry->type & type_mask))
2138c2ecf20Sopenharmony_ci				entry = iter_entry;
2148c2ecf20Sopenharmony_ci			break;
2158c2ecf20Sopenharmony_ci		}
2168c2ecf20Sopenharmony_ci	}
2178c2ecf20Sopenharmony_ci	if (likely(entry)) {
2188c2ecf20Sopenharmony_ci		entry->peer = wg_peer_get_maybe_zero(entry->peer);
2198c2ecf20Sopenharmony_ci		if (likely(entry->peer))
2208c2ecf20Sopenharmony_ci			*peer = entry->peer;
2218c2ecf20Sopenharmony_ci		else
2228c2ecf20Sopenharmony_ci			entry = NULL;
2238c2ecf20Sopenharmony_ci	}
2248c2ecf20Sopenharmony_ci	rcu_read_unlock_bh();
2258c2ecf20Sopenharmony_ci	return entry;
2268c2ecf20Sopenharmony_ci}
227