1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "model_manager.h"
17 
18 #include <dlfcn.h>
19 
20 #include "directory_ex.h"
21 #include "file_ex.h"
22 #include "config_data_manager.h"
23 #include "security_guard_log.h"
24 #include "model_manager_impl.h"
25 #include "database_manager.h"
26 #include "security_guard_define.h"
27 
28 namespace OHOS::Security::SecurityGuard {
29 std::shared_ptr<IModelManager> ModelManager::modelManagerApi_ = std::make_shared<ModelManagerImpl>();
30 
31 namespace {
32     constexpr const char *PREFIX_MODEL_PATH = "/system/lib64/lib";
33     constexpr uint32_t AUDIT_MODEL = 3001000003;
34 }
35 
Init()36 void ModelManager::Init()
37 {
38     std::vector<uint32_t> modelIds = ConfigDataManager::GetInstance().GetAllModelIds();
39     ModelCfg cfg;
40     for (uint32_t modelId : modelIds) {
41         bool success = ConfigDataManager::GetInstance().GetModelConfig(modelId, cfg);
42         if (!success) {
43             continue;
44         }
45         SGLOGI("modelId is %{public}u, start_mode: %{public}u", modelId, cfg.startMode);
46         if (cfg.startMode != START_ON_STARTUP) {
47             continue;
48         }
49         if (cfg.modelId != AUDIT_MODEL) {
50             (void) InitModel(modelId);
51             continue;
52         }
53     }
54     const uint32_t modelId = 3001000008;
55     if (FileExists("/data/service/el1/public/security_guard/related_event_analysis.json")) {
56         (void) InitModel(modelId);
57     }
58 }
59 
InitModel(uint32_t modelId)60 int32_t ModelManager::InitModel(uint32_t modelId)
61 {
62     std::unordered_map<uint32_t, std::unique_ptr<ModelAttrs>>::iterator iter;
63     {
64         std::lock_guard<std::mutex> lock(mutex_);
65         iter = modelIdApiMap_.find(modelId);
66         if (iter != modelIdApiMap_.end() && iter->second != nullptr && iter->second->GetModelApi() != nullptr) {
67             iter->second->GetModelApi()->Release();
68             modelIdApiMap_.erase(iter);
69         }
70     }
71 
72     ModelCfg cfg;
73     bool success = ConfigDataManager::GetInstance().GetModelConfig(modelId, cfg);
74     if (!success) {
75         SGLOGE("the model not support, modelId=%{public}u", modelId);
76         return NOT_FOUND;
77     }
78     std::string realPath;
79     if (!PathToRealPath(cfg.path, realPath) || realPath.find(PREFIX_MODEL_PATH) != 0) {
80         return FILE_ERR;
81     }
82     void *handle = dlopen(realPath.c_str(), RTLD_LAZY);
83     if (handle == nullptr) {
84         SGLOGE("modelId=%{public}u, open failed, reason:%{public}s", modelId, dlerror());
85         return FAILED;
86     }
87     std::unique_ptr<ModelAttrs> attr = std::make_unique<ModelAttrs>();
88     attr->SetHandle(handle);
89     auto getModelApi = (GetModelApi)dlsym(handle, "GetModelApi");
90     if (getModelApi == nullptr) {
91         SGLOGE("get model api func is nullptr");
92         return FAILED;
93     }
94     IModel *api = getModelApi();
95     if (api == nullptr) {
96         SGLOGE("get model api is nullptr");
97         return FAILED;
98     }
99     attr->SetModelApi(api);
100     int32_t ret = attr->GetModelApi()->Init(modelManagerApi_);
101     if (ret != SUCCESS) {
102         SGLOGE("model api init failed, ret=%{public}d", ret);
103         return ret;
104     }
105     {
106         std::lock_guard<std::mutex> lock(mutex_);
107         modelIdApiMap_[modelId] = std::move(attr);
108     }
109     SGLOGI("init model success, modelId=%{public}u", modelId);
110     return SUCCESS;
111 }
112 
GetResult(uint32_t modelId, const std::string &param)113 std::string ModelManager::GetResult(uint32_t modelId, const std::string &param)
114 {
115     std::string result = "unknown";
116     int32_t ret = InitModel(modelId);
117     if (ret != SUCCESS) {
118         return result;
119     }
120 
121     {
122         std::lock_guard<std::mutex> lock(mutex_);
123         auto iter = modelIdApiMap_.find(modelId);
124         if (iter == modelIdApiMap_.end() || iter->second == nullptr || iter->second->GetModelApi() == nullptr) {
125             SGLOGI("the model has not been initialized, begin init, modelId=%{public}u", modelId);
126             return result;
127         }
128         result = iter->second->GetModelApi()->GetResult(modelId, param);
129     }
130     ModelCfg config;
131     bool success = ConfigDataManager::GetInstance().GetModelConfig(modelId, config);
132     if (success && config.startMode == START_ON_DEMAND) {
133         Release(modelId);
134     }
135     return result;
136 }
137 
SubscribeResult(uint32_t modelId, std::shared_ptr<IModelResultListener> listener)138 int32_t ModelManager::SubscribeResult(uint32_t modelId, std::shared_ptr<IModelResultListener> listener)
139 {
140     int32_t ret = InitModel(modelId);
141     if (ret != SUCCESS) {
142         return ret;
143     }
144 
145     std::lock_guard<std::mutex> lock(mutex_);
146     auto iter = modelIdApiMap_.find(modelId);
147     if (iter == modelIdApiMap_.end() || iter->second == nullptr || iter->second->GetModelApi() == nullptr) {
148         SGLOGI("the model has not been initialized, modelId=%{public}u", modelId);
149         return FAILED;
150     }
151 
152     return iter->second->GetModelApi()->SubscribeResult(listener);
153 }
154 
Release(uint32_t modelId)155 void ModelManager::Release(uint32_t modelId)
156 {
157     std::lock_guard<std::mutex> lock(mutex_);
158     auto iter = modelIdApiMap_.find(modelId);
159     if (iter == modelIdApiMap_.end()) {
160         SGLOGI("the model has not been initialized, modelId=%{public}u", modelId);
161         return;
162     }
163 
164     if (iter->second == nullptr || iter->second->GetModelApi() == nullptr) {
165         SGLOGI("the model attr is nullptr, modelId=%{public}u", modelId);
166         modelIdApiMap_.erase(iter);
167         return;
168     }
169 
170     iter->second->GetModelApi()->Release();
171     modelIdApiMap_.erase(iter);
172 }
173 } // namespace OHOS::Security::SecurityGuard
174