1 /*
2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include <stdlib.h>
16 #include "hvb_crypto.h"
17 #include "hvb_util.h"
18 #include "hvb_rsa.h"
19
20 enum {
21 RESULT_OK = 0,
22 ERROR_MEMORY_EMPTY,
23 ERROR_MEMORY_NO_ENOUGH,
24 ERROR_WORDLEN_ZERO,
25 };
26
27 #ifndef __WORDSIZE
28 #if defined(__LP64__)
29 #define __WORDSIZE 64
30 #elif defined(__LP32__)
31 #define __WORDSIZE 32
32 #else
33 #error "not support word size "
34 #endif
35 #endif
36
37 #define WORD_BYTE_SIZE sizeof(unsigned long)
38 #define WORD_BIT_SIZE (WORD_BYTE_SIZE * 8)
39 #define WORD_BIT_MASK (((1UL << WORD_BIT_SIZE) - 1))
40 #define byte2bit(byte) ((byte) << 3)
41 #define SWORD_BIT_SIZE (WORD_BIT_SIZE / 2)
42 #define SWORD_BIT_MASK ((1UL << SWORD_BIT_SIZE) - 1)
43
lin_clear(struct long_int_num *p_a)44 static void lin_clear(struct long_int_num *p_a)
45 {
46 (void)hvb_memset_s(p_a->data_mem, p_a->mem_size, 0, p_a->mem_size);
47 }
48
lin_copy(struct long_int_num *p_src, struct long_int_num *p_dst)49 static int lin_copy(struct long_int_num *p_src, struct long_int_num *p_dst)
50 {
51 if (p_src->valid_word_len * WORD_BYTE_SIZE > p_dst->mem_size) {
52 return ERROR_MEMORY_NO_ENOUGH;
53 }
54
55 if (hvb_memcpy_s(p_dst->p_uint, p_dst->mem_size, p_src->p_uint, p_src->valid_word_len * WORD_BYTE_SIZE) != 0) {
56 return ERROR_MEMORY_NO_ENOUGH;
57 }
58
59 p_dst->valid_word_len = p_src->valid_word_len;
60
61 return RESULT_OK;
62 }
63
lin_compare(struct long_int_num *p_a, struct long_int_num *p_b)64 static int lin_compare(struct long_int_num *p_a, struct long_int_num *p_b)
65 {
66 int i;
67
68 if (p_a->valid_word_len != p_b->valid_word_len) {
69 return p_a->valid_word_len - p_b->valid_word_len;
70 }
71
72 if (p_a->valid_word_len == 0) {
73 return 0;
74 }
75
76 for (i = p_a->valid_word_len - 1; i >= 0; --i) {
77 if (p_a->p_uint[i] != p_b->p_uint[i]) {
78 if (p_a->p_uint[i] > p_b->p_uint[i]) {
79 return 1;
80 }
81 return -1;
82 }
83 }
84 return 0;
85 }
86
lin_calloc(struct long_int_num *p_long_int, uint32_t word_len)87 static int lin_calloc(struct long_int_num *p_long_int, uint32_t word_len)
88 {
89 unsigned long *p_data = NULL;
90
91 if (word_len == 0) {
92 return ERROR_WORDLEN_ZERO;
93 }
94 p_data = hvb_malloc(word_len * WORD_BYTE_SIZE);
95 if (p_data == NULL) {
96 return ERROR_MEMORY_EMPTY;
97 }
98
99 if (hvb_memset_s(p_data, word_len * WORD_BYTE_SIZE, 0, word_len * WORD_BYTE_SIZE) != 0) {
100 hvb_free(p_data);
101 return ERROR_MEMORY_NO_ENOUGH;
102 }
103
104 p_long_int->data_mem = p_data;
105 p_long_int->mem_size = word_len * WORD_BYTE_SIZE;
106 p_long_int->p_uint = p_data;
107 p_long_int->valid_word_len = 0;
108
109 return RESULT_OK;
110 }
111
lin_create(uint32_t word_len)112 struct long_int_num *lin_create(uint32_t word_len)
113 {
114 struct long_int_num *p_res = NULL;
115
116 p_res = hvb_malloc(sizeof(struct long_int_num));
117 if (p_res == NULL) {
118 return NULL;
119 }
120
121 if (lin_calloc(p_res, word_len) > 0) {
122 hvb_free(p_res);
123 return NULL;
124 }
125 p_res->valid_word_len = 0;
126 return p_res;
127 }
128
lin_free(struct long_int_num *p_long_int)129 void lin_free(struct long_int_num *p_long_int)
130 {
131 if (!p_long_int) {
132 return;
133 }
134 if (p_long_int->p_uint != NULL) {
135 hvb_free(p_long_int->data_mem);
136 p_long_int->p_uint = NULL;
137 }
138 hvb_free(p_long_int);
139
140 return;
141 }
142
bn_get_valid_len(const uint8_t *pd, uint32_t size)143 uint32_t bn_get_valid_len(const uint8_t *pd, uint32_t size)
144 {
145 uint32_t i = 0;
146 uint32_t valid_len = size;
147
148 if (!pd) {
149 return 0;
150 }
151
152 while (valid_len > 0 && pd[i++] == 0) {
153 valid_len--;
154 }
155
156 return valid_len;
157 }
158
lin_update_valid_len(struct long_int_num *p_a)159 void lin_update_valid_len(struct long_int_num *p_a)
160 {
161 unsigned long *p_data = NULL;
162 uint32_t i;
163
164 if (!p_a) {
165 return;
166 }
167
168 p_data = p_a->p_uint + p_a->valid_word_len - 1;
169 for (i = 0; i < p_a->valid_word_len; ++i) {
170 if (*p_data != 0) {
171 break;
172 }
173 --p_data;
174 }
175 p_a->valid_word_len -= i;
176 }
177
lin_mul_word(unsigned long a, unsigned long b, unsigned long *res_hi, unsigned long *res_low)178 static void lin_mul_word(unsigned long a, unsigned long b, unsigned long *res_hi, unsigned long *res_low)
179 {
180 #if defined(__aarch64__)
181 unsigned long hi = 0;
182 *res_low = a * b;
183 __asm__ volatile ("umulh %0, %1, %2" : "+r"(hi) : "r"(a), "r"(b) :);
184 *res_hi = hi;
185 #else
186
187 #if defined(__uint128_t)
188 #if __WORDSIZE == 32
189 unsigned long long aa;
190 #elif __WORDSIZE == 64
191 __uint128_t aa, bb;
192 #else
193 #error "not support word size "
194 #endif
195 aa = a;
196 bb = b;
197 aa = aa * bb;
198 *res_hi = aa >> WORD_BIT_SIZE;
199 *res_low = aa & WORD_BIT_MASK;
200 #else
201 unsigned long a_h, a_l;
202 unsigned long b_h, b_l;
203 unsigned long res_h, res_l;
204 unsigned long c, t;
205 a_h = a >> SWORD_BIT_SIZE;
206 a_l = a & SWORD_BIT_MASK;
207 b_h = b >> SWORD_BIT_SIZE;
208 b_l = b & SWORD_BIT_MASK;
209
210 res_h = a_h * b_h;
211 res_l = a_l * b_l;
212
213 c = a_h * b_l;
214 res_h += c >> SWORD_BIT_SIZE;
215 t = res_l;
216 res_l += c << SWORD_BIT_SIZE;
217 res_h += t > res_l;
218
219 c = a_l * b_h;
220 res_h += c >> SWORD_BIT_SIZE;
221 t = res_l;
222 res_l += c << SWORD_BIT_SIZE;
223 res_h += t > res_l;
224 *res_hi = res_h;
225 *res_low = res_l;
226 #endif
227 #endif
228 }
229
lin_sub(struct long_int_num *p_a, struct long_int_num *p_b)230 static void lin_sub(struct long_int_num *p_a, struct long_int_num *p_b)
231 {
232 uint32_t i;
233 unsigned long c;
234 unsigned long t;
235
236 c = 0;
237 for (i = 0; i < p_b->valid_word_len; ++i) {
238 t = p_a->p_uint[i] < c;
239 p_a->p_uint[i] = p_a->p_uint[i] - c;
240
241 c = (p_a->p_uint[i] < p_b->p_uint[i]) + t;
242 p_a->p_uint[i] = p_a->p_uint[i] - p_b->p_uint[i];
243 }
244 for (; i < p_a->valid_word_len && c; ++i) {
245 t = p_a->p_uint[i] < c;
246 p_a->p_uint[i] = p_a->p_uint[i] - c;
247 c = t;
248 }
249 lin_update_valid_len(p_a);
250 }
251
252 #define dword_add_word(a, b, r) \
253 do { \
254 r##_l = a##_l + (b); \
255 r##_h = a##_h + (r##_l < (b)); \
256 } while (0)
257
montgomery_mul_add(struct long_int_num *p_a, unsigned long b, struct long_int_num *p_n, unsigned long n_n0_i, struct long_int_num *p_res)258 static void montgomery_mul_add(struct long_int_num *p_a, unsigned long b, struct long_int_num *p_n,
259 unsigned long n_n0_i, struct long_int_num *p_res)
260 {
261 unsigned long x_h, x_l;
262 unsigned long d0;
263 unsigned long y_h, y_l;
264 unsigned long t_h, t_l;
265 unsigned long *p_ad = p_a->p_uint;
266 unsigned long *p_nd = p_n->p_uint;
267 unsigned long *p_rd = p_res->p_uint;
268 uint32_t i;
269
270 while (p_a->valid_word_len > p_n->valid_word_len){
271 lin_sub(p_a, p_n);
272 }
273
274 lin_mul_word(p_a->p_uint[0], b, &x_h, &x_l);
275
276 dword_add_word(x, p_rd[0], x);
277
278 d0 = x_l * n_n0_i;
279
280 lin_mul_word(d0, p_nd[0], &y_h, &y_l);
281 dword_add_word(y, x_l, y);
282
283 for (i = 1; i < p_a->valid_word_len; ++i) {
284 lin_mul_word(p_ad[i], b, &t_h, &t_l);
285 dword_add_word(t, p_rd[i], t);
286 dword_add_word(t, x_h, x);
287
288 lin_mul_word(d0, p_nd[i], &t_h, &t_l);
289 dword_add_word(t, x_l, t);
290 dword_add_word(t, y_h, y);
291
292 p_rd[i - 1] = y_l;
293 }
294
295 p_rd[i - 1] = x_h + y_h;
296
297 p_res->valid_word_len = p_n->valid_word_len;
298 if (p_rd[i - 1] < x_h) {
299 lin_sub(p_res, p_n);
300 }
301 }
302
montgomery_mod_mul(struct long_int_num *p_a, struct long_int_num *p_b, struct long_int_num *p_n, unsigned long n_n0_i, struct long_int_num *p_res)303 static void montgomery_mod_mul(struct long_int_num *p_a, struct long_int_num *p_b, struct long_int_num *p_n,
304 unsigned long n_n0_i, struct long_int_num *p_res)
305 {
306 uint32_t i;
307
308 lin_clear(p_res);
309
310 for (i = 0; i < p_b->valid_word_len; ++i) {
311 montgomery_mul_add(p_a, p_b->p_uint[i], p_n, n_n0_i, p_res);
312 }
313 }
314
montgomery_mod_exp(struct long_int_num *p_m, struct long_int_num *p_n, unsigned long n_n0_i, struct long_int_num *p_rr, uint32_t exp)315 struct long_int_num *montgomery_mod_exp(struct long_int_num *p_m, struct long_int_num *p_n, unsigned long n_n0_i,
316 struct long_int_num *p_rr, uint32_t exp)
317 {
318 struct long_int_num *p_res = NULL;
319 struct long_int_num *p_mr = NULL;
320 struct long_int_num *p_square = NULL;
321 int i;
322 if ((exp & 1UL) == 0) {
323 goto fail_final;
324 }
325
326 p_mr = lin_create(p_n->valid_word_len);
327 if (p_mr == NULL) {
328 goto fail_final;
329 }
330
331 p_square = lin_create(p_n->valid_word_len);
332 if (p_square == NULL) {
333 goto fail_final;
334 }
335
336 p_res = lin_create(p_n->valid_word_len);
337 if (p_res == NULL) {
338 goto fail_final;
339 }
340
341 montgomery_mod_mul(p_m, p_rr, p_n, n_n0_i, p_mr);
342 i = byte2bit(sizeof(exp)) - 1;
343 for (; i >= 0; --i) {
344 if (exp & (1UL << i)) {
345 break;
346 }
347 }
348
349 lin_copy(p_mr, p_res);
350
351 for (--i; i > 0; --i) {
352 montgomery_mod_mul(p_res, p_res, p_n, n_n0_i, p_square);
353 if (exp & (1UL << i)) {
354 montgomery_mod_mul(p_mr, p_square, p_n, n_n0_i, p_res);
355 } else {
356 lin_copy(p_square, p_res);
357 }
358 }
359 montgomery_mod_mul(p_res, p_res, p_n, n_n0_i, p_square);
360 montgomery_mod_mul(p_m, p_square, p_n, n_n0_i, p_res);
361
362 if (lin_compare(p_res, p_n) >= 0) {
363 lin_sub(p_res, p_n);
364 }
365
366 fail_final:
367 lin_free(p_mr);
368 lin_free(p_square);
369
370 return p_res;
371 }
372
lin_get_bitlen(struct long_int_num *p_a)373 uint32_t lin_get_bitlen(struct long_int_num *p_a)
374 {
375 int i;
376 int bit_len;
377 unsigned long *p_data = NULL;
378 unsigned long value;
379
380 if (!p_a || p_a->valid_word_len == 0) {
381 return 0;
382 }
383 p_data = p_a->p_uint;
384 for (i = p_a->valid_word_len - 1; i >= 0; --i) {
385 if (p_data[i] != 0) {
386 break;
387 }
388 }
389
390 bit_len = (i + 1) * WORD_BIT_SIZE;
391
392 if (bit_len == 0) {
393 return 0;
394 }
395
396 for (value = p_data[i]; ((signed long)value) > 0; value = value << 1) {
397 --bit_len;
398 }
399
400 return bit_len;
401 }
402