1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * common code for virtio vsock
4 *
5 * Copyright (C) 2013-2015 Red Hat, Inc.
6 * Author: Asias He <asias@redhat.com>
7 *         Stefan Hajnoczi <stefanha@redhat.com>
8 */
9#include <linux/spinlock.h>
10#include <linux/module.h>
11#include <linux/sched/signal.h>
12#include <linux/ctype.h>
13#include <linux/list.h>
14#include <linux/virtio_vsock.h>
15#include <uapi/linux/vsockmon.h>
16
17#include <net/sock.h>
18#include <net/af_vsock.h>
19
20#define CREATE_TRACE_POINTS
21#include <trace/events/vsock_virtio_transport_common.h>
22
23/* How long to wait for graceful shutdown of a connection */
24#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
25
26/* Threshold for detecting small packets to copy */
27#define GOOD_COPY_LEN  128
28
29static const struct virtio_transport *
30virtio_transport_get_ops(struct vsock_sock *vsk)
31{
32	const struct vsock_transport *t = vsock_core_get_transport(vsk);
33
34	if (WARN_ON(!t))
35		return NULL;
36
37	return container_of(t, struct virtio_transport, transport);
38}
39
40/* Returns a new packet on success, otherwise returns NULL.
41 *
42 * If NULL is returned, errp is set to a negative errno.
43 */
44static struct sk_buff *
45virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
46			   size_t len,
47			   u32 src_cid,
48			   u32 src_port,
49			   u32 dst_cid,
50			   u32 dst_port)
51{
52	const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len;
53	struct virtio_vsock_hdr *hdr;
54	struct sk_buff *skb;
55	void *payload;
56	int err;
57
58	skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
59	if (!skb)
60		return NULL;
61
62	hdr = virtio_vsock_hdr(skb);
63	hdr->type	= cpu_to_le16(info->type);
64	hdr->op		= cpu_to_le16(info->op);
65	hdr->src_cid	= cpu_to_le64(src_cid);
66	hdr->dst_cid	= cpu_to_le64(dst_cid);
67	hdr->src_port	= cpu_to_le32(src_port);
68	hdr->dst_port	= cpu_to_le32(dst_port);
69	hdr->flags	= cpu_to_le32(info->flags);
70	hdr->len	= cpu_to_le32(len);
71	hdr->buf_alloc	= cpu_to_le32(0);
72	hdr->fwd_cnt	= cpu_to_le32(0);
73
74	if (info->msg && len > 0) {
75		payload = skb_put(skb, len);
76		err = memcpy_from_msg(payload, info->msg, len);
77		if (err)
78			goto out;
79
80		if (msg_data_left(info->msg) == 0 &&
81		    info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
82			hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
83
84			if (info->msg->msg_flags & MSG_EOR)
85				hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
86		}
87	}
88
89	if (info->reply)
90		virtio_vsock_skb_set_reply(skb);
91
92	trace_virtio_transport_alloc_pkt(src_cid, src_port,
93					 dst_cid, dst_port,
94					 len,
95					 info->type,
96					 info->op,
97					 info->flags);
98
99	if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) {
100		WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n");
101		goto out;
102	}
103
104	return skb;
105
106out:
107	kfree_skb(skb);
108	return NULL;
109}
110
111/* Packet capture */
112static struct sk_buff *virtio_transport_build_skb(void *opaque)
113{
114	struct virtio_vsock_hdr *pkt_hdr;
115	struct sk_buff *pkt = opaque;
116	struct af_vsockmon_hdr *hdr;
117	struct sk_buff *skb;
118	size_t payload_len;
119	void *payload_buf;
120
121	/* A packet could be split to fit the RX buffer, so we can retrieve
122	 * the payload length from the header and the buffer pointer taking
123	 * care of the offset in the original packet.
124	 */
125	pkt_hdr = virtio_vsock_hdr(pkt);
126	payload_len = pkt->len;
127	payload_buf = pkt->data;
128
129	skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len,
130			GFP_ATOMIC);
131	if (!skb)
132		return NULL;
133
134	hdr = skb_put(skb, sizeof(*hdr));
135
136	/* pkt->hdr is little-endian so no need to byteswap here */
137	hdr->src_cid = pkt_hdr->src_cid;
138	hdr->src_port = pkt_hdr->src_port;
139	hdr->dst_cid = pkt_hdr->dst_cid;
140	hdr->dst_port = pkt_hdr->dst_port;
141
142	hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
143	hdr->len = cpu_to_le16(sizeof(*pkt_hdr));
144	memset(hdr->reserved, 0, sizeof(hdr->reserved));
145
146	switch (le16_to_cpu(pkt_hdr->op)) {
147	case VIRTIO_VSOCK_OP_REQUEST:
148	case VIRTIO_VSOCK_OP_RESPONSE:
149		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
150		break;
151	case VIRTIO_VSOCK_OP_RST:
152	case VIRTIO_VSOCK_OP_SHUTDOWN:
153		hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
154		break;
155	case VIRTIO_VSOCK_OP_RW:
156		hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
157		break;
158	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
159	case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
160		hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
161		break;
162	default:
163		hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
164		break;
165	}
166
167	skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr));
168
169	if (payload_len) {
170		skb_put_data(skb, payload_buf, payload_len);
171	}
172
173	return skb;
174}
175
176void virtio_transport_deliver_tap_pkt(struct sk_buff *skb)
177{
178	if (virtio_vsock_skb_tap_delivered(skb))
179		return;
180
181	vsock_deliver_tap(virtio_transport_build_skb, skb);
182	virtio_vsock_skb_set_tap_delivered(skb);
183}
184EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
185
186static u16 virtio_transport_get_type(struct sock *sk)
187{
188	if (sk->sk_type == SOCK_STREAM)
189		return VIRTIO_VSOCK_TYPE_STREAM;
190	else
191		return VIRTIO_VSOCK_TYPE_SEQPACKET;
192}
193
194/* This function can only be used on connecting/connected sockets,
195 * since a socket assigned to a transport is required.
196 *
197 * Do not use on listener sockets!
198 */
199static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
200					  struct virtio_vsock_pkt_info *info)
201{
202	u32 src_cid, src_port, dst_cid, dst_port;
203	const struct virtio_transport *t_ops;
204	struct virtio_vsock_sock *vvs;
205	u32 pkt_len = info->pkt_len;
206	u32 rest_len;
207	int ret;
208
209	info->type = virtio_transport_get_type(sk_vsock(vsk));
210
211	t_ops = virtio_transport_get_ops(vsk);
212	if (unlikely(!t_ops))
213		return -EFAULT;
214
215	src_cid = t_ops->transport.get_local_cid();
216	src_port = vsk->local_addr.svm_port;
217	if (!info->remote_cid) {
218		dst_cid	= vsk->remote_addr.svm_cid;
219		dst_port = vsk->remote_addr.svm_port;
220	} else {
221		dst_cid = info->remote_cid;
222		dst_port = info->remote_port;
223	}
224
225	vvs = vsk->trans;
226
227	/* virtio_transport_get_credit might return less than pkt_len credit */
228	pkt_len = virtio_transport_get_credit(vvs, pkt_len);
229
230	/* Do not send zero length OP_RW pkt */
231	if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
232		return pkt_len;
233
234	rest_len = pkt_len;
235
236	do {
237		struct sk_buff *skb;
238		size_t skb_len;
239
240		skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len);
241
242		skb = virtio_transport_alloc_skb(info, skb_len,
243						 src_cid, src_port,
244						 dst_cid, dst_port);
245		if (!skb) {
246			ret = -ENOMEM;
247			break;
248		}
249
250		virtio_transport_inc_tx_pkt(vvs, skb);
251
252		ret = t_ops->send_pkt(skb);
253		if (ret < 0)
254			break;
255
256		/* Both virtio and vhost 'send_pkt()' returns 'skb_len',
257		 * but for reliability use 'ret' instead of 'skb_len'.
258		 * Also if partial send happens (e.g. 'ret' != 'skb_len')
259		 * somehow, we break this loop, but account such returned
260		 * value in 'virtio_transport_put_credit()'.
261		 */
262		rest_len -= ret;
263
264		if (WARN_ONCE(ret != skb_len,
265			      "'send_pkt()' returns %i, but %zu expected\n",
266			      ret, skb_len))
267			break;
268	} while (rest_len);
269
270	virtio_transport_put_credit(vvs, rest_len);
271
272	/* Return number of bytes, if any data has been sent. */
273	if (rest_len != pkt_len)
274		ret = pkt_len - rest_len;
275
276	return ret;
277}
278
279static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
280					u32 len)
281{
282	if (vvs->rx_bytes + len > vvs->buf_alloc)
283		return false;
284
285	vvs->rx_bytes += len;
286	return true;
287}
288
289static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
290					u32 len)
291{
292	vvs->rx_bytes -= len;
293	vvs->fwd_cnt += len;
294}
295
296void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb)
297{
298	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
299
300	spin_lock_bh(&vvs->rx_lock);
301	vvs->last_fwd_cnt = vvs->fwd_cnt;
302	hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
303	hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc);
304	spin_unlock_bh(&vvs->rx_lock);
305}
306EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
307
308u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
309{
310	u32 ret;
311
312	if (!credit)
313		return 0;
314
315	spin_lock_bh(&vvs->tx_lock);
316	ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
317	if (ret > credit)
318		ret = credit;
319	vvs->tx_cnt += ret;
320	spin_unlock_bh(&vvs->tx_lock);
321
322	return ret;
323}
324EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
325
326void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
327{
328	if (!credit)
329		return;
330
331	spin_lock_bh(&vvs->tx_lock);
332	vvs->tx_cnt -= credit;
333	spin_unlock_bh(&vvs->tx_lock);
334}
335EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
336
337static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
338{
339	struct virtio_vsock_pkt_info info = {
340		.op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
341		.vsk = vsk,
342	};
343
344	return virtio_transport_send_pkt_info(vsk, &info);
345}
346
347static ssize_t
348virtio_transport_stream_do_peek(struct vsock_sock *vsk,
349				struct msghdr *msg,
350				size_t len)
351{
352	struct virtio_vsock_sock *vvs = vsk->trans;
353	struct sk_buff *skb;
354	size_t total = 0;
355	int err;
356
357	spin_lock_bh(&vvs->rx_lock);
358
359	skb_queue_walk(&vvs->rx_queue, skb) {
360		size_t bytes;
361
362		bytes = len - total;
363		if (bytes > skb->len)
364			bytes = skb->len;
365
366		spin_unlock_bh(&vvs->rx_lock);
367
368		/* sk_lock is held by caller so no one else can dequeue.
369		 * Unlock rx_lock since memcpy_to_msg() may sleep.
370		 */
371		err = memcpy_to_msg(msg, skb->data, bytes);
372		if (err)
373			goto out;
374
375		total += bytes;
376
377		spin_lock_bh(&vvs->rx_lock);
378
379		if (total == len)
380			break;
381	}
382
383	spin_unlock_bh(&vvs->rx_lock);
384
385	return total;
386
387out:
388	if (total)
389		err = total;
390	return err;
391}
392
393static ssize_t
394virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
395				   struct msghdr *msg,
396				   size_t len)
397{
398	struct virtio_vsock_sock *vvs = vsk->trans;
399	size_t bytes, total = 0;
400	struct sk_buff *skb;
401	u32 fwd_cnt_delta;
402	bool low_rx_bytes;
403	int err = -EFAULT;
404	u32 free_space;
405
406	spin_lock_bh(&vvs->rx_lock);
407
408	if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes,
409		      "rx_queue is empty, but rx_bytes is non-zero\n")) {
410		spin_unlock_bh(&vvs->rx_lock);
411		return err;
412	}
413
414	while (total < len && !skb_queue_empty(&vvs->rx_queue)) {
415		skb = skb_peek(&vvs->rx_queue);
416
417		bytes = len - total;
418		if (bytes > skb->len)
419			bytes = skb->len;
420
421		/* sk_lock is held by caller so no one else can dequeue.
422		 * Unlock rx_lock since memcpy_to_msg() may sleep.
423		 */
424		spin_unlock_bh(&vvs->rx_lock);
425
426		err = memcpy_to_msg(msg, skb->data, bytes);
427		if (err)
428			goto out;
429
430		spin_lock_bh(&vvs->rx_lock);
431
432		total += bytes;
433		skb_pull(skb, bytes);
434
435		if (skb->len == 0) {
436			u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len);
437
438			virtio_transport_dec_rx_pkt(vvs, pkt_len);
439			__skb_unlink(skb, &vvs->rx_queue);
440			consume_skb(skb);
441		}
442	}
443
444	fwd_cnt_delta = vvs->fwd_cnt - vvs->last_fwd_cnt;
445	free_space = vvs->buf_alloc - fwd_cnt_delta;
446	low_rx_bytes = (vvs->rx_bytes <
447			sock_rcvlowat(sk_vsock(vsk), 0, INT_MAX));
448
449	spin_unlock_bh(&vvs->rx_lock);
450
451	/* To reduce the number of credit update messages,
452	 * don't update credits as long as lots of space is available.
453	 * Note: the limit chosen here is arbitrary. Setting the limit
454	 * too high causes extra messages. Too low causes transmitter
455	 * stalls. As stalls are in theory more expensive than extra
456	 * messages, we set the limit to a high value. TODO: experiment
457	 * with different values. Also send credit update message when
458	 * number of bytes in rx queue is not enough to wake up reader.
459	 */
460	if (fwd_cnt_delta &&
461	    (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE || low_rx_bytes))
462		virtio_transport_send_credit_update(vsk);
463
464	return total;
465
466out:
467	if (total)
468		err = total;
469	return err;
470}
471
472static ssize_t
473virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk,
474				   struct msghdr *msg)
475{
476	struct virtio_vsock_sock *vvs = vsk->trans;
477	struct sk_buff *skb;
478	size_t total, len;
479
480	spin_lock_bh(&vvs->rx_lock);
481
482	if (!vvs->msg_count) {
483		spin_unlock_bh(&vvs->rx_lock);
484		return 0;
485	}
486
487	total = 0;
488	len = msg_data_left(msg);
489
490	skb_queue_walk(&vvs->rx_queue, skb) {
491		struct virtio_vsock_hdr *hdr;
492
493		if (total < len) {
494			size_t bytes;
495			int err;
496
497			bytes = len - total;
498			if (bytes > skb->len)
499				bytes = skb->len;
500
501			spin_unlock_bh(&vvs->rx_lock);
502
503			/* sk_lock is held by caller so no one else can dequeue.
504			 * Unlock rx_lock since memcpy_to_msg() may sleep.
505			 */
506			err = memcpy_to_msg(msg, skb->data, bytes);
507			if (err)
508				return err;
509
510			spin_lock_bh(&vvs->rx_lock);
511		}
512
513		total += skb->len;
514		hdr = virtio_vsock_hdr(skb);
515
516		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
517			if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
518				msg->msg_flags |= MSG_EOR;
519
520			break;
521		}
522	}
523
524	spin_unlock_bh(&vvs->rx_lock);
525
526	return total;
527}
528
529static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
530						 struct msghdr *msg,
531						 int flags)
532{
533	struct virtio_vsock_sock *vvs = vsk->trans;
534	int dequeued_len = 0;
535	size_t user_buf_len = msg_data_left(msg);
536	bool msg_ready = false;
537	struct sk_buff *skb;
538
539	spin_lock_bh(&vvs->rx_lock);
540
541	if (vvs->msg_count == 0) {
542		spin_unlock_bh(&vvs->rx_lock);
543		return 0;
544	}
545
546	while (!msg_ready) {
547		struct virtio_vsock_hdr *hdr;
548		size_t pkt_len;
549
550		skb = __skb_dequeue(&vvs->rx_queue);
551		if (!skb)
552			break;
553		hdr = virtio_vsock_hdr(skb);
554		pkt_len = (size_t)le32_to_cpu(hdr->len);
555
556		if (dequeued_len >= 0) {
557			size_t bytes_to_copy;
558
559			bytes_to_copy = min(user_buf_len, pkt_len);
560
561			if (bytes_to_copy) {
562				int err;
563
564				/* sk_lock is held by caller so no one else can dequeue.
565				 * Unlock rx_lock since memcpy_to_msg() may sleep.
566				 */
567				spin_unlock_bh(&vvs->rx_lock);
568
569				err = memcpy_to_msg(msg, skb->data, bytes_to_copy);
570				if (err) {
571					/* Copy of message failed. Rest of
572					 * fragments will be freed without copy.
573					 */
574					dequeued_len = err;
575				} else {
576					user_buf_len -= bytes_to_copy;
577				}
578
579				spin_lock_bh(&vvs->rx_lock);
580			}
581
582			if (dequeued_len >= 0)
583				dequeued_len += pkt_len;
584		}
585
586		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
587			msg_ready = true;
588			vvs->msg_count--;
589
590			if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
591				msg->msg_flags |= MSG_EOR;
592		}
593
594		virtio_transport_dec_rx_pkt(vvs, pkt_len);
595		kfree_skb(skb);
596	}
597
598	spin_unlock_bh(&vvs->rx_lock);
599
600	virtio_transport_send_credit_update(vsk);
601
602	return dequeued_len;
603}
604
605ssize_t
606virtio_transport_stream_dequeue(struct vsock_sock *vsk,
607				struct msghdr *msg,
608				size_t len, int flags)
609{
610	if (flags & MSG_PEEK)
611		return virtio_transport_stream_do_peek(vsk, msg, len);
612	else
613		return virtio_transport_stream_do_dequeue(vsk, msg, len);
614}
615EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
616
617ssize_t
618virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
619				   struct msghdr *msg,
620				   int flags)
621{
622	if (flags & MSG_PEEK)
623		return virtio_transport_seqpacket_do_peek(vsk, msg);
624	else
625		return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
626}
627EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
628
629int
630virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
631				   struct msghdr *msg,
632				   size_t len)
633{
634	struct virtio_vsock_sock *vvs = vsk->trans;
635
636	spin_lock_bh(&vvs->tx_lock);
637
638	if (len > vvs->peer_buf_alloc) {
639		spin_unlock_bh(&vvs->tx_lock);
640		return -EMSGSIZE;
641	}
642
643	spin_unlock_bh(&vvs->tx_lock);
644
645	return virtio_transport_stream_enqueue(vsk, msg, len);
646}
647EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
648
649int
650virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
651			       struct msghdr *msg,
652			       size_t len, int flags)
653{
654	return -EOPNOTSUPP;
655}
656EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
657
658s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
659{
660	struct virtio_vsock_sock *vvs = vsk->trans;
661	s64 bytes;
662
663	spin_lock_bh(&vvs->rx_lock);
664	bytes = vvs->rx_bytes;
665	spin_unlock_bh(&vvs->rx_lock);
666
667	return bytes;
668}
669EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
670
671u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
672{
673	struct virtio_vsock_sock *vvs = vsk->trans;
674	u32 msg_count;
675
676	spin_lock_bh(&vvs->rx_lock);
677	msg_count = vvs->msg_count;
678	spin_unlock_bh(&vvs->rx_lock);
679
680	return msg_count;
681}
682EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
683
684static s64 virtio_transport_has_space(struct vsock_sock *vsk)
685{
686	struct virtio_vsock_sock *vvs = vsk->trans;
687	s64 bytes;
688
689	bytes = (s64)vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
690	if (bytes < 0)
691		bytes = 0;
692
693	return bytes;
694}
695
696s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
697{
698	struct virtio_vsock_sock *vvs = vsk->trans;
699	s64 bytes;
700
701	spin_lock_bh(&vvs->tx_lock);
702	bytes = virtio_transport_has_space(vsk);
703	spin_unlock_bh(&vvs->tx_lock);
704
705	return bytes;
706}
707EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
708
709int virtio_transport_do_socket_init(struct vsock_sock *vsk,
710				    struct vsock_sock *psk)
711{
712	struct virtio_vsock_sock *vvs;
713
714	vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
715	if (!vvs)
716		return -ENOMEM;
717
718	vsk->trans = vvs;
719	vvs->vsk = vsk;
720	if (psk && psk->trans) {
721		struct virtio_vsock_sock *ptrans = psk->trans;
722
723		vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
724	}
725
726	if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
727		vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
728
729	vvs->buf_alloc = vsk->buffer_size;
730
731	spin_lock_init(&vvs->rx_lock);
732	spin_lock_init(&vvs->tx_lock);
733	skb_queue_head_init(&vvs->rx_queue);
734
735	return 0;
736}
737EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
738
739/* sk_lock held by the caller */
740void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
741{
742	struct virtio_vsock_sock *vvs = vsk->trans;
743
744	if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
745		*val = VIRTIO_VSOCK_MAX_BUF_SIZE;
746
747	vvs->buf_alloc = *val;
748
749	virtio_transport_send_credit_update(vsk);
750}
751EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
752
753int
754virtio_transport_notify_poll_in(struct vsock_sock *vsk,
755				size_t target,
756				bool *data_ready_now)
757{
758	*data_ready_now = vsock_stream_has_data(vsk) >= target;
759
760	return 0;
761}
762EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
763
764int
765virtio_transport_notify_poll_out(struct vsock_sock *vsk,
766				 size_t target,
767				 bool *space_avail_now)
768{
769	s64 free_space;
770
771	free_space = vsock_stream_has_space(vsk);
772	if (free_space > 0)
773		*space_avail_now = true;
774	else if (free_space == 0)
775		*space_avail_now = false;
776
777	return 0;
778}
779EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
780
781int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
782	size_t target, struct vsock_transport_recv_notify_data *data)
783{
784	return 0;
785}
786EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
787
788int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
789	size_t target, struct vsock_transport_recv_notify_data *data)
790{
791	return 0;
792}
793EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
794
795int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
796	size_t target, struct vsock_transport_recv_notify_data *data)
797{
798	return 0;
799}
800EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
801
802int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
803	size_t target, ssize_t copied, bool data_read,
804	struct vsock_transport_recv_notify_data *data)
805{
806	return 0;
807}
808EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
809
810int virtio_transport_notify_send_init(struct vsock_sock *vsk,
811	struct vsock_transport_send_notify_data *data)
812{
813	return 0;
814}
815EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
816
817int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
818	struct vsock_transport_send_notify_data *data)
819{
820	return 0;
821}
822EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
823
824int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
825	struct vsock_transport_send_notify_data *data)
826{
827	return 0;
828}
829EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
830
831int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
832	ssize_t written, struct vsock_transport_send_notify_data *data)
833{
834	return 0;
835}
836EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
837
838u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
839{
840	return vsk->buffer_size;
841}
842EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
843
844bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
845{
846	return true;
847}
848EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
849
850bool virtio_transport_stream_allow(u32 cid, u32 port)
851{
852	return true;
853}
854EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
855
856int virtio_transport_dgram_bind(struct vsock_sock *vsk,
857				struct sockaddr_vm *addr)
858{
859	return -EOPNOTSUPP;
860}
861EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
862
863bool virtio_transport_dgram_allow(u32 cid, u32 port)
864{
865	return false;
866}
867EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
868
869int virtio_transport_connect(struct vsock_sock *vsk)
870{
871	struct virtio_vsock_pkt_info info = {
872		.op = VIRTIO_VSOCK_OP_REQUEST,
873		.vsk = vsk,
874	};
875
876	return virtio_transport_send_pkt_info(vsk, &info);
877}
878EXPORT_SYMBOL_GPL(virtio_transport_connect);
879
880int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
881{
882	struct virtio_vsock_pkt_info info = {
883		.op = VIRTIO_VSOCK_OP_SHUTDOWN,
884		.flags = (mode & RCV_SHUTDOWN ?
885			  VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
886			 (mode & SEND_SHUTDOWN ?
887			  VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
888		.vsk = vsk,
889	};
890
891	return virtio_transport_send_pkt_info(vsk, &info);
892}
893EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
894
895int
896virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
897			       struct sockaddr_vm *remote_addr,
898			       struct msghdr *msg,
899			       size_t dgram_len)
900{
901	return -EOPNOTSUPP;
902}
903EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
904
905ssize_t
906virtio_transport_stream_enqueue(struct vsock_sock *vsk,
907				struct msghdr *msg,
908				size_t len)
909{
910	struct virtio_vsock_pkt_info info = {
911		.op = VIRTIO_VSOCK_OP_RW,
912		.msg = msg,
913		.pkt_len = len,
914		.vsk = vsk,
915	};
916
917	return virtio_transport_send_pkt_info(vsk, &info);
918}
919EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
920
921void virtio_transport_destruct(struct vsock_sock *vsk)
922{
923	struct virtio_vsock_sock *vvs = vsk->trans;
924
925	kfree(vvs);
926}
927EXPORT_SYMBOL_GPL(virtio_transport_destruct);
928
929static int virtio_transport_reset(struct vsock_sock *vsk,
930				  struct sk_buff *skb)
931{
932	struct virtio_vsock_pkt_info info = {
933		.op = VIRTIO_VSOCK_OP_RST,
934		.reply = !!skb,
935		.vsk = vsk,
936	};
937
938	/* Send RST only if the original pkt is not a RST pkt */
939	if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST)
940		return 0;
941
942	return virtio_transport_send_pkt_info(vsk, &info);
943}
944
945/* Normally packets are associated with a socket.  There may be no socket if an
946 * attempt was made to connect to a socket that does not exist.
947 */
948static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
949					  struct sk_buff *skb)
950{
951	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
952	struct virtio_vsock_pkt_info info = {
953		.op = VIRTIO_VSOCK_OP_RST,
954		.type = le16_to_cpu(hdr->type),
955		.reply = true,
956	};
957	struct sk_buff *reply;
958
959	/* Send RST only if the original pkt is not a RST pkt */
960	if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
961		return 0;
962
963	if (!t)
964		return -ENOTCONN;
965
966	reply = virtio_transport_alloc_skb(&info, 0,
967					   le64_to_cpu(hdr->dst_cid),
968					   le32_to_cpu(hdr->dst_port),
969					   le64_to_cpu(hdr->src_cid),
970					   le32_to_cpu(hdr->src_port));
971	if (!reply)
972		return -ENOMEM;
973
974	return t->send_pkt(reply);
975}
976
977/* This function should be called with sk_lock held and SOCK_DONE set */
978static void virtio_transport_remove_sock(struct vsock_sock *vsk)
979{
980	struct virtio_vsock_sock *vvs = vsk->trans;
981
982	/* We don't need to take rx_lock, as the socket is closing and we are
983	 * removing it.
984	 */
985	__skb_queue_purge(&vvs->rx_queue);
986	vsock_remove_sock(vsk);
987}
988
989static void virtio_transport_wait_close(struct sock *sk, long timeout)
990{
991	if (timeout) {
992		DEFINE_WAIT_FUNC(wait, woken_wake_function);
993
994		add_wait_queue(sk_sleep(sk), &wait);
995
996		do {
997			if (sk_wait_event(sk, &timeout,
998					  sock_flag(sk, SOCK_DONE), &wait))
999				break;
1000		} while (!signal_pending(current) && timeout);
1001
1002		remove_wait_queue(sk_sleep(sk), &wait);
1003	}
1004}
1005
1006static void virtio_transport_do_close(struct vsock_sock *vsk,
1007				      bool cancel_timeout)
1008{
1009	struct sock *sk = sk_vsock(vsk);
1010
1011	sock_set_flag(sk, SOCK_DONE);
1012	vsk->peer_shutdown = SHUTDOWN_MASK;
1013	if (vsock_stream_has_data(vsk) <= 0)
1014		sk->sk_state = TCP_CLOSING;
1015	sk->sk_state_change(sk);
1016
1017	if (vsk->close_work_scheduled &&
1018	    (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
1019		vsk->close_work_scheduled = false;
1020
1021		virtio_transport_remove_sock(vsk);
1022
1023		/* Release refcnt obtained when we scheduled the timeout */
1024		sock_put(sk);
1025	}
1026}
1027
1028static void virtio_transport_close_timeout(struct work_struct *work)
1029{
1030	struct vsock_sock *vsk =
1031		container_of(work, struct vsock_sock, close_work.work);
1032	struct sock *sk = sk_vsock(vsk);
1033
1034	sock_hold(sk);
1035	lock_sock(sk);
1036
1037	if (!sock_flag(sk, SOCK_DONE)) {
1038		(void)virtio_transport_reset(vsk, NULL);
1039
1040		virtio_transport_do_close(vsk, false);
1041	}
1042
1043	vsk->close_work_scheduled = false;
1044
1045	release_sock(sk);
1046	sock_put(sk);
1047}
1048
1049/* User context, vsk->sk is locked */
1050static bool virtio_transport_close(struct vsock_sock *vsk)
1051{
1052	struct sock *sk = &vsk->sk;
1053
1054	if (!(sk->sk_state == TCP_ESTABLISHED ||
1055	      sk->sk_state == TCP_CLOSING))
1056		return true;
1057
1058	/* Already received SHUTDOWN from peer, reply with RST */
1059	if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
1060		(void)virtio_transport_reset(vsk, NULL);
1061		return true;
1062	}
1063
1064	if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
1065		(void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
1066
1067	if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
1068		virtio_transport_wait_close(sk, sk->sk_lingertime);
1069
1070	if (sock_flag(sk, SOCK_DONE)) {
1071		return true;
1072	}
1073
1074	sock_hold(sk);
1075	INIT_DELAYED_WORK(&vsk->close_work,
1076			  virtio_transport_close_timeout);
1077	vsk->close_work_scheduled = true;
1078	schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
1079	return false;
1080}
1081
1082void virtio_transport_release(struct vsock_sock *vsk)
1083{
1084	struct sock *sk = &vsk->sk;
1085	bool remove_sock = true;
1086
1087	if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
1088		remove_sock = virtio_transport_close(vsk);
1089
1090	if (remove_sock) {
1091		sock_set_flag(sk, SOCK_DONE);
1092		virtio_transport_remove_sock(vsk);
1093	}
1094}
1095EXPORT_SYMBOL_GPL(virtio_transport_release);
1096
1097static int
1098virtio_transport_recv_connecting(struct sock *sk,
1099				 struct sk_buff *skb)
1100{
1101	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1102	struct vsock_sock *vsk = vsock_sk(sk);
1103	int skerr;
1104	int err;
1105
1106	switch (le16_to_cpu(hdr->op)) {
1107	case VIRTIO_VSOCK_OP_RESPONSE:
1108		sk->sk_state = TCP_ESTABLISHED;
1109		sk->sk_socket->state = SS_CONNECTED;
1110		vsock_insert_connected(vsk);
1111		sk->sk_state_change(sk);
1112		break;
1113	case VIRTIO_VSOCK_OP_INVALID:
1114		break;
1115	case VIRTIO_VSOCK_OP_RST:
1116		skerr = ECONNRESET;
1117		err = 0;
1118		goto destroy;
1119	default:
1120		skerr = EPROTO;
1121		err = -EINVAL;
1122		goto destroy;
1123	}
1124	return 0;
1125
1126destroy:
1127	virtio_transport_reset(vsk, skb);
1128	sk->sk_state = TCP_CLOSE;
1129	sk->sk_err = skerr;
1130	sk_error_report(sk);
1131	return err;
1132}
1133
1134static void
1135virtio_transport_recv_enqueue(struct vsock_sock *vsk,
1136			      struct sk_buff *skb)
1137{
1138	struct virtio_vsock_sock *vvs = vsk->trans;
1139	bool can_enqueue, free_pkt = false;
1140	struct virtio_vsock_hdr *hdr;
1141	u32 len;
1142
1143	hdr = virtio_vsock_hdr(skb);
1144	len = le32_to_cpu(hdr->len);
1145
1146	spin_lock_bh(&vvs->rx_lock);
1147
1148	can_enqueue = virtio_transport_inc_rx_pkt(vvs, len);
1149	if (!can_enqueue) {
1150		free_pkt = true;
1151		goto out;
1152	}
1153
1154	if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)
1155		vvs->msg_count++;
1156
1157	/* Try to copy small packets into the buffer of last packet queued,
1158	 * to avoid wasting memory queueing the entire buffer with a small
1159	 * payload.
1160	 */
1161	if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue)) {
1162		struct virtio_vsock_hdr *last_hdr;
1163		struct sk_buff *last_skb;
1164
1165		last_skb = skb_peek_tail(&vvs->rx_queue);
1166		last_hdr = virtio_vsock_hdr(last_skb);
1167
1168		/* If there is space in the last packet queued, we copy the
1169		 * new packet in its buffer. We avoid this if the last packet
1170		 * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is
1171		 * delimiter of SEQPACKET message, so 'pkt' is the first packet
1172		 * of a new message.
1173		 */
1174		if (skb->len < skb_tailroom(last_skb) &&
1175		    !(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) {
1176			memcpy(skb_put(last_skb, skb->len), skb->data, skb->len);
1177			free_pkt = true;
1178			last_hdr->flags |= hdr->flags;
1179			le32_add_cpu(&last_hdr->len, len);
1180			goto out;
1181		}
1182	}
1183
1184	__skb_queue_tail(&vvs->rx_queue, skb);
1185
1186out:
1187	spin_unlock_bh(&vvs->rx_lock);
1188	if (free_pkt)
1189		kfree_skb(skb);
1190}
1191
1192static int
1193virtio_transport_recv_connected(struct sock *sk,
1194				struct sk_buff *skb)
1195{
1196	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1197	struct vsock_sock *vsk = vsock_sk(sk);
1198	int err = 0;
1199
1200	switch (le16_to_cpu(hdr->op)) {
1201	case VIRTIO_VSOCK_OP_RW:
1202		virtio_transport_recv_enqueue(vsk, skb);
1203		vsock_data_ready(sk);
1204		return err;
1205	case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
1206		virtio_transport_send_credit_update(vsk);
1207		break;
1208	case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
1209		sk->sk_write_space(sk);
1210		break;
1211	case VIRTIO_VSOCK_OP_SHUTDOWN:
1212		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
1213			vsk->peer_shutdown |= RCV_SHUTDOWN;
1214		if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
1215			vsk->peer_shutdown |= SEND_SHUTDOWN;
1216		if (vsk->peer_shutdown == SHUTDOWN_MASK) {
1217			if (vsock_stream_has_data(vsk) <= 0 && !sock_flag(sk, SOCK_DONE)) {
1218				(void)virtio_transport_reset(vsk, NULL);
1219				virtio_transport_do_close(vsk, true);
1220			}
1221			/* Remove this socket anyway because the remote peer sent
1222			 * the shutdown. This way a new connection will succeed
1223			 * if the remote peer uses the same source port,
1224			 * even if the old socket is still unreleased, but now disconnected.
1225			 */
1226			vsock_remove_sock(vsk);
1227		}
1228		if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
1229			sk->sk_state_change(sk);
1230		break;
1231	case VIRTIO_VSOCK_OP_RST:
1232		virtio_transport_do_close(vsk, true);
1233		break;
1234	default:
1235		err = -EINVAL;
1236		break;
1237	}
1238
1239	kfree_skb(skb);
1240	return err;
1241}
1242
1243static void
1244virtio_transport_recv_disconnecting(struct sock *sk,
1245				    struct sk_buff *skb)
1246{
1247	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1248	struct vsock_sock *vsk = vsock_sk(sk);
1249
1250	if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
1251		virtio_transport_do_close(vsk, true);
1252}
1253
1254static int
1255virtio_transport_send_response(struct vsock_sock *vsk,
1256			       struct sk_buff *skb)
1257{
1258	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1259	struct virtio_vsock_pkt_info info = {
1260		.op = VIRTIO_VSOCK_OP_RESPONSE,
1261		.remote_cid = le64_to_cpu(hdr->src_cid),
1262		.remote_port = le32_to_cpu(hdr->src_port),
1263		.reply = true,
1264		.vsk = vsk,
1265	};
1266
1267	return virtio_transport_send_pkt_info(vsk, &info);
1268}
1269
1270static bool virtio_transport_space_update(struct sock *sk,
1271					  struct sk_buff *skb)
1272{
1273	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1274	struct vsock_sock *vsk = vsock_sk(sk);
1275	struct virtio_vsock_sock *vvs = vsk->trans;
1276	bool space_available;
1277
1278	/* Listener sockets are not associated with any transport, so we are
1279	 * not able to take the state to see if there is space available in the
1280	 * remote peer, but since they are only used to receive requests, we
1281	 * can assume that there is always space available in the other peer.
1282	 */
1283	if (!vvs)
1284		return true;
1285
1286	/* buf_alloc and fwd_cnt is always included in the hdr */
1287	spin_lock_bh(&vvs->tx_lock);
1288	vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc);
1289	vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt);
1290	space_available = virtio_transport_has_space(vsk);
1291	spin_unlock_bh(&vvs->tx_lock);
1292	return space_available;
1293}
1294
1295/* Handle server socket */
1296static int
1297virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
1298			     struct virtio_transport *t)
1299{
1300	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1301	struct vsock_sock *vsk = vsock_sk(sk);
1302	struct vsock_sock *vchild;
1303	struct sock *child;
1304	int ret;
1305
1306	if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) {
1307		virtio_transport_reset_no_sock(t, skb);
1308		return -EINVAL;
1309	}
1310
1311	if (sk_acceptq_is_full(sk)) {
1312		virtio_transport_reset_no_sock(t, skb);
1313		return -ENOMEM;
1314	}
1315
1316	child = vsock_create_connected(sk);
1317	if (!child) {
1318		virtio_transport_reset_no_sock(t, skb);
1319		return -ENOMEM;
1320	}
1321
1322	sk_acceptq_added(sk);
1323
1324	lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1325
1326	child->sk_state = TCP_ESTABLISHED;
1327
1328	vchild = vsock_sk(child);
1329	vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
1330			le32_to_cpu(hdr->dst_port));
1331	vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
1332			le32_to_cpu(hdr->src_port));
1333
1334	ret = vsock_assign_transport(vchild, vsk);
1335	/* Transport assigned (looking at remote_addr) must be the same
1336	 * where we received the request.
1337	 */
1338	if (ret || vchild->transport != &t->transport) {
1339		release_sock(child);
1340		virtio_transport_reset_no_sock(t, skb);
1341		sock_put(child);
1342		return ret;
1343	}
1344
1345	if (virtio_transport_space_update(child, skb))
1346		child->sk_write_space(child);
1347
1348	vsock_insert_connected(vchild);
1349	vsock_enqueue_accept(sk, child);
1350	virtio_transport_send_response(vchild, skb);
1351
1352	release_sock(child);
1353
1354	sk->sk_data_ready(sk);
1355	return 0;
1356}
1357
1358static bool virtio_transport_valid_type(u16 type)
1359{
1360	return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
1361	       (type == VIRTIO_VSOCK_TYPE_SEQPACKET);
1362}
1363
1364/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1365 * lock.
1366 */
1367void virtio_transport_recv_pkt(struct virtio_transport *t,
1368			       struct sk_buff *skb)
1369{
1370	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1371	struct sockaddr_vm src, dst;
1372	struct vsock_sock *vsk;
1373	struct sock *sk;
1374	bool space_available;
1375
1376	vsock_addr_init(&src, le64_to_cpu(hdr->src_cid),
1377			le32_to_cpu(hdr->src_port));
1378	vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid),
1379			le32_to_cpu(hdr->dst_port));
1380
1381	trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1382					dst.svm_cid, dst.svm_port,
1383					le32_to_cpu(hdr->len),
1384					le16_to_cpu(hdr->type),
1385					le16_to_cpu(hdr->op),
1386					le32_to_cpu(hdr->flags),
1387					le32_to_cpu(hdr->buf_alloc),
1388					le32_to_cpu(hdr->fwd_cnt));
1389
1390	if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) {
1391		(void)virtio_transport_reset_no_sock(t, skb);
1392		goto free_pkt;
1393	}
1394
1395	/* The socket must be in connected or bound table
1396	 * otherwise send reset back
1397	 */
1398	sk = vsock_find_connected_socket(&src, &dst);
1399	if (!sk) {
1400		sk = vsock_find_bound_socket(&dst);
1401		if (!sk) {
1402			(void)virtio_transport_reset_no_sock(t, skb);
1403			goto free_pkt;
1404		}
1405	}
1406
1407	if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) {
1408		(void)virtio_transport_reset_no_sock(t, skb);
1409		sock_put(sk);
1410		goto free_pkt;
1411	}
1412
1413	if (!skb_set_owner_sk_safe(skb, sk)) {
1414		WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n");
1415		goto free_pkt;
1416	}
1417
1418	vsk = vsock_sk(sk);
1419
1420	lock_sock(sk);
1421
1422	/* Check if sk has been closed before lock_sock */
1423	if (sock_flag(sk, SOCK_DONE)) {
1424		(void)virtio_transport_reset_no_sock(t, skb);
1425		release_sock(sk);
1426		sock_put(sk);
1427		goto free_pkt;
1428	}
1429
1430	space_available = virtio_transport_space_update(sk, skb);
1431
1432	/* Update CID in case it has changed after a transport reset event */
1433	if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
1434		vsk->local_addr.svm_cid = dst.svm_cid;
1435
1436	if (space_available)
1437		sk->sk_write_space(sk);
1438
1439	switch (sk->sk_state) {
1440	case TCP_LISTEN:
1441		virtio_transport_recv_listen(sk, skb, t);
1442		kfree_skb(skb);
1443		break;
1444	case TCP_SYN_SENT:
1445		virtio_transport_recv_connecting(sk, skb);
1446		kfree_skb(skb);
1447		break;
1448	case TCP_ESTABLISHED:
1449		virtio_transport_recv_connected(sk, skb);
1450		break;
1451	case TCP_CLOSING:
1452		virtio_transport_recv_disconnecting(sk, skb);
1453		kfree_skb(skb);
1454		break;
1455	default:
1456		(void)virtio_transport_reset_no_sock(t, skb);
1457		kfree_skb(skb);
1458		break;
1459	}
1460
1461	release_sock(sk);
1462
1463	/* Release refcnt obtained when we fetched this socket out of the
1464	 * bound or connected list.
1465	 */
1466	sock_put(sk);
1467	return;
1468
1469free_pkt:
1470	kfree_skb(skb);
1471}
1472EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1473
1474/* Remove skbs found in a queue that have a vsk that matches.
1475 *
1476 * Each skb is freed.
1477 *
1478 * Returns the count of skbs that were reply packets.
1479 */
1480int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue)
1481{
1482	struct sk_buff_head freeme;
1483	struct sk_buff *skb, *tmp;
1484	int cnt = 0;
1485
1486	skb_queue_head_init(&freeme);
1487
1488	spin_lock_bh(&queue->lock);
1489	skb_queue_walk_safe(queue, skb, tmp) {
1490		if (vsock_sk(skb->sk) != vsk)
1491			continue;
1492
1493		__skb_unlink(skb, queue);
1494		__skb_queue_tail(&freeme, skb);
1495
1496		if (virtio_vsock_skb_reply(skb))
1497			cnt++;
1498	}
1499	spin_unlock_bh(&queue->lock);
1500
1501	__skb_queue_purge(&freeme);
1502
1503	return cnt;
1504}
1505EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs);
1506
1507int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor)
1508{
1509	struct virtio_vsock_sock *vvs = vsk->trans;
1510	struct sock *sk = sk_vsock(vsk);
1511	struct sk_buff *skb;
1512	int off = 0;
1513	int err;
1514
1515	spin_lock_bh(&vvs->rx_lock);
1516	/* Use __skb_recv_datagram() for race-free handling of the receive. It
1517	 * works for types other than dgrams.
1518	 */
1519	skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err);
1520	spin_unlock_bh(&vvs->rx_lock);
1521
1522	if (!skb)
1523		return err;
1524
1525	return recv_actor(sk, skb);
1526}
1527EXPORT_SYMBOL_GPL(virtio_transport_read_skb);
1528
1529int virtio_transport_notify_set_rcvlowat(struct vsock_sock *vsk, int val)
1530{
1531	struct virtio_vsock_sock *vvs = vsk->trans;
1532	bool send_update;
1533
1534	spin_lock_bh(&vvs->rx_lock);
1535
1536	/* If number of available bytes is less than new SO_RCVLOWAT value,
1537	 * kick sender to send more data, because sender may sleep in its
1538	 * 'send()' syscall waiting for enough space at our side. Also
1539	 * don't send credit update when peer already knows actual value -
1540	 * such transmission will be useless.
1541	 */
1542	send_update = (vvs->rx_bytes < val) &&
1543		      (vvs->fwd_cnt != vvs->last_fwd_cnt);
1544
1545	spin_unlock_bh(&vvs->rx_lock);
1546
1547	if (send_update) {
1548		int err;
1549
1550		err = virtio_transport_send_credit_update(vsk);
1551		if (err < 0)
1552			return err;
1553	}
1554
1555	return 0;
1556}
1557EXPORT_SYMBOL_GPL(virtio_transport_notify_set_rcvlowat);
1558
1559MODULE_LICENSE("GPL v2");
1560MODULE_AUTHOR("Asias He");
1561MODULE_DESCRIPTION("common code for virtio vsock");
1562