/* BEGIN_HEADER */

#include "psa/crypto.h"
#include "test/psa_crypto_helpers.h"

static int test_equal_status(const char *test,
                             int line_no, const char *filename,
                             psa_status_t value1,
                             psa_status_t value2)
{
    if ((value1 == PSA_ERROR_INVALID_ARGUMENT &&
         value2 == PSA_ERROR_NOT_SUPPORTED) ||
        (value1 == PSA_ERROR_NOT_SUPPORTED &&
         value2 == PSA_ERROR_INVALID_ARGUMENT)) {
        return 1;
    }
    return mbedtls_test_equal(test, line_no, filename, value1, value2);
}

/** Like #TEST_EQUAL, but expects #psa_status_t values and treats
 * #PSA_ERROR_INVALID_ARGUMENT and #PSA_ERROR_NOT_SUPPORTED as
 * interchangeable.
 *
 * This test suite currently allows NOT_SUPPORTED and INVALID_ARGUMENT
 * to be interchangeable in places where the library's behavior does not
 * match the strict expectations of the test case generator. In the long
 * run, it would be better to clarify the expectations and reconcile the
 * library and the test case generator.
 */
#define TEST_STATUS(expr1, expr2)                                     \
    do {                                                                \
        if (!test_equal_status( #expr1 " == " #expr2, __LINE__, __FILE__, \
                                expr1, expr2))                       \
        goto exit;                                                  \
    } while (0)

/* END_HEADER */

/* BEGIN_DEPENDENCIES
 * depends_on:MBEDTLS_PSA_CRYPTO_C
 * END_DEPENDENCIES
 */

