162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0-only
262306a36Sopenharmony_ci/*
362306a36Sopenharmony_ci * Generic netlink handshake service
462306a36Sopenharmony_ci *
562306a36Sopenharmony_ci * Author: Chuck Lever <chuck.lever@oracle.com>
662306a36Sopenharmony_ci *
762306a36Sopenharmony_ci * Copyright (c) 2023, Oracle and/or its affiliates.
862306a36Sopenharmony_ci */
962306a36Sopenharmony_ci
1062306a36Sopenharmony_ci#include <linux/types.h>
1162306a36Sopenharmony_ci#include <linux/socket.h>
1262306a36Sopenharmony_ci#include <linux/kernel.h>
1362306a36Sopenharmony_ci#include <linux/module.h>
1462306a36Sopenharmony_ci#include <linux/skbuff.h>
1562306a36Sopenharmony_ci#include <linux/mm.h>
1662306a36Sopenharmony_ci
1762306a36Sopenharmony_ci#include <net/sock.h>
1862306a36Sopenharmony_ci#include <net/genetlink.h>
1962306a36Sopenharmony_ci#include <net/netns/generic.h>
2062306a36Sopenharmony_ci
2162306a36Sopenharmony_ci#include <kunit/visibility.h>
2262306a36Sopenharmony_ci
2362306a36Sopenharmony_ci#include <uapi/linux/handshake.h>
2462306a36Sopenharmony_ci#include "handshake.h"
2562306a36Sopenharmony_ci#include "genl.h"
2662306a36Sopenharmony_ci
2762306a36Sopenharmony_ci#include <trace/events/handshake.h>
2862306a36Sopenharmony_ci
2962306a36Sopenharmony_ci/**
3062306a36Sopenharmony_ci * handshake_genl_notify - Notify handlers that a request is waiting
3162306a36Sopenharmony_ci * @net: target network namespace
3262306a36Sopenharmony_ci * @proto: handshake protocol
3362306a36Sopenharmony_ci * @flags: memory allocation control flags
3462306a36Sopenharmony_ci *
3562306a36Sopenharmony_ci * Returns zero on success or a negative errno if notification failed.
3662306a36Sopenharmony_ci */
3762306a36Sopenharmony_ciint handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
3862306a36Sopenharmony_ci			  gfp_t flags)
3962306a36Sopenharmony_ci{
4062306a36Sopenharmony_ci	struct sk_buff *msg;
4162306a36Sopenharmony_ci	void *hdr;
4262306a36Sopenharmony_ci
4362306a36Sopenharmony_ci	/* Disable notifications during unit testing */
4462306a36Sopenharmony_ci	if (!test_bit(HANDSHAKE_F_PROTO_NOTIFY, &proto->hp_flags))
4562306a36Sopenharmony_ci		return 0;
4662306a36Sopenharmony_ci
4762306a36Sopenharmony_ci	if (!genl_has_listeners(&handshake_nl_family, net,
4862306a36Sopenharmony_ci				proto->hp_handler_class))
4962306a36Sopenharmony_ci		return -ESRCH;
5062306a36Sopenharmony_ci
5162306a36Sopenharmony_ci	msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, flags);
5262306a36Sopenharmony_ci	if (!msg)
5362306a36Sopenharmony_ci		return -ENOMEM;
5462306a36Sopenharmony_ci
5562306a36Sopenharmony_ci	hdr = genlmsg_put(msg, 0, 0, &handshake_nl_family, 0,
5662306a36Sopenharmony_ci			  HANDSHAKE_CMD_READY);
5762306a36Sopenharmony_ci	if (!hdr)
5862306a36Sopenharmony_ci		goto out_free;
5962306a36Sopenharmony_ci
6062306a36Sopenharmony_ci	if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_HANDLER_CLASS,
6162306a36Sopenharmony_ci			proto->hp_handler_class) < 0) {
6262306a36Sopenharmony_ci		genlmsg_cancel(msg, hdr);
6362306a36Sopenharmony_ci		goto out_free;
6462306a36Sopenharmony_ci	}
6562306a36Sopenharmony_ci
6662306a36Sopenharmony_ci	genlmsg_end(msg, hdr);
6762306a36Sopenharmony_ci	return genlmsg_multicast_netns(&handshake_nl_family, net, msg,
6862306a36Sopenharmony_ci				       0, proto->hp_handler_class, flags);
6962306a36Sopenharmony_ci
7062306a36Sopenharmony_ciout_free:
7162306a36Sopenharmony_ci	nlmsg_free(msg);
7262306a36Sopenharmony_ci	return -EMSGSIZE;
7362306a36Sopenharmony_ci}
7462306a36Sopenharmony_ci
7562306a36Sopenharmony_ci/**
7662306a36Sopenharmony_ci * handshake_genl_put - Create a generic netlink message header
7762306a36Sopenharmony_ci * @msg: buffer in which to create the header
7862306a36Sopenharmony_ci * @info: generic netlink message context
7962306a36Sopenharmony_ci *
8062306a36Sopenharmony_ci * Returns a ready-to-use header, or NULL.
8162306a36Sopenharmony_ci */
8262306a36Sopenharmony_cistruct nlmsghdr *handshake_genl_put(struct sk_buff *msg,
8362306a36Sopenharmony_ci				    struct genl_info *info)
8462306a36Sopenharmony_ci{
8562306a36Sopenharmony_ci	return genlmsg_put(msg, info->snd_portid, info->snd_seq,
8662306a36Sopenharmony_ci			   &handshake_nl_family, 0, info->genlhdr->cmd);
8762306a36Sopenharmony_ci}
8862306a36Sopenharmony_ciEXPORT_SYMBOL(handshake_genl_put);
8962306a36Sopenharmony_ci
9062306a36Sopenharmony_ciint handshake_nl_accept_doit(struct sk_buff *skb, struct genl_info *info)
9162306a36Sopenharmony_ci{
9262306a36Sopenharmony_ci	struct net *net = sock_net(skb->sk);
9362306a36Sopenharmony_ci	struct handshake_net *hn = handshake_pernet(net);
9462306a36Sopenharmony_ci	struct handshake_req *req = NULL;
9562306a36Sopenharmony_ci	struct socket *sock;
9662306a36Sopenharmony_ci	int class, fd, err;
9762306a36Sopenharmony_ci
9862306a36Sopenharmony_ci	err = -EOPNOTSUPP;
9962306a36Sopenharmony_ci	if (!hn)
10062306a36Sopenharmony_ci		goto out_status;
10162306a36Sopenharmony_ci
10262306a36Sopenharmony_ci	err = -EINVAL;
10362306a36Sopenharmony_ci	if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_ACCEPT_HANDLER_CLASS))
10462306a36Sopenharmony_ci		goto out_status;
10562306a36Sopenharmony_ci	class = nla_get_u32(info->attrs[HANDSHAKE_A_ACCEPT_HANDLER_CLASS]);
10662306a36Sopenharmony_ci
10762306a36Sopenharmony_ci	err = -EAGAIN;
10862306a36Sopenharmony_ci	req = handshake_req_next(hn, class);
10962306a36Sopenharmony_ci	if (!req)
11062306a36Sopenharmony_ci		goto out_status;
11162306a36Sopenharmony_ci
11262306a36Sopenharmony_ci	sock = req->hr_sk->sk_socket;
11362306a36Sopenharmony_ci	fd = get_unused_fd_flags(O_CLOEXEC);
11462306a36Sopenharmony_ci	if (fd < 0) {
11562306a36Sopenharmony_ci		err = fd;
11662306a36Sopenharmony_ci		goto out_complete;
11762306a36Sopenharmony_ci	}
11862306a36Sopenharmony_ci
11962306a36Sopenharmony_ci	err = req->hr_proto->hp_accept(req, info, fd);
12062306a36Sopenharmony_ci	if (err) {
12162306a36Sopenharmony_ci		put_unused_fd(fd);
12262306a36Sopenharmony_ci		goto out_complete;
12362306a36Sopenharmony_ci	}
12462306a36Sopenharmony_ci
12562306a36Sopenharmony_ci	fd_install(fd, get_file(sock->file));
12662306a36Sopenharmony_ci
12762306a36Sopenharmony_ci	trace_handshake_cmd_accept(net, req, req->hr_sk, fd);
12862306a36Sopenharmony_ci	return 0;
12962306a36Sopenharmony_ci
13062306a36Sopenharmony_ciout_complete:
13162306a36Sopenharmony_ci	handshake_complete(req, -EIO, NULL);
13262306a36Sopenharmony_ciout_status:
13362306a36Sopenharmony_ci	trace_handshake_cmd_accept_err(net, req, NULL, err);
13462306a36Sopenharmony_ci	return err;
13562306a36Sopenharmony_ci}
13662306a36Sopenharmony_ci
13762306a36Sopenharmony_ciint handshake_nl_done_doit(struct sk_buff *skb, struct genl_info *info)
13862306a36Sopenharmony_ci{
13962306a36Sopenharmony_ci	struct net *net = sock_net(skb->sk);
14062306a36Sopenharmony_ci	struct handshake_req *req;
14162306a36Sopenharmony_ci	struct socket *sock;
14262306a36Sopenharmony_ci	int fd, status, err;
14362306a36Sopenharmony_ci
14462306a36Sopenharmony_ci	if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_DONE_SOCKFD))
14562306a36Sopenharmony_ci		return -EINVAL;
14662306a36Sopenharmony_ci	fd = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_SOCKFD]);
14762306a36Sopenharmony_ci
14862306a36Sopenharmony_ci	sock = sockfd_lookup(fd, &err);
14962306a36Sopenharmony_ci	if (!sock)
15062306a36Sopenharmony_ci		return err;
15162306a36Sopenharmony_ci
15262306a36Sopenharmony_ci	req = handshake_req_hash_lookup(sock->sk);
15362306a36Sopenharmony_ci	if (!req) {
15462306a36Sopenharmony_ci		err = -EBUSY;
15562306a36Sopenharmony_ci		trace_handshake_cmd_done_err(net, req, sock->sk, err);
15662306a36Sopenharmony_ci		fput(sock->file);
15762306a36Sopenharmony_ci		return err;
15862306a36Sopenharmony_ci	}
15962306a36Sopenharmony_ci
16062306a36Sopenharmony_ci	trace_handshake_cmd_done(net, req, sock->sk, fd);
16162306a36Sopenharmony_ci
16262306a36Sopenharmony_ci	status = -EIO;
16362306a36Sopenharmony_ci	if (info->attrs[HANDSHAKE_A_DONE_STATUS])
16462306a36Sopenharmony_ci		status = nla_get_u32(info->attrs[HANDSHAKE_A_DONE_STATUS]);
16562306a36Sopenharmony_ci
16662306a36Sopenharmony_ci	handshake_complete(req, status, info);
16762306a36Sopenharmony_ci	fput(sock->file);
16862306a36Sopenharmony_ci	return 0;
16962306a36Sopenharmony_ci}
17062306a36Sopenharmony_ci
17162306a36Sopenharmony_cistatic unsigned int handshake_net_id;
17262306a36Sopenharmony_ci
17362306a36Sopenharmony_cistatic int __net_init handshake_net_init(struct net *net)
17462306a36Sopenharmony_ci{
17562306a36Sopenharmony_ci	struct handshake_net *hn = net_generic(net, handshake_net_id);
17662306a36Sopenharmony_ci	unsigned long tmp;
17762306a36Sopenharmony_ci	struct sysinfo si;
17862306a36Sopenharmony_ci
17962306a36Sopenharmony_ci	/*
18062306a36Sopenharmony_ci	 * Arbitrary limit to prevent handshakes that do not make
18162306a36Sopenharmony_ci	 * progress from clogging up the system. The cap scales up
18262306a36Sopenharmony_ci	 * with the amount of physical memory on the system.
18362306a36Sopenharmony_ci	 */
18462306a36Sopenharmony_ci	si_meminfo(&si);
18562306a36Sopenharmony_ci	tmp = si.totalram / (25 * si.mem_unit);
18662306a36Sopenharmony_ci	hn->hn_pending_max = clamp(tmp, 3UL, 50UL);
18762306a36Sopenharmony_ci
18862306a36Sopenharmony_ci	spin_lock_init(&hn->hn_lock);
18962306a36Sopenharmony_ci	hn->hn_pending = 0;
19062306a36Sopenharmony_ci	hn->hn_flags = 0;
19162306a36Sopenharmony_ci	INIT_LIST_HEAD(&hn->hn_requests);
19262306a36Sopenharmony_ci	return 0;
19362306a36Sopenharmony_ci}
19462306a36Sopenharmony_ci
19562306a36Sopenharmony_cistatic void __net_exit handshake_net_exit(struct net *net)
19662306a36Sopenharmony_ci{
19762306a36Sopenharmony_ci	struct handshake_net *hn = net_generic(net, handshake_net_id);
19862306a36Sopenharmony_ci	struct handshake_req *req;
19962306a36Sopenharmony_ci	LIST_HEAD(requests);
20062306a36Sopenharmony_ci
20162306a36Sopenharmony_ci	/*
20262306a36Sopenharmony_ci	 * Drain the net's pending list. Requests that have been
20362306a36Sopenharmony_ci	 * accepted and are in progress will be destroyed when
20462306a36Sopenharmony_ci	 * the socket is closed.
20562306a36Sopenharmony_ci	 */
20662306a36Sopenharmony_ci	spin_lock(&hn->hn_lock);
20762306a36Sopenharmony_ci	set_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags);
20862306a36Sopenharmony_ci	list_splice_init(&requests, &hn->hn_requests);
20962306a36Sopenharmony_ci	spin_unlock(&hn->hn_lock);
21062306a36Sopenharmony_ci
21162306a36Sopenharmony_ci	while (!list_empty(&requests)) {
21262306a36Sopenharmony_ci		req = list_first_entry(&requests, struct handshake_req, hr_list);
21362306a36Sopenharmony_ci		list_del(&req->hr_list);
21462306a36Sopenharmony_ci
21562306a36Sopenharmony_ci		/*
21662306a36Sopenharmony_ci		 * Requests on this list have not yet been
21762306a36Sopenharmony_ci		 * accepted, so they do not have an fd to put.
21862306a36Sopenharmony_ci		 */
21962306a36Sopenharmony_ci
22062306a36Sopenharmony_ci		handshake_complete(req, -ETIMEDOUT, NULL);
22162306a36Sopenharmony_ci	}
22262306a36Sopenharmony_ci}
22362306a36Sopenharmony_ci
22462306a36Sopenharmony_cistatic struct pernet_operations handshake_genl_net_ops = {
22562306a36Sopenharmony_ci	.init		= handshake_net_init,
22662306a36Sopenharmony_ci	.exit		= handshake_net_exit,
22762306a36Sopenharmony_ci	.id		= &handshake_net_id,
22862306a36Sopenharmony_ci	.size		= sizeof(struct handshake_net),
22962306a36Sopenharmony_ci};
23062306a36Sopenharmony_ci
23162306a36Sopenharmony_ci/**
23262306a36Sopenharmony_ci * handshake_pernet - Get the handshake private per-net structure
23362306a36Sopenharmony_ci * @net: network namespace
23462306a36Sopenharmony_ci *
23562306a36Sopenharmony_ci * Returns a pointer to the net's private per-net structure for the
23662306a36Sopenharmony_ci * handshake module, or NULL if handshake_init() failed.
23762306a36Sopenharmony_ci */
23862306a36Sopenharmony_cistruct handshake_net *handshake_pernet(struct net *net)
23962306a36Sopenharmony_ci{
24062306a36Sopenharmony_ci	return handshake_net_id ?
24162306a36Sopenharmony_ci		net_generic(net, handshake_net_id) : NULL;
24262306a36Sopenharmony_ci}
24362306a36Sopenharmony_ciEXPORT_SYMBOL_IF_KUNIT(handshake_pernet);
24462306a36Sopenharmony_ci
24562306a36Sopenharmony_cistatic int __init handshake_init(void)
24662306a36Sopenharmony_ci{
24762306a36Sopenharmony_ci	int ret;
24862306a36Sopenharmony_ci
24962306a36Sopenharmony_ci	ret = handshake_req_hash_init();
25062306a36Sopenharmony_ci	if (ret) {
25162306a36Sopenharmony_ci		pr_warn("handshake: hash initialization failed (%d)\n", ret);
25262306a36Sopenharmony_ci		return ret;
25362306a36Sopenharmony_ci	}
25462306a36Sopenharmony_ci
25562306a36Sopenharmony_ci	ret = genl_register_family(&handshake_nl_family);
25662306a36Sopenharmony_ci	if (ret) {
25762306a36Sopenharmony_ci		pr_warn("handshake: netlink registration failed (%d)\n", ret);
25862306a36Sopenharmony_ci		handshake_req_hash_destroy();
25962306a36Sopenharmony_ci		return ret;
26062306a36Sopenharmony_ci	}
26162306a36Sopenharmony_ci
26262306a36Sopenharmony_ci	/*
26362306a36Sopenharmony_ci	 * ORDER: register_pernet_subsys must be done last.
26462306a36Sopenharmony_ci	 *
26562306a36Sopenharmony_ci	 *	If initialization does not make it past pernet_subsys
26662306a36Sopenharmony_ci	 *	registration, then handshake_net_id will remain 0. That
26762306a36Sopenharmony_ci	 *	shunts the handshake consumer API to return ENOTSUPP
26862306a36Sopenharmony_ci	 *	to prevent it from dereferencing something that hasn't
26962306a36Sopenharmony_ci	 *	been allocated.
27062306a36Sopenharmony_ci	 */
27162306a36Sopenharmony_ci	ret = register_pernet_subsys(&handshake_genl_net_ops);
27262306a36Sopenharmony_ci	if (ret) {
27362306a36Sopenharmony_ci		pr_warn("handshake: pernet registration failed (%d)\n", ret);
27462306a36Sopenharmony_ci		genl_unregister_family(&handshake_nl_family);
27562306a36Sopenharmony_ci		handshake_req_hash_destroy();
27662306a36Sopenharmony_ci	}
27762306a36Sopenharmony_ci
27862306a36Sopenharmony_ci	return ret;
27962306a36Sopenharmony_ci}
28062306a36Sopenharmony_ci
28162306a36Sopenharmony_cistatic void __exit handshake_exit(void)
28262306a36Sopenharmony_ci{
28362306a36Sopenharmony_ci	unregister_pernet_subsys(&handshake_genl_net_ops);
28462306a36Sopenharmony_ci	handshake_net_id = 0;
28562306a36Sopenharmony_ci
28662306a36Sopenharmony_ci	handshake_req_hash_destroy();
28762306a36Sopenharmony_ci	genl_unregister_family(&handshake_nl_family);
28862306a36Sopenharmony_ci}
28962306a36Sopenharmony_ci
29062306a36Sopenharmony_cimodule_init(handshake_init);
29162306a36Sopenharmony_cimodule_exit(handshake_exit);
292