1From 49309e193d11f8546f0f564b74ea4adc1425cd9f Mon Sep 17 00:00:00 2001 2From: chengfeng27 <chengfeng27@huawei.com> 3Date: Mon, 9 Sep 2024 09:55:11 +0800 4Subject: [PATCH] fix context double free 5 6--- 7 mindspore/lite/src/litert/c_api/context_c.cc | 50 ++++++++++++-------- 8 mindspore/lite/src/litert/c_api/context_c.h | 16 +++++++ 9 mindspore/lite/src/litert/c_api/model_c.cc | 29 +++++++----- 10 3 files changed, 63 insertions(+), 32 deletions(-) 11 12diff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc 13index 88fdc4d0..5418c46a 100644 14--- a/mindspore/lite/src/litert/c_api/context_c.cc 15+++ b/mindspore/lite/src/litert/c_api/context_c.cc 16@@ -17,6 +17,7 @@ 17 #include "include/api/context.h" 18 #include <string.h> 19 #include "src/litert/c_api/type_c_private.h" 20+#include "src/litert/c_api/context_c.h" 21 #include "src/common/log_adapter.h" 22 #ifdef SUPPORT_NNRT_METAGRAPH 23 #include "src/litert/delegate/nnrt/hiai_foundation_wrapper.h" 24@@ -31,17 +32,26 @@ const auto kNpuNamePrefixLen = 4; 25 26 // ================ Context ================ 27 OH_AI_ContextHandle OH_AI_ContextCreate() { 28- auto impl = new (std::nothrow) mindspore::Context(); 29+ auto impl = new (std::nothrow) mindspore::ContextC(); 30 if (impl == nullptr) { 31 MS_LOG(ERROR) << "memory allocation failed."; 32 return nullptr; 33 } 34+ impl->context_ = new (std::nothrow) mindspore::Context(); 35+ if (impl->context_ == nullptr) { 36+ MS_LOG(ERROR) << "memory allocation failed."; 37+ delete impl; 38+ } 39+ impl->owned_by_model_ = false; 40 return static_cast<OH_AI_ContextHandle>(impl); 41 } 42 43 void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 44 if (context != nullptr && *context != nullptr) { 45- auto impl = static_cast<mindspore::Context *>(*context); 46+ auto impl = static_cast<mindspore::ContextC *>(*context); 47+ if (impl->owned_by_model_) { 48+ impl->context_ = nullptr; 49+ } 50 delete impl; 51 *context = nullptr; 52 } 53@@ -52,8 +62,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) 54 MS_LOG(ERROR) << "param is nullptr."; 55 return; 56 } 57- auto impl = static_cast<mindspore::Context *>(context); 58- impl->SetThreadNum(thread_num); 59+ auto impl = static_cast<mindspore::ContextC *>(context); 60+ impl->context_->SetThreadNum(thread_num); 61 } 62 63 int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 64@@ -61,8 +71,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 65 MS_LOG(ERROR) << "param is nullptr."; 66 return 0; 67 } 68- auto impl = static_cast<mindspore::Context *>(context); 69- return impl->GetThreadNum(); 70+ auto impl = static_cast<mindspore::ContextC *>(context); 71+ return impl->context_->GetThreadNum(); 72 } 73 74 void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 75@@ -70,8 +80,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 76 MS_LOG(ERROR) << "param is nullptr."; 77 return; 78 } 79- auto impl = static_cast<mindspore::Context *>(context); 80- impl->SetThreadAffinity(mode); 81+ auto impl = static_cast<mindspore::ContextC *>(context); 82+ impl->context_->SetThreadAffinity(mode); 83 return; 84 } 85 86@@ -80,8 +90,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 87 MS_LOG(ERROR) << "param is nullptr."; 88 return 0; 89 } 90- auto impl = static_cast<mindspore::Context *>(context); 91- return impl->GetThreadAffinityMode(); 92+ auto impl = static_cast<mindspore::ContextC *>(context); 93+ return impl->context_->GetThreadAffinityMode(); 94 } 95 96 void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) { 97@@ -90,8 +100,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i 98 return; 99 } 100 const std::vector<int32_t> vec_core_list(core_list, core_list + core_num); 101- auto impl = static_cast<mindspore::Context *>(context); 102- impl->SetThreadAffinity(vec_core_list); 103+ auto impl = static_cast<mindspore::ContextC *>(context); 104+ impl->context_->SetThreadAffinity(vec_core_list); 105 return; 106 } 107 108@@ -100,8 +110,8 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle 109 MS_LOG(ERROR) << "param is nullptr."; 110 return nullptr; 111 } 112- auto impl = static_cast<mindspore::Context *>(context); 113- auto affinity_core_list = impl->GetThreadAffinityCoreList(); 114+ auto impl = static_cast<mindspore::ContextC *>(context); 115+ auto affinity_core_list = impl->context_->GetThreadAffinityCoreList(); 116 *core_num = affinity_core_list.size(); 117 int32_t *core_list = static_cast<int32_t *>(malloc((*core_num) * sizeof(int32_t))); 118 if (core_list == nullptr) { 119@@ -119,8 +129,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle 120 MS_LOG(ERROR) << "param is nullptr."; 121 return; 122 } 123- auto impl = static_cast<mindspore::Context *>(context); 124- impl->SetEnableParallel(is_parallel); 125+ auto impl = static_cast<mindspore::ContextC *>(context); 126+ impl->context_->SetEnableParallel(is_parallel); 127 } 128 129 bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { 130@@ -128,8 +138,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { 131 MS_LOG(ERROR) << "param is nullptr."; 132 return false; 133 } 134- auto impl = static_cast<mindspore::Context *>(context); 135- return impl->GetEnableParallel(); 136+ auto impl = static_cast<mindspore::ContextC *>(context); 137+ return impl->context_->GetEnableParallel(); 138 } 139 140 void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) { 141@@ -137,9 +147,9 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan 142 MS_LOG(ERROR) << "param is nullptr."; 143 return; 144 } 145- auto impl = static_cast<mindspore::Context *>(context); 146+ auto impl = static_cast<mindspore::ContextC *>(context); 147 std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info)); 148- impl->MutableDeviceInfo().push_back(device); 149+ impl->context_->MutableDeviceInfo().push_back(device); 150 } 151 152 // ================ DeviceInfo ================ 153diff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h 154index dc88b8a4..34f4d1e4 100644 155--- a/mindspore/lite/src/litert/c_api/context_c.h 156+++ b/mindspore/lite/src/litert/c_api/context_c.h 157@@ -20,5 +20,21 @@ 158 #include <vector> 159 #include <memory> 160 #include "include/c_api/types_c.h" 161+#include "include/api/context.h" 162+ 163+namespace mindspore { 164+class ContextC { 165+ public: 166+ ContextC() : owned_by_model_(false), context_(nullptr) {} 167+ ~ContextC() { 168+ if (context_ != nullptr) { 169+ delete context_; 170+ context_ = nullptr; 171+ } 172+ } 173+ bool owned_by_model_; 174+ Context *context_; 175+}; 176+} // namespace mindspore 177 178 #endif // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_ 179diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc 180index 661a8d06..d8632338 100644 181--- a/mindspore/lite/src/litert/c_api/model_c.cc 182+++ b/mindspore/lite/src/litert/c_api/model_c.cc 183@@ -15,6 +15,7 @@ 184 */ 185 #include "include/c_api/model_c.h" 186 #include "type_c_private.h" 187+#include "context_c.h" 188 #include <vector> 189 #include <cstdint> 190 #include "include/api/context.h" 191@@ -191,10 +192,11 @@ OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, s 192 MS_LOG(ERROR) << "model_type is invalid."; 193 return OH_AI_STATUS_LITE_PARAM_INVALID; 194 } 195- mindspore::Context *context = static_cast<mindspore::Context *>(model_context); 196+ mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context); 197 auto impl = static_cast<mindspore::ModelC *>(model); 198- if (impl->context_.get() != context) { 199- impl->context_.reset(context); 200+ if (impl->context_.get() != context->context_) { 201+ impl->context_.reset(context->context_); 202+ context->owned_by_model_ = true; 203 } 204 auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_); 205 return static_cast<OH_AI_Status>(ret.StatusCode()); 206@@ -210,10 +212,11 @@ OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model 207 MS_LOG(ERROR) << "model_type is invalid."; 208 return OH_AI_STATUS_LITE_PARAM_INVALID; 209 } 210- mindspore::Context *context = static_cast<mindspore::Context *>(model_context); 211+ mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context); 212 auto impl = static_cast<mindspore::ModelC *>(model); 213- if (impl->context_.get() != context) { 214- impl->context_.reset(context); 215+ if (impl->context_.get() != context->context_) { 216+ impl->context_.reset(context->context_); 217+ context->owned_by_model_ = true; 218 } 219 auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_); 220 return static_cast<OH_AI_Status>(ret.StatusCode()); 221@@ -447,10 +450,11 @@ OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_da 222 MS_LOG(ERROR) << "load ms file failed."; 223 return OH_AI_STATUS_LITE_ERROR; 224 } 225- auto context = static_cast<mindspore::Context *>(model_context); 226+ auto context = static_cast<mindspore::ContextC *>(model_context); 227 auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg); 228- if (impl->context_.get() != context) { 229- impl->context_.reset(context); 230+ if (impl->context_.get() != context->context_) { 231+ impl->context_.reset(context->context_); 232+ context->owned_by_model_ = true; 233 } 234 auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_, 235 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg)); 236@@ -479,10 +483,11 @@ OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char * 237 MS_LOG(ERROR) << "load ms file failed. " << model_path; 238 return OH_AI_STATUS_LITE_ERROR; 239 } 240- auto context = static_cast<mindspore::Context *>(model_context); 241+ auto context = static_cast<mindspore::ContextC *>(model_context); 242 auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg); 243- if (impl->context_.get() != context) { 244- impl->context_.reset(context); 245+ if (impl->context_.get() != context->context_) { 246+ impl->context_.reset(context->context_); 247+ context->owned_by_model_ = true; 248 } 249 auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_, 250 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg)); 251-- 2522.17.1 253 254