1be168c0dSopenharmony_ciFrom 49309e193d11f8546f0f564b74ea4adc1425cd9f Mon Sep 17 00:00:00 2001
2be168c0dSopenharmony_ciFrom: chengfeng27 <chengfeng27@huawei.com>
3be168c0dSopenharmony_ciDate: Mon, 9 Sep 2024 09:55:11 +0800
4be168c0dSopenharmony_ciSubject: [PATCH] fix context double free
5be168c0dSopenharmony_ci
6be168c0dSopenharmony_ci---
7be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/context_c.cc | 50 ++++++++++++--------
8be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/context_c.h  | 16 +++++++
9be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/model_c.cc   | 29 +++++++-----
10be168c0dSopenharmony_ci 3 files changed, 63 insertions(+), 32 deletions(-)
11be168c0dSopenharmony_ci
12be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc
13be168c0dSopenharmony_ciindex 88fdc4d0..5418c46a 100644
14be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/context_c.cc
15be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/context_c.cc
16be168c0dSopenharmony_ci@@ -17,6 +17,7 @@
17be168c0dSopenharmony_ci #include "include/api/context.h"
18be168c0dSopenharmony_ci #include <string.h>
19be168c0dSopenharmony_ci #include "src/litert/c_api/type_c_private.h"
20be168c0dSopenharmony_ci+#include "src/litert/c_api/context_c.h"
21be168c0dSopenharmony_ci #include "src/common/log_adapter.h"
22be168c0dSopenharmony_ci #ifdef SUPPORT_NNRT_METAGRAPH
23be168c0dSopenharmony_ci #include "src/litert/delegate/nnrt/hiai_foundation_wrapper.h"
24be168c0dSopenharmony_ci@@ -31,17 +32,26 @@ const auto kNpuNamePrefixLen = 4;
25be168c0dSopenharmony_ci 
26be168c0dSopenharmony_ci // ================ Context ================
27be168c0dSopenharmony_ci OH_AI_ContextHandle OH_AI_ContextCreate() {
28be168c0dSopenharmony_ci-  auto impl = new (std::nothrow) mindspore::Context();
29be168c0dSopenharmony_ci+  auto impl = new (std::nothrow) mindspore::ContextC();
30be168c0dSopenharmony_ci   if (impl == nullptr) {
31be168c0dSopenharmony_ci     MS_LOG(ERROR) << "memory allocation failed.";
32be168c0dSopenharmony_ci     return nullptr;
33be168c0dSopenharmony_ci   }
34be168c0dSopenharmony_ci+  impl->context_ = new (std::nothrow) mindspore::Context();
35be168c0dSopenharmony_ci+  if (impl->context_ == nullptr) {
36be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "memory allocation failed.";
37be168c0dSopenharmony_ci+    delete impl;
38be168c0dSopenharmony_ci+  }
39be168c0dSopenharmony_ci+  impl->owned_by_model_ = false;
40be168c0dSopenharmony_ci   return static_cast<OH_AI_ContextHandle>(impl);
41be168c0dSopenharmony_ci }
42be168c0dSopenharmony_ci 
43be168c0dSopenharmony_ci void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) {
44be168c0dSopenharmony_ci   if (context != nullptr && *context != nullptr) {
45be168c0dSopenharmony_ci-    auto impl = static_cast<mindspore::Context *>(*context);
46be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::ContextC *>(*context);
47be168c0dSopenharmony_ci+    if (impl->owned_by_model_) {
48be168c0dSopenharmony_ci+      impl->context_ = nullptr;
49be168c0dSopenharmony_ci+    }
50be168c0dSopenharmony_ci     delete impl;
51be168c0dSopenharmony_ci     *context = nullptr;
52be168c0dSopenharmony_ci   }
53be168c0dSopenharmony_ci@@ -52,8 +62,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num)
54be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
55be168c0dSopenharmony_ci     return;
56be168c0dSopenharmony_ci   }
57be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
58be168c0dSopenharmony_ci-  impl->SetThreadNum(thread_num);
59be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
60be168c0dSopenharmony_ci+  impl->context_->SetThreadNum(thread_num);
61be168c0dSopenharmony_ci }
62be168c0dSopenharmony_ci 
63be168c0dSopenharmony_ci int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
64be168c0dSopenharmony_ci@@ -61,8 +71,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
65be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
66be168c0dSopenharmony_ci     return 0;
67be168c0dSopenharmony_ci   }
68be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
69be168c0dSopenharmony_ci-  return impl->GetThreadNum();
70be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
71be168c0dSopenharmony_ci+  return impl->context_->GetThreadNum();
72be168c0dSopenharmony_ci }
73be168c0dSopenharmony_ci 
74be168c0dSopenharmony_ci void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
75be168c0dSopenharmony_ci@@ -70,8 +80,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
76be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
77be168c0dSopenharmony_ci     return;
78be168c0dSopenharmony_ci   }
79be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
80be168c0dSopenharmony_ci-  impl->SetThreadAffinity(mode);
81be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
82be168c0dSopenharmony_ci+  impl->context_->SetThreadAffinity(mode);
83be168c0dSopenharmony_ci   return;
84be168c0dSopenharmony_ci }
85be168c0dSopenharmony_ci 
86be168c0dSopenharmony_ci@@ -80,8 +90,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) {
87be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
88be168c0dSopenharmony_ci     return 0;
89be168c0dSopenharmony_ci   }
90be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
91be168c0dSopenharmony_ci-  return impl->GetThreadAffinityMode();
92be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
93be168c0dSopenharmony_ci+  return impl->context_->GetThreadAffinityMode();
94be168c0dSopenharmony_ci }
95be168c0dSopenharmony_ci 
96be168c0dSopenharmony_ci void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) {
97be168c0dSopenharmony_ci@@ -90,8 +100,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i
98be168c0dSopenharmony_ci     return;
99be168c0dSopenharmony_ci   }
100be168c0dSopenharmony_ci   const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
101be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
102be168c0dSopenharmony_ci-  impl->SetThreadAffinity(vec_core_list);
103be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
104be168c0dSopenharmony_ci+  impl->context_->SetThreadAffinity(vec_core_list);
105be168c0dSopenharmony_ci   return;
106be168c0dSopenharmony_ci }
107be168c0dSopenharmony_ci 
108be168c0dSopenharmony_ci@@ -100,8 +110,8 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle
109be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
110be168c0dSopenharmony_ci     return nullptr;
111be168c0dSopenharmony_ci   }
112be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
113be168c0dSopenharmony_ci-  auto affinity_core_list = impl->GetThreadAffinityCoreList();
114be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
115be168c0dSopenharmony_ci+  auto affinity_core_list = impl->context_->GetThreadAffinityCoreList();
116be168c0dSopenharmony_ci   *core_num = affinity_core_list.size();
117be168c0dSopenharmony_ci   int32_t *core_list = static_cast<int32_t *>(malloc((*core_num) * sizeof(int32_t)));
118be168c0dSopenharmony_ci   if (core_list == nullptr) {
119be168c0dSopenharmony_ci@@ -119,8 +129,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle
120be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
121be168c0dSopenharmony_ci     return;
122be168c0dSopenharmony_ci   }
123be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
124be168c0dSopenharmony_ci-  impl->SetEnableParallel(is_parallel);
125be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
126be168c0dSopenharmony_ci+  impl->context_->SetEnableParallel(is_parallel);
127be168c0dSopenharmony_ci }
128be168c0dSopenharmony_ci 
129be168c0dSopenharmony_ci bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
130be168c0dSopenharmony_ci@@ -128,8 +138,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
131be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
132be168c0dSopenharmony_ci     return false;
133be168c0dSopenharmony_ci   }
134be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
135be168c0dSopenharmony_ci-  return impl->GetEnableParallel();
136be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
137be168c0dSopenharmony_ci+  return impl->context_->GetEnableParallel();
138be168c0dSopenharmony_ci }
139be168c0dSopenharmony_ci 
140be168c0dSopenharmony_ci void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) {
141be168c0dSopenharmony_ci@@ -137,9 +147,9 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan
142be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
143be168c0dSopenharmony_ci     return;
144be168c0dSopenharmony_ci   }
145be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::Context *>(context);
146be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ContextC *>(context);
147be168c0dSopenharmony_ci   std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info));
148be168c0dSopenharmony_ci-  impl->MutableDeviceInfo().push_back(device);
149be168c0dSopenharmony_ci+  impl->context_->MutableDeviceInfo().push_back(device);
150be168c0dSopenharmony_ci }
151be168c0dSopenharmony_ci 
152be168c0dSopenharmony_ci // ================ DeviceInfo ================
153be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h
154be168c0dSopenharmony_ciindex dc88b8a4..34f4d1e4 100644
155be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/context_c.h
156be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/context_c.h
157be168c0dSopenharmony_ci@@ -20,5 +20,21 @@
158be168c0dSopenharmony_ci #include <vector>
159be168c0dSopenharmony_ci #include <memory>
160be168c0dSopenharmony_ci #include "include/c_api/types_c.h"
161be168c0dSopenharmony_ci+#include "include/api/context.h"
162be168c0dSopenharmony_ci+
163be168c0dSopenharmony_ci+namespace mindspore {
164be168c0dSopenharmony_ci+class ContextC {
165be168c0dSopenharmony_ci+ public:
166be168c0dSopenharmony_ci+  ContextC() : owned_by_model_(false), context_(nullptr) {}
167be168c0dSopenharmony_ci+  ~ContextC() {
168be168c0dSopenharmony_ci+    if (context_ != nullptr) {
169be168c0dSopenharmony_ci+      delete context_;
170be168c0dSopenharmony_ci+      context_ = nullptr;
171be168c0dSopenharmony_ci+    }
172be168c0dSopenharmony_ci+  }
173be168c0dSopenharmony_ci+  bool owned_by_model_;
174be168c0dSopenharmony_ci+  Context *context_;
175be168c0dSopenharmony_ci+};
176be168c0dSopenharmony_ci+}  // namespace mindspore
177be168c0dSopenharmony_ci 
178be168c0dSopenharmony_ci #endif  // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_
179be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc
180be168c0dSopenharmony_ciindex 661a8d06..d8632338 100644
181be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/model_c.cc
182be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/model_c.cc
183be168c0dSopenharmony_ci@@ -15,6 +15,7 @@
184be168c0dSopenharmony_ci  */
185be168c0dSopenharmony_ci #include "include/c_api/model_c.h"
186be168c0dSopenharmony_ci #include "type_c_private.h"
187be168c0dSopenharmony_ci+#include "context_c.h"
188be168c0dSopenharmony_ci #include <vector>
189be168c0dSopenharmony_ci #include <cstdint>
190be168c0dSopenharmony_ci #include "include/api/context.h"
191be168c0dSopenharmony_ci@@ -191,10 +192,11 @@ OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, s
192be168c0dSopenharmony_ci     MS_LOG(ERROR) << "model_type is invalid.";
193be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_PARAM_INVALID;
194be168c0dSopenharmony_ci   }
195be168c0dSopenharmony_ci-  mindspore::Context *context = static_cast<mindspore::Context *>(model_context);
196be168c0dSopenharmony_ci+  mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
197be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
198be168c0dSopenharmony_ci-  if (impl->context_.get() != context) {
199be168c0dSopenharmony_ci-    impl->context_.reset(context);
200be168c0dSopenharmony_ci+  if (impl->context_.get() != context->context_) {
201be168c0dSopenharmony_ci+    impl->context_.reset(context->context_);
202be168c0dSopenharmony_ci+    context->owned_by_model_ = true;
203be168c0dSopenharmony_ci   }
204be168c0dSopenharmony_ci   auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_);
205be168c0dSopenharmony_ci   return static_cast<OH_AI_Status>(ret.StatusCode());
206be168c0dSopenharmony_ci@@ -210,10 +212,11 @@ OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model
207be168c0dSopenharmony_ci     MS_LOG(ERROR) << "model_type is invalid.";
208be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_PARAM_INVALID;
209be168c0dSopenharmony_ci   }
210be168c0dSopenharmony_ci-  mindspore::Context *context = static_cast<mindspore::Context *>(model_context);
211be168c0dSopenharmony_ci+  mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
212be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
213be168c0dSopenharmony_ci-  if (impl->context_.get() != context) {
214be168c0dSopenharmony_ci-    impl->context_.reset(context);
215be168c0dSopenharmony_ci+  if (impl->context_.get() != context->context_) {
216be168c0dSopenharmony_ci+    impl->context_.reset(context->context_);
217be168c0dSopenharmony_ci+    context->owned_by_model_ = true;
218be168c0dSopenharmony_ci   }
219be168c0dSopenharmony_ci   auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_);
220be168c0dSopenharmony_ci   return static_cast<OH_AI_Status>(ret.StatusCode());
221be168c0dSopenharmony_ci@@ -447,10 +450,11 @@ OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_da
222be168c0dSopenharmony_ci     MS_LOG(ERROR) << "load ms file failed.";
223be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_ERROR;
224be168c0dSopenharmony_ci   }
225be168c0dSopenharmony_ci-  auto context = static_cast<mindspore::Context *>(model_context);
226be168c0dSopenharmony_ci+  auto context = static_cast<mindspore::ContextC *>(model_context);
227be168c0dSopenharmony_ci   auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
228be168c0dSopenharmony_ci-  if (impl->context_.get() != context) {
229be168c0dSopenharmony_ci-    impl->context_.reset(context);
230be168c0dSopenharmony_ci+  if (impl->context_.get() != context->context_) {
231be168c0dSopenharmony_ci+    impl->context_.reset(context->context_);
232be168c0dSopenharmony_ci+    context->owned_by_model_ = true;
233be168c0dSopenharmony_ci   }
234be168c0dSopenharmony_ci   auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
235be168c0dSopenharmony_ci                                  std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
236be168c0dSopenharmony_ci@@ -479,10 +483,11 @@ OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *
237be168c0dSopenharmony_ci     MS_LOG(ERROR) << "load ms file failed. " << model_path;
238be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_ERROR;
239be168c0dSopenharmony_ci   }
240be168c0dSopenharmony_ci-  auto context = static_cast<mindspore::Context *>(model_context);
241be168c0dSopenharmony_ci+  auto context = static_cast<mindspore::ContextC *>(model_context);
242be168c0dSopenharmony_ci   auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
243be168c0dSopenharmony_ci-  if (impl->context_.get() != context) {
244be168c0dSopenharmony_ci-    impl->context_.reset(context);
245be168c0dSopenharmony_ci+  if (impl->context_.get() != context->context_) {
246be168c0dSopenharmony_ci+    impl->context_.reset(context->context_);
247be168c0dSopenharmony_ci+    context->owned_by_model_ = true;
248be168c0dSopenharmony_ci   }
249be168c0dSopenharmony_ci   auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
250be168c0dSopenharmony_ci                                  std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
251be168c0dSopenharmony_ci-- 
252be168c0dSopenharmony_ci2.17.1
253be168c0dSopenharmony_ci
254