1/*
2 * Copyright (C) 2017 Linaro Ltd;  <ard.biesheuvel@linaro.org>
3 *
4 * This program is free software; you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License version 2 as
6 * published by the Free Software Foundation.
7 *
8 */
9
10#include <linux/libfdt_env.h>
11#include <libfdt.h>
12#include <linux/types.h>
13#include <generated/compile.h>
14#include <generated/utsrelease.h>
15#include <linux/pgtable.h>
16
17#include CONFIG_UNCOMPRESS_INCLUDE
18
19struct regions {
20	u32 pa_start;
21	u32 pa_end;
22	u32 image_size;
23	u32 zimage_start;
24	u32 zimage_size;
25	u32 dtb_start;
26	u32 dtb_size;
27	u32 initrd_start;
28	u32 initrd_size;
29	int reserved_mem;
30	int reserved_mem_addr_cells;
31	int reserved_mem_size_cells;
32};
33
34extern u32 __crc16(u32 crc, u32 const input[], int byte_count);
35
36static u32 __memparse(const char *val, const char **retptr)
37{
38	int base = 10;
39	u32 ret = 0;
40
41	if (*val == '0') {
42		val++;
43		if (*val == 'x' || *val == 'X') {
44			val++;
45			base = 16;
46		} else {
47			base = 8;
48		}
49	}
50
51	while (*val != ',' && *val != ' ' && *val != '\0') {
52		char c = *val++;
53
54		switch (c) {
55		case '0' ... '9':
56			ret = ret * base + (c - '0');
57			continue;
58		case 'a' ... 'f':
59			ret = ret * base + (c - 'a' + 10);
60			continue;
61		case 'A' ... 'F':
62			ret = ret * base + (c - 'A' + 10);
63			continue;
64		case 'g':
65		case 'G':
66			ret <<= 10;
67			/* fall through */
68		case 'm':
69		case 'M':
70			ret <<= 10;
71			/* fall through */
72		case 'k':
73		case 'K':
74			ret <<= 10;
75			break;
76		default:
77			if (retptr)
78				*retptr = NULL;
79			return 0;
80		}
81	}
82	if (retptr)
83		*retptr = val;
84	return ret;
85}
86
87static bool regions_intersect(u32 s1, u32 e1, u32 s2, u32 e2)
88{
89	return e1 >= s2 && e2 >= s1;
90}
91
92static bool intersects_reserved_region(const void *fdt, u32 start,
93				       u32 end, struct regions *regions)
94{
95	int subnode, len, i;
96	u64 base, size;
97
98	/* check for overlap with /memreserve/ entries */
99	for (i = 0; i < fdt_num_mem_rsv(fdt); i++) {
100		if (fdt_get_mem_rsv(fdt, i, &base, &size) < 0)
101			continue;
102		if (regions_intersect(start, end, base, base + size))
103			return true;
104	}
105
106	if (regions->reserved_mem < 0)
107		return false;
108
109	/* check for overlap with static reservations in /reserved-memory */
110	for (subnode = fdt_first_subnode(fdt, regions->reserved_mem);
111	     subnode >= 0;
112	     subnode = fdt_next_subnode(fdt, subnode)) {
113		const fdt32_t *reg;
114
115		len = 0;
116		reg = fdt_getprop(fdt, subnode, "reg", &len);
117		while (len >= (regions->reserved_mem_addr_cells +
118			       regions->reserved_mem_size_cells)) {
119
120			base = fdt32_to_cpu(reg[0]);
121			if (regions->reserved_mem_addr_cells == 2)
122				base = (base << 32) | fdt32_to_cpu(reg[1]);
123
124			reg += regions->reserved_mem_addr_cells;
125			len -= 4 * regions->reserved_mem_addr_cells;
126
127			size = fdt32_to_cpu(reg[0]);
128			if (regions->reserved_mem_size_cells == 2)
129				size = (size << 32) | fdt32_to_cpu(reg[1]);
130
131			reg += regions->reserved_mem_size_cells;
132			len -= 4 * regions->reserved_mem_size_cells;
133
134			if (base >= regions->pa_end)
135				continue;
136
137			if (regions_intersect(start, end, base,
138					      min(base + size, (u64)U32_MAX)))
139				return true;
140		}
141	}
142	return false;
143}
144
145static bool intersects_occupied_region(const void *fdt, u32 start,
146				       u32 end, struct regions *regions)
147{
148	if (regions_intersect(start, end, regions->zimage_start,
149			      regions->zimage_start + regions->zimage_size))
150		return true;
151
152	if (regions_intersect(start, end, regions->initrd_start,
153			      regions->initrd_start + regions->initrd_size))
154		return true;
155
156	if (regions_intersect(start, end, regions->dtb_start,
157			      regions->dtb_start + regions->dtb_size))
158		return true;
159
160	return intersects_reserved_region(fdt, start, end, regions);
161}
162
163static u32 count_suitable_regions(const void *fdt, struct regions *regions,
164				  u32 *bitmap)
165{
166	u32 pa, i = 0, ret = 0;
167
168	for (pa = regions->pa_start; pa < regions->pa_end; pa += SZ_2M, i++) {
169		if (!intersects_occupied_region(fdt, pa,
170						pa + regions->image_size,
171						regions)) {
172			ret++;
173		} else {
174			/* set 'occupied' bit */
175			bitmap[i >> 5] |= BIT(i & 0x1f);
176		}
177	}
178	return ret;
179}
180
181/* The caller ensures that num is within the range of regions.*/
182static u32 get_region_number(u32 num, u32 *bitmap, u32 size)
183{
184	u32 i, cnt = size * BITS_PER_BYTE * sizeof(u32);
185
186	for (i = 0; i < cnt; i++) {
187		if (bitmap[i >> 5] & BIT(i & 0x1f))
188			continue;
189		if (num-- == 0)
190			break;
191	}
192
193	return i;
194}
195
196static void get_cell_sizes(const void *fdt, int node, int *addr_cells,
197			   int *size_cells)
198{
199	const int *prop;
200	int len;
201
202	/*
203	 * Retrieve the #address-cells and #size-cells properties
204	 * from the 'node', or use the default if not provided.
205	 */
206	*addr_cells = *size_cells = 1;
207
208	prop = fdt_getprop(fdt, node, "#address-cells", &len);
209	if (len == 4)
210		*addr_cells = fdt32_to_cpu(*prop);
211	prop = fdt_getprop(fdt, node, "#size-cells", &len);
212	if (len == 4)
213		*size_cells = fdt32_to_cpu(*prop);
214}
215
216/*
217 * Original method only consider the first memory node in dtb,
218 * but there may be more than one memory nodes, we only consider
219 * the memory node zImage exists.
220 */
221static u32 get_memory_end(const void *fdt, u32 zimage_start)
222{
223	int mem_node, address_cells, size_cells, len;
224	const fdt32_t *reg;
225
226	/* Look for a node called "memory" at the lowest level of the tree */
227	mem_node = fdt_path_offset(fdt, "/memory");
228	if (mem_node <= 0)
229		return 0;
230
231	get_cell_sizes(fdt, 0, &address_cells, &size_cells);
232
233	while(mem_node >= 0) {
234		/*
235		 * Now find the 'reg' property of the /memory node, and iterate over
236		 * the base/size pairs.
237		 */
238		len = 0;
239		reg = fdt_getprop(fdt, mem_node, "reg", &len);
240		while (len >= 4 * (address_cells + size_cells)) {
241			u64 base, size;
242			base = fdt32_to_cpu(reg[0]);
243			if (address_cells == 2)
244				base = (base << 32) | fdt32_to_cpu(reg[1]);
245
246			reg += address_cells;
247			len -= 4 * address_cells;
248
249			size = fdt32_to_cpu(reg[0]);
250			if (size_cells == 2)
251				size = (size << 32) | fdt32_to_cpu(reg[1]);
252
253			reg += size_cells;
254			len -= 4 * size_cells;
255
256			/* Get the base and size of the zimage memory node */
257			if (zimage_start >= base && zimage_start < base + size)
258				return base + size;
259		}
260		/* If current memory node is not the one zImage exists, then traverse next memory node. */
261		mem_node = fdt_node_offset_by_prop_value(fdt, mem_node, "device_type", "memory", sizeof("memory"));
262	}
263
264	return 0;
265}
266
267static char *__strstr(const char *s1, const char *s2, int l2)
268{
269	int l1;
270
271	l1 = strlen(s1);
272	while (l1 >= l2) {
273		l1--;
274		if (!memcmp(s1, s2, l2))
275			return (char *)s1;
276		s1++;
277	}
278	return NULL;
279}
280
281static const char *get_cmdline_param(const char *cmdline, const char *param,
282				     int param_size)
283{
284	static const char default_cmdline[] = CONFIG_CMDLINE;
285	const char *p;
286
287	if (!IS_ENABLED(CONFIG_CMDLINE_FORCE) && cmdline != NULL) {
288		p = __strstr(cmdline, param, param_size);
289		if (p == cmdline ||
290		    (p > cmdline && *(p - 1) == ' '))
291			return p;
292	}
293
294	if (IS_ENABLED(CONFIG_CMDLINE_FORCE)  ||
295	    IS_ENABLED(CONFIG_CMDLINE_EXTEND)) {
296		p = __strstr(default_cmdline, param, param_size);
297		if (p == default_cmdline ||
298		    (p > default_cmdline && *(p - 1) == ' '))
299			return p;
300	}
301	return NULL;
302}
303
304static void __puthex32(const char *name, u32 val)
305{
306	int i;
307
308	while (*name)
309		putc(*name++);
310	putc(':');
311	for (i = 28; i >= 0; i -= 4) {
312		char c = (val >> i) & 0xf;
313
314		if (c < 10)
315			putc(c + '0');
316		else
317			putc(c + 'a' - 10);
318	}
319	putc('\r');
320	putc('\n');
321}
322#define puthex32(val)	__puthex32(#val, (val))
323
324u32 kaslr_early_init(u32 *kaslr_offset, u32 image_base, u32 image_size,
325		     u32 seed, u32 zimage_start, const void *fdt,
326		     u32 zimage_end)
327{
328	static const char __aligned(4) build_id[] = UTS_VERSION UTS_RELEASE;
329	u32 bitmap[(VMALLOC_END - PAGE_OFFSET) / SZ_2M / 32] = {};
330	struct regions regions;
331	const char *command_line;
332	const char *p;
333	int chosen, len;
334	u32 lowmem_top, count, num, mem_fdt;
335
336	if (fdt_check_header(fdt))
337		return 0;
338
339	chosen = fdt_path_offset(fdt, "/chosen");
340	if (chosen < 0)
341		return 0;
342
343	command_line = fdt_getprop(fdt, chosen, "bootargs", &len);
344
345	/* check the command line for the presence of 'nokaslr' */
346	p = get_cmdline_param(command_line, "nokaslr", sizeof("nokaslr") - 1);
347	if (p != NULL)
348		return 0;
349
350	/* check the command line for the presence of 'vmalloc=' */
351	p = get_cmdline_param(command_line, "vmalloc=", sizeof("vmalloc=") - 1);
352	if (p != NULL)
353		lowmem_top = VMALLOC_END - __memparse(p + 8, NULL) -
354			     VMALLOC_OFFSET;
355	else
356		lowmem_top = VMALLOC_DEFAULT_BASE;
357
358	regions.image_size = image_base % SZ_128M + round_up(image_size, SZ_2M);
359	regions.pa_start = round_down(image_base, SZ_128M);
360	regions.pa_end = lowmem_top - PAGE_OFFSET + regions.pa_start;
361	regions.zimage_start = zimage_start;
362	regions.zimage_size = zimage_end - zimage_start;
363	regions.dtb_start = (u32)fdt;
364	regions.dtb_size = fdt_totalsize(fdt);
365
366	/*
367	 * Stir up the seed a bit by taking the CRC of the DTB:
368	 * hopefully there's a /chosen/kaslr-seed in there.
369	 */
370	seed = __crc16(seed, fdt, regions.dtb_size);
371
372	/* stir a bit more using data that changes between builds */
373	seed = __crc16(seed, (u32 *)build_id, sizeof(build_id));
374
375	/* check for initrd on the command line */
376	regions.initrd_start = regions.initrd_size = 0;
377	p = get_cmdline_param(command_line, "initrd=", sizeof("initrd=") - 1);
378	if (p != NULL) {
379		regions.initrd_start = __memparse(p + 7, &p);
380		if (*p++ == ',')
381			regions.initrd_size = __memparse(p, NULL);
382		if (regions.initrd_size == 0)
383			regions.initrd_start = 0;
384	}
385
386	/* ... or in /chosen */
387	if (regions.initrd_size == 0) {
388		const fdt32_t *prop;
389		u64 start = 0, end = 0;
390
391		prop = fdt_getprop(fdt, chosen, "linux,initrd-start", &len);
392		if (prop) {
393			start = fdt32_to_cpu(prop[0]);
394			if (len == 8)
395				start = (start << 32) | fdt32_to_cpu(prop[1]);
396		}
397
398		prop = fdt_getprop(fdt, chosen, "linux,initrd-end", &len);
399		if (prop) {
400			end = fdt32_to_cpu(prop[0]);
401			if (len == 8)
402				end = (end << 32) | fdt32_to_cpu(prop[1]);
403		}
404		if (start != 0 && end != 0 && start < U32_MAX) {
405			regions.initrd_start = start;
406			regions.initrd_size = max_t(u64, end, U32_MAX) - start;
407		}
408	}
409
410	/*
411	 * check the memory nodes for the size of the lowmem region, traverse
412	 * all memory nodes to find the node in which zImage exists, we
413	 * randomize kernel only in the one zImage exists.
414	 */
415	mem_fdt = get_memory_end(fdt, zimage_start);
416	if (mem_fdt)
417		regions.pa_end = min(regions.pa_end, mem_fdt) - regions.image_size;
418	else
419		regions.pa_end = regions.pa_end - regions.image_size;
420
421	puthex32(regions.image_size);
422	puthex32(regions.pa_start);
423	puthex32(regions.pa_end);
424	puthex32(regions.zimage_start);
425	puthex32(regions.zimage_size);
426	puthex32(regions.dtb_start);
427	puthex32(regions.dtb_size);
428	puthex32(regions.initrd_start);
429	puthex32(regions.initrd_size);
430
431	/* check for a reserved-memory node and record its cell sizes */
432	regions.reserved_mem = fdt_path_offset(fdt, "/reserved-memory");
433	if (regions.reserved_mem >= 0)
434		get_cell_sizes(fdt, regions.reserved_mem,
435			       &regions.reserved_mem_addr_cells,
436			       &regions.reserved_mem_size_cells);
437
438	/*
439	 * Iterate over the physical memory range covered by the lowmem region
440	 * in 2 MB increments, and count each offset at which we don't overlap
441	 * with any of the reserved regions for the zImage itself, the DTB,
442	 * the initrd and any regions described as reserved in the device tree.
443	 * If the region does overlap, set the respective bit in the bitmap[].
444	 * Using this random value, we go over the bitmap and count zero bits
445	 * until we counted enough iterations, and return the offset we ended
446	 * up at.
447	 */
448	count = count_suitable_regions(fdt, &regions, bitmap);
449	puthex32(count);
450
451	num = ((u16)seed * count) >> 16;
452	puthex32(num);
453
454	*kaslr_offset = get_region_number(num, bitmap, sizeof(bitmap) / sizeof(u32)) * SZ_2M;
455	puthex32(*kaslr_offset);
456
457	return *kaslr_offset;
458}
459