1/*
2 * Copyright (c) 2023-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "cert_utils.h"
17
18#include <cstring>
19#include <memory>
20#include <openssl/rand.h>
21#include <securec.h>
22#include <string>
23#include <vector>
24
25#include "byte_buffer.h"
26#include "errcode.h"
27#include "huks_param_set.h"
28#include "log.h"
29
30namespace OHOS {
31namespace Security {
32namespace CodeSign {
33static const uint32_t CERT_DATA_SIZE = 8192;
34static const uint32_t CHALLENGE_LEN = 32;
35
36static inline uint8_t *CastToUint8Ptr(uint32_t *ptr)
37{
38    return reinterpret_cast<uint8_t *>(ptr);
39}
40
41bool ConstructDataToCertChain(struct HksCertChain **certChain, int certsCount)
42{
43    *certChain = static_cast<struct HksCertChain *>(malloc(sizeof(struct HksCertChain)));
44    if (*certChain == nullptr) {
45        LOG_ERROR("malloc fail");
46        return false;
47    }
48    (*certChain)->certsCount = CERT_COUNT;
49
50    (*certChain)->certs = static_cast<struct HksBlob *>(malloc(sizeof(struct HksBlob) *
51        ((*certChain)->certsCount)));
52    if ((*certChain)->certs == nullptr) {
53        free(*certChain);
54        *certChain = nullptr;
55        return false;
56    }
57    for (uint32_t i = 0; i < (*certChain)->certsCount; i++) {
58        (*certChain)->certs[i].size = CERT_DATA_SIZE;
59        (*certChain)->certs[i].data = static_cast<uint8_t *>(malloc((*certChain)->certs[i].size));
60        if ((*certChain)->certs[i].data == nullptr) {
61            LOG_ERROR("malloc fail");
62            FreeCertChain(certChain, i);
63            return false;
64        }
65    }
66    return true;
67}
68
69void FreeCertChain(struct HksCertChain **certChain, const uint32_t pos)
70{
71    if (*certChain == nullptr) {
72        return;
73    }
74    if ((*certChain)->certs == nullptr) {
75        free(*certChain);
76        *certChain = nullptr;
77        return;
78    }
79    for (uint32_t j = 0; j < pos; j++) {
80        if ((*certChain)->certs[j].data != nullptr) {
81            free((*certChain)->certs[j].data);
82            (*certChain)->certs[j].data = nullptr;
83        }
84    }
85    free((*certChain)->certs);
86    (*certChain)->certs = nullptr;
87    free(*certChain);
88    *certChain = nullptr;
89}
90
91bool FormattedCertChain(const HksCertChain *certChain, ByteBuffer &buffer)
92{
93    uint32_t certsCount = certChain->certsCount;
94    uint32_t totalLen = sizeof(uint32_t);
95    for (uint32_t i = 0; i < certsCount; i++) {
96        totalLen += sizeof(uint32_t) + certChain->certs[i].size;
97    }
98
99    buffer.Resize(totalLen);
100    if (!buffer.PutData(0, CastToUint8Ptr(&certsCount), sizeof(uint32_t))) {
101        return false;
102    }
103    uint32_t pos = sizeof(uint32_t);
104    for (uint32_t i = 0; i < certsCount; i++) {
105        if (!buffer.PutData(pos, CastToUint8Ptr(&certChain->certs[i].size), sizeof(uint32_t))) {
106            return false;
107        }
108        pos += sizeof(uint32_t);
109        if (!buffer.PutData(pos, certChain->certs[i].data, certChain->certs[i].size)) {
110            return false;
111        }
112        pos += certChain->certs[i].size;
113    }
114    return true;
115}
116
117static inline bool CheckSizeAndAssign(uint8_t *&bufferPtr, uint32_t &restSize, uint32_t &retSize)
118{
119    if (restSize < sizeof(uint32_t)) {
120        return false;
121    }
122    retSize = *reinterpret_cast<uint32_t *>(bufferPtr);
123    bufferPtr += sizeof(uint32_t);
124    restSize -= sizeof(uint32_t);
125    return true;
126}
127
128static inline bool CheckSizeAndCopy(uint8_t *&bufferPtr, uint32_t &restSize, const uint32_t size,
129    ByteBuffer &ret)
130{
131    if (restSize < size) {
132        return false;
133    }
134    if (!ret.CopyFrom(bufferPtr, size)) {
135        return false;
136    }
137    bufferPtr += size;
138    restSize -= size;
139    return true;
140}
141
142bool GetCertChainFormBuffer(const ByteBuffer &certChainBuffer,
143    ByteBuffer &signCert, ByteBuffer &issuer, std::vector<ByteBuffer> &chain)
144{
145    uint8_t *rawPtr = certChainBuffer.GetBuffer();
146    if (rawPtr == nullptr || certChainBuffer.GetSize() < sizeof(uint32_t)) {
147        LOG_ERROR("empty cert chain buffer.");
148        return false;
149    }
150    uint32_t certsCount = *reinterpret_cast<uint32_t *>(rawPtr);
151    rawPtr += sizeof(uint32_t);
152    if (certsCount == 0) {
153        return false;
154    }
155
156    uint32_t certSize;
157    bool ret = true;
158    uint32_t restSize = certChainBuffer.GetSize() - sizeof(uint32_t);
159    for (uint32_t i = 0; i < certsCount - 1; i++) {
160        if (!CheckSizeAndAssign(rawPtr, restSize, certSize)) {
161            return false;
162        }
163        if (i == 0) {
164            ret = CheckSizeAndCopy(rawPtr, restSize, certSize, signCert);
165        } else if (i == 1) {
166            ret = CheckSizeAndCopy(rawPtr, restSize, certSize, issuer);
167        } else {
168            ByteBuffer cert;
169            ret = CheckSizeAndCopy(rawPtr, restSize, certSize, cert);
170            chain.emplace_back(cert);
171        }
172        if (!ret) {
173            LOG_ERROR("failed at index = %{public}u", i);
174            break;
175        }
176    }
177    return ret;
178}
179
180std::unique_ptr<ByteBuffer> GetRandomChallenge()
181{
182    std::unique_ptr<ByteBuffer> challenge = std::make_unique<ByteBuffer>(CHALLENGE_LEN);
183    if (challenge == nullptr) {
184        return nullptr;
185    }
186    RAND_bytes(challenge->GetBuffer(), CHALLENGE_LEN);
187    return challenge;
188}
189
190bool CheckChallengeSize(uint32_t size)
191{
192    if (size > CHALLENGE_LEN) {
193        return false;
194    }
195    return true;
196}
197}
198}
199}