162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0-only
262306a36Sopenharmony_ci/*
362306a36Sopenharmony_ci * Handshake request lifetime events
462306a36Sopenharmony_ci *
562306a36Sopenharmony_ci * Author: Chuck Lever <chuck.lever@oracle.com>
662306a36Sopenharmony_ci *
762306a36Sopenharmony_ci * Copyright (c) 2023, Oracle and/or its affiliates.
862306a36Sopenharmony_ci */
962306a36Sopenharmony_ci
1062306a36Sopenharmony_ci#include <linux/types.h>
1162306a36Sopenharmony_ci#include <linux/socket.h>
1262306a36Sopenharmony_ci#include <linux/kernel.h>
1362306a36Sopenharmony_ci#include <linux/module.h>
1462306a36Sopenharmony_ci#include <linux/skbuff.h>
1562306a36Sopenharmony_ci#include <linux/inet.h>
1662306a36Sopenharmony_ci#include <linux/fdtable.h>
1762306a36Sopenharmony_ci#include <linux/rhashtable.h>
1862306a36Sopenharmony_ci
1962306a36Sopenharmony_ci#include <net/sock.h>
2062306a36Sopenharmony_ci#include <net/genetlink.h>
2162306a36Sopenharmony_ci#include <net/netns/generic.h>
2262306a36Sopenharmony_ci
2362306a36Sopenharmony_ci#include <kunit/visibility.h>
2462306a36Sopenharmony_ci
2562306a36Sopenharmony_ci#include <uapi/linux/handshake.h>
2662306a36Sopenharmony_ci#include "handshake.h"
2762306a36Sopenharmony_ci
2862306a36Sopenharmony_ci#include <trace/events/handshake.h>
2962306a36Sopenharmony_ci
3062306a36Sopenharmony_ci/*
3162306a36Sopenharmony_ci * We need both a handshake_req -> sock mapping, and a sock ->
3262306a36Sopenharmony_ci * handshake_req mapping. Both are one-to-one.
3362306a36Sopenharmony_ci *
3462306a36Sopenharmony_ci * To avoid adding another pointer field to struct sock, net/handshake
3562306a36Sopenharmony_ci * maintains a hash table, indexed by the memory address of @sock, to
3662306a36Sopenharmony_ci * find the struct handshake_req outstanding for that socket. The
3762306a36Sopenharmony_ci * reverse direction uses a simple pointer field in the handshake_req
3862306a36Sopenharmony_ci * struct.
3962306a36Sopenharmony_ci */
4062306a36Sopenharmony_ci
4162306a36Sopenharmony_cistatic struct rhashtable handshake_rhashtbl ____cacheline_aligned_in_smp;
4262306a36Sopenharmony_ci
4362306a36Sopenharmony_cistatic const struct rhashtable_params handshake_rhash_params = {
4462306a36Sopenharmony_ci	.key_len		= sizeof_field(struct handshake_req, hr_sk),
4562306a36Sopenharmony_ci	.key_offset		= offsetof(struct handshake_req, hr_sk),
4662306a36Sopenharmony_ci	.head_offset		= offsetof(struct handshake_req, hr_rhash),
4762306a36Sopenharmony_ci	.automatic_shrinking	= true,
4862306a36Sopenharmony_ci};
4962306a36Sopenharmony_ci
5062306a36Sopenharmony_ciint handshake_req_hash_init(void)
5162306a36Sopenharmony_ci{
5262306a36Sopenharmony_ci	return rhashtable_init(&handshake_rhashtbl, &handshake_rhash_params);
5362306a36Sopenharmony_ci}
5462306a36Sopenharmony_ci
5562306a36Sopenharmony_civoid handshake_req_hash_destroy(void)
5662306a36Sopenharmony_ci{
5762306a36Sopenharmony_ci	rhashtable_destroy(&handshake_rhashtbl);
5862306a36Sopenharmony_ci}
5962306a36Sopenharmony_ci
6062306a36Sopenharmony_cistruct handshake_req *handshake_req_hash_lookup(struct sock *sk)
6162306a36Sopenharmony_ci{
6262306a36Sopenharmony_ci	return rhashtable_lookup_fast(&handshake_rhashtbl, &sk,
6362306a36Sopenharmony_ci				      handshake_rhash_params);
6462306a36Sopenharmony_ci}
6562306a36Sopenharmony_ciEXPORT_SYMBOL_IF_KUNIT(handshake_req_hash_lookup);
6662306a36Sopenharmony_ci
6762306a36Sopenharmony_cistatic bool handshake_req_hash_add(struct handshake_req *req)
6862306a36Sopenharmony_ci{
6962306a36Sopenharmony_ci	int ret;
7062306a36Sopenharmony_ci
7162306a36Sopenharmony_ci	ret = rhashtable_lookup_insert_fast(&handshake_rhashtbl,
7262306a36Sopenharmony_ci					    &req->hr_rhash,
7362306a36Sopenharmony_ci					    handshake_rhash_params);
7462306a36Sopenharmony_ci	return ret == 0;
7562306a36Sopenharmony_ci}
7662306a36Sopenharmony_ci
7762306a36Sopenharmony_cistatic void handshake_req_destroy(struct handshake_req *req)
7862306a36Sopenharmony_ci{
7962306a36Sopenharmony_ci	if (req->hr_proto->hp_destroy)
8062306a36Sopenharmony_ci		req->hr_proto->hp_destroy(req);
8162306a36Sopenharmony_ci	rhashtable_remove_fast(&handshake_rhashtbl, &req->hr_rhash,
8262306a36Sopenharmony_ci			       handshake_rhash_params);
8362306a36Sopenharmony_ci	kfree(req);
8462306a36Sopenharmony_ci}
8562306a36Sopenharmony_ci
8662306a36Sopenharmony_cistatic void handshake_sk_destruct(struct sock *sk)
8762306a36Sopenharmony_ci{
8862306a36Sopenharmony_ci	void (*sk_destruct)(struct sock *sk);
8962306a36Sopenharmony_ci	struct handshake_req *req;
9062306a36Sopenharmony_ci
9162306a36Sopenharmony_ci	req = handshake_req_hash_lookup(sk);
9262306a36Sopenharmony_ci	if (!req)
9362306a36Sopenharmony_ci		return;
9462306a36Sopenharmony_ci
9562306a36Sopenharmony_ci	trace_handshake_destruct(sock_net(sk), req, sk);
9662306a36Sopenharmony_ci	sk_destruct = req->hr_odestruct;
9762306a36Sopenharmony_ci	handshake_req_destroy(req);
9862306a36Sopenharmony_ci	if (sk_destruct)
9962306a36Sopenharmony_ci		sk_destruct(sk);
10062306a36Sopenharmony_ci}
10162306a36Sopenharmony_ci
10262306a36Sopenharmony_ci/**
10362306a36Sopenharmony_ci * handshake_req_alloc - Allocate a handshake request
10462306a36Sopenharmony_ci * @proto: security protocol
10562306a36Sopenharmony_ci * @flags: memory allocation flags
10662306a36Sopenharmony_ci *
10762306a36Sopenharmony_ci * Returns an initialized handshake_req or NULL.
10862306a36Sopenharmony_ci */
10962306a36Sopenharmony_cistruct handshake_req *handshake_req_alloc(const struct handshake_proto *proto,
11062306a36Sopenharmony_ci					  gfp_t flags)
11162306a36Sopenharmony_ci{
11262306a36Sopenharmony_ci	struct handshake_req *req;
11362306a36Sopenharmony_ci
11462306a36Sopenharmony_ci	if (!proto)
11562306a36Sopenharmony_ci		return NULL;
11662306a36Sopenharmony_ci	if (proto->hp_handler_class <= HANDSHAKE_HANDLER_CLASS_NONE)
11762306a36Sopenharmony_ci		return NULL;
11862306a36Sopenharmony_ci	if (proto->hp_handler_class >= HANDSHAKE_HANDLER_CLASS_MAX)
11962306a36Sopenharmony_ci		return NULL;
12062306a36Sopenharmony_ci	if (!proto->hp_accept || !proto->hp_done)
12162306a36Sopenharmony_ci		return NULL;
12262306a36Sopenharmony_ci
12362306a36Sopenharmony_ci	req = kzalloc(struct_size(req, hr_priv, proto->hp_privsize), flags);
12462306a36Sopenharmony_ci	if (!req)
12562306a36Sopenharmony_ci		return NULL;
12662306a36Sopenharmony_ci
12762306a36Sopenharmony_ci	INIT_LIST_HEAD(&req->hr_list);
12862306a36Sopenharmony_ci	req->hr_proto = proto;
12962306a36Sopenharmony_ci	return req;
13062306a36Sopenharmony_ci}
13162306a36Sopenharmony_ciEXPORT_SYMBOL(handshake_req_alloc);
13262306a36Sopenharmony_ci
13362306a36Sopenharmony_ci/**
13462306a36Sopenharmony_ci * handshake_req_private - Get per-handshake private data
13562306a36Sopenharmony_ci * @req: handshake arguments
13662306a36Sopenharmony_ci *
13762306a36Sopenharmony_ci */
13862306a36Sopenharmony_civoid *handshake_req_private(struct handshake_req *req)
13962306a36Sopenharmony_ci{
14062306a36Sopenharmony_ci	return (void *)&req->hr_priv;
14162306a36Sopenharmony_ci}
14262306a36Sopenharmony_ciEXPORT_SYMBOL(handshake_req_private);
14362306a36Sopenharmony_ci
14462306a36Sopenharmony_cistatic bool __add_pending_locked(struct handshake_net *hn,
14562306a36Sopenharmony_ci				 struct handshake_req *req)
14662306a36Sopenharmony_ci{
14762306a36Sopenharmony_ci	if (WARN_ON_ONCE(!list_empty(&req->hr_list)))
14862306a36Sopenharmony_ci		return false;
14962306a36Sopenharmony_ci	hn->hn_pending++;
15062306a36Sopenharmony_ci	list_add_tail(&req->hr_list, &hn->hn_requests);
15162306a36Sopenharmony_ci	return true;
15262306a36Sopenharmony_ci}
15362306a36Sopenharmony_ci
15462306a36Sopenharmony_cistatic void __remove_pending_locked(struct handshake_net *hn,
15562306a36Sopenharmony_ci				    struct handshake_req *req)
15662306a36Sopenharmony_ci{
15762306a36Sopenharmony_ci	hn->hn_pending--;
15862306a36Sopenharmony_ci	list_del_init(&req->hr_list);
15962306a36Sopenharmony_ci}
16062306a36Sopenharmony_ci
16162306a36Sopenharmony_ci/*
16262306a36Sopenharmony_ci * Returns %true if the request was found on @net's pending list,
16362306a36Sopenharmony_ci * otherwise %false.
16462306a36Sopenharmony_ci *
16562306a36Sopenharmony_ci * If @req was on a pending list, it has not yet been accepted.
16662306a36Sopenharmony_ci */
16762306a36Sopenharmony_cistatic bool remove_pending(struct handshake_net *hn, struct handshake_req *req)
16862306a36Sopenharmony_ci{
16962306a36Sopenharmony_ci	bool ret = false;
17062306a36Sopenharmony_ci
17162306a36Sopenharmony_ci	spin_lock(&hn->hn_lock);
17262306a36Sopenharmony_ci	if (!list_empty(&req->hr_list)) {
17362306a36Sopenharmony_ci		__remove_pending_locked(hn, req);
17462306a36Sopenharmony_ci		ret = true;
17562306a36Sopenharmony_ci	}
17662306a36Sopenharmony_ci	spin_unlock(&hn->hn_lock);
17762306a36Sopenharmony_ci
17862306a36Sopenharmony_ci	return ret;
17962306a36Sopenharmony_ci}
18062306a36Sopenharmony_ci
18162306a36Sopenharmony_cistruct handshake_req *handshake_req_next(struct handshake_net *hn, int class)
18262306a36Sopenharmony_ci{
18362306a36Sopenharmony_ci	struct handshake_req *req, *pos;
18462306a36Sopenharmony_ci
18562306a36Sopenharmony_ci	req = NULL;
18662306a36Sopenharmony_ci	spin_lock(&hn->hn_lock);
18762306a36Sopenharmony_ci	list_for_each_entry(pos, &hn->hn_requests, hr_list) {
18862306a36Sopenharmony_ci		if (pos->hr_proto->hp_handler_class != class)
18962306a36Sopenharmony_ci			continue;
19062306a36Sopenharmony_ci		__remove_pending_locked(hn, pos);
19162306a36Sopenharmony_ci		req = pos;
19262306a36Sopenharmony_ci		break;
19362306a36Sopenharmony_ci	}
19462306a36Sopenharmony_ci	spin_unlock(&hn->hn_lock);
19562306a36Sopenharmony_ci
19662306a36Sopenharmony_ci	return req;
19762306a36Sopenharmony_ci}
19862306a36Sopenharmony_ciEXPORT_SYMBOL_IF_KUNIT(handshake_req_next);
19962306a36Sopenharmony_ci
20062306a36Sopenharmony_ci/**
20162306a36Sopenharmony_ci * handshake_req_submit - Submit a handshake request
20262306a36Sopenharmony_ci * @sock: open socket on which to perform the handshake
20362306a36Sopenharmony_ci * @req: handshake arguments
20462306a36Sopenharmony_ci * @flags: memory allocation flags
20562306a36Sopenharmony_ci *
20662306a36Sopenharmony_ci * Return values:
20762306a36Sopenharmony_ci *   %0: Request queued
20862306a36Sopenharmony_ci *   %-EINVAL: Invalid argument
20962306a36Sopenharmony_ci *   %-EBUSY: A handshake is already under way for this socket
21062306a36Sopenharmony_ci *   %-ESRCH: No handshake agent is available
21162306a36Sopenharmony_ci *   %-EAGAIN: Too many pending handshake requests
21262306a36Sopenharmony_ci *   %-ENOMEM: Failed to allocate memory
21362306a36Sopenharmony_ci *   %-EMSGSIZE: Failed to construct notification message
21462306a36Sopenharmony_ci *   %-EOPNOTSUPP: Handshake module not initialized
21562306a36Sopenharmony_ci *
21662306a36Sopenharmony_ci * A zero return value from handshake_req_submit() means that
21762306a36Sopenharmony_ci * exactly one subsequent completion callback is guaranteed.
21862306a36Sopenharmony_ci *
21962306a36Sopenharmony_ci * A negative return value from handshake_req_submit() means that
22062306a36Sopenharmony_ci * no completion callback will be done and that @req has been
22162306a36Sopenharmony_ci * destroyed.
22262306a36Sopenharmony_ci */
22362306a36Sopenharmony_ciint handshake_req_submit(struct socket *sock, struct handshake_req *req,
22462306a36Sopenharmony_ci			 gfp_t flags)
22562306a36Sopenharmony_ci{
22662306a36Sopenharmony_ci	struct handshake_net *hn;
22762306a36Sopenharmony_ci	struct net *net;
22862306a36Sopenharmony_ci	int ret;
22962306a36Sopenharmony_ci
23062306a36Sopenharmony_ci	if (!sock || !req || !sock->file) {
23162306a36Sopenharmony_ci		kfree(req);
23262306a36Sopenharmony_ci		return -EINVAL;
23362306a36Sopenharmony_ci	}
23462306a36Sopenharmony_ci
23562306a36Sopenharmony_ci	req->hr_sk = sock->sk;
23662306a36Sopenharmony_ci	if (!req->hr_sk) {
23762306a36Sopenharmony_ci		kfree(req);
23862306a36Sopenharmony_ci		return -EINVAL;
23962306a36Sopenharmony_ci	}
24062306a36Sopenharmony_ci	req->hr_odestruct = req->hr_sk->sk_destruct;
24162306a36Sopenharmony_ci	req->hr_sk->sk_destruct = handshake_sk_destruct;
24262306a36Sopenharmony_ci
24362306a36Sopenharmony_ci	ret = -EOPNOTSUPP;
24462306a36Sopenharmony_ci	net = sock_net(req->hr_sk);
24562306a36Sopenharmony_ci	hn = handshake_pernet(net);
24662306a36Sopenharmony_ci	if (!hn)
24762306a36Sopenharmony_ci		goto out_err;
24862306a36Sopenharmony_ci
24962306a36Sopenharmony_ci	ret = -EAGAIN;
25062306a36Sopenharmony_ci	if (READ_ONCE(hn->hn_pending) >= hn->hn_pending_max)
25162306a36Sopenharmony_ci		goto out_err;
25262306a36Sopenharmony_ci
25362306a36Sopenharmony_ci	spin_lock(&hn->hn_lock);
25462306a36Sopenharmony_ci	ret = -EOPNOTSUPP;
25562306a36Sopenharmony_ci	if (test_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags))
25662306a36Sopenharmony_ci		goto out_unlock;
25762306a36Sopenharmony_ci	ret = -EBUSY;
25862306a36Sopenharmony_ci	if (!handshake_req_hash_add(req))
25962306a36Sopenharmony_ci		goto out_unlock;
26062306a36Sopenharmony_ci	if (!__add_pending_locked(hn, req))
26162306a36Sopenharmony_ci		goto out_unlock;
26262306a36Sopenharmony_ci	spin_unlock(&hn->hn_lock);
26362306a36Sopenharmony_ci
26462306a36Sopenharmony_ci	ret = handshake_genl_notify(net, req->hr_proto, flags);
26562306a36Sopenharmony_ci	if (ret) {
26662306a36Sopenharmony_ci		trace_handshake_notify_err(net, req, req->hr_sk, ret);
26762306a36Sopenharmony_ci		if (remove_pending(hn, req))
26862306a36Sopenharmony_ci			goto out_err;
26962306a36Sopenharmony_ci	}
27062306a36Sopenharmony_ci
27162306a36Sopenharmony_ci	/* Prevent socket release while a handshake request is pending */
27262306a36Sopenharmony_ci	sock_hold(req->hr_sk);
27362306a36Sopenharmony_ci
27462306a36Sopenharmony_ci	trace_handshake_submit(net, req, req->hr_sk);
27562306a36Sopenharmony_ci	return 0;
27662306a36Sopenharmony_ci
27762306a36Sopenharmony_ciout_unlock:
27862306a36Sopenharmony_ci	spin_unlock(&hn->hn_lock);
27962306a36Sopenharmony_ciout_err:
28062306a36Sopenharmony_ci	trace_handshake_submit_err(net, req, req->hr_sk, ret);
28162306a36Sopenharmony_ci	handshake_req_destroy(req);
28262306a36Sopenharmony_ci	return ret;
28362306a36Sopenharmony_ci}
28462306a36Sopenharmony_ciEXPORT_SYMBOL(handshake_req_submit);
28562306a36Sopenharmony_ci
28662306a36Sopenharmony_civoid handshake_complete(struct handshake_req *req, unsigned int status,
28762306a36Sopenharmony_ci			struct genl_info *info)
28862306a36Sopenharmony_ci{
28962306a36Sopenharmony_ci	struct sock *sk = req->hr_sk;
29062306a36Sopenharmony_ci	struct net *net = sock_net(sk);
29162306a36Sopenharmony_ci
29262306a36Sopenharmony_ci	if (!test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
29362306a36Sopenharmony_ci		trace_handshake_complete(net, req, sk, status);
29462306a36Sopenharmony_ci		req->hr_proto->hp_done(req, status, info);
29562306a36Sopenharmony_ci
29662306a36Sopenharmony_ci		/* Handshake request is no longer pending */
29762306a36Sopenharmony_ci		sock_put(sk);
29862306a36Sopenharmony_ci	}
29962306a36Sopenharmony_ci}
30062306a36Sopenharmony_ciEXPORT_SYMBOL_IF_KUNIT(handshake_complete);
30162306a36Sopenharmony_ci
30262306a36Sopenharmony_ci/**
30362306a36Sopenharmony_ci * handshake_req_cancel - Cancel an in-progress handshake
30462306a36Sopenharmony_ci * @sk: socket on which there is an ongoing handshake
30562306a36Sopenharmony_ci *
30662306a36Sopenharmony_ci * Request cancellation races with request completion. To determine
30762306a36Sopenharmony_ci * who won, callers examine the return value from this function.
30862306a36Sopenharmony_ci *
30962306a36Sopenharmony_ci * Return values:
31062306a36Sopenharmony_ci *   %true - Uncompleted handshake request was canceled
31162306a36Sopenharmony_ci *   %false - Handshake request already completed or not found
31262306a36Sopenharmony_ci */
31362306a36Sopenharmony_cibool handshake_req_cancel(struct sock *sk)
31462306a36Sopenharmony_ci{
31562306a36Sopenharmony_ci	struct handshake_req *req;
31662306a36Sopenharmony_ci	struct handshake_net *hn;
31762306a36Sopenharmony_ci	struct net *net;
31862306a36Sopenharmony_ci
31962306a36Sopenharmony_ci	net = sock_net(sk);
32062306a36Sopenharmony_ci	req = handshake_req_hash_lookup(sk);
32162306a36Sopenharmony_ci	if (!req) {
32262306a36Sopenharmony_ci		trace_handshake_cancel_none(net, req, sk);
32362306a36Sopenharmony_ci		return false;
32462306a36Sopenharmony_ci	}
32562306a36Sopenharmony_ci
32662306a36Sopenharmony_ci	hn = handshake_pernet(net);
32762306a36Sopenharmony_ci	if (hn && remove_pending(hn, req)) {
32862306a36Sopenharmony_ci		/* Request hadn't been accepted */
32962306a36Sopenharmony_ci		goto out_true;
33062306a36Sopenharmony_ci	}
33162306a36Sopenharmony_ci	if (test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
33262306a36Sopenharmony_ci		/* Request already completed */
33362306a36Sopenharmony_ci		trace_handshake_cancel_busy(net, req, sk);
33462306a36Sopenharmony_ci		return false;
33562306a36Sopenharmony_ci	}
33662306a36Sopenharmony_ci
33762306a36Sopenharmony_ciout_true:
33862306a36Sopenharmony_ci	trace_handshake_cancel(net, req, sk);
33962306a36Sopenharmony_ci
34062306a36Sopenharmony_ci	/* Handshake request is no longer pending */
34162306a36Sopenharmony_ci	sock_put(sk);
34262306a36Sopenharmony_ci	return true;
34362306a36Sopenharmony_ci}
34462306a36Sopenharmony_ciEXPORT_SYMBOL(handshake_req_cancel);
345