/* BEGIN_CASE */
void hash_fail(int alg_arg, int expected_status_arg)
{
    psa_status_t expected_status = expected_status_arg;
    psa_algorithm_t alg = alg_arg;
    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
    uint8_t input[1] = { 'A' };
    uint8_t output[PSA_HASH_MAX_SIZE] = { 0 };
    size_t length = SIZE_MAX;

    PSA_INIT();

    TEST_EQUAL(expected_status,
               psa_hash_setup(&operation, alg));
    TEST_EQUAL(expected_status,
               psa_hash_compute(alg, input, sizeof(input),
                                output, sizeof(output), &length));
    TEST_EQUAL(expected_status,
               psa_hash_compare(alg, input, sizeof(input),
                                output, sizeof(output)));

exit:
    psa_hash_abort(&operation);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void mac_fail(int key_type_arg, data_t *key_data,
              int alg_arg, int expected_status_arg)
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t input[1] = { 'A' };
    uint8_t output[PSA_MAC_MAX_SIZE] = { 0 };
    size_t length = SIZE_MAX;

    PSA_INIT();

    psa_set_key_type(&attributes, key_type);
    psa_set_key_usage_flags(&attributes,
                            PSA_KEY_USAGE_SIGN_HASH |
                            PSA_KEY_USAGE_VERIFY_HASH);
    psa_set_key_algorithm(&attributes, alg);
    PSA_ASSERT(psa_import_key(&attributes,
                              key_data->x, key_data->len,
                              &key_id));

    TEST_STATUS(expected_status,
                psa_mac_sign_setup(&operation, key_id, alg));
    TEST_STATUS(expected_status,
                psa_mac_verify_setup(&operation, key_id, alg));
    TEST_STATUS(expected_status,
                psa_mac_compute(key_id, alg,
                                input, sizeof(input),
                                output, sizeof(output), &length));
    TEST_STATUS(expected_status,
                psa_mac_verify(key_id, alg,
                               input, sizeof(input),
                               output, sizeof(output)));

exit:
    psa_mac_abort(&operation);
    psa_destroy_key(key_id);
    psa_reset_key_attributes(&attributes);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void cipher_fail(int key_type_arg, data_t *key_data,
                 int alg_arg, int expected_status_arg)
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_cipher_operation_t operation = PSA_CIPHER_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t input[1] = { 'A' };
    uint8_t output[64] = { 0 };
    size_t length = SIZE_MAX;

    PSA_INIT();

    psa_set_key_type(&attributes, key_type);
    psa_set_key_usage_flags(&attributes,
                            PSA_KEY_USAGE_ENCRYPT |
                            PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    PSA_ASSERT(psa_import_key(&attributes,
                              key_data->x, key_data->len,
                              &key_id));

    TEST_STATUS(expected_status,
                psa_cipher_encrypt_setup(&operation, key_id, alg));
    TEST_STATUS(expected_status,
                psa_cipher_decrypt_setup(&operation, key_id, alg));
    TEST_STATUS(expected_status,
                psa_cipher_encrypt(key_id, alg,
                                   input, sizeof(input),
                                   output, sizeof(output), &length));
    TEST_STATUS(expected_status,
                psa_cipher_decrypt(key_id, alg,
                                   input, sizeof(input),
                                   output, sizeof(output), &length));

exit:
    psa_cipher_abort(&operation);
    psa_destroy_key(key_id);
    psa_reset_key_attributes(&attributes);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void aead_fail(int key_type_arg, data_t *key_data,
               int alg_arg, int expected_status_arg)
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_aead_operation_t operation = PSA_AEAD_OPERATION_INIT;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t input[16] = "ABCDEFGHIJKLMNO";
    uint8_t output[64] = { 0 };
    size_t length = SIZE_MAX;

    PSA_INIT();

    psa_set_key_type(&attributes, key_type);
    psa_set_key_usage_flags(&attributes,
                            PSA_KEY_USAGE_ENCRYPT |
                            PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    PSA_ASSERT(psa_import_key(&attributes,
                              key_data->x, key_data->len,
                              &key_id));

    TEST_STATUS(expected_status,
                psa_aead_encrypt_setup(&operation, key_id, alg));
    TEST_STATUS(expected_status,
                psa_aead_decrypt_setup(&operation, key_id, alg));
    TEST_STATUS(expected_status,
                psa_aead_encrypt(key_id, alg,
                                 input, sizeof(input),
                                 NULL, 0, input, sizeof(input),
                                 output, sizeof(output), &length));
    TEST_STATUS(expected_status,
                psa_aead_decrypt(key_id, alg,
                                 input, sizeof(input),
                                 NULL, 0, input, sizeof(input),
                                 output, sizeof(output), &length));

exit:
    psa_aead_abort(&operation);
    psa_destroy_key(key_id);
    psa_reset_key_attributes(&attributes);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void sign_fail(int key_type_arg, data_t *key_data,
               int alg_arg, int private_only,
               int expected_status_arg)
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t input[1] = { 'A' };
    uint8_t output[PSA_SIGNATURE_MAX_SIZE] = { 0 };
    size_t length = SIZE_MAX;
    psa_sign_hash_interruptible_operation_t sign_operation =
        psa_sign_hash_interruptible_operation_init();

    psa_verify_hash_interruptible_operation_t verify_operation =
        psa_verify_hash_interruptible_operation_init();



    PSA_INIT();

    psa_set_key_type(&attributes, key_type);
    psa_set_key_usage_flags(&attributes,
                            PSA_KEY_USAGE_SIGN_HASH |
                            PSA_KEY_USAGE_VERIFY_HASH);
    psa_set_key_algorithm(&attributes, alg);
    PSA_ASSERT(psa_import_key(&attributes,
                              key_data->x, key_data->len,
                              &key_id));

    TEST_STATUS(expected_status,
                psa_sign_hash(key_id, alg,
                              input, sizeof(input),
                              output, sizeof(output), &length));

    TEST_STATUS(expected_status,
                psa_sign_hash_start(&sign_operation, key_id, alg,
                                    input, sizeof(input)));

    PSA_ASSERT(psa_sign_hash_abort(&sign_operation));

    if (!private_only) {
        /* Determine a plausible signature size to avoid an INVALID_SIGNATURE
         * error based on this. */
        PSA_ASSERT(psa_get_key_attributes(key_id, &attributes));
        size_t key_bits = psa_get_key_bits(&attributes);
        size_t output_length = sizeof(output);
        if (PSA_KEY_TYPE_IS_RSA(key_type)) {
            output_length = PSA_BITS_TO_BYTES(key_bits);
        } else if (PSA_KEY_TYPE_IS_ECC(key_type)) {
            output_length = 2 * PSA_BITS_TO_BYTES(key_bits);
        }
        TEST_ASSERT(output_length <= sizeof(output));
        TEST_STATUS(expected_status,
                    psa_verify_hash(key_id, alg,
                                    input, sizeof(input),
                                    output, output_length));

        TEST_STATUS(expected_status,
                    psa_verify_hash_start(&verify_operation, key_id, alg,
                                          input, sizeof(input),
                                          output, output_length));

        PSA_ASSERT(psa_verify_hash_abort(&verify_operation));
    }

exit:
    psa_destroy_key(key_id);
    psa_reset_key_attributes(&attributes);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void asymmetric_encryption_fail(int key_type_arg, data_t *key_data,
                                int alg_arg, int private_only,
                                int expected_status_arg)
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t plaintext[PSA_ASYMMETRIC_DECRYPT_OUTPUT_MAX_SIZE] = { 0 };
    uint8_t ciphertext[PSA_ASYMMETRIC_ENCRYPT_OUTPUT_MAX_SIZE] = { 0 };
    size_t length = SIZE_MAX;

    PSA_INIT();

    psa_set_key_type(&attributes, key_type);
    psa_set_key_usage_flags(&attributes,
                            PSA_KEY_USAGE_ENCRYPT |
                            PSA_KEY_USAGE_DECRYPT);
    psa_set_key_algorithm(&attributes, alg);
    PSA_ASSERT(psa_import_key(&attributes,
                              key_data->x, key_data->len,
                              &key_id));

    if (!private_only) {
        TEST_STATUS(expected_status,
                    psa_asymmetric_encrypt(key_id, alg,
                                           plaintext, 1,
                                           NULL, 0,
                                           ciphertext, sizeof(ciphertext),
                                           &length));
    }
    TEST_STATUS(expected_status,
                psa_asymmetric_decrypt(key_id, alg,
                                       ciphertext, sizeof(ciphertext),
                                       NULL, 0,
                                       plaintext, sizeof(plaintext),
                                       &length));

exit:
    psa_destroy_key(key_id);
    psa_reset_key_attributes(&attributes);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void key_derivation_fail(int alg_arg, int expected_status_arg)
{
    psa_status_t expected_status = expected_status_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_derivation_operation_t operation = PSA_KEY_DERIVATION_OPERATION_INIT;

    PSA_INIT();

    TEST_EQUAL(expected_status,
               psa_key_derivation_setup(&operation, alg));

exit:
    psa_key_derivation_abort(&operation);
    PSA_DONE();
}
/* END_CASE */

/* BEGIN_CASE */
void key_agreement_fail(int key_type_arg, data_t *key_data,
                        int alg_arg, int private_only,
                        int expected_status_arg)
{
    psa_status_t expected_status = expected_status_arg;
    psa_key_type_t key_type = key_type_arg;
    psa_algorithm_t alg = alg_arg;
    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
    uint8_t public_key[PSA_EXPORT_PUBLIC_KEY_MAX_SIZE] = { 0 };
    size_t public_key_length = 0;
    uint8_t output[PSA_RAW_KEY_AGREEMENT_OUTPUT_MAX_SIZE] = { 0 };
    size_t length = 0;
    psa_key_derivation_operation_t operation = PSA_KEY_DERIVATION_OPERATION_INIT;

    PSA_INIT();

    psa_set_key_type(&attributes, key_type);
    psa_set_key_usage_flags(&attributes,
                            PSA_KEY_USAGE_DERIVE);
    psa_set_key_algorithm(&attributes, alg);
    PSA_ASSERT(psa_import_key(&attributes,
                              key_data->x, key_data->len,
                              &key_id));
    if (PSA_KEY_TYPE_IS_KEY_PAIR(key_type) ||
        PSA_KEY_TYPE_IS_PUBLIC_KEY(key_type)) {
        PSA_ASSERT(psa_export_public_key(key_id,
                                         public_key, sizeof(public_key),
                                         &public_key_length));
    }

    TEST_STATUS(expected_status,
                psa_raw_key_agreement(alg, key_id,
                                      public_key, public_key_length,
                                      output, sizeof(output), &length));

#if defined(PSA_WANT_ALG_HKDF) && defined(PSA_WANT_ALG_SHA_256)
    PSA_ASSERT(psa_key_derivation_setup(&operation,
                                        PSA_ALG_HKDF(PSA_ALG_SHA_256)));
    TEST_STATUS(expected_status,
                psa_key_derivation_key_agreement(
                    &operation,
                    PSA_KEY_DERIVATION_INPUT_SECRET,
                    key_id,
                    public_key, public_key_length));
#endif

    /* There are no public-key operations. */
    (void) private_only;

exit:
    psa_key_derivation_abort(&operation);
    psa_destroy_key(key_id);
    psa_reset_key_attributes(&attributes);
    PSA_DONE();
}
/* END_CASE */
