xref: /kernel/linux/linux-6.6/net/handshake/request.c (revision 62306a36)
1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Handshake request lifetime events
4 *
5 * Author: Chuck Lever <chuck.lever@oracle.com>
6 *
7 * Copyright (c) 2023, Oracle and/or its affiliates.
8 */
9
10#include <linux/types.h>
11#include <linux/socket.h>
12#include <linux/kernel.h>
13#include <linux/module.h>
14#include <linux/skbuff.h>
15#include <linux/inet.h>
16#include <linux/fdtable.h>
17#include <linux/rhashtable.h>
18
19#include <net/sock.h>
20#include <net/genetlink.h>
21#include <net/netns/generic.h>
22
23#include <kunit/visibility.h>
24
25#include <uapi/linux/handshake.h>
26#include "handshake.h"
27
28#include <trace/events/handshake.h>
29
30/*
31 * We need both a handshake_req -> sock mapping, and a sock ->
32 * handshake_req mapping. Both are one-to-one.
33 *
34 * To avoid adding another pointer field to struct sock, net/handshake
35 * maintains a hash table, indexed by the memory address of @sock, to
36 * find the struct handshake_req outstanding for that socket. The
37 * reverse direction uses a simple pointer field in the handshake_req
38 * struct.
39 */
40
41static struct rhashtable handshake_rhashtbl ____cacheline_aligned_in_smp;
42
43static const struct rhashtable_params handshake_rhash_params = {
44	.key_len		= sizeof_field(struct handshake_req, hr_sk),
45	.key_offset		= offsetof(struct handshake_req, hr_sk),
46	.head_offset		= offsetof(struct handshake_req, hr_rhash),
47	.automatic_shrinking	= true,
48};
49
50int handshake_req_hash_init(void)
51{
52	return rhashtable_init(&handshake_rhashtbl, &handshake_rhash_params);
53}
54
55void handshake_req_hash_destroy(void)
56{
57	rhashtable_destroy(&handshake_rhashtbl);
58}
59
60struct handshake_req *handshake_req_hash_lookup(struct sock *sk)
61{
62	return rhashtable_lookup_fast(&handshake_rhashtbl, &sk,
63				      handshake_rhash_params);
64}
65EXPORT_SYMBOL_IF_KUNIT(handshake_req_hash_lookup);
66
67static bool handshake_req_hash_add(struct handshake_req *req)
68{
69	int ret;
70
71	ret = rhashtable_lookup_insert_fast(&handshake_rhashtbl,
72					    &req->hr_rhash,
73					    handshake_rhash_params);
74	return ret == 0;
75}
76
77static void handshake_req_destroy(struct handshake_req *req)
78{
79	if (req->hr_proto->hp_destroy)
80		req->hr_proto->hp_destroy(req);
81	rhashtable_remove_fast(&handshake_rhashtbl, &req->hr_rhash,
82			       handshake_rhash_params);
83	kfree(req);
84}
85
86static void handshake_sk_destruct(struct sock *sk)
87{
88	void (*sk_destruct)(struct sock *sk);
89	struct handshake_req *req;
90
91	req = handshake_req_hash_lookup(sk);
92	if (!req)
93		return;
94
95	trace_handshake_destruct(sock_net(sk), req, sk);
96	sk_destruct = req->hr_odestruct;
97	handshake_req_destroy(req);
98	if (sk_destruct)
99		sk_destruct(sk);
100}
101
102/**
103 * handshake_req_alloc - Allocate a handshake request
104 * @proto: security protocol
105 * @flags: memory allocation flags
106 *
107 * Returns an initialized handshake_req or NULL.
108 */
109struct handshake_req *handshake_req_alloc(const struct handshake_proto *proto,
110					  gfp_t flags)
111{
112	struct handshake_req *req;
113
114	if (!proto)
115		return NULL;
116	if (proto->hp_handler_class <= HANDSHAKE_HANDLER_CLASS_NONE)
117		return NULL;
118	if (proto->hp_handler_class >= HANDSHAKE_HANDLER_CLASS_MAX)
119		return NULL;
120	if (!proto->hp_accept || !proto->hp_done)
121		return NULL;
122
123	req = kzalloc(struct_size(req, hr_priv, proto->hp_privsize), flags);
124	if (!req)
125		return NULL;
126
127	INIT_LIST_HEAD(&req->hr_list);
128	req->hr_proto = proto;
129	return req;
130}
131EXPORT_SYMBOL(handshake_req_alloc);
132
133/**
134 * handshake_req_private - Get per-handshake private data
135 * @req: handshake arguments
136 *
137 */
138void *handshake_req_private(struct handshake_req *req)
139{
140	return (void *)&req->hr_priv;
141}
142EXPORT_SYMBOL(handshake_req_private);
143
144static bool __add_pending_locked(struct handshake_net *hn,
145				 struct handshake_req *req)
146{
147	if (WARN_ON_ONCE(!list_empty(&req->hr_list)))
148		return false;
149	hn->hn_pending++;
150	list_add_tail(&req->hr_list, &hn->hn_requests);
151	return true;
152}
153
154static void __remove_pending_locked(struct handshake_net *hn,
155				    struct handshake_req *req)
156{
157	hn->hn_pending--;
158	list_del_init(&req->hr_list);
159}
160
161/*
162 * Returns %true if the request was found on @net's pending list,
163 * otherwise %false.
164 *
165 * If @req was on a pending list, it has not yet been accepted.
166 */
167static bool remove_pending(struct handshake_net *hn, struct handshake_req *req)
168{
169	bool ret = false;
170
171	spin_lock(&hn->hn_lock);
172	if (!list_empty(&req->hr_list)) {
173		__remove_pending_locked(hn, req);
174		ret = true;
175	}
176	spin_unlock(&hn->hn_lock);
177
178	return ret;
179}
180
181struct handshake_req *handshake_req_next(struct handshake_net *hn, int class)
182{
183	struct handshake_req *req, *pos;
184
185	req = NULL;
186	spin_lock(&hn->hn_lock);
187	list_for_each_entry(pos, &hn->hn_requests, hr_list) {
188		if (pos->hr_proto->hp_handler_class != class)
189			continue;
190		__remove_pending_locked(hn, pos);
191		req = pos;
192		break;
193	}
194	spin_unlock(&hn->hn_lock);
195
196	return req;
197}
198EXPORT_SYMBOL_IF_KUNIT(handshake_req_next);
199
200/**
201 * handshake_req_submit - Submit a handshake request
202 * @sock: open socket on which to perform the handshake
203 * @req: handshake arguments
204 * @flags: memory allocation flags
205 *
206 * Return values:
207 *   %0: Request queued
208 *   %-EINVAL: Invalid argument
209 *   %-EBUSY: A handshake is already under way for this socket
210 *   %-ESRCH: No handshake agent is available
211 *   %-EAGAIN: Too many pending handshake requests
212 *   %-ENOMEM: Failed to allocate memory
213 *   %-EMSGSIZE: Failed to construct notification message
214 *   %-EOPNOTSUPP: Handshake module not initialized
215 *
216 * A zero return value from handshake_req_submit() means that
217 * exactly one subsequent completion callback is guaranteed.
218 *
219 * A negative return value from handshake_req_submit() means that
220 * no completion callback will be done and that @req has been
221 * destroyed.
222 */
223int handshake_req_submit(struct socket *sock, struct handshake_req *req,
224			 gfp_t flags)
225{
226	struct handshake_net *hn;
227	struct net *net;
228	int ret;
229
230	if (!sock || !req || !sock->file) {
231		kfree(req);
232		return -EINVAL;
233	}
234
235	req->hr_sk = sock->sk;
236	if (!req->hr_sk) {
237		kfree(req);
238		return -EINVAL;
239	}
240	req->hr_odestruct = req->hr_sk->sk_destruct;
241	req->hr_sk->sk_destruct = handshake_sk_destruct;
242
243	ret = -EOPNOTSUPP;
244	net = sock_net(req->hr_sk);
245	hn = handshake_pernet(net);
246	if (!hn)
247		goto out_err;
248
249	ret = -EAGAIN;
250	if (READ_ONCE(hn->hn_pending) >= hn->hn_pending_max)
251		goto out_err;
252
253	spin_lock(&hn->hn_lock);
254	ret = -EOPNOTSUPP;
255	if (test_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags))
256		goto out_unlock;
257	ret = -EBUSY;
258	if (!handshake_req_hash_add(req))
259		goto out_unlock;
260	if (!__add_pending_locked(hn, req))
261		goto out_unlock;
262	spin_unlock(&hn->hn_lock);
263
264	ret = handshake_genl_notify(net, req->hr_proto, flags);
265	if (ret) {
266		trace_handshake_notify_err(net, req, req->hr_sk, ret);
267		if (remove_pending(hn, req))
268			goto out_err;
269	}
270
271	/* Prevent socket release while a handshake request is pending */
272	sock_hold(req->hr_sk);
273
274	trace_handshake_submit(net, req, req->hr_sk);
275	return 0;
276
277out_unlock:
278	spin_unlock(&hn->hn_lock);
279out_err:
280	trace_handshake_submit_err(net, req, req->hr_sk, ret);
281	handshake_req_destroy(req);
282	return ret;
283}
284EXPORT_SYMBOL(handshake_req_submit);
285
286void handshake_complete(struct handshake_req *req, unsigned int status,
287			struct genl_info *info)
288{
289	struct sock *sk = req->hr_sk;
290	struct net *net = sock_net(sk);
291
292	if (!test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
293		trace_handshake_complete(net, req, sk, status);
294		req->hr_proto->hp_done(req, status, info);
295
296		/* Handshake request is no longer pending */
297		sock_put(sk);
298	}
299}
300EXPORT_SYMBOL_IF_KUNIT(handshake_complete);
301
302/**
303 * handshake_req_cancel - Cancel an in-progress handshake
304 * @sk: socket on which there is an ongoing handshake
305 *
306 * Request cancellation races with request completion. To determine
307 * who won, callers examine the return value from this function.
308 *
309 * Return values:
310 *   %true - Uncompleted handshake request was canceled
311 *   %false - Handshake request already completed or not found
312 */
313bool handshake_req_cancel(struct sock *sk)
314{
315	struct handshake_req *req;
316	struct handshake_net *hn;
317	struct net *net;
318
319	net = sock_net(sk);
320	req = handshake_req_hash_lookup(sk);
321	if (!req) {
322		trace_handshake_cancel_none(net, req, sk);
323		return false;
324	}
325
326	hn = handshake_pernet(net);
327	if (hn && remove_pending(hn, req)) {
328		/* Request hadn't been accepted */
329		goto out_true;
330	}
331	if (test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
332		/* Request already completed */
333		trace_handshake_cancel_busy(net, req, sk);
334		return false;
335	}
336
337out_true:
338	trace_handshake_cancel(net, req, sk);
339
340	/* Handshake request is no longer pending */
341	sock_put(sk);
342	return true;
343}
344EXPORT_SYMBOL(handshake_req_cancel);
345