1/*
2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15#include <stdio.h>
16#include <stdlib.h>
17#include "hvb_hash_sha256.h"
18#include "hvb_crypto.h"
19#include "hvb_rsa.h"
20#include "hvb_util.h"
21#include "hvb_sysdeps.h"
22#include "hvb_rsa_verify.h"
23
24
25#define SHA256_DIGEST_LEN 32
26#define PSS_EM_PADDING_LEN 2
27#define PSS_MTMP_PADDING_LEN 8
28#define PSS_DB_PADDING_LEN 1
29#define PSS_END_PADDING_UNIT 0xBC
30#define PSS_LEFTMOST_BIT_MASK 0xFFU
31
32#define PADDING_UNIT_ZERO 0x00
33#define PADDING_UNIT_ONE 0x01
34#define RSA_WIDTH_MAX 8192
35
36#define WORD_BYTE_SIZE sizeof(unsigned long)
37#define WORD_BIT_SIZE (WORD_BYTE_SIZE * 8)
38#define WORD_BIT_MASK (((1UL << WORD_BIT_SIZE) - 1))
39#define bit2byte(bits) ((bits) >> 3)
40#define byte2bit(byte) ((byte) << 3)
41#define bit_val(x) (1U << (x))
42#define bit_mask(x) (bit_val(x) - 1U)
43#define bit_align(n, bit) (((n) + bit_mask(bit)) & (~(bit_mask(bit))))
44#define bit2byte_align(bits) bit2byte(bit_align(bits, 3))
45#define byte2dword(bytes) (((bytes) + (WORD_BYTE_SIZE) - 1) / WORD_BYTE_SIZE)
46#define dword2byte(words) ((words) * WORD_BYTE_SIZE)
47
48/* calc M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */
49static int emsa_pss_calc_m(const uint8_t *pdigest, uint32_t digestlen,
50                           uint8_t *salt, uint32_t saltlen,
51                           uint8_t **m)
52{
53    uint8_t *m_tmp = NULL;
54    uint32_t m_tmp_len;
55    int ret = VERIFY_OK;
56
57    m_tmp_len = digestlen + saltlen + PSS_MTMP_PADDING_LEN;
58    m_tmp = (uint8_t *)hvb_malloc(m_tmp_len);
59    if (!m_tmp) {
60        return PARAM_EMPTY_ERROR;
61    }
62
63    if (hvb_memset_s(m_tmp, m_tmp_len, 0, PSS_MTMP_PADDING_LEN) !=  0) {
64        ret = MEMORY_ERROR;
65        goto error;
66    }
67
68    if (hvb_memcpy_s(&m_tmp[PSS_MTMP_PADDING_LEN], m_tmp_len - PSS_MTMP_PADDING_LEN, pdigest, digestlen) != 0) {
69        ret = MEMORY_ERROR;
70        goto error;
71    }
72
73    if (saltlen != 0 && salt) {
74        if (hvb_memcpy_s(&m_tmp[PSS_MTMP_PADDING_LEN + digestlen], saltlen, salt, saltlen) != 0) {
75            ret = MEMORY_ERROR;
76            goto error;
77        }
78    }
79
80    *m = m_tmp;
81    return ret;
82error:
83    hvb_free(m_tmp);
84    return ret;
85}
86
87/* rsa verify last step compare hash value */
88static int emsa_pss_hash_cmp(uint8_t *m_tmp, uint32_t m_tmp_len,
89                             uint8_t *hash, uint32_t digestlen)
90{
91    int ret;
92    uint8_t *hash_tmp = NULL;
93
94    hash_tmp = (uint8_t *)hvb_malloc(digestlen);
95    if (!hash_tmp) {
96        return HASH_CMP_FAIL;
97    }
98    if (hash_sha256_single(m_tmp, m_tmp_len, hash_tmp, digestlen) != HASH_OK) {
99        ret = HASH_CMP_FAIL;
100        goto rsa_error;
101    }
102    /* compare twice */
103    ret = VERIFY_OK;
104    ret += hvb_memcmp(hash, hash_tmp, digestlen);
105    ret += hvb_memcmp(hash, hash_tmp, digestlen);
106    if (ret != VERIFY_OK)
107        ret = HASH_CMP_FAIL;
108rsa_error:
109    hvb_free(hash_tmp);
110    return ret;
111}
112
113static int rsa_pss_get_emlen(uint32_t klen, struct long_int_num *pn,
114                             uint32_t *emlen, uint32_t *embits)
115{
116    *embits = lin_get_bitlen(pn);
117    if (*embits == 0) {
118        return CALC_EMLEN_ERROR;
119    }
120    (*embits)--;
121
122    *emlen = bit2byte_align(*embits);
123    if (*emlen == 0) {
124        return CALC_EMLEN_ERROR;
125    }
126
127    if (*emlen > klen) {
128        return CALC_EMLEN_ERROR;
129    }
130
131    return VERIFY_OK;
132}
133
134/* make generate function V1 */
135static int rsa_gen_mask_mgf_v1(uint8_t *seed, uint32_t seed_len,
136                               uint8_t *mask, uint32_t mask_len)
137{
138    int ret = VERIFY_OK;
139    uint32_t cnt = 0;
140    uint32_t cnt_maxsize = 0;
141    uint8_t *p_tmp = NULL;
142    uint8_t *pt = NULL;
143    uint8_t *pc = NULL;
144    const uint32_t hash_len = SHA256_DIGEST_LEN;
145
146    /* Step 1: mask length is smaller than the maximum key length */
147    if (mask_len > bit2byte(RSA_WIDTH_MAX)) {
148        return CALC_MASK_ERROR;
149    }
150
151    /* Step 2:  Let pt and pt_tmp be the empty octet string. */
152    pt = (uint8_t *)hvb_malloc(mask_len + hash_len);
153    if (!pt) {
154        return CALC_MASK_ERROR;
155    }
156
157    pc = (uint8_t *)hvb_malloc(seed_len + sizeof(uint32_t));
158    if (!pc) {
159        ret = CALC_MASK_ERROR;
160        goto rsa_error;
161    }
162
163    /*
164     * Step 3:  For counter from 0 to (mask_len + hash_len - 1) / hash_len ,
165     * do the following:
166     * string T:   T = T || Hash (pseed || counter)
167     */
168    p_tmp = pt;
169    if (hvb_memcpy_s(pc, seed_len, seed, seed_len) != 0) {
170        ret = MEMORY_ERROR;
171        goto rsa_error;
172    }
173
174    if (hvb_memset_s(pc + seed_len, sizeof(uint32_t), 0, sizeof(uint32_t)) != 0) {
175        ret = MEMORY_ERROR;
176        goto rsa_error;
177    }
178    /* step 3.1: count of Hash blocks needed for mask calculation */
179    cnt_maxsize = (uint32_t)((mask_len + hash_len - 1) / hash_len);
180
181    for (cnt = 0; cnt < cnt_maxsize; cnt++) {
182        /* step 3.2: pt_tmp = pseed ||Counter */
183        pc[seed_len + sizeof(uint32_t) - sizeof(uint8_t)] = cnt;
184
185        /* step 3.3: calc T, T = T || Hash (pt_tmp) */
186        if (hash_sha256_single(pc, seed_len + sizeof(uint32_t), p_tmp, hash_len) != HASH_OK) {
187            ret = CALC_MASK_ERROR;
188            goto rsa_error;
189        }
190        p_tmp += hash_len;
191    }
192    /* Step 4:  Output the leading L octets of T as the octet string mask. */
193    if (hvb_memcpy_s(mask, mask_len, pt, mask_len) != 0) {
194        ret = MEMORY_ERROR;
195        goto rsa_error;
196    }
197
198rsa_error:
199    if (pt != NULL)
200        hvb_free(pt);
201    if (pc != NULL)
202        hvb_free(pc);
203    return ret;
204}
205
206static int emsa_pss_verify_check_db(uint8_t *db, uint32_t db_len,
207                                    uint32_t emlen, uint32_t digestlen,
208                                    uint32_t saltlen)
209{
210    int i;
211
212    for (i = 0; i < emlen - digestlen - saltlen - PSS_EM_PADDING_LEN; i++) {
213        if (db[i] != PADDING_UNIT_ZERO) {
214            return CHECK_DB_ERROR;
215        }
216    }
217
218    if (db[db_len - saltlen - PSS_DB_PADDING_LEN] != PADDING_UNIT_ONE) {
219        return CMP_DB_FAIL;
220    }
221
222    return VERIFY_OK;
223}
224
225static int emsa_pss_verify(uint32_t saltlen, const uint8_t *pdigest,
226                           uint32_t digestlen, uint32_t emlen,
227                           uint32_t embits, uint8_t *pem)
228{
229    int ret;
230    uint32_t i;
231    uint32_t masklen;
232    uint32_t m_tmp_len;
233    uint32_t db_len = 0;
234    uint8_t *hash = NULL;
235    uint8_t *m_tmp = NULL;
236    uint8_t *maskedb = NULL;
237    uint8_t *salt = NULL;
238    uint8_t *db = NULL;
239
240    masklen = byte2bit(emlen) - embits;
241
242    /*
243     * Step 1: Skip digest calculate
244     * Step 2: Check sizes, emLen < hLen + sLen + 2
245     */
246    if (emlen < digestlen + PSS_EM_PADDING_LEN || saltlen > (emlen - digestlen - PSS_EM_PADDING_LEN)) {
247        return CALC_EMLEN_ERROR;
248    }
249    /* Step 3: if rightmost of EM is oxbc */
250    if (pem[emlen - PSS_DB_PADDING_LEN] != PSS_END_PADDING_UNIT) {
251        return CALC_0XBC_ERROR;
252    }
253
254    /* Step 4: set maskedDB and H */
255    maskedb = pem;
256    db_len = emlen - digestlen - PSS_DB_PADDING_LEN;
257    hash = &pem[db_len];
258
259    /* Step 5: Check that the leftmost bits in the leftmost octet of EM have the value 0 */
260    if ((maskedb[0] & (~(PSS_LEFTMOST_BIT_MASK >> masklen))) != 0) {
261        return CALC_EM_ERROR;
262    }
263
264    /* Step 6: calc dbMask, MGF(H) */
265    db = (uint8_t *)hvb_malloc(db_len); /* db is dbmask */
266    if (!db) {
267        return CALC_DB_ERROR;
268    }
269    ret = rsa_gen_mask_mgf_v1(hash, digestlen, db, db_len);
270    if (ret != VERIFY_OK) {
271        goto rsa_error;
272    }
273    /* Step 7: calc db, maskedDB ^ db_mask */
274    for (i = 0; i < db_len; i++) {
275        db[i] = maskedb[i] ^ db[i];
276    }
277
278    /* Step 8: Set the leftmost 8*emLen-emBits bits in DB to zero */
279    db[0] &= PSS_LEFTMOST_BIT_MASK >> masklen;
280
281    /* Step 9: check db padding data */
282    ret = emsa_pss_verify_check_db(db, db_len, emlen, digestlen, saltlen);
283    if (ret != VERIFY_OK) {
284        goto rsa_error;
285    }
286    /* Step 10: set salt be the last slen of DB */
287    if (saltlen != 0) {
288        salt = &db[db_len - saltlen];
289    }
290
291    /* Step 11: calc M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */
292    ret = emsa_pss_calc_m(pdigest, digestlen, salt, saltlen, &m_tmp);
293    if (ret != VERIFY_OK) {
294        goto rsa_error;
295    }
296    /* Step 12: hash_tmp = H' = Hash(M') */
297    m_tmp_len = PSS_MTMP_PADDING_LEN + digestlen + saltlen;
298    ret = emsa_pss_hash_cmp(m_tmp, m_tmp_len, hash, digestlen);
299
300rsa_error:
301    if (db != NULL)
302        hvb_free(db);
303    if (m_tmp != NULL)
304        hvb_free(m_tmp);
305    return ret;
306}
307
308static inline void invert_copy(uint8_t *dst, uint8_t *src, uint32_t len)
309{
310    for (uint32_t i = 0; i < len; i++) {
311        dst[i] = src[len - i - 1];
312    }
313}
314
315static int hvb_rsa_verify_pss_param_check(const struct hvb_rsa_pubkey *pkey, const uint8_t *pdigest,
316                                          uint32_t digestlen, uint8_t *psign, uint32_t signlen)
317{
318    uint32_t klen;
319    uint32_t n_validlen;
320
321    if (!pkey || !pdigest || !psign) {
322        return PARAM_EMPTY_ERROR;
323    }
324    if (!pkey->pn || !pkey->p_rr || pkey->n_n0_i == 0) {
325        return PUBKEY_EMPTY_ERROR;
326    }
327    klen = bit2byte(pkey->width);
328    n_validlen = bn_get_valid_len(pkey->pn, pkey->nlen);
329    if (digestlen != SHA256_DIGEST_LEN) {
330        return DIGEST_LEN_ERROR;
331    }
332    if (n_validlen != klen || pkey->rlen > pkey->nlen) {
333        return PUBKEY_LEN_ERROR;
334    }
335    if (signlen > klen) {
336        return SIGN_LEN_ERROR;
337    }
338
339    return VERIFY_OK;
340}
341
342static int hvb_rsa_verify_pss_param_convert(const struct hvb_rsa_pubkey *pkey, uint8_t *psign,
343                                            uint32_t signlen, struct long_int_num *p_n,
344                                            struct long_int_num *p_rr, struct long_int_num *p_m)
345{
346    invert_copy((uint8_t *)p_n->p_uint, pkey->pn, pkey->nlen);
347    p_n->valid_word_len = byte2dword(pkey->nlen);
348    lin_update_valid_len(p_n);
349    if (!p_n) {
350        return PUBKEY_EMPTY_ERROR;
351    }
352
353    invert_copy((uint8_t *)p_m->p_uint, psign, signlen);
354    p_m->valid_word_len = byte2dword(pkey->nlen);
355    lin_update_valid_len(p_m);
356    if (!p_m) {
357        return SIGN_EMPTY_ERROR;
358    }
359
360    invert_copy((uint8_t *)p_rr->p_uint, pkey->p_rr, pkey->rlen);
361    p_rr->valid_word_len = byte2dword(pkey->nlen);
362    lin_update_valid_len(p_rr);
363    if (!p_rr) {
364        return PUBKEY_EMPTY_ERROR;
365    }
366
367    return VERIFY_OK;
368}
369
370int hvb_rsa_verify_pss(const struct hvb_rsa_pubkey
371                       *pkey, const uint8_t *pdigest,
372                       uint32_t digestlen, uint8_t *psign,
373                       uint32_t signlen, uint32_t saltlen)
374{
375    int ret;
376    uint32_t klen;
377    uint32_t emlen;
378    uint32_t embits;
379    unsigned long n_n0_i;
380    struct long_int_num *p_n = NULL;
381    struct long_int_num *p_m = NULL;
382    struct long_int_num *p_rr = NULL;
383    struct long_int_num *em = NULL;
384    uint8_t *em_data = NULL;
385
386    ret = hvb_rsa_verify_pss_param_check(pkey, pdigest, digestlen, psign, signlen);
387    if (ret != VERIFY_OK) {
388        return ret;
389    }
390
391    n_n0_i = (unsigned long)pkey->n_n0_i;
392    klen = bit2byte(pkey->width);
393    p_n = lin_create(byte2dword(pkey->nlen));
394    if (!p_n) {
395        return MEMORY_ERROR;
396    }
397    p_m = lin_create(byte2dword(pkey->nlen));
398    if (!p_m) {
399        ret = MEMORY_ERROR;
400        goto rsa_error;
401    }
402    p_rr = lin_create(byte2dword(pkey->nlen));
403    if (!p_rr) {
404        ret = MEMORY_ERROR;
405        goto rsa_error;
406    }
407    ret = hvb_rsa_verify_pss_param_convert(pkey, psign, signlen, p_n, p_rr, p_m);
408    if (ret != VERIFY_OK) {
409        goto rsa_error;
410    }
411    /* Step 1: RSA prim decrypt */
412    em = montgomery_mod_exp(p_m, p_n, n_n0_i, p_rr, pkey->e);
413    if (!em) {
414        ret = MOD_EXP_CALC_FAIL;
415        goto rsa_error;
416    }
417
418    lin_update_valid_len(em);
419    em_data = hvb_malloc(klen);
420    if (!em_data) {
421        ret = MOD_EXP_CALC_FAIL;
422        goto rsa_error;
423    }
424
425    if (hvb_memset_s(em_data, klen, 0, klen) != 0) {
426        ret = MEMORY_ERROR;
427        goto rsa_error;
428    }
429    invert_copy(em_data, (uint8_t *)em->p_uint, klen);
430    /* Step 2: emsa pss verify */
431    ret = rsa_pss_get_emlen(klen, p_n, &emlen, &embits);
432    if (ret != VERIFY_OK) {
433        goto rsa_error;
434    }
435    if (klen - emlen == 1 && em_data[0] != 0) {
436        ret = MOD_EXP_CALC_FAIL;
437        goto rsa_error;
438    }
439    ret = emsa_pss_verify(saltlen, pdigest, digestlen, emlen, embits, em_data + klen - emlen);
440
441rsa_error:
442    lin_free(em);
443    lin_free(p_n);
444    lin_free(p_m);
445    lin_free(p_rr);
446    if (em_data) {
447        hvb_free(em_data);
448    }
449
450    return ret;
451}
452