1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (c) 2024 Huawei Device Co., Ltd.
4  */
5 
6 #include <asm/page.h>
7 #include <linux/mm.h>
8 #include <linux/mm_types.h>
9 #include <linux/radix-tree.h>
10 #include <linux/rmap.h>
11 #include <linux/slab.h>
12 #include <linux/oom.h> /* find_lock_task_mm */
13 
14 #include <linux/mm_purgeable.h>
15 
16 struct uxpte_t {
17 	atomic64_t val;
18 };
19 
20 #define UXPTE_SIZE_SHIFT 3
21 #define UXPTE_SIZE (1 << UXPTE_SIZE_SHIFT)
22 
23 #define UXPTE_PER_PAGE_SHIFT (PAGE_SHIFT - UXPTE_SIZE_SHIFT)
24 #define UXPTE_PER_PAGE (1 << UXPTE_PER_PAGE_SHIFT)
25 
26 #define UXPTE_PRESENT_BIT 1
27 #define UXPTE_PRESENT_MASK ((1 << UXPTE_PRESENT_BIT) - 1)
28 #define UXPTE_REFCNT_ONE (1 << UXPTE_PRESENT_BIT)
29 #define UXPTE_UNDER_RECLAIM (-UXPTE_REFCNT_ONE)
30 
31 #define vpn(vaddr) ((vaddr) >> PAGE_SHIFT)
32 #define uxpte_pn(vaddr) (vpn(vaddr) >> UXPTE_PER_PAGE_SHIFT)
33 #define uxpte_off(vaddr) (vpn(vaddr) & (UXPTE_PER_PAGE - 1))
34 #define uxpn2addr(uxpn) ((uxpn) << (UXPTE_PER_PAGE_SHIFT + PAGE_SHIFT))
35 #define uxpte_refcnt(uxpte) ((uxpte) >> UXPTE_PRESENT_BIT)
36 #define uxpte_present(uxpte) ((uxpte) & UXPTE_PRESENT_MASK)
37 
uxpte_read(struct uxpte_t *uxpte)38 static inline long uxpte_read(struct uxpte_t *uxpte)
39 {
40 	return atomic64_read(&uxpte->val);
41 }
42 
uxpte_set(struct uxpte_t *uxpte, long val)43 static inline void uxpte_set(struct uxpte_t *uxpte, long val)
44 {
45 	atomic64_set(&uxpte->val, val);
46 }
47 
uxpte_cas(struct uxpte_t *uxpte, long old, long new)48 static inline bool uxpte_cas(struct uxpte_t *uxpte, long old, long new)
49 {
50 	return atomic64_cmpxchg(&uxpte->val, old, new) == old;
51 }
52 
mm_init_uxpgd(struct mm_struct *mm)53 void mm_init_uxpgd(struct mm_struct *mm)
54 {
55 	mm->uxpgd = NULL;
56 	spin_lock_init(&mm->uxpgd_lock);
57 }
58 
mm_clear_uxpgd(struct mm_struct *mm)59 void mm_clear_uxpgd(struct mm_struct *mm)
60 {
61 	struct page *page = NULL;
62 	void **slot = NULL;
63 	struct radix_tree_iter iter;
64 
65 	spin_lock(&mm->uxpgd_lock);
66 	if (!mm->uxpgd)
67 		goto out;
68 	radix_tree_for_each_slot(slot, mm->uxpgd, &iter, 0) {
69 		page = radix_tree_delete(mm->uxpgd, iter.index);
70 		put_page(page);
71 	}
72 out:
73 	kfree(mm->uxpgd);
74 	mm->uxpgd = NULL;
75 	spin_unlock(&mm->uxpgd_lock);
76 }
77 
78 /* should hold uxpgd_lock before invoke */
lookup_uxpte_page(struct vm_area_struct *vma, unsigned long addr, bool alloc)79 static struct page *lookup_uxpte_page(struct vm_area_struct *vma,
80 	unsigned long addr, bool alloc)
81 {
82 	struct radix_tree_root *uxpgd = NULL;
83 	struct page *page = NULL;
84     struct folio *new_folio = NULL;
85 	struct page *new_page = NULL;
86 	struct mm_struct *mm = vma->vm_mm;
87 	unsigned long uxpn = uxpte_pn(addr);
88 
89 	if (mm->uxpgd)
90 		goto lookup;
91 	if (!alloc)
92 		goto out;
93 	spin_unlock(&mm->uxpgd_lock);
94 	uxpgd = kzalloc(sizeof(struct radix_tree_root), GFP_KERNEL);
95 	if (!uxpgd) {
96 		pr_err("uxpgd alloc failed.\n");
97 		spin_lock(&mm->uxpgd_lock);
98 		goto out;
99 	}
100 	INIT_RADIX_TREE(uxpgd, GFP_KERNEL);
101 	spin_lock(&mm->uxpgd_lock);
102 	if (mm->uxpgd)
103 		kfree(uxpgd);
104 	else
105 		mm->uxpgd = uxpgd;
106 lookup:
107 	page = radix_tree_lookup(mm->uxpgd, uxpn);
108 	if (page)
109 		goto out;
110 	if (!alloc)
111 		goto out;
112 	spin_unlock(&mm->uxpgd_lock);
113 	new_folio = vma_alloc_zeroed_movable_folio(vma, addr);
114 	if (!new_folio) {
115 		pr_err("uxpte page alloc fail.\n");
116 		spin_lock(&mm->uxpgd_lock);
117 		goto out;
118 	}
119     new_page = &new_folio->page;
120 	if (radix_tree_preload(GFP_KERNEL)) {
121 		put_page(new_page);
122 		pr_err("radix preload fail.\n");
123 		spin_lock(&mm->uxpgd_lock);
124 		goto out;
125 	}
126 	spin_lock(&mm->uxpgd_lock);
127 	page = radix_tree_lookup(mm->uxpgd, uxpn);
128 	if (page) {
129 		put_page(new_page);
130 	} else {
131 		page = new_page;
132 		radix_tree_insert(mm->uxpgd, uxpn, page);
133 	}
134 	radix_tree_preload_end();
135 out:
136 	return page;
137 }
138 
139 /* should hold uxpgd_lock before invoke */
lookup_uxpte(struct vm_area_struct *vma, unsigned long addr, bool alloc)140 static struct uxpte_t *lookup_uxpte(struct vm_area_struct *vma,
141 		unsigned long addr, bool alloc)
142 {
143 	struct uxpte_t *uxpte = NULL;
144 	struct page *page = NULL;
145 
146 	page = lookup_uxpte_page(vma, addr, alloc);
147 	if (!page)
148 		return NULL;
149 	uxpte = page_to_virt(page);
150 
151 	return uxpte + uxpte_off(addr);
152 }
153 
lock_uxpte(struct vm_area_struct *vma, unsigned long addr)154 bool lock_uxpte(struct vm_area_struct *vma, unsigned long addr)
155 {
156 	struct uxpte_t *uxpte = NULL;
157 	long val = 0;
158 
159 	spin_lock(&vma->vm_mm->uxpgd_lock);
160 	uxpte = lookup_uxpte(vma, addr, true);
161 	if (!uxpte)
162 		goto unlock;
163 retry:
164 	val = uxpte_read(uxpte);
165 	if (val >> 1)
166 		goto unlock;
167 	if (!uxpte_cas(uxpte, val, UXPTE_UNDER_RECLAIM))
168 		goto retry;
169 	val = UXPTE_UNDER_RECLAIM;
170 unlock:
171 	spin_unlock(&vma->vm_mm->uxpgd_lock);
172 
173 	return val == UXPTE_UNDER_RECLAIM;
174 }
175 
unlock_uxpte(struct vm_area_struct *vma, unsigned long addr)176 void unlock_uxpte(struct vm_area_struct *vma, unsigned long addr)
177 {
178 	struct uxpte_t *uxpte = NULL;
179 
180 	spin_lock(&vma->vm_mm->uxpgd_lock);
181 	uxpte = lookup_uxpte(vma, addr, false);
182 	if (!uxpte)
183 		goto unlock;
184 	uxpte_set(uxpte, 0);
185 unlock:
186 	spin_unlock(&vma->vm_mm->uxpgd_lock);
187 }
188 
uxpte_set_present(struct vm_area_struct *vma, unsigned long addr)189 bool uxpte_set_present(struct vm_area_struct *vma, unsigned long addr)
190 {
191 	struct uxpte_t *uxpte = NULL;
192 	long val = 0;
193 
194 	spin_lock(&vma->vm_mm->uxpgd_lock);
195 	uxpte = lookup_uxpte(vma, addr, true);
196 	if (!uxpte)
197 		goto unlock;
198 retry:
199 	val = uxpte_read(uxpte);
200 	if (val & 1)
201 		goto unlock;
202 	if (!uxpte_cas(uxpte, val, val + 1))
203 		goto retry;
204 	val++;
205 unlock:
206 	spin_unlock(&vma->vm_mm->uxpgd_lock);
207 
208 	return val & 1;
209 }
210 
uxpte_clear_present(struct vm_area_struct *vma, unsigned long addr)211 void uxpte_clear_present(struct vm_area_struct *vma, unsigned long addr)
212 {
213 	struct uxpte_t *uxpte = NULL;
214 	long val = 0;
215 
216 	spin_lock(&vma->vm_mm->uxpgd_lock);
217 	uxpte = lookup_uxpte(vma, addr, false);
218 	if (!uxpte)
219 		goto unlock;
220 retry:
221 	val = uxpte_read(uxpte);
222 	if (!(val & 1))
223 		goto unlock;
224 	if (!uxpte_cas(uxpte, val, val - 1))
225 		goto retry;
226 unlock:
227 	spin_unlock(&vma->vm_mm->uxpgd_lock);
228 }
229 
do_uxpte_page_fault(struct vm_fault *vmf, pte_t *entry)230 vm_fault_t do_uxpte_page_fault(struct vm_fault *vmf, pte_t *entry)
231 {
232 	struct vm_area_struct *vma = vmf->vma;
233 	unsigned long vma_uxpn = vma->vm_pgoff;
234 	unsigned long off_uxpn = vpn(vmf->address - vma->vm_start);
235 	unsigned long addr = uxpn2addr(vma_uxpn + off_uxpn);
236 	struct page *page = NULL;
237 
238 	if (unlikely(anon_vma_prepare(vma)))
239 		return VM_FAULT_OOM;
240 
241 	spin_lock(&vma->vm_mm->uxpgd_lock);
242 	page = lookup_uxpte_page(vma, addr, true);
243 	spin_unlock(&vma->vm_mm->uxpgd_lock);
244 
245 	if (!page)
246 		return VM_FAULT_OOM;
247 
248 	*entry = mk_pte(page, vma->vm_page_prot);
249 	*entry = pte_sw_mkyoung(*entry);
250 	if (vma->vm_flags & VM_WRITE)
251 		*entry = pte_mkwrite(pte_mkdirty(*entry), vma);
252 	return 0;
253 }
254 
__mm_purg_pages_info(struct mm_struct *mm, unsigned long *total_purg_pages, unsigned long *pined_purg_pages)255 static void __mm_purg_pages_info(struct mm_struct *mm, unsigned long *total_purg_pages,
256 	unsigned long *pined_purg_pages)
257 {
258 	struct page *page = NULL;
259 	void **slot = NULL;
260 	struct radix_tree_iter iter;
261 	struct uxpte_t *uxpte = NULL;
262 	long pte_entry = 0;
263 	int index = 0;
264 	unsigned long nr_total = 0, nr_pined = 0;
265 
266 	spin_lock(&mm->uxpgd_lock);
267 	if (!mm->uxpgd)
268 		goto out;
269 	radix_tree_for_each_slot(slot, mm->uxpgd, &iter, 0) {
270 		page = radix_tree_deref_slot(slot);
271 		if (unlikely(!page))
272 			continue;
273 		uxpte = page_to_virt(page);
274 		for (index = 0; index < UXPTE_PER_PAGE; index++) {
275 			pte_entry = uxpte_read(&(uxpte[index]));
276 			if (uxpte_present(pte_entry) == 0) /* not present */
277 				continue;
278 			nr_total++;
279 			if (uxpte_refcnt(pte_entry) > 0) /* pined by user */
280 				nr_pined++;
281 		}
282 	}
283 out:
284 	spin_unlock(&mm->uxpgd_lock);
285 
286 	if (total_purg_pages)
287 		*total_purg_pages = nr_total;
288 
289 	if (pined_purg_pages)
290 		*pined_purg_pages = nr_pined;
291 }
292 
mm_purg_pages_info(struct mm_struct *mm, unsigned long *total_purg_pages, unsigned long *pined_purg_pages)293 void mm_purg_pages_info(struct mm_struct *mm, unsigned long *total_purg_pages,
294 	unsigned long *pined_purg_pages)
295 {
296 	if (unlikely(!mm))
297 		return;
298 
299 	if (!total_purg_pages && !pined_purg_pages)
300 		return;
301 
302 	__mm_purg_pages_info(mm, total_purg_pages, pined_purg_pages);
303 }
304 
purg_pages_info(unsigned long *total_purg_pages, unsigned long *pined_purg_pages)305 void purg_pages_info(unsigned long *total_purg_pages, unsigned long *pined_purg_pages)
306 {
307 	struct task_struct *p = NULL;
308 	struct task_struct *tsk = NULL;
309 	unsigned long mm_nr_purge = 0, mm_nr_pined = 0;
310 	unsigned long nr_total = 0, nr_pined = 0;
311 
312 	if (!total_purg_pages && !pined_purg_pages)
313 		return;
314 
315 	if (total_purg_pages)
316 		*total_purg_pages = 0;
317 
318 	if (pined_purg_pages)
319 		*pined_purg_pages = 0;
320 
321 	rcu_read_lock();
322 	for_each_process(p) {
323 		tsk = find_lock_task_mm(p);
324 		if (!tsk) {
325 			/*
326 			 * It is a kthread or all of p's threads have already
327 			 * detached their mm's.
328 			 */
329 			continue;
330 		}
331 		__mm_purg_pages_info(tsk->mm, &mm_nr_purge, &mm_nr_pined);
332 		nr_total += mm_nr_purge;
333 		nr_pined += mm_nr_pined;
334 		task_unlock(tsk);
335 
336 		if (mm_nr_purge > 0) {
337 			pr_info("purgemm: tsk: %s %lu pined in %lu pages\n", tsk->comm ?: "NULL",
338 				mm_nr_pined, mm_nr_purge);
339 		}
340 	}
341 	rcu_read_unlock();
342 	if (total_purg_pages)
343 		*total_purg_pages = nr_total;
344 
345 	if (pined_purg_pages)
346 		*pined_purg_pages = nr_pined;
347 	pr_info("purgemm: Sum: %lu pined in %lu pages\n", nr_pined, nr_total);
348 }
349