1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
4 * Author: Joerg Roedel <jroedel@suse.de>
5 */
6
7#define pr_fmt(fmt)     "AMD-Vi: " fmt
8
9#include <linux/mmu_notifier.h>
10#include <linux/amd-iommu.h>
11#include <linux/mm_types.h>
12#include <linux/profile.h>
13#include <linux/module.h>
14#include <linux/sched.h>
15#include <linux/sched/mm.h>
16#include <linux/wait.h>
17#include <linux/pci.h>
18#include <linux/gfp.h>
19
20#include "amd_iommu.h"
21
22MODULE_LICENSE("GPL v2");
23MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
24
25#define MAX_DEVICES		0x10000
26#define PRI_QUEUE_SIZE		512
27
28struct pri_queue {
29	atomic_t inflight;
30	bool finish;
31	int status;
32};
33
34struct pasid_state {
35	struct list_head list;			/* For global state-list */
36	atomic_t count;				/* Reference count */
37	unsigned mmu_notifier_count;		/* Counting nested mmu_notifier
38						   calls */
39	struct mm_struct *mm;			/* mm_struct for the faults */
40	struct mmu_notifier mn;                 /* mmu_notifier handle */
41	struct pri_queue pri[PRI_QUEUE_SIZE];	/* PRI tag states */
42	struct device_state *device_state;	/* Link to our device_state */
43	u32 pasid;				/* PASID index */
44	bool invalid;				/* Used during setup and
45						   teardown of the pasid */
46	spinlock_t lock;			/* Protect pri_queues and
47						   mmu_notifer_count */
48	wait_queue_head_t wq;			/* To wait for count == 0 */
49};
50
51struct device_state {
52	struct list_head list;
53	u16 devid;
54	atomic_t count;
55	struct pci_dev *pdev;
56	struct pasid_state **states;
57	struct iommu_domain *domain;
58	int pasid_levels;
59	int max_pasids;
60	amd_iommu_invalid_ppr_cb inv_ppr_cb;
61	amd_iommu_invalidate_ctx inv_ctx_cb;
62	spinlock_t lock;
63	wait_queue_head_t wq;
64};
65
66struct fault {
67	struct work_struct work;
68	struct device_state *dev_state;
69	struct pasid_state *state;
70	struct mm_struct *mm;
71	u64 address;
72	u16 devid;
73	u32 pasid;
74	u16 tag;
75	u16 finish;
76	u16 flags;
77};
78
79static LIST_HEAD(state_list);
80static spinlock_t state_lock;
81
82static struct workqueue_struct *iommu_wq;
83
84static void free_pasid_states(struct device_state *dev_state);
85
86static u16 device_id(struct pci_dev *pdev)
87{
88	u16 devid;
89
90	devid = pdev->bus->number;
91	devid = (devid << 8) | pdev->devfn;
92
93	return devid;
94}
95
96static struct device_state *__get_device_state(u16 devid)
97{
98	struct device_state *dev_state;
99
100	list_for_each_entry(dev_state, &state_list, list) {
101		if (dev_state->devid == devid)
102			return dev_state;
103	}
104
105	return NULL;
106}
107
108static struct device_state *get_device_state(u16 devid)
109{
110	struct device_state *dev_state;
111	unsigned long flags;
112
113	spin_lock_irqsave(&state_lock, flags);
114	dev_state = __get_device_state(devid);
115	if (dev_state != NULL)
116		atomic_inc(&dev_state->count);
117	spin_unlock_irqrestore(&state_lock, flags);
118
119	return dev_state;
120}
121
122static void free_device_state(struct device_state *dev_state)
123{
124	struct iommu_group *group;
125
126	/*
127	 * First detach device from domain - No more PRI requests will arrive
128	 * from that device after it is unbound from the IOMMUv2 domain.
129	 */
130	group = iommu_group_get(&dev_state->pdev->dev);
131	if (WARN_ON(!group))
132		return;
133
134	iommu_detach_group(dev_state->domain, group);
135
136	iommu_group_put(group);
137
138	/* Everything is down now, free the IOMMUv2 domain */
139	iommu_domain_free(dev_state->domain);
140
141	/* Finally get rid of the device-state */
142	kfree(dev_state);
143}
144
145static void put_device_state(struct device_state *dev_state)
146{
147	if (atomic_dec_and_test(&dev_state->count))
148		wake_up(&dev_state->wq);
149}
150
151/* Must be called under dev_state->lock */
152static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
153						  u32 pasid, bool alloc)
154{
155	struct pasid_state **root, **ptr;
156	int level, index;
157
158	level = dev_state->pasid_levels;
159	root  = dev_state->states;
160
161	while (true) {
162
163		index = (pasid >> (9 * level)) & 0x1ff;
164		ptr   = &root[index];
165
166		if (level == 0)
167			break;
168
169		if (*ptr == NULL) {
170			if (!alloc)
171				return NULL;
172
173			*ptr = (void *)get_zeroed_page(GFP_ATOMIC);
174			if (*ptr == NULL)
175				return NULL;
176		}
177
178		root   = (struct pasid_state **)*ptr;
179		level -= 1;
180	}
181
182	return ptr;
183}
184
185static int set_pasid_state(struct device_state *dev_state,
186			   struct pasid_state *pasid_state,
187			   u32 pasid)
188{
189	struct pasid_state **ptr;
190	unsigned long flags;
191	int ret;
192
193	spin_lock_irqsave(&dev_state->lock, flags);
194	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
195
196	ret = -ENOMEM;
197	if (ptr == NULL)
198		goto out_unlock;
199
200	ret = -ENOMEM;
201	if (*ptr != NULL)
202		goto out_unlock;
203
204	*ptr = pasid_state;
205
206	ret = 0;
207
208out_unlock:
209	spin_unlock_irqrestore(&dev_state->lock, flags);
210
211	return ret;
212}
213
214static void clear_pasid_state(struct device_state *dev_state, u32 pasid)
215{
216	struct pasid_state **ptr;
217	unsigned long flags;
218
219	spin_lock_irqsave(&dev_state->lock, flags);
220	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
221
222	if (ptr == NULL)
223		goto out_unlock;
224
225	*ptr = NULL;
226
227out_unlock:
228	spin_unlock_irqrestore(&dev_state->lock, flags);
229}
230
231static struct pasid_state *get_pasid_state(struct device_state *dev_state,
232					   u32 pasid)
233{
234	struct pasid_state **ptr, *ret = NULL;
235	unsigned long flags;
236
237	spin_lock_irqsave(&dev_state->lock, flags);
238	ptr = __get_pasid_state_ptr(dev_state, pasid, false);
239
240	if (ptr == NULL)
241		goto out_unlock;
242
243	ret = *ptr;
244	if (ret)
245		atomic_inc(&ret->count);
246
247out_unlock:
248	spin_unlock_irqrestore(&dev_state->lock, flags);
249
250	return ret;
251}
252
253static void free_pasid_state(struct pasid_state *pasid_state)
254{
255	kfree(pasid_state);
256}
257
258static void put_pasid_state(struct pasid_state *pasid_state)
259{
260	if (atomic_dec_and_test(&pasid_state->count))
261		wake_up(&pasid_state->wq);
262}
263
264static void put_pasid_state_wait(struct pasid_state *pasid_state)
265{
266	atomic_dec(&pasid_state->count);
267	wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
268	free_pasid_state(pasid_state);
269}
270
271static void unbind_pasid(struct pasid_state *pasid_state)
272{
273	struct iommu_domain *domain;
274
275	domain = pasid_state->device_state->domain;
276
277	/*
278	 * Mark pasid_state as invalid, no more faults will we added to the
279	 * work queue after this is visible everywhere.
280	 */
281	pasid_state->invalid = true;
282
283	/* Make sure this is visible */
284	smp_wmb();
285
286	/* After this the device/pasid can't access the mm anymore */
287	amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
288
289	/* Make sure no more pending faults are in the queue */
290	flush_workqueue(iommu_wq);
291}
292
293static void free_pasid_states_level1(struct pasid_state **tbl)
294{
295	int i;
296
297	for (i = 0; i < 512; ++i) {
298		if (tbl[i] == NULL)
299			continue;
300
301		free_page((unsigned long)tbl[i]);
302	}
303}
304
305static void free_pasid_states_level2(struct pasid_state **tbl)
306{
307	struct pasid_state **ptr;
308	int i;
309
310	for (i = 0; i < 512; ++i) {
311		if (tbl[i] == NULL)
312			continue;
313
314		ptr = (struct pasid_state **)tbl[i];
315		free_pasid_states_level1(ptr);
316	}
317}
318
319static void free_pasid_states(struct device_state *dev_state)
320{
321	struct pasid_state *pasid_state;
322	int i;
323
324	for (i = 0; i < dev_state->max_pasids; ++i) {
325		pasid_state = get_pasid_state(dev_state, i);
326		if (pasid_state == NULL)
327			continue;
328
329		put_pasid_state(pasid_state);
330
331		/*
332		 * This will call the mn_release function and
333		 * unbind the PASID
334		 */
335		mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
336
337		put_pasid_state_wait(pasid_state); /* Reference taken in
338						      amd_iommu_bind_pasid */
339
340		/* Drop reference taken in amd_iommu_bind_pasid */
341		put_device_state(dev_state);
342	}
343
344	if (dev_state->pasid_levels == 2)
345		free_pasid_states_level2(dev_state->states);
346	else if (dev_state->pasid_levels == 1)
347		free_pasid_states_level1(dev_state->states);
348	else
349		BUG_ON(dev_state->pasid_levels != 0);
350
351	free_page((unsigned long)dev_state->states);
352}
353
354static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
355{
356	return container_of(mn, struct pasid_state, mn);
357}
358
359static void mn_invalidate_range(struct mmu_notifier *mn,
360				struct mm_struct *mm,
361				unsigned long start, unsigned long end)
362{
363	struct pasid_state *pasid_state;
364	struct device_state *dev_state;
365
366	pasid_state = mn_to_state(mn);
367	dev_state   = pasid_state->device_state;
368
369	if ((start ^ (end - 1)) < PAGE_SIZE)
370		amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
371				     start);
372	else
373		amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
374}
375
376static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
377{
378	struct pasid_state *pasid_state;
379	struct device_state *dev_state;
380	bool run_inv_ctx_cb;
381
382	might_sleep();
383
384	pasid_state    = mn_to_state(mn);
385	dev_state      = pasid_state->device_state;
386	run_inv_ctx_cb = !pasid_state->invalid;
387
388	if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
389		dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
390
391	unbind_pasid(pasid_state);
392}
393
394static const struct mmu_notifier_ops iommu_mn = {
395	.release		= mn_release,
396	.invalidate_range       = mn_invalidate_range,
397};
398
399static void set_pri_tag_status(struct pasid_state *pasid_state,
400			       u16 tag, int status)
401{
402	unsigned long flags;
403
404	spin_lock_irqsave(&pasid_state->lock, flags);
405	pasid_state->pri[tag].status = status;
406	spin_unlock_irqrestore(&pasid_state->lock, flags);
407}
408
409static void finish_pri_tag(struct device_state *dev_state,
410			   struct pasid_state *pasid_state,
411			   u16 tag)
412{
413	unsigned long flags;
414
415	spin_lock_irqsave(&pasid_state->lock, flags);
416	if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
417	    pasid_state->pri[tag].finish) {
418		amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
419				       pasid_state->pri[tag].status, tag);
420		pasid_state->pri[tag].finish = false;
421		pasid_state->pri[tag].status = PPR_SUCCESS;
422	}
423	spin_unlock_irqrestore(&pasid_state->lock, flags);
424}
425
426static void handle_fault_error(struct fault *fault)
427{
428	int status;
429
430	if (!fault->dev_state->inv_ppr_cb) {
431		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
432		return;
433	}
434
435	status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
436					      fault->pasid,
437					      fault->address,
438					      fault->flags);
439	switch (status) {
440	case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
441		set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
442		break;
443	case AMD_IOMMU_INV_PRI_RSP_INVALID:
444		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
445		break;
446	case AMD_IOMMU_INV_PRI_RSP_FAIL:
447		set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
448		break;
449	default:
450		BUG();
451	}
452}
453
454static bool access_error(struct vm_area_struct *vma, struct fault *fault)
455{
456	unsigned long requested = 0;
457
458	if (fault->flags & PPR_FAULT_EXEC)
459		requested |= VM_EXEC;
460
461	if (fault->flags & PPR_FAULT_READ)
462		requested |= VM_READ;
463
464	if (fault->flags & PPR_FAULT_WRITE)
465		requested |= VM_WRITE;
466
467	return (requested & ~vma->vm_flags) != 0;
468}
469
470static void do_fault(struct work_struct *work)
471{
472	struct fault *fault = container_of(work, struct fault, work);
473	struct vm_area_struct *vma;
474	vm_fault_t ret = VM_FAULT_ERROR;
475	unsigned int flags = 0;
476	struct mm_struct *mm;
477	u64 address;
478
479	mm = fault->state->mm;
480	address = fault->address;
481
482	if (fault->flags & PPR_FAULT_USER)
483		flags |= FAULT_FLAG_USER;
484	if (fault->flags & PPR_FAULT_WRITE)
485		flags |= FAULT_FLAG_WRITE;
486	flags |= FAULT_FLAG_REMOTE;
487
488	mmap_read_lock(mm);
489	vma = find_extend_vma(mm, address);
490	if (!vma || address < vma->vm_start)
491		/* failed to get a vma in the right range */
492		goto out;
493
494	/* Check if we have the right permissions on the vma */
495	if (access_error(vma, fault))
496		goto out;
497
498	ret = handle_mm_fault(vma, address, flags, NULL);
499out:
500	mmap_read_unlock(mm);
501
502	if (ret & VM_FAULT_ERROR)
503		/* failed to service fault */
504		handle_fault_error(fault);
505
506	finish_pri_tag(fault->dev_state, fault->state, fault->tag);
507
508	put_pasid_state(fault->state);
509
510	kfree(fault);
511}
512
513static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
514{
515	struct amd_iommu_fault *iommu_fault;
516	struct pasid_state *pasid_state;
517	struct device_state *dev_state;
518	struct pci_dev *pdev = NULL;
519	unsigned long flags;
520	struct fault *fault;
521	bool finish;
522	u16 tag, devid;
523	int ret;
524
525	iommu_fault = data;
526	tag         = iommu_fault->tag & 0x1ff;
527	finish      = (iommu_fault->tag >> 9) & 1;
528
529	devid = iommu_fault->device_id;
530	pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
531					   devid & 0xff);
532	if (!pdev)
533		return -ENODEV;
534
535	ret = NOTIFY_DONE;
536
537	/* In kdump kernel pci dev is not initialized yet -> send INVALID */
538	if (amd_iommu_is_attach_deferred(NULL, &pdev->dev)) {
539		amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
540				       PPR_INVALID, tag);
541		goto out;
542	}
543
544	dev_state = get_device_state(iommu_fault->device_id);
545	if (dev_state == NULL)
546		goto out;
547
548	pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
549	if (pasid_state == NULL || pasid_state->invalid) {
550		/* We know the device but not the PASID -> send INVALID */
551		amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
552				       PPR_INVALID, tag);
553		goto out_drop_state;
554	}
555
556	spin_lock_irqsave(&pasid_state->lock, flags);
557	atomic_inc(&pasid_state->pri[tag].inflight);
558	if (finish)
559		pasid_state->pri[tag].finish = true;
560	spin_unlock_irqrestore(&pasid_state->lock, flags);
561
562	fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
563	if (fault == NULL) {
564		/* We are OOM - send success and let the device re-fault */
565		finish_pri_tag(dev_state, pasid_state, tag);
566		goto out_drop_state;
567	}
568
569	fault->dev_state = dev_state;
570	fault->address   = iommu_fault->address;
571	fault->state     = pasid_state;
572	fault->tag       = tag;
573	fault->finish    = finish;
574	fault->pasid     = iommu_fault->pasid;
575	fault->flags     = iommu_fault->flags;
576	INIT_WORK(&fault->work, do_fault);
577
578	queue_work(iommu_wq, &fault->work);
579
580	ret = NOTIFY_OK;
581
582out_drop_state:
583
584	if (ret != NOTIFY_OK && pasid_state)
585		put_pasid_state(pasid_state);
586
587	put_device_state(dev_state);
588
589out:
590	pci_dev_put(pdev);
591	return ret;
592}
593
594static struct notifier_block ppr_nb = {
595	.notifier_call = ppr_notifier,
596};
597
598int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
599			 struct task_struct *task)
600{
601	struct pasid_state *pasid_state;
602	struct device_state *dev_state;
603	struct mm_struct *mm;
604	u16 devid;
605	int ret;
606
607	might_sleep();
608
609	if (!amd_iommu_v2_supported())
610		return -ENODEV;
611
612	devid     = device_id(pdev);
613	dev_state = get_device_state(devid);
614
615	if (dev_state == NULL)
616		return -EINVAL;
617
618	ret = -EINVAL;
619	if (pasid >= dev_state->max_pasids)
620		goto out;
621
622	ret = -ENOMEM;
623	pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
624	if (pasid_state == NULL)
625		goto out;
626
627
628	atomic_set(&pasid_state->count, 1);
629	init_waitqueue_head(&pasid_state->wq);
630	spin_lock_init(&pasid_state->lock);
631
632	mm                        = get_task_mm(task);
633	pasid_state->mm           = mm;
634	pasid_state->device_state = dev_state;
635	pasid_state->pasid        = pasid;
636	pasid_state->invalid      = true; /* Mark as valid only if we are
637					     done with setting up the pasid */
638	pasid_state->mn.ops       = &iommu_mn;
639
640	if (pasid_state->mm == NULL)
641		goto out_free;
642
643	mmu_notifier_register(&pasid_state->mn, mm);
644
645	ret = set_pasid_state(dev_state, pasid_state, pasid);
646	if (ret)
647		goto out_unregister;
648
649	ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
650					__pa(pasid_state->mm->pgd));
651	if (ret)
652		goto out_clear_state;
653
654	/* Now we are ready to handle faults */
655	pasid_state->invalid = false;
656
657	/*
658	 * Drop the reference to the mm_struct here. We rely on the
659	 * mmu_notifier release call-back to inform us when the mm
660	 * is going away.
661	 */
662	mmput(mm);
663
664	return 0;
665
666out_clear_state:
667	clear_pasid_state(dev_state, pasid);
668
669out_unregister:
670	mmu_notifier_unregister(&pasid_state->mn, mm);
671	mmput(mm);
672
673out_free:
674	free_pasid_state(pasid_state);
675
676out:
677	put_device_state(dev_state);
678
679	return ret;
680}
681EXPORT_SYMBOL(amd_iommu_bind_pasid);
682
683void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid)
684{
685	struct pasid_state *pasid_state;
686	struct device_state *dev_state;
687	u16 devid;
688
689	might_sleep();
690
691	if (!amd_iommu_v2_supported())
692		return;
693
694	devid = device_id(pdev);
695	dev_state = get_device_state(devid);
696	if (dev_state == NULL)
697		return;
698
699	if (pasid >= dev_state->max_pasids)
700		goto out;
701
702	pasid_state = get_pasid_state(dev_state, pasid);
703	if (pasid_state == NULL)
704		goto out;
705	/*
706	 * Drop reference taken here. We are safe because we still hold
707	 * the reference taken in the amd_iommu_bind_pasid function.
708	 */
709	put_pasid_state(pasid_state);
710
711	/* Clear the pasid state so that the pasid can be re-used */
712	clear_pasid_state(dev_state, pasid_state->pasid);
713
714	/*
715	 * Call mmu_notifier_unregister to drop our reference
716	 * to pasid_state->mm
717	 */
718	mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
719
720	put_pasid_state_wait(pasid_state); /* Reference taken in
721					      amd_iommu_bind_pasid */
722out:
723	/* Drop reference taken in this function */
724	put_device_state(dev_state);
725
726	/* Drop reference taken in amd_iommu_bind_pasid */
727	put_device_state(dev_state);
728}
729EXPORT_SYMBOL(amd_iommu_unbind_pasid);
730
731int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
732{
733	struct device_state *dev_state;
734	struct iommu_group *group;
735	unsigned long flags;
736	int ret, tmp;
737	u16 devid;
738
739	might_sleep();
740
741	/*
742	 * When memory encryption is active the device is likely not in a
743	 * direct-mapped domain. Forbid using IOMMUv2 functionality for now.
744	 */
745	if (mem_encrypt_active())
746		return -ENODEV;
747
748	if (!amd_iommu_v2_supported())
749		return -ENODEV;
750
751	if (pasids <= 0 || pasids > (PASID_MASK + 1))
752		return -EINVAL;
753
754	devid = device_id(pdev);
755
756	dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
757	if (dev_state == NULL)
758		return -ENOMEM;
759
760	spin_lock_init(&dev_state->lock);
761	init_waitqueue_head(&dev_state->wq);
762	dev_state->pdev  = pdev;
763	dev_state->devid = devid;
764
765	tmp = pasids;
766	for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
767		dev_state->pasid_levels += 1;
768
769	atomic_set(&dev_state->count, 1);
770	dev_state->max_pasids = pasids;
771
772	ret = -ENOMEM;
773	dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
774	if (dev_state->states == NULL)
775		goto out_free_dev_state;
776
777	dev_state->domain = iommu_domain_alloc(&pci_bus_type);
778	if (dev_state->domain == NULL)
779		goto out_free_states;
780
781	amd_iommu_domain_direct_map(dev_state->domain);
782
783	ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
784	if (ret)
785		goto out_free_domain;
786
787	group = iommu_group_get(&pdev->dev);
788	if (!group) {
789		ret = -EINVAL;
790		goto out_free_domain;
791	}
792
793	ret = iommu_attach_group(dev_state->domain, group);
794	if (ret != 0)
795		goto out_drop_group;
796
797	iommu_group_put(group);
798
799	spin_lock_irqsave(&state_lock, flags);
800
801	if (__get_device_state(devid) != NULL) {
802		spin_unlock_irqrestore(&state_lock, flags);
803		ret = -EBUSY;
804		goto out_free_domain;
805	}
806
807	list_add_tail(&dev_state->list, &state_list);
808
809	spin_unlock_irqrestore(&state_lock, flags);
810
811	return 0;
812
813out_drop_group:
814	iommu_group_put(group);
815
816out_free_domain:
817	iommu_domain_free(dev_state->domain);
818
819out_free_states:
820	free_page((unsigned long)dev_state->states);
821
822out_free_dev_state:
823	kfree(dev_state);
824
825	return ret;
826}
827EXPORT_SYMBOL(amd_iommu_init_device);
828
829void amd_iommu_free_device(struct pci_dev *pdev)
830{
831	struct device_state *dev_state;
832	unsigned long flags;
833	u16 devid;
834
835	if (!amd_iommu_v2_supported())
836		return;
837
838	devid = device_id(pdev);
839
840	spin_lock_irqsave(&state_lock, flags);
841
842	dev_state = __get_device_state(devid);
843	if (dev_state == NULL) {
844		spin_unlock_irqrestore(&state_lock, flags);
845		return;
846	}
847
848	list_del(&dev_state->list);
849
850	spin_unlock_irqrestore(&state_lock, flags);
851
852	/* Get rid of any remaining pasid states */
853	free_pasid_states(dev_state);
854
855	put_device_state(dev_state);
856	/*
857	 * Wait until the last reference is dropped before freeing
858	 * the device state.
859	 */
860	wait_event(dev_state->wq, !atomic_read(&dev_state->count));
861	free_device_state(dev_state);
862}
863EXPORT_SYMBOL(amd_iommu_free_device);
864
865int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
866				 amd_iommu_invalid_ppr_cb cb)
867{
868	struct device_state *dev_state;
869	unsigned long flags;
870	u16 devid;
871	int ret;
872
873	if (!amd_iommu_v2_supported())
874		return -ENODEV;
875
876	devid = device_id(pdev);
877
878	spin_lock_irqsave(&state_lock, flags);
879
880	ret = -EINVAL;
881	dev_state = __get_device_state(devid);
882	if (dev_state == NULL)
883		goto out_unlock;
884
885	dev_state->inv_ppr_cb = cb;
886
887	ret = 0;
888
889out_unlock:
890	spin_unlock_irqrestore(&state_lock, flags);
891
892	return ret;
893}
894EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
895
896int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
897				    amd_iommu_invalidate_ctx cb)
898{
899	struct device_state *dev_state;
900	unsigned long flags;
901	u16 devid;
902	int ret;
903
904	if (!amd_iommu_v2_supported())
905		return -ENODEV;
906
907	devid = device_id(pdev);
908
909	spin_lock_irqsave(&state_lock, flags);
910
911	ret = -EINVAL;
912	dev_state = __get_device_state(devid);
913	if (dev_state == NULL)
914		goto out_unlock;
915
916	dev_state->inv_ctx_cb = cb;
917
918	ret = 0;
919
920out_unlock:
921	spin_unlock_irqrestore(&state_lock, flags);
922
923	return ret;
924}
925EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
926
927static int __init amd_iommu_v2_init(void)
928{
929	int ret;
930
931	if (!amd_iommu_v2_supported()) {
932		pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n");
933		/*
934		 * Load anyway to provide the symbols to other modules
935		 * which may use AMD IOMMUv2 optionally.
936		 */
937		return 0;
938	}
939
940	spin_lock_init(&state_lock);
941
942	ret = -ENOMEM;
943	iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
944	if (iommu_wq == NULL)
945		goto out;
946
947	amd_iommu_register_ppr_notifier(&ppr_nb);
948
949	pr_info("AMD IOMMUv2 loaded and initialized\n");
950
951	return 0;
952
953out:
954	return ret;
955}
956
957static void __exit amd_iommu_v2_exit(void)
958{
959	struct device_state *dev_state;
960	int i;
961
962	if (!amd_iommu_v2_supported())
963		return;
964
965	amd_iommu_unregister_ppr_notifier(&ppr_nb);
966
967	flush_workqueue(iommu_wq);
968
969	/*
970	 * The loop below might call flush_workqueue(), so call
971	 * destroy_workqueue() after it
972	 */
973	for (i = 0; i < MAX_DEVICES; ++i) {
974		dev_state = get_device_state(i);
975
976		if (dev_state == NULL)
977			continue;
978
979		WARN_ON_ONCE(1);
980
981		put_device_state(dev_state);
982		amd_iommu_free_device(dev_state->pdev);
983	}
984
985	destroy_workqueue(iommu_wq);
986}
987
988module_init(amd_iommu_v2_init);
989module_exit(amd_iommu_v2_exit);
990