1From baf2daaebd70448cddd35f5011642fe585d071b5 Mon Sep 17 00:00:00 2001 2From: chengfeng27 <chengfeng27@huawei.com> 3Date: Tue, 5 Mar 2024 20:00:24 +0800 4Subject: [PATCH] hilog use macro definition api 5 6--- 7 cmake/external_libs/flatbuffers.cmake | 4 +- 8 include/api/context.h | 65 ++ 9 include/c_api/context_c.h | 111 +++ 10 include/c_api/model_c.h | 178 ++++ 11 include/c_api/tensor_c.h | 14 + 12 include/c_api/types_c.h | 57 +- 13 include/sdk_api/context.h | 103 +++ 14 include/sdk_api/tensor.h | 13 + 15 include/sdk_api/types.h | 38 +- 16 .../plugin/device/cpu/kernel/nnacl/BUILD.gn | 3 + 17 .../device/cpu/kernel/nnacl/CMakeLists.txt | 2 +- 18 .../kernel/nnacl/avx/scatter_nd_binary_avx.h | 66 ++ 19 .../nnacl/avx512/scatter_nd_binary_avx512.h | 66 ++ 20 .../cpu/kernel/nnacl/base/scatter_nd_binary.c | 28 + 21 .../cpu/kernel/nnacl/base/scatter_nd_binary.h | 3 + 22 .../nnacl/base/scatter_nd_binary_simd.h.in | 14 + 23 .../kernel/nnacl/custom_is_inf_parameter.h | 26 + 24 .../nnacl/custom_masked_fill_parameter.h | 26 + 25 .../custom_tensor_scatter_max_parameter.h | 26 + 26 .../kernel/nnacl/infer/custom_is_inf_infer.c | 38 + 27 .../kernel/nnacl/infer/custom_is_inf_infer.h | 31 + 28 .../nnacl/infer/custom_masked_fill_infer.c | 37 + 29 .../nnacl/infer/custom_masked_fill_infer.h | 31 + 30 .../infer/custom_tensor_scatter_max_infer.c | 37 + 31 .../infer/custom_tensor_scatter_max_infer.h | 31 + 32 .../nnacl/neon/scatter_nd_binary_neon.h | 65 ++ 33 .../plugin/device/cpu/kernel/nnacl/op_base.h | 4 + 34 .../cpu/kernel/nnacl/scatter_nd_binary_simd.h | 36 + 35 .../kernel/nnacl/sse/scatter_nd_binary_sse.h | 66 ++ 36 mindspore/core/mindrt/BUILD.gn | 9 +- 37 .../mindrt/src/thread/actor_threadpool.cc | 2 +- 38 .../core/mindrt/src/thread/core_affinity.cc | 6 +- 39 .../core/mindrt/src/thread/core_affinity.h | 2 +- 40 .../mindrt/src/thread/parallel_threadpool.cc | 2 +- 41 mindspore/core/mindrt/src/thread/threadlog.h | 28 +- 42 .../core/mindrt/src/thread/threadpool.cc | 7 +- 43 mindspore/lite/BUILD.gn | 82 +- 44 mindspore/lite/CMakeLists.txt | 5 +- 45 mindspore/lite/include/lite_types.h | 1 + 46 mindspore/lite/include/model.h | 4 + 47 .../lite/include/registry/converter_context.h | 4 +- 48 mindspore/lite/mindir/include/mindir.h | 2 + 49 mindspore/lite/mindir/src/mindir.cc | 40 + 50 mindspore/lite/mindir/src/mindir_tensor.cc | 2 +- 51 mindspore/lite/mindir/src/utils.cc | 2 +- 52 mindspore/lite/src/CMakeLists.txt | 6 +- 53 mindspore/lite/src/common/context_util.cc | 14 +- 54 mindspore/lite/src/common/log.cc | 33 +- 55 mindspore/lite/src/common/log.h | 50 +- 56 .../common/ops/populate/custom_populate.cc | 53 ++ 57 mindspore/lite/src/litert/c_api/context_c.cc | 372 +++++++- 58 mindspore/lite/src/litert/c_api/context_c.h | 23 - 59 mindspore/lite/src/litert/c_api/model_c.cc | 724 ++++++++------- 60 mindspore/lite/src/litert/c_api/tensor_c.cc | 78 +- 61 .../lite/src/litert/c_api/type_c_private.h | 40 + 62 mindspore/lite/src/litert/cxx_api/context.cc | 85 ++ 63 .../lite/src/litert/cxx_api/converters.cc | 60 +- 64 .../lite/src/litert/cxx_api/converters.h | 4 +- 65 .../src/litert/delegate/nnrt/CMakeLists.txt | 27 +- 66 .../delegate/nnrt/checker/primitive_check.cc | 2 + 67 .../src/litert/delegate/nnrt/nnrt_delegate.cc | 836 ++++++++++++++---- 68 .../src/litert/delegate/nnrt/nnrt_delegate.h | 74 +- 69 .../litert/delegate/nnrt/nnrt_model_kernel.cc | 3 +- 70 .../litert/delegate/nnrt/nnrt_model_kernel.h | 2 +- 71 .../src/litert/delegate/nnrt/nnrt_stub.cc | 99 +++ 72 mindspore/lite/src/litert/infer_manager.cc | 3 +- 73 mindspore/lite/src/litert/inner_context.cc | 4 + 74 mindspore/lite/src/litert/inner_context.h | 14 + 75 mindspore/lite/src/litert/kernel/cpu/BUILD.gn | 51 +- 76 .../src/litert/kernel/cpu/base/custom_base.cc | 46 + 77 .../src/litert/kernel/cpu/base/custom_base.h | 43 + 78 .../litert/kernel/cpu/base/custom_is_inf.cc | 61 ++ 79 .../litert/kernel/cpu/base/custom_is_inf.h | 38 + 80 .../kernel/cpu/base/custom_masked_fill.cc | 84 ++ 81 .../kernel/cpu/base/custom_masked_fill.h | 35 + 82 .../kernel/cpu/base/custom_tensor_scatter.cc | 75 ++ 83 .../kernel/cpu/base/custom_tensor_scatter.h | 36 + 84 mindspore/lite/src/litert/lite_model.cc | 29 + 85 mindspore/lite/src/litert/lite_session.cc | 39 +- 86 mindspore/lite/src/litert/lite_session.h | 1 + 87 mindspore/lite/src/litert/scheduler.cc | 17 + 88 mindspore/lite/src/litert/tensor_category.cc | 4 + 89 mindspore/lite/src/litert/tensor_category.h | 1 + 90 mindspore/lite/test/CMakeLists.txt | 15 +- 91 mindspore/lite/test/runtest.sh | 1 + 92 .../test/ut/test_data/third_party_model.cfg | 8 + 93 .../tools/converter/api/converter_api_test.cc | 10 + 94 .../third_party_param_parser_test.cc | 176 ++++ 95 .../lite/tools/benchmark/benchmark_base.cc | 2 +- 96 .../lite/tools/benchmark/benchmark_base.h | 2 +- 97 .../lite/tools/benchmark/benchmark_c_api.cc | 4 + 98 .../tools/benchmark/benchmark_unified_api.cc | 5 + 99 .../lite/tools/benchmark_train/CMakeLists.txt | 3 + 100 mindspore/lite/tools/benchmark_train/main.cc | 3 +- 101 .../lite/tools/benchmark_train/net_runner.cc | 10 +- 102 .../lite/tools/benchmark_train/net_train.cc | 418 +-------- 103 .../lite/tools/benchmark_train/net_train.h | 229 +---- 104 .../tools/benchmark_train/net_train_base.cc | 410 +++++++++ 105 .../tools/benchmark_train/net_train_base.h | 288 ++++++ 106 .../tools/benchmark_train/net_train_c_api.cc | 659 ++++++++++++++ 107 .../tools/benchmark_train/net_train_c_api.h | 121 +++ 108 .../tools/benchmark_train/run_net_train.cc | 86 ++ 109 .../tools/benchmark_train/run_net_train.h | 22 + 110 mindspore/lite/tools/converter/CMakeLists.txt | 4 + 111 .../config_parser/config_file_parser.cc | 27 + 112 .../config_parser/config_file_parser.h | 15 + 113 .../config_parser/third_party_param_parser.cc | 299 +++++++ 114 .../config_parser/third_party_param_parser.h | 44 + 115 mindspore/lite/tools/converter/converter.cc | 34 +- 116 .../tools/converter/converter_funcgraph.cc | 13 +- 117 .../converter_lite/converter_flags.cc | 4 +- 118 .../tools/converter/cxx_api/converter_para.h | 14 + 119 .../tools/converter/graphdef_transform.cc | 44 + 120 .../parser/third_party/CMakeLists.txt | 4 + 121 .../third_party/third_party_model_parser.cc | 277 ++++++ 122 .../third_party/third_party_model_parser.h | 50 ++ 123 .../registry/model_parser_registry.cc | 4 +- 124 117 files changed, 6456 insertions(+), 1432 deletions(-) 125 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h 126 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h 127 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h 128 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h 129 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h 130 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c 131 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h 132 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c 133 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h 134 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c 135 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h 136 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h 137 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h 138 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h 139 create mode 100644 mindspore/lite/src/litert/c_api/type_c_private.h 140 create mode 100644 mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc 141 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc 142 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_base.h 143 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc 144 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h 145 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc 146 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h 147 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc 148 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h 149 create mode 100644 mindspore/lite/test/ut/test_data/third_party_model.cfg 150 create mode 100644 mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 151 create mode 100644 mindspore/lite/tools/benchmark_train/net_train_base.cc 152 create mode 100644 mindspore/lite/tools/benchmark_train/net_train_base.h 153 create mode 100644 mindspore/lite/tools/benchmark_train/net_train_c_api.cc 154 create mode 100644 mindspore/lite/tools/benchmark_train/net_train_c_api.h 155 create mode 100644 mindspore/lite/tools/benchmark_train/run_net_train.cc 156 create mode 100644 mindspore/lite/tools/benchmark_train/run_net_train.h 157 create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 158 create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 159 create mode 100644 mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 160 create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 161 create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 162 163diff --git a/cmake/external_libs/flatbuffers.cmake b/cmake/external_libs/flatbuffers.cmake 164index 2fde4311..87f0425b 100644 165--- a/cmake/external_libs/flatbuffers.cmake 166+++ b/cmake/external_libs/flatbuffers.cmake 167@@ -21,8 +21,8 @@ else() 168 # flatbuffers.lib cimplied by msvc 169 set(CMAKE_STATIC_LIBRARY_PREFIX "") 170 else() 171- set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong") 172- set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong") 173+ set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -Wno-error=unused-but-set-variable") 174+ set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -Wno-error=unused-but-set-variable") 175 endif() 176 177 if(WIN32) 178diff --git a/include/api/context.h b/include/api/context.h 179index c9fb11f0..eb704d44 100644 180--- a/include/api/context.h 181+++ b/include/api/context.h 182@@ -39,6 +39,8 @@ enum DeviceType { 183 kAscend310, 184 kCustomDevice, 185 kAllDevice, 186+ //ohos-only device range[60,80) 187+ kNNRt = 60, 188 // add new type here 189 kInvalidDeviceType = 100, 190 }; 191@@ -598,5 +600,68 @@ void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_ 192 SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); 193 } 194 std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } 195+ 196+struct Extension { 197+ std::string name; 198+ std::vector<uint8_t> value; 199+}; 200+ 201+class MS_API NNRTDeviceInfo : public DeviceInfoContext { 202+ public: 203+ /// \brief Get the type of this DeviceInfoContext. 204+ /// 205+ /// \return Type of this DeviceInfoContext. 206+ enum DeviceType GetDeviceType() const override { return DeviceType::kNNRt; }; 207+ 208+ /// \brief Set device id. 209+ /// 210+ /// \param[in] device_id The device id. 211+ void SetDeviceID(size_t device_id); 212+ 213+ /// \brief Get the device id. 214+ /// 215+ /// \return The device id. 216+ size_t GetDeviceID() const; 217+ 218+ /// \brief Set performance mode. 219+ /// 220+ /// \param[in] performance_mode The performance mode. 221+ void SetPerformanceMode(int performance_mode); 222+ 223+ /// \brief Get performance mode. 224+ /// 225+ /// \return The priority. 226+ int GetPerformanceMode() const; 227+ 228+ /// \brief Set priority. 229+ /// 230+ /// \param[in] priority The priority. 231+ void SetPriority(int priority); 232+ 233+ /// \brief Get priority. 234+ /// 235+ /// \return The priority. 236+ int GetPriority() const; 237+ 238+ /// \brief Set enables to perform the float16 inference 239+ /// 240+ /// \param[in] is_fp16 Enable float16 inference or not. 241+ void SetEnableFP16(bool is_fp16); 242+ 243+ /// \brief Get enables to perform the float16 inference 244+ /// 245+ /// \return Whether enable float16 inference. 246+ bool GetEnableFP16() const; 247+ 248+ /// \brief Set extensions 249+ /// 250+ /// \param[in] extension array. 251+ void SetExtensions(const std::vector<Extension> &extensions); 252+ 253+ /// \brief Get extensions 254+ /// 255+ /// \return extension array. 256+ std::vector<Extension> GetExtensions() const; 257+}; 258 } // namespace mindspore 259 #endif // MINDSPORE_INCLUDE_API_CONTEXT_H 260diff --git a/include/c_api/context_c.h b/include/c_api/context_c.h 261index 53839e80..8951da25 100644 262--- a/include/c_api/context_c.h 263+++ b/include/c_api/context_c.h 264@@ -19,6 +19,7 @@ 265 #include <stddef.h> 266 #include <stdint.h> 267 #include <stdbool.h> 268+#include "include/c_api/status_c.h" 269 #include "include/c_api/types_c.h" 270 271 #ifdef __cplusplus 272@@ -173,6 +174,116 @@ OH_AI_API void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, 273 /// \return NPU frequency 274 OH_AI_API int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info); 275 276+/// \brief Obtain the all device descriptions in NNRT. 277+/// 278+/// \param[out] num Number of NNRT device description. 279+/// 280+/// \return NNRT device description array. 281+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num); 282+ 283+/// \brief Obtain the specified element in NNRt device description array. 284+/// 285+/// \param[in] descs NNRT device description array. 286+/// \param[in] index Element index. 287+/// 288+/// \return NNRT device description. 289+OH_AI_API NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index); 290+ 291+/// \brief Obtain the all device descriptions in NNRT. 292+/// 293+/// \param[out] num Number of NNRT device description. 294+/// 295+/// \return NNRT device description array. 296+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num); 297+ 298+/// \brief Destroy the NNRT device descriptions returned by OH_AI_GetAllNNRTDeviceDescs(). 299+/// 300+/// \param[in] desc NNRT device description array. 301+OH_AI_API void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc); 302+ 303+/// \brief Obtain the device id in NNRT device description. 304+/// 305+/// \param[in] desc pointer to the NNRT device description instance. 306+/// 307+/// \return NNRT device id. 308+OH_AI_API size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 309+ 310+/// \brief Obtain the device name in NNRT device description. 311+/// 312+/// \param[in] desc pointer to the NNRT device description instance. 313+/// 314+/// \return NNRT device name. 315+OH_AI_API const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 316+ 317+/// \brief Obtain the device type in NNRT device description. 318+/// 319+/// \param[in] desc pointer to the NNRT device description instance. 320+/// 321+/// \return NNRT device type. 322+OH_AI_API OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 323+ 324+/// \brief Create the NNRT device info by exactly matching the specific device name. 325+/// 326+/// \param[in] name NNRt device name. 327+/// 328+/// \return Device info object handle. 329+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name); 330+ 331+/// \brief Create the NNRT device info by finding the first device with the specific device type. 332+/// 333+/// \param[in] name NNRt device type. 334+/// 335+/// \return Device info object handle. 336+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type); 337+ 338+/// \brief Set the NNRT device id, Only valid for NNRT. 339+/// 340+/// \param[in] device_info Device info object handle. 341+/// \param[in] device_id NNRT device id. 342+OH_AI_API void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id); 343+ 344+/// \brief Obtain the NNRT device id, Only valid for NNRT. 345+/// 346+/// \param[in] device_info Device info object handle. 347+/// 348+/// \return NNRT device id. 349+OH_AI_API size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info); 350+ 351+/// \brief Set the NNRT performance mode, Only valid for NNRT. 352+/// 353+/// \param[in] device_info Device info object handle. 354+/// \param[in] device_id NNRT performance mode. 355+OH_AI_API void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode); 356+ 357+/// \brief Obtain the NNRT performance mode, Only valid for NNRT. 358+/// 359+/// \param[in] device_info Device info object handle. 360+/// 361+/// \return NNRT performance mode. 362+OH_AI_API OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info); 363+ 364+/// \brief Set the NNRT priority, Only valid for NNRT. 365+/// 366+/// \param[in] device_info Device info object handle. 367+/// \param[in] device_id NNRT priority. 368+OH_AI_API void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority); 369+ 370+/// \brief Obtain the NNRT priority, Only valid for NNRT. 371+/// 372+/// \param[in] device_info Device info object handle. 373+/// 374+/// \return NNRT priority. 375+OH_AI_API OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info); 376+ 377+/// \brief Add extension of key/value format to device info, Only valid for NNRT. 378+/// 379+/// \param[in] device_info Device info object handle. 380+/// \param[in] name The content of key as a C string. 381+/// \param[in] value The pointer to the value, which is a byte array. 382+/// \param[in] value_size The size of the value, which is a byte array. 383+/// 384+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed. 385+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, const char *name, const char *value, size_t value_size); 386 #ifdef __cplusplus 387 } 388 #endif 389diff --git a/include/c_api/model_c.h b/include/c_api/model_c.h 390index 12a46bcd..2286e673 100644 391--- a/include/c_api/model_c.h 392+++ b/include/c_api/model_c.h 393@@ -26,6 +26,8 @@ extern "C" { 394 395 typedef void *OH_AI_ModelHandle; 396 397+typedef void *OH_AI_TrainCfgHandle; 398+ 399 typedef struct OH_AI_TensorHandleArray { 400 size_t handle_num; 401 OH_AI_TensorHandle *handle_list; 402@@ -168,6 +170,182 @@ OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHa 403 /// \return The output tensor handle with the given name, if the name is not found, an NULL is returned. 404 OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name); 405 406+/// \brief Create a TrainCfg object. Only valid for Lite Train. 407+/// 408+/// \return TrainCfg object handle. 409+OH_AI_API OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate(); 410+ 411+/// \brief Destroy the train_cfg object. Only valid for Lite Train. 412+/// 413+/// \param[in] train_cfg TrainCfg object handle. 414+OH_AI_API void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg); 415+ 416+/// \brief Obtains part of the name that identify a loss kernel. Only valid for Lite Train. 417+/// 418+/// \param[in] train_cfg TrainCfg object handle. 419+/// \param[in] num The num of loss_name. 420+/// 421+/// \return loss_name. 422+OH_AI_API char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num); 423+ 424+/// \brief Set part of the name that identify a loss kernel. Only valid for Lite Train. 425+/// 426+/// \param[in] train_cfg TrainCfg object handle. 427+/// \param[in] loss_name define part of the name that identify a loss kernel. 428+/// \param[in] num The num of loss_name. 429+OH_AI_API void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num); 430+ 431+/// \brief Obtains optimization level of the train_cfg. Only valid for Lite Train. 432+/// 433+/// \param[in] train_cfg TrainCfg object handle. 434+/// 435+/// \return OH_AI_OptimizationLevel. 436+OH_AI_API OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg); 437+ 438+/// \brief Set optimization level of the train_cfg. Only valid for Lite Train. 439+/// 440+/// \param[in] train_cfg TrainCfg object handle. 441+/// \param[in] level The optimization level of train_cfg. 442+OH_AI_API void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level); 443+ 444+/// \brief Build the train model from model buffer so that it can run on a device. Only valid for Lite Train. 445+/// 446+/// \param[in] model Model object handle. 447+/// \param[in] model_data Define the buffer read from a model file. 448+/// \param[in] data_size Define bytes number of model file buffer. 449+/// \param[in] model_type Define The type of model file. 450+/// \param[in] model_context Define the context used to store options during execution. 451+/// \param[in] train_cfg Define the config used by training. 452+/// 453+/// \return OH_AI_Status. 454+OH_AI_API OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, 455+ OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, 456+ const OH_AI_TrainCfgHandle train_cfg); 457+ 458+/// \brief Build the train model from model file buffer so that it can run on a device. Only valid for Lite Train. 459+/// 460+/// \param[in] model Model object handle. 461+/// \param[in] model_path Define the model path. 462+/// \param[in] model_type Define The type of model file. 463+/// \param[in] model_context Define the context used to store options during execution. 464+/// \param[in] train_cfg Define the config used by training. 465+/// 466+/// \return OH_AI_Status. 467+OH_AI_API OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, 468+ OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, 469+ const OH_AI_TrainCfgHandle train_cfg); 470+ 471+/// \brief Train model by step. Only valid for Lite Train. 472+/// 473+/// \param[in] model Model object handle. 474+/// \param[in] before CallBack before predict. 475+/// \param[in] after CallBack after predict. 476+/// 477+/// \return OH_AI_Status. 478+OH_AI_API OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, 479+ const OH_AI_KernelCallBack after); 480+ 481+/// \brief Sets the Learning Rate of the training. Only valid for Lite Train. 482+/// 483+/// \param[in] learning_rate to set. 484+/// 485+/// \return OH_AI_Status of operation. 486+OH_AI_API OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate); 487+ 488+/// \brief Obtains the Learning Rate of the optimizer. Only valid for Lite Train. 489+/// 490+/// \return Learning rate. 0.0 if no optimizer was found. 491+OH_AI_API float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model); 492+ 493+/// \brief Obtains all weights tensors of the model. Only valid for Lite Train. 494+/// 495+/// \param[in] model Model object handle. 496+/// 497+/// \return The vector that includes all gradient tensors. 498+OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model); 499+ 500+/// \brief update weights tensors of the model. Only valid for Lite Train. 501+/// 502+/// \param[in] new_weights A vector new weights. 503+/// 504+/// \return OH_AI_Status 505+OH_AI_API OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights); 506+ 507+/// \brief Get the model running mode. 508+/// 509+/// \param[in] model Model object handle. 510+/// 511+/// \return Is Train Mode or not. 512+OH_AI_API bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model); 513+ 514+/// \brief Set the model running mode. Only valid for Lite Train. 515+/// 516+/// \param[in] model Model object handle. 517+/// \param[in] train True means model runs in Train Mode, otherwise Eval Mode. 518+/// 519+/// \return OH_AI_Status. 520+OH_AI_API OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train); 521+ 522+/// \brief Setup training with virtual batches. Only valid for Lite Train. 523+/// 524+/// \param[in] model Model object handle. 525+/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable. 526+/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration. 527+/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration. 528+/// 529+/// \return OH_AI_Status. 530+OH_AI_API OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, 531+ float momentum); 532+ 533+/// \brief Export training model from file. Only valid for Lite Train. 534+/// 535+/// \param[in] model The model data. 536+/// \param[in] model_type The model file type. 537+/// \param[in] model_file The exported model file. 538+/// \param[in] quantization_type The quantification type. 539+/// \param[in] export_inference_only Whether to export a reasoning only model. 540+/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as 541+/// empty, and export the complete reasoning model. 542+/// \param[in] num The number of output_tensor_name. 543+/// 544+/// \return OH_AI_Status. 545+OH_AI_API OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file, 546+ OH_AI_QuantizationType quantization_type, bool export_inference_only, 547+ char **output_tensor_name, size_t num); 548+ 549+/// \brief Export training model from buffer. Only valid for Lite Train. 550+/// 551+/// \param[in] model The model data. 552+/// \param[in] model_type The model file type. 553+/// \param[in] model_data The exported model buffer. 554+/// \param[in] data_size The exported model buffer size. 555+/// \param[in] quantization_type The quantification type. 556+/// \param[in] export_inference_only Whether to export a reasoning only model. 557+/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as 558+/// empty, and export the complete reasoning model. 559+/// \param[in] num The number of output_tensor_name. 560+/// 561+/// \return OH_AI_Status. 562+OH_AI_API OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data, 563+ size_t *data_size, OH_AI_QuantizationType quantization_type, 564+ bool export_inference_only, char **output_tensor_name, size_t num); 565+ 566+/// \brief Export model's weights, which can be used in micro only. Only valid for Lite Train. 567+/// 568+/// \param[in] model The model data. 569+/// \param[in] model_type The model file type. 570+/// \param[in] weight_file The path of exported weight file. 571+/// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`. 572+/// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format. 573+/// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable. 574+/// \param[in] num The number of changeable_weights_name. 575+/// 576+/// \return OH_AI_Status. 577+OH_AI_API OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, 578+ const char *weight_file, bool is_inference, 579+ bool enable_fp16, char **changeable_weights_name, 580+ size_t num); 581+ 582 #ifdef __cplusplus 583 } 584 #endif 585diff --git a/include/c_api/tensor_c.h b/include/c_api/tensor_c.h 586index f18ba163..6d2aaab6 100644 587--- a/include/c_api/tensor_c.h 588+++ b/include/c_api/tensor_c.h 589@@ -17,6 +17,7 @@ 590 #define MINDSPORE_INCLUDE_C_API_TENSOE_C_H 591 592 #include <stddef.h> 593+#include "include/c_api/status_c.h" 594 #include "include/c_api/types_c.h" 595 #include "include/c_api/data_type_c.h" 596 #include "include/c_api/format_c.h" 597@@ -112,6 +113,19 @@ OH_AI_API OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor); 598 /// \param[in] data A pointer to the data of the tensor. 599 OH_AI_API void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data); 600 601+/// \brief Set the data for the tensor with user-allocated data buffer. 602+/// The main purpose of this interface is providing a way of using memory already allocated by user as the Model's 603+/// input, but not which allocated inside the Model object. It can reduce one copy. 604+/// Note: The tensor won't free the data provided by invoker. Invoker has the responsibility to free it. And this 605+/// free action should not be preformed before destruction of the tensor. 606+/// 607+/// \param[in] tensor Tensor object handle. 608+/// \param[in] data A pointer to the user data buffer. 609+/// \param[in] data the byte size of the user data buffer. 610+/// 611+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed. 612+OH_AI_API OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size); 613+ 614 /// \brief Obtain the data pointer of the tensor. 615 /// 616 /// \param[in] tensor Tensor object handle. 617diff --git a/include/c_api/types_c.h b/include/c_api/types_c.h 618index dba54ffa..e520e336 100644 619--- a/include/c_api/types_c.h 620+++ b/include/c_api/types_c.h 621@@ -40,10 +40,65 @@ typedef enum OH_AI_DeviceType { 622 OH_AI_DEVICETYPE_KIRIN_NPU, 623 // add new type here 624 // ohos-only device range: [60, 80) 625- OH_AI_DEVICETYPE__NNRT = 60, 626+ OH_AI_DEVICETYPE_NNRT = 60, 627 OH_AI_DEVICETYPE_INVALID = 100, 628 } OH_AI_DeviceType; 629 630+typedef enum OH_AI_NNRTDeviceType { 631+ /** Devices that are not CPU, GPU, or dedicated accelerator */ 632+ OH_AI_NNRTDEVICE_OTHERS = 0, 633+ /** CPU device */ 634+ OH_AI_NNRTDEVICE_CPU = 1, 635+ /** GPU device */ 636+ OH_AI_NNRTDEVICE_GPU = 2, 637+ /** Dedicated hardware accelerator */ 638+ OH_AI_NNRTDEVICE_ACCELERATOR = 3, 639+} OH_AI_NNRTDeviceType; 640+ 641+typedef enum OH_AI_PerformanceMode { 642+ /** No performance mode preference */ 643+ OH_AI_PERFORMANCE_NONE = 0, 644+ /** Low power consumption mode*/ 645+ OH_AI_PERFORMANCE_LOW = 1, 646+ /** Medium performance mode */ 647+ OH_AI_PERFORMANCE_MEDIUM = 2, 648+ /** High performance mode */ 649+ OH_AI_PERFORMANCE_HIGH = 3, 650+ /** Ultimate performance mode */ 651+ OH_AI_PERFORMANCE_EXTREME = 4 652+} OH_AI_PerformanceMode; 653+ 654+typedef enum OH_AI_Priority { 655+ /** No priority preference */ 656+ OH_AI_PRIORITY_NONE = 0, 657+ /** Low priority */ 658+ OH_AI_PRIORITY_LOW = 1, 659+ /** Medium priority */ 660+ OH_AI_PRIORITY_MEDIUM = 2, 661+ /** High priority */ 662+ OH_AI_PRIORITY_HIGH = 3 663+} OH_AI_Priority; 664+ 665+typedef enum OH_AI_OptimizationLevel { 666+ /** Do not change */ 667+ OH_AI_KO0 = 0, 668+ /** Cast network to float16, keep batchnorm and loss in float32 */ 669+ OH_AI_KO2 = 2, 670+ /** Cast network to float16, including bacthnorm */ 671+ OH_AI_KO3 = 3, 672+ /** Choose optimization based on device */ 673+ OH_AI_KAUTO = 4, 674+ OH_AI_KOPTIMIZATIONTYPE = 0xFFFFFFFF 675+} OH_AI_OptimizationLevel; 676+ 677+typedef enum OH_AI_QuantizationType { 678+ OH_AI_NO_QUANT = 0, 679+ OH_AI_WEIGHT_QUANT = 1, 680+ OH_AI_FULL_QUANT = 2, 681+ OH_AI_UNKNOWN_QUANT_TYPE = 0xFFFFFFFF 682+} OH_AI_QuantizationType; 683+ 684+typedef struct NNRTDeviceDesc NNRTDeviceDesc; 685 #ifdef __cplusplus 686 } 687 #endif 688diff --git a/include/sdk_api/context.h b/include/sdk_api/context.h 689index 5bfc9279..e12b8d6f 100644 690--- a/include/sdk_api/context.h 691+++ b/include/sdk_api/context.h 692@@ -174,6 +174,109 @@ OH_AI_API void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, 693 /// \return NPU frequency 694 OH_AI_API int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info); 695 696+/// \brief Obtain the all device descriptions in NNRT. 697+/// 698+/// \param[out] num Number of NNRT device description. 699+/// 700+/// \return NNRT device description array. 701+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num); 702+ 703+/// \brief Obtain the specified element in NNRt device description array. 704+/// 705+/// \param[in] descs NNRT device description array. 706+/// \param[in] index Element index. 707+/// 708+/// \return NNRT device description. 709+OH_AI_API NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index); 710+ 711+/// \brief Destroy the NNRT device descriptions returned by OH_AI_NNRTGetAllDeviceDescs(). 712+/// 713+/// \param[in] desc NNRT device description array. 714+OH_AI_API void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc); 715+ 716+/// \brief Obtain the device id in NNRT device description. 717+/// 718+/// \param[in] desc pointer to the NNRT device description instance. 719+/// 720+/// \return NNRT device id. 721+OH_AI_API size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 722+ 723+/// \brief Obtain the device name in NNRT device description. 724+/// 725+/// \param[in] desc pointer to the NNRT device description instance. 726+/// 727+/// \return NNRT device name. 728+OH_AI_API const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 729+ 730+/// \brief Obtain the device type in NNRT device description. 731+/// 732+/// \param[in] desc pointer to the NNRT device description instance. 733+/// 734+/// \return NNRT device type. 735+OH_AI_API OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 736+ 737+/// \brief Create the NNRT device info by exactly matching the specific device name. 738+/// 739+/// \param[in] name NNRt device name. 740+/// 741+/// \return Device info object handle. 742+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name); 743+ 744+/// \brief Create the NNRT device info by finding the first device with the specific device type. 745+/// 746+/// \param[in] name NNRt device type. 747+/// 748+/// \return Device info object handle. 749+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type); 750+ 751+/// \brief Set the NNRT device id, Only valid for NNRT. 752+/// 753+/// \param[in] device_info Device info object handle. 754+/// \param[in] device_id NNRT device id. 755+OH_AI_API void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id); 756+ 757+/// \brief Obtain the NNRT device id, Only valid for NNRT. 758+/// 759+/// \param[in] device_info Device info object handle. 760+/// 761+/// \return NNRT device id. 762+OH_AI_API size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info); 763+ 764+/// \brief Set the NNRT performance mode, Only valid for NNRT. 765+/// 766+/// \param[in] device_info Device info object handle. 767+/// \param[in] device_id NNRT performance mode. 768+OH_AI_API void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode); 769+ 770+/// \brief Obtain the NNRT performance mode, Only valid for NNRT. 771+/// 772+/// \param[in] device_info Device info object handle. 773+/// 774+/// \return NNRT performance mode. 775+OH_AI_API OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info); 776+ 777+/// \brief Set the NNRT priority, Only valid for NNRT. 778+/// 779+/// \param[in] device_info Device info object handle. 780+/// \param[in] device_id NNRT priority. 781+OH_AI_API void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority); 782+ 783+/// \brief Obtain the NNRT priority, Only valid for NNRT. 784+/// 785+/// \param[in] device_info Device info object handle. 786+/// 787+/// \return NNRT priority. 788+OH_AI_API OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info); 789+ 790+/// \brief Add extension of key/value format to device info, Only valid for NNRT. 791+/// 792+/// \param[in] device_info Device info object handle. 793+/// \param[in] name The content of key as a C string. 794+/// \param[in] value The pointer to the value, which is a byte array. 795+/// \param[in] value_size The size of the value, which is a byte array. 796+/// 797+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed. 798+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, const char *name, const char *value, size_t value_size); 799 #ifdef __cplusplus 800 } 801 #endif 802diff --git a/include/sdk_api/tensor.h b/include/sdk_api/tensor.h 803index f6ba02cd..3dad04ac 100644 804--- a/include/sdk_api/tensor.h 805+++ b/include/sdk_api/tensor.h 806@@ -17,6 +17,7 @@ 807 #define MINDSPORE_INCLUDE_C_API_TENSOE_C_H 808 809 #include <stddef.h> 810+#include "mindspore/status.h" 811 #include "mindspore/types.h" 812 #include "mindspore/data_type.h" 813 #include "mindspore/format.h" 814@@ -140,6 +141,18 @@ OH_AI_API int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor); 815 /// \return The data size of the tensor. 816 OH_AI_API size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor); 817 818+/// \brief Set the data for the tensor with user-allocated data buffer. 819+/// The main purpose of this interface is providing a way of using memory already allocated by user as the Model's 820+/// input, but not which allocated inside the Model object. It can reduce one copy. 821+/// Note: The tensor won't free the data provided by invoker. Invoker has the responsibility to free it. And this 822+/// free action should not be preformed before destruction of the tensor. 823+/// 824+/// \param[in] tensor Tensor object handle. 825+/// \param[in] data A pointer to the user data buffer. 826+/// \param[in] data the byte size of the user data buffer. 827+/// 828+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed. 829+OH_AI_API OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size); 830 #ifdef __cplusplus 831 } 832 #endif 833diff --git a/include/sdk_api/types.h b/include/sdk_api/types.h 834index a39c6daa..d38660b0 100644 835--- a/include/sdk_api/types.h 836+++ b/include/sdk_api/types.h 837@@ -40,10 +40,46 @@ typedef enum OH_AI_DeviceType { 838 OH_AI_DEVICETYPE_KIRIN_NPU, 839 // add new type here 840 // ohos-only device range: [60, 80) 841- OH_AI_DeviceType_NNRT = 60, 842+ OH_AI_DEVICETYPE_NNRT = 60, 843 OH_AI_DEVICETYPE_INVALID = 100, 844 } OH_AI_DeviceType; 845 846+typedef enum OH_AI_NNRTDeviceType { 847+ /** Devices that are not CPU, GPU, or dedicated accelerator */ 848+ OH_AI_NNRTDEVICE_OTHERS = 0, 849+ /** CPU device */ 850+ OH_AI_NNRTDEVICE_CPU = 1, 851+ /** GPU device */ 852+ OH_AI_NNRTDEVICE_GPU = 2, 853+ /** Dedicated hardware accelerator */ 854+ OH_AI_NNRTDEVICE_ACCELERATOR = 3, 855+} OH_AI_NNRTDeviceType; 856+ 857+typedef enum OH_AI_PerformanceMode { 858+ /** No performance mode preference */ 859+ OH_AI_PERFORMANCE_NONE = 0, 860+ /** Low power consumption mode*/ 861+ OH_AI_PERFORMANCE_LOW = 1, 862+ /** Medium performance mode */ 863+ OH_AI_PERFORMANCE_MEDIUM = 2, 864+ /** High performance mode */ 865+ OH_AI_PERFORMANCE_HIGH = 3, 866+ /** Ultimate performance mode */ 867+ OH_AI_PERFORMANCE_EXTREME = 4 868+} OH_AI_PerformanceMode; 869+ 870+typedef enum OH_AI_Priority { 871+ /** No priority preference */ 872+ OH_AI_PRIORITY_NONE = 0, 873+ /** Low priority */ 874+ OH_AI_PRIORITY_LOW = 1, 875+ /** Medium priority */ 876+ OH_AI_PRIORITY_MEDIUM = 2, 877+ /** High priority */ 878+ OH_AI_PRIORITY_HIGH = 3 879+} OH_AI_Priority; 880+ 881+typedef struct NNRTDeviceDesc NNRTDeviceDesc; 882 #ifdef __cplusplus 883 } 884 #endif 885diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 886index 7bbc3782..103e53b7 100644 887--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 888+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 889@@ -498,6 +498,9 @@ infer_shape_sources = [ 890 "infer/crop_infer.c", 891 "infer/cumsum_infer.c", 892 "infer/custom_gru_infer.c", 893+ "infer/custom_masked_fill_infer.c", 894+ "infer/custom_is_inf_infer.c", 895+ "infer/custom_tensor_scatter_max_infer.c", 896 "infer/decoder_layer_infer.c", 897 "infer/deconv2d_infer.c", 898 "infer/depth_to_space_infer.c", 899diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt 900index c1685a65..6fef44fd 100644 901--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt 902+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt 903@@ -238,7 +238,7 @@ endif() 904 if(PLATFORM_ARM) 905 set(NO_FAST_MATH_OPTI ${NNACL_DIR}/fp32/resize_fp32.c) 906 set_source_files_properties(${NO_FAST_MATH_OPTI} PROPERTIES LANGUAGE C 907- COMPILE_FLAGS "${CMAKE_C_FLAGS} -fno-fast-math") 908+ COMPILE_FLAGS "${CMAKE_C_FLAGS} -w -fno-fast-math") 909 endif() 910 911 add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC}) 912diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h 913new file mode 100644 914index 00000000..14bd1d76 915--- /dev/null 916+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h 917@@ -0,0 +1,66 @@ 918+/** 919+* Copyright 2023 Huawei Technologies Co., Ltd 920+* 921+* Licensed under the Apache License, Version 2.0 (the "License"); 922+* you may not use this file except in compliance with the License. 923+* You may obtain a copy of the License at 924+* 925+* http://www.apache.org/licenses/LICENSE-2.0 926+* 927+* Unless required by applicable law or agreed to in writing, software 928+* distributed under the License is distributed on an "AS IS" BASIS, 929+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 930+* See the License for the specific language governing permissions and 931+* limitations under the License. 932+*/ 933+#ifndef NNACL_BASE_SCATTER_ND_BINARY_AVX_H_ 934+#define NNACL_BASE_SCATTER_ND_BINARY_AVX_H_ 935+ 936+#include "nnacl/intrinsics/ms_simd_instructions.h" 937+#include "nnacl/intrinsics/ms_simd_avx_instructions.h" 938+ 939+#ifdef __cplusplus 940+extern "C" { 941+#endif 942+#pragma GCC push_options 943+#pragma GCC target("avx", "avx2") 944+#define MS_SIMD_INSTRUCTION MS_SIMD_AVX_INSTRUCTION 945+#define BLOCK_NUM 8 946+#define MS_SIMD_AVX 947+ 948+static inline int ScatterNDAddFp32AVX(int index, const float *update, int size, float *output) { 949+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 950+ SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 951+ } 952+ return index; 953+} 954+ 955+static inline int ScatterNDAddInt32AVX(int index, const int *update, int size, int *output) { 956+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 957+ SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 958+ } 959+ return index; 960+} 961+ 962+static inline int ScatterNDMaxFp32AVX(int index, const float *update, int size, float *output) { 963+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 964+ SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 965+ } 966+ return index; 967+} 968+ 969+static inline int ScatterNDMaxInt32AVX(int index, const int *update, int size, int *output) { 970+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 971+ SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 972+ } 973+ return index; 974+} 975+ 976+#undef MS_SIMD_INSTRUCTION 977+#undef BLOCK_NUM 978+#pragma GCC pop_options 979+#undef MS_SIMD_AVX 980+#ifdef __cplusplus 981+} 982+#endif 983+#endif // NNACL_BASE_SCATTER_ND_BINARY_AVX_H_ 984diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h 985new file mode 100644 986index 00000000..abf024c5 987--- /dev/null 988+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h 989@@ -0,0 +1,66 @@ 990+/** 991+* Copyright 2023 Huawei Technologies Co., Ltd 992+* 993+* Licensed under the Apache License, Version 2.0 (the "License"); 994+* you may not use this file except in compliance with the License. 995+* You may obtain a copy of the License at 996+* 997+* http://www.apache.org/licenses/LICENSE-2.0 998+* 999+* Unless required by applicable law or agreed to in writing, software 1000+* distributed under the License is distributed on an "AS IS" BASIS, 1001+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1002+* See the License for the specific language governing permissions and 1003+* limitations under the License. 1004+*/ 1005+#ifndef NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_ 1006+#define NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_ 1007+ 1008+#include "nnacl/intrinsics/ms_simd_instructions.h" 1009+#include "nnacl/intrinsics/ms_simd_avx512_instructions.h" 1010+ 1011+#ifdef __cplusplus 1012+extern "C" { 1013+#endif 1014+#pragma GCC push_options 1015+#pragma GCC target("avx512f") 1016+#define MS_SIMD_INSTRUCTION MS_SIMD_AVX512_INSTRUCTION 1017+#define BLOCK_NUM 16 1018+#define MS_SIMD_AVX512 1019+ 1020+static inline int ScatterNDAddFp32AVX512(int index, const float *update, int size, float *output) { 1021+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1022+ SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1023+ } 1024+ return index; 1025+} 1026+ 1027+static inline int ScatterNDAddInt32AVX512(int index, const int *update, int size, int *output) { 1028+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1029+ SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1030+ } 1031+ return index; 1032+} 1033+ 1034+static inline int ScatterNDMaxFp32AVX512(int index, const float *update, int size, float *output) { 1035+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1036+ SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1037+ } 1038+ return index; 1039+} 1040+ 1041+static inline int ScatterNDMaxInt32AVX512(int index, const int *update, int size, int *output) { 1042+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1043+ SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1044+ } 1045+ return index; 1046+} 1047+ 1048+#undef MS_SIMD_INSTRUCTION 1049+#undef BLOCK_NUM 1050+#pragma GCC pop_options 1051+#undef MS_SIMD_AVX512 1052+#ifdef __cplusplus 1053+} 1054+#endif 1055+#endif // NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_ 1056diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c 1057index bca71f55..e496bb4b 100644 1058--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c 1059+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c 1060@@ -77,3 +77,31 @@ int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, 1061 } 1062 return NNACL_OK; 1063 } 1064+ 1065+int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, 1066+ int task_id) { 1067+ if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) { 1068+ return NNACL_NULL_PTR; 1069+ } 1070+ if (param->op_parameter.thread_num_ == 0) { 1071+ return NNACL_ERR; 1072+ } 1073+ int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); 1074+ int begin = unit_per_thread * task_id; 1075+ int end = MSMIN(begin + unit_per_thread, param->num_unit); 1076+ if (type == 0) { 1077+ float *update_fp32 = (float *)update; 1078+ float *output_fp32 = (float *)output; 1079+ for (int i = begin; i < end; i++) { 1080+ const float *update_data = update_fp32 + i * param->unit_size; 1081+ float *output_data = output_fp32 + output_unit_offsets[i]; 1082+ int j = 0; 1083+ 1084+ SIMD_RUN_NO_SCALAR(ScatterNDMaxFp32, j, update_data, param->unit_size, output_data); 1085+ for (; j < param->unit_size; j++) { 1086+ output_data[j] = fmaxf(update_data[j], output_data[j]); 1087+ } 1088+ } 1089+ } 1090+ return NNACL_OK; 1091+} 1092diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h 1093index 3af55335..36657cd9 100644 1094--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h 1095+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h 1096@@ -27,6 +27,9 @@ int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, 1097 1098 int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, 1099 int task_id); 1100+ 1101+int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, 1102+ int task_id); 1103 #ifdef __cplusplus 1104 } 1105 #endif 1106diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in 1107index c72d9cc2..46bb20ce 100644 1108--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in 1109+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in 1110@@ -38,6 +38,20 @@ static inline int ScatterNDAddInt32@SIMD_INSTRUCTION@(int index, const int *upda 1111 return index; 1112 } 1113 1114+static inline int ScatterNDMaxFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) { 1115+for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1116+SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1117+} 1118+return index; 1119+} 1120+ 1121+static inline int ScatterNDMaxInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) { 1122+for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1123+SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1124+} 1125+return index; 1126+} 1127+ 1128 @SIMD_INSTRUCTION_END@ 1129 #ifdef __cplusplus 1130 } 1131diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h 1132new file mode 100644 1133index 00000000..e1eae394 1134--- /dev/null 1135+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h 1136@@ -0,0 +1,26 @@ 1137+/** 1138+ * Copyright 2023 Huawei Technologies Co., Ltd 1139+ * 1140+ * Licensed under the Apache License, Version 2.0 (the "License"); 1141+ * you may not use this file except in compliance with the License. 1142+ * You may obtain a copy of the License at 1143+ * 1144+ * http://www.apache.org/licenses/LICENSE-2.0 1145+ * 1146+ * Unless required by applicable law or agreed to in writing, software 1147+ * distributed under the License is distributed on an "AS IS" BASIS, 1148+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1149+ * See the License for the specific language governing permissions and 1150+ * limitations under the License. 1151+ */ 1152+#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ 1153+#define MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ 1154+ 1155+#include "nnacl/op_base.h" 1156+ 1157+typedef struct CustomIsInfParameter { 1158+ // Primitive parameter 1159+ OpParameter op_parameter_; 1160+} CustomIsInfParameter; 1161+ 1162+#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ 1163diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h 1164new file mode 100644 1165index 00000000..047d3d3f 1166--- /dev/null 1167+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h 1168@@ -0,0 +1,26 @@ 1169+/** 1170+ * Copyright 2023 Huawei Technologies Co., Ltd 1171+ * 1172+ * Licensed under the Apache License, Version 2.0 (the "License"); 1173+ * you may not use this file except in compliance with the License. 1174+ * You may obtain a copy of the License at 1175+ * 1176+ * http://www.apache.org/licenses/LICENSE-2.0 1177+ * 1178+ * Unless required by applicable law or agreed to in writing, software 1179+ * distributed under the License is distributed on an "AS IS" BASIS, 1180+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1181+ * See the License for the specific language governing permissions and 1182+ * limitations under the License. 1183+ */ 1184+#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ 1185+#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ 1186+ 1187+#include "nnacl/op_base.h" 1188+ 1189+typedef struct CustomMaskedFillParameter { 1190+ // Primitive parameter 1191+ OpParameter op_parameter_; 1192+} CustomMaskedFillParameter; 1193+ 1194+#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ 1195diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h 1196new file mode 100644 1197index 00000000..ba6940db 1198--- /dev/null 1199+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h 1200@@ -0,0 +1,26 @@ 1201+/** 1202+ * Copyright 2023 Huawei Technologies Co., Ltd 1203+ * 1204+ * Licensed under the Apache License, Version 2.0 (the "License"); 1205+ * you may not use this file except in compliance with the License. 1206+ * You may obtain a copy of the License at 1207+ * 1208+ * http://www.apache.org/licenses/LICENSE-2.0 1209+ * 1210+ * Unless required by applicable law or agreed to in writing, software 1211+ * distributed under the License is distributed on an "AS IS" BASIS, 1212+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1213+ * See the License for the specific language governing permissions and 1214+ * limitations under the License. 1215+ */ 1216+#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_ 1217+#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_ 1218+ 1219+#include "nnacl/op_base.h" 1220+ 1221+typedef struct CustomTensorScatterMaxParameter { 1222+ // Primitive parameter 1223+ OpParameter op_parameter_; 1224+} CustomTensorScatterMaxParameter; 1225+ 1226+#endif // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_ 1227diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c 1228new file mode 100644 1229index 00000000..fc87d157 1230--- /dev/null 1231+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c 1232@@ -0,0 +1,38 @@ 1233+/** 1234+ * Copyright 2023 Huawei Technologies Co., Ltd 1235+ * 1236+ * Licensed under the Apache License, Version 2.0 (the "License"); 1237+ * you may not use this file except in compliance with the License. 1238+ * You may obtain a copy of the License at 1239+ * 1240+ * http://www.apache.org/licenses/LICENSE-2.0 1241+ * 1242+ * Unless required by applicable law or agreed to in writing, software 1243+ * distributed under the License is distributed on an "AS IS" BASIS, 1244+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1245+ * See the License for the specific language governing permissions and 1246+ * limitations under the License. 1247+ */ 1248+ 1249+#include "nnacl/infer/custom_is_inf_infer.h" 1250+#include "nnacl/infer/infer_register.h" 1251+ 1252+int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 1253+ OpParameter *parameter) { 1254+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C1NUM, C1NUM); 1255+ if (check_ret != NNACL_OK) { 1256+ return check_ret; 1257+ } 1258+ 1259+ const TensorC *input = inputs[0]; 1260+ TensorC *output = outputs[0]; 1261+ output->data_type_ = kNumberTypeBool; 1262+ output->format_ = input->format_; 1263+ if (!InferFlag(inputs, inputs_size)) { 1264+ return NNACL_INFER_INVALID; 1265+ } 1266+ SetShapeTensor(output, input); 1267+ return NNACL_OK; 1268+} 1269+ 1270+REG_INFER(CustomIsInf, PrimType_Inner_CustomIsInf, CustomIsInfInferShape) 1271diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h 1272new file mode 100644 1273index 00000000..d1b4b33d 1274--- /dev/null 1275+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h 1276@@ -0,0 +1,31 @@ 1277+/** 1278+ * Copyright 2023 Huawei Technologies Co., Ltd 1279+ * 1280+ * Licensed under the Apache License, Version 2.0 (the "License"); 1281+ * you may not use this file except in compliance with the License. 1282+ * You may obtain a copy of the License at 1283+ * 1284+ * http://www.apache.org/licenses/LICENSE-2.0 1285+ * 1286+ * Unless required by applicable law or agreed to in writing, software 1287+ * distributed under the License is distributed on an "AS IS" BASIS, 1288+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1289+ * See the License for the specific language governing permissions and 1290+ * limitations under the License. 1291+ */ 1292+#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H 1293+#define MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H 1294+ 1295+#include "nnacl/infer/common_infer.h" 1296+ 1297+#ifdef __cplusplus 1298+extern "C" { 1299+#endif 1300+ 1301+int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 1302+ OpParameter *parameter); 1303+ 1304+#ifdef __cplusplus 1305+} 1306+#endif 1307+#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H 1308diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c 1309new file mode 100644 1310index 00000000..957a4d4f 1311--- /dev/null 1312+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c 1313@@ -0,0 +1,37 @@ 1314+/** 1315+ * Copyright 2023 Huawei Technologies Co., Ltd 1316+ * 1317+ * Licensed under the Apache License, Version 2.0 (the "License"); 1318+ * you may not use this file except in compliance with the License. 1319+ * You may obtain a copy of the License at 1320+ * 1321+ * http://www.apache.org/licenses/LICENSE-2.0 1322+ * 1323+ * Unless required by applicable law or agreed to in writing, software 1324+ * distributed under the License is distributed on an "AS IS" BASIS, 1325+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1326+ * See the License for the specific language governing permissions and 1327+ * limitations under the License. 1328+ */ 1329+ 1330+#include "nnacl/infer/custom_masked_fill_infer.h" 1331+#include "nnacl/infer/infer_register.h" 1332+ 1333+int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 1334+ OpParameter *parameter) { 1335+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); 1336+ if (check_ret != NNACL_OK) { 1337+ return check_ret; 1338+ } 1339+ 1340+ const TensorC *input = inputs[0]; 1341+ TensorC *output = outputs[0]; 1342+ SetDataTypeFormat(output, input); 1343+ if (!InferFlag(inputs, inputs_size)) { 1344+ return NNACL_INFER_INVALID; 1345+ } 1346+ SetShapeTensor(output, input); 1347+ return NNACL_OK; 1348+} 1349+ 1350+REG_INFER(CustomMaskedFill, PrimType_Inner_CustomMaskedFill, CustomMaskedFillInferShape) 1351diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h 1352new file mode 100644 1353index 00000000..a8adbae2 1354--- /dev/null 1355+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h 1356@@ -0,0 +1,31 @@ 1357+/** 1358+ * Copyright 2023 Huawei Technologies Co., Ltd 1359+ * 1360+ * Licensed under the Apache License, Version 2.0 (the "License"); 1361+ * you may not use this file except in compliance with the License. 1362+ * You may obtain a copy of the License at 1363+ * 1364+ * http://www.apache.org/licenses/LICENSE-2.0 1365+ * 1366+ * Unless required by applicable law or agreed to in writing, software 1367+ * distributed under the License is distributed on an "AS IS" BASIS, 1368+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1369+ * See the License for the specific language governing permissions and 1370+ * limitations under the License. 1371+ */ 1372+#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H 1373+#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H 1374+ 1375+#include "nnacl/infer/common_infer.h" 1376+ 1377+#ifdef __cplusplus 1378+extern "C" { 1379+#endif 1380+ 1381+int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 1382+ OpParameter *parameter); 1383+ 1384+#ifdef __cplusplus 1385+} 1386+#endif 1387+#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H 1388diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c 1389new file mode 100644 1390index 00000000..be6716ba 1391--- /dev/null 1392+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c 1393@@ -0,0 +1,37 @@ 1394+/** 1395+ * Copyright 2023 Huawei Technologies Co., Ltd 1396+ * 1397+ * Licensed under the Apache License, Version 2.0 (the "License"); 1398+ * you may not use this file except in compliance with the License. 1399+ * You may obtain a copy of the License at 1400+ * 1401+ * http://www.apache.org/licenses/LICENSE-2.0 1402+ * 1403+ * Unless required by applicable law or agreed to in writing, software 1404+ * distributed under the License is distributed on an "AS IS" BASIS, 1405+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1406+ * See the License for the specific language governing permissions and 1407+ * limitations under the License. 1408+ */ 1409+ 1410+#include "nnacl/infer/custom_tensor_scatter_max_infer.h" 1411+#include "nnacl/infer/infer_register.h" 1412+ 1413+int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, 1414+ size_t outputs_size, OpParameter *parameter) { 1415+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); 1416+ if (check_ret != NNACL_OK) { 1417+ return check_ret; 1418+ } 1419+ 1420+ const TensorC *input = inputs[0]; 1421+ TensorC *output = outputs[0]; 1422+ SetDataTypeFormat(output, input); 1423+ if (!InferFlag(inputs, inputs_size)) { 1424+ return NNACL_INFER_INVALID; 1425+ } 1426+ SetShapeTensor(output, input); 1427+ return NNACL_OK; 1428+} 1429+ 1430+REG_INFER(CustomTensorScatterMax, PrimType_Inner_CustomTensorScatterMax, CustomTensorScatterMaxInferShape) 1431diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h 1432new file mode 100644 1433index 00000000..641aa483 1434--- /dev/null 1435+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h 1436@@ -0,0 +1,31 @@ 1437+/** 1438+ * Copyright 2023 Huawei Technologies Co., Ltd 1439+ * 1440+ * Licensed under the Apache License, Version 2.0 (the "License"); 1441+ * you may not use this file except in compliance with the License. 1442+ * You may obtain a copy of the License at 1443+ * 1444+ * http://www.apache.org/licenses/LICENSE-2.0 1445+ * 1446+ * Unless required by applicable law or agreed to in writing, software 1447+ * distributed under the License is distributed on an "AS IS" BASIS, 1448+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1449+ * See the License for the specific language governing permissions and 1450+ * limitations under the License. 1451+ */ 1452+#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H 1453+#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H 1454+ 1455+#include "nnacl/infer/common_infer.h" 1456+ 1457+#ifdef __cplusplus 1458+extern "C" { 1459+#endif 1460+ 1461+int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, 1462+ size_t outputs_size, OpParameter *parameter); 1463+ 1464+#ifdef __cplusplus 1465+} 1466+#endif 1467+#endif // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H 1468diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h 1469new file mode 100644 1470index 00000000..d7c34768 1471--- /dev/null 1472+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h 1473@@ -0,0 +1,65 @@ 1474+/** 1475+* Copyright 2023 Huawei Technologies Co., Ltd 1476+* 1477+* Licensed under the Apache License, Version 2.0 (the "License"); 1478+* you may not use this file except in compliance with the License. 1479+* You may obtain a copy of the License at 1480+* 1481+* http://www.apache.org/licenses/LICENSE-2.0 1482+* 1483+* Unless required by applicable law or agreed to in writing, software 1484+* distributed under the License is distributed on an "AS IS" BASIS, 1485+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1486+* See the License for the specific language governing permissions and 1487+* limitations under the License. 1488+*/ 1489+#ifndef NNACL_BASE_SCATTER_ND_BINARY_NEON_H_ 1490+#define NNACL_BASE_SCATTER_ND_BINARY_NEON_H_ 1491+ 1492+#include "nnacl/intrinsics/ms_simd_instructions.h" 1493+#include "nnacl/intrinsics/ms_simd_neon_instructions.h" 1494+ 1495+#ifdef __cplusplus 1496+extern "C" { 1497+#endif 1498+ 1499+#define MS_SIMD_INSTRUCTION MS_SIMD_NEON_INSTRUCTION 1500+#define BLOCK_NUM 4 1501+#define MS_SIMD_NEON 1502+ 1503+static inline int ScatterNDAddFp32NEON(int index, const float *update, int size, float *output) { 1504+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1505+ SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1506+ } 1507+ return index; 1508+} 1509+ 1510+static inline int ScatterNDAddInt32NEON(int index, const int *update, int size, int *output) { 1511+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1512+ SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1513+ } 1514+ return index; 1515+} 1516+ 1517+static inline int ScatterNDMaxFp32NEON(int index, const float *update, int size, float *output) { 1518+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1519+ SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1520+ } 1521+ return index; 1522+} 1523+ 1524+static inline int ScatterNDMaxInt32NEON(int index, const int *update, int size, int *output) { 1525+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1526+ SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1527+ } 1528+ return index; 1529+} 1530+ 1531+#undef MS_SIMD_INSTRUCTION 1532+#undef BLOCK_NUM 1533+ 1534+#undef MS_SIMD_NEON 1535+#ifdef __cplusplus 1536+} 1537+#endif 1538+#endif // NNACL_BASE_SCATTER_ND_BINARY_NEON_H_ 1539diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 1540index 955a70a5..895f7e3d 100644 1541--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 1542+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 1543@@ -558,6 +558,10 @@ enum PrimType { 1544 PrimType_Inner_CustomGru = 10010, 1545 PrimType_Inner_CastGatherReduceFusion = 10011, 1546 PrimType_Inner_ReduceConcatFusion = 10012, 1547+ PrimType_Inner_ThirdPartyModel = 10013, 1548+ PrimType_Inner_CustomMaskedFill = 10014, 1549+ PrimType_Inner_CustomTensorScatterMax = 10015, 1550+ PrimType_Inner_CustomIsInf = 10016, 1551 PrimType_InnerOpMax, 1552 PrimType_InnerOpMin = PrimType_Inner_ToFormat 1553 }; 1554diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h 1555new file mode 100644 1556index 00000000..dd9878f7 1557--- /dev/null 1558+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h 1559@@ -0,0 +1,36 @@ 1560+/** 1561+* Copyright 2023 Huawei Technologies Co., Ltd 1562+* 1563+* Licensed under the Apache License, Version 2.0 (the "License"); 1564+* you may not use this file except in compliance with the License. 1565+* You may obtain a copy of the License at 1566+* 1567+* http://www.apache.org/licenses/LICENSE-2.0 1568+* 1569+* Unless required by applicable law or agreed to in writing, software 1570+* distributed under the License is distributed on an "AS IS" BASIS, 1571+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1572+* See the License for the specific language governing permissions and 1573+* limitations under the License. 1574+*/ 1575+#ifndef NNACL_SCATTER_ND_BINARY_SIMD_H_ 1576+#define NNACL_SCATTER_ND_BINARY_SIMD_H_ 1577+ 1578+#include "nnacl/intrinsics/ms_simd_instructions.h" 1579+#ifdef ENABLE_AVX512 1580+#include "nnacl/avx512/scatter_nd_binary_avx512.h" 1581+#endif 1582+ 1583+#ifdef ENABLE_AVX 1584+#include "nnacl/avx/scatter_nd_binary_avx.h" 1585+#endif 1586+ 1587+#ifdef ENABLE_SSE 1588+#include "nnacl/sse/scatter_nd_binary_sse.h" 1589+#endif 1590+ 1591+#ifdef ENABLE_ARM 1592+#include "nnacl/neon/scatter_nd_binary_neon.h" 1593+#endif 1594+ 1595+#endif 1596diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h 1597new file mode 100644 1598index 00000000..983d2923 1599--- /dev/null 1600+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h 1601@@ -0,0 +1,66 @@ 1602+/** 1603+* Copyright 2023 Huawei Technologies Co., Ltd 1604+* 1605+* Licensed under the Apache License, Version 2.0 (the "License"); 1606+* you may not use this file except in compliance with the License. 1607+* You may obtain a copy of the License at 1608+* 1609+* http://www.apache.org/licenses/LICENSE-2.0 1610+* 1611+* Unless required by applicable law or agreed to in writing, software 1612+* distributed under the License is distributed on an "AS IS" BASIS, 1613+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1614+* See the License for the specific language governing permissions and 1615+* limitations under the License. 1616+*/ 1617+#ifndef NNACL_BASE_SCATTER_ND_BINARY_SSE_H_ 1618+#define NNACL_BASE_SCATTER_ND_BINARY_SSE_H_ 1619+ 1620+#include "nnacl/intrinsics/ms_simd_instructions.h" 1621+#include "nnacl/intrinsics/ms_simd_sse_instructions.h" 1622+ 1623+#ifdef __cplusplus 1624+extern "C" { 1625+#endif 1626+#pragma GCC push_options 1627+#pragma GCC target("sse4.1") 1628+#define MS_SIMD_INSTRUCTION MS_SIMD_SSE_INSTRUCTION 1629+#define BLOCK_NUM 4 1630+#define MS_SIMD_SSE 1631+ 1632+static inline int ScatterNDAddFp32SSE(int index, const float *update, int size, float *output) { 1633+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1634+ SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1635+ } 1636+ return index; 1637+} 1638+ 1639+static inline int ScatterNDAddInt32SSE(int index, const int *update, int size, int *output) { 1640+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1641+ SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1642+ } 1643+ return index; 1644+} 1645+ 1646+static inline int ScatterNDMaxFp32SSE(int index, const float *update, int size, float *output) { 1647+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1648+ SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1649+ } 1650+ return index; 1651+} 1652+ 1653+static inline int ScatterNDMaxInt32SSE(int index, const int *update, int size, int *output) { 1654+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1655+ SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1656+ } 1657+ return index; 1658+} 1659+ 1660+#undef MS_SIMD_INSTRUCTION 1661+#undef BLOCK_NUM 1662+#pragma GCC pop_options 1663+#undef MS_SIMD_SSE 1664+#ifdef __cplusplus 1665+} 1666+#endif 1667+#endif // NNACL_BASE_SCATTER_ND_BINARY_SSE_H_ 1668diff --git a/mindspore/core/mindrt/BUILD.gn b/mindspore/core/mindrt/BUILD.gn 1669index b56d5f5c..b0e7c70d 100644 1670--- a/mindspore/core/mindrt/BUILD.gn 1671+++ b/mindspore/core/mindrt/BUILD.gn 1672@@ -41,8 +41,15 @@ ohos_source_set("mindrt_obj") { 1673 "../../core/", 1674 ] 1675 1676+ defines = [ 1677+ "ENABLE_MINDRT", 1678+ "MS_COMPILE_OHOS", 1679+ "BUILD_LITE", 1680+ ] 1681+ 1682+ external_deps = [ "hilog:libhilog" ] 1683+ 1684 remove_configs = [ "//build/config/compiler:no_rtti" ] 1685- defines = [ "BUILD_LITE" ] 1686 1687 part_name = "mindspore" 1688 subsystem_name = "thirdparty" 1689diff --git a/mindspore/core/mindrt/src/thread/actor_threadpool.cc b/mindspore/core/mindrt/src/thread/actor_threadpool.cc 1690index 70414757..c50c46e0 100644 1691--- a/mindspore/core/mindrt/src/thread/actor_threadpool.cc 1692+++ b/mindspore/core/mindrt/src/thread/actor_threadpool.cc 1693@@ -32,7 +32,7 @@ void ActorWorker::RunWithSpin() { 1694 } 1695 #if !defined(__APPLE__) && !defined(_MSC_VER) 1696 static std::atomic_int index{0}; 1697- (void)pthread_setname_np(pthread_self(), ("ActorThread_" + std::to_string(index++)).c_str()); 1698+ (void)pthread_setname_np(pthread_self(), ("OS_Actor_" + std::to_string(index++)).c_str()); 1699 #endif 1700 #ifdef PLATFORM_86 1701 // Some CPU kernels need set the flush zero mode to improve performance. 1702diff --git a/mindspore/core/mindrt/src/thread/core_affinity.cc b/mindspore/core/mindrt/src/thread/core_affinity.cc 1703index 33bf3529..a3478dff 100644 1704--- a/mindspore/core/mindrt/src/thread/core_affinity.cc 1705+++ b/mindspore/core/mindrt/src/thread/core_affinity.cc 1706@@ -344,12 +344,12 @@ int CoreAffinity::InitBindCoreId(size_t thread_num, BindMode bind_mode) { 1707 int CoreAffinity::SetAffinity() { return THREAD_OK; } 1708 #elif defined(BIND_CORE) 1709 int CoreAffinity::SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set) { 1710-#ifdef __ANDROID__ 1711-#if __ANDROID_API__ >= 21 1712+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) 1713+#if (__ANDROID_API__ >= 21) || defined(MS_COMPILE_OHOS) 1714 THREAD_INFO("thread: %d, mask: %lu", pthread_gettid_np(thread_id), cpu_set->__bits[0]); 1715 int ret = sched_setaffinity(pthread_gettid_np(thread_id), sizeof(cpu_set_t), cpu_set); 1716 if (ret != THREAD_OK) { 1717- THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", pthread_gettid_np(thread_id), ret); 1718+ THREAD_ERROR("bind thread %d to cpu failed. ERROR %{public}d", pthread_gettid_np(thread_id), ret); 1719 return THREAD_ERROR; 1720 } 1721 #endif 1722diff --git a/mindspore/core/mindrt/src/thread/core_affinity.h b/mindspore/core/mindrt/src/thread/core_affinity.h 1723index 2dd2abd1..28b0967a 100644 1724--- a/mindspore/core/mindrt/src/thread/core_affinity.h 1725+++ b/mindspore/core/mindrt/src/thread/core_affinity.h 1726@@ -23,7 +23,7 @@ 1727 #ifdef PARALLEL_INFERENCE 1728 #define BIND_CORE 1729 #endif 1730-#ifdef __ANDROID__ 1731+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) 1732 #define BIND_CORE 1733 #include <sched.h> 1734 #endif 1735diff --git a/mindspore/core/mindrt/src/thread/parallel_threadpool.cc b/mindspore/core/mindrt/src/thread/parallel_threadpool.cc 1736index 9e0dd25c..09c39f32 100644 1737--- a/mindspore/core/mindrt/src/thread/parallel_threadpool.cc 1738+++ b/mindspore/core/mindrt/src/thread/parallel_threadpool.cc 1739@@ -48,7 +48,7 @@ void ParallelWorker::ParallelRun() { 1740 SetAffinity(); 1741 } 1742 #if !defined(__APPLE__) && !defined(_MSC_VER) 1743- (void)pthread_setname_np(pthread_self(), ("ParallelThread_" + std::to_string(worker_id_)).c_str()); 1744+ (void)pthread_setname_np(pthread_self(), ("OS_Parallel_" + std::to_string(worker_id_)).c_str()); 1745 #endif 1746 #ifdef PLATFORM_86 1747 // Some CPU kernels need set the flush zero mode to improve performance. 1748diff --git a/mindspore/core/mindrt/src/thread/threadlog.h b/mindspore/core/mindrt/src/thread/threadlog.h 1749index 7ed917f1..b212a401 100644 1750--- a/mindspore/core/mindrt/src/thread/threadlog.h 1751+++ b/mindspore/core/mindrt/src/thread/threadlog.h 1752@@ -16,7 +16,9 @@ 1753 1754 #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_LOG_H_ 1755 #define MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_LOG_H_ 1756- 1757+#ifdef MS_COMPILE_OHOS 1758+#include "hilog/log.h" 1759+#endif 1760 namespace mindspore { 1761 #ifdef THREAD_POOL_DEBUG 1762 #include <stdio.h> 1763@@ -32,13 +34,35 @@ namespace mindspore { 1764 } 1765 #else 1766 #define THREAD_DEBUG(content, ...) 1767-#define THREAD_INFO(content, ...) 1768 #define THREAD_TEST_TRUE(flag) 1769+ 1770 #if defined(__ANDROID__) 1771+#define THREAD_INFO(content, ...) 1772 #include <android/log.h> 1773 #define THREAD_ERROR(content, args...) \ 1774 { __android_log_print(ANDROID_LOG_ERROR, "MS_LITE", "%s|%d: " #content "\r\n", __func__, __LINE__, ##args); } 1775+ 1776+#elif defined(MS_COMPILE_OHOS) // For OHOS, use hilog. 1777+ 1778+#define MINDRT_OHOS_LOG_DOMAIN 0x2102 1779+#define MINDRT_OHOS_LOG_TAG "MS_LITE" 1780+ 1781+#ifdef MS_COMPILE_WITH_OHOS_NDK 1782+// When build with OHOS NDK, use public api of hilog module. 1783+#define THREAD_INFO(content, args...) \ 1784+ { OH_LOG_Print(LOG_APP, LOG_INFO, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); } 1785+#define THREAD_ERROR(content, args...) \ 1786+ { OH_LOG_Print(LOG_APP, LOG_ERROR, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); } 1787+#else 1788+// When build in OHOS repo, use inner api of hilog module. 1789+#define THREAD_INFO(content, args...) \ 1790+ { HiLogPrint(LOG_APP, LOG_INFO, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); } 1791+#define THREAD_ERROR(content, args...) \ 1792+ { HiLogPrint(LOG_APP, LOG_ERROR, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); } 1793+#endif 1794+ 1795 #else 1796+#define THREAD_INFO(content, ...) 1797 #define THREAD_ERROR(content, ...) 1798 #endif 1799 #endif 1800diff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc 1801index c56e0425..2301be8c 100644 1802--- a/mindspore/core/mindrt/src/thread/threadpool.cc 1803+++ b/mindspore/core/mindrt/src/thread/threadpool.cc 1804@@ -68,10 +68,11 @@ void Worker::SetAffinity() { 1805 #ifdef _WIN32 1806 SetWindowsSelfAffinity(core_id_); 1807 #elif defined(BIND_CORE) 1808-#ifdef __ANDROID__ 1809+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) 1810+ THREAD_INFO("thread: %d, mask: %lu", gettid(), mask_.__bits[0]); 1811 int ret = sched_setaffinity(gettid(), sizeof(cpu_set_t), &mask_); 1812 if (ret != THREAD_OK) { 1813- THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", gettid(), errno); 1814+ THREAD_ERROR("bind thread %d to cpu failed. ERROR %{public}d", gettid(), errno); 1815 } 1816 return; 1817 #else 1818@@ -111,7 +112,7 @@ void Worker::Run() { 1819 } 1820 #if !defined(__APPLE__) && !defined(_MSC_VER) 1821 static std::atomic_int index = {0}; 1822- (void)pthread_setname_np(pthread_self(), ("KernelThread_" + std::to_string(index++)).c_str()); 1823+ (void)pthread_setname_np(pthread_self(), ("OS_Kernel_" + std::to_string(index++)).c_str()); 1824 #endif 1825 #ifdef PLATFORM_86 1826 // Some CPU kernels need set the flush zero mode to improve performance. 1827diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 1828index a774b58c..f7e465e2 100644 1829--- a/mindspore/lite/BUILD.gn 1830+++ b/mindspore/lite/BUILD.gn 1831@@ -71,9 +71,14 @@ 1832 1833 import("//build/ohos.gni") 1834 1835+declare_args() { 1836+ mindspore_feature_nnrt_metagraph = false 1837+} 1838+ 1839 ohos_group("mindspore") { 1840 deps = [ 1841 ":mindspore_lib", 1842+ ":mindspore_ndk", 1843 ":mindspore_train_lib", 1844 "mindir:mindir_lib", 1845 "src/litert/js_api:mindsporelite_napi" 1846@@ -180,7 +185,6 @@ lite_mindrt_sources = [ 1847 ] 1848 1849 all_lite_sources += cxx_api_sources 1850-all_lite_sources += c_api_sources 1851 all_lite_sources += api_source 1852 all_lite_sources += control_flow_kernel_sources 1853 all_lite_sources += experimental_sources 1854@@ -368,7 +372,6 @@ ohos_shared_library("mindspore_lib") { 1855 sources = all_sources 1856 1857 include_dirs = [ 1858- "//base/hiviewdfx/hilog/interfaces/native/innerkits/include", 1859 "//third_party/flatbuffers/include", 1860 "./", 1861 "../", 1862@@ -384,6 +387,7 @@ ohos_shared_library("mindspore_lib") { 1863 "../ccsrc/", 1864 "src/litert/kernel/cpu/", 1865 "../core/mindrt/src/", 1866+ "//foundation/ai/neural_network_runtime/", 1867 ] 1868 1869 defines = [ 1870@@ -426,24 +430,29 @@ ohos_shared_library("mindspore_lib") { 1871 1872 external_deps = [ "hilog:libhilog" ] 1873 1874- output_name = "libmindspore-lite.huawei" 1875+ output_name = "libmindspore-lite" 1876 output_extension = "so" 1877 innerapi_tags = [ "platformsdk" ] 1878 SUPPORT_NNRT = true 1879 if (SUPPORT_NNRT) { 1880+ if (mindspore_feature_nnrt_metagraph) { 1881+ defines += [ "SUPPORT_NNRT_METAGRAPH" ] 1882+ print("enabled feature: mindspore_feature_nnrt_metagraph") 1883+ } 1884 sources += [ 1885 "src/litert/delegate/nnrt/checker/primitive_check.cc", 1886 "src/litert/delegate/nnrt/nnrt_delegate.cc", 1887 "src/litert/delegate/nnrt/nnrt_model_kernel.cc", 1888 ] 1889 include_dirs += [ 1890- "//foundation/ai/neural_network_runtime", 1891 "src/delegate/nnrt/include", 1892 "../../mindspore/core/ir", 1893 "mindir/include", 1894 "mindir/inner_headers", 1895 ] 1896+ 1897 external_deps += [ "neural_network_runtime:nnrt_target" ] 1898+ 1899 deps += [ "mindir:mindir_lib" ] 1900 defines += [ "SUPPORT_NNRT" ] 1901 } 1902@@ -461,6 +470,67 @@ ohos_shared_library("mindspore_lib") { 1903 subsystem_name = "thirdparty" 1904 } 1905 1906+# NDK lib 1907+ohos_shared_library("mindspore_ndk") { 1908+ deps = [ 1909+ ":mindspore_lib", 1910+ ":mindspore_train_lib" 1911+ ] 1912+ 1913+ sources = c_api_sources 1914+ 1915+ include_dirs = [ 1916+ "//base/hiviewdfx/hilog/interfaces/native/innerkits/include", 1917+ "//third_party/flatbuffers/include", 1918+ "./", 1919+ "../", 1920+ "../../", 1921+ "../core", 1922+ "src", 1923+ "src/c_api/", 1924+ "../ccsrc/plugin/device/cpu/kernel/", 1925+ "../core/mindrt/src/", 1926+ "../core/mindrt/include/", 1927+ "../../third_party/", 1928+ "./schema/", 1929+ "../ccsrc/", 1930+ "//foundation/ai/neural_network_runtime/", 1931+ ] 1932+ 1933+ defines = [ 1934+ "SUPPORT_NNRT", 1935+ "MS_COMPILE_OHOS", 1936+ "PRIMITIVE_WRITEABLE", 1937+ "RUNTIME_PASS_CLIP", 1938+ "ENABLE_MULTI_LAYOUT", 1939+ "VERSION_STR=\"2.1.0\"", 1940+ ] 1941+ 1942+ configs = [ 1943+ ":mindspore_api", 1944+ ":disable_android", 1945+ ":secure_option", 1946+ ] 1947+ 1948+ external_deps = [ "neural_network_runtime:nnrt_target" ] 1949+ 1950+ remove_configs = [ "//build/config/compiler:no_rtti" ] 1951+ 1952+ output_name = "libmindspore_lite_ndk" 1953+ output_extension = "so" 1954+ innerapi_tags = [ "ndk"] 1955+ cflags_cc = [ 1956+ "-Wno-ignored-qualifiers", 1957+ "-Wunused-private-field", 1958+ "-Wno-unused-private-field", 1959+ "-Wno-inconsistent-missing-override", 1960+ "-Wno-macro-redefined", 1961+ "-Wno-constant-conversion", 1962+ ] 1963+ part_name = "mindspore" 1964+ subsystem_name = "thirdparty" 1965+} 1966+ 1967 # Train library 1968 expression_cxx_api_sources = [ 1969 "src/litert/cxx_api/expression/net.cc", 1970@@ -614,7 +684,6 @@ ohos_shared_library("mindspore_train_lib") { 1971 sources = all_train_sources 1972 1973 include_dirs = [ 1974- "//base/hiviewdfx/hilog/interfaces/native/innerkits/include", 1975 "//third_party/flatbuffers/include", 1976 "./", 1977 "../", 1978@@ -698,6 +767,9 @@ config("disable_android") { 1979 "-U__ANDROID__", 1980 "-U__ANDROID_API__", 1981 ] 1982+ ldflags = [ 1983+ "-Wl,--no-as-needed", 1984+ ] 1985 } 1986 1987 config("secure_option") { 1988diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt 1989index 72337f70..1faf2f38 100644 1990--- a/mindspore/lite/CMakeLists.txt 1991+++ b/mindspore/lite/CMakeLists.txt 1992@@ -298,8 +298,9 @@ elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite") 1993 elseif(TOOLCHAIN_NAME STREQUAL "ohos") 1994 set(TARGET_OHOS on) 1995 add_compile_definitions(MS_COMPILE_OHOS) 1996- set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions") 1997- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions") 1998+ add_compile_definitions(MS_COMPILE_WITH_OHOS_NDK) 1999+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions -Wno-deprecated-builtins") 2000+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions -Wno-deprecated-builtins") 2001 endif() 2002 2003 if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0 2004diff --git a/mindspore/lite/include/lite_types.h b/mindspore/lite/include/lite_types.h 2005index 017e98a8..860390d5 100644 2006--- a/mindspore/lite/include/lite_types.h 2007+++ b/mindspore/lite/include/lite_types.h 2008@@ -42,6 +42,7 @@ typedef enum { 2009 DT_NPU, /**< NPU device type */ 2010 DT_ASCEND, /**< ASCEND device type */ 2011 DT_CUSTOM, /**< EXTEND device type */ 2012+ DT_NNRT, /**< NNRT device type */ 2013 DT_END /**< NO device type */ 2014 } DeviceType; 2015 2016diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h 2017index 93e27ea9..b96c7e35 100644 2018--- a/mindspore/lite/include/model.h 2019+++ b/mindspore/lite/include/model.h 2020@@ -25,6 +25,7 @@ namespace mindspore { 2021 namespace schema { 2022 struct Tensor; 2023 } // namespace schema 2024+ 2025 namespace lite { 2026 typedef enum { ModelType_MSLite, ModelType_MindIR } LiteModelType; 2027 2028@@ -62,7 +63,10 @@ struct MS_API LiteGraph { 2029 bool model_obfuscated_ = false; 2030 std::vector<unsigned char *> deobf_prims_; 2031 #endif 2032+ 2033+ std::string ToString() const; 2034 }; 2035+ 2036 struct MS_API Model { 2037 LiteGraph graph_; 2038 char *buf = nullptr; 2039diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h 2040index 2d72b200..4bc92599 100644 2041--- a/mindspore/lite/include/registry/converter_context.h 2042+++ b/mindspore/lite/include/registry/converter_context.h 2043@@ -39,7 +39,9 @@ enum MS_API FmkType : int { 2044 kFmkTypeMs = 3, 2045 kFmkTypeTflite = 4, 2046 kFmkTypePytorch = 5, 2047- kFmkTypeMsLite = 6, 2048+ kFmkTypeThirdParty = 6, 2049+ kFmkTypeMsLite = 7, 2050+ kFmkTypeEnd = 8, // For range check purpose, valid range: [0, kFmkTypeEnd) 2051 }; 2052 2053 /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. 2054diff --git a/mindspore/lite/mindir/include/mindir.h b/mindspore/lite/mindir/include/mindir.h 2055index ca811dce..f47cad8c 100644 2056--- a/mindspore/lite/mindir/include/mindir.h 2057+++ b/mindspore/lite/mindir/include/mindir.h 2058@@ -151,6 +151,8 @@ int64_t MindIR_Conv2DFusion_GetOutChannel(ConstPrimitivePtr primitive); 2059 void MindIR_Conv2DFusion_SetOutChannel(PrimitivePtr *primitive, int64_t out_channel); 2060 ActivationType MindIR_Conv2DFusion_GetActivationType(ConstPrimitivePtr primitive); 2061 void MindIR_Conv2DFusion_SetActivationType(PrimitivePtr *primitive, ActivationType activation_type); 2062+Format MindIR_Conv2DFusion_GetFormat(ConstPrimitivePtr primitive); 2063+void MindIR_Conv2DFusion_SetFormat(PrimitivePtr *primitive, Format format); 2064 2065 // ********** Conv2dTransposeFusion ********** 2066 PrimitivePtr MindIR_Conv2dTransposeFusion_CreatePrimitive( 2067diff --git a/mindspore/lite/mindir/src/mindir.cc b/mindspore/lite/mindir/src/mindir.cc 2068index 7fc9c00e..374bbef5 100644 2069--- a/mindspore/lite/mindir/src/mindir.cc 2070+++ b/mindspore/lite/mindir/src/mindir.cc 2071@@ -1215,6 +1215,46 @@ void MindIR_Conv2DFusion_SetActivationType(PrimitivePtr *primitive, ActivationTy 2072 } 2073 } 2074 2075+Format MindIR_Conv2DFusion_GetFormat(ConstPrimitivePtr primitive) { 2076+ if (primitive != nullptr) { 2077+ auto prim = static_cast<const schema::Primitive *>(primitive); 2078+ auto value = prim->value_as_Conv2DFusion(); 2079+ if (prim != nullptr && value != nullptr) { 2080+ return static_cast<Format>(value->format()); 2081+ } else { 2082+ Format en = static_cast<Format>(0); 2083+ return en; 2084+ } 2085+ } else { 2086+ Format en = static_cast<Format>(0); 2087+ return en; 2088+ } 2089+} 2090+ 2091+void MindIR_Conv2DFusion_SetFormat(PrimitivePtr *primitive, Format format) { 2092+ if (primitive != nullptr && *primitive != nullptr) { 2093+ auto prim = static_cast<schema::Primitive *>(*primitive); 2094+ auto value = prim->value_as_Conv2DFusion(); 2095+ if (prim != nullptr && value != nullptr) { 2096+ flatbuffers::FlatBufferBuilder fbb; 2097+ auto ops_offset = schema::CreateConv2DFusion( 2098+ fbb, static_cast<schema::Format>(format), 2099+ fbb.CreateVector(value->kernel_size()->data(), value->kernel_size()->size()), 2100+ fbb.CreateVector(value->stride()->data(), value->stride()->size()), 2101+ fbb.CreateVector(value->dilation()->data(), value->dilation()->size()), 2102+ static_cast<schema::PadMode>(value->pad_mode()), 2103+ fbb.CreateVector(value->pad_list()->data(), value->pad_list()->size()), 0, value->group(), value->in_channel(), 2104+ value->out_channel(), static_cast<schema::ActivationType>(value->activation_type())); 2105+ auto prim_offset = 2106+ schema::CreatePrimitive(fbb, static_cast<schema::PrimitiveType>(NODE_TYPE_CONV2D_FUSION), ops_offset.o); 2107+ fbb.Finish(prim_offset); 2108+ auto new_addr = MindIRMemoryManager::GetInstance()->CreatePrimitiveFromBuilder(fbb, prim); 2109+ auto ret_value = flatbuffers::GetMutableRoot<schema::Primitive>(new_addr); 2110+ *primitive = ret_value; 2111+ } 2112+ } 2113+} 2114+ 2115 // ********** Conv2dTransposeFusion ********** 2116 PrimitivePtr MindIR_Conv2dTransposeFusion_CreatePrimitive( 2117 const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, 2118diff --git a/mindspore/lite/mindir/src/mindir_tensor.cc b/mindspore/lite/mindir/src/mindir_tensor.cc 2119index 9ec2d0e4..2db4ce8b 100644 2120--- a/mindspore/lite/mindir/src/mindir_tensor.cc 2121+++ b/mindspore/lite/mindir/src/mindir_tensor.cc 2122@@ -134,7 +134,7 @@ void MindIR_Tensor_SetDataType(TensorPtr *tensor, DataType data_type) { 2123 name = fbb.CreateString(value->name()->c_str(), value->name()->size()); 2124 } 2125 auto ops_offset = 2126- schema::CreateTensor(fbb, 0, value->dataType(), dims, static_cast<schema::Format>(value->format()), 0, 0, data, 2127+ schema::CreateTensor(fbb, 0, data_type, dims, static_cast<schema::Format>(value->format()), 0, 0, data, 2128 ConvertQuantParams(fbb, value->quantParams()), 0, name); 2129 fbb.Finish(ops_offset); 2130 auto new_addr = MindIRMemoryManager::GetInstance()->CreateTensorFromBuilder(fbb, value); 2131diff --git a/mindspore/lite/mindir/src/utils.cc b/mindspore/lite/mindir/src/utils.cc 2132index 28d66ceb..b044f414 100644 2133--- a/mindspore/lite/mindir/src/utils.cc 2134+++ b/mindspore/lite/mindir/src/utils.cc 2135@@ -22,7 +22,7 @@ namespace lite { 2136 2137 // ********** PrimitiveBase ********** 2138 NodeType MindIR_Primitive_GetType(PrimitivePtr primitive) { 2139- auto prim = flatbuffers::GetMutableRoot<schema::Primitive>(primitive); 2140+ auto prim = static_cast<schema::Primitive *>(primitive); 2141 auto type = prim->value_type(); 2142 return static_cast<NodeType>(type); 2143 } 2144diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt 2145index 5afccc87..de1781cd 100644 2146--- a/mindspore/lite/src/CMakeLists.txt 2147+++ b/mindspore/lite/src/CMakeLists.txt 2148@@ -410,6 +410,11 @@ add_subdirectory(common) 2149 add_library(lite_src_mid OBJECT ${LITE_SRC}) 2150 add_dependencies(lite_src_mid lite_src_common_mid fbs_src fbs_inner_src) 2151 2152+if(SUPPORT_NNRT) 2153+ add_subdirectory(litert/delegate/nnrt) 2154+ target_link_libraries(lite_src_mid nnrt_mid) 2155+endif() 2156+ 2157 if(MSLITE_ENABLE_ACL) 2158 include_directories(${TOP_DIR}/graphengine/910/inc/external) 2159 if(NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) 2160@@ -497,7 +502,6 @@ if(MSLITE_ENABLE_MINDRT) 2161 endif() 2162 2163 if (SUPPORT_NNRT) 2164- add_subdirectory(litert/delegate/nnrt) 2165 target_link_libraries(mindspore-lite nnrt_mid) 2166 target_link_libraries(mindspore-lite_static nnrt_mid) 2167 endif() 2168diff --git a/mindspore/lite/src/common/context_util.cc b/mindspore/lite/src/common/context_util.cc 2169index f011e0d7..0fa4ebd0 100644 2170--- a/mindspore/lite/src/common/context_util.cc 2171+++ b/mindspore/lite/src/common/context_util.cc 2172@@ -118,6 +118,17 @@ std::shared_ptr<mindspore::DeviceInfoContext> CustomDeviceInfoFromCustomDeviceCo 2173 MS_CHECK_TRUE_RET(device_info != nullptr, nullptr); 2174 return device_info; 2175 } 2176+ 2177+std::shared_ptr<mindspore::NNRTDeviceInfo> NNRtDeviceInfoFromNNRtDeviceContext( 2178+ const lite::DeviceContext &nnrt_context) { 2179+ if (nnrt_context.device_type_ != DT_NNRT) { 2180+ MS_LOG(ERROR) << "Function input parameter is not NNRt context."; 2181+ return nullptr; 2182+ } 2183+ auto nnrt_info = std::make_shared<mindspore::NNRTDeviceInfo>(); 2184+ MS_CHECK_TRUE_RET(nnrt_info != nullptr, nullptr); 2185+ return nnrt_info; 2186+} 2187 } // namespace 2188 2189 mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &context) { 2190@@ -144,7 +155,8 @@ mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &co 2191 {DT_GPU, GPUDeviceInfoFromGPUDeviceContext}, 2192 {DT_NPU, NPUDeviceInfoFromNPUDeviceContext}, 2193 {DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext}, 2194- {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext}}; 2195+ {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext}, 2196+ {DT_NNRT, NNRtDeviceInfoFromNNRtDeviceContext}}; 2197 for (auto &device_context : context->device_list_) { 2198 auto device_type = device_context.device_type_; 2199 if (transfer_funcs.find(device_type) == transfer_funcs.end()) { 2200diff --git a/mindspore/lite/src/common/log.cc b/mindspore/lite/src/common/log.cc 2201index 66c0d76b..f1040662 100644 2202--- a/mindspore/lite/src/common/log.cc 2203+++ b/mindspore/lite/src/common/log.cc 2204@@ -21,6 +21,13 @@ 2205 #include <android/log.h> 2206 #endif 2207 2208+#ifdef MS_COMPILE_OHOS 2209+#define LOG_DOMAIN 0xD002102 2210+#define LOG_TAG "MS_LITE" 2211+#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s" 2212+#include "hilog/log.h" 2213+#endif 2214+ 2215 // namespace to support utils module definition namespace mindspore constexpr const char *ANDROID_LOG_TAG = "MS_LITE"; 2216 namespace mindspore { 2217 #if defined(__ANDROID__) 2218@@ -73,17 +80,33 @@ static int GetAndroidLogLevel(LiteLogLevel level) { 2219 2220 #ifdef MS_COMPILE_OHOS 2221 void PrintHiLog(LiteLogLevel level, const char *file, int line, const char *func, const char *msg) { 2222+#ifdef MS_COMPILE_WITH_OHOS_NDK 2223+ // When build with OHOS NDK, use public api of hilog module. 2224 if (level == LiteLogLevel::DEBUG) { 2225- OHOS::HiviewDFX::HiLog::Debug(MSLite_LABEL, FORMAT, file, line, func, msg); 2226+ OH_LOG_Print(LOG_APP, LOG_DEBUG, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2227 } else if (level == LiteLogLevel::INFO) { 2228- OHOS::HiviewDFX::HiLog::Info(MSLite_LABEL, FORMAT, file, line, func, msg); 2229+ OH_LOG_Print(LOG_APP, LOG_INFO, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2230 } else if (level == LiteLogLevel::WARNING) { 2231- OHOS::HiviewDFX::HiLog::Warn(MSLite_LABEL, FORMAT, file, line, func, msg); 2232+ OH_LOG_Print(LOG_APP, LOG_WARN, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2233 } else if (level == LiteLogLevel::ERROR) { 2234- OHOS::HiviewDFX::HiLog::Error(MSLite_LABEL, FORMAT, file, line, func, msg); 2235+ OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2236 } else { 2237- OHOS::HiviewDFX::HiLog::Error(MSLite_LABEL, FORMAT, file, line, func, msg); 2238+ OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2239 } 2240+#else 2241+ // When build in OHOS repo, use inner api of hilog module. 2242+ if (level == LiteLogLevel::DEBUG) { 2243+ HILOG_DEBUG(LOG_APP, FORMAT, file, line, func, msg); 2244+ } else if (level == LiteLogLevel::INFO) { 2245+ HILOG_INFO(LOG_APP, FORMAT, file, line, func, msg); 2246+ } else if (level == LiteLogLevel::WARNING) { 2247+ HILOG_WARN(LOG_APP, FORMAT, file, line, func, msg); 2248+ } else if (level == LiteLogLevel::ERROR) { 2249+ HILOG_ERROR(LOG_APP, FORMAT, file, line, func, msg); 2250+ } else { 2251+ HILOG_ERROR(LOG_APP, FORMAT, file, line, func, msg); 2252+ } 2253+#endif 2254 } 2255 #endif 2256 2257diff --git a/mindspore/lite/src/common/log.h b/mindspore/lite/src/common/log.h 2258index 3002a454..bea21f01 100644 2259--- a/mindspore/lite/src/common/log.h 2260+++ b/mindspore/lite/src/common/log.h 2261@@ -23,12 +23,6 @@ 2262 #include <unordered_map> 2263 #include "utils/overload.h" 2264 2265-#ifdef MS_COMPILE_OHOS 2266-#define LOG_DOMAIN 0x2102 2267-#define LOG_TAG "MS_Lite" 2268-#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s" 2269-#include "hilog/log.h" 2270-#endif 2271 // NOTICE: when relative path of 'log.h' changed, macro 'LITE_LOG_HEAR_FILE_REL_PATH' must be changed 2272 #ifndef LITE_LOG_HEAR_FILE_REL_PATH 2273 #define LITE_LOG_HEAR_FILE_REL_PATH "mindspore/lite/src/common/log.h" 2274@@ -140,6 +134,9 @@ class LiteLogWriter { 2275 LiteLogLevel log_level_; 2276 }; 2277 2278+#define MSLOG_IF(level) \ 2279+ mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \ 2280+ mindspore::LiteLogStream() 2281 2282 #define MS_LOG(level) MS_LOG_##level 2283 2284@@ -148,47 +145,6 @@ class LiteLogWriter { 2285 #define MS_LOG_WARNING MSLOG_IF(mindspore::LiteLogLevel::WARNING) 2286 #define MS_LOG_ERROR MSLOG_IF(mindspore::LiteLogLevel::ERROR) 2287 2288- 2289-#ifdef MS_COMPILE_OHOS 2290-namespace { 2291-constexpr unsigned int MSLITE_DOMAIN_ID_START = 0xD0029A0; 2292-constexpr unsigned int MSLITE_DOMAIN_ID_END = MSLITE_DOMAIN_ID_START + 32; 2293-constexpr unsigned int TEST_DOMAIN_ID = 0xD000F00; 2294-} // namespace 2295- 2296-#define FILE_NAME (__builtin_strrchr(__FILE__, '/') ? __builtin_strrchr(__FILE__, '/') + 1 : __FILE__) 2297-#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s" 2298- 2299-#define MSLOG_IF(level) \ 2300- mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \ 2301- mindspore::LiteLogStream() 2302- 2303-enum MSLiteManagerLogLabel { 2304- // Component labels, you can add if needed 2305- COMP_FWK = 0, 2306- // Test label 2307- LABEL_TEST, 2308- // The end of labels, max to the domain id range length 32 2309- LABEL_END, 2310-}; 2311- 2312-enum MSLiteManagerLogDomain { 2313- DOMAIN_FRAMEWORK = MSLITE_DOMAIN_ID_START + COMP_FWK, // 0xD0029A0 2314- DOMAIN_TEST = TEST_DOMAIN_ID, // 0xD000F00 2315- DOMAIN_END = MSLITE_DOMAIN_ID_END, // Max to 0xD002940, keep the sequence and length same as MSLiteManagerLogLabel 2316-}; 2317- 2318-// Keep the sequence and length same as MSLiteManagerLogDomain 2319-static constexpr OHOS::HiviewDFX::HiLogLabel MSLite_LABEL = {LOG_CORE, DOMAIN_FRAMEWORK, "MSLiteFwk"}; 2320- 2321-#else 2322- 2323-#define MSLOG_IF(level) \ 2324- mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \ 2325- mindspore::LiteLogStream() 2326- 2327-#endif 2328- 2329 } // namespace mindspore 2330 2331 #ifdef Debug 2332diff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc 2333index 5e1878b9..13957ed7 100644 2334--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc 2335+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc 2336@@ -19,6 +19,9 @@ 2337 #include "nnacl/custom_parameter.h" 2338 #include "nnacl/split_parameter.h" 2339 #include "nnacl/custom_gru_parameter.h" 2340+#include "nnacl/custom_masked_fill_parameter.h" 2341+#include "nnacl/custom_is_inf_parameter.h" 2342+#include "nnacl/custom_tensor_scatter_max_parameter.h" 2343 using mindspore::schema::PrimitiveType_Custom; 2344 2345 namespace mindspore { 2346@@ -92,6 +95,39 @@ OpParameter *CreateCustomGruParameter() { 2347 return reinterpret_cast<OpParameter *>(param); 2348 } 2349 2350+OpParameter *CreateCustomIsInfParameter() { 2351+ auto *param = static_cast<CustomIsInfParameter *>(malloc(sizeof(CustomIsInfParameter))); 2352+ if (param == nullptr) { 2353+ MS_LOG(ERROR) << "malloc CustomIsInfParameter failed."; 2354+ return nullptr; 2355+ } 2356+ memset(param, 0, sizeof(CustomIsInfParameter)); 2357+ param->op_parameter_.type_ = PrimType_Inner_CustomIsInf; 2358+ return reinterpret_cast<OpParameter *>(param); 2359+} 2360+ 2361+OpParameter *CreateCustomTensorScatterMaxParameter() { 2362+ auto *param = static_cast<CustomTensorScatterMaxParameter *>(malloc(sizeof(CustomTensorScatterMaxParameter))); 2363+ if (param == nullptr) { 2364+ MS_LOG(ERROR) << "malloc CustomTensorScatterMaxParameter failed."; 2365+ return nullptr; 2366+ } 2367+ memset(param, 0, sizeof(CustomTensorScatterMaxParameter)); 2368+ param->op_parameter_.type_ = PrimType_Inner_CustomTensorScatterMax; 2369+ return reinterpret_cast<OpParameter *>(param); 2370+} 2371+ 2372+OpParameter *CreateCustomMaskedFillParameter() { 2373+ auto *param = static_cast<CustomMaskedFillParameter *>(malloc(sizeof(CustomMaskedFillParameter))); 2374+ if (param == nullptr) { 2375+ MS_LOG(ERROR) << "malloc CustomMaskedFillParameter failed."; 2376+ return nullptr; 2377+ } 2378+ memset(param, 0, sizeof(CustomMaskedFillParameter)); 2379+ param->op_parameter_.type_ = PrimType_Inner_CustomMaskedFill; 2380+ return reinterpret_cast<OpParameter *>(param); 2381+} 2382+ 2383 OpParameter *PopulateCustomParameter(const void *prim) { 2384 MS_CHECK_TRUE_RET(prim != nullptr, nullptr); 2385 auto primitive = static_cast<const schema::Primitive *>(prim); 2386@@ -131,6 +167,23 @@ OpParameter *PopulateCustomParameter(const void *prim) { 2387 return CreateCustomGruParameter(); 2388 } else if (type == "CastGatherReduceFusion") { 2389 return CreateParam(PrimType_Inner_CastGatherReduceFusion); 2390+ } else if (type == "ThirdPartyModel") { 2391+ auto *param = static_cast<CustomParameter *>(malloc(sizeof(CustomParameter))); 2392+ if (param == nullptr) { 2393+ MS_LOG(ERROR) << "malloc CustomParameter failed."; 2394+ return nullptr; 2395+ } 2396+ memset(param, 0, sizeof(CustomParameter)); 2397+ param->op_parameter_.type_ = PrimType_Inner_ThirdPartyModel; 2398+ // Just use the attr_data pointer to save the prim directly, the inner value is parsed as necessary. 2399+ param->attr_data[0] = static_cast<char *>(const_cast<void *>(prim)); 2400+ return reinterpret_cast<OpParameter *>(param); 2401+ } else if (type == "MaskedFill") { 2402+ return CreateCustomMaskedFillParameter(); 2403+ } else if (type == "TensorScatterMax") { 2404+ return CreateCustomTensorScatterMaxParameter(); 2405+ } else if (type == "IsInf") { 2406+ return CreateCustomIsInfParameter(); 2407 } else { 2408 MS_LOG(ERROR) << "Unsupported custom type: " << type; 2409 } 2410diff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc 2411index f614ef09..c5f825aa 100644 2412--- a/mindspore/lite/src/litert/c_api/context_c.cc 2413+++ b/mindspore/lite/src/litert/c_api/context_c.cc 2414@@ -14,12 +14,17 @@ 2415 * limitations under the License. 2416 */ 2417 #include "include/c_api/context_c.h" 2418-#include "src/litert/c_api/context_c.h" 2419+#include "include/api/context.h" 2420+#include <string.h> 2421+#include "src/litert/c_api/type_c_private.h" 2422 #include "src/common/log_adapter.h" 2423+#ifdef SUPPORT_NNRT 2424+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 2425+#endif 2426 2427 // ================ Context ================ 2428 OH_AI_ContextHandle OH_AI_ContextCreate() { 2429- auto impl = new (std::nothrow) mindspore::ContextC; 2430+ auto impl = new (std::nothrow) mindspore::Context(); 2431 if (impl == nullptr) { 2432 MS_LOG(ERROR) << "memory allocation failed."; 2433 return nullptr; 2434@@ -29,7 +34,7 @@ OH_AI_ContextHandle OH_AI_ContextCreate() { 2435 2436 void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 2437 if (context != nullptr && *context != nullptr) { 2438- auto impl = static_cast<mindspore::ContextC *>(*context); 2439+ auto impl = static_cast<mindspore::Context *>(*context); 2440 delete impl; 2441 *context = nullptr; 2442 } 2443@@ -40,8 +45,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) 2444 MS_LOG(ERROR) << "param is nullptr."; 2445 return; 2446 } 2447- auto impl = static_cast<mindspore::ContextC *>(context); 2448- impl->thread_num = thread_num; 2449+ auto impl = static_cast<mindspore::Context *>(context); 2450+ impl->SetThreadNum(thread_num); 2451 } 2452 2453 int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 2454@@ -49,8 +54,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 2455 MS_LOG(ERROR) << "param is nullptr."; 2456 return 0; 2457 } 2458- auto impl = static_cast<mindspore::ContextC *>(context); 2459- return impl->thread_num; 2460+ auto impl = static_cast<mindspore::Context *>(context); 2461+ return impl->GetThreadNum(); 2462 } 2463 2464 void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 2465@@ -58,8 +63,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 2466 MS_LOG(ERROR) << "param is nullptr."; 2467 return; 2468 } 2469- auto impl = static_cast<mindspore::ContextC *>(context); 2470- impl->affinity_mode = mode; 2471+ auto impl = static_cast<mindspore::Context *>(context); 2472+ impl->SetThreadAffinity(mode); 2473 return; 2474 } 2475 2476@@ -68,8 +73,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 2477 MS_LOG(ERROR) << "param is nullptr."; 2478 return 0; 2479 } 2480- auto impl = static_cast<mindspore::ContextC *>(context); 2481- return impl->affinity_mode; 2482+ auto impl = static_cast<mindspore::Context *>(context); 2483+ return impl->GetThreadAffinityMode(); 2484 } 2485 2486 void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) { 2487@@ -78,8 +83,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i 2488 return; 2489 } 2490 const std::vector<int32_t> vec_core_list(core_list, core_list + core_num); 2491- auto impl = static_cast<mindspore::ContextC *>(context); 2492- impl->affinity_core_list = vec_core_list; 2493+ auto impl = static_cast<mindspore::Context *>(context); 2494+ impl->SetThreadAffinity(vec_core_list); 2495 return; 2496 } 2497 2498@@ -88,9 +93,18 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle 2499 MS_LOG(ERROR) << "param is nullptr."; 2500 return nullptr; 2501 } 2502- auto impl = static_cast<mindspore::ContextC *>(context); 2503- *core_num = impl->affinity_core_list.size(); 2504- return impl->affinity_core_list.data(); 2505+ auto impl = static_cast<mindspore::Context *>(context); 2506+ auto affinity_core_list = impl->GetThreadAffinityCoreList(); 2507+ *core_num = affinity_core_list.size(); 2508+ int32_t *core_list = static_cast<int32_t *>(malloc((*core_num) * sizeof(int32_t))); 2509+ if (core_list == nullptr) { 2510+ MS_LOG(ERROR) << "malloc core_list is null."; 2511+ return nullptr; 2512+ } 2513+ for (size_t i = 0; i < affinity_core_list.size(); i++) { 2514+ core_list[i] = affinity_core_list[i]; 2515+ } 2516+ return core_list; 2517 } 2518 2519 void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_parallel) { 2520@@ -98,8 +112,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle 2521 MS_LOG(ERROR) << "param is nullptr."; 2522 return; 2523 } 2524- auto impl = static_cast<mindspore::ContextC *>(context); 2525- impl->enable_parallel = is_parallel; 2526+ auto impl = static_cast<mindspore::Context *>(context); 2527+ impl->SetEnableParallel(is_parallel); 2528 } 2529 2530 bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { 2531@@ -107,8 +121,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { 2532 MS_LOG(ERROR) << "param is nullptr."; 2533 return false; 2534 } 2535- auto impl = static_cast<mindspore::ContextC *>(context); 2536- return impl->enable_parallel; 2537+ auto impl = static_cast<mindspore::Context *>(context); 2538+ return impl->GetEnableParallel(); 2539 } 2540 2541 void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) { 2542@@ -116,25 +130,36 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan 2543 MS_LOG(ERROR) << "param is nullptr."; 2544 return; 2545 } 2546- auto impl = static_cast<mindspore::ContextC *>(context); 2547- std::shared_ptr<mindspore::DeviceInfoC> device(static_cast<mindspore::DeviceInfoC *>(device_info)); 2548- impl->device_info_list.push_back(device); 2549+ auto impl = static_cast<mindspore::Context *>(context); 2550+ std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info)); 2551+ impl->MutableDeviceInfo().push_back(device); 2552 } 2553 2554 // ================ DeviceInfo ================ 2555 OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type) { 2556- mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC; 2557+ mindspore::DeviceInfoContext *impl; 2558+ if (OH_AI_DEVICETYPE_CPU == device_type) { 2559+ impl = new (std::nothrow) mindspore::CPUDeviceInfo(); 2560+ } else if (OH_AI_DEVICETYPE_GPU == device_type) { 2561+ impl = new (std::nothrow) mindspore::GPUDeviceInfo(); 2562+ } else if (OH_AI_DEVICETYPE_KIRIN_NPU == device_type) { 2563+ impl = new (std::nothrow) mindspore::KirinNPUDeviceInfo(); 2564+ } else if (OH_AI_DEVICETYPE_NNRT == device_type) { 2565+ impl = new (std::nothrow) mindspore::NNRTDeviceInfo(); 2566+ } else { 2567+ MS_LOG(ERROR) << "device_type is invalid."; 2568+ impl = nullptr; 2569+ } 2570 if (impl == nullptr) { 2571 MS_LOG(ERROR) << "memory allocation failed."; 2572 return nullptr; 2573 } 2574- impl->device_type = device_type; 2575 return static_cast<OH_AI_DeviceInfoHandle>(impl); 2576 } 2577 2578 void OH_AI_DeviceInfoDestroy(OH_AI_DeviceInfoHandle *device_info) { 2579 if (device_info != nullptr && *device_info != nullptr) { 2580- auto impl = static_cast<mindspore::DeviceInfoC *>(*device_info); 2581+ auto impl = static_cast<mindspore::DeviceInfoContext *>(*device_info); 2582 delete impl; 2583 *device_info = nullptr; 2584 } 2585@@ -145,8 +170,8 @@ void OH_AI_DeviceInfoSetProvider(OH_AI_DeviceInfoHandle device_info, const char 2586 MS_LOG(ERROR) << "param is nullptr."; 2587 return; 2588 } 2589- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2590- impl->provider = provider; 2591+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2592+ impl->SetProvider(provider); 2593 } 2594 2595 const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info) { 2596@@ -154,8 +179,14 @@ const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info 2597 MS_LOG(ERROR) << "param is nullptr."; 2598 return nullptr; 2599 } 2600- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2601- return impl->provider.c_str(); 2602+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2603+ char *provider = static_cast<char *>(malloc(impl->GetProvider().size() + 1)); 2604+ if (provider == nullptr) { 2605+ MS_LOG(ERROR) << "malloc provider is null."; 2606+ return nullptr; 2607+ } 2608+ strcpy(provider, impl->GetProvider().c_str()); 2609+ return provider; 2610 } 2611 2612 void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const char *device) { 2613@@ -163,8 +194,8 @@ void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const 2614 MS_LOG(ERROR) << "param is nullptr."; 2615 return; 2616 } 2617- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2618- impl->provider_device = device; 2619+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2620+ impl->SetProviderDevice(device); 2621 } 2622 2623 const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle device_info) { 2624@@ -172,8 +203,14 @@ const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle devic 2625 MS_LOG(ERROR) << "param is nullptr."; 2626 return nullptr; 2627 } 2628- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2629- return impl->provider_device.c_str(); 2630+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2631+ char *provider_device = static_cast<char *>(malloc(impl->GetProviderDevice().size() + 1)); 2632+ if (provider_device == nullptr) { 2633+ MS_LOG(ERROR) << "malloc provider_device is null."; 2634+ return nullptr; 2635+ } 2636+ strcpy(provider_device, impl->GetProviderDevice().c_str()); 2637+ return provider_device; 2638 } 2639 2640 OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle device_info) { 2641@@ -181,8 +218,8 @@ OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle devi 2642 MS_LOG(ERROR) << "param is nullptr."; 2643 return OH_AI_DEVICETYPE_INVALID; 2644 } 2645- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2646- return impl->device_type; 2647+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2648+ return static_cast<OH_AI_DeviceType>(impl->GetDeviceType()); 2649 } 2650 2651 void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_fp16) { 2652@@ -190,9 +227,17 @@ void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_f 2653 MS_LOG(ERROR) << "param is nullptr."; 2654 return; 2655 } 2656- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2657- if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) { 2658- impl->enable_fp16 = is_fp16; 2659+ 2660+ auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info); 2661+ if (OH_AI_DEVICETYPE_CPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2662+ auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info); 2663+ impl->SetEnableFP16(is_fp16); 2664+ } else if (OH_AI_DEVICETYPE_GPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2665+ auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info); 2666+ impl->SetEnableFP16(is_fp16); 2667+ } else if (OH_AI_DEVICETYPE_NNRT == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2668+ auto impl = static_cast<mindspore::NNRTDeviceInfo *>(device_info); 2669+ impl->SetEnableFP16(is_fp16); 2670 } else { 2671 MS_LOG(ERROR) << "Unsupported Feature."; 2672 } 2673@@ -203,11 +248,19 @@ bool OH_AI_DeviceInfoGetEnableFP16(const OH_AI_DeviceInfoHandle device_info) { 2674 MS_LOG(ERROR) << "param is nullptr."; 2675 return false; 2676 } 2677- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2678- if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) { 2679- return impl->enable_fp16; 2680+ 2681+ auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info); 2682+ if (OH_AI_DEVICETYPE_CPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2683+ auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info); 2684+ return impl->GetEnableFP16(); 2685+ } else if (OH_AI_DEVICETYPE_GPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2686+ auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info); 2687+ return impl->GetEnableFP16(); 2688+ } else if (OH_AI_DEVICETYPE_NNRT == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2689+ auto impl = static_cast<mindspore::NNRTDeviceInfo *>(device_info); 2690+ return impl->GetEnableFP16(); 2691 } else { 2692- MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type; 2693+ MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl_device->GetDeviceType(); 2694 return false; 2695 } 2696 } 2697@@ -217,9 +270,10 @@ void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, int freque 2698 MS_LOG(ERROR) << "param is nullptr."; 2699 return; 2700 } 2701- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2702- if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) { 2703- impl->frequency = frequency; 2704+ auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info); 2705+ if (static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType()) == OH_AI_DEVICETYPE_KIRIN_NPU) { 2706+ auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info); 2707+ impl->SetFrequency(frequency); 2708 } else { 2709 MS_LOG(ERROR) << "Unsupported Feature."; 2710 } 2711@@ -230,11 +284,231 @@ int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info) { // 2712 MS_LOG(ERROR) << "param is nullptr."; 2713 return -1; 2714 } 2715- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2716- if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) { 2717- return impl->frequency; 2718+ auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info); 2719+ if (static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType()) == OH_AI_DEVICETYPE_KIRIN_NPU) { 2720+ auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info); 2721+ return impl->GetFrequency(); 2722 } else { 2723 MS_LOG(ERROR) << "Unsupported Feature."; 2724 return -1; 2725 } 2726 } 2727+ 2728+NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num) { 2729+ if (num == nullptr) { 2730+ MS_LOG(ERROR) << "Input num is null"; 2731+ return nullptr; 2732+ } 2733+#ifdef SUPPORT_NNRT 2734+ *num = 0; 2735+ 2736+ const size_t *all_device_ids; 2737+ uint32_t device_count; 2738+ auto ret = OH_NNDevice_GetAllDevicesID(&all_device_ids, &device_count); 2739+ if ((ret != OH_NN_SUCCESS) || (device_count == 0)) { 2740+ MS_LOG(ERROR) << "NNRT get all device id failed, ret: " << ret; 2741+ return nullptr; 2742+ } 2743+ 2744+ NNRTDeviceDesc *desc = (NNRTDeviceDesc *)malloc(sizeof(NNRTDeviceDesc) * device_count); 2745+ if (desc == nullptr) { 2746+ MS_LOG(ERROR) << "NNRT allocate desc failed"; 2747+ return nullptr; 2748+ } 2749+ 2750+ for (uint32_t i = 0; i < device_count; i++) { 2751+ desc[i].device_id = all_device_ids[i]; 2752+ OH_NN_DeviceType type; 2753+ (void)OH_NNDevice_GetType(all_device_ids[i], &type); 2754+ desc[i].device_type = static_cast<OH_AI_NNRTDeviceType>(type); 2755+ 2756+ const char *name = nullptr; 2757+ (void)OH_NNDevice_GetName(all_device_ids[i], &name); 2758+ desc[i].device_name[127] = '\0'; 2759+ strncpy(desc[i].device_name, name, 127); 2760+ } 2761+ *num = device_count; 2762+ return desc; 2763+#else 2764+ return nullptr; 2765+#endif 2766+} 2767+ 2768+NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index) { 2769+ if (descs == nullptr) { 2770+ MS_LOG(ERROR) << "descs is null"; 2771+ return nullptr; 2772+ } 2773+ return descs + index; 2774+} 2775+ 2776+void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc) { 2777+ if (desc == nullptr) { 2778+ MS_LOG(WARNING) << "desc is null"; 2779+ return; 2780+ } 2781+ free(*desc); 2782+ *desc = nullptr; 2783+} 2784+ 2785+size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) { 2786+ if (desc == nullptr) { 2787+ MS_LOG(ERROR) << "NNRT desc is null"; 2788+ return 0; 2789+ } 2790+ return desc->device_id; 2791+} 2792+ 2793+const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) { 2794+ if (desc == nullptr) { 2795+ MS_LOG(ERROR) << "NNRT desc is null"; 2796+ return nullptr; 2797+ } 2798+ return desc->device_name; 2799+} 2800+ 2801+OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) { 2802+ if (desc == nullptr) { 2803+ MS_LOG(ERROR) << "NNRT desc is null"; 2804+ return OH_AI_NNRTDeviceType::OH_AI_NNRTDEVICE_OTHERS; 2805+ } 2806+ return desc->device_type; 2807+} 2808+ 2809+OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name) { 2810+ size_t num = 0; 2811+ NNRTDeviceDesc *desc = OH_AI_GetAllNNRTDeviceDescs(&num); 2812+ if (desc == nullptr) { 2813+ MS_LOG(ERROR) << "Get all device desc failed"; 2814+ return nullptr; 2815+ } 2816+ 2817+ OH_AI_DeviceInfoHandle handle = nullptr; 2818+ for (size_t i = 0; i < num; i++) { 2819+ if (strncmp(desc[i].device_name, name, NNRT_DEVICE_NAME_MAX - 1) == 0) { 2820+ handle = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT); 2821+ OH_AI_DeviceInfoSetDeviceId(handle, desc[i].device_id); 2822+ break; 2823+ } 2824+ } 2825+ OH_AI_DestroyAllNNRTDeviceDescs(&desc); 2826+ return handle; 2827+} 2828+ 2829+OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type) { 2830+ size_t num = 0; 2831+ NNRTDeviceDesc *desc = OH_AI_GetAllNNRTDeviceDescs(&num); 2832+ if (desc == nullptr) { 2833+ MS_LOG(ERROR) << "Get all device desc failed"; 2834+ return nullptr; 2835+ } 2836+ 2837+ OH_AI_DeviceInfoHandle handle = nullptr; 2838+ for (size_t i = 0; i < num; i++) { 2839+ if (desc[i].device_type == type) { 2840+ handle = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT); 2841+ OH_AI_DeviceInfoSetDeviceId(handle, desc[i].device_id); 2842+ break; 2843+ } 2844+ } 2845+ OH_AI_DestroyAllNNRTDeviceDescs(&desc); 2846+ return handle; 2847+} 2848+ 2849+void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id) { 2850+ if (device_info == nullptr) { 2851+ MS_LOG(ERROR) << "device info is null"; 2852+ return; 2853+ } 2854+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2855+ MS_LOG(ERROR) << "Set device_id of non-NNRT device is not allowable, ignored"; 2856+ return; 2857+ } 2858+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2859+ impl->SetDeviceID(device_id); 2860+} 2861+ 2862+size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info) { 2863+ if (device_info == nullptr) { 2864+ MS_LOG(ERROR) << "device info is null"; 2865+ return 0; 2866+ } 2867+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2868+ MS_LOG(ERROR) << "Get device_id of non-NNRT device is not allowable, ignored"; 2869+ return 0; 2870+ } 2871+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2872+ return impl->GetDeviceID(); 2873+} 2874+ 2875+void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode) { 2876+ if (device_info == nullptr) { 2877+ MS_LOG(ERROR) << "device info is null"; 2878+ return; 2879+ } 2880+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2881+ MS_LOG(ERROR) << "Set performance_mode of non-NNRT device is not allowable, ignored"; 2882+ return; 2883+ } 2884+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2885+ impl->SetPerformanceMode(mode); 2886+} 2887+ 2888+OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info) { 2889+ if (device_info == nullptr) { 2890+ MS_LOG(ERROR) << "device info is null"; 2891+ return OH_AI_PERFORMANCE_NONE; 2892+ } 2893+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2894+ MS_LOG(ERROR) << "Get performance_mode of non-NNRT device is not allowable, ignored"; 2895+ return OH_AI_PERFORMANCE_NONE; 2896+ } 2897+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2898+ return static_cast<OH_AI_PerformanceMode>(impl->GetPerformanceMode()); 2899+} 2900+ 2901+void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority) { 2902+ if (device_info == nullptr) { 2903+ MS_LOG(ERROR) << "device info is null"; 2904+ return; 2905+ } 2906+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2907+ MS_LOG(ERROR) << "Set priority of non-NNRT device is not allowable, ignored"; 2908+ return; 2909+ } 2910+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2911+ impl->SetPriority(priority); 2912+} 2913+ 2914+OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info) { 2915+ if (device_info == nullptr) { 2916+ MS_LOG(ERROR) << "device info is null"; 2917+ return OH_AI_PRIORITY_NONE; 2918+ } 2919+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2920+ MS_LOG(ERROR) << "Get priority of non-NNRT device is not allowable, ignored"; 2921+ return OH_AI_PRIORITY_NONE; 2922+ } 2923+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2924+ return static_cast<OH_AI_Priority>(impl->GetPriority()); 2925+} 2926+ 2927+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, 2928+ const char *name, const char*value, size_t value_size) { 2929+ if (device_info == nullptr) { 2930+ MS_LOG(ERROR) << "device info is null"; 2931+ return OH_AI_STATUS_LITE_NULLPTR; 2932+ } 2933+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2934+ MS_LOG(ERROR) << "Add extension to non-NNRT device is not allowable, ignored"; 2935+ return OH_AI_STATUS_LITE_ERROR; 2936+ } 2937+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2938+ mindspore::Extension extension; 2939+ extension.name = std::string(name); 2940+ extension.value = std::vector<uint8_t>(value, value + value_size); 2941+ std::vector<mindspore::Extension> extension_list = impl->GetExtensions(); 2942+ extension_list.push_back(extension); 2943+ impl->SetExtensions(extension_list); 2944+ return OH_AI_STATUS_SUCCESS; 2945+} 2946\ No newline at end of file 2947diff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h 2948index 076f4d1f..dc88b8a4 100644 2949--- a/mindspore/lite/src/litert/c_api/context_c.h 2950+++ b/mindspore/lite/src/litert/c_api/context_c.h 2951@@ -21,27 +21,4 @@ 2952 #include <memory> 2953 #include "include/c_api/types_c.h" 2954 2955-namespace mindspore { 2956-class Allocator; 2957-class Delegate; 2958- 2959-typedef struct DeviceInfoC { 2960- OH_AI_DeviceType device_type; 2961- bool enable_fp16 = false; 2962- int frequency = 3; 2963- std::string provider; 2964- std::string provider_device; 2965- std::shared_ptr<Allocator> allocator = nullptr; 2966-} DeviceInfoC; 2967- 2968-typedef struct ContextC { 2969- std::vector<std::shared_ptr<DeviceInfoC>> device_info_list; 2970- int32_t thread_num = 2; 2971- bool enable_parallel = false; 2972- std::vector<int32_t> affinity_core_list; 2973- int affinity_mode = 0; 2974- int delegate_mode = 0; 2975- std::shared_ptr<Delegate> delegate = nullptr; 2976-} ContextC; 2977-} // namespace mindspore 2978 #endif // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_ 2979diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc 2980index 802df6b1..9da52d76 100644 2981--- a/mindspore/lite/src/litert/c_api/model_c.cc 2982+++ b/mindspore/lite/src/litert/c_api/model_c.cc 2983@@ -17,321 +17,135 @@ 2984 #include <vector> 2985 #include <cstdint> 2986 #include "include/api/context.h" 2987+#include <include/api/serialization.h> 2988 #include "include/api/types.h" 2989 #include "src/litert/cxx_api/tensor/tensor_impl.h" 2990 #include "src/litert/cxx_api/converters.h" 2991-#include "src/litert/lite_session.h" 2992-#include "src/litert/cpu_info.h" 2993+#include "src/litert//cxx_api/model/model_impl.h" 2994 2995 namespace mindspore { 2996 class ModelC { 2997- public: 2998- ModelC() : session_(nullptr), context_(nullptr) {} 2999+public: 3000+ ModelC() : model_(nullptr) {} 3001 ~ModelC() { 3002- for (auto &impl : tensor_map_) { 3003- delete impl.second; 3004+ for (auto in : inputs_) { 3005+ delete in; 3006+ } 3007+ for (auto out : outputs_) { 3008+ delete out; 3009+ } 3010+ for (auto out : outputs_train_) { 3011+ delete out; 3012 } 3013 } 3014 3015- Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context); 3016- Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context); 3017- Status Resize(const std::vector<LiteTensorImpl *> &inputs, const std::vector<std::vector<int64_t>> &shapes); 3018- 3019- Status Predict(const OH_AI_TensorHandle *inputs, size_t input_num, OH_AI_TensorHandle **outputs, size_t *output_num, 3020- const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after); 3021- 3022- LiteTensorImpl **GetInputs(size_t *input_num); 3023- LiteTensorImpl **GetOutputs(size_t *output_num); 3024+ MSTensor **GetInputs(size_t *input_num); 3025+ MSTensor **GetOutputs(size_t *output_num); 3026+ mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &oh_callback); 3027+ std::shared_ptr<Model> model_; 3028+ std::shared_ptr<Context> context_; 3029 3030- private: 3031- Status RunGraph(const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after); 3032- void ResetTensorData(std::vector<void *> old_data, std::vector<lite::Tensor *> tensors); 3033- LiteTensorImpl *TensorToTensorImpl(mindspore::lite::Tensor *tensor); 3034- 3035- private: 3036- std::shared_ptr<lite::LiteSession> session_ = nullptr; 3037- std::shared_ptr<const ContextC> context_ = nullptr; 3038- std::map<mindspore::lite::Tensor *, LiteTensorImpl *> tensor_map_; 3039- std::vector<LiteTensorImpl *> inputs_; 3040- std::vector<LiteTensorImpl *> outputs_; 3041- bool is_already_built = false; 3042+private: 3043+ MSTensor **GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors); 3044+ std::vector<MSTensor *> inputs_; 3045+ std::vector<MSTensor *> outputs_; 3046+ std::vector<MSTensor *> outputs_train_; 3047 }; 3048 3049-Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) { 3050- if (is_already_built) { 3051- MS_LOG(ERROR) << "The model is already built."; 3052- return kLiteModelRebuild; 3053- } 3054- if (!PlatformInstructionSetSupportCheck()) { 3055- MS_LOG(ERROR) << "The platform exist don't support's instruction."; 3056- return kLiteNotSupport; 3057- } 3058- if(context_.get() != model_context){ 3059- context_.reset(model_context); 3060- } 3061- session_ = std::make_shared<lite::LiteSession>(); 3062- if (session_ == nullptr) { 3063- MS_LOG(ERROR) << "create session failed"; 3064- return kLiteNullptr; 3065- } 3066- auto ret = session_->Init(ContextUtils::Convert(model_context)); 3067- if (ret != mindspore::lite::RET_OK) { 3068- MS_LOG(ERROR) << "init session failed"; 3069- return static_cast<StatusCode>(ret); 3070- } 3071- ret = session_->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), model_type, data_size); 3072- if (ret != RET_OK) { 3073- MS_LOG(ERROR) << "Load and compile failed"; 3074- return static_cast<StatusCode>(ret); 3075- } 3076- is_already_built = true; 3077- return static_cast<StatusCode>(kSuccess); 3078-} 3079- 3080-Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) { 3081- if (is_already_built) { 3082- MS_LOG(ERROR) << "The model is already built."; 3083- return kLiteModelRebuild; 3084- } 3085- if (!PlatformInstructionSetSupportCheck()) { 3086- MS_LOG(ERROR) << "The platform exist don't support's instruction."; 3087- return kLiteNotSupport; 3088- } 3089- if(context_.get() != model_context){ 3090- context_.reset(model_context); 3091- } 3092- session_ = std::make_shared<lite::LiteSession>(); 3093- if (session_ == nullptr) { 3094- MS_LOG(ERROR) << "create session failed"; 3095- return kLiteNullptr; 3096- } 3097- auto ret = session_->Init(ContextUtils::Convert(model_context)); 3098- if (ret != mindspore::lite::RET_OK) { 3099- MS_LOG(ERROR) << "init session failed"; 3100- return static_cast<StatusCode>(ret); 3101+MSTensor **ModelC::GetInputs(size_t *input_num) { 3102+ if (model_ == nullptr) { 3103+ MS_LOG(ERROR) << "model_ is nullptr."; 3104+ return nullptr; 3105 } 3106- ret = session_->LoadModelAndCompileByPath(model_path, model_type); 3107- if (ret != RET_OK) { 3108- MS_LOG(ERROR) << "Load and compile failed"; 3109- return static_cast<StatusCode>(ret); 3110+ if (!inputs_.empty()) { 3111+ *input_num = inputs_.size(); 3112+ return inputs_.data(); 3113 } 3114- is_already_built = true; 3115- return static_cast<StatusCode>(kSuccess); 3116-} 3117 3118-Status ModelC::Resize(const std::vector<LiteTensorImpl *> &inputs, const std::vector<std::vector<int64_t>> &shapes) { 3119- std::vector<lite::Tensor *> inner_input; 3120- size_t input_num = inputs.size(); 3121- for (size_t i = 0; i < input_num; i++) { 3122- auto input = inputs[i]; 3123- if (input == nullptr || input->lite_tensor() == nullptr) { 3124- MS_LOG(ERROR) << "Input tensor is null."; 3125- return kLiteInputTensorError; 3126+ auto inputs = model_->GetInputs(); 3127+ *input_num = inputs.size(); 3128+ inputs_.resize(inputs.size(), nullptr); 3129+ for (size_t i = 0; i < inputs.size(); i++) { 3130+ inputs_[i] = new (std::nothrow) MSTensor(inputs[i].impl()); 3131+ if (inputs_[i] == nullptr) { 3132+ inputs_.clear(); 3133+ return nullptr; 3134 } 3135- inner_input.push_back(input->lite_tensor()); 3136 } 3137- size_t shape_num = shapes.size(); 3138- std::vector<std::vector<int32_t>> inner_shapes(shape_num); 3139- for (size_t i = 0; i < shape_num; i++) { 3140- std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(inner_shapes[i]), 3141- [](int64_t value) { return static_cast<int32_t>(value); }); 3142- } 3143- if (session_ == nullptr) { 3144- MS_LOG(ERROR) << "Session implement is null."; 3145- return kLiteNullptr; 3146- } 3147- auto ret = session_->Resize(inner_input, inner_shapes); 3148- return static_cast<StatusCode>(ret); 3149+ return inputs_.data(); 3150 } 3151 3152-void ModelC::ResetTensorData(std::vector<void *> old_data, std::vector<lite::Tensor *> tensors) { 3153- for (size_t j = 0; j < old_data.size(); j++) { 3154- tensors.at(j)->set_data(old_data.at(j)); 3155+MSTensor **ModelC::GetOutputs(size_t *output_num) { 3156+ if (model_->GetTrainMode() == true) { 3157+ return GetOutputsTensor(output_num, &outputs_train_); 3158+ } else { 3159+ return GetOutputsTensor(output_num, &outputs_); 3160 } 3161 } 3162 3163-Status ModelC::Predict(const OH_AI_TensorHandle *inputs, size_t input_num, OH_AI_TensorHandle **outputs, 3164- size_t *output_num, const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after) { 3165- if (outputs == nullptr || session_ == nullptr) { 3166- MS_LOG(ERROR) << "param is nullptr."; 3167- return kLiteError; 3168+MSTensor **ModelC::GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors) { 3169+ if (model_ == nullptr) { 3170+ MS_LOG(ERROR) << "model_ is nullptr."; 3171+ return nullptr; 3172 } 3173- auto model_inputs = session_->GetInputs(); 3174- if (model_inputs.size() != input_num) { 3175- MS_LOG(ERROR) << "Wrong input size."; 3176- return kLiteError; 3177+ if (!vec_tensors->empty()) { 3178+ *output_num = vec_tensors->size(); 3179+ return vec_tensors->data(); 3180 } 3181- std::vector<void *> old_data; 3182- for (size_t i = 0; i < input_num; i++) { 3183- auto real_input = model_inputs[i]; 3184- auto user_input = static_cast<LiteTensorImpl *>(inputs[i]); 3185- if (user_input->DataType() != static_cast<DataType>(real_input->data_type())) { 3186- ResetTensorData(old_data, model_inputs); 3187- MS_LOG(ERROR) << "DataType does not match, input:" << user_input->Name() 3188- << ", real:" << real_input->tensor_name(); 3189- return kLiteInputTensorError; 3190- } 3191- if (user_input->Data() == nullptr) { 3192- ResetTensorData(old_data, model_inputs); 3193- MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has no data."; 3194- return kLiteInputTensorError; 3195- } 3196 3197- // GPU tensor can't manipulate CPU memory which the user provides. 3198- // When model input is GPU tensor and user input is NOT GPU data, 3199- // just free model input's data for late GPU Tensor filling. 3200- if (IS_OPENCL_ALLOCATOR(real_input->allocator()) && (!IS_OPENCL_ALLOCATOR(user_input->GetAllocator()))) { 3201- real_input->FreeData(); 3202- } 3203- old_data.push_back(real_input->data()); // Save original data in model tensors. 3204- 3205- if (real_input->data_type() == kObjectTypeString) { 3206- std::vector<int32_t> shape; 3207- std::transform(user_input->Shape().begin(), user_input->Shape().end(), std::back_inserter(shape), 3208- [](int64_t value) { return static_cast<int32_t>(value); }); 3209- real_input->set_shape(shape); 3210- real_input->set_data(user_input->MutableData()); 3211- } else { 3212- if (user_input->MutableData() != real_input->data()) { 3213- if (real_input->Size() != user_input->DataSize()) { 3214- ResetTensorData(old_data, model_inputs); 3215- MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has wrong data size."; 3216- return kLiteInputTensorError; 3217- } 3218- if (!IS_OPENCL_ALLOCATOR(real_input->allocator())) { 3219- real_input->set_data(user_input->MutableData()); 3220- } else { 3221- // Use outside CPU data to fill GPU Tensor. 3222- auto dst_data = real_input->MutableData(); 3223- auto src_data = user_input->MutableData(); 3224- (void)memcpy(dst_data, src_data, real_input->Size()); 3225- } 3226- } 3227- } 3228- } 3229- auto ret = RunGraph(before, after); 3230- ResetTensorData(old_data, model_inputs); 3231- if (ret != kSuccess) { 3232- MS_LOG(ERROR) << "Run graph failed."; 3233- return ret; 3234- } 3235- 3236- *outputs = reinterpret_cast<OH_AI_TensorHandle *>(GetOutputs(output_num)); 3237- return kSuccess; 3238-} 3239- 3240-Status ModelC::RunGraph(const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after) { 3241- KernelCallBack before_call_back = nullptr; 3242- KernelCallBack after_call_back = nullptr; 3243- if (before != nullptr) { 3244- before_call_back = [&](const std::vector<mindspore::lite::Tensor *> &before_inputs, 3245- const std::vector<mindspore::lite::Tensor *> &before_outputs, 3246- const MSCallBackParam &call_param) { 3247- std::vector<LiteTensorImpl> inputs_impl; 3248- std::vector<LiteTensorImpl> outputs_impl; 3249- std::vector<OH_AI_TensorHandle> op_inputs; 3250- std::vector<OH_AI_TensorHandle> op_outputs; 3251- size_t op_input_num = before_inputs.size(); 3252- for (size_t i = 0; i < op_input_num; i++) { 3253- inputs_impl.emplace_back(before_inputs[i]); 3254- op_inputs.push_back(&(inputs_impl.back())); 3255- } 3256- size_t op_output_num = before_outputs.size(); 3257- for (size_t i = 0; i < op_output_num; i++) { 3258- outputs_impl.emplace_back(before_outputs[i]); 3259- op_outputs.push_back(&(outputs_impl.back())); 3260- } 3261- const OH_AI_CallBackParam op_info = {const_cast<char *>(call_param.node_name.c_str()), 3262- const_cast<char *>(call_param.node_type.c_str())}; 3263- OH_AI_TensorHandleArray inputs = {op_input_num, op_inputs.data()}; 3264- OH_AI_TensorHandleArray outputs = {op_output_num, op_outputs.data()}; 3265- return before(inputs, outputs, op_info); 3266- }; 3267- } 3268- if (after != nullptr) { 3269- after_call_back = [&](const std::vector<mindspore::lite::Tensor *> &after_inputs, 3270- const std::vector<mindspore::lite::Tensor *> &after_outputs, 3271- const MSCallBackParam &call_param) { 3272- std::vector<LiteTensorImpl> inputs_impl; 3273- std::vector<LiteTensorImpl> outputs_impl; 3274- std::vector<OH_AI_TensorHandle> op_inputs; 3275- std::vector<OH_AI_TensorHandle> op_outputs; 3276- size_t op_input_num = after_inputs.size(); 3277- for (size_t i = 0; i < op_input_num; i++) { 3278- inputs_impl.emplace_back(after_inputs[i]); 3279- op_inputs.push_back(&(inputs_impl.back())); 3280- } 3281- size_t op_output_num = after_outputs.size(); 3282- for (size_t i = 0; i < op_output_num; i++) { 3283- outputs_impl.emplace_back(after_outputs[i]); 3284- op_outputs.push_back(&(outputs_impl.back())); 3285- } 3286- const OH_AI_CallBackParam op_info = {const_cast<char *>(call_param.node_name.c_str()), 3287- const_cast<char *>(call_param.node_type.c_str())}; 3288- OH_AI_TensorHandleArray inputs = {op_input_num, op_inputs.data()}; 3289- OH_AI_TensorHandleArray outputs = {op_output_num, op_outputs.data()}; 3290- return after(inputs, outputs, op_info); 3291- }; 3292- } 3293- auto ret = session_->RunGraph(before_call_back, after_call_back); 3294- return static_cast<StatusCode>(ret); 3295-} 3296- 3297-LiteTensorImpl *ModelC::TensorToTensorImpl(mindspore::lite::Tensor *tensor) { 3298- LiteTensorImpl *impl = nullptr; 3299- auto iter = tensor_map_.find(tensor); 3300- if (iter != tensor_map_.end()) { 3301- impl = iter->second; 3302- } else { 3303- impl = new (std::nothrow) LiteTensorImpl(tensor); 3304- if (impl == nullptr || impl->lite_tensor() == nullptr) { 3305- MS_LOG(ERROR) << "Create tensor failed."; 3306+ auto outputs = model_->GetOutputs(); 3307+ *output_num = outputs.size(); 3308+ vec_tensors->resize(outputs.size(), nullptr); 3309+ for (size_t i = 0; i < outputs.size(); i++) { 3310+ (*vec_tensors)[i] = new (std::nothrow) MSTensor(outputs[i].impl()); 3311+ if ((*vec_tensors)[i] == nullptr) { 3312+ vec_tensors->clear(); 3313 return nullptr; 3314 } 3315- tensor_map_[tensor] = impl; 3316 } 3317- return impl; 3318+ return vec_tensors->data(); 3319 } 3320 3321-LiteTensorImpl **ModelC::GetInputs(size_t *input_num) { 3322- if (session_ == nullptr || input_num == nullptr) { 3323- MS_LOG(ERROR) << "Session is null."; 3324- return nullptr; 3325- } 3326- auto inputs = session_->GetInputs(); 3327- *input_num = inputs.size(); 3328- if (inputs_.capacity() < *input_num) { 3329- inputs_.reserve(*input_num); 3330- } 3331- inputs_.clear(); 3332- std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_), 3333- [&](lite::Tensor *input) { return TensorToTensorImpl(input); }); 3334- return inputs_.data(); 3335-} 3336+mindspore::MSKernelCallBack ModelC::TransCallBack(const OH_AI_KernelCallBack &oh_callback) { 3337+ mindspore::MSKernelCallBack call_back = nullptr; 3338+ if (oh_callback != nullptr) { 3339+ call_back = [&](const std::vector<mindspore::MSTensor> &inputs, 3340+ const std::vector<mindspore::MSTensor> &outputs, 3341+ const mindspore::MSCallBackParam &opInfo) { 3342+ std::vector<OH_AI_TensorHandle> vec_inputs; 3343+ std::vector<OH_AI_TensorHandle> vec_outputs; 3344+ OH_AI_CallBackParam call_back = {const_cast<char *>(opInfo.node_name.c_str()), 3345+ const_cast<char *>(opInfo.node_type.c_str())}; 3346+ size_t inputs_handle_num = inputs.size(); 3347+ for (size_t i = 0; i < inputs_handle_num; i++) { 3348+ vec_inputs.push_back( 3349+ static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(inputs)[i]))); 3350+ } 3351+ size_t outputs_handle_num = inputs.size(); 3352+ for (size_t i = 0; i < outputs_handle_num; i++) { 3353+ vec_outputs.push_back( 3354+ static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(outputs)[i]))); 3355+ } 3356 3357-LiteTensorImpl **ModelC::GetOutputs(size_t *output_num) { 3358- if (session_ == nullptr || output_num == nullptr) { 3359- MS_LOG(ERROR) << "Session is null."; 3360- return nullptr; 3361- } 3362- auto outputs = session_->GetOutputs(); 3363- *output_num = outputs.size(); 3364- if (outputs_.capacity() < *output_num) { 3365- outputs_.reserve(*output_num); 3366+ OH_AI_TensorHandleArray handle_inputs = {inputs_handle_num, vec_inputs.data()}; 3367+ OH_AI_TensorHandleArray handle_outputs = {outputs_handle_num, vec_outputs.data()}; 3368+ return oh_callback(handle_inputs, handle_outputs, call_back); 3369+ }; 3370 } 3371- outputs_.clear(); 3372- std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_), 3373- [&](std::unordered_map<std::string, mindspore::lite::Tensor *>::value_type iter) { 3374- return TensorToTensorImpl(iter.second); 3375- }); 3376- return outputs_.data(); 3377+ return call_back; 3378 } 3379 } // namespace mindspore 3380 3381 OH_AI_ModelHandle OH_AI_ModelCreate() { 3382 auto impl = new (std::nothrow) mindspore::ModelC(); 3383 if (impl == nullptr) { 3384- MS_LOG(ERROR) << "Model implement is null."; 3385+ MS_LOG(ERROR) << "Model implement is nullptr."; 3386+ return nullptr; 3387+ } 3388+ impl->model_ = std::make_shared<mindspore::Model>(); 3389+ if (impl->model_ == nullptr) { 3390+ MS_LOG(ERROR) << "model_ is nullptr."; 3391+ delete impl; 3392 return nullptr; 3393 } 3394 return static_cast<OH_AI_ModelHandle>(impl); 3395@@ -358,55 +172,59 @@ size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) { 3396 OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, 3397 OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context) { 3398 if (model == nullptr || model_data == nullptr || model_context == nullptr) { 3399- MS_LOG(ERROR) << "param is nullptr."; 3400+ MS_LOG(ERROR) << "model/model_data/model_context is nullptr."; 3401 return OH_AI_STATUS_LITE_NULLPTR; 3402 } 3403 if (model_type == OH_AI_MODELTYPE_INVALID) { 3404- MS_LOG(ERROR) << "param is invalid."; 3405+ MS_LOG(ERROR) << "model_type is invalid."; 3406 return OH_AI_STATUS_LITE_PARAM_INVALID; 3407 } 3408- mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context); 3409+ mindspore::Context *context = static_cast<mindspore::Context *>(model_context); 3410 auto impl = static_cast<mindspore::ModelC *>(model); 3411- auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context); 3412+ if (impl->context_.get() != context) { 3413+ impl->context_.reset(context); 3414+ } 3415+ auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_); 3416 return static_cast<OH_AI_Status>(ret.StatusCode()); 3417 } 3418 3419 OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type, 3420 const OH_AI_ContextHandle model_context) { 3421 if (model == nullptr || model_path == nullptr || model_context == nullptr) { 3422- MS_LOG(ERROR) << "param is nullptr."; 3423+ MS_LOG(ERROR) << "model/model_path/model_context is nullptr."; 3424 return OH_AI_STATUS_LITE_NULLPTR; 3425 } 3426 if (model_type == OH_AI_MODELTYPE_INVALID) { 3427- MS_LOG(ERROR) << "param is invalid."; 3428+ MS_LOG(ERROR) << "model_type is invalid."; 3429 return OH_AI_STATUS_LITE_PARAM_INVALID; 3430 } 3431- mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context); 3432+ mindspore::Context *context = static_cast<mindspore::Context *>(model_context); 3433 auto impl = static_cast<mindspore::ModelC *>(model); 3434- auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context); 3435+ if (impl->context_.get() != context) { 3436+ impl->context_.reset(context); 3437+ } 3438+ auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_); 3439 return static_cast<OH_AI_Status>(ret.StatusCode()); 3440 } 3441 3442 OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, 3443 OH_AI_ShapeInfo *shape_infos, size_t shape_info_num) { 3444 if (model == nullptr || shape_infos == nullptr) { 3445- MS_LOG(ERROR) << "param is nullptr."; 3446+ MS_LOG(ERROR) << "model/shape_infos is nullptr."; 3447 return OH_AI_STATUS_LITE_NULLPTR; 3448 } 3449- std::vector<mindspore::LiteTensorImpl *> vec_inputs; 3450- std::transform(inputs.handle_list, inputs.handle_list + inputs.handle_num, std::back_inserter(vec_inputs), 3451- [](OH_AI_TensorHandle value) { return static_cast<mindspore::LiteTensorImpl *>(value); }); 3452+ std::vector<mindspore::MSTensor> vec_inputs; 3453+ for (size_t i = 0; i < inputs.handle_num; ++i) { 3454+ vec_inputs.push_back(*static_cast<mindspore::MSTensor *>(inputs.handle_list[i])); 3455+ } 3456+ 3457 std::vector<std::vector<int64_t>> vec_dims; 3458 for (size_t i = 0; i < shape_info_num; i++) { 3459 std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num); 3460- if (std::any_of(shape.begin(), shape.end(), [](int64_t val) { return val < 0 || val > INT32_MAX; })) { 3461- MS_LOG(ERROR) << "Invalid shape: " << shape << ", each dimension must be in [0, INT32_MAX]"; 3462- return OH_AI_STATUS_LITE_PARAM_INVALID; 3463- } 3464 vec_dims.push_back(shape); 3465 } 3466 auto impl = static_cast<mindspore::ModelC *>(model); 3467- auto ret = impl->Resize(vec_inputs, vec_dims); 3468+ auto ret = impl->model_->Resize(vec_inputs, vec_dims); 3469 return static_cast<OH_AI_Status>(ret.StatusCode()); 3470 } 3471 3472@@ -414,15 +232,25 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl 3473 OH_AI_TensorHandleArray *outputs, const OH_AI_KernelCallBack before, 3474 const OH_AI_KernelCallBack after) { 3475 if (model == nullptr) { 3476- MS_LOG(ERROR) << "param is nullptr."; 3477+ MS_LOG(ERROR) << "model is nullptr."; 3478 return OH_AI_STATUS_LITE_NULLPTR; 3479 } 3480+ std::vector<mindspore::MSTensor> ms_tensor_inputs; 3481+ for (size_t i = 0; i < inputs.handle_num; i++) { 3482+ auto user_input = static_cast<mindspore::MSTensor *>(inputs.handle_list[i]); 3483+ ms_tensor_inputs.push_back(*user_input); 3484+ } 3485+ 3486 auto impl = static_cast<mindspore::ModelC *>(model); 3487- auto ret = impl->Predict(inputs.handle_list, inputs.handle_num, &(outputs->handle_list), &(outputs->handle_num), 3488- before, after); 3489+ mindspore::MSKernelCallBack before_call_back = impl->TransCallBack(before); 3490+ mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after); 3491+ 3492+ std::vector<mindspore::MSTensor> ms_tensor_outputs; 3493+ auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back); 3494 if (!ret.IsOk()) { 3495 MS_LOG(ERROR) << "Predict fail, ret :" << ret; 3496 } 3497+ outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&outputs->handle_num)); 3498 return static_cast<OH_AI_Status>(ret.StatusCode()); 3499 } 3500 3501@@ -431,11 +259,6 @@ OH_AI_Status OH_AI_ModelRunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallB 3502 return OH_AI_STATUS_LITE_NOT_SUPPORT; 3503 } 3504 3505-OH_AI_Status OH_AI_ModelSetTrainMode(const OH_AI_ModelHandle model, bool train) { 3506- MS_LOG(ERROR) << "Unsupported Feature."; 3507- return OH_AI_STATUS_LITE_NOT_SUPPORT; 3508-} 3509- 3510 OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *export_path) { 3511 MS_LOG(ERROR) << "Unsupported Feature."; 3512 return OH_AI_STATUS_LITE_NOT_SUPPORT; 3513@@ -443,7 +266,7 @@ OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char * 3514 3515 OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) { 3516 if (model == nullptr) { 3517- MS_LOG(ERROR) << "param is nullptr."; 3518+ MS_LOG(ERROR) << "model is nullptr."; 3519 return {0, nullptr}; 3520 } 3521 auto impl = static_cast<mindspore::ModelC *>(model); 3522@@ -454,7 +277,7 @@ OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) { 3523 3524 OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) { 3525 if (model == nullptr) { 3526- MS_LOG(ERROR) << "param is nullptr."; 3527+ MS_LOG(ERROR) << "model is nullptr."; 3528 return {0, nullptr}; 3529 } 3530 auto impl = static_cast<mindspore::ModelC *>(model); 3531@@ -465,7 +288,7 @@ OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) { 3532 3533 OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) { 3534 if (model == nullptr || tensor_name == nullptr) { 3535- MS_LOG(ERROR) << "param is nullptr."; 3536+ MS_LOG(ERROR) << "model/tensor_name is nullptr."; 3537 return nullptr; 3538 } 3539 auto impl = static_cast<mindspore::ModelC *>(model); 3540@@ -482,7 +305,7 @@ OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model 3541 3542 OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) { 3543 if (model == nullptr || tensor_name == nullptr) { 3544- MS_LOG(ERROR) << "param is nullptr."; 3545+ MS_LOG(ERROR) << "model/tensor_name is nullptr."; 3546 return nullptr; 3547 } 3548 auto impl = static_cast<mindspore::ModelC *>(model); 3549@@ -496,3 +319,294 @@ OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle mode 3550 MS_LOG(ERROR) << "tensor is not exist."; 3551 return nullptr; 3552 } 3553+ 3554+OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate() { 3555+ auto impl = new (std::nothrow) mindspore::TrainCfg(); 3556+ if (impl == nullptr) { 3557+ MS_LOG(ERROR) << "TrainCfg implement is nullptr."; 3558+ return nullptr; 3559+ } 3560+ return static_cast<OH_AI_TrainCfgHandle>(impl); 3561+} 3562+ 3563+void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg) { 3564+ if (train_cfg != nullptr && *train_cfg != nullptr) { 3565+ auto impl = static_cast<mindspore::TrainCfg *>(*train_cfg); 3566+ delete impl; 3567+ *train_cfg = nullptr; 3568+ } 3569+} 3570+ 3571+char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) { 3572+ if (train_cfg == nullptr || num == nullptr) { 3573+ MS_LOG(ERROR) << "train_cfg/num is nullptr."; 3574+ return nullptr; 3575+ } 3576+ auto impl = static_cast<mindspore::TrainCfg *>(train_cfg); 3577+ auto loss_name = impl->GetLossName(); 3578+ *num = loss_name.size(); 3579+ char **name = static_cast<char **>(malloc(loss_name.size())); 3580+ if (name == nullptr) { 3581+ MS_LOG(ERROR) << "Failed to malloc loss_name."; 3582+ return nullptr; 3583+ } 3584+ for (size_t i = 0; i < loss_name.size(); i++) { 3585+ name[i] = static_cast<char *>(malloc(loss_name[i].size() + 1)); 3586+ strcpy(name[i], loss_name[i].c_str()); 3587+ } 3588+ return name; 3589+} 3590+ 3591+void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num) { 3592+ if (train_cfg == nullptr) { 3593+ MS_LOG(ERROR) << "train_cfg is nullptr."; 3594+ return; 3595+ } 3596+ auto impl = static_cast<mindspore::TrainCfg *>(train_cfg); 3597+ std::vector<std::string> vec_name; 3598+ for (size_t i = 0; i < num; i++) { 3599+ vec_name.push_back(loss_name[i]); 3600+ } 3601+ impl->SetLossName(vec_name); 3602+} 3603+ 3604+OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg) { 3605+ if (train_cfg == nullptr) { 3606+ MS_LOG(ERROR) << "train_cfg is nullptr."; 3607+ return OH_AI_KO0; 3608+ } 3609+ auto impl = static_cast<mindspore::TrainCfg *>(train_cfg); 3610+ return static_cast<OH_AI_OptimizationLevel>(impl->optimization_level_); 3611+} 3612+ 3613+void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level) { 3614+ if (train_cfg == nullptr) { 3615+ MS_LOG(ERROR) << "train_cfg is nullptr."; 3616+ return; 3617+ } 3618+ auto impl = static_cast<mindspore::TrainCfg *>(train_cfg); 3619+ impl->optimization_level_ = static_cast<mindspore::OptimizationLevel>(level); 3620+} 3621+ 3622+OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, 3623+ OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, 3624+ const OH_AI_TrainCfgHandle train_cfg) { 3625+ if (model == nullptr || model_data == nullptr || model_context == nullptr) { 3626+ MS_LOG(ERROR) << "model/model_data/model_context is nullptr."; 3627+ return OH_AI_STATUS_LITE_NULLPTR; 3628+ } 3629+ if (model_type == OH_AI_MODELTYPE_INVALID) { 3630+ MS_LOG(ERROR) << "model_type is invalid."; 3631+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3632+ } 3633+ auto impl = static_cast<mindspore::ModelC *>(model); 3634+ 3635+ mindspore::Graph graph; 3636+ auto status = mindspore::Serialization::Load(model_data, data_size, static_cast<mindspore::ModelType>(model_type), &graph); 3637+ if (status != mindspore::kSuccess) { 3638+ MS_LOG(ERROR) << "load ms file failed."; 3639+ return OH_AI_STATUS_LITE_ERROR; 3640+ } 3641+ auto context = static_cast<mindspore::Context *>(model_context); 3642+ auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg); 3643+ if (impl->context_.get() != context) { 3644+ impl->context_.reset(context); 3645+ } 3646+ auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_, 3647+ std::shared_ptr<mindspore::TrainCfg>(build_train_cfg)); 3648+ if (ret != mindspore::kSuccess) { 3649+ MS_LOG(ERROR) << "Load and compile failed"; 3650+ } 3651+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3652+} 3653+ 3654+OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, 3655+ OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, 3656+ const OH_AI_TrainCfgHandle train_cfg) { 3657+ if (model == nullptr || model_path == nullptr || model_context == nullptr) { 3658+ MS_LOG(ERROR) << "model/model_path/model_context is nullptr."; 3659+ return OH_AI_STATUS_LITE_NULLPTR; 3660+ } 3661+ if (model_type == OH_AI_MODELTYPE_INVALID) { 3662+ MS_LOG(ERROR) << "model_type is invalid."; 3663+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3664+ } 3665+ auto impl = static_cast<mindspore::ModelC *>(model); 3666+ 3667+ mindspore::Graph graph; 3668+ auto status = mindspore::Serialization::Load(model_path, static_cast<mindspore::ModelType>(model_type), &graph); 3669+ if (status != mindspore::kSuccess) { 3670+ MS_LOG(ERROR) << "load ms file failed. " << model_path; 3671+ return OH_AI_STATUS_LITE_ERROR; 3672+ } 3673+ auto context = static_cast<mindspore::Context *>(model_context); 3674+ auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg); 3675+ if (impl->context_.get() != context) { 3676+ impl->context_.reset(context); 3677+ } 3678+ auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_, 3679+ std::shared_ptr<mindspore::TrainCfg>(build_train_cfg)); 3680+ if (ret != mindspore::kSuccess) { 3681+ MS_LOG(ERROR) << "Load and compile failed"; 3682+ } 3683+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3684+} 3685+ 3686+OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate) { 3687+ if (model == nullptr) { 3688+ MS_LOG(ERROR) << "model is nullptr."; 3689+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3690+ } 3691+ auto impl = static_cast<mindspore::ModelC *>(model); 3692+ auto ret = impl->model_->SetLearningRate(learning_rate); 3693+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3694+} 3695+ 3696+float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model) { 3697+ if (model == nullptr) { 3698+ MS_LOG(ERROR) << "model is nullptr."; 3699+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3700+ } 3701+ auto impl = static_cast<mindspore::ModelC *>(model); 3702+ return impl->model_->GetLearningRate(); 3703+} 3704+ 3705+OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) { 3706+ if (model == nullptr) { 3707+ MS_LOG(ERROR) << "model is nullptr."; 3708+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3709+ } 3710+ auto impl = static_cast<mindspore::ModelC *>(model); 3711+ auto ret = impl->model_->RunStep(impl->TransCallBack(before), impl->TransCallBack(after)); 3712+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3713+} 3714+ 3715+OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model) { 3716+ if (model == nullptr) { 3717+ MS_LOG(ERROR) << "model is nullptr."; 3718+ return {0, nullptr}; 3719+ } 3720+ auto impl = static_cast<mindspore::ModelC *>(model); 3721+ auto features = impl->model_->GetFeatureMaps(); 3722+ size_t handle_num = features.size(); 3723+ 3724+ mindspore::MSTensor **handle_list = static_cast<mindspore::MSTensor **>(malloc( 3725+ handle_num * sizeof(mindspore::MSTensor *))); 3726+ if (handle_list == nullptr) { 3727+ MS_LOG(ERROR) << "Failed to malloc handle_list."; 3728+ return {0, nullptr}; 3729+ } 3730+ for (size_t i = 0; i < handle_num; i++) { 3731+ handle_list[i] = new mindspore::MSTensor(features[i].impl()); 3732+ } 3733+ return {handle_num, reinterpret_cast<OH_AI_TensorHandle *>(handle_list)}; 3734+} 3735+ 3736+OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights) { 3737+ if (model == nullptr) { 3738+ MS_LOG(ERROR) << "model is nullptr."; 3739+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3740+ } 3741+ auto impl = static_cast<mindspore::ModelC *>(model); 3742+ std::vector<mindspore::MSTensor> weights; 3743+ for (size_t i = 0; i < new_weights.handle_num; i++) { 3744+ weights.push_back(*static_cast<mindspore::MSTensor *>(new_weights.handle_list[i])); 3745+ } 3746+ auto ret = impl->model_->UpdateWeights(weights); 3747+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3748+} 3749+ 3750+bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model) { 3751+ if (model == nullptr) { 3752+ MS_LOG(ERROR) << "model is nullptr."; 3753+ return false; 3754+ } 3755+ auto impl = static_cast<mindspore::ModelC *>(model); 3756+ return impl->model_->GetTrainMode(); 3757+} 3758+ 3759+OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train) { 3760+ if (model == nullptr) { 3761+ MS_LOG(ERROR) << "model is nullptr."; 3762+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3763+ } 3764+ auto impl = static_cast<mindspore::ModelC *>(model); 3765+ auto ret = impl->model_->SetTrainMode(train); 3766+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3767+} 3768+ 3769+OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, float momentum) { 3770+ if (model == nullptr) { 3771+ MS_LOG(ERROR) << "model is nullptr."; 3772+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3773+ } 3774+ auto impl = static_cast<mindspore::ModelC *>(model); 3775+ auto ret = impl->model_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum); 3776+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3777+} 3778+ 3779+OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file, 3780+ OH_AI_QuantizationType quantization_type, bool export_inference_only, 3781+ char **output_tensor_name, size_t num) { 3782+ if (model == nullptr) { 3783+ MS_LOG(ERROR) << "model is nullptr."; 3784+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3785+ } 3786+ auto impl = static_cast<mindspore::ModelC *>(model); 3787+ std::vector<std::string> tensor_name; 3788+ for (size_t i = 0; i < num; i++) { 3789+ tensor_name.push_back(output_tensor_name[i]); 3790+ } 3791+ auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), 3792+ model_file, 3793+ static_cast<mindspore::QuantizationType>(quantization_type), 3794+ export_inference_only, tensor_name); 3795+ if (!ret.IsOk()) { 3796+ MS_LOG(ERROR) << "export model fail, ret :" << ret; 3797+ } 3798+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3799+} 3800+ 3801+OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data, 3802+ size_t *data_size, OH_AI_QuantizationType quantization_type, 3803+ bool export_inference_only, char **output_tensor_name, size_t num) { 3804+ if (model == nullptr) { 3805+ MS_LOG(ERROR) << "model is nullptr."; 3806+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3807+ } 3808+ auto impl = static_cast<mindspore::ModelC *>(model); 3809+ std::vector<std::string> tensor_name; 3810+ for (size_t i = 0; i < num; i++) { 3811+ tensor_name.push_back(output_tensor_name[i]); 3812+ } 3813+ mindspore::Buffer buffer; 3814+ auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), 3815+ &buffer, static_cast<mindspore::QuantizationType>(quantization_type), 3816+ export_inference_only, tensor_name); 3817+ auto data = static_cast<char *>(buffer.MutableData()); 3818+ *model_data = (char *) malloc(buffer.DataSize()); 3819+ *data_size = buffer.DataSize(); 3820+ memcpy(*model_data, data, buffer.DataSize()); 3821+ if (!ret.IsOk()) { 3822+ MS_LOG(ERROR) << "export model fail, ret :" << ret; 3823+ } 3824+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3825+} 3826+ 3827+OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *weight_file, 3828+ bool is_inference, bool enable_fp16, char **changeable_weights_name, size_t num) { 3829+ if (model == nullptr) { 3830+ MS_LOG(ERROR) << "model is nullptr."; 3831+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3832+ } 3833+ auto impl = static_cast<mindspore::ModelC *>(model); 3834+ std::vector<std::string> weights_name; 3835+ for (size_t i = 0; i < num; i++) { 3836+ weights_name.push_back(changeable_weights_name[i]); 3837+ } 3838+ auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), weight_file, is_inference, enable_fp16, weights_name); 3839+ if (!ret.IsOk()) { 3840+ MS_LOG(ERROR) << "export model fail, ret :" << ret; 3841+ } 3842+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3843+} 3844diff --git a/mindspore/lite/src/litert/c_api/tensor_c.cc b/mindspore/lite/src/litert/c_api/tensor_c.cc 3845index 7b5c4c2f..4b1e6aff 100644 3846--- a/mindspore/lite/src/litert/c_api/tensor_c.cc 3847+++ b/mindspore/lite/src/litert/c_api/tensor_c.cc 3848@@ -17,7 +17,6 @@ 3849 #include "include/api/status.h" 3850 #include "src/tensor.h" 3851 #include "src/litert/cxx_api/tensor/tensor_impl.h" 3852-#include "src/litert/inner_allocator.h" 3853 3854 OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, const int64_t *shape, size_t shape_num, 3855 const void *data, size_t data_len) { 3856@@ -31,18 +30,23 @@ OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, con 3857 } 3858 auto lite_tensor = 3859 mindspore::lite::Tensor::CreateTensor(name, static_cast<mindspore::TypeId>(type), vec_shape, data, data_len); 3860- auto impl = new (std::nothrow) mindspore::LiteTensorImpl(lite_tensor); 3861- if (impl == nullptr || impl->lite_tensor() == nullptr) { 3862+ auto lite_tensor_impl = std::make_shared<mindspore::LiteTensorImpl>(lite_tensor); 3863+ if (lite_tensor_impl == nullptr || lite_tensor_impl->lite_tensor() == nullptr) { 3864 MS_LOG(ERROR) << "Failed to allocate tensor impl."; 3865 return nullptr; 3866 } 3867- impl->set_from_session(false); 3868+ lite_tensor_impl->set_from_session(false); 3869+ auto impl = new (std::nothrow) mindspore::MSTensor(lite_tensor_impl); 3870+ if (impl == nullptr) { 3871+ MS_LOG(ERROR) << "Failed to allocate MSTensor."; 3872+ return nullptr; 3873+ } 3874 return impl; 3875 } 3876 3877 void OH_AI_TensorDestroy(OH_AI_TensorHandle *tensor) { 3878 if (tensor != nullptr && *tensor != nullptr) { 3879- auto impl = static_cast<mindspore::LiteTensorImpl *>(*tensor); 3880+ auto impl = static_cast<mindspore::MSTensor *>(*tensor); 3881 delete impl; 3882 *tensor = nullptr; 3883 } 3884@@ -53,20 +57,14 @@ OH_AI_TensorHandle OH_AI_TensorClone(OH_AI_TensorHandle tensor) { 3885 MS_LOG(ERROR) << "param is nullptr."; 3886 return nullptr; 3887 } 3888- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3889- auto lite_tensor = static_cast<mindspore::lite::Tensor *>(impl->lite_tensor()); 3890- auto clone = mindspore::lite::Tensor::CopyTensor(*lite_tensor, true, lite_tensor->allocator()); 3891- if (clone == nullptr) { 3892- MS_LOG(ERROR) << "Failed to allocate tensor."; 3893- return nullptr; 3894- } 3895- auto clone_impl = new (std::nothrow) mindspore::LiteTensorImpl(clone); 3896+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3897+ auto clone_impl = impl->Clone(); 3898 if (clone_impl == nullptr) { 3899- delete clone; 3900 MS_LOG(ERROR) << "Failed to allocate tensor impl."; 3901 return nullptr; 3902 } 3903- clone_impl->set_from_session(false); 3904+ std::static_pointer_cast<mindspore::LiteTensorImpl>(clone_impl->impl())->set_own_data(false); 3905+ clone_impl->SetTensorName(impl->Name() + "_duplicate"); 3906 return clone_impl; 3907 } 3908 3909@@ -75,8 +73,8 @@ void OH_AI_TensorSetName(OH_AI_TensorHandle tensor, const char *name) { 3910 MS_LOG(ERROR) << "param is nullptr."; 3911 return; 3912 } 3913- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3914- impl->SetName(name); 3915+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3916+ impl->SetTensorName(name); 3917 } 3918 3919 const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) { 3920@@ -84,8 +82,8 @@ const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) { 3921 MS_LOG(ERROR) << "param is nullptr."; 3922 return nullptr; 3923 } 3924- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3925- return impl->Name().c_str(); 3926+ auto ms_tensor = static_cast<mindspore::MSTensor *>(tensor); 3927+ return std::static_pointer_cast<mindspore::LiteTensorImpl>(ms_tensor->impl())->Name().c_str(); 3928 } 3929 3930 void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) { 3931@@ -93,7 +91,7 @@ void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) { 3932 MS_LOG(ERROR) << "param is nullptr."; 3933 return; 3934 } 3935- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3936+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3937 impl->SetDataType(static_cast<mindspore::DataType>(type)); 3938 } 3939 3940@@ -102,7 +100,7 @@ OH_AI_DataType OH_AI_TensorGetDataType(const OH_AI_TensorHandle tensor) { 3941 MS_LOG(ERROR) << "param is nullptr."; 3942 return OH_AI_DATATYPE_UNKNOWN; 3943 } 3944- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3945+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3946 auto dtype = impl->DataType(); 3947 return static_cast<OH_AI_DataType>(dtype); 3948 } 3949@@ -112,7 +110,7 @@ void OH_AI_TensorSetShape(OH_AI_TensorHandle tensor, const int64_t *shape, size_ 3950 MS_LOG(ERROR) << "param is nullptr."; 3951 return; 3952 } 3953- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3954+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3955 std::vector<int64_t> vec_shape(shape_num); 3956 for (size_t i = 0; i < shape_num; i++) { 3957 vec_shape[i] = shape[i]; 3958@@ -125,7 +123,7 @@ const int64_t *OH_AI_TensorGetShape(const OH_AI_TensorHandle tensor, size_t *sha 3959 MS_LOG(ERROR) << "param is nullptr."; 3960 return nullptr; 3961 } 3962- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3963+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3964 *shape_num = impl->Shape().size(); 3965 return impl->Shape().data(); 3966 } 3967@@ -135,7 +133,7 @@ void OH_AI_TensorSetFormat(OH_AI_TensorHandle tensor, OH_AI_Format format) { 3968 MS_LOG(ERROR) << "param is nullptr."; 3969 return; 3970 } 3971- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3972+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3973 return impl->SetFormat(static_cast<mindspore::Format>(format)); 3974 } 3975 3976@@ -144,8 +142,8 @@ OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor) { 3977 MS_LOG(ERROR) << "param is nullptr."; 3978 return OH_AI_FORMAT_NHWC; 3979 } 3980- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3981- return static_cast<OH_AI_Format>(impl->Format()); 3982+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3983+ return static_cast<OH_AI_Format>(impl->format()); 3984 } 3985 3986 void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) { 3987@@ -153,16 +151,34 @@ void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) { 3988 MS_LOG(ERROR) << "param is nullptr."; 3989 return; 3990 } 3991- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3992+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3993 return impl->SetData(data, true); 3994 } 3995 3996+OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size) { 3997+ if (tensor == nullptr) { 3998+ MS_LOG(ERROR) << "param is nullptr."; 3999+ return OH_AI_STATUS_LITE_NULLPTR; 4000+ } 4001+ 4002+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4003+ if ((impl->DataSize() > 0) && (data_size != impl->DataSize())) { 4004+ MS_LOG(ERROR) << "input data size does not match inner data size"; 4005+ return OH_AI_STATUS_LITE_PARAM_INVALID; 4006+ } 4007+ 4008+ // This is one tricky way to represent that the inner data is not owned by tensor itself. 4009+ impl->SetAllocator(nullptr); 4010+ impl->SetData(data, false); 4011+ return OH_AI_STATUS_SUCCESS; 4012+} 4013+ 4014 const void *OH_AI_TensorGetData(const OH_AI_TensorHandle tensor) { 4015 if (tensor == nullptr) { 4016 MS_LOG(ERROR) << "param is nullptr."; 4017 return nullptr; 4018 } 4019- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 4020+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4021 return impl->Data().get(); 4022 } 4023 4024@@ -171,7 +187,7 @@ void *OH_AI_TensorGetMutableData(const OH_AI_TensorHandle tensor) { 4025 MS_LOG(ERROR) << "param is nullptr."; 4026 return nullptr; 4027 } 4028- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 4029+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4030 return impl->MutableData(); 4031 } 4032 4033@@ -180,7 +196,7 @@ int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor) { 4034 MS_LOG(ERROR) << "param is nullptr."; 4035 return 0; 4036 } 4037- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 4038+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4039 return impl->ElementNum(); 4040 } 4041 4042@@ -189,6 +205,6 @@ size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor) { 4043 MS_LOG(ERROR) << "param is nullptr."; 4044 return 0; 4045 } 4046- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 4047+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4048 return impl->DataSize(); 4049 } 4050diff --git a/mindspore/lite/src/litert/c_api/type_c_private.h b/mindspore/lite/src/litert/c_api/type_c_private.h 4051new file mode 100644 4052index 00000000..2d3b3883 4053--- /dev/null 4054+++ b/mindspore/lite/src/litert/c_api/type_c_private.h 4055@@ -0,0 +1,40 @@ 4056+/** 4057+ * Copyright 2023 Huawei Technologies Co., Ltd 4058+ * 4059+ * Licensed under the Apache License, Version 2.0 (the "License"); 4060+ * you may not use this file except in compliance with the License. 4061+ * You may obtain a copy of the License at 4062+ * 4063+ * http://www.apache.org/licenses/LICENSE-2.0 4064+ * 4065+ * Unless required by applicable law or agreed to in writing, software 4066+ * distributed under the License is distributed on an "AS IS" BASIS, 4067+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4068+ * See the License for the specific language governing permissions and 4069+ * limitations under the License. 4070+ */ 4071+#ifndef MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_ 4072+#define MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_ 4073+ 4074+#include <string> 4075+#include <vector> 4076+#include <memory> 4077+#include <stddef.h> 4078+#include "include/c_api/types_c.h" 4079+ 4080+#ifdef __cplusplus 4081+extern "C" { 4082+#endif 4083+ 4084+#define NNRT_DEVICE_NAME_MAX (128) 4085+ 4086+struct NNRTDeviceDesc { 4087+ size_t device_id; 4088+ OH_AI_NNRTDeviceType device_type; 4089+ char device_name[NNRT_DEVICE_NAME_MAX]; 4090+}; 4091+ 4092+#ifdef __cplusplus 4093+} 4094+#endif 4095+#endif // MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_ 4096diff --git a/mindspore/lite/src/litert/cxx_api/context.cc b/mindspore/lite/src/litert/cxx_api/context.cc 4097index 1371bcf0..e5f19d28 100644 4098--- a/mindspore/lite/src/litert/cxx_api/context.cc 4099+++ b/mindspore/lite/src/litert/cxx_api/context.cc 4100@@ -50,6 +50,11 @@ constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dyn 4101 constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size"; 4102 constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize"; 4103 constexpr auto kModelOptionAscendRankID = "mindspore.option.ascend.rank_id"; 4104+constexpr auto kModelOptionNNRTDeviceID = "mindspore.option.nnrt.device_id"; 4105+constexpr auto kModelOptionNNRTPerformanceMode = "mindspore.option.nnrt.performance_mode"; 4106+constexpr auto kModelOptionNNRTPriority = "mindspore.option.nnrt.priority"; 4107+constexpr auto kModelOptionNNRTEnableFP16 = "mindspore.option.nnrt.enable_fp16"; 4108+constexpr auto kModelOptionNNRTExtensions = "mindspore.option.nnrt.extensions"; 4109 #ifdef USE_GLOG 4110 extern "C" { 4111 extern void mindspore_log_init(); 4112@@ -684,4 +689,84 @@ std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const { 4113 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendBufferOptimize); 4114 return StringToChar(ref); 4115 } 4116+ 4117+void NNRTDeviceInfo::SetDeviceID(size_t device_id) { 4118+ if (data_ == nullptr) { 4119+ MS_LOG(ERROR) << "Invalid context."; 4120+ return; 4121+ } 4122+ data_->params[kModelOptionNNRTDeviceID] = device_id; 4123+} 4124+ 4125+size_t NNRTDeviceInfo::GetDeviceID() const { 4126+ if (data_ == nullptr) { 4127+ MS_LOG(ERROR) << "Invalid context."; 4128+ return 0; 4129+ } 4130+ return GetValue<size_t>(data_, kModelOptionNNRTDeviceID); 4131+} 4132+ 4133+void NNRTDeviceInfo::SetPerformanceMode(int performance_mode) { 4134+ if (data_ == nullptr) { 4135+ MS_LOG(ERROR) << "Invalid context."; 4136+ return; 4137+ } 4138+ data_->params[kModelOptionNNRTPerformanceMode] = performance_mode; 4139+} 4140+ 4141+int NNRTDeviceInfo::GetPerformanceMode() const { 4142+ if (data_ == nullptr) { 4143+ MS_LOG(ERROR) << "Invalid context."; 4144+ return 0; 4145+ } 4146+ return GetValue<int>(data_, kModelOptionNNRTPerformanceMode); 4147+} 4148+ 4149+void NNRTDeviceInfo::SetPriority(int priority) { 4150+ if (data_ == nullptr) { 4151+ MS_LOG(ERROR) << "Invalid context."; 4152+ return; 4153+ } 4154+ data_->params[kModelOptionNNRTPriority] = priority; 4155+} 4156+ 4157+int NNRTDeviceInfo::GetPriority() const { 4158+ if (data_ == nullptr) { 4159+ MS_LOG(ERROR) << "Invalid context."; 4160+ return 0; 4161+ } 4162+ return GetValue<int>(data_, kModelOptionNNRTPriority); 4163+} 4164+ 4165+void NNRTDeviceInfo::SetEnableFP16(bool is_fp16) { 4166+ if (data_ == nullptr) { 4167+ MS_LOG(ERROR) << "Invalid context."; 4168+ return; 4169+ } 4170+ data_->params[kModelOptionNNRTEnableFP16] = is_fp16; 4171+} 4172+ 4173+bool NNRTDeviceInfo::GetEnableFP16() const { 4174+ if (data_ == nullptr) { 4175+ MS_LOG(ERROR) << "Invalid context."; 4176+ return false; 4177+ } 4178+ return GetValue<bool>(data_, kModelOptionNNRTEnableFP16); 4179+} 4180+ 4181+void NNRTDeviceInfo::SetExtensions(const std::vector<Extension> &extensions) { 4182+ if (data_ == nullptr) { 4183+ MS_LOG(ERROR) << "Invalid context."; 4184+ return; 4185+ } 4186+ data_->params[kModelOptionNNRTExtensions] = extensions; 4187+} 4188+ 4189+std::vector<Extension> NNRTDeviceInfo::GetExtensions() const { 4190+ if (data_ == nullptr) { 4191+ MS_LOG(ERROR) << "Invalid context."; 4192+ return {}; 4193+ } 4194+ return GetValue<std::vector<Extension>>(data_, kModelOptionNNRTExtensions); 4195+} 4196 } // namespace mindspore 4197diff --git a/mindspore/lite/src/litert/cxx_api/converters.cc b/mindspore/lite/src/litert/cxx_api/converters.cc 4198index 0ff345cc..e54a36ee 100644 4199--- a/mindspore/lite/src/litert/cxx_api/converters.cc 4200+++ b/mindspore/lite/src/litert/cxx_api/converters.cc 4201@@ -86,6 +86,23 @@ Status ContextUtils::AddCustomDevice(lite::InnerContext *inner_context, 4202 return kSuccess; 4203 } 4204 4205+Status ContextUtils::AddNNRtDevice(lite::InnerContext *inner_context, size_t device_id, int performance_mode, 4206+ int priority, bool enable_fp16, const std::vector<Extension> &extensions) { 4207+ lite::DeviceInfo device_info = {0}; 4208+ device_info.nnrt_device_info_.device_id_ = device_id; 4209+ device_info.nnrt_device_info_.performance_mode_ = performance_mode; 4210+ device_info.nnrt_device_info_.priority_ = priority; 4211+ device_info.nnrt_device_info_.enable_fp16_ = enable_fp16; 4212+ for (auto src_extension: extensions) { 4213+ lite::Extension dest_extension; 4214+ dest_extension.name = src_extension.name; 4215+ dest_extension.value = src_extension.value; 4216+ device_info.nnrt_device_info_.extensions_.push_back(dest_extension); 4217+ } 4218+ inner_context->device_list_.push_back({lite::DT_NNRT, device_info}); 4219+ return kSuccess; 4220+} 4221+ 4222 void ContextUtils::ResetContextDefaultParam(Context *context) { 4223 if (context->GetInterOpParallelNum() == 0) { 4224 context->SetInterOpParallelNum(kDefaultInterOpParallelNum); 4225@@ -163,44 +180,11 @@ std::shared_ptr<lite::InnerContext> ContextUtils::Convert(Context *context) { 4226 ret = AddAscendDevice(inner_context.get(), device.get()); 4227 } else if (device->GetDeviceType() == kCustomDevice) { 4228 ret = AddCustomDevice(inner_context.get(), device); 4229- } 4230- if (ret != kSuccess) { 4231- MS_LOG(ERROR) << "Add device failed!"; 4232- return nullptr; 4233- } 4234- } 4235- return inner_context; 4236-} 4237- 4238-std::shared_ptr<lite::InnerContext> ContextUtils::Convert(const ContextC *context_c) { 4239- auto inner_context = std::make_shared<lite::InnerContext>(); 4240- if ((context_c == nullptr) || (inner_context == nullptr)) { 4241- MS_LOG(ERROR) << "Invalid context pointers."; 4242- return nullptr; 4243- } 4244- auto device_list = context_c->device_info_list; 4245- if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) { 4246- MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices; 4247- return nullptr; 4248- } 4249- SetContextAttr(context_c->thread_num, 1, context_c->enable_parallel, context_c->affinity_core_list, 4250- context_c->delegate_mode, context_c->delegate, inner_context.get()); 4251- inner_context->device_list_.clear(); 4252- Status ret = kLiteError; 4253- for (auto &device_info_c : device_list) { 4254- MS_CHECK_TRUE_RET(device_info_c != nullptr, nullptr); 4255- lite::DeviceInfo device_info = {{0}}; 4256- if (device_info_c->device_type == OH_AI_DEVICETYPE_CPU) { 4257- if (device_info_c->allocator == nullptr) { 4258- device_info_c->allocator = Allocator::Create(); 4259- } 4260- ret = AddCpuDevice(device_info_c->allocator, context_c->affinity_mode, device_info_c->enable_fp16, 4261- device_info_c->provider, device_info_c->provider_device, inner_context.get()); 4262- } else if (device_info_c->device_type == OH_AI_DEVICETYPE_GPU) { 4263- ret = AddGpuDevice(device_info_c->enable_fp16, 0, 0, 0, false, nullptr, nullptr, device_info_c->provider, 4264- device_info_c->provider_device, device_info_c->allocator, inner_context.get()); 4265- } else if (device_info_c->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) { 4266- ret = AddNpuDevice(device_info_c->enable_fp16, device_info_c->frequency, inner_context.get()); 4267+ } else if (device->GetDeviceType() == kNNRt) { 4268+ auto nnrt_device_info = device->Cast<NNRTDeviceInfo>(); 4269+ ret = AddNNRtDevice(inner_context.get(), nnrt_device_info->GetDeviceID(), 4270+ nnrt_device_info->GetPerformanceMode(), nnrt_device_info->GetPriority(), 4271+ nnrt_device_info->GetEnableFP16(), nnrt_device_info->GetExtensions()); 4272 } 4273 if (ret != kSuccess) { 4274 MS_LOG(ERROR) << "Add device failed!"; 4275diff --git a/mindspore/lite/src/litert/cxx_api/converters.h b/mindspore/lite/src/litert/cxx_api/converters.h 4276index 0c043fc3..1af7c7df 100644 4277--- a/mindspore/lite/src/litert/cxx_api/converters.h 4278+++ b/mindspore/lite/src/litert/cxx_api/converters.h 4279@@ -24,14 +24,12 @@ 4280 #include "include/api/cfg.h" 4281 #include "include/train/train_cfg.h" 4282 #include "src/litert/inner_context.h" 4283-#include "src/litert/c_api/context_c.h" 4284 #include "src/common/log_adapter.h" 4285 4286 namespace mindspore { 4287 class MS_API ContextUtils { 4288 public: 4289 static std::shared_ptr<lite::InnerContext> Convert(Context *context); 4290- static std::shared_ptr<lite::InnerContext> Convert(const ContextC *context_c); 4291 4292 private: 4293 static void SetContextAttr(int32_t thread_num, int32_t inter_op_parallel_num, bool enable_parallel, 4294@@ -48,6 +46,8 @@ class MS_API ContextUtils { 4295 static Status AddNpuDevice(bool enable_fp16, int frequency, lite::InnerContext *inner_context); 4296 static Status AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device); 4297 static Status AddCustomDevice(lite::InnerContext *inner_context, const std::shared_ptr<DeviceInfoContext> &device); 4298+ static Status AddNNRtDevice(lite::InnerContext *inner_context, size_t device_id, int performance_mode, int priority, 4299+ bool enable_fp16, const std::vector<Extension> &extensions); 4300 static bool IsAffinityModeValid(int affinity_mode) { 4301 return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU; 4302 } 4303diff --git a/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt b/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt 4304index 70aa63f3..625459e2 100644 4305--- a/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt 4306+++ b/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt 4307@@ -1,30 +1,13 @@ 4308 include_directories(${DDK_PATH}) 4309 include_directories($(CCSRC_DIR)/plugin/device/cpu/kernel) 4310+include_directories(${CMAKE_SOURCE_DIR}/../../../../../../foundation/ai/neural_network_runtime/) 4311 4312 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) 4313-#include_directories(/home/tony/wty/workspace/ohos/third_party/mindspore/mindspore/lite/mindir/include/inner) 4314-#include_directories(/home/tony/wty/workspace/ohos/third_party/mindspore/mindspore/lite/mindir/include) 4315+ 4316 file(GLOB_RECURSE NNRT_SRC 4317 ${CMAKE_CURRENT_SOURCE_DIR}/*.cc 4318 ) 4319- 4320-#add_library(hiai SHARED IMPORTED) 4321-#set_target_properties(hiai PROPERTIES IMPORTED_LOCATION 4322-# ${DDK_LIB_PATH}/libhiai.so) 4323-#add_library(hiai_ir SHARED IMPORTED) 4324-#set_target_properties(hiai_ir PROPERTIES IMPORTED_LOCATION 4325-# ${DDK_LIB_PATH}/libhiai_ir.so) 4326-#add_library(hiai_ir_build SHARED IMPORTED) 4327-#set_target_properties(hiai_ir_build PROPERTIES IMPORTED_LOCATION 4328-# ${DDK_LIB_PATH}/libhiai_ir_build.so) 4329-#add_library(npu_kernel_mid OBJECT ${NPU_RUNTIME_SRC}) 4330-#add_dependencies(npu_kernel_mid fbs_src) 4331-#target_link_libraries( 4332-# npu_kernel_mid 4333-# hiai 4334-# hiai_ir 4335-# hiai_ir_build 4336-#) 4337- 4338 file(GLOB convert_source checker/*.cc) 4339-add_library(nnr_mid OBJECT ${NNRT_SRC} ${convert_source} ) 4340\ No newline at end of file 4341+ 4342+add_library(nnrt_mid OBJECT ${NNRT_SRC} ${convert_source}) 4343+target_include_directories(nnrt_mid PUBLIC ${CMAKE_SOURCE_DIR}/../../../../../../foundation/ai/neural_network_runtime/) 4344\ No newline at end of file 4345diff --git a/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc b/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc 4346index 4df7e477..6b191c8e 100644 4347--- a/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc 4348+++ b/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc 4349@@ -109,6 +109,8 @@ Status CheckPrimitiveSupported(const schema::Primitive *primitive) { 4350 return mindspore::kSuccess; 4351 case schema::PrimitiveType_Unsqueeze: 4352 return mindspore::kSuccess; 4353+ case schema::PrimitiveType_Custom: 4354+ return mindspore::kSuccess; 4355 default: { 4356 MS_LOG(WARNING) << "No primitive type :" << (int)(type); 4357 return mindspore::kLiteSuccessExit; 4358diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc 4359index 34897331..9f012e76 100644 4360--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc 4361+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc 4362@@ -13,144 +13,637 @@ 4363 * See the License for the specific language governing permissions and 4364 * limitations under the License. 4365 */ 4366+ 4367+#include <unordered_set> 4368+#include <numeric> 4369 #include "nnrt_delegate.h" 4370 #include "checker/primitive_check.h" 4371 #include "src/common/log_adapter.h" 4372-#include "interfaces/kits/c/neural_network_runtime.h" 4373+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 4374 #include "interfaces/innerkits/c/neural_network_runtime_inner.h" 4375 #include "nnrt_model_kernel.h" 4376+#include "schema/model_generated.h" 4377+#include "schema/ops_generated.h" 4378+#include "flatbuffers/flatbuffers.h" 4379+#include "litert/tensor_category.h" 4380+ 4381+namespace mindspore { 4382+namespace lite { 4383+void NNRTDelegate::InitCachePath() { 4384+ static const std::string kCachePathName = "CachePath"; 4385+ static const std::string kCacheVersion = "CacheVersion"; 4386+ 4387+ const auto &extensions = nnrt_device_info_.extensions_; 4388 4389-mindspore::Status mindspore::NNRTDelegate::Build(DelegateModel<schema::Primitive> *model) { 4390- if (this->nnrt_lite_graph == nullptr) { 4391- MS_LOG(ERROR) << "nnrt_lite_graph is nullptr."; 4392- return mindspore::kLiteError; 4393+ auto iter_path = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) { 4394+ return extension.name == kCachePathName; 4395+ }); 4396+ if (iter_path != extensions.end()) { 4397+ cache_path_ = std::string(iter_path->value.begin(), iter_path->value.end()); 4398 } 4399- if (this->nnrt_lite_graph->sub_graphs_.empty()) { 4400- // must have at lease one subgraph 4401- MS_LOG(ERROR) << "must have at lease one subgraph"; 4402- return mindspore::kLiteError; 4403+ 4404+ auto iter_version = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) { 4405+ return extension.name == kCacheVersion; 4406+ }); 4407+ if (iter_version != extensions.end()) { 4408+ std::string version_str = std::string(iter_version->value.begin(), iter_version->value.end()); 4409+ cache_version_ = static_cast<uint32_t>(std::atol(version_str.c_str())); 4410 } 4411- OH_NN_ReturnCode ret_code; 4412- OH_NNModel *oh_nnmodel = OH_NNModel_Construct(); 4413- if (oh_nnmodel == nullptr) { 4414- MS_LOG(ERROR) << "Construct NNModel failed, oh_nnmodel is nullptr."; 4415- return mindspore::kLiteError; 4416+} 4417+ 4418+Status NNRTDelegate::Build(DelegateModel<schema::Primitive> *model) { 4419+#ifdef SUPPORT_NNRT_METAGRAPH 4420+ if (IsKirinNPU()) { 4421+ MS_LOG(DEBUG) << "Choose to build nnrt model with Metagraph"; 4422+ InitCachePath(); 4423+ return BuildKirinNPUModel(model); 4424 } 4425+#endif 4426 4427- ret_code = OH_NNModel_BuildFromLiteGraph(oh_nnmodel, this->nnrt_lite_graph); 4428- if (ret_code != OH_NN_SUCCESS) { 4429- MS_LOG(ERROR) << "Build NNModel failed, OH_NN_ReturnCode = " << ret_code; 4430- OH_NNModel_Destroy(&oh_nnmodel); 4431- return mindspore::kLiteError; 4432+ return BuildNormalModel(model); 4433+} 4434+ 4435+bool NNRTDelegate::IsCustomModel() const { 4436+ // check if there is only one Cutsom kernel in LiteModel. 4437+ if (lite_graph_ == nullptr) { 4438+ return false; 4439+ } 4440+ if (lite_graph_->all_nodes_.size() != 1) { 4441+ return false; 4442+ } 4443+ auto node = lite_graph_->all_nodes_[0]; 4444+ if (node == nullptr) { 4445+ return false; 4446+ } 4447+ if (node->node_type_ != mindspore::schema::PrimitiveType_Custom) { 4448+ return false; 4449+ } 4450+ return true; 4451+} 4452+ 4453+#ifdef SUPPORT_NNRT_METAGRAPH 4454+bool NNRTDelegate::IsKirinNPU() const { 4455+ const std::string kirin_npu_name_prefix = "NPU_"; 4456+ auto device_id = nnrt_device_info_.device_id_; 4457+ const char *device_name; 4458+ auto ret = OH_NNDevice_GetName(device_id, &device_name); 4459+ if (ret != OH_NN_SUCCESS) { 4460+ MS_LOG(WARNING) << "Get name of device: " << device_id << " failed, error: " << ret; 4461+ return false; 4462+ } 4463+ 4464+ if (strncmp(kirin_npu_name_prefix.c_str(), device_name, kirin_npu_name_prefix.size()) != 0) { 4465+ MS_LOG(WARNING) << "strncmp: " << device_id << " failed, device_name: " << device_name; 4466+ return false; 4467+ } 4468+ return true; 4469+} 4470+ 4471+Status NNRTDelegate::BuildKirinNPUModel(DelegateModel<schema::Primitive> *model) { 4472+ OH_NNModel *nn_model = OH_NNModel_Construct(); 4473+ if (nn_model == nullptr) { 4474+ MS_LOG(ERROR) << "Create NNModel failed, result is nullptr"; 4475+ return kLiteNullptr; 4476+ } 4477+ 4478+ size_t extension_size = nnrt_device_info_.extensions_.size(); 4479+ std::vector<OH_NN_Extension> extensions; 4480+ MS_LOG_DEBUG << "set extensions, item number: " << extension_size; 4481+ const size_t kExtensionNameMax = 128; // This is a length limitation in NNRT API. 4482+ for (size_t i = 0; i < extension_size; i++) { 4483+ auto &src_extension = nnrt_device_info_.extensions_[i]; 4484+ OH_NN_Extension dst_extension; 4485+ dst_extension.name[kExtensionNameMax - 1] = '\0'; 4486+ strncpy(dst_extension.name, src_extension.name.c_str(), kExtensionNameMax - 1); 4487+ dst_extension.value = (char *)((void *)src_extension.value.data()); 4488+ dst_extension.valueSize = src_extension.value.size(); 4489+ extensions.push_back(dst_extension); 4490+ MS_LOG_DEBUG << "set extension, item name: " << dst_extension.name << ", value size: " << dst_extension.valueSize; 4491+ } 4492+ 4493+ if (IsCustomModel()) { 4494+ auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, lite_graph_); 4495+ if (ret != OH_NN_SUCCESS) { 4496+ MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret; 4497+ OH_NNModel_Destroy(&nn_model); 4498+ return kLiteError; 4499+ } 4500+ } else { 4501+ SetKirinModelInputsAndOutputs(nn_model); 4502+ auto ret = OH_NNModel_BuildFromMetaGraph(nn_model, meta_graph_, extensions.data(), extensions.size()); 4503+ if (ret != OH_NN_SUCCESS) { 4504+ MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret; 4505+ OH_NNModel_Destroy(&nn_model); 4506+ return kLiteError; 4507+ } 4508+ } 4509+ 4510+ auto ret2 = CreateFullModelKernel(model, nn_model); 4511+ if (ret2 != kSuccess) { 4512+ MS_LOG(ERROR) << "Create full model kernel failed, ret: " << ret2; 4513+ return kLiteError; 4514 } 4515- MS_LOG(INFO) << "NNRTDelegate creates NNModel success."; 4516+ return kSuccess; 4517+} 4518+ 4519+std::vector<OH_NN_TensorInfo> NNRTDelegate::CreateNNTensorInfos(const std::vector<uint32_t> &indices) const { 4520+ std::vector<OH_NN_TensorInfo> nn_tensor_infos; 4521+ for (auto index: indices) { 4522+ auto tensor = lite_graph_->all_tensors_[index]; 4523+ auto shape = tensor->dims(); 4524+ auto data_type = tensor->dataType(); 4525+ auto name = tensor->name(); 4526+ auto format = tensor->format(); 4527 4528- OH_NNCompilation *oh_nn_compilation = nullptr; 4529- oh_nn_compilation = OH_NNCompilation_Construct(oh_nnmodel); 4530+ OH_NN_TensorInfo info; 4531+ info.dataType = CastToNNRTDataType(static_cast<mindspore::DataType>(data_type)); 4532+ info.dimensions = shape->data(); 4533+ info.dimensionCount = shape->size(); 4534+ strcpy(info.name, name->c_str()); 4535+ info.format = CastToNNRTFormat(static_cast<Format>(format)); 4536+ nn_tensor_infos.push_back(info); 4537+ } 4538+ return nn_tensor_infos; 4539+} 4540 4541- if (oh_nn_compilation == nullptr) { 4542+Status NNRTDelegate::SetKirinModelInputsAndOutputs(OH_NNModel *nn_model) { 4543+ std::vector<OH_NN_TensorInfo> inputInfos; 4544+ std::vector<OH_NN_TensorInfo> outputInfos; 4545+ auto input_infos = CreateNNTensorInfos(lite_graph_->input_indices_); 4546+ auto output_infos = CreateNNTensorInfos(lite_graph_->output_indices_); 4547+ OH_NNModel_SetInputsAndOutputsInfo(nn_model, input_infos.data(), input_infos.size(), output_infos.data(), 4548+ output_infos.size()); 4549+ return kSuccess; 4550+} 4551+ 4552+Status NNRTDelegate::CreateFullModelKernel(DelegateModel<schema::Primitive> *model, OH_NNModel *nn_model) { 4553+ OH_NNCompilation *nn_compilation = OH_NNCompilation_Construct(nn_model); 4554+ if (nn_compilation == nullptr) { 4555 MS_LOG(ERROR) << "Construct NNCompilation failed"; 4556- OH_NNModel_Destroy(&oh_nnmodel); 4557- return mindspore::kLiteError; 4558+ OH_NNModel_Destroy(&nn_model); 4559+ return kLiteError; 4560 } 4561- MS_LOG(INFO) << "NNRTDelegate creates NNCompilation success."; 4562+ MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success."; 4563 4564- const size_t *allDevicesID = nullptr; 4565- uint32_t device_count = 0; 4566- ret_code = OH_NNDevice_GetAllDevicesID(&allDevicesID, &device_count); 4567- if (ret_code != OH_NN_SUCCESS) { 4568- MS_LOG(ERROR) << "NNModel GetAllDevicesID failed, OH_NN_ReturnCode = " << ret_code; 4569- OH_NNCompilation_Destroy(&oh_nn_compilation); 4570- OH_NNModel_Destroy(&oh_nnmodel); 4571- return mindspore::kLiteError; 4572+ auto ret_code = InitNNCompilation(nn_compilation); 4573+ if (ret_code != kSuccess) { 4574+ MS_LOG(ERROR) << "Init NNCompilation failed"; 4575+ OH_NNModel_Destroy(&nn_model); 4576+ OH_NNCompilation_Destroy(&nn_compilation); 4577+ return kLiteError; 4578 } 4579+ OH_NNModel_Destroy(&nn_model); 4580 4581- if (device_count <= 0) { 4582- MS_LOG(WARNING) << "No NNRt Device found, fall back to CPU. "; 4583- // OH_NNCompilation_Destroy(&oh_nn_compilation); 4584- // OH_NNModel_Destroy(&oh_nnmodel); 4585- return mindspore::kSuccess; 4586+ OH_NNExecutor *nn_executor = nullptr; 4587+ nn_executor = OH_NNExecutor_Construct(nn_compilation); 4588+ if (nn_executor == nullptr) { 4589+ MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code; 4590+ OH_NNCompilation_Destroy(&nn_compilation); 4591+ return kLiteError; 4592 } 4593- MS_LOG(INFO) << "NNRTDelegate GetAllDevicesID success."; 4594+ OH_NNCompilation_Destroy(&nn_compilation); 4595 4596- // check if model ops are supported 4597- const bool *issupported = nullptr; 4598+ auto nnrt_model_kernel = new (std::nothrow)NNRTModelKernel(nn_executor, model->inputs(), model->outputs()); 4599+ if (nnrt_model_kernel == nullptr) { 4600+ OH_NNExecutor_Destroy(&nn_executor); 4601+ MS_LOG(ERROR) << "new NNRTModelKernel failed"; 4602+ return kLiteError; 4603+ } 4604+ model->Replace(model->BeginKernelIterator(), model->EndKernelIterator(), nnrt_model_kernel); 4605+ return kSuccess; 4606+} 4607+#endif 4608+ 4609+Status NNRTDelegate::BuildNormalModel(DelegateModel<schema::Primitive> *model) { 4610+ MS_LOG(DEBUG) << "Start to build NNRT model."; 4611+ if ((lite_graph_ == nullptr) || (lite_graph_->sub_graphs_.size() > 1)) { 4612+ MS_LOG(WARNING) << "LiteGraph contains more than one subgraph. NNRT does not support control-flow model yet, fallback to CPU"; 4613+ return kSuccess; 4614+ } 4615+ 4616+ OH_NNModel *full_model = CreateFullNNModel(); 4617+ if (full_model == nullptr) { 4618+ MS_LOG(WARNING) << "Build full NNModel failed, fallback to CPU"; 4619+ return kSuccess; 4620+ } 4621+ std::vector<bool> op_supports = QueryOpSupports(full_model); 4622+ if (op_supports.empty()) { 4623+ MS_LOG(WARNING) << "Query no op supports for full model, fallback to CPU"; 4624+ OH_NNModel_Destroy(&full_model); 4625+ return kSuccess; 4626+ } 4627+ auto nnrt_subgraph_ranges = GetNNRTSubgraphRanges(model, op_supports); 4628+ MS_LOG(INFO) << "Found NNRT subgraph count: " << nnrt_subgraph_ranges.size(); 4629+ 4630+ std::vector<LiteGraph *> sub_lite_graphs; 4631+ auto ret = CreateLiteGraphForNNRTSubgraph(nnrt_subgraph_ranges, &sub_lite_graphs); 4632+ if (ret != kSuccess) { 4633+ OH_NNModel_Destroy(&full_model); 4634+ MS_LOG(WARNING) << "Create NNRT sub LiteGraph failed, fallback to CPU"; 4635+ return kSuccess; 4636+ } 4637+ 4638+ std::vector<NNRTModelKernel *> nnrt_subgraph_kernels; 4639+ ret = CreateNNRTSubgraphKernels(model, sub_lite_graphs, nnrt_subgraph_ranges, &nnrt_subgraph_kernels); 4640+ if (ret != kSuccess) { 4641+ OH_NNModel_Destroy(&full_model); 4642+ MS_LOG(WARNING) << "Create NNRT subgraph kernel failed, fallback to CPU"; 4643+ return kSuccess; 4644+ } 4645+ 4646+ ReplaceNNRTKernelsInDelegateModel(model, nnrt_subgraph_ranges, nnrt_subgraph_kernels); 4647+ OH_NNModel_Destroy(&full_model); 4648+ MS_LOG(INFO) << "NNRTDelegate build success."; 4649+ return kSuccess; 4650+} 4651+ 4652+OH_NNModel *NNRTDelegate::CreateFullNNModel() { 4653+ if (lite_graph_ == nullptr) { 4654+ MS_LOG(ERROR) << "Lite graph is null"; 4655+ return nullptr; 4656+ } 4657+ 4658+ if (lite_graph_->sub_graphs_.empty()) { 4659+ MS_LOG(ERROR) << "Lite graph must have at lease one subgraph"; 4660+ return nullptr; 4661+ } 4662+ 4663+ OH_NNModel *nn_model = OH_NNModel_Construct(); 4664+ if (nn_model == nullptr) { 4665+ MS_LOG(ERROR) << "Create NNModel failed, result is nullptr"; 4666+ return nullptr; 4667+ } 4668+ 4669+ auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, lite_graph_); 4670+ if (ret != OH_NN_SUCCESS) { 4671+ MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret; 4672+ OH_NNModel_Destroy(&nn_model); 4673+ return nullptr; 4674+ } 4675+ return nn_model; 4676+} 4677+ 4678+std::vector<bool> NNRTDelegate::QueryOpSupports(OH_NNModel *nn_model) { 4679+ const bool *is_supported = nullptr; // Note: this memory is owned by nn_model, don't free alone. 4680 uint32_t op_count = 0; 4681- ret_code = OH_NNModel_GetAvailableOperations(oh_nnmodel, allDevicesID[0], &issupported, &op_count); 4682- if (ret_code != OH_NN_SUCCESS) { 4683- MS_LOG(ERROR) << "NNModel GetAvailableOperations failed, OH_NN_ReturnCode = " << ret_code 4684- << ", maybe due to dataParcel data length limitaion. Fall back to CPU."; 4685- OH_NNCompilation_Destroy(&oh_nn_compilation); 4686- OH_NNModel_Destroy(&oh_nnmodel); 4687- return mindspore::kSuccess; 4688+ auto ret = OH_NNModel_GetAvailableOperations(nn_model, nnrt_device_info_.device_id_, &is_supported, &op_count); 4689+ if (ret != OH_NN_SUCCESS) { 4690+ MS_LOG(WARNING) << "NNModel GetAvailableOperations failed, ret: " << ret 4691+ << ", maybe caused by dataParcel data length limitation"; 4692+ return {}; 4693 } 4694- uint32_t supported_op_count = 0; 4695- for (uint32_t i = 0; i < op_count; i++) { 4696- if (issupported[i]) { 4697- supported_op_count++; 4698+ std::vector<bool> op_supports(is_supported, is_supported + op_count); 4699+ return op_supports; 4700+} 4701+ 4702+/* Find continuous sub-sequence in op_supports. */ 4703+std::vector<NNRTOpRange> NNRTDelegate::GetNNRTSubgraphRanges(DelegateModel<schema::Primitive> *model, 4704+ const std::vector<bool> &op_supports) { 4705+ std::vector<NNRTOpRange> nnrt_subgraph_ranges; 4706+ NNRTOpRange op_range; 4707+ bool start_count = false; 4708+ for (size_t i = 0; i < op_supports.size(); i++) { 4709+ if (op_supports[i]) { 4710+ if (start_count == false) { 4711+ start_count = true; 4712+ op_range.begin_index_ = i; 4713+ op_range.begin_iter_ = model->BeginKernelIterator() + i; 4714+ } 4715+ } else { 4716+ if (start_count == true) { 4717+ start_count = false; 4718+ op_range.end_index_ = i; 4719+ op_range.end_iter_ = model->BeginKernelIterator() + i; 4720+ nnrt_subgraph_ranges.push_back(op_range); 4721+ } 4722 } 4723 } 4724- if (op_count != supported_op_count) { 4725- MS_LOG(WARNING) << "this model has " << op_count << "ops, but NNRT only support " << supported_op_count 4726- << " ops, fall back to CPU."; 4727- // must support all op, else fall back to CPU 4728- OH_NNCompilation_Destroy(&oh_nn_compilation); 4729- OH_NNModel_Destroy(&oh_nnmodel); 4730- return mindspore::kSuccess; 4731+ // handle last true subsequence 4732+ if (start_count == true) { 4733+ op_range.end_index_ = op_supports.size(); 4734+ op_range.end_iter_ = model->EndKernelIterator(); 4735+ nnrt_subgraph_ranges.push_back(op_range); 4736+ MS_LOG(INFO) << "Schedule NNRT subgraph range: [" << op_range.begin_index_ << ", " << op_range.end_index_ << ")"; 4737 } 4738- MS_LOG(INFO) << "NNRtDelegate supports all op in this model."; 4739+ return nnrt_subgraph_ranges; 4740+} 4741+ 4742+/** 4743+ * This method ONLY works when the follow pre-conditions are satisfied: 4744+ * 1. The node order of lite_graph_->all_nodes should be consistent with DelegateModel sequence. 4745+ * This ensures the kernel replacement in DelegateModel based on the re-organizing info from lite_graph_ is correct. 4746+ * 2. The node indices of lite_graph_->sub_graphs[0].node_indices should be monotonically increasing from 0 to size - 1. 4747+ */ 4748+Status NNRTDelegate::CreateLiteGraphForNNRTSubgraph( 4749+ const std::vector<NNRTOpRange> &nnrt_op_ranges, 4750+ std::vector<LiteGraph *> *sub_lite_graphs) { 4751+ MS_LOG(INFO) << "Start creating LiteGraph for NNRT subgraph"; 4752+ for (const auto &op_range: nnrt_op_ranges) { 4753+ MS_LOG(INFO) << "Process op range: [" << op_range.begin_index_ << ", " << op_range.end_index_ << ")"; 4754+ LiteGraph *sub_lite_graph = new (std::nothrow)LiteGraph; 4755+ if (sub_lite_graph == nullptr) { 4756+ MS_LOG(ERROR) << "Allocate LiteGraph failed"; 4757+ return kLiteError; 4758+ } 4759+ sub_lite_graph->name_ = lite_graph_->name_; 4760+ sub_lite_graph->version_ = lite_graph_->version_; 4761 4762- ret_code = OH_NNCompilation_SetDevice(oh_nn_compilation, allDevicesID[0]); 4763+ auto sub_graph = new (std::nothrow)LiteGraph::SubGraph; 4764+ if (sub_graph == nullptr) { 4765+ MS_LOG(ERROR) << "Allocate SubGraph failed"; 4766+ return kLiteError; 4767+ } 4768+ sub_graph->name_ = lite_graph_->name_; 4769+ sub_lite_graph->sub_graphs_.push_back(sub_graph); 4770 4771+ // deal with all_nodes 4772+ MS_LOG(INFO) << "Assemble all_nodes..."; 4773+ int new_node_index = 0; 4774+ std::map<uint32_t, schema::Tensor *> in_tensor_index_map; 4775+ std::map<uint32_t, schema::Tensor *> out_tensor_index_map; 4776+ for (size_t index = op_range.begin_index_; index < op_range.end_index_; index++) { 4777+ LiteGraph::Node *node = new (std::nothrow)LiteGraph::Node; 4778+ if (node == nullptr) { 4779+ MS_LOG(ERROR) << "Allocate Node failed"; 4780+ return kLiteError; 4781+ } 4782+ *node = *lite_graph_->all_nodes_[index]; 4783+ sub_lite_graph->all_nodes_.push_back(node); 4784+ sub_graph->node_indices_.push_back(new_node_index++); 4785+ 4786+ for (auto i: node->input_indices_) { 4787+ in_tensor_index_map.emplace(i, lite_graph_->all_tensors_[i]); 4788+ } 4789+ for (auto i: node->output_indices_) { 4790+ out_tensor_index_map.emplace(i, lite_graph_->all_tensors_[i]); 4791+ } 4792+ } 4793+ 4794+ // deal with all_tensors 4795+ MS_LOG(INFO) << "Assemble all_tensors..."; 4796+ std::set<schema::Tensor *> tensors; 4797+ for (auto iter: in_tensor_index_map) { 4798+ tensors.emplace(iter.second); 4799+ } 4800+ for (auto iter: out_tensor_index_map) { 4801+ tensors.emplace(iter.second); 4802+ } 4803+ 4804+ uint32_t new_index = 0; 4805+ std::map<schema::Tensor *, uint32_t> new_tensor_maps; 4806+ for (auto tensor: tensors) { 4807+ new_tensor_maps.emplace(tensor, new_index++); 4808+ } 4809+ 4810+ sub_lite_graph->all_tensors_ = std::vector<schema::Tensor *>(tensors.begin(), tensors.end()); 4811+ 4812+ // deal with every node's input/output indices 4813+ MS_LOG(INFO) << "Set input/output indices of each node..."; 4814+ for (auto node: sub_lite_graph->all_nodes_) { 4815+ for (auto &index : node->input_indices_) { 4816+ index = new_tensor_maps.at(in_tensor_index_map.at(index)); 4817+ } 4818+ for (auto &index : node->output_indices_) { 4819+ index = new_tensor_maps.at(out_tensor_index_map.at(index)); 4820+ } 4821+ } 4822+ 4823+ // deal with subgraph's input/output indices 4824+ MS_LOG(INFO) << "Set input/output indices of each subgraph..."; 4825+ sub_graph->tensor_indices_ = std::vector<uint32_t>(tensors.size()); 4826+ std::iota(sub_graph->tensor_indices_.begin(), sub_graph->tensor_indices_.end(), 0U); 4827+ 4828+ for (auto iter: in_tensor_index_map) { 4829+ auto new_tensor_index = new_tensor_maps[iter.second]; 4830+ MS_LOG(DEBUG) << "handle input: old: " << iter.first << ", new: " << new_tensor_index << std::endl; 4831+ if (IsConstTensor(*iter.second)) { 4832+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is const." << std::endl; 4833+ continue; 4834+ } 4835+ 4836+ bool is_subgraph_input = true; 4837+ for (auto node: sub_lite_graph->all_nodes_) { 4838+ if (std::find(node->output_indices_.begin(), node->output_indices_.end(), new_tensor_index) != 4839+ node->output_indices_.end()) { 4840+ is_subgraph_input = false; 4841+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is not subgraph input." << std::endl; 4842+ break; 4843+ } 4844+ } 4845+ if (is_subgraph_input) { 4846+ sub_graph->input_indices_.push_back(new_tensor_index); 4847+ MS_LOG(DEBUG) << "- select tensor: " << new_tensor_index << " as subgraph input." << std::endl; 4848+ } 4849+ } 4850+ 4851+ for (auto iter: out_tensor_index_map) { 4852+ int new_tensor_index = new_tensor_maps.at(iter.second); 4853+ MS_LOG(DEBUG) << "handle output: old: " << iter.first << ", new: " << new_tensor_index << std::endl; 4854+ if (IsConstTensor(*iter.second)) { 4855+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is const." << std::endl; 4856+ continue; 4857+ } 4858+ 4859+ bool is_subgraph_output = false; 4860+ for (size_t i = 0; i < lite_graph_->all_nodes_.size(); i++) { 4861+ if ((i >= op_range.begin_index_) && (i < op_range.end_index_)) { 4862+ continue; 4863+ } 4864+ auto node = lite_graph_->all_nodes_[i]; 4865+ if (std::find(node->input_indices_.begin(), node->input_indices_.end(), iter.first) != 4866+ node->input_indices_.end()) { // As the input of node which does not belong to the subgraph. 4867+ is_subgraph_output = true; 4868+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is original subgraph output. node: " << node->primitive_ << std::endl; 4869+ break; 4870+ } 4871+ } 4872+ bool is_graph_output = (std::find(lite_graph_->output_indices_.begin(),lite_graph_->output_indices_.end(), 4873+ iter.first) != lite_graph_->output_indices_.end()); 4874+ if (is_graph_output) { 4875+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is graph output." << std::endl; 4876+ } 4877+ if (is_subgraph_output || is_graph_output) { 4878+ sub_graph->output_indices_.push_back(new_tensor_index); 4879+ MS_LOG(DEBUG) << "- select tensor: " << new_tensor_index << " as subgraph output." << std::endl; 4880+ } 4881+ } 4882+ 4883+ // deal with full-graph's input/output indices 4884+ sub_lite_graph->input_indices_ = sub_graph->input_indices_; 4885+ sub_lite_graph->output_indices_ = sub_graph->output_indices_; 4886+ sub_lite_graphs->push_back(sub_lite_graph); 4887+ } 4888+ MS_LOG(INFO) << "Finished creating LiteGraph for NNRT subgraph"; 4889+ return kSuccess; 4890+} 4891+ 4892+struct TensorLocation { 4893+ uint32_t node_index; // the index of node which the tensor belongs to. 4894+ uint32_t tensor_index; // the index of node in/out tensors which the tensor is located at. 4895+}; 4896+ 4897+Status NNRTDelegate::InitNNCompilation(OH_NNCompilation *nn_compilation) const { 4898+ auto ret_code = OH_NNCompilation_SetDevice(nn_compilation, nnrt_device_info_.device_id_); 4899 if (ret_code != OH_NN_SUCCESS) { 4900- MS_LOG(ERROR) << "NNCompilation SetDevice failed, OH_NN_ReturnCode = " << ret_code; 4901- OH_NNCompilation_Destroy(&oh_nn_compilation); 4902- OH_NNModel_Destroy(&oh_nnmodel); 4903- return mindspore::kLiteError; 4904+ MS_LOG(ERROR) << "NNCompilation set device id failed, ret: " << ret_code; 4905+ return kLiteError; 4906+ } 4907+ ret_code = OH_NNCompilation_SetPerformanceMode(nn_compilation, 4908+ (OH_NN_PerformanceMode)(nnrt_device_info_.performance_mode_)); 4909+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 4910+ MS_LOG(ERROR) << "NNCompilation set performance mode failed, ret: " << ret_code; 4911+ return kLiteError; 4912+ } 4913+ ret_code = OH_NNCompilation_SetPriority(nn_compilation, (OH_NN_Priority)(nnrt_device_info_.priority_)); 4914+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 4915+ MS_LOG(ERROR) << "NNCompilation set priority failed, ret: " << ret_code; 4916+ return kLiteError; 4917+ } 4918+ ret_code = OH_NNCompilation_EnableFloat16(nn_compilation, nnrt_device_info_.enable_fp16_); 4919+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 4920+ MS_LOG(ERROR) << "NNCompilation enable fp16 failed, ret: " << ret_code; 4921+ return kLiteError; 4922 } 4923 4924- ret_code = OH_NNCompilation_Build(oh_nn_compilation); 4925+ if (!cache_path_.empty()) { // Set cache path if user indeed set it. 4926+ ret_code = OH_NNCompilation_SetCache(nn_compilation, cache_path_.c_str(), cache_version_); 4927+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 4928+ MS_LOG(ERROR) << "NNCompilation set cache failed, ret: " << ret_code; 4929+ return kLiteError; 4930+ } 4931+ } 4932 4933+ ret_code = OH_NNCompilation_Build(nn_compilation); 4934 if (ret_code != OH_NN_SUCCESS) { 4935- MS_LOG(ERROR) << "Build NNCompilation failed, OH_NN_ReturnCode = " << ret_code; 4936- OH_NNCompilation_Destroy(&oh_nn_compilation); 4937- OH_NNModel_Destroy(&oh_nnmodel); 4938- return mindspore::kLiteError; 4939- } 4940- 4941- MS_LOG(DEBUG) << "NNRTDelegate SetDevice success."; 4942- 4943- OH_NNExecutor *oh_nn_executor = nullptr; 4944- oh_nn_executor = OH_NNExecutor_Construct(oh_nn_compilation); 4945- if (oh_nn_executor == nullptr) { 4946- MS_LOG(ERROR) << "Construct NNCompilation SetDevice failed, OH_NN_ReturnCode = " << ret_code; 4947- OH_NNCompilation_Destroy(&oh_nn_compilation); 4948- OH_NNModel_Destroy(&oh_nnmodel); 4949- return mindspore::kLiteError; 4950- } 4951- MS_LOG(DEBUG) << "NNRTDelegate creates NNExecutor success."; 4952- mindspore::Status prepare_data_ret; 4953- auto nnr_model_kernel = new (std::nothrow) NNRTModelKernel(oh_nn_executor, model->inputs(), model->outputs()); 4954- if (nnr_model_kernel == nullptr) { 4955- MS_LOG(ERROR) << "new NNRTModelKernel failed"; 4956- return mindspore::kLiteError; 4957+ MS_LOG(ERROR) << "Build NNCompilation failed, ret: " << ret_code; 4958+ return kLiteError; 4959 } 4960- OH_NNCompilation_Destroy(&oh_nn_compilation); 4961- OH_NNModel_Destroy(&oh_nnmodel); 4962- KernelIter from = model->BeginKernelIterator(); 4963- KernelIter end = model->EndKernelIterator(); 4964- model->Replace(from, end, nnr_model_kernel); 4965+ return kSuccess; 4966+} 4967+ 4968+Status NNRTDelegate::CreateNNRTSubgraphKernels(DelegateModel<schema::Primitive> *model, 4969+ const std::vector<LiteGraph *> &sub_lite_graphs, const std::vector<NNRTOpRange> &nnrt_subgraph_ranges, 4970+ std::vector<NNRTModelKernel *> *nnrt_subgraph_kernels) { 4971+ for (size_t i = 0; i < sub_lite_graphs.size(); i++) { 4972+ auto sub_lite_graph = sub_lite_graphs[i]; 4973+ 4974+ OH_NNModel *nn_model = OH_NNModel_Construct(); 4975+ auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, sub_lite_graph); 4976+ if (ret != OH_NN_SUCCESS) { 4977+ MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret; 4978+ OH_NNModel_Destroy(&nn_model); 4979+ return kLiteError; 4980+ } 4981 4982- MS_LOG(INFO) << "NNRTDelegate build success."; 4983- return mindspore::kSuccess; 4984+ OH_NNCompilation *nn_compilation = OH_NNCompilation_Construct(nn_model); 4985+ if (nn_compilation == nullptr) { 4986+ MS_LOG(ERROR) << "Construct NNCompilation failed"; 4987+ OH_NNModel_Destroy(&nn_model); 4988+ return kLiteError; 4989+ } 4990+ MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success."; 4991+ 4992+ auto ret_code = InitNNCompilation(nn_compilation); 4993+ if (ret_code != kSuccess) { 4994+ MS_LOG(ERROR) << "Init NNCompilation failed"; 4995+ OH_NNCompilation_Destroy(&nn_compilation); 4996+ OH_NNModel_Destroy(&nn_model); 4997+ return kLiteError; 4998+ } 4999+ 5000+ OH_NNExecutor *nn_executor = nullptr; 5001+ nn_executor = OH_NNExecutor_Construct(nn_compilation); 5002+ if (nn_executor == nullptr) { 5003+ MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code; 5004+ OH_NNCompilation_Destroy(&nn_compilation); 5005+ OH_NNModel_Destroy(&nn_model); 5006+ return kLiteError; 5007+ } 5008+ MS_LOG(DEBUG) << "NNRTDelegate creates NNExecutor success."; 5009+ 5010+ bool format_not_support = false; 5011+ std::vector<MSTensor> in_tensors; 5012+ for (auto index: sub_lite_graph->sub_graphs_[0]->input_indices_) { 5013+ TensorLocation location; 5014+ for (auto node_index: sub_lite_graph->sub_graphs_[0]->node_indices_) { 5015+ auto node = sub_lite_graph->all_nodes_[node_index]; 5016+ auto iter = std::find(node->input_indices_.begin(), node->input_indices_.end(), index); 5017+ if (iter != node->input_indices_.end()) { 5018+ uint32_t tensor_index = iter - node->input_indices_.begin(); 5019+ location.node_index = node_index; 5020+ location.tensor_index = tensor_index; 5021+ MS_LOG(INFO) << "Found graph input index: " << index << " is the " << tensor_index << "th input of the node " << node->primitive_; 5022+ break; 5023+ } 5024+ } 5025+ KernelIter kernel_iter = nnrt_subgraph_ranges[i].begin_iter_ + location.node_index; 5026+ in_tensors.push_back((*kernel_iter)->inputs()[location.tensor_index]); 5027+ if (in_tensors.back().format() != Format::NHWC) { 5028+ format_not_support = true; 5029+ break ; 5030+ } 5031+ } 5032+ 5033+ std::vector<MSTensor> out_tensors; 5034+ for (auto index: sub_lite_graph->sub_graphs_[0]->output_indices_) { 5035+ TensorLocation location; 5036+ for (auto node_index: sub_lite_graph->sub_graphs_[0]->node_indices_) { 5037+ auto node = sub_lite_graph->all_nodes_[node_index]; 5038+ auto iter = std::find(node->output_indices_.begin(), node->output_indices_.end(), index); 5039+ if (iter != node->output_indices_.end()) { 5040+ uint32_t tensor_index = iter - node->output_indices_.begin(); 5041+ location.node_index = node_index; 5042+ location.tensor_index = tensor_index; 5043+ MS_LOG(INFO) << "Found graph output index: " << index << " is the " << tensor_index << "th output of the node " << node->primitive_; 5044+ break; 5045+ } 5046+ } 5047+ KernelIter kernel_iter = nnrt_subgraph_ranges[i].begin_iter_ + location.node_index; 5048+ out_tensors.push_back((*kernel_iter)->outputs()[location.tensor_index]); 5049+ if (out_tensors.back().format() != Format::NHWC) { 5050+ format_not_support = true; 5051+ break ; 5052+ } 5053+ } 5054+ if (format_not_support) { 5055+ MS_LOG(WARNING) << "Not support in/out tensor format, skip this subgraph"; 5056+ OH_NNCompilation_Destroy(&nn_compilation); 5057+ OH_NNModel_Destroy(&nn_model); 5058+ nnrt_subgraph_kernels->push_back(nullptr); 5059+ continue ; 5060+ } 5061+ 5062+ auto nnrt_model_kernel = new (std::nothrow)NNRTModelKernel(nn_executor, in_tensors, out_tensors); 5063+ if (nnrt_model_kernel == nullptr) { 5064+ MS_LOG(ERROR) << "new NNRTModelKernel failed"; 5065+ return kLiteError; 5066+ } 5067+ OH_NNCompilation_Destroy(&nn_compilation); 5068+ OH_NNModel_Destroy(&nn_model); 5069+ nnrt_subgraph_kernels->push_back(nnrt_model_kernel); 5070+ } 5071+ return kSuccess; 5072 } 5073 5074-mindspore::Status mindspore::NNRTDelegate::Init() { 5075- MS_LOG(DEBUG) << "NNRTDelegate init success."; 5076- return mindspore::kSuccess; 5077+void NNRTDelegate::ReplaceNNRTKernelsInDelegateModel(DelegateModel<schema::Primitive> *model, 5078+ const std::vector<NNRTOpRange> &nnrt_subgraph_ranges, 5079+ const std::vector<NNRTModelKernel *> &nnrt_subgraph_kernels) { 5080+ // Here we perform the replacement from back to front intentionally! If replace from front to end, the kernel 5081+ // sequence would shrink and the later begin_iter_/end_iter_ may be erased already. 5082+ for (int i = nnrt_subgraph_ranges.size() - 1; i >= 0; i--) { 5083+ if (nnrt_subgraph_kernels[i] == nullptr) { 5084+ continue; 5085+ } 5086+ auto from = nnrt_subgraph_ranges[i].begin_iter_; 5087+ auto end = nnrt_subgraph_ranges[i].end_iter_; 5088+ (void)model->Replace(from, end, nnrt_subgraph_kernels[i]); 5089+ MS_LOG(INFO) << "Replace nnrt subgraph kernel in range: [" << (from - model->BeginKernelIterator()) 5090+ << ", " << (end - model->BeginKernelIterator()) << ")"; 5091+ } 5092 } 5093-mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::Primitive> *model, 5094- OH_NNExecutor *oh_nn_executor) { 5095+ 5096+Status NNRTDelegate::PrepareInputs(DelegateModel<schema::Primitive> *model, 5097+ OH_NNExecutor *oh_nn_executor) { 5098 auto input_tensors = model->inputs(); 5099 for (size_t i = 0; i < input_tensors.size(); i++) { 5100 auto tensor = input_tensors[i]; 5101@@ -161,10 +654,10 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P 5102 std::vector<double> scale; 5103 std::vector<int32_t> zero_point; 5104 if (!tmp_quant_param.empty()) { 5105- quant_param = new (std::nothrow) OH_NN_QuantParam; 5106+ quant_param = new(std::nothrow) OH_NN_QuantParam; 5107 if (quant_param == nullptr) { 5108 MS_LOG(ERROR) << "new OH_NN_QuantParam failed."; 5109- return mindspore::kLiteError; 5110+ return kLiteError; 5111 } 5112 for (auto qparam : tmp_quant_param) { 5113 bit_num.emplace_back(qparam.bit_num); 5114@@ -176,12 +669,12 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P 5115 quant_param->scale = scale.data(); 5116 quant_param->zeroPoint = zero_point.data(); 5117 } 5118- auto oprend = new (std::nothrow) OH_NN_Tensor; 5119+ auto oprend = new(std::nothrow) OH_NN_Tensor; 5120 if (oprend == nullptr) { 5121 MS_LOG(ERROR) << "new OH_NN_Tensor Failed"; 5122- return mindspore::kLiteError; 5123+ return kLiteError; 5124 } 5125- oprend->dataType = ConvertDataType(tensor.DataType()); 5126+ oprend->dataType = CastToNNRTDataType(tensor.DataType()); 5127 oprend->dimensionCount = tensor_shape.size(); 5128 5129 std::vector<int32_t> dimensions_list; 5130@@ -191,14 +684,14 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P 5131 } else { 5132 MS_LOG(ERROR) << "NNExecutor SetInput failed,tensor dimension is is too large, max dim = " << INT32_MAX 5133 << ", but get dimension = " << shape; 5134- return mindspore::kLiteError; 5135+ return kLiteError; 5136 } 5137 } 5138 oprend->dimensions = dimensions_list.data(); 5139 oprend->quantParam = quant_param; 5140 oprend->type = OH_NN_TENSOR; 5141 OH_NN_ReturnCode ret_code = 5142- OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize()); 5143+ OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize()); 5144 delete (oprend); 5145 5146 if (!tmp_quant_param.empty()) { 5147@@ -209,70 +702,41 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P 5148 if (ret_code != OH_NN_SUCCESS) { 5149 MS_LOG(ERROR) << "NNExecutor SetInput failed, current input tensor is" << tensor.Name() 5150 << "OH_NN_ReturnCode = " << ret_code; 5151- return mindspore::kLiteError; 5152+ return kLiteError; 5153 } 5154 } 5155+ return kSuccess; 5156+} 5157+ 5158+OH_NN_DataType NNRTDelegate::CastToNNRTDataType(DataType data_type) { 5159+ const std::unordered_map<DataType, OH_NN_DataType> kDataTypeMap = { 5160+ {DataType::kNumberTypeBool, OH_NN_BOOL}, 5161+ {DataType::kNumberTypeInt8, OH_NN_INT8}, 5162+ {DataType::kNumberTypeInt16, OH_NN_INT16}, 5163+ {DataType::kNumberTypeInt32, OH_NN_INT32}, 5164+ {DataType::kNumberTypeInt64, OH_NN_INT64}, 5165+ {DataType::kNumberTypeUInt8, OH_NN_UINT8}, 5166+ {DataType::kNumberTypeUInt16, OH_NN_UINT16}, 5167+ {DataType::kNumberTypeUInt32, OH_NN_UINT32}, 5168+ {DataType::kNumberTypeUInt64, OH_NN_UINT64}, 5169+ {DataType::kNumberTypeFloat16, OH_NN_FLOAT16}, 5170+ {DataType::kNumberTypeFloat32, OH_NN_FLOAT32}, 5171+ {DataType::kNumberTypeFloat64, OH_NN_FLOAT64}, 5172+ }; 5173 5174- return mindspore::kSuccess; 5175+ auto iter = kDataTypeMap.find(data_type); 5176+ if (iter == kDataTypeMap.end()) { 5177+ return OH_NN_UNKNOWN; 5178+ } 5179+ return iter->second; 5180 } 5181-OH_NN_DataType mindspore::NNRTDelegate::ConvertDataType(mindspore::DataType data_type) { 5182- OH_NN_DataType oh_data_type; 5183- switch (data_type) { 5184- case mindspore::DataType::kTypeUnknown: 5185- case mindspore::DataType::kObjectTypeString: 5186- case mindspore::DataType::kObjectTypeList: 5187- case mindspore::DataType::kObjectTypeTuple: 5188- case mindspore::DataType::kObjectTypeTensorType: 5189- case mindspore::DataType::kNumberTypeBegin: 5190- case mindspore::DataType::kNumberTypeEnd: 5191- case mindspore::DataType::kInvalidType: 5192- oh_data_type = OH_NN_UNKNOWN; 5193- break; 5194- case mindspore::DataType::kNumberTypeBool: 5195- oh_data_type = OH_NN_BOOL; 5196- break; 5197- case mindspore::DataType::kNumberTypeInt8: 5198- oh_data_type = OH_NN_INT8; 5199- break; 5200- case mindspore::DataType::kNumberTypeInt16: 5201- oh_data_type = OH_NN_INT16; 5202- break; 5203- case mindspore::DataType::kNumberTypeInt32: 5204- oh_data_type = OH_NN_INT32; 5205- break; 5206- case mindspore::DataType::kNumberTypeInt64: 5207- oh_data_type = OH_NN_INT64; 5208- break; 5209- case mindspore::DataType::kNumberTypeUInt8: 5210- oh_data_type = OH_NN_UINT8; 5211- break; 5212- case mindspore::DataType::kNumberTypeUInt16: 5213- oh_data_type = OH_NN_UINT16; 5214- break; 5215- case mindspore::DataType::kNumberTypeUInt32: 5216- oh_data_type = OH_NN_UINT32; 5217- break; 5218- case mindspore::DataType::kNumberTypeUInt64: 5219- oh_data_type = OH_NN_UINT64; 5220- break; 5221- case mindspore::DataType::kNumberTypeFloat16: 5222- oh_data_type = OH_NN_FLOAT16; 5223- break; 5224- case mindspore::DataType::kNumberTypeFloat32: 5225- oh_data_type = OH_NN_FLOAT32; 5226- break; 5227- case mindspore::DataType::kNumberTypeFloat64: 5228- oh_data_type = OH_NN_FLOAT64; 5229- break; 5230- default: { 5231- oh_data_type = OH_NN_UNKNOWN; 5232- } 5233- } 5234- return oh_data_type; 5235+ 5236+OH_NN_Format NNRTDelegate::CastToNNRTFormat(Format format) { 5237+ return OH_NN_FORMAT_NHWC; 5238 } 5239 5240-mindspore::Status mindspore::NNRTDelegate::PrepareOutputs(DelegateModel<schema::Primitive> *model, 5241- OH_NNExecutor *oh_nn_executor) { 5242+Status NNRTDelegate::PrepareOutputs(DelegateModel<schema::Primitive> *model, 5243+ OH_NNExecutor *oh_nn_executor) { 5244 auto output_tensors = model->outputs(); 5245 for (size_t i = 0; i < output_tensors.size(); i++) { 5246 auto tensor = output_tensors[i]; 5247@@ -280,17 +744,17 @@ mindspore::Status mindspore::NNRTDelegate::PrepareOutputs(DelegateModel<schema:: 5248 if (ret_code != OH_NN_SUCCESS) { 5249 MS_LOG(ERROR) << "NNExecutor SetOutput failed, current out tensor is" << tensor.Name() 5250 << ", OH_NN_ReturnCode = " << ret_code; 5251- return mindspore::kLiteError; 5252+ return kLiteError; 5253 } 5254 } 5255- return mindspore::kSuccess; 5256+ return kSuccess; 5257 } 5258 5259-void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGraph &lite_graph) { 5260+void NNRTDelegate::ShallowCopyLiteGraph(const lite::LiteGraph &lite_graph) { 5261 Status ret; 5262 for (auto node : lite_graph.all_nodes_) { 5263 ret = lite::CheckPrimitiveSupported(static_cast<const schema::Primitive *>(node->primitive_)); 5264- if (ret == mindspore::kLiteError) { 5265+ if (ret == kLiteError) { 5266 MS_LOG(ERROR) << " primitive supported check failed."; 5267 return; 5268 } 5269@@ -299,7 +763,7 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr 5270 node_list.reserve(lite_graph.all_nodes_.size()); 5271 // copy node 5272 for (auto node : lite_graph.all_nodes_) { 5273- auto new_node = new (std::nothrow) LiteGraph::Node; 5274+ auto new_node = new(std::nothrow) LiteGraph::Node; 5275 if (new_node == nullptr) { 5276 MS_LOG(ERROR) << " new LiteGraph::Node failed."; 5277 return; 5278@@ -318,7 +782,7 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr 5279 // copy subgraph 5280 std::vector<LiteGraph::SubGraph *> subgraph_list; 5281 for (auto subgraph : lite_graph.sub_graphs_) { 5282- auto new_subgraph = new (std::nothrow) LiteGraph::SubGraph; 5283+ auto new_subgraph = new(std::nothrow) LiteGraph::SubGraph; 5284 if (new_subgraph == nullptr) { 5285 MS_LOG(ERROR) << "new LiteGraph::Subgraph failed."; 5286 return; 5287@@ -331,30 +795,32 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr 5288 } 5289 for (auto tensor : lite_graph.all_tensors_) { 5290 ret = lite::CheckTensorSupported(static_cast<const schema::Tensor *>(tensor)); 5291- if (ret == mindspore::kLiteError) { 5292+ if (ret == kLiteError) { 5293 MS_LOG(ERROR) << "tensor supported check failed."; 5294 return; 5295 } 5296 } 5297 5298- nnrt_lite_graph = new (std::nothrow) lite::LiteGraph(); 5299- if (nnrt_lite_graph == nullptr) { 5300+ lite_graph_ = new(std::nothrow) lite::LiteGraph(); 5301+ if (lite_graph_ == nullptr) { 5302 MS_LOG(ERROR) << "new LiteGraph failed."; 5303 return; 5304 } 5305 5306- nnrt_lite_graph->name_ = lite_graph.name_; 5307- nnrt_lite_graph->version_ = lite_graph.version_; 5308- nnrt_lite_graph->input_indices_ = lite_graph.input_indices_; 5309- nnrt_lite_graph->output_indices_ = lite_graph.output_indices_; 5310- nnrt_lite_graph->all_tensors_ = lite_graph.all_tensors_; 5311- nnrt_lite_graph->all_nodes_ = node_list; 5312- nnrt_lite_graph->sub_graphs_ = subgraph_list; 5313+ lite_graph_->name_ = lite_graph.name_; 5314+ lite_graph_->version_ = lite_graph.version_; 5315+ lite_graph_->input_indices_ = lite_graph.input_indices_; 5316+ lite_graph_->output_indices_ = lite_graph.output_indices_; 5317+ lite_graph_->all_tensors_ = lite_graph.all_tensors_; 5318+ lite_graph_->all_nodes_ = node_list; 5319+ lite_graph_->sub_graphs_ = subgraph_list; 5320 MS_LOG(INFO) << "ShallowCopyLiteGraph success."; 5321 } 5322 5323-mindspore::NNRTDelegate::~NNRTDelegate() { 5324- if (this->nnrt_lite_graph != nullptr) { 5325+NNRTDelegate::~NNRTDelegate() { 5326+ if (lite_graph_ != nullptr) { 5327 MS_LOG(ERROR) << "Delete NNRTDelegate."; 5328 } 5329-}; 5330+} 5331+} // namespace lite 5332+} // namespace mindspore 5333diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h 5334index c2847704..52626339 100644 5335--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h 5336+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h 5337@@ -15,37 +15,81 @@ 5338 */ 5339 #ifndef MINDSPORE_NNR_DELEGATE_H 5340 #define MINDSPORE_NNR_DELEGATE_H 5341+ 5342 #include <vector> 5343 #include <map> 5344 #include "include/api/delegate.h" 5345 #include "include/model.h" 5346-#include "interfaces/kits/c/neural_network_runtime_type.h" 5347-namespace mindspore { 5348+#include "src/litert/inner_context.h" 5349+#include "nnrt_model_kernel.h" 5350+#include "schema/model_generated.h" 5351+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime_type.h" 5352+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 5353+#include "interfaces/innerkits/c/neural_network_runtime_inner.h" 5354 5355-using namespace lite; 5356+namespace mindspore { 5357+namespace lite { 5358+struct NNRTOpRange { 5359+ /* NNRT kernel range in DelegateModel: [begin_iter_, end_iter_) */ 5360+ KernelIter begin_iter_; 5361+ KernelIter end_iter_; 5362+ /* NNRT node range in lite_graph_: [begin_index_, end_index_) */ 5363+ size_t begin_index_; 5364+ size_t end_index_; 5365+}; 5366 5367 class NNRTDelegate : public Delegate { 5368 public: 5369- NNRTDelegate() : Delegate(){}; 5370- 5371+ NNRTDelegate() = default; 5372+ NNRTDelegate(const NNRtDeviceInfo &nnrt_device_info) : nnrt_device_info_(nnrt_device_info) {} 5373 ~NNRTDelegate() override; 5374- 5375- Status Init() override; 5376- 5377+ Status Init() override { return kSuccess; } 5378 Status Build(DelegateModel<schema::Primitive> *model) override; 5379- 5380 void ShallowCopyLiteGraph(const lite::LiteGraph &liteGraph); 5381- 5382- protected: 5383- LiteGraph *nnrt_lite_graph = nullptr; 5384+ void SetMetaGraph(const void *meta_graph) { 5385+ meta_graph_ = meta_graph; 5386+ } 5387+ static std::vector<NNRTOpRange> GetNNRTSubgraphRanges(DelegateModel<schema::Primitive> *model, 5388+ const std::vector<bool> &op_supports); 5389 5390 private: 5391- // static LiteGraph* CreateLiteGraph(const LiteGraph &liteGraph); 5392+ void InitCachePath(); 5393+ Status BuildNormalModel(DelegateModel<schema::Primitive> *model); 5394+ OH_NNModel *CreateFullNNModel(); 5395+ std::vector<bool> QueryOpSupports(OH_NNModel *nn_model); 5396+ Status CreateLiteGraphForNNRTSubgraph( 5397+ const std::vector<NNRTOpRange> &nnrt_op_ranges, 5398+ std::vector<LiteGraph *> *sub_lite_graphs); 5399+ Status CreateNNRTSubgraphKernels( 5400+ DelegateModel<schema::Primitive> *model, 5401+ const std::vector<LiteGraph *> &sub_lite_graphs, 5402+ const std::vector<NNRTOpRange> &nnrt_subgraph_ranges, 5403+ std::vector<NNRTModelKernel *> *nnrt_subgraph_kernels); 5404+ void ReplaceNNRTKernelsInDelegateModel(DelegateModel<schema::Primitive> *model, 5405+ const std::vector<NNRTOpRange> &nnrt_subgraph_ranges, 5406+ const std::vector<NNRTModelKernel *> &nnrt_subgraph_kernels); 5407 Status PrepareInputs(DelegateModel<schema::Primitive> *model, OH_NNExecutor *oh_nn_executor); 5408 Status PrepareOutputs(DelegateModel<schema::Primitive> *model, OH_NNExecutor *oh_nn_executor); 5409- OH_NN_DataType ConvertDataType(mindspore::DataType data_type); 5410-}; 5411+ Status InitNNCompilation(OH_NNCompilation *nn_compilation) const; 5412+ static OH_NN_DataType CastToNNRTDataType(mindspore::DataType data_type); 5413+ static OH_NN_Format CastToNNRTFormat(Format format); 5414+ bool IsCustomModel() const; 5415+ 5416+#ifdef SUPPORT_NNRT_METAGRAPH 5417+ bool IsKirinNPU() const; 5418+ Status BuildKirinNPUModel(DelegateModel<schema::Primitive> *model); 5419+ Status SetKirinModelInputsAndOutputs(OH_NNModel *nn_model); 5420+ std::vector<OH_NN_TensorInfo> CreateNNTensorInfos(const std::vector<uint32_t> &indices) const; 5421+ Status CreateFullModelKernel(DelegateModel<schema::Primitive> *model, OH_NNModel *nn_model); 5422+#endif 5423 5424+ NNRtDeviceInfo nnrt_device_info_; 5425+ LiteGraph *lite_graph_ = nullptr; 5426+ const void *meta_graph_ = nullptr; 5427+ std::string cache_path_ = ""; 5428+ uint32_t cache_version_ = 0; 5429+}; 5430+} // namespace lite 5431 } // namespace mindspore 5432 5433 #endif // MINDSPORE_NNR_DELEGATE_H 5434diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc 5435index 5acf2e9a..67443e08 100644 5436--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc 5437+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc 5438@@ -97,7 +97,7 @@ OH_NN_DataType mindspore::NNRTModelKernel::ConvertDataType(mindspore::DataType d 5439 } 5440 int mindspore::NNRTModelKernel::PrepareInputs() { 5441 auto input_tensors = this->inputs(); 5442- for (int i = 0; i < input_tensors.size(); i++) { 5443+ for (size_t i = 0; i < input_tensors.size(); i++) { 5444 auto tensor = input_tensors[i]; 5445 auto tensor_shape = tensor.Shape(); 5446 auto tmp_quant_param = tensor.QuantParams(); 5447@@ -142,6 +142,7 @@ int mindspore::NNRTModelKernel::PrepareInputs() { 5448 oprend->dimensions = dimensions_list.data(); 5449 oprend->quantParam = quant_param; 5450 oprend->type = OH_NN_TENSOR; 5451+ MS_LOG_INFO << "input tensor: " << tensor.Name() << ", data: " << (void *)tensor.MutableData() << ", size: " << tensor.DataSize(); 5452 OH_NN_ReturnCode ret_code = 5453 OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize()); 5454 delete (oprend); 5455diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h 5456index cf9481df..ea15f7ca 100644 5457--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h 5458+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h 5459@@ -20,7 +20,7 @@ 5460 #include <map> 5461 #include <utility> 5462 #include "include/api/kernel.h" 5463-#include "interfaces/kits/c/neural_network_runtime.h" 5464+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 5465 #include "src/common/log_adapter.h" 5466 #include "include/errorcode.h" 5467 5468diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc 5469new file mode 100644 5470index 00000000..8ac283af 5471--- /dev/null 5472+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc 5473@@ -0,0 +1,99 @@ 5474+/** 5475+* Copyright 2023 Huawei Technologies Co., Ltd 5476+* 5477+* Licensed under the Apache License, Version 2.0 (the "License"); 5478+* you may not use this file except in compliance with the License. 5479+* You may obtain a copy of the License at 5480+* 5481+* http://www.apache.org/licenses/LICENSE-2.0 5482+* 5483+* Unless required by applicable law or agreed to in writing, software 5484+* distributed under the License is distributed on an "AS IS" BASIS, 5485+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5486+* See the License for the specific language governing permissions and 5487+* limitations under the License. 5488+*/ 5489+ 5490+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 5491+#include "interfaces/innerkits/c/neural_network_runtime_inner.h" 5492+ 5493+OH_NNModel *OH_NNModel_Construct(void) { 5494+ return NULL; 5495+} 5496+ 5497+OH_NN_ReturnCode OH_NNExecutor_Run(OH_NNExecutor *executor) { 5498+ return OH_NN_SUCCESS; 5499+} 5500+ 5501+OH_NN_ReturnCode OH_NNCompilation_Build(OH_NNCompilation *compilation) { 5502+ return OH_NN_SUCCESS; 5503+} 5504+ 5505+void OH_NNCompilation_Destroy(OH_NNCompilation **compilation) {} 5506+ 5507+OH_NNExecutor *OH_NNExecutor_Construct(OH_NNCompilation *compilation) { 5508+ return NULL; 5509+} 5510+ 5511+void OH_NNExecutor_Destroy(OH_NNExecutor **executor) {} 5512+ 5513+OH_NNCompilation *OH_NNCompilation_Construct(const OH_NNModel *model) { 5514+ return NULL; 5515+} 5516+ 5517+OH_NN_ReturnCode OH_NNDevice_GetAllDevicesID(const size_t **allDevicesID, uint32_t *deviceCount) { 5518+ return OH_NN_SUCCESS; 5519+} 5520+ 5521+OH_NN_ReturnCode OH_NNExecutor_SetOutput(OH_NNExecutor *executor, 5522+ uint32_t outputIndex, 5523+ void *dataBuffer, 5524+ size_t length) { 5525+ return OH_NN_SUCCESS; 5526+} 5527+ 5528+OH_NN_ReturnCode OH_NNCompilation_SetDevice(OH_NNCompilation *compilation, size_t deviceID) { 5529+ return OH_NN_SUCCESS; 5530+} 5531+ 5532+OH_NN_ReturnCode OH_NNExecutor_SetInput(OH_NNExecutor *executor, 5533+ uint32_t inputIndex, 5534+ const OH_NN_Tensor *tensor, 5535+ const void *dataBuffer, 5536+ size_t length) { 5537+ return OH_NN_SUCCESS; 5538+} 5539+ 5540+void OH_NNModel_Destroy(OH_NNModel **model) {} 5541+ 5542+OH_NN_ReturnCode OH_NNModel_GetAvailableOperations(OH_NNModel *model, 5543+ size_t deviceID, 5544+ const bool **isSupported, 5545+ uint32_t *opCount) { 5546+ return OH_NN_SUCCESS; 5547+} 5548+ 5549+OH_NN_ReturnCode OH_NNModel_BuildFromLiteGraph(OH_NNModel *model, const void *liteGraph) { 5550+ return OH_NN_SUCCESS; 5551+} 5552+ 5553+OH_NN_ReturnCode OH_NNDevice_GetName(size_t deviceID, const char **name) { 5554+ return OH_NN_SUCCESS; 5555+} 5556+ 5557+OH_NN_ReturnCode OH_NNDevice_GetType(size_t deviceID, OH_NN_DeviceType *deviceType) { 5558+ return OH_NN_SUCCESS; 5559+} 5560+ 5561+OH_NN_ReturnCode OH_NNCompilation_SetPriority(OH_NNCompilation *compilation, OH_NN_Priority priority) { 5562+ return OH_NN_SUCCESS; 5563+} 5564+ 5565+OH_NN_ReturnCode OH_NNCompilation_EnableFloat16(OH_NNCompilation *compilation, bool enableFloat16) { 5566+ return OH_NN_SUCCESS; 5567+} 5568+ 5569+OH_NN_ReturnCode OH_NNCompilation_SetPerformanceMode(OH_NNCompilation *compilation, 5570+ OH_NN_PerformanceMode performanceMode) { 5571+ return OH_NN_SUCCESS; 5572+} 5573\ No newline at end of file 5574diff --git a/mindspore/lite/src/litert/infer_manager.cc b/mindspore/lite/src/litert/infer_manager.cc 5575index 2b21d1ca..908ab122 100644 5576--- a/mindspore/lite/src/litert/infer_manager.cc 5577+++ b/mindspore/lite/src/litert/infer_manager.cc 5578@@ -162,7 +162,8 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto 5579 if (parameter->type_ == static_cast<int>(schema::PrimitiveType_PartialFusion) || 5580 parameter->type_ == static_cast<int>(schema::PrimitiveType_Switch) || 5581 parameter->type_ == static_cast<int>(schema::PrimitiveType_Call) || 5582- parameter->type_ == static_cast<int>(schema::PrimitiveType_SwitchLayer)) { 5583+ parameter->type_ == static_cast<int>(schema::PrimitiveType_SwitchLayer) || 5584+ parameter->type_ == static_cast<int>(PrimType_Inner_ThirdPartyModel)) { 5585 MS_LOG(INFO) << "no need infer shape."; 5586 return RET_OK; 5587 } 5588diff --git a/mindspore/lite/src/litert/inner_context.cc b/mindspore/lite/src/litert/inner_context.cc 5589index 7cbac8f7..bf585ff0 100644 5590--- a/mindspore/lite/src/litert/inner_context.cc 5591+++ b/mindspore/lite/src/litert/inner_context.cc 5592@@ -122,6 +122,10 @@ int InnerContext::Init() { 5593 #endif 5594 } 5595 5596+ if (IsDeviceTypeEnabled(DT_NNRT)) { 5597+ MS_LOG(DEBUG) << "NNRT enabled."; 5598+ } 5599+ 5600 if (CreateThreadPool(false)) { 5601 MS_LOG(ERROR) << "CreateThreadPool failed."; 5602 return RET_ERROR; 5603diff --git a/mindspore/lite/src/litert/inner_context.h b/mindspore/lite/src/litert/inner_context.h 5604index 88281eb1..8735961c 100644 5605--- a/mindspore/lite/src/litert/inner_context.h 5606+++ b/mindspore/lite/src/litert/inner_context.h 5607@@ -71,12 +71,26 @@ typedef struct CustomDeviceInfo { 5608 std::shared_ptr<DeviceInfoContext> user_defined_device_info_; 5609 } CustomDeviceInfo; 5610 5611+typedef struct Extension { 5612+ std::string name; // config name 5613+ std::vector<uint8_t> value; // config value 5614+} Extension; 5615+ 5616+typedef struct NNRtDeviceInfo { 5617+ size_t device_id_ = 0; 5618+ int priority_ = 0; 5619+ int performance_mode_ = 0; 5620+ bool enable_fp16_ = false; 5621+ std::vector<Extension> extensions_; 5622+} NNRtDeviceInfo; 5623+ 5624 struct DeviceInfo { 5625 CpuDeviceInfo cpu_device_info_; 5626 GpuDeviceInfo gpu_device_info_; 5627 NpuDeviceInfo npu_device_info_; 5628 AscendDeviceInfo ascend_device_info_; 5629 CustomDeviceInfo custom_device_info_; 5630+ NNRtDeviceInfo nnrt_device_info_; 5631 }; 5632 5633 struct DeviceContext { 5634diff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5635index 48308425..65065b5b 100644 5636--- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5637+++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5638@@ -13,6 +13,10 @@ cpu_kernel_sources = [ 5639 "base/call.cc", 5640 "base/constant_of_shape.cc", 5641 "base/convolution_base.cc", 5642+ "base/custom_base.cc", 5643+ "base/custom_masked_fill.cc", 5644+ "base/custom_is_inf.cc", 5645+ "base/custom_tensor_scatter.cc", 5646 "base/detection_post_process_base.cc", 5647 "base/format_transpose.cc", 5648 "base/group_convolution_base.cc", 5649@@ -37,7 +41,6 @@ cpu_kernel_sources = [ 5650 "fp32/batchnorm_fp32.cc", 5651 "fp32/batch_to_space_fp32.cc", 5652 "fp32/broadcast_to_fp32.cc", 5653- "fp32/cast_for_x86_fp16.cc", 5654 "fp32/cast_fp32.cc", 5655 "fp32/convolution_1x1_fp32.cc", 5656 "fp32/convolution_delegate_fp32.cc", 5657@@ -118,6 +121,10 @@ cpu_kernel_sources = [ 5658 "fp32/online_fusion/split_reduce_concat_fp32.cc", 5659 ] 5660 5661+if ((target_cpu != "arm") && (target_cpu != "arm64")) { 5662+ cpu_kernel_sources += [ "src/runtime/kernel/cpu/fp32/cast_for_x86_fp16.cc" ] 5663+} 5664+ 5665 arm64_cpu_kernel_sources = [ 5666 "fp32/convolution_im2col_arm64_fp32.cc", 5667 "fp32/matmul_fp32_arm64.cc", 5668@@ -142,6 +149,42 @@ sse_avx_avx512_kernel_sources = [ 5669 "fp32/matmul_fp32_avx512.cc", 5670 ] 5671 5672+fp16_kernel_sources = [ 5673+ "fp16/batchnorm_fp16.cc", 5674+ "fp16/biasadd_fp16.cc", 5675+ "fp16/cast_fp16.cc", 5676+ "fp16/common_fp16.cc", 5677+ "fp16/convolution_1x1_fp16.cc", 5678+ "fp16/convolution_delegate_fp16.cc", 5679+ "fp16/convolution_depthwise_3x3_fp16.cc", 5680+ "fp16/convolution_depthwise_fp16.cc", 5681+ "fp16/convolution_depthwise_slidewindow_fp16.cc", 5682+ "fp16/convolution_fp16.cc", 5683+ "fp16/convolution_winograd_fp16.cc", 5684+ "fp16/custom_gru_fp16.cc", 5685+ "fp16/deconvolution_depthwise_fp16.cc", 5686+ "fp16/deconvolution_fp16.cc", 5687+ "fp16/deconvolution_winograd_fp16.cc", 5688+ "fp16/depth_to_space_fp16.cc", 5689+ "fp16/dynamic_quant_fp16.cc", 5690+ "fp16/fullconnection_fp16.cc", 5691+ "fp16/fused_batchnorm_fp16.cc", 5692+ "fp16/group_convolution_fp16.cc", 5693+ "fp16/gru_fp16.cc", 5694+ "fp16/instance_norm_fp16.cc", 5695+ "fp16/layout_transform_fp16.cc", 5696+ "fp16/lstm_fp16.cc", 5697+ "fp16/matmul_base_fp16.cc", 5698+ "fp16/matmul_fp16.cc", 5699+ "fp16/power_fp16.cc", 5700+ "fp16/prelu_fp16.cc", 5701+ "fp16/quant_dtype_cast_fp16.cc", 5702+ "fp16/reduce_fp16.cc", 5703+ "fp16/resize_fp16.cc", 5704+ "fp16/slice_fp16.cc", 5705+ "fp16/where_fp16.cc", 5706+] 5707+ 5708 int8_kernel_sources = [ 5709 "int8/activation_int8.cc", 5710 "int8/add_int8.cc", 5711@@ -227,6 +270,12 @@ all_cpu_kernel_sources += int8_kernel_sources 5712 all_cpu_kernel_sources += string_kernel_sources 5713 all_cpu_kernel_sources += control_kernel_sources 5714 5715+if (target_cpu == "arm64") { 5716+ all_cpu_kernel_sources += fp16_kernel_sources 5717+} else { 5718+ not_needed(fp16_kernel_sources) 5719+} 5720+ 5721 if (target_cpu == "arm") { 5722 all_cpu_kernel_sources -= arm64_cpu_kernel_sources 5723 all_cpu_kernel_sources -= sse_avx_avx512_kernel_sources 5724diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc 5725new file mode 100644 5726index 00000000..9921e063 5727--- /dev/null 5728+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc 5729@@ -0,0 +1,46 @@ 5730+/** 5731+ * Copyright 2022 Huawei Technologies Co., Ltd 5732+ * 5733+ * Licensed under the Apache License, Version 2.0 (the "License"); 5734+ * you may not use this file except in compliance with the License. 5735+ * You may obtain a copy of the License at 5736+ * 5737+ * http://www.apache.org/licenses/LICENSE-2.0 5738+ * 5739+ * Unless required by applicable law or agreed to in writing, software 5740+ * distributed under the License is distributed on an "AS IS" BASIS, 5741+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5742+ * See the License for the specific language governing permissions and 5743+ * limitations under the License. 5744+ */ 5745+ 5746+#include "src/litert/kernel/cpu/base/custom_base.h" 5747+#include <algorithm> 5748+#include <utility> 5749+#include <vector> 5750+#include "src/litert/kernel_registry.h" 5751+#include "nnacl/op_base.h" 5752+ 5753+using mindspore::kernel::KERNEL_ARCH; 5754+using mindspore::lite::KernelRegistrar; 5755+using mindspore::lite::RET_ERROR; 5756+using mindspore::lite::RET_OK; 5757+using mindspore::schema::PrimitiveType_Custom; 5758+ 5759+namespace mindspore::kernel { 5760+int CustomBaseCPUKernel::Prepare() { 5761+ return RET_OK; 5762+} 5763+ 5764+int CustomBaseCPUKernel::ReSize() { 5765+ return RET_OK; 5766+} 5767+ 5768+int CustomBaseCPUKernel::Run() { 5769+ return RET_OK; 5770+} 5771+ 5772+REG_KERNEL(kCPU, kNumberTypeInt32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>) 5773+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>) 5774+REG_KERNEL(kCPU, kNumberTypeBool, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>) 5775+} // namespace mindspore::kernel 5776diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h 5777new file mode 100644 5778index 00000000..ecb4c72d 5779--- /dev/null 5780+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h 5781@@ -0,0 +1,43 @@ 5782+/** 5783+ * Copyright 2022 Huawei Technologies Co., Ltd 5784+ * 5785+ * Licensed under the Apache License, Version 2.0 (the "License"); 5786+ * you may not use this file except in compliance with the License. 5787+ * You may obtain a copy of the License at 5788+ * 5789+ * http://www.apache.org/licenses/LICENSE-2.0 5790+ * 5791+ * Unless required by applicable law or agreed to in writing, software 5792+ * distributed under the License is distributed on an "AS IS" BASIS, 5793+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5794+ * See the License for the specific language governing permissions and 5795+ * limitations under the License. 5796+ */ 5797+ 5798+#ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_ 5799+#define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_ 5800+ 5801+#include <vector> 5802+#include "src/litert/lite_kernel.h" 5803+#include "nnacl/custom_parameter.h" 5804+ 5805+namespace mindspore::kernel { 5806+class CustomBaseCPUKernel : public LiteKernel { 5807+ public: 5808+ CustomBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 5809+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 5810+ : LiteKernel(parameter, inputs, outputs, ctx) { 5811+ custom_param_ = reinterpret_cast<CustomParameter *>(op_parameter_); 5812+ } 5813+ ~CustomBaseCPUKernel() override = default; 5814+ 5815+ int Prepare() override; 5816+ int ReSize() override; 5817+ int Run() override; 5818+ 5819+ private: 5820+ CustomParameter *custom_param_ = nullptr; 5821+}; 5822+} // namespace mindspore::kernel 5823+ 5824+#endif // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_ 5825diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc 5826new file mode 100644 5827index 00000000..edffea42 5828--- /dev/null 5829+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc 5830@@ -0,0 +1,61 @@ 5831+/** 5832+ * Copyright 2023 Huawei Technologies Co., Ltd 5833+ * 5834+ * Licensed under the Apache License, Version 2.0 (the "License"); 5835+ * you may not use this file except in compliance with the License. 5836+ * You may obtain a copy of the License at 5837+ * 5838+ * http://www.apache.org/licenses/LICENSE-2.0 5839+ * 5840+ * Unless required by applicable law or agreed to in writing, software 5841+ * distributed under the License is distributed on an "AS IS" BASIS, 5842+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5843+ * See the License for the specific language governing permissions and 5844+ * limitations under the License. 5845+ */ 5846+#include "src/litert/kernel_registry.h" 5847+#include "include/errorcode.h" 5848+#include "src/litert/kernel/cpu/base/custom_is_inf.h" 5849+#include "src/common/tensor_util.h" 5850+#include "nnacl/op_base.h" 5851+ 5852+using mindspore::lite::KernelRegistrar; 5853+using mindspore::lite::RET_ERROR; 5854+using mindspore::lite::RET_OK; 5855+ 5856+namespace mindspore::kernel { 5857+ 5858+int CustomIsInfCPUKernel::Prepare() { 5859+ CHECK_LESS_RETURN(in_tensors_.size(), C1NUM); 5860+ CHECK_LESS_RETURN(out_tensors_.size(), C1NUM); 5861+ return RET_OK; 5862+} 5863+ 5864+int CustomIsInfCPUKernel::ReSize() { return RET_OK; } 5865+ 5866+void CustomIsInfCPUKernel::LaunchKernelFloat(const float *input, bool *output) { 5867+ auto elem_num = in_tensors_[FIRST_INPUT]->ElementsNum(); 5868+ 5869+ for (int i = 0; i < elem_num; i++) { 5870+ output[i] = std::isinf(input[i]); 5871+ } 5872+} 5873+ 5874+int CustomIsInfCPUKernel::Run() { 5875+ auto input = in_tensors_[FIRST_INPUT]; 5876+ auto output = out_tensors_[FIRST_INPUT]; 5877+ CHECK_NULL_RETURN(input); 5878+ CHECK_NULL_RETURN(output); 5879+ 5880+ if (input->data_type() == kNumberTypeFloat32 || input->data_type() == kNumberTypeFloat) { 5881+ LaunchKernelFloat(reinterpret_cast<const float *>(input->data()), reinterpret_cast<bool *>(output->data())); 5882+ } else { 5883+ MS_LOG(ERROR) << "unsupported input data type " << input->data_type(); 5884+ return RET_ERROR; 5885+ } 5886+ 5887+ return RET_OK; 5888+} 5889+ 5890+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomIsInf, LiteKernelCreator<CustomIsInfCPUKernel>) 5891+} // namespace mindspore::kernel 5892diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h 5893new file mode 100644 5894index 00000000..e63d8ec7 5895--- /dev/null 5896+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h 5897@@ -0,0 +1,38 @@ 5898+/** 5899+ * Copyright 2023 Huawei Technologies Co., Ltd 5900+ * 5901+ * Licensed under the Apache License, Version 2.0 (the "License"); 5902+ * you may not use this file except in compliance with the License. 5903+ * You may obtain a copy of the License at 5904+ * 5905+ * http://www.apache.org/licenses/LICENSE-2.0 5906+ * 5907+ * Unless required by applicable law or agreed to in writing, software 5908+ * distributed under the License is distributed on an "AS IS" BASIS, 5909+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5910+ * See the License for the specific language governing permissions and 5911+ * limitations under the License. 5912+ */ 5913+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_ 5914+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_ 5915+ 5916+#include <vector> 5917+#include "src/litert/lite_kernel.h" 5918+ 5919+namespace mindspore::kernel { 5920+class CustomIsInfCPUKernel : public LiteKernel { 5921+ public: 5922+ CustomIsInfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 5923+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 5924+ : LiteKernel(parameter, inputs, outputs, ctx) {} 5925+ ~CustomIsInfCPUKernel() override = default; 5926+ int Prepare() override; 5927+ int ReSize() override; 5928+ int Run() override; 5929+ 5930+ private: 5931+ void LaunchKernelFloat(const float *input, bool *output); 5932+}; 5933+} // namespace mindspore::kernel 5934+ 5935+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_ 5936diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc 5937new file mode 100644 5938index 00000000..9af1af5d 5939--- /dev/null 5940+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc 5941@@ -0,0 +1,84 @@ 5942+/** 5943+ * Copyright 2023 Huawei Technologies Co., Ltd 5944+ * 5945+ * Licensed under the Apache License, Version 2.0 (the "License"); 5946+ * you may not use this file except in compliance with the License. 5947+ * You may obtain a copy of the License at 5948+ * 5949+ * http://www.apache.org/licenses/LICENSE-2.0 5950+ * 5951+ * Unless required by applicable law or agreed to in writing, software 5952+ * distributed under the License is distributed on an "AS IS" BASIS, 5953+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5954+ * See the License for the specific language governing permissions and 5955+ * limitations under the License. 5956+ */ 5957+#include "src/litert/kernel_registry.h" 5958+#include "include/errorcode.h" 5959+#include "src/litert/kernel/cpu/base/custom_masked_fill.h" 5960+#include "src/common/tensor_util.h" 5961+#include "nnacl/op_base.h" 5962+ 5963+using mindspore::lite::KernelRegistrar; 5964+using mindspore::lite::RET_ERROR; 5965+using mindspore::lite::RET_OK; 5966+ 5967+namespace mindspore::kernel { 5968+ 5969+int CustomMaskedFillCPUKernel::Prepare() { 5970+ CHECK_LESS_RETURN(in_tensors_.size(), C3NUM); 5971+ CHECK_LESS_RETURN(out_tensors_.size(), C1NUM); 5972+ 5973+ // only support input value as a single float value 5974+ MS_CHECK_TRUE_MSG(in_tensors_[FIRST_INPUT]->data_type() == mindspore::TypeId::kNumberTypeFloat32 || 5975+ in_tensors_[FIRST_INPUT]->data_type() == mindspore::TypeId::kNumberTypeFloat, 5976+ RET_ERROR, "input dtype must be float32"); 5977+ if (in_tensors_[THIRD_INPUT]->ElementsNum() != 1) { 5978+ MS_LOG(ERROR) << "only support fill value as a single float"; 5979+ return RET_ERROR; 5980+ } 5981+ MS_CHECK_TRUE_MSG(in_tensors_[SECOND_INPUT]->data_type() == mindspore::TypeId::kNumberTypeBool, RET_ERROR, 5982+ "mask dtype must be bool"); 5983+ if (!InferShapeDone()) { 5984+ return RET_OK; 5985+ } 5986+ return ReSize(); 5987+} 5988+ 5989+int CustomMaskedFillCPUKernel::ReSize() { return RET_OK; } 5990+ 5991+int CustomMaskedFillCPUKernel::Run() { 5992+ auto input = in_tensors_[FIRST_INPUT]; 5993+ auto mask = in_tensors_[SECOND_INPUT]; 5994+ auto value = in_tensors_[THIRD_INPUT]; 5995+ auto output = out_tensors_[FIRST_INPUT]; 5996+ CHECK_NULL_RETURN(input); 5997+ CHECK_NULL_RETURN(mask); 5998+ CHECK_NULL_RETURN(value); 5999+ CHECK_NULL_RETURN(output); 6000+ 6001+ if (input->shape() != mask->shape()) { 6002+ MS_LOG(ERROR) << "Not support broadcast mask to input"; 6003+ return RET_ERROR; 6004+ } 6005+ 6006+ auto value_data = reinterpret_cast<float *>(value->data()); 6007+ auto fill_value = value_data[0]; 6008+ 6009+ auto data_num = input->ElementsNum(); 6010+ auto input_data = reinterpret_cast<float *>(input->data()); 6011+ auto mask_data = reinterpret_cast<bool *>(mask->data()); 6012+ auto output_data = reinterpret_cast<float *>(output->data()); 6013+ for (int64_t i = 0; i < data_num; i++) { 6014+ if (mask_data[i]) { 6015+ output_data[i] = fill_value; 6016+ } else { 6017+ output_data[i] = input_data[i]; 6018+ } 6019+ } 6020+ 6021+ return RET_OK; 6022+} 6023+ 6024+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomMaskedFill, LiteKernelCreator<CustomMaskedFillCPUKernel>) 6025+} // namespace mindspore::kernel 6026diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h 6027new file mode 100644 6028index 00000000..04a2dcab 6029--- /dev/null 6030+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h 6031@@ -0,0 +1,35 @@ 6032+/** 6033+ * Copyright 2023 Huawei Technologies Co., Ltd 6034+ * 6035+ * Licensed under the Apache License, Version 2.0 (the "License"); 6036+ * you may not use this file except in compliance with the License. 6037+ * You may obtain a copy of the License at 6038+ * 6039+ * http://www.apache.org/licenses/LICENSE-2.0 6040+ * 6041+ * Unless required by applicable law or agreed to in writing, software 6042+ * distributed under the License is distributed on an "AS IS" BASIS, 6043+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6044+ * See the License for the specific language governing permissions and 6045+ * limitations under the License. 6046+ */ 6047+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_ 6048+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_ 6049+ 6050+#include <vector> 6051+#include "src/litert/lite_kernel.h" 6052+ 6053+namespace mindspore::kernel { 6054+class CustomMaskedFillCPUKernel : public LiteKernel { 6055+ public: 6056+ CustomMaskedFillCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6057+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6058+ : LiteKernel(parameter, inputs, outputs, ctx) {} 6059+ ~CustomMaskedFillCPUKernel() override = default; 6060+ int Prepare() override; 6061+ int ReSize() override; 6062+ int Run() override; 6063+}; 6064+} // namespace mindspore::kernel 6065+ 6066+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_ 6067diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc 6068new file mode 100644 6069index 00000000..d52d67d5 6070--- /dev/null 6071+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc 6072@@ -0,0 +1,75 @@ 6073+/** 6074+ * Copyright 2022 Huawei Technologies Co., Ltd 6075+ * 6076+ * Licensed under the Apache License, Version 2.0 (the "License"); 6077+ * you may not use this file except in compliance with the License. 6078+ * You may obtain a copy of the License at 6079+ * 6080+ * http://www.apache.org/licenses/LICENSE-2.0 6081+ * 6082+ * Unless required by applicable law or agreed to in writing, software 6083+ * distributed under the License is distributed on an "AS IS" BASIS, 6084+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6085+ * See the License for the specific language governing permissions and 6086+ * limitations under the License. 6087+ */ 6088+ 6089+#include "src/litert/kernel/cpu/base/custom_tensor_scatter.h" 6090+#include <cstring> 6091+#include "schema/model_generated.h" 6092+#include "src/litert/kernel_registry.h" 6093+#include "include/errorcode.h" 6094+#include "nnacl/base/scatter_nd_binary.h" 6095+ 6096+using mindspore::kernel::KERNEL_ARCH; 6097+using mindspore::lite::KernelRegistrar; 6098+using mindspore::lite::RET_ERROR; 6099+using mindspore::lite::RET_OK; 6100+ 6101+namespace mindspore::kernel { 6102+namespace { 6103+int TensorScatterRun(void *cdata, int task_id, float, float) { 6104+ auto kernel = static_cast<CustomTensorScatterCPUKernel *>(cdata); 6105+ CHECK_NULL_RETURN(kernel); 6106+ return kernel->TensorScatterDispatch(task_id); 6107+} 6108+} // namespace 6109+ 6110+int CustomTensorScatterCPUKernel::TensorScatterDispatch(int task_id) { 6111+ auto data_type = in_tensors_[kScatterUpdateInputIndex]->data_type(); 6112+ if (data_type != kNumberTypeFloat32) { 6113+ MS_LOG(ERROR) << "TensorScatterMax only support float32 input tensor, but got " << data_type; 6114+ return RET_ERROR; 6115+ } 6116+ int type = data_type == kNumberTypeFloat32 ? 0 : 1; 6117+ // multi thread have some problems to solve 6118+ param_->op_parameter.thread_num_ = 1; 6119+ auto ret = ScatterNDMax(in_tensors_[kScatterUpdateIndex]->data(), out_tensors_[kOutputIndex]->data(), 6120+ output_unit_offsets_.data(), param_, type, task_id); 6121+ if (ret != RET_OK) { 6122+ MS_LOG(ERROR) << "ScatterNDMax failed, ret: " << ret; 6123+ return RET_ERROR; 6124+ } 6125+ return RET_OK; 6126+} 6127+ 6128+int CustomTensorScatterCPUKernel::Run() { 6129+ auto in_tensor = in_tensors().front(); 6130+ auto out_tensor = out_tensors().front(); 6131+ (void)memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size()); 6132+ auto indices = in_tensors_.at(kScatterIndicesIndex); 6133+ if (!indices->IsConst() && ReSize() != RET_OK) { 6134+ MS_LOG(ERROR) << "TensorScatterAdd resize failed."; 6135+ return RET_ERROR; 6136+ } 6137+ 6138+ auto ret = ParallelLaunch(ms_context_, TensorScatterRun, this, op_parameter_->thread_num_); 6139+ if (ret != RET_OK) { 6140+ MS_LOG(ERROR) << "TensorScatterAdd error error_code[" << ret << "]"; 6141+ } 6142+ return ret; 6143+} 6144+ 6145+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomTensorScatterMax, 6146+ LiteKernelCreator<CustomTensorScatterCPUKernel>) 6147+} // namespace mindspore::kernel 6148diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h 6149new file mode 100644 6150index 00000000..e39733c5 6151--- /dev/null 6152+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h 6153@@ -0,0 +1,36 @@ 6154+/** 6155+ * Copyright 2022 Huawei Technologies Co., Ltd 6156+ * 6157+ * Licensed under the Apache License, Version 2.0 (the "License"); 6158+ * you may not use this file except in compliance with the License. 6159+ * You may obtain a copy of the License at 6160+ * 6161+ * http://www.apache.org/licenses/LICENSE-2.0 6162+ * 6163+ * Unless required by applicable law or agreed to in writing, software 6164+ * distributed under the License is distributed on an "AS IS" BASIS, 6165+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6166+ * See the License for the specific language governing permissions and 6167+ * limitations under the License. 6168+ */ 6169+ 6170+#ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_ 6171+#define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_ 6172+ 6173+#include <vector> 6174+#include "src/litert/kernel/cpu/base/scatter_nd_binary.h" 6175+ 6176+namespace mindspore::kernel { 6177+class CustomTensorScatterCPUKernel : public ScatterNDBinaryCPUKernel { 6178+ public: 6179+ explicit CustomTensorScatterCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6180+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6181+ : ScatterNDBinaryCPUKernel(parameter, inputs, outputs, ctx) {} 6182+ ~CustomTensorScatterCPUKernel() override = default; 6183+ 6184+ int Run() override; 6185+ int TensorScatterDispatch(int task_id); 6186+}; 6187+} // namespace mindspore::kernel 6188+ 6189+#endif // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_ 6190diff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc 6191index 2c5bc658..13652633 100644 6192--- a/mindspore/lite/src/litert/lite_model.cc 6193+++ b/mindspore/lite/src/litert/lite_model.cc 6194@@ -98,6 +98,8 @@ int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) { 6195 if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr || 6196 sub_graph.tensorIndices() == nullptr) { 6197 MS_LOG(ERROR) << "sub_graph is invalid"; 6198+ MS_LOG(ERROR) << "sub_graph.name() = " << sub_graph.name() << ", sub_graph.inputIndices() = " << sub_graph.inputIndices() 6199+ << ", sub_graph.outputIndices() = " << sub_graph.outputIndices() << ", sub_graph.tensorIndices() = " << sub_graph.tensorIndices(); 6200 return RET_ERROR; 6201 } 6202 6203@@ -620,6 +622,33 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, minds 6204 return model; 6205 } 6206 6207+std::string LiteGraph::ToString() const { 6208+ std::stringstream ss; 6209+ ss << "all_nodes: " << all_nodes_.size() << std::endl; 6210+ for (size_t i = 0; i < all_nodes_.size(); i++) { 6211+ ss << "- node " << i << ": " << all_nodes_[i]->primitive_ << std::endl; 6212+ ss << "- node " << i << " input_indices_: " << all_nodes_[i]->input_indices_ << std::endl; 6213+ ss << "- node " << i << " output_indices_: " << all_nodes_[i]->output_indices_ << std::endl; 6214+ } 6215+ ss << "all_tensors: " << all_tensors_.size() << std::endl; 6216+ for (size_t i = 0; i < all_tensors_.size(); i++) { 6217+ ss << "- tensor " << i << ": " << all_tensors_[i] << std::endl; 6218+ } 6219+ ss << "input_indices: " << input_indices_<< std::endl; 6220+ ss << "output_indices: " << output_indices_ << std::endl; 6221+ 6222+ ss << "subgraphs: " << std::endl; 6223+ int count = 0; 6224+ for (auto subgraph: sub_graphs_) { 6225+ ss << "- subgraph " << count++ << std::endl; 6226+ ss << "--- subgraph input " << subgraph->input_indices_ << std::endl; 6227+ ss << "--- subgraph output " << subgraph->output_indices_ << std::endl; 6228+ ss << "--- subgraph node " << subgraph->node_indices_ << std::endl; 6229+ ss << "--- subgraph tensor " << subgraph->tensor_indices_ << std::endl; 6230+ } 6231+ return ss.str(); 6232+} 6233+ 6234 Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } 6235 6236 Model *Model::Import(const char *filename) { return ImportFromPath(filename); } 6237diff --git a/mindspore/lite/src/litert/lite_session.cc b/mindspore/lite/src/litert/lite_session.cc 6238index 8f54879e..f635c8d2 100644 6239--- a/mindspore/lite/src/litert/lite_session.cc 6240+++ b/mindspore/lite/src/litert/lite_session.cc 6241@@ -67,6 +67,9 @@ 6242 #include "thread/parallel_thread_pool_manager.h" 6243 #endif 6244 #include "src/litert/runtime_packed_node_pass.h" 6245+#ifdef SUPPORT_NNRT 6246+#include "src/litert/delegate/nnrt/nnrt_delegate.h" 6247+#endif 6248 6249 using AbstractBaseModel = mindspore::infer::AbstractBaseModel; 6250 6251@@ -635,12 +638,6 @@ int LiteSession::CompileGraph(Model *model) { 6252 MarkSharedWeight(kernels_); 6253 FreePackOpWeight(kernels_); 6254 6255- ret = RuntimeAllocatorInit(); 6256- if (ret != RET_OK) { 6257- MS_LOG(ERROR) << "Runtime allocator init failed."; 6258- is_running_.store(false); 6259- return ret; 6260- } 6261 infer_along_running_ = infer_along_running_ && (runtime_allocator_ == nullptr); 6262 if (infer_along_running_) { 6263 this->context_->set_infer_checker(InferCheckerAll); 6264@@ -1092,6 +1089,27 @@ int LiteSession::CreateCoreMLDelegate() { 6265 return RET_OK; 6266 } 6267 6268+int LiteSession::CreateNNRTDelegate() { 6269+#if SUPPORT_NNRT 6270+ auto iter = std::find_if(context_->device_list_.begin(), context_->device_list_.end(), 6271+ [](DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; }); 6272+ if(iter == context_->device_list_.end()) { 6273+ MS_LOG(ERROR) << "Found non NNRT device info"; 6274+ return RET_ERROR; 6275+ } 6276+ 6277+ delegate_ = std::make_shared<NNRTDelegate>(iter->device_info_.nnrt_device_info_); 6278+ if (delegate_ == nullptr) { 6279+ MS_LOG(ERROR) << "New NNRT delegate failed"; 6280+ return RET_ERROR; 6281+ } 6282+// ((NNRTDelegate *)(delegate_.get()))->SetMetaGraph(this->model_->buf); 6283+ delegate_device_type_ = DT_NNRT; 6284+ this->context_->delegate = delegate_; 6285+#endif 6286+ return RET_OK; 6287+}; 6288+ 6289 int LiteSession::DelegateInit() { 6290 #ifndef DELEGATE_CLIP 6291 int ret = RET_OK; 6292@@ -1115,6 +1133,8 @@ int LiteSession::DelegateInit() { 6293 ret = CreateNPUDelegate(); 6294 } else if (context_->IsDeviceTypeEnabled(DT_GPU)) { 6295 ret = CreateTensorRTDelegate(); 6296+ } else if (context_->IsDeviceTypeEnabled(DT_NNRT)) { 6297+ ret = CreateNNRTDelegate(); 6298 } 6299 } 6300 6301@@ -1496,12 +1516,6 @@ int LiteSession::Resize(const std::vector<mindspore::lite::Tensor *> &inputs, 6302 return ret; 6303 } 6304 6305- if (RuntimeAllocatorInit() != RET_OK) { 6306- MS_LOG(ERROR) << "Runtime allocator in resize failed."; 6307- is_running_.store(false); 6308- return RET_ERROR; 6309- } 6310- 6311 auto status = GraphOptimizePass(&kernels_); 6312 if (status != RET_OK) { 6313 MS_LOG(ERROR) << "GraphOptimizePass failed."; 6314@@ -2022,7 +2036,6 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, 6315 delete model; 6316 return RET_ERROR; 6317 } 6318- model->Free(); 6319 set_model(model); 6320 return RET_OK; 6321 } 6322diff --git a/mindspore/lite/src/litert/lite_session.h b/mindspore/lite/src/litert/lite_session.h 6323index f8f8fe08..64a5f6d3 100644 6324--- a/mindspore/lite/src/litert/lite_session.h 6325+++ b/mindspore/lite/src/litert/lite_session.h 6326@@ -178,6 +178,7 @@ class MS_API LiteSession { 6327 int CreateNPUDelegate(); 6328 int CreateNNAPIDelegate(); 6329 int CreateCoreMLDelegate(); 6330+ int CreateNNRTDelegate(); 6331 int DelegateInit(); 6332 int InitGPURuntime(); 6333 int InitSharedThreadPool(); 6334diff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc 6335index 11382b09..199b4361 100644 6336--- a/mindspore/lite/src/litert/scheduler.cc 6337+++ b/mindspore/lite/src/litert/scheduler.cc 6338@@ -60,6 +60,9 @@ 6339 #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) 6340 #include "thread/parallel_thread_pool_manager.h" 6341 #endif 6342+#ifdef SUPPORT_NNRT 6343+#include "src/litert/delegate/nnrt/nnrt_delegate.h" 6344+#endif 6345 6346 using AbstractBaseModel = mindspore::infer::AbstractBaseModel; 6347 6348@@ -368,6 +371,7 @@ STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::KernelExec *> *ker 6349 } 6350 6351 int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) { 6352+ MS_LOG(DEBUG) << "Start schedule."; 6353 int check_input_ret = CheckInputParam(dst_kernels); 6354 if (check_input_ret != RET_OK) { 6355 MS_LOG(ERROR) << "CheckInputParam failed! ret: " << check_input_ret; 6356@@ -404,11 +408,13 @@ int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) { 6357 } 6358 shape_fusion_pass_->StoreStateAndReset(); 6359 6360+ MS_LOG(DEBUG) << "Start to init delegate kernels."; 6361 ret = InitDelegateKernels(dst_kernels); 6362 if (ret != RET_OK) { 6363 MS_LOG(ERROR) << "Repalce delegate kernels failed."; 6364 return ret; 6365 } 6366+ MS_LOG(DEBUG) << "Finish to init delegate kernels."; 6367 6368 ret = CheckCpuValid(dst_kernels); 6369 if (ret != RET_OK) { 6370@@ -500,6 +506,17 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::KernelExec *> *dst_ker 6371 MS_LOG(ERROR) << "New delegate model failed."; 6372 return RET_NULL_PTR; 6373 } 6374+ 6375+#ifdef SUPPORT_NNRT 6376+ if (context_->IsDeviceTypeEnabled(DT_NNRT)) { 6377+ auto delegate = static_cast<NNRTDelegate *>(delegate_.get()); 6378+ delegate->ShallowCopyLiteGraph(this->src_model_->graph_); 6379+ void *meta_graph = reinterpret_cast<void*>(const_cast<mindspore::schema::MetaGraph *>( 6380+ mindspore::schema::GetMetaGraph(this->src_model_->buf))); 6381+ delegate->SetMetaGraph(meta_graph); 6382+ } 6383+#endif 6384+ 6385 auto ret = delegate_->Build(model); 6386 if (ret != mindspore::kSuccess) { 6387 delete model; 6388diff --git a/mindspore/lite/src/litert/tensor_category.cc b/mindspore/lite/src/litert/tensor_category.cc 6389index 70d13865..e57cdb28 100644 6390--- a/mindspore/lite/src/litert/tensor_category.cc 6391+++ b/mindspore/lite/src/litert/tensor_category.cc 6392@@ -30,5 +30,9 @@ Category TensorCategory(const schema::Tensor &tensor) { 6393 auto data_size = tensor.data() == nullptr ? 0 : tensor.data()->size(); 6394 return TensorCategory(tensor.nodeType(), shape_num, TypeId(tensor.dataType()), data_size); 6395 } 6396+ 6397+bool IsConstTensor(const schema::Tensor &tensor) { 6398+ return TensorCategory(tensor) != Category::VAR; 6399+} 6400 } // namespace lite 6401 } // namespace mindspore 6402diff --git a/mindspore/lite/src/litert/tensor_category.h b/mindspore/lite/src/litert/tensor_category.h 6403index 83273032..70e65b31 100644 6404--- a/mindspore/lite/src/litert/tensor_category.h 6405+++ b/mindspore/lite/src/litert/tensor_category.h 6406@@ -35,6 +35,7 @@ enum Category { 6407 6408 Category TensorCategory(const int node_type, const size_t shape_num, const TypeId data_type, const size_t data_size); 6409 Category TensorCategory(const schema::Tensor &tensor); 6410+bool IsConstTensor(const schema::Tensor &tensor); 6411 } // namespace lite 6412 } // namespace mindspore 6413 #endif // MINDSPORE_LITE_SRC_RUNTIME_TENSOR_CATEGORY_H_ 6414diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt 6415index 60e240f0..78dab536 100644 6416--- a/mindspore/lite/test/CMakeLists.txt 6417+++ b/mindspore/lite/test/CMakeLists.txt 6418@@ -28,10 +28,14 @@ file(GLOB_RECURSE TEST_UT_SRC 6419 ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc 6420 ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc 6421 ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc 6422- ${TEST_DIR}/ut/src/api/context_c_test.cc 6423- ${TEST_DIR}/ut/src/api/model_c_test.cc 6424- ${TEST_DIR}/ut/src/api/tensor_c_test.cc` 6425+# ${TEST_DIR}/ut/src/api/context_c_test.cc 6426+# ${TEST_DIR}/ut/src/api/model_c_test.cc 6427+# ${TEST_DIR}/ut/src/api/tensor_c_test.cc` 6428 ) 6429+if(MSLITE_ENABLE_NNRT) 6430+ list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/nnrt_delegate/nnrt_delegate_tests.cc) 6431+endif() 6432+ 6433 if(MSLITE_ENABLE_SERVER_INFERENCE) 6434 list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/api/model_parallel_runner_test.cc) 6435 endif() 6436@@ -86,7 +90,7 @@ endif() 6437 6438 if(MSLITE_ENABLE_INT8) 6439 file(GLOB_RECURSE TEST_INT8_UT_SRC 6440- ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc 6441+# ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc 6442 ${TEST_DIR}/ut/nnacl/int8/*.cc 6443 ) 6444 list(APPEND TEST_UT_SRC ${TEST_INT8_UT_SRC}) 6445@@ -118,6 +122,7 @@ if(MSLITE_ENABLE_CONVERTER) 6446 ${TEST_DIR}/ut/tools/converter/registry/*.cc 6447 ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc 6448 ${TEST_DIR}/ut/tools/converter/api/*.cc 6449+ ${TEST_DIR}/ut/tools/converter/config_parser/*.cc 6450 ${TEST_DIR}/st/converter_test.cc 6451 ${TEST_DIR}/st/delegate_test.cc 6452 ${TEST_DIR}/st/mindrt_parallel_test.cc 6453@@ -232,7 +237,7 @@ endif() 6454 6455 if(MSLITE_ENABLE_CONVERTER) 6456 target_link_libraries(lite-test-converter tflite_parser_mid caffe_parser_mid 6457- onnx_parser_mid tf_parser_mid) 6458+ onnx_parser_mid tf_parser_mid third_party_parser_mid) 6459 endif() 6460 6461 if(MSLITE_ENABLE_MODEL_OBF) 6462diff --git a/mindspore/lite/test/runtest.sh b/mindspore/lite/test/runtest.sh 6463index c0d6d843..abdea6f4 100644 6464--- a/mindspore/lite/test/runtest.sh 6465+++ b/mindspore/lite/test/runtest.sh 6466@@ -80,6 +80,7 @@ if [ "$ENABLE_CONVERTER_TEST" = true ]; then 6467 ./lite-test-converter --gtest_filter="PassRegistryTest.TestRegistry" 6468 ./lite-test-converter --gtest_filter="TestConverterAPI.*" 6469 ./lite-test-converter --gtest_filter="SpecifyGraphOutputFormatTest*" 6470+ ./lite-test-converter --gtest_filter="TestThirdPartyParamParser.*" 6471 fi 6472 ./lite-test --gtest_filter="TestRegistry.TestAdd" 6473 ./lite-test --gtest_filter="TestRegistryCustomOp.TestCustomAdd" 6474diff --git a/mindspore/lite/test/ut/test_data/third_party_model.cfg b/mindspore/lite/test/ut/test_data/third_party_model.cfg 6475new file mode 100644 6476index 00000000..b5fcba75 6477--- /dev/null 6478+++ b/mindspore/lite/test/ut/test_data/third_party_model.cfg 6479@@ -0,0 +1,8 @@ 6480+[third_party_model] 6481+input_names=demo_in_0;demo_in_1;demo_in_2 6482+input_dtypes=float32;float16;float64 6483+input_shapes=1;2,3;4,5,6 6484+output_names=demo_out_0;demo_out_1;demo_out_2;demo_out_4 6485+output_dtypes=int32;int16;int8;uint8 6486+output_shapes=10;20,30;40;50,60,70 6487+extended_parameters=foo:foo_value;bar:bar_value 6488diff --git a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 6489index 549bdd72..e73afc0e 100644 6490--- a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 6491+++ b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 6492@@ -34,3 +34,13 @@ TEST(TestConverterAPI, ConvertCaffeWithNotExistWeight) { 6493 mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeCaffe, caffe_model, output_model, caffe_weight); 6494 ASSERT_FALSE(converter.Convert().IsOk()); 6495 } 6496+ 6497+TEST(TestConverterAPI, ConvertThirdParty) { 6498+ std::string third_party_model = "./relu.mindir"; 6499+ std::string config_model = "./third_party_model.cfg"; 6500+ std::string output_model = "./demo_third_party.ms"; 6501+ 6502+ mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeThirdParty, third_party_model, output_model); 6503+ converter.SetConfigFile(config_model); 6504+ ASSERT_TRUE(converter.Convert().IsOk()); 6505+} 6506\ No newline at end of file 6507diff --git a/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 6508new file mode 100644 6509index 00000000..c8eb5536 6510--- /dev/null 6511+++ b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 6512@@ -0,0 +1,176 @@ 6513+/** 6514+ * Copyright 2023 Huawei Technologies Co., Ltd 6515+ * 6516+ * Licensed under the Apache License, Version 2.0 (the "License"); 6517+ * you may not use this file except in compliance with the License. 6518+ * You may obtain a copy of the License at 6519+ * 6520+ * http://www.apache.org/licenses/LICENSE-2.0 6521+ * 6522+ * Unless required by applicable law or agreed to in writing, software 6523+ * distributed under the License is distributed on an "AS IS" BASIS, 6524+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6525+ * See the License for the specific language governing permissions and 6526+ * limitations under the License. 6527+ */ 6528+ 6529+#include "gtest/gtest.h" 6530+#include "tools/converter/config_parser/third_party_param_parser.h" 6531+ 6532+using mindspore::ThirdPartyModelParam; 6533+using mindspore::TypeId; 6534+using mindspore::lite::RET_OK; 6535+using mindspore::lite::ThirdPartyModelString; 6536+using mindspore::lite::ThirdPartyParamParser; 6537+ 6538+const ThirdPartyModelString kDemoSISOParam = { 6539+ // SISO is short for single-input-single-output. 6540+ .input_dtypes = "float32", 6541+ .input_shapes = "1,2,3,4", 6542+ .input_names = "siso_input", 6543+ .output_dtypes = "int32", 6544+ .output_shapes = "2", 6545+ .output_names = "siso_output", 6546+ .extended_parameters = "siso_foo:siso_foo_value;siso_bar:siso_bar_value", 6547+}; 6548+ 6549+const ThirdPartyModelString kDemoMIMOParam = { 6550+ // MIMO is short for multiple-input-multiple-output. 6551+ .input_dtypes = "float32;int8;float16", 6552+ .input_shapes = "1,2,3,4;5,6;7,8,9", 6553+ .input_names = "mimo_in_0;mimo_in_1;mimo_in_2", 6554+ .output_dtypes = "int32;float32", 6555+ .output_shapes = "2,4;10,20,30", 6556+ .output_names = "mimo_out_0;mimo_out_1", 6557+ .extended_parameters = "mimo_foo:mimo_foo_value;mimo_bar:mimo_bar_value", 6558+}; 6559+ 6560+TEST(TestThirdPartyParamParser, ParseSISOParam) { 6561+ ThirdPartyModelString param_string = kDemoSISOParam; 6562+ ThirdPartyModelParam result; 6563+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6564+ 6565+ ASSERT_EQ(result.input_names, std::vector<std::string>{"siso_input"}); 6566+ ASSERT_EQ(result.input_shapes.size(), 1U); 6567+ std::vector<int64_t> expect_in_shape = {1, 2, 3, 4}; 6568+ ASSERT_EQ(result.input_shapes[0], expect_in_shape); 6569+ ASSERT_EQ(result.input_dtypes, std::vector<TypeId>{TypeId::kNumberTypeFloat32}); 6570+ 6571+ ASSERT_EQ(result.output_names, std::vector<std::string>{"siso_output"}); 6572+ ASSERT_EQ(result.output_shapes.size(), 1U); 6573+ std::vector<int64_t> expect_out_shape = {2}; 6574+ ASSERT_EQ(result.output_shapes[0], expect_out_shape); 6575+ ASSERT_EQ(result.output_dtypes, std::vector<TypeId>{TypeId::kNumberTypeInt32}); 6576+ 6577+ const auto &ext_param = result.extended_parameters; 6578+ ASSERT_EQ(ext_param.size(), 2U); 6579+ ASSERT_TRUE(ext_param.find("siso_foo") != ext_param.end()); 6580+ auto expect_foo_value = ext_param.at("siso_foo"); 6581+ ASSERT_EQ(std::string(expect_foo_value.begin(), expect_foo_value.end()), "siso_foo_value"); 6582+ ASSERT_TRUE(ext_param.find("siso_bar") != ext_param.end()); 6583+ auto expect_bar_value = ext_param.at("siso_bar"); 6584+ ASSERT_EQ(std::string(expect_bar_value.begin(), expect_bar_value.end()), "siso_bar_value"); 6585+} 6586+ 6587+TEST(TestThirdPartyParamParser, ParseValidDtype) { 6588+ ThirdPartyModelString param_string = kDemoSISOParam; 6589+ const std::vector<std::string> kValidDtypeStrings = { 6590+ "float64", "float32", "float16", "int64", "int32", "int16", "int8", "uint8", "bool", 6591+ }; 6592+ 6593+ const std::vector<TypeId> kExpects = { 6594+ TypeId::kNumberTypeFloat64, TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat16, 6595+ TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt16, 6596+ TypeId::kNumberTypeInt8, TypeId::kNumberTypeUInt8, TypeId::kNumberTypeBool}; 6597+ 6598+ for (size_t i = 0; i < kValidDtypeStrings.size(); i++) { 6599+ param_string.input_dtypes = kValidDtypeStrings[i]; 6600+ ThirdPartyModelParam result; 6601+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6602+ ASSERT_EQ(result.input_dtypes[0], kExpects[i]); 6603+ } 6604+} 6605+ 6606+TEST(TestThirdPartyParamParser, ParseInvalidDtype) { 6607+ ThirdPartyModelParam result; 6608+ ThirdPartyModelString param_string = kDemoSISOParam; 6609+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6610+ param_string.input_dtypes = "bad_dtype"; 6611+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6612+} 6613+ 6614+TEST(TestThirdPartyParamParser, ParseValidShape) { 6615+ ThirdPartyModelString param_string = kDemoSISOParam; 6616+ param_string.input_shapes = "256,256,1024,96"; // Only support fixed shape. 6617+ ThirdPartyModelParam result; 6618+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6619+ std::vector<int64_t> expect = {256, 256, 1024, 96}; 6620+ ASSERT_EQ(result.input_shapes[0], expect); 6621+} 6622+ 6623+TEST(TestThirdPartyParamParser, ParseInvalidShape) { 6624+ ThirdPartyModelParam result; 6625+ ThirdPartyModelString param_string = kDemoSISOParam; 6626+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6627+ 6628+ param_string.input_shapes = "256,256,1024,-1"; 6629+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6630+ 6631+ param_string.input_shapes = "256,256,0,96"; 6632+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6633+ 6634+ param_string.input_shapes = "256,-256,1024,96"; 6635+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6636+ 6637+ param_string.input_shapes = "256,foo,1024,96"; 6638+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6639+} 6640+ 6641+TEST(TestThirdPartyParamParser, ParseDefaultName) { 6642+ ThirdPartyModelParam result; 6643+ ThirdPartyModelString param_string = kDemoSISOParam; 6644+ param_string.input_names = ""; 6645+ param_string.output_names = ""; 6646+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6647+ ASSERT_EQ(result.input_names[0], "in_0"); 6648+ ASSERT_EQ(result.output_names[0], "out_0"); 6649+} 6650+ 6651+TEST(TestThirdPartyParamParser, ParseMIMOParam) { 6652+ ThirdPartyModelString param_string = kDemoMIMOParam; 6653+ ThirdPartyModelParam result; 6654+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6655+ 6656+ std::vector<std::string> expect_input_names = {"mimo_in_0", "mimo_in_1", "mimo_in_2"}; 6657+ ASSERT_EQ(result.input_names, expect_input_names); 6658+ std::vector<std::vector<int64_t>> expect_input_shapes = {{1, 2, 3, 4}, {5, 6}, {7, 8, 9}}; 6659+ ASSERT_EQ(result.input_shapes, expect_input_shapes); 6660+ std::vector<TypeId> expect_input_dtypes = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt8, 6661+ TypeId::kNumberTypeFloat16}; 6662+ ASSERT_EQ(result.input_dtypes, expect_input_dtypes); 6663+ 6664+ std::vector<std::string> expect_output_names = {"mimo_out_0", "mimo_out_1"}; 6665+ ASSERT_EQ(result.output_names, expect_output_names); 6666+ std::vector<std::vector<int64_t>> expect_output_shapes = {{2, 4}, {10, 20, 30}}; 6667+ ASSERT_EQ(result.output_shapes, expect_output_shapes); 6668+ std::vector<TypeId> expect_output_dtypes = {TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32}; 6669+ ASSERT_EQ(result.output_dtypes, expect_output_dtypes); 6670+} 6671+ 6672+TEST(TestThirdPartyParamParser, ParseMismatchedShapeAndDtypeSize) { 6673+ ThirdPartyModelString param_string = kDemoMIMOParam; 6674+ ThirdPartyModelParam result; 6675+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6676+ 6677+ param_string.input_shapes = "1,2,3,4;5,6"; // shape size is 2 while dtype size is 3. 6678+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6679+} 6680+ 6681+TEST(TestThirdPartyParamParser, ParseMismatchedNameAndDtypeSize) { 6682+ ThirdPartyModelString param_string = kDemoMIMOParam; 6683+ ThirdPartyModelParam result; 6684+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6685+ 6686+ param_string.input_names = "mimo_in_0;mimo_in_1"; // name size is 2 while dtype size is 3. 6687+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6688+} 6689diff --git a/mindspore/lite/tools/benchmark/benchmark_base.cc b/mindspore/lite/tools/benchmark/benchmark_base.cc 6690index 16b1e218..ebaa9212 100644 6691--- a/mindspore/lite/tools/benchmark/benchmark_base.cc 6692+++ b/mindspore/lite/tools/benchmark/benchmark_base.cc 6693@@ -323,7 +323,7 @@ int BenchmarkBase::CheckThreadNumValid() { 6694 6695 int BenchmarkBase::CheckDeviceTypeValid() { 6696 if (flags_->device_ != "CPU" && flags_->device_ != "GPU" && flags_->device_ != "NPU" && 6697- flags_->device_ != "Ascend310" && flags_->device_ != "Ascend310P") { 6698+ flags_->device_ != "Ascend310" && flags_->device_ != "Ascend310P" && flags_->device_ != "NNRT") { 6699 MS_LOG(ERROR) << "Device type:" << flags_->device_ << " is not supported."; 6700 std::cerr << "Device type:" << flags_->device_ << " is not supported." << std::endl; 6701 return RET_ERROR; 6702diff --git a/mindspore/lite/tools/benchmark/benchmark_base.h b/mindspore/lite/tools/benchmark/benchmark_base.h 6703index acdea21a..f818270c 100644 6704--- a/mindspore/lite/tools/benchmark/benchmark_base.h 6705+++ b/mindspore/lite/tools/benchmark/benchmark_base.h 6706@@ -122,7 +122,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { 6707 AddFlag(&BenchmarkFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", ""); 6708 AddFlag(&BenchmarkFlags::group_info_file_, "GroupInfoFile", "Communication group info file", ""); 6709 AddFlag(&BenchmarkFlags::config_file_, "configFile", "Config file", ""); 6710- AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend310P | Auto", "CPU"); 6711+ AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend310P | NNRT | Auto", "CPU"); 6712 AddFlag(&BenchmarkFlags::provider_, "provider", "device provider litert | tensorrt | mindrt", "litert"); 6713 AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode", "Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU.", 1); 6714 // MarkPerformance 6715diff --git a/mindspore/lite/tools/benchmark/benchmark_c_api.cc b/mindspore/lite/tools/benchmark/benchmark_c_api.cc 6716index 252e65c6..cb0c56b0 100644 6717--- a/mindspore/lite/tools/benchmark/benchmark_c_api.cc 6718+++ b/mindspore/lite/tools/benchmark/benchmark_c_api.cc 6719@@ -125,6 +125,10 @@ int BenchmarkCApi::InitContext() { 6720 OH_AI_DeviceInfoSetFrequency(npu_device_info, kFrequencyDefault); 6721 OH_AI_ContextAddDeviceInfo(context_, npu_device_info); 6722 } 6723+ if (flags_->device_ == "NNRT") { 6724+ OH_AI_DeviceInfoHandle nnrt_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT); 6725+ OH_AI_ContextAddDeviceInfo(context_, nnrt_device_info); 6726+ } 6727 OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU); 6728 OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_); 6729 OH_AI_ContextAddDeviceInfo(context_, cpu_device_info); 6730diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc 6731index bb36c168..c18111b6 100644 6732--- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc 6733+++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc 6734@@ -521,6 +521,11 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context> 6735 // InitMSContextForAscend(context, &device_list); 6736 } 6737 6738+ if (flags_->device_ == "NNRT" || flags_->device_ == "Auto") { 6739+ std::shared_ptr<NNRTDeviceInfo> nnrt_device_info = std::make_shared<NNRTDeviceInfo>(); 6740+ device_list.push_back(nnrt_device_info); 6741+ } 6742+ 6743 // CPU priority is behind GPU and NPU 6744 std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>(); 6745 device_info->SetEnableFP16(flags_->enable_fp16_); 6746diff --git a/mindspore/lite/tools/benchmark_train/CMakeLists.txt b/mindspore/lite/tools/benchmark_train/CMakeLists.txt 6747index 0c558524..1b9fc347 100644 6748--- a/mindspore/lite/tools/benchmark_train/CMakeLists.txt 6749+++ b/mindspore/lite/tools/benchmark_train/CMakeLists.txt 6750@@ -9,6 +9,9 @@ set(COMMON_SRC 6751 set(TEST_SRC 6752 ${CMAKE_CURRENT_SOURCE_DIR}/main.cc 6753 ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc 6754+ ${CMAKE_CURRENT_SOURCE_DIR}/net_train_base.cc 6755+ ${CMAKE_CURRENT_SOURCE_DIR}/run_net_train.cc 6756+ ${CMAKE_CURRENT_SOURCE_DIR}/net_train_c_api.cc 6757 ) 6758 6759 # add static securec link library 6760diff --git a/mindspore/lite/tools/benchmark_train/main.cc b/mindspore/lite/tools/benchmark_train/main.cc 6761index abf3d9dd..76f85aa7 100644 6762--- a/mindspore/lite/tools/benchmark_train/main.cc 6763+++ b/mindspore/lite/tools/benchmark_train/main.cc 6764@@ -17,7 +17,8 @@ 6765 #include <malloc.h> 6766 #include <unistd.h> 6767 #include <fstream> 6768-#include "tools/benchmark_train/net_train.h" 6769+#include <iostream> 6770+#include "tools/benchmark_train/run_net_train.h" 6771 6772 void PrintMem() { 6773 std::string proc_file = "/proc/" + std::to_string(getpid()) + "/status"; 6774diff --git a/mindspore/lite/tools/benchmark_train/net_runner.cc b/mindspore/lite/tools/benchmark_train/net_runner.cc 6775index 9b63d29f..edf3e964 100644 6776--- a/mindspore/lite/tools/benchmark_train/net_runner.cc 6777+++ b/mindspore/lite/tools/benchmark_train/net_runner.cc 6778@@ -15,7 +15,7 @@ 6779 */ 6780 6781 #include "tools/benchmark_train/net_runner.h" 6782-#include "tools/benchmark_train/net_train.h" 6783+#include "tools/benchmark_train/net_train_base.h" 6784 #include <getopt.h> 6785 #include <malloc.h> 6786 #include <cmath> 6787@@ -187,7 +187,7 @@ int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) { 6788 auto output = tensor.Data(); 6789 size_t size; 6790 std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin"; 6791- auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(output_file.c_str(), &size)); 6792+ auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(output_file.c_str(), &size)); 6793 if (bin_buf == nullptr) { 6794 MS_LOG(ERROR) << "ReadFile return nullptr"; 6795 std::cout << "ReadFile return nullptr" << std::endl; 6796@@ -200,7 +200,7 @@ int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) { 6797 << ", read size: " << size << std::endl; 6798 return mindspore::kLiteError; 6799 } 6800- float bias = mindspore::lite::NetTrain::CompareData<float>(bin_buf.get(), tensor.ElementNum(), 6801+ float bias = mindspore::lite::NetTrainBase::CompareData<float>(bin_buf.get(), tensor.ElementNum(), 6802 reinterpret_cast<const float *>(output.get())); 6803 if (bias >= 0) { 6804 total_bias += bias; 6805@@ -332,7 +332,7 @@ int NetRunner::ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs) { 6806 } 6807 size_t size; 6808 std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin"; 6809- auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(file_name.c_str(), &size)); 6810+ auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(file_name.c_str(), &size)); 6811 if (bin_buf == nullptr) { 6812 MS_LOG(ERROR) << "ReadFile return nullptr"; 6813 std::cout << "ReadFile return nullptr" << std::endl; 6814@@ -368,4 +368,4 @@ int CallBack(mindspore::lite::NetTrainFlags *flags) { 6815 return nr.Main(); 6816 } 6817 6818-int init = mindspore::lite::NetTrain::SetNr(CallBack); 6819+int init = mindspore::lite::NetTrainBase::SetNr(CallBack); 6820diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc 6821index d1150043..514bba53 100644 6822--- a/mindspore/lite/tools/benchmark_train/net_train.cc 6823+++ b/mindspore/lite/tools/benchmark_train/net_train.cc 6824@@ -31,74 +31,11 @@ 6825 6826 namespace mindspore { 6827 namespace lite { 6828-static const char *DELIM_SLASH = "/"; 6829-constexpr const char *DELIM_COLON = ":"; 6830-constexpr const char *DELIM_COMMA = ","; 6831-constexpr int RET_TOO_BIG = -9; 6832 constexpr int kField0 = 0; 6833 constexpr int kField1 = 1; 6834 constexpr int kField2 = 2; 6835 constexpr int kField3 = 3; 6836 constexpr int kField4 = 4; 6837-constexpr int kFieldsToPrint = 5; 6838-constexpr int kPrintOffset = 4; 6839-static const int kTHOUSAND = 1000; 6840-constexpr int kDumpInputsAndOutputs = 0; 6841-constexpr int kDumpOutputs = 2; 6842- 6843-const std::unordered_map<int, std::string> kTypeIdMap{ 6844- {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"}, {kNumberTypeFloat32, "Float32"}, 6845- {kNumberTypeInt8, "Int8"}, {kNumberTypeInt16, "Int16"}, {kNumberTypeInt, "Int32"}, 6846- {kNumberTypeInt32, "Int32"}, {kNumberTypeUInt8, "UInt8"}, {kNumberTypeUInt16, "UInt16"}, 6847- {kNumberTypeUInt, "UInt32"}, {kNumberTypeUInt32, "UInt32"}, {kObjectTypeString, "String"}, 6848- {kNumberTypeBool, "Bool"}, {kObjectTypeTensorType, "Tensor"}}; 6849- 6850-const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{ 6851- {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"}, {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"}, 6852- {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"}, {mindspore::CKHW, "CKHW"}, {mindspore::KHWC, "KHWC"}, 6853- {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"}, {mindspore::HW4, "HW4"}, {mindspore::NC, "NC"}, 6854- {mindspore::NC4, "NC4"}, {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}}; 6855- 6856-std::function<int(NetTrainFlags *)> NetTrain::nr_cb_ = nullptr; 6857- 6858-int NetTrain::SetNr(std::function<int(NetTrainFlags *)> param) { 6859- nr_cb_ = param; 6860- return 0; 6861-} 6862- 6863-float *NetTrain::ReadFileBuf(const std::string file, size_t *size) { 6864- if (file.empty()) { 6865- MS_LOG(ERROR) << "file is nullptr"; 6866- return nullptr; 6867- } 6868- MS_ASSERT(size != nullptr); 6869- std::string real_path = RealPath(file.c_str()); 6870- std::ifstream ifs(real_path); 6871- if (!ifs.good()) { 6872- MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 6873- return nullptr; 6874- } 6875- 6876- if (!ifs.is_open()) { 6877- MS_LOG(ERROR) << "file: " << real_path << " open failed"; 6878- return nullptr; 6879- } 6880- 6881- ifs.seekg(0, std::ios::end); 6882- *size = ifs.tellg(); 6883- std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1); 6884- if (buf == nullptr) { 6885- MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; 6886- ifs.close(); 6887- return nullptr; 6888- } 6889- 6890- ifs.seekg(0, std::ios::beg); 6891- ifs.read(reinterpret_cast<char *>(buf.get()), *size); 6892- ifs.close(); 6893- 6894- return buf.release(); 6895-} 6896 6897 int NetTrain::GenerateInputData() { 6898 for (auto tensor : ms_inputs_for_api_) { 6899@@ -120,28 +57,6 @@ int NetTrain::GenerateInputData() { 6900 return RET_OK; 6901 } 6902 6903-int NetTrain::LoadInput() { 6904- inputs_buf_.clear(); 6905- inputs_size_.clear(); 6906- batch_num_ = 0; 6907- if (flags_->in_data_file_.empty()) { 6908- auto status = GenerateInputData(); 6909- if (status != RET_OK) { 6910- std::cerr << "Generate input data error " << status << std::endl; 6911- MS_LOG(ERROR) << "Generate input data error " << status; 6912- return status; 6913- } 6914- } else { 6915- auto status = ReadInputFile(); 6916- if (status != RET_OK) { 6917- std::cerr << "Read Input File error, " << status << std::endl; 6918- MS_LOG(ERROR) << "Read Input File error, " << status; 6919- return status; 6920- } 6921- } 6922- return RET_OK; 6923-} 6924- 6925 int NetTrain::LoadStepInput(size_t step) { 6926 if (step >= batch_num_) { 6927 auto cur_batch = step + 1; 6928@@ -269,30 +184,6 @@ int NetTrain::CompareOutput() { 6929 } 6930 } 6931 6932-std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name, 6933- const std::string &file_type, const size_t &idx) { 6934- std::string file_name = op_name; 6935- auto pos = file_name.find_first_of('/'); 6936- while (pos != std::string::npos) { 6937- file_name.replace(pos, 1, "."); 6938- pos = file_name.find_first_of('/'); 6939- } 6940- file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_"; 6941- for (const auto &dim : tensor->Shape()) { 6942- file_name += std::to_string(dim) + "_"; 6943- } 6944- if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) { 6945- file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType())); 6946- } 6947- auto tensor_format = tensor->format(); 6948- if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) { 6949- file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin"; 6950- } 6951- 6952- file_name += ".bin"; 6953- return file_name; 6954-} 6955- 6956 int NetTrain::MarkPerformance() { 6957 MS_LOG(INFO) << "Running train loops..."; 6958 std::cout << "Running train loops..." << std::endl; 6959@@ -574,26 +465,6 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string 6960 return RET_OK; 6961 } 6962 6963-int NetTrain::RunNetTrain() { 6964- auto file_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1); 6965- bool is_train = (file_name.find("train") != std::string::npos) || !flags_->bb_model_file_.empty(); 6966- auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, is_train, flags_->epochs_); 6967- if (status != RET_OK) { 6968- MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status; 6969- std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status 6970- << std::endl; 6971- return status; 6972- } 6973- 6974- status = CheckExecutionOfSavedModels(); // re-initialize sessions according to flags 6975- if (status != RET_OK) { 6976- MS_LOG(ERROR) << "Run CheckExecute error: " << status; 6977- std::cout << "Run CheckExecute error: " << status << std::endl; 6978- return status; 6979- } 6980- return RET_OK; 6981-} 6982- 6983 int NetTrain::SaveModels() { 6984 if (!flags_->export_file_.empty()) { 6985 if (flags_->bb_model_file_.empty()) { 6986@@ -635,77 +506,6 @@ int NetTrain::SaveModels() { 6987 return RET_OK; 6988 } 6989 6990-int NetTrain::CheckExecutionOfSavedModels() { 6991- int status = RET_OK; 6992- if (!flags_->export_file_.empty()) { 6993- status = NetTrain::CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0); 6994- if (status != RET_OK) { 6995- MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status; 6996- std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl; 6997- return status; 6998- } 6999- if (flags_->bb_model_file_.empty()) { 7000- status = NetTrain::CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false); 7001- if (status != RET_OK) { 7002- MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status; 7003- std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl; 7004- return status; 7005- } 7006- } 7007- } 7008- if (!flags_->inference_file_.empty()) { 7009- status = NetTrain::CreateAndRunNetwork(flags_->inference_file_, "", false, 0); 7010- if (status != RET_OK) { 7011- MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status; 7012- std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl; 7013- return status; 7014- } 7015- status = NetTrain::CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false); 7016- if (status != RET_OK) { 7017- MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status; 7018- std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl; 7019- return status; 7020- } 7021- } 7022- return status; 7023-} 7024- 7025-void NetTrain::CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out) { 7026- if (tensor == nullptr) { 7027- MS_LOG(ERROR) << "input tensor is nullptr."; 7028- return; 7029- } 7030- int tensor_size = tensor->ElementNum(); 7031- void *data = tensor->MutableData(); 7032- auto *fdata = reinterpret_cast<float *>(tensor->MutableData()); 7033- auto type = tensor->DataType(); 7034- std::cout << node_type << " " << in_out << id << " shape=" << tensor->Shape() << " sum="; 7035- switch (type) { 7036- case mindspore::DataType::kNumberTypeFloat32: 7037- TensorNan(reinterpret_cast<float *>(data), tensor_size); 7038- std::cout << TensorSum<float>(data, tensor_size) << std::endl; 7039- std::cout << "tensor name: " << tensor->Name() << std::endl; 7040- std::cout << "data: "; 7041- for (int i = 0; i <= kPrintOffset && i < tensor_size; i++) { 7042- std::cout << static_cast<float>(fdata[i]) << ", "; 7043- } 7044- std::cout << std::endl; 7045- break; 7046- case mindspore::DataType::kNumberTypeInt32: 7047- std::cout << TensorSum<int>(data, tensor_size) << std::endl; 7048- break; 7049-#ifdef ENABLE_FP16 7050- case mindspore::DataType::kNumberTypeFloat16: 7051- std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl; 7052- TensorNan(reinterpret_cast<float16_t *>(data), tensor_size); 7053- break; 7054-#endif 7055- default: 7056- std::cout << "unsupported type:" << static_cast<int>(type) << std::endl; 7057- break; 7058- } 7059-} 7060- 7061 int NetTrain::InitDumpTensorDataCallbackParameter() { 7062 // before callback 7063 before_call_back_ = [&](const std::vector<mindspore::MSTensor> &before_inputs, 7064@@ -815,178 +615,6 @@ int NetTrain::InitTimeProfilingCallbackParameter() { 7065 return RET_OK; 7066 } 7067 7068-int NetTrain::InitCallbackParameter() { 7069- int ret = RET_OK; 7070- if (flags_->dump_tensor_data_) { 7071- ret = InitDumpTensorDataCallbackParameter(); 7072- } else if (flags_->time_profiling_) { 7073- ret = InitTimeProfilingCallbackParameter(); 7074- } 7075- return ret; 7076-} 7077- 7078-void NetTrainFlags::InitResizeDimsList() { 7079- std::string content = this->resize_dims_in_; 7080- std::vector<int> shape; 7081- auto shape_strs = StrSplit(content, std::string(DELIM_COLON)); 7082- for (const auto &shape_str : shape_strs) { 7083- shape.clear(); 7084- auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA)); 7085- std::cout << "Resize Dims: "; 7086- for (const auto &dim_str : dim_strs) { 7087- std::cout << dim_str << " "; 7088- shape.emplace_back(static_cast<int>(std::stoi(dim_str))); 7089- } 7090- std::cout << std::endl; 7091- this->resize_dims_.emplace_back(shape); 7092- } 7093-} 7094- 7095-int NetTrain::Init() { 7096- if (this->flags_ == nullptr) { 7097- return 1; 7098- } 7099- MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_; 7100- MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_; 7101- MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_; 7102- MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_; 7103- MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_; 7104- MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_; 7105- MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_; 7106- MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_; 7107- MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_; 7108- MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_; 7109- MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_; 7110- 7111- if (this->flags_->epochs_ < 0) { 7112- MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0"; 7113- std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl; 7114- return RET_ERROR; 7115- } 7116- 7117- if (this->flags_->num_threads_ < 1) { 7118- MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0"; 7119- std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl; 7120- return RET_ERROR; 7121- } 7122- 7123- this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary; 7124- 7125- if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) { 7126- MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided"; 7127- std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl; 7128- return RET_ERROR; 7129- } 7130- 7131- if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) { 7132- MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided"; 7133- std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl; 7134- return RET_ERROR; 7135- } 7136- 7137- if (flags_->model_file_.empty()) { 7138- MS_LOG(ERROR) << "modelPath is required"; 7139- std::cerr << "modelPath is required" << std::endl; 7140- return 1; 7141- } 7142- 7143- // get dump data output path 7144- auto dump_cfg_path = std::getenv(dump::kConfigPath); 7145- if (dump_cfg_path != nullptr) { 7146- flags_->dump_tensor_data_ = true; 7147- if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) { 7148- MS_LOG(ERROR) << "parse dump config file failed."; 7149- return RET_ERROR; 7150- } 7151- } else { 7152- MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data"; 7153- } 7154- 7155- auto status = InitCallbackParameter(); 7156- if (status != RET_OK) { 7157- MS_LOG(ERROR) << "Init callback Parameter failed."; 7158- std::cerr << "Init callback Parameter failed." << std::endl; 7159- return RET_ERROR; 7160- } 7161- 7162- flags_->InitResizeDimsList(); 7163- if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() && 7164- flags_->resize_dims_.size() != flags_->input_data_list_.size()) { 7165- MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath"; 7166- std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl; 7167- return RET_ERROR; 7168- } 7169- return RET_OK; 7170-} 7171- 7172-namespace { 7173-constexpr int kNumToPrint = 5; 7174-} 7175- 7176-int NetTrain::InitDumpConfigFromJson(std::string path) { 7177- auto real_path = RealPath(path.c_str()); 7178- std::ifstream ifs(real_path); 7179- if (!ifs.good()) { 7180- MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 7181- return RET_ERROR; 7182- } 7183- if (!ifs.is_open()) { 7184- MS_LOG(ERROR) << "file: " << real_path << " open failed"; 7185- return RET_ERROR; 7186- } 7187- 7188- try { 7189- dump_cfg_json_ = nlohmann::json::parse(ifs); 7190- } catch (const nlohmann::json::parse_error &error) { 7191- MS_LOG(ERROR) << "parse json file failed, please check your file."; 7192- return RET_ERROR; 7193- } 7194- if (dump_cfg_json_[dump::kSettings] == nullptr) { 7195- MS_LOG(ERROR) << "\"common_dump_settings\" is required."; 7196- return RET_ERROR; 7197- } 7198- if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) { 7199- MS_LOG(ERROR) << "\"dump_mode\" is required."; 7200- return RET_ERROR; 7201- } 7202- if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) { 7203- MS_LOG(ERROR) << "\"path\" is required."; 7204- return RET_ERROR; 7205- } 7206- if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) { 7207- dump_cfg_json_[dump::kSettings][dump::kNetName] = "default"; 7208- } 7209- if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) { 7210- dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0; 7211- } 7212- if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr && 7213- !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) { 7214- if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) { 7215- MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)"; 7216- return RET_ERROR; 7217- } 7218- } 7219- 7220- auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>(); 7221- auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>(); 7222- if (abs_path.back() == '\\' || abs_path.back() == '/') { 7223- dump_file_output_dir_ = abs_path + net_name; 7224- } else { 7225-#ifdef _WIN32 7226- dump_file_output_dir_ = abs_path + "\\" + net_name; 7227-#else 7228- dump_file_output_dir_ = abs_path + "/" + net_name; 7229-#endif 7230- } 7231- 7232- auto status = CreateOutputDir(&dump_file_output_dir_); 7233- if (status != RET_OK) { 7234- MS_LOG(ERROR) << "create data output directory failed."; 7235- return RET_ERROR; 7236- } 7237- return RET_OK; 7238-} 7239- 7240 int NetTrain::PrintResult(const std::vector<std::string> &title, 7241 const std::map<std::string, std::pair<int, float>> &result) { 7242 std::vector<size_t> columnLenMax(kFieldsToPrint); 7243@@ -1035,7 +663,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title, 7244 } 7245 7246 printf("-------------------------------------------------------------------------\n"); 7247- for (int i = 0; i < kNumToPrint; i++) { 7248+ for (int i = 0; i < kFieldsToPrint; i++) { 7249 auto printBuf = title[i]; 7250 if (printBuf.size() > columnLenMax.at(i)) { 7251 columnLenMax.at(i) = printBuf.size(); 7252@@ -1045,7 +673,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title, 7253 } 7254 printf("\n"); 7255 for (auto &row : rows) { 7256- for (int j = 0; j < kNumToPrint; j++) { 7257+ for (int j = 0; j < kFieldsToPrint; j++) { 7258 auto printBuf = row[j]; 7259 printBuf.resize(columnLenMax.at(j), ' '); 7260 printf("%s\t", printBuf.c_str()); 7261@@ -1054,47 +682,5 @@ int NetTrain::PrintResult(const std::vector<std::string> &title, 7262 } 7263 return RET_OK; 7264 } 7265- 7266-int RunNetTrain(int argc, const char **argv) { 7267- NetTrainFlags flags; 7268- Option<std::string> err = flags.ParseFlags(argc, argv); 7269- 7270- if (err.IsSome()) { 7271- std::cerr << err.Get() << std::endl; 7272- std::cerr << flags.Usage() << std::endl; 7273- return RET_ERROR; 7274- } 7275- 7276- if (flags.help) { 7277- std::cerr << flags.Usage() << std::endl; 7278- return RET_OK; 7279- } 7280- if (flags.unified_api_) { 7281- return NetTrain::RunNr(&flags); 7282- } 7283- NetTrain net_trainer(&flags); 7284- auto status = net_trainer.Init(); 7285- if (status != RET_OK) { 7286- MS_LOG(ERROR) << "NetTrain init Error : " << status; 7287- std::cerr << "NetTrain init Error : " << status << std::endl; 7288- return RET_ERROR; 7289- } 7290- 7291- status = net_trainer.RunNetTrain(); 7292- if (status != RET_OK) { 7293- MS_LOG(ERROR) << "Run NetTrain " 7294- << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 7295- << " Failed : " << status; 7296- std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 7297- << " Failed : " << status << std::endl; 7298- return RET_ERROR; 7299- } 7300- 7301- MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 7302- << " Success."; 7303- std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 7304- << " Success." << std::endl; 7305- return RET_OK; 7306-} 7307 } // namespace lite 7308 } // namespace mindspore 7309diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h 7310index 67e58a04..bdf0ec88 100644 7311--- a/mindspore/lite/tools/benchmark_train/net_train.h 7312+++ b/mindspore/lite/tools/benchmark_train/net_train.h 7313@@ -42,183 +42,22 @@ 7314 #include "tools/common/flag_parser.h" 7315 #include "src/common/file_utils.h" 7316 #include "src/common/utils.h" 7317- 7318-#ifdef ENABLE_FP16 7319-static __attribute__((always_inline)) inline bool MS_ISNAN_FP16(float16_t var) { 7320- volatile float16_t d = var; 7321- return d != d; 7322-} 7323-#endif 7324+#include "tools/benchmark_train/net_train_base.h" 7325 7326 namespace mindspore::lite { 7327-enum MS_API DataType { kImage = 0, kBinary = 1 }; 7328- 7329-constexpr float relativeTolerance = 1e-5; 7330-constexpr float absoluteTolerance = 1e-8; 7331 extern const std::unordered_map<int, std::string> kTypeIdMap; 7332 extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap; 7333 7334-namespace dump { 7335-constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG"; 7336-constexpr auto kSettings = "common_dump_settings"; 7337-constexpr auto kMode = "dump_mode"; 7338-constexpr auto kPath = "path"; 7339-constexpr auto kNetName = "net_name"; 7340-constexpr auto kInputOutput = "input_output"; 7341-constexpr auto kKernels = "kernels"; 7342-} // namespace dump 7343- 7344-template <typename T> 7345-float TensorSum(const void *data, int size) { 7346- const T *typed_data = reinterpret_cast<const T *>(data); 7347- float sum = 0.f; 7348- for (int i = 0; i < size; i++) { 7349- sum += static_cast<float>(typed_data[i]); 7350- } 7351- return sum; 7352-} 7353- 7354-class MS_API NetTrainFlags : public virtual FlagParser { 7355+class MS_API NetTrain : public NetTrainBase { 7356 public: 7357- NetTrainFlags() { 7358- // common 7359- AddFlag(&NetTrainFlags::model_file_, "modelFile", "Input model file", ""); 7360- AddFlag(&NetTrainFlags::bb_model_file_, "bbModelFile", "Backboine model for transfer session", ""); 7361- AddFlag(&NetTrainFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", ""); 7362- // MarkPerformance 7363- AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0); 7364- AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false); 7365- AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1); 7366- AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1); 7367- // MarkAccuracy 7368- AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); 7369- AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); 7370- AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); 7371- AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); 7372- AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false); 7373- AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", ""); 7374- AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", ""); 7375- AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false); 7376- AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes", 7377- "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); 7378- AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false); 7379- } 7380- 7381- ~NetTrainFlags() override = default; 7382- void InitResizeDimsList(); 7383+ explicit NetTrain(NetTrainFlags *flags) : NetTrainBase(flags) {} 7384+ virtual ~NetTrain() {} 7385 7386- public: 7387- // common 7388- std::string model_file_; 7389- std::string in_data_file_; 7390- std::string bb_model_file_; 7391- std::vector<std::string> input_data_list_; 7392- DataType in_data_type_; 7393- std::string in_data_type_in_ = "bin"; 7394- int cpu_bind_mode_ = 1; 7395- bool enable_fp16_ = false; 7396- bool virtual_batch_ = false; 7397- // MarkPerformance 7398- int num_threads_ = 1; 7399- int warm_up_loop_count_ = 0; 7400- bool time_profiling_; 7401- int epochs_ = 1; 7402- // MarkAccuracy 7403- std::string data_file_; 7404- std::string data_type_ = "FLOAT"; 7405- float accuracy_threshold_; 7406- // Resize 7407- std::string export_file_ = ""; 7408- std::string resize_dims_in_ = ""; 7409- bool layer_checksum_ = false; 7410- std::vector<std::vector<int>> resize_dims_; 7411- std::string loss_name_ = ""; 7412- std::string inference_file_ = ""; 7413- bool unified_api_ = false; 7414- bool dump_tensor_data_ = false; 7415-}; 7416- 7417-class MS_API NetTrain { 7418- public: 7419- explicit NetTrain(NetTrainFlags *flags) : flags_(flags) {} 7420- virtual ~NetTrain() = default; 7421- 7422- int Init(); 7423- int RunNetTrain(); 7424- static float *ReadFileBuf(const std::string file, size_t *size); 7425- static int SetNr(std::function<int(NetTrainFlags *)> param); 7426- static int RunNr(NetTrainFlags *flags) { 7427- if (nr_cb_ != nullptr) { 7428- return nr_cb_(flags); 7429- } 7430- MS_LOG(WARNING) << "unified api was not tested"; 7431- std::cout << "unified api was not tested"; 7432- return RET_OK; 7433- } 7434- // tensorData need to be converter first 7435- template <typename T> 7436- static float CompareData(const float *refOutput, int size, const T *msTensorData) { 7437- size_t errorCount = 0; 7438- float meanError = 0; 7439- std::cout << "Out tensor size is: " << size << std::endl; 7440- std::cout << "Data of model output: "; 7441- for (int j = 0; j < std::min(50, size); j++) { 7442- std::cout << static_cast<float>(msTensorData[j]) << " "; 7443- } 7444- std::cout << std::endl; 7445- std::cout << "Data of Ref output : "; 7446- for (int j = 0; j < std::min(50, size); j++) { 7447- std::cout << refOutput[j] << " "; 7448- } 7449- std::cout << std::endl; 7450- for (int j = 0; j < size; j++) { 7451- if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { 7452- std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; 7453- MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; 7454- return RET_ERROR; 7455- } 7456- 7457- auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]); 7458- auto absoluteError = std::fabs(static_cast<float>(msTensorData[j]) - refOutput[j]); 7459- if (absoluteError > tolerance) { 7460- if (fabs(refOutput[j]) == 0) { 7461- if (absoluteError > 1e-5) { 7462- meanError += absoluteError; 7463- errorCount++; 7464- } else { 7465- continue; 7466- } 7467- } else { 7468- // just assume that atol = rtol 7469- meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN); 7470- errorCount++; 7471- } 7472- } 7473- } 7474- std::cout << std::endl; 7475- if (meanError > 0.0f) { 7476- meanError /= errorCount; 7477- } 7478- 7479- if (meanError <= 0.0000001) { 7480- std::cout << "Mean bias of tensor: 0%" << std::endl; 7481- } else { 7482- std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl; 7483- } 7484- return meanError; 7485- } 7486- int InitDumpConfigFromJson(std::string path); 7487- 7488- private: 7489- // call GenerateInputData or ReadInputFile to init inputTensors 7490- int LoadInput(); 7491- void CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out); 7492+ protected: 7493 // call GenerateRandomData to fill inputTensors 7494- int GenerateInputData(); 7495+ int GenerateInputData() override; 7496 7497- int GenerateRandomData(mindspore::MSTensor *tensor); 7498- 7499- int ReadInputFile(); 7500+ int ReadInputFile() override; 7501 7502 int LoadStepInput(size_t step); 7503 7504@@ -227,20 +66,19 @@ class MS_API NetTrain { 7505 void InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg); 7506 7507 int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs, 7508- bool check_accuracy = true); 7509+ bool check_accuracy = true) override; 7510 7511 int CreateAndRunNetworkForInference(const std::string &filename, const std::shared_ptr<mindspore::Context> &context); 7512 7513 int CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename, 7514 const std::shared_ptr<mindspore::Context> &context, 7515 const std::shared_ptr<TrainCfg> &train_cfg, int epochs); 7516- int InitCallbackParameter(); 7517 7518- int InitDumpTensorDataCallbackParameter(); 7519+ int InitDumpTensorDataCallbackParameter() override; 7520 7521- int InitTimeProfilingCallbackParameter(); 7522+ int InitTimeProfilingCallbackParameter() override; 7523 7524- int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result); 7525+ int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) override; 7526 7527 template <typename T> 7528 void PrintInputData(mindspore::MSTensor *input) { 7529@@ -256,39 +94,11 @@ class MS_API NetTrain { 7530 std::cout << std::endl; 7531 } 7532 7533- template <typename T> 7534- std::vector<int64_t> ConverterToInt64Vector(const std::vector<T> &srcDims) { 7535- std::vector<int64_t> dims; 7536- for (auto shape : srcDims) { 7537- dims.push_back(static_cast<int64_t>(shape)); 7538- } 7539- return dims; 7540- } 7541- int MarkPerformance(); 7542- int MarkAccuracy(bool enforce_accuracy = true); 7543- int CompareOutput(); 7544- int SaveModels(); 7545- int CheckExecutionOfSavedModels(); 7546- void TensorNan(const float *data, int size) { 7547- for (int i = 0; i < size; i++) { 7548- if (std::isnan(data[i])) { 7549- std::cout << "nan value of index=" << i << ", " << data[i] << std::endl; 7550- break; 7551- } 7552- } 7553- } 7554-#ifdef ENABLE_FP16 7555- void TensorNan(float16_t *data, int size) { 7556- for (int i = 0; i < size; i++) { 7557- if (MS_ISNAN_FP16(data[i]) || std::isinf(data[i])) { 7558- std::cout << "nan or inf value of index=" << i << ", " << data[i] << std::endl; 7559- break; 7560- } 7561- } 7562- } 7563-#endif 7564- NetTrainFlags *flags_{nullptr}; 7565- static std::function<int(NetTrainFlags *)> nr_cb_; 7566+ int MarkPerformance() override; 7567+ int MarkAccuracy(bool enforce_accuracy = true) override; 7568+ int CompareOutput() override; 7569+ int SaveModels() override; 7570+ 7571 // callback parameters 7572 uint64_t op_begin_ = 0; 7573 int op_call_times_total_ = 0; 7574@@ -301,13 +111,6 @@ class MS_API NetTrain { 7575 7576 mindspore::MSKernelCallBack before_call_back_{nullptr}; 7577 mindspore::MSKernelCallBack after_call_back_{nullptr}; 7578- nlohmann::json dump_cfg_json_; 7579- std::string dump_file_output_dir_; 7580- std::vector<std::shared_ptr<char>> inputs_buf_; 7581- std::vector<size_t> inputs_size_; 7582- size_t batch_num_ = 0; 7583 }; 7584- 7585-int MS_API RunNetTrain(int argc, const char **argv); 7586 } // namespace mindspore::lite 7587 #endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_ 7588diff --git a/mindspore/lite/tools/benchmark_train/net_train_base.cc b/mindspore/lite/tools/benchmark_train/net_train_base.cc 7589new file mode 100644 7590index 00000000..8d3c75de 7591--- /dev/null 7592+++ b/mindspore/lite/tools/benchmark_train/net_train_base.cc 7593@@ -0,0 +1,410 @@ 7594+/** 7595+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 7596+ * 7597+ * Licensed under the Apache License, Version 2.0 (the "License"); 7598+ * you may not use this file except in compliance with the License. 7599+ * You may obtain a copy of the License at 7600+ * 7601+ * http://www.apache.org/licenses/LICENSE-2.0 7602+ * 7603+ * Unless required by applicable law or agreed to in writing, software 7604+ * distributed under the License is distributed on an "AS IS" BASIS, 7605+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7606+ * See the License for the specific language governing permissions and 7607+ * limitations under the License. 7608+ */ 7609+ 7610+#include "tools/benchmark_train/net_train_base.h" 7611+#define __STDC_FORMAT_MACROS 7612+#undef __STDC_FORMAT_MACROS 7613+#include <algorithm> 7614+#include <cstring> 7615+#ifdef ENABLE_NEON 7616+#include <arm_neon.h> 7617+#endif 7618+#include "src/common/common.h" 7619+#include "include/api/serialization.h" 7620+ 7621+namespace mindspore { 7622+namespace lite { 7623+const std::unordered_map<int, std::string> kTypeIdMap{ 7624+ {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"}, {kNumberTypeFloat32, "Float32"}, 7625+ {kNumberTypeInt8, "Int8"}, {kNumberTypeInt16, "Int16"}, {kNumberTypeInt, "Int32"}, 7626+ {kNumberTypeInt32, "Int32"}, {kNumberTypeUInt8, "UInt8"}, {kNumberTypeUInt16, "UInt16"}, 7627+ {kNumberTypeUInt, "UInt32"}, {kNumberTypeUInt32, "UInt32"}, {kObjectTypeString, "String"}, 7628+ {kNumberTypeBool, "Bool"}, {kObjectTypeTensorType, "Tensor"}}; 7629+ 7630+const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{ 7631+ {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"}, {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"}, 7632+ {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"}, {mindspore::CKHW, "CKHW"}, {mindspore::KHWC, "KHWC"}, 7633+ {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"}, {mindspore::HW4, "HW4"}, {mindspore::NC, "NC"}, 7634+ {mindspore::NC4, "NC4"}, {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}}; 7635+ 7636+std::function<int(NetTrainFlags *)> NetTrainBase::nr_cb_ = nullptr; 7637+ 7638+int NetTrainBase::SetNr(std::function<int(NetTrainFlags *)> param) { 7639+ nr_cb_ = param; 7640+ return 0; 7641+} 7642+ 7643+float *NetTrainBase::ReadFileBuf(const std::string file, size_t *size) { 7644+ if (file.empty()) { 7645+ MS_LOG(ERROR) << "file is nullptr"; 7646+ return nullptr; 7647+ } 7648+ MS_ASSERT(size != nullptr); 7649+ std::string real_path = RealPath(file.c_str()); 7650+ std::ifstream ifs(real_path); 7651+ if (!ifs.good()) { 7652+ MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 7653+ return nullptr; 7654+ } 7655+ 7656+ if (!ifs.is_open()) { 7657+ MS_LOG(ERROR) << "file: " << real_path << " open failed"; 7658+ return nullptr; 7659+ } 7660+ 7661+ ifs.seekg(0, std::ios::end); 7662+ *size = ifs.tellg(); 7663+ std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1); 7664+ if (buf == nullptr) { 7665+ MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; 7666+ ifs.close(); 7667+ return nullptr; 7668+ } 7669+ 7670+ ifs.seekg(0, std::ios::beg); 7671+ ifs.read(reinterpret_cast<char *>(buf.get()), *size); 7672+ ifs.close(); 7673+ 7674+ return buf.release(); 7675+} 7676+ 7677+int NetTrainBase::GenerateRandomData(mindspore::MSTensor *tensor) { 7678+ auto input_data = tensor->MutableData(); 7679+ if (input_data == nullptr) { 7680+ MS_LOG(ERROR) << "MallocData for inTensor failed"; 7681+ return RET_ERROR; 7682+ } 7683+ auto tensor_byte_size = tensor->DataSize(); 7684+ char *casted_data = static_cast<char *>(input_data); 7685+ for (size_t i = 0; i < tensor_byte_size; i++) { 7686+ casted_data[i] = 7687+ (tensor->DataType() == mindspore::DataType::kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0); 7688+ } 7689+ return RET_OK; 7690+} 7691+ 7692+int NetTrainBase::LoadInput() { 7693+ inputs_buf_.clear(); 7694+ inputs_size_.clear(); 7695+ batch_num_ = 0; 7696+ if (flags_->in_data_file_.empty()) { 7697+ auto status = GenerateInputData(); 7698+ if (status != RET_OK) { 7699+ std::cerr << "Generate input data error " << status << std::endl; 7700+ MS_LOG(ERROR) << "Generate input data error " << status; 7701+ return status; 7702+ } 7703+ } else { 7704+ auto status = ReadInputFile(); 7705+ if (status != RET_OK) { 7706+ std::cerr << "Read Input File error, " << status << std::endl; 7707+ MS_LOG(ERROR) << "Read Input File error, " << status; 7708+ return status; 7709+ } 7710+ } 7711+ return RET_OK; 7712+} 7713+ 7714+int NetTrainBase::RunNetTrain() { 7715+ auto file_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1); 7716+ bool is_train = (file_name.find("train") != std::string::npos) || !flags_->bb_model_file_.empty(); 7717+ auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, is_train, flags_->epochs_); 7718+ if (status != RET_OK) { 7719+ MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status; 7720+ std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status 7721+ << std::endl; 7722+ return status; 7723+ } 7724+ 7725+ status = CheckExecutionOfSavedModels(); // re-initialize sessions according to flags 7726+ if (status != RET_OK) { 7727+ MS_LOG(ERROR) << "Run CheckExecute error: " << status; 7728+ std::cout << "Run CheckExecute error: " << status << std::endl; 7729+ return status; 7730+ } 7731+ return RET_OK; 7732+} 7733+ 7734+int NetTrainBase::CheckExecutionOfSavedModels() { 7735+ int status = RET_OK; 7736+ if (!flags_->export_file_.empty()) { 7737+ status = CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0); 7738+ if (status != RET_OK) { 7739+ MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status; 7740+ std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl; 7741+ return status; 7742+ } 7743+ if (flags_->bb_model_file_.empty()) { 7744+ status = CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false); 7745+ if (status != RET_OK) { 7746+ MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status; 7747+ std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl; 7748+ return status; 7749+ } 7750+ } 7751+ } 7752+ if (!flags_->inference_file_.empty()) { 7753+ status = CreateAndRunNetwork(flags_->inference_file_, "", false, 0); 7754+ if (status != RET_OK) { 7755+ MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status; 7756+ std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl; 7757+ return status; 7758+ } 7759+ status = CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false); 7760+ if (status != RET_OK) { 7761+ MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status; 7762+ std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl; 7763+ return status; 7764+ } 7765+ } 7766+ return status; 7767+} 7768+ 7769+void NetTrainBase::CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out) { 7770+ if (tensor == nullptr) { 7771+ MS_LOG(ERROR) << "input tensor is nullptr."; 7772+ return; 7773+ } 7774+ int tensor_size = tensor->ElementNum(); 7775+ void *data = tensor->MutableData(); 7776+ auto *fdata = reinterpret_cast<float *>(tensor->MutableData()); 7777+ auto type = tensor->DataType(); 7778+ std::cout << node_type << " " << in_out << id << " shape=" << tensor->Shape() << " sum="; 7779+ switch (type) { 7780+ case mindspore::DataType::kNumberTypeFloat32: 7781+ TensorNan(reinterpret_cast<float *>(data), tensor_size); 7782+ std::cout << TensorSum<float>(data, tensor_size) << std::endl; 7783+ std::cout << "tensor name: " << tensor->Name() << std::endl; 7784+ std::cout << "data: "; 7785+ for (int i = 0; i <= kPrintOffset && i < tensor_size; i++) { 7786+ std::cout << static_cast<float>(fdata[i]) << ", "; 7787+ } 7788+ std::cout << std::endl; 7789+ break; 7790+ case mindspore::DataType::kNumberTypeInt32: 7791+ std::cout << TensorSum<int>(data, tensor_size) << std::endl; 7792+ break; 7793+#ifdef ENABLE_FP16 7794+ case mindspore::DataType::kNumberTypeFloat16: 7795+ std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl; 7796+ TensorNan(reinterpret_cast<float16_t *>(data), tensor_size); 7797+ break; 7798+#endif 7799+ default: 7800+ std::cout << "unsupported type:" << static_cast<int>(type) << std::endl; 7801+ break; 7802+ } 7803+} 7804+ 7805+std::string NetTrainBase::GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name, 7806+ const std::string &file_type, const size_t &idx) { 7807+ std::string file_name = op_name; 7808+ auto pos = file_name.find_first_of('/'); 7809+ while (pos != std::string::npos) { 7810+ file_name.replace(pos, 1, "."); 7811+ pos = file_name.find_first_of('/'); 7812+ } 7813+ file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_"; 7814+ for (const auto &dim : tensor->Shape()) { 7815+ file_name += std::to_string(dim) + "_"; 7816+ } 7817+ if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) { 7818+ file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType())); 7819+ } 7820+ auto tensor_format = tensor->format(); 7821+ if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) { 7822+ file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin"; 7823+ } 7824+ 7825+ file_name += ".bin"; 7826+ return file_name; 7827+} 7828+ 7829+int NetTrainBase::InitCallbackParameter() { 7830+ int ret = RET_OK; 7831+ if (flags_->dump_tensor_data_) { 7832+ ret = InitDumpTensorDataCallbackParameter(); 7833+ } else if (flags_->time_profiling_) { 7834+ ret = InitTimeProfilingCallbackParameter(); 7835+ } 7836+ return ret; 7837+} 7838+ 7839+void NetTrainFlags::InitResizeDimsList() { 7840+ std::string content = this->resize_dims_in_; 7841+ if (content.empty()) { 7842+ return; 7843+ } 7844+ std::vector<int> shape; 7845+ auto shape_strs = StrSplit(content, std::string(DELIM_COLON)); 7846+ for (const auto &shape_str : shape_strs) { 7847+ shape.clear(); 7848+ auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA)); 7849+ std::cout << "Resize Dims: "; 7850+ for (const auto &dim_str : dim_strs) { 7851+ std::cout << dim_str << " "; 7852+ shape.emplace_back(static_cast<int>(std::stoi(dim_str))); 7853+ } 7854+ std::cout << std::endl; 7855+ this->resize_dims_.emplace_back(shape); 7856+ } 7857+} 7858+ 7859+int NetTrainBase::Init() { 7860+ if (this->flags_ == nullptr) { 7861+ return 1; 7862+ } 7863+ MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_; 7864+ MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_; 7865+ MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_; 7866+ MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_; 7867+ MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_; 7868+ MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_; 7869+ MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_; 7870+ MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_; 7871+ MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_; 7872+ MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_; 7873+ MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_; 7874+ 7875+ if (this->flags_->epochs_ < 0) { 7876+ MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0"; 7877+ std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl; 7878+ return RET_ERROR; 7879+ } 7880+ 7881+ if (this->flags_->num_threads_ < 1) { 7882+ MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0"; 7883+ std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl; 7884+ return RET_ERROR; 7885+ } 7886+ 7887+ this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary; 7888+ 7889+ if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) { 7890+ MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided"; 7891+ std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl; 7892+ return RET_ERROR; 7893+ } 7894+ 7895+ if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) { 7896+ MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided"; 7897+ std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl; 7898+ return RET_ERROR; 7899+ } 7900+ 7901+ if (flags_->model_file_.empty()) { 7902+ MS_LOG(ERROR) << "modelPath is required"; 7903+ std::cerr << "modelPath is required" << std::endl; 7904+ return 1; 7905+ } 7906+ 7907+ // get dump data output path 7908+ auto dump_cfg_path = std::getenv(dump::kConfigPath); 7909+ if (dump_cfg_path != nullptr) { 7910+ flags_->dump_tensor_data_ = true; 7911+ if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) { 7912+ MS_LOG(ERROR) << "parse dump config file failed."; 7913+ return RET_ERROR; 7914+ } 7915+ } else { 7916+ MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data"; 7917+ } 7918+ 7919+ auto status = InitCallbackParameter(); 7920+ if (status != RET_OK) { 7921+ MS_LOG(ERROR) << "Init callback Parameter failed."; 7922+ std::cerr << "Init callback Parameter failed." << std::endl; 7923+ return RET_ERROR; 7924+ } 7925+ 7926+ flags_->InitResizeDimsList(); 7927+ if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() && 7928+ flags_->resize_dims_.size() != flags_->input_data_list_.size()) { 7929+ MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath"; 7930+ std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl; 7931+ return RET_ERROR; 7932+ } 7933+ return RET_OK; 7934+} 7935+ 7936+int NetTrainBase::InitDumpConfigFromJson(std::string path) { 7937+ auto real_path = RealPath(path.c_str()); 7938+ std::ifstream ifs(real_path); 7939+ if (!ifs.good()) { 7940+ MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 7941+ return RET_ERROR; 7942+ } 7943+ if (!ifs.is_open()) { 7944+ MS_LOG(ERROR) << "file: " << real_path << " open failed"; 7945+ return RET_ERROR; 7946+ } 7947+ 7948+ try { 7949+ dump_cfg_json_ = nlohmann::json::parse(ifs); 7950+ } catch (const nlohmann::json::parse_error &error) { 7951+ MS_LOG(ERROR) << "parse json file failed, please check your file."; 7952+ return RET_ERROR; 7953+ } 7954+ if (dump_cfg_json_[dump::kSettings] == nullptr) { 7955+ MS_LOG(ERROR) << "\"common_dump_settings\" is required."; 7956+ return RET_ERROR; 7957+ } 7958+ if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) { 7959+ MS_LOG(ERROR) << "\"dump_mode\" is required."; 7960+ return RET_ERROR; 7961+ } 7962+ if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) { 7963+ MS_LOG(ERROR) << "\"path\" is required."; 7964+ return RET_ERROR; 7965+ } 7966+ if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) { 7967+ dump_cfg_json_[dump::kSettings][dump::kNetName] = "default"; 7968+ } 7969+ if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) { 7970+ dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0; 7971+ } 7972+ if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr && 7973+ !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) { 7974+ if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) { 7975+ MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)"; 7976+ return RET_ERROR; 7977+ } 7978+ } 7979+ 7980+ auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>(); 7981+ auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>(); 7982+ if (abs_path.back() == '\\' || abs_path.back() == '/') { 7983+ dump_file_output_dir_ = abs_path + net_name; 7984+ } else { 7985+#ifdef _WIN32 7986+ dump_file_output_dir_ = abs_path + "\\" + net_name; 7987+#else 7988+ dump_file_output_dir_ = abs_path + "/" + net_name; 7989+#endif 7990+ } 7991+ 7992+ auto status = CreateOutputDir(&dump_file_output_dir_); 7993+ if (status != RET_OK) { 7994+ MS_LOG(ERROR) << "create data output directory failed."; 7995+ return RET_ERROR; 7996+ } 7997+ return RET_OK; 7998+} 7999+ 8000+NetTrainBase:: ~NetTrainBase() { 8001+} 8002+} // namespace lite 8003+} // namespace mindspore 8004diff --git a/mindspore/lite/tools/benchmark_train/net_train_base.h b/mindspore/lite/tools/benchmark_train/net_train_base.h 8005new file mode 100644 8006index 00000000..e3d5f39a 8007--- /dev/null 8008+++ b/mindspore/lite/tools/benchmark_train/net_train_base.h 8009@@ -0,0 +1,288 @@ 8010+/** 8011+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 8012+ * 8013+ * Licensed under the Apache License, Version 2.0 (the "License"); 8014+ * you may not use this file except in compliance with the License. 8015+ * You may obtain a copy of the License at 8016+ * 8017+ * http://www.apache.org/licenses/LICENSE-2.0 8018+ * 8019+ * Unless required by applicable law or agreed to in writing, software 8020+ * distributed under the License is distributed on an "AS IS" BASIS, 8021+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8022+ * See the License for the specific language governing permissions and 8023+ * limitations under the License. 8024+ */ 8025+ 8026+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_ 8027+#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_ 8028+ 8029+#include <getopt.h> 8030+#include <csignal> 8031+#include <unordered_map> 8032+#include <fstream> 8033+#include <iostream> 8034+#include <map> 8035+#include <cmath> 8036+#include <string> 8037+#include <vector> 8038+#include <memory> 8039+#include <cfloat> 8040+#include <utility> 8041+#include <algorithm> 8042+#include <nlohmann/json.hpp> 8043+#include "include/api/model.h" 8044+#include "include/api/types.h" 8045+#include "include/api/context.h" 8046+#include "include/api/cfg.h" 8047+ 8048+#ifdef ENABLE_FP16 8049+#include <arm_neon.h> 8050+#endif 8051+#include "tools/common/flag_parser.h" 8052+#include "src/common/file_utils.h" 8053+#include "src/common/utils.h" 8054+ 8055+#ifdef ENABLE_FP16 8056+static __attribute__((always_inline)) inline bool MS_ISNAN_FP16(float16_t var) { 8057+ volatile float16_t d = var; 8058+ return d != d; 8059+} 8060+#endif 8061+ 8062+namespace mindspore::lite { 8063+enum MS_API DataType { kImage = 0, kBinary = 1 }; 8064+ 8065+constexpr float relativeTolerance = 1e-5; 8066+constexpr float absoluteTolerance = 1e-8; 8067+extern const std::unordered_map<int, std::string> kTypeIdMap; 8068+extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap; 8069+ 8070+constexpr const char *DELIM_SLASH = "/"; 8071+constexpr const char *DELIM_COLON = ":"; 8072+constexpr const char *DELIM_COMMA = ","; 8073+ 8074+constexpr int RET_TOO_BIG = -9; 8075+constexpr int kFieldsToPrint = 5; 8076+constexpr int kPrintOffset = 4; 8077+constexpr int kDumpInputsAndOutputs = 0; 8078+constexpr int kDumpOutputs = 2; 8079+constexpr int kTHOUSAND = 1000; 8080+ 8081+namespace dump { 8082+constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG"; 8083+constexpr auto kSettings = "common_dump_settings"; 8084+constexpr auto kMode = "dump_mode"; 8085+constexpr auto kPath = "path"; 8086+constexpr auto kNetName = "net_name"; 8087+constexpr auto kInputOutput = "input_output"; 8088+constexpr auto kKernels = "kernels"; 8089+} // namespace dump 8090+ 8091+template <typename T> 8092+float TensorSum(const void *data, int size) { 8093+ const T *typed_data = reinterpret_cast<const T *>(data); 8094+ float sum = 0.f; 8095+ for (int i = 0; i < size; i++) { 8096+ sum += static_cast<float>(typed_data[i]); 8097+ } 8098+ return sum; 8099+} 8100+ 8101+class MS_API NetTrainFlags : public virtual FlagParser { 8102+ public: 8103+ NetTrainFlags() { 8104+ // common 8105+ AddFlag(&NetTrainFlags::model_file_, "modelFile", "Input model file", ""); 8106+ AddFlag(&NetTrainFlags::bb_model_file_, "bbModelFile", "Backboine model for transfer session", ""); 8107+ AddFlag(&NetTrainFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", ""); 8108+ // MarkPerformance 8109+ AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0); 8110+ AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false); 8111+ AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1); 8112+ AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1); 8113+ // MarkAccuracy 8114+ AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); 8115+ AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); 8116+ AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); 8117+ AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); 8118+ AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false); 8119+ AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", ""); 8120+ AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", ""); 8121+ AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false); 8122+ AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes", 8123+ "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); 8124+ AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false); 8125+ } 8126+ 8127+ ~NetTrainFlags() override = default; 8128+ void InitResizeDimsList(); 8129+ 8130+ public: 8131+ // common 8132+ std::string model_file_; 8133+ std::string in_data_file_; 8134+ std::string bb_model_file_; 8135+ std::vector<std::string> input_data_list_; 8136+ DataType in_data_type_; 8137+ std::string in_data_type_in_ = "bin"; 8138+ int cpu_bind_mode_ = 1; 8139+ bool enable_fp16_ = false; 8140+ bool virtual_batch_ = false; 8141+ // MarkPerformance 8142+ int num_threads_ = 1; 8143+ int warm_up_loop_count_ = 0; 8144+ bool time_profiling_; 8145+ int epochs_ = 1; 8146+ // MarkAccuracy 8147+ std::string data_file_; 8148+ std::string data_type_ = "FLOAT"; 8149+ float accuracy_threshold_; 8150+ // Resize 8151+ std::string export_file_ = ""; 8152+ std::string resize_dims_in_ = ""; 8153+ bool layer_checksum_ = false; 8154+ std::vector<std::vector<int>> resize_dims_; 8155+ std::string loss_name_ = ""; 8156+ std::string inference_file_ = ""; 8157+ bool unified_api_ = false; 8158+ bool dump_tensor_data_ = false; 8159+}; 8160+ 8161+class MS_API NetTrainBase { 8162+ public: 8163+ explicit NetTrainBase(NetTrainFlags *flags) : flags_(flags) {} 8164+ virtual ~NetTrainBase(); 8165+ 8166+ int Init(); 8167+ int RunNetTrain(); 8168+ static float *ReadFileBuf(const std::string file, size_t *size); 8169+ static int SetNr(std::function<int(NetTrainFlags *)> param); 8170+ static int RunNr(NetTrainFlags *flags) { 8171+ if (nr_cb_ != nullptr) { 8172+ return nr_cb_(flags); 8173+ } 8174+ MS_LOG(WARNING) << "unified api was not tested"; 8175+ std::cout << "unified api was not tested"; 8176+ return RET_OK; 8177+ } 8178+ // tensorData need to be converter first 8179+ template <typename T> 8180+ static float CompareData(const float *refOutput, int size, const T *msTensorData) { 8181+ size_t errorCount = 0; 8182+ float meanError = 0; 8183+ std::cout << "Out tensor size is: " << size << std::endl; 8184+ std::cout << "Data of model output: "; 8185+ for (int j = 0; j < std::min(50, size); j++) { 8186+ std::cout << static_cast<float>(msTensorData[j]) << " "; 8187+ } 8188+ std::cout << std::endl; 8189+ std::cout << "Data of Ref output : "; 8190+ for (int j = 0; j < std::min(50, size); j++) { 8191+ std::cout << refOutput[j] << " "; 8192+ } 8193+ std::cout << std::endl; 8194+ for (int j = 0; j < size; j++) { 8195+ if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { 8196+ std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; 8197+ MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; 8198+ return RET_ERROR; 8199+ } 8200+ 8201+ auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]); 8202+ auto absoluteError = std::fabs(static_cast<float>(msTensorData[j]) - refOutput[j]); 8203+ if (absoluteError > tolerance) { 8204+ if (fabs(refOutput[j]) == 0) { 8205+ if (absoluteError > 1e-5) { 8206+ meanError += absoluteError; 8207+ errorCount++; 8208+ } else { 8209+ continue; 8210+ } 8211+ } else { 8212+ // just assume that atol = rtol 8213+ meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN); 8214+ errorCount++; 8215+ } 8216+ } 8217+ } 8218+ std::cout << std::endl; 8219+ if (meanError > 0.0f) { 8220+ meanError /= errorCount; 8221+ } 8222+ 8223+ if (meanError <= 0.0000001) { 8224+ std::cout << "Mean bias of tensor: 0%" << std::endl; 8225+ } else { 8226+ std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl; 8227+ } 8228+ return meanError; 8229+ } 8230+ int InitDumpConfigFromJson(std::string path); 8231+ 8232+ protected: 8233+ // call GenerateInputData or ReadInputFile to init inputTensors 8234+ int LoadInput(); 8235+ void CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out); 8236+ // call GenerateRandomData to fill inputTensors 8237+ virtual int GenerateInputData() = 0; 8238+ 8239+ int GenerateRandomData(mindspore::MSTensor *tensor); 8240+ 8241+ std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name, 8242+ const std::string &file_type, const size_t &idx); 8243+ virtual int ReadInputFile() = 0; 8244+ 8245+ virtual int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs, 8246+ bool check_accuracy = true) = 0; 8247+ 8248+ int InitCallbackParameter(); 8249+ 8250+ virtual int InitDumpTensorDataCallbackParameter() = 0; 8251+ 8252+ virtual int InitTimeProfilingCallbackParameter() = 0; 8253+ 8254+ virtual int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) = 0; 8255+ 8256+ template <typename T> 8257+ std::vector<int64_t> ConverterToInt64Vector(const std::vector<T> &srcDims) { 8258+ std::vector<int64_t> dims; 8259+ for (auto shape : srcDims) { 8260+ dims.push_back(static_cast<int64_t>(shape)); 8261+ } 8262+ return dims; 8263+ } 8264+ virtual int MarkPerformance() = 0; 8265+ virtual int MarkAccuracy(bool enforce_accuracy = true) = 0; 8266+ virtual int CompareOutput() = 0; 8267+ virtual int SaveModels() = 0; 8268+ int CheckExecutionOfSavedModels(); 8269+ void TensorNan(const float *data, int size) { 8270+ for (int i = 0; i < size; i++) { 8271+ if (std::isnan(data[i])) { 8272+ std::cout << "nan value of index=" << i << ", " << data[i] << std::endl; 8273+ break; 8274+ } 8275+ } 8276+ } 8277+#ifdef ENABLE_FP16 8278+ void TensorNan(float16_t *data, int size) { 8279+ for (int i = 0; i < size; i++) { 8280+ if (MS_ISNAN_FP16(data[i]) || std::isinf(data[i])) { 8281+ std::cout << "nan or inf value of index=" << i << ", " << data[i] << std::endl; 8282+ break; 8283+ } 8284+ } 8285+ } 8286+#endif 8287+ NetTrainFlags *flags_{nullptr}; 8288+ static std::function<int(NetTrainFlags *)> nr_cb_; 8289+ 8290+ nlohmann::json dump_cfg_json_; 8291+ std::string dump_file_output_dir_; 8292+ std::vector<std::shared_ptr<char>> inputs_buf_; 8293+ std::vector<size_t> inputs_size_; 8294+ size_t batch_num_ = 0; 8295+}; 8296+} // namespace mindspore::lite 8297+#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_ 8298diff --git a/mindspore/lite/tools/benchmark_train/net_train_c_api.cc b/mindspore/lite/tools/benchmark_train/net_train_c_api.cc 8299new file mode 100644 8300index 00000000..4dcf3af6 8301--- /dev/null 8302+++ b/mindspore/lite/tools/benchmark_train/net_train_c_api.cc 8303@@ -0,0 +1,659 @@ 8304+/** 8305+ * Copyright 2023-2023 Huawei Technologies Co., Ltd 8306+ * 8307+ * Licensed under the Apache License, Version 2.0 (the "License"); 8308+ * you may not use this file except in compliance with the License. 8309+ * You may obtain a copy of the License at 8310+ * 8311+ * http://www.apache.org/licenses/LICENSE-2.0 8312+ * 8313+ * Unless required by applicable law or agreed to in writing, software 8314+ * distributed under the License is distributed on an "AS IS" BASIS, 8315+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8316+ * See the License for the specific language governing permissions and 8317+ * limitations under the License. 8318+ */ 8319+ 8320+#include "net_train_c_api.h" 8321+#include "securec/include/securec.h" 8322+ 8323+namespace mindspore { 8324+namespace lite { 8325+uint64_t g_op_begin_ = 0; 8326+int g_op_call_times_total_ = 0; 8327+float g_op_cost_total_ = 0.0f; 8328+ 8329+int NetTrainCApi::GenerateInputData() { 8330+ for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) { 8331+ OH_AI_TensorHandle tensor = ms_inputs_for_api_.handle_list[i]; 8332+ auto data_type = OH_AI_TensorGetDataType(tensor); 8333+ if (data_type == OH_AI_DATATYPE_OBJECTTYPE_STRING) { 8334+ MS_LOG(ERROR) << "Unsupported OH_AI_DATATYPE_OBJECTTYPE_STRING"; 8335+ return RET_ERROR; 8336+ } else { 8337+ (void)GenerateRandomData(static_cast<mindspore::MSTensor *>(tensor)); 8338+ } 8339+ } 8340+ return RET_OK; 8341+} 8342+ 8343+int NetTrainCApi::SaveModels() { 8344+ if (!flags_->export_file_.empty()) { 8345+ if (flags_->bb_model_file_.empty()) { 8346+ auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->export_file_ + "_qt").c_str(), OH_AI_WEIGHT_QUANT, false, 8347+ nullptr, 0); 8348+ if (status != OH_AI_STATUS_SUCCESS) { 8349+ MS_LOG(ERROR) << "Export quantized model error " << flags_->export_file_ + "_qt"; 8350+ std::cout << "Export quantized model error " << flags_->export_file_ + "_qt" << std::endl; 8351+ return RET_ERROR; 8352+ } 8353+ } 8354+ auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->export_file_).c_str(), OH_AI_NO_QUANT, false, 8355+ nullptr, 0); 8356+ 8357+ if (status != OH_AI_STATUS_SUCCESS) { 8358+ MS_LOG(ERROR) << "Export non quantized model error " << flags_->export_file_; 8359+ std::cout << "Export non quantized model error " << flags_->export_file_ << std::endl; 8360+ return RET_ERROR; 8361+ } 8362+ } 8363+ if (!flags_->inference_file_.empty()) { 8364+ auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->inference_file_ + "_qt").c_str(), OH_AI_WEIGHT_QUANT, true, 8365+ nullptr, 0); 8366+ if (status != OH_AI_STATUS_SUCCESS) { 8367+ MS_LOG(ERROR) << "Export quantized inference model error " << flags_->inference_file_ + "_qt"; 8368+ std::cout << "Export quantized inference model error " << flags_->inference_file_ + "_qt" << std::endl; 8369+ return RET_ERROR; 8370+ } 8371+ 8372+ auto tick = GetTimeUs(); 8373+ status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->inference_file_).c_str(), OH_AI_NO_QUANT, true, 8374+ nullptr, 0); 8375+ if (status != OH_AI_STATUS_SUCCESS) { 8376+ MS_LOG(ERROR) << "Export non quantized inference model error " << flags_->inference_file_; 8377+ std::cout << "Export non quantized inference model error " << flags_->inference_file_ << std::endl; 8378+ return RET_ERROR; 8379+ } 8380+ std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n"; 8381+ } 8382+ return RET_OK; 8383+} 8384+ 8385+int NetTrainCApi::LoadStepInput(size_t step) { 8386+ if (step >= batch_num_) { 8387+ auto cur_batch = step + 1; 8388+ MS_LOG(ERROR) << "Max input Batch is:" << batch_num_ << " but got batch :" << cur_batch; 8389+ return RET_ERROR; 8390+ } 8391+ for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) { 8392+ OH_AI_TensorHandle cur_tensor = ms_inputs_for_api_.handle_list[i]; 8393+ MS_ASSERT(cur_tensor != nullptr); 8394+ auto tensor_data_size = OH_AI_TensorGetDataSize(cur_tensor); 8395+ auto input_data = OH_AI_TensorGetMutableData(cur_tensor); 8396+ MS_ASSERT(input_data != nullptr); 8397+ memcpy_s(input_data, tensor_data_size, inputs_buf_[i].get() + step * tensor_data_size, tensor_data_size); 8398+ } 8399+ return RET_OK; 8400+} 8401+ 8402+int NetTrainCApi::ReadInputFile() { 8403+ if (this->flags_->in_data_type_ == lite::kImage) { 8404+ MS_LOG(ERROR) << "Unsupported image input"; 8405+ return RET_ERROR; 8406+ } else { 8407+ for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) { 8408+ OH_AI_TensorHandle tensor = ms_inputs_for_api_.handle_list[i]; 8409+ MS_ASSERT(tensor != nullptr); 8410+ size_t size; 8411+ std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin"; 8412+ auto bin_buf = lite::ReadFile(file_name.c_str(), &size); 8413+ if (bin_buf == nullptr) { 8414+ MS_LOG(ERROR) << "ReadFile failed"; 8415+ return RET_ERROR; 8416+ } 8417+ auto tensor_data_size = OH_AI_TensorGetDataSize(tensor); 8418+ MS_ASSERT(tensor_data_size != 0); 8419+ if (size == 0 || size % tensor_data_size != 0 || (batch_num_ != 0 && size / tensor_data_size != batch_num_)) { 8420+ std::cerr << "Input binary file size error, required :N * " << tensor_data_size << ", in fact: " << size 8421+ << " ,file_name: " << file_name.c_str() << std::endl; 8422+ MS_LOG(ERROR) << "Input binary file size error, required: N * " << tensor_data_size << ", in fact: " << size 8423+ << " ,file_name: " << file_name.c_str(); 8424+ delete bin_buf; 8425+ return RET_ERROR; 8426+ } 8427+ inputs_buf_.emplace_back(bin_buf); 8428+ inputs_size_.emplace_back(size); 8429+ batch_num_ = size / tensor_data_size; 8430+ } 8431+ } 8432+ return RET_OK; 8433+} 8434+ 8435+int NetTrainCApi::InitDumpTensorDataCallbackParameter() { 8436+ MS_LOG(ERROR) << "Unsupported feature."; 8437+ return RET_ERROR; 8438+} 8439+ 8440+int NetTrainCApi::InitTimeProfilingCallbackParameter() { 8441+ before_call_back_ = TimeProfilingBeforeCallback; 8442+ after_call_back_ = TimeProfilingAfterCallback; 8443+ return RET_OK; 8444+} 8445+ 8446+int NetTrainCApi::InitMSContext() { 8447+ context_ = OH_AI_ContextCreate(); 8448+ if (context_ == nullptr) { 8449+ MS_LOG(INFO) << "OH_AI_ContextCreate failed"; 8450+ return RET_ERROR; 8451+ } 8452+ OH_AI_ContextSetThreadNum(context_, flags_->num_threads_); 8453+ OH_AI_ContextSetThreadAffinityMode(context_, flags_->cpu_bind_mode_); 8454+ 8455+ OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU); 8456+ OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_); 8457+ OH_AI_ContextAddDeviceInfo(context_, cpu_device_info); 8458+ return RET_OK; 8459+} 8460+ 8461+char **NetTrainCApi::TransStrVectorToCharArrays(const std::vector<std::string> &s) { 8462+ char **char_arr = static_cast<char **>(malloc(s.size() * sizeof(char *))); 8463+ for (size_t i = 0; i < s.size(); i++) { 8464+ char_arr[i] = static_cast<char *>(malloc((s[i].size() + 1))); 8465+ strcpy(char_arr[i], s[i].c_str()); 8466+ } 8467+ return char_arr; 8468+} 8469+ 8470+std::vector<std::string> NetTrainCApi::TransCharArraysToStrVector(char **c, const size_t &num) { 8471+ std::vector<std::string> str; 8472+ for (size_t i = 0; i < num; i++) { 8473+ str.push_back(std::string(c[i])); 8474+ } 8475+ return str; 8476+} 8477+ 8478+void NetTrainCApi::InitTrainCfg() { 8479+ if (flags_->loss_name_.empty()) { 8480+ return; 8481+ } 8482+ 8483+ std::string delimiter = ","; 8484+ size_t pos = 0; 8485+ std::string token; 8486+ train_cfg_ = OH_AI_TrainCfgCreate(); 8487+ size_t num = 0; 8488+ std::vector<std::string> train_cfg_loss_name; 8489+ OH_AI_TrainCfgSetLossName(train_cfg_, nullptr, train_cfg_loss_name.size()); 8490+ while ((pos = flags_->loss_name_.find(delimiter)) != std::string::npos) { 8491+ token = flags_->loss_name_.substr(0, pos); 8492+ flags_->loss_name_.erase(0, pos + delimiter.length()); // change to delim without deletion 8493+ char **name = OH_AI_TrainCfgGetLossName(train_cfg_, &num); 8494+ train_cfg_loss_name = TransCharArraysToStrVector(name, num); 8495+ train_cfg_loss_name.push_back(token); 8496+ char **loss_name = TransStrVectorToCharArrays(train_cfg_loss_name); 8497+ OH_AI_TrainCfgSetLossName(train_cfg_, const_cast<const char **>(loss_name), train_cfg_loss_name.size()); 8498+ for (size_t i = 0; i < train_cfg_loss_name.size(); i++) { 8499+ free(loss_name[i]); 8500+ } 8501+ free(loss_name); 8502+ for (size_t i = 0; i < num; i++) { 8503+ free(name[i]); 8504+ } 8505+ free(name); 8506+ } 8507+ if (!(flags_->loss_name_.empty())) { 8508+ char **name = OH_AI_TrainCfgGetLossName(train_cfg_, &num); 8509+ train_cfg_loss_name = TransCharArraysToStrVector(name, num); 8510+ train_cfg_loss_name.push_back(flags_->loss_name_); 8511+ char **loss_name = TransStrVectorToCharArrays(train_cfg_loss_name); 8512+ OH_AI_TrainCfgSetLossName(train_cfg_, const_cast<const char **>(loss_name), train_cfg_loss_name.size()); 8513+ for (size_t i = 0; i < train_cfg_loss_name.size(); i++) { 8514+ free(loss_name[i]); 8515+ } 8516+ free(loss_name); 8517+ for (size_t i = 0; i < num; i++) { 8518+ free(name[i]); 8519+ } 8520+ free(name); 8521+ } 8522+} 8523+ 8524+int NetTrainCApi::CreateAndRunNetworkForInference(const std::string &filename, 8525+ const OH_AI_ContextHandle &context) { 8526+ std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1); 8527+ std::string filenamems = filename; 8528+ if (filenamems.substr(filenamems.find_last_of('.') + 1) != "ms") { 8529+ filenamems = filenamems + ".ms"; 8530+ } 8531+ MS_LOG(INFO) << "start reading model file " << filenamems.c_str(); 8532+ std::cout << "start reading model file " << filenamems.c_str() << std::endl; 8533+ auto status = OH_AI_ModelBuildFromFile(ms_model_, filenamems.c_str(), 8534+ static_cast<OH_AI_ModelType>(mindspore::kMindIR), context); 8535+ if (status != OH_AI_STATUS_SUCCESS) { 8536+ MS_LOG(ERROR) << "ms model build failed. " << model_name; 8537+ return RET_ERROR; 8538+ } 8539+ return RET_OK; 8540+} 8541+ 8542+int NetTrainCApi::CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename, 8543+ const OH_AI_ContextHandle &context, 8544+ const OH_AI_TrainCfgHandle &train_cfg, int epochs) { 8545+ std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1); 8546+ OH_AI_Status status; 8547+ if (!bb_filename.empty()) { 8548+ MS_LOG(ERROR) << "build transfer learning not supported. " << model_name; 8549+ return RET_ERROR; 8550+ } else { 8551+ MS_LOG(INFO) << "Build mindspore model from model file" << filename.c_str(); 8552+ std::cout << "Build mindspore model from model file" << filename.c_str() << std::endl; 8553+ status = OH_AI_TrainModelBuildFromFile(ms_model_, filename.c_str(), OH_AI_MODELTYPE_MINDIR, context, train_cfg); 8554+ if (status != OH_AI_STATUS_SUCCESS) { 8555+ MS_LOG(ERROR) << "build transfer learning failed. " << model_name; 8556+ return RET_ERROR; 8557+ } 8558+ } 8559+ if (epochs > 0) { 8560+ if (flags_->virtual_batch_) { 8561+ OH_AI_ModelSetupVirtualBatch(ms_model_, epochs, -1.0f, -1.0f); 8562+ } 8563+ status = OH_AI_ModelSetTrainMode(ms_model_, true); 8564+ if (status != OH_AI_STATUS_SUCCESS) { 8565+ MS_LOG(ERROR) << "set train mode failed. "; 8566+ return RET_ERROR; 8567+ } 8568+ } 8569+ return RET_OK; 8570+} 8571+ 8572+int NetTrainCApi::CompareOutput() { 8573+ std::cout << "================ Comparing Forward Output data ================" << std::endl; 8574+ float total_bias = 0; 8575+ int total_size = 0; 8576+ bool has_error = false; 8577+ auto output_tensors_handle = OH_AI_ModelGetOutputs(ms_model_); 8578+ 8579+ std::vector<mindspore::MSTensor> output_tensors; 8580+ for (size_t i = 0; i < output_tensors_handle.handle_num; i++) { 8581+ output_tensors.push_back(*static_cast<mindspore::MSTensor *>(output_tensors_handle.handle_list[i])); 8582+ } 8583+ if (output_tensors.empty()) { 8584+ MS_LOG(ERROR) << "Cannot find output tensors, get model output failed"; 8585+ return RET_ERROR; 8586+ } 8587+ std::map<std::string, MSTensor> ordered_outputs; 8588+ for (const auto &output_tensor : output_tensors) { 8589+ ordered_outputs.insert({output_tensor.Name(), output_tensor}); 8590+ } 8591+ int i = 1; 8592+ mindspore::MSTensor tensor; 8593+ for (auto &ordered_output : ordered_outputs) { 8594+ tensor = ordered_output.second; 8595+ std::cout << "output is tensor " << ordered_output.first << "\n"; 8596+ auto outputs = tensor.MutableData(); 8597+ size_t size; 8598+ std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin"; 8599+ auto bin_buf = std::unique_ptr<float[]>(ReadFileBuf(output_file.c_str(), &size)); 8600+ if (bin_buf == nullptr) { 8601+ MS_LOG(ERROR) << "ReadFile return nullptr"; 8602+ std::cout << "ReadFile return nullptr" << std::endl; 8603+ return RET_ERROR; 8604+ } 8605+ if (size != tensor.DataSize()) { 8606+ MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize() 8607+ << ", read size: " << size; 8608+ std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize() 8609+ << ", read size: " << size << std::endl; 8610+ return RET_ERROR; 8611+ } 8612+ float bias = CompareData<float>(bin_buf.get(), tensor.ElementNum(), reinterpret_cast<float *>(outputs)); 8613+ if (bias >= 0) { 8614+ total_bias += bias; 8615+ total_size++; 8616+ } else { 8617+ has_error = true; 8618+ break; 8619+ } 8620+ i++; 8621+ } 8622+ 8623+ if (!has_error) { 8624+ float mean_bias; 8625+ if (total_size != 0) { 8626+ mean_bias = total_bias / total_size * 100; 8627+ } else { 8628+ mean_bias = 0; 8629+ } 8630+ 8631+ std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" 8632+ << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl; 8633+ std::cout << "=======================================================" << std::endl << std::endl; 8634+ 8635+ if (mean_bias > this->flags_->accuracy_threshold_) { 8636+ MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%"; 8637+ std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl; 8638+ return RET_TOO_BIG; 8639+ } else { 8640+ return RET_OK; 8641+ } 8642+ } else { 8643+ MS_LOG(ERROR) << "Error in CompareData"; 8644+ std::cerr << "Error in CompareData" << std::endl; 8645+ std::cout << "=======================================================" << std::endl << std::endl; 8646+ return RET_ERROR; 8647+ } 8648+} 8649+ 8650+int NetTrainCApi::MarkPerformance() { 8651+ MS_LOG(INFO) << "Running train loops..."; 8652+ std::cout << "Running train loops..." << std::endl; 8653+ uint64_t time_min = 0xFFFFFFFFFFFFFFFF; 8654+ uint64_t time_max = 0; 8655+ uint64_t time_avg = 0; 8656+ std::vector<MSTensor> outputs; 8657+ 8658+ for (int i = 0; i < flags_->epochs_; i++) { 8659+ auto start = GetTimeUs(); 8660+ for (size_t step = 0; step < batch_num_; step++) { 8661+ MS_LOG(INFO) << "Run for epoch:" << i << " step:" << step; 8662+ auto ret = LoadStepInput(step); 8663+ if (ret != RET_OK) { 8664+ return ret; 8665+ } 8666+ auto status = OH_AI_RunStep(ms_model_, before_call_back_, after_call_back_); 8667+ if (status != OH_AI_STATUS_SUCCESS) { 8668+ MS_LOG(ERROR) << "Inference error " << status; 8669+ std::cerr << "Inference error " << status; 8670+ return RET_ERROR; 8671+ } 8672+ } 8673+ 8674+ auto end = GetTimeUs(); 8675+ auto time = end - start; 8676+ time_min = std::min(time_min, time); 8677+ time_max = std::max(time_max, time); 8678+ time_avg += time; 8679+ } 8680+ 8681+ if (flags_->time_profiling_) { 8682+ const std::vector<std::string> per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"}; 8683+ const std::vector<std::string> per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"}; 8684+ PrintResult(per_op_name, g_c_op_times_by_name_); 8685+ PrintResult(per_op_type, g_c_op_times_by_type_); 8686+ } 8687+ 8688+ if (flags_->epochs_ > 0) { 8689+ time_avg /= static_cast<size_t>(flags_->epochs_); 8690+ MS_LOG(INFO) << "Model = " << flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 8691+ << ", NumThreads = " << flags_->num_threads_ << ", MinRunTime = " << time_min / 1000.0f 8692+ << ", MaxRuntime = " << time_max / 1000.0f << ", AvgRunTime = " << time_avg / 1000.0f; 8693+ printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n", 8694+ flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str(), flags_->num_threads_, 8695+ time_min / 1000.0f, time_max / 1000.0f, time_avg / 1000.0f); 8696+ } 8697+ return RET_OK; 8698+} 8699+ 8700+int NetTrainCApi::MarkAccuracy(bool enforce_accuracy) { 8701+ MS_LOG(INFO) << "MarkAccuracy"; 8702+ auto load_ret = LoadStepInput(0); 8703+ if (load_ret != RET_OK) { 8704+ return load_ret; 8705+ } 8706+ auto status = PrintInputData(); 8707+ if (status != RET_OK) { 8708+ MS_LOG(ERROR) << "PrintInputData failed, ret: " << status; 8709+ return status; 8710+ } 8711+ status = OH_AI_RunStep(ms_model_, before_call_back_, after_call_back_); 8712+ if (status != OH_AI_STATUS_SUCCESS) { 8713+ MS_LOG(ERROR) << "Inference error " << status; 8714+ std::cerr << "Inference error " << status << std::endl; 8715+ return RET_ERROR; 8716+ } 8717+ 8718+ auto ret = CompareOutput(); 8719+ if (ret == RET_TOO_BIG && !enforce_accuracy) { 8720+ MS_LOG(INFO) << "Accuracy Error is big but not enforced"; 8721+ std::cout << "Accuracy Error is big but not enforced" << std::endl; 8722+ return RET_OK; 8723+ } 8724+ 8725+ if (ret != RET_OK) { 8726+ MS_LOG(ERROR) << "Compare output error " << ret; 8727+ std::cerr << "Compare output error " << ret << std::endl; 8728+ return ret; 8729+ } 8730+ return RET_OK; 8731+} 8732+ 8733+int NetTrainCApi::CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, 8734+ int epochs, bool check_accuracy) { 8735+ auto start_prepare_time = GetTimeUs(); 8736+ 8737+ int ret = InitMSContext(); 8738+ if (ret != RET_OK) { 8739+ MS_LOG(ERROR) << "InitContext failed, ret: " << ret; 8740+ return ret; 8741+ } 8742+ 8743+ InitTrainCfg(); 8744+ ms_model_ = OH_AI_ModelCreate(); 8745+ 8746+ if (is_train) { 8747+ ret = CreateAndRunNetworkForTrain(filename, bb_filename, context_ , train_cfg_, epochs); 8748+ if (ret != RET_OK) { 8749+ MS_LOG(ERROR) << "CreateAndRunNetworkForTrain failed."; 8750+ return RET_ERROR; 8751+ } 8752+ } else { 8753+ ret = CreateAndRunNetworkForInference(filename, context_); 8754+ if (ret != RET_OK) { 8755+ MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed."; 8756+ return RET_ERROR; 8757+ } 8758+ } 8759+ 8760+ ms_inputs_for_api_ = OH_AI_ModelGetInputs(ms_model_); 8761+ if (ms_inputs_for_api_.handle_list == nullptr) { 8762+ MS_LOG(ERROR) << "OH_AI_ModelGetInputs failed, ret: "; 8763+ return RET_ERROR; 8764+ } 8765+ 8766+ if (!flags_->resize_dims_.empty()) { 8767+ std::vector<OH_AI_ShapeInfo> shape_infos; 8768+ std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(shape_infos), 8769+ [&](auto &shapes) { 8770+ OH_AI_ShapeInfo shape_info; 8771+ shape_info.shape_num = shapes.size(); 8772+ for (size_t i = 0; i < shape_info.shape_num; i++) { 8773+ shape_info.shape[i] = shapes[i]; 8774+ } 8775+ return shape_info; 8776+ }); 8777+ auto status = OH_AI_ModelResize(ms_model_, ms_inputs_for_api_, shape_infos.data(), shape_infos.size()); 8778+ if (status != OH_AI_STATUS_SUCCESS) { 8779+ MS_LOG(ERROR) << "Input tensor resize failed."; 8780+ std::cout << "Input tensor resize failed."; 8781+ return RET_ERROR; 8782+ } 8783+ } 8784+ 8785+ auto end_prepare_time = GetTimeUs(); 8786+ MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms"; 8787+ std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl; 8788+ // Load input 8789+ MS_LOG(INFO) << "Load input data"; 8790+ auto status = LoadInput(); 8791+ if (status != RET_OK) { 8792+ MS_LOG(ERROR) << "Load input data error"; 8793+ std::cout << "Load input data error" << std::endl; 8794+ return status; 8795+ } 8796+ 8797+ if ((epochs > 0) && is_train) { 8798+ status = MarkPerformance(); 8799+ if (status != RET_OK) { 8800+ MS_LOG(ERROR) << "Run MarkPerformance error: " << status; 8801+ std::cout << "Run MarkPerformance error: " << status << std::endl; 8802+ return status; 8803+ } 8804+ SaveModels(); // save file if flags are on 8805+ } 8806+ if (!flags_->data_file_.empty()) { 8807+ auto res = OH_AI_ModelSetTrainMode(ms_model_, false); 8808+ if (res != OH_AI_STATUS_SUCCESS) { 8809+ MS_LOG(ERROR) << "set eval mode failed. "; 8810+ return RET_ERROR; 8811+ } 8812+ 8813+ status = MarkAccuracy(check_accuracy); 8814+ if (status != RET_OK) { 8815+ MS_LOG(ERROR) << "Run MarkAccuracy error: " << status; 8816+ std::cout << "Run MarkAccuracy error: " << status << std::endl; 8817+ return status; 8818+ } 8819+ } 8820+ return RET_OK; 8821+} 8822+ 8823+int NetTrainCApi::PrintInputData() { 8824+ constexpr int64_t kPrintDataNum = 20; 8825+ for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) { 8826+ auto input = ms_inputs_for_api_.handle_list[i]; 8827+ std::cout << "InData" << i << ": "; 8828+ auto data_type = static_cast<TypeId>(OH_AI_TensorGetDataType(input)); 8829+ if (data_type == TypeId::kObjectTypeString) { 8830+ MS_LOG(ERROR) << "Unsupported OH_AI_DATATYPE_OBJECTTYPE_STRING."; 8831+ return RET_ERROR; 8832+ } 8833+ auto tensor_data = OH_AI_TensorGetData(input); 8834+ size_t print_num = std::min(OH_AI_TensorGetElementNum(input), kPrintDataNum); 8835+ for (size_t j = 0; j < print_num; j++) { 8836+ if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat) { 8837+ std::cout << static_cast<const float *>(tensor_data)[j] << " "; 8838+ } else if (data_type == TypeId::kNumberTypeInt8) { 8839+ std::cout << static_cast<const int8_t *>(tensor_data)[j] << " "; 8840+ } else if (data_type == TypeId::kNumberTypeUInt8) { 8841+ std::cout << static_cast<const uint8_t *>(tensor_data)[j] << " "; 8842+ } else if (data_type == TypeId::kNumberTypeInt32) { 8843+ std::cout << static_cast<const int32_t *>(tensor_data)[j] << " "; 8844+ } else if (data_type == TypeId::kNumberTypeInt64) { 8845+ std::cout << static_cast<const int64_t *>(tensor_data)[j] << " "; 8846+ } else if (data_type == TypeId::kNumberTypeBool) { 8847+ std::cout << static_cast<const bool *>(tensor_data)[j] << " "; 8848+ } else { 8849+ MS_LOG(ERROR) << "Datatype: " << data_type << " is not supported."; 8850+ return RET_ERROR; 8851+ } 8852+ } 8853+ std::cout << std::endl; 8854+ } 8855+ return RET_OK; 8856+} 8857+ 8858+int NetTrainCApi::PrintResult(const std::vector<std::string> &title, 8859+ const std::map<std::string, std::pair<int, float>> &result) { 8860+ std::vector<size_t> columnLenMax(kFieldsToPrint); 8861+ std::vector<std::vector<std::string>> rows; 8862+ 8863+ for (auto &iter : result) { 8864+ std::string stringBuf[kFieldsToPrint]; 8865+ std::vector<std::string> columns; 8866+ size_t len = 0; 8867+ int index = 0; 8868+ len = iter.first.size(); 8869+ if (len > columnLenMax.at(index)) { 8870+ columnLenMax.at(index) = len + kPrintOffset; 8871+ } 8872+ columns.push_back(iter.first); 8873+ 8874+ index++; 8875+ if (title[0] == "opName") { 8876+ stringBuf[index] = std::to_string(iter.second.second / flags_->epochs_); 8877+ } else { 8878+ stringBuf[index] = std::to_string(iter.second.second / iter.second.first); 8879+ } 8880+ len = stringBuf[index].length(); 8881+ if (len > columnLenMax.at(index)) { 8882+ columnLenMax.at(index) = len + kPrintOffset; 8883+ } 8884+ columns.emplace_back(stringBuf[index]); 8885+ 8886+ index++; 8887+ stringBuf[index] = std::to_string(iter.second.second / g_op_cost_total_); 8888+ len = stringBuf[index].length(); 8889+ if (len > columnLenMax.at(index)) { 8890+ columnLenMax.at(index) = len + kPrintOffset; 8891+ } 8892+ columns.emplace_back(stringBuf[index]); 8893+ 8894+ index++; 8895+ stringBuf[index] = std::to_string(iter.second.first); 8896+ len = stringBuf[index].length(); 8897+ if (len > columnLenMax.at(index)) { 8898+ columnLenMax.at(index) = len + kPrintOffset; 8899+ } 8900+ columns.emplace_back(stringBuf[index]); 8901+ 8902+ index++; 8903+ stringBuf[index] = std::to_string(iter.second.second); 8904+ len = stringBuf[index].length(); 8905+ if (len > columnLenMax.at(index)) { 8906+ columnLenMax.at(index) = len + kPrintOffset; 8907+ } 8908+ columns.emplace_back(stringBuf[index]); 8909+ 8910+ rows.push_back(columns); 8911+ } 8912+ 8913+ printf("-------------------------------------------------------------------------\n"); 8914+ for (int i = 0; i < kFieldsToPrint; i++) { 8915+ auto printBuf = title[i]; 8916+ if (printBuf.size() > columnLenMax.at(i)) { 8917+ columnLenMax.at(i) = printBuf.size(); 8918+ } 8919+ printBuf.resize(columnLenMax.at(i), ' '); 8920+ printf("%s\t", printBuf.c_str()); 8921+ } 8922+ printf("\n"); 8923+ for (auto &row : rows) { 8924+ for (int j = 0; j < kFieldsToPrint; j++) { 8925+ auto printBuf = row[j]; 8926+ printBuf.resize(columnLenMax.at(j), ' '); 8927+ printf("%s\t", printBuf.c_str()); 8928+ } 8929+ printf("\n"); 8930+ } 8931+ return RET_OK; 8932+} 8933+ 8934+bool TimeProfilingBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 8935+ const OH_AI_CallBackParam kernel_Info) { 8936+ if (g_c_op_times_by_type_.find(kernel_Info.node_type) == g_c_op_times_by_type_.end()) { 8937+ g_c_op_times_by_type_.insert(std::make_pair(kernel_Info.node_type, std::make_pair(0, 0.0f))); 8938+ } 8939+ if (g_c_op_times_by_name_.find(kernel_Info.node_name) == g_c_op_times_by_name_.end()) { 8940+ g_c_op_times_by_name_.insert(std::make_pair(kernel_Info.node_name, std::make_pair(0, 0.0f))); 8941+ } 8942+ 8943+ g_op_call_times_total_++; 8944+ g_op_begin_ = mindspore::lite::GetTimeUs(); 8945+ return true; 8946+} 8947+ 8948+bool TimeProfilingAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 8949+ const OH_AI_CallBackParam kernel_Info) { 8950+ uint64_t opEnd = mindspore::lite::GetTimeUs(); 8951+ float cost = static_cast<float>(opEnd - g_op_begin_) / 1000.0f; 8952+ g_op_cost_total_ += cost; 8953+ g_c_op_times_by_type_[kernel_Info.node_type].first++; 8954+ g_c_op_times_by_type_[kernel_Info.node_type].second += cost; 8955+ g_c_op_times_by_name_[kernel_Info.node_name].first++; 8956+ g_c_op_times_by_name_[kernel_Info.node_name].second += cost; 8957+ return true; 8958+} 8959+} // namespace lite 8960+} // namespace mindspore 8961+ 8962+ 8963diff --git a/mindspore/lite/tools/benchmark_train/net_train_c_api.h b/mindspore/lite/tools/benchmark_train/net_train_c_api.h 8964new file mode 100644 8965index 00000000..bb84d3c1 8966--- /dev/null 8967+++ b/mindspore/lite/tools/benchmark_train/net_train_c_api.h 8968@@ -0,0 +1,121 @@ 8969+/** 8970+ * Copyright 2023-2023 Huawei Technologies Co., Ltd 8971+ * 8972+ * Licensed under the Apache License, Version 2.0 (the "License"); 8973+ * you may not use this file except in compliance with the License. 8974+ * You may obtain a copy of the License at 8975+ * 8976+ * http://www.apache.org/licenses/LICENSE-2.0 8977+ * 8978+ * Unless required by applicable law or agreed to in writing, software 8979+ * distributed under the License is distributed on an "AS IS" BASIS, 8980+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8981+ * See the License for the specific language governing permissions and 8982+ * limitations under the License. 8983+ */ 8984+ 8985+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H 8986+#define MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H 8987+ 8988+#include <getopt.h> 8989+#include <csignal> 8990+#include <unordered_map> 8991+#include <fstream> 8992+#include <iostream> 8993+#include <map> 8994+#include <cmath> 8995+#include <string> 8996+#include <vector> 8997+#include <memory> 8998+#include <cfloat> 8999+#include <utility> 9000+#include <algorithm> 9001+#include <nlohmann/json.hpp> 9002+#include "include/api/model.h" 9003+#include "include/api/types.h" 9004+#include "include/api/context.h" 9005+#include "include/api/cfg.h" 9006+ 9007+#include "include/c_api/model_c.h" 9008+#include "include/c_api/context_c.h" 9009+ 9010+#ifdef ENABLE_FP16 9011+#include <arm_neon.h> 9012+#endif 9013+#include "tools/common/flag_parser.h" 9014+#include "src/common/file_utils.h" 9015+#include "src/common/utils.h" 9016+#include "tools/benchmark_train/net_train_base.h" 9017+ 9018+namespace mindspore::lite { 9019+ namespace { 9020+ std::map<std::string, std::pair<int, float>> g_c_op_times_by_type_; 9021+ std::map<std::string, std::pair<int, float>> g_c_op_times_by_name_; 9022+ } 9023+#ifdef __cplusplus 9024+ extern "C" { 9025+#endif 9026+ bool TimeProfilingBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 9027+ const OH_AI_CallBackParam kernel_Info); 9028+ bool TimeProfilingAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 9029+ const OH_AI_CallBackParam kernel_Info); 9030+#ifdef __cplusplus 9031+ } 9032+#endif 9033+ 9034+class MS_API NetTrainCApi : public NetTrainBase { 9035+ public: 9036+ explicit NetTrainCApi(NetTrainFlags *flags) : NetTrainBase(flags) {} 9037+ virtual ~NetTrainCApi() {}; 9038+ 9039+ protected: 9040+ // call GenerateRandomData to fill inputTensors 9041+ int GenerateInputData() override; 9042+ 9043+ int ReadInputFile() override; 9044+ 9045+ int LoadStepInput(size_t step); 9046+ 9047+ int InitMSContext(); 9048+ 9049+ void InitTrainCfg(); 9050+ 9051+ char **TransStrVectorToCharArrays(const std::vector<std::string> &s); 9052+ 9053+ std::vector<std::string> TransCharArraysToStrVector(char **c, const size_t &num); 9054+ 9055+ int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs, 9056+ bool check_accuracy = true) override; 9057+ 9058+ int CreateAndRunNetworkForInference(const std::string &filename, const OH_AI_ContextHandle &context); 9059+ 9060+ int CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename, 9061+ const OH_AI_ContextHandle &context, 9062+ const OH_AI_TrainCfgHandle &train_cfg, int epochs); 9063+ 9064+ int InitDumpTensorDataCallbackParameter() override; 9065+ 9066+ int InitTimeProfilingCallbackParameter() override; 9067+ 9068+ int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) override; 9069+ 9070+ int PrintInputData(); 9071+ 9072+ int MarkPerformance() override; 9073+ 9074+ int MarkAccuracy(bool enforce_accuracy = true) override; 9075+ 9076+ int CompareOutput() override; 9077+ 9078+ int SaveModels() override; 9079+ 9080+ OH_AI_ModelHandle ms_model_; 9081+ OH_AI_TensorHandleArray ms_inputs_for_api_; 9082+ OH_AI_ContextHandle context_ = nullptr; 9083+ OH_AI_TrainCfgHandle train_cfg_ = nullptr; 9084+ OH_AI_KernelCallBack before_call_back_{nullptr}; 9085+ OH_AI_KernelCallBack after_call_back_{nullptr}; 9086+}; 9087+} // namespace mindspore::lite 9088+ 9089+#endif //MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H 9090diff --git a/mindspore/lite/tools/benchmark_train/run_net_train.cc b/mindspore/lite/tools/benchmark_train/run_net_train.cc 9091new file mode 100644 9092index 00000000..37a7e602 9093--- /dev/null 9094+++ b/mindspore/lite/tools/benchmark_train/run_net_train.cc 9095@@ -0,0 +1,86 @@ 9096+/** 9097+ * Copyright 2020 Huawei Technologies Co., Ltd 9098+ * 9099+ * Licensed under the Apache License, Version 2.0 (the "License"); 9100+ * you may not use this file except in compliance with the License. 9101+ * You may obtain a copy of the License at 9102+ * 9103+ * http://www.apache.org/licenses/LICENSE-2.0 9104+ * 9105+ * Unless required by applicable law or agreed to in writing, software 9106+ * distributed under the License is distributed on an "AS IS" BASIS, 9107+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9108+ * See the License for the specific language governing permissions and 9109+ * limitations under the License. 9110+ */ 9111+ 9112+#include "tools/benchmark_train/run_net_train.h" 9113+#include "tools/benchmark_train/net_train.h" 9114+#include "tools/benchmark_train/net_train_c_api.h" 9115+ 9116+namespace mindspore { 9117+namespace lite { 9118+int RunNetTrain(int argc, const char **argv) { 9119+ NetTrainFlags flags; 9120+ Option<std::string> err = flags.ParseFlags(argc, argv); 9121+ 9122+ if (err.IsSome()) { 9123+ std::cerr << err.Get() << std::endl; 9124+ std::cerr << flags.Usage() << std::endl; 9125+ return RET_ERROR; 9126+ } 9127+ 9128+ if (flags.help) { 9129+ std::cerr << flags.Usage() << std::endl; 9130+ return RET_OK; 9131+ } 9132+ if (flags.unified_api_) { 9133+ return NetTrain::RunNr(&flags); 9134+ } 9135+ 9136+ auto api_type = std::getenv("MSLITE_API_TYPE"); 9137+ if (api_type != nullptr) { 9138+ MS_LOG(INFO) << "MSLITE_API_TYPE = " << api_type; 9139+ std::cout << "MSLITE_API_TYPE = " << api_type << std::endl; 9140+ } 9141+ 9142+ NetTrainBase *net_trainer = nullptr; 9143+ if (api_type == nullptr || std::string(api_type) == "NEW") { 9144+ net_trainer = new (std::nothrow) NetTrain(&flags); 9145+ } else if (std::string(api_type) == "C") { 9146+ net_trainer = new (std::nothrow) NetTrainCApi(&flags); 9147+ } else { 9148+ MS_LOG(ERROR) << "Invalid MSLITE_API_TYPE, (NEW/C, default:NEW)"; 9149+ return RET_ERROR; 9150+ } 9151+ 9152+ if (net_trainer == nullptr) { 9153+ MS_LOG(ERROR) << "new net_trainer failed."; 9154+ return RET_ERROR; 9155+ } 9156+ auto status = net_trainer->Init(); 9157+ if (status != RET_OK) { 9158+ MS_LOG(ERROR) << "NetTrain init Error : " << status; 9159+ std::cerr << "NetTrain init Error : " << status << std::endl; 9160+ return RET_ERROR; 9161+ } 9162+ 9163+ status = net_trainer->RunNetTrain(); 9164+ if (status != RET_OK) { 9165+ MS_LOG(ERROR) << "Run NetTrain " 9166+ << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str() 9167+ << " Failed : " << status; 9168+ std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str() 9169+ << " Failed : " << status << std::endl; 9170+ return RET_ERROR; 9171+ } 9172+ 9173+ MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str() 9174+ << " Success."; 9175+ std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str() 9176+ << " Success." << std::endl; 9177+ delete net_trainer; 9178+ return RET_OK; 9179+} 9180+} // namespace lite 9181+} // namespace mindspore 9182\ No newline at end of file 9183diff --git a/mindspore/lite/tools/benchmark_train/run_net_train.h b/mindspore/lite/tools/benchmark_train/run_net_train.h 9184new file mode 100644 9185index 00000000..9ca2d73c 9186--- /dev/null 9187+++ b/mindspore/lite/tools/benchmark_train/run_net_train.h 9188@@ -0,0 +1,22 @@ 9189+/** 9190+ * Copyright 2023-2023 Huawei Technologies Co., Ltd 9191+ * 9192+ * Licensed under the Apache License, Version 2.0 (the "License"); 9193+ * you may not use this file except in compliance with the License. 9194+ * You may obtain a copy of the License at 9195+ * 9196+ * http://www.apache.org/licenses/LICENSE-2.0 9197+ * 9198+ * Unless required by applicable law or agreed to in writing, software 9199+ * distributed under the License is distributed on an "AS IS" BASIS, 9200+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9201+ * See the License for the specific language governing permissions and 9202+ * limitations under the License. 9203+ */ 9204+ 9205+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H 9206+#define MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H 9207+namespace mindspore::lite { 9208+int RunNetTrain(int argc, const char **argv); 9209+} // namespace mindspore::lite 9210+#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H 9211diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt 9212index 1e09d2ed..f854620f 100644 9213--- a/mindspore/lite/tools/converter/CMakeLists.txt 9214+++ b/mindspore/lite/tools/converter/CMakeLists.txt 9215@@ -7,6 +7,8 @@ endif() 9216 9217 set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) 9218 9219+include_directories(${CMAKE_SOURCE_DIR}/mindspore/lite/) 9220+ 9221 if(ENABLE_GPU) 9222 add_compile_definitions(ENABLE_GPU) 9223 endif() 9224@@ -70,6 +72,7 @@ add_subdirectory(parser/caffe) 9225 add_subdirectory(parser/tflite) 9226 add_subdirectory(parser/onnx) 9227 add_subdirectory(parser/tf) 9228+add_subdirectory(parser/third_party) 9229 if(ENABLE_CONVERT_PYTORCH_MODEL) 9230 add_subdirectory(parser/pytorch) 9231 endif() 9232@@ -363,6 +366,7 @@ target_link_libraries(mindspore_converter 9233 tf_parser_mid 9234 caffe_parser_mid 9235 onnx_parser_mid 9236+ third_party_parser_mid 9237 lite_exporter_mid 9238 graph_pass_mid 9239 fusion_mid 9240diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9241index fecc56d9..2e7ca749 100644 9242--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9243+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9244@@ -34,6 +34,7 @@ constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param"; 9245 constexpr auto kDataPreprocessParam = "data_preprocess_param"; 9246 constexpr auto kRegistry = "registry"; 9247 constexpr auto kMicroParam = "micro_param"; 9248+constexpr auto kThirdPartyModelParam = "third_party_model"; 9249 constexpr auto kCpuOptionParam = "cpu_option_cfg_param"; 9250 constexpr auto kCustomOppPath = "custom_opp_path"; 9251 constexpr auto kTransformQuantParam = "transform_quant_param"; 9252@@ -330,6 +331,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin 9253 MS_LOG(ERROR) << "ParseMicroParamString failed."; 9254 return ret; 9255 } 9256+ ret = ParseThirdPartyParamString(*maps); 9257+ (void)maps->erase(kThirdPartyModelParam); 9258+ if (ret != RET_OK) { 9259+ MS_LOG(ERROR) << "ParseTransformQuantString failed."; 9260+ return ret; 9261+ } 9262 ret = ParseWeightQuantString(*maps); 9263 (void)maps->erase(kWeightQuantParam); 9264 if (ret != RET_OK) { 9265@@ -594,5 +601,25 @@ int ConfigFileParser::ParseGraphKernelString(const std::map<std::string, std::ma 9266 } 9267 return RET_OK; 9268 } 9269+ 9270+int ConfigFileParser::ParseThirdPartyParamString( 9271+ const std::map<std::string, std::map<std::string, std::string>> §ions) { 9272+ if (sections.find(kThirdPartyModelParam) == sections.end()) { 9273+ return RET_OK; 9274+ } 9275+ const auto &input_args = sections.at(kThirdPartyModelParam); 9276+ const std::map<std::string, std::string &> kValidArgs = { 9277+ {"input_shapes", third_party_model_string_.input_shapes}, 9278+ {"input_dtypes", third_party_model_string_.input_dtypes}, 9279+ {"input_names", third_party_model_string_.input_names}, 9280+ {"input_formats", third_party_model_string_.input_formats}, 9281+ {"output_shapes", third_party_model_string_.output_shapes}, 9282+ {"output_dtypes", third_party_model_string_.output_dtypes}, 9283+ {"output_names", third_party_model_string_.output_names}, 9284+ {"output_formats", third_party_model_string_.output_formats}, 9285+ {"extended_parameters", third_party_model_string_.extended_parameters}, 9286+ }; 9287+ return SetMapData(input_args, kValidArgs, kThirdPartyModelParam); 9288+} 9289 } // namespace lite 9290 } // namespace mindspore 9291diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9292index 31269816..6997bac8 100644 9293--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9294+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9295@@ -110,6 +110,18 @@ struct MicroParamString { 9296 std::string changeable_weights_name; 9297 }; 9298 9299+struct ThirdPartyModelString { 9300+ std::string input_dtypes; 9301+ std::string input_shapes; 9302+ std::string input_names; // optional, default: "" 9303+ std::string input_formats; // optional, default: NHWC 9304+ std::string output_dtypes; 9305+ std::string output_shapes; 9306+ std::string output_names; // optional, default: "" 9307+ std::string output_formats; // optional, default: NHWC 9308+ std::string extended_parameters; // format: {key1:value1;ker2:value2} 9309+}; 9310+ 9311 struct CpuOptionCfgString { 9312 std::string architecture; 9313 std::string instruction; 9314@@ -144,6 +156,7 @@ class ConfigFileParser { 9315 RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; } 9316 AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; } 9317 MicroParamString GetMicroParamString() { return this->micro_param_string_; } 9318+ lite::ThirdPartyModelString GetThirdPartyModelString() const { return this->third_party_model_string_; } 9319 CpuOptionCfgString GetCpuOptionCfgString() { return this->cpu_option_cfg_string_; } 9320 TransformQuantString GetTransformQuantString() const { return this->transform_quant_string_; } 9321 AscendQuantString GetAscendQuantString() const { return this->ascend_quant_string_; } 9322@@ -161,6 +174,7 @@ class ConfigFileParser { 9323 int SetMapData(const std::map<std::string, std::string> &input_map, 9324 const std::map<std::string, std::string &> &parse_map, const std::string §ion); 9325 int ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9326+ int ParseThirdPartyParamString(const std::map<std::string, std::map<std::string, std::string>> §ions); 9327 int ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9328 int ParseTransformQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9329 int ParseAscendQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9330@@ -176,6 +190,7 @@ class ConfigFileParser { 9331 RegistryInfoString registry_info_string_; 9332 AclOptionCfgString acl_option_cfg_string_; 9333 MicroParamString micro_param_string_; 9334+ lite::ThirdPartyModelString third_party_model_string_; 9335 CpuOptionCfgString cpu_option_cfg_string_; 9336 TransformQuantString transform_quant_string_; 9337 AscendQuantString ascend_quant_string_; 9338diff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 9339new file mode 100644 9340index 00000000..aee6a29c 9341--- /dev/null 9342+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 9343@@ -0,0 +1,299 @@ 9344+/** 9345+ * Copyright 2023 Huawei Technologies Co., Ltd 9346+ * 9347+ * Licensed under the Apache License, Version 2.0 (the "License"); 9348+ * you may not use this file except in compliance with the License. 9349+ * You may obtain a copy of the License at 9350+ * 9351+ * http://www.apache.org/licenses/LICENSE-2.0 9352+ * 9353+ * Unless required by applicable law or agreed to in writing, software 9354+ * distributed under the License is distributed on an "AS IS" BASIS, 9355+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9356+ * See the License for the specific language governing permissions and 9357+ * limitations under the License. 9358+ */ 9359+ 9360+#include "tools/converter/config_parser/third_party_param_parser.h" 9361+#include <vector> 9362+#include <string> 9363+#include <map> 9364+#include "include/errorcode.h" 9365+#include "src/common/log_adapter.h" 9366+#include "nnacl/op_base.h" 9367+#include "tools/common/string_util.h" 9368+ 9369+namespace mindspore { 9370+namespace lite { 9371+namespace { 9372+const std::map<std::string, TypeId> kDataTypeMap = { 9373+ {"float64", TypeId::kNumberTypeFloat64}, {"float32", TypeId::kNumberTypeFloat32}, 9374+ {"float16", TypeId::kNumberTypeFloat16}, {"int64", TypeId::kNumberTypeInt64}, 9375+ {"int32", TypeId::kNumberTypeInt32}, {"int16", TypeId::kNumberTypeInt16}, 9376+ {"int8", TypeId::kNumberTypeInt8}, {"uint8", TypeId::kNumberTypeUInt8}, 9377+ {"bool", TypeId::kNumberTypeBool}, 9378+}; 9379+ 9380+TypeId ConvertDataType(const std::string &type) { 9381+ auto iter = kDataTypeMap.find(type); 9382+ if (iter == kDataTypeMap.end()) { 9383+ return TypeId::kTypeUnknown; 9384+ } 9385+ return iter->second; 9386+} 9387+} // namespace 9388+ 9389+/** 9390+ * Parse shapes like "1,256,256,3;3,96;96,96", and return like [[1,256,256,3], [3,96], [96,96]]. 9391+ */ 9392+int ThirdPartyParamParser::DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes) { 9393+ MS_CHECK_TRUE_RET(dst_shapes != nullptr, RET_ERROR); 9394+ dst_shapes->clear(); 9395+ 9396+ auto tmp_shapes = SplitStringToVector(src, ";"); 9397+ for (auto tmp_shape : tmp_shapes) { 9398+ auto tmp = SplitStringToVector(tmp_shape, ","); 9399+ std::vector<int64_t> shape = {}; 9400+ for (auto t : tmp) { 9401+ int value = 0; 9402+ if (!ConvertIntNum(t, &value)) { 9403+ MS_LOG(ERROR) << "Found error when convert shape string to integer"; 9404+ return RET_ERROR; 9405+ } 9406+ if (value <= 0) { // Valid shape value should be greater than 0. 9407+ MS_LOG(ERROR) << "Only support fixed shapes in third party param"; 9408+ return RET_ERROR; 9409+ } 9410+ shape.push_back(value); 9411+ } 9412+ dst_shapes->push_back(shape); 9413+ } 9414+ return RET_OK; 9415+} 9416+ 9417+/** 9418+ * Parse extended parameter like "key_1:value_1;key_2:value_2" and get {{"key_1", "value_1"}, {"key_2", "value_2"}}. 9419+ */ 9420+int ThirdPartyParamParser::DoParseExtendedParameters(const std::string &src, 9421+ std::map<std::string, std::vector<uint8_t>> *dst_ext_param) { 9422+ MS_CHECK_TRUE_RET(dst_ext_param != nullptr, RET_ERROR); 9423+ constexpr size_t kKeyIndex = 0U; 9424+ constexpr size_t kValueIndex = 1U; 9425+ constexpr size_t kKeyValueSize = 2U; 9426+ 9427+ if (src == "") { // Just return if 'extended_parameters' is configured. 9428+ return RET_OK; 9429+ } 9430+ 9431+ auto tmp_list = SplitStringToVector(src, ";"); 9432+ std::map<std::string, std::vector<uint8_t>> tmp_map = {}; 9433+ for (auto tmp : tmp_list) { 9434+ auto key_and_value = SplitStringToVector(tmp, ":"); 9435+ if (key_and_value.size() != kKeyValueSize) { 9436+ MS_LOG(ERROR) << "Parse extended parameters failed, should keep key:value format"; 9437+ return RET_ERROR; 9438+ } 9439+ auto key = key_and_value[kKeyIndex]; 9440+ auto value = key_and_value[kValueIndex]; 9441+ if (tmp_map.find(key) != tmp_map.end()) { 9442+ MS_LOG(ERROR) << "Parse extended parameters failed, key should not be duplicated"; 9443+ return RET_ERROR; 9444+ } 9445+ tmp_map.emplace(key, std::vector<uint8_t>(value.begin(), value.end())); 9446+ } 9447+ 9448+ *dst_ext_param = tmp_map; 9449+ return RET_OK; 9450+} 9451+ 9452+/** 9453+ * Parse dtypes like "float32;float32;int32" and return [kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32] 9454+ */ 9455+int ThirdPartyParamParser::DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes) { 9456+ MS_CHECK_TRUE_RET(dst_dtypes != nullptr, RET_ERROR); 9457+ dst_dtypes->clear(); 9458+ auto tmp_dtypes = SplitStringToVector(src, ";"); 9459+ for (auto tmp_dtype : tmp_dtypes) { 9460+ TypeId type = ConvertDataType(tmp_dtype); 9461+ if (type == kTypeUnknown) { 9462+ MS_LOG(ERROR) << "Parse dtypes in third party model config failed"; 9463+ return RET_ERROR; 9464+ } 9465+ dst_dtypes->push_back(type); 9466+ } 9467+ return RET_OK; 9468+} 9469+ 9470+/** 9471+ * Parse names like "foo;bar;boo" and get ["foo", "bar", "boo"] 9472+ * If input names are not provided in config, use the default prefix to generate like: "in_0;in_1;..;in_n" 9473+ */ 9474+int ThirdPartyParamParser::DoParseNames(const std::string &src, size_t num, const std::string &default_prefix, 9475+ std::vector<std::string> *dst_names) { 9476+ MS_CHECK_TRUE_RET(dst_names != nullptr, RET_ERROR); 9477+ std::string tmp_names = src; 9478+ if (tmp_names.empty()) { 9479+ std::string tmp = ""; 9480+ for (size_t i = 0; i < num; i++) { 9481+ tmp += default_prefix + "_" + std::to_string(i); 9482+ if (i + 1 < num) { 9483+ tmp += ";"; 9484+ } 9485+ } 9486+ tmp_names = tmp; 9487+ } 9488+ 9489+ *dst_names = SplitStringToVector(tmp_names, ";"); 9490+ if (dst_names->size() != num) { 9491+ MS_LOG(ERROR) << "Name number " << dst_names->size() << " and input number: " << num << " are not equal"; 9492+ return RET_ERROR; 9493+ } 9494+ return RET_OK; 9495+} 9496+ 9497+/** 9498+ * Parse formats like "NCHW;NHWC" and get [NCHW, NHWC] 9499+ */ 9500+namespace { 9501+ int StringToFormat(const std::string &format_string, schema::Format *format) { 9502+ static const std::unordered_map<std::string, schema::Format> kFormatTable = { 9503+ {"NCHW", schema::Format::Format_NCHW}, 9504+ {"NHWC", schema::Format::Format_NHWC}, 9505+ {"NHWC4", schema::Format::Format_NHWC4}, 9506+ {"HWKC", schema::Format::Format_HWKC}, 9507+ {"HWCK", schema::Format::Format_HWCK}, 9508+ {"KCHW", schema::Format::Format_KCHW}, 9509+ {"CKHW", schema::Format::Format_CKHW}, 9510+ {"KHWC", schema::Format::Format_KHWC}, 9511+ {"CHWK", schema::Format::Format_CHWK}, 9512+ {"HW", schema::Format::Format_HW}, 9513+ {"HW4", schema::Format::Format_HW4}, 9514+ {"NC", schema::Format::Format_NC}, 9515+ {"NC4", schema::Format::Format_NC4}, 9516+ {"NC4HW4", schema::Format::Format_NC4HW4}, 9517+ {"NUM_OF_FORMAT", schema::Format::Format_NUM_OF_FORMAT}, 9518+ {"NCDHW", schema::Format::Format_NCDHW}, 9519+ {"NWC", schema::Format::Format_NWC}, 9520+ {"NCW", schema::Format::Format_NCW}, 9521+ }; 9522+ 9523+ if (format == nullptr) { 9524+ return RET_NULL_PTR; 9525+ } 9526+ 9527+ auto iter = kFormatTable.find(format_string); 9528+ if (iter == kFormatTable.end()) { 9529+ return RET_PARAM_INVALID; 9530+ } 9531+ 9532+ *format = iter->second; 9533+ return RET_OK; 9534+ } 9535+} 9536+ 9537+int ThirdPartyParamParser::DoParseFormats(const std::string &src, size_t num, 9538+ std::vector<schema::Format> *result_formats) { 9539+ MS_CHECK_TRUE_RET(result_formats != nullptr, RET_ERROR); 9540+ std::string tmp_names = src; 9541+ if (tmp_names.empty()) { 9542+ std::vector<schema::Format> default_formats(num, schema::Format::Format_NHWC); 9543+ *result_formats = default_formats; 9544+ return RET_OK; 9545+ } 9546+ 9547+ auto format_strings = SplitStringToVector(tmp_names, ";"); 9548+ if (format_strings.size() != num) { 9549+ MS_LOG(ERROR) << "Number of format: " << format_strings.size() << " and number of tensor: " << num << " are not equal"; 9550+ return RET_ERROR; 9551+ } 9552+ 9553+ std::vector<schema::Format> result(num); 9554+ for (size_t i = 0; i < num; i++) { 9555+ if (StringToFormat(format_strings[i], &result[i]) != RET_OK) { 9556+ MS_LOG(ERROR) << "Tensor format:" << format_strings[i] << " is invalid"; 9557+ return RET_PARAM_INVALID; 9558+ } 9559+ } 9560+ *result_formats = result; 9561+ return RET_OK; 9562+} 9563+ 9564+int ThirdPartyParamParser::Parse(const ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param) { 9565+ MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR); 9566+ 9567+ auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes)); 9568+ if (ret != RET_OK) { 9569+ MS_LOG(ERROR) << "Parse input shapes of third party param failed"; 9570+ return RET_ERROR; 9571+ } 9572+ 9573+ ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes)); 9574+ if (ret != RET_OK) { 9575+ MS_LOG(ERROR) << "Parse input dtypes of third party param failed"; 9576+ return RET_ERROR; 9577+ } 9578+ 9579+ auto input_shape_num = param->input_shapes.size(); 9580+ auto input_dtype_num = param->input_dtypes.size(); 9581+ if (input_shape_num != input_dtype_num) { 9582+ MS_LOG(ERROR) << "Input shape number: " << input_shape_num << " and dtype number: " << input_dtype_num 9583+ << " are not equal"; 9584+ return RET_ERROR; 9585+ } 9586+ 9587+ ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats)); 9588+ if (ret != RET_OK) { 9589+ MS_LOG(ERROR) << "Parse input formats of third party param failed"; 9590+ return RET_ERROR; 9591+ } 9592+ 9593+ const std::string kInputNamePrefix = "in"; 9594+ ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names)); 9595+ if (ret != RET_OK) { 9596+ MS_LOG(ERROR) << "Parse input names of third party param failed"; 9597+ return RET_ERROR; 9598+ } 9599+ 9600+ ret = DoParseShape(param_string.output_shapes, &(param->output_shapes)); 9601+ if (ret != RET_OK) { 9602+ MS_LOG(ERROR) << "Parse output shaped of third party param failed"; 9603+ return RET_ERROR; 9604+ } 9605+ 9606+ ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes)); 9607+ if (ret != RET_OK) { 9608+ MS_LOG(ERROR) << "Parse output dtypes of third party param failed"; 9609+ return RET_ERROR; 9610+ } 9611+ 9612+ auto output_shape_num = param->output_shapes.size(); 9613+ auto output_dtype_num = param->output_dtypes.size(); 9614+ if (output_shape_num != output_dtype_num) { 9615+ MS_LOG(ERROR) << "Output shape number: " << output_shape_num << " and dtype number: " << output_dtype_num 9616+ << " are not equal"; 9617+ return RET_ERROR; 9618+ } 9619+ 9620+ ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats)); 9621+ if (ret != RET_OK) { 9622+ MS_LOG(ERROR) << "Parse output formats of third party param failed"; 9623+ return RET_ERROR; 9624+ } 9625+ 9626+ const std::string kOutputNamePrefix = "out"; 9627+ ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names)); 9628+ if (ret != RET_OK) { 9629+ MS_LOG(ERROR) << "Parse output names of third party param failed"; 9630+ return RET_ERROR; 9631+ } 9632+ 9633+ ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters)); 9634+ if (ret != RET_OK) { 9635+ MS_LOG(ERROR) << "Parse extended parameter of third party param failed"; 9636+ return RET_ERROR; 9637+ } 9638+ 9639+ return RET_OK; 9640+} 9641+} // namespace lite 9642+} // namespace mindspore 9643diff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 9644new file mode 100644 9645index 00000000..5cf6e8fb 9646--- /dev/null 9647+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 9648@@ -0,0 +1,44 @@ 9649+/** 9650+ * Copyright 2023 Huawei Technologies Co., Ltd 9651+ * 9652+ * Licensed under the Apache License, Version 2.0 (the "License"); 9653+ * you may not use this file except in compliance with the License. 9654+ * You may obtain a copy of the License at 9655+ * 9656+ * http://www.apache.org/licenses/LICENSE-2.0 9657+ * 9658+ * Unless required by applicable law or agreed to in writing, software 9659+ * distributed under the License is distributed on an "AS IS" BASIS, 9660+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9661+ * See the License for the specific language governing permissions and 9662+ * limitations under the License. 9663+ */ 9664+ 9665+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 9666+#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 9667+#include <string> 9668+#include <vector> 9669+#include <map> 9670+#include "include/errorcode.h" 9671+#include "tools/converter/cxx_api/converter_para.h" 9672+#include "tools/converter/config_parser/config_file_parser.h" 9673+ 9674+namespace mindspore { 9675+namespace lite { 9676+class ThirdPartyParamParser { 9677+ public: 9678+ static int Parse(const lite::ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param); 9679+ 9680+ private: 9681+ static int DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes); 9682+ static int DoParseExtendedParameters(const std::string &src, 9683+ std::map<std::string, std::vector<uint8_t>> *dst_ext_param); 9684+ static int DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes); 9685+ static int DoParseNames(const std::string &src, size_t num, const std::string &default_prefix, 9686+ std::vector<std::string> *dst_names); 9687+ static int DoParseFormats(const std::string &src, size_t num, std::vector<schema::Format> *result_formats); 9688+}; 9689+} // namespace lite 9690+} // namespace mindspore 9691+ 9692+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 9693diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc 9694index df3176c2..a61bd51c 100644 9695--- a/mindspore/lite/tools/converter/converter.cc 9696+++ b/mindspore/lite/tools/converter/converter.cc 9697@@ -49,6 +49,7 @@ 9698 #include "tools/converter/config_parser/preprocess_parser.h" 9699 #include "tools/converter/config_parser/quant_param_parser.h" 9700 #include "tools/converter/config_parser/graph_kernel_param_parser.h" 9701+#include "tools/converter/config_parser/third_party_param_parser.h" 9702 #include "tools/converter/converter_funcgraph.h" 9703 #include "tools/converter/converter_metagraph.h" 9704 #include "tools/common/string_util.h" 9705@@ -472,6 +473,12 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std:: 9706 MS_LOG(ERROR) << "Parse mixed bit weight quant param failed."; 9707 return ret; 9708 } 9709+ ret = lite::ThirdPartyParamParser::Parse(config_parser->GetThirdPartyModelString(), 9710+ ¶m->thirdPartyModelParam); 9711+ if (ret != RET_OK) { 9712+ MS_LOG(ERROR) << "Parse third party param failed."; 9713+ return ret; 9714+ } 9715 ret = InitExtendedIntegrationInfo(param, *config_parser); 9716 if (ret != RET_OK) { 9717 MS_LOG(ERROR) << "Parse extended integration info failed."; 9718@@ -699,19 +706,20 @@ std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const s 9719 9720 int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) { 9721 if (param != nullptr) { 9722- std::set valid_values = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 9723- FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 9724- FmkType::kFmkTypeMsLite}; 9725- if (std::find(valid_values.begin(), valid_values.end(), param->fmk_type) == valid_values.end()) { 9726- MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be " 9727- "kFmkTypeTf|kFmkTypeCaffe|kFmkTypeOnnx|kFmkTypeMs|kFmkTypeTflite|kFmkTypeMsLite" 9728- << ", but got " << param->fmk_type; 9729- return RET_INPUT_PARAM_INVALID; 9730- } 9731- if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) { 9732- MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag"; 9733- return RET_INPUT_PARAM_INVALID; 9734- } 9735+ return RET_OK; 9736+ } 9737+ std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 9738+ FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 9739+ FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty}; 9740+ if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) { 9741+ MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be " 9742+ "TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|MSLITE|THIRDPARTY" 9743+ << ", but got " << param->fmk_type; 9744+ return RET_INPUT_PARAM_INVALID; 9745+ } 9746+ if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) { 9747+ MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag"; 9748+ return RET_INPUT_PARAM_INVALID; 9749 } 9750 return RET_OK; 9751 } 9752diff --git a/mindspore/lite/tools/converter/converter_funcgraph.cc b/mindspore/lite/tools/converter/converter_funcgraph.cc 9753index f03f995c..61d5c463 100644 9754--- a/mindspore/lite/tools/converter/converter_funcgraph.cc 9755+++ b/mindspore/lite/tools/converter/converter_funcgraph.cc 9756@@ -90,6 +90,7 @@ FuncGraphPtr ConverterFuncGraph::Load3rdModelToFuncgraph(const std::shared_ptr<C 9757 converter_parameters.save_type = param->save_type; 9758 converter_parameters.model_file = param->model_file; 9759 converter_parameters.weight_file = param->weight_file; 9760+ converter_parameters.attrs.emplace("config_file", param->config_file); 9761 func_graph_base = model_parser->Parse(converter_parameters); 9762 if (func_graph_base == nullptr) { 9763 delete model_parser; 9764@@ -447,11 +448,13 @@ STATUS ConverterFuncGraph::Optimize(const std::shared_ptr<ConverterPara> ¶m, 9765 return status; 9766 } 9767 9768- AnfTransform funcgraph_transform; 9769- status = funcgraph_transform.Transform(func_graph, param); 9770- if (status != RET_OK) { 9771- MS_LOG(ERROR) << "Transform anf graph failed."; 9772- return status; 9773+ if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) { 9774+ AnfTransform funcgraph_transform; 9775+ status = funcgraph_transform.Transform(func_graph, param); 9776+ if (status != RET_OK) { 9777+ MS_LOG(ERROR) << "Transform anf graph failed."; 9778+ return status; 9779+ } 9780 } 9781 9782 status = UnifyFuncGraphOutputFormat(param, func_graph); 9783diff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 9784index 4883c48d..024e209f 100644 9785--- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 9786+++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 9787@@ -138,11 +138,11 @@ int Flags::InitFmk() { 9788 // value check not here, it is in converter c++ API's CheckValueParam method. 9789 std::map<std::string, FmkType> StrToEnumFmkTypeMap = { 9790 {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs}, {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx}, 9791- {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"MSLITE", kFmkTypeMsLite}}; 9792+ {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"MSLITE", kFmkTypeMsLite}, {"THIRDPARTY", kFmkTypeThirdParty}}; 9793 if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) { 9794 this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn); 9795 } else { 9796- std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl; 9797+ std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX|PYTORCH|THIRDPARTY" << std::endl; 9798 return RET_INPUT_PARAM_INVALID; 9799 } 9800 9801diff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h 9802index a4f72a69..33210fd0 100644 9803--- a/mindspore/lite/tools/converter/cxx_api/converter_para.h 9804+++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h 9805@@ -21,6 +21,7 @@ 9806 #include <vector> 9807 #include <set> 9808 #include "include/converter.h" 9809+#include "mindapi/base/type_id.h" 9810 #include "tools/converter/quantizer/quant_params.h" 9811 #include "tools/converter/preprocess/preprocess_param.h" 9812 #include "tools/converter/adapter/acl/common/acl_types.h" 9813@@ -35,6 +36,18 @@ struct ParallelSplitConfig { 9814 std::vector<std::string> parallel_devices_; 9815 }; 9816 9817+struct ThirdPartyModelParam { 9818+ std::vector<TypeId> input_dtypes; 9819+ std::vector<std::vector<int64_t>> input_shapes; 9820+ std::vector<std::string> input_names; 9821+ std::vector<schema::Format> input_formats; 9822+ std::vector<TypeId> output_dtypes; 9823+ std::vector<std::vector<int64_t>> output_shapes; 9824+ std::vector<std::string> output_names; 9825+ std::vector<schema::Format> output_formats; 9826+ std::map<std::string, std::vector<uint8_t>> extended_parameters; 9827+}; 9828+ 9829 struct CpuOptionCfg { 9830 std::string architecture; 9831 std::string instruction; 9832@@ -97,6 +110,7 @@ struct ConverterPara { 9833 lite::acl::AclModelOptionCfg aclModelOptionCfgParam; 9834 lite::micro::MicroParam microParam; 9835 ParallelSplitConfig parallel_split_config; 9836+ ThirdPartyModelParam thirdPartyModelParam; 9837 AscendGeOptionCfg ascendGeOptionCfg; 9838 std::string device; 9839 std::string provider; 9840diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc 9841index 90b744e5..bf1a82ae 100644 9842--- a/mindspore/lite/tools/converter/graphdef_transform.cc 9843+++ b/mindspore/lite/tools/converter/graphdef_transform.cc 9844@@ -76,11 +76,55 @@ int QuantTransform(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGrap 9845 } 9846 return RET_OK; 9847 } 9848+ 9849+int FillGraphOutputShape(MetaGraphT *meta_graph, const std::vector<std::vector<int64_t>> output_shapes) { 9850+ const auto &out_indices = meta_graph->outputIndex; 9851+ for (size_t i = 0; i < out_indices.size(); i++) { 9852+ auto &out_tensor = meta_graph->allTensors[out_indices[i]]; 9853+ out_tensor->dims = {}; 9854+ for (size_t k = 0; k < output_shapes[i].size(); k++) { 9855+ out_tensor->dims.push_back(static_cast<int32_t>(output_shapes[i][k])); 9856+ } 9857+ } 9858+ return RET_OK; 9859+} 9860+ 9861+void FillGraphInputAndOutputFormats(MetaGraphT *meta_graph, const ConverterPara ¶) { 9862+ const auto &in_indices = meta_graph->inputIndex; 9863+ for (size_t i = 0; i < in_indices.size(); i++) { 9864+ auto &in_tensor = meta_graph->allTensors[in_indices[i]]; 9865+ in_tensor->format = para.thirdPartyModelParam.input_formats[i]; 9866+ MS_LOG_DEBUG << "input " << i << " format: " << EnumNameFormat(in_tensor->format); 9867+ } 9868+ 9869+ const auto &out_indices = meta_graph->outputIndex; 9870+ for (size_t i = 0; i < out_indices.size(); i++) { 9871+ auto &out_tensor = meta_graph->allTensors[out_indices[i]]; 9872+ out_tensor->format = para.thirdPartyModelParam.output_formats[i]; 9873+ MS_LOG_DEBUG << "output " << i << " format: " << EnumNameFormat(out_tensor->format); 9874+ } 9875+} 9876 } // namespace 9877 9878 int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> ¶m) { 9879 MS_ASSERT(param != nullptr); 9880 STATUS status; 9881+ 9882+ if (param->fmk_type == converter::kFmkTypeThirdParty) { 9883+ 9884+ // Legacy optimizer infer shape, but op Custom which wraps third party model has no infer-shape function. 9885+ // So we don't perform legacy optimization for kFmkTypeThirdParty case. 9886+ auto ret = FillGraphOutputShape(graph_defT_, param->thirdPartyModelParam.output_shapes); 9887+ if (ret != RET_OK) { 9888+ MS_LOG(ERROR) << "Fill output shape of third party model failed, ret:" << ret; 9889+ return ret; 9890+ } 9891+ 9892+ // Tensor of FuncGraph has no attribute of format, so set format in MetaGraph. 9893+ FillGraphInputAndOutputFormats(graph_defT_, *param); 9894+ return RET_OK; 9895+ } 9896+ 9897 { 9898 auto old_nodes = GetGraphNodes(*graph_defT_); 9899 Optimizer unused_op_remove_optimizer; 9900diff --git a/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 9901new file mode 100644 9902index 00000000..b55e0194 9903--- /dev/null 9904+++ b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 9905@@ -0,0 +1,4 @@ 9906+add_library(third_party_parser_mid OBJECT third_party_model_parser.cc) 9907+add_dependencies(third_party_parser_mid proto_mid) 9908+add_dependencies(third_party_parser_mid fbs_src) 9909+add_dependencies(third_party_parser_mid fbs_inner_src) 9910\ No newline at end of file 9911diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 9912new file mode 100644 9913index 00000000..652db4af 9914--- /dev/null 9915+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 9916@@ -0,0 +1,277 @@ 9917+/** 9918+ * Copyright 2023 Huawei Technologies Co., Ltd 9919+ * 9920+ * Licensed under the Apache License, Version 2.0 (the "License"); 9921+ * you may not use this file except in compliance with the License. 9922+ * You may obtain a copy of the License at 9923+ * 9924+ * http://www.apache.org/licenses/LICENSE-2.0 9925+ * 9926+ * Unless required by applicable law or agreed to in writing, software 9927+ * distributed under the License is distributed on an "AS IS" BASIS, 9928+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9929+ * See the License for the specific language governing permissions and 9930+ * limitations under the License. 9931+ */ 9932+#include "tools/converter/parser/third_party/third_party_model_parser.h" 9933+#include <string> 9934+#include <vector> 9935+#include <memory> 9936+#include "ir/value.h" 9937+#include "mindapi/base/type_id.h" 9938+#include "src/common/log_util.h" 9939+#include "src/common/file_utils.h" 9940+#include "nnacl/op_base.h" 9941+#include "ops/primitive_c.h" 9942+#include "ops/custom.h" 9943+#include "ops/tuple_get_item.h" 9944+#include "ops/make_tuple.h" 9945+#include "ops/return.h" 9946+#include "tools/converter/config_parser/config_file_parser.h" 9947+#include "include/registry/model_parser_registry.h" 9948+#include "tools/common/graph_util.h" 9949+#include "tools/common/tensor_util.h" 9950+#include "tools/converter/converter_context.h" 9951+#include "tools/converter/parser/lite_model_parser_creator.h" 9952+ 9953+using mindspore::converter::kFmkTypeThirdParty; 9954+ 9955+namespace mindspore { 9956+namespace lite { 9957+api::FuncGraphPtr ThirdPartyModelParser::Parse(const converter::ConverterParameters &flag) { 9958+ model_file_ = flag.model_file; 9959+ auto &attrs = flag.attrs; 9960+ auto iter = attrs.find("config_file"); 9961+ if (iter == attrs.end()) { 9962+ return nullptr; 9963+ } 9964+ auto config_file = iter->second; 9965+ 9966+ auto ret = InitConfig(config_file); 9967+ if (ret != RET_OK) { 9968+ MS_LOG(ERROR) << "Init config for third party model parsing failed"; 9969+ return nullptr; 9970+ } 9971+ 9972+ return CreateFuncGraph(); 9973+} 9974+ 9975+STATUS ThirdPartyModelParser::InitConfig(const std::string &config_file) { 9976+ lite::ConfigFileParser config_parser; 9977+ if (config_file.empty()) { 9978+ MS_LOG(ERROR) << "Missing config file in converting third party model"; 9979+ return RET_ERROR; 9980+ } 9981+ auto ret = config_parser.ParseConfigFile(config_file); 9982+ if (ret != RET_OK) { 9983+ MS_LOG(ERROR) << "Get third party model section from config file failed"; 9984+ return RET_ERROR; 9985+ } 9986+ 9987+ ret = ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), ¶m_); 9988+ if (ret != RET_OK) { 9989+ MS_LOG(ERROR) << "Parse third party model param failed."; 9990+ return ret; 9991+ } 9992+ return RET_OK; 9993+} 9994+ 9995+api::FuncGraphPtr ThirdPartyModelParser::CreateFuncGraph() { 9996+ auto func_graph = std::make_shared<FuncGraph>(); 9997+ MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); 9998+ auto type_value = MakeValue(static_cast<int>(converter::kFmkTypeThirdParty)); 9999+ MS_CHECK_TRUE_RET(type_value != nullptr, nullptr); 10000+ func_graph->set_attr("fmk", type_value); 10001+ auto attr_value = MakeValue("third_party"); 10002+ MS_CHECK_TRUE_RET(attr_value != nullptr, nullptr); 10003+ func_graph->set_attr("graph_name", attr_value); 10004+ 10005+ std::vector<AnfNodePtr> input_nodes = {}; 10006+ auto ret = BuildGraphInputs(func_graph, &input_nodes); 10007+ if (ret != RET_OK) { 10008+ MS_LOG(ERROR) << "Create func graph input nodes failed"; 10009+ return nullptr; 10010+ } 10011+ 10012+ CNodePtr custom_node = nullptr; 10013+ ret = BuildCustomOp(func_graph, input_nodes, &custom_node); 10014+ if (ret != RET_OK) { 10015+ MS_LOG(ERROR) << "Create func graph custom op node failed"; 10016+ return nullptr; 10017+ } 10018+ 10019+ ret = BuildGraphOutputs(func_graph, custom_node); 10020+ if (ret != RET_OK) { 10021+ MS_LOG(ERROR) << "Create func graph output nodes failed"; 10022+ return nullptr; 10023+ } 10024+ 10025+ static auto manager = Manage(func_graph); 10026+ func_graph->set_manager(manager); 10027+ 10028+ auto result_graph = api::MakeShared<api::FuncGraph>(func_graph); 10029+ return result_graph; 10030+} 10031+ 10032+STATUS ThirdPartyModelParser::BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs) { 10033+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 10034+ auto &dtypes = param_.input_dtypes; 10035+ auto &shapes = param_.input_shapes; 10036+ auto &names = param_.input_names; 10037+ 10038+ auto input_size = dtypes.size(); 10039+ 10040+ // Create parameter nodes for graph inputs 10041+ for (size_t i = 0; i < input_size; i++) { 10042+ auto parameter = func_graph->add_parameter(); 10043+ MSLITE_CHECK_PTR(parameter); 10044+ auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]); 10045+ if (abstract_tensor == nullptr) { 10046+ MS_LOG(ERROR) << "Create tensor abstract failed"; 10047+ return RET_ERROR; 10048+ } 10049+ parameter->set_abstract(abstract_tensor); 10050+ parameter->set_name(names[i]); 10051+ op_inputs->push_back(parameter); 10052+ } 10053+ 10054+ // Create parameter nodes for const tensor which wrapped third model buffer. 10055+ size_t model_size = 0U; 10056+ auto model_data = ReadFile(model_file_.c_str(), &model_size); 10057+ std::vector<int64_t> model_shape = {static_cast<int64_t>(model_size)}; 10058+ auto tensor_info = CreateTensorInfo(nullptr, 0, model_shape, kNumberTypeUInt8); 10059+ if (tensor_info == nullptr) { 10060+ MS_LOG(ERROR) << "init tensor info failed"; 10061+ delete model_data; 10062+ return RET_NULL_PTR; 10063+ } 10064+ auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c()); 10065+ if (memcpy_s(tensor_data, tensor_info->Size(), model_data, model_size) != EOK) { 10066+ MS_LOG(ERROR) << "memcpy failed."; 10067+ delete model_data; 10068+ return RET_ERROR; 10069+ } 10070+ delete model_data; 10071+ auto parameter = func_graph->add_parameter(); 10072+ MSLITE_CHECK_PTR(parameter); 10073+ auto status = InitParameterFromTensorInfo(parameter, tensor_info); 10074+ if (status != RET_OK) { 10075+ MS_LOG(ERROR) << "init parameter from tensor info failed."; 10076+ return RET_ERROR; 10077+ } 10078+ parameter->set_name("ThirdPartyModel"); 10079+ op_inputs->push_back(parameter); 10080+ return RET_OK; 10081+} 10082+ 10083+STATUS ThirdPartyModelParser::BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs, 10084+ CNodePtr *operator_node) { 10085+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 10086+ NotSupportOp::GetInstance()->set_fmk_type("THIRDPARTY"); 10087+ STATUS status = RET_OK; 10088+ 10089+ // create primitive and build CNode of CUSTOM operator 10090+ ops::PrimitiveCPtr primitive_c; 10091+ auto prim = std::make_unique<ops::Custom>(); 10092+ MS_CHECK_TRUE_RET(prim != nullptr, RET_ERROR); 10093+ prim->set_type("ThirdPartyModel"); 10094+ 10095+ const auto &attr = param_.extended_parameters; 10096+ prim->set_attr(attr); 10097+ primitive_c = prim->GetPrim(); 10098+ if (primitive_c == nullptr) { 10099+ MS_LOG(ERROR) << "failed to create primitive: custom"; 10100+ return RET_ERROR; 10101+ } 10102+ 10103+ auto operator_cnode = func_graph->NewCNode(primitive_c, op_inputs); 10104+ MSLITE_CHECK_PTR(operator_cnode); 10105+ operator_cnode->set_fullname_with_scope("Custom"); 10106+ *operator_node = operator_cnode; 10107+ return status; 10108+} 10109+ 10110+STATUS ThirdPartyModelParser::BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node) { 10111+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 10112+ 10113+ auto dtypes = param_.output_dtypes; 10114+ auto shapes = param_.output_shapes; 10115+ auto names = param_.output_names; 10116+ 10117+ auto output_size = dtypes.size(); 10118+ std::vector<AnfNodePtr> output_nodes = {}; 10119+ 10120+ // Use TupleGetItem to wrap op outputs. 10121+ AbstractBasePtrList abstract_list; 10122+ for (size_t i = 0; i < output_size; i++) { 10123+ auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]); 10124+ if (abstract_tensor == nullptr) { 10125+ MS_LOG(ERROR) << "Create tensor abstract failed"; 10126+ return RET_ERROR; 10127+ } 10128+ abstract_list.emplace_back(abstract_tensor); 10129+ auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); 10130+ if (tuple_get_item_prim_ptr == nullptr) { 10131+ MS_LOG(ERROR) << "new TupleGetItem failed"; 10132+ return RET_NULL_PTR; 10133+ } 10134+ auto tuple_get_item_prim_c = tuple_get_item_prim_ptr->GetPrim(); 10135+ MSLITE_CHECK_PTR(tuple_get_item_prim_c); 10136+ auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_c); 10137+ MSLITE_CHECK_PTR(tuple_get_item_prim); 10138+ auto get_item_value = NewValueNode(MakeValue<int>(i)); 10139+ MSLITE_CHECK_PTR(get_item_value); 10140+ std::vector<AnfNodePtr> inputs = {tuple_get_item_prim, operator_node, get_item_value}; 10141+ CNodePtr get_item_cnode = func_graph->NewCNode(inputs); 10142+ MSLITE_CHECK_PTR(get_item_cnode); 10143+ std::string output_item_name = operator_node->fullname_with_scope() + "_getitem_" + std::to_string(i); 10144+ auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32); 10145+ if (get_item_abstract == nullptr) { 10146+ MS_LOG(ERROR) << "Create tensor abstarct failed"; 10147+ return RET_ERROR; 10148+ } 10149+ get_item_cnode->set_fullname_with_scope(output_item_name); 10150+ get_item_cnode->set_abstract(get_item_abstract); 10151+ output_nodes.push_back(get_item_cnode); 10152+ } 10153+ auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); 10154+ MSLITE_CHECK_PTR(abstract_tuple); 10155+ operator_node->set_abstract(abstract_tuple); 10156+ 10157+ // Use MakeTuple node to wrap all outputs as single input of Return node. 10158+ auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>(); 10159+ if (make_tuple_prim_ptr == nullptr) { 10160+ MS_LOG(ERROR) << "new MakeTuple failed"; 10161+ return RET_NULL_PTR; 10162+ } 10163+ auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim(); 10164+ MSLITE_CHECK_PTR(make_tuple_prim_c); 10165+ auto make_tuple_prim = NewValueNode(make_tuple_prim_c); 10166+ MSLITE_CHECK_PTR(make_tuple_prim); 10167+ std::vector<AnfNodePtr> make_tuple_inputs = output_nodes; 10168+ make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim); 10169+ auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs); 10170+ MSLITE_CHECK_PTR(make_tuple_cnode); 10171+ make_tuple_cnode->set_fullname_with_scope("return_tuple"); 10172+ 10173+ auto return_prim_ptr = std::make_shared<ops::Return>(); 10174+ if (return_prim_ptr == nullptr) { 10175+ MS_LOG(ERROR) << "new Return failed"; 10176+ return RET_NULL_PTR; 10177+ } 10178+ auto return_prim_c = return_prim_ptr->GetPrim(); 10179+ MSLITE_CHECK_PTR(return_prim_c); 10180+ std::vector<AnfNodePtr> op_inputs{make_tuple_cnode}; 10181+ auto cnode = func_graph->NewCNode(return_prim_c, op_inputs); 10182+ MSLITE_CHECK_PTR(cnode); 10183+ cnode->set_fullname_with_scope("Return"); 10184+ func_graph->set_return(cnode); 10185+ 10186+ // Save original output tensor names. 10187+ ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(names); 10188+ return RET_OK; 10189+} 10190+ 10191+REG_MODEL_PARSER(kFmkTypeThirdParty, LiteModelParserCreator<ThirdPartyModelParser>) 10192+} // namespace lite 10193+} // namespace mindspore 10194diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 10195new file mode 100644 10196index 00000000..c4b197b8 10197--- /dev/null 10198+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 10199@@ -0,0 +1,50 @@ 10200+/** 10201+ * Copyright 2023 Huawei Technologies Co., Ltd 10202+ * 10203+ * Licensed under the Apache License, Version 2.0 (the "License"); 10204+ * you may not use this file except in compliance with the License. 10205+ * You may obtain a copy of the License at 10206+ * 10207+ * http://www.apache.org/licenses/LICENSE-2.0 10208+ * 10209+ * Unless required by applicable law or agreed to in writing, software 10210+ * distributed under the License is distributed on an "AS IS" BASIS, 10211+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10212+ * See the License for the specific language governing permissions and 10213+ * limitations under the License. 10214+ */ 10215+ 10216+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 10217+#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 10218+ 10219+#include <string> 10220+#include <vector> 10221+#include "schema/inner/model_generated.h" 10222+#include "base/base.h" 10223+#include "ir/anf.h" 10224+#include "ir/func_graph.h" 10225+#include "include/errorcode.h" 10226+#include "include/registry/model_parser.h" 10227+#include "tools/converter/config_parser/third_party_param_parser.h" 10228+ 10229+namespace mindspore { 10230+namespace lite { 10231+class ThirdPartyModelParser : public converter::ModelParser { 10232+ public: 10233+ api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; 10234+ 10235+ private: 10236+ STATUS InitConfig(const std::string &config_file); 10237+ api::FuncGraphPtr CreateFuncGraph(); 10238+ STATUS BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs); 10239+ STATUS BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs, 10240+ CNodePtr *operator_node); 10241+ STATUS BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node); 10242+ 10243+ std::string model_file_ = ""; 10244+ ThirdPartyModelParam param_; 10245+}; 10246+} // namespace lite 10247+} // namespace mindspore 10248+ 10249+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 10250diff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc 10251index 832fb92d..6bc2d4d3 100644 10252--- a/mindspore/lite/tools/converter/registry/model_parser_registry.cc 10253+++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc 10254@@ -26,7 +26,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room; 10255 } // namespace 10256 10257 ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { 10258- if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { 10259+ if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) { 10260 MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; 10261 return; 10262 } 10263@@ -38,7 +38,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator 10264 } 10265 10266 converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) { 10267- if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { 10268+ if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) { 10269 MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; 10270 return nullptr; 10271 } 10272-- 102732.17.1 10274 10275