162306a36Sopenharmony_ci// SPDX-License-Identifier: GPL-2.0 262306a36Sopenharmony_ci/* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */ 362306a36Sopenharmony_ci 462306a36Sopenharmony_ci#include <linux/skmsg.h> 562306a36Sopenharmony_ci#include <linux/bpf.h> 662306a36Sopenharmony_ci#include <net/sock.h> 762306a36Sopenharmony_ci#include <net/af_unix.h> 862306a36Sopenharmony_ci 962306a36Sopenharmony_ci#define unix_sk_has_data(__sk, __psock) \ 1062306a36Sopenharmony_ci ({ !skb_queue_empty(&__sk->sk_receive_queue) || \ 1162306a36Sopenharmony_ci !skb_queue_empty(&__psock->ingress_skb) || \ 1262306a36Sopenharmony_ci !list_empty(&__psock->ingress_msg); \ 1362306a36Sopenharmony_ci }) 1462306a36Sopenharmony_ci 1562306a36Sopenharmony_cistatic int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock, 1662306a36Sopenharmony_ci long timeo) 1762306a36Sopenharmony_ci{ 1862306a36Sopenharmony_ci DEFINE_WAIT_FUNC(wait, woken_wake_function); 1962306a36Sopenharmony_ci struct unix_sock *u = unix_sk(sk); 2062306a36Sopenharmony_ci int ret = 0; 2162306a36Sopenharmony_ci 2262306a36Sopenharmony_ci if (sk->sk_shutdown & RCV_SHUTDOWN) 2362306a36Sopenharmony_ci return 1; 2462306a36Sopenharmony_ci 2562306a36Sopenharmony_ci if (!timeo) 2662306a36Sopenharmony_ci return ret; 2762306a36Sopenharmony_ci 2862306a36Sopenharmony_ci add_wait_queue(sk_sleep(sk), &wait); 2962306a36Sopenharmony_ci sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 3062306a36Sopenharmony_ci if (!unix_sk_has_data(sk, psock)) { 3162306a36Sopenharmony_ci mutex_unlock(&u->iolock); 3262306a36Sopenharmony_ci wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 3362306a36Sopenharmony_ci mutex_lock(&u->iolock); 3462306a36Sopenharmony_ci ret = unix_sk_has_data(sk, psock); 3562306a36Sopenharmony_ci } 3662306a36Sopenharmony_ci sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 3762306a36Sopenharmony_ci remove_wait_queue(sk_sleep(sk), &wait); 3862306a36Sopenharmony_ci return ret; 3962306a36Sopenharmony_ci} 4062306a36Sopenharmony_ci 4162306a36Sopenharmony_cistatic int __unix_recvmsg(struct sock *sk, struct msghdr *msg, 4262306a36Sopenharmony_ci size_t len, int flags) 4362306a36Sopenharmony_ci{ 4462306a36Sopenharmony_ci if (sk->sk_type == SOCK_DGRAM) 4562306a36Sopenharmony_ci return __unix_dgram_recvmsg(sk, msg, len, flags); 4662306a36Sopenharmony_ci else 4762306a36Sopenharmony_ci return __unix_stream_recvmsg(sk, msg, len, flags); 4862306a36Sopenharmony_ci} 4962306a36Sopenharmony_ci 5062306a36Sopenharmony_cistatic int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg, 5162306a36Sopenharmony_ci size_t len, int flags, int *addr_len) 5262306a36Sopenharmony_ci{ 5362306a36Sopenharmony_ci struct unix_sock *u = unix_sk(sk); 5462306a36Sopenharmony_ci struct sk_psock *psock; 5562306a36Sopenharmony_ci int copied; 5662306a36Sopenharmony_ci 5762306a36Sopenharmony_ci if (!len) 5862306a36Sopenharmony_ci return 0; 5962306a36Sopenharmony_ci 6062306a36Sopenharmony_ci psock = sk_psock_get(sk); 6162306a36Sopenharmony_ci if (unlikely(!psock)) 6262306a36Sopenharmony_ci return __unix_recvmsg(sk, msg, len, flags); 6362306a36Sopenharmony_ci 6462306a36Sopenharmony_ci mutex_lock(&u->iolock); 6562306a36Sopenharmony_ci if (!skb_queue_empty(&sk->sk_receive_queue) && 6662306a36Sopenharmony_ci sk_psock_queue_empty(psock)) { 6762306a36Sopenharmony_ci mutex_unlock(&u->iolock); 6862306a36Sopenharmony_ci sk_psock_put(sk, psock); 6962306a36Sopenharmony_ci return __unix_recvmsg(sk, msg, len, flags); 7062306a36Sopenharmony_ci } 7162306a36Sopenharmony_ci 7262306a36Sopenharmony_cimsg_bytes_ready: 7362306a36Sopenharmony_ci copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 7462306a36Sopenharmony_ci if (!copied) { 7562306a36Sopenharmony_ci long timeo; 7662306a36Sopenharmony_ci int data; 7762306a36Sopenharmony_ci 7862306a36Sopenharmony_ci timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 7962306a36Sopenharmony_ci data = unix_msg_wait_data(sk, psock, timeo); 8062306a36Sopenharmony_ci if (data) { 8162306a36Sopenharmony_ci if (!sk_psock_queue_empty(psock)) 8262306a36Sopenharmony_ci goto msg_bytes_ready; 8362306a36Sopenharmony_ci mutex_unlock(&u->iolock); 8462306a36Sopenharmony_ci sk_psock_put(sk, psock); 8562306a36Sopenharmony_ci return __unix_recvmsg(sk, msg, len, flags); 8662306a36Sopenharmony_ci } 8762306a36Sopenharmony_ci copied = -EAGAIN; 8862306a36Sopenharmony_ci } 8962306a36Sopenharmony_ci mutex_unlock(&u->iolock); 9062306a36Sopenharmony_ci sk_psock_put(sk, psock); 9162306a36Sopenharmony_ci return copied; 9262306a36Sopenharmony_ci} 9362306a36Sopenharmony_ci 9462306a36Sopenharmony_cistatic struct proto *unix_dgram_prot_saved __read_mostly; 9562306a36Sopenharmony_cistatic DEFINE_SPINLOCK(unix_dgram_prot_lock); 9662306a36Sopenharmony_cistatic struct proto unix_dgram_bpf_prot; 9762306a36Sopenharmony_ci 9862306a36Sopenharmony_cistatic struct proto *unix_stream_prot_saved __read_mostly; 9962306a36Sopenharmony_cistatic DEFINE_SPINLOCK(unix_stream_prot_lock); 10062306a36Sopenharmony_cistatic struct proto unix_stream_bpf_prot; 10162306a36Sopenharmony_ci 10262306a36Sopenharmony_cistatic void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 10362306a36Sopenharmony_ci{ 10462306a36Sopenharmony_ci *prot = *base; 10562306a36Sopenharmony_ci prot->close = sock_map_close; 10662306a36Sopenharmony_ci prot->recvmsg = unix_bpf_recvmsg; 10762306a36Sopenharmony_ci prot->sock_is_readable = sk_msg_is_readable; 10862306a36Sopenharmony_ci} 10962306a36Sopenharmony_ci 11062306a36Sopenharmony_cistatic void unix_stream_bpf_rebuild_protos(struct proto *prot, 11162306a36Sopenharmony_ci const struct proto *base) 11262306a36Sopenharmony_ci{ 11362306a36Sopenharmony_ci *prot = *base; 11462306a36Sopenharmony_ci prot->close = sock_map_close; 11562306a36Sopenharmony_ci prot->recvmsg = unix_bpf_recvmsg; 11662306a36Sopenharmony_ci prot->sock_is_readable = sk_msg_is_readable; 11762306a36Sopenharmony_ci prot->unhash = sock_map_unhash; 11862306a36Sopenharmony_ci} 11962306a36Sopenharmony_ci 12062306a36Sopenharmony_cistatic void unix_dgram_bpf_check_needs_rebuild(struct proto *ops) 12162306a36Sopenharmony_ci{ 12262306a36Sopenharmony_ci if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) { 12362306a36Sopenharmony_ci spin_lock_bh(&unix_dgram_prot_lock); 12462306a36Sopenharmony_ci if (likely(ops != unix_dgram_prot_saved)) { 12562306a36Sopenharmony_ci unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops); 12662306a36Sopenharmony_ci smp_store_release(&unix_dgram_prot_saved, ops); 12762306a36Sopenharmony_ci } 12862306a36Sopenharmony_ci spin_unlock_bh(&unix_dgram_prot_lock); 12962306a36Sopenharmony_ci } 13062306a36Sopenharmony_ci} 13162306a36Sopenharmony_ci 13262306a36Sopenharmony_cistatic void unix_stream_bpf_check_needs_rebuild(struct proto *ops) 13362306a36Sopenharmony_ci{ 13462306a36Sopenharmony_ci if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) { 13562306a36Sopenharmony_ci spin_lock_bh(&unix_stream_prot_lock); 13662306a36Sopenharmony_ci if (likely(ops != unix_stream_prot_saved)) { 13762306a36Sopenharmony_ci unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops); 13862306a36Sopenharmony_ci smp_store_release(&unix_stream_prot_saved, ops); 13962306a36Sopenharmony_ci } 14062306a36Sopenharmony_ci spin_unlock_bh(&unix_stream_prot_lock); 14162306a36Sopenharmony_ci } 14262306a36Sopenharmony_ci} 14362306a36Sopenharmony_ci 14462306a36Sopenharmony_ciint unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 14562306a36Sopenharmony_ci{ 14662306a36Sopenharmony_ci if (sk->sk_type != SOCK_DGRAM) 14762306a36Sopenharmony_ci return -EOPNOTSUPP; 14862306a36Sopenharmony_ci 14962306a36Sopenharmony_ci if (restore) { 15062306a36Sopenharmony_ci sk->sk_write_space = psock->saved_write_space; 15162306a36Sopenharmony_ci sock_replace_proto(sk, psock->sk_proto); 15262306a36Sopenharmony_ci return 0; 15362306a36Sopenharmony_ci } 15462306a36Sopenharmony_ci 15562306a36Sopenharmony_ci unix_dgram_bpf_check_needs_rebuild(psock->sk_proto); 15662306a36Sopenharmony_ci sock_replace_proto(sk, &unix_dgram_bpf_prot); 15762306a36Sopenharmony_ci return 0; 15862306a36Sopenharmony_ci} 15962306a36Sopenharmony_ci 16062306a36Sopenharmony_ciint unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 16162306a36Sopenharmony_ci{ 16262306a36Sopenharmony_ci struct sock *sk_pair; 16362306a36Sopenharmony_ci 16462306a36Sopenharmony_ci /* Restore does not decrement the sk_pair reference yet because we must 16562306a36Sopenharmony_ci * keep the a reference to the socket until after an RCU grace period 16662306a36Sopenharmony_ci * and any pending sends have completed. 16762306a36Sopenharmony_ci */ 16862306a36Sopenharmony_ci if (restore) { 16962306a36Sopenharmony_ci sk->sk_write_space = psock->saved_write_space; 17062306a36Sopenharmony_ci sock_replace_proto(sk, psock->sk_proto); 17162306a36Sopenharmony_ci return 0; 17262306a36Sopenharmony_ci } 17362306a36Sopenharmony_ci 17462306a36Sopenharmony_ci /* psock_update_sk_prot can be called multiple times if psock is 17562306a36Sopenharmony_ci * added to multiple maps and/or slots in the same map. There is 17662306a36Sopenharmony_ci * also an edge case where replacing a psock with itself can trigger 17762306a36Sopenharmony_ci * an extra psock_update_sk_prot during the insert process. So it 17862306a36Sopenharmony_ci * must be safe to do multiple calls. Here we need to ensure we don't 17962306a36Sopenharmony_ci * increment the refcnt through sock_hold many times. There will only 18062306a36Sopenharmony_ci * be a single matching destroy operation. 18162306a36Sopenharmony_ci */ 18262306a36Sopenharmony_ci if (!psock->sk_pair) { 18362306a36Sopenharmony_ci sk_pair = unix_peer(sk); 18462306a36Sopenharmony_ci sock_hold(sk_pair); 18562306a36Sopenharmony_ci psock->sk_pair = sk_pair; 18662306a36Sopenharmony_ci } 18762306a36Sopenharmony_ci 18862306a36Sopenharmony_ci unix_stream_bpf_check_needs_rebuild(psock->sk_proto); 18962306a36Sopenharmony_ci sock_replace_proto(sk, &unix_stream_bpf_prot); 19062306a36Sopenharmony_ci return 0; 19162306a36Sopenharmony_ci} 19262306a36Sopenharmony_ci 19362306a36Sopenharmony_civoid __init unix_bpf_build_proto(void) 19462306a36Sopenharmony_ci{ 19562306a36Sopenharmony_ci unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto); 19662306a36Sopenharmony_ci unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto); 19762306a36Sopenharmony_ci 19862306a36Sopenharmony_ci} 199