1 /* SPDX-License-Identifier: GPL-2.0 */
2 /*
3  * NH - ε-almost-universal hash function, x86_64 AVX2 accelerated
4  *
5  * Copyright 2018 Google LLC
6  *
7  * Author: Eric Biggers <ebiggers@google.com>
8  */
9 
10 #include <linux/linkage.h>
11 #include <linux/cfi_types.h>
12 
13 #define		PASS0_SUMS	%ymm0
14 #define		PASS1_SUMS	%ymm1
15 #define		PASS2_SUMS	%ymm2
16 #define		PASS3_SUMS	%ymm3
17 #define		K0		%ymm4
18 #define		K0_XMM		%xmm4
19 #define		K1		%ymm5
20 #define		K1_XMM		%xmm5
21 #define		K2		%ymm6
22 #define		K2_XMM		%xmm6
23 #define		K3		%ymm7
24 #define		K3_XMM		%xmm7
25 #define		T0		%ymm8
26 #define		T1		%ymm9
27 #define		T2		%ymm10
28 #define		T2_XMM		%xmm10
29 #define		T3		%ymm11
30 #define		T3_XMM		%xmm11
31 #define		T4		%ymm12
32 #define		T5		%ymm13
33 #define		T6		%ymm14
34 #define		T7		%ymm15
35 #define		KEY		%rdi
36 #define		MESSAGE		%rsi
37 #define		MESSAGE_LEN	%rdx
38 #define		HASH		%rcx
39 
40 .macro _nh_2xstride	k0, k1, k2, k3
41 
42 	// Add message words to key words
43 	vpaddd		\k0, T3, T0
44 	vpaddd		\k1, T3, T1
45 	vpaddd		\k2, T3, T2
46 	vpaddd		\k3, T3, T3
47 
48 	// Multiply 32x32 => 64 and accumulate
49 	vpshufd		$0x10, T0, T4
50 	vpshufd		$0x32, T0, T0
51 	vpshufd		$0x10, T1, T5
52 	vpshufd		$0x32, T1, T1
53 	vpshufd		$0x10, T2, T6
54 	vpshufd		$0x32, T2, T2
55 	vpshufd		$0x10, T3, T7
56 	vpshufd		$0x32, T3, T3
57 	vpmuludq	T4, T0, T0
58 	vpmuludq	T5, T1, T1
59 	vpmuludq	T6, T2, T2
60 	vpmuludq	T7, T3, T3
61 	vpaddq		T0, PASS0_SUMS, PASS0_SUMS
62 	vpaddq		T1, PASS1_SUMS, PASS1_SUMS
63 	vpaddq		T2, PASS2_SUMS, PASS2_SUMS
64 	vpaddq		T3, PASS3_SUMS, PASS3_SUMS
65 .endm
66 
67 /*
68  * void nh_avx2(const u32 *key, const u8 *message, size_t message_len,
69  *		__le64 hash[NH_NUM_PASSES])
70  *
71  * It's guaranteed that message_len % 16 == 0.
72  */
73 SYM_TYPED_FUNC_START(nh_avx2)
74 
75 	vmovdqu		0x00(KEY), K0
76 	vmovdqu		0x10(KEY), K1
77 	add		$0x20, KEY
78 	vpxor		PASS0_SUMS, PASS0_SUMS, PASS0_SUMS
79 	vpxor		PASS1_SUMS, PASS1_SUMS, PASS1_SUMS
80 	vpxor		PASS2_SUMS, PASS2_SUMS, PASS2_SUMS
81 	vpxor		PASS3_SUMS, PASS3_SUMS, PASS3_SUMS
82 
83 	sub		$0x40, MESSAGE_LEN
84 	jl		.Lloop4_done
85 .Lloop4:
86 	vmovdqu		(MESSAGE), T3
87 	vmovdqu		0x00(KEY), K2
88 	vmovdqu		0x10(KEY), K3
89 	_nh_2xstride	K0, K1, K2, K3
90 
91 	vmovdqu		0x20(MESSAGE), T3
92 	vmovdqu		0x20(KEY), K0
93 	vmovdqu		0x30(KEY), K1
94 	_nh_2xstride	K2, K3, K0, K1
95 
96 	add		$0x40, MESSAGE
97 	add		$0x40, KEY
98 	sub		$0x40, MESSAGE_LEN
99 	jge		.Lloop4
100 
101 .Lloop4_done:
102 	and		$0x3f, MESSAGE_LEN
103 	jz		.Ldone
104 
105 	cmp		$0x20, MESSAGE_LEN
106 	jl		.Llast
107 
108 	// 2 or 3 strides remain; do 2 more.
109 	vmovdqu		(MESSAGE), T3
110 	vmovdqu		0x00(KEY), K2
111 	vmovdqu		0x10(KEY), K3
112 	_nh_2xstride	K0, K1, K2, K3
113 	add		$0x20, MESSAGE
114 	add		$0x20, KEY
115 	sub		$0x20, MESSAGE_LEN
116 	jz		.Ldone
117 	vmovdqa		K2, K0
118 	vmovdqa		K3, K1
119 .Llast:
120 	// Last stride.  Zero the high 128 bits of the message and keys so they
121 	// don't affect the result when processing them like 2 strides.
122 	vmovdqu		(MESSAGE), T3_XMM
123 	vmovdqa		K0_XMM, K0_XMM
124 	vmovdqa		K1_XMM, K1_XMM
125 	vmovdqu		0x00(KEY), K2_XMM
126 	vmovdqu		0x10(KEY), K3_XMM
127 	_nh_2xstride	K0, K1, K2, K3
128 
129 .Ldone:
130 	// Sum the accumulators for each pass, then store the sums to 'hash'
131 
132 	// PASS0_SUMS is (0A 0B 0C 0D)
133 	// PASS1_SUMS is (1A 1B 1C 1D)
134 	// PASS2_SUMS is (2A 2B 2C 2D)
135 	// PASS3_SUMS is (3A 3B 3C 3D)
136 	// We need the horizontal sums:
137 	//     (0A + 0B + 0C + 0D,
138 	//	1A + 1B + 1C + 1D,
139 	//	2A + 2B + 2C + 2D,
140 	//	3A + 3B + 3C + 3D)
141 	//
142 
143 	vpunpcklqdq	PASS1_SUMS, PASS0_SUMS, T0	// T0 = (0A 1A 0C 1C)
144 	vpunpckhqdq	PASS1_SUMS, PASS0_SUMS, T1	// T1 = (0B 1B 0D 1D)
145 	vpunpcklqdq	PASS3_SUMS, PASS2_SUMS, T2	// T2 = (2A 3A 2C 3C)
146 	vpunpckhqdq	PASS3_SUMS, PASS2_SUMS, T3	// T3 = (2B 3B 2D 3D)
147 
148 	vinserti128	$0x1, T2_XMM, T0, T4		// T4 = (0A 1A 2A 3A)
149 	vinserti128	$0x1, T3_XMM, T1, T5		// T5 = (0B 1B 2B 3B)
150 	vperm2i128	$0x31, T2, T0, T0		// T0 = (0C 1C 2C 3C)
151 	vperm2i128	$0x31, T3, T1, T1		// T1 = (0D 1D 2D 3D)
152 
153 	vpaddq		T5, T4, T4
154 	vpaddq		T1, T0, T0
155 	vpaddq		T4, T0, T0
156 	vmovdqu		T0, (HASH)
157 	RET
158 SYM_FUNC_END(nh_avx2)
159