1/*
2 * Copyright (c) 2024-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15#include "sign_info.h"
16
17namespace OHOS {
18namespace SignatureTools {
19
20SignInfo::SignInfo()
21{
22    saltSize = 0;
23    sigSize = 0;
24    flags = 0;
25    dataSize = 0;
26    salt = std::vector<int8_t>();
27    extensionNum = 0;
28    extensionOffset = 0;
29    signature = std::vector<int8_t>();
30    zeroPadding = std::vector<int8_t>();
31}
32
33SignInfo::SignInfo(int32_t saltSize,
34                   int32_t flags,
35                   int64_t dataSize,
36                   const std::vector<int8_t>& salt,
37                   const std::vector<int8_t>& sig)
38{
39    this->saltSize = saltSize;
40    this->flags = flags;
41    this->dataSize = dataSize;
42    if (salt.empty()) {
43        this->salt = std::vector<int8_t>(SALT_BUFFER_LENGTH, 0);
44    } else {
45        this->salt = salt;
46    }
47    signature = sig;
48    sigSize = sig.empty() ? 0 : sig.size();
49    // align for extension after signature
50    zeroPadding = std::vector<int8_t>((SignInfo::SIGNATURE_ALIGNMENT
51                                            - (sigSize % SignInfo::SIGNATURE_ALIGNMENT))
52                                            % SignInfo::SIGNATURE_ALIGNMENT, 0);
53    extensionNum = 0;
54    extensionOffset = 0;
55}
56
57SignInfo::SignInfo(int32_t saltSize,
58                   int32_t sigSize,
59                   int32_t flags,
60                   int64_t dataSize,
61                   const std::vector<int8_t>& salt,
62                   int32_t extensionNum,
63                   int32_t extensionOffset,
64                   const std::vector<int8_t>& signature,
65                   const std::vector<int8_t>& zeroPadding,
66                   const std::vector<MerkleTreeExtension*>& extensionList)
67{
68    this->saltSize = saltSize;
69    this->sigSize = sigSize;
70    this->flags = flags;
71    this->dataSize = dataSize;
72    this->salt = salt;
73    this->extensionNum = extensionNum;
74    this->extensionOffset = extensionOffset;
75    this->signature = signature;
76    this->zeroPadding = zeroPadding;
77    this->extensionList = extensionList;
78}
79
80SignInfo::SignInfo(const SignInfo& other)
81{
82    this->saltSize = other.saltSize;
83    this->sigSize = other.sigSize;
84    this->flags = other.flags;
85    this->dataSize = other.dataSize;
86    this->salt = other.salt;
87    this->extensionNum = other.extensionNum;
88    this->extensionOffset = other.extensionOffset;
89    this->signature = other.signature;
90    this->zeroPadding = other.zeroPadding;
91    for (MerkleTreeExtension* ext : other.extensionList) {
92        MerkleTreeExtension* extTmp = new MerkleTreeExtension(*(MerkleTreeExtension*)(ext));
93        this->extensionList.push_back(extTmp);
94    }
95}
96
97SignInfo& SignInfo::operator=(const SignInfo& other)
98{
99    if (this == &other) {
100        return *this;
101    }
102    this->saltSize = other.saltSize;
103    this->sigSize = other.sigSize;
104    this->flags = other.flags;
105    this->dataSize = other.dataSize;
106    this->salt = other.salt;
107    this->extensionNum = other.extensionNum;
108    this->extensionOffset = other.extensionOffset;
109    this->signature = other.signature;
110    this->zeroPadding = other.zeroPadding;
111    for (Extension* ext : other.extensionList) {
112        MerkleTreeExtension* extTmp = new MerkleTreeExtension(*(MerkleTreeExtension*)(ext));
113        this->extensionList.push_back(extTmp);
114    }
115    return *this;
116}
117
118SignInfo::~SignInfo()
119{
120    for (Extension* ext : extensionList) {
121        if (ext) {
122            delete ext;
123            ext = nullptr;
124        }
125    }
126}
127
128int32_t SignInfo::GetSize()
129{
130    int blockSize = SignInfo::SIGN_INFO_SIZE_WITHOUT_SIGNATURE + signature.size() + zeroPadding.size();
131    for (Extension* ext : extensionList) {
132        blockSize += ext->GetSize();
133    }
134    return blockSize;
135}
136
137void SignInfo::AddExtension(MerkleTreeExtension* extension)
138{
139    extensionOffset = GetSize();
140    extensionList.push_back(extension);
141    extensionNum = extensionList.size();
142}
143
144Extension* SignInfo::GetExtensionByType(int32_t type)
145{
146    for (Extension* ext : extensionList) {
147        if (ext->IsType(type)) {
148            return ext;
149        }
150    }
151    return nullptr;
152}
153
154int32_t SignInfo::GetExtensionNum()
155{
156    return extensionNum;
157}
158
159std::vector<int8_t> SignInfo::GetSignature()
160{
161    return signature;
162}
163
164int64_t SignInfo::GetDataSize()
165{
166    return dataSize;
167}
168
169void SignInfo::ToByteArray(std::vector<int8_t> &ret)
170{
171    std::unique_ptr<ByteBuffer> bf = std::make_unique<ByteBuffer>(ByteBuffer(GetSize()));
172    std::vector<int8_t> empt(GetSize());
173    bf->PutData(empt.data(), empt.size());
174    bf->Clear();
175    bf->PutInt32(saltSize);
176    bf->PutInt32(sigSize);
177    bf->PutInt32(flags);
178    bf->PutInt64(dataSize);
179    bf->PutData(salt.data(), salt.size());
180    bf->PutInt32(extensionNum);
181    bf->PutInt32(extensionOffset);
182    bf->PutData(signature.data(), signature.size());
183    bf->PutData(zeroPadding.data(), zeroPadding.size());
184    // put extension
185    for (Extension* ext : extensionList) {
186        std::vector<int8_t> ret;
187        ext->ToByteArray(ret);
188        bf->PutData(ret.data(), ret.size());
189    }
190    bf->Flip();
191    ret = std::vector<int8_t>(bf->GetBufferPtr(), bf->GetBufferPtr() + bf.get()->GetCapacity());
192    return;
193}
194
195std::vector<MerkleTreeExtension*> SignInfo::ParseMerkleTreeExtension(ByteBuffer* bf, int32_t inExtensionNum)
196{
197    std::vector<MerkleTreeExtension*> inExtensionList;
198    if (inExtensionNum == 1) {
199        // parse merkle tree extension
200        int32_t extensionType = 0;
201        bf->GetInt32(extensionType);
202        if (extensionType != MerkleTreeExtension::MERKLE_TREE_INLINED) {
203            PrintErrorNumberMsg("VERIFY_ERROR", VERIFY_ERROR,
204                                "The extension type of SignInfo is incorrect.");
205            return inExtensionList;
206        }
207        int32_t extensionSize = 0;
208        bf->GetInt32(extensionSize);
209        if (extensionSize != MerkleTreeExtension::MERKLE_TREE_EXTENSION_DATA_SIZE) {
210            PrintErrorNumberMsg("VERIFY_ERROR", VERIFY_ERROR,
211                                "The extension size of SignInfo is incorrect.");
212            return inExtensionList;
213        }
214        std::vector<int8_t> merkleTreeExtension(MerkleTreeExtension::MERKLE_TREE_EXTENSION_DATA_SIZE, 0);
215        bf->GetByte(merkleTreeExtension.data(), merkleTreeExtension.size());
216        MerkleTreeExtension* pMerkleTreeExtension = MerkleTreeExtension::FromByteArray(merkleTreeExtension);
217        if (pMerkleTreeExtension) {
218            inExtensionList.push_back(pMerkleTreeExtension);
219        }
220    }
221    return inExtensionList;
222}
223
224SignInfo SignInfo::FromByteArray(std::vector<int8_t> bytes)
225{
226    std::unique_ptr<ByteBuffer> bf = std::make_unique<ByteBuffer>(ByteBuffer(bytes.size()));
227    bf->PutData(bytes.data(), bytes.size());
228    bf->Flip();
229    int32_t inSaltSize = 0;
230    bool flag = bf->GetInt32(inSaltSize);
231    int32_t inSigSize = 0;
232    bool ret = bf->GetInt32(inSigSize);
233    if (!flag || !ret || inSaltSize < 0 || inSigSize < 0) {
234        SIGNATURE_TOOLS_LOGE("Invalid saltSize or sigSize of SignInfo, saltSize: %d, sigSize: %d",
235            inSaltSize, inSigSize);
236        return SignInfo();
237    }
238    int32_t inFlags = 0;
239    flag = bf->GetInt32(inFlags);
240    if (!flag || (inFlags != 0 && inFlags != SignInfo::FLAG_MERKLE_TREE_INCLUDED)) {
241        SIGNATURE_TOOLS_LOGE("Invalid flags of SignInfo: %d", inFlags);
242        return SignInfo();
243    }
244    int64_t inDataSize = 0;
245    flag = bf->GetInt64(inDataSize);
246    if (!flag || (inDataSize < 0)) {
247        SIGNATURE_TOOLS_LOGE("Invalid dataSize of SignInfo");
248        return SignInfo();
249    }
250    std::vector<int8_t> inSalt(SignInfo::SALT_BUFFER_LENGTH, 0);
251    bf->GetByte(inSalt.data(), SignInfo::SALT_BUFFER_LENGTH);
252    int32_t inExtensionNum = 0;
253    flag = bf->GetInt32(inExtensionNum);
254    if (!flag || inExtensionNum < 0 || inExtensionNum > SignInfo::MAX_EXTENSION_NUM) {
255        SIGNATURE_TOOLS_LOGE("Invalid extensionNum of SignInfo: %d", inExtensionNum);
256        return SignInfo();
257    }
258    int32_t inExtensionOffset = 0;
259    flag = bf->GetInt32(inExtensionOffset);
260    if (!flag || inExtensionOffset < 0 || inExtensionOffset % SignInfo::SIGNATURE_ALIGNMENT != 0) {
261        SIGNATURE_TOOLS_LOGE("Invalid extensionOffset of SignInfo: %d", inExtensionOffset);
262        return SignInfo();
263    }
264    std::vector<int8_t> inSignature(inSigSize, 0);
265    bf->GetByte(inSignature.data(), inSigSize);
266    std::vector<int8_t> inZeroPadding((SignInfo::SIGNATURE_ALIGNMENT - (inSigSize % SignInfo::SIGNATURE_ALIGNMENT))
267                                      % SignInfo::SIGNATURE_ALIGNMENT, 0);
268    bf->GetByte(inZeroPadding.data(), inZeroPadding.size());
269    std::vector<MerkleTreeExtension*> inExtensionList = ParseMerkleTreeExtension(bf.get(), inExtensionNum);
270    return SignInfo(inSaltSize, inSigSize, inFlags, inDataSize, inSalt, inExtensionNum, inExtensionOffset,
271                    inSignature, inZeroPadding, inExtensionList);
272}
273
274}
275}