xref: /base/startup/hvb/libhvb/src/crypto/hvb_rsa.c (revision 7310c0d0)
1/*
2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15#include <stdlib.h>
16#include "hvb_crypto.h"
17#include "hvb_util.h"
18#include "hvb_rsa.h"
19
20enum {
21    RESULT_OK = 0,
22    ERROR_MEMORY_EMPTY,
23    ERROR_MEMORY_NO_ENOUGH,
24    ERROR_WORDLEN_ZERO,
25};
26
27#ifndef __WORDSIZE
28#if defined(__LP64__)
29#define __WORDSIZE 64
30#elif defined(__LP32__)
31#define __WORDSIZE 32
32#else
33#error "not support word size "
34#endif
35#endif
36
37#define WORD_BYTE_SIZE sizeof(unsigned long)
38#define WORD_BIT_SIZE (WORD_BYTE_SIZE * 8)
39#define WORD_BIT_MASK (((1UL << WORD_BIT_SIZE) - 1))
40#define byte2bit(byte) ((byte) << 3)
41#define SWORD_BIT_SIZE (WORD_BIT_SIZE / 2)
42#define SWORD_BIT_MASK ((1UL << SWORD_BIT_SIZE) - 1)
43
44static void lin_clear(struct long_int_num *p_a)
45{
46    (void)hvb_memset_s(p_a->data_mem, p_a->mem_size, 0, p_a->mem_size);
47}
48
49static int lin_copy(struct long_int_num *p_src, struct long_int_num *p_dst)
50{
51    if (p_src->valid_word_len * WORD_BYTE_SIZE > p_dst->mem_size) {
52        return ERROR_MEMORY_NO_ENOUGH;
53    }
54
55    if (hvb_memcpy_s(p_dst->p_uint, p_dst->mem_size, p_src->p_uint, p_src->valid_word_len * WORD_BYTE_SIZE) != 0) {
56        return ERROR_MEMORY_NO_ENOUGH;
57    }
58
59    p_dst->valid_word_len = p_src->valid_word_len;
60
61    return RESULT_OK;
62}
63
64static int lin_compare(struct long_int_num *p_a, struct long_int_num *p_b)
65{
66    int i;
67
68    if (p_a->valid_word_len != p_b->valid_word_len) {
69        return p_a->valid_word_len - p_b->valid_word_len;
70    }
71
72    if (p_a->valid_word_len == 0) {
73        return 0;
74    }
75
76    for (i = p_a->valid_word_len - 1; i >= 0; --i) {
77        if (p_a->p_uint[i] != p_b->p_uint[i]) {
78            if (p_a->p_uint[i] > p_b->p_uint[i]) {
79                return 1;
80            }
81            return -1;
82        }
83    }
84    return 0;
85}
86
87static int lin_calloc(struct long_int_num *p_long_int, uint32_t word_len)
88{
89    unsigned long *p_data = NULL;
90
91    if (word_len == 0) {
92        return ERROR_WORDLEN_ZERO;
93    }
94    p_data = hvb_malloc(word_len * WORD_BYTE_SIZE);
95    if (p_data == NULL) {
96        return ERROR_MEMORY_EMPTY;
97    }
98
99    if (hvb_memset_s(p_data, word_len * WORD_BYTE_SIZE, 0, word_len * WORD_BYTE_SIZE) != 0) {
100        hvb_free(p_data);
101        return ERROR_MEMORY_NO_ENOUGH;
102    }
103
104    p_long_int->data_mem = p_data;
105    p_long_int->mem_size = word_len * WORD_BYTE_SIZE;
106    p_long_int->p_uint = p_data;
107    p_long_int->valid_word_len = 0;
108
109    return RESULT_OK;
110}
111
112struct long_int_num *lin_create(uint32_t word_len)
113{
114    struct long_int_num *p_res = NULL;
115
116    p_res = hvb_malloc(sizeof(struct long_int_num));
117    if (p_res == NULL) {
118        return NULL;
119    }
120
121    if (lin_calloc(p_res, word_len) > 0) {
122        hvb_free(p_res);
123        return NULL;
124    }
125    p_res->valid_word_len = 0;
126    return p_res;
127}
128
129void lin_free(struct long_int_num *p_long_int)
130{
131    if (!p_long_int) {
132        return;
133    }
134    if (p_long_int->p_uint != NULL) {
135        hvb_free(p_long_int->data_mem);
136        p_long_int->p_uint = NULL;
137    }
138    hvb_free(p_long_int);
139
140    return;
141}
142
143uint32_t bn_get_valid_len(const uint8_t *pd, uint32_t size)
144{
145    uint32_t i = 0;
146    uint32_t valid_len = size;
147
148    if (!pd) {
149        return 0;
150    }
151
152    while (valid_len > 0 && pd[i++] == 0) {
153        valid_len--;
154    }
155
156    return valid_len;
157}
158
159void lin_update_valid_len(struct long_int_num *p_a)
160{
161    unsigned long *p_data = NULL;
162    uint32_t i;
163
164    if (!p_a) {
165        return;
166    }
167
168    p_data = p_a->p_uint + p_a->valid_word_len - 1;
169    for (i = 0; i < p_a->valid_word_len; ++i) {
170        if (*p_data != 0) {
171            break;
172        }
173        --p_data;
174    }
175    p_a->valid_word_len -= i;
176}
177
178static void lin_mul_word(unsigned long a, unsigned long b, unsigned long *res_hi, unsigned long *res_low)
179{
180#if defined(__aarch64__)
181    unsigned long hi = 0;
182    *res_low = a * b;
183    __asm__ volatile ("umulh %0, %1, %2" : "+r"(hi) : "r"(a), "r"(b) :);
184    *res_hi = hi;
185#else
186
187#if defined(__uint128_t)
188    #if __WORDSIZE == 32
189    unsigned long long aa;
190#elif __WORDSIZE == 64
191    __uint128_t aa, bb;
192#else
193    #error "not support word size "
194#endif
195    aa = a;
196    bb = b;
197    aa = aa * bb;
198    *res_hi = aa >> WORD_BIT_SIZE;
199    *res_low = aa & WORD_BIT_MASK;
200#else
201    unsigned long a_h, a_l;
202    unsigned long b_h, b_l;
203    unsigned long res_h, res_l;
204    unsigned long c, t;
205    a_h = a >> SWORD_BIT_SIZE;
206    a_l = a & SWORD_BIT_MASK;
207    b_h = b >> SWORD_BIT_SIZE;
208    b_l = b & SWORD_BIT_MASK;
209
210    res_h = a_h * b_h;
211    res_l = a_l * b_l;
212
213    c = a_h * b_l;
214    res_h += c >> SWORD_BIT_SIZE;
215    t = res_l;
216    res_l += c << SWORD_BIT_SIZE;
217    res_h += t > res_l;
218
219    c = a_l * b_h;
220    res_h += c >> SWORD_BIT_SIZE;
221    t = res_l;
222    res_l += c << SWORD_BIT_SIZE;
223    res_h += t > res_l;
224    *res_hi  = res_h;
225    *res_low = res_l;
226#endif
227#endif
228}
229
230static void lin_sub(struct long_int_num *p_a, struct long_int_num *p_b)
231{
232    uint32_t i;
233    unsigned long c;
234    unsigned long t;
235
236    c = 0;
237    for (i = 0; i < p_b->valid_word_len; ++i) {
238        t = p_a->p_uint[i] < c;
239        p_a->p_uint[i] = p_a->p_uint[i] - c;
240
241        c = (p_a->p_uint[i] < p_b->p_uint[i]) + t;
242        p_a->p_uint[i] = p_a->p_uint[i] - p_b->p_uint[i];
243    }
244    for (; i < p_a->valid_word_len && c; ++i) {
245        t = p_a->p_uint[i] < c;
246        p_a->p_uint[i] = p_a->p_uint[i] - c;
247        c = t;
248    }
249    lin_update_valid_len(p_a);
250}
251
252#define dword_add_word(a, b, r)		       \
253    do {				       \
254        r##_l = a##_l + (b);	       \
255        r##_h = a##_h + (r##_l < (b)); \
256    } while (0)
257
258static void montgomery_mul_add(struct long_int_num *p_a, unsigned long b, struct long_int_num *p_n,
259                               unsigned long n_n0_i, struct long_int_num *p_res)
260{
261    unsigned long x_h, x_l;
262    unsigned long d0;
263    unsigned long y_h, y_l;
264    unsigned long t_h, t_l;
265    unsigned long *p_ad = p_a->p_uint;
266    unsigned long *p_nd = p_n->p_uint;
267    unsigned long *p_rd = p_res->p_uint;
268    uint32_t i;
269
270    while (p_a->valid_word_len > p_n->valid_word_len){
271        lin_sub(p_a, p_n);
272    }
273
274    lin_mul_word(p_a->p_uint[0], b, &x_h, &x_l);
275
276    dword_add_word(x, p_rd[0], x);
277
278    d0 = x_l * n_n0_i;
279
280    lin_mul_word(d0, p_nd[0], &y_h, &y_l);
281    dword_add_word(y, x_l, y);
282
283    for (i = 1; i < p_a->valid_word_len; ++i) {
284        lin_mul_word(p_ad[i], b, &t_h, &t_l);
285        dword_add_word(t, p_rd[i], t);
286        dword_add_word(t, x_h, x);
287
288        lin_mul_word(d0, p_nd[i], &t_h, &t_l);
289        dword_add_word(t, x_l, t);
290        dword_add_word(t, y_h, y);
291
292        p_rd[i - 1] = y_l;
293    }
294
295    p_rd[i - 1] = x_h + y_h;
296
297    p_res->valid_word_len = p_n->valid_word_len;
298    if (p_rd[i - 1] < x_h) {
299        lin_sub(p_res, p_n);
300    }
301}
302
303static void montgomery_mod_mul(struct long_int_num *p_a, struct long_int_num *p_b, struct long_int_num *p_n,
304                               unsigned long n_n0_i, struct long_int_num *p_res)
305{
306    uint32_t i;
307
308    lin_clear(p_res);
309
310    for (i = 0; i < p_b->valid_word_len; ++i) {
311        montgomery_mul_add(p_a, p_b->p_uint[i], p_n, n_n0_i, p_res);
312    }
313}
314
315struct long_int_num *montgomery_mod_exp(struct long_int_num *p_m, struct long_int_num *p_n, unsigned long n_n0_i,
316                                        struct long_int_num *p_rr, uint32_t exp)
317{
318    struct long_int_num *p_res = NULL;
319    struct long_int_num *p_mr = NULL;
320    struct long_int_num *p_square = NULL;
321    int i;
322    if ((exp & 1UL) == 0) {
323        goto fail_final;
324    }
325
326    p_mr = lin_create(p_n->valid_word_len);
327    if (p_mr == NULL) {
328        goto fail_final;
329    }
330
331    p_square = lin_create(p_n->valid_word_len);
332    if (p_square == NULL) {
333        goto fail_final;
334    }
335
336    p_res = lin_create(p_n->valid_word_len);
337    if (p_res == NULL) {
338        goto fail_final;
339    }
340
341    montgomery_mod_mul(p_m, p_rr, p_n, n_n0_i, p_mr);
342    i = byte2bit(sizeof(exp)) - 1;
343    for (; i >= 0; --i) {
344        if (exp & (1UL << i)) {
345            break;
346        }
347    }
348
349    lin_copy(p_mr, p_res);
350
351    for (--i; i > 0; --i) {
352        montgomery_mod_mul(p_res, p_res, p_n, n_n0_i, p_square);
353        if (exp & (1UL << i)) {
354            montgomery_mod_mul(p_mr, p_square, p_n, n_n0_i, p_res);
355        } else {
356            lin_copy(p_square, p_res);
357        }
358    }
359    montgomery_mod_mul(p_res, p_res, p_n, n_n0_i, p_square);
360    montgomery_mod_mul(p_m, p_square, p_n, n_n0_i, p_res);
361
362    if (lin_compare(p_res, p_n) >= 0) {
363        lin_sub(p_res, p_n);
364    }
365
366fail_final:
367    lin_free(p_mr);
368    lin_free(p_square);
369
370    return p_res;
371}
372
373uint32_t lin_get_bitlen(struct long_int_num *p_a)
374{
375    int i;
376    int bit_len;
377    unsigned long *p_data = NULL;
378    unsigned long value;
379
380    if (!p_a || p_a->valid_word_len == 0) {
381        return 0;
382    }
383    p_data = p_a->p_uint;
384    for (i = p_a->valid_word_len - 1; i >= 0; --i) {
385        if (p_data[i] != 0) {
386            break;
387        }
388    }
389
390    bit_len = (i + 1) * WORD_BIT_SIZE;
391
392    if (bit_len == 0) {
393        return 0;
394    }
395
396    for (value = p_data[i]; ((signed long)value) > 0; value = value << 1) {
397        --bit_len;
398    }
399
400    return bit_len;
401}
402