1// SPDX-License-Identifier: GPL-2.0-only
2/*
3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
4 * Author: Joerg Roedel <jroedel@suse.de>
5 */
6
7#define pr_fmt(fmt)     "AMD-Vi: " fmt
8
9#include <linux/refcount.h>
10#include <linux/mmu_notifier.h>
11#include <linux/amd-iommu.h>
12#include <linux/mm_types.h>
13#include <linux/profile.h>
14#include <linux/module.h>
15#include <linux/sched.h>
16#include <linux/sched/mm.h>
17#include <linux/wait.h>
18#include <linux/pci.h>
19#include <linux/gfp.h>
20#include <linux/cc_platform.h>
21
22#include "amd_iommu.h"
23
24MODULE_LICENSE("GPL v2");
25MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
26
27#define PRI_QUEUE_SIZE		512
28
29struct pri_queue {
30	atomic_t inflight;
31	bool finish;
32	int status;
33};
34
35struct pasid_state {
36	struct list_head list;			/* For global state-list */
37	refcount_t count;				/* Reference count */
38	unsigned mmu_notifier_count;		/* Counting nested mmu_notifier
39						   calls */
40	struct mm_struct *mm;			/* mm_struct for the faults */
41	struct mmu_notifier mn;                 /* mmu_notifier handle */
42	struct pri_queue pri[PRI_QUEUE_SIZE];	/* PRI tag states */
43	struct device_state *device_state;	/* Link to our device_state */
44	u32 pasid;				/* PASID index */
45	bool invalid;				/* Used during setup and
46						   teardown of the pasid */
47	spinlock_t lock;			/* Protect pri_queues and
48						   mmu_notifer_count */
49	wait_queue_head_t wq;			/* To wait for count == 0 */
50};
51
52struct device_state {
53	struct list_head list;
54	u32 sbdf;
55	atomic_t count;
56	struct pci_dev *pdev;
57	struct pasid_state **states;
58	struct iommu_domain *domain;
59	int pasid_levels;
60	int max_pasids;
61	amd_iommu_invalid_ppr_cb inv_ppr_cb;
62	amd_iommu_invalidate_ctx inv_ctx_cb;
63	spinlock_t lock;
64	wait_queue_head_t wq;
65};
66
67struct fault {
68	struct work_struct work;
69	struct device_state *dev_state;
70	struct pasid_state *state;
71	struct mm_struct *mm;
72	u64 address;
73	u32 pasid;
74	u16 tag;
75	u16 finish;
76	u16 flags;
77};
78
79static LIST_HEAD(state_list);
80static DEFINE_SPINLOCK(state_lock);
81
82static struct workqueue_struct *iommu_wq;
83
84static void free_pasid_states(struct device_state *dev_state);
85
86static struct device_state *__get_device_state(u32 sbdf)
87{
88	struct device_state *dev_state;
89
90	list_for_each_entry(dev_state, &state_list, list) {
91		if (dev_state->sbdf == sbdf)
92			return dev_state;
93	}
94
95	return NULL;
96}
97
98static struct device_state *get_device_state(u32 sbdf)
99{
100	struct device_state *dev_state;
101	unsigned long flags;
102
103	spin_lock_irqsave(&state_lock, flags);
104	dev_state = __get_device_state(sbdf);
105	if (dev_state != NULL)
106		atomic_inc(&dev_state->count);
107	spin_unlock_irqrestore(&state_lock, flags);
108
109	return dev_state;
110}
111
112static void free_device_state(struct device_state *dev_state)
113{
114	struct iommu_group *group;
115
116	/* Get rid of any remaining pasid states */
117	free_pasid_states(dev_state);
118
119	/*
120	 * Wait until the last reference is dropped before freeing
121	 * the device state.
122	 */
123	wait_event(dev_state->wq, !atomic_read(&dev_state->count));
124
125	/*
126	 * First detach device from domain - No more PRI requests will arrive
127	 * from that device after it is unbound from the IOMMUv2 domain.
128	 */
129	group = iommu_group_get(&dev_state->pdev->dev);
130	if (WARN_ON(!group))
131		return;
132
133	iommu_detach_group(dev_state->domain, group);
134
135	iommu_group_put(group);
136
137	/* Everything is down now, free the IOMMUv2 domain */
138	iommu_domain_free(dev_state->domain);
139
140	/* Finally get rid of the device-state */
141	kfree(dev_state);
142}
143
144static void put_device_state(struct device_state *dev_state)
145{
146	if (atomic_dec_and_test(&dev_state->count))
147		wake_up(&dev_state->wq);
148}
149
150/* Must be called under dev_state->lock */
151static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
152						  u32 pasid, bool alloc)
153{
154	struct pasid_state **root, **ptr;
155	int level, index;
156
157	level = dev_state->pasid_levels;
158	root  = dev_state->states;
159
160	while (true) {
161
162		index = (pasid >> (9 * level)) & 0x1ff;
163		ptr   = &root[index];
164
165		if (level == 0)
166			break;
167
168		if (*ptr == NULL) {
169			if (!alloc)
170				return NULL;
171
172			*ptr = (void *)get_zeroed_page(GFP_ATOMIC);
173			if (*ptr == NULL)
174				return NULL;
175		}
176
177		root   = (struct pasid_state **)*ptr;
178		level -= 1;
179	}
180
181	return ptr;
182}
183
184static int set_pasid_state(struct device_state *dev_state,
185			   struct pasid_state *pasid_state,
186			   u32 pasid)
187{
188	struct pasid_state **ptr;
189	unsigned long flags;
190	int ret;
191
192	spin_lock_irqsave(&dev_state->lock, flags);
193	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
194
195	ret = -ENOMEM;
196	if (ptr == NULL)
197		goto out_unlock;
198
199	ret = -ENOMEM;
200	if (*ptr != NULL)
201		goto out_unlock;
202
203	*ptr = pasid_state;
204
205	ret = 0;
206
207out_unlock:
208	spin_unlock_irqrestore(&dev_state->lock, flags);
209
210	return ret;
211}
212
213static void clear_pasid_state(struct device_state *dev_state, u32 pasid)
214{
215	struct pasid_state **ptr;
216	unsigned long flags;
217
218	spin_lock_irqsave(&dev_state->lock, flags);
219	ptr = __get_pasid_state_ptr(dev_state, pasid, true);
220
221	if (ptr == NULL)
222		goto out_unlock;
223
224	*ptr = NULL;
225
226out_unlock:
227	spin_unlock_irqrestore(&dev_state->lock, flags);
228}
229
230static struct pasid_state *get_pasid_state(struct device_state *dev_state,
231					   u32 pasid)
232{
233	struct pasid_state **ptr, *ret = NULL;
234	unsigned long flags;
235
236	spin_lock_irqsave(&dev_state->lock, flags);
237	ptr = __get_pasid_state_ptr(dev_state, pasid, false);
238
239	if (ptr == NULL)
240		goto out_unlock;
241
242	ret = *ptr;
243	if (ret)
244		refcount_inc(&ret->count);
245
246out_unlock:
247	spin_unlock_irqrestore(&dev_state->lock, flags);
248
249	return ret;
250}
251
252static void free_pasid_state(struct pasid_state *pasid_state)
253{
254	kfree(pasid_state);
255}
256
257static void put_pasid_state(struct pasid_state *pasid_state)
258{
259	if (refcount_dec_and_test(&pasid_state->count))
260		wake_up(&pasid_state->wq);
261}
262
263static void put_pasid_state_wait(struct pasid_state *pasid_state)
264{
265	if (!refcount_dec_and_test(&pasid_state->count))
266		wait_event(pasid_state->wq, !refcount_read(&pasid_state->count));
267	free_pasid_state(pasid_state);
268}
269
270static void unbind_pasid(struct pasid_state *pasid_state)
271{
272	struct iommu_domain *domain;
273
274	domain = pasid_state->device_state->domain;
275
276	/*
277	 * Mark pasid_state as invalid, no more faults will we added to the
278	 * work queue after this is visible everywhere.
279	 */
280	pasid_state->invalid = true;
281
282	/* Make sure this is visible */
283	smp_wmb();
284
285	/* After this the device/pasid can't access the mm anymore */
286	amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
287
288	/* Make sure no more pending faults are in the queue */
289	flush_workqueue(iommu_wq);
290}
291
292static void free_pasid_states_level1(struct pasid_state **tbl)
293{
294	int i;
295
296	for (i = 0; i < 512; ++i) {
297		if (tbl[i] == NULL)
298			continue;
299
300		free_page((unsigned long)tbl[i]);
301	}
302}
303
304static void free_pasid_states_level2(struct pasid_state **tbl)
305{
306	struct pasid_state **ptr;
307	int i;
308
309	for (i = 0; i < 512; ++i) {
310		if (tbl[i] == NULL)
311			continue;
312
313		ptr = (struct pasid_state **)tbl[i];
314		free_pasid_states_level1(ptr);
315	}
316}
317
318static void free_pasid_states(struct device_state *dev_state)
319{
320	struct pasid_state *pasid_state;
321	int i;
322
323	for (i = 0; i < dev_state->max_pasids; ++i) {
324		pasid_state = get_pasid_state(dev_state, i);
325		if (pasid_state == NULL)
326			continue;
327
328		put_pasid_state(pasid_state);
329
330		/* Clear the pasid state so that the pasid can be re-used */
331		clear_pasid_state(dev_state, pasid_state->pasid);
332
333		/*
334		 * This will call the mn_release function and
335		 * unbind the PASID
336		 */
337		mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
338
339		put_pasid_state_wait(pasid_state); /* Reference taken in
340						      amd_iommu_bind_pasid */
341
342		/* Drop reference taken in amd_iommu_bind_pasid */
343		put_device_state(dev_state);
344	}
345
346	if (dev_state->pasid_levels == 2)
347		free_pasid_states_level2(dev_state->states);
348	else if (dev_state->pasid_levels == 1)
349		free_pasid_states_level1(dev_state->states);
350	else
351		BUG_ON(dev_state->pasid_levels != 0);
352
353	free_page((unsigned long)dev_state->states);
354}
355
356static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
357{
358	return container_of(mn, struct pasid_state, mn);
359}
360
361static void mn_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
362					struct mm_struct *mm,
363					unsigned long start, unsigned long end)
364{
365	struct pasid_state *pasid_state;
366	struct device_state *dev_state;
367
368	pasid_state = mn_to_state(mn);
369	dev_state   = pasid_state->device_state;
370
371	if ((start ^ (end - 1)) < PAGE_SIZE)
372		amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
373				     start);
374	else
375		amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
376}
377
378static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
379{
380	struct pasid_state *pasid_state;
381	struct device_state *dev_state;
382	bool run_inv_ctx_cb;
383
384	might_sleep();
385
386	pasid_state    = mn_to_state(mn);
387	dev_state      = pasid_state->device_state;
388	run_inv_ctx_cb = !pasid_state->invalid;
389
390	if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
391		dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
392
393	unbind_pasid(pasid_state);
394}
395
396static const struct mmu_notifier_ops iommu_mn = {
397	.release			= mn_release,
398	.arch_invalidate_secondary_tlbs	= mn_arch_invalidate_secondary_tlbs,
399};
400
401static void set_pri_tag_status(struct pasid_state *pasid_state,
402			       u16 tag, int status)
403{
404	unsigned long flags;
405
406	spin_lock_irqsave(&pasid_state->lock, flags);
407	pasid_state->pri[tag].status = status;
408	spin_unlock_irqrestore(&pasid_state->lock, flags);
409}
410
411static void finish_pri_tag(struct device_state *dev_state,
412			   struct pasid_state *pasid_state,
413			   u16 tag)
414{
415	unsigned long flags;
416
417	spin_lock_irqsave(&pasid_state->lock, flags);
418	if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
419	    pasid_state->pri[tag].finish) {
420		amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
421				       pasid_state->pri[tag].status, tag);
422		pasid_state->pri[tag].finish = false;
423		pasid_state->pri[tag].status = PPR_SUCCESS;
424	}
425	spin_unlock_irqrestore(&pasid_state->lock, flags);
426}
427
428static void handle_fault_error(struct fault *fault)
429{
430	int status;
431
432	if (!fault->dev_state->inv_ppr_cb) {
433		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
434		return;
435	}
436
437	status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
438					      fault->pasid,
439					      fault->address,
440					      fault->flags);
441	switch (status) {
442	case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
443		set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
444		break;
445	case AMD_IOMMU_INV_PRI_RSP_INVALID:
446		set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
447		break;
448	case AMD_IOMMU_INV_PRI_RSP_FAIL:
449		set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
450		break;
451	default:
452		BUG();
453	}
454}
455
456static bool access_error(struct vm_area_struct *vma, struct fault *fault)
457{
458	unsigned long requested = 0;
459
460	if (fault->flags & PPR_FAULT_EXEC)
461		requested |= VM_EXEC;
462
463	if (fault->flags & PPR_FAULT_READ)
464		requested |= VM_READ;
465
466	if (fault->flags & PPR_FAULT_WRITE)
467		requested |= VM_WRITE;
468
469	return (requested & ~vma->vm_flags) != 0;
470}
471
472static void do_fault(struct work_struct *work)
473{
474	struct fault *fault = container_of(work, struct fault, work);
475	struct vm_area_struct *vma;
476	vm_fault_t ret = VM_FAULT_ERROR;
477	unsigned int flags = 0;
478	struct mm_struct *mm;
479	u64 address;
480
481	mm = fault->state->mm;
482	address = fault->address;
483
484	if (fault->flags & PPR_FAULT_USER)
485		flags |= FAULT_FLAG_USER;
486	if (fault->flags & PPR_FAULT_WRITE)
487		flags |= FAULT_FLAG_WRITE;
488	flags |= FAULT_FLAG_REMOTE;
489
490	mmap_read_lock(mm);
491	vma = vma_lookup(mm, address);
492	if (!vma)
493		/* failed to get a vma in the right range */
494		goto out;
495
496	/* Check if we have the right permissions on the vma */
497	if (access_error(vma, fault))
498		goto out;
499
500	ret = handle_mm_fault(vma, address, flags, NULL);
501out:
502	mmap_read_unlock(mm);
503
504	if (ret & VM_FAULT_ERROR)
505		/* failed to service fault */
506		handle_fault_error(fault);
507
508	finish_pri_tag(fault->dev_state, fault->state, fault->tag);
509
510	put_pasid_state(fault->state);
511
512	kfree(fault);
513}
514
515static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
516{
517	struct amd_iommu_fault *iommu_fault;
518	struct pasid_state *pasid_state;
519	struct device_state *dev_state;
520	struct pci_dev *pdev = NULL;
521	unsigned long flags;
522	struct fault *fault;
523	bool finish;
524	u16 tag, devid, seg_id;
525	int ret;
526
527	iommu_fault = data;
528	tag         = iommu_fault->tag & 0x1ff;
529	finish      = (iommu_fault->tag >> 9) & 1;
530
531	seg_id = PCI_SBDF_TO_SEGID(iommu_fault->sbdf);
532	devid = PCI_SBDF_TO_DEVID(iommu_fault->sbdf);
533	pdev = pci_get_domain_bus_and_slot(seg_id, PCI_BUS_NUM(devid),
534					   devid & 0xff);
535	if (!pdev)
536		return -ENODEV;
537
538	ret = NOTIFY_DONE;
539
540	/* In kdump kernel pci dev is not initialized yet -> send INVALID */
541	if (amd_iommu_is_attach_deferred(&pdev->dev)) {
542		amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
543				       PPR_INVALID, tag);
544		goto out;
545	}
546
547	dev_state = get_device_state(iommu_fault->sbdf);
548	if (dev_state == NULL)
549		goto out;
550
551	pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
552	if (pasid_state == NULL || pasid_state->invalid) {
553		/* We know the device but not the PASID -> send INVALID */
554		amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
555				       PPR_INVALID, tag);
556		goto out_drop_state;
557	}
558
559	spin_lock_irqsave(&pasid_state->lock, flags);
560	atomic_inc(&pasid_state->pri[tag].inflight);
561	if (finish)
562		pasid_state->pri[tag].finish = true;
563	spin_unlock_irqrestore(&pasid_state->lock, flags);
564
565	fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
566	if (fault == NULL) {
567		/* We are OOM - send success and let the device re-fault */
568		finish_pri_tag(dev_state, pasid_state, tag);
569		goto out_drop_state;
570	}
571
572	fault->dev_state = dev_state;
573	fault->address   = iommu_fault->address;
574	fault->state     = pasid_state;
575	fault->tag       = tag;
576	fault->finish    = finish;
577	fault->pasid     = iommu_fault->pasid;
578	fault->flags     = iommu_fault->flags;
579	INIT_WORK(&fault->work, do_fault);
580
581	queue_work(iommu_wq, &fault->work);
582
583	ret = NOTIFY_OK;
584
585out_drop_state:
586
587	if (ret != NOTIFY_OK && pasid_state)
588		put_pasid_state(pasid_state);
589
590	put_device_state(dev_state);
591
592out:
593	pci_dev_put(pdev);
594	return ret;
595}
596
597static struct notifier_block ppr_nb = {
598	.notifier_call = ppr_notifier,
599};
600
601int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
602			 struct task_struct *task)
603{
604	struct pasid_state *pasid_state;
605	struct device_state *dev_state;
606	struct mm_struct *mm;
607	u32 sbdf;
608	int ret;
609
610	might_sleep();
611
612	if (!amd_iommu_v2_supported())
613		return -ENODEV;
614
615	sbdf      = get_pci_sbdf_id(pdev);
616	dev_state = get_device_state(sbdf);
617
618	if (dev_state == NULL)
619		return -EINVAL;
620
621	ret = -EINVAL;
622	if (pasid >= dev_state->max_pasids)
623		goto out;
624
625	ret = -ENOMEM;
626	pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
627	if (pasid_state == NULL)
628		goto out;
629
630
631	refcount_set(&pasid_state->count, 1);
632	init_waitqueue_head(&pasid_state->wq);
633	spin_lock_init(&pasid_state->lock);
634
635	mm                        = get_task_mm(task);
636	pasid_state->mm           = mm;
637	pasid_state->device_state = dev_state;
638	pasid_state->pasid        = pasid;
639	pasid_state->invalid      = true; /* Mark as valid only if we are
640					     done with setting up the pasid */
641	pasid_state->mn.ops       = &iommu_mn;
642
643	if (pasid_state->mm == NULL)
644		goto out_free;
645
646	ret = mmu_notifier_register(&pasid_state->mn, mm);
647	if (ret)
648		goto out_free;
649
650	ret = set_pasid_state(dev_state, pasid_state, pasid);
651	if (ret)
652		goto out_unregister;
653
654	ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
655					__pa(pasid_state->mm->pgd));
656	if (ret)
657		goto out_clear_state;
658
659	/* Now we are ready to handle faults */
660	pasid_state->invalid = false;
661
662	/*
663	 * Drop the reference to the mm_struct here. We rely on the
664	 * mmu_notifier release call-back to inform us when the mm
665	 * is going away.
666	 */
667	mmput(mm);
668
669	return 0;
670
671out_clear_state:
672	clear_pasid_state(dev_state, pasid);
673
674out_unregister:
675	mmu_notifier_unregister(&pasid_state->mn, mm);
676	mmput(mm);
677
678out_free:
679	free_pasid_state(pasid_state);
680
681out:
682	put_device_state(dev_state);
683
684	return ret;
685}
686EXPORT_SYMBOL(amd_iommu_bind_pasid);
687
688void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid)
689{
690	struct pasid_state *pasid_state;
691	struct device_state *dev_state;
692	u32 sbdf;
693
694	might_sleep();
695
696	if (!amd_iommu_v2_supported())
697		return;
698
699	sbdf = get_pci_sbdf_id(pdev);
700	dev_state = get_device_state(sbdf);
701	if (dev_state == NULL)
702		return;
703
704	if (pasid >= dev_state->max_pasids)
705		goto out;
706
707	pasid_state = get_pasid_state(dev_state, pasid);
708	if (pasid_state == NULL)
709		goto out;
710	/*
711	 * Drop reference taken here. We are safe because we still hold
712	 * the reference taken in the amd_iommu_bind_pasid function.
713	 */
714	put_pasid_state(pasid_state);
715
716	/* Clear the pasid state so that the pasid can be re-used */
717	clear_pasid_state(dev_state, pasid_state->pasid);
718
719	/*
720	 * Call mmu_notifier_unregister to drop our reference
721	 * to pasid_state->mm
722	 */
723	mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
724
725	put_pasid_state_wait(pasid_state); /* Reference taken in
726					      amd_iommu_bind_pasid */
727out:
728	/* Drop reference taken in this function */
729	put_device_state(dev_state);
730
731	/* Drop reference taken in amd_iommu_bind_pasid */
732	put_device_state(dev_state);
733}
734EXPORT_SYMBOL(amd_iommu_unbind_pasid);
735
736int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
737{
738	struct device_state *dev_state;
739	struct iommu_group *group;
740	unsigned long flags;
741	int ret, tmp;
742	u32 sbdf;
743
744	might_sleep();
745
746	/*
747	 * When memory encryption is active the device is likely not in a
748	 * direct-mapped domain. Forbid using IOMMUv2 functionality for now.
749	 */
750	if (cc_platform_has(CC_ATTR_MEM_ENCRYPT))
751		return -ENODEV;
752
753	if (!amd_iommu_v2_supported())
754		return -ENODEV;
755
756	if (pasids <= 0 || pasids > (PASID_MASK + 1))
757		return -EINVAL;
758
759	sbdf = get_pci_sbdf_id(pdev);
760
761	dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
762	if (dev_state == NULL)
763		return -ENOMEM;
764
765	spin_lock_init(&dev_state->lock);
766	init_waitqueue_head(&dev_state->wq);
767	dev_state->pdev  = pdev;
768	dev_state->sbdf = sbdf;
769
770	tmp = pasids;
771	for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
772		dev_state->pasid_levels += 1;
773
774	atomic_set(&dev_state->count, 1);
775	dev_state->max_pasids = pasids;
776
777	ret = -ENOMEM;
778	dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
779	if (dev_state->states == NULL)
780		goto out_free_dev_state;
781
782	dev_state->domain = iommu_domain_alloc(&pci_bus_type);
783	if (dev_state->domain == NULL)
784		goto out_free_states;
785
786	/* See iommu_is_default_domain() */
787	dev_state->domain->type = IOMMU_DOMAIN_IDENTITY;
788	amd_iommu_domain_direct_map(dev_state->domain);
789
790	ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
791	if (ret)
792		goto out_free_domain;
793
794	group = iommu_group_get(&pdev->dev);
795	if (!group) {
796		ret = -EINVAL;
797		goto out_free_domain;
798	}
799
800	ret = iommu_attach_group(dev_state->domain, group);
801	if (ret != 0)
802		goto out_drop_group;
803
804	iommu_group_put(group);
805
806	spin_lock_irqsave(&state_lock, flags);
807
808	if (__get_device_state(sbdf) != NULL) {
809		spin_unlock_irqrestore(&state_lock, flags);
810		ret = -EBUSY;
811		goto out_free_domain;
812	}
813
814	list_add_tail(&dev_state->list, &state_list);
815
816	spin_unlock_irqrestore(&state_lock, flags);
817
818	return 0;
819
820out_drop_group:
821	iommu_group_put(group);
822
823out_free_domain:
824	iommu_domain_free(dev_state->domain);
825
826out_free_states:
827	free_page((unsigned long)dev_state->states);
828
829out_free_dev_state:
830	kfree(dev_state);
831
832	return ret;
833}
834EXPORT_SYMBOL(amd_iommu_init_device);
835
836void amd_iommu_free_device(struct pci_dev *pdev)
837{
838	struct device_state *dev_state;
839	unsigned long flags;
840	u32 sbdf;
841
842	if (!amd_iommu_v2_supported())
843		return;
844
845	sbdf = get_pci_sbdf_id(pdev);
846
847	spin_lock_irqsave(&state_lock, flags);
848
849	dev_state = __get_device_state(sbdf);
850	if (dev_state == NULL) {
851		spin_unlock_irqrestore(&state_lock, flags);
852		return;
853	}
854
855	list_del(&dev_state->list);
856
857	spin_unlock_irqrestore(&state_lock, flags);
858
859	put_device_state(dev_state);
860	free_device_state(dev_state);
861}
862EXPORT_SYMBOL(amd_iommu_free_device);
863
864int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
865				 amd_iommu_invalid_ppr_cb cb)
866{
867	struct device_state *dev_state;
868	unsigned long flags;
869	u32 sbdf;
870	int ret;
871
872	if (!amd_iommu_v2_supported())
873		return -ENODEV;
874
875	sbdf = get_pci_sbdf_id(pdev);
876
877	spin_lock_irqsave(&state_lock, flags);
878
879	ret = -EINVAL;
880	dev_state = __get_device_state(sbdf);
881	if (dev_state == NULL)
882		goto out_unlock;
883
884	dev_state->inv_ppr_cb = cb;
885
886	ret = 0;
887
888out_unlock:
889	spin_unlock_irqrestore(&state_lock, flags);
890
891	return ret;
892}
893EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
894
895int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
896				    amd_iommu_invalidate_ctx cb)
897{
898	struct device_state *dev_state;
899	unsigned long flags;
900	u32 sbdf;
901	int ret;
902
903	if (!amd_iommu_v2_supported())
904		return -ENODEV;
905
906	sbdf = get_pci_sbdf_id(pdev);
907
908	spin_lock_irqsave(&state_lock, flags);
909
910	ret = -EINVAL;
911	dev_state = __get_device_state(sbdf);
912	if (dev_state == NULL)
913		goto out_unlock;
914
915	dev_state->inv_ctx_cb = cb;
916
917	ret = 0;
918
919out_unlock:
920	spin_unlock_irqrestore(&state_lock, flags);
921
922	return ret;
923}
924EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
925
926static int __init amd_iommu_v2_init(void)
927{
928	int ret;
929
930	if (!amd_iommu_v2_supported()) {
931		pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n");
932		/*
933		 * Load anyway to provide the symbols to other modules
934		 * which may use AMD IOMMUv2 optionally.
935		 */
936		return 0;
937	}
938
939	ret = -ENOMEM;
940	iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
941	if (iommu_wq == NULL)
942		goto out;
943
944	amd_iommu_register_ppr_notifier(&ppr_nb);
945
946	pr_info("AMD IOMMUv2 loaded and initialized\n");
947
948	return 0;
949
950out:
951	return ret;
952}
953
954static void __exit amd_iommu_v2_exit(void)
955{
956	struct device_state *dev_state, *next;
957	unsigned long flags;
958	LIST_HEAD(freelist);
959
960	if (!amd_iommu_v2_supported())
961		return;
962
963	amd_iommu_unregister_ppr_notifier(&ppr_nb);
964
965	flush_workqueue(iommu_wq);
966
967	/*
968	 * The loop below might call flush_workqueue(), so call
969	 * destroy_workqueue() after it
970	 */
971	spin_lock_irqsave(&state_lock, flags);
972
973	list_for_each_entry_safe(dev_state, next, &state_list, list) {
974		WARN_ON_ONCE(1);
975
976		put_device_state(dev_state);
977		list_del(&dev_state->list);
978		list_add_tail(&dev_state->list, &freelist);
979	}
980
981	spin_unlock_irqrestore(&state_lock, flags);
982
983	/*
984	 * Since free_device_state waits on the count to be zero,
985	 * we need to free dev_state outside the spinlock.
986	 */
987	list_for_each_entry_safe(dev_state, next, &freelist, list) {
988		list_del(&dev_state->list);
989		free_device_state(dev_state);
990	}
991
992	destroy_workqueue(iommu_wq);
993}
994
995module_init(amd_iommu_v2_init);
996module_exit(amd_iommu_v2_exit);
997