162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0 262306a36Sopenharmony_ci/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ 362306a36Sopenharmony_ci 462306a36Sopenharmony_ci#include <linux/skmsg.h> 562306a36Sopenharmony_ci#include <net/sock.h> 662306a36Sopenharmony_ci#include <net/udp.h> 762306a36Sopenharmony_ci#include <net/inet_common.h> 862306a36Sopenharmony_ci 962306a36Sopenharmony_ci#include "udp_impl.h" 1062306a36Sopenharmony_ci 1162306a36Sopenharmony_cistatic struct proto *udpv6_prot_saved __read_mostly; 1262306a36Sopenharmony_ci 1362306a36Sopenharmony_cistatic int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 1462306a36Sopenharmony_ci int flags, int *addr_len) 1562306a36Sopenharmony_ci{ 1662306a36Sopenharmony_ci#if IS_ENABLED(CONFIG_IPV6) 1762306a36Sopenharmony_ci if (sk->sk_family == AF_INET6) 1862306a36Sopenharmony_ci return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len); 1962306a36Sopenharmony_ci#endif 2062306a36Sopenharmony_ci return udp_prot.recvmsg(sk, msg, len, flags, addr_len); 2162306a36Sopenharmony_ci} 2262306a36Sopenharmony_ci 2362306a36Sopenharmony_cistatic bool udp_sk_has_data(struct sock *sk) 2462306a36Sopenharmony_ci{ 2562306a36Sopenharmony_ci return !skb_queue_empty(&udp_sk(sk)->reader_queue) || 2662306a36Sopenharmony_ci !skb_queue_empty(&sk->sk_receive_queue); 2762306a36Sopenharmony_ci} 2862306a36Sopenharmony_ci 2962306a36Sopenharmony_cistatic bool psock_has_data(struct sk_psock *psock) 3062306a36Sopenharmony_ci{ 3162306a36Sopenharmony_ci return !skb_queue_empty(&psock->ingress_skb) || 3262306a36Sopenharmony_ci !sk_psock_queue_empty(psock); 3362306a36Sopenharmony_ci} 3462306a36Sopenharmony_ci 3562306a36Sopenharmony_ci#define udp_msg_has_data(__sk, __psock) \ 3662306a36Sopenharmony_ci ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) 3762306a36Sopenharmony_ci 3862306a36Sopenharmony_cistatic int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 3962306a36Sopenharmony_ci long timeo) 4062306a36Sopenharmony_ci{ 4162306a36Sopenharmony_ci DEFINE_WAIT_FUNC(wait, woken_wake_function); 4262306a36Sopenharmony_ci int ret = 0; 4362306a36Sopenharmony_ci 4462306a36Sopenharmony_ci if (sk->sk_shutdown & RCV_SHUTDOWN) 4562306a36Sopenharmony_ci return 1; 4662306a36Sopenharmony_ci 4762306a36Sopenharmony_ci if (!timeo) 4862306a36Sopenharmony_ci return ret; 4962306a36Sopenharmony_ci 5062306a36Sopenharmony_ci add_wait_queue(sk_sleep(sk), &wait); 5162306a36Sopenharmony_ci sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 5262306a36Sopenharmony_ci ret = udp_msg_has_data(sk, psock); 5362306a36Sopenharmony_ci if (!ret) { 5462306a36Sopenharmony_ci wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 5562306a36Sopenharmony_ci ret = udp_msg_has_data(sk, psock); 5662306a36Sopenharmony_ci } 5762306a36Sopenharmony_ci sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 5862306a36Sopenharmony_ci remove_wait_queue(sk_sleep(sk), &wait); 5962306a36Sopenharmony_ci return ret; 6062306a36Sopenharmony_ci} 6162306a36Sopenharmony_ci 6262306a36Sopenharmony_cistatic int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 6362306a36Sopenharmony_ci int flags, int *addr_len) 6462306a36Sopenharmony_ci{ 6562306a36Sopenharmony_ci struct sk_psock *psock; 6662306a36Sopenharmony_ci int copied, ret; 6762306a36Sopenharmony_ci 6862306a36Sopenharmony_ci if (unlikely(flags & MSG_ERRQUEUE)) 6962306a36Sopenharmony_ci return inet_recv_error(sk, msg, len, addr_len); 7062306a36Sopenharmony_ci 7162306a36Sopenharmony_ci if (!len) 7262306a36Sopenharmony_ci return 0; 7362306a36Sopenharmony_ci 7462306a36Sopenharmony_ci psock = sk_psock_get(sk); 7562306a36Sopenharmony_ci if (unlikely(!psock)) 7662306a36Sopenharmony_ci return sk_udp_recvmsg(sk, msg, len, flags, addr_len); 7762306a36Sopenharmony_ci 7862306a36Sopenharmony_ci if (!psock_has_data(psock)) { 7962306a36Sopenharmony_ci ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); 8062306a36Sopenharmony_ci goto out; 8162306a36Sopenharmony_ci } 8262306a36Sopenharmony_ci 8362306a36Sopenharmony_cimsg_bytes_ready: 8462306a36Sopenharmony_ci copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 8562306a36Sopenharmony_ci if (!copied) { 8662306a36Sopenharmony_ci long timeo; 8762306a36Sopenharmony_ci int data; 8862306a36Sopenharmony_ci 8962306a36Sopenharmony_ci timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 9062306a36Sopenharmony_ci data = udp_msg_wait_data(sk, psock, timeo); 9162306a36Sopenharmony_ci if (data) { 9262306a36Sopenharmony_ci if (psock_has_data(psock)) 9362306a36Sopenharmony_ci goto msg_bytes_ready; 9462306a36Sopenharmony_ci ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); 9562306a36Sopenharmony_ci goto out; 9662306a36Sopenharmony_ci } 9762306a36Sopenharmony_ci copied = -EAGAIN; 9862306a36Sopenharmony_ci } 9962306a36Sopenharmony_ci ret = copied; 10062306a36Sopenharmony_ciout: 10162306a36Sopenharmony_ci sk_psock_put(sk, psock); 10262306a36Sopenharmony_ci return ret; 10362306a36Sopenharmony_ci} 10462306a36Sopenharmony_ci 10562306a36Sopenharmony_cienum { 10662306a36Sopenharmony_ci UDP_BPF_IPV4, 10762306a36Sopenharmony_ci UDP_BPF_IPV6, 10862306a36Sopenharmony_ci UDP_BPF_NUM_PROTS, 10962306a36Sopenharmony_ci}; 11062306a36Sopenharmony_ci 11162306a36Sopenharmony_cistatic DEFINE_SPINLOCK(udpv6_prot_lock); 11262306a36Sopenharmony_cistatic struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; 11362306a36Sopenharmony_ci 11462306a36Sopenharmony_cistatic void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 11562306a36Sopenharmony_ci{ 11662306a36Sopenharmony_ci *prot = *base; 11762306a36Sopenharmony_ci prot->close = sock_map_close; 11862306a36Sopenharmony_ci prot->recvmsg = udp_bpf_recvmsg; 11962306a36Sopenharmony_ci prot->sock_is_readable = sk_msg_is_readable; 12062306a36Sopenharmony_ci} 12162306a36Sopenharmony_ci 12262306a36Sopenharmony_cistatic void udp_bpf_check_v6_needs_rebuild(struct proto *ops) 12362306a36Sopenharmony_ci{ 12462306a36Sopenharmony_ci if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { 12562306a36Sopenharmony_ci spin_lock_bh(&udpv6_prot_lock); 12662306a36Sopenharmony_ci if (likely(ops != udpv6_prot_saved)) { 12762306a36Sopenharmony_ci udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); 12862306a36Sopenharmony_ci smp_store_release(&udpv6_prot_saved, ops); 12962306a36Sopenharmony_ci } 13062306a36Sopenharmony_ci spin_unlock_bh(&udpv6_prot_lock); 13162306a36Sopenharmony_ci } 13262306a36Sopenharmony_ci} 13362306a36Sopenharmony_ci 13462306a36Sopenharmony_cistatic int __init udp_bpf_v4_build_proto(void) 13562306a36Sopenharmony_ci{ 13662306a36Sopenharmony_ci udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); 13762306a36Sopenharmony_ci return 0; 13862306a36Sopenharmony_ci} 13962306a36Sopenharmony_cilate_initcall(udp_bpf_v4_build_proto); 14062306a36Sopenharmony_ci 14162306a36Sopenharmony_ciint udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 14262306a36Sopenharmony_ci{ 14362306a36Sopenharmony_ci int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; 14462306a36Sopenharmony_ci 14562306a36Sopenharmony_ci if (restore) { 14662306a36Sopenharmony_ci sk->sk_write_space = psock->saved_write_space; 14762306a36Sopenharmony_ci sock_replace_proto(sk, psock->sk_proto); 14862306a36Sopenharmony_ci return 0; 14962306a36Sopenharmony_ci } 15062306a36Sopenharmony_ci 15162306a36Sopenharmony_ci if (sk->sk_family == AF_INET6) 15262306a36Sopenharmony_ci udp_bpf_check_v6_needs_rebuild(psock->sk_proto); 15362306a36Sopenharmony_ci 15462306a36Sopenharmony_ci sock_replace_proto(sk, &udp_bpf_prots[family]); 15562306a36Sopenharmony_ci return 0; 15662306a36Sopenharmony_ci} 15762306a36Sopenharmony_ciEXPORT_SYMBOL_GPL(udp_bpf_update_proto); 158