1be168c0dSopenharmony_ciFrom baf2daaebd70448cddd35f5011642fe585d071b5 Mon Sep 17 00:00:00 2001
2be168c0dSopenharmony_ciFrom: chengfeng27 <chengfeng27@huawei.com>
3be168c0dSopenharmony_ciDate: Tue, 5 Mar 2024 20:00:24 +0800
4be168c0dSopenharmony_ciSubject: [PATCH] hilog use macro definition api
5be168c0dSopenharmony_ci
6be168c0dSopenharmony_ci---
7be168c0dSopenharmony_ci cmake/external_libs/flatbuffers.cmake         |   4 +-
8be168c0dSopenharmony_ci include/api/context.h                         |  65 ++
9be168c0dSopenharmony_ci include/c_api/context_c.h                     | 111 +++
10be168c0dSopenharmony_ci include/c_api/model_c.h                       | 178 ++++
11be168c0dSopenharmony_ci include/c_api/tensor_c.h                      |  14 +
12be168c0dSopenharmony_ci include/c_api/types_c.h                       |  57 +-
13be168c0dSopenharmony_ci include/sdk_api/context.h                     | 103 +++
14be168c0dSopenharmony_ci include/sdk_api/tensor.h                      |  13 +
15be168c0dSopenharmony_ci include/sdk_api/types.h                       |  38 +-
16be168c0dSopenharmony_ci .../plugin/device/cpu/kernel/nnacl/BUILD.gn   |   3 +
17be168c0dSopenharmony_ci .../device/cpu/kernel/nnacl/CMakeLists.txt    |   2 +-
18be168c0dSopenharmony_ci .../kernel/nnacl/avx/scatter_nd_binary_avx.h  |  66 ++
19be168c0dSopenharmony_ci .../nnacl/avx512/scatter_nd_binary_avx512.h   |  66 ++
20be168c0dSopenharmony_ci .../cpu/kernel/nnacl/base/scatter_nd_binary.c |  28 +
21be168c0dSopenharmony_ci .../cpu/kernel/nnacl/base/scatter_nd_binary.h |   3 +
22be168c0dSopenharmony_ci .../nnacl/base/scatter_nd_binary_simd.h.in    |  14 +
23be168c0dSopenharmony_ci .../kernel/nnacl/custom_is_inf_parameter.h    |  26 +
24be168c0dSopenharmony_ci .../nnacl/custom_masked_fill_parameter.h      |  26 +
25be168c0dSopenharmony_ci .../custom_tensor_scatter_max_parameter.h     |  26 +
26be168c0dSopenharmony_ci .../kernel/nnacl/infer/custom_is_inf_infer.c  |  38 +
27be168c0dSopenharmony_ci .../kernel/nnacl/infer/custom_is_inf_infer.h  |  31 +
28be168c0dSopenharmony_ci .../nnacl/infer/custom_masked_fill_infer.c    |  37 +
29be168c0dSopenharmony_ci .../nnacl/infer/custom_masked_fill_infer.h    |  31 +
30be168c0dSopenharmony_ci .../infer/custom_tensor_scatter_max_infer.c   |  37 +
31be168c0dSopenharmony_ci .../infer/custom_tensor_scatter_max_infer.h   |  31 +
32be168c0dSopenharmony_ci .../nnacl/neon/scatter_nd_binary_neon.h       |  65 ++
33be168c0dSopenharmony_ci .../plugin/device/cpu/kernel/nnacl/op_base.h  |   4 +
34be168c0dSopenharmony_ci .../cpu/kernel/nnacl/scatter_nd_binary_simd.h |  36 +
35be168c0dSopenharmony_ci .../kernel/nnacl/sse/scatter_nd_binary_sse.h  |  66 ++
36be168c0dSopenharmony_ci mindspore/core/mindrt/BUILD.gn                |   9 +-
37be168c0dSopenharmony_ci .../mindrt/src/thread/actor_threadpool.cc     |   2 +-
38be168c0dSopenharmony_ci .../core/mindrt/src/thread/core_affinity.cc   |   6 +-
39be168c0dSopenharmony_ci .../core/mindrt/src/thread/core_affinity.h    |   2 +-
40be168c0dSopenharmony_ci .../mindrt/src/thread/parallel_threadpool.cc  |   2 +-
41be168c0dSopenharmony_ci mindspore/core/mindrt/src/thread/threadlog.h  |  28 +-
42be168c0dSopenharmony_ci .../core/mindrt/src/thread/threadpool.cc      |   7 +-
43be168c0dSopenharmony_ci mindspore/lite/BUILD.gn                       |  82 +-
44be168c0dSopenharmony_ci mindspore/lite/CMakeLists.txt                 |   5 +-
45be168c0dSopenharmony_ci mindspore/lite/include/lite_types.h           |   1 +
46be168c0dSopenharmony_ci mindspore/lite/include/model.h                |   4 +
47be168c0dSopenharmony_ci .../lite/include/registry/converter_context.h |   4 +-
48be168c0dSopenharmony_ci mindspore/lite/mindir/include/mindir.h        |   2 +
49be168c0dSopenharmony_ci mindspore/lite/mindir/src/mindir.cc           |  40 +
50be168c0dSopenharmony_ci mindspore/lite/mindir/src/mindir_tensor.cc    |   2 +-
51be168c0dSopenharmony_ci mindspore/lite/mindir/src/utils.cc            |   2 +-
52be168c0dSopenharmony_ci mindspore/lite/src/CMakeLists.txt             |   6 +-
53be168c0dSopenharmony_ci mindspore/lite/src/common/context_util.cc     |  14 +-
54be168c0dSopenharmony_ci mindspore/lite/src/common/log.cc              |  33 +-
55be168c0dSopenharmony_ci mindspore/lite/src/common/log.h               |  50 +-
56be168c0dSopenharmony_ci .../common/ops/populate/custom_populate.cc    |  53 ++
57be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/context_c.cc  | 372 +++++++-
58be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/context_c.h   |  23 -
59be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/model_c.cc    | 724 ++++++++-------
60be168c0dSopenharmony_ci mindspore/lite/src/litert/c_api/tensor_c.cc   |  78 +-
61be168c0dSopenharmony_ci .../lite/src/litert/c_api/type_c_private.h    |  40 +
62be168c0dSopenharmony_ci mindspore/lite/src/litert/cxx_api/context.cc  |  85 ++
63be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/converters.cc     |  60 +-
64be168c0dSopenharmony_ci .../lite/src/litert/cxx_api/converters.h      |   4 +-
65be168c0dSopenharmony_ci .../src/litert/delegate/nnrt/CMakeLists.txt   |  27 +-
66be168c0dSopenharmony_ci .../delegate/nnrt/checker/primitive_check.cc  |   2 +
67be168c0dSopenharmony_ci .../src/litert/delegate/nnrt/nnrt_delegate.cc | 836 ++++++++++++++----
68be168c0dSopenharmony_ci .../src/litert/delegate/nnrt/nnrt_delegate.h  |  74 +-
69be168c0dSopenharmony_ci .../litert/delegate/nnrt/nnrt_model_kernel.cc |   3 +-
70be168c0dSopenharmony_ci .../litert/delegate/nnrt/nnrt_model_kernel.h  |   2 +-
71be168c0dSopenharmony_ci .../src/litert/delegate/nnrt/nnrt_stub.cc     |  99 +++
72be168c0dSopenharmony_ci mindspore/lite/src/litert/infer_manager.cc    |   3 +-
73be168c0dSopenharmony_ci mindspore/lite/src/litert/inner_context.cc    |   4 +
74be168c0dSopenharmony_ci mindspore/lite/src/litert/inner_context.h     |  14 +
75be168c0dSopenharmony_ci mindspore/lite/src/litert/kernel/cpu/BUILD.gn |  51 +-
76be168c0dSopenharmony_ci .../src/litert/kernel/cpu/base/custom_base.cc |  46 +
77be168c0dSopenharmony_ci .../src/litert/kernel/cpu/base/custom_base.h  |  43 +
78be168c0dSopenharmony_ci .../litert/kernel/cpu/base/custom_is_inf.cc   |  61 ++
79be168c0dSopenharmony_ci .../litert/kernel/cpu/base/custom_is_inf.h    |  38 +
80be168c0dSopenharmony_ci .../kernel/cpu/base/custom_masked_fill.cc     |  84 ++
81be168c0dSopenharmony_ci .../kernel/cpu/base/custom_masked_fill.h      |  35 +
82be168c0dSopenharmony_ci .../kernel/cpu/base/custom_tensor_scatter.cc  |  75 ++
83be168c0dSopenharmony_ci .../kernel/cpu/base/custom_tensor_scatter.h   |  36 +
84be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_model.cc       |  29 +
85be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_session.cc     |  39 +-
86be168c0dSopenharmony_ci mindspore/lite/src/litert/lite_session.h      |   1 +
87be168c0dSopenharmony_ci mindspore/lite/src/litert/scheduler.cc        |  17 +
88be168c0dSopenharmony_ci mindspore/lite/src/litert/tensor_category.cc  |   4 +
89be168c0dSopenharmony_ci mindspore/lite/src/litert/tensor_category.h   |   1 +
90be168c0dSopenharmony_ci mindspore/lite/test/CMakeLists.txt            |  15 +-
91be168c0dSopenharmony_ci mindspore/lite/test/runtest.sh                |   1 +
92be168c0dSopenharmony_ci .../test/ut/test_data/third_party_model.cfg   |   8 +
93be168c0dSopenharmony_ci .../tools/converter/api/converter_api_test.cc |  10 +
94be168c0dSopenharmony_ci .../third_party_param_parser_test.cc          | 176 ++++
95be168c0dSopenharmony_ci .../lite/tools/benchmark/benchmark_base.cc    |   2 +-
96be168c0dSopenharmony_ci .../lite/tools/benchmark/benchmark_base.h     |   2 +-
97be168c0dSopenharmony_ci .../lite/tools/benchmark/benchmark_c_api.cc   |   4 +
98be168c0dSopenharmony_ci .../tools/benchmark/benchmark_unified_api.cc  |   5 +
99be168c0dSopenharmony_ci .../lite/tools/benchmark_train/CMakeLists.txt |   3 +
100be168c0dSopenharmony_ci mindspore/lite/tools/benchmark_train/main.cc  |   3 +-
101be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_runner.cc  |  10 +-
102be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_train.cc   | 418 +--------
103be168c0dSopenharmony_ci .../lite/tools/benchmark_train/net_train.h    | 229 +----
104be168c0dSopenharmony_ci .../tools/benchmark_train/net_train_base.cc   | 410 +++++++++
105be168c0dSopenharmony_ci .../tools/benchmark_train/net_train_base.h    | 288 ++++++
106be168c0dSopenharmony_ci .../tools/benchmark_train/net_train_c_api.cc  | 659 ++++++++++++++
107be168c0dSopenharmony_ci .../tools/benchmark_train/net_train_c_api.h   | 121 +++
108be168c0dSopenharmony_ci .../tools/benchmark_train/run_net_train.cc    |  86 ++
109be168c0dSopenharmony_ci .../tools/benchmark_train/run_net_train.h     |  22 +
110be168c0dSopenharmony_ci mindspore/lite/tools/converter/CMakeLists.txt |   4 +
111be168c0dSopenharmony_ci .../config_parser/config_file_parser.cc       |  27 +
112be168c0dSopenharmony_ci .../config_parser/config_file_parser.h        |  15 +
113be168c0dSopenharmony_ci .../config_parser/third_party_param_parser.cc | 299 +++++++
114be168c0dSopenharmony_ci .../config_parser/third_party_param_parser.h  |  44 +
115be168c0dSopenharmony_ci mindspore/lite/tools/converter/converter.cc   |  34 +-
116be168c0dSopenharmony_ci .../tools/converter/converter_funcgraph.cc    |  13 +-
117be168c0dSopenharmony_ci .../converter_lite/converter_flags.cc         |   4 +-
118be168c0dSopenharmony_ci .../tools/converter/cxx_api/converter_para.h  |  14 +
119be168c0dSopenharmony_ci .../tools/converter/graphdef_transform.cc     |  44 +
120be168c0dSopenharmony_ci .../parser/third_party/CMakeLists.txt         |   4 +
121be168c0dSopenharmony_ci .../third_party/third_party_model_parser.cc   | 277 ++++++
122be168c0dSopenharmony_ci .../third_party/third_party_model_parser.h    |  50 ++
123be168c0dSopenharmony_ci .../registry/model_parser_registry.cc         |   4 +-
124be168c0dSopenharmony_ci 117 files changed, 6456 insertions(+), 1432 deletions(-)
125be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h
126be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h
127be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h
128be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h
129be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h
130be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c
131be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h
132be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c
133be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h
134be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c
135be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h
136be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h
137be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h
138be168c0dSopenharmony_ci create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h
139be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/c_api/type_c_private.h
140be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc
141be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc
142be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_base.h
143be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc
144be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h
145be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc
146be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h
147be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc
148be168c0dSopenharmony_ci create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h
149be168c0dSopenharmony_ci create mode 100644 mindspore/lite/test/ut/test_data/third_party_model.cfg
150be168c0dSopenharmony_ci create mode 100644 mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc
151be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/net_train_base.cc
152be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/net_train_base.h
153be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/net_train_c_api.cc
154be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/net_train_c_api.h
155be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/run_net_train.cc
156be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/benchmark_train/run_net_train.h
157be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc
158be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.h
159be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt
160be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc
161be168c0dSopenharmony_ci create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h
162be168c0dSopenharmony_ci
163be168c0dSopenharmony_cidiff --git a/cmake/external_libs/flatbuffers.cmake b/cmake/external_libs/flatbuffers.cmake
164be168c0dSopenharmony_ciindex 2fde4311..87f0425b 100644
165be168c0dSopenharmony_ci--- a/cmake/external_libs/flatbuffers.cmake
166be168c0dSopenharmony_ci+++ b/cmake/external_libs/flatbuffers.cmake
167be168c0dSopenharmony_ci@@ -21,8 +21,8 @@ else()
168be168c0dSopenharmony_ci         # flatbuffers.lib cimplied by msvc
169be168c0dSopenharmony_ci         set(CMAKE_STATIC_LIBRARY_PREFIX "")
170be168c0dSopenharmony_ci     else()
171be168c0dSopenharmony_ci-        set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong")
172be168c0dSopenharmony_ci-        set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong")
173be168c0dSopenharmony_ci+        set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -Wno-error=unused-but-set-variable")
174be168c0dSopenharmony_ci+        set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -Wno-error=unused-but-set-variable")
175be168c0dSopenharmony_ci     endif()
176be168c0dSopenharmony_ci 
177be168c0dSopenharmony_ci     if(WIN32)
178be168c0dSopenharmony_cidiff --git a/include/api/context.h b/include/api/context.h
179be168c0dSopenharmony_ciindex c9fb11f0..eb704d44 100644
180be168c0dSopenharmony_ci--- a/include/api/context.h
181be168c0dSopenharmony_ci+++ b/include/api/context.h
182be168c0dSopenharmony_ci@@ -39,6 +39,8 @@ enum DeviceType {
183be168c0dSopenharmony_ci   kAscend310,
184be168c0dSopenharmony_ci   kCustomDevice,
185be168c0dSopenharmony_ci   kAllDevice,
186be168c0dSopenharmony_ci+  //ohos-only device range[60,80)
187be168c0dSopenharmony_ci+  kNNRt = 60,
188be168c0dSopenharmony_ci   // add new type here
189be168c0dSopenharmony_ci   kInvalidDeviceType = 100,
190be168c0dSopenharmony_ci };
191be168c0dSopenharmony_ci@@ -598,5 +600,68 @@ void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_
192be168c0dSopenharmony_ci   SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
193be168c0dSopenharmony_ci }
194be168c0dSopenharmony_ci std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
195be168c0dSopenharmony_ci+
196be168c0dSopenharmony_ci+struct Extension {
197be168c0dSopenharmony_ci+  std::string name;
198be168c0dSopenharmony_ci+  std::vector<uint8_t> value;
199be168c0dSopenharmony_ci+};
200be168c0dSopenharmony_ci+
201be168c0dSopenharmony_ci+class MS_API NNRTDeviceInfo : public DeviceInfoContext {
202be168c0dSopenharmony_ci+ public:
203be168c0dSopenharmony_ci+  /// \brief Get the type of this DeviceInfoContext.
204be168c0dSopenharmony_ci+  ///
205be168c0dSopenharmony_ci+  /// \return Type of this DeviceInfoContext.
206be168c0dSopenharmony_ci+  enum DeviceType GetDeviceType() const override { return DeviceType::kNNRt; };
207be168c0dSopenharmony_ci+
208be168c0dSopenharmony_ci+  /// \brief Set device id.
209be168c0dSopenharmony_ci+  ///
210be168c0dSopenharmony_ci+  /// \param[in] device_id The device id.
211be168c0dSopenharmony_ci+  void SetDeviceID(size_t device_id);
212be168c0dSopenharmony_ci+
213be168c0dSopenharmony_ci+  /// \brief Get the device id.
214be168c0dSopenharmony_ci+  ///
215be168c0dSopenharmony_ci+  /// \return The device id.
216be168c0dSopenharmony_ci+  size_t GetDeviceID() const;
217be168c0dSopenharmony_ci+
218be168c0dSopenharmony_ci+  /// \brief Set performance mode.
219be168c0dSopenharmony_ci+  ///
220be168c0dSopenharmony_ci+  /// \param[in] performance_mode The performance mode.
221be168c0dSopenharmony_ci+  void SetPerformanceMode(int performance_mode);
222be168c0dSopenharmony_ci+
223be168c0dSopenharmony_ci+  /// \brief Get performance mode.
224be168c0dSopenharmony_ci+  ///
225be168c0dSopenharmony_ci+  /// \return The priority.
226be168c0dSopenharmony_ci+  int GetPerformanceMode() const;
227be168c0dSopenharmony_ci+
228be168c0dSopenharmony_ci+  /// \brief Set priority.
229be168c0dSopenharmony_ci+  ///
230be168c0dSopenharmony_ci+  /// \param[in] priority The priority.
231be168c0dSopenharmony_ci+  void SetPriority(int priority);
232be168c0dSopenharmony_ci+
233be168c0dSopenharmony_ci+  /// \brief Get priority.
234be168c0dSopenharmony_ci+  ///
235be168c0dSopenharmony_ci+  /// \return The priority.
236be168c0dSopenharmony_ci+  int GetPriority() const;
237be168c0dSopenharmony_ci+
238be168c0dSopenharmony_ci+  /// \brief Set enables to perform the float16 inference
239be168c0dSopenharmony_ci+  ///
240be168c0dSopenharmony_ci+  /// \param[in] is_fp16 Enable float16 inference or not.
241be168c0dSopenharmony_ci+  void SetEnableFP16(bool is_fp16);
242be168c0dSopenharmony_ci+
243be168c0dSopenharmony_ci+  /// \brief Get enables to perform the float16 inference
244be168c0dSopenharmony_ci+  ///
245be168c0dSopenharmony_ci+  /// \return Whether enable float16 inference.
246be168c0dSopenharmony_ci+  bool GetEnableFP16() const;
247be168c0dSopenharmony_ci+
248be168c0dSopenharmony_ci+  /// \brief Set extensions
249be168c0dSopenharmony_ci+  ///
250be168c0dSopenharmony_ci+  /// \param[in] extension array.
251be168c0dSopenharmony_ci+  void SetExtensions(const std::vector<Extension> &extensions);
252be168c0dSopenharmony_ci+
253be168c0dSopenharmony_ci+  /// \brief Get extensions
254be168c0dSopenharmony_ci+  ///
255be168c0dSopenharmony_ci+  /// \return extension array.
256be168c0dSopenharmony_ci+  std::vector<Extension> GetExtensions() const;
257be168c0dSopenharmony_ci+};
258be168c0dSopenharmony_ci }  // namespace mindspore
259be168c0dSopenharmony_ci #endif  // MINDSPORE_INCLUDE_API_CONTEXT_H
260be168c0dSopenharmony_cidiff --git a/include/c_api/context_c.h b/include/c_api/context_c.h
261be168c0dSopenharmony_ciindex 53839e80..8951da25 100644
262be168c0dSopenharmony_ci--- a/include/c_api/context_c.h
263be168c0dSopenharmony_ci+++ b/include/c_api/context_c.h
264be168c0dSopenharmony_ci@@ -19,6 +19,7 @@
265be168c0dSopenharmony_ci #include <stddef.h>
266be168c0dSopenharmony_ci #include <stdint.h>
267be168c0dSopenharmony_ci #include <stdbool.h>
268be168c0dSopenharmony_ci+#include "include/c_api/status_c.h"
269be168c0dSopenharmony_ci #include "include/c_api/types_c.h"
270be168c0dSopenharmony_ci 
271be168c0dSopenharmony_ci #ifdef __cplusplus
272be168c0dSopenharmony_ci@@ -173,6 +174,116 @@ OH_AI_API void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info,
273be168c0dSopenharmony_ci /// \return NPU frequency
274be168c0dSopenharmony_ci OH_AI_API int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info);
275be168c0dSopenharmony_ci 
276be168c0dSopenharmony_ci+/// \brief Obtain the all device descriptions in NNRT.
277be168c0dSopenharmony_ci+///
278be168c0dSopenharmony_ci+/// \param[out] num Number of NNRT device description.
279be168c0dSopenharmony_ci+///
280be168c0dSopenharmony_ci+/// \return NNRT device description array.
281be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num);
282be168c0dSopenharmony_ci+
283be168c0dSopenharmony_ci+/// \brief Obtain the specified element in NNRt device description array.
284be168c0dSopenharmony_ci+///
285be168c0dSopenharmony_ci+/// \param[in] descs NNRT device description array.
286be168c0dSopenharmony_ci+/// \param[in] index Element index.
287be168c0dSopenharmony_ci+///
288be168c0dSopenharmony_ci+/// \return NNRT device description.
289be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index);
290be168c0dSopenharmony_ci+
291be168c0dSopenharmony_ci+/// \brief Obtain the all device descriptions in NNRT.
292be168c0dSopenharmony_ci+///
293be168c0dSopenharmony_ci+/// \param[out] num Number of NNRT device description.
294be168c0dSopenharmony_ci+///
295be168c0dSopenharmony_ci+/// \return NNRT device description array.
296be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num);
297be168c0dSopenharmony_ci+
298be168c0dSopenharmony_ci+/// \brief Destroy the NNRT device descriptions returned by OH_AI_GetAllNNRTDeviceDescs().
299be168c0dSopenharmony_ci+///
300be168c0dSopenharmony_ci+/// \param[in] desc NNRT device description array.
301be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc);
302be168c0dSopenharmony_ci+
303be168c0dSopenharmony_ci+/// \brief Obtain the device id in NNRT device description.
304be168c0dSopenharmony_ci+///
305be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance.
306be168c0dSopenharmony_ci+///
307be168c0dSopenharmony_ci+/// \return NNRT device id.
308be168c0dSopenharmony_ci+OH_AI_API size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
309be168c0dSopenharmony_ci+
310be168c0dSopenharmony_ci+/// \brief Obtain the device name in NNRT device description.
311be168c0dSopenharmony_ci+///
312be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance.
313be168c0dSopenharmony_ci+///
314be168c0dSopenharmony_ci+/// \return NNRT device name.
315be168c0dSopenharmony_ci+OH_AI_API const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
316be168c0dSopenharmony_ci+
317be168c0dSopenharmony_ci+/// \brief Obtain the device type in NNRT device description.
318be168c0dSopenharmony_ci+///
319be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance.
320be168c0dSopenharmony_ci+///
321be168c0dSopenharmony_ci+/// \return NNRT device type.
322be168c0dSopenharmony_ci+OH_AI_API OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
323be168c0dSopenharmony_ci+
324be168c0dSopenharmony_ci+/// \brief Create the NNRT device info by exactly matching the specific device name.
325be168c0dSopenharmony_ci+///
326be168c0dSopenharmony_ci+/// \param[in] name NNRt device name.
327be168c0dSopenharmony_ci+///
328be168c0dSopenharmony_ci+/// \return Device info object handle.
329be168c0dSopenharmony_ci+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name);
330be168c0dSopenharmony_ci+
331be168c0dSopenharmony_ci+/// \brief Create the NNRT device info by finding the first device with the specific device type.
332be168c0dSopenharmony_ci+///
333be168c0dSopenharmony_ci+/// \param[in] name NNRt device type.
334be168c0dSopenharmony_ci+///
335be168c0dSopenharmony_ci+/// \return Device info object handle.
336be168c0dSopenharmony_ci+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type);
337be168c0dSopenharmony_ci+
338be168c0dSopenharmony_ci+/// \brief Set the NNRT device id, Only valid for NNRT.
339be168c0dSopenharmony_ci+///
340be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
341be168c0dSopenharmony_ci+/// \param[in] device_id NNRT device id.
342be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id);
343be168c0dSopenharmony_ci+
344be168c0dSopenharmony_ci+/// \brief Obtain the NNRT device id, Only valid for NNRT.
345be168c0dSopenharmony_ci+///
346be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
347be168c0dSopenharmony_ci+///
348be168c0dSopenharmony_ci+/// \return NNRT device id.
349be168c0dSopenharmony_ci+OH_AI_API size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info);
350be168c0dSopenharmony_ci+
351be168c0dSopenharmony_ci+/// \brief Set the NNRT performance mode, Only valid for NNRT.
352be168c0dSopenharmony_ci+///
353be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
354be168c0dSopenharmony_ci+/// \param[in] device_id NNRT performance mode.
355be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode);
356be168c0dSopenharmony_ci+
357be168c0dSopenharmony_ci+/// \brief Obtain the NNRT performance mode, Only valid for NNRT.
358be168c0dSopenharmony_ci+///
359be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
360be168c0dSopenharmony_ci+///
361be168c0dSopenharmony_ci+/// \return NNRT performance mode.
362be168c0dSopenharmony_ci+OH_AI_API OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info);
363be168c0dSopenharmony_ci+
364be168c0dSopenharmony_ci+/// \brief Set the NNRT priority, Only valid for NNRT.
365be168c0dSopenharmony_ci+///
366be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
367be168c0dSopenharmony_ci+/// \param[in] device_id NNRT priority.
368be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority);
369be168c0dSopenharmony_ci+
370be168c0dSopenharmony_ci+/// \brief Obtain the NNRT priority, Only valid for NNRT.
371be168c0dSopenharmony_ci+///
372be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
373be168c0dSopenharmony_ci+///
374be168c0dSopenharmony_ci+/// \return NNRT priority.
375be168c0dSopenharmony_ci+OH_AI_API OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info);
376be168c0dSopenharmony_ci+
377be168c0dSopenharmony_ci+/// \brief Add extension of key/value format to device info, Only valid for NNRT.
378be168c0dSopenharmony_ci+///
379be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
380be168c0dSopenharmony_ci+/// \param[in] name The content of key as a C string.
381be168c0dSopenharmony_ci+/// \param[in] value The pointer to the value, which is a byte array.
382be168c0dSopenharmony_ci+/// \param[in] value_size The size of the value, which is a byte array.
383be168c0dSopenharmony_ci+///
384be168c0dSopenharmony_ci+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
385be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, const char *name, const char *value, size_t value_size);
386be168c0dSopenharmony_ci #ifdef __cplusplus
387be168c0dSopenharmony_ci }
388be168c0dSopenharmony_ci #endif
389be168c0dSopenharmony_cidiff --git a/include/c_api/model_c.h b/include/c_api/model_c.h
390be168c0dSopenharmony_ciindex 12a46bcd..2286e673 100644
391be168c0dSopenharmony_ci--- a/include/c_api/model_c.h
392be168c0dSopenharmony_ci+++ b/include/c_api/model_c.h
393be168c0dSopenharmony_ci@@ -26,6 +26,8 @@ extern "C" {
394be168c0dSopenharmony_ci 
395be168c0dSopenharmony_ci typedef void *OH_AI_ModelHandle;
396be168c0dSopenharmony_ci 
397be168c0dSopenharmony_ci+typedef void *OH_AI_TrainCfgHandle;
398be168c0dSopenharmony_ci+
399be168c0dSopenharmony_ci typedef struct OH_AI_TensorHandleArray {
400be168c0dSopenharmony_ci   size_t handle_num;
401be168c0dSopenharmony_ci   OH_AI_TensorHandle *handle_list;
402be168c0dSopenharmony_ci@@ -168,6 +170,182 @@ OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHa
403be168c0dSopenharmony_ci /// \return The output tensor handle with the given name, if the name is not found, an NULL is returned.
404be168c0dSopenharmony_ci OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name);
405be168c0dSopenharmony_ci 
406be168c0dSopenharmony_ci+/// \brief Create a TrainCfg object. Only valid for Lite Train.
407be168c0dSopenharmony_ci+///
408be168c0dSopenharmony_ci+/// \return TrainCfg object handle.
409be168c0dSopenharmony_ci+OH_AI_API OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate();
410be168c0dSopenharmony_ci+
411be168c0dSopenharmony_ci+/// \brief Destroy the train_cfg object. Only valid for Lite Train.
412be168c0dSopenharmony_ci+///
413be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle.
414be168c0dSopenharmony_ci+OH_AI_API void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg);
415be168c0dSopenharmony_ci+
416be168c0dSopenharmony_ci+/// \brief Obtains part of the name that identify a loss kernel. Only valid for Lite Train.
417be168c0dSopenharmony_ci+///
418be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle.
419be168c0dSopenharmony_ci+/// \param[in] num The num of loss_name.
420be168c0dSopenharmony_ci+///
421be168c0dSopenharmony_ci+/// \return loss_name.
422be168c0dSopenharmony_ci+OH_AI_API char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num);
423be168c0dSopenharmony_ci+
424be168c0dSopenharmony_ci+/// \brief Set part of the name that identify a loss kernel. Only valid for Lite Train.
425be168c0dSopenharmony_ci+///
426be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle.
427be168c0dSopenharmony_ci+/// \param[in] loss_name define part of the name that identify a loss kernel.
428be168c0dSopenharmony_ci+/// \param[in] num The num of loss_name.
429be168c0dSopenharmony_ci+OH_AI_API void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num);
430be168c0dSopenharmony_ci+
431be168c0dSopenharmony_ci+/// \brief Obtains optimization level of the train_cfg. Only valid for Lite Train.
432be168c0dSopenharmony_ci+///
433be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle.
434be168c0dSopenharmony_ci+///
435be168c0dSopenharmony_ci+/// \return OH_AI_OptimizationLevel.
436be168c0dSopenharmony_ci+OH_AI_API OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg);
437be168c0dSopenharmony_ci+
438be168c0dSopenharmony_ci+/// \brief Set optimization level of the train_cfg. Only valid for Lite Train.
439be168c0dSopenharmony_ci+///
440be168c0dSopenharmony_ci+/// \param[in] train_cfg TrainCfg object handle.
441be168c0dSopenharmony_ci+/// \param[in] level The optimization level of train_cfg.
442be168c0dSopenharmony_ci+OH_AI_API void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level);
443be168c0dSopenharmony_ci+
444be168c0dSopenharmony_ci+/// \brief Build the train model from model buffer so that it can run on a device. Only valid for Lite Train.
445be168c0dSopenharmony_ci+///
446be168c0dSopenharmony_ci+/// \param[in] model Model object handle.
447be168c0dSopenharmony_ci+/// \param[in] model_data Define the buffer read from a model file.
448be168c0dSopenharmony_ci+/// \param[in] data_size Define bytes number of model file buffer.
449be168c0dSopenharmony_ci+/// \param[in] model_type Define The type of model file.
450be168c0dSopenharmony_ci+/// \param[in] model_context Define the context used to store options during execution.
451be168c0dSopenharmony_ci+/// \param[in] train_cfg Define the config used by training.
452be168c0dSopenharmony_ci+///
453be168c0dSopenharmony_ci+/// \return OH_AI_Status.
454be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
455be168c0dSopenharmony_ci+                                     OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
456be168c0dSopenharmony_ci+                                     const OH_AI_TrainCfgHandle train_cfg);
457be168c0dSopenharmony_ci+
458be168c0dSopenharmony_ci+/// \brief Build the train model from model file buffer so that it can run on a device. Only valid for Lite Train.
459be168c0dSopenharmony_ci+///
460be168c0dSopenharmony_ci+/// \param[in] model Model object handle.
461be168c0dSopenharmony_ci+/// \param[in] model_path Define the model path.
462be168c0dSopenharmony_ci+/// \param[in] model_type Define The type of model file.
463be168c0dSopenharmony_ci+/// \param[in] model_context Define the context used to store options during execution.
464be168c0dSopenharmony_ci+/// \param[in] train_cfg Define the config used by training.
465be168c0dSopenharmony_ci+///
466be168c0dSopenharmony_ci+/// \return OH_AI_Status.
467be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path,
468be168c0dSopenharmony_ci+                                             OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
469be168c0dSopenharmony_ci+                                             const OH_AI_TrainCfgHandle train_cfg);
470be168c0dSopenharmony_ci+
471be168c0dSopenharmony_ci+/// \brief Train model by step. Only valid for Lite Train.
472be168c0dSopenharmony_ci+///
473be168c0dSopenharmony_ci+/// \param[in] model Model object handle.
474be168c0dSopenharmony_ci+/// \param[in] before CallBack before predict.
475be168c0dSopenharmony_ci+/// \param[in] after CallBack after predict.
476be168c0dSopenharmony_ci+///
477be168c0dSopenharmony_ci+/// \return OH_AI_Status.
478be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before,
479be168c0dSopenharmony_ci+                                     const OH_AI_KernelCallBack after);
480be168c0dSopenharmony_ci+
481be168c0dSopenharmony_ci+/// \brief Sets the Learning Rate of the training. Only valid for Lite Train.
482be168c0dSopenharmony_ci+///
483be168c0dSopenharmony_ci+/// \param[in] learning_rate to set.
484be168c0dSopenharmony_ci+///
485be168c0dSopenharmony_ci+/// \return OH_AI_Status of operation.
486be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate);
487be168c0dSopenharmony_ci+
488be168c0dSopenharmony_ci+/// \brief Obtains the Learning Rate of the optimizer. Only valid for Lite Train.
489be168c0dSopenharmony_ci+///
490be168c0dSopenharmony_ci+/// \return Learning rate. 0.0 if no optimizer was found.
491be168c0dSopenharmony_ci+OH_AI_API float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model);
492be168c0dSopenharmony_ci+
493be168c0dSopenharmony_ci+/// \brief Obtains all weights tensors of the model. Only valid for Lite Train.
494be168c0dSopenharmony_ci+///
495be168c0dSopenharmony_ci+/// \param[in] model Model object handle.
496be168c0dSopenharmony_ci+///
497be168c0dSopenharmony_ci+/// \return The vector that includes all gradient tensors.
498be168c0dSopenharmony_ci+OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model);
499be168c0dSopenharmony_ci+
500be168c0dSopenharmony_ci+/// \brief update weights tensors of the model. Only valid for Lite Train.
501be168c0dSopenharmony_ci+///
502be168c0dSopenharmony_ci+/// \param[in] new_weights A vector new weights.
503be168c0dSopenharmony_ci+///
504be168c0dSopenharmony_ci+/// \return OH_AI_Status
505be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights);
506be168c0dSopenharmony_ci+
507be168c0dSopenharmony_ci+/// \brief Get the model running mode.
508be168c0dSopenharmony_ci+///
509be168c0dSopenharmony_ci+/// \param[in] model Model object handle.
510be168c0dSopenharmony_ci+///
511be168c0dSopenharmony_ci+/// \return Is Train Mode or not.
512be168c0dSopenharmony_ci+OH_AI_API bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model);
513be168c0dSopenharmony_ci+
514be168c0dSopenharmony_ci+/// \brief Set the model running mode. Only valid for Lite Train.
515be168c0dSopenharmony_ci+///
516be168c0dSopenharmony_ci+/// \param[in] model Model object handle.
517be168c0dSopenharmony_ci+/// \param[in] train True means model runs in Train Mode, otherwise Eval Mode.
518be168c0dSopenharmony_ci+///
519be168c0dSopenharmony_ci+/// \return OH_AI_Status.
520be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train);
521be168c0dSopenharmony_ci+
522be168c0dSopenharmony_ci+/// \brief Setup training with virtual batches. Only valid for Lite Train.
523be168c0dSopenharmony_ci+///
524be168c0dSopenharmony_ci+/// \param[in] model Model object handle.
525be168c0dSopenharmony_ci+/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable.
526be168c0dSopenharmony_ci+/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration.
527be168c0dSopenharmony_ci+/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration.
528be168c0dSopenharmony_ci+///
529be168c0dSopenharmony_ci+/// \return OH_AI_Status.
530be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr,
531be168c0dSopenharmony_ci+                                                    float momentum);
532be168c0dSopenharmony_ci+
533be168c0dSopenharmony_ci+/// \brief Export training model from file. Only valid for Lite Train.
534be168c0dSopenharmony_ci+///
535be168c0dSopenharmony_ci+/// \param[in] model The model data.
536be168c0dSopenharmony_ci+/// \param[in] model_type The model file type.
537be168c0dSopenharmony_ci+/// \param[in] model_file The exported model file.
538be168c0dSopenharmony_ci+/// \param[in] quantization_type The quantification type.
539be168c0dSopenharmony_ci+/// \param[in] export_inference_only Whether to export a reasoning only model.
540be168c0dSopenharmony_ci+/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
541be168c0dSopenharmony_ci+/// empty, and export the complete reasoning model.
542be168c0dSopenharmony_ci+/// \param[in] num The number of output_tensor_name.
543be168c0dSopenharmony_ci+///
544be168c0dSopenharmony_ci+/// \return OH_AI_Status.
545be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
546be168c0dSopenharmony_ci+                                         OH_AI_QuantizationType quantization_type, bool export_inference_only,
547be168c0dSopenharmony_ci+                                         char **output_tensor_name, size_t num);
548be168c0dSopenharmony_ci+
549be168c0dSopenharmony_ci+/// \brief Export training model from buffer. Only valid for Lite Train.
550be168c0dSopenharmony_ci+///
551be168c0dSopenharmony_ci+/// \param[in] model The model data.
552be168c0dSopenharmony_ci+/// \param[in] model_type The model file type.
553be168c0dSopenharmony_ci+/// \param[in] model_data The exported model buffer.
554be168c0dSopenharmony_ci+/// \param[in] data_size The exported model buffer size.
555be168c0dSopenharmony_ci+/// \param[in] quantization_type The quantification type.
556be168c0dSopenharmony_ci+/// \param[in] export_inference_only Whether to export a reasoning only model.
557be168c0dSopenharmony_ci+/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
558be168c0dSopenharmony_ci+/// empty, and export the complete reasoning model.
559be168c0dSopenharmony_ci+/// \param[in] num The number of output_tensor_name.
560be168c0dSopenharmony_ci+///
561be168c0dSopenharmony_ci+/// \return OH_AI_Status.
562be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data,
563be168c0dSopenharmony_ci+                                               size_t *data_size, OH_AI_QuantizationType quantization_type,
564be168c0dSopenharmony_ci+                                               bool export_inference_only, char **output_tensor_name, size_t num);
565be168c0dSopenharmony_ci+
566be168c0dSopenharmony_ci+/// \brief Export model's weights, which can be used in micro only. Only valid for Lite Train.
567be168c0dSopenharmony_ci+///
568be168c0dSopenharmony_ci+/// \param[in] model The model data.
569be168c0dSopenharmony_ci+/// \param[in] model_type The model file type.
570be168c0dSopenharmony_ci+/// \param[in] weight_file The path of exported weight file.
571be168c0dSopenharmony_ci+/// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`.
572be168c0dSopenharmony_ci+/// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format.
573be168c0dSopenharmony_ci+/// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable.
574be168c0dSopenharmony_ci+/// \param[in] num The number of changeable_weights_name.
575be168c0dSopenharmony_ci+///
576be168c0dSopenharmony_ci+/// \return OH_AI_Status.
577be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type,
578be168c0dSopenharmony_ci+                                                               const char *weight_file, bool is_inference,
579be168c0dSopenharmony_ci+                                                               bool enable_fp16, char **changeable_weights_name,
580be168c0dSopenharmony_ci+                                                               size_t num);
581be168c0dSopenharmony_ci+
582be168c0dSopenharmony_ci #ifdef __cplusplus
583be168c0dSopenharmony_ci }
584be168c0dSopenharmony_ci #endif
585be168c0dSopenharmony_cidiff --git a/include/c_api/tensor_c.h b/include/c_api/tensor_c.h
586be168c0dSopenharmony_ciindex f18ba163..6d2aaab6 100644
587be168c0dSopenharmony_ci--- a/include/c_api/tensor_c.h
588be168c0dSopenharmony_ci+++ b/include/c_api/tensor_c.h
589be168c0dSopenharmony_ci@@ -17,6 +17,7 @@
590be168c0dSopenharmony_ci #define MINDSPORE_INCLUDE_C_API_TENSOE_C_H
591be168c0dSopenharmony_ci 
592be168c0dSopenharmony_ci #include <stddef.h>
593be168c0dSopenharmony_ci+#include "include/c_api/status_c.h"
594be168c0dSopenharmony_ci #include "include/c_api/types_c.h"
595be168c0dSopenharmony_ci #include "include/c_api/data_type_c.h"
596be168c0dSopenharmony_ci #include "include/c_api/format_c.h"
597be168c0dSopenharmony_ci@@ -112,6 +113,19 @@ OH_AI_API OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor);
598be168c0dSopenharmony_ci /// \param[in] data A pointer to the data of the tensor.
599be168c0dSopenharmony_ci OH_AI_API void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data);
600be168c0dSopenharmony_ci 
601be168c0dSopenharmony_ci+/// \brief Set the data for the tensor with user-allocated data buffer.
602be168c0dSopenharmony_ci+/// The main purpose of this interface is providing a way of using memory already allocated by user as the Model's
603be168c0dSopenharmony_ci+/// input, but not which allocated inside the Model object. It can reduce one copy.
604be168c0dSopenharmony_ci+/// Note: The tensor won't free the data provided by invoker. Invoker has the responsibility to free it. And this
605be168c0dSopenharmony_ci+/// free action should not be preformed before destruction of the tensor.
606be168c0dSopenharmony_ci+///
607be168c0dSopenharmony_ci+/// \param[in] tensor Tensor object handle.
608be168c0dSopenharmony_ci+/// \param[in] data A pointer to the user data buffer.
609be168c0dSopenharmony_ci+/// \param[in] data the byte size of the user data buffer.
610be168c0dSopenharmony_ci+///
611be168c0dSopenharmony_ci+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
612be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size);
613be168c0dSopenharmony_ci+
614be168c0dSopenharmony_ci /// \brief Obtain the data pointer of the tensor.
615be168c0dSopenharmony_ci ///
616be168c0dSopenharmony_ci /// \param[in] tensor Tensor object handle.
617be168c0dSopenharmony_cidiff --git a/include/c_api/types_c.h b/include/c_api/types_c.h
618be168c0dSopenharmony_ciindex dba54ffa..e520e336 100644
619be168c0dSopenharmony_ci--- a/include/c_api/types_c.h
620be168c0dSopenharmony_ci+++ b/include/c_api/types_c.h
621be168c0dSopenharmony_ci@@ -40,10 +40,65 @@ typedef enum OH_AI_DeviceType {
622be168c0dSopenharmony_ci   OH_AI_DEVICETYPE_KIRIN_NPU,
623be168c0dSopenharmony_ci   // add new type here
624be168c0dSopenharmony_ci   // ohos-only device range: [60, 80)
625be168c0dSopenharmony_ci-  OH_AI_DEVICETYPE__NNRT = 60,
626be168c0dSopenharmony_ci+  OH_AI_DEVICETYPE_NNRT = 60,
627be168c0dSopenharmony_ci   OH_AI_DEVICETYPE_INVALID = 100,
628be168c0dSopenharmony_ci } OH_AI_DeviceType;
629be168c0dSopenharmony_ci 
630be168c0dSopenharmony_ci+typedef enum OH_AI_NNRTDeviceType {
631be168c0dSopenharmony_ci+  /** Devices that are not CPU, GPU, or dedicated accelerator */
632be168c0dSopenharmony_ci+  OH_AI_NNRTDEVICE_OTHERS = 0,
633be168c0dSopenharmony_ci+  /** CPU device */
634be168c0dSopenharmony_ci+  OH_AI_NNRTDEVICE_CPU = 1,
635be168c0dSopenharmony_ci+  /** GPU device */
636be168c0dSopenharmony_ci+  OH_AI_NNRTDEVICE_GPU = 2,
637be168c0dSopenharmony_ci+  /** Dedicated hardware accelerator */
638be168c0dSopenharmony_ci+  OH_AI_NNRTDEVICE_ACCELERATOR = 3,
639be168c0dSopenharmony_ci+} OH_AI_NNRTDeviceType;
640be168c0dSopenharmony_ci+
641be168c0dSopenharmony_ci+typedef enum OH_AI_PerformanceMode {
642be168c0dSopenharmony_ci+  /** No performance mode preference */
643be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_NONE = 0,
644be168c0dSopenharmony_ci+  /** Low power consumption mode*/
645be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_LOW = 1,
646be168c0dSopenharmony_ci+  /** Medium performance mode */
647be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_MEDIUM = 2,
648be168c0dSopenharmony_ci+  /** High performance mode */
649be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_HIGH = 3,
650be168c0dSopenharmony_ci+  /** Ultimate performance mode */
651be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_EXTREME = 4
652be168c0dSopenharmony_ci+} OH_AI_PerformanceMode;
653be168c0dSopenharmony_ci+
654be168c0dSopenharmony_ci+typedef enum OH_AI_Priority {
655be168c0dSopenharmony_ci+  /** No priority preference */
656be168c0dSopenharmony_ci+  OH_AI_PRIORITY_NONE = 0,
657be168c0dSopenharmony_ci+  /** Low priority */
658be168c0dSopenharmony_ci+  OH_AI_PRIORITY_LOW = 1,
659be168c0dSopenharmony_ci+  /** Medium priority */
660be168c0dSopenharmony_ci+  OH_AI_PRIORITY_MEDIUM = 2,
661be168c0dSopenharmony_ci+  /** High priority */
662be168c0dSopenharmony_ci+  OH_AI_PRIORITY_HIGH = 3
663be168c0dSopenharmony_ci+} OH_AI_Priority;
664be168c0dSopenharmony_ci+
665be168c0dSopenharmony_ci+typedef enum OH_AI_OptimizationLevel {
666be168c0dSopenharmony_ci+  /** Do not change */
667be168c0dSopenharmony_ci+  OH_AI_KO0 = 0,
668be168c0dSopenharmony_ci+  /** Cast network to float16, keep batchnorm and loss in float32 */
669be168c0dSopenharmony_ci+  OH_AI_KO2 = 2,
670be168c0dSopenharmony_ci+  /** Cast network to float16, including bacthnorm */
671be168c0dSopenharmony_ci+  OH_AI_KO3 = 3,
672be168c0dSopenharmony_ci+  /** Choose optimization based on device */
673be168c0dSopenharmony_ci+  OH_AI_KAUTO = 4,
674be168c0dSopenharmony_ci+  OH_AI_KOPTIMIZATIONTYPE = 0xFFFFFFFF
675be168c0dSopenharmony_ci+} OH_AI_OptimizationLevel;
676be168c0dSopenharmony_ci+
677be168c0dSopenharmony_ci+typedef enum OH_AI_QuantizationType {
678be168c0dSopenharmony_ci+  OH_AI_NO_QUANT = 0,
679be168c0dSopenharmony_ci+  OH_AI_WEIGHT_QUANT = 1,
680be168c0dSopenharmony_ci+  OH_AI_FULL_QUANT = 2,
681be168c0dSopenharmony_ci+  OH_AI_UNKNOWN_QUANT_TYPE = 0xFFFFFFFF
682be168c0dSopenharmony_ci+} OH_AI_QuantizationType;
683be168c0dSopenharmony_ci+
684be168c0dSopenharmony_ci+typedef struct NNRTDeviceDesc NNRTDeviceDesc;
685be168c0dSopenharmony_ci #ifdef __cplusplus
686be168c0dSopenharmony_ci }
687be168c0dSopenharmony_ci #endif
688be168c0dSopenharmony_cidiff --git a/include/sdk_api/context.h b/include/sdk_api/context.h
689be168c0dSopenharmony_ciindex 5bfc9279..e12b8d6f 100644
690be168c0dSopenharmony_ci--- a/include/sdk_api/context.h
691be168c0dSopenharmony_ci+++ b/include/sdk_api/context.h
692be168c0dSopenharmony_ci@@ -174,6 +174,109 @@ OH_AI_API void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info,
693be168c0dSopenharmony_ci /// \return NPU frequency
694be168c0dSopenharmony_ci OH_AI_API int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info);
695be168c0dSopenharmony_ci 
696be168c0dSopenharmony_ci+/// \brief Obtain the all device descriptions in NNRT.
697be168c0dSopenharmony_ci+///
698be168c0dSopenharmony_ci+/// \param[out] num Number of NNRT device description.
699be168c0dSopenharmony_ci+///
700be168c0dSopenharmony_ci+/// \return NNRT device description array.
701be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num);
702be168c0dSopenharmony_ci+
703be168c0dSopenharmony_ci+/// \brief Obtain the specified element in NNRt device description array.
704be168c0dSopenharmony_ci+///
705be168c0dSopenharmony_ci+/// \param[in] descs NNRT device description array.
706be168c0dSopenharmony_ci+/// \param[in] index Element index.
707be168c0dSopenharmony_ci+///
708be168c0dSopenharmony_ci+/// \return NNRT device description.
709be168c0dSopenharmony_ci+OH_AI_API NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index);
710be168c0dSopenharmony_ci+
711be168c0dSopenharmony_ci+/// \brief Destroy the NNRT device descriptions returned by OH_AI_NNRTGetAllDeviceDescs().
712be168c0dSopenharmony_ci+///
713be168c0dSopenharmony_ci+/// \param[in] desc NNRT device description array.
714be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc);
715be168c0dSopenharmony_ci+
716be168c0dSopenharmony_ci+/// \brief Obtain the device id in NNRT device description.
717be168c0dSopenharmony_ci+///
718be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance.
719be168c0dSopenharmony_ci+///
720be168c0dSopenharmony_ci+/// \return NNRT device id.
721be168c0dSopenharmony_ci+OH_AI_API size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
722be168c0dSopenharmony_ci+
723be168c0dSopenharmony_ci+/// \brief Obtain the device name in NNRT device description.
724be168c0dSopenharmony_ci+///
725be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance.
726be168c0dSopenharmony_ci+///
727be168c0dSopenharmony_ci+/// \return NNRT device name.
728be168c0dSopenharmony_ci+OH_AI_API const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
729be168c0dSopenharmony_ci+
730be168c0dSopenharmony_ci+/// \brief Obtain the device type in NNRT device description.
731be168c0dSopenharmony_ci+///
732be168c0dSopenharmony_ci+/// \param[in] desc pointer to the NNRT device description instance.
733be168c0dSopenharmony_ci+///
734be168c0dSopenharmony_ci+/// \return NNRT device type.
735be168c0dSopenharmony_ci+OH_AI_API OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
736be168c0dSopenharmony_ci+
737be168c0dSopenharmony_ci+/// \brief Create the NNRT device info by exactly matching the specific device name.
738be168c0dSopenharmony_ci+///
739be168c0dSopenharmony_ci+/// \param[in] name NNRt device name.
740be168c0dSopenharmony_ci+///
741be168c0dSopenharmony_ci+/// \return Device info object handle.
742be168c0dSopenharmony_ci+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name);
743be168c0dSopenharmony_ci+
744be168c0dSopenharmony_ci+/// \brief Create the NNRT device info by finding the first device with the specific device type.
745be168c0dSopenharmony_ci+///
746be168c0dSopenharmony_ci+/// \param[in] name NNRt device type.
747be168c0dSopenharmony_ci+///
748be168c0dSopenharmony_ci+/// \return Device info object handle.
749be168c0dSopenharmony_ci+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type);
750be168c0dSopenharmony_ci+
751be168c0dSopenharmony_ci+/// \brief Set the NNRT device id, Only valid for NNRT.
752be168c0dSopenharmony_ci+///
753be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
754be168c0dSopenharmony_ci+/// \param[in] device_id NNRT device id.
755be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id);
756be168c0dSopenharmony_ci+
757be168c0dSopenharmony_ci+/// \brief Obtain the NNRT device id, Only valid for NNRT.
758be168c0dSopenharmony_ci+///
759be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
760be168c0dSopenharmony_ci+///
761be168c0dSopenharmony_ci+/// \return NNRT device id.
762be168c0dSopenharmony_ci+OH_AI_API size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info);
763be168c0dSopenharmony_ci+
764be168c0dSopenharmony_ci+/// \brief Set the NNRT performance mode, Only valid for NNRT.
765be168c0dSopenharmony_ci+///
766be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
767be168c0dSopenharmony_ci+/// \param[in] device_id NNRT performance mode.
768be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode);
769be168c0dSopenharmony_ci+
770be168c0dSopenharmony_ci+/// \brief Obtain the NNRT performance mode, Only valid for NNRT.
771be168c0dSopenharmony_ci+///
772be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
773be168c0dSopenharmony_ci+///
774be168c0dSopenharmony_ci+/// \return NNRT performance mode.
775be168c0dSopenharmony_ci+OH_AI_API OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info);
776be168c0dSopenharmony_ci+
777be168c0dSopenharmony_ci+/// \brief Set the NNRT priority, Only valid for NNRT.
778be168c0dSopenharmony_ci+///
779be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
780be168c0dSopenharmony_ci+/// \param[in] device_id NNRT priority.
781be168c0dSopenharmony_ci+OH_AI_API void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority);
782be168c0dSopenharmony_ci+
783be168c0dSopenharmony_ci+/// \brief Obtain the NNRT priority, Only valid for NNRT.
784be168c0dSopenharmony_ci+///
785be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
786be168c0dSopenharmony_ci+///
787be168c0dSopenharmony_ci+/// \return NNRT priority.
788be168c0dSopenharmony_ci+OH_AI_API OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info);
789be168c0dSopenharmony_ci+
790be168c0dSopenharmony_ci+/// \brief Add extension of key/value format to device info, Only valid for NNRT.
791be168c0dSopenharmony_ci+///
792be168c0dSopenharmony_ci+/// \param[in] device_info Device info object handle.
793be168c0dSopenharmony_ci+/// \param[in] name The content of key as a C string.
794be168c0dSopenharmony_ci+/// \param[in] value The pointer to the value, which is a byte array.
795be168c0dSopenharmony_ci+/// \param[in] value_size The size of the value, which is a byte array.
796be168c0dSopenharmony_ci+///
797be168c0dSopenharmony_ci+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
798be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, const char *name, const char *value, size_t value_size);
799be168c0dSopenharmony_ci #ifdef __cplusplus
800be168c0dSopenharmony_ci }
801be168c0dSopenharmony_ci #endif
802be168c0dSopenharmony_cidiff --git a/include/sdk_api/tensor.h b/include/sdk_api/tensor.h
803be168c0dSopenharmony_ciindex f6ba02cd..3dad04ac 100644
804be168c0dSopenharmony_ci--- a/include/sdk_api/tensor.h
805be168c0dSopenharmony_ci+++ b/include/sdk_api/tensor.h
806be168c0dSopenharmony_ci@@ -17,6 +17,7 @@
807be168c0dSopenharmony_ci #define MINDSPORE_INCLUDE_C_API_TENSOE_C_H
808be168c0dSopenharmony_ci 
809be168c0dSopenharmony_ci #include <stddef.h>
810be168c0dSopenharmony_ci+#include "mindspore/status.h"
811be168c0dSopenharmony_ci #include "mindspore/types.h"
812be168c0dSopenharmony_ci #include "mindspore/data_type.h"
813be168c0dSopenharmony_ci #include "mindspore/format.h"
814be168c0dSopenharmony_ci@@ -140,6 +141,18 @@ OH_AI_API int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor);
815be168c0dSopenharmony_ci /// \return The data size of the tensor.
816be168c0dSopenharmony_ci OH_AI_API size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor);
817be168c0dSopenharmony_ci 
818be168c0dSopenharmony_ci+/// \brief Set the data for the tensor with user-allocated data buffer.
819be168c0dSopenharmony_ci+/// The main purpose of this interface is providing a way of using memory already allocated by user as the Model's
820be168c0dSopenharmony_ci+/// input, but not which allocated inside the Model object. It can reduce one copy.
821be168c0dSopenharmony_ci+/// Note: The tensor won't free the data provided by invoker. Invoker has the responsibility to free it. And this
822be168c0dSopenharmony_ci+/// free action should not be preformed before destruction of the tensor.
823be168c0dSopenharmony_ci+///
824be168c0dSopenharmony_ci+/// \param[in] tensor Tensor object handle.
825be168c0dSopenharmony_ci+/// \param[in] data A pointer to the user data buffer.
826be168c0dSopenharmony_ci+/// \param[in] data the byte size of the user data buffer.
827be168c0dSopenharmony_ci+///
828be168c0dSopenharmony_ci+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
829be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size);
830be168c0dSopenharmony_ci #ifdef __cplusplus
831be168c0dSopenharmony_ci }
832be168c0dSopenharmony_ci #endif
833be168c0dSopenharmony_cidiff --git a/include/sdk_api/types.h b/include/sdk_api/types.h
834be168c0dSopenharmony_ciindex a39c6daa..d38660b0 100644
835be168c0dSopenharmony_ci--- a/include/sdk_api/types.h
836be168c0dSopenharmony_ci+++ b/include/sdk_api/types.h
837be168c0dSopenharmony_ci@@ -40,10 +40,46 @@ typedef enum OH_AI_DeviceType {
838be168c0dSopenharmony_ci   OH_AI_DEVICETYPE_KIRIN_NPU,
839be168c0dSopenharmony_ci   // add new type here
840be168c0dSopenharmony_ci   // ohos-only device range: [60, 80)
841be168c0dSopenharmony_ci-  OH_AI_DeviceType_NNRT = 60,
842be168c0dSopenharmony_ci+  OH_AI_DEVICETYPE_NNRT = 60,
843be168c0dSopenharmony_ci   OH_AI_DEVICETYPE_INVALID = 100,
844be168c0dSopenharmony_ci } OH_AI_DeviceType;
845be168c0dSopenharmony_ci 
846be168c0dSopenharmony_ci+typedef enum OH_AI_NNRTDeviceType {
847be168c0dSopenharmony_ci+  /** Devices that are not CPU, GPU, or dedicated accelerator */
848be168c0dSopenharmony_ci+  OH_AI_NNRTDEVICE_OTHERS = 0,
849be168c0dSopenharmony_ci+  /** CPU device */
850be168c0dSopenharmony_ci+  OH_AI_NNRTDEVICE_CPU = 1,
851be168c0dSopenharmony_ci+  /** GPU device */
852be168c0dSopenharmony_ci+  OH_AI_NNRTDEVICE_GPU = 2,
853be168c0dSopenharmony_ci+  /** Dedicated hardware accelerator */
854be168c0dSopenharmony_ci+  OH_AI_NNRTDEVICE_ACCELERATOR = 3,
855be168c0dSopenharmony_ci+} OH_AI_NNRTDeviceType;
856be168c0dSopenharmony_ci+
857be168c0dSopenharmony_ci+typedef enum OH_AI_PerformanceMode {
858be168c0dSopenharmony_ci+  /** No performance mode preference */
859be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_NONE = 0,
860be168c0dSopenharmony_ci+  /** Low power consumption mode*/
861be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_LOW = 1,
862be168c0dSopenharmony_ci+  /** Medium performance mode */
863be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_MEDIUM = 2,
864be168c0dSopenharmony_ci+  /** High performance mode */
865be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_HIGH = 3,
866be168c0dSopenharmony_ci+  /** Ultimate performance mode */
867be168c0dSopenharmony_ci+  OH_AI_PERFORMANCE_EXTREME = 4
868be168c0dSopenharmony_ci+} OH_AI_PerformanceMode;
869be168c0dSopenharmony_ci+
870be168c0dSopenharmony_ci+typedef enum OH_AI_Priority {
871be168c0dSopenharmony_ci+  /** No priority preference */
872be168c0dSopenharmony_ci+  OH_AI_PRIORITY_NONE = 0,
873be168c0dSopenharmony_ci+  /** Low priority */
874be168c0dSopenharmony_ci+  OH_AI_PRIORITY_LOW = 1,
875be168c0dSopenharmony_ci+  /** Medium priority */
876be168c0dSopenharmony_ci+  OH_AI_PRIORITY_MEDIUM = 2,
877be168c0dSopenharmony_ci+  /** High priority */
878be168c0dSopenharmony_ci+  OH_AI_PRIORITY_HIGH = 3
879be168c0dSopenharmony_ci+} OH_AI_Priority;
880be168c0dSopenharmony_ci+
881be168c0dSopenharmony_ci+typedef struct NNRTDeviceDesc NNRTDeviceDesc;
882be168c0dSopenharmony_ci #ifdef __cplusplus
883be168c0dSopenharmony_ci }
884be168c0dSopenharmony_ci #endif
885be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
886be168c0dSopenharmony_ciindex 7bbc3782..103e53b7 100644
887be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
888be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
889be168c0dSopenharmony_ci@@ -498,6 +498,9 @@ infer_shape_sources = [
890be168c0dSopenharmony_ci   "infer/crop_infer.c",
891be168c0dSopenharmony_ci   "infer/cumsum_infer.c",
892be168c0dSopenharmony_ci   "infer/custom_gru_infer.c",
893be168c0dSopenharmony_ci+  "infer/custom_masked_fill_infer.c",
894be168c0dSopenharmony_ci+  "infer/custom_is_inf_infer.c",
895be168c0dSopenharmony_ci+  "infer/custom_tensor_scatter_max_infer.c",
896be168c0dSopenharmony_ci   "infer/decoder_layer_infer.c",
897be168c0dSopenharmony_ci   "infer/deconv2d_infer.c",
898be168c0dSopenharmony_ci   "infer/depth_to_space_infer.c",
899be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt
900be168c0dSopenharmony_ciindex c1685a65..6fef44fd 100644
901be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt
902be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt
903be168c0dSopenharmony_ci@@ -238,7 +238,7 @@ endif()
904be168c0dSopenharmony_ci if(PLATFORM_ARM)
905be168c0dSopenharmony_ci     set(NO_FAST_MATH_OPTI ${NNACL_DIR}/fp32/resize_fp32.c)
906be168c0dSopenharmony_ci     set_source_files_properties(${NO_FAST_MATH_OPTI} PROPERTIES LANGUAGE C
907be168c0dSopenharmony_ci-        COMPILE_FLAGS "${CMAKE_C_FLAGS} -fno-fast-math")
908be168c0dSopenharmony_ci+        COMPILE_FLAGS "${CMAKE_C_FLAGS} -w -fno-fast-math")
909be168c0dSopenharmony_ci endif()
910be168c0dSopenharmony_ci 
911be168c0dSopenharmony_ci add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC})
912be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h
913be168c0dSopenharmony_cinew file mode 100644
914be168c0dSopenharmony_ciindex 00000000..14bd1d76
915be168c0dSopenharmony_ci--- /dev/null
916be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h
917be168c0dSopenharmony_ci@@ -0,0 +1,66 @@
918be168c0dSopenharmony_ci+/**
919be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd
920be168c0dSopenharmony_ci+*
921be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License");
922be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License.
923be168c0dSopenharmony_ci+* You may obtain a copy of the License at
924be168c0dSopenharmony_ci+*
925be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0
926be168c0dSopenharmony_ci+*
927be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software
928be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS,
929be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
930be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and
931be168c0dSopenharmony_ci+* limitations under the License.
932be168c0dSopenharmony_ci+*/
933be168c0dSopenharmony_ci+#ifndef NNACL_BASE_SCATTER_ND_BINARY_AVX_H_
934be168c0dSopenharmony_ci+#define NNACL_BASE_SCATTER_ND_BINARY_AVX_H_
935be168c0dSopenharmony_ci+
936be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h"
937be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_avx_instructions.h"
938be168c0dSopenharmony_ci+
939be168c0dSopenharmony_ci+#ifdef __cplusplus
940be168c0dSopenharmony_ci+extern "C" {
941be168c0dSopenharmony_ci+#endif
942be168c0dSopenharmony_ci+#pragma GCC push_options
943be168c0dSopenharmony_ci+#pragma GCC target("avx", "avx2")
944be168c0dSopenharmony_ci+#define MS_SIMD_INSTRUCTION MS_SIMD_AVX_INSTRUCTION
945be168c0dSopenharmony_ci+#define BLOCK_NUM 8
946be168c0dSopenharmony_ci+#define MS_SIMD_AVX
947be168c0dSopenharmony_ci+
948be168c0dSopenharmony_ci+static inline int ScatterNDAddFp32AVX(int index, const float *update, int size, float *output) {
949be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
950be168c0dSopenharmony_ci+    SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
951be168c0dSopenharmony_ci+  }
952be168c0dSopenharmony_ci+  return index;
953be168c0dSopenharmony_ci+}
954be168c0dSopenharmony_ci+
955be168c0dSopenharmony_ci+static inline int ScatterNDAddInt32AVX(int index, const int *update, int size, int *output) {
956be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
957be168c0dSopenharmony_ci+    SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
958be168c0dSopenharmony_ci+  }
959be168c0dSopenharmony_ci+  return index;
960be168c0dSopenharmony_ci+}
961be168c0dSopenharmony_ci+
962be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32AVX(int index, const float *update, int size, float *output) {
963be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
964be168c0dSopenharmony_ci+    SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
965be168c0dSopenharmony_ci+  }
966be168c0dSopenharmony_ci+  return index;
967be168c0dSopenharmony_ci+}
968be168c0dSopenharmony_ci+
969be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32AVX(int index, const int *update, int size, int *output) {
970be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
971be168c0dSopenharmony_ci+    SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
972be168c0dSopenharmony_ci+  }
973be168c0dSopenharmony_ci+  return index;
974be168c0dSopenharmony_ci+}
975be168c0dSopenharmony_ci+
976be168c0dSopenharmony_ci+#undef MS_SIMD_INSTRUCTION
977be168c0dSopenharmony_ci+#undef BLOCK_NUM
978be168c0dSopenharmony_ci+#pragma GCC pop_options
979be168c0dSopenharmony_ci+#undef MS_SIMD_AVX
980be168c0dSopenharmony_ci+#ifdef __cplusplus
981be168c0dSopenharmony_ci+}
982be168c0dSopenharmony_ci+#endif
983be168c0dSopenharmony_ci+#endif  // NNACL_BASE_SCATTER_ND_BINARY_AVX_H_
984be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h
985be168c0dSopenharmony_cinew file mode 100644
986be168c0dSopenharmony_ciindex 00000000..abf024c5
987be168c0dSopenharmony_ci--- /dev/null
988be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h
989be168c0dSopenharmony_ci@@ -0,0 +1,66 @@
990be168c0dSopenharmony_ci+/**
991be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd
992be168c0dSopenharmony_ci+*
993be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License");
994be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License.
995be168c0dSopenharmony_ci+* You may obtain a copy of the License at
996be168c0dSopenharmony_ci+*
997be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0
998be168c0dSopenharmony_ci+*
999be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software
1000be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS,
1001be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1002be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and
1003be168c0dSopenharmony_ci+* limitations under the License.
1004be168c0dSopenharmony_ci+*/
1005be168c0dSopenharmony_ci+#ifndef NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_
1006be168c0dSopenharmony_ci+#define NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_
1007be168c0dSopenharmony_ci+
1008be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h"
1009be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_avx512_instructions.h"
1010be168c0dSopenharmony_ci+
1011be168c0dSopenharmony_ci+#ifdef __cplusplus
1012be168c0dSopenharmony_ci+extern "C" {
1013be168c0dSopenharmony_ci+#endif
1014be168c0dSopenharmony_ci+#pragma GCC push_options
1015be168c0dSopenharmony_ci+#pragma GCC target("avx512f")
1016be168c0dSopenharmony_ci+#define MS_SIMD_INSTRUCTION MS_SIMD_AVX512_INSTRUCTION
1017be168c0dSopenharmony_ci+#define BLOCK_NUM 16
1018be168c0dSopenharmony_ci+#define MS_SIMD_AVX512
1019be168c0dSopenharmony_ci+
1020be168c0dSopenharmony_ci+static inline int ScatterNDAddFp32AVX512(int index, const float *update, int size, float *output) {
1021be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1022be168c0dSopenharmony_ci+    SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1023be168c0dSopenharmony_ci+  }
1024be168c0dSopenharmony_ci+  return index;
1025be168c0dSopenharmony_ci+}
1026be168c0dSopenharmony_ci+
1027be168c0dSopenharmony_ci+static inline int ScatterNDAddInt32AVX512(int index, const int *update, int size, int *output) {
1028be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1029be168c0dSopenharmony_ci+    SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1030be168c0dSopenharmony_ci+  }
1031be168c0dSopenharmony_ci+  return index;
1032be168c0dSopenharmony_ci+}
1033be168c0dSopenharmony_ci+
1034be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32AVX512(int index, const float *update, int size, float *output) {
1035be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1036be168c0dSopenharmony_ci+    SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1037be168c0dSopenharmony_ci+  }
1038be168c0dSopenharmony_ci+  return index;
1039be168c0dSopenharmony_ci+}
1040be168c0dSopenharmony_ci+
1041be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32AVX512(int index, const int *update, int size, int *output) {
1042be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1043be168c0dSopenharmony_ci+    SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1044be168c0dSopenharmony_ci+  }
1045be168c0dSopenharmony_ci+  return index;
1046be168c0dSopenharmony_ci+}
1047be168c0dSopenharmony_ci+
1048be168c0dSopenharmony_ci+#undef MS_SIMD_INSTRUCTION
1049be168c0dSopenharmony_ci+#undef BLOCK_NUM
1050be168c0dSopenharmony_ci+#pragma GCC pop_options
1051be168c0dSopenharmony_ci+#undef MS_SIMD_AVX512
1052be168c0dSopenharmony_ci+#ifdef __cplusplus
1053be168c0dSopenharmony_ci+}
1054be168c0dSopenharmony_ci+#endif
1055be168c0dSopenharmony_ci+#endif  // NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_
1056be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c
1057be168c0dSopenharmony_ciindex bca71f55..e496bb4b 100644
1058be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c
1059be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c
1060be168c0dSopenharmony_ci@@ -77,3 +77,31 @@ int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets,
1061be168c0dSopenharmony_ci   }
1062be168c0dSopenharmony_ci   return NNACL_OK;
1063be168c0dSopenharmony_ci }
1064be168c0dSopenharmony_ci+
1065be168c0dSopenharmony_ci+int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
1066be168c0dSopenharmony_ci+                 int task_id) {
1067be168c0dSopenharmony_ci+  if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) {
1068be168c0dSopenharmony_ci+    return NNACL_NULL_PTR;
1069be168c0dSopenharmony_ci+  }
1070be168c0dSopenharmony_ci+  if (param->op_parameter.thread_num_ == 0) {
1071be168c0dSopenharmony_ci+    return NNACL_ERR;
1072be168c0dSopenharmony_ci+  }
1073be168c0dSopenharmony_ci+  int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_);
1074be168c0dSopenharmony_ci+  int begin = unit_per_thread * task_id;
1075be168c0dSopenharmony_ci+  int end = MSMIN(begin + unit_per_thread, param->num_unit);
1076be168c0dSopenharmony_ci+  if (type == 0) {
1077be168c0dSopenharmony_ci+    float *update_fp32 = (float *)update;
1078be168c0dSopenharmony_ci+    float *output_fp32 = (float *)output;
1079be168c0dSopenharmony_ci+    for (int i = begin; i < end; i++) {
1080be168c0dSopenharmony_ci+      const float *update_data = update_fp32 + i * param->unit_size;
1081be168c0dSopenharmony_ci+      float *output_data = output_fp32 + output_unit_offsets[i];
1082be168c0dSopenharmony_ci+      int j = 0;
1083be168c0dSopenharmony_ci+
1084be168c0dSopenharmony_ci+      SIMD_RUN_NO_SCALAR(ScatterNDMaxFp32, j, update_data, param->unit_size, output_data);
1085be168c0dSopenharmony_ci+      for (; j < param->unit_size; j++) {
1086be168c0dSopenharmony_ci+        output_data[j] = fmaxf(update_data[j], output_data[j]);
1087be168c0dSopenharmony_ci+      }
1088be168c0dSopenharmony_ci+    }
1089be168c0dSopenharmony_ci+  }
1090be168c0dSopenharmony_ci+  return NNACL_OK;
1091be168c0dSopenharmony_ci+}
1092be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h
1093be168c0dSopenharmony_ciindex 3af55335..36657cd9 100644
1094be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h
1095be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h
1096be168c0dSopenharmony_ci@@ -27,6 +27,9 @@ int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets,
1097be168c0dSopenharmony_ci 
1098be168c0dSopenharmony_ci int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
1099be168c0dSopenharmony_ci                  int task_id);
1100be168c0dSopenharmony_ci+
1101be168c0dSopenharmony_ci+int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
1102be168c0dSopenharmony_ci+                 int task_id);
1103be168c0dSopenharmony_ci #ifdef __cplusplus
1104be168c0dSopenharmony_ci }
1105be168c0dSopenharmony_ci #endif
1106be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in
1107be168c0dSopenharmony_ciindex c72d9cc2..46bb20ce 100644
1108be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in
1109be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in
1110be168c0dSopenharmony_ci@@ -38,6 +38,20 @@ static inline int ScatterNDAddInt32@SIMD_INSTRUCTION@(int index, const int *upda
1111be168c0dSopenharmony_ci   return index;
1112be168c0dSopenharmony_ci }
1113be168c0dSopenharmony_ci 
1114be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) {
1115be168c0dSopenharmony_ci+for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1116be168c0dSopenharmony_ci+SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1117be168c0dSopenharmony_ci+}
1118be168c0dSopenharmony_ci+return index;
1119be168c0dSopenharmony_ci+}
1120be168c0dSopenharmony_ci+
1121be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) {
1122be168c0dSopenharmony_ci+for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1123be168c0dSopenharmony_ci+SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1124be168c0dSopenharmony_ci+}
1125be168c0dSopenharmony_ci+return index;
1126be168c0dSopenharmony_ci+}
1127be168c0dSopenharmony_ci+
1128be168c0dSopenharmony_ci @SIMD_INSTRUCTION_END@
1129be168c0dSopenharmony_ci #ifdef __cplusplus
1130be168c0dSopenharmony_ci }
1131be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h
1132be168c0dSopenharmony_cinew file mode 100644
1133be168c0dSopenharmony_ciindex 00000000..e1eae394
1134be168c0dSopenharmony_ci--- /dev/null
1135be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h
1136be168c0dSopenharmony_ci@@ -0,0 +1,26 @@
1137be168c0dSopenharmony_ci+/**
1138be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1139be168c0dSopenharmony_ci+ *
1140be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1141be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1142be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1143be168c0dSopenharmony_ci+ *
1144be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1145be168c0dSopenharmony_ci+ *
1146be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1147be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1148be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1149be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1150be168c0dSopenharmony_ci+ * limitations under the License.
1151be168c0dSopenharmony_ci+ */
1152be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_
1153be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_
1154be168c0dSopenharmony_ci+
1155be168c0dSopenharmony_ci+#include "nnacl/op_base.h"
1156be168c0dSopenharmony_ci+
1157be168c0dSopenharmony_ci+typedef struct CustomIsInfParameter {
1158be168c0dSopenharmony_ci+  // Primitive parameter
1159be168c0dSopenharmony_ci+  OpParameter op_parameter_;
1160be168c0dSopenharmony_ci+} CustomIsInfParameter;
1161be168c0dSopenharmony_ci+
1162be168c0dSopenharmony_ci+#endif  // MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_
1163be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h
1164be168c0dSopenharmony_cinew file mode 100644
1165be168c0dSopenharmony_ciindex 00000000..047d3d3f
1166be168c0dSopenharmony_ci--- /dev/null
1167be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h
1168be168c0dSopenharmony_ci@@ -0,0 +1,26 @@
1169be168c0dSopenharmony_ci+/**
1170be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1171be168c0dSopenharmony_ci+ *
1172be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1173be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1174be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1175be168c0dSopenharmony_ci+ *
1176be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1177be168c0dSopenharmony_ci+ *
1178be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1179be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1180be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1181be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1182be168c0dSopenharmony_ci+ * limitations under the License.
1183be168c0dSopenharmony_ci+ */
1184be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_
1185be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_
1186be168c0dSopenharmony_ci+
1187be168c0dSopenharmony_ci+#include "nnacl/op_base.h"
1188be168c0dSopenharmony_ci+
1189be168c0dSopenharmony_ci+typedef struct CustomMaskedFillParameter {
1190be168c0dSopenharmony_ci+  // Primitive parameter
1191be168c0dSopenharmony_ci+  OpParameter op_parameter_;
1192be168c0dSopenharmony_ci+} CustomMaskedFillParameter;
1193be168c0dSopenharmony_ci+
1194be168c0dSopenharmony_ci+#endif  // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_
1195be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h
1196be168c0dSopenharmony_cinew file mode 100644
1197be168c0dSopenharmony_ciindex 00000000..ba6940db
1198be168c0dSopenharmony_ci--- /dev/null
1199be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h
1200be168c0dSopenharmony_ci@@ -0,0 +1,26 @@
1201be168c0dSopenharmony_ci+/**
1202be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1203be168c0dSopenharmony_ci+ *
1204be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1205be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1206be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1207be168c0dSopenharmony_ci+ *
1208be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1209be168c0dSopenharmony_ci+ *
1210be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1211be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1212be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1214be168c0dSopenharmony_ci+ * limitations under the License.
1215be168c0dSopenharmony_ci+ */
1216be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_
1217be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_
1218be168c0dSopenharmony_ci+
1219be168c0dSopenharmony_ci+#include "nnacl/op_base.h"
1220be168c0dSopenharmony_ci+
1221be168c0dSopenharmony_ci+typedef struct CustomTensorScatterMaxParameter {
1222be168c0dSopenharmony_ci+  // Primitive parameter
1223be168c0dSopenharmony_ci+  OpParameter op_parameter_;
1224be168c0dSopenharmony_ci+} CustomTensorScatterMaxParameter;
1225be168c0dSopenharmony_ci+
1226be168c0dSopenharmony_ci+#endif  // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_
1227be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c
1228be168c0dSopenharmony_cinew file mode 100644
1229be168c0dSopenharmony_ciindex 00000000..fc87d157
1230be168c0dSopenharmony_ci--- /dev/null
1231be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c
1232be168c0dSopenharmony_ci@@ -0,0 +1,38 @@
1233be168c0dSopenharmony_ci+/**
1234be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1235be168c0dSopenharmony_ci+ *
1236be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1237be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1238be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1239be168c0dSopenharmony_ci+ *
1240be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1241be168c0dSopenharmony_ci+ *
1242be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1243be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1244be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1245be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1246be168c0dSopenharmony_ci+ * limitations under the License.
1247be168c0dSopenharmony_ci+ */
1248be168c0dSopenharmony_ci+
1249be168c0dSopenharmony_ci+#include "nnacl/infer/custom_is_inf_infer.h"
1250be168c0dSopenharmony_ci+#include "nnacl/infer/infer_register.h"
1251be168c0dSopenharmony_ci+
1252be168c0dSopenharmony_ci+int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
1253be168c0dSopenharmony_ci+                          OpParameter *parameter) {
1254be168c0dSopenharmony_ci+  int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C1NUM, C1NUM);
1255be168c0dSopenharmony_ci+  if (check_ret != NNACL_OK) {
1256be168c0dSopenharmony_ci+    return check_ret;
1257be168c0dSopenharmony_ci+  }
1258be168c0dSopenharmony_ci+
1259be168c0dSopenharmony_ci+  const TensorC *input = inputs[0];
1260be168c0dSopenharmony_ci+  TensorC *output = outputs[0];
1261be168c0dSopenharmony_ci+  output->data_type_ = kNumberTypeBool;
1262be168c0dSopenharmony_ci+  output->format_ = input->format_;
1263be168c0dSopenharmony_ci+  if (!InferFlag(inputs, inputs_size)) {
1264be168c0dSopenharmony_ci+    return NNACL_INFER_INVALID;
1265be168c0dSopenharmony_ci+  }
1266be168c0dSopenharmony_ci+  SetShapeTensor(output, input);
1267be168c0dSopenharmony_ci+  return NNACL_OK;
1268be168c0dSopenharmony_ci+}
1269be168c0dSopenharmony_ci+
1270be168c0dSopenharmony_ci+REG_INFER(CustomIsInf, PrimType_Inner_CustomIsInf, CustomIsInfInferShape)
1271be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h
1272be168c0dSopenharmony_cinew file mode 100644
1273be168c0dSopenharmony_ciindex 00000000..d1b4b33d
1274be168c0dSopenharmony_ci--- /dev/null
1275be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h
1276be168c0dSopenharmony_ci@@ -0,0 +1,31 @@
1277be168c0dSopenharmony_ci+/**
1278be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1279be168c0dSopenharmony_ci+ *
1280be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1281be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1282be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1283be168c0dSopenharmony_ci+ *
1284be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1285be168c0dSopenharmony_ci+ *
1286be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1287be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1288be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1289be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1290be168c0dSopenharmony_ci+ * limitations under the License.
1291be168c0dSopenharmony_ci+ */
1292be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H
1293be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H
1294be168c0dSopenharmony_ci+
1295be168c0dSopenharmony_ci+#include "nnacl/infer/common_infer.h"
1296be168c0dSopenharmony_ci+
1297be168c0dSopenharmony_ci+#ifdef __cplusplus
1298be168c0dSopenharmony_ci+extern "C" {
1299be168c0dSopenharmony_ci+#endif
1300be168c0dSopenharmony_ci+
1301be168c0dSopenharmony_ci+int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
1302be168c0dSopenharmony_ci+                          OpParameter *parameter);
1303be168c0dSopenharmony_ci+
1304be168c0dSopenharmony_ci+#ifdef __cplusplus
1305be168c0dSopenharmony_ci+}
1306be168c0dSopenharmony_ci+#endif
1307be168c0dSopenharmony_ci+#endif  // MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H
1308be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c
1309be168c0dSopenharmony_cinew file mode 100644
1310be168c0dSopenharmony_ciindex 00000000..957a4d4f
1311be168c0dSopenharmony_ci--- /dev/null
1312be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c
1313be168c0dSopenharmony_ci@@ -0,0 +1,37 @@
1314be168c0dSopenharmony_ci+/**
1315be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1316be168c0dSopenharmony_ci+ *
1317be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1318be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1319be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1320be168c0dSopenharmony_ci+ *
1321be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1322be168c0dSopenharmony_ci+ *
1323be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1324be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1325be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1326be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1327be168c0dSopenharmony_ci+ * limitations under the License.
1328be168c0dSopenharmony_ci+ */
1329be168c0dSopenharmony_ci+
1330be168c0dSopenharmony_ci+#include "nnacl/infer/custom_masked_fill_infer.h"
1331be168c0dSopenharmony_ci+#include "nnacl/infer/infer_register.h"
1332be168c0dSopenharmony_ci+
1333be168c0dSopenharmony_ci+int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
1334be168c0dSopenharmony_ci+                               OpParameter *parameter) {
1335be168c0dSopenharmony_ci+  int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM);
1336be168c0dSopenharmony_ci+  if (check_ret != NNACL_OK) {
1337be168c0dSopenharmony_ci+    return check_ret;
1338be168c0dSopenharmony_ci+  }
1339be168c0dSopenharmony_ci+
1340be168c0dSopenharmony_ci+  const TensorC *input = inputs[0];
1341be168c0dSopenharmony_ci+  TensorC *output = outputs[0];
1342be168c0dSopenharmony_ci+  SetDataTypeFormat(output, input);
1343be168c0dSopenharmony_ci+  if (!InferFlag(inputs, inputs_size)) {
1344be168c0dSopenharmony_ci+    return NNACL_INFER_INVALID;
1345be168c0dSopenharmony_ci+  }
1346be168c0dSopenharmony_ci+  SetShapeTensor(output, input);
1347be168c0dSopenharmony_ci+  return NNACL_OK;
1348be168c0dSopenharmony_ci+}
1349be168c0dSopenharmony_ci+
1350be168c0dSopenharmony_ci+REG_INFER(CustomMaskedFill, PrimType_Inner_CustomMaskedFill, CustomMaskedFillInferShape)
1351be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h
1352be168c0dSopenharmony_cinew file mode 100644
1353be168c0dSopenharmony_ciindex 00000000..a8adbae2
1354be168c0dSopenharmony_ci--- /dev/null
1355be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h
1356be168c0dSopenharmony_ci@@ -0,0 +1,31 @@
1357be168c0dSopenharmony_ci+/**
1358be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1359be168c0dSopenharmony_ci+ *
1360be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1361be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1362be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1363be168c0dSopenharmony_ci+ *
1364be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1365be168c0dSopenharmony_ci+ *
1366be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1367be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1368be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1369be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1370be168c0dSopenharmony_ci+ * limitations under the License.
1371be168c0dSopenharmony_ci+ */
1372be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H
1373be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H
1374be168c0dSopenharmony_ci+
1375be168c0dSopenharmony_ci+#include "nnacl/infer/common_infer.h"
1376be168c0dSopenharmony_ci+
1377be168c0dSopenharmony_ci+#ifdef __cplusplus
1378be168c0dSopenharmony_ci+extern "C" {
1379be168c0dSopenharmony_ci+#endif
1380be168c0dSopenharmony_ci+
1381be168c0dSopenharmony_ci+int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
1382be168c0dSopenharmony_ci+                               OpParameter *parameter);
1383be168c0dSopenharmony_ci+
1384be168c0dSopenharmony_ci+#ifdef __cplusplus
1385be168c0dSopenharmony_ci+}
1386be168c0dSopenharmony_ci+#endif
1387be168c0dSopenharmony_ci+#endif  // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H
1388be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c
1389be168c0dSopenharmony_cinew file mode 100644
1390be168c0dSopenharmony_ciindex 00000000..be6716ba
1391be168c0dSopenharmony_ci--- /dev/null
1392be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c
1393be168c0dSopenharmony_ci@@ -0,0 +1,37 @@
1394be168c0dSopenharmony_ci+/**
1395be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1396be168c0dSopenharmony_ci+ *
1397be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1398be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1399be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1400be168c0dSopenharmony_ci+ *
1401be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1402be168c0dSopenharmony_ci+ *
1403be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1404be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1405be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1406be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1407be168c0dSopenharmony_ci+ * limitations under the License.
1408be168c0dSopenharmony_ci+ */
1409be168c0dSopenharmony_ci+
1410be168c0dSopenharmony_ci+#include "nnacl/infer/custom_tensor_scatter_max_infer.h"
1411be168c0dSopenharmony_ci+#include "nnacl/infer/infer_register.h"
1412be168c0dSopenharmony_ci+
1413be168c0dSopenharmony_ci+int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
1414be168c0dSopenharmony_ci+                                     size_t outputs_size, OpParameter *parameter) {
1415be168c0dSopenharmony_ci+  int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM);
1416be168c0dSopenharmony_ci+  if (check_ret != NNACL_OK) {
1417be168c0dSopenharmony_ci+    return check_ret;
1418be168c0dSopenharmony_ci+  }
1419be168c0dSopenharmony_ci+
1420be168c0dSopenharmony_ci+  const TensorC *input = inputs[0];
1421be168c0dSopenharmony_ci+  TensorC *output = outputs[0];
1422be168c0dSopenharmony_ci+  SetDataTypeFormat(output, input);
1423be168c0dSopenharmony_ci+  if (!InferFlag(inputs, inputs_size)) {
1424be168c0dSopenharmony_ci+    return NNACL_INFER_INVALID;
1425be168c0dSopenharmony_ci+  }
1426be168c0dSopenharmony_ci+  SetShapeTensor(output, input);
1427be168c0dSopenharmony_ci+  return NNACL_OK;
1428be168c0dSopenharmony_ci+}
1429be168c0dSopenharmony_ci+
1430be168c0dSopenharmony_ci+REG_INFER(CustomTensorScatterMax, PrimType_Inner_CustomTensorScatterMax, CustomTensorScatterMaxInferShape)
1431be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h
1432be168c0dSopenharmony_cinew file mode 100644
1433be168c0dSopenharmony_ciindex 00000000..641aa483
1434be168c0dSopenharmony_ci--- /dev/null
1435be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h
1436be168c0dSopenharmony_ci@@ -0,0 +1,31 @@
1437be168c0dSopenharmony_ci+/**
1438be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
1439be168c0dSopenharmony_ci+ *
1440be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
1441be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
1442be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
1443be168c0dSopenharmony_ci+ *
1444be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
1445be168c0dSopenharmony_ci+ *
1446be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
1447be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
1448be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1449be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
1450be168c0dSopenharmony_ci+ * limitations under the License.
1451be168c0dSopenharmony_ci+ */
1452be168c0dSopenharmony_ci+#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H
1453be168c0dSopenharmony_ci+#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H
1454be168c0dSopenharmony_ci+
1455be168c0dSopenharmony_ci+#include "nnacl/infer/common_infer.h"
1456be168c0dSopenharmony_ci+
1457be168c0dSopenharmony_ci+#ifdef __cplusplus
1458be168c0dSopenharmony_ci+extern "C" {
1459be168c0dSopenharmony_ci+#endif
1460be168c0dSopenharmony_ci+
1461be168c0dSopenharmony_ci+int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
1462be168c0dSopenharmony_ci+                                     size_t outputs_size, OpParameter *parameter);
1463be168c0dSopenharmony_ci+
1464be168c0dSopenharmony_ci+#ifdef __cplusplus
1465be168c0dSopenharmony_ci+}
1466be168c0dSopenharmony_ci+#endif
1467be168c0dSopenharmony_ci+#endif  // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H
1468be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h
1469be168c0dSopenharmony_cinew file mode 100644
1470be168c0dSopenharmony_ciindex 00000000..d7c34768
1471be168c0dSopenharmony_ci--- /dev/null
1472be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h
1473be168c0dSopenharmony_ci@@ -0,0 +1,65 @@
1474be168c0dSopenharmony_ci+/**
1475be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd
1476be168c0dSopenharmony_ci+*
1477be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License");
1478be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License.
1479be168c0dSopenharmony_ci+* You may obtain a copy of the License at
1480be168c0dSopenharmony_ci+*
1481be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0
1482be168c0dSopenharmony_ci+*
1483be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software
1484be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS,
1485be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1486be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and
1487be168c0dSopenharmony_ci+* limitations under the License.
1488be168c0dSopenharmony_ci+*/
1489be168c0dSopenharmony_ci+#ifndef NNACL_BASE_SCATTER_ND_BINARY_NEON_H_
1490be168c0dSopenharmony_ci+#define NNACL_BASE_SCATTER_ND_BINARY_NEON_H_
1491be168c0dSopenharmony_ci+
1492be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h"
1493be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_neon_instructions.h"
1494be168c0dSopenharmony_ci+
1495be168c0dSopenharmony_ci+#ifdef __cplusplus
1496be168c0dSopenharmony_ci+extern "C" {
1497be168c0dSopenharmony_ci+#endif
1498be168c0dSopenharmony_ci+
1499be168c0dSopenharmony_ci+#define MS_SIMD_INSTRUCTION MS_SIMD_NEON_INSTRUCTION
1500be168c0dSopenharmony_ci+#define BLOCK_NUM 4
1501be168c0dSopenharmony_ci+#define MS_SIMD_NEON
1502be168c0dSopenharmony_ci+
1503be168c0dSopenharmony_ci+static inline int ScatterNDAddFp32NEON(int index, const float *update, int size, float *output) {
1504be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1505be168c0dSopenharmony_ci+    SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1506be168c0dSopenharmony_ci+  }
1507be168c0dSopenharmony_ci+  return index;
1508be168c0dSopenharmony_ci+}
1509be168c0dSopenharmony_ci+
1510be168c0dSopenharmony_ci+static inline int ScatterNDAddInt32NEON(int index, const int *update, int size, int *output) {
1511be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1512be168c0dSopenharmony_ci+    SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1513be168c0dSopenharmony_ci+  }
1514be168c0dSopenharmony_ci+  return index;
1515be168c0dSopenharmony_ci+}
1516be168c0dSopenharmony_ci+
1517be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32NEON(int index, const float *update, int size, float *output) {
1518be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1519be168c0dSopenharmony_ci+    SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1520be168c0dSopenharmony_ci+  }
1521be168c0dSopenharmony_ci+  return index;
1522be168c0dSopenharmony_ci+}
1523be168c0dSopenharmony_ci+
1524be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32NEON(int index, const int *update, int size, int *output) {
1525be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1526be168c0dSopenharmony_ci+    SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1527be168c0dSopenharmony_ci+  }
1528be168c0dSopenharmony_ci+  return index;
1529be168c0dSopenharmony_ci+}
1530be168c0dSopenharmony_ci+
1531be168c0dSopenharmony_ci+#undef MS_SIMD_INSTRUCTION
1532be168c0dSopenharmony_ci+#undef BLOCK_NUM
1533be168c0dSopenharmony_ci+
1534be168c0dSopenharmony_ci+#undef MS_SIMD_NEON
1535be168c0dSopenharmony_ci+#ifdef __cplusplus
1536be168c0dSopenharmony_ci+}
1537be168c0dSopenharmony_ci+#endif
1538be168c0dSopenharmony_ci+#endif  // NNACL_BASE_SCATTER_ND_BINARY_NEON_H_
1539be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
1540be168c0dSopenharmony_ciindex 955a70a5..895f7e3d 100644
1541be168c0dSopenharmony_ci--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
1542be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
1543be168c0dSopenharmony_ci@@ -558,6 +558,10 @@ enum PrimType {
1544be168c0dSopenharmony_ci   PrimType_Inner_CustomGru = 10010,
1545be168c0dSopenharmony_ci   PrimType_Inner_CastGatherReduceFusion = 10011,
1546be168c0dSopenharmony_ci   PrimType_Inner_ReduceConcatFusion = 10012,
1547be168c0dSopenharmony_ci+  PrimType_Inner_ThirdPartyModel = 10013,
1548be168c0dSopenharmony_ci+  PrimType_Inner_CustomMaskedFill = 10014,
1549be168c0dSopenharmony_ci+  PrimType_Inner_CustomTensorScatterMax = 10015,
1550be168c0dSopenharmony_ci+  PrimType_Inner_CustomIsInf = 10016,
1551be168c0dSopenharmony_ci   PrimType_InnerOpMax,
1552be168c0dSopenharmony_ci   PrimType_InnerOpMin = PrimType_Inner_ToFormat
1553be168c0dSopenharmony_ci };
1554be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h
1555be168c0dSopenharmony_cinew file mode 100644
1556be168c0dSopenharmony_ciindex 00000000..dd9878f7
1557be168c0dSopenharmony_ci--- /dev/null
1558be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h
1559be168c0dSopenharmony_ci@@ -0,0 +1,36 @@
1560be168c0dSopenharmony_ci+/**
1561be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd
1562be168c0dSopenharmony_ci+*
1563be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License");
1564be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License.
1565be168c0dSopenharmony_ci+* You may obtain a copy of the License at
1566be168c0dSopenharmony_ci+*
1567be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0
1568be168c0dSopenharmony_ci+*
1569be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software
1570be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS,
1571be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1572be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and
1573be168c0dSopenharmony_ci+* limitations under the License.
1574be168c0dSopenharmony_ci+*/
1575be168c0dSopenharmony_ci+#ifndef NNACL_SCATTER_ND_BINARY_SIMD_H_
1576be168c0dSopenharmony_ci+#define NNACL_SCATTER_ND_BINARY_SIMD_H_
1577be168c0dSopenharmony_ci+
1578be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h"
1579be168c0dSopenharmony_ci+#ifdef ENABLE_AVX512
1580be168c0dSopenharmony_ci+#include "nnacl/avx512/scatter_nd_binary_avx512.h"
1581be168c0dSopenharmony_ci+#endif
1582be168c0dSopenharmony_ci+
1583be168c0dSopenharmony_ci+#ifdef ENABLE_AVX
1584be168c0dSopenharmony_ci+#include "nnacl/avx/scatter_nd_binary_avx.h"
1585be168c0dSopenharmony_ci+#endif
1586be168c0dSopenharmony_ci+
1587be168c0dSopenharmony_ci+#ifdef ENABLE_SSE
1588be168c0dSopenharmony_ci+#include "nnacl/sse/scatter_nd_binary_sse.h"
1589be168c0dSopenharmony_ci+#endif
1590be168c0dSopenharmony_ci+
1591be168c0dSopenharmony_ci+#ifdef ENABLE_ARM
1592be168c0dSopenharmony_ci+#include "nnacl/neon/scatter_nd_binary_neon.h"
1593be168c0dSopenharmony_ci+#endif
1594be168c0dSopenharmony_ci+
1595be168c0dSopenharmony_ci+#endif
1596be168c0dSopenharmony_cidiff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h
1597be168c0dSopenharmony_cinew file mode 100644
1598be168c0dSopenharmony_ciindex 00000000..983d2923
1599be168c0dSopenharmony_ci--- /dev/null
1600be168c0dSopenharmony_ci+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h
1601be168c0dSopenharmony_ci@@ -0,0 +1,66 @@
1602be168c0dSopenharmony_ci+/**
1603be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd
1604be168c0dSopenharmony_ci+*
1605be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License");
1606be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License.
1607be168c0dSopenharmony_ci+* You may obtain a copy of the License at
1608be168c0dSopenharmony_ci+*
1609be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0
1610be168c0dSopenharmony_ci+*
1611be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software
1612be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS,
1613be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1614be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and
1615be168c0dSopenharmony_ci+* limitations under the License.
1616be168c0dSopenharmony_ci+*/
1617be168c0dSopenharmony_ci+#ifndef NNACL_BASE_SCATTER_ND_BINARY_SSE_H_
1618be168c0dSopenharmony_ci+#define NNACL_BASE_SCATTER_ND_BINARY_SSE_H_
1619be168c0dSopenharmony_ci+
1620be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_instructions.h"
1621be168c0dSopenharmony_ci+#include "nnacl/intrinsics/ms_simd_sse_instructions.h"
1622be168c0dSopenharmony_ci+
1623be168c0dSopenharmony_ci+#ifdef __cplusplus
1624be168c0dSopenharmony_ci+extern "C" {
1625be168c0dSopenharmony_ci+#endif
1626be168c0dSopenharmony_ci+#pragma GCC push_options
1627be168c0dSopenharmony_ci+#pragma GCC target("sse4.1")
1628be168c0dSopenharmony_ci+#define MS_SIMD_INSTRUCTION MS_SIMD_SSE_INSTRUCTION
1629be168c0dSopenharmony_ci+#define BLOCK_NUM 4
1630be168c0dSopenharmony_ci+#define MS_SIMD_SSE
1631be168c0dSopenharmony_ci+
1632be168c0dSopenharmony_ci+static inline int ScatterNDAddFp32SSE(int index, const float *update, int size, float *output) {
1633be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1634be168c0dSopenharmony_ci+    SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1635be168c0dSopenharmony_ci+  }
1636be168c0dSopenharmony_ci+  return index;
1637be168c0dSopenharmony_ci+}
1638be168c0dSopenharmony_ci+
1639be168c0dSopenharmony_ci+static inline int ScatterNDAddInt32SSE(int index, const int *update, int size, int *output) {
1640be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1641be168c0dSopenharmony_ci+    SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1642be168c0dSopenharmony_ci+  }
1643be168c0dSopenharmony_ci+  return index;
1644be168c0dSopenharmony_ci+}
1645be168c0dSopenharmony_ci+
1646be168c0dSopenharmony_ci+static inline int ScatterNDMaxFp32SSE(int index, const float *update, int size, float *output) {
1647be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1648be168c0dSopenharmony_ci+    SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1649be168c0dSopenharmony_ci+  }
1650be168c0dSopenharmony_ci+  return index;
1651be168c0dSopenharmony_ci+}
1652be168c0dSopenharmony_ci+
1653be168c0dSopenharmony_ci+static inline int ScatterNDMaxInt32SSE(int index, const int *update, int size, int *output) {
1654be168c0dSopenharmony_ci+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1655be168c0dSopenharmony_ci+    SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1656be168c0dSopenharmony_ci+  }
1657be168c0dSopenharmony_ci+  return index;
1658be168c0dSopenharmony_ci+}
1659be168c0dSopenharmony_ci+
1660be168c0dSopenharmony_ci+#undef MS_SIMD_INSTRUCTION
1661be168c0dSopenharmony_ci+#undef BLOCK_NUM
1662be168c0dSopenharmony_ci+#pragma GCC pop_options
1663be168c0dSopenharmony_ci+#undef MS_SIMD_SSE
1664be168c0dSopenharmony_ci+#ifdef __cplusplus
1665be168c0dSopenharmony_ci+}
1666be168c0dSopenharmony_ci+#endif
1667be168c0dSopenharmony_ci+#endif  // NNACL_BASE_SCATTER_ND_BINARY_SSE_H_
1668be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/BUILD.gn b/mindspore/core/mindrt/BUILD.gn
1669be168c0dSopenharmony_ciindex b56d5f5c..b0e7c70d 100644
1670be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/BUILD.gn
1671be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/BUILD.gn
1672be168c0dSopenharmony_ci@@ -41,8 +41,15 @@ ohos_source_set("mindrt_obj") {
1673be168c0dSopenharmony_ci     "../../core/",
1674be168c0dSopenharmony_ci   ]
1675be168c0dSopenharmony_ci 
1676be168c0dSopenharmony_ci+  defines = [
1677be168c0dSopenharmony_ci+    "ENABLE_MINDRT",
1678be168c0dSopenharmony_ci+    "MS_COMPILE_OHOS",
1679be168c0dSopenharmony_ci+    "BUILD_LITE",
1680be168c0dSopenharmony_ci+  ]
1681be168c0dSopenharmony_ci+
1682be168c0dSopenharmony_ci+  external_deps = [ "hilog:libhilog" ]
1683be168c0dSopenharmony_ci+
1684be168c0dSopenharmony_ci   remove_configs = [ "//build/config/compiler:no_rtti" ]
1685be168c0dSopenharmony_ci-  defines = [ "BUILD_LITE" ]
1686be168c0dSopenharmony_ci 
1687be168c0dSopenharmony_ci   part_name = "mindspore"
1688be168c0dSopenharmony_ci   subsystem_name = "thirdparty"
1689be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/actor_threadpool.cc b/mindspore/core/mindrt/src/thread/actor_threadpool.cc
1690be168c0dSopenharmony_ciindex 70414757..c50c46e0 100644
1691be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/actor_threadpool.cc
1692be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/actor_threadpool.cc
1693be168c0dSopenharmony_ci@@ -32,7 +32,7 @@ void ActorWorker::RunWithSpin() {
1694be168c0dSopenharmony_ci   }
1695be168c0dSopenharmony_ci #if !defined(__APPLE__) && !defined(_MSC_VER)
1696be168c0dSopenharmony_ci   static std::atomic_int index{0};
1697be168c0dSopenharmony_ci-  (void)pthread_setname_np(pthread_self(), ("ActorThread_" + std::to_string(index++)).c_str());
1698be168c0dSopenharmony_ci+  (void)pthread_setname_np(pthread_self(), ("OS_Actor_" + std::to_string(index++)).c_str());
1699be168c0dSopenharmony_ci #endif
1700be168c0dSopenharmony_ci #ifdef PLATFORM_86
1701be168c0dSopenharmony_ci   // Some CPU kernels need set the flush zero mode to improve performance.
1702be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/core_affinity.cc b/mindspore/core/mindrt/src/thread/core_affinity.cc
1703be168c0dSopenharmony_ciindex 33bf3529..a3478dff 100644
1704be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/core_affinity.cc
1705be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/core_affinity.cc
1706be168c0dSopenharmony_ci@@ -344,12 +344,12 @@ int CoreAffinity::InitBindCoreId(size_t thread_num, BindMode bind_mode) {
1707be168c0dSopenharmony_ci int CoreAffinity::SetAffinity() { return THREAD_OK; }
1708be168c0dSopenharmony_ci #elif defined(BIND_CORE)
1709be168c0dSopenharmony_ci int CoreAffinity::SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set) {
1710be168c0dSopenharmony_ci-#ifdef __ANDROID__
1711be168c0dSopenharmony_ci-#if __ANDROID_API__ >= 21
1712be168c0dSopenharmony_ci+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS)
1713be168c0dSopenharmony_ci+#if (__ANDROID_API__ >= 21) || defined(MS_COMPILE_OHOS)
1714be168c0dSopenharmony_ci   THREAD_INFO("thread: %d, mask: %lu", pthread_gettid_np(thread_id), cpu_set->__bits[0]);
1715be168c0dSopenharmony_ci   int ret = sched_setaffinity(pthread_gettid_np(thread_id), sizeof(cpu_set_t), cpu_set);
1716be168c0dSopenharmony_ci   if (ret != THREAD_OK) {
1717be168c0dSopenharmony_ci-    THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", pthread_gettid_np(thread_id), ret);
1718be168c0dSopenharmony_ci+    THREAD_ERROR("bind thread %d to cpu failed. ERROR %{public}d", pthread_gettid_np(thread_id), ret);
1719be168c0dSopenharmony_ci     return THREAD_ERROR;
1720be168c0dSopenharmony_ci   }
1721be168c0dSopenharmony_ci #endif
1722be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/core_affinity.h b/mindspore/core/mindrt/src/thread/core_affinity.h
1723be168c0dSopenharmony_ciindex 2dd2abd1..28b0967a 100644
1724be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/core_affinity.h
1725be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/core_affinity.h
1726be168c0dSopenharmony_ci@@ -23,7 +23,7 @@
1727be168c0dSopenharmony_ci #ifdef PARALLEL_INFERENCE
1728be168c0dSopenharmony_ci #define BIND_CORE
1729be168c0dSopenharmony_ci #endif
1730be168c0dSopenharmony_ci-#ifdef __ANDROID__
1731be168c0dSopenharmony_ci+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS)
1732be168c0dSopenharmony_ci #define BIND_CORE
1733be168c0dSopenharmony_ci #include <sched.h>
1734be168c0dSopenharmony_ci #endif
1735be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/parallel_threadpool.cc b/mindspore/core/mindrt/src/thread/parallel_threadpool.cc
1736be168c0dSopenharmony_ciindex 9e0dd25c..09c39f32 100644
1737be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/parallel_threadpool.cc
1738be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/parallel_threadpool.cc
1739be168c0dSopenharmony_ci@@ -48,7 +48,7 @@ void ParallelWorker::ParallelRun() {
1740be168c0dSopenharmony_ci     SetAffinity();
1741be168c0dSopenharmony_ci   }
1742be168c0dSopenharmony_ci #if !defined(__APPLE__) && !defined(_MSC_VER)
1743be168c0dSopenharmony_ci-  (void)pthread_setname_np(pthread_self(), ("ParallelThread_" + std::to_string(worker_id_)).c_str());
1744be168c0dSopenharmony_ci+  (void)pthread_setname_np(pthread_self(), ("OS_Parallel_" + std::to_string(worker_id_)).c_str());
1745be168c0dSopenharmony_ci #endif
1746be168c0dSopenharmony_ci #ifdef PLATFORM_86
1747be168c0dSopenharmony_ci   // Some CPU kernels need set the flush zero mode to improve performance.
1748be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/threadlog.h b/mindspore/core/mindrt/src/thread/threadlog.h
1749be168c0dSopenharmony_ciindex 7ed917f1..b212a401 100644
1750be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/threadlog.h
1751be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/threadlog.h
1752be168c0dSopenharmony_ci@@ -16,7 +16,9 @@
1753be168c0dSopenharmony_ci 
1754be168c0dSopenharmony_ci #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_LOG_H_
1755be168c0dSopenharmony_ci #define MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_LOG_H_
1756be168c0dSopenharmony_ci-
1757be168c0dSopenharmony_ci+#ifdef MS_COMPILE_OHOS
1758be168c0dSopenharmony_ci+#include "hilog/log.h"
1759be168c0dSopenharmony_ci+#endif
1760be168c0dSopenharmony_ci namespace mindspore {
1761be168c0dSopenharmony_ci #ifdef THREAD_POOL_DEBUG
1762be168c0dSopenharmony_ci #include <stdio.h>
1763be168c0dSopenharmony_ci@@ -32,13 +34,35 @@ namespace mindspore {
1764be168c0dSopenharmony_ci   }
1765be168c0dSopenharmony_ci #else
1766be168c0dSopenharmony_ci #define THREAD_DEBUG(content, ...)
1767be168c0dSopenharmony_ci-#define THREAD_INFO(content, ...)
1768be168c0dSopenharmony_ci #define THREAD_TEST_TRUE(flag)
1769be168c0dSopenharmony_ci+
1770be168c0dSopenharmony_ci #if defined(__ANDROID__)
1771be168c0dSopenharmony_ci+#define THREAD_INFO(content, ...)
1772be168c0dSopenharmony_ci #include <android/log.h>
1773be168c0dSopenharmony_ci #define THREAD_ERROR(content, args...) \
1774be168c0dSopenharmony_ci   { __android_log_print(ANDROID_LOG_ERROR, "MS_LITE", "%s|%d: " #content "\r\n", __func__, __LINE__, ##args); }
1775be168c0dSopenharmony_ci+
1776be168c0dSopenharmony_ci+#elif defined(MS_COMPILE_OHOS) // For OHOS, use hilog.
1777be168c0dSopenharmony_ci+
1778be168c0dSopenharmony_ci+#define MINDRT_OHOS_LOG_DOMAIN 0x2102
1779be168c0dSopenharmony_ci+#define MINDRT_OHOS_LOG_TAG "MS_LITE"
1780be168c0dSopenharmony_ci+
1781be168c0dSopenharmony_ci+#ifdef MS_COMPILE_WITH_OHOS_NDK
1782be168c0dSopenharmony_ci+// When build with OHOS NDK, use public api of hilog module.
1783be168c0dSopenharmony_ci+#define THREAD_INFO(content, args...) \
1784be168c0dSopenharmony_ci+  { OH_LOG_Print(LOG_APP, LOG_INFO, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); }
1785be168c0dSopenharmony_ci+#define THREAD_ERROR(content, args...) \
1786be168c0dSopenharmony_ci+  { OH_LOG_Print(LOG_APP, LOG_ERROR, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); }
1787be168c0dSopenharmony_ci+#else
1788be168c0dSopenharmony_ci+// When build in OHOS repo, use inner api of hilog module.
1789be168c0dSopenharmony_ci+#define THREAD_INFO(content, args...) \
1790be168c0dSopenharmony_ci+  { HiLogPrint(LOG_APP, LOG_INFO, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); }
1791be168c0dSopenharmony_ci+#define THREAD_ERROR(content, args...) \
1792be168c0dSopenharmony_ci+  { HiLogPrint(LOG_APP, LOG_ERROR, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); }
1793be168c0dSopenharmony_ci+#endif
1794be168c0dSopenharmony_ci+
1795be168c0dSopenharmony_ci #else
1796be168c0dSopenharmony_ci+#define THREAD_INFO(content, ...)
1797be168c0dSopenharmony_ci #define THREAD_ERROR(content, ...)
1798be168c0dSopenharmony_ci #endif
1799be168c0dSopenharmony_ci #endif
1800be168c0dSopenharmony_cidiff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc
1801be168c0dSopenharmony_ciindex c56e0425..2301be8c 100644
1802be168c0dSopenharmony_ci--- a/mindspore/core/mindrt/src/thread/threadpool.cc
1803be168c0dSopenharmony_ci+++ b/mindspore/core/mindrt/src/thread/threadpool.cc
1804be168c0dSopenharmony_ci@@ -68,10 +68,11 @@ void Worker::SetAffinity() {
1805be168c0dSopenharmony_ci #ifdef _WIN32
1806be168c0dSopenharmony_ci   SetWindowsSelfAffinity(core_id_);
1807be168c0dSopenharmony_ci #elif defined(BIND_CORE)
1808be168c0dSopenharmony_ci-#ifdef __ANDROID__
1809be168c0dSopenharmony_ci+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS)
1810be168c0dSopenharmony_ci+  THREAD_INFO("thread: %d, mask: %lu", gettid(), mask_.__bits[0]);
1811be168c0dSopenharmony_ci   int ret = sched_setaffinity(gettid(), sizeof(cpu_set_t), &mask_);
1812be168c0dSopenharmony_ci   if (ret != THREAD_OK) {
1813be168c0dSopenharmony_ci-    THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", gettid(), errno);
1814be168c0dSopenharmony_ci+    THREAD_ERROR("bind thread %d to cpu failed. ERROR %{public}d", gettid(), errno);
1815be168c0dSopenharmony_ci   }
1816be168c0dSopenharmony_ci   return;
1817be168c0dSopenharmony_ci #else
1818be168c0dSopenharmony_ci@@ -111,7 +112,7 @@ void Worker::Run() {
1819be168c0dSopenharmony_ci   }
1820be168c0dSopenharmony_ci #if !defined(__APPLE__) && !defined(_MSC_VER)
1821be168c0dSopenharmony_ci   static std::atomic_int index = {0};
1822be168c0dSopenharmony_ci-  (void)pthread_setname_np(pthread_self(), ("KernelThread_" + std::to_string(index++)).c_str());
1823be168c0dSopenharmony_ci+  (void)pthread_setname_np(pthread_self(), ("OS_Kernel_" + std::to_string(index++)).c_str());
1824be168c0dSopenharmony_ci #endif
1825be168c0dSopenharmony_ci #ifdef PLATFORM_86
1826be168c0dSopenharmony_ci   // Some CPU kernels need set the flush zero mode to improve performance.
1827be168c0dSopenharmony_cidiff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn
1828be168c0dSopenharmony_ciindex a774b58c..f7e465e2 100644
1829be168c0dSopenharmony_ci--- a/mindspore/lite/BUILD.gn
1830be168c0dSopenharmony_ci+++ b/mindspore/lite/BUILD.gn
1831be168c0dSopenharmony_ci@@ -71,9 +71,14 @@
1832be168c0dSopenharmony_ci 
1833be168c0dSopenharmony_ci import("//build/ohos.gni")
1834be168c0dSopenharmony_ci 
1835be168c0dSopenharmony_ci+declare_args() {
1836be168c0dSopenharmony_ci+    mindspore_feature_nnrt_metagraph = false
1837be168c0dSopenharmony_ci+}
1838be168c0dSopenharmony_ci+
1839be168c0dSopenharmony_ci ohos_group("mindspore") {
1840be168c0dSopenharmony_ci   deps = [
1841be168c0dSopenharmony_ci     ":mindspore_lib",
1842be168c0dSopenharmony_ci+    ":mindspore_ndk",
1843be168c0dSopenharmony_ci     ":mindspore_train_lib",
1844be168c0dSopenharmony_ci     "mindir:mindir_lib",
1845be168c0dSopenharmony_ci     "src/litert/js_api:mindsporelite_napi"
1846be168c0dSopenharmony_ci@@ -180,7 +185,6 @@ lite_mindrt_sources = [
1847be168c0dSopenharmony_ci ]
1848be168c0dSopenharmony_ci 
1849be168c0dSopenharmony_ci all_lite_sources += cxx_api_sources
1850be168c0dSopenharmony_ci-all_lite_sources += c_api_sources
1851be168c0dSopenharmony_ci all_lite_sources += api_source
1852be168c0dSopenharmony_ci all_lite_sources += control_flow_kernel_sources
1853be168c0dSopenharmony_ci all_lite_sources += experimental_sources
1854be168c0dSopenharmony_ci@@ -368,7 +372,6 @@ ohos_shared_library("mindspore_lib") {
1855be168c0dSopenharmony_ci   sources = all_sources
1856be168c0dSopenharmony_ci 
1857be168c0dSopenharmony_ci   include_dirs = [
1858be168c0dSopenharmony_ci-    "//base/hiviewdfx/hilog/interfaces/native/innerkits/include",
1859be168c0dSopenharmony_ci     "//third_party/flatbuffers/include",
1860be168c0dSopenharmony_ci     "./",
1861be168c0dSopenharmony_ci     "../",
1862be168c0dSopenharmony_ci@@ -384,6 +387,7 @@ ohos_shared_library("mindspore_lib") {
1863be168c0dSopenharmony_ci     "../ccsrc/",
1864be168c0dSopenharmony_ci     "src/litert/kernel/cpu/",
1865be168c0dSopenharmony_ci     "../core/mindrt/src/",
1866be168c0dSopenharmony_ci+    "//foundation/ai/neural_network_runtime/",
1867be168c0dSopenharmony_ci   ]
1868be168c0dSopenharmony_ci 
1869be168c0dSopenharmony_ci   defines = [
1870be168c0dSopenharmony_ci@@ -426,24 +430,29 @@ ohos_shared_library("mindspore_lib") {
1871be168c0dSopenharmony_ci 
1872be168c0dSopenharmony_ci   external_deps = [ "hilog:libhilog" ]
1873be168c0dSopenharmony_ci 
1874be168c0dSopenharmony_ci-  output_name = "libmindspore-lite.huawei"
1875be168c0dSopenharmony_ci+  output_name = "libmindspore-lite"
1876be168c0dSopenharmony_ci   output_extension = "so"
1877be168c0dSopenharmony_ci   innerapi_tags = [ "platformsdk" ]
1878be168c0dSopenharmony_ci   SUPPORT_NNRT = true
1879be168c0dSopenharmony_ci   if (SUPPORT_NNRT) {
1880be168c0dSopenharmony_ci+    if (mindspore_feature_nnrt_metagraph) {
1881be168c0dSopenharmony_ci+      defines += [ "SUPPORT_NNRT_METAGRAPH" ]
1882be168c0dSopenharmony_ci+      print("enabled feature: mindspore_feature_nnrt_metagraph")
1883be168c0dSopenharmony_ci+    }
1884be168c0dSopenharmony_ci     sources += [
1885be168c0dSopenharmony_ci       "src/litert/delegate/nnrt/checker/primitive_check.cc",
1886be168c0dSopenharmony_ci       "src/litert/delegate/nnrt/nnrt_delegate.cc",
1887be168c0dSopenharmony_ci       "src/litert/delegate/nnrt/nnrt_model_kernel.cc",
1888be168c0dSopenharmony_ci     ]
1889be168c0dSopenharmony_ci     include_dirs += [
1890be168c0dSopenharmony_ci-      "//foundation/ai/neural_network_runtime",
1891be168c0dSopenharmony_ci       "src/delegate/nnrt/include",
1892be168c0dSopenharmony_ci       "../../mindspore/core/ir",
1893be168c0dSopenharmony_ci       "mindir/include",
1894be168c0dSopenharmony_ci       "mindir/inner_headers",
1895be168c0dSopenharmony_ci     ]
1896be168c0dSopenharmony_ci+
1897be168c0dSopenharmony_ci     external_deps += [ "neural_network_runtime:nnrt_target" ]
1898be168c0dSopenharmony_ci+
1899be168c0dSopenharmony_ci     deps += [ "mindir:mindir_lib" ]
1900be168c0dSopenharmony_ci     defines += [ "SUPPORT_NNRT" ]
1901be168c0dSopenharmony_ci   }
1902be168c0dSopenharmony_ci@@ -461,6 +470,67 @@ ohos_shared_library("mindspore_lib") {
1903be168c0dSopenharmony_ci   subsystem_name = "thirdparty"
1904be168c0dSopenharmony_ci }
1905be168c0dSopenharmony_ci 
1906be168c0dSopenharmony_ci+# NDK lib
1907be168c0dSopenharmony_ci+ohos_shared_library("mindspore_ndk") {
1908be168c0dSopenharmony_ci+  deps = [
1909be168c0dSopenharmony_ci+    ":mindspore_lib",
1910be168c0dSopenharmony_ci+    ":mindspore_train_lib"
1911be168c0dSopenharmony_ci+  ]
1912be168c0dSopenharmony_ci+
1913be168c0dSopenharmony_ci+  sources = c_api_sources
1914be168c0dSopenharmony_ci+
1915be168c0dSopenharmony_ci+  include_dirs = [
1916be168c0dSopenharmony_ci+    "//base/hiviewdfx/hilog/interfaces/native/innerkits/include",
1917be168c0dSopenharmony_ci+    "//third_party/flatbuffers/include",
1918be168c0dSopenharmony_ci+    "./",
1919be168c0dSopenharmony_ci+    "../",
1920be168c0dSopenharmony_ci+    "../../",
1921be168c0dSopenharmony_ci+    "../core",
1922be168c0dSopenharmony_ci+    "src",
1923be168c0dSopenharmony_ci+    "src/c_api/",
1924be168c0dSopenharmony_ci+    "../ccsrc/plugin/device/cpu/kernel/",
1925be168c0dSopenharmony_ci+    "../core/mindrt/src/",
1926be168c0dSopenharmony_ci+    "../core/mindrt/include/",
1927be168c0dSopenharmony_ci+    "../../third_party/",
1928be168c0dSopenharmony_ci+    "./schema/",
1929be168c0dSopenharmony_ci+    "../ccsrc/",
1930be168c0dSopenharmony_ci+    "//foundation/ai/neural_network_runtime/",
1931be168c0dSopenharmony_ci+  ]
1932be168c0dSopenharmony_ci+
1933be168c0dSopenharmony_ci+  defines = [
1934be168c0dSopenharmony_ci+    "SUPPORT_NNRT",
1935be168c0dSopenharmony_ci+    "MS_COMPILE_OHOS",
1936be168c0dSopenharmony_ci+    "PRIMITIVE_WRITEABLE",
1937be168c0dSopenharmony_ci+    "RUNTIME_PASS_CLIP",
1938be168c0dSopenharmony_ci+    "ENABLE_MULTI_LAYOUT",
1939be168c0dSopenharmony_ci+    "VERSION_STR=\"2.1.0\"",
1940be168c0dSopenharmony_ci+  ]
1941be168c0dSopenharmony_ci+
1942be168c0dSopenharmony_ci+  configs = [
1943be168c0dSopenharmony_ci+    ":mindspore_api",
1944be168c0dSopenharmony_ci+    ":disable_android",
1945be168c0dSopenharmony_ci+    ":secure_option",
1946be168c0dSopenharmony_ci+  ]
1947be168c0dSopenharmony_ci+
1948be168c0dSopenharmony_ci+  external_deps = [ "neural_network_runtime:nnrt_target" ]
1949be168c0dSopenharmony_ci+
1950be168c0dSopenharmony_ci+  remove_configs = [ "//build/config/compiler:no_rtti" ]
1951be168c0dSopenharmony_ci+
1952be168c0dSopenharmony_ci+  output_name = "libmindspore_lite_ndk"
1953be168c0dSopenharmony_ci+  output_extension = "so"
1954be168c0dSopenharmony_ci+  innerapi_tags = [ "ndk"]
1955be168c0dSopenharmony_ci+  cflags_cc = [
1956be168c0dSopenharmony_ci+    "-Wno-ignored-qualifiers",
1957be168c0dSopenharmony_ci+    "-Wunused-private-field",
1958be168c0dSopenharmony_ci+    "-Wno-unused-private-field",
1959be168c0dSopenharmony_ci+    "-Wno-inconsistent-missing-override",
1960be168c0dSopenharmony_ci+    "-Wno-macro-redefined",
1961be168c0dSopenharmony_ci+    "-Wno-constant-conversion",
1962be168c0dSopenharmony_ci+  ]
1963be168c0dSopenharmony_ci+  part_name = "mindspore"
1964be168c0dSopenharmony_ci+  subsystem_name = "thirdparty"
1965be168c0dSopenharmony_ci+}
1966be168c0dSopenharmony_ci+
1967be168c0dSopenharmony_ci # Train library
1968be168c0dSopenharmony_ci expression_cxx_api_sources = [
1969be168c0dSopenharmony_ci   "src/litert/cxx_api/expression/net.cc",
1970be168c0dSopenharmony_ci@@ -614,7 +684,6 @@ ohos_shared_library("mindspore_train_lib") {
1971be168c0dSopenharmony_ci   sources = all_train_sources
1972be168c0dSopenharmony_ci 
1973be168c0dSopenharmony_ci   include_dirs = [
1974be168c0dSopenharmony_ci-    "//base/hiviewdfx/hilog/interfaces/native/innerkits/include",
1975be168c0dSopenharmony_ci     "//third_party/flatbuffers/include",
1976be168c0dSopenharmony_ci     "./",
1977be168c0dSopenharmony_ci     "../",
1978be168c0dSopenharmony_ci@@ -698,6 +767,9 @@ config("disable_android") {
1979be168c0dSopenharmony_ci     "-U__ANDROID__",
1980be168c0dSopenharmony_ci     "-U__ANDROID_API__",
1981be168c0dSopenharmony_ci   ]
1982be168c0dSopenharmony_ci+  ldflags = [
1983be168c0dSopenharmony_ci+    "-Wl,--no-as-needed",
1984be168c0dSopenharmony_ci+  ]
1985be168c0dSopenharmony_ci }
1986be168c0dSopenharmony_ci 
1987be168c0dSopenharmony_ci config("secure_option") {
1988be168c0dSopenharmony_cidiff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt
1989be168c0dSopenharmony_ciindex 72337f70..1faf2f38 100644
1990be168c0dSopenharmony_ci--- a/mindspore/lite/CMakeLists.txt
1991be168c0dSopenharmony_ci+++ b/mindspore/lite/CMakeLists.txt
1992be168c0dSopenharmony_ci@@ -298,8 +298,9 @@ elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite")
1993be168c0dSopenharmony_ci elseif(TOOLCHAIN_NAME STREQUAL "ohos")
1994be168c0dSopenharmony_ci     set(TARGET_OHOS on)
1995be168c0dSopenharmony_ci     add_compile_definitions(MS_COMPILE_OHOS)
1996be168c0dSopenharmony_ci-    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions")
1997be168c0dSopenharmony_ci-    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions")
1998be168c0dSopenharmony_ci+    add_compile_definitions(MS_COMPILE_WITH_OHOS_NDK)
1999be168c0dSopenharmony_ci+    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions -Wno-deprecated-builtins")
2000be168c0dSopenharmony_ci+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions -Wno-deprecated-builtins")
2001be168c0dSopenharmony_ci endif()
2002be168c0dSopenharmony_ci 
2003be168c0dSopenharmony_ci if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0
2004be168c0dSopenharmony_cidiff --git a/mindspore/lite/include/lite_types.h b/mindspore/lite/include/lite_types.h
2005be168c0dSopenharmony_ciindex 017e98a8..860390d5 100644
2006be168c0dSopenharmony_ci--- a/mindspore/lite/include/lite_types.h
2007be168c0dSopenharmony_ci+++ b/mindspore/lite/include/lite_types.h
2008be168c0dSopenharmony_ci@@ -42,6 +42,7 @@ typedef enum {
2009be168c0dSopenharmony_ci   DT_NPU,    /**< NPU device type */
2010be168c0dSopenharmony_ci   DT_ASCEND, /**< ASCEND device type */
2011be168c0dSopenharmony_ci   DT_CUSTOM, /**< EXTEND device type */
2012be168c0dSopenharmony_ci+  DT_NNRT,   /**< NNRT device type */
2013be168c0dSopenharmony_ci   DT_END     /**< NO device type */
2014be168c0dSopenharmony_ci } DeviceType;
2015be168c0dSopenharmony_ci 
2016be168c0dSopenharmony_cidiff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h
2017be168c0dSopenharmony_ciindex 93e27ea9..b96c7e35 100644
2018be168c0dSopenharmony_ci--- a/mindspore/lite/include/model.h
2019be168c0dSopenharmony_ci+++ b/mindspore/lite/include/model.h
2020be168c0dSopenharmony_ci@@ -25,6 +25,7 @@ namespace mindspore {
2021be168c0dSopenharmony_ci namespace schema {
2022be168c0dSopenharmony_ci struct Tensor;
2023be168c0dSopenharmony_ci }  // namespace schema
2024be168c0dSopenharmony_ci+
2025be168c0dSopenharmony_ci namespace lite {
2026be168c0dSopenharmony_ci typedef enum { ModelType_MSLite, ModelType_MindIR } LiteModelType;
2027be168c0dSopenharmony_ci 
2028be168c0dSopenharmony_ci@@ -62,7 +63,10 @@ struct MS_API LiteGraph {
2029be168c0dSopenharmony_ci   bool model_obfuscated_ = false;
2030be168c0dSopenharmony_ci   std::vector<unsigned char *> deobf_prims_;
2031be168c0dSopenharmony_ci #endif
2032be168c0dSopenharmony_ci+
2033be168c0dSopenharmony_ci+  std::string ToString() const;
2034be168c0dSopenharmony_ci };
2035be168c0dSopenharmony_ci+
2036be168c0dSopenharmony_ci struct MS_API Model {
2037be168c0dSopenharmony_ci   LiteGraph graph_;
2038be168c0dSopenharmony_ci   char *buf = nullptr;
2039be168c0dSopenharmony_cidiff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h
2040be168c0dSopenharmony_ciindex 2d72b200..4bc92599 100644
2041be168c0dSopenharmony_ci--- a/mindspore/lite/include/registry/converter_context.h
2042be168c0dSopenharmony_ci+++ b/mindspore/lite/include/registry/converter_context.h
2043be168c0dSopenharmony_ci@@ -39,7 +39,9 @@ enum MS_API FmkType : int {
2044be168c0dSopenharmony_ci   kFmkTypeMs = 3,
2045be168c0dSopenharmony_ci   kFmkTypeTflite = 4,
2046be168c0dSopenharmony_ci   kFmkTypePytorch = 5,
2047be168c0dSopenharmony_ci-  kFmkTypeMsLite = 6,
2048be168c0dSopenharmony_ci+  kFmkTypeThirdParty = 6,
2049be168c0dSopenharmony_ci+  kFmkTypeMsLite = 7,
2050be168c0dSopenharmony_ci+  kFmkTypeEnd = 8,  // For range check purpose, valid range: [0, kFmkTypeEnd)
2051be168c0dSopenharmony_ci };
2052be168c0dSopenharmony_ci 
2053be168c0dSopenharmony_ci /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
2054be168c0dSopenharmony_cidiff --git a/mindspore/lite/mindir/include/mindir.h b/mindspore/lite/mindir/include/mindir.h
2055be168c0dSopenharmony_ciindex ca811dce..f47cad8c 100644
2056be168c0dSopenharmony_ci--- a/mindspore/lite/mindir/include/mindir.h
2057be168c0dSopenharmony_ci+++ b/mindspore/lite/mindir/include/mindir.h
2058be168c0dSopenharmony_ci@@ -151,6 +151,8 @@ int64_t MindIR_Conv2DFusion_GetOutChannel(ConstPrimitivePtr primitive);
2059be168c0dSopenharmony_ci void MindIR_Conv2DFusion_SetOutChannel(PrimitivePtr *primitive, int64_t out_channel);
2060be168c0dSopenharmony_ci ActivationType MindIR_Conv2DFusion_GetActivationType(ConstPrimitivePtr primitive);
2061be168c0dSopenharmony_ci void MindIR_Conv2DFusion_SetActivationType(PrimitivePtr *primitive, ActivationType activation_type);
2062be168c0dSopenharmony_ci+Format MindIR_Conv2DFusion_GetFormat(ConstPrimitivePtr primitive);
2063be168c0dSopenharmony_ci+void MindIR_Conv2DFusion_SetFormat(PrimitivePtr *primitive, Format format);
2064be168c0dSopenharmony_ci 
2065be168c0dSopenharmony_ci // ********** Conv2dTransposeFusion **********
2066be168c0dSopenharmony_ci PrimitivePtr MindIR_Conv2dTransposeFusion_CreatePrimitive(
2067be168c0dSopenharmony_cidiff --git a/mindspore/lite/mindir/src/mindir.cc b/mindspore/lite/mindir/src/mindir.cc
2068be168c0dSopenharmony_ciindex 7fc9c00e..374bbef5 100644
2069be168c0dSopenharmony_ci--- a/mindspore/lite/mindir/src/mindir.cc
2070be168c0dSopenharmony_ci+++ b/mindspore/lite/mindir/src/mindir.cc
2071be168c0dSopenharmony_ci@@ -1215,6 +1215,46 @@ void MindIR_Conv2DFusion_SetActivationType(PrimitivePtr *primitive, ActivationTy
2072be168c0dSopenharmony_ci   }
2073be168c0dSopenharmony_ci }
2074be168c0dSopenharmony_ci 
2075be168c0dSopenharmony_ci+Format MindIR_Conv2DFusion_GetFormat(ConstPrimitivePtr primitive) {
2076be168c0dSopenharmony_ci+  if (primitive != nullptr) {
2077be168c0dSopenharmony_ci+    auto prim = static_cast<const schema::Primitive *>(primitive);
2078be168c0dSopenharmony_ci+    auto value = prim->value_as_Conv2DFusion();
2079be168c0dSopenharmony_ci+    if (prim != nullptr && value != nullptr) {
2080be168c0dSopenharmony_ci+      return static_cast<Format>(value->format());
2081be168c0dSopenharmony_ci+    } else {
2082be168c0dSopenharmony_ci+      Format en = static_cast<Format>(0);
2083be168c0dSopenharmony_ci+      return en;
2084be168c0dSopenharmony_ci+    }
2085be168c0dSopenharmony_ci+  } else {
2086be168c0dSopenharmony_ci+    Format en = static_cast<Format>(0);
2087be168c0dSopenharmony_ci+    return en;
2088be168c0dSopenharmony_ci+  }
2089be168c0dSopenharmony_ci+}
2090be168c0dSopenharmony_ci+
2091be168c0dSopenharmony_ci+void MindIR_Conv2DFusion_SetFormat(PrimitivePtr *primitive, Format format) {
2092be168c0dSopenharmony_ci+  if (primitive != nullptr && *primitive != nullptr) {
2093be168c0dSopenharmony_ci+    auto prim = static_cast<schema::Primitive *>(*primitive);
2094be168c0dSopenharmony_ci+    auto value = prim->value_as_Conv2DFusion();
2095be168c0dSopenharmony_ci+    if (prim != nullptr && value != nullptr) {
2096be168c0dSopenharmony_ci+      flatbuffers::FlatBufferBuilder fbb;
2097be168c0dSopenharmony_ci+      auto ops_offset = schema::CreateConv2DFusion(
2098be168c0dSopenharmony_ci+        fbb, static_cast<schema::Format>(format),
2099be168c0dSopenharmony_ci+        fbb.CreateVector(value->kernel_size()->data(), value->kernel_size()->size()),
2100be168c0dSopenharmony_ci+        fbb.CreateVector(value->stride()->data(), value->stride()->size()),
2101be168c0dSopenharmony_ci+        fbb.CreateVector(value->dilation()->data(), value->dilation()->size()),
2102be168c0dSopenharmony_ci+        static_cast<schema::PadMode>(value->pad_mode()),
2103be168c0dSopenharmony_ci+        fbb.CreateVector(value->pad_list()->data(), value->pad_list()->size()), 0, value->group(), value->in_channel(),
2104be168c0dSopenharmony_ci+        value->out_channel(), static_cast<schema::ActivationType>(value->activation_type()));
2105be168c0dSopenharmony_ci+      auto prim_offset =
2106be168c0dSopenharmony_ci+        schema::CreatePrimitive(fbb, static_cast<schema::PrimitiveType>(NODE_TYPE_CONV2D_FUSION), ops_offset.o);
2107be168c0dSopenharmony_ci+      fbb.Finish(prim_offset);
2108be168c0dSopenharmony_ci+      auto new_addr = MindIRMemoryManager::GetInstance()->CreatePrimitiveFromBuilder(fbb, prim);
2109be168c0dSopenharmony_ci+      auto ret_value = flatbuffers::GetMutableRoot<schema::Primitive>(new_addr);
2110be168c0dSopenharmony_ci+      *primitive = ret_value;
2111be168c0dSopenharmony_ci+    }
2112be168c0dSopenharmony_ci+  }
2113be168c0dSopenharmony_ci+}
2114be168c0dSopenharmony_ci+
2115be168c0dSopenharmony_ci // ********** Conv2dTransposeFusion **********
2116be168c0dSopenharmony_ci PrimitivePtr MindIR_Conv2dTransposeFusion_CreatePrimitive(
2117be168c0dSopenharmony_ci   const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation,
2118be168c0dSopenharmony_cidiff --git a/mindspore/lite/mindir/src/mindir_tensor.cc b/mindspore/lite/mindir/src/mindir_tensor.cc
2119be168c0dSopenharmony_ciindex 9ec2d0e4..2db4ce8b 100644
2120be168c0dSopenharmony_ci--- a/mindspore/lite/mindir/src/mindir_tensor.cc
2121be168c0dSopenharmony_ci+++ b/mindspore/lite/mindir/src/mindir_tensor.cc
2122be168c0dSopenharmony_ci@@ -134,7 +134,7 @@ void MindIR_Tensor_SetDataType(TensorPtr *tensor, DataType data_type) {
2123be168c0dSopenharmony_ci         name = fbb.CreateString(value->name()->c_str(), value->name()->size());
2124be168c0dSopenharmony_ci       }
2125be168c0dSopenharmony_ci       auto ops_offset =
2126be168c0dSopenharmony_ci-        schema::CreateTensor(fbb, 0, value->dataType(), dims, static_cast<schema::Format>(value->format()), 0, 0, data,
2127be168c0dSopenharmony_ci+        schema::CreateTensor(fbb, 0, data_type, dims, static_cast<schema::Format>(value->format()), 0, 0, data,
2128be168c0dSopenharmony_ci                              ConvertQuantParams(fbb, value->quantParams()), 0, name);
2129be168c0dSopenharmony_ci       fbb.Finish(ops_offset);
2130be168c0dSopenharmony_ci       auto new_addr = MindIRMemoryManager::GetInstance()->CreateTensorFromBuilder(fbb, value);
2131be168c0dSopenharmony_cidiff --git a/mindspore/lite/mindir/src/utils.cc b/mindspore/lite/mindir/src/utils.cc
2132be168c0dSopenharmony_ciindex 28d66ceb..b044f414 100644
2133be168c0dSopenharmony_ci--- a/mindspore/lite/mindir/src/utils.cc
2134be168c0dSopenharmony_ci+++ b/mindspore/lite/mindir/src/utils.cc
2135be168c0dSopenharmony_ci@@ -22,7 +22,7 @@ namespace lite {
2136be168c0dSopenharmony_ci 
2137be168c0dSopenharmony_ci // ********** PrimitiveBase **********
2138be168c0dSopenharmony_ci NodeType MindIR_Primitive_GetType(PrimitivePtr primitive) {
2139be168c0dSopenharmony_ci-  auto prim = flatbuffers::GetMutableRoot<schema::Primitive>(primitive);
2140be168c0dSopenharmony_ci+  auto prim = static_cast<schema::Primitive *>(primitive);
2141be168c0dSopenharmony_ci   auto type = prim->value_type();
2142be168c0dSopenharmony_ci   return static_cast<NodeType>(type);
2143be168c0dSopenharmony_ci }
2144be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
2145be168c0dSopenharmony_ciindex 5afccc87..de1781cd 100644
2146be168c0dSopenharmony_ci--- a/mindspore/lite/src/CMakeLists.txt
2147be168c0dSopenharmony_ci+++ b/mindspore/lite/src/CMakeLists.txt
2148be168c0dSopenharmony_ci@@ -410,6 +410,11 @@ add_subdirectory(common)
2149be168c0dSopenharmony_ci add_library(lite_src_mid OBJECT ${LITE_SRC})
2150be168c0dSopenharmony_ci add_dependencies(lite_src_mid lite_src_common_mid fbs_src fbs_inner_src)
2151be168c0dSopenharmony_ci 
2152be168c0dSopenharmony_ci+if(SUPPORT_NNRT)
2153be168c0dSopenharmony_ci+    add_subdirectory(litert/delegate/nnrt)
2154be168c0dSopenharmony_ci+    target_link_libraries(lite_src_mid nnrt_mid)
2155be168c0dSopenharmony_ci+endif()
2156be168c0dSopenharmony_ci+
2157be168c0dSopenharmony_ci if(MSLITE_ENABLE_ACL)
2158be168c0dSopenharmony_ci     include_directories(${TOP_DIR}/graphengine/910/inc/external)
2159be168c0dSopenharmony_ci     if(NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE))
2160be168c0dSopenharmony_ci@@ -497,7 +502,6 @@ if(MSLITE_ENABLE_MINDRT)
2161be168c0dSopenharmony_ci endif()
2162be168c0dSopenharmony_ci 
2163be168c0dSopenharmony_ci if (SUPPORT_NNRT)
2164be168c0dSopenharmony_ci-    add_subdirectory(litert/delegate/nnrt)
2165be168c0dSopenharmony_ci     target_link_libraries(mindspore-lite nnrt_mid)
2166be168c0dSopenharmony_ci     target_link_libraries(mindspore-lite_static nnrt_mid)
2167be168c0dSopenharmony_ci endif()
2168be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/context_util.cc b/mindspore/lite/src/common/context_util.cc
2169be168c0dSopenharmony_ciindex f011e0d7..0fa4ebd0 100644
2170be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/context_util.cc
2171be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/context_util.cc
2172be168c0dSopenharmony_ci@@ -118,6 +118,17 @@ std::shared_ptr<mindspore::DeviceInfoContext> CustomDeviceInfoFromCustomDeviceCo
2173be168c0dSopenharmony_ci   MS_CHECK_TRUE_RET(device_info != nullptr, nullptr);
2174be168c0dSopenharmony_ci   return device_info;
2175be168c0dSopenharmony_ci }
2176be168c0dSopenharmony_ci+
2177be168c0dSopenharmony_ci+std::shared_ptr<mindspore::NNRTDeviceInfo> NNRtDeviceInfoFromNNRtDeviceContext(
2178be168c0dSopenharmony_ci+  const lite::DeviceContext &nnrt_context) {
2179be168c0dSopenharmony_ci+  if (nnrt_context.device_type_ != DT_NNRT) {
2180be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Function input parameter is not NNRt context.";
2181be168c0dSopenharmony_ci+    return nullptr;
2182be168c0dSopenharmony_ci+  }
2183be168c0dSopenharmony_ci+  auto nnrt_info = std::make_shared<mindspore::NNRTDeviceInfo>();
2184be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(nnrt_info != nullptr, nullptr);
2185be168c0dSopenharmony_ci+  return nnrt_info;
2186be168c0dSopenharmony_ci+}
2187be168c0dSopenharmony_ci }  // namespace
2188be168c0dSopenharmony_ci 
2189be168c0dSopenharmony_ci mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &context) {
2190be168c0dSopenharmony_ci@@ -144,7 +155,8 @@ mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &co
2191be168c0dSopenharmony_ci                       {DT_GPU, GPUDeviceInfoFromGPUDeviceContext},
2192be168c0dSopenharmony_ci                       {DT_NPU, NPUDeviceInfoFromNPUDeviceContext},
2193be168c0dSopenharmony_ci                       {DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext},
2194be168c0dSopenharmony_ci-                      {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext}};
2195be168c0dSopenharmony_ci+                      {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext},
2196be168c0dSopenharmony_ci+                      {DT_NNRT, NNRtDeviceInfoFromNNRtDeviceContext}};
2197be168c0dSopenharmony_ci   for (auto &device_context : context->device_list_) {
2198be168c0dSopenharmony_ci     auto device_type = device_context.device_type_;
2199be168c0dSopenharmony_ci     if (transfer_funcs.find(device_type) == transfer_funcs.end()) {
2200be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/log.cc b/mindspore/lite/src/common/log.cc
2201be168c0dSopenharmony_ciindex 66c0d76b..f1040662 100644
2202be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/log.cc
2203be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/log.cc
2204be168c0dSopenharmony_ci@@ -21,6 +21,13 @@
2205be168c0dSopenharmony_ci #include <android/log.h>
2206be168c0dSopenharmony_ci #endif
2207be168c0dSopenharmony_ci 
2208be168c0dSopenharmony_ci+#ifdef MS_COMPILE_OHOS
2209be168c0dSopenharmony_ci+#define LOG_DOMAIN 0xD002102
2210be168c0dSopenharmony_ci+#define LOG_TAG "MS_LITE"
2211be168c0dSopenharmony_ci+#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s"
2212be168c0dSopenharmony_ci+#include "hilog/log.h"
2213be168c0dSopenharmony_ci+#endif
2214be168c0dSopenharmony_ci+
2215be168c0dSopenharmony_ci // namespace to support utils module definition namespace mindspore constexpr const char *ANDROID_LOG_TAG = "MS_LITE";
2216be168c0dSopenharmony_ci namespace mindspore {
2217be168c0dSopenharmony_ci #if defined(__ANDROID__)
2218be168c0dSopenharmony_ci@@ -73,17 +80,33 @@ static int GetAndroidLogLevel(LiteLogLevel level) {
2219be168c0dSopenharmony_ci 
2220be168c0dSopenharmony_ci #ifdef MS_COMPILE_OHOS
2221be168c0dSopenharmony_ci void PrintHiLog(LiteLogLevel level, const char *file, int line, const char *func, const char *msg) {
2222be168c0dSopenharmony_ci+#ifdef MS_COMPILE_WITH_OHOS_NDK
2223be168c0dSopenharmony_ci+  // When build with OHOS NDK, use public api of hilog module.
2224be168c0dSopenharmony_ci   if (level == LiteLogLevel::DEBUG) {
2225be168c0dSopenharmony_ci-    OHOS::HiviewDFX::HiLog::Debug(MSLite_LABEL, FORMAT, file, line, func, msg);
2226be168c0dSopenharmony_ci+    OH_LOG_Print(LOG_APP, LOG_DEBUG, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2227be168c0dSopenharmony_ci   } else if (level == LiteLogLevel::INFO) {
2228be168c0dSopenharmony_ci-    OHOS::HiviewDFX::HiLog::Info(MSLite_LABEL, FORMAT, file, line, func, msg);
2229be168c0dSopenharmony_ci+    OH_LOG_Print(LOG_APP, LOG_INFO, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2230be168c0dSopenharmony_ci   } else if (level == LiteLogLevel::WARNING) {
2231be168c0dSopenharmony_ci-    OHOS::HiviewDFX::HiLog::Warn(MSLite_LABEL, FORMAT, file, line, func, msg);
2232be168c0dSopenharmony_ci+    OH_LOG_Print(LOG_APP, LOG_WARN, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2233be168c0dSopenharmony_ci   } else if (level == LiteLogLevel::ERROR) {
2234be168c0dSopenharmony_ci-    OHOS::HiviewDFX::HiLog::Error(MSLite_LABEL, FORMAT, file, line, func, msg);
2235be168c0dSopenharmony_ci+    OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2236be168c0dSopenharmony_ci   } else {
2237be168c0dSopenharmony_ci-    OHOS::HiviewDFX::HiLog::Error(MSLite_LABEL, FORMAT, file, line, func, msg);
2238be168c0dSopenharmony_ci+    OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2239be168c0dSopenharmony_ci   }
2240be168c0dSopenharmony_ci+#else
2241be168c0dSopenharmony_ci+  // When build in OHOS repo, use inner api of hilog module.
2242be168c0dSopenharmony_ci+  if (level == LiteLogLevel::DEBUG) {
2243be168c0dSopenharmony_ci+    HILOG_DEBUG(LOG_APP, FORMAT, file, line, func, msg);
2244be168c0dSopenharmony_ci+  } else if (level == LiteLogLevel::INFO) {
2245be168c0dSopenharmony_ci+    HILOG_INFO(LOG_APP, FORMAT, file, line, func, msg);
2246be168c0dSopenharmony_ci+  } else if (level == LiteLogLevel::WARNING) {
2247be168c0dSopenharmony_ci+    HILOG_WARN(LOG_APP, FORMAT, file, line, func, msg);
2248be168c0dSopenharmony_ci+  } else if (level == LiteLogLevel::ERROR) {
2249be168c0dSopenharmony_ci+    HILOG_ERROR(LOG_APP, FORMAT, file, line, func, msg);
2250be168c0dSopenharmony_ci+  } else {
2251be168c0dSopenharmony_ci+    HILOG_ERROR(LOG_APP, FORMAT, file, line, func, msg);
2252be168c0dSopenharmony_ci+  }
2253be168c0dSopenharmony_ci+#endif
2254be168c0dSopenharmony_ci }
2255be168c0dSopenharmony_ci #endif
2256be168c0dSopenharmony_ci 
2257be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/log.h b/mindspore/lite/src/common/log.h
2258be168c0dSopenharmony_ciindex 3002a454..bea21f01 100644
2259be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/log.h
2260be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/log.h
2261be168c0dSopenharmony_ci@@ -23,12 +23,6 @@
2262be168c0dSopenharmony_ci #include <unordered_map>
2263be168c0dSopenharmony_ci #include "utils/overload.h"
2264be168c0dSopenharmony_ci 
2265be168c0dSopenharmony_ci-#ifdef MS_COMPILE_OHOS
2266be168c0dSopenharmony_ci-#define LOG_DOMAIN 0x2102
2267be168c0dSopenharmony_ci-#define LOG_TAG "MS_Lite"
2268be168c0dSopenharmony_ci-#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s"
2269be168c0dSopenharmony_ci-#include "hilog/log.h"
2270be168c0dSopenharmony_ci-#endif
2271be168c0dSopenharmony_ci // NOTICE: when relative path of 'log.h' changed, macro 'LITE_LOG_HEAR_FILE_REL_PATH' must be changed
2272be168c0dSopenharmony_ci #ifndef LITE_LOG_HEAR_FILE_REL_PATH
2273be168c0dSopenharmony_ci #define LITE_LOG_HEAR_FILE_REL_PATH "mindspore/lite/src/common/log.h"
2274be168c0dSopenharmony_ci@@ -140,6 +134,9 @@ class LiteLogWriter {
2275be168c0dSopenharmony_ci   LiteLogLevel log_level_;
2276be168c0dSopenharmony_ci };
2277be168c0dSopenharmony_ci 
2278be168c0dSopenharmony_ci+#define MSLOG_IF(level)                                                                                  \
2279be168c0dSopenharmony_ci+  mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \
2280be168c0dSopenharmony_ci+    mindspore::LiteLogStream()
2281be168c0dSopenharmony_ci 
2282be168c0dSopenharmony_ci #define MS_LOG(level) MS_LOG_##level
2283be168c0dSopenharmony_ci 
2284be168c0dSopenharmony_ci@@ -148,47 +145,6 @@ class LiteLogWriter {
2285be168c0dSopenharmony_ci #define MS_LOG_WARNING MSLOG_IF(mindspore::LiteLogLevel::WARNING)
2286be168c0dSopenharmony_ci #define MS_LOG_ERROR MSLOG_IF(mindspore::LiteLogLevel::ERROR)
2287be168c0dSopenharmony_ci 
2288be168c0dSopenharmony_ci-
2289be168c0dSopenharmony_ci-#ifdef MS_COMPILE_OHOS
2290be168c0dSopenharmony_ci-namespace {
2291be168c0dSopenharmony_ci-constexpr unsigned int MSLITE_DOMAIN_ID_START = 0xD0029A0;
2292be168c0dSopenharmony_ci-constexpr unsigned int MSLITE_DOMAIN_ID_END = MSLITE_DOMAIN_ID_START + 32;
2293be168c0dSopenharmony_ci-constexpr unsigned int TEST_DOMAIN_ID = 0xD000F00;
2294be168c0dSopenharmony_ci-}  // namespace
2295be168c0dSopenharmony_ci-
2296be168c0dSopenharmony_ci-#define FILE_NAME (__builtin_strrchr(__FILE__, '/') ? __builtin_strrchr(__FILE__, '/') + 1 : __FILE__)
2297be168c0dSopenharmony_ci-#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s"
2298be168c0dSopenharmony_ci-
2299be168c0dSopenharmony_ci-#define MSLOG_IF(level)                                                                                  \
2300be168c0dSopenharmony_ci-  mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \
2301be168c0dSopenharmony_ci-    mindspore::LiteLogStream()
2302be168c0dSopenharmony_ci-
2303be168c0dSopenharmony_ci-enum MSLiteManagerLogLabel {
2304be168c0dSopenharmony_ci-  // Component labels, you can add if needed
2305be168c0dSopenharmony_ci-  COMP_FWK = 0,
2306be168c0dSopenharmony_ci-  // Test label
2307be168c0dSopenharmony_ci-  LABEL_TEST,
2308be168c0dSopenharmony_ci-  // The end of labels, max to the domain id range length 32
2309be168c0dSopenharmony_ci-  LABEL_END,
2310be168c0dSopenharmony_ci-};
2311be168c0dSopenharmony_ci-
2312be168c0dSopenharmony_ci-enum MSLiteManagerLogDomain {
2313be168c0dSopenharmony_ci-  DOMAIN_FRAMEWORK = MSLITE_DOMAIN_ID_START + COMP_FWK,  // 0xD0029A0
2314be168c0dSopenharmony_ci-  DOMAIN_TEST = TEST_DOMAIN_ID,                          // 0xD000F00
2315be168c0dSopenharmony_ci-  DOMAIN_END = MSLITE_DOMAIN_ID_END,  // Max to 0xD002940, keep the sequence and length same as MSLiteManagerLogLabel
2316be168c0dSopenharmony_ci-};
2317be168c0dSopenharmony_ci-
2318be168c0dSopenharmony_ci-// Keep the sequence and length same as MSLiteManagerLogDomain
2319be168c0dSopenharmony_ci-static constexpr OHOS::HiviewDFX::HiLogLabel MSLite_LABEL = {LOG_CORE, DOMAIN_FRAMEWORK, "MSLiteFwk"};
2320be168c0dSopenharmony_ci-
2321be168c0dSopenharmony_ci-#else
2322be168c0dSopenharmony_ci-
2323be168c0dSopenharmony_ci-#define MSLOG_IF(level)                                                                                  \
2324be168c0dSopenharmony_ci-  mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \
2325be168c0dSopenharmony_ci-    mindspore::LiteLogStream()
2326be168c0dSopenharmony_ci-
2327be168c0dSopenharmony_ci-#endif
2328be168c0dSopenharmony_ci-
2329be168c0dSopenharmony_ci }  // namespace mindspore
2330be168c0dSopenharmony_ci 
2331be168c0dSopenharmony_ci #ifdef Debug
2332be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc
2333be168c0dSopenharmony_ciindex 5e1878b9..13957ed7 100644
2334be168c0dSopenharmony_ci--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc
2335be168c0dSopenharmony_ci+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc
2336be168c0dSopenharmony_ci@@ -19,6 +19,9 @@
2337be168c0dSopenharmony_ci #include "nnacl/custom_parameter.h"
2338be168c0dSopenharmony_ci #include "nnacl/split_parameter.h"
2339be168c0dSopenharmony_ci #include "nnacl/custom_gru_parameter.h"
2340be168c0dSopenharmony_ci+#include "nnacl/custom_masked_fill_parameter.h"
2341be168c0dSopenharmony_ci+#include "nnacl/custom_is_inf_parameter.h"
2342be168c0dSopenharmony_ci+#include "nnacl/custom_tensor_scatter_max_parameter.h"
2343be168c0dSopenharmony_ci using mindspore::schema::PrimitiveType_Custom;
2344be168c0dSopenharmony_ci 
2345be168c0dSopenharmony_ci namespace mindspore {
2346be168c0dSopenharmony_ci@@ -92,6 +95,39 @@ OpParameter *CreateCustomGruParameter() {
2347be168c0dSopenharmony_ci   return reinterpret_cast<OpParameter *>(param);
2348be168c0dSopenharmony_ci }
2349be168c0dSopenharmony_ci 
2350be168c0dSopenharmony_ci+OpParameter *CreateCustomIsInfParameter() {
2351be168c0dSopenharmony_ci+  auto *param = static_cast<CustomIsInfParameter *>(malloc(sizeof(CustomIsInfParameter)));
2352be168c0dSopenharmony_ci+  if (param == nullptr) {
2353be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc CustomIsInfParameter failed.";
2354be168c0dSopenharmony_ci+    return nullptr;
2355be168c0dSopenharmony_ci+  }
2356be168c0dSopenharmony_ci+  memset(param, 0, sizeof(CustomIsInfParameter));
2357be168c0dSopenharmony_ci+  param->op_parameter_.type_ = PrimType_Inner_CustomIsInf;
2358be168c0dSopenharmony_ci+  return reinterpret_cast<OpParameter *>(param);
2359be168c0dSopenharmony_ci+}
2360be168c0dSopenharmony_ci+
2361be168c0dSopenharmony_ci+OpParameter *CreateCustomTensorScatterMaxParameter() {
2362be168c0dSopenharmony_ci+  auto *param = static_cast<CustomTensorScatterMaxParameter *>(malloc(sizeof(CustomTensorScatterMaxParameter)));
2363be168c0dSopenharmony_ci+  if (param == nullptr) {
2364be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc CustomTensorScatterMaxParameter failed.";
2365be168c0dSopenharmony_ci+    return nullptr;
2366be168c0dSopenharmony_ci+  }
2367be168c0dSopenharmony_ci+  memset(param, 0, sizeof(CustomTensorScatterMaxParameter));
2368be168c0dSopenharmony_ci+  param->op_parameter_.type_ = PrimType_Inner_CustomTensorScatterMax;
2369be168c0dSopenharmony_ci+  return reinterpret_cast<OpParameter *>(param);
2370be168c0dSopenharmony_ci+}
2371be168c0dSopenharmony_ci+
2372be168c0dSopenharmony_ci+OpParameter *CreateCustomMaskedFillParameter() {
2373be168c0dSopenharmony_ci+  auto *param = static_cast<CustomMaskedFillParameter *>(malloc(sizeof(CustomMaskedFillParameter)));
2374be168c0dSopenharmony_ci+  if (param == nullptr) {
2375be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc CustomMaskedFillParameter failed.";
2376be168c0dSopenharmony_ci+    return nullptr;
2377be168c0dSopenharmony_ci+  }
2378be168c0dSopenharmony_ci+  memset(param, 0, sizeof(CustomMaskedFillParameter));
2379be168c0dSopenharmony_ci+  param->op_parameter_.type_ = PrimType_Inner_CustomMaskedFill;
2380be168c0dSopenharmony_ci+  return reinterpret_cast<OpParameter *>(param);
2381be168c0dSopenharmony_ci+}
2382be168c0dSopenharmony_ci+
2383be168c0dSopenharmony_ci OpParameter *PopulateCustomParameter(const void *prim) {
2384be168c0dSopenharmony_ci   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
2385be168c0dSopenharmony_ci   auto primitive = static_cast<const schema::Primitive *>(prim);
2386be168c0dSopenharmony_ci@@ -131,6 +167,23 @@ OpParameter *PopulateCustomParameter(const void *prim) {
2387be168c0dSopenharmony_ci     return CreateCustomGruParameter();
2388be168c0dSopenharmony_ci   } else if (type == "CastGatherReduceFusion") {
2389be168c0dSopenharmony_ci     return CreateParam(PrimType_Inner_CastGatherReduceFusion);
2390be168c0dSopenharmony_ci+  } else if (type == "ThirdPartyModel") {
2391be168c0dSopenharmony_ci+    auto *param = static_cast<CustomParameter *>(malloc(sizeof(CustomParameter)));
2392be168c0dSopenharmony_ci+    if (param == nullptr) {
2393be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "malloc CustomParameter failed.";
2394be168c0dSopenharmony_ci+      return nullptr;
2395be168c0dSopenharmony_ci+    }
2396be168c0dSopenharmony_ci+    memset(param, 0, sizeof(CustomParameter));
2397be168c0dSopenharmony_ci+    param->op_parameter_.type_ = PrimType_Inner_ThirdPartyModel;
2398be168c0dSopenharmony_ci+    // Just use the attr_data pointer to save the prim directly, the inner value is parsed as necessary.
2399be168c0dSopenharmony_ci+    param->attr_data[0] = static_cast<char *>(const_cast<void *>(prim));
2400be168c0dSopenharmony_ci+    return reinterpret_cast<OpParameter *>(param);
2401be168c0dSopenharmony_ci+  } else if (type == "MaskedFill") {
2402be168c0dSopenharmony_ci+    return CreateCustomMaskedFillParameter();
2403be168c0dSopenharmony_ci+  } else if (type == "TensorScatterMax") {
2404be168c0dSopenharmony_ci+    return CreateCustomTensorScatterMaxParameter();
2405be168c0dSopenharmony_ci+  } else if (type == "IsInf") {
2406be168c0dSopenharmony_ci+    return CreateCustomIsInfParameter();
2407be168c0dSopenharmony_ci   } else {
2408be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Unsupported custom type: " << type;
2409be168c0dSopenharmony_ci   }
2410be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc
2411be168c0dSopenharmony_ciindex f614ef09..c5f825aa 100644
2412be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/context_c.cc
2413be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/context_c.cc
2414be168c0dSopenharmony_ci@@ -14,12 +14,17 @@
2415be168c0dSopenharmony_ci  * limitations under the License.
2416be168c0dSopenharmony_ci  */
2417be168c0dSopenharmony_ci #include "include/c_api/context_c.h"
2418be168c0dSopenharmony_ci-#include "src/litert/c_api/context_c.h"
2419be168c0dSopenharmony_ci+#include "include/api/context.h"
2420be168c0dSopenharmony_ci+#include <string.h>
2421be168c0dSopenharmony_ci+#include "src/litert/c_api/type_c_private.h"
2422be168c0dSopenharmony_ci #include "src/common/log_adapter.h"
2423be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT
2424be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
2425be168c0dSopenharmony_ci+#endif
2426be168c0dSopenharmony_ci 
2427be168c0dSopenharmony_ci // ================ Context ================
2428be168c0dSopenharmony_ci OH_AI_ContextHandle OH_AI_ContextCreate() {
2429be168c0dSopenharmony_ci-  auto impl = new (std::nothrow) mindspore::ContextC;
2430be168c0dSopenharmony_ci+  auto impl = new (std::nothrow) mindspore::Context();
2431be168c0dSopenharmony_ci   if (impl == nullptr) {
2432be168c0dSopenharmony_ci     MS_LOG(ERROR) << "memory allocation failed.";
2433be168c0dSopenharmony_ci     return nullptr;
2434be168c0dSopenharmony_ci@@ -29,7 +34,7 @@ OH_AI_ContextHandle OH_AI_ContextCreate() {
2435be168c0dSopenharmony_ci 
2436be168c0dSopenharmony_ci void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) {
2437be168c0dSopenharmony_ci   if (context != nullptr && *context != nullptr) {
2438be168c0dSopenharmony_ci-    auto impl = static_cast<mindspore::ContextC *>(*context);
2439be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::Context *>(*context);
2440be168c0dSopenharmony_ci     delete impl;
2441be168c0dSopenharmony_ci     *context = nullptr;
2442be168c0dSopenharmony_ci   }
2443be168c0dSopenharmony_ci@@ -40,8 +45,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num)
2444be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2445be168c0dSopenharmony_ci     return;
2446be168c0dSopenharmony_ci   }
2447be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2448be168c0dSopenharmony_ci-  impl->thread_num = thread_num;
2449be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2450be168c0dSopenharmony_ci+  impl->SetThreadNum(thread_num);
2451be168c0dSopenharmony_ci }
2452be168c0dSopenharmony_ci 
2453be168c0dSopenharmony_ci int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
2454be168c0dSopenharmony_ci@@ -49,8 +54,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
2455be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2456be168c0dSopenharmony_ci     return 0;
2457be168c0dSopenharmony_ci   }
2458be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2459be168c0dSopenharmony_ci-  return impl->thread_num;
2460be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2461be168c0dSopenharmony_ci+  return impl->GetThreadNum();
2462be168c0dSopenharmony_ci }
2463be168c0dSopenharmony_ci 
2464be168c0dSopenharmony_ci void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
2465be168c0dSopenharmony_ci@@ -58,8 +63,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
2466be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2467be168c0dSopenharmony_ci     return;
2468be168c0dSopenharmony_ci   }
2469be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2470be168c0dSopenharmony_ci-  impl->affinity_mode = mode;
2471be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2472be168c0dSopenharmony_ci+  impl->SetThreadAffinity(mode);
2473be168c0dSopenharmony_ci   return;
2474be168c0dSopenharmony_ci }
2475be168c0dSopenharmony_ci 
2476be168c0dSopenharmony_ci@@ -68,8 +73,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) {
2477be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2478be168c0dSopenharmony_ci     return 0;
2479be168c0dSopenharmony_ci   }
2480be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2481be168c0dSopenharmony_ci-  return impl->affinity_mode;
2482be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2483be168c0dSopenharmony_ci+  return impl->GetThreadAffinityMode();
2484be168c0dSopenharmony_ci }
2485be168c0dSopenharmony_ci 
2486be168c0dSopenharmony_ci void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) {
2487be168c0dSopenharmony_ci@@ -78,8 +83,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i
2488be168c0dSopenharmony_ci     return;
2489be168c0dSopenharmony_ci   }
2490be168c0dSopenharmony_ci   const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
2491be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2492be168c0dSopenharmony_ci-  impl->affinity_core_list = vec_core_list;
2493be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2494be168c0dSopenharmony_ci+  impl->SetThreadAffinity(vec_core_list);
2495be168c0dSopenharmony_ci   return;
2496be168c0dSopenharmony_ci }
2497be168c0dSopenharmony_ci 
2498be168c0dSopenharmony_ci@@ -88,9 +93,18 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle
2499be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2500be168c0dSopenharmony_ci     return nullptr;
2501be168c0dSopenharmony_ci   }
2502be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2503be168c0dSopenharmony_ci-  *core_num = impl->affinity_core_list.size();
2504be168c0dSopenharmony_ci-  return impl->affinity_core_list.data();
2505be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2506be168c0dSopenharmony_ci+  auto affinity_core_list = impl->GetThreadAffinityCoreList();
2507be168c0dSopenharmony_ci+  *core_num = affinity_core_list.size();
2508be168c0dSopenharmony_ci+  int32_t *core_list = static_cast<int32_t *>(malloc((*core_num) * sizeof(int32_t)));
2509be168c0dSopenharmony_ci+  if (core_list == nullptr) {
2510be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc core_list is null.";
2511be168c0dSopenharmony_ci+    return nullptr;
2512be168c0dSopenharmony_ci+  }
2513be168c0dSopenharmony_ci+  for (size_t i = 0; i < affinity_core_list.size(); i++) {
2514be168c0dSopenharmony_ci+    core_list[i] = affinity_core_list[i];
2515be168c0dSopenharmony_ci+  }
2516be168c0dSopenharmony_ci+  return core_list;
2517be168c0dSopenharmony_ci }
2518be168c0dSopenharmony_ci 
2519be168c0dSopenharmony_ci void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_parallel) {
2520be168c0dSopenharmony_ci@@ -98,8 +112,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle
2521be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2522be168c0dSopenharmony_ci     return;
2523be168c0dSopenharmony_ci   }
2524be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2525be168c0dSopenharmony_ci-  impl->enable_parallel = is_parallel;
2526be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2527be168c0dSopenharmony_ci+  impl->SetEnableParallel(is_parallel);
2528be168c0dSopenharmony_ci }
2529be168c0dSopenharmony_ci 
2530be168c0dSopenharmony_ci bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
2531be168c0dSopenharmony_ci@@ -107,8 +121,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
2532be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2533be168c0dSopenharmony_ci     return false;
2534be168c0dSopenharmony_ci   }
2535be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2536be168c0dSopenharmony_ci-  return impl->enable_parallel;
2537be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2538be168c0dSopenharmony_ci+  return impl->GetEnableParallel();
2539be168c0dSopenharmony_ci }
2540be168c0dSopenharmony_ci 
2541be168c0dSopenharmony_ci void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) {
2542be168c0dSopenharmony_ci@@ -116,25 +130,36 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan
2543be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2544be168c0dSopenharmony_ci     return;
2545be168c0dSopenharmony_ci   }
2546be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::ContextC *>(context);
2547be168c0dSopenharmony_ci-  std::shared_ptr<mindspore::DeviceInfoC> device(static_cast<mindspore::DeviceInfoC *>(device_info));
2548be168c0dSopenharmony_ci-  impl->device_info_list.push_back(device);
2549be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::Context *>(context);
2550be168c0dSopenharmony_ci+  std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info));
2551be168c0dSopenharmony_ci+  impl->MutableDeviceInfo().push_back(device);
2552be168c0dSopenharmony_ci }
2553be168c0dSopenharmony_ci 
2554be168c0dSopenharmony_ci // ================ DeviceInfo ================
2555be168c0dSopenharmony_ci OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type) {
2556be168c0dSopenharmony_ci-  mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC;
2557be168c0dSopenharmony_ci+  mindspore::DeviceInfoContext *impl;
2558be168c0dSopenharmony_ci+  if (OH_AI_DEVICETYPE_CPU == device_type) {
2559be168c0dSopenharmony_ci+    impl = new (std::nothrow) mindspore::CPUDeviceInfo();
2560be168c0dSopenharmony_ci+  } else if (OH_AI_DEVICETYPE_GPU == device_type) {
2561be168c0dSopenharmony_ci+    impl = new (std::nothrow) mindspore::GPUDeviceInfo();
2562be168c0dSopenharmony_ci+  } else if (OH_AI_DEVICETYPE_KIRIN_NPU == device_type) {
2563be168c0dSopenharmony_ci+    impl = new (std::nothrow) mindspore::KirinNPUDeviceInfo();
2564be168c0dSopenharmony_ci+  } else if (OH_AI_DEVICETYPE_NNRT == device_type) {
2565be168c0dSopenharmony_ci+    impl = new (std::nothrow) mindspore::NNRTDeviceInfo();
2566be168c0dSopenharmony_ci+  } else {
2567be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "device_type is invalid.";
2568be168c0dSopenharmony_ci+    impl = nullptr;
2569be168c0dSopenharmony_ci+  }
2570be168c0dSopenharmony_ci   if (impl == nullptr) {
2571be168c0dSopenharmony_ci     MS_LOG(ERROR) << "memory allocation failed.";
2572be168c0dSopenharmony_ci     return nullptr;
2573be168c0dSopenharmony_ci   }
2574be168c0dSopenharmony_ci-  impl->device_type = device_type;
2575be168c0dSopenharmony_ci   return static_cast<OH_AI_DeviceInfoHandle>(impl);
2576be168c0dSopenharmony_ci }
2577be168c0dSopenharmony_ci 
2578be168c0dSopenharmony_ci void OH_AI_DeviceInfoDestroy(OH_AI_DeviceInfoHandle *device_info) {
2579be168c0dSopenharmony_ci   if (device_info != nullptr && *device_info != nullptr) {
2580be168c0dSopenharmony_ci-    auto impl = static_cast<mindspore::DeviceInfoC *>(*device_info);
2581be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::DeviceInfoContext *>(*device_info);
2582be168c0dSopenharmony_ci     delete impl;
2583be168c0dSopenharmony_ci     *device_info = nullptr;
2584be168c0dSopenharmony_ci   }
2585be168c0dSopenharmony_ci@@ -145,8 +170,8 @@ void OH_AI_DeviceInfoSetProvider(OH_AI_DeviceInfoHandle device_info, const char
2586be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2587be168c0dSopenharmony_ci     return;
2588be168c0dSopenharmony_ci   }
2589be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2590be168c0dSopenharmony_ci-  impl->provider = provider;
2591be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2592be168c0dSopenharmony_ci+  impl->SetProvider(provider);
2593be168c0dSopenharmony_ci }
2594be168c0dSopenharmony_ci 
2595be168c0dSopenharmony_ci const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info) {
2596be168c0dSopenharmony_ci@@ -154,8 +179,14 @@ const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info
2597be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2598be168c0dSopenharmony_ci     return nullptr;
2599be168c0dSopenharmony_ci   }
2600be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2601be168c0dSopenharmony_ci-  return impl->provider.c_str();
2602be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2603be168c0dSopenharmony_ci+  char *provider = static_cast<char *>(malloc(impl->GetProvider().size() + 1));
2604be168c0dSopenharmony_ci+  if (provider == nullptr) {
2605be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc provider is null.";
2606be168c0dSopenharmony_ci+    return nullptr;
2607be168c0dSopenharmony_ci+  }
2608be168c0dSopenharmony_ci+  strcpy(provider, impl->GetProvider().c_str());
2609be168c0dSopenharmony_ci+  return provider;
2610be168c0dSopenharmony_ci }
2611be168c0dSopenharmony_ci 
2612be168c0dSopenharmony_ci void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const char *device) {
2613be168c0dSopenharmony_ci@@ -163,8 +194,8 @@ void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const
2614be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2615be168c0dSopenharmony_ci     return;
2616be168c0dSopenharmony_ci   }
2617be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2618be168c0dSopenharmony_ci-  impl->provider_device = device;
2619be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2620be168c0dSopenharmony_ci+  impl->SetProviderDevice(device);
2621be168c0dSopenharmony_ci }
2622be168c0dSopenharmony_ci 
2623be168c0dSopenharmony_ci const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle device_info) {
2624be168c0dSopenharmony_ci@@ -172,8 +203,14 @@ const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle devic
2625be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2626be168c0dSopenharmony_ci     return nullptr;
2627be168c0dSopenharmony_ci   }
2628be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2629be168c0dSopenharmony_ci-  return impl->provider_device.c_str();
2630be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2631be168c0dSopenharmony_ci+  char *provider_device = static_cast<char *>(malloc(impl->GetProviderDevice().size() + 1));
2632be168c0dSopenharmony_ci+  if (provider_device == nullptr) {
2633be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc provider_device is null.";
2634be168c0dSopenharmony_ci+    return nullptr;
2635be168c0dSopenharmony_ci+  }
2636be168c0dSopenharmony_ci+  strcpy(provider_device, impl->GetProviderDevice().c_str());
2637be168c0dSopenharmony_ci+  return provider_device;
2638be168c0dSopenharmony_ci }
2639be168c0dSopenharmony_ci 
2640be168c0dSopenharmony_ci OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle device_info) {
2641be168c0dSopenharmony_ci@@ -181,8 +218,8 @@ OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle devi
2642be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2643be168c0dSopenharmony_ci     return OH_AI_DEVICETYPE_INVALID;
2644be168c0dSopenharmony_ci   }
2645be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2646be168c0dSopenharmony_ci-  return impl->device_type;
2647be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2648be168c0dSopenharmony_ci+  return static_cast<OH_AI_DeviceType>(impl->GetDeviceType());
2649be168c0dSopenharmony_ci }
2650be168c0dSopenharmony_ci 
2651be168c0dSopenharmony_ci void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_fp16) {
2652be168c0dSopenharmony_ci@@ -190,9 +227,17 @@ void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_f
2653be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2654be168c0dSopenharmony_ci     return;
2655be168c0dSopenharmony_ci   }
2656be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2657be168c0dSopenharmony_ci-  if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) {
2658be168c0dSopenharmony_ci-    impl->enable_fp16 = is_fp16;
2659be168c0dSopenharmony_ci+
2660be168c0dSopenharmony_ci+  auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info);
2661be168c0dSopenharmony_ci+  if (OH_AI_DEVICETYPE_CPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2662be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
2663be168c0dSopenharmony_ci+    impl->SetEnableFP16(is_fp16);
2664be168c0dSopenharmony_ci+  } else if (OH_AI_DEVICETYPE_GPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2665be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
2666be168c0dSopenharmony_ci+    impl->SetEnableFP16(is_fp16);
2667be168c0dSopenharmony_ci+  } else if (OH_AI_DEVICETYPE_NNRT == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2668be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::NNRTDeviceInfo *>(device_info);
2669be168c0dSopenharmony_ci+    impl->SetEnableFP16(is_fp16);
2670be168c0dSopenharmony_ci   } else {
2671be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Unsupported Feature.";
2672be168c0dSopenharmony_ci   }
2673be168c0dSopenharmony_ci@@ -203,11 +248,19 @@ bool OH_AI_DeviceInfoGetEnableFP16(const OH_AI_DeviceInfoHandle device_info) {
2674be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2675be168c0dSopenharmony_ci     return false;
2676be168c0dSopenharmony_ci   }
2677be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2678be168c0dSopenharmony_ci-  if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) {
2679be168c0dSopenharmony_ci-    return impl->enable_fp16;
2680be168c0dSopenharmony_ci+
2681be168c0dSopenharmony_ci+  auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info);
2682be168c0dSopenharmony_ci+  if (OH_AI_DEVICETYPE_CPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2683be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
2684be168c0dSopenharmony_ci+    return impl->GetEnableFP16();
2685be168c0dSopenharmony_ci+  } else if (OH_AI_DEVICETYPE_GPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2686be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
2687be168c0dSopenharmony_ci+    return impl->GetEnableFP16();
2688be168c0dSopenharmony_ci+  } else if (OH_AI_DEVICETYPE_NNRT == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2689be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::NNRTDeviceInfo *>(device_info);
2690be168c0dSopenharmony_ci+    return impl->GetEnableFP16();
2691be168c0dSopenharmony_ci   } else {
2692be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type;
2693be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl_device->GetDeviceType();
2694be168c0dSopenharmony_ci     return false;
2695be168c0dSopenharmony_ci   }
2696be168c0dSopenharmony_ci }
2697be168c0dSopenharmony_ci@@ -217,9 +270,10 @@ void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, int freque
2698be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2699be168c0dSopenharmony_ci     return;
2700be168c0dSopenharmony_ci   }
2701be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2702be168c0dSopenharmony_ci-  if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
2703be168c0dSopenharmony_ci-    impl->frequency = frequency;
2704be168c0dSopenharmony_ci+  auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info);
2705be168c0dSopenharmony_ci+  if (static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType()) == OH_AI_DEVICETYPE_KIRIN_NPU) {
2706be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
2707be168c0dSopenharmony_ci+    impl->SetFrequency(frequency);
2708be168c0dSopenharmony_ci   } else {
2709be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Unsupported Feature.";
2710be168c0dSopenharmony_ci   }
2711be168c0dSopenharmony_ci@@ -230,11 +284,231 @@ int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info) {  //
2712be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
2713be168c0dSopenharmony_ci     return -1;
2714be168c0dSopenharmony_ci   }
2715be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2716be168c0dSopenharmony_ci-  if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
2717be168c0dSopenharmony_ci-    return impl->frequency;
2718be168c0dSopenharmony_ci+  auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info);
2719be168c0dSopenharmony_ci+  if (static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType()) == OH_AI_DEVICETYPE_KIRIN_NPU) {
2720be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
2721be168c0dSopenharmony_ci+    return impl->GetFrequency();
2722be168c0dSopenharmony_ci   } else {
2723be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Unsupported Feature.";
2724be168c0dSopenharmony_ci     return -1;
2725be168c0dSopenharmony_ci   }
2726be168c0dSopenharmony_ci }
2727be168c0dSopenharmony_ci+
2728be168c0dSopenharmony_ci+NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num) {
2729be168c0dSopenharmony_ci+  if (num == nullptr) {
2730be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Input num is null";
2731be168c0dSopenharmony_ci+    return nullptr;
2732be168c0dSopenharmony_ci+  }
2733be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT
2734be168c0dSopenharmony_ci+  *num = 0;
2735be168c0dSopenharmony_ci+
2736be168c0dSopenharmony_ci+  const size_t *all_device_ids;
2737be168c0dSopenharmony_ci+  uint32_t device_count;
2738be168c0dSopenharmony_ci+  auto ret = OH_NNDevice_GetAllDevicesID(&all_device_ids, &device_count);
2739be168c0dSopenharmony_ci+  if ((ret != OH_NN_SUCCESS) || (device_count == 0)) {
2740be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNRT get all device id failed, ret: " << ret;
2741be168c0dSopenharmony_ci+    return nullptr;
2742be168c0dSopenharmony_ci+  }
2743be168c0dSopenharmony_ci+
2744be168c0dSopenharmony_ci+  NNRTDeviceDesc *desc = (NNRTDeviceDesc *)malloc(sizeof(NNRTDeviceDesc) * device_count);
2745be168c0dSopenharmony_ci+  if (desc == nullptr) {
2746be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNRT allocate desc failed";
2747be168c0dSopenharmony_ci+    return nullptr;
2748be168c0dSopenharmony_ci+  }
2749be168c0dSopenharmony_ci+
2750be168c0dSopenharmony_ci+  for (uint32_t i = 0; i < device_count; i++) {
2751be168c0dSopenharmony_ci+    desc[i].device_id = all_device_ids[i];
2752be168c0dSopenharmony_ci+    OH_NN_DeviceType type;
2753be168c0dSopenharmony_ci+    (void)OH_NNDevice_GetType(all_device_ids[i], &type);
2754be168c0dSopenharmony_ci+    desc[i].device_type = static_cast<OH_AI_NNRTDeviceType>(type);
2755be168c0dSopenharmony_ci+
2756be168c0dSopenharmony_ci+    const char *name = nullptr;
2757be168c0dSopenharmony_ci+    (void)OH_NNDevice_GetName(all_device_ids[i], &name);
2758be168c0dSopenharmony_ci+    desc[i].device_name[127] = '\0';
2759be168c0dSopenharmony_ci+    strncpy(desc[i].device_name, name, 127);
2760be168c0dSopenharmony_ci+  }
2761be168c0dSopenharmony_ci+  *num = device_count;
2762be168c0dSopenharmony_ci+  return desc;
2763be168c0dSopenharmony_ci+#else
2764be168c0dSopenharmony_ci+  return nullptr;
2765be168c0dSopenharmony_ci+#endif
2766be168c0dSopenharmony_ci+}
2767be168c0dSopenharmony_ci+
2768be168c0dSopenharmony_ci+NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index) {
2769be168c0dSopenharmony_ci+  if (descs == nullptr) {
2770be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "descs is null";
2771be168c0dSopenharmony_ci+    return nullptr;
2772be168c0dSopenharmony_ci+  }
2773be168c0dSopenharmony_ci+  return descs + index;
2774be168c0dSopenharmony_ci+}
2775be168c0dSopenharmony_ci+
2776be168c0dSopenharmony_ci+void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc) {
2777be168c0dSopenharmony_ci+  if (desc == nullptr) {
2778be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "desc is null";
2779be168c0dSopenharmony_ci+    return;
2780be168c0dSopenharmony_ci+  }
2781be168c0dSopenharmony_ci+  free(*desc);
2782be168c0dSopenharmony_ci+  *desc = nullptr;
2783be168c0dSopenharmony_ci+}
2784be168c0dSopenharmony_ci+
2785be168c0dSopenharmony_ci+size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) {
2786be168c0dSopenharmony_ci+  if (desc == nullptr) {
2787be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNRT desc is null";
2788be168c0dSopenharmony_ci+    return 0;
2789be168c0dSopenharmony_ci+  }
2790be168c0dSopenharmony_ci+  return desc->device_id;
2791be168c0dSopenharmony_ci+}
2792be168c0dSopenharmony_ci+
2793be168c0dSopenharmony_ci+const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) {
2794be168c0dSopenharmony_ci+  if (desc == nullptr) {
2795be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNRT desc is null";
2796be168c0dSopenharmony_ci+    return nullptr;
2797be168c0dSopenharmony_ci+  }
2798be168c0dSopenharmony_ci+  return desc->device_name;
2799be168c0dSopenharmony_ci+}
2800be168c0dSopenharmony_ci+
2801be168c0dSopenharmony_ci+OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) {
2802be168c0dSopenharmony_ci+  if (desc == nullptr) {
2803be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNRT desc is null";
2804be168c0dSopenharmony_ci+    return OH_AI_NNRTDeviceType::OH_AI_NNRTDEVICE_OTHERS;
2805be168c0dSopenharmony_ci+  }
2806be168c0dSopenharmony_ci+  return desc->device_type;
2807be168c0dSopenharmony_ci+}
2808be168c0dSopenharmony_ci+
2809be168c0dSopenharmony_ci+OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name) {
2810be168c0dSopenharmony_ci+  size_t num = 0;
2811be168c0dSopenharmony_ci+  NNRTDeviceDesc *desc = OH_AI_GetAllNNRTDeviceDescs(&num);
2812be168c0dSopenharmony_ci+  if (desc == nullptr) {
2813be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Get all device desc failed";
2814be168c0dSopenharmony_ci+    return nullptr;
2815be168c0dSopenharmony_ci+  }
2816be168c0dSopenharmony_ci+
2817be168c0dSopenharmony_ci+  OH_AI_DeviceInfoHandle handle = nullptr;
2818be168c0dSopenharmony_ci+  for (size_t i = 0; i < num; i++) {
2819be168c0dSopenharmony_ci+    if (strncmp(desc[i].device_name, name, NNRT_DEVICE_NAME_MAX - 1) == 0) {
2820be168c0dSopenharmony_ci+      handle = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
2821be168c0dSopenharmony_ci+      OH_AI_DeviceInfoSetDeviceId(handle, desc[i].device_id);
2822be168c0dSopenharmony_ci+      break;
2823be168c0dSopenharmony_ci+    }
2824be168c0dSopenharmony_ci+  }
2825be168c0dSopenharmony_ci+  OH_AI_DestroyAllNNRTDeviceDescs(&desc);
2826be168c0dSopenharmony_ci+  return handle;
2827be168c0dSopenharmony_ci+}
2828be168c0dSopenharmony_ci+
2829be168c0dSopenharmony_ci+OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type) {
2830be168c0dSopenharmony_ci+  size_t num = 0;
2831be168c0dSopenharmony_ci+  NNRTDeviceDesc *desc = OH_AI_GetAllNNRTDeviceDescs(&num);
2832be168c0dSopenharmony_ci+  if (desc == nullptr) {
2833be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Get all device desc failed";
2834be168c0dSopenharmony_ci+    return nullptr;
2835be168c0dSopenharmony_ci+  }
2836be168c0dSopenharmony_ci+
2837be168c0dSopenharmony_ci+  OH_AI_DeviceInfoHandle handle = nullptr;
2838be168c0dSopenharmony_ci+  for (size_t i = 0; i < num; i++) {
2839be168c0dSopenharmony_ci+    if (desc[i].device_type == type) {
2840be168c0dSopenharmony_ci+      handle = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
2841be168c0dSopenharmony_ci+      OH_AI_DeviceInfoSetDeviceId(handle, desc[i].device_id);
2842be168c0dSopenharmony_ci+      break;
2843be168c0dSopenharmony_ci+    }
2844be168c0dSopenharmony_ci+  }
2845be168c0dSopenharmony_ci+  OH_AI_DestroyAllNNRTDeviceDescs(&desc);
2846be168c0dSopenharmony_ci+  return handle;
2847be168c0dSopenharmony_ci+}
2848be168c0dSopenharmony_ci+
2849be168c0dSopenharmony_ci+void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id) {
2850be168c0dSopenharmony_ci+  if (device_info == nullptr) {
2851be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "device info is null";
2852be168c0dSopenharmony_ci+    return;
2853be168c0dSopenharmony_ci+  }
2854be168c0dSopenharmony_ci+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2855be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Set device_id of non-NNRT device is not allowable, ignored";
2856be168c0dSopenharmony_ci+    return;
2857be168c0dSopenharmony_ci+  }
2858be168c0dSopenharmony_ci+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2859be168c0dSopenharmony_ci+  impl->SetDeviceID(device_id);
2860be168c0dSopenharmony_ci+}
2861be168c0dSopenharmony_ci+
2862be168c0dSopenharmony_ci+size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info) {
2863be168c0dSopenharmony_ci+  if (device_info == nullptr) {
2864be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "device info is null";
2865be168c0dSopenharmony_ci+    return 0;
2866be168c0dSopenharmony_ci+  }
2867be168c0dSopenharmony_ci+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2868be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Get device_id of non-NNRT device is not allowable, ignored";
2869be168c0dSopenharmony_ci+    return 0;
2870be168c0dSopenharmony_ci+  }
2871be168c0dSopenharmony_ci+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2872be168c0dSopenharmony_ci+  return impl->GetDeviceID();
2873be168c0dSopenharmony_ci+}
2874be168c0dSopenharmony_ci+
2875be168c0dSopenharmony_ci+void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode) {
2876be168c0dSopenharmony_ci+  if (device_info == nullptr) {
2877be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "device info is null";
2878be168c0dSopenharmony_ci+    return;
2879be168c0dSopenharmony_ci+  }
2880be168c0dSopenharmony_ci+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2881be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Set performance_mode of non-NNRT device is not allowable, ignored";
2882be168c0dSopenharmony_ci+    return;
2883be168c0dSopenharmony_ci+  }
2884be168c0dSopenharmony_ci+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2885be168c0dSopenharmony_ci+  impl->SetPerformanceMode(mode);
2886be168c0dSopenharmony_ci+}
2887be168c0dSopenharmony_ci+
2888be168c0dSopenharmony_ci+OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info) {
2889be168c0dSopenharmony_ci+  if (device_info == nullptr) {
2890be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "device info is null";
2891be168c0dSopenharmony_ci+    return OH_AI_PERFORMANCE_NONE;
2892be168c0dSopenharmony_ci+  }
2893be168c0dSopenharmony_ci+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2894be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Get performance_mode of non-NNRT device is not allowable, ignored";
2895be168c0dSopenharmony_ci+    return OH_AI_PERFORMANCE_NONE;
2896be168c0dSopenharmony_ci+  }
2897be168c0dSopenharmony_ci+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2898be168c0dSopenharmony_ci+  return static_cast<OH_AI_PerformanceMode>(impl->GetPerformanceMode());
2899be168c0dSopenharmony_ci+}
2900be168c0dSopenharmony_ci+
2901be168c0dSopenharmony_ci+void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority) {
2902be168c0dSopenharmony_ci+  if (device_info == nullptr) {
2903be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "device info is null";
2904be168c0dSopenharmony_ci+    return;
2905be168c0dSopenharmony_ci+  }
2906be168c0dSopenharmony_ci+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2907be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Set priority of non-NNRT device is not allowable, ignored";
2908be168c0dSopenharmony_ci+    return;
2909be168c0dSopenharmony_ci+  }
2910be168c0dSopenharmony_ci+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2911be168c0dSopenharmony_ci+  impl->SetPriority(priority);
2912be168c0dSopenharmony_ci+}
2913be168c0dSopenharmony_ci+
2914be168c0dSopenharmony_ci+OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info) {
2915be168c0dSopenharmony_ci+  if (device_info == nullptr) {
2916be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "device info is null";
2917be168c0dSopenharmony_ci+    return OH_AI_PRIORITY_NONE;
2918be168c0dSopenharmony_ci+  }
2919be168c0dSopenharmony_ci+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2920be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Get priority of non-NNRT device is not allowable, ignored";
2921be168c0dSopenharmony_ci+    return OH_AI_PRIORITY_NONE;
2922be168c0dSopenharmony_ci+  }
2923be168c0dSopenharmony_ci+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2924be168c0dSopenharmony_ci+  return static_cast<OH_AI_Priority>(impl->GetPriority());
2925be168c0dSopenharmony_ci+}
2926be168c0dSopenharmony_ci+
2927be168c0dSopenharmony_ci+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info,
2928be168c0dSopenharmony_ci+                                                    const char *name, const char*value, size_t value_size) {
2929be168c0dSopenharmony_ci+  if (device_info == nullptr) {
2930be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "device info is null";
2931be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_NULLPTR;
2932be168c0dSopenharmony_ci+  }
2933be168c0dSopenharmony_ci+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2934be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Add extension to non-NNRT device is not allowable, ignored";
2935be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_ERROR;
2936be168c0dSopenharmony_ci+  }
2937be168c0dSopenharmony_ci+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2938be168c0dSopenharmony_ci+  mindspore::Extension extension;
2939be168c0dSopenharmony_ci+  extension.name = std::string(name);
2940be168c0dSopenharmony_ci+  extension.value = std::vector<uint8_t>(value, value + value_size);
2941be168c0dSopenharmony_ci+  std::vector<mindspore::Extension> extension_list = impl->GetExtensions();
2942be168c0dSopenharmony_ci+  extension_list.push_back(extension);
2943be168c0dSopenharmony_ci+  impl->SetExtensions(extension_list);
2944be168c0dSopenharmony_ci+  return OH_AI_STATUS_SUCCESS;
2945be168c0dSopenharmony_ci+}
2946be168c0dSopenharmony_ci\ No newline at end of file
2947be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h
2948be168c0dSopenharmony_ciindex 076f4d1f..dc88b8a4 100644
2949be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/context_c.h
2950be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/context_c.h
2951be168c0dSopenharmony_ci@@ -21,27 +21,4 @@
2952be168c0dSopenharmony_ci #include <memory>
2953be168c0dSopenharmony_ci #include "include/c_api/types_c.h"
2954be168c0dSopenharmony_ci 
2955be168c0dSopenharmony_ci-namespace mindspore {
2956be168c0dSopenharmony_ci-class Allocator;
2957be168c0dSopenharmony_ci-class Delegate;
2958be168c0dSopenharmony_ci-
2959be168c0dSopenharmony_ci-typedef struct DeviceInfoC {
2960be168c0dSopenharmony_ci-  OH_AI_DeviceType device_type;
2961be168c0dSopenharmony_ci-  bool enable_fp16 = false;
2962be168c0dSopenharmony_ci-  int frequency = 3;
2963be168c0dSopenharmony_ci-  std::string provider;
2964be168c0dSopenharmony_ci-  std::string provider_device;
2965be168c0dSopenharmony_ci-  std::shared_ptr<Allocator> allocator = nullptr;
2966be168c0dSopenharmony_ci-} DeviceInfoC;
2967be168c0dSopenharmony_ci-
2968be168c0dSopenharmony_ci-typedef struct ContextC {
2969be168c0dSopenharmony_ci-  std::vector<std::shared_ptr<DeviceInfoC>> device_info_list;
2970be168c0dSopenharmony_ci-  int32_t thread_num = 2;
2971be168c0dSopenharmony_ci-  bool enable_parallel = false;
2972be168c0dSopenharmony_ci-  std::vector<int32_t> affinity_core_list;
2973be168c0dSopenharmony_ci-  int affinity_mode = 0;
2974be168c0dSopenharmony_ci-  int delegate_mode = 0;
2975be168c0dSopenharmony_ci-  std::shared_ptr<Delegate> delegate = nullptr;
2976be168c0dSopenharmony_ci-} ContextC;
2977be168c0dSopenharmony_ci-}  // namespace mindspore
2978be168c0dSopenharmony_ci #endif  // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_
2979be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc
2980be168c0dSopenharmony_ciindex 802df6b1..9da52d76 100644
2981be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/model_c.cc
2982be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/model_c.cc
2983be168c0dSopenharmony_ci@@ -17,321 +17,135 @@
2984be168c0dSopenharmony_ci #include <vector>
2985be168c0dSopenharmony_ci #include <cstdint>
2986be168c0dSopenharmony_ci #include "include/api/context.h"
2987be168c0dSopenharmony_ci+#include <include/api/serialization.h>
2988be168c0dSopenharmony_ci #include "include/api/types.h"
2989be168c0dSopenharmony_ci #include "src/litert/cxx_api/tensor/tensor_impl.h"
2990be168c0dSopenharmony_ci #include "src/litert/cxx_api/converters.h"
2991be168c0dSopenharmony_ci-#include "src/litert/lite_session.h"
2992be168c0dSopenharmony_ci-#include "src/litert/cpu_info.h"
2993be168c0dSopenharmony_ci+#include "src/litert//cxx_api/model/model_impl.h"
2994be168c0dSopenharmony_ci 
2995be168c0dSopenharmony_ci namespace mindspore {
2996be168c0dSopenharmony_ci class ModelC {
2997be168c0dSopenharmony_ci- public:
2998be168c0dSopenharmony_ci-  ModelC() : session_(nullptr), context_(nullptr) {}
2999be168c0dSopenharmony_ci+public:
3000be168c0dSopenharmony_ci+  ModelC() : model_(nullptr) {}
3001be168c0dSopenharmony_ci   ~ModelC() {
3002be168c0dSopenharmony_ci-    for (auto &impl : tensor_map_) {
3003be168c0dSopenharmony_ci-      delete impl.second;
3004be168c0dSopenharmony_ci+    for (auto in : inputs_) {
3005be168c0dSopenharmony_ci+      delete in;
3006be168c0dSopenharmony_ci+    }
3007be168c0dSopenharmony_ci+    for (auto out : outputs_) {
3008be168c0dSopenharmony_ci+      delete out;
3009be168c0dSopenharmony_ci+    }
3010be168c0dSopenharmony_ci+    for (auto out : outputs_train_) {
3011be168c0dSopenharmony_ci+      delete out;
3012be168c0dSopenharmony_ci     }
3013be168c0dSopenharmony_ci   }
3014be168c0dSopenharmony_ci 
3015be168c0dSopenharmony_ci-  Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context);
3016be168c0dSopenharmony_ci-  Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context);
3017be168c0dSopenharmony_ci-  Status Resize(const std::vector<LiteTensorImpl *> &inputs, const std::vector<std::vector<int64_t>> &shapes);
3018be168c0dSopenharmony_ci-
3019be168c0dSopenharmony_ci-  Status Predict(const OH_AI_TensorHandle *inputs, size_t input_num, OH_AI_TensorHandle **outputs, size_t *output_num,
3020be168c0dSopenharmony_ci-                 const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after);
3021be168c0dSopenharmony_ci-
3022be168c0dSopenharmony_ci-  LiteTensorImpl **GetInputs(size_t *input_num);
3023be168c0dSopenharmony_ci-  LiteTensorImpl **GetOutputs(size_t *output_num);
3024be168c0dSopenharmony_ci+  MSTensor **GetInputs(size_t *input_num);
3025be168c0dSopenharmony_ci+  MSTensor **GetOutputs(size_t *output_num);
3026be168c0dSopenharmony_ci+  mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &oh_callback);
3027be168c0dSopenharmony_ci+  std::shared_ptr<Model> model_;
3028be168c0dSopenharmony_ci+  std::shared_ptr<Context> context_;
3029be168c0dSopenharmony_ci 
3030be168c0dSopenharmony_ci- private:
3031be168c0dSopenharmony_ci-  Status RunGraph(const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after);
3032be168c0dSopenharmony_ci-  void ResetTensorData(std::vector<void *> old_data, std::vector<lite::Tensor *> tensors);
3033be168c0dSopenharmony_ci-  LiteTensorImpl *TensorToTensorImpl(mindspore::lite::Tensor *tensor);
3034be168c0dSopenharmony_ci-
3035be168c0dSopenharmony_ci- private:
3036be168c0dSopenharmony_ci-  std::shared_ptr<lite::LiteSession> session_ = nullptr;
3037be168c0dSopenharmony_ci-  std::shared_ptr<const ContextC> context_ = nullptr;
3038be168c0dSopenharmony_ci-  std::map<mindspore::lite::Tensor *, LiteTensorImpl *> tensor_map_;
3039be168c0dSopenharmony_ci-  std::vector<LiteTensorImpl *> inputs_;
3040be168c0dSopenharmony_ci-  std::vector<LiteTensorImpl *> outputs_;
3041be168c0dSopenharmony_ci-  bool is_already_built = false;
3042be168c0dSopenharmony_ci+private:
3043be168c0dSopenharmony_ci+  MSTensor **GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors);
3044be168c0dSopenharmony_ci+  std::vector<MSTensor *> inputs_;
3045be168c0dSopenharmony_ci+  std::vector<MSTensor *> outputs_;
3046be168c0dSopenharmony_ci+  std::vector<MSTensor *> outputs_train_;
3047be168c0dSopenharmony_ci };
3048be168c0dSopenharmony_ci 
3049be168c0dSopenharmony_ci-Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) {
3050be168c0dSopenharmony_ci-  if (is_already_built) {
3051be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "The model is already built.";
3052be168c0dSopenharmony_ci-    return kLiteModelRebuild;
3053be168c0dSopenharmony_ci-  }
3054be168c0dSopenharmony_ci-  if (!PlatformInstructionSetSupportCheck()) {
3055be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "The platform exist don't support's instruction.";
3056be168c0dSopenharmony_ci-    return kLiteNotSupport;
3057be168c0dSopenharmony_ci-  }
3058be168c0dSopenharmony_ci-  if(context_.get() != model_context){
3059be168c0dSopenharmony_ci-    context_.reset(model_context);
3060be168c0dSopenharmony_ci-  }
3061be168c0dSopenharmony_ci-  session_ = std::make_shared<lite::LiteSession>();
3062be168c0dSopenharmony_ci-  if (session_ == nullptr) {
3063be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "create session failed";
3064be168c0dSopenharmony_ci-    return kLiteNullptr;
3065be168c0dSopenharmony_ci-  }
3066be168c0dSopenharmony_ci-  auto ret = session_->Init(ContextUtils::Convert(model_context));
3067be168c0dSopenharmony_ci-  if (ret != mindspore::lite::RET_OK) {
3068be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "init session failed";
3069be168c0dSopenharmony_ci-    return static_cast<StatusCode>(ret);
3070be168c0dSopenharmony_ci-  }
3071be168c0dSopenharmony_ci-  ret = session_->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), model_type, data_size);
3072be168c0dSopenharmony_ci-  if (ret != RET_OK) {
3073be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Load and compile failed";
3074be168c0dSopenharmony_ci-    return static_cast<StatusCode>(ret);
3075be168c0dSopenharmony_ci-  }
3076be168c0dSopenharmony_ci-  is_already_built = true;
3077be168c0dSopenharmony_ci-  return static_cast<StatusCode>(kSuccess);
3078be168c0dSopenharmony_ci-}
3079be168c0dSopenharmony_ci-
3080be168c0dSopenharmony_ci-Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) {
3081be168c0dSopenharmony_ci-  if (is_already_built) {
3082be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "The model is already built.";
3083be168c0dSopenharmony_ci-    return kLiteModelRebuild;
3084be168c0dSopenharmony_ci-  }
3085be168c0dSopenharmony_ci-  if (!PlatformInstructionSetSupportCheck()) {
3086be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "The platform exist don't support's instruction.";
3087be168c0dSopenharmony_ci-    return kLiteNotSupport;
3088be168c0dSopenharmony_ci-  }
3089be168c0dSopenharmony_ci-  if(context_.get() != model_context){
3090be168c0dSopenharmony_ci-    context_.reset(model_context);
3091be168c0dSopenharmony_ci-  }
3092be168c0dSopenharmony_ci-  session_ = std::make_shared<lite::LiteSession>();
3093be168c0dSopenharmony_ci-  if (session_ == nullptr) {
3094be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "create session failed";
3095be168c0dSopenharmony_ci-    return kLiteNullptr;
3096be168c0dSopenharmony_ci-  }
3097be168c0dSopenharmony_ci-  auto ret = session_->Init(ContextUtils::Convert(model_context));
3098be168c0dSopenharmony_ci-  if (ret != mindspore::lite::RET_OK) {
3099be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "init session failed";
3100be168c0dSopenharmony_ci-    return static_cast<StatusCode>(ret);
3101be168c0dSopenharmony_ci+MSTensor **ModelC::GetInputs(size_t *input_num) {
3102be168c0dSopenharmony_ci+  if (model_ == nullptr) {
3103be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model_ is nullptr.";
3104be168c0dSopenharmony_ci+    return nullptr;
3105be168c0dSopenharmony_ci   }
3106be168c0dSopenharmony_ci-  ret = session_->LoadModelAndCompileByPath(model_path, model_type);
3107be168c0dSopenharmony_ci-  if (ret != RET_OK) {
3108be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Load and compile failed";
3109be168c0dSopenharmony_ci-    return static_cast<StatusCode>(ret);
3110be168c0dSopenharmony_ci+  if (!inputs_.empty()) {
3111be168c0dSopenharmony_ci+    *input_num = inputs_.size();
3112be168c0dSopenharmony_ci+    return inputs_.data();
3113be168c0dSopenharmony_ci   }
3114be168c0dSopenharmony_ci-  is_already_built = true;
3115be168c0dSopenharmony_ci-  return static_cast<StatusCode>(kSuccess);
3116be168c0dSopenharmony_ci-}
3117be168c0dSopenharmony_ci 
3118be168c0dSopenharmony_ci-Status ModelC::Resize(const std::vector<LiteTensorImpl *> &inputs, const std::vector<std::vector<int64_t>> &shapes) {
3119be168c0dSopenharmony_ci-  std::vector<lite::Tensor *> inner_input;
3120be168c0dSopenharmony_ci-  size_t input_num = inputs.size();
3121be168c0dSopenharmony_ci-  for (size_t i = 0; i < input_num; i++) {
3122be168c0dSopenharmony_ci-    auto input = inputs[i];
3123be168c0dSopenharmony_ci-    if (input == nullptr || input->lite_tensor() == nullptr) {
3124be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Input tensor is null.";
3125be168c0dSopenharmony_ci-      return kLiteInputTensorError;
3126be168c0dSopenharmony_ci+  auto inputs = model_->GetInputs();
3127be168c0dSopenharmony_ci+  *input_num = inputs.size();
3128be168c0dSopenharmony_ci+  inputs_.resize(inputs.size(), nullptr);
3129be168c0dSopenharmony_ci+  for (size_t i = 0; i < inputs.size(); i++) {
3130be168c0dSopenharmony_ci+    inputs_[i] = new (std::nothrow) MSTensor(inputs[i].impl());
3131be168c0dSopenharmony_ci+    if (inputs_[i] == nullptr) {
3132be168c0dSopenharmony_ci+      inputs_.clear();
3133be168c0dSopenharmony_ci+      return nullptr;
3134be168c0dSopenharmony_ci     }
3135be168c0dSopenharmony_ci-    inner_input.push_back(input->lite_tensor());
3136be168c0dSopenharmony_ci   }
3137be168c0dSopenharmony_ci-  size_t shape_num = shapes.size();
3138be168c0dSopenharmony_ci-  std::vector<std::vector<int32_t>> inner_shapes(shape_num);
3139be168c0dSopenharmony_ci-  for (size_t i = 0; i < shape_num; i++) {
3140be168c0dSopenharmony_ci-    std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(inner_shapes[i]),
3141be168c0dSopenharmony_ci-                   [](int64_t value) { return static_cast<int32_t>(value); });
3142be168c0dSopenharmony_ci-  }
3143be168c0dSopenharmony_ci-  if (session_ == nullptr) {
3144be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Session implement is null.";
3145be168c0dSopenharmony_ci-    return kLiteNullptr;
3146be168c0dSopenharmony_ci-  }
3147be168c0dSopenharmony_ci-  auto ret = session_->Resize(inner_input, inner_shapes);
3148be168c0dSopenharmony_ci-  return static_cast<StatusCode>(ret);
3149be168c0dSopenharmony_ci+  return inputs_.data();
3150be168c0dSopenharmony_ci }
3151be168c0dSopenharmony_ci 
3152be168c0dSopenharmony_ci-void ModelC::ResetTensorData(std::vector<void *> old_data, std::vector<lite::Tensor *> tensors) {
3153be168c0dSopenharmony_ci-  for (size_t j = 0; j < old_data.size(); j++) {
3154be168c0dSopenharmony_ci-    tensors.at(j)->set_data(old_data.at(j));
3155be168c0dSopenharmony_ci+MSTensor **ModelC::GetOutputs(size_t *output_num) {
3156be168c0dSopenharmony_ci+  if (model_->GetTrainMode() == true) {
3157be168c0dSopenharmony_ci+    return GetOutputsTensor(output_num, &outputs_train_);
3158be168c0dSopenharmony_ci+  } else {
3159be168c0dSopenharmony_ci+    return GetOutputsTensor(output_num, &outputs_);
3160be168c0dSopenharmony_ci   }
3161be168c0dSopenharmony_ci }
3162be168c0dSopenharmony_ci 
3163be168c0dSopenharmony_ci-Status ModelC::Predict(const OH_AI_TensorHandle *inputs, size_t input_num, OH_AI_TensorHandle **outputs,
3164be168c0dSopenharmony_ci-                       size_t *output_num, const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after) {
3165be168c0dSopenharmony_ci-  if (outputs == nullptr || session_ == nullptr) {
3166be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3167be168c0dSopenharmony_ci-    return kLiteError;
3168be168c0dSopenharmony_ci+MSTensor **ModelC::GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors) {
3169be168c0dSopenharmony_ci+  if (model_ == nullptr) {
3170be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model_ is nullptr.";
3171be168c0dSopenharmony_ci+    return nullptr;
3172be168c0dSopenharmony_ci   }
3173be168c0dSopenharmony_ci-  auto model_inputs = session_->GetInputs();
3174be168c0dSopenharmony_ci-  if (model_inputs.size() != input_num) {
3175be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Wrong input size.";
3176be168c0dSopenharmony_ci-    return kLiteError;
3177be168c0dSopenharmony_ci+  if (!vec_tensors->empty()) {
3178be168c0dSopenharmony_ci+    *output_num = vec_tensors->size();
3179be168c0dSopenharmony_ci+    return vec_tensors->data();
3180be168c0dSopenharmony_ci   }
3181be168c0dSopenharmony_ci-  std::vector<void *> old_data;
3182be168c0dSopenharmony_ci-  for (size_t i = 0; i < input_num; i++) {
3183be168c0dSopenharmony_ci-    auto real_input = model_inputs[i];
3184be168c0dSopenharmony_ci-    auto user_input = static_cast<LiteTensorImpl *>(inputs[i]);
3185be168c0dSopenharmony_ci-    if (user_input->DataType() != static_cast<DataType>(real_input->data_type())) {
3186be168c0dSopenharmony_ci-      ResetTensorData(old_data, model_inputs);
3187be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "DataType does not match, input:" << user_input->Name()
3188be168c0dSopenharmony_ci-                    << ", real:" << real_input->tensor_name();
3189be168c0dSopenharmony_ci-      return kLiteInputTensorError;
3190be168c0dSopenharmony_ci-    }
3191be168c0dSopenharmony_ci-    if (user_input->Data() == nullptr) {
3192be168c0dSopenharmony_ci-      ResetTensorData(old_data, model_inputs);
3193be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has no data.";
3194be168c0dSopenharmony_ci-      return kLiteInputTensorError;
3195be168c0dSopenharmony_ci-    }
3196be168c0dSopenharmony_ci 
3197be168c0dSopenharmony_ci-    // GPU tensor can't manipulate CPU memory which the user provides.
3198be168c0dSopenharmony_ci-    // When model input is GPU tensor and user input is NOT GPU data,
3199be168c0dSopenharmony_ci-    // just free model input's data for late GPU Tensor filling.
3200be168c0dSopenharmony_ci-    if (IS_OPENCL_ALLOCATOR(real_input->allocator()) && (!IS_OPENCL_ALLOCATOR(user_input->GetAllocator()))) {
3201be168c0dSopenharmony_ci-      real_input->FreeData();
3202be168c0dSopenharmony_ci-    }
3203be168c0dSopenharmony_ci-    old_data.push_back(real_input->data());  // Save original data in model tensors.
3204be168c0dSopenharmony_ci-
3205be168c0dSopenharmony_ci-    if (real_input->data_type() == kObjectTypeString) {
3206be168c0dSopenharmony_ci-      std::vector<int32_t> shape;
3207be168c0dSopenharmony_ci-      std::transform(user_input->Shape().begin(), user_input->Shape().end(), std::back_inserter(shape),
3208be168c0dSopenharmony_ci-                     [](int64_t value) { return static_cast<int32_t>(value); });
3209be168c0dSopenharmony_ci-      real_input->set_shape(shape);
3210be168c0dSopenharmony_ci-      real_input->set_data(user_input->MutableData());
3211be168c0dSopenharmony_ci-    } else {
3212be168c0dSopenharmony_ci-      if (user_input->MutableData() != real_input->data()) {
3213be168c0dSopenharmony_ci-        if (real_input->Size() != user_input->DataSize()) {
3214be168c0dSopenharmony_ci-          ResetTensorData(old_data, model_inputs);
3215be168c0dSopenharmony_ci-          MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has wrong data size.";
3216be168c0dSopenharmony_ci-          return kLiteInputTensorError;
3217be168c0dSopenharmony_ci-        }
3218be168c0dSopenharmony_ci-        if (!IS_OPENCL_ALLOCATOR(real_input->allocator())) {
3219be168c0dSopenharmony_ci-          real_input->set_data(user_input->MutableData());
3220be168c0dSopenharmony_ci-        } else {
3221be168c0dSopenharmony_ci-          // Use outside CPU data to fill GPU Tensor.
3222be168c0dSopenharmony_ci-          auto dst_data = real_input->MutableData();
3223be168c0dSopenharmony_ci-          auto src_data = user_input->MutableData();
3224be168c0dSopenharmony_ci-          (void)memcpy(dst_data, src_data, real_input->Size());
3225be168c0dSopenharmony_ci-        }
3226be168c0dSopenharmony_ci-      }
3227be168c0dSopenharmony_ci-    }
3228be168c0dSopenharmony_ci-  }
3229be168c0dSopenharmony_ci-  auto ret = RunGraph(before, after);
3230be168c0dSopenharmony_ci-  ResetTensorData(old_data, model_inputs);
3231be168c0dSopenharmony_ci-  if (ret != kSuccess) {
3232be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Run graph failed.";
3233be168c0dSopenharmony_ci-    return ret;
3234be168c0dSopenharmony_ci-  }
3235be168c0dSopenharmony_ci-
3236be168c0dSopenharmony_ci-  *outputs = reinterpret_cast<OH_AI_TensorHandle *>(GetOutputs(output_num));
3237be168c0dSopenharmony_ci-  return kSuccess;
3238be168c0dSopenharmony_ci-}
3239be168c0dSopenharmony_ci-
3240be168c0dSopenharmony_ci-Status ModelC::RunGraph(const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after) {
3241be168c0dSopenharmony_ci-  KernelCallBack before_call_back = nullptr;
3242be168c0dSopenharmony_ci-  KernelCallBack after_call_back = nullptr;
3243be168c0dSopenharmony_ci-  if (before != nullptr) {
3244be168c0dSopenharmony_ci-    before_call_back = [&](const std::vector<mindspore::lite::Tensor *> &before_inputs,
3245be168c0dSopenharmony_ci-                           const std::vector<mindspore::lite::Tensor *> &before_outputs,
3246be168c0dSopenharmony_ci-                           const MSCallBackParam &call_param) {
3247be168c0dSopenharmony_ci-      std::vector<LiteTensorImpl> inputs_impl;
3248be168c0dSopenharmony_ci-      std::vector<LiteTensorImpl> outputs_impl;
3249be168c0dSopenharmony_ci-      std::vector<OH_AI_TensorHandle> op_inputs;
3250be168c0dSopenharmony_ci-      std::vector<OH_AI_TensorHandle> op_outputs;
3251be168c0dSopenharmony_ci-    size_t op_input_num = before_inputs.size();
3252be168c0dSopenharmony_ci-    for (size_t i = 0; i < op_input_num; i++) {
3253be168c0dSopenharmony_ci-      inputs_impl.emplace_back(before_inputs[i]);
3254be168c0dSopenharmony_ci-      op_inputs.push_back(&(inputs_impl.back()));
3255be168c0dSopenharmony_ci-    }
3256be168c0dSopenharmony_ci-    size_t op_output_num = before_outputs.size();
3257be168c0dSopenharmony_ci-    for (size_t i = 0; i < op_output_num; i++) {
3258be168c0dSopenharmony_ci-      outputs_impl.emplace_back(before_outputs[i]);
3259be168c0dSopenharmony_ci-      op_outputs.push_back(&(outputs_impl.back()));
3260be168c0dSopenharmony_ci-    }
3261be168c0dSopenharmony_ci-      const OH_AI_CallBackParam op_info = {const_cast<char *>(call_param.node_name.c_str()),
3262be168c0dSopenharmony_ci-                                         const_cast<char *>(call_param.node_type.c_str())};
3263be168c0dSopenharmony_ci-      OH_AI_TensorHandleArray inputs = {op_input_num, op_inputs.data()};
3264be168c0dSopenharmony_ci-      OH_AI_TensorHandleArray outputs = {op_output_num, op_outputs.data()};
3265be168c0dSopenharmony_ci-    return before(inputs, outputs, op_info);
3266be168c0dSopenharmony_ci-  };
3267be168c0dSopenharmony_ci-  }
3268be168c0dSopenharmony_ci-  if (after != nullptr) {
3269be168c0dSopenharmony_ci-    after_call_back = [&](const std::vector<mindspore::lite::Tensor *> &after_inputs,
3270be168c0dSopenharmony_ci-                          const std::vector<mindspore::lite::Tensor *> &after_outputs,
3271be168c0dSopenharmony_ci-                          const MSCallBackParam &call_param) {
3272be168c0dSopenharmony_ci-      std::vector<LiteTensorImpl> inputs_impl;
3273be168c0dSopenharmony_ci-      std::vector<LiteTensorImpl> outputs_impl;
3274be168c0dSopenharmony_ci-      std::vector<OH_AI_TensorHandle> op_inputs;
3275be168c0dSopenharmony_ci-      std::vector<OH_AI_TensorHandle> op_outputs;
3276be168c0dSopenharmony_ci-    size_t op_input_num = after_inputs.size();
3277be168c0dSopenharmony_ci-    for (size_t i = 0; i < op_input_num; i++) {
3278be168c0dSopenharmony_ci-      inputs_impl.emplace_back(after_inputs[i]);
3279be168c0dSopenharmony_ci-      op_inputs.push_back(&(inputs_impl.back()));
3280be168c0dSopenharmony_ci-    }
3281be168c0dSopenharmony_ci-    size_t op_output_num = after_outputs.size();
3282be168c0dSopenharmony_ci-    for (size_t i = 0; i < op_output_num; i++) {
3283be168c0dSopenharmony_ci-      outputs_impl.emplace_back(after_outputs[i]);
3284be168c0dSopenharmony_ci-      op_outputs.push_back(&(outputs_impl.back()));
3285be168c0dSopenharmony_ci-    }
3286be168c0dSopenharmony_ci-    const OH_AI_CallBackParam op_info = {const_cast<char *>(call_param.node_name.c_str()),
3287be168c0dSopenharmony_ci-                                         const_cast<char *>(call_param.node_type.c_str())};
3288be168c0dSopenharmony_ci-    OH_AI_TensorHandleArray inputs = {op_input_num, op_inputs.data()};
3289be168c0dSopenharmony_ci-    OH_AI_TensorHandleArray outputs = {op_output_num, op_outputs.data()};
3290be168c0dSopenharmony_ci-    return after(inputs, outputs, op_info);
3291be168c0dSopenharmony_ci-  };
3292be168c0dSopenharmony_ci-  }
3293be168c0dSopenharmony_ci-  auto ret = session_->RunGraph(before_call_back, after_call_back);
3294be168c0dSopenharmony_ci-  return static_cast<StatusCode>(ret);
3295be168c0dSopenharmony_ci-}
3296be168c0dSopenharmony_ci-
3297be168c0dSopenharmony_ci-LiteTensorImpl *ModelC::TensorToTensorImpl(mindspore::lite::Tensor *tensor) {
3298be168c0dSopenharmony_ci-  LiteTensorImpl *impl = nullptr;
3299be168c0dSopenharmony_ci-  auto iter = tensor_map_.find(tensor);
3300be168c0dSopenharmony_ci-  if (iter != tensor_map_.end()) {
3301be168c0dSopenharmony_ci-    impl = iter->second;
3302be168c0dSopenharmony_ci-  } else {
3303be168c0dSopenharmony_ci-    impl = new (std::nothrow) LiteTensorImpl(tensor);
3304be168c0dSopenharmony_ci-    if (impl == nullptr || impl->lite_tensor() == nullptr) {
3305be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Create tensor failed.";
3306be168c0dSopenharmony_ci+  auto outputs = model_->GetOutputs();
3307be168c0dSopenharmony_ci+  *output_num = outputs.size();
3308be168c0dSopenharmony_ci+  vec_tensors->resize(outputs.size(), nullptr);
3309be168c0dSopenharmony_ci+  for (size_t i = 0; i < outputs.size(); i++) {
3310be168c0dSopenharmony_ci+    (*vec_tensors)[i] = new (std::nothrow) MSTensor(outputs[i].impl());
3311be168c0dSopenharmony_ci+    if ((*vec_tensors)[i] == nullptr) {
3312be168c0dSopenharmony_ci+      vec_tensors->clear();
3313be168c0dSopenharmony_ci       return nullptr;
3314be168c0dSopenharmony_ci     }
3315be168c0dSopenharmony_ci-    tensor_map_[tensor] = impl;
3316be168c0dSopenharmony_ci   }
3317be168c0dSopenharmony_ci-  return impl;
3318be168c0dSopenharmony_ci+  return vec_tensors->data();
3319be168c0dSopenharmony_ci }
3320be168c0dSopenharmony_ci 
3321be168c0dSopenharmony_ci-LiteTensorImpl **ModelC::GetInputs(size_t *input_num) {
3322be168c0dSopenharmony_ci-  if (session_ == nullptr || input_num == nullptr) {
3323be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Session is null.";
3324be168c0dSopenharmony_ci-    return nullptr;
3325be168c0dSopenharmony_ci-  }
3326be168c0dSopenharmony_ci-  auto inputs = session_->GetInputs();
3327be168c0dSopenharmony_ci-  *input_num = inputs.size();
3328be168c0dSopenharmony_ci-  if (inputs_.capacity() < *input_num) {
3329be168c0dSopenharmony_ci-    inputs_.reserve(*input_num);
3330be168c0dSopenharmony_ci-  }
3331be168c0dSopenharmony_ci-  inputs_.clear();
3332be168c0dSopenharmony_ci-  std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_),
3333be168c0dSopenharmony_ci-                 [&](lite::Tensor *input) { return TensorToTensorImpl(input); });
3334be168c0dSopenharmony_ci-  return inputs_.data();
3335be168c0dSopenharmony_ci-}
3336be168c0dSopenharmony_ci+mindspore::MSKernelCallBack ModelC::TransCallBack(const OH_AI_KernelCallBack &oh_callback) {
3337be168c0dSopenharmony_ci+  mindspore::MSKernelCallBack call_back = nullptr;
3338be168c0dSopenharmony_ci+  if (oh_callback != nullptr) {
3339be168c0dSopenharmony_ci+    call_back = [&](const std::vector<mindspore::MSTensor> &inputs,
3340be168c0dSopenharmony_ci+                    const std::vector<mindspore::MSTensor> &outputs,
3341be168c0dSopenharmony_ci+                    const mindspore::MSCallBackParam &opInfo) {
3342be168c0dSopenharmony_ci+      std::vector<OH_AI_TensorHandle> vec_inputs;
3343be168c0dSopenharmony_ci+      std::vector<OH_AI_TensorHandle> vec_outputs;
3344be168c0dSopenharmony_ci+      OH_AI_CallBackParam call_back = {const_cast<char *>(opInfo.node_name.c_str()),
3345be168c0dSopenharmony_ci+                                       const_cast<char *>(opInfo.node_type.c_str())};
3346be168c0dSopenharmony_ci+      size_t inputs_handle_num = inputs.size();
3347be168c0dSopenharmony_ci+      for (size_t i = 0; i < inputs_handle_num; i++) {
3348be168c0dSopenharmony_ci+        vec_inputs.push_back(
3349be168c0dSopenharmony_ci+          static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(inputs)[i])));
3350be168c0dSopenharmony_ci+      }
3351be168c0dSopenharmony_ci+      size_t outputs_handle_num = inputs.size();
3352be168c0dSopenharmony_ci+      for (size_t i = 0; i < outputs_handle_num; i++) {
3353be168c0dSopenharmony_ci+        vec_outputs.push_back(
3354be168c0dSopenharmony_ci+          static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(outputs)[i])));
3355be168c0dSopenharmony_ci+      }
3356be168c0dSopenharmony_ci 
3357be168c0dSopenharmony_ci-LiteTensorImpl **ModelC::GetOutputs(size_t *output_num) {
3358be168c0dSopenharmony_ci-  if (session_ == nullptr || output_num == nullptr) {
3359be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Session is null.";
3360be168c0dSopenharmony_ci-    return nullptr;
3361be168c0dSopenharmony_ci-  }
3362be168c0dSopenharmony_ci-  auto outputs = session_->GetOutputs();
3363be168c0dSopenharmony_ci-  *output_num = outputs.size();
3364be168c0dSopenharmony_ci-  if (outputs_.capacity() < *output_num) {
3365be168c0dSopenharmony_ci-    outputs_.reserve(*output_num);
3366be168c0dSopenharmony_ci+      OH_AI_TensorHandleArray handle_inputs = {inputs_handle_num, vec_inputs.data()};
3367be168c0dSopenharmony_ci+      OH_AI_TensorHandleArray handle_outputs = {outputs_handle_num, vec_outputs.data()};
3368be168c0dSopenharmony_ci+      return oh_callback(handle_inputs, handle_outputs, call_back);
3369be168c0dSopenharmony_ci+    };
3370be168c0dSopenharmony_ci   }
3371be168c0dSopenharmony_ci-  outputs_.clear();
3372be168c0dSopenharmony_ci-  std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_),
3373be168c0dSopenharmony_ci-                 [&](std::unordered_map<std::string, mindspore::lite::Tensor *>::value_type iter) {
3374be168c0dSopenharmony_ci-                   return TensorToTensorImpl(iter.second);
3375be168c0dSopenharmony_ci-                 });
3376be168c0dSopenharmony_ci-  return outputs_.data();
3377be168c0dSopenharmony_ci+  return call_back;
3378be168c0dSopenharmony_ci }
3379be168c0dSopenharmony_ci }  // namespace mindspore
3380be168c0dSopenharmony_ci 
3381be168c0dSopenharmony_ci OH_AI_ModelHandle OH_AI_ModelCreate() {
3382be168c0dSopenharmony_ci   auto impl = new (std::nothrow) mindspore::ModelC();
3383be168c0dSopenharmony_ci   if (impl == nullptr) {
3384be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Model implement is null.";
3385be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Model implement is nullptr.";
3386be168c0dSopenharmony_ci+    return nullptr;
3387be168c0dSopenharmony_ci+  }
3388be168c0dSopenharmony_ci+  impl->model_ = std::make_shared<mindspore::Model>();
3389be168c0dSopenharmony_ci+  if (impl->model_ == nullptr) {
3390be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model_ is nullptr.";
3391be168c0dSopenharmony_ci+    delete impl;
3392be168c0dSopenharmony_ci     return nullptr;
3393be168c0dSopenharmony_ci   }
3394be168c0dSopenharmony_ci   return static_cast<OH_AI_ModelHandle>(impl);
3395be168c0dSopenharmony_ci@@ -358,55 +172,59 @@ size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {
3396be168c0dSopenharmony_ci OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
3397be168c0dSopenharmony_ci                               OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context) {
3398be168c0dSopenharmony_ci   if (model == nullptr || model_data == nullptr || model_context == nullptr) {
3399be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3400be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model/model_data/model_context is nullptr.";
3401be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_NULLPTR;
3402be168c0dSopenharmony_ci   }
3403be168c0dSopenharmony_ci   if (model_type == OH_AI_MODELTYPE_INVALID) {
3404be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is invalid.";
3405be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model_type is invalid.";
3406be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_PARAM_INVALID;
3407be168c0dSopenharmony_ci   }
3408be168c0dSopenharmony_ci-  mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
3409be168c0dSopenharmony_ci+  mindspore::Context *context = static_cast<mindspore::Context *>(model_context);
3410be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
3411be168c0dSopenharmony_ci-  auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context);
3412be168c0dSopenharmony_ci+  if (impl->context_.get() != context) {
3413be168c0dSopenharmony_ci+    impl->context_.reset(context);
3414be168c0dSopenharmony_ci+  }
3415be168c0dSopenharmony_ci+  auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_);
3416be168c0dSopenharmony_ci   return static_cast<OH_AI_Status>(ret.StatusCode());
3417be168c0dSopenharmony_ci }
3418be168c0dSopenharmony_ci 
3419be168c0dSopenharmony_ci OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
3420be168c0dSopenharmony_ci                                       const OH_AI_ContextHandle model_context) {
3421be168c0dSopenharmony_ci   if (model == nullptr || model_path == nullptr || model_context == nullptr) {
3422be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3423be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model/model_path/model_context is nullptr.";
3424be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_NULLPTR;
3425be168c0dSopenharmony_ci   }
3426be168c0dSopenharmony_ci   if (model_type == OH_AI_MODELTYPE_INVALID) {
3427be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is invalid.";
3428be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model_type is invalid.";
3429be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_PARAM_INVALID;
3430be168c0dSopenharmony_ci   }
3431be168c0dSopenharmony_ci-  mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
3432be168c0dSopenharmony_ci+  mindspore::Context *context = static_cast<mindspore::Context *>(model_context);
3433be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
3434be168c0dSopenharmony_ci-  auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context);
3435be168c0dSopenharmony_ci+  if (impl->context_.get() != context) {
3436be168c0dSopenharmony_ci+    impl->context_.reset(context);
3437be168c0dSopenharmony_ci+  }
3438be168c0dSopenharmony_ci+  auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_);
3439be168c0dSopenharmony_ci   return static_cast<OH_AI_Status>(ret.StatusCode());
3440be168c0dSopenharmony_ci }
3441be168c0dSopenharmony_ci 
3442be168c0dSopenharmony_ci OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs,
3443be168c0dSopenharmony_ci                                OH_AI_ShapeInfo *shape_infos, size_t shape_info_num) {
3444be168c0dSopenharmony_ci   if (model == nullptr || shape_infos == nullptr) {
3445be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3446be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model/shape_infos is nullptr.";
3447be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_NULLPTR;
3448be168c0dSopenharmony_ci   }
3449be168c0dSopenharmony_ci-  std::vector<mindspore::LiteTensorImpl *> vec_inputs;
3450be168c0dSopenharmony_ci-  std::transform(inputs.handle_list, inputs.handle_list + inputs.handle_num, std::back_inserter(vec_inputs),
3451be168c0dSopenharmony_ci-                 [](OH_AI_TensorHandle value) { return static_cast<mindspore::LiteTensorImpl *>(value); });
3452be168c0dSopenharmony_ci+  std::vector<mindspore::MSTensor> vec_inputs;
3453be168c0dSopenharmony_ci+  for (size_t i = 0; i < inputs.handle_num; ++i) {
3454be168c0dSopenharmony_ci+    vec_inputs.push_back(*static_cast<mindspore::MSTensor *>(inputs.handle_list[i]));
3455be168c0dSopenharmony_ci+  }
3456be168c0dSopenharmony_ci+
3457be168c0dSopenharmony_ci   std::vector<std::vector<int64_t>> vec_dims;
3458be168c0dSopenharmony_ci   for (size_t i = 0; i < shape_info_num; i++) {
3459be168c0dSopenharmony_ci     std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num);
3460be168c0dSopenharmony_ci-    if (std::any_of(shape.begin(), shape.end(), [](int64_t val) { return val < 0 || val > INT32_MAX; })) {
3461be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Invalid shape: " << shape << ", each dimension must be in [0, INT32_MAX]";
3462be168c0dSopenharmony_ci-      return OH_AI_STATUS_LITE_PARAM_INVALID;
3463be168c0dSopenharmony_ci-    }
3464be168c0dSopenharmony_ci     vec_dims.push_back(shape);
3465be168c0dSopenharmony_ci   }
3466be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
3467be168c0dSopenharmony_ci-  auto ret = impl->Resize(vec_inputs, vec_dims);
3468be168c0dSopenharmony_ci+  auto ret = impl->model_->Resize(vec_inputs, vec_dims);
3469be168c0dSopenharmony_ci   return static_cast<OH_AI_Status>(ret.StatusCode());
3470be168c0dSopenharmony_ci }
3471be168c0dSopenharmony_ci 
3472be168c0dSopenharmony_ci@@ -414,15 +232,25 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
3473be168c0dSopenharmony_ci                                 OH_AI_TensorHandleArray *outputs, const OH_AI_KernelCallBack before,
3474be168c0dSopenharmony_ci                                 const OH_AI_KernelCallBack after) {
3475be168c0dSopenharmony_ci   if (model == nullptr) {
3476be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3477be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3478be168c0dSopenharmony_ci     return OH_AI_STATUS_LITE_NULLPTR;
3479be168c0dSopenharmony_ci   }
3480be168c0dSopenharmony_ci+  std::vector<mindspore::MSTensor> ms_tensor_inputs;
3481be168c0dSopenharmony_ci+  for (size_t i = 0; i < inputs.handle_num; i++) {
3482be168c0dSopenharmony_ci+    auto user_input = static_cast<mindspore::MSTensor *>(inputs.handle_list[i]);
3483be168c0dSopenharmony_ci+    ms_tensor_inputs.push_back(*user_input);
3484be168c0dSopenharmony_ci+  }
3485be168c0dSopenharmony_ci+
3486be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
3487be168c0dSopenharmony_ci-  auto ret = impl->Predict(inputs.handle_list, inputs.handle_num, &(outputs->handle_list), &(outputs->handle_num),
3488be168c0dSopenharmony_ci-                           before, after);
3489be168c0dSopenharmony_ci+  mindspore::MSKernelCallBack before_call_back = impl->TransCallBack(before);
3490be168c0dSopenharmony_ci+  mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after);
3491be168c0dSopenharmony_ci+
3492be168c0dSopenharmony_ci+  std::vector<mindspore::MSTensor> ms_tensor_outputs;
3493be168c0dSopenharmony_ci+  auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
3494be168c0dSopenharmony_ci   if (!ret.IsOk()) {
3495be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Predict fail, ret :" << ret;
3496be168c0dSopenharmony_ci   }
3497be168c0dSopenharmony_ci+  outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&outputs->handle_num));
3498be168c0dSopenharmony_ci   return static_cast<OH_AI_Status>(ret.StatusCode());
3499be168c0dSopenharmony_ci }
3500be168c0dSopenharmony_ci 
3501be168c0dSopenharmony_ci@@ -431,11 +259,6 @@ OH_AI_Status OH_AI_ModelRunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallB
3502be168c0dSopenharmony_ci   return OH_AI_STATUS_LITE_NOT_SUPPORT;
3503be168c0dSopenharmony_ci }
3504be168c0dSopenharmony_ci 
3505be168c0dSopenharmony_ci-OH_AI_Status OH_AI_ModelSetTrainMode(const OH_AI_ModelHandle model, bool train) {
3506be168c0dSopenharmony_ci-  MS_LOG(ERROR) << "Unsupported Feature.";
3507be168c0dSopenharmony_ci-  return OH_AI_STATUS_LITE_NOT_SUPPORT;
3508be168c0dSopenharmony_ci-}
3509be168c0dSopenharmony_ci-
3510be168c0dSopenharmony_ci OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *export_path) {
3511be168c0dSopenharmony_ci   MS_LOG(ERROR) << "Unsupported Feature.";
3512be168c0dSopenharmony_ci   return OH_AI_STATUS_LITE_NOT_SUPPORT;
3513be168c0dSopenharmony_ci@@ -443,7 +266,7 @@ OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *
3514be168c0dSopenharmony_ci 
3515be168c0dSopenharmony_ci OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) {
3516be168c0dSopenharmony_ci   if (model == nullptr) {
3517be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3518be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3519be168c0dSopenharmony_ci     return {0, nullptr};
3520be168c0dSopenharmony_ci   }
3521be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
3522be168c0dSopenharmony_ci@@ -454,7 +277,7 @@ OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) {
3523be168c0dSopenharmony_ci 
3524be168c0dSopenharmony_ci OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) {
3525be168c0dSopenharmony_ci   if (model == nullptr) {
3526be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3527be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3528be168c0dSopenharmony_ci     return {0, nullptr};
3529be168c0dSopenharmony_ci   }
3530be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
3531be168c0dSopenharmony_ci@@ -465,7 +288,7 @@ OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) {
3532be168c0dSopenharmony_ci 
3533be168c0dSopenharmony_ci OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
3534be168c0dSopenharmony_ci   if (model == nullptr || tensor_name == nullptr) {
3535be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3536be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model/tensor_name is nullptr.";
3537be168c0dSopenharmony_ci     return nullptr;
3538be168c0dSopenharmony_ci   }
3539be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
3540be168c0dSopenharmony_ci@@ -482,7 +305,7 @@ OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model
3541be168c0dSopenharmony_ci 
3542be168c0dSopenharmony_ci OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
3543be168c0dSopenharmony_ci   if (model == nullptr || tensor_name == nullptr) {
3544be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "param is nullptr.";
3545be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model/tensor_name is nullptr.";
3546be168c0dSopenharmony_ci     return nullptr;
3547be168c0dSopenharmony_ci   }
3548be168c0dSopenharmony_ci   auto impl = static_cast<mindspore::ModelC *>(model);
3549be168c0dSopenharmony_ci@@ -496,3 +319,294 @@ OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle mode
3550be168c0dSopenharmony_ci   MS_LOG(ERROR) << "tensor is not exist.";
3551be168c0dSopenharmony_ci   return nullptr;
3552be168c0dSopenharmony_ci }
3553be168c0dSopenharmony_ci+
3554be168c0dSopenharmony_ci+OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate() {
3555be168c0dSopenharmony_ci+  auto impl = new (std::nothrow) mindspore::TrainCfg();
3556be168c0dSopenharmony_ci+  if (impl == nullptr) {
3557be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "TrainCfg implement is nullptr.";
3558be168c0dSopenharmony_ci+    return nullptr;
3559be168c0dSopenharmony_ci+  }
3560be168c0dSopenharmony_ci+  return static_cast<OH_AI_TrainCfgHandle>(impl);
3561be168c0dSopenharmony_ci+}
3562be168c0dSopenharmony_ci+
3563be168c0dSopenharmony_ci+void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg) {
3564be168c0dSopenharmony_ci+  if (train_cfg != nullptr && *train_cfg != nullptr) {
3565be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::TrainCfg *>(*train_cfg);
3566be168c0dSopenharmony_ci+    delete impl;
3567be168c0dSopenharmony_ci+    *train_cfg = nullptr;
3568be168c0dSopenharmony_ci+  }
3569be168c0dSopenharmony_ci+}
3570be168c0dSopenharmony_ci+
3571be168c0dSopenharmony_ci+char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) {
3572be168c0dSopenharmony_ci+  if (train_cfg == nullptr || num == nullptr) {
3573be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "train_cfg/num is nullptr.";
3574be168c0dSopenharmony_ci+    return nullptr;
3575be168c0dSopenharmony_ci+  }
3576be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
3577be168c0dSopenharmony_ci+  auto loss_name = impl->GetLossName();
3578be168c0dSopenharmony_ci+  *num = loss_name.size();
3579be168c0dSopenharmony_ci+  char **name = static_cast<char **>(malloc(loss_name.size()));
3580be168c0dSopenharmony_ci+  if (name == nullptr) {
3581be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Failed to malloc loss_name.";
3582be168c0dSopenharmony_ci+    return nullptr;
3583be168c0dSopenharmony_ci+  }
3584be168c0dSopenharmony_ci+  for (size_t i = 0; i < loss_name.size(); i++) {
3585be168c0dSopenharmony_ci+    name[i] = static_cast<char *>(malloc(loss_name[i].size() + 1));
3586be168c0dSopenharmony_ci+    strcpy(name[i], loss_name[i].c_str());
3587be168c0dSopenharmony_ci+  }
3588be168c0dSopenharmony_ci+  return name;
3589be168c0dSopenharmony_ci+}
3590be168c0dSopenharmony_ci+
3591be168c0dSopenharmony_ci+void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num) {
3592be168c0dSopenharmony_ci+  if (train_cfg == nullptr) {
3593be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "train_cfg is nullptr.";
3594be168c0dSopenharmony_ci+    return;
3595be168c0dSopenharmony_ci+  }
3596be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
3597be168c0dSopenharmony_ci+  std::vector<std::string> vec_name;
3598be168c0dSopenharmony_ci+  for (size_t i = 0; i < num; i++) {
3599be168c0dSopenharmony_ci+    vec_name.push_back(loss_name[i]);
3600be168c0dSopenharmony_ci+  }
3601be168c0dSopenharmony_ci+  impl->SetLossName(vec_name);
3602be168c0dSopenharmony_ci+}
3603be168c0dSopenharmony_ci+
3604be168c0dSopenharmony_ci+OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg) {
3605be168c0dSopenharmony_ci+  if (train_cfg == nullptr) {
3606be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "train_cfg is nullptr.";
3607be168c0dSopenharmony_ci+    return OH_AI_KO0;
3608be168c0dSopenharmony_ci+  }
3609be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
3610be168c0dSopenharmony_ci+  return static_cast<OH_AI_OptimizationLevel>(impl->optimization_level_);
3611be168c0dSopenharmony_ci+}
3612be168c0dSopenharmony_ci+
3613be168c0dSopenharmony_ci+void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level) {
3614be168c0dSopenharmony_ci+  if (train_cfg == nullptr) {
3615be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "train_cfg is nullptr.";
3616be168c0dSopenharmony_ci+    return;
3617be168c0dSopenharmony_ci+  }
3618be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
3619be168c0dSopenharmony_ci+  impl->optimization_level_ = static_cast<mindspore::OptimizationLevel>(level);
3620be168c0dSopenharmony_ci+}
3621be168c0dSopenharmony_ci+
3622be168c0dSopenharmony_ci+OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
3623be168c0dSopenharmony_ci+                                   OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
3624be168c0dSopenharmony_ci+                                   const OH_AI_TrainCfgHandle train_cfg) {
3625be168c0dSopenharmony_ci+  if (model == nullptr || model_data == nullptr || model_context == nullptr) {
3626be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model/model_data/model_context is nullptr.";
3627be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_NULLPTR;
3628be168c0dSopenharmony_ci+  }
3629be168c0dSopenharmony_ci+  if (model_type == OH_AI_MODELTYPE_INVALID) {
3630be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model_type is invalid.";
3631be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3632be168c0dSopenharmony_ci+  }
3633be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3634be168c0dSopenharmony_ci+
3635be168c0dSopenharmony_ci+  mindspore::Graph graph;
3636be168c0dSopenharmony_ci+  auto status = mindspore::Serialization::Load(model_data, data_size, static_cast<mindspore::ModelType>(model_type), &graph);
3637be168c0dSopenharmony_ci+  if (status != mindspore::kSuccess) {
3638be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "load ms file failed.";
3639be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_ERROR;
3640be168c0dSopenharmony_ci+  }
3641be168c0dSopenharmony_ci+  auto context = static_cast<mindspore::Context *>(model_context);
3642be168c0dSopenharmony_ci+  auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
3643be168c0dSopenharmony_ci+  if (impl->context_.get() != context) {
3644be168c0dSopenharmony_ci+    impl->context_.reset(context);
3645be168c0dSopenharmony_ci+  }
3646be168c0dSopenharmony_ci+  auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
3647be168c0dSopenharmony_ci+                                 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
3648be168c0dSopenharmony_ci+  if (ret != mindspore::kSuccess) {
3649be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Load and compile failed";
3650be168c0dSopenharmony_ci+  }
3651be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3652be168c0dSopenharmony_ci+}
3653be168c0dSopenharmony_ci+
3654be168c0dSopenharmony_ci+OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path,
3655be168c0dSopenharmony_ci+                                           OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
3656be168c0dSopenharmony_ci+                                           const OH_AI_TrainCfgHandle train_cfg) {
3657be168c0dSopenharmony_ci+  if (model == nullptr || model_path == nullptr || model_context == nullptr) {
3658be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model/model_path/model_context is nullptr.";
3659be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_NULLPTR;
3660be168c0dSopenharmony_ci+  }
3661be168c0dSopenharmony_ci+  if (model_type == OH_AI_MODELTYPE_INVALID) {
3662be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model_type is invalid.";
3663be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3664be168c0dSopenharmony_ci+  }
3665be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3666be168c0dSopenharmony_ci+
3667be168c0dSopenharmony_ci+  mindspore::Graph graph;
3668be168c0dSopenharmony_ci+  auto status = mindspore::Serialization::Load(model_path, static_cast<mindspore::ModelType>(model_type), &graph);
3669be168c0dSopenharmony_ci+  if (status != mindspore::kSuccess) {
3670be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "load ms file failed. " << model_path;
3671be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_ERROR;
3672be168c0dSopenharmony_ci+  }
3673be168c0dSopenharmony_ci+  auto context = static_cast<mindspore::Context *>(model_context);
3674be168c0dSopenharmony_ci+  auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
3675be168c0dSopenharmony_ci+  if (impl->context_.get() != context) {
3676be168c0dSopenharmony_ci+    impl->context_.reset(context);
3677be168c0dSopenharmony_ci+  }
3678be168c0dSopenharmony_ci+  auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
3679be168c0dSopenharmony_ci+                                 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
3680be168c0dSopenharmony_ci+  if (ret != mindspore::kSuccess) {
3681be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Load and compile failed";
3682be168c0dSopenharmony_ci+  }
3683be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3684be168c0dSopenharmony_ci+}
3685be168c0dSopenharmony_ci+
3686be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate) {
3687be168c0dSopenharmony_ci+  if (model == nullptr) {
3688be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3689be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3690be168c0dSopenharmony_ci+  }
3691be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3692be168c0dSopenharmony_ci+  auto ret = impl->model_->SetLearningRate(learning_rate);
3693be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3694be168c0dSopenharmony_ci+}
3695be168c0dSopenharmony_ci+
3696be168c0dSopenharmony_ci+float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model) {
3697be168c0dSopenharmony_ci+  if (model == nullptr) {
3698be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3699be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3700be168c0dSopenharmony_ci+  }
3701be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3702be168c0dSopenharmony_ci+  return impl->model_->GetLearningRate();
3703be168c0dSopenharmony_ci+}
3704be168c0dSopenharmony_ci+
3705be168c0dSopenharmony_ci+OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
3706be168c0dSopenharmony_ci+  if (model == nullptr) {
3707be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3708be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3709be168c0dSopenharmony_ci+  }
3710be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3711be168c0dSopenharmony_ci+  auto ret = impl->model_->RunStep(impl->TransCallBack(before), impl->TransCallBack(after));
3712be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3713be168c0dSopenharmony_ci+}
3714be168c0dSopenharmony_ci+
3715be168c0dSopenharmony_ci+OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model) {
3716be168c0dSopenharmony_ci+  if (model == nullptr) {
3717be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3718be168c0dSopenharmony_ci+    return {0, nullptr};
3719be168c0dSopenharmony_ci+  }
3720be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3721be168c0dSopenharmony_ci+  auto features = impl->model_->GetFeatureMaps();
3722be168c0dSopenharmony_ci+  size_t handle_num = features.size();
3723be168c0dSopenharmony_ci+
3724be168c0dSopenharmony_ci+  mindspore::MSTensor **handle_list = static_cast<mindspore::MSTensor **>(malloc(
3725be168c0dSopenharmony_ci+    handle_num * sizeof(mindspore::MSTensor *)));
3726be168c0dSopenharmony_ci+  if (handle_list == nullptr) {
3727be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Failed to malloc handle_list.";
3728be168c0dSopenharmony_ci+    return {0, nullptr};
3729be168c0dSopenharmony_ci+  }
3730be168c0dSopenharmony_ci+  for (size_t i = 0; i < handle_num; i++) {
3731be168c0dSopenharmony_ci+    handle_list[i] = new mindspore::MSTensor(features[i].impl());
3732be168c0dSopenharmony_ci+  }
3733be168c0dSopenharmony_ci+  return {handle_num, reinterpret_cast<OH_AI_TensorHandle *>(handle_list)};
3734be168c0dSopenharmony_ci+}
3735be168c0dSopenharmony_ci+
3736be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights) {
3737be168c0dSopenharmony_ci+  if (model == nullptr) {
3738be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3739be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3740be168c0dSopenharmony_ci+  }
3741be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3742be168c0dSopenharmony_ci+  std::vector<mindspore::MSTensor> weights;
3743be168c0dSopenharmony_ci+  for (size_t i = 0; i < new_weights.handle_num; i++) {
3744be168c0dSopenharmony_ci+    weights.push_back(*static_cast<mindspore::MSTensor *>(new_weights.handle_list[i]));
3745be168c0dSopenharmony_ci+  }
3746be168c0dSopenharmony_ci+  auto ret = impl->model_->UpdateWeights(weights);
3747be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3748be168c0dSopenharmony_ci+}
3749be168c0dSopenharmony_ci+
3750be168c0dSopenharmony_ci+bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model) {
3751be168c0dSopenharmony_ci+  if (model == nullptr) {
3752be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3753be168c0dSopenharmony_ci+    return false;
3754be168c0dSopenharmony_ci+  }
3755be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3756be168c0dSopenharmony_ci+  return impl->model_->GetTrainMode();
3757be168c0dSopenharmony_ci+}
3758be168c0dSopenharmony_ci+
3759be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train) {
3760be168c0dSopenharmony_ci+  if (model == nullptr) {
3761be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3762be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3763be168c0dSopenharmony_ci+  }
3764be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3765be168c0dSopenharmony_ci+  auto ret = impl->model_->SetTrainMode(train);
3766be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3767be168c0dSopenharmony_ci+}
3768be168c0dSopenharmony_ci+
3769be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, float momentum) {
3770be168c0dSopenharmony_ci+  if (model == nullptr) {
3771be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3772be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3773be168c0dSopenharmony_ci+  }
3774be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3775be168c0dSopenharmony_ci+  auto ret = impl->model_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
3776be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3777be168c0dSopenharmony_ci+}
3778be168c0dSopenharmony_ci+
3779be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
3780be168c0dSopenharmony_ci+                               OH_AI_QuantizationType quantization_type, bool export_inference_only,
3781be168c0dSopenharmony_ci+                               char **output_tensor_name, size_t num) {
3782be168c0dSopenharmony_ci+  if (model == nullptr) {
3783be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3784be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3785be168c0dSopenharmony_ci+  }
3786be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3787be168c0dSopenharmony_ci+  std::vector<std::string> tensor_name;
3788be168c0dSopenharmony_ci+  for (size_t i = 0; i < num; i++) {
3789be168c0dSopenharmony_ci+    tensor_name.push_back(output_tensor_name[i]);
3790be168c0dSopenharmony_ci+  }
3791be168c0dSopenharmony_ci+  auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type),
3792be168c0dSopenharmony_ci+                                                   model_file,
3793be168c0dSopenharmony_ci+                                                   static_cast<mindspore::QuantizationType>(quantization_type),
3794be168c0dSopenharmony_ci+                                                   export_inference_only, tensor_name);
3795be168c0dSopenharmony_ci+  if (!ret.IsOk()) {
3796be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "export model fail, ret :" << ret;
3797be168c0dSopenharmony_ci+  }
3798be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3799be168c0dSopenharmony_ci+}
3800be168c0dSopenharmony_ci+
3801be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data,
3802be168c0dSopenharmony_ci+                                     size_t *data_size, OH_AI_QuantizationType quantization_type,
3803be168c0dSopenharmony_ci+                                     bool export_inference_only, char **output_tensor_name, size_t num) {
3804be168c0dSopenharmony_ci+  if (model == nullptr) {
3805be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3806be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3807be168c0dSopenharmony_ci+  }
3808be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3809be168c0dSopenharmony_ci+  std::vector<std::string> tensor_name;
3810be168c0dSopenharmony_ci+  for (size_t i = 0; i < num; i++) {
3811be168c0dSopenharmony_ci+    tensor_name.push_back(output_tensor_name[i]);
3812be168c0dSopenharmony_ci+  }
3813be168c0dSopenharmony_ci+  mindspore::Buffer buffer;
3814be168c0dSopenharmony_ci+  auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type),
3815be168c0dSopenharmony_ci+                                                   &buffer, static_cast<mindspore::QuantizationType>(quantization_type),
3816be168c0dSopenharmony_ci+                                                   export_inference_only, tensor_name);
3817be168c0dSopenharmony_ci+  auto data = static_cast<char *>(buffer.MutableData());
3818be168c0dSopenharmony_ci+  *model_data = (char *) malloc(buffer.DataSize());
3819be168c0dSopenharmony_ci+  *data_size = buffer.DataSize();
3820be168c0dSopenharmony_ci+  memcpy(*model_data, data, buffer.DataSize());
3821be168c0dSopenharmony_ci+  if (!ret.IsOk()) {
3822be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "export model fail, ret :" << ret;
3823be168c0dSopenharmony_ci+  }
3824be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3825be168c0dSopenharmony_ci+}
3826be168c0dSopenharmony_ci+
3827be168c0dSopenharmony_ci+OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *weight_file,
3828be168c0dSopenharmony_ci+  bool is_inference, bool enable_fp16, char **changeable_weights_name, size_t num) {
3829be168c0dSopenharmony_ci+  if (model == nullptr) {
3830be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "model is nullptr.";
3831be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3832be168c0dSopenharmony_ci+  }
3833be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::ModelC *>(model);
3834be168c0dSopenharmony_ci+  std::vector<std::string> weights_name;
3835be168c0dSopenharmony_ci+  for (size_t i = 0; i < num; i++) {
3836be168c0dSopenharmony_ci+    weights_name.push_back(changeable_weights_name[i]);
3837be168c0dSopenharmony_ci+  }
3838be168c0dSopenharmony_ci+  auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), weight_file, is_inference, enable_fp16, weights_name);
3839be168c0dSopenharmony_ci+  if (!ret.IsOk()) {
3840be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "export model fail, ret :" << ret;
3841be168c0dSopenharmony_ci+  }
3842be168c0dSopenharmony_ci+  return static_cast<OH_AI_Status>(ret.StatusCode());
3843be168c0dSopenharmony_ci+}
3844be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/tensor_c.cc b/mindspore/lite/src/litert/c_api/tensor_c.cc
3845be168c0dSopenharmony_ciindex 7b5c4c2f..4b1e6aff 100644
3846be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/c_api/tensor_c.cc
3847be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/tensor_c.cc
3848be168c0dSopenharmony_ci@@ -17,7 +17,6 @@
3849be168c0dSopenharmony_ci #include "include/api/status.h"
3850be168c0dSopenharmony_ci #include "src/tensor.h"
3851be168c0dSopenharmony_ci #include "src/litert/cxx_api/tensor/tensor_impl.h"
3852be168c0dSopenharmony_ci-#include "src/litert/inner_allocator.h"
3853be168c0dSopenharmony_ci 
3854be168c0dSopenharmony_ci OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, const int64_t *shape, size_t shape_num,
3855be168c0dSopenharmony_ci                                       const void *data, size_t data_len) {
3856be168c0dSopenharmony_ci@@ -31,18 +30,23 @@ OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, con
3857be168c0dSopenharmony_ci   }
3858be168c0dSopenharmony_ci   auto lite_tensor =
3859be168c0dSopenharmony_ci     mindspore::lite::Tensor::CreateTensor(name, static_cast<mindspore::TypeId>(type), vec_shape, data, data_len);
3860be168c0dSopenharmony_ci-  auto impl = new (std::nothrow) mindspore::LiteTensorImpl(lite_tensor);
3861be168c0dSopenharmony_ci-  if (impl == nullptr || impl->lite_tensor() == nullptr) {
3862be168c0dSopenharmony_ci+  auto lite_tensor_impl = std::make_shared<mindspore::LiteTensorImpl>(lite_tensor);
3863be168c0dSopenharmony_ci+  if (lite_tensor_impl == nullptr || lite_tensor_impl->lite_tensor() == nullptr) {
3864be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Failed to allocate tensor impl.";
3865be168c0dSopenharmony_ci     return nullptr;
3866be168c0dSopenharmony_ci   }
3867be168c0dSopenharmony_ci-  impl->set_from_session(false);
3868be168c0dSopenharmony_ci+  lite_tensor_impl->set_from_session(false);
3869be168c0dSopenharmony_ci+  auto impl = new (std::nothrow) mindspore::MSTensor(lite_tensor_impl);
3870be168c0dSopenharmony_ci+  if (impl == nullptr) {
3871be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Failed to allocate MSTensor.";
3872be168c0dSopenharmony_ci+    return nullptr;
3873be168c0dSopenharmony_ci+  }
3874be168c0dSopenharmony_ci   return impl;
3875be168c0dSopenharmony_ci }
3876be168c0dSopenharmony_ci 
3877be168c0dSopenharmony_ci void OH_AI_TensorDestroy(OH_AI_TensorHandle *tensor) {
3878be168c0dSopenharmony_ci   if (tensor != nullptr && *tensor != nullptr) {
3879be168c0dSopenharmony_ci-    auto impl = static_cast<mindspore::LiteTensorImpl *>(*tensor);
3880be168c0dSopenharmony_ci+    auto impl = static_cast<mindspore::MSTensor *>(*tensor);
3881be168c0dSopenharmony_ci     delete impl;
3882be168c0dSopenharmony_ci     *tensor = nullptr;
3883be168c0dSopenharmony_ci   }
3884be168c0dSopenharmony_ci@@ -53,20 +57,14 @@ OH_AI_TensorHandle OH_AI_TensorClone(OH_AI_TensorHandle tensor) {
3885be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3886be168c0dSopenharmony_ci     return nullptr;
3887be168c0dSopenharmony_ci   }
3888be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3889be168c0dSopenharmony_ci-  auto lite_tensor = static_cast<mindspore::lite::Tensor *>(impl->lite_tensor());
3890be168c0dSopenharmony_ci-  auto clone = mindspore::lite::Tensor::CopyTensor(*lite_tensor, true, lite_tensor->allocator());
3891be168c0dSopenharmony_ci-  if (clone == nullptr) {
3892be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Failed to allocate tensor.";
3893be168c0dSopenharmony_ci-    return nullptr;
3894be168c0dSopenharmony_ci-  }
3895be168c0dSopenharmony_ci-  auto clone_impl = new (std::nothrow) mindspore::LiteTensorImpl(clone);
3896be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3897be168c0dSopenharmony_ci+  auto clone_impl = impl->Clone();
3898be168c0dSopenharmony_ci   if (clone_impl == nullptr) {
3899be168c0dSopenharmony_ci-    delete clone;
3900be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Failed to allocate tensor impl.";
3901be168c0dSopenharmony_ci     return nullptr;
3902be168c0dSopenharmony_ci   }
3903be168c0dSopenharmony_ci-  clone_impl->set_from_session(false);
3904be168c0dSopenharmony_ci+  std::static_pointer_cast<mindspore::LiteTensorImpl>(clone_impl->impl())->set_own_data(false);
3905be168c0dSopenharmony_ci+  clone_impl->SetTensorName(impl->Name() + "_duplicate");
3906be168c0dSopenharmony_ci   return clone_impl;
3907be168c0dSopenharmony_ci }
3908be168c0dSopenharmony_ci 
3909be168c0dSopenharmony_ci@@ -75,8 +73,8 @@ void OH_AI_TensorSetName(OH_AI_TensorHandle tensor, const char *name) {
3910be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3911be168c0dSopenharmony_ci     return;
3912be168c0dSopenharmony_ci   }
3913be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3914be168c0dSopenharmony_ci-  impl->SetName(name);
3915be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3916be168c0dSopenharmony_ci+  impl->SetTensorName(name);
3917be168c0dSopenharmony_ci }
3918be168c0dSopenharmony_ci 
3919be168c0dSopenharmony_ci const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) {
3920be168c0dSopenharmony_ci@@ -84,8 +82,8 @@ const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) {
3921be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3922be168c0dSopenharmony_ci     return nullptr;
3923be168c0dSopenharmony_ci   }
3924be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3925be168c0dSopenharmony_ci-  return impl->Name().c_str();
3926be168c0dSopenharmony_ci+  auto ms_tensor = static_cast<mindspore::MSTensor *>(tensor);
3927be168c0dSopenharmony_ci+  return std::static_pointer_cast<mindspore::LiteTensorImpl>(ms_tensor->impl())->Name().c_str();
3928be168c0dSopenharmony_ci }
3929be168c0dSopenharmony_ci 
3930be168c0dSopenharmony_ci void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) {
3931be168c0dSopenharmony_ci@@ -93,7 +91,7 @@ void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) {
3932be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3933be168c0dSopenharmony_ci     return;
3934be168c0dSopenharmony_ci   }
3935be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3936be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3937be168c0dSopenharmony_ci   impl->SetDataType(static_cast<mindspore::DataType>(type));
3938be168c0dSopenharmony_ci }
3939be168c0dSopenharmony_ci 
3940be168c0dSopenharmony_ci@@ -102,7 +100,7 @@ OH_AI_DataType OH_AI_TensorGetDataType(const OH_AI_TensorHandle tensor) {
3941be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3942be168c0dSopenharmony_ci     return OH_AI_DATATYPE_UNKNOWN;
3943be168c0dSopenharmony_ci   }
3944be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3945be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3946be168c0dSopenharmony_ci   auto dtype = impl->DataType();
3947be168c0dSopenharmony_ci   return static_cast<OH_AI_DataType>(dtype);
3948be168c0dSopenharmony_ci }
3949be168c0dSopenharmony_ci@@ -112,7 +110,7 @@ void OH_AI_TensorSetShape(OH_AI_TensorHandle tensor, const int64_t *shape, size_
3950be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3951be168c0dSopenharmony_ci     return;
3952be168c0dSopenharmony_ci   }
3953be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3954be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3955be168c0dSopenharmony_ci   std::vector<int64_t> vec_shape(shape_num);
3956be168c0dSopenharmony_ci   for (size_t i = 0; i < shape_num; i++) {
3957be168c0dSopenharmony_ci     vec_shape[i] = shape[i];
3958be168c0dSopenharmony_ci@@ -125,7 +123,7 @@ const int64_t *OH_AI_TensorGetShape(const OH_AI_TensorHandle tensor, size_t *sha
3959be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3960be168c0dSopenharmony_ci     return nullptr;
3961be168c0dSopenharmony_ci   }
3962be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3963be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3964be168c0dSopenharmony_ci   *shape_num = impl->Shape().size();
3965be168c0dSopenharmony_ci   return impl->Shape().data();
3966be168c0dSopenharmony_ci }
3967be168c0dSopenharmony_ci@@ -135,7 +133,7 @@ void OH_AI_TensorSetFormat(OH_AI_TensorHandle tensor, OH_AI_Format format) {
3968be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3969be168c0dSopenharmony_ci     return;
3970be168c0dSopenharmony_ci   }
3971be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3972be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3973be168c0dSopenharmony_ci   return impl->SetFormat(static_cast<mindspore::Format>(format));
3974be168c0dSopenharmony_ci }
3975be168c0dSopenharmony_ci 
3976be168c0dSopenharmony_ci@@ -144,8 +142,8 @@ OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor) {
3977be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3978be168c0dSopenharmony_ci     return OH_AI_FORMAT_NHWC;
3979be168c0dSopenharmony_ci   }
3980be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3981be168c0dSopenharmony_ci-  return static_cast<OH_AI_Format>(impl->Format());
3982be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3983be168c0dSopenharmony_ci+  return static_cast<OH_AI_Format>(impl->format());
3984be168c0dSopenharmony_ci }
3985be168c0dSopenharmony_ci 
3986be168c0dSopenharmony_ci void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) {
3987be168c0dSopenharmony_ci@@ -153,16 +151,34 @@ void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) {
3988be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
3989be168c0dSopenharmony_ci     return;
3990be168c0dSopenharmony_ci   }
3991be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3992be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3993be168c0dSopenharmony_ci   return impl->SetData(data, true);
3994be168c0dSopenharmony_ci }
3995be168c0dSopenharmony_ci 
3996be168c0dSopenharmony_ci+OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size) {
3997be168c0dSopenharmony_ci+  if (tensor == nullptr) {
3998be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "param is nullptr.";
3999be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_NULLPTR;
4000be168c0dSopenharmony_ci+  }
4001be168c0dSopenharmony_ci+
4002be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4003be168c0dSopenharmony_ci+  if ((impl->DataSize() > 0) && (data_size != impl->DataSize())) {
4004be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "input data size does not match inner data size";
4005be168c0dSopenharmony_ci+    return OH_AI_STATUS_LITE_PARAM_INVALID;
4006be168c0dSopenharmony_ci+  }
4007be168c0dSopenharmony_ci+
4008be168c0dSopenharmony_ci+  // This is one tricky way to represent that the inner data is not owned by tensor itself.
4009be168c0dSopenharmony_ci+  impl->SetAllocator(nullptr);
4010be168c0dSopenharmony_ci+  impl->SetData(data, false);
4011be168c0dSopenharmony_ci+  return OH_AI_STATUS_SUCCESS;
4012be168c0dSopenharmony_ci+}
4013be168c0dSopenharmony_ci+
4014be168c0dSopenharmony_ci const void *OH_AI_TensorGetData(const OH_AI_TensorHandle tensor) {
4015be168c0dSopenharmony_ci   if (tensor == nullptr) {
4016be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
4017be168c0dSopenharmony_ci     return nullptr;
4018be168c0dSopenharmony_ci   }
4019be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
4020be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4021be168c0dSopenharmony_ci   return impl->Data().get();
4022be168c0dSopenharmony_ci }
4023be168c0dSopenharmony_ci 
4024be168c0dSopenharmony_ci@@ -171,7 +187,7 @@ void *OH_AI_TensorGetMutableData(const OH_AI_TensorHandle tensor) {
4025be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
4026be168c0dSopenharmony_ci     return nullptr;
4027be168c0dSopenharmony_ci   }
4028be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
4029be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4030be168c0dSopenharmony_ci   return impl->MutableData();
4031be168c0dSopenharmony_ci }
4032be168c0dSopenharmony_ci 
4033be168c0dSopenharmony_ci@@ -180,7 +196,7 @@ int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor) {
4034be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
4035be168c0dSopenharmony_ci     return 0;
4036be168c0dSopenharmony_ci   }
4037be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
4038be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4039be168c0dSopenharmony_ci   return impl->ElementNum();
4040be168c0dSopenharmony_ci }
4041be168c0dSopenharmony_ci 
4042be168c0dSopenharmony_ci@@ -189,6 +205,6 @@ size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor) {
4043be168c0dSopenharmony_ci     MS_LOG(ERROR) << "param is nullptr.";
4044be168c0dSopenharmony_ci     return 0;
4045be168c0dSopenharmony_ci   }
4046be168c0dSopenharmony_ci-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
4047be168c0dSopenharmony_ci+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4048be168c0dSopenharmony_ci   return impl->DataSize();
4049be168c0dSopenharmony_ci }
4050be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/c_api/type_c_private.h b/mindspore/lite/src/litert/c_api/type_c_private.h
4051be168c0dSopenharmony_cinew file mode 100644
4052be168c0dSopenharmony_ciindex 00000000..2d3b3883
4053be168c0dSopenharmony_ci--- /dev/null
4054be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/c_api/type_c_private.h
4055be168c0dSopenharmony_ci@@ -0,0 +1,40 @@
4056be168c0dSopenharmony_ci+/**
4057be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
4058be168c0dSopenharmony_ci+ *
4059be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
4060be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
4061be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
4062be168c0dSopenharmony_ci+ *
4063be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
4064be168c0dSopenharmony_ci+ *
4065be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
4066be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
4067be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4068be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
4069be168c0dSopenharmony_ci+ * limitations under the License.
4070be168c0dSopenharmony_ci+ */
4071be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_
4072be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_
4073be168c0dSopenharmony_ci+
4074be168c0dSopenharmony_ci+#include <string>
4075be168c0dSopenharmony_ci+#include <vector>
4076be168c0dSopenharmony_ci+#include <memory>
4077be168c0dSopenharmony_ci+#include <stddef.h>
4078be168c0dSopenharmony_ci+#include "include/c_api/types_c.h"
4079be168c0dSopenharmony_ci+
4080be168c0dSopenharmony_ci+#ifdef __cplusplus
4081be168c0dSopenharmony_ci+extern "C" {
4082be168c0dSopenharmony_ci+#endif
4083be168c0dSopenharmony_ci+
4084be168c0dSopenharmony_ci+#define NNRT_DEVICE_NAME_MAX (128)
4085be168c0dSopenharmony_ci+
4086be168c0dSopenharmony_ci+struct NNRTDeviceDesc {
4087be168c0dSopenharmony_ci+  size_t device_id;
4088be168c0dSopenharmony_ci+  OH_AI_NNRTDeviceType device_type;
4089be168c0dSopenharmony_ci+  char device_name[NNRT_DEVICE_NAME_MAX];
4090be168c0dSopenharmony_ci+};
4091be168c0dSopenharmony_ci+
4092be168c0dSopenharmony_ci+#ifdef __cplusplus
4093be168c0dSopenharmony_ci+}
4094be168c0dSopenharmony_ci+#endif
4095be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_
4096be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/context.cc b/mindspore/lite/src/litert/cxx_api/context.cc
4097be168c0dSopenharmony_ciindex 1371bcf0..e5f19d28 100644
4098be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/context.cc
4099be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/context.cc
4100be168c0dSopenharmony_ci@@ -50,6 +50,11 @@ constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dyn
4101be168c0dSopenharmony_ci constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size";
4102be168c0dSopenharmony_ci constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize";
4103be168c0dSopenharmony_ci constexpr auto kModelOptionAscendRankID = "mindspore.option.ascend.rank_id";
4104be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTDeviceID = "mindspore.option.nnrt.device_id";
4105be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTPerformanceMode = "mindspore.option.nnrt.performance_mode";
4106be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTPriority = "mindspore.option.nnrt.priority";
4107be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTEnableFP16 = "mindspore.option.nnrt.enable_fp16";
4108be168c0dSopenharmony_ci+constexpr auto kModelOptionNNRTExtensions = "mindspore.option.nnrt.extensions";
4109be168c0dSopenharmony_ci #ifdef USE_GLOG
4110be168c0dSopenharmony_ci extern "C" {
4111be168c0dSopenharmony_ci extern void mindspore_log_init();
4112be168c0dSopenharmony_ci@@ -684,4 +689,84 @@ std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
4113be168c0dSopenharmony_ci   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendBufferOptimize);
4114be168c0dSopenharmony_ci   return StringToChar(ref);
4115be168c0dSopenharmony_ci }
4116be168c0dSopenharmony_ci+
4117be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetDeviceID(size_t device_id) {
4118be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4119be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4120be168c0dSopenharmony_ci+    return;
4121be168c0dSopenharmony_ci+  }
4122be168c0dSopenharmony_ci+  data_->params[kModelOptionNNRTDeviceID] = device_id;
4123be168c0dSopenharmony_ci+}
4124be168c0dSopenharmony_ci+
4125be168c0dSopenharmony_ci+size_t NNRTDeviceInfo::GetDeviceID() const {
4126be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4127be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4128be168c0dSopenharmony_ci+    return 0;
4129be168c0dSopenharmony_ci+  }
4130be168c0dSopenharmony_ci+  return GetValue<size_t>(data_, kModelOptionNNRTDeviceID);
4131be168c0dSopenharmony_ci+}
4132be168c0dSopenharmony_ci+
4133be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetPerformanceMode(int performance_mode) {
4134be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4135be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4136be168c0dSopenharmony_ci+    return;
4137be168c0dSopenharmony_ci+  }
4138be168c0dSopenharmony_ci+  data_->params[kModelOptionNNRTPerformanceMode] = performance_mode;
4139be168c0dSopenharmony_ci+}
4140be168c0dSopenharmony_ci+
4141be168c0dSopenharmony_ci+int NNRTDeviceInfo::GetPerformanceMode() const {
4142be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4143be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4144be168c0dSopenharmony_ci+    return 0;
4145be168c0dSopenharmony_ci+  }
4146be168c0dSopenharmony_ci+  return GetValue<int>(data_, kModelOptionNNRTPerformanceMode);
4147be168c0dSopenharmony_ci+}
4148be168c0dSopenharmony_ci+
4149be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetPriority(int priority) {
4150be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4151be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4152be168c0dSopenharmony_ci+    return;
4153be168c0dSopenharmony_ci+  }
4154be168c0dSopenharmony_ci+  data_->params[kModelOptionNNRTPriority] = priority;
4155be168c0dSopenharmony_ci+}
4156be168c0dSopenharmony_ci+
4157be168c0dSopenharmony_ci+int NNRTDeviceInfo::GetPriority() const {
4158be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4159be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4160be168c0dSopenharmony_ci+    return 0;
4161be168c0dSopenharmony_ci+  }
4162be168c0dSopenharmony_ci+  return GetValue<int>(data_, kModelOptionNNRTPriority);
4163be168c0dSopenharmony_ci+}
4164be168c0dSopenharmony_ci+
4165be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetEnableFP16(bool is_fp16) {
4166be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4167be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4168be168c0dSopenharmony_ci+    return;
4169be168c0dSopenharmony_ci+  }
4170be168c0dSopenharmony_ci+  data_->params[kModelOptionNNRTEnableFP16] = is_fp16;
4171be168c0dSopenharmony_ci+}
4172be168c0dSopenharmony_ci+
4173be168c0dSopenharmony_ci+bool NNRTDeviceInfo::GetEnableFP16() const {
4174be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4175be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4176be168c0dSopenharmony_ci+    return false;
4177be168c0dSopenharmony_ci+  }
4178be168c0dSopenharmony_ci+  return GetValue<bool>(data_, kModelOptionNNRTEnableFP16);
4179be168c0dSopenharmony_ci+}
4180be168c0dSopenharmony_ci+
4181be168c0dSopenharmony_ci+void NNRTDeviceInfo::SetExtensions(const std::vector<Extension> &extensions) {
4182be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4183be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4184be168c0dSopenharmony_ci+    return;
4185be168c0dSopenharmony_ci+  }
4186be168c0dSopenharmony_ci+  data_->params[kModelOptionNNRTExtensions] = extensions;
4187be168c0dSopenharmony_ci+}
4188be168c0dSopenharmony_ci+
4189be168c0dSopenharmony_ci+std::vector<Extension> NNRTDeviceInfo::GetExtensions() const {
4190be168c0dSopenharmony_ci+  if (data_ == nullptr) {
4191be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid context.";
4192be168c0dSopenharmony_ci+    return {};
4193be168c0dSopenharmony_ci+  }
4194be168c0dSopenharmony_ci+  return GetValue<std::vector<Extension>>(data_, kModelOptionNNRTExtensions);
4195be168c0dSopenharmony_ci+}
4196be168c0dSopenharmony_ci }  // namespace mindspore
4197be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/converters.cc b/mindspore/lite/src/litert/cxx_api/converters.cc
4198be168c0dSopenharmony_ciindex 0ff345cc..e54a36ee 100644
4199be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/converters.cc
4200be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/converters.cc
4201be168c0dSopenharmony_ci@@ -86,6 +86,23 @@ Status ContextUtils::AddCustomDevice(lite::InnerContext *inner_context,
4202be168c0dSopenharmony_ci   return kSuccess;
4203be168c0dSopenharmony_ci }
4204be168c0dSopenharmony_ci 
4205be168c0dSopenharmony_ci+Status ContextUtils::AddNNRtDevice(lite::InnerContext *inner_context, size_t device_id, int performance_mode,
4206be168c0dSopenharmony_ci+                                   int priority, bool enable_fp16, const std::vector<Extension> &extensions) {
4207be168c0dSopenharmony_ci+  lite::DeviceInfo device_info = {0};
4208be168c0dSopenharmony_ci+  device_info.nnrt_device_info_.device_id_ = device_id;
4209be168c0dSopenharmony_ci+  device_info.nnrt_device_info_.performance_mode_ = performance_mode;
4210be168c0dSopenharmony_ci+  device_info.nnrt_device_info_.priority_ = priority;
4211be168c0dSopenharmony_ci+  device_info.nnrt_device_info_.enable_fp16_ = enable_fp16;
4212be168c0dSopenharmony_ci+  for (auto src_extension: extensions) {
4213be168c0dSopenharmony_ci+    lite::Extension dest_extension;
4214be168c0dSopenharmony_ci+    dest_extension.name = src_extension.name;
4215be168c0dSopenharmony_ci+    dest_extension.value = src_extension.value;
4216be168c0dSopenharmony_ci+    device_info.nnrt_device_info_.extensions_.push_back(dest_extension);
4217be168c0dSopenharmony_ci+  }
4218be168c0dSopenharmony_ci+  inner_context->device_list_.push_back({lite::DT_NNRT, device_info});
4219be168c0dSopenharmony_ci+  return kSuccess;
4220be168c0dSopenharmony_ci+}
4221be168c0dSopenharmony_ci+
4222be168c0dSopenharmony_ci void ContextUtils::ResetContextDefaultParam(Context *context) {
4223be168c0dSopenharmony_ci   if (context->GetInterOpParallelNum() == 0) {
4224be168c0dSopenharmony_ci     context->SetInterOpParallelNum(kDefaultInterOpParallelNum);
4225be168c0dSopenharmony_ci@@ -163,44 +180,11 @@ std::shared_ptr<lite::InnerContext> ContextUtils::Convert(Context *context) {
4226be168c0dSopenharmony_ci       ret = AddAscendDevice(inner_context.get(), device.get());
4227be168c0dSopenharmony_ci     } else if (device->GetDeviceType() == kCustomDevice) {
4228be168c0dSopenharmony_ci       ret = AddCustomDevice(inner_context.get(), device);
4229be168c0dSopenharmony_ci-    }
4230be168c0dSopenharmony_ci-    if (ret != kSuccess) {
4231be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Add device failed!";
4232be168c0dSopenharmony_ci-      return nullptr;
4233be168c0dSopenharmony_ci-    }
4234be168c0dSopenharmony_ci-  }
4235be168c0dSopenharmony_ci-  return inner_context;
4236be168c0dSopenharmony_ci-}
4237be168c0dSopenharmony_ci-
4238be168c0dSopenharmony_ci-std::shared_ptr<lite::InnerContext> ContextUtils::Convert(const ContextC *context_c) {
4239be168c0dSopenharmony_ci-  auto inner_context = std::make_shared<lite::InnerContext>();
4240be168c0dSopenharmony_ci-  if ((context_c == nullptr) || (inner_context == nullptr)) {
4241be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Invalid context pointers.";
4242be168c0dSopenharmony_ci-    return nullptr;
4243be168c0dSopenharmony_ci-  }
4244be168c0dSopenharmony_ci-  auto device_list = context_c->device_info_list;
4245be168c0dSopenharmony_ci-  if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) {
4246be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices;
4247be168c0dSopenharmony_ci-    return nullptr;
4248be168c0dSopenharmony_ci-  }
4249be168c0dSopenharmony_ci-  SetContextAttr(context_c->thread_num, 1, context_c->enable_parallel, context_c->affinity_core_list,
4250be168c0dSopenharmony_ci-                 context_c->delegate_mode, context_c->delegate, inner_context.get());
4251be168c0dSopenharmony_ci-  inner_context->device_list_.clear();
4252be168c0dSopenharmony_ci-  Status ret = kLiteError;
4253be168c0dSopenharmony_ci-  for (auto &device_info_c : device_list) {
4254be168c0dSopenharmony_ci-    MS_CHECK_TRUE_RET(device_info_c != nullptr, nullptr);
4255be168c0dSopenharmony_ci-    lite::DeviceInfo device_info = {{0}};
4256be168c0dSopenharmony_ci-    if (device_info_c->device_type == OH_AI_DEVICETYPE_CPU) {
4257be168c0dSopenharmony_ci-      if (device_info_c->allocator == nullptr) {
4258be168c0dSopenharmony_ci-        device_info_c->allocator = Allocator::Create();
4259be168c0dSopenharmony_ci-      }
4260be168c0dSopenharmony_ci-      ret = AddCpuDevice(device_info_c->allocator, context_c->affinity_mode, device_info_c->enable_fp16,
4261be168c0dSopenharmony_ci-                         device_info_c->provider, device_info_c->provider_device, inner_context.get());
4262be168c0dSopenharmony_ci-    } else if (device_info_c->device_type == OH_AI_DEVICETYPE_GPU) {
4263be168c0dSopenharmony_ci-      ret = AddGpuDevice(device_info_c->enable_fp16, 0, 0, 0, false, nullptr, nullptr, device_info_c->provider,
4264be168c0dSopenharmony_ci-                         device_info_c->provider_device, device_info_c->allocator, inner_context.get());
4265be168c0dSopenharmony_ci-    } else if (device_info_c->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
4266be168c0dSopenharmony_ci-      ret = AddNpuDevice(device_info_c->enable_fp16, device_info_c->frequency, inner_context.get());
4267be168c0dSopenharmony_ci+    } else if (device->GetDeviceType() == kNNRt) {
4268be168c0dSopenharmony_ci+      auto nnrt_device_info = device->Cast<NNRTDeviceInfo>();
4269be168c0dSopenharmony_ci+      ret = AddNNRtDevice(inner_context.get(), nnrt_device_info->GetDeviceID(),
4270be168c0dSopenharmony_ci+                          nnrt_device_info->GetPerformanceMode(), nnrt_device_info->GetPriority(),
4271be168c0dSopenharmony_ci+                          nnrt_device_info->GetEnableFP16(), nnrt_device_info->GetExtensions());
4272be168c0dSopenharmony_ci     }
4273be168c0dSopenharmony_ci     if (ret != kSuccess) {
4274be168c0dSopenharmony_ci       MS_LOG(ERROR) << "Add device failed!";
4275be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/cxx_api/converters.h b/mindspore/lite/src/litert/cxx_api/converters.h
4276be168c0dSopenharmony_ciindex 0c043fc3..1af7c7df 100644
4277be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/cxx_api/converters.h
4278be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/cxx_api/converters.h
4279be168c0dSopenharmony_ci@@ -24,14 +24,12 @@
4280be168c0dSopenharmony_ci #include "include/api/cfg.h"
4281be168c0dSopenharmony_ci #include "include/train/train_cfg.h"
4282be168c0dSopenharmony_ci #include "src/litert/inner_context.h"
4283be168c0dSopenharmony_ci-#include "src/litert/c_api/context_c.h"
4284be168c0dSopenharmony_ci #include "src/common/log_adapter.h"
4285be168c0dSopenharmony_ci 
4286be168c0dSopenharmony_ci namespace mindspore {
4287be168c0dSopenharmony_ci class MS_API ContextUtils {
4288be168c0dSopenharmony_ci  public:
4289be168c0dSopenharmony_ci   static std::shared_ptr<lite::InnerContext> Convert(Context *context);
4290be168c0dSopenharmony_ci-  static std::shared_ptr<lite::InnerContext> Convert(const ContextC *context_c);
4291be168c0dSopenharmony_ci 
4292be168c0dSopenharmony_ci  private:
4293be168c0dSopenharmony_ci   static void SetContextAttr(int32_t thread_num, int32_t inter_op_parallel_num, bool enable_parallel,
4294be168c0dSopenharmony_ci@@ -48,6 +46,8 @@ class MS_API ContextUtils {
4295be168c0dSopenharmony_ci   static Status AddNpuDevice(bool enable_fp16, int frequency, lite::InnerContext *inner_context);
4296be168c0dSopenharmony_ci   static Status AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device);
4297be168c0dSopenharmony_ci   static Status AddCustomDevice(lite::InnerContext *inner_context, const std::shared_ptr<DeviceInfoContext> &device);
4298be168c0dSopenharmony_ci+  static Status AddNNRtDevice(lite::InnerContext *inner_context, size_t device_id, int performance_mode, int priority,
4299be168c0dSopenharmony_ci+                              bool enable_fp16, const std::vector<Extension> &extensions);
4300be168c0dSopenharmony_ci   static bool IsAffinityModeValid(int affinity_mode) {
4301be168c0dSopenharmony_ci     return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
4302be168c0dSopenharmony_ci   }
4303be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt b/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt
4304be168c0dSopenharmony_ciindex 70aa63f3..625459e2 100644
4305be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt
4306be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt
4307be168c0dSopenharmony_ci@@ -1,30 +1,13 @@
4308be168c0dSopenharmony_ci include_directories(${DDK_PATH})
4309be168c0dSopenharmony_ci include_directories($(CCSRC_DIR)/plugin/device/cpu/kernel)
4310be168c0dSopenharmony_ci+include_directories(${CMAKE_SOURCE_DIR}/../../../../../../foundation/ai/neural_network_runtime/)
4311be168c0dSopenharmony_ci 
4312be168c0dSopenharmony_ci include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
4313be168c0dSopenharmony_ci-#include_directories(/home/tony/wty/workspace/ohos/third_party/mindspore/mindspore/lite/mindir/include/inner)
4314be168c0dSopenharmony_ci-#include_directories(/home/tony/wty/workspace/ohos/third_party/mindspore/mindspore/lite/mindir/include)
4315be168c0dSopenharmony_ci+
4316be168c0dSopenharmony_ci file(GLOB_RECURSE NNRT_SRC
4317be168c0dSopenharmony_ci         ${CMAKE_CURRENT_SOURCE_DIR}/*.cc
4318be168c0dSopenharmony_ci )
4319be168c0dSopenharmony_ci-
4320be168c0dSopenharmony_ci-#add_library(hiai SHARED IMPORTED)
4321be168c0dSopenharmony_ci-#set_target_properties(hiai PROPERTIES IMPORTED_LOCATION
4322be168c0dSopenharmony_ci-#        ${DDK_LIB_PATH}/libhiai.so)
4323be168c0dSopenharmony_ci-#add_library(hiai_ir SHARED IMPORTED)
4324be168c0dSopenharmony_ci-#set_target_properties(hiai_ir PROPERTIES IMPORTED_LOCATION
4325be168c0dSopenharmony_ci-#        ${DDK_LIB_PATH}/libhiai_ir.so)
4326be168c0dSopenharmony_ci-#add_library(hiai_ir_build SHARED IMPORTED)
4327be168c0dSopenharmony_ci-#set_target_properties(hiai_ir_build PROPERTIES IMPORTED_LOCATION
4328be168c0dSopenharmony_ci-#        ${DDK_LIB_PATH}/libhiai_ir_build.so)
4329be168c0dSopenharmony_ci-#add_library(npu_kernel_mid OBJECT ${NPU_RUNTIME_SRC})
4330be168c0dSopenharmony_ci-#add_dependencies(npu_kernel_mid fbs_src)
4331be168c0dSopenharmony_ci-#target_link_libraries(
4332be168c0dSopenharmony_ci-#        npu_kernel_mid
4333be168c0dSopenharmony_ci-#        hiai
4334be168c0dSopenharmony_ci-#        hiai_ir
4335be168c0dSopenharmony_ci-#        hiai_ir_build
4336be168c0dSopenharmony_ci-#)
4337be168c0dSopenharmony_ci-
4338be168c0dSopenharmony_ci file(GLOB convert_source checker/*.cc)
4339be168c0dSopenharmony_ci-add_library(nnr_mid OBJECT ${NNRT_SRC} ${convert_source} )
4340be168c0dSopenharmony_ci\ No newline at end of file
4341be168c0dSopenharmony_ci+
4342be168c0dSopenharmony_ci+add_library(nnrt_mid OBJECT ${NNRT_SRC} ${convert_source})
4343be168c0dSopenharmony_ci+target_include_directories(nnrt_mid PUBLIC ${CMAKE_SOURCE_DIR}/../../../../../../foundation/ai/neural_network_runtime/)
4344be168c0dSopenharmony_ci\ No newline at end of file
4345be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc b/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc
4346be168c0dSopenharmony_ciindex 4df7e477..6b191c8e 100644
4347be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc
4348be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc
4349be168c0dSopenharmony_ci@@ -109,6 +109,8 @@ Status CheckPrimitiveSupported(const schema::Primitive *primitive) {
4350be168c0dSopenharmony_ci         return mindspore::kSuccess;
4351be168c0dSopenharmony_ci       case schema::PrimitiveType_Unsqueeze:
4352be168c0dSopenharmony_ci         return mindspore::kSuccess;
4353be168c0dSopenharmony_ci+      case schema::PrimitiveType_Custom:
4354be168c0dSopenharmony_ci+        return mindspore::kSuccess;
4355be168c0dSopenharmony_ci       default: {
4356be168c0dSopenharmony_ci         MS_LOG(WARNING) << "No primitive type :" << (int)(type);
4357be168c0dSopenharmony_ci         return mindspore::kLiteSuccessExit;
4358be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc
4359be168c0dSopenharmony_ciindex 34897331..9f012e76 100644
4360be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc
4361be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc
4362be168c0dSopenharmony_ci@@ -13,144 +13,637 @@
4363be168c0dSopenharmony_ci  * See the License for the specific language governing permissions and
4364be168c0dSopenharmony_ci  * limitations under the License.
4365be168c0dSopenharmony_ci  */
4366be168c0dSopenharmony_ci+
4367be168c0dSopenharmony_ci+#include <unordered_set>
4368be168c0dSopenharmony_ci+#include <numeric>
4369be168c0dSopenharmony_ci #include "nnrt_delegate.h"
4370be168c0dSopenharmony_ci #include "checker/primitive_check.h"
4371be168c0dSopenharmony_ci #include "src/common/log_adapter.h"
4372be168c0dSopenharmony_ci-#include "interfaces/kits/c/neural_network_runtime.h"
4373be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
4374be168c0dSopenharmony_ci #include "interfaces/innerkits/c/neural_network_runtime_inner.h"
4375be168c0dSopenharmony_ci #include "nnrt_model_kernel.h"
4376be168c0dSopenharmony_ci+#include "schema/model_generated.h"
4377be168c0dSopenharmony_ci+#include "schema/ops_generated.h"
4378be168c0dSopenharmony_ci+#include "flatbuffers/flatbuffers.h"
4379be168c0dSopenharmony_ci+#include "litert/tensor_category.h"
4380be168c0dSopenharmony_ci+
4381be168c0dSopenharmony_ci+namespace mindspore {
4382be168c0dSopenharmony_ci+namespace lite {
4383be168c0dSopenharmony_ci+void NNRTDelegate::InitCachePath() {
4384be168c0dSopenharmony_ci+  static const std::string kCachePathName = "CachePath";
4385be168c0dSopenharmony_ci+  static const std::string kCacheVersion = "CacheVersion";
4386be168c0dSopenharmony_ci+
4387be168c0dSopenharmony_ci+  const auto &extensions = nnrt_device_info_.extensions_;
4388be168c0dSopenharmony_ci 
4389be168c0dSopenharmony_ci-mindspore::Status mindspore::NNRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
4390be168c0dSopenharmony_ci-  if (this->nnrt_lite_graph == nullptr) {
4391be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "nnrt_lite_graph is nullptr.";
4392be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4393be168c0dSopenharmony_ci+  auto iter_path = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
4394be168c0dSopenharmony_ci+    return extension.name == kCachePathName;
4395be168c0dSopenharmony_ci+  });
4396be168c0dSopenharmony_ci+  if (iter_path != extensions.end()) {
4397be168c0dSopenharmony_ci+    cache_path_ = std::string(iter_path->value.begin(), iter_path->value.end());
4398be168c0dSopenharmony_ci   }
4399be168c0dSopenharmony_ci-  if (this->nnrt_lite_graph->sub_graphs_.empty()) {
4400be168c0dSopenharmony_ci-    // must have at lease one subgraph
4401be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "must have at lease one subgraph";
4402be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4403be168c0dSopenharmony_ci+
4404be168c0dSopenharmony_ci+  auto iter_version = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
4405be168c0dSopenharmony_ci+    return extension.name == kCacheVersion;
4406be168c0dSopenharmony_ci+  });
4407be168c0dSopenharmony_ci+  if (iter_version != extensions.end()) {
4408be168c0dSopenharmony_ci+    std::string version_str = std::string(iter_version->value.begin(), iter_version->value.end());
4409be168c0dSopenharmony_ci+    cache_version_ = static_cast<uint32_t>(std::atol(version_str.c_str()));
4410be168c0dSopenharmony_ci   }
4411be168c0dSopenharmony_ci-  OH_NN_ReturnCode ret_code;
4412be168c0dSopenharmony_ci-  OH_NNModel *oh_nnmodel = OH_NNModel_Construct();
4413be168c0dSopenharmony_ci-  if (oh_nnmodel == nullptr) {
4414be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Construct NNModel failed, oh_nnmodel is nullptr.";
4415be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4416be168c0dSopenharmony_ci+}
4417be168c0dSopenharmony_ci+
4418be168c0dSopenharmony_ci+Status NNRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
4419be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT_METAGRAPH
4420be168c0dSopenharmony_ci+  if (IsKirinNPU()) {
4421be168c0dSopenharmony_ci+    MS_LOG(DEBUG) << "Choose to build nnrt model with Metagraph";
4422be168c0dSopenharmony_ci+    InitCachePath();
4423be168c0dSopenharmony_ci+    return BuildKirinNPUModel(model);
4424be168c0dSopenharmony_ci   }
4425be168c0dSopenharmony_ci+#endif
4426be168c0dSopenharmony_ci 
4427be168c0dSopenharmony_ci-  ret_code = OH_NNModel_BuildFromLiteGraph(oh_nnmodel, this->nnrt_lite_graph);
4428be168c0dSopenharmony_ci-  if (ret_code != OH_NN_SUCCESS) {
4429be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Build NNModel failed, OH_NN_ReturnCode = " << ret_code;
4430be168c0dSopenharmony_ci-    OH_NNModel_Destroy(&oh_nnmodel);
4431be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4432be168c0dSopenharmony_ci+  return BuildNormalModel(model);
4433be168c0dSopenharmony_ci+}
4434be168c0dSopenharmony_ci+
4435be168c0dSopenharmony_ci+bool NNRTDelegate::IsCustomModel() const {
4436be168c0dSopenharmony_ci+  // check if there is only one Cutsom kernel in LiteModel.
4437be168c0dSopenharmony_ci+  if (lite_graph_ == nullptr) {
4438be168c0dSopenharmony_ci+    return false;
4439be168c0dSopenharmony_ci+  }
4440be168c0dSopenharmony_ci+  if (lite_graph_->all_nodes_.size() != 1) {
4441be168c0dSopenharmony_ci+    return false;
4442be168c0dSopenharmony_ci+  }
4443be168c0dSopenharmony_ci+  auto node = lite_graph_->all_nodes_[0];
4444be168c0dSopenharmony_ci+  if (node == nullptr) {
4445be168c0dSopenharmony_ci+    return false;
4446be168c0dSopenharmony_ci+  }
4447be168c0dSopenharmony_ci+  if (node->node_type_ != mindspore::schema::PrimitiveType_Custom) {
4448be168c0dSopenharmony_ci+    return false;
4449be168c0dSopenharmony_ci+  }
4450be168c0dSopenharmony_ci+  return true;
4451be168c0dSopenharmony_ci+}
4452be168c0dSopenharmony_ci+
4453be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT_METAGRAPH
4454be168c0dSopenharmony_ci+bool NNRTDelegate::IsKirinNPU() const {
4455be168c0dSopenharmony_ci+  const std::string kirin_npu_name_prefix = "NPU_";
4456be168c0dSopenharmony_ci+  auto device_id = nnrt_device_info_.device_id_;
4457be168c0dSopenharmony_ci+  const char *device_name;
4458be168c0dSopenharmony_ci+  auto ret = OH_NNDevice_GetName(device_id, &device_name);
4459be168c0dSopenharmony_ci+  if (ret != OH_NN_SUCCESS) {
4460be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "Get name of device: " << device_id << " failed, error: " << ret;
4461be168c0dSopenharmony_ci+    return false;
4462be168c0dSopenharmony_ci+  }
4463be168c0dSopenharmony_ci+
4464be168c0dSopenharmony_ci+  if (strncmp(kirin_npu_name_prefix.c_str(), device_name, kirin_npu_name_prefix.size()) != 0) {
4465be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "strncmp: " << device_id << " failed, device_name: " << device_name;
4466be168c0dSopenharmony_ci+    return false;
4467be168c0dSopenharmony_ci+  }
4468be168c0dSopenharmony_ci+  return true;
4469be168c0dSopenharmony_ci+}
4470be168c0dSopenharmony_ci+
4471be168c0dSopenharmony_ci+Status NNRTDelegate::BuildKirinNPUModel(DelegateModel<schema::Primitive> *model) {
4472be168c0dSopenharmony_ci+  OH_NNModel *nn_model = OH_NNModel_Construct();
4473be168c0dSopenharmony_ci+  if (nn_model == nullptr) {
4474be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Create NNModel failed, result is nullptr";
4475be168c0dSopenharmony_ci+    return kLiteNullptr;
4476be168c0dSopenharmony_ci+  }
4477be168c0dSopenharmony_ci+
4478be168c0dSopenharmony_ci+  size_t extension_size = nnrt_device_info_.extensions_.size();
4479be168c0dSopenharmony_ci+  std::vector<OH_NN_Extension> extensions;
4480be168c0dSopenharmony_ci+  MS_LOG_DEBUG << "set extensions, item number: " << extension_size;
4481be168c0dSopenharmony_ci+  const size_t kExtensionNameMax = 128; // This is a length limitation in NNRT API.
4482be168c0dSopenharmony_ci+  for (size_t i = 0; i < extension_size; i++) {
4483be168c0dSopenharmony_ci+    auto &src_extension = nnrt_device_info_.extensions_[i];
4484be168c0dSopenharmony_ci+    OH_NN_Extension dst_extension;
4485be168c0dSopenharmony_ci+    dst_extension.name[kExtensionNameMax - 1] = '\0';
4486be168c0dSopenharmony_ci+    strncpy(dst_extension.name, src_extension.name.c_str(), kExtensionNameMax - 1);
4487be168c0dSopenharmony_ci+    dst_extension.value = (char *)((void *)src_extension.value.data());
4488be168c0dSopenharmony_ci+    dst_extension.valueSize = src_extension.value.size();
4489be168c0dSopenharmony_ci+    extensions.push_back(dst_extension);
4490be168c0dSopenharmony_ci+    MS_LOG_DEBUG << "set extension, item name: " << dst_extension.name << ", value size: " << dst_extension.valueSize;
4491be168c0dSopenharmony_ci+  }
4492be168c0dSopenharmony_ci+
4493be168c0dSopenharmony_ci+  if (IsCustomModel()) {
4494be168c0dSopenharmony_ci+    auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, lite_graph_);
4495be168c0dSopenharmony_ci+    if (ret != OH_NN_SUCCESS) {
4496be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret;
4497be168c0dSopenharmony_ci+      OH_NNModel_Destroy(&nn_model);
4498be168c0dSopenharmony_ci+      return kLiteError;
4499be168c0dSopenharmony_ci+    }
4500be168c0dSopenharmony_ci+  } else {
4501be168c0dSopenharmony_ci+    SetKirinModelInputsAndOutputs(nn_model);
4502be168c0dSopenharmony_ci+    auto ret = OH_NNModel_BuildFromMetaGraph(nn_model, meta_graph_, extensions.data(), extensions.size());
4503be168c0dSopenharmony_ci+    if (ret != OH_NN_SUCCESS) {
4504be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret;
4505be168c0dSopenharmony_ci+      OH_NNModel_Destroy(&nn_model);
4506be168c0dSopenharmony_ci+      return kLiteError;
4507be168c0dSopenharmony_ci+    }
4508be168c0dSopenharmony_ci+  }
4509be168c0dSopenharmony_ci+
4510be168c0dSopenharmony_ci+  auto ret2 =  CreateFullModelKernel(model, nn_model);
4511be168c0dSopenharmony_ci+  if (ret2 != kSuccess) {
4512be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Create full model kernel failed, ret: " << ret2;
4513be168c0dSopenharmony_ci+    return kLiteError;
4514be168c0dSopenharmony_ci   }
4515be168c0dSopenharmony_ci-  MS_LOG(INFO) << "NNRTDelegate creates NNModel success.";
4516be168c0dSopenharmony_ci+  return kSuccess;
4517be168c0dSopenharmony_ci+}
4518be168c0dSopenharmony_ci+
4519be168c0dSopenharmony_ci+std::vector<OH_NN_TensorInfo> NNRTDelegate::CreateNNTensorInfos(const std::vector<uint32_t> &indices) const {
4520be168c0dSopenharmony_ci+  std::vector<OH_NN_TensorInfo> nn_tensor_infos;
4521be168c0dSopenharmony_ci+  for (auto index: indices) {
4522be168c0dSopenharmony_ci+    auto tensor = lite_graph_->all_tensors_[index];
4523be168c0dSopenharmony_ci+    auto shape = tensor->dims();
4524be168c0dSopenharmony_ci+    auto data_type = tensor->dataType();
4525be168c0dSopenharmony_ci+    auto name = tensor->name();
4526be168c0dSopenharmony_ci+    auto format = tensor->format();
4527be168c0dSopenharmony_ci 
4528be168c0dSopenharmony_ci-  OH_NNCompilation *oh_nn_compilation = nullptr;
4529be168c0dSopenharmony_ci-  oh_nn_compilation = OH_NNCompilation_Construct(oh_nnmodel);
4530be168c0dSopenharmony_ci+    OH_NN_TensorInfo info;
4531be168c0dSopenharmony_ci+    info.dataType = CastToNNRTDataType(static_cast<mindspore::DataType>(data_type));
4532be168c0dSopenharmony_ci+    info.dimensions = shape->data();
4533be168c0dSopenharmony_ci+    info.dimensionCount = shape->size();
4534be168c0dSopenharmony_ci+    strcpy(info.name, name->c_str());
4535be168c0dSopenharmony_ci+    info.format = CastToNNRTFormat(static_cast<Format>(format));
4536be168c0dSopenharmony_ci+    nn_tensor_infos.push_back(info);
4537be168c0dSopenharmony_ci+  }
4538be168c0dSopenharmony_ci+  return nn_tensor_infos;
4539be168c0dSopenharmony_ci+}
4540be168c0dSopenharmony_ci 
4541be168c0dSopenharmony_ci-  if (oh_nn_compilation == nullptr) {
4542be168c0dSopenharmony_ci+Status NNRTDelegate::SetKirinModelInputsAndOutputs(OH_NNModel *nn_model) {
4543be168c0dSopenharmony_ci+  std::vector<OH_NN_TensorInfo> inputInfos;
4544be168c0dSopenharmony_ci+  std::vector<OH_NN_TensorInfo> outputInfos;
4545be168c0dSopenharmony_ci+  auto input_infos = CreateNNTensorInfos(lite_graph_->input_indices_);
4546be168c0dSopenharmony_ci+  auto output_infos = CreateNNTensorInfos(lite_graph_->output_indices_);
4547be168c0dSopenharmony_ci+  OH_NNModel_SetInputsAndOutputsInfo(nn_model, input_infos.data(), input_infos.size(), output_infos.data(),
4548be168c0dSopenharmony_ci+                                     output_infos.size());
4549be168c0dSopenharmony_ci+  return kSuccess;
4550be168c0dSopenharmony_ci+}
4551be168c0dSopenharmony_ci+
4552be168c0dSopenharmony_ci+Status NNRTDelegate::CreateFullModelKernel(DelegateModel<schema::Primitive> *model, OH_NNModel *nn_model) {
4553be168c0dSopenharmony_ci+  OH_NNCompilation *nn_compilation = OH_NNCompilation_Construct(nn_model);
4554be168c0dSopenharmony_ci+  if (nn_compilation == nullptr) {
4555be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Construct NNCompilation failed";
4556be168c0dSopenharmony_ci-    OH_NNModel_Destroy(&oh_nnmodel);
4557be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4558be168c0dSopenharmony_ci+    OH_NNModel_Destroy(&nn_model);
4559be168c0dSopenharmony_ci+    return kLiteError;
4560be168c0dSopenharmony_ci   }
4561be168c0dSopenharmony_ci-  MS_LOG(INFO) << "NNRTDelegate creates NNCompilation success.";
4562be168c0dSopenharmony_ci+  MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success.";
4563be168c0dSopenharmony_ci 
4564be168c0dSopenharmony_ci-  const size_t *allDevicesID = nullptr;
4565be168c0dSopenharmony_ci-  uint32_t device_count = 0;
4566be168c0dSopenharmony_ci-  ret_code = OH_NNDevice_GetAllDevicesID(&allDevicesID, &device_count);
4567be168c0dSopenharmony_ci-  if (ret_code != OH_NN_SUCCESS) {
4568be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "NNModel GetAllDevicesID failed, OH_NN_ReturnCode = " << ret_code;
4569be168c0dSopenharmony_ci-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4570be168c0dSopenharmony_ci-    OH_NNModel_Destroy(&oh_nnmodel);
4571be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4572be168c0dSopenharmony_ci+  auto ret_code = InitNNCompilation(nn_compilation);
4573be168c0dSopenharmony_ci+  if (ret_code != kSuccess) {
4574be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Init NNCompilation failed";
4575be168c0dSopenharmony_ci+    OH_NNModel_Destroy(&nn_model);
4576be168c0dSopenharmony_ci+    OH_NNCompilation_Destroy(&nn_compilation);
4577be168c0dSopenharmony_ci+    return kLiteError;
4578be168c0dSopenharmony_ci   }
4579be168c0dSopenharmony_ci+  OH_NNModel_Destroy(&nn_model);
4580be168c0dSopenharmony_ci 
4581be168c0dSopenharmony_ci-  if (device_count <= 0) {
4582be168c0dSopenharmony_ci-    MS_LOG(WARNING) << "No NNRt Device found, fall back to CPU. ";
4583be168c0dSopenharmony_ci-    // OH_NNCompilation_Destroy(&oh_nn_compilation);
4584be168c0dSopenharmony_ci-    // OH_NNModel_Destroy(&oh_nnmodel);
4585be168c0dSopenharmony_ci-    return mindspore::kSuccess;
4586be168c0dSopenharmony_ci+  OH_NNExecutor *nn_executor = nullptr;
4587be168c0dSopenharmony_ci+  nn_executor = OH_NNExecutor_Construct(nn_compilation);
4588be168c0dSopenharmony_ci+  if (nn_executor == nullptr) {
4589be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code;
4590be168c0dSopenharmony_ci+    OH_NNCompilation_Destroy(&nn_compilation);
4591be168c0dSopenharmony_ci+    return kLiteError;
4592be168c0dSopenharmony_ci   }
4593be168c0dSopenharmony_ci-  MS_LOG(INFO) << "NNRTDelegate GetAllDevicesID success.";
4594be168c0dSopenharmony_ci+  OH_NNCompilation_Destroy(&nn_compilation);
4595be168c0dSopenharmony_ci 
4596be168c0dSopenharmony_ci-  // check if model ops are supported
4597be168c0dSopenharmony_ci-  const bool *issupported = nullptr;
4598be168c0dSopenharmony_ci+  auto nnrt_model_kernel = new (std::nothrow)NNRTModelKernel(nn_executor, model->inputs(), model->outputs());
4599be168c0dSopenharmony_ci+  if (nnrt_model_kernel == nullptr) {
4600be168c0dSopenharmony_ci+    OH_NNExecutor_Destroy(&nn_executor);
4601be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "new NNRTModelKernel failed";
4602be168c0dSopenharmony_ci+    return kLiteError;
4603be168c0dSopenharmony_ci+  }
4604be168c0dSopenharmony_ci+  model->Replace(model->BeginKernelIterator(), model->EndKernelIterator(), nnrt_model_kernel);
4605be168c0dSopenharmony_ci+  return kSuccess;
4606be168c0dSopenharmony_ci+}
4607be168c0dSopenharmony_ci+#endif
4608be168c0dSopenharmony_ci+
4609be168c0dSopenharmony_ci+Status NNRTDelegate::BuildNormalModel(DelegateModel<schema::Primitive> *model) {
4610be168c0dSopenharmony_ci+  MS_LOG(DEBUG) << "Start to build NNRT model.";
4611be168c0dSopenharmony_ci+  if ((lite_graph_ == nullptr) || (lite_graph_->sub_graphs_.size() > 1)) {
4612be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "LiteGraph contains more than one subgraph. NNRT does not support control-flow model yet, fallback to CPU";
4613be168c0dSopenharmony_ci+    return kSuccess;
4614be168c0dSopenharmony_ci+  }
4615be168c0dSopenharmony_ci+
4616be168c0dSopenharmony_ci+  OH_NNModel *full_model = CreateFullNNModel();
4617be168c0dSopenharmony_ci+  if (full_model == nullptr) {
4618be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "Build full NNModel failed, fallback to CPU";
4619be168c0dSopenharmony_ci+    return kSuccess;
4620be168c0dSopenharmony_ci+  }
4621be168c0dSopenharmony_ci+  std::vector<bool> op_supports = QueryOpSupports(full_model);
4622be168c0dSopenharmony_ci+  if (op_supports.empty()) {
4623be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "Query no op supports for full model, fallback to CPU";
4624be168c0dSopenharmony_ci+    OH_NNModel_Destroy(&full_model);
4625be168c0dSopenharmony_ci+    return kSuccess;
4626be168c0dSopenharmony_ci+  }
4627be168c0dSopenharmony_ci+  auto nnrt_subgraph_ranges = GetNNRTSubgraphRanges(model, op_supports);
4628be168c0dSopenharmony_ci+  MS_LOG(INFO) << "Found NNRT subgraph count: " << nnrt_subgraph_ranges.size();
4629be168c0dSopenharmony_ci+
4630be168c0dSopenharmony_ci+  std::vector<LiteGraph *> sub_lite_graphs;
4631be168c0dSopenharmony_ci+  auto ret = CreateLiteGraphForNNRTSubgraph(nnrt_subgraph_ranges, &sub_lite_graphs);
4632be168c0dSopenharmony_ci+  if (ret != kSuccess) {
4633be168c0dSopenharmony_ci+    OH_NNModel_Destroy(&full_model);
4634be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "Create NNRT sub LiteGraph failed, fallback to CPU";
4635be168c0dSopenharmony_ci+    return kSuccess;
4636be168c0dSopenharmony_ci+  }
4637be168c0dSopenharmony_ci+
4638be168c0dSopenharmony_ci+  std::vector<NNRTModelKernel *> nnrt_subgraph_kernels;
4639be168c0dSopenharmony_ci+  ret = CreateNNRTSubgraphKernels(model, sub_lite_graphs, nnrt_subgraph_ranges, &nnrt_subgraph_kernels);
4640be168c0dSopenharmony_ci+  if (ret != kSuccess) {
4641be168c0dSopenharmony_ci+    OH_NNModel_Destroy(&full_model);
4642be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "Create NNRT subgraph kernel failed, fallback to CPU";
4643be168c0dSopenharmony_ci+    return kSuccess;
4644be168c0dSopenharmony_ci+  }
4645be168c0dSopenharmony_ci+
4646be168c0dSopenharmony_ci+  ReplaceNNRTKernelsInDelegateModel(model, nnrt_subgraph_ranges, nnrt_subgraph_kernels);
4647be168c0dSopenharmony_ci+  OH_NNModel_Destroy(&full_model);
4648be168c0dSopenharmony_ci+  MS_LOG(INFO) << "NNRTDelegate build success.";
4649be168c0dSopenharmony_ci+  return kSuccess;
4650be168c0dSopenharmony_ci+}
4651be168c0dSopenharmony_ci+
4652be168c0dSopenharmony_ci+OH_NNModel *NNRTDelegate::CreateFullNNModel() {
4653be168c0dSopenharmony_ci+  if (lite_graph_ == nullptr) {
4654be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Lite graph is null";
4655be168c0dSopenharmony_ci+    return nullptr;
4656be168c0dSopenharmony_ci+  }
4657be168c0dSopenharmony_ci+
4658be168c0dSopenharmony_ci+  if (lite_graph_->sub_graphs_.empty()) {
4659be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Lite graph must have at lease one subgraph";
4660be168c0dSopenharmony_ci+    return nullptr;
4661be168c0dSopenharmony_ci+  }
4662be168c0dSopenharmony_ci+
4663be168c0dSopenharmony_ci+  OH_NNModel *nn_model = OH_NNModel_Construct();
4664be168c0dSopenharmony_ci+  if (nn_model == nullptr) {
4665be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Create NNModel failed, result is nullptr";
4666be168c0dSopenharmony_ci+    return nullptr;
4667be168c0dSopenharmony_ci+  }
4668be168c0dSopenharmony_ci+
4669be168c0dSopenharmony_ci+  auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, lite_graph_);
4670be168c0dSopenharmony_ci+  if (ret != OH_NN_SUCCESS) {
4671be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret;
4672be168c0dSopenharmony_ci+    OH_NNModel_Destroy(&nn_model);
4673be168c0dSopenharmony_ci+    return nullptr;
4674be168c0dSopenharmony_ci+  }
4675be168c0dSopenharmony_ci+  return nn_model;
4676be168c0dSopenharmony_ci+}
4677be168c0dSopenharmony_ci+
4678be168c0dSopenharmony_ci+std::vector<bool> NNRTDelegate::QueryOpSupports(OH_NNModel *nn_model) {
4679be168c0dSopenharmony_ci+  const bool *is_supported = nullptr; // Note: this memory is owned by nn_model, don't free alone.
4680be168c0dSopenharmony_ci   uint32_t op_count = 0;
4681be168c0dSopenharmony_ci-  ret_code = OH_NNModel_GetAvailableOperations(oh_nnmodel, allDevicesID[0], &issupported, &op_count);
4682be168c0dSopenharmony_ci-  if (ret_code != OH_NN_SUCCESS) {
4683be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "NNModel GetAvailableOperations failed, OH_NN_ReturnCode = " << ret_code
4684be168c0dSopenharmony_ci-                  << ", maybe due to dataParcel data length limitaion. Fall back to CPU.";
4685be168c0dSopenharmony_ci-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4686be168c0dSopenharmony_ci-    OH_NNModel_Destroy(&oh_nnmodel);
4687be168c0dSopenharmony_ci-    return mindspore::kSuccess;
4688be168c0dSopenharmony_ci+  auto ret = OH_NNModel_GetAvailableOperations(nn_model, nnrt_device_info_.device_id_, &is_supported, &op_count);
4689be168c0dSopenharmony_ci+  if (ret != OH_NN_SUCCESS) {
4690be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "NNModel GetAvailableOperations failed, ret: " << ret
4691be168c0dSopenharmony_ci+                  << ", maybe caused by dataParcel data length limitation";
4692be168c0dSopenharmony_ci+    return {};
4693be168c0dSopenharmony_ci   }
4694be168c0dSopenharmony_ci-  uint32_t supported_op_count = 0;
4695be168c0dSopenharmony_ci-  for (uint32_t i = 0; i < op_count; i++) {
4696be168c0dSopenharmony_ci-    if (issupported[i]) {
4697be168c0dSopenharmony_ci-      supported_op_count++;
4698be168c0dSopenharmony_ci+  std::vector<bool> op_supports(is_supported, is_supported + op_count);
4699be168c0dSopenharmony_ci+  return op_supports;
4700be168c0dSopenharmony_ci+}
4701be168c0dSopenharmony_ci+
4702be168c0dSopenharmony_ci+/* Find continuous sub-sequence in op_supports. */
4703be168c0dSopenharmony_ci+std::vector<NNRTOpRange> NNRTDelegate::GetNNRTSubgraphRanges(DelegateModel<schema::Primitive> *model,
4704be168c0dSopenharmony_ci+                                                             const std::vector<bool> &op_supports) {
4705be168c0dSopenharmony_ci+  std::vector<NNRTOpRange> nnrt_subgraph_ranges;
4706be168c0dSopenharmony_ci+  NNRTOpRange op_range;
4707be168c0dSopenharmony_ci+  bool start_count = false;
4708be168c0dSopenharmony_ci+  for (size_t i = 0; i < op_supports.size(); i++) {
4709be168c0dSopenharmony_ci+    if (op_supports[i]) {
4710be168c0dSopenharmony_ci+      if (start_count == false) {
4711be168c0dSopenharmony_ci+        start_count = true;
4712be168c0dSopenharmony_ci+        op_range.begin_index_ = i;
4713be168c0dSopenharmony_ci+        op_range.begin_iter_ = model->BeginKernelIterator() + i;
4714be168c0dSopenharmony_ci+      }
4715be168c0dSopenharmony_ci+    } else {
4716be168c0dSopenharmony_ci+      if (start_count == true) {
4717be168c0dSopenharmony_ci+        start_count = false;
4718be168c0dSopenharmony_ci+        op_range.end_index_ = i;
4719be168c0dSopenharmony_ci+        op_range.end_iter_ = model->BeginKernelIterator() + i;
4720be168c0dSopenharmony_ci+        nnrt_subgraph_ranges.push_back(op_range);
4721be168c0dSopenharmony_ci+      }
4722be168c0dSopenharmony_ci     }
4723be168c0dSopenharmony_ci   }
4724be168c0dSopenharmony_ci-  if (op_count != supported_op_count) {
4725be168c0dSopenharmony_ci-    MS_LOG(WARNING) << "this model has " << op_count << "ops, but NNRT only support " << supported_op_count
4726be168c0dSopenharmony_ci-                    << " ops, fall back to CPU.";
4727be168c0dSopenharmony_ci-    // must support all op, else fall back to CPU
4728be168c0dSopenharmony_ci-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4729be168c0dSopenharmony_ci-    OH_NNModel_Destroy(&oh_nnmodel);
4730be168c0dSopenharmony_ci-    return mindspore::kSuccess;
4731be168c0dSopenharmony_ci+  // handle last true subsequence
4732be168c0dSopenharmony_ci+  if (start_count == true) {
4733be168c0dSopenharmony_ci+    op_range.end_index_ = op_supports.size();
4734be168c0dSopenharmony_ci+    op_range.end_iter_ = model->EndKernelIterator();
4735be168c0dSopenharmony_ci+    nnrt_subgraph_ranges.push_back(op_range);
4736be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Schedule NNRT subgraph range: [" << op_range.begin_index_ << ", " << op_range.end_index_ << ")";
4737be168c0dSopenharmony_ci   }
4738be168c0dSopenharmony_ci-  MS_LOG(INFO) << "NNRtDelegate supports all op in this model.";
4739be168c0dSopenharmony_ci+  return nnrt_subgraph_ranges;
4740be168c0dSopenharmony_ci+}
4741be168c0dSopenharmony_ci+
4742be168c0dSopenharmony_ci+/**
4743be168c0dSopenharmony_ci+ * This method ONLY works when the follow pre-conditions are satisfied:
4744be168c0dSopenharmony_ci+ * 1. The node order of lite_graph_->all_nodes should be consistent with DelegateModel sequence.
4745be168c0dSopenharmony_ci+ *  This ensures the kernel replacement in DelegateModel based on the re-organizing info from lite_graph_ is correct.
4746be168c0dSopenharmony_ci+ * 2. The node indices of lite_graph_->sub_graphs[0].node_indices should be monotonically increasing from 0 to size - 1.
4747be168c0dSopenharmony_ci+ */
4748be168c0dSopenharmony_ci+Status NNRTDelegate::CreateLiteGraphForNNRTSubgraph(
4749be168c0dSopenharmony_ci+    const std::vector<NNRTOpRange> &nnrt_op_ranges,
4750be168c0dSopenharmony_ci+    std::vector<LiteGraph *> *sub_lite_graphs) {
4751be168c0dSopenharmony_ci+  MS_LOG(INFO) << "Start creating LiteGraph for NNRT subgraph";
4752be168c0dSopenharmony_ci+  for (const auto &op_range: nnrt_op_ranges) {
4753be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Process op range: [" << op_range.begin_index_ << ", " << op_range.end_index_ << ")";
4754be168c0dSopenharmony_ci+    LiteGraph *sub_lite_graph = new (std::nothrow)LiteGraph;
4755be168c0dSopenharmony_ci+    if (sub_lite_graph == nullptr) {
4756be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Allocate LiteGraph failed";
4757be168c0dSopenharmony_ci+      return kLiteError;
4758be168c0dSopenharmony_ci+    }
4759be168c0dSopenharmony_ci+    sub_lite_graph->name_ = lite_graph_->name_;
4760be168c0dSopenharmony_ci+    sub_lite_graph->version_ = lite_graph_->version_;
4761be168c0dSopenharmony_ci 
4762be168c0dSopenharmony_ci-  ret_code = OH_NNCompilation_SetDevice(oh_nn_compilation, allDevicesID[0]);
4763be168c0dSopenharmony_ci+    auto sub_graph = new (std::nothrow)LiteGraph::SubGraph;
4764be168c0dSopenharmony_ci+    if (sub_graph == nullptr) {
4765be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Allocate SubGraph failed";
4766be168c0dSopenharmony_ci+      return kLiteError;
4767be168c0dSopenharmony_ci+    }
4768be168c0dSopenharmony_ci+    sub_graph->name_ = lite_graph_->name_;
4769be168c0dSopenharmony_ci+    sub_lite_graph->sub_graphs_.push_back(sub_graph);
4770be168c0dSopenharmony_ci 
4771be168c0dSopenharmony_ci+    // deal with all_nodes
4772be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Assemble all_nodes...";
4773be168c0dSopenharmony_ci+    int new_node_index = 0;
4774be168c0dSopenharmony_ci+    std::map<uint32_t, schema::Tensor *> in_tensor_index_map;
4775be168c0dSopenharmony_ci+    std::map<uint32_t, schema::Tensor *> out_tensor_index_map;
4776be168c0dSopenharmony_ci+    for (size_t index = op_range.begin_index_; index < op_range.end_index_; index++) {
4777be168c0dSopenharmony_ci+      LiteGraph::Node *node = new (std::nothrow)LiteGraph::Node;
4778be168c0dSopenharmony_ci+      if (node == nullptr) {
4779be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Allocate Node failed";
4780be168c0dSopenharmony_ci+        return kLiteError;
4781be168c0dSopenharmony_ci+      }
4782be168c0dSopenharmony_ci+      *node = *lite_graph_->all_nodes_[index];
4783be168c0dSopenharmony_ci+      sub_lite_graph->all_nodes_.push_back(node);
4784be168c0dSopenharmony_ci+      sub_graph->node_indices_.push_back(new_node_index++);
4785be168c0dSopenharmony_ci+
4786be168c0dSopenharmony_ci+      for (auto i: node->input_indices_) {
4787be168c0dSopenharmony_ci+        in_tensor_index_map.emplace(i, lite_graph_->all_tensors_[i]);
4788be168c0dSopenharmony_ci+      }
4789be168c0dSopenharmony_ci+      for (auto i: node->output_indices_) {
4790be168c0dSopenharmony_ci+        out_tensor_index_map.emplace(i, lite_graph_->all_tensors_[i]);
4791be168c0dSopenharmony_ci+      }
4792be168c0dSopenharmony_ci+    }
4793be168c0dSopenharmony_ci+
4794be168c0dSopenharmony_ci+    // deal with all_tensors
4795be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Assemble all_tensors...";
4796be168c0dSopenharmony_ci+    std::set<schema::Tensor *> tensors;
4797be168c0dSopenharmony_ci+    for (auto iter: in_tensor_index_map) {
4798be168c0dSopenharmony_ci+      tensors.emplace(iter.second);
4799be168c0dSopenharmony_ci+    }
4800be168c0dSopenharmony_ci+    for (auto iter: out_tensor_index_map) {
4801be168c0dSopenharmony_ci+      tensors.emplace(iter.second);
4802be168c0dSopenharmony_ci+    }
4803be168c0dSopenharmony_ci+
4804be168c0dSopenharmony_ci+    uint32_t new_index = 0;
4805be168c0dSopenharmony_ci+    std::map<schema::Tensor *, uint32_t> new_tensor_maps;
4806be168c0dSopenharmony_ci+    for (auto tensor: tensors) {
4807be168c0dSopenharmony_ci+      new_tensor_maps.emplace(tensor, new_index++);
4808be168c0dSopenharmony_ci+    }
4809be168c0dSopenharmony_ci+
4810be168c0dSopenharmony_ci+    sub_lite_graph->all_tensors_ = std::vector<schema::Tensor *>(tensors.begin(), tensors.end());
4811be168c0dSopenharmony_ci+
4812be168c0dSopenharmony_ci+    // deal with every node's input/output indices
4813be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Set input/output indices of each node...";
4814be168c0dSopenharmony_ci+    for (auto node: sub_lite_graph->all_nodes_) {
4815be168c0dSopenharmony_ci+      for (auto &index : node->input_indices_) {
4816be168c0dSopenharmony_ci+        index = new_tensor_maps.at(in_tensor_index_map.at(index));
4817be168c0dSopenharmony_ci+      }
4818be168c0dSopenharmony_ci+      for (auto &index : node->output_indices_) {
4819be168c0dSopenharmony_ci+        index = new_tensor_maps.at(out_tensor_index_map.at(index));
4820be168c0dSopenharmony_ci+      }
4821be168c0dSopenharmony_ci+    }
4822be168c0dSopenharmony_ci+
4823be168c0dSopenharmony_ci+    // deal with subgraph's input/output indices
4824be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Set input/output indices of each subgraph...";
4825be168c0dSopenharmony_ci+    sub_graph->tensor_indices_ = std::vector<uint32_t>(tensors.size());
4826be168c0dSopenharmony_ci+    std::iota(sub_graph->tensor_indices_.begin(), sub_graph->tensor_indices_.end(), 0U);
4827be168c0dSopenharmony_ci+
4828be168c0dSopenharmony_ci+    for (auto iter: in_tensor_index_map) {
4829be168c0dSopenharmony_ci+      auto new_tensor_index = new_tensor_maps[iter.second];
4830be168c0dSopenharmony_ci+      MS_LOG(DEBUG) << "handle input: old: " << iter.first << ", new: " << new_tensor_index << std::endl;
4831be168c0dSopenharmony_ci+      if (IsConstTensor(*iter.second)) {
4832be168c0dSopenharmony_ci+        MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is const." << std::endl;
4833be168c0dSopenharmony_ci+        continue;
4834be168c0dSopenharmony_ci+      }
4835be168c0dSopenharmony_ci+
4836be168c0dSopenharmony_ci+      bool is_subgraph_input = true;
4837be168c0dSopenharmony_ci+      for (auto node: sub_lite_graph->all_nodes_) {
4838be168c0dSopenharmony_ci+        if (std::find(node->output_indices_.begin(), node->output_indices_.end(), new_tensor_index) !=
4839be168c0dSopenharmony_ci+            node->output_indices_.end()) {
4840be168c0dSopenharmony_ci+          is_subgraph_input = false;
4841be168c0dSopenharmony_ci+          MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is not subgraph input." << std::endl;
4842be168c0dSopenharmony_ci+          break;
4843be168c0dSopenharmony_ci+        }
4844be168c0dSopenharmony_ci+      }
4845be168c0dSopenharmony_ci+      if (is_subgraph_input) {
4846be168c0dSopenharmony_ci+        sub_graph->input_indices_.push_back(new_tensor_index);
4847be168c0dSopenharmony_ci+        MS_LOG(DEBUG) << "- select tensor: " << new_tensor_index << " as subgraph input." << std::endl;
4848be168c0dSopenharmony_ci+      }
4849be168c0dSopenharmony_ci+    }
4850be168c0dSopenharmony_ci+
4851be168c0dSopenharmony_ci+    for (auto iter: out_tensor_index_map) {
4852be168c0dSopenharmony_ci+      int new_tensor_index = new_tensor_maps.at(iter.second);
4853be168c0dSopenharmony_ci+      MS_LOG(DEBUG) << "handle output: old: " << iter.first << ", new: " << new_tensor_index << std::endl;
4854be168c0dSopenharmony_ci+      if (IsConstTensor(*iter.second)) {
4855be168c0dSopenharmony_ci+        MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is const." << std::endl;
4856be168c0dSopenharmony_ci+        continue;
4857be168c0dSopenharmony_ci+      }
4858be168c0dSopenharmony_ci+
4859be168c0dSopenharmony_ci+      bool is_subgraph_output = false;
4860be168c0dSopenharmony_ci+      for (size_t i = 0; i < lite_graph_->all_nodes_.size(); i++) {
4861be168c0dSopenharmony_ci+        if ((i >= op_range.begin_index_) && (i < op_range.end_index_)) {
4862be168c0dSopenharmony_ci+          continue;
4863be168c0dSopenharmony_ci+        }
4864be168c0dSopenharmony_ci+        auto node = lite_graph_->all_nodes_[i];
4865be168c0dSopenharmony_ci+        if (std::find(node->input_indices_.begin(), node->input_indices_.end(), iter.first) !=
4866be168c0dSopenharmony_ci+            node->input_indices_.end()) { // As the input of node which does not belong to the subgraph.
4867be168c0dSopenharmony_ci+          is_subgraph_output = true;
4868be168c0dSopenharmony_ci+          MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is original subgraph output. node: " << node->primitive_ << std::endl;
4869be168c0dSopenharmony_ci+          break;
4870be168c0dSopenharmony_ci+        }
4871be168c0dSopenharmony_ci+      }
4872be168c0dSopenharmony_ci+      bool is_graph_output = (std::find(lite_graph_->output_indices_.begin(),lite_graph_->output_indices_.end(),
4873be168c0dSopenharmony_ci+                                        iter.first) != lite_graph_->output_indices_.end());
4874be168c0dSopenharmony_ci+      if (is_graph_output) {
4875be168c0dSopenharmony_ci+        MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is graph output." << std::endl;
4876be168c0dSopenharmony_ci+      }
4877be168c0dSopenharmony_ci+      if (is_subgraph_output || is_graph_output) {
4878be168c0dSopenharmony_ci+        sub_graph->output_indices_.push_back(new_tensor_index);
4879be168c0dSopenharmony_ci+        MS_LOG(DEBUG) << "- select tensor: " << new_tensor_index << " as subgraph output." << std::endl;
4880be168c0dSopenharmony_ci+      }
4881be168c0dSopenharmony_ci+    }
4882be168c0dSopenharmony_ci+
4883be168c0dSopenharmony_ci+    // deal with full-graph's input/output indices
4884be168c0dSopenharmony_ci+    sub_lite_graph->input_indices_ = sub_graph->input_indices_;
4885be168c0dSopenharmony_ci+    sub_lite_graph->output_indices_ = sub_graph->output_indices_;
4886be168c0dSopenharmony_ci+    sub_lite_graphs->push_back(sub_lite_graph);
4887be168c0dSopenharmony_ci+  }
4888be168c0dSopenharmony_ci+  MS_LOG(INFO) << "Finished creating LiteGraph for NNRT subgraph";
4889be168c0dSopenharmony_ci+  return kSuccess;
4890be168c0dSopenharmony_ci+}
4891be168c0dSopenharmony_ci+
4892be168c0dSopenharmony_ci+struct TensorLocation {
4893be168c0dSopenharmony_ci+  uint32_t node_index; // the index of node which the tensor belongs to.
4894be168c0dSopenharmony_ci+  uint32_t tensor_index; // the index of node in/out tensors which the tensor is located at.
4895be168c0dSopenharmony_ci+};
4896be168c0dSopenharmony_ci+
4897be168c0dSopenharmony_ci+Status NNRTDelegate::InitNNCompilation(OH_NNCompilation *nn_compilation) const {
4898be168c0dSopenharmony_ci+  auto ret_code = OH_NNCompilation_SetDevice(nn_compilation, nnrt_device_info_.device_id_);
4899be168c0dSopenharmony_ci   if (ret_code != OH_NN_SUCCESS) {
4900be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "NNCompilation SetDevice failed, OH_NN_ReturnCode = " << ret_code;
4901be168c0dSopenharmony_ci-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4902be168c0dSopenharmony_ci-    OH_NNModel_Destroy(&oh_nnmodel);
4903be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4904be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNCompilation set device id failed, ret: " << ret_code;
4905be168c0dSopenharmony_ci+    return kLiteError;
4906be168c0dSopenharmony_ci+  }
4907be168c0dSopenharmony_ci+  ret_code = OH_NNCompilation_SetPerformanceMode(nn_compilation,
4908be168c0dSopenharmony_ci+                                                 (OH_NN_PerformanceMode)(nnrt_device_info_.performance_mode_));
4909be168c0dSopenharmony_ci+  if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
4910be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNCompilation set performance mode failed, ret: " << ret_code;
4911be168c0dSopenharmony_ci+    return kLiteError;
4912be168c0dSopenharmony_ci+  }
4913be168c0dSopenharmony_ci+  ret_code = OH_NNCompilation_SetPriority(nn_compilation, (OH_NN_Priority)(nnrt_device_info_.priority_));
4914be168c0dSopenharmony_ci+  if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
4915be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNCompilation set priority failed, ret: " << ret_code;
4916be168c0dSopenharmony_ci+    return kLiteError;
4917be168c0dSopenharmony_ci+  }
4918be168c0dSopenharmony_ci+  ret_code = OH_NNCompilation_EnableFloat16(nn_compilation, nnrt_device_info_.enable_fp16_);
4919be168c0dSopenharmony_ci+  if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
4920be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NNCompilation enable fp16 failed, ret: " << ret_code;
4921be168c0dSopenharmony_ci+    return kLiteError;
4922be168c0dSopenharmony_ci   }
4923be168c0dSopenharmony_ci 
4924be168c0dSopenharmony_ci-  ret_code = OH_NNCompilation_Build(oh_nn_compilation);
4925be168c0dSopenharmony_ci+  if (!cache_path_.empty()) { // Set cache path if user indeed set it.
4926be168c0dSopenharmony_ci+    ret_code = OH_NNCompilation_SetCache(nn_compilation, cache_path_.c_str(), cache_version_);
4927be168c0dSopenharmony_ci+    if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
4928be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "NNCompilation set cache failed, ret: " << ret_code;
4929be168c0dSopenharmony_ci+      return kLiteError;
4930be168c0dSopenharmony_ci+    }
4931be168c0dSopenharmony_ci+  }
4932be168c0dSopenharmony_ci 
4933be168c0dSopenharmony_ci+  ret_code = OH_NNCompilation_Build(nn_compilation);
4934be168c0dSopenharmony_ci   if (ret_code != OH_NN_SUCCESS) {
4935be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Build NNCompilation failed, OH_NN_ReturnCode = " << ret_code;
4936be168c0dSopenharmony_ci-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4937be168c0dSopenharmony_ci-    OH_NNModel_Destroy(&oh_nnmodel);
4938be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4939be168c0dSopenharmony_ci-  }
4940be168c0dSopenharmony_ci-
4941be168c0dSopenharmony_ci-  MS_LOG(DEBUG) << "NNRTDelegate SetDevice success.";
4942be168c0dSopenharmony_ci-
4943be168c0dSopenharmony_ci-  OH_NNExecutor *oh_nn_executor = nullptr;
4944be168c0dSopenharmony_ci-  oh_nn_executor = OH_NNExecutor_Construct(oh_nn_compilation);
4945be168c0dSopenharmony_ci-  if (oh_nn_executor == nullptr) {
4946be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Construct NNCompilation SetDevice failed, OH_NN_ReturnCode = " << ret_code;
4947be168c0dSopenharmony_ci-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4948be168c0dSopenharmony_ci-    OH_NNModel_Destroy(&oh_nnmodel);
4949be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4950be168c0dSopenharmony_ci-  }
4951be168c0dSopenharmony_ci-  MS_LOG(DEBUG) << "NNRTDelegate creates NNExecutor success.";
4952be168c0dSopenharmony_ci-  mindspore::Status prepare_data_ret;
4953be168c0dSopenharmony_ci-  auto nnr_model_kernel = new (std::nothrow) NNRTModelKernel(oh_nn_executor, model->inputs(), model->outputs());
4954be168c0dSopenharmony_ci-  if (nnr_model_kernel == nullptr) {
4955be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "new NNRTModelKernel failed";
4956be168c0dSopenharmony_ci-    return mindspore::kLiteError;
4957be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Build NNCompilation failed, ret: " << ret_code;
4958be168c0dSopenharmony_ci+    return kLiteError;
4959be168c0dSopenharmony_ci   }
4960be168c0dSopenharmony_ci-  OH_NNCompilation_Destroy(&oh_nn_compilation);
4961be168c0dSopenharmony_ci-  OH_NNModel_Destroy(&oh_nnmodel);
4962be168c0dSopenharmony_ci-  KernelIter from = model->BeginKernelIterator();
4963be168c0dSopenharmony_ci-  KernelIter end = model->EndKernelIterator();
4964be168c0dSopenharmony_ci-  model->Replace(from, end, nnr_model_kernel);
4965be168c0dSopenharmony_ci+  return kSuccess;
4966be168c0dSopenharmony_ci+}
4967be168c0dSopenharmony_ci+
4968be168c0dSopenharmony_ci+Status NNRTDelegate::CreateNNRTSubgraphKernels(DelegateModel<schema::Primitive> *model,
4969be168c0dSopenharmony_ci+                                               const std::vector<LiteGraph *> &sub_lite_graphs, const std::vector<NNRTOpRange> &nnrt_subgraph_ranges,
4970be168c0dSopenharmony_ci+                                               std::vector<NNRTModelKernel *> *nnrt_subgraph_kernels) {
4971be168c0dSopenharmony_ci+  for (size_t i = 0; i < sub_lite_graphs.size(); i++) {
4972be168c0dSopenharmony_ci+    auto sub_lite_graph = sub_lite_graphs[i];
4973be168c0dSopenharmony_ci+
4974be168c0dSopenharmony_ci+    OH_NNModel *nn_model = OH_NNModel_Construct();
4975be168c0dSopenharmony_ci+    auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, sub_lite_graph);
4976be168c0dSopenharmony_ci+    if (ret != OH_NN_SUCCESS) {
4977be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret;
4978be168c0dSopenharmony_ci+      OH_NNModel_Destroy(&nn_model);
4979be168c0dSopenharmony_ci+      return kLiteError;
4980be168c0dSopenharmony_ci+    }
4981be168c0dSopenharmony_ci 
4982be168c0dSopenharmony_ci-  MS_LOG(INFO) << "NNRTDelegate build  success.";
4983be168c0dSopenharmony_ci-  return mindspore::kSuccess;
4984be168c0dSopenharmony_ci+    OH_NNCompilation *nn_compilation = OH_NNCompilation_Construct(nn_model);
4985be168c0dSopenharmony_ci+    if (nn_compilation == nullptr) {
4986be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Construct NNCompilation failed";
4987be168c0dSopenharmony_ci+      OH_NNModel_Destroy(&nn_model);
4988be168c0dSopenharmony_ci+      return kLiteError;
4989be168c0dSopenharmony_ci+    }
4990be168c0dSopenharmony_ci+    MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success.";
4991be168c0dSopenharmony_ci+
4992be168c0dSopenharmony_ci+    auto ret_code = InitNNCompilation(nn_compilation);
4993be168c0dSopenharmony_ci+    if (ret_code != kSuccess) {
4994be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Init NNCompilation failed";
4995be168c0dSopenharmony_ci+      OH_NNCompilation_Destroy(&nn_compilation);
4996be168c0dSopenharmony_ci+      OH_NNModel_Destroy(&nn_model);
4997be168c0dSopenharmony_ci+      return kLiteError;
4998be168c0dSopenharmony_ci+    }
4999be168c0dSopenharmony_ci+
5000be168c0dSopenharmony_ci+    OH_NNExecutor *nn_executor = nullptr;
5001be168c0dSopenharmony_ci+    nn_executor = OH_NNExecutor_Construct(nn_compilation);
5002be168c0dSopenharmony_ci+    if (nn_executor == nullptr) {
5003be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code;
5004be168c0dSopenharmony_ci+      OH_NNCompilation_Destroy(&nn_compilation);
5005be168c0dSopenharmony_ci+      OH_NNModel_Destroy(&nn_model);
5006be168c0dSopenharmony_ci+      return kLiteError;
5007be168c0dSopenharmony_ci+    }
5008be168c0dSopenharmony_ci+    MS_LOG(DEBUG) << "NNRTDelegate creates NNExecutor success.";
5009be168c0dSopenharmony_ci+
5010be168c0dSopenharmony_ci+    bool format_not_support = false;
5011be168c0dSopenharmony_ci+    std::vector<MSTensor> in_tensors;
5012be168c0dSopenharmony_ci+    for (auto index: sub_lite_graph->sub_graphs_[0]->input_indices_) {
5013be168c0dSopenharmony_ci+      TensorLocation location;
5014be168c0dSopenharmony_ci+      for (auto node_index: sub_lite_graph->sub_graphs_[0]->node_indices_) {
5015be168c0dSopenharmony_ci+        auto node = sub_lite_graph->all_nodes_[node_index];
5016be168c0dSopenharmony_ci+        auto iter = std::find(node->input_indices_.begin(), node->input_indices_.end(), index);
5017be168c0dSopenharmony_ci+        if (iter != node->input_indices_.end()) {
5018be168c0dSopenharmony_ci+          uint32_t tensor_index = iter - node->input_indices_.begin();
5019be168c0dSopenharmony_ci+          location.node_index = node_index;
5020be168c0dSopenharmony_ci+          location.tensor_index = tensor_index;
5021be168c0dSopenharmony_ci+          MS_LOG(INFO) << "Found graph input index: " << index << " is the " << tensor_index << "th input of the node " << node->primitive_;
5022be168c0dSopenharmony_ci+          break;
5023be168c0dSopenharmony_ci+        }
5024be168c0dSopenharmony_ci+      }
5025be168c0dSopenharmony_ci+      KernelIter kernel_iter = nnrt_subgraph_ranges[i].begin_iter_ + location.node_index;
5026be168c0dSopenharmony_ci+      in_tensors.push_back((*kernel_iter)->inputs()[location.tensor_index]);
5027be168c0dSopenharmony_ci+      if (in_tensors.back().format() != Format::NHWC) {
5028be168c0dSopenharmony_ci+        format_not_support = true;
5029be168c0dSopenharmony_ci+        break ;
5030be168c0dSopenharmony_ci+      }
5031be168c0dSopenharmony_ci+    }
5032be168c0dSopenharmony_ci+
5033be168c0dSopenharmony_ci+    std::vector<MSTensor> out_tensors;
5034be168c0dSopenharmony_ci+    for (auto index: sub_lite_graph->sub_graphs_[0]->output_indices_) {
5035be168c0dSopenharmony_ci+      TensorLocation location;
5036be168c0dSopenharmony_ci+      for (auto node_index: sub_lite_graph->sub_graphs_[0]->node_indices_) {
5037be168c0dSopenharmony_ci+        auto node = sub_lite_graph->all_nodes_[node_index];
5038be168c0dSopenharmony_ci+        auto iter = std::find(node->output_indices_.begin(), node->output_indices_.end(), index);
5039be168c0dSopenharmony_ci+        if (iter != node->output_indices_.end()) {
5040be168c0dSopenharmony_ci+          uint32_t tensor_index = iter - node->output_indices_.begin();
5041be168c0dSopenharmony_ci+          location.node_index = node_index;
5042be168c0dSopenharmony_ci+          location.tensor_index = tensor_index;
5043be168c0dSopenharmony_ci+          MS_LOG(INFO) << "Found graph output index: " << index << " is the " << tensor_index << "th output of the node " << node->primitive_;
5044be168c0dSopenharmony_ci+          break;
5045be168c0dSopenharmony_ci+        }
5046be168c0dSopenharmony_ci+      }
5047be168c0dSopenharmony_ci+      KernelIter kernel_iter = nnrt_subgraph_ranges[i].begin_iter_ + location.node_index;
5048be168c0dSopenharmony_ci+      out_tensors.push_back((*kernel_iter)->outputs()[location.tensor_index]);
5049be168c0dSopenharmony_ci+      if (out_tensors.back().format() != Format::NHWC) {
5050be168c0dSopenharmony_ci+        format_not_support = true;
5051be168c0dSopenharmony_ci+        break ;
5052be168c0dSopenharmony_ci+      }
5053be168c0dSopenharmony_ci+    }
5054be168c0dSopenharmony_ci+    if (format_not_support) {
5055be168c0dSopenharmony_ci+      MS_LOG(WARNING) << "Not support in/out tensor format, skip this subgraph";
5056be168c0dSopenharmony_ci+      OH_NNCompilation_Destroy(&nn_compilation);
5057be168c0dSopenharmony_ci+      OH_NNModel_Destroy(&nn_model);
5058be168c0dSopenharmony_ci+      nnrt_subgraph_kernels->push_back(nullptr);
5059be168c0dSopenharmony_ci+      continue ;
5060be168c0dSopenharmony_ci+    }
5061be168c0dSopenharmony_ci+
5062be168c0dSopenharmony_ci+    auto nnrt_model_kernel = new (std::nothrow)NNRTModelKernel(nn_executor, in_tensors, out_tensors);
5063be168c0dSopenharmony_ci+    if (nnrt_model_kernel == nullptr) {
5064be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "new NNRTModelKernel failed";
5065be168c0dSopenharmony_ci+      return kLiteError;
5066be168c0dSopenharmony_ci+    }
5067be168c0dSopenharmony_ci+    OH_NNCompilation_Destroy(&nn_compilation);
5068be168c0dSopenharmony_ci+    OH_NNModel_Destroy(&nn_model);
5069be168c0dSopenharmony_ci+    nnrt_subgraph_kernels->push_back(nnrt_model_kernel);
5070be168c0dSopenharmony_ci+  }
5071be168c0dSopenharmony_ci+  return kSuccess;
5072be168c0dSopenharmony_ci }
5073be168c0dSopenharmony_ci 
5074be168c0dSopenharmony_ci-mindspore::Status mindspore::NNRTDelegate::Init() {
5075be168c0dSopenharmony_ci-  MS_LOG(DEBUG) << "NNRTDelegate init success.";
5076be168c0dSopenharmony_ci-  return mindspore::kSuccess;
5077be168c0dSopenharmony_ci+void NNRTDelegate::ReplaceNNRTKernelsInDelegateModel(DelegateModel<schema::Primitive> *model,
5078be168c0dSopenharmony_ci+                                       const std::vector<NNRTOpRange> &nnrt_subgraph_ranges,
5079be168c0dSopenharmony_ci+                                       const std::vector<NNRTModelKernel *> &nnrt_subgraph_kernels) {
5080be168c0dSopenharmony_ci+  // Here we perform the replacement from back to front intentionally! If replace from front to end, the kernel
5081be168c0dSopenharmony_ci+  // sequence would shrink and the later begin_iter_/end_iter_ may be erased already.
5082be168c0dSopenharmony_ci+  for (int i = nnrt_subgraph_ranges.size() - 1; i >= 0; i--) {
5083be168c0dSopenharmony_ci+    if (nnrt_subgraph_kernels[i] == nullptr) {
5084be168c0dSopenharmony_ci+      continue;
5085be168c0dSopenharmony_ci+    }
5086be168c0dSopenharmony_ci+    auto from = nnrt_subgraph_ranges[i].begin_iter_;
5087be168c0dSopenharmony_ci+    auto end = nnrt_subgraph_ranges[i].end_iter_;
5088be168c0dSopenharmony_ci+    (void)model->Replace(from, end, nnrt_subgraph_kernels[i]);
5089be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Replace nnrt subgraph kernel in range: [" << (from - model->BeginKernelIterator())
5090be168c0dSopenharmony_ci+      << ", " << (end - model->BeginKernelIterator()) << ")";
5091be168c0dSopenharmony_ci+  }
5092be168c0dSopenharmony_ci }
5093be168c0dSopenharmony_ci-mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::Primitive> *model,
5094be168c0dSopenharmony_ci-                                                         OH_NNExecutor *oh_nn_executor) {
5095be168c0dSopenharmony_ci+
5096be168c0dSopenharmony_ci+Status NNRTDelegate::PrepareInputs(DelegateModel<schema::Primitive> *model,
5097be168c0dSopenharmony_ci+                                   OH_NNExecutor *oh_nn_executor) {
5098be168c0dSopenharmony_ci   auto input_tensors = model->inputs();
5099be168c0dSopenharmony_ci   for (size_t i = 0; i < input_tensors.size(); i++) {
5100be168c0dSopenharmony_ci     auto tensor = input_tensors[i];
5101be168c0dSopenharmony_ci@@ -161,10 +654,10 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P
5102be168c0dSopenharmony_ci     std::vector<double> scale;
5103be168c0dSopenharmony_ci     std::vector<int32_t> zero_point;
5104be168c0dSopenharmony_ci     if (!tmp_quant_param.empty()) {
5105be168c0dSopenharmony_ci-      quant_param = new (std::nothrow) OH_NN_QuantParam;
5106be168c0dSopenharmony_ci+      quant_param = new(std::nothrow) OH_NN_QuantParam;
5107be168c0dSopenharmony_ci       if (quant_param == nullptr) {
5108be168c0dSopenharmony_ci         MS_LOG(ERROR) << "new OH_NN_QuantParam failed.";
5109be168c0dSopenharmony_ci-        return mindspore::kLiteError;
5110be168c0dSopenharmony_ci+        return kLiteError;
5111be168c0dSopenharmony_ci       }
5112be168c0dSopenharmony_ci       for (auto qparam : tmp_quant_param) {
5113be168c0dSopenharmony_ci         bit_num.emplace_back(qparam.bit_num);
5114be168c0dSopenharmony_ci@@ -176,12 +669,12 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P
5115be168c0dSopenharmony_ci       quant_param->scale = scale.data();
5116be168c0dSopenharmony_ci       quant_param->zeroPoint = zero_point.data();
5117be168c0dSopenharmony_ci     }
5118be168c0dSopenharmony_ci-    auto oprend = new (std::nothrow) OH_NN_Tensor;
5119be168c0dSopenharmony_ci+    auto oprend = new(std::nothrow) OH_NN_Tensor;
5120be168c0dSopenharmony_ci     if (oprend == nullptr) {
5121be168c0dSopenharmony_ci       MS_LOG(ERROR) << "new OH_NN_Tensor Failed";
5122be168c0dSopenharmony_ci-      return mindspore::kLiteError;
5123be168c0dSopenharmony_ci+      return kLiteError;
5124be168c0dSopenharmony_ci     }
5125be168c0dSopenharmony_ci-    oprend->dataType = ConvertDataType(tensor.DataType());
5126be168c0dSopenharmony_ci+    oprend->dataType = CastToNNRTDataType(tensor.DataType());
5127be168c0dSopenharmony_ci     oprend->dimensionCount = tensor_shape.size();
5128be168c0dSopenharmony_ci 
5129be168c0dSopenharmony_ci     std::vector<int32_t> dimensions_list;
5130be168c0dSopenharmony_ci@@ -191,14 +684,14 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P
5131be168c0dSopenharmony_ci       } else {
5132be168c0dSopenharmony_ci         MS_LOG(ERROR) << "NNExecutor SetInput failed,tensor dimension is is too large, max dim = " << INT32_MAX
5133be168c0dSopenharmony_ci                       << ", but get dimension = " << shape;
5134be168c0dSopenharmony_ci-        return mindspore::kLiteError;
5135be168c0dSopenharmony_ci+        return kLiteError;
5136be168c0dSopenharmony_ci       }
5137be168c0dSopenharmony_ci     }
5138be168c0dSopenharmony_ci     oprend->dimensions = dimensions_list.data();
5139be168c0dSopenharmony_ci     oprend->quantParam = quant_param;
5140be168c0dSopenharmony_ci     oprend->type = OH_NN_TENSOR;
5141be168c0dSopenharmony_ci     OH_NN_ReturnCode ret_code =
5142be168c0dSopenharmony_ci-      OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize());
5143be168c0dSopenharmony_ci+        OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize());
5144be168c0dSopenharmony_ci     delete (oprend);
5145be168c0dSopenharmony_ci 
5146be168c0dSopenharmony_ci     if (!tmp_quant_param.empty()) {
5147be168c0dSopenharmony_ci@@ -209,70 +702,41 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P
5148be168c0dSopenharmony_ci     if (ret_code != OH_NN_SUCCESS) {
5149be168c0dSopenharmony_ci       MS_LOG(ERROR) << "NNExecutor SetInput failed, current input tensor is" << tensor.Name()
5150be168c0dSopenharmony_ci                     << "OH_NN_ReturnCode = " << ret_code;
5151be168c0dSopenharmony_ci-      return mindspore::kLiteError;
5152be168c0dSopenharmony_ci+      return kLiteError;
5153be168c0dSopenharmony_ci     }
5154be168c0dSopenharmony_ci   }
5155be168c0dSopenharmony_ci+  return kSuccess;
5156be168c0dSopenharmony_ci+}
5157be168c0dSopenharmony_ci+
5158be168c0dSopenharmony_ci+OH_NN_DataType NNRTDelegate::CastToNNRTDataType(DataType data_type) {
5159be168c0dSopenharmony_ci+  const std::unordered_map<DataType, OH_NN_DataType> kDataTypeMap = {
5160be168c0dSopenharmony_ci+      {DataType::kNumberTypeBool, OH_NN_BOOL},
5161be168c0dSopenharmony_ci+      {DataType::kNumberTypeInt8, OH_NN_INT8},
5162be168c0dSopenharmony_ci+      {DataType::kNumberTypeInt16, OH_NN_INT16},
5163be168c0dSopenharmony_ci+      {DataType::kNumberTypeInt32, OH_NN_INT32},
5164be168c0dSopenharmony_ci+      {DataType::kNumberTypeInt64, OH_NN_INT64},
5165be168c0dSopenharmony_ci+      {DataType::kNumberTypeUInt8, OH_NN_UINT8},
5166be168c0dSopenharmony_ci+      {DataType::kNumberTypeUInt16, OH_NN_UINT16},
5167be168c0dSopenharmony_ci+      {DataType::kNumberTypeUInt32, OH_NN_UINT32},
5168be168c0dSopenharmony_ci+      {DataType::kNumberTypeUInt64, OH_NN_UINT64},
5169be168c0dSopenharmony_ci+      {DataType::kNumberTypeFloat16, OH_NN_FLOAT16},
5170be168c0dSopenharmony_ci+      {DataType::kNumberTypeFloat32, OH_NN_FLOAT32},
5171be168c0dSopenharmony_ci+      {DataType::kNumberTypeFloat64, OH_NN_FLOAT64},
5172be168c0dSopenharmony_ci+  };
5173be168c0dSopenharmony_ci 
5174be168c0dSopenharmony_ci-  return mindspore::kSuccess;
5175be168c0dSopenharmony_ci+  auto iter = kDataTypeMap.find(data_type);
5176be168c0dSopenharmony_ci+  if (iter == kDataTypeMap.end()) {
5177be168c0dSopenharmony_ci+    return OH_NN_UNKNOWN;
5178be168c0dSopenharmony_ci+  }
5179be168c0dSopenharmony_ci+  return iter->second;
5180be168c0dSopenharmony_ci }
5181be168c0dSopenharmony_ci-OH_NN_DataType mindspore::NNRTDelegate::ConvertDataType(mindspore::DataType data_type) {
5182be168c0dSopenharmony_ci-  OH_NN_DataType oh_data_type;
5183be168c0dSopenharmony_ci-  switch (data_type) {
5184be168c0dSopenharmony_ci-    case mindspore::DataType::kTypeUnknown:
5185be168c0dSopenharmony_ci-    case mindspore::DataType::kObjectTypeString:
5186be168c0dSopenharmony_ci-    case mindspore::DataType::kObjectTypeList:
5187be168c0dSopenharmony_ci-    case mindspore::DataType::kObjectTypeTuple:
5188be168c0dSopenharmony_ci-    case mindspore::DataType::kObjectTypeTensorType:
5189be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeBegin:
5190be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeEnd:
5191be168c0dSopenharmony_ci-    case mindspore::DataType::kInvalidType:
5192be168c0dSopenharmony_ci-      oh_data_type = OH_NN_UNKNOWN;
5193be168c0dSopenharmony_ci-      break;
5194be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeBool:
5195be168c0dSopenharmony_ci-      oh_data_type = OH_NN_BOOL;
5196be168c0dSopenharmony_ci-      break;
5197be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeInt8:
5198be168c0dSopenharmony_ci-      oh_data_type = OH_NN_INT8;
5199be168c0dSopenharmony_ci-      break;
5200be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeInt16:
5201be168c0dSopenharmony_ci-      oh_data_type = OH_NN_INT16;
5202be168c0dSopenharmony_ci-      break;
5203be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeInt32:
5204be168c0dSopenharmony_ci-      oh_data_type = OH_NN_INT32;
5205be168c0dSopenharmony_ci-      break;
5206be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeInt64:
5207be168c0dSopenharmony_ci-      oh_data_type = OH_NN_INT64;
5208be168c0dSopenharmony_ci-      break;
5209be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeUInt8:
5210be168c0dSopenharmony_ci-      oh_data_type = OH_NN_UINT8;
5211be168c0dSopenharmony_ci-      break;
5212be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeUInt16:
5213be168c0dSopenharmony_ci-      oh_data_type = OH_NN_UINT16;
5214be168c0dSopenharmony_ci-      break;
5215be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeUInt32:
5216be168c0dSopenharmony_ci-      oh_data_type = OH_NN_UINT32;
5217be168c0dSopenharmony_ci-      break;
5218be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeUInt64:
5219be168c0dSopenharmony_ci-      oh_data_type = OH_NN_UINT64;
5220be168c0dSopenharmony_ci-      break;
5221be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeFloat16:
5222be168c0dSopenharmony_ci-      oh_data_type = OH_NN_FLOAT16;
5223be168c0dSopenharmony_ci-      break;
5224be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeFloat32:
5225be168c0dSopenharmony_ci-      oh_data_type = OH_NN_FLOAT32;
5226be168c0dSopenharmony_ci-      break;
5227be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeFloat64:
5228be168c0dSopenharmony_ci-      oh_data_type = OH_NN_FLOAT64;
5229be168c0dSopenharmony_ci-      break;
5230be168c0dSopenharmony_ci-    default: {
5231be168c0dSopenharmony_ci-      oh_data_type = OH_NN_UNKNOWN;
5232be168c0dSopenharmony_ci-    }
5233be168c0dSopenharmony_ci-  }
5234be168c0dSopenharmony_ci-  return oh_data_type;
5235be168c0dSopenharmony_ci+
5236be168c0dSopenharmony_ci+OH_NN_Format NNRTDelegate::CastToNNRTFormat(Format format) {
5237be168c0dSopenharmony_ci+  return OH_NN_FORMAT_NHWC;
5238be168c0dSopenharmony_ci }
5239be168c0dSopenharmony_ci 
5240be168c0dSopenharmony_ci-mindspore::Status mindspore::NNRTDelegate::PrepareOutputs(DelegateModel<schema::Primitive> *model,
5241be168c0dSopenharmony_ci-                                                          OH_NNExecutor *oh_nn_executor) {
5242be168c0dSopenharmony_ci+Status NNRTDelegate::PrepareOutputs(DelegateModel<schema::Primitive> *model,
5243be168c0dSopenharmony_ci+                                    OH_NNExecutor *oh_nn_executor) {
5244be168c0dSopenharmony_ci   auto output_tensors = model->outputs();
5245be168c0dSopenharmony_ci   for (size_t i = 0; i < output_tensors.size(); i++) {
5246be168c0dSopenharmony_ci     auto tensor = output_tensors[i];
5247be168c0dSopenharmony_ci@@ -280,17 +744,17 @@ mindspore::Status mindspore::NNRTDelegate::PrepareOutputs(DelegateModel<schema::
5248be168c0dSopenharmony_ci     if (ret_code != OH_NN_SUCCESS) {
5249be168c0dSopenharmony_ci       MS_LOG(ERROR) << "NNExecutor SetOutput failed, current out tensor is" << tensor.Name()
5250be168c0dSopenharmony_ci                     << ", OH_NN_ReturnCode = " << ret_code;
5251be168c0dSopenharmony_ci-      return mindspore::kLiteError;
5252be168c0dSopenharmony_ci+      return kLiteError;
5253be168c0dSopenharmony_ci     }
5254be168c0dSopenharmony_ci   }
5255be168c0dSopenharmony_ci-  return mindspore::kSuccess;
5256be168c0dSopenharmony_ci+  return kSuccess;
5257be168c0dSopenharmony_ci }
5258be168c0dSopenharmony_ci 
5259be168c0dSopenharmony_ci-void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGraph &lite_graph) {
5260be168c0dSopenharmony_ci+void NNRTDelegate::ShallowCopyLiteGraph(const lite::LiteGraph &lite_graph) {
5261be168c0dSopenharmony_ci   Status ret;
5262be168c0dSopenharmony_ci   for (auto node : lite_graph.all_nodes_) {
5263be168c0dSopenharmony_ci     ret = lite::CheckPrimitiveSupported(static_cast<const schema::Primitive *>(node->primitive_));
5264be168c0dSopenharmony_ci-    if (ret == mindspore::kLiteError) {
5265be168c0dSopenharmony_ci+    if (ret == kLiteError) {
5266be168c0dSopenharmony_ci       MS_LOG(ERROR) << " primitive supported check failed.";
5267be168c0dSopenharmony_ci       return;
5268be168c0dSopenharmony_ci     }
5269be168c0dSopenharmony_ci@@ -299,7 +763,7 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr
5270be168c0dSopenharmony_ci   node_list.reserve(lite_graph.all_nodes_.size());
5271be168c0dSopenharmony_ci   // copy node
5272be168c0dSopenharmony_ci   for (auto node : lite_graph.all_nodes_) {
5273be168c0dSopenharmony_ci-    auto new_node = new (std::nothrow) LiteGraph::Node;
5274be168c0dSopenharmony_ci+    auto new_node = new(std::nothrow) LiteGraph::Node;
5275be168c0dSopenharmony_ci     if (new_node == nullptr) {
5276be168c0dSopenharmony_ci       MS_LOG(ERROR) << " new LiteGraph::Node failed.";
5277be168c0dSopenharmony_ci       return;
5278be168c0dSopenharmony_ci@@ -318,7 +782,7 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr
5279be168c0dSopenharmony_ci   // copy subgraph
5280be168c0dSopenharmony_ci   std::vector<LiteGraph::SubGraph *> subgraph_list;
5281be168c0dSopenharmony_ci   for (auto subgraph : lite_graph.sub_graphs_) {
5282be168c0dSopenharmony_ci-    auto new_subgraph = new (std::nothrow) LiteGraph::SubGraph;
5283be168c0dSopenharmony_ci+    auto new_subgraph = new(std::nothrow) LiteGraph::SubGraph;
5284be168c0dSopenharmony_ci     if (new_subgraph == nullptr) {
5285be168c0dSopenharmony_ci       MS_LOG(ERROR) << "new LiteGraph::Subgraph failed.";
5286be168c0dSopenharmony_ci       return;
5287be168c0dSopenharmony_ci@@ -331,30 +795,32 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr
5288be168c0dSopenharmony_ci   }
5289be168c0dSopenharmony_ci   for (auto tensor : lite_graph.all_tensors_) {
5290be168c0dSopenharmony_ci     ret = lite::CheckTensorSupported(static_cast<const schema::Tensor *>(tensor));
5291be168c0dSopenharmony_ci-    if (ret == mindspore::kLiteError) {
5292be168c0dSopenharmony_ci+    if (ret == kLiteError) {
5293be168c0dSopenharmony_ci       MS_LOG(ERROR) << "tensor supported check failed.";
5294be168c0dSopenharmony_ci       return;
5295be168c0dSopenharmony_ci     }
5296be168c0dSopenharmony_ci   }
5297be168c0dSopenharmony_ci 
5298be168c0dSopenharmony_ci-  nnrt_lite_graph = new (std::nothrow) lite::LiteGraph();
5299be168c0dSopenharmony_ci-  if (nnrt_lite_graph == nullptr) {
5300be168c0dSopenharmony_ci+  lite_graph_ = new(std::nothrow) lite::LiteGraph();
5301be168c0dSopenharmony_ci+  if (lite_graph_ == nullptr) {
5302be168c0dSopenharmony_ci     MS_LOG(ERROR) << "new LiteGraph failed.";
5303be168c0dSopenharmony_ci     return;
5304be168c0dSopenharmony_ci   }
5305be168c0dSopenharmony_ci 
5306be168c0dSopenharmony_ci-  nnrt_lite_graph->name_ = lite_graph.name_;
5307be168c0dSopenharmony_ci-  nnrt_lite_graph->version_ = lite_graph.version_;
5308be168c0dSopenharmony_ci-  nnrt_lite_graph->input_indices_ = lite_graph.input_indices_;
5309be168c0dSopenharmony_ci-  nnrt_lite_graph->output_indices_ = lite_graph.output_indices_;
5310be168c0dSopenharmony_ci-  nnrt_lite_graph->all_tensors_ = lite_graph.all_tensors_;
5311be168c0dSopenharmony_ci-  nnrt_lite_graph->all_nodes_ = node_list;
5312be168c0dSopenharmony_ci-  nnrt_lite_graph->sub_graphs_ = subgraph_list;
5313be168c0dSopenharmony_ci+  lite_graph_->name_ = lite_graph.name_;
5314be168c0dSopenharmony_ci+  lite_graph_->version_ = lite_graph.version_;
5315be168c0dSopenharmony_ci+  lite_graph_->input_indices_ = lite_graph.input_indices_;
5316be168c0dSopenharmony_ci+  lite_graph_->output_indices_ = lite_graph.output_indices_;
5317be168c0dSopenharmony_ci+  lite_graph_->all_tensors_ = lite_graph.all_tensors_;
5318be168c0dSopenharmony_ci+  lite_graph_->all_nodes_ = node_list;
5319be168c0dSopenharmony_ci+  lite_graph_->sub_graphs_ = subgraph_list;
5320be168c0dSopenharmony_ci   MS_LOG(INFO) << "ShallowCopyLiteGraph success.";
5321be168c0dSopenharmony_ci }
5322be168c0dSopenharmony_ci 
5323be168c0dSopenharmony_ci-mindspore::NNRTDelegate::~NNRTDelegate() {
5324be168c0dSopenharmony_ci-  if (this->nnrt_lite_graph != nullptr) {
5325be168c0dSopenharmony_ci+NNRTDelegate::~NNRTDelegate() {
5326be168c0dSopenharmony_ci+  if (lite_graph_ != nullptr) {
5327be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Delete NNRTDelegate.";
5328be168c0dSopenharmony_ci   }
5329be168c0dSopenharmony_ci-};
5330be168c0dSopenharmony_ci+}
5331be168c0dSopenharmony_ci+}  // namespace lite
5332be168c0dSopenharmony_ci+}  // namespace mindspore
5333be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h
5334be168c0dSopenharmony_ciindex c2847704..52626339 100644
5335be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h
5336be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h
5337be168c0dSopenharmony_ci@@ -15,37 +15,81 @@
5338be168c0dSopenharmony_ci  */
5339be168c0dSopenharmony_ci #ifndef MINDSPORE_NNR_DELEGATE_H
5340be168c0dSopenharmony_ci #define MINDSPORE_NNR_DELEGATE_H
5341be168c0dSopenharmony_ci+
5342be168c0dSopenharmony_ci #include <vector>
5343be168c0dSopenharmony_ci #include <map>
5344be168c0dSopenharmony_ci #include "include/api/delegate.h"
5345be168c0dSopenharmony_ci #include "include/model.h"
5346be168c0dSopenharmony_ci-#include "interfaces/kits/c/neural_network_runtime_type.h"
5347be168c0dSopenharmony_ci-namespace mindspore {
5348be168c0dSopenharmony_ci+#include "src/litert/inner_context.h"
5349be168c0dSopenharmony_ci+#include "nnrt_model_kernel.h"
5350be168c0dSopenharmony_ci+#include "schema/model_generated.h"
5351be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime_type.h"
5352be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
5353be168c0dSopenharmony_ci+#include "interfaces/innerkits/c/neural_network_runtime_inner.h"
5354be168c0dSopenharmony_ci 
5355be168c0dSopenharmony_ci-using namespace lite;
5356be168c0dSopenharmony_ci+namespace mindspore {
5357be168c0dSopenharmony_ci+namespace lite {
5358be168c0dSopenharmony_ci+struct NNRTOpRange {
5359be168c0dSopenharmony_ci+  /* NNRT kernel range in DelegateModel: [begin_iter_, end_iter_) */
5360be168c0dSopenharmony_ci+  KernelIter begin_iter_;
5361be168c0dSopenharmony_ci+  KernelIter end_iter_;
5362be168c0dSopenharmony_ci+  /* NNRT node range in lite_graph_: [begin_index_, end_index_) */
5363be168c0dSopenharmony_ci+  size_t begin_index_;
5364be168c0dSopenharmony_ci+  size_t end_index_;
5365be168c0dSopenharmony_ci+};
5366be168c0dSopenharmony_ci 
5367be168c0dSopenharmony_ci class NNRTDelegate : public Delegate {
5368be168c0dSopenharmony_ci  public:
5369be168c0dSopenharmony_ci-  NNRTDelegate() : Delegate(){};
5370be168c0dSopenharmony_ci-
5371be168c0dSopenharmony_ci+  NNRTDelegate() = default;
5372be168c0dSopenharmony_ci+  NNRTDelegate(const NNRtDeviceInfo &nnrt_device_info) : nnrt_device_info_(nnrt_device_info) {}
5373be168c0dSopenharmony_ci   ~NNRTDelegate() override;
5374be168c0dSopenharmony_ci-
5375be168c0dSopenharmony_ci-  Status Init() override;
5376be168c0dSopenharmony_ci-
5377be168c0dSopenharmony_ci+  Status Init() override { return kSuccess; }
5378be168c0dSopenharmony_ci   Status Build(DelegateModel<schema::Primitive> *model) override;
5379be168c0dSopenharmony_ci-
5380be168c0dSopenharmony_ci   void ShallowCopyLiteGraph(const lite::LiteGraph &liteGraph);
5381be168c0dSopenharmony_ci-
5382be168c0dSopenharmony_ci- protected:
5383be168c0dSopenharmony_ci-  LiteGraph *nnrt_lite_graph = nullptr;
5384be168c0dSopenharmony_ci+  void SetMetaGraph(const void *meta_graph) {
5385be168c0dSopenharmony_ci+    meta_graph_ = meta_graph;
5386be168c0dSopenharmony_ci+  }
5387be168c0dSopenharmony_ci+  static std::vector<NNRTOpRange> GetNNRTSubgraphRanges(DelegateModel<schema::Primitive> *model,
5388be168c0dSopenharmony_ci+                                                        const std::vector<bool> &op_supports);
5389be168c0dSopenharmony_ci 
5390be168c0dSopenharmony_ci  private:
5391be168c0dSopenharmony_ci-  //  static LiteGraph* CreateLiteGraph(const LiteGraph &liteGraph);
5392be168c0dSopenharmony_ci+  void InitCachePath();
5393be168c0dSopenharmony_ci+  Status BuildNormalModel(DelegateModel<schema::Primitive> *model);
5394be168c0dSopenharmony_ci+  OH_NNModel *CreateFullNNModel();
5395be168c0dSopenharmony_ci+  std::vector<bool> QueryOpSupports(OH_NNModel *nn_model);
5396be168c0dSopenharmony_ci+  Status CreateLiteGraphForNNRTSubgraph(
5397be168c0dSopenharmony_ci+    const std::vector<NNRTOpRange> &nnrt_op_ranges,
5398be168c0dSopenharmony_ci+    std::vector<LiteGraph *> *sub_lite_graphs);
5399be168c0dSopenharmony_ci+  Status CreateNNRTSubgraphKernels(
5400be168c0dSopenharmony_ci+    DelegateModel<schema::Primitive> *model,
5401be168c0dSopenharmony_ci+    const std::vector<LiteGraph *> &sub_lite_graphs,
5402be168c0dSopenharmony_ci+    const std::vector<NNRTOpRange> &nnrt_subgraph_ranges,
5403be168c0dSopenharmony_ci+    std::vector<NNRTModelKernel *> *nnrt_subgraph_kernels);
5404be168c0dSopenharmony_ci+  void ReplaceNNRTKernelsInDelegateModel(DelegateModel<schema::Primitive> *model,
5405be168c0dSopenharmony_ci+                                         const std::vector<NNRTOpRange> &nnrt_subgraph_ranges,
5406be168c0dSopenharmony_ci+                                         const std::vector<NNRTModelKernel *> &nnrt_subgraph_kernels);
5407be168c0dSopenharmony_ci   Status PrepareInputs(DelegateModel<schema::Primitive> *model, OH_NNExecutor *oh_nn_executor);
5408be168c0dSopenharmony_ci   Status PrepareOutputs(DelegateModel<schema::Primitive> *model, OH_NNExecutor *oh_nn_executor);
5409be168c0dSopenharmony_ci-  OH_NN_DataType ConvertDataType(mindspore::DataType data_type);
5410be168c0dSopenharmony_ci-};
5411be168c0dSopenharmony_ci+  Status InitNNCompilation(OH_NNCompilation *nn_compilation) const;
5412be168c0dSopenharmony_ci+  static OH_NN_DataType CastToNNRTDataType(mindspore::DataType data_type);
5413be168c0dSopenharmony_ci+  static OH_NN_Format CastToNNRTFormat(Format format);
5414be168c0dSopenharmony_ci+  bool IsCustomModel() const;
5415be168c0dSopenharmony_ci+
5416be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT_METAGRAPH
5417be168c0dSopenharmony_ci+  bool IsKirinNPU() const;
5418be168c0dSopenharmony_ci+  Status BuildKirinNPUModel(DelegateModel<schema::Primitive> *model);
5419be168c0dSopenharmony_ci+  Status SetKirinModelInputsAndOutputs(OH_NNModel *nn_model);
5420be168c0dSopenharmony_ci+  std::vector<OH_NN_TensorInfo> CreateNNTensorInfos(const std::vector<uint32_t> &indices) const;
5421be168c0dSopenharmony_ci+  Status CreateFullModelKernel(DelegateModel<schema::Primitive> *model, OH_NNModel *nn_model);
5422be168c0dSopenharmony_ci+#endif
5423be168c0dSopenharmony_ci 
5424be168c0dSopenharmony_ci+  NNRtDeviceInfo nnrt_device_info_;
5425be168c0dSopenharmony_ci+  LiteGraph *lite_graph_ = nullptr;
5426be168c0dSopenharmony_ci+  const void *meta_graph_ = nullptr;
5427be168c0dSopenharmony_ci+  std::string cache_path_ = "";
5428be168c0dSopenharmony_ci+  uint32_t cache_version_ = 0;
5429be168c0dSopenharmony_ci+};
5430be168c0dSopenharmony_ci+}  // namespace lite
5431be168c0dSopenharmony_ci }  // namespace mindspore
5432be168c0dSopenharmony_ci 
5433be168c0dSopenharmony_ci #endif  // MINDSPORE_NNR_DELEGATE_H
5434be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
5435be168c0dSopenharmony_ciindex 5acf2e9a..67443e08 100644
5436be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
5437be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
5438be168c0dSopenharmony_ci@@ -97,7 +97,7 @@ OH_NN_DataType mindspore::NNRTModelKernel::ConvertDataType(mindspore::DataType d
5439be168c0dSopenharmony_ci }
5440be168c0dSopenharmony_ci int mindspore::NNRTModelKernel::PrepareInputs() {
5441be168c0dSopenharmony_ci   auto input_tensors = this->inputs();
5442be168c0dSopenharmony_ci-  for (int i = 0; i < input_tensors.size(); i++) {
5443be168c0dSopenharmony_ci+  for (size_t i = 0; i < input_tensors.size(); i++) {
5444be168c0dSopenharmony_ci     auto tensor = input_tensors[i];
5445be168c0dSopenharmony_ci     auto tensor_shape = tensor.Shape();
5446be168c0dSopenharmony_ci     auto tmp_quant_param = tensor.QuantParams();
5447be168c0dSopenharmony_ci@@ -142,6 +142,7 @@ int mindspore::NNRTModelKernel::PrepareInputs() {
5448be168c0dSopenharmony_ci     oprend->dimensions = dimensions_list.data();
5449be168c0dSopenharmony_ci     oprend->quantParam = quant_param;
5450be168c0dSopenharmony_ci     oprend->type = OH_NN_TENSOR;
5451be168c0dSopenharmony_ci+    MS_LOG_INFO << "input tensor: " << tensor.Name() << ", data: " << (void *)tensor.MutableData() << ", size: " << tensor.DataSize();
5452be168c0dSopenharmony_ci     OH_NN_ReturnCode ret_code =
5453be168c0dSopenharmony_ci       OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize());
5454be168c0dSopenharmony_ci     delete (oprend);
5455be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
5456be168c0dSopenharmony_ciindex cf9481df..ea15f7ca 100644
5457be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
5458be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
5459be168c0dSopenharmony_ci@@ -20,7 +20,7 @@
5460be168c0dSopenharmony_ci #include <map>
5461be168c0dSopenharmony_ci #include <utility>
5462be168c0dSopenharmony_ci #include "include/api/kernel.h"
5463be168c0dSopenharmony_ci-#include "interfaces/kits/c/neural_network_runtime.h"
5464be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
5465be168c0dSopenharmony_ci #include "src/common/log_adapter.h"
5466be168c0dSopenharmony_ci #include "include/errorcode.h"
5467be168c0dSopenharmony_ci 
5468be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc
5469be168c0dSopenharmony_cinew file mode 100644
5470be168c0dSopenharmony_ciindex 00000000..8ac283af
5471be168c0dSopenharmony_ci--- /dev/null
5472be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc
5473be168c0dSopenharmony_ci@@ -0,0 +1,99 @@
5474be168c0dSopenharmony_ci+/**
5475be168c0dSopenharmony_ci+* Copyright 2023 Huawei Technologies Co., Ltd
5476be168c0dSopenharmony_ci+*
5477be168c0dSopenharmony_ci+* Licensed under the Apache License, Version 2.0 (the "License");
5478be168c0dSopenharmony_ci+* you may not use this file except in compliance with the License.
5479be168c0dSopenharmony_ci+* You may obtain a copy of the License at
5480be168c0dSopenharmony_ci+*
5481be168c0dSopenharmony_ci+* http://www.apache.org/licenses/LICENSE-2.0
5482be168c0dSopenharmony_ci+*
5483be168c0dSopenharmony_ci+* Unless required by applicable law or agreed to in writing, software
5484be168c0dSopenharmony_ci+* distributed under the License is distributed on an "AS IS" BASIS,
5485be168c0dSopenharmony_ci+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5486be168c0dSopenharmony_ci+* See the License for the specific language governing permissions and
5487be168c0dSopenharmony_ci+* limitations under the License.
5488be168c0dSopenharmony_ci+*/
5489be168c0dSopenharmony_ci+
5490be168c0dSopenharmony_ci+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
5491be168c0dSopenharmony_ci+#include "interfaces/innerkits/c/neural_network_runtime_inner.h"
5492be168c0dSopenharmony_ci+
5493be168c0dSopenharmony_ci+OH_NNModel *OH_NNModel_Construct(void) {
5494be168c0dSopenharmony_ci+  return NULL;
5495be168c0dSopenharmony_ci+}
5496be168c0dSopenharmony_ci+
5497be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNExecutor_Run(OH_NNExecutor *executor) {
5498be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5499be168c0dSopenharmony_ci+}
5500be168c0dSopenharmony_ci+
5501be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_Build(OH_NNCompilation *compilation) {
5502be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5503be168c0dSopenharmony_ci+}
5504be168c0dSopenharmony_ci+
5505be168c0dSopenharmony_ci+void OH_NNCompilation_Destroy(OH_NNCompilation **compilation) {}
5506be168c0dSopenharmony_ci+
5507be168c0dSopenharmony_ci+OH_NNExecutor *OH_NNExecutor_Construct(OH_NNCompilation *compilation) {
5508be168c0dSopenharmony_ci+  return NULL;
5509be168c0dSopenharmony_ci+}
5510be168c0dSopenharmony_ci+
5511be168c0dSopenharmony_ci+void OH_NNExecutor_Destroy(OH_NNExecutor **executor) {}
5512be168c0dSopenharmony_ci+
5513be168c0dSopenharmony_ci+OH_NNCompilation *OH_NNCompilation_Construct(const OH_NNModel *model) {
5514be168c0dSopenharmony_ci+  return NULL;
5515be168c0dSopenharmony_ci+}
5516be168c0dSopenharmony_ci+
5517be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNDevice_GetAllDevicesID(const size_t **allDevicesID, uint32_t *deviceCount) {
5518be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5519be168c0dSopenharmony_ci+}
5520be168c0dSopenharmony_ci+
5521be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNExecutor_SetOutput(OH_NNExecutor *executor,
5522be168c0dSopenharmony_ci+                                         uint32_t outputIndex,
5523be168c0dSopenharmony_ci+                                         void *dataBuffer,
5524be168c0dSopenharmony_ci+                                         size_t length) {
5525be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5526be168c0dSopenharmony_ci+}
5527be168c0dSopenharmony_ci+
5528be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_SetDevice(OH_NNCompilation *compilation, size_t deviceID) {
5529be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5530be168c0dSopenharmony_ci+}
5531be168c0dSopenharmony_ci+
5532be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNExecutor_SetInput(OH_NNExecutor *executor,
5533be168c0dSopenharmony_ci+                                        uint32_t inputIndex,
5534be168c0dSopenharmony_ci+                                        const OH_NN_Tensor *tensor,
5535be168c0dSopenharmony_ci+                                        const void *dataBuffer,
5536be168c0dSopenharmony_ci+                                        size_t length) {
5537be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5538be168c0dSopenharmony_ci+}
5539be168c0dSopenharmony_ci+
5540be168c0dSopenharmony_ci+void OH_NNModel_Destroy(OH_NNModel **model) {}
5541be168c0dSopenharmony_ci+
5542be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNModel_GetAvailableOperations(OH_NNModel *model,
5543be168c0dSopenharmony_ci+                                                   size_t deviceID,
5544be168c0dSopenharmony_ci+                                                   const bool **isSupported,
5545be168c0dSopenharmony_ci+                                                   uint32_t *opCount) {
5546be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5547be168c0dSopenharmony_ci+}
5548be168c0dSopenharmony_ci+
5549be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNModel_BuildFromLiteGraph(OH_NNModel *model, const void *liteGraph) {
5550be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5551be168c0dSopenharmony_ci+}
5552be168c0dSopenharmony_ci+
5553be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNDevice_GetName(size_t deviceID, const char **name) {
5554be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5555be168c0dSopenharmony_ci+}
5556be168c0dSopenharmony_ci+
5557be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNDevice_GetType(size_t deviceID, OH_NN_DeviceType *deviceType) {
5558be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5559be168c0dSopenharmony_ci+}
5560be168c0dSopenharmony_ci+
5561be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_SetPriority(OH_NNCompilation *compilation, OH_NN_Priority priority) {
5562be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5563be168c0dSopenharmony_ci+}
5564be168c0dSopenharmony_ci+
5565be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_EnableFloat16(OH_NNCompilation *compilation, bool enableFloat16) {
5566be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5567be168c0dSopenharmony_ci+}
5568be168c0dSopenharmony_ci+
5569be168c0dSopenharmony_ci+OH_NN_ReturnCode OH_NNCompilation_SetPerformanceMode(OH_NNCompilation *compilation,
5570be168c0dSopenharmony_ci+                                                     OH_NN_PerformanceMode performanceMode) {
5571be168c0dSopenharmony_ci+  return OH_NN_SUCCESS;
5572be168c0dSopenharmony_ci+}
5573be168c0dSopenharmony_ci\ No newline at end of file
5574be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/infer_manager.cc b/mindspore/lite/src/litert/infer_manager.cc
5575be168c0dSopenharmony_ciindex 2b21d1ca..908ab122 100644
5576be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/infer_manager.cc
5577be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/infer_manager.cc
5578be168c0dSopenharmony_ci@@ -162,7 +162,8 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
5579be168c0dSopenharmony_ci   if (parameter->type_ == static_cast<int>(schema::PrimitiveType_PartialFusion) ||
5580be168c0dSopenharmony_ci       parameter->type_ == static_cast<int>(schema::PrimitiveType_Switch) ||
5581be168c0dSopenharmony_ci       parameter->type_ == static_cast<int>(schema::PrimitiveType_Call) ||
5582be168c0dSopenharmony_ci-      parameter->type_ == static_cast<int>(schema::PrimitiveType_SwitchLayer)) {
5583be168c0dSopenharmony_ci+      parameter->type_ == static_cast<int>(schema::PrimitiveType_SwitchLayer) ||
5584be168c0dSopenharmony_ci+      parameter->type_ == static_cast<int>(PrimType_Inner_ThirdPartyModel)) {
5585be168c0dSopenharmony_ci     MS_LOG(INFO) << "no need infer shape.";
5586be168c0dSopenharmony_ci     return RET_OK;
5587be168c0dSopenharmony_ci   }
5588be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/inner_context.cc b/mindspore/lite/src/litert/inner_context.cc
5589be168c0dSopenharmony_ciindex 7cbac8f7..bf585ff0 100644
5590be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/inner_context.cc
5591be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/inner_context.cc
5592be168c0dSopenharmony_ci@@ -122,6 +122,10 @@ int InnerContext::Init() {
5593be168c0dSopenharmony_ci #endif
5594be168c0dSopenharmony_ci   }
5595be168c0dSopenharmony_ci 
5596be168c0dSopenharmony_ci+  if (IsDeviceTypeEnabled(DT_NNRT)) {
5597be168c0dSopenharmony_ci+    MS_LOG(DEBUG) << "NNRT enabled.";
5598be168c0dSopenharmony_ci+  }
5599be168c0dSopenharmony_ci+
5600be168c0dSopenharmony_ci   if (CreateThreadPool(false)) {
5601be168c0dSopenharmony_ci     MS_LOG(ERROR) << "CreateThreadPool failed.";
5602be168c0dSopenharmony_ci     return RET_ERROR;
5603be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/inner_context.h b/mindspore/lite/src/litert/inner_context.h
5604be168c0dSopenharmony_ciindex 88281eb1..8735961c 100644
5605be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/inner_context.h
5606be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/inner_context.h
5607be168c0dSopenharmony_ci@@ -71,12 +71,26 @@ typedef struct CustomDeviceInfo {
5608be168c0dSopenharmony_ci   std::shared_ptr<DeviceInfoContext> user_defined_device_info_;
5609be168c0dSopenharmony_ci } CustomDeviceInfo;
5610be168c0dSopenharmony_ci 
5611be168c0dSopenharmony_ci+typedef struct Extension {
5612be168c0dSopenharmony_ci+  std::string name; // config name
5613be168c0dSopenharmony_ci+  std::vector<uint8_t> value; // config value
5614be168c0dSopenharmony_ci+} Extension;
5615be168c0dSopenharmony_ci+
5616be168c0dSopenharmony_ci+typedef struct NNRtDeviceInfo {
5617be168c0dSopenharmony_ci+  size_t device_id_ = 0;
5618be168c0dSopenharmony_ci+  int priority_ = 0;
5619be168c0dSopenharmony_ci+  int performance_mode_ = 0;
5620be168c0dSopenharmony_ci+  bool enable_fp16_ = false;
5621be168c0dSopenharmony_ci+  std::vector<Extension> extensions_;
5622be168c0dSopenharmony_ci+} NNRtDeviceInfo;
5623be168c0dSopenharmony_ci+
5624be168c0dSopenharmony_ci struct DeviceInfo {
5625be168c0dSopenharmony_ci   CpuDeviceInfo cpu_device_info_;
5626be168c0dSopenharmony_ci   GpuDeviceInfo gpu_device_info_;
5627be168c0dSopenharmony_ci   NpuDeviceInfo npu_device_info_;
5628be168c0dSopenharmony_ci   AscendDeviceInfo ascend_device_info_;
5629be168c0dSopenharmony_ci   CustomDeviceInfo custom_device_info_;
5630be168c0dSopenharmony_ci+  NNRtDeviceInfo nnrt_device_info_;
5631be168c0dSopenharmony_ci };
5632be168c0dSopenharmony_ci 
5633be168c0dSopenharmony_ci struct DeviceContext {
5634be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn
5635be168c0dSopenharmony_ciindex 48308425..65065b5b 100644
5636be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn
5637be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn
5638be168c0dSopenharmony_ci@@ -13,6 +13,10 @@ cpu_kernel_sources = [
5639be168c0dSopenharmony_ci     "base/call.cc",
5640be168c0dSopenharmony_ci     "base/constant_of_shape.cc",
5641be168c0dSopenharmony_ci     "base/convolution_base.cc",
5642be168c0dSopenharmony_ci+    "base/custom_base.cc",
5643be168c0dSopenharmony_ci+    "base/custom_masked_fill.cc",
5644be168c0dSopenharmony_ci+    "base/custom_is_inf.cc",
5645be168c0dSopenharmony_ci+    "base/custom_tensor_scatter.cc",
5646be168c0dSopenharmony_ci     "base/detection_post_process_base.cc",
5647be168c0dSopenharmony_ci     "base/format_transpose.cc",
5648be168c0dSopenharmony_ci     "base/group_convolution_base.cc",
5649be168c0dSopenharmony_ci@@ -37,7 +41,6 @@ cpu_kernel_sources = [
5650be168c0dSopenharmony_ci     "fp32/batchnorm_fp32.cc",
5651be168c0dSopenharmony_ci     "fp32/batch_to_space_fp32.cc",
5652be168c0dSopenharmony_ci     "fp32/broadcast_to_fp32.cc",
5653be168c0dSopenharmony_ci-    "fp32/cast_for_x86_fp16.cc",
5654be168c0dSopenharmony_ci     "fp32/cast_fp32.cc",
5655be168c0dSopenharmony_ci     "fp32/convolution_1x1_fp32.cc",
5656be168c0dSopenharmony_ci     "fp32/convolution_delegate_fp32.cc",
5657be168c0dSopenharmony_ci@@ -118,6 +121,10 @@ cpu_kernel_sources = [
5658be168c0dSopenharmony_ci     "fp32/online_fusion/split_reduce_concat_fp32.cc",
5659be168c0dSopenharmony_ci ]
5660be168c0dSopenharmony_ci 
5661be168c0dSopenharmony_ci+if ((target_cpu != "arm") && (target_cpu != "arm64")) {
5662be168c0dSopenharmony_ci+    cpu_kernel_sources += [ "src/runtime/kernel/cpu/fp32/cast_for_x86_fp16.cc" ]
5663be168c0dSopenharmony_ci+}
5664be168c0dSopenharmony_ci+
5665be168c0dSopenharmony_ci arm64_cpu_kernel_sources = [
5666be168c0dSopenharmony_ci   "fp32/convolution_im2col_arm64_fp32.cc",
5667be168c0dSopenharmony_ci   "fp32/matmul_fp32_arm64.cc",
5668be168c0dSopenharmony_ci@@ -142,6 +149,42 @@ sse_avx_avx512_kernel_sources = [
5669be168c0dSopenharmony_ci   "fp32/matmul_fp32_avx512.cc",
5670be168c0dSopenharmony_ci ]
5671be168c0dSopenharmony_ci 
5672be168c0dSopenharmony_ci+fp16_kernel_sources = [
5673be168c0dSopenharmony_ci+  "fp16/batchnorm_fp16.cc",
5674be168c0dSopenharmony_ci+  "fp16/biasadd_fp16.cc",
5675be168c0dSopenharmony_ci+  "fp16/cast_fp16.cc",
5676be168c0dSopenharmony_ci+  "fp16/common_fp16.cc",
5677be168c0dSopenharmony_ci+  "fp16/convolution_1x1_fp16.cc",
5678be168c0dSopenharmony_ci+  "fp16/convolution_delegate_fp16.cc",
5679be168c0dSopenharmony_ci+  "fp16/convolution_depthwise_3x3_fp16.cc",
5680be168c0dSopenharmony_ci+  "fp16/convolution_depthwise_fp16.cc",
5681be168c0dSopenharmony_ci+  "fp16/convolution_depthwise_slidewindow_fp16.cc",
5682be168c0dSopenharmony_ci+  "fp16/convolution_fp16.cc",
5683be168c0dSopenharmony_ci+  "fp16/convolution_winograd_fp16.cc",
5684be168c0dSopenharmony_ci+  "fp16/custom_gru_fp16.cc",
5685be168c0dSopenharmony_ci+  "fp16/deconvolution_depthwise_fp16.cc",
5686be168c0dSopenharmony_ci+  "fp16/deconvolution_fp16.cc",
5687be168c0dSopenharmony_ci+  "fp16/deconvolution_winograd_fp16.cc",
5688be168c0dSopenharmony_ci+  "fp16/depth_to_space_fp16.cc",
5689be168c0dSopenharmony_ci+  "fp16/dynamic_quant_fp16.cc",
5690be168c0dSopenharmony_ci+  "fp16/fullconnection_fp16.cc",
5691be168c0dSopenharmony_ci+  "fp16/fused_batchnorm_fp16.cc",
5692be168c0dSopenharmony_ci+  "fp16/group_convolution_fp16.cc",
5693be168c0dSopenharmony_ci+  "fp16/gru_fp16.cc",
5694be168c0dSopenharmony_ci+  "fp16/instance_norm_fp16.cc",
5695be168c0dSopenharmony_ci+  "fp16/layout_transform_fp16.cc",
5696be168c0dSopenharmony_ci+  "fp16/lstm_fp16.cc",
5697be168c0dSopenharmony_ci+  "fp16/matmul_base_fp16.cc",
5698be168c0dSopenharmony_ci+  "fp16/matmul_fp16.cc",
5699be168c0dSopenharmony_ci+  "fp16/power_fp16.cc",
5700be168c0dSopenharmony_ci+  "fp16/prelu_fp16.cc",
5701be168c0dSopenharmony_ci+  "fp16/quant_dtype_cast_fp16.cc",
5702be168c0dSopenharmony_ci+  "fp16/reduce_fp16.cc",
5703be168c0dSopenharmony_ci+  "fp16/resize_fp16.cc",
5704be168c0dSopenharmony_ci+  "fp16/slice_fp16.cc",
5705be168c0dSopenharmony_ci+  "fp16/where_fp16.cc",
5706be168c0dSopenharmony_ci+]
5707be168c0dSopenharmony_ci+
5708be168c0dSopenharmony_ci int8_kernel_sources = [
5709be168c0dSopenharmony_ci     "int8/activation_int8.cc",
5710be168c0dSopenharmony_ci     "int8/add_int8.cc",
5711be168c0dSopenharmony_ci@@ -227,6 +270,12 @@ all_cpu_kernel_sources += int8_kernel_sources
5712be168c0dSopenharmony_ci all_cpu_kernel_sources += string_kernel_sources
5713be168c0dSopenharmony_ci all_cpu_kernel_sources += control_kernel_sources
5714be168c0dSopenharmony_ci 
5715be168c0dSopenharmony_ci+if (target_cpu == "arm64") {
5716be168c0dSopenharmony_ci+    all_cpu_kernel_sources += fp16_kernel_sources
5717be168c0dSopenharmony_ci+} else {
5718be168c0dSopenharmony_ci+    not_needed(fp16_kernel_sources)
5719be168c0dSopenharmony_ci+}
5720be168c0dSopenharmony_ci+
5721be168c0dSopenharmony_ci if (target_cpu == "arm") {
5722be168c0dSopenharmony_ci   all_cpu_kernel_sources -= arm64_cpu_kernel_sources
5723be168c0dSopenharmony_ci   all_cpu_kernel_sources -= sse_avx_avx512_kernel_sources
5724be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc
5725be168c0dSopenharmony_cinew file mode 100644
5726be168c0dSopenharmony_ciindex 00000000..9921e063
5727be168c0dSopenharmony_ci--- /dev/null
5728be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc
5729be168c0dSopenharmony_ci@@ -0,0 +1,46 @@
5730be168c0dSopenharmony_ci+/**
5731be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd
5732be168c0dSopenharmony_ci+ *
5733be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
5734be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
5735be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
5736be168c0dSopenharmony_ci+ *
5737be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
5738be168c0dSopenharmony_ci+ *
5739be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
5740be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
5741be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5742be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
5743be168c0dSopenharmony_ci+ * limitations under the License.
5744be168c0dSopenharmony_ci+ */
5745be168c0dSopenharmony_ci+
5746be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/custom_base.h"
5747be168c0dSopenharmony_ci+#include <algorithm>
5748be168c0dSopenharmony_ci+#include <utility>
5749be168c0dSopenharmony_ci+#include <vector>
5750be168c0dSopenharmony_ci+#include "src/litert/kernel_registry.h"
5751be168c0dSopenharmony_ci+#include "nnacl/op_base.h"
5752be168c0dSopenharmony_ci+
5753be168c0dSopenharmony_ci+using mindspore::kernel::KERNEL_ARCH;
5754be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar;
5755be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR;
5756be168c0dSopenharmony_ci+using mindspore::lite::RET_OK;
5757be168c0dSopenharmony_ci+using mindspore::schema::PrimitiveType_Custom;
5758be168c0dSopenharmony_ci+
5759be168c0dSopenharmony_ci+namespace mindspore::kernel {
5760be168c0dSopenharmony_ci+int CustomBaseCPUKernel::Prepare() {
5761be168c0dSopenharmony_ci+  return RET_OK;
5762be168c0dSopenharmony_ci+}
5763be168c0dSopenharmony_ci+
5764be168c0dSopenharmony_ci+int CustomBaseCPUKernel::ReSize() {
5765be168c0dSopenharmony_ci+  return RET_OK;
5766be168c0dSopenharmony_ci+}
5767be168c0dSopenharmony_ci+
5768be168c0dSopenharmony_ci+int CustomBaseCPUKernel::Run() {
5769be168c0dSopenharmony_ci+  return RET_OK;
5770be168c0dSopenharmony_ci+}
5771be168c0dSopenharmony_ci+
5772be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeInt32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>)
5773be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>)
5774be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeBool, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>)
5775be168c0dSopenharmony_ci+}  // namespace mindspore::kernel
5776be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h
5777be168c0dSopenharmony_cinew file mode 100644
5778be168c0dSopenharmony_ciindex 00000000..ecb4c72d
5779be168c0dSopenharmony_ci--- /dev/null
5780be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h
5781be168c0dSopenharmony_ci@@ -0,0 +1,43 @@
5782be168c0dSopenharmony_ci+/**
5783be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd
5784be168c0dSopenharmony_ci+ *
5785be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
5786be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
5787be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
5788be168c0dSopenharmony_ci+ *
5789be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
5790be168c0dSopenharmony_ci+ *
5791be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
5792be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
5793be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5794be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
5795be168c0dSopenharmony_ci+ * limitations under the License.
5796be168c0dSopenharmony_ci+ */
5797be168c0dSopenharmony_ci+
5798be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_
5799be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_
5800be168c0dSopenharmony_ci+
5801be168c0dSopenharmony_ci+#include <vector>
5802be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h"
5803be168c0dSopenharmony_ci+#include "nnacl/custom_parameter.h"
5804be168c0dSopenharmony_ci+
5805be168c0dSopenharmony_ci+namespace mindspore::kernel {
5806be168c0dSopenharmony_ci+class CustomBaseCPUKernel : public LiteKernel {
5807be168c0dSopenharmony_ci+ public:
5808be168c0dSopenharmony_ci+  CustomBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
5809be168c0dSopenharmony_ci+                      const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
5810be168c0dSopenharmony_ci+      : LiteKernel(parameter, inputs, outputs, ctx) {
5811be168c0dSopenharmony_ci+    custom_param_ = reinterpret_cast<CustomParameter *>(op_parameter_);
5812be168c0dSopenharmony_ci+  }
5813be168c0dSopenharmony_ci+  ~CustomBaseCPUKernel() override = default;
5814be168c0dSopenharmony_ci+
5815be168c0dSopenharmony_ci+  int Prepare() override;
5816be168c0dSopenharmony_ci+  int ReSize() override;
5817be168c0dSopenharmony_ci+  int Run() override;
5818be168c0dSopenharmony_ci+
5819be168c0dSopenharmony_ci+ private:
5820be168c0dSopenharmony_ci+  CustomParameter *custom_param_ = nullptr;
5821be168c0dSopenharmony_ci+};
5822be168c0dSopenharmony_ci+}  // namespace mindspore::kernel
5823be168c0dSopenharmony_ci+
5824be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_
5825be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc
5826be168c0dSopenharmony_cinew file mode 100644
5827be168c0dSopenharmony_ciindex 00000000..edffea42
5828be168c0dSopenharmony_ci--- /dev/null
5829be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc
5830be168c0dSopenharmony_ci@@ -0,0 +1,61 @@
5831be168c0dSopenharmony_ci+/**
5832be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
5833be168c0dSopenharmony_ci+ *
5834be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
5835be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
5836be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
5837be168c0dSopenharmony_ci+ *
5838be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
5839be168c0dSopenharmony_ci+ *
5840be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
5841be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
5842be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5843be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
5844be168c0dSopenharmony_ci+ * limitations under the License.
5845be168c0dSopenharmony_ci+ */
5846be168c0dSopenharmony_ci+#include "src/litert/kernel_registry.h"
5847be168c0dSopenharmony_ci+#include "include/errorcode.h"
5848be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/custom_is_inf.h"
5849be168c0dSopenharmony_ci+#include "src/common/tensor_util.h"
5850be168c0dSopenharmony_ci+#include "nnacl/op_base.h"
5851be168c0dSopenharmony_ci+
5852be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar;
5853be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR;
5854be168c0dSopenharmony_ci+using mindspore::lite::RET_OK;
5855be168c0dSopenharmony_ci+
5856be168c0dSopenharmony_ci+namespace mindspore::kernel {
5857be168c0dSopenharmony_ci+
5858be168c0dSopenharmony_ci+int CustomIsInfCPUKernel::Prepare() {
5859be168c0dSopenharmony_ci+  CHECK_LESS_RETURN(in_tensors_.size(), C1NUM);
5860be168c0dSopenharmony_ci+  CHECK_LESS_RETURN(out_tensors_.size(), C1NUM);
5861be168c0dSopenharmony_ci+  return RET_OK;
5862be168c0dSopenharmony_ci+}
5863be168c0dSopenharmony_ci+
5864be168c0dSopenharmony_ci+int CustomIsInfCPUKernel::ReSize() { return RET_OK; }
5865be168c0dSopenharmony_ci+
5866be168c0dSopenharmony_ci+void CustomIsInfCPUKernel::LaunchKernelFloat(const float *input, bool *output) {
5867be168c0dSopenharmony_ci+  auto elem_num = in_tensors_[FIRST_INPUT]->ElementsNum();
5868be168c0dSopenharmony_ci+
5869be168c0dSopenharmony_ci+  for (int i = 0; i < elem_num; i++) {
5870be168c0dSopenharmony_ci+    output[i] = std::isinf(input[i]);
5871be168c0dSopenharmony_ci+  }
5872be168c0dSopenharmony_ci+}
5873be168c0dSopenharmony_ci+
5874be168c0dSopenharmony_ci+int CustomIsInfCPUKernel::Run() {
5875be168c0dSopenharmony_ci+  auto input = in_tensors_[FIRST_INPUT];
5876be168c0dSopenharmony_ci+  auto output = out_tensors_[FIRST_INPUT];
5877be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(input);
5878be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(output);
5879be168c0dSopenharmony_ci+
5880be168c0dSopenharmony_ci+  if (input->data_type() == kNumberTypeFloat32 || input->data_type() == kNumberTypeFloat) {
5881be168c0dSopenharmony_ci+    LaunchKernelFloat(reinterpret_cast<const float *>(input->data()), reinterpret_cast<bool *>(output->data()));
5882be168c0dSopenharmony_ci+  } else {
5883be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "unsupported input data type " << input->data_type();
5884be168c0dSopenharmony_ci+    return RET_ERROR;
5885be168c0dSopenharmony_ci+  }
5886be168c0dSopenharmony_ci+
5887be168c0dSopenharmony_ci+  return RET_OK;
5888be168c0dSopenharmony_ci+}
5889be168c0dSopenharmony_ci+
5890be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomIsInf, LiteKernelCreator<CustomIsInfCPUKernel>)
5891be168c0dSopenharmony_ci+}  // namespace mindspore::kernel
5892be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h
5893be168c0dSopenharmony_cinew file mode 100644
5894be168c0dSopenharmony_ciindex 00000000..e63d8ec7
5895be168c0dSopenharmony_ci--- /dev/null
5896be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h
5897be168c0dSopenharmony_ci@@ -0,0 +1,38 @@
5898be168c0dSopenharmony_ci+/**
5899be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
5900be168c0dSopenharmony_ci+ *
5901be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
5902be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
5903be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
5904be168c0dSopenharmony_ci+ *
5905be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
5906be168c0dSopenharmony_ci+ *
5907be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
5908be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
5909be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5910be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
5911be168c0dSopenharmony_ci+ * limitations under the License.
5912be168c0dSopenharmony_ci+ */
5913be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_
5914be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_
5915be168c0dSopenharmony_ci+
5916be168c0dSopenharmony_ci+#include <vector>
5917be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h"
5918be168c0dSopenharmony_ci+
5919be168c0dSopenharmony_ci+namespace mindspore::kernel {
5920be168c0dSopenharmony_ci+class CustomIsInfCPUKernel : public LiteKernel {
5921be168c0dSopenharmony_ci+ public:
5922be168c0dSopenharmony_ci+  CustomIsInfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
5923be168c0dSopenharmony_ci+                       const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
5924be168c0dSopenharmony_ci+      : LiteKernel(parameter, inputs, outputs, ctx) {}
5925be168c0dSopenharmony_ci+  ~CustomIsInfCPUKernel() override = default;
5926be168c0dSopenharmony_ci+  int Prepare() override;
5927be168c0dSopenharmony_ci+  int ReSize() override;
5928be168c0dSopenharmony_ci+  int Run() override;
5929be168c0dSopenharmony_ci+
5930be168c0dSopenharmony_ci+ private:
5931be168c0dSopenharmony_ci+  void LaunchKernelFloat(const float *input, bool *output);
5932be168c0dSopenharmony_ci+};
5933be168c0dSopenharmony_ci+}  // namespace mindspore::kernel
5934be168c0dSopenharmony_ci+
5935be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_
5936be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc
5937be168c0dSopenharmony_cinew file mode 100644
5938be168c0dSopenharmony_ciindex 00000000..9af1af5d
5939be168c0dSopenharmony_ci--- /dev/null
5940be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc
5941be168c0dSopenharmony_ci@@ -0,0 +1,84 @@
5942be168c0dSopenharmony_ci+/**
5943be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
5944be168c0dSopenharmony_ci+ *
5945be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
5946be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
5947be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
5948be168c0dSopenharmony_ci+ *
5949be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
5950be168c0dSopenharmony_ci+ *
5951be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
5952be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
5953be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5954be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
5955be168c0dSopenharmony_ci+ * limitations under the License.
5956be168c0dSopenharmony_ci+ */
5957be168c0dSopenharmony_ci+#include "src/litert/kernel_registry.h"
5958be168c0dSopenharmony_ci+#include "include/errorcode.h"
5959be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/custom_masked_fill.h"
5960be168c0dSopenharmony_ci+#include "src/common/tensor_util.h"
5961be168c0dSopenharmony_ci+#include "nnacl/op_base.h"
5962be168c0dSopenharmony_ci+
5963be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar;
5964be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR;
5965be168c0dSopenharmony_ci+using mindspore::lite::RET_OK;
5966be168c0dSopenharmony_ci+
5967be168c0dSopenharmony_ci+namespace mindspore::kernel {
5968be168c0dSopenharmony_ci+
5969be168c0dSopenharmony_ci+int CustomMaskedFillCPUKernel::Prepare() {
5970be168c0dSopenharmony_ci+  CHECK_LESS_RETURN(in_tensors_.size(), C3NUM);
5971be168c0dSopenharmony_ci+  CHECK_LESS_RETURN(out_tensors_.size(), C1NUM);
5972be168c0dSopenharmony_ci+
5973be168c0dSopenharmony_ci+  // only support input value as a single float value
5974be168c0dSopenharmony_ci+  MS_CHECK_TRUE_MSG(in_tensors_[FIRST_INPUT]->data_type() == mindspore::TypeId::kNumberTypeFloat32 ||
5975be168c0dSopenharmony_ci+                      in_tensors_[FIRST_INPUT]->data_type() == mindspore::TypeId::kNumberTypeFloat,
5976be168c0dSopenharmony_ci+                    RET_ERROR, "input dtype must be float32");
5977be168c0dSopenharmony_ci+  if (in_tensors_[THIRD_INPUT]->ElementsNum() != 1) {
5978be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "only support fill value as a single float";
5979be168c0dSopenharmony_ci+    return RET_ERROR;
5980be168c0dSopenharmony_ci+  }
5981be168c0dSopenharmony_ci+  MS_CHECK_TRUE_MSG(in_tensors_[SECOND_INPUT]->data_type() == mindspore::TypeId::kNumberTypeBool, RET_ERROR,
5982be168c0dSopenharmony_ci+                    "mask dtype must be bool");
5983be168c0dSopenharmony_ci+  if (!InferShapeDone()) {
5984be168c0dSopenharmony_ci+    return RET_OK;
5985be168c0dSopenharmony_ci+  }
5986be168c0dSopenharmony_ci+  return ReSize();
5987be168c0dSopenharmony_ci+}
5988be168c0dSopenharmony_ci+
5989be168c0dSopenharmony_ci+int CustomMaskedFillCPUKernel::ReSize() { return RET_OK; }
5990be168c0dSopenharmony_ci+
5991be168c0dSopenharmony_ci+int CustomMaskedFillCPUKernel::Run() {
5992be168c0dSopenharmony_ci+  auto input = in_tensors_[FIRST_INPUT];
5993be168c0dSopenharmony_ci+  auto mask = in_tensors_[SECOND_INPUT];
5994be168c0dSopenharmony_ci+  auto value = in_tensors_[THIRD_INPUT];
5995be168c0dSopenharmony_ci+  auto output = out_tensors_[FIRST_INPUT];
5996be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(input);
5997be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(mask);
5998be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(value);
5999be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(output);
6000be168c0dSopenharmony_ci+
6001be168c0dSopenharmony_ci+  if (input->shape() != mask->shape()) {
6002be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Not support broadcast mask to input";
6003be168c0dSopenharmony_ci+    return RET_ERROR;
6004be168c0dSopenharmony_ci+  }
6005be168c0dSopenharmony_ci+
6006be168c0dSopenharmony_ci+  auto value_data = reinterpret_cast<float *>(value->data());
6007be168c0dSopenharmony_ci+  auto fill_value = value_data[0];
6008be168c0dSopenharmony_ci+
6009be168c0dSopenharmony_ci+  auto data_num = input->ElementsNum();
6010be168c0dSopenharmony_ci+  auto input_data = reinterpret_cast<float *>(input->data());
6011be168c0dSopenharmony_ci+  auto mask_data = reinterpret_cast<bool *>(mask->data());
6012be168c0dSopenharmony_ci+  auto output_data = reinterpret_cast<float *>(output->data());
6013be168c0dSopenharmony_ci+  for (int64_t i = 0; i < data_num; i++) {
6014be168c0dSopenharmony_ci+    if (mask_data[i]) {
6015be168c0dSopenharmony_ci+      output_data[i] = fill_value;
6016be168c0dSopenharmony_ci+    } else {
6017be168c0dSopenharmony_ci+      output_data[i] = input_data[i];
6018be168c0dSopenharmony_ci+    }
6019be168c0dSopenharmony_ci+  }
6020be168c0dSopenharmony_ci+
6021be168c0dSopenharmony_ci+  return RET_OK;
6022be168c0dSopenharmony_ci+}
6023be168c0dSopenharmony_ci+
6024be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomMaskedFill, LiteKernelCreator<CustomMaskedFillCPUKernel>)
6025be168c0dSopenharmony_ci+}  // namespace mindspore::kernel
6026be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h
6027be168c0dSopenharmony_cinew file mode 100644
6028be168c0dSopenharmony_ciindex 00000000..04a2dcab
6029be168c0dSopenharmony_ci--- /dev/null
6030be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h
6031be168c0dSopenharmony_ci@@ -0,0 +1,35 @@
6032be168c0dSopenharmony_ci+/**
6033be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
6034be168c0dSopenharmony_ci+ *
6035be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
6036be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
6037be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
6038be168c0dSopenharmony_ci+ *
6039be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
6040be168c0dSopenharmony_ci+ *
6041be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
6042be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
6043be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6044be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
6045be168c0dSopenharmony_ci+ * limitations under the License.
6046be168c0dSopenharmony_ci+ */
6047be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_
6048be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_
6049be168c0dSopenharmony_ci+
6050be168c0dSopenharmony_ci+#include <vector>
6051be168c0dSopenharmony_ci+#include "src/litert/lite_kernel.h"
6052be168c0dSopenharmony_ci+
6053be168c0dSopenharmony_ci+namespace mindspore::kernel {
6054be168c0dSopenharmony_ci+class CustomMaskedFillCPUKernel : public LiteKernel {
6055be168c0dSopenharmony_ci+ public:
6056be168c0dSopenharmony_ci+  CustomMaskedFillCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
6057be168c0dSopenharmony_ci+                            const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
6058be168c0dSopenharmony_ci+      : LiteKernel(parameter, inputs, outputs, ctx) {}
6059be168c0dSopenharmony_ci+  ~CustomMaskedFillCPUKernel() override = default;
6060be168c0dSopenharmony_ci+  int Prepare() override;
6061be168c0dSopenharmony_ci+  int ReSize() override;
6062be168c0dSopenharmony_ci+  int Run() override;
6063be168c0dSopenharmony_ci+};
6064be168c0dSopenharmony_ci+}  // namespace mindspore::kernel
6065be168c0dSopenharmony_ci+
6066be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_
6067be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc
6068be168c0dSopenharmony_cinew file mode 100644
6069be168c0dSopenharmony_ciindex 00000000..d52d67d5
6070be168c0dSopenharmony_ci--- /dev/null
6071be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc
6072be168c0dSopenharmony_ci@@ -0,0 +1,75 @@
6073be168c0dSopenharmony_ci+/**
6074be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd
6075be168c0dSopenharmony_ci+ *
6076be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
6077be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
6078be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
6079be168c0dSopenharmony_ci+ *
6080be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
6081be168c0dSopenharmony_ci+ *
6082be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
6083be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
6084be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6085be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
6086be168c0dSopenharmony_ci+ * limitations under the License.
6087be168c0dSopenharmony_ci+ */
6088be168c0dSopenharmony_ci+
6089be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/custom_tensor_scatter.h"
6090be168c0dSopenharmony_ci+#include <cstring>
6091be168c0dSopenharmony_ci+#include "schema/model_generated.h"
6092be168c0dSopenharmony_ci+#include "src/litert/kernel_registry.h"
6093be168c0dSopenharmony_ci+#include "include/errorcode.h"
6094be168c0dSopenharmony_ci+#include "nnacl/base/scatter_nd_binary.h"
6095be168c0dSopenharmony_ci+
6096be168c0dSopenharmony_ci+using mindspore::kernel::KERNEL_ARCH;
6097be168c0dSopenharmony_ci+using mindspore::lite::KernelRegistrar;
6098be168c0dSopenharmony_ci+using mindspore::lite::RET_ERROR;
6099be168c0dSopenharmony_ci+using mindspore::lite::RET_OK;
6100be168c0dSopenharmony_ci+
6101be168c0dSopenharmony_ci+namespace mindspore::kernel {
6102be168c0dSopenharmony_ci+namespace {
6103be168c0dSopenharmony_ci+int TensorScatterRun(void *cdata, int task_id, float, float) {
6104be168c0dSopenharmony_ci+  auto kernel = static_cast<CustomTensorScatterCPUKernel *>(cdata);
6105be168c0dSopenharmony_ci+  CHECK_NULL_RETURN(kernel);
6106be168c0dSopenharmony_ci+  return kernel->TensorScatterDispatch(task_id);
6107be168c0dSopenharmony_ci+}
6108be168c0dSopenharmony_ci+}  // namespace
6109be168c0dSopenharmony_ci+
6110be168c0dSopenharmony_ci+int CustomTensorScatterCPUKernel::TensorScatterDispatch(int task_id) {
6111be168c0dSopenharmony_ci+  auto data_type = in_tensors_[kScatterUpdateInputIndex]->data_type();
6112be168c0dSopenharmony_ci+  if (data_type != kNumberTypeFloat32) {
6113be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "TensorScatterMax only support float32 input tensor, but got " << data_type;
6114be168c0dSopenharmony_ci+    return RET_ERROR;
6115be168c0dSopenharmony_ci+  }
6116be168c0dSopenharmony_ci+  int type = data_type == kNumberTypeFloat32 ? 0 : 1;
6117be168c0dSopenharmony_ci+  // multi thread have some problems to solve
6118be168c0dSopenharmony_ci+  param_->op_parameter.thread_num_ = 1;
6119be168c0dSopenharmony_ci+  auto ret = ScatterNDMax(in_tensors_[kScatterUpdateIndex]->data(), out_tensors_[kOutputIndex]->data(),
6120be168c0dSopenharmony_ci+                          output_unit_offsets_.data(), param_, type, task_id);
6121be168c0dSopenharmony_ci+  if (ret != RET_OK) {
6122be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "ScatterNDMax failed, ret: " << ret;
6123be168c0dSopenharmony_ci+    return RET_ERROR;
6124be168c0dSopenharmony_ci+  }
6125be168c0dSopenharmony_ci+  return RET_OK;
6126be168c0dSopenharmony_ci+}
6127be168c0dSopenharmony_ci+
6128be168c0dSopenharmony_ci+int CustomTensorScatterCPUKernel::Run() {
6129be168c0dSopenharmony_ci+  auto in_tensor = in_tensors().front();
6130be168c0dSopenharmony_ci+  auto out_tensor = out_tensors().front();
6131be168c0dSopenharmony_ci+  (void)memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
6132be168c0dSopenharmony_ci+  auto indices = in_tensors_.at(kScatterIndicesIndex);
6133be168c0dSopenharmony_ci+  if (!indices->IsConst() && ReSize() != RET_OK) {
6134be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "TensorScatterAdd resize failed.";
6135be168c0dSopenharmony_ci+    return RET_ERROR;
6136be168c0dSopenharmony_ci+  }
6137be168c0dSopenharmony_ci+
6138be168c0dSopenharmony_ci+  auto ret = ParallelLaunch(ms_context_, TensorScatterRun, this, op_parameter_->thread_num_);
6139be168c0dSopenharmony_ci+  if (ret != RET_OK) {
6140be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "TensorScatterAdd error error_code[" << ret << "]";
6141be168c0dSopenharmony_ci+  }
6142be168c0dSopenharmony_ci+  return ret;
6143be168c0dSopenharmony_ci+}
6144be168c0dSopenharmony_ci+
6145be168c0dSopenharmony_ci+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomTensorScatterMax,
6146be168c0dSopenharmony_ci+           LiteKernelCreator<CustomTensorScatterCPUKernel>)
6147be168c0dSopenharmony_ci+}  // namespace mindspore::kernel
6148be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h
6149be168c0dSopenharmony_cinew file mode 100644
6150be168c0dSopenharmony_ciindex 00000000..e39733c5
6151be168c0dSopenharmony_ci--- /dev/null
6152be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h
6153be168c0dSopenharmony_ci@@ -0,0 +1,36 @@
6154be168c0dSopenharmony_ci+/**
6155be168c0dSopenharmony_ci+ * Copyright 2022 Huawei Technologies Co., Ltd
6156be168c0dSopenharmony_ci+ *
6157be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
6158be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
6159be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
6160be168c0dSopenharmony_ci+ *
6161be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
6162be168c0dSopenharmony_ci+ *
6163be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
6164be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
6165be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6166be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
6167be168c0dSopenharmony_ci+ * limitations under the License.
6168be168c0dSopenharmony_ci+ */
6169be168c0dSopenharmony_ci+
6170be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_
6171be168c0dSopenharmony_ci+#define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_
6172be168c0dSopenharmony_ci+
6173be168c0dSopenharmony_ci+#include <vector>
6174be168c0dSopenharmony_ci+#include "src/litert/kernel/cpu/base/scatter_nd_binary.h"
6175be168c0dSopenharmony_ci+
6176be168c0dSopenharmony_ci+namespace mindspore::kernel {
6177be168c0dSopenharmony_ci+class CustomTensorScatterCPUKernel : public ScatterNDBinaryCPUKernel {
6178be168c0dSopenharmony_ci+ public:
6179be168c0dSopenharmony_ci+  explicit CustomTensorScatterCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
6180be168c0dSopenharmony_ci+                                        const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
6181be168c0dSopenharmony_ci+      : ScatterNDBinaryCPUKernel(parameter, inputs, outputs, ctx) {}
6182be168c0dSopenharmony_ci+  ~CustomTensorScatterCPUKernel() override = default;
6183be168c0dSopenharmony_ci+
6184be168c0dSopenharmony_ci+  int Run() override;
6185be168c0dSopenharmony_ci+  int TensorScatterDispatch(int task_id);
6186be168c0dSopenharmony_ci+};
6187be168c0dSopenharmony_ci+}  // namespace mindspore::kernel
6188be168c0dSopenharmony_ci+
6189be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_
6190be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc
6191be168c0dSopenharmony_ciindex 2c5bc658..13652633 100644
6192be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_model.cc
6193be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_model.cc
6194be168c0dSopenharmony_ci@@ -98,6 +98,8 @@ int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) {
6195be168c0dSopenharmony_ci   if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr ||
6196be168c0dSopenharmony_ci       sub_graph.tensorIndices() == nullptr) {
6197be168c0dSopenharmony_ci     MS_LOG(ERROR) << "sub_graph is invalid";
6198be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "sub_graph.name() = " << sub_graph.name() << ", sub_graph.inputIndices() = " << sub_graph.inputIndices()
6199be168c0dSopenharmony_ci+      << ", sub_graph.outputIndices() = " << sub_graph.outputIndices() << ", sub_graph.tensorIndices() = " << sub_graph.tensorIndices();
6200be168c0dSopenharmony_ci     return RET_ERROR;
6201be168c0dSopenharmony_ci   }
6202be168c0dSopenharmony_ci 
6203be168c0dSopenharmony_ci@@ -620,6 +622,33 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, minds
6204be168c0dSopenharmony_ci   return model;
6205be168c0dSopenharmony_ci }
6206be168c0dSopenharmony_ci 
6207be168c0dSopenharmony_ci+std::string LiteGraph::ToString() const {
6208be168c0dSopenharmony_ci+  std::stringstream ss;
6209be168c0dSopenharmony_ci+  ss << "all_nodes: " << all_nodes_.size() << std::endl;
6210be168c0dSopenharmony_ci+  for (size_t i = 0; i < all_nodes_.size(); i++) {
6211be168c0dSopenharmony_ci+    ss << "- node " << i << ": " << all_nodes_[i]->primitive_ << std::endl;
6212be168c0dSopenharmony_ci+    ss << "- node " << i << " input_indices_: " << all_nodes_[i]->input_indices_ << std::endl;
6213be168c0dSopenharmony_ci+    ss << "- node " << i << " output_indices_: " << all_nodes_[i]->output_indices_ << std::endl;
6214be168c0dSopenharmony_ci+  }
6215be168c0dSopenharmony_ci+  ss << "all_tensors: " << all_tensors_.size() << std::endl;
6216be168c0dSopenharmony_ci+  for (size_t i = 0; i < all_tensors_.size(); i++) {
6217be168c0dSopenharmony_ci+    ss << "- tensor " << i << ": " << all_tensors_[i] << std::endl;
6218be168c0dSopenharmony_ci+  }
6219be168c0dSopenharmony_ci+  ss << "input_indices: " << input_indices_<< std::endl;
6220be168c0dSopenharmony_ci+  ss << "output_indices: " << output_indices_ << std::endl;
6221be168c0dSopenharmony_ci+
6222be168c0dSopenharmony_ci+  ss << "subgraphs: " << std::endl;
6223be168c0dSopenharmony_ci+  int count = 0;
6224be168c0dSopenharmony_ci+  for (auto subgraph: sub_graphs_) {
6225be168c0dSopenharmony_ci+    ss << "- subgraph " << count++ << std::endl;
6226be168c0dSopenharmony_ci+    ss << "--- subgraph input " << subgraph->input_indices_ << std::endl;
6227be168c0dSopenharmony_ci+    ss << "--- subgraph output " << subgraph->output_indices_ << std::endl;
6228be168c0dSopenharmony_ci+    ss << "--- subgraph node " << subgraph->node_indices_ << std::endl;
6229be168c0dSopenharmony_ci+    ss << "--- subgraph tensor " << subgraph->tensor_indices_ << std::endl;
6230be168c0dSopenharmony_ci+  }
6231be168c0dSopenharmony_ci+  return ss.str();
6232be168c0dSopenharmony_ci+}
6233be168c0dSopenharmony_ci+
6234be168c0dSopenharmony_ci Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
6235be168c0dSopenharmony_ci 
6236be168c0dSopenharmony_ci Model *Model::Import(const char *filename) { return ImportFromPath(filename); }
6237be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_session.cc b/mindspore/lite/src/litert/lite_session.cc
6238be168c0dSopenharmony_ciindex 8f54879e..f635c8d2 100644
6239be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_session.cc
6240be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_session.cc
6241be168c0dSopenharmony_ci@@ -67,6 +67,9 @@
6242be168c0dSopenharmony_ci #include "thread/parallel_thread_pool_manager.h"
6243be168c0dSopenharmony_ci #endif
6244be168c0dSopenharmony_ci #include "src/litert/runtime_packed_node_pass.h"
6245be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT
6246be168c0dSopenharmony_ci+#include "src/litert/delegate/nnrt/nnrt_delegate.h"
6247be168c0dSopenharmony_ci+#endif
6248be168c0dSopenharmony_ci 
6249be168c0dSopenharmony_ci using AbstractBaseModel = mindspore::infer::AbstractBaseModel;
6250be168c0dSopenharmony_ci 
6251be168c0dSopenharmony_ci@@ -635,12 +638,6 @@ int LiteSession::CompileGraph(Model *model) {
6252be168c0dSopenharmony_ci   MarkSharedWeight(kernels_);
6253be168c0dSopenharmony_ci   FreePackOpWeight(kernels_);
6254be168c0dSopenharmony_ci 
6255be168c0dSopenharmony_ci-  ret = RuntimeAllocatorInit();
6256be168c0dSopenharmony_ci-  if (ret != RET_OK) {
6257be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Runtime allocator init failed.";
6258be168c0dSopenharmony_ci-    is_running_.store(false);
6259be168c0dSopenharmony_ci-    return ret;
6260be168c0dSopenharmony_ci-  }
6261be168c0dSopenharmony_ci   infer_along_running_ = infer_along_running_ && (runtime_allocator_ == nullptr);
6262be168c0dSopenharmony_ci   if (infer_along_running_) {
6263be168c0dSopenharmony_ci     this->context_->set_infer_checker(InferCheckerAll);
6264be168c0dSopenharmony_ci@@ -1092,6 +1089,27 @@ int LiteSession::CreateCoreMLDelegate() {
6265be168c0dSopenharmony_ci   return RET_OK;
6266be168c0dSopenharmony_ci }
6267be168c0dSopenharmony_ci 
6268be168c0dSopenharmony_ci+int LiteSession::CreateNNRTDelegate() {
6269be168c0dSopenharmony_ci+#if SUPPORT_NNRT
6270be168c0dSopenharmony_ci+  auto iter = std::find_if(context_->device_list_.begin(), context_->device_list_.end(),
6271be168c0dSopenharmony_ci+                           [](DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; });
6272be168c0dSopenharmony_ci+  if(iter == context_->device_list_.end()) {
6273be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Found non NNRT device info";
6274be168c0dSopenharmony_ci+    return RET_ERROR;
6275be168c0dSopenharmony_ci+  }
6276be168c0dSopenharmony_ci+
6277be168c0dSopenharmony_ci+  delegate_ = std::make_shared<NNRTDelegate>(iter->device_info_.nnrt_device_info_);
6278be168c0dSopenharmony_ci+  if (delegate_ == nullptr) {
6279be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "New NNRT delegate failed";
6280be168c0dSopenharmony_ci+    return RET_ERROR;
6281be168c0dSopenharmony_ci+  }
6282be168c0dSopenharmony_ci+//  ((NNRTDelegate *)(delegate_.get()))->SetMetaGraph(this->model_->buf);
6283be168c0dSopenharmony_ci+  delegate_device_type_ = DT_NNRT;
6284be168c0dSopenharmony_ci+  this->context_->delegate = delegate_;
6285be168c0dSopenharmony_ci+#endif
6286be168c0dSopenharmony_ci+  return RET_OK;
6287be168c0dSopenharmony_ci+};
6288be168c0dSopenharmony_ci+
6289be168c0dSopenharmony_ci int LiteSession::DelegateInit() {
6290be168c0dSopenharmony_ci #ifndef DELEGATE_CLIP
6291be168c0dSopenharmony_ci   int ret = RET_OK;
6292be168c0dSopenharmony_ci@@ -1115,6 +1133,8 @@ int LiteSession::DelegateInit() {
6293be168c0dSopenharmony_ci       ret = CreateNPUDelegate();
6294be168c0dSopenharmony_ci     } else if (context_->IsDeviceTypeEnabled(DT_GPU)) {
6295be168c0dSopenharmony_ci       ret = CreateTensorRTDelegate();
6296be168c0dSopenharmony_ci+    } else if (context_->IsDeviceTypeEnabled(DT_NNRT)) {
6297be168c0dSopenharmony_ci+      ret = CreateNNRTDelegate();
6298be168c0dSopenharmony_ci     }
6299be168c0dSopenharmony_ci   }
6300be168c0dSopenharmony_ci 
6301be168c0dSopenharmony_ci@@ -1496,12 +1516,6 @@ int LiteSession::Resize(const std::vector<mindspore::lite::Tensor *> &inputs,
6302be168c0dSopenharmony_ci     return ret;
6303be168c0dSopenharmony_ci   }
6304be168c0dSopenharmony_ci 
6305be168c0dSopenharmony_ci-  if (RuntimeAllocatorInit() != RET_OK) {
6306be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Runtime allocator in resize failed.";
6307be168c0dSopenharmony_ci-    is_running_.store(false);
6308be168c0dSopenharmony_ci-    return RET_ERROR;
6309be168c0dSopenharmony_ci-  }
6310be168c0dSopenharmony_ci-
6311be168c0dSopenharmony_ci   auto status = GraphOptimizePass(&kernels_);
6312be168c0dSopenharmony_ci   if (status != RET_OK) {
6313be168c0dSopenharmony_ci     MS_LOG(ERROR) << "GraphOptimizePass failed.";
6314be168c0dSopenharmony_ci@@ -2022,7 +2036,6 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path,
6315be168c0dSopenharmony_ci     delete model;
6316be168c0dSopenharmony_ci     return RET_ERROR;
6317be168c0dSopenharmony_ci   }
6318be168c0dSopenharmony_ci-  model->Free();
6319be168c0dSopenharmony_ci   set_model(model);
6320be168c0dSopenharmony_ci   return RET_OK;
6321be168c0dSopenharmony_ci }
6322be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/lite_session.h b/mindspore/lite/src/litert/lite_session.h
6323be168c0dSopenharmony_ciindex f8f8fe08..64a5f6d3 100644
6324be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/lite_session.h
6325be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/lite_session.h
6326be168c0dSopenharmony_ci@@ -178,6 +178,7 @@ class MS_API LiteSession {
6327be168c0dSopenharmony_ci   int CreateNPUDelegate();
6328be168c0dSopenharmony_ci   int CreateNNAPIDelegate();
6329be168c0dSopenharmony_ci   int CreateCoreMLDelegate();
6330be168c0dSopenharmony_ci+  int CreateNNRTDelegate();
6331be168c0dSopenharmony_ci   int DelegateInit();
6332be168c0dSopenharmony_ci   int InitGPURuntime();
6333be168c0dSopenharmony_ci   int InitSharedThreadPool();
6334be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc
6335be168c0dSopenharmony_ciindex 11382b09..199b4361 100644
6336be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/scheduler.cc
6337be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/scheduler.cc
6338be168c0dSopenharmony_ci@@ -60,6 +60,9 @@
6339be168c0dSopenharmony_ci #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT)
6340be168c0dSopenharmony_ci #include "thread/parallel_thread_pool_manager.h"
6341be168c0dSopenharmony_ci #endif
6342be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT
6343be168c0dSopenharmony_ci+#include "src/litert/delegate/nnrt/nnrt_delegate.h"
6344be168c0dSopenharmony_ci+#endif
6345be168c0dSopenharmony_ci 
6346be168c0dSopenharmony_ci using AbstractBaseModel = mindspore::infer::AbstractBaseModel;
6347be168c0dSopenharmony_ci 
6348be168c0dSopenharmony_ci@@ -368,6 +371,7 @@ STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::KernelExec *> *ker
6349be168c0dSopenharmony_ci }
6350be168c0dSopenharmony_ci 
6351be168c0dSopenharmony_ci int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
6352be168c0dSopenharmony_ci+  MS_LOG(DEBUG) << "Start schedule.";
6353be168c0dSopenharmony_ci   int check_input_ret = CheckInputParam(dst_kernels);
6354be168c0dSopenharmony_ci   if (check_input_ret != RET_OK) {
6355be168c0dSopenharmony_ci     MS_LOG(ERROR) << "CheckInputParam failed! ret: " << check_input_ret;
6356be168c0dSopenharmony_ci@@ -404,11 +408,13 @@ int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
6357be168c0dSopenharmony_ci   }
6358be168c0dSopenharmony_ci   shape_fusion_pass_->StoreStateAndReset();
6359be168c0dSopenharmony_ci 
6360be168c0dSopenharmony_ci+  MS_LOG(DEBUG) << "Start to init delegate kernels.";
6361be168c0dSopenharmony_ci   ret = InitDelegateKernels(dst_kernels);
6362be168c0dSopenharmony_ci   if (ret != RET_OK) {
6363be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Repalce delegate kernels failed.";
6364be168c0dSopenharmony_ci     return ret;
6365be168c0dSopenharmony_ci   }
6366be168c0dSopenharmony_ci+  MS_LOG(DEBUG) << "Finish to init delegate kernels.";
6367be168c0dSopenharmony_ci 
6368be168c0dSopenharmony_ci   ret = CheckCpuValid(dst_kernels);
6369be168c0dSopenharmony_ci   if (ret != RET_OK) {
6370be168c0dSopenharmony_ci@@ -500,6 +506,17 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::KernelExec *> *dst_ker
6371be168c0dSopenharmony_ci     MS_LOG(ERROR) << "New delegate model failed.";
6372be168c0dSopenharmony_ci     return RET_NULL_PTR;
6373be168c0dSopenharmony_ci   }
6374be168c0dSopenharmony_ci+
6375be168c0dSopenharmony_ci+#ifdef SUPPORT_NNRT
6376be168c0dSopenharmony_ci+  if (context_->IsDeviceTypeEnabled(DT_NNRT)) {
6377be168c0dSopenharmony_ci+    auto delegate = static_cast<NNRTDelegate *>(delegate_.get());
6378be168c0dSopenharmony_ci+    delegate->ShallowCopyLiteGraph(this->src_model_->graph_);
6379be168c0dSopenharmony_ci+    void *meta_graph = reinterpret_cast<void*>(const_cast<mindspore::schema::MetaGraph *>(
6380be168c0dSopenharmony_ci+      mindspore::schema::GetMetaGraph(this->src_model_->buf)));
6381be168c0dSopenharmony_ci+    delegate->SetMetaGraph(meta_graph);
6382be168c0dSopenharmony_ci+  }
6383be168c0dSopenharmony_ci+#endif
6384be168c0dSopenharmony_ci+
6385be168c0dSopenharmony_ci   auto ret = delegate_->Build(model);
6386be168c0dSopenharmony_ci   if (ret != mindspore::kSuccess) {
6387be168c0dSopenharmony_ci     delete model;
6388be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/tensor_category.cc b/mindspore/lite/src/litert/tensor_category.cc
6389be168c0dSopenharmony_ciindex 70d13865..e57cdb28 100644
6390be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/tensor_category.cc
6391be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/tensor_category.cc
6392be168c0dSopenharmony_ci@@ -30,5 +30,9 @@ Category TensorCategory(const schema::Tensor &tensor) {
6393be168c0dSopenharmony_ci   auto data_size = tensor.data() == nullptr ? 0 : tensor.data()->size();
6394be168c0dSopenharmony_ci   return TensorCategory(tensor.nodeType(), shape_num, TypeId(tensor.dataType()), data_size);
6395be168c0dSopenharmony_ci }
6396be168c0dSopenharmony_ci+
6397be168c0dSopenharmony_ci+bool IsConstTensor(const schema::Tensor &tensor) {
6398be168c0dSopenharmony_ci+  return TensorCategory(tensor) != Category::VAR;
6399be168c0dSopenharmony_ci+}
6400be168c0dSopenharmony_ci }  // namespace lite
6401be168c0dSopenharmony_ci }  // namespace mindspore
6402be168c0dSopenharmony_cidiff --git a/mindspore/lite/src/litert/tensor_category.h b/mindspore/lite/src/litert/tensor_category.h
6403be168c0dSopenharmony_ciindex 83273032..70e65b31 100644
6404be168c0dSopenharmony_ci--- a/mindspore/lite/src/litert/tensor_category.h
6405be168c0dSopenharmony_ci+++ b/mindspore/lite/src/litert/tensor_category.h
6406be168c0dSopenharmony_ci@@ -35,6 +35,7 @@ enum Category {
6407be168c0dSopenharmony_ci 
6408be168c0dSopenharmony_ci Category TensorCategory(const int node_type, const size_t shape_num, const TypeId data_type, const size_t data_size);
6409be168c0dSopenharmony_ci Category TensorCategory(const schema::Tensor &tensor);
6410be168c0dSopenharmony_ci+bool IsConstTensor(const schema::Tensor &tensor);
6411be168c0dSopenharmony_ci }  // namespace lite
6412be168c0dSopenharmony_ci }  // namespace mindspore
6413be168c0dSopenharmony_ci #endif  // MINDSPORE_LITE_SRC_RUNTIME_TENSOR_CATEGORY_H_
6414be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt
6415be168c0dSopenharmony_ciindex 60e240f0..78dab536 100644
6416be168c0dSopenharmony_ci--- a/mindspore/lite/test/CMakeLists.txt
6417be168c0dSopenharmony_ci+++ b/mindspore/lite/test/CMakeLists.txt
6418be168c0dSopenharmony_ci@@ -28,10 +28,14 @@ file(GLOB_RECURSE TEST_UT_SRC
6419be168c0dSopenharmony_ci         ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
6420be168c0dSopenharmony_ci         ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
6421be168c0dSopenharmony_ci         ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
6422be168c0dSopenharmony_ci-        ${TEST_DIR}/ut/src/api/context_c_test.cc
6423be168c0dSopenharmony_ci-        ${TEST_DIR}/ut/src/api/model_c_test.cc
6424be168c0dSopenharmony_ci-        ${TEST_DIR}/ut/src/api/tensor_c_test.cc`
6425be168c0dSopenharmony_ci+#        ${TEST_DIR}/ut/src/api/context_c_test.cc
6426be168c0dSopenharmony_ci+#        ${TEST_DIR}/ut/src/api/model_c_test.cc
6427be168c0dSopenharmony_ci+#        ${TEST_DIR}/ut/src/api/tensor_c_test.cc`
6428be168c0dSopenharmony_ci         )
6429be168c0dSopenharmony_ci+if(MSLITE_ENABLE_NNRT)
6430be168c0dSopenharmony_ci+    list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/nnrt_delegate/nnrt_delegate_tests.cc)
6431be168c0dSopenharmony_ci+endif()
6432be168c0dSopenharmony_ci+
6433be168c0dSopenharmony_ci if(MSLITE_ENABLE_SERVER_INFERENCE)
6434be168c0dSopenharmony_ci     list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/api/model_parallel_runner_test.cc)
6435be168c0dSopenharmony_ci endif()
6436be168c0dSopenharmony_ci@@ -86,7 +90,7 @@ endif()
6437be168c0dSopenharmony_ci 
6438be168c0dSopenharmony_ci if(MSLITE_ENABLE_INT8)
6439be168c0dSopenharmony_ci     file(GLOB_RECURSE TEST_INT8_UT_SRC
6440be168c0dSopenharmony_ci-            ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
6441be168c0dSopenharmony_ci+#            ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
6442be168c0dSopenharmony_ci             ${TEST_DIR}/ut/nnacl/int8/*.cc
6443be168c0dSopenharmony_ci             )
6444be168c0dSopenharmony_ci     list(APPEND TEST_UT_SRC ${TEST_INT8_UT_SRC})
6445be168c0dSopenharmony_ci@@ -118,6 +122,7 @@ if(MSLITE_ENABLE_CONVERTER)
6446be168c0dSopenharmony_ci             ${TEST_DIR}/ut/tools/converter/registry/*.cc
6447be168c0dSopenharmony_ci             ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc
6448be168c0dSopenharmony_ci             ${TEST_DIR}/ut/tools/converter/api/*.cc
6449be168c0dSopenharmony_ci+            ${TEST_DIR}/ut/tools/converter/config_parser/*.cc
6450be168c0dSopenharmony_ci             ${TEST_DIR}/st/converter_test.cc
6451be168c0dSopenharmony_ci             ${TEST_DIR}/st/delegate_test.cc
6452be168c0dSopenharmony_ci             ${TEST_DIR}/st/mindrt_parallel_test.cc
6453be168c0dSopenharmony_ci@@ -232,7 +237,7 @@ endif()
6454be168c0dSopenharmony_ci 
6455be168c0dSopenharmony_ci if(MSLITE_ENABLE_CONVERTER)
6456be168c0dSopenharmony_ci     target_link_libraries(lite-test-converter tflite_parser_mid caffe_parser_mid
6457be168c0dSopenharmony_ci-                                    onnx_parser_mid tf_parser_mid)
6458be168c0dSopenharmony_ci+                                    onnx_parser_mid tf_parser_mid third_party_parser_mid)
6459be168c0dSopenharmony_ci endif()
6460be168c0dSopenharmony_ci 
6461be168c0dSopenharmony_ci if(MSLITE_ENABLE_MODEL_OBF)
6462be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/runtest.sh b/mindspore/lite/test/runtest.sh
6463be168c0dSopenharmony_ciindex c0d6d843..abdea6f4 100644
6464be168c0dSopenharmony_ci--- a/mindspore/lite/test/runtest.sh
6465be168c0dSopenharmony_ci+++ b/mindspore/lite/test/runtest.sh
6466be168c0dSopenharmony_ci@@ -80,6 +80,7 @@ if [ "$ENABLE_CONVERTER_TEST" = true ]; then
6467be168c0dSopenharmony_ci   ./lite-test-converter --gtest_filter="PassRegistryTest.TestRegistry"
6468be168c0dSopenharmony_ci   ./lite-test-converter --gtest_filter="TestConverterAPI.*"
6469be168c0dSopenharmony_ci   ./lite-test-converter --gtest_filter="SpecifyGraphOutputFormatTest*"
6470be168c0dSopenharmony_ci+  ./lite-test-converter --gtest_filter="TestThirdPartyParamParser.*"
6471be168c0dSopenharmony_ci fi
6472be168c0dSopenharmony_ci ./lite-test --gtest_filter="TestRegistry.TestAdd"
6473be168c0dSopenharmony_ci ./lite-test --gtest_filter="TestRegistryCustomOp.TestCustomAdd"
6474be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/ut/test_data/third_party_model.cfg b/mindspore/lite/test/ut/test_data/third_party_model.cfg
6475be168c0dSopenharmony_cinew file mode 100644
6476be168c0dSopenharmony_ciindex 00000000..b5fcba75
6477be168c0dSopenharmony_ci--- /dev/null
6478be168c0dSopenharmony_ci+++ b/mindspore/lite/test/ut/test_data/third_party_model.cfg
6479be168c0dSopenharmony_ci@@ -0,0 +1,8 @@
6480be168c0dSopenharmony_ci+[third_party_model]
6481be168c0dSopenharmony_ci+input_names=demo_in_0;demo_in_1;demo_in_2
6482be168c0dSopenharmony_ci+input_dtypes=float32;float16;float64
6483be168c0dSopenharmony_ci+input_shapes=1;2,3;4,5,6
6484be168c0dSopenharmony_ci+output_names=demo_out_0;demo_out_1;demo_out_2;demo_out_4
6485be168c0dSopenharmony_ci+output_dtypes=int32;int16;int8;uint8
6486be168c0dSopenharmony_ci+output_shapes=10;20,30;40;50,60,70
6487be168c0dSopenharmony_ci+extended_parameters=foo:foo_value;bar:bar_value
6488be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc
6489be168c0dSopenharmony_ciindex 549bdd72..e73afc0e 100644
6490be168c0dSopenharmony_ci--- a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc
6491be168c0dSopenharmony_ci+++ b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc
6492be168c0dSopenharmony_ci@@ -34,3 +34,13 @@ TEST(TestConverterAPI, ConvertCaffeWithNotExistWeight) {
6493be168c0dSopenharmony_ci   mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeCaffe, caffe_model, output_model, caffe_weight);
6494be168c0dSopenharmony_ci   ASSERT_FALSE(converter.Convert().IsOk());
6495be168c0dSopenharmony_ci }
6496be168c0dSopenharmony_ci+
6497be168c0dSopenharmony_ci+TEST(TestConverterAPI, ConvertThirdParty) {
6498be168c0dSopenharmony_ci+  std::string third_party_model = "./relu.mindir";
6499be168c0dSopenharmony_ci+  std::string config_model = "./third_party_model.cfg";
6500be168c0dSopenharmony_ci+  std::string output_model = "./demo_third_party.ms";
6501be168c0dSopenharmony_ci+
6502be168c0dSopenharmony_ci+  mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeThirdParty, third_party_model, output_model);
6503be168c0dSopenharmony_ci+  converter.SetConfigFile(config_model);
6504be168c0dSopenharmony_ci+  ASSERT_TRUE(converter.Convert().IsOk());
6505be168c0dSopenharmony_ci+}
6506be168c0dSopenharmony_ci\ No newline at end of file
6507be168c0dSopenharmony_cidiff --git a/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc
6508be168c0dSopenharmony_cinew file mode 100644
6509be168c0dSopenharmony_ciindex 00000000..c8eb5536
6510be168c0dSopenharmony_ci--- /dev/null
6511be168c0dSopenharmony_ci+++ b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc
6512be168c0dSopenharmony_ci@@ -0,0 +1,176 @@
6513be168c0dSopenharmony_ci+/**
6514be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
6515be168c0dSopenharmony_ci+ *
6516be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
6517be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
6518be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
6519be168c0dSopenharmony_ci+ *
6520be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
6521be168c0dSopenharmony_ci+ *
6522be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
6523be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
6524be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6525be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
6526be168c0dSopenharmony_ci+ * limitations under the License.
6527be168c0dSopenharmony_ci+ */
6528be168c0dSopenharmony_ci+
6529be168c0dSopenharmony_ci+#include "gtest/gtest.h"
6530be168c0dSopenharmony_ci+#include "tools/converter/config_parser/third_party_param_parser.h"
6531be168c0dSopenharmony_ci+
6532be168c0dSopenharmony_ci+using mindspore::ThirdPartyModelParam;
6533be168c0dSopenharmony_ci+using mindspore::TypeId;
6534be168c0dSopenharmony_ci+using mindspore::lite::RET_OK;
6535be168c0dSopenharmony_ci+using mindspore::lite::ThirdPartyModelString;
6536be168c0dSopenharmony_ci+using mindspore::lite::ThirdPartyParamParser;
6537be168c0dSopenharmony_ci+
6538be168c0dSopenharmony_ci+const ThirdPartyModelString kDemoSISOParam = {
6539be168c0dSopenharmony_ci+  // SISO is short for single-input-single-output.
6540be168c0dSopenharmony_ci+  .input_dtypes = "float32",
6541be168c0dSopenharmony_ci+  .input_shapes = "1,2,3,4",
6542be168c0dSopenharmony_ci+  .input_names = "siso_input",
6543be168c0dSopenharmony_ci+  .output_dtypes = "int32",
6544be168c0dSopenharmony_ci+  .output_shapes = "2",
6545be168c0dSopenharmony_ci+  .output_names = "siso_output",
6546be168c0dSopenharmony_ci+  .extended_parameters = "siso_foo:siso_foo_value;siso_bar:siso_bar_value",
6547be168c0dSopenharmony_ci+};
6548be168c0dSopenharmony_ci+
6549be168c0dSopenharmony_ci+const ThirdPartyModelString kDemoMIMOParam = {
6550be168c0dSopenharmony_ci+  // MIMO is short for multiple-input-multiple-output.
6551be168c0dSopenharmony_ci+  .input_dtypes = "float32;int8;float16",
6552be168c0dSopenharmony_ci+  .input_shapes = "1,2,3,4;5,6;7,8,9",
6553be168c0dSopenharmony_ci+  .input_names = "mimo_in_0;mimo_in_1;mimo_in_2",
6554be168c0dSopenharmony_ci+  .output_dtypes = "int32;float32",
6555be168c0dSopenharmony_ci+  .output_shapes = "2,4;10,20,30",
6556be168c0dSopenharmony_ci+  .output_names = "mimo_out_0;mimo_out_1",
6557be168c0dSopenharmony_ci+  .extended_parameters = "mimo_foo:mimo_foo_value;mimo_bar:mimo_bar_value",
6558be168c0dSopenharmony_ci+};
6559be168c0dSopenharmony_ci+
6560be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseSISOParam) {
6561be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoSISOParam;
6562be168c0dSopenharmony_ci+  ThirdPartyModelParam result;
6563be168c0dSopenharmony_ci+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6564be168c0dSopenharmony_ci+
6565be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_names, std::vector<std::string>{"siso_input"});
6566be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_shapes.size(), 1U);
6567be168c0dSopenharmony_ci+  std::vector<int64_t> expect_in_shape = {1, 2, 3, 4};
6568be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_shapes[0], expect_in_shape);
6569be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_dtypes, std::vector<TypeId>{TypeId::kNumberTypeFloat32});
6570be168c0dSopenharmony_ci+
6571be168c0dSopenharmony_ci+  ASSERT_EQ(result.output_names, std::vector<std::string>{"siso_output"});
6572be168c0dSopenharmony_ci+  ASSERT_EQ(result.output_shapes.size(), 1U);
6573be168c0dSopenharmony_ci+  std::vector<int64_t> expect_out_shape = {2};
6574be168c0dSopenharmony_ci+  ASSERT_EQ(result.output_shapes[0], expect_out_shape);
6575be168c0dSopenharmony_ci+  ASSERT_EQ(result.output_dtypes, std::vector<TypeId>{TypeId::kNumberTypeInt32});
6576be168c0dSopenharmony_ci+
6577be168c0dSopenharmony_ci+  const auto &ext_param = result.extended_parameters;
6578be168c0dSopenharmony_ci+  ASSERT_EQ(ext_param.size(), 2U);
6579be168c0dSopenharmony_ci+  ASSERT_TRUE(ext_param.find("siso_foo") != ext_param.end());
6580be168c0dSopenharmony_ci+  auto expect_foo_value = ext_param.at("siso_foo");
6581be168c0dSopenharmony_ci+  ASSERT_EQ(std::string(expect_foo_value.begin(), expect_foo_value.end()), "siso_foo_value");
6582be168c0dSopenharmony_ci+  ASSERT_TRUE(ext_param.find("siso_bar") != ext_param.end());
6583be168c0dSopenharmony_ci+  auto expect_bar_value = ext_param.at("siso_bar");
6584be168c0dSopenharmony_ci+  ASSERT_EQ(std::string(expect_bar_value.begin(), expect_bar_value.end()), "siso_bar_value");
6585be168c0dSopenharmony_ci+}
6586be168c0dSopenharmony_ci+
6587be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseValidDtype) {
6588be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoSISOParam;
6589be168c0dSopenharmony_ci+  const std::vector<std::string> kValidDtypeStrings = {
6590be168c0dSopenharmony_ci+    "float64", "float32", "float16", "int64", "int32", "int16", "int8", "uint8", "bool",
6591be168c0dSopenharmony_ci+  };
6592be168c0dSopenharmony_ci+
6593be168c0dSopenharmony_ci+  const std::vector<TypeId> kExpects = {
6594be168c0dSopenharmony_ci+    TypeId::kNumberTypeFloat64, TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat16,
6595be168c0dSopenharmony_ci+    TypeId::kNumberTypeInt64,   TypeId::kNumberTypeInt32,   TypeId::kNumberTypeInt16,
6596be168c0dSopenharmony_ci+    TypeId::kNumberTypeInt8,    TypeId::kNumberTypeUInt8,   TypeId::kNumberTypeBool};
6597be168c0dSopenharmony_ci+
6598be168c0dSopenharmony_ci+  for (size_t i = 0; i < kValidDtypeStrings.size(); i++) {
6599be168c0dSopenharmony_ci+    param_string.input_dtypes = kValidDtypeStrings[i];
6600be168c0dSopenharmony_ci+    ThirdPartyModelParam result;
6601be168c0dSopenharmony_ci+    ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6602be168c0dSopenharmony_ci+    ASSERT_EQ(result.input_dtypes[0], kExpects[i]);
6603be168c0dSopenharmony_ci+  }
6604be168c0dSopenharmony_ci+}
6605be168c0dSopenharmony_ci+
6606be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseInvalidDtype) {
6607be168c0dSopenharmony_ci+  ThirdPartyModelParam result;
6608be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoSISOParam;
6609be168c0dSopenharmony_ci+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6610be168c0dSopenharmony_ci+  param_string.input_dtypes = "bad_dtype";
6611be168c0dSopenharmony_ci+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6612be168c0dSopenharmony_ci+}
6613be168c0dSopenharmony_ci+
6614be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseValidShape) {
6615be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoSISOParam;
6616be168c0dSopenharmony_ci+  param_string.input_shapes = "256,256,1024,96";  // Only support fixed shape.
6617be168c0dSopenharmony_ci+  ThirdPartyModelParam result;
6618be168c0dSopenharmony_ci+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6619be168c0dSopenharmony_ci+  std::vector<int64_t> expect = {256, 256, 1024, 96};
6620be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_shapes[0], expect);
6621be168c0dSopenharmony_ci+}
6622be168c0dSopenharmony_ci+
6623be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseInvalidShape) {
6624be168c0dSopenharmony_ci+  ThirdPartyModelParam result;
6625be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoSISOParam;
6626be168c0dSopenharmony_ci+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6627be168c0dSopenharmony_ci+
6628be168c0dSopenharmony_ci+  param_string.input_shapes = "256,256,1024,-1";
6629be168c0dSopenharmony_ci+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6630be168c0dSopenharmony_ci+
6631be168c0dSopenharmony_ci+  param_string.input_shapes = "256,256,0,96";
6632be168c0dSopenharmony_ci+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6633be168c0dSopenharmony_ci+
6634be168c0dSopenharmony_ci+  param_string.input_shapes = "256,-256,1024,96";
6635be168c0dSopenharmony_ci+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6636be168c0dSopenharmony_ci+
6637be168c0dSopenharmony_ci+  param_string.input_shapes = "256,foo,1024,96";
6638be168c0dSopenharmony_ci+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6639be168c0dSopenharmony_ci+}
6640be168c0dSopenharmony_ci+
6641be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseDefaultName) {
6642be168c0dSopenharmony_ci+  ThirdPartyModelParam result;
6643be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoSISOParam;
6644be168c0dSopenharmony_ci+  param_string.input_names = "";
6645be168c0dSopenharmony_ci+  param_string.output_names = "";
6646be168c0dSopenharmony_ci+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6647be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_names[0], "in_0");
6648be168c0dSopenharmony_ci+  ASSERT_EQ(result.output_names[0], "out_0");
6649be168c0dSopenharmony_ci+}
6650be168c0dSopenharmony_ci+
6651be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseMIMOParam) {
6652be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoMIMOParam;
6653be168c0dSopenharmony_ci+  ThirdPartyModelParam result;
6654be168c0dSopenharmony_ci+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6655be168c0dSopenharmony_ci+
6656be168c0dSopenharmony_ci+  std::vector<std::string> expect_input_names = {"mimo_in_0", "mimo_in_1", "mimo_in_2"};
6657be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_names, expect_input_names);
6658be168c0dSopenharmony_ci+  std::vector<std::vector<int64_t>> expect_input_shapes = {{1, 2, 3, 4}, {5, 6}, {7, 8, 9}};
6659be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_shapes, expect_input_shapes);
6660be168c0dSopenharmony_ci+  std::vector<TypeId> expect_input_dtypes = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt8,
6661be168c0dSopenharmony_ci+                                             TypeId::kNumberTypeFloat16};
6662be168c0dSopenharmony_ci+  ASSERT_EQ(result.input_dtypes, expect_input_dtypes);
6663be168c0dSopenharmony_ci+
6664be168c0dSopenharmony_ci+  std::vector<std::string> expect_output_names = {"mimo_out_0", "mimo_out_1"};
6665be168c0dSopenharmony_ci+  ASSERT_EQ(result.output_names, expect_output_names);
6666be168c0dSopenharmony_ci+  std::vector<std::vector<int64_t>> expect_output_shapes = {{2, 4}, {10, 20, 30}};
6667be168c0dSopenharmony_ci+  ASSERT_EQ(result.output_shapes, expect_output_shapes);
6668be168c0dSopenharmony_ci+  std::vector<TypeId> expect_output_dtypes = {TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32};
6669be168c0dSopenharmony_ci+  ASSERT_EQ(result.output_dtypes, expect_output_dtypes);
6670be168c0dSopenharmony_ci+}
6671be168c0dSopenharmony_ci+
6672be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseMismatchedShapeAndDtypeSize) {
6673be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoMIMOParam;
6674be168c0dSopenharmony_ci+  ThirdPartyModelParam result;
6675be168c0dSopenharmony_ci+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6676be168c0dSopenharmony_ci+
6677be168c0dSopenharmony_ci+  param_string.input_shapes = "1,2,3,4;5,6";  // shape size is 2 while dtype size is 3.
6678be168c0dSopenharmony_ci+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6679be168c0dSopenharmony_ci+}
6680be168c0dSopenharmony_ci+
6681be168c0dSopenharmony_ci+TEST(TestThirdPartyParamParser, ParseMismatchedNameAndDtypeSize) {
6682be168c0dSopenharmony_ci+  ThirdPartyModelString param_string = kDemoMIMOParam;
6683be168c0dSopenharmony_ci+  ThirdPartyModelParam result;
6684be168c0dSopenharmony_ci+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6685be168c0dSopenharmony_ci+
6686be168c0dSopenharmony_ci+  param_string.input_names = "mimo_in_0;mimo_in_1";  // name size is 2 while dtype size is 3.
6687be168c0dSopenharmony_ci+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6688be168c0dSopenharmony_ci+}
6689be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark/benchmark_base.cc b/mindspore/lite/tools/benchmark/benchmark_base.cc
6690be168c0dSopenharmony_ciindex 16b1e218..ebaa9212 100644
6691be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark/benchmark_base.cc
6692be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark/benchmark_base.cc
6693be168c0dSopenharmony_ci@@ -323,7 +323,7 @@ int BenchmarkBase::CheckThreadNumValid() {
6694be168c0dSopenharmony_ci 
6695be168c0dSopenharmony_ci int BenchmarkBase::CheckDeviceTypeValid() {
6696be168c0dSopenharmony_ci   if (flags_->device_ != "CPU" && flags_->device_ != "GPU" && flags_->device_ != "NPU" &&
6697be168c0dSopenharmony_ci-      flags_->device_ != "Ascend310" && flags_->device_ != "Ascend310P") {
6698be168c0dSopenharmony_ci+      flags_->device_ != "Ascend310" && flags_->device_ != "Ascend310P" && flags_->device_ != "NNRT") {
6699be168c0dSopenharmony_ci     MS_LOG(ERROR) << "Device type:" << flags_->device_ << " is not supported.";
6700be168c0dSopenharmony_ci     std::cerr << "Device type:" << flags_->device_ << " is not supported." << std::endl;
6701be168c0dSopenharmony_ci     return RET_ERROR;
6702be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark/benchmark_base.h b/mindspore/lite/tools/benchmark/benchmark_base.h
6703be168c0dSopenharmony_ciindex acdea21a..f818270c 100644
6704be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark/benchmark_base.h
6705be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark/benchmark_base.h
6706be168c0dSopenharmony_ci@@ -122,7 +122,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
6707be168c0dSopenharmony_ci     AddFlag(&BenchmarkFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
6708be168c0dSopenharmony_ci     AddFlag(&BenchmarkFlags::group_info_file_, "GroupInfoFile", "Communication group info file", "");
6709be168c0dSopenharmony_ci     AddFlag(&BenchmarkFlags::config_file_, "configFile", "Config file", "");
6710be168c0dSopenharmony_ci-    AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend310P | Auto", "CPU");
6711be168c0dSopenharmony_ci+    AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend310P | NNRT | Auto", "CPU");
6712be168c0dSopenharmony_ci     AddFlag(&BenchmarkFlags::provider_, "provider", "device provider litert | tensorrt | mindrt", "litert");
6713be168c0dSopenharmony_ci     AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode", "Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU.", 1);
6714be168c0dSopenharmony_ci     // MarkPerformance
6715be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark/benchmark_c_api.cc b/mindspore/lite/tools/benchmark/benchmark_c_api.cc
6716be168c0dSopenharmony_ciindex 252e65c6..cb0c56b0 100644
6717be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark/benchmark_c_api.cc
6718be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark/benchmark_c_api.cc
6719be168c0dSopenharmony_ci@@ -125,6 +125,10 @@ int BenchmarkCApi::InitContext() {
6720be168c0dSopenharmony_ci     OH_AI_DeviceInfoSetFrequency(npu_device_info, kFrequencyDefault);
6721be168c0dSopenharmony_ci     OH_AI_ContextAddDeviceInfo(context_, npu_device_info);
6722be168c0dSopenharmony_ci   }
6723be168c0dSopenharmony_ci+  if (flags_->device_ == "NNRT") {
6724be168c0dSopenharmony_ci+    OH_AI_DeviceInfoHandle nnrt_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
6725be168c0dSopenharmony_ci+    OH_AI_ContextAddDeviceInfo(context_, nnrt_device_info);
6726be168c0dSopenharmony_ci+  }
6727be168c0dSopenharmony_ci   OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
6728be168c0dSopenharmony_ci   OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_);
6729be168c0dSopenharmony_ci   OH_AI_ContextAddDeviceInfo(context_, cpu_device_info);
6730be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc
6731be168c0dSopenharmony_ciindex bb36c168..c18111b6 100644
6732be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc
6733be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc
6734be168c0dSopenharmony_ci@@ -521,6 +521,11 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context>
6735be168c0dSopenharmony_ci     // InitMSContextForAscend(context, &device_list);
6736be168c0dSopenharmony_ci   }
6737be168c0dSopenharmony_ci 
6738be168c0dSopenharmony_ci+  if (flags_->device_ == "NNRT" || flags_->device_ == "Auto") {
6739be168c0dSopenharmony_ci+    std::shared_ptr<NNRTDeviceInfo> nnrt_device_info = std::make_shared<NNRTDeviceInfo>();
6740be168c0dSopenharmony_ci+    device_list.push_back(nnrt_device_info);
6741be168c0dSopenharmony_ci+  }
6742be168c0dSopenharmony_ci+
6743be168c0dSopenharmony_ci   // CPU priority is behind GPU and NPU
6744be168c0dSopenharmony_ci   std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
6745be168c0dSopenharmony_ci   device_info->SetEnableFP16(flags_->enable_fp16_);
6746be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/CMakeLists.txt b/mindspore/lite/tools/benchmark_train/CMakeLists.txt
6747be168c0dSopenharmony_ciindex 0c558524..1b9fc347 100644
6748be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/CMakeLists.txt
6749be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/CMakeLists.txt
6750be168c0dSopenharmony_ci@@ -9,6 +9,9 @@ set(COMMON_SRC
6751be168c0dSopenharmony_ci set(TEST_SRC
6752be168c0dSopenharmony_ci     ${CMAKE_CURRENT_SOURCE_DIR}/main.cc
6753be168c0dSopenharmony_ci     ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc
6754be168c0dSopenharmony_ci+    ${CMAKE_CURRENT_SOURCE_DIR}/net_train_base.cc
6755be168c0dSopenharmony_ci+    ${CMAKE_CURRENT_SOURCE_DIR}/run_net_train.cc
6756be168c0dSopenharmony_ci+    ${CMAKE_CURRENT_SOURCE_DIR}/net_train_c_api.cc
6757be168c0dSopenharmony_ci     )
6758be168c0dSopenharmony_ci 
6759be168c0dSopenharmony_ci # add static securec link library
6760be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/main.cc b/mindspore/lite/tools/benchmark_train/main.cc
6761be168c0dSopenharmony_ciindex abf3d9dd..76f85aa7 100644
6762be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/main.cc
6763be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/main.cc
6764be168c0dSopenharmony_ci@@ -17,7 +17,8 @@
6765be168c0dSopenharmony_ci #include <malloc.h>
6766be168c0dSopenharmony_ci #include <unistd.h>
6767be168c0dSopenharmony_ci #include <fstream>
6768be168c0dSopenharmony_ci-#include "tools/benchmark_train/net_train.h"
6769be168c0dSopenharmony_ci+#include <iostream>
6770be168c0dSopenharmony_ci+#include "tools/benchmark_train/run_net_train.h"
6771be168c0dSopenharmony_ci 
6772be168c0dSopenharmony_ci void PrintMem() {
6773be168c0dSopenharmony_ci   std::string proc_file = "/proc/" + std::to_string(getpid()) + "/status";
6774be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_runner.cc b/mindspore/lite/tools/benchmark_train/net_runner.cc
6775be168c0dSopenharmony_ciindex 9b63d29f..edf3e964 100644
6776be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_runner.cc
6777be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_runner.cc
6778be168c0dSopenharmony_ci@@ -15,7 +15,7 @@
6779be168c0dSopenharmony_ci  */
6780be168c0dSopenharmony_ci 
6781be168c0dSopenharmony_ci #include "tools/benchmark_train/net_runner.h"
6782be168c0dSopenharmony_ci-#include "tools/benchmark_train/net_train.h"
6783be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_base.h"
6784be168c0dSopenharmony_ci #include <getopt.h>
6785be168c0dSopenharmony_ci #include <malloc.h>
6786be168c0dSopenharmony_ci #include <cmath>
6787be168c0dSopenharmony_ci@@ -187,7 +187,7 @@ int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) {
6788be168c0dSopenharmony_ci     auto output = tensor.Data();
6789be168c0dSopenharmony_ci     size_t size;
6790be168c0dSopenharmony_ci     std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
6791be168c0dSopenharmony_ci-    auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(output_file.c_str(), &size));
6792be168c0dSopenharmony_ci+    auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(output_file.c_str(), &size));
6793be168c0dSopenharmony_ci     if (bin_buf == nullptr) {
6794be168c0dSopenharmony_ci       MS_LOG(ERROR) << "ReadFile return nullptr";
6795be168c0dSopenharmony_ci       std::cout << "ReadFile return nullptr" << std::endl;
6796be168c0dSopenharmony_ci@@ -200,7 +200,7 @@ int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) {
6797be168c0dSopenharmony_ci                 << ", read size: " << size << std::endl;
6798be168c0dSopenharmony_ci       return mindspore::kLiteError;
6799be168c0dSopenharmony_ci     }
6800be168c0dSopenharmony_ci-    float bias = mindspore::lite::NetTrain::CompareData<float>(bin_buf.get(), tensor.ElementNum(),
6801be168c0dSopenharmony_ci+    float bias = mindspore::lite::NetTrainBase::CompareData<float>(bin_buf.get(), tensor.ElementNum(),
6802be168c0dSopenharmony_ci                                                                reinterpret_cast<const float *>(output.get()));
6803be168c0dSopenharmony_ci     if (bias >= 0) {
6804be168c0dSopenharmony_ci       total_bias += bias;
6805be168c0dSopenharmony_ci@@ -332,7 +332,7 @@ int NetRunner::ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs) {
6806be168c0dSopenharmony_ci     }
6807be168c0dSopenharmony_ci     size_t size;
6808be168c0dSopenharmony_ci     std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
6809be168c0dSopenharmony_ci-    auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(file_name.c_str(), &size));
6810be168c0dSopenharmony_ci+    auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(file_name.c_str(), &size));
6811be168c0dSopenharmony_ci     if (bin_buf == nullptr) {
6812be168c0dSopenharmony_ci       MS_LOG(ERROR) << "ReadFile return nullptr";
6813be168c0dSopenharmony_ci       std::cout << "ReadFile return nullptr" << std::endl;
6814be168c0dSopenharmony_ci@@ -368,4 +368,4 @@ int CallBack(mindspore::lite::NetTrainFlags *flags) {
6815be168c0dSopenharmony_ci   return nr.Main();
6816be168c0dSopenharmony_ci }
6817be168c0dSopenharmony_ci 
6818be168c0dSopenharmony_ci-int init = mindspore::lite::NetTrain::SetNr(CallBack);
6819be168c0dSopenharmony_ci+int init = mindspore::lite::NetTrainBase::SetNr(CallBack);
6820be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc
6821be168c0dSopenharmony_ciindex d1150043..514bba53 100644
6822be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_train.cc
6823be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train.cc
6824be168c0dSopenharmony_ci@@ -31,74 +31,11 @@
6825be168c0dSopenharmony_ci 
6826be168c0dSopenharmony_ci namespace mindspore {
6827be168c0dSopenharmony_ci namespace lite {
6828be168c0dSopenharmony_ci-static const char *DELIM_SLASH = "/";
6829be168c0dSopenharmony_ci-constexpr const char *DELIM_COLON = ":";
6830be168c0dSopenharmony_ci-constexpr const char *DELIM_COMMA = ",";
6831be168c0dSopenharmony_ci-constexpr int RET_TOO_BIG = -9;
6832be168c0dSopenharmony_ci constexpr int kField0 = 0;
6833be168c0dSopenharmony_ci constexpr int kField1 = 1;
6834be168c0dSopenharmony_ci constexpr int kField2 = 2;
6835be168c0dSopenharmony_ci constexpr int kField3 = 3;
6836be168c0dSopenharmony_ci constexpr int kField4 = 4;
6837be168c0dSopenharmony_ci-constexpr int kFieldsToPrint = 5;
6838be168c0dSopenharmony_ci-constexpr int kPrintOffset = 4;
6839be168c0dSopenharmony_ci-static const int kTHOUSAND = 1000;
6840be168c0dSopenharmony_ci-constexpr int kDumpInputsAndOutputs = 0;
6841be168c0dSopenharmony_ci-constexpr int kDumpOutputs = 2;
6842be168c0dSopenharmony_ci-
6843be168c0dSopenharmony_ci-const std::unordered_map<int, std::string> kTypeIdMap{
6844be168c0dSopenharmony_ci-  {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"},    {kNumberTypeFloat32, "Float32"},
6845be168c0dSopenharmony_ci-  {kNumberTypeInt8, "Int8"},       {kNumberTypeInt16, "Int16"},      {kNumberTypeInt, "Int32"},
6846be168c0dSopenharmony_ci-  {kNumberTypeInt32, "Int32"},     {kNumberTypeUInt8, "UInt8"},      {kNumberTypeUInt16, "UInt16"},
6847be168c0dSopenharmony_ci-  {kNumberTypeUInt, "UInt32"},     {kNumberTypeUInt32, "UInt32"},    {kObjectTypeString, "String"},
6848be168c0dSopenharmony_ci-  {kNumberTypeBool, "Bool"},       {kObjectTypeTensorType, "Tensor"}};
6849be168c0dSopenharmony_ci-
6850be168c0dSopenharmony_ci-const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{
6851be168c0dSopenharmony_ci-  {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"},     {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"},
6852be168c0dSopenharmony_ci-  {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"},     {mindspore::CKHW, "CKHW"},   {mindspore::KHWC, "KHWC"},
6853be168c0dSopenharmony_ci-  {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"},         {mindspore::HW4, "HW4"},     {mindspore::NC, "NC"},
6854be168c0dSopenharmony_ci-  {mindspore::NC4, "NC4"},   {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}};
6855be168c0dSopenharmony_ci-
6856be168c0dSopenharmony_ci-std::function<int(NetTrainFlags *)> NetTrain::nr_cb_ = nullptr;
6857be168c0dSopenharmony_ci-
6858be168c0dSopenharmony_ci-int NetTrain::SetNr(std::function<int(NetTrainFlags *)> param) {
6859be168c0dSopenharmony_ci-  nr_cb_ = param;
6860be168c0dSopenharmony_ci-  return 0;
6861be168c0dSopenharmony_ci-}
6862be168c0dSopenharmony_ci-
6863be168c0dSopenharmony_ci-float *NetTrain::ReadFileBuf(const std::string file, size_t *size) {
6864be168c0dSopenharmony_ci-  if (file.empty()) {
6865be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "file is nullptr";
6866be168c0dSopenharmony_ci-    return nullptr;
6867be168c0dSopenharmony_ci-  }
6868be168c0dSopenharmony_ci-  MS_ASSERT(size != nullptr);
6869be168c0dSopenharmony_ci-  std::string real_path = RealPath(file.c_str());
6870be168c0dSopenharmony_ci-  std::ifstream ifs(real_path);
6871be168c0dSopenharmony_ci-  if (!ifs.good()) {
6872be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
6873be168c0dSopenharmony_ci-    return nullptr;
6874be168c0dSopenharmony_ci-  }
6875be168c0dSopenharmony_ci-
6876be168c0dSopenharmony_ci-  if (!ifs.is_open()) {
6877be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "file: " << real_path << " open failed";
6878be168c0dSopenharmony_ci-    return nullptr;
6879be168c0dSopenharmony_ci-  }
6880be168c0dSopenharmony_ci-
6881be168c0dSopenharmony_ci-  ifs.seekg(0, std::ios::end);
6882be168c0dSopenharmony_ci-  *size = ifs.tellg();
6883be168c0dSopenharmony_ci-  std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1);
6884be168c0dSopenharmony_ci-  if (buf == nullptr) {
6885be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
6886be168c0dSopenharmony_ci-    ifs.close();
6887be168c0dSopenharmony_ci-    return nullptr;
6888be168c0dSopenharmony_ci-  }
6889be168c0dSopenharmony_ci-
6890be168c0dSopenharmony_ci-  ifs.seekg(0, std::ios::beg);
6891be168c0dSopenharmony_ci-  ifs.read(reinterpret_cast<char *>(buf.get()), *size);
6892be168c0dSopenharmony_ci-  ifs.close();
6893be168c0dSopenharmony_ci-
6894be168c0dSopenharmony_ci-  return buf.release();
6895be168c0dSopenharmony_ci-}
6896be168c0dSopenharmony_ci 
6897be168c0dSopenharmony_ci int NetTrain::GenerateInputData() {
6898be168c0dSopenharmony_ci   for (auto tensor : ms_inputs_for_api_) {
6899be168c0dSopenharmony_ci@@ -120,28 +57,6 @@ int NetTrain::GenerateInputData() {
6900be168c0dSopenharmony_ci   return RET_OK;
6901be168c0dSopenharmony_ci }
6902be168c0dSopenharmony_ci 
6903be168c0dSopenharmony_ci-int NetTrain::LoadInput() {
6904be168c0dSopenharmony_ci-  inputs_buf_.clear();
6905be168c0dSopenharmony_ci-  inputs_size_.clear();
6906be168c0dSopenharmony_ci-  batch_num_ = 0;
6907be168c0dSopenharmony_ci-  if (flags_->in_data_file_.empty()) {
6908be168c0dSopenharmony_ci-    auto status = GenerateInputData();
6909be168c0dSopenharmony_ci-    if (status != RET_OK) {
6910be168c0dSopenharmony_ci-      std::cerr << "Generate input data error " << status << std::endl;
6911be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Generate input data error " << status;
6912be168c0dSopenharmony_ci-      return status;
6913be168c0dSopenharmony_ci-    }
6914be168c0dSopenharmony_ci-  } else {
6915be168c0dSopenharmony_ci-    auto status = ReadInputFile();
6916be168c0dSopenharmony_ci-    if (status != RET_OK) {
6917be168c0dSopenharmony_ci-      std::cerr << "Read Input File error, " << status << std::endl;
6918be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Read Input File error, " << status;
6919be168c0dSopenharmony_ci-      return status;
6920be168c0dSopenharmony_ci-    }
6921be168c0dSopenharmony_ci-  }
6922be168c0dSopenharmony_ci-  return RET_OK;
6923be168c0dSopenharmony_ci-}
6924be168c0dSopenharmony_ci-
6925be168c0dSopenharmony_ci int NetTrain::LoadStepInput(size_t step) {
6926be168c0dSopenharmony_ci   if (step >= batch_num_) {
6927be168c0dSopenharmony_ci     auto cur_batch = step + 1;
6928be168c0dSopenharmony_ci@@ -269,30 +184,6 @@ int NetTrain::CompareOutput() {
6929be168c0dSopenharmony_ci   }
6930be168c0dSopenharmony_ci }
6931be168c0dSopenharmony_ci 
6932be168c0dSopenharmony_ci-std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name,
6933be168c0dSopenharmony_ci-                                   const std::string &file_type, const size_t &idx) {
6934be168c0dSopenharmony_ci-  std::string file_name = op_name;
6935be168c0dSopenharmony_ci-  auto pos = file_name.find_first_of('/');
6936be168c0dSopenharmony_ci-  while (pos != std::string::npos) {
6937be168c0dSopenharmony_ci-    file_name.replace(pos, 1, ".");
6938be168c0dSopenharmony_ci-    pos = file_name.find_first_of('/');
6939be168c0dSopenharmony_ci-  }
6940be168c0dSopenharmony_ci-  file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_";
6941be168c0dSopenharmony_ci-  for (const auto &dim : tensor->Shape()) {
6942be168c0dSopenharmony_ci-    file_name += std::to_string(dim) + "_";
6943be168c0dSopenharmony_ci-  }
6944be168c0dSopenharmony_ci-  if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) {
6945be168c0dSopenharmony_ci-    file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType()));
6946be168c0dSopenharmony_ci-  }
6947be168c0dSopenharmony_ci-  auto tensor_format = tensor->format();
6948be168c0dSopenharmony_ci-  if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) {
6949be168c0dSopenharmony_ci-    file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin";
6950be168c0dSopenharmony_ci-  }
6951be168c0dSopenharmony_ci-
6952be168c0dSopenharmony_ci-  file_name += ".bin";
6953be168c0dSopenharmony_ci-  return file_name;
6954be168c0dSopenharmony_ci-}
6955be168c0dSopenharmony_ci-
6956be168c0dSopenharmony_ci int NetTrain::MarkPerformance() {
6957be168c0dSopenharmony_ci   MS_LOG(INFO) << "Running train loops...";
6958be168c0dSopenharmony_ci   std::cout << "Running train loops..." << std::endl;
6959be168c0dSopenharmony_ci@@ -574,26 +465,6 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
6960be168c0dSopenharmony_ci   return RET_OK;
6961be168c0dSopenharmony_ci }
6962be168c0dSopenharmony_ci 
6963be168c0dSopenharmony_ci-int NetTrain::RunNetTrain() {
6964be168c0dSopenharmony_ci-  auto file_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1);
6965be168c0dSopenharmony_ci-  bool is_train = (file_name.find("train") != std::string::npos) || !flags_->bb_model_file_.empty();
6966be168c0dSopenharmony_ci-  auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, is_train, flags_->epochs_);
6967be168c0dSopenharmony_ci-  if (status != RET_OK) {
6968be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status;
6969be168c0dSopenharmony_ci-    std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status
6970be168c0dSopenharmony_ci-              << std::endl;
6971be168c0dSopenharmony_ci-    return status;
6972be168c0dSopenharmony_ci-  }
6973be168c0dSopenharmony_ci-
6974be168c0dSopenharmony_ci-  status = CheckExecutionOfSavedModels();  // re-initialize sessions according to flags
6975be168c0dSopenharmony_ci-  if (status != RET_OK) {
6976be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Run CheckExecute error: " << status;
6977be168c0dSopenharmony_ci-    std::cout << "Run CheckExecute error: " << status << std::endl;
6978be168c0dSopenharmony_ci-    return status;
6979be168c0dSopenharmony_ci-  }
6980be168c0dSopenharmony_ci-  return RET_OK;
6981be168c0dSopenharmony_ci-}
6982be168c0dSopenharmony_ci-
6983be168c0dSopenharmony_ci int NetTrain::SaveModels() {
6984be168c0dSopenharmony_ci   if (!flags_->export_file_.empty()) {
6985be168c0dSopenharmony_ci     if (flags_->bb_model_file_.empty()) {
6986be168c0dSopenharmony_ci@@ -635,77 +506,6 @@ int NetTrain::SaveModels() {
6987be168c0dSopenharmony_ci   return RET_OK;
6988be168c0dSopenharmony_ci }
6989be168c0dSopenharmony_ci 
6990be168c0dSopenharmony_ci-int NetTrain::CheckExecutionOfSavedModels() {
6991be168c0dSopenharmony_ci-  int status = RET_OK;
6992be168c0dSopenharmony_ci-  if (!flags_->export_file_.empty()) {
6993be168c0dSopenharmony_ci-    status = NetTrain::CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0);
6994be168c0dSopenharmony_ci-    if (status != RET_OK) {
6995be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status;
6996be168c0dSopenharmony_ci-      std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl;
6997be168c0dSopenharmony_ci-      return status;
6998be168c0dSopenharmony_ci-    }
6999be168c0dSopenharmony_ci-    if (flags_->bb_model_file_.empty()) {
7000be168c0dSopenharmony_ci-      status = NetTrain::CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false);
7001be168c0dSopenharmony_ci-      if (status != RET_OK) {
7002be168c0dSopenharmony_ci-        MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status;
7003be168c0dSopenharmony_ci-        std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl;
7004be168c0dSopenharmony_ci-        return status;
7005be168c0dSopenharmony_ci-      }
7006be168c0dSopenharmony_ci-    }
7007be168c0dSopenharmony_ci-  }
7008be168c0dSopenharmony_ci-  if (!flags_->inference_file_.empty()) {
7009be168c0dSopenharmony_ci-    status = NetTrain::CreateAndRunNetwork(flags_->inference_file_, "", false, 0);
7010be168c0dSopenharmony_ci-    if (status != RET_OK) {
7011be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status;
7012be168c0dSopenharmony_ci-      std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl;
7013be168c0dSopenharmony_ci-      return status;
7014be168c0dSopenharmony_ci-    }
7015be168c0dSopenharmony_ci-    status = NetTrain::CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false);
7016be168c0dSopenharmony_ci-    if (status != RET_OK) {
7017be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status;
7018be168c0dSopenharmony_ci-      std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl;
7019be168c0dSopenharmony_ci-      return status;
7020be168c0dSopenharmony_ci-    }
7021be168c0dSopenharmony_ci-  }
7022be168c0dSopenharmony_ci-  return status;
7023be168c0dSopenharmony_ci-}
7024be168c0dSopenharmony_ci-
7025be168c0dSopenharmony_ci-void NetTrain::CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out) {
7026be168c0dSopenharmony_ci-  if (tensor == nullptr) {
7027be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "input tensor is nullptr.";
7028be168c0dSopenharmony_ci-    return;
7029be168c0dSopenharmony_ci-  }
7030be168c0dSopenharmony_ci-  int tensor_size = tensor->ElementNum();
7031be168c0dSopenharmony_ci-  void *data = tensor->MutableData();
7032be168c0dSopenharmony_ci-  auto *fdata = reinterpret_cast<float *>(tensor->MutableData());
7033be168c0dSopenharmony_ci-  auto type = tensor->DataType();
7034be168c0dSopenharmony_ci-  std::cout << node_type << " " << in_out << id << " shape=" << tensor->Shape() << " sum=";
7035be168c0dSopenharmony_ci-  switch (type) {
7036be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeFloat32:
7037be168c0dSopenharmony_ci-      TensorNan(reinterpret_cast<float *>(data), tensor_size);
7038be168c0dSopenharmony_ci-      std::cout << TensorSum<float>(data, tensor_size) << std::endl;
7039be168c0dSopenharmony_ci-      std::cout << "tensor name: " << tensor->Name() << std::endl;
7040be168c0dSopenharmony_ci-      std::cout << "data: ";
7041be168c0dSopenharmony_ci-      for (int i = 0; i <= kPrintOffset && i < tensor_size; i++) {
7042be168c0dSopenharmony_ci-        std::cout << static_cast<float>(fdata[i]) << ", ";
7043be168c0dSopenharmony_ci-      }
7044be168c0dSopenharmony_ci-      std::cout << std::endl;
7045be168c0dSopenharmony_ci-      break;
7046be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeInt32:
7047be168c0dSopenharmony_ci-      std::cout << TensorSum<int>(data, tensor_size) << std::endl;
7048be168c0dSopenharmony_ci-      break;
7049be168c0dSopenharmony_ci-#ifdef ENABLE_FP16
7050be168c0dSopenharmony_ci-    case mindspore::DataType::kNumberTypeFloat16:
7051be168c0dSopenharmony_ci-      std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl;
7052be168c0dSopenharmony_ci-      TensorNan(reinterpret_cast<float16_t *>(data), tensor_size);
7053be168c0dSopenharmony_ci-      break;
7054be168c0dSopenharmony_ci-#endif
7055be168c0dSopenharmony_ci-    default:
7056be168c0dSopenharmony_ci-      std::cout << "unsupported type:" << static_cast<int>(type) << std::endl;
7057be168c0dSopenharmony_ci-      break;
7058be168c0dSopenharmony_ci-  }
7059be168c0dSopenharmony_ci-}
7060be168c0dSopenharmony_ci-
7061be168c0dSopenharmony_ci int NetTrain::InitDumpTensorDataCallbackParameter() {
7062be168c0dSopenharmony_ci   // before callback
7063be168c0dSopenharmony_ci   before_call_back_ = [&](const std::vector<mindspore::MSTensor> &before_inputs,
7064be168c0dSopenharmony_ci@@ -815,178 +615,6 @@ int NetTrain::InitTimeProfilingCallbackParameter() {
7065be168c0dSopenharmony_ci   return RET_OK;
7066be168c0dSopenharmony_ci }
7067be168c0dSopenharmony_ci 
7068be168c0dSopenharmony_ci-int NetTrain::InitCallbackParameter() {
7069be168c0dSopenharmony_ci-  int ret = RET_OK;
7070be168c0dSopenharmony_ci-  if (flags_->dump_tensor_data_) {
7071be168c0dSopenharmony_ci-    ret = InitDumpTensorDataCallbackParameter();
7072be168c0dSopenharmony_ci-  } else if (flags_->time_profiling_) {
7073be168c0dSopenharmony_ci-    ret = InitTimeProfilingCallbackParameter();
7074be168c0dSopenharmony_ci-  }
7075be168c0dSopenharmony_ci-  return ret;
7076be168c0dSopenharmony_ci-}
7077be168c0dSopenharmony_ci-
7078be168c0dSopenharmony_ci-void NetTrainFlags::InitResizeDimsList() {
7079be168c0dSopenharmony_ci-  std::string content = this->resize_dims_in_;
7080be168c0dSopenharmony_ci-  std::vector<int> shape;
7081be168c0dSopenharmony_ci-  auto shape_strs = StrSplit(content, std::string(DELIM_COLON));
7082be168c0dSopenharmony_ci-  for (const auto &shape_str : shape_strs) {
7083be168c0dSopenharmony_ci-    shape.clear();
7084be168c0dSopenharmony_ci-    auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA));
7085be168c0dSopenharmony_ci-    std::cout << "Resize Dims: ";
7086be168c0dSopenharmony_ci-    for (const auto &dim_str : dim_strs) {
7087be168c0dSopenharmony_ci-      std::cout << dim_str << " ";
7088be168c0dSopenharmony_ci-      shape.emplace_back(static_cast<int>(std::stoi(dim_str)));
7089be168c0dSopenharmony_ci-    }
7090be168c0dSopenharmony_ci-    std::cout << std::endl;
7091be168c0dSopenharmony_ci-    this->resize_dims_.emplace_back(shape);
7092be168c0dSopenharmony_ci-  }
7093be168c0dSopenharmony_ci-}
7094be168c0dSopenharmony_ci-
7095be168c0dSopenharmony_ci-int NetTrain::Init() {
7096be168c0dSopenharmony_ci-  if (this->flags_ == nullptr) {
7097be168c0dSopenharmony_ci-    return 1;
7098be168c0dSopenharmony_ci-  }
7099be168c0dSopenharmony_ci-  MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
7100be168c0dSopenharmony_ci-  MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
7101be168c0dSopenharmony_ci-  MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
7102be168c0dSopenharmony_ci-  MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_;
7103be168c0dSopenharmony_ci-  MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_;
7104be168c0dSopenharmony_ci-  MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_;
7105be168c0dSopenharmony_ci-  MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
7106be168c0dSopenharmony_ci-  MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
7107be168c0dSopenharmony_ci-  MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
7108be168c0dSopenharmony_ci-  MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;
7109be168c0dSopenharmony_ci-  MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_;
7110be168c0dSopenharmony_ci-
7111be168c0dSopenharmony_ci-  if (this->flags_->epochs_ < 0) {
7112be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";
7113be168c0dSopenharmony_ci-    std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl;
7114be168c0dSopenharmony_ci-    return RET_ERROR;
7115be168c0dSopenharmony_ci-  }
7116be168c0dSopenharmony_ci-
7117be168c0dSopenharmony_ci-  if (this->flags_->num_threads_ < 1) {
7118be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
7119be168c0dSopenharmony_ci-    std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl;
7120be168c0dSopenharmony_ci-    return RET_ERROR;
7121be168c0dSopenharmony_ci-  }
7122be168c0dSopenharmony_ci-
7123be168c0dSopenharmony_ci-  this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
7124be168c0dSopenharmony_ci-
7125be168c0dSopenharmony_ci-  if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) {
7126be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided";
7127be168c0dSopenharmony_ci-    std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl;
7128be168c0dSopenharmony_ci-    return RET_ERROR;
7129be168c0dSopenharmony_ci-  }
7130be168c0dSopenharmony_ci-
7131be168c0dSopenharmony_ci-  if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) {
7132be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided";
7133be168c0dSopenharmony_ci-    std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl;
7134be168c0dSopenharmony_ci-    return RET_ERROR;
7135be168c0dSopenharmony_ci-  }
7136be168c0dSopenharmony_ci-
7137be168c0dSopenharmony_ci-  if (flags_->model_file_.empty()) {
7138be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "modelPath is required";
7139be168c0dSopenharmony_ci-    std::cerr << "modelPath is required" << std::endl;
7140be168c0dSopenharmony_ci-    return 1;
7141be168c0dSopenharmony_ci-  }
7142be168c0dSopenharmony_ci-
7143be168c0dSopenharmony_ci-  // get dump data output path
7144be168c0dSopenharmony_ci-  auto dump_cfg_path = std::getenv(dump::kConfigPath);
7145be168c0dSopenharmony_ci-  if (dump_cfg_path != nullptr) {
7146be168c0dSopenharmony_ci-    flags_->dump_tensor_data_ = true;
7147be168c0dSopenharmony_ci-    if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) {
7148be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "parse dump config file failed.";
7149be168c0dSopenharmony_ci-      return RET_ERROR;
7150be168c0dSopenharmony_ci-    }
7151be168c0dSopenharmony_ci-  } else {
7152be168c0dSopenharmony_ci-    MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data";
7153be168c0dSopenharmony_ci-  }
7154be168c0dSopenharmony_ci-
7155be168c0dSopenharmony_ci-  auto status = InitCallbackParameter();
7156be168c0dSopenharmony_ci-  if (status != RET_OK) {
7157be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Init callback Parameter failed.";
7158be168c0dSopenharmony_ci-    std::cerr << "Init callback Parameter failed." << std::endl;
7159be168c0dSopenharmony_ci-    return RET_ERROR;
7160be168c0dSopenharmony_ci-  }
7161be168c0dSopenharmony_ci-
7162be168c0dSopenharmony_ci-  flags_->InitResizeDimsList();
7163be168c0dSopenharmony_ci-  if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
7164be168c0dSopenharmony_ci-      flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
7165be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
7166be168c0dSopenharmony_ci-    std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
7167be168c0dSopenharmony_ci-    return RET_ERROR;
7168be168c0dSopenharmony_ci-  }
7169be168c0dSopenharmony_ci-  return RET_OK;
7170be168c0dSopenharmony_ci-}
7171be168c0dSopenharmony_ci-
7172be168c0dSopenharmony_ci-namespace {
7173be168c0dSopenharmony_ci-constexpr int kNumToPrint = 5;
7174be168c0dSopenharmony_ci-}
7175be168c0dSopenharmony_ci-
7176be168c0dSopenharmony_ci-int NetTrain::InitDumpConfigFromJson(std::string path) {
7177be168c0dSopenharmony_ci-  auto real_path = RealPath(path.c_str());
7178be168c0dSopenharmony_ci-  std::ifstream ifs(real_path);
7179be168c0dSopenharmony_ci-  if (!ifs.good()) {
7180be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
7181be168c0dSopenharmony_ci-    return RET_ERROR;
7182be168c0dSopenharmony_ci-  }
7183be168c0dSopenharmony_ci-  if (!ifs.is_open()) {
7184be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "file: " << real_path << " open failed";
7185be168c0dSopenharmony_ci-    return RET_ERROR;
7186be168c0dSopenharmony_ci-  }
7187be168c0dSopenharmony_ci-
7188be168c0dSopenharmony_ci-  try {
7189be168c0dSopenharmony_ci-    dump_cfg_json_ = nlohmann::json::parse(ifs);
7190be168c0dSopenharmony_ci-  } catch (const nlohmann::json::parse_error &error) {
7191be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "parse json file failed, please check your file.";
7192be168c0dSopenharmony_ci-    return RET_ERROR;
7193be168c0dSopenharmony_ci-  }
7194be168c0dSopenharmony_ci-  if (dump_cfg_json_[dump::kSettings] == nullptr) {
7195be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "\"common_dump_settings\" is required.";
7196be168c0dSopenharmony_ci-    return RET_ERROR;
7197be168c0dSopenharmony_ci-  }
7198be168c0dSopenharmony_ci-  if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) {
7199be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "\"dump_mode\" is required.";
7200be168c0dSopenharmony_ci-    return RET_ERROR;
7201be168c0dSopenharmony_ci-  }
7202be168c0dSopenharmony_ci-  if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) {
7203be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "\"path\" is required.";
7204be168c0dSopenharmony_ci-    return RET_ERROR;
7205be168c0dSopenharmony_ci-  }
7206be168c0dSopenharmony_ci-  if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) {
7207be168c0dSopenharmony_ci-    dump_cfg_json_[dump::kSettings][dump::kNetName] = "default";
7208be168c0dSopenharmony_ci-  }
7209be168c0dSopenharmony_ci-  if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) {
7210be168c0dSopenharmony_ci-    dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0;
7211be168c0dSopenharmony_ci-  }
7212be168c0dSopenharmony_ci-  if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr &&
7213be168c0dSopenharmony_ci-      !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) {
7214be168c0dSopenharmony_ci-    if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) {
7215be168c0dSopenharmony_ci-      MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)";
7216be168c0dSopenharmony_ci-      return RET_ERROR;
7217be168c0dSopenharmony_ci-    }
7218be168c0dSopenharmony_ci-  }
7219be168c0dSopenharmony_ci-
7220be168c0dSopenharmony_ci-  auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>();
7221be168c0dSopenharmony_ci-  auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>();
7222be168c0dSopenharmony_ci-  if (abs_path.back() == '\\' || abs_path.back() == '/') {
7223be168c0dSopenharmony_ci-    dump_file_output_dir_ = abs_path + net_name;
7224be168c0dSopenharmony_ci-  } else {
7225be168c0dSopenharmony_ci-#ifdef _WIN32
7226be168c0dSopenharmony_ci-    dump_file_output_dir_ = abs_path + "\\" + net_name;
7227be168c0dSopenharmony_ci-#else
7228be168c0dSopenharmony_ci-    dump_file_output_dir_ = abs_path + "/" + net_name;
7229be168c0dSopenharmony_ci-#endif
7230be168c0dSopenharmony_ci-  }
7231be168c0dSopenharmony_ci-
7232be168c0dSopenharmony_ci-  auto status = CreateOutputDir(&dump_file_output_dir_);
7233be168c0dSopenharmony_ci-  if (status != RET_OK) {
7234be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "create data output directory failed.";
7235be168c0dSopenharmony_ci-    return RET_ERROR;
7236be168c0dSopenharmony_ci-  }
7237be168c0dSopenharmony_ci-  return RET_OK;
7238be168c0dSopenharmony_ci-}
7239be168c0dSopenharmony_ci-
7240be168c0dSopenharmony_ci int NetTrain::PrintResult(const std::vector<std::string> &title,
7241be168c0dSopenharmony_ci                           const std::map<std::string, std::pair<int, float>> &result) {
7242be168c0dSopenharmony_ci   std::vector<size_t> columnLenMax(kFieldsToPrint);
7243be168c0dSopenharmony_ci@@ -1035,7 +663,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
7244be168c0dSopenharmony_ci   }
7245be168c0dSopenharmony_ci 
7246be168c0dSopenharmony_ci   printf("-------------------------------------------------------------------------\n");
7247be168c0dSopenharmony_ci-  for (int i = 0; i < kNumToPrint; i++) {
7248be168c0dSopenharmony_ci+  for (int i = 0; i < kFieldsToPrint; i++) {
7249be168c0dSopenharmony_ci     auto printBuf = title[i];
7250be168c0dSopenharmony_ci     if (printBuf.size() > columnLenMax.at(i)) {
7251be168c0dSopenharmony_ci       columnLenMax.at(i) = printBuf.size();
7252be168c0dSopenharmony_ci@@ -1045,7 +673,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
7253be168c0dSopenharmony_ci   }
7254be168c0dSopenharmony_ci   printf("\n");
7255be168c0dSopenharmony_ci   for (auto &row : rows) {
7256be168c0dSopenharmony_ci-    for (int j = 0; j < kNumToPrint; j++) {
7257be168c0dSopenharmony_ci+    for (int j = 0; j < kFieldsToPrint; j++) {
7258be168c0dSopenharmony_ci       auto printBuf = row[j];
7259be168c0dSopenharmony_ci       printBuf.resize(columnLenMax.at(j), ' ');
7260be168c0dSopenharmony_ci       printf("%s\t", printBuf.c_str());
7261be168c0dSopenharmony_ci@@ -1054,47 +682,5 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
7262be168c0dSopenharmony_ci   }
7263be168c0dSopenharmony_ci   return RET_OK;
7264be168c0dSopenharmony_ci }
7265be168c0dSopenharmony_ci-
7266be168c0dSopenharmony_ci-int RunNetTrain(int argc, const char **argv) {
7267be168c0dSopenharmony_ci-  NetTrainFlags flags;
7268be168c0dSopenharmony_ci-  Option<std::string> err = flags.ParseFlags(argc, argv);
7269be168c0dSopenharmony_ci-
7270be168c0dSopenharmony_ci-  if (err.IsSome()) {
7271be168c0dSopenharmony_ci-    std::cerr << err.Get() << std::endl;
7272be168c0dSopenharmony_ci-    std::cerr << flags.Usage() << std::endl;
7273be168c0dSopenharmony_ci-    return RET_ERROR;
7274be168c0dSopenharmony_ci-  }
7275be168c0dSopenharmony_ci-
7276be168c0dSopenharmony_ci-  if (flags.help) {
7277be168c0dSopenharmony_ci-    std::cerr << flags.Usage() << std::endl;
7278be168c0dSopenharmony_ci-    return RET_OK;
7279be168c0dSopenharmony_ci-  }
7280be168c0dSopenharmony_ci-  if (flags.unified_api_) {
7281be168c0dSopenharmony_ci-    return NetTrain::RunNr(&flags);
7282be168c0dSopenharmony_ci-  }
7283be168c0dSopenharmony_ci-  NetTrain net_trainer(&flags);
7284be168c0dSopenharmony_ci-  auto status = net_trainer.Init();
7285be168c0dSopenharmony_ci-  if (status != RET_OK) {
7286be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "NetTrain init Error : " << status;
7287be168c0dSopenharmony_ci-    std::cerr << "NetTrain init Error : " << status << std::endl;
7288be168c0dSopenharmony_ci-    return RET_ERROR;
7289be168c0dSopenharmony_ci-  }
7290be168c0dSopenharmony_ci-
7291be168c0dSopenharmony_ci-  status = net_trainer.RunNetTrain();
7292be168c0dSopenharmony_ci-  if (status != RET_OK) {
7293be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Run NetTrain "
7294be168c0dSopenharmony_ci-                  << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
7295be168c0dSopenharmony_ci-                  << " Failed : " << status;
7296be168c0dSopenharmony_ci-    std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
7297be168c0dSopenharmony_ci-              << " Failed : " << status << std::endl;
7298be168c0dSopenharmony_ci-    return RET_ERROR;
7299be168c0dSopenharmony_ci-  }
7300be168c0dSopenharmony_ci-
7301be168c0dSopenharmony_ci-  MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
7302be168c0dSopenharmony_ci-               << " Success.";
7303be168c0dSopenharmony_ci-  std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
7304be168c0dSopenharmony_ci-            << " Success." << std::endl;
7305be168c0dSopenharmony_ci-  return RET_OK;
7306be168c0dSopenharmony_ci-}
7307be168c0dSopenharmony_ci }  // namespace lite
7308be168c0dSopenharmony_ci }  // namespace mindspore
7309be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h
7310be168c0dSopenharmony_ciindex 67e58a04..bdf0ec88 100644
7311be168c0dSopenharmony_ci--- a/mindspore/lite/tools/benchmark_train/net_train.h
7312be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train.h
7313be168c0dSopenharmony_ci@@ -42,183 +42,22 @@
7314be168c0dSopenharmony_ci #include "tools/common/flag_parser.h"
7315be168c0dSopenharmony_ci #include "src/common/file_utils.h"
7316be168c0dSopenharmony_ci #include "src/common/utils.h"
7317be168c0dSopenharmony_ci-
7318be168c0dSopenharmony_ci-#ifdef ENABLE_FP16
7319be168c0dSopenharmony_ci-static __attribute__((always_inline)) inline bool MS_ISNAN_FP16(float16_t var) {
7320be168c0dSopenharmony_ci-  volatile float16_t d = var;
7321be168c0dSopenharmony_ci-  return d != d;
7322be168c0dSopenharmony_ci-}
7323be168c0dSopenharmony_ci-#endif
7324be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_base.h"
7325be168c0dSopenharmony_ci 
7326be168c0dSopenharmony_ci namespace mindspore::lite {
7327be168c0dSopenharmony_ci-enum MS_API DataType { kImage = 0, kBinary = 1 };
7328be168c0dSopenharmony_ci-
7329be168c0dSopenharmony_ci-constexpr float relativeTolerance = 1e-5;
7330be168c0dSopenharmony_ci-constexpr float absoluteTolerance = 1e-8;
7331be168c0dSopenharmony_ci extern const std::unordered_map<int, std::string> kTypeIdMap;
7332be168c0dSopenharmony_ci extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
7333be168c0dSopenharmony_ci 
7334be168c0dSopenharmony_ci-namespace dump {
7335be168c0dSopenharmony_ci-constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG";
7336be168c0dSopenharmony_ci-constexpr auto kSettings = "common_dump_settings";
7337be168c0dSopenharmony_ci-constexpr auto kMode = "dump_mode";
7338be168c0dSopenharmony_ci-constexpr auto kPath = "path";
7339be168c0dSopenharmony_ci-constexpr auto kNetName = "net_name";
7340be168c0dSopenharmony_ci-constexpr auto kInputOutput = "input_output";
7341be168c0dSopenharmony_ci-constexpr auto kKernels = "kernels";
7342be168c0dSopenharmony_ci-}  // namespace dump
7343be168c0dSopenharmony_ci-
7344be168c0dSopenharmony_ci-template <typename T>
7345be168c0dSopenharmony_ci-float TensorSum(const void *data, int size) {
7346be168c0dSopenharmony_ci-  const T *typed_data = reinterpret_cast<const T *>(data);
7347be168c0dSopenharmony_ci-  float sum = 0.f;
7348be168c0dSopenharmony_ci-  for (int i = 0; i < size; i++) {
7349be168c0dSopenharmony_ci-    sum += static_cast<float>(typed_data[i]);
7350be168c0dSopenharmony_ci-  }
7351be168c0dSopenharmony_ci-  return sum;
7352be168c0dSopenharmony_ci-}
7353be168c0dSopenharmony_ci-
7354be168c0dSopenharmony_ci-class MS_API NetTrainFlags : public virtual FlagParser {
7355be168c0dSopenharmony_ci+class MS_API NetTrain : public NetTrainBase {
7356be168c0dSopenharmony_ci  public:
7357be168c0dSopenharmony_ci-  NetTrainFlags() {
7358be168c0dSopenharmony_ci-    // common
7359be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::model_file_, "modelFile", "Input model file", "");
7360be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::bb_model_file_, "bbModelFile", "Backboine model for transfer session", "");
7361be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
7362be168c0dSopenharmony_ci-    // MarkPerformance
7363be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0);
7364be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false);
7365be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1);
7366be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1);
7367be168c0dSopenharmony_ci-    // MarkAccuracy
7368be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", "");
7369be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", "");
7370be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
7371be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false);
7372be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false);
7373be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", "");
7374be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", "");
7375be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false);
7376be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes",
7377be168c0dSopenharmony_ci-            "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
7378be168c0dSopenharmony_ci-    AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false);
7379be168c0dSopenharmony_ci-  }
7380be168c0dSopenharmony_ci-
7381be168c0dSopenharmony_ci-  ~NetTrainFlags() override = default;
7382be168c0dSopenharmony_ci-  void InitResizeDimsList();
7383be168c0dSopenharmony_ci+  explicit NetTrain(NetTrainFlags *flags) : NetTrainBase(flags) {}
7384be168c0dSopenharmony_ci+  virtual ~NetTrain() {}
7385be168c0dSopenharmony_ci 
7386be168c0dSopenharmony_ci- public:
7387be168c0dSopenharmony_ci-  // common
7388be168c0dSopenharmony_ci-  std::string model_file_;
7389be168c0dSopenharmony_ci-  std::string in_data_file_;
7390be168c0dSopenharmony_ci-  std::string bb_model_file_;
7391be168c0dSopenharmony_ci-  std::vector<std::string> input_data_list_;
7392be168c0dSopenharmony_ci-  DataType in_data_type_;
7393be168c0dSopenharmony_ci-  std::string in_data_type_in_ = "bin";
7394be168c0dSopenharmony_ci-  int cpu_bind_mode_ = 1;
7395be168c0dSopenharmony_ci-  bool enable_fp16_ = false;
7396be168c0dSopenharmony_ci-  bool virtual_batch_ = false;
7397be168c0dSopenharmony_ci-  // MarkPerformance
7398be168c0dSopenharmony_ci-  int num_threads_ = 1;
7399be168c0dSopenharmony_ci-  int warm_up_loop_count_ = 0;
7400be168c0dSopenharmony_ci-  bool time_profiling_;
7401be168c0dSopenharmony_ci-  int epochs_ = 1;
7402be168c0dSopenharmony_ci-  // MarkAccuracy
7403be168c0dSopenharmony_ci-  std::string data_file_;
7404be168c0dSopenharmony_ci-  std::string data_type_ = "FLOAT";
7405be168c0dSopenharmony_ci-  float accuracy_threshold_;
7406be168c0dSopenharmony_ci-  // Resize
7407be168c0dSopenharmony_ci-  std::string export_file_ = "";
7408be168c0dSopenharmony_ci-  std::string resize_dims_in_ = "";
7409be168c0dSopenharmony_ci-  bool layer_checksum_ = false;
7410be168c0dSopenharmony_ci-  std::vector<std::vector<int>> resize_dims_;
7411be168c0dSopenharmony_ci-  std::string loss_name_ = "";
7412be168c0dSopenharmony_ci-  std::string inference_file_ = "";
7413be168c0dSopenharmony_ci-  bool unified_api_ = false;
7414be168c0dSopenharmony_ci-  bool dump_tensor_data_ = false;
7415be168c0dSopenharmony_ci-};
7416be168c0dSopenharmony_ci-
7417be168c0dSopenharmony_ci-class MS_API NetTrain {
7418be168c0dSopenharmony_ci- public:
7419be168c0dSopenharmony_ci-  explicit NetTrain(NetTrainFlags *flags) : flags_(flags) {}
7420be168c0dSopenharmony_ci-  virtual ~NetTrain() = default;
7421be168c0dSopenharmony_ci-
7422be168c0dSopenharmony_ci-  int Init();
7423be168c0dSopenharmony_ci-  int RunNetTrain();
7424be168c0dSopenharmony_ci-  static float *ReadFileBuf(const std::string file, size_t *size);
7425be168c0dSopenharmony_ci-  static int SetNr(std::function<int(NetTrainFlags *)> param);
7426be168c0dSopenharmony_ci-  static int RunNr(NetTrainFlags *flags) {
7427be168c0dSopenharmony_ci-    if (nr_cb_ != nullptr) {
7428be168c0dSopenharmony_ci-      return nr_cb_(flags);
7429be168c0dSopenharmony_ci-    }
7430be168c0dSopenharmony_ci-    MS_LOG(WARNING) << "unified api was not tested";
7431be168c0dSopenharmony_ci-    std::cout << "unified api was not tested";
7432be168c0dSopenharmony_ci-    return RET_OK;
7433be168c0dSopenharmony_ci-  }
7434be168c0dSopenharmony_ci-  // tensorData need to be converter first
7435be168c0dSopenharmony_ci-  template <typename T>
7436be168c0dSopenharmony_ci-  static float CompareData(const float *refOutput, int size, const T *msTensorData) {
7437be168c0dSopenharmony_ci-    size_t errorCount = 0;
7438be168c0dSopenharmony_ci-    float meanError = 0;
7439be168c0dSopenharmony_ci-    std::cout << "Out tensor size is: " << size << std::endl;
7440be168c0dSopenharmony_ci-    std::cout << "Data of model output: ";
7441be168c0dSopenharmony_ci-    for (int j = 0; j < std::min(50, size); j++) {
7442be168c0dSopenharmony_ci-      std::cout << static_cast<float>(msTensorData[j]) << " ";
7443be168c0dSopenharmony_ci-    }
7444be168c0dSopenharmony_ci-    std::cout << std::endl;
7445be168c0dSopenharmony_ci-    std::cout << "Data of Ref output  : ";
7446be168c0dSopenharmony_ci-    for (int j = 0; j < std::min(50, size); j++) {
7447be168c0dSopenharmony_ci-      std::cout << refOutput[j] << " ";
7448be168c0dSopenharmony_ci-    }
7449be168c0dSopenharmony_ci-    std::cout << std::endl;
7450be168c0dSopenharmony_ci-    for (int j = 0; j < size; j++) {
7451be168c0dSopenharmony_ci-      if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
7452be168c0dSopenharmony_ci-        std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
7453be168c0dSopenharmony_ci-        MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
7454be168c0dSopenharmony_ci-        return RET_ERROR;
7455be168c0dSopenharmony_ci-      }
7456be168c0dSopenharmony_ci-
7457be168c0dSopenharmony_ci-      auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]);
7458be168c0dSopenharmony_ci-      auto absoluteError = std::fabs(static_cast<float>(msTensorData[j]) - refOutput[j]);
7459be168c0dSopenharmony_ci-      if (absoluteError > tolerance) {
7460be168c0dSopenharmony_ci-        if (fabs(refOutput[j]) == 0) {
7461be168c0dSopenharmony_ci-          if (absoluteError > 1e-5) {
7462be168c0dSopenharmony_ci-            meanError += absoluteError;
7463be168c0dSopenharmony_ci-            errorCount++;
7464be168c0dSopenharmony_ci-          } else {
7465be168c0dSopenharmony_ci-            continue;
7466be168c0dSopenharmony_ci-          }
7467be168c0dSopenharmony_ci-        } else {
7468be168c0dSopenharmony_ci-          // just assume that atol = rtol
7469be168c0dSopenharmony_ci-          meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN);
7470be168c0dSopenharmony_ci-          errorCount++;
7471be168c0dSopenharmony_ci-        }
7472be168c0dSopenharmony_ci-      }
7473be168c0dSopenharmony_ci-    }
7474be168c0dSopenharmony_ci-    std::cout << std::endl;
7475be168c0dSopenharmony_ci-    if (meanError > 0.0f) {
7476be168c0dSopenharmony_ci-      meanError /= errorCount;
7477be168c0dSopenharmony_ci-    }
7478be168c0dSopenharmony_ci-
7479be168c0dSopenharmony_ci-    if (meanError <= 0.0000001) {
7480be168c0dSopenharmony_ci-      std::cout << "Mean bias of tensor: 0%" << std::endl;
7481be168c0dSopenharmony_ci-    } else {
7482be168c0dSopenharmony_ci-      std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl;
7483be168c0dSopenharmony_ci-    }
7484be168c0dSopenharmony_ci-    return meanError;
7485be168c0dSopenharmony_ci-  }
7486be168c0dSopenharmony_ci-  int InitDumpConfigFromJson(std::string path);
7487be168c0dSopenharmony_ci-
7488be168c0dSopenharmony_ci- private:
7489be168c0dSopenharmony_ci-  // call GenerateInputData or ReadInputFile to init inputTensors
7490be168c0dSopenharmony_ci-  int LoadInput();
7491be168c0dSopenharmony_ci-  void CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out);
7492be168c0dSopenharmony_ci+ protected:
7493be168c0dSopenharmony_ci   // call GenerateRandomData to fill inputTensors
7494be168c0dSopenharmony_ci-  int GenerateInputData();
7495be168c0dSopenharmony_ci+  int GenerateInputData() override;
7496be168c0dSopenharmony_ci 
7497be168c0dSopenharmony_ci-  int GenerateRandomData(mindspore::MSTensor *tensor);
7498be168c0dSopenharmony_ci-
7499be168c0dSopenharmony_ci-  int ReadInputFile();
7500be168c0dSopenharmony_ci+  int ReadInputFile() override;
7501be168c0dSopenharmony_ci 
7502be168c0dSopenharmony_ci   int LoadStepInput(size_t step);
7503be168c0dSopenharmony_ci 
7504be168c0dSopenharmony_ci@@ -227,20 +66,19 @@ class MS_API NetTrain {
7505be168c0dSopenharmony_ci   void InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg);
7506be168c0dSopenharmony_ci 
7507be168c0dSopenharmony_ci   int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs,
7508be168c0dSopenharmony_ci-                          bool check_accuracy = true);
7509be168c0dSopenharmony_ci+                          bool check_accuracy = true) override;
7510be168c0dSopenharmony_ci 
7511be168c0dSopenharmony_ci   int CreateAndRunNetworkForInference(const std::string &filename, const std::shared_ptr<mindspore::Context> &context);
7512be168c0dSopenharmony_ci 
7513be168c0dSopenharmony_ci   int CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename,
7514be168c0dSopenharmony_ci                                   const std::shared_ptr<mindspore::Context> &context,
7515be168c0dSopenharmony_ci                                   const std::shared_ptr<TrainCfg> &train_cfg, int epochs);
7516be168c0dSopenharmony_ci-  int InitCallbackParameter();
7517be168c0dSopenharmony_ci 
7518be168c0dSopenharmony_ci-  int InitDumpTensorDataCallbackParameter();
7519be168c0dSopenharmony_ci+  int InitDumpTensorDataCallbackParameter() override;
7520be168c0dSopenharmony_ci 
7521be168c0dSopenharmony_ci-  int InitTimeProfilingCallbackParameter();
7522be168c0dSopenharmony_ci+  int InitTimeProfilingCallbackParameter() override;
7523be168c0dSopenharmony_ci 
7524be168c0dSopenharmony_ci-  int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result);
7525be168c0dSopenharmony_ci+  int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) override;
7526be168c0dSopenharmony_ci 
7527be168c0dSopenharmony_ci   template <typename T>
7528be168c0dSopenharmony_ci   void PrintInputData(mindspore::MSTensor *input) {
7529be168c0dSopenharmony_ci@@ -256,39 +94,11 @@ class MS_API NetTrain {
7530be168c0dSopenharmony_ci     std::cout << std::endl;
7531be168c0dSopenharmony_ci   }
7532be168c0dSopenharmony_ci 
7533be168c0dSopenharmony_ci-  template <typename T>
7534be168c0dSopenharmony_ci-  std::vector<int64_t> ConverterToInt64Vector(const std::vector<T> &srcDims) {
7535be168c0dSopenharmony_ci-    std::vector<int64_t> dims;
7536be168c0dSopenharmony_ci-    for (auto shape : srcDims) {
7537be168c0dSopenharmony_ci-      dims.push_back(static_cast<int64_t>(shape));
7538be168c0dSopenharmony_ci-    }
7539be168c0dSopenharmony_ci-    return dims;
7540be168c0dSopenharmony_ci-  }
7541be168c0dSopenharmony_ci-  int MarkPerformance();
7542be168c0dSopenharmony_ci-  int MarkAccuracy(bool enforce_accuracy = true);
7543be168c0dSopenharmony_ci-  int CompareOutput();
7544be168c0dSopenharmony_ci-  int SaveModels();
7545be168c0dSopenharmony_ci-  int CheckExecutionOfSavedModels();
7546be168c0dSopenharmony_ci-  void TensorNan(const float *data, int size) {
7547be168c0dSopenharmony_ci-    for (int i = 0; i < size; i++) {
7548be168c0dSopenharmony_ci-      if (std::isnan(data[i])) {
7549be168c0dSopenharmony_ci-        std::cout << "nan value of index=" << i << ", " << data[i] << std::endl;
7550be168c0dSopenharmony_ci-        break;
7551be168c0dSopenharmony_ci-      }
7552be168c0dSopenharmony_ci-    }
7553be168c0dSopenharmony_ci-  }
7554be168c0dSopenharmony_ci-#ifdef ENABLE_FP16
7555be168c0dSopenharmony_ci-  void TensorNan(float16_t *data, int size) {
7556be168c0dSopenharmony_ci-    for (int i = 0; i < size; i++) {
7557be168c0dSopenharmony_ci-      if (MS_ISNAN_FP16(data[i]) || std::isinf(data[i])) {
7558be168c0dSopenharmony_ci-        std::cout << "nan or inf value of index=" << i << ", " << data[i] << std::endl;
7559be168c0dSopenharmony_ci-        break;
7560be168c0dSopenharmony_ci-      }
7561be168c0dSopenharmony_ci-    }
7562be168c0dSopenharmony_ci-  }
7563be168c0dSopenharmony_ci-#endif
7564be168c0dSopenharmony_ci-  NetTrainFlags *flags_{nullptr};
7565be168c0dSopenharmony_ci-  static std::function<int(NetTrainFlags *)> nr_cb_;
7566be168c0dSopenharmony_ci+  int MarkPerformance() override;
7567be168c0dSopenharmony_ci+  int MarkAccuracy(bool enforce_accuracy = true) override;
7568be168c0dSopenharmony_ci+  int CompareOutput() override;
7569be168c0dSopenharmony_ci+  int SaveModels() override;
7570be168c0dSopenharmony_ci+
7571be168c0dSopenharmony_ci   // callback parameters
7572be168c0dSopenharmony_ci   uint64_t op_begin_ = 0;
7573be168c0dSopenharmony_ci   int op_call_times_total_ = 0;
7574be168c0dSopenharmony_ci@@ -301,13 +111,6 @@ class MS_API NetTrain {
7575be168c0dSopenharmony_ci 
7576be168c0dSopenharmony_ci   mindspore::MSKernelCallBack before_call_back_{nullptr};
7577be168c0dSopenharmony_ci   mindspore::MSKernelCallBack after_call_back_{nullptr};
7578be168c0dSopenharmony_ci-  nlohmann::json dump_cfg_json_;
7579be168c0dSopenharmony_ci-  std::string dump_file_output_dir_;
7580be168c0dSopenharmony_ci-  std::vector<std::shared_ptr<char>> inputs_buf_;
7581be168c0dSopenharmony_ci-  std::vector<size_t> inputs_size_;
7582be168c0dSopenharmony_ci-  size_t batch_num_ = 0;
7583be168c0dSopenharmony_ci };
7584be168c0dSopenharmony_ci-
7585be168c0dSopenharmony_ci-int MS_API RunNetTrain(int argc, const char **argv);
7586be168c0dSopenharmony_ci }  // namespace mindspore::lite
7587be168c0dSopenharmony_ci #endif  // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_
7588be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train_base.cc b/mindspore/lite/tools/benchmark_train/net_train_base.cc
7589be168c0dSopenharmony_cinew file mode 100644
7590be168c0dSopenharmony_ciindex 00000000..8d3c75de
7591be168c0dSopenharmony_ci--- /dev/null
7592be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train_base.cc
7593be168c0dSopenharmony_ci@@ -0,0 +1,410 @@
7594be168c0dSopenharmony_ci+/**
7595be168c0dSopenharmony_ci+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
7596be168c0dSopenharmony_ci+ *
7597be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
7598be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
7599be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
7600be168c0dSopenharmony_ci+ *
7601be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
7602be168c0dSopenharmony_ci+ *
7603be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
7604be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
7605be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7606be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
7607be168c0dSopenharmony_ci+ * limitations under the License.
7608be168c0dSopenharmony_ci+ */
7609be168c0dSopenharmony_ci+
7610be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_base.h"
7611be168c0dSopenharmony_ci+#define __STDC_FORMAT_MACROS
7612be168c0dSopenharmony_ci+#undef __STDC_FORMAT_MACROS
7613be168c0dSopenharmony_ci+#include <algorithm>
7614be168c0dSopenharmony_ci+#include <cstring>
7615be168c0dSopenharmony_ci+#ifdef ENABLE_NEON
7616be168c0dSopenharmony_ci+#include <arm_neon.h>
7617be168c0dSopenharmony_ci+#endif
7618be168c0dSopenharmony_ci+#include "src/common/common.h"
7619be168c0dSopenharmony_ci+#include "include/api/serialization.h"
7620be168c0dSopenharmony_ci+
7621be168c0dSopenharmony_ci+namespace mindspore {
7622be168c0dSopenharmony_ci+namespace lite {
7623be168c0dSopenharmony_ci+const std::unordered_map<int, std::string> kTypeIdMap{
7624be168c0dSopenharmony_ci+  {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"},    {kNumberTypeFloat32, "Float32"},
7625be168c0dSopenharmony_ci+  {kNumberTypeInt8, "Int8"},       {kNumberTypeInt16, "Int16"},      {kNumberTypeInt, "Int32"},
7626be168c0dSopenharmony_ci+  {kNumberTypeInt32, "Int32"},     {kNumberTypeUInt8, "UInt8"},      {kNumberTypeUInt16, "UInt16"},
7627be168c0dSopenharmony_ci+  {kNumberTypeUInt, "UInt32"},     {kNumberTypeUInt32, "UInt32"},    {kObjectTypeString, "String"},
7628be168c0dSopenharmony_ci+  {kNumberTypeBool, "Bool"},       {kObjectTypeTensorType, "Tensor"}};
7629be168c0dSopenharmony_ci+
7630be168c0dSopenharmony_ci+const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{
7631be168c0dSopenharmony_ci+  {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"},     {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"},
7632be168c0dSopenharmony_ci+  {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"},     {mindspore::CKHW, "CKHW"},   {mindspore::KHWC, "KHWC"},
7633be168c0dSopenharmony_ci+  {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"},         {mindspore::HW4, "HW4"},     {mindspore::NC, "NC"},
7634be168c0dSopenharmony_ci+  {mindspore::NC4, "NC4"},   {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}};
7635be168c0dSopenharmony_ci+
7636be168c0dSopenharmony_ci+std::function<int(NetTrainFlags *)> NetTrainBase::nr_cb_ = nullptr;
7637be168c0dSopenharmony_ci+
7638be168c0dSopenharmony_ci+int NetTrainBase::SetNr(std::function<int(NetTrainFlags *)> param) {
7639be168c0dSopenharmony_ci+  nr_cb_ = param;
7640be168c0dSopenharmony_ci+  return 0;
7641be168c0dSopenharmony_ci+}
7642be168c0dSopenharmony_ci+
7643be168c0dSopenharmony_ci+float *NetTrainBase::ReadFileBuf(const std::string file, size_t *size) {
7644be168c0dSopenharmony_ci+  if (file.empty()) {
7645be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "file is nullptr";
7646be168c0dSopenharmony_ci+    return nullptr;
7647be168c0dSopenharmony_ci+  }
7648be168c0dSopenharmony_ci+  MS_ASSERT(size != nullptr);
7649be168c0dSopenharmony_ci+  std::string real_path = RealPath(file.c_str());
7650be168c0dSopenharmony_ci+  std::ifstream ifs(real_path);
7651be168c0dSopenharmony_ci+  if (!ifs.good()) {
7652be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
7653be168c0dSopenharmony_ci+    return nullptr;
7654be168c0dSopenharmony_ci+  }
7655be168c0dSopenharmony_ci+
7656be168c0dSopenharmony_ci+  if (!ifs.is_open()) {
7657be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "file: " << real_path << " open failed";
7658be168c0dSopenharmony_ci+    return nullptr;
7659be168c0dSopenharmony_ci+  }
7660be168c0dSopenharmony_ci+
7661be168c0dSopenharmony_ci+  ifs.seekg(0, std::ios::end);
7662be168c0dSopenharmony_ci+  *size = ifs.tellg();
7663be168c0dSopenharmony_ci+  std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1);
7664be168c0dSopenharmony_ci+  if (buf == nullptr) {
7665be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
7666be168c0dSopenharmony_ci+    ifs.close();
7667be168c0dSopenharmony_ci+    return nullptr;
7668be168c0dSopenharmony_ci+  }
7669be168c0dSopenharmony_ci+
7670be168c0dSopenharmony_ci+  ifs.seekg(0, std::ios::beg);
7671be168c0dSopenharmony_ci+  ifs.read(reinterpret_cast<char *>(buf.get()), *size);
7672be168c0dSopenharmony_ci+  ifs.close();
7673be168c0dSopenharmony_ci+
7674be168c0dSopenharmony_ci+  return buf.release();
7675be168c0dSopenharmony_ci+}
7676be168c0dSopenharmony_ci+
7677be168c0dSopenharmony_ci+int NetTrainBase::GenerateRandomData(mindspore::MSTensor *tensor) {
7678be168c0dSopenharmony_ci+  auto input_data = tensor->MutableData();
7679be168c0dSopenharmony_ci+  if (input_data == nullptr) {
7680be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "MallocData for inTensor failed";
7681be168c0dSopenharmony_ci+    return RET_ERROR;
7682be168c0dSopenharmony_ci+  }
7683be168c0dSopenharmony_ci+  auto tensor_byte_size = tensor->DataSize();
7684be168c0dSopenharmony_ci+  char *casted_data = static_cast<char *>(input_data);
7685be168c0dSopenharmony_ci+  for (size_t i = 0; i < tensor_byte_size; i++) {
7686be168c0dSopenharmony_ci+    casted_data[i] =
7687be168c0dSopenharmony_ci+      (tensor->DataType() == mindspore::DataType::kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0);
7688be168c0dSopenharmony_ci+  }
7689be168c0dSopenharmony_ci+  return RET_OK;
7690be168c0dSopenharmony_ci+}
7691be168c0dSopenharmony_ci+
7692be168c0dSopenharmony_ci+int NetTrainBase::LoadInput() {
7693be168c0dSopenharmony_ci+  inputs_buf_.clear();
7694be168c0dSopenharmony_ci+  inputs_size_.clear();
7695be168c0dSopenharmony_ci+  batch_num_ = 0;
7696be168c0dSopenharmony_ci+  if (flags_->in_data_file_.empty()) {
7697be168c0dSopenharmony_ci+    auto status = GenerateInputData();
7698be168c0dSopenharmony_ci+    if (status != RET_OK) {
7699be168c0dSopenharmony_ci+      std::cerr << "Generate input data error " << status << std::endl;
7700be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Generate input data error " << status;
7701be168c0dSopenharmony_ci+      return status;
7702be168c0dSopenharmony_ci+    }
7703be168c0dSopenharmony_ci+  } else {
7704be168c0dSopenharmony_ci+    auto status = ReadInputFile();
7705be168c0dSopenharmony_ci+    if (status != RET_OK) {
7706be168c0dSopenharmony_ci+      std::cerr << "Read Input File error, " << status << std::endl;
7707be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Read Input File error, " << status;
7708be168c0dSopenharmony_ci+      return status;
7709be168c0dSopenharmony_ci+    }
7710be168c0dSopenharmony_ci+  }
7711be168c0dSopenharmony_ci+  return RET_OK;
7712be168c0dSopenharmony_ci+}
7713be168c0dSopenharmony_ci+
7714be168c0dSopenharmony_ci+int NetTrainBase::RunNetTrain() {
7715be168c0dSopenharmony_ci+  auto file_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1);
7716be168c0dSopenharmony_ci+  bool is_train = (file_name.find("train") != std::string::npos) || !flags_->bb_model_file_.empty();
7717be168c0dSopenharmony_ci+  auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, is_train, flags_->epochs_);
7718be168c0dSopenharmony_ci+  if (status != RET_OK) {
7719be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status;
7720be168c0dSopenharmony_ci+    std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status
7721be168c0dSopenharmony_ci+              << std::endl;
7722be168c0dSopenharmony_ci+    return status;
7723be168c0dSopenharmony_ci+  }
7724be168c0dSopenharmony_ci+
7725be168c0dSopenharmony_ci+  status = CheckExecutionOfSavedModels();  // re-initialize sessions according to flags
7726be168c0dSopenharmony_ci+  if (status != RET_OK) {
7727be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Run CheckExecute error: " << status;
7728be168c0dSopenharmony_ci+    std::cout << "Run CheckExecute error: " << status << std::endl;
7729be168c0dSopenharmony_ci+    return status;
7730be168c0dSopenharmony_ci+  }
7731be168c0dSopenharmony_ci+  return RET_OK;
7732be168c0dSopenharmony_ci+}
7733be168c0dSopenharmony_ci+
7734be168c0dSopenharmony_ci+int NetTrainBase::CheckExecutionOfSavedModels() {
7735be168c0dSopenharmony_ci+  int status = RET_OK;
7736be168c0dSopenharmony_ci+  if (!flags_->export_file_.empty()) {
7737be168c0dSopenharmony_ci+    status = CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0);
7738be168c0dSopenharmony_ci+    if (status != RET_OK) {
7739be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status;
7740be168c0dSopenharmony_ci+      std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl;
7741be168c0dSopenharmony_ci+      return status;
7742be168c0dSopenharmony_ci+    }
7743be168c0dSopenharmony_ci+    if (flags_->bb_model_file_.empty()) {
7744be168c0dSopenharmony_ci+      status = CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false);
7745be168c0dSopenharmony_ci+      if (status != RET_OK) {
7746be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status;
7747be168c0dSopenharmony_ci+        std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl;
7748be168c0dSopenharmony_ci+        return status;
7749be168c0dSopenharmony_ci+      }
7750be168c0dSopenharmony_ci+    }
7751be168c0dSopenharmony_ci+  }
7752be168c0dSopenharmony_ci+  if (!flags_->inference_file_.empty()) {
7753be168c0dSopenharmony_ci+    status = CreateAndRunNetwork(flags_->inference_file_, "", false, 0);
7754be168c0dSopenharmony_ci+    if (status != RET_OK) {
7755be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status;
7756be168c0dSopenharmony_ci+      std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl;
7757be168c0dSopenharmony_ci+      return status;
7758be168c0dSopenharmony_ci+    }
7759be168c0dSopenharmony_ci+    status = CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false);
7760be168c0dSopenharmony_ci+    if (status != RET_OK) {
7761be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status;
7762be168c0dSopenharmony_ci+      std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl;
7763be168c0dSopenharmony_ci+      return status;
7764be168c0dSopenharmony_ci+    }
7765be168c0dSopenharmony_ci+  }
7766be168c0dSopenharmony_ci+  return status;
7767be168c0dSopenharmony_ci+}
7768be168c0dSopenharmony_ci+
7769be168c0dSopenharmony_ci+void NetTrainBase::CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out) {
7770be168c0dSopenharmony_ci+  if (tensor == nullptr) {
7771be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "input tensor is nullptr.";
7772be168c0dSopenharmony_ci+    return;
7773be168c0dSopenharmony_ci+  }
7774be168c0dSopenharmony_ci+  int tensor_size = tensor->ElementNum();
7775be168c0dSopenharmony_ci+  void *data = tensor->MutableData();
7776be168c0dSopenharmony_ci+  auto *fdata = reinterpret_cast<float *>(tensor->MutableData());
7777be168c0dSopenharmony_ci+  auto type = tensor->DataType();
7778be168c0dSopenharmony_ci+  std::cout << node_type << " " << in_out << id << " shape=" << tensor->Shape() << " sum=";
7779be168c0dSopenharmony_ci+  switch (type) {
7780be168c0dSopenharmony_ci+    case mindspore::DataType::kNumberTypeFloat32:
7781be168c0dSopenharmony_ci+      TensorNan(reinterpret_cast<float *>(data), tensor_size);
7782be168c0dSopenharmony_ci+      std::cout << TensorSum<float>(data, tensor_size) << std::endl;
7783be168c0dSopenharmony_ci+      std::cout << "tensor name: " << tensor->Name() << std::endl;
7784be168c0dSopenharmony_ci+      std::cout << "data: ";
7785be168c0dSopenharmony_ci+      for (int i = 0; i <= kPrintOffset && i < tensor_size; i++) {
7786be168c0dSopenharmony_ci+        std::cout << static_cast<float>(fdata[i]) << ", ";
7787be168c0dSopenharmony_ci+      }
7788be168c0dSopenharmony_ci+      std::cout << std::endl;
7789be168c0dSopenharmony_ci+      break;
7790be168c0dSopenharmony_ci+    case mindspore::DataType::kNumberTypeInt32:
7791be168c0dSopenharmony_ci+      std::cout << TensorSum<int>(data, tensor_size) << std::endl;
7792be168c0dSopenharmony_ci+      break;
7793be168c0dSopenharmony_ci+#ifdef ENABLE_FP16
7794be168c0dSopenharmony_ci+    case mindspore::DataType::kNumberTypeFloat16:
7795be168c0dSopenharmony_ci+      std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl;
7796be168c0dSopenharmony_ci+      TensorNan(reinterpret_cast<float16_t *>(data), tensor_size);
7797be168c0dSopenharmony_ci+      break;
7798be168c0dSopenharmony_ci+#endif
7799be168c0dSopenharmony_ci+    default:
7800be168c0dSopenharmony_ci+      std::cout << "unsupported type:" << static_cast<int>(type) << std::endl;
7801be168c0dSopenharmony_ci+      break;
7802be168c0dSopenharmony_ci+  }
7803be168c0dSopenharmony_ci+}
7804be168c0dSopenharmony_ci+
7805be168c0dSopenharmony_ci+std::string NetTrainBase::GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name,
7806be168c0dSopenharmony_ci+                                   const std::string &file_type, const size_t &idx) {
7807be168c0dSopenharmony_ci+  std::string file_name = op_name;
7808be168c0dSopenharmony_ci+  auto pos = file_name.find_first_of('/');
7809be168c0dSopenharmony_ci+  while (pos != std::string::npos) {
7810be168c0dSopenharmony_ci+    file_name.replace(pos, 1, ".");
7811be168c0dSopenharmony_ci+    pos = file_name.find_first_of('/');
7812be168c0dSopenharmony_ci+  }
7813be168c0dSopenharmony_ci+  file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_";
7814be168c0dSopenharmony_ci+  for (const auto &dim : tensor->Shape()) {
7815be168c0dSopenharmony_ci+    file_name += std::to_string(dim) + "_";
7816be168c0dSopenharmony_ci+  }
7817be168c0dSopenharmony_ci+  if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) {
7818be168c0dSopenharmony_ci+    file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType()));
7819be168c0dSopenharmony_ci+  }
7820be168c0dSopenharmony_ci+  auto tensor_format = tensor->format();
7821be168c0dSopenharmony_ci+  if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) {
7822be168c0dSopenharmony_ci+    file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin";
7823be168c0dSopenharmony_ci+  }
7824be168c0dSopenharmony_ci+
7825be168c0dSopenharmony_ci+  file_name += ".bin";
7826be168c0dSopenharmony_ci+  return file_name;
7827be168c0dSopenharmony_ci+}
7828be168c0dSopenharmony_ci+
7829be168c0dSopenharmony_ci+int NetTrainBase::InitCallbackParameter() {
7830be168c0dSopenharmony_ci+  int ret = RET_OK;
7831be168c0dSopenharmony_ci+  if (flags_->dump_tensor_data_) {
7832be168c0dSopenharmony_ci+    ret = InitDumpTensorDataCallbackParameter();
7833be168c0dSopenharmony_ci+  } else if (flags_->time_profiling_) {
7834be168c0dSopenharmony_ci+    ret = InitTimeProfilingCallbackParameter();
7835be168c0dSopenharmony_ci+  }
7836be168c0dSopenharmony_ci+  return ret;
7837be168c0dSopenharmony_ci+}
7838be168c0dSopenharmony_ci+
7839be168c0dSopenharmony_ci+void NetTrainFlags::InitResizeDimsList() {
7840be168c0dSopenharmony_ci+  std::string content = this->resize_dims_in_;
7841be168c0dSopenharmony_ci+  if (content.empty()) {
7842be168c0dSopenharmony_ci+    return;
7843be168c0dSopenharmony_ci+  }
7844be168c0dSopenharmony_ci+  std::vector<int> shape;
7845be168c0dSopenharmony_ci+  auto shape_strs = StrSplit(content, std::string(DELIM_COLON));
7846be168c0dSopenharmony_ci+  for (const auto &shape_str : shape_strs) {
7847be168c0dSopenharmony_ci+    shape.clear();
7848be168c0dSopenharmony_ci+    auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA));
7849be168c0dSopenharmony_ci+    std::cout << "Resize Dims: ";
7850be168c0dSopenharmony_ci+    for (const auto &dim_str : dim_strs) {
7851be168c0dSopenharmony_ci+      std::cout << dim_str << " ";
7852be168c0dSopenharmony_ci+      shape.emplace_back(static_cast<int>(std::stoi(dim_str)));
7853be168c0dSopenharmony_ci+    }
7854be168c0dSopenharmony_ci+    std::cout << std::endl;
7855be168c0dSopenharmony_ci+    this->resize_dims_.emplace_back(shape);
7856be168c0dSopenharmony_ci+  }
7857be168c0dSopenharmony_ci+}
7858be168c0dSopenharmony_ci+
7859be168c0dSopenharmony_ci+int NetTrainBase::Init() {
7860be168c0dSopenharmony_ci+  if (this->flags_ == nullptr) {
7861be168c0dSopenharmony_ci+    return 1;
7862be168c0dSopenharmony_ci+  }
7863be168c0dSopenharmony_ci+  MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
7864be168c0dSopenharmony_ci+  MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
7865be168c0dSopenharmony_ci+  MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
7866be168c0dSopenharmony_ci+  MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_;
7867be168c0dSopenharmony_ci+  MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_;
7868be168c0dSopenharmony_ci+  MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_;
7869be168c0dSopenharmony_ci+  MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
7870be168c0dSopenharmony_ci+  MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
7871be168c0dSopenharmony_ci+  MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
7872be168c0dSopenharmony_ci+  MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;
7873be168c0dSopenharmony_ci+  MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_;
7874be168c0dSopenharmony_ci+
7875be168c0dSopenharmony_ci+  if (this->flags_->epochs_ < 0) {
7876be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";
7877be168c0dSopenharmony_ci+    std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl;
7878be168c0dSopenharmony_ci+    return RET_ERROR;
7879be168c0dSopenharmony_ci+  }
7880be168c0dSopenharmony_ci+
7881be168c0dSopenharmony_ci+  if (this->flags_->num_threads_ < 1) {
7882be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
7883be168c0dSopenharmony_ci+    std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl;
7884be168c0dSopenharmony_ci+    return RET_ERROR;
7885be168c0dSopenharmony_ci+  }
7886be168c0dSopenharmony_ci+
7887be168c0dSopenharmony_ci+  this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
7888be168c0dSopenharmony_ci+
7889be168c0dSopenharmony_ci+  if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) {
7890be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided";
7891be168c0dSopenharmony_ci+    std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl;
7892be168c0dSopenharmony_ci+    return RET_ERROR;
7893be168c0dSopenharmony_ci+  }
7894be168c0dSopenharmony_ci+
7895be168c0dSopenharmony_ci+  if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) {
7896be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided";
7897be168c0dSopenharmony_ci+    std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl;
7898be168c0dSopenharmony_ci+    return RET_ERROR;
7899be168c0dSopenharmony_ci+  }
7900be168c0dSopenharmony_ci+
7901be168c0dSopenharmony_ci+  if (flags_->model_file_.empty()) {
7902be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "modelPath is required";
7903be168c0dSopenharmony_ci+    std::cerr << "modelPath is required" << std::endl;
7904be168c0dSopenharmony_ci+    return 1;
7905be168c0dSopenharmony_ci+  }
7906be168c0dSopenharmony_ci+
7907be168c0dSopenharmony_ci+  // get dump data output path
7908be168c0dSopenharmony_ci+  auto dump_cfg_path = std::getenv(dump::kConfigPath);
7909be168c0dSopenharmony_ci+  if (dump_cfg_path != nullptr) {
7910be168c0dSopenharmony_ci+    flags_->dump_tensor_data_ = true;
7911be168c0dSopenharmony_ci+    if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) {
7912be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "parse dump config file failed.";
7913be168c0dSopenharmony_ci+      return RET_ERROR;
7914be168c0dSopenharmony_ci+    }
7915be168c0dSopenharmony_ci+  } else {
7916be168c0dSopenharmony_ci+    MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data";
7917be168c0dSopenharmony_ci+  }
7918be168c0dSopenharmony_ci+
7919be168c0dSopenharmony_ci+  auto status = InitCallbackParameter();
7920be168c0dSopenharmony_ci+  if (status != RET_OK) {
7921be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Init callback Parameter failed.";
7922be168c0dSopenharmony_ci+    std::cerr << "Init callback Parameter failed." << std::endl;
7923be168c0dSopenharmony_ci+    return RET_ERROR;
7924be168c0dSopenharmony_ci+  }
7925be168c0dSopenharmony_ci+
7926be168c0dSopenharmony_ci+  flags_->InitResizeDimsList();
7927be168c0dSopenharmony_ci+  if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
7928be168c0dSopenharmony_ci+      flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
7929be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
7930be168c0dSopenharmony_ci+    std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
7931be168c0dSopenharmony_ci+    return RET_ERROR;
7932be168c0dSopenharmony_ci+  }
7933be168c0dSopenharmony_ci+  return RET_OK;
7934be168c0dSopenharmony_ci+}
7935be168c0dSopenharmony_ci+
7936be168c0dSopenharmony_ci+int NetTrainBase::InitDumpConfigFromJson(std::string path) {
7937be168c0dSopenharmony_ci+  auto real_path = RealPath(path.c_str());
7938be168c0dSopenharmony_ci+  std::ifstream ifs(real_path);
7939be168c0dSopenharmony_ci+  if (!ifs.good()) {
7940be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
7941be168c0dSopenharmony_ci+    return RET_ERROR;
7942be168c0dSopenharmony_ci+  }
7943be168c0dSopenharmony_ci+  if (!ifs.is_open()) {
7944be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "file: " << real_path << " open failed";
7945be168c0dSopenharmony_ci+    return RET_ERROR;
7946be168c0dSopenharmony_ci+  }
7947be168c0dSopenharmony_ci+
7948be168c0dSopenharmony_ci+  try {
7949be168c0dSopenharmony_ci+    dump_cfg_json_ = nlohmann::json::parse(ifs);
7950be168c0dSopenharmony_ci+  } catch (const nlohmann::json::parse_error &error) {
7951be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "parse json file failed, please check your file.";
7952be168c0dSopenharmony_ci+    return RET_ERROR;
7953be168c0dSopenharmony_ci+  }
7954be168c0dSopenharmony_ci+  if (dump_cfg_json_[dump::kSettings] == nullptr) {
7955be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "\"common_dump_settings\" is required.";
7956be168c0dSopenharmony_ci+    return RET_ERROR;
7957be168c0dSopenharmony_ci+  }
7958be168c0dSopenharmony_ci+  if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) {
7959be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "\"dump_mode\" is required.";
7960be168c0dSopenharmony_ci+    return RET_ERROR;
7961be168c0dSopenharmony_ci+  }
7962be168c0dSopenharmony_ci+  if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) {
7963be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "\"path\" is required.";
7964be168c0dSopenharmony_ci+    return RET_ERROR;
7965be168c0dSopenharmony_ci+  }
7966be168c0dSopenharmony_ci+  if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) {
7967be168c0dSopenharmony_ci+    dump_cfg_json_[dump::kSettings][dump::kNetName] = "default";
7968be168c0dSopenharmony_ci+  }
7969be168c0dSopenharmony_ci+  if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) {
7970be168c0dSopenharmony_ci+    dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0;
7971be168c0dSopenharmony_ci+  }
7972be168c0dSopenharmony_ci+  if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr &&
7973be168c0dSopenharmony_ci+      !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) {
7974be168c0dSopenharmony_ci+    if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) {
7975be168c0dSopenharmony_ci+      MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)";
7976be168c0dSopenharmony_ci+      return RET_ERROR;
7977be168c0dSopenharmony_ci+    }
7978be168c0dSopenharmony_ci+  }
7979be168c0dSopenharmony_ci+
7980be168c0dSopenharmony_ci+  auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>();
7981be168c0dSopenharmony_ci+  auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>();
7982be168c0dSopenharmony_ci+  if (abs_path.back() == '\\' || abs_path.back() == '/') {
7983be168c0dSopenharmony_ci+    dump_file_output_dir_ = abs_path + net_name;
7984be168c0dSopenharmony_ci+  } else {
7985be168c0dSopenharmony_ci+#ifdef _WIN32
7986be168c0dSopenharmony_ci+    dump_file_output_dir_ = abs_path + "\\" + net_name;
7987be168c0dSopenharmony_ci+#else
7988be168c0dSopenharmony_ci+    dump_file_output_dir_ = abs_path + "/" + net_name;
7989be168c0dSopenharmony_ci+#endif
7990be168c0dSopenharmony_ci+  }
7991be168c0dSopenharmony_ci+
7992be168c0dSopenharmony_ci+  auto status = CreateOutputDir(&dump_file_output_dir_);
7993be168c0dSopenharmony_ci+  if (status != RET_OK) {
7994be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "create data output directory failed.";
7995be168c0dSopenharmony_ci+    return RET_ERROR;
7996be168c0dSopenharmony_ci+  }
7997be168c0dSopenharmony_ci+  return RET_OK;
7998be168c0dSopenharmony_ci+}
7999be168c0dSopenharmony_ci+
8000be168c0dSopenharmony_ci+NetTrainBase:: ~NetTrainBase() {
8001be168c0dSopenharmony_ci+}
8002be168c0dSopenharmony_ci+}  // namespace lite
8003be168c0dSopenharmony_ci+}  // namespace mindspore
8004be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train_base.h b/mindspore/lite/tools/benchmark_train/net_train_base.h
8005be168c0dSopenharmony_cinew file mode 100644
8006be168c0dSopenharmony_ciindex 00000000..e3d5f39a
8007be168c0dSopenharmony_ci--- /dev/null
8008be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train_base.h
8009be168c0dSopenharmony_ci@@ -0,0 +1,288 @@
8010be168c0dSopenharmony_ci+/**
8011be168c0dSopenharmony_ci+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
8012be168c0dSopenharmony_ci+ *
8013be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
8014be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
8015be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
8016be168c0dSopenharmony_ci+ *
8017be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
8018be168c0dSopenharmony_ci+ *
8019be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
8020be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
8021be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8022be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
8023be168c0dSopenharmony_ci+ * limitations under the License.
8024be168c0dSopenharmony_ci+ */
8025be168c0dSopenharmony_ci+
8026be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_
8027be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_
8028be168c0dSopenharmony_ci+
8029be168c0dSopenharmony_ci+#include <getopt.h>
8030be168c0dSopenharmony_ci+#include <csignal>
8031be168c0dSopenharmony_ci+#include <unordered_map>
8032be168c0dSopenharmony_ci+#include <fstream>
8033be168c0dSopenharmony_ci+#include <iostream>
8034be168c0dSopenharmony_ci+#include <map>
8035be168c0dSopenharmony_ci+#include <cmath>
8036be168c0dSopenharmony_ci+#include <string>
8037be168c0dSopenharmony_ci+#include <vector>
8038be168c0dSopenharmony_ci+#include <memory>
8039be168c0dSopenharmony_ci+#include <cfloat>
8040be168c0dSopenharmony_ci+#include <utility>
8041be168c0dSopenharmony_ci+#include <algorithm>
8042be168c0dSopenharmony_ci+#include <nlohmann/json.hpp>
8043be168c0dSopenharmony_ci+#include "include/api/model.h"
8044be168c0dSopenharmony_ci+#include "include/api/types.h"
8045be168c0dSopenharmony_ci+#include "include/api/context.h"
8046be168c0dSopenharmony_ci+#include "include/api/cfg.h"
8047be168c0dSopenharmony_ci+
8048be168c0dSopenharmony_ci+#ifdef ENABLE_FP16
8049be168c0dSopenharmony_ci+#include <arm_neon.h>
8050be168c0dSopenharmony_ci+#endif
8051be168c0dSopenharmony_ci+#include "tools/common/flag_parser.h"
8052be168c0dSopenharmony_ci+#include "src/common/file_utils.h"
8053be168c0dSopenharmony_ci+#include "src/common/utils.h"
8054be168c0dSopenharmony_ci+
8055be168c0dSopenharmony_ci+#ifdef ENABLE_FP16
8056be168c0dSopenharmony_ci+static __attribute__((always_inline)) inline bool MS_ISNAN_FP16(float16_t var) {
8057be168c0dSopenharmony_ci+  volatile float16_t d = var;
8058be168c0dSopenharmony_ci+  return d != d;
8059be168c0dSopenharmony_ci+}
8060be168c0dSopenharmony_ci+#endif
8061be168c0dSopenharmony_ci+
8062be168c0dSopenharmony_ci+namespace mindspore::lite {
8063be168c0dSopenharmony_ci+enum MS_API DataType { kImage = 0, kBinary = 1 };
8064be168c0dSopenharmony_ci+
8065be168c0dSopenharmony_ci+constexpr float relativeTolerance = 1e-5;
8066be168c0dSopenharmony_ci+constexpr float absoluteTolerance = 1e-8;
8067be168c0dSopenharmony_ci+extern const std::unordered_map<int, std::string> kTypeIdMap;
8068be168c0dSopenharmony_ci+extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
8069be168c0dSopenharmony_ci+
8070be168c0dSopenharmony_ci+constexpr const char *DELIM_SLASH = "/";
8071be168c0dSopenharmony_ci+constexpr const char *DELIM_COLON = ":";
8072be168c0dSopenharmony_ci+constexpr const char *DELIM_COMMA = ",";
8073be168c0dSopenharmony_ci+
8074be168c0dSopenharmony_ci+constexpr int RET_TOO_BIG = -9;
8075be168c0dSopenharmony_ci+constexpr int kFieldsToPrint = 5;
8076be168c0dSopenharmony_ci+constexpr int kPrintOffset = 4;
8077be168c0dSopenharmony_ci+constexpr int kDumpInputsAndOutputs = 0;
8078be168c0dSopenharmony_ci+constexpr int kDumpOutputs = 2;
8079be168c0dSopenharmony_ci+constexpr int kTHOUSAND = 1000;
8080be168c0dSopenharmony_ci+
8081be168c0dSopenharmony_ci+namespace dump {
8082be168c0dSopenharmony_ci+constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG";
8083be168c0dSopenharmony_ci+constexpr auto kSettings = "common_dump_settings";
8084be168c0dSopenharmony_ci+constexpr auto kMode = "dump_mode";
8085be168c0dSopenharmony_ci+constexpr auto kPath = "path";
8086be168c0dSopenharmony_ci+constexpr auto kNetName = "net_name";
8087be168c0dSopenharmony_ci+constexpr auto kInputOutput = "input_output";
8088be168c0dSopenharmony_ci+constexpr auto kKernels = "kernels";
8089be168c0dSopenharmony_ci+}  // namespace dump
8090be168c0dSopenharmony_ci+
8091be168c0dSopenharmony_ci+template <typename T>
8092be168c0dSopenharmony_ci+float TensorSum(const void *data, int size) {
8093be168c0dSopenharmony_ci+  const T *typed_data = reinterpret_cast<const T *>(data);
8094be168c0dSopenharmony_ci+  float sum = 0.f;
8095be168c0dSopenharmony_ci+  for (int i = 0; i < size; i++) {
8096be168c0dSopenharmony_ci+    sum += static_cast<float>(typed_data[i]);
8097be168c0dSopenharmony_ci+  }
8098be168c0dSopenharmony_ci+  return sum;
8099be168c0dSopenharmony_ci+}
8100be168c0dSopenharmony_ci+
8101be168c0dSopenharmony_ci+class MS_API NetTrainFlags : public virtual FlagParser {
8102be168c0dSopenharmony_ci+ public:
8103be168c0dSopenharmony_ci+  NetTrainFlags() {
8104be168c0dSopenharmony_ci+    // common
8105be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::model_file_, "modelFile", "Input model file", "");
8106be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::bb_model_file_, "bbModelFile", "Backboine model for transfer session", "");
8107be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
8108be168c0dSopenharmony_ci+    // MarkPerformance
8109be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0);
8110be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false);
8111be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1);
8112be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1);
8113be168c0dSopenharmony_ci+    // MarkAccuracy
8114be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", "");
8115be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", "");
8116be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
8117be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false);
8118be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false);
8119be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", "");
8120be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", "");
8121be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false);
8122be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes",
8123be168c0dSopenharmony_ci+            "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
8124be168c0dSopenharmony_ci+    AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false);
8125be168c0dSopenharmony_ci+  }
8126be168c0dSopenharmony_ci+
8127be168c0dSopenharmony_ci+  ~NetTrainFlags() override = default;
8128be168c0dSopenharmony_ci+  void InitResizeDimsList();
8129be168c0dSopenharmony_ci+
8130be168c0dSopenharmony_ci+ public:
8131be168c0dSopenharmony_ci+  // common
8132be168c0dSopenharmony_ci+  std::string model_file_;
8133be168c0dSopenharmony_ci+  std::string in_data_file_;
8134be168c0dSopenharmony_ci+  std::string bb_model_file_;
8135be168c0dSopenharmony_ci+  std::vector<std::string> input_data_list_;
8136be168c0dSopenharmony_ci+  DataType in_data_type_;
8137be168c0dSopenharmony_ci+  std::string in_data_type_in_ = "bin";
8138be168c0dSopenharmony_ci+  int cpu_bind_mode_ = 1;
8139be168c0dSopenharmony_ci+  bool enable_fp16_ = false;
8140be168c0dSopenharmony_ci+  bool virtual_batch_ = false;
8141be168c0dSopenharmony_ci+  // MarkPerformance
8142be168c0dSopenharmony_ci+  int num_threads_ = 1;
8143be168c0dSopenharmony_ci+  int warm_up_loop_count_ = 0;
8144be168c0dSopenharmony_ci+  bool time_profiling_;
8145be168c0dSopenharmony_ci+  int epochs_ = 1;
8146be168c0dSopenharmony_ci+  // MarkAccuracy
8147be168c0dSopenharmony_ci+  std::string data_file_;
8148be168c0dSopenharmony_ci+  std::string data_type_ = "FLOAT";
8149be168c0dSopenharmony_ci+  float accuracy_threshold_;
8150be168c0dSopenharmony_ci+  // Resize
8151be168c0dSopenharmony_ci+  std::string export_file_ = "";
8152be168c0dSopenharmony_ci+  std::string resize_dims_in_ = "";
8153be168c0dSopenharmony_ci+  bool layer_checksum_ = false;
8154be168c0dSopenharmony_ci+  std::vector<std::vector<int>> resize_dims_;
8155be168c0dSopenharmony_ci+  std::string loss_name_ = "";
8156be168c0dSopenharmony_ci+  std::string inference_file_ = "";
8157be168c0dSopenharmony_ci+  bool unified_api_ = false;
8158be168c0dSopenharmony_ci+  bool dump_tensor_data_ = false;
8159be168c0dSopenharmony_ci+};
8160be168c0dSopenharmony_ci+
8161be168c0dSopenharmony_ci+class MS_API NetTrainBase {
8162be168c0dSopenharmony_ci+ public:
8163be168c0dSopenharmony_ci+  explicit NetTrainBase(NetTrainFlags *flags) : flags_(flags) {}
8164be168c0dSopenharmony_ci+  virtual ~NetTrainBase();
8165be168c0dSopenharmony_ci+
8166be168c0dSopenharmony_ci+  int Init();
8167be168c0dSopenharmony_ci+  int RunNetTrain();
8168be168c0dSopenharmony_ci+  static float *ReadFileBuf(const std::string file, size_t *size);
8169be168c0dSopenharmony_ci+  static int SetNr(std::function<int(NetTrainFlags *)> param);
8170be168c0dSopenharmony_ci+  static int RunNr(NetTrainFlags *flags) {
8171be168c0dSopenharmony_ci+    if (nr_cb_ != nullptr) {
8172be168c0dSopenharmony_ci+      return nr_cb_(flags);
8173be168c0dSopenharmony_ci+    }
8174be168c0dSopenharmony_ci+    MS_LOG(WARNING) << "unified api was not tested";
8175be168c0dSopenharmony_ci+    std::cout << "unified api was not tested";
8176be168c0dSopenharmony_ci+    return RET_OK;
8177be168c0dSopenharmony_ci+  }
8178be168c0dSopenharmony_ci+  // tensorData need to be converter first
8179be168c0dSopenharmony_ci+  template <typename T>
8180be168c0dSopenharmony_ci+  static float CompareData(const float *refOutput, int size, const T *msTensorData) {
8181be168c0dSopenharmony_ci+    size_t errorCount = 0;
8182be168c0dSopenharmony_ci+    float meanError = 0;
8183be168c0dSopenharmony_ci+    std::cout << "Out tensor size is: " << size << std::endl;
8184be168c0dSopenharmony_ci+    std::cout << "Data of model output: ";
8185be168c0dSopenharmony_ci+    for (int j = 0; j < std::min(50, size); j++) {
8186be168c0dSopenharmony_ci+      std::cout << static_cast<float>(msTensorData[j]) << " ";
8187be168c0dSopenharmony_ci+    }
8188be168c0dSopenharmony_ci+    std::cout << std::endl;
8189be168c0dSopenharmony_ci+    std::cout << "Data of Ref output  : ";
8190be168c0dSopenharmony_ci+    for (int j = 0; j < std::min(50, size); j++) {
8191be168c0dSopenharmony_ci+      std::cout << refOutput[j] << " ";
8192be168c0dSopenharmony_ci+    }
8193be168c0dSopenharmony_ci+    std::cout << std::endl;
8194be168c0dSopenharmony_ci+    for (int j = 0; j < size; j++) {
8195be168c0dSopenharmony_ci+      if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
8196be168c0dSopenharmony_ci+        std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
8197be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
8198be168c0dSopenharmony_ci+        return RET_ERROR;
8199be168c0dSopenharmony_ci+      }
8200be168c0dSopenharmony_ci+
8201be168c0dSopenharmony_ci+      auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]);
8202be168c0dSopenharmony_ci+      auto absoluteError = std::fabs(static_cast<float>(msTensorData[j]) - refOutput[j]);
8203be168c0dSopenharmony_ci+      if (absoluteError > tolerance) {
8204be168c0dSopenharmony_ci+        if (fabs(refOutput[j]) == 0) {
8205be168c0dSopenharmony_ci+          if (absoluteError > 1e-5) {
8206be168c0dSopenharmony_ci+            meanError += absoluteError;
8207be168c0dSopenharmony_ci+            errorCount++;
8208be168c0dSopenharmony_ci+          } else {
8209be168c0dSopenharmony_ci+            continue;
8210be168c0dSopenharmony_ci+          }
8211be168c0dSopenharmony_ci+        } else {
8212be168c0dSopenharmony_ci+          // just assume that atol = rtol
8213be168c0dSopenharmony_ci+          meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN);
8214be168c0dSopenharmony_ci+          errorCount++;
8215be168c0dSopenharmony_ci+        }
8216be168c0dSopenharmony_ci+      }
8217be168c0dSopenharmony_ci+    }
8218be168c0dSopenharmony_ci+    std::cout << std::endl;
8219be168c0dSopenharmony_ci+    if (meanError > 0.0f) {
8220be168c0dSopenharmony_ci+      meanError /= errorCount;
8221be168c0dSopenharmony_ci+    }
8222be168c0dSopenharmony_ci+
8223be168c0dSopenharmony_ci+    if (meanError <= 0.0000001) {
8224be168c0dSopenharmony_ci+      std::cout << "Mean bias of tensor: 0%" << std::endl;
8225be168c0dSopenharmony_ci+    } else {
8226be168c0dSopenharmony_ci+      std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl;
8227be168c0dSopenharmony_ci+    }
8228be168c0dSopenharmony_ci+    return meanError;
8229be168c0dSopenharmony_ci+  }
8230be168c0dSopenharmony_ci+  int InitDumpConfigFromJson(std::string path);
8231be168c0dSopenharmony_ci+
8232be168c0dSopenharmony_ci+ protected:
8233be168c0dSopenharmony_ci+  // call GenerateInputData or ReadInputFile to init inputTensors
8234be168c0dSopenharmony_ci+  int LoadInput();
8235be168c0dSopenharmony_ci+  void CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out);
8236be168c0dSopenharmony_ci+  // call GenerateRandomData to fill inputTensors
8237be168c0dSopenharmony_ci+  virtual int GenerateInputData() = 0;
8238be168c0dSopenharmony_ci+
8239be168c0dSopenharmony_ci+  int GenerateRandomData(mindspore::MSTensor *tensor);
8240be168c0dSopenharmony_ci+
8241be168c0dSopenharmony_ci+  std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name,
8242be168c0dSopenharmony_ci+                                     const std::string &file_type, const size_t &idx);
8243be168c0dSopenharmony_ci+  virtual int ReadInputFile() = 0;
8244be168c0dSopenharmony_ci+
8245be168c0dSopenharmony_ci+  virtual int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs,
8246be168c0dSopenharmony_ci+                          bool check_accuracy = true) = 0;
8247be168c0dSopenharmony_ci+
8248be168c0dSopenharmony_ci+  int InitCallbackParameter();
8249be168c0dSopenharmony_ci+
8250be168c0dSopenharmony_ci+  virtual int InitDumpTensorDataCallbackParameter() = 0;
8251be168c0dSopenharmony_ci+
8252be168c0dSopenharmony_ci+  virtual int InitTimeProfilingCallbackParameter() = 0;
8253be168c0dSopenharmony_ci+
8254be168c0dSopenharmony_ci+  virtual int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) = 0;
8255be168c0dSopenharmony_ci+
8256be168c0dSopenharmony_ci+  template <typename T>
8257be168c0dSopenharmony_ci+  std::vector<int64_t> ConverterToInt64Vector(const std::vector<T> &srcDims) {
8258be168c0dSopenharmony_ci+    std::vector<int64_t> dims;
8259be168c0dSopenharmony_ci+    for (auto shape : srcDims) {
8260be168c0dSopenharmony_ci+      dims.push_back(static_cast<int64_t>(shape));
8261be168c0dSopenharmony_ci+    }
8262be168c0dSopenharmony_ci+    return dims;
8263be168c0dSopenharmony_ci+  }
8264be168c0dSopenharmony_ci+  virtual int MarkPerformance() = 0;
8265be168c0dSopenharmony_ci+  virtual int MarkAccuracy(bool enforce_accuracy = true) = 0;
8266be168c0dSopenharmony_ci+  virtual int CompareOutput() = 0;
8267be168c0dSopenharmony_ci+  virtual int SaveModels() = 0;
8268be168c0dSopenharmony_ci+  int CheckExecutionOfSavedModels();
8269be168c0dSopenharmony_ci+  void TensorNan(const float *data, int size) {
8270be168c0dSopenharmony_ci+    for (int i = 0; i < size; i++) {
8271be168c0dSopenharmony_ci+      if (std::isnan(data[i])) {
8272be168c0dSopenharmony_ci+        std::cout << "nan value of index=" << i << ", " << data[i] << std::endl;
8273be168c0dSopenharmony_ci+        break;
8274be168c0dSopenharmony_ci+      }
8275be168c0dSopenharmony_ci+    }
8276be168c0dSopenharmony_ci+  }
8277be168c0dSopenharmony_ci+#ifdef ENABLE_FP16
8278be168c0dSopenharmony_ci+  void TensorNan(float16_t *data, int size) {
8279be168c0dSopenharmony_ci+    for (int i = 0; i < size; i++) {
8280be168c0dSopenharmony_ci+      if (MS_ISNAN_FP16(data[i]) || std::isinf(data[i])) {
8281be168c0dSopenharmony_ci+        std::cout << "nan or inf value of index=" << i << ", " << data[i] << std::endl;
8282be168c0dSopenharmony_ci+        break;
8283be168c0dSopenharmony_ci+      }
8284be168c0dSopenharmony_ci+    }
8285be168c0dSopenharmony_ci+  }
8286be168c0dSopenharmony_ci+#endif
8287be168c0dSopenharmony_ci+  NetTrainFlags *flags_{nullptr};
8288be168c0dSopenharmony_ci+  static std::function<int(NetTrainFlags *)> nr_cb_;
8289be168c0dSopenharmony_ci+
8290be168c0dSopenharmony_ci+  nlohmann::json dump_cfg_json_;
8291be168c0dSopenharmony_ci+  std::string dump_file_output_dir_;
8292be168c0dSopenharmony_ci+  std::vector<std::shared_ptr<char>> inputs_buf_;
8293be168c0dSopenharmony_ci+  std::vector<size_t> inputs_size_;
8294be168c0dSopenharmony_ci+  size_t batch_num_ = 0;
8295be168c0dSopenharmony_ci+};
8296be168c0dSopenharmony_ci+}  // namespace mindspore::lite
8297be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_
8298be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train_c_api.cc b/mindspore/lite/tools/benchmark_train/net_train_c_api.cc
8299be168c0dSopenharmony_cinew file mode 100644
8300be168c0dSopenharmony_ciindex 00000000..4dcf3af6
8301be168c0dSopenharmony_ci--- /dev/null
8302be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train_c_api.cc
8303be168c0dSopenharmony_ci@@ -0,0 +1,659 @@
8304be168c0dSopenharmony_ci+/**
8305be168c0dSopenharmony_ci+ * Copyright 2023-2023 Huawei Technologies Co., Ltd
8306be168c0dSopenharmony_ci+ *
8307be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
8308be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
8309be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
8310be168c0dSopenharmony_ci+ *
8311be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
8312be168c0dSopenharmony_ci+ *
8313be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
8314be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
8315be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8316be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
8317be168c0dSopenharmony_ci+ * limitations under the License.
8318be168c0dSopenharmony_ci+ */
8319be168c0dSopenharmony_ci+
8320be168c0dSopenharmony_ci+#include "net_train_c_api.h"
8321be168c0dSopenharmony_ci+#include "securec/include/securec.h"
8322be168c0dSopenharmony_ci+
8323be168c0dSopenharmony_ci+namespace mindspore {
8324be168c0dSopenharmony_ci+namespace lite {
8325be168c0dSopenharmony_ci+uint64_t g_op_begin_ = 0;
8326be168c0dSopenharmony_ci+int g_op_call_times_total_ = 0;
8327be168c0dSopenharmony_ci+float g_op_cost_total_ = 0.0f;
8328be168c0dSopenharmony_ci+
8329be168c0dSopenharmony_ci+int NetTrainCApi::GenerateInputData() {
8330be168c0dSopenharmony_ci+  for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) {
8331be168c0dSopenharmony_ci+    OH_AI_TensorHandle tensor = ms_inputs_for_api_.handle_list[i];
8332be168c0dSopenharmony_ci+    auto data_type = OH_AI_TensorGetDataType(tensor);
8333be168c0dSopenharmony_ci+    if (data_type == OH_AI_DATATYPE_OBJECTTYPE_STRING) {
8334be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Unsupported OH_AI_DATATYPE_OBJECTTYPE_STRING";
8335be168c0dSopenharmony_ci+      return RET_ERROR;
8336be168c0dSopenharmony_ci+    } else {
8337be168c0dSopenharmony_ci+      (void)GenerateRandomData(static_cast<mindspore::MSTensor *>(tensor));
8338be168c0dSopenharmony_ci+    }
8339be168c0dSopenharmony_ci+  }
8340be168c0dSopenharmony_ci+  return RET_OK;
8341be168c0dSopenharmony_ci+}
8342be168c0dSopenharmony_ci+
8343be168c0dSopenharmony_ci+int NetTrainCApi::SaveModels() {
8344be168c0dSopenharmony_ci+  if (!flags_->export_file_.empty()) {
8345be168c0dSopenharmony_ci+    if (flags_->bb_model_file_.empty()) {
8346be168c0dSopenharmony_ci+      auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->export_file_ + "_qt").c_str(), OH_AI_WEIGHT_QUANT, false,
8347be168c0dSopenharmony_ci+                        nullptr, 0);
8348be168c0dSopenharmony_ci+      if (status != OH_AI_STATUS_SUCCESS) {
8349be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Export quantized model error " << flags_->export_file_ + "_qt";
8350be168c0dSopenharmony_ci+        std::cout << "Export quantized model error " << flags_->export_file_ + "_qt" << std::endl;
8351be168c0dSopenharmony_ci+        return RET_ERROR;
8352be168c0dSopenharmony_ci+      }
8353be168c0dSopenharmony_ci+    }
8354be168c0dSopenharmony_ci+    auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->export_file_).c_str(), OH_AI_NO_QUANT, false,
8355be168c0dSopenharmony_ci+                                    nullptr, 0);
8356be168c0dSopenharmony_ci+
8357be168c0dSopenharmony_ci+    if (status != OH_AI_STATUS_SUCCESS) {
8358be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Export non quantized model error " << flags_->export_file_;
8359be168c0dSopenharmony_ci+      std::cout << "Export non quantized model error " << flags_->export_file_ << std::endl;
8360be168c0dSopenharmony_ci+      return RET_ERROR;
8361be168c0dSopenharmony_ci+    }
8362be168c0dSopenharmony_ci+  }
8363be168c0dSopenharmony_ci+  if (!flags_->inference_file_.empty()) {
8364be168c0dSopenharmony_ci+    auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->inference_file_ + "_qt").c_str(), OH_AI_WEIGHT_QUANT, true,
8365be168c0dSopenharmony_ci+                                    nullptr, 0);
8366be168c0dSopenharmony_ci+    if (status != OH_AI_STATUS_SUCCESS) {
8367be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Export quantized inference model error " << flags_->inference_file_ + "_qt";
8368be168c0dSopenharmony_ci+      std::cout << "Export quantized inference model error " << flags_->inference_file_ + "_qt" << std::endl;
8369be168c0dSopenharmony_ci+      return RET_ERROR;
8370be168c0dSopenharmony_ci+    }
8371be168c0dSopenharmony_ci+
8372be168c0dSopenharmony_ci+    auto tick = GetTimeUs();
8373be168c0dSopenharmony_ci+    status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->inference_file_).c_str(), OH_AI_NO_QUANT, true,
8374be168c0dSopenharmony_ci+                                    nullptr, 0);
8375be168c0dSopenharmony_ci+    if (status != OH_AI_STATUS_SUCCESS) {
8376be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Export non quantized inference model error " << flags_->inference_file_;
8377be168c0dSopenharmony_ci+      std::cout << "Export non quantized inference model error " << flags_->inference_file_ << std::endl;
8378be168c0dSopenharmony_ci+      return RET_ERROR;
8379be168c0dSopenharmony_ci+    }
8380be168c0dSopenharmony_ci+    std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n";
8381be168c0dSopenharmony_ci+  }
8382be168c0dSopenharmony_ci+  return RET_OK;
8383be168c0dSopenharmony_ci+}
8384be168c0dSopenharmony_ci+
8385be168c0dSopenharmony_ci+int NetTrainCApi::LoadStepInput(size_t step) {
8386be168c0dSopenharmony_ci+  if (step >= batch_num_) {
8387be168c0dSopenharmony_ci+    auto cur_batch = step + 1;
8388be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Max input Batch is:" << batch_num_ << " but got batch :" << cur_batch;
8389be168c0dSopenharmony_ci+    return RET_ERROR;
8390be168c0dSopenharmony_ci+  }
8391be168c0dSopenharmony_ci+  for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) {
8392be168c0dSopenharmony_ci+    OH_AI_TensorHandle cur_tensor = ms_inputs_for_api_.handle_list[i];
8393be168c0dSopenharmony_ci+    MS_ASSERT(cur_tensor != nullptr);
8394be168c0dSopenharmony_ci+    auto tensor_data_size = OH_AI_TensorGetDataSize(cur_tensor);
8395be168c0dSopenharmony_ci+    auto input_data = OH_AI_TensorGetMutableData(cur_tensor);
8396be168c0dSopenharmony_ci+    MS_ASSERT(input_data != nullptr);
8397be168c0dSopenharmony_ci+    memcpy_s(input_data, tensor_data_size, inputs_buf_[i].get() + step * tensor_data_size, tensor_data_size);
8398be168c0dSopenharmony_ci+  }
8399be168c0dSopenharmony_ci+  return RET_OK;
8400be168c0dSopenharmony_ci+}
8401be168c0dSopenharmony_ci+
8402be168c0dSopenharmony_ci+int NetTrainCApi::ReadInputFile() {
8403be168c0dSopenharmony_ci+  if (this->flags_->in_data_type_ == lite::kImage) {
8404be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Unsupported image input";
8405be168c0dSopenharmony_ci+    return RET_ERROR;
8406be168c0dSopenharmony_ci+  } else {
8407be168c0dSopenharmony_ci+    for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) {
8408be168c0dSopenharmony_ci+      OH_AI_TensorHandle tensor = ms_inputs_for_api_.handle_list[i];
8409be168c0dSopenharmony_ci+      MS_ASSERT(tensor != nullptr);
8410be168c0dSopenharmony_ci+      size_t size;
8411be168c0dSopenharmony_ci+      std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
8412be168c0dSopenharmony_ci+      auto bin_buf = lite::ReadFile(file_name.c_str(), &size);
8413be168c0dSopenharmony_ci+      if (bin_buf == nullptr) {
8414be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "ReadFile failed";
8415be168c0dSopenharmony_ci+        return RET_ERROR;
8416be168c0dSopenharmony_ci+      }
8417be168c0dSopenharmony_ci+      auto tensor_data_size = OH_AI_TensorGetDataSize(tensor);
8418be168c0dSopenharmony_ci+      MS_ASSERT(tensor_data_size != 0);
8419be168c0dSopenharmony_ci+      if (size == 0 || size % tensor_data_size != 0 || (batch_num_ != 0 && size / tensor_data_size != batch_num_)) {
8420be168c0dSopenharmony_ci+        std::cerr << "Input binary file size error, required :N * " << tensor_data_size << ", in fact: " << size
8421be168c0dSopenharmony_ci+                  << " ,file_name: " << file_name.c_str() << std::endl;
8422be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Input binary file size error, required: N * " << tensor_data_size << ", in fact: " << size
8423be168c0dSopenharmony_ci+                      << " ,file_name: " << file_name.c_str();
8424be168c0dSopenharmony_ci+        delete bin_buf;
8425be168c0dSopenharmony_ci+        return RET_ERROR;
8426be168c0dSopenharmony_ci+      }
8427be168c0dSopenharmony_ci+      inputs_buf_.emplace_back(bin_buf);
8428be168c0dSopenharmony_ci+      inputs_size_.emplace_back(size);
8429be168c0dSopenharmony_ci+      batch_num_ = size / tensor_data_size;
8430be168c0dSopenharmony_ci+    }
8431be168c0dSopenharmony_ci+  }
8432be168c0dSopenharmony_ci+  return RET_OK;
8433be168c0dSopenharmony_ci+}
8434be168c0dSopenharmony_ci+
8435be168c0dSopenharmony_ci+int NetTrainCApi::InitDumpTensorDataCallbackParameter() {
8436be168c0dSopenharmony_ci+  MS_LOG(ERROR) << "Unsupported feature.";
8437be168c0dSopenharmony_ci+  return RET_ERROR;
8438be168c0dSopenharmony_ci+}
8439be168c0dSopenharmony_ci+
8440be168c0dSopenharmony_ci+int NetTrainCApi::InitTimeProfilingCallbackParameter() {
8441be168c0dSopenharmony_ci+  before_call_back_ = TimeProfilingBeforeCallback;
8442be168c0dSopenharmony_ci+  after_call_back_ = TimeProfilingAfterCallback;
8443be168c0dSopenharmony_ci+  return RET_OK;
8444be168c0dSopenharmony_ci+}
8445be168c0dSopenharmony_ci+
8446be168c0dSopenharmony_ci+int NetTrainCApi::InitMSContext() {
8447be168c0dSopenharmony_ci+  context_ = OH_AI_ContextCreate();
8448be168c0dSopenharmony_ci+  if (context_ == nullptr) {
8449be168c0dSopenharmony_ci+    MS_LOG(INFO) << "OH_AI_ContextCreate failed";
8450be168c0dSopenharmony_ci+    return RET_ERROR;
8451be168c0dSopenharmony_ci+  }
8452be168c0dSopenharmony_ci+  OH_AI_ContextSetThreadNum(context_, flags_->num_threads_);
8453be168c0dSopenharmony_ci+  OH_AI_ContextSetThreadAffinityMode(context_, flags_->cpu_bind_mode_);
8454be168c0dSopenharmony_ci+
8455be168c0dSopenharmony_ci+  OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
8456be168c0dSopenharmony_ci+  OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_);
8457be168c0dSopenharmony_ci+  OH_AI_ContextAddDeviceInfo(context_, cpu_device_info);
8458be168c0dSopenharmony_ci+  return RET_OK;
8459be168c0dSopenharmony_ci+}
8460be168c0dSopenharmony_ci+
8461be168c0dSopenharmony_ci+char **NetTrainCApi::TransStrVectorToCharArrays(const std::vector<std::string> &s) {
8462be168c0dSopenharmony_ci+  char **char_arr = static_cast<char **>(malloc(s.size() * sizeof(char *)));
8463be168c0dSopenharmony_ci+  for (size_t i = 0; i < s.size(); i++) {
8464be168c0dSopenharmony_ci+    char_arr[i] = static_cast<char *>(malloc((s[i].size() + 1)));
8465be168c0dSopenharmony_ci+    strcpy(char_arr[i], s[i].c_str());
8466be168c0dSopenharmony_ci+  }
8467be168c0dSopenharmony_ci+  return char_arr;
8468be168c0dSopenharmony_ci+}
8469be168c0dSopenharmony_ci+
8470be168c0dSopenharmony_ci+std::vector<std::string> NetTrainCApi::TransCharArraysToStrVector(char **c, const size_t &num) {
8471be168c0dSopenharmony_ci+  std::vector<std::string> str;
8472be168c0dSopenharmony_ci+  for (size_t i = 0; i < num; i++) {
8473be168c0dSopenharmony_ci+    str.push_back(std::string(c[i]));
8474be168c0dSopenharmony_ci+  }
8475be168c0dSopenharmony_ci+  return str;
8476be168c0dSopenharmony_ci+}
8477be168c0dSopenharmony_ci+
8478be168c0dSopenharmony_ci+void NetTrainCApi::InitTrainCfg() {
8479be168c0dSopenharmony_ci+  if (flags_->loss_name_.empty()) {
8480be168c0dSopenharmony_ci+    return;
8481be168c0dSopenharmony_ci+  }
8482be168c0dSopenharmony_ci+
8483be168c0dSopenharmony_ci+  std::string delimiter = ",";
8484be168c0dSopenharmony_ci+  size_t pos = 0;
8485be168c0dSopenharmony_ci+  std::string token;
8486be168c0dSopenharmony_ci+  train_cfg_ = OH_AI_TrainCfgCreate();
8487be168c0dSopenharmony_ci+  size_t num = 0;
8488be168c0dSopenharmony_ci+  std::vector<std::string> train_cfg_loss_name;
8489be168c0dSopenharmony_ci+  OH_AI_TrainCfgSetLossName(train_cfg_, nullptr, train_cfg_loss_name.size());
8490be168c0dSopenharmony_ci+  while ((pos = flags_->loss_name_.find(delimiter)) != std::string::npos) {
8491be168c0dSopenharmony_ci+    token = flags_->loss_name_.substr(0, pos);
8492be168c0dSopenharmony_ci+    flags_->loss_name_.erase(0, pos + delimiter.length());  // change to delim without deletion
8493be168c0dSopenharmony_ci+    char **name = OH_AI_TrainCfgGetLossName(train_cfg_, &num);
8494be168c0dSopenharmony_ci+    train_cfg_loss_name = TransCharArraysToStrVector(name, num);
8495be168c0dSopenharmony_ci+    train_cfg_loss_name.push_back(token);
8496be168c0dSopenharmony_ci+    char **loss_name = TransStrVectorToCharArrays(train_cfg_loss_name);
8497be168c0dSopenharmony_ci+    OH_AI_TrainCfgSetLossName(train_cfg_, const_cast<const char **>(loss_name), train_cfg_loss_name.size());
8498be168c0dSopenharmony_ci+    for (size_t i = 0; i < train_cfg_loss_name.size(); i++) {
8499be168c0dSopenharmony_ci+      free(loss_name[i]);
8500be168c0dSopenharmony_ci+    }
8501be168c0dSopenharmony_ci+    free(loss_name);
8502be168c0dSopenharmony_ci+    for (size_t i = 0; i < num; i++) {
8503be168c0dSopenharmony_ci+      free(name[i]);
8504be168c0dSopenharmony_ci+    }
8505be168c0dSopenharmony_ci+    free(name);
8506be168c0dSopenharmony_ci+  }
8507be168c0dSopenharmony_ci+  if (!(flags_->loss_name_.empty())) {
8508be168c0dSopenharmony_ci+    char **name = OH_AI_TrainCfgGetLossName(train_cfg_, &num);
8509be168c0dSopenharmony_ci+    train_cfg_loss_name = TransCharArraysToStrVector(name, num);
8510be168c0dSopenharmony_ci+    train_cfg_loss_name.push_back(flags_->loss_name_);
8511be168c0dSopenharmony_ci+    char **loss_name = TransStrVectorToCharArrays(train_cfg_loss_name);
8512be168c0dSopenharmony_ci+    OH_AI_TrainCfgSetLossName(train_cfg_, const_cast<const char **>(loss_name), train_cfg_loss_name.size());
8513be168c0dSopenharmony_ci+    for (size_t i = 0; i < train_cfg_loss_name.size(); i++) {
8514be168c0dSopenharmony_ci+      free(loss_name[i]);
8515be168c0dSopenharmony_ci+    }
8516be168c0dSopenharmony_ci+    free(loss_name);
8517be168c0dSopenharmony_ci+    for (size_t i = 0; i < num; i++) {
8518be168c0dSopenharmony_ci+      free(name[i]);
8519be168c0dSopenharmony_ci+    }
8520be168c0dSopenharmony_ci+    free(name);
8521be168c0dSopenharmony_ci+  }
8522be168c0dSopenharmony_ci+}
8523be168c0dSopenharmony_ci+
8524be168c0dSopenharmony_ci+int NetTrainCApi::CreateAndRunNetworkForInference(const std::string &filename,
8525be168c0dSopenharmony_ci+                                              const OH_AI_ContextHandle &context) {
8526be168c0dSopenharmony_ci+  std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
8527be168c0dSopenharmony_ci+  std::string filenamems = filename;
8528be168c0dSopenharmony_ci+  if (filenamems.substr(filenamems.find_last_of('.') + 1) != "ms") {
8529be168c0dSopenharmony_ci+    filenamems = filenamems + ".ms";
8530be168c0dSopenharmony_ci+  }
8531be168c0dSopenharmony_ci+  MS_LOG(INFO) << "start reading model file " << filenamems.c_str();
8532be168c0dSopenharmony_ci+  std::cout << "start reading model file " << filenamems.c_str() << std::endl;
8533be168c0dSopenharmony_ci+  auto status = OH_AI_ModelBuildFromFile(ms_model_, filenamems.c_str(),
8534be168c0dSopenharmony_ci+                                         static_cast<OH_AI_ModelType>(mindspore::kMindIR), context);
8535be168c0dSopenharmony_ci+  if (status != OH_AI_STATUS_SUCCESS) {
8536be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "ms model build failed. " << model_name;
8537be168c0dSopenharmony_ci+    return RET_ERROR;
8538be168c0dSopenharmony_ci+  }
8539be168c0dSopenharmony_ci+  return RET_OK;
8540be168c0dSopenharmony_ci+}
8541be168c0dSopenharmony_ci+
8542be168c0dSopenharmony_ci+int NetTrainCApi::CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename,
8543be168c0dSopenharmony_ci+                                          const OH_AI_ContextHandle &context,
8544be168c0dSopenharmony_ci+                                          const OH_AI_TrainCfgHandle &train_cfg, int epochs) {
8545be168c0dSopenharmony_ci+  std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
8546be168c0dSopenharmony_ci+  OH_AI_Status status;
8547be168c0dSopenharmony_ci+  if (!bb_filename.empty()) {
8548be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "build transfer learning not supported. " << model_name;
8549be168c0dSopenharmony_ci+      return RET_ERROR;
8550be168c0dSopenharmony_ci+  } else {
8551be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Build mindspore model from model file" << filename.c_str();
8552be168c0dSopenharmony_ci+    std::cout << "Build mindspore model from model file" << filename.c_str() << std::endl;
8553be168c0dSopenharmony_ci+    status = OH_AI_TrainModelBuildFromFile(ms_model_, filename.c_str(), OH_AI_MODELTYPE_MINDIR, context, train_cfg);
8554be168c0dSopenharmony_ci+    if (status != OH_AI_STATUS_SUCCESS) {
8555be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "build transfer learning failed. " << model_name;
8556be168c0dSopenharmony_ci+      return RET_ERROR;
8557be168c0dSopenharmony_ci+    }
8558be168c0dSopenharmony_ci+  }
8559be168c0dSopenharmony_ci+  if (epochs > 0) {
8560be168c0dSopenharmony_ci+    if (flags_->virtual_batch_) {
8561be168c0dSopenharmony_ci+      OH_AI_ModelSetupVirtualBatch(ms_model_, epochs, -1.0f, -1.0f);
8562be168c0dSopenharmony_ci+    }
8563be168c0dSopenharmony_ci+    status = OH_AI_ModelSetTrainMode(ms_model_, true);
8564be168c0dSopenharmony_ci+    if (status != OH_AI_STATUS_SUCCESS) {
8565be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "set train mode failed. ";
8566be168c0dSopenharmony_ci+      return RET_ERROR;
8567be168c0dSopenharmony_ci+    }
8568be168c0dSopenharmony_ci+  }
8569be168c0dSopenharmony_ci+  return RET_OK;
8570be168c0dSopenharmony_ci+}
8571be168c0dSopenharmony_ci+
8572be168c0dSopenharmony_ci+int NetTrainCApi::CompareOutput() {
8573be168c0dSopenharmony_ci+  std::cout << "================ Comparing Forward Output data ================" << std::endl;
8574be168c0dSopenharmony_ci+  float total_bias = 0;
8575be168c0dSopenharmony_ci+  int total_size = 0;
8576be168c0dSopenharmony_ci+  bool has_error = false;
8577be168c0dSopenharmony_ci+  auto output_tensors_handle = OH_AI_ModelGetOutputs(ms_model_);
8578be168c0dSopenharmony_ci+
8579be168c0dSopenharmony_ci+  std::vector<mindspore::MSTensor> output_tensors;
8580be168c0dSopenharmony_ci+  for (size_t i = 0; i < output_tensors_handle.handle_num; i++) {
8581be168c0dSopenharmony_ci+    output_tensors.push_back(*static_cast<mindspore::MSTensor *>(output_tensors_handle.handle_list[i]));
8582be168c0dSopenharmony_ci+  }
8583be168c0dSopenharmony_ci+  if (output_tensors.empty()) {
8584be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
8585be168c0dSopenharmony_ci+    return RET_ERROR;
8586be168c0dSopenharmony_ci+  }
8587be168c0dSopenharmony_ci+  std::map<std::string, MSTensor> ordered_outputs;
8588be168c0dSopenharmony_ci+  for (const auto &output_tensor : output_tensors) {
8589be168c0dSopenharmony_ci+    ordered_outputs.insert({output_tensor.Name(), output_tensor});
8590be168c0dSopenharmony_ci+  }
8591be168c0dSopenharmony_ci+  int i = 1;
8592be168c0dSopenharmony_ci+  mindspore::MSTensor tensor;
8593be168c0dSopenharmony_ci+  for (auto &ordered_output : ordered_outputs) {
8594be168c0dSopenharmony_ci+    tensor = ordered_output.second;
8595be168c0dSopenharmony_ci+    std::cout << "output is tensor " << ordered_output.first << "\n";
8596be168c0dSopenharmony_ci+    auto outputs = tensor.MutableData();
8597be168c0dSopenharmony_ci+    size_t size;
8598be168c0dSopenharmony_ci+    std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
8599be168c0dSopenharmony_ci+    auto bin_buf = std::unique_ptr<float[]>(ReadFileBuf(output_file.c_str(), &size));
8600be168c0dSopenharmony_ci+    if (bin_buf == nullptr) {
8601be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "ReadFile return nullptr";
8602be168c0dSopenharmony_ci+      std::cout << "ReadFile return nullptr" << std::endl;
8603be168c0dSopenharmony_ci+      return RET_ERROR;
8604be168c0dSopenharmony_ci+    }
8605be168c0dSopenharmony_ci+    if (size != tensor.DataSize()) {
8606be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize()
8607be168c0dSopenharmony_ci+                    << ", read size: " << size;
8608be168c0dSopenharmony_ci+      std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize()
8609be168c0dSopenharmony_ci+                << ", read size: " << size << std::endl;
8610be168c0dSopenharmony_ci+      return RET_ERROR;
8611be168c0dSopenharmony_ci+    }
8612be168c0dSopenharmony_ci+    float bias = CompareData<float>(bin_buf.get(), tensor.ElementNum(), reinterpret_cast<float *>(outputs));
8613be168c0dSopenharmony_ci+    if (bias >= 0) {
8614be168c0dSopenharmony_ci+      total_bias += bias;
8615be168c0dSopenharmony_ci+      total_size++;
8616be168c0dSopenharmony_ci+    } else {
8617be168c0dSopenharmony_ci+      has_error = true;
8618be168c0dSopenharmony_ci+      break;
8619be168c0dSopenharmony_ci+    }
8620be168c0dSopenharmony_ci+    i++;
8621be168c0dSopenharmony_ci+  }
8622be168c0dSopenharmony_ci+
8623be168c0dSopenharmony_ci+  if (!has_error) {
8624be168c0dSopenharmony_ci+    float mean_bias;
8625be168c0dSopenharmony_ci+    if (total_size != 0) {
8626be168c0dSopenharmony_ci+      mean_bias = total_bias / total_size * 100;
8627be168c0dSopenharmony_ci+    } else {
8628be168c0dSopenharmony_ci+      mean_bias = 0;
8629be168c0dSopenharmony_ci+    }
8630be168c0dSopenharmony_ci+
8631be168c0dSopenharmony_ci+    std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
8632be168c0dSopenharmony_ci+              << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
8633be168c0dSopenharmony_ci+    std::cout << "=======================================================" << std::endl << std::endl;
8634be168c0dSopenharmony_ci+
8635be168c0dSopenharmony_ci+    if (mean_bias > this->flags_->accuracy_threshold_) {
8636be168c0dSopenharmony_ci+      MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
8637be168c0dSopenharmony_ci+      std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
8638be168c0dSopenharmony_ci+      return RET_TOO_BIG;
8639be168c0dSopenharmony_ci+    } else {
8640be168c0dSopenharmony_ci+      return RET_OK;
8641be168c0dSopenharmony_ci+    }
8642be168c0dSopenharmony_ci+  } else {
8643be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Error in CompareData";
8644be168c0dSopenharmony_ci+    std::cerr << "Error in CompareData" << std::endl;
8645be168c0dSopenharmony_ci+    std::cout << "=======================================================" << std::endl << std::endl;
8646be168c0dSopenharmony_ci+    return RET_ERROR;
8647be168c0dSopenharmony_ci+  }
8648be168c0dSopenharmony_ci+}
8649be168c0dSopenharmony_ci+
8650be168c0dSopenharmony_ci+int NetTrainCApi::MarkPerformance() {
8651be168c0dSopenharmony_ci+  MS_LOG(INFO) << "Running train loops...";
8652be168c0dSopenharmony_ci+  std::cout << "Running train loops..." << std::endl;
8653be168c0dSopenharmony_ci+  uint64_t time_min = 0xFFFFFFFFFFFFFFFF;
8654be168c0dSopenharmony_ci+  uint64_t time_max = 0;
8655be168c0dSopenharmony_ci+  uint64_t time_avg = 0;
8656be168c0dSopenharmony_ci+  std::vector<MSTensor> outputs;
8657be168c0dSopenharmony_ci+
8658be168c0dSopenharmony_ci+  for (int i = 0; i < flags_->epochs_; i++) {
8659be168c0dSopenharmony_ci+    auto start = GetTimeUs();
8660be168c0dSopenharmony_ci+    for (size_t step = 0; step < batch_num_; step++) {
8661be168c0dSopenharmony_ci+      MS_LOG(INFO) << "Run for epoch:" << i << " step:" << step;
8662be168c0dSopenharmony_ci+      auto ret = LoadStepInput(step);
8663be168c0dSopenharmony_ci+      if (ret != RET_OK) {
8664be168c0dSopenharmony_ci+        return ret;
8665be168c0dSopenharmony_ci+      }
8666be168c0dSopenharmony_ci+      auto status = OH_AI_RunStep(ms_model_, before_call_back_, after_call_back_);
8667be168c0dSopenharmony_ci+      if (status != OH_AI_STATUS_SUCCESS) {
8668be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Inference error " << status;
8669be168c0dSopenharmony_ci+        std::cerr << "Inference error " << status;
8670be168c0dSopenharmony_ci+        return RET_ERROR;
8671be168c0dSopenharmony_ci+      }
8672be168c0dSopenharmony_ci+    }
8673be168c0dSopenharmony_ci+
8674be168c0dSopenharmony_ci+    auto end = GetTimeUs();
8675be168c0dSopenharmony_ci+    auto time = end - start;
8676be168c0dSopenharmony_ci+    time_min = std::min(time_min, time);
8677be168c0dSopenharmony_ci+    time_max = std::max(time_max, time);
8678be168c0dSopenharmony_ci+    time_avg += time;
8679be168c0dSopenharmony_ci+  }
8680be168c0dSopenharmony_ci+
8681be168c0dSopenharmony_ci+  if (flags_->time_profiling_) {
8682be168c0dSopenharmony_ci+    const std::vector<std::string> per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
8683be168c0dSopenharmony_ci+    const std::vector<std::string> per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
8684be168c0dSopenharmony_ci+    PrintResult(per_op_name, g_c_op_times_by_name_);
8685be168c0dSopenharmony_ci+    PrintResult(per_op_type, g_c_op_times_by_type_);
8686be168c0dSopenharmony_ci+  }
8687be168c0dSopenharmony_ci+
8688be168c0dSopenharmony_ci+  if (flags_->epochs_ > 0) {
8689be168c0dSopenharmony_ci+    time_avg /= static_cast<size_t>(flags_->epochs_);
8690be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Model = " << flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
8691be168c0dSopenharmony_ci+                 << ", NumThreads = " << flags_->num_threads_ << ", MinRunTime = " << time_min / 1000.0f
8692be168c0dSopenharmony_ci+                 << ", MaxRuntime = " << time_max / 1000.0f << ", AvgRunTime = " << time_avg / 1000.0f;
8693be168c0dSopenharmony_ci+    printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n",
8694be168c0dSopenharmony_ci+           flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str(), flags_->num_threads_,
8695be168c0dSopenharmony_ci+           time_min / 1000.0f, time_max / 1000.0f, time_avg / 1000.0f);
8696be168c0dSopenharmony_ci+  }
8697be168c0dSopenharmony_ci+  return RET_OK;
8698be168c0dSopenharmony_ci+}
8699be168c0dSopenharmony_ci+
8700be168c0dSopenharmony_ci+int NetTrainCApi::MarkAccuracy(bool enforce_accuracy) {
8701be168c0dSopenharmony_ci+  MS_LOG(INFO) << "MarkAccuracy";
8702be168c0dSopenharmony_ci+  auto load_ret = LoadStepInput(0);
8703be168c0dSopenharmony_ci+  if (load_ret != RET_OK) {
8704be168c0dSopenharmony_ci+    return load_ret;
8705be168c0dSopenharmony_ci+  }
8706be168c0dSopenharmony_ci+  auto status = PrintInputData();
8707be168c0dSopenharmony_ci+  if (status != RET_OK) {
8708be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "PrintInputData failed, ret: " << status;
8709be168c0dSopenharmony_ci+    return status;
8710be168c0dSopenharmony_ci+  }
8711be168c0dSopenharmony_ci+  status = OH_AI_RunStep(ms_model_, before_call_back_, after_call_back_);
8712be168c0dSopenharmony_ci+  if (status != OH_AI_STATUS_SUCCESS) {
8713be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Inference error " << status;
8714be168c0dSopenharmony_ci+    std::cerr << "Inference error " << status << std::endl;
8715be168c0dSopenharmony_ci+    return RET_ERROR;
8716be168c0dSopenharmony_ci+  }
8717be168c0dSopenharmony_ci+
8718be168c0dSopenharmony_ci+  auto ret = CompareOutput();
8719be168c0dSopenharmony_ci+  if (ret == RET_TOO_BIG && !enforce_accuracy) {
8720be168c0dSopenharmony_ci+    MS_LOG(INFO) << "Accuracy Error is big but not enforced";
8721be168c0dSopenharmony_ci+    std::cout << "Accuracy Error is big but not enforced" << std::endl;
8722be168c0dSopenharmony_ci+    return RET_OK;
8723be168c0dSopenharmony_ci+  }
8724be168c0dSopenharmony_ci+
8725be168c0dSopenharmony_ci+  if (ret != RET_OK) {
8726be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Compare output error " << ret;
8727be168c0dSopenharmony_ci+    std::cerr << "Compare output error " << ret << std::endl;
8728be168c0dSopenharmony_ci+    return ret;
8729be168c0dSopenharmony_ci+  }
8730be168c0dSopenharmony_ci+  return RET_OK;
8731be168c0dSopenharmony_ci+}
8732be168c0dSopenharmony_ci+
8733be168c0dSopenharmony_ci+int NetTrainCApi::CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train,
8734be168c0dSopenharmony_ci+                                  int epochs, bool check_accuracy) {
8735be168c0dSopenharmony_ci+  auto start_prepare_time = GetTimeUs();
8736be168c0dSopenharmony_ci+
8737be168c0dSopenharmony_ci+  int ret = InitMSContext();
8738be168c0dSopenharmony_ci+  if (ret != RET_OK) {
8739be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "InitContext failed, ret: " << ret;
8740be168c0dSopenharmony_ci+    return ret;
8741be168c0dSopenharmony_ci+  }
8742be168c0dSopenharmony_ci+
8743be168c0dSopenharmony_ci+  InitTrainCfg();
8744be168c0dSopenharmony_ci+  ms_model_ = OH_AI_ModelCreate();
8745be168c0dSopenharmony_ci+
8746be168c0dSopenharmony_ci+  if (is_train) {
8747be168c0dSopenharmony_ci+    ret = CreateAndRunNetworkForTrain(filename, bb_filename, context_ , train_cfg_, epochs);
8748be168c0dSopenharmony_ci+    if (ret != RET_OK) {
8749be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "CreateAndRunNetworkForTrain failed.";
8750be168c0dSopenharmony_ci+      return RET_ERROR;
8751be168c0dSopenharmony_ci+    }
8752be168c0dSopenharmony_ci+  } else {
8753be168c0dSopenharmony_ci+    ret = CreateAndRunNetworkForInference(filename, context_);
8754be168c0dSopenharmony_ci+    if (ret != RET_OK) {
8755be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed.";
8756be168c0dSopenharmony_ci+      return RET_ERROR;
8757be168c0dSopenharmony_ci+    }
8758be168c0dSopenharmony_ci+  }
8759be168c0dSopenharmony_ci+
8760be168c0dSopenharmony_ci+  ms_inputs_for_api_ = OH_AI_ModelGetInputs(ms_model_);
8761be168c0dSopenharmony_ci+  if (ms_inputs_for_api_.handle_list == nullptr) {
8762be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "OH_AI_ModelGetInputs failed, ret: ";
8763be168c0dSopenharmony_ci+    return RET_ERROR;
8764be168c0dSopenharmony_ci+  }
8765be168c0dSopenharmony_ci+
8766be168c0dSopenharmony_ci+  if (!flags_->resize_dims_.empty()) {
8767be168c0dSopenharmony_ci+    std::vector<OH_AI_ShapeInfo> shape_infos;
8768be168c0dSopenharmony_ci+    std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(shape_infos),
8769be168c0dSopenharmony_ci+                   [&](auto &shapes) {
8770be168c0dSopenharmony_ci+                     OH_AI_ShapeInfo shape_info;
8771be168c0dSopenharmony_ci+                     shape_info.shape_num = shapes.size();
8772be168c0dSopenharmony_ci+                     for (size_t i = 0; i < shape_info.shape_num; i++) {
8773be168c0dSopenharmony_ci+                       shape_info.shape[i] = shapes[i];
8774be168c0dSopenharmony_ci+                     }
8775be168c0dSopenharmony_ci+                     return shape_info;
8776be168c0dSopenharmony_ci+                   });
8777be168c0dSopenharmony_ci+    auto status = OH_AI_ModelResize(ms_model_, ms_inputs_for_api_, shape_infos.data(), shape_infos.size());
8778be168c0dSopenharmony_ci+    if (status != OH_AI_STATUS_SUCCESS) {
8779be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Input tensor resize failed.";
8780be168c0dSopenharmony_ci+      std::cout << "Input tensor resize failed.";
8781be168c0dSopenharmony_ci+      return RET_ERROR;
8782be168c0dSopenharmony_ci+    }
8783be168c0dSopenharmony_ci+  }
8784be168c0dSopenharmony_ci+
8785be168c0dSopenharmony_ci+  auto end_prepare_time = GetTimeUs();
8786be168c0dSopenharmony_ci+  MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms";
8787be168c0dSopenharmony_ci+  std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl;
8788be168c0dSopenharmony_ci+  // Load input
8789be168c0dSopenharmony_ci+  MS_LOG(INFO) << "Load input data";
8790be168c0dSopenharmony_ci+  auto status = LoadInput();
8791be168c0dSopenharmony_ci+  if (status != RET_OK) {
8792be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Load input data error";
8793be168c0dSopenharmony_ci+    std::cout << "Load input data error" << std::endl;
8794be168c0dSopenharmony_ci+    return status;
8795be168c0dSopenharmony_ci+  }
8796be168c0dSopenharmony_ci+
8797be168c0dSopenharmony_ci+  if ((epochs > 0) && is_train) {
8798be168c0dSopenharmony_ci+    status = MarkPerformance();
8799be168c0dSopenharmony_ci+    if (status != RET_OK) {
8800be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Run MarkPerformance error: " << status;
8801be168c0dSopenharmony_ci+      std::cout << "Run MarkPerformance error: " << status << std::endl;
8802be168c0dSopenharmony_ci+      return status;
8803be168c0dSopenharmony_ci+    }
8804be168c0dSopenharmony_ci+    SaveModels();  // save file if flags are on
8805be168c0dSopenharmony_ci+  }
8806be168c0dSopenharmony_ci+  if (!flags_->data_file_.empty()) {
8807be168c0dSopenharmony_ci+    auto res = OH_AI_ModelSetTrainMode(ms_model_, false);
8808be168c0dSopenharmony_ci+    if (res != OH_AI_STATUS_SUCCESS) {
8809be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "set eval mode failed. ";
8810be168c0dSopenharmony_ci+      return RET_ERROR;
8811be168c0dSopenharmony_ci+    }
8812be168c0dSopenharmony_ci+
8813be168c0dSopenharmony_ci+    status = MarkAccuracy(check_accuracy);
8814be168c0dSopenharmony_ci+    if (status != RET_OK) {
8815be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
8816be168c0dSopenharmony_ci+      std::cout << "Run MarkAccuracy error: " << status << std::endl;
8817be168c0dSopenharmony_ci+      return status;
8818be168c0dSopenharmony_ci+    }
8819be168c0dSopenharmony_ci+  }
8820be168c0dSopenharmony_ci+  return RET_OK;
8821be168c0dSopenharmony_ci+}
8822be168c0dSopenharmony_ci+
8823be168c0dSopenharmony_ci+int NetTrainCApi::PrintInputData() {
8824be168c0dSopenharmony_ci+  constexpr int64_t kPrintDataNum = 20;
8825be168c0dSopenharmony_ci+  for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) {
8826be168c0dSopenharmony_ci+    auto input = ms_inputs_for_api_.handle_list[i];
8827be168c0dSopenharmony_ci+    std::cout << "InData" << i << ": ";
8828be168c0dSopenharmony_ci+    auto data_type = static_cast<TypeId>(OH_AI_TensorGetDataType(input));
8829be168c0dSopenharmony_ci+    if (data_type == TypeId::kObjectTypeString) {
8830be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Unsupported OH_AI_DATATYPE_OBJECTTYPE_STRING.";
8831be168c0dSopenharmony_ci+      return RET_ERROR;
8832be168c0dSopenharmony_ci+    }
8833be168c0dSopenharmony_ci+    auto tensor_data = OH_AI_TensorGetData(input);
8834be168c0dSopenharmony_ci+    size_t print_num = std::min(OH_AI_TensorGetElementNum(input), kPrintDataNum);
8835be168c0dSopenharmony_ci+    for (size_t j = 0; j < print_num; j++) {
8836be168c0dSopenharmony_ci+      if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat) {
8837be168c0dSopenharmony_ci+        std::cout << static_cast<const float *>(tensor_data)[j] << " ";
8838be168c0dSopenharmony_ci+      } else if (data_type == TypeId::kNumberTypeInt8) {
8839be168c0dSopenharmony_ci+        std::cout << static_cast<const int8_t *>(tensor_data)[j] << " ";
8840be168c0dSopenharmony_ci+      } else if (data_type == TypeId::kNumberTypeUInt8) {
8841be168c0dSopenharmony_ci+        std::cout << static_cast<const uint8_t *>(tensor_data)[j] << " ";
8842be168c0dSopenharmony_ci+      } else if (data_type == TypeId::kNumberTypeInt32) {
8843be168c0dSopenharmony_ci+        std::cout << static_cast<const int32_t *>(tensor_data)[j] << " ";
8844be168c0dSopenharmony_ci+      } else if (data_type == TypeId::kNumberTypeInt64) {
8845be168c0dSopenharmony_ci+        std::cout << static_cast<const int64_t *>(tensor_data)[j] << " ";
8846be168c0dSopenharmony_ci+      } else if (data_type == TypeId::kNumberTypeBool) {
8847be168c0dSopenharmony_ci+        std::cout << static_cast<const bool *>(tensor_data)[j] << " ";
8848be168c0dSopenharmony_ci+      } else {
8849be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Datatype: " << data_type << " is not supported.";
8850be168c0dSopenharmony_ci+        return RET_ERROR;
8851be168c0dSopenharmony_ci+      }
8852be168c0dSopenharmony_ci+    }
8853be168c0dSopenharmony_ci+    std::cout << std::endl;
8854be168c0dSopenharmony_ci+  }
8855be168c0dSopenharmony_ci+  return RET_OK;
8856be168c0dSopenharmony_ci+}
8857be168c0dSopenharmony_ci+
8858be168c0dSopenharmony_ci+int NetTrainCApi::PrintResult(const std::vector<std::string> &title,
8859be168c0dSopenharmony_ci+                          const std::map<std::string, std::pair<int, float>> &result) {
8860be168c0dSopenharmony_ci+  std::vector<size_t> columnLenMax(kFieldsToPrint);
8861be168c0dSopenharmony_ci+  std::vector<std::vector<std::string>> rows;
8862be168c0dSopenharmony_ci+
8863be168c0dSopenharmony_ci+  for (auto &iter : result) {
8864be168c0dSopenharmony_ci+    std::string stringBuf[kFieldsToPrint];
8865be168c0dSopenharmony_ci+    std::vector<std::string> columns;
8866be168c0dSopenharmony_ci+    size_t len = 0;
8867be168c0dSopenharmony_ci+    int index = 0;
8868be168c0dSopenharmony_ci+    len = iter.first.size();
8869be168c0dSopenharmony_ci+    if (len > columnLenMax.at(index)) {
8870be168c0dSopenharmony_ci+      columnLenMax.at(index) = len + kPrintOffset;
8871be168c0dSopenharmony_ci+    }
8872be168c0dSopenharmony_ci+    columns.push_back(iter.first);
8873be168c0dSopenharmony_ci+
8874be168c0dSopenharmony_ci+    index++;
8875be168c0dSopenharmony_ci+    if (title[0] == "opName") {
8876be168c0dSopenharmony_ci+      stringBuf[index] = std::to_string(iter.second.second / flags_->epochs_);
8877be168c0dSopenharmony_ci+    } else {
8878be168c0dSopenharmony_ci+      stringBuf[index] = std::to_string(iter.second.second / iter.second.first);
8879be168c0dSopenharmony_ci+    }
8880be168c0dSopenharmony_ci+    len = stringBuf[index].length();
8881be168c0dSopenharmony_ci+    if (len > columnLenMax.at(index)) {
8882be168c0dSopenharmony_ci+      columnLenMax.at(index) = len + kPrintOffset;
8883be168c0dSopenharmony_ci+    }
8884be168c0dSopenharmony_ci+    columns.emplace_back(stringBuf[index]);
8885be168c0dSopenharmony_ci+
8886be168c0dSopenharmony_ci+    index++;
8887be168c0dSopenharmony_ci+    stringBuf[index] = std::to_string(iter.second.second / g_op_cost_total_);
8888be168c0dSopenharmony_ci+    len = stringBuf[index].length();
8889be168c0dSopenharmony_ci+    if (len > columnLenMax.at(index)) {
8890be168c0dSopenharmony_ci+      columnLenMax.at(index) = len + kPrintOffset;
8891be168c0dSopenharmony_ci+    }
8892be168c0dSopenharmony_ci+    columns.emplace_back(stringBuf[index]);
8893be168c0dSopenharmony_ci+
8894be168c0dSopenharmony_ci+    index++;
8895be168c0dSopenharmony_ci+    stringBuf[index] = std::to_string(iter.second.first);
8896be168c0dSopenharmony_ci+    len = stringBuf[index].length();
8897be168c0dSopenharmony_ci+    if (len > columnLenMax.at(index)) {
8898be168c0dSopenharmony_ci+      columnLenMax.at(index) = len + kPrintOffset;
8899be168c0dSopenharmony_ci+    }
8900be168c0dSopenharmony_ci+    columns.emplace_back(stringBuf[index]);
8901be168c0dSopenharmony_ci+
8902be168c0dSopenharmony_ci+    index++;
8903be168c0dSopenharmony_ci+    stringBuf[index] = std::to_string(iter.second.second);
8904be168c0dSopenharmony_ci+    len = stringBuf[index].length();
8905be168c0dSopenharmony_ci+    if (len > columnLenMax.at(index)) {
8906be168c0dSopenharmony_ci+      columnLenMax.at(index) = len + kPrintOffset;
8907be168c0dSopenharmony_ci+    }
8908be168c0dSopenharmony_ci+    columns.emplace_back(stringBuf[index]);
8909be168c0dSopenharmony_ci+
8910be168c0dSopenharmony_ci+    rows.push_back(columns);
8911be168c0dSopenharmony_ci+  }
8912be168c0dSopenharmony_ci+
8913be168c0dSopenharmony_ci+  printf("-------------------------------------------------------------------------\n");
8914be168c0dSopenharmony_ci+  for (int i = 0; i < kFieldsToPrint; i++) {
8915be168c0dSopenharmony_ci+    auto printBuf = title[i];
8916be168c0dSopenharmony_ci+    if (printBuf.size() > columnLenMax.at(i)) {
8917be168c0dSopenharmony_ci+      columnLenMax.at(i) = printBuf.size();
8918be168c0dSopenharmony_ci+    }
8919be168c0dSopenharmony_ci+    printBuf.resize(columnLenMax.at(i), ' ');
8920be168c0dSopenharmony_ci+    printf("%s\t", printBuf.c_str());
8921be168c0dSopenharmony_ci+  }
8922be168c0dSopenharmony_ci+  printf("\n");
8923be168c0dSopenharmony_ci+  for (auto &row : rows) {
8924be168c0dSopenharmony_ci+    for (int j = 0; j < kFieldsToPrint; j++) {
8925be168c0dSopenharmony_ci+      auto printBuf = row[j];
8926be168c0dSopenharmony_ci+      printBuf.resize(columnLenMax.at(j), ' ');
8927be168c0dSopenharmony_ci+      printf("%s\t", printBuf.c_str());
8928be168c0dSopenharmony_ci+    }
8929be168c0dSopenharmony_ci+    printf("\n");
8930be168c0dSopenharmony_ci+  }
8931be168c0dSopenharmony_ci+  return RET_OK;
8932be168c0dSopenharmony_ci+}
8933be168c0dSopenharmony_ci+
8934be168c0dSopenharmony_ci+bool TimeProfilingBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
8935be168c0dSopenharmony_ci+                                 const OH_AI_CallBackParam kernel_Info) {
8936be168c0dSopenharmony_ci+  if (g_c_op_times_by_type_.find(kernel_Info.node_type) == g_c_op_times_by_type_.end()) {
8937be168c0dSopenharmony_ci+    g_c_op_times_by_type_.insert(std::make_pair(kernel_Info.node_type, std::make_pair(0, 0.0f)));
8938be168c0dSopenharmony_ci+  }
8939be168c0dSopenharmony_ci+  if (g_c_op_times_by_name_.find(kernel_Info.node_name) == g_c_op_times_by_name_.end()) {
8940be168c0dSopenharmony_ci+    g_c_op_times_by_name_.insert(std::make_pair(kernel_Info.node_name, std::make_pair(0, 0.0f)));
8941be168c0dSopenharmony_ci+  }
8942be168c0dSopenharmony_ci+
8943be168c0dSopenharmony_ci+  g_op_call_times_total_++;
8944be168c0dSopenharmony_ci+  g_op_begin_ = mindspore::lite::GetTimeUs();
8945be168c0dSopenharmony_ci+  return true;
8946be168c0dSopenharmony_ci+}
8947be168c0dSopenharmony_ci+
8948be168c0dSopenharmony_ci+bool TimeProfilingAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
8949be168c0dSopenharmony_ci+                                const OH_AI_CallBackParam kernel_Info) {
8950be168c0dSopenharmony_ci+  uint64_t opEnd = mindspore::lite::GetTimeUs();
8951be168c0dSopenharmony_ci+  float cost = static_cast<float>(opEnd - g_op_begin_) / 1000.0f;
8952be168c0dSopenharmony_ci+  g_op_cost_total_ += cost;
8953be168c0dSopenharmony_ci+  g_c_op_times_by_type_[kernel_Info.node_type].first++;
8954be168c0dSopenharmony_ci+  g_c_op_times_by_type_[kernel_Info.node_type].second += cost;
8955be168c0dSopenharmony_ci+  g_c_op_times_by_name_[kernel_Info.node_name].first++;
8956be168c0dSopenharmony_ci+  g_c_op_times_by_name_[kernel_Info.node_name].second += cost;
8957be168c0dSopenharmony_ci+  return true;
8958be168c0dSopenharmony_ci+}
8959be168c0dSopenharmony_ci+}  // namespace lite
8960be168c0dSopenharmony_ci+}  // namespace mindspore
8961be168c0dSopenharmony_ci+
8962be168c0dSopenharmony_ci+
8963be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/net_train_c_api.h b/mindspore/lite/tools/benchmark_train/net_train_c_api.h
8964be168c0dSopenharmony_cinew file mode 100644
8965be168c0dSopenharmony_ciindex 00000000..bb84d3c1
8966be168c0dSopenharmony_ci--- /dev/null
8967be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/net_train_c_api.h
8968be168c0dSopenharmony_ci@@ -0,0 +1,121 @@
8969be168c0dSopenharmony_ci+/**
8970be168c0dSopenharmony_ci+ * Copyright 2023-2023 Huawei Technologies Co., Ltd
8971be168c0dSopenharmony_ci+ *
8972be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
8973be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
8974be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
8975be168c0dSopenharmony_ci+ *
8976be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
8977be168c0dSopenharmony_ci+ *
8978be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
8979be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
8980be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8981be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
8982be168c0dSopenharmony_ci+ * limitations under the License.
8983be168c0dSopenharmony_ci+ */
8984be168c0dSopenharmony_ci+
8985be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H
8986be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H
8987be168c0dSopenharmony_ci+
8988be168c0dSopenharmony_ci+#include <getopt.h>
8989be168c0dSopenharmony_ci+#include <csignal>
8990be168c0dSopenharmony_ci+#include <unordered_map>
8991be168c0dSopenharmony_ci+#include <fstream>
8992be168c0dSopenharmony_ci+#include <iostream>
8993be168c0dSopenharmony_ci+#include <map>
8994be168c0dSopenharmony_ci+#include <cmath>
8995be168c0dSopenharmony_ci+#include <string>
8996be168c0dSopenharmony_ci+#include <vector>
8997be168c0dSopenharmony_ci+#include <memory>
8998be168c0dSopenharmony_ci+#include <cfloat>
8999be168c0dSopenharmony_ci+#include <utility>
9000be168c0dSopenharmony_ci+#include <algorithm>
9001be168c0dSopenharmony_ci+#include <nlohmann/json.hpp>
9002be168c0dSopenharmony_ci+#include "include/api/model.h"
9003be168c0dSopenharmony_ci+#include "include/api/types.h"
9004be168c0dSopenharmony_ci+#include "include/api/context.h"
9005be168c0dSopenharmony_ci+#include "include/api/cfg.h"
9006be168c0dSopenharmony_ci+
9007be168c0dSopenharmony_ci+#include "include/c_api/model_c.h"
9008be168c0dSopenharmony_ci+#include "include/c_api/context_c.h"
9009be168c0dSopenharmony_ci+
9010be168c0dSopenharmony_ci+#ifdef ENABLE_FP16
9011be168c0dSopenharmony_ci+#include <arm_neon.h>
9012be168c0dSopenharmony_ci+#endif
9013be168c0dSopenharmony_ci+#include "tools/common/flag_parser.h"
9014be168c0dSopenharmony_ci+#include "src/common/file_utils.h"
9015be168c0dSopenharmony_ci+#include "src/common/utils.h"
9016be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_base.h"
9017be168c0dSopenharmony_ci+
9018be168c0dSopenharmony_ci+namespace mindspore::lite {
9019be168c0dSopenharmony_ci+  namespace {
9020be168c0dSopenharmony_ci+    std::map<std::string, std::pair<int, float>> g_c_op_times_by_type_;
9021be168c0dSopenharmony_ci+    std::map<std::string, std::pair<int, float>> g_c_op_times_by_name_;
9022be168c0dSopenharmony_ci+  }
9023be168c0dSopenharmony_ci+#ifdef __cplusplus
9024be168c0dSopenharmony_ci+  extern "C" {
9025be168c0dSopenharmony_ci+#endif
9026be168c0dSopenharmony_ci+  bool TimeProfilingBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
9027be168c0dSopenharmony_ci+                                   const OH_AI_CallBackParam kernel_Info);
9028be168c0dSopenharmony_ci+  bool TimeProfilingAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
9029be168c0dSopenharmony_ci+                                  const OH_AI_CallBackParam kernel_Info);
9030be168c0dSopenharmony_ci+#ifdef __cplusplus
9031be168c0dSopenharmony_ci+  }
9032be168c0dSopenharmony_ci+#endif
9033be168c0dSopenharmony_ci+
9034be168c0dSopenharmony_ci+class MS_API NetTrainCApi : public NetTrainBase  {
9035be168c0dSopenharmony_ci+ public:
9036be168c0dSopenharmony_ci+  explicit NetTrainCApi(NetTrainFlags *flags) : NetTrainBase(flags) {}
9037be168c0dSopenharmony_ci+  virtual ~NetTrainCApi() {};
9038be168c0dSopenharmony_ci+
9039be168c0dSopenharmony_ci+ protected:
9040be168c0dSopenharmony_ci+  // call GenerateRandomData to fill inputTensors
9041be168c0dSopenharmony_ci+  int GenerateInputData() override;
9042be168c0dSopenharmony_ci+
9043be168c0dSopenharmony_ci+  int ReadInputFile() override;
9044be168c0dSopenharmony_ci+
9045be168c0dSopenharmony_ci+  int LoadStepInput(size_t step);
9046be168c0dSopenharmony_ci+
9047be168c0dSopenharmony_ci+  int InitMSContext();
9048be168c0dSopenharmony_ci+
9049be168c0dSopenharmony_ci+  void InitTrainCfg();
9050be168c0dSopenharmony_ci+
9051be168c0dSopenharmony_ci+  char **TransStrVectorToCharArrays(const std::vector<std::string> &s);
9052be168c0dSopenharmony_ci+
9053be168c0dSopenharmony_ci+  std::vector<std::string> TransCharArraysToStrVector(char **c, const size_t &num);
9054be168c0dSopenharmony_ci+
9055be168c0dSopenharmony_ci+  int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs,
9056be168c0dSopenharmony_ci+    bool check_accuracy = true) override;
9057be168c0dSopenharmony_ci+
9058be168c0dSopenharmony_ci+  int CreateAndRunNetworkForInference(const std::string &filename, const OH_AI_ContextHandle &context);
9059be168c0dSopenharmony_ci+
9060be168c0dSopenharmony_ci+  int CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename,
9061be168c0dSopenharmony_ci+    const OH_AI_ContextHandle &context,
9062be168c0dSopenharmony_ci+    const OH_AI_TrainCfgHandle &train_cfg, int epochs);
9063be168c0dSopenharmony_ci+
9064be168c0dSopenharmony_ci+  int InitDumpTensorDataCallbackParameter() override;
9065be168c0dSopenharmony_ci+
9066be168c0dSopenharmony_ci+  int InitTimeProfilingCallbackParameter() override;
9067be168c0dSopenharmony_ci+
9068be168c0dSopenharmony_ci+  int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) override;
9069be168c0dSopenharmony_ci+
9070be168c0dSopenharmony_ci+  int PrintInputData();
9071be168c0dSopenharmony_ci+
9072be168c0dSopenharmony_ci+  int MarkPerformance() override;
9073be168c0dSopenharmony_ci+
9074be168c0dSopenharmony_ci+  int MarkAccuracy(bool enforce_accuracy = true) override;
9075be168c0dSopenharmony_ci+
9076be168c0dSopenharmony_ci+  int CompareOutput() override;
9077be168c0dSopenharmony_ci+
9078be168c0dSopenharmony_ci+  int SaveModels() override;
9079be168c0dSopenharmony_ci+
9080be168c0dSopenharmony_ci+  OH_AI_ModelHandle ms_model_;
9081be168c0dSopenharmony_ci+  OH_AI_TensorHandleArray ms_inputs_for_api_;
9082be168c0dSopenharmony_ci+  OH_AI_ContextHandle context_ = nullptr;
9083be168c0dSopenharmony_ci+  OH_AI_TrainCfgHandle train_cfg_ = nullptr;
9084be168c0dSopenharmony_ci+  OH_AI_KernelCallBack before_call_back_{nullptr};
9085be168c0dSopenharmony_ci+  OH_AI_KernelCallBack after_call_back_{nullptr};
9086be168c0dSopenharmony_ci+};
9087be168c0dSopenharmony_ci+}  // namespace mindspore::lite
9088be168c0dSopenharmony_ci+
9089be168c0dSopenharmony_ci+#endif //MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H
9090be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/run_net_train.cc b/mindspore/lite/tools/benchmark_train/run_net_train.cc
9091be168c0dSopenharmony_cinew file mode 100644
9092be168c0dSopenharmony_ciindex 00000000..37a7e602
9093be168c0dSopenharmony_ci--- /dev/null
9094be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/run_net_train.cc
9095be168c0dSopenharmony_ci@@ -0,0 +1,86 @@
9096be168c0dSopenharmony_ci+/**
9097be168c0dSopenharmony_ci+ * Copyright 2020 Huawei Technologies Co., Ltd
9098be168c0dSopenharmony_ci+ *
9099be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
9100be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
9101be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
9102be168c0dSopenharmony_ci+ *
9103be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
9104be168c0dSopenharmony_ci+ *
9105be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
9106be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
9107be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9108be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
9109be168c0dSopenharmony_ci+ * limitations under the License.
9110be168c0dSopenharmony_ci+ */
9111be168c0dSopenharmony_ci+
9112be168c0dSopenharmony_ci+#include "tools/benchmark_train/run_net_train.h"
9113be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train.h"
9114be168c0dSopenharmony_ci+#include "tools/benchmark_train/net_train_c_api.h"
9115be168c0dSopenharmony_ci+
9116be168c0dSopenharmony_ci+namespace mindspore {
9117be168c0dSopenharmony_ci+namespace lite {
9118be168c0dSopenharmony_ci+int RunNetTrain(int argc, const char **argv) {
9119be168c0dSopenharmony_ci+  NetTrainFlags flags;
9120be168c0dSopenharmony_ci+  Option<std::string> err = flags.ParseFlags(argc, argv);
9121be168c0dSopenharmony_ci+
9122be168c0dSopenharmony_ci+  if (err.IsSome()) {
9123be168c0dSopenharmony_ci+    std::cerr << err.Get() << std::endl;
9124be168c0dSopenharmony_ci+    std::cerr << flags.Usage() << std::endl;
9125be168c0dSopenharmony_ci+    return RET_ERROR;
9126be168c0dSopenharmony_ci+  }
9127be168c0dSopenharmony_ci+
9128be168c0dSopenharmony_ci+  if (flags.help) {
9129be168c0dSopenharmony_ci+    std::cerr << flags.Usage() << std::endl;
9130be168c0dSopenharmony_ci+    return RET_OK;
9131be168c0dSopenharmony_ci+  }
9132be168c0dSopenharmony_ci+  if (flags.unified_api_) {
9133be168c0dSopenharmony_ci+    return NetTrain::RunNr(&flags);
9134be168c0dSopenharmony_ci+  }
9135be168c0dSopenharmony_ci+
9136be168c0dSopenharmony_ci+  auto api_type = std::getenv("MSLITE_API_TYPE");
9137be168c0dSopenharmony_ci+  if (api_type != nullptr) {
9138be168c0dSopenharmony_ci+    MS_LOG(INFO) << "MSLITE_API_TYPE = " << api_type;
9139be168c0dSopenharmony_ci+    std::cout << "MSLITE_API_TYPE = " << api_type << std::endl;
9140be168c0dSopenharmony_ci+  }
9141be168c0dSopenharmony_ci+
9142be168c0dSopenharmony_ci+  NetTrainBase *net_trainer = nullptr;
9143be168c0dSopenharmony_ci+  if (api_type == nullptr || std::string(api_type) == "NEW") {
9144be168c0dSopenharmony_ci+    net_trainer = new (std::nothrow) NetTrain(&flags);
9145be168c0dSopenharmony_ci+  } else if (std::string(api_type) == "C") {
9146be168c0dSopenharmony_ci+    net_trainer = new (std::nothrow) NetTrainCApi(&flags);
9147be168c0dSopenharmony_ci+  } else {
9148be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Invalid MSLITE_API_TYPE, (NEW/C, default:NEW)";
9149be168c0dSopenharmony_ci+    return RET_ERROR;
9150be168c0dSopenharmony_ci+  }
9151be168c0dSopenharmony_ci+
9152be168c0dSopenharmony_ci+  if (net_trainer == nullptr) {
9153be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "new net_trainer failed.";
9154be168c0dSopenharmony_ci+    return RET_ERROR;
9155be168c0dSopenharmony_ci+  }
9156be168c0dSopenharmony_ci+  auto status = net_trainer->Init();
9157be168c0dSopenharmony_ci+  if (status != RET_OK) {
9158be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "NetTrain init Error : " << status;
9159be168c0dSopenharmony_ci+    std::cerr << "NetTrain init Error : " << status << std::endl;
9160be168c0dSopenharmony_ci+    return RET_ERROR;
9161be168c0dSopenharmony_ci+  }
9162be168c0dSopenharmony_ci+
9163be168c0dSopenharmony_ci+  status = net_trainer->RunNetTrain();
9164be168c0dSopenharmony_ci+  if (status != RET_OK) {
9165be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Run NetTrain "
9166be168c0dSopenharmony_ci+                  << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str()
9167be168c0dSopenharmony_ci+                  << " Failed : " << status;
9168be168c0dSopenharmony_ci+    std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str()
9169be168c0dSopenharmony_ci+              << " Failed : " << status << std::endl;
9170be168c0dSopenharmony_ci+    return RET_ERROR;
9171be168c0dSopenharmony_ci+  }
9172be168c0dSopenharmony_ci+
9173be168c0dSopenharmony_ci+  MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str()
9174be168c0dSopenharmony_ci+               << " Success.";
9175be168c0dSopenharmony_ci+  std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str()
9176be168c0dSopenharmony_ci+            << " Success." << std::endl;
9177be168c0dSopenharmony_ci+  delete net_trainer;
9178be168c0dSopenharmony_ci+  return RET_OK;
9179be168c0dSopenharmony_ci+}
9180be168c0dSopenharmony_ci+}  // namespace lite
9181be168c0dSopenharmony_ci+}  // namespace mindspore
9182be168c0dSopenharmony_ci\ No newline at end of file
9183be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/benchmark_train/run_net_train.h b/mindspore/lite/tools/benchmark_train/run_net_train.h
9184be168c0dSopenharmony_cinew file mode 100644
9185be168c0dSopenharmony_ciindex 00000000..9ca2d73c
9186be168c0dSopenharmony_ci--- /dev/null
9187be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/benchmark_train/run_net_train.h
9188be168c0dSopenharmony_ci@@ -0,0 +1,22 @@
9189be168c0dSopenharmony_ci+/**
9190be168c0dSopenharmony_ci+ * Copyright 2023-2023 Huawei Technologies Co., Ltd
9191be168c0dSopenharmony_ci+ *
9192be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
9193be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
9194be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
9195be168c0dSopenharmony_ci+ *
9196be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
9197be168c0dSopenharmony_ci+ *
9198be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
9199be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
9200be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9201be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
9202be168c0dSopenharmony_ci+ * limitations under the License.
9203be168c0dSopenharmony_ci+ */
9204be168c0dSopenharmony_ci+
9205be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H
9206be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H
9207be168c0dSopenharmony_ci+namespace mindspore::lite {
9208be168c0dSopenharmony_ci+int RunNetTrain(int argc, const char **argv);
9209be168c0dSopenharmony_ci+}  // namespace mindspore::lite
9210be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H
9211be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt
9212be168c0dSopenharmony_ciindex 1e09d2ed..f854620f 100644
9213be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/CMakeLists.txt
9214be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/CMakeLists.txt
9215be168c0dSopenharmony_ci@@ -7,6 +7,8 @@ endif()
9216be168c0dSopenharmony_ci 
9217be168c0dSopenharmony_ci set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
9218be168c0dSopenharmony_ci 
9219be168c0dSopenharmony_ci+include_directories(${CMAKE_SOURCE_DIR}/mindspore/lite/)
9220be168c0dSopenharmony_ci+
9221be168c0dSopenharmony_ci if(ENABLE_GPU)
9222be168c0dSopenharmony_ci     add_compile_definitions(ENABLE_GPU)
9223be168c0dSopenharmony_ci endif()
9224be168c0dSopenharmony_ci@@ -70,6 +72,7 @@ add_subdirectory(parser/caffe)
9225be168c0dSopenharmony_ci add_subdirectory(parser/tflite)
9226be168c0dSopenharmony_ci add_subdirectory(parser/onnx)
9227be168c0dSopenharmony_ci add_subdirectory(parser/tf)
9228be168c0dSopenharmony_ci+add_subdirectory(parser/third_party)
9229be168c0dSopenharmony_ci if(ENABLE_CONVERT_PYTORCH_MODEL)
9230be168c0dSopenharmony_ci     add_subdirectory(parser/pytorch)
9231be168c0dSopenharmony_ci endif()
9232be168c0dSopenharmony_ci@@ -363,6 +366,7 @@ target_link_libraries(mindspore_converter
9233be168c0dSopenharmony_ci         tf_parser_mid
9234be168c0dSopenharmony_ci         caffe_parser_mid
9235be168c0dSopenharmony_ci         onnx_parser_mid
9236be168c0dSopenharmony_ci+        third_party_parser_mid
9237be168c0dSopenharmony_ci         lite_exporter_mid
9238be168c0dSopenharmony_ci         graph_pass_mid
9239be168c0dSopenharmony_ci         fusion_mid
9240be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
9241be168c0dSopenharmony_ciindex fecc56d9..2e7ca749 100644
9242be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
9243be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
9244be168c0dSopenharmony_ci@@ -34,6 +34,7 @@ constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param";
9245be168c0dSopenharmony_ci constexpr auto kDataPreprocessParam = "data_preprocess_param";
9246be168c0dSopenharmony_ci constexpr auto kRegistry = "registry";
9247be168c0dSopenharmony_ci constexpr auto kMicroParam = "micro_param";
9248be168c0dSopenharmony_ci+constexpr auto kThirdPartyModelParam = "third_party_model";
9249be168c0dSopenharmony_ci constexpr auto kCpuOptionParam = "cpu_option_cfg_param";
9250be168c0dSopenharmony_ci constexpr auto kCustomOppPath = "custom_opp_path";
9251be168c0dSopenharmony_ci constexpr auto kTransformQuantParam = "transform_quant_param";
9252be168c0dSopenharmony_ci@@ -330,6 +331,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin
9253be168c0dSopenharmony_ci     MS_LOG(ERROR) << "ParseMicroParamString failed.";
9254be168c0dSopenharmony_ci     return ret;
9255be168c0dSopenharmony_ci   }
9256be168c0dSopenharmony_ci+  ret = ParseThirdPartyParamString(*maps);
9257be168c0dSopenharmony_ci+  (void)maps->erase(kThirdPartyModelParam);
9258be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9259be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "ParseTransformQuantString failed.";
9260be168c0dSopenharmony_ci+    return ret;
9261be168c0dSopenharmony_ci+  }
9262be168c0dSopenharmony_ci   ret = ParseWeightQuantString(*maps);
9263be168c0dSopenharmony_ci   (void)maps->erase(kWeightQuantParam);
9264be168c0dSopenharmony_ci   if (ret != RET_OK) {
9265be168c0dSopenharmony_ci@@ -594,5 +601,25 @@ int ConfigFileParser::ParseGraphKernelString(const std::map<std::string, std::ma
9266be168c0dSopenharmony_ci   }
9267be168c0dSopenharmony_ci   return RET_OK;
9268be168c0dSopenharmony_ci }
9269be168c0dSopenharmony_ci+
9270be168c0dSopenharmony_ci+int ConfigFileParser::ParseThirdPartyParamString(
9271be168c0dSopenharmony_ci+  const std::map<std::string, std::map<std::string, std::string>> &sections) {
9272be168c0dSopenharmony_ci+  if (sections.find(kThirdPartyModelParam) == sections.end()) {
9273be168c0dSopenharmony_ci+    return RET_OK;
9274be168c0dSopenharmony_ci+  }
9275be168c0dSopenharmony_ci+  const auto &input_args = sections.at(kThirdPartyModelParam);
9276be168c0dSopenharmony_ci+  const std::map<std::string, std::string &> kValidArgs = {
9277be168c0dSopenharmony_ci+    {"input_shapes", third_party_model_string_.input_shapes},
9278be168c0dSopenharmony_ci+    {"input_dtypes", third_party_model_string_.input_dtypes},
9279be168c0dSopenharmony_ci+    {"input_names", third_party_model_string_.input_names},
9280be168c0dSopenharmony_ci+    {"input_formats", third_party_model_string_.input_formats},
9281be168c0dSopenharmony_ci+    {"output_shapes", third_party_model_string_.output_shapes},
9282be168c0dSopenharmony_ci+    {"output_dtypes", third_party_model_string_.output_dtypes},
9283be168c0dSopenharmony_ci+    {"output_names", third_party_model_string_.output_names},
9284be168c0dSopenharmony_ci+    {"output_formats", third_party_model_string_.output_formats},
9285be168c0dSopenharmony_ci+    {"extended_parameters", third_party_model_string_.extended_parameters},
9286be168c0dSopenharmony_ci+  };
9287be168c0dSopenharmony_ci+  return SetMapData(input_args, kValidArgs, kThirdPartyModelParam);
9288be168c0dSopenharmony_ci+}
9289be168c0dSopenharmony_ci }  // namespace lite
9290be168c0dSopenharmony_ci }  // namespace mindspore
9291be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h
9292be168c0dSopenharmony_ciindex 31269816..6997bac8 100644
9293be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h
9294be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h
9295be168c0dSopenharmony_ci@@ -110,6 +110,18 @@ struct MicroParamString {
9296be168c0dSopenharmony_ci   std::string changeable_weights_name;
9297be168c0dSopenharmony_ci };
9298be168c0dSopenharmony_ci 
9299be168c0dSopenharmony_ci+struct ThirdPartyModelString {
9300be168c0dSopenharmony_ci+  std::string input_dtypes;
9301be168c0dSopenharmony_ci+  std::string input_shapes;
9302be168c0dSopenharmony_ci+  std::string input_names;  // optional, default: ""
9303be168c0dSopenharmony_ci+  std::string input_formats;  // optional, default: NHWC
9304be168c0dSopenharmony_ci+  std::string output_dtypes;
9305be168c0dSopenharmony_ci+  std::string output_shapes;
9306be168c0dSopenharmony_ci+  std::string output_names;  // optional, default: ""
9307be168c0dSopenharmony_ci+  std::string output_formats;  // optional, default: NHWC
9308be168c0dSopenharmony_ci+  std::string extended_parameters;  // format: {key1:value1;ker2:value2}
9309be168c0dSopenharmony_ci+};
9310be168c0dSopenharmony_ci+
9311be168c0dSopenharmony_ci struct CpuOptionCfgString {
9312be168c0dSopenharmony_ci   std::string architecture;
9313be168c0dSopenharmony_ci   std::string instruction;
9314be168c0dSopenharmony_ci@@ -144,6 +156,7 @@ class ConfigFileParser {
9315be168c0dSopenharmony_ci   RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; }
9316be168c0dSopenharmony_ci   AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; }
9317be168c0dSopenharmony_ci   MicroParamString GetMicroParamString() { return this->micro_param_string_; }
9318be168c0dSopenharmony_ci+  lite::ThirdPartyModelString GetThirdPartyModelString() const { return this->third_party_model_string_; }
9319be168c0dSopenharmony_ci   CpuOptionCfgString GetCpuOptionCfgString() { return this->cpu_option_cfg_string_; }
9320be168c0dSopenharmony_ci   TransformQuantString GetTransformQuantString() const { return this->transform_quant_string_; }
9321be168c0dSopenharmony_ci   AscendQuantString GetAscendQuantString() const { return this->ascend_quant_string_; }
9322be168c0dSopenharmony_ci@@ -161,6 +174,7 @@ class ConfigFileParser {
9323be168c0dSopenharmony_ci   int SetMapData(const std::map<std::string, std::string> &input_map,
9324be168c0dSopenharmony_ci                  const std::map<std::string, std::string &> &parse_map, const std::string &section);
9325be168c0dSopenharmony_ci   int ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps);
9326be168c0dSopenharmony_ci+  int ParseThirdPartyParamString(const std::map<std::string, std::map<std::string, std::string>> &sections);
9327be168c0dSopenharmony_ci   int ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps);
9328be168c0dSopenharmony_ci   int ParseTransformQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
9329be168c0dSopenharmony_ci   int ParseAscendQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
9330be168c0dSopenharmony_ci@@ -176,6 +190,7 @@ class ConfigFileParser {
9331be168c0dSopenharmony_ci   RegistryInfoString registry_info_string_;
9332be168c0dSopenharmony_ci   AclOptionCfgString acl_option_cfg_string_;
9333be168c0dSopenharmony_ci   MicroParamString micro_param_string_;
9334be168c0dSopenharmony_ci+  lite::ThirdPartyModelString third_party_model_string_;
9335be168c0dSopenharmony_ci   CpuOptionCfgString cpu_option_cfg_string_;
9336be168c0dSopenharmony_ci   TransformQuantString transform_quant_string_;
9337be168c0dSopenharmony_ci   AscendQuantString ascend_quant_string_;
9338be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc
9339be168c0dSopenharmony_cinew file mode 100644
9340be168c0dSopenharmony_ciindex 00000000..aee6a29c
9341be168c0dSopenharmony_ci--- /dev/null
9342be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc
9343be168c0dSopenharmony_ci@@ -0,0 +1,299 @@
9344be168c0dSopenharmony_ci+/**
9345be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
9346be168c0dSopenharmony_ci+ *
9347be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
9348be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
9349be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
9350be168c0dSopenharmony_ci+ *
9351be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
9352be168c0dSopenharmony_ci+ *
9353be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
9354be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
9355be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9356be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
9357be168c0dSopenharmony_ci+ * limitations under the License.
9358be168c0dSopenharmony_ci+ */
9359be168c0dSopenharmony_ci+
9360be168c0dSopenharmony_ci+#include "tools/converter/config_parser/third_party_param_parser.h"
9361be168c0dSopenharmony_ci+#include <vector>
9362be168c0dSopenharmony_ci+#include <string>
9363be168c0dSopenharmony_ci+#include <map>
9364be168c0dSopenharmony_ci+#include "include/errorcode.h"
9365be168c0dSopenharmony_ci+#include "src/common/log_adapter.h"
9366be168c0dSopenharmony_ci+#include "nnacl/op_base.h"
9367be168c0dSopenharmony_ci+#include "tools/common/string_util.h"
9368be168c0dSopenharmony_ci+
9369be168c0dSopenharmony_ci+namespace mindspore {
9370be168c0dSopenharmony_ci+namespace lite {
9371be168c0dSopenharmony_ci+namespace {
9372be168c0dSopenharmony_ci+const std::map<std::string, TypeId> kDataTypeMap = {
9373be168c0dSopenharmony_ci+  {"float64", TypeId::kNumberTypeFloat64}, {"float32", TypeId::kNumberTypeFloat32},
9374be168c0dSopenharmony_ci+  {"float16", TypeId::kNumberTypeFloat16}, {"int64", TypeId::kNumberTypeInt64},
9375be168c0dSopenharmony_ci+  {"int32", TypeId::kNumberTypeInt32},     {"int16", TypeId::kNumberTypeInt16},
9376be168c0dSopenharmony_ci+  {"int8", TypeId::kNumberTypeInt8},       {"uint8", TypeId::kNumberTypeUInt8},
9377be168c0dSopenharmony_ci+  {"bool", TypeId::kNumberTypeBool},
9378be168c0dSopenharmony_ci+};
9379be168c0dSopenharmony_ci+
9380be168c0dSopenharmony_ci+TypeId ConvertDataType(const std::string &type) {
9381be168c0dSopenharmony_ci+  auto iter = kDataTypeMap.find(type);
9382be168c0dSopenharmony_ci+  if (iter == kDataTypeMap.end()) {
9383be168c0dSopenharmony_ci+    return TypeId::kTypeUnknown;
9384be168c0dSopenharmony_ci+  }
9385be168c0dSopenharmony_ci+  return iter->second;
9386be168c0dSopenharmony_ci+}
9387be168c0dSopenharmony_ci+}  // namespace
9388be168c0dSopenharmony_ci+
9389be168c0dSopenharmony_ci+/**
9390be168c0dSopenharmony_ci+ * Parse shapes like "1,256,256,3;3,96;96,96", and return like [[1,256,256,3], [3,96], [96,96]].
9391be168c0dSopenharmony_ci+ */
9392be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes) {
9393be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(dst_shapes != nullptr, RET_ERROR);
9394be168c0dSopenharmony_ci+  dst_shapes->clear();
9395be168c0dSopenharmony_ci+
9396be168c0dSopenharmony_ci+  auto tmp_shapes = SplitStringToVector(src, ";");
9397be168c0dSopenharmony_ci+  for (auto tmp_shape : tmp_shapes) {
9398be168c0dSopenharmony_ci+    auto tmp = SplitStringToVector(tmp_shape, ",");
9399be168c0dSopenharmony_ci+    std::vector<int64_t> shape = {};
9400be168c0dSopenharmony_ci+    for (auto t : tmp) {
9401be168c0dSopenharmony_ci+      int value = 0;
9402be168c0dSopenharmony_ci+      if (!ConvertIntNum(t, &value)) {
9403be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Found error when convert shape string to integer";
9404be168c0dSopenharmony_ci+        return RET_ERROR;
9405be168c0dSopenharmony_ci+      }
9406be168c0dSopenharmony_ci+      if (value <= 0) {  // Valid shape value should be greater than 0.
9407be168c0dSopenharmony_ci+        MS_LOG(ERROR) << "Only support fixed shapes in third party param";
9408be168c0dSopenharmony_ci+        return RET_ERROR;
9409be168c0dSopenharmony_ci+      }
9410be168c0dSopenharmony_ci+      shape.push_back(value);
9411be168c0dSopenharmony_ci+    }
9412be168c0dSopenharmony_ci+    dst_shapes->push_back(shape);
9413be168c0dSopenharmony_ci+  }
9414be168c0dSopenharmony_ci+  return RET_OK;
9415be168c0dSopenharmony_ci+}
9416be168c0dSopenharmony_ci+
9417be168c0dSopenharmony_ci+/**
9418be168c0dSopenharmony_ci+ * Parse extended parameter like "key_1:value_1;key_2:value_2" and get {{"key_1", "value_1"}, {"key_2", "value_2"}}.
9419be168c0dSopenharmony_ci+ */
9420be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseExtendedParameters(const std::string &src,
9421be168c0dSopenharmony_ci+                                                     std::map<std::string, std::vector<uint8_t>> *dst_ext_param) {
9422be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(dst_ext_param != nullptr, RET_ERROR);
9423be168c0dSopenharmony_ci+  constexpr size_t kKeyIndex = 0U;
9424be168c0dSopenharmony_ci+  constexpr size_t kValueIndex = 1U;
9425be168c0dSopenharmony_ci+  constexpr size_t kKeyValueSize = 2U;
9426be168c0dSopenharmony_ci+
9427be168c0dSopenharmony_ci+  if (src == "") {  // Just return if 'extended_parameters' is configured.
9428be168c0dSopenharmony_ci+    return RET_OK;
9429be168c0dSopenharmony_ci+  }
9430be168c0dSopenharmony_ci+
9431be168c0dSopenharmony_ci+  auto tmp_list = SplitStringToVector(src, ";");
9432be168c0dSopenharmony_ci+  std::map<std::string, std::vector<uint8_t>> tmp_map = {};
9433be168c0dSopenharmony_ci+  for (auto tmp : tmp_list) {
9434be168c0dSopenharmony_ci+    auto key_and_value = SplitStringToVector(tmp, ":");
9435be168c0dSopenharmony_ci+    if (key_and_value.size() != kKeyValueSize) {
9436be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Parse extended parameters failed, should keep key:value format";
9437be168c0dSopenharmony_ci+      return RET_ERROR;
9438be168c0dSopenharmony_ci+    }
9439be168c0dSopenharmony_ci+    auto key = key_and_value[kKeyIndex];
9440be168c0dSopenharmony_ci+    auto value = key_and_value[kValueIndex];
9441be168c0dSopenharmony_ci+    if (tmp_map.find(key) != tmp_map.end()) {
9442be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Parse extended parameters failed, key should not be duplicated";
9443be168c0dSopenharmony_ci+      return RET_ERROR;
9444be168c0dSopenharmony_ci+    }
9445be168c0dSopenharmony_ci+    tmp_map.emplace(key, std::vector<uint8_t>(value.begin(), value.end()));
9446be168c0dSopenharmony_ci+  }
9447be168c0dSopenharmony_ci+
9448be168c0dSopenharmony_ci+  *dst_ext_param = tmp_map;
9449be168c0dSopenharmony_ci+  return RET_OK;
9450be168c0dSopenharmony_ci+}
9451be168c0dSopenharmony_ci+
9452be168c0dSopenharmony_ci+/**
9453be168c0dSopenharmony_ci+ * Parse dtypes like "float32;float32;int32" and return [kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32]
9454be168c0dSopenharmony_ci+ */
9455be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes) {
9456be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(dst_dtypes != nullptr, RET_ERROR);
9457be168c0dSopenharmony_ci+  dst_dtypes->clear();
9458be168c0dSopenharmony_ci+  auto tmp_dtypes = SplitStringToVector(src, ";");
9459be168c0dSopenharmony_ci+  for (auto tmp_dtype : tmp_dtypes) {
9460be168c0dSopenharmony_ci+    TypeId type = ConvertDataType(tmp_dtype);
9461be168c0dSopenharmony_ci+    if (type == kTypeUnknown) {
9462be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Parse dtypes in third party model config failed";
9463be168c0dSopenharmony_ci+      return RET_ERROR;
9464be168c0dSopenharmony_ci+    }
9465be168c0dSopenharmony_ci+    dst_dtypes->push_back(type);
9466be168c0dSopenharmony_ci+  }
9467be168c0dSopenharmony_ci+  return RET_OK;
9468be168c0dSopenharmony_ci+}
9469be168c0dSopenharmony_ci+
9470be168c0dSopenharmony_ci+/**
9471be168c0dSopenharmony_ci+ * Parse names like "foo;bar;boo" and get ["foo", "bar", "boo"]
9472be168c0dSopenharmony_ci+ * If input names are not provided in config, use the default prefix to generate like: "in_0;in_1;..;in_n"
9473be168c0dSopenharmony_ci+ */
9474be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseNames(const std::string &src, size_t num, const std::string &default_prefix,
9475be168c0dSopenharmony_ci+                                        std::vector<std::string> *dst_names) {
9476be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(dst_names != nullptr, RET_ERROR);
9477be168c0dSopenharmony_ci+  std::string tmp_names = src;
9478be168c0dSopenharmony_ci+  if (tmp_names.empty()) {
9479be168c0dSopenharmony_ci+    std::string tmp = "";
9480be168c0dSopenharmony_ci+    for (size_t i = 0; i < num; i++) {
9481be168c0dSopenharmony_ci+      tmp += default_prefix + "_" + std::to_string(i);
9482be168c0dSopenharmony_ci+      if (i + 1 < num) {
9483be168c0dSopenharmony_ci+        tmp += ";";
9484be168c0dSopenharmony_ci+      }
9485be168c0dSopenharmony_ci+    }
9486be168c0dSopenharmony_ci+    tmp_names = tmp;
9487be168c0dSopenharmony_ci+  }
9488be168c0dSopenharmony_ci+
9489be168c0dSopenharmony_ci+  *dst_names = SplitStringToVector(tmp_names, ";");
9490be168c0dSopenharmony_ci+  if (dst_names->size() != num) {
9491be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Name number " << dst_names->size() << " and input number: " << num << " are not equal";
9492be168c0dSopenharmony_ci+    return RET_ERROR;
9493be168c0dSopenharmony_ci+  }
9494be168c0dSopenharmony_ci+  return RET_OK;
9495be168c0dSopenharmony_ci+}
9496be168c0dSopenharmony_ci+
9497be168c0dSopenharmony_ci+/**
9498be168c0dSopenharmony_ci+ * Parse formats like "NCHW;NHWC" and get [NCHW, NHWC]
9499be168c0dSopenharmony_ci+ */
9500be168c0dSopenharmony_ci+namespace {
9501be168c0dSopenharmony_ci+  int StringToFormat(const std::string &format_string, schema::Format *format) {
9502be168c0dSopenharmony_ci+    static const std::unordered_map<std::string, schema::Format> kFormatTable = {
9503be168c0dSopenharmony_ci+      {"NCHW", schema::Format::Format_NCHW},
9504be168c0dSopenharmony_ci+      {"NHWC", schema::Format::Format_NHWC},
9505be168c0dSopenharmony_ci+      {"NHWC4", schema::Format::Format_NHWC4},
9506be168c0dSopenharmony_ci+      {"HWKC", schema::Format::Format_HWKC},
9507be168c0dSopenharmony_ci+      {"HWCK", schema::Format::Format_HWCK},
9508be168c0dSopenharmony_ci+      {"KCHW", schema::Format::Format_KCHW},
9509be168c0dSopenharmony_ci+      {"CKHW", schema::Format::Format_CKHW},
9510be168c0dSopenharmony_ci+      {"KHWC", schema::Format::Format_KHWC},
9511be168c0dSopenharmony_ci+      {"CHWK", schema::Format::Format_CHWK},
9512be168c0dSopenharmony_ci+      {"HW", schema::Format::Format_HW},
9513be168c0dSopenharmony_ci+      {"HW4", schema::Format::Format_HW4},
9514be168c0dSopenharmony_ci+      {"NC", schema::Format::Format_NC},
9515be168c0dSopenharmony_ci+      {"NC4", schema::Format::Format_NC4},
9516be168c0dSopenharmony_ci+      {"NC4HW4", schema::Format::Format_NC4HW4},
9517be168c0dSopenharmony_ci+      {"NUM_OF_FORMAT", schema::Format::Format_NUM_OF_FORMAT},
9518be168c0dSopenharmony_ci+      {"NCDHW", schema::Format::Format_NCDHW},
9519be168c0dSopenharmony_ci+      {"NWC", schema::Format::Format_NWC},
9520be168c0dSopenharmony_ci+      {"NCW", schema::Format::Format_NCW},
9521be168c0dSopenharmony_ci+    };
9522be168c0dSopenharmony_ci+
9523be168c0dSopenharmony_ci+    if (format == nullptr) {
9524be168c0dSopenharmony_ci+      return RET_NULL_PTR;
9525be168c0dSopenharmony_ci+    }
9526be168c0dSopenharmony_ci+
9527be168c0dSopenharmony_ci+    auto iter = kFormatTable.find(format_string);
9528be168c0dSopenharmony_ci+    if (iter == kFormatTable.end()) {
9529be168c0dSopenharmony_ci+      return RET_PARAM_INVALID;
9530be168c0dSopenharmony_ci+    }
9531be168c0dSopenharmony_ci+
9532be168c0dSopenharmony_ci+    *format = iter->second;
9533be168c0dSopenharmony_ci+    return RET_OK;
9534be168c0dSopenharmony_ci+  }
9535be168c0dSopenharmony_ci+}
9536be168c0dSopenharmony_ci+
9537be168c0dSopenharmony_ci+int ThirdPartyParamParser::DoParseFormats(const std::string &src, size_t num,
9538be168c0dSopenharmony_ci+                                          std::vector<schema::Format> *result_formats) {
9539be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(result_formats != nullptr, RET_ERROR);
9540be168c0dSopenharmony_ci+  std::string tmp_names = src;
9541be168c0dSopenharmony_ci+  if (tmp_names.empty()) {
9542be168c0dSopenharmony_ci+    std::vector<schema::Format> default_formats(num, schema::Format::Format_NHWC);
9543be168c0dSopenharmony_ci+    *result_formats = default_formats;
9544be168c0dSopenharmony_ci+    return RET_OK;
9545be168c0dSopenharmony_ci+  }
9546be168c0dSopenharmony_ci+
9547be168c0dSopenharmony_ci+  auto format_strings = SplitStringToVector(tmp_names, ";");
9548be168c0dSopenharmony_ci+  if (format_strings.size() != num) {
9549be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Number of format: " << format_strings.size() << " and number of tensor: " << num << " are not equal";
9550be168c0dSopenharmony_ci+    return RET_ERROR;
9551be168c0dSopenharmony_ci+  }
9552be168c0dSopenharmony_ci+
9553be168c0dSopenharmony_ci+  std::vector<schema::Format> result(num);
9554be168c0dSopenharmony_ci+  for (size_t i = 0; i < num; i++) {
9555be168c0dSopenharmony_ci+    if (StringToFormat(format_strings[i], &result[i]) != RET_OK) {
9556be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Tensor format:" << format_strings[i] << " is invalid";
9557be168c0dSopenharmony_ci+      return RET_PARAM_INVALID;
9558be168c0dSopenharmony_ci+    }
9559be168c0dSopenharmony_ci+  }
9560be168c0dSopenharmony_ci+  *result_formats = result;
9561be168c0dSopenharmony_ci+  return RET_OK;
9562be168c0dSopenharmony_ci+}
9563be168c0dSopenharmony_ci+
9564be168c0dSopenharmony_ci+int ThirdPartyParamParser::Parse(const ThirdPartyModelString &param_string, ThirdPartyModelParam *param) {
9565be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR);
9566be168c0dSopenharmony_ci+
9567be168c0dSopenharmony_ci+  auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes));
9568be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9569be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse input shapes of third party param failed";
9570be168c0dSopenharmony_ci+    return RET_ERROR;
9571be168c0dSopenharmony_ci+  }
9572be168c0dSopenharmony_ci+
9573be168c0dSopenharmony_ci+  ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes));
9574be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9575be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse input dtypes of third party param failed";
9576be168c0dSopenharmony_ci+    return RET_ERROR;
9577be168c0dSopenharmony_ci+  }
9578be168c0dSopenharmony_ci+
9579be168c0dSopenharmony_ci+  auto input_shape_num = param->input_shapes.size();
9580be168c0dSopenharmony_ci+  auto input_dtype_num = param->input_dtypes.size();
9581be168c0dSopenharmony_ci+  if (input_shape_num != input_dtype_num) {
9582be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Input shape number: " << input_shape_num << " and dtype number: " << input_dtype_num
9583be168c0dSopenharmony_ci+                  << " are not equal";
9584be168c0dSopenharmony_ci+    return RET_ERROR;
9585be168c0dSopenharmony_ci+  }
9586be168c0dSopenharmony_ci+
9587be168c0dSopenharmony_ci+  ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats));
9588be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9589be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse input formats of third party param failed";
9590be168c0dSopenharmony_ci+    return RET_ERROR;
9591be168c0dSopenharmony_ci+  }
9592be168c0dSopenharmony_ci+
9593be168c0dSopenharmony_ci+  const std::string kInputNamePrefix = "in";
9594be168c0dSopenharmony_ci+  ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names));
9595be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9596be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse input names of third party param failed";
9597be168c0dSopenharmony_ci+    return RET_ERROR;
9598be168c0dSopenharmony_ci+  }
9599be168c0dSopenharmony_ci+
9600be168c0dSopenharmony_ci+  ret = DoParseShape(param_string.output_shapes, &(param->output_shapes));
9601be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9602be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse output shaped of third party param failed";
9603be168c0dSopenharmony_ci+    return RET_ERROR;
9604be168c0dSopenharmony_ci+  }
9605be168c0dSopenharmony_ci+
9606be168c0dSopenharmony_ci+  ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes));
9607be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9608be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse output dtypes of third party param failed";
9609be168c0dSopenharmony_ci+    return RET_ERROR;
9610be168c0dSopenharmony_ci+  }
9611be168c0dSopenharmony_ci+
9612be168c0dSopenharmony_ci+  auto output_shape_num = param->output_shapes.size();
9613be168c0dSopenharmony_ci+  auto output_dtype_num = param->output_dtypes.size();
9614be168c0dSopenharmony_ci+  if (output_shape_num != output_dtype_num) {
9615be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Output shape number: " << output_shape_num << " and dtype number: " << output_dtype_num
9616be168c0dSopenharmony_ci+                  << " are not equal";
9617be168c0dSopenharmony_ci+    return RET_ERROR;
9618be168c0dSopenharmony_ci+  }
9619be168c0dSopenharmony_ci+
9620be168c0dSopenharmony_ci+  ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats));
9621be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9622be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse output formats of third party param failed";
9623be168c0dSopenharmony_ci+    return RET_ERROR;
9624be168c0dSopenharmony_ci+  }
9625be168c0dSopenharmony_ci+
9626be168c0dSopenharmony_ci+  const std::string kOutputNamePrefix = "out";
9627be168c0dSopenharmony_ci+  ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names));
9628be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9629be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse output names of third party param failed";
9630be168c0dSopenharmony_ci+    return RET_ERROR;
9631be168c0dSopenharmony_ci+  }
9632be168c0dSopenharmony_ci+
9633be168c0dSopenharmony_ci+  ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters));
9634be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9635be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse extended parameter of third party param failed";
9636be168c0dSopenharmony_ci+    return RET_ERROR;
9637be168c0dSopenharmony_ci+  }
9638be168c0dSopenharmony_ci+
9639be168c0dSopenharmony_ci+  return RET_OK;
9640be168c0dSopenharmony_ci+}
9641be168c0dSopenharmony_ci+}  // namespace lite
9642be168c0dSopenharmony_ci+}  // namespace mindspore
9643be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h
9644be168c0dSopenharmony_cinew file mode 100644
9645be168c0dSopenharmony_ciindex 00000000..5cf6e8fb
9646be168c0dSopenharmony_ci--- /dev/null
9647be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h
9648be168c0dSopenharmony_ci@@ -0,0 +1,44 @@
9649be168c0dSopenharmony_ci+/**
9650be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
9651be168c0dSopenharmony_ci+ *
9652be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
9653be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
9654be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
9655be168c0dSopenharmony_ci+ *
9656be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
9657be168c0dSopenharmony_ci+ *
9658be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
9659be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
9660be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9661be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
9662be168c0dSopenharmony_ci+ * limitations under the License.
9663be168c0dSopenharmony_ci+ */
9664be168c0dSopenharmony_ci+
9665be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_
9666be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_
9667be168c0dSopenharmony_ci+#include <string>
9668be168c0dSopenharmony_ci+#include <vector>
9669be168c0dSopenharmony_ci+#include <map>
9670be168c0dSopenharmony_ci+#include "include/errorcode.h"
9671be168c0dSopenharmony_ci+#include "tools/converter/cxx_api/converter_para.h"
9672be168c0dSopenharmony_ci+#include "tools/converter/config_parser/config_file_parser.h"
9673be168c0dSopenharmony_ci+
9674be168c0dSopenharmony_ci+namespace mindspore {
9675be168c0dSopenharmony_ci+namespace lite {
9676be168c0dSopenharmony_ci+class ThirdPartyParamParser {
9677be168c0dSopenharmony_ci+ public:
9678be168c0dSopenharmony_ci+  static int Parse(const lite::ThirdPartyModelString &param_string, ThirdPartyModelParam *param);
9679be168c0dSopenharmony_ci+
9680be168c0dSopenharmony_ci+ private:
9681be168c0dSopenharmony_ci+  static int DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes);
9682be168c0dSopenharmony_ci+  static int DoParseExtendedParameters(const std::string &src,
9683be168c0dSopenharmony_ci+                                       std::map<std::string, std::vector<uint8_t>> *dst_ext_param);
9684be168c0dSopenharmony_ci+  static int DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes);
9685be168c0dSopenharmony_ci+  static int DoParseNames(const std::string &src, size_t num, const std::string &default_prefix,
9686be168c0dSopenharmony_ci+                          std::vector<std::string> *dst_names);
9687be168c0dSopenharmony_ci+  static int DoParseFormats(const std::string &src, size_t num, std::vector<schema::Format> *result_formats);
9688be168c0dSopenharmony_ci+};
9689be168c0dSopenharmony_ci+}  // namespace lite
9690be168c0dSopenharmony_ci+}  // namespace mindspore
9691be168c0dSopenharmony_ci+
9692be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_
9693be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc
9694be168c0dSopenharmony_ciindex df3176c2..a61bd51c 100644
9695be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/converter.cc
9696be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/converter.cc
9697be168c0dSopenharmony_ci@@ -49,6 +49,7 @@
9698be168c0dSopenharmony_ci #include "tools/converter/config_parser/preprocess_parser.h"
9699be168c0dSopenharmony_ci #include "tools/converter/config_parser/quant_param_parser.h"
9700be168c0dSopenharmony_ci #include "tools/converter/config_parser/graph_kernel_param_parser.h"
9701be168c0dSopenharmony_ci+#include "tools/converter/config_parser/third_party_param_parser.h"
9702be168c0dSopenharmony_ci #include "tools/converter/converter_funcgraph.h"
9703be168c0dSopenharmony_ci #include "tools/converter/converter_metagraph.h"
9704be168c0dSopenharmony_ci #include "tools/common/string_util.h"
9705be168c0dSopenharmony_ci@@ -472,6 +473,12 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std::
9706be168c0dSopenharmony_ci       MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
9707be168c0dSopenharmony_ci       return ret;
9708be168c0dSopenharmony_ci     }
9709be168c0dSopenharmony_ci+    ret = lite::ThirdPartyParamParser::Parse(config_parser->GetThirdPartyModelString(),
9710be168c0dSopenharmony_ci+                                             &param->thirdPartyModelParam);
9711be168c0dSopenharmony_ci+    if (ret != RET_OK) {
9712be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Parse third party param failed.";
9713be168c0dSopenharmony_ci+      return ret;
9714be168c0dSopenharmony_ci+    }
9715be168c0dSopenharmony_ci     ret = InitExtendedIntegrationInfo(param, *config_parser);
9716be168c0dSopenharmony_ci     if (ret != RET_OK) {
9717be168c0dSopenharmony_ci       MS_LOG(ERROR) << "Parse extended integration info failed.";
9718be168c0dSopenharmony_ci@@ -699,19 +706,20 @@ std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const s
9719be168c0dSopenharmony_ci 
9720be168c0dSopenharmony_ci int CheckFmkType(const std::shared_ptr<ConverterPara> &param) {
9721be168c0dSopenharmony_ci   if (param != nullptr) {
9722be168c0dSopenharmony_ci-    std::set valid_values = {FmkType::kFmkTypeTf,    FmkType::kFmkTypeCaffe,  FmkType::kFmkTypeOnnx,
9723be168c0dSopenharmony_ci-                             FmkType::kFmkTypeMs,    FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch,
9724be168c0dSopenharmony_ci-                             FmkType::kFmkTypeMsLite};
9725be168c0dSopenharmony_ci-    if (std::find(valid_values.begin(), valid_values.end(), param->fmk_type) == valid_values.end()) {
9726be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be "
9727be168c0dSopenharmony_ci-                       "kFmkTypeTf|kFmkTypeCaffe|kFmkTypeOnnx|kFmkTypeMs|kFmkTypeTflite|kFmkTypeMsLite"
9728be168c0dSopenharmony_ci-                    << ", but got " << param->fmk_type;
9729be168c0dSopenharmony_ci-      return RET_INPUT_PARAM_INVALID;
9730be168c0dSopenharmony_ci-    }
9731be168c0dSopenharmony_ci-    if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) {
9732be168c0dSopenharmony_ci-      MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag";
9733be168c0dSopenharmony_ci-      return RET_INPUT_PARAM_INVALID;
9734be168c0dSopenharmony_ci-    }
9735be168c0dSopenharmony_ci+    return RET_OK;
9736be168c0dSopenharmony_ci+  }
9737be168c0dSopenharmony_ci+  std::set kValidFmkTypes = {FmkType::kFmkTypeTf,    FmkType::kFmkTypeCaffe,  FmkType::kFmkTypeOnnx,
9738be168c0dSopenharmony_ci+                           FmkType::kFmkTypeMs,    FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch,
9739be168c0dSopenharmony_ci+                           FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty};
9740be168c0dSopenharmony_ci+  if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) {
9741be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be "
9742be168c0dSopenharmony_ci+                     "TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|MSLITE|THIRDPARTY"
9743be168c0dSopenharmony_ci+                  << ", but got " << param->fmk_type;
9744be168c0dSopenharmony_ci+    return RET_INPUT_PARAM_INVALID;
9745be168c0dSopenharmony_ci+  }
9746be168c0dSopenharmony_ci+  if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) {
9747be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag";
9748be168c0dSopenharmony_ci+    return RET_INPUT_PARAM_INVALID;
9749be168c0dSopenharmony_ci   }
9750be168c0dSopenharmony_ci   return RET_OK;
9751be168c0dSopenharmony_ci }
9752be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/converter_funcgraph.cc b/mindspore/lite/tools/converter/converter_funcgraph.cc
9753be168c0dSopenharmony_ciindex f03f995c..61d5c463 100644
9754be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/converter_funcgraph.cc
9755be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/converter_funcgraph.cc
9756be168c0dSopenharmony_ci@@ -90,6 +90,7 @@ FuncGraphPtr ConverterFuncGraph::Load3rdModelToFuncgraph(const std::shared_ptr<C
9757be168c0dSopenharmony_ci   converter_parameters.save_type = param->save_type;
9758be168c0dSopenharmony_ci   converter_parameters.model_file = param->model_file;
9759be168c0dSopenharmony_ci   converter_parameters.weight_file = param->weight_file;
9760be168c0dSopenharmony_ci+  converter_parameters.attrs.emplace("config_file", param->config_file);
9761be168c0dSopenharmony_ci   func_graph_base = model_parser->Parse(converter_parameters);
9762be168c0dSopenharmony_ci   if (func_graph_base == nullptr) {
9763be168c0dSopenharmony_ci     delete model_parser;
9764be168c0dSopenharmony_ci@@ -447,11 +448,13 @@ STATUS ConverterFuncGraph::Optimize(const std::shared_ptr<ConverterPara> &param,
9765be168c0dSopenharmony_ci     return status;
9766be168c0dSopenharmony_ci   }
9767be168c0dSopenharmony_ci 
9768be168c0dSopenharmony_ci-  AnfTransform funcgraph_transform;
9769be168c0dSopenharmony_ci-  status = funcgraph_transform.Transform(func_graph, param);
9770be168c0dSopenharmony_ci-  if (status != RET_OK) {
9771be168c0dSopenharmony_ci-    MS_LOG(ERROR) << "Transform anf graph failed.";
9772be168c0dSopenharmony_ci-    return status;
9773be168c0dSopenharmony_ci+  if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) {
9774be168c0dSopenharmony_ci+    AnfTransform funcgraph_transform;
9775be168c0dSopenharmony_ci+    status = funcgraph_transform.Transform(func_graph, param);
9776be168c0dSopenharmony_ci+    if (status != RET_OK) {
9777be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Transform anf graph failed.";
9778be168c0dSopenharmony_ci+      return status;
9779be168c0dSopenharmony_ci+    }
9780be168c0dSopenharmony_ci   }
9781be168c0dSopenharmony_ci 
9782be168c0dSopenharmony_ci   status = UnifyFuncGraphOutputFormat(param, func_graph);
9783be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
9784be168c0dSopenharmony_ciindex 4883c48d..024e209f 100644
9785be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
9786be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
9787be168c0dSopenharmony_ci@@ -138,11 +138,11 @@ int Flags::InitFmk() {
9788be168c0dSopenharmony_ci   // value check not here, it is in converter c++ API's CheckValueParam method.
9789be168c0dSopenharmony_ci   std::map<std::string, FmkType> StrToEnumFmkTypeMap = {
9790be168c0dSopenharmony_ci     {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs},       {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx},
9791be168c0dSopenharmony_ci-    {"TF", kFmkTypeTf},       {"PYTORCH", kFmkTypePytorch}, {"MSLITE", kFmkTypeMsLite}};
9792be168c0dSopenharmony_ci+    {"TF", kFmkTypeTf},       {"PYTORCH", kFmkTypePytorch}, {"MSLITE", kFmkTypeMsLite}, {"THIRDPARTY", kFmkTypeThirdParty}};
9793be168c0dSopenharmony_ci   if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) {
9794be168c0dSopenharmony_ci     this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn);
9795be168c0dSopenharmony_ci   } else {
9796be168c0dSopenharmony_ci-    std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl;
9797be168c0dSopenharmony_ci+    std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX|PYTORCH|THIRDPARTY" << std::endl;
9798be168c0dSopenharmony_ci     return RET_INPUT_PARAM_INVALID;
9799be168c0dSopenharmony_ci   }
9800be168c0dSopenharmony_ci 
9801be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h
9802be168c0dSopenharmony_ciindex a4f72a69..33210fd0 100644
9803be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/cxx_api/converter_para.h
9804be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h
9805be168c0dSopenharmony_ci@@ -21,6 +21,7 @@
9806be168c0dSopenharmony_ci #include <vector>
9807be168c0dSopenharmony_ci #include <set>
9808be168c0dSopenharmony_ci #include "include/converter.h"
9809be168c0dSopenharmony_ci+#include "mindapi/base/type_id.h"
9810be168c0dSopenharmony_ci #include "tools/converter/quantizer/quant_params.h"
9811be168c0dSopenharmony_ci #include "tools/converter/preprocess/preprocess_param.h"
9812be168c0dSopenharmony_ci #include "tools/converter/adapter/acl/common/acl_types.h"
9813be168c0dSopenharmony_ci@@ -35,6 +36,18 @@ struct ParallelSplitConfig {
9814be168c0dSopenharmony_ci   std::vector<std::string> parallel_devices_;
9815be168c0dSopenharmony_ci };
9816be168c0dSopenharmony_ci 
9817be168c0dSopenharmony_ci+struct ThirdPartyModelParam {
9818be168c0dSopenharmony_ci+  std::vector<TypeId> input_dtypes;
9819be168c0dSopenharmony_ci+  std::vector<std::vector<int64_t>> input_shapes;
9820be168c0dSopenharmony_ci+  std::vector<std::string> input_names;
9821be168c0dSopenharmony_ci+  std::vector<schema::Format> input_formats;
9822be168c0dSopenharmony_ci+  std::vector<TypeId> output_dtypes;
9823be168c0dSopenharmony_ci+  std::vector<std::vector<int64_t>> output_shapes;
9824be168c0dSopenharmony_ci+  std::vector<std::string> output_names;
9825be168c0dSopenharmony_ci+  std::vector<schema::Format> output_formats;
9826be168c0dSopenharmony_ci+  std::map<std::string, std::vector<uint8_t>> extended_parameters;
9827be168c0dSopenharmony_ci+};
9828be168c0dSopenharmony_ci+
9829be168c0dSopenharmony_ci struct CpuOptionCfg {
9830be168c0dSopenharmony_ci   std::string architecture;
9831be168c0dSopenharmony_ci   std::string instruction;
9832be168c0dSopenharmony_ci@@ -97,6 +110,7 @@ struct ConverterPara {
9833be168c0dSopenharmony_ci   lite::acl::AclModelOptionCfg aclModelOptionCfgParam;
9834be168c0dSopenharmony_ci   lite::micro::MicroParam microParam;
9835be168c0dSopenharmony_ci   ParallelSplitConfig parallel_split_config;
9836be168c0dSopenharmony_ci+  ThirdPartyModelParam thirdPartyModelParam;
9837be168c0dSopenharmony_ci   AscendGeOptionCfg ascendGeOptionCfg;
9838be168c0dSopenharmony_ci   std::string device;
9839be168c0dSopenharmony_ci   std::string provider;
9840be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc
9841be168c0dSopenharmony_ciindex 90b744e5..bf1a82ae 100644
9842be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/graphdef_transform.cc
9843be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/graphdef_transform.cc
9844be168c0dSopenharmony_ci@@ -76,11 +76,55 @@ int QuantTransform(const std::shared_ptr<ConverterPara> &param, schema::MetaGrap
9845be168c0dSopenharmony_ci   }
9846be168c0dSopenharmony_ci   return RET_OK;
9847be168c0dSopenharmony_ci }
9848be168c0dSopenharmony_ci+
9849be168c0dSopenharmony_ci+int FillGraphOutputShape(MetaGraphT *meta_graph, const std::vector<std::vector<int64_t>> output_shapes) {
9850be168c0dSopenharmony_ci+  const auto &out_indices = meta_graph->outputIndex;
9851be168c0dSopenharmony_ci+  for (size_t i = 0; i < out_indices.size(); i++) {
9852be168c0dSopenharmony_ci+    auto &out_tensor = meta_graph->allTensors[out_indices[i]];
9853be168c0dSopenharmony_ci+    out_tensor->dims = {};
9854be168c0dSopenharmony_ci+    for (size_t k = 0; k < output_shapes[i].size(); k++) {
9855be168c0dSopenharmony_ci+      out_tensor->dims.push_back(static_cast<int32_t>(output_shapes[i][k]));
9856be168c0dSopenharmony_ci+    }
9857be168c0dSopenharmony_ci+  }
9858be168c0dSopenharmony_ci+  return RET_OK;
9859be168c0dSopenharmony_ci+}
9860be168c0dSopenharmony_ci+
9861be168c0dSopenharmony_ci+void FillGraphInputAndOutputFormats(MetaGraphT *meta_graph, const ConverterPara &para) {
9862be168c0dSopenharmony_ci+  const auto &in_indices = meta_graph->inputIndex;
9863be168c0dSopenharmony_ci+  for (size_t i = 0; i < in_indices.size(); i++) {
9864be168c0dSopenharmony_ci+    auto &in_tensor = meta_graph->allTensors[in_indices[i]];
9865be168c0dSopenharmony_ci+    in_tensor->format = para.thirdPartyModelParam.input_formats[i];
9866be168c0dSopenharmony_ci+    MS_LOG_DEBUG << "input " << i << " format: " << EnumNameFormat(in_tensor->format);
9867be168c0dSopenharmony_ci+  }
9868be168c0dSopenharmony_ci+
9869be168c0dSopenharmony_ci+  const auto &out_indices = meta_graph->outputIndex;
9870be168c0dSopenharmony_ci+  for (size_t i = 0; i < out_indices.size(); i++) {
9871be168c0dSopenharmony_ci+    auto &out_tensor = meta_graph->allTensors[out_indices[i]];
9872be168c0dSopenharmony_ci+    out_tensor->format = para.thirdPartyModelParam.output_formats[i];
9873be168c0dSopenharmony_ci+    MS_LOG_DEBUG << "output " << i << " format: " << EnumNameFormat(out_tensor->format);
9874be168c0dSopenharmony_ci+  }
9875be168c0dSopenharmony_ci+}
9876be168c0dSopenharmony_ci }  // namespace
9877be168c0dSopenharmony_ci 
9878be168c0dSopenharmony_ci int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> &param) {
9879be168c0dSopenharmony_ci   MS_ASSERT(param != nullptr);
9880be168c0dSopenharmony_ci   STATUS status;
9881be168c0dSopenharmony_ci+
9882be168c0dSopenharmony_ci+  if (param->fmk_type == converter::kFmkTypeThirdParty) {
9883be168c0dSopenharmony_ci+
9884be168c0dSopenharmony_ci+    // Legacy optimizer infer shape, but op Custom which wraps third party model has no infer-shape function.
9885be168c0dSopenharmony_ci+    // So we don't perform legacy optimization for kFmkTypeThirdParty case.
9886be168c0dSopenharmony_ci+    auto ret = FillGraphOutputShape(graph_defT_, param->thirdPartyModelParam.output_shapes);
9887be168c0dSopenharmony_ci+    if (ret != RET_OK) {
9888be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Fill output shape of third party model failed, ret:" << ret;
9889be168c0dSopenharmony_ci+      return ret;
9890be168c0dSopenharmony_ci+    }
9891be168c0dSopenharmony_ci+
9892be168c0dSopenharmony_ci+    // Tensor of FuncGraph has no attribute of format, so set format in MetaGraph.
9893be168c0dSopenharmony_ci+    FillGraphInputAndOutputFormats(graph_defT_, *param);
9894be168c0dSopenharmony_ci+    return RET_OK;
9895be168c0dSopenharmony_ci+  }
9896be168c0dSopenharmony_ci+
9897be168c0dSopenharmony_ci   {
9898be168c0dSopenharmony_ci     auto old_nodes = GetGraphNodes(*graph_defT_);
9899be168c0dSopenharmony_ci     Optimizer unused_op_remove_optimizer;
9900be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt
9901be168c0dSopenharmony_cinew file mode 100644
9902be168c0dSopenharmony_ciindex 00000000..b55e0194
9903be168c0dSopenharmony_ci--- /dev/null
9904be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt
9905be168c0dSopenharmony_ci@@ -0,0 +1,4 @@
9906be168c0dSopenharmony_ci+add_library(third_party_parser_mid OBJECT third_party_model_parser.cc)
9907be168c0dSopenharmony_ci+add_dependencies(third_party_parser_mid proto_mid)
9908be168c0dSopenharmony_ci+add_dependencies(third_party_parser_mid fbs_src)
9909be168c0dSopenharmony_ci+add_dependencies(third_party_parser_mid fbs_inner_src)
9910be168c0dSopenharmony_ci\ No newline at end of file
9911be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc
9912be168c0dSopenharmony_cinew file mode 100644
9913be168c0dSopenharmony_ciindex 00000000..652db4af
9914be168c0dSopenharmony_ci--- /dev/null
9915be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc
9916be168c0dSopenharmony_ci@@ -0,0 +1,277 @@
9917be168c0dSopenharmony_ci+/**
9918be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
9919be168c0dSopenharmony_ci+ *
9920be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
9921be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
9922be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
9923be168c0dSopenharmony_ci+ *
9924be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
9925be168c0dSopenharmony_ci+ *
9926be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
9927be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
9928be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9929be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
9930be168c0dSopenharmony_ci+ * limitations under the License.
9931be168c0dSopenharmony_ci+ */
9932be168c0dSopenharmony_ci+#include "tools/converter/parser/third_party/third_party_model_parser.h"
9933be168c0dSopenharmony_ci+#include <string>
9934be168c0dSopenharmony_ci+#include <vector>
9935be168c0dSopenharmony_ci+#include <memory>
9936be168c0dSopenharmony_ci+#include "ir/value.h"
9937be168c0dSopenharmony_ci+#include "mindapi/base/type_id.h"
9938be168c0dSopenharmony_ci+#include "src/common/log_util.h"
9939be168c0dSopenharmony_ci+#include "src/common/file_utils.h"
9940be168c0dSopenharmony_ci+#include "nnacl/op_base.h"
9941be168c0dSopenharmony_ci+#include "ops/primitive_c.h"
9942be168c0dSopenharmony_ci+#include "ops/custom.h"
9943be168c0dSopenharmony_ci+#include "ops/tuple_get_item.h"
9944be168c0dSopenharmony_ci+#include "ops/make_tuple.h"
9945be168c0dSopenharmony_ci+#include "ops/return.h"
9946be168c0dSopenharmony_ci+#include "tools/converter/config_parser/config_file_parser.h"
9947be168c0dSopenharmony_ci+#include "include/registry/model_parser_registry.h"
9948be168c0dSopenharmony_ci+#include "tools/common/graph_util.h"
9949be168c0dSopenharmony_ci+#include "tools/common/tensor_util.h"
9950be168c0dSopenharmony_ci+#include "tools/converter/converter_context.h"
9951be168c0dSopenharmony_ci+#include "tools/converter/parser/lite_model_parser_creator.h"
9952be168c0dSopenharmony_ci+
9953be168c0dSopenharmony_ci+using mindspore::converter::kFmkTypeThirdParty;
9954be168c0dSopenharmony_ci+
9955be168c0dSopenharmony_ci+namespace mindspore {
9956be168c0dSopenharmony_ci+namespace lite {
9957be168c0dSopenharmony_ci+api::FuncGraphPtr ThirdPartyModelParser::Parse(const converter::ConverterParameters &flag) {
9958be168c0dSopenharmony_ci+  model_file_ = flag.model_file;
9959be168c0dSopenharmony_ci+  auto &attrs = flag.attrs;
9960be168c0dSopenharmony_ci+  auto iter = attrs.find("config_file");
9961be168c0dSopenharmony_ci+  if (iter == attrs.end()) {
9962be168c0dSopenharmony_ci+    return nullptr;
9963be168c0dSopenharmony_ci+  }
9964be168c0dSopenharmony_ci+  auto config_file = iter->second;
9965be168c0dSopenharmony_ci+
9966be168c0dSopenharmony_ci+  auto ret = InitConfig(config_file);
9967be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9968be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Init config for third party model parsing failed";
9969be168c0dSopenharmony_ci+    return nullptr;
9970be168c0dSopenharmony_ci+  }
9971be168c0dSopenharmony_ci+
9972be168c0dSopenharmony_ci+  return CreateFuncGraph();
9973be168c0dSopenharmony_ci+}
9974be168c0dSopenharmony_ci+
9975be168c0dSopenharmony_ci+STATUS ThirdPartyModelParser::InitConfig(const std::string &config_file) {
9976be168c0dSopenharmony_ci+  lite::ConfigFileParser config_parser;
9977be168c0dSopenharmony_ci+  if (config_file.empty()) {
9978be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Missing config file in converting third party model";
9979be168c0dSopenharmony_ci+    return RET_ERROR;
9980be168c0dSopenharmony_ci+  }
9981be168c0dSopenharmony_ci+  auto ret = config_parser.ParseConfigFile(config_file);
9982be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9983be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Get third party model section from config file failed";
9984be168c0dSopenharmony_ci+    return RET_ERROR;
9985be168c0dSopenharmony_ci+  }
9986be168c0dSopenharmony_ci+
9987be168c0dSopenharmony_ci+  ret = ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), &param_);
9988be168c0dSopenharmony_ci+  if (ret != RET_OK) {
9989be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Parse third party model param failed.";
9990be168c0dSopenharmony_ci+    return ret;
9991be168c0dSopenharmony_ci+  }
9992be168c0dSopenharmony_ci+  return RET_OK;
9993be168c0dSopenharmony_ci+}
9994be168c0dSopenharmony_ci+
9995be168c0dSopenharmony_ci+api::FuncGraphPtr ThirdPartyModelParser::CreateFuncGraph() {
9996be168c0dSopenharmony_ci+  auto func_graph = std::make_shared<FuncGraph>();
9997be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
9998be168c0dSopenharmony_ci+  auto type_value = MakeValue(static_cast<int>(converter::kFmkTypeThirdParty));
9999be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(type_value != nullptr, nullptr);
10000be168c0dSopenharmony_ci+  func_graph->set_attr("fmk", type_value);
10001be168c0dSopenharmony_ci+  auto attr_value = MakeValue("third_party");
10002be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(attr_value != nullptr, nullptr);
10003be168c0dSopenharmony_ci+  func_graph->set_attr("graph_name", attr_value);
10004be168c0dSopenharmony_ci+
10005be168c0dSopenharmony_ci+  std::vector<AnfNodePtr> input_nodes = {};
10006be168c0dSopenharmony_ci+  auto ret = BuildGraphInputs(func_graph, &input_nodes);
10007be168c0dSopenharmony_ci+  if (ret != RET_OK) {
10008be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Create func graph input nodes failed";
10009be168c0dSopenharmony_ci+    return nullptr;
10010be168c0dSopenharmony_ci+  }
10011be168c0dSopenharmony_ci+
10012be168c0dSopenharmony_ci+  CNodePtr custom_node = nullptr;
10013be168c0dSopenharmony_ci+  ret = BuildCustomOp(func_graph, input_nodes, &custom_node);
10014be168c0dSopenharmony_ci+  if (ret != RET_OK) {
10015be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Create func graph custom op node failed";
10016be168c0dSopenharmony_ci+    return nullptr;
10017be168c0dSopenharmony_ci+  }
10018be168c0dSopenharmony_ci+
10019be168c0dSopenharmony_ci+  ret = BuildGraphOutputs(func_graph, custom_node);
10020be168c0dSopenharmony_ci+  if (ret != RET_OK) {
10021be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "Create func graph output nodes failed";
10022be168c0dSopenharmony_ci+    return nullptr;
10023be168c0dSopenharmony_ci+  }
10024be168c0dSopenharmony_ci+
10025be168c0dSopenharmony_ci+  static auto manager = Manage(func_graph);
10026be168c0dSopenharmony_ci+  func_graph->set_manager(manager);
10027be168c0dSopenharmony_ci+
10028be168c0dSopenharmony_ci+  auto result_graph = api::MakeShared<api::FuncGraph>(func_graph);
10029be168c0dSopenharmony_ci+  return result_graph;
10030be168c0dSopenharmony_ci+}
10031be168c0dSopenharmony_ci+
10032be168c0dSopenharmony_ci+STATUS ThirdPartyModelParser::BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs) {
10033be168c0dSopenharmony_ci+  MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr);
10034be168c0dSopenharmony_ci+  auto &dtypes = param_.input_dtypes;
10035be168c0dSopenharmony_ci+  auto &shapes = param_.input_shapes;
10036be168c0dSopenharmony_ci+  auto &names = param_.input_names;
10037be168c0dSopenharmony_ci+
10038be168c0dSopenharmony_ci+  auto input_size = dtypes.size();
10039be168c0dSopenharmony_ci+
10040be168c0dSopenharmony_ci+  // Create parameter nodes for graph inputs
10041be168c0dSopenharmony_ci+  for (size_t i = 0; i < input_size; i++) {
10042be168c0dSopenharmony_ci+    auto parameter = func_graph->add_parameter();
10043be168c0dSopenharmony_ci+    MSLITE_CHECK_PTR(parameter);
10044be168c0dSopenharmony_ci+    auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]);
10045be168c0dSopenharmony_ci+    if (abstract_tensor == nullptr) {
10046be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Create tensor abstract failed";
10047be168c0dSopenharmony_ci+      return RET_ERROR;
10048be168c0dSopenharmony_ci+    }
10049be168c0dSopenharmony_ci+    parameter->set_abstract(abstract_tensor);
10050be168c0dSopenharmony_ci+    parameter->set_name(names[i]);
10051be168c0dSopenharmony_ci+    op_inputs->push_back(parameter);
10052be168c0dSopenharmony_ci+  }
10053be168c0dSopenharmony_ci+
10054be168c0dSopenharmony_ci+  // Create parameter nodes for const tensor which wrapped third model buffer.
10055be168c0dSopenharmony_ci+  size_t model_size = 0U;
10056be168c0dSopenharmony_ci+  auto model_data = ReadFile(model_file_.c_str(), &model_size);
10057be168c0dSopenharmony_ci+  std::vector<int64_t> model_shape = {static_cast<int64_t>(model_size)};
10058be168c0dSopenharmony_ci+  auto tensor_info = CreateTensorInfo(nullptr, 0, model_shape, kNumberTypeUInt8);
10059be168c0dSopenharmony_ci+  if (tensor_info == nullptr) {
10060be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "init tensor info failed";
10061be168c0dSopenharmony_ci+    delete model_data;
10062be168c0dSopenharmony_ci+    return RET_NULL_PTR;
10063be168c0dSopenharmony_ci+  }
10064be168c0dSopenharmony_ci+  auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
10065be168c0dSopenharmony_ci+  if (memcpy_s(tensor_data, tensor_info->Size(), model_data, model_size) != EOK) {
10066be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "memcpy failed.";
10067be168c0dSopenharmony_ci+    delete model_data;
10068be168c0dSopenharmony_ci+    return RET_ERROR;
10069be168c0dSopenharmony_ci+  }
10070be168c0dSopenharmony_ci+  delete model_data;
10071be168c0dSopenharmony_ci+  auto parameter = func_graph->add_parameter();
10072be168c0dSopenharmony_ci+  MSLITE_CHECK_PTR(parameter);
10073be168c0dSopenharmony_ci+  auto status = InitParameterFromTensorInfo(parameter, tensor_info);
10074be168c0dSopenharmony_ci+  if (status != RET_OK) {
10075be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "init parameter from tensor info failed.";
10076be168c0dSopenharmony_ci+    return RET_ERROR;
10077be168c0dSopenharmony_ci+  }
10078be168c0dSopenharmony_ci+  parameter->set_name("ThirdPartyModel");
10079be168c0dSopenharmony_ci+  op_inputs->push_back(parameter);
10080be168c0dSopenharmony_ci+  return RET_OK;
10081be168c0dSopenharmony_ci+}
10082be168c0dSopenharmony_ci+
10083be168c0dSopenharmony_ci+STATUS ThirdPartyModelParser::BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs,
10084be168c0dSopenharmony_ci+                                            CNodePtr *operator_node) {
10085be168c0dSopenharmony_ci+  MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr);
10086be168c0dSopenharmony_ci+  NotSupportOp::GetInstance()->set_fmk_type("THIRDPARTY");
10087be168c0dSopenharmony_ci+  STATUS status = RET_OK;
10088be168c0dSopenharmony_ci+
10089be168c0dSopenharmony_ci+  // create primitive and build CNode of CUSTOM operator
10090be168c0dSopenharmony_ci+  ops::PrimitiveCPtr primitive_c;
10091be168c0dSopenharmony_ci+  auto prim = std::make_unique<ops::Custom>();
10092be168c0dSopenharmony_ci+  MS_CHECK_TRUE_RET(prim != nullptr, RET_ERROR);
10093be168c0dSopenharmony_ci+  prim->set_type("ThirdPartyModel");
10094be168c0dSopenharmony_ci+
10095be168c0dSopenharmony_ci+  const auto &attr = param_.extended_parameters;
10096be168c0dSopenharmony_ci+  prim->set_attr(attr);
10097be168c0dSopenharmony_ci+  primitive_c = prim->GetPrim();
10098be168c0dSopenharmony_ci+  if (primitive_c == nullptr) {
10099be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "failed to create primitive: custom";
10100be168c0dSopenharmony_ci+    return RET_ERROR;
10101be168c0dSopenharmony_ci+  }
10102be168c0dSopenharmony_ci+
10103be168c0dSopenharmony_ci+  auto operator_cnode = func_graph->NewCNode(primitive_c, op_inputs);
10104be168c0dSopenharmony_ci+  MSLITE_CHECK_PTR(operator_cnode);
10105be168c0dSopenharmony_ci+  operator_cnode->set_fullname_with_scope("Custom");
10106be168c0dSopenharmony_ci+  *operator_node = operator_cnode;
10107be168c0dSopenharmony_ci+  return status;
10108be168c0dSopenharmony_ci+}
10109be168c0dSopenharmony_ci+
10110be168c0dSopenharmony_ci+STATUS ThirdPartyModelParser::BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node) {
10111be168c0dSopenharmony_ci+  MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr);
10112be168c0dSopenharmony_ci+
10113be168c0dSopenharmony_ci+  auto dtypes = param_.output_dtypes;
10114be168c0dSopenharmony_ci+  auto shapes = param_.output_shapes;
10115be168c0dSopenharmony_ci+  auto names = param_.output_names;
10116be168c0dSopenharmony_ci+
10117be168c0dSopenharmony_ci+  auto output_size = dtypes.size();
10118be168c0dSopenharmony_ci+  std::vector<AnfNodePtr> output_nodes = {};
10119be168c0dSopenharmony_ci+
10120be168c0dSopenharmony_ci+  // Use TupleGetItem to wrap op outputs.
10121be168c0dSopenharmony_ci+  AbstractBasePtrList abstract_list;
10122be168c0dSopenharmony_ci+  for (size_t i = 0; i < output_size; i++) {
10123be168c0dSopenharmony_ci+    auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]);
10124be168c0dSopenharmony_ci+    if (abstract_tensor == nullptr) {
10125be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Create tensor abstract failed";
10126be168c0dSopenharmony_ci+      return RET_ERROR;
10127be168c0dSopenharmony_ci+    }
10128be168c0dSopenharmony_ci+    abstract_list.emplace_back(abstract_tensor);
10129be168c0dSopenharmony_ci+    auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
10130be168c0dSopenharmony_ci+    if (tuple_get_item_prim_ptr == nullptr) {
10131be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "new TupleGetItem failed";
10132be168c0dSopenharmony_ci+      return RET_NULL_PTR;
10133be168c0dSopenharmony_ci+    }
10134be168c0dSopenharmony_ci+    auto tuple_get_item_prim_c = tuple_get_item_prim_ptr->GetPrim();
10135be168c0dSopenharmony_ci+    MSLITE_CHECK_PTR(tuple_get_item_prim_c);
10136be168c0dSopenharmony_ci+    auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_c);
10137be168c0dSopenharmony_ci+    MSLITE_CHECK_PTR(tuple_get_item_prim);
10138be168c0dSopenharmony_ci+    auto get_item_value = NewValueNode(MakeValue<int>(i));
10139be168c0dSopenharmony_ci+    MSLITE_CHECK_PTR(get_item_value);
10140be168c0dSopenharmony_ci+    std::vector<AnfNodePtr> inputs = {tuple_get_item_prim, operator_node, get_item_value};
10141be168c0dSopenharmony_ci+    CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
10142be168c0dSopenharmony_ci+    MSLITE_CHECK_PTR(get_item_cnode);
10143be168c0dSopenharmony_ci+    std::string output_item_name = operator_node->fullname_with_scope() + "_getitem_" + std::to_string(i);
10144be168c0dSopenharmony_ci+    auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
10145be168c0dSopenharmony_ci+    if (get_item_abstract == nullptr) {
10146be168c0dSopenharmony_ci+      MS_LOG(ERROR) << "Create tensor abstarct failed";
10147be168c0dSopenharmony_ci+      return RET_ERROR;
10148be168c0dSopenharmony_ci+    }
10149be168c0dSopenharmony_ci+    get_item_cnode->set_fullname_with_scope(output_item_name);
10150be168c0dSopenharmony_ci+    get_item_cnode->set_abstract(get_item_abstract);
10151be168c0dSopenharmony_ci+    output_nodes.push_back(get_item_cnode);
10152be168c0dSopenharmony_ci+  }
10153be168c0dSopenharmony_ci+  auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
10154be168c0dSopenharmony_ci+  MSLITE_CHECK_PTR(abstract_tuple);
10155be168c0dSopenharmony_ci+  operator_node->set_abstract(abstract_tuple);
10156be168c0dSopenharmony_ci+
10157be168c0dSopenharmony_ci+  // Use MakeTuple node to wrap all outputs as single input of Return node.
10158be168c0dSopenharmony_ci+  auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
10159be168c0dSopenharmony_ci+  if (make_tuple_prim_ptr == nullptr) {
10160be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "new MakeTuple failed";
10161be168c0dSopenharmony_ci+    return RET_NULL_PTR;
10162be168c0dSopenharmony_ci+  }
10163be168c0dSopenharmony_ci+  auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim();
10164be168c0dSopenharmony_ci+  MSLITE_CHECK_PTR(make_tuple_prim_c);
10165be168c0dSopenharmony_ci+  auto make_tuple_prim = NewValueNode(make_tuple_prim_c);
10166be168c0dSopenharmony_ci+  MSLITE_CHECK_PTR(make_tuple_prim);
10167be168c0dSopenharmony_ci+  std::vector<AnfNodePtr> make_tuple_inputs = output_nodes;
10168be168c0dSopenharmony_ci+  make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim);
10169be168c0dSopenharmony_ci+  auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs);
10170be168c0dSopenharmony_ci+  MSLITE_CHECK_PTR(make_tuple_cnode);
10171be168c0dSopenharmony_ci+  make_tuple_cnode->set_fullname_with_scope("return_tuple");
10172be168c0dSopenharmony_ci+
10173be168c0dSopenharmony_ci+  auto return_prim_ptr = std::make_shared<ops::Return>();
10174be168c0dSopenharmony_ci+  if (return_prim_ptr == nullptr) {
10175be168c0dSopenharmony_ci+    MS_LOG(ERROR) << "new Return failed";
10176be168c0dSopenharmony_ci+    return RET_NULL_PTR;
10177be168c0dSopenharmony_ci+  }
10178be168c0dSopenharmony_ci+  auto return_prim_c = return_prim_ptr->GetPrim();
10179be168c0dSopenharmony_ci+  MSLITE_CHECK_PTR(return_prim_c);
10180be168c0dSopenharmony_ci+  std::vector<AnfNodePtr> op_inputs{make_tuple_cnode};
10181be168c0dSopenharmony_ci+  auto cnode = func_graph->NewCNode(return_prim_c, op_inputs);
10182be168c0dSopenharmony_ci+  MSLITE_CHECK_PTR(cnode);
10183be168c0dSopenharmony_ci+  cnode->set_fullname_with_scope("Return");
10184be168c0dSopenharmony_ci+  func_graph->set_return(cnode);
10185be168c0dSopenharmony_ci+
10186be168c0dSopenharmony_ci+  // Save original output tensor names.
10187be168c0dSopenharmony_ci+  ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(names);
10188be168c0dSopenharmony_ci+  return RET_OK;
10189be168c0dSopenharmony_ci+}
10190be168c0dSopenharmony_ci+
10191be168c0dSopenharmony_ci+REG_MODEL_PARSER(kFmkTypeThirdParty, LiteModelParserCreator<ThirdPartyModelParser>)
10192be168c0dSopenharmony_ci+}  // namespace lite
10193be168c0dSopenharmony_ci+}  // namespace mindspore
10194be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h
10195be168c0dSopenharmony_cinew file mode 100644
10196be168c0dSopenharmony_ciindex 00000000..c4b197b8
10197be168c0dSopenharmony_ci--- /dev/null
10198be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h
10199be168c0dSopenharmony_ci@@ -0,0 +1,50 @@
10200be168c0dSopenharmony_ci+/**
10201be168c0dSopenharmony_ci+ * Copyright 2023 Huawei Technologies Co., Ltd
10202be168c0dSopenharmony_ci+ *
10203be168c0dSopenharmony_ci+ * Licensed under the Apache License, Version 2.0 (the "License");
10204be168c0dSopenharmony_ci+ * you may not use this file except in compliance with the License.
10205be168c0dSopenharmony_ci+ * You may obtain a copy of the License at
10206be168c0dSopenharmony_ci+ *
10207be168c0dSopenharmony_ci+ * http://www.apache.org/licenses/LICENSE-2.0
10208be168c0dSopenharmony_ci+ *
10209be168c0dSopenharmony_ci+ * Unless required by applicable law or agreed to in writing, software
10210be168c0dSopenharmony_ci+ * distributed under the License is distributed on an "AS IS" BASIS,
10211be168c0dSopenharmony_ci+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10212be168c0dSopenharmony_ci+ * See the License for the specific language governing permissions and
10213be168c0dSopenharmony_ci+ * limitations under the License.
10214be168c0dSopenharmony_ci+ */
10215be168c0dSopenharmony_ci+
10216be168c0dSopenharmony_ci+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_
10217be168c0dSopenharmony_ci+#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_
10218be168c0dSopenharmony_ci+
10219be168c0dSopenharmony_ci+#include <string>
10220be168c0dSopenharmony_ci+#include <vector>
10221be168c0dSopenharmony_ci+#include "schema/inner/model_generated.h"
10222be168c0dSopenharmony_ci+#include "base/base.h"
10223be168c0dSopenharmony_ci+#include "ir/anf.h"
10224be168c0dSopenharmony_ci+#include "ir/func_graph.h"
10225be168c0dSopenharmony_ci+#include "include/errorcode.h"
10226be168c0dSopenharmony_ci+#include "include/registry/model_parser.h"
10227be168c0dSopenharmony_ci+#include "tools/converter/config_parser/third_party_param_parser.h"
10228be168c0dSopenharmony_ci+
10229be168c0dSopenharmony_ci+namespace mindspore {
10230be168c0dSopenharmony_ci+namespace lite {
10231be168c0dSopenharmony_ci+class ThirdPartyModelParser : public converter::ModelParser {
10232be168c0dSopenharmony_ci+ public:
10233be168c0dSopenharmony_ci+  api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
10234be168c0dSopenharmony_ci+
10235be168c0dSopenharmony_ci+ private:
10236be168c0dSopenharmony_ci+  STATUS InitConfig(const std::string &config_file);
10237be168c0dSopenharmony_ci+  api::FuncGraphPtr CreateFuncGraph();
10238be168c0dSopenharmony_ci+  STATUS BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs);
10239be168c0dSopenharmony_ci+  STATUS BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs,
10240be168c0dSopenharmony_ci+                       CNodePtr *operator_node);
10241be168c0dSopenharmony_ci+  STATUS BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node);
10242be168c0dSopenharmony_ci+
10243be168c0dSopenharmony_ci+  std::string model_file_ = "";
10244be168c0dSopenharmony_ci+  ThirdPartyModelParam param_;
10245be168c0dSopenharmony_ci+};
10246be168c0dSopenharmony_ci+}  // namespace lite
10247be168c0dSopenharmony_ci+}  // namespace mindspore
10248be168c0dSopenharmony_ci+
10249be168c0dSopenharmony_ci+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_
10250be168c0dSopenharmony_cidiff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc
10251be168c0dSopenharmony_ciindex 832fb92d..6bc2d4d3 100644
10252be168c0dSopenharmony_ci--- a/mindspore/lite/tools/converter/registry/model_parser_registry.cc
10253be168c0dSopenharmony_ci+++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc
10254be168c0dSopenharmony_ci@@ -26,7 +26,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room;
10255be168c0dSopenharmony_ci }  // namespace
10256be168c0dSopenharmony_ci 
10257be168c0dSopenharmony_ci ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) {
10258be168c0dSopenharmony_ci-  if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) {
10259be168c0dSopenharmony_ci+  if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) {
10260be168c0dSopenharmony_ci     MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
10261be168c0dSopenharmony_ci     return;
10262be168c0dSopenharmony_ci   }
10263be168c0dSopenharmony_ci@@ -38,7 +38,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator
10264be168c0dSopenharmony_ci }
10265be168c0dSopenharmony_ci 
10266be168c0dSopenharmony_ci converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
10267be168c0dSopenharmony_ci-  if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) {
10268be168c0dSopenharmony_ci+  if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) {
10269be168c0dSopenharmony_ci     MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
10270be168c0dSopenharmony_ci     return nullptr;
10271be168c0dSopenharmony_ci   }
10272be168c0dSopenharmony_ci-- 
10273be168c0dSopenharmony_ci2.17.1
10274be168c0dSopenharmony_ci
10275