1From fcce2a2794417a6ff16148dbb751e402e476084a Mon Sep 17 00:00:00 2001 2From: chengfeng27 <chengfeng27@huawei.com> 3Date: Tue, 23 Jul 2024 10:46:59 +0800 4Subject: [PATCH] fix memory leak 5 6--- 7 .../core/mindrt/src/thread/core_affinity.cc | 2 +- 8 mindspore/lite/BUILD.gn | 5 +- 9 mindspore/lite/src/common/mmap_utils.cc | 14 +- 10 mindspore/lite/src/common/mmap_utils.h | 2 +- 11 mindspore/lite/src/litert/cache_session.cc | 425 ++++++++++++++++++ 12 mindspore/lite/src/litert/cache_session.h | 129 ++++++ 13 .../src/litert/cxx_api/model/model_impl.cc | 36 +- 14 .../delegate/nnrt/extension_options_parser.cc | 12 + 15 .../delegate/nnrt/extension_options_parser.h | 2 + 16 mindspore/lite/src/litert/lite_model.cc | 12 +- 17 mindspore/lite/src/litert/lite_model.h | 2 +- 18 mindspore/lite/src/litert/lite_session.h | 6 +- 19 12 files changed, 631 insertions(+), 16 deletions(-) 20 create mode 100644 mindspore/lite/src/litert/cache_session.cc 21 create mode 100644 mindspore/lite/src/litert/cache_session.h 22 23diff --git a/mindspore/core/mindrt/src/thread/core_affinity.cc b/mindspore/core/mindrt/src/thread/core_affinity.cc 24index 6886f743..6d13724f 100644 25--- a/mindspore/core/mindrt/src/thread/core_affinity.cc 26+++ b/mindspore/core/mindrt/src/thread/core_affinity.cc 27@@ -217,7 +217,7 @@ int GetMaxFrequency(int core_id) { 28 29 float CoreAffinity::GetServerFrequency() { 30 float max_freq = -1.0f; 31-#if defined(__APPLE__) || defined(__MACOSX) || defined(_MSC_VER) || defined(_WIN32) 32+#if defined(__APPLE__) || defined(__MACOSX) || defined(_MSC_VER) || defined(_WIN32) || defined(MS_COMPILE_OHOS) 33 return max_freq; // MHz 34 #else 35 // The CPU cores in the server of the numa architecture are the same. 36diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 37index acee9733..d8ed3b44 100644 38--- a/mindspore/lite/BUILD.gn 39+++ b/mindspore/lite/BUILD.gn 40@@ -438,7 +438,10 @@ ohos_shared_library("mindspore_lib") { 41 if (SUPPORT_NNRT) { 42 if (mindspore_feature_nnrt_metagraph) { 43 defines += [ "SUPPORT_NNRT_METAGRAPH" ] 44- sources += [ "src/litert/delegate/nnrt/hiai_foundation_wrapper.cc", ] 45+ sources += [ 46+ "src/litert/delegate/nnrt/hiai_foundation_wrapper.cc", 47+ "src/litert/cache_session.cc", 48+ ] 49 print("enabled feature: mindspore_feature_nnrt_metagraph") 50 } 51 sources += [ 52diff --git a/mindspore/lite/src/common/mmap_utils.cc b/mindspore/lite/src/common/mmap_utils.cc 53index ca8f8d1e..0dd31f7c 100644 54--- a/mindspore/lite/src/common/mmap_utils.cc 55+++ b/mindspore/lite/src/common/mmap_utils.cc 56@@ -24,7 +24,7 @@ 57 58 namespace mindspore { 59 namespace lite { 60-void *ReadFileByMmap(const std::string &file, size_t *size) { 61+void *ReadFileByMmap(const std::string &file, size_t *size, bool populate) { 62 #if !defined(_WIN32) && !defined(_WIN64) && !defined(MS_COMPILE_IOS) 63 auto real_path = RealPath(file.c_str()); 64 auto fd = open(real_path.c_str(), O_RDONLY); 65@@ -39,7 +39,12 @@ void *ReadFileByMmap(const std::string &file, size_t *size) { 66 return nullptr; 67 } 68 *size = fd_stat.st_size; 69- auto mmap_buffers = mmap(nullptr, *size, PROT_READ, MAP_SHARED | MAP_POPULATE, fd, 0); 70+ void *mmap_buffers; 71+ if (populate) { 72+ mmap_buffers = mmap(nullptr, *size, PROT_READ, MAP_SHARED | MAP_POPULATE, fd, 0); 73+ } else { 74+ mmap_buffers = mmap(nullptr, *size, PROT_READ, MAP_SHARED, fd, 0); 75+ } 76 close(fd); 77 if (mmap_buffers == MAP_FAILED) { 78 MS_LOG(ERROR) << "Model mmap failed."; 79@@ -54,7 +59,10 @@ void *ReadFileByMmap(const std::string &file, size_t *size) { 80 81 void UnmapMmapBuffer(void *buffer, size_t size) { 82 #if !defined(_WIN32) && !defined(_WIN64) 83- (void)munmap(buffer, size); 84+ auto ret = munmap(buffer, size); 85+ if (ret != RET_OK) { 86+ MS_LOG(ERROR) << "munmap failed ret: " << ret << ", err: " << strerror(errno); 87+ } 88 #else 89 MS_LOG(ERROR) << "Mmap is unsupported on windows."; 90 #endif 91diff --git a/mindspore/lite/src/common/mmap_utils.h b/mindspore/lite/src/common/mmap_utils.h 92index bdd7c9a5..d3b0ec5f 100644 93--- a/mindspore/lite/src/common/mmap_utils.h 94+++ b/mindspore/lite/src/common/mmap_utils.h 95@@ -20,7 +20,7 @@ 96 97 namespace mindspore { 98 namespace lite { 99-void *ReadFileByMmap(const std::string &file, size_t *size); 100+void *ReadFileByMmap(const std::string &file, size_t *size, bool populate = true); 101 void UnmapMmapBuffer(void *buffer, size_t size); 102 } // namespace lite 103 } // namespace mindspore 104diff --git a/mindspore/lite/src/litert/cache_session.cc b/mindspore/lite/src/litert/cache_session.cc 105new file mode 100644 106index 00000000..7bafe3f7 107--- /dev/null 108+++ b/mindspore/lite/src/litert/cache_session.cc 109@@ -0,0 +1,425 @@ 110+/** 111+ * Copyright 2024 Huawei Technologies Co., Ltd 112+ * 113+ * Licensed under the Apache License, Version 2.0 (the "License"); 114+ * you may not use this file except in compliance with the License. 115+ * You may obtain a copy of the License at 116+ * 117+ * http://www.apache.org/licenses/LICENSE-2.0 118+ * 119+ * Unless required by applicable law or agreed to in writing, software 120+ * distributed under the License is distributed on an "AS IS" BASIS, 121+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 122+ * See the License for the specific language governing permissions and 123+ * limitations under the License. 124+ */ 125+ 126+#include "cache_session.h" 127+#include "src/common/context_util.h" 128+#include "src/common/tensor_util.h" 129+#include "src/common/mmap_utils.h" 130+#include "src/common/file_utils.h" 131+#include "src/litert/delegate/nnrt/nnrt_model_kernel.h" 132+ 133+namespace mindspore { 134+namespace lite { 135+CacheSession::~CacheSession() { 136+ if (nn_executor_ != nullptr) { 137+ OH_NNExecutor_Destroy(&nn_executor_); 138+ MS_LOG(INFO) << "Destroy NNExecutor Finish."; 139+ } 140+} 141+ 142+int CacheSession::CompileGraph(Model *model) { 143+ bool expected = false; 144+ if (!is_running_.compare_exchange_strong(expected, true)) { 145+ MS_LOG(ERROR) << "Not support multi-threading"; 146+ return RET_ERROR; 147+ } 148+ // Convert to abstract base model interface 149+ auto ret = ConvertInOutTensors(model); 150+ context_->set_schema_version(reinterpret_cast<LiteModel *>(model)->GetSchemaVersion()); 151+ if (ret != RET_OK) { 152+ MS_LOG(ERROR) << "ConvertTensors failed: " << ret; 153+ is_running_.store(false); 154+ return ret; 155+ } 156+ InitGraphInputTensors(model); 157+ InitGraphOutputTensors(model); 158+ 159+ // create NNRt kernel 160+ ret = ScheduleToNNRTKernel(); 161+ if (ret != RET_OK) { 162+ MS_LOG(ERROR) << "Schedule NNRt kernel failed: " << ret; 163+ is_running_.store(false); 164+ return ret; 165+ } 166+ 167+ InitGraphInOutTensorsMap(model); 168+ ret = PrepareKernels(model); 169+ if (ret != RET_OK) { 170+ MS_LOG(ERROR) << "Prepare kernels failed: " << ret; 171+ is_running_.store(false); 172+ return ret; 173+ } 174+ 175+ ret = InitExecutor(); 176+ if (ret != RET_OK) { 177+ MS_LOG(ERROR) << "InitExecutor failed: " << ret; 178+ is_running_.store(false); 179+ return ret; 180+ } 181+ 182+ MarkSharedWeight(kernels_); 183+ FreePackOpWeight(kernels_); 184+ 185+ is_running_.store(false); 186+ return RET_OK; 187+} 188+ 189+int CacheSession::InitExecutor() { 190+ executor_ = new (std::nothrow) Executor(); 191+ if (executor_ == nullptr) { 192+ MS_LOG(ERROR) << "New Executor failed"; 193+ return RET_ERROR; 194+ } 195+ auto ret = executor_->Prepare(kernels_, inputs_, outputs_, context_.get()); 196+ if (ret != RET_OK) { 197+ MS_LOG(ERROR) << "Prepare executor failed: " << ret; 198+ return ret; 199+ } 200+ return RET_OK; 201+} 202+ 203+int CacheSession::ConvertInOutTensors(const lite::Model *model) { 204+ MS_ASSERT(model != nullptr); 205+ auto lite_model = reinterpret_cast<const lite::LiteModel *>(model); 206+ uint32_t tensor_count = model->graph_.all_tensors_.size(); 207+ auto model_input_indices = model->graph_.input_indices_; 208+ auto model_output_indices = model->graph_.output_indices_; 209+ 210+ for (uint32_t i = 0; i < tensor_count; ++i) { 211+ auto *src_tensor = model->graph_.all_tensors_[i]; 212+ if (!IsContain(model_input_indices, i) && !IsContain(model_output_indices, i)) { 213+ this->tensors_.emplace_back(nullptr); 214+ continue; 215+ } 216+ if (src_tensor == nullptr) { 217+ MS_LOG(ERROR) << i << "th tensor in model is nullptr"; 218+ return RET_NULL_PTR; 219+ } 220+ auto *dst_tensor = ConvertTensor(*src_tensor); 221+ if (dst_tensor == nullptr) { 222+ MS_LOG(ERROR) << "Convert new " << i << "th tensor failed!"; 223+ return RET_NULL_PTR; 224+ } 225+ auto ret = ConvertTensorsData(lite_model, i, dst_tensor); 226+ if (ret != RET_OK) { 227+ MS_LOG(ERROR) << "Convert data of " << i << "th tensor failed"; 228+ delete dst_tensor; 229+ return ret; 230+ } 231+ ConvertTensorsQuantParam(src_tensor, dst_tensor); 232+ if (IsContain(model_input_indices, i)) { 233+ dst_tensor->set_category(Category::GRAPH_INPUT); 234+ } 235+ if (IsContain(model_output_indices, i)) { 236+ // a tensor is as both input and output, would be treated as an input. 237+ if (!dst_tensor->IsGraphInput()) { 238+ dst_tensor->set_category(Category::GRAPH_OUTPUT); 239+ } 240+ } 241+ 242+ ret = CheckTensorValid(dst_tensor); 243+ if (ret != RET_OK) { 244+ MS_LOG(ERROR) << "Check " << i << "th tensor failed"; 245+ delete dst_tensor; 246+ return ret; 247+ } 248+ 249+ this->tensors_.emplace_back(dst_tensor); 250+ } 251+ return RET_OK; 252+} 253+ 254+int CacheSession::Init(const std::shared_ptr<InnerContext> &context) { 255+ if (context == nullptr) { 256+ MS_LOG(ERROR) << "context is nullptr"; 257+ return RET_NULL_PTR; 258+ } 259+ bool expected = false; 260+ if (!is_running_.compare_exchange_strong(expected, true)) { 261+ MS_LOG(ERROR) << "Not support multi-threading"; 262+ return RET_ERROR; 263+ } 264+ context_ = context; 265+ auto ret = context_->Init(); 266+ if (ret != RET_OK) { 267+ MS_LOG(ERROR) << "Init Context failed"; 268+ return ret; 269+ } 270+ ms_context_ = MSContextFromContext(context); 271+ if (ms_context_ == nullptr) { 272+ MS_LOG(ERROR) << "transfer context to ms context failed."; 273+ return RET_NULL_PTR; 274+ } 275+ 276+ auto iter = std::find_if(context_->device_list_.begin(), context_->device_list_.end(), 277+ [](DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; }); 278+ if(iter == context_->device_list_.end()) { 279+ MS_LOG(ERROR) << "Found non NNRT device info"; 280+ return RET_ERROR; 281+ } 282+ nnrt_device_info_ = iter->device_info_.nnrt_device_info_; 283+ 284+ const auto &extensions = nnrt_device_info_.extensions_; 285+ mindspore::lite::nnrt::ExtensionOptionsParser::Parse(extensions, &extension_options_); 286+ 287+ is_running_.store(false); 288+ return RET_OK; 289+} 290+ 291+int CacheSession::ParseInputOutputFromModelBuffer(const char *model_buf, LiteModel *model) { 292+ const void *meta_graph = nullptr; 293+ meta_graph = reinterpret_cast<const void *>(schema::GetMetaGraph(model_buf)); 294+ assert(meta_graph != nullptr); 295+ 296+ auto status = GenerateModelInputOutput<schema::MetaGraph, schema::CNode>( 297+ *reinterpret_cast<const schema::MetaGraph *>(meta_graph), model->graph_); 298+ if (status != RET_OK) { 299+ MS_LOG(ERROR) << "fail to generate model"; 300+ return status; 301+ } 302+ model->buf = const_cast<char *>(model_buf); 303+ return RET_OK; 304+} 305+ 306+int CacheSession::LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type) { 307+ size_t model_size; 308+ bool use_mmap = IsMmapEnable(); 309+ auto model_buf = LoadModelByPath(model_path, model_type, &model_size, use_mmap); 310+ if (model_buf == nullptr) { 311+ MS_LOG(ERROR) << "Read model file failed"; 312+ return RET_ERROR; 313+ } 314+ 315+ Model *model = nullptr; 316+ if (extension_options_.cache_path_.empty()) { 317+ MS_LOG(ERROR) << "cache path is empty"; 318+ return RET_ERROR; 319+ } else { 320+ model = ImportInOutFromBuffer(model_buf, model_size, true, model_type, model_path); 321+ dynamic_cast<LiteModel *>(model)->PrepareInnerTensors(); 322+ } 323+ if (model == nullptr) { 324+ MS_LOG(ERROR) << "Import model failed"; 325+ return RET_ERROR; 326+ } 327+ 328+ if (use_mmap) { 329+ reinterpret_cast<lite::LiteModel *>(model)->model_buf_by_mmap_ = true; 330+ } else { 331+ MS_LOG(WARNING) << "Memory may exceed the limit of business demands."; 332+ } 333+ (reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true); 334+ auto ret = CompileGraph(model); 335+ if (ret != lite::RET_OK) { 336+ MS_LOG(ERROR) << "Compile model failed"; 337+ model->buf = nullptr; 338+ delete model; 339+ return RET_ERROR; 340+ } 341+ set_model(model); 342+ return RET_OK; 343+} 344+ 345+Model *CacheSession::ImportInOutFromBuffer(const char *model_buf, size_t size, bool take_buf, mindspore::ModelType model_type, 346+ const std::string &path) { 347+ MS_LOG(INFO) << "import model from lite model"; 348+ auto *model = new (std::nothrow) LiteModel(path); 349+ if (model == nullptr) { 350+ MS_LOG(ERROR) << "new model fail!"; 351+ return nullptr; 352+ } 353+ 354+ auto status = ParseInputOutputFromModelBuffer(model_buf, model); 355+ if (status != RET_OK) { 356+ MS_LOG(ERROR) << "construct model failed."; 357+ delete model; 358+ return nullptr; 359+ } 360+ model->buf = const_cast<char *>(model_buf); 361+ model->buf_size_ = size; 362+ return model; 363+} 364+ 365+int CacheSession::ScheduleToNNRTKernel() { 366+ if (!IsKirinNPUWithOnlineInference(nnrt_device_info_.device_id_)) { 367+ MS_LOG(ERROR) << "only support NPU_ device."; 368+ return RET_ERROR; 369+ } 370+ auto ret = CreateFullModelKernel(); 371+ if (ret != kSuccess) { 372+ MS_LOG(ERROR) << "Build npu model failed."; 373+ return RET_ERROR; 374+ } 375+ return RET_OK; 376+} 377+ 378+bool CacheSession::IsKirinNPUWithOnlineInference(size_t device_id) { 379+ const std::string kirin_npu_name_prefix = "NPU_"; 380+ const char *device_name; 381+ auto ret = OH_NNDevice_GetName(device_id, &device_name); 382+ if (ret != OH_NN_SUCCESS) { 383+ MS_LOG(WARNING) << "Get name of device: " << device_id << " failed, error: " << ret; 384+ return false; 385+ } 386+ 387+ if (strncmp(kirin_npu_name_prefix.c_str(), device_name, kirin_npu_name_prefix.size()) != 0) { 388+ MS_LOG(WARNING) << "strncmp: " << device_id << " failed, device_name: " << device_name; 389+ return false; 390+ } 391+ return true; 392+} 393+ 394+Status CacheSession::CreateFullModelKernel() { 395+ OH_NNCompilation* nn_compilation = OH_NNCompilation_ConstructForCache(); 396+ if (nn_compilation == nullptr) { 397+ MS_LOG(ERROR) << "Construct NNCompilation failed"; 398+ return kLiteError; 399+ } 400+ MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success."; 401+ 402+ auto ret_code = InitNNCompilation(nn_compilation); 403+ if (ret_code != kSuccess) { 404+ MS_LOG(ERROR) << "Init NNCompilation failed"; 405+ OH_NNCompilation_Destroy(&nn_compilation); 406+ return kLiteError; 407+ } 408+ 409+ OH_NNExecutor *nn_executor = nullptr; 410+ nn_executor = OH_NNExecutor_Construct(nn_compilation); 411+ if (nn_executor == nullptr) { 412+ MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code; 413+ OH_NNCompilation_Destroy(&nn_compilation); 414+ return kLiteError; 415+ } 416+ OH_NNCompilation_Destroy(&nn_compilation); 417+ 418+ ms_inputs_ = LiteTensorsToMSTensors(inputs_); 419+ ms_outputs_ = LiteTensorsToMSTensors(outputs_); 420+ auto nnrt_model_kernel = new (std::nothrow) NNRTModelKernel(nn_executor, nnrt_device_info_, ms_inputs_, ms_outputs_); 421+ if (nnrt_model_kernel == nullptr) { 422+ OH_NNExecutor_Destroy(&nn_executor); 423+ MS_LOG(ERROR) << "new NNRTModelKernel failed"; 424+ return kLiteError; 425+ } 426+ nn_executor_ = nn_executor; 427+ 428+ std::shared_ptr<kernel::Kernel> shared_kernel(nnrt_model_kernel); 429+ auto *kernel_exec = new (std::nothrow) kernel::KernelExec(shared_kernel); 430+ if (kernel_exec == nullptr) { 431+ MS_LOG(ERROR) << "nnrt kernel exec create failed."; 432+ return kLiteError; 433+ } 434+ auto delegate_type = kNumberTypeFloat32; 435+ for (auto &input : nnrt_model_kernel->inputs()) { 436+ if (static_cast<TypeId>(input.DataType()) == kNumberTypeFloat16) { 437+ delegate_type = kNumberTypeFloat16; 438+ break; 439+ } 440+ } 441+ kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, NHWC, schema::PrimitiveType_NONE, "", ""}; 442+ kernel_exec->set_desc(delegate_desc); 443+ kernel_exec->set_context(context_.get()); 444+ kernels_.push_back(kernel_exec); 445+ 446+ return kSuccess; 447+} 448+ 449+Status CacheSession::InitNNCompilation(OH_NNCompilation *nn_compilation) const { 450+ auto ret_code = OH_NNCompilation_SetDevice(nn_compilation, nnrt_device_info_.device_id_); 451+ if (ret_code != OH_NN_SUCCESS) { 452+ MS_LOG(ERROR) << "NNCompilation set device id failed, ret: " << ret_code; 453+ return kLiteError; 454+ } 455+ ret_code = OH_NNCompilation_SetPerformanceMode(nn_compilation, 456+ (OH_NN_PerformanceMode)(nnrt_device_info_.performance_mode_)); 457+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 458+ MS_LOG(ERROR) << "NNCompilation set performance mode failed, ret: " << ret_code; 459+ return kLiteError; 460+ } 461+ ret_code = OH_NNCompilation_SetPriority(nn_compilation, (OH_NN_Priority)(nnrt_device_info_.priority_)); 462+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 463+ MS_LOG(ERROR) << "NNCompilation set priority failed, ret: " << ret_code; 464+ return kLiteError; 465+ } 466+ ret_code = OH_NNCompilation_EnableFloat16(nn_compilation, nnrt_device_info_.enable_fp16_); 467+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 468+ MS_LOG(ERROR) << "NNCompilation enable fp16 failed, ret: " << ret_code; 469+ return kLiteError; 470+ } 471+ 472+ if (!extension_options_.cache_path_.empty()) { 473+ ret_code = OH_NNCompilation_SetCache(nn_compilation, extension_options_.cache_path_.c_str(), 474+ extension_options_.cache_version_); 475+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 476+ MS_LOG(ERROR) << "NNCompilation set cache failed, ret: " << ret_code; 477+ return kLiteError; 478+ } 479+ } else { 480+ MS_LOG(ERROR) << "NNCompilation must set Cache."; 481+ return kLiteError; 482+ } 483+ 484+ size_t extension_size = nnrt_device_info_.extensions_.size(); 485+ for (size_t i = 0; i < extension_size; i++) { 486+ auto &src_extensoin = nnrt_device_info_.extensions_[i]; 487+ ret_code = OH_NNCompilation_AddExtensionConfig(nn_compilation, src_extensoin.name.c_str(), 488+ (char *)((void *)src_extensoin.value.data()), 489+ src_extensoin.value.size()); 490+ if (ret_code != OH_NN_SUCCESS) { 491+ MS_LOG(ERROR) << "OH_NNCompilation_AddExtensionConfig " << i << ": "<< src_extensoin.name << " failed, ret: " 492+ << ret_code; 493+ return kLiteError; 494+ } 495+ } 496+ 497+ ret_code = OH_NNCompilation_Build(nn_compilation); 498+ if (ret_code != OH_NN_SUCCESS) { 499+ MS_LOG(ERROR) << "Build NNCompilation failed, ret: " << ret_code; 500+ return kLiteError; 501+ } 502+ return kSuccess; 503+} 504+ 505+const char *CacheSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, bool use_mmap) { 506+ size_t buf_size; 507+ char *model_buf; 508+ if (use_mmap) { 509+ model_buf = reinterpret_cast<char *>(lite::ReadFileByMmap(file.c_str(), &buf_size, false)); 510+ } else { 511+ MS_LOG(WARNING) << "Memory may exceed the limit of business demands."; 512+ model_buf = lite::ReadFile(file.c_str(), &buf_size); 513+ } 514+ if (model_buf == nullptr) { 515+ MS_LOG(ERROR) << "The model path is invalid"; 516+ return model_buf; 517+ } 518+ 519+ char *lite_buf = nullptr; 520+ auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, size, model_type); 521+ if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) { 522+ if (use_mmap) { 523+ lite::UnmapMmapBuffer(const_cast<void *>(static_cast<const void *>(model_buf)), buf_size); 524+ } else { 525+ delete[] model_buf; 526+ } 527+ model_buf = nullptr; 528+ return nullptr; 529+ } 530+ 531+ return lite_buf; 532+} 533+} // namespace lite 534+} // namespace mindspore 535diff --git a/mindspore/lite/src/litert/cache_session.h b/mindspore/lite/src/litert/cache_session.h 536new file mode 100644 537index 00000000..f0ae185a 538--- /dev/null 539+++ b/mindspore/lite/src/litert/cache_session.h 540@@ -0,0 +1,129 @@ 541+/** 542+ * Copyright 2024 Huawei Technologies Co., Ltd 543+ * 544+ * Licensed under the Apache License, Version 2.0 (the "License"); 545+ * you may not use this file except in compliance with the License. 546+ * You may obtain a copy of the License at 547+ * 548+ * http://www.apache.org/licenses/LICENSE-2.0 549+ * 550+ * Unless required by applicable law or agreed to in writing, software 551+ * distributed under the License is distributed on an "AS IS" BASIS, 552+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 553+ * See the License for the specific language governing permissions and 554+ * limitations under the License. 555+ */ 556+ 557+#ifndef MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_ 558+#define MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_ 559+ 560+#include "src/litert/lite_session.h" 561+#include "src/litert/inner_context.h" 562+#include "src/litert/lite_model.h" 563+#include "src/litert/delegate/nnrt/extension_options_parser.h" 564+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime_type.h" 565+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 566+#include "interfaces/innerkits/c/neural_network_runtime_inner.h" 567+ 568+namespace mindspore { 569+namespace lite { 570+class CacheSession : public LiteSession { 571+ public: 572+ CacheSession() = default; 573+ ~CacheSession() override; 574+ int Init(const std::shared_ptr<InnerContext> &context) override; 575+ int CompileGraph(Model *model) override; 576+ int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type) override; 577+ static bool IsKirinNPUWithOnlineInference(size_t device_id); 578+ const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, 579+ bool use_mmap) override; 580+ Model* ImportInOutFromBuffer(const char *model_buf, size_t size, bool take_buf, 581+ mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite, 582+ const std::string &path = ""); 583+ 584+ template <typename T = schema::MetaGraph> 585+ bool ConvertInputOutputTensors(const T &meta_graph, LiteGraph &graph_) { 586+ if (meta_graph.allTensors() == nullptr) { 587+ MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 588+ return false; 589+ } 590+ 591+ graph_.all_tensors_.resize(meta_graph.allTensors()->size()); 592+ MS_LOG(INFO) << "convert input/output tensors"; 593+ for (auto i: graph_.input_indices_) { 594+ auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i); 595+ if (tensor == nullptr) { 596+ MS_LOG(ERROR) << i << " the input tensor in metagraph is nullptr"; 597+ return false; 598+ } 599+ MS_CHECK_TRUE_RET(tensor->format() >= schema::Format_MIN && tensor->format() <= schema::Format_MAX, false); 600+ graph_.all_tensors_[i] = (const_cast<mindspore::schema::Tensor *>(tensor)); 601+ } 602+ 603+ for (auto i: graph_.output_indices_) { 604+ auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i); 605+ if (tensor == nullptr) { 606+ MS_LOG(ERROR) << i << " the output tensor in metagraph is nullptr"; 607+ } 608+ MS_CHECK_TRUE_RET(tensor->format() >= schema::Format_MIN && tensor->format() <= schema::Format_MAX, false); 609+ graph_.all_tensors_[i] = (const_cast<mindspore::schema::Tensor *>(tensor)); 610+ } 611+ return true; 612+ } 613+ 614+ template <typename T = schema::MetaGraph, typename U = schema::CNode> 615+ int GenerateModelInputOutput(const T &meta_graph, LiteGraph &graph_) { 616+ if (meta_graph.name() != nullptr) { 617+ graph_.name_ = meta_graph.name()->c_str(); 618+ } 619+ if (meta_graph.version() != nullptr) { 620+ graph_.version_ = meta_graph.version()->c_str(); 621+ } 622+ 623+ if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || 624+ meta_graph.allTensors() == nullptr) { 625+ MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 626+ return RET_ERROR; 627+ } 628+ 629+ // converterInputOutput 630+ auto in_count = meta_graph.inputIndex()->size(); 631+ for (uint32_t i = 0; i < in_count; ++i) { 632+ graph_.input_indices_.push_back(meta_graph.inputIndex()->Get(i)); 633+ } 634+ auto out_count = meta_graph.outputIndex()->size(); 635+ for (uint32_t i = 0; i < out_count; ++i) { 636+ graph_.output_indices_.push_back(meta_graph.outputIndex()->Get(i)); 637+ } 638+ 639+ if (!ConvertInputOutputTensors<T>(meta_graph, graph_)) { 640+ MS_LOG(ERROR) << "convert tensor failed"; 641+ return RET_ERROR; 642+ } 643+ return RET_OK; 644+ } 645+ 646+ int ParseInputOutputFromModelBuffer(const char *model_buf, LiteModel *model); 647+ int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture, 648+ std::map<std::string, unsigned int> *outputGLTexture) override { 649+ return RET_ERROR; 650+ } 651+ 652+ protected: 653+ int ScheduleToNNRTKernel(); 654+ Status CreateFullModelKernel(); 655+ Status InitNNCompilation(OH_NNCompilation *nn_compilation) const; 656+ int ConvertInOutTensors(const lite::Model *model); 657+ int InitExecutor() override; 658+ std::vector<mindspore::MSTensor> ms_inputs_; 659+ std::vector<mindspore::MSTensor> ms_outputs_; 660+ 661+ private: 662+ NNRtDeviceInfo nnrt_device_info_; 663+ OH_NNExecutor *nn_executor_{nullptr}; 664+ nnrt::ExtensionOptions extension_options_; 665+}; 666+} // namespace lite 667+} // namespace mindspore 668+ 669+#endif // MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_ 670diff --git a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 671index 02533dc3..cacbf86e 100644 672--- a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 673+++ b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 674@@ -39,6 +39,11 @@ 675 #include "src/common/config_file.h" 676 #include "src/litert/cpu_info.h" 677 #include "src/litert/pack_weight_manager.h" 678+#ifdef SUPPORT_NNRT_METAGRAPH 679+#include "src/litert/cache_session.h" 680+#include "src/litert/delegate/nnrt/extension_options_parser.h" 681+#endif 682+ 683 namespace mindspore { 684 namespace { 685 const char *const kExecutionPlan = "execution_plan"; 686@@ -1006,7 +1011,36 @@ float ModelImpl::GetLearningRate() { 687 } 688 689 lite::LiteSession *ModelImpl::CreateLiteSession(const std::shared_ptr<lite::InnerContext> &context) { 690- auto session = new (std::nothrow) lite::LiteSession(); 691+ if (context == nullptr) { 692+ MS_LOG(ERROR) << "context is nullptr"; 693+ return nullptr; 694+ } 695+ lite::LiteSession *session = nullptr; 696+#ifdef SUPPORT_NNRT_METAGRAPH 697+ auto iter = std::find_if(context->device_list_.begin(), context->device_list_.end(), 698+ [](lite::DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; }); 699+ if(iter != context->device_list_.end()) { 700+ const auto &nnrt_device_info = iter->device_info_.nnrt_device_info_; 701+ if (lite::CacheSession::IsKirinNPUWithOnlineInference(nnrt_device_info.device_id_)) { 702+ const auto &extensions = nnrt_device_info.extensions_; 703+ lite::nnrt::ExtensionOptions extension_options; 704+ mindspore::lite::nnrt::ExtensionOptionsParser::Parse(extensions, &extension_options); 705+ auto has_cache = OH_NNModel_HasCache(extension_options.cache_path_.c_str(), extension_options.model_name.c_str(), 706+ extension_options.cache_version_); 707+ if (has_cache) { 708+ session = reinterpret_cast<lite::LiteSession *>(new (std::nothrow) lite::CacheSession()); 709+ if (session == nullptr) { 710+ MS_LOG(ERROR) << "create cache session failed"; 711+ return nullptr; 712+ } 713+ } 714+ } 715+ } 716+#endif 717+ 718+ if (session == nullptr) { 719+ session = new (std::nothrow) lite::LiteSession(); 720+ } 721 if (session == nullptr) { 722 MS_LOG(ERROR) << "create session failed"; 723 return nullptr; 724diff --git a/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.cc b/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.cc 725index e35cc2a5..a66cd5ea 100644 726--- a/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.cc 727+++ b/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.cc 728@@ -30,6 +30,7 @@ const std::string kCachePath = "CachePath"; 729 const std::string kCacheVersion = "CacheVersion"; 730 const std::string kBandMode = "BandMode"; 731 const std::string kQuantConfigData = "QuantConfigData"; 732+const std::string kModelName = "ModelName"; 733 } // namespace 734 735 int ExtensionOptionsParser::Parse(const std::vector<Extension> &extensions, ExtensionOptions *param) { 736@@ -39,6 +40,7 @@ int ExtensionOptionsParser::Parse(const std::vector<Extension> &extensions, Exte 737 DoParseCacheVersion(extensions, ¶m->cache_version_); 738 DoParseBondMode(extensions, ¶m->band_mode); 739 DoParseQuantConfig(extensions, ¶m->quant_config, ¶m->quant_config_size, ¶m->is_optional_quant_setted); 740+ DoParseModelName(extensions, ¶m->model_name); 741 return RET_OK; 742 } 743 744@@ -89,4 +91,14 @@ void ExtensionOptionsParser::DoParseQuantConfig(const std::vector<Extension> &ex 745 *quant_setted = true; 746 } 747 } 748+ 749+void ExtensionOptionsParser::DoParseModelName(const std::vector<Extension> &extensions, std::string *model_name) { 750+ MS_CHECK_TRUE_RET_VOID(model_name != nullptr); 751+ auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) { 752+ return extension.name == kModelName; 753+ }); 754+ if (iter_config != extensions.end()) { 755+ *model_name = std::string(iter_config->value.begin(), iter_config->value.end()); 756+ } 757+} 758 } // mindspore::lite::nnrt 759\ No newline at end of file 760diff --git a/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.h b/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.h 761index f24682ce..9a030ad6 100644 762--- a/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.h 763+++ b/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.h 764@@ -29,6 +29,7 @@ struct ExtensionOptions { 765 void *quant_config; 766 size_t quant_config_size = 0; 767 bool is_optional_quant_setted = false; 768+ std::string model_name = ""; 769 }; 770 771 class ExtensionOptionsParser { 772@@ -41,6 +42,7 @@ private: 773 bool *quant_setted); 774 static void DoParseCachePath(const std::vector<Extension> &extensions, std::string *cache_path); 775 static void DoParseCacheVersion(const std::vector<Extension> &extensions, uint32_t *cache_version); 776+ static void DoParseModelName(const std::vector<Extension> &extensions, std::string *model_name); 777 }; 778 779 } // namespace mindspore::lite::nnrt 780diff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc 781index 006bc02c..5acf5760 100644 782--- a/mindspore/lite/src/litert/lite_model.cc 783+++ b/mindspore/lite/src/litert/lite_model.cc 784@@ -538,14 +538,16 @@ bool LiteModel::PrepareInnerTensors() { 785 MS_LOG(ERROR) << "Create SchemaTensorWrapper return nullptr"; 786 return false; 787 } 788+ if (graph_.all_tensors_.at(i) != nullptr) { 789 #ifdef ENABLE_LITE_HELPER 790- if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir, 791- infer_helpers)) { 792+ if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir, 793+ infer_helpers)) { 794 #else 795- if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir)) { 796+ if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir)) { 797 #endif 798- delete tensor_wrapper; 799- return false; 800+ delete tensor_wrapper; 801+ return false; 802+ } 803 } 804 this->inner_all_tensors_[i] = tensor_wrapper; 805 } 806diff --git a/mindspore/lite/src/litert/lite_model.h b/mindspore/lite/src/litert/lite_model.h 807index 647746a2..c0847c1e 100644 808--- a/mindspore/lite/src/litert/lite_model.h 809+++ b/mindspore/lite/src/litert/lite_model.h 810@@ -66,13 +66,13 @@ class MS_API LiteModel : public Model { 811 812 static int VersionVerify(flatbuffers::Verifier *verify); 813 814- private: 815 #ifdef ENABLE_LITE_HELPER 816 bool PrepareInnerTensors(mindspore::infer::helper::InferHelpers *infer_helpers = nullptr); 817 #else 818 bool PrepareInnerTensors(); 819 #endif 820 821+ private: 822 bool CheckQuantAllInit(const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params); 823 824 template <typename T = schema::MetaGraph, typename U = schema::CNode> 825diff --git a/mindspore/lite/src/litert/lite_session.h b/mindspore/lite/src/litert/lite_session.h 826index 64a5f6d3..487b382a 100644 827--- a/mindspore/lite/src/litert/lite_session.h 828+++ b/mindspore/lite/src/litert/lite_session.h 829@@ -57,10 +57,10 @@ class MS_API LiteSession { 830 #else 831 int LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size); 832 #endif 833- int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type); 834+ virtual int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type); 835 mindspore::ModelType LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf, size_t *size, 836 mindspore::ModelType model_type); 837- const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, bool use_mmap); 838+ virtual const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, bool use_mmap); 839 virtual int Init(const std::shared_ptr<InnerContext> &context); 840 virtual void BindThread(bool if_bind); 841 virtual int CompileGraph(Model *model); 842@@ -168,10 +168,10 @@ class MS_API LiteSession { 843 static void MarkSharedWeight(const std::vector<kernel::KernelExec *> &kernels); 844 std::string ParseWeightPath(); 845 bool IsMmapEnable(); 846+ virtual int InitExecutor(); 847 848 private: 849 int PreCheck(Model *model); 850- int InitExecutor(); 851 void ResetInputsShape(const std::vector<std::vector<int>> &dims); 852 int ContextInit(const std::shared_ptr<InnerContext> &context); 853 int CreateTensorRTDelegate(); 854-- 8552.17.1 856 857