1/*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "huks_attest_verifier.h"
17
18#include <openssl/asn1.h>
19#include <openssl/obj_mac.h>
20#include <openssl/objects.h>
21#include <openssl/ossl_typ.h>
22#include <openssl/pem.h>
23#include <openssl/safestack.h>
24#include <openssl/x509.h>
25#include <openssl/x509v3.h>
26#include <openssl/x509_vfy.h>
27#include <string>
28#include <vector>
29
30#include "byte_buffer.h"
31#include "cert_utils.h"
32#include "log.h"
33#include "openssl_utils.h"
34
35namespace OHOS {
36namespace Security {
37namespace CodeSign {
38static const std::string ATTEST_ROOT_CA_PATH = "/system/etc/security/trusted_attest_root_ca.cer";
39static const std::vector<std::string> ATTESTTATION_EXTENSION = {
40    "1.3.6.1.4.1.2011.2.376.1.3",
41    "AttestationInfo",
42    "Attestation Information"
43};
44
45static const std::vector<std::string> SA_INFO_EXTENSION = {
46    "1.3.6.1.4.1.2011.2.376.2.1.3.1",
47    "SA INFO",
48    "SystemAbiliy Information"
49};
50
51static const std::vector<std::string> CHALLENGE_EXTENSION = {
52    "1.3.6.1.4.1.2011.2.376.2.1.4",
53    "Challenge",
54    "Challenge"
55};
56
57static const std::string LOCAL_CODE_SIGN_SA_NAME = "local_code_sign";
58
59static constexpr uint32_t MIN_VECTOR_SIZE = 3;
60static bool g_verifierInited = false;
61static int g_saNid = 0;
62static int g_challengeNid = 0;
63static int g_attestationNid = 0;
64#ifdef VERIFY_KEY_ATTEST_CERTCHAIN
65static constexpr uint32_t COMMON_NAME_BUF_SIZE = 256;
66#endif
67
68static inline int GetNidFromDefination(const std::vector<std::string> &defVector)
69{
70    if (defVector.size() < MIN_VECTOR_SIZE) {
71        return NID_undef;
72    }
73    return CreateNIDFromOID(defVector[0], defVector[1], defVector[defVector.size() - 1]);
74}
75
76static void InitVerifier()
77{
78    if (g_verifierInited) {
79        return;
80    }
81    g_saNid = GetNidFromDefination(SA_INFO_EXTENSION);
82    g_challengeNid = GetNidFromDefination(CHALLENGE_EXTENSION);
83    g_attestationNid = GetNidFromDefination(ATTESTTATION_EXTENSION);
84    LOG_DEBUG("g_saNid = %{public}d, g_challengeNid = %{public}d, g_attestationNid = %{public}d",
85        g_saNid, g_challengeNid, g_attestationNid);
86    g_verifierInited = true;
87}
88
89static bool AddCAToStore(X509_STORE *store)
90{
91    FILE *fp = fopen(ATTEST_ROOT_CA_PATH.c_str(), "r");
92    if (fp == nullptr) {
93        LOG_ERROR("Open file failed.");
94        return false;
95    }
96
97    X509 *caCert = nullptr;
98    do {
99        caCert = PEM_read_X509(fp, nullptr, nullptr, nullptr);
100        if (caCert == nullptr) {
101            break;
102        }
103        if (X509_STORE_add_cert(store, caCert) <= 0) {
104            LOG_ERROR("add cert to X509 store failed");
105            GetOpensslErrorMessage();
106        }
107        LOG_INFO("Add root CA successfully");
108    } while (caCert != nullptr);
109    (void) fclose(fp);
110    return true;
111}
112
113static bool VerifyIssurCert(X509 *cert, STACK_OF(X509) *chain)
114{
115    X509_STORE *store = X509_STORE_new();
116    if (store == nullptr) {
117        return false;
118    }
119
120    bool ret = false;
121    X509_STORE_CTX *storeCtx = nullptr;
122
123    do {
124        if (!AddCAToStore(store)) {
125            break;
126        }
127        storeCtx = X509_STORE_CTX_new();
128        if (storeCtx == nullptr) {
129            break;
130        }
131
132        if (!X509_STORE_CTX_init(storeCtx, store, cert, chain)) {
133            LOG_ERROR("init X509_STORE_CTX failed.");
134            break;
135        }
136        X509_STORE_CTX_set_purpose(storeCtx, X509_PURPOSE_ANY);
137        // because user can set date of device, validation skip time check for fool-proofing
138        X509_STORE_CTX_set_flags(storeCtx, X509_V_FLAG_NO_CHECK_TIME);
139        int index = X509_verify_cert(storeCtx);
140        if (index <= 0) {
141            index = X509_STORE_CTX_get_error(storeCtx);
142            LOG_ERROR("Verify cert failed, msg = %{public}s", X509_verify_cert_error_string(index));
143            break;
144        }
145        ret = true;
146    } while (0);
147    if (!ret) {
148        GetOpensslErrorMessage();
149    }
150    X509_STORE_CTX_free(storeCtx);
151    X509_STORE_free(store);
152    return ret;
153}
154
155static bool VerifySigningCert(X509 *signCert, X509 *issuerCert)
156{
157    EVP_PKEY *key = X509_get0_pubkey(issuerCert);
158    if (key == nullptr) {
159        LOG_ERROR("get pub key failed.");
160        return false;
161    }
162    if (X509_verify(signCert, key) <= 0) {
163        LOG_ERROR("verify signing cert failed.");
164        GetOpensslErrorMessage();
165        return false;
166    }
167    return true;
168}
169
170static bool CompareTargetValue(int nid, uint8_t *data, int size, const ByteBuffer &challenge)
171{
172    if (nid == g_saNid) {
173        std::string str(reinterpret_cast<char *>(data), size);
174        LOG_INFO("compare with proc = %{private}s", str.c_str());
175        return str.find(LOCAL_CODE_SIGN_SA_NAME) != std::string::npos;
176    } else if (nid == g_challengeNid) {
177        LOG_INFO("compare with challenge");
178        return (static_cast<uint32_t>(size) == challenge.GetSize())
179                    && (memcmp(data, challenge.GetBuffer(), size) == 0);
180    }
181    return true;
182}
183
184static bool ParseASN1Sequence(uint8_t *data, int size, const ByteBuffer &challenge)
185{
186    STACK_OF(ASN1_TYPE) *types = d2i_ASN1_SEQUENCE_ANY(
187        nullptr, const_cast<const uint8_t **>(&data), size);
188    if (types == nullptr) {
189        return false;
190    }
191
192    int num = sk_ASN1_TYPE_num(types);
193    int curNid = -1;
194    bool ret = true;
195    for (int i = 0; i < num; i++) {
196        ASN1_TYPE *type = sk_ASN1_TYPE_value(types, i);
197        if (type->type == V_ASN1_SEQUENCE) {
198            ret = ParseASN1Sequence(type->value.sequence->data, type->value.sequence->length,
199                challenge);
200        } else if (type->type == V_ASN1_OBJECT) {
201            ASN1_OBJECT *obj = type->value.object;
202            curNid = OBJ_obj2nid(obj);
203        } else if (type->type == V_ASN1_OCTET_STRING) {
204            ASN1_OCTET_STRING *value = type->value.octet_string;
205            ret = CompareTargetValue(curNid, value->data, value->length, challenge);
206        }
207        if (!ret) {
208            LOG_ERROR("Value is unexpected");
209            break;
210        }
211    }
212    return ret;
213}
214
215static bool VerifyExtension(X509 *cert, const ByteBuffer &challenge)
216{
217    if (challenge.GetBuffer() == nullptr) {
218        return false;
219    }
220
221    const STACK_OF(X509_EXTENSION) *exts = X509_get0_extensions(cert);
222    int num;
223
224    if ((num = sk_X509_EXTENSION_num(exts)) <= 0) {
225        LOG_ERROR("Get extension failed.");
226        return false;
227    }
228
229    InitVerifier();
230    for (int i = 0; i < num; i++) {
231        X509_EXTENSION *ext = sk_X509_EXTENSION_value(exts, i);
232        ASN1_OBJECT *obj = X509_EXTENSION_get_object(ext);
233        if (obj == nullptr) {
234            LOG_ERROR("Get ans1 object faild");
235            continue;
236        }
237        int curNid = OBJ_obj2nid(obj);
238        if (g_attestationNid == curNid) {
239            const ASN1_OCTET_STRING *extData = X509_EXTENSION_get_data(ext);
240            if (!ParseASN1Sequence(extData->data, extData->length, challenge)) {
241                LOG_INFO("extension verify failed.");
242                return false;
243            }
244        }
245    }
246    return true;
247}
248
249#ifdef CODE_SIGNATURE_DEBUGGABLE
250static void ShowCertInfo(const std::vector<ByteBuffer> &certChainBuffer,
251    const ByteBuffer &issuerBuffer, const ByteBuffer &certBuffer)
252{
253    std::string pem;
254    LOG_INFO("Dump cert chain");
255    for (auto cert: certChainBuffer) {
256        if (ConvertCertToPEMString(cert, pem)) {
257            LOG_INFO("%{private}s", pem.c_str());
258        }
259    }
260    LOG_INFO("Dump issuer cert");
261    if (ConvertCertToPEMString(issuerBuffer, pem)) {
262        LOG_INFO("%{private}s", pem.c_str());
263    }
264    LOG_INFO("Dump signing cert");
265    if (ConvertCertToPEMString(certBuffer, pem)) {
266        LOG_INFO("%{private}s", pem.c_str());
267    }
268}
269#endif
270
271bool VerifyCertAndExtension(X509 *signCert, X509 *issuerCert, const ByteBuffer &challenge)
272{
273    if (!VerifySigningCert(signCert, issuerCert)) {
274        return false;
275    }
276    LOG_DEBUG("Verify sign cert pass");
277
278    if (!VerifyExtension(signCert, challenge)) {
279        LOG_ERROR("Verify extension failed.");
280        return false;
281    }
282    LOG_INFO("Verify success");
283    return true;
284}
285
286static bool VerifyIntermediateCASubject(const std::vector<ByteBuffer> &certChainBuffer)
287{
288#ifndef VERIFY_KEY_ATTEST_CERTCHAIN
289    LOG_INFO("Skip intermediate CA subject verification.");
290    return true;
291#else
292    if (certChainBuffer.empty()) {
293        LOG_ERROR("The vector is empty");
294        return false;
295    }
296
297    auto certBuffer = certChainBuffer.back();
298    X509 *cert = LoadCertFromBuffer(certBuffer.GetBuffer(), certBuffer.GetSize());
299    if (cert == nullptr) {
300        LOG_ERROR("Load intermediate CA cert failed.");
301        return false;
302    }
303
304    bool ret = false;
305    do {
306        X509_NAME *subjectName = X509_get_subject_name(cert);
307        if (subjectName == nullptr) {
308            LOG_ERROR("Get subject name failed.");
309            break;
310        }
311
312        char commonNameBuf[COMMON_NAME_BUF_SIZE] = {0};
313        int len = X509_NAME_get_text_by_NID(subjectName, NID_commonName, commonNameBuf, COMMON_NAME_BUF_SIZE);
314        if (len <= 0) {
315            LOG_ERROR("Get common name failed.");
316            break;
317        }
318
319        if (!strstr(commonNameBuf, "Huawei CBG Mobile Equipment CA") &&
320            !strstr(commonNameBuf, "Huawei CBG Equipment S2 CA") &&
321            !strstr(commonNameBuf, "Huawei CBG Equipment S3 CA")) {
322            LOG_ERROR("Intermediate CA common name not matched, common name:%{private}s", commonNameBuf);
323            break;
324        }
325
326        ret = true;
327    } while (0);
328
329    X509_free(cert);
330    return ret;
331#endif
332}
333
334bool GetVerifiedCert(const ByteBuffer &buffer, const ByteBuffer &challenge, ByteBuffer &certBuffer)
335{
336    std::vector<ByteBuffer> certChainBuffer;
337    ByteBuffer issuerBuffer;
338    if (!GetCertChainFormBuffer(buffer, certBuffer, issuerBuffer, certChainBuffer)) {
339        return false;
340    }
341    X509 *issuerCert = LoadCertFromBuffer(issuerBuffer.GetBuffer(), issuerBuffer.GetSize());
342    if (issuerCert == nullptr) {
343        LOG_ERROR("Load issuerCert cert failed.");
344        return false;
345    }
346    bool ret = false;
347    X509 *signCert = nullptr;
348    STACK_OF(X509 *) certChain = nullptr;
349    do {
350        certChain = MakeStackOfCerts(certChainBuffer);
351        if (certChain == nullptr) {
352            LOG_ERROR("Load cert chain failed.");
353            break;
354        }
355        if (!VerifyIntermediateCASubject(certChainBuffer)) {
356            LOG_ERROR("Failed to verify the Intermediate CA subject.");
357            break;
358        }
359        if (!VerifyIssurCert(issuerCert, certChain)) {
360            LOG_ERROR("Verify issuer cert not pass.");
361            break;
362        }
363        LOG_DEBUG("Verify issuer cert pass");
364        signCert = LoadCertFromBuffer(certBuffer.GetBuffer(), certBuffer.GetSize());
365        if (signCert == nullptr) {
366            LOG_ERROR("Load signing cert failed.");
367            break;
368        }
369        if (!VerifyCertAndExtension(signCert, issuerCert, challenge)) {
370            break;
371        }
372        ret = true;
373    } while (0);
374    X509_free(signCert);
375    X509_free(issuerCert);
376    sk_X509_pop_free(certChain, X509_free);
377#ifdef CODE_SIGNATURE_DEBUGGABLE
378    if (!ret) {
379        ShowCertInfo(certChainBuffer, issuerBuffer, certBuffer);
380    }
381#endif
382    LOG_INFO("verify finished, ret = %{public}d.", ret);
383    return ret;
384}
385}
386}
387}
388