From 49309e193d11f8546f0f564b74ea4adc1425cd9f Mon Sep 17 00:00:00 2001 From: chengfeng27 Date: Mon, 9 Sep 2024 09:55:11 +0800 Subject: [PATCH] fix context double free --- mindspore/lite/src/litert/c_api/context_c.cc | 50 ++++++++++++-------- mindspore/lite/src/litert/c_api/context_c.h | 16 +++++++ mindspore/lite/src/litert/c_api/model_c.cc | 29 +++++++----- 3 files changed, 63 insertions(+), 32 deletions(-) diff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc index 88fdc4d0..5418c46a 100644 --- a/mindspore/lite/src/litert/c_api/context_c.cc +++ b/mindspore/lite/src/litert/c_api/context_c.cc @@ -17,6 +17,7 @@ #include "include/api/context.h" #include #include "src/litert/c_api/type_c_private.h" +#include "src/litert/c_api/context_c.h" #include "src/common/log_adapter.h" #ifdef SUPPORT_NNRT_METAGRAPH #include "src/litert/delegate/nnrt/hiai_foundation_wrapper.h" @@ -31,17 +32,26 @@ const auto kNpuNamePrefixLen = 4; // ================ Context ================ OH_AI_ContextHandle OH_AI_ContextCreate() { - auto impl = new (std::nothrow) mindspore::Context(); + auto impl = new (std::nothrow) mindspore::ContextC(); if (impl == nullptr) { MS_LOG(ERROR) << "memory allocation failed."; return nullptr; } + impl->context_ = new (std::nothrow) mindspore::Context(); + if (impl->context_ == nullptr) { + MS_LOG(ERROR) << "memory allocation failed."; + delete impl; + } + impl->owned_by_model_ = false; return static_cast(impl); } void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { if (context != nullptr && *context != nullptr) { - auto impl = static_cast(*context); + auto impl = static_cast(*context); + if (impl->owned_by_model_) { + impl->context_ = nullptr; + } delete impl; *context = nullptr; } @@ -52,8 +62,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(context); - impl->SetThreadNum(thread_num); + auto impl = static_cast(context); + impl->context_->SetThreadNum(thread_num); } int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { @@ -61,8 +71,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { MS_LOG(ERROR) << "param is nullptr."; return 0; } - auto impl = static_cast(context); - return impl->GetThreadNum(); + auto impl = static_cast(context); + return impl->context_->GetThreadNum(); } void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { @@ -70,8 +80,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(context); - impl->SetThreadAffinity(mode); + auto impl = static_cast(context); + impl->context_->SetThreadAffinity(mode); return; } @@ -80,8 +90,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { MS_LOG(ERROR) << "param is nullptr."; return 0; } - auto impl = static_cast(context); - return impl->GetThreadAffinityMode(); + auto impl = static_cast(context); + return impl->context_->GetThreadAffinityMode(); } void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) { @@ -90,8 +100,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i return; } const std::vector vec_core_list(core_list, core_list + core_num); - auto impl = static_cast(context); - impl->SetThreadAffinity(vec_core_list); + auto impl = static_cast(context); + impl->context_->SetThreadAffinity(vec_core_list); return; } @@ -100,8 +110,8 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle MS_LOG(ERROR) << "param is nullptr."; return nullptr; } - auto impl = static_cast(context); - auto affinity_core_list = impl->GetThreadAffinityCoreList(); + auto impl = static_cast(context); + auto affinity_core_list = impl->context_->GetThreadAffinityCoreList(); *core_num = affinity_core_list.size(); int32_t *core_list = static_cast(malloc((*core_num) * sizeof(int32_t))); if (core_list == nullptr) { @@ -119,8 +129,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(context); - impl->SetEnableParallel(is_parallel); + auto impl = static_cast(context); + impl->context_->SetEnableParallel(is_parallel); } bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { @@ -128,8 +138,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { MS_LOG(ERROR) << "param is nullptr."; return false; } - auto impl = static_cast(context); - return impl->GetEnableParallel(); + auto impl = static_cast(context); + return impl->context_->GetEnableParallel(); } void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) { @@ -137,9 +147,9 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan MS_LOG(ERROR) << "param is nullptr."; return; } - auto impl = static_cast(context); + auto impl = static_cast(context); std::shared_ptr device(static_cast(device_info)); - impl->MutableDeviceInfo().push_back(device); + impl->context_->MutableDeviceInfo().push_back(device); } // ================ DeviceInfo ================ diff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h index dc88b8a4..34f4d1e4 100644 --- a/mindspore/lite/src/litert/c_api/context_c.h +++ b/mindspore/lite/src/litert/c_api/context_c.h @@ -20,5 +20,21 @@ #include #include #include "include/c_api/types_c.h" +#include "include/api/context.h" + +namespace mindspore { +class ContextC { + public: + ContextC() : owned_by_model_(false), context_(nullptr) {} + ~ContextC() { + if (context_ != nullptr) { + delete context_; + context_ = nullptr; + } + } + bool owned_by_model_; + Context *context_; +}; +} // namespace mindspore #endif // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_ diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc index 661a8d06..d8632338 100644 --- a/mindspore/lite/src/litert/c_api/model_c.cc +++ b/mindspore/lite/src/litert/c_api/model_c.cc @@ -15,6 +15,7 @@ */ #include "include/c_api/model_c.h" #include "type_c_private.h" +#include "context_c.h" #include #include #include "include/api/context.h" @@ -191,10 +192,11 @@ OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, s MS_LOG(ERROR) << "model_type is invalid."; return OH_AI_STATUS_LITE_PARAM_INVALID; } - mindspore::Context *context = static_cast(model_context); + mindspore::ContextC *context = static_cast(model_context); auto impl = static_cast(model); - if (impl->context_.get() != context) { - impl->context_.reset(context); + if (impl->context_.get() != context->context_) { + impl->context_.reset(context->context_); + context->owned_by_model_ = true; } auto ret = impl->model_->Build(model_data, data_size, static_cast(model_type), impl->context_); return static_cast(ret.StatusCode()); @@ -210,10 +212,11 @@ OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model MS_LOG(ERROR) << "model_type is invalid."; return OH_AI_STATUS_LITE_PARAM_INVALID; } - mindspore::Context *context = static_cast(model_context); + mindspore::ContextC *context = static_cast(model_context); auto impl = static_cast(model); - if (impl->context_.get() != context) { - impl->context_.reset(context); + if (impl->context_.get() != context->context_) { + impl->context_.reset(context->context_); + context->owned_by_model_ = true; } auto ret = impl->model_->Build(model_path, static_cast(model_type), impl->context_); return static_cast(ret.StatusCode()); @@ -447,10 +450,11 @@ OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_da MS_LOG(ERROR) << "load ms file failed."; return OH_AI_STATUS_LITE_ERROR; } - auto context = static_cast(model_context); + auto context = static_cast(model_context); auto build_train_cfg = static_cast(train_cfg); - if (impl->context_.get() != context) { - impl->context_.reset(context); + if (impl->context_.get() != context->context_) { + impl->context_.reset(context->context_); + context->owned_by_model_ = true; } auto ret = impl->model_->Build(static_cast(graph), impl->context_, std::shared_ptr(build_train_cfg)); @@ -479,10 +483,11 @@ OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char * MS_LOG(ERROR) << "load ms file failed. " << model_path; return OH_AI_STATUS_LITE_ERROR; } - auto context = static_cast(model_context); + auto context = static_cast(model_context); auto build_train_cfg = static_cast(train_cfg); - if (impl->context_.get() != context) { - impl->context_.reset(context); + if (impl->context_.get() != context->context_) { + impl->context_.reset(context->context_); + context->owned_by_model_ = true; } auto ret = impl->model_->Build(static_cast(graph), impl->context_, std::shared_ptr(build_train_cfg)); -- 2.17.1