1// SPDX-License-Identifier: GPL-2.0-only 2/* 3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc. 4 * Author: Joerg Roedel <jroedel@suse.de> 5 */ 6 7#define pr_fmt(fmt) "AMD-Vi: " fmt 8 9#include <linux/mmu_notifier.h> 10#include <linux/amd-iommu.h> 11#include <linux/mm_types.h> 12#include <linux/profile.h> 13#include <linux/module.h> 14#include <linux/sched.h> 15#include <linux/sched/mm.h> 16#include <linux/wait.h> 17#include <linux/pci.h> 18#include <linux/gfp.h> 19 20#include "amd_iommu.h" 21 22MODULE_LICENSE("GPL v2"); 23MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>"); 24 25#define MAX_DEVICES 0x10000 26#define PRI_QUEUE_SIZE 512 27 28struct pri_queue { 29 atomic_t inflight; 30 bool finish; 31 int status; 32}; 33 34struct pasid_state { 35 struct list_head list; /* For global state-list */ 36 atomic_t count; /* Reference count */ 37 unsigned mmu_notifier_count; /* Counting nested mmu_notifier 38 calls */ 39 struct mm_struct *mm; /* mm_struct for the faults */ 40 struct mmu_notifier mn; /* mmu_notifier handle */ 41 struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ 42 struct device_state *device_state; /* Link to our device_state */ 43 u32 pasid; /* PASID index */ 44 bool invalid; /* Used during setup and 45 teardown of the pasid */ 46 spinlock_t lock; /* Protect pri_queues and 47 mmu_notifer_count */ 48 wait_queue_head_t wq; /* To wait for count == 0 */ 49}; 50 51struct device_state { 52 struct list_head list; 53 u16 devid; 54 atomic_t count; 55 struct pci_dev *pdev; 56 struct pasid_state **states; 57 struct iommu_domain *domain; 58 int pasid_levels; 59 int max_pasids; 60 amd_iommu_invalid_ppr_cb inv_ppr_cb; 61 amd_iommu_invalidate_ctx inv_ctx_cb; 62 spinlock_t lock; 63 wait_queue_head_t wq; 64}; 65 66struct fault { 67 struct work_struct work; 68 struct device_state *dev_state; 69 struct pasid_state *state; 70 struct mm_struct *mm; 71 u64 address; 72 u16 devid; 73 u32 pasid; 74 u16 tag; 75 u16 finish; 76 u16 flags; 77}; 78 79static LIST_HEAD(state_list); 80static spinlock_t state_lock; 81 82static struct workqueue_struct *iommu_wq; 83 84static void free_pasid_states(struct device_state *dev_state); 85 86static u16 device_id(struct pci_dev *pdev) 87{ 88 u16 devid; 89 90 devid = pdev->bus->number; 91 devid = (devid << 8) | pdev->devfn; 92 93 return devid; 94} 95 96static struct device_state *__get_device_state(u16 devid) 97{ 98 struct device_state *dev_state; 99 100 list_for_each_entry(dev_state, &state_list, list) { 101 if (dev_state->devid == devid) 102 return dev_state; 103 } 104 105 return NULL; 106} 107 108static struct device_state *get_device_state(u16 devid) 109{ 110 struct device_state *dev_state; 111 unsigned long flags; 112 113 spin_lock_irqsave(&state_lock, flags); 114 dev_state = __get_device_state(devid); 115 if (dev_state != NULL) 116 atomic_inc(&dev_state->count); 117 spin_unlock_irqrestore(&state_lock, flags); 118 119 return dev_state; 120} 121 122static void free_device_state(struct device_state *dev_state) 123{ 124 struct iommu_group *group; 125 126 /* 127 * First detach device from domain - No more PRI requests will arrive 128 * from that device after it is unbound from the IOMMUv2 domain. 129 */ 130 group = iommu_group_get(&dev_state->pdev->dev); 131 if (WARN_ON(!group)) 132 return; 133 134 iommu_detach_group(dev_state->domain, group); 135 136 iommu_group_put(group); 137 138 /* Everything is down now, free the IOMMUv2 domain */ 139 iommu_domain_free(dev_state->domain); 140 141 /* Finally get rid of the device-state */ 142 kfree(dev_state); 143} 144 145static void put_device_state(struct device_state *dev_state) 146{ 147 if (atomic_dec_and_test(&dev_state->count)) 148 wake_up(&dev_state->wq); 149} 150 151/* Must be called under dev_state->lock */ 152static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state, 153 u32 pasid, bool alloc) 154{ 155 struct pasid_state **root, **ptr; 156 int level, index; 157 158 level = dev_state->pasid_levels; 159 root = dev_state->states; 160 161 while (true) { 162 163 index = (pasid >> (9 * level)) & 0x1ff; 164 ptr = &root[index]; 165 166 if (level == 0) 167 break; 168 169 if (*ptr == NULL) { 170 if (!alloc) 171 return NULL; 172 173 *ptr = (void *)get_zeroed_page(GFP_ATOMIC); 174 if (*ptr == NULL) 175 return NULL; 176 } 177 178 root = (struct pasid_state **)*ptr; 179 level -= 1; 180 } 181 182 return ptr; 183} 184 185static int set_pasid_state(struct device_state *dev_state, 186 struct pasid_state *pasid_state, 187 u32 pasid) 188{ 189 struct pasid_state **ptr; 190 unsigned long flags; 191 int ret; 192 193 spin_lock_irqsave(&dev_state->lock, flags); 194 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 195 196 ret = -ENOMEM; 197 if (ptr == NULL) 198 goto out_unlock; 199 200 ret = -ENOMEM; 201 if (*ptr != NULL) 202 goto out_unlock; 203 204 *ptr = pasid_state; 205 206 ret = 0; 207 208out_unlock: 209 spin_unlock_irqrestore(&dev_state->lock, flags); 210 211 return ret; 212} 213 214static void clear_pasid_state(struct device_state *dev_state, u32 pasid) 215{ 216 struct pasid_state **ptr; 217 unsigned long flags; 218 219 spin_lock_irqsave(&dev_state->lock, flags); 220 ptr = __get_pasid_state_ptr(dev_state, pasid, true); 221 222 if (ptr == NULL) 223 goto out_unlock; 224 225 *ptr = NULL; 226 227out_unlock: 228 spin_unlock_irqrestore(&dev_state->lock, flags); 229} 230 231static struct pasid_state *get_pasid_state(struct device_state *dev_state, 232 u32 pasid) 233{ 234 struct pasid_state **ptr, *ret = NULL; 235 unsigned long flags; 236 237 spin_lock_irqsave(&dev_state->lock, flags); 238 ptr = __get_pasid_state_ptr(dev_state, pasid, false); 239 240 if (ptr == NULL) 241 goto out_unlock; 242 243 ret = *ptr; 244 if (ret) 245 atomic_inc(&ret->count); 246 247out_unlock: 248 spin_unlock_irqrestore(&dev_state->lock, flags); 249 250 return ret; 251} 252 253static void free_pasid_state(struct pasid_state *pasid_state) 254{ 255 kfree(pasid_state); 256} 257 258static void put_pasid_state(struct pasid_state *pasid_state) 259{ 260 if (atomic_dec_and_test(&pasid_state->count)) 261 wake_up(&pasid_state->wq); 262} 263 264static void put_pasid_state_wait(struct pasid_state *pasid_state) 265{ 266 atomic_dec(&pasid_state->count); 267 wait_event(pasid_state->wq, !atomic_read(&pasid_state->count)); 268 free_pasid_state(pasid_state); 269} 270 271static void unbind_pasid(struct pasid_state *pasid_state) 272{ 273 struct iommu_domain *domain; 274 275 domain = pasid_state->device_state->domain; 276 277 /* 278 * Mark pasid_state as invalid, no more faults will we added to the 279 * work queue after this is visible everywhere. 280 */ 281 pasid_state->invalid = true; 282 283 /* Make sure this is visible */ 284 smp_wmb(); 285 286 /* After this the device/pasid can't access the mm anymore */ 287 amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); 288 289 /* Make sure no more pending faults are in the queue */ 290 flush_workqueue(iommu_wq); 291} 292 293static void free_pasid_states_level1(struct pasid_state **tbl) 294{ 295 int i; 296 297 for (i = 0; i < 512; ++i) { 298 if (tbl[i] == NULL) 299 continue; 300 301 free_page((unsigned long)tbl[i]); 302 } 303} 304 305static void free_pasid_states_level2(struct pasid_state **tbl) 306{ 307 struct pasid_state **ptr; 308 int i; 309 310 for (i = 0; i < 512; ++i) { 311 if (tbl[i] == NULL) 312 continue; 313 314 ptr = (struct pasid_state **)tbl[i]; 315 free_pasid_states_level1(ptr); 316 } 317} 318 319static void free_pasid_states(struct device_state *dev_state) 320{ 321 struct pasid_state *pasid_state; 322 int i; 323 324 for (i = 0; i < dev_state->max_pasids; ++i) { 325 pasid_state = get_pasid_state(dev_state, i); 326 if (pasid_state == NULL) 327 continue; 328 329 put_pasid_state(pasid_state); 330 331 /* 332 * This will call the mn_release function and 333 * unbind the PASID 334 */ 335 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 336 337 put_pasid_state_wait(pasid_state); /* Reference taken in 338 amd_iommu_bind_pasid */ 339 340 /* Drop reference taken in amd_iommu_bind_pasid */ 341 put_device_state(dev_state); 342 } 343 344 if (dev_state->pasid_levels == 2) 345 free_pasid_states_level2(dev_state->states); 346 else if (dev_state->pasid_levels == 1) 347 free_pasid_states_level1(dev_state->states); 348 else 349 BUG_ON(dev_state->pasid_levels != 0); 350 351 free_page((unsigned long)dev_state->states); 352} 353 354static struct pasid_state *mn_to_state(struct mmu_notifier *mn) 355{ 356 return container_of(mn, struct pasid_state, mn); 357} 358 359static void mn_invalidate_range(struct mmu_notifier *mn, 360 struct mm_struct *mm, 361 unsigned long start, unsigned long end) 362{ 363 struct pasid_state *pasid_state; 364 struct device_state *dev_state; 365 366 pasid_state = mn_to_state(mn); 367 dev_state = pasid_state->device_state; 368 369 if ((start ^ (end - 1)) < PAGE_SIZE) 370 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, 371 start); 372 else 373 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid); 374} 375 376static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) 377{ 378 struct pasid_state *pasid_state; 379 struct device_state *dev_state; 380 bool run_inv_ctx_cb; 381 382 might_sleep(); 383 384 pasid_state = mn_to_state(mn); 385 dev_state = pasid_state->device_state; 386 run_inv_ctx_cb = !pasid_state->invalid; 387 388 if (run_inv_ctx_cb && dev_state->inv_ctx_cb) 389 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); 390 391 unbind_pasid(pasid_state); 392} 393 394static const struct mmu_notifier_ops iommu_mn = { 395 .release = mn_release, 396 .invalidate_range = mn_invalidate_range, 397}; 398 399static void set_pri_tag_status(struct pasid_state *pasid_state, 400 u16 tag, int status) 401{ 402 unsigned long flags; 403 404 spin_lock_irqsave(&pasid_state->lock, flags); 405 pasid_state->pri[tag].status = status; 406 spin_unlock_irqrestore(&pasid_state->lock, flags); 407} 408 409static void finish_pri_tag(struct device_state *dev_state, 410 struct pasid_state *pasid_state, 411 u16 tag) 412{ 413 unsigned long flags; 414 415 spin_lock_irqsave(&pasid_state->lock, flags); 416 if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) && 417 pasid_state->pri[tag].finish) { 418 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid, 419 pasid_state->pri[tag].status, tag); 420 pasid_state->pri[tag].finish = false; 421 pasid_state->pri[tag].status = PPR_SUCCESS; 422 } 423 spin_unlock_irqrestore(&pasid_state->lock, flags); 424} 425 426static void handle_fault_error(struct fault *fault) 427{ 428 int status; 429 430 if (!fault->dev_state->inv_ppr_cb) { 431 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 432 return; 433 } 434 435 status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev, 436 fault->pasid, 437 fault->address, 438 fault->flags); 439 switch (status) { 440 case AMD_IOMMU_INV_PRI_RSP_SUCCESS: 441 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS); 442 break; 443 case AMD_IOMMU_INV_PRI_RSP_INVALID: 444 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID); 445 break; 446 case AMD_IOMMU_INV_PRI_RSP_FAIL: 447 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE); 448 break; 449 default: 450 BUG(); 451 } 452} 453 454static bool access_error(struct vm_area_struct *vma, struct fault *fault) 455{ 456 unsigned long requested = 0; 457 458 if (fault->flags & PPR_FAULT_EXEC) 459 requested |= VM_EXEC; 460 461 if (fault->flags & PPR_FAULT_READ) 462 requested |= VM_READ; 463 464 if (fault->flags & PPR_FAULT_WRITE) 465 requested |= VM_WRITE; 466 467 return (requested & ~vma->vm_flags) != 0; 468} 469 470static void do_fault(struct work_struct *work) 471{ 472 struct fault *fault = container_of(work, struct fault, work); 473 struct vm_area_struct *vma; 474 vm_fault_t ret = VM_FAULT_ERROR; 475 unsigned int flags = 0; 476 struct mm_struct *mm; 477 u64 address; 478 479 mm = fault->state->mm; 480 address = fault->address; 481 482 if (fault->flags & PPR_FAULT_USER) 483 flags |= FAULT_FLAG_USER; 484 if (fault->flags & PPR_FAULT_WRITE) 485 flags |= FAULT_FLAG_WRITE; 486 flags |= FAULT_FLAG_REMOTE; 487 488 mmap_read_lock(mm); 489 vma = find_extend_vma(mm, address); 490 if (!vma || address < vma->vm_start) 491 /* failed to get a vma in the right range */ 492 goto out; 493 494 /* Check if we have the right permissions on the vma */ 495 if (access_error(vma, fault)) 496 goto out; 497 498 ret = handle_mm_fault(vma, address, flags, NULL); 499out: 500 mmap_read_unlock(mm); 501 502 if (ret & VM_FAULT_ERROR) 503 /* failed to service fault */ 504 handle_fault_error(fault); 505 506 finish_pri_tag(fault->dev_state, fault->state, fault->tag); 507 508 put_pasid_state(fault->state); 509 510 kfree(fault); 511} 512 513static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) 514{ 515 struct amd_iommu_fault *iommu_fault; 516 struct pasid_state *pasid_state; 517 struct device_state *dev_state; 518 struct pci_dev *pdev = NULL; 519 unsigned long flags; 520 struct fault *fault; 521 bool finish; 522 u16 tag, devid; 523 int ret; 524 525 iommu_fault = data; 526 tag = iommu_fault->tag & 0x1ff; 527 finish = (iommu_fault->tag >> 9) & 1; 528 529 devid = iommu_fault->device_id; 530 pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid), 531 devid & 0xff); 532 if (!pdev) 533 return -ENODEV; 534 535 ret = NOTIFY_DONE; 536 537 /* In kdump kernel pci dev is not initialized yet -> send INVALID */ 538 if (amd_iommu_is_attach_deferred(NULL, &pdev->dev)) { 539 amd_iommu_complete_ppr(pdev, iommu_fault->pasid, 540 PPR_INVALID, tag); 541 goto out; 542 } 543 544 dev_state = get_device_state(iommu_fault->device_id); 545 if (dev_state == NULL) 546 goto out; 547 548 pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); 549 if (pasid_state == NULL || pasid_state->invalid) { 550 /* We know the device but not the PASID -> send INVALID */ 551 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, 552 PPR_INVALID, tag); 553 goto out_drop_state; 554 } 555 556 spin_lock_irqsave(&pasid_state->lock, flags); 557 atomic_inc(&pasid_state->pri[tag].inflight); 558 if (finish) 559 pasid_state->pri[tag].finish = true; 560 spin_unlock_irqrestore(&pasid_state->lock, flags); 561 562 fault = kzalloc(sizeof(*fault), GFP_ATOMIC); 563 if (fault == NULL) { 564 /* We are OOM - send success and let the device re-fault */ 565 finish_pri_tag(dev_state, pasid_state, tag); 566 goto out_drop_state; 567 } 568 569 fault->dev_state = dev_state; 570 fault->address = iommu_fault->address; 571 fault->state = pasid_state; 572 fault->tag = tag; 573 fault->finish = finish; 574 fault->pasid = iommu_fault->pasid; 575 fault->flags = iommu_fault->flags; 576 INIT_WORK(&fault->work, do_fault); 577 578 queue_work(iommu_wq, &fault->work); 579 580 ret = NOTIFY_OK; 581 582out_drop_state: 583 584 if (ret != NOTIFY_OK && pasid_state) 585 put_pasid_state(pasid_state); 586 587 put_device_state(dev_state); 588 589out: 590 pci_dev_put(pdev); 591 return ret; 592} 593 594static struct notifier_block ppr_nb = { 595 .notifier_call = ppr_notifier, 596}; 597 598int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, 599 struct task_struct *task) 600{ 601 struct pasid_state *pasid_state; 602 struct device_state *dev_state; 603 struct mm_struct *mm; 604 u16 devid; 605 int ret; 606 607 might_sleep(); 608 609 if (!amd_iommu_v2_supported()) 610 return -ENODEV; 611 612 devid = device_id(pdev); 613 dev_state = get_device_state(devid); 614 615 if (dev_state == NULL) 616 return -EINVAL; 617 618 ret = -EINVAL; 619 if (pasid >= dev_state->max_pasids) 620 goto out; 621 622 ret = -ENOMEM; 623 pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL); 624 if (pasid_state == NULL) 625 goto out; 626 627 628 atomic_set(&pasid_state->count, 1); 629 init_waitqueue_head(&pasid_state->wq); 630 spin_lock_init(&pasid_state->lock); 631 632 mm = get_task_mm(task); 633 pasid_state->mm = mm; 634 pasid_state->device_state = dev_state; 635 pasid_state->pasid = pasid; 636 pasid_state->invalid = true; /* Mark as valid only if we are 637 done with setting up the pasid */ 638 pasid_state->mn.ops = &iommu_mn; 639 640 if (pasid_state->mm == NULL) 641 goto out_free; 642 643 mmu_notifier_register(&pasid_state->mn, mm); 644 645 ret = set_pasid_state(dev_state, pasid_state, pasid); 646 if (ret) 647 goto out_unregister; 648 649 ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid, 650 __pa(pasid_state->mm->pgd)); 651 if (ret) 652 goto out_clear_state; 653 654 /* Now we are ready to handle faults */ 655 pasid_state->invalid = false; 656 657 /* 658 * Drop the reference to the mm_struct here. We rely on the 659 * mmu_notifier release call-back to inform us when the mm 660 * is going away. 661 */ 662 mmput(mm); 663 664 return 0; 665 666out_clear_state: 667 clear_pasid_state(dev_state, pasid); 668 669out_unregister: 670 mmu_notifier_unregister(&pasid_state->mn, mm); 671 mmput(mm); 672 673out_free: 674 free_pasid_state(pasid_state); 675 676out: 677 put_device_state(dev_state); 678 679 return ret; 680} 681EXPORT_SYMBOL(amd_iommu_bind_pasid); 682 683void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid) 684{ 685 struct pasid_state *pasid_state; 686 struct device_state *dev_state; 687 u16 devid; 688 689 might_sleep(); 690 691 if (!amd_iommu_v2_supported()) 692 return; 693 694 devid = device_id(pdev); 695 dev_state = get_device_state(devid); 696 if (dev_state == NULL) 697 return; 698 699 if (pasid >= dev_state->max_pasids) 700 goto out; 701 702 pasid_state = get_pasid_state(dev_state, pasid); 703 if (pasid_state == NULL) 704 goto out; 705 /* 706 * Drop reference taken here. We are safe because we still hold 707 * the reference taken in the amd_iommu_bind_pasid function. 708 */ 709 put_pasid_state(pasid_state); 710 711 /* Clear the pasid state so that the pasid can be re-used */ 712 clear_pasid_state(dev_state, pasid_state->pasid); 713 714 /* 715 * Call mmu_notifier_unregister to drop our reference 716 * to pasid_state->mm 717 */ 718 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); 719 720 put_pasid_state_wait(pasid_state); /* Reference taken in 721 amd_iommu_bind_pasid */ 722out: 723 /* Drop reference taken in this function */ 724 put_device_state(dev_state); 725 726 /* Drop reference taken in amd_iommu_bind_pasid */ 727 put_device_state(dev_state); 728} 729EXPORT_SYMBOL(amd_iommu_unbind_pasid); 730 731int amd_iommu_init_device(struct pci_dev *pdev, int pasids) 732{ 733 struct device_state *dev_state; 734 struct iommu_group *group; 735 unsigned long flags; 736 int ret, tmp; 737 u16 devid; 738 739 might_sleep(); 740 741 /* 742 * When memory encryption is active the device is likely not in a 743 * direct-mapped domain. Forbid using IOMMUv2 functionality for now. 744 */ 745 if (mem_encrypt_active()) 746 return -ENODEV; 747 748 if (!amd_iommu_v2_supported()) 749 return -ENODEV; 750 751 if (pasids <= 0 || pasids > (PASID_MASK + 1)) 752 return -EINVAL; 753 754 devid = device_id(pdev); 755 756 dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL); 757 if (dev_state == NULL) 758 return -ENOMEM; 759 760 spin_lock_init(&dev_state->lock); 761 init_waitqueue_head(&dev_state->wq); 762 dev_state->pdev = pdev; 763 dev_state->devid = devid; 764 765 tmp = pasids; 766 for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9) 767 dev_state->pasid_levels += 1; 768 769 atomic_set(&dev_state->count, 1); 770 dev_state->max_pasids = pasids; 771 772 ret = -ENOMEM; 773 dev_state->states = (void *)get_zeroed_page(GFP_KERNEL); 774 if (dev_state->states == NULL) 775 goto out_free_dev_state; 776 777 dev_state->domain = iommu_domain_alloc(&pci_bus_type); 778 if (dev_state->domain == NULL) 779 goto out_free_states; 780 781 amd_iommu_domain_direct_map(dev_state->domain); 782 783 ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids); 784 if (ret) 785 goto out_free_domain; 786 787 group = iommu_group_get(&pdev->dev); 788 if (!group) { 789 ret = -EINVAL; 790 goto out_free_domain; 791 } 792 793 ret = iommu_attach_group(dev_state->domain, group); 794 if (ret != 0) 795 goto out_drop_group; 796 797 iommu_group_put(group); 798 799 spin_lock_irqsave(&state_lock, flags); 800 801 if (__get_device_state(devid) != NULL) { 802 spin_unlock_irqrestore(&state_lock, flags); 803 ret = -EBUSY; 804 goto out_free_domain; 805 } 806 807 list_add_tail(&dev_state->list, &state_list); 808 809 spin_unlock_irqrestore(&state_lock, flags); 810 811 return 0; 812 813out_drop_group: 814 iommu_group_put(group); 815 816out_free_domain: 817 iommu_domain_free(dev_state->domain); 818 819out_free_states: 820 free_page((unsigned long)dev_state->states); 821 822out_free_dev_state: 823 kfree(dev_state); 824 825 return ret; 826} 827EXPORT_SYMBOL(amd_iommu_init_device); 828 829void amd_iommu_free_device(struct pci_dev *pdev) 830{ 831 struct device_state *dev_state; 832 unsigned long flags; 833 u16 devid; 834 835 if (!amd_iommu_v2_supported()) 836 return; 837 838 devid = device_id(pdev); 839 840 spin_lock_irqsave(&state_lock, flags); 841 842 dev_state = __get_device_state(devid); 843 if (dev_state == NULL) { 844 spin_unlock_irqrestore(&state_lock, flags); 845 return; 846 } 847 848 list_del(&dev_state->list); 849 850 spin_unlock_irqrestore(&state_lock, flags); 851 852 /* Get rid of any remaining pasid states */ 853 free_pasid_states(dev_state); 854 855 put_device_state(dev_state); 856 /* 857 * Wait until the last reference is dropped before freeing 858 * the device state. 859 */ 860 wait_event(dev_state->wq, !atomic_read(&dev_state->count)); 861 free_device_state(dev_state); 862} 863EXPORT_SYMBOL(amd_iommu_free_device); 864 865int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev, 866 amd_iommu_invalid_ppr_cb cb) 867{ 868 struct device_state *dev_state; 869 unsigned long flags; 870 u16 devid; 871 int ret; 872 873 if (!amd_iommu_v2_supported()) 874 return -ENODEV; 875 876 devid = device_id(pdev); 877 878 spin_lock_irqsave(&state_lock, flags); 879 880 ret = -EINVAL; 881 dev_state = __get_device_state(devid); 882 if (dev_state == NULL) 883 goto out_unlock; 884 885 dev_state->inv_ppr_cb = cb; 886 887 ret = 0; 888 889out_unlock: 890 spin_unlock_irqrestore(&state_lock, flags); 891 892 return ret; 893} 894EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb); 895 896int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev, 897 amd_iommu_invalidate_ctx cb) 898{ 899 struct device_state *dev_state; 900 unsigned long flags; 901 u16 devid; 902 int ret; 903 904 if (!amd_iommu_v2_supported()) 905 return -ENODEV; 906 907 devid = device_id(pdev); 908 909 spin_lock_irqsave(&state_lock, flags); 910 911 ret = -EINVAL; 912 dev_state = __get_device_state(devid); 913 if (dev_state == NULL) 914 goto out_unlock; 915 916 dev_state->inv_ctx_cb = cb; 917 918 ret = 0; 919 920out_unlock: 921 spin_unlock_irqrestore(&state_lock, flags); 922 923 return ret; 924} 925EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb); 926 927static int __init amd_iommu_v2_init(void) 928{ 929 int ret; 930 931 if (!amd_iommu_v2_supported()) { 932 pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n"); 933 /* 934 * Load anyway to provide the symbols to other modules 935 * which may use AMD IOMMUv2 optionally. 936 */ 937 return 0; 938 } 939 940 spin_lock_init(&state_lock); 941 942 ret = -ENOMEM; 943 iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0); 944 if (iommu_wq == NULL) 945 goto out; 946 947 amd_iommu_register_ppr_notifier(&ppr_nb); 948 949 pr_info("AMD IOMMUv2 loaded and initialized\n"); 950 951 return 0; 952 953out: 954 return ret; 955} 956 957static void __exit amd_iommu_v2_exit(void) 958{ 959 struct device_state *dev_state; 960 int i; 961 962 if (!amd_iommu_v2_supported()) 963 return; 964 965 amd_iommu_unregister_ppr_notifier(&ppr_nb); 966 967 flush_workqueue(iommu_wq); 968 969 /* 970 * The loop below might call flush_workqueue(), so call 971 * destroy_workqueue() after it 972 */ 973 for (i = 0; i < MAX_DEVICES; ++i) { 974 dev_state = get_device_state(i); 975 976 if (dev_state == NULL) 977 continue; 978 979 WARN_ON_ONCE(1); 980 981 put_device_state(dev_state); 982 amd_iommu_free_device(dev_state->pdev); 983 } 984 985 destroy_workqueue(iommu_wq); 986} 987 988module_init(amd_iommu_v2_init); 989module_exit(amd_iommu_v2_exit); 990