1be168c0dSopenharmony_ciFrom 49309e193d11f8546f0f564b74ea4adc1425cd9f Mon Sep 17 00:00:00 2001 2be168c0dSopenharmony_ciFrom: chengfeng27 <chengfeng27@huawei.com> 3be168c0dSopenharmony_ciDate: Mon, 9 Sep 2024 09:55:11 +0800 4be168c0dSopenharmony_ciSubject: [PATCH] fix context double free 5be168c0dSopenharmony_ci 6be168c0dSopenharmony_ci--- 7be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/context_c.cc | 50 ++++++++++++-------- 8be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/context_c.h | 16 +++++++ 9be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/model_c.cc | 29 +++++++----- 10be168c0dSopenharmony_ci 3 files changed, 63 insertions(+), 32 deletions(-) 11be168c0dSopenharmony_ci 12be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc 13be168c0dSopenharmony_ciindex 88fdc4d0..5418c46a 100644 14be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/context_c.cc 15be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/context_c.cc 16be168c0dSopenharmony_ci@@ -17,6 +17,7 @@ 17be168c0dSopenharmony_ci #include "include/api/context.h" 18be168c0dSopenharmony_ci #include <string.h> 19be168c0dSopenharmony_ci #include "src/litert/c_api/type_c_private.h" 20be168c0dSopenharmony_ci+#include "src/litert/c_api/context_c.h" 21be168c0dSopenharmony_ci #include "src/common/log_adapter.h" 22be168c0dSopenharmony_ci #ifdef SUPPORT_NNRT_METAGRAPH 23be168c0dSopenharmony_ci #include "src/litert/delegate/nnrt/hiai_foundation_wrapper.h" 24be168c0dSopenharmony_ci@@ -31,17 +32,26 @@ const auto kNpuNamePrefixLen = 4; 25be168c0dSopenharmony_ci 26be168c0dSopenharmony_ci // ================ Context ================ 27be168c0dSopenharmony_ci OH_AI_ContextHandle OH_AI_ContextCreate() { 28be168c0dSopenharmony_ci- auto impl = new (std::nothrow) mindspore::Context(); 29be168c0dSopenharmony_ci+ auto impl = new (std::nothrow) mindspore::ContextC(); 30be168c0dSopenharmony_ci if (impl == nullptr) { 31be168c0dSopenharmony_ci MS_LOG(ERROR) << "memory allocation failed."; 32be168c0dSopenharmony_ci return nullptr; 33be168c0dSopenharmony_ci } 34be168c0dSopenharmony_ci+ impl->context_ = new (std::nothrow) mindspore::Context(); 35be168c0dSopenharmony_ci+ if (impl->context_ == nullptr) { 36be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "memory allocation failed."; 37be168c0dSopenharmony_ci+ delete impl; 38be168c0dSopenharmony_ci+ } 39be168c0dSopenharmony_ci+ impl->owned_by_model_ = false; 40be168c0dSopenharmony_ci return static_cast<OH_AI_ContextHandle>(impl); 41be168c0dSopenharmony_ci } 42be168c0dSopenharmony_ci 43be168c0dSopenharmony_ci void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 44be168c0dSopenharmony_ci if (context != nullptr && *context != nullptr) { 45be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(*context); 46be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(*context); 47be168c0dSopenharmony_ci+ if (impl->owned_by_model_) { 48be168c0dSopenharmony_ci+ impl->context_ = nullptr; 49be168c0dSopenharmony_ci+ } 50be168c0dSopenharmony_ci delete impl; 51be168c0dSopenharmony_ci *context = nullptr; 52be168c0dSopenharmony_ci } 53be168c0dSopenharmony_ci@@ -52,8 +62,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) 54be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 55be168c0dSopenharmony_ci return; 56be168c0dSopenharmony_ci } 57be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 58be168c0dSopenharmony_ci- impl->SetThreadNum(thread_num); 59be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 60be168c0dSopenharmony_ci+ impl->context_->SetThreadNum(thread_num); 61be168c0dSopenharmony_ci } 62be168c0dSopenharmony_ci 63be168c0dSopenharmony_ci int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 64be168c0dSopenharmony_ci@@ -61,8 +71,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 65be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 66be168c0dSopenharmony_ci return 0; 67be168c0dSopenharmony_ci } 68be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 69be168c0dSopenharmony_ci- return impl->GetThreadNum(); 70be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 71be168c0dSopenharmony_ci+ return impl->context_->GetThreadNum(); 72be168c0dSopenharmony_ci } 73be168c0dSopenharmony_ci 74be168c0dSopenharmony_ci void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 75be168c0dSopenharmony_ci@@ -70,8 +80,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 76be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 77be168c0dSopenharmony_ci return; 78be168c0dSopenharmony_ci } 79be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 80be168c0dSopenharmony_ci- impl->SetThreadAffinity(mode); 81be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 82be168c0dSopenharmony_ci+ impl->context_->SetThreadAffinity(mode); 83be168c0dSopenharmony_ci return; 84be168c0dSopenharmony_ci } 85be168c0dSopenharmony_ci 86be168c0dSopenharmony_ci@@ -80,8 +90,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 87be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 88be168c0dSopenharmony_ci return 0; 89be168c0dSopenharmony_ci } 90be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 91be168c0dSopenharmony_ci- return impl->GetThreadAffinityMode(); 92be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 93be168c0dSopenharmony_ci+ return impl->context_->GetThreadAffinityMode(); 94be168c0dSopenharmony_ci } 95be168c0dSopenharmony_ci 96be168c0dSopenharmony_ci void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) { 97be168c0dSopenharmony_ci@@ -90,8 +100,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i 98be168c0dSopenharmony_ci return; 99be168c0dSopenharmony_ci } 100be168c0dSopenharmony_ci const std::vector<int32_t> vec_core_list(core_list, core_list + core_num); 101be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 102be168c0dSopenharmony_ci- impl->SetThreadAffinity(vec_core_list); 103be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 104be168c0dSopenharmony_ci+ impl->context_->SetThreadAffinity(vec_core_list); 105be168c0dSopenharmony_ci return; 106be168c0dSopenharmony_ci } 107be168c0dSopenharmony_ci 108be168c0dSopenharmony_ci@@ -100,8 +110,8 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle 109be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 110be168c0dSopenharmony_ci return nullptr; 111be168c0dSopenharmony_ci } 112be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 113be168c0dSopenharmony_ci- auto affinity_core_list = impl->GetThreadAffinityCoreList(); 114be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 115be168c0dSopenharmony_ci+ auto affinity_core_list = impl->context_->GetThreadAffinityCoreList(); 116be168c0dSopenharmony_ci *core_num = affinity_core_list.size(); 117be168c0dSopenharmony_ci int32_t *core_list = static_cast<int32_t *>(malloc((*core_num) * sizeof(int32_t))); 118be168c0dSopenharmony_ci if (core_list == nullptr) { 119be168c0dSopenharmony_ci@@ -119,8 +129,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle 120be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 121be168c0dSopenharmony_ci return; 122be168c0dSopenharmony_ci } 123be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 124be168c0dSopenharmony_ci- impl->SetEnableParallel(is_parallel); 125be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 126be168c0dSopenharmony_ci+ impl->context_->SetEnableParallel(is_parallel); 127be168c0dSopenharmony_ci } 128be168c0dSopenharmony_ci 129be168c0dSopenharmony_ci bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { 130be168c0dSopenharmony_ci@@ -128,8 +138,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { 131be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 132be168c0dSopenharmony_ci return false; 133be168c0dSopenharmony_ci } 134be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 135be168c0dSopenharmony_ci- return impl->GetEnableParallel(); 136be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 137be168c0dSopenharmony_ci+ return impl->context_->GetEnableParallel(); 138be168c0dSopenharmony_ci } 139be168c0dSopenharmony_ci 140be168c0dSopenharmony_ci void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) { 141be168c0dSopenharmony_ci@@ -137,9 +147,9 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan 142be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 143be168c0dSopenharmony_ci return; 144be168c0dSopenharmony_ci } 145be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::Context *>(context); 146be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ContextC *>(context); 147be168c0dSopenharmony_ci std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info)); 148be168c0dSopenharmony_ci- impl->MutableDeviceInfo().push_back(device); 149be168c0dSopenharmony_ci+ impl->context_->MutableDeviceInfo().push_back(device); 150be168c0dSopenharmony_ci } 151be168c0dSopenharmony_ci 152be168c0dSopenharmony_ci // ================ DeviceInfo ================ 153be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h 154be168c0dSopenharmony_ciindex dc88b8a4..34f4d1e4 100644 155be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/context_c.h 156be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/context_c.h 157be168c0dSopenharmony_ci@@ -20,5 +20,21 @@ 158be168c0dSopenharmony_ci #include <vector> 159be168c0dSopenharmony_ci #include <memory> 160be168c0dSopenharmony_ci #include "include/c_api/types_c.h" 161be168c0dSopenharmony_ci+#include "include/api/context.h" 162be168c0dSopenharmony_ci+ 163be168c0dSopenharmony_ci+namespace mindspore { 164be168c0dSopenharmony_ci+class ContextC { 165be168c0dSopenharmony_ci+ public: 166be168c0dSopenharmony_ci+ ContextC() : owned_by_model_(false), context_(nullptr) {} 167be168c0dSopenharmony_ci+ ~ContextC() { 168be168c0dSopenharmony_ci+ if (context_ != nullptr) { 169be168c0dSopenharmony_ci+ delete context_; 170be168c0dSopenharmony_ci+ context_ = nullptr; 171be168c0dSopenharmony_ci+ } 172be168c0dSopenharmony_ci+ } 173be168c0dSopenharmony_ci+ bool owned_by_model_; 174be168c0dSopenharmony_ci+ Context *context_; 175be168c0dSopenharmony_ci+}; 176be168c0dSopenharmony_ci+} // namespace mindspore 177be168c0dSopenharmony_ci 178be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_ 179be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc 180be168c0dSopenharmony_ciindex 661a8d06..d8632338 100644 181be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/model_c.cc 182be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/model_c.cc 183be168c0dSopenharmony_ci@@ -15,6 +15,7 @@ 184be168c0dSopenharmony_ci */ 185be168c0dSopenharmony_ci #include "include/c_api/model_c.h" 186be168c0dSopenharmony_ci #include "type_c_private.h" 187be168c0dSopenharmony_ci+#include "context_c.h" 188be168c0dSopenharmony_ci #include <vector> 189be168c0dSopenharmony_ci #include <cstdint> 190be168c0dSopenharmony_ci #include "include/api/context.h" 191be168c0dSopenharmony_ci@@ -191,10 +192,11 @@ OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, s 192be168c0dSopenharmony_ci MS_LOG(ERROR) << "model_type is invalid."; 193be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_PARAM_INVALID; 194be168c0dSopenharmony_ci } 195be168c0dSopenharmony_ci- mindspore::Context *context = static_cast<mindspore::Context *>(model_context); 196be168c0dSopenharmony_ci+ mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context); 197be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 198be168c0dSopenharmony_ci- if (impl->context_.get() != context) { 199be168c0dSopenharmony_ci- impl->context_.reset(context); 200be168c0dSopenharmony_ci+ if (impl->context_.get() != context->context_) { 201be168c0dSopenharmony_ci+ impl->context_.reset(context->context_); 202be168c0dSopenharmony_ci+ context->owned_by_model_ = true; 203be168c0dSopenharmony_ci } 204be168c0dSopenharmony_ci auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_); 205be168c0dSopenharmony_ci return static_cast<OH_AI_Status>(ret.StatusCode()); 206be168c0dSopenharmony_ci@@ -210,10 +212,11 @@ OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model 207be168c0dSopenharmony_ci MS_LOG(ERROR) << "model_type is invalid."; 208be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_PARAM_INVALID; 209be168c0dSopenharmony_ci } 210be168c0dSopenharmony_ci- mindspore::Context *context = static_cast<mindspore::Context *>(model_context); 211be168c0dSopenharmony_ci+ mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context); 212be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 213be168c0dSopenharmony_ci- if (impl->context_.get() != context) { 214be168c0dSopenharmony_ci- impl->context_.reset(context); 215be168c0dSopenharmony_ci+ if (impl->context_.get() != context->context_) { 216be168c0dSopenharmony_ci+ impl->context_.reset(context->context_); 217be168c0dSopenharmony_ci+ context->owned_by_model_ = true; 218be168c0dSopenharmony_ci } 219be168c0dSopenharmony_ci auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_); 220be168c0dSopenharmony_ci return static_cast<OH_AI_Status>(ret.StatusCode()); 221be168c0dSopenharmony_ci@@ -447,10 +450,11 @@ OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_da 222be168c0dSopenharmony_ci MS_LOG(ERROR) << "load ms file failed."; 223be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_ERROR; 224be168c0dSopenharmony_ci } 225be168c0dSopenharmony_ci- auto context = static_cast<mindspore::Context *>(model_context); 226be168c0dSopenharmony_ci+ auto context = static_cast<mindspore::ContextC *>(model_context); 227be168c0dSopenharmony_ci auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg); 228be168c0dSopenharmony_ci- if (impl->context_.get() != context) { 229be168c0dSopenharmony_ci- impl->context_.reset(context); 230be168c0dSopenharmony_ci+ if (impl->context_.get() != context->context_) { 231be168c0dSopenharmony_ci+ impl->context_.reset(context->context_); 232be168c0dSopenharmony_ci+ context->owned_by_model_ = true; 233be168c0dSopenharmony_ci } 234be168c0dSopenharmony_ci auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_, 235be168c0dSopenharmony_ci std::shared_ptr<mindspore::TrainCfg>(build_train_cfg)); 236be168c0dSopenharmony_ci@@ -479,10 +483,11 @@ OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char * 237be168c0dSopenharmony_ci MS_LOG(ERROR) << "load ms file failed. " << model_path; 238be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_ERROR; 239be168c0dSopenharmony_ci } 240be168c0dSopenharmony_ci- auto context = static_cast<mindspore::Context *>(model_context); 241be168c0dSopenharmony_ci+ auto context = static_cast<mindspore::ContextC *>(model_context); 242be168c0dSopenharmony_ci auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg); 243be168c0dSopenharmony_ci- if (impl->context_.get() != context) { 244be168c0dSopenharmony_ci- impl->context_.reset(context); 245be168c0dSopenharmony_ci+ if (impl->context_.get() != context->context_) { 246be168c0dSopenharmony_ci+ impl->context_.reset(context->context_); 247be168c0dSopenharmony_ci+ context->owned_by_model_ = true; 248be168c0dSopenharmony_ci } 249be168c0dSopenharmony_ci auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_, 250be168c0dSopenharmony_ci std::shared_ptr<mindspore::TrainCfg>(build_train_cfg)); 251be168c0dSopenharmony_ci-- 252be168c0dSopenharmony_ci2.17.1 253be168c0dSopenharmony_ci 254