1be168c0dSopenharmony_ciFrom baf2daaebd70448cddd35f5011642fe585d071b5 Mon Sep 17 00:00:00 2001 2be168c0dSopenharmony_ciFrom: chengfeng27 <chengfeng27@huawei.com> 3be168c0dSopenharmony_ciDate: Tue, 5 Mar 2024 20:00:24 +0800 4be168c0dSopenharmony_ciSubject: [PATCH] hilog use macro definition api 5be168c0dSopenharmony_ci 6be168c0dSopenharmony_ci--- 7be168c0dSopenharmony_ci cmake/external_libs/flatbuffers.cmake | 4 +- 8be168c0dSopenharmony_ci include/api/context.h | 65 ++ 9be168c0dSopenharmony_ci include/c_api/context_c.h | 111 +++ 10be168c0dSopenharmony_ci include/c_api/model_c.h | 178 ++++ 11be168c0dSopenharmony_ci include/c_api/tensor_c.h | 14 + 12be168c0dSopenharmony_ci include/c_api/types_c.h | 57 +- 13be168c0dSopenharmony_ci include/sdk_api/context.h | 103 +++ 14be168c0dSopenharmony_ci include/sdk_api/tensor.h | 13 + 15be168c0dSopenharmony_ci include/sdk_api/types.h | 38 +- 16be168c0dSopenharmony_ci .../plugin/device/cpu/kernel/nnacl/BUILD.gn | 3 + 17be168c0dSopenharmony_ci .../device/cpu/kernel/nnacl/CMakeLists.txt | 2 +- 18be168c0dSopenharmony_ci .../kernel/nnacl/avx/scatter_nd_binary_avx.h | 66 ++ 19be168c0dSopenharmony_ci .../nnacl/avx512/scatter_nd_binary_avx512.h | 66 ++ 20be168c0dSopenharmony_ci .../cpu/kernel/nnacl/base/scatter_nd_binary.c | 28 + 21be168c0dSopenharmony_ci .../cpu/kernel/nnacl/base/scatter_nd_binary.h | 3 + 22be168c0dSopenharmony_ci .../nnacl/base/scatter_nd_binary_simd.h.in | 14 + 23be168c0dSopenharmony_ci .../kernel/nnacl/custom_is_inf_parameter.h | 26 + 24be168c0dSopenharmony_ci .../nnacl/custom_masked_fill_parameter.h | 26 + 25be168c0dSopenharmony_ci .../custom_tensor_scatter_max_parameter.h | 26 + 26be168c0dSopenharmony_ci .../kernel/nnacl/infer/custom_is_inf_infer.c | 38 + 27be168c0dSopenharmony_ci .../kernel/nnacl/infer/custom_is_inf_infer.h | 31 + 28be168c0dSopenharmony_ci .../nnacl/infer/custom_masked_fill_infer.c | 37 + 29be168c0dSopenharmony_ci .../nnacl/infer/custom_masked_fill_infer.h | 31 + 30be168c0dSopenharmony_ci .../infer/custom_tensor_scatter_max_infer.c | 37 + 31be168c0dSopenharmony_ci .../infer/custom_tensor_scatter_max_infer.h | 31 + 32be168c0dSopenharmony_ci .../nnacl/neon/scatter_nd_binary_neon.h | 65 ++ 33be168c0dSopenharmony_ci .../plugin/device/cpu/kernel/nnacl/op_base.h | 4 + 34be168c0dSopenharmony_ci .../cpu/kernel/nnacl/scatter_nd_binary_simd.h | 36 + 35be168c0dSopenharmony_ci .../kernel/nnacl/sse/scatter_nd_binary_sse.h | 66 ++ 36be168c0dSopenharmony_ci mindspore/core/mindrt/BUILD.gn | 9 +- 37be168c0dSopenharmony_ci .../mindrt/src/thread/actor_threadpool.cc | 2 +- 38be168c0dSopenharmony_ci .../core/mindrt/src/thread/core_affinity.cc | 6 +- 39be168c0dSopenharmony_ci .../core/mindrt/src/thread/core_affinity.h | 2 +- 40be168c0dSopenharmony_ci .../mindrt/src/thread/parallel_threadpool.cc | 2 +- 41be168c0dSopenharmony_ci mindspore/core/mindrt/src/thread/threadlog.h | 28 +- 42be168c0dSopenharmony_ci .../core/mindrt/src/thread/threadpool.cc | 7 +- 43be168c0dSopenharmony_ci mindspore/lite/BUILD.gn | 82 +- 44be168c0dSopenharmony_ci mindspore/lite/CMakeLists.txt | 5 +- 45be168c0dSopenharmony_ci mindspore/lite/include/lite_types.h | 1 + 46be168c0dSopenharmony_ci mindspore/lite/include/model.h | 4 + 47be168c0dSopenharmony_ci .../lite/include/registry/converter_context.h | 4 +- 48be168c0dSopenharmony_ci mindspore/lite/mindir/include/mindir.h | 2 + 49be168c0dSopenharmony_ci mindspore/lite/mindir/src/mindir.cc | 40 + 50be168c0dSopenharmony_ci mindspore/lite/mindir/src/mindir_tensor.cc | 2 +- 51be168c0dSopenharmony_ci mindspore/lite/mindir/src/utils.cc | 2 +- 52be168c0dSopenharmony_ci mindspore/lite/src/CMakeLists.txt | 6 +- 53be168c0dSopenharmony_ci mindspore/lite/src/common/context_util.cc | 14 +- 54be168c0dSopenharmony_ci mindspore/lite/src/common/log.cc | 33 +- 55be168c0dSopenharmony_ci mindspore/lite/src/common/log.h | 50 +- 56be168c0dSopenharmony_ci .../common/ops/populate/custom_populate.cc | 53 ++ 57be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/context_c.cc | 372 +++++++- 58be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/context_c.h | 23 - 59be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/model_c.cc | 724 ++++++++------- 60be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/tensor_c.cc | 78 +- 61be168c0dSopenharmony_ci .../lite/src/litert/c_api/type_c_private.h | 40 + 62be168c0dSopenharmony_ci mindspore/lite/src/litert/cxx_api/context.cc | 85 ++ 63be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/converters.cc | 60 +- 64be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/converters.h | 4 +- 65be168c0dSopenharmony_ci .../src/litert/delegate/nnrt/CMakeLists.txt | 27 +- 66be168c0dSopenharmony_ci .../delegate/nnrt/checker/primitive_check.cc | 2 + 67be168c0dSopenharmony_ci .../src/litert/delegate/nnrt/nnrt_delegate.cc | 836 ++++++++++++++---- 68be168c0dSopenharmony_ci .../src/litert/delegate/nnrt/nnrt_delegate.h | 74 +- 69be168c0dSopenharmony_ci .../litert/delegate/nnrt/nnrt_model_kernel.cc | 3 +- 70be168c0dSopenharmony_ci .../litert/delegate/nnrt/nnrt_model_kernel.h | 2 +- 71be168c0dSopenharmony_ci .../src/litert/delegate/nnrt/nnrt_stub.cc | 99 +++ 72be168c0dSopenharmony_ci mindspore/lite/src/litert/infer_manager.cc | 3 +- 73be168c0dSopenharmony_ci mindspore/lite/src/litert/inner_context.cc | 4 + 74be168c0dSopenharmony_ci mindspore/lite/src/litert/inner_context.h | 14 + 75be168c0dSopenharmony_ci mindspore/lite/src/litert/kernel/cpu/BUILD.gn | 51 +- 76be168c0dSopenharmony_ci .../src/litert/kernel/cpu/base/custom_base.cc | 46 + 77be168c0dSopenharmony_ci .../src/litert/kernel/cpu/base/custom_base.h | 43 + 78be168c0dSopenharmony_ci .../litert/kernel/cpu/base/custom_is_inf.cc | 61 ++ 79be168c0dSopenharmony_ci .../litert/kernel/cpu/base/custom_is_inf.h | 38 + 80be168c0dSopenharmony_ci .../kernel/cpu/base/custom_masked_fill.cc | 84 ++ 81be168c0dSopenharmony_ci .../kernel/cpu/base/custom_masked_fill.h | 35 + 82be168c0dSopenharmony_ci .../kernel/cpu/base/custom_tensor_scatter.cc | 75 ++ 83be168c0dSopenharmony_ci .../kernel/cpu/base/custom_tensor_scatter.h | 36 + 84be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_model.cc | 29 + 85be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_session.cc | 39 +- 86be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_session.h | 1 + 87be168c0dSopenharmony_ci mindspore/lite/src/litert/scheduler.cc | 17 + 88be168c0dSopenharmony_ci mindspore/lite/src/litert/tensor_category.cc | 4 + 89be168c0dSopenharmony_ci mindspore/lite/src/litert/tensor_category.h | 1 + 90be168c0dSopenharmony_ci mindspore/lite/test/CMakeLists.txt | 15 +- 91be168c0dSopenharmony_ci mindspore/lite/test/runtest.sh | 1 + 92be168c0dSopenharmony_ci .../test/ut/test_data/third_party_model.cfg | 8 + 93be168c0dSopenharmony_ci .../tools/converter/api/converter_api_test.cc | 10 + 94be168c0dSopenharmony_ci .../third_party_param_parser_test.cc | 176 ++++ 95be168c0dSopenharmony_ci .../lite/tools/benchmark/benchmark_base.cc | 2 +- 96be168c0dSopenharmony_ci .../lite/tools/benchmark/benchmark_base.h | 2 +- 97be168c0dSopenharmony_ci .../lite/tools/benchmark/benchmark_c_api.cc | 4 + 98be168c0dSopenharmony_ci .../tools/benchmark/benchmark_unified_api.cc | 5 + 99be168c0dSopenharmony_ci .../lite/tools/benchmark_train/CMakeLists.txt | 3 + 100be168c0dSopenharmony_ci mindspore/lite/tools/benchmark_train/main.cc | 3 +- 101be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_runner.cc | 10 +- 102be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_train.cc | 418 +-------- 103be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_train.h | 229 +---- 104be168c0dSopenharmony_ci .../tools/benchmark_train/net_train_base.cc | 410 +++++++++ 105be168c0dSopenharmony_ci .../tools/benchmark_train/net_train_base.h | 288 ++++++ 106be168c0dSopenharmony_ci .../tools/benchmark_train/net_train_c_api.cc | 659 ++++++++++++++ 107be168c0dSopenharmony_ci .../tools/benchmark_train/net_train_c_api.h | 121 +++ 108be168c0dSopenharmony_ci .../tools/benchmark_train/run_net_train.cc | 86 ++ 109be168c0dSopenharmony_ci .../tools/benchmark_train/run_net_train.h | 22 + 110be168c0dSopenharmony_ci mindspore/lite/tools/converter/CMakeLists.txt | 4 + 111be168c0dSopenharmony_ci .../config_parser/config_file_parser.cc | 27 + 112be168c0dSopenharmony_ci .../config_parser/config_file_parser.h | 15 + 113be168c0dSopenharmony_ci .../config_parser/third_party_param_parser.cc | 299 +++++++ 114be168c0dSopenharmony_ci .../config_parser/third_party_param_parser.h | 44 + 115be168c0dSopenharmony_ci mindspore/lite/tools/converter/converter.cc | 34 +- 116be168c0dSopenharmony_ci .../tools/converter/converter_funcgraph.cc | 13 +- 117be168c0dSopenharmony_ci .../converter_lite/converter_flags.cc | 4 +- 118be168c0dSopenharmony_ci .../tools/converter/cxx_api/converter_para.h | 14 + 119be168c0dSopenharmony_ci .../tools/converter/graphdef_transform.cc | 44 + 120be168c0dSopenharmony_ci .../parser/third_party/CMakeLists.txt | 4 + 121be168c0dSopenharmony_ci .../third_party/third_party_model_parser.cc | 277 ++++++ 122be168c0dSopenharmony_ci .../third_party/third_party_model_parser.h | 50 ++ 123be168c0dSopenharmony_ci .../registry/model_parser_registry.cc | 4 +- 124be168c0dSopenharmony_ci 117 files changed, 6456 insertions(+), 1432 deletions(-) 125be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h 126be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h 127be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h 128be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h 129be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h 130be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c 131be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h 132be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c 133be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h 134be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c 135be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h 136be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h 137be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h 138be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h 139be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/c_api/type_c_private.h 140be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc 141be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc 142be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_base.h 143be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc 144be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h 145be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc 146be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h 147be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc 148be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h 149be168c0dSopenharmony_ci create mode 100644 mindspore/lite/test/ut/test_data/third_party_model.cfg 150be168c0dSopenharmony_ci create mode 100644 mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 151be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/net_train_base.cc 152be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/net_train_base.h 153be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/net_train_c_api.cc 154be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/net_train_c_api.h 155be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/run_net_train.cc 156be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/run_net_train.h 157be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 158be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 159be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 160be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 161be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 162be168c0dSopenharmony_ci 163be168c0dSopenharmony_cidiff --git a/cmake/external_libs/flatbuffers.cmake b/cmake/external_libs/flatbuffers.cmake 164be168c0dSopenharmony_ciindex 2fde4311..87f0425b 100644 165be168c0dSopenharmony_ci--- a/cmake/external_libs/flatbuffers.cmake 166be168c0dSopenharmony_ci+++ b/cmake/external_libs/flatbuffers.cmake 167be168c0dSopenharmony_ci@@ -21,8 +21,8 @@ else() 168be168c0dSopenharmony_ci # flatbuffers.lib cimplied by msvc 169be168c0dSopenharmony_ci set(CMAKE_STATIC_LIBRARY_PREFIX "") 170be168c0dSopenharmony_ci else() 171be168c0dSopenharmony_ci- set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong") 172be168c0dSopenharmony_ci- set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong") 173be168c0dSopenharmony_ci+ set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -Wno-error=unused-but-set-variable") 174be168c0dSopenharmony_ci+ set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -Wno-error=unused-but-set-variable") 175be168c0dSopenharmony_ci endif() 176be168c0dSopenharmony_ci 177be168c0dSopenharmony_ci if(WIN32) 178be168c0dSopenharmony_cidiff --git a/include/api/context.h b/include/api/context.h 179be168c0dSopenharmony_ciindex c9fb11f0..eb704d44 100644 180be168c0dSopenharmony_ci--- a/include/api/context.h 181be168c0dSopenharmony_ci+++ b/include/api/context.h 182be168c0dSopenharmony_ci@@ -39,6 +39,8 @@ enum DeviceType { 183be168c0dSopenharmony_ci kAscend310, 184be168c0dSopenharmony_ci kCustomDevice, 185be168c0dSopenharmony_ci kAllDevice, 186be168c0dSopenharmony_ci+ //ohos-only device range[60,80) 187be168c0dSopenharmony_ci+ kNNRt = 60, 188be168c0dSopenharmony_ci // add new type here 189be168c0dSopenharmony_ci kInvalidDeviceType = 100, 190be168c0dSopenharmony_ci }; 191be168c0dSopenharmony_ci@@ -598,5 +600,68 @@ void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_ 192be168c0dSopenharmony_ci SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); 193be168c0dSopenharmony_ci } 194be168c0dSopenharmony_ci std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } 195be168c0dSopenharmony_ci+ 196be168c0dSopenharmony_ci+struct Extension { 197be168c0dSopenharmony_ci+ std::string name; 198be168c0dSopenharmony_ci+ std::vector<uint8_t> value; 199be168c0dSopenharmony_ci+}; 200be168c0dSopenharmony_ci+ 201be168c0dSopenharmony_ci+class MS_API NNRTDeviceInfo : public DeviceInfoContext { 202be168c0dSopenharmony_ci+ public: 203be168c0dSopenharmony_ci+ /// \brief Get the type of this DeviceInfoContext. 204be168c0dSopenharmony_ci+ /// 205be168c0dSopenharmony_ci+ /// \return Type of this DeviceInfoContext. 206be168c0dSopenharmony_ci+ enum DeviceType GetDeviceType() const override { return DeviceType::kNNRt; }; 207be168c0dSopenharmony_ci+ 208be168c0dSopenharmony_ci+ /// \brief Set device id. 209be168c0dSopenharmony_ci+ /// 210be168c0dSopenharmony_ci+ /// \param[in] device_id The device id. 211be168c0dSopenharmony_ci+ void SetDeviceID(size_t device_id); 212be168c0dSopenharmony_ci+ 213be168c0dSopenharmony_ci+ /// \brief Get the device id. 214be168c0dSopenharmony_ci+ /// 215be168c0dSopenharmony_ci+ /// \return The device id. 216be168c0dSopenharmony_ci+ size_t GetDeviceID() const; 217be168c0dSopenharmony_ci+ 218be168c0dSopenharmony_ci+ /// \brief Set performance mode. 219be168c0dSopenharmony_ci+ /// 220be168c0dSopenharmony_ci+ /// \param[in] performance_mode The performance mode. 221be168c0dSopenharmony_ci+ void SetPerformanceMode(int performance_mode); 222be168c0dSopenharmony_ci+ 223be168c0dSopenharmony_ci+ /// \brief Get performance mode. 224be168c0dSopenharmony_ci+ /// 225be168c0dSopenharmony_ci+ /// \return The priority. 226be168c0dSopenharmony_ci+ int GetPerformanceMode() const; 227be168c0dSopenharmony_ci+ 228be168c0dSopenharmony_ci+ /// \brief Set priority. 229be168c0dSopenharmony_ci+ /// 230be168c0dSopenharmony_ci+ /// \param[in] priority The priority. 231be168c0dSopenharmony_ci+ void SetPriority(int priority); 232be168c0dSopenharmony_ci+ 233be168c0dSopenharmony_ci+ /// \brief Get priority. 234be168c0dSopenharmony_ci+ /// 235be168c0dSopenharmony_ci+ /// \return The priority. 236be168c0dSopenharmony_ci+ int GetPriority() const; 237be168c0dSopenharmony_ci+ 238be168c0dSopenharmony_ci+ /// \brief Set enables to perform the float16 inference 239be168c0dSopenharmony_ci+ /// 240be168c0dSopenharmony_ci+ /// \param[in] is_fp16 Enable float16 inference or not. 241be168c0dSopenharmony_ci+ void SetEnableFP16(bool is_fp16); 242be168c0dSopenharmony_ci+ 243be168c0dSopenharmony_ci+ /// \brief Get enables to perform the float16 inference 244be168c0dSopenharmony_ci+ /// 245be168c0dSopenharmony_ci+ /// \return Whether enable float16 inference. 246be168c0dSopenharmony_ci+ bool GetEnableFP16() const; 247be168c0dSopenharmony_ci+ 248be168c0dSopenharmony_ci+ /// \brief Set extensions 249be168c0dSopenharmony_ci+ /// 250be168c0dSopenharmony_ci+ /// \param[in] extension array. 251be168c0dSopenharmony_ci+ void SetExtensions(const std::vector<Extension> &extensions); 252be168c0dSopenharmony_ci+ 253be168c0dSopenharmony_ci+ /// \brief Get extensions 254be168c0dSopenharmony_ci+ /// 255be168c0dSopenharmony_ci+ /// \return extension array. 256be168c0dSopenharmony_ci+ std::vector<Extension> GetExtensions() const; 257be168c0dSopenharmony_ci+}; 258be168c0dSopenharmony_ci } // namespace mindspore 259be168c0dSopenharmony_ci #endif // MINDSPORE_INCLUDE_API_CONTEXT_H 260be168c0dSopenharmony_cidiff --git a/include/c_api/context_c.h b/include/c_api/context_c.h 261be168c0dSopenharmony_ciindex 53839e80..8951da25 100644 262be168c0dSopenharmony_ci--- a/include/c_api/context_c.h 263be168c0dSopenharmony_ci+++ b/include/c_api/context_c.h 264be168c0dSopenharmony_ci@@ -19,6 +19,7 @@ 265be168c0dSopenharmony_ci #include <stddef.h> 266be168c0dSopenharmony_ci #include <stdint.h> 267be168c0dSopenharmony_ci #include <stdbool.h> 268be168c0dSopenharmony_ci+#include "include/c_api/status_c.h" 269be168c0dSopenharmony_ci #include "include/c_api/types_c.h" 270be168c0dSopenharmony_ci 271be168c0dSopenharmony_ci #ifdef __cplusplus 272be168c0dSopenharmony_ci@@ -173,6 +174,116 @@ OH_AI_API void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, 273be168c0dSopenharmony_ci /// \return NPU frequency 274be168c0dSopenharmony_ci OH_AI_API int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info); 275be168c0dSopenharmony_ci 276be168c0dSopenharmony_ci+/// \brief Obtain the all device descriptions in NNRT. 277be168c0dSopenharmony_ci+/// 278be168c0dSopenharmony_ci+/// \param[out] num Number of NNRT device description. 279be168c0dSopenharmony_ci+/// 280be168c0dSopenharmony_ci+/// \return NNRT device description array. 281be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num); 282be168c0dSopenharmony_ci+ 283be168c0dSopenharmony_ci+/// \brief Obtain the specified element in NNRt device description array. 284be168c0dSopenharmony_ci+/// 285be168c0dSopenharmony_ci+/// \param[in] descs NNRT device description array. 286be168c0dSopenharmony_ci+/// \param[in] index Element index. 287be168c0dSopenharmony_ci+/// 288be168c0dSopenharmony_ci+/// \return NNRT device description. 289be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index); 290be168c0dSopenharmony_ci+ 291be168c0dSopenharmony_ci+/// \brief Obtain the all device descriptions in NNRT. 292be168c0dSopenharmony_ci+/// 293be168c0dSopenharmony_ci+/// \param[out] num Number of NNRT device description. 294be168c0dSopenharmony_ci+/// 295be168c0dSopenharmony_ci+/// \return NNRT device description array. 296be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num); 297be168c0dSopenharmony_ci+ 298be168c0dSopenharmony_ci+/// \brief Destroy the NNRT device descriptions returned by OH_AI_GetAllNNRTDeviceDescs(). 299be168c0dSopenharmony_ci+/// 300be168c0dSopenharmony_ci+/// \param[in] desc NNRT device description array. 301be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc); 302be168c0dSopenharmony_ci+ 303be168c0dSopenharmony_ci+/// \brief Obtain the device id in NNRT device description. 304be168c0dSopenharmony_ci+/// 305be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance. 306be168c0dSopenharmony_ci+/// 307be168c0dSopenharmony_ci+/// \return NNRT device id. 308be168c0dSopenharmony_ci+OH_AI_API size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 309be168c0dSopenharmony_ci+ 310be168c0dSopenharmony_ci+/// \brief Obtain the device name in NNRT device description. 311be168c0dSopenharmony_ci+/// 312be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance. 313be168c0dSopenharmony_ci+/// 314be168c0dSopenharmony_ci+/// \return NNRT device name. 315be168c0dSopenharmony_ci+OH_AI_API const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 316be168c0dSopenharmony_ci+ 317be168c0dSopenharmony_ci+/// \brief Obtain the device type in NNRT device description. 318be168c0dSopenharmony_ci+/// 319be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance. 320be168c0dSopenharmony_ci+/// 321be168c0dSopenharmony_ci+/// \return NNRT device type. 322be168c0dSopenharmony_ci+OH_AI_API OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 323be168c0dSopenharmony_ci+ 324be168c0dSopenharmony_ci+/// \brief Create the NNRT device info by exactly matching the specific device name. 325be168c0dSopenharmony_ci+/// 326be168c0dSopenharmony_ci+/// \param[in] name NNRt device name. 327be168c0dSopenharmony_ci+/// 328be168c0dSopenharmony_ci+/// \return Device info object handle. 329be168c0dSopenharmony_ci+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name); 330be168c0dSopenharmony_ci+ 331be168c0dSopenharmony_ci+/// \brief Create the NNRT device info by finding the first device with the specific device type. 332be168c0dSopenharmony_ci+/// 333be168c0dSopenharmony_ci+/// \param[in] name NNRt device type. 334be168c0dSopenharmony_ci+/// 335be168c0dSopenharmony_ci+/// \return Device info object handle. 336be168c0dSopenharmony_ci+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type); 337be168c0dSopenharmony_ci+ 338be168c0dSopenharmony_ci+/// \brief Set the NNRT device id, Only valid for NNRT. 339be168c0dSopenharmony_ci+/// 340be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 341be168c0dSopenharmony_ci+/// \param[in] device_id NNRT device id. 342be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id); 343be168c0dSopenharmony_ci+ 344be168c0dSopenharmony_ci+/// \brief Obtain the NNRT device id, Only valid for NNRT. 345be168c0dSopenharmony_ci+/// 346be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 347be168c0dSopenharmony_ci+/// 348be168c0dSopenharmony_ci+/// \return NNRT device id. 349be168c0dSopenharmony_ci+OH_AI_API size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info); 350be168c0dSopenharmony_ci+ 351be168c0dSopenharmony_ci+/// \brief Set the NNRT performance mode, Only valid for NNRT. 352be168c0dSopenharmony_ci+/// 353be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 354be168c0dSopenharmony_ci+/// \param[in] device_id NNRT performance mode. 355be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode); 356be168c0dSopenharmony_ci+ 357be168c0dSopenharmony_ci+/// \brief Obtain the NNRT performance mode, Only valid for NNRT. 358be168c0dSopenharmony_ci+/// 359be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 360be168c0dSopenharmony_ci+/// 361be168c0dSopenharmony_ci+/// \return NNRT performance mode. 362be168c0dSopenharmony_ci+OH_AI_API OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info); 363be168c0dSopenharmony_ci+ 364be168c0dSopenharmony_ci+/// \brief Set the NNRT priority, Only valid for NNRT. 365be168c0dSopenharmony_ci+/// 366be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 367be168c0dSopenharmony_ci+/// \param[in] device_id NNRT priority. 368be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority); 369be168c0dSopenharmony_ci+ 370be168c0dSopenharmony_ci+/// \brief Obtain the NNRT priority, Only valid for NNRT. 371be168c0dSopenharmony_ci+/// 372be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 373be168c0dSopenharmony_ci+/// 374be168c0dSopenharmony_ci+/// \return NNRT priority. 375be168c0dSopenharmony_ci+OH_AI_API OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info); 376be168c0dSopenharmony_ci+ 377be168c0dSopenharmony_ci+/// \brief Add extension of key/value format to device info, Only valid for NNRT. 378be168c0dSopenharmony_ci+/// 379be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 380be168c0dSopenharmony_ci+/// \param[in] name The content of key as a C string. 381be168c0dSopenharmony_ci+/// \param[in] value The pointer to the value, which is a byte array. 382be168c0dSopenharmony_ci+/// \param[in] value_size The size of the value, which is a byte array. 383be168c0dSopenharmony_ci+/// 384be168c0dSopenharmony_ci+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed. 385be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, const char *name, const char *value, size_t value_size); 386be168c0dSopenharmony_ci #ifdef __cplusplus 387be168c0dSopenharmony_ci } 388be168c0dSopenharmony_ci #endif 389be168c0dSopenharmony_cidiff --git a/include/c_api/model_c.h b/include/c_api/model_c.h 390be168c0dSopenharmony_ciindex 12a46bcd..2286e673 100644 391be168c0dSopenharmony_ci--- a/include/c_api/model_c.h 392be168c0dSopenharmony_ci+++ b/include/c_api/model_c.h 393be168c0dSopenharmony_ci@@ -26,6 +26,8 @@ extern "C" { 394be168c0dSopenharmony_ci 395be168c0dSopenharmony_ci typedef void *OH_AI_ModelHandle; 396be168c0dSopenharmony_ci 397be168c0dSopenharmony_ci+typedef void *OH_AI_TrainCfgHandle; 398be168c0dSopenharmony_ci+ 399be168c0dSopenharmony_ci typedef struct OH_AI_TensorHandleArray { 400be168c0dSopenharmony_ci size_t handle_num; 401be168c0dSopenharmony_ci OH_AI_TensorHandle *handle_list; 402be168c0dSopenharmony_ci@@ -168,6 +170,182 @@ OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHa 403be168c0dSopenharmony_ci /// \return The output tensor handle with the given name, if the name is not found, an NULL is returned. 404be168c0dSopenharmony_ci OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name); 405be168c0dSopenharmony_ci 406be168c0dSopenharmony_ci+/// \brief Create a TrainCfg object. Only valid for Lite Train. 407be168c0dSopenharmony_ci+/// 408be168c0dSopenharmony_ci+/// \return TrainCfg object handle. 409be168c0dSopenharmony_ci+OH_AI_API OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate(); 410be168c0dSopenharmony_ci+ 411be168c0dSopenharmony_ci+/// \brief Destroy the train_cfg object. Only valid for Lite Train. 412be168c0dSopenharmony_ci+/// 413be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle. 414be168c0dSopenharmony_ci+OH_AI_API void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg); 415be168c0dSopenharmony_ci+ 416be168c0dSopenharmony_ci+/// \brief Obtains part of the name that identify a loss kernel. Only valid for Lite Train. 417be168c0dSopenharmony_ci+/// 418be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle. 419be168c0dSopenharmony_ci+/// \param[in] num The num of loss_name. 420be168c0dSopenharmony_ci+/// 421be168c0dSopenharmony_ci+/// \return loss_name. 422be168c0dSopenharmony_ci+OH_AI_API char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num); 423be168c0dSopenharmony_ci+ 424be168c0dSopenharmony_ci+/// \brief Set part of the name that identify a loss kernel. Only valid for Lite Train. 425be168c0dSopenharmony_ci+/// 426be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle. 427be168c0dSopenharmony_ci+/// \param[in] loss_name define part of the name that identify a loss kernel. 428be168c0dSopenharmony_ci+/// \param[in] num The num of loss_name. 429be168c0dSopenharmony_ci+OH_AI_API void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num); 430be168c0dSopenharmony_ci+ 431be168c0dSopenharmony_ci+/// \brief Obtains optimization level of the train_cfg. Only valid for Lite Train. 432be168c0dSopenharmony_ci+/// 433be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle. 434be168c0dSopenharmony_ci+/// 435be168c0dSopenharmony_ci+/// \return OH_AI_OptimizationLevel. 436be168c0dSopenharmony_ci+OH_AI_API OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg); 437be168c0dSopenharmony_ci+ 438be168c0dSopenharmony_ci+/// \brief Set optimization level of the train_cfg. Only valid for Lite Train. 439be168c0dSopenharmony_ci+/// 440be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle. 441be168c0dSopenharmony_ci+/// \param[in] level The optimization level of train_cfg. 442be168c0dSopenharmony_ci+OH_AI_API void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level); 443be168c0dSopenharmony_ci+ 444be168c0dSopenharmony_ci+/// \brief Build the train model from model buffer so that it can run on a device. Only valid for Lite Train. 445be168c0dSopenharmony_ci+/// 446be168c0dSopenharmony_ci+/// \param[in] model Model object handle. 447be168c0dSopenharmony_ci+/// \param[in] model_data Define the buffer read from a model file. 448be168c0dSopenharmony_ci+/// \param[in] data_size Define bytes number of model file buffer. 449be168c0dSopenharmony_ci+/// \param[in] model_type Define The type of model file. 450be168c0dSopenharmony_ci+/// \param[in] model_context Define the context used to store options during execution. 451be168c0dSopenharmony_ci+/// \param[in] train_cfg Define the config used by training. 452be168c0dSopenharmony_ci+/// 453be168c0dSopenharmony_ci+/// \return OH_AI_Status. 454be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, 455be168c0dSopenharmony_ci+ OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, 456be168c0dSopenharmony_ci+ const OH_AI_TrainCfgHandle train_cfg); 457be168c0dSopenharmony_ci+ 458be168c0dSopenharmony_ci+/// \brief Build the train model from model file buffer so that it can run on a device. Only valid for Lite Train. 459be168c0dSopenharmony_ci+/// 460be168c0dSopenharmony_ci+/// \param[in] model Model object handle. 461be168c0dSopenharmony_ci+/// \param[in] model_path Define the model path. 462be168c0dSopenharmony_ci+/// \param[in] model_type Define The type of model file. 463be168c0dSopenharmony_ci+/// \param[in] model_context Define the context used to store options during execution. 464be168c0dSopenharmony_ci+/// \param[in] train_cfg Define the config used by training. 465be168c0dSopenharmony_ci+/// 466be168c0dSopenharmony_ci+/// \return OH_AI_Status. 467be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, 468be168c0dSopenharmony_ci+ OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, 469be168c0dSopenharmony_ci+ const OH_AI_TrainCfgHandle train_cfg); 470be168c0dSopenharmony_ci+ 471be168c0dSopenharmony_ci+/// \brief Train model by step. Only valid for Lite Train. 472be168c0dSopenharmony_ci+/// 473be168c0dSopenharmony_ci+/// \param[in] model Model object handle. 474be168c0dSopenharmony_ci+/// \param[in] before CallBack before predict. 475be168c0dSopenharmony_ci+/// \param[in] after CallBack after predict. 476be168c0dSopenharmony_ci+/// 477be168c0dSopenharmony_ci+/// \return OH_AI_Status. 478be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, 479be168c0dSopenharmony_ci+ const OH_AI_KernelCallBack after); 480be168c0dSopenharmony_ci+ 481be168c0dSopenharmony_ci+/// \brief Sets the Learning Rate of the training. Only valid for Lite Train. 482be168c0dSopenharmony_ci+/// 483be168c0dSopenharmony_ci+/// \param[in] learning_rate to set. 484be168c0dSopenharmony_ci+/// 485be168c0dSopenharmony_ci+/// \return OH_AI_Status of operation. 486be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate); 487be168c0dSopenharmony_ci+ 488be168c0dSopenharmony_ci+/// \brief Obtains the Learning Rate of the optimizer. Only valid for Lite Train. 489be168c0dSopenharmony_ci+/// 490be168c0dSopenharmony_ci+/// \return Learning rate. 0.0 if no optimizer was found. 491be168c0dSopenharmony_ci+OH_AI_API float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model); 492be168c0dSopenharmony_ci+ 493be168c0dSopenharmony_ci+/// \brief Obtains all weights tensors of the model. Only valid for Lite Train. 494be168c0dSopenharmony_ci+/// 495be168c0dSopenharmony_ci+/// \param[in] model Model object handle. 496be168c0dSopenharmony_ci+/// 497be168c0dSopenharmony_ci+/// \return The vector that includes all gradient tensors. 498be168c0dSopenharmony_ci+OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model); 499be168c0dSopenharmony_ci+ 500be168c0dSopenharmony_ci+/// \brief update weights tensors of the model. Only valid for Lite Train. 501be168c0dSopenharmony_ci+/// 502be168c0dSopenharmony_ci+/// \param[in] new_weights A vector new weights. 503be168c0dSopenharmony_ci+/// 504be168c0dSopenharmony_ci+/// \return OH_AI_Status 505be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights); 506be168c0dSopenharmony_ci+ 507be168c0dSopenharmony_ci+/// \brief Get the model running mode. 508be168c0dSopenharmony_ci+/// 509be168c0dSopenharmony_ci+/// \param[in] model Model object handle. 510be168c0dSopenharmony_ci+/// 511be168c0dSopenharmony_ci+/// \return Is Train Mode or not. 512be168c0dSopenharmony_ci+OH_AI_API bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model); 513be168c0dSopenharmony_ci+ 514be168c0dSopenharmony_ci+/// \brief Set the model running mode. Only valid for Lite Train. 515be168c0dSopenharmony_ci+/// 516be168c0dSopenharmony_ci+/// \param[in] model Model object handle. 517be168c0dSopenharmony_ci+/// \param[in] train True means model runs in Train Mode, otherwise Eval Mode. 518be168c0dSopenharmony_ci+/// 519be168c0dSopenharmony_ci+/// \return OH_AI_Status. 520be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train); 521be168c0dSopenharmony_ci+ 522be168c0dSopenharmony_ci+/// \brief Setup training with virtual batches. Only valid for Lite Train. 523be168c0dSopenharmony_ci+/// 524be168c0dSopenharmony_ci+/// \param[in] model Model object handle. 525be168c0dSopenharmony_ci+/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable. 526be168c0dSopenharmony_ci+/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration. 527be168c0dSopenharmony_ci+/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration. 528be168c0dSopenharmony_ci+/// 529be168c0dSopenharmony_ci+/// \return OH_AI_Status. 530be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, 531be168c0dSopenharmony_ci+ float momentum); 532be168c0dSopenharmony_ci+ 533be168c0dSopenharmony_ci+/// \brief Export training model from file. Only valid for Lite Train. 534be168c0dSopenharmony_ci+/// 535be168c0dSopenharmony_ci+/// \param[in] model The model data. 536be168c0dSopenharmony_ci+/// \param[in] model_type The model file type. 537be168c0dSopenharmony_ci+/// \param[in] model_file The exported model file. 538be168c0dSopenharmony_ci+/// \param[in] quantization_type The quantification type. 539be168c0dSopenharmony_ci+/// \param[in] export_inference_only Whether to export a reasoning only model. 540be168c0dSopenharmony_ci+/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as 541be168c0dSopenharmony_ci+/// empty, and export the complete reasoning model. 542be168c0dSopenharmony_ci+/// \param[in] num The number of output_tensor_name. 543be168c0dSopenharmony_ci+/// 544be168c0dSopenharmony_ci+/// \return OH_AI_Status. 545be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file, 546be168c0dSopenharmony_ci+ OH_AI_QuantizationType quantization_type, bool export_inference_only, 547be168c0dSopenharmony_ci+ char **output_tensor_name, size_t num); 548be168c0dSopenharmony_ci+ 549be168c0dSopenharmony_ci+/// \brief Export training model from buffer. Only valid for Lite Train. 550be168c0dSopenharmony_ci+/// 551be168c0dSopenharmony_ci+/// \param[in] model The model data. 552be168c0dSopenharmony_ci+/// \param[in] model_type The model file type. 553be168c0dSopenharmony_ci+/// \param[in] model_data The exported model buffer. 554be168c0dSopenharmony_ci+/// \param[in] data_size The exported model buffer size. 555be168c0dSopenharmony_ci+/// \param[in] quantization_type The quantification type. 556be168c0dSopenharmony_ci+/// \param[in] export_inference_only Whether to export a reasoning only model. 557be168c0dSopenharmony_ci+/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as 558be168c0dSopenharmony_ci+/// empty, and export the complete reasoning model. 559be168c0dSopenharmony_ci+/// \param[in] num The number of output_tensor_name. 560be168c0dSopenharmony_ci+/// 561be168c0dSopenharmony_ci+/// \return OH_AI_Status. 562be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data, 563be168c0dSopenharmony_ci+ size_t *data_size, OH_AI_QuantizationType quantization_type, 564be168c0dSopenharmony_ci+ bool export_inference_only, char **output_tensor_name, size_t num); 565be168c0dSopenharmony_ci+ 566be168c0dSopenharmony_ci+/// \brief Export model's weights, which can be used in micro only. Only valid for Lite Train. 567be168c0dSopenharmony_ci+/// 568be168c0dSopenharmony_ci+/// \param[in] model The model data. 569be168c0dSopenharmony_ci+/// \param[in] model_type The model file type. 570be168c0dSopenharmony_ci+/// \param[in] weight_file The path of exported weight file. 571be168c0dSopenharmony_ci+/// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`. 572be168c0dSopenharmony_ci+/// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format. 573be168c0dSopenharmony_ci+/// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable. 574be168c0dSopenharmony_ci+/// \param[in] num The number of changeable_weights_name. 575be168c0dSopenharmony_ci+/// 576be168c0dSopenharmony_ci+/// \return OH_AI_Status. 577be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, 578be168c0dSopenharmony_ci+ const char *weight_file, bool is_inference, 579be168c0dSopenharmony_ci+ bool enable_fp16, char **changeable_weights_name, 580be168c0dSopenharmony_ci+ size_t num); 581be168c0dSopenharmony_ci+ 582be168c0dSopenharmony_ci #ifdef __cplusplus 583be168c0dSopenharmony_ci } 584be168c0dSopenharmony_ci #endif 585be168c0dSopenharmony_cidiff --git a/include/c_api/tensor_c.h b/include/c_api/tensor_c.h 586be168c0dSopenharmony_ciindex f18ba163..6d2aaab6 100644 587be168c0dSopenharmony_ci--- a/include/c_api/tensor_c.h 588be168c0dSopenharmony_ci+++ b/include/c_api/tensor_c.h 589be168c0dSopenharmony_ci@@ -17,6 +17,7 @@ 590be168c0dSopenharmony_ci #define MINDSPORE_INCLUDE_C_API_TENSOE_C_H 591be168c0dSopenharmony_ci 592be168c0dSopenharmony_ci #include <stddef.h> 593be168c0dSopenharmony_ci+#include "include/c_api/status_c.h" 594be168c0dSopenharmony_ci #include "include/c_api/types_c.h" 595be168c0dSopenharmony_ci #include "include/c_api/data_type_c.h" 596be168c0dSopenharmony_ci #include "include/c_api/format_c.h" 597be168c0dSopenharmony_ci@@ -112,6 +113,19 @@ OH_AI_API OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor); 598be168c0dSopenharmony_ci /// \param[in] data A pointer to the data of the tensor. 599be168c0dSopenharmony_ci OH_AI_API void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data); 600be168c0dSopenharmony_ci 601be168c0dSopenharmony_ci+/// \brief Set the data for the tensor with user-allocated data buffer. 602be168c0dSopenharmony_ci+/// The main purpose of this interface is providing a way of using memory already allocated by user as the Model's 603be168c0dSopenharmony_ci+/// input, but not which allocated inside the Model object. It can reduce one copy. 604be168c0dSopenharmony_ci+/// Note: The tensor won't free the data provided by invoker. Invoker has the responsibility to free it. And this 605be168c0dSopenharmony_ci+/// free action should not be preformed before destruction of the tensor. 606be168c0dSopenharmony_ci+/// 607be168c0dSopenharmony_ci+/// \param[in] tensor Tensor object handle. 608be168c0dSopenharmony_ci+/// \param[in] data A pointer to the user data buffer. 609be168c0dSopenharmony_ci+/// \param[in] data the byte size of the user data buffer. 610be168c0dSopenharmony_ci+/// 611be168c0dSopenharmony_ci+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed. 612be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size); 613be168c0dSopenharmony_ci+ 614be168c0dSopenharmony_ci /// \brief Obtain the data pointer of the tensor. 615be168c0dSopenharmony_ci /// 616be168c0dSopenharmony_ci /// \param[in] tensor Tensor object handle. 617be168c0dSopenharmony_cidiff --git a/include/c_api/types_c.h b/include/c_api/types_c.h 618be168c0dSopenharmony_ciindex dba54ffa..e520e336 100644 619be168c0dSopenharmony_ci--- a/include/c_api/types_c.h 620be168c0dSopenharmony_ci+++ b/include/c_api/types_c.h 621be168c0dSopenharmony_ci@@ -40,10 +40,65 @@ typedef enum OH_AI_DeviceType { 622be168c0dSopenharmony_ci OH_AI_DEVICETYPE_KIRIN_NPU, 623be168c0dSopenharmony_ci // add new type here 624be168c0dSopenharmony_ci // ohos-only device range: [60, 80) 625be168c0dSopenharmony_ci- OH_AI_DEVICETYPE__NNRT = 60, 626be168c0dSopenharmony_ci+ OH_AI_DEVICETYPE_NNRT = 60, 627be168c0dSopenharmony_ci OH_AI_DEVICETYPE_INVALID = 100, 628be168c0dSopenharmony_ci } OH_AI_DeviceType; 629be168c0dSopenharmony_ci 630be168c0dSopenharmony_ci+typedef enum OH_AI_NNRTDeviceType { 631be168c0dSopenharmony_ci+ /** Devices that are not CPU, GPU, or dedicated accelerator */ 632be168c0dSopenharmony_ci+ OH_AI_NNRTDEVICE_OTHERS = 0, 633be168c0dSopenharmony_ci+ /** CPU device */ 634be168c0dSopenharmony_ci+ OH_AI_NNRTDEVICE_CPU = 1, 635be168c0dSopenharmony_ci+ /** GPU device */ 636be168c0dSopenharmony_ci+ OH_AI_NNRTDEVICE_GPU = 2, 637be168c0dSopenharmony_ci+ /** Dedicated hardware accelerator */ 638be168c0dSopenharmony_ci+ OH_AI_NNRTDEVICE_ACCELERATOR = 3, 639be168c0dSopenharmony_ci+} OH_AI_NNRTDeviceType; 640be168c0dSopenharmony_ci+ 641be168c0dSopenharmony_ci+typedef enum OH_AI_PerformanceMode { 642be168c0dSopenharmony_ci+ /** No performance mode preference */ 643be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_NONE = 0, 644be168c0dSopenharmony_ci+ /** Low power consumption mode*/ 645be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_LOW = 1, 646be168c0dSopenharmony_ci+ /** Medium performance mode */ 647be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_MEDIUM = 2, 648be168c0dSopenharmony_ci+ /** High performance mode */ 649be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_HIGH = 3, 650be168c0dSopenharmony_ci+ /** Ultimate performance mode */ 651be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_EXTREME = 4 652be168c0dSopenharmony_ci+} OH_AI_PerformanceMode; 653be168c0dSopenharmony_ci+ 654be168c0dSopenharmony_ci+typedef enum OH_AI_Priority { 655be168c0dSopenharmony_ci+ /** No priority preference */ 656be168c0dSopenharmony_ci+ OH_AI_PRIORITY_NONE = 0, 657be168c0dSopenharmony_ci+ /** Low priority */ 658be168c0dSopenharmony_ci+ OH_AI_PRIORITY_LOW = 1, 659be168c0dSopenharmony_ci+ /** Medium priority */ 660be168c0dSopenharmony_ci+ OH_AI_PRIORITY_MEDIUM = 2, 661be168c0dSopenharmony_ci+ /** High priority */ 662be168c0dSopenharmony_ci+ OH_AI_PRIORITY_HIGH = 3 663be168c0dSopenharmony_ci+} OH_AI_Priority; 664be168c0dSopenharmony_ci+ 665be168c0dSopenharmony_ci+typedef enum OH_AI_OptimizationLevel { 666be168c0dSopenharmony_ci+ /** Do not change */ 667be168c0dSopenharmony_ci+ OH_AI_KO0 = 0, 668be168c0dSopenharmony_ci+ /** Cast network to float16, keep batchnorm and loss in float32 */ 669be168c0dSopenharmony_ci+ OH_AI_KO2 = 2, 670be168c0dSopenharmony_ci+ /** Cast network to float16, including bacthnorm */ 671be168c0dSopenharmony_ci+ OH_AI_KO3 = 3, 672be168c0dSopenharmony_ci+ /** Choose optimization based on device */ 673be168c0dSopenharmony_ci+ OH_AI_KAUTO = 4, 674be168c0dSopenharmony_ci+ OH_AI_KOPTIMIZATIONTYPE = 0xFFFFFFFF 675be168c0dSopenharmony_ci+} OH_AI_OptimizationLevel; 676be168c0dSopenharmony_ci+ 677be168c0dSopenharmony_ci+typedef enum OH_AI_QuantizationType { 678be168c0dSopenharmony_ci+ OH_AI_NO_QUANT = 0, 679be168c0dSopenharmony_ci+ OH_AI_WEIGHT_QUANT = 1, 680be168c0dSopenharmony_ci+ OH_AI_FULL_QUANT = 2, 681be168c0dSopenharmony_ci+ OH_AI_UNKNOWN_QUANT_TYPE = 0xFFFFFFFF 682be168c0dSopenharmony_ci+} OH_AI_QuantizationType; 683be168c0dSopenharmony_ci+ 684be168c0dSopenharmony_ci+typedef struct NNRTDeviceDesc NNRTDeviceDesc; 685be168c0dSopenharmony_ci #ifdef __cplusplus 686be168c0dSopenharmony_ci } 687be168c0dSopenharmony_ci #endif 688be168c0dSopenharmony_cidiff --git a/include/sdk_api/context.h b/include/sdk_api/context.h 689be168c0dSopenharmony_ciindex 5bfc9279..e12b8d6f 100644 690be168c0dSopenharmony_ci--- a/include/sdk_api/context.h 691be168c0dSopenharmony_ci+++ b/include/sdk_api/context.h 692be168c0dSopenharmony_ci@@ -174,6 +174,109 @@ OH_AI_API void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, 693be168c0dSopenharmony_ci /// \return NPU frequency 694be168c0dSopenharmony_ci OH_AI_API int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info); 695be168c0dSopenharmony_ci 696be168c0dSopenharmony_ci+/// \brief Obtain the all device descriptions in NNRT. 697be168c0dSopenharmony_ci+/// 698be168c0dSopenharmony_ci+/// \param[out] num Number of NNRT device description. 699be168c0dSopenharmony_ci+/// 700be168c0dSopenharmony_ci+/// \return NNRT device description array. 701be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num); 702be168c0dSopenharmony_ci+ 703be168c0dSopenharmony_ci+/// \brief Obtain the specified element in NNRt device description array. 704be168c0dSopenharmony_ci+/// 705be168c0dSopenharmony_ci+/// \param[in] descs NNRT device description array. 706be168c0dSopenharmony_ci+/// \param[in] index Element index. 707be168c0dSopenharmony_ci+/// 708be168c0dSopenharmony_ci+/// \return NNRT device description. 709be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index); 710be168c0dSopenharmony_ci+ 711be168c0dSopenharmony_ci+/// \brief Destroy the NNRT device descriptions returned by OH_AI_NNRTGetAllDeviceDescs(). 712be168c0dSopenharmony_ci+/// 713be168c0dSopenharmony_ci+/// \param[in] desc NNRT device description array. 714be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc); 715be168c0dSopenharmony_ci+ 716be168c0dSopenharmony_ci+/// \brief Obtain the device id in NNRT device description. 717be168c0dSopenharmony_ci+/// 718be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance. 719be168c0dSopenharmony_ci+/// 720be168c0dSopenharmony_ci+/// \return NNRT device id. 721be168c0dSopenharmony_ci+OH_AI_API size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 722be168c0dSopenharmony_ci+ 723be168c0dSopenharmony_ci+/// \brief Obtain the device name in NNRT device description. 724be168c0dSopenharmony_ci+/// 725be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance. 726be168c0dSopenharmony_ci+/// 727be168c0dSopenharmony_ci+/// \return NNRT device name. 728be168c0dSopenharmony_ci+OH_AI_API const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 729be168c0dSopenharmony_ci+ 730be168c0dSopenharmony_ci+/// \brief Obtain the device type in NNRT device description. 731be168c0dSopenharmony_ci+/// 732be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance. 733be168c0dSopenharmony_ci+/// 734be168c0dSopenharmony_ci+/// \return NNRT device type. 735be168c0dSopenharmony_ci+OH_AI_API OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc); 736be168c0dSopenharmony_ci+ 737be168c0dSopenharmony_ci+/// \brief Create the NNRT device info by exactly matching the specific device name. 738be168c0dSopenharmony_ci+/// 739be168c0dSopenharmony_ci+/// \param[in] name NNRt device name. 740be168c0dSopenharmony_ci+/// 741be168c0dSopenharmony_ci+/// \return Device info object handle. 742be168c0dSopenharmony_ci+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name); 743be168c0dSopenharmony_ci+ 744be168c0dSopenharmony_ci+/// \brief Create the NNRT device info by finding the first device with the specific device type. 745be168c0dSopenharmony_ci+/// 746be168c0dSopenharmony_ci+/// \param[in] name NNRt device type. 747be168c0dSopenharmony_ci+/// 748be168c0dSopenharmony_ci+/// \return Device info object handle. 749be168c0dSopenharmony_ci+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type); 750be168c0dSopenharmony_ci+ 751be168c0dSopenharmony_ci+/// \brief Set the NNRT device id, Only valid for NNRT. 752be168c0dSopenharmony_ci+/// 753be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 754be168c0dSopenharmony_ci+/// \param[in] device_id NNRT device id. 755be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id); 756be168c0dSopenharmony_ci+ 757be168c0dSopenharmony_ci+/// \brief Obtain the NNRT device id, Only valid for NNRT. 758be168c0dSopenharmony_ci+/// 759be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 760be168c0dSopenharmony_ci+/// 761be168c0dSopenharmony_ci+/// \return NNRT device id. 762be168c0dSopenharmony_ci+OH_AI_API size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info); 763be168c0dSopenharmony_ci+ 764be168c0dSopenharmony_ci+/// \brief Set the NNRT performance mode, Only valid for NNRT. 765be168c0dSopenharmony_ci+/// 766be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 767be168c0dSopenharmony_ci+/// \param[in] device_id NNRT performance mode. 768be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode); 769be168c0dSopenharmony_ci+ 770be168c0dSopenharmony_ci+/// \brief Obtain the NNRT performance mode, Only valid for NNRT. 771be168c0dSopenharmony_ci+/// 772be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 773be168c0dSopenharmony_ci+/// 774be168c0dSopenharmony_ci+/// \return NNRT performance mode. 775be168c0dSopenharmony_ci+OH_AI_API OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info); 776be168c0dSopenharmony_ci+ 777be168c0dSopenharmony_ci+/// \brief Set the NNRT priority, Only valid for NNRT. 778be168c0dSopenharmony_ci+/// 779be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 780be168c0dSopenharmony_ci+/// \param[in] device_id NNRT priority. 781be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority); 782be168c0dSopenharmony_ci+ 783be168c0dSopenharmony_ci+/// \brief Obtain the NNRT priority, Only valid for NNRT. 784be168c0dSopenharmony_ci+/// 785be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 786be168c0dSopenharmony_ci+/// 787be168c0dSopenharmony_ci+/// \return NNRT priority. 788be168c0dSopenharmony_ci+OH_AI_API OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info); 789be168c0dSopenharmony_ci+ 790be168c0dSopenharmony_ci+/// \brief Add extension of key/value format to device info, Only valid for NNRT. 791be168c0dSopenharmony_ci+/// 792be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle. 793be168c0dSopenharmony_ci+/// \param[in] name The content of key as a C string. 794be168c0dSopenharmony_ci+/// \param[in] value The pointer to the value, which is a byte array. 795be168c0dSopenharmony_ci+/// \param[in] value_size The size of the value, which is a byte array. 796be168c0dSopenharmony_ci+/// 797be168c0dSopenharmony_ci+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed. 798be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, const char *name, const char *value, size_t value_size); 799be168c0dSopenharmony_ci #ifdef __cplusplus 800be168c0dSopenharmony_ci } 801be168c0dSopenharmony_ci #endif 802be168c0dSopenharmony_cidiff --git a/include/sdk_api/tensor.h b/include/sdk_api/tensor.h 803be168c0dSopenharmony_ciindex f6ba02cd..3dad04ac 100644 804be168c0dSopenharmony_ci--- a/include/sdk_api/tensor.h 805be168c0dSopenharmony_ci+++ b/include/sdk_api/tensor.h 806be168c0dSopenharmony_ci@@ -17,6 +17,7 @@ 807be168c0dSopenharmony_ci #define MINDSPORE_INCLUDE_C_API_TENSOE_C_H 808be168c0dSopenharmony_ci 809be168c0dSopenharmony_ci #include <stddef.h> 810be168c0dSopenharmony_ci+#include "mindspore/status.h" 811be168c0dSopenharmony_ci #include "mindspore/types.h" 812be168c0dSopenharmony_ci #include "mindspore/data_type.h" 813be168c0dSopenharmony_ci #include "mindspore/format.h" 814be168c0dSopenharmony_ci@@ -140,6 +141,18 @@ OH_AI_API int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor); 815be168c0dSopenharmony_ci /// \return The data size of the tensor. 816be168c0dSopenharmony_ci OH_AI_API size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor); 817be168c0dSopenharmony_ci 818be168c0dSopenharmony_ci+/// \brief Set the data for the tensor with user-allocated data buffer. 819be168c0dSopenharmony_ci+/// The main purpose of this interface is providing a way of using memory already allocated by user as the Model's 820be168c0dSopenharmony_ci+/// input, but not which allocated inside the Model object. It can reduce one copy. 821be168c0dSopenharmony_ci+/// Note: The tensor won't free the data provided by invoker. Invoker has the responsibility to free it. And this 822be168c0dSopenharmony_ci+/// free action should not be preformed before destruction of the tensor. 823be168c0dSopenharmony_ci+/// 824be168c0dSopenharmony_ci+/// \param[in] tensor Tensor object handle. 825be168c0dSopenharmony_ci+/// \param[in] data A pointer to the user data buffer. 826be168c0dSopenharmony_ci+/// \param[in] data the byte size of the user data buffer. 827be168c0dSopenharmony_ci+/// 828be168c0dSopenharmony_ci+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed. 829be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size); 830be168c0dSopenharmony_ci #ifdef __cplusplus 831be168c0dSopenharmony_ci } 832be168c0dSopenharmony_ci #endif 833be168c0dSopenharmony_cidiff --git a/include/sdk_api/types.h b/include/sdk_api/types.h 834be168c0dSopenharmony_ciindex a39c6daa..d38660b0 100644 835be168c0dSopenharmony_ci--- a/include/sdk_api/types.h 836be168c0dSopenharmony_ci+++ b/include/sdk_api/types.h 837be168c0dSopenharmony_ci@@ -40,10 +40,46 @@ typedef enum OH_AI_DeviceType { 838be168c0dSopenharmony_ci OH_AI_DEVICETYPE_KIRIN_NPU, 839be168c0dSopenharmony_ci // add new type here 840be168c0dSopenharmony_ci // ohos-only device range: [60, 80) 841be168c0dSopenharmony_ci- OH_AI_DeviceType_NNRT = 60, 842be168c0dSopenharmony_ci+ OH_AI_DEVICETYPE_NNRT = 60, 843be168c0dSopenharmony_ci OH_AI_DEVICETYPE_INVALID = 100, 844be168c0dSopenharmony_ci } OH_AI_DeviceType; 845be168c0dSopenharmony_ci 846be168c0dSopenharmony_ci+typedef enum OH_AI_NNRTDeviceType { 847be168c0dSopenharmony_ci+ /** Devices that are not CPU, GPU, or dedicated accelerator */ 848be168c0dSopenharmony_ci+ OH_AI_NNRTDEVICE_OTHERS = 0, 849be168c0dSopenharmony_ci+ /** CPU device */ 850be168c0dSopenharmony_ci+ OH_AI_NNRTDEVICE_CPU = 1, 851be168c0dSopenharmony_ci+ /** GPU device */ 852be168c0dSopenharmony_ci+ OH_AI_NNRTDEVICE_GPU = 2, 853be168c0dSopenharmony_ci+ /** Dedicated hardware accelerator */ 854be168c0dSopenharmony_ci+ OH_AI_NNRTDEVICE_ACCELERATOR = 3, 855be168c0dSopenharmony_ci+} OH_AI_NNRTDeviceType; 856be168c0dSopenharmony_ci+ 857be168c0dSopenharmony_ci+typedef enum OH_AI_PerformanceMode { 858be168c0dSopenharmony_ci+ /** No performance mode preference */ 859be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_NONE = 0, 860be168c0dSopenharmony_ci+ /** Low power consumption mode*/ 861be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_LOW = 1, 862be168c0dSopenharmony_ci+ /** Medium performance mode */ 863be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_MEDIUM = 2, 864be168c0dSopenharmony_ci+ /** High performance mode */ 865be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_HIGH = 3, 866be168c0dSopenharmony_ci+ /** Ultimate performance mode */ 867be168c0dSopenharmony_ci+ OH_AI_PERFORMANCE_EXTREME = 4 868be168c0dSopenharmony_ci+} OH_AI_PerformanceMode; 869be168c0dSopenharmony_ci+ 870be168c0dSopenharmony_ci+typedef enum OH_AI_Priority { 871be168c0dSopenharmony_ci+ /** No priority preference */ 872be168c0dSopenharmony_ci+ OH_AI_PRIORITY_NONE = 0, 873be168c0dSopenharmony_ci+ /** Low priority */ 874be168c0dSopenharmony_ci+ OH_AI_PRIORITY_LOW = 1, 875be168c0dSopenharmony_ci+ /** Medium priority */ 876be168c0dSopenharmony_ci+ OH_AI_PRIORITY_MEDIUM = 2, 877be168c0dSopenharmony_ci+ /** High priority */ 878be168c0dSopenharmony_ci+ OH_AI_PRIORITY_HIGH = 3 879be168c0dSopenharmony_ci+} OH_AI_Priority; 880be168c0dSopenharmony_ci+ 881be168c0dSopenharmony_ci+typedef struct NNRTDeviceDesc NNRTDeviceDesc; 882be168c0dSopenharmony_ci #ifdef __cplusplus 883be168c0dSopenharmony_ci } 884be168c0dSopenharmony_ci #endif 885be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 886be168c0dSopenharmony_ciindex 7bbc3782..103e53b7 100644 887be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 888be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn 889be168c0dSopenharmony_ci@@ -498,6 +498,9 @@ infer_shape_sources = [ 890be168c0dSopenharmony_ci "infer/crop_infer.c", 891be168c0dSopenharmony_ci "infer/cumsum_infer.c", 892be168c0dSopenharmony_ci "infer/custom_gru_infer.c", 893be168c0dSopenharmony_ci+ "infer/custom_masked_fill_infer.c", 894be168c0dSopenharmony_ci+ "infer/custom_is_inf_infer.c", 895be168c0dSopenharmony_ci+ "infer/custom_tensor_scatter_max_infer.c", 896be168c0dSopenharmony_ci "infer/decoder_layer_infer.c", 897be168c0dSopenharmony_ci "infer/deconv2d_infer.c", 898be168c0dSopenharmony_ci "infer/depth_to_space_infer.c", 899be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt 900be168c0dSopenharmony_ciindex c1685a65..6fef44fd 100644 901be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt 902be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt 903be168c0dSopenharmony_ci@@ -238,7 +238,7 @@ endif() 904be168c0dSopenharmony_ci if(PLATFORM_ARM) 905be168c0dSopenharmony_ci set(NO_FAST_MATH_OPTI ${NNACL_DIR}/fp32/resize_fp32.c) 906be168c0dSopenharmony_ci set_source_files_properties(${NO_FAST_MATH_OPTI} PROPERTIES LANGUAGE C 907be168c0dSopenharmony_ci- COMPILE_FLAGS "${CMAKE_C_FLAGS} -fno-fast-math") 908be168c0dSopenharmony_ci+ COMPILE_FLAGS "${CMAKE_C_FLAGS} -w -fno-fast-math") 909be168c0dSopenharmony_ci endif() 910be168c0dSopenharmony_ci 911be168c0dSopenharmony_ci add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC}) 912be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h 913be168c0dSopenharmony_cinew file mode 100644 914be168c0dSopenharmony_ciindex 00000000..14bd1d76 915be168c0dSopenharmony_ci--- /dev/null 916be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h 917be168c0dSopenharmony_ci@@ -0,0 +1,66 @@ 918be168c0dSopenharmony_ci+/** 919be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd 920be168c0dSopenharmony_ci+* 921be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License"); 922be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License. 923be168c0dSopenharmony_ci+* You may obtain a copy of the License at 924be168c0dSopenharmony_ci+* 925be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0 926be168c0dSopenharmony_ci+* 927be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software 928be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS, 929be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 930be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and 931be168c0dSopenharmony_ci+* limitations under the License. 932be168c0dSopenharmony_ci+*/ 933be168c0dSopenharmony_ci+#ifndef NNACL_BASE_SCATTER_ND_BINARY_AVX_H_ 934be168c0dSopenharmony_ci+#define NNACL_BASE_SCATTER_ND_BINARY_AVX_H_ 935be168c0dSopenharmony_ci+ 936be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h" 937be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_avx_instructions.h" 938be168c0dSopenharmony_ci+ 939be168c0dSopenharmony_ci+#ifdef __cplusplus 940be168c0dSopenharmony_ci+extern "C" { 941be168c0dSopenharmony_ci+#endif 942be168c0dSopenharmony_ci+#pragma GCC push_options 943be168c0dSopenharmony_ci+#pragma GCC target("avx", "avx2") 944be168c0dSopenharmony_ci+#define MS_SIMD_INSTRUCTION MS_SIMD_AVX_INSTRUCTION 945be168c0dSopenharmony_ci+#define BLOCK_NUM 8 946be168c0dSopenharmony_ci+#define MS_SIMD_AVX 947be168c0dSopenharmony_ci+ 948be168c0dSopenharmony_ci+static inline int ScatterNDAddFp32AVX(int index, const float *update, int size, float *output) { 949be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 950be168c0dSopenharmony_ci+ SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 951be168c0dSopenharmony_ci+ } 952be168c0dSopenharmony_ci+ return index; 953be168c0dSopenharmony_ci+} 954be168c0dSopenharmony_ci+ 955be168c0dSopenharmony_ci+static inline int ScatterNDAddInt32AVX(int index, const int *update, int size, int *output) { 956be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 957be168c0dSopenharmony_ci+ SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 958be168c0dSopenharmony_ci+ } 959be168c0dSopenharmony_ci+ return index; 960be168c0dSopenharmony_ci+} 961be168c0dSopenharmony_ci+ 962be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32AVX(int index, const float *update, int size, float *output) { 963be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 964be168c0dSopenharmony_ci+ SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 965be168c0dSopenharmony_ci+ } 966be168c0dSopenharmony_ci+ return index; 967be168c0dSopenharmony_ci+} 968be168c0dSopenharmony_ci+ 969be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32AVX(int index, const int *update, int size, int *output) { 970be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 971be168c0dSopenharmony_ci+ SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 972be168c0dSopenharmony_ci+ } 973be168c0dSopenharmony_ci+ return index; 974be168c0dSopenharmony_ci+} 975be168c0dSopenharmony_ci+ 976be168c0dSopenharmony_ci+#undef MS_SIMD_INSTRUCTION 977be168c0dSopenharmony_ci+#undef BLOCK_NUM 978be168c0dSopenharmony_ci+#pragma GCC pop_options 979be168c0dSopenharmony_ci+#undef MS_SIMD_AVX 980be168c0dSopenharmony_ci+#ifdef __cplusplus 981be168c0dSopenharmony_ci+} 982be168c0dSopenharmony_ci+#endif 983be168c0dSopenharmony_ci+#endif // NNACL_BASE_SCATTER_ND_BINARY_AVX_H_ 984be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h 985be168c0dSopenharmony_cinew file mode 100644 986be168c0dSopenharmony_ciindex 00000000..abf024c5 987be168c0dSopenharmony_ci--- /dev/null 988be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h 989be168c0dSopenharmony_ci@@ -0,0 +1,66 @@ 990be168c0dSopenharmony_ci+/** 991be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd 992be168c0dSopenharmony_ci+* 993be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License"); 994be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License. 995be168c0dSopenharmony_ci+* You may obtain a copy of the License at 996be168c0dSopenharmony_ci+* 997be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0 998be168c0dSopenharmony_ci+* 999be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software 1000be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS, 1001be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1002be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and 1003be168c0dSopenharmony_ci+* limitations under the License. 1004be168c0dSopenharmony_ci+*/ 1005be168c0dSopenharmony_ci+#ifndef NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_ 1006be168c0dSopenharmony_ci+#define NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_ 1007be168c0dSopenharmony_ci+ 1008be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h" 1009be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_avx512_instructions.h" 1010be168c0dSopenharmony_ci+ 1011be168c0dSopenharmony_ci+#ifdef __cplusplus 1012be168c0dSopenharmony_ci+extern "C" { 1013be168c0dSopenharmony_ci+#endif 1014be168c0dSopenharmony_ci+#pragma GCC push_options 1015be168c0dSopenharmony_ci+#pragma GCC target("avx512f") 1016be168c0dSopenharmony_ci+#define MS_SIMD_INSTRUCTION MS_SIMD_AVX512_INSTRUCTION 1017be168c0dSopenharmony_ci+#define BLOCK_NUM 16 1018be168c0dSopenharmony_ci+#define MS_SIMD_AVX512 1019be168c0dSopenharmony_ci+ 1020be168c0dSopenharmony_ci+static inline int ScatterNDAddFp32AVX512(int index, const float *update, int size, float *output) { 1021be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1022be168c0dSopenharmony_ci+ SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1023be168c0dSopenharmony_ci+ } 1024be168c0dSopenharmony_ci+ return index; 1025be168c0dSopenharmony_ci+} 1026be168c0dSopenharmony_ci+ 1027be168c0dSopenharmony_ci+static inline int ScatterNDAddInt32AVX512(int index, const int *update, int size, int *output) { 1028be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1029be168c0dSopenharmony_ci+ SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1030be168c0dSopenharmony_ci+ } 1031be168c0dSopenharmony_ci+ return index; 1032be168c0dSopenharmony_ci+} 1033be168c0dSopenharmony_ci+ 1034be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32AVX512(int index, const float *update, int size, float *output) { 1035be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1036be168c0dSopenharmony_ci+ SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1037be168c0dSopenharmony_ci+ } 1038be168c0dSopenharmony_ci+ return index; 1039be168c0dSopenharmony_ci+} 1040be168c0dSopenharmony_ci+ 1041be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32AVX512(int index, const int *update, int size, int *output) { 1042be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1043be168c0dSopenharmony_ci+ SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1044be168c0dSopenharmony_ci+ } 1045be168c0dSopenharmony_ci+ return index; 1046be168c0dSopenharmony_ci+} 1047be168c0dSopenharmony_ci+ 1048be168c0dSopenharmony_ci+#undef MS_SIMD_INSTRUCTION 1049be168c0dSopenharmony_ci+#undef BLOCK_NUM 1050be168c0dSopenharmony_ci+#pragma GCC pop_options 1051be168c0dSopenharmony_ci+#undef MS_SIMD_AVX512 1052be168c0dSopenharmony_ci+#ifdef __cplusplus 1053be168c0dSopenharmony_ci+} 1054be168c0dSopenharmony_ci+#endif 1055be168c0dSopenharmony_ci+#endif // NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_ 1056be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c 1057be168c0dSopenharmony_ciindex bca71f55..e496bb4b 100644 1058be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c 1059be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c 1060be168c0dSopenharmony_ci@@ -77,3 +77,31 @@ int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, 1061be168c0dSopenharmony_ci } 1062be168c0dSopenharmony_ci return NNACL_OK; 1063be168c0dSopenharmony_ci } 1064be168c0dSopenharmony_ci+ 1065be168c0dSopenharmony_ci+int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, 1066be168c0dSopenharmony_ci+ int task_id) { 1067be168c0dSopenharmony_ci+ if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) { 1068be168c0dSopenharmony_ci+ return NNACL_NULL_PTR; 1069be168c0dSopenharmony_ci+ } 1070be168c0dSopenharmony_ci+ if (param->op_parameter.thread_num_ == 0) { 1071be168c0dSopenharmony_ci+ return NNACL_ERR; 1072be168c0dSopenharmony_ci+ } 1073be168c0dSopenharmony_ci+ int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_); 1074be168c0dSopenharmony_ci+ int begin = unit_per_thread * task_id; 1075be168c0dSopenharmony_ci+ int end = MSMIN(begin + unit_per_thread, param->num_unit); 1076be168c0dSopenharmony_ci+ if (type == 0) { 1077be168c0dSopenharmony_ci+ float *update_fp32 = (float *)update; 1078be168c0dSopenharmony_ci+ float *output_fp32 = (float *)output; 1079be168c0dSopenharmony_ci+ for (int i = begin; i < end; i++) { 1080be168c0dSopenharmony_ci+ const float *update_data = update_fp32 + i * param->unit_size; 1081be168c0dSopenharmony_ci+ float *output_data = output_fp32 + output_unit_offsets[i]; 1082be168c0dSopenharmony_ci+ int j = 0; 1083be168c0dSopenharmony_ci+ 1084be168c0dSopenharmony_ci+ SIMD_RUN_NO_SCALAR(ScatterNDMaxFp32, j, update_data, param->unit_size, output_data); 1085be168c0dSopenharmony_ci+ for (; j < param->unit_size; j++) { 1086be168c0dSopenharmony_ci+ output_data[j] = fmaxf(update_data[j], output_data[j]); 1087be168c0dSopenharmony_ci+ } 1088be168c0dSopenharmony_ci+ } 1089be168c0dSopenharmony_ci+ } 1090be168c0dSopenharmony_ci+ return NNACL_OK; 1091be168c0dSopenharmony_ci+} 1092be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h 1093be168c0dSopenharmony_ciindex 3af55335..36657cd9 100644 1094be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h 1095be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h 1096be168c0dSopenharmony_ci@@ -27,6 +27,9 @@ int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets, 1097be168c0dSopenharmony_ci 1098be168c0dSopenharmony_ci int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, 1099be168c0dSopenharmony_ci int task_id); 1100be168c0dSopenharmony_ci+ 1101be168c0dSopenharmony_ci+int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type, 1102be168c0dSopenharmony_ci+ int task_id); 1103be168c0dSopenharmony_ci #ifdef __cplusplus 1104be168c0dSopenharmony_ci } 1105be168c0dSopenharmony_ci #endif 1106be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in 1107be168c0dSopenharmony_ciindex c72d9cc2..46bb20ce 100644 1108be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in 1109be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in 1110be168c0dSopenharmony_ci@@ -38,6 +38,20 @@ static inline int ScatterNDAddInt32@SIMD_INSTRUCTION@(int index, const int *upda 1111be168c0dSopenharmony_ci return index; 1112be168c0dSopenharmony_ci } 1113be168c0dSopenharmony_ci 1114be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) { 1115be168c0dSopenharmony_ci+for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1116be168c0dSopenharmony_ci+SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1117be168c0dSopenharmony_ci+} 1118be168c0dSopenharmony_ci+return index; 1119be168c0dSopenharmony_ci+} 1120be168c0dSopenharmony_ci+ 1121be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) { 1122be168c0dSopenharmony_ci+for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1123be168c0dSopenharmony_ci+SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1124be168c0dSopenharmony_ci+} 1125be168c0dSopenharmony_ci+return index; 1126be168c0dSopenharmony_ci+} 1127be168c0dSopenharmony_ci+ 1128be168c0dSopenharmony_ci @SIMD_INSTRUCTION_END@ 1129be168c0dSopenharmony_ci #ifdef __cplusplus 1130be168c0dSopenharmony_ci } 1131be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h 1132be168c0dSopenharmony_cinew file mode 100644 1133be168c0dSopenharmony_ciindex 00000000..e1eae394 1134be168c0dSopenharmony_ci--- /dev/null 1135be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h 1136be168c0dSopenharmony_ci@@ -0,0 +1,26 @@ 1137be168c0dSopenharmony_ci+/** 1138be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1139be168c0dSopenharmony_ci+ * 1140be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1141be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1142be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1143be168c0dSopenharmony_ci+ * 1144be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1145be168c0dSopenharmony_ci+ * 1146be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1147be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1148be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1149be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1150be168c0dSopenharmony_ci+ * limitations under the License. 1151be168c0dSopenharmony_ci+ */ 1152be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ 1153be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ 1154be168c0dSopenharmony_ci+ 1155be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 1156be168c0dSopenharmony_ci+ 1157be168c0dSopenharmony_ci+typedef struct CustomIsInfParameter { 1158be168c0dSopenharmony_ci+ // Primitive parameter 1159be168c0dSopenharmony_ci+ OpParameter op_parameter_; 1160be168c0dSopenharmony_ci+} CustomIsInfParameter; 1161be168c0dSopenharmony_ci+ 1162be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_ 1163be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h 1164be168c0dSopenharmony_cinew file mode 100644 1165be168c0dSopenharmony_ciindex 00000000..047d3d3f 1166be168c0dSopenharmony_ci--- /dev/null 1167be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h 1168be168c0dSopenharmony_ci@@ -0,0 +1,26 @@ 1169be168c0dSopenharmony_ci+/** 1170be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1171be168c0dSopenharmony_ci+ * 1172be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1173be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1174be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1175be168c0dSopenharmony_ci+ * 1176be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1177be168c0dSopenharmony_ci+ * 1178be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1179be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1180be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1181be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1182be168c0dSopenharmony_ci+ * limitations under the License. 1183be168c0dSopenharmony_ci+ */ 1184be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ 1185be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ 1186be168c0dSopenharmony_ci+ 1187be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 1188be168c0dSopenharmony_ci+ 1189be168c0dSopenharmony_ci+typedef struct CustomMaskedFillParameter { 1190be168c0dSopenharmony_ci+ // Primitive parameter 1191be168c0dSopenharmony_ci+ OpParameter op_parameter_; 1192be168c0dSopenharmony_ci+} CustomMaskedFillParameter; 1193be168c0dSopenharmony_ci+ 1194be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_ 1195be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h 1196be168c0dSopenharmony_cinew file mode 100644 1197be168c0dSopenharmony_ciindex 00000000..ba6940db 1198be168c0dSopenharmony_ci--- /dev/null 1199be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h 1200be168c0dSopenharmony_ci@@ -0,0 +1,26 @@ 1201be168c0dSopenharmony_ci+/** 1202be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1203be168c0dSopenharmony_ci+ * 1204be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1205be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1206be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1207be168c0dSopenharmony_ci+ * 1208be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1209be168c0dSopenharmony_ci+ * 1210be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1211be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1212be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1213be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1214be168c0dSopenharmony_ci+ * limitations under the License. 1215be168c0dSopenharmony_ci+ */ 1216be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_ 1217be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_ 1218be168c0dSopenharmony_ci+ 1219be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 1220be168c0dSopenharmony_ci+ 1221be168c0dSopenharmony_ci+typedef struct CustomTensorScatterMaxParameter { 1222be168c0dSopenharmony_ci+ // Primitive parameter 1223be168c0dSopenharmony_ci+ OpParameter op_parameter_; 1224be168c0dSopenharmony_ci+} CustomTensorScatterMaxParameter; 1225be168c0dSopenharmony_ci+ 1226be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_ 1227be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c 1228be168c0dSopenharmony_cinew file mode 100644 1229be168c0dSopenharmony_ciindex 00000000..fc87d157 1230be168c0dSopenharmony_ci--- /dev/null 1231be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c 1232be168c0dSopenharmony_ci@@ -0,0 +1,38 @@ 1233be168c0dSopenharmony_ci+/** 1234be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1235be168c0dSopenharmony_ci+ * 1236be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1237be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1238be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1239be168c0dSopenharmony_ci+ * 1240be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1241be168c0dSopenharmony_ci+ * 1242be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1243be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1244be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1245be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1246be168c0dSopenharmony_ci+ * limitations under the License. 1247be168c0dSopenharmony_ci+ */ 1248be168c0dSopenharmony_ci+ 1249be168c0dSopenharmony_ci+#include "nnacl/infer/custom_is_inf_infer.h" 1250be168c0dSopenharmony_ci+#include "nnacl/infer/infer_register.h" 1251be168c0dSopenharmony_ci+ 1252be168c0dSopenharmony_ci+int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 1253be168c0dSopenharmony_ci+ OpParameter *parameter) { 1254be168c0dSopenharmony_ci+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C1NUM, C1NUM); 1255be168c0dSopenharmony_ci+ if (check_ret != NNACL_OK) { 1256be168c0dSopenharmony_ci+ return check_ret; 1257be168c0dSopenharmony_ci+ } 1258be168c0dSopenharmony_ci+ 1259be168c0dSopenharmony_ci+ const TensorC *input = inputs[0]; 1260be168c0dSopenharmony_ci+ TensorC *output = outputs[0]; 1261be168c0dSopenharmony_ci+ output->data_type_ = kNumberTypeBool; 1262be168c0dSopenharmony_ci+ output->format_ = input->format_; 1263be168c0dSopenharmony_ci+ if (!InferFlag(inputs, inputs_size)) { 1264be168c0dSopenharmony_ci+ return NNACL_INFER_INVALID; 1265be168c0dSopenharmony_ci+ } 1266be168c0dSopenharmony_ci+ SetShapeTensor(output, input); 1267be168c0dSopenharmony_ci+ return NNACL_OK; 1268be168c0dSopenharmony_ci+} 1269be168c0dSopenharmony_ci+ 1270be168c0dSopenharmony_ci+REG_INFER(CustomIsInf, PrimType_Inner_CustomIsInf, CustomIsInfInferShape) 1271be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h 1272be168c0dSopenharmony_cinew file mode 100644 1273be168c0dSopenharmony_ciindex 00000000..d1b4b33d 1274be168c0dSopenharmony_ci--- /dev/null 1275be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h 1276be168c0dSopenharmony_ci@@ -0,0 +1,31 @@ 1277be168c0dSopenharmony_ci+/** 1278be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1279be168c0dSopenharmony_ci+ * 1280be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1281be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1282be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1283be168c0dSopenharmony_ci+ * 1284be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1285be168c0dSopenharmony_ci+ * 1286be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1287be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1288be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1289be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1290be168c0dSopenharmony_ci+ * limitations under the License. 1291be168c0dSopenharmony_ci+ */ 1292be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H 1293be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H 1294be168c0dSopenharmony_ci+ 1295be168c0dSopenharmony_ci+#include "nnacl/infer/common_infer.h" 1296be168c0dSopenharmony_ci+ 1297be168c0dSopenharmony_ci+#ifdef __cplusplus 1298be168c0dSopenharmony_ci+extern "C" { 1299be168c0dSopenharmony_ci+#endif 1300be168c0dSopenharmony_ci+ 1301be168c0dSopenharmony_ci+int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 1302be168c0dSopenharmony_ci+ OpParameter *parameter); 1303be168c0dSopenharmony_ci+ 1304be168c0dSopenharmony_ci+#ifdef __cplusplus 1305be168c0dSopenharmony_ci+} 1306be168c0dSopenharmony_ci+#endif 1307be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H 1308be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c 1309be168c0dSopenharmony_cinew file mode 100644 1310be168c0dSopenharmony_ciindex 00000000..957a4d4f 1311be168c0dSopenharmony_ci--- /dev/null 1312be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c 1313be168c0dSopenharmony_ci@@ -0,0 +1,37 @@ 1314be168c0dSopenharmony_ci+/** 1315be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1316be168c0dSopenharmony_ci+ * 1317be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1318be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1319be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1320be168c0dSopenharmony_ci+ * 1321be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1322be168c0dSopenharmony_ci+ * 1323be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1324be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1325be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1326be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1327be168c0dSopenharmony_ci+ * limitations under the License. 1328be168c0dSopenharmony_ci+ */ 1329be168c0dSopenharmony_ci+ 1330be168c0dSopenharmony_ci+#include "nnacl/infer/custom_masked_fill_infer.h" 1331be168c0dSopenharmony_ci+#include "nnacl/infer/infer_register.h" 1332be168c0dSopenharmony_ci+ 1333be168c0dSopenharmony_ci+int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 1334be168c0dSopenharmony_ci+ OpParameter *parameter) { 1335be168c0dSopenharmony_ci+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); 1336be168c0dSopenharmony_ci+ if (check_ret != NNACL_OK) { 1337be168c0dSopenharmony_ci+ return check_ret; 1338be168c0dSopenharmony_ci+ } 1339be168c0dSopenharmony_ci+ 1340be168c0dSopenharmony_ci+ const TensorC *input = inputs[0]; 1341be168c0dSopenharmony_ci+ TensorC *output = outputs[0]; 1342be168c0dSopenharmony_ci+ SetDataTypeFormat(output, input); 1343be168c0dSopenharmony_ci+ if (!InferFlag(inputs, inputs_size)) { 1344be168c0dSopenharmony_ci+ return NNACL_INFER_INVALID; 1345be168c0dSopenharmony_ci+ } 1346be168c0dSopenharmony_ci+ SetShapeTensor(output, input); 1347be168c0dSopenharmony_ci+ return NNACL_OK; 1348be168c0dSopenharmony_ci+} 1349be168c0dSopenharmony_ci+ 1350be168c0dSopenharmony_ci+REG_INFER(CustomMaskedFill, PrimType_Inner_CustomMaskedFill, CustomMaskedFillInferShape) 1351be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h 1352be168c0dSopenharmony_cinew file mode 100644 1353be168c0dSopenharmony_ciindex 00000000..a8adbae2 1354be168c0dSopenharmony_ci--- /dev/null 1355be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h 1356be168c0dSopenharmony_ci@@ -0,0 +1,31 @@ 1357be168c0dSopenharmony_ci+/** 1358be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1359be168c0dSopenharmony_ci+ * 1360be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1361be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1362be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1363be168c0dSopenharmony_ci+ * 1364be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1365be168c0dSopenharmony_ci+ * 1366be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1367be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1368be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1369be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1370be168c0dSopenharmony_ci+ * limitations under the License. 1371be168c0dSopenharmony_ci+ */ 1372be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H 1373be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H 1374be168c0dSopenharmony_ci+ 1375be168c0dSopenharmony_ci+#include "nnacl/infer/common_infer.h" 1376be168c0dSopenharmony_ci+ 1377be168c0dSopenharmony_ci+#ifdef __cplusplus 1378be168c0dSopenharmony_ci+extern "C" { 1379be168c0dSopenharmony_ci+#endif 1380be168c0dSopenharmony_ci+ 1381be168c0dSopenharmony_ci+int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, 1382be168c0dSopenharmony_ci+ OpParameter *parameter); 1383be168c0dSopenharmony_ci+ 1384be168c0dSopenharmony_ci+#ifdef __cplusplus 1385be168c0dSopenharmony_ci+} 1386be168c0dSopenharmony_ci+#endif 1387be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H 1388be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c 1389be168c0dSopenharmony_cinew file mode 100644 1390be168c0dSopenharmony_ciindex 00000000..be6716ba 1391be168c0dSopenharmony_ci--- /dev/null 1392be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c 1393be168c0dSopenharmony_ci@@ -0,0 +1,37 @@ 1394be168c0dSopenharmony_ci+/** 1395be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1396be168c0dSopenharmony_ci+ * 1397be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1398be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1399be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1400be168c0dSopenharmony_ci+ * 1401be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1402be168c0dSopenharmony_ci+ * 1403be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1404be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1405be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1406be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1407be168c0dSopenharmony_ci+ * limitations under the License. 1408be168c0dSopenharmony_ci+ */ 1409be168c0dSopenharmony_ci+ 1410be168c0dSopenharmony_ci+#include "nnacl/infer/custom_tensor_scatter_max_infer.h" 1411be168c0dSopenharmony_ci+#include "nnacl/infer/infer_register.h" 1412be168c0dSopenharmony_ci+ 1413be168c0dSopenharmony_ci+int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, 1414be168c0dSopenharmony_ci+ size_t outputs_size, OpParameter *parameter) { 1415be168c0dSopenharmony_ci+ int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM); 1416be168c0dSopenharmony_ci+ if (check_ret != NNACL_OK) { 1417be168c0dSopenharmony_ci+ return check_ret; 1418be168c0dSopenharmony_ci+ } 1419be168c0dSopenharmony_ci+ 1420be168c0dSopenharmony_ci+ const TensorC *input = inputs[0]; 1421be168c0dSopenharmony_ci+ TensorC *output = outputs[0]; 1422be168c0dSopenharmony_ci+ SetDataTypeFormat(output, input); 1423be168c0dSopenharmony_ci+ if (!InferFlag(inputs, inputs_size)) { 1424be168c0dSopenharmony_ci+ return NNACL_INFER_INVALID; 1425be168c0dSopenharmony_ci+ } 1426be168c0dSopenharmony_ci+ SetShapeTensor(output, input); 1427be168c0dSopenharmony_ci+ return NNACL_OK; 1428be168c0dSopenharmony_ci+} 1429be168c0dSopenharmony_ci+ 1430be168c0dSopenharmony_ci+REG_INFER(CustomTensorScatterMax, PrimType_Inner_CustomTensorScatterMax, CustomTensorScatterMaxInferShape) 1431be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h 1432be168c0dSopenharmony_cinew file mode 100644 1433be168c0dSopenharmony_ciindex 00000000..641aa483 1434be168c0dSopenharmony_ci--- /dev/null 1435be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h 1436be168c0dSopenharmony_ci@@ -0,0 +1,31 @@ 1437be168c0dSopenharmony_ci+/** 1438be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 1439be168c0dSopenharmony_ci+ * 1440be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 1441be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 1442be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 1443be168c0dSopenharmony_ci+ * 1444be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 1445be168c0dSopenharmony_ci+ * 1446be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 1447be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 1448be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1449be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 1450be168c0dSopenharmony_ci+ * limitations under the License. 1451be168c0dSopenharmony_ci+ */ 1452be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H 1453be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H 1454be168c0dSopenharmony_ci+ 1455be168c0dSopenharmony_ci+#include "nnacl/infer/common_infer.h" 1456be168c0dSopenharmony_ci+ 1457be168c0dSopenharmony_ci+#ifdef __cplusplus 1458be168c0dSopenharmony_ci+extern "C" { 1459be168c0dSopenharmony_ci+#endif 1460be168c0dSopenharmony_ci+ 1461be168c0dSopenharmony_ci+int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, 1462be168c0dSopenharmony_ci+ size_t outputs_size, OpParameter *parameter); 1463be168c0dSopenharmony_ci+ 1464be168c0dSopenharmony_ci+#ifdef __cplusplus 1465be168c0dSopenharmony_ci+} 1466be168c0dSopenharmony_ci+#endif 1467be168c0dSopenharmony_ci+#endif // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H 1468be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h 1469be168c0dSopenharmony_cinew file mode 100644 1470be168c0dSopenharmony_ciindex 00000000..d7c34768 1471be168c0dSopenharmony_ci--- /dev/null 1472be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h 1473be168c0dSopenharmony_ci@@ -0,0 +1,65 @@ 1474be168c0dSopenharmony_ci+/** 1475be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd 1476be168c0dSopenharmony_ci+* 1477be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License"); 1478be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License. 1479be168c0dSopenharmony_ci+* You may obtain a copy of the License at 1480be168c0dSopenharmony_ci+* 1481be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0 1482be168c0dSopenharmony_ci+* 1483be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software 1484be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS, 1485be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1486be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and 1487be168c0dSopenharmony_ci+* limitations under the License. 1488be168c0dSopenharmony_ci+*/ 1489be168c0dSopenharmony_ci+#ifndef NNACL_BASE_SCATTER_ND_BINARY_NEON_H_ 1490be168c0dSopenharmony_ci+#define NNACL_BASE_SCATTER_ND_BINARY_NEON_H_ 1491be168c0dSopenharmony_ci+ 1492be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h" 1493be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_neon_instructions.h" 1494be168c0dSopenharmony_ci+ 1495be168c0dSopenharmony_ci+#ifdef __cplusplus 1496be168c0dSopenharmony_ci+extern "C" { 1497be168c0dSopenharmony_ci+#endif 1498be168c0dSopenharmony_ci+ 1499be168c0dSopenharmony_ci+#define MS_SIMD_INSTRUCTION MS_SIMD_NEON_INSTRUCTION 1500be168c0dSopenharmony_ci+#define BLOCK_NUM 4 1501be168c0dSopenharmony_ci+#define MS_SIMD_NEON 1502be168c0dSopenharmony_ci+ 1503be168c0dSopenharmony_ci+static inline int ScatterNDAddFp32NEON(int index, const float *update, int size, float *output) { 1504be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1505be168c0dSopenharmony_ci+ SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1506be168c0dSopenharmony_ci+ } 1507be168c0dSopenharmony_ci+ return index; 1508be168c0dSopenharmony_ci+} 1509be168c0dSopenharmony_ci+ 1510be168c0dSopenharmony_ci+static inline int ScatterNDAddInt32NEON(int index, const int *update, int size, int *output) { 1511be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1512be168c0dSopenharmony_ci+ SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1513be168c0dSopenharmony_ci+ } 1514be168c0dSopenharmony_ci+ return index; 1515be168c0dSopenharmony_ci+} 1516be168c0dSopenharmony_ci+ 1517be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32NEON(int index, const float *update, int size, float *output) { 1518be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1519be168c0dSopenharmony_ci+ SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1520be168c0dSopenharmony_ci+ } 1521be168c0dSopenharmony_ci+ return index; 1522be168c0dSopenharmony_ci+} 1523be168c0dSopenharmony_ci+ 1524be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32NEON(int index, const int *update, int size, int *output) { 1525be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1526be168c0dSopenharmony_ci+ SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1527be168c0dSopenharmony_ci+ } 1528be168c0dSopenharmony_ci+ return index; 1529be168c0dSopenharmony_ci+} 1530be168c0dSopenharmony_ci+ 1531be168c0dSopenharmony_ci+#undef MS_SIMD_INSTRUCTION 1532be168c0dSopenharmony_ci+#undef BLOCK_NUM 1533be168c0dSopenharmony_ci+ 1534be168c0dSopenharmony_ci+#undef MS_SIMD_NEON 1535be168c0dSopenharmony_ci+#ifdef __cplusplus 1536be168c0dSopenharmony_ci+} 1537be168c0dSopenharmony_ci+#endif 1538be168c0dSopenharmony_ci+#endif // NNACL_BASE_SCATTER_ND_BINARY_NEON_H_ 1539be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 1540be168c0dSopenharmony_ciindex 955a70a5..895f7e3d 100644 1541be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 1542be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h 1543be168c0dSopenharmony_ci@@ -558,6 +558,10 @@ enum PrimType { 1544be168c0dSopenharmony_ci PrimType_Inner_CustomGru = 10010, 1545be168c0dSopenharmony_ci PrimType_Inner_CastGatherReduceFusion = 10011, 1546be168c0dSopenharmony_ci PrimType_Inner_ReduceConcatFusion = 10012, 1547be168c0dSopenharmony_ci+ PrimType_Inner_ThirdPartyModel = 10013, 1548be168c0dSopenharmony_ci+ PrimType_Inner_CustomMaskedFill = 10014, 1549be168c0dSopenharmony_ci+ PrimType_Inner_CustomTensorScatterMax = 10015, 1550be168c0dSopenharmony_ci+ PrimType_Inner_CustomIsInf = 10016, 1551be168c0dSopenharmony_ci PrimType_InnerOpMax, 1552be168c0dSopenharmony_ci PrimType_InnerOpMin = PrimType_Inner_ToFormat 1553be168c0dSopenharmony_ci }; 1554be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h 1555be168c0dSopenharmony_cinew file mode 100644 1556be168c0dSopenharmony_ciindex 00000000..dd9878f7 1557be168c0dSopenharmony_ci--- /dev/null 1558be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h 1559be168c0dSopenharmony_ci@@ -0,0 +1,36 @@ 1560be168c0dSopenharmony_ci+/** 1561be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd 1562be168c0dSopenharmony_ci+* 1563be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License"); 1564be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License. 1565be168c0dSopenharmony_ci+* You may obtain a copy of the License at 1566be168c0dSopenharmony_ci+* 1567be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0 1568be168c0dSopenharmony_ci+* 1569be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software 1570be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS, 1571be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1572be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and 1573be168c0dSopenharmony_ci+* limitations under the License. 1574be168c0dSopenharmony_ci+*/ 1575be168c0dSopenharmony_ci+#ifndef NNACL_SCATTER_ND_BINARY_SIMD_H_ 1576be168c0dSopenharmony_ci+#define NNACL_SCATTER_ND_BINARY_SIMD_H_ 1577be168c0dSopenharmony_ci+ 1578be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h" 1579be168c0dSopenharmony_ci+#ifdef ENABLE_AVX512 1580be168c0dSopenharmony_ci+#include "nnacl/avx512/scatter_nd_binary_avx512.h" 1581be168c0dSopenharmony_ci+#endif 1582be168c0dSopenharmony_ci+ 1583be168c0dSopenharmony_ci+#ifdef ENABLE_AVX 1584be168c0dSopenharmony_ci+#include "nnacl/avx/scatter_nd_binary_avx.h" 1585be168c0dSopenharmony_ci+#endif 1586be168c0dSopenharmony_ci+ 1587be168c0dSopenharmony_ci+#ifdef ENABLE_SSE 1588be168c0dSopenharmony_ci+#include "nnacl/sse/scatter_nd_binary_sse.h" 1589be168c0dSopenharmony_ci+#endif 1590be168c0dSopenharmony_ci+ 1591be168c0dSopenharmony_ci+#ifdef ENABLE_ARM 1592be168c0dSopenharmony_ci+#include "nnacl/neon/scatter_nd_binary_neon.h" 1593be168c0dSopenharmony_ci+#endif 1594be168c0dSopenharmony_ci+ 1595be168c0dSopenharmony_ci+#endif 1596be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h 1597be168c0dSopenharmony_cinew file mode 100644 1598be168c0dSopenharmony_ciindex 00000000..983d2923 1599be168c0dSopenharmony_ci--- /dev/null 1600be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h 1601be168c0dSopenharmony_ci@@ -0,0 +1,66 @@ 1602be168c0dSopenharmony_ci+/** 1603be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd 1604be168c0dSopenharmony_ci+* 1605be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License"); 1606be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License. 1607be168c0dSopenharmony_ci+* You may obtain a copy of the License at 1608be168c0dSopenharmony_ci+* 1609be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0 1610be168c0dSopenharmony_ci+* 1611be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software 1612be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS, 1613be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1614be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and 1615be168c0dSopenharmony_ci+* limitations under the License. 1616be168c0dSopenharmony_ci+*/ 1617be168c0dSopenharmony_ci+#ifndef NNACL_BASE_SCATTER_ND_BINARY_SSE_H_ 1618be168c0dSopenharmony_ci+#define NNACL_BASE_SCATTER_ND_BINARY_SSE_H_ 1619be168c0dSopenharmony_ci+ 1620be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h" 1621be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_sse_instructions.h" 1622be168c0dSopenharmony_ci+ 1623be168c0dSopenharmony_ci+#ifdef __cplusplus 1624be168c0dSopenharmony_ci+extern "C" { 1625be168c0dSopenharmony_ci+#endif 1626be168c0dSopenharmony_ci+#pragma GCC push_options 1627be168c0dSopenharmony_ci+#pragma GCC target("sse4.1") 1628be168c0dSopenharmony_ci+#define MS_SIMD_INSTRUCTION MS_SIMD_SSE_INSTRUCTION 1629be168c0dSopenharmony_ci+#define BLOCK_NUM 4 1630be168c0dSopenharmony_ci+#define MS_SIMD_SSE 1631be168c0dSopenharmony_ci+ 1632be168c0dSopenharmony_ci+static inline int ScatterNDAddFp32SSE(int index, const float *update, int size, float *output) { 1633be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1634be168c0dSopenharmony_ci+ SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1635be168c0dSopenharmony_ci+ } 1636be168c0dSopenharmony_ci+ return index; 1637be168c0dSopenharmony_ci+} 1638be168c0dSopenharmony_ci+ 1639be168c0dSopenharmony_ci+static inline int ScatterNDAddInt32SSE(int index, const int *update, int size, int *output) { 1640be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1641be168c0dSopenharmony_ci+ SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1642be168c0dSopenharmony_ci+ } 1643be168c0dSopenharmony_ci+ return index; 1644be168c0dSopenharmony_ci+} 1645be168c0dSopenharmony_ci+ 1646be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32SSE(int index, const float *update, int size, float *output) { 1647be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1648be168c0dSopenharmony_ci+ SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index))); 1649be168c0dSopenharmony_ci+ } 1650be168c0dSopenharmony_ci+ return index; 1651be168c0dSopenharmony_ci+} 1652be168c0dSopenharmony_ci+ 1653be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32SSE(int index, const int *update, int size, int *output) { 1654be168c0dSopenharmony_ci+ for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) { 1655be168c0dSopenharmony_ci+ SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index))); 1656be168c0dSopenharmony_ci+ } 1657be168c0dSopenharmony_ci+ return index; 1658be168c0dSopenharmony_ci+} 1659be168c0dSopenharmony_ci+ 1660be168c0dSopenharmony_ci+#undef MS_SIMD_INSTRUCTION 1661be168c0dSopenharmony_ci+#undef BLOCK_NUM 1662be168c0dSopenharmony_ci+#pragma GCC pop_options 1663be168c0dSopenharmony_ci+#undef MS_SIMD_SSE 1664be168c0dSopenharmony_ci+#ifdef __cplusplus 1665be168c0dSopenharmony_ci+} 1666be168c0dSopenharmony_ci+#endif 1667be168c0dSopenharmony_ci+#endif // NNACL_BASE_SCATTER_ND_BINARY_SSE_H_ 1668be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/BUILD.gn b/mindspore/core/mindrt/BUILD.gn 1669be168c0dSopenharmony_ciindex b56d5f5c..b0e7c70d 100644 1670be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/BUILD.gn 1671be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/BUILD.gn 1672be168c0dSopenharmony_ci@@ -41,8 +41,15 @@ ohos_source_set("mindrt_obj") { 1673be168c0dSopenharmony_ci "../../core/", 1674be168c0dSopenharmony_ci ] 1675be168c0dSopenharmony_ci 1676be168c0dSopenharmony_ci+ defines = [ 1677be168c0dSopenharmony_ci+ "ENABLE_MINDRT", 1678be168c0dSopenharmony_ci+ "MS_COMPILE_OHOS", 1679be168c0dSopenharmony_ci+ "BUILD_LITE", 1680be168c0dSopenharmony_ci+ ] 1681be168c0dSopenharmony_ci+ 1682be168c0dSopenharmony_ci+ external_deps = [ "hilog:libhilog" ] 1683be168c0dSopenharmony_ci+ 1684be168c0dSopenharmony_ci remove_configs = [ "//build/config/compiler:no_rtti" ] 1685be168c0dSopenharmony_ci- defines = [ "BUILD_LITE" ] 1686be168c0dSopenharmony_ci 1687be168c0dSopenharmony_ci part_name = "mindspore" 1688be168c0dSopenharmony_ci subsystem_name = "thirdparty" 1689be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/actor_threadpool.cc b/mindspore/core/mindrt/src/thread/actor_threadpool.cc 1690be168c0dSopenharmony_ciindex 70414757..c50c46e0 100644 1691be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/actor_threadpool.cc 1692be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/actor_threadpool.cc 1693be168c0dSopenharmony_ci@@ -32,7 +32,7 @@ void ActorWorker::RunWithSpin() { 1694be168c0dSopenharmony_ci } 1695be168c0dSopenharmony_ci #if !defined(__APPLE__) && !defined(_MSC_VER) 1696be168c0dSopenharmony_ci static std::atomic_int index{0}; 1697be168c0dSopenharmony_ci- (void)pthread_setname_np(pthread_self(), ("ActorThread_" + std::to_string(index++)).c_str()); 1698be168c0dSopenharmony_ci+ (void)pthread_setname_np(pthread_self(), ("OS_Actor_" + std::to_string(index++)).c_str()); 1699be168c0dSopenharmony_ci #endif 1700be168c0dSopenharmony_ci #ifdef PLATFORM_86 1701be168c0dSopenharmony_ci // Some CPU kernels need set the flush zero mode to improve performance. 1702be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/core_affinity.cc b/mindspore/core/mindrt/src/thread/core_affinity.cc 1703be168c0dSopenharmony_ciindex 33bf3529..a3478dff 100644 1704be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/core_affinity.cc 1705be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/core_affinity.cc 1706be168c0dSopenharmony_ci@@ -344,12 +344,12 @@ int CoreAffinity::InitBindCoreId(size_t thread_num, BindMode bind_mode) { 1707be168c0dSopenharmony_ci int CoreAffinity::SetAffinity() { return THREAD_OK; } 1708be168c0dSopenharmony_ci #elif defined(BIND_CORE) 1709be168c0dSopenharmony_ci int CoreAffinity::SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set) { 1710be168c0dSopenharmony_ci-#ifdef __ANDROID__ 1711be168c0dSopenharmony_ci-#if __ANDROID_API__ >= 21 1712be168c0dSopenharmony_ci+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) 1713be168c0dSopenharmony_ci+#if (__ANDROID_API__ >= 21) || defined(MS_COMPILE_OHOS) 1714be168c0dSopenharmony_ci THREAD_INFO("thread: %d, mask: %lu", pthread_gettid_np(thread_id), cpu_set->__bits[0]); 1715be168c0dSopenharmony_ci int ret = sched_setaffinity(pthread_gettid_np(thread_id), sizeof(cpu_set_t), cpu_set); 1716be168c0dSopenharmony_ci if (ret != THREAD_OK) { 1717be168c0dSopenharmony_ci- THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", pthread_gettid_np(thread_id), ret); 1718be168c0dSopenharmony_ci+ THREAD_ERROR("bind thread %d to cpu failed. ERROR %{public}d", pthread_gettid_np(thread_id), ret); 1719be168c0dSopenharmony_ci return THREAD_ERROR; 1720be168c0dSopenharmony_ci } 1721be168c0dSopenharmony_ci #endif 1722be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/core_affinity.h b/mindspore/core/mindrt/src/thread/core_affinity.h 1723be168c0dSopenharmony_ciindex 2dd2abd1..28b0967a 100644 1724be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/core_affinity.h 1725be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/core_affinity.h 1726be168c0dSopenharmony_ci@@ -23,7 +23,7 @@ 1727be168c0dSopenharmony_ci #ifdef PARALLEL_INFERENCE 1728be168c0dSopenharmony_ci #define BIND_CORE 1729be168c0dSopenharmony_ci #endif 1730be168c0dSopenharmony_ci-#ifdef __ANDROID__ 1731be168c0dSopenharmony_ci+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) 1732be168c0dSopenharmony_ci #define BIND_CORE 1733be168c0dSopenharmony_ci #include <sched.h> 1734be168c0dSopenharmony_ci #endif 1735be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/parallel_threadpool.cc b/mindspore/core/mindrt/src/thread/parallel_threadpool.cc 1736be168c0dSopenharmony_ciindex 9e0dd25c..09c39f32 100644 1737be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/parallel_threadpool.cc 1738be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/parallel_threadpool.cc 1739be168c0dSopenharmony_ci@@ -48,7 +48,7 @@ void ParallelWorker::ParallelRun() { 1740be168c0dSopenharmony_ci SetAffinity(); 1741be168c0dSopenharmony_ci } 1742be168c0dSopenharmony_ci #if !defined(__APPLE__) && !defined(_MSC_VER) 1743be168c0dSopenharmony_ci- (void)pthread_setname_np(pthread_self(), ("ParallelThread_" + std::to_string(worker_id_)).c_str()); 1744be168c0dSopenharmony_ci+ (void)pthread_setname_np(pthread_self(), ("OS_Parallel_" + std::to_string(worker_id_)).c_str()); 1745be168c0dSopenharmony_ci #endif 1746be168c0dSopenharmony_ci #ifdef PLATFORM_86 1747be168c0dSopenharmony_ci // Some CPU kernels need set the flush zero mode to improve performance. 1748be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/threadlog.h b/mindspore/core/mindrt/src/thread/threadlog.h 1749be168c0dSopenharmony_ciindex 7ed917f1..b212a401 100644 1750be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/threadlog.h 1751be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/threadlog.h 1752be168c0dSopenharmony_ci@@ -16,7 +16,9 @@ 1753be168c0dSopenharmony_ci 1754be168c0dSopenharmony_ci #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_LOG_H_ 1755be168c0dSopenharmony_ci #define MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_LOG_H_ 1756be168c0dSopenharmony_ci- 1757be168c0dSopenharmony_ci+#ifdef MS_COMPILE_OHOS 1758be168c0dSopenharmony_ci+#include "hilog/log.h" 1759be168c0dSopenharmony_ci+#endif 1760be168c0dSopenharmony_ci namespace mindspore { 1761be168c0dSopenharmony_ci #ifdef THREAD_POOL_DEBUG 1762be168c0dSopenharmony_ci #include <stdio.h> 1763be168c0dSopenharmony_ci@@ -32,13 +34,35 @@ namespace mindspore { 1764be168c0dSopenharmony_ci } 1765be168c0dSopenharmony_ci #else 1766be168c0dSopenharmony_ci #define THREAD_DEBUG(content, ...) 1767be168c0dSopenharmony_ci-#define THREAD_INFO(content, ...) 1768be168c0dSopenharmony_ci #define THREAD_TEST_TRUE(flag) 1769be168c0dSopenharmony_ci+ 1770be168c0dSopenharmony_ci #if defined(__ANDROID__) 1771be168c0dSopenharmony_ci+#define THREAD_INFO(content, ...) 1772be168c0dSopenharmony_ci #include <android/log.h> 1773be168c0dSopenharmony_ci #define THREAD_ERROR(content, args...) \ 1774be168c0dSopenharmony_ci { __android_log_print(ANDROID_LOG_ERROR, "MS_LITE", "%s|%d: " #content "\r\n", __func__, __LINE__, ##args); } 1775be168c0dSopenharmony_ci+ 1776be168c0dSopenharmony_ci+#elif defined(MS_COMPILE_OHOS) // For OHOS, use hilog. 1777be168c0dSopenharmony_ci+ 1778be168c0dSopenharmony_ci+#define MINDRT_OHOS_LOG_DOMAIN 0x2102 1779be168c0dSopenharmony_ci+#define MINDRT_OHOS_LOG_TAG "MS_LITE" 1780be168c0dSopenharmony_ci+ 1781be168c0dSopenharmony_ci+#ifdef MS_COMPILE_WITH_OHOS_NDK 1782be168c0dSopenharmony_ci+// When build with OHOS NDK, use public api of hilog module. 1783be168c0dSopenharmony_ci+#define THREAD_INFO(content, args...) \ 1784be168c0dSopenharmony_ci+ { OH_LOG_Print(LOG_APP, LOG_INFO, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); } 1785be168c0dSopenharmony_ci+#define THREAD_ERROR(content, args...) \ 1786be168c0dSopenharmony_ci+ { OH_LOG_Print(LOG_APP, LOG_ERROR, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); } 1787be168c0dSopenharmony_ci+#else 1788be168c0dSopenharmony_ci+// When build in OHOS repo, use inner api of hilog module. 1789be168c0dSopenharmony_ci+#define THREAD_INFO(content, args...) \ 1790be168c0dSopenharmony_ci+ { HiLogPrint(LOG_APP, LOG_INFO, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); } 1791be168c0dSopenharmony_ci+#define THREAD_ERROR(content, args...) \ 1792be168c0dSopenharmony_ci+ { HiLogPrint(LOG_APP, LOG_ERROR, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); } 1793be168c0dSopenharmony_ci+#endif 1794be168c0dSopenharmony_ci+ 1795be168c0dSopenharmony_ci #else 1796be168c0dSopenharmony_ci+#define THREAD_INFO(content, ...) 1797be168c0dSopenharmony_ci #define THREAD_ERROR(content, ...) 1798be168c0dSopenharmony_ci #endif 1799be168c0dSopenharmony_ci #endif 1800be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc 1801be168c0dSopenharmony_ciindex c56e0425..2301be8c 100644 1802be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/threadpool.cc 1803be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/threadpool.cc 1804be168c0dSopenharmony_ci@@ -68,10 +68,11 @@ void Worker::SetAffinity() { 1805be168c0dSopenharmony_ci #ifdef _WIN32 1806be168c0dSopenharmony_ci SetWindowsSelfAffinity(core_id_); 1807be168c0dSopenharmony_ci #elif defined(BIND_CORE) 1808be168c0dSopenharmony_ci-#ifdef __ANDROID__ 1809be168c0dSopenharmony_ci+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS) 1810be168c0dSopenharmony_ci+ THREAD_INFO("thread: %d, mask: %lu", gettid(), mask_.__bits[0]); 1811be168c0dSopenharmony_ci int ret = sched_setaffinity(gettid(), sizeof(cpu_set_t), &mask_); 1812be168c0dSopenharmony_ci if (ret != THREAD_OK) { 1813be168c0dSopenharmony_ci- THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", gettid(), errno); 1814be168c0dSopenharmony_ci+ THREAD_ERROR("bind thread %d to cpu failed. ERROR %{public}d", gettid(), errno); 1815be168c0dSopenharmony_ci } 1816be168c0dSopenharmony_ci return; 1817be168c0dSopenharmony_ci #else 1818be168c0dSopenharmony_ci@@ -111,7 +112,7 @@ void Worker::Run() { 1819be168c0dSopenharmony_ci } 1820be168c0dSopenharmony_ci #if !defined(__APPLE__) && !defined(_MSC_VER) 1821be168c0dSopenharmony_ci static std::atomic_int index = {0}; 1822be168c0dSopenharmony_ci- (void)pthread_setname_np(pthread_self(), ("KernelThread_" + std::to_string(index++)).c_str()); 1823be168c0dSopenharmony_ci+ (void)pthread_setname_np(pthread_self(), ("OS_Kernel_" + std::to_string(index++)).c_str()); 1824be168c0dSopenharmony_ci #endif 1825be168c0dSopenharmony_ci #ifdef PLATFORM_86 1826be168c0dSopenharmony_ci // Some CPU kernels need set the flush zero mode to improve performance. 1827be168c0dSopenharmony_cidiff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn 1828be168c0dSopenharmony_ciindex a774b58c..f7e465e2 100644 1829be168c0dSopenharmony_ci--- a/mindspore/lite/BUILD.gn 1830be168c0dSopenharmony_ci+++ b/mindspore/lite/BUILD.gn 1831be168c0dSopenharmony_ci@@ -71,9 +71,14 @@ 1832be168c0dSopenharmony_ci 1833be168c0dSopenharmony_ci import("//build/ohos.gni") 1834be168c0dSopenharmony_ci 1835be168c0dSopenharmony_ci+declare_args() { 1836be168c0dSopenharmony_ci+ mindspore_feature_nnrt_metagraph = false 1837be168c0dSopenharmony_ci+} 1838be168c0dSopenharmony_ci+ 1839be168c0dSopenharmony_ci ohos_group("mindspore") { 1840be168c0dSopenharmony_ci deps = [ 1841be168c0dSopenharmony_ci ":mindspore_lib", 1842be168c0dSopenharmony_ci+ ":mindspore_ndk", 1843be168c0dSopenharmony_ci ":mindspore_train_lib", 1844be168c0dSopenharmony_ci "mindir:mindir_lib", 1845be168c0dSopenharmony_ci "src/litert/js_api:mindsporelite_napi" 1846be168c0dSopenharmony_ci@@ -180,7 +185,6 @@ lite_mindrt_sources = [ 1847be168c0dSopenharmony_ci ] 1848be168c0dSopenharmony_ci 1849be168c0dSopenharmony_ci all_lite_sources += cxx_api_sources 1850be168c0dSopenharmony_ci-all_lite_sources += c_api_sources 1851be168c0dSopenharmony_ci all_lite_sources += api_source 1852be168c0dSopenharmony_ci all_lite_sources += control_flow_kernel_sources 1853be168c0dSopenharmony_ci all_lite_sources += experimental_sources 1854be168c0dSopenharmony_ci@@ -368,7 +372,6 @@ ohos_shared_library("mindspore_lib") { 1855be168c0dSopenharmony_ci sources = all_sources 1856be168c0dSopenharmony_ci 1857be168c0dSopenharmony_ci include_dirs = [ 1858be168c0dSopenharmony_ci- "//base/hiviewdfx/hilog/interfaces/native/innerkits/include", 1859be168c0dSopenharmony_ci "//third_party/flatbuffers/include", 1860be168c0dSopenharmony_ci "./", 1861be168c0dSopenharmony_ci "../", 1862be168c0dSopenharmony_ci@@ -384,6 +387,7 @@ ohos_shared_library("mindspore_lib") { 1863be168c0dSopenharmony_ci "../ccsrc/", 1864be168c0dSopenharmony_ci "src/litert/kernel/cpu/", 1865be168c0dSopenharmony_ci "../core/mindrt/src/", 1866be168c0dSopenharmony_ci+ "//foundation/ai/neural_network_runtime/", 1867be168c0dSopenharmony_ci ] 1868be168c0dSopenharmony_ci 1869be168c0dSopenharmony_ci defines = [ 1870be168c0dSopenharmony_ci@@ -426,24 +430,29 @@ ohos_shared_library("mindspore_lib") { 1871be168c0dSopenharmony_ci 1872be168c0dSopenharmony_ci external_deps = [ "hilog:libhilog" ] 1873be168c0dSopenharmony_ci 1874be168c0dSopenharmony_ci- output_name = "libmindspore-lite.huawei" 1875be168c0dSopenharmony_ci+ output_name = "libmindspore-lite" 1876be168c0dSopenharmony_ci output_extension = "so" 1877be168c0dSopenharmony_ci innerapi_tags = [ "platformsdk" ] 1878be168c0dSopenharmony_ci SUPPORT_NNRT = true 1879be168c0dSopenharmony_ci if (SUPPORT_NNRT) { 1880be168c0dSopenharmony_ci+ if (mindspore_feature_nnrt_metagraph) { 1881be168c0dSopenharmony_ci+ defines += [ "SUPPORT_NNRT_METAGRAPH" ] 1882be168c0dSopenharmony_ci+ print("enabled feature: mindspore_feature_nnrt_metagraph") 1883be168c0dSopenharmony_ci+ } 1884be168c0dSopenharmony_ci sources += [ 1885be168c0dSopenharmony_ci "src/litert/delegate/nnrt/checker/primitive_check.cc", 1886be168c0dSopenharmony_ci "src/litert/delegate/nnrt/nnrt_delegate.cc", 1887be168c0dSopenharmony_ci "src/litert/delegate/nnrt/nnrt_model_kernel.cc", 1888be168c0dSopenharmony_ci ] 1889be168c0dSopenharmony_ci include_dirs += [ 1890be168c0dSopenharmony_ci- "//foundation/ai/neural_network_runtime", 1891be168c0dSopenharmony_ci "src/delegate/nnrt/include", 1892be168c0dSopenharmony_ci "../../mindspore/core/ir", 1893be168c0dSopenharmony_ci "mindir/include", 1894be168c0dSopenharmony_ci "mindir/inner_headers", 1895be168c0dSopenharmony_ci ] 1896be168c0dSopenharmony_ci+ 1897be168c0dSopenharmony_ci external_deps += [ "neural_network_runtime:nnrt_target" ] 1898be168c0dSopenharmony_ci+ 1899be168c0dSopenharmony_ci deps += [ "mindir:mindir_lib" ] 1900be168c0dSopenharmony_ci defines += [ "SUPPORT_NNRT" ] 1901be168c0dSopenharmony_ci } 1902be168c0dSopenharmony_ci@@ -461,6 +470,67 @@ ohos_shared_library("mindspore_lib") { 1903be168c0dSopenharmony_ci subsystem_name = "thirdparty" 1904be168c0dSopenharmony_ci } 1905be168c0dSopenharmony_ci 1906be168c0dSopenharmony_ci+# NDK lib 1907be168c0dSopenharmony_ci+ohos_shared_library("mindspore_ndk") { 1908be168c0dSopenharmony_ci+ deps = [ 1909be168c0dSopenharmony_ci+ ":mindspore_lib", 1910be168c0dSopenharmony_ci+ ":mindspore_train_lib" 1911be168c0dSopenharmony_ci+ ] 1912be168c0dSopenharmony_ci+ 1913be168c0dSopenharmony_ci+ sources = c_api_sources 1914be168c0dSopenharmony_ci+ 1915be168c0dSopenharmony_ci+ include_dirs = [ 1916be168c0dSopenharmony_ci+ "//base/hiviewdfx/hilog/interfaces/native/innerkits/include", 1917be168c0dSopenharmony_ci+ "//third_party/flatbuffers/include", 1918be168c0dSopenharmony_ci+ "./", 1919be168c0dSopenharmony_ci+ "../", 1920be168c0dSopenharmony_ci+ "../../", 1921be168c0dSopenharmony_ci+ "../core", 1922be168c0dSopenharmony_ci+ "src", 1923be168c0dSopenharmony_ci+ "src/c_api/", 1924be168c0dSopenharmony_ci+ "../ccsrc/plugin/device/cpu/kernel/", 1925be168c0dSopenharmony_ci+ "../core/mindrt/src/", 1926be168c0dSopenharmony_ci+ "../core/mindrt/include/", 1927be168c0dSopenharmony_ci+ "../../third_party/", 1928be168c0dSopenharmony_ci+ "./schema/", 1929be168c0dSopenharmony_ci+ "../ccsrc/", 1930be168c0dSopenharmony_ci+ "//foundation/ai/neural_network_runtime/", 1931be168c0dSopenharmony_ci+ ] 1932be168c0dSopenharmony_ci+ 1933be168c0dSopenharmony_ci+ defines = [ 1934be168c0dSopenharmony_ci+ "SUPPORT_NNRT", 1935be168c0dSopenharmony_ci+ "MS_COMPILE_OHOS", 1936be168c0dSopenharmony_ci+ "PRIMITIVE_WRITEABLE", 1937be168c0dSopenharmony_ci+ "RUNTIME_PASS_CLIP", 1938be168c0dSopenharmony_ci+ "ENABLE_MULTI_LAYOUT", 1939be168c0dSopenharmony_ci+ "VERSION_STR=\"2.1.0\"", 1940be168c0dSopenharmony_ci+ ] 1941be168c0dSopenharmony_ci+ 1942be168c0dSopenharmony_ci+ configs = [ 1943be168c0dSopenharmony_ci+ ":mindspore_api", 1944be168c0dSopenharmony_ci+ ":disable_android", 1945be168c0dSopenharmony_ci+ ":secure_option", 1946be168c0dSopenharmony_ci+ ] 1947be168c0dSopenharmony_ci+ 1948be168c0dSopenharmony_ci+ external_deps = [ "neural_network_runtime:nnrt_target" ] 1949be168c0dSopenharmony_ci+ 1950be168c0dSopenharmony_ci+ remove_configs = [ "//build/config/compiler:no_rtti" ] 1951be168c0dSopenharmony_ci+ 1952be168c0dSopenharmony_ci+ output_name = "libmindspore_lite_ndk" 1953be168c0dSopenharmony_ci+ output_extension = "so" 1954be168c0dSopenharmony_ci+ innerapi_tags = [ "ndk"] 1955be168c0dSopenharmony_ci+ cflags_cc = [ 1956be168c0dSopenharmony_ci+ "-Wno-ignored-qualifiers", 1957be168c0dSopenharmony_ci+ "-Wunused-private-field", 1958be168c0dSopenharmony_ci+ "-Wno-unused-private-field", 1959be168c0dSopenharmony_ci+ "-Wno-inconsistent-missing-override", 1960be168c0dSopenharmony_ci+ "-Wno-macro-redefined", 1961be168c0dSopenharmony_ci+ "-Wno-constant-conversion", 1962be168c0dSopenharmony_ci+ ] 1963be168c0dSopenharmony_ci+ part_name = "mindspore" 1964be168c0dSopenharmony_ci+ subsystem_name = "thirdparty" 1965be168c0dSopenharmony_ci+} 1966be168c0dSopenharmony_ci+ 1967be168c0dSopenharmony_ci # Train library 1968be168c0dSopenharmony_ci expression_cxx_api_sources = [ 1969be168c0dSopenharmony_ci "src/litert/cxx_api/expression/net.cc", 1970be168c0dSopenharmony_ci@@ -614,7 +684,6 @@ ohos_shared_library("mindspore_train_lib") { 1971be168c0dSopenharmony_ci sources = all_train_sources 1972be168c0dSopenharmony_ci 1973be168c0dSopenharmony_ci include_dirs = [ 1974be168c0dSopenharmony_ci- "//base/hiviewdfx/hilog/interfaces/native/innerkits/include", 1975be168c0dSopenharmony_ci "//third_party/flatbuffers/include", 1976be168c0dSopenharmony_ci "./", 1977be168c0dSopenharmony_ci "../", 1978be168c0dSopenharmony_ci@@ -698,6 +767,9 @@ config("disable_android") { 1979be168c0dSopenharmony_ci "-U__ANDROID__", 1980be168c0dSopenharmony_ci "-U__ANDROID_API__", 1981be168c0dSopenharmony_ci ] 1982be168c0dSopenharmony_ci+ ldflags = [ 1983be168c0dSopenharmony_ci+ "-Wl,--no-as-needed", 1984be168c0dSopenharmony_ci+ ] 1985be168c0dSopenharmony_ci } 1986be168c0dSopenharmony_ci 1987be168c0dSopenharmony_ci config("secure_option") { 1988be168c0dSopenharmony_cidiff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt 1989be168c0dSopenharmony_ciindex 72337f70..1faf2f38 100644 1990be168c0dSopenharmony_ci--- a/mindspore/lite/CMakeLists.txt 1991be168c0dSopenharmony_ci+++ b/mindspore/lite/CMakeLists.txt 1992be168c0dSopenharmony_ci@@ -298,8 +298,9 @@ elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite") 1993be168c0dSopenharmony_ci elseif(TOOLCHAIN_NAME STREQUAL "ohos") 1994be168c0dSopenharmony_ci set(TARGET_OHOS on) 1995be168c0dSopenharmony_ci add_compile_definitions(MS_COMPILE_OHOS) 1996be168c0dSopenharmony_ci- set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions") 1997be168c0dSopenharmony_ci- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions") 1998be168c0dSopenharmony_ci+ add_compile_definitions(MS_COMPILE_WITH_OHOS_NDK) 1999be168c0dSopenharmony_ci+ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions -Wno-deprecated-builtins") 2000be168c0dSopenharmony_ci+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions -Wno-deprecated-builtins") 2001be168c0dSopenharmony_ci endif() 2002be168c0dSopenharmony_ci 2003be168c0dSopenharmony_ci if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0 2004be168c0dSopenharmony_cidiff --git a/mindspore/lite/include/lite_types.h b/mindspore/lite/include/lite_types.h 2005be168c0dSopenharmony_ciindex 017e98a8..860390d5 100644 2006be168c0dSopenharmony_ci--- a/mindspore/lite/include/lite_types.h 2007be168c0dSopenharmony_ci+++ b/mindspore/lite/include/lite_types.h 2008be168c0dSopenharmony_ci@@ -42,6 +42,7 @@ typedef enum { 2009be168c0dSopenharmony_ci DT_NPU, /**< NPU device type */ 2010be168c0dSopenharmony_ci DT_ASCEND, /**< ASCEND device type */ 2011be168c0dSopenharmony_ci DT_CUSTOM, /**< EXTEND device type */ 2012be168c0dSopenharmony_ci+ DT_NNRT, /**< NNRT device type */ 2013be168c0dSopenharmony_ci DT_END /**< NO device type */ 2014be168c0dSopenharmony_ci } DeviceType; 2015be168c0dSopenharmony_ci 2016be168c0dSopenharmony_cidiff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h 2017be168c0dSopenharmony_ciindex 93e27ea9..b96c7e35 100644 2018be168c0dSopenharmony_ci--- a/mindspore/lite/include/model.h 2019be168c0dSopenharmony_ci+++ b/mindspore/lite/include/model.h 2020be168c0dSopenharmony_ci@@ -25,6 +25,7 @@ namespace mindspore { 2021be168c0dSopenharmony_ci namespace schema { 2022be168c0dSopenharmony_ci struct Tensor; 2023be168c0dSopenharmony_ci } // namespace schema 2024be168c0dSopenharmony_ci+ 2025be168c0dSopenharmony_ci namespace lite { 2026be168c0dSopenharmony_ci typedef enum { ModelType_MSLite, ModelType_MindIR } LiteModelType; 2027be168c0dSopenharmony_ci 2028be168c0dSopenharmony_ci@@ -62,7 +63,10 @@ struct MS_API LiteGraph { 2029be168c0dSopenharmony_ci bool model_obfuscated_ = false; 2030be168c0dSopenharmony_ci std::vector<unsigned char *> deobf_prims_; 2031be168c0dSopenharmony_ci #endif 2032be168c0dSopenharmony_ci+ 2033be168c0dSopenharmony_ci+ std::string ToString() const; 2034be168c0dSopenharmony_ci }; 2035be168c0dSopenharmony_ci+ 2036be168c0dSopenharmony_ci struct MS_API Model { 2037be168c0dSopenharmony_ci LiteGraph graph_; 2038be168c0dSopenharmony_ci char *buf = nullptr; 2039be168c0dSopenharmony_cidiff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h 2040be168c0dSopenharmony_ciindex 2d72b200..4bc92599 100644 2041be168c0dSopenharmony_ci--- a/mindspore/lite/include/registry/converter_context.h 2042be168c0dSopenharmony_ci+++ b/mindspore/lite/include/registry/converter_context.h 2043be168c0dSopenharmony_ci@@ -39,7 +39,9 @@ enum MS_API FmkType : int { 2044be168c0dSopenharmony_ci kFmkTypeMs = 3, 2045be168c0dSopenharmony_ci kFmkTypeTflite = 4, 2046be168c0dSopenharmony_ci kFmkTypePytorch = 5, 2047be168c0dSopenharmony_ci- kFmkTypeMsLite = 6, 2048be168c0dSopenharmony_ci+ kFmkTypeThirdParty = 6, 2049be168c0dSopenharmony_ci+ kFmkTypeMsLite = 7, 2050be168c0dSopenharmony_ci+ kFmkTypeEnd = 8, // For range check purpose, valid range: [0, kFmkTypeEnd) 2051be168c0dSopenharmony_ci }; 2052be168c0dSopenharmony_ci 2053be168c0dSopenharmony_ci /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser. 2054be168c0dSopenharmony_cidiff --git a/mindspore/lite/mindir/include/mindir.h b/mindspore/lite/mindir/include/mindir.h 2055be168c0dSopenharmony_ciindex ca811dce..f47cad8c 100644 2056be168c0dSopenharmony_ci--- a/mindspore/lite/mindir/include/mindir.h 2057be168c0dSopenharmony_ci+++ b/mindspore/lite/mindir/include/mindir.h 2058be168c0dSopenharmony_ci@@ -151,6 +151,8 @@ int64_t MindIR_Conv2DFusion_GetOutChannel(ConstPrimitivePtr primitive); 2059be168c0dSopenharmony_ci void MindIR_Conv2DFusion_SetOutChannel(PrimitivePtr *primitive, int64_t out_channel); 2060be168c0dSopenharmony_ci ActivationType MindIR_Conv2DFusion_GetActivationType(ConstPrimitivePtr primitive); 2061be168c0dSopenharmony_ci void MindIR_Conv2DFusion_SetActivationType(PrimitivePtr *primitive, ActivationType activation_type); 2062be168c0dSopenharmony_ci+Format MindIR_Conv2DFusion_GetFormat(ConstPrimitivePtr primitive); 2063be168c0dSopenharmony_ci+void MindIR_Conv2DFusion_SetFormat(PrimitivePtr *primitive, Format format); 2064be168c0dSopenharmony_ci 2065be168c0dSopenharmony_ci // ********** Conv2dTransposeFusion ********** 2066be168c0dSopenharmony_ci PrimitivePtr MindIR_Conv2dTransposeFusion_CreatePrimitive( 2067be168c0dSopenharmony_cidiff --git a/mindspore/lite/mindir/src/mindir.cc b/mindspore/lite/mindir/src/mindir.cc 2068be168c0dSopenharmony_ciindex 7fc9c00e..374bbef5 100644 2069be168c0dSopenharmony_ci--- a/mindspore/lite/mindir/src/mindir.cc 2070be168c0dSopenharmony_ci+++ b/mindspore/lite/mindir/src/mindir.cc 2071be168c0dSopenharmony_ci@@ -1215,6 +1215,46 @@ void MindIR_Conv2DFusion_SetActivationType(PrimitivePtr *primitive, ActivationTy 2072be168c0dSopenharmony_ci } 2073be168c0dSopenharmony_ci } 2074be168c0dSopenharmony_ci 2075be168c0dSopenharmony_ci+Format MindIR_Conv2DFusion_GetFormat(ConstPrimitivePtr primitive) { 2076be168c0dSopenharmony_ci+ if (primitive != nullptr) { 2077be168c0dSopenharmony_ci+ auto prim = static_cast<const schema::Primitive *>(primitive); 2078be168c0dSopenharmony_ci+ auto value = prim->value_as_Conv2DFusion(); 2079be168c0dSopenharmony_ci+ if (prim != nullptr && value != nullptr) { 2080be168c0dSopenharmony_ci+ return static_cast<Format>(value->format()); 2081be168c0dSopenharmony_ci+ } else { 2082be168c0dSopenharmony_ci+ Format en = static_cast<Format>(0); 2083be168c0dSopenharmony_ci+ return en; 2084be168c0dSopenharmony_ci+ } 2085be168c0dSopenharmony_ci+ } else { 2086be168c0dSopenharmony_ci+ Format en = static_cast<Format>(0); 2087be168c0dSopenharmony_ci+ return en; 2088be168c0dSopenharmony_ci+ } 2089be168c0dSopenharmony_ci+} 2090be168c0dSopenharmony_ci+ 2091be168c0dSopenharmony_ci+void MindIR_Conv2DFusion_SetFormat(PrimitivePtr *primitive, Format format) { 2092be168c0dSopenharmony_ci+ if (primitive != nullptr && *primitive != nullptr) { 2093be168c0dSopenharmony_ci+ auto prim = static_cast<schema::Primitive *>(*primitive); 2094be168c0dSopenharmony_ci+ auto value = prim->value_as_Conv2DFusion(); 2095be168c0dSopenharmony_ci+ if (prim != nullptr && value != nullptr) { 2096be168c0dSopenharmony_ci+ flatbuffers::FlatBufferBuilder fbb; 2097be168c0dSopenharmony_ci+ auto ops_offset = schema::CreateConv2DFusion( 2098be168c0dSopenharmony_ci+ fbb, static_cast<schema::Format>(format), 2099be168c0dSopenharmony_ci+ fbb.CreateVector(value->kernel_size()->data(), value->kernel_size()->size()), 2100be168c0dSopenharmony_ci+ fbb.CreateVector(value->stride()->data(), value->stride()->size()), 2101be168c0dSopenharmony_ci+ fbb.CreateVector(value->dilation()->data(), value->dilation()->size()), 2102be168c0dSopenharmony_ci+ static_cast<schema::PadMode>(value->pad_mode()), 2103be168c0dSopenharmony_ci+ fbb.CreateVector(value->pad_list()->data(), value->pad_list()->size()), 0, value->group(), value->in_channel(), 2104be168c0dSopenharmony_ci+ value->out_channel(), static_cast<schema::ActivationType>(value->activation_type())); 2105be168c0dSopenharmony_ci+ auto prim_offset = 2106be168c0dSopenharmony_ci+ schema::CreatePrimitive(fbb, static_cast<schema::PrimitiveType>(NODE_TYPE_CONV2D_FUSION), ops_offset.o); 2107be168c0dSopenharmony_ci+ fbb.Finish(prim_offset); 2108be168c0dSopenharmony_ci+ auto new_addr = MindIRMemoryManager::GetInstance()->CreatePrimitiveFromBuilder(fbb, prim); 2109be168c0dSopenharmony_ci+ auto ret_value = flatbuffers::GetMutableRoot<schema::Primitive>(new_addr); 2110be168c0dSopenharmony_ci+ *primitive = ret_value; 2111be168c0dSopenharmony_ci+ } 2112be168c0dSopenharmony_ci+ } 2113be168c0dSopenharmony_ci+} 2114be168c0dSopenharmony_ci+ 2115be168c0dSopenharmony_ci // ********** Conv2dTransposeFusion ********** 2116be168c0dSopenharmony_ci PrimitivePtr MindIR_Conv2dTransposeFusion_CreatePrimitive( 2117be168c0dSopenharmony_ci const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation, 2118be168c0dSopenharmony_cidiff --git a/mindspore/lite/mindir/src/mindir_tensor.cc b/mindspore/lite/mindir/src/mindir_tensor.cc 2119be168c0dSopenharmony_ciindex 9ec2d0e4..2db4ce8b 100644 2120be168c0dSopenharmony_ci--- a/mindspore/lite/mindir/src/mindir_tensor.cc 2121be168c0dSopenharmony_ci+++ b/mindspore/lite/mindir/src/mindir_tensor.cc 2122be168c0dSopenharmony_ci@@ -134,7 +134,7 @@ void MindIR_Tensor_SetDataType(TensorPtr *tensor, DataType data_type) { 2123be168c0dSopenharmony_ci name = fbb.CreateString(value->name()->c_str(), value->name()->size()); 2124be168c0dSopenharmony_ci } 2125be168c0dSopenharmony_ci auto ops_offset = 2126be168c0dSopenharmony_ci- schema::CreateTensor(fbb, 0, value->dataType(), dims, static_cast<schema::Format>(value->format()), 0, 0, data, 2127be168c0dSopenharmony_ci+ schema::CreateTensor(fbb, 0, data_type, dims, static_cast<schema::Format>(value->format()), 0, 0, data, 2128be168c0dSopenharmony_ci ConvertQuantParams(fbb, value->quantParams()), 0, name); 2129be168c0dSopenharmony_ci fbb.Finish(ops_offset); 2130be168c0dSopenharmony_ci auto new_addr = MindIRMemoryManager::GetInstance()->CreateTensorFromBuilder(fbb, value); 2131be168c0dSopenharmony_cidiff --git a/mindspore/lite/mindir/src/utils.cc b/mindspore/lite/mindir/src/utils.cc 2132be168c0dSopenharmony_ciindex 28d66ceb..b044f414 100644 2133be168c0dSopenharmony_ci--- a/mindspore/lite/mindir/src/utils.cc 2134be168c0dSopenharmony_ci+++ b/mindspore/lite/mindir/src/utils.cc 2135be168c0dSopenharmony_ci@@ -22,7 +22,7 @@ namespace lite { 2136be168c0dSopenharmony_ci 2137be168c0dSopenharmony_ci // ********** PrimitiveBase ********** 2138be168c0dSopenharmony_ci NodeType MindIR_Primitive_GetType(PrimitivePtr primitive) { 2139be168c0dSopenharmony_ci- auto prim = flatbuffers::GetMutableRoot<schema::Primitive>(primitive); 2140be168c0dSopenharmony_ci+ auto prim = static_cast<schema::Primitive *>(primitive); 2141be168c0dSopenharmony_ci auto type = prim->value_type(); 2142be168c0dSopenharmony_ci return static_cast<NodeType>(type); 2143be168c0dSopenharmony_ci } 2144be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt 2145be168c0dSopenharmony_ciindex 5afccc87..de1781cd 100644 2146be168c0dSopenharmony_ci--- a/mindspore/lite/src/CMakeLists.txt 2147be168c0dSopenharmony_ci+++ b/mindspore/lite/src/CMakeLists.txt 2148be168c0dSopenharmony_ci@@ -410,6 +410,11 @@ add_subdirectory(common) 2149be168c0dSopenharmony_ci add_library(lite_src_mid OBJECT ${LITE_SRC}) 2150be168c0dSopenharmony_ci add_dependencies(lite_src_mid lite_src_common_mid fbs_src fbs_inner_src) 2151be168c0dSopenharmony_ci 2152be168c0dSopenharmony_ci+if(SUPPORT_NNRT) 2153be168c0dSopenharmony_ci+ add_subdirectory(litert/delegate/nnrt) 2154be168c0dSopenharmony_ci+ target_link_libraries(lite_src_mid nnrt_mid) 2155be168c0dSopenharmony_ci+endif() 2156be168c0dSopenharmony_ci+ 2157be168c0dSopenharmony_ci if(MSLITE_ENABLE_ACL) 2158be168c0dSopenharmony_ci include_directories(${TOP_DIR}/graphengine/910/inc/external) 2159be168c0dSopenharmony_ci if(NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)) 2160be168c0dSopenharmony_ci@@ -497,7 +502,6 @@ if(MSLITE_ENABLE_MINDRT) 2161be168c0dSopenharmony_ci endif() 2162be168c0dSopenharmony_ci 2163be168c0dSopenharmony_ci if (SUPPORT_NNRT) 2164be168c0dSopenharmony_ci- add_subdirectory(litert/delegate/nnrt) 2165be168c0dSopenharmony_ci target_link_libraries(mindspore-lite nnrt_mid) 2166be168c0dSopenharmony_ci target_link_libraries(mindspore-lite_static nnrt_mid) 2167be168c0dSopenharmony_ci endif() 2168be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/context_util.cc b/mindspore/lite/src/common/context_util.cc 2169be168c0dSopenharmony_ciindex f011e0d7..0fa4ebd0 100644 2170be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/context_util.cc 2171be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/context_util.cc 2172be168c0dSopenharmony_ci@@ -118,6 +118,17 @@ std::shared_ptr<mindspore::DeviceInfoContext> CustomDeviceInfoFromCustomDeviceCo 2173be168c0dSopenharmony_ci MS_CHECK_TRUE_RET(device_info != nullptr, nullptr); 2174be168c0dSopenharmony_ci return device_info; 2175be168c0dSopenharmony_ci } 2176be168c0dSopenharmony_ci+ 2177be168c0dSopenharmony_ci+std::shared_ptr<mindspore::NNRTDeviceInfo> NNRtDeviceInfoFromNNRtDeviceContext( 2178be168c0dSopenharmony_ci+ const lite::DeviceContext &nnrt_context) { 2179be168c0dSopenharmony_ci+ if (nnrt_context.device_type_ != DT_NNRT) { 2180be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Function input parameter is not NNRt context."; 2181be168c0dSopenharmony_ci+ return nullptr; 2182be168c0dSopenharmony_ci+ } 2183be168c0dSopenharmony_ci+ auto nnrt_info = std::make_shared<mindspore::NNRTDeviceInfo>(); 2184be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(nnrt_info != nullptr, nullptr); 2185be168c0dSopenharmony_ci+ return nnrt_info; 2186be168c0dSopenharmony_ci+} 2187be168c0dSopenharmony_ci } // namespace 2188be168c0dSopenharmony_ci 2189be168c0dSopenharmony_ci mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &context) { 2190be168c0dSopenharmony_ci@@ -144,7 +155,8 @@ mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &co 2191be168c0dSopenharmony_ci {DT_GPU, GPUDeviceInfoFromGPUDeviceContext}, 2192be168c0dSopenharmony_ci {DT_NPU, NPUDeviceInfoFromNPUDeviceContext}, 2193be168c0dSopenharmony_ci {DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext}, 2194be168c0dSopenharmony_ci- {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext}}; 2195be168c0dSopenharmony_ci+ {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext}, 2196be168c0dSopenharmony_ci+ {DT_NNRT, NNRtDeviceInfoFromNNRtDeviceContext}}; 2197be168c0dSopenharmony_ci for (auto &device_context : context->device_list_) { 2198be168c0dSopenharmony_ci auto device_type = device_context.device_type_; 2199be168c0dSopenharmony_ci if (transfer_funcs.find(device_type) == transfer_funcs.end()) { 2200be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/log.cc b/mindspore/lite/src/common/log.cc 2201be168c0dSopenharmony_ciindex 66c0d76b..f1040662 100644 2202be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/log.cc 2203be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/log.cc 2204be168c0dSopenharmony_ci@@ -21,6 +21,13 @@ 2205be168c0dSopenharmony_ci #include <android/log.h> 2206be168c0dSopenharmony_ci #endif 2207be168c0dSopenharmony_ci 2208be168c0dSopenharmony_ci+#ifdef MS_COMPILE_OHOS 2209be168c0dSopenharmony_ci+#define LOG_DOMAIN 0xD002102 2210be168c0dSopenharmony_ci+#define LOG_TAG "MS_LITE" 2211be168c0dSopenharmony_ci+#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s" 2212be168c0dSopenharmony_ci+#include "hilog/log.h" 2213be168c0dSopenharmony_ci+#endif 2214be168c0dSopenharmony_ci+ 2215be168c0dSopenharmony_ci // namespace to support utils module definition namespace mindspore constexpr const char *ANDROID_LOG_TAG = "MS_LITE"; 2216be168c0dSopenharmony_ci namespace mindspore { 2217be168c0dSopenharmony_ci #if defined(__ANDROID__) 2218be168c0dSopenharmony_ci@@ -73,17 +80,33 @@ static int GetAndroidLogLevel(LiteLogLevel level) { 2219be168c0dSopenharmony_ci 2220be168c0dSopenharmony_ci #ifdef MS_COMPILE_OHOS 2221be168c0dSopenharmony_ci void PrintHiLog(LiteLogLevel level, const char *file, int line, const char *func, const char *msg) { 2222be168c0dSopenharmony_ci+#ifdef MS_COMPILE_WITH_OHOS_NDK 2223be168c0dSopenharmony_ci+ // When build with OHOS NDK, use public api of hilog module. 2224be168c0dSopenharmony_ci if (level == LiteLogLevel::DEBUG) { 2225be168c0dSopenharmony_ci- OHOS::HiviewDFX::HiLog::Debug(MSLite_LABEL, FORMAT, file, line, func, msg); 2226be168c0dSopenharmony_ci+ OH_LOG_Print(LOG_APP, LOG_DEBUG, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2227be168c0dSopenharmony_ci } else if (level == LiteLogLevel::INFO) { 2228be168c0dSopenharmony_ci- OHOS::HiviewDFX::HiLog::Info(MSLite_LABEL, FORMAT, file, line, func, msg); 2229be168c0dSopenharmony_ci+ OH_LOG_Print(LOG_APP, LOG_INFO, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2230be168c0dSopenharmony_ci } else if (level == LiteLogLevel::WARNING) { 2231be168c0dSopenharmony_ci- OHOS::HiviewDFX::HiLog::Warn(MSLite_LABEL, FORMAT, file, line, func, msg); 2232be168c0dSopenharmony_ci+ OH_LOG_Print(LOG_APP, LOG_WARN, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2233be168c0dSopenharmony_ci } else if (level == LiteLogLevel::ERROR) { 2234be168c0dSopenharmony_ci- OHOS::HiviewDFX::HiLog::Error(MSLite_LABEL, FORMAT, file, line, func, msg); 2235be168c0dSopenharmony_ci+ OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2236be168c0dSopenharmony_ci } else { 2237be168c0dSopenharmony_ci- OHOS::HiviewDFX::HiLog::Error(MSLite_LABEL, FORMAT, file, line, func, msg); 2238be168c0dSopenharmony_ci+ OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg); 2239be168c0dSopenharmony_ci } 2240be168c0dSopenharmony_ci+#else 2241be168c0dSopenharmony_ci+ // When build in OHOS repo, use inner api of hilog module. 2242be168c0dSopenharmony_ci+ if (level == LiteLogLevel::DEBUG) { 2243be168c0dSopenharmony_ci+ HILOG_DEBUG(LOG_APP, FORMAT, file, line, func, msg); 2244be168c0dSopenharmony_ci+ } else if (level == LiteLogLevel::INFO) { 2245be168c0dSopenharmony_ci+ HILOG_INFO(LOG_APP, FORMAT, file, line, func, msg); 2246be168c0dSopenharmony_ci+ } else if (level == LiteLogLevel::WARNING) { 2247be168c0dSopenharmony_ci+ HILOG_WARN(LOG_APP, FORMAT, file, line, func, msg); 2248be168c0dSopenharmony_ci+ } else if (level == LiteLogLevel::ERROR) { 2249be168c0dSopenharmony_ci+ HILOG_ERROR(LOG_APP, FORMAT, file, line, func, msg); 2250be168c0dSopenharmony_ci+ } else { 2251be168c0dSopenharmony_ci+ HILOG_ERROR(LOG_APP, FORMAT, file, line, func, msg); 2252be168c0dSopenharmony_ci+ } 2253be168c0dSopenharmony_ci+#endif 2254be168c0dSopenharmony_ci } 2255be168c0dSopenharmony_ci #endif 2256be168c0dSopenharmony_ci 2257be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/log.h b/mindspore/lite/src/common/log.h 2258be168c0dSopenharmony_ciindex 3002a454..bea21f01 100644 2259be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/log.h 2260be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/log.h 2261be168c0dSopenharmony_ci@@ -23,12 +23,6 @@ 2262be168c0dSopenharmony_ci #include <unordered_map> 2263be168c0dSopenharmony_ci #include "utils/overload.h" 2264be168c0dSopenharmony_ci 2265be168c0dSopenharmony_ci-#ifdef MS_COMPILE_OHOS 2266be168c0dSopenharmony_ci-#define LOG_DOMAIN 0x2102 2267be168c0dSopenharmony_ci-#define LOG_TAG "MS_Lite" 2268be168c0dSopenharmony_ci-#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s" 2269be168c0dSopenharmony_ci-#include "hilog/log.h" 2270be168c0dSopenharmony_ci-#endif 2271be168c0dSopenharmony_ci // NOTICE: when relative path of 'log.h' changed, macro 'LITE_LOG_HEAR_FILE_REL_PATH' must be changed 2272be168c0dSopenharmony_ci #ifndef LITE_LOG_HEAR_FILE_REL_PATH 2273be168c0dSopenharmony_ci #define LITE_LOG_HEAR_FILE_REL_PATH "mindspore/lite/src/common/log.h" 2274be168c0dSopenharmony_ci@@ -140,6 +134,9 @@ class LiteLogWriter { 2275be168c0dSopenharmony_ci LiteLogLevel log_level_; 2276be168c0dSopenharmony_ci }; 2277be168c0dSopenharmony_ci 2278be168c0dSopenharmony_ci+#define MSLOG_IF(level) \ 2279be168c0dSopenharmony_ci+ mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \ 2280be168c0dSopenharmony_ci+ mindspore::LiteLogStream() 2281be168c0dSopenharmony_ci 2282be168c0dSopenharmony_ci #define MS_LOG(level) MS_LOG_##level 2283be168c0dSopenharmony_ci 2284be168c0dSopenharmony_ci@@ -148,47 +145,6 @@ class LiteLogWriter { 2285be168c0dSopenharmony_ci #define MS_LOG_WARNING MSLOG_IF(mindspore::LiteLogLevel::WARNING) 2286be168c0dSopenharmony_ci #define MS_LOG_ERROR MSLOG_IF(mindspore::LiteLogLevel::ERROR) 2287be168c0dSopenharmony_ci 2288be168c0dSopenharmony_ci- 2289be168c0dSopenharmony_ci-#ifdef MS_COMPILE_OHOS 2290be168c0dSopenharmony_ci-namespace { 2291be168c0dSopenharmony_ci-constexpr unsigned int MSLITE_DOMAIN_ID_START = 0xD0029A0; 2292be168c0dSopenharmony_ci-constexpr unsigned int MSLITE_DOMAIN_ID_END = MSLITE_DOMAIN_ID_START + 32; 2293be168c0dSopenharmony_ci-constexpr unsigned int TEST_DOMAIN_ID = 0xD000F00; 2294be168c0dSopenharmony_ci-} // namespace 2295be168c0dSopenharmony_ci- 2296be168c0dSopenharmony_ci-#define FILE_NAME (__builtin_strrchr(__FILE__, '/') ? __builtin_strrchr(__FILE__, '/') + 1 : __FILE__) 2297be168c0dSopenharmony_ci-#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s" 2298be168c0dSopenharmony_ci- 2299be168c0dSopenharmony_ci-#define MSLOG_IF(level) \ 2300be168c0dSopenharmony_ci- mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \ 2301be168c0dSopenharmony_ci- mindspore::LiteLogStream() 2302be168c0dSopenharmony_ci- 2303be168c0dSopenharmony_ci-enum MSLiteManagerLogLabel { 2304be168c0dSopenharmony_ci- // Component labels, you can add if needed 2305be168c0dSopenharmony_ci- COMP_FWK = 0, 2306be168c0dSopenharmony_ci- // Test label 2307be168c0dSopenharmony_ci- LABEL_TEST, 2308be168c0dSopenharmony_ci- // The end of labels, max to the domain id range length 32 2309be168c0dSopenharmony_ci- LABEL_END, 2310be168c0dSopenharmony_ci-}; 2311be168c0dSopenharmony_ci- 2312be168c0dSopenharmony_ci-enum MSLiteManagerLogDomain { 2313be168c0dSopenharmony_ci- DOMAIN_FRAMEWORK = MSLITE_DOMAIN_ID_START + COMP_FWK, // 0xD0029A0 2314be168c0dSopenharmony_ci- DOMAIN_TEST = TEST_DOMAIN_ID, // 0xD000F00 2315be168c0dSopenharmony_ci- DOMAIN_END = MSLITE_DOMAIN_ID_END, // Max to 0xD002940, keep the sequence and length same as MSLiteManagerLogLabel 2316be168c0dSopenharmony_ci-}; 2317be168c0dSopenharmony_ci- 2318be168c0dSopenharmony_ci-// Keep the sequence and length same as MSLiteManagerLogDomain 2319be168c0dSopenharmony_ci-static constexpr OHOS::HiviewDFX::HiLogLabel MSLite_LABEL = {LOG_CORE, DOMAIN_FRAMEWORK, "MSLiteFwk"}; 2320be168c0dSopenharmony_ci- 2321be168c0dSopenharmony_ci-#else 2322be168c0dSopenharmony_ci- 2323be168c0dSopenharmony_ci-#define MSLOG_IF(level) \ 2324be168c0dSopenharmony_ci- mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \ 2325be168c0dSopenharmony_ci- mindspore::LiteLogStream() 2326be168c0dSopenharmony_ci- 2327be168c0dSopenharmony_ci-#endif 2328be168c0dSopenharmony_ci- 2329be168c0dSopenharmony_ci } // namespace mindspore 2330be168c0dSopenharmony_ci 2331be168c0dSopenharmony_ci #ifdef Debug 2332be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc 2333be168c0dSopenharmony_ciindex 5e1878b9..13957ed7 100644 2334be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc 2335be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc 2336be168c0dSopenharmony_ci@@ -19,6 +19,9 @@ 2337be168c0dSopenharmony_ci #include "nnacl/custom_parameter.h" 2338be168c0dSopenharmony_ci #include "nnacl/split_parameter.h" 2339be168c0dSopenharmony_ci #include "nnacl/custom_gru_parameter.h" 2340be168c0dSopenharmony_ci+#include "nnacl/custom_masked_fill_parameter.h" 2341be168c0dSopenharmony_ci+#include "nnacl/custom_is_inf_parameter.h" 2342be168c0dSopenharmony_ci+#include "nnacl/custom_tensor_scatter_max_parameter.h" 2343be168c0dSopenharmony_ci using mindspore::schema::PrimitiveType_Custom; 2344be168c0dSopenharmony_ci 2345be168c0dSopenharmony_ci namespace mindspore { 2346be168c0dSopenharmony_ci@@ -92,6 +95,39 @@ OpParameter *CreateCustomGruParameter() { 2347be168c0dSopenharmony_ci return reinterpret_cast<OpParameter *>(param); 2348be168c0dSopenharmony_ci } 2349be168c0dSopenharmony_ci 2350be168c0dSopenharmony_ci+OpParameter *CreateCustomIsInfParameter() { 2351be168c0dSopenharmony_ci+ auto *param = static_cast<CustomIsInfParameter *>(malloc(sizeof(CustomIsInfParameter))); 2352be168c0dSopenharmony_ci+ if (param == nullptr) { 2353be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc CustomIsInfParameter failed."; 2354be168c0dSopenharmony_ci+ return nullptr; 2355be168c0dSopenharmony_ci+ } 2356be168c0dSopenharmony_ci+ memset(param, 0, sizeof(CustomIsInfParameter)); 2357be168c0dSopenharmony_ci+ param->op_parameter_.type_ = PrimType_Inner_CustomIsInf; 2358be168c0dSopenharmony_ci+ return reinterpret_cast<OpParameter *>(param); 2359be168c0dSopenharmony_ci+} 2360be168c0dSopenharmony_ci+ 2361be168c0dSopenharmony_ci+OpParameter *CreateCustomTensorScatterMaxParameter() { 2362be168c0dSopenharmony_ci+ auto *param = static_cast<CustomTensorScatterMaxParameter *>(malloc(sizeof(CustomTensorScatterMaxParameter))); 2363be168c0dSopenharmony_ci+ if (param == nullptr) { 2364be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc CustomTensorScatterMaxParameter failed."; 2365be168c0dSopenharmony_ci+ return nullptr; 2366be168c0dSopenharmony_ci+ } 2367be168c0dSopenharmony_ci+ memset(param, 0, sizeof(CustomTensorScatterMaxParameter)); 2368be168c0dSopenharmony_ci+ param->op_parameter_.type_ = PrimType_Inner_CustomTensorScatterMax; 2369be168c0dSopenharmony_ci+ return reinterpret_cast<OpParameter *>(param); 2370be168c0dSopenharmony_ci+} 2371be168c0dSopenharmony_ci+ 2372be168c0dSopenharmony_ci+OpParameter *CreateCustomMaskedFillParameter() { 2373be168c0dSopenharmony_ci+ auto *param = static_cast<CustomMaskedFillParameter *>(malloc(sizeof(CustomMaskedFillParameter))); 2374be168c0dSopenharmony_ci+ if (param == nullptr) { 2375be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc CustomMaskedFillParameter failed."; 2376be168c0dSopenharmony_ci+ return nullptr; 2377be168c0dSopenharmony_ci+ } 2378be168c0dSopenharmony_ci+ memset(param, 0, sizeof(CustomMaskedFillParameter)); 2379be168c0dSopenharmony_ci+ param->op_parameter_.type_ = PrimType_Inner_CustomMaskedFill; 2380be168c0dSopenharmony_ci+ return reinterpret_cast<OpParameter *>(param); 2381be168c0dSopenharmony_ci+} 2382be168c0dSopenharmony_ci+ 2383be168c0dSopenharmony_ci OpParameter *PopulateCustomParameter(const void *prim) { 2384be168c0dSopenharmony_ci MS_CHECK_TRUE_RET(prim != nullptr, nullptr); 2385be168c0dSopenharmony_ci auto primitive = static_cast<const schema::Primitive *>(prim); 2386be168c0dSopenharmony_ci@@ -131,6 +167,23 @@ OpParameter *PopulateCustomParameter(const void *prim) { 2387be168c0dSopenharmony_ci return CreateCustomGruParameter(); 2388be168c0dSopenharmony_ci } else if (type == "CastGatherReduceFusion") { 2389be168c0dSopenharmony_ci return CreateParam(PrimType_Inner_CastGatherReduceFusion); 2390be168c0dSopenharmony_ci+ } else if (type == "ThirdPartyModel") { 2391be168c0dSopenharmony_ci+ auto *param = static_cast<CustomParameter *>(malloc(sizeof(CustomParameter))); 2392be168c0dSopenharmony_ci+ if (param == nullptr) { 2393be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc CustomParameter failed."; 2394be168c0dSopenharmony_ci+ return nullptr; 2395be168c0dSopenharmony_ci+ } 2396be168c0dSopenharmony_ci+ memset(param, 0, sizeof(CustomParameter)); 2397be168c0dSopenharmony_ci+ param->op_parameter_.type_ = PrimType_Inner_ThirdPartyModel; 2398be168c0dSopenharmony_ci+ // Just use the attr_data pointer to save the prim directly, the inner value is parsed as necessary. 2399be168c0dSopenharmony_ci+ param->attr_data[0] = static_cast<char *>(const_cast<void *>(prim)); 2400be168c0dSopenharmony_ci+ return reinterpret_cast<OpParameter *>(param); 2401be168c0dSopenharmony_ci+ } else if (type == "MaskedFill") { 2402be168c0dSopenharmony_ci+ return CreateCustomMaskedFillParameter(); 2403be168c0dSopenharmony_ci+ } else if (type == "TensorScatterMax") { 2404be168c0dSopenharmony_ci+ return CreateCustomTensorScatterMaxParameter(); 2405be168c0dSopenharmony_ci+ } else if (type == "IsInf") { 2406be168c0dSopenharmony_ci+ return CreateCustomIsInfParameter(); 2407be168c0dSopenharmony_ci } else { 2408be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported custom type: " << type; 2409be168c0dSopenharmony_ci } 2410be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc 2411be168c0dSopenharmony_ciindex f614ef09..c5f825aa 100644 2412be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/context_c.cc 2413be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/context_c.cc 2414be168c0dSopenharmony_ci@@ -14,12 +14,17 @@ 2415be168c0dSopenharmony_ci * limitations under the License. 2416be168c0dSopenharmony_ci */ 2417be168c0dSopenharmony_ci #include "include/c_api/context_c.h" 2418be168c0dSopenharmony_ci-#include "src/litert/c_api/context_c.h" 2419be168c0dSopenharmony_ci+#include "include/api/context.h" 2420be168c0dSopenharmony_ci+#include <string.h> 2421be168c0dSopenharmony_ci+#include "src/litert/c_api/type_c_private.h" 2422be168c0dSopenharmony_ci #include "src/common/log_adapter.h" 2423be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT 2424be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 2425be168c0dSopenharmony_ci+#endif 2426be168c0dSopenharmony_ci 2427be168c0dSopenharmony_ci // ================ Context ================ 2428be168c0dSopenharmony_ci OH_AI_ContextHandle OH_AI_ContextCreate() { 2429be168c0dSopenharmony_ci- auto impl = new (std::nothrow) mindspore::ContextC; 2430be168c0dSopenharmony_ci+ auto impl = new (std::nothrow) mindspore::Context(); 2431be168c0dSopenharmony_ci if (impl == nullptr) { 2432be168c0dSopenharmony_ci MS_LOG(ERROR) << "memory allocation failed."; 2433be168c0dSopenharmony_ci return nullptr; 2434be168c0dSopenharmony_ci@@ -29,7 +34,7 @@ OH_AI_ContextHandle OH_AI_ContextCreate() { 2435be168c0dSopenharmony_ci 2436be168c0dSopenharmony_ci void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) { 2437be168c0dSopenharmony_ci if (context != nullptr && *context != nullptr) { 2438be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(*context); 2439be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(*context); 2440be168c0dSopenharmony_ci delete impl; 2441be168c0dSopenharmony_ci *context = nullptr; 2442be168c0dSopenharmony_ci } 2443be168c0dSopenharmony_ci@@ -40,8 +45,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) 2444be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2445be168c0dSopenharmony_ci return; 2446be168c0dSopenharmony_ci } 2447be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2448be168c0dSopenharmony_ci- impl->thread_num = thread_num; 2449be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2450be168c0dSopenharmony_ci+ impl->SetThreadNum(thread_num); 2451be168c0dSopenharmony_ci } 2452be168c0dSopenharmony_ci 2453be168c0dSopenharmony_ci int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 2454be168c0dSopenharmony_ci@@ -49,8 +54,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) { 2455be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2456be168c0dSopenharmony_ci return 0; 2457be168c0dSopenharmony_ci } 2458be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2459be168c0dSopenharmony_ci- return impl->thread_num; 2460be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2461be168c0dSopenharmony_ci+ return impl->GetThreadNum(); 2462be168c0dSopenharmony_ci } 2463be168c0dSopenharmony_ci 2464be168c0dSopenharmony_ci void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 2465be168c0dSopenharmony_ci@@ -58,8 +63,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) { 2466be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2467be168c0dSopenharmony_ci return; 2468be168c0dSopenharmony_ci } 2469be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2470be168c0dSopenharmony_ci- impl->affinity_mode = mode; 2471be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2472be168c0dSopenharmony_ci+ impl->SetThreadAffinity(mode); 2473be168c0dSopenharmony_ci return; 2474be168c0dSopenharmony_ci } 2475be168c0dSopenharmony_ci 2476be168c0dSopenharmony_ci@@ -68,8 +73,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) { 2477be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2478be168c0dSopenharmony_ci return 0; 2479be168c0dSopenharmony_ci } 2480be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2481be168c0dSopenharmony_ci- return impl->affinity_mode; 2482be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2483be168c0dSopenharmony_ci+ return impl->GetThreadAffinityMode(); 2484be168c0dSopenharmony_ci } 2485be168c0dSopenharmony_ci 2486be168c0dSopenharmony_ci void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) { 2487be168c0dSopenharmony_ci@@ -78,8 +83,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i 2488be168c0dSopenharmony_ci return; 2489be168c0dSopenharmony_ci } 2490be168c0dSopenharmony_ci const std::vector<int32_t> vec_core_list(core_list, core_list + core_num); 2491be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2492be168c0dSopenharmony_ci- impl->affinity_core_list = vec_core_list; 2493be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2494be168c0dSopenharmony_ci+ impl->SetThreadAffinity(vec_core_list); 2495be168c0dSopenharmony_ci return; 2496be168c0dSopenharmony_ci } 2497be168c0dSopenharmony_ci 2498be168c0dSopenharmony_ci@@ -88,9 +93,18 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle 2499be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2500be168c0dSopenharmony_ci return nullptr; 2501be168c0dSopenharmony_ci } 2502be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2503be168c0dSopenharmony_ci- *core_num = impl->affinity_core_list.size(); 2504be168c0dSopenharmony_ci- return impl->affinity_core_list.data(); 2505be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2506be168c0dSopenharmony_ci+ auto affinity_core_list = impl->GetThreadAffinityCoreList(); 2507be168c0dSopenharmony_ci+ *core_num = affinity_core_list.size(); 2508be168c0dSopenharmony_ci+ int32_t *core_list = static_cast<int32_t *>(malloc((*core_num) * sizeof(int32_t))); 2509be168c0dSopenharmony_ci+ if (core_list == nullptr) { 2510be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc core_list is null."; 2511be168c0dSopenharmony_ci+ return nullptr; 2512be168c0dSopenharmony_ci+ } 2513be168c0dSopenharmony_ci+ for (size_t i = 0; i < affinity_core_list.size(); i++) { 2514be168c0dSopenharmony_ci+ core_list[i] = affinity_core_list[i]; 2515be168c0dSopenharmony_ci+ } 2516be168c0dSopenharmony_ci+ return core_list; 2517be168c0dSopenharmony_ci } 2518be168c0dSopenharmony_ci 2519be168c0dSopenharmony_ci void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_parallel) { 2520be168c0dSopenharmony_ci@@ -98,8 +112,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle 2521be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2522be168c0dSopenharmony_ci return; 2523be168c0dSopenharmony_ci } 2524be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2525be168c0dSopenharmony_ci- impl->enable_parallel = is_parallel; 2526be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2527be168c0dSopenharmony_ci+ impl->SetEnableParallel(is_parallel); 2528be168c0dSopenharmony_ci } 2529be168c0dSopenharmony_ci 2530be168c0dSopenharmony_ci bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { 2531be168c0dSopenharmony_ci@@ -107,8 +121,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) { 2532be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2533be168c0dSopenharmony_ci return false; 2534be168c0dSopenharmony_ci } 2535be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2536be168c0dSopenharmony_ci- return impl->enable_parallel; 2537be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2538be168c0dSopenharmony_ci+ return impl->GetEnableParallel(); 2539be168c0dSopenharmony_ci } 2540be168c0dSopenharmony_ci 2541be168c0dSopenharmony_ci void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) { 2542be168c0dSopenharmony_ci@@ -116,25 +130,36 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan 2543be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2544be168c0dSopenharmony_ci return; 2545be168c0dSopenharmony_ci } 2546be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::ContextC *>(context); 2547be168c0dSopenharmony_ci- std::shared_ptr<mindspore::DeviceInfoC> device(static_cast<mindspore::DeviceInfoC *>(device_info)); 2548be168c0dSopenharmony_ci- impl->device_info_list.push_back(device); 2549be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::Context *>(context); 2550be168c0dSopenharmony_ci+ std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info)); 2551be168c0dSopenharmony_ci+ impl->MutableDeviceInfo().push_back(device); 2552be168c0dSopenharmony_ci } 2553be168c0dSopenharmony_ci 2554be168c0dSopenharmony_ci // ================ DeviceInfo ================ 2555be168c0dSopenharmony_ci OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type) { 2556be168c0dSopenharmony_ci- mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC; 2557be168c0dSopenharmony_ci+ mindspore::DeviceInfoContext *impl; 2558be168c0dSopenharmony_ci+ if (OH_AI_DEVICETYPE_CPU == device_type) { 2559be168c0dSopenharmony_ci+ impl = new (std::nothrow) mindspore::CPUDeviceInfo(); 2560be168c0dSopenharmony_ci+ } else if (OH_AI_DEVICETYPE_GPU == device_type) { 2561be168c0dSopenharmony_ci+ impl = new (std::nothrow) mindspore::GPUDeviceInfo(); 2562be168c0dSopenharmony_ci+ } else if (OH_AI_DEVICETYPE_KIRIN_NPU == device_type) { 2563be168c0dSopenharmony_ci+ impl = new (std::nothrow) mindspore::KirinNPUDeviceInfo(); 2564be168c0dSopenharmony_ci+ } else if (OH_AI_DEVICETYPE_NNRT == device_type) { 2565be168c0dSopenharmony_ci+ impl = new (std::nothrow) mindspore::NNRTDeviceInfo(); 2566be168c0dSopenharmony_ci+ } else { 2567be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "device_type is invalid."; 2568be168c0dSopenharmony_ci+ impl = nullptr; 2569be168c0dSopenharmony_ci+ } 2570be168c0dSopenharmony_ci if (impl == nullptr) { 2571be168c0dSopenharmony_ci MS_LOG(ERROR) << "memory allocation failed."; 2572be168c0dSopenharmony_ci return nullptr; 2573be168c0dSopenharmony_ci } 2574be168c0dSopenharmony_ci- impl->device_type = device_type; 2575be168c0dSopenharmony_ci return static_cast<OH_AI_DeviceInfoHandle>(impl); 2576be168c0dSopenharmony_ci } 2577be168c0dSopenharmony_ci 2578be168c0dSopenharmony_ci void OH_AI_DeviceInfoDestroy(OH_AI_DeviceInfoHandle *device_info) { 2579be168c0dSopenharmony_ci if (device_info != nullptr && *device_info != nullptr) { 2580be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(*device_info); 2581be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::DeviceInfoContext *>(*device_info); 2582be168c0dSopenharmony_ci delete impl; 2583be168c0dSopenharmony_ci *device_info = nullptr; 2584be168c0dSopenharmony_ci } 2585be168c0dSopenharmony_ci@@ -145,8 +170,8 @@ void OH_AI_DeviceInfoSetProvider(OH_AI_DeviceInfoHandle device_info, const char 2586be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2587be168c0dSopenharmony_ci return; 2588be168c0dSopenharmony_ci } 2589be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2590be168c0dSopenharmony_ci- impl->provider = provider; 2591be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2592be168c0dSopenharmony_ci+ impl->SetProvider(provider); 2593be168c0dSopenharmony_ci } 2594be168c0dSopenharmony_ci 2595be168c0dSopenharmony_ci const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info) { 2596be168c0dSopenharmony_ci@@ -154,8 +179,14 @@ const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info 2597be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2598be168c0dSopenharmony_ci return nullptr; 2599be168c0dSopenharmony_ci } 2600be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2601be168c0dSopenharmony_ci- return impl->provider.c_str(); 2602be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2603be168c0dSopenharmony_ci+ char *provider = static_cast<char *>(malloc(impl->GetProvider().size() + 1)); 2604be168c0dSopenharmony_ci+ if (provider == nullptr) { 2605be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc provider is null."; 2606be168c0dSopenharmony_ci+ return nullptr; 2607be168c0dSopenharmony_ci+ } 2608be168c0dSopenharmony_ci+ strcpy(provider, impl->GetProvider().c_str()); 2609be168c0dSopenharmony_ci+ return provider; 2610be168c0dSopenharmony_ci } 2611be168c0dSopenharmony_ci 2612be168c0dSopenharmony_ci void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const char *device) { 2613be168c0dSopenharmony_ci@@ -163,8 +194,8 @@ void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const 2614be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2615be168c0dSopenharmony_ci return; 2616be168c0dSopenharmony_ci } 2617be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2618be168c0dSopenharmony_ci- impl->provider_device = device; 2619be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2620be168c0dSopenharmony_ci+ impl->SetProviderDevice(device); 2621be168c0dSopenharmony_ci } 2622be168c0dSopenharmony_ci 2623be168c0dSopenharmony_ci const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle device_info) { 2624be168c0dSopenharmony_ci@@ -172,8 +203,14 @@ const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle devic 2625be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2626be168c0dSopenharmony_ci return nullptr; 2627be168c0dSopenharmony_ci } 2628be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2629be168c0dSopenharmony_ci- return impl->provider_device.c_str(); 2630be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2631be168c0dSopenharmony_ci+ char *provider_device = static_cast<char *>(malloc(impl->GetProviderDevice().size() + 1)); 2632be168c0dSopenharmony_ci+ if (provider_device == nullptr) { 2633be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc provider_device is null."; 2634be168c0dSopenharmony_ci+ return nullptr; 2635be168c0dSopenharmony_ci+ } 2636be168c0dSopenharmony_ci+ strcpy(provider_device, impl->GetProviderDevice().c_str()); 2637be168c0dSopenharmony_ci+ return provider_device; 2638be168c0dSopenharmony_ci } 2639be168c0dSopenharmony_ci 2640be168c0dSopenharmony_ci OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle device_info) { 2641be168c0dSopenharmony_ci@@ -181,8 +218,8 @@ OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle devi 2642be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2643be168c0dSopenharmony_ci return OH_AI_DEVICETYPE_INVALID; 2644be168c0dSopenharmony_ci } 2645be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2646be168c0dSopenharmony_ci- return impl->device_type; 2647be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info); 2648be168c0dSopenharmony_ci+ return static_cast<OH_AI_DeviceType>(impl->GetDeviceType()); 2649be168c0dSopenharmony_ci } 2650be168c0dSopenharmony_ci 2651be168c0dSopenharmony_ci void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_fp16) { 2652be168c0dSopenharmony_ci@@ -190,9 +227,17 @@ void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_f 2653be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2654be168c0dSopenharmony_ci return; 2655be168c0dSopenharmony_ci } 2656be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2657be168c0dSopenharmony_ci- if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) { 2658be168c0dSopenharmony_ci- impl->enable_fp16 = is_fp16; 2659be168c0dSopenharmony_ci+ 2660be168c0dSopenharmony_ci+ auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info); 2661be168c0dSopenharmony_ci+ if (OH_AI_DEVICETYPE_CPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2662be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info); 2663be168c0dSopenharmony_ci+ impl->SetEnableFP16(is_fp16); 2664be168c0dSopenharmony_ci+ } else if (OH_AI_DEVICETYPE_GPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2665be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info); 2666be168c0dSopenharmony_ci+ impl->SetEnableFP16(is_fp16); 2667be168c0dSopenharmony_ci+ } else if (OH_AI_DEVICETYPE_NNRT == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2668be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::NNRTDeviceInfo *>(device_info); 2669be168c0dSopenharmony_ci+ impl->SetEnableFP16(is_fp16); 2670be168c0dSopenharmony_ci } else { 2671be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported Feature."; 2672be168c0dSopenharmony_ci } 2673be168c0dSopenharmony_ci@@ -203,11 +248,19 @@ bool OH_AI_DeviceInfoGetEnableFP16(const OH_AI_DeviceInfoHandle device_info) { 2674be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2675be168c0dSopenharmony_ci return false; 2676be168c0dSopenharmony_ci } 2677be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2678be168c0dSopenharmony_ci- if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) { 2679be168c0dSopenharmony_ci- return impl->enable_fp16; 2680be168c0dSopenharmony_ci+ 2681be168c0dSopenharmony_ci+ auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info); 2682be168c0dSopenharmony_ci+ if (OH_AI_DEVICETYPE_CPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2683be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info); 2684be168c0dSopenharmony_ci+ return impl->GetEnableFP16(); 2685be168c0dSopenharmony_ci+ } else if (OH_AI_DEVICETYPE_GPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2686be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info); 2687be168c0dSopenharmony_ci+ return impl->GetEnableFP16(); 2688be168c0dSopenharmony_ci+ } else if (OH_AI_DEVICETYPE_NNRT == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) { 2689be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::NNRTDeviceInfo *>(device_info); 2690be168c0dSopenharmony_ci+ return impl->GetEnableFP16(); 2691be168c0dSopenharmony_ci } else { 2692be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type; 2693be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl_device->GetDeviceType(); 2694be168c0dSopenharmony_ci return false; 2695be168c0dSopenharmony_ci } 2696be168c0dSopenharmony_ci } 2697be168c0dSopenharmony_ci@@ -217,9 +270,10 @@ void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, int freque 2698be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2699be168c0dSopenharmony_ci return; 2700be168c0dSopenharmony_ci } 2701be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2702be168c0dSopenharmony_ci- if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) { 2703be168c0dSopenharmony_ci- impl->frequency = frequency; 2704be168c0dSopenharmony_ci+ auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info); 2705be168c0dSopenharmony_ci+ if (static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType()) == OH_AI_DEVICETYPE_KIRIN_NPU) { 2706be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info); 2707be168c0dSopenharmony_ci+ impl->SetFrequency(frequency); 2708be168c0dSopenharmony_ci } else { 2709be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported Feature."; 2710be168c0dSopenharmony_ci } 2711be168c0dSopenharmony_ci@@ -230,11 +284,231 @@ int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info) { // 2712be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 2713be168c0dSopenharmony_ci return -1; 2714be168c0dSopenharmony_ci } 2715be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::DeviceInfoC *>(device_info); 2716be168c0dSopenharmony_ci- if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) { 2717be168c0dSopenharmony_ci- return impl->frequency; 2718be168c0dSopenharmony_ci+ auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info); 2719be168c0dSopenharmony_ci+ if (static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType()) == OH_AI_DEVICETYPE_KIRIN_NPU) { 2720be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info); 2721be168c0dSopenharmony_ci+ return impl->GetFrequency(); 2722be168c0dSopenharmony_ci } else { 2723be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported Feature."; 2724be168c0dSopenharmony_ci return -1; 2725be168c0dSopenharmony_ci } 2726be168c0dSopenharmony_ci } 2727be168c0dSopenharmony_ci+ 2728be168c0dSopenharmony_ci+NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num) { 2729be168c0dSopenharmony_ci+ if (num == nullptr) { 2730be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Input num is null"; 2731be168c0dSopenharmony_ci+ return nullptr; 2732be168c0dSopenharmony_ci+ } 2733be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT 2734be168c0dSopenharmony_ci+ *num = 0; 2735be168c0dSopenharmony_ci+ 2736be168c0dSopenharmony_ci+ const size_t *all_device_ids; 2737be168c0dSopenharmony_ci+ uint32_t device_count; 2738be168c0dSopenharmony_ci+ auto ret = OH_NNDevice_GetAllDevicesID(&all_device_ids, &device_count); 2739be168c0dSopenharmony_ci+ if ((ret != OH_NN_SUCCESS) || (device_count == 0)) { 2740be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNRT get all device id failed, ret: " << ret; 2741be168c0dSopenharmony_ci+ return nullptr; 2742be168c0dSopenharmony_ci+ } 2743be168c0dSopenharmony_ci+ 2744be168c0dSopenharmony_ci+ NNRTDeviceDesc *desc = (NNRTDeviceDesc *)malloc(sizeof(NNRTDeviceDesc) * device_count); 2745be168c0dSopenharmony_ci+ if (desc == nullptr) { 2746be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNRT allocate desc failed"; 2747be168c0dSopenharmony_ci+ return nullptr; 2748be168c0dSopenharmony_ci+ } 2749be168c0dSopenharmony_ci+ 2750be168c0dSopenharmony_ci+ for (uint32_t i = 0; i < device_count; i++) { 2751be168c0dSopenharmony_ci+ desc[i].device_id = all_device_ids[i]; 2752be168c0dSopenharmony_ci+ OH_NN_DeviceType type; 2753be168c0dSopenharmony_ci+ (void)OH_NNDevice_GetType(all_device_ids[i], &type); 2754be168c0dSopenharmony_ci+ desc[i].device_type = static_cast<OH_AI_NNRTDeviceType>(type); 2755be168c0dSopenharmony_ci+ 2756be168c0dSopenharmony_ci+ const char *name = nullptr; 2757be168c0dSopenharmony_ci+ (void)OH_NNDevice_GetName(all_device_ids[i], &name); 2758be168c0dSopenharmony_ci+ desc[i].device_name[127] = '\0'; 2759be168c0dSopenharmony_ci+ strncpy(desc[i].device_name, name, 127); 2760be168c0dSopenharmony_ci+ } 2761be168c0dSopenharmony_ci+ *num = device_count; 2762be168c0dSopenharmony_ci+ return desc; 2763be168c0dSopenharmony_ci+#else 2764be168c0dSopenharmony_ci+ return nullptr; 2765be168c0dSopenharmony_ci+#endif 2766be168c0dSopenharmony_ci+} 2767be168c0dSopenharmony_ci+ 2768be168c0dSopenharmony_ci+NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index) { 2769be168c0dSopenharmony_ci+ if (descs == nullptr) { 2770be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "descs is null"; 2771be168c0dSopenharmony_ci+ return nullptr; 2772be168c0dSopenharmony_ci+ } 2773be168c0dSopenharmony_ci+ return descs + index; 2774be168c0dSopenharmony_ci+} 2775be168c0dSopenharmony_ci+ 2776be168c0dSopenharmony_ci+void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc) { 2777be168c0dSopenharmony_ci+ if (desc == nullptr) { 2778be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "desc is null"; 2779be168c0dSopenharmony_ci+ return; 2780be168c0dSopenharmony_ci+ } 2781be168c0dSopenharmony_ci+ free(*desc); 2782be168c0dSopenharmony_ci+ *desc = nullptr; 2783be168c0dSopenharmony_ci+} 2784be168c0dSopenharmony_ci+ 2785be168c0dSopenharmony_ci+size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) { 2786be168c0dSopenharmony_ci+ if (desc == nullptr) { 2787be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNRT desc is null"; 2788be168c0dSopenharmony_ci+ return 0; 2789be168c0dSopenharmony_ci+ } 2790be168c0dSopenharmony_ci+ return desc->device_id; 2791be168c0dSopenharmony_ci+} 2792be168c0dSopenharmony_ci+ 2793be168c0dSopenharmony_ci+const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) { 2794be168c0dSopenharmony_ci+ if (desc == nullptr) { 2795be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNRT desc is null"; 2796be168c0dSopenharmony_ci+ return nullptr; 2797be168c0dSopenharmony_ci+ } 2798be168c0dSopenharmony_ci+ return desc->device_name; 2799be168c0dSopenharmony_ci+} 2800be168c0dSopenharmony_ci+ 2801be168c0dSopenharmony_ci+OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) { 2802be168c0dSopenharmony_ci+ if (desc == nullptr) { 2803be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNRT desc is null"; 2804be168c0dSopenharmony_ci+ return OH_AI_NNRTDeviceType::OH_AI_NNRTDEVICE_OTHERS; 2805be168c0dSopenharmony_ci+ } 2806be168c0dSopenharmony_ci+ return desc->device_type; 2807be168c0dSopenharmony_ci+} 2808be168c0dSopenharmony_ci+ 2809be168c0dSopenharmony_ci+OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name) { 2810be168c0dSopenharmony_ci+ size_t num = 0; 2811be168c0dSopenharmony_ci+ NNRTDeviceDesc *desc = OH_AI_GetAllNNRTDeviceDescs(&num); 2812be168c0dSopenharmony_ci+ if (desc == nullptr) { 2813be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Get all device desc failed"; 2814be168c0dSopenharmony_ci+ return nullptr; 2815be168c0dSopenharmony_ci+ } 2816be168c0dSopenharmony_ci+ 2817be168c0dSopenharmony_ci+ OH_AI_DeviceInfoHandle handle = nullptr; 2818be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 2819be168c0dSopenharmony_ci+ if (strncmp(desc[i].device_name, name, NNRT_DEVICE_NAME_MAX - 1) == 0) { 2820be168c0dSopenharmony_ci+ handle = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT); 2821be168c0dSopenharmony_ci+ OH_AI_DeviceInfoSetDeviceId(handle, desc[i].device_id); 2822be168c0dSopenharmony_ci+ break; 2823be168c0dSopenharmony_ci+ } 2824be168c0dSopenharmony_ci+ } 2825be168c0dSopenharmony_ci+ OH_AI_DestroyAllNNRTDeviceDescs(&desc); 2826be168c0dSopenharmony_ci+ return handle; 2827be168c0dSopenharmony_ci+} 2828be168c0dSopenharmony_ci+ 2829be168c0dSopenharmony_ci+OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type) { 2830be168c0dSopenharmony_ci+ size_t num = 0; 2831be168c0dSopenharmony_ci+ NNRTDeviceDesc *desc = OH_AI_GetAllNNRTDeviceDescs(&num); 2832be168c0dSopenharmony_ci+ if (desc == nullptr) { 2833be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Get all device desc failed"; 2834be168c0dSopenharmony_ci+ return nullptr; 2835be168c0dSopenharmony_ci+ } 2836be168c0dSopenharmony_ci+ 2837be168c0dSopenharmony_ci+ OH_AI_DeviceInfoHandle handle = nullptr; 2838be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 2839be168c0dSopenharmony_ci+ if (desc[i].device_type == type) { 2840be168c0dSopenharmony_ci+ handle = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT); 2841be168c0dSopenharmony_ci+ OH_AI_DeviceInfoSetDeviceId(handle, desc[i].device_id); 2842be168c0dSopenharmony_ci+ break; 2843be168c0dSopenharmony_ci+ } 2844be168c0dSopenharmony_ci+ } 2845be168c0dSopenharmony_ci+ OH_AI_DestroyAllNNRTDeviceDescs(&desc); 2846be168c0dSopenharmony_ci+ return handle; 2847be168c0dSopenharmony_ci+} 2848be168c0dSopenharmony_ci+ 2849be168c0dSopenharmony_ci+void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id) { 2850be168c0dSopenharmony_ci+ if (device_info == nullptr) { 2851be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "device info is null"; 2852be168c0dSopenharmony_ci+ return; 2853be168c0dSopenharmony_ci+ } 2854be168c0dSopenharmony_ci+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2855be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Set device_id of non-NNRT device is not allowable, ignored"; 2856be168c0dSopenharmony_ci+ return; 2857be168c0dSopenharmony_ci+ } 2858be168c0dSopenharmony_ci+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2859be168c0dSopenharmony_ci+ impl->SetDeviceID(device_id); 2860be168c0dSopenharmony_ci+} 2861be168c0dSopenharmony_ci+ 2862be168c0dSopenharmony_ci+size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info) { 2863be168c0dSopenharmony_ci+ if (device_info == nullptr) { 2864be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "device info is null"; 2865be168c0dSopenharmony_ci+ return 0; 2866be168c0dSopenharmony_ci+ } 2867be168c0dSopenharmony_ci+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2868be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Get device_id of non-NNRT device is not allowable, ignored"; 2869be168c0dSopenharmony_ci+ return 0; 2870be168c0dSopenharmony_ci+ } 2871be168c0dSopenharmony_ci+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2872be168c0dSopenharmony_ci+ return impl->GetDeviceID(); 2873be168c0dSopenharmony_ci+} 2874be168c0dSopenharmony_ci+ 2875be168c0dSopenharmony_ci+void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode) { 2876be168c0dSopenharmony_ci+ if (device_info == nullptr) { 2877be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "device info is null"; 2878be168c0dSopenharmony_ci+ return; 2879be168c0dSopenharmony_ci+ } 2880be168c0dSopenharmony_ci+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2881be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Set performance_mode of non-NNRT device is not allowable, ignored"; 2882be168c0dSopenharmony_ci+ return; 2883be168c0dSopenharmony_ci+ } 2884be168c0dSopenharmony_ci+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2885be168c0dSopenharmony_ci+ impl->SetPerformanceMode(mode); 2886be168c0dSopenharmony_ci+} 2887be168c0dSopenharmony_ci+ 2888be168c0dSopenharmony_ci+OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info) { 2889be168c0dSopenharmony_ci+ if (device_info == nullptr) { 2890be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "device info is null"; 2891be168c0dSopenharmony_ci+ return OH_AI_PERFORMANCE_NONE; 2892be168c0dSopenharmony_ci+ } 2893be168c0dSopenharmony_ci+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2894be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Get performance_mode of non-NNRT device is not allowable, ignored"; 2895be168c0dSopenharmony_ci+ return OH_AI_PERFORMANCE_NONE; 2896be168c0dSopenharmony_ci+ } 2897be168c0dSopenharmony_ci+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2898be168c0dSopenharmony_ci+ return static_cast<OH_AI_PerformanceMode>(impl->GetPerformanceMode()); 2899be168c0dSopenharmony_ci+} 2900be168c0dSopenharmony_ci+ 2901be168c0dSopenharmony_ci+void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority) { 2902be168c0dSopenharmony_ci+ if (device_info == nullptr) { 2903be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "device info is null"; 2904be168c0dSopenharmony_ci+ return; 2905be168c0dSopenharmony_ci+ } 2906be168c0dSopenharmony_ci+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2907be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Set priority of non-NNRT device is not allowable, ignored"; 2908be168c0dSopenharmony_ci+ return; 2909be168c0dSopenharmony_ci+ } 2910be168c0dSopenharmony_ci+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2911be168c0dSopenharmony_ci+ impl->SetPriority(priority); 2912be168c0dSopenharmony_ci+} 2913be168c0dSopenharmony_ci+ 2914be168c0dSopenharmony_ci+OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info) { 2915be168c0dSopenharmony_ci+ if (device_info == nullptr) { 2916be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "device info is null"; 2917be168c0dSopenharmony_ci+ return OH_AI_PRIORITY_NONE; 2918be168c0dSopenharmony_ci+ } 2919be168c0dSopenharmony_ci+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2920be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Get priority of non-NNRT device is not allowable, ignored"; 2921be168c0dSopenharmony_ci+ return OH_AI_PRIORITY_NONE; 2922be168c0dSopenharmony_ci+ } 2923be168c0dSopenharmony_ci+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2924be168c0dSopenharmony_ci+ return static_cast<OH_AI_Priority>(impl->GetPriority()); 2925be168c0dSopenharmony_ci+} 2926be168c0dSopenharmony_ci+ 2927be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, 2928be168c0dSopenharmony_ci+ const char *name, const char*value, size_t value_size) { 2929be168c0dSopenharmony_ci+ if (device_info == nullptr) { 2930be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "device info is null"; 2931be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 2932be168c0dSopenharmony_ci+ } 2933be168c0dSopenharmony_ci+ if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) { 2934be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Add extension to non-NNRT device is not allowable, ignored"; 2935be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 2936be168c0dSopenharmony_ci+ } 2937be168c0dSopenharmony_ci+ auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info); 2938be168c0dSopenharmony_ci+ mindspore::Extension extension; 2939be168c0dSopenharmony_ci+ extension.name = std::string(name); 2940be168c0dSopenharmony_ci+ extension.value = std::vector<uint8_t>(value, value + value_size); 2941be168c0dSopenharmony_ci+ std::vector<mindspore::Extension> extension_list = impl->GetExtensions(); 2942be168c0dSopenharmony_ci+ extension_list.push_back(extension); 2943be168c0dSopenharmony_ci+ impl->SetExtensions(extension_list); 2944be168c0dSopenharmony_ci+ return OH_AI_STATUS_SUCCESS; 2945be168c0dSopenharmony_ci+} 2946be168c0dSopenharmony_ci\ No newline at end of file 2947be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h 2948be168c0dSopenharmony_ciindex 076f4d1f..dc88b8a4 100644 2949be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/context_c.h 2950be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/context_c.h 2951be168c0dSopenharmony_ci@@ -21,27 +21,4 @@ 2952be168c0dSopenharmony_ci #include <memory> 2953be168c0dSopenharmony_ci #include "include/c_api/types_c.h" 2954be168c0dSopenharmony_ci 2955be168c0dSopenharmony_ci-namespace mindspore { 2956be168c0dSopenharmony_ci-class Allocator; 2957be168c0dSopenharmony_ci-class Delegate; 2958be168c0dSopenharmony_ci- 2959be168c0dSopenharmony_ci-typedef struct DeviceInfoC { 2960be168c0dSopenharmony_ci- OH_AI_DeviceType device_type; 2961be168c0dSopenharmony_ci- bool enable_fp16 = false; 2962be168c0dSopenharmony_ci- int frequency = 3; 2963be168c0dSopenharmony_ci- std::string provider; 2964be168c0dSopenharmony_ci- std::string provider_device; 2965be168c0dSopenharmony_ci- std::shared_ptr<Allocator> allocator = nullptr; 2966be168c0dSopenharmony_ci-} DeviceInfoC; 2967be168c0dSopenharmony_ci- 2968be168c0dSopenharmony_ci-typedef struct ContextC { 2969be168c0dSopenharmony_ci- std::vector<std::shared_ptr<DeviceInfoC>> device_info_list; 2970be168c0dSopenharmony_ci- int32_t thread_num = 2; 2971be168c0dSopenharmony_ci- bool enable_parallel = false; 2972be168c0dSopenharmony_ci- std::vector<int32_t> affinity_core_list; 2973be168c0dSopenharmony_ci- int affinity_mode = 0; 2974be168c0dSopenharmony_ci- int delegate_mode = 0; 2975be168c0dSopenharmony_ci- std::shared_ptr<Delegate> delegate = nullptr; 2976be168c0dSopenharmony_ci-} ContextC; 2977be168c0dSopenharmony_ci-} // namespace mindspore 2978be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_ 2979be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc 2980be168c0dSopenharmony_ciindex 802df6b1..9da52d76 100644 2981be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/model_c.cc 2982be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/model_c.cc 2983be168c0dSopenharmony_ci@@ -17,321 +17,135 @@ 2984be168c0dSopenharmony_ci #include <vector> 2985be168c0dSopenharmony_ci #include <cstdint> 2986be168c0dSopenharmony_ci #include "include/api/context.h" 2987be168c0dSopenharmony_ci+#include <include/api/serialization.h> 2988be168c0dSopenharmony_ci #include "include/api/types.h" 2989be168c0dSopenharmony_ci #include "src/litert/cxx_api/tensor/tensor_impl.h" 2990be168c0dSopenharmony_ci #include "src/litert/cxx_api/converters.h" 2991be168c0dSopenharmony_ci-#include "src/litert/lite_session.h" 2992be168c0dSopenharmony_ci-#include "src/litert/cpu_info.h" 2993be168c0dSopenharmony_ci+#include "src/litert//cxx_api/model/model_impl.h" 2994be168c0dSopenharmony_ci 2995be168c0dSopenharmony_ci namespace mindspore { 2996be168c0dSopenharmony_ci class ModelC { 2997be168c0dSopenharmony_ci- public: 2998be168c0dSopenharmony_ci- ModelC() : session_(nullptr), context_(nullptr) {} 2999be168c0dSopenharmony_ci+public: 3000be168c0dSopenharmony_ci+ ModelC() : model_(nullptr) {} 3001be168c0dSopenharmony_ci ~ModelC() { 3002be168c0dSopenharmony_ci- for (auto &impl : tensor_map_) { 3003be168c0dSopenharmony_ci- delete impl.second; 3004be168c0dSopenharmony_ci+ for (auto in : inputs_) { 3005be168c0dSopenharmony_ci+ delete in; 3006be168c0dSopenharmony_ci+ } 3007be168c0dSopenharmony_ci+ for (auto out : outputs_) { 3008be168c0dSopenharmony_ci+ delete out; 3009be168c0dSopenharmony_ci+ } 3010be168c0dSopenharmony_ci+ for (auto out : outputs_train_) { 3011be168c0dSopenharmony_ci+ delete out; 3012be168c0dSopenharmony_ci } 3013be168c0dSopenharmony_ci } 3014be168c0dSopenharmony_ci 3015be168c0dSopenharmony_ci- Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context); 3016be168c0dSopenharmony_ci- Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context); 3017be168c0dSopenharmony_ci- Status Resize(const std::vector<LiteTensorImpl *> &inputs, const std::vector<std::vector<int64_t>> &shapes); 3018be168c0dSopenharmony_ci- 3019be168c0dSopenharmony_ci- Status Predict(const OH_AI_TensorHandle *inputs, size_t input_num, OH_AI_TensorHandle **outputs, size_t *output_num, 3020be168c0dSopenharmony_ci- const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after); 3021be168c0dSopenharmony_ci- 3022be168c0dSopenharmony_ci- LiteTensorImpl **GetInputs(size_t *input_num); 3023be168c0dSopenharmony_ci- LiteTensorImpl **GetOutputs(size_t *output_num); 3024be168c0dSopenharmony_ci+ MSTensor **GetInputs(size_t *input_num); 3025be168c0dSopenharmony_ci+ MSTensor **GetOutputs(size_t *output_num); 3026be168c0dSopenharmony_ci+ mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &oh_callback); 3027be168c0dSopenharmony_ci+ std::shared_ptr<Model> model_; 3028be168c0dSopenharmony_ci+ std::shared_ptr<Context> context_; 3029be168c0dSopenharmony_ci 3030be168c0dSopenharmony_ci- private: 3031be168c0dSopenharmony_ci- Status RunGraph(const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after); 3032be168c0dSopenharmony_ci- void ResetTensorData(std::vector<void *> old_data, std::vector<lite::Tensor *> tensors); 3033be168c0dSopenharmony_ci- LiteTensorImpl *TensorToTensorImpl(mindspore::lite::Tensor *tensor); 3034be168c0dSopenharmony_ci- 3035be168c0dSopenharmony_ci- private: 3036be168c0dSopenharmony_ci- std::shared_ptr<lite::LiteSession> session_ = nullptr; 3037be168c0dSopenharmony_ci- std::shared_ptr<const ContextC> context_ = nullptr; 3038be168c0dSopenharmony_ci- std::map<mindspore::lite::Tensor *, LiteTensorImpl *> tensor_map_; 3039be168c0dSopenharmony_ci- std::vector<LiteTensorImpl *> inputs_; 3040be168c0dSopenharmony_ci- std::vector<LiteTensorImpl *> outputs_; 3041be168c0dSopenharmony_ci- bool is_already_built = false; 3042be168c0dSopenharmony_ci+private: 3043be168c0dSopenharmony_ci+ MSTensor **GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors); 3044be168c0dSopenharmony_ci+ std::vector<MSTensor *> inputs_; 3045be168c0dSopenharmony_ci+ std::vector<MSTensor *> outputs_; 3046be168c0dSopenharmony_ci+ std::vector<MSTensor *> outputs_train_; 3047be168c0dSopenharmony_ci }; 3048be168c0dSopenharmony_ci 3049be168c0dSopenharmony_ci-Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) { 3050be168c0dSopenharmony_ci- if (is_already_built) { 3051be168c0dSopenharmony_ci- MS_LOG(ERROR) << "The model is already built."; 3052be168c0dSopenharmony_ci- return kLiteModelRebuild; 3053be168c0dSopenharmony_ci- } 3054be168c0dSopenharmony_ci- if (!PlatformInstructionSetSupportCheck()) { 3055be168c0dSopenharmony_ci- MS_LOG(ERROR) << "The platform exist don't support's instruction."; 3056be168c0dSopenharmony_ci- return kLiteNotSupport; 3057be168c0dSopenharmony_ci- } 3058be168c0dSopenharmony_ci- if(context_.get() != model_context){ 3059be168c0dSopenharmony_ci- context_.reset(model_context); 3060be168c0dSopenharmony_ci- } 3061be168c0dSopenharmony_ci- session_ = std::make_shared<lite::LiteSession>(); 3062be168c0dSopenharmony_ci- if (session_ == nullptr) { 3063be168c0dSopenharmony_ci- MS_LOG(ERROR) << "create session failed"; 3064be168c0dSopenharmony_ci- return kLiteNullptr; 3065be168c0dSopenharmony_ci- } 3066be168c0dSopenharmony_ci- auto ret = session_->Init(ContextUtils::Convert(model_context)); 3067be168c0dSopenharmony_ci- if (ret != mindspore::lite::RET_OK) { 3068be168c0dSopenharmony_ci- MS_LOG(ERROR) << "init session failed"; 3069be168c0dSopenharmony_ci- return static_cast<StatusCode>(ret); 3070be168c0dSopenharmony_ci- } 3071be168c0dSopenharmony_ci- ret = session_->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), model_type, data_size); 3072be168c0dSopenharmony_ci- if (ret != RET_OK) { 3073be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Load and compile failed"; 3074be168c0dSopenharmony_ci- return static_cast<StatusCode>(ret); 3075be168c0dSopenharmony_ci- } 3076be168c0dSopenharmony_ci- is_already_built = true; 3077be168c0dSopenharmony_ci- return static_cast<StatusCode>(kSuccess); 3078be168c0dSopenharmony_ci-} 3079be168c0dSopenharmony_ci- 3080be168c0dSopenharmony_ci-Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) { 3081be168c0dSopenharmony_ci- if (is_already_built) { 3082be168c0dSopenharmony_ci- MS_LOG(ERROR) << "The model is already built."; 3083be168c0dSopenharmony_ci- return kLiteModelRebuild; 3084be168c0dSopenharmony_ci- } 3085be168c0dSopenharmony_ci- if (!PlatformInstructionSetSupportCheck()) { 3086be168c0dSopenharmony_ci- MS_LOG(ERROR) << "The platform exist don't support's instruction."; 3087be168c0dSopenharmony_ci- return kLiteNotSupport; 3088be168c0dSopenharmony_ci- } 3089be168c0dSopenharmony_ci- if(context_.get() != model_context){ 3090be168c0dSopenharmony_ci- context_.reset(model_context); 3091be168c0dSopenharmony_ci- } 3092be168c0dSopenharmony_ci- session_ = std::make_shared<lite::LiteSession>(); 3093be168c0dSopenharmony_ci- if (session_ == nullptr) { 3094be168c0dSopenharmony_ci- MS_LOG(ERROR) << "create session failed"; 3095be168c0dSopenharmony_ci- return kLiteNullptr; 3096be168c0dSopenharmony_ci- } 3097be168c0dSopenharmony_ci- auto ret = session_->Init(ContextUtils::Convert(model_context)); 3098be168c0dSopenharmony_ci- if (ret != mindspore::lite::RET_OK) { 3099be168c0dSopenharmony_ci- MS_LOG(ERROR) << "init session failed"; 3100be168c0dSopenharmony_ci- return static_cast<StatusCode>(ret); 3101be168c0dSopenharmony_ci+MSTensor **ModelC::GetInputs(size_t *input_num) { 3102be168c0dSopenharmony_ci+ if (model_ == nullptr) { 3103be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model_ is nullptr."; 3104be168c0dSopenharmony_ci+ return nullptr; 3105be168c0dSopenharmony_ci } 3106be168c0dSopenharmony_ci- ret = session_->LoadModelAndCompileByPath(model_path, model_type); 3107be168c0dSopenharmony_ci- if (ret != RET_OK) { 3108be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Load and compile failed"; 3109be168c0dSopenharmony_ci- return static_cast<StatusCode>(ret); 3110be168c0dSopenharmony_ci+ if (!inputs_.empty()) { 3111be168c0dSopenharmony_ci+ *input_num = inputs_.size(); 3112be168c0dSopenharmony_ci+ return inputs_.data(); 3113be168c0dSopenharmony_ci } 3114be168c0dSopenharmony_ci- is_already_built = true; 3115be168c0dSopenharmony_ci- return static_cast<StatusCode>(kSuccess); 3116be168c0dSopenharmony_ci-} 3117be168c0dSopenharmony_ci 3118be168c0dSopenharmony_ci-Status ModelC::Resize(const std::vector<LiteTensorImpl *> &inputs, const std::vector<std::vector<int64_t>> &shapes) { 3119be168c0dSopenharmony_ci- std::vector<lite::Tensor *> inner_input; 3120be168c0dSopenharmony_ci- size_t input_num = inputs.size(); 3121be168c0dSopenharmony_ci- for (size_t i = 0; i < input_num; i++) { 3122be168c0dSopenharmony_ci- auto input = inputs[i]; 3123be168c0dSopenharmony_ci- if (input == nullptr || input->lite_tensor() == nullptr) { 3124be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Input tensor is null."; 3125be168c0dSopenharmony_ci- return kLiteInputTensorError; 3126be168c0dSopenharmony_ci+ auto inputs = model_->GetInputs(); 3127be168c0dSopenharmony_ci+ *input_num = inputs.size(); 3128be168c0dSopenharmony_ci+ inputs_.resize(inputs.size(), nullptr); 3129be168c0dSopenharmony_ci+ for (size_t i = 0; i < inputs.size(); i++) { 3130be168c0dSopenharmony_ci+ inputs_[i] = new (std::nothrow) MSTensor(inputs[i].impl()); 3131be168c0dSopenharmony_ci+ if (inputs_[i] == nullptr) { 3132be168c0dSopenharmony_ci+ inputs_.clear(); 3133be168c0dSopenharmony_ci+ return nullptr; 3134be168c0dSopenharmony_ci } 3135be168c0dSopenharmony_ci- inner_input.push_back(input->lite_tensor()); 3136be168c0dSopenharmony_ci } 3137be168c0dSopenharmony_ci- size_t shape_num = shapes.size(); 3138be168c0dSopenharmony_ci- std::vector<std::vector<int32_t>> inner_shapes(shape_num); 3139be168c0dSopenharmony_ci- for (size_t i = 0; i < shape_num; i++) { 3140be168c0dSopenharmony_ci- std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(inner_shapes[i]), 3141be168c0dSopenharmony_ci- [](int64_t value) { return static_cast<int32_t>(value); }); 3142be168c0dSopenharmony_ci- } 3143be168c0dSopenharmony_ci- if (session_ == nullptr) { 3144be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Session implement is null."; 3145be168c0dSopenharmony_ci- return kLiteNullptr; 3146be168c0dSopenharmony_ci- } 3147be168c0dSopenharmony_ci- auto ret = session_->Resize(inner_input, inner_shapes); 3148be168c0dSopenharmony_ci- return static_cast<StatusCode>(ret); 3149be168c0dSopenharmony_ci+ return inputs_.data(); 3150be168c0dSopenharmony_ci } 3151be168c0dSopenharmony_ci 3152be168c0dSopenharmony_ci-void ModelC::ResetTensorData(std::vector<void *> old_data, std::vector<lite::Tensor *> tensors) { 3153be168c0dSopenharmony_ci- for (size_t j = 0; j < old_data.size(); j++) { 3154be168c0dSopenharmony_ci- tensors.at(j)->set_data(old_data.at(j)); 3155be168c0dSopenharmony_ci+MSTensor **ModelC::GetOutputs(size_t *output_num) { 3156be168c0dSopenharmony_ci+ if (model_->GetTrainMode() == true) { 3157be168c0dSopenharmony_ci+ return GetOutputsTensor(output_num, &outputs_train_); 3158be168c0dSopenharmony_ci+ } else { 3159be168c0dSopenharmony_ci+ return GetOutputsTensor(output_num, &outputs_); 3160be168c0dSopenharmony_ci } 3161be168c0dSopenharmony_ci } 3162be168c0dSopenharmony_ci 3163be168c0dSopenharmony_ci-Status ModelC::Predict(const OH_AI_TensorHandle *inputs, size_t input_num, OH_AI_TensorHandle **outputs, 3164be168c0dSopenharmony_ci- size_t *output_num, const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after) { 3165be168c0dSopenharmony_ci- if (outputs == nullptr || session_ == nullptr) { 3166be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3167be168c0dSopenharmony_ci- return kLiteError; 3168be168c0dSopenharmony_ci+MSTensor **ModelC::GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors) { 3169be168c0dSopenharmony_ci+ if (model_ == nullptr) { 3170be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model_ is nullptr."; 3171be168c0dSopenharmony_ci+ return nullptr; 3172be168c0dSopenharmony_ci } 3173be168c0dSopenharmony_ci- auto model_inputs = session_->GetInputs(); 3174be168c0dSopenharmony_ci- if (model_inputs.size() != input_num) { 3175be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Wrong input size."; 3176be168c0dSopenharmony_ci- return kLiteError; 3177be168c0dSopenharmony_ci+ if (!vec_tensors->empty()) { 3178be168c0dSopenharmony_ci+ *output_num = vec_tensors->size(); 3179be168c0dSopenharmony_ci+ return vec_tensors->data(); 3180be168c0dSopenharmony_ci } 3181be168c0dSopenharmony_ci- std::vector<void *> old_data; 3182be168c0dSopenharmony_ci- for (size_t i = 0; i < input_num; i++) { 3183be168c0dSopenharmony_ci- auto real_input = model_inputs[i]; 3184be168c0dSopenharmony_ci- auto user_input = static_cast<LiteTensorImpl *>(inputs[i]); 3185be168c0dSopenharmony_ci- if (user_input->DataType() != static_cast<DataType>(real_input->data_type())) { 3186be168c0dSopenharmony_ci- ResetTensorData(old_data, model_inputs); 3187be168c0dSopenharmony_ci- MS_LOG(ERROR) << "DataType does not match, input:" << user_input->Name() 3188be168c0dSopenharmony_ci- << ", real:" << real_input->tensor_name(); 3189be168c0dSopenharmony_ci- return kLiteInputTensorError; 3190be168c0dSopenharmony_ci- } 3191be168c0dSopenharmony_ci- if (user_input->Data() == nullptr) { 3192be168c0dSopenharmony_ci- ResetTensorData(old_data, model_inputs); 3193be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has no data."; 3194be168c0dSopenharmony_ci- return kLiteInputTensorError; 3195be168c0dSopenharmony_ci- } 3196be168c0dSopenharmony_ci 3197be168c0dSopenharmony_ci- // GPU tensor can't manipulate CPU memory which the user provides. 3198be168c0dSopenharmony_ci- // When model input is GPU tensor and user input is NOT GPU data, 3199be168c0dSopenharmony_ci- // just free model input's data for late GPU Tensor filling. 3200be168c0dSopenharmony_ci- if (IS_OPENCL_ALLOCATOR(real_input->allocator()) && (!IS_OPENCL_ALLOCATOR(user_input->GetAllocator()))) { 3201be168c0dSopenharmony_ci- real_input->FreeData(); 3202be168c0dSopenharmony_ci- } 3203be168c0dSopenharmony_ci- old_data.push_back(real_input->data()); // Save original data in model tensors. 3204be168c0dSopenharmony_ci- 3205be168c0dSopenharmony_ci- if (real_input->data_type() == kObjectTypeString) { 3206be168c0dSopenharmony_ci- std::vector<int32_t> shape; 3207be168c0dSopenharmony_ci- std::transform(user_input->Shape().begin(), user_input->Shape().end(), std::back_inserter(shape), 3208be168c0dSopenharmony_ci- [](int64_t value) { return static_cast<int32_t>(value); }); 3209be168c0dSopenharmony_ci- real_input->set_shape(shape); 3210be168c0dSopenharmony_ci- real_input->set_data(user_input->MutableData()); 3211be168c0dSopenharmony_ci- } else { 3212be168c0dSopenharmony_ci- if (user_input->MutableData() != real_input->data()) { 3213be168c0dSopenharmony_ci- if (real_input->Size() != user_input->DataSize()) { 3214be168c0dSopenharmony_ci- ResetTensorData(old_data, model_inputs); 3215be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has wrong data size."; 3216be168c0dSopenharmony_ci- return kLiteInputTensorError; 3217be168c0dSopenharmony_ci- } 3218be168c0dSopenharmony_ci- if (!IS_OPENCL_ALLOCATOR(real_input->allocator())) { 3219be168c0dSopenharmony_ci- real_input->set_data(user_input->MutableData()); 3220be168c0dSopenharmony_ci- } else { 3221be168c0dSopenharmony_ci- // Use outside CPU data to fill GPU Tensor. 3222be168c0dSopenharmony_ci- auto dst_data = real_input->MutableData(); 3223be168c0dSopenharmony_ci- auto src_data = user_input->MutableData(); 3224be168c0dSopenharmony_ci- (void)memcpy(dst_data, src_data, real_input->Size()); 3225be168c0dSopenharmony_ci- } 3226be168c0dSopenharmony_ci- } 3227be168c0dSopenharmony_ci- } 3228be168c0dSopenharmony_ci- } 3229be168c0dSopenharmony_ci- auto ret = RunGraph(before, after); 3230be168c0dSopenharmony_ci- ResetTensorData(old_data, model_inputs); 3231be168c0dSopenharmony_ci- if (ret != kSuccess) { 3232be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Run graph failed."; 3233be168c0dSopenharmony_ci- return ret; 3234be168c0dSopenharmony_ci- } 3235be168c0dSopenharmony_ci- 3236be168c0dSopenharmony_ci- *outputs = reinterpret_cast<OH_AI_TensorHandle *>(GetOutputs(output_num)); 3237be168c0dSopenharmony_ci- return kSuccess; 3238be168c0dSopenharmony_ci-} 3239be168c0dSopenharmony_ci- 3240be168c0dSopenharmony_ci-Status ModelC::RunGraph(const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after) { 3241be168c0dSopenharmony_ci- KernelCallBack before_call_back = nullptr; 3242be168c0dSopenharmony_ci- KernelCallBack after_call_back = nullptr; 3243be168c0dSopenharmony_ci- if (before != nullptr) { 3244be168c0dSopenharmony_ci- before_call_back = [&](const std::vector<mindspore::lite::Tensor *> &before_inputs, 3245be168c0dSopenharmony_ci- const std::vector<mindspore::lite::Tensor *> &before_outputs, 3246be168c0dSopenharmony_ci- const MSCallBackParam &call_param) { 3247be168c0dSopenharmony_ci- std::vector<LiteTensorImpl> inputs_impl; 3248be168c0dSopenharmony_ci- std::vector<LiteTensorImpl> outputs_impl; 3249be168c0dSopenharmony_ci- std::vector<OH_AI_TensorHandle> op_inputs; 3250be168c0dSopenharmony_ci- std::vector<OH_AI_TensorHandle> op_outputs; 3251be168c0dSopenharmony_ci- size_t op_input_num = before_inputs.size(); 3252be168c0dSopenharmony_ci- for (size_t i = 0; i < op_input_num; i++) { 3253be168c0dSopenharmony_ci- inputs_impl.emplace_back(before_inputs[i]); 3254be168c0dSopenharmony_ci- op_inputs.push_back(&(inputs_impl.back())); 3255be168c0dSopenharmony_ci- } 3256be168c0dSopenharmony_ci- size_t op_output_num = before_outputs.size(); 3257be168c0dSopenharmony_ci- for (size_t i = 0; i < op_output_num; i++) { 3258be168c0dSopenharmony_ci- outputs_impl.emplace_back(before_outputs[i]); 3259be168c0dSopenharmony_ci- op_outputs.push_back(&(outputs_impl.back())); 3260be168c0dSopenharmony_ci- } 3261be168c0dSopenharmony_ci- const OH_AI_CallBackParam op_info = {const_cast<char *>(call_param.node_name.c_str()), 3262be168c0dSopenharmony_ci- const_cast<char *>(call_param.node_type.c_str())}; 3263be168c0dSopenharmony_ci- OH_AI_TensorHandleArray inputs = {op_input_num, op_inputs.data()}; 3264be168c0dSopenharmony_ci- OH_AI_TensorHandleArray outputs = {op_output_num, op_outputs.data()}; 3265be168c0dSopenharmony_ci- return before(inputs, outputs, op_info); 3266be168c0dSopenharmony_ci- }; 3267be168c0dSopenharmony_ci- } 3268be168c0dSopenharmony_ci- if (after != nullptr) { 3269be168c0dSopenharmony_ci- after_call_back = [&](const std::vector<mindspore::lite::Tensor *> &after_inputs, 3270be168c0dSopenharmony_ci- const std::vector<mindspore::lite::Tensor *> &after_outputs, 3271be168c0dSopenharmony_ci- const MSCallBackParam &call_param) { 3272be168c0dSopenharmony_ci- std::vector<LiteTensorImpl> inputs_impl; 3273be168c0dSopenharmony_ci- std::vector<LiteTensorImpl> outputs_impl; 3274be168c0dSopenharmony_ci- std::vector<OH_AI_TensorHandle> op_inputs; 3275be168c0dSopenharmony_ci- std::vector<OH_AI_TensorHandle> op_outputs; 3276be168c0dSopenharmony_ci- size_t op_input_num = after_inputs.size(); 3277be168c0dSopenharmony_ci- for (size_t i = 0; i < op_input_num; i++) { 3278be168c0dSopenharmony_ci- inputs_impl.emplace_back(after_inputs[i]); 3279be168c0dSopenharmony_ci- op_inputs.push_back(&(inputs_impl.back())); 3280be168c0dSopenharmony_ci- } 3281be168c0dSopenharmony_ci- size_t op_output_num = after_outputs.size(); 3282be168c0dSopenharmony_ci- for (size_t i = 0; i < op_output_num; i++) { 3283be168c0dSopenharmony_ci- outputs_impl.emplace_back(after_outputs[i]); 3284be168c0dSopenharmony_ci- op_outputs.push_back(&(outputs_impl.back())); 3285be168c0dSopenharmony_ci- } 3286be168c0dSopenharmony_ci- const OH_AI_CallBackParam op_info = {const_cast<char *>(call_param.node_name.c_str()), 3287be168c0dSopenharmony_ci- const_cast<char *>(call_param.node_type.c_str())}; 3288be168c0dSopenharmony_ci- OH_AI_TensorHandleArray inputs = {op_input_num, op_inputs.data()}; 3289be168c0dSopenharmony_ci- OH_AI_TensorHandleArray outputs = {op_output_num, op_outputs.data()}; 3290be168c0dSopenharmony_ci- return after(inputs, outputs, op_info); 3291be168c0dSopenharmony_ci- }; 3292be168c0dSopenharmony_ci- } 3293be168c0dSopenharmony_ci- auto ret = session_->RunGraph(before_call_back, after_call_back); 3294be168c0dSopenharmony_ci- return static_cast<StatusCode>(ret); 3295be168c0dSopenharmony_ci-} 3296be168c0dSopenharmony_ci- 3297be168c0dSopenharmony_ci-LiteTensorImpl *ModelC::TensorToTensorImpl(mindspore::lite::Tensor *tensor) { 3298be168c0dSopenharmony_ci- LiteTensorImpl *impl = nullptr; 3299be168c0dSopenharmony_ci- auto iter = tensor_map_.find(tensor); 3300be168c0dSopenharmony_ci- if (iter != tensor_map_.end()) { 3301be168c0dSopenharmony_ci- impl = iter->second; 3302be168c0dSopenharmony_ci- } else { 3303be168c0dSopenharmony_ci- impl = new (std::nothrow) LiteTensorImpl(tensor); 3304be168c0dSopenharmony_ci- if (impl == nullptr || impl->lite_tensor() == nullptr) { 3305be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Create tensor failed."; 3306be168c0dSopenharmony_ci+ auto outputs = model_->GetOutputs(); 3307be168c0dSopenharmony_ci+ *output_num = outputs.size(); 3308be168c0dSopenharmony_ci+ vec_tensors->resize(outputs.size(), nullptr); 3309be168c0dSopenharmony_ci+ for (size_t i = 0; i < outputs.size(); i++) { 3310be168c0dSopenharmony_ci+ (*vec_tensors)[i] = new (std::nothrow) MSTensor(outputs[i].impl()); 3311be168c0dSopenharmony_ci+ if ((*vec_tensors)[i] == nullptr) { 3312be168c0dSopenharmony_ci+ vec_tensors->clear(); 3313be168c0dSopenharmony_ci return nullptr; 3314be168c0dSopenharmony_ci } 3315be168c0dSopenharmony_ci- tensor_map_[tensor] = impl; 3316be168c0dSopenharmony_ci } 3317be168c0dSopenharmony_ci- return impl; 3318be168c0dSopenharmony_ci+ return vec_tensors->data(); 3319be168c0dSopenharmony_ci } 3320be168c0dSopenharmony_ci 3321be168c0dSopenharmony_ci-LiteTensorImpl **ModelC::GetInputs(size_t *input_num) { 3322be168c0dSopenharmony_ci- if (session_ == nullptr || input_num == nullptr) { 3323be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Session is null."; 3324be168c0dSopenharmony_ci- return nullptr; 3325be168c0dSopenharmony_ci- } 3326be168c0dSopenharmony_ci- auto inputs = session_->GetInputs(); 3327be168c0dSopenharmony_ci- *input_num = inputs.size(); 3328be168c0dSopenharmony_ci- if (inputs_.capacity() < *input_num) { 3329be168c0dSopenharmony_ci- inputs_.reserve(*input_num); 3330be168c0dSopenharmony_ci- } 3331be168c0dSopenharmony_ci- inputs_.clear(); 3332be168c0dSopenharmony_ci- std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_), 3333be168c0dSopenharmony_ci- [&](lite::Tensor *input) { return TensorToTensorImpl(input); }); 3334be168c0dSopenharmony_ci- return inputs_.data(); 3335be168c0dSopenharmony_ci-} 3336be168c0dSopenharmony_ci+mindspore::MSKernelCallBack ModelC::TransCallBack(const OH_AI_KernelCallBack &oh_callback) { 3337be168c0dSopenharmony_ci+ mindspore::MSKernelCallBack call_back = nullptr; 3338be168c0dSopenharmony_ci+ if (oh_callback != nullptr) { 3339be168c0dSopenharmony_ci+ call_back = [&](const std::vector<mindspore::MSTensor> &inputs, 3340be168c0dSopenharmony_ci+ const std::vector<mindspore::MSTensor> &outputs, 3341be168c0dSopenharmony_ci+ const mindspore::MSCallBackParam &opInfo) { 3342be168c0dSopenharmony_ci+ std::vector<OH_AI_TensorHandle> vec_inputs; 3343be168c0dSopenharmony_ci+ std::vector<OH_AI_TensorHandle> vec_outputs; 3344be168c0dSopenharmony_ci+ OH_AI_CallBackParam call_back = {const_cast<char *>(opInfo.node_name.c_str()), 3345be168c0dSopenharmony_ci+ const_cast<char *>(opInfo.node_type.c_str())}; 3346be168c0dSopenharmony_ci+ size_t inputs_handle_num = inputs.size(); 3347be168c0dSopenharmony_ci+ for (size_t i = 0; i < inputs_handle_num; i++) { 3348be168c0dSopenharmony_ci+ vec_inputs.push_back( 3349be168c0dSopenharmony_ci+ static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(inputs)[i]))); 3350be168c0dSopenharmony_ci+ } 3351be168c0dSopenharmony_ci+ size_t outputs_handle_num = inputs.size(); 3352be168c0dSopenharmony_ci+ for (size_t i = 0; i < outputs_handle_num; i++) { 3353be168c0dSopenharmony_ci+ vec_outputs.push_back( 3354be168c0dSopenharmony_ci+ static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(outputs)[i]))); 3355be168c0dSopenharmony_ci+ } 3356be168c0dSopenharmony_ci 3357be168c0dSopenharmony_ci-LiteTensorImpl **ModelC::GetOutputs(size_t *output_num) { 3358be168c0dSopenharmony_ci- if (session_ == nullptr || output_num == nullptr) { 3359be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Session is null."; 3360be168c0dSopenharmony_ci- return nullptr; 3361be168c0dSopenharmony_ci- } 3362be168c0dSopenharmony_ci- auto outputs = session_->GetOutputs(); 3363be168c0dSopenharmony_ci- *output_num = outputs.size(); 3364be168c0dSopenharmony_ci- if (outputs_.capacity() < *output_num) { 3365be168c0dSopenharmony_ci- outputs_.reserve(*output_num); 3366be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray handle_inputs = {inputs_handle_num, vec_inputs.data()}; 3367be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray handle_outputs = {outputs_handle_num, vec_outputs.data()}; 3368be168c0dSopenharmony_ci+ return oh_callback(handle_inputs, handle_outputs, call_back); 3369be168c0dSopenharmony_ci+ }; 3370be168c0dSopenharmony_ci } 3371be168c0dSopenharmony_ci- outputs_.clear(); 3372be168c0dSopenharmony_ci- std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_), 3373be168c0dSopenharmony_ci- [&](std::unordered_map<std::string, mindspore::lite::Tensor *>::value_type iter) { 3374be168c0dSopenharmony_ci- return TensorToTensorImpl(iter.second); 3375be168c0dSopenharmony_ci- }); 3376be168c0dSopenharmony_ci- return outputs_.data(); 3377be168c0dSopenharmony_ci+ return call_back; 3378be168c0dSopenharmony_ci } 3379be168c0dSopenharmony_ci } // namespace mindspore 3380be168c0dSopenharmony_ci 3381be168c0dSopenharmony_ci OH_AI_ModelHandle OH_AI_ModelCreate() { 3382be168c0dSopenharmony_ci auto impl = new (std::nothrow) mindspore::ModelC(); 3383be168c0dSopenharmony_ci if (impl == nullptr) { 3384be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Model implement is null."; 3385be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Model implement is nullptr."; 3386be168c0dSopenharmony_ci+ return nullptr; 3387be168c0dSopenharmony_ci+ } 3388be168c0dSopenharmony_ci+ impl->model_ = std::make_shared<mindspore::Model>(); 3389be168c0dSopenharmony_ci+ if (impl->model_ == nullptr) { 3390be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model_ is nullptr."; 3391be168c0dSopenharmony_ci+ delete impl; 3392be168c0dSopenharmony_ci return nullptr; 3393be168c0dSopenharmony_ci } 3394be168c0dSopenharmony_ci return static_cast<OH_AI_ModelHandle>(impl); 3395be168c0dSopenharmony_ci@@ -358,55 +172,59 @@ size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) { 3396be168c0dSopenharmony_ci OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, 3397be168c0dSopenharmony_ci OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context) { 3398be168c0dSopenharmony_ci if (model == nullptr || model_data == nullptr || model_context == nullptr) { 3399be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3400be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model/model_data/model_context is nullptr."; 3401be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_NULLPTR; 3402be168c0dSopenharmony_ci } 3403be168c0dSopenharmony_ci if (model_type == OH_AI_MODELTYPE_INVALID) { 3404be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is invalid."; 3405be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model_type is invalid."; 3406be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_PARAM_INVALID; 3407be168c0dSopenharmony_ci } 3408be168c0dSopenharmony_ci- mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context); 3409be168c0dSopenharmony_ci+ mindspore::Context *context = static_cast<mindspore::Context *>(model_context); 3410be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 3411be168c0dSopenharmony_ci- auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context); 3412be168c0dSopenharmony_ci+ if (impl->context_.get() != context) { 3413be168c0dSopenharmony_ci+ impl->context_.reset(context); 3414be168c0dSopenharmony_ci+ } 3415be168c0dSopenharmony_ci+ auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_); 3416be168c0dSopenharmony_ci return static_cast<OH_AI_Status>(ret.StatusCode()); 3417be168c0dSopenharmony_ci } 3418be168c0dSopenharmony_ci 3419be168c0dSopenharmony_ci OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type, 3420be168c0dSopenharmony_ci const OH_AI_ContextHandle model_context) { 3421be168c0dSopenharmony_ci if (model == nullptr || model_path == nullptr || model_context == nullptr) { 3422be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3423be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model/model_path/model_context is nullptr."; 3424be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_NULLPTR; 3425be168c0dSopenharmony_ci } 3426be168c0dSopenharmony_ci if (model_type == OH_AI_MODELTYPE_INVALID) { 3427be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is invalid."; 3428be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model_type is invalid."; 3429be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_PARAM_INVALID; 3430be168c0dSopenharmony_ci } 3431be168c0dSopenharmony_ci- mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context); 3432be168c0dSopenharmony_ci+ mindspore::Context *context = static_cast<mindspore::Context *>(model_context); 3433be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 3434be168c0dSopenharmony_ci- auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context); 3435be168c0dSopenharmony_ci+ if (impl->context_.get() != context) { 3436be168c0dSopenharmony_ci+ impl->context_.reset(context); 3437be168c0dSopenharmony_ci+ } 3438be168c0dSopenharmony_ci+ auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_); 3439be168c0dSopenharmony_ci return static_cast<OH_AI_Status>(ret.StatusCode()); 3440be168c0dSopenharmony_ci } 3441be168c0dSopenharmony_ci 3442be168c0dSopenharmony_ci OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs, 3443be168c0dSopenharmony_ci OH_AI_ShapeInfo *shape_infos, size_t shape_info_num) { 3444be168c0dSopenharmony_ci if (model == nullptr || shape_infos == nullptr) { 3445be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3446be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model/shape_infos is nullptr."; 3447be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_NULLPTR; 3448be168c0dSopenharmony_ci } 3449be168c0dSopenharmony_ci- std::vector<mindspore::LiteTensorImpl *> vec_inputs; 3450be168c0dSopenharmony_ci- std::transform(inputs.handle_list, inputs.handle_list + inputs.handle_num, std::back_inserter(vec_inputs), 3451be168c0dSopenharmony_ci- [](OH_AI_TensorHandle value) { return static_cast<mindspore::LiteTensorImpl *>(value); }); 3452be168c0dSopenharmony_ci+ std::vector<mindspore::MSTensor> vec_inputs; 3453be168c0dSopenharmony_ci+ for (size_t i = 0; i < inputs.handle_num; ++i) { 3454be168c0dSopenharmony_ci+ vec_inputs.push_back(*static_cast<mindspore::MSTensor *>(inputs.handle_list[i])); 3455be168c0dSopenharmony_ci+ } 3456be168c0dSopenharmony_ci+ 3457be168c0dSopenharmony_ci std::vector<std::vector<int64_t>> vec_dims; 3458be168c0dSopenharmony_ci for (size_t i = 0; i < shape_info_num; i++) { 3459be168c0dSopenharmony_ci std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num); 3460be168c0dSopenharmony_ci- if (std::any_of(shape.begin(), shape.end(), [](int64_t val) { return val < 0 || val > INT32_MAX; })) { 3461be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Invalid shape: " << shape << ", each dimension must be in [0, INT32_MAX]"; 3462be168c0dSopenharmony_ci- return OH_AI_STATUS_LITE_PARAM_INVALID; 3463be168c0dSopenharmony_ci- } 3464be168c0dSopenharmony_ci vec_dims.push_back(shape); 3465be168c0dSopenharmony_ci } 3466be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 3467be168c0dSopenharmony_ci- auto ret = impl->Resize(vec_inputs, vec_dims); 3468be168c0dSopenharmony_ci+ auto ret = impl->model_->Resize(vec_inputs, vec_dims); 3469be168c0dSopenharmony_ci return static_cast<OH_AI_Status>(ret.StatusCode()); 3470be168c0dSopenharmony_ci } 3471be168c0dSopenharmony_ci 3472be168c0dSopenharmony_ci@@ -414,15 +232,25 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl 3473be168c0dSopenharmony_ci OH_AI_TensorHandleArray *outputs, const OH_AI_KernelCallBack before, 3474be168c0dSopenharmony_ci const OH_AI_KernelCallBack after) { 3475be168c0dSopenharmony_ci if (model == nullptr) { 3476be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3477be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3478be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_NULLPTR; 3479be168c0dSopenharmony_ci } 3480be168c0dSopenharmony_ci+ std::vector<mindspore::MSTensor> ms_tensor_inputs; 3481be168c0dSopenharmony_ci+ for (size_t i = 0; i < inputs.handle_num; i++) { 3482be168c0dSopenharmony_ci+ auto user_input = static_cast<mindspore::MSTensor *>(inputs.handle_list[i]); 3483be168c0dSopenharmony_ci+ ms_tensor_inputs.push_back(*user_input); 3484be168c0dSopenharmony_ci+ } 3485be168c0dSopenharmony_ci+ 3486be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 3487be168c0dSopenharmony_ci- auto ret = impl->Predict(inputs.handle_list, inputs.handle_num, &(outputs->handle_list), &(outputs->handle_num), 3488be168c0dSopenharmony_ci- before, after); 3489be168c0dSopenharmony_ci+ mindspore::MSKernelCallBack before_call_back = impl->TransCallBack(before); 3490be168c0dSopenharmony_ci+ mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after); 3491be168c0dSopenharmony_ci+ 3492be168c0dSopenharmony_ci+ std::vector<mindspore::MSTensor> ms_tensor_outputs; 3493be168c0dSopenharmony_ci+ auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back); 3494be168c0dSopenharmony_ci if (!ret.IsOk()) { 3495be168c0dSopenharmony_ci MS_LOG(ERROR) << "Predict fail, ret :" << ret; 3496be168c0dSopenharmony_ci } 3497be168c0dSopenharmony_ci+ outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&outputs->handle_num)); 3498be168c0dSopenharmony_ci return static_cast<OH_AI_Status>(ret.StatusCode()); 3499be168c0dSopenharmony_ci } 3500be168c0dSopenharmony_ci 3501be168c0dSopenharmony_ci@@ -431,11 +259,6 @@ OH_AI_Status OH_AI_ModelRunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallB 3502be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_NOT_SUPPORT; 3503be168c0dSopenharmony_ci } 3504be168c0dSopenharmony_ci 3505be168c0dSopenharmony_ci-OH_AI_Status OH_AI_ModelSetTrainMode(const OH_AI_ModelHandle model, bool train) { 3506be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Unsupported Feature."; 3507be168c0dSopenharmony_ci- return OH_AI_STATUS_LITE_NOT_SUPPORT; 3508be168c0dSopenharmony_ci-} 3509be168c0dSopenharmony_ci- 3510be168c0dSopenharmony_ci OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *export_path) { 3511be168c0dSopenharmony_ci MS_LOG(ERROR) << "Unsupported Feature."; 3512be168c0dSopenharmony_ci return OH_AI_STATUS_LITE_NOT_SUPPORT; 3513be168c0dSopenharmony_ci@@ -443,7 +266,7 @@ OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char * 3514be168c0dSopenharmony_ci 3515be168c0dSopenharmony_ci OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) { 3516be168c0dSopenharmony_ci if (model == nullptr) { 3517be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3518be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3519be168c0dSopenharmony_ci return {0, nullptr}; 3520be168c0dSopenharmony_ci } 3521be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 3522be168c0dSopenharmony_ci@@ -454,7 +277,7 @@ OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) { 3523be168c0dSopenharmony_ci 3524be168c0dSopenharmony_ci OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) { 3525be168c0dSopenharmony_ci if (model == nullptr) { 3526be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3527be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3528be168c0dSopenharmony_ci return {0, nullptr}; 3529be168c0dSopenharmony_ci } 3530be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 3531be168c0dSopenharmony_ci@@ -465,7 +288,7 @@ OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) { 3532be168c0dSopenharmony_ci 3533be168c0dSopenharmony_ci OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) { 3534be168c0dSopenharmony_ci if (model == nullptr || tensor_name == nullptr) { 3535be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3536be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model/tensor_name is nullptr."; 3537be168c0dSopenharmony_ci return nullptr; 3538be168c0dSopenharmony_ci } 3539be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 3540be168c0dSopenharmony_ci@@ -482,7 +305,7 @@ OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model 3541be168c0dSopenharmony_ci 3542be168c0dSopenharmony_ci OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) { 3543be168c0dSopenharmony_ci if (model == nullptr || tensor_name == nullptr) { 3544be168c0dSopenharmony_ci- MS_LOG(ERROR) << "param is nullptr."; 3545be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model/tensor_name is nullptr."; 3546be168c0dSopenharmony_ci return nullptr; 3547be168c0dSopenharmony_ci } 3548be168c0dSopenharmony_ci auto impl = static_cast<mindspore::ModelC *>(model); 3549be168c0dSopenharmony_ci@@ -496,3 +319,294 @@ OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle mode 3550be168c0dSopenharmony_ci MS_LOG(ERROR) << "tensor is not exist."; 3551be168c0dSopenharmony_ci return nullptr; 3552be168c0dSopenharmony_ci } 3553be168c0dSopenharmony_ci+ 3554be168c0dSopenharmony_ci+OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate() { 3555be168c0dSopenharmony_ci+ auto impl = new (std::nothrow) mindspore::TrainCfg(); 3556be168c0dSopenharmony_ci+ if (impl == nullptr) { 3557be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "TrainCfg implement is nullptr."; 3558be168c0dSopenharmony_ci+ return nullptr; 3559be168c0dSopenharmony_ci+ } 3560be168c0dSopenharmony_ci+ return static_cast<OH_AI_TrainCfgHandle>(impl); 3561be168c0dSopenharmony_ci+} 3562be168c0dSopenharmony_ci+ 3563be168c0dSopenharmony_ci+void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg) { 3564be168c0dSopenharmony_ci+ if (train_cfg != nullptr && *train_cfg != nullptr) { 3565be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::TrainCfg *>(*train_cfg); 3566be168c0dSopenharmony_ci+ delete impl; 3567be168c0dSopenharmony_ci+ *train_cfg = nullptr; 3568be168c0dSopenharmony_ci+ } 3569be168c0dSopenharmony_ci+} 3570be168c0dSopenharmony_ci+ 3571be168c0dSopenharmony_ci+char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) { 3572be168c0dSopenharmony_ci+ if (train_cfg == nullptr || num == nullptr) { 3573be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "train_cfg/num is nullptr."; 3574be168c0dSopenharmony_ci+ return nullptr; 3575be168c0dSopenharmony_ci+ } 3576be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::TrainCfg *>(train_cfg); 3577be168c0dSopenharmony_ci+ auto loss_name = impl->GetLossName(); 3578be168c0dSopenharmony_ci+ *num = loss_name.size(); 3579be168c0dSopenharmony_ci+ char **name = static_cast<char **>(malloc(loss_name.size())); 3580be168c0dSopenharmony_ci+ if (name == nullptr) { 3581be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Failed to malloc loss_name."; 3582be168c0dSopenharmony_ci+ return nullptr; 3583be168c0dSopenharmony_ci+ } 3584be168c0dSopenharmony_ci+ for (size_t i = 0; i < loss_name.size(); i++) { 3585be168c0dSopenharmony_ci+ name[i] = static_cast<char *>(malloc(loss_name[i].size() + 1)); 3586be168c0dSopenharmony_ci+ strcpy(name[i], loss_name[i].c_str()); 3587be168c0dSopenharmony_ci+ } 3588be168c0dSopenharmony_ci+ return name; 3589be168c0dSopenharmony_ci+} 3590be168c0dSopenharmony_ci+ 3591be168c0dSopenharmony_ci+void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num) { 3592be168c0dSopenharmony_ci+ if (train_cfg == nullptr) { 3593be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "train_cfg is nullptr."; 3594be168c0dSopenharmony_ci+ return; 3595be168c0dSopenharmony_ci+ } 3596be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::TrainCfg *>(train_cfg); 3597be168c0dSopenharmony_ci+ std::vector<std::string> vec_name; 3598be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 3599be168c0dSopenharmony_ci+ vec_name.push_back(loss_name[i]); 3600be168c0dSopenharmony_ci+ } 3601be168c0dSopenharmony_ci+ impl->SetLossName(vec_name); 3602be168c0dSopenharmony_ci+} 3603be168c0dSopenharmony_ci+ 3604be168c0dSopenharmony_ci+OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg) { 3605be168c0dSopenharmony_ci+ if (train_cfg == nullptr) { 3606be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "train_cfg is nullptr."; 3607be168c0dSopenharmony_ci+ return OH_AI_KO0; 3608be168c0dSopenharmony_ci+ } 3609be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::TrainCfg *>(train_cfg); 3610be168c0dSopenharmony_ci+ return static_cast<OH_AI_OptimizationLevel>(impl->optimization_level_); 3611be168c0dSopenharmony_ci+} 3612be168c0dSopenharmony_ci+ 3613be168c0dSopenharmony_ci+void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level) { 3614be168c0dSopenharmony_ci+ if (train_cfg == nullptr) { 3615be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "train_cfg is nullptr."; 3616be168c0dSopenharmony_ci+ return; 3617be168c0dSopenharmony_ci+ } 3618be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::TrainCfg *>(train_cfg); 3619be168c0dSopenharmony_ci+ impl->optimization_level_ = static_cast<mindspore::OptimizationLevel>(level); 3620be168c0dSopenharmony_ci+} 3621be168c0dSopenharmony_ci+ 3622be168c0dSopenharmony_ci+OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size, 3623be168c0dSopenharmony_ci+ OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, 3624be168c0dSopenharmony_ci+ const OH_AI_TrainCfgHandle train_cfg) { 3625be168c0dSopenharmony_ci+ if (model == nullptr || model_data == nullptr || model_context == nullptr) { 3626be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model/model_data/model_context is nullptr."; 3627be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 3628be168c0dSopenharmony_ci+ } 3629be168c0dSopenharmony_ci+ if (model_type == OH_AI_MODELTYPE_INVALID) { 3630be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model_type is invalid."; 3631be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3632be168c0dSopenharmony_ci+ } 3633be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3634be168c0dSopenharmony_ci+ 3635be168c0dSopenharmony_ci+ mindspore::Graph graph; 3636be168c0dSopenharmony_ci+ auto status = mindspore::Serialization::Load(model_data, data_size, static_cast<mindspore::ModelType>(model_type), &graph); 3637be168c0dSopenharmony_ci+ if (status != mindspore::kSuccess) { 3638be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "load ms file failed."; 3639be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 3640be168c0dSopenharmony_ci+ } 3641be168c0dSopenharmony_ci+ auto context = static_cast<mindspore::Context *>(model_context); 3642be168c0dSopenharmony_ci+ auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg); 3643be168c0dSopenharmony_ci+ if (impl->context_.get() != context) { 3644be168c0dSopenharmony_ci+ impl->context_.reset(context); 3645be168c0dSopenharmony_ci+ } 3646be168c0dSopenharmony_ci+ auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_, 3647be168c0dSopenharmony_ci+ std::shared_ptr<mindspore::TrainCfg>(build_train_cfg)); 3648be168c0dSopenharmony_ci+ if (ret != mindspore::kSuccess) { 3649be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Load and compile failed"; 3650be168c0dSopenharmony_ci+ } 3651be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3652be168c0dSopenharmony_ci+} 3653be168c0dSopenharmony_ci+ 3654be168c0dSopenharmony_ci+OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, 3655be168c0dSopenharmony_ci+ OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, 3656be168c0dSopenharmony_ci+ const OH_AI_TrainCfgHandle train_cfg) { 3657be168c0dSopenharmony_ci+ if (model == nullptr || model_path == nullptr || model_context == nullptr) { 3658be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model/model_path/model_context is nullptr."; 3659be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 3660be168c0dSopenharmony_ci+ } 3661be168c0dSopenharmony_ci+ if (model_type == OH_AI_MODELTYPE_INVALID) { 3662be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model_type is invalid."; 3663be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3664be168c0dSopenharmony_ci+ } 3665be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3666be168c0dSopenharmony_ci+ 3667be168c0dSopenharmony_ci+ mindspore::Graph graph; 3668be168c0dSopenharmony_ci+ auto status = mindspore::Serialization::Load(model_path, static_cast<mindspore::ModelType>(model_type), &graph); 3669be168c0dSopenharmony_ci+ if (status != mindspore::kSuccess) { 3670be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "load ms file failed. " << model_path; 3671be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_ERROR; 3672be168c0dSopenharmony_ci+ } 3673be168c0dSopenharmony_ci+ auto context = static_cast<mindspore::Context *>(model_context); 3674be168c0dSopenharmony_ci+ auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg); 3675be168c0dSopenharmony_ci+ if (impl->context_.get() != context) { 3676be168c0dSopenharmony_ci+ impl->context_.reset(context); 3677be168c0dSopenharmony_ci+ } 3678be168c0dSopenharmony_ci+ auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_, 3679be168c0dSopenharmony_ci+ std::shared_ptr<mindspore::TrainCfg>(build_train_cfg)); 3680be168c0dSopenharmony_ci+ if (ret != mindspore::kSuccess) { 3681be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Load and compile failed"; 3682be168c0dSopenharmony_ci+ } 3683be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3684be168c0dSopenharmony_ci+} 3685be168c0dSopenharmony_ci+ 3686be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate) { 3687be168c0dSopenharmony_ci+ if (model == nullptr) { 3688be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3689be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3690be168c0dSopenharmony_ci+ } 3691be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3692be168c0dSopenharmony_ci+ auto ret = impl->model_->SetLearningRate(learning_rate); 3693be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3694be168c0dSopenharmony_ci+} 3695be168c0dSopenharmony_ci+ 3696be168c0dSopenharmony_ci+float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model) { 3697be168c0dSopenharmony_ci+ if (model == nullptr) { 3698be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3699be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3700be168c0dSopenharmony_ci+ } 3701be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3702be168c0dSopenharmony_ci+ return impl->model_->GetLearningRate(); 3703be168c0dSopenharmony_ci+} 3704be168c0dSopenharmony_ci+ 3705be168c0dSopenharmony_ci+OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) { 3706be168c0dSopenharmony_ci+ if (model == nullptr) { 3707be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3708be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3709be168c0dSopenharmony_ci+ } 3710be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3711be168c0dSopenharmony_ci+ auto ret = impl->model_->RunStep(impl->TransCallBack(before), impl->TransCallBack(after)); 3712be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3713be168c0dSopenharmony_ci+} 3714be168c0dSopenharmony_ci+ 3715be168c0dSopenharmony_ci+OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model) { 3716be168c0dSopenharmony_ci+ if (model == nullptr) { 3717be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3718be168c0dSopenharmony_ci+ return {0, nullptr}; 3719be168c0dSopenharmony_ci+ } 3720be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3721be168c0dSopenharmony_ci+ auto features = impl->model_->GetFeatureMaps(); 3722be168c0dSopenharmony_ci+ size_t handle_num = features.size(); 3723be168c0dSopenharmony_ci+ 3724be168c0dSopenharmony_ci+ mindspore::MSTensor **handle_list = static_cast<mindspore::MSTensor **>(malloc( 3725be168c0dSopenharmony_ci+ handle_num * sizeof(mindspore::MSTensor *))); 3726be168c0dSopenharmony_ci+ if (handle_list == nullptr) { 3727be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Failed to malloc handle_list."; 3728be168c0dSopenharmony_ci+ return {0, nullptr}; 3729be168c0dSopenharmony_ci+ } 3730be168c0dSopenharmony_ci+ for (size_t i = 0; i < handle_num; i++) { 3731be168c0dSopenharmony_ci+ handle_list[i] = new mindspore::MSTensor(features[i].impl()); 3732be168c0dSopenharmony_ci+ } 3733be168c0dSopenharmony_ci+ return {handle_num, reinterpret_cast<OH_AI_TensorHandle *>(handle_list)}; 3734be168c0dSopenharmony_ci+} 3735be168c0dSopenharmony_ci+ 3736be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights) { 3737be168c0dSopenharmony_ci+ if (model == nullptr) { 3738be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3739be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3740be168c0dSopenharmony_ci+ } 3741be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3742be168c0dSopenharmony_ci+ std::vector<mindspore::MSTensor> weights; 3743be168c0dSopenharmony_ci+ for (size_t i = 0; i < new_weights.handle_num; i++) { 3744be168c0dSopenharmony_ci+ weights.push_back(*static_cast<mindspore::MSTensor *>(new_weights.handle_list[i])); 3745be168c0dSopenharmony_ci+ } 3746be168c0dSopenharmony_ci+ auto ret = impl->model_->UpdateWeights(weights); 3747be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3748be168c0dSopenharmony_ci+} 3749be168c0dSopenharmony_ci+ 3750be168c0dSopenharmony_ci+bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model) { 3751be168c0dSopenharmony_ci+ if (model == nullptr) { 3752be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3753be168c0dSopenharmony_ci+ return false; 3754be168c0dSopenharmony_ci+ } 3755be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3756be168c0dSopenharmony_ci+ return impl->model_->GetTrainMode(); 3757be168c0dSopenharmony_ci+} 3758be168c0dSopenharmony_ci+ 3759be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train) { 3760be168c0dSopenharmony_ci+ if (model == nullptr) { 3761be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3762be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3763be168c0dSopenharmony_ci+ } 3764be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3765be168c0dSopenharmony_ci+ auto ret = impl->model_->SetTrainMode(train); 3766be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3767be168c0dSopenharmony_ci+} 3768be168c0dSopenharmony_ci+ 3769be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, float momentum) { 3770be168c0dSopenharmony_ci+ if (model == nullptr) { 3771be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3772be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3773be168c0dSopenharmony_ci+ } 3774be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3775be168c0dSopenharmony_ci+ auto ret = impl->model_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum); 3776be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3777be168c0dSopenharmony_ci+} 3778be168c0dSopenharmony_ci+ 3779be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file, 3780be168c0dSopenharmony_ci+ OH_AI_QuantizationType quantization_type, bool export_inference_only, 3781be168c0dSopenharmony_ci+ char **output_tensor_name, size_t num) { 3782be168c0dSopenharmony_ci+ if (model == nullptr) { 3783be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3784be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3785be168c0dSopenharmony_ci+ } 3786be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3787be168c0dSopenharmony_ci+ std::vector<std::string> tensor_name; 3788be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 3789be168c0dSopenharmony_ci+ tensor_name.push_back(output_tensor_name[i]); 3790be168c0dSopenharmony_ci+ } 3791be168c0dSopenharmony_ci+ auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), 3792be168c0dSopenharmony_ci+ model_file, 3793be168c0dSopenharmony_ci+ static_cast<mindspore::QuantizationType>(quantization_type), 3794be168c0dSopenharmony_ci+ export_inference_only, tensor_name); 3795be168c0dSopenharmony_ci+ if (!ret.IsOk()) { 3796be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "export model fail, ret :" << ret; 3797be168c0dSopenharmony_ci+ } 3798be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3799be168c0dSopenharmony_ci+} 3800be168c0dSopenharmony_ci+ 3801be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data, 3802be168c0dSopenharmony_ci+ size_t *data_size, OH_AI_QuantizationType quantization_type, 3803be168c0dSopenharmony_ci+ bool export_inference_only, char **output_tensor_name, size_t num) { 3804be168c0dSopenharmony_ci+ if (model == nullptr) { 3805be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3806be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3807be168c0dSopenharmony_ci+ } 3808be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3809be168c0dSopenharmony_ci+ std::vector<std::string> tensor_name; 3810be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 3811be168c0dSopenharmony_ci+ tensor_name.push_back(output_tensor_name[i]); 3812be168c0dSopenharmony_ci+ } 3813be168c0dSopenharmony_ci+ mindspore::Buffer buffer; 3814be168c0dSopenharmony_ci+ auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), 3815be168c0dSopenharmony_ci+ &buffer, static_cast<mindspore::QuantizationType>(quantization_type), 3816be168c0dSopenharmony_ci+ export_inference_only, tensor_name); 3817be168c0dSopenharmony_ci+ auto data = static_cast<char *>(buffer.MutableData()); 3818be168c0dSopenharmony_ci+ *model_data = (char *) malloc(buffer.DataSize()); 3819be168c0dSopenharmony_ci+ *data_size = buffer.DataSize(); 3820be168c0dSopenharmony_ci+ memcpy(*model_data, data, buffer.DataSize()); 3821be168c0dSopenharmony_ci+ if (!ret.IsOk()) { 3822be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "export model fail, ret :" << ret; 3823be168c0dSopenharmony_ci+ } 3824be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3825be168c0dSopenharmony_ci+} 3826be168c0dSopenharmony_ci+ 3827be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *weight_file, 3828be168c0dSopenharmony_ci+ bool is_inference, bool enable_fp16, char **changeable_weights_name, size_t num) { 3829be168c0dSopenharmony_ci+ if (model == nullptr) { 3830be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "model is nullptr."; 3831be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 3832be168c0dSopenharmony_ci+ } 3833be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::ModelC *>(model); 3834be168c0dSopenharmony_ci+ std::vector<std::string> weights_name; 3835be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 3836be168c0dSopenharmony_ci+ weights_name.push_back(changeable_weights_name[i]); 3837be168c0dSopenharmony_ci+ } 3838be168c0dSopenharmony_ci+ auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), weight_file, is_inference, enable_fp16, weights_name); 3839be168c0dSopenharmony_ci+ if (!ret.IsOk()) { 3840be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "export model fail, ret :" << ret; 3841be168c0dSopenharmony_ci+ } 3842be168c0dSopenharmony_ci+ return static_cast<OH_AI_Status>(ret.StatusCode()); 3843be168c0dSopenharmony_ci+} 3844be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/tensor_c.cc b/mindspore/lite/src/litert/c_api/tensor_c.cc 3845be168c0dSopenharmony_ciindex 7b5c4c2f..4b1e6aff 100644 3846be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/tensor_c.cc 3847be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/tensor_c.cc 3848be168c0dSopenharmony_ci@@ -17,7 +17,6 @@ 3849be168c0dSopenharmony_ci #include "include/api/status.h" 3850be168c0dSopenharmony_ci #include "src/tensor.h" 3851be168c0dSopenharmony_ci #include "src/litert/cxx_api/tensor/tensor_impl.h" 3852be168c0dSopenharmony_ci-#include "src/litert/inner_allocator.h" 3853be168c0dSopenharmony_ci 3854be168c0dSopenharmony_ci OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, const int64_t *shape, size_t shape_num, 3855be168c0dSopenharmony_ci const void *data, size_t data_len) { 3856be168c0dSopenharmony_ci@@ -31,18 +30,23 @@ OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, con 3857be168c0dSopenharmony_ci } 3858be168c0dSopenharmony_ci auto lite_tensor = 3859be168c0dSopenharmony_ci mindspore::lite::Tensor::CreateTensor(name, static_cast<mindspore::TypeId>(type), vec_shape, data, data_len); 3860be168c0dSopenharmony_ci- auto impl = new (std::nothrow) mindspore::LiteTensorImpl(lite_tensor); 3861be168c0dSopenharmony_ci- if (impl == nullptr || impl->lite_tensor() == nullptr) { 3862be168c0dSopenharmony_ci+ auto lite_tensor_impl = std::make_shared<mindspore::LiteTensorImpl>(lite_tensor); 3863be168c0dSopenharmony_ci+ if (lite_tensor_impl == nullptr || lite_tensor_impl->lite_tensor() == nullptr) { 3864be168c0dSopenharmony_ci MS_LOG(ERROR) << "Failed to allocate tensor impl."; 3865be168c0dSopenharmony_ci return nullptr; 3866be168c0dSopenharmony_ci } 3867be168c0dSopenharmony_ci- impl->set_from_session(false); 3868be168c0dSopenharmony_ci+ lite_tensor_impl->set_from_session(false); 3869be168c0dSopenharmony_ci+ auto impl = new (std::nothrow) mindspore::MSTensor(lite_tensor_impl); 3870be168c0dSopenharmony_ci+ if (impl == nullptr) { 3871be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Failed to allocate MSTensor."; 3872be168c0dSopenharmony_ci+ return nullptr; 3873be168c0dSopenharmony_ci+ } 3874be168c0dSopenharmony_ci return impl; 3875be168c0dSopenharmony_ci } 3876be168c0dSopenharmony_ci 3877be168c0dSopenharmony_ci void OH_AI_TensorDestroy(OH_AI_TensorHandle *tensor) { 3878be168c0dSopenharmony_ci if (tensor != nullptr && *tensor != nullptr) { 3879be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(*tensor); 3880be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(*tensor); 3881be168c0dSopenharmony_ci delete impl; 3882be168c0dSopenharmony_ci *tensor = nullptr; 3883be168c0dSopenharmony_ci } 3884be168c0dSopenharmony_ci@@ -53,20 +57,14 @@ OH_AI_TensorHandle OH_AI_TensorClone(OH_AI_TensorHandle tensor) { 3885be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3886be168c0dSopenharmony_ci return nullptr; 3887be168c0dSopenharmony_ci } 3888be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3889be168c0dSopenharmony_ci- auto lite_tensor = static_cast<mindspore::lite::Tensor *>(impl->lite_tensor()); 3890be168c0dSopenharmony_ci- auto clone = mindspore::lite::Tensor::CopyTensor(*lite_tensor, true, lite_tensor->allocator()); 3891be168c0dSopenharmony_ci- if (clone == nullptr) { 3892be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Failed to allocate tensor."; 3893be168c0dSopenharmony_ci- return nullptr; 3894be168c0dSopenharmony_ci- } 3895be168c0dSopenharmony_ci- auto clone_impl = new (std::nothrow) mindspore::LiteTensorImpl(clone); 3896be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3897be168c0dSopenharmony_ci+ auto clone_impl = impl->Clone(); 3898be168c0dSopenharmony_ci if (clone_impl == nullptr) { 3899be168c0dSopenharmony_ci- delete clone; 3900be168c0dSopenharmony_ci MS_LOG(ERROR) << "Failed to allocate tensor impl."; 3901be168c0dSopenharmony_ci return nullptr; 3902be168c0dSopenharmony_ci } 3903be168c0dSopenharmony_ci- clone_impl->set_from_session(false); 3904be168c0dSopenharmony_ci+ std::static_pointer_cast<mindspore::LiteTensorImpl>(clone_impl->impl())->set_own_data(false); 3905be168c0dSopenharmony_ci+ clone_impl->SetTensorName(impl->Name() + "_duplicate"); 3906be168c0dSopenharmony_ci return clone_impl; 3907be168c0dSopenharmony_ci } 3908be168c0dSopenharmony_ci 3909be168c0dSopenharmony_ci@@ -75,8 +73,8 @@ void OH_AI_TensorSetName(OH_AI_TensorHandle tensor, const char *name) { 3910be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3911be168c0dSopenharmony_ci return; 3912be168c0dSopenharmony_ci } 3913be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3914be168c0dSopenharmony_ci- impl->SetName(name); 3915be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3916be168c0dSopenharmony_ci+ impl->SetTensorName(name); 3917be168c0dSopenharmony_ci } 3918be168c0dSopenharmony_ci 3919be168c0dSopenharmony_ci const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) { 3920be168c0dSopenharmony_ci@@ -84,8 +82,8 @@ const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) { 3921be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3922be168c0dSopenharmony_ci return nullptr; 3923be168c0dSopenharmony_ci } 3924be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3925be168c0dSopenharmony_ci- return impl->Name().c_str(); 3926be168c0dSopenharmony_ci+ auto ms_tensor = static_cast<mindspore::MSTensor *>(tensor); 3927be168c0dSopenharmony_ci+ return std::static_pointer_cast<mindspore::LiteTensorImpl>(ms_tensor->impl())->Name().c_str(); 3928be168c0dSopenharmony_ci } 3929be168c0dSopenharmony_ci 3930be168c0dSopenharmony_ci void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) { 3931be168c0dSopenharmony_ci@@ -93,7 +91,7 @@ void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) { 3932be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3933be168c0dSopenharmony_ci return; 3934be168c0dSopenharmony_ci } 3935be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3936be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3937be168c0dSopenharmony_ci impl->SetDataType(static_cast<mindspore::DataType>(type)); 3938be168c0dSopenharmony_ci } 3939be168c0dSopenharmony_ci 3940be168c0dSopenharmony_ci@@ -102,7 +100,7 @@ OH_AI_DataType OH_AI_TensorGetDataType(const OH_AI_TensorHandle tensor) { 3941be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3942be168c0dSopenharmony_ci return OH_AI_DATATYPE_UNKNOWN; 3943be168c0dSopenharmony_ci } 3944be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3945be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3946be168c0dSopenharmony_ci auto dtype = impl->DataType(); 3947be168c0dSopenharmony_ci return static_cast<OH_AI_DataType>(dtype); 3948be168c0dSopenharmony_ci } 3949be168c0dSopenharmony_ci@@ -112,7 +110,7 @@ void OH_AI_TensorSetShape(OH_AI_TensorHandle tensor, const int64_t *shape, size_ 3950be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3951be168c0dSopenharmony_ci return; 3952be168c0dSopenharmony_ci } 3953be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3954be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3955be168c0dSopenharmony_ci std::vector<int64_t> vec_shape(shape_num); 3956be168c0dSopenharmony_ci for (size_t i = 0; i < shape_num; i++) { 3957be168c0dSopenharmony_ci vec_shape[i] = shape[i]; 3958be168c0dSopenharmony_ci@@ -125,7 +123,7 @@ const int64_t *OH_AI_TensorGetShape(const OH_AI_TensorHandle tensor, size_t *sha 3959be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3960be168c0dSopenharmony_ci return nullptr; 3961be168c0dSopenharmony_ci } 3962be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3963be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3964be168c0dSopenharmony_ci *shape_num = impl->Shape().size(); 3965be168c0dSopenharmony_ci return impl->Shape().data(); 3966be168c0dSopenharmony_ci } 3967be168c0dSopenharmony_ci@@ -135,7 +133,7 @@ void OH_AI_TensorSetFormat(OH_AI_TensorHandle tensor, OH_AI_Format format) { 3968be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3969be168c0dSopenharmony_ci return; 3970be168c0dSopenharmony_ci } 3971be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3972be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3973be168c0dSopenharmony_ci return impl->SetFormat(static_cast<mindspore::Format>(format)); 3974be168c0dSopenharmony_ci } 3975be168c0dSopenharmony_ci 3976be168c0dSopenharmony_ci@@ -144,8 +142,8 @@ OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor) { 3977be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3978be168c0dSopenharmony_ci return OH_AI_FORMAT_NHWC; 3979be168c0dSopenharmony_ci } 3980be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3981be168c0dSopenharmony_ci- return static_cast<OH_AI_Format>(impl->Format()); 3982be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3983be168c0dSopenharmony_ci+ return static_cast<OH_AI_Format>(impl->format()); 3984be168c0dSopenharmony_ci } 3985be168c0dSopenharmony_ci 3986be168c0dSopenharmony_ci void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) { 3987be168c0dSopenharmony_ci@@ -153,16 +151,34 @@ void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) { 3988be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 3989be168c0dSopenharmony_ci return; 3990be168c0dSopenharmony_ci } 3991be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 3992be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 3993be168c0dSopenharmony_ci return impl->SetData(data, true); 3994be168c0dSopenharmony_ci } 3995be168c0dSopenharmony_ci 3996be168c0dSopenharmony_ci+OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size) { 3997be168c0dSopenharmony_ci+ if (tensor == nullptr) { 3998be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "param is nullptr."; 3999be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_NULLPTR; 4000be168c0dSopenharmony_ci+ } 4001be168c0dSopenharmony_ci+ 4002be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4003be168c0dSopenharmony_ci+ if ((impl->DataSize() > 0) && (data_size != impl->DataSize())) { 4004be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "input data size does not match inner data size"; 4005be168c0dSopenharmony_ci+ return OH_AI_STATUS_LITE_PARAM_INVALID; 4006be168c0dSopenharmony_ci+ } 4007be168c0dSopenharmony_ci+ 4008be168c0dSopenharmony_ci+ // This is one tricky way to represent that the inner data is not owned by tensor itself. 4009be168c0dSopenharmony_ci+ impl->SetAllocator(nullptr); 4010be168c0dSopenharmony_ci+ impl->SetData(data, false); 4011be168c0dSopenharmony_ci+ return OH_AI_STATUS_SUCCESS; 4012be168c0dSopenharmony_ci+} 4013be168c0dSopenharmony_ci+ 4014be168c0dSopenharmony_ci const void *OH_AI_TensorGetData(const OH_AI_TensorHandle tensor) { 4015be168c0dSopenharmony_ci if (tensor == nullptr) { 4016be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 4017be168c0dSopenharmony_ci return nullptr; 4018be168c0dSopenharmony_ci } 4019be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 4020be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4021be168c0dSopenharmony_ci return impl->Data().get(); 4022be168c0dSopenharmony_ci } 4023be168c0dSopenharmony_ci 4024be168c0dSopenharmony_ci@@ -171,7 +187,7 @@ void *OH_AI_TensorGetMutableData(const OH_AI_TensorHandle tensor) { 4025be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 4026be168c0dSopenharmony_ci return nullptr; 4027be168c0dSopenharmony_ci } 4028be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 4029be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4030be168c0dSopenharmony_ci return impl->MutableData(); 4031be168c0dSopenharmony_ci } 4032be168c0dSopenharmony_ci 4033be168c0dSopenharmony_ci@@ -180,7 +196,7 @@ int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor) { 4034be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 4035be168c0dSopenharmony_ci return 0; 4036be168c0dSopenharmony_ci } 4037be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 4038be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4039be168c0dSopenharmony_ci return impl->ElementNum(); 4040be168c0dSopenharmony_ci } 4041be168c0dSopenharmony_ci 4042be168c0dSopenharmony_ci@@ -189,6 +205,6 @@ size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor) { 4043be168c0dSopenharmony_ci MS_LOG(ERROR) << "param is nullptr."; 4044be168c0dSopenharmony_ci return 0; 4045be168c0dSopenharmony_ci } 4046be168c0dSopenharmony_ci- auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor); 4047be168c0dSopenharmony_ci+ auto impl = static_cast<mindspore::MSTensor *>(tensor); 4048be168c0dSopenharmony_ci return impl->DataSize(); 4049be168c0dSopenharmony_ci } 4050be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/type_c_private.h b/mindspore/lite/src/litert/c_api/type_c_private.h 4051be168c0dSopenharmony_cinew file mode 100644 4052be168c0dSopenharmony_ciindex 00000000..2d3b3883 4053be168c0dSopenharmony_ci--- /dev/null 4054be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/type_c_private.h 4055be168c0dSopenharmony_ci@@ -0,0 +1,40 @@ 4056be168c0dSopenharmony_ci+/** 4057be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 4058be168c0dSopenharmony_ci+ * 4059be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 4060be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 4061be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 4062be168c0dSopenharmony_ci+ * 4063be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 4064be168c0dSopenharmony_ci+ * 4065be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 4066be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 4067be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 4068be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 4069be168c0dSopenharmony_ci+ * limitations under the License. 4070be168c0dSopenharmony_ci+ */ 4071be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_ 4072be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_ 4073be168c0dSopenharmony_ci+ 4074be168c0dSopenharmony_ci+#include <string> 4075be168c0dSopenharmony_ci+#include <vector> 4076be168c0dSopenharmony_ci+#include <memory> 4077be168c0dSopenharmony_ci+#include <stddef.h> 4078be168c0dSopenharmony_ci+#include "include/c_api/types_c.h" 4079be168c0dSopenharmony_ci+ 4080be168c0dSopenharmony_ci+#ifdef __cplusplus 4081be168c0dSopenharmony_ci+extern "C" { 4082be168c0dSopenharmony_ci+#endif 4083be168c0dSopenharmony_ci+ 4084be168c0dSopenharmony_ci+#define NNRT_DEVICE_NAME_MAX (128) 4085be168c0dSopenharmony_ci+ 4086be168c0dSopenharmony_ci+struct NNRTDeviceDesc { 4087be168c0dSopenharmony_ci+ size_t device_id; 4088be168c0dSopenharmony_ci+ OH_AI_NNRTDeviceType device_type; 4089be168c0dSopenharmony_ci+ char device_name[NNRT_DEVICE_NAME_MAX]; 4090be168c0dSopenharmony_ci+}; 4091be168c0dSopenharmony_ci+ 4092be168c0dSopenharmony_ci+#ifdef __cplusplus 4093be168c0dSopenharmony_ci+} 4094be168c0dSopenharmony_ci+#endif 4095be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_ 4096be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/context.cc b/mindspore/lite/src/litert/cxx_api/context.cc 4097be168c0dSopenharmony_ciindex 1371bcf0..e5f19d28 100644 4098be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/context.cc 4099be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/context.cc 4100be168c0dSopenharmony_ci@@ -50,6 +50,11 @@ constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dyn 4101be168c0dSopenharmony_ci constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size"; 4102be168c0dSopenharmony_ci constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize"; 4103be168c0dSopenharmony_ci constexpr auto kModelOptionAscendRankID = "mindspore.option.ascend.rank_id"; 4104be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTDeviceID = "mindspore.option.nnrt.device_id"; 4105be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTPerformanceMode = "mindspore.option.nnrt.performance_mode"; 4106be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTPriority = "mindspore.option.nnrt.priority"; 4107be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTEnableFP16 = "mindspore.option.nnrt.enable_fp16"; 4108be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTExtensions = "mindspore.option.nnrt.extensions"; 4109be168c0dSopenharmony_ci #ifdef USE_GLOG 4110be168c0dSopenharmony_ci extern "C" { 4111be168c0dSopenharmony_ci extern void mindspore_log_init(); 4112be168c0dSopenharmony_ci@@ -684,4 +689,84 @@ std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const { 4113be168c0dSopenharmony_ci const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendBufferOptimize); 4114be168c0dSopenharmony_ci return StringToChar(ref); 4115be168c0dSopenharmony_ci } 4116be168c0dSopenharmony_ci+ 4117be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetDeviceID(size_t device_id) { 4118be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4119be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4120be168c0dSopenharmony_ci+ return; 4121be168c0dSopenharmony_ci+ } 4122be168c0dSopenharmony_ci+ data_->params[kModelOptionNNRTDeviceID] = device_id; 4123be168c0dSopenharmony_ci+} 4124be168c0dSopenharmony_ci+ 4125be168c0dSopenharmony_ci+size_t NNRTDeviceInfo::GetDeviceID() const { 4126be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4127be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4128be168c0dSopenharmony_ci+ return 0; 4129be168c0dSopenharmony_ci+ } 4130be168c0dSopenharmony_ci+ return GetValue<size_t>(data_, kModelOptionNNRTDeviceID); 4131be168c0dSopenharmony_ci+} 4132be168c0dSopenharmony_ci+ 4133be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetPerformanceMode(int performance_mode) { 4134be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4135be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4136be168c0dSopenharmony_ci+ return; 4137be168c0dSopenharmony_ci+ } 4138be168c0dSopenharmony_ci+ data_->params[kModelOptionNNRTPerformanceMode] = performance_mode; 4139be168c0dSopenharmony_ci+} 4140be168c0dSopenharmony_ci+ 4141be168c0dSopenharmony_ci+int NNRTDeviceInfo::GetPerformanceMode() const { 4142be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4143be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4144be168c0dSopenharmony_ci+ return 0; 4145be168c0dSopenharmony_ci+ } 4146be168c0dSopenharmony_ci+ return GetValue<int>(data_, kModelOptionNNRTPerformanceMode); 4147be168c0dSopenharmony_ci+} 4148be168c0dSopenharmony_ci+ 4149be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetPriority(int priority) { 4150be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4151be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4152be168c0dSopenharmony_ci+ return; 4153be168c0dSopenharmony_ci+ } 4154be168c0dSopenharmony_ci+ data_->params[kModelOptionNNRTPriority] = priority; 4155be168c0dSopenharmony_ci+} 4156be168c0dSopenharmony_ci+ 4157be168c0dSopenharmony_ci+int NNRTDeviceInfo::GetPriority() const { 4158be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4159be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4160be168c0dSopenharmony_ci+ return 0; 4161be168c0dSopenharmony_ci+ } 4162be168c0dSopenharmony_ci+ return GetValue<int>(data_, kModelOptionNNRTPriority); 4163be168c0dSopenharmony_ci+} 4164be168c0dSopenharmony_ci+ 4165be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetEnableFP16(bool is_fp16) { 4166be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4167be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4168be168c0dSopenharmony_ci+ return; 4169be168c0dSopenharmony_ci+ } 4170be168c0dSopenharmony_ci+ data_->params[kModelOptionNNRTEnableFP16] = is_fp16; 4171be168c0dSopenharmony_ci+} 4172be168c0dSopenharmony_ci+ 4173be168c0dSopenharmony_ci+bool NNRTDeviceInfo::GetEnableFP16() const { 4174be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4175be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4176be168c0dSopenharmony_ci+ return false; 4177be168c0dSopenharmony_ci+ } 4178be168c0dSopenharmony_ci+ return GetValue<bool>(data_, kModelOptionNNRTEnableFP16); 4179be168c0dSopenharmony_ci+} 4180be168c0dSopenharmony_ci+ 4181be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetExtensions(const std::vector<Extension> &extensions) { 4182be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4183be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4184be168c0dSopenharmony_ci+ return; 4185be168c0dSopenharmony_ci+ } 4186be168c0dSopenharmony_ci+ data_->params[kModelOptionNNRTExtensions] = extensions; 4187be168c0dSopenharmony_ci+} 4188be168c0dSopenharmony_ci+ 4189be168c0dSopenharmony_ci+std::vector<Extension> NNRTDeviceInfo::GetExtensions() const { 4190be168c0dSopenharmony_ci+ if (data_ == nullptr) { 4191be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid context."; 4192be168c0dSopenharmony_ci+ return {}; 4193be168c0dSopenharmony_ci+ } 4194be168c0dSopenharmony_ci+ return GetValue<std::vector<Extension>>(data_, kModelOptionNNRTExtensions); 4195be168c0dSopenharmony_ci+} 4196be168c0dSopenharmony_ci } // namespace mindspore 4197be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/converters.cc b/mindspore/lite/src/litert/cxx_api/converters.cc 4198be168c0dSopenharmony_ciindex 0ff345cc..e54a36ee 100644 4199be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/converters.cc 4200be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/converters.cc 4201be168c0dSopenharmony_ci@@ -86,6 +86,23 @@ Status ContextUtils::AddCustomDevice(lite::InnerContext *inner_context, 4202be168c0dSopenharmony_ci return kSuccess; 4203be168c0dSopenharmony_ci } 4204be168c0dSopenharmony_ci 4205be168c0dSopenharmony_ci+Status ContextUtils::AddNNRtDevice(lite::InnerContext *inner_context, size_t device_id, int performance_mode, 4206be168c0dSopenharmony_ci+ int priority, bool enable_fp16, const std::vector<Extension> &extensions) { 4207be168c0dSopenharmony_ci+ lite::DeviceInfo device_info = {0}; 4208be168c0dSopenharmony_ci+ device_info.nnrt_device_info_.device_id_ = device_id; 4209be168c0dSopenharmony_ci+ device_info.nnrt_device_info_.performance_mode_ = performance_mode; 4210be168c0dSopenharmony_ci+ device_info.nnrt_device_info_.priority_ = priority; 4211be168c0dSopenharmony_ci+ device_info.nnrt_device_info_.enable_fp16_ = enable_fp16; 4212be168c0dSopenharmony_ci+ for (auto src_extension: extensions) { 4213be168c0dSopenharmony_ci+ lite::Extension dest_extension; 4214be168c0dSopenharmony_ci+ dest_extension.name = src_extension.name; 4215be168c0dSopenharmony_ci+ dest_extension.value = src_extension.value; 4216be168c0dSopenharmony_ci+ device_info.nnrt_device_info_.extensions_.push_back(dest_extension); 4217be168c0dSopenharmony_ci+ } 4218be168c0dSopenharmony_ci+ inner_context->device_list_.push_back({lite::DT_NNRT, device_info}); 4219be168c0dSopenharmony_ci+ return kSuccess; 4220be168c0dSopenharmony_ci+} 4221be168c0dSopenharmony_ci+ 4222be168c0dSopenharmony_ci void ContextUtils::ResetContextDefaultParam(Context *context) { 4223be168c0dSopenharmony_ci if (context->GetInterOpParallelNum() == 0) { 4224be168c0dSopenharmony_ci context->SetInterOpParallelNum(kDefaultInterOpParallelNum); 4225be168c0dSopenharmony_ci@@ -163,44 +180,11 @@ std::shared_ptr<lite::InnerContext> ContextUtils::Convert(Context *context) { 4226be168c0dSopenharmony_ci ret = AddAscendDevice(inner_context.get(), device.get()); 4227be168c0dSopenharmony_ci } else if (device->GetDeviceType() == kCustomDevice) { 4228be168c0dSopenharmony_ci ret = AddCustomDevice(inner_context.get(), device); 4229be168c0dSopenharmony_ci- } 4230be168c0dSopenharmony_ci- if (ret != kSuccess) { 4231be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Add device failed!"; 4232be168c0dSopenharmony_ci- return nullptr; 4233be168c0dSopenharmony_ci- } 4234be168c0dSopenharmony_ci- } 4235be168c0dSopenharmony_ci- return inner_context; 4236be168c0dSopenharmony_ci-} 4237be168c0dSopenharmony_ci- 4238be168c0dSopenharmony_ci-std::shared_ptr<lite::InnerContext> ContextUtils::Convert(const ContextC *context_c) { 4239be168c0dSopenharmony_ci- auto inner_context = std::make_shared<lite::InnerContext>(); 4240be168c0dSopenharmony_ci- if ((context_c == nullptr) || (inner_context == nullptr)) { 4241be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Invalid context pointers."; 4242be168c0dSopenharmony_ci- return nullptr; 4243be168c0dSopenharmony_ci- } 4244be168c0dSopenharmony_ci- auto device_list = context_c->device_info_list; 4245be168c0dSopenharmony_ci- if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) { 4246be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices; 4247be168c0dSopenharmony_ci- return nullptr; 4248be168c0dSopenharmony_ci- } 4249be168c0dSopenharmony_ci- SetContextAttr(context_c->thread_num, 1, context_c->enable_parallel, context_c->affinity_core_list, 4250be168c0dSopenharmony_ci- context_c->delegate_mode, context_c->delegate, inner_context.get()); 4251be168c0dSopenharmony_ci- inner_context->device_list_.clear(); 4252be168c0dSopenharmony_ci- Status ret = kLiteError; 4253be168c0dSopenharmony_ci- for (auto &device_info_c : device_list) { 4254be168c0dSopenharmony_ci- MS_CHECK_TRUE_RET(device_info_c != nullptr, nullptr); 4255be168c0dSopenharmony_ci- lite::DeviceInfo device_info = {{0}}; 4256be168c0dSopenharmony_ci- if (device_info_c->device_type == OH_AI_DEVICETYPE_CPU) { 4257be168c0dSopenharmony_ci- if (device_info_c->allocator == nullptr) { 4258be168c0dSopenharmony_ci- device_info_c->allocator = Allocator::Create(); 4259be168c0dSopenharmony_ci- } 4260be168c0dSopenharmony_ci- ret = AddCpuDevice(device_info_c->allocator, context_c->affinity_mode, device_info_c->enable_fp16, 4261be168c0dSopenharmony_ci- device_info_c->provider, device_info_c->provider_device, inner_context.get()); 4262be168c0dSopenharmony_ci- } else if (device_info_c->device_type == OH_AI_DEVICETYPE_GPU) { 4263be168c0dSopenharmony_ci- ret = AddGpuDevice(device_info_c->enable_fp16, 0, 0, 0, false, nullptr, nullptr, device_info_c->provider, 4264be168c0dSopenharmony_ci- device_info_c->provider_device, device_info_c->allocator, inner_context.get()); 4265be168c0dSopenharmony_ci- } else if (device_info_c->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) { 4266be168c0dSopenharmony_ci- ret = AddNpuDevice(device_info_c->enable_fp16, device_info_c->frequency, inner_context.get()); 4267be168c0dSopenharmony_ci+ } else if (device->GetDeviceType() == kNNRt) { 4268be168c0dSopenharmony_ci+ auto nnrt_device_info = device->Cast<NNRTDeviceInfo>(); 4269be168c0dSopenharmony_ci+ ret = AddNNRtDevice(inner_context.get(), nnrt_device_info->GetDeviceID(), 4270be168c0dSopenharmony_ci+ nnrt_device_info->GetPerformanceMode(), nnrt_device_info->GetPriority(), 4271be168c0dSopenharmony_ci+ nnrt_device_info->GetEnableFP16(), nnrt_device_info->GetExtensions()); 4272be168c0dSopenharmony_ci } 4273be168c0dSopenharmony_ci if (ret != kSuccess) { 4274be168c0dSopenharmony_ci MS_LOG(ERROR) << "Add device failed!"; 4275be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/converters.h b/mindspore/lite/src/litert/cxx_api/converters.h 4276be168c0dSopenharmony_ciindex 0c043fc3..1af7c7df 100644 4277be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/converters.h 4278be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/converters.h 4279be168c0dSopenharmony_ci@@ -24,14 +24,12 @@ 4280be168c0dSopenharmony_ci #include "include/api/cfg.h" 4281be168c0dSopenharmony_ci #include "include/train/train_cfg.h" 4282be168c0dSopenharmony_ci #include "src/litert/inner_context.h" 4283be168c0dSopenharmony_ci-#include "src/litert/c_api/context_c.h" 4284be168c0dSopenharmony_ci #include "src/common/log_adapter.h" 4285be168c0dSopenharmony_ci 4286be168c0dSopenharmony_ci namespace mindspore { 4287be168c0dSopenharmony_ci class MS_API ContextUtils { 4288be168c0dSopenharmony_ci public: 4289be168c0dSopenharmony_ci static std::shared_ptr<lite::InnerContext> Convert(Context *context); 4290be168c0dSopenharmony_ci- static std::shared_ptr<lite::InnerContext> Convert(const ContextC *context_c); 4291be168c0dSopenharmony_ci 4292be168c0dSopenharmony_ci private: 4293be168c0dSopenharmony_ci static void SetContextAttr(int32_t thread_num, int32_t inter_op_parallel_num, bool enable_parallel, 4294be168c0dSopenharmony_ci@@ -48,6 +46,8 @@ class MS_API ContextUtils { 4295be168c0dSopenharmony_ci static Status AddNpuDevice(bool enable_fp16, int frequency, lite::InnerContext *inner_context); 4296be168c0dSopenharmony_ci static Status AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device); 4297be168c0dSopenharmony_ci static Status AddCustomDevice(lite::InnerContext *inner_context, const std::shared_ptr<DeviceInfoContext> &device); 4298be168c0dSopenharmony_ci+ static Status AddNNRtDevice(lite::InnerContext *inner_context, size_t device_id, int performance_mode, int priority, 4299be168c0dSopenharmony_ci+ bool enable_fp16, const std::vector<Extension> &extensions); 4300be168c0dSopenharmony_ci static bool IsAffinityModeValid(int affinity_mode) { 4301be168c0dSopenharmony_ci return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU; 4302be168c0dSopenharmony_ci } 4303be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt b/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt 4304be168c0dSopenharmony_ciindex 70aa63f3..625459e2 100644 4305be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt 4306be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt 4307be168c0dSopenharmony_ci@@ -1,30 +1,13 @@ 4308be168c0dSopenharmony_ci include_directories(${DDK_PATH}) 4309be168c0dSopenharmony_ci include_directories($(CCSRC_DIR)/plugin/device/cpu/kernel) 4310be168c0dSopenharmony_ci+include_directories(${CMAKE_SOURCE_DIR}/../../../../../../foundation/ai/neural_network_runtime/) 4311be168c0dSopenharmony_ci 4312be168c0dSopenharmony_ci include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) 4313be168c0dSopenharmony_ci-#include_directories(/home/tony/wty/workspace/ohos/third_party/mindspore/mindspore/lite/mindir/include/inner) 4314be168c0dSopenharmony_ci-#include_directories(/home/tony/wty/workspace/ohos/third_party/mindspore/mindspore/lite/mindir/include) 4315be168c0dSopenharmony_ci+ 4316be168c0dSopenharmony_ci file(GLOB_RECURSE NNRT_SRC 4317be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/*.cc 4318be168c0dSopenharmony_ci ) 4319be168c0dSopenharmony_ci- 4320be168c0dSopenharmony_ci-#add_library(hiai SHARED IMPORTED) 4321be168c0dSopenharmony_ci-#set_target_properties(hiai PROPERTIES IMPORTED_LOCATION 4322be168c0dSopenharmony_ci-# ${DDK_LIB_PATH}/libhiai.so) 4323be168c0dSopenharmony_ci-#add_library(hiai_ir SHARED IMPORTED) 4324be168c0dSopenharmony_ci-#set_target_properties(hiai_ir PROPERTIES IMPORTED_LOCATION 4325be168c0dSopenharmony_ci-# ${DDK_LIB_PATH}/libhiai_ir.so) 4326be168c0dSopenharmony_ci-#add_library(hiai_ir_build SHARED IMPORTED) 4327be168c0dSopenharmony_ci-#set_target_properties(hiai_ir_build PROPERTIES IMPORTED_LOCATION 4328be168c0dSopenharmony_ci-# ${DDK_LIB_PATH}/libhiai_ir_build.so) 4329be168c0dSopenharmony_ci-#add_library(npu_kernel_mid OBJECT ${NPU_RUNTIME_SRC}) 4330be168c0dSopenharmony_ci-#add_dependencies(npu_kernel_mid fbs_src) 4331be168c0dSopenharmony_ci-#target_link_libraries( 4332be168c0dSopenharmony_ci-# npu_kernel_mid 4333be168c0dSopenharmony_ci-# hiai 4334be168c0dSopenharmony_ci-# hiai_ir 4335be168c0dSopenharmony_ci-# hiai_ir_build 4336be168c0dSopenharmony_ci-#) 4337be168c0dSopenharmony_ci- 4338be168c0dSopenharmony_ci file(GLOB convert_source checker/*.cc) 4339be168c0dSopenharmony_ci-add_library(nnr_mid OBJECT ${NNRT_SRC} ${convert_source} ) 4340be168c0dSopenharmony_ci\ No newline at end of file 4341be168c0dSopenharmony_ci+ 4342be168c0dSopenharmony_ci+add_library(nnrt_mid OBJECT ${NNRT_SRC} ${convert_source}) 4343be168c0dSopenharmony_ci+target_include_directories(nnrt_mid PUBLIC ${CMAKE_SOURCE_DIR}/../../../../../../foundation/ai/neural_network_runtime/) 4344be168c0dSopenharmony_ci\ No newline at end of file 4345be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc b/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc 4346be168c0dSopenharmony_ciindex 4df7e477..6b191c8e 100644 4347be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc 4348be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc 4349be168c0dSopenharmony_ci@@ -109,6 +109,8 @@ Status CheckPrimitiveSupported(const schema::Primitive *primitive) { 4350be168c0dSopenharmony_ci return mindspore::kSuccess; 4351be168c0dSopenharmony_ci case schema::PrimitiveType_Unsqueeze: 4352be168c0dSopenharmony_ci return mindspore::kSuccess; 4353be168c0dSopenharmony_ci+ case schema::PrimitiveType_Custom: 4354be168c0dSopenharmony_ci+ return mindspore::kSuccess; 4355be168c0dSopenharmony_ci default: { 4356be168c0dSopenharmony_ci MS_LOG(WARNING) << "No primitive type :" << (int)(type); 4357be168c0dSopenharmony_ci return mindspore::kLiteSuccessExit; 4358be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc 4359be168c0dSopenharmony_ciindex 34897331..9f012e76 100644 4360be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc 4361be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc 4362be168c0dSopenharmony_ci@@ -13,144 +13,637 @@ 4363be168c0dSopenharmony_ci * See the License for the specific language governing permissions and 4364be168c0dSopenharmony_ci * limitations under the License. 4365be168c0dSopenharmony_ci */ 4366be168c0dSopenharmony_ci+ 4367be168c0dSopenharmony_ci+#include <unordered_set> 4368be168c0dSopenharmony_ci+#include <numeric> 4369be168c0dSopenharmony_ci #include "nnrt_delegate.h" 4370be168c0dSopenharmony_ci #include "checker/primitive_check.h" 4371be168c0dSopenharmony_ci #include "src/common/log_adapter.h" 4372be168c0dSopenharmony_ci-#include "interfaces/kits/c/neural_network_runtime.h" 4373be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 4374be168c0dSopenharmony_ci #include "interfaces/innerkits/c/neural_network_runtime_inner.h" 4375be168c0dSopenharmony_ci #include "nnrt_model_kernel.h" 4376be168c0dSopenharmony_ci+#include "schema/model_generated.h" 4377be168c0dSopenharmony_ci+#include "schema/ops_generated.h" 4378be168c0dSopenharmony_ci+#include "flatbuffers/flatbuffers.h" 4379be168c0dSopenharmony_ci+#include "litert/tensor_category.h" 4380be168c0dSopenharmony_ci+ 4381be168c0dSopenharmony_ci+namespace mindspore { 4382be168c0dSopenharmony_ci+namespace lite { 4383be168c0dSopenharmony_ci+void NNRTDelegate::InitCachePath() { 4384be168c0dSopenharmony_ci+ static const std::string kCachePathName = "CachePath"; 4385be168c0dSopenharmony_ci+ static const std::string kCacheVersion = "CacheVersion"; 4386be168c0dSopenharmony_ci+ 4387be168c0dSopenharmony_ci+ const auto &extensions = nnrt_device_info_.extensions_; 4388be168c0dSopenharmony_ci 4389be168c0dSopenharmony_ci-mindspore::Status mindspore::NNRTDelegate::Build(DelegateModel<schema::Primitive> *model) { 4390be168c0dSopenharmony_ci- if (this->nnrt_lite_graph == nullptr) { 4391be168c0dSopenharmony_ci- MS_LOG(ERROR) << "nnrt_lite_graph is nullptr."; 4392be168c0dSopenharmony_ci- return mindspore::kLiteError; 4393be168c0dSopenharmony_ci+ auto iter_path = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) { 4394be168c0dSopenharmony_ci+ return extension.name == kCachePathName; 4395be168c0dSopenharmony_ci+ }); 4396be168c0dSopenharmony_ci+ if (iter_path != extensions.end()) { 4397be168c0dSopenharmony_ci+ cache_path_ = std::string(iter_path->value.begin(), iter_path->value.end()); 4398be168c0dSopenharmony_ci } 4399be168c0dSopenharmony_ci- if (this->nnrt_lite_graph->sub_graphs_.empty()) { 4400be168c0dSopenharmony_ci- // must have at lease one subgraph 4401be168c0dSopenharmony_ci- MS_LOG(ERROR) << "must have at lease one subgraph"; 4402be168c0dSopenharmony_ci- return mindspore::kLiteError; 4403be168c0dSopenharmony_ci+ 4404be168c0dSopenharmony_ci+ auto iter_version = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) { 4405be168c0dSopenharmony_ci+ return extension.name == kCacheVersion; 4406be168c0dSopenharmony_ci+ }); 4407be168c0dSopenharmony_ci+ if (iter_version != extensions.end()) { 4408be168c0dSopenharmony_ci+ std::string version_str = std::string(iter_version->value.begin(), iter_version->value.end()); 4409be168c0dSopenharmony_ci+ cache_version_ = static_cast<uint32_t>(std::atol(version_str.c_str())); 4410be168c0dSopenharmony_ci } 4411be168c0dSopenharmony_ci- OH_NN_ReturnCode ret_code; 4412be168c0dSopenharmony_ci- OH_NNModel *oh_nnmodel = OH_NNModel_Construct(); 4413be168c0dSopenharmony_ci- if (oh_nnmodel == nullptr) { 4414be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Construct NNModel failed, oh_nnmodel is nullptr."; 4415be168c0dSopenharmony_ci- return mindspore::kLiteError; 4416be168c0dSopenharmony_ci+} 4417be168c0dSopenharmony_ci+ 4418be168c0dSopenharmony_ci+Status NNRTDelegate::Build(DelegateModel<schema::Primitive> *model) { 4419be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT_METAGRAPH 4420be168c0dSopenharmony_ci+ if (IsKirinNPU()) { 4421be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "Choose to build nnrt model with Metagraph"; 4422be168c0dSopenharmony_ci+ InitCachePath(); 4423be168c0dSopenharmony_ci+ return BuildKirinNPUModel(model); 4424be168c0dSopenharmony_ci } 4425be168c0dSopenharmony_ci+#endif 4426be168c0dSopenharmony_ci 4427be168c0dSopenharmony_ci- ret_code = OH_NNModel_BuildFromLiteGraph(oh_nnmodel, this->nnrt_lite_graph); 4428be168c0dSopenharmony_ci- if (ret_code != OH_NN_SUCCESS) { 4429be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Build NNModel failed, OH_NN_ReturnCode = " << ret_code; 4430be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4431be168c0dSopenharmony_ci- return mindspore::kLiteError; 4432be168c0dSopenharmony_ci+ return BuildNormalModel(model); 4433be168c0dSopenharmony_ci+} 4434be168c0dSopenharmony_ci+ 4435be168c0dSopenharmony_ci+bool NNRTDelegate::IsCustomModel() const { 4436be168c0dSopenharmony_ci+ // check if there is only one Cutsom kernel in LiteModel. 4437be168c0dSopenharmony_ci+ if (lite_graph_ == nullptr) { 4438be168c0dSopenharmony_ci+ return false; 4439be168c0dSopenharmony_ci+ } 4440be168c0dSopenharmony_ci+ if (lite_graph_->all_nodes_.size() != 1) { 4441be168c0dSopenharmony_ci+ return false; 4442be168c0dSopenharmony_ci+ } 4443be168c0dSopenharmony_ci+ auto node = lite_graph_->all_nodes_[0]; 4444be168c0dSopenharmony_ci+ if (node == nullptr) { 4445be168c0dSopenharmony_ci+ return false; 4446be168c0dSopenharmony_ci+ } 4447be168c0dSopenharmony_ci+ if (node->node_type_ != mindspore::schema::PrimitiveType_Custom) { 4448be168c0dSopenharmony_ci+ return false; 4449be168c0dSopenharmony_ci+ } 4450be168c0dSopenharmony_ci+ return true; 4451be168c0dSopenharmony_ci+} 4452be168c0dSopenharmony_ci+ 4453be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT_METAGRAPH 4454be168c0dSopenharmony_ci+bool NNRTDelegate::IsKirinNPU() const { 4455be168c0dSopenharmony_ci+ const std::string kirin_npu_name_prefix = "NPU_"; 4456be168c0dSopenharmony_ci+ auto device_id = nnrt_device_info_.device_id_; 4457be168c0dSopenharmony_ci+ const char *device_name; 4458be168c0dSopenharmony_ci+ auto ret = OH_NNDevice_GetName(device_id, &device_name); 4459be168c0dSopenharmony_ci+ if (ret != OH_NN_SUCCESS) { 4460be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Get name of device: " << device_id << " failed, error: " << ret; 4461be168c0dSopenharmony_ci+ return false; 4462be168c0dSopenharmony_ci+ } 4463be168c0dSopenharmony_ci+ 4464be168c0dSopenharmony_ci+ if (strncmp(kirin_npu_name_prefix.c_str(), device_name, kirin_npu_name_prefix.size()) != 0) { 4465be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "strncmp: " << device_id << " failed, device_name: " << device_name; 4466be168c0dSopenharmony_ci+ return false; 4467be168c0dSopenharmony_ci+ } 4468be168c0dSopenharmony_ci+ return true; 4469be168c0dSopenharmony_ci+} 4470be168c0dSopenharmony_ci+ 4471be168c0dSopenharmony_ci+Status NNRTDelegate::BuildKirinNPUModel(DelegateModel<schema::Primitive> *model) { 4472be168c0dSopenharmony_ci+ OH_NNModel *nn_model = OH_NNModel_Construct(); 4473be168c0dSopenharmony_ci+ if (nn_model == nullptr) { 4474be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create NNModel failed, result is nullptr"; 4475be168c0dSopenharmony_ci+ return kLiteNullptr; 4476be168c0dSopenharmony_ci+ } 4477be168c0dSopenharmony_ci+ 4478be168c0dSopenharmony_ci+ size_t extension_size = nnrt_device_info_.extensions_.size(); 4479be168c0dSopenharmony_ci+ std::vector<OH_NN_Extension> extensions; 4480be168c0dSopenharmony_ci+ MS_LOG_DEBUG << "set extensions, item number: " << extension_size; 4481be168c0dSopenharmony_ci+ const size_t kExtensionNameMax = 128; // This is a length limitation in NNRT API. 4482be168c0dSopenharmony_ci+ for (size_t i = 0; i < extension_size; i++) { 4483be168c0dSopenharmony_ci+ auto &src_extension = nnrt_device_info_.extensions_[i]; 4484be168c0dSopenharmony_ci+ OH_NN_Extension dst_extension; 4485be168c0dSopenharmony_ci+ dst_extension.name[kExtensionNameMax - 1] = '\0'; 4486be168c0dSopenharmony_ci+ strncpy(dst_extension.name, src_extension.name.c_str(), kExtensionNameMax - 1); 4487be168c0dSopenharmony_ci+ dst_extension.value = (char *)((void *)src_extension.value.data()); 4488be168c0dSopenharmony_ci+ dst_extension.valueSize = src_extension.value.size(); 4489be168c0dSopenharmony_ci+ extensions.push_back(dst_extension); 4490be168c0dSopenharmony_ci+ MS_LOG_DEBUG << "set extension, item name: " << dst_extension.name << ", value size: " << dst_extension.valueSize; 4491be168c0dSopenharmony_ci+ } 4492be168c0dSopenharmony_ci+ 4493be168c0dSopenharmony_ci+ if (IsCustomModel()) { 4494be168c0dSopenharmony_ci+ auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, lite_graph_); 4495be168c0dSopenharmony_ci+ if (ret != OH_NN_SUCCESS) { 4496be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret; 4497be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4498be168c0dSopenharmony_ci+ return kLiteError; 4499be168c0dSopenharmony_ci+ } 4500be168c0dSopenharmony_ci+ } else { 4501be168c0dSopenharmony_ci+ SetKirinModelInputsAndOutputs(nn_model); 4502be168c0dSopenharmony_ci+ auto ret = OH_NNModel_BuildFromMetaGraph(nn_model, meta_graph_, extensions.data(), extensions.size()); 4503be168c0dSopenharmony_ci+ if (ret != OH_NN_SUCCESS) { 4504be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret; 4505be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4506be168c0dSopenharmony_ci+ return kLiteError; 4507be168c0dSopenharmony_ci+ } 4508be168c0dSopenharmony_ci+ } 4509be168c0dSopenharmony_ci+ 4510be168c0dSopenharmony_ci+ auto ret2 = CreateFullModelKernel(model, nn_model); 4511be168c0dSopenharmony_ci+ if (ret2 != kSuccess) { 4512be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create full model kernel failed, ret: " << ret2; 4513be168c0dSopenharmony_ci+ return kLiteError; 4514be168c0dSopenharmony_ci } 4515be168c0dSopenharmony_ci- MS_LOG(INFO) << "NNRTDelegate creates NNModel success."; 4516be168c0dSopenharmony_ci+ return kSuccess; 4517be168c0dSopenharmony_ci+} 4518be168c0dSopenharmony_ci+ 4519be168c0dSopenharmony_ci+std::vector<OH_NN_TensorInfo> NNRTDelegate::CreateNNTensorInfos(const std::vector<uint32_t> &indices) const { 4520be168c0dSopenharmony_ci+ std::vector<OH_NN_TensorInfo> nn_tensor_infos; 4521be168c0dSopenharmony_ci+ for (auto index: indices) { 4522be168c0dSopenharmony_ci+ auto tensor = lite_graph_->all_tensors_[index]; 4523be168c0dSopenharmony_ci+ auto shape = tensor->dims(); 4524be168c0dSopenharmony_ci+ auto data_type = tensor->dataType(); 4525be168c0dSopenharmony_ci+ auto name = tensor->name(); 4526be168c0dSopenharmony_ci+ auto format = tensor->format(); 4527be168c0dSopenharmony_ci 4528be168c0dSopenharmony_ci- OH_NNCompilation *oh_nn_compilation = nullptr; 4529be168c0dSopenharmony_ci- oh_nn_compilation = OH_NNCompilation_Construct(oh_nnmodel); 4530be168c0dSopenharmony_ci+ OH_NN_TensorInfo info; 4531be168c0dSopenharmony_ci+ info.dataType = CastToNNRTDataType(static_cast<mindspore::DataType>(data_type)); 4532be168c0dSopenharmony_ci+ info.dimensions = shape->data(); 4533be168c0dSopenharmony_ci+ info.dimensionCount = shape->size(); 4534be168c0dSopenharmony_ci+ strcpy(info.name, name->c_str()); 4535be168c0dSopenharmony_ci+ info.format = CastToNNRTFormat(static_cast<Format>(format)); 4536be168c0dSopenharmony_ci+ nn_tensor_infos.push_back(info); 4537be168c0dSopenharmony_ci+ } 4538be168c0dSopenharmony_ci+ return nn_tensor_infos; 4539be168c0dSopenharmony_ci+} 4540be168c0dSopenharmony_ci 4541be168c0dSopenharmony_ci- if (oh_nn_compilation == nullptr) { 4542be168c0dSopenharmony_ci+Status NNRTDelegate::SetKirinModelInputsAndOutputs(OH_NNModel *nn_model) { 4543be168c0dSopenharmony_ci+ std::vector<OH_NN_TensorInfo> inputInfos; 4544be168c0dSopenharmony_ci+ std::vector<OH_NN_TensorInfo> outputInfos; 4545be168c0dSopenharmony_ci+ auto input_infos = CreateNNTensorInfos(lite_graph_->input_indices_); 4546be168c0dSopenharmony_ci+ auto output_infos = CreateNNTensorInfos(lite_graph_->output_indices_); 4547be168c0dSopenharmony_ci+ OH_NNModel_SetInputsAndOutputsInfo(nn_model, input_infos.data(), input_infos.size(), output_infos.data(), 4548be168c0dSopenharmony_ci+ output_infos.size()); 4549be168c0dSopenharmony_ci+ return kSuccess; 4550be168c0dSopenharmony_ci+} 4551be168c0dSopenharmony_ci+ 4552be168c0dSopenharmony_ci+Status NNRTDelegate::CreateFullModelKernel(DelegateModel<schema::Primitive> *model, OH_NNModel *nn_model) { 4553be168c0dSopenharmony_ci+ OH_NNCompilation *nn_compilation = OH_NNCompilation_Construct(nn_model); 4554be168c0dSopenharmony_ci+ if (nn_compilation == nullptr) { 4555be168c0dSopenharmony_ci MS_LOG(ERROR) << "Construct NNCompilation failed"; 4556be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4557be168c0dSopenharmony_ci- return mindspore::kLiteError; 4558be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4559be168c0dSopenharmony_ci+ return kLiteError; 4560be168c0dSopenharmony_ci } 4561be168c0dSopenharmony_ci- MS_LOG(INFO) << "NNRTDelegate creates NNCompilation success."; 4562be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success."; 4563be168c0dSopenharmony_ci 4564be168c0dSopenharmony_ci- const size_t *allDevicesID = nullptr; 4565be168c0dSopenharmony_ci- uint32_t device_count = 0; 4566be168c0dSopenharmony_ci- ret_code = OH_NNDevice_GetAllDevicesID(&allDevicesID, &device_count); 4567be168c0dSopenharmony_ci- if (ret_code != OH_NN_SUCCESS) { 4568be168c0dSopenharmony_ci- MS_LOG(ERROR) << "NNModel GetAllDevicesID failed, OH_NN_ReturnCode = " << ret_code; 4569be168c0dSopenharmony_ci- OH_NNCompilation_Destroy(&oh_nn_compilation); 4570be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4571be168c0dSopenharmony_ci- return mindspore::kLiteError; 4572be168c0dSopenharmony_ci+ auto ret_code = InitNNCompilation(nn_compilation); 4573be168c0dSopenharmony_ci+ if (ret_code != kSuccess) { 4574be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Init NNCompilation failed"; 4575be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4576be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 4577be168c0dSopenharmony_ci+ return kLiteError; 4578be168c0dSopenharmony_ci } 4579be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4580be168c0dSopenharmony_ci 4581be168c0dSopenharmony_ci- if (device_count <= 0) { 4582be168c0dSopenharmony_ci- MS_LOG(WARNING) << "No NNRt Device found, fall back to CPU. "; 4583be168c0dSopenharmony_ci- // OH_NNCompilation_Destroy(&oh_nn_compilation); 4584be168c0dSopenharmony_ci- // OH_NNModel_Destroy(&oh_nnmodel); 4585be168c0dSopenharmony_ci- return mindspore::kSuccess; 4586be168c0dSopenharmony_ci+ OH_NNExecutor *nn_executor = nullptr; 4587be168c0dSopenharmony_ci+ nn_executor = OH_NNExecutor_Construct(nn_compilation); 4588be168c0dSopenharmony_ci+ if (nn_executor == nullptr) { 4589be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code; 4590be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 4591be168c0dSopenharmony_ci+ return kLiteError; 4592be168c0dSopenharmony_ci } 4593be168c0dSopenharmony_ci- MS_LOG(INFO) << "NNRTDelegate GetAllDevicesID success."; 4594be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 4595be168c0dSopenharmony_ci 4596be168c0dSopenharmony_ci- // check if model ops are supported 4597be168c0dSopenharmony_ci- const bool *issupported = nullptr; 4598be168c0dSopenharmony_ci+ auto nnrt_model_kernel = new (std::nothrow)NNRTModelKernel(nn_executor, model->inputs(), model->outputs()); 4599be168c0dSopenharmony_ci+ if (nnrt_model_kernel == nullptr) { 4600be168c0dSopenharmony_ci+ OH_NNExecutor_Destroy(&nn_executor); 4601be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new NNRTModelKernel failed"; 4602be168c0dSopenharmony_ci+ return kLiteError; 4603be168c0dSopenharmony_ci+ } 4604be168c0dSopenharmony_ci+ model->Replace(model->BeginKernelIterator(), model->EndKernelIterator(), nnrt_model_kernel); 4605be168c0dSopenharmony_ci+ return kSuccess; 4606be168c0dSopenharmony_ci+} 4607be168c0dSopenharmony_ci+#endif 4608be168c0dSopenharmony_ci+ 4609be168c0dSopenharmony_ci+Status NNRTDelegate::BuildNormalModel(DelegateModel<schema::Primitive> *model) { 4610be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "Start to build NNRT model."; 4611be168c0dSopenharmony_ci+ if ((lite_graph_ == nullptr) || (lite_graph_->sub_graphs_.size() > 1)) { 4612be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "LiteGraph contains more than one subgraph. NNRT does not support control-flow model yet, fallback to CPU"; 4613be168c0dSopenharmony_ci+ return kSuccess; 4614be168c0dSopenharmony_ci+ } 4615be168c0dSopenharmony_ci+ 4616be168c0dSopenharmony_ci+ OH_NNModel *full_model = CreateFullNNModel(); 4617be168c0dSopenharmony_ci+ if (full_model == nullptr) { 4618be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Build full NNModel failed, fallback to CPU"; 4619be168c0dSopenharmony_ci+ return kSuccess; 4620be168c0dSopenharmony_ci+ } 4621be168c0dSopenharmony_ci+ std::vector<bool> op_supports = QueryOpSupports(full_model); 4622be168c0dSopenharmony_ci+ if (op_supports.empty()) { 4623be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Query no op supports for full model, fallback to CPU"; 4624be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&full_model); 4625be168c0dSopenharmony_ci+ return kSuccess; 4626be168c0dSopenharmony_ci+ } 4627be168c0dSopenharmony_ci+ auto nnrt_subgraph_ranges = GetNNRTSubgraphRanges(model, op_supports); 4628be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Found NNRT subgraph count: " << nnrt_subgraph_ranges.size(); 4629be168c0dSopenharmony_ci+ 4630be168c0dSopenharmony_ci+ std::vector<LiteGraph *> sub_lite_graphs; 4631be168c0dSopenharmony_ci+ auto ret = CreateLiteGraphForNNRTSubgraph(nnrt_subgraph_ranges, &sub_lite_graphs); 4632be168c0dSopenharmony_ci+ if (ret != kSuccess) { 4633be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&full_model); 4634be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Create NNRT sub LiteGraph failed, fallback to CPU"; 4635be168c0dSopenharmony_ci+ return kSuccess; 4636be168c0dSopenharmony_ci+ } 4637be168c0dSopenharmony_ci+ 4638be168c0dSopenharmony_ci+ std::vector<NNRTModelKernel *> nnrt_subgraph_kernels; 4639be168c0dSopenharmony_ci+ ret = CreateNNRTSubgraphKernels(model, sub_lite_graphs, nnrt_subgraph_ranges, &nnrt_subgraph_kernels); 4640be168c0dSopenharmony_ci+ if (ret != kSuccess) { 4641be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&full_model); 4642be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Create NNRT subgraph kernel failed, fallback to CPU"; 4643be168c0dSopenharmony_ci+ return kSuccess; 4644be168c0dSopenharmony_ci+ } 4645be168c0dSopenharmony_ci+ 4646be168c0dSopenharmony_ci+ ReplaceNNRTKernelsInDelegateModel(model, nnrt_subgraph_ranges, nnrt_subgraph_kernels); 4647be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&full_model); 4648be168c0dSopenharmony_ci+ MS_LOG(INFO) << "NNRTDelegate build success."; 4649be168c0dSopenharmony_ci+ return kSuccess; 4650be168c0dSopenharmony_ci+} 4651be168c0dSopenharmony_ci+ 4652be168c0dSopenharmony_ci+OH_NNModel *NNRTDelegate::CreateFullNNModel() { 4653be168c0dSopenharmony_ci+ if (lite_graph_ == nullptr) { 4654be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Lite graph is null"; 4655be168c0dSopenharmony_ci+ return nullptr; 4656be168c0dSopenharmony_ci+ } 4657be168c0dSopenharmony_ci+ 4658be168c0dSopenharmony_ci+ if (lite_graph_->sub_graphs_.empty()) { 4659be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Lite graph must have at lease one subgraph"; 4660be168c0dSopenharmony_ci+ return nullptr; 4661be168c0dSopenharmony_ci+ } 4662be168c0dSopenharmony_ci+ 4663be168c0dSopenharmony_ci+ OH_NNModel *nn_model = OH_NNModel_Construct(); 4664be168c0dSopenharmony_ci+ if (nn_model == nullptr) { 4665be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create NNModel failed, result is nullptr"; 4666be168c0dSopenharmony_ci+ return nullptr; 4667be168c0dSopenharmony_ci+ } 4668be168c0dSopenharmony_ci+ 4669be168c0dSopenharmony_ci+ auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, lite_graph_); 4670be168c0dSopenharmony_ci+ if (ret != OH_NN_SUCCESS) { 4671be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret; 4672be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4673be168c0dSopenharmony_ci+ return nullptr; 4674be168c0dSopenharmony_ci+ } 4675be168c0dSopenharmony_ci+ return nn_model; 4676be168c0dSopenharmony_ci+} 4677be168c0dSopenharmony_ci+ 4678be168c0dSopenharmony_ci+std::vector<bool> NNRTDelegate::QueryOpSupports(OH_NNModel *nn_model) { 4679be168c0dSopenharmony_ci+ const bool *is_supported = nullptr; // Note: this memory is owned by nn_model, don't free alone. 4680be168c0dSopenharmony_ci uint32_t op_count = 0; 4681be168c0dSopenharmony_ci- ret_code = OH_NNModel_GetAvailableOperations(oh_nnmodel, allDevicesID[0], &issupported, &op_count); 4682be168c0dSopenharmony_ci- if (ret_code != OH_NN_SUCCESS) { 4683be168c0dSopenharmony_ci- MS_LOG(ERROR) << "NNModel GetAvailableOperations failed, OH_NN_ReturnCode = " << ret_code 4684be168c0dSopenharmony_ci- << ", maybe due to dataParcel data length limitaion. Fall back to CPU."; 4685be168c0dSopenharmony_ci- OH_NNCompilation_Destroy(&oh_nn_compilation); 4686be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4687be168c0dSopenharmony_ci- return mindspore::kSuccess; 4688be168c0dSopenharmony_ci+ auto ret = OH_NNModel_GetAvailableOperations(nn_model, nnrt_device_info_.device_id_, &is_supported, &op_count); 4689be168c0dSopenharmony_ci+ if (ret != OH_NN_SUCCESS) { 4690be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "NNModel GetAvailableOperations failed, ret: " << ret 4691be168c0dSopenharmony_ci+ << ", maybe caused by dataParcel data length limitation"; 4692be168c0dSopenharmony_ci+ return {}; 4693be168c0dSopenharmony_ci } 4694be168c0dSopenharmony_ci- uint32_t supported_op_count = 0; 4695be168c0dSopenharmony_ci- for (uint32_t i = 0; i < op_count; i++) { 4696be168c0dSopenharmony_ci- if (issupported[i]) { 4697be168c0dSopenharmony_ci- supported_op_count++; 4698be168c0dSopenharmony_ci+ std::vector<bool> op_supports(is_supported, is_supported + op_count); 4699be168c0dSopenharmony_ci+ return op_supports; 4700be168c0dSopenharmony_ci+} 4701be168c0dSopenharmony_ci+ 4702be168c0dSopenharmony_ci+/* Find continuous sub-sequence in op_supports. */ 4703be168c0dSopenharmony_ci+std::vector<NNRTOpRange> NNRTDelegate::GetNNRTSubgraphRanges(DelegateModel<schema::Primitive> *model, 4704be168c0dSopenharmony_ci+ const std::vector<bool> &op_supports) { 4705be168c0dSopenharmony_ci+ std::vector<NNRTOpRange> nnrt_subgraph_ranges; 4706be168c0dSopenharmony_ci+ NNRTOpRange op_range; 4707be168c0dSopenharmony_ci+ bool start_count = false; 4708be168c0dSopenharmony_ci+ for (size_t i = 0; i < op_supports.size(); i++) { 4709be168c0dSopenharmony_ci+ if (op_supports[i]) { 4710be168c0dSopenharmony_ci+ if (start_count == false) { 4711be168c0dSopenharmony_ci+ start_count = true; 4712be168c0dSopenharmony_ci+ op_range.begin_index_ = i; 4713be168c0dSopenharmony_ci+ op_range.begin_iter_ = model->BeginKernelIterator() + i; 4714be168c0dSopenharmony_ci+ } 4715be168c0dSopenharmony_ci+ } else { 4716be168c0dSopenharmony_ci+ if (start_count == true) { 4717be168c0dSopenharmony_ci+ start_count = false; 4718be168c0dSopenharmony_ci+ op_range.end_index_ = i; 4719be168c0dSopenharmony_ci+ op_range.end_iter_ = model->BeginKernelIterator() + i; 4720be168c0dSopenharmony_ci+ nnrt_subgraph_ranges.push_back(op_range); 4721be168c0dSopenharmony_ci+ } 4722be168c0dSopenharmony_ci } 4723be168c0dSopenharmony_ci } 4724be168c0dSopenharmony_ci- if (op_count != supported_op_count) { 4725be168c0dSopenharmony_ci- MS_LOG(WARNING) << "this model has " << op_count << "ops, but NNRT only support " << supported_op_count 4726be168c0dSopenharmony_ci- << " ops, fall back to CPU."; 4727be168c0dSopenharmony_ci- // must support all op, else fall back to CPU 4728be168c0dSopenharmony_ci- OH_NNCompilation_Destroy(&oh_nn_compilation); 4729be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4730be168c0dSopenharmony_ci- return mindspore::kSuccess; 4731be168c0dSopenharmony_ci+ // handle last true subsequence 4732be168c0dSopenharmony_ci+ if (start_count == true) { 4733be168c0dSopenharmony_ci+ op_range.end_index_ = op_supports.size(); 4734be168c0dSopenharmony_ci+ op_range.end_iter_ = model->EndKernelIterator(); 4735be168c0dSopenharmony_ci+ nnrt_subgraph_ranges.push_back(op_range); 4736be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Schedule NNRT subgraph range: [" << op_range.begin_index_ << ", " << op_range.end_index_ << ")"; 4737be168c0dSopenharmony_ci } 4738be168c0dSopenharmony_ci- MS_LOG(INFO) << "NNRtDelegate supports all op in this model."; 4739be168c0dSopenharmony_ci+ return nnrt_subgraph_ranges; 4740be168c0dSopenharmony_ci+} 4741be168c0dSopenharmony_ci+ 4742be168c0dSopenharmony_ci+/** 4743be168c0dSopenharmony_ci+ * This method ONLY works when the follow pre-conditions are satisfied: 4744be168c0dSopenharmony_ci+ * 1. The node order of lite_graph_->all_nodes should be consistent with DelegateModel sequence. 4745be168c0dSopenharmony_ci+ * This ensures the kernel replacement in DelegateModel based on the re-organizing info from lite_graph_ is correct. 4746be168c0dSopenharmony_ci+ * 2. The node indices of lite_graph_->sub_graphs[0].node_indices should be monotonically increasing from 0 to size - 1. 4747be168c0dSopenharmony_ci+ */ 4748be168c0dSopenharmony_ci+Status NNRTDelegate::CreateLiteGraphForNNRTSubgraph( 4749be168c0dSopenharmony_ci+ const std::vector<NNRTOpRange> &nnrt_op_ranges, 4750be168c0dSopenharmony_ci+ std::vector<LiteGraph *> *sub_lite_graphs) { 4751be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Start creating LiteGraph for NNRT subgraph"; 4752be168c0dSopenharmony_ci+ for (const auto &op_range: nnrt_op_ranges) { 4753be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Process op range: [" << op_range.begin_index_ << ", " << op_range.end_index_ << ")"; 4754be168c0dSopenharmony_ci+ LiteGraph *sub_lite_graph = new (std::nothrow)LiteGraph; 4755be168c0dSopenharmony_ci+ if (sub_lite_graph == nullptr) { 4756be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Allocate LiteGraph failed"; 4757be168c0dSopenharmony_ci+ return kLiteError; 4758be168c0dSopenharmony_ci+ } 4759be168c0dSopenharmony_ci+ sub_lite_graph->name_ = lite_graph_->name_; 4760be168c0dSopenharmony_ci+ sub_lite_graph->version_ = lite_graph_->version_; 4761be168c0dSopenharmony_ci 4762be168c0dSopenharmony_ci- ret_code = OH_NNCompilation_SetDevice(oh_nn_compilation, allDevicesID[0]); 4763be168c0dSopenharmony_ci+ auto sub_graph = new (std::nothrow)LiteGraph::SubGraph; 4764be168c0dSopenharmony_ci+ if (sub_graph == nullptr) { 4765be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Allocate SubGraph failed"; 4766be168c0dSopenharmony_ci+ return kLiteError; 4767be168c0dSopenharmony_ci+ } 4768be168c0dSopenharmony_ci+ sub_graph->name_ = lite_graph_->name_; 4769be168c0dSopenharmony_ci+ sub_lite_graph->sub_graphs_.push_back(sub_graph); 4770be168c0dSopenharmony_ci 4771be168c0dSopenharmony_ci+ // deal with all_nodes 4772be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Assemble all_nodes..."; 4773be168c0dSopenharmony_ci+ int new_node_index = 0; 4774be168c0dSopenharmony_ci+ std::map<uint32_t, schema::Tensor *> in_tensor_index_map; 4775be168c0dSopenharmony_ci+ std::map<uint32_t, schema::Tensor *> out_tensor_index_map; 4776be168c0dSopenharmony_ci+ for (size_t index = op_range.begin_index_; index < op_range.end_index_; index++) { 4777be168c0dSopenharmony_ci+ LiteGraph::Node *node = new (std::nothrow)LiteGraph::Node; 4778be168c0dSopenharmony_ci+ if (node == nullptr) { 4779be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Allocate Node failed"; 4780be168c0dSopenharmony_ci+ return kLiteError; 4781be168c0dSopenharmony_ci+ } 4782be168c0dSopenharmony_ci+ *node = *lite_graph_->all_nodes_[index]; 4783be168c0dSopenharmony_ci+ sub_lite_graph->all_nodes_.push_back(node); 4784be168c0dSopenharmony_ci+ sub_graph->node_indices_.push_back(new_node_index++); 4785be168c0dSopenharmony_ci+ 4786be168c0dSopenharmony_ci+ for (auto i: node->input_indices_) { 4787be168c0dSopenharmony_ci+ in_tensor_index_map.emplace(i, lite_graph_->all_tensors_[i]); 4788be168c0dSopenharmony_ci+ } 4789be168c0dSopenharmony_ci+ for (auto i: node->output_indices_) { 4790be168c0dSopenharmony_ci+ out_tensor_index_map.emplace(i, lite_graph_->all_tensors_[i]); 4791be168c0dSopenharmony_ci+ } 4792be168c0dSopenharmony_ci+ } 4793be168c0dSopenharmony_ci+ 4794be168c0dSopenharmony_ci+ // deal with all_tensors 4795be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Assemble all_tensors..."; 4796be168c0dSopenharmony_ci+ std::set<schema::Tensor *> tensors; 4797be168c0dSopenharmony_ci+ for (auto iter: in_tensor_index_map) { 4798be168c0dSopenharmony_ci+ tensors.emplace(iter.second); 4799be168c0dSopenharmony_ci+ } 4800be168c0dSopenharmony_ci+ for (auto iter: out_tensor_index_map) { 4801be168c0dSopenharmony_ci+ tensors.emplace(iter.second); 4802be168c0dSopenharmony_ci+ } 4803be168c0dSopenharmony_ci+ 4804be168c0dSopenharmony_ci+ uint32_t new_index = 0; 4805be168c0dSopenharmony_ci+ std::map<schema::Tensor *, uint32_t> new_tensor_maps; 4806be168c0dSopenharmony_ci+ for (auto tensor: tensors) { 4807be168c0dSopenharmony_ci+ new_tensor_maps.emplace(tensor, new_index++); 4808be168c0dSopenharmony_ci+ } 4809be168c0dSopenharmony_ci+ 4810be168c0dSopenharmony_ci+ sub_lite_graph->all_tensors_ = std::vector<schema::Tensor *>(tensors.begin(), tensors.end()); 4811be168c0dSopenharmony_ci+ 4812be168c0dSopenharmony_ci+ // deal with every node's input/output indices 4813be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Set input/output indices of each node..."; 4814be168c0dSopenharmony_ci+ for (auto node: sub_lite_graph->all_nodes_) { 4815be168c0dSopenharmony_ci+ for (auto &index : node->input_indices_) { 4816be168c0dSopenharmony_ci+ index = new_tensor_maps.at(in_tensor_index_map.at(index)); 4817be168c0dSopenharmony_ci+ } 4818be168c0dSopenharmony_ci+ for (auto &index : node->output_indices_) { 4819be168c0dSopenharmony_ci+ index = new_tensor_maps.at(out_tensor_index_map.at(index)); 4820be168c0dSopenharmony_ci+ } 4821be168c0dSopenharmony_ci+ } 4822be168c0dSopenharmony_ci+ 4823be168c0dSopenharmony_ci+ // deal with subgraph's input/output indices 4824be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Set input/output indices of each subgraph..."; 4825be168c0dSopenharmony_ci+ sub_graph->tensor_indices_ = std::vector<uint32_t>(tensors.size()); 4826be168c0dSopenharmony_ci+ std::iota(sub_graph->tensor_indices_.begin(), sub_graph->tensor_indices_.end(), 0U); 4827be168c0dSopenharmony_ci+ 4828be168c0dSopenharmony_ci+ for (auto iter: in_tensor_index_map) { 4829be168c0dSopenharmony_ci+ auto new_tensor_index = new_tensor_maps[iter.second]; 4830be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "handle input: old: " << iter.first << ", new: " << new_tensor_index << std::endl; 4831be168c0dSopenharmony_ci+ if (IsConstTensor(*iter.second)) { 4832be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is const." << std::endl; 4833be168c0dSopenharmony_ci+ continue; 4834be168c0dSopenharmony_ci+ } 4835be168c0dSopenharmony_ci+ 4836be168c0dSopenharmony_ci+ bool is_subgraph_input = true; 4837be168c0dSopenharmony_ci+ for (auto node: sub_lite_graph->all_nodes_) { 4838be168c0dSopenharmony_ci+ if (std::find(node->output_indices_.begin(), node->output_indices_.end(), new_tensor_index) != 4839be168c0dSopenharmony_ci+ node->output_indices_.end()) { 4840be168c0dSopenharmony_ci+ is_subgraph_input = false; 4841be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is not subgraph input." << std::endl; 4842be168c0dSopenharmony_ci+ break; 4843be168c0dSopenharmony_ci+ } 4844be168c0dSopenharmony_ci+ } 4845be168c0dSopenharmony_ci+ if (is_subgraph_input) { 4846be168c0dSopenharmony_ci+ sub_graph->input_indices_.push_back(new_tensor_index); 4847be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "- select tensor: " << new_tensor_index << " as subgraph input." << std::endl; 4848be168c0dSopenharmony_ci+ } 4849be168c0dSopenharmony_ci+ } 4850be168c0dSopenharmony_ci+ 4851be168c0dSopenharmony_ci+ for (auto iter: out_tensor_index_map) { 4852be168c0dSopenharmony_ci+ int new_tensor_index = new_tensor_maps.at(iter.second); 4853be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "handle output: old: " << iter.first << ", new: " << new_tensor_index << std::endl; 4854be168c0dSopenharmony_ci+ if (IsConstTensor(*iter.second)) { 4855be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is const." << std::endl; 4856be168c0dSopenharmony_ci+ continue; 4857be168c0dSopenharmony_ci+ } 4858be168c0dSopenharmony_ci+ 4859be168c0dSopenharmony_ci+ bool is_subgraph_output = false; 4860be168c0dSopenharmony_ci+ for (size_t i = 0; i < lite_graph_->all_nodes_.size(); i++) { 4861be168c0dSopenharmony_ci+ if ((i >= op_range.begin_index_) && (i < op_range.end_index_)) { 4862be168c0dSopenharmony_ci+ continue; 4863be168c0dSopenharmony_ci+ } 4864be168c0dSopenharmony_ci+ auto node = lite_graph_->all_nodes_[i]; 4865be168c0dSopenharmony_ci+ if (std::find(node->input_indices_.begin(), node->input_indices_.end(), iter.first) != 4866be168c0dSopenharmony_ci+ node->input_indices_.end()) { // As the input of node which does not belong to the subgraph. 4867be168c0dSopenharmony_ci+ is_subgraph_output = true; 4868be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is original subgraph output. node: " << node->primitive_ << std::endl; 4869be168c0dSopenharmony_ci+ break; 4870be168c0dSopenharmony_ci+ } 4871be168c0dSopenharmony_ci+ } 4872be168c0dSopenharmony_ci+ bool is_graph_output = (std::find(lite_graph_->output_indices_.begin(),lite_graph_->output_indices_.end(), 4873be168c0dSopenharmony_ci+ iter.first) != lite_graph_->output_indices_.end()); 4874be168c0dSopenharmony_ci+ if (is_graph_output) { 4875be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is graph output." << std::endl; 4876be168c0dSopenharmony_ci+ } 4877be168c0dSopenharmony_ci+ if (is_subgraph_output || is_graph_output) { 4878be168c0dSopenharmony_ci+ sub_graph->output_indices_.push_back(new_tensor_index); 4879be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "- select tensor: " << new_tensor_index << " as subgraph output." << std::endl; 4880be168c0dSopenharmony_ci+ } 4881be168c0dSopenharmony_ci+ } 4882be168c0dSopenharmony_ci+ 4883be168c0dSopenharmony_ci+ // deal with full-graph's input/output indices 4884be168c0dSopenharmony_ci+ sub_lite_graph->input_indices_ = sub_graph->input_indices_; 4885be168c0dSopenharmony_ci+ sub_lite_graph->output_indices_ = sub_graph->output_indices_; 4886be168c0dSopenharmony_ci+ sub_lite_graphs->push_back(sub_lite_graph); 4887be168c0dSopenharmony_ci+ } 4888be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Finished creating LiteGraph for NNRT subgraph"; 4889be168c0dSopenharmony_ci+ return kSuccess; 4890be168c0dSopenharmony_ci+} 4891be168c0dSopenharmony_ci+ 4892be168c0dSopenharmony_ci+struct TensorLocation { 4893be168c0dSopenharmony_ci+ uint32_t node_index; // the index of node which the tensor belongs to. 4894be168c0dSopenharmony_ci+ uint32_t tensor_index; // the index of node in/out tensors which the tensor is located at. 4895be168c0dSopenharmony_ci+}; 4896be168c0dSopenharmony_ci+ 4897be168c0dSopenharmony_ci+Status NNRTDelegate::InitNNCompilation(OH_NNCompilation *nn_compilation) const { 4898be168c0dSopenharmony_ci+ auto ret_code = OH_NNCompilation_SetDevice(nn_compilation, nnrt_device_info_.device_id_); 4899be168c0dSopenharmony_ci if (ret_code != OH_NN_SUCCESS) { 4900be168c0dSopenharmony_ci- MS_LOG(ERROR) << "NNCompilation SetDevice failed, OH_NN_ReturnCode = " << ret_code; 4901be168c0dSopenharmony_ci- OH_NNCompilation_Destroy(&oh_nn_compilation); 4902be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4903be168c0dSopenharmony_ci- return mindspore::kLiteError; 4904be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation set device id failed, ret: " << ret_code; 4905be168c0dSopenharmony_ci+ return kLiteError; 4906be168c0dSopenharmony_ci+ } 4907be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_SetPerformanceMode(nn_compilation, 4908be168c0dSopenharmony_ci+ (OH_NN_PerformanceMode)(nnrt_device_info_.performance_mode_)); 4909be168c0dSopenharmony_ci+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 4910be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation set performance mode failed, ret: " << ret_code; 4911be168c0dSopenharmony_ci+ return kLiteError; 4912be168c0dSopenharmony_ci+ } 4913be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_SetPriority(nn_compilation, (OH_NN_Priority)(nnrt_device_info_.priority_)); 4914be168c0dSopenharmony_ci+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 4915be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation set priority failed, ret: " << ret_code; 4916be168c0dSopenharmony_ci+ return kLiteError; 4917be168c0dSopenharmony_ci+ } 4918be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_EnableFloat16(nn_compilation, nnrt_device_info_.enable_fp16_); 4919be168c0dSopenharmony_ci+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 4920be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation enable fp16 failed, ret: " << ret_code; 4921be168c0dSopenharmony_ci+ return kLiteError; 4922be168c0dSopenharmony_ci } 4923be168c0dSopenharmony_ci 4924be168c0dSopenharmony_ci- ret_code = OH_NNCompilation_Build(oh_nn_compilation); 4925be168c0dSopenharmony_ci+ if (!cache_path_.empty()) { // Set cache path if user indeed set it. 4926be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_SetCache(nn_compilation, cache_path_.c_str(), cache_version_); 4927be168c0dSopenharmony_ci+ if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) { 4928be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NNCompilation set cache failed, ret: " << ret_code; 4929be168c0dSopenharmony_ci+ return kLiteError; 4930be168c0dSopenharmony_ci+ } 4931be168c0dSopenharmony_ci+ } 4932be168c0dSopenharmony_ci 4933be168c0dSopenharmony_ci+ ret_code = OH_NNCompilation_Build(nn_compilation); 4934be168c0dSopenharmony_ci if (ret_code != OH_NN_SUCCESS) { 4935be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Build NNCompilation failed, OH_NN_ReturnCode = " << ret_code; 4936be168c0dSopenharmony_ci- OH_NNCompilation_Destroy(&oh_nn_compilation); 4937be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4938be168c0dSopenharmony_ci- return mindspore::kLiteError; 4939be168c0dSopenharmony_ci- } 4940be168c0dSopenharmony_ci- 4941be168c0dSopenharmony_ci- MS_LOG(DEBUG) << "NNRTDelegate SetDevice success."; 4942be168c0dSopenharmony_ci- 4943be168c0dSopenharmony_ci- OH_NNExecutor *oh_nn_executor = nullptr; 4944be168c0dSopenharmony_ci- oh_nn_executor = OH_NNExecutor_Construct(oh_nn_compilation); 4945be168c0dSopenharmony_ci- if (oh_nn_executor == nullptr) { 4946be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Construct NNCompilation SetDevice failed, OH_NN_ReturnCode = " << ret_code; 4947be168c0dSopenharmony_ci- OH_NNCompilation_Destroy(&oh_nn_compilation); 4948be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4949be168c0dSopenharmony_ci- return mindspore::kLiteError; 4950be168c0dSopenharmony_ci- } 4951be168c0dSopenharmony_ci- MS_LOG(DEBUG) << "NNRTDelegate creates NNExecutor success."; 4952be168c0dSopenharmony_ci- mindspore::Status prepare_data_ret; 4953be168c0dSopenharmony_ci- auto nnr_model_kernel = new (std::nothrow) NNRTModelKernel(oh_nn_executor, model->inputs(), model->outputs()); 4954be168c0dSopenharmony_ci- if (nnr_model_kernel == nullptr) { 4955be168c0dSopenharmony_ci- MS_LOG(ERROR) << "new NNRTModelKernel failed"; 4956be168c0dSopenharmony_ci- return mindspore::kLiteError; 4957be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Build NNCompilation failed, ret: " << ret_code; 4958be168c0dSopenharmony_ci+ return kLiteError; 4959be168c0dSopenharmony_ci } 4960be168c0dSopenharmony_ci- OH_NNCompilation_Destroy(&oh_nn_compilation); 4961be168c0dSopenharmony_ci- OH_NNModel_Destroy(&oh_nnmodel); 4962be168c0dSopenharmony_ci- KernelIter from = model->BeginKernelIterator(); 4963be168c0dSopenharmony_ci- KernelIter end = model->EndKernelIterator(); 4964be168c0dSopenharmony_ci- model->Replace(from, end, nnr_model_kernel); 4965be168c0dSopenharmony_ci+ return kSuccess; 4966be168c0dSopenharmony_ci+} 4967be168c0dSopenharmony_ci+ 4968be168c0dSopenharmony_ci+Status NNRTDelegate::CreateNNRTSubgraphKernels(DelegateModel<schema::Primitive> *model, 4969be168c0dSopenharmony_ci+ const std::vector<LiteGraph *> &sub_lite_graphs, const std::vector<NNRTOpRange> &nnrt_subgraph_ranges, 4970be168c0dSopenharmony_ci+ std::vector<NNRTModelKernel *> *nnrt_subgraph_kernels) { 4971be168c0dSopenharmony_ci+ for (size_t i = 0; i < sub_lite_graphs.size(); i++) { 4972be168c0dSopenharmony_ci+ auto sub_lite_graph = sub_lite_graphs[i]; 4973be168c0dSopenharmony_ci+ 4974be168c0dSopenharmony_ci+ OH_NNModel *nn_model = OH_NNModel_Construct(); 4975be168c0dSopenharmony_ci+ auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, sub_lite_graph); 4976be168c0dSopenharmony_ci+ if (ret != OH_NN_SUCCESS) { 4977be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret; 4978be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4979be168c0dSopenharmony_ci+ return kLiteError; 4980be168c0dSopenharmony_ci+ } 4981be168c0dSopenharmony_ci 4982be168c0dSopenharmony_ci- MS_LOG(INFO) << "NNRTDelegate build success."; 4983be168c0dSopenharmony_ci- return mindspore::kSuccess; 4984be168c0dSopenharmony_ci+ OH_NNCompilation *nn_compilation = OH_NNCompilation_Construct(nn_model); 4985be168c0dSopenharmony_ci+ if (nn_compilation == nullptr) { 4986be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Construct NNCompilation failed"; 4987be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4988be168c0dSopenharmony_ci+ return kLiteError; 4989be168c0dSopenharmony_ci+ } 4990be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success."; 4991be168c0dSopenharmony_ci+ 4992be168c0dSopenharmony_ci+ auto ret_code = InitNNCompilation(nn_compilation); 4993be168c0dSopenharmony_ci+ if (ret_code != kSuccess) { 4994be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Init NNCompilation failed"; 4995be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 4996be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 4997be168c0dSopenharmony_ci+ return kLiteError; 4998be168c0dSopenharmony_ci+ } 4999be168c0dSopenharmony_ci+ 5000be168c0dSopenharmony_ci+ OH_NNExecutor *nn_executor = nullptr; 5001be168c0dSopenharmony_ci+ nn_executor = OH_NNExecutor_Construct(nn_compilation); 5002be168c0dSopenharmony_ci+ if (nn_executor == nullptr) { 5003be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code; 5004be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 5005be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 5006be168c0dSopenharmony_ci+ return kLiteError; 5007be168c0dSopenharmony_ci+ } 5008be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "NNRTDelegate creates NNExecutor success."; 5009be168c0dSopenharmony_ci+ 5010be168c0dSopenharmony_ci+ bool format_not_support = false; 5011be168c0dSopenharmony_ci+ std::vector<MSTensor> in_tensors; 5012be168c0dSopenharmony_ci+ for (auto index: sub_lite_graph->sub_graphs_[0]->input_indices_) { 5013be168c0dSopenharmony_ci+ TensorLocation location; 5014be168c0dSopenharmony_ci+ for (auto node_index: sub_lite_graph->sub_graphs_[0]->node_indices_) { 5015be168c0dSopenharmony_ci+ auto node = sub_lite_graph->all_nodes_[node_index]; 5016be168c0dSopenharmony_ci+ auto iter = std::find(node->input_indices_.begin(), node->input_indices_.end(), index); 5017be168c0dSopenharmony_ci+ if (iter != node->input_indices_.end()) { 5018be168c0dSopenharmony_ci+ uint32_t tensor_index = iter - node->input_indices_.begin(); 5019be168c0dSopenharmony_ci+ location.node_index = node_index; 5020be168c0dSopenharmony_ci+ location.tensor_index = tensor_index; 5021be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Found graph input index: " << index << " is the " << tensor_index << "th input of the node " << node->primitive_; 5022be168c0dSopenharmony_ci+ break; 5023be168c0dSopenharmony_ci+ } 5024be168c0dSopenharmony_ci+ } 5025be168c0dSopenharmony_ci+ KernelIter kernel_iter = nnrt_subgraph_ranges[i].begin_iter_ + location.node_index; 5026be168c0dSopenharmony_ci+ in_tensors.push_back((*kernel_iter)->inputs()[location.tensor_index]); 5027be168c0dSopenharmony_ci+ if (in_tensors.back().format() != Format::NHWC) { 5028be168c0dSopenharmony_ci+ format_not_support = true; 5029be168c0dSopenharmony_ci+ break ; 5030be168c0dSopenharmony_ci+ } 5031be168c0dSopenharmony_ci+ } 5032be168c0dSopenharmony_ci+ 5033be168c0dSopenharmony_ci+ std::vector<MSTensor> out_tensors; 5034be168c0dSopenharmony_ci+ for (auto index: sub_lite_graph->sub_graphs_[0]->output_indices_) { 5035be168c0dSopenharmony_ci+ TensorLocation location; 5036be168c0dSopenharmony_ci+ for (auto node_index: sub_lite_graph->sub_graphs_[0]->node_indices_) { 5037be168c0dSopenharmony_ci+ auto node = sub_lite_graph->all_nodes_[node_index]; 5038be168c0dSopenharmony_ci+ auto iter = std::find(node->output_indices_.begin(), node->output_indices_.end(), index); 5039be168c0dSopenharmony_ci+ if (iter != node->output_indices_.end()) { 5040be168c0dSopenharmony_ci+ uint32_t tensor_index = iter - node->output_indices_.begin(); 5041be168c0dSopenharmony_ci+ location.node_index = node_index; 5042be168c0dSopenharmony_ci+ location.tensor_index = tensor_index; 5043be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Found graph output index: " << index << " is the " << tensor_index << "th output of the node " << node->primitive_; 5044be168c0dSopenharmony_ci+ break; 5045be168c0dSopenharmony_ci+ } 5046be168c0dSopenharmony_ci+ } 5047be168c0dSopenharmony_ci+ KernelIter kernel_iter = nnrt_subgraph_ranges[i].begin_iter_ + location.node_index; 5048be168c0dSopenharmony_ci+ out_tensors.push_back((*kernel_iter)->outputs()[location.tensor_index]); 5049be168c0dSopenharmony_ci+ if (out_tensors.back().format() != Format::NHWC) { 5050be168c0dSopenharmony_ci+ format_not_support = true; 5051be168c0dSopenharmony_ci+ break ; 5052be168c0dSopenharmony_ci+ } 5053be168c0dSopenharmony_ci+ } 5054be168c0dSopenharmony_ci+ if (format_not_support) { 5055be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "Not support in/out tensor format, skip this subgraph"; 5056be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 5057be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 5058be168c0dSopenharmony_ci+ nnrt_subgraph_kernels->push_back(nullptr); 5059be168c0dSopenharmony_ci+ continue ; 5060be168c0dSopenharmony_ci+ } 5061be168c0dSopenharmony_ci+ 5062be168c0dSopenharmony_ci+ auto nnrt_model_kernel = new (std::nothrow)NNRTModelKernel(nn_executor, in_tensors, out_tensors); 5063be168c0dSopenharmony_ci+ if (nnrt_model_kernel == nullptr) { 5064be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new NNRTModelKernel failed"; 5065be168c0dSopenharmony_ci+ return kLiteError; 5066be168c0dSopenharmony_ci+ } 5067be168c0dSopenharmony_ci+ OH_NNCompilation_Destroy(&nn_compilation); 5068be168c0dSopenharmony_ci+ OH_NNModel_Destroy(&nn_model); 5069be168c0dSopenharmony_ci+ nnrt_subgraph_kernels->push_back(nnrt_model_kernel); 5070be168c0dSopenharmony_ci+ } 5071be168c0dSopenharmony_ci+ return kSuccess; 5072be168c0dSopenharmony_ci } 5073be168c0dSopenharmony_ci 5074be168c0dSopenharmony_ci-mindspore::Status mindspore::NNRTDelegate::Init() { 5075be168c0dSopenharmony_ci- MS_LOG(DEBUG) << "NNRTDelegate init success."; 5076be168c0dSopenharmony_ci- return mindspore::kSuccess; 5077be168c0dSopenharmony_ci+void NNRTDelegate::ReplaceNNRTKernelsInDelegateModel(DelegateModel<schema::Primitive> *model, 5078be168c0dSopenharmony_ci+ const std::vector<NNRTOpRange> &nnrt_subgraph_ranges, 5079be168c0dSopenharmony_ci+ const std::vector<NNRTModelKernel *> &nnrt_subgraph_kernels) { 5080be168c0dSopenharmony_ci+ // Here we perform the replacement from back to front intentionally! If replace from front to end, the kernel 5081be168c0dSopenharmony_ci+ // sequence would shrink and the later begin_iter_/end_iter_ may be erased already. 5082be168c0dSopenharmony_ci+ for (int i = nnrt_subgraph_ranges.size() - 1; i >= 0; i--) { 5083be168c0dSopenharmony_ci+ if (nnrt_subgraph_kernels[i] == nullptr) { 5084be168c0dSopenharmony_ci+ continue; 5085be168c0dSopenharmony_ci+ } 5086be168c0dSopenharmony_ci+ auto from = nnrt_subgraph_ranges[i].begin_iter_; 5087be168c0dSopenharmony_ci+ auto end = nnrt_subgraph_ranges[i].end_iter_; 5088be168c0dSopenharmony_ci+ (void)model->Replace(from, end, nnrt_subgraph_kernels[i]); 5089be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Replace nnrt subgraph kernel in range: [" << (from - model->BeginKernelIterator()) 5090be168c0dSopenharmony_ci+ << ", " << (end - model->BeginKernelIterator()) << ")"; 5091be168c0dSopenharmony_ci+ } 5092be168c0dSopenharmony_ci } 5093be168c0dSopenharmony_ci-mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::Primitive> *model, 5094be168c0dSopenharmony_ci- OH_NNExecutor *oh_nn_executor) { 5095be168c0dSopenharmony_ci+ 5096be168c0dSopenharmony_ci+Status NNRTDelegate::PrepareInputs(DelegateModel<schema::Primitive> *model, 5097be168c0dSopenharmony_ci+ OH_NNExecutor *oh_nn_executor) { 5098be168c0dSopenharmony_ci auto input_tensors = model->inputs(); 5099be168c0dSopenharmony_ci for (size_t i = 0; i < input_tensors.size(); i++) { 5100be168c0dSopenharmony_ci auto tensor = input_tensors[i]; 5101be168c0dSopenharmony_ci@@ -161,10 +654,10 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P 5102be168c0dSopenharmony_ci std::vector<double> scale; 5103be168c0dSopenharmony_ci std::vector<int32_t> zero_point; 5104be168c0dSopenharmony_ci if (!tmp_quant_param.empty()) { 5105be168c0dSopenharmony_ci- quant_param = new (std::nothrow) OH_NN_QuantParam; 5106be168c0dSopenharmony_ci+ quant_param = new(std::nothrow) OH_NN_QuantParam; 5107be168c0dSopenharmony_ci if (quant_param == nullptr) { 5108be168c0dSopenharmony_ci MS_LOG(ERROR) << "new OH_NN_QuantParam failed."; 5109be168c0dSopenharmony_ci- return mindspore::kLiteError; 5110be168c0dSopenharmony_ci+ return kLiteError; 5111be168c0dSopenharmony_ci } 5112be168c0dSopenharmony_ci for (auto qparam : tmp_quant_param) { 5113be168c0dSopenharmony_ci bit_num.emplace_back(qparam.bit_num); 5114be168c0dSopenharmony_ci@@ -176,12 +669,12 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P 5115be168c0dSopenharmony_ci quant_param->scale = scale.data(); 5116be168c0dSopenharmony_ci quant_param->zeroPoint = zero_point.data(); 5117be168c0dSopenharmony_ci } 5118be168c0dSopenharmony_ci- auto oprend = new (std::nothrow) OH_NN_Tensor; 5119be168c0dSopenharmony_ci+ auto oprend = new(std::nothrow) OH_NN_Tensor; 5120be168c0dSopenharmony_ci if (oprend == nullptr) { 5121be168c0dSopenharmony_ci MS_LOG(ERROR) << "new OH_NN_Tensor Failed"; 5122be168c0dSopenharmony_ci- return mindspore::kLiteError; 5123be168c0dSopenharmony_ci+ return kLiteError; 5124be168c0dSopenharmony_ci } 5125be168c0dSopenharmony_ci- oprend->dataType = ConvertDataType(tensor.DataType()); 5126be168c0dSopenharmony_ci+ oprend->dataType = CastToNNRTDataType(tensor.DataType()); 5127be168c0dSopenharmony_ci oprend->dimensionCount = tensor_shape.size(); 5128be168c0dSopenharmony_ci 5129be168c0dSopenharmony_ci std::vector<int32_t> dimensions_list; 5130be168c0dSopenharmony_ci@@ -191,14 +684,14 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P 5131be168c0dSopenharmony_ci } else { 5132be168c0dSopenharmony_ci MS_LOG(ERROR) << "NNExecutor SetInput failed,tensor dimension is is too large, max dim = " << INT32_MAX 5133be168c0dSopenharmony_ci << ", but get dimension = " << shape; 5134be168c0dSopenharmony_ci- return mindspore::kLiteError; 5135be168c0dSopenharmony_ci+ return kLiteError; 5136be168c0dSopenharmony_ci } 5137be168c0dSopenharmony_ci } 5138be168c0dSopenharmony_ci oprend->dimensions = dimensions_list.data(); 5139be168c0dSopenharmony_ci oprend->quantParam = quant_param; 5140be168c0dSopenharmony_ci oprend->type = OH_NN_TENSOR; 5141be168c0dSopenharmony_ci OH_NN_ReturnCode ret_code = 5142be168c0dSopenharmony_ci- OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize()); 5143be168c0dSopenharmony_ci+ OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize()); 5144be168c0dSopenharmony_ci delete (oprend); 5145be168c0dSopenharmony_ci 5146be168c0dSopenharmony_ci if (!tmp_quant_param.empty()) { 5147be168c0dSopenharmony_ci@@ -209,70 +702,41 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P 5148be168c0dSopenharmony_ci if (ret_code != OH_NN_SUCCESS) { 5149be168c0dSopenharmony_ci MS_LOG(ERROR) << "NNExecutor SetInput failed, current input tensor is" << tensor.Name() 5150be168c0dSopenharmony_ci << "OH_NN_ReturnCode = " << ret_code; 5151be168c0dSopenharmony_ci- return mindspore::kLiteError; 5152be168c0dSopenharmony_ci+ return kLiteError; 5153be168c0dSopenharmony_ci } 5154be168c0dSopenharmony_ci } 5155be168c0dSopenharmony_ci+ return kSuccess; 5156be168c0dSopenharmony_ci+} 5157be168c0dSopenharmony_ci+ 5158be168c0dSopenharmony_ci+OH_NN_DataType NNRTDelegate::CastToNNRTDataType(DataType data_type) { 5159be168c0dSopenharmony_ci+ const std::unordered_map<DataType, OH_NN_DataType> kDataTypeMap = { 5160be168c0dSopenharmony_ci+ {DataType::kNumberTypeBool, OH_NN_BOOL}, 5161be168c0dSopenharmony_ci+ {DataType::kNumberTypeInt8, OH_NN_INT8}, 5162be168c0dSopenharmony_ci+ {DataType::kNumberTypeInt16, OH_NN_INT16}, 5163be168c0dSopenharmony_ci+ {DataType::kNumberTypeInt32, OH_NN_INT32}, 5164be168c0dSopenharmony_ci+ {DataType::kNumberTypeInt64, OH_NN_INT64}, 5165be168c0dSopenharmony_ci+ {DataType::kNumberTypeUInt8, OH_NN_UINT8}, 5166be168c0dSopenharmony_ci+ {DataType::kNumberTypeUInt16, OH_NN_UINT16}, 5167be168c0dSopenharmony_ci+ {DataType::kNumberTypeUInt32, OH_NN_UINT32}, 5168be168c0dSopenharmony_ci+ {DataType::kNumberTypeUInt64, OH_NN_UINT64}, 5169be168c0dSopenharmony_ci+ {DataType::kNumberTypeFloat16, OH_NN_FLOAT16}, 5170be168c0dSopenharmony_ci+ {DataType::kNumberTypeFloat32, OH_NN_FLOAT32}, 5171be168c0dSopenharmony_ci+ {DataType::kNumberTypeFloat64, OH_NN_FLOAT64}, 5172be168c0dSopenharmony_ci+ }; 5173be168c0dSopenharmony_ci 5174be168c0dSopenharmony_ci- return mindspore::kSuccess; 5175be168c0dSopenharmony_ci+ auto iter = kDataTypeMap.find(data_type); 5176be168c0dSopenharmony_ci+ if (iter == kDataTypeMap.end()) { 5177be168c0dSopenharmony_ci+ return OH_NN_UNKNOWN; 5178be168c0dSopenharmony_ci+ } 5179be168c0dSopenharmony_ci+ return iter->second; 5180be168c0dSopenharmony_ci } 5181be168c0dSopenharmony_ci-OH_NN_DataType mindspore::NNRTDelegate::ConvertDataType(mindspore::DataType data_type) { 5182be168c0dSopenharmony_ci- OH_NN_DataType oh_data_type; 5183be168c0dSopenharmony_ci- switch (data_type) { 5184be168c0dSopenharmony_ci- case mindspore::DataType::kTypeUnknown: 5185be168c0dSopenharmony_ci- case mindspore::DataType::kObjectTypeString: 5186be168c0dSopenharmony_ci- case mindspore::DataType::kObjectTypeList: 5187be168c0dSopenharmony_ci- case mindspore::DataType::kObjectTypeTuple: 5188be168c0dSopenharmony_ci- case mindspore::DataType::kObjectTypeTensorType: 5189be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeBegin: 5190be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeEnd: 5191be168c0dSopenharmony_ci- case mindspore::DataType::kInvalidType: 5192be168c0dSopenharmony_ci- oh_data_type = OH_NN_UNKNOWN; 5193be168c0dSopenharmony_ci- break; 5194be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeBool: 5195be168c0dSopenharmony_ci- oh_data_type = OH_NN_BOOL; 5196be168c0dSopenharmony_ci- break; 5197be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeInt8: 5198be168c0dSopenharmony_ci- oh_data_type = OH_NN_INT8; 5199be168c0dSopenharmony_ci- break; 5200be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeInt16: 5201be168c0dSopenharmony_ci- oh_data_type = OH_NN_INT16; 5202be168c0dSopenharmony_ci- break; 5203be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeInt32: 5204be168c0dSopenharmony_ci- oh_data_type = OH_NN_INT32; 5205be168c0dSopenharmony_ci- break; 5206be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeInt64: 5207be168c0dSopenharmony_ci- oh_data_type = OH_NN_INT64; 5208be168c0dSopenharmony_ci- break; 5209be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeUInt8: 5210be168c0dSopenharmony_ci- oh_data_type = OH_NN_UINT8; 5211be168c0dSopenharmony_ci- break; 5212be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeUInt16: 5213be168c0dSopenharmony_ci- oh_data_type = OH_NN_UINT16; 5214be168c0dSopenharmony_ci- break; 5215be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeUInt32: 5216be168c0dSopenharmony_ci- oh_data_type = OH_NN_UINT32; 5217be168c0dSopenharmony_ci- break; 5218be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeUInt64: 5219be168c0dSopenharmony_ci- oh_data_type = OH_NN_UINT64; 5220be168c0dSopenharmony_ci- break; 5221be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeFloat16: 5222be168c0dSopenharmony_ci- oh_data_type = OH_NN_FLOAT16; 5223be168c0dSopenharmony_ci- break; 5224be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeFloat32: 5225be168c0dSopenharmony_ci- oh_data_type = OH_NN_FLOAT32; 5226be168c0dSopenharmony_ci- break; 5227be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeFloat64: 5228be168c0dSopenharmony_ci- oh_data_type = OH_NN_FLOAT64; 5229be168c0dSopenharmony_ci- break; 5230be168c0dSopenharmony_ci- default: { 5231be168c0dSopenharmony_ci- oh_data_type = OH_NN_UNKNOWN; 5232be168c0dSopenharmony_ci- } 5233be168c0dSopenharmony_ci- } 5234be168c0dSopenharmony_ci- return oh_data_type; 5235be168c0dSopenharmony_ci+ 5236be168c0dSopenharmony_ci+OH_NN_Format NNRTDelegate::CastToNNRTFormat(Format format) { 5237be168c0dSopenharmony_ci+ return OH_NN_FORMAT_NHWC; 5238be168c0dSopenharmony_ci } 5239be168c0dSopenharmony_ci 5240be168c0dSopenharmony_ci-mindspore::Status mindspore::NNRTDelegate::PrepareOutputs(DelegateModel<schema::Primitive> *model, 5241be168c0dSopenharmony_ci- OH_NNExecutor *oh_nn_executor) { 5242be168c0dSopenharmony_ci+Status NNRTDelegate::PrepareOutputs(DelegateModel<schema::Primitive> *model, 5243be168c0dSopenharmony_ci+ OH_NNExecutor *oh_nn_executor) { 5244be168c0dSopenharmony_ci auto output_tensors = model->outputs(); 5245be168c0dSopenharmony_ci for (size_t i = 0; i < output_tensors.size(); i++) { 5246be168c0dSopenharmony_ci auto tensor = output_tensors[i]; 5247be168c0dSopenharmony_ci@@ -280,17 +744,17 @@ mindspore::Status mindspore::NNRTDelegate::PrepareOutputs(DelegateModel<schema:: 5248be168c0dSopenharmony_ci if (ret_code != OH_NN_SUCCESS) { 5249be168c0dSopenharmony_ci MS_LOG(ERROR) << "NNExecutor SetOutput failed, current out tensor is" << tensor.Name() 5250be168c0dSopenharmony_ci << ", OH_NN_ReturnCode = " << ret_code; 5251be168c0dSopenharmony_ci- return mindspore::kLiteError; 5252be168c0dSopenharmony_ci+ return kLiteError; 5253be168c0dSopenharmony_ci } 5254be168c0dSopenharmony_ci } 5255be168c0dSopenharmony_ci- return mindspore::kSuccess; 5256be168c0dSopenharmony_ci+ return kSuccess; 5257be168c0dSopenharmony_ci } 5258be168c0dSopenharmony_ci 5259be168c0dSopenharmony_ci-void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGraph &lite_graph) { 5260be168c0dSopenharmony_ci+void NNRTDelegate::ShallowCopyLiteGraph(const lite::LiteGraph &lite_graph) { 5261be168c0dSopenharmony_ci Status ret; 5262be168c0dSopenharmony_ci for (auto node : lite_graph.all_nodes_) { 5263be168c0dSopenharmony_ci ret = lite::CheckPrimitiveSupported(static_cast<const schema::Primitive *>(node->primitive_)); 5264be168c0dSopenharmony_ci- if (ret == mindspore::kLiteError) { 5265be168c0dSopenharmony_ci+ if (ret == kLiteError) { 5266be168c0dSopenharmony_ci MS_LOG(ERROR) << " primitive supported check failed."; 5267be168c0dSopenharmony_ci return; 5268be168c0dSopenharmony_ci } 5269be168c0dSopenharmony_ci@@ -299,7 +763,7 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr 5270be168c0dSopenharmony_ci node_list.reserve(lite_graph.all_nodes_.size()); 5271be168c0dSopenharmony_ci // copy node 5272be168c0dSopenharmony_ci for (auto node : lite_graph.all_nodes_) { 5273be168c0dSopenharmony_ci- auto new_node = new (std::nothrow) LiteGraph::Node; 5274be168c0dSopenharmony_ci+ auto new_node = new(std::nothrow) LiteGraph::Node; 5275be168c0dSopenharmony_ci if (new_node == nullptr) { 5276be168c0dSopenharmony_ci MS_LOG(ERROR) << " new LiteGraph::Node failed."; 5277be168c0dSopenharmony_ci return; 5278be168c0dSopenharmony_ci@@ -318,7 +782,7 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr 5279be168c0dSopenharmony_ci // copy subgraph 5280be168c0dSopenharmony_ci std::vector<LiteGraph::SubGraph *> subgraph_list; 5281be168c0dSopenharmony_ci for (auto subgraph : lite_graph.sub_graphs_) { 5282be168c0dSopenharmony_ci- auto new_subgraph = new (std::nothrow) LiteGraph::SubGraph; 5283be168c0dSopenharmony_ci+ auto new_subgraph = new(std::nothrow) LiteGraph::SubGraph; 5284be168c0dSopenharmony_ci if (new_subgraph == nullptr) { 5285be168c0dSopenharmony_ci MS_LOG(ERROR) << "new LiteGraph::Subgraph failed."; 5286be168c0dSopenharmony_ci return; 5287be168c0dSopenharmony_ci@@ -331,30 +795,32 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr 5288be168c0dSopenharmony_ci } 5289be168c0dSopenharmony_ci for (auto tensor : lite_graph.all_tensors_) { 5290be168c0dSopenharmony_ci ret = lite::CheckTensorSupported(static_cast<const schema::Tensor *>(tensor)); 5291be168c0dSopenharmony_ci- if (ret == mindspore::kLiteError) { 5292be168c0dSopenharmony_ci+ if (ret == kLiteError) { 5293be168c0dSopenharmony_ci MS_LOG(ERROR) << "tensor supported check failed."; 5294be168c0dSopenharmony_ci return; 5295be168c0dSopenharmony_ci } 5296be168c0dSopenharmony_ci } 5297be168c0dSopenharmony_ci 5298be168c0dSopenharmony_ci- nnrt_lite_graph = new (std::nothrow) lite::LiteGraph(); 5299be168c0dSopenharmony_ci- if (nnrt_lite_graph == nullptr) { 5300be168c0dSopenharmony_ci+ lite_graph_ = new(std::nothrow) lite::LiteGraph(); 5301be168c0dSopenharmony_ci+ if (lite_graph_ == nullptr) { 5302be168c0dSopenharmony_ci MS_LOG(ERROR) << "new LiteGraph failed."; 5303be168c0dSopenharmony_ci return; 5304be168c0dSopenharmony_ci } 5305be168c0dSopenharmony_ci 5306be168c0dSopenharmony_ci- nnrt_lite_graph->name_ = lite_graph.name_; 5307be168c0dSopenharmony_ci- nnrt_lite_graph->version_ = lite_graph.version_; 5308be168c0dSopenharmony_ci- nnrt_lite_graph->input_indices_ = lite_graph.input_indices_; 5309be168c0dSopenharmony_ci- nnrt_lite_graph->output_indices_ = lite_graph.output_indices_; 5310be168c0dSopenharmony_ci- nnrt_lite_graph->all_tensors_ = lite_graph.all_tensors_; 5311be168c0dSopenharmony_ci- nnrt_lite_graph->all_nodes_ = node_list; 5312be168c0dSopenharmony_ci- nnrt_lite_graph->sub_graphs_ = subgraph_list; 5313be168c0dSopenharmony_ci+ lite_graph_->name_ = lite_graph.name_; 5314be168c0dSopenharmony_ci+ lite_graph_->version_ = lite_graph.version_; 5315be168c0dSopenharmony_ci+ lite_graph_->input_indices_ = lite_graph.input_indices_; 5316be168c0dSopenharmony_ci+ lite_graph_->output_indices_ = lite_graph.output_indices_; 5317be168c0dSopenharmony_ci+ lite_graph_->all_tensors_ = lite_graph.all_tensors_; 5318be168c0dSopenharmony_ci+ lite_graph_->all_nodes_ = node_list; 5319be168c0dSopenharmony_ci+ lite_graph_->sub_graphs_ = subgraph_list; 5320be168c0dSopenharmony_ci MS_LOG(INFO) << "ShallowCopyLiteGraph success."; 5321be168c0dSopenharmony_ci } 5322be168c0dSopenharmony_ci 5323be168c0dSopenharmony_ci-mindspore::NNRTDelegate::~NNRTDelegate() { 5324be168c0dSopenharmony_ci- if (this->nnrt_lite_graph != nullptr) { 5325be168c0dSopenharmony_ci+NNRTDelegate::~NNRTDelegate() { 5326be168c0dSopenharmony_ci+ if (lite_graph_ != nullptr) { 5327be168c0dSopenharmony_ci MS_LOG(ERROR) << "Delete NNRTDelegate."; 5328be168c0dSopenharmony_ci } 5329be168c0dSopenharmony_ci-}; 5330be168c0dSopenharmony_ci+} 5331be168c0dSopenharmony_ci+} // namespace lite 5332be168c0dSopenharmony_ci+} // namespace mindspore 5333be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h 5334be168c0dSopenharmony_ciindex c2847704..52626339 100644 5335be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h 5336be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h 5337be168c0dSopenharmony_ci@@ -15,37 +15,81 @@ 5338be168c0dSopenharmony_ci */ 5339be168c0dSopenharmony_ci #ifndef MINDSPORE_NNR_DELEGATE_H 5340be168c0dSopenharmony_ci #define MINDSPORE_NNR_DELEGATE_H 5341be168c0dSopenharmony_ci+ 5342be168c0dSopenharmony_ci #include <vector> 5343be168c0dSopenharmony_ci #include <map> 5344be168c0dSopenharmony_ci #include "include/api/delegate.h" 5345be168c0dSopenharmony_ci #include "include/model.h" 5346be168c0dSopenharmony_ci-#include "interfaces/kits/c/neural_network_runtime_type.h" 5347be168c0dSopenharmony_ci-namespace mindspore { 5348be168c0dSopenharmony_ci+#include "src/litert/inner_context.h" 5349be168c0dSopenharmony_ci+#include "nnrt_model_kernel.h" 5350be168c0dSopenharmony_ci+#include "schema/model_generated.h" 5351be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime_type.h" 5352be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 5353be168c0dSopenharmony_ci+#include "interfaces/innerkits/c/neural_network_runtime_inner.h" 5354be168c0dSopenharmony_ci 5355be168c0dSopenharmony_ci-using namespace lite; 5356be168c0dSopenharmony_ci+namespace mindspore { 5357be168c0dSopenharmony_ci+namespace lite { 5358be168c0dSopenharmony_ci+struct NNRTOpRange { 5359be168c0dSopenharmony_ci+ /* NNRT kernel range in DelegateModel: [begin_iter_, end_iter_) */ 5360be168c0dSopenharmony_ci+ KernelIter begin_iter_; 5361be168c0dSopenharmony_ci+ KernelIter end_iter_; 5362be168c0dSopenharmony_ci+ /* NNRT node range in lite_graph_: [begin_index_, end_index_) */ 5363be168c0dSopenharmony_ci+ size_t begin_index_; 5364be168c0dSopenharmony_ci+ size_t end_index_; 5365be168c0dSopenharmony_ci+}; 5366be168c0dSopenharmony_ci 5367be168c0dSopenharmony_ci class NNRTDelegate : public Delegate { 5368be168c0dSopenharmony_ci public: 5369be168c0dSopenharmony_ci- NNRTDelegate() : Delegate(){}; 5370be168c0dSopenharmony_ci- 5371be168c0dSopenharmony_ci+ NNRTDelegate() = default; 5372be168c0dSopenharmony_ci+ NNRTDelegate(const NNRtDeviceInfo &nnrt_device_info) : nnrt_device_info_(nnrt_device_info) {} 5373be168c0dSopenharmony_ci ~NNRTDelegate() override; 5374be168c0dSopenharmony_ci- 5375be168c0dSopenharmony_ci- Status Init() override; 5376be168c0dSopenharmony_ci- 5377be168c0dSopenharmony_ci+ Status Init() override { return kSuccess; } 5378be168c0dSopenharmony_ci Status Build(DelegateModel<schema::Primitive> *model) override; 5379be168c0dSopenharmony_ci- 5380be168c0dSopenharmony_ci void ShallowCopyLiteGraph(const lite::LiteGraph &liteGraph); 5381be168c0dSopenharmony_ci- 5382be168c0dSopenharmony_ci- protected: 5383be168c0dSopenharmony_ci- LiteGraph *nnrt_lite_graph = nullptr; 5384be168c0dSopenharmony_ci+ void SetMetaGraph(const void *meta_graph) { 5385be168c0dSopenharmony_ci+ meta_graph_ = meta_graph; 5386be168c0dSopenharmony_ci+ } 5387be168c0dSopenharmony_ci+ static std::vector<NNRTOpRange> GetNNRTSubgraphRanges(DelegateModel<schema::Primitive> *model, 5388be168c0dSopenharmony_ci+ const std::vector<bool> &op_supports); 5389be168c0dSopenharmony_ci 5390be168c0dSopenharmony_ci private: 5391be168c0dSopenharmony_ci- // static LiteGraph* CreateLiteGraph(const LiteGraph &liteGraph); 5392be168c0dSopenharmony_ci+ void InitCachePath(); 5393be168c0dSopenharmony_ci+ Status BuildNormalModel(DelegateModel<schema::Primitive> *model); 5394be168c0dSopenharmony_ci+ OH_NNModel *CreateFullNNModel(); 5395be168c0dSopenharmony_ci+ std::vector<bool> QueryOpSupports(OH_NNModel *nn_model); 5396be168c0dSopenharmony_ci+ Status CreateLiteGraphForNNRTSubgraph( 5397be168c0dSopenharmony_ci+ const std::vector<NNRTOpRange> &nnrt_op_ranges, 5398be168c0dSopenharmony_ci+ std::vector<LiteGraph *> *sub_lite_graphs); 5399be168c0dSopenharmony_ci+ Status CreateNNRTSubgraphKernels( 5400be168c0dSopenharmony_ci+ DelegateModel<schema::Primitive> *model, 5401be168c0dSopenharmony_ci+ const std::vector<LiteGraph *> &sub_lite_graphs, 5402be168c0dSopenharmony_ci+ const std::vector<NNRTOpRange> &nnrt_subgraph_ranges, 5403be168c0dSopenharmony_ci+ std::vector<NNRTModelKernel *> *nnrt_subgraph_kernels); 5404be168c0dSopenharmony_ci+ void ReplaceNNRTKernelsInDelegateModel(DelegateModel<schema::Primitive> *model, 5405be168c0dSopenharmony_ci+ const std::vector<NNRTOpRange> &nnrt_subgraph_ranges, 5406be168c0dSopenharmony_ci+ const std::vector<NNRTModelKernel *> &nnrt_subgraph_kernels); 5407be168c0dSopenharmony_ci Status PrepareInputs(DelegateModel<schema::Primitive> *model, OH_NNExecutor *oh_nn_executor); 5408be168c0dSopenharmony_ci Status PrepareOutputs(DelegateModel<schema::Primitive> *model, OH_NNExecutor *oh_nn_executor); 5409be168c0dSopenharmony_ci- OH_NN_DataType ConvertDataType(mindspore::DataType data_type); 5410be168c0dSopenharmony_ci-}; 5411be168c0dSopenharmony_ci+ Status InitNNCompilation(OH_NNCompilation *nn_compilation) const; 5412be168c0dSopenharmony_ci+ static OH_NN_DataType CastToNNRTDataType(mindspore::DataType data_type); 5413be168c0dSopenharmony_ci+ static OH_NN_Format CastToNNRTFormat(Format format); 5414be168c0dSopenharmony_ci+ bool IsCustomModel() const; 5415be168c0dSopenharmony_ci+ 5416be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT_METAGRAPH 5417be168c0dSopenharmony_ci+ bool IsKirinNPU() const; 5418be168c0dSopenharmony_ci+ Status BuildKirinNPUModel(DelegateModel<schema::Primitive> *model); 5419be168c0dSopenharmony_ci+ Status SetKirinModelInputsAndOutputs(OH_NNModel *nn_model); 5420be168c0dSopenharmony_ci+ std::vector<OH_NN_TensorInfo> CreateNNTensorInfos(const std::vector<uint32_t> &indices) const; 5421be168c0dSopenharmony_ci+ Status CreateFullModelKernel(DelegateModel<schema::Primitive> *model, OH_NNModel *nn_model); 5422be168c0dSopenharmony_ci+#endif 5423be168c0dSopenharmony_ci 5424be168c0dSopenharmony_ci+ NNRtDeviceInfo nnrt_device_info_; 5425be168c0dSopenharmony_ci+ LiteGraph *lite_graph_ = nullptr; 5426be168c0dSopenharmony_ci+ const void *meta_graph_ = nullptr; 5427be168c0dSopenharmony_ci+ std::string cache_path_ = ""; 5428be168c0dSopenharmony_ci+ uint32_t cache_version_ = 0; 5429be168c0dSopenharmony_ci+}; 5430be168c0dSopenharmony_ci+} // namespace lite 5431be168c0dSopenharmony_ci } // namespace mindspore 5432be168c0dSopenharmony_ci 5433be168c0dSopenharmony_ci #endif // MINDSPORE_NNR_DELEGATE_H 5434be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc 5435be168c0dSopenharmony_ciindex 5acf2e9a..67443e08 100644 5436be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc 5437be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc 5438be168c0dSopenharmony_ci@@ -97,7 +97,7 @@ OH_NN_DataType mindspore::NNRTModelKernel::ConvertDataType(mindspore::DataType d 5439be168c0dSopenharmony_ci } 5440be168c0dSopenharmony_ci int mindspore::NNRTModelKernel::PrepareInputs() { 5441be168c0dSopenharmony_ci auto input_tensors = this->inputs(); 5442be168c0dSopenharmony_ci- for (int i = 0; i < input_tensors.size(); i++) { 5443be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_tensors.size(); i++) { 5444be168c0dSopenharmony_ci auto tensor = input_tensors[i]; 5445be168c0dSopenharmony_ci auto tensor_shape = tensor.Shape(); 5446be168c0dSopenharmony_ci auto tmp_quant_param = tensor.QuantParams(); 5447be168c0dSopenharmony_ci@@ -142,6 +142,7 @@ int mindspore::NNRTModelKernel::PrepareInputs() { 5448be168c0dSopenharmony_ci oprend->dimensions = dimensions_list.data(); 5449be168c0dSopenharmony_ci oprend->quantParam = quant_param; 5450be168c0dSopenharmony_ci oprend->type = OH_NN_TENSOR; 5451be168c0dSopenharmony_ci+ MS_LOG_INFO << "input tensor: " << tensor.Name() << ", data: " << (void *)tensor.MutableData() << ", size: " << tensor.DataSize(); 5452be168c0dSopenharmony_ci OH_NN_ReturnCode ret_code = 5453be168c0dSopenharmony_ci OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize()); 5454be168c0dSopenharmony_ci delete (oprend); 5455be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h 5456be168c0dSopenharmony_ciindex cf9481df..ea15f7ca 100644 5457be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h 5458be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h 5459be168c0dSopenharmony_ci@@ -20,7 +20,7 @@ 5460be168c0dSopenharmony_ci #include <map> 5461be168c0dSopenharmony_ci #include <utility> 5462be168c0dSopenharmony_ci #include "include/api/kernel.h" 5463be168c0dSopenharmony_ci-#include "interfaces/kits/c/neural_network_runtime.h" 5464be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 5465be168c0dSopenharmony_ci #include "src/common/log_adapter.h" 5466be168c0dSopenharmony_ci #include "include/errorcode.h" 5467be168c0dSopenharmony_ci 5468be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc 5469be168c0dSopenharmony_cinew file mode 100644 5470be168c0dSopenharmony_ciindex 00000000..8ac283af 5471be168c0dSopenharmony_ci--- /dev/null 5472be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc 5473be168c0dSopenharmony_ci@@ -0,0 +1,99 @@ 5474be168c0dSopenharmony_ci+/** 5475be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd 5476be168c0dSopenharmony_ci+* 5477be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License"); 5478be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License. 5479be168c0dSopenharmony_ci+* You may obtain a copy of the License at 5480be168c0dSopenharmony_ci+* 5481be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0 5482be168c0dSopenharmony_ci+* 5483be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software 5484be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS, 5485be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5486be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and 5487be168c0dSopenharmony_ci+* limitations under the License. 5488be168c0dSopenharmony_ci+*/ 5489be168c0dSopenharmony_ci+ 5490be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h" 5491be168c0dSopenharmony_ci+#include "interfaces/innerkits/c/neural_network_runtime_inner.h" 5492be168c0dSopenharmony_ci+ 5493be168c0dSopenharmony_ci+OH_NNModel *OH_NNModel_Construct(void) { 5494be168c0dSopenharmony_ci+ return NULL; 5495be168c0dSopenharmony_ci+} 5496be168c0dSopenharmony_ci+ 5497be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNExecutor_Run(OH_NNExecutor *executor) { 5498be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5499be168c0dSopenharmony_ci+} 5500be168c0dSopenharmony_ci+ 5501be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_Build(OH_NNCompilation *compilation) { 5502be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5503be168c0dSopenharmony_ci+} 5504be168c0dSopenharmony_ci+ 5505be168c0dSopenharmony_ci+void OH_NNCompilation_Destroy(OH_NNCompilation **compilation) {} 5506be168c0dSopenharmony_ci+ 5507be168c0dSopenharmony_ci+OH_NNExecutor *OH_NNExecutor_Construct(OH_NNCompilation *compilation) { 5508be168c0dSopenharmony_ci+ return NULL; 5509be168c0dSopenharmony_ci+} 5510be168c0dSopenharmony_ci+ 5511be168c0dSopenharmony_ci+void OH_NNExecutor_Destroy(OH_NNExecutor **executor) {} 5512be168c0dSopenharmony_ci+ 5513be168c0dSopenharmony_ci+OH_NNCompilation *OH_NNCompilation_Construct(const OH_NNModel *model) { 5514be168c0dSopenharmony_ci+ return NULL; 5515be168c0dSopenharmony_ci+} 5516be168c0dSopenharmony_ci+ 5517be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNDevice_GetAllDevicesID(const size_t **allDevicesID, uint32_t *deviceCount) { 5518be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5519be168c0dSopenharmony_ci+} 5520be168c0dSopenharmony_ci+ 5521be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNExecutor_SetOutput(OH_NNExecutor *executor, 5522be168c0dSopenharmony_ci+ uint32_t outputIndex, 5523be168c0dSopenharmony_ci+ void *dataBuffer, 5524be168c0dSopenharmony_ci+ size_t length) { 5525be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5526be168c0dSopenharmony_ci+} 5527be168c0dSopenharmony_ci+ 5528be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_SetDevice(OH_NNCompilation *compilation, size_t deviceID) { 5529be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5530be168c0dSopenharmony_ci+} 5531be168c0dSopenharmony_ci+ 5532be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNExecutor_SetInput(OH_NNExecutor *executor, 5533be168c0dSopenharmony_ci+ uint32_t inputIndex, 5534be168c0dSopenharmony_ci+ const OH_NN_Tensor *tensor, 5535be168c0dSopenharmony_ci+ const void *dataBuffer, 5536be168c0dSopenharmony_ci+ size_t length) { 5537be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5538be168c0dSopenharmony_ci+} 5539be168c0dSopenharmony_ci+ 5540be168c0dSopenharmony_ci+void OH_NNModel_Destroy(OH_NNModel **model) {} 5541be168c0dSopenharmony_ci+ 5542be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNModel_GetAvailableOperations(OH_NNModel *model, 5543be168c0dSopenharmony_ci+ size_t deviceID, 5544be168c0dSopenharmony_ci+ const bool **isSupported, 5545be168c0dSopenharmony_ci+ uint32_t *opCount) { 5546be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5547be168c0dSopenharmony_ci+} 5548be168c0dSopenharmony_ci+ 5549be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNModel_BuildFromLiteGraph(OH_NNModel *model, const void *liteGraph) { 5550be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5551be168c0dSopenharmony_ci+} 5552be168c0dSopenharmony_ci+ 5553be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNDevice_GetName(size_t deviceID, const char **name) { 5554be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5555be168c0dSopenharmony_ci+} 5556be168c0dSopenharmony_ci+ 5557be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNDevice_GetType(size_t deviceID, OH_NN_DeviceType *deviceType) { 5558be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5559be168c0dSopenharmony_ci+} 5560be168c0dSopenharmony_ci+ 5561be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_SetPriority(OH_NNCompilation *compilation, OH_NN_Priority priority) { 5562be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5563be168c0dSopenharmony_ci+} 5564be168c0dSopenharmony_ci+ 5565be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_EnableFloat16(OH_NNCompilation *compilation, bool enableFloat16) { 5566be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5567be168c0dSopenharmony_ci+} 5568be168c0dSopenharmony_ci+ 5569be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_SetPerformanceMode(OH_NNCompilation *compilation, 5570be168c0dSopenharmony_ci+ OH_NN_PerformanceMode performanceMode) { 5571be168c0dSopenharmony_ci+ return OH_NN_SUCCESS; 5572be168c0dSopenharmony_ci+} 5573be168c0dSopenharmony_ci\ No newline at end of file 5574be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/infer_manager.cc b/mindspore/lite/src/litert/infer_manager.cc 5575be168c0dSopenharmony_ciindex 2b21d1ca..908ab122 100644 5576be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/infer_manager.cc 5577be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/infer_manager.cc 5578be168c0dSopenharmony_ci@@ -162,7 +162,8 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto 5579be168c0dSopenharmony_ci if (parameter->type_ == static_cast<int>(schema::PrimitiveType_PartialFusion) || 5580be168c0dSopenharmony_ci parameter->type_ == static_cast<int>(schema::PrimitiveType_Switch) || 5581be168c0dSopenharmony_ci parameter->type_ == static_cast<int>(schema::PrimitiveType_Call) || 5582be168c0dSopenharmony_ci- parameter->type_ == static_cast<int>(schema::PrimitiveType_SwitchLayer)) { 5583be168c0dSopenharmony_ci+ parameter->type_ == static_cast<int>(schema::PrimitiveType_SwitchLayer) || 5584be168c0dSopenharmony_ci+ parameter->type_ == static_cast<int>(PrimType_Inner_ThirdPartyModel)) { 5585be168c0dSopenharmony_ci MS_LOG(INFO) << "no need infer shape."; 5586be168c0dSopenharmony_ci return RET_OK; 5587be168c0dSopenharmony_ci } 5588be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/inner_context.cc b/mindspore/lite/src/litert/inner_context.cc 5589be168c0dSopenharmony_ciindex 7cbac8f7..bf585ff0 100644 5590be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/inner_context.cc 5591be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/inner_context.cc 5592be168c0dSopenharmony_ci@@ -122,6 +122,10 @@ int InnerContext::Init() { 5593be168c0dSopenharmony_ci #endif 5594be168c0dSopenharmony_ci } 5595be168c0dSopenharmony_ci 5596be168c0dSopenharmony_ci+ if (IsDeviceTypeEnabled(DT_NNRT)) { 5597be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "NNRT enabled."; 5598be168c0dSopenharmony_ci+ } 5599be168c0dSopenharmony_ci+ 5600be168c0dSopenharmony_ci if (CreateThreadPool(false)) { 5601be168c0dSopenharmony_ci MS_LOG(ERROR) << "CreateThreadPool failed."; 5602be168c0dSopenharmony_ci return RET_ERROR; 5603be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/inner_context.h b/mindspore/lite/src/litert/inner_context.h 5604be168c0dSopenharmony_ciindex 88281eb1..8735961c 100644 5605be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/inner_context.h 5606be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/inner_context.h 5607be168c0dSopenharmony_ci@@ -71,12 +71,26 @@ typedef struct CustomDeviceInfo { 5608be168c0dSopenharmony_ci std::shared_ptr<DeviceInfoContext> user_defined_device_info_; 5609be168c0dSopenharmony_ci } CustomDeviceInfo; 5610be168c0dSopenharmony_ci 5611be168c0dSopenharmony_ci+typedef struct Extension { 5612be168c0dSopenharmony_ci+ std::string name; // config name 5613be168c0dSopenharmony_ci+ std::vector<uint8_t> value; // config value 5614be168c0dSopenharmony_ci+} Extension; 5615be168c0dSopenharmony_ci+ 5616be168c0dSopenharmony_ci+typedef struct NNRtDeviceInfo { 5617be168c0dSopenharmony_ci+ size_t device_id_ = 0; 5618be168c0dSopenharmony_ci+ int priority_ = 0; 5619be168c0dSopenharmony_ci+ int performance_mode_ = 0; 5620be168c0dSopenharmony_ci+ bool enable_fp16_ = false; 5621be168c0dSopenharmony_ci+ std::vector<Extension> extensions_; 5622be168c0dSopenharmony_ci+} NNRtDeviceInfo; 5623be168c0dSopenharmony_ci+ 5624be168c0dSopenharmony_ci struct DeviceInfo { 5625be168c0dSopenharmony_ci CpuDeviceInfo cpu_device_info_; 5626be168c0dSopenharmony_ci GpuDeviceInfo gpu_device_info_; 5627be168c0dSopenharmony_ci NpuDeviceInfo npu_device_info_; 5628be168c0dSopenharmony_ci AscendDeviceInfo ascend_device_info_; 5629be168c0dSopenharmony_ci CustomDeviceInfo custom_device_info_; 5630be168c0dSopenharmony_ci+ NNRtDeviceInfo nnrt_device_info_; 5631be168c0dSopenharmony_ci }; 5632be168c0dSopenharmony_ci 5633be168c0dSopenharmony_ci struct DeviceContext { 5634be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5635be168c0dSopenharmony_ciindex 48308425..65065b5b 100644 5636be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5637be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn 5638be168c0dSopenharmony_ci@@ -13,6 +13,10 @@ cpu_kernel_sources = [ 5639be168c0dSopenharmony_ci "base/call.cc", 5640be168c0dSopenharmony_ci "base/constant_of_shape.cc", 5641be168c0dSopenharmony_ci "base/convolution_base.cc", 5642be168c0dSopenharmony_ci+ "base/custom_base.cc", 5643be168c0dSopenharmony_ci+ "base/custom_masked_fill.cc", 5644be168c0dSopenharmony_ci+ "base/custom_is_inf.cc", 5645be168c0dSopenharmony_ci+ "base/custom_tensor_scatter.cc", 5646be168c0dSopenharmony_ci "base/detection_post_process_base.cc", 5647be168c0dSopenharmony_ci "base/format_transpose.cc", 5648be168c0dSopenharmony_ci "base/group_convolution_base.cc", 5649be168c0dSopenharmony_ci@@ -37,7 +41,6 @@ cpu_kernel_sources = [ 5650be168c0dSopenharmony_ci "fp32/batchnorm_fp32.cc", 5651be168c0dSopenharmony_ci "fp32/batch_to_space_fp32.cc", 5652be168c0dSopenharmony_ci "fp32/broadcast_to_fp32.cc", 5653be168c0dSopenharmony_ci- "fp32/cast_for_x86_fp16.cc", 5654be168c0dSopenharmony_ci "fp32/cast_fp32.cc", 5655be168c0dSopenharmony_ci "fp32/convolution_1x1_fp32.cc", 5656be168c0dSopenharmony_ci "fp32/convolution_delegate_fp32.cc", 5657be168c0dSopenharmony_ci@@ -118,6 +121,10 @@ cpu_kernel_sources = [ 5658be168c0dSopenharmony_ci "fp32/online_fusion/split_reduce_concat_fp32.cc", 5659be168c0dSopenharmony_ci ] 5660be168c0dSopenharmony_ci 5661be168c0dSopenharmony_ci+if ((target_cpu != "arm") && (target_cpu != "arm64")) { 5662be168c0dSopenharmony_ci+ cpu_kernel_sources += [ "src/runtime/kernel/cpu/fp32/cast_for_x86_fp16.cc" ] 5663be168c0dSopenharmony_ci+} 5664be168c0dSopenharmony_ci+ 5665be168c0dSopenharmony_ci arm64_cpu_kernel_sources = [ 5666be168c0dSopenharmony_ci "fp32/convolution_im2col_arm64_fp32.cc", 5667be168c0dSopenharmony_ci "fp32/matmul_fp32_arm64.cc", 5668be168c0dSopenharmony_ci@@ -142,6 +149,42 @@ sse_avx_avx512_kernel_sources = [ 5669be168c0dSopenharmony_ci "fp32/matmul_fp32_avx512.cc", 5670be168c0dSopenharmony_ci ] 5671be168c0dSopenharmony_ci 5672be168c0dSopenharmony_ci+fp16_kernel_sources = [ 5673be168c0dSopenharmony_ci+ "fp16/batchnorm_fp16.cc", 5674be168c0dSopenharmony_ci+ "fp16/biasadd_fp16.cc", 5675be168c0dSopenharmony_ci+ "fp16/cast_fp16.cc", 5676be168c0dSopenharmony_ci+ "fp16/common_fp16.cc", 5677be168c0dSopenharmony_ci+ "fp16/convolution_1x1_fp16.cc", 5678be168c0dSopenharmony_ci+ "fp16/convolution_delegate_fp16.cc", 5679be168c0dSopenharmony_ci+ "fp16/convolution_depthwise_3x3_fp16.cc", 5680be168c0dSopenharmony_ci+ "fp16/convolution_depthwise_fp16.cc", 5681be168c0dSopenharmony_ci+ "fp16/convolution_depthwise_slidewindow_fp16.cc", 5682be168c0dSopenharmony_ci+ "fp16/convolution_fp16.cc", 5683be168c0dSopenharmony_ci+ "fp16/convolution_winograd_fp16.cc", 5684be168c0dSopenharmony_ci+ "fp16/custom_gru_fp16.cc", 5685be168c0dSopenharmony_ci+ "fp16/deconvolution_depthwise_fp16.cc", 5686be168c0dSopenharmony_ci+ "fp16/deconvolution_fp16.cc", 5687be168c0dSopenharmony_ci+ "fp16/deconvolution_winograd_fp16.cc", 5688be168c0dSopenharmony_ci+ "fp16/depth_to_space_fp16.cc", 5689be168c0dSopenharmony_ci+ "fp16/dynamic_quant_fp16.cc", 5690be168c0dSopenharmony_ci+ "fp16/fullconnection_fp16.cc", 5691be168c0dSopenharmony_ci+ "fp16/fused_batchnorm_fp16.cc", 5692be168c0dSopenharmony_ci+ "fp16/group_convolution_fp16.cc", 5693be168c0dSopenharmony_ci+ "fp16/gru_fp16.cc", 5694be168c0dSopenharmony_ci+ "fp16/instance_norm_fp16.cc", 5695be168c0dSopenharmony_ci+ "fp16/layout_transform_fp16.cc", 5696be168c0dSopenharmony_ci+ "fp16/lstm_fp16.cc", 5697be168c0dSopenharmony_ci+ "fp16/matmul_base_fp16.cc", 5698be168c0dSopenharmony_ci+ "fp16/matmul_fp16.cc", 5699be168c0dSopenharmony_ci+ "fp16/power_fp16.cc", 5700be168c0dSopenharmony_ci+ "fp16/prelu_fp16.cc", 5701be168c0dSopenharmony_ci+ "fp16/quant_dtype_cast_fp16.cc", 5702be168c0dSopenharmony_ci+ "fp16/reduce_fp16.cc", 5703be168c0dSopenharmony_ci+ "fp16/resize_fp16.cc", 5704be168c0dSopenharmony_ci+ "fp16/slice_fp16.cc", 5705be168c0dSopenharmony_ci+ "fp16/where_fp16.cc", 5706be168c0dSopenharmony_ci+] 5707be168c0dSopenharmony_ci+ 5708be168c0dSopenharmony_ci int8_kernel_sources = [ 5709be168c0dSopenharmony_ci "int8/activation_int8.cc", 5710be168c0dSopenharmony_ci "int8/add_int8.cc", 5711be168c0dSopenharmony_ci@@ -227,6 +270,12 @@ all_cpu_kernel_sources += int8_kernel_sources 5712be168c0dSopenharmony_ci all_cpu_kernel_sources += string_kernel_sources 5713be168c0dSopenharmony_ci all_cpu_kernel_sources += control_kernel_sources 5714be168c0dSopenharmony_ci 5715be168c0dSopenharmony_ci+if (target_cpu == "arm64") { 5716be168c0dSopenharmony_ci+ all_cpu_kernel_sources += fp16_kernel_sources 5717be168c0dSopenharmony_ci+} else { 5718be168c0dSopenharmony_ci+ not_needed(fp16_kernel_sources) 5719be168c0dSopenharmony_ci+} 5720be168c0dSopenharmony_ci+ 5721be168c0dSopenharmony_ci if (target_cpu == "arm") { 5722be168c0dSopenharmony_ci all_cpu_kernel_sources -= arm64_cpu_kernel_sources 5723be168c0dSopenharmony_ci all_cpu_kernel_sources -= sse_avx_avx512_kernel_sources 5724be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc 5725be168c0dSopenharmony_cinew file mode 100644 5726be168c0dSopenharmony_ciindex 00000000..9921e063 5727be168c0dSopenharmony_ci--- /dev/null 5728be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc 5729be168c0dSopenharmony_ci@@ -0,0 +1,46 @@ 5730be168c0dSopenharmony_ci+/** 5731be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd 5732be168c0dSopenharmony_ci+ * 5733be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 5734be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 5735be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 5736be168c0dSopenharmony_ci+ * 5737be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 5738be168c0dSopenharmony_ci+ * 5739be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 5740be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 5741be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5742be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 5743be168c0dSopenharmony_ci+ * limitations under the License. 5744be168c0dSopenharmony_ci+ */ 5745be168c0dSopenharmony_ci+ 5746be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/custom_base.h" 5747be168c0dSopenharmony_ci+#include <algorithm> 5748be168c0dSopenharmony_ci+#include <utility> 5749be168c0dSopenharmony_ci+#include <vector> 5750be168c0dSopenharmony_ci+#include "src/litert/kernel_registry.h" 5751be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 5752be168c0dSopenharmony_ci+ 5753be168c0dSopenharmony_ci+using mindspore::kernel::KERNEL_ARCH; 5754be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar; 5755be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR; 5756be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 5757be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Custom; 5758be168c0dSopenharmony_ci+ 5759be168c0dSopenharmony_ci+namespace mindspore::kernel { 5760be168c0dSopenharmony_ci+int CustomBaseCPUKernel::Prepare() { 5761be168c0dSopenharmony_ci+ return RET_OK; 5762be168c0dSopenharmony_ci+} 5763be168c0dSopenharmony_ci+ 5764be168c0dSopenharmony_ci+int CustomBaseCPUKernel::ReSize() { 5765be168c0dSopenharmony_ci+ return RET_OK; 5766be168c0dSopenharmony_ci+} 5767be168c0dSopenharmony_ci+ 5768be168c0dSopenharmony_ci+int CustomBaseCPUKernel::Run() { 5769be168c0dSopenharmony_ci+ return RET_OK; 5770be168c0dSopenharmony_ci+} 5771be168c0dSopenharmony_ci+ 5772be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeInt32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>) 5773be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>) 5774be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeBool, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>) 5775be168c0dSopenharmony_ci+} // namespace mindspore::kernel 5776be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h 5777be168c0dSopenharmony_cinew file mode 100644 5778be168c0dSopenharmony_ciindex 00000000..ecb4c72d 5779be168c0dSopenharmony_ci--- /dev/null 5780be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h 5781be168c0dSopenharmony_ci@@ -0,0 +1,43 @@ 5782be168c0dSopenharmony_ci+/** 5783be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd 5784be168c0dSopenharmony_ci+ * 5785be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 5786be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 5787be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 5788be168c0dSopenharmony_ci+ * 5789be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 5790be168c0dSopenharmony_ci+ * 5791be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 5792be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 5793be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5794be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 5795be168c0dSopenharmony_ci+ * limitations under the License. 5796be168c0dSopenharmony_ci+ */ 5797be168c0dSopenharmony_ci+ 5798be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_ 5799be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_ 5800be168c0dSopenharmony_ci+ 5801be168c0dSopenharmony_ci+#include <vector> 5802be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h" 5803be168c0dSopenharmony_ci+#include "nnacl/custom_parameter.h" 5804be168c0dSopenharmony_ci+ 5805be168c0dSopenharmony_ci+namespace mindspore::kernel { 5806be168c0dSopenharmony_ci+class CustomBaseCPUKernel : public LiteKernel { 5807be168c0dSopenharmony_ci+ public: 5808be168c0dSopenharmony_ci+ CustomBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 5809be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 5810be168c0dSopenharmony_ci+ : LiteKernel(parameter, inputs, outputs, ctx) { 5811be168c0dSopenharmony_ci+ custom_param_ = reinterpret_cast<CustomParameter *>(op_parameter_); 5812be168c0dSopenharmony_ci+ } 5813be168c0dSopenharmony_ci+ ~CustomBaseCPUKernel() override = default; 5814be168c0dSopenharmony_ci+ 5815be168c0dSopenharmony_ci+ int Prepare() override; 5816be168c0dSopenharmony_ci+ int ReSize() override; 5817be168c0dSopenharmony_ci+ int Run() override; 5818be168c0dSopenharmony_ci+ 5819be168c0dSopenharmony_ci+ private: 5820be168c0dSopenharmony_ci+ CustomParameter *custom_param_ = nullptr; 5821be168c0dSopenharmony_ci+}; 5822be168c0dSopenharmony_ci+} // namespace mindspore::kernel 5823be168c0dSopenharmony_ci+ 5824be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_ 5825be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc 5826be168c0dSopenharmony_cinew file mode 100644 5827be168c0dSopenharmony_ciindex 00000000..edffea42 5828be168c0dSopenharmony_ci--- /dev/null 5829be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc 5830be168c0dSopenharmony_ci@@ -0,0 +1,61 @@ 5831be168c0dSopenharmony_ci+/** 5832be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 5833be168c0dSopenharmony_ci+ * 5834be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 5835be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 5836be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 5837be168c0dSopenharmony_ci+ * 5838be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 5839be168c0dSopenharmony_ci+ * 5840be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 5841be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 5842be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5843be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 5844be168c0dSopenharmony_ci+ * limitations under the License. 5845be168c0dSopenharmony_ci+ */ 5846be168c0dSopenharmony_ci+#include "src/litert/kernel_registry.h" 5847be168c0dSopenharmony_ci+#include "include/errorcode.h" 5848be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/custom_is_inf.h" 5849be168c0dSopenharmony_ci+#include "src/common/tensor_util.h" 5850be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 5851be168c0dSopenharmony_ci+ 5852be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar; 5853be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR; 5854be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 5855be168c0dSopenharmony_ci+ 5856be168c0dSopenharmony_ci+namespace mindspore::kernel { 5857be168c0dSopenharmony_ci+ 5858be168c0dSopenharmony_ci+int CustomIsInfCPUKernel::Prepare() { 5859be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(in_tensors_.size(), C1NUM); 5860be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(out_tensors_.size(), C1NUM); 5861be168c0dSopenharmony_ci+ return RET_OK; 5862be168c0dSopenharmony_ci+} 5863be168c0dSopenharmony_ci+ 5864be168c0dSopenharmony_ci+int CustomIsInfCPUKernel::ReSize() { return RET_OK; } 5865be168c0dSopenharmony_ci+ 5866be168c0dSopenharmony_ci+void CustomIsInfCPUKernel::LaunchKernelFloat(const float *input, bool *output) { 5867be168c0dSopenharmony_ci+ auto elem_num = in_tensors_[FIRST_INPUT]->ElementsNum(); 5868be168c0dSopenharmony_ci+ 5869be168c0dSopenharmony_ci+ for (int i = 0; i < elem_num; i++) { 5870be168c0dSopenharmony_ci+ output[i] = std::isinf(input[i]); 5871be168c0dSopenharmony_ci+ } 5872be168c0dSopenharmony_ci+} 5873be168c0dSopenharmony_ci+ 5874be168c0dSopenharmony_ci+int CustomIsInfCPUKernel::Run() { 5875be168c0dSopenharmony_ci+ auto input = in_tensors_[FIRST_INPUT]; 5876be168c0dSopenharmony_ci+ auto output = out_tensors_[FIRST_INPUT]; 5877be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(input); 5878be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output); 5879be168c0dSopenharmony_ci+ 5880be168c0dSopenharmony_ci+ if (input->data_type() == kNumberTypeFloat32 || input->data_type() == kNumberTypeFloat) { 5881be168c0dSopenharmony_ci+ LaunchKernelFloat(reinterpret_cast<const float *>(input->data()), reinterpret_cast<bool *>(output->data())); 5882be168c0dSopenharmony_ci+ } else { 5883be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "unsupported input data type " << input->data_type(); 5884be168c0dSopenharmony_ci+ return RET_ERROR; 5885be168c0dSopenharmony_ci+ } 5886be168c0dSopenharmony_ci+ 5887be168c0dSopenharmony_ci+ return RET_OK; 5888be168c0dSopenharmony_ci+} 5889be168c0dSopenharmony_ci+ 5890be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomIsInf, LiteKernelCreator<CustomIsInfCPUKernel>) 5891be168c0dSopenharmony_ci+} // namespace mindspore::kernel 5892be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h 5893be168c0dSopenharmony_cinew file mode 100644 5894be168c0dSopenharmony_ciindex 00000000..e63d8ec7 5895be168c0dSopenharmony_ci--- /dev/null 5896be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h 5897be168c0dSopenharmony_ci@@ -0,0 +1,38 @@ 5898be168c0dSopenharmony_ci+/** 5899be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 5900be168c0dSopenharmony_ci+ * 5901be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 5902be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 5903be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 5904be168c0dSopenharmony_ci+ * 5905be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 5906be168c0dSopenharmony_ci+ * 5907be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 5908be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 5909be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5910be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 5911be168c0dSopenharmony_ci+ * limitations under the License. 5912be168c0dSopenharmony_ci+ */ 5913be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_ 5914be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_ 5915be168c0dSopenharmony_ci+ 5916be168c0dSopenharmony_ci+#include <vector> 5917be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h" 5918be168c0dSopenharmony_ci+ 5919be168c0dSopenharmony_ci+namespace mindspore::kernel { 5920be168c0dSopenharmony_ci+class CustomIsInfCPUKernel : public LiteKernel { 5921be168c0dSopenharmony_ci+ public: 5922be168c0dSopenharmony_ci+ CustomIsInfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 5923be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 5924be168c0dSopenharmony_ci+ : LiteKernel(parameter, inputs, outputs, ctx) {} 5925be168c0dSopenharmony_ci+ ~CustomIsInfCPUKernel() override = default; 5926be168c0dSopenharmony_ci+ int Prepare() override; 5927be168c0dSopenharmony_ci+ int ReSize() override; 5928be168c0dSopenharmony_ci+ int Run() override; 5929be168c0dSopenharmony_ci+ 5930be168c0dSopenharmony_ci+ private: 5931be168c0dSopenharmony_ci+ void LaunchKernelFloat(const float *input, bool *output); 5932be168c0dSopenharmony_ci+}; 5933be168c0dSopenharmony_ci+} // namespace mindspore::kernel 5934be168c0dSopenharmony_ci+ 5935be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_ 5936be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc 5937be168c0dSopenharmony_cinew file mode 100644 5938be168c0dSopenharmony_ciindex 00000000..9af1af5d 5939be168c0dSopenharmony_ci--- /dev/null 5940be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc 5941be168c0dSopenharmony_ci@@ -0,0 +1,84 @@ 5942be168c0dSopenharmony_ci+/** 5943be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 5944be168c0dSopenharmony_ci+ * 5945be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 5946be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 5947be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 5948be168c0dSopenharmony_ci+ * 5949be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 5950be168c0dSopenharmony_ci+ * 5951be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 5952be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 5953be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 5954be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 5955be168c0dSopenharmony_ci+ * limitations under the License. 5956be168c0dSopenharmony_ci+ */ 5957be168c0dSopenharmony_ci+#include "src/litert/kernel_registry.h" 5958be168c0dSopenharmony_ci+#include "include/errorcode.h" 5959be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/custom_masked_fill.h" 5960be168c0dSopenharmony_ci+#include "src/common/tensor_util.h" 5961be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 5962be168c0dSopenharmony_ci+ 5963be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar; 5964be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR; 5965be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 5966be168c0dSopenharmony_ci+ 5967be168c0dSopenharmony_ci+namespace mindspore::kernel { 5968be168c0dSopenharmony_ci+ 5969be168c0dSopenharmony_ci+int CustomMaskedFillCPUKernel::Prepare() { 5970be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(in_tensors_.size(), C3NUM); 5971be168c0dSopenharmony_ci+ CHECK_LESS_RETURN(out_tensors_.size(), C1NUM); 5972be168c0dSopenharmony_ci+ 5973be168c0dSopenharmony_ci+ // only support input value as a single float value 5974be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in_tensors_[FIRST_INPUT]->data_type() == mindspore::TypeId::kNumberTypeFloat32 || 5975be168c0dSopenharmony_ci+ in_tensors_[FIRST_INPUT]->data_type() == mindspore::TypeId::kNumberTypeFloat, 5976be168c0dSopenharmony_ci+ RET_ERROR, "input dtype must be float32"); 5977be168c0dSopenharmony_ci+ if (in_tensors_[THIRD_INPUT]->ElementsNum() != 1) { 5978be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "only support fill value as a single float"; 5979be168c0dSopenharmony_ci+ return RET_ERROR; 5980be168c0dSopenharmony_ci+ } 5981be168c0dSopenharmony_ci+ MS_CHECK_TRUE_MSG(in_tensors_[SECOND_INPUT]->data_type() == mindspore::TypeId::kNumberTypeBool, RET_ERROR, 5982be168c0dSopenharmony_ci+ "mask dtype must be bool"); 5983be168c0dSopenharmony_ci+ if (!InferShapeDone()) { 5984be168c0dSopenharmony_ci+ return RET_OK; 5985be168c0dSopenharmony_ci+ } 5986be168c0dSopenharmony_ci+ return ReSize(); 5987be168c0dSopenharmony_ci+} 5988be168c0dSopenharmony_ci+ 5989be168c0dSopenharmony_ci+int CustomMaskedFillCPUKernel::ReSize() { return RET_OK; } 5990be168c0dSopenharmony_ci+ 5991be168c0dSopenharmony_ci+int CustomMaskedFillCPUKernel::Run() { 5992be168c0dSopenharmony_ci+ auto input = in_tensors_[FIRST_INPUT]; 5993be168c0dSopenharmony_ci+ auto mask = in_tensors_[SECOND_INPUT]; 5994be168c0dSopenharmony_ci+ auto value = in_tensors_[THIRD_INPUT]; 5995be168c0dSopenharmony_ci+ auto output = out_tensors_[FIRST_INPUT]; 5996be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(input); 5997be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(mask); 5998be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(value); 5999be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(output); 6000be168c0dSopenharmony_ci+ 6001be168c0dSopenharmony_ci+ if (input->shape() != mask->shape()) { 6002be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Not support broadcast mask to input"; 6003be168c0dSopenharmony_ci+ return RET_ERROR; 6004be168c0dSopenharmony_ci+ } 6005be168c0dSopenharmony_ci+ 6006be168c0dSopenharmony_ci+ auto value_data = reinterpret_cast<float *>(value->data()); 6007be168c0dSopenharmony_ci+ auto fill_value = value_data[0]; 6008be168c0dSopenharmony_ci+ 6009be168c0dSopenharmony_ci+ auto data_num = input->ElementsNum(); 6010be168c0dSopenharmony_ci+ auto input_data = reinterpret_cast<float *>(input->data()); 6011be168c0dSopenharmony_ci+ auto mask_data = reinterpret_cast<bool *>(mask->data()); 6012be168c0dSopenharmony_ci+ auto output_data = reinterpret_cast<float *>(output->data()); 6013be168c0dSopenharmony_ci+ for (int64_t i = 0; i < data_num; i++) { 6014be168c0dSopenharmony_ci+ if (mask_data[i]) { 6015be168c0dSopenharmony_ci+ output_data[i] = fill_value; 6016be168c0dSopenharmony_ci+ } else { 6017be168c0dSopenharmony_ci+ output_data[i] = input_data[i]; 6018be168c0dSopenharmony_ci+ } 6019be168c0dSopenharmony_ci+ } 6020be168c0dSopenharmony_ci+ 6021be168c0dSopenharmony_ci+ return RET_OK; 6022be168c0dSopenharmony_ci+} 6023be168c0dSopenharmony_ci+ 6024be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomMaskedFill, LiteKernelCreator<CustomMaskedFillCPUKernel>) 6025be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6026be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h 6027be168c0dSopenharmony_cinew file mode 100644 6028be168c0dSopenharmony_ciindex 00000000..04a2dcab 6029be168c0dSopenharmony_ci--- /dev/null 6030be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h 6031be168c0dSopenharmony_ci@@ -0,0 +1,35 @@ 6032be168c0dSopenharmony_ci+/** 6033be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 6034be168c0dSopenharmony_ci+ * 6035be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6036be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6037be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6038be168c0dSopenharmony_ci+ * 6039be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6040be168c0dSopenharmony_ci+ * 6041be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6042be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6043be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6044be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6045be168c0dSopenharmony_ci+ * limitations under the License. 6046be168c0dSopenharmony_ci+ */ 6047be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_ 6048be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_ 6049be168c0dSopenharmony_ci+ 6050be168c0dSopenharmony_ci+#include <vector> 6051be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h" 6052be168c0dSopenharmony_ci+ 6053be168c0dSopenharmony_ci+namespace mindspore::kernel { 6054be168c0dSopenharmony_ci+class CustomMaskedFillCPUKernel : public LiteKernel { 6055be168c0dSopenharmony_ci+ public: 6056be168c0dSopenharmony_ci+ CustomMaskedFillCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6057be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6058be168c0dSopenharmony_ci+ : LiteKernel(parameter, inputs, outputs, ctx) {} 6059be168c0dSopenharmony_ci+ ~CustomMaskedFillCPUKernel() override = default; 6060be168c0dSopenharmony_ci+ int Prepare() override; 6061be168c0dSopenharmony_ci+ int ReSize() override; 6062be168c0dSopenharmony_ci+ int Run() override; 6063be168c0dSopenharmony_ci+}; 6064be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6065be168c0dSopenharmony_ci+ 6066be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_ 6067be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc 6068be168c0dSopenharmony_cinew file mode 100644 6069be168c0dSopenharmony_ciindex 00000000..d52d67d5 6070be168c0dSopenharmony_ci--- /dev/null 6071be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc 6072be168c0dSopenharmony_ci@@ -0,0 +1,75 @@ 6073be168c0dSopenharmony_ci+/** 6074be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd 6075be168c0dSopenharmony_ci+ * 6076be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6077be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6078be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6079be168c0dSopenharmony_ci+ * 6080be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6081be168c0dSopenharmony_ci+ * 6082be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6083be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6084be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6085be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6086be168c0dSopenharmony_ci+ * limitations under the License. 6087be168c0dSopenharmony_ci+ */ 6088be168c0dSopenharmony_ci+ 6089be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/custom_tensor_scatter.h" 6090be168c0dSopenharmony_ci+#include <cstring> 6091be168c0dSopenharmony_ci+#include "schema/model_generated.h" 6092be168c0dSopenharmony_ci+#include "src/litert/kernel_registry.h" 6093be168c0dSopenharmony_ci+#include "include/errorcode.h" 6094be168c0dSopenharmony_ci+#include "nnacl/base/scatter_nd_binary.h" 6095be168c0dSopenharmony_ci+ 6096be168c0dSopenharmony_ci+using mindspore::kernel::KERNEL_ARCH; 6097be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar; 6098be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR; 6099be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 6100be168c0dSopenharmony_ci+ 6101be168c0dSopenharmony_ci+namespace mindspore::kernel { 6102be168c0dSopenharmony_ci+namespace { 6103be168c0dSopenharmony_ci+int TensorScatterRun(void *cdata, int task_id, float, float) { 6104be168c0dSopenharmony_ci+ auto kernel = static_cast<CustomTensorScatterCPUKernel *>(cdata); 6105be168c0dSopenharmony_ci+ CHECK_NULL_RETURN(kernel); 6106be168c0dSopenharmony_ci+ return kernel->TensorScatterDispatch(task_id); 6107be168c0dSopenharmony_ci+} 6108be168c0dSopenharmony_ci+} // namespace 6109be168c0dSopenharmony_ci+ 6110be168c0dSopenharmony_ci+int CustomTensorScatterCPUKernel::TensorScatterDispatch(int task_id) { 6111be168c0dSopenharmony_ci+ auto data_type = in_tensors_[kScatterUpdateInputIndex]->data_type(); 6112be168c0dSopenharmony_ci+ if (data_type != kNumberTypeFloat32) { 6113be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "TensorScatterMax only support float32 input tensor, but got " << data_type; 6114be168c0dSopenharmony_ci+ return RET_ERROR; 6115be168c0dSopenharmony_ci+ } 6116be168c0dSopenharmony_ci+ int type = data_type == kNumberTypeFloat32 ? 0 : 1; 6117be168c0dSopenharmony_ci+ // multi thread have some problems to solve 6118be168c0dSopenharmony_ci+ param_->op_parameter.thread_num_ = 1; 6119be168c0dSopenharmony_ci+ auto ret = ScatterNDMax(in_tensors_[kScatterUpdateIndex]->data(), out_tensors_[kOutputIndex]->data(), 6120be168c0dSopenharmony_ci+ output_unit_offsets_.data(), param_, type, task_id); 6121be168c0dSopenharmony_ci+ if (ret != RET_OK) { 6122be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "ScatterNDMax failed, ret: " << ret; 6123be168c0dSopenharmony_ci+ return RET_ERROR; 6124be168c0dSopenharmony_ci+ } 6125be168c0dSopenharmony_ci+ return RET_OK; 6126be168c0dSopenharmony_ci+} 6127be168c0dSopenharmony_ci+ 6128be168c0dSopenharmony_ci+int CustomTensorScatterCPUKernel::Run() { 6129be168c0dSopenharmony_ci+ auto in_tensor = in_tensors().front(); 6130be168c0dSopenharmony_ci+ auto out_tensor = out_tensors().front(); 6131be168c0dSopenharmony_ci+ (void)memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size()); 6132be168c0dSopenharmony_ci+ auto indices = in_tensors_.at(kScatterIndicesIndex); 6133be168c0dSopenharmony_ci+ if (!indices->IsConst() && ReSize() != RET_OK) { 6134be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "TensorScatterAdd resize failed."; 6135be168c0dSopenharmony_ci+ return RET_ERROR; 6136be168c0dSopenharmony_ci+ } 6137be168c0dSopenharmony_ci+ 6138be168c0dSopenharmony_ci+ auto ret = ParallelLaunch(ms_context_, TensorScatterRun, this, op_parameter_->thread_num_); 6139be168c0dSopenharmony_ci+ if (ret != RET_OK) { 6140be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "TensorScatterAdd error error_code[" << ret << "]"; 6141be168c0dSopenharmony_ci+ } 6142be168c0dSopenharmony_ci+ return ret; 6143be168c0dSopenharmony_ci+} 6144be168c0dSopenharmony_ci+ 6145be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomTensorScatterMax, 6146be168c0dSopenharmony_ci+ LiteKernelCreator<CustomTensorScatterCPUKernel>) 6147be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6148be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h 6149be168c0dSopenharmony_cinew file mode 100644 6150be168c0dSopenharmony_ciindex 00000000..e39733c5 6151be168c0dSopenharmony_ci--- /dev/null 6152be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h 6153be168c0dSopenharmony_ci@@ -0,0 +1,36 @@ 6154be168c0dSopenharmony_ci+/** 6155be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd 6156be168c0dSopenharmony_ci+ * 6157be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6158be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6159be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6160be168c0dSopenharmony_ci+ * 6161be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6162be168c0dSopenharmony_ci+ * 6163be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6164be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6165be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6166be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6167be168c0dSopenharmony_ci+ * limitations under the License. 6168be168c0dSopenharmony_ci+ */ 6169be168c0dSopenharmony_ci+ 6170be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_ 6171be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_ 6172be168c0dSopenharmony_ci+ 6173be168c0dSopenharmony_ci+#include <vector> 6174be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/scatter_nd_binary.h" 6175be168c0dSopenharmony_ci+ 6176be168c0dSopenharmony_ci+namespace mindspore::kernel { 6177be168c0dSopenharmony_ci+class CustomTensorScatterCPUKernel : public ScatterNDBinaryCPUKernel { 6178be168c0dSopenharmony_ci+ public: 6179be168c0dSopenharmony_ci+ explicit CustomTensorScatterCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, 6180be168c0dSopenharmony_ci+ const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx) 6181be168c0dSopenharmony_ci+ : ScatterNDBinaryCPUKernel(parameter, inputs, outputs, ctx) {} 6182be168c0dSopenharmony_ci+ ~CustomTensorScatterCPUKernel() override = default; 6183be168c0dSopenharmony_ci+ 6184be168c0dSopenharmony_ci+ int Run() override; 6185be168c0dSopenharmony_ci+ int TensorScatterDispatch(int task_id); 6186be168c0dSopenharmony_ci+}; 6187be168c0dSopenharmony_ci+} // namespace mindspore::kernel 6188be168c0dSopenharmony_ci+ 6189be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_ 6190be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc 6191be168c0dSopenharmony_ciindex 2c5bc658..13652633 100644 6192be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_model.cc 6193be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_model.cc 6194be168c0dSopenharmony_ci@@ -98,6 +98,8 @@ int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) { 6195be168c0dSopenharmony_ci if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr || 6196be168c0dSopenharmony_ci sub_graph.tensorIndices() == nullptr) { 6197be168c0dSopenharmony_ci MS_LOG(ERROR) << "sub_graph is invalid"; 6198be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "sub_graph.name() = " << sub_graph.name() << ", sub_graph.inputIndices() = " << sub_graph.inputIndices() 6199be168c0dSopenharmony_ci+ << ", sub_graph.outputIndices() = " << sub_graph.outputIndices() << ", sub_graph.tensorIndices() = " << sub_graph.tensorIndices(); 6200be168c0dSopenharmony_ci return RET_ERROR; 6201be168c0dSopenharmony_ci } 6202be168c0dSopenharmony_ci 6203be168c0dSopenharmony_ci@@ -620,6 +622,33 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, minds 6204be168c0dSopenharmony_ci return model; 6205be168c0dSopenharmony_ci } 6206be168c0dSopenharmony_ci 6207be168c0dSopenharmony_ci+std::string LiteGraph::ToString() const { 6208be168c0dSopenharmony_ci+ std::stringstream ss; 6209be168c0dSopenharmony_ci+ ss << "all_nodes: " << all_nodes_.size() << std::endl; 6210be168c0dSopenharmony_ci+ for (size_t i = 0; i < all_nodes_.size(); i++) { 6211be168c0dSopenharmony_ci+ ss << "- node " << i << ": " << all_nodes_[i]->primitive_ << std::endl; 6212be168c0dSopenharmony_ci+ ss << "- node " << i << " input_indices_: " << all_nodes_[i]->input_indices_ << std::endl; 6213be168c0dSopenharmony_ci+ ss << "- node " << i << " output_indices_: " << all_nodes_[i]->output_indices_ << std::endl; 6214be168c0dSopenharmony_ci+ } 6215be168c0dSopenharmony_ci+ ss << "all_tensors: " << all_tensors_.size() << std::endl; 6216be168c0dSopenharmony_ci+ for (size_t i = 0; i < all_tensors_.size(); i++) { 6217be168c0dSopenharmony_ci+ ss << "- tensor " << i << ": " << all_tensors_[i] << std::endl; 6218be168c0dSopenharmony_ci+ } 6219be168c0dSopenharmony_ci+ ss << "input_indices: " << input_indices_<< std::endl; 6220be168c0dSopenharmony_ci+ ss << "output_indices: " << output_indices_ << std::endl; 6221be168c0dSopenharmony_ci+ 6222be168c0dSopenharmony_ci+ ss << "subgraphs: " << std::endl; 6223be168c0dSopenharmony_ci+ int count = 0; 6224be168c0dSopenharmony_ci+ for (auto subgraph: sub_graphs_) { 6225be168c0dSopenharmony_ci+ ss << "- subgraph " << count++ << std::endl; 6226be168c0dSopenharmony_ci+ ss << "--- subgraph input " << subgraph->input_indices_ << std::endl; 6227be168c0dSopenharmony_ci+ ss << "--- subgraph output " << subgraph->output_indices_ << std::endl; 6228be168c0dSopenharmony_ci+ ss << "--- subgraph node " << subgraph->node_indices_ << std::endl; 6229be168c0dSopenharmony_ci+ ss << "--- subgraph tensor " << subgraph->tensor_indices_ << std::endl; 6230be168c0dSopenharmony_ci+ } 6231be168c0dSopenharmony_ci+ return ss.str(); 6232be168c0dSopenharmony_ci+} 6233be168c0dSopenharmony_ci+ 6234be168c0dSopenharmony_ci Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } 6235be168c0dSopenharmony_ci 6236be168c0dSopenharmony_ci Model *Model::Import(const char *filename) { return ImportFromPath(filename); } 6237be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_session.cc b/mindspore/lite/src/litert/lite_session.cc 6238be168c0dSopenharmony_ciindex 8f54879e..f635c8d2 100644 6239be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_session.cc 6240be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_session.cc 6241be168c0dSopenharmony_ci@@ -67,6 +67,9 @@ 6242be168c0dSopenharmony_ci #include "thread/parallel_thread_pool_manager.h" 6243be168c0dSopenharmony_ci #endif 6244be168c0dSopenharmony_ci #include "src/litert/runtime_packed_node_pass.h" 6245be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT 6246be168c0dSopenharmony_ci+#include "src/litert/delegate/nnrt/nnrt_delegate.h" 6247be168c0dSopenharmony_ci+#endif 6248be168c0dSopenharmony_ci 6249be168c0dSopenharmony_ci using AbstractBaseModel = mindspore::infer::AbstractBaseModel; 6250be168c0dSopenharmony_ci 6251be168c0dSopenharmony_ci@@ -635,12 +638,6 @@ int LiteSession::CompileGraph(Model *model) { 6252be168c0dSopenharmony_ci MarkSharedWeight(kernels_); 6253be168c0dSopenharmony_ci FreePackOpWeight(kernels_); 6254be168c0dSopenharmony_ci 6255be168c0dSopenharmony_ci- ret = RuntimeAllocatorInit(); 6256be168c0dSopenharmony_ci- if (ret != RET_OK) { 6257be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Runtime allocator init failed."; 6258be168c0dSopenharmony_ci- is_running_.store(false); 6259be168c0dSopenharmony_ci- return ret; 6260be168c0dSopenharmony_ci- } 6261be168c0dSopenharmony_ci infer_along_running_ = infer_along_running_ && (runtime_allocator_ == nullptr); 6262be168c0dSopenharmony_ci if (infer_along_running_) { 6263be168c0dSopenharmony_ci this->context_->set_infer_checker(InferCheckerAll); 6264be168c0dSopenharmony_ci@@ -1092,6 +1089,27 @@ int LiteSession::CreateCoreMLDelegate() { 6265be168c0dSopenharmony_ci return RET_OK; 6266be168c0dSopenharmony_ci } 6267be168c0dSopenharmony_ci 6268be168c0dSopenharmony_ci+int LiteSession::CreateNNRTDelegate() { 6269be168c0dSopenharmony_ci+#if SUPPORT_NNRT 6270be168c0dSopenharmony_ci+ auto iter = std::find_if(context_->device_list_.begin(), context_->device_list_.end(), 6271be168c0dSopenharmony_ci+ [](DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; }); 6272be168c0dSopenharmony_ci+ if(iter == context_->device_list_.end()) { 6273be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Found non NNRT device info"; 6274be168c0dSopenharmony_ci+ return RET_ERROR; 6275be168c0dSopenharmony_ci+ } 6276be168c0dSopenharmony_ci+ 6277be168c0dSopenharmony_ci+ delegate_ = std::make_shared<NNRTDelegate>(iter->device_info_.nnrt_device_info_); 6278be168c0dSopenharmony_ci+ if (delegate_ == nullptr) { 6279be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "New NNRT delegate failed"; 6280be168c0dSopenharmony_ci+ return RET_ERROR; 6281be168c0dSopenharmony_ci+ } 6282be168c0dSopenharmony_ci+// ((NNRTDelegate *)(delegate_.get()))->SetMetaGraph(this->model_->buf); 6283be168c0dSopenharmony_ci+ delegate_device_type_ = DT_NNRT; 6284be168c0dSopenharmony_ci+ this->context_->delegate = delegate_; 6285be168c0dSopenharmony_ci+#endif 6286be168c0dSopenharmony_ci+ return RET_OK; 6287be168c0dSopenharmony_ci+}; 6288be168c0dSopenharmony_ci+ 6289be168c0dSopenharmony_ci int LiteSession::DelegateInit() { 6290be168c0dSopenharmony_ci #ifndef DELEGATE_CLIP 6291be168c0dSopenharmony_ci int ret = RET_OK; 6292be168c0dSopenharmony_ci@@ -1115,6 +1133,8 @@ int LiteSession::DelegateInit() { 6293be168c0dSopenharmony_ci ret = CreateNPUDelegate(); 6294be168c0dSopenharmony_ci } else if (context_->IsDeviceTypeEnabled(DT_GPU)) { 6295be168c0dSopenharmony_ci ret = CreateTensorRTDelegate(); 6296be168c0dSopenharmony_ci+ } else if (context_->IsDeviceTypeEnabled(DT_NNRT)) { 6297be168c0dSopenharmony_ci+ ret = CreateNNRTDelegate(); 6298be168c0dSopenharmony_ci } 6299be168c0dSopenharmony_ci } 6300be168c0dSopenharmony_ci 6301be168c0dSopenharmony_ci@@ -1496,12 +1516,6 @@ int LiteSession::Resize(const std::vector<mindspore::lite::Tensor *> &inputs, 6302be168c0dSopenharmony_ci return ret; 6303be168c0dSopenharmony_ci } 6304be168c0dSopenharmony_ci 6305be168c0dSopenharmony_ci- if (RuntimeAllocatorInit() != RET_OK) { 6306be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Runtime allocator in resize failed."; 6307be168c0dSopenharmony_ci- is_running_.store(false); 6308be168c0dSopenharmony_ci- return RET_ERROR; 6309be168c0dSopenharmony_ci- } 6310be168c0dSopenharmony_ci- 6311be168c0dSopenharmony_ci auto status = GraphOptimizePass(&kernels_); 6312be168c0dSopenharmony_ci if (status != RET_OK) { 6313be168c0dSopenharmony_ci MS_LOG(ERROR) << "GraphOptimizePass failed."; 6314be168c0dSopenharmony_ci@@ -2022,7 +2036,6 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, 6315be168c0dSopenharmony_ci delete model; 6316be168c0dSopenharmony_ci return RET_ERROR; 6317be168c0dSopenharmony_ci } 6318be168c0dSopenharmony_ci- model->Free(); 6319be168c0dSopenharmony_ci set_model(model); 6320be168c0dSopenharmony_ci return RET_OK; 6321be168c0dSopenharmony_ci } 6322be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_session.h b/mindspore/lite/src/litert/lite_session.h 6323be168c0dSopenharmony_ciindex f8f8fe08..64a5f6d3 100644 6324be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_session.h 6325be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_session.h 6326be168c0dSopenharmony_ci@@ -178,6 +178,7 @@ class MS_API LiteSession { 6327be168c0dSopenharmony_ci int CreateNPUDelegate(); 6328be168c0dSopenharmony_ci int CreateNNAPIDelegate(); 6329be168c0dSopenharmony_ci int CreateCoreMLDelegate(); 6330be168c0dSopenharmony_ci+ int CreateNNRTDelegate(); 6331be168c0dSopenharmony_ci int DelegateInit(); 6332be168c0dSopenharmony_ci int InitGPURuntime(); 6333be168c0dSopenharmony_ci int InitSharedThreadPool(); 6334be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc 6335be168c0dSopenharmony_ciindex 11382b09..199b4361 100644 6336be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/scheduler.cc 6337be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/scheduler.cc 6338be168c0dSopenharmony_ci@@ -60,6 +60,9 @@ 6339be168c0dSopenharmony_ci #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT) 6340be168c0dSopenharmony_ci #include "thread/parallel_thread_pool_manager.h" 6341be168c0dSopenharmony_ci #endif 6342be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT 6343be168c0dSopenharmony_ci+#include "src/litert/delegate/nnrt/nnrt_delegate.h" 6344be168c0dSopenharmony_ci+#endif 6345be168c0dSopenharmony_ci 6346be168c0dSopenharmony_ci using AbstractBaseModel = mindspore::infer::AbstractBaseModel; 6347be168c0dSopenharmony_ci 6348be168c0dSopenharmony_ci@@ -368,6 +371,7 @@ STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::KernelExec *> *ker 6349be168c0dSopenharmony_ci } 6350be168c0dSopenharmony_ci 6351be168c0dSopenharmony_ci int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) { 6352be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "Start schedule."; 6353be168c0dSopenharmony_ci int check_input_ret = CheckInputParam(dst_kernels); 6354be168c0dSopenharmony_ci if (check_input_ret != RET_OK) { 6355be168c0dSopenharmony_ci MS_LOG(ERROR) << "CheckInputParam failed! ret: " << check_input_ret; 6356be168c0dSopenharmony_ci@@ -404,11 +408,13 @@ int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) { 6357be168c0dSopenharmony_ci } 6358be168c0dSopenharmony_ci shape_fusion_pass_->StoreStateAndReset(); 6359be168c0dSopenharmony_ci 6360be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "Start to init delegate kernels."; 6361be168c0dSopenharmony_ci ret = InitDelegateKernels(dst_kernels); 6362be168c0dSopenharmony_ci if (ret != RET_OK) { 6363be168c0dSopenharmony_ci MS_LOG(ERROR) << "Repalce delegate kernels failed."; 6364be168c0dSopenharmony_ci return ret; 6365be168c0dSopenharmony_ci } 6366be168c0dSopenharmony_ci+ MS_LOG(DEBUG) << "Finish to init delegate kernels."; 6367be168c0dSopenharmony_ci 6368be168c0dSopenharmony_ci ret = CheckCpuValid(dst_kernels); 6369be168c0dSopenharmony_ci if (ret != RET_OK) { 6370be168c0dSopenharmony_ci@@ -500,6 +506,17 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::KernelExec *> *dst_ker 6371be168c0dSopenharmony_ci MS_LOG(ERROR) << "New delegate model failed."; 6372be168c0dSopenharmony_ci return RET_NULL_PTR; 6373be168c0dSopenharmony_ci } 6374be168c0dSopenharmony_ci+ 6375be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT 6376be168c0dSopenharmony_ci+ if (context_->IsDeviceTypeEnabled(DT_NNRT)) { 6377be168c0dSopenharmony_ci+ auto delegate = static_cast<NNRTDelegate *>(delegate_.get()); 6378be168c0dSopenharmony_ci+ delegate->ShallowCopyLiteGraph(this->src_model_->graph_); 6379be168c0dSopenharmony_ci+ void *meta_graph = reinterpret_cast<void*>(const_cast<mindspore::schema::MetaGraph *>( 6380be168c0dSopenharmony_ci+ mindspore::schema::GetMetaGraph(this->src_model_->buf))); 6381be168c0dSopenharmony_ci+ delegate->SetMetaGraph(meta_graph); 6382be168c0dSopenharmony_ci+ } 6383be168c0dSopenharmony_ci+#endif 6384be168c0dSopenharmony_ci+ 6385be168c0dSopenharmony_ci auto ret = delegate_->Build(model); 6386be168c0dSopenharmony_ci if (ret != mindspore::kSuccess) { 6387be168c0dSopenharmony_ci delete model; 6388be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/tensor_category.cc b/mindspore/lite/src/litert/tensor_category.cc 6389be168c0dSopenharmony_ciindex 70d13865..e57cdb28 100644 6390be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/tensor_category.cc 6391be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/tensor_category.cc 6392be168c0dSopenharmony_ci@@ -30,5 +30,9 @@ Category TensorCategory(const schema::Tensor &tensor) { 6393be168c0dSopenharmony_ci auto data_size = tensor.data() == nullptr ? 0 : tensor.data()->size(); 6394be168c0dSopenharmony_ci return TensorCategory(tensor.nodeType(), shape_num, TypeId(tensor.dataType()), data_size); 6395be168c0dSopenharmony_ci } 6396be168c0dSopenharmony_ci+ 6397be168c0dSopenharmony_ci+bool IsConstTensor(const schema::Tensor &tensor) { 6398be168c0dSopenharmony_ci+ return TensorCategory(tensor) != Category::VAR; 6399be168c0dSopenharmony_ci+} 6400be168c0dSopenharmony_ci } // namespace lite 6401be168c0dSopenharmony_ci } // namespace mindspore 6402be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/tensor_category.h b/mindspore/lite/src/litert/tensor_category.h 6403be168c0dSopenharmony_ciindex 83273032..70e65b31 100644 6404be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/tensor_category.h 6405be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/tensor_category.h 6406be168c0dSopenharmony_ci@@ -35,6 +35,7 @@ enum Category { 6407be168c0dSopenharmony_ci 6408be168c0dSopenharmony_ci Category TensorCategory(const int node_type, const size_t shape_num, const TypeId data_type, const size_t data_size); 6409be168c0dSopenharmony_ci Category TensorCategory(const schema::Tensor &tensor); 6410be168c0dSopenharmony_ci+bool IsConstTensor(const schema::Tensor &tensor); 6411be168c0dSopenharmony_ci } // namespace lite 6412be168c0dSopenharmony_ci } // namespace mindspore 6413be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_SRC_RUNTIME_TENSOR_CATEGORY_H_ 6414be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt 6415be168c0dSopenharmony_ciindex 60e240f0..78dab536 100644 6416be168c0dSopenharmony_ci--- a/mindspore/lite/test/CMakeLists.txt 6417be168c0dSopenharmony_ci+++ b/mindspore/lite/test/CMakeLists.txt 6418be168c0dSopenharmony_ci@@ -28,10 +28,14 @@ file(GLOB_RECURSE TEST_UT_SRC 6419be168c0dSopenharmony_ci ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc 6420be168c0dSopenharmony_ci ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc 6421be168c0dSopenharmony_ci ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc 6422be168c0dSopenharmony_ci- ${TEST_DIR}/ut/src/api/context_c_test.cc 6423be168c0dSopenharmony_ci- ${TEST_DIR}/ut/src/api/model_c_test.cc 6424be168c0dSopenharmony_ci- ${TEST_DIR}/ut/src/api/tensor_c_test.cc` 6425be168c0dSopenharmony_ci+# ${TEST_DIR}/ut/src/api/context_c_test.cc 6426be168c0dSopenharmony_ci+# ${TEST_DIR}/ut/src/api/model_c_test.cc 6427be168c0dSopenharmony_ci+# ${TEST_DIR}/ut/src/api/tensor_c_test.cc` 6428be168c0dSopenharmony_ci ) 6429be168c0dSopenharmony_ci+if(MSLITE_ENABLE_NNRT) 6430be168c0dSopenharmony_ci+ list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/nnrt_delegate/nnrt_delegate_tests.cc) 6431be168c0dSopenharmony_ci+endif() 6432be168c0dSopenharmony_ci+ 6433be168c0dSopenharmony_ci if(MSLITE_ENABLE_SERVER_INFERENCE) 6434be168c0dSopenharmony_ci list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/api/model_parallel_runner_test.cc) 6435be168c0dSopenharmony_ci endif() 6436be168c0dSopenharmony_ci@@ -86,7 +90,7 @@ endif() 6437be168c0dSopenharmony_ci 6438be168c0dSopenharmony_ci if(MSLITE_ENABLE_INT8) 6439be168c0dSopenharmony_ci file(GLOB_RECURSE TEST_INT8_UT_SRC 6440be168c0dSopenharmony_ci- ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc 6441be168c0dSopenharmony_ci+# ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc 6442be168c0dSopenharmony_ci ${TEST_DIR}/ut/nnacl/int8/*.cc 6443be168c0dSopenharmony_ci ) 6444be168c0dSopenharmony_ci list(APPEND TEST_UT_SRC ${TEST_INT8_UT_SRC}) 6445be168c0dSopenharmony_ci@@ -118,6 +122,7 @@ if(MSLITE_ENABLE_CONVERTER) 6446be168c0dSopenharmony_ci ${TEST_DIR}/ut/tools/converter/registry/*.cc 6447be168c0dSopenharmony_ci ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc 6448be168c0dSopenharmony_ci ${TEST_DIR}/ut/tools/converter/api/*.cc 6449be168c0dSopenharmony_ci+ ${TEST_DIR}/ut/tools/converter/config_parser/*.cc 6450be168c0dSopenharmony_ci ${TEST_DIR}/st/converter_test.cc 6451be168c0dSopenharmony_ci ${TEST_DIR}/st/delegate_test.cc 6452be168c0dSopenharmony_ci ${TEST_DIR}/st/mindrt_parallel_test.cc 6453be168c0dSopenharmony_ci@@ -232,7 +237,7 @@ endif() 6454be168c0dSopenharmony_ci 6455be168c0dSopenharmony_ci if(MSLITE_ENABLE_CONVERTER) 6456be168c0dSopenharmony_ci target_link_libraries(lite-test-converter tflite_parser_mid caffe_parser_mid 6457be168c0dSopenharmony_ci- onnx_parser_mid tf_parser_mid) 6458be168c0dSopenharmony_ci+ onnx_parser_mid tf_parser_mid third_party_parser_mid) 6459be168c0dSopenharmony_ci endif() 6460be168c0dSopenharmony_ci 6461be168c0dSopenharmony_ci if(MSLITE_ENABLE_MODEL_OBF) 6462be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/runtest.sh b/mindspore/lite/test/runtest.sh 6463be168c0dSopenharmony_ciindex c0d6d843..abdea6f4 100644 6464be168c0dSopenharmony_ci--- a/mindspore/lite/test/runtest.sh 6465be168c0dSopenharmony_ci+++ b/mindspore/lite/test/runtest.sh 6466be168c0dSopenharmony_ci@@ -80,6 +80,7 @@ if [ "$ENABLE_CONVERTER_TEST" = true ]; then 6467be168c0dSopenharmony_ci ./lite-test-converter --gtest_filter="PassRegistryTest.TestRegistry" 6468be168c0dSopenharmony_ci ./lite-test-converter --gtest_filter="TestConverterAPI.*" 6469be168c0dSopenharmony_ci ./lite-test-converter --gtest_filter="SpecifyGraphOutputFormatTest*" 6470be168c0dSopenharmony_ci+ ./lite-test-converter --gtest_filter="TestThirdPartyParamParser.*" 6471be168c0dSopenharmony_ci fi 6472be168c0dSopenharmony_ci ./lite-test --gtest_filter="TestRegistry.TestAdd" 6473be168c0dSopenharmony_ci ./lite-test --gtest_filter="TestRegistryCustomOp.TestCustomAdd" 6474be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/ut/test_data/third_party_model.cfg b/mindspore/lite/test/ut/test_data/third_party_model.cfg 6475be168c0dSopenharmony_cinew file mode 100644 6476be168c0dSopenharmony_ciindex 00000000..b5fcba75 6477be168c0dSopenharmony_ci--- /dev/null 6478be168c0dSopenharmony_ci+++ b/mindspore/lite/test/ut/test_data/third_party_model.cfg 6479be168c0dSopenharmony_ci@@ -0,0 +1,8 @@ 6480be168c0dSopenharmony_ci+[third_party_model] 6481be168c0dSopenharmony_ci+input_names=demo_in_0;demo_in_1;demo_in_2 6482be168c0dSopenharmony_ci+input_dtypes=float32;float16;float64 6483be168c0dSopenharmony_ci+input_shapes=1;2,3;4,5,6 6484be168c0dSopenharmony_ci+output_names=demo_out_0;demo_out_1;demo_out_2;demo_out_4 6485be168c0dSopenharmony_ci+output_dtypes=int32;int16;int8;uint8 6486be168c0dSopenharmony_ci+output_shapes=10;20,30;40;50,60,70 6487be168c0dSopenharmony_ci+extended_parameters=foo:foo_value;bar:bar_value 6488be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 6489be168c0dSopenharmony_ciindex 549bdd72..e73afc0e 100644 6490be168c0dSopenharmony_ci--- a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 6491be168c0dSopenharmony_ci+++ b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc 6492be168c0dSopenharmony_ci@@ -34,3 +34,13 @@ TEST(TestConverterAPI, ConvertCaffeWithNotExistWeight) { 6493be168c0dSopenharmony_ci mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeCaffe, caffe_model, output_model, caffe_weight); 6494be168c0dSopenharmony_ci ASSERT_FALSE(converter.Convert().IsOk()); 6495be168c0dSopenharmony_ci } 6496be168c0dSopenharmony_ci+ 6497be168c0dSopenharmony_ci+TEST(TestConverterAPI, ConvertThirdParty) { 6498be168c0dSopenharmony_ci+ std::string third_party_model = "./relu.mindir"; 6499be168c0dSopenharmony_ci+ std::string config_model = "./third_party_model.cfg"; 6500be168c0dSopenharmony_ci+ std::string output_model = "./demo_third_party.ms"; 6501be168c0dSopenharmony_ci+ 6502be168c0dSopenharmony_ci+ mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeThirdParty, third_party_model, output_model); 6503be168c0dSopenharmony_ci+ converter.SetConfigFile(config_model); 6504be168c0dSopenharmony_ci+ ASSERT_TRUE(converter.Convert().IsOk()); 6505be168c0dSopenharmony_ci+} 6506be168c0dSopenharmony_ci\ No newline at end of file 6507be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 6508be168c0dSopenharmony_cinew file mode 100644 6509be168c0dSopenharmony_ciindex 00000000..c8eb5536 6510be168c0dSopenharmony_ci--- /dev/null 6511be168c0dSopenharmony_ci+++ b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc 6512be168c0dSopenharmony_ci@@ -0,0 +1,176 @@ 6513be168c0dSopenharmony_ci+/** 6514be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 6515be168c0dSopenharmony_ci+ * 6516be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 6517be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 6518be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 6519be168c0dSopenharmony_ci+ * 6520be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 6521be168c0dSopenharmony_ci+ * 6522be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 6523be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 6524be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 6525be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 6526be168c0dSopenharmony_ci+ * limitations under the License. 6527be168c0dSopenharmony_ci+ */ 6528be168c0dSopenharmony_ci+ 6529be168c0dSopenharmony_ci+#include "gtest/gtest.h" 6530be168c0dSopenharmony_ci+#include "tools/converter/config_parser/third_party_param_parser.h" 6531be168c0dSopenharmony_ci+ 6532be168c0dSopenharmony_ci+using mindspore::ThirdPartyModelParam; 6533be168c0dSopenharmony_ci+using mindspore::TypeId; 6534be168c0dSopenharmony_ci+using mindspore::lite::RET_OK; 6535be168c0dSopenharmony_ci+using mindspore::lite::ThirdPartyModelString; 6536be168c0dSopenharmony_ci+using mindspore::lite::ThirdPartyParamParser; 6537be168c0dSopenharmony_ci+ 6538be168c0dSopenharmony_ci+const ThirdPartyModelString kDemoSISOParam = { 6539be168c0dSopenharmony_ci+ // SISO is short for single-input-single-output. 6540be168c0dSopenharmony_ci+ .input_dtypes = "float32", 6541be168c0dSopenharmony_ci+ .input_shapes = "1,2,3,4", 6542be168c0dSopenharmony_ci+ .input_names = "siso_input", 6543be168c0dSopenharmony_ci+ .output_dtypes = "int32", 6544be168c0dSopenharmony_ci+ .output_shapes = "2", 6545be168c0dSopenharmony_ci+ .output_names = "siso_output", 6546be168c0dSopenharmony_ci+ .extended_parameters = "siso_foo:siso_foo_value;siso_bar:siso_bar_value", 6547be168c0dSopenharmony_ci+}; 6548be168c0dSopenharmony_ci+ 6549be168c0dSopenharmony_ci+const ThirdPartyModelString kDemoMIMOParam = { 6550be168c0dSopenharmony_ci+ // MIMO is short for multiple-input-multiple-output. 6551be168c0dSopenharmony_ci+ .input_dtypes = "float32;int8;float16", 6552be168c0dSopenharmony_ci+ .input_shapes = "1,2,3,4;5,6;7,8,9", 6553be168c0dSopenharmony_ci+ .input_names = "mimo_in_0;mimo_in_1;mimo_in_2", 6554be168c0dSopenharmony_ci+ .output_dtypes = "int32;float32", 6555be168c0dSopenharmony_ci+ .output_shapes = "2,4;10,20,30", 6556be168c0dSopenharmony_ci+ .output_names = "mimo_out_0;mimo_out_1", 6557be168c0dSopenharmony_ci+ .extended_parameters = "mimo_foo:mimo_foo_value;mimo_bar:mimo_bar_value", 6558be168c0dSopenharmony_ci+}; 6559be168c0dSopenharmony_ci+ 6560be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseSISOParam) { 6561be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoSISOParam; 6562be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6563be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6564be168c0dSopenharmony_ci+ 6565be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_names, std::vector<std::string>{"siso_input"}); 6566be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_shapes.size(), 1U); 6567be168c0dSopenharmony_ci+ std::vector<int64_t> expect_in_shape = {1, 2, 3, 4}; 6568be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_shapes[0], expect_in_shape); 6569be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_dtypes, std::vector<TypeId>{TypeId::kNumberTypeFloat32}); 6570be168c0dSopenharmony_ci+ 6571be168c0dSopenharmony_ci+ ASSERT_EQ(result.output_names, std::vector<std::string>{"siso_output"}); 6572be168c0dSopenharmony_ci+ ASSERT_EQ(result.output_shapes.size(), 1U); 6573be168c0dSopenharmony_ci+ std::vector<int64_t> expect_out_shape = {2}; 6574be168c0dSopenharmony_ci+ ASSERT_EQ(result.output_shapes[0], expect_out_shape); 6575be168c0dSopenharmony_ci+ ASSERT_EQ(result.output_dtypes, std::vector<TypeId>{TypeId::kNumberTypeInt32}); 6576be168c0dSopenharmony_ci+ 6577be168c0dSopenharmony_ci+ const auto &ext_param = result.extended_parameters; 6578be168c0dSopenharmony_ci+ ASSERT_EQ(ext_param.size(), 2U); 6579be168c0dSopenharmony_ci+ ASSERT_TRUE(ext_param.find("siso_foo") != ext_param.end()); 6580be168c0dSopenharmony_ci+ auto expect_foo_value = ext_param.at("siso_foo"); 6581be168c0dSopenharmony_ci+ ASSERT_EQ(std::string(expect_foo_value.begin(), expect_foo_value.end()), "siso_foo_value"); 6582be168c0dSopenharmony_ci+ ASSERT_TRUE(ext_param.find("siso_bar") != ext_param.end()); 6583be168c0dSopenharmony_ci+ auto expect_bar_value = ext_param.at("siso_bar"); 6584be168c0dSopenharmony_ci+ ASSERT_EQ(std::string(expect_bar_value.begin(), expect_bar_value.end()), "siso_bar_value"); 6585be168c0dSopenharmony_ci+} 6586be168c0dSopenharmony_ci+ 6587be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseValidDtype) { 6588be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoSISOParam; 6589be168c0dSopenharmony_ci+ const std::vector<std::string> kValidDtypeStrings = { 6590be168c0dSopenharmony_ci+ "float64", "float32", "float16", "int64", "int32", "int16", "int8", "uint8", "bool", 6591be168c0dSopenharmony_ci+ }; 6592be168c0dSopenharmony_ci+ 6593be168c0dSopenharmony_ci+ const std::vector<TypeId> kExpects = { 6594be168c0dSopenharmony_ci+ TypeId::kNumberTypeFloat64, TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat16, 6595be168c0dSopenharmony_ci+ TypeId::kNumberTypeInt64, TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt16, 6596be168c0dSopenharmony_ci+ TypeId::kNumberTypeInt8, TypeId::kNumberTypeUInt8, TypeId::kNumberTypeBool}; 6597be168c0dSopenharmony_ci+ 6598be168c0dSopenharmony_ci+ for (size_t i = 0; i < kValidDtypeStrings.size(); i++) { 6599be168c0dSopenharmony_ci+ param_string.input_dtypes = kValidDtypeStrings[i]; 6600be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6601be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6602be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_dtypes[0], kExpects[i]); 6603be168c0dSopenharmony_ci+ } 6604be168c0dSopenharmony_ci+} 6605be168c0dSopenharmony_ci+ 6606be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseInvalidDtype) { 6607be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6608be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoSISOParam; 6609be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6610be168c0dSopenharmony_ci+ param_string.input_dtypes = "bad_dtype"; 6611be168c0dSopenharmony_ci+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6612be168c0dSopenharmony_ci+} 6613be168c0dSopenharmony_ci+ 6614be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseValidShape) { 6615be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoSISOParam; 6616be168c0dSopenharmony_ci+ param_string.input_shapes = "256,256,1024,96"; // Only support fixed shape. 6617be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6618be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6619be168c0dSopenharmony_ci+ std::vector<int64_t> expect = {256, 256, 1024, 96}; 6620be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_shapes[0], expect); 6621be168c0dSopenharmony_ci+} 6622be168c0dSopenharmony_ci+ 6623be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseInvalidShape) { 6624be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6625be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoSISOParam; 6626be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6627be168c0dSopenharmony_ci+ 6628be168c0dSopenharmony_ci+ param_string.input_shapes = "256,256,1024,-1"; 6629be168c0dSopenharmony_ci+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6630be168c0dSopenharmony_ci+ 6631be168c0dSopenharmony_ci+ param_string.input_shapes = "256,256,0,96"; 6632be168c0dSopenharmony_ci+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6633be168c0dSopenharmony_ci+ 6634be168c0dSopenharmony_ci+ param_string.input_shapes = "256,-256,1024,96"; 6635be168c0dSopenharmony_ci+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6636be168c0dSopenharmony_ci+ 6637be168c0dSopenharmony_ci+ param_string.input_shapes = "256,foo,1024,96"; 6638be168c0dSopenharmony_ci+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6639be168c0dSopenharmony_ci+} 6640be168c0dSopenharmony_ci+ 6641be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseDefaultName) { 6642be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6643be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoSISOParam; 6644be168c0dSopenharmony_ci+ param_string.input_names = ""; 6645be168c0dSopenharmony_ci+ param_string.output_names = ""; 6646be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6647be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_names[0], "in_0"); 6648be168c0dSopenharmony_ci+ ASSERT_EQ(result.output_names[0], "out_0"); 6649be168c0dSopenharmony_ci+} 6650be168c0dSopenharmony_ci+ 6651be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseMIMOParam) { 6652be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoMIMOParam; 6653be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6654be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6655be168c0dSopenharmony_ci+ 6656be168c0dSopenharmony_ci+ std::vector<std::string> expect_input_names = {"mimo_in_0", "mimo_in_1", "mimo_in_2"}; 6657be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_names, expect_input_names); 6658be168c0dSopenharmony_ci+ std::vector<std::vector<int64_t>> expect_input_shapes = {{1, 2, 3, 4}, {5, 6}, {7, 8, 9}}; 6659be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_shapes, expect_input_shapes); 6660be168c0dSopenharmony_ci+ std::vector<TypeId> expect_input_dtypes = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt8, 6661be168c0dSopenharmony_ci+ TypeId::kNumberTypeFloat16}; 6662be168c0dSopenharmony_ci+ ASSERT_EQ(result.input_dtypes, expect_input_dtypes); 6663be168c0dSopenharmony_ci+ 6664be168c0dSopenharmony_ci+ std::vector<std::string> expect_output_names = {"mimo_out_0", "mimo_out_1"}; 6665be168c0dSopenharmony_ci+ ASSERT_EQ(result.output_names, expect_output_names); 6666be168c0dSopenharmony_ci+ std::vector<std::vector<int64_t>> expect_output_shapes = {{2, 4}, {10, 20, 30}}; 6667be168c0dSopenharmony_ci+ ASSERT_EQ(result.output_shapes, expect_output_shapes); 6668be168c0dSopenharmony_ci+ std::vector<TypeId> expect_output_dtypes = {TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32}; 6669be168c0dSopenharmony_ci+ ASSERT_EQ(result.output_dtypes, expect_output_dtypes); 6670be168c0dSopenharmony_ci+} 6671be168c0dSopenharmony_ci+ 6672be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseMismatchedShapeAndDtypeSize) { 6673be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoMIMOParam; 6674be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6675be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6676be168c0dSopenharmony_ci+ 6677be168c0dSopenharmony_ci+ param_string.input_shapes = "1,2,3,4;5,6"; // shape size is 2 while dtype size is 3. 6678be168c0dSopenharmony_ci+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6679be168c0dSopenharmony_ci+} 6680be168c0dSopenharmony_ci+ 6681be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseMismatchedNameAndDtypeSize) { 6682be168c0dSopenharmony_ci+ ThirdPartyModelString param_string = kDemoMIMOParam; 6683be168c0dSopenharmony_ci+ ThirdPartyModelParam result; 6684be168c0dSopenharmony_ci+ ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6685be168c0dSopenharmony_ci+ 6686be168c0dSopenharmony_ci+ param_string.input_names = "mimo_in_0;mimo_in_1"; // name size is 2 while dtype size is 3. 6687be168c0dSopenharmony_ci+ ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK); 6688be168c0dSopenharmony_ci+} 6689be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark/benchmark_base.cc b/mindspore/lite/tools/benchmark/benchmark_base.cc 6690be168c0dSopenharmony_ciindex 16b1e218..ebaa9212 100644 6691be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark/benchmark_base.cc 6692be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark/benchmark_base.cc 6693be168c0dSopenharmony_ci@@ -323,7 +323,7 @@ int BenchmarkBase::CheckThreadNumValid() { 6694be168c0dSopenharmony_ci 6695be168c0dSopenharmony_ci int BenchmarkBase::CheckDeviceTypeValid() { 6696be168c0dSopenharmony_ci if (flags_->device_ != "CPU" && flags_->device_ != "GPU" && flags_->device_ != "NPU" && 6697be168c0dSopenharmony_ci- flags_->device_ != "Ascend310" && flags_->device_ != "Ascend310P") { 6698be168c0dSopenharmony_ci+ flags_->device_ != "Ascend310" && flags_->device_ != "Ascend310P" && flags_->device_ != "NNRT") { 6699be168c0dSopenharmony_ci MS_LOG(ERROR) << "Device type:" << flags_->device_ << " is not supported."; 6700be168c0dSopenharmony_ci std::cerr << "Device type:" << flags_->device_ << " is not supported." << std::endl; 6701be168c0dSopenharmony_ci return RET_ERROR; 6702be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark/benchmark_base.h b/mindspore/lite/tools/benchmark/benchmark_base.h 6703be168c0dSopenharmony_ciindex acdea21a..f818270c 100644 6704be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark/benchmark_base.h 6705be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark/benchmark_base.h 6706be168c0dSopenharmony_ci@@ -122,7 +122,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { 6707be168c0dSopenharmony_ci AddFlag(&BenchmarkFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", ""); 6708be168c0dSopenharmony_ci AddFlag(&BenchmarkFlags::group_info_file_, "GroupInfoFile", "Communication group info file", ""); 6709be168c0dSopenharmony_ci AddFlag(&BenchmarkFlags::config_file_, "configFile", "Config file", ""); 6710be168c0dSopenharmony_ci- AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend310P | Auto", "CPU"); 6711be168c0dSopenharmony_ci+ AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend310P | NNRT | Auto", "CPU"); 6712be168c0dSopenharmony_ci AddFlag(&BenchmarkFlags::provider_, "provider", "device provider litert | tensorrt | mindrt", "litert"); 6713be168c0dSopenharmony_ci AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode", "Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU.", 1); 6714be168c0dSopenharmony_ci // MarkPerformance 6715be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark/benchmark_c_api.cc b/mindspore/lite/tools/benchmark/benchmark_c_api.cc 6716be168c0dSopenharmony_ciindex 252e65c6..cb0c56b0 100644 6717be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark/benchmark_c_api.cc 6718be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark/benchmark_c_api.cc 6719be168c0dSopenharmony_ci@@ -125,6 +125,10 @@ int BenchmarkCApi::InitContext() { 6720be168c0dSopenharmony_ci OH_AI_DeviceInfoSetFrequency(npu_device_info, kFrequencyDefault); 6721be168c0dSopenharmony_ci OH_AI_ContextAddDeviceInfo(context_, npu_device_info); 6722be168c0dSopenharmony_ci } 6723be168c0dSopenharmony_ci+ if (flags_->device_ == "NNRT") { 6724be168c0dSopenharmony_ci+ OH_AI_DeviceInfoHandle nnrt_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT); 6725be168c0dSopenharmony_ci+ OH_AI_ContextAddDeviceInfo(context_, nnrt_device_info); 6726be168c0dSopenharmony_ci+ } 6727be168c0dSopenharmony_ci OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU); 6728be168c0dSopenharmony_ci OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_); 6729be168c0dSopenharmony_ci OH_AI_ContextAddDeviceInfo(context_, cpu_device_info); 6730be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc 6731be168c0dSopenharmony_ciindex bb36c168..c18111b6 100644 6732be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc 6733be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc 6734be168c0dSopenharmony_ci@@ -521,6 +521,11 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context> 6735be168c0dSopenharmony_ci // InitMSContextForAscend(context, &device_list); 6736be168c0dSopenharmony_ci } 6737be168c0dSopenharmony_ci 6738be168c0dSopenharmony_ci+ if (flags_->device_ == "NNRT" || flags_->device_ == "Auto") { 6739be168c0dSopenharmony_ci+ std::shared_ptr<NNRTDeviceInfo> nnrt_device_info = std::make_shared<NNRTDeviceInfo>(); 6740be168c0dSopenharmony_ci+ device_list.push_back(nnrt_device_info); 6741be168c0dSopenharmony_ci+ } 6742be168c0dSopenharmony_ci+ 6743be168c0dSopenharmony_ci // CPU priority is behind GPU and NPU 6744be168c0dSopenharmony_ci std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>(); 6745be168c0dSopenharmony_ci device_info->SetEnableFP16(flags_->enable_fp16_); 6746be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/CMakeLists.txt b/mindspore/lite/tools/benchmark_train/CMakeLists.txt 6747be168c0dSopenharmony_ciindex 0c558524..1b9fc347 100644 6748be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/CMakeLists.txt 6749be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/CMakeLists.txt 6750be168c0dSopenharmony_ci@@ -9,6 +9,9 @@ set(COMMON_SRC 6751be168c0dSopenharmony_ci set(TEST_SRC 6752be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/main.cc 6753be168c0dSopenharmony_ci ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc 6754be168c0dSopenharmony_ci+ ${CMAKE_CURRENT_SOURCE_DIR}/net_train_base.cc 6755be168c0dSopenharmony_ci+ ${CMAKE_CURRENT_SOURCE_DIR}/run_net_train.cc 6756be168c0dSopenharmony_ci+ ${CMAKE_CURRENT_SOURCE_DIR}/net_train_c_api.cc 6757be168c0dSopenharmony_ci ) 6758be168c0dSopenharmony_ci 6759be168c0dSopenharmony_ci # add static securec link library 6760be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/main.cc b/mindspore/lite/tools/benchmark_train/main.cc 6761be168c0dSopenharmony_ciindex abf3d9dd..76f85aa7 100644 6762be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/main.cc 6763be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/main.cc 6764be168c0dSopenharmony_ci@@ -17,7 +17,8 @@ 6765be168c0dSopenharmony_ci #include <malloc.h> 6766be168c0dSopenharmony_ci #include <unistd.h> 6767be168c0dSopenharmony_ci #include <fstream> 6768be168c0dSopenharmony_ci-#include "tools/benchmark_train/net_train.h" 6769be168c0dSopenharmony_ci+#include <iostream> 6770be168c0dSopenharmony_ci+#include "tools/benchmark_train/run_net_train.h" 6771be168c0dSopenharmony_ci 6772be168c0dSopenharmony_ci void PrintMem() { 6773be168c0dSopenharmony_ci std::string proc_file = "/proc/" + std::to_string(getpid()) + "/status"; 6774be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_runner.cc b/mindspore/lite/tools/benchmark_train/net_runner.cc 6775be168c0dSopenharmony_ciindex 9b63d29f..edf3e964 100644 6776be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_runner.cc 6777be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_runner.cc 6778be168c0dSopenharmony_ci@@ -15,7 +15,7 @@ 6779be168c0dSopenharmony_ci */ 6780be168c0dSopenharmony_ci 6781be168c0dSopenharmony_ci #include "tools/benchmark_train/net_runner.h" 6782be168c0dSopenharmony_ci-#include "tools/benchmark_train/net_train.h" 6783be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_base.h" 6784be168c0dSopenharmony_ci #include <getopt.h> 6785be168c0dSopenharmony_ci #include <malloc.h> 6786be168c0dSopenharmony_ci #include <cmath> 6787be168c0dSopenharmony_ci@@ -187,7 +187,7 @@ int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) { 6788be168c0dSopenharmony_ci auto output = tensor.Data(); 6789be168c0dSopenharmony_ci size_t size; 6790be168c0dSopenharmony_ci std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin"; 6791be168c0dSopenharmony_ci- auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(output_file.c_str(), &size)); 6792be168c0dSopenharmony_ci+ auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(output_file.c_str(), &size)); 6793be168c0dSopenharmony_ci if (bin_buf == nullptr) { 6794be168c0dSopenharmony_ci MS_LOG(ERROR) << "ReadFile return nullptr"; 6795be168c0dSopenharmony_ci std::cout << "ReadFile return nullptr" << std::endl; 6796be168c0dSopenharmony_ci@@ -200,7 +200,7 @@ int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) { 6797be168c0dSopenharmony_ci << ", read size: " << size << std::endl; 6798be168c0dSopenharmony_ci return mindspore::kLiteError; 6799be168c0dSopenharmony_ci } 6800be168c0dSopenharmony_ci- float bias = mindspore::lite::NetTrain::CompareData<float>(bin_buf.get(), tensor.ElementNum(), 6801be168c0dSopenharmony_ci+ float bias = mindspore::lite::NetTrainBase::CompareData<float>(bin_buf.get(), tensor.ElementNum(), 6802be168c0dSopenharmony_ci reinterpret_cast<const float *>(output.get())); 6803be168c0dSopenharmony_ci if (bias >= 0) { 6804be168c0dSopenharmony_ci total_bias += bias; 6805be168c0dSopenharmony_ci@@ -332,7 +332,7 @@ int NetRunner::ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs) { 6806be168c0dSopenharmony_ci } 6807be168c0dSopenharmony_ci size_t size; 6808be168c0dSopenharmony_ci std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin"; 6809be168c0dSopenharmony_ci- auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(file_name.c_str(), &size)); 6810be168c0dSopenharmony_ci+ auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(file_name.c_str(), &size)); 6811be168c0dSopenharmony_ci if (bin_buf == nullptr) { 6812be168c0dSopenharmony_ci MS_LOG(ERROR) << "ReadFile return nullptr"; 6813be168c0dSopenharmony_ci std::cout << "ReadFile return nullptr" << std::endl; 6814be168c0dSopenharmony_ci@@ -368,4 +368,4 @@ int CallBack(mindspore::lite::NetTrainFlags *flags) { 6815be168c0dSopenharmony_ci return nr.Main(); 6816be168c0dSopenharmony_ci } 6817be168c0dSopenharmony_ci 6818be168c0dSopenharmony_ci-int init = mindspore::lite::NetTrain::SetNr(CallBack); 6819be168c0dSopenharmony_ci+int init = mindspore::lite::NetTrainBase::SetNr(CallBack); 6820be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc 6821be168c0dSopenharmony_ciindex d1150043..514bba53 100644 6822be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_train.cc 6823be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train.cc 6824be168c0dSopenharmony_ci@@ -31,74 +31,11 @@ 6825be168c0dSopenharmony_ci 6826be168c0dSopenharmony_ci namespace mindspore { 6827be168c0dSopenharmony_ci namespace lite { 6828be168c0dSopenharmony_ci-static const char *DELIM_SLASH = "/"; 6829be168c0dSopenharmony_ci-constexpr const char *DELIM_COLON = ":"; 6830be168c0dSopenharmony_ci-constexpr const char *DELIM_COMMA = ","; 6831be168c0dSopenharmony_ci-constexpr int RET_TOO_BIG = -9; 6832be168c0dSopenharmony_ci constexpr int kField0 = 0; 6833be168c0dSopenharmony_ci constexpr int kField1 = 1; 6834be168c0dSopenharmony_ci constexpr int kField2 = 2; 6835be168c0dSopenharmony_ci constexpr int kField3 = 3; 6836be168c0dSopenharmony_ci constexpr int kField4 = 4; 6837be168c0dSopenharmony_ci-constexpr int kFieldsToPrint = 5; 6838be168c0dSopenharmony_ci-constexpr int kPrintOffset = 4; 6839be168c0dSopenharmony_ci-static const int kTHOUSAND = 1000; 6840be168c0dSopenharmony_ci-constexpr int kDumpInputsAndOutputs = 0; 6841be168c0dSopenharmony_ci-constexpr int kDumpOutputs = 2; 6842be168c0dSopenharmony_ci- 6843be168c0dSopenharmony_ci-const std::unordered_map<int, std::string> kTypeIdMap{ 6844be168c0dSopenharmony_ci- {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"}, {kNumberTypeFloat32, "Float32"}, 6845be168c0dSopenharmony_ci- {kNumberTypeInt8, "Int8"}, {kNumberTypeInt16, "Int16"}, {kNumberTypeInt, "Int32"}, 6846be168c0dSopenharmony_ci- {kNumberTypeInt32, "Int32"}, {kNumberTypeUInt8, "UInt8"}, {kNumberTypeUInt16, "UInt16"}, 6847be168c0dSopenharmony_ci- {kNumberTypeUInt, "UInt32"}, {kNumberTypeUInt32, "UInt32"}, {kObjectTypeString, "String"}, 6848be168c0dSopenharmony_ci- {kNumberTypeBool, "Bool"}, {kObjectTypeTensorType, "Tensor"}}; 6849be168c0dSopenharmony_ci- 6850be168c0dSopenharmony_ci-const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{ 6851be168c0dSopenharmony_ci- {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"}, {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"}, 6852be168c0dSopenharmony_ci- {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"}, {mindspore::CKHW, "CKHW"}, {mindspore::KHWC, "KHWC"}, 6853be168c0dSopenharmony_ci- {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"}, {mindspore::HW4, "HW4"}, {mindspore::NC, "NC"}, 6854be168c0dSopenharmony_ci- {mindspore::NC4, "NC4"}, {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}}; 6855be168c0dSopenharmony_ci- 6856be168c0dSopenharmony_ci-std::function<int(NetTrainFlags *)> NetTrain::nr_cb_ = nullptr; 6857be168c0dSopenharmony_ci- 6858be168c0dSopenharmony_ci-int NetTrain::SetNr(std::function<int(NetTrainFlags *)> param) { 6859be168c0dSopenharmony_ci- nr_cb_ = param; 6860be168c0dSopenharmony_ci- return 0; 6861be168c0dSopenharmony_ci-} 6862be168c0dSopenharmony_ci- 6863be168c0dSopenharmony_ci-float *NetTrain::ReadFileBuf(const std::string file, size_t *size) { 6864be168c0dSopenharmony_ci- if (file.empty()) { 6865be168c0dSopenharmony_ci- MS_LOG(ERROR) << "file is nullptr"; 6866be168c0dSopenharmony_ci- return nullptr; 6867be168c0dSopenharmony_ci- } 6868be168c0dSopenharmony_ci- MS_ASSERT(size != nullptr); 6869be168c0dSopenharmony_ci- std::string real_path = RealPath(file.c_str()); 6870be168c0dSopenharmony_ci- std::ifstream ifs(real_path); 6871be168c0dSopenharmony_ci- if (!ifs.good()) { 6872be168c0dSopenharmony_ci- MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 6873be168c0dSopenharmony_ci- return nullptr; 6874be168c0dSopenharmony_ci- } 6875be168c0dSopenharmony_ci- 6876be168c0dSopenharmony_ci- if (!ifs.is_open()) { 6877be168c0dSopenharmony_ci- MS_LOG(ERROR) << "file: " << real_path << " open failed"; 6878be168c0dSopenharmony_ci- return nullptr; 6879be168c0dSopenharmony_ci- } 6880be168c0dSopenharmony_ci- 6881be168c0dSopenharmony_ci- ifs.seekg(0, std::ios::end); 6882be168c0dSopenharmony_ci- *size = ifs.tellg(); 6883be168c0dSopenharmony_ci- std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1); 6884be168c0dSopenharmony_ci- if (buf == nullptr) { 6885be168c0dSopenharmony_ci- MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; 6886be168c0dSopenharmony_ci- ifs.close(); 6887be168c0dSopenharmony_ci- return nullptr; 6888be168c0dSopenharmony_ci- } 6889be168c0dSopenharmony_ci- 6890be168c0dSopenharmony_ci- ifs.seekg(0, std::ios::beg); 6891be168c0dSopenharmony_ci- ifs.read(reinterpret_cast<char *>(buf.get()), *size); 6892be168c0dSopenharmony_ci- ifs.close(); 6893be168c0dSopenharmony_ci- 6894be168c0dSopenharmony_ci- return buf.release(); 6895be168c0dSopenharmony_ci-} 6896be168c0dSopenharmony_ci 6897be168c0dSopenharmony_ci int NetTrain::GenerateInputData() { 6898be168c0dSopenharmony_ci for (auto tensor : ms_inputs_for_api_) { 6899be168c0dSopenharmony_ci@@ -120,28 +57,6 @@ int NetTrain::GenerateInputData() { 6900be168c0dSopenharmony_ci return RET_OK; 6901be168c0dSopenharmony_ci } 6902be168c0dSopenharmony_ci 6903be168c0dSopenharmony_ci-int NetTrain::LoadInput() { 6904be168c0dSopenharmony_ci- inputs_buf_.clear(); 6905be168c0dSopenharmony_ci- inputs_size_.clear(); 6906be168c0dSopenharmony_ci- batch_num_ = 0; 6907be168c0dSopenharmony_ci- if (flags_->in_data_file_.empty()) { 6908be168c0dSopenharmony_ci- auto status = GenerateInputData(); 6909be168c0dSopenharmony_ci- if (status != RET_OK) { 6910be168c0dSopenharmony_ci- std::cerr << "Generate input data error " << status << std::endl; 6911be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Generate input data error " << status; 6912be168c0dSopenharmony_ci- return status; 6913be168c0dSopenharmony_ci- } 6914be168c0dSopenharmony_ci- } else { 6915be168c0dSopenharmony_ci- auto status = ReadInputFile(); 6916be168c0dSopenharmony_ci- if (status != RET_OK) { 6917be168c0dSopenharmony_ci- std::cerr << "Read Input File error, " << status << std::endl; 6918be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Read Input File error, " << status; 6919be168c0dSopenharmony_ci- return status; 6920be168c0dSopenharmony_ci- } 6921be168c0dSopenharmony_ci- } 6922be168c0dSopenharmony_ci- return RET_OK; 6923be168c0dSopenharmony_ci-} 6924be168c0dSopenharmony_ci- 6925be168c0dSopenharmony_ci int NetTrain::LoadStepInput(size_t step) { 6926be168c0dSopenharmony_ci if (step >= batch_num_) { 6927be168c0dSopenharmony_ci auto cur_batch = step + 1; 6928be168c0dSopenharmony_ci@@ -269,30 +184,6 @@ int NetTrain::CompareOutput() { 6929be168c0dSopenharmony_ci } 6930be168c0dSopenharmony_ci } 6931be168c0dSopenharmony_ci 6932be168c0dSopenharmony_ci-std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name, 6933be168c0dSopenharmony_ci- const std::string &file_type, const size_t &idx) { 6934be168c0dSopenharmony_ci- std::string file_name = op_name; 6935be168c0dSopenharmony_ci- auto pos = file_name.find_first_of('/'); 6936be168c0dSopenharmony_ci- while (pos != std::string::npos) { 6937be168c0dSopenharmony_ci- file_name.replace(pos, 1, "."); 6938be168c0dSopenharmony_ci- pos = file_name.find_first_of('/'); 6939be168c0dSopenharmony_ci- } 6940be168c0dSopenharmony_ci- file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_"; 6941be168c0dSopenharmony_ci- for (const auto &dim : tensor->Shape()) { 6942be168c0dSopenharmony_ci- file_name += std::to_string(dim) + "_"; 6943be168c0dSopenharmony_ci- } 6944be168c0dSopenharmony_ci- if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) { 6945be168c0dSopenharmony_ci- file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType())); 6946be168c0dSopenharmony_ci- } 6947be168c0dSopenharmony_ci- auto tensor_format = tensor->format(); 6948be168c0dSopenharmony_ci- if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) { 6949be168c0dSopenharmony_ci- file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin"; 6950be168c0dSopenharmony_ci- } 6951be168c0dSopenharmony_ci- 6952be168c0dSopenharmony_ci- file_name += ".bin"; 6953be168c0dSopenharmony_ci- return file_name; 6954be168c0dSopenharmony_ci-} 6955be168c0dSopenharmony_ci- 6956be168c0dSopenharmony_ci int NetTrain::MarkPerformance() { 6957be168c0dSopenharmony_ci MS_LOG(INFO) << "Running train loops..."; 6958be168c0dSopenharmony_ci std::cout << "Running train loops..." << std::endl; 6959be168c0dSopenharmony_ci@@ -574,26 +465,6 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string 6960be168c0dSopenharmony_ci return RET_OK; 6961be168c0dSopenharmony_ci } 6962be168c0dSopenharmony_ci 6963be168c0dSopenharmony_ci-int NetTrain::RunNetTrain() { 6964be168c0dSopenharmony_ci- auto file_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1); 6965be168c0dSopenharmony_ci- bool is_train = (file_name.find("train") != std::string::npos) || !flags_->bb_model_file_.empty(); 6966be168c0dSopenharmony_ci- auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, is_train, flags_->epochs_); 6967be168c0dSopenharmony_ci- if (status != RET_OK) { 6968be168c0dSopenharmony_ci- MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status; 6969be168c0dSopenharmony_ci- std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status 6970be168c0dSopenharmony_ci- << std::endl; 6971be168c0dSopenharmony_ci- return status; 6972be168c0dSopenharmony_ci- } 6973be168c0dSopenharmony_ci- 6974be168c0dSopenharmony_ci- status = CheckExecutionOfSavedModels(); // re-initialize sessions according to flags 6975be168c0dSopenharmony_ci- if (status != RET_OK) { 6976be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Run CheckExecute error: " << status; 6977be168c0dSopenharmony_ci- std::cout << "Run CheckExecute error: " << status << std::endl; 6978be168c0dSopenharmony_ci- return status; 6979be168c0dSopenharmony_ci- } 6980be168c0dSopenharmony_ci- return RET_OK; 6981be168c0dSopenharmony_ci-} 6982be168c0dSopenharmony_ci- 6983be168c0dSopenharmony_ci int NetTrain::SaveModels() { 6984be168c0dSopenharmony_ci if (!flags_->export_file_.empty()) { 6985be168c0dSopenharmony_ci if (flags_->bb_model_file_.empty()) { 6986be168c0dSopenharmony_ci@@ -635,77 +506,6 @@ int NetTrain::SaveModels() { 6987be168c0dSopenharmony_ci return RET_OK; 6988be168c0dSopenharmony_ci } 6989be168c0dSopenharmony_ci 6990be168c0dSopenharmony_ci-int NetTrain::CheckExecutionOfSavedModels() { 6991be168c0dSopenharmony_ci- int status = RET_OK; 6992be168c0dSopenharmony_ci- if (!flags_->export_file_.empty()) { 6993be168c0dSopenharmony_ci- status = NetTrain::CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0); 6994be168c0dSopenharmony_ci- if (status != RET_OK) { 6995be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status; 6996be168c0dSopenharmony_ci- std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl; 6997be168c0dSopenharmony_ci- return status; 6998be168c0dSopenharmony_ci- } 6999be168c0dSopenharmony_ci- if (flags_->bb_model_file_.empty()) { 7000be168c0dSopenharmony_ci- status = NetTrain::CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false); 7001be168c0dSopenharmony_ci- if (status != RET_OK) { 7002be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status; 7003be168c0dSopenharmony_ci- std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl; 7004be168c0dSopenharmony_ci- return status; 7005be168c0dSopenharmony_ci- } 7006be168c0dSopenharmony_ci- } 7007be168c0dSopenharmony_ci- } 7008be168c0dSopenharmony_ci- if (!flags_->inference_file_.empty()) { 7009be168c0dSopenharmony_ci- status = NetTrain::CreateAndRunNetwork(flags_->inference_file_, "", false, 0); 7010be168c0dSopenharmony_ci- if (status != RET_OK) { 7011be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status; 7012be168c0dSopenharmony_ci- std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl; 7013be168c0dSopenharmony_ci- return status; 7014be168c0dSopenharmony_ci- } 7015be168c0dSopenharmony_ci- status = NetTrain::CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false); 7016be168c0dSopenharmony_ci- if (status != RET_OK) { 7017be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status; 7018be168c0dSopenharmony_ci- std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl; 7019be168c0dSopenharmony_ci- return status; 7020be168c0dSopenharmony_ci- } 7021be168c0dSopenharmony_ci- } 7022be168c0dSopenharmony_ci- return status; 7023be168c0dSopenharmony_ci-} 7024be168c0dSopenharmony_ci- 7025be168c0dSopenharmony_ci-void NetTrain::CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out) { 7026be168c0dSopenharmony_ci- if (tensor == nullptr) { 7027be168c0dSopenharmony_ci- MS_LOG(ERROR) << "input tensor is nullptr."; 7028be168c0dSopenharmony_ci- return; 7029be168c0dSopenharmony_ci- } 7030be168c0dSopenharmony_ci- int tensor_size = tensor->ElementNum(); 7031be168c0dSopenharmony_ci- void *data = tensor->MutableData(); 7032be168c0dSopenharmony_ci- auto *fdata = reinterpret_cast<float *>(tensor->MutableData()); 7033be168c0dSopenharmony_ci- auto type = tensor->DataType(); 7034be168c0dSopenharmony_ci- std::cout << node_type << " " << in_out << id << " shape=" << tensor->Shape() << " sum="; 7035be168c0dSopenharmony_ci- switch (type) { 7036be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeFloat32: 7037be168c0dSopenharmony_ci- TensorNan(reinterpret_cast<float *>(data), tensor_size); 7038be168c0dSopenharmony_ci- std::cout << TensorSum<float>(data, tensor_size) << std::endl; 7039be168c0dSopenharmony_ci- std::cout << "tensor name: " << tensor->Name() << std::endl; 7040be168c0dSopenharmony_ci- std::cout << "data: "; 7041be168c0dSopenharmony_ci- for (int i = 0; i <= kPrintOffset && i < tensor_size; i++) { 7042be168c0dSopenharmony_ci- std::cout << static_cast<float>(fdata[i]) << ", "; 7043be168c0dSopenharmony_ci- } 7044be168c0dSopenharmony_ci- std::cout << std::endl; 7045be168c0dSopenharmony_ci- break; 7046be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeInt32: 7047be168c0dSopenharmony_ci- std::cout << TensorSum<int>(data, tensor_size) << std::endl; 7048be168c0dSopenharmony_ci- break; 7049be168c0dSopenharmony_ci-#ifdef ENABLE_FP16 7050be168c0dSopenharmony_ci- case mindspore::DataType::kNumberTypeFloat16: 7051be168c0dSopenharmony_ci- std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl; 7052be168c0dSopenharmony_ci- TensorNan(reinterpret_cast<float16_t *>(data), tensor_size); 7053be168c0dSopenharmony_ci- break; 7054be168c0dSopenharmony_ci-#endif 7055be168c0dSopenharmony_ci- default: 7056be168c0dSopenharmony_ci- std::cout << "unsupported type:" << static_cast<int>(type) << std::endl; 7057be168c0dSopenharmony_ci- break; 7058be168c0dSopenharmony_ci- } 7059be168c0dSopenharmony_ci-} 7060be168c0dSopenharmony_ci- 7061be168c0dSopenharmony_ci int NetTrain::InitDumpTensorDataCallbackParameter() { 7062be168c0dSopenharmony_ci // before callback 7063be168c0dSopenharmony_ci before_call_back_ = [&](const std::vector<mindspore::MSTensor> &before_inputs, 7064be168c0dSopenharmony_ci@@ -815,178 +615,6 @@ int NetTrain::InitTimeProfilingCallbackParameter() { 7065be168c0dSopenharmony_ci return RET_OK; 7066be168c0dSopenharmony_ci } 7067be168c0dSopenharmony_ci 7068be168c0dSopenharmony_ci-int NetTrain::InitCallbackParameter() { 7069be168c0dSopenharmony_ci- int ret = RET_OK; 7070be168c0dSopenharmony_ci- if (flags_->dump_tensor_data_) { 7071be168c0dSopenharmony_ci- ret = InitDumpTensorDataCallbackParameter(); 7072be168c0dSopenharmony_ci- } else if (flags_->time_profiling_) { 7073be168c0dSopenharmony_ci- ret = InitTimeProfilingCallbackParameter(); 7074be168c0dSopenharmony_ci- } 7075be168c0dSopenharmony_ci- return ret; 7076be168c0dSopenharmony_ci-} 7077be168c0dSopenharmony_ci- 7078be168c0dSopenharmony_ci-void NetTrainFlags::InitResizeDimsList() { 7079be168c0dSopenharmony_ci- std::string content = this->resize_dims_in_; 7080be168c0dSopenharmony_ci- std::vector<int> shape; 7081be168c0dSopenharmony_ci- auto shape_strs = StrSplit(content, std::string(DELIM_COLON)); 7082be168c0dSopenharmony_ci- for (const auto &shape_str : shape_strs) { 7083be168c0dSopenharmony_ci- shape.clear(); 7084be168c0dSopenharmony_ci- auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA)); 7085be168c0dSopenharmony_ci- std::cout << "Resize Dims: "; 7086be168c0dSopenharmony_ci- for (const auto &dim_str : dim_strs) { 7087be168c0dSopenharmony_ci- std::cout << dim_str << " "; 7088be168c0dSopenharmony_ci- shape.emplace_back(static_cast<int>(std::stoi(dim_str))); 7089be168c0dSopenharmony_ci- } 7090be168c0dSopenharmony_ci- std::cout << std::endl; 7091be168c0dSopenharmony_ci- this->resize_dims_.emplace_back(shape); 7092be168c0dSopenharmony_ci- } 7093be168c0dSopenharmony_ci-} 7094be168c0dSopenharmony_ci- 7095be168c0dSopenharmony_ci-int NetTrain::Init() { 7096be168c0dSopenharmony_ci- if (this->flags_ == nullptr) { 7097be168c0dSopenharmony_ci- return 1; 7098be168c0dSopenharmony_ci- } 7099be168c0dSopenharmony_ci- MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_; 7100be168c0dSopenharmony_ci- MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_; 7101be168c0dSopenharmony_ci- MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_; 7102be168c0dSopenharmony_ci- MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_; 7103be168c0dSopenharmony_ci- MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_; 7104be168c0dSopenharmony_ci- MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_; 7105be168c0dSopenharmony_ci- MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_; 7106be168c0dSopenharmony_ci- MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_; 7107be168c0dSopenharmony_ci- MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_; 7108be168c0dSopenharmony_ci- MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_; 7109be168c0dSopenharmony_ci- MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_; 7110be168c0dSopenharmony_ci- 7111be168c0dSopenharmony_ci- if (this->flags_->epochs_ < 0) { 7112be168c0dSopenharmony_ci- MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0"; 7113be168c0dSopenharmony_ci- std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl; 7114be168c0dSopenharmony_ci- return RET_ERROR; 7115be168c0dSopenharmony_ci- } 7116be168c0dSopenharmony_ci- 7117be168c0dSopenharmony_ci- if (this->flags_->num_threads_ < 1) { 7118be168c0dSopenharmony_ci- MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0"; 7119be168c0dSopenharmony_ci- std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl; 7120be168c0dSopenharmony_ci- return RET_ERROR; 7121be168c0dSopenharmony_ci- } 7122be168c0dSopenharmony_ci- 7123be168c0dSopenharmony_ci- this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary; 7124be168c0dSopenharmony_ci- 7125be168c0dSopenharmony_ci- if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) { 7126be168c0dSopenharmony_ci- MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided"; 7127be168c0dSopenharmony_ci- std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl; 7128be168c0dSopenharmony_ci- return RET_ERROR; 7129be168c0dSopenharmony_ci- } 7130be168c0dSopenharmony_ci- 7131be168c0dSopenharmony_ci- if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) { 7132be168c0dSopenharmony_ci- MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided"; 7133be168c0dSopenharmony_ci- std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl; 7134be168c0dSopenharmony_ci- return RET_ERROR; 7135be168c0dSopenharmony_ci- } 7136be168c0dSopenharmony_ci- 7137be168c0dSopenharmony_ci- if (flags_->model_file_.empty()) { 7138be168c0dSopenharmony_ci- MS_LOG(ERROR) << "modelPath is required"; 7139be168c0dSopenharmony_ci- std::cerr << "modelPath is required" << std::endl; 7140be168c0dSopenharmony_ci- return 1; 7141be168c0dSopenharmony_ci- } 7142be168c0dSopenharmony_ci- 7143be168c0dSopenharmony_ci- // get dump data output path 7144be168c0dSopenharmony_ci- auto dump_cfg_path = std::getenv(dump::kConfigPath); 7145be168c0dSopenharmony_ci- if (dump_cfg_path != nullptr) { 7146be168c0dSopenharmony_ci- flags_->dump_tensor_data_ = true; 7147be168c0dSopenharmony_ci- if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) { 7148be168c0dSopenharmony_ci- MS_LOG(ERROR) << "parse dump config file failed."; 7149be168c0dSopenharmony_ci- return RET_ERROR; 7150be168c0dSopenharmony_ci- } 7151be168c0dSopenharmony_ci- } else { 7152be168c0dSopenharmony_ci- MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data"; 7153be168c0dSopenharmony_ci- } 7154be168c0dSopenharmony_ci- 7155be168c0dSopenharmony_ci- auto status = InitCallbackParameter(); 7156be168c0dSopenharmony_ci- if (status != RET_OK) { 7157be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Init callback Parameter failed."; 7158be168c0dSopenharmony_ci- std::cerr << "Init callback Parameter failed." << std::endl; 7159be168c0dSopenharmony_ci- return RET_ERROR; 7160be168c0dSopenharmony_ci- } 7161be168c0dSopenharmony_ci- 7162be168c0dSopenharmony_ci- flags_->InitResizeDimsList(); 7163be168c0dSopenharmony_ci- if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() && 7164be168c0dSopenharmony_ci- flags_->resize_dims_.size() != flags_->input_data_list_.size()) { 7165be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath"; 7166be168c0dSopenharmony_ci- std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl; 7167be168c0dSopenharmony_ci- return RET_ERROR; 7168be168c0dSopenharmony_ci- } 7169be168c0dSopenharmony_ci- return RET_OK; 7170be168c0dSopenharmony_ci-} 7171be168c0dSopenharmony_ci- 7172be168c0dSopenharmony_ci-namespace { 7173be168c0dSopenharmony_ci-constexpr int kNumToPrint = 5; 7174be168c0dSopenharmony_ci-} 7175be168c0dSopenharmony_ci- 7176be168c0dSopenharmony_ci-int NetTrain::InitDumpConfigFromJson(std::string path) { 7177be168c0dSopenharmony_ci- auto real_path = RealPath(path.c_str()); 7178be168c0dSopenharmony_ci- std::ifstream ifs(real_path); 7179be168c0dSopenharmony_ci- if (!ifs.good()) { 7180be168c0dSopenharmony_ci- MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 7181be168c0dSopenharmony_ci- return RET_ERROR; 7182be168c0dSopenharmony_ci- } 7183be168c0dSopenharmony_ci- if (!ifs.is_open()) { 7184be168c0dSopenharmony_ci- MS_LOG(ERROR) << "file: " << real_path << " open failed"; 7185be168c0dSopenharmony_ci- return RET_ERROR; 7186be168c0dSopenharmony_ci- } 7187be168c0dSopenharmony_ci- 7188be168c0dSopenharmony_ci- try { 7189be168c0dSopenharmony_ci- dump_cfg_json_ = nlohmann::json::parse(ifs); 7190be168c0dSopenharmony_ci- } catch (const nlohmann::json::parse_error &error) { 7191be168c0dSopenharmony_ci- MS_LOG(ERROR) << "parse json file failed, please check your file."; 7192be168c0dSopenharmony_ci- return RET_ERROR; 7193be168c0dSopenharmony_ci- } 7194be168c0dSopenharmony_ci- if (dump_cfg_json_[dump::kSettings] == nullptr) { 7195be168c0dSopenharmony_ci- MS_LOG(ERROR) << "\"common_dump_settings\" is required."; 7196be168c0dSopenharmony_ci- return RET_ERROR; 7197be168c0dSopenharmony_ci- } 7198be168c0dSopenharmony_ci- if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) { 7199be168c0dSopenharmony_ci- MS_LOG(ERROR) << "\"dump_mode\" is required."; 7200be168c0dSopenharmony_ci- return RET_ERROR; 7201be168c0dSopenharmony_ci- } 7202be168c0dSopenharmony_ci- if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) { 7203be168c0dSopenharmony_ci- MS_LOG(ERROR) << "\"path\" is required."; 7204be168c0dSopenharmony_ci- return RET_ERROR; 7205be168c0dSopenharmony_ci- } 7206be168c0dSopenharmony_ci- if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) { 7207be168c0dSopenharmony_ci- dump_cfg_json_[dump::kSettings][dump::kNetName] = "default"; 7208be168c0dSopenharmony_ci- } 7209be168c0dSopenharmony_ci- if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) { 7210be168c0dSopenharmony_ci- dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0; 7211be168c0dSopenharmony_ci- } 7212be168c0dSopenharmony_ci- if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr && 7213be168c0dSopenharmony_ci- !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) { 7214be168c0dSopenharmony_ci- if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) { 7215be168c0dSopenharmony_ci- MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)"; 7216be168c0dSopenharmony_ci- return RET_ERROR; 7217be168c0dSopenharmony_ci- } 7218be168c0dSopenharmony_ci- } 7219be168c0dSopenharmony_ci- 7220be168c0dSopenharmony_ci- auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>(); 7221be168c0dSopenharmony_ci- auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>(); 7222be168c0dSopenharmony_ci- if (abs_path.back() == '\\' || abs_path.back() == '/') { 7223be168c0dSopenharmony_ci- dump_file_output_dir_ = abs_path + net_name; 7224be168c0dSopenharmony_ci- } else { 7225be168c0dSopenharmony_ci-#ifdef _WIN32 7226be168c0dSopenharmony_ci- dump_file_output_dir_ = abs_path + "\\" + net_name; 7227be168c0dSopenharmony_ci-#else 7228be168c0dSopenharmony_ci- dump_file_output_dir_ = abs_path + "/" + net_name; 7229be168c0dSopenharmony_ci-#endif 7230be168c0dSopenharmony_ci- } 7231be168c0dSopenharmony_ci- 7232be168c0dSopenharmony_ci- auto status = CreateOutputDir(&dump_file_output_dir_); 7233be168c0dSopenharmony_ci- if (status != RET_OK) { 7234be168c0dSopenharmony_ci- MS_LOG(ERROR) << "create data output directory failed."; 7235be168c0dSopenharmony_ci- return RET_ERROR; 7236be168c0dSopenharmony_ci- } 7237be168c0dSopenharmony_ci- return RET_OK; 7238be168c0dSopenharmony_ci-} 7239be168c0dSopenharmony_ci- 7240be168c0dSopenharmony_ci int NetTrain::PrintResult(const std::vector<std::string> &title, 7241be168c0dSopenharmony_ci const std::map<std::string, std::pair<int, float>> &result) { 7242be168c0dSopenharmony_ci std::vector<size_t> columnLenMax(kFieldsToPrint); 7243be168c0dSopenharmony_ci@@ -1035,7 +663,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title, 7244be168c0dSopenharmony_ci } 7245be168c0dSopenharmony_ci 7246be168c0dSopenharmony_ci printf("-------------------------------------------------------------------------\n"); 7247be168c0dSopenharmony_ci- for (int i = 0; i < kNumToPrint; i++) { 7248be168c0dSopenharmony_ci+ for (int i = 0; i < kFieldsToPrint; i++) { 7249be168c0dSopenharmony_ci auto printBuf = title[i]; 7250be168c0dSopenharmony_ci if (printBuf.size() > columnLenMax.at(i)) { 7251be168c0dSopenharmony_ci columnLenMax.at(i) = printBuf.size(); 7252be168c0dSopenharmony_ci@@ -1045,7 +673,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title, 7253be168c0dSopenharmony_ci } 7254be168c0dSopenharmony_ci printf("\n"); 7255be168c0dSopenharmony_ci for (auto &row : rows) { 7256be168c0dSopenharmony_ci- for (int j = 0; j < kNumToPrint; j++) { 7257be168c0dSopenharmony_ci+ for (int j = 0; j < kFieldsToPrint; j++) { 7258be168c0dSopenharmony_ci auto printBuf = row[j]; 7259be168c0dSopenharmony_ci printBuf.resize(columnLenMax.at(j), ' '); 7260be168c0dSopenharmony_ci printf("%s\t", printBuf.c_str()); 7261be168c0dSopenharmony_ci@@ -1054,47 +682,5 @@ int NetTrain::PrintResult(const std::vector<std::string> &title, 7262be168c0dSopenharmony_ci } 7263be168c0dSopenharmony_ci return RET_OK; 7264be168c0dSopenharmony_ci } 7265be168c0dSopenharmony_ci- 7266be168c0dSopenharmony_ci-int RunNetTrain(int argc, const char **argv) { 7267be168c0dSopenharmony_ci- NetTrainFlags flags; 7268be168c0dSopenharmony_ci- Option<std::string> err = flags.ParseFlags(argc, argv); 7269be168c0dSopenharmony_ci- 7270be168c0dSopenharmony_ci- if (err.IsSome()) { 7271be168c0dSopenharmony_ci- std::cerr << err.Get() << std::endl; 7272be168c0dSopenharmony_ci- std::cerr << flags.Usage() << std::endl; 7273be168c0dSopenharmony_ci- return RET_ERROR; 7274be168c0dSopenharmony_ci- } 7275be168c0dSopenharmony_ci- 7276be168c0dSopenharmony_ci- if (flags.help) { 7277be168c0dSopenharmony_ci- std::cerr << flags.Usage() << std::endl; 7278be168c0dSopenharmony_ci- return RET_OK; 7279be168c0dSopenharmony_ci- } 7280be168c0dSopenharmony_ci- if (flags.unified_api_) { 7281be168c0dSopenharmony_ci- return NetTrain::RunNr(&flags); 7282be168c0dSopenharmony_ci- } 7283be168c0dSopenharmony_ci- NetTrain net_trainer(&flags); 7284be168c0dSopenharmony_ci- auto status = net_trainer.Init(); 7285be168c0dSopenharmony_ci- if (status != RET_OK) { 7286be168c0dSopenharmony_ci- MS_LOG(ERROR) << "NetTrain init Error : " << status; 7287be168c0dSopenharmony_ci- std::cerr << "NetTrain init Error : " << status << std::endl; 7288be168c0dSopenharmony_ci- return RET_ERROR; 7289be168c0dSopenharmony_ci- } 7290be168c0dSopenharmony_ci- 7291be168c0dSopenharmony_ci- status = net_trainer.RunNetTrain(); 7292be168c0dSopenharmony_ci- if (status != RET_OK) { 7293be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Run NetTrain " 7294be168c0dSopenharmony_ci- << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 7295be168c0dSopenharmony_ci- << " Failed : " << status; 7296be168c0dSopenharmony_ci- std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 7297be168c0dSopenharmony_ci- << " Failed : " << status << std::endl; 7298be168c0dSopenharmony_ci- return RET_ERROR; 7299be168c0dSopenharmony_ci- } 7300be168c0dSopenharmony_ci- 7301be168c0dSopenharmony_ci- MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 7302be168c0dSopenharmony_ci- << " Success."; 7303be168c0dSopenharmony_ci- std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 7304be168c0dSopenharmony_ci- << " Success." << std::endl; 7305be168c0dSopenharmony_ci- return RET_OK; 7306be168c0dSopenharmony_ci-} 7307be168c0dSopenharmony_ci } // namespace lite 7308be168c0dSopenharmony_ci } // namespace mindspore 7309be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h 7310be168c0dSopenharmony_ciindex 67e58a04..bdf0ec88 100644 7311be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_train.h 7312be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train.h 7313be168c0dSopenharmony_ci@@ -42,183 +42,22 @@ 7314be168c0dSopenharmony_ci #include "tools/common/flag_parser.h" 7315be168c0dSopenharmony_ci #include "src/common/file_utils.h" 7316be168c0dSopenharmony_ci #include "src/common/utils.h" 7317be168c0dSopenharmony_ci- 7318be168c0dSopenharmony_ci-#ifdef ENABLE_FP16 7319be168c0dSopenharmony_ci-static __attribute__((always_inline)) inline bool MS_ISNAN_FP16(float16_t var) { 7320be168c0dSopenharmony_ci- volatile float16_t d = var; 7321be168c0dSopenharmony_ci- return d != d; 7322be168c0dSopenharmony_ci-} 7323be168c0dSopenharmony_ci-#endif 7324be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_base.h" 7325be168c0dSopenharmony_ci 7326be168c0dSopenharmony_ci namespace mindspore::lite { 7327be168c0dSopenharmony_ci-enum MS_API DataType { kImage = 0, kBinary = 1 }; 7328be168c0dSopenharmony_ci- 7329be168c0dSopenharmony_ci-constexpr float relativeTolerance = 1e-5; 7330be168c0dSopenharmony_ci-constexpr float absoluteTolerance = 1e-8; 7331be168c0dSopenharmony_ci extern const std::unordered_map<int, std::string> kTypeIdMap; 7332be168c0dSopenharmony_ci extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap; 7333be168c0dSopenharmony_ci 7334be168c0dSopenharmony_ci-namespace dump { 7335be168c0dSopenharmony_ci-constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG"; 7336be168c0dSopenharmony_ci-constexpr auto kSettings = "common_dump_settings"; 7337be168c0dSopenharmony_ci-constexpr auto kMode = "dump_mode"; 7338be168c0dSopenharmony_ci-constexpr auto kPath = "path"; 7339be168c0dSopenharmony_ci-constexpr auto kNetName = "net_name"; 7340be168c0dSopenharmony_ci-constexpr auto kInputOutput = "input_output"; 7341be168c0dSopenharmony_ci-constexpr auto kKernels = "kernels"; 7342be168c0dSopenharmony_ci-} // namespace dump 7343be168c0dSopenharmony_ci- 7344be168c0dSopenharmony_ci-template <typename T> 7345be168c0dSopenharmony_ci-float TensorSum(const void *data, int size) { 7346be168c0dSopenharmony_ci- const T *typed_data = reinterpret_cast<const T *>(data); 7347be168c0dSopenharmony_ci- float sum = 0.f; 7348be168c0dSopenharmony_ci- for (int i = 0; i < size; i++) { 7349be168c0dSopenharmony_ci- sum += static_cast<float>(typed_data[i]); 7350be168c0dSopenharmony_ci- } 7351be168c0dSopenharmony_ci- return sum; 7352be168c0dSopenharmony_ci-} 7353be168c0dSopenharmony_ci- 7354be168c0dSopenharmony_ci-class MS_API NetTrainFlags : public virtual FlagParser { 7355be168c0dSopenharmony_ci+class MS_API NetTrain : public NetTrainBase { 7356be168c0dSopenharmony_ci public: 7357be168c0dSopenharmony_ci- NetTrainFlags() { 7358be168c0dSopenharmony_ci- // common 7359be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::model_file_, "modelFile", "Input model file", ""); 7360be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::bb_model_file_, "bbModelFile", "Backboine model for transfer session", ""); 7361be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", ""); 7362be168c0dSopenharmony_ci- // MarkPerformance 7363be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0); 7364be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false); 7365be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1); 7366be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1); 7367be168c0dSopenharmony_ci- // MarkAccuracy 7368be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); 7369be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); 7370be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); 7371be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); 7372be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false); 7373be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", ""); 7374be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", ""); 7375be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false); 7376be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes", 7377be168c0dSopenharmony_ci- "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); 7378be168c0dSopenharmony_ci- AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false); 7379be168c0dSopenharmony_ci- } 7380be168c0dSopenharmony_ci- 7381be168c0dSopenharmony_ci- ~NetTrainFlags() override = default; 7382be168c0dSopenharmony_ci- void InitResizeDimsList(); 7383be168c0dSopenharmony_ci+ explicit NetTrain(NetTrainFlags *flags) : NetTrainBase(flags) {} 7384be168c0dSopenharmony_ci+ virtual ~NetTrain() {} 7385be168c0dSopenharmony_ci 7386be168c0dSopenharmony_ci- public: 7387be168c0dSopenharmony_ci- // common 7388be168c0dSopenharmony_ci- std::string model_file_; 7389be168c0dSopenharmony_ci- std::string in_data_file_; 7390be168c0dSopenharmony_ci- std::string bb_model_file_; 7391be168c0dSopenharmony_ci- std::vector<std::string> input_data_list_; 7392be168c0dSopenharmony_ci- DataType in_data_type_; 7393be168c0dSopenharmony_ci- std::string in_data_type_in_ = "bin"; 7394be168c0dSopenharmony_ci- int cpu_bind_mode_ = 1; 7395be168c0dSopenharmony_ci- bool enable_fp16_ = false; 7396be168c0dSopenharmony_ci- bool virtual_batch_ = false; 7397be168c0dSopenharmony_ci- // MarkPerformance 7398be168c0dSopenharmony_ci- int num_threads_ = 1; 7399be168c0dSopenharmony_ci- int warm_up_loop_count_ = 0; 7400be168c0dSopenharmony_ci- bool time_profiling_; 7401be168c0dSopenharmony_ci- int epochs_ = 1; 7402be168c0dSopenharmony_ci- // MarkAccuracy 7403be168c0dSopenharmony_ci- std::string data_file_; 7404be168c0dSopenharmony_ci- std::string data_type_ = "FLOAT"; 7405be168c0dSopenharmony_ci- float accuracy_threshold_; 7406be168c0dSopenharmony_ci- // Resize 7407be168c0dSopenharmony_ci- std::string export_file_ = ""; 7408be168c0dSopenharmony_ci- std::string resize_dims_in_ = ""; 7409be168c0dSopenharmony_ci- bool layer_checksum_ = false; 7410be168c0dSopenharmony_ci- std::vector<std::vector<int>> resize_dims_; 7411be168c0dSopenharmony_ci- std::string loss_name_ = ""; 7412be168c0dSopenharmony_ci- std::string inference_file_ = ""; 7413be168c0dSopenharmony_ci- bool unified_api_ = false; 7414be168c0dSopenharmony_ci- bool dump_tensor_data_ = false; 7415be168c0dSopenharmony_ci-}; 7416be168c0dSopenharmony_ci- 7417be168c0dSopenharmony_ci-class MS_API NetTrain { 7418be168c0dSopenharmony_ci- public: 7419be168c0dSopenharmony_ci- explicit NetTrain(NetTrainFlags *flags) : flags_(flags) {} 7420be168c0dSopenharmony_ci- virtual ~NetTrain() = default; 7421be168c0dSopenharmony_ci- 7422be168c0dSopenharmony_ci- int Init(); 7423be168c0dSopenharmony_ci- int RunNetTrain(); 7424be168c0dSopenharmony_ci- static float *ReadFileBuf(const std::string file, size_t *size); 7425be168c0dSopenharmony_ci- static int SetNr(std::function<int(NetTrainFlags *)> param); 7426be168c0dSopenharmony_ci- static int RunNr(NetTrainFlags *flags) { 7427be168c0dSopenharmony_ci- if (nr_cb_ != nullptr) { 7428be168c0dSopenharmony_ci- return nr_cb_(flags); 7429be168c0dSopenharmony_ci- } 7430be168c0dSopenharmony_ci- MS_LOG(WARNING) << "unified api was not tested"; 7431be168c0dSopenharmony_ci- std::cout << "unified api was not tested"; 7432be168c0dSopenharmony_ci- return RET_OK; 7433be168c0dSopenharmony_ci- } 7434be168c0dSopenharmony_ci- // tensorData need to be converter first 7435be168c0dSopenharmony_ci- template <typename T> 7436be168c0dSopenharmony_ci- static float CompareData(const float *refOutput, int size, const T *msTensorData) { 7437be168c0dSopenharmony_ci- size_t errorCount = 0; 7438be168c0dSopenharmony_ci- float meanError = 0; 7439be168c0dSopenharmony_ci- std::cout << "Out tensor size is: " << size << std::endl; 7440be168c0dSopenharmony_ci- std::cout << "Data of model output: "; 7441be168c0dSopenharmony_ci- for (int j = 0; j < std::min(50, size); j++) { 7442be168c0dSopenharmony_ci- std::cout << static_cast<float>(msTensorData[j]) << " "; 7443be168c0dSopenharmony_ci- } 7444be168c0dSopenharmony_ci- std::cout << std::endl; 7445be168c0dSopenharmony_ci- std::cout << "Data of Ref output : "; 7446be168c0dSopenharmony_ci- for (int j = 0; j < std::min(50, size); j++) { 7447be168c0dSopenharmony_ci- std::cout << refOutput[j] << " "; 7448be168c0dSopenharmony_ci- } 7449be168c0dSopenharmony_ci- std::cout << std::endl; 7450be168c0dSopenharmony_ci- for (int j = 0; j < size; j++) { 7451be168c0dSopenharmony_ci- if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { 7452be168c0dSopenharmony_ci- std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; 7453be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; 7454be168c0dSopenharmony_ci- return RET_ERROR; 7455be168c0dSopenharmony_ci- } 7456be168c0dSopenharmony_ci- 7457be168c0dSopenharmony_ci- auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]); 7458be168c0dSopenharmony_ci- auto absoluteError = std::fabs(static_cast<float>(msTensorData[j]) - refOutput[j]); 7459be168c0dSopenharmony_ci- if (absoluteError > tolerance) { 7460be168c0dSopenharmony_ci- if (fabs(refOutput[j]) == 0) { 7461be168c0dSopenharmony_ci- if (absoluteError > 1e-5) { 7462be168c0dSopenharmony_ci- meanError += absoluteError; 7463be168c0dSopenharmony_ci- errorCount++; 7464be168c0dSopenharmony_ci- } else { 7465be168c0dSopenharmony_ci- continue; 7466be168c0dSopenharmony_ci- } 7467be168c0dSopenharmony_ci- } else { 7468be168c0dSopenharmony_ci- // just assume that atol = rtol 7469be168c0dSopenharmony_ci- meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN); 7470be168c0dSopenharmony_ci- errorCount++; 7471be168c0dSopenharmony_ci- } 7472be168c0dSopenharmony_ci- } 7473be168c0dSopenharmony_ci- } 7474be168c0dSopenharmony_ci- std::cout << std::endl; 7475be168c0dSopenharmony_ci- if (meanError > 0.0f) { 7476be168c0dSopenharmony_ci- meanError /= errorCount; 7477be168c0dSopenharmony_ci- } 7478be168c0dSopenharmony_ci- 7479be168c0dSopenharmony_ci- if (meanError <= 0.0000001) { 7480be168c0dSopenharmony_ci- std::cout << "Mean bias of tensor: 0%" << std::endl; 7481be168c0dSopenharmony_ci- } else { 7482be168c0dSopenharmony_ci- std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl; 7483be168c0dSopenharmony_ci- } 7484be168c0dSopenharmony_ci- return meanError; 7485be168c0dSopenharmony_ci- } 7486be168c0dSopenharmony_ci- int InitDumpConfigFromJson(std::string path); 7487be168c0dSopenharmony_ci- 7488be168c0dSopenharmony_ci- private: 7489be168c0dSopenharmony_ci- // call GenerateInputData or ReadInputFile to init inputTensors 7490be168c0dSopenharmony_ci- int LoadInput(); 7491be168c0dSopenharmony_ci- void CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out); 7492be168c0dSopenharmony_ci+ protected: 7493be168c0dSopenharmony_ci // call GenerateRandomData to fill inputTensors 7494be168c0dSopenharmony_ci- int GenerateInputData(); 7495be168c0dSopenharmony_ci+ int GenerateInputData() override; 7496be168c0dSopenharmony_ci 7497be168c0dSopenharmony_ci- int GenerateRandomData(mindspore::MSTensor *tensor); 7498be168c0dSopenharmony_ci- 7499be168c0dSopenharmony_ci- int ReadInputFile(); 7500be168c0dSopenharmony_ci+ int ReadInputFile() override; 7501be168c0dSopenharmony_ci 7502be168c0dSopenharmony_ci int LoadStepInput(size_t step); 7503be168c0dSopenharmony_ci 7504be168c0dSopenharmony_ci@@ -227,20 +66,19 @@ class MS_API NetTrain { 7505be168c0dSopenharmony_ci void InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg); 7506be168c0dSopenharmony_ci 7507be168c0dSopenharmony_ci int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs, 7508be168c0dSopenharmony_ci- bool check_accuracy = true); 7509be168c0dSopenharmony_ci+ bool check_accuracy = true) override; 7510be168c0dSopenharmony_ci 7511be168c0dSopenharmony_ci int CreateAndRunNetworkForInference(const std::string &filename, const std::shared_ptr<mindspore::Context> &context); 7512be168c0dSopenharmony_ci 7513be168c0dSopenharmony_ci int CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename, 7514be168c0dSopenharmony_ci const std::shared_ptr<mindspore::Context> &context, 7515be168c0dSopenharmony_ci const std::shared_ptr<TrainCfg> &train_cfg, int epochs); 7516be168c0dSopenharmony_ci- int InitCallbackParameter(); 7517be168c0dSopenharmony_ci 7518be168c0dSopenharmony_ci- int InitDumpTensorDataCallbackParameter(); 7519be168c0dSopenharmony_ci+ int InitDumpTensorDataCallbackParameter() override; 7520be168c0dSopenharmony_ci 7521be168c0dSopenharmony_ci- int InitTimeProfilingCallbackParameter(); 7522be168c0dSopenharmony_ci+ int InitTimeProfilingCallbackParameter() override; 7523be168c0dSopenharmony_ci 7524be168c0dSopenharmony_ci- int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result); 7525be168c0dSopenharmony_ci+ int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) override; 7526be168c0dSopenharmony_ci 7527be168c0dSopenharmony_ci template <typename T> 7528be168c0dSopenharmony_ci void PrintInputData(mindspore::MSTensor *input) { 7529be168c0dSopenharmony_ci@@ -256,39 +94,11 @@ class MS_API NetTrain { 7530be168c0dSopenharmony_ci std::cout << std::endl; 7531be168c0dSopenharmony_ci } 7532be168c0dSopenharmony_ci 7533be168c0dSopenharmony_ci- template <typename T> 7534be168c0dSopenharmony_ci- std::vector<int64_t> ConverterToInt64Vector(const std::vector<T> &srcDims) { 7535be168c0dSopenharmony_ci- std::vector<int64_t> dims; 7536be168c0dSopenharmony_ci- for (auto shape : srcDims) { 7537be168c0dSopenharmony_ci- dims.push_back(static_cast<int64_t>(shape)); 7538be168c0dSopenharmony_ci- } 7539be168c0dSopenharmony_ci- return dims; 7540be168c0dSopenharmony_ci- } 7541be168c0dSopenharmony_ci- int MarkPerformance(); 7542be168c0dSopenharmony_ci- int MarkAccuracy(bool enforce_accuracy = true); 7543be168c0dSopenharmony_ci- int CompareOutput(); 7544be168c0dSopenharmony_ci- int SaveModels(); 7545be168c0dSopenharmony_ci- int CheckExecutionOfSavedModels(); 7546be168c0dSopenharmony_ci- void TensorNan(const float *data, int size) { 7547be168c0dSopenharmony_ci- for (int i = 0; i < size; i++) { 7548be168c0dSopenharmony_ci- if (std::isnan(data[i])) { 7549be168c0dSopenharmony_ci- std::cout << "nan value of index=" << i << ", " << data[i] << std::endl; 7550be168c0dSopenharmony_ci- break; 7551be168c0dSopenharmony_ci- } 7552be168c0dSopenharmony_ci- } 7553be168c0dSopenharmony_ci- } 7554be168c0dSopenharmony_ci-#ifdef ENABLE_FP16 7555be168c0dSopenharmony_ci- void TensorNan(float16_t *data, int size) { 7556be168c0dSopenharmony_ci- for (int i = 0; i < size; i++) { 7557be168c0dSopenharmony_ci- if (MS_ISNAN_FP16(data[i]) || std::isinf(data[i])) { 7558be168c0dSopenharmony_ci- std::cout << "nan or inf value of index=" << i << ", " << data[i] << std::endl; 7559be168c0dSopenharmony_ci- break; 7560be168c0dSopenharmony_ci- } 7561be168c0dSopenharmony_ci- } 7562be168c0dSopenharmony_ci- } 7563be168c0dSopenharmony_ci-#endif 7564be168c0dSopenharmony_ci- NetTrainFlags *flags_{nullptr}; 7565be168c0dSopenharmony_ci- static std::function<int(NetTrainFlags *)> nr_cb_; 7566be168c0dSopenharmony_ci+ int MarkPerformance() override; 7567be168c0dSopenharmony_ci+ int MarkAccuracy(bool enforce_accuracy = true) override; 7568be168c0dSopenharmony_ci+ int CompareOutput() override; 7569be168c0dSopenharmony_ci+ int SaveModels() override; 7570be168c0dSopenharmony_ci+ 7571be168c0dSopenharmony_ci // callback parameters 7572be168c0dSopenharmony_ci uint64_t op_begin_ = 0; 7573be168c0dSopenharmony_ci int op_call_times_total_ = 0; 7574be168c0dSopenharmony_ci@@ -301,13 +111,6 @@ class MS_API NetTrain { 7575be168c0dSopenharmony_ci 7576be168c0dSopenharmony_ci mindspore::MSKernelCallBack before_call_back_{nullptr}; 7577be168c0dSopenharmony_ci mindspore::MSKernelCallBack after_call_back_{nullptr}; 7578be168c0dSopenharmony_ci- nlohmann::json dump_cfg_json_; 7579be168c0dSopenharmony_ci- std::string dump_file_output_dir_; 7580be168c0dSopenharmony_ci- std::vector<std::shared_ptr<char>> inputs_buf_; 7581be168c0dSopenharmony_ci- std::vector<size_t> inputs_size_; 7582be168c0dSopenharmony_ci- size_t batch_num_ = 0; 7583be168c0dSopenharmony_ci }; 7584be168c0dSopenharmony_ci- 7585be168c0dSopenharmony_ci-int MS_API RunNetTrain(int argc, const char **argv); 7586be168c0dSopenharmony_ci } // namespace mindspore::lite 7587be168c0dSopenharmony_ci #endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_ 7588be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train_base.cc b/mindspore/lite/tools/benchmark_train/net_train_base.cc 7589be168c0dSopenharmony_cinew file mode 100644 7590be168c0dSopenharmony_ciindex 00000000..8d3c75de 7591be168c0dSopenharmony_ci--- /dev/null 7592be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train_base.cc 7593be168c0dSopenharmony_ci@@ -0,0 +1,410 @@ 7594be168c0dSopenharmony_ci+/** 7595be168c0dSopenharmony_ci+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 7596be168c0dSopenharmony_ci+ * 7597be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 7598be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 7599be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 7600be168c0dSopenharmony_ci+ * 7601be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 7602be168c0dSopenharmony_ci+ * 7603be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 7604be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 7605be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 7606be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 7607be168c0dSopenharmony_ci+ * limitations under the License. 7608be168c0dSopenharmony_ci+ */ 7609be168c0dSopenharmony_ci+ 7610be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_base.h" 7611be168c0dSopenharmony_ci+#define __STDC_FORMAT_MACROS 7612be168c0dSopenharmony_ci+#undef __STDC_FORMAT_MACROS 7613be168c0dSopenharmony_ci+#include <algorithm> 7614be168c0dSopenharmony_ci+#include <cstring> 7615be168c0dSopenharmony_ci+#ifdef ENABLE_NEON 7616be168c0dSopenharmony_ci+#include <arm_neon.h> 7617be168c0dSopenharmony_ci+#endif 7618be168c0dSopenharmony_ci+#include "src/common/common.h" 7619be168c0dSopenharmony_ci+#include "include/api/serialization.h" 7620be168c0dSopenharmony_ci+ 7621be168c0dSopenharmony_ci+namespace mindspore { 7622be168c0dSopenharmony_ci+namespace lite { 7623be168c0dSopenharmony_ci+const std::unordered_map<int, std::string> kTypeIdMap{ 7624be168c0dSopenharmony_ci+ {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"}, {kNumberTypeFloat32, "Float32"}, 7625be168c0dSopenharmony_ci+ {kNumberTypeInt8, "Int8"}, {kNumberTypeInt16, "Int16"}, {kNumberTypeInt, "Int32"}, 7626be168c0dSopenharmony_ci+ {kNumberTypeInt32, "Int32"}, {kNumberTypeUInt8, "UInt8"}, {kNumberTypeUInt16, "UInt16"}, 7627be168c0dSopenharmony_ci+ {kNumberTypeUInt, "UInt32"}, {kNumberTypeUInt32, "UInt32"}, {kObjectTypeString, "String"}, 7628be168c0dSopenharmony_ci+ {kNumberTypeBool, "Bool"}, {kObjectTypeTensorType, "Tensor"}}; 7629be168c0dSopenharmony_ci+ 7630be168c0dSopenharmony_ci+const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{ 7631be168c0dSopenharmony_ci+ {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"}, {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"}, 7632be168c0dSopenharmony_ci+ {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"}, {mindspore::CKHW, "CKHW"}, {mindspore::KHWC, "KHWC"}, 7633be168c0dSopenharmony_ci+ {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"}, {mindspore::HW4, "HW4"}, {mindspore::NC, "NC"}, 7634be168c0dSopenharmony_ci+ {mindspore::NC4, "NC4"}, {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}}; 7635be168c0dSopenharmony_ci+ 7636be168c0dSopenharmony_ci+std::function<int(NetTrainFlags *)> NetTrainBase::nr_cb_ = nullptr; 7637be168c0dSopenharmony_ci+ 7638be168c0dSopenharmony_ci+int NetTrainBase::SetNr(std::function<int(NetTrainFlags *)> param) { 7639be168c0dSopenharmony_ci+ nr_cb_ = param; 7640be168c0dSopenharmony_ci+ return 0; 7641be168c0dSopenharmony_ci+} 7642be168c0dSopenharmony_ci+ 7643be168c0dSopenharmony_ci+float *NetTrainBase::ReadFileBuf(const std::string file, size_t *size) { 7644be168c0dSopenharmony_ci+ if (file.empty()) { 7645be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "file is nullptr"; 7646be168c0dSopenharmony_ci+ return nullptr; 7647be168c0dSopenharmony_ci+ } 7648be168c0dSopenharmony_ci+ MS_ASSERT(size != nullptr); 7649be168c0dSopenharmony_ci+ std::string real_path = RealPath(file.c_str()); 7650be168c0dSopenharmony_ci+ std::ifstream ifs(real_path); 7651be168c0dSopenharmony_ci+ if (!ifs.good()) { 7652be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 7653be168c0dSopenharmony_ci+ return nullptr; 7654be168c0dSopenharmony_ci+ } 7655be168c0dSopenharmony_ci+ 7656be168c0dSopenharmony_ci+ if (!ifs.is_open()) { 7657be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "file: " << real_path << " open failed"; 7658be168c0dSopenharmony_ci+ return nullptr; 7659be168c0dSopenharmony_ci+ } 7660be168c0dSopenharmony_ci+ 7661be168c0dSopenharmony_ci+ ifs.seekg(0, std::ios::end); 7662be168c0dSopenharmony_ci+ *size = ifs.tellg(); 7663be168c0dSopenharmony_ci+ std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1); 7664be168c0dSopenharmony_ci+ if (buf == nullptr) { 7665be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "malloc buf failed, file: " << real_path; 7666be168c0dSopenharmony_ci+ ifs.close(); 7667be168c0dSopenharmony_ci+ return nullptr; 7668be168c0dSopenharmony_ci+ } 7669be168c0dSopenharmony_ci+ 7670be168c0dSopenharmony_ci+ ifs.seekg(0, std::ios::beg); 7671be168c0dSopenharmony_ci+ ifs.read(reinterpret_cast<char *>(buf.get()), *size); 7672be168c0dSopenharmony_ci+ ifs.close(); 7673be168c0dSopenharmony_ci+ 7674be168c0dSopenharmony_ci+ return buf.release(); 7675be168c0dSopenharmony_ci+} 7676be168c0dSopenharmony_ci+ 7677be168c0dSopenharmony_ci+int NetTrainBase::GenerateRandomData(mindspore::MSTensor *tensor) { 7678be168c0dSopenharmony_ci+ auto input_data = tensor->MutableData(); 7679be168c0dSopenharmony_ci+ if (input_data == nullptr) { 7680be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "MallocData for inTensor failed"; 7681be168c0dSopenharmony_ci+ return RET_ERROR; 7682be168c0dSopenharmony_ci+ } 7683be168c0dSopenharmony_ci+ auto tensor_byte_size = tensor->DataSize(); 7684be168c0dSopenharmony_ci+ char *casted_data = static_cast<char *>(input_data); 7685be168c0dSopenharmony_ci+ for (size_t i = 0; i < tensor_byte_size; i++) { 7686be168c0dSopenharmony_ci+ casted_data[i] = 7687be168c0dSopenharmony_ci+ (tensor->DataType() == mindspore::DataType::kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0); 7688be168c0dSopenharmony_ci+ } 7689be168c0dSopenharmony_ci+ return RET_OK; 7690be168c0dSopenharmony_ci+} 7691be168c0dSopenharmony_ci+ 7692be168c0dSopenharmony_ci+int NetTrainBase::LoadInput() { 7693be168c0dSopenharmony_ci+ inputs_buf_.clear(); 7694be168c0dSopenharmony_ci+ inputs_size_.clear(); 7695be168c0dSopenharmony_ci+ batch_num_ = 0; 7696be168c0dSopenharmony_ci+ if (flags_->in_data_file_.empty()) { 7697be168c0dSopenharmony_ci+ auto status = GenerateInputData(); 7698be168c0dSopenharmony_ci+ if (status != RET_OK) { 7699be168c0dSopenharmony_ci+ std::cerr << "Generate input data error " << status << std::endl; 7700be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Generate input data error " << status; 7701be168c0dSopenharmony_ci+ return status; 7702be168c0dSopenharmony_ci+ } 7703be168c0dSopenharmony_ci+ } else { 7704be168c0dSopenharmony_ci+ auto status = ReadInputFile(); 7705be168c0dSopenharmony_ci+ if (status != RET_OK) { 7706be168c0dSopenharmony_ci+ std::cerr << "Read Input File error, " << status << std::endl; 7707be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Read Input File error, " << status; 7708be168c0dSopenharmony_ci+ return status; 7709be168c0dSopenharmony_ci+ } 7710be168c0dSopenharmony_ci+ } 7711be168c0dSopenharmony_ci+ return RET_OK; 7712be168c0dSopenharmony_ci+} 7713be168c0dSopenharmony_ci+ 7714be168c0dSopenharmony_ci+int NetTrainBase::RunNetTrain() { 7715be168c0dSopenharmony_ci+ auto file_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1); 7716be168c0dSopenharmony_ci+ bool is_train = (file_name.find("train") != std::string::npos) || !flags_->bb_model_file_.empty(); 7717be168c0dSopenharmony_ci+ auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, is_train, flags_->epochs_); 7718be168c0dSopenharmony_ci+ if (status != RET_OK) { 7719be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status; 7720be168c0dSopenharmony_ci+ std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status 7721be168c0dSopenharmony_ci+ << std::endl; 7722be168c0dSopenharmony_ci+ return status; 7723be168c0dSopenharmony_ci+ } 7724be168c0dSopenharmony_ci+ 7725be168c0dSopenharmony_ci+ status = CheckExecutionOfSavedModels(); // re-initialize sessions according to flags 7726be168c0dSopenharmony_ci+ if (status != RET_OK) { 7727be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Run CheckExecute error: " << status; 7728be168c0dSopenharmony_ci+ std::cout << "Run CheckExecute error: " << status << std::endl; 7729be168c0dSopenharmony_ci+ return status; 7730be168c0dSopenharmony_ci+ } 7731be168c0dSopenharmony_ci+ return RET_OK; 7732be168c0dSopenharmony_ci+} 7733be168c0dSopenharmony_ci+ 7734be168c0dSopenharmony_ci+int NetTrainBase::CheckExecutionOfSavedModels() { 7735be168c0dSopenharmony_ci+ int status = RET_OK; 7736be168c0dSopenharmony_ci+ if (!flags_->export_file_.empty()) { 7737be168c0dSopenharmony_ci+ status = CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0); 7738be168c0dSopenharmony_ci+ if (status != RET_OK) { 7739be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status; 7740be168c0dSopenharmony_ci+ std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl; 7741be168c0dSopenharmony_ci+ return status; 7742be168c0dSopenharmony_ci+ } 7743be168c0dSopenharmony_ci+ if (flags_->bb_model_file_.empty()) { 7744be168c0dSopenharmony_ci+ status = CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false); 7745be168c0dSopenharmony_ci+ if (status != RET_OK) { 7746be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status; 7747be168c0dSopenharmony_ci+ std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl; 7748be168c0dSopenharmony_ci+ return status; 7749be168c0dSopenharmony_ci+ } 7750be168c0dSopenharmony_ci+ } 7751be168c0dSopenharmony_ci+ } 7752be168c0dSopenharmony_ci+ if (!flags_->inference_file_.empty()) { 7753be168c0dSopenharmony_ci+ status = CreateAndRunNetwork(flags_->inference_file_, "", false, 0); 7754be168c0dSopenharmony_ci+ if (status != RET_OK) { 7755be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status; 7756be168c0dSopenharmony_ci+ std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl; 7757be168c0dSopenharmony_ci+ return status; 7758be168c0dSopenharmony_ci+ } 7759be168c0dSopenharmony_ci+ status = CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false); 7760be168c0dSopenharmony_ci+ if (status != RET_OK) { 7761be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status; 7762be168c0dSopenharmony_ci+ std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl; 7763be168c0dSopenharmony_ci+ return status; 7764be168c0dSopenharmony_ci+ } 7765be168c0dSopenharmony_ci+ } 7766be168c0dSopenharmony_ci+ return status; 7767be168c0dSopenharmony_ci+} 7768be168c0dSopenharmony_ci+ 7769be168c0dSopenharmony_ci+void NetTrainBase::CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out) { 7770be168c0dSopenharmony_ci+ if (tensor == nullptr) { 7771be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "input tensor is nullptr."; 7772be168c0dSopenharmony_ci+ return; 7773be168c0dSopenharmony_ci+ } 7774be168c0dSopenharmony_ci+ int tensor_size = tensor->ElementNum(); 7775be168c0dSopenharmony_ci+ void *data = tensor->MutableData(); 7776be168c0dSopenharmony_ci+ auto *fdata = reinterpret_cast<float *>(tensor->MutableData()); 7777be168c0dSopenharmony_ci+ auto type = tensor->DataType(); 7778be168c0dSopenharmony_ci+ std::cout << node_type << " " << in_out << id << " shape=" << tensor->Shape() << " sum="; 7779be168c0dSopenharmony_ci+ switch (type) { 7780be168c0dSopenharmony_ci+ case mindspore::DataType::kNumberTypeFloat32: 7781be168c0dSopenharmony_ci+ TensorNan(reinterpret_cast<float *>(data), tensor_size); 7782be168c0dSopenharmony_ci+ std::cout << TensorSum<float>(data, tensor_size) << std::endl; 7783be168c0dSopenharmony_ci+ std::cout << "tensor name: " << tensor->Name() << std::endl; 7784be168c0dSopenharmony_ci+ std::cout << "data: "; 7785be168c0dSopenharmony_ci+ for (int i = 0; i <= kPrintOffset && i < tensor_size; i++) { 7786be168c0dSopenharmony_ci+ std::cout << static_cast<float>(fdata[i]) << ", "; 7787be168c0dSopenharmony_ci+ } 7788be168c0dSopenharmony_ci+ std::cout << std::endl; 7789be168c0dSopenharmony_ci+ break; 7790be168c0dSopenharmony_ci+ case mindspore::DataType::kNumberTypeInt32: 7791be168c0dSopenharmony_ci+ std::cout << TensorSum<int>(data, tensor_size) << std::endl; 7792be168c0dSopenharmony_ci+ break; 7793be168c0dSopenharmony_ci+#ifdef ENABLE_FP16 7794be168c0dSopenharmony_ci+ case mindspore::DataType::kNumberTypeFloat16: 7795be168c0dSopenharmony_ci+ std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl; 7796be168c0dSopenharmony_ci+ TensorNan(reinterpret_cast<float16_t *>(data), tensor_size); 7797be168c0dSopenharmony_ci+ break; 7798be168c0dSopenharmony_ci+#endif 7799be168c0dSopenharmony_ci+ default: 7800be168c0dSopenharmony_ci+ std::cout << "unsupported type:" << static_cast<int>(type) << std::endl; 7801be168c0dSopenharmony_ci+ break; 7802be168c0dSopenharmony_ci+ } 7803be168c0dSopenharmony_ci+} 7804be168c0dSopenharmony_ci+ 7805be168c0dSopenharmony_ci+std::string NetTrainBase::GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name, 7806be168c0dSopenharmony_ci+ const std::string &file_type, const size_t &idx) { 7807be168c0dSopenharmony_ci+ std::string file_name = op_name; 7808be168c0dSopenharmony_ci+ auto pos = file_name.find_first_of('/'); 7809be168c0dSopenharmony_ci+ while (pos != std::string::npos) { 7810be168c0dSopenharmony_ci+ file_name.replace(pos, 1, "."); 7811be168c0dSopenharmony_ci+ pos = file_name.find_first_of('/'); 7812be168c0dSopenharmony_ci+ } 7813be168c0dSopenharmony_ci+ file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_"; 7814be168c0dSopenharmony_ci+ for (const auto &dim : tensor->Shape()) { 7815be168c0dSopenharmony_ci+ file_name += std::to_string(dim) + "_"; 7816be168c0dSopenharmony_ci+ } 7817be168c0dSopenharmony_ci+ if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) { 7818be168c0dSopenharmony_ci+ file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType())); 7819be168c0dSopenharmony_ci+ } 7820be168c0dSopenharmony_ci+ auto tensor_format = tensor->format(); 7821be168c0dSopenharmony_ci+ if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) { 7822be168c0dSopenharmony_ci+ file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin"; 7823be168c0dSopenharmony_ci+ } 7824be168c0dSopenharmony_ci+ 7825be168c0dSopenharmony_ci+ file_name += ".bin"; 7826be168c0dSopenharmony_ci+ return file_name; 7827be168c0dSopenharmony_ci+} 7828be168c0dSopenharmony_ci+ 7829be168c0dSopenharmony_ci+int NetTrainBase::InitCallbackParameter() { 7830be168c0dSopenharmony_ci+ int ret = RET_OK; 7831be168c0dSopenharmony_ci+ if (flags_->dump_tensor_data_) { 7832be168c0dSopenharmony_ci+ ret = InitDumpTensorDataCallbackParameter(); 7833be168c0dSopenharmony_ci+ } else if (flags_->time_profiling_) { 7834be168c0dSopenharmony_ci+ ret = InitTimeProfilingCallbackParameter(); 7835be168c0dSopenharmony_ci+ } 7836be168c0dSopenharmony_ci+ return ret; 7837be168c0dSopenharmony_ci+} 7838be168c0dSopenharmony_ci+ 7839be168c0dSopenharmony_ci+void NetTrainFlags::InitResizeDimsList() { 7840be168c0dSopenharmony_ci+ std::string content = this->resize_dims_in_; 7841be168c0dSopenharmony_ci+ if (content.empty()) { 7842be168c0dSopenharmony_ci+ return; 7843be168c0dSopenharmony_ci+ } 7844be168c0dSopenharmony_ci+ std::vector<int> shape; 7845be168c0dSopenharmony_ci+ auto shape_strs = StrSplit(content, std::string(DELIM_COLON)); 7846be168c0dSopenharmony_ci+ for (const auto &shape_str : shape_strs) { 7847be168c0dSopenharmony_ci+ shape.clear(); 7848be168c0dSopenharmony_ci+ auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA)); 7849be168c0dSopenharmony_ci+ std::cout << "Resize Dims: "; 7850be168c0dSopenharmony_ci+ for (const auto &dim_str : dim_strs) { 7851be168c0dSopenharmony_ci+ std::cout << dim_str << " "; 7852be168c0dSopenharmony_ci+ shape.emplace_back(static_cast<int>(std::stoi(dim_str))); 7853be168c0dSopenharmony_ci+ } 7854be168c0dSopenharmony_ci+ std::cout << std::endl; 7855be168c0dSopenharmony_ci+ this->resize_dims_.emplace_back(shape); 7856be168c0dSopenharmony_ci+ } 7857be168c0dSopenharmony_ci+} 7858be168c0dSopenharmony_ci+ 7859be168c0dSopenharmony_ci+int NetTrainBase::Init() { 7860be168c0dSopenharmony_ci+ if (this->flags_ == nullptr) { 7861be168c0dSopenharmony_ci+ return 1; 7862be168c0dSopenharmony_ci+ } 7863be168c0dSopenharmony_ci+ MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_; 7864be168c0dSopenharmony_ci+ MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_; 7865be168c0dSopenharmony_ci+ MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_; 7866be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_; 7867be168c0dSopenharmony_ci+ MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_; 7868be168c0dSopenharmony_ci+ MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_; 7869be168c0dSopenharmony_ci+ MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_; 7870be168c0dSopenharmony_ci+ MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_; 7871be168c0dSopenharmony_ci+ MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_; 7872be168c0dSopenharmony_ci+ MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_; 7873be168c0dSopenharmony_ci+ MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_; 7874be168c0dSopenharmony_ci+ 7875be168c0dSopenharmony_ci+ if (this->flags_->epochs_ < 0) { 7876be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0"; 7877be168c0dSopenharmony_ci+ std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl; 7878be168c0dSopenharmony_ci+ return RET_ERROR; 7879be168c0dSopenharmony_ci+ } 7880be168c0dSopenharmony_ci+ 7881be168c0dSopenharmony_ci+ if (this->flags_->num_threads_ < 1) { 7882be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0"; 7883be168c0dSopenharmony_ci+ std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl; 7884be168c0dSopenharmony_ci+ return RET_ERROR; 7885be168c0dSopenharmony_ci+ } 7886be168c0dSopenharmony_ci+ 7887be168c0dSopenharmony_ci+ this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary; 7888be168c0dSopenharmony_ci+ 7889be168c0dSopenharmony_ci+ if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) { 7890be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided"; 7891be168c0dSopenharmony_ci+ std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl; 7892be168c0dSopenharmony_ci+ return RET_ERROR; 7893be168c0dSopenharmony_ci+ } 7894be168c0dSopenharmony_ci+ 7895be168c0dSopenharmony_ci+ if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) { 7896be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided"; 7897be168c0dSopenharmony_ci+ std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl; 7898be168c0dSopenharmony_ci+ return RET_ERROR; 7899be168c0dSopenharmony_ci+ } 7900be168c0dSopenharmony_ci+ 7901be168c0dSopenharmony_ci+ if (flags_->model_file_.empty()) { 7902be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "modelPath is required"; 7903be168c0dSopenharmony_ci+ std::cerr << "modelPath is required" << std::endl; 7904be168c0dSopenharmony_ci+ return 1; 7905be168c0dSopenharmony_ci+ } 7906be168c0dSopenharmony_ci+ 7907be168c0dSopenharmony_ci+ // get dump data output path 7908be168c0dSopenharmony_ci+ auto dump_cfg_path = std::getenv(dump::kConfigPath); 7909be168c0dSopenharmony_ci+ if (dump_cfg_path != nullptr) { 7910be168c0dSopenharmony_ci+ flags_->dump_tensor_data_ = true; 7911be168c0dSopenharmony_ci+ if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) { 7912be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "parse dump config file failed."; 7913be168c0dSopenharmony_ci+ return RET_ERROR; 7914be168c0dSopenharmony_ci+ } 7915be168c0dSopenharmony_ci+ } else { 7916be168c0dSopenharmony_ci+ MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data"; 7917be168c0dSopenharmony_ci+ } 7918be168c0dSopenharmony_ci+ 7919be168c0dSopenharmony_ci+ auto status = InitCallbackParameter(); 7920be168c0dSopenharmony_ci+ if (status != RET_OK) { 7921be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Init callback Parameter failed."; 7922be168c0dSopenharmony_ci+ std::cerr << "Init callback Parameter failed." << std::endl; 7923be168c0dSopenharmony_ci+ return RET_ERROR; 7924be168c0dSopenharmony_ci+ } 7925be168c0dSopenharmony_ci+ 7926be168c0dSopenharmony_ci+ flags_->InitResizeDimsList(); 7927be168c0dSopenharmony_ci+ if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() && 7928be168c0dSopenharmony_ci+ flags_->resize_dims_.size() != flags_->input_data_list_.size()) { 7929be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath"; 7930be168c0dSopenharmony_ci+ std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl; 7931be168c0dSopenharmony_ci+ return RET_ERROR; 7932be168c0dSopenharmony_ci+ } 7933be168c0dSopenharmony_ci+ return RET_OK; 7934be168c0dSopenharmony_ci+} 7935be168c0dSopenharmony_ci+ 7936be168c0dSopenharmony_ci+int NetTrainBase::InitDumpConfigFromJson(std::string path) { 7937be168c0dSopenharmony_ci+ auto real_path = RealPath(path.c_str()); 7938be168c0dSopenharmony_ci+ std::ifstream ifs(real_path); 7939be168c0dSopenharmony_ci+ if (!ifs.good()) { 7940be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "file: " << real_path << " is not exist"; 7941be168c0dSopenharmony_ci+ return RET_ERROR; 7942be168c0dSopenharmony_ci+ } 7943be168c0dSopenharmony_ci+ if (!ifs.is_open()) { 7944be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "file: " << real_path << " open failed"; 7945be168c0dSopenharmony_ci+ return RET_ERROR; 7946be168c0dSopenharmony_ci+ } 7947be168c0dSopenharmony_ci+ 7948be168c0dSopenharmony_ci+ try { 7949be168c0dSopenharmony_ci+ dump_cfg_json_ = nlohmann::json::parse(ifs); 7950be168c0dSopenharmony_ci+ } catch (const nlohmann::json::parse_error &error) { 7951be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "parse json file failed, please check your file."; 7952be168c0dSopenharmony_ci+ return RET_ERROR; 7953be168c0dSopenharmony_ci+ } 7954be168c0dSopenharmony_ci+ if (dump_cfg_json_[dump::kSettings] == nullptr) { 7955be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "\"common_dump_settings\" is required."; 7956be168c0dSopenharmony_ci+ return RET_ERROR; 7957be168c0dSopenharmony_ci+ } 7958be168c0dSopenharmony_ci+ if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) { 7959be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "\"dump_mode\" is required."; 7960be168c0dSopenharmony_ci+ return RET_ERROR; 7961be168c0dSopenharmony_ci+ } 7962be168c0dSopenharmony_ci+ if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) { 7963be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "\"path\" is required."; 7964be168c0dSopenharmony_ci+ return RET_ERROR; 7965be168c0dSopenharmony_ci+ } 7966be168c0dSopenharmony_ci+ if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) { 7967be168c0dSopenharmony_ci+ dump_cfg_json_[dump::kSettings][dump::kNetName] = "default"; 7968be168c0dSopenharmony_ci+ } 7969be168c0dSopenharmony_ci+ if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) { 7970be168c0dSopenharmony_ci+ dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0; 7971be168c0dSopenharmony_ci+ } 7972be168c0dSopenharmony_ci+ if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr && 7973be168c0dSopenharmony_ci+ !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) { 7974be168c0dSopenharmony_ci+ if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) { 7975be168c0dSopenharmony_ci+ MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)"; 7976be168c0dSopenharmony_ci+ return RET_ERROR; 7977be168c0dSopenharmony_ci+ } 7978be168c0dSopenharmony_ci+ } 7979be168c0dSopenharmony_ci+ 7980be168c0dSopenharmony_ci+ auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>(); 7981be168c0dSopenharmony_ci+ auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>(); 7982be168c0dSopenharmony_ci+ if (abs_path.back() == '\\' || abs_path.back() == '/') { 7983be168c0dSopenharmony_ci+ dump_file_output_dir_ = abs_path + net_name; 7984be168c0dSopenharmony_ci+ } else { 7985be168c0dSopenharmony_ci+#ifdef _WIN32 7986be168c0dSopenharmony_ci+ dump_file_output_dir_ = abs_path + "\\" + net_name; 7987be168c0dSopenharmony_ci+#else 7988be168c0dSopenharmony_ci+ dump_file_output_dir_ = abs_path + "/" + net_name; 7989be168c0dSopenharmony_ci+#endif 7990be168c0dSopenharmony_ci+ } 7991be168c0dSopenharmony_ci+ 7992be168c0dSopenharmony_ci+ auto status = CreateOutputDir(&dump_file_output_dir_); 7993be168c0dSopenharmony_ci+ if (status != RET_OK) { 7994be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "create data output directory failed."; 7995be168c0dSopenharmony_ci+ return RET_ERROR; 7996be168c0dSopenharmony_ci+ } 7997be168c0dSopenharmony_ci+ return RET_OK; 7998be168c0dSopenharmony_ci+} 7999be168c0dSopenharmony_ci+ 8000be168c0dSopenharmony_ci+NetTrainBase:: ~NetTrainBase() { 8001be168c0dSopenharmony_ci+} 8002be168c0dSopenharmony_ci+} // namespace lite 8003be168c0dSopenharmony_ci+} // namespace mindspore 8004be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train_base.h b/mindspore/lite/tools/benchmark_train/net_train_base.h 8005be168c0dSopenharmony_cinew file mode 100644 8006be168c0dSopenharmony_ciindex 00000000..e3d5f39a 8007be168c0dSopenharmony_ci--- /dev/null 8008be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train_base.h 8009be168c0dSopenharmony_ci@@ -0,0 +1,288 @@ 8010be168c0dSopenharmony_ci+/** 8011be168c0dSopenharmony_ci+ * Copyright 2020-2023 Huawei Technologies Co., Ltd 8012be168c0dSopenharmony_ci+ * 8013be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8014be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8015be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8016be168c0dSopenharmony_ci+ * 8017be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8018be168c0dSopenharmony_ci+ * 8019be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8020be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8021be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8022be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8023be168c0dSopenharmony_ci+ * limitations under the License. 8024be168c0dSopenharmony_ci+ */ 8025be168c0dSopenharmony_ci+ 8026be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_ 8027be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_ 8028be168c0dSopenharmony_ci+ 8029be168c0dSopenharmony_ci+#include <getopt.h> 8030be168c0dSopenharmony_ci+#include <csignal> 8031be168c0dSopenharmony_ci+#include <unordered_map> 8032be168c0dSopenharmony_ci+#include <fstream> 8033be168c0dSopenharmony_ci+#include <iostream> 8034be168c0dSopenharmony_ci+#include <map> 8035be168c0dSopenharmony_ci+#include <cmath> 8036be168c0dSopenharmony_ci+#include <string> 8037be168c0dSopenharmony_ci+#include <vector> 8038be168c0dSopenharmony_ci+#include <memory> 8039be168c0dSopenharmony_ci+#include <cfloat> 8040be168c0dSopenharmony_ci+#include <utility> 8041be168c0dSopenharmony_ci+#include <algorithm> 8042be168c0dSopenharmony_ci+#include <nlohmann/json.hpp> 8043be168c0dSopenharmony_ci+#include "include/api/model.h" 8044be168c0dSopenharmony_ci+#include "include/api/types.h" 8045be168c0dSopenharmony_ci+#include "include/api/context.h" 8046be168c0dSopenharmony_ci+#include "include/api/cfg.h" 8047be168c0dSopenharmony_ci+ 8048be168c0dSopenharmony_ci+#ifdef ENABLE_FP16 8049be168c0dSopenharmony_ci+#include <arm_neon.h> 8050be168c0dSopenharmony_ci+#endif 8051be168c0dSopenharmony_ci+#include "tools/common/flag_parser.h" 8052be168c0dSopenharmony_ci+#include "src/common/file_utils.h" 8053be168c0dSopenharmony_ci+#include "src/common/utils.h" 8054be168c0dSopenharmony_ci+ 8055be168c0dSopenharmony_ci+#ifdef ENABLE_FP16 8056be168c0dSopenharmony_ci+static __attribute__((always_inline)) inline bool MS_ISNAN_FP16(float16_t var) { 8057be168c0dSopenharmony_ci+ volatile float16_t d = var; 8058be168c0dSopenharmony_ci+ return d != d; 8059be168c0dSopenharmony_ci+} 8060be168c0dSopenharmony_ci+#endif 8061be168c0dSopenharmony_ci+ 8062be168c0dSopenharmony_ci+namespace mindspore::lite { 8063be168c0dSopenharmony_ci+enum MS_API DataType { kImage = 0, kBinary = 1 }; 8064be168c0dSopenharmony_ci+ 8065be168c0dSopenharmony_ci+constexpr float relativeTolerance = 1e-5; 8066be168c0dSopenharmony_ci+constexpr float absoluteTolerance = 1e-8; 8067be168c0dSopenharmony_ci+extern const std::unordered_map<int, std::string> kTypeIdMap; 8068be168c0dSopenharmony_ci+extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap; 8069be168c0dSopenharmony_ci+ 8070be168c0dSopenharmony_ci+constexpr const char *DELIM_SLASH = "/"; 8071be168c0dSopenharmony_ci+constexpr const char *DELIM_COLON = ":"; 8072be168c0dSopenharmony_ci+constexpr const char *DELIM_COMMA = ","; 8073be168c0dSopenharmony_ci+ 8074be168c0dSopenharmony_ci+constexpr int RET_TOO_BIG = -9; 8075be168c0dSopenharmony_ci+constexpr int kFieldsToPrint = 5; 8076be168c0dSopenharmony_ci+constexpr int kPrintOffset = 4; 8077be168c0dSopenharmony_ci+constexpr int kDumpInputsAndOutputs = 0; 8078be168c0dSopenharmony_ci+constexpr int kDumpOutputs = 2; 8079be168c0dSopenharmony_ci+constexpr int kTHOUSAND = 1000; 8080be168c0dSopenharmony_ci+ 8081be168c0dSopenharmony_ci+namespace dump { 8082be168c0dSopenharmony_ci+constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG"; 8083be168c0dSopenharmony_ci+constexpr auto kSettings = "common_dump_settings"; 8084be168c0dSopenharmony_ci+constexpr auto kMode = "dump_mode"; 8085be168c0dSopenharmony_ci+constexpr auto kPath = "path"; 8086be168c0dSopenharmony_ci+constexpr auto kNetName = "net_name"; 8087be168c0dSopenharmony_ci+constexpr auto kInputOutput = "input_output"; 8088be168c0dSopenharmony_ci+constexpr auto kKernels = "kernels"; 8089be168c0dSopenharmony_ci+} // namespace dump 8090be168c0dSopenharmony_ci+ 8091be168c0dSopenharmony_ci+template <typename T> 8092be168c0dSopenharmony_ci+float TensorSum(const void *data, int size) { 8093be168c0dSopenharmony_ci+ const T *typed_data = reinterpret_cast<const T *>(data); 8094be168c0dSopenharmony_ci+ float sum = 0.f; 8095be168c0dSopenharmony_ci+ for (int i = 0; i < size; i++) { 8096be168c0dSopenharmony_ci+ sum += static_cast<float>(typed_data[i]); 8097be168c0dSopenharmony_ci+ } 8098be168c0dSopenharmony_ci+ return sum; 8099be168c0dSopenharmony_ci+} 8100be168c0dSopenharmony_ci+ 8101be168c0dSopenharmony_ci+class MS_API NetTrainFlags : public virtual FlagParser { 8102be168c0dSopenharmony_ci+ public: 8103be168c0dSopenharmony_ci+ NetTrainFlags() { 8104be168c0dSopenharmony_ci+ // common 8105be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::model_file_, "modelFile", "Input model file", ""); 8106be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::bb_model_file_, "bbModelFile", "Backboine model for transfer session", ""); 8107be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", ""); 8108be168c0dSopenharmony_ci+ // MarkPerformance 8109be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0); 8110be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false); 8111be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1); 8112be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1); 8113be168c0dSopenharmony_ci+ // MarkAccuracy 8114be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); 8115be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); 8116be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); 8117be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); 8118be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false); 8119be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", ""); 8120be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", ""); 8121be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false); 8122be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes", 8123be168c0dSopenharmony_ci+ "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); 8124be168c0dSopenharmony_ci+ AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false); 8125be168c0dSopenharmony_ci+ } 8126be168c0dSopenharmony_ci+ 8127be168c0dSopenharmony_ci+ ~NetTrainFlags() override = default; 8128be168c0dSopenharmony_ci+ void InitResizeDimsList(); 8129be168c0dSopenharmony_ci+ 8130be168c0dSopenharmony_ci+ public: 8131be168c0dSopenharmony_ci+ // common 8132be168c0dSopenharmony_ci+ std::string model_file_; 8133be168c0dSopenharmony_ci+ std::string in_data_file_; 8134be168c0dSopenharmony_ci+ std::string bb_model_file_; 8135be168c0dSopenharmony_ci+ std::vector<std::string> input_data_list_; 8136be168c0dSopenharmony_ci+ DataType in_data_type_; 8137be168c0dSopenharmony_ci+ std::string in_data_type_in_ = "bin"; 8138be168c0dSopenharmony_ci+ int cpu_bind_mode_ = 1; 8139be168c0dSopenharmony_ci+ bool enable_fp16_ = false; 8140be168c0dSopenharmony_ci+ bool virtual_batch_ = false; 8141be168c0dSopenharmony_ci+ // MarkPerformance 8142be168c0dSopenharmony_ci+ int num_threads_ = 1; 8143be168c0dSopenharmony_ci+ int warm_up_loop_count_ = 0; 8144be168c0dSopenharmony_ci+ bool time_profiling_; 8145be168c0dSopenharmony_ci+ int epochs_ = 1; 8146be168c0dSopenharmony_ci+ // MarkAccuracy 8147be168c0dSopenharmony_ci+ std::string data_file_; 8148be168c0dSopenharmony_ci+ std::string data_type_ = "FLOAT"; 8149be168c0dSopenharmony_ci+ float accuracy_threshold_; 8150be168c0dSopenharmony_ci+ // Resize 8151be168c0dSopenharmony_ci+ std::string export_file_ = ""; 8152be168c0dSopenharmony_ci+ std::string resize_dims_in_ = ""; 8153be168c0dSopenharmony_ci+ bool layer_checksum_ = false; 8154be168c0dSopenharmony_ci+ std::vector<std::vector<int>> resize_dims_; 8155be168c0dSopenharmony_ci+ std::string loss_name_ = ""; 8156be168c0dSopenharmony_ci+ std::string inference_file_ = ""; 8157be168c0dSopenharmony_ci+ bool unified_api_ = false; 8158be168c0dSopenharmony_ci+ bool dump_tensor_data_ = false; 8159be168c0dSopenharmony_ci+}; 8160be168c0dSopenharmony_ci+ 8161be168c0dSopenharmony_ci+class MS_API NetTrainBase { 8162be168c0dSopenharmony_ci+ public: 8163be168c0dSopenharmony_ci+ explicit NetTrainBase(NetTrainFlags *flags) : flags_(flags) {} 8164be168c0dSopenharmony_ci+ virtual ~NetTrainBase(); 8165be168c0dSopenharmony_ci+ 8166be168c0dSopenharmony_ci+ int Init(); 8167be168c0dSopenharmony_ci+ int RunNetTrain(); 8168be168c0dSopenharmony_ci+ static float *ReadFileBuf(const std::string file, size_t *size); 8169be168c0dSopenharmony_ci+ static int SetNr(std::function<int(NetTrainFlags *)> param); 8170be168c0dSopenharmony_ci+ static int RunNr(NetTrainFlags *flags) { 8171be168c0dSopenharmony_ci+ if (nr_cb_ != nullptr) { 8172be168c0dSopenharmony_ci+ return nr_cb_(flags); 8173be168c0dSopenharmony_ci+ } 8174be168c0dSopenharmony_ci+ MS_LOG(WARNING) << "unified api was not tested"; 8175be168c0dSopenharmony_ci+ std::cout << "unified api was not tested"; 8176be168c0dSopenharmony_ci+ return RET_OK; 8177be168c0dSopenharmony_ci+ } 8178be168c0dSopenharmony_ci+ // tensorData need to be converter first 8179be168c0dSopenharmony_ci+ template <typename T> 8180be168c0dSopenharmony_ci+ static float CompareData(const float *refOutput, int size, const T *msTensorData) { 8181be168c0dSopenharmony_ci+ size_t errorCount = 0; 8182be168c0dSopenharmony_ci+ float meanError = 0; 8183be168c0dSopenharmony_ci+ std::cout << "Out tensor size is: " << size << std::endl; 8184be168c0dSopenharmony_ci+ std::cout << "Data of model output: "; 8185be168c0dSopenharmony_ci+ for (int j = 0; j < std::min(50, size); j++) { 8186be168c0dSopenharmony_ci+ std::cout << static_cast<float>(msTensorData[j]) << " "; 8187be168c0dSopenharmony_ci+ } 8188be168c0dSopenharmony_ci+ std::cout << std::endl; 8189be168c0dSopenharmony_ci+ std::cout << "Data of Ref output : "; 8190be168c0dSopenharmony_ci+ for (int j = 0; j < std::min(50, size); j++) { 8191be168c0dSopenharmony_ci+ std::cout << refOutput[j] << " "; 8192be168c0dSopenharmony_ci+ } 8193be168c0dSopenharmony_ci+ std::cout << std::endl; 8194be168c0dSopenharmony_ci+ for (int j = 0; j < size; j++) { 8195be168c0dSopenharmony_ci+ if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { 8196be168c0dSopenharmony_ci+ std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; 8197be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; 8198be168c0dSopenharmony_ci+ return RET_ERROR; 8199be168c0dSopenharmony_ci+ } 8200be168c0dSopenharmony_ci+ 8201be168c0dSopenharmony_ci+ auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]); 8202be168c0dSopenharmony_ci+ auto absoluteError = std::fabs(static_cast<float>(msTensorData[j]) - refOutput[j]); 8203be168c0dSopenharmony_ci+ if (absoluteError > tolerance) { 8204be168c0dSopenharmony_ci+ if (fabs(refOutput[j]) == 0) { 8205be168c0dSopenharmony_ci+ if (absoluteError > 1e-5) { 8206be168c0dSopenharmony_ci+ meanError += absoluteError; 8207be168c0dSopenharmony_ci+ errorCount++; 8208be168c0dSopenharmony_ci+ } else { 8209be168c0dSopenharmony_ci+ continue; 8210be168c0dSopenharmony_ci+ } 8211be168c0dSopenharmony_ci+ } else { 8212be168c0dSopenharmony_ci+ // just assume that atol = rtol 8213be168c0dSopenharmony_ci+ meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN); 8214be168c0dSopenharmony_ci+ errorCount++; 8215be168c0dSopenharmony_ci+ } 8216be168c0dSopenharmony_ci+ } 8217be168c0dSopenharmony_ci+ } 8218be168c0dSopenharmony_ci+ std::cout << std::endl; 8219be168c0dSopenharmony_ci+ if (meanError > 0.0f) { 8220be168c0dSopenharmony_ci+ meanError /= errorCount; 8221be168c0dSopenharmony_ci+ } 8222be168c0dSopenharmony_ci+ 8223be168c0dSopenharmony_ci+ if (meanError <= 0.0000001) { 8224be168c0dSopenharmony_ci+ std::cout << "Mean bias of tensor: 0%" << std::endl; 8225be168c0dSopenharmony_ci+ } else { 8226be168c0dSopenharmony_ci+ std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl; 8227be168c0dSopenharmony_ci+ } 8228be168c0dSopenharmony_ci+ return meanError; 8229be168c0dSopenharmony_ci+ } 8230be168c0dSopenharmony_ci+ int InitDumpConfigFromJson(std::string path); 8231be168c0dSopenharmony_ci+ 8232be168c0dSopenharmony_ci+ protected: 8233be168c0dSopenharmony_ci+ // call GenerateInputData or ReadInputFile to init inputTensors 8234be168c0dSopenharmony_ci+ int LoadInput(); 8235be168c0dSopenharmony_ci+ void CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out); 8236be168c0dSopenharmony_ci+ // call GenerateRandomData to fill inputTensors 8237be168c0dSopenharmony_ci+ virtual int GenerateInputData() = 0; 8238be168c0dSopenharmony_ci+ 8239be168c0dSopenharmony_ci+ int GenerateRandomData(mindspore::MSTensor *tensor); 8240be168c0dSopenharmony_ci+ 8241be168c0dSopenharmony_ci+ std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name, 8242be168c0dSopenharmony_ci+ const std::string &file_type, const size_t &idx); 8243be168c0dSopenharmony_ci+ virtual int ReadInputFile() = 0; 8244be168c0dSopenharmony_ci+ 8245be168c0dSopenharmony_ci+ virtual int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs, 8246be168c0dSopenharmony_ci+ bool check_accuracy = true) = 0; 8247be168c0dSopenharmony_ci+ 8248be168c0dSopenharmony_ci+ int InitCallbackParameter(); 8249be168c0dSopenharmony_ci+ 8250be168c0dSopenharmony_ci+ virtual int InitDumpTensorDataCallbackParameter() = 0; 8251be168c0dSopenharmony_ci+ 8252be168c0dSopenharmony_ci+ virtual int InitTimeProfilingCallbackParameter() = 0; 8253be168c0dSopenharmony_ci+ 8254be168c0dSopenharmony_ci+ virtual int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) = 0; 8255be168c0dSopenharmony_ci+ 8256be168c0dSopenharmony_ci+ template <typename T> 8257be168c0dSopenharmony_ci+ std::vector<int64_t> ConverterToInt64Vector(const std::vector<T> &srcDims) { 8258be168c0dSopenharmony_ci+ std::vector<int64_t> dims; 8259be168c0dSopenharmony_ci+ for (auto shape : srcDims) { 8260be168c0dSopenharmony_ci+ dims.push_back(static_cast<int64_t>(shape)); 8261be168c0dSopenharmony_ci+ } 8262be168c0dSopenharmony_ci+ return dims; 8263be168c0dSopenharmony_ci+ } 8264be168c0dSopenharmony_ci+ virtual int MarkPerformance() = 0; 8265be168c0dSopenharmony_ci+ virtual int MarkAccuracy(bool enforce_accuracy = true) = 0; 8266be168c0dSopenharmony_ci+ virtual int CompareOutput() = 0; 8267be168c0dSopenharmony_ci+ virtual int SaveModels() = 0; 8268be168c0dSopenharmony_ci+ int CheckExecutionOfSavedModels(); 8269be168c0dSopenharmony_ci+ void TensorNan(const float *data, int size) { 8270be168c0dSopenharmony_ci+ for (int i = 0; i < size; i++) { 8271be168c0dSopenharmony_ci+ if (std::isnan(data[i])) { 8272be168c0dSopenharmony_ci+ std::cout << "nan value of index=" << i << ", " << data[i] << std::endl; 8273be168c0dSopenharmony_ci+ break; 8274be168c0dSopenharmony_ci+ } 8275be168c0dSopenharmony_ci+ } 8276be168c0dSopenharmony_ci+ } 8277be168c0dSopenharmony_ci+#ifdef ENABLE_FP16 8278be168c0dSopenharmony_ci+ void TensorNan(float16_t *data, int size) { 8279be168c0dSopenharmony_ci+ for (int i = 0; i < size; i++) { 8280be168c0dSopenharmony_ci+ if (MS_ISNAN_FP16(data[i]) || std::isinf(data[i])) { 8281be168c0dSopenharmony_ci+ std::cout << "nan or inf value of index=" << i << ", " << data[i] << std::endl; 8282be168c0dSopenharmony_ci+ break; 8283be168c0dSopenharmony_ci+ } 8284be168c0dSopenharmony_ci+ } 8285be168c0dSopenharmony_ci+ } 8286be168c0dSopenharmony_ci+#endif 8287be168c0dSopenharmony_ci+ NetTrainFlags *flags_{nullptr}; 8288be168c0dSopenharmony_ci+ static std::function<int(NetTrainFlags *)> nr_cb_; 8289be168c0dSopenharmony_ci+ 8290be168c0dSopenharmony_ci+ nlohmann::json dump_cfg_json_; 8291be168c0dSopenharmony_ci+ std::string dump_file_output_dir_; 8292be168c0dSopenharmony_ci+ std::vector<std::shared_ptr<char>> inputs_buf_; 8293be168c0dSopenharmony_ci+ std::vector<size_t> inputs_size_; 8294be168c0dSopenharmony_ci+ size_t batch_num_ = 0; 8295be168c0dSopenharmony_ci+}; 8296be168c0dSopenharmony_ci+} // namespace mindspore::lite 8297be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_ 8298be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train_c_api.cc b/mindspore/lite/tools/benchmark_train/net_train_c_api.cc 8299be168c0dSopenharmony_cinew file mode 100644 8300be168c0dSopenharmony_ciindex 00000000..4dcf3af6 8301be168c0dSopenharmony_ci--- /dev/null 8302be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train_c_api.cc 8303be168c0dSopenharmony_ci@@ -0,0 +1,659 @@ 8304be168c0dSopenharmony_ci+/** 8305be168c0dSopenharmony_ci+ * Copyright 2023-2023 Huawei Technologies Co., Ltd 8306be168c0dSopenharmony_ci+ * 8307be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8308be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8309be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8310be168c0dSopenharmony_ci+ * 8311be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8312be168c0dSopenharmony_ci+ * 8313be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8314be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8315be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8316be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8317be168c0dSopenharmony_ci+ * limitations under the License. 8318be168c0dSopenharmony_ci+ */ 8319be168c0dSopenharmony_ci+ 8320be168c0dSopenharmony_ci+#include "net_train_c_api.h" 8321be168c0dSopenharmony_ci+#include "securec/include/securec.h" 8322be168c0dSopenharmony_ci+ 8323be168c0dSopenharmony_ci+namespace mindspore { 8324be168c0dSopenharmony_ci+namespace lite { 8325be168c0dSopenharmony_ci+uint64_t g_op_begin_ = 0; 8326be168c0dSopenharmony_ci+int g_op_call_times_total_ = 0; 8327be168c0dSopenharmony_ci+float g_op_cost_total_ = 0.0f; 8328be168c0dSopenharmony_ci+ 8329be168c0dSopenharmony_ci+int NetTrainCApi::GenerateInputData() { 8330be168c0dSopenharmony_ci+ for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) { 8331be168c0dSopenharmony_ci+ OH_AI_TensorHandle tensor = ms_inputs_for_api_.handle_list[i]; 8332be168c0dSopenharmony_ci+ auto data_type = OH_AI_TensorGetDataType(tensor); 8333be168c0dSopenharmony_ci+ if (data_type == OH_AI_DATATYPE_OBJECTTYPE_STRING) { 8334be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported OH_AI_DATATYPE_OBJECTTYPE_STRING"; 8335be168c0dSopenharmony_ci+ return RET_ERROR; 8336be168c0dSopenharmony_ci+ } else { 8337be168c0dSopenharmony_ci+ (void)GenerateRandomData(static_cast<mindspore::MSTensor *>(tensor)); 8338be168c0dSopenharmony_ci+ } 8339be168c0dSopenharmony_ci+ } 8340be168c0dSopenharmony_ci+ return RET_OK; 8341be168c0dSopenharmony_ci+} 8342be168c0dSopenharmony_ci+ 8343be168c0dSopenharmony_ci+int NetTrainCApi::SaveModels() { 8344be168c0dSopenharmony_ci+ if (!flags_->export_file_.empty()) { 8345be168c0dSopenharmony_ci+ if (flags_->bb_model_file_.empty()) { 8346be168c0dSopenharmony_ci+ auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->export_file_ + "_qt").c_str(), OH_AI_WEIGHT_QUANT, false, 8347be168c0dSopenharmony_ci+ nullptr, 0); 8348be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8349be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Export quantized model error " << flags_->export_file_ + "_qt"; 8350be168c0dSopenharmony_ci+ std::cout << "Export quantized model error " << flags_->export_file_ + "_qt" << std::endl; 8351be168c0dSopenharmony_ci+ return RET_ERROR; 8352be168c0dSopenharmony_ci+ } 8353be168c0dSopenharmony_ci+ } 8354be168c0dSopenharmony_ci+ auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->export_file_).c_str(), OH_AI_NO_QUANT, false, 8355be168c0dSopenharmony_ci+ nullptr, 0); 8356be168c0dSopenharmony_ci+ 8357be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8358be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Export non quantized model error " << flags_->export_file_; 8359be168c0dSopenharmony_ci+ std::cout << "Export non quantized model error " << flags_->export_file_ << std::endl; 8360be168c0dSopenharmony_ci+ return RET_ERROR; 8361be168c0dSopenharmony_ci+ } 8362be168c0dSopenharmony_ci+ } 8363be168c0dSopenharmony_ci+ if (!flags_->inference_file_.empty()) { 8364be168c0dSopenharmony_ci+ auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->inference_file_ + "_qt").c_str(), OH_AI_WEIGHT_QUANT, true, 8365be168c0dSopenharmony_ci+ nullptr, 0); 8366be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8367be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Export quantized inference model error " << flags_->inference_file_ + "_qt"; 8368be168c0dSopenharmony_ci+ std::cout << "Export quantized inference model error " << flags_->inference_file_ + "_qt" << std::endl; 8369be168c0dSopenharmony_ci+ return RET_ERROR; 8370be168c0dSopenharmony_ci+ } 8371be168c0dSopenharmony_ci+ 8372be168c0dSopenharmony_ci+ auto tick = GetTimeUs(); 8373be168c0dSopenharmony_ci+ status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->inference_file_).c_str(), OH_AI_NO_QUANT, true, 8374be168c0dSopenharmony_ci+ nullptr, 0); 8375be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8376be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Export non quantized inference model error " << flags_->inference_file_; 8377be168c0dSopenharmony_ci+ std::cout << "Export non quantized inference model error " << flags_->inference_file_ << std::endl; 8378be168c0dSopenharmony_ci+ return RET_ERROR; 8379be168c0dSopenharmony_ci+ } 8380be168c0dSopenharmony_ci+ std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n"; 8381be168c0dSopenharmony_ci+ } 8382be168c0dSopenharmony_ci+ return RET_OK; 8383be168c0dSopenharmony_ci+} 8384be168c0dSopenharmony_ci+ 8385be168c0dSopenharmony_ci+int NetTrainCApi::LoadStepInput(size_t step) { 8386be168c0dSopenharmony_ci+ if (step >= batch_num_) { 8387be168c0dSopenharmony_ci+ auto cur_batch = step + 1; 8388be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Max input Batch is:" << batch_num_ << " but got batch :" << cur_batch; 8389be168c0dSopenharmony_ci+ return RET_ERROR; 8390be168c0dSopenharmony_ci+ } 8391be168c0dSopenharmony_ci+ for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) { 8392be168c0dSopenharmony_ci+ OH_AI_TensorHandle cur_tensor = ms_inputs_for_api_.handle_list[i]; 8393be168c0dSopenharmony_ci+ MS_ASSERT(cur_tensor != nullptr); 8394be168c0dSopenharmony_ci+ auto tensor_data_size = OH_AI_TensorGetDataSize(cur_tensor); 8395be168c0dSopenharmony_ci+ auto input_data = OH_AI_TensorGetMutableData(cur_tensor); 8396be168c0dSopenharmony_ci+ MS_ASSERT(input_data != nullptr); 8397be168c0dSopenharmony_ci+ memcpy_s(input_data, tensor_data_size, inputs_buf_[i].get() + step * tensor_data_size, tensor_data_size); 8398be168c0dSopenharmony_ci+ } 8399be168c0dSopenharmony_ci+ return RET_OK; 8400be168c0dSopenharmony_ci+} 8401be168c0dSopenharmony_ci+ 8402be168c0dSopenharmony_ci+int NetTrainCApi::ReadInputFile() { 8403be168c0dSopenharmony_ci+ if (this->flags_->in_data_type_ == lite::kImage) { 8404be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported image input"; 8405be168c0dSopenharmony_ci+ return RET_ERROR; 8406be168c0dSopenharmony_ci+ } else { 8407be168c0dSopenharmony_ci+ for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) { 8408be168c0dSopenharmony_ci+ OH_AI_TensorHandle tensor = ms_inputs_for_api_.handle_list[i]; 8409be168c0dSopenharmony_ci+ MS_ASSERT(tensor != nullptr); 8410be168c0dSopenharmony_ci+ size_t size; 8411be168c0dSopenharmony_ci+ std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin"; 8412be168c0dSopenharmony_ci+ auto bin_buf = lite::ReadFile(file_name.c_str(), &size); 8413be168c0dSopenharmony_ci+ if (bin_buf == nullptr) { 8414be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "ReadFile failed"; 8415be168c0dSopenharmony_ci+ return RET_ERROR; 8416be168c0dSopenharmony_ci+ } 8417be168c0dSopenharmony_ci+ auto tensor_data_size = OH_AI_TensorGetDataSize(tensor); 8418be168c0dSopenharmony_ci+ MS_ASSERT(tensor_data_size != 0); 8419be168c0dSopenharmony_ci+ if (size == 0 || size % tensor_data_size != 0 || (batch_num_ != 0 && size / tensor_data_size != batch_num_)) { 8420be168c0dSopenharmony_ci+ std::cerr << "Input binary file size error, required :N * " << tensor_data_size << ", in fact: " << size 8421be168c0dSopenharmony_ci+ << " ,file_name: " << file_name.c_str() << std::endl; 8422be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Input binary file size error, required: N * " << tensor_data_size << ", in fact: " << size 8423be168c0dSopenharmony_ci+ << " ,file_name: " << file_name.c_str(); 8424be168c0dSopenharmony_ci+ delete bin_buf; 8425be168c0dSopenharmony_ci+ return RET_ERROR; 8426be168c0dSopenharmony_ci+ } 8427be168c0dSopenharmony_ci+ inputs_buf_.emplace_back(bin_buf); 8428be168c0dSopenharmony_ci+ inputs_size_.emplace_back(size); 8429be168c0dSopenharmony_ci+ batch_num_ = size / tensor_data_size; 8430be168c0dSopenharmony_ci+ } 8431be168c0dSopenharmony_ci+ } 8432be168c0dSopenharmony_ci+ return RET_OK; 8433be168c0dSopenharmony_ci+} 8434be168c0dSopenharmony_ci+ 8435be168c0dSopenharmony_ci+int NetTrainCApi::InitDumpTensorDataCallbackParameter() { 8436be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported feature."; 8437be168c0dSopenharmony_ci+ return RET_ERROR; 8438be168c0dSopenharmony_ci+} 8439be168c0dSopenharmony_ci+ 8440be168c0dSopenharmony_ci+int NetTrainCApi::InitTimeProfilingCallbackParameter() { 8441be168c0dSopenharmony_ci+ before_call_back_ = TimeProfilingBeforeCallback; 8442be168c0dSopenharmony_ci+ after_call_back_ = TimeProfilingAfterCallback; 8443be168c0dSopenharmony_ci+ return RET_OK; 8444be168c0dSopenharmony_ci+} 8445be168c0dSopenharmony_ci+ 8446be168c0dSopenharmony_ci+int NetTrainCApi::InitMSContext() { 8447be168c0dSopenharmony_ci+ context_ = OH_AI_ContextCreate(); 8448be168c0dSopenharmony_ci+ if (context_ == nullptr) { 8449be168c0dSopenharmony_ci+ MS_LOG(INFO) << "OH_AI_ContextCreate failed"; 8450be168c0dSopenharmony_ci+ return RET_ERROR; 8451be168c0dSopenharmony_ci+ } 8452be168c0dSopenharmony_ci+ OH_AI_ContextSetThreadNum(context_, flags_->num_threads_); 8453be168c0dSopenharmony_ci+ OH_AI_ContextSetThreadAffinityMode(context_, flags_->cpu_bind_mode_); 8454be168c0dSopenharmony_ci+ 8455be168c0dSopenharmony_ci+ OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU); 8456be168c0dSopenharmony_ci+ OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_); 8457be168c0dSopenharmony_ci+ OH_AI_ContextAddDeviceInfo(context_, cpu_device_info); 8458be168c0dSopenharmony_ci+ return RET_OK; 8459be168c0dSopenharmony_ci+} 8460be168c0dSopenharmony_ci+ 8461be168c0dSopenharmony_ci+char **NetTrainCApi::TransStrVectorToCharArrays(const std::vector<std::string> &s) { 8462be168c0dSopenharmony_ci+ char **char_arr = static_cast<char **>(malloc(s.size() * sizeof(char *))); 8463be168c0dSopenharmony_ci+ for (size_t i = 0; i < s.size(); i++) { 8464be168c0dSopenharmony_ci+ char_arr[i] = static_cast<char *>(malloc((s[i].size() + 1))); 8465be168c0dSopenharmony_ci+ strcpy(char_arr[i], s[i].c_str()); 8466be168c0dSopenharmony_ci+ } 8467be168c0dSopenharmony_ci+ return char_arr; 8468be168c0dSopenharmony_ci+} 8469be168c0dSopenharmony_ci+ 8470be168c0dSopenharmony_ci+std::vector<std::string> NetTrainCApi::TransCharArraysToStrVector(char **c, const size_t &num) { 8471be168c0dSopenharmony_ci+ std::vector<std::string> str; 8472be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 8473be168c0dSopenharmony_ci+ str.push_back(std::string(c[i])); 8474be168c0dSopenharmony_ci+ } 8475be168c0dSopenharmony_ci+ return str; 8476be168c0dSopenharmony_ci+} 8477be168c0dSopenharmony_ci+ 8478be168c0dSopenharmony_ci+void NetTrainCApi::InitTrainCfg() { 8479be168c0dSopenharmony_ci+ if (flags_->loss_name_.empty()) { 8480be168c0dSopenharmony_ci+ return; 8481be168c0dSopenharmony_ci+ } 8482be168c0dSopenharmony_ci+ 8483be168c0dSopenharmony_ci+ std::string delimiter = ","; 8484be168c0dSopenharmony_ci+ size_t pos = 0; 8485be168c0dSopenharmony_ci+ std::string token; 8486be168c0dSopenharmony_ci+ train_cfg_ = OH_AI_TrainCfgCreate(); 8487be168c0dSopenharmony_ci+ size_t num = 0; 8488be168c0dSopenharmony_ci+ std::vector<std::string> train_cfg_loss_name; 8489be168c0dSopenharmony_ci+ OH_AI_TrainCfgSetLossName(train_cfg_, nullptr, train_cfg_loss_name.size()); 8490be168c0dSopenharmony_ci+ while ((pos = flags_->loss_name_.find(delimiter)) != std::string::npos) { 8491be168c0dSopenharmony_ci+ token = flags_->loss_name_.substr(0, pos); 8492be168c0dSopenharmony_ci+ flags_->loss_name_.erase(0, pos + delimiter.length()); // change to delim without deletion 8493be168c0dSopenharmony_ci+ char **name = OH_AI_TrainCfgGetLossName(train_cfg_, &num); 8494be168c0dSopenharmony_ci+ train_cfg_loss_name = TransCharArraysToStrVector(name, num); 8495be168c0dSopenharmony_ci+ train_cfg_loss_name.push_back(token); 8496be168c0dSopenharmony_ci+ char **loss_name = TransStrVectorToCharArrays(train_cfg_loss_name); 8497be168c0dSopenharmony_ci+ OH_AI_TrainCfgSetLossName(train_cfg_, const_cast<const char **>(loss_name), train_cfg_loss_name.size()); 8498be168c0dSopenharmony_ci+ for (size_t i = 0; i < train_cfg_loss_name.size(); i++) { 8499be168c0dSopenharmony_ci+ free(loss_name[i]); 8500be168c0dSopenharmony_ci+ } 8501be168c0dSopenharmony_ci+ free(loss_name); 8502be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 8503be168c0dSopenharmony_ci+ free(name[i]); 8504be168c0dSopenharmony_ci+ } 8505be168c0dSopenharmony_ci+ free(name); 8506be168c0dSopenharmony_ci+ } 8507be168c0dSopenharmony_ci+ if (!(flags_->loss_name_.empty())) { 8508be168c0dSopenharmony_ci+ char **name = OH_AI_TrainCfgGetLossName(train_cfg_, &num); 8509be168c0dSopenharmony_ci+ train_cfg_loss_name = TransCharArraysToStrVector(name, num); 8510be168c0dSopenharmony_ci+ train_cfg_loss_name.push_back(flags_->loss_name_); 8511be168c0dSopenharmony_ci+ char **loss_name = TransStrVectorToCharArrays(train_cfg_loss_name); 8512be168c0dSopenharmony_ci+ OH_AI_TrainCfgSetLossName(train_cfg_, const_cast<const char **>(loss_name), train_cfg_loss_name.size()); 8513be168c0dSopenharmony_ci+ for (size_t i = 0; i < train_cfg_loss_name.size(); i++) { 8514be168c0dSopenharmony_ci+ free(loss_name[i]); 8515be168c0dSopenharmony_ci+ } 8516be168c0dSopenharmony_ci+ free(loss_name); 8517be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 8518be168c0dSopenharmony_ci+ free(name[i]); 8519be168c0dSopenharmony_ci+ } 8520be168c0dSopenharmony_ci+ free(name); 8521be168c0dSopenharmony_ci+ } 8522be168c0dSopenharmony_ci+} 8523be168c0dSopenharmony_ci+ 8524be168c0dSopenharmony_ci+int NetTrainCApi::CreateAndRunNetworkForInference(const std::string &filename, 8525be168c0dSopenharmony_ci+ const OH_AI_ContextHandle &context) { 8526be168c0dSopenharmony_ci+ std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1); 8527be168c0dSopenharmony_ci+ std::string filenamems = filename; 8528be168c0dSopenharmony_ci+ if (filenamems.substr(filenamems.find_last_of('.') + 1) != "ms") { 8529be168c0dSopenharmony_ci+ filenamems = filenamems + ".ms"; 8530be168c0dSopenharmony_ci+ } 8531be168c0dSopenharmony_ci+ MS_LOG(INFO) << "start reading model file " << filenamems.c_str(); 8532be168c0dSopenharmony_ci+ std::cout << "start reading model file " << filenamems.c_str() << std::endl; 8533be168c0dSopenharmony_ci+ auto status = OH_AI_ModelBuildFromFile(ms_model_, filenamems.c_str(), 8534be168c0dSopenharmony_ci+ static_cast<OH_AI_ModelType>(mindspore::kMindIR), context); 8535be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8536be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "ms model build failed. " << model_name; 8537be168c0dSopenharmony_ci+ return RET_ERROR; 8538be168c0dSopenharmony_ci+ } 8539be168c0dSopenharmony_ci+ return RET_OK; 8540be168c0dSopenharmony_ci+} 8541be168c0dSopenharmony_ci+ 8542be168c0dSopenharmony_ci+int NetTrainCApi::CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename, 8543be168c0dSopenharmony_ci+ const OH_AI_ContextHandle &context, 8544be168c0dSopenharmony_ci+ const OH_AI_TrainCfgHandle &train_cfg, int epochs) { 8545be168c0dSopenharmony_ci+ std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1); 8546be168c0dSopenharmony_ci+ OH_AI_Status status; 8547be168c0dSopenharmony_ci+ if (!bb_filename.empty()) { 8548be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "build transfer learning not supported. " << model_name; 8549be168c0dSopenharmony_ci+ return RET_ERROR; 8550be168c0dSopenharmony_ci+ } else { 8551be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Build mindspore model from model file" << filename.c_str(); 8552be168c0dSopenharmony_ci+ std::cout << "Build mindspore model from model file" << filename.c_str() << std::endl; 8553be168c0dSopenharmony_ci+ status = OH_AI_TrainModelBuildFromFile(ms_model_, filename.c_str(), OH_AI_MODELTYPE_MINDIR, context, train_cfg); 8554be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8555be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "build transfer learning failed. " << model_name; 8556be168c0dSopenharmony_ci+ return RET_ERROR; 8557be168c0dSopenharmony_ci+ } 8558be168c0dSopenharmony_ci+ } 8559be168c0dSopenharmony_ci+ if (epochs > 0) { 8560be168c0dSopenharmony_ci+ if (flags_->virtual_batch_) { 8561be168c0dSopenharmony_ci+ OH_AI_ModelSetupVirtualBatch(ms_model_, epochs, -1.0f, -1.0f); 8562be168c0dSopenharmony_ci+ } 8563be168c0dSopenharmony_ci+ status = OH_AI_ModelSetTrainMode(ms_model_, true); 8564be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8565be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "set train mode failed. "; 8566be168c0dSopenharmony_ci+ return RET_ERROR; 8567be168c0dSopenharmony_ci+ } 8568be168c0dSopenharmony_ci+ } 8569be168c0dSopenharmony_ci+ return RET_OK; 8570be168c0dSopenharmony_ci+} 8571be168c0dSopenharmony_ci+ 8572be168c0dSopenharmony_ci+int NetTrainCApi::CompareOutput() { 8573be168c0dSopenharmony_ci+ std::cout << "================ Comparing Forward Output data ================" << std::endl; 8574be168c0dSopenharmony_ci+ float total_bias = 0; 8575be168c0dSopenharmony_ci+ int total_size = 0; 8576be168c0dSopenharmony_ci+ bool has_error = false; 8577be168c0dSopenharmony_ci+ auto output_tensors_handle = OH_AI_ModelGetOutputs(ms_model_); 8578be168c0dSopenharmony_ci+ 8579be168c0dSopenharmony_ci+ std::vector<mindspore::MSTensor> output_tensors; 8580be168c0dSopenharmony_ci+ for (size_t i = 0; i < output_tensors_handle.handle_num; i++) { 8581be168c0dSopenharmony_ci+ output_tensors.push_back(*static_cast<mindspore::MSTensor *>(output_tensors_handle.handle_list[i])); 8582be168c0dSopenharmony_ci+ } 8583be168c0dSopenharmony_ci+ if (output_tensors.empty()) { 8584be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Cannot find output tensors, get model output failed"; 8585be168c0dSopenharmony_ci+ return RET_ERROR; 8586be168c0dSopenharmony_ci+ } 8587be168c0dSopenharmony_ci+ std::map<std::string, MSTensor> ordered_outputs; 8588be168c0dSopenharmony_ci+ for (const auto &output_tensor : output_tensors) { 8589be168c0dSopenharmony_ci+ ordered_outputs.insert({output_tensor.Name(), output_tensor}); 8590be168c0dSopenharmony_ci+ } 8591be168c0dSopenharmony_ci+ int i = 1; 8592be168c0dSopenharmony_ci+ mindspore::MSTensor tensor; 8593be168c0dSopenharmony_ci+ for (auto &ordered_output : ordered_outputs) { 8594be168c0dSopenharmony_ci+ tensor = ordered_output.second; 8595be168c0dSopenharmony_ci+ std::cout << "output is tensor " << ordered_output.first << "\n"; 8596be168c0dSopenharmony_ci+ auto outputs = tensor.MutableData(); 8597be168c0dSopenharmony_ci+ size_t size; 8598be168c0dSopenharmony_ci+ std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin"; 8599be168c0dSopenharmony_ci+ auto bin_buf = std::unique_ptr<float[]>(ReadFileBuf(output_file.c_str(), &size)); 8600be168c0dSopenharmony_ci+ if (bin_buf == nullptr) { 8601be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "ReadFile return nullptr"; 8602be168c0dSopenharmony_ci+ std::cout << "ReadFile return nullptr" << std::endl; 8603be168c0dSopenharmony_ci+ return RET_ERROR; 8604be168c0dSopenharmony_ci+ } 8605be168c0dSopenharmony_ci+ if (size != tensor.DataSize()) { 8606be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize() 8607be168c0dSopenharmony_ci+ << ", read size: " << size; 8608be168c0dSopenharmony_ci+ std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize() 8609be168c0dSopenharmony_ci+ << ", read size: " << size << std::endl; 8610be168c0dSopenharmony_ci+ return RET_ERROR; 8611be168c0dSopenharmony_ci+ } 8612be168c0dSopenharmony_ci+ float bias = CompareData<float>(bin_buf.get(), tensor.ElementNum(), reinterpret_cast<float *>(outputs)); 8613be168c0dSopenharmony_ci+ if (bias >= 0) { 8614be168c0dSopenharmony_ci+ total_bias += bias; 8615be168c0dSopenharmony_ci+ total_size++; 8616be168c0dSopenharmony_ci+ } else { 8617be168c0dSopenharmony_ci+ has_error = true; 8618be168c0dSopenharmony_ci+ break; 8619be168c0dSopenharmony_ci+ } 8620be168c0dSopenharmony_ci+ i++; 8621be168c0dSopenharmony_ci+ } 8622be168c0dSopenharmony_ci+ 8623be168c0dSopenharmony_ci+ if (!has_error) { 8624be168c0dSopenharmony_ci+ float mean_bias; 8625be168c0dSopenharmony_ci+ if (total_size != 0) { 8626be168c0dSopenharmony_ci+ mean_bias = total_bias / total_size * 100; 8627be168c0dSopenharmony_ci+ } else { 8628be168c0dSopenharmony_ci+ mean_bias = 0; 8629be168c0dSopenharmony_ci+ } 8630be168c0dSopenharmony_ci+ 8631be168c0dSopenharmony_ci+ std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%" 8632be168c0dSopenharmony_ci+ << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl; 8633be168c0dSopenharmony_ci+ std::cout << "=======================================================" << std::endl << std::endl; 8634be168c0dSopenharmony_ci+ 8635be168c0dSopenharmony_ci+ if (mean_bias > this->flags_->accuracy_threshold_) { 8636be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%"; 8637be168c0dSopenharmony_ci+ std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl; 8638be168c0dSopenharmony_ci+ return RET_TOO_BIG; 8639be168c0dSopenharmony_ci+ } else { 8640be168c0dSopenharmony_ci+ return RET_OK; 8641be168c0dSopenharmony_ci+ } 8642be168c0dSopenharmony_ci+ } else { 8643be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Error in CompareData"; 8644be168c0dSopenharmony_ci+ std::cerr << "Error in CompareData" << std::endl; 8645be168c0dSopenharmony_ci+ std::cout << "=======================================================" << std::endl << std::endl; 8646be168c0dSopenharmony_ci+ return RET_ERROR; 8647be168c0dSopenharmony_ci+ } 8648be168c0dSopenharmony_ci+} 8649be168c0dSopenharmony_ci+ 8650be168c0dSopenharmony_ci+int NetTrainCApi::MarkPerformance() { 8651be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Running train loops..."; 8652be168c0dSopenharmony_ci+ std::cout << "Running train loops..." << std::endl; 8653be168c0dSopenharmony_ci+ uint64_t time_min = 0xFFFFFFFFFFFFFFFF; 8654be168c0dSopenharmony_ci+ uint64_t time_max = 0; 8655be168c0dSopenharmony_ci+ uint64_t time_avg = 0; 8656be168c0dSopenharmony_ci+ std::vector<MSTensor> outputs; 8657be168c0dSopenharmony_ci+ 8658be168c0dSopenharmony_ci+ for (int i = 0; i < flags_->epochs_; i++) { 8659be168c0dSopenharmony_ci+ auto start = GetTimeUs(); 8660be168c0dSopenharmony_ci+ for (size_t step = 0; step < batch_num_; step++) { 8661be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Run for epoch:" << i << " step:" << step; 8662be168c0dSopenharmony_ci+ auto ret = LoadStepInput(step); 8663be168c0dSopenharmony_ci+ if (ret != RET_OK) { 8664be168c0dSopenharmony_ci+ return ret; 8665be168c0dSopenharmony_ci+ } 8666be168c0dSopenharmony_ci+ auto status = OH_AI_RunStep(ms_model_, before_call_back_, after_call_back_); 8667be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8668be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Inference error " << status; 8669be168c0dSopenharmony_ci+ std::cerr << "Inference error " << status; 8670be168c0dSopenharmony_ci+ return RET_ERROR; 8671be168c0dSopenharmony_ci+ } 8672be168c0dSopenharmony_ci+ } 8673be168c0dSopenharmony_ci+ 8674be168c0dSopenharmony_ci+ auto end = GetTimeUs(); 8675be168c0dSopenharmony_ci+ auto time = end - start; 8676be168c0dSopenharmony_ci+ time_min = std::min(time_min, time); 8677be168c0dSopenharmony_ci+ time_max = std::max(time_max, time); 8678be168c0dSopenharmony_ci+ time_avg += time; 8679be168c0dSopenharmony_ci+ } 8680be168c0dSopenharmony_ci+ 8681be168c0dSopenharmony_ci+ if (flags_->time_profiling_) { 8682be168c0dSopenharmony_ci+ const std::vector<std::string> per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"}; 8683be168c0dSopenharmony_ci+ const std::vector<std::string> per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"}; 8684be168c0dSopenharmony_ci+ PrintResult(per_op_name, g_c_op_times_by_name_); 8685be168c0dSopenharmony_ci+ PrintResult(per_op_type, g_c_op_times_by_type_); 8686be168c0dSopenharmony_ci+ } 8687be168c0dSopenharmony_ci+ 8688be168c0dSopenharmony_ci+ if (flags_->epochs_ > 0) { 8689be168c0dSopenharmony_ci+ time_avg /= static_cast<size_t>(flags_->epochs_); 8690be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Model = " << flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str() 8691be168c0dSopenharmony_ci+ << ", NumThreads = " << flags_->num_threads_ << ", MinRunTime = " << time_min / 1000.0f 8692be168c0dSopenharmony_ci+ << ", MaxRuntime = " << time_max / 1000.0f << ", AvgRunTime = " << time_avg / 1000.0f; 8693be168c0dSopenharmony_ci+ printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n", 8694be168c0dSopenharmony_ci+ flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str(), flags_->num_threads_, 8695be168c0dSopenharmony_ci+ time_min / 1000.0f, time_max / 1000.0f, time_avg / 1000.0f); 8696be168c0dSopenharmony_ci+ } 8697be168c0dSopenharmony_ci+ return RET_OK; 8698be168c0dSopenharmony_ci+} 8699be168c0dSopenharmony_ci+ 8700be168c0dSopenharmony_ci+int NetTrainCApi::MarkAccuracy(bool enforce_accuracy) { 8701be168c0dSopenharmony_ci+ MS_LOG(INFO) << "MarkAccuracy"; 8702be168c0dSopenharmony_ci+ auto load_ret = LoadStepInput(0); 8703be168c0dSopenharmony_ci+ if (load_ret != RET_OK) { 8704be168c0dSopenharmony_ci+ return load_ret; 8705be168c0dSopenharmony_ci+ } 8706be168c0dSopenharmony_ci+ auto status = PrintInputData(); 8707be168c0dSopenharmony_ci+ if (status != RET_OK) { 8708be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "PrintInputData failed, ret: " << status; 8709be168c0dSopenharmony_ci+ return status; 8710be168c0dSopenharmony_ci+ } 8711be168c0dSopenharmony_ci+ status = OH_AI_RunStep(ms_model_, before_call_back_, after_call_back_); 8712be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8713be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Inference error " << status; 8714be168c0dSopenharmony_ci+ std::cerr << "Inference error " << status << std::endl; 8715be168c0dSopenharmony_ci+ return RET_ERROR; 8716be168c0dSopenharmony_ci+ } 8717be168c0dSopenharmony_ci+ 8718be168c0dSopenharmony_ci+ auto ret = CompareOutput(); 8719be168c0dSopenharmony_ci+ if (ret == RET_TOO_BIG && !enforce_accuracy) { 8720be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Accuracy Error is big but not enforced"; 8721be168c0dSopenharmony_ci+ std::cout << "Accuracy Error is big but not enforced" << std::endl; 8722be168c0dSopenharmony_ci+ return RET_OK; 8723be168c0dSopenharmony_ci+ } 8724be168c0dSopenharmony_ci+ 8725be168c0dSopenharmony_ci+ if (ret != RET_OK) { 8726be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Compare output error " << ret; 8727be168c0dSopenharmony_ci+ std::cerr << "Compare output error " << ret << std::endl; 8728be168c0dSopenharmony_ci+ return ret; 8729be168c0dSopenharmony_ci+ } 8730be168c0dSopenharmony_ci+ return RET_OK; 8731be168c0dSopenharmony_ci+} 8732be168c0dSopenharmony_ci+ 8733be168c0dSopenharmony_ci+int NetTrainCApi::CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, 8734be168c0dSopenharmony_ci+ int epochs, bool check_accuracy) { 8735be168c0dSopenharmony_ci+ auto start_prepare_time = GetTimeUs(); 8736be168c0dSopenharmony_ci+ 8737be168c0dSopenharmony_ci+ int ret = InitMSContext(); 8738be168c0dSopenharmony_ci+ if (ret != RET_OK) { 8739be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "InitContext failed, ret: " << ret; 8740be168c0dSopenharmony_ci+ return ret; 8741be168c0dSopenharmony_ci+ } 8742be168c0dSopenharmony_ci+ 8743be168c0dSopenharmony_ci+ InitTrainCfg(); 8744be168c0dSopenharmony_ci+ ms_model_ = OH_AI_ModelCreate(); 8745be168c0dSopenharmony_ci+ 8746be168c0dSopenharmony_ci+ if (is_train) { 8747be168c0dSopenharmony_ci+ ret = CreateAndRunNetworkForTrain(filename, bb_filename, context_ , train_cfg_, epochs); 8748be168c0dSopenharmony_ci+ if (ret != RET_OK) { 8749be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "CreateAndRunNetworkForTrain failed."; 8750be168c0dSopenharmony_ci+ return RET_ERROR; 8751be168c0dSopenharmony_ci+ } 8752be168c0dSopenharmony_ci+ } else { 8753be168c0dSopenharmony_ci+ ret = CreateAndRunNetworkForInference(filename, context_); 8754be168c0dSopenharmony_ci+ if (ret != RET_OK) { 8755be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed."; 8756be168c0dSopenharmony_ci+ return RET_ERROR; 8757be168c0dSopenharmony_ci+ } 8758be168c0dSopenharmony_ci+ } 8759be168c0dSopenharmony_ci+ 8760be168c0dSopenharmony_ci+ ms_inputs_for_api_ = OH_AI_ModelGetInputs(ms_model_); 8761be168c0dSopenharmony_ci+ if (ms_inputs_for_api_.handle_list == nullptr) { 8762be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "OH_AI_ModelGetInputs failed, ret: "; 8763be168c0dSopenharmony_ci+ return RET_ERROR; 8764be168c0dSopenharmony_ci+ } 8765be168c0dSopenharmony_ci+ 8766be168c0dSopenharmony_ci+ if (!flags_->resize_dims_.empty()) { 8767be168c0dSopenharmony_ci+ std::vector<OH_AI_ShapeInfo> shape_infos; 8768be168c0dSopenharmony_ci+ std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(shape_infos), 8769be168c0dSopenharmony_ci+ [&](auto &shapes) { 8770be168c0dSopenharmony_ci+ OH_AI_ShapeInfo shape_info; 8771be168c0dSopenharmony_ci+ shape_info.shape_num = shapes.size(); 8772be168c0dSopenharmony_ci+ for (size_t i = 0; i < shape_info.shape_num; i++) { 8773be168c0dSopenharmony_ci+ shape_info.shape[i] = shapes[i]; 8774be168c0dSopenharmony_ci+ } 8775be168c0dSopenharmony_ci+ return shape_info; 8776be168c0dSopenharmony_ci+ }); 8777be168c0dSopenharmony_ci+ auto status = OH_AI_ModelResize(ms_model_, ms_inputs_for_api_, shape_infos.data(), shape_infos.size()); 8778be168c0dSopenharmony_ci+ if (status != OH_AI_STATUS_SUCCESS) { 8779be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Input tensor resize failed."; 8780be168c0dSopenharmony_ci+ std::cout << "Input tensor resize failed."; 8781be168c0dSopenharmony_ci+ return RET_ERROR; 8782be168c0dSopenharmony_ci+ } 8783be168c0dSopenharmony_ci+ } 8784be168c0dSopenharmony_ci+ 8785be168c0dSopenharmony_ci+ auto end_prepare_time = GetTimeUs(); 8786be168c0dSopenharmony_ci+ MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms"; 8787be168c0dSopenharmony_ci+ std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl; 8788be168c0dSopenharmony_ci+ // Load input 8789be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Load input data"; 8790be168c0dSopenharmony_ci+ auto status = LoadInput(); 8791be168c0dSopenharmony_ci+ if (status != RET_OK) { 8792be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Load input data error"; 8793be168c0dSopenharmony_ci+ std::cout << "Load input data error" << std::endl; 8794be168c0dSopenharmony_ci+ return status; 8795be168c0dSopenharmony_ci+ } 8796be168c0dSopenharmony_ci+ 8797be168c0dSopenharmony_ci+ if ((epochs > 0) && is_train) { 8798be168c0dSopenharmony_ci+ status = MarkPerformance(); 8799be168c0dSopenharmony_ci+ if (status != RET_OK) { 8800be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Run MarkPerformance error: " << status; 8801be168c0dSopenharmony_ci+ std::cout << "Run MarkPerformance error: " << status << std::endl; 8802be168c0dSopenharmony_ci+ return status; 8803be168c0dSopenharmony_ci+ } 8804be168c0dSopenharmony_ci+ SaveModels(); // save file if flags are on 8805be168c0dSopenharmony_ci+ } 8806be168c0dSopenharmony_ci+ if (!flags_->data_file_.empty()) { 8807be168c0dSopenharmony_ci+ auto res = OH_AI_ModelSetTrainMode(ms_model_, false); 8808be168c0dSopenharmony_ci+ if (res != OH_AI_STATUS_SUCCESS) { 8809be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "set eval mode failed. "; 8810be168c0dSopenharmony_ci+ return RET_ERROR; 8811be168c0dSopenharmony_ci+ } 8812be168c0dSopenharmony_ci+ 8813be168c0dSopenharmony_ci+ status = MarkAccuracy(check_accuracy); 8814be168c0dSopenharmony_ci+ if (status != RET_OK) { 8815be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Run MarkAccuracy error: " << status; 8816be168c0dSopenharmony_ci+ std::cout << "Run MarkAccuracy error: " << status << std::endl; 8817be168c0dSopenharmony_ci+ return status; 8818be168c0dSopenharmony_ci+ } 8819be168c0dSopenharmony_ci+ } 8820be168c0dSopenharmony_ci+ return RET_OK; 8821be168c0dSopenharmony_ci+} 8822be168c0dSopenharmony_ci+ 8823be168c0dSopenharmony_ci+int NetTrainCApi::PrintInputData() { 8824be168c0dSopenharmony_ci+ constexpr int64_t kPrintDataNum = 20; 8825be168c0dSopenharmony_ci+ for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) { 8826be168c0dSopenharmony_ci+ auto input = ms_inputs_for_api_.handle_list[i]; 8827be168c0dSopenharmony_ci+ std::cout << "InData" << i << ": "; 8828be168c0dSopenharmony_ci+ auto data_type = static_cast<TypeId>(OH_AI_TensorGetDataType(input)); 8829be168c0dSopenharmony_ci+ if (data_type == TypeId::kObjectTypeString) { 8830be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Unsupported OH_AI_DATATYPE_OBJECTTYPE_STRING."; 8831be168c0dSopenharmony_ci+ return RET_ERROR; 8832be168c0dSopenharmony_ci+ } 8833be168c0dSopenharmony_ci+ auto tensor_data = OH_AI_TensorGetData(input); 8834be168c0dSopenharmony_ci+ size_t print_num = std::min(OH_AI_TensorGetElementNum(input), kPrintDataNum); 8835be168c0dSopenharmony_ci+ for (size_t j = 0; j < print_num; j++) { 8836be168c0dSopenharmony_ci+ if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat) { 8837be168c0dSopenharmony_ci+ std::cout << static_cast<const float *>(tensor_data)[j] << " "; 8838be168c0dSopenharmony_ci+ } else if (data_type == TypeId::kNumberTypeInt8) { 8839be168c0dSopenharmony_ci+ std::cout << static_cast<const int8_t *>(tensor_data)[j] << " "; 8840be168c0dSopenharmony_ci+ } else if (data_type == TypeId::kNumberTypeUInt8) { 8841be168c0dSopenharmony_ci+ std::cout << static_cast<const uint8_t *>(tensor_data)[j] << " "; 8842be168c0dSopenharmony_ci+ } else if (data_type == TypeId::kNumberTypeInt32) { 8843be168c0dSopenharmony_ci+ std::cout << static_cast<const int32_t *>(tensor_data)[j] << " "; 8844be168c0dSopenharmony_ci+ } else if (data_type == TypeId::kNumberTypeInt64) { 8845be168c0dSopenharmony_ci+ std::cout << static_cast<const int64_t *>(tensor_data)[j] << " "; 8846be168c0dSopenharmony_ci+ } else if (data_type == TypeId::kNumberTypeBool) { 8847be168c0dSopenharmony_ci+ std::cout << static_cast<const bool *>(tensor_data)[j] << " "; 8848be168c0dSopenharmony_ci+ } else { 8849be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Datatype: " << data_type << " is not supported."; 8850be168c0dSopenharmony_ci+ return RET_ERROR; 8851be168c0dSopenharmony_ci+ } 8852be168c0dSopenharmony_ci+ } 8853be168c0dSopenharmony_ci+ std::cout << std::endl; 8854be168c0dSopenharmony_ci+ } 8855be168c0dSopenharmony_ci+ return RET_OK; 8856be168c0dSopenharmony_ci+} 8857be168c0dSopenharmony_ci+ 8858be168c0dSopenharmony_ci+int NetTrainCApi::PrintResult(const std::vector<std::string> &title, 8859be168c0dSopenharmony_ci+ const std::map<std::string, std::pair<int, float>> &result) { 8860be168c0dSopenharmony_ci+ std::vector<size_t> columnLenMax(kFieldsToPrint); 8861be168c0dSopenharmony_ci+ std::vector<std::vector<std::string>> rows; 8862be168c0dSopenharmony_ci+ 8863be168c0dSopenharmony_ci+ for (auto &iter : result) { 8864be168c0dSopenharmony_ci+ std::string stringBuf[kFieldsToPrint]; 8865be168c0dSopenharmony_ci+ std::vector<std::string> columns; 8866be168c0dSopenharmony_ci+ size_t len = 0; 8867be168c0dSopenharmony_ci+ int index = 0; 8868be168c0dSopenharmony_ci+ len = iter.first.size(); 8869be168c0dSopenharmony_ci+ if (len > columnLenMax.at(index)) { 8870be168c0dSopenharmony_ci+ columnLenMax.at(index) = len + kPrintOffset; 8871be168c0dSopenharmony_ci+ } 8872be168c0dSopenharmony_ci+ columns.push_back(iter.first); 8873be168c0dSopenharmony_ci+ 8874be168c0dSopenharmony_ci+ index++; 8875be168c0dSopenharmony_ci+ if (title[0] == "opName") { 8876be168c0dSopenharmony_ci+ stringBuf[index] = std::to_string(iter.second.second / flags_->epochs_); 8877be168c0dSopenharmony_ci+ } else { 8878be168c0dSopenharmony_ci+ stringBuf[index] = std::to_string(iter.second.second / iter.second.first); 8879be168c0dSopenharmony_ci+ } 8880be168c0dSopenharmony_ci+ len = stringBuf[index].length(); 8881be168c0dSopenharmony_ci+ if (len > columnLenMax.at(index)) { 8882be168c0dSopenharmony_ci+ columnLenMax.at(index) = len + kPrintOffset; 8883be168c0dSopenharmony_ci+ } 8884be168c0dSopenharmony_ci+ columns.emplace_back(stringBuf[index]); 8885be168c0dSopenharmony_ci+ 8886be168c0dSopenharmony_ci+ index++; 8887be168c0dSopenharmony_ci+ stringBuf[index] = std::to_string(iter.second.second / g_op_cost_total_); 8888be168c0dSopenharmony_ci+ len = stringBuf[index].length(); 8889be168c0dSopenharmony_ci+ if (len > columnLenMax.at(index)) { 8890be168c0dSopenharmony_ci+ columnLenMax.at(index) = len + kPrintOffset; 8891be168c0dSopenharmony_ci+ } 8892be168c0dSopenharmony_ci+ columns.emplace_back(stringBuf[index]); 8893be168c0dSopenharmony_ci+ 8894be168c0dSopenharmony_ci+ index++; 8895be168c0dSopenharmony_ci+ stringBuf[index] = std::to_string(iter.second.first); 8896be168c0dSopenharmony_ci+ len = stringBuf[index].length(); 8897be168c0dSopenharmony_ci+ if (len > columnLenMax.at(index)) { 8898be168c0dSopenharmony_ci+ columnLenMax.at(index) = len + kPrintOffset; 8899be168c0dSopenharmony_ci+ } 8900be168c0dSopenharmony_ci+ columns.emplace_back(stringBuf[index]); 8901be168c0dSopenharmony_ci+ 8902be168c0dSopenharmony_ci+ index++; 8903be168c0dSopenharmony_ci+ stringBuf[index] = std::to_string(iter.second.second); 8904be168c0dSopenharmony_ci+ len = stringBuf[index].length(); 8905be168c0dSopenharmony_ci+ if (len > columnLenMax.at(index)) { 8906be168c0dSopenharmony_ci+ columnLenMax.at(index) = len + kPrintOffset; 8907be168c0dSopenharmony_ci+ } 8908be168c0dSopenharmony_ci+ columns.emplace_back(stringBuf[index]); 8909be168c0dSopenharmony_ci+ 8910be168c0dSopenharmony_ci+ rows.push_back(columns); 8911be168c0dSopenharmony_ci+ } 8912be168c0dSopenharmony_ci+ 8913be168c0dSopenharmony_ci+ printf("-------------------------------------------------------------------------\n"); 8914be168c0dSopenharmony_ci+ for (int i = 0; i < kFieldsToPrint; i++) { 8915be168c0dSopenharmony_ci+ auto printBuf = title[i]; 8916be168c0dSopenharmony_ci+ if (printBuf.size() > columnLenMax.at(i)) { 8917be168c0dSopenharmony_ci+ columnLenMax.at(i) = printBuf.size(); 8918be168c0dSopenharmony_ci+ } 8919be168c0dSopenharmony_ci+ printBuf.resize(columnLenMax.at(i), ' '); 8920be168c0dSopenharmony_ci+ printf("%s\t", printBuf.c_str()); 8921be168c0dSopenharmony_ci+ } 8922be168c0dSopenharmony_ci+ printf("\n"); 8923be168c0dSopenharmony_ci+ for (auto &row : rows) { 8924be168c0dSopenharmony_ci+ for (int j = 0; j < kFieldsToPrint; j++) { 8925be168c0dSopenharmony_ci+ auto printBuf = row[j]; 8926be168c0dSopenharmony_ci+ printBuf.resize(columnLenMax.at(j), ' '); 8927be168c0dSopenharmony_ci+ printf("%s\t", printBuf.c_str()); 8928be168c0dSopenharmony_ci+ } 8929be168c0dSopenharmony_ci+ printf("\n"); 8930be168c0dSopenharmony_ci+ } 8931be168c0dSopenharmony_ci+ return RET_OK; 8932be168c0dSopenharmony_ci+} 8933be168c0dSopenharmony_ci+ 8934be168c0dSopenharmony_ci+bool TimeProfilingBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 8935be168c0dSopenharmony_ci+ const OH_AI_CallBackParam kernel_Info) { 8936be168c0dSopenharmony_ci+ if (g_c_op_times_by_type_.find(kernel_Info.node_type) == g_c_op_times_by_type_.end()) { 8937be168c0dSopenharmony_ci+ g_c_op_times_by_type_.insert(std::make_pair(kernel_Info.node_type, std::make_pair(0, 0.0f))); 8938be168c0dSopenharmony_ci+ } 8939be168c0dSopenharmony_ci+ if (g_c_op_times_by_name_.find(kernel_Info.node_name) == g_c_op_times_by_name_.end()) { 8940be168c0dSopenharmony_ci+ g_c_op_times_by_name_.insert(std::make_pair(kernel_Info.node_name, std::make_pair(0, 0.0f))); 8941be168c0dSopenharmony_ci+ } 8942be168c0dSopenharmony_ci+ 8943be168c0dSopenharmony_ci+ g_op_call_times_total_++; 8944be168c0dSopenharmony_ci+ g_op_begin_ = mindspore::lite::GetTimeUs(); 8945be168c0dSopenharmony_ci+ return true; 8946be168c0dSopenharmony_ci+} 8947be168c0dSopenharmony_ci+ 8948be168c0dSopenharmony_ci+bool TimeProfilingAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 8949be168c0dSopenharmony_ci+ const OH_AI_CallBackParam kernel_Info) { 8950be168c0dSopenharmony_ci+ uint64_t opEnd = mindspore::lite::GetTimeUs(); 8951be168c0dSopenharmony_ci+ float cost = static_cast<float>(opEnd - g_op_begin_) / 1000.0f; 8952be168c0dSopenharmony_ci+ g_op_cost_total_ += cost; 8953be168c0dSopenharmony_ci+ g_c_op_times_by_type_[kernel_Info.node_type].first++; 8954be168c0dSopenharmony_ci+ g_c_op_times_by_type_[kernel_Info.node_type].second += cost; 8955be168c0dSopenharmony_ci+ g_c_op_times_by_name_[kernel_Info.node_name].first++; 8956be168c0dSopenharmony_ci+ g_c_op_times_by_name_[kernel_Info.node_name].second += cost; 8957be168c0dSopenharmony_ci+ return true; 8958be168c0dSopenharmony_ci+} 8959be168c0dSopenharmony_ci+} // namespace lite 8960be168c0dSopenharmony_ci+} // namespace mindspore 8961be168c0dSopenharmony_ci+ 8962be168c0dSopenharmony_ci+ 8963be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train_c_api.h b/mindspore/lite/tools/benchmark_train/net_train_c_api.h 8964be168c0dSopenharmony_cinew file mode 100644 8965be168c0dSopenharmony_ciindex 00000000..bb84d3c1 8966be168c0dSopenharmony_ci--- /dev/null 8967be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train_c_api.h 8968be168c0dSopenharmony_ci@@ -0,0 +1,121 @@ 8969be168c0dSopenharmony_ci+/** 8970be168c0dSopenharmony_ci+ * Copyright 2023-2023 Huawei Technologies Co., Ltd 8971be168c0dSopenharmony_ci+ * 8972be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 8973be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 8974be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 8975be168c0dSopenharmony_ci+ * 8976be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 8977be168c0dSopenharmony_ci+ * 8978be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 8979be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 8980be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8981be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 8982be168c0dSopenharmony_ci+ * limitations under the License. 8983be168c0dSopenharmony_ci+ */ 8984be168c0dSopenharmony_ci+ 8985be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H 8986be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H 8987be168c0dSopenharmony_ci+ 8988be168c0dSopenharmony_ci+#include <getopt.h> 8989be168c0dSopenharmony_ci+#include <csignal> 8990be168c0dSopenharmony_ci+#include <unordered_map> 8991be168c0dSopenharmony_ci+#include <fstream> 8992be168c0dSopenharmony_ci+#include <iostream> 8993be168c0dSopenharmony_ci+#include <map> 8994be168c0dSopenharmony_ci+#include <cmath> 8995be168c0dSopenharmony_ci+#include <string> 8996be168c0dSopenharmony_ci+#include <vector> 8997be168c0dSopenharmony_ci+#include <memory> 8998be168c0dSopenharmony_ci+#include <cfloat> 8999be168c0dSopenharmony_ci+#include <utility> 9000be168c0dSopenharmony_ci+#include <algorithm> 9001be168c0dSopenharmony_ci+#include <nlohmann/json.hpp> 9002be168c0dSopenharmony_ci+#include "include/api/model.h" 9003be168c0dSopenharmony_ci+#include "include/api/types.h" 9004be168c0dSopenharmony_ci+#include "include/api/context.h" 9005be168c0dSopenharmony_ci+#include "include/api/cfg.h" 9006be168c0dSopenharmony_ci+ 9007be168c0dSopenharmony_ci+#include "include/c_api/model_c.h" 9008be168c0dSopenharmony_ci+#include "include/c_api/context_c.h" 9009be168c0dSopenharmony_ci+ 9010be168c0dSopenharmony_ci+#ifdef ENABLE_FP16 9011be168c0dSopenharmony_ci+#include <arm_neon.h> 9012be168c0dSopenharmony_ci+#endif 9013be168c0dSopenharmony_ci+#include "tools/common/flag_parser.h" 9014be168c0dSopenharmony_ci+#include "src/common/file_utils.h" 9015be168c0dSopenharmony_ci+#include "src/common/utils.h" 9016be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_base.h" 9017be168c0dSopenharmony_ci+ 9018be168c0dSopenharmony_ci+namespace mindspore::lite { 9019be168c0dSopenharmony_ci+ namespace { 9020be168c0dSopenharmony_ci+ std::map<std::string, std::pair<int, float>> g_c_op_times_by_type_; 9021be168c0dSopenharmony_ci+ std::map<std::string, std::pair<int, float>> g_c_op_times_by_name_; 9022be168c0dSopenharmony_ci+ } 9023be168c0dSopenharmony_ci+#ifdef __cplusplus 9024be168c0dSopenharmony_ci+ extern "C" { 9025be168c0dSopenharmony_ci+#endif 9026be168c0dSopenharmony_ci+ bool TimeProfilingBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 9027be168c0dSopenharmony_ci+ const OH_AI_CallBackParam kernel_Info); 9028be168c0dSopenharmony_ci+ bool TimeProfilingAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 9029be168c0dSopenharmony_ci+ const OH_AI_CallBackParam kernel_Info); 9030be168c0dSopenharmony_ci+#ifdef __cplusplus 9031be168c0dSopenharmony_ci+ } 9032be168c0dSopenharmony_ci+#endif 9033be168c0dSopenharmony_ci+ 9034be168c0dSopenharmony_ci+class MS_API NetTrainCApi : public NetTrainBase { 9035be168c0dSopenharmony_ci+ public: 9036be168c0dSopenharmony_ci+ explicit NetTrainCApi(NetTrainFlags *flags) : NetTrainBase(flags) {} 9037be168c0dSopenharmony_ci+ virtual ~NetTrainCApi() {}; 9038be168c0dSopenharmony_ci+ 9039be168c0dSopenharmony_ci+ protected: 9040be168c0dSopenharmony_ci+ // call GenerateRandomData to fill inputTensors 9041be168c0dSopenharmony_ci+ int GenerateInputData() override; 9042be168c0dSopenharmony_ci+ 9043be168c0dSopenharmony_ci+ int ReadInputFile() override; 9044be168c0dSopenharmony_ci+ 9045be168c0dSopenharmony_ci+ int LoadStepInput(size_t step); 9046be168c0dSopenharmony_ci+ 9047be168c0dSopenharmony_ci+ int InitMSContext(); 9048be168c0dSopenharmony_ci+ 9049be168c0dSopenharmony_ci+ void InitTrainCfg(); 9050be168c0dSopenharmony_ci+ 9051be168c0dSopenharmony_ci+ char **TransStrVectorToCharArrays(const std::vector<std::string> &s); 9052be168c0dSopenharmony_ci+ 9053be168c0dSopenharmony_ci+ std::vector<std::string> TransCharArraysToStrVector(char **c, const size_t &num); 9054be168c0dSopenharmony_ci+ 9055be168c0dSopenharmony_ci+ int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs, 9056be168c0dSopenharmony_ci+ bool check_accuracy = true) override; 9057be168c0dSopenharmony_ci+ 9058be168c0dSopenharmony_ci+ int CreateAndRunNetworkForInference(const std::string &filename, const OH_AI_ContextHandle &context); 9059be168c0dSopenharmony_ci+ 9060be168c0dSopenharmony_ci+ int CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename, 9061be168c0dSopenharmony_ci+ const OH_AI_ContextHandle &context, 9062be168c0dSopenharmony_ci+ const OH_AI_TrainCfgHandle &train_cfg, int epochs); 9063be168c0dSopenharmony_ci+ 9064be168c0dSopenharmony_ci+ int InitDumpTensorDataCallbackParameter() override; 9065be168c0dSopenharmony_ci+ 9066be168c0dSopenharmony_ci+ int InitTimeProfilingCallbackParameter() override; 9067be168c0dSopenharmony_ci+ 9068be168c0dSopenharmony_ci+ int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) override; 9069be168c0dSopenharmony_ci+ 9070be168c0dSopenharmony_ci+ int PrintInputData(); 9071be168c0dSopenharmony_ci+ 9072be168c0dSopenharmony_ci+ int MarkPerformance() override; 9073be168c0dSopenharmony_ci+ 9074be168c0dSopenharmony_ci+ int MarkAccuracy(bool enforce_accuracy = true) override; 9075be168c0dSopenharmony_ci+ 9076be168c0dSopenharmony_ci+ int CompareOutput() override; 9077be168c0dSopenharmony_ci+ 9078be168c0dSopenharmony_ci+ int SaveModels() override; 9079be168c0dSopenharmony_ci+ 9080be168c0dSopenharmony_ci+ OH_AI_ModelHandle ms_model_; 9081be168c0dSopenharmony_ci+ OH_AI_TensorHandleArray ms_inputs_for_api_; 9082be168c0dSopenharmony_ci+ OH_AI_ContextHandle context_ = nullptr; 9083be168c0dSopenharmony_ci+ OH_AI_TrainCfgHandle train_cfg_ = nullptr; 9084be168c0dSopenharmony_ci+ OH_AI_KernelCallBack before_call_back_{nullptr}; 9085be168c0dSopenharmony_ci+ OH_AI_KernelCallBack after_call_back_{nullptr}; 9086be168c0dSopenharmony_ci+}; 9087be168c0dSopenharmony_ci+} // namespace mindspore::lite 9088be168c0dSopenharmony_ci+ 9089be168c0dSopenharmony_ci+#endif //MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H 9090be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/run_net_train.cc b/mindspore/lite/tools/benchmark_train/run_net_train.cc 9091be168c0dSopenharmony_cinew file mode 100644 9092be168c0dSopenharmony_ciindex 00000000..37a7e602 9093be168c0dSopenharmony_ci--- /dev/null 9094be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/run_net_train.cc 9095be168c0dSopenharmony_ci@@ -0,0 +1,86 @@ 9096be168c0dSopenharmony_ci+/** 9097be168c0dSopenharmony_ci+ * Copyright 2020 Huawei Technologies Co., Ltd 9098be168c0dSopenharmony_ci+ * 9099be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 9100be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 9101be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 9102be168c0dSopenharmony_ci+ * 9103be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 9104be168c0dSopenharmony_ci+ * 9105be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 9106be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 9107be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9108be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 9109be168c0dSopenharmony_ci+ * limitations under the License. 9110be168c0dSopenharmony_ci+ */ 9111be168c0dSopenharmony_ci+ 9112be168c0dSopenharmony_ci+#include "tools/benchmark_train/run_net_train.h" 9113be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train.h" 9114be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_c_api.h" 9115be168c0dSopenharmony_ci+ 9116be168c0dSopenharmony_ci+namespace mindspore { 9117be168c0dSopenharmony_ci+namespace lite { 9118be168c0dSopenharmony_ci+int RunNetTrain(int argc, const char **argv) { 9119be168c0dSopenharmony_ci+ NetTrainFlags flags; 9120be168c0dSopenharmony_ci+ Option<std::string> err = flags.ParseFlags(argc, argv); 9121be168c0dSopenharmony_ci+ 9122be168c0dSopenharmony_ci+ if (err.IsSome()) { 9123be168c0dSopenharmony_ci+ std::cerr << err.Get() << std::endl; 9124be168c0dSopenharmony_ci+ std::cerr << flags.Usage() << std::endl; 9125be168c0dSopenharmony_ci+ return RET_ERROR; 9126be168c0dSopenharmony_ci+ } 9127be168c0dSopenharmony_ci+ 9128be168c0dSopenharmony_ci+ if (flags.help) { 9129be168c0dSopenharmony_ci+ std::cerr << flags.Usage() << std::endl; 9130be168c0dSopenharmony_ci+ return RET_OK; 9131be168c0dSopenharmony_ci+ } 9132be168c0dSopenharmony_ci+ if (flags.unified_api_) { 9133be168c0dSopenharmony_ci+ return NetTrain::RunNr(&flags); 9134be168c0dSopenharmony_ci+ } 9135be168c0dSopenharmony_ci+ 9136be168c0dSopenharmony_ci+ auto api_type = std::getenv("MSLITE_API_TYPE"); 9137be168c0dSopenharmony_ci+ if (api_type != nullptr) { 9138be168c0dSopenharmony_ci+ MS_LOG(INFO) << "MSLITE_API_TYPE = " << api_type; 9139be168c0dSopenharmony_ci+ std::cout << "MSLITE_API_TYPE = " << api_type << std::endl; 9140be168c0dSopenharmony_ci+ } 9141be168c0dSopenharmony_ci+ 9142be168c0dSopenharmony_ci+ NetTrainBase *net_trainer = nullptr; 9143be168c0dSopenharmony_ci+ if (api_type == nullptr || std::string(api_type) == "NEW") { 9144be168c0dSopenharmony_ci+ net_trainer = new (std::nothrow) NetTrain(&flags); 9145be168c0dSopenharmony_ci+ } else if (std::string(api_type) == "C") { 9146be168c0dSopenharmony_ci+ net_trainer = new (std::nothrow) NetTrainCApi(&flags); 9147be168c0dSopenharmony_ci+ } else { 9148be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Invalid MSLITE_API_TYPE, (NEW/C, default:NEW)"; 9149be168c0dSopenharmony_ci+ return RET_ERROR; 9150be168c0dSopenharmony_ci+ } 9151be168c0dSopenharmony_ci+ 9152be168c0dSopenharmony_ci+ if (net_trainer == nullptr) { 9153be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new net_trainer failed."; 9154be168c0dSopenharmony_ci+ return RET_ERROR; 9155be168c0dSopenharmony_ci+ } 9156be168c0dSopenharmony_ci+ auto status = net_trainer->Init(); 9157be168c0dSopenharmony_ci+ if (status != RET_OK) { 9158be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "NetTrain init Error : " << status; 9159be168c0dSopenharmony_ci+ std::cerr << "NetTrain init Error : " << status << std::endl; 9160be168c0dSopenharmony_ci+ return RET_ERROR; 9161be168c0dSopenharmony_ci+ } 9162be168c0dSopenharmony_ci+ 9163be168c0dSopenharmony_ci+ status = net_trainer->RunNetTrain(); 9164be168c0dSopenharmony_ci+ if (status != RET_OK) { 9165be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Run NetTrain " 9166be168c0dSopenharmony_ci+ << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str() 9167be168c0dSopenharmony_ci+ << " Failed : " << status; 9168be168c0dSopenharmony_ci+ std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str() 9169be168c0dSopenharmony_ci+ << " Failed : " << status << std::endl; 9170be168c0dSopenharmony_ci+ return RET_ERROR; 9171be168c0dSopenharmony_ci+ } 9172be168c0dSopenharmony_ci+ 9173be168c0dSopenharmony_ci+ MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str() 9174be168c0dSopenharmony_ci+ << " Success."; 9175be168c0dSopenharmony_ci+ std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str() 9176be168c0dSopenharmony_ci+ << " Success." << std::endl; 9177be168c0dSopenharmony_ci+ delete net_trainer; 9178be168c0dSopenharmony_ci+ return RET_OK; 9179be168c0dSopenharmony_ci+} 9180be168c0dSopenharmony_ci+} // namespace lite 9181be168c0dSopenharmony_ci+} // namespace mindspore 9182be168c0dSopenharmony_ci\ No newline at end of file 9183be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/run_net_train.h b/mindspore/lite/tools/benchmark_train/run_net_train.h 9184be168c0dSopenharmony_cinew file mode 100644 9185be168c0dSopenharmony_ciindex 00000000..9ca2d73c 9186be168c0dSopenharmony_ci--- /dev/null 9187be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/run_net_train.h 9188be168c0dSopenharmony_ci@@ -0,0 +1,22 @@ 9189be168c0dSopenharmony_ci+/** 9190be168c0dSopenharmony_ci+ * Copyright 2023-2023 Huawei Technologies Co., Ltd 9191be168c0dSopenharmony_ci+ * 9192be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 9193be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 9194be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 9195be168c0dSopenharmony_ci+ * 9196be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 9197be168c0dSopenharmony_ci+ * 9198be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 9199be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 9200be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9201be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 9202be168c0dSopenharmony_ci+ * limitations under the License. 9203be168c0dSopenharmony_ci+ */ 9204be168c0dSopenharmony_ci+ 9205be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H 9206be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H 9207be168c0dSopenharmony_ci+namespace mindspore::lite { 9208be168c0dSopenharmony_ci+int RunNetTrain(int argc, const char **argv); 9209be168c0dSopenharmony_ci+} // namespace mindspore::lite 9210be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H 9211be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt 9212be168c0dSopenharmony_ciindex 1e09d2ed..f854620f 100644 9213be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/CMakeLists.txt 9214be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/CMakeLists.txt 9215be168c0dSopenharmony_ci@@ -7,6 +7,8 @@ endif() 9216be168c0dSopenharmony_ci 9217be168c0dSopenharmony_ci set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src) 9218be168c0dSopenharmony_ci 9219be168c0dSopenharmony_ci+include_directories(${CMAKE_SOURCE_DIR}/mindspore/lite/) 9220be168c0dSopenharmony_ci+ 9221be168c0dSopenharmony_ci if(ENABLE_GPU) 9222be168c0dSopenharmony_ci add_compile_definitions(ENABLE_GPU) 9223be168c0dSopenharmony_ci endif() 9224be168c0dSopenharmony_ci@@ -70,6 +72,7 @@ add_subdirectory(parser/caffe) 9225be168c0dSopenharmony_ci add_subdirectory(parser/tflite) 9226be168c0dSopenharmony_ci add_subdirectory(parser/onnx) 9227be168c0dSopenharmony_ci add_subdirectory(parser/tf) 9228be168c0dSopenharmony_ci+add_subdirectory(parser/third_party) 9229be168c0dSopenharmony_ci if(ENABLE_CONVERT_PYTORCH_MODEL) 9230be168c0dSopenharmony_ci add_subdirectory(parser/pytorch) 9231be168c0dSopenharmony_ci endif() 9232be168c0dSopenharmony_ci@@ -363,6 +366,7 @@ target_link_libraries(mindspore_converter 9233be168c0dSopenharmony_ci tf_parser_mid 9234be168c0dSopenharmony_ci caffe_parser_mid 9235be168c0dSopenharmony_ci onnx_parser_mid 9236be168c0dSopenharmony_ci+ third_party_parser_mid 9237be168c0dSopenharmony_ci lite_exporter_mid 9238be168c0dSopenharmony_ci graph_pass_mid 9239be168c0dSopenharmony_ci fusion_mid 9240be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9241be168c0dSopenharmony_ciindex fecc56d9..2e7ca749 100644 9242be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9243be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc 9244be168c0dSopenharmony_ci@@ -34,6 +34,7 @@ constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param"; 9245be168c0dSopenharmony_ci constexpr auto kDataPreprocessParam = "data_preprocess_param"; 9246be168c0dSopenharmony_ci constexpr auto kRegistry = "registry"; 9247be168c0dSopenharmony_ci constexpr auto kMicroParam = "micro_param"; 9248be168c0dSopenharmony_ci+constexpr auto kThirdPartyModelParam = "third_party_model"; 9249be168c0dSopenharmony_ci constexpr auto kCpuOptionParam = "cpu_option_cfg_param"; 9250be168c0dSopenharmony_ci constexpr auto kCustomOppPath = "custom_opp_path"; 9251be168c0dSopenharmony_ci constexpr auto kTransformQuantParam = "transform_quant_param"; 9252be168c0dSopenharmony_ci@@ -330,6 +331,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin 9253be168c0dSopenharmony_ci MS_LOG(ERROR) << "ParseMicroParamString failed."; 9254be168c0dSopenharmony_ci return ret; 9255be168c0dSopenharmony_ci } 9256be168c0dSopenharmony_ci+ ret = ParseThirdPartyParamString(*maps); 9257be168c0dSopenharmony_ci+ (void)maps->erase(kThirdPartyModelParam); 9258be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9259be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "ParseTransformQuantString failed."; 9260be168c0dSopenharmony_ci+ return ret; 9261be168c0dSopenharmony_ci+ } 9262be168c0dSopenharmony_ci ret = ParseWeightQuantString(*maps); 9263be168c0dSopenharmony_ci (void)maps->erase(kWeightQuantParam); 9264be168c0dSopenharmony_ci if (ret != RET_OK) { 9265be168c0dSopenharmony_ci@@ -594,5 +601,25 @@ int ConfigFileParser::ParseGraphKernelString(const std::map<std::string, std::ma 9266be168c0dSopenharmony_ci } 9267be168c0dSopenharmony_ci return RET_OK; 9268be168c0dSopenharmony_ci } 9269be168c0dSopenharmony_ci+ 9270be168c0dSopenharmony_ci+int ConfigFileParser::ParseThirdPartyParamString( 9271be168c0dSopenharmony_ci+ const std::map<std::string, std::map<std::string, std::string>> §ions) { 9272be168c0dSopenharmony_ci+ if (sections.find(kThirdPartyModelParam) == sections.end()) { 9273be168c0dSopenharmony_ci+ return RET_OK; 9274be168c0dSopenharmony_ci+ } 9275be168c0dSopenharmony_ci+ const auto &input_args = sections.at(kThirdPartyModelParam); 9276be168c0dSopenharmony_ci+ const std::map<std::string, std::string &> kValidArgs = { 9277be168c0dSopenharmony_ci+ {"input_shapes", third_party_model_string_.input_shapes}, 9278be168c0dSopenharmony_ci+ {"input_dtypes", third_party_model_string_.input_dtypes}, 9279be168c0dSopenharmony_ci+ {"input_names", third_party_model_string_.input_names}, 9280be168c0dSopenharmony_ci+ {"input_formats", third_party_model_string_.input_formats}, 9281be168c0dSopenharmony_ci+ {"output_shapes", third_party_model_string_.output_shapes}, 9282be168c0dSopenharmony_ci+ {"output_dtypes", third_party_model_string_.output_dtypes}, 9283be168c0dSopenharmony_ci+ {"output_names", third_party_model_string_.output_names}, 9284be168c0dSopenharmony_ci+ {"output_formats", third_party_model_string_.output_formats}, 9285be168c0dSopenharmony_ci+ {"extended_parameters", third_party_model_string_.extended_parameters}, 9286be168c0dSopenharmony_ci+ }; 9287be168c0dSopenharmony_ci+ return SetMapData(input_args, kValidArgs, kThirdPartyModelParam); 9288be168c0dSopenharmony_ci+} 9289be168c0dSopenharmony_ci } // namespace lite 9290be168c0dSopenharmony_ci } // namespace mindspore 9291be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9292be168c0dSopenharmony_ciindex 31269816..6997bac8 100644 9293be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9294be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h 9295be168c0dSopenharmony_ci@@ -110,6 +110,18 @@ struct MicroParamString { 9296be168c0dSopenharmony_ci std::string changeable_weights_name; 9297be168c0dSopenharmony_ci }; 9298be168c0dSopenharmony_ci 9299be168c0dSopenharmony_ci+struct ThirdPartyModelString { 9300be168c0dSopenharmony_ci+ std::string input_dtypes; 9301be168c0dSopenharmony_ci+ std::string input_shapes; 9302be168c0dSopenharmony_ci+ std::string input_names; // optional, default: "" 9303be168c0dSopenharmony_ci+ std::string input_formats; // optional, default: NHWC 9304be168c0dSopenharmony_ci+ std::string output_dtypes; 9305be168c0dSopenharmony_ci+ std::string output_shapes; 9306be168c0dSopenharmony_ci+ std::string output_names; // optional, default: "" 9307be168c0dSopenharmony_ci+ std::string output_formats; // optional, default: NHWC 9308be168c0dSopenharmony_ci+ std::string extended_parameters; // format: {key1:value1;ker2:value2} 9309be168c0dSopenharmony_ci+}; 9310be168c0dSopenharmony_ci+ 9311be168c0dSopenharmony_ci struct CpuOptionCfgString { 9312be168c0dSopenharmony_ci std::string architecture; 9313be168c0dSopenharmony_ci std::string instruction; 9314be168c0dSopenharmony_ci@@ -144,6 +156,7 @@ class ConfigFileParser { 9315be168c0dSopenharmony_ci RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; } 9316be168c0dSopenharmony_ci AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; } 9317be168c0dSopenharmony_ci MicroParamString GetMicroParamString() { return this->micro_param_string_; } 9318be168c0dSopenharmony_ci+ lite::ThirdPartyModelString GetThirdPartyModelString() const { return this->third_party_model_string_; } 9319be168c0dSopenharmony_ci CpuOptionCfgString GetCpuOptionCfgString() { return this->cpu_option_cfg_string_; } 9320be168c0dSopenharmony_ci TransformQuantString GetTransformQuantString() const { return this->transform_quant_string_; } 9321be168c0dSopenharmony_ci AscendQuantString GetAscendQuantString() const { return this->ascend_quant_string_; } 9322be168c0dSopenharmony_ci@@ -161,6 +174,7 @@ class ConfigFileParser { 9323be168c0dSopenharmony_ci int SetMapData(const std::map<std::string, std::string> &input_map, 9324be168c0dSopenharmony_ci const std::map<std::string, std::string &> &parse_map, const std::string §ion); 9325be168c0dSopenharmony_ci int ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9326be168c0dSopenharmony_ci+ int ParseThirdPartyParamString(const std::map<std::string, std::map<std::string, std::string>> §ions); 9327be168c0dSopenharmony_ci int ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9328be168c0dSopenharmony_ci int ParseTransformQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9329be168c0dSopenharmony_ci int ParseAscendQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps); 9330be168c0dSopenharmony_ci@@ -176,6 +190,7 @@ class ConfigFileParser { 9331be168c0dSopenharmony_ci RegistryInfoString registry_info_string_; 9332be168c0dSopenharmony_ci AclOptionCfgString acl_option_cfg_string_; 9333be168c0dSopenharmony_ci MicroParamString micro_param_string_; 9334be168c0dSopenharmony_ci+ lite::ThirdPartyModelString third_party_model_string_; 9335be168c0dSopenharmony_ci CpuOptionCfgString cpu_option_cfg_string_; 9336be168c0dSopenharmony_ci TransformQuantString transform_quant_string_; 9337be168c0dSopenharmony_ci AscendQuantString ascend_quant_string_; 9338be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 9339be168c0dSopenharmony_cinew file mode 100644 9340be168c0dSopenharmony_ciindex 00000000..aee6a29c 9341be168c0dSopenharmony_ci--- /dev/null 9342be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc 9343be168c0dSopenharmony_ci@@ -0,0 +1,299 @@ 9344be168c0dSopenharmony_ci+/** 9345be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 9346be168c0dSopenharmony_ci+ * 9347be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 9348be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 9349be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 9350be168c0dSopenharmony_ci+ * 9351be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 9352be168c0dSopenharmony_ci+ * 9353be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 9354be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 9355be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9356be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 9357be168c0dSopenharmony_ci+ * limitations under the License. 9358be168c0dSopenharmony_ci+ */ 9359be168c0dSopenharmony_ci+ 9360be168c0dSopenharmony_ci+#include "tools/converter/config_parser/third_party_param_parser.h" 9361be168c0dSopenharmony_ci+#include <vector> 9362be168c0dSopenharmony_ci+#include <string> 9363be168c0dSopenharmony_ci+#include <map> 9364be168c0dSopenharmony_ci+#include "include/errorcode.h" 9365be168c0dSopenharmony_ci+#include "src/common/log_adapter.h" 9366be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 9367be168c0dSopenharmony_ci+#include "tools/common/string_util.h" 9368be168c0dSopenharmony_ci+ 9369be168c0dSopenharmony_ci+namespace mindspore { 9370be168c0dSopenharmony_ci+namespace lite { 9371be168c0dSopenharmony_ci+namespace { 9372be168c0dSopenharmony_ci+const std::map<std::string, TypeId> kDataTypeMap = { 9373be168c0dSopenharmony_ci+ {"float64", TypeId::kNumberTypeFloat64}, {"float32", TypeId::kNumberTypeFloat32}, 9374be168c0dSopenharmony_ci+ {"float16", TypeId::kNumberTypeFloat16}, {"int64", TypeId::kNumberTypeInt64}, 9375be168c0dSopenharmony_ci+ {"int32", TypeId::kNumberTypeInt32}, {"int16", TypeId::kNumberTypeInt16}, 9376be168c0dSopenharmony_ci+ {"int8", TypeId::kNumberTypeInt8}, {"uint8", TypeId::kNumberTypeUInt8}, 9377be168c0dSopenharmony_ci+ {"bool", TypeId::kNumberTypeBool}, 9378be168c0dSopenharmony_ci+}; 9379be168c0dSopenharmony_ci+ 9380be168c0dSopenharmony_ci+TypeId ConvertDataType(const std::string &type) { 9381be168c0dSopenharmony_ci+ auto iter = kDataTypeMap.find(type); 9382be168c0dSopenharmony_ci+ if (iter == kDataTypeMap.end()) { 9383be168c0dSopenharmony_ci+ return TypeId::kTypeUnknown; 9384be168c0dSopenharmony_ci+ } 9385be168c0dSopenharmony_ci+ return iter->second; 9386be168c0dSopenharmony_ci+} 9387be168c0dSopenharmony_ci+} // namespace 9388be168c0dSopenharmony_ci+ 9389be168c0dSopenharmony_ci+/** 9390be168c0dSopenharmony_ci+ * Parse shapes like "1,256,256,3;3,96;96,96", and return like [[1,256,256,3], [3,96], [96,96]]. 9391be168c0dSopenharmony_ci+ */ 9392be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes) { 9393be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(dst_shapes != nullptr, RET_ERROR); 9394be168c0dSopenharmony_ci+ dst_shapes->clear(); 9395be168c0dSopenharmony_ci+ 9396be168c0dSopenharmony_ci+ auto tmp_shapes = SplitStringToVector(src, ";"); 9397be168c0dSopenharmony_ci+ for (auto tmp_shape : tmp_shapes) { 9398be168c0dSopenharmony_ci+ auto tmp = SplitStringToVector(tmp_shape, ","); 9399be168c0dSopenharmony_ci+ std::vector<int64_t> shape = {}; 9400be168c0dSopenharmony_ci+ for (auto t : tmp) { 9401be168c0dSopenharmony_ci+ int value = 0; 9402be168c0dSopenharmony_ci+ if (!ConvertIntNum(t, &value)) { 9403be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Found error when convert shape string to integer"; 9404be168c0dSopenharmony_ci+ return RET_ERROR; 9405be168c0dSopenharmony_ci+ } 9406be168c0dSopenharmony_ci+ if (value <= 0) { // Valid shape value should be greater than 0. 9407be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Only support fixed shapes in third party param"; 9408be168c0dSopenharmony_ci+ return RET_ERROR; 9409be168c0dSopenharmony_ci+ } 9410be168c0dSopenharmony_ci+ shape.push_back(value); 9411be168c0dSopenharmony_ci+ } 9412be168c0dSopenharmony_ci+ dst_shapes->push_back(shape); 9413be168c0dSopenharmony_ci+ } 9414be168c0dSopenharmony_ci+ return RET_OK; 9415be168c0dSopenharmony_ci+} 9416be168c0dSopenharmony_ci+ 9417be168c0dSopenharmony_ci+/** 9418be168c0dSopenharmony_ci+ * Parse extended parameter like "key_1:value_1;key_2:value_2" and get {{"key_1", "value_1"}, {"key_2", "value_2"}}. 9419be168c0dSopenharmony_ci+ */ 9420be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseExtendedParameters(const std::string &src, 9421be168c0dSopenharmony_ci+ std::map<std::string, std::vector<uint8_t>> *dst_ext_param) { 9422be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(dst_ext_param != nullptr, RET_ERROR); 9423be168c0dSopenharmony_ci+ constexpr size_t kKeyIndex = 0U; 9424be168c0dSopenharmony_ci+ constexpr size_t kValueIndex = 1U; 9425be168c0dSopenharmony_ci+ constexpr size_t kKeyValueSize = 2U; 9426be168c0dSopenharmony_ci+ 9427be168c0dSopenharmony_ci+ if (src == "") { // Just return if 'extended_parameters' is configured. 9428be168c0dSopenharmony_ci+ return RET_OK; 9429be168c0dSopenharmony_ci+ } 9430be168c0dSopenharmony_ci+ 9431be168c0dSopenharmony_ci+ auto tmp_list = SplitStringToVector(src, ";"); 9432be168c0dSopenharmony_ci+ std::map<std::string, std::vector<uint8_t>> tmp_map = {}; 9433be168c0dSopenharmony_ci+ for (auto tmp : tmp_list) { 9434be168c0dSopenharmony_ci+ auto key_and_value = SplitStringToVector(tmp, ":"); 9435be168c0dSopenharmony_ci+ if (key_and_value.size() != kKeyValueSize) { 9436be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse extended parameters failed, should keep key:value format"; 9437be168c0dSopenharmony_ci+ return RET_ERROR; 9438be168c0dSopenharmony_ci+ } 9439be168c0dSopenharmony_ci+ auto key = key_and_value[kKeyIndex]; 9440be168c0dSopenharmony_ci+ auto value = key_and_value[kValueIndex]; 9441be168c0dSopenharmony_ci+ if (tmp_map.find(key) != tmp_map.end()) { 9442be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse extended parameters failed, key should not be duplicated"; 9443be168c0dSopenharmony_ci+ return RET_ERROR; 9444be168c0dSopenharmony_ci+ } 9445be168c0dSopenharmony_ci+ tmp_map.emplace(key, std::vector<uint8_t>(value.begin(), value.end())); 9446be168c0dSopenharmony_ci+ } 9447be168c0dSopenharmony_ci+ 9448be168c0dSopenharmony_ci+ *dst_ext_param = tmp_map; 9449be168c0dSopenharmony_ci+ return RET_OK; 9450be168c0dSopenharmony_ci+} 9451be168c0dSopenharmony_ci+ 9452be168c0dSopenharmony_ci+/** 9453be168c0dSopenharmony_ci+ * Parse dtypes like "float32;float32;int32" and return [kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32] 9454be168c0dSopenharmony_ci+ */ 9455be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes) { 9456be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(dst_dtypes != nullptr, RET_ERROR); 9457be168c0dSopenharmony_ci+ dst_dtypes->clear(); 9458be168c0dSopenharmony_ci+ auto tmp_dtypes = SplitStringToVector(src, ";"); 9459be168c0dSopenharmony_ci+ for (auto tmp_dtype : tmp_dtypes) { 9460be168c0dSopenharmony_ci+ TypeId type = ConvertDataType(tmp_dtype); 9461be168c0dSopenharmony_ci+ if (type == kTypeUnknown) { 9462be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse dtypes in third party model config failed"; 9463be168c0dSopenharmony_ci+ return RET_ERROR; 9464be168c0dSopenharmony_ci+ } 9465be168c0dSopenharmony_ci+ dst_dtypes->push_back(type); 9466be168c0dSopenharmony_ci+ } 9467be168c0dSopenharmony_ci+ return RET_OK; 9468be168c0dSopenharmony_ci+} 9469be168c0dSopenharmony_ci+ 9470be168c0dSopenharmony_ci+/** 9471be168c0dSopenharmony_ci+ * Parse names like "foo;bar;boo" and get ["foo", "bar", "boo"] 9472be168c0dSopenharmony_ci+ * If input names are not provided in config, use the default prefix to generate like: "in_0;in_1;..;in_n" 9473be168c0dSopenharmony_ci+ */ 9474be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseNames(const std::string &src, size_t num, const std::string &default_prefix, 9475be168c0dSopenharmony_ci+ std::vector<std::string> *dst_names) { 9476be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(dst_names != nullptr, RET_ERROR); 9477be168c0dSopenharmony_ci+ std::string tmp_names = src; 9478be168c0dSopenharmony_ci+ if (tmp_names.empty()) { 9479be168c0dSopenharmony_ci+ std::string tmp = ""; 9480be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 9481be168c0dSopenharmony_ci+ tmp += default_prefix + "_" + std::to_string(i); 9482be168c0dSopenharmony_ci+ if (i + 1 < num) { 9483be168c0dSopenharmony_ci+ tmp += ";"; 9484be168c0dSopenharmony_ci+ } 9485be168c0dSopenharmony_ci+ } 9486be168c0dSopenharmony_ci+ tmp_names = tmp; 9487be168c0dSopenharmony_ci+ } 9488be168c0dSopenharmony_ci+ 9489be168c0dSopenharmony_ci+ *dst_names = SplitStringToVector(tmp_names, ";"); 9490be168c0dSopenharmony_ci+ if (dst_names->size() != num) { 9491be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Name number " << dst_names->size() << " and input number: " << num << " are not equal"; 9492be168c0dSopenharmony_ci+ return RET_ERROR; 9493be168c0dSopenharmony_ci+ } 9494be168c0dSopenharmony_ci+ return RET_OK; 9495be168c0dSopenharmony_ci+} 9496be168c0dSopenharmony_ci+ 9497be168c0dSopenharmony_ci+/** 9498be168c0dSopenharmony_ci+ * Parse formats like "NCHW;NHWC" and get [NCHW, NHWC] 9499be168c0dSopenharmony_ci+ */ 9500be168c0dSopenharmony_ci+namespace { 9501be168c0dSopenharmony_ci+ int StringToFormat(const std::string &format_string, schema::Format *format) { 9502be168c0dSopenharmony_ci+ static const std::unordered_map<std::string, schema::Format> kFormatTable = { 9503be168c0dSopenharmony_ci+ {"NCHW", schema::Format::Format_NCHW}, 9504be168c0dSopenharmony_ci+ {"NHWC", schema::Format::Format_NHWC}, 9505be168c0dSopenharmony_ci+ {"NHWC4", schema::Format::Format_NHWC4}, 9506be168c0dSopenharmony_ci+ {"HWKC", schema::Format::Format_HWKC}, 9507be168c0dSopenharmony_ci+ {"HWCK", schema::Format::Format_HWCK}, 9508be168c0dSopenharmony_ci+ {"KCHW", schema::Format::Format_KCHW}, 9509be168c0dSopenharmony_ci+ {"CKHW", schema::Format::Format_CKHW}, 9510be168c0dSopenharmony_ci+ {"KHWC", schema::Format::Format_KHWC}, 9511be168c0dSopenharmony_ci+ {"CHWK", schema::Format::Format_CHWK}, 9512be168c0dSopenharmony_ci+ {"HW", schema::Format::Format_HW}, 9513be168c0dSopenharmony_ci+ {"HW4", schema::Format::Format_HW4}, 9514be168c0dSopenharmony_ci+ {"NC", schema::Format::Format_NC}, 9515be168c0dSopenharmony_ci+ {"NC4", schema::Format::Format_NC4}, 9516be168c0dSopenharmony_ci+ {"NC4HW4", schema::Format::Format_NC4HW4}, 9517be168c0dSopenharmony_ci+ {"NUM_OF_FORMAT", schema::Format::Format_NUM_OF_FORMAT}, 9518be168c0dSopenharmony_ci+ {"NCDHW", schema::Format::Format_NCDHW}, 9519be168c0dSopenharmony_ci+ {"NWC", schema::Format::Format_NWC}, 9520be168c0dSopenharmony_ci+ {"NCW", schema::Format::Format_NCW}, 9521be168c0dSopenharmony_ci+ }; 9522be168c0dSopenharmony_ci+ 9523be168c0dSopenharmony_ci+ if (format == nullptr) { 9524be168c0dSopenharmony_ci+ return RET_NULL_PTR; 9525be168c0dSopenharmony_ci+ } 9526be168c0dSopenharmony_ci+ 9527be168c0dSopenharmony_ci+ auto iter = kFormatTable.find(format_string); 9528be168c0dSopenharmony_ci+ if (iter == kFormatTable.end()) { 9529be168c0dSopenharmony_ci+ return RET_PARAM_INVALID; 9530be168c0dSopenharmony_ci+ } 9531be168c0dSopenharmony_ci+ 9532be168c0dSopenharmony_ci+ *format = iter->second; 9533be168c0dSopenharmony_ci+ return RET_OK; 9534be168c0dSopenharmony_ci+ } 9535be168c0dSopenharmony_ci+} 9536be168c0dSopenharmony_ci+ 9537be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseFormats(const std::string &src, size_t num, 9538be168c0dSopenharmony_ci+ std::vector<schema::Format> *result_formats) { 9539be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(result_formats != nullptr, RET_ERROR); 9540be168c0dSopenharmony_ci+ std::string tmp_names = src; 9541be168c0dSopenharmony_ci+ if (tmp_names.empty()) { 9542be168c0dSopenharmony_ci+ std::vector<schema::Format> default_formats(num, schema::Format::Format_NHWC); 9543be168c0dSopenharmony_ci+ *result_formats = default_formats; 9544be168c0dSopenharmony_ci+ return RET_OK; 9545be168c0dSopenharmony_ci+ } 9546be168c0dSopenharmony_ci+ 9547be168c0dSopenharmony_ci+ auto format_strings = SplitStringToVector(tmp_names, ";"); 9548be168c0dSopenharmony_ci+ if (format_strings.size() != num) { 9549be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Number of format: " << format_strings.size() << " and number of tensor: " << num << " are not equal"; 9550be168c0dSopenharmony_ci+ return RET_ERROR; 9551be168c0dSopenharmony_ci+ } 9552be168c0dSopenharmony_ci+ 9553be168c0dSopenharmony_ci+ std::vector<schema::Format> result(num); 9554be168c0dSopenharmony_ci+ for (size_t i = 0; i < num; i++) { 9555be168c0dSopenharmony_ci+ if (StringToFormat(format_strings[i], &result[i]) != RET_OK) { 9556be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Tensor format:" << format_strings[i] << " is invalid"; 9557be168c0dSopenharmony_ci+ return RET_PARAM_INVALID; 9558be168c0dSopenharmony_ci+ } 9559be168c0dSopenharmony_ci+ } 9560be168c0dSopenharmony_ci+ *result_formats = result; 9561be168c0dSopenharmony_ci+ return RET_OK; 9562be168c0dSopenharmony_ci+} 9563be168c0dSopenharmony_ci+ 9564be168c0dSopenharmony_ci+int ThirdPartyParamParser::Parse(const ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param) { 9565be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR); 9566be168c0dSopenharmony_ci+ 9567be168c0dSopenharmony_ci+ auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes)); 9568be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9569be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse input shapes of third party param failed"; 9570be168c0dSopenharmony_ci+ return RET_ERROR; 9571be168c0dSopenharmony_ci+ } 9572be168c0dSopenharmony_ci+ 9573be168c0dSopenharmony_ci+ ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes)); 9574be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9575be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse input dtypes of third party param failed"; 9576be168c0dSopenharmony_ci+ return RET_ERROR; 9577be168c0dSopenharmony_ci+ } 9578be168c0dSopenharmony_ci+ 9579be168c0dSopenharmony_ci+ auto input_shape_num = param->input_shapes.size(); 9580be168c0dSopenharmony_ci+ auto input_dtype_num = param->input_dtypes.size(); 9581be168c0dSopenharmony_ci+ if (input_shape_num != input_dtype_num) { 9582be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Input shape number: " << input_shape_num << " and dtype number: " << input_dtype_num 9583be168c0dSopenharmony_ci+ << " are not equal"; 9584be168c0dSopenharmony_ci+ return RET_ERROR; 9585be168c0dSopenharmony_ci+ } 9586be168c0dSopenharmony_ci+ 9587be168c0dSopenharmony_ci+ ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats)); 9588be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9589be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse input formats of third party param failed"; 9590be168c0dSopenharmony_ci+ return RET_ERROR; 9591be168c0dSopenharmony_ci+ } 9592be168c0dSopenharmony_ci+ 9593be168c0dSopenharmony_ci+ const std::string kInputNamePrefix = "in"; 9594be168c0dSopenharmony_ci+ ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names)); 9595be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9596be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse input names of third party param failed"; 9597be168c0dSopenharmony_ci+ return RET_ERROR; 9598be168c0dSopenharmony_ci+ } 9599be168c0dSopenharmony_ci+ 9600be168c0dSopenharmony_ci+ ret = DoParseShape(param_string.output_shapes, &(param->output_shapes)); 9601be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9602be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse output shaped of third party param failed"; 9603be168c0dSopenharmony_ci+ return RET_ERROR; 9604be168c0dSopenharmony_ci+ } 9605be168c0dSopenharmony_ci+ 9606be168c0dSopenharmony_ci+ ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes)); 9607be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9608be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse output dtypes of third party param failed"; 9609be168c0dSopenharmony_ci+ return RET_ERROR; 9610be168c0dSopenharmony_ci+ } 9611be168c0dSopenharmony_ci+ 9612be168c0dSopenharmony_ci+ auto output_shape_num = param->output_shapes.size(); 9613be168c0dSopenharmony_ci+ auto output_dtype_num = param->output_dtypes.size(); 9614be168c0dSopenharmony_ci+ if (output_shape_num != output_dtype_num) { 9615be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Output shape number: " << output_shape_num << " and dtype number: " << output_dtype_num 9616be168c0dSopenharmony_ci+ << " are not equal"; 9617be168c0dSopenharmony_ci+ return RET_ERROR; 9618be168c0dSopenharmony_ci+ } 9619be168c0dSopenharmony_ci+ 9620be168c0dSopenharmony_ci+ ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats)); 9621be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9622be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse output formats of third party param failed"; 9623be168c0dSopenharmony_ci+ return RET_ERROR; 9624be168c0dSopenharmony_ci+ } 9625be168c0dSopenharmony_ci+ 9626be168c0dSopenharmony_ci+ const std::string kOutputNamePrefix = "out"; 9627be168c0dSopenharmony_ci+ ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names)); 9628be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9629be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse output names of third party param failed"; 9630be168c0dSopenharmony_ci+ return RET_ERROR; 9631be168c0dSopenharmony_ci+ } 9632be168c0dSopenharmony_ci+ 9633be168c0dSopenharmony_ci+ ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters)); 9634be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9635be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse extended parameter of third party param failed"; 9636be168c0dSopenharmony_ci+ return RET_ERROR; 9637be168c0dSopenharmony_ci+ } 9638be168c0dSopenharmony_ci+ 9639be168c0dSopenharmony_ci+ return RET_OK; 9640be168c0dSopenharmony_ci+} 9641be168c0dSopenharmony_ci+} // namespace lite 9642be168c0dSopenharmony_ci+} // namespace mindspore 9643be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 9644be168c0dSopenharmony_cinew file mode 100644 9645be168c0dSopenharmony_ciindex 00000000..5cf6e8fb 9646be168c0dSopenharmony_ci--- /dev/null 9647be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h 9648be168c0dSopenharmony_ci@@ -0,0 +1,44 @@ 9649be168c0dSopenharmony_ci+/** 9650be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 9651be168c0dSopenharmony_ci+ * 9652be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 9653be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 9654be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 9655be168c0dSopenharmony_ci+ * 9656be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 9657be168c0dSopenharmony_ci+ * 9658be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 9659be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 9660be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9661be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 9662be168c0dSopenharmony_ci+ * limitations under the License. 9663be168c0dSopenharmony_ci+ */ 9664be168c0dSopenharmony_ci+ 9665be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 9666be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 9667be168c0dSopenharmony_ci+#include <string> 9668be168c0dSopenharmony_ci+#include <vector> 9669be168c0dSopenharmony_ci+#include <map> 9670be168c0dSopenharmony_ci+#include "include/errorcode.h" 9671be168c0dSopenharmony_ci+#include "tools/converter/cxx_api/converter_para.h" 9672be168c0dSopenharmony_ci+#include "tools/converter/config_parser/config_file_parser.h" 9673be168c0dSopenharmony_ci+ 9674be168c0dSopenharmony_ci+namespace mindspore { 9675be168c0dSopenharmony_ci+namespace lite { 9676be168c0dSopenharmony_ci+class ThirdPartyParamParser { 9677be168c0dSopenharmony_ci+ public: 9678be168c0dSopenharmony_ci+ static int Parse(const lite::ThirdPartyModelString ¶m_string, ThirdPartyModelParam *param); 9679be168c0dSopenharmony_ci+ 9680be168c0dSopenharmony_ci+ private: 9681be168c0dSopenharmony_ci+ static int DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes); 9682be168c0dSopenharmony_ci+ static int DoParseExtendedParameters(const std::string &src, 9683be168c0dSopenharmony_ci+ std::map<std::string, std::vector<uint8_t>> *dst_ext_param); 9684be168c0dSopenharmony_ci+ static int DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes); 9685be168c0dSopenharmony_ci+ static int DoParseNames(const std::string &src, size_t num, const std::string &default_prefix, 9686be168c0dSopenharmony_ci+ std::vector<std::string> *dst_names); 9687be168c0dSopenharmony_ci+ static int DoParseFormats(const std::string &src, size_t num, std::vector<schema::Format> *result_formats); 9688be168c0dSopenharmony_ci+}; 9689be168c0dSopenharmony_ci+} // namespace lite 9690be168c0dSopenharmony_ci+} // namespace mindspore 9691be168c0dSopenharmony_ci+ 9692be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_ 9693be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc 9694be168c0dSopenharmony_ciindex df3176c2..a61bd51c 100644 9695be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/converter.cc 9696be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/converter.cc 9697be168c0dSopenharmony_ci@@ -49,6 +49,7 @@ 9698be168c0dSopenharmony_ci #include "tools/converter/config_parser/preprocess_parser.h" 9699be168c0dSopenharmony_ci #include "tools/converter/config_parser/quant_param_parser.h" 9700be168c0dSopenharmony_ci #include "tools/converter/config_parser/graph_kernel_param_parser.h" 9701be168c0dSopenharmony_ci+#include "tools/converter/config_parser/third_party_param_parser.h" 9702be168c0dSopenharmony_ci #include "tools/converter/converter_funcgraph.h" 9703be168c0dSopenharmony_ci #include "tools/converter/converter_metagraph.h" 9704be168c0dSopenharmony_ci #include "tools/common/string_util.h" 9705be168c0dSopenharmony_ci@@ -472,6 +473,12 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std:: 9706be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse mixed bit weight quant param failed."; 9707be168c0dSopenharmony_ci return ret; 9708be168c0dSopenharmony_ci } 9709be168c0dSopenharmony_ci+ ret = lite::ThirdPartyParamParser::Parse(config_parser->GetThirdPartyModelString(), 9710be168c0dSopenharmony_ci+ ¶m->thirdPartyModelParam); 9711be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9712be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse third party param failed."; 9713be168c0dSopenharmony_ci+ return ret; 9714be168c0dSopenharmony_ci+ } 9715be168c0dSopenharmony_ci ret = InitExtendedIntegrationInfo(param, *config_parser); 9716be168c0dSopenharmony_ci if (ret != RET_OK) { 9717be168c0dSopenharmony_ci MS_LOG(ERROR) << "Parse extended integration info failed."; 9718be168c0dSopenharmony_ci@@ -699,19 +706,20 @@ std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const s 9719be168c0dSopenharmony_ci 9720be168c0dSopenharmony_ci int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) { 9721be168c0dSopenharmony_ci if (param != nullptr) { 9722be168c0dSopenharmony_ci- std::set valid_values = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 9723be168c0dSopenharmony_ci- FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 9724be168c0dSopenharmony_ci- FmkType::kFmkTypeMsLite}; 9725be168c0dSopenharmony_ci- if (std::find(valid_values.begin(), valid_values.end(), param->fmk_type) == valid_values.end()) { 9726be168c0dSopenharmony_ci- MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be " 9727be168c0dSopenharmony_ci- "kFmkTypeTf|kFmkTypeCaffe|kFmkTypeOnnx|kFmkTypeMs|kFmkTypeTflite|kFmkTypeMsLite" 9728be168c0dSopenharmony_ci- << ", but got " << param->fmk_type; 9729be168c0dSopenharmony_ci- return RET_INPUT_PARAM_INVALID; 9730be168c0dSopenharmony_ci- } 9731be168c0dSopenharmony_ci- if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) { 9732be168c0dSopenharmony_ci- MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag"; 9733be168c0dSopenharmony_ci- return RET_INPUT_PARAM_INVALID; 9734be168c0dSopenharmony_ci- } 9735be168c0dSopenharmony_ci+ return RET_OK; 9736be168c0dSopenharmony_ci+ } 9737be168c0dSopenharmony_ci+ std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx, 9738be168c0dSopenharmony_ci+ FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch, 9739be168c0dSopenharmony_ci+ FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty}; 9740be168c0dSopenharmony_ci+ if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) { 9741be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be " 9742be168c0dSopenharmony_ci+ "TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|MSLITE|THIRDPARTY" 9743be168c0dSopenharmony_ci+ << ", but got " << param->fmk_type; 9744be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9745be168c0dSopenharmony_ci+ } 9746be168c0dSopenharmony_ci+ if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) { 9747be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag"; 9748be168c0dSopenharmony_ci+ return RET_INPUT_PARAM_INVALID; 9749be168c0dSopenharmony_ci } 9750be168c0dSopenharmony_ci return RET_OK; 9751be168c0dSopenharmony_ci } 9752be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/converter_funcgraph.cc b/mindspore/lite/tools/converter/converter_funcgraph.cc 9753be168c0dSopenharmony_ciindex f03f995c..61d5c463 100644 9754be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/converter_funcgraph.cc 9755be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/converter_funcgraph.cc 9756be168c0dSopenharmony_ci@@ -90,6 +90,7 @@ FuncGraphPtr ConverterFuncGraph::Load3rdModelToFuncgraph(const std::shared_ptr<C 9757be168c0dSopenharmony_ci converter_parameters.save_type = param->save_type; 9758be168c0dSopenharmony_ci converter_parameters.model_file = param->model_file; 9759be168c0dSopenharmony_ci converter_parameters.weight_file = param->weight_file; 9760be168c0dSopenharmony_ci+ converter_parameters.attrs.emplace("config_file", param->config_file); 9761be168c0dSopenharmony_ci func_graph_base = model_parser->Parse(converter_parameters); 9762be168c0dSopenharmony_ci if (func_graph_base == nullptr) { 9763be168c0dSopenharmony_ci delete model_parser; 9764be168c0dSopenharmony_ci@@ -447,11 +448,13 @@ STATUS ConverterFuncGraph::Optimize(const std::shared_ptr<ConverterPara> ¶m, 9765be168c0dSopenharmony_ci return status; 9766be168c0dSopenharmony_ci } 9767be168c0dSopenharmony_ci 9768be168c0dSopenharmony_ci- AnfTransform funcgraph_transform; 9769be168c0dSopenharmony_ci- status = funcgraph_transform.Transform(func_graph, param); 9770be168c0dSopenharmony_ci- if (status != RET_OK) { 9771be168c0dSopenharmony_ci- MS_LOG(ERROR) << "Transform anf graph failed."; 9772be168c0dSopenharmony_ci- return status; 9773be168c0dSopenharmony_ci+ if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) { 9774be168c0dSopenharmony_ci+ AnfTransform funcgraph_transform; 9775be168c0dSopenharmony_ci+ status = funcgraph_transform.Transform(func_graph, param); 9776be168c0dSopenharmony_ci+ if (status != RET_OK) { 9777be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Transform anf graph failed."; 9778be168c0dSopenharmony_ci+ return status; 9779be168c0dSopenharmony_ci+ } 9780be168c0dSopenharmony_ci } 9781be168c0dSopenharmony_ci 9782be168c0dSopenharmony_ci status = UnifyFuncGraphOutputFormat(param, func_graph); 9783be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 9784be168c0dSopenharmony_ciindex 4883c48d..024e209f 100644 9785be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 9786be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc 9787be168c0dSopenharmony_ci@@ -138,11 +138,11 @@ int Flags::InitFmk() { 9788be168c0dSopenharmony_ci // value check not here, it is in converter c++ API's CheckValueParam method. 9789be168c0dSopenharmony_ci std::map<std::string, FmkType> StrToEnumFmkTypeMap = { 9790be168c0dSopenharmony_ci {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs}, {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx}, 9791be168c0dSopenharmony_ci- {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"MSLITE", kFmkTypeMsLite}}; 9792be168c0dSopenharmony_ci+ {"TF", kFmkTypeTf}, {"PYTORCH", kFmkTypePytorch}, {"MSLITE", kFmkTypeMsLite}, {"THIRDPARTY", kFmkTypeThirdParty}}; 9793be168c0dSopenharmony_ci if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) { 9794be168c0dSopenharmony_ci this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn); 9795be168c0dSopenharmony_ci } else { 9796be168c0dSopenharmony_ci- std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl; 9797be168c0dSopenharmony_ci+ std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX|PYTORCH|THIRDPARTY" << std::endl; 9798be168c0dSopenharmony_ci return RET_INPUT_PARAM_INVALID; 9799be168c0dSopenharmony_ci } 9800be168c0dSopenharmony_ci 9801be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h 9802be168c0dSopenharmony_ciindex a4f72a69..33210fd0 100644 9803be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/cxx_api/converter_para.h 9804be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h 9805be168c0dSopenharmony_ci@@ -21,6 +21,7 @@ 9806be168c0dSopenharmony_ci #include <vector> 9807be168c0dSopenharmony_ci #include <set> 9808be168c0dSopenharmony_ci #include "include/converter.h" 9809be168c0dSopenharmony_ci+#include "mindapi/base/type_id.h" 9810be168c0dSopenharmony_ci #include "tools/converter/quantizer/quant_params.h" 9811be168c0dSopenharmony_ci #include "tools/converter/preprocess/preprocess_param.h" 9812be168c0dSopenharmony_ci #include "tools/converter/adapter/acl/common/acl_types.h" 9813be168c0dSopenharmony_ci@@ -35,6 +36,18 @@ struct ParallelSplitConfig { 9814be168c0dSopenharmony_ci std::vector<std::string> parallel_devices_; 9815be168c0dSopenharmony_ci }; 9816be168c0dSopenharmony_ci 9817be168c0dSopenharmony_ci+struct ThirdPartyModelParam { 9818be168c0dSopenharmony_ci+ std::vector<TypeId> input_dtypes; 9819be168c0dSopenharmony_ci+ std::vector<std::vector<int64_t>> input_shapes; 9820be168c0dSopenharmony_ci+ std::vector<std::string> input_names; 9821be168c0dSopenharmony_ci+ std::vector<schema::Format> input_formats; 9822be168c0dSopenharmony_ci+ std::vector<TypeId> output_dtypes; 9823be168c0dSopenharmony_ci+ std::vector<std::vector<int64_t>> output_shapes; 9824be168c0dSopenharmony_ci+ std::vector<std::string> output_names; 9825be168c0dSopenharmony_ci+ std::vector<schema::Format> output_formats; 9826be168c0dSopenharmony_ci+ std::map<std::string, std::vector<uint8_t>> extended_parameters; 9827be168c0dSopenharmony_ci+}; 9828be168c0dSopenharmony_ci+ 9829be168c0dSopenharmony_ci struct CpuOptionCfg { 9830be168c0dSopenharmony_ci std::string architecture; 9831be168c0dSopenharmony_ci std::string instruction; 9832be168c0dSopenharmony_ci@@ -97,6 +110,7 @@ struct ConverterPara { 9833be168c0dSopenharmony_ci lite::acl::AclModelOptionCfg aclModelOptionCfgParam; 9834be168c0dSopenharmony_ci lite::micro::MicroParam microParam; 9835be168c0dSopenharmony_ci ParallelSplitConfig parallel_split_config; 9836be168c0dSopenharmony_ci+ ThirdPartyModelParam thirdPartyModelParam; 9837be168c0dSopenharmony_ci AscendGeOptionCfg ascendGeOptionCfg; 9838be168c0dSopenharmony_ci std::string device; 9839be168c0dSopenharmony_ci std::string provider; 9840be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc 9841be168c0dSopenharmony_ciindex 90b744e5..bf1a82ae 100644 9842be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/graphdef_transform.cc 9843be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/graphdef_transform.cc 9844be168c0dSopenharmony_ci@@ -76,11 +76,55 @@ int QuantTransform(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGrap 9845be168c0dSopenharmony_ci } 9846be168c0dSopenharmony_ci return RET_OK; 9847be168c0dSopenharmony_ci } 9848be168c0dSopenharmony_ci+ 9849be168c0dSopenharmony_ci+int FillGraphOutputShape(MetaGraphT *meta_graph, const std::vector<std::vector<int64_t>> output_shapes) { 9850be168c0dSopenharmony_ci+ const auto &out_indices = meta_graph->outputIndex; 9851be168c0dSopenharmony_ci+ for (size_t i = 0; i < out_indices.size(); i++) { 9852be168c0dSopenharmony_ci+ auto &out_tensor = meta_graph->allTensors[out_indices[i]]; 9853be168c0dSopenharmony_ci+ out_tensor->dims = {}; 9854be168c0dSopenharmony_ci+ for (size_t k = 0; k < output_shapes[i].size(); k++) { 9855be168c0dSopenharmony_ci+ out_tensor->dims.push_back(static_cast<int32_t>(output_shapes[i][k])); 9856be168c0dSopenharmony_ci+ } 9857be168c0dSopenharmony_ci+ } 9858be168c0dSopenharmony_ci+ return RET_OK; 9859be168c0dSopenharmony_ci+} 9860be168c0dSopenharmony_ci+ 9861be168c0dSopenharmony_ci+void FillGraphInputAndOutputFormats(MetaGraphT *meta_graph, const ConverterPara ¶) { 9862be168c0dSopenharmony_ci+ const auto &in_indices = meta_graph->inputIndex; 9863be168c0dSopenharmony_ci+ for (size_t i = 0; i < in_indices.size(); i++) { 9864be168c0dSopenharmony_ci+ auto &in_tensor = meta_graph->allTensors[in_indices[i]]; 9865be168c0dSopenharmony_ci+ in_tensor->format = para.thirdPartyModelParam.input_formats[i]; 9866be168c0dSopenharmony_ci+ MS_LOG_DEBUG << "input " << i << " format: " << EnumNameFormat(in_tensor->format); 9867be168c0dSopenharmony_ci+ } 9868be168c0dSopenharmony_ci+ 9869be168c0dSopenharmony_ci+ const auto &out_indices = meta_graph->outputIndex; 9870be168c0dSopenharmony_ci+ for (size_t i = 0; i < out_indices.size(); i++) { 9871be168c0dSopenharmony_ci+ auto &out_tensor = meta_graph->allTensors[out_indices[i]]; 9872be168c0dSopenharmony_ci+ out_tensor->format = para.thirdPartyModelParam.output_formats[i]; 9873be168c0dSopenharmony_ci+ MS_LOG_DEBUG << "output " << i << " format: " << EnumNameFormat(out_tensor->format); 9874be168c0dSopenharmony_ci+ } 9875be168c0dSopenharmony_ci+} 9876be168c0dSopenharmony_ci } // namespace 9877be168c0dSopenharmony_ci 9878be168c0dSopenharmony_ci int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> ¶m) { 9879be168c0dSopenharmony_ci MS_ASSERT(param != nullptr); 9880be168c0dSopenharmony_ci STATUS status; 9881be168c0dSopenharmony_ci+ 9882be168c0dSopenharmony_ci+ if (param->fmk_type == converter::kFmkTypeThirdParty) { 9883be168c0dSopenharmony_ci+ 9884be168c0dSopenharmony_ci+ // Legacy optimizer infer shape, but op Custom which wraps third party model has no infer-shape function. 9885be168c0dSopenharmony_ci+ // So we don't perform legacy optimization for kFmkTypeThirdParty case. 9886be168c0dSopenharmony_ci+ auto ret = FillGraphOutputShape(graph_defT_, param->thirdPartyModelParam.output_shapes); 9887be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9888be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Fill output shape of third party model failed, ret:" << ret; 9889be168c0dSopenharmony_ci+ return ret; 9890be168c0dSopenharmony_ci+ } 9891be168c0dSopenharmony_ci+ 9892be168c0dSopenharmony_ci+ // Tensor of FuncGraph has no attribute of format, so set format in MetaGraph. 9893be168c0dSopenharmony_ci+ FillGraphInputAndOutputFormats(graph_defT_, *param); 9894be168c0dSopenharmony_ci+ return RET_OK; 9895be168c0dSopenharmony_ci+ } 9896be168c0dSopenharmony_ci+ 9897be168c0dSopenharmony_ci { 9898be168c0dSopenharmony_ci auto old_nodes = GetGraphNodes(*graph_defT_); 9899be168c0dSopenharmony_ci Optimizer unused_op_remove_optimizer; 9900be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 9901be168c0dSopenharmony_cinew file mode 100644 9902be168c0dSopenharmony_ciindex 00000000..b55e0194 9903be168c0dSopenharmony_ci--- /dev/null 9904be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt 9905be168c0dSopenharmony_ci@@ -0,0 +1,4 @@ 9906be168c0dSopenharmony_ci+add_library(third_party_parser_mid OBJECT third_party_model_parser.cc) 9907be168c0dSopenharmony_ci+add_dependencies(third_party_parser_mid proto_mid) 9908be168c0dSopenharmony_ci+add_dependencies(third_party_parser_mid fbs_src) 9909be168c0dSopenharmony_ci+add_dependencies(third_party_parser_mid fbs_inner_src) 9910be168c0dSopenharmony_ci\ No newline at end of file 9911be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 9912be168c0dSopenharmony_cinew file mode 100644 9913be168c0dSopenharmony_ciindex 00000000..652db4af 9914be168c0dSopenharmony_ci--- /dev/null 9915be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc 9916be168c0dSopenharmony_ci@@ -0,0 +1,277 @@ 9917be168c0dSopenharmony_ci+/** 9918be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 9919be168c0dSopenharmony_ci+ * 9920be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 9921be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 9922be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 9923be168c0dSopenharmony_ci+ * 9924be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 9925be168c0dSopenharmony_ci+ * 9926be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 9927be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 9928be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9929be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 9930be168c0dSopenharmony_ci+ * limitations under the License. 9931be168c0dSopenharmony_ci+ */ 9932be168c0dSopenharmony_ci+#include "tools/converter/parser/third_party/third_party_model_parser.h" 9933be168c0dSopenharmony_ci+#include <string> 9934be168c0dSopenharmony_ci+#include <vector> 9935be168c0dSopenharmony_ci+#include <memory> 9936be168c0dSopenharmony_ci+#include "ir/value.h" 9937be168c0dSopenharmony_ci+#include "mindapi/base/type_id.h" 9938be168c0dSopenharmony_ci+#include "src/common/log_util.h" 9939be168c0dSopenharmony_ci+#include "src/common/file_utils.h" 9940be168c0dSopenharmony_ci+#include "nnacl/op_base.h" 9941be168c0dSopenharmony_ci+#include "ops/primitive_c.h" 9942be168c0dSopenharmony_ci+#include "ops/custom.h" 9943be168c0dSopenharmony_ci+#include "ops/tuple_get_item.h" 9944be168c0dSopenharmony_ci+#include "ops/make_tuple.h" 9945be168c0dSopenharmony_ci+#include "ops/return.h" 9946be168c0dSopenharmony_ci+#include "tools/converter/config_parser/config_file_parser.h" 9947be168c0dSopenharmony_ci+#include "include/registry/model_parser_registry.h" 9948be168c0dSopenharmony_ci+#include "tools/common/graph_util.h" 9949be168c0dSopenharmony_ci+#include "tools/common/tensor_util.h" 9950be168c0dSopenharmony_ci+#include "tools/converter/converter_context.h" 9951be168c0dSopenharmony_ci+#include "tools/converter/parser/lite_model_parser_creator.h" 9952be168c0dSopenharmony_ci+ 9953be168c0dSopenharmony_ci+using mindspore::converter::kFmkTypeThirdParty; 9954be168c0dSopenharmony_ci+ 9955be168c0dSopenharmony_ci+namespace mindspore { 9956be168c0dSopenharmony_ci+namespace lite { 9957be168c0dSopenharmony_ci+api::FuncGraphPtr ThirdPartyModelParser::Parse(const converter::ConverterParameters &flag) { 9958be168c0dSopenharmony_ci+ model_file_ = flag.model_file; 9959be168c0dSopenharmony_ci+ auto &attrs = flag.attrs; 9960be168c0dSopenharmony_ci+ auto iter = attrs.find("config_file"); 9961be168c0dSopenharmony_ci+ if (iter == attrs.end()) { 9962be168c0dSopenharmony_ci+ return nullptr; 9963be168c0dSopenharmony_ci+ } 9964be168c0dSopenharmony_ci+ auto config_file = iter->second; 9965be168c0dSopenharmony_ci+ 9966be168c0dSopenharmony_ci+ auto ret = InitConfig(config_file); 9967be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9968be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Init config for third party model parsing failed"; 9969be168c0dSopenharmony_ci+ return nullptr; 9970be168c0dSopenharmony_ci+ } 9971be168c0dSopenharmony_ci+ 9972be168c0dSopenharmony_ci+ return CreateFuncGraph(); 9973be168c0dSopenharmony_ci+} 9974be168c0dSopenharmony_ci+ 9975be168c0dSopenharmony_ci+STATUS ThirdPartyModelParser::InitConfig(const std::string &config_file) { 9976be168c0dSopenharmony_ci+ lite::ConfigFileParser config_parser; 9977be168c0dSopenharmony_ci+ if (config_file.empty()) { 9978be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Missing config file in converting third party model"; 9979be168c0dSopenharmony_ci+ return RET_ERROR; 9980be168c0dSopenharmony_ci+ } 9981be168c0dSopenharmony_ci+ auto ret = config_parser.ParseConfigFile(config_file); 9982be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9983be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Get third party model section from config file failed"; 9984be168c0dSopenharmony_ci+ return RET_ERROR; 9985be168c0dSopenharmony_ci+ } 9986be168c0dSopenharmony_ci+ 9987be168c0dSopenharmony_ci+ ret = ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), ¶m_); 9988be168c0dSopenharmony_ci+ if (ret != RET_OK) { 9989be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Parse third party model param failed."; 9990be168c0dSopenharmony_ci+ return ret; 9991be168c0dSopenharmony_ci+ } 9992be168c0dSopenharmony_ci+ return RET_OK; 9993be168c0dSopenharmony_ci+} 9994be168c0dSopenharmony_ci+ 9995be168c0dSopenharmony_ci+api::FuncGraphPtr ThirdPartyModelParser::CreateFuncGraph() { 9996be168c0dSopenharmony_ci+ auto func_graph = std::make_shared<FuncGraph>(); 9997be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr); 9998be168c0dSopenharmony_ci+ auto type_value = MakeValue(static_cast<int>(converter::kFmkTypeThirdParty)); 9999be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(type_value != nullptr, nullptr); 10000be168c0dSopenharmony_ci+ func_graph->set_attr("fmk", type_value); 10001be168c0dSopenharmony_ci+ auto attr_value = MakeValue("third_party"); 10002be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(attr_value != nullptr, nullptr); 10003be168c0dSopenharmony_ci+ func_graph->set_attr("graph_name", attr_value); 10004be168c0dSopenharmony_ci+ 10005be168c0dSopenharmony_ci+ std::vector<AnfNodePtr> input_nodes = {}; 10006be168c0dSopenharmony_ci+ auto ret = BuildGraphInputs(func_graph, &input_nodes); 10007be168c0dSopenharmony_ci+ if (ret != RET_OK) { 10008be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create func graph input nodes failed"; 10009be168c0dSopenharmony_ci+ return nullptr; 10010be168c0dSopenharmony_ci+ } 10011be168c0dSopenharmony_ci+ 10012be168c0dSopenharmony_ci+ CNodePtr custom_node = nullptr; 10013be168c0dSopenharmony_ci+ ret = BuildCustomOp(func_graph, input_nodes, &custom_node); 10014be168c0dSopenharmony_ci+ if (ret != RET_OK) { 10015be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create func graph custom op node failed"; 10016be168c0dSopenharmony_ci+ return nullptr; 10017be168c0dSopenharmony_ci+ } 10018be168c0dSopenharmony_ci+ 10019be168c0dSopenharmony_ci+ ret = BuildGraphOutputs(func_graph, custom_node); 10020be168c0dSopenharmony_ci+ if (ret != RET_OK) { 10021be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create func graph output nodes failed"; 10022be168c0dSopenharmony_ci+ return nullptr; 10023be168c0dSopenharmony_ci+ } 10024be168c0dSopenharmony_ci+ 10025be168c0dSopenharmony_ci+ static auto manager = Manage(func_graph); 10026be168c0dSopenharmony_ci+ func_graph->set_manager(manager); 10027be168c0dSopenharmony_ci+ 10028be168c0dSopenharmony_ci+ auto result_graph = api::MakeShared<api::FuncGraph>(func_graph); 10029be168c0dSopenharmony_ci+ return result_graph; 10030be168c0dSopenharmony_ci+} 10031be168c0dSopenharmony_ci+ 10032be168c0dSopenharmony_ci+STATUS ThirdPartyModelParser::BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs) { 10033be168c0dSopenharmony_ci+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 10034be168c0dSopenharmony_ci+ auto &dtypes = param_.input_dtypes; 10035be168c0dSopenharmony_ci+ auto &shapes = param_.input_shapes; 10036be168c0dSopenharmony_ci+ auto &names = param_.input_names; 10037be168c0dSopenharmony_ci+ 10038be168c0dSopenharmony_ci+ auto input_size = dtypes.size(); 10039be168c0dSopenharmony_ci+ 10040be168c0dSopenharmony_ci+ // Create parameter nodes for graph inputs 10041be168c0dSopenharmony_ci+ for (size_t i = 0; i < input_size; i++) { 10042be168c0dSopenharmony_ci+ auto parameter = func_graph->add_parameter(); 10043be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(parameter); 10044be168c0dSopenharmony_ci+ auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]); 10045be168c0dSopenharmony_ci+ if (abstract_tensor == nullptr) { 10046be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create tensor abstract failed"; 10047be168c0dSopenharmony_ci+ return RET_ERROR; 10048be168c0dSopenharmony_ci+ } 10049be168c0dSopenharmony_ci+ parameter->set_abstract(abstract_tensor); 10050be168c0dSopenharmony_ci+ parameter->set_name(names[i]); 10051be168c0dSopenharmony_ci+ op_inputs->push_back(parameter); 10052be168c0dSopenharmony_ci+ } 10053be168c0dSopenharmony_ci+ 10054be168c0dSopenharmony_ci+ // Create parameter nodes for const tensor which wrapped third model buffer. 10055be168c0dSopenharmony_ci+ size_t model_size = 0U; 10056be168c0dSopenharmony_ci+ auto model_data = ReadFile(model_file_.c_str(), &model_size); 10057be168c0dSopenharmony_ci+ std::vector<int64_t> model_shape = {static_cast<int64_t>(model_size)}; 10058be168c0dSopenharmony_ci+ auto tensor_info = CreateTensorInfo(nullptr, 0, model_shape, kNumberTypeUInt8); 10059be168c0dSopenharmony_ci+ if (tensor_info == nullptr) { 10060be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "init tensor info failed"; 10061be168c0dSopenharmony_ci+ delete model_data; 10062be168c0dSopenharmony_ci+ return RET_NULL_PTR; 10063be168c0dSopenharmony_ci+ } 10064be168c0dSopenharmony_ci+ auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c()); 10065be168c0dSopenharmony_ci+ if (memcpy_s(tensor_data, tensor_info->Size(), model_data, model_size) != EOK) { 10066be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "memcpy failed."; 10067be168c0dSopenharmony_ci+ delete model_data; 10068be168c0dSopenharmony_ci+ return RET_ERROR; 10069be168c0dSopenharmony_ci+ } 10070be168c0dSopenharmony_ci+ delete model_data; 10071be168c0dSopenharmony_ci+ auto parameter = func_graph->add_parameter(); 10072be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(parameter); 10073be168c0dSopenharmony_ci+ auto status = InitParameterFromTensorInfo(parameter, tensor_info); 10074be168c0dSopenharmony_ci+ if (status != RET_OK) { 10075be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "init parameter from tensor info failed."; 10076be168c0dSopenharmony_ci+ return RET_ERROR; 10077be168c0dSopenharmony_ci+ } 10078be168c0dSopenharmony_ci+ parameter->set_name("ThirdPartyModel"); 10079be168c0dSopenharmony_ci+ op_inputs->push_back(parameter); 10080be168c0dSopenharmony_ci+ return RET_OK; 10081be168c0dSopenharmony_ci+} 10082be168c0dSopenharmony_ci+ 10083be168c0dSopenharmony_ci+STATUS ThirdPartyModelParser::BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs, 10084be168c0dSopenharmony_ci+ CNodePtr *operator_node) { 10085be168c0dSopenharmony_ci+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 10086be168c0dSopenharmony_ci+ NotSupportOp::GetInstance()->set_fmk_type("THIRDPARTY"); 10087be168c0dSopenharmony_ci+ STATUS status = RET_OK; 10088be168c0dSopenharmony_ci+ 10089be168c0dSopenharmony_ci+ // create primitive and build CNode of CUSTOM operator 10090be168c0dSopenharmony_ci+ ops::PrimitiveCPtr primitive_c; 10091be168c0dSopenharmony_ci+ auto prim = std::make_unique<ops::Custom>(); 10092be168c0dSopenharmony_ci+ MS_CHECK_TRUE_RET(prim != nullptr, RET_ERROR); 10093be168c0dSopenharmony_ci+ prim->set_type("ThirdPartyModel"); 10094be168c0dSopenharmony_ci+ 10095be168c0dSopenharmony_ci+ const auto &attr = param_.extended_parameters; 10096be168c0dSopenharmony_ci+ prim->set_attr(attr); 10097be168c0dSopenharmony_ci+ primitive_c = prim->GetPrim(); 10098be168c0dSopenharmony_ci+ if (primitive_c == nullptr) { 10099be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "failed to create primitive: custom"; 10100be168c0dSopenharmony_ci+ return RET_ERROR; 10101be168c0dSopenharmony_ci+ } 10102be168c0dSopenharmony_ci+ 10103be168c0dSopenharmony_ci+ auto operator_cnode = func_graph->NewCNode(primitive_c, op_inputs); 10104be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(operator_cnode); 10105be168c0dSopenharmony_ci+ operator_cnode->set_fullname_with_scope("Custom"); 10106be168c0dSopenharmony_ci+ *operator_node = operator_cnode; 10107be168c0dSopenharmony_ci+ return status; 10108be168c0dSopenharmony_ci+} 10109be168c0dSopenharmony_ci+ 10110be168c0dSopenharmony_ci+STATUS ThirdPartyModelParser::BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node) { 10111be168c0dSopenharmony_ci+ MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr); 10112be168c0dSopenharmony_ci+ 10113be168c0dSopenharmony_ci+ auto dtypes = param_.output_dtypes; 10114be168c0dSopenharmony_ci+ auto shapes = param_.output_shapes; 10115be168c0dSopenharmony_ci+ auto names = param_.output_names; 10116be168c0dSopenharmony_ci+ 10117be168c0dSopenharmony_ci+ auto output_size = dtypes.size(); 10118be168c0dSopenharmony_ci+ std::vector<AnfNodePtr> output_nodes = {}; 10119be168c0dSopenharmony_ci+ 10120be168c0dSopenharmony_ci+ // Use TupleGetItem to wrap op outputs. 10121be168c0dSopenharmony_ci+ AbstractBasePtrList abstract_list; 10122be168c0dSopenharmony_ci+ for (size_t i = 0; i < output_size; i++) { 10123be168c0dSopenharmony_ci+ auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]); 10124be168c0dSopenharmony_ci+ if (abstract_tensor == nullptr) { 10125be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create tensor abstract failed"; 10126be168c0dSopenharmony_ci+ return RET_ERROR; 10127be168c0dSopenharmony_ci+ } 10128be168c0dSopenharmony_ci+ abstract_list.emplace_back(abstract_tensor); 10129be168c0dSopenharmony_ci+ auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>(); 10130be168c0dSopenharmony_ci+ if (tuple_get_item_prim_ptr == nullptr) { 10131be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new TupleGetItem failed"; 10132be168c0dSopenharmony_ci+ return RET_NULL_PTR; 10133be168c0dSopenharmony_ci+ } 10134be168c0dSopenharmony_ci+ auto tuple_get_item_prim_c = tuple_get_item_prim_ptr->GetPrim(); 10135be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(tuple_get_item_prim_c); 10136be168c0dSopenharmony_ci+ auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_c); 10137be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(tuple_get_item_prim); 10138be168c0dSopenharmony_ci+ auto get_item_value = NewValueNode(MakeValue<int>(i)); 10139be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(get_item_value); 10140be168c0dSopenharmony_ci+ std::vector<AnfNodePtr> inputs = {tuple_get_item_prim, operator_node, get_item_value}; 10141be168c0dSopenharmony_ci+ CNodePtr get_item_cnode = func_graph->NewCNode(inputs); 10142be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(get_item_cnode); 10143be168c0dSopenharmony_ci+ std::string output_item_name = operator_node->fullname_with_scope() + "_getitem_" + std::to_string(i); 10144be168c0dSopenharmony_ci+ auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32); 10145be168c0dSopenharmony_ci+ if (get_item_abstract == nullptr) { 10146be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "Create tensor abstarct failed"; 10147be168c0dSopenharmony_ci+ return RET_ERROR; 10148be168c0dSopenharmony_ci+ } 10149be168c0dSopenharmony_ci+ get_item_cnode->set_fullname_with_scope(output_item_name); 10150be168c0dSopenharmony_ci+ get_item_cnode->set_abstract(get_item_abstract); 10151be168c0dSopenharmony_ci+ output_nodes.push_back(get_item_cnode); 10152be168c0dSopenharmony_ci+ } 10153be168c0dSopenharmony_ci+ auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list); 10154be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(abstract_tuple); 10155be168c0dSopenharmony_ci+ operator_node->set_abstract(abstract_tuple); 10156be168c0dSopenharmony_ci+ 10157be168c0dSopenharmony_ci+ // Use MakeTuple node to wrap all outputs as single input of Return node. 10158be168c0dSopenharmony_ci+ auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>(); 10159be168c0dSopenharmony_ci+ if (make_tuple_prim_ptr == nullptr) { 10160be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new MakeTuple failed"; 10161be168c0dSopenharmony_ci+ return RET_NULL_PTR; 10162be168c0dSopenharmony_ci+ } 10163be168c0dSopenharmony_ci+ auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim(); 10164be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(make_tuple_prim_c); 10165be168c0dSopenharmony_ci+ auto make_tuple_prim = NewValueNode(make_tuple_prim_c); 10166be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(make_tuple_prim); 10167be168c0dSopenharmony_ci+ std::vector<AnfNodePtr> make_tuple_inputs = output_nodes; 10168be168c0dSopenharmony_ci+ make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim); 10169be168c0dSopenharmony_ci+ auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs); 10170be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(make_tuple_cnode); 10171be168c0dSopenharmony_ci+ make_tuple_cnode->set_fullname_with_scope("return_tuple"); 10172be168c0dSopenharmony_ci+ 10173be168c0dSopenharmony_ci+ auto return_prim_ptr = std::make_shared<ops::Return>(); 10174be168c0dSopenharmony_ci+ if (return_prim_ptr == nullptr) { 10175be168c0dSopenharmony_ci+ MS_LOG(ERROR) << "new Return failed"; 10176be168c0dSopenharmony_ci+ return RET_NULL_PTR; 10177be168c0dSopenharmony_ci+ } 10178be168c0dSopenharmony_ci+ auto return_prim_c = return_prim_ptr->GetPrim(); 10179be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(return_prim_c); 10180be168c0dSopenharmony_ci+ std::vector<AnfNodePtr> op_inputs{make_tuple_cnode}; 10181be168c0dSopenharmony_ci+ auto cnode = func_graph->NewCNode(return_prim_c, op_inputs); 10182be168c0dSopenharmony_ci+ MSLITE_CHECK_PTR(cnode); 10183be168c0dSopenharmony_ci+ cnode->set_fullname_with_scope("Return"); 10184be168c0dSopenharmony_ci+ func_graph->set_return(cnode); 10185be168c0dSopenharmony_ci+ 10186be168c0dSopenharmony_ci+ // Save original output tensor names. 10187be168c0dSopenharmony_ci+ ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(names); 10188be168c0dSopenharmony_ci+ return RET_OK; 10189be168c0dSopenharmony_ci+} 10190be168c0dSopenharmony_ci+ 10191be168c0dSopenharmony_ci+REG_MODEL_PARSER(kFmkTypeThirdParty, LiteModelParserCreator<ThirdPartyModelParser>) 10192be168c0dSopenharmony_ci+} // namespace lite 10193be168c0dSopenharmony_ci+} // namespace mindspore 10194be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 10195be168c0dSopenharmony_cinew file mode 100644 10196be168c0dSopenharmony_ciindex 00000000..c4b197b8 10197be168c0dSopenharmony_ci--- /dev/null 10198be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h 10199be168c0dSopenharmony_ci@@ -0,0 +1,50 @@ 10200be168c0dSopenharmony_ci+/** 10201be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd 10202be168c0dSopenharmony_ci+ * 10203be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License"); 10204be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License. 10205be168c0dSopenharmony_ci+ * You may obtain a copy of the License at 10206be168c0dSopenharmony_ci+ * 10207be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0 10208be168c0dSopenharmony_ci+ * 10209be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software 10210be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS, 10211be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10212be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and 10213be168c0dSopenharmony_ci+ * limitations under the License. 10214be168c0dSopenharmony_ci+ */ 10215be168c0dSopenharmony_ci+ 10216be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 10217be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 10218be168c0dSopenharmony_ci+ 10219be168c0dSopenharmony_ci+#include <string> 10220be168c0dSopenharmony_ci+#include <vector> 10221be168c0dSopenharmony_ci+#include "schema/inner/model_generated.h" 10222be168c0dSopenharmony_ci+#include "base/base.h" 10223be168c0dSopenharmony_ci+#include "ir/anf.h" 10224be168c0dSopenharmony_ci+#include "ir/func_graph.h" 10225be168c0dSopenharmony_ci+#include "include/errorcode.h" 10226be168c0dSopenharmony_ci+#include "include/registry/model_parser.h" 10227be168c0dSopenharmony_ci+#include "tools/converter/config_parser/third_party_param_parser.h" 10228be168c0dSopenharmony_ci+ 10229be168c0dSopenharmony_ci+namespace mindspore { 10230be168c0dSopenharmony_ci+namespace lite { 10231be168c0dSopenharmony_ci+class ThirdPartyModelParser : public converter::ModelParser { 10232be168c0dSopenharmony_ci+ public: 10233be168c0dSopenharmony_ci+ api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override; 10234be168c0dSopenharmony_ci+ 10235be168c0dSopenharmony_ci+ private: 10236be168c0dSopenharmony_ci+ STATUS InitConfig(const std::string &config_file); 10237be168c0dSopenharmony_ci+ api::FuncGraphPtr CreateFuncGraph(); 10238be168c0dSopenharmony_ci+ STATUS BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs); 10239be168c0dSopenharmony_ci+ STATUS BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs, 10240be168c0dSopenharmony_ci+ CNodePtr *operator_node); 10241be168c0dSopenharmony_ci+ STATUS BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node); 10242be168c0dSopenharmony_ci+ 10243be168c0dSopenharmony_ci+ std::string model_file_ = ""; 10244be168c0dSopenharmony_ci+ ThirdPartyModelParam param_; 10245be168c0dSopenharmony_ci+}; 10246be168c0dSopenharmony_ci+} // namespace lite 10247be168c0dSopenharmony_ci+} // namespace mindspore 10248be168c0dSopenharmony_ci+ 10249be168c0dSopenharmony_ci+#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_ 10250be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc 10251be168c0dSopenharmony_ciindex 832fb92d..6bc2d4d3 100644 10252be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/registry/model_parser_registry.cc 10253be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc 10254be168c0dSopenharmony_ci@@ -26,7 +26,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room; 10255be168c0dSopenharmony_ci } // namespace 10256be168c0dSopenharmony_ci 10257be168c0dSopenharmony_ci ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { 10258be168c0dSopenharmony_ci- if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { 10259be168c0dSopenharmony_ci+ if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) { 10260be168c0dSopenharmony_ci MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; 10261be168c0dSopenharmony_ci return; 10262be168c0dSopenharmony_ci } 10263be168c0dSopenharmony_ci@@ -38,7 +38,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator 10264be168c0dSopenharmony_ci } 10265be168c0dSopenharmony_ci 10266be168c0dSopenharmony_ci converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) { 10267be168c0dSopenharmony_ci- if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) { 10268be168c0dSopenharmony_ci+ if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) { 10269be168c0dSopenharmony_ci MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; 10270be168c0dSopenharmony_ci return nullptr; 10271be168c0dSopenharmony_ci } 10272be168c0dSopenharmony_ci-- 10273be168c0dSopenharmony_ci2.17.1 10274be168c0dSopenharmony_ci 10275