1/*
2 * Copyright (c) 2024-2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include <fstream>
17#include <map>
18#include <cstdio>
19#include <cstdlib>
20
21#include "fs_digest_utils.h"
22#include "file_utils.h"
23#include "hash_utils.h"
24
25namespace OHOS {
26namespace SignatureTools {
27
28int HashUtils::GetHashAlgsId(const std::string& algMethod)
29{
30    int result = static_cast<int>(HashAlgs::USE_NONE);
31    if (0 == algMethod.compare("SHA-256")) {
32        result = static_cast<int>(HashAlgs::USE_SHA256);
33    }
34    if (0 == algMethod.compare("SHA-384")) {
35        result = static_cast<int>(HashAlgs::USE_SHA384);
36    }
37    if (0 == algMethod.compare("SHA-512")) {
38        result = static_cast<int>(HashAlgs::USE_SHA512);
39    }
40    return result;
41}
42
43std::string HashUtils::GetHashAlgName(int algId)
44{
45    if (static_cast<int>(HashAlgs::USE_SHA256) == algId) {
46        return "SHA-256";
47    }
48    if (static_cast<int>(HashAlgs::USE_SHA384) == algId) {
49        return "SHA-384";
50    }
51    if (static_cast<int>(HashAlgs::USE_SHA512) == algId) {
52        return "SHA-512";
53    }
54    return "";
55}
56
57std::vector<int8_t> HashUtils::GetFileDigest(const std::string& inputFile, const std::string& algName)
58{
59    std::vector<int8_t> result;
60
61    std::ifstream input(inputFile, std::ios::binary);
62    if (0 != input.rdstate()) {
63        PrintErrorNumberMsg("VERIFY_ERROR", VERIFY_ERROR, "failed to get input stream object!");
64        return std::vector<int8_t>();
65    }
66
67    char buffer[HASH_LEN] = { 0 };
68    int num = 0;
69    std::map<int, std::vector<int8_t>> hashMap;
70
71    while (!input.eof()) {
72        input.read(buffer, HASH_LEN);
73
74        if (input.fail() && !input.eof()) {
75            PrintErrorNumberMsg("VERIFY_ERROR", VERIFY_ERROR, "error occurred while reading data.");
76            return std::vector<int8_t>();
77        }
78
79        std::streamsize readLen = input.gcount();
80        std::string str;
81        for (int i = 0; i < readLen; ++i) {
82            str.push_back(buffer[i]);
83        }
84
85        std::vector<int8_t> dig = GetByteDigest(str, readLen, algName);
86        hashMap.emplace(num, dig);
87        ++num;
88    }
89
90    if (hashMap.empty()) {
91        PrintErrorNumberMsg("VERIFY_ERROR", VERIFY_ERROR, "hashMap is empty.");
92        return std::vector<int8_t>();
93    }
94
95    DigestUtils fileDigestUtils(HASH_SHA256);
96    for (const auto& hashMapItem : hashMap) {
97        std::string str(hashMapItem.second.begin(), hashMapItem.second.end());
98        fileDigestUtils.AddData(str);
99    }
100    std::string digest = fileDigestUtils.Result(DigestUtils::Type::BINARY);
101    for (std::string::size_type i = 0; i < digest.size(); i++) {
102        result.push_back(digest[i]);
103    }
104    return result;
105}
106
107std::vector<int8_t> HashUtils::GetDigestFromBytes(const std::vector<int8_t>& fileBytes, int64_t length,
108                                                  const std::string& algName)
109{
110    if (fileBytes.empty() || length <= 0) {
111        PrintErrorNumberMsg("VERIFY_ERROR", VERIFY_ERROR, "file bytes is empty.");
112        return std::vector<int8_t>();
113    }
114    std::map<int, std::vector<int8_t>> hashMap;
115    int64_t readLength = 0;
116    int64_t num = 0;
117    while (readLength < length) {
118        int64_t blockLength = length - readLength > HASH_LEN ? HASH_LEN : (length - readLength);
119        std::string readStr(fileBytes.begin() + readLength, fileBytes.begin() + readLength + blockLength);
120        std::vector<int8_t> dig = GetByteDigest(readStr, readStr.size(), algName);
121        hashMap.emplace(num, dig);
122        ++num;
123        readLength += readStr.size();
124    }
125    if (hashMap.empty()) {
126        PrintErrorNumberMsg("VERIFY_ERROR", VERIFY_ERROR, "hashMap is empty.");
127        return std::vector<int8_t>();
128    }
129    DigestUtils digestUtils(HASH_SHA256);
130    for (const auto& item : hashMap) {
131        std::string str(item.second.begin(), item.second.end());
132        digestUtils.AddData(str);
133    }
134    std::string digest = digestUtils.Result(DigestUtils::Type::BINARY);
135    std::vector<int8_t> result;
136    for (std::string::size_type i = 0; i < digest.size(); i++) {
137        result.push_back(digest[i]);
138    }
139    return result;
140}
141
142std::vector<int8_t> HashUtils::GetByteDigest(const std::string& str, int count, const std::string& algMethod)
143{
144    std::vector<int8_t> result;
145    DigestUtils digestUtils(HASH_SHA256);
146    digestUtils.AddData(str);
147    std::string digest = digestUtils.Result(DigestUtils::Type::BINARY);
148    for (std::string::size_type i = 0; i < digest.size(); i++) {
149        result.push_back(digest[i]);
150    }
151    return result;
152}
153
154} // namespace SignatureTools
155} // namespace OHOS