xref: /kernel/linux/linux-6.6/drivers/vhost/vdpa.c (revision 62306a36)
1// SPDX-License-Identifier: GPL-2.0
2/*
3 * Copyright (C) 2018-2020 Intel Corporation.
4 * Copyright (C) 2020 Red Hat, Inc.
5 *
6 * Author: Tiwei Bie <tiwei.bie@intel.com>
7 *         Jason Wang <jasowang@redhat.com>
8 *
9 * Thanks Michael S. Tsirkin for the valuable comments and
10 * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11 * their supports.
12 */
13
14#include <linux/kernel.h>
15#include <linux/module.h>
16#include <linux/cdev.h>
17#include <linux/device.h>
18#include <linux/mm.h>
19#include <linux/slab.h>
20#include <linux/iommu.h>
21#include <linux/uuid.h>
22#include <linux/vdpa.h>
23#include <linux/nospec.h>
24#include <linux/vhost.h>
25
26#include "vhost.h"
27
28enum {
29	VHOST_VDPA_BACKEND_FEATURES =
30	(1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31	(1ULL << VHOST_BACKEND_F_IOTLB_BATCH) |
32	(1ULL << VHOST_BACKEND_F_IOTLB_ASID),
33};
34
35#define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
36
37#define VHOST_VDPA_IOTLB_BUCKETS 16
38
39struct vhost_vdpa_as {
40	struct hlist_node hash_link;
41	struct vhost_iotlb iotlb;
42	u32 id;
43};
44
45struct vhost_vdpa {
46	struct vhost_dev vdev;
47	struct iommu_domain *domain;
48	struct vhost_virtqueue *vqs;
49	struct completion completion;
50	struct vdpa_device *vdpa;
51	struct hlist_head as[VHOST_VDPA_IOTLB_BUCKETS];
52	struct device dev;
53	struct cdev cdev;
54	atomic_t opened;
55	u32 nvqs;
56	int virtio_id;
57	int minor;
58	struct eventfd_ctx *config_ctx;
59	int in_batch;
60	struct vdpa_iova_range range;
61	u32 batch_asid;
62};
63
64static DEFINE_IDA(vhost_vdpa_ida);
65
66static dev_t vhost_vdpa_major;
67
68static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
69				   struct vhost_iotlb *iotlb, u64 start,
70				   u64 last, u32 asid);
71
72static inline u32 iotlb_to_asid(struct vhost_iotlb *iotlb)
73{
74	struct vhost_vdpa_as *as = container_of(iotlb, struct
75						vhost_vdpa_as, iotlb);
76	return as->id;
77}
78
79static struct vhost_vdpa_as *asid_to_as(struct vhost_vdpa *v, u32 asid)
80{
81	struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
82	struct vhost_vdpa_as *as;
83
84	hlist_for_each_entry(as, head, hash_link)
85		if (as->id == asid)
86			return as;
87
88	return NULL;
89}
90
91static struct vhost_iotlb *asid_to_iotlb(struct vhost_vdpa *v, u32 asid)
92{
93	struct vhost_vdpa_as *as = asid_to_as(v, asid);
94
95	if (!as)
96		return NULL;
97
98	return &as->iotlb;
99}
100
101static struct vhost_vdpa_as *vhost_vdpa_alloc_as(struct vhost_vdpa *v, u32 asid)
102{
103	struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
104	struct vhost_vdpa_as *as;
105
106	if (asid_to_as(v, asid))
107		return NULL;
108
109	if (asid >= v->vdpa->nas)
110		return NULL;
111
112	as = kmalloc(sizeof(*as), GFP_KERNEL);
113	if (!as)
114		return NULL;
115
116	vhost_iotlb_init(&as->iotlb, 0, 0);
117	as->id = asid;
118	hlist_add_head(&as->hash_link, head);
119
120	return as;
121}
122
123static struct vhost_vdpa_as *vhost_vdpa_find_alloc_as(struct vhost_vdpa *v,
124						      u32 asid)
125{
126	struct vhost_vdpa_as *as = asid_to_as(v, asid);
127
128	if (as)
129		return as;
130
131	return vhost_vdpa_alloc_as(v, asid);
132}
133
134static int vhost_vdpa_remove_as(struct vhost_vdpa *v, u32 asid)
135{
136	struct vhost_vdpa_as *as = asid_to_as(v, asid);
137
138	if (!as)
139		return -EINVAL;
140
141	hlist_del(&as->hash_link);
142	vhost_vdpa_iotlb_unmap(v, &as->iotlb, 0ULL, 0ULL - 1, asid);
143	kfree(as);
144
145	return 0;
146}
147
148static void handle_vq_kick(struct vhost_work *work)
149{
150	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
151						  poll.work);
152	struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
153	const struct vdpa_config_ops *ops = v->vdpa->config;
154
155	ops->kick_vq(v->vdpa, vq - v->vqs);
156}
157
158static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
159{
160	struct vhost_virtqueue *vq = private;
161	struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
162
163	if (call_ctx)
164		eventfd_signal(call_ctx, 1);
165
166	return IRQ_HANDLED;
167}
168
169static irqreturn_t vhost_vdpa_config_cb(void *private)
170{
171	struct vhost_vdpa *v = private;
172	struct eventfd_ctx *config_ctx = v->config_ctx;
173
174	if (config_ctx)
175		eventfd_signal(config_ctx, 1);
176
177	return IRQ_HANDLED;
178}
179
180static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
181{
182	struct vhost_virtqueue *vq = &v->vqs[qid];
183	const struct vdpa_config_ops *ops = v->vdpa->config;
184	struct vdpa_device *vdpa = v->vdpa;
185	int ret, irq;
186
187	if (!ops->get_vq_irq)
188		return;
189
190	irq = ops->get_vq_irq(vdpa, qid);
191	if (irq < 0)
192		return;
193
194	irq_bypass_unregister_producer(&vq->call_ctx.producer);
195	if (!vq->call_ctx.ctx)
196		return;
197
198	vq->call_ctx.producer.token = vq->call_ctx.ctx;
199	vq->call_ctx.producer.irq = irq;
200	ret = irq_bypass_register_producer(&vq->call_ctx.producer);
201	if (unlikely(ret))
202		dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
203			 qid, vq->call_ctx.producer.token, ret);
204}
205
206static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
207{
208	struct vhost_virtqueue *vq = &v->vqs[qid];
209
210	irq_bypass_unregister_producer(&vq->call_ctx.producer);
211}
212
213static int vhost_vdpa_reset(struct vhost_vdpa *v)
214{
215	struct vdpa_device *vdpa = v->vdpa;
216
217	v->in_batch = 0;
218
219	return vdpa_reset(vdpa);
220}
221
222static long vhost_vdpa_bind_mm(struct vhost_vdpa *v)
223{
224	struct vdpa_device *vdpa = v->vdpa;
225	const struct vdpa_config_ops *ops = vdpa->config;
226
227	if (!vdpa->use_va || !ops->bind_mm)
228		return 0;
229
230	return ops->bind_mm(vdpa, v->vdev.mm);
231}
232
233static void vhost_vdpa_unbind_mm(struct vhost_vdpa *v)
234{
235	struct vdpa_device *vdpa = v->vdpa;
236	const struct vdpa_config_ops *ops = vdpa->config;
237
238	if (!vdpa->use_va || !ops->unbind_mm)
239		return;
240
241	ops->unbind_mm(vdpa);
242}
243
244static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
245{
246	struct vdpa_device *vdpa = v->vdpa;
247	const struct vdpa_config_ops *ops = vdpa->config;
248	u32 device_id;
249
250	device_id = ops->get_device_id(vdpa);
251
252	if (copy_to_user(argp, &device_id, sizeof(device_id)))
253		return -EFAULT;
254
255	return 0;
256}
257
258static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
259{
260	struct vdpa_device *vdpa = v->vdpa;
261	const struct vdpa_config_ops *ops = vdpa->config;
262	u8 status;
263
264	status = ops->get_status(vdpa);
265
266	if (copy_to_user(statusp, &status, sizeof(status)))
267		return -EFAULT;
268
269	return 0;
270}
271
272static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
273{
274	struct vdpa_device *vdpa = v->vdpa;
275	const struct vdpa_config_ops *ops = vdpa->config;
276	u8 status, status_old;
277	u32 nvqs = v->nvqs;
278	int ret;
279	u16 i;
280
281	if (copy_from_user(&status, statusp, sizeof(status)))
282		return -EFAULT;
283
284	status_old = ops->get_status(vdpa);
285
286	/*
287	 * Userspace shouldn't remove status bits unless reset the
288	 * status to 0.
289	 */
290	if (status != 0 && (status_old & ~status) != 0)
291		return -EINVAL;
292
293	if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
294		for (i = 0; i < nvqs; i++)
295			vhost_vdpa_unsetup_vq_irq(v, i);
296
297	if (status == 0) {
298		ret = vdpa_reset(vdpa);
299		if (ret)
300			return ret;
301	} else
302		vdpa_set_status(vdpa, status);
303
304	if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
305		for (i = 0; i < nvqs; i++)
306			vhost_vdpa_setup_vq_irq(v, i);
307
308	return 0;
309}
310
311static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
312				      struct vhost_vdpa_config *c)
313{
314	struct vdpa_device *vdpa = v->vdpa;
315	size_t size = vdpa->config->get_config_size(vdpa);
316
317	if (c->len == 0 || c->off > size)
318		return -EINVAL;
319
320	if (c->len > size - c->off)
321		return -E2BIG;
322
323	return 0;
324}
325
326static long vhost_vdpa_get_config(struct vhost_vdpa *v,
327				  struct vhost_vdpa_config __user *c)
328{
329	struct vdpa_device *vdpa = v->vdpa;
330	struct vhost_vdpa_config config;
331	unsigned long size = offsetof(struct vhost_vdpa_config, buf);
332	u8 *buf;
333
334	if (copy_from_user(&config, c, size))
335		return -EFAULT;
336	if (vhost_vdpa_config_validate(v, &config))
337		return -EINVAL;
338	buf = kvzalloc(config.len, GFP_KERNEL);
339	if (!buf)
340		return -ENOMEM;
341
342	vdpa_get_config(vdpa, config.off, buf, config.len);
343
344	if (copy_to_user(c->buf, buf, config.len)) {
345		kvfree(buf);
346		return -EFAULT;
347	}
348
349	kvfree(buf);
350	return 0;
351}
352
353static long vhost_vdpa_set_config(struct vhost_vdpa *v,
354				  struct vhost_vdpa_config __user *c)
355{
356	struct vdpa_device *vdpa = v->vdpa;
357	struct vhost_vdpa_config config;
358	unsigned long size = offsetof(struct vhost_vdpa_config, buf);
359	u8 *buf;
360
361	if (copy_from_user(&config, c, size))
362		return -EFAULT;
363	if (vhost_vdpa_config_validate(v, &config))
364		return -EINVAL;
365
366	buf = vmemdup_user(c->buf, config.len);
367	if (IS_ERR(buf))
368		return PTR_ERR(buf);
369
370	vdpa_set_config(vdpa, config.off, buf, config.len);
371
372	kvfree(buf);
373	return 0;
374}
375
376static bool vhost_vdpa_can_suspend(const struct vhost_vdpa *v)
377{
378	struct vdpa_device *vdpa = v->vdpa;
379	const struct vdpa_config_ops *ops = vdpa->config;
380
381	return ops->suspend;
382}
383
384static bool vhost_vdpa_can_resume(const struct vhost_vdpa *v)
385{
386	struct vdpa_device *vdpa = v->vdpa;
387	const struct vdpa_config_ops *ops = vdpa->config;
388
389	return ops->resume;
390}
391
392static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
393{
394	struct vdpa_device *vdpa = v->vdpa;
395	const struct vdpa_config_ops *ops = vdpa->config;
396	u64 features;
397
398	features = ops->get_device_features(vdpa);
399
400	if (copy_to_user(featurep, &features, sizeof(features)))
401		return -EFAULT;
402
403	return 0;
404}
405
406static u64 vhost_vdpa_get_backend_features(const struct vhost_vdpa *v)
407{
408	struct vdpa_device *vdpa = v->vdpa;
409	const struct vdpa_config_ops *ops = vdpa->config;
410
411	if (!ops->get_backend_features)
412		return 0;
413	else
414		return ops->get_backend_features(vdpa);
415}
416
417static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
418{
419	struct vdpa_device *vdpa = v->vdpa;
420	const struct vdpa_config_ops *ops = vdpa->config;
421	struct vhost_dev *d = &v->vdev;
422	u64 actual_features;
423	u64 features;
424	int i;
425
426	/*
427	 * It's not allowed to change the features after they have
428	 * been negotiated.
429	 */
430	if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
431		return -EBUSY;
432
433	if (copy_from_user(&features, featurep, sizeof(features)))
434		return -EFAULT;
435
436	if (vdpa_set_features(vdpa, features))
437		return -EINVAL;
438
439	/* let the vqs know what has been configured */
440	actual_features = ops->get_driver_features(vdpa);
441	for (i = 0; i < d->nvqs; ++i) {
442		struct vhost_virtqueue *vq = d->vqs[i];
443
444		mutex_lock(&vq->mutex);
445		vq->acked_features = actual_features;
446		mutex_unlock(&vq->mutex);
447	}
448
449	return 0;
450}
451
452static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
453{
454	struct vdpa_device *vdpa = v->vdpa;
455	const struct vdpa_config_ops *ops = vdpa->config;
456	u16 num;
457
458	num = ops->get_vq_num_max(vdpa);
459
460	if (copy_to_user(argp, &num, sizeof(num)))
461		return -EFAULT;
462
463	return 0;
464}
465
466static void vhost_vdpa_config_put(struct vhost_vdpa *v)
467{
468	if (v->config_ctx) {
469		eventfd_ctx_put(v->config_ctx);
470		v->config_ctx = NULL;
471	}
472}
473
474static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
475{
476	struct vdpa_callback cb;
477	int fd;
478	struct eventfd_ctx *ctx;
479
480	cb.callback = vhost_vdpa_config_cb;
481	cb.private = v;
482	if (copy_from_user(&fd, argp, sizeof(fd)))
483		return  -EFAULT;
484
485	ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
486	swap(ctx, v->config_ctx);
487
488	if (!IS_ERR_OR_NULL(ctx))
489		eventfd_ctx_put(ctx);
490
491	if (IS_ERR(v->config_ctx)) {
492		long ret = PTR_ERR(v->config_ctx);
493
494		v->config_ctx = NULL;
495		return ret;
496	}
497
498	v->vdpa->config->set_config_cb(v->vdpa, &cb);
499
500	return 0;
501}
502
503static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
504{
505	struct vhost_vdpa_iova_range range = {
506		.first = v->range.first,
507		.last = v->range.last,
508	};
509
510	if (copy_to_user(argp, &range, sizeof(range)))
511		return -EFAULT;
512	return 0;
513}
514
515static long vhost_vdpa_get_config_size(struct vhost_vdpa *v, u32 __user *argp)
516{
517	struct vdpa_device *vdpa = v->vdpa;
518	const struct vdpa_config_ops *ops = vdpa->config;
519	u32 size;
520
521	size = ops->get_config_size(vdpa);
522
523	if (copy_to_user(argp, &size, sizeof(size)))
524		return -EFAULT;
525
526	return 0;
527}
528
529static long vhost_vdpa_get_vqs_count(struct vhost_vdpa *v, u32 __user *argp)
530{
531	struct vdpa_device *vdpa = v->vdpa;
532
533	if (copy_to_user(argp, &vdpa->nvqs, sizeof(vdpa->nvqs)))
534		return -EFAULT;
535
536	return 0;
537}
538
539/* After a successful return of ioctl the device must not process more
540 * virtqueue descriptors. The device can answer to read or writes of config
541 * fields as if it were not suspended. In particular, writing to "queue_enable"
542 * with a value of 1 will not make the device start processing buffers.
543 */
544static long vhost_vdpa_suspend(struct vhost_vdpa *v)
545{
546	struct vdpa_device *vdpa = v->vdpa;
547	const struct vdpa_config_ops *ops = vdpa->config;
548
549	if (!ops->suspend)
550		return -EOPNOTSUPP;
551
552	return ops->suspend(vdpa);
553}
554
555/* After a successful return of this ioctl the device resumes processing
556 * virtqueue descriptors. The device becomes fully operational the same way it
557 * was before it was suspended.
558 */
559static long vhost_vdpa_resume(struct vhost_vdpa *v)
560{
561	struct vdpa_device *vdpa = v->vdpa;
562	const struct vdpa_config_ops *ops = vdpa->config;
563
564	if (!ops->resume)
565		return -EOPNOTSUPP;
566
567	return ops->resume(vdpa);
568}
569
570static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
571				   void __user *argp)
572{
573	struct vdpa_device *vdpa = v->vdpa;
574	const struct vdpa_config_ops *ops = vdpa->config;
575	struct vdpa_vq_state vq_state;
576	struct vdpa_callback cb;
577	struct vhost_virtqueue *vq;
578	struct vhost_vring_state s;
579	u32 idx;
580	long r;
581
582	r = get_user(idx, (u32 __user *)argp);
583	if (r < 0)
584		return r;
585
586	if (idx >= v->nvqs)
587		return -ENOBUFS;
588
589	idx = array_index_nospec(idx, v->nvqs);
590	vq = &v->vqs[idx];
591
592	switch (cmd) {
593	case VHOST_VDPA_SET_VRING_ENABLE:
594		if (copy_from_user(&s, argp, sizeof(s)))
595			return -EFAULT;
596		ops->set_vq_ready(vdpa, idx, s.num);
597		return 0;
598	case VHOST_VDPA_GET_VRING_GROUP:
599		if (!ops->get_vq_group)
600			return -EOPNOTSUPP;
601		s.index = idx;
602		s.num = ops->get_vq_group(vdpa, idx);
603		if (s.num >= vdpa->ngroups)
604			return -EIO;
605		else if (copy_to_user(argp, &s, sizeof(s)))
606			return -EFAULT;
607		return 0;
608	case VHOST_VDPA_SET_GROUP_ASID:
609		if (copy_from_user(&s, argp, sizeof(s)))
610			return -EFAULT;
611		if (s.num >= vdpa->nas)
612			return -EINVAL;
613		if (!ops->set_group_asid)
614			return -EOPNOTSUPP;
615		return ops->set_group_asid(vdpa, idx, s.num);
616	case VHOST_GET_VRING_BASE:
617		r = ops->get_vq_state(v->vdpa, idx, &vq_state);
618		if (r)
619			return r;
620
621		if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
622			vq->last_avail_idx = vq_state.packed.last_avail_idx |
623					     (vq_state.packed.last_avail_counter << 15);
624			vq->last_used_idx = vq_state.packed.last_used_idx |
625					    (vq_state.packed.last_used_counter << 15);
626		} else {
627			vq->last_avail_idx = vq_state.split.avail_index;
628		}
629		break;
630	}
631
632	r = vhost_vring_ioctl(&v->vdev, cmd, argp);
633	if (r)
634		return r;
635
636	switch (cmd) {
637	case VHOST_SET_VRING_ADDR:
638		if (ops->set_vq_address(vdpa, idx,
639					(u64)(uintptr_t)vq->desc,
640					(u64)(uintptr_t)vq->avail,
641					(u64)(uintptr_t)vq->used))
642			r = -EINVAL;
643		break;
644
645	case VHOST_SET_VRING_BASE:
646		if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
647			vq_state.packed.last_avail_idx = vq->last_avail_idx & 0x7fff;
648			vq_state.packed.last_avail_counter = !!(vq->last_avail_idx & 0x8000);
649			vq_state.packed.last_used_idx = vq->last_used_idx & 0x7fff;
650			vq_state.packed.last_used_counter = !!(vq->last_used_idx & 0x8000);
651		} else {
652			vq_state.split.avail_index = vq->last_avail_idx;
653		}
654		r = ops->set_vq_state(vdpa, idx, &vq_state);
655		break;
656
657	case VHOST_SET_VRING_CALL:
658		if (vq->call_ctx.ctx) {
659			cb.callback = vhost_vdpa_virtqueue_cb;
660			cb.private = vq;
661			cb.trigger = vq->call_ctx.ctx;
662		} else {
663			cb.callback = NULL;
664			cb.private = NULL;
665			cb.trigger = NULL;
666		}
667		ops->set_vq_cb(vdpa, idx, &cb);
668		vhost_vdpa_setup_vq_irq(v, idx);
669		break;
670
671	case VHOST_SET_VRING_NUM:
672		ops->set_vq_num(vdpa, idx, vq->num);
673		break;
674	}
675
676	return r;
677}
678
679static long vhost_vdpa_unlocked_ioctl(struct file *filep,
680				      unsigned int cmd, unsigned long arg)
681{
682	struct vhost_vdpa *v = filep->private_data;
683	struct vhost_dev *d = &v->vdev;
684	void __user *argp = (void __user *)arg;
685	u64 __user *featurep = argp;
686	u64 features;
687	long r = 0;
688
689	if (cmd == VHOST_SET_BACKEND_FEATURES) {
690		if (copy_from_user(&features, featurep, sizeof(features)))
691			return -EFAULT;
692		if (features & ~(VHOST_VDPA_BACKEND_FEATURES |
693				 BIT_ULL(VHOST_BACKEND_F_SUSPEND) |
694				 BIT_ULL(VHOST_BACKEND_F_RESUME) |
695				 BIT_ULL(VHOST_BACKEND_F_ENABLE_AFTER_DRIVER_OK)))
696			return -EOPNOTSUPP;
697		if ((features & BIT_ULL(VHOST_BACKEND_F_SUSPEND)) &&
698		     !vhost_vdpa_can_suspend(v))
699			return -EOPNOTSUPP;
700		if ((features & BIT_ULL(VHOST_BACKEND_F_RESUME)) &&
701		     !vhost_vdpa_can_resume(v))
702			return -EOPNOTSUPP;
703		vhost_set_backend_features(&v->vdev, features);
704		return 0;
705	}
706
707	mutex_lock(&d->mutex);
708
709	switch (cmd) {
710	case VHOST_VDPA_GET_DEVICE_ID:
711		r = vhost_vdpa_get_device_id(v, argp);
712		break;
713	case VHOST_VDPA_GET_STATUS:
714		r = vhost_vdpa_get_status(v, argp);
715		break;
716	case VHOST_VDPA_SET_STATUS:
717		r = vhost_vdpa_set_status(v, argp);
718		break;
719	case VHOST_VDPA_GET_CONFIG:
720		r = vhost_vdpa_get_config(v, argp);
721		break;
722	case VHOST_VDPA_SET_CONFIG:
723		r = vhost_vdpa_set_config(v, argp);
724		break;
725	case VHOST_GET_FEATURES:
726		r = vhost_vdpa_get_features(v, argp);
727		break;
728	case VHOST_SET_FEATURES:
729		r = vhost_vdpa_set_features(v, argp);
730		break;
731	case VHOST_VDPA_GET_VRING_NUM:
732		r = vhost_vdpa_get_vring_num(v, argp);
733		break;
734	case VHOST_VDPA_GET_GROUP_NUM:
735		if (copy_to_user(argp, &v->vdpa->ngroups,
736				 sizeof(v->vdpa->ngroups)))
737			r = -EFAULT;
738		break;
739	case VHOST_VDPA_GET_AS_NUM:
740		if (copy_to_user(argp, &v->vdpa->nas, sizeof(v->vdpa->nas)))
741			r = -EFAULT;
742		break;
743	case VHOST_SET_LOG_BASE:
744	case VHOST_SET_LOG_FD:
745		r = -ENOIOCTLCMD;
746		break;
747	case VHOST_VDPA_SET_CONFIG_CALL:
748		r = vhost_vdpa_set_config_call(v, argp);
749		break;
750	case VHOST_GET_BACKEND_FEATURES:
751		features = VHOST_VDPA_BACKEND_FEATURES;
752		if (vhost_vdpa_can_suspend(v))
753			features |= BIT_ULL(VHOST_BACKEND_F_SUSPEND);
754		if (vhost_vdpa_can_resume(v))
755			features |= BIT_ULL(VHOST_BACKEND_F_RESUME);
756		features |= vhost_vdpa_get_backend_features(v);
757		if (copy_to_user(featurep, &features, sizeof(features)))
758			r = -EFAULT;
759		break;
760	case VHOST_VDPA_GET_IOVA_RANGE:
761		r = vhost_vdpa_get_iova_range(v, argp);
762		break;
763	case VHOST_VDPA_GET_CONFIG_SIZE:
764		r = vhost_vdpa_get_config_size(v, argp);
765		break;
766	case VHOST_VDPA_GET_VQS_COUNT:
767		r = vhost_vdpa_get_vqs_count(v, argp);
768		break;
769	case VHOST_VDPA_SUSPEND:
770		r = vhost_vdpa_suspend(v);
771		break;
772	case VHOST_VDPA_RESUME:
773		r = vhost_vdpa_resume(v);
774		break;
775	default:
776		r = vhost_dev_ioctl(&v->vdev, cmd, argp);
777		if (r == -ENOIOCTLCMD)
778			r = vhost_vdpa_vring_ioctl(v, cmd, argp);
779		break;
780	}
781
782	if (r)
783		goto out;
784
785	switch (cmd) {
786	case VHOST_SET_OWNER:
787		r = vhost_vdpa_bind_mm(v);
788		if (r)
789			vhost_dev_reset_owner(d, NULL);
790		break;
791	}
792out:
793	mutex_unlock(&d->mutex);
794	return r;
795}
796static void vhost_vdpa_general_unmap(struct vhost_vdpa *v,
797				     struct vhost_iotlb_map *map, u32 asid)
798{
799	struct vdpa_device *vdpa = v->vdpa;
800	const struct vdpa_config_ops *ops = vdpa->config;
801	if (ops->dma_map) {
802		ops->dma_unmap(vdpa, asid, map->start, map->size);
803	} else if (ops->set_map == NULL) {
804		iommu_unmap(v->domain, map->start, map->size);
805	}
806}
807
808static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
809				u64 start, u64 last, u32 asid)
810{
811	struct vhost_dev *dev = &v->vdev;
812	struct vhost_iotlb_map *map;
813	struct page *page;
814	unsigned long pfn, pinned;
815
816	while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
817		pinned = PFN_DOWN(map->size);
818		for (pfn = PFN_DOWN(map->addr);
819		     pinned > 0; pfn++, pinned--) {
820			page = pfn_to_page(pfn);
821			if (map->perm & VHOST_ACCESS_WO)
822				set_page_dirty_lock(page);
823			unpin_user_page(page);
824		}
825		atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
826		vhost_vdpa_general_unmap(v, map, asid);
827		vhost_iotlb_map_free(iotlb, map);
828	}
829}
830
831static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
832				u64 start, u64 last, u32 asid)
833{
834	struct vhost_iotlb_map *map;
835	struct vdpa_map_file *map_file;
836
837	while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
838		map_file = (struct vdpa_map_file *)map->opaque;
839		fput(map_file->file);
840		kfree(map_file);
841		vhost_vdpa_general_unmap(v, map, asid);
842		vhost_iotlb_map_free(iotlb, map);
843	}
844}
845
846static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
847				   struct vhost_iotlb *iotlb, u64 start,
848				   u64 last, u32 asid)
849{
850	struct vdpa_device *vdpa = v->vdpa;
851
852	if (vdpa->use_va)
853		return vhost_vdpa_va_unmap(v, iotlb, start, last, asid);
854
855	return vhost_vdpa_pa_unmap(v, iotlb, start, last, asid);
856}
857
858static int perm_to_iommu_flags(u32 perm)
859{
860	int flags = 0;
861
862	switch (perm) {
863	case VHOST_ACCESS_WO:
864		flags |= IOMMU_WRITE;
865		break;
866	case VHOST_ACCESS_RO:
867		flags |= IOMMU_READ;
868		break;
869	case VHOST_ACCESS_RW:
870		flags |= (IOMMU_WRITE | IOMMU_READ);
871		break;
872	default:
873		WARN(1, "invalidate vhost IOTLB permission\n");
874		break;
875	}
876
877	return flags | IOMMU_CACHE;
878}
879
880static int vhost_vdpa_map(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
881			  u64 iova, u64 size, u64 pa, u32 perm, void *opaque)
882{
883	struct vhost_dev *dev = &v->vdev;
884	struct vdpa_device *vdpa = v->vdpa;
885	const struct vdpa_config_ops *ops = vdpa->config;
886	u32 asid = iotlb_to_asid(iotlb);
887	int r = 0;
888
889	r = vhost_iotlb_add_range_ctx(iotlb, iova, iova + size - 1,
890				      pa, perm, opaque);
891	if (r)
892		return r;
893
894	if (ops->dma_map) {
895		r = ops->dma_map(vdpa, asid, iova, size, pa, perm, opaque);
896	} else if (ops->set_map) {
897		if (!v->in_batch)
898			r = ops->set_map(vdpa, asid, iotlb);
899	} else {
900		r = iommu_map(v->domain, iova, pa, size,
901			      perm_to_iommu_flags(perm), GFP_KERNEL);
902	}
903	if (r) {
904		vhost_iotlb_del_range(iotlb, iova, iova + size - 1);
905		return r;
906	}
907
908	if (!vdpa->use_va)
909		atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
910
911	return 0;
912}
913
914static void vhost_vdpa_unmap(struct vhost_vdpa *v,
915			     struct vhost_iotlb *iotlb,
916			     u64 iova, u64 size)
917{
918	struct vdpa_device *vdpa = v->vdpa;
919	const struct vdpa_config_ops *ops = vdpa->config;
920	u32 asid = iotlb_to_asid(iotlb);
921
922	vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1, asid);
923
924	if (ops->set_map) {
925		if (!v->in_batch)
926			ops->set_map(vdpa, asid, iotlb);
927	}
928
929}
930
931static int vhost_vdpa_va_map(struct vhost_vdpa *v,
932			     struct vhost_iotlb *iotlb,
933			     u64 iova, u64 size, u64 uaddr, u32 perm)
934{
935	struct vhost_dev *dev = &v->vdev;
936	u64 offset, map_size, map_iova = iova;
937	struct vdpa_map_file *map_file;
938	struct vm_area_struct *vma;
939	int ret = 0;
940
941	mmap_read_lock(dev->mm);
942
943	while (size) {
944		vma = find_vma(dev->mm, uaddr);
945		if (!vma) {
946			ret = -EINVAL;
947			break;
948		}
949		map_size = min(size, vma->vm_end - uaddr);
950		if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
951			!(vma->vm_flags & (VM_IO | VM_PFNMAP))))
952			goto next;
953
954		map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
955		if (!map_file) {
956			ret = -ENOMEM;
957			break;
958		}
959		offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
960		map_file->offset = offset;
961		map_file->file = get_file(vma->vm_file);
962		ret = vhost_vdpa_map(v, iotlb, map_iova, map_size, uaddr,
963				     perm, map_file);
964		if (ret) {
965			fput(map_file->file);
966			kfree(map_file);
967			break;
968		}
969next:
970		size -= map_size;
971		uaddr += map_size;
972		map_iova += map_size;
973	}
974	if (ret)
975		vhost_vdpa_unmap(v, iotlb, iova, map_iova - iova);
976
977	mmap_read_unlock(dev->mm);
978
979	return ret;
980}
981
982static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
983			     struct vhost_iotlb *iotlb,
984			     u64 iova, u64 size, u64 uaddr, u32 perm)
985{
986	struct vhost_dev *dev = &v->vdev;
987	struct page **page_list;
988	unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
989	unsigned int gup_flags = FOLL_LONGTERM;
990	unsigned long npages, cur_base, map_pfn, last_pfn = 0;
991	unsigned long lock_limit, sz2pin, nchunks, i;
992	u64 start = iova;
993	long pinned;
994	int ret = 0;
995
996	/* Limit the use of memory for bookkeeping */
997	page_list = (struct page **) __get_free_page(GFP_KERNEL);
998	if (!page_list)
999		return -ENOMEM;
1000
1001	if (perm & VHOST_ACCESS_WO)
1002		gup_flags |= FOLL_WRITE;
1003
1004	npages = PFN_UP(size + (iova & ~PAGE_MASK));
1005	if (!npages) {
1006		ret = -EINVAL;
1007		goto free;
1008	}
1009
1010	mmap_read_lock(dev->mm);
1011
1012	lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
1013	if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
1014		ret = -ENOMEM;
1015		goto unlock;
1016	}
1017
1018	cur_base = uaddr & PAGE_MASK;
1019	iova &= PAGE_MASK;
1020	nchunks = 0;
1021
1022	while (npages) {
1023		sz2pin = min_t(unsigned long, npages, list_size);
1024		pinned = pin_user_pages(cur_base, sz2pin,
1025					gup_flags, page_list);
1026		if (sz2pin != pinned) {
1027			if (pinned < 0) {
1028				ret = pinned;
1029			} else {
1030				unpin_user_pages(page_list, pinned);
1031				ret = -ENOMEM;
1032			}
1033			goto out;
1034		}
1035		nchunks++;
1036
1037		if (!last_pfn)
1038			map_pfn = page_to_pfn(page_list[0]);
1039
1040		for (i = 0; i < pinned; i++) {
1041			unsigned long this_pfn = page_to_pfn(page_list[i]);
1042			u64 csize;
1043
1044			if (last_pfn && (this_pfn != last_pfn + 1)) {
1045				/* Pin a contiguous chunk of memory */
1046				csize = PFN_PHYS(last_pfn - map_pfn + 1);
1047				ret = vhost_vdpa_map(v, iotlb, iova, csize,
1048						     PFN_PHYS(map_pfn),
1049						     perm, NULL);
1050				if (ret) {
1051					/*
1052					 * Unpin the pages that are left unmapped
1053					 * from this point on in the current
1054					 * page_list. The remaining outstanding
1055					 * ones which may stride across several
1056					 * chunks will be covered in the common
1057					 * error path subsequently.
1058					 */
1059					unpin_user_pages(&page_list[i],
1060							 pinned - i);
1061					goto out;
1062				}
1063
1064				map_pfn = this_pfn;
1065				iova += csize;
1066				nchunks = 0;
1067			}
1068
1069			last_pfn = this_pfn;
1070		}
1071
1072		cur_base += PFN_PHYS(pinned);
1073		npages -= pinned;
1074	}
1075
1076	/* Pin the rest chunk */
1077	ret = vhost_vdpa_map(v, iotlb, iova, PFN_PHYS(last_pfn - map_pfn + 1),
1078			     PFN_PHYS(map_pfn), perm, NULL);
1079out:
1080	if (ret) {
1081		if (nchunks) {
1082			unsigned long pfn;
1083
1084			/*
1085			 * Unpin the outstanding pages which are yet to be
1086			 * mapped but haven't due to vdpa_map() or
1087			 * pin_user_pages() failure.
1088			 *
1089			 * Mapped pages are accounted in vdpa_map(), hence
1090			 * the corresponding unpinning will be handled by
1091			 * vdpa_unmap().
1092			 */
1093			WARN_ON(!last_pfn);
1094			for (pfn = map_pfn; pfn <= last_pfn; pfn++)
1095				unpin_user_page(pfn_to_page(pfn));
1096		}
1097		vhost_vdpa_unmap(v, iotlb, start, size);
1098	}
1099unlock:
1100	mmap_read_unlock(dev->mm);
1101free:
1102	free_page((unsigned long)page_list);
1103	return ret;
1104
1105}
1106
1107static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
1108					   struct vhost_iotlb *iotlb,
1109					   struct vhost_iotlb_msg *msg)
1110{
1111	struct vdpa_device *vdpa = v->vdpa;
1112
1113	if (msg->iova < v->range.first || !msg->size ||
1114	    msg->iova > U64_MAX - msg->size + 1 ||
1115	    msg->iova + msg->size - 1 > v->range.last)
1116		return -EINVAL;
1117
1118	if (vhost_iotlb_itree_first(iotlb, msg->iova,
1119				    msg->iova + msg->size - 1))
1120		return -EEXIST;
1121
1122	if (vdpa->use_va)
1123		return vhost_vdpa_va_map(v, iotlb, msg->iova, msg->size,
1124					 msg->uaddr, msg->perm);
1125
1126	return vhost_vdpa_pa_map(v, iotlb, msg->iova, msg->size, msg->uaddr,
1127				 msg->perm);
1128}
1129
1130static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1131					struct vhost_iotlb_msg *msg)
1132{
1133	struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
1134	struct vdpa_device *vdpa = v->vdpa;
1135	const struct vdpa_config_ops *ops = vdpa->config;
1136	struct vhost_iotlb *iotlb = NULL;
1137	struct vhost_vdpa_as *as = NULL;
1138	int r = 0;
1139
1140	mutex_lock(&dev->mutex);
1141
1142	r = vhost_dev_check_owner(dev);
1143	if (r)
1144		goto unlock;
1145
1146	if (msg->type == VHOST_IOTLB_UPDATE ||
1147	    msg->type == VHOST_IOTLB_BATCH_BEGIN) {
1148		as = vhost_vdpa_find_alloc_as(v, asid);
1149		if (!as) {
1150			dev_err(&v->dev, "can't find and alloc asid %d\n",
1151				asid);
1152			r = -EINVAL;
1153			goto unlock;
1154		}
1155		iotlb = &as->iotlb;
1156	} else
1157		iotlb = asid_to_iotlb(v, asid);
1158
1159	if ((v->in_batch && v->batch_asid != asid) || !iotlb) {
1160		if (v->in_batch && v->batch_asid != asid) {
1161			dev_info(&v->dev, "batch id %d asid %d\n",
1162				 v->batch_asid, asid);
1163		}
1164		if (!iotlb)
1165			dev_err(&v->dev, "no iotlb for asid %d\n", asid);
1166		r = -EINVAL;
1167		goto unlock;
1168	}
1169
1170	switch (msg->type) {
1171	case VHOST_IOTLB_UPDATE:
1172		r = vhost_vdpa_process_iotlb_update(v, iotlb, msg);
1173		break;
1174	case VHOST_IOTLB_INVALIDATE:
1175		vhost_vdpa_unmap(v, iotlb, msg->iova, msg->size);
1176		break;
1177	case VHOST_IOTLB_BATCH_BEGIN:
1178		v->batch_asid = asid;
1179		v->in_batch = true;
1180		break;
1181	case VHOST_IOTLB_BATCH_END:
1182		if (v->in_batch && ops->set_map)
1183			ops->set_map(vdpa, asid, iotlb);
1184		v->in_batch = false;
1185		break;
1186	default:
1187		r = -EINVAL;
1188		break;
1189	}
1190unlock:
1191	mutex_unlock(&dev->mutex);
1192
1193	return r;
1194}
1195
1196static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
1197					 struct iov_iter *from)
1198{
1199	struct file *file = iocb->ki_filp;
1200	struct vhost_vdpa *v = file->private_data;
1201	struct vhost_dev *dev = &v->vdev;
1202
1203	return vhost_chr_write_iter(dev, from);
1204}
1205
1206static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
1207{
1208	struct vdpa_device *vdpa = v->vdpa;
1209	const struct vdpa_config_ops *ops = vdpa->config;
1210	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1211	const struct bus_type *bus;
1212	int ret;
1213
1214	/* Device want to do DMA by itself */
1215	if (ops->set_map || ops->dma_map)
1216		return 0;
1217
1218	bus = dma_dev->bus;
1219	if (!bus)
1220		return -EFAULT;
1221
1222	if (!device_iommu_capable(dma_dev, IOMMU_CAP_CACHE_COHERENCY)) {
1223		dev_warn_once(&v->dev,
1224			      "Failed to allocate domain, device is not IOMMU cache coherent capable\n");
1225		return -ENOTSUPP;
1226	}
1227
1228	v->domain = iommu_domain_alloc(bus);
1229	if (!v->domain)
1230		return -EIO;
1231
1232	ret = iommu_attach_device(v->domain, dma_dev);
1233	if (ret)
1234		goto err_attach;
1235
1236	return 0;
1237
1238err_attach:
1239	iommu_domain_free(v->domain);
1240	v->domain = NULL;
1241	return ret;
1242}
1243
1244static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
1245{
1246	struct vdpa_device *vdpa = v->vdpa;
1247	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1248
1249	if (v->domain) {
1250		iommu_detach_device(v->domain, dma_dev);
1251		iommu_domain_free(v->domain);
1252	}
1253
1254	v->domain = NULL;
1255}
1256
1257static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
1258{
1259	struct vdpa_iova_range *range = &v->range;
1260	struct vdpa_device *vdpa = v->vdpa;
1261	const struct vdpa_config_ops *ops = vdpa->config;
1262
1263	if (ops->get_iova_range) {
1264		*range = ops->get_iova_range(vdpa);
1265	} else if (v->domain && v->domain->geometry.force_aperture) {
1266		range->first = v->domain->geometry.aperture_start;
1267		range->last = v->domain->geometry.aperture_end;
1268	} else {
1269		range->first = 0;
1270		range->last = ULLONG_MAX;
1271	}
1272}
1273
1274static void vhost_vdpa_cleanup(struct vhost_vdpa *v)
1275{
1276	struct vhost_vdpa_as *as;
1277	u32 asid;
1278
1279	for (asid = 0; asid < v->vdpa->nas; asid++) {
1280		as = asid_to_as(v, asid);
1281		if (as)
1282			vhost_vdpa_remove_as(v, asid);
1283	}
1284
1285	vhost_vdpa_free_domain(v);
1286	vhost_dev_cleanup(&v->vdev);
1287	kfree(v->vdev.vqs);
1288}
1289
1290static int vhost_vdpa_open(struct inode *inode, struct file *filep)
1291{
1292	struct vhost_vdpa *v;
1293	struct vhost_dev *dev;
1294	struct vhost_virtqueue **vqs;
1295	int r, opened;
1296	u32 i, nvqs;
1297
1298	v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
1299
1300	opened = atomic_cmpxchg(&v->opened, 0, 1);
1301	if (opened)
1302		return -EBUSY;
1303
1304	nvqs = v->nvqs;
1305	r = vhost_vdpa_reset(v);
1306	if (r)
1307		goto err;
1308
1309	vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
1310	if (!vqs) {
1311		r = -ENOMEM;
1312		goto err;
1313	}
1314
1315	dev = &v->vdev;
1316	for (i = 0; i < nvqs; i++) {
1317		vqs[i] = &v->vqs[i];
1318		vqs[i]->handle_kick = handle_vq_kick;
1319	}
1320	vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
1321		       vhost_vdpa_process_iotlb_msg);
1322
1323	r = vhost_vdpa_alloc_domain(v);
1324	if (r)
1325		goto err_alloc_domain;
1326
1327	vhost_vdpa_set_iova_range(v);
1328
1329	filep->private_data = v;
1330
1331	return 0;
1332
1333err_alloc_domain:
1334	vhost_vdpa_cleanup(v);
1335err:
1336	atomic_dec(&v->opened);
1337	return r;
1338}
1339
1340static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
1341{
1342	u32 i;
1343
1344	for (i = 0; i < v->nvqs; i++)
1345		vhost_vdpa_unsetup_vq_irq(v, i);
1346}
1347
1348static int vhost_vdpa_release(struct inode *inode, struct file *filep)
1349{
1350	struct vhost_vdpa *v = filep->private_data;
1351	struct vhost_dev *d = &v->vdev;
1352
1353	mutex_lock(&d->mutex);
1354	filep->private_data = NULL;
1355	vhost_vdpa_clean_irq(v);
1356	vhost_vdpa_reset(v);
1357	vhost_dev_stop(&v->vdev);
1358	vhost_vdpa_unbind_mm(v);
1359	vhost_vdpa_config_put(v);
1360	vhost_vdpa_cleanup(v);
1361	mutex_unlock(&d->mutex);
1362
1363	atomic_dec(&v->opened);
1364	complete(&v->completion);
1365
1366	return 0;
1367}
1368
1369#ifdef CONFIG_MMU
1370static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
1371{
1372	struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
1373	struct vdpa_device *vdpa = v->vdpa;
1374	const struct vdpa_config_ops *ops = vdpa->config;
1375	struct vdpa_notification_area notify;
1376	struct vm_area_struct *vma = vmf->vma;
1377	u16 index = vma->vm_pgoff;
1378
1379	notify = ops->get_vq_notification(vdpa, index);
1380
1381	vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
1382	if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
1383			    PFN_DOWN(notify.addr), PAGE_SIZE,
1384			    vma->vm_page_prot))
1385		return VM_FAULT_SIGBUS;
1386
1387	return VM_FAULT_NOPAGE;
1388}
1389
1390static const struct vm_operations_struct vhost_vdpa_vm_ops = {
1391	.fault = vhost_vdpa_fault,
1392};
1393
1394static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
1395{
1396	struct vhost_vdpa *v = vma->vm_file->private_data;
1397	struct vdpa_device *vdpa = v->vdpa;
1398	const struct vdpa_config_ops *ops = vdpa->config;
1399	struct vdpa_notification_area notify;
1400	unsigned long index = vma->vm_pgoff;
1401
1402	if (vma->vm_end - vma->vm_start != PAGE_SIZE)
1403		return -EINVAL;
1404	if ((vma->vm_flags & VM_SHARED) == 0)
1405		return -EINVAL;
1406	if (vma->vm_flags & VM_READ)
1407		return -EINVAL;
1408	if (index > 65535)
1409		return -EINVAL;
1410	if (!ops->get_vq_notification)
1411		return -ENOTSUPP;
1412
1413	/* To be safe and easily modelled by userspace, We only
1414	 * support the doorbell which sits on the page boundary and
1415	 * does not share the page with other registers.
1416	 */
1417	notify = ops->get_vq_notification(vdpa, index);
1418	if (notify.addr & (PAGE_SIZE - 1))
1419		return -EINVAL;
1420	if (vma->vm_end - vma->vm_start != notify.size)
1421		return -ENOTSUPP;
1422
1423	vm_flags_set(vma, VM_IO | VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP);
1424	vma->vm_ops = &vhost_vdpa_vm_ops;
1425	return 0;
1426}
1427#endif /* CONFIG_MMU */
1428
1429static const struct file_operations vhost_vdpa_fops = {
1430	.owner		= THIS_MODULE,
1431	.open		= vhost_vdpa_open,
1432	.release	= vhost_vdpa_release,
1433	.write_iter	= vhost_vdpa_chr_write_iter,
1434	.unlocked_ioctl	= vhost_vdpa_unlocked_ioctl,
1435#ifdef CONFIG_MMU
1436	.mmap		= vhost_vdpa_mmap,
1437#endif /* CONFIG_MMU */
1438	.compat_ioctl	= compat_ptr_ioctl,
1439};
1440
1441static void vhost_vdpa_release_dev(struct device *device)
1442{
1443	struct vhost_vdpa *v =
1444	       container_of(device, struct vhost_vdpa, dev);
1445
1446	ida_simple_remove(&vhost_vdpa_ida, v->minor);
1447	kfree(v->vqs);
1448	kfree(v);
1449}
1450
1451static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1452{
1453	const struct vdpa_config_ops *ops = vdpa->config;
1454	struct vhost_vdpa *v;
1455	int minor;
1456	int i, r;
1457
1458	/* We can't support platform IOMMU device with more than 1
1459	 * group or as
1460	 */
1461	if (!ops->set_map && !ops->dma_map &&
1462	    (vdpa->ngroups > 1 || vdpa->nas > 1))
1463		return -EOPNOTSUPP;
1464
1465	v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1466	if (!v)
1467		return -ENOMEM;
1468
1469	minor = ida_simple_get(&vhost_vdpa_ida, 0,
1470			       VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1471	if (minor < 0) {
1472		kfree(v);
1473		return minor;
1474	}
1475
1476	atomic_set(&v->opened, 0);
1477	v->minor = minor;
1478	v->vdpa = vdpa;
1479	v->nvqs = vdpa->nvqs;
1480	v->virtio_id = ops->get_device_id(vdpa);
1481
1482	device_initialize(&v->dev);
1483	v->dev.release = vhost_vdpa_release_dev;
1484	v->dev.parent = &vdpa->dev;
1485	v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1486	v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1487			       GFP_KERNEL);
1488	if (!v->vqs) {
1489		r = -ENOMEM;
1490		goto err;
1491	}
1492
1493	r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1494	if (r)
1495		goto err;
1496
1497	cdev_init(&v->cdev, &vhost_vdpa_fops);
1498	v->cdev.owner = THIS_MODULE;
1499
1500	r = cdev_device_add(&v->cdev, &v->dev);
1501	if (r)
1502		goto err;
1503
1504	init_completion(&v->completion);
1505	vdpa_set_drvdata(vdpa, v);
1506
1507	for (i = 0; i < VHOST_VDPA_IOTLB_BUCKETS; i++)
1508		INIT_HLIST_HEAD(&v->as[i]);
1509
1510	return 0;
1511
1512err:
1513	put_device(&v->dev);
1514	return r;
1515}
1516
1517static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1518{
1519	struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1520	int opened;
1521
1522	cdev_device_del(&v->cdev, &v->dev);
1523
1524	do {
1525		opened = atomic_cmpxchg(&v->opened, 0, 1);
1526		if (!opened)
1527			break;
1528		wait_for_completion(&v->completion);
1529	} while (1);
1530
1531	put_device(&v->dev);
1532}
1533
1534static struct vdpa_driver vhost_vdpa_driver = {
1535	.driver = {
1536		.name	= "vhost_vdpa",
1537	},
1538	.probe	= vhost_vdpa_probe,
1539	.remove	= vhost_vdpa_remove,
1540};
1541
1542static int __init vhost_vdpa_init(void)
1543{
1544	int r;
1545
1546	r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1547				"vhost-vdpa");
1548	if (r)
1549		goto err_alloc_chrdev;
1550
1551	r = vdpa_register_driver(&vhost_vdpa_driver);
1552	if (r)
1553		goto err_vdpa_register_driver;
1554
1555	return 0;
1556
1557err_vdpa_register_driver:
1558	unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1559err_alloc_chrdev:
1560	return r;
1561}
1562module_init(vhost_vdpa_init);
1563
1564static void __exit vhost_vdpa_exit(void)
1565{
1566	vdpa_unregister_driver(&vhost_vdpa_driver);
1567	unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1568}
1569module_exit(vhost_vdpa_exit);
1570
1571MODULE_VERSION("0.0.1");
1572MODULE_LICENSE("GPL v2");
1573MODULE_AUTHOR("Intel Corporation");
1574MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");
1575