1 /*
2  * Copyright (c) 2024-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include <cmath>
16 
17 #include "merkle_tree_builder.h"
18 
19 using namespace OHOS::SignatureTools::Uscript;
20 namespace OHOS {
21 namespace SignatureTools {
22 
23 const int MerkleTreeBuilder::FSVERITY_HASH_PAGE_SIZE = 4096;
24 const int64_t MerkleTreeBuilder::INPUTSTREAM_MAX_SIZE = 4503599627370496L;
25 const int MerkleTreeBuilder::CHUNK_SIZE = 4096;
26 const long MerkleTreeBuilder::MAX_READ_SIZE = 4194304L;
27 const int MerkleTreeBuilder::MAX_PROCESSORS = 32;
28 const int MerkleTreeBuilder::BLOCKINGQUEUE = 4;
29 
SetAlgorithm(const std::string& algorithm)30 void MerkleTreeBuilder::SetAlgorithm(const std::string& algorithm)
31 {
32     mAlgorithm = algorithm;
33 }
34 
TransInputStreamToHashData(std::istream& inputStream, long size, ByteBuffer* outputBuffer, int bufStartIdx)35 void MerkleTreeBuilder::TransInputStreamToHashData(std::istream& inputStream,
36                                                    long size, ByteBuffer* outputBuffer, int bufStartIdx)
37 {
38     if (size == 0 || static_cast<int64_t>(size) > INPUTSTREAM_MAX_SIZE) {
39         SIGNATURE_TOOLS_LOGE("invalid input stream size");
40         CheckCalculateHashResult = false;
41         return;
42     }
43     std::vector<std::vector<int8_t>> hashes = GetDataHashes(inputStream, size);
44     int32_t writeSize = 0;
45     for (const auto& hash : hashes) {
46         outputBuffer->PutData(writeSize + bufStartIdx, hash.data(), hash.size());
47         writeSize += hash.size();
48     }
49     outputBuffer->SetLimit(outputBuffer->GetCapacity() - bufStartIdx);
50     outputBuffer->SetCapacity(outputBuffer->GetCapacity() - bufStartIdx);
51     outputBuffer->SetPosition(writeSize);
52 }
53 
GetDataHashes(std::istream& inputStream, long size)54 std::vector<std::vector<int8_t>> MerkleTreeBuilder::GetDataHashes(std::istream& inputStream, long size)
55 {
56     int count = (int)(GetChunkCount(size, MAX_READ_SIZE));
57     int chunks = (int)(GetChunkCount(size, CHUNK_SIZE));
58     std::vector<std::vector<int8_t>> hashes(chunks);
59     std::vector<std::future<void>> thread_results;
60     long readOffset = 0L;
61     for (int i = 0; i < count; i++) {
62         long readLimit = std::min(readOffset + MAX_READ_SIZE, size);
63         int readSize = (int)((readLimit - readOffset));
64         int fullChunkSize = (int)(GetFullChunkSize(readSize, CHUNK_SIZE, CHUNK_SIZE));
65         ByteBuffer* byteBuffer(new ByteBuffer(fullChunkSize));
66         std::vector<char> buffer(CHUNK_SIZE);
67         int num = 0;
68         int offset = 0;
69         int flag = 0;
70         int len = CHUNK_SIZE;
71         while (num > 0 || flag == 0) {
72             inputStream.read((buffer.data()), len);
73             if (inputStream.fail() && !inputStream.eof()) {
74                 PrintErrorNumberMsg("IO_ERROR", IO_ERROR, "Error occurred while reading data");
75                 CheckCalculateHashResult = false;
76                 return std::vector<std::vector<int8_t>>();
77             }
78             num = inputStream.gcount();
79             byteBuffer->PutData(buffer.data(), num);
80             offset += num;
81             len = std::min(CHUNK_SIZE, readSize - offset);
82             if (len <= 0 || offset == readSize) {
83                 break;
84             }
85             flag = 1;
86         }
87         if (offset != readSize) {
88             PrintErrorNumberMsg("READ_FILE_ERROR", IO_ERROR, "Error reading buffer from input");
89             CheckCalculateHashResult = false;
90             return std::vector<std::vector<int8_t>>();
91         }
92         byteBuffer->Flip();
93         int readChunkIndex = (int)(GetFullChunkSize(MAX_READ_SIZE, CHUNK_SIZE, i));
94         thread_results.push_back(mPools->Enqueue(&MerkleTreeBuilder::RunHashTask, this, std::ref(hashes),
95                                                  byteBuffer, readChunkIndex, 0));
96         readOffset += readSize;
97     }
98     for (auto& thread_result : thread_results) {
99         thread_result.wait();
100     }
101     return hashes;
102 }
103 
Slice(ByteBuffer* buffer, int begin, int end)104 ByteBuffer* MerkleTreeBuilder::Slice(ByteBuffer* buffer, int begin, int end)
105 {
106     ByteBuffer* tmpBuffer = buffer->Duplicate();
107     tmpBuffer->SetPosition(0);
108     tmpBuffer->SetLimit(end);
109     tmpBuffer->SetPosition(begin);
110     return &tmpBuffer->slice_for_codesigning();
111 }
112 
GetOffsetArrays(long dataSize, int digestSize)113 std::vector<int64_t> MerkleTreeBuilder::GetOffsetArrays(long dataSize, int digestSize)
114 {
115     std::vector<long> levelSize = GetLevelSize(dataSize, digestSize);
116     std::vector<int64_t> levelOffset(levelSize.size() + 1);
117     levelOffset[0] = 0;
118     for (int i = 0; i < levelSize.size(); i++) {
119         levelOffset[i + 1] = levelOffset[i] + levelSize[levelSize.size() - i - 1];
120     }
121     return levelOffset;
122 }
123 
GetLevelSize(long dataSize, int digestSize)124 std::vector<long> MerkleTreeBuilder::GetLevelSize(long dataSize, int digestSize)
125 {
126     std::vector<long> levelSize;
127 
128     long fullChunkSize = 0L;
129     long originalDataSize = dataSize;
130     do {
131         fullChunkSize = GetFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
132         int size = GetFullChunkSize(fullChunkSize, CHUNK_SIZE, CHUNK_SIZE);
133         levelSize.push_back(size);
134         originalDataSize = fullChunkSize;
135     } while (fullChunkSize > CHUNK_SIZE);
136     return levelSize;
137 }
138 
RunHashTask(std::vector<std::vector<int8_t>>& hashes, ByteBuffer* buffer, int readChunkIndex, int bufStartIdx)139 void MerkleTreeBuilder::RunHashTask(std::vector<std::vector<int8_t>>& hashes,
140                                     ByteBuffer* buffer, int readChunkIndex, int bufStartIdx)
141 {
142     int offset = 0;
143 
144     std::shared_ptr<ByteBuffer> bufferPtr(buffer);
145     int bufferSize = bufferPtr->GetCapacity();
146     int index = readChunkIndex;
147     while (offset < bufferSize) {
148         ByteBuffer* chunk = Slice(bufferPtr.get(), offset, offset + CHUNK_SIZE);
149         std::vector<int8_t> tmpByte(CHUNK_SIZE);
150         chunk->GetData(offset + bufStartIdx, tmpByte.data(), CHUNK_SIZE);
151         DigestUtils digestUtils(HASH_SHA256);
152         std::string tmpByteStr(tmpByte.begin(), tmpByte.end());
153         digestUtils.AddData(tmpByteStr);
154         std::string result = digestUtils.Result(DigestUtils::Type::BINARY);
155         std::vector<int8_t> hashEle;
156         for (int i = 0; i < result.size(); i++) {
157             hashEle.push_back(result[i]);
158         }
159         hashes[index++] = hashEle;
160         offset += CHUNK_SIZE;
161         delete chunk;
162     }
163 }
164 
TransInputDataToHashData(ByteBuffer* inputBuffer, ByteBuffer* outputBuffer, int64_t inputStartIdx, int64_t outputStartIdx)165 void MerkleTreeBuilder::TransInputDataToHashData(ByteBuffer* inputBuffer, ByteBuffer* outputBuffer,
166                                                  int64_t inputStartIdx, int64_t outputStartIdx)
167 {
168     long size = inputBuffer->GetCapacity();
169     int chunks = (int)GetChunkCount(size, CHUNK_SIZE);
170     std::vector<std::vector<int8_t>> hashes(chunks);
171     long readOffset = 0L;
172     int startChunkIndex = 0;
173     while (readOffset < size) {
174         long readLimit = std::min(readOffset + MAX_READ_SIZE, size);
175         ByteBuffer* buffer = Slice(inputBuffer, (int)readOffset, (int)readLimit);
176         buffer->SetPosition(0);
177         int readChunkIndex = startChunkIndex;
178         RunHashTask(hashes, buffer, readChunkIndex, inputStartIdx);
179         int readSize = (int)(readLimit - readOffset);
180         startChunkIndex += (int)GetChunkCount(readSize, CHUNK_SIZE);
181         readOffset += readSize;
182         inputStartIdx += readSize;
183     }
184     int32_t writeSize = 0;
185     for (const auto& hash : hashes) {
186         outputBuffer->PutData(writeSize + outputStartIdx, hash.data(), hash.size());
187         writeSize += hash.size();
188     }
189 }
190 
MerkleTreeBuilder()191 OHOS::SignatureTools::MerkleTreeBuilder::MerkleTreeBuilder() :mPools(new Uscript::ThreadPool(POOL_SIZE))
192 {
193     CheckCalculateHashResult = true;
194 }
195 
GenerateMerkleTree(std::istream& inputStream, long size, const FsVerityHashAlgorithm& fsVerityHashAlgorithm)196 MerkleTree* MerkleTreeBuilder::GenerateMerkleTree(std::istream& inputStream, long size,
197                                                   const FsVerityHashAlgorithm& fsVerityHashAlgorithm)
198 {
199     SetAlgorithm(fsVerityHashAlgorithm.GetHashAlgorithm());
200     int digestSize = fsVerityHashAlgorithm.GetOutputByteSize();
201     std::vector<int64_t> offsetArrays = GetOffsetArrays(size, digestSize);
202     std::unique_ptr<ByteBuffer> allHashBuffer = std::make_unique<ByteBuffer>
203         (ByteBuffer(offsetArrays[offsetArrays.size() - 1]));
204     GenerateHashDataByInputData(inputStream, size, allHashBuffer.get(), offsetArrays, digestSize);
205     GenerateHashDataByHashData(allHashBuffer.get(), offsetArrays, digestSize);
206 
207     if (CheckCalculateHashResult) {
208         return GetMerkleTree(allHashBuffer.get(), size, fsVerityHashAlgorithm);
209     }
210     return nullptr;
211 }
212 
GenerateHashDataByInputData(std::istream& inputStream, long size, ByteBuffer* outputBuffer, std::vector<int64_t>& offsetArrays, int digestSize)213 void MerkleTreeBuilder::GenerateHashDataByInputData(std::istream& inputStream, long size, ByteBuffer* outputBuffer,
214                                                     std::vector<int64_t>& offsetArrays, int digestSize)
215 {
216     int64_t inputDataOffsetBegin = offsetArrays[offsetArrays.size() - 2];
217     int64_t inputDataOffsetEnd = offsetArrays[offsetArrays.size() - 1];
218     ByteBuffer* hashBuffer = Slice(outputBuffer, 0, inputDataOffsetEnd);
219     TransInputStreamToHashData(inputStream, size, hashBuffer, inputDataOffsetBegin);
220     DataRoundupChunkSize(hashBuffer, size, digestSize);
221     delete hashBuffer;
222 }
223 
GenerateHashDataByHashData(ByteBuffer* buffer, std::vector<int64_t>& offsetArrays, int digestSize)224 void MerkleTreeBuilder::GenerateHashDataByHashData(ByteBuffer* buffer,
225                                                    std::vector<int64_t>& offsetArrays, int digestSize)
226 {
227     for (int i = offsetArrays.size() - 3; i >= 0; i--) {
228         int64_t generateOffset = offsetArrays[i];
229         int64_t originalOffset = offsetArrays[i + 1];
230         ByteBuffer* generateHashBuffer = Slice(buffer, offsetArrays[i], offsetArrays[i + 1] + offsetArrays[i]);
231         ByteBuffer* readOnlyBuffer = buffer->Duplicate();
232         ByteBuffer* originalHashBuffer = Slice(readOnlyBuffer, offsetArrays[i + 1], offsetArrays[i + 2]);
233         TransInputDataToHashData(originalHashBuffer, generateHashBuffer, originalOffset, generateOffset);
234         DataRoundupChunkSize(generateHashBuffer, originalHashBuffer->GetCapacity(), digestSize);
235         delete originalHashBuffer;
236         delete readOnlyBuffer;
237         delete generateHashBuffer;
238     }
239 }
240 
GetMerkleTree(ByteBuffer* dataBuffer, long inputDataSize, FsVerityHashAlgorithm fsVerityHashAlgorithm)241 MerkleTree* MerkleTreeBuilder::GetMerkleTree(ByteBuffer* dataBuffer, long inputDataSize,
242                                              FsVerityHashAlgorithm fsVerityHashAlgorithm)
243 {
244     int digestSize = fsVerityHashAlgorithm.GetOutputByteSize();
245     dataBuffer->Flip();
246     std::vector<int8_t> rootHash;
247     std::vector<int8_t> tree;
248     if (inputDataSize < FSVERITY_HASH_PAGE_SIZE) {
249         ByteBuffer* fsVerityHashPageBuffer = Slice(dataBuffer, 0, digestSize);
250         rootHash = std::vector<int8_t>(digestSize);
251         fsVerityHashPageBuffer->GetByte((int8_t*)rootHash.data(), digestSize);
252         if (fsVerityHashPageBuffer != nullptr) {
253             delete fsVerityHashPageBuffer;
254             fsVerityHashPageBuffer = nullptr;
255         }
256     } else {
257         tree = std::vector<int8_t>(dataBuffer->GetBufferPtr(), dataBuffer->GetBufferPtr() + dataBuffer->GetCapacity());
258         ByteBuffer* fsVerityHashPageBuffer = Slice(dataBuffer, 0, FSVERITY_HASH_PAGE_SIZE);
259         std::vector<int8_t> fsVerityHashPage(FSVERITY_HASH_PAGE_SIZE);
260         fsVerityHashPageBuffer->GetData(0, fsVerityHashPage.data(), FSVERITY_HASH_PAGE_SIZE);
261         DigestUtils digestUtils(HASH_SHA256);
262         std::string fsVerityHashPageStr(fsVerityHashPage.begin(), fsVerityHashPage.end());
263         digestUtils.AddData(fsVerityHashPageStr);
264         std::string result = digestUtils.Result(DigestUtils::Type::BINARY);
265         for (int i = 0; i < static_cast<int>(result.size()); i++) {
266             rootHash.push_back(result[i]);
267         }
268         if (fsVerityHashPageBuffer != nullptr) {
269             delete fsVerityHashPageBuffer;
270             fsVerityHashPageBuffer = nullptr;
271         }
272     }
273 
274     return new MerkleTree(rootHash, tree, fsVerityHashAlgorithm);
275 }
276 
DataRoundupChunkSize(ByteBuffer* data, long originalDataSize, int digestSize)277 void MerkleTreeBuilder::DataRoundupChunkSize(ByteBuffer* data, long originalDataSize, int digestSize)
278 {
279     long fullChunkSize = GetFullChunkSize(originalDataSize, CHUNK_SIZE, digestSize);
280     int diffValue = (int)(fullChunkSize % CHUNK_SIZE);
281     if (diffValue > 0) {
282         std::vector<int8_t> padding(CHUNK_SIZE - diffValue);
283         data->SetPosition(data->GetPosition() + (CHUNK_SIZE - diffValue));
284     }
285 }
286 
GetChunkCount(long dataSize, long divisor)287 long MerkleTreeBuilder::GetChunkCount(long dataSize, long divisor)
288 {
289     return (long)std::ceil((double)dataSize / (double)divisor);
290 }
291 
GetFullChunkSize(long dataSize, long divisor, long multiplier)292 long MerkleTreeBuilder::GetFullChunkSize(long dataSize, long divisor, long multiplier)
293 {
294     return GetChunkCount(dataSize, divisor) * multiplier;
295 }
296 } // namespace SignatureTools
297 } // namespace OHOS