1 /*
2  * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include <stdlib.h>
16 #include "hvb_crypto.h"
17 #include "hvb_util.h"
18 #include "hvb_rsa.h"
19 
20 enum {
21     RESULT_OK = 0,
22     ERROR_MEMORY_EMPTY,
23     ERROR_MEMORY_NO_ENOUGH,
24     ERROR_WORDLEN_ZERO,
25 };
26 
27 #ifndef __WORDSIZE
28 #if defined(__LP64__)
29 #define __WORDSIZE 64
30 #elif defined(__LP32__)
31 #define __WORDSIZE 32
32 #else
33 #error "not support word size "
34 #endif
35 #endif
36 
37 #define WORD_BYTE_SIZE sizeof(unsigned long)
38 #define WORD_BIT_SIZE (WORD_BYTE_SIZE * 8)
39 #define WORD_BIT_MASK (((1UL << WORD_BIT_SIZE) - 1))
40 #define byte2bit(byte) ((byte) << 3)
41 #define SWORD_BIT_SIZE (WORD_BIT_SIZE / 2)
42 #define SWORD_BIT_MASK ((1UL << SWORD_BIT_SIZE) - 1)
43 
lin_clear(struct long_int_num *p_a)44 static void lin_clear(struct long_int_num *p_a)
45 {
46     (void)hvb_memset_s(p_a->data_mem, p_a->mem_size, 0, p_a->mem_size);
47 }
48 
lin_copy(struct long_int_num *p_src, struct long_int_num *p_dst)49 static int lin_copy(struct long_int_num *p_src, struct long_int_num *p_dst)
50 {
51     if (p_src->valid_word_len * WORD_BYTE_SIZE > p_dst->mem_size) {
52         return ERROR_MEMORY_NO_ENOUGH;
53     }
54 
55     if (hvb_memcpy_s(p_dst->p_uint, p_dst->mem_size, p_src->p_uint, p_src->valid_word_len * WORD_BYTE_SIZE) != 0) {
56         return ERROR_MEMORY_NO_ENOUGH;
57     }
58 
59     p_dst->valid_word_len = p_src->valid_word_len;
60 
61     return RESULT_OK;
62 }
63 
lin_compare(struct long_int_num *p_a, struct long_int_num *p_b)64 static int lin_compare(struct long_int_num *p_a, struct long_int_num *p_b)
65 {
66     int i;
67 
68     if (p_a->valid_word_len != p_b->valid_word_len) {
69         return p_a->valid_word_len - p_b->valid_word_len;
70     }
71 
72     if (p_a->valid_word_len == 0) {
73         return 0;
74     }
75 
76     for (i = p_a->valid_word_len - 1; i >= 0; --i) {
77         if (p_a->p_uint[i] != p_b->p_uint[i]) {
78             if (p_a->p_uint[i] > p_b->p_uint[i]) {
79                 return 1;
80             }
81             return -1;
82         }
83     }
84     return 0;
85 }
86 
lin_calloc(struct long_int_num *p_long_int, uint32_t word_len)87 static int lin_calloc(struct long_int_num *p_long_int, uint32_t word_len)
88 {
89     unsigned long *p_data = NULL;
90 
91     if (word_len == 0) {
92         return ERROR_WORDLEN_ZERO;
93     }
94     p_data = hvb_malloc(word_len * WORD_BYTE_SIZE);
95     if (p_data == NULL) {
96         return ERROR_MEMORY_EMPTY;
97     }
98 
99     if (hvb_memset_s(p_data, word_len * WORD_BYTE_SIZE, 0, word_len * WORD_BYTE_SIZE) != 0) {
100         hvb_free(p_data);
101         return ERROR_MEMORY_NO_ENOUGH;
102     }
103 
104     p_long_int->data_mem = p_data;
105     p_long_int->mem_size = word_len * WORD_BYTE_SIZE;
106     p_long_int->p_uint = p_data;
107     p_long_int->valid_word_len = 0;
108 
109     return RESULT_OK;
110 }
111 
lin_create(uint32_t word_len)112 struct long_int_num *lin_create(uint32_t word_len)
113 {
114     struct long_int_num *p_res = NULL;
115 
116     p_res = hvb_malloc(sizeof(struct long_int_num));
117     if (p_res == NULL) {
118         return NULL;
119     }
120 
121     if (lin_calloc(p_res, word_len) > 0) {
122         hvb_free(p_res);
123         return NULL;
124     }
125     p_res->valid_word_len = 0;
126     return p_res;
127 }
128 
lin_free(struct long_int_num *p_long_int)129 void lin_free(struct long_int_num *p_long_int)
130 {
131     if (!p_long_int) {
132         return;
133     }
134     if (p_long_int->p_uint != NULL) {
135         hvb_free(p_long_int->data_mem);
136         p_long_int->p_uint = NULL;
137     }
138     hvb_free(p_long_int);
139 
140     return;
141 }
142 
bn_get_valid_len(const uint8_t *pd, uint32_t size)143 uint32_t bn_get_valid_len(const uint8_t *pd, uint32_t size)
144 {
145     uint32_t i = 0;
146     uint32_t valid_len = size;
147 
148     if (!pd) {
149         return 0;
150     }
151 
152     while (valid_len > 0 && pd[i++] == 0) {
153         valid_len--;
154     }
155 
156     return valid_len;
157 }
158 
lin_update_valid_len(struct long_int_num *p_a)159 void lin_update_valid_len(struct long_int_num *p_a)
160 {
161     unsigned long *p_data = NULL;
162     uint32_t i;
163 
164     if (!p_a) {
165         return;
166     }
167 
168     p_data = p_a->p_uint + p_a->valid_word_len - 1;
169     for (i = 0; i < p_a->valid_word_len; ++i) {
170         if (*p_data != 0) {
171             break;
172         }
173         --p_data;
174     }
175     p_a->valid_word_len -= i;
176 }
177 
lin_mul_word(unsigned long a, unsigned long b, unsigned long *res_hi, unsigned long *res_low)178 static void lin_mul_word(unsigned long a, unsigned long b, unsigned long *res_hi, unsigned long *res_low)
179 {
180 #if defined(__aarch64__)
181     unsigned long hi = 0;
182     *res_low = a * b;
183     __asm__ volatile ("umulh %0, %1, %2" : "+r"(hi) : "r"(a), "r"(b) :);
184     *res_hi = hi;
185 #else
186 
187 #if defined(__uint128_t)
188     #if __WORDSIZE == 32
189     unsigned long long aa;
190 #elif __WORDSIZE == 64
191     __uint128_t aa, bb;
192 #else
193     #error "not support word size "
194 #endif
195     aa = a;
196     bb = b;
197     aa = aa * bb;
198     *res_hi = aa >> WORD_BIT_SIZE;
199     *res_low = aa & WORD_BIT_MASK;
200 #else
201     unsigned long a_h, a_l;
202     unsigned long b_h, b_l;
203     unsigned long res_h, res_l;
204     unsigned long c, t;
205     a_h = a >> SWORD_BIT_SIZE;
206     a_l = a & SWORD_BIT_MASK;
207     b_h = b >> SWORD_BIT_SIZE;
208     b_l = b & SWORD_BIT_MASK;
209 
210     res_h = a_h * b_h;
211     res_l = a_l * b_l;
212 
213     c = a_h * b_l;
214     res_h += c >> SWORD_BIT_SIZE;
215     t = res_l;
216     res_l += c << SWORD_BIT_SIZE;
217     res_h += t > res_l;
218 
219     c = a_l * b_h;
220     res_h += c >> SWORD_BIT_SIZE;
221     t = res_l;
222     res_l += c << SWORD_BIT_SIZE;
223     res_h += t > res_l;
224     *res_hi  = res_h;
225     *res_low = res_l;
226 #endif
227 #endif
228 }
229 
lin_sub(struct long_int_num *p_a, struct long_int_num *p_b)230 static void lin_sub(struct long_int_num *p_a, struct long_int_num *p_b)
231 {
232     uint32_t i;
233     unsigned long c;
234     unsigned long t;
235 
236     c = 0;
237     for (i = 0; i < p_b->valid_word_len; ++i) {
238         t = p_a->p_uint[i] < c;
239         p_a->p_uint[i] = p_a->p_uint[i] - c;
240 
241         c = (p_a->p_uint[i] < p_b->p_uint[i]) + t;
242         p_a->p_uint[i] = p_a->p_uint[i] - p_b->p_uint[i];
243     }
244     for (; i < p_a->valid_word_len && c; ++i) {
245         t = p_a->p_uint[i] < c;
246         p_a->p_uint[i] = p_a->p_uint[i] - c;
247         c = t;
248     }
249     lin_update_valid_len(p_a);
250 }
251 
252 #define dword_add_word(a, b, r)		       \
253     do {				       \
254         r##_l = a##_l + (b);	       \
255         r##_h = a##_h + (r##_l < (b)); \
256     } while (0)
257 
montgomery_mul_add(struct long_int_num *p_a, unsigned long b, struct long_int_num *p_n, unsigned long n_n0_i, struct long_int_num *p_res)258 static void montgomery_mul_add(struct long_int_num *p_a, unsigned long b, struct long_int_num *p_n,
259                                unsigned long n_n0_i, struct long_int_num *p_res)
260 {
261     unsigned long x_h, x_l;
262     unsigned long d0;
263     unsigned long y_h, y_l;
264     unsigned long t_h, t_l;
265     unsigned long *p_ad = p_a->p_uint;
266     unsigned long *p_nd = p_n->p_uint;
267     unsigned long *p_rd = p_res->p_uint;
268     uint32_t i;
269 
270     while (p_a->valid_word_len > p_n->valid_word_len){
271         lin_sub(p_a, p_n);
272     }
273 
274     lin_mul_word(p_a->p_uint[0], b, &x_h, &x_l);
275 
276     dword_add_word(x, p_rd[0], x);
277 
278     d0 = x_l * n_n0_i;
279 
280     lin_mul_word(d0, p_nd[0], &y_h, &y_l);
281     dword_add_word(y, x_l, y);
282 
283     for (i = 1; i < p_a->valid_word_len; ++i) {
284         lin_mul_word(p_ad[i], b, &t_h, &t_l);
285         dword_add_word(t, p_rd[i], t);
286         dword_add_word(t, x_h, x);
287 
288         lin_mul_word(d0, p_nd[i], &t_h, &t_l);
289         dword_add_word(t, x_l, t);
290         dword_add_word(t, y_h, y);
291 
292         p_rd[i - 1] = y_l;
293     }
294 
295     p_rd[i - 1] = x_h + y_h;
296 
297     p_res->valid_word_len = p_n->valid_word_len;
298     if (p_rd[i - 1] < x_h) {
299         lin_sub(p_res, p_n);
300     }
301 }
302 
montgomery_mod_mul(struct long_int_num *p_a, struct long_int_num *p_b, struct long_int_num *p_n, unsigned long n_n0_i, struct long_int_num *p_res)303 static void montgomery_mod_mul(struct long_int_num *p_a, struct long_int_num *p_b, struct long_int_num *p_n,
304                                unsigned long n_n0_i, struct long_int_num *p_res)
305 {
306     uint32_t i;
307 
308     lin_clear(p_res);
309 
310     for (i = 0; i < p_b->valid_word_len; ++i) {
311         montgomery_mul_add(p_a, p_b->p_uint[i], p_n, n_n0_i, p_res);
312     }
313 }
314 
montgomery_mod_exp(struct long_int_num *p_m, struct long_int_num *p_n, unsigned long n_n0_i, struct long_int_num *p_rr, uint32_t exp)315 struct long_int_num *montgomery_mod_exp(struct long_int_num *p_m, struct long_int_num *p_n, unsigned long n_n0_i,
316                                         struct long_int_num *p_rr, uint32_t exp)
317 {
318     struct long_int_num *p_res = NULL;
319     struct long_int_num *p_mr = NULL;
320     struct long_int_num *p_square = NULL;
321     int i;
322     if ((exp & 1UL) == 0) {
323         goto fail_final;
324     }
325 
326     p_mr = lin_create(p_n->valid_word_len);
327     if (p_mr == NULL) {
328         goto fail_final;
329     }
330 
331     p_square = lin_create(p_n->valid_word_len);
332     if (p_square == NULL) {
333         goto fail_final;
334     }
335 
336     p_res = lin_create(p_n->valid_word_len);
337     if (p_res == NULL) {
338         goto fail_final;
339     }
340 
341     montgomery_mod_mul(p_m, p_rr, p_n, n_n0_i, p_mr);
342     i = byte2bit(sizeof(exp)) - 1;
343     for (; i >= 0; --i) {
344         if (exp & (1UL << i)) {
345             break;
346         }
347     }
348 
349     lin_copy(p_mr, p_res);
350 
351     for (--i; i > 0; --i) {
352         montgomery_mod_mul(p_res, p_res, p_n, n_n0_i, p_square);
353         if (exp & (1UL << i)) {
354             montgomery_mod_mul(p_mr, p_square, p_n, n_n0_i, p_res);
355         } else {
356             lin_copy(p_square, p_res);
357         }
358     }
359     montgomery_mod_mul(p_res, p_res, p_n, n_n0_i, p_square);
360     montgomery_mod_mul(p_m, p_square, p_n, n_n0_i, p_res);
361 
362     if (lin_compare(p_res, p_n) >= 0) {
363         lin_sub(p_res, p_n);
364     }
365 
366 fail_final:
367     lin_free(p_mr);
368     lin_free(p_square);
369 
370     return p_res;
371 }
372 
lin_get_bitlen(struct long_int_num *p_a)373 uint32_t lin_get_bitlen(struct long_int_num *p_a)
374 {
375     int i;
376     int bit_len;
377     unsigned long *p_data = NULL;
378     unsigned long value;
379 
380     if (!p_a || p_a->valid_word_len == 0) {
381         return 0;
382     }
383     p_data = p_a->p_uint;
384     for (i = p_a->valid_word_len - 1; i >= 0; --i) {
385         if (p_data[i] != 0) {
386             break;
387         }
388     }
389 
390     bit_len = (i + 1) * WORD_BIT_SIZE;
391 
392     if (bit_len == 0) {
393         return 0;
394     }
395 
396     for (value = p_data[i]; ((signed long)value) > 0; value = value << 1) {
397         --bit_len;
398     }
399 
400     return bit_len;
401 }
402