1be168c0dSopenharmony_ciFrom fcce2a2794417a6ff16148dbb751e402e476084a Mon Sep 17 00:00:00 2001 2be168c0dSopenharmony_ciFrom: chengfeng27 <chengfeng27@huawei.com> 3be168c0dSopenharmony_ciDate: Tue, 23 Jul 2024 10:46:59 +0800 4be168c0dSopenharmony_ciSubject: [PATCH] fix memory leak 5be168c0dSopenharmony_ci 6be168c0dSopenharmony_ci--- 7be168c0dSopenharmony_ci .../core/mindrt/src/thread/core_affinity.cc | 2 +- 8be168c0dSopenharmony_ci mindspore/lite/BUILD.gn | 5 +- 9be168c0dSopenharmony_ci mindspore/lite/src/common/mmap_utils.cc | 14 +- 10be168c0dSopenharmony_ci mindspore/lite/src/common/mmap_utils.h | 2 +- 11be168c0dSopenharmony_ci mindspore/lite/src/litert/cache_session.cc | 425 ++++++++++++++++++ 12be168c0dSopenharmony_ci mindspore/lite/src/litert/cache_session.h | 129 ++++++ 13be168c0dSopenharmony_ci .../src/litert/cxx_api/model/model_impl.cc | 36 +- 14be168c0dSopenharmony_ci .../delegate/nnrt/extension_options_parser.cc | 12 + 15be168c0dSopenharmony_ci .../delegate/nnrt/extension_options_parser.h | 2 + 16be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_model.cc | 12 +- 17be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_model.h | 2 +- 18be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_session.h | 6 +- 19be168c0dSopenharmony_ci 12 files changed, 631 insertions(+), 16 deletions(-) 20be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/cache_session.cc 21be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/cache_session.h 22be168c0dSopenharmony_ci 23be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/core_affinity.cc b/mindspore/core/mindrt/src/thread/core_affinity.cc 24be168c0dSopenharmony_ciindex 6886f743..6d13724f 100644 25be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/core_affinity.cc 26be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/core_affinity.cc 27be168c0dSopenharmony_ci@@ -217,7 +217,7 @@ int GetMaxFrequency(int core_id) { 28be168c0dSopenharmony_ci 29be168c0dSopenharmony_ci float CoreAffinity::GetServerFrequency() { 30be168c0dSopenharmony_ci float max_freq = -1.0f; 31be168c0dSopenharmony_ci-#if defined(__APPLE__) || defined(__MACOSX) || defined(_MSC_VER) || defined(_WIN32) 32be168c0dSopenharmony_ci+#if defined(__APPLE__) || defined(__MACOSX) || defined(_MSC_VER) || defined(_WIN32) || defined(MS_COMPILE_OHOS) 33be168c0dSopenharmony_ci return max_freq; // MHz 34be168c0dSopenharmony_ci #else 35be168c0dSopenharmony_ci // The CPU cores in the server of the numa architecture are the same. 36be168c0dSopenharmony_cidiff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 37be168c0dSopenharmony_ciindex acee9733..d8ed3b44 100644 38be168c0dSopenharmony_ci--- a/mindspore/lite/BUILD.gn 39be168c0dSopenharmony_ci+++ b/mindspore/lite/BUILD.gn 40be168c0dSopenharmony_ci@@ -438,7 +438,10 @@ ohos_shared_library("mindspore_lib") { 41be168c0dSopenharmony_ci if (SUPPORT_NNRT) { 42be168c0dSopenharmony_ci if (mindspore_feature_nnrt_metagraph) { 43be168c0dSopenharmony_ci defines += [ "SUPPORT_NNRT_METAGRAPH" ] 44be168c0dSopenharmony_ci- sources += [ "src/litert/delegate/nnrt/hiai_foundation_wrapper.cc", ] 45be168c0dSopenharmony_ci+ sources += [ 46be168c0dSopenharmony_ci+ "src/litert/delegate/nnrt/hiai_foundation_wrapper.cc", 47be168c0dSopenharmony_ci+ "src/litert/cache_session.cc", 48be168c0dSopenharmony_ci+ ] 49be168c0dSopenharmony_ci print("enabled feature: mindspore_feature_nnrt_metagraph") 50be168c0dSopenharmony_ci } 51be168c0dSopenharmony_ci sources += [ 52be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/mmap_utils.cc b/mindspore/lite/src/common/mmap_utils.cc 53be168c0dSopenharmony_ciindex ca8f8d1e..0dd31f7c 100644 54be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/mmap_utils.cc 55be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/mmap_utils.cc 56be168c0dSopenharmony_ci@@ -24,7 +24,7 @@ 57be168c0dSopenharmony_ci 58be168c0dSopenharmony_ci namespace mindspore { 59be168c0dSopenharmony_ci namespace lite { 60be168c0dSopenharmony_ci-void *ReadFileByMmap(const std::string &file, size_t *size) { 61be168c0dSopenharmony_ci+void *ReadFileByMmap(const std::string &file, size_t *size, bool populate) { 62be168c0dSopenharmony_ci #if !defined(_WIN32) && !defined(_WIN64) && !defined(MS_COMPILE_IOS) 63be168c0dSopenharmony_ci auto real_path = RealPath(file.c_str()); 64be168c0dSopenharmony_ci auto fd = open(real_path.c_str(), O_RDONLY); 65be168c0dSopenharmony_ci@@ -39,7 +39,12 @@ void *ReadFileByMmap(const std::string &file, size_t *size) { 66be168c0dSopenharmony_ci return nullptr; 67be168c0dSopenharmony_ci } 68be168c0dSopenharmony_ci *size = fd_stat.st_size; 69be168c0dSopenharmony_ci- auto mmap_buffers = mmap(nullptr, *size, PROT_READ, MAP_SHARED | MAP_POPULATE, fd, 0); 70be168c0dSopenharmony_ci+ void *mmap_buffers; 71be168c0dSopenharmony_ci+ if (populate) { 72be168c0dSopenharmony_ci+ mmap_buffers = mmap(nullptr, *size, PROT_READ, MAP_SHARED | MAP_POPULATE, fd, 0); 73be168c0dSopenharmony_ci+ } else { 74be168c0dSopenharmony_ci+ mmap_buffers = mmap(nullptr, *size, PROT_READ, MAP_SHARED, fd, 0); 75be168c0dSopenharmony_ci+ } 76be168c0dSopenharmony_ci close(fd); 77be168c0dSopenharmony_ci if (mmap_buffers == MAP_FAILED) { 78be168c0dSopenharmony_ci MS_LOG(ERROR) << "Model mmap failed."; 79be168c0dSopenharmony_ci@@ -54,7 +59,10 @@ void *ReadFileByMmap(const std::string &file, size_t *size) { 80be168c0dSopenharmony_ci 81be168c0dSopenharmony_ci void UnmapMmapBuffer(void *buffer, size_t size) { 82be168c0dSopenharmony_ci #if !defined(_WIN32) && !defined(_WIN64) 83be168c0dSopenharmony_ci- (void)munmap(buffer, size); 84be168c0dSopenharmony_ci+ auto ret = munmap(buffer, size); 85be168c0dSopenharmony_ci+ if (ret != RET_OK) { 86be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "munmap failed ret: " << ret << ", err: " << strerror(errno); 87be168c0dSopenharmony_ci+ } 88be168c0dSopenharmony_ci #else 89be168c0dSopenharmony_ci MS_LOG(ERROR) << "Mmap is unsupported on windows."; 90be168c0dSopenharmony_ci #endif 91be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/mmap_utils.h b/mindspore/lite/src/common/mmap_utils.h 92be168c0dSopenharmony_ciindex bdd7c9a5..d3b0ec5f 100644 93be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/mmap_utils.h 94be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/mmap_utils.h 95be168c0dSopenharmony_ci@@ -20,7 +20,7 @@ 96be168c0dSopenharmony_ci 97be168c0dSopenharmony_ci namespace mindspore { 98be168c0dSopenharmony_ci namespace lite { 99be168c0dSopenharmony_ci-void *ReadFileByMmap(const std::string &file, size_t *size); 100be168c0dSopenharmony_ci+void *ReadFileByMmap(const std::string &file, size_t *size, bool populate = true); 101be168c0dSopenharmony_ci void UnmapMmapBuffer(void *buffer, size_t size); 102be168c0dSopenharmony_ci } // namespace lite 103be168c0dSopenharmony_ci } // namespace mindspore 104be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cache_session.cc b/mindspore/lite/src/litert/cache_session.cc 105be168c0dSopenharmony_cinew file mode 100644 106be168c0dSopenharmony_ciindex 00000000..7bafe3f7 107be168c0dSopenharmony_ci--- /dev/null 108be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cache_session.cc 109be168c0dSopenharmony_ci@@ -0,0 +1,425 @@ 110be168c0dSopenharmony_ci+/** 111be168c0dSopenharmony_ci+ * Copyright 2024 Huawei Technologies Co., Ltd 112be168c0dSopenharmony_ci+ * 113be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 114be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 115be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 116be168c0dSopenharmony_ci+ * 117be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 118be168c0dSopenharmony_ci+ * 119be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 120be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 121be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 122be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 123be168c0dSopenharmony_ci+ * limitations under the License. 124be168c0dSopenharmony_ci+ */ 125be168c0dSopenharmony_ci+ 126be168c0dSopenharmony_ci+#include "cache_session.h" 127be168c0dSopenharmony_ci+#include "src/common/context_util.h" 128be168c0dSopenharmony_ci+#include "src/common/tensor_util.h" 129be168c0dSopenharmony_ci+#include "src/common/mmap_utils.h" 130be168c0dSopenharmony_ci+#include "src/common/file_utils.h" 131be168c0dSopenharmony_ci+#include "src/litert/delegate/nnrt/nnrt_model_kernel.h" 132be168c0dSopenharmony_ci+ 133be168c0dSopenharmony_ci+namespace mindspore { 134be168c0dSopenharmony_ci+namespace lite { 135be168c0dSopenharmony_ci+CacheSession::~CacheSession() { 136be168c0dSopenharmony_ci+ if (nn_executor_ != nullptr) { 137be168c0dSopenharmony_ci+ OH_NNExecutor_Destroy(&nn_executor_); 138be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Destroy NNExecutor Finish."; 139be168c0dSopenharmony_ci+ } 140be168c0dSopenharmony_ci+} 141be168c0dSopenharmony_ci+ 142be168c0dSopenharmony_ci+int CacheSession::CompileGraph(Model *model) { 143be168c0dSopenharmony_ci+ bool expected = false; 144be168c0dSopenharmony_ci+ if (!is_running_.compare_exchange_strong(expected, true)) { 145be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Not support multi-threading"; 146be168c0dSopenharmony_ci+ return RET_ERROR; 147be168c0dSopenharmony_ci+ } 148be168c0dSopenharmony_ci+ // Convert to abstract base model interface 149be168c0dSopenharmony_ci+ auto ret = ConvertInOutTensors(model); 150be168c0dSopenharmony_ci+ context_->set_schema_version(reinterpret_cast<LiteModel *>(model)->GetSchemaVersion()); 151be168c0dSopenharmony_ci+ if (ret != RET_OK) { 152be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "ConvertTensors failed: " << ret; 153be168c0dSopenharmony_ci+ is_running_.store(false); 154be168c0dSopenharmony_ci+ return ret; 155be168c0dSopenharmony_ci+ } 156be168c0dSopenharmony_ci+ InitGraphInputTensors(model); 157be168c0dSopenharmony_ci+ InitGraphOutputTensors(model); 158be168c0dSopenharmony_ci+ 159be168c0dSopenharmony_ci+ // create NNRt kernel 160be168c0dSopenharmony_ci+ ret = ScheduleToNNRTKernel(); 161be168c0dSopenharmony_ci+ if (ret != RET_OK) { 162be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Schedule NNRt kernel failed: " << ret; 163be168c0dSopenharmony_ci+ is_running_.store(false); 164be168c0dSopenharmony_ci+ return ret; 165be168c0dSopenharmony_ci+ } 166be168c0dSopenharmony_ci+ 167be168c0dSopenharmony_ci+ InitGraphInOutTensorsMap(model); 168be168c0dSopenharmony_ci+ ret = PrepareKernels(model); 169be168c0dSopenharmony_ci+ if (ret != RET_OK) { 170be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Prepare kernels failed: " << ret; 171be168c0dSopenharmony_ci+ is_running_.store(false); 172be168c0dSopenharmony_ci+ return ret; 173be168c0dSopenharmony_ci+ } 174be168c0dSopenharmony_ci+ 175be168c0dSopenharmony_ci+ ret = InitExecutor(); 176be168c0dSopenharmony_ci+ if (ret != RET_OK) { 177be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "InitExecutor failed: " << ret; 178be168c0dSopenharmony_ci+ is_running_.store(false); 179be168c0dSopenharmony_ci+ return ret; 180be168c0dSopenharmony_ci+ } 181be168c0dSopenharmony_ci+ 182be168c0dSopenharmony_ci+ MarkSharedWeight(kernels_); 183be168c0dSopenharmony_ci+ FreePackOpWeight(kernels_); 184be168c0dSopenharmony_ci+ 185be168c0dSopenharmony_ci+ is_running_.store(false); 186be168c0dSopenharmony_ci+ return RET_OK; 187be168c0dSopenharmony_ci+} 188be168c0dSopenharmony_ci+ 189be168c0dSopenharmony_ci+int CacheSession::InitExecutor() { 190be168c0dSopenharmony_ci+ executor_ = new (std::nothrow) Executor(); 191be168c0dSopenharmony_ci+ if (executor_ == nullptr) { 192be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "New Executor failed"; 193be168c0dSopenharmony_ci+ return RET_ERROR; 194be168c0dSopenharmony_ci+ } 195be168c0dSopenharmony_ci+ auto ret = executor_->Prepare(kernels_, inputs_, outputs_, context_.get()); 196be168c0dSopenharmony_ci+ if (ret != RET_OK) { 197be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Prepare executor failed: " << ret; 198be168c0dSopenharmony_ci+ return ret; 199be168c0dSopenharmony_ci+ } 200be168c0dSopenharmony_ci+ return RET_OK; 201be168c0dSopenharmony_ci+} 202be168c0dSopenharmony_ci+ 203be168c0dSopenharmony_ci+int CacheSession::ConvertInOutTensors(const lite::Model *model) { 204be168c0dSopenharmony_ci+ MS_ASSERT(model != nullptr); 205be168c0dSopenharmony_ci+ auto lite_model = reinterpret_cast<const lite::LiteModel *>(model); 206be168c0dSopenharmony_ci+ uint32_t tensor_count = model->graph_.all_tensors_.size(); 207be168c0dSopenharmony_ci+ auto model_input_indices = model->graph_.input_indices_; 208be168c0dSopenharmony_ci+ auto model_output_indices = model->graph_.output_indices_; 209be168c0dSopenharmony_ci+ 210be168c0dSopenharmony_ci+ for (uint32_t i = 0; i < tensor_count; ++i) { 211be168c0dSopenharmony_ci+ auto *src_tensor = model->graph_.all_tensors_[i]; 212be168c0dSopenharmony_ci+ if (!IsContain(model_input_indices, i) && !IsContain(model_output_indices, i)) { 213be168c0dSopenharmony_ci+ this->tensors_.emplace_back(nullptr); 214be168c0dSopenharmony_ci+ continue; 215be168c0dSopenharmony_ci+ } 216be168c0dSopenharmony_ci+ if (src_tensor == nullptr) { 217be168c0dSopenharmony_ci+ MS_LOG(ERROR) << i << "th tensor in model is nullptr"; 218be168c0dSopenharmony_ci+ return RET_NULL_PTR; 219be168c0dSopenharmony_ci+ } 220be168c0dSopenharmony_ci+ auto *dst_tensor = ConvertTensor(*src_tensor); 221be168c0dSopenharmony_ci+ if (dst_tensor == nullptr) { 222be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Convert new " << i << "th tensor failed!"; 223be168c0dSopenharmony_ci+ return RET_NULL_PTR; 224be168c0dSopenharmony_ci+ } 225be168c0dSopenharmony_ci+ auto ret = ConvertTensorsData(lite_model, i, dst_tensor); 226be168c0dSopenharmony_ci+ if (ret != RET_OK) { 227be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Convert data of " << i << "th tensor failed"; 228be168c0dSopenharmony_ci+ delete dst_tensor; 229be168c0dSopenharmony_ci+ return ret; 230be168c0dSopenharmony_ci+ } 231be168c0dSopenharmony_ci+ ConvertTensorsQuantParam(src_tensor, dst_tensor); 232be168c0dSopenharmony_ci+ if (IsContain(model_input_indices, i)) { 233be168c0dSopenharmony_ci+ dst_tensor->set_category(Category::GRAPH_INPUT); 234be168c0dSopenharmony_ci+ } 235be168c0dSopenharmony_ci+ if (IsContain(model_output_indices, i)) { 236be168c0dSopenharmony_ci+ // a tensor is as both input and output, would be treated as an input. 237be168c0dSopenharmony_ci+ if (!dst_tensor->IsGraphInput()) { 238be168c0dSopenharmony_ci+ dst_tensor->set_category(Category::GRAPH_OUTPUT); 239be168c0dSopenharmony_ci+ } 240be168c0dSopenharmony_ci+ } 241be168c0dSopenharmony_ci+ 242be168c0dSopenharmony_ci+ ret = CheckTensorValid(dst_tensor); 243be168c0dSopenharmony_ci+ if (ret != RET_OK) { 244be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Check " << i << "th tensor failed"; 245be168c0dSopenharmony_ci+ delete dst_tensor; 246be168c0dSopenharmony_ci+ return ret; 247be168c0dSopenharmony_ci+ } 248be168c0dSopenharmony_ci+ 249be168c0dSopenharmony_ci+ this->tensors_.emplace_back(dst_tensor); 250be168c0dSopenharmony_ci+ } 251be168c0dSopenharmony_ci+ return RET_OK; 252be168c0dSopenharmony_ci+} 253be168c0dSopenharmony_ci+ 254be168c0dSopenharmony_ci+int CacheSession::Init(const std::shared_ptr<InnerContext> &context) { 255be168c0dSopenharmony_ci+ if (context == nullptr) { 256be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "context is nullptr"; 257be168c0dSopenharmony_ci+ return RET_NULL_PTR; 258be168c0dSopenharmony_ci+ } 259be168c0dSopenharmony_ci+ bool expected = false; 260be168c0dSopenharmony_ci+ if (!is_running_.compare_exchange_strong(expected, true)) { 261be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Not support multi-threading"; 262be168c0dSopenharmony_ci+ return RET_ERROR; 263be168c0dSopenharmony_ci+ } 264be168c0dSopenharmony_ci+ context_ = context; 265be168c0dSopenharmony_ci+ auto ret = context_->Init(); 266be168c0dSopenharmony_ci+ if (ret != RET_OK) { 267be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Init Context failed"; 268be168c0dSopenharmony_ci+ return ret; 269be168c0dSopenharmony_ci+ } 270be168c0dSopenharmony_ci+ ms_context_ = MSContextFromContext(context); 271be168c0dSopenharmony_ci+ if (ms_context_ == nullptr) { 272be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "transfer context to ms context failed."; 273be168c0dSopenharmony_ci+ return RET_NULL_PTR; 274be168c0dSopenharmony_ci+ } 275be168c0dSopenharmony_ci+ 276be168c0dSopenharmony_ci+ auto iter = std::find_if(context_->device_list_.begin(), context_->device_list_.end(), 277be168c0dSopenharmony_ci+ [](DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; }); 278be168c0dSopenharmony_ci+ if(iter == context_->device_list_.end()) { 279be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Found non NNRT device info"; 280be168c0dSopenharmony_ci+ return RET_ERROR; 281be168c0dSopenharmony_ci+ } 282be168c0dSopenharmony_ci+ nnrt_device_info_ = iter->device_info_.nnrt_device_info_; 283be168c0dSopenharmony_ci+ 284be168c0dSopenharmony_ci+ const auto &extensions = nnrt_device_info_.extensions_; 285be168c0dSopenharmony_ci+ mindspore::lite::nnrt::ExtensionOptionsParser::Parse(extensions, &extension_options_); 286be168c0dSopenharmony_ci+ 287be168c0dSopenharmony_ci+ is_running_.store(false); 288be168c0dSopenharmony_ci+ return RET_OK; 289be168c0dSopenharmony_ci+} 290be168c0dSopenharmony_ci+ 291be168c0dSopenharmony_ci+int CacheSession::ParseInputOutputFromModelBuffer(const char *model_buf, LiteModel *model) { 292be168c0dSopenharmony_ci+ const void *meta_graph = nullptr; 293be168c0dSopenharmony_ci+ meta_graph = reinterpret_cast<const void *>(schema::GetMetaGraph(model_buf)); 294be168c0dSopenharmony_ci+ assert(meta_graph != nullptr); 295be168c0dSopenharmony_ci+ 296be168c0dSopenharmony_ci+ auto status = GenerateModelInputOutput<schema::MetaGraph, schema::CNode>( 297be168c0dSopenharmony_ci+ *reinterpret_cast<const schema::MetaGraph *>(meta_graph), model->graph_); 298be168c0dSopenharmony_ci+ if (status != RET_OK) { 299be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "fail to generate model"; 300be168c0dSopenharmony_ci+ return status; 301be168c0dSopenharmony_ci+ } 302be168c0dSopenharmony_ci+ model->buf = const_cast<char *>(model_buf); 303be168c0dSopenharmony_ci+ return RET_OK; 304be168c0dSopenharmony_ci+} 305be168c0dSopenharmony_ci+ 306be168c0dSopenharmony_ci+int CacheSession::LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type) { 307be168c0dSopenharmony_ci+ size_t model_size; 308be168c0dSopenharmony_ci+ bool use_mmap = IsMmapEnable(); 309be168c0dSopenharmony_ci+ auto model_buf = LoadModelByPath(model_path, model_type, &model_size, use_mmap); 310be168c0dSopenharmony_ci+ if (model_buf == nullptr) { 311be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Read model file failed"; 312be168c0dSopenharmony_ci+ return RET_ERROR; 313be168c0dSopenharmony_ci+ } 314be168c0dSopenharmony_ci+ 315be168c0dSopenharmony_ci+ Model *model = nullptr; 316be168c0dSopenharmony_ci+ if (extension_options_.cache_path_.empty()) { 317be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "cache path is empty"; 318be168c0dSopenharmony_ci+ return RET_ERROR; 319be168c0dSopenharmony_ci+ } else { 320be168c0dSopenharmony_ci+ model = ImportInOutFromBuffer(model_buf, model_size, true, model_type, model_path); 321be168c0dSopenharmony_ci+ dynamic_cast<LiteModel *>(model)->PrepareInnerTensors(); 322be168c0dSopenharmony_ci+ } 323be168c0dSopenharmony_ci+ if (model == nullptr) { 324be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Import model failed"; 325be168c0dSopenharmony_ci+ return RET_ERROR; 326be168c0dSopenharmony_ci+ } 327be168c0dSopenharmony_ci+ 328be168c0dSopenharmony_ci+ if (use_mmap) { 329be168c0dSopenharmony_ci+ reinterpret_cast<lite::LiteModel *>(model)->model_buf_by_mmap_ = true; 330be168c0dSopenharmony_ci+ } else { 331be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Memory may exceed the limit of business demands."; 332be168c0dSopenharmony_ci+ } 333be168c0dSopenharmony_ci+ (reinterpret_cast<lite::LiteModel *>(model))->set_keep_model_buf(true); 334be168c0dSopenharmony_ci+ auto ret = CompileGraph(model); 335be168c0dSopenharmony_ci+ if (ret != lite::RET_OK) { 336be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Compile model failed"; 337be168c0dSopenharmony_ci+ model->buf = nullptr; 338be168c0dSopenharmony_ci+ delete model; 339be168c0dSopenharmony_ci+ return RET_ERROR; 340be168c0dSopenharmony_ci+ } 341be168c0dSopenharmony_ci+ set_model(model); 342be168c0dSopenharmony_ci+ return RET_OK; 343be168c0dSopenharmony_ci+} 344be168c0dSopenharmony_ci+ 345be168c0dSopenharmony_ci+Model *CacheSession::ImportInOutFromBuffer(const char *model_buf, size_t size, bool take_buf, mindspore::ModelType model_type, 346be168c0dSopenharmony_ci+ const std::string &path) { 347be168c0dSopenharmony_ci+ MS_LOG(INFO) << "import model from lite model"; 348be168c0dSopenharmony_ci+ auto *model = new (std::nothrow) LiteModel(path); 349be168c0dSopenharmony_ci+ if (model == nullptr) { 350be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new model fail!"; 351be168c0dSopenharmony_ci+ return nullptr; 352be168c0dSopenharmony_ci+ } 353be168c0dSopenharmony_ci+ 354be168c0dSopenharmony_ci+ auto status = ParseInputOutputFromModelBuffer(model_buf, model); 355be168c0dSopenharmony_ci+ if (status != RET_OK) { 356be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "construct model failed."; 357be168c0dSopenharmony_ci+ delete model; 358be168c0dSopenharmony_ci+ return nullptr; 359be168c0dSopenharmony_ci+ } 360be168c0dSopenharmony_ci+ model->buf = const_cast<char *>(model_buf); 361be168c0dSopenharmony_ci+ model->buf_size_ = size; 362be168c0dSopenharmony_ci+ return model; 363be168c0dSopenharmony_ci+} 364be168c0dSopenharmony_ci+ 365be168c0dSopenharmony_ci+int CacheSession::ScheduleToNNRTKernel() { 366be168c0dSopenharmony_ci+ if (!IsKirinNPUWithOnlineInference(nnrt_device_info_.device_id_)) { 367be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "only support NPU_ device."; 368be168c0dSopenharmony_ci+ return RET_ERROR; 369be168c0dSopenharmony_ci+ } 370be168c0dSopenharmony_ci+ auto ret = CreateFullModelKernel(); 371be168c0dSopenharmony_ci+ if (ret != kSuccess) { 372be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Build npu model failed."; 373be168c0dSopenharmony_ci+ return RET_ERROR; 374be168c0dSopenharmony_ci+ } 375be168c0dSopenharmony_ci+ return RET_OK; 376be168c0dSopenharmony_ci+} 377be168c0dSopenharmony_ci+ 378be168c0dSopenharmony_ci+bool CacheSession::IsKirinNPUWithOnlineInference(size_t device_id) { 379be168c0dSopenharmony_ci+ const std::string kirin_npu_name_prefix = "NPU_"; 380be168c0dSopenharmony_ci+ const char *device_name; 381be168c0dSopenharmony_ci+ auto ret = OH_NNDevice_GetName(device_id, &device_name); 382be168c0dSopenharmony_ci+ if (ret != OH_NN_SUCCESS) { 383be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Get name of device: " << device_id << " failed, error: " << ret; 384be168c0dSopenharmony_ci+ return false; 385be168c0dSopenharmony_ci+ } 386be168c0dSopenharmony_ci+ 387be168c0dSopenharmony_ci+ if (strncmp(kirin_npu_name_prefix.c_str(), device_name, kirin_npu_name_prefix.size()) != 0) { 388be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "strncmp: " << device_id << " failed, device_name: " << device_name; 389be168c0dSopenharmony_ci+ return false; 390be168c0dSopenharmony_ci+ } 391be168c0dSopenharmony_ci+ return true; 392be168c0dSopenharmony_ci+} 393be168c0dSopenharmony_ci+ 394be168c0dSopenharmony_ci+Status CacheSession::CreateFullModelKernel() { 395be168c0dSopenharmony_ci+ OH_NNCompilation* nn_compilation = OH_NNCompilation_ConstructForCache(); 396be168c0dSopenharmony_ci+ if (nn_compilation == nullptr) { 397be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Construct NNCompilation failed"; 398be168c0dSopenharmony_ci+ return kLiteError; 399be168c0dSopenharmony_ci+ } 400be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success."; 401be168c0dSopenharmony_ci+ 402be168c0dSopenharmony_ci+ auto ret_code = InitNNCompilation(nn_compilation); 403be168c0dSopenharmony_ci+ if (ret_code != kSuccess) { 404be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Init NNCompilation failed"; 405be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 406be168c0dSopenharmony_ci+ return kLiteError; 407be168c0dSopenharmony_ci+ } 408be168c0dSopenharmony_ci+ 409be168c0dSopenharmony_ci+ OH_NNExecutor *nn_executor = nullptr; 410be168c0dSopenharmony_ci+ nn_executor = OH_NNExecutor_Construct(nn_compilation); 411be168c0dSopenharmony_ci+ if (nn_executor == nullptr) { 412be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code; 413be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 414be168c0dSopenharmony_ci+ return kLiteError; 415be168c0dSopenharmony_ci+ } 416be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 417be168c0dSopenharmony_ci+ 418be168c0dSopenharmony_ci+ ms_inputs_ = LiteTensorsToMSTensors(inputs_); 419be168c0dSopenharmony_ci+ ms_outputs_ = LiteTensorsToMSTensors(outputs_); 420be168c0dSopenharmony_ci+ auto nnrt_model_kernel = new (std::nothrow) NNRTModelKernel(nn_executor, nnrt_device_info_, ms_inputs_, ms_outputs_); 421be168c0dSopenharmony_ci+ if (nnrt_model_kernel == nullptr) { 422be168c0dSopenharmony_ci+ OH_NNExecutor_Destroy(&nn_executor); 423be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new NNRTModelKernel failed"; 424be168c0dSopenharmony_ci+ return kLiteError; 425be168c0dSopenharmony_ci+ } 426be168c0dSopenharmony_ci+ nn_executor_ = nn_executor; 427be168c0dSopenharmony_ci+ 428be168c0dSopenharmony_ci+ std::shared_ptr<kernel::Kernel> shared_kernel(nnrt_model_kernel); 429be168c0dSopenharmony_ci+ auto *kernel_exec = new (std::nothrow) kernel::KernelExec(shared_kernel); 430be168c0dSopenharmony_ci+ if (kernel_exec == nullptr) { 431be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "nnrt kernel exec create failed."; 432be168c0dSopenharmony_ci+ return kLiteError; 433be168c0dSopenharmony_ci+ } 434be168c0dSopenharmony_ci+ auto delegate_type = kNumberTypeFloat32; 435be168c0dSopenharmony_ci+ for (auto &input : nnrt_model_kernel->inputs()) { 436be168c0dSopenharmony_ci+ if (static_cast<TypeId>(input.DataType()) == kNumberTypeFloat16) { 437be168c0dSopenharmony_ci+ delegate_type = kNumberTypeFloat16; 438be168c0dSopenharmony_ci+ break; 439be168c0dSopenharmony_ci+ } 440be168c0dSopenharmony_ci+ } 441be168c0dSopenharmony_ci+ kernel::KernelKey delegate_desc{kernel::kDelegate, delegate_type, NHWC, schema::PrimitiveType_NONE, "", ""}; 442be168c0dSopenharmony_ci+ kernel_exec->set_desc(delegate_desc); 443be168c0dSopenharmony_ci+ kernel_exec->set_context(context_.get()); 444be168c0dSopenharmony_ci+ kernels_.push_back(kernel_exec); 445be168c0dSopenharmony_ci+ 446be168c0dSopenharmony_ci+ return kSuccess; 447be168c0dSopenharmony_ci+} 448be168c0dSopenharmony_ci+ 449be168c0dSopenharmony_ci+Status CacheSession::InitNNCompilation(OH_NNCompilation *nn_compilation) const { 450be168c0dSopenharmony_ci+ auto ret_code = OH_NNCompilation_SetDevice(nn_compilation, nnrt_device_info_.device_id_); 451be168c0dSopenharmony_ci+ if (ret_code != OH_NN_SUCCESS) { 452be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation set device id failed, ret: " << ret_code; 453be168c0dSopenharmony_ci+ return kLiteError; 454be168c0dSopenharmony_ci+ } 455be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_SetPerformanceMode(nn_compilation, 456be168c0dSopenharmony_ci+ (OH_NN_PerformanceMode)(nnrt_device_info_.performance_mode_)); 457be168c0dSopenharmony_ci+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 458be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation set performance mode failed, ret: " << ret_code; 459be168c0dSopenharmony_ci+ return kLiteError; 460be168c0dSopenharmony_ci+ } 461be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_SetPriority(nn_compilation, (OH_NN_Priority)(nnrt_device_info_.priority_)); 462be168c0dSopenharmony_ci+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 463be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation set priority failed, ret: " << ret_code; 464be168c0dSopenharmony_ci+ return kLiteError; 465be168c0dSopenharmony_ci+ } 466be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_EnableFloat16(nn_compilation, nnrt_device_info_.enable_fp16_); 467be168c0dSopenharmony_ci+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 468be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation enable fp16 failed, ret: " << ret_code; 469be168c0dSopenharmony_ci+ return kLiteError; 470be168c0dSopenharmony_ci+ } 471be168c0dSopenharmony_ci+ 472be168c0dSopenharmony_ci+ if (!extension_options_.cache_path_.empty()) { 473be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_SetCache(nn_compilation, extension_options_.cache_path_.c_str(), 474be168c0dSopenharmony_ci+ extension_options_.cache_version_); 475be168c0dSopenharmony_ci+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 476be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation set cache failed, ret: " << ret_code; 477be168c0dSopenharmony_ci+ return kLiteError; 478be168c0dSopenharmony_ci+ } 479be168c0dSopenharmony_ci+ } else { 480be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation must set Cache."; 481be168c0dSopenharmony_ci+ return kLiteError; 482be168c0dSopenharmony_ci+ } 483be168c0dSopenharmony_ci+ 484be168c0dSopenharmony_ci+ size_t extension_size = nnrt_device_info_.extensions_.size(); 485be168c0dSopenharmony_ci+ for (size_t i = 0; i < extension_size; i++) { 486be168c0dSopenharmony_ci+ auto &src_extensoin = nnrt_device_info_.extensions_[i]; 487be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_AddExtensionConfig(nn_compilation, src_extensoin.name.c_str(), 488be168c0dSopenharmony_ci+ (char *)((void *)src_extensoin.value.data()), 489be168c0dSopenharmony_ci+ src_extensoin.value.size()); 490be168c0dSopenharmony_ci+ if (ret_code != OH_NN_SUCCESS) { 491be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "OH_NNCompilation_AddExtensionConfig " << i << ": "<< src_extensoin.name << " failed, ret: " 492be168c0dSopenharmony_ci+ << ret_code; 493be168c0dSopenharmony_ci+ return kLiteError; 494be168c0dSopenharmony_ci+ } 495be168c0dSopenharmony_ci+ } 496be168c0dSopenharmony_ci+ 497be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_Build(nn_compilation); 498be168c0dSopenharmony_ci+ if (ret_code != OH_NN_SUCCESS) { 499be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Build NNCompilation failed, ret: " << ret_code; 500be168c0dSopenharmony_ci+ return kLiteError; 501be168c0dSopenharmony_ci+ } 502be168c0dSopenharmony_ci+ return kSuccess; 503be168c0dSopenharmony_ci+} 504be168c0dSopenharmony_ci+ 505be168c0dSopenharmony_ci+const char *CacheSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, bool use_mmap) { 506be168c0dSopenharmony_ci+ size_t buf_size; 507be168c0dSopenharmony_ci+ char *model_buf; 508be168c0dSopenharmony_ci+ if (use_mmap) { 509be168c0dSopenharmony_ci+ model_buf = reinterpret_cast<char *>(lite::ReadFileByMmap(file.c_str(), &buf_size, false)); 510be168c0dSopenharmony_ci+ } else { 511be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Memory may exceed the limit of business demands."; 512be168c0dSopenharmony_ci+ model_buf = lite::ReadFile(file.c_str(), &buf_size); 513be168c0dSopenharmony_ci+ } 514be168c0dSopenharmony_ci+ if (model_buf == nullptr) { 515be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "The model path is invalid"; 516be168c0dSopenharmony_ci+ return model_buf; 517be168c0dSopenharmony_ci+ } 518be168c0dSopenharmony_ci+ 519be168c0dSopenharmony_ci+ char *lite_buf = nullptr; 520be168c0dSopenharmony_ci+ auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, size, model_type); 521be168c0dSopenharmony_ci+ if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) { 522be168c0dSopenharmony_ci+ if (use_mmap) { 523be168c0dSopenharmony_ci+ lite::UnmapMmapBuffer(const_cast<void *>(static_cast<const void *>(model_buf)), buf_size); 524be168c0dSopenharmony_ci+ } else { 525be168c0dSopenharmony_ci+ delete[] model_buf; 526be168c0dSopenharmony_ci+ } 527be168c0dSopenharmony_ci+ model_buf = nullptr; 528be168c0dSopenharmony_ci+ return nullptr; 529be168c0dSopenharmony_ci+ } 530be168c0dSopenharmony_ci+ 531be168c0dSopenharmony_ci+ return lite_buf; 532be168c0dSopenharmony_ci+} 533be168c0dSopenharmony_ci+} // namespace lite 534be168c0dSopenharmony_ci+} // namespace mindspore 535be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cache_session.h b/mindspore/lite/src/litert/cache_session.h 536be168c0dSopenharmony_cinew file mode 100644 537be168c0dSopenharmony_ciindex 00000000..f0ae185a 538be168c0dSopenharmony_ci--- /dev/null 539be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cache_session.h 540be168c0dSopenharmony_ci@@ -0,0 +1,129 @@ 541be168c0dSopenharmony_ci+/** 542be168c0dSopenharmony_ci+ * Copyright 2024 Huawei Technologies Co., Ltd 543be168c0dSopenharmony_ci+ * 544be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 545be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 546be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 547be168c0dSopenharmony_ci+ * 548be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 549be168c0dSopenharmony_ci+ * 550be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 551be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 552be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 553be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 554be168c0dSopenharmony_ci+ * limitations under the License. 555be168c0dSopenharmony_ci+ */ 556be168c0dSopenharmony_ci+ 557be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_ 558be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_ 559be168c0dSopenharmony_ci+ 560be168c0dSopenharmony_ci+#include "src/litert/lite_session.h" 561be168c0dSopenharmony_ci+#include "src/litert/inner_context.h" 562be168c0dSopenharmony_ci+#include "src/litert/lite_model.h" 563be168c0dSopenharmony_ci+#include "src/litert/delegate/nnrt/extension_options_parser.h" 564be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime_type.h" 565be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 566be168c0dSopenharmony_ci+#include "interfaces/innerkits/c/neural_network_runtime_inner.h" 567be168c0dSopenharmony_ci+ 568be168c0dSopenharmony_ci+namespace mindspore { 569be168c0dSopenharmony_ci+namespace lite { 570be168c0dSopenharmony_ci+class CacheSession : public LiteSession { 571be168c0dSopenharmony_ci+ public: 572be168c0dSopenharmony_ci+ CacheSession() = default; 573be168c0dSopenharmony_ci+ ~CacheSession() override; 574be168c0dSopenharmony_ci+ int Init(const std::shared_ptr<InnerContext> &context) override; 575be168c0dSopenharmony_ci+ int CompileGraph(Model *model) override; 576be168c0dSopenharmony_ci+ int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type) override; 577be168c0dSopenharmony_ci+ static bool IsKirinNPUWithOnlineInference(size_t device_id); 578be168c0dSopenharmony_ci+ const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, 579be168c0dSopenharmony_ci+ bool use_mmap) override; 580be168c0dSopenharmony_ci+ Model* ImportInOutFromBuffer(const char *model_buf, size_t size, bool take_buf, 581be168c0dSopenharmony_ci+ mindspore::ModelType model_type = mindspore::ModelType::kMindIR_Lite, 582be168c0dSopenharmony_ci+ const std::string &path = ""); 583be168c0dSopenharmony_ci+ 584be168c0dSopenharmony_ci+ template <typename T = schema::MetaGraph> 585be168c0dSopenharmony_ci+ bool ConvertInputOutputTensors(const T &meta_graph, LiteGraph &graph_) { 586be168c0dSopenharmony_ci+ if (meta_graph.allTensors() == nullptr) { 587be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 588be168c0dSopenharmony_ci+ return false; 589be168c0dSopenharmony_ci+ } 590be168c0dSopenharmony_ci+ 591be168c0dSopenharmony_ci+ graph_.all_tensors_.resize(meta_graph.allTensors()->size()); 592be168c0dSopenharmony_ci+ MS_LOG(INFO) << "convert input/output tensors"; 593be168c0dSopenharmony_ci+ for (auto i: graph_.input_indices_) { 594be168c0dSopenharmony_ci+ auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i); 595be168c0dSopenharmony_ci+ if (tensor == nullptr) { 596be168c0dSopenharmony_ci+ MS_LOG(ERROR) << i << " the input tensor in metagraph is nullptr"; 597be168c0dSopenharmony_ci+ return false; 598be168c0dSopenharmony_ci+ } 599be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(tensor->format() >= schema::Format_MIN && tensor->format() <= schema::Format_MAX, false); 600be168c0dSopenharmony_ci+ graph_.all_tensors_[i] = (const_cast<mindspore::schema::Tensor *>(tensor)); 601be168c0dSopenharmony_ci+ } 602be168c0dSopenharmony_ci+ 603be168c0dSopenharmony_ci+ for (auto i: graph_.output_indices_) { 604be168c0dSopenharmony_ci+ auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i); 605be168c0dSopenharmony_ci+ if (tensor == nullptr) { 606be168c0dSopenharmony_ci+ MS_LOG(ERROR) << i << " the output tensor in metagraph is nullptr"; 607be168c0dSopenharmony_ci+ } 608be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(tensor->format() >= schema::Format_MIN && tensor->format() <= schema::Format_MAX, false); 609be168c0dSopenharmony_ci+ graph_.all_tensors_[i] = (const_cast<mindspore::schema::Tensor *>(tensor)); 610be168c0dSopenharmony_ci+ } 611be168c0dSopenharmony_ci+ return true; 612be168c0dSopenharmony_ci+ } 613be168c0dSopenharmony_ci+ 614be168c0dSopenharmony_ci+ template <typename T = schema::MetaGraph, typename U = schema::CNode> 615be168c0dSopenharmony_ci+ int GenerateModelInputOutput(const T &meta_graph, LiteGraph &graph_) { 616be168c0dSopenharmony_ci+ if (meta_graph.name() != nullptr) { 617be168c0dSopenharmony_ci+ graph_.name_ = meta_graph.name()->c_str(); 618be168c0dSopenharmony_ci+ } 619be168c0dSopenharmony_ci+ if (meta_graph.version() != nullptr) { 620be168c0dSopenharmony_ci+ graph_.version_ = meta_graph.version()->c_str(); 621be168c0dSopenharmony_ci+ } 622be168c0dSopenharmony_ci+ 623be168c0dSopenharmony_ci+ if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || 624be168c0dSopenharmony_ci+ meta_graph.allTensors() == nullptr) { 625be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; 626be168c0dSopenharmony_ci+ return RET_ERROR; 627be168c0dSopenharmony_ci+ } 628be168c0dSopenharmony_ci+ 629be168c0dSopenharmony_ci+ // converterInputOutput 630be168c0dSopenharmony_ci+ auto in_count = meta_graph.inputIndex()->size(); 631be168c0dSopenharmony_ci+ for (uint32_t i = 0; i < in_count; ++i) { 632be168c0dSopenharmony_ci+ graph_.input_indices_.push_back(meta_graph.inputIndex()->Get(i)); 633be168c0dSopenharmony_ci+ } 634be168c0dSopenharmony_ci+ auto out_count = meta_graph.outputIndex()->size(); 635be168c0dSopenharmony_ci+ for (uint32_t i = 0; i < out_count; ++i) { 636be168c0dSopenharmony_ci+ graph_.output_indices_.push_back(meta_graph.outputIndex()->Get(i)); 637be168c0dSopenharmony_ci+ } 638be168c0dSopenharmony_ci+ 639be168c0dSopenharmony_ci+ if (!ConvertInputOutputTensors<T>(meta_graph, graph_)) { 640be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "convert tensor failed"; 641be168c0dSopenharmony_ci+ return RET_ERROR; 642be168c0dSopenharmony_ci+ } 643be168c0dSopenharmony_ci+ return RET_OK; 644be168c0dSopenharmony_ci+ } 645be168c0dSopenharmony_ci+ 646be168c0dSopenharmony_ci+ int ParseInputOutputFromModelBuffer(const char *model_buf, LiteModel *model); 647be168c0dSopenharmony_ci+ int BindGLTexture2DMemory(const std::map<std::string, unsigned int> &inputGLTexture, 648be168c0dSopenharmony_ci+ std::map<std::string, unsigned int> *outputGLTexture) override { 649be168c0dSopenharmony_ci+ return RET_ERROR; 650be168c0dSopenharmony_ci+ } 651be168c0dSopenharmony_ci+ 652be168c0dSopenharmony_ci+ protected: 653be168c0dSopenharmony_ci+ int ScheduleToNNRTKernel(); 654be168c0dSopenharmony_ci+ Status CreateFullModelKernel(); 655be168c0dSopenharmony_ci+ Status InitNNCompilation(OH_NNCompilation *nn_compilation) const; 656be168c0dSopenharmony_ci+ int ConvertInOutTensors(const lite::Model *model); 657be168c0dSopenharmony_ci+ int InitExecutor() override; 658be168c0dSopenharmony_ci+ std::vector<mindspore::MSTensor> ms_inputs_; 659be168c0dSopenharmony_ci+ std::vector<mindspore::MSTensor> ms_outputs_; 660be168c0dSopenharmony_ci+ 661be168c0dSopenharmony_ci+ private: 662be168c0dSopenharmony_ci+ NNRtDeviceInfo nnrt_device_info_; 663be168c0dSopenharmony_ci+ OH_NNExecutor *nn_executor_{nullptr}; 664be168c0dSopenharmony_ci+ nnrt::ExtensionOptions extension_options_; 665be168c0dSopenharmony_ci+}; 666be168c0dSopenharmony_ci+} // namespace lite 667be168c0dSopenharmony_ci+} // namespace mindspore 668be168c0dSopenharmony_ci+ 669be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_LITERT_CACHE_SESSION_H_ 670be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 671be168c0dSopenharmony_ciindex 02533dc3..cacbf86e 100644 672be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 673be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/model/model_impl.cc 674be168c0dSopenharmony_ci@@ -39,6 +39,11 @@ 675be168c0dSopenharmony_ci #include "src/common/config_file.h" 676be168c0dSopenharmony_ci #include "src/litert/cpu_info.h" 677be168c0dSopenharmony_ci #include "src/litert/pack_weight_manager.h" 678be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT_METAGRAPH 679be168c0dSopenharmony_ci+#include "src/litert/cache_session.h" 680be168c0dSopenharmony_ci+#include "src/litert/delegate/nnrt/extension_options_parser.h" 681be168c0dSopenharmony_ci+#endif 682be168c0dSopenharmony_ci+ 683be168c0dSopenharmony_ci namespace mindspore { 684be168c0dSopenharmony_ci namespace { 685be168c0dSopenharmony_ci const char *const kExecutionPlan = "execution_plan"; 686be168c0dSopenharmony_ci@@ -1006,7 +1011,36 @@ float ModelImpl::GetLearningRate() { 687be168c0dSopenharmony_ci } 688be168c0dSopenharmony_ci 689be168c0dSopenharmony_ci lite::LiteSession *ModelImpl::CreateLiteSession(const std::shared_ptr<lite::InnerContext> &context) { 690be168c0dSopenharmony_ci- auto session = new (std::nothrow) lite::LiteSession(); 691be168c0dSopenharmony_ci+ if (context == nullptr) { 692be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "context is nullptr"; 693be168c0dSopenharmony_ci+ return nullptr; 694be168c0dSopenharmony_ci+ } 695be168c0dSopenharmony_ci+ lite::LiteSession *session = nullptr; 696be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT_METAGRAPH 697be168c0dSopenharmony_ci+ auto iter = std::find_if(context->device_list_.begin(), context->device_list_.end(), 698be168c0dSopenharmony_ci+ [](lite::DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; }); 699be168c0dSopenharmony_ci+ if(iter != context->device_list_.end()) { 700be168c0dSopenharmony_ci+ const auto &nnrt_device_info = iter->device_info_.nnrt_device_info_; 701be168c0dSopenharmony_ci+ if (lite::CacheSession::IsKirinNPUWithOnlineInference(nnrt_device_info.device_id_)) { 702be168c0dSopenharmony_ci+ const auto &extensions = nnrt_device_info.extensions_; 703be168c0dSopenharmony_ci+ lite::nnrt::ExtensionOptions extension_options; 704be168c0dSopenharmony_ci+ mindspore::lite::nnrt::ExtensionOptionsParser::Parse(extensions, &extension_options); 705be168c0dSopenharmony_ci+ auto has_cache = OH_NNModel_HasCache(extension_options.cache_path_.c_str(), extension_options.model_name.c_str(), 706be168c0dSopenharmony_ci+ extension_options.cache_version_); 707be168c0dSopenharmony_ci+ if (has_cache) { 708be168c0dSopenharmony_ci+ session = reinterpret_cast<lite::LiteSession *>(new (std::nothrow) lite::CacheSession()); 709be168c0dSopenharmony_ci+ if (session == nullptr) { 710be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "create cache session failed"; 711be168c0dSopenharmony_ci+ return nullptr; 712be168c0dSopenharmony_ci+ } 713be168c0dSopenharmony_ci+ } 714be168c0dSopenharmony_ci+ } 715be168c0dSopenharmony_ci+ } 716be168c0dSopenharmony_ci+#endif 717be168c0dSopenharmony_ci+ 718be168c0dSopenharmony_ci+ if (session == nullptr) { 719be168c0dSopenharmony_ci+ session = new (std::nothrow) lite::LiteSession(); 720be168c0dSopenharmony_ci+ } 721be168c0dSopenharmony_ci if (session == nullptr) { 722be168c0dSopenharmony_ci MS_LOG(ERROR) << "create session failed"; 723be168c0dSopenharmony_ci return nullptr; 724be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.cc b/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.cc 725be168c0dSopenharmony_ciindex e35cc2a5..a66cd5ea 100644 726be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.cc 727be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.cc 728be168c0dSopenharmony_ci@@ -30,6 +30,7 @@ const std::string kCachePath = "CachePath"; 729be168c0dSopenharmony_ci const std::string kCacheVersion = "CacheVersion"; 730be168c0dSopenharmony_ci const std::string kBandMode = "BandMode"; 731be168c0dSopenharmony_ci const std::string kQuantConfigData = "QuantConfigData"; 732be168c0dSopenharmony_ci+const std::string kModelName = "ModelName"; 733be168c0dSopenharmony_ci } // namespace 734be168c0dSopenharmony_ci 735be168c0dSopenharmony_ci int ExtensionOptionsParser::Parse(const std::vector<Extension> &extensions, ExtensionOptions *param) { 736be168c0dSopenharmony_ci@@ -39,6 +40,7 @@ int ExtensionOptionsParser::Parse(const std::vector<Extension> &extensions, Exte 737be168c0dSopenharmony_ci DoParseCacheVersion(extensions, ¶m->cache_version_); 738be168c0dSopenharmony_ci DoParseBondMode(extensions, ¶m->band_mode); 739be168c0dSopenharmony_ci DoParseQuantConfig(extensions, ¶m->quant_config, ¶m->quant_config_size, ¶m->is_optional_quant_setted); 740be168c0dSopenharmony_ci+ DoParseModelName(extensions, ¶m->model_name); 741be168c0dSopenharmony_ci return RET_OK; 742be168c0dSopenharmony_ci } 743be168c0dSopenharmony_ci 744be168c0dSopenharmony_ci@@ -89,4 +91,14 @@ void ExtensionOptionsParser::DoParseQuantConfig(const std::vector<Extension> &ex 745be168c0dSopenharmony_ci *quant_setted = true; 746be168c0dSopenharmony_ci } 747be168c0dSopenharmony_ci } 748be168c0dSopenharmony_ci+ 749be168c0dSopenharmony_ci+void ExtensionOptionsParser::DoParseModelName(const std::vector<Extension> &extensions, std::string *model_name) { 750be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET_VOID(model_name != nullptr); 751be168c0dSopenharmony_ci+ auto iter_config = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) { 752be168c0dSopenharmony_ci+ return extension.name == kModelName; 753be168c0dSopenharmony_ci+ }); 754be168c0dSopenharmony_ci+ if (iter_config != extensions.end()) { 755be168c0dSopenharmony_ci+ *model_name = std::string(iter_config->value.begin(), iter_config->value.end()); 756be168c0dSopenharmony_ci+ } 757be168c0dSopenharmony_ci+} 758be168c0dSopenharmony_ci } // mindspore::lite::nnrt 759be168c0dSopenharmony_ci\ No newline at end of file 760be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.h b/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.h 761be168c0dSopenharmony_ciindex f24682ce..9a030ad6 100644 762be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.h 763be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/extension_options_parser.h 764be168c0dSopenharmony_ci@@ -29,6 +29,7 @@ struct ExtensionOptions { 765be168c0dSopenharmony_ci void *quant_config; 766be168c0dSopenharmony_ci size_t quant_config_size = 0; 767be168c0dSopenharmony_ci bool is_optional_quant_setted = false; 768be168c0dSopenharmony_ci+ std::string model_name = ""; 769be168c0dSopenharmony_ci }; 770be168c0dSopenharmony_ci 771be168c0dSopenharmony_ci class ExtensionOptionsParser { 772be168c0dSopenharmony_ci@@ -41,6 +42,7 @@ private: 773be168c0dSopenharmony_ci bool *quant_setted); 774be168c0dSopenharmony_ci static void DoParseCachePath(const std::vector<Extension> &extensions, std::string *cache_path); 775be168c0dSopenharmony_ci static void DoParseCacheVersion(const std::vector<Extension> &extensions, uint32_t *cache_version); 776be168c0dSopenharmony_ci+ static void DoParseModelName(const std::vector<Extension> &extensions, std::string *model_name); 777be168c0dSopenharmony_ci }; 778be168c0dSopenharmony_ci 779be168c0dSopenharmony_ci } // namespace mindspore::lite::nnrt 780be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc 781be168c0dSopenharmony_ciindex 006bc02c..5acf5760 100644 782be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_model.cc 783be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_model.cc 784be168c0dSopenharmony_ci@@ -538,14 +538,16 @@ bool LiteModel::PrepareInnerTensors() { 785be168c0dSopenharmony_ci MS_LOG(ERROR) << "Create SchemaTensorWrapper return nullptr"; 786be168c0dSopenharmony_ci return false; 787be168c0dSopenharmony_ci } 788be168c0dSopenharmony_ci+ if (graph_.all_tensors_.at(i) != nullptr) { 789be168c0dSopenharmony_ci #ifdef ENABLE_LITE_HELPER 790be168c0dSopenharmony_ci- if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir, 791be168c0dSopenharmony_ci- infer_helpers)) { 792be168c0dSopenharmony_ci+ if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir, 793be168c0dSopenharmony_ci+ infer_helpers)) { 794be168c0dSopenharmony_ci #else 795be168c0dSopenharmony_ci- if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir)) { 796be168c0dSopenharmony_ci+ if (!tensor_wrapper->Init(*(graph_.all_tensors_.at(i)), static_cast<SCHEMA_VERSION>(schema_version_), dir)) { 797be168c0dSopenharmony_ci #endif 798be168c0dSopenharmony_ci- delete tensor_wrapper; 799be168c0dSopenharmony_ci- return false; 800be168c0dSopenharmony_ci+ delete tensor_wrapper; 801be168c0dSopenharmony_ci+ return false; 802be168c0dSopenharmony_ci+ } 803be168c0dSopenharmony_ci } 804be168c0dSopenharmony_ci this->inner_all_tensors_[i] = tensor_wrapper; 805be168c0dSopenharmony_ci } 806be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_model.h b/mindspore/lite/src/litert/lite_model.h 807be168c0dSopenharmony_ciindex 647746a2..c0847c1e 100644 808be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_model.h 809be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_model.h 810be168c0dSopenharmony_ci@@ -66,13 +66,13 @@ class MS_API LiteModel : public Model { 811be168c0dSopenharmony_ci 812be168c0dSopenharmony_ci static int VersionVerify(flatbuffers::Verifier *verify); 813be168c0dSopenharmony_ci 814be168c0dSopenharmony_ci- private: 815be168c0dSopenharmony_ci #ifdef ENABLE_LITE_HELPER 816be168c0dSopenharmony_ci bool PrepareInnerTensors(mindspore::infer::helper::InferHelpers *infer_helpers = nullptr); 817be168c0dSopenharmony_ci #else 818be168c0dSopenharmony_ci bool PrepareInnerTensors(); 819be168c0dSopenharmony_ci #endif 820be168c0dSopenharmony_ci 821be168c0dSopenharmony_ci+ private: 822be168c0dSopenharmony_ci bool CheckQuantAllInit(const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params); 823be168c0dSopenharmony_ci 824be168c0dSopenharmony_ci template <typename T = schema::MetaGraph, typename U = schema::CNode> 825be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_session.h b/mindspore/lite/src/litert/lite_session.h 826be168c0dSopenharmony_ciindex 64a5f6d3..487b382a 100644 827be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_session.h 828be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_session.h 829be168c0dSopenharmony_ci@@ -57,10 +57,10 @@ class MS_API LiteSession { 830be168c0dSopenharmony_ci #else 831be168c0dSopenharmony_ci int LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size); 832be168c0dSopenharmony_ci #endif 833be168c0dSopenharmony_ci- int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type); 834be168c0dSopenharmony_ci+ virtual int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type); 835be168c0dSopenharmony_ci mindspore::ModelType LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf, size_t *size, 836be168c0dSopenharmony_ci mindspore::ModelType model_type); 837be168c0dSopenharmony_ci- const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, bool use_mmap); 838be168c0dSopenharmony_ci+ virtual const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, bool use_mmap); 839be168c0dSopenharmony_ci virtual int Init(const std::shared_ptr<InnerContext> &context); 840be168c0dSopenharmony_ci virtual void BindThread(bool if_bind); 841be168c0dSopenharmony_ci virtual int CompileGraph(Model *model); 842be168c0dSopenharmony_ci@@ -168,10 +168,10 @@ class MS_API LiteSession { 843be168c0dSopenharmony_ci static void MarkSharedWeight(const std::vector<kernel::KernelExec *> &kernels); 844be168c0dSopenharmony_ci std::string ParseWeightPath(); 845be168c0dSopenharmony_ci bool IsMmapEnable(); 846be168c0dSopenharmony_ci+ virtual int InitExecutor(); 847be168c0dSopenharmony_ci 848be168c0dSopenharmony_ci private: 849be168c0dSopenharmony_ci int PreCheck(Model *model); 850be168c0dSopenharmony_ci- int InitExecutor(); 851be168c0dSopenharmony_ci void ResetInputsShape(const std::vector<std::vector<int>> &dims); 852be168c0dSopenharmony_ci int ContextInit(const std::shared_ptr<InnerContext> &context); 853be168c0dSopenharmony_ci int CreateTensorRTDelegate(); 854be168c0dSopenharmony_ci-- 855be168c0dSopenharmony_ci2.17.1 856be168c0dSopenharmony_ci 857