1// SPDX-License-Identifier: GPL-2.0-or-later
2/*
3 * INET		An implementation of the TCP/IP protocol suite for the LINUX
4 *		operating system.  INET is implemented using the BSD Socket
5 *		interface as the means of communication with the user level.
6 *
7 *		Generic INET transport hashtables
8 *
9 * Authors:	Lotsa people, from code originally in tcp
10 */
11
12#include <linux/module.h>
13#include <linux/random.h>
14#include <linux/sched.h>
15#include <linux/slab.h>
16#include <linux/wait.h>
17#include <linux/vmalloc.h>
18#include <linux/memblock.h>
19#include <linux/hck/lite_hck_inet.h>
20
21#include <net/addrconf.h>
22#include <net/inet_connection_sock.h>
23#include <net/inet_hashtables.h>
24#if IS_ENABLED(CONFIG_IPV6)
25#include <net/inet6_hashtables.h>
26#endif
27#include <net/secure_seq.h>
28#include <net/ip.h>
29#include <net/tcp.h>
30#include <net/sock_reuseport.h>
31
32static u32 inet_ehashfn(const struct net *net, const __be32 laddr,
33			const __u16 lport, const __be32 faddr,
34			const __be16 fport)
35{
36	static u32 inet_ehash_secret __read_mostly;
37
38	net_get_random_once(&inet_ehash_secret, sizeof(inet_ehash_secret));
39
40	return __inet_ehashfn(laddr, lport, faddr, fport,
41			      inet_ehash_secret + net_hash_mix(net));
42}
43
44/* This function handles inet_sock, but also timewait and request sockets
45 * for IPv4/IPv6.
46 */
47static u32 sk_ehashfn(const struct sock *sk)
48{
49#if IS_ENABLED(CONFIG_IPV6)
50	if (sk->sk_family == AF_INET6 &&
51	    !ipv6_addr_v4mapped(&sk->sk_v6_daddr))
52		return inet6_ehashfn(sock_net(sk),
53				     &sk->sk_v6_rcv_saddr, sk->sk_num,
54				     &sk->sk_v6_daddr, sk->sk_dport);
55#endif
56
57	if (sk->sk_family == AF_NINET) {
58		u32 ret = 0;
59
60		CALL_HCK_LITE_HOOK(nip_ninet_ehashfn_lhck, sk, &ret);
61		return ret;
62	}
63
64	return inet_ehashfn(sock_net(sk),
65			    sk->sk_rcv_saddr, sk->sk_num,
66			    sk->sk_daddr, sk->sk_dport);
67}
68
69/*
70 * Allocate and initialize a new local port bind bucket.
71 * The bindhash mutex for snum's hash chain must be held here.
72 */
73struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep,
74						 struct net *net,
75						 struct inet_bind_hashbucket *head,
76						 const unsigned short snum,
77						 int l3mdev)
78{
79	struct inet_bind_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
80
81	if (tb) {
82		write_pnet(&tb->ib_net, net);
83		tb->l3mdev    = l3mdev;
84		tb->port      = snum;
85		tb->fastreuse = 0;
86		tb->fastreuseport = 0;
87		INIT_HLIST_HEAD(&tb->owners);
88		hlist_add_head(&tb->node, &head->chain);
89	}
90	return tb;
91}
92
93/*
94 * Caller must hold hashbucket lock for this tb with local BH disabled
95 */
96void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb)
97{
98	if (hlist_empty(&tb->owners)) {
99		__hlist_del(&tb->node);
100		kmem_cache_free(cachep, tb);
101	}
102}
103
104void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
105		    const unsigned short snum)
106{
107	inet_sk(sk)->inet_num = snum;
108	sk_add_bind_node(sk, &tb->owners);
109	inet_csk(sk)->icsk_bind_hash = tb;
110}
111
112/*
113 * Get rid of any references to a local port held by the given sock.
114 */
115static void __inet_put_port(struct sock *sk)
116{
117	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
118	const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
119			hashinfo->bhash_size);
120	struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
121	struct inet_bind_bucket *tb;
122
123	spin_lock(&head->lock);
124	tb = inet_csk(sk)->icsk_bind_hash;
125	__sk_del_bind_node(sk);
126	inet_csk(sk)->icsk_bind_hash = NULL;
127	inet_sk(sk)->inet_num = 0;
128	inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
129	spin_unlock(&head->lock);
130}
131
132void inet_put_port(struct sock *sk)
133{
134	local_bh_disable();
135	__inet_put_port(sk);
136	local_bh_enable();
137}
138EXPORT_SYMBOL(inet_put_port);
139
140int __inet_inherit_port(const struct sock *sk, struct sock *child)
141{
142	struct inet_hashinfo *table = sk->sk_prot->h.hashinfo;
143	unsigned short port = inet_sk(child)->inet_num;
144	const int bhash = inet_bhashfn(sock_net(sk), port,
145			table->bhash_size);
146	struct inet_bind_hashbucket *head = &table->bhash[bhash];
147	struct inet_bind_bucket *tb;
148	int l3mdev;
149
150	spin_lock(&head->lock);
151	tb = inet_csk(sk)->icsk_bind_hash;
152	if (unlikely(!tb)) {
153		spin_unlock(&head->lock);
154		return -ENOENT;
155	}
156	if (tb->port != port) {
157		l3mdev = inet_sk_bound_l3mdev(sk);
158
159		/* NOTE: using tproxy and redirecting skbs to a proxy
160		 * on a different listener port breaks the assumption
161		 * that the listener socket's icsk_bind_hash is the same
162		 * as that of the child socket. We have to look up or
163		 * create a new bind bucket for the child here. */
164		inet_bind_bucket_for_each(tb, &head->chain) {
165			if (net_eq(ib_net(tb), sock_net(sk)) &&
166			    tb->l3mdev == l3mdev && tb->port == port)
167				break;
168		}
169		if (!tb) {
170			tb = inet_bind_bucket_create(table->bind_bucket_cachep,
171						     sock_net(sk), head, port,
172						     l3mdev);
173			if (!tb) {
174				spin_unlock(&head->lock);
175				return -ENOMEM;
176			}
177		}
178		inet_csk_update_fastreuse(tb, child);
179	}
180	inet_bind_hash(child, tb, port);
181	spin_unlock(&head->lock);
182
183	return 0;
184}
185EXPORT_SYMBOL_GPL(__inet_inherit_port);
186
187static struct inet_listen_hashbucket *
188inet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk)
189{
190	u32 hash;
191
192#if IS_ENABLED(CONFIG_IPV6)
193	if (sk->sk_family == AF_INET6)
194		hash = ipv6_portaddr_hash(sock_net(sk),
195					  &sk->sk_v6_rcv_saddr,
196					  inet_sk(sk)->inet_num);
197	else
198#endif
199		hash = ipv4_portaddr_hash(sock_net(sk),
200					  inet_sk(sk)->inet_rcv_saddr,
201					  inet_sk(sk)->inet_num);
202	return inet_lhash2_bucket(h, hash);
203}
204
205static void inet_hash2(struct inet_hashinfo *h, struct sock *sk)
206{
207	struct inet_listen_hashbucket *ilb2;
208
209	if (!h->lhash2)
210		return;
211
212	ilb2 = inet_lhash2_bucket_sk(h, sk);
213
214	spin_lock(&ilb2->lock);
215	if (sk->sk_reuseport && sk->sk_family == AF_INET6)
216		hlist_add_tail_rcu(&inet_csk(sk)->icsk_listen_portaddr_node,
217				   &ilb2->head);
218	else
219		hlist_add_head_rcu(&inet_csk(sk)->icsk_listen_portaddr_node,
220				   &ilb2->head);
221	ilb2->count++;
222	spin_unlock(&ilb2->lock);
223}
224
225static void inet_unhash2(struct inet_hashinfo *h, struct sock *sk)
226{
227	struct inet_listen_hashbucket *ilb2;
228
229	if (!h->lhash2 ||
230	    WARN_ON_ONCE(hlist_unhashed(&inet_csk(sk)->icsk_listen_portaddr_node)))
231		return;
232
233	ilb2 = inet_lhash2_bucket_sk(h, sk);
234
235	spin_lock(&ilb2->lock);
236	hlist_del_init_rcu(&inet_csk(sk)->icsk_listen_portaddr_node);
237	ilb2->count--;
238	spin_unlock(&ilb2->lock);
239}
240
241static inline int compute_score(struct sock *sk, struct net *net,
242				const unsigned short hnum, const __be32 daddr,
243				const int dif, const int sdif)
244{
245	int score = -1;
246
247	if (net_eq(sock_net(sk), net) && sk->sk_num == hnum &&
248			!ipv6_only_sock(sk)) {
249		if (sk->sk_rcv_saddr != daddr)
250			return -1;
251
252		if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
253			return -1;
254		score =  sk->sk_bound_dev_if ? 2 : 1;
255
256		if (sk->sk_family == PF_INET)
257			score++;
258		if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
259			score++;
260	}
261	return score;
262}
263
264static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk,
265					    struct sk_buff *skb, int doff,
266					    __be32 saddr, __be16 sport,
267					    __be32 daddr, unsigned short hnum)
268{
269	struct sock *reuse_sk = NULL;
270	u32 phash;
271
272	if (sk->sk_reuseport) {
273		phash = inet_ehashfn(net, daddr, hnum, saddr, sport);
274		reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
275	}
276	return reuse_sk;
277}
278
279/*
280 * Here are some nice properties to exploit here. The BSD API
281 * does not allow a listening sock to specify the remote port nor the
282 * remote address for the connection. So always assume those are both
283 * wildcarded during the search since they can never be otherwise.
284 */
285
286/* called with rcu_read_lock() : No refcount taken on the socket */
287static struct sock *inet_lhash2_lookup(struct net *net,
288				struct inet_listen_hashbucket *ilb2,
289				struct sk_buff *skb, int doff,
290				const __be32 saddr, __be16 sport,
291				const __be32 daddr, const unsigned short hnum,
292				const int dif, const int sdif)
293{
294	struct inet_connection_sock *icsk;
295	struct sock *sk, *result = NULL;
296	int score, hiscore = 0;
297
298	inet_lhash2_for_each_icsk_rcu(icsk, &ilb2->head) {
299		sk = (struct sock *)icsk;
300		score = compute_score(sk, net, hnum, daddr, dif, sdif);
301		if (score > hiscore) {
302			result = lookup_reuseport(net, sk, skb, doff,
303						  saddr, sport, daddr, hnum);
304			if (result)
305				return result;
306
307			result = sk;
308			hiscore = score;
309		}
310	}
311
312	return result;
313}
314
315static inline struct sock *inet_lookup_run_bpf(struct net *net,
316					       struct inet_hashinfo *hashinfo,
317					       struct sk_buff *skb, int doff,
318					       __be32 saddr, __be16 sport,
319					       __be32 daddr, u16 hnum)
320{
321	struct sock *sk, *reuse_sk;
322	bool no_reuseport;
323
324	if (hashinfo != &tcp_hashinfo)
325		return NULL; /* only TCP is supported */
326
327	no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP,
328					    saddr, sport, daddr, hnum, &sk);
329	if (no_reuseport || IS_ERR_OR_NULL(sk))
330		return sk;
331
332	reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum);
333	if (reuse_sk)
334		sk = reuse_sk;
335	return sk;
336}
337
338struct sock *__inet_lookup_listener(struct net *net,
339				    struct inet_hashinfo *hashinfo,
340				    struct sk_buff *skb, int doff,
341				    const __be32 saddr, __be16 sport,
342				    const __be32 daddr, const unsigned short hnum,
343				    const int dif, const int sdif)
344{
345	struct inet_listen_hashbucket *ilb2;
346	struct sock *result = NULL;
347	unsigned int hash2;
348
349	/* Lookup redirect from BPF */
350	if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
351		result = inet_lookup_run_bpf(net, hashinfo, skb, doff,
352					     saddr, sport, daddr, hnum);
353		if (result)
354			goto done;
355	}
356
357	hash2 = ipv4_portaddr_hash(net, daddr, hnum);
358	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
359
360	result = inet_lhash2_lookup(net, ilb2, skb, doff,
361				    saddr, sport, daddr, hnum,
362				    dif, sdif);
363	if (result)
364		goto done;
365
366	/* Lookup lhash2 with INADDR_ANY */
367	hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
368	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
369
370	result = inet_lhash2_lookup(net, ilb2, skb, doff,
371				    saddr, sport, htonl(INADDR_ANY), hnum,
372				    dif, sdif);
373done:
374	if (IS_ERR(result))
375		return NULL;
376	return result;
377}
378EXPORT_SYMBOL_GPL(__inet_lookup_listener);
379
380/* All sockets share common refcount, but have different destructors */
381void sock_gen_put(struct sock *sk)
382{
383	if (!refcount_dec_and_test(&sk->sk_refcnt))
384		return;
385
386	if (sk->sk_state == TCP_TIME_WAIT)
387		inet_twsk_free(inet_twsk(sk));
388	else if (sk->sk_state == TCP_NEW_SYN_RECV)
389		reqsk_free(inet_reqsk(sk));
390	else
391		sk_free(sk);
392}
393EXPORT_SYMBOL_GPL(sock_gen_put);
394
395void sock_edemux(struct sk_buff *skb)
396{
397	sock_gen_put(skb->sk);
398}
399EXPORT_SYMBOL(sock_edemux);
400
401struct sock *__inet_lookup_established(struct net *net,
402				  struct inet_hashinfo *hashinfo,
403				  const __be32 saddr, const __be16 sport,
404				  const __be32 daddr, const u16 hnum,
405				  const int dif, const int sdif)
406{
407	INET_ADDR_COOKIE(acookie, saddr, daddr);
408	const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
409	struct sock *sk;
410	const struct hlist_nulls_node *node;
411	/* Optimize here for direct hit, only listening connections can
412	 * have wildcards anyways.
413	 */
414	unsigned int hash = inet_ehashfn(net, daddr, hnum, saddr, sport);
415	unsigned int slot = hash & hashinfo->ehash_mask;
416	struct inet_ehash_bucket *head = &hashinfo->ehash[slot];
417
418begin:
419	sk_nulls_for_each_rcu(sk, node, &head->chain) {
420		if (sk->sk_hash != hash)
421			continue;
422		if (likely(INET_MATCH(net, sk, acookie, ports, dif, sdif))) {
423			if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
424				goto out;
425			if (unlikely(!INET_MATCH(net, sk, acookie,
426						 ports, dif, sdif))) {
427				sock_gen_put(sk);
428				goto begin;
429			}
430			goto found;
431		}
432	}
433	/*
434	 * if the nulls value we got at the end of this lookup is
435	 * not the expected one, we must restart lookup.
436	 * We probably met an item that was moved to another chain.
437	 */
438	if (get_nulls_value(node) != slot)
439		goto begin;
440out:
441	sk = NULL;
442found:
443	return sk;
444}
445EXPORT_SYMBOL_GPL(__inet_lookup_established);
446
447/* called with local bh disabled */
448static int __inet_check_established(struct inet_timewait_death_row *death_row,
449				    struct sock *sk, __u16 lport,
450				    struct inet_timewait_sock **twp)
451{
452	struct inet_hashinfo *hinfo = death_row->hashinfo;
453	struct inet_sock *inet = inet_sk(sk);
454	__be32 daddr = inet->inet_rcv_saddr;
455	__be32 saddr = inet->inet_daddr;
456	int dif = sk->sk_bound_dev_if;
457	struct net *net = sock_net(sk);
458	int sdif = l3mdev_master_ifindex_by_index(net, dif);
459	INET_ADDR_COOKIE(acookie, saddr, daddr);
460	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
461	unsigned int hash = inet_ehashfn(net, daddr, lport,
462					 saddr, inet->inet_dport);
463	struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
464	spinlock_t *lock = inet_ehash_lockp(hinfo, hash);
465	struct sock *sk2;
466	const struct hlist_nulls_node *node;
467	struct inet_timewait_sock *tw = NULL;
468
469	spin_lock(lock);
470
471	sk_nulls_for_each(sk2, node, &head->chain) {
472		if (sk2->sk_hash != hash)
473			continue;
474
475		if (likely(INET_MATCH(net, sk2, acookie, ports, dif, sdif))) {
476			if (sk2->sk_state == TCP_TIME_WAIT) {
477				tw = inet_twsk(sk2);
478				if (twsk_unique(sk, sk2, twp))
479					break;
480			}
481			goto not_unique;
482		}
483	}
484
485	/* Must record num and sport now. Otherwise we will see
486	 * in hash table socket with a funny identity.
487	 */
488	inet->inet_num = lport;
489	inet->inet_sport = htons(lport);
490	sk->sk_hash = hash;
491	WARN_ON(!sk_unhashed(sk));
492	__sk_nulls_add_node_rcu(sk, &head->chain);
493	if (tw) {
494		sk_nulls_del_node_init_rcu((struct sock *)tw);
495		__NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED);
496	}
497	spin_unlock(lock);
498	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
499
500	if (twp) {
501		*twp = tw;
502	} else if (tw) {
503		/* Silly. Should hash-dance instead... */
504		inet_twsk_deschedule_put(tw);
505	}
506	return 0;
507
508not_unique:
509	spin_unlock(lock);
510	return -EADDRNOTAVAIL;
511}
512
513static u64 inet_sk_port_offset(const struct sock *sk)
514{
515	const struct inet_sock *inet = inet_sk(sk);
516
517	return secure_ipv4_port_ephemeral(inet->inet_rcv_saddr,
518					  inet->inet_daddr,
519					  inet->inet_dport);
520}
521
522/* Searches for an exsiting socket in the ehash bucket list.
523 * Returns true if found, false otherwise.
524 */
525static bool inet_ehash_lookup_by_sk(struct sock *sk,
526				    struct hlist_nulls_head *list)
527{
528	const __portpair ports = INET_COMBINED_PORTS(sk->sk_dport, sk->sk_num);
529	const int sdif = sk->sk_bound_dev_if;
530	const int dif = sk->sk_bound_dev_if;
531	const struct hlist_nulls_node *node;
532	struct net *net = sock_net(sk);
533	struct sock *esk;
534
535	INET_ADDR_COOKIE(acookie, sk->sk_daddr, sk->sk_rcv_saddr);
536
537	sk_nulls_for_each_rcu(esk, node, list) {
538		if (esk->sk_hash != sk->sk_hash)
539			continue;
540		if (sk->sk_family == AF_INET) {
541			if (unlikely(INET_MATCH(net, esk, acookie,
542						ports, dif, sdif))) {
543				return true;
544			}
545		}
546#if IS_ENABLED(CONFIG_IPV6)
547		else if (sk->sk_family == AF_INET6) {
548			if (unlikely(inet6_match(net, esk,
549						 &sk->sk_v6_daddr,
550						 &sk->sk_v6_rcv_saddr,
551						 ports, dif, sdif))) {
552				return true;
553			}
554		}
555#endif
556	}
557	return false;
558}
559
560/* Insert a socket into ehash, and eventually remove another one
561 * (The another one can be a SYN_RECV or TIMEWAIT)
562 * If an existing socket already exists, socket sk is not inserted,
563 * and sets found_dup_sk parameter to true.
564 */
565bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk)
566{
567	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
568	struct hlist_nulls_head *list;
569	struct inet_ehash_bucket *head;
570	spinlock_t *lock;
571	bool ret = true;
572
573	WARN_ON_ONCE(!sk_unhashed(sk));
574
575	sk->sk_hash = sk_ehashfn(sk);
576	head = inet_ehash_bucket(hashinfo, sk->sk_hash);
577	list = &head->chain;
578	lock = inet_ehash_lockp(hashinfo, sk->sk_hash);
579
580	spin_lock(lock);
581	if (osk) {
582		WARN_ON_ONCE(sk->sk_hash != osk->sk_hash);
583		ret = sk_nulls_del_node_init_rcu(osk);
584	} else if (found_dup_sk) {
585		*found_dup_sk = inet_ehash_lookup_by_sk(sk, list);
586		if (*found_dup_sk)
587			ret = false;
588	}
589
590	if (ret)
591		__sk_nulls_add_node_rcu(sk, list);
592
593	spin_unlock(lock);
594
595	return ret;
596}
597
598bool inet_ehash_nolisten(struct sock *sk, struct sock *osk, bool *found_dup_sk)
599{
600	bool ok = inet_ehash_insert(sk, osk, found_dup_sk);
601
602	if (ok) {
603		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
604	} else {
605		this_cpu_inc(*sk->sk_prot->orphan_count);
606		inet_sk_set_state(sk, TCP_CLOSE);
607		sock_set_flag(sk, SOCK_DEAD);
608		inet_csk_destroy_sock(sk);
609	}
610	return ok;
611}
612EXPORT_SYMBOL_GPL(inet_ehash_nolisten);
613
614static int inet_reuseport_add_sock(struct sock *sk,
615				   struct inet_listen_hashbucket *ilb)
616{
617	struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash;
618	const struct hlist_nulls_node *node;
619	struct sock *sk2;
620	kuid_t uid = sock_i_uid(sk);
621
622	sk_nulls_for_each_rcu(sk2, node, &ilb->nulls_head) {
623		if (sk2 != sk &&
624		    sk2->sk_family == sk->sk_family &&
625		    ipv6_only_sock(sk2) == ipv6_only_sock(sk) &&
626		    sk2->sk_bound_dev_if == sk->sk_bound_dev_if &&
627		    inet_csk(sk2)->icsk_bind_hash == tb &&
628		    sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
629		    inet_rcv_saddr_equal(sk, sk2, false))
630			return reuseport_add_sock(sk, sk2,
631						  inet_rcv_saddr_any(sk));
632	}
633
634	return reuseport_alloc(sk, inet_rcv_saddr_any(sk));
635}
636
637int __inet_hash(struct sock *sk, struct sock *osk)
638{
639	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
640	struct inet_listen_hashbucket *ilb;
641	int err = 0;
642
643	if (sk->sk_state != TCP_LISTEN) {
644		local_bh_disable();
645		inet_ehash_nolisten(sk, osk, NULL);
646		local_bh_enable();
647		return 0;
648	}
649	WARN_ON(!sk_unhashed(sk));
650	ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
651
652	spin_lock(&ilb->lock);
653	if (sk->sk_reuseport) {
654		err = inet_reuseport_add_sock(sk, ilb);
655		if (err)
656			goto unlock;
657	}
658	if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport &&
659		sk->sk_family == AF_INET6)
660		__sk_nulls_add_node_tail_rcu(sk, &ilb->nulls_head);
661	else
662		__sk_nulls_add_node_rcu(sk, &ilb->nulls_head);
663	inet_hash2(hashinfo, sk);
664	ilb->count++;
665	sock_set_flag(sk, SOCK_RCU_FREE);
666	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
667unlock:
668	spin_unlock(&ilb->lock);
669
670	return err;
671}
672EXPORT_SYMBOL(__inet_hash);
673
674int inet_hash(struct sock *sk)
675{
676	int err = 0;
677
678	if (sk->sk_state != TCP_CLOSE)
679		err = __inet_hash(sk, NULL);
680
681	return err;
682}
683EXPORT_SYMBOL_GPL(inet_hash);
684
685static void __inet_unhash(struct sock *sk, struct inet_listen_hashbucket *ilb)
686{
687	if (sk_unhashed(sk))
688		return;
689
690	if (rcu_access_pointer(sk->sk_reuseport_cb))
691		reuseport_detach_sock(sk);
692	if (ilb) {
693		struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
694
695		inet_unhash2(hashinfo, sk);
696		ilb->count--;
697	}
698	__sk_nulls_del_node_init_rcu(sk);
699	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
700}
701
702void inet_unhash(struct sock *sk)
703{
704	struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
705
706	if (sk_unhashed(sk))
707		return;
708
709	if (sk->sk_state == TCP_LISTEN) {
710		struct inet_listen_hashbucket *ilb;
711
712		ilb = &hashinfo->listening_hash[inet_sk_listen_hashfn(sk)];
713		/* Don't disable bottom halves while acquiring the lock to
714		 * avoid circular locking dependency on PREEMPT_RT.
715		 */
716		spin_lock(&ilb->lock);
717		__inet_unhash(sk, ilb);
718		spin_unlock(&ilb->lock);
719	} else {
720		spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash);
721
722		spin_lock_bh(lock);
723		__inet_unhash(sk, NULL);
724		spin_unlock_bh(lock);
725	}
726}
727EXPORT_SYMBOL_GPL(inet_unhash);
728
729/* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
730 * Note that we use 32bit integers (vs RFC 'short integers')
731 * because 2^16 is not a multiple of num_ephemeral and this
732 * property might be used by clever attacker.
733 *
734 * RFC claims using TABLE_LENGTH=10 buckets gives an improvement, though
735 * attacks were since demonstrated, thus we use 65536 by default instead
736 * to really give more isolation and privacy, at the expense of 256kB
737 * of kernel memory.
738 */
739#define INET_TABLE_PERTURB_SIZE (1 << CONFIG_INET_TABLE_PERTURB_ORDER)
740static u32 *table_perturb;
741
742int __inet_hash_connect(struct inet_timewait_death_row *death_row,
743		struct sock *sk, u64 port_offset,
744		int (*check_established)(struct inet_timewait_death_row *,
745			struct sock *, __u16, struct inet_timewait_sock **))
746{
747	struct inet_hashinfo *hinfo = death_row->hashinfo;
748	struct inet_timewait_sock *tw = NULL;
749	struct inet_bind_hashbucket *head;
750	int port = inet_sk(sk)->inet_num;
751	struct net *net = sock_net(sk);
752	struct inet_bind_bucket *tb;
753	u32 remaining, offset;
754	int ret, i, low, high;
755	int l3mdev;
756	u32 index;
757
758	if (port) {
759		local_bh_disable();
760		ret = check_established(death_row, sk, port, NULL);
761		local_bh_enable();
762		return ret;
763	}
764
765	l3mdev = inet_sk_bound_l3mdev(sk);
766
767	inet_get_local_port_range(net, &low, &high);
768	high++; /* [32768, 60999] -> [32768, 61000[ */
769	remaining = high - low;
770	if (likely(remaining > 1))
771		remaining &= ~1U;
772
773	get_random_slow_once(table_perturb,
774			     INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb));
775	index = port_offset & (INET_TABLE_PERTURB_SIZE - 1);
776
777	offset = READ_ONCE(table_perturb[index]) + (port_offset >> 32);
778	offset %= remaining;
779	/* In first pass we try ports of @low parity.
780	 * inet_csk_get_port() does the opposite choice.
781	 */
782	offset &= ~1U;
783other_parity_scan:
784	port = low + offset;
785	for (i = 0; i < remaining; i += 2, port += 2) {
786		if (unlikely(port >= high))
787			port -= remaining;
788		if (inet_is_local_reserved_port(net, port))
789			continue;
790		head = &hinfo->bhash[inet_bhashfn(net, port,
791						  hinfo->bhash_size)];
792		spin_lock_bh(&head->lock);
793
794		/* Does not bother with rcv_saddr checks, because
795		 * the established check is already unique enough.
796		 */
797		inet_bind_bucket_for_each(tb, &head->chain) {
798			if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
799			    tb->port == port) {
800				if (tb->fastreuse >= 0 ||
801				    tb->fastreuseport >= 0)
802					goto next_port;
803				WARN_ON(hlist_empty(&tb->owners));
804				if (!check_established(death_row, sk,
805						       port, &tw))
806					goto ok;
807				goto next_port;
808			}
809		}
810
811		tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep,
812					     net, head, port, l3mdev);
813		if (!tb) {
814			spin_unlock_bh(&head->lock);
815			return -ENOMEM;
816		}
817		tb->fastreuse = -1;
818		tb->fastreuseport = -1;
819		goto ok;
820next_port:
821		spin_unlock_bh(&head->lock);
822		cond_resched();
823	}
824
825	offset++;
826	if ((offset & 1) && remaining > 1)
827		goto other_parity_scan;
828
829	return -EADDRNOTAVAIL;
830
831ok:
832	/* Here we want to add a little bit of randomness to the next source
833	 * port that will be chosen. We use a max() with a random here so that
834	 * on low contention the randomness is maximal and on high contention
835	 * it may be inexistent.
836	 */
837	i = max_t(int, i, (prandom_u32() & 7) * 2);
838	WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2);
839
840	/* Head lock still held and bh's disabled */
841	inet_bind_hash(sk, tb, port);
842	if (sk_unhashed(sk)) {
843		inet_sk(sk)->inet_sport = htons(port);
844		inet_ehash_nolisten(sk, (struct sock *)tw, NULL);
845	}
846	if (tw)
847		inet_twsk_bind_unhash(tw, hinfo);
848	spin_unlock(&head->lock);
849	if (tw)
850		inet_twsk_deschedule_put(tw);
851	local_bh_enable();
852	return 0;
853}
854
855/*
856 * Bind a port for a connect operation and hash it.
857 */
858int inet_hash_connect(struct inet_timewait_death_row *death_row,
859		      struct sock *sk)
860{
861	u64 port_offset = 0;
862
863	if (!inet_sk(sk)->inet_num)
864		port_offset = inet_sk_port_offset(sk);
865	return __inet_hash_connect(death_row, sk, port_offset,
866				   __inet_check_established);
867}
868EXPORT_SYMBOL_GPL(inet_hash_connect);
869
870void inet_hashinfo_init(struct inet_hashinfo *h)
871{
872	int i;
873
874	for (i = 0; i < INET_LHTABLE_SIZE; i++) {
875		spin_lock_init(&h->listening_hash[i].lock);
876		INIT_HLIST_NULLS_HEAD(&h->listening_hash[i].nulls_head,
877				      i + LISTENING_NULLS_BASE);
878		h->listening_hash[i].count = 0;
879	}
880
881	h->lhash2 = NULL;
882}
883EXPORT_SYMBOL_GPL(inet_hashinfo_init);
884
885static void init_hashinfo_lhash2(struct inet_hashinfo *h)
886{
887	int i;
888
889	for (i = 0; i <= h->lhash2_mask; i++) {
890		spin_lock_init(&h->lhash2[i].lock);
891		INIT_HLIST_HEAD(&h->lhash2[i].head);
892		h->lhash2[i].count = 0;
893	}
894}
895
896void __init inet_hashinfo2_init(struct inet_hashinfo *h, const char *name,
897				unsigned long numentries, int scale,
898				unsigned long low_limit,
899				unsigned long high_limit)
900{
901	h->lhash2 = alloc_large_system_hash(name,
902					    sizeof(*h->lhash2),
903					    numentries,
904					    scale,
905					    0,
906					    NULL,
907					    &h->lhash2_mask,
908					    low_limit,
909					    high_limit);
910	init_hashinfo_lhash2(h);
911
912	/* this one is used for source ports of outgoing connections */
913	table_perturb = kmalloc_array(INET_TABLE_PERTURB_SIZE,
914				      sizeof(*table_perturb), GFP_KERNEL);
915	if (!table_perturb)
916		panic("TCP: failed to alloc table_perturb");
917}
918
919int inet_hashinfo2_init_mod(struct inet_hashinfo *h)
920{
921	h->lhash2 = kmalloc_array(INET_LHTABLE_SIZE, sizeof(*h->lhash2), GFP_KERNEL);
922	if (!h->lhash2)
923		return -ENOMEM;
924
925	h->lhash2_mask = INET_LHTABLE_SIZE - 1;
926	/* INET_LHTABLE_SIZE must be a power of 2 */
927	BUG_ON(INET_LHTABLE_SIZE & h->lhash2_mask);
928
929	init_hashinfo_lhash2(h);
930	return 0;
931}
932EXPORT_SYMBOL_GPL(inet_hashinfo2_init_mod);
933
934int inet_ehash_locks_alloc(struct inet_hashinfo *hashinfo)
935{
936	unsigned int locksz = sizeof(spinlock_t);
937	unsigned int i, nblocks = 1;
938
939	if (locksz != 0) {
940		/* allocate 2 cache lines or at least one spinlock per cpu */
941		nblocks = max(2U * L1_CACHE_BYTES / locksz, 1U);
942		nblocks = roundup_pow_of_two(nblocks * num_possible_cpus());
943
944		/* no more locks than number of hash buckets */
945		nblocks = min(nblocks, hashinfo->ehash_mask + 1);
946
947		hashinfo->ehash_locks = kvmalloc_array(nblocks, locksz, GFP_KERNEL);
948		if (!hashinfo->ehash_locks)
949			return -ENOMEM;
950
951		for (i = 0; i < nblocks; i++)
952			spin_lock_init(&hashinfo->ehash_locks[i]);
953	}
954	hashinfo->ehash_locks_mask = nblocks - 1;
955	return 0;
956}
957EXPORT_SYMBOL_GPL(inet_ehash_locks_alloc);
958