1From 49309e193d11f8546f0f564b74ea4adc1425cd9f Mon Sep 17 00:00:00 2001
2From: chengfeng27 <chengfeng27@huawei.com>
3Date: Mon, 9 Sep 2024 09:55:11 +0800
4Subject: [PATCH] fix context double free
5
6---
7 mindspore/lite/src/litert/c_api/context_c.cc | 50 ++++++++++++--------
8 mindspore/lite/src/litert/c_api/context_c.h  | 16 +++++++
9 mindspore/lite/src/litert/c_api/model_c.cc   | 29 +++++++-----
10 3 files changed, 63 insertions(+), 32 deletions(-)
11
12diff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc
13index 88fdc4d0..5418c46a 100644
14--- a/mindspore/lite/src/litert/c_api/context_c.cc
15+++ b/mindspore/lite/src/litert/c_api/context_c.cc
16@@ -17,6 +17,7 @@
17 #include "include/api/context.h"
18 #include <string.h>
19 #include "src/litert/c_api/type_c_private.h"
20+#include "src/litert/c_api/context_c.h"
21 #include "src/common/log_adapter.h"
22 #ifdef SUPPORT_NNRT_METAGRAPH
23 #include "src/litert/delegate/nnrt/hiai_foundation_wrapper.h"
24@@ -31,17 +32,26 @@ const auto kNpuNamePrefixLen = 4;
25 
26 // ================ Context ================
27 OH_AI_ContextHandle OH_AI_ContextCreate() {
28-  auto impl = new (std::nothrow) mindspore::Context();
29+  auto impl = new (std::nothrow) mindspore::ContextC();
30   if (impl == nullptr) {
31     MS_LOG(ERROR) << "memory allocation failed.";
32     return nullptr;
33   }
34+  impl->context_ = new (std::nothrow) mindspore::Context();
35+  if (impl->context_ == nullptr) {
36+    MS_LOG(ERROR) << "memory allocation failed.";
37+    delete impl;
38+  }
39+  impl->owned_by_model_ = false;
40   return static_cast<OH_AI_ContextHandle>(impl);
41 }
42 
43 void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) {
44   if (context != nullptr && *context != nullptr) {
45-    auto impl = static_cast<mindspore::Context *>(*context);
46+    auto impl = static_cast<mindspore::ContextC *>(*context);
47+    if (impl->owned_by_model_) {
48+      impl->context_ = nullptr;
49+    }
50     delete impl;
51     *context = nullptr;
52   }
53@@ -52,8 +62,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num)
54     MS_LOG(ERROR) << "param is nullptr.";
55     return;
56   }
57-  auto impl = static_cast<mindspore::Context *>(context);
58-  impl->SetThreadNum(thread_num);
59+  auto impl = static_cast<mindspore::ContextC *>(context);
60+  impl->context_->SetThreadNum(thread_num);
61 }
62 
63 int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
64@@ -61,8 +71,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
65     MS_LOG(ERROR) << "param is nullptr.";
66     return 0;
67   }
68-  auto impl = static_cast<mindspore::Context *>(context);
69-  return impl->GetThreadNum();
70+  auto impl = static_cast<mindspore::ContextC *>(context);
71+  return impl->context_->GetThreadNum();
72 }
73 
74 void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
75@@ -70,8 +80,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
76     MS_LOG(ERROR) << "param is nullptr.";
77     return;
78   }
79-  auto impl = static_cast<mindspore::Context *>(context);
80-  impl->SetThreadAffinity(mode);
81+  auto impl = static_cast<mindspore::ContextC *>(context);
82+  impl->context_->SetThreadAffinity(mode);
83   return;
84 }
85 
86@@ -80,8 +90,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) {
87     MS_LOG(ERROR) << "param is nullptr.";
88     return 0;
89   }
90-  auto impl = static_cast<mindspore::Context *>(context);
91-  return impl->GetThreadAffinityMode();
92+  auto impl = static_cast<mindspore::ContextC *>(context);
93+  return impl->context_->GetThreadAffinityMode();
94 }
95 
96 void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) {
97@@ -90,8 +100,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i
98     return;
99   }
100   const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
101-  auto impl = static_cast<mindspore::Context *>(context);
102-  impl->SetThreadAffinity(vec_core_list);
103+  auto impl = static_cast<mindspore::ContextC *>(context);
104+  impl->context_->SetThreadAffinity(vec_core_list);
105   return;
106 }
107 
108@@ -100,8 +110,8 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle
109     MS_LOG(ERROR) << "param is nullptr.";
110     return nullptr;
111   }
112-  auto impl = static_cast<mindspore::Context *>(context);
113-  auto affinity_core_list = impl->GetThreadAffinityCoreList();
114+  auto impl = static_cast<mindspore::ContextC *>(context);
115+  auto affinity_core_list = impl->context_->GetThreadAffinityCoreList();
116   *core_num = affinity_core_list.size();
117   int32_t *core_list = static_cast<int32_t *>(malloc((*core_num) * sizeof(int32_t)));
118   if (core_list == nullptr) {
119@@ -119,8 +129,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle
120     MS_LOG(ERROR) << "param is nullptr.";
121     return;
122   }
123-  auto impl = static_cast<mindspore::Context *>(context);
124-  impl->SetEnableParallel(is_parallel);
125+  auto impl = static_cast<mindspore::ContextC *>(context);
126+  impl->context_->SetEnableParallel(is_parallel);
127 }
128 
129 bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
130@@ -128,8 +138,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
131     MS_LOG(ERROR) << "param is nullptr.";
132     return false;
133   }
134-  auto impl = static_cast<mindspore::Context *>(context);
135-  return impl->GetEnableParallel();
136+  auto impl = static_cast<mindspore::ContextC *>(context);
137+  return impl->context_->GetEnableParallel();
138 }
139 
140 void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) {
141@@ -137,9 +147,9 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan
142     MS_LOG(ERROR) << "param is nullptr.";
143     return;
144   }
145-  auto impl = static_cast<mindspore::Context *>(context);
146+  auto impl = static_cast<mindspore::ContextC *>(context);
147   std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info));
148-  impl->MutableDeviceInfo().push_back(device);
149+  impl->context_->MutableDeviceInfo().push_back(device);
150 }
151 
152 // ================ DeviceInfo ================
153diff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h
154index dc88b8a4..34f4d1e4 100644
155--- a/mindspore/lite/src/litert/c_api/context_c.h
156+++ b/mindspore/lite/src/litert/c_api/context_c.h
157@@ -20,5 +20,21 @@
158 #include <vector>
159 #include <memory>
160 #include "include/c_api/types_c.h"
161+#include "include/api/context.h"
162+
163+namespace mindspore {
164+class ContextC {
165+ public:
166+  ContextC() : owned_by_model_(false), context_(nullptr) {}
167+  ~ContextC() {
168+    if (context_ != nullptr) {
169+      delete context_;
170+      context_ = nullptr;
171+    }
172+  }
173+  bool owned_by_model_;
174+  Context *context_;
175+};
176+}  // namespace mindspore
177 
178 #endif  // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_
179diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc
180index 661a8d06..d8632338 100644
181--- a/mindspore/lite/src/litert/c_api/model_c.cc
182+++ b/mindspore/lite/src/litert/c_api/model_c.cc
183@@ -15,6 +15,7 @@
184  */
185 #include "include/c_api/model_c.h"
186 #include "type_c_private.h"
187+#include "context_c.h"
188 #include <vector>
189 #include <cstdint>
190 #include "include/api/context.h"
191@@ -191,10 +192,11 @@ OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, s
192     MS_LOG(ERROR) << "model_type is invalid.";
193     return OH_AI_STATUS_LITE_PARAM_INVALID;
194   }
195-  mindspore::Context *context = static_cast<mindspore::Context *>(model_context);
196+  mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
197   auto impl = static_cast<mindspore::ModelC *>(model);
198-  if (impl->context_.get() != context) {
199-    impl->context_.reset(context);
200+  if (impl->context_.get() != context->context_) {
201+    impl->context_.reset(context->context_);
202+    context->owned_by_model_ = true;
203   }
204   auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_);
205   return static_cast<OH_AI_Status>(ret.StatusCode());
206@@ -210,10 +212,11 @@ OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model
207     MS_LOG(ERROR) << "model_type is invalid.";
208     return OH_AI_STATUS_LITE_PARAM_INVALID;
209   }
210-  mindspore::Context *context = static_cast<mindspore::Context *>(model_context);
211+  mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
212   auto impl = static_cast<mindspore::ModelC *>(model);
213-  if (impl->context_.get() != context) {
214-    impl->context_.reset(context);
215+  if (impl->context_.get() != context->context_) {
216+    impl->context_.reset(context->context_);
217+    context->owned_by_model_ = true;
218   }
219   auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_);
220   return static_cast<OH_AI_Status>(ret.StatusCode());
221@@ -447,10 +450,11 @@ OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_da
222     MS_LOG(ERROR) << "load ms file failed.";
223     return OH_AI_STATUS_LITE_ERROR;
224   }
225-  auto context = static_cast<mindspore::Context *>(model_context);
226+  auto context = static_cast<mindspore::ContextC *>(model_context);
227   auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
228-  if (impl->context_.get() != context) {
229-    impl->context_.reset(context);
230+  if (impl->context_.get() != context->context_) {
231+    impl->context_.reset(context->context_);
232+    context->owned_by_model_ = true;
233   }
234   auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
235                                  std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
236@@ -479,10 +483,11 @@ OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *
237     MS_LOG(ERROR) << "load ms file failed. " << model_path;
238     return OH_AI_STATUS_LITE_ERROR;
239   }
240-  auto context = static_cast<mindspore::Context *>(model_context);
241+  auto context = static_cast<mindspore::ContextC *>(model_context);
242   auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
243-  if (impl->context_.get() != context) {
244-    impl->context_.reset(context);
245+  if (impl->context_.get() != context->context_) {
246+    impl->context_.reset(context->context_);
247+    context->owned_by_model_ = true;
248   }
249   auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
250                                  std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
251-- 
2522.17.1
253
254