1 /*
2 * Copyright (c) 2024-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include <cmath>
16
17 #include "merkle_tree_builder.h"
18
19 using namespace OHOS::SignatureTools::Uscript;
20 namespace OHOS {
21 namespace SignatureTools {
22
23 const int MerkleTreeBuilder::FSVERITY_HASH_PAGE_SIZE = 4096;
24 const int64_t MerkleTreeBuilder::INPUTSTREAM_MAX_SIZE = 4503599627370496L;
25 const int MerkleTreeBuilder::CHUNK_SIZE = 4096;
26 const long MerkleTreeBuilder::MAX_READ_SIZE = 4194304L;
27 const int MerkleTreeBuilder::MAX_PROCESSORS = 32;
28 const int MerkleTreeBuilder::BLOCKINGQUEUE = 4;
29
SetAlgorithm(const std::string& algorithm)30 void MerkleTreeBuilder::SetAlgorithm(const std::string& algorithm)
31 {
32 mAlgorithm = algorithm;
33 }
34
TransInputStreamToHashData(std::istream& inputStream, long size, ByteBuffer* outputBuffer, int bufStartIdx)35 void MerkleTreeBuilder::TransInputStreamToHashData(std::istream& inputStream,
36 long size, ByteBuffer* outputBuffer, int bufStartIdx)
37 {
38 if (size == 0 || static_cast<int64_t>(size) > INPUTSTREAM_MAX_SIZE) {
39 SIGNATURE_TOOLS_LOGE("invalid input stream size");
40 CheckCalculateHashResult = false;
41 return;
42 }
43 std::vector<std::vector<int8_t>> hashes = GetDataHashes(inputStream, size);
44 int32_t writeSize = 0;
45 for (const auto& hash : hashes) {
46 outputBuffer->PutData(writeSize + bufStartIdx, hash.data(), hash.size());
47 writeSize += hash.size();
48 }
49 outputBuffer->SetLimit(outputBuffer->GetCapacity() - bufStartIdx);
50 outputBuffer->SetCapacity(outputBuffer->GetCapacity() - bufStartIdx);
51 outputBuffer->SetPosition(writeSize);
52 }
53
GetDataHashes(std::istream& inputStream, long size)54 std::vector<std::vector<int8_t>> MerkleTreeBuilder::GetDataHashes(std::istream& inputStream, long size)
55 {
56 int count = (int)(GetChunkCount(size, MAX_READ_SIZE));
57 int chunks = (int)(GetChunkCount(size, CHUNK_SIZE));
58 std::vector<std::vector<int8_t>> hashes(chunks);
59 std::vector<std::future<void>> thread_results;
60 long readOffset = 0L;
61 for (int i = 0; i < count; i++) {
62 long readLimit = std::min(readOffset + MAX_READ_SIZE, size);
63 int readSize = (int)((readLimit - readOffset));
64 int fullChunkSize = (int)(GetFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE));
65 ByteBuffer* byteBuffer(new ByteBuffer(fullChunkSize));
66 std::vector<char> buffer(CHUNK_SIZE);
67 int num = 0;
68 int offset = 0;
69 int flag = 0;
70 int len = CHUNK_SIZE;
71 while (num > 0 || flag == 0) {
72 inputStream.read((buffer.data()), len);
73 if (inputStream.fail() && !inputStream.eof()) {
74 PrintErrorNumberMsg("IO_ERROR", IO_ERROR, "Error occurred while reading data");
75 CheckCalculateHashResult = false;
76 return std::vector<std::vector<int8_t>>();
77 }
78 num = inputStream.gcount();
79 byteBuffer->PutData(buffer.data(), num);
80 offset += num;
81 len = std::min(CHUNK_SIZE, readSize - offset);
82 if (len <= 0 || offset == readSize) {
83 break;
84 }
85 flag = 1;
86 }
87 if (offset != readSize) {
88 PrintErrorNumberMsg("READ_FILE_ERROR", IO_ERROR, "Error reading buffer from input");
89 CheckCalculateHashResult = false;
90 return std::vector<std::vector<int8_t>>();
91 }
92 byteBuffer->Flip();
93 int readChunkIndex = (int)(GetFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i));
94 thread_results.push_back(mPools->Enqueue(&MerkleTreeBuilder::RunHashTask, this, std::ref(hashes),
95 byteBuffer, readChunkIndex, 0));
96 readOffset += readSize;
97 }
98 for (auto& thread_result : thread_results) {
99 thread_result.wait();
100 }
101 return hashes;
102 }
103
Slice(ByteBuffer* buffer, int begin, int end)104 ByteBuffer* MerkleTreeBuilder::Slice(ByteBuffer* buffer, int begin, int end)
105 {
106 ByteBuffer* tmpBuffer = buffer->Duplicate();
107 tmpBuffer->SetPosition(0);
108 tmpBuffer->SetLimit(end);
109 tmpBuffer->SetPosition(begin);
110 return &tmpBuffer->slice_for_codesigning();
111 }
112
GetOffsetArrays(long dataSize, int digestSize)113 std::vector<int64_t> MerkleTreeBuilder::GetOffsetArrays(long dataSize, int digestSize)
114 {
115 std::vector<long> levelSize = GetLevelSize(dataSize, digestSize);
116 std::vector<int64_t> levelOffset(levelSize.size() + 1);
117 levelOffset[0] = 0;
118 for (int i = 0; i < levelSize.size(); i++) {
119 levelOffset[i + 1] = levelOffset[i] + levelSize[levelSize.size() - i - 1];
120 }
121 return levelOffset;
122 }
123
GetLevelSize(long dataSize, int digestSize)124 std::vector<long> MerkleTreeBuilder::GetLevelSize(long dataSize, int digestSize)
125 {
126 std::vector<long> levelSize;
127
128 long fullChunkSize = 0L;
129 long originalDataSize = dataSize;
130 do {
131 fullChunkSize = GetFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
132 int size = GetFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE);
133 levelSize.push_back(size);
134 originalDataSize = fullChunkSize;
135 } while (fullChunkSize > CHUNK_SIZE);
136 return levelSize;
137 }
138
RunHashTask(std::vector<std::vector<int8_t>>& hashes, ByteBuffer* buffer, int readChunkIndex, int bufStartIdx)139 void MerkleTreeBuilder::RunHashTask(std::vector<std::vector<int8_t>>& hashes,
140 ByteBuffer* buffer, int readChunkIndex, int bufStartIdx)
141 {
142 int offset = 0;
143
144 std::shared_ptr<ByteBuffer> bufferPtr(buffer);
145 int bufferSize = bufferPtr->GetCapacity();
146 int index = readChunkIndex;
147 while (offset < bufferSize) {
148 ByteBuffer* chunk = Slice(bufferPtr.get(), offset, offset + CHUNK_SIZE);
149 std::vector<int8_t> tmpByte(CHUNK_SIZE);
150 chunk->GetData(offset + bufStartIdx, tmpByte.data(), CHUNK_SIZE);
151 DigestUtils digestUtils(HASH_SHA256);
152 std::string tmpByteStr(tmpByte.begin(), tmpByte.end());
153 digestUtils.AddData(tmpByteStr);
154 std::string result = digestUtils.Result(DigestUtils::Type::BINARY);
155 std::vector<int8_t> hashEle;
156 for (int i = 0; i < result.size(); i++) {
157 hashEle.push_back(result[i]);
158 }
159 hashes[index++] = hashEle;
160 offset += CHUNK_SIZE;
161 delete chunk;
162 }
163 }
164
TransInputDataToHashData(ByteBuffer* inputBuffer, ByteBuffer* outputBuffer, int64_t inputStartIdx, int64_t outputStartIdx)165 void MerkleTreeBuilder::TransInputDataToHashData(ByteBuffer* inputBuffer, ByteBuffer* outputBuffer,
166 int64_t inputStartIdx, int64_t outputStartIdx)
167 {
168 long size = inputBuffer->GetCapacity();
169 int chunks = (int)GetChunkCount(size, CHUNK_SIZE);
170 std::vector<std::vector<int8_t>> hashes(chunks);
171 long readOffset = 0L;
172 int startChunkIndex = 0;
173 while (readOffset < size) {
174 long readLimit = std::min(readOffset + MAX_READ_SIZE, size);
175 ByteBuffer* buffer = Slice(inputBuffer, (int)readOffset, (int)readLimit);
176 buffer->SetPosition(0);
177 int readChunkIndex = startChunkIndex;
178 RunHashTask(hashes, buffer, readChunkIndex, inputStartIdx);
179 int readSize = (int)(readLimit - readOffset);
180 startChunkIndex += (int)GetChunkCount(readSize, CHUNK_SIZE);
181 readOffset += readSize;
182 inputStartIdx += readSize;
183 }
184 int32_t writeSize = 0;
185 for (const auto& hash : hashes) {
186 outputBuffer->PutData(writeSize + outputStartIdx, hash.data(), hash.size());
187 writeSize += hash.size();
188 }
189 }
190
MerkleTreeBuilder()191 OHOS::SignatureTools::MerkleTreeBuilder::MerkleTreeBuilder() :mPools(new Uscript::ThreadPool(POOL_SIZE))
192 {
193 CheckCalculateHashResult = true;
194 }
195
GenerateMerkleTree(std::istream& inputStream, long size, const FsVerityHashAlgorithm& fsVerityHashAlgorithm)196 MerkleTree* MerkleTreeBuilder::GenerateMerkleTree(std::istream& inputStream, long size,
197 const FsVerityHashAlgorithm& fsVerityHashAlgorithm)
198 {
199 SetAlgorithm(fsVerityHashAlgorithm.GetHashAlgorithm());
200 int digestSize = fsVerityHashAlgorithm.GetOutputByteSize();
201 std::vector<int64_t> offsetArrays = GetOffsetArrays(size, digestSize);
202 std::unique_ptr<ByteBuffer> allHashBuffer = std::make_unique<ByteBuffer>
203 (ByteBuffer(offsetArrays[offsetArrays.size() - 1]));
204 GenerateHashDataByInputData(inputStream, size, allHashBuffer.get(), offsetArrays, digestSize);
205 GenerateHashDataByHashData(allHashBuffer.get(), offsetArrays, digestSize);
206
207 if (CheckCalculateHashResult) {
208 return GetMerkleTree(allHashBuffer.get(), size, fsVerityHashAlgorithm);
209 }
210 return nullptr;
211 }
212
GenerateHashDataByInputData(std::istream& inputStream, long size, ByteBuffer* outputBuffer, std::vector<int64_t>& offsetArrays, int digestSize)213 void MerkleTreeBuilder::GenerateHashDataByInputData(std::istream& inputStream, long size, ByteBuffer* outputBuffer,
214 std::vector<int64_t>& offsetArrays, int digestSize)
215 {
216 int64_t inputDataOffsetBegin = offsetArrays[offsetArrays.size() - 2];
217 int64_t inputDataOffsetEnd = offsetArrays[offsetArrays.size() - 1];
218 ByteBuffer* hashBuffer = Slice(outputBuffer, 0, inputDataOffsetEnd);
219 TransInputStreamToHashData(inputStream, size, hashBuffer, inputDataOffsetBegin);
220 DataRoundupChunkSize(hashBuffer, size, digestSize);
221 delete hashBuffer;
222 }
223
GenerateHashDataByHashData(ByteBuffer* buffer, std::vector<int64_t>& offsetArrays, int digestSize)224 void MerkleTreeBuilder::GenerateHashDataByHashData(ByteBuffer* buffer,
225 std::vector<int64_t>& offsetArrays, int digestSize)
226 {
227 for (int i = offsetArrays.size() - 3; i >= 0; i--) {
228 int64_t generateOffset = offsetArrays[i];
229 int64_t originalOffset = offsetArrays[i + 1];
230 ByteBuffer* generateHashBuffer = Slice(buffer, offsetArrays[i], offsetArrays[i + 1] + offsetArrays[i]);
231 ByteBuffer* readOnlyBuffer = buffer->Duplicate();
232 ByteBuffer* originalHashBuffer = Slice(readOnlyBuffer, offsetArrays[i + 1], offsetArrays[i + 2]);
233 TransInputDataToHashData(originalHashBuffer, generateHashBuffer, originalOffset, generateOffset);
234 DataRoundupChunkSize(generateHashBuffer, originalHashBuffer->GetCapacity(), digestSize);
235 delete originalHashBuffer;
236 delete readOnlyBuffer;
237 delete generateHashBuffer;
238 }
239 }
240
GetMerkleTree(ByteBuffer* dataBuffer, long inputDataSize, FsVerityHashAlgorithm fsVerityHashAlgorithm)241 MerkleTree* MerkleTreeBuilder::GetMerkleTree(ByteBuffer* dataBuffer, long inputDataSize,
242 FsVerityHashAlgorithm fsVerityHashAlgorithm)
243 {
244 int digestSize = fsVerityHashAlgorithm.GetOutputByteSize();
245 dataBuffer->Flip();
246 std::vector<int8_t> rootHash;
247 std::vector<int8_t> tree;
248 if (inputDataSize < FSVERITY_HASH_PAGE_SIZE) {
249 ByteBuffer* fsVerityHashPageBuffer = Slice(dataBuffer, 0, digestSize);
250 rootHash = std::vector<int8_t>(digestSize);
251 fsVerityHashPageBuffer->GetByte((int8_t*)rootHash.data(), digestSize);
252 if (fsVerityHashPageBuffer != nullptr) {
253 delete fsVerityHashPageBuffer;
254 fsVerityHashPageBuffer = nullptr;
255 }
256 } else {
257 tree = std::vector<int8_t>(dataBuffer->GetBufferPtr(), dataBuffer->GetBufferPtr() + dataBuffer->GetCapacity());
258 ByteBuffer* fsVerityHashPageBuffer = Slice(dataBuffer, 0, FSVERITY_HASH_PAGE_SIZE);
259 std::vector<int8_t> fsVerityHashPage(FSVERITY_HASH_PAGE_SIZE);
260 fsVerityHashPageBuffer->GetData(0, fsVerityHashPage.data(), FSVERITY_HASH_PAGE_SIZE);
261 DigestUtils digestUtils(HASH_SHA256);
262 std::string fsVerityHashPageStr(fsVerityHashPage.begin(), fsVerityHashPage.end());
263 digestUtils.AddData(fsVerityHashPageStr);
264 std::string result = digestUtils.Result(DigestUtils::Type::BINARY);
265 for (int i = 0; i < static_cast<int>(result.size()); i++) {
266 rootHash.push_back(result[i]);
267 }
268 if (fsVerityHashPageBuffer != nullptr) {
269 delete fsVerityHashPageBuffer;
270 fsVerityHashPageBuffer = nullptr;
271 }
272 }
273
274 return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm);
275 }
276
DataRoundupChunkSize(ByteBuffer* data, long originalDataSize, int digestSize)277 void MerkleTreeBuilder::DataRoundupChunkSize(ByteBuffer* data, long originalDataSize, int digestSize)
278 {
279 long fullChunkSize = GetFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
280 int diffValue = (int)(fullChunkSize % CHUNK_SIZE);
281 if (diffValue > 0) {
282 std::vector<int8_t> padding(CHUNK_SIZE - diffValue);
283 data->SetPosition(data->GetPosition() + (CHUNK_SIZE - diffValue));
284 }
285 }
286
GetChunkCount(long dataSize, long divisor)287 long MerkleTreeBuilder::GetChunkCount(long dataSize, long divisor)
288 {
289 return (long)std::ceil((double)dataSize / (double)divisor);
290 }
291
GetFullChunkSize(long dataSize, long divisor, long multiplier)292 long MerkleTreeBuilder::GetFullChunkSize(long dataSize, long divisor, long multiplier)
293 {
294 return GetChunkCount(dataSize, divisor) * multiplier;
295 }
296 } // namespace SignatureTools
297 } // namespace OHOS