162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0
262306a36Sopenharmony_ci/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
362306a36Sopenharmony_ci
462306a36Sopenharmony_ci#include <linux/skmsg.h>
562306a36Sopenharmony_ci#include <net/sock.h>
662306a36Sopenharmony_ci#include <net/udp.h>
762306a36Sopenharmony_ci#include <net/inet_common.h>
862306a36Sopenharmony_ci
962306a36Sopenharmony_ci#include "udp_impl.h"
1062306a36Sopenharmony_ci
1162306a36Sopenharmony_cistatic struct proto *udpv6_prot_saved __read_mostly;
1262306a36Sopenharmony_ci
1362306a36Sopenharmony_cistatic int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
1462306a36Sopenharmony_ci			  int flags, int *addr_len)
1562306a36Sopenharmony_ci{
1662306a36Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6)
1762306a36Sopenharmony_ci	if (sk->sk_family == AF_INET6)
1862306a36Sopenharmony_ci		return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len);
1962306a36Sopenharmony_ci#endif
2062306a36Sopenharmony_ci	return udp_prot.recvmsg(sk, msg, len, flags, addr_len);
2162306a36Sopenharmony_ci}
2262306a36Sopenharmony_ci
2362306a36Sopenharmony_cistatic bool udp_sk_has_data(struct sock *sk)
2462306a36Sopenharmony_ci{
2562306a36Sopenharmony_ci	return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
2662306a36Sopenharmony_ci	       !skb_queue_empty(&sk->sk_receive_queue);
2762306a36Sopenharmony_ci}
2862306a36Sopenharmony_ci
2962306a36Sopenharmony_cistatic bool psock_has_data(struct sk_psock *psock)
3062306a36Sopenharmony_ci{
3162306a36Sopenharmony_ci	return !skb_queue_empty(&psock->ingress_skb) ||
3262306a36Sopenharmony_ci	       !sk_psock_queue_empty(psock);
3362306a36Sopenharmony_ci}
3462306a36Sopenharmony_ci
3562306a36Sopenharmony_ci#define udp_msg_has_data(__sk, __psock)	\
3662306a36Sopenharmony_ci		({ udp_sk_has_data(__sk) || psock_has_data(__psock); })
3762306a36Sopenharmony_ci
3862306a36Sopenharmony_cistatic int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
3962306a36Sopenharmony_ci			     long timeo)
4062306a36Sopenharmony_ci{
4162306a36Sopenharmony_ci	DEFINE_WAIT_FUNC(wait, woken_wake_function);
4262306a36Sopenharmony_ci	int ret = 0;
4362306a36Sopenharmony_ci
4462306a36Sopenharmony_ci	if (sk->sk_shutdown & RCV_SHUTDOWN)
4562306a36Sopenharmony_ci		return 1;
4662306a36Sopenharmony_ci
4762306a36Sopenharmony_ci	if (!timeo)
4862306a36Sopenharmony_ci		return ret;
4962306a36Sopenharmony_ci
5062306a36Sopenharmony_ci	add_wait_queue(sk_sleep(sk), &wait);
5162306a36Sopenharmony_ci	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
5262306a36Sopenharmony_ci	ret = udp_msg_has_data(sk, psock);
5362306a36Sopenharmony_ci	if (!ret) {
5462306a36Sopenharmony_ci		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
5562306a36Sopenharmony_ci		ret = udp_msg_has_data(sk, psock);
5662306a36Sopenharmony_ci	}
5762306a36Sopenharmony_ci	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
5862306a36Sopenharmony_ci	remove_wait_queue(sk_sleep(sk), &wait);
5962306a36Sopenharmony_ci	return ret;
6062306a36Sopenharmony_ci}
6162306a36Sopenharmony_ci
6262306a36Sopenharmony_cistatic int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
6362306a36Sopenharmony_ci			   int flags, int *addr_len)
6462306a36Sopenharmony_ci{
6562306a36Sopenharmony_ci	struct sk_psock *psock;
6662306a36Sopenharmony_ci	int copied, ret;
6762306a36Sopenharmony_ci
6862306a36Sopenharmony_ci	if (unlikely(flags & MSG_ERRQUEUE))
6962306a36Sopenharmony_ci		return inet_recv_error(sk, msg, len, addr_len);
7062306a36Sopenharmony_ci
7162306a36Sopenharmony_ci	if (!len)
7262306a36Sopenharmony_ci		return 0;
7362306a36Sopenharmony_ci
7462306a36Sopenharmony_ci	psock = sk_psock_get(sk);
7562306a36Sopenharmony_ci	if (unlikely(!psock))
7662306a36Sopenharmony_ci		return sk_udp_recvmsg(sk, msg, len, flags, addr_len);
7762306a36Sopenharmony_ci
7862306a36Sopenharmony_ci	if (!psock_has_data(psock)) {
7962306a36Sopenharmony_ci		ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len);
8062306a36Sopenharmony_ci		goto out;
8162306a36Sopenharmony_ci	}
8262306a36Sopenharmony_ci
8362306a36Sopenharmony_cimsg_bytes_ready:
8462306a36Sopenharmony_ci	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
8562306a36Sopenharmony_ci	if (!copied) {
8662306a36Sopenharmony_ci		long timeo;
8762306a36Sopenharmony_ci		int data;
8862306a36Sopenharmony_ci
8962306a36Sopenharmony_ci		timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
9062306a36Sopenharmony_ci		data = udp_msg_wait_data(sk, psock, timeo);
9162306a36Sopenharmony_ci		if (data) {
9262306a36Sopenharmony_ci			if (psock_has_data(psock))
9362306a36Sopenharmony_ci				goto msg_bytes_ready;
9462306a36Sopenharmony_ci			ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len);
9562306a36Sopenharmony_ci			goto out;
9662306a36Sopenharmony_ci		}
9762306a36Sopenharmony_ci		copied = -EAGAIN;
9862306a36Sopenharmony_ci	}
9962306a36Sopenharmony_ci	ret = copied;
10062306a36Sopenharmony_ciout:
10162306a36Sopenharmony_ci	sk_psock_put(sk, psock);
10262306a36Sopenharmony_ci	return ret;
10362306a36Sopenharmony_ci}
10462306a36Sopenharmony_ci
10562306a36Sopenharmony_cienum {
10662306a36Sopenharmony_ci	UDP_BPF_IPV4,
10762306a36Sopenharmony_ci	UDP_BPF_IPV6,
10862306a36Sopenharmony_ci	UDP_BPF_NUM_PROTS,
10962306a36Sopenharmony_ci};
11062306a36Sopenharmony_ci
11162306a36Sopenharmony_cistatic DEFINE_SPINLOCK(udpv6_prot_lock);
11262306a36Sopenharmony_cistatic struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
11362306a36Sopenharmony_ci
11462306a36Sopenharmony_cistatic void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
11562306a36Sopenharmony_ci{
11662306a36Sopenharmony_ci	*prot        = *base;
11762306a36Sopenharmony_ci	prot->close  = sock_map_close;
11862306a36Sopenharmony_ci	prot->recvmsg = udp_bpf_recvmsg;
11962306a36Sopenharmony_ci	prot->sock_is_readable = sk_msg_is_readable;
12062306a36Sopenharmony_ci}
12162306a36Sopenharmony_ci
12262306a36Sopenharmony_cistatic void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
12362306a36Sopenharmony_ci{
12462306a36Sopenharmony_ci	if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
12562306a36Sopenharmony_ci		spin_lock_bh(&udpv6_prot_lock);
12662306a36Sopenharmony_ci		if (likely(ops != udpv6_prot_saved)) {
12762306a36Sopenharmony_ci			udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
12862306a36Sopenharmony_ci			smp_store_release(&udpv6_prot_saved, ops);
12962306a36Sopenharmony_ci		}
13062306a36Sopenharmony_ci		spin_unlock_bh(&udpv6_prot_lock);
13162306a36Sopenharmony_ci	}
13262306a36Sopenharmony_ci}
13362306a36Sopenharmony_ci
13462306a36Sopenharmony_cistatic int __init udp_bpf_v4_build_proto(void)
13562306a36Sopenharmony_ci{
13662306a36Sopenharmony_ci	udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
13762306a36Sopenharmony_ci	return 0;
13862306a36Sopenharmony_ci}
13962306a36Sopenharmony_cilate_initcall(udp_bpf_v4_build_proto);
14062306a36Sopenharmony_ci
14162306a36Sopenharmony_ciint udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
14262306a36Sopenharmony_ci{
14362306a36Sopenharmony_ci	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
14462306a36Sopenharmony_ci
14562306a36Sopenharmony_ci	if (restore) {
14662306a36Sopenharmony_ci		sk->sk_write_space = psock->saved_write_space;
14762306a36Sopenharmony_ci		sock_replace_proto(sk, psock->sk_proto);
14862306a36Sopenharmony_ci		return 0;
14962306a36Sopenharmony_ci	}
15062306a36Sopenharmony_ci
15162306a36Sopenharmony_ci	if (sk->sk_family == AF_INET6)
15262306a36Sopenharmony_ci		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
15362306a36Sopenharmony_ci
15462306a36Sopenharmony_ci	sock_replace_proto(sk, &udp_bpf_prots[family]);
15562306a36Sopenharmony_ci	return 0;
15662306a36Sopenharmony_ci}
15762306a36Sopenharmony_ciEXPORT_SYMBOL_GPL(udp_bpf_update_proto);
158