1/* 2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15#include <stdio.h> 16#include <stdlib.h> 17#include "hvb_hash_sha256.h" 18#include "hvb_crypto.h" 19#include "hvb_rsa.h" 20#include "hvb_util.h" 21#include "hvb_sysdeps.h" 22#include "hvb_rsa_verify.h" 23 24 25#define SHA256_DIGEST_LEN 32 26#define PSS_EM_PADDING_LEN 2 27#define PSS_MTMP_PADDING_LEN 8 28#define PSS_DB_PADDING_LEN 1 29#define PSS_END_PADDING_UNIT 0xBC 30#define PSS_LEFTMOST_BIT_MASK 0xFFU 31 32#define PADDING_UNIT_ZERO 0x00 33#define PADDING_UNIT_ONE 0x01 34#define RSA_WIDTH_MAX 8192 35 36#define WORD_BYTE_SIZE sizeof(unsigned long) 37#define WORD_BIT_SIZE (WORD_BYTE_SIZE * 8) 38#define WORD_BIT_MASK (((1UL << WORD_BIT_SIZE) - 1)) 39#define bit2byte(bits) ((bits) >> 3) 40#define byte2bit(byte) ((byte) << 3) 41#define bit_val(x) (1U << (x)) 42#define bit_mask(x) (bit_val(x) - 1U) 43#define bit_align(n, bit) (((n) + bit_mask(bit)) & (~(bit_mask(bit)))) 44#define bit2byte_align(bits) bit2byte(bit_align(bits, 3)) 45#define byte2dword(bytes) (((bytes) + (WORD_BYTE_SIZE) - 1) / WORD_BYTE_SIZE) 46#define dword2byte(words) ((words) * WORD_BYTE_SIZE) 47 48/* calc M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */ 49static int emsa_pss_calc_m(const uint8_t *pdigest, uint32_t digestlen, 50 uint8_t *salt, uint32_t saltlen, 51 uint8_t **m) 52{ 53 uint8_t *m_tmp = NULL; 54 uint32_t m_tmp_len; 55 int ret = VERIFY_OK; 56 57 m_tmp_len = digestlen + saltlen + PSS_MTMP_PADDING_LEN; 58 m_tmp = (uint8_t *)hvb_malloc(m_tmp_len); 59 if (!m_tmp) { 60 return PARAM_EMPTY_ERROR; 61 } 62 63 if (hvb_memset_s(m_tmp, m_tmp_len, 0, PSS_MTMP_PADDING_LEN) != 0) { 64 ret = MEMORY_ERROR; 65 goto error; 66 } 67 68 if (hvb_memcpy_s(&m_tmp[PSS_MTMP_PADDING_LEN], m_tmp_len - PSS_MTMP_PADDING_LEN, pdigest, digestlen) != 0) { 69 ret = MEMORY_ERROR; 70 goto error; 71 } 72 73 if (saltlen != 0 && salt) { 74 if (hvb_memcpy_s(&m_tmp[PSS_MTMP_PADDING_LEN + digestlen], saltlen, salt, saltlen) != 0) { 75 ret = MEMORY_ERROR; 76 goto error; 77 } 78 } 79 80 *m = m_tmp; 81 return ret; 82error: 83 hvb_free(m_tmp); 84 return ret; 85} 86 87/* rsa verify last step compare hash value */ 88static int emsa_pss_hash_cmp(uint8_t *m_tmp, uint32_t m_tmp_len, 89 uint8_t *hash, uint32_t digestlen) 90{ 91 int ret; 92 uint8_t *hash_tmp = NULL; 93 94 hash_tmp = (uint8_t *)hvb_malloc(digestlen); 95 if (!hash_tmp) { 96 return HASH_CMP_FAIL; 97 } 98 if (hash_sha256_single(m_tmp, m_tmp_len, hash_tmp, digestlen) != HASH_OK) { 99 ret = HASH_CMP_FAIL; 100 goto rsa_error; 101 } 102 /* compare twice */ 103 ret = VERIFY_OK; 104 ret += hvb_memcmp(hash, hash_tmp, digestlen); 105 ret += hvb_memcmp(hash, hash_tmp, digestlen); 106 if (ret != VERIFY_OK) 107 ret = HASH_CMP_FAIL; 108rsa_error: 109 hvb_free(hash_tmp); 110 return ret; 111} 112 113static int rsa_pss_get_emlen(uint32_t klen, struct long_int_num *pn, 114 uint32_t *emlen, uint32_t *embits) 115{ 116 *embits = lin_get_bitlen(pn); 117 if (*embits == 0) { 118 return CALC_EMLEN_ERROR; 119 } 120 (*embits)--; 121 122 *emlen = bit2byte_align(*embits); 123 if (*emlen == 0) { 124 return CALC_EMLEN_ERROR; 125 } 126 127 if (*emlen > klen) { 128 return CALC_EMLEN_ERROR; 129 } 130 131 return VERIFY_OK; 132} 133 134/* make generate function V1 */ 135static int rsa_gen_mask_mgf_v1(uint8_t *seed, uint32_t seed_len, 136 uint8_t *mask, uint32_t mask_len) 137{ 138 int ret = VERIFY_OK; 139 uint32_t cnt = 0; 140 uint32_t cnt_maxsize = 0; 141 uint8_t *p_tmp = NULL; 142 uint8_t *pt = NULL; 143 uint8_t *pc = NULL; 144 const uint32_t hash_len = SHA256_DIGEST_LEN; 145 146 /* Step 1: mask length is smaller than the maximum key length */ 147 if (mask_len > bit2byte(RSA_WIDTH_MAX)) { 148 return CALC_MASK_ERROR; 149 } 150 151 /* Step 2: Let pt and pt_tmp be the empty octet string. */ 152 pt = (uint8_t *)hvb_malloc(mask_len + hash_len); 153 if (!pt) { 154 return CALC_MASK_ERROR; 155 } 156 157 pc = (uint8_t *)hvb_malloc(seed_len + sizeof(uint32_t)); 158 if (!pc) { 159 ret = CALC_MASK_ERROR; 160 goto rsa_error; 161 } 162 163 /* 164 * Step 3: For counter from 0 to (mask_len + hash_len - 1) / hash_len , 165 * do the following: 166 * string T: T = T || Hash (pseed || counter) 167 */ 168 p_tmp = pt; 169 if (hvb_memcpy_s(pc, seed_len, seed, seed_len) != 0) { 170 ret = MEMORY_ERROR; 171 goto rsa_error; 172 } 173 174 if (hvb_memset_s(pc + seed_len, sizeof(uint32_t), 0, sizeof(uint32_t)) != 0) { 175 ret = MEMORY_ERROR; 176 goto rsa_error; 177 } 178 /* step 3.1: count of Hash blocks needed for mask calculation */ 179 cnt_maxsize = (uint32_t)((mask_len + hash_len - 1) / hash_len); 180 181 for (cnt = 0; cnt < cnt_maxsize; cnt++) { 182 /* step 3.2: pt_tmp = pseed ||Counter */ 183 pc[seed_len + sizeof(uint32_t) - sizeof(uint8_t)] = cnt; 184 185 /* step 3.3: calc T, T = T || Hash (pt_tmp) */ 186 if (hash_sha256_single(pc, seed_len + sizeof(uint32_t), p_tmp, hash_len) != HASH_OK) { 187 ret = CALC_MASK_ERROR; 188 goto rsa_error; 189 } 190 p_tmp += hash_len; 191 } 192 /* Step 4: Output the leading L octets of T as the octet string mask. */ 193 if (hvb_memcpy_s(mask, mask_len, pt, mask_len) != 0) { 194 ret = MEMORY_ERROR; 195 goto rsa_error; 196 } 197 198rsa_error: 199 if (pt != NULL) 200 hvb_free(pt); 201 if (pc != NULL) 202 hvb_free(pc); 203 return ret; 204} 205 206static int emsa_pss_verify_check_db(uint8_t *db, uint32_t db_len, 207 uint32_t emlen, uint32_t digestlen, 208 uint32_t saltlen) 209{ 210 int i; 211 212 for (i = 0; i < emlen - digestlen - saltlen - PSS_EM_PADDING_LEN; i++) { 213 if (db[i] != PADDING_UNIT_ZERO) { 214 return CHECK_DB_ERROR; 215 } 216 } 217 218 if (db[db_len - saltlen - PSS_DB_PADDING_LEN] != PADDING_UNIT_ONE) { 219 return CMP_DB_FAIL; 220 } 221 222 return VERIFY_OK; 223} 224 225static int emsa_pss_verify(uint32_t saltlen, const uint8_t *pdigest, 226 uint32_t digestlen, uint32_t emlen, 227 uint32_t embits, uint8_t *pem) 228{ 229 int ret; 230 uint32_t i; 231 uint32_t masklen; 232 uint32_t m_tmp_len; 233 uint32_t db_len = 0; 234 uint8_t *hash = NULL; 235 uint8_t *m_tmp = NULL; 236 uint8_t *maskedb = NULL; 237 uint8_t *salt = NULL; 238 uint8_t *db = NULL; 239 240 masklen = byte2bit(emlen) - embits; 241 242 /* 243 * Step 1: Skip digest calculate 244 * Step 2: Check sizes, emLen < hLen + sLen + 2 245 */ 246 if (emlen < digestlen + PSS_EM_PADDING_LEN || saltlen > (emlen - digestlen - PSS_EM_PADDING_LEN)) { 247 return CALC_EMLEN_ERROR; 248 } 249 /* Step 3: if rightmost of EM is oxbc */ 250 if (pem[emlen - PSS_DB_PADDING_LEN] != PSS_END_PADDING_UNIT) { 251 return CALC_0XBC_ERROR; 252 } 253 254 /* Step 4: set maskedDB and H */ 255 maskedb = pem; 256 db_len = emlen - digestlen - PSS_DB_PADDING_LEN; 257 hash = &pem[db_len]; 258 259 /* Step 5: Check that the leftmost bits in the leftmost octet of EM have the value 0 */ 260 if ((maskedb[0] & (~(PSS_LEFTMOST_BIT_MASK >> masklen))) != 0) { 261 return CALC_EM_ERROR; 262 } 263 264 /* Step 6: calc dbMask, MGF(H) */ 265 db = (uint8_t *)hvb_malloc(db_len); /* db is dbmask */ 266 if (!db) { 267 return CALC_DB_ERROR; 268 } 269 ret = rsa_gen_mask_mgf_v1(hash, digestlen, db, db_len); 270 if (ret != VERIFY_OK) { 271 goto rsa_error; 272 } 273 /* Step 7: calc db, maskedDB ^ db_mask */ 274 for (i = 0; i < db_len; i++) { 275 db[i] = maskedb[i] ^ db[i]; 276 } 277 278 /* Step 8: Set the leftmost 8*emLen-emBits bits in DB to zero */ 279 db[0] &= PSS_LEFTMOST_BIT_MASK >> masklen; 280 281 /* Step 9: check db padding data */ 282 ret = emsa_pss_verify_check_db(db, db_len, emlen, digestlen, saltlen); 283 if (ret != VERIFY_OK) { 284 goto rsa_error; 285 } 286 /* Step 10: set salt be the last slen of DB */ 287 if (saltlen != 0) { 288 salt = &db[db_len - saltlen]; 289 } 290 291 /* Step 11: calc M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */ 292 ret = emsa_pss_calc_m(pdigest, digestlen, salt, saltlen, &m_tmp); 293 if (ret != VERIFY_OK) { 294 goto rsa_error; 295 } 296 /* Step 12: hash_tmp = H' = Hash(M') */ 297 m_tmp_len = PSS_MTMP_PADDING_LEN + digestlen + saltlen; 298 ret = emsa_pss_hash_cmp(m_tmp, m_tmp_len, hash, digestlen); 299 300rsa_error: 301 if (db != NULL) 302 hvb_free(db); 303 if (m_tmp != NULL) 304 hvb_free(m_tmp); 305 return ret; 306} 307 308static inline void invert_copy(uint8_t *dst, uint8_t *src, uint32_t len) 309{ 310 for (uint32_t i = 0; i < len; i++) { 311 dst[i] = src[len - i - 1]; 312 } 313} 314 315static int hvb_rsa_verify_pss_param_check(const struct hvb_rsa_pubkey *pkey, const uint8_t *pdigest, 316 uint32_t digestlen, uint8_t *psign, uint32_t signlen) 317{ 318 uint32_t klen; 319 uint32_t n_validlen; 320 321 if (!pkey || !pdigest || !psign) { 322 return PARAM_EMPTY_ERROR; 323 } 324 if (!pkey->pn || !pkey->p_rr || pkey->n_n0_i == 0) { 325 return PUBKEY_EMPTY_ERROR; 326 } 327 klen = bit2byte(pkey->width); 328 n_validlen = bn_get_valid_len(pkey->pn, pkey->nlen); 329 if (digestlen != SHA256_DIGEST_LEN) { 330 return DIGEST_LEN_ERROR; 331 } 332 if (n_validlen != klen || pkey->rlen > pkey->nlen) { 333 return PUBKEY_LEN_ERROR; 334 } 335 if (signlen > klen) { 336 return SIGN_LEN_ERROR; 337 } 338 339 return VERIFY_OK; 340} 341 342static int hvb_rsa_verify_pss_param_convert(const struct hvb_rsa_pubkey *pkey, uint8_t *psign, 343 uint32_t signlen, struct long_int_num *p_n, 344 struct long_int_num *p_rr, struct long_int_num *p_m) 345{ 346 invert_copy((uint8_t *)p_n->p_uint, pkey->pn, pkey->nlen); 347 p_n->valid_word_len = byte2dword(pkey->nlen); 348 lin_update_valid_len(p_n); 349 if (!p_n) { 350 return PUBKEY_EMPTY_ERROR; 351 } 352 353 invert_copy((uint8_t *)p_m->p_uint, psign, signlen); 354 p_m->valid_word_len = byte2dword(pkey->nlen); 355 lin_update_valid_len(p_m); 356 if (!p_m) { 357 return SIGN_EMPTY_ERROR; 358 } 359 360 invert_copy((uint8_t *)p_rr->p_uint, pkey->p_rr, pkey->rlen); 361 p_rr->valid_word_len = byte2dword(pkey->nlen); 362 lin_update_valid_len(p_rr); 363 if (!p_rr) { 364 return PUBKEY_EMPTY_ERROR; 365 } 366 367 return VERIFY_OK; 368} 369 370int hvb_rsa_verify_pss(const struct hvb_rsa_pubkey 371 *pkey, const uint8_t *pdigest, 372 uint32_t digestlen, uint8_t *psign, 373 uint32_t signlen, uint32_t saltlen) 374{ 375 int ret; 376 uint32_t klen; 377 uint32_t emlen; 378 uint32_t embits; 379 unsigned long n_n0_i; 380 struct long_int_num *p_n = NULL; 381 struct long_int_num *p_m = NULL; 382 struct long_int_num *p_rr = NULL; 383 struct long_int_num *em = NULL; 384 uint8_t *em_data = NULL; 385 386 ret = hvb_rsa_verify_pss_param_check(pkey, pdigest, digestlen, psign, signlen); 387 if (ret != VERIFY_OK) { 388 return ret; 389 } 390 391 n_n0_i = (unsigned long)pkey->n_n0_i; 392 klen = bit2byte(pkey->width); 393 p_n = lin_create(byte2dword(pkey->nlen)); 394 if (!p_n) { 395 return MEMORY_ERROR; 396 } 397 p_m = lin_create(byte2dword(pkey->nlen)); 398 if (!p_m) { 399 ret = MEMORY_ERROR; 400 goto rsa_error; 401 } 402 p_rr = lin_create(byte2dword(pkey->nlen)); 403 if (!p_rr) { 404 ret = MEMORY_ERROR; 405 goto rsa_error; 406 } 407 ret = hvb_rsa_verify_pss_param_convert(pkey, psign, signlen, p_n, p_rr, p_m); 408 if (ret != VERIFY_OK) { 409 goto rsa_error; 410 } 411 /* Step 1: RSA prim decrypt */ 412 em = montgomery_mod_exp(p_m, p_n, n_n0_i, p_rr, pkey->e); 413 if (!em) { 414 ret = MOD_EXP_CALC_FAIL; 415 goto rsa_error; 416 } 417 418 lin_update_valid_len(em); 419 em_data = hvb_malloc(klen); 420 if (!em_data) { 421 ret = MOD_EXP_CALC_FAIL; 422 goto rsa_error; 423 } 424 425 if (hvb_memset_s(em_data, klen, 0, klen) != 0) { 426 ret = MEMORY_ERROR; 427 goto rsa_error; 428 } 429 invert_copy(em_data, (uint8_t *)em->p_uint, klen); 430 /* Step 2: emsa pss verify */ 431 ret = rsa_pss_get_emlen(klen, p_n, &emlen, &embits); 432 if (ret != VERIFY_OK) { 433 goto rsa_error; 434 } 435 if (klen - emlen == 1 && em_data[0] != 0) { 436 ret = MOD_EXP_CALC_FAIL; 437 goto rsa_error; 438 } 439 ret = emsa_pss_verify(saltlen, pdigest, digestlen, emlen, embits, em_data + klen - emlen); 440 441rsa_error: 442 lin_free(em); 443 lin_free(p_n); 444 lin_free(p_m); 445 lin_free(p_rr); 446 if (em_data) { 447 hvb_free(em_data); 448 } 449 450 return ret; 451} 452