1/*
2 *  PSA RSA layer on top of Mbed TLS crypto
3 */
4/*
5 *  Copyright The Mbed TLS Contributors
6 *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7 */
8
9#include "common.h"
10
11#if defined(MBEDTLS_PSA_CRYPTO_C)
12
13#include <psa/crypto.h>
14#include "psa/crypto_values.h"
15#include "psa_crypto_core.h"
16#include "psa_crypto_random_impl.h"
17#include "psa_crypto_rsa.h"
18#include "psa_crypto_hash.h"
19#include "mbedtls/psa_util.h"
20
21#include <stdlib.h>
22#include <string.h>
23#include "mbedtls/platform.h"
24
25#include <mbedtls/rsa.h>
26#include <mbedtls/error.h>
27#include "rsa_internal.h"
28
29#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
30    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) || \
31    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
32    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) || \
33    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) || \
34    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) || \
35    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
36
37/* Mbed TLS doesn't support non-byte-aligned key sizes (i.e. key sizes
38 * that are not a multiple of 8) well. For example, there is only
39 * mbedtls_rsa_get_len(), which returns a number of bytes, and no
40 * way to return the exact bit size of a key.
41 * To keep things simple, reject non-byte-aligned key sizes. */
42static psa_status_t psa_check_rsa_key_byte_aligned(
43    const mbedtls_rsa_context *rsa)
44{
45    mbedtls_mpi n;
46    psa_status_t status;
47    mbedtls_mpi_init(&n);
48    status = mbedtls_to_psa_error(
49        mbedtls_rsa_export(rsa, &n, NULL, NULL, NULL, NULL));
50    if (status == PSA_SUCCESS) {
51        if (mbedtls_mpi_bitlen(&n) % 8 != 0) {
52            status = PSA_ERROR_NOT_SUPPORTED;
53        }
54    }
55    mbedtls_mpi_free(&n);
56    return status;
57}
58
59psa_status_t mbedtls_psa_rsa_load_representation(
60    psa_key_type_t type, const uint8_t *data, size_t data_length,
61    mbedtls_rsa_context **p_rsa)
62{
63    psa_status_t status;
64    size_t bits;
65
66    *p_rsa = mbedtls_calloc(1, sizeof(mbedtls_rsa_context));
67    if (*p_rsa == NULL) {
68        return PSA_ERROR_INSUFFICIENT_MEMORY;
69    }
70    mbedtls_rsa_init(*p_rsa);
71
72    /* Parse the data. */
73    if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
74        status = mbedtls_to_psa_error(mbedtls_rsa_parse_key(*p_rsa, data, data_length));
75    } else {
76        status = mbedtls_to_psa_error(mbedtls_rsa_parse_pubkey(*p_rsa, data, data_length));
77    }
78    if (status != PSA_SUCCESS) {
79        goto exit;
80    }
81
82    /* The size of an RSA key doesn't have to be a multiple of 8. Mbed TLS
83     * supports non-byte-aligned key sizes, but not well. For example,
84     * mbedtls_rsa_get_len() returns the key size in bytes, not in bits. */
85    bits = PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(*p_rsa));
86    if (bits > PSA_VENDOR_RSA_MAX_KEY_BITS) {
87        status = PSA_ERROR_NOT_SUPPORTED;
88        goto exit;
89    }
90    status = psa_check_rsa_key_byte_aligned(*p_rsa);
91    if (status != PSA_SUCCESS) {
92        goto exit;
93    }
94
95exit:
96    return status;
97}
98#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
99        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) ||
100        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
101        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) ||
102        * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) ||
103        * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) ||
104        * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
105
106#if (defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) && \
107    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT)) || \
108    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
109psa_status_t mbedtls_psa_rsa_import_key(
110    const psa_key_attributes_t *attributes,
111    const uint8_t *data, size_t data_length,
112    uint8_t *key_buffer, size_t key_buffer_size,
113    size_t *key_buffer_length, size_t *bits)
114{
115    psa_status_t status;
116    mbedtls_rsa_context *rsa = NULL;
117
118    /* Parse input */
119    status = mbedtls_psa_rsa_load_representation(attributes->type,
120                                                 data,
121                                                 data_length,
122                                                 &rsa);
123    if (status != PSA_SUCCESS) {
124        goto exit;
125    }
126
127    *bits = (psa_key_bits_t) PSA_BYTES_TO_BITS(mbedtls_rsa_get_len(rsa));
128
129    /* Re-export the data to PSA export format, such that we can store export
130     * representation in the key slot. Export representation in case of RSA is
131     * the smallest representation that's allowed as input, so a straight-up
132     * allocation of the same size as the input buffer will be large enough. */
133    status = mbedtls_psa_rsa_export_key(attributes->type,
134                                        rsa,
135                                        key_buffer,
136                                        key_buffer_size,
137                                        key_buffer_length);
138exit:
139    /* Always free the RSA object */
140    mbedtls_rsa_free(rsa);
141    mbedtls_free(rsa);
142
143    return status;
144}
145#endif /* (defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_IMPORT) &&
146        *  defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT)) ||
147        * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
148
149#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) || \
150    defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
151psa_status_t mbedtls_psa_rsa_export_key(psa_key_type_t type,
152                                        mbedtls_rsa_context *rsa,
153                                        uint8_t *data,
154                                        size_t data_size,
155                                        size_t *data_length)
156{
157    int ret;
158    uint8_t *end = data + data_size;
159
160    /* PSA Crypto API defines the format of an RSA key as a DER-encoded
161     * representation of the non-encrypted PKCS#1 RSAPrivateKey for a
162     * private key and of the RFC3279 RSAPublicKey for a public key. */
163    if (PSA_KEY_TYPE_IS_KEY_PAIR(type)) {
164        ret = mbedtls_rsa_write_key(rsa, data, &end);
165    } else {
166        ret = mbedtls_rsa_write_pubkey(rsa, data, &end);
167    }
168
169    if (ret < 0) {
170        /* Clean up in case pk_write failed halfway through. */
171        memset(data, 0, data_size);
172        return mbedtls_to_psa_error(ret);
173    }
174
175    /* The mbedtls_pk_xxx functions write to the end of the buffer.
176     * Move the data to the beginning and erase remaining data
177     * at the original location. */
178    if (2 * (size_t) ret <= data_size) {
179        memcpy(data, data + data_size - ret, ret);
180        memset(data + data_size - ret, 0, ret);
181    } else if ((size_t) ret < data_size) {
182        memmove(data, data + data_size - ret, ret);
183        memset(data + ret, 0, data_size - ret);
184    }
185
186    *data_length = ret;
187    return PSA_SUCCESS;
188}
189
190psa_status_t mbedtls_psa_rsa_export_public_key(
191    const psa_key_attributes_t *attributes,
192    const uint8_t *key_buffer, size_t key_buffer_size,
193    uint8_t *data, size_t data_size, size_t *data_length)
194{
195    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
196    mbedtls_rsa_context *rsa = NULL;
197
198    status = mbedtls_psa_rsa_load_representation(
199        attributes->type, key_buffer, key_buffer_size, &rsa);
200    if (status != PSA_SUCCESS) {
201        return status;
202    }
203
204    status = mbedtls_psa_rsa_export_key(PSA_KEY_TYPE_RSA_PUBLIC_KEY,
205                                        rsa,
206                                        data,
207                                        data_size,
208                                        data_length);
209
210    mbedtls_rsa_free(rsa);
211    mbedtls_free(rsa);
212
213    return status;
214}
215#endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_EXPORT) ||
216        * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
217
218#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
219static psa_status_t psa_rsa_read_exponent(const uint8_t *e_bytes,
220                                          size_t e_length,
221                                          int *exponent)
222{
223    size_t i;
224    uint32_t acc = 0;
225
226    /* Mbed TLS encodes the public exponent as an int. For simplicity, only
227     * support values that fit in a 32-bit integer, which is larger than
228     * int on just about every platform anyway. */
229    if (e_length > sizeof(acc)) {
230        return PSA_ERROR_NOT_SUPPORTED;
231    }
232    for (i = 0; i < e_length; i++) {
233        acc = (acc << 8) | e_bytes[i];
234    }
235    if (acc > INT_MAX) {
236        return PSA_ERROR_NOT_SUPPORTED;
237    }
238    *exponent = acc;
239    return PSA_SUCCESS;
240}
241
242psa_status_t mbedtls_psa_rsa_generate_key(
243    const psa_key_attributes_t *attributes,
244    const psa_key_production_parameters_t *params, size_t params_data_length,
245    uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
246{
247    psa_status_t status;
248    mbedtls_rsa_context rsa;
249    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
250    int exponent = 65537;
251
252    if (params_data_length != 0) {
253        status = psa_rsa_read_exponent(params->data, params_data_length,
254                                       &exponent);
255        if (status != PSA_SUCCESS) {
256            return status;
257        }
258    }
259
260    mbedtls_rsa_init(&rsa);
261    ret = mbedtls_rsa_gen_key(&rsa,
262                              mbedtls_psa_get_random,
263                              MBEDTLS_PSA_RANDOM_STATE,
264                              (unsigned int) attributes->bits,
265                              exponent);
266    if (ret != 0) {
267        return mbedtls_to_psa_error(ret);
268    }
269
270    status = mbedtls_psa_rsa_export_key(attributes->type,
271                                        &rsa, key_buffer, key_buffer_size,
272                                        key_buffer_length);
273    mbedtls_rsa_free(&rsa);
274
275    return status;
276}
277#endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE) */
278
279/****************************************************************/
280/* Sign/verify hashes */
281/****************************************************************/
282
283#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
284    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
285
286/* Decode the hash algorithm from alg and store the mbedtls encoding in
287 * md_alg. Verify that the hash length is acceptable. */
288static psa_status_t psa_rsa_decode_md_type(psa_algorithm_t alg,
289                                           size_t hash_length,
290                                           mbedtls_md_type_t *md_alg)
291{
292    psa_algorithm_t hash_alg = PSA_ALG_SIGN_GET_HASH(alg);
293    *md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
294
295    /* The Mbed TLS RSA module uses an unsigned int for hash length
296     * parameters. Validate that it fits so that we don't risk an
297     * overflow later. */
298#if SIZE_MAX > UINT_MAX
299    if ((int)hash_length > (int)UINT_MAX) {
300        return PSA_ERROR_INVALID_ARGUMENT;
301    }
302#endif
303
304    /* For signatures using a hash, the hash length must be correct. */
305    if (alg != PSA_ALG_RSA_PKCS1V15_SIGN_RAW) {
306        if (*md_alg == MBEDTLS_MD_NONE) {
307            return PSA_ERROR_NOT_SUPPORTED;
308        }
309        if (mbedtls_md_get_size_from_type(*md_alg) != hash_length) {
310            return PSA_ERROR_INVALID_ARGUMENT;
311        }
312    }
313
314    return PSA_SUCCESS;
315}
316
317psa_status_t mbedtls_psa_rsa_sign_hash(
318    const psa_key_attributes_t *attributes,
319    const uint8_t *key_buffer, size_t key_buffer_size,
320    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
321    uint8_t *signature, size_t signature_size, size_t *signature_length)
322{
323    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
324    mbedtls_rsa_context *rsa = NULL;
325    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
326    mbedtls_md_type_t md_alg;
327
328    status = mbedtls_psa_rsa_load_representation(attributes->type,
329                                                 key_buffer,
330                                                 key_buffer_size,
331                                                 &rsa);
332    if (status != PSA_SUCCESS) {
333        return status;
334    }
335
336    status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
337    if (status != PSA_SUCCESS) {
338        goto exit;
339    }
340
341    if (signature_size < mbedtls_rsa_get_len(rsa)) {
342        status = PSA_ERROR_BUFFER_TOO_SMALL;
343        goto exit;
344    }
345
346#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
347    if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
348        ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
349                                      MBEDTLS_MD_NONE);
350        if (ret == 0) {
351            ret = mbedtls_rsa_pkcs1_sign(rsa,
352                                         mbedtls_psa_get_random,
353                                         MBEDTLS_PSA_RANDOM_STATE,
354                                         md_alg,
355                                         (unsigned int) hash_length,
356                                         hash,
357                                         signature);
358        }
359    } else
360#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
361#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
362    if (PSA_ALG_IS_RSA_PSS(alg)) {
363        ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
364
365        if (ret == 0) {
366            ret = mbedtls_rsa_rsassa_pss_sign(rsa,
367                                              mbedtls_psa_get_random,
368                                              MBEDTLS_PSA_RANDOM_STATE,
369                                              MBEDTLS_MD_NONE,
370                                              (unsigned int) hash_length,
371                                              hash,
372                                              signature);
373        }
374    } else
375#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
376    {
377        status = PSA_ERROR_INVALID_ARGUMENT;
378        goto exit;
379    }
380
381    if (ret == 0) {
382        *signature_length = mbedtls_rsa_get_len(rsa);
383    }
384    status = mbedtls_to_psa_error(ret);
385
386exit:
387    mbedtls_rsa_free(rsa);
388    mbedtls_free(rsa);
389
390    return status;
391}
392
393#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
394static int rsa_pss_expected_salt_len(psa_algorithm_t alg,
395                                     const mbedtls_rsa_context *rsa,
396                                     size_t hash_length)
397{
398    if (PSA_ALG_IS_RSA_PSS_ANY_SALT(alg)) {
399        return MBEDTLS_RSA_SALT_LEN_ANY;
400    }
401    /* Otherwise: standard salt length, i.e. largest possible salt length
402     * up to the hash length. */
403    int klen = (int) mbedtls_rsa_get_len(rsa);   // known to fit
404    int hlen = (int) hash_length; // known to fit
405    int room = klen - 2 - hlen;
406    if (room < 0) {
407        return 0;  // there is no valid signature in this case anyway
408    } else if (room > hlen) {
409        return hlen;
410    } else {
411        return room;
412    }
413}
414#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
415
416psa_status_t mbedtls_psa_rsa_verify_hash(
417    const psa_key_attributes_t *attributes,
418    const uint8_t *key_buffer, size_t key_buffer_size,
419    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
420    const uint8_t *signature, size_t signature_length)
421{
422    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
423    mbedtls_rsa_context *rsa = NULL;
424    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
425    mbedtls_md_type_t md_alg;
426
427    status = mbedtls_psa_rsa_load_representation(attributes->type,
428                                                 key_buffer,
429                                                 key_buffer_size,
430                                                 &rsa);
431    if (status != PSA_SUCCESS) {
432        goto exit;
433    }
434
435    status = psa_rsa_decode_md_type(alg, hash_length, &md_alg);
436    if (status != PSA_SUCCESS) {
437        goto exit;
438    }
439
440    if (signature_length != mbedtls_rsa_get_len(rsa)) {
441        status = PSA_ERROR_INVALID_SIGNATURE;
442        goto exit;
443    }
444
445#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
446    if (PSA_ALG_IS_RSA_PKCS1V15_SIGN(alg)) {
447        ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V15,
448                                      MBEDTLS_MD_NONE);
449        if (ret == 0) {
450            ret = mbedtls_rsa_pkcs1_verify(rsa,
451                                           md_alg,
452                                           (unsigned int) hash_length,
453                                           hash,
454                                           signature);
455        }
456    } else
457#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
458#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
459    if (PSA_ALG_IS_RSA_PSS(alg)) {
460        ret = mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
461        if (ret == 0) {
462            int slen = rsa_pss_expected_salt_len(alg, rsa, hash_length);
463            ret = mbedtls_rsa_rsassa_pss_verify_ext(rsa,
464                                                    md_alg,
465                                                    (unsigned) hash_length,
466                                                    hash,
467                                                    md_alg,
468                                                    slen,
469                                                    signature);
470        }
471    } else
472#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
473    {
474        status = PSA_ERROR_INVALID_ARGUMENT;
475        goto exit;
476    }
477
478    /* Mbed TLS distinguishes "invalid padding" from "valid padding but
479     * the rest of the signature is invalid". This has little use in
480     * practice and PSA doesn't report this distinction. */
481    status = (ret == MBEDTLS_ERR_RSA_INVALID_PADDING) ?
482             PSA_ERROR_INVALID_SIGNATURE :
483             mbedtls_to_psa_error(ret);
484
485exit:
486    mbedtls_rsa_free(rsa);
487    mbedtls_free(rsa);
488
489    return status;
490}
491
492#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
493        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
494
495/****************************************************************/
496/* Asymmetric cryptography */
497/****************************************************************/
498
499#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
500static int psa_rsa_oaep_set_padding_mode(psa_algorithm_t alg,
501                                         mbedtls_rsa_context *rsa)
502{
503    psa_algorithm_t hash_alg = PSA_ALG_RSA_OAEP_GET_HASH(alg);
504    mbedtls_md_type_t md_alg = mbedtls_md_type_from_psa_alg(hash_alg);
505
506    /* Just to get the error status right, as rsa_set_padding() doesn't
507     * distinguish between "bad RSA algorithm" and "unknown hash". */
508    if (mbedtls_md_info_from_type(md_alg) == NULL) {
509        return PSA_ERROR_NOT_SUPPORTED;
510    }
511
512    return mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_alg);
513}
514#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
515
516psa_status_t mbedtls_psa_asymmetric_encrypt(const psa_key_attributes_t *attributes,
517                                            const uint8_t *key_buffer,
518                                            size_t key_buffer_size,
519                                            psa_algorithm_t alg,
520                                            const uint8_t *input,
521                                            size_t input_length,
522                                            const uint8_t *salt,
523                                            size_t salt_length,
524                                            uint8_t *output,
525                                            size_t output_size,
526                                            size_t *output_length)
527{
528    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
529    (void) key_buffer;
530    (void) key_buffer_size;
531    (void) input;
532    (void) input_length;
533    (void) salt;
534    (void) salt_length;
535    (void) output;
536    (void) output_size;
537    (void) output_length;
538
539    if (PSA_KEY_TYPE_IS_RSA(attributes->type)) {
540#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
541        defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
542        mbedtls_rsa_context *rsa = NULL;
543        status = mbedtls_psa_rsa_load_representation(attributes->type,
544                                                     key_buffer,
545                                                     key_buffer_size,
546                                                     &rsa);
547        if (status != PSA_SUCCESS) {
548            goto rsa_exit;
549        }
550
551        if (output_size < mbedtls_rsa_get_len(rsa)) {
552            status = PSA_ERROR_BUFFER_TOO_SMALL;
553            goto rsa_exit;
554        }
555#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
556        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
557        if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
558#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
559            status = mbedtls_to_psa_error(
560                mbedtls_rsa_pkcs1_encrypt(rsa,
561                                          mbedtls_psa_get_random,
562                                          MBEDTLS_PSA_RANDOM_STATE,
563                                          input_length,
564                                          input,
565                                          output));
566#else
567            status = PSA_ERROR_NOT_SUPPORTED;
568#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
569        } else
570        if (PSA_ALG_IS_RSA_OAEP(alg)) {
571#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
572            status = mbedtls_to_psa_error(
573                psa_rsa_oaep_set_padding_mode(alg, rsa));
574            if (status != PSA_SUCCESS) {
575                goto rsa_exit;
576            }
577
578            status = mbedtls_to_psa_error(
579                mbedtls_rsa_rsaes_oaep_encrypt(rsa,
580                                               mbedtls_psa_get_random,
581                                               MBEDTLS_PSA_RANDOM_STATE,
582                                               salt, salt_length,
583                                               input_length,
584                                               input,
585                                               output));
586#else
587            status = PSA_ERROR_NOT_SUPPORTED;
588#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
589        } else {
590            status = PSA_ERROR_INVALID_ARGUMENT;
591        }
592#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
593        defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
594rsa_exit:
595        if (status == PSA_SUCCESS) {
596            *output_length = mbedtls_rsa_get_len(rsa);
597        }
598
599        mbedtls_rsa_free(rsa);
600        mbedtls_free(rsa);
601#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
602        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
603    } else {
604        status = PSA_ERROR_NOT_SUPPORTED;
605    }
606
607    return status;
608}
609
610psa_status_t mbedtls_psa_asymmetric_decrypt(const psa_key_attributes_t *attributes,
611                                            const uint8_t *key_buffer,
612                                            size_t key_buffer_size,
613                                            psa_algorithm_t alg,
614                                            const uint8_t *input,
615                                            size_t input_length,
616                                            const uint8_t *salt,
617                                            size_t salt_length,
618                                            uint8_t *output,
619                                            size_t output_size,
620                                            size_t *output_length)
621{
622    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
623    (void) key_buffer;
624    (void) key_buffer_size;
625    (void) input;
626    (void) input_length;
627    (void) salt;
628    (void) salt_length;
629    (void) output;
630    (void) output_size;
631    (void) output_length;
632
633    *output_length = 0;
634
635    if (attributes->type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
636#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
637        defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
638        mbedtls_rsa_context *rsa = NULL;
639        status = mbedtls_psa_rsa_load_representation(attributes->type,
640                                                     key_buffer,
641                                                     key_buffer_size,
642                                                     &rsa);
643        if (status != PSA_SUCCESS) {
644            goto rsa_exit;
645        }
646
647        if (input_length != mbedtls_rsa_get_len(rsa)) {
648            status = PSA_ERROR_INVALID_ARGUMENT;
649            goto rsa_exit;
650        }
651#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
652        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
653
654        if (alg == PSA_ALG_RSA_PKCS1V15_CRYPT) {
655#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT)
656            status = mbedtls_to_psa_error(
657                mbedtls_rsa_pkcs1_decrypt(rsa,
658                                          mbedtls_psa_get_random,
659                                          MBEDTLS_PSA_RANDOM_STATE,
660                                          output_length,
661                                          input,
662                                          output,
663                                          output_size));
664#else
665            status = PSA_ERROR_NOT_SUPPORTED;
666#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT */
667        } else
668        if (PSA_ALG_IS_RSA_OAEP(alg)) {
669#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
670            status = mbedtls_to_psa_error(
671                psa_rsa_oaep_set_padding_mode(alg, rsa));
672            if (status != PSA_SUCCESS) {
673                goto rsa_exit;
674            }
675
676            status = mbedtls_to_psa_error(
677                mbedtls_rsa_rsaes_oaep_decrypt(rsa,
678                                               mbedtls_psa_get_random,
679                                               MBEDTLS_PSA_RANDOM_STATE,
680                                               salt, salt_length,
681                                               output_length,
682                                               input,
683                                               output,
684                                               output_size));
685#else
686            status = PSA_ERROR_NOT_SUPPORTED;
687#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP */
688        } else {
689            status = PSA_ERROR_INVALID_ARGUMENT;
690        }
691
692#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
693        defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
694rsa_exit:
695        mbedtls_rsa_free(rsa);
696        mbedtls_free(rsa);
697#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
698        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
699    } else {
700        status = PSA_ERROR_NOT_SUPPORTED;
701    }
702
703    return status;
704}
705
706#endif /* MBEDTLS_PSA_CRYPTO_C */
707