1From baf2daaebd70448cddd35f5011642fe585d071b5 Mon Sep 17 00:00:00 2001
2From: chengfeng27 <chengfeng27@huawei.com>
3Date: Tue, 5 Mar 2024 20:00:24 +0800
4Subject: [PATCH] hilog use macro definition api
5
6---
7 cmake/external_libs/flatbuffers.cmake         |   4 +-
8 include/api/context.h                         |  65 ++
9 include/c_api/context_c.h                     | 111 +++
10 include/c_api/model_c.h                       | 178 ++++
11 include/c_api/tensor_c.h                      |  14 +
12 include/c_api/types_c.h                       |  57 +-
13 include/sdk_api/context.h                     | 103 +++
14 include/sdk_api/tensor.h                      |  13 +
15 include/sdk_api/types.h                       |  38 +-
16 .../plugin/device/cpu/kernel/nnacl/BUILD.gn   |   3 +
17 .../device/cpu/kernel/nnacl/CMakeLists.txt    |   2 +-
18 .../kernel/nnacl/avx/scatter_nd_binary_avx.h  |  66 ++
19 .../nnacl/avx512/scatter_nd_binary_avx512.h   |  66 ++
20 .../cpu/kernel/nnacl/base/scatter_nd_binary.c |  28 +
21 .../cpu/kernel/nnacl/base/scatter_nd_binary.h |   3 +
22 .../nnacl/base/scatter_nd_binary_simd.h.in    |  14 +
23 .../kernel/nnacl/custom_is_inf_parameter.h    |  26 +
24 .../nnacl/custom_masked_fill_parameter.h      |  26 +
25 .../custom_tensor_scatter_max_parameter.h     |  26 +
26 .../kernel/nnacl/infer/custom_is_inf_infer.c  |  38 +
27 .../kernel/nnacl/infer/custom_is_inf_infer.h  |  31 +
28 .../nnacl/infer/custom_masked_fill_infer.c    |  37 +
29 .../nnacl/infer/custom_masked_fill_infer.h    |  31 +
30 .../infer/custom_tensor_scatter_max_infer.c   |  37 +
31 .../infer/custom_tensor_scatter_max_infer.h   |  31 +
32 .../nnacl/neon/scatter_nd_binary_neon.h       |  65 ++
33 .../plugin/device/cpu/kernel/nnacl/op_base.h  |   4 +
34 .../cpu/kernel/nnacl/scatter_nd_binary_simd.h |  36 +
35 .../kernel/nnacl/sse/scatter_nd_binary_sse.h  |  66 ++
36 mindspore/core/mindrt/BUILD.gn                |   9 +-
37 .../mindrt/src/thread/actor_threadpool.cc     |   2 +-
38 .../core/mindrt/src/thread/core_affinity.cc   |   6 +-
39 .../core/mindrt/src/thread/core_affinity.h    |   2 +-
40 .../mindrt/src/thread/parallel_threadpool.cc  |   2 +-
41 mindspore/core/mindrt/src/thread/threadlog.h  |  28 +-
42 .../core/mindrt/src/thread/threadpool.cc      |   7 +-
43 mindspore/lite/BUILD.gn                       |  82 +-
44 mindspore/lite/CMakeLists.txt                 |   5 +-
45 mindspore/lite/include/lite_types.h           |   1 +
46 mindspore/lite/include/model.h                |   4 +
47 .../lite/include/registry/converter_context.h |   4 +-
48 mindspore/lite/mindir/include/mindir.h        |   2 +
49 mindspore/lite/mindir/src/mindir.cc           |  40 +
50 mindspore/lite/mindir/src/mindir_tensor.cc    |   2 +-
51 mindspore/lite/mindir/src/utils.cc            |   2 +-
52 mindspore/lite/src/CMakeLists.txt             |   6 +-
53 mindspore/lite/src/common/context_util.cc     |  14 +-
54 mindspore/lite/src/common/log.cc              |  33 +-
55 mindspore/lite/src/common/log.h               |  50 +-
56 .../common/ops/populate/custom_populate.cc    |  53 ++
57 mindspore/lite/src/litert/c_api/context_c.cc  | 372 +++++++-
58 mindspore/lite/src/litert/c_api/context_c.h   |  23 -
59 mindspore/lite/src/litert/c_api/model_c.cc    | 724 ++++++++-------
60 mindspore/lite/src/litert/c_api/tensor_c.cc   |  78 +-
61 .../lite/src/litert/c_api/type_c_private.h    |  40 +
62 mindspore/lite/src/litert/cxx_api/context.cc  |  85 ++
63 .../lite/src/litert/cxx_api/converters.cc     |  60 +-
64 .../lite/src/litert/cxx_api/converters.h      |   4 +-
65 .../src/litert/delegate/nnrt/CMakeLists.txt   |  27 +-
66 .../delegate/nnrt/checker/primitive_check.cc  |   2 +
67 .../src/litert/delegate/nnrt/nnrt_delegate.cc | 836 ++++++++++++++----
68 .../src/litert/delegate/nnrt/nnrt_delegate.h  |  74 +-
69 .../litert/delegate/nnrt/nnrt_model_kernel.cc |   3 +-
70 .../litert/delegate/nnrt/nnrt_model_kernel.h  |   2 +-
71 .../src/litert/delegate/nnrt/nnrt_stub.cc     |  99 +++
72 mindspore/lite/src/litert/infer_manager.cc    |   3 +-
73 mindspore/lite/src/litert/inner_context.cc    |   4 +
74 mindspore/lite/src/litert/inner_context.h     |  14 +
75 mindspore/lite/src/litert/kernel/cpu/BUILD.gn |  51 +-
76 .../src/litert/kernel/cpu/base/custom_base.cc |  46 +
77 .../src/litert/kernel/cpu/base/custom_base.h  |  43 +
78 .../litert/kernel/cpu/base/custom_is_inf.cc   |  61 ++
79 .../litert/kernel/cpu/base/custom_is_inf.h    |  38 +
80 .../kernel/cpu/base/custom_masked_fill.cc     |  84 ++
81 .../kernel/cpu/base/custom_masked_fill.h      |  35 +
82 .../kernel/cpu/base/custom_tensor_scatter.cc  |  75 ++
83 .../kernel/cpu/base/custom_tensor_scatter.h   |  36 +
84 mindspore/lite/src/litert/lite_model.cc       |  29 +
85 mindspore/lite/src/litert/lite_session.cc     |  39 +-
86 mindspore/lite/src/litert/lite_session.h      |   1 +
87 mindspore/lite/src/litert/scheduler.cc        |  17 +
88 mindspore/lite/src/litert/tensor_category.cc  |   4 +
89 mindspore/lite/src/litert/tensor_category.h   |   1 +
90 mindspore/lite/test/CMakeLists.txt            |  15 +-
91 mindspore/lite/test/runtest.sh                |   1 +
92 .../test/ut/test_data/third_party_model.cfg   |   8 +
93 .../tools/converter/api/converter_api_test.cc |  10 +
94 .../third_party_param_parser_test.cc          | 176 ++++
95 .../lite/tools/benchmark/benchmark_base.cc    |   2 +-
96 .../lite/tools/benchmark/benchmark_base.h     |   2 +-
97 .../lite/tools/benchmark/benchmark_c_api.cc   |   4 +
98 .../tools/benchmark/benchmark_unified_api.cc  |   5 +
99 .../lite/tools/benchmark_train/CMakeLists.txt |   3 +
100 mindspore/lite/tools/benchmark_train/main.cc  |   3 +-
101 .../lite/tools/benchmark_train/net_runner.cc  |  10 +-
102 .../lite/tools/benchmark_train/net_train.cc   | 418 +--------
103 .../lite/tools/benchmark_train/net_train.h    | 229 +----
104 .../tools/benchmark_train/net_train_base.cc   | 410 +++++++++
105 .../tools/benchmark_train/net_train_base.h    | 288 ++++++
106 .../tools/benchmark_train/net_train_c_api.cc  | 659 ++++++++++++++
107 .../tools/benchmark_train/net_train_c_api.h   | 121 +++
108 .../tools/benchmark_train/run_net_train.cc    |  86 ++
109 .../tools/benchmark_train/run_net_train.h     |  22 +
110 mindspore/lite/tools/converter/CMakeLists.txt |   4 +
111 .../config_parser/config_file_parser.cc       |  27 +
112 .../config_parser/config_file_parser.h        |  15 +
113 .../config_parser/third_party_param_parser.cc | 299 +++++++
114 .../config_parser/third_party_param_parser.h  |  44 +
115 mindspore/lite/tools/converter/converter.cc   |  34 +-
116 .../tools/converter/converter_funcgraph.cc    |  13 +-
117 .../converter_lite/converter_flags.cc         |   4 +-
118 .../tools/converter/cxx_api/converter_para.h  |  14 +
119 .../tools/converter/graphdef_transform.cc     |  44 +
120 .../parser/third_party/CMakeLists.txt         |   4 +
121 .../third_party/third_party_model_parser.cc   | 277 ++++++
122 .../third_party/third_party_model_parser.h    |  50 ++
123 .../registry/model_parser_registry.cc         |   4 +-
124 117 files changed, 6456 insertions(+), 1432 deletions(-)
125 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h
126 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h
127 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h
128 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h
129 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h
130 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c
131 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h
132 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c
133 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h
134 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c
135 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h
136 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h
137 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h
138 create mode 100644 mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h
139 create mode 100644 mindspore/lite/src/litert/c_api/type_c_private.h
140 create mode 100644 mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc
141 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc
142 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_base.h
143 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc
144 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h
145 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc
146 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h
147 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc
148 create mode 100644 mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h
149 create mode 100644 mindspore/lite/test/ut/test_data/third_party_model.cfg
150 create mode 100644 mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc
151 create mode 100644 mindspore/lite/tools/benchmark_train/net_train_base.cc
152 create mode 100644 mindspore/lite/tools/benchmark_train/net_train_base.h
153 create mode 100644 mindspore/lite/tools/benchmark_train/net_train_c_api.cc
154 create mode 100644 mindspore/lite/tools/benchmark_train/net_train_c_api.h
155 create mode 100644 mindspore/lite/tools/benchmark_train/run_net_train.cc
156 create mode 100644 mindspore/lite/tools/benchmark_train/run_net_train.h
157 create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc
158 create mode 100644 mindspore/lite/tools/converter/config_parser/third_party_param_parser.h
159 create mode 100644 mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt
160 create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc
161 create mode 100644 mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h
162
163diff --git a/cmake/external_libs/flatbuffers.cmake b/cmake/external_libs/flatbuffers.cmake
164index 2fde4311..87f0425b 100644
165--- a/cmake/external_libs/flatbuffers.cmake
166+++ b/cmake/external_libs/flatbuffers.cmake
167@@ -21,8 +21,8 @@ else()
168         # flatbuffers.lib cimplied by msvc
169         set(CMAKE_STATIC_LIBRARY_PREFIX "")
170     else()
171-        set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong")
172-        set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong")
173+        set(flatbuffers_CXXFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -Wno-error=unused-but-set-variable")
174+        set(flatbuffers_CFLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -fstack-protector-strong -Wno-error=unused-but-set-variable")
175     endif()
176 
177     if(WIN32)
178diff --git a/include/api/context.h b/include/api/context.h
179index c9fb11f0..eb704d44 100644
180--- a/include/api/context.h
181+++ b/include/api/context.h
182@@ -39,6 +39,8 @@ enum DeviceType {
183   kAscend310,
184   kCustomDevice,
185   kAllDevice,
186+  //ohos-only device range[60,80)
187+  kNNRt = 60,
188   // add new type here
189   kInvalidDeviceType = 100,
190 };
191@@ -598,5 +600,68 @@ void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_
192   SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
193 }
194 std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
195+
196+struct Extension {
197+  std::string name;
198+  std::vector<uint8_t> value;
199+};
200+
201+class MS_API NNRTDeviceInfo : public DeviceInfoContext {
202+ public:
203+  /// \brief Get the type of this DeviceInfoContext.
204+  ///
205+  /// \return Type of this DeviceInfoContext.
206+  enum DeviceType GetDeviceType() const override { return DeviceType::kNNRt; };
207+
208+  /// \brief Set device id.
209+  ///
210+  /// \param[in] device_id The device id.
211+  void SetDeviceID(size_t device_id);
212+
213+  /// \brief Get the device id.
214+  ///
215+  /// \return The device id.
216+  size_t GetDeviceID() const;
217+
218+  /// \brief Set performance mode.
219+  ///
220+  /// \param[in] performance_mode The performance mode.
221+  void SetPerformanceMode(int performance_mode);
222+
223+  /// \brief Get performance mode.
224+  ///
225+  /// \return The priority.
226+  int GetPerformanceMode() const;
227+
228+  /// \brief Set priority.
229+  ///
230+  /// \param[in] priority The priority.
231+  void SetPriority(int priority);
232+
233+  /// \brief Get priority.
234+  ///
235+  /// \return The priority.
236+  int GetPriority() const;
237+
238+  /// \brief Set enables to perform the float16 inference
239+  ///
240+  /// \param[in] is_fp16 Enable float16 inference or not.
241+  void SetEnableFP16(bool is_fp16);
242+
243+  /// \brief Get enables to perform the float16 inference
244+  ///
245+  /// \return Whether enable float16 inference.
246+  bool GetEnableFP16() const;
247+
248+  /// \brief Set extensions
249+  ///
250+  /// \param[in] extension array.
251+  void SetExtensions(const std::vector<Extension> &extensions);
252+
253+  /// \brief Get extensions
254+  ///
255+  /// \return extension array.
256+  std::vector<Extension> GetExtensions() const;
257+};
258 }  // namespace mindspore
259 #endif  // MINDSPORE_INCLUDE_API_CONTEXT_H
260diff --git a/include/c_api/context_c.h b/include/c_api/context_c.h
261index 53839e80..8951da25 100644
262--- a/include/c_api/context_c.h
263+++ b/include/c_api/context_c.h
264@@ -19,6 +19,7 @@
265 #include <stddef.h>
266 #include <stdint.h>
267 #include <stdbool.h>
268+#include "include/c_api/status_c.h"
269 #include "include/c_api/types_c.h"
270 
271 #ifdef __cplusplus
272@@ -173,6 +174,116 @@ OH_AI_API void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info,
273 /// \return NPU frequency
274 OH_AI_API int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info);
275 
276+/// \brief Obtain the all device descriptions in NNRT.
277+///
278+/// \param[out] num Number of NNRT device description.
279+///
280+/// \return NNRT device description array.
281+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num);
282+
283+/// \brief Obtain the specified element in NNRt device description array.
284+///
285+/// \param[in] descs NNRT device description array.
286+/// \param[in] index Element index.
287+///
288+/// \return NNRT device description.
289+OH_AI_API NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index);
290+
291+/// \brief Obtain the all device descriptions in NNRT.
292+///
293+/// \param[out] num Number of NNRT device description.
294+///
295+/// \return NNRT device description array.
296+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num);
297+
298+/// \brief Destroy the NNRT device descriptions returned by OH_AI_GetAllNNRTDeviceDescs().
299+///
300+/// \param[in] desc NNRT device description array.
301+OH_AI_API void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc);
302+
303+/// \brief Obtain the device id in NNRT device description.
304+///
305+/// \param[in] desc pointer to the NNRT device description instance.
306+///
307+/// \return NNRT device id.
308+OH_AI_API size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
309+
310+/// \brief Obtain the device name in NNRT device description.
311+///
312+/// \param[in] desc pointer to the NNRT device description instance.
313+///
314+/// \return NNRT device name.
315+OH_AI_API const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
316+
317+/// \brief Obtain the device type in NNRT device description.
318+///
319+/// \param[in] desc pointer to the NNRT device description instance.
320+///
321+/// \return NNRT device type.
322+OH_AI_API OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
323+
324+/// \brief Create the NNRT device info by exactly matching the specific device name.
325+///
326+/// \param[in] name NNRt device name.
327+///
328+/// \return Device info object handle.
329+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name);
330+
331+/// \brief Create the NNRT device info by finding the first device with the specific device type.
332+///
333+/// \param[in] name NNRt device type.
334+///
335+/// \return Device info object handle.
336+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type);
337+
338+/// \brief Set the NNRT device id, Only valid for NNRT.
339+///
340+/// \param[in] device_info Device info object handle.
341+/// \param[in] device_id NNRT device id.
342+OH_AI_API void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id);
343+
344+/// \brief Obtain the NNRT device id, Only valid for NNRT.
345+///
346+/// \param[in] device_info Device info object handle.
347+///
348+/// \return NNRT device id.
349+OH_AI_API size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info);
350+
351+/// \brief Set the NNRT performance mode, Only valid for NNRT.
352+///
353+/// \param[in] device_info Device info object handle.
354+/// \param[in] device_id NNRT performance mode.
355+OH_AI_API void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode);
356+
357+/// \brief Obtain the NNRT performance mode, Only valid for NNRT.
358+///
359+/// \param[in] device_info Device info object handle.
360+///
361+/// \return NNRT performance mode.
362+OH_AI_API OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info);
363+
364+/// \brief Set the NNRT priority, Only valid for NNRT.
365+///
366+/// \param[in] device_info Device info object handle.
367+/// \param[in] device_id NNRT priority.
368+OH_AI_API void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority);
369+
370+/// \brief Obtain the NNRT priority, Only valid for NNRT.
371+///
372+/// \param[in] device_info Device info object handle.
373+///
374+/// \return NNRT priority.
375+OH_AI_API OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info);
376+
377+/// \brief Add extension of key/value format to device info, Only valid for NNRT.
378+///
379+/// \param[in] device_info Device info object handle.
380+/// \param[in] name The content of key as a C string.
381+/// \param[in] value The pointer to the value, which is a byte array.
382+/// \param[in] value_size The size of the value, which is a byte array.
383+///
384+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
385+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, const char *name, const char *value, size_t value_size);
386 #ifdef __cplusplus
387 }
388 #endif
389diff --git a/include/c_api/model_c.h b/include/c_api/model_c.h
390index 12a46bcd..2286e673 100644
391--- a/include/c_api/model_c.h
392+++ b/include/c_api/model_c.h
393@@ -26,6 +26,8 @@ extern "C" {
394 
395 typedef void *OH_AI_ModelHandle;
396 
397+typedef void *OH_AI_TrainCfgHandle;
398+
399 typedef struct OH_AI_TensorHandleArray {
400   size_t handle_num;
401   OH_AI_TensorHandle *handle_list;
402@@ -168,6 +170,182 @@ OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHa
403 /// \return The output tensor handle with the given name, if the name is not found, an NULL is returned.
404 OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name);
405 
406+/// \brief Create a TrainCfg object. Only valid for Lite Train.
407+///
408+/// \return TrainCfg object handle.
409+OH_AI_API OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate();
410+
411+/// \brief Destroy the train_cfg object. Only valid for Lite Train.
412+///
413+/// \param[in] train_cfg TrainCfg object handle.
414+OH_AI_API void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg);
415+
416+/// \brief Obtains part of the name that identify a loss kernel. Only valid for Lite Train.
417+///
418+/// \param[in] train_cfg TrainCfg object handle.
419+/// \param[in] num The num of loss_name.
420+///
421+/// \return loss_name.
422+OH_AI_API char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num);
423+
424+/// \brief Set part of the name that identify a loss kernel. Only valid for Lite Train.
425+///
426+/// \param[in] train_cfg TrainCfg object handle.
427+/// \param[in] loss_name define part of the name that identify a loss kernel.
428+/// \param[in] num The num of loss_name.
429+OH_AI_API void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num);
430+
431+/// \brief Obtains optimization level of the train_cfg. Only valid for Lite Train.
432+///
433+/// \param[in] train_cfg TrainCfg object handle.
434+///
435+/// \return OH_AI_OptimizationLevel.
436+OH_AI_API OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg);
437+
438+/// \brief Set optimization level of the train_cfg. Only valid for Lite Train.
439+///
440+/// \param[in] train_cfg TrainCfg object handle.
441+/// \param[in] level The optimization level of train_cfg.
442+OH_AI_API void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level);
443+
444+/// \brief Build the train model from model buffer so that it can run on a device. Only valid for Lite Train.
445+///
446+/// \param[in] model Model object handle.
447+/// \param[in] model_data Define the buffer read from a model file.
448+/// \param[in] data_size Define bytes number of model file buffer.
449+/// \param[in] model_type Define The type of model file.
450+/// \param[in] model_context Define the context used to store options during execution.
451+/// \param[in] train_cfg Define the config used by training.
452+///
453+/// \return OH_AI_Status.
454+OH_AI_API OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
455+                                     OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
456+                                     const OH_AI_TrainCfgHandle train_cfg);
457+
458+/// \brief Build the train model from model file buffer so that it can run on a device. Only valid for Lite Train.
459+///
460+/// \param[in] model Model object handle.
461+/// \param[in] model_path Define the model path.
462+/// \param[in] model_type Define The type of model file.
463+/// \param[in] model_context Define the context used to store options during execution.
464+/// \param[in] train_cfg Define the config used by training.
465+///
466+/// \return OH_AI_Status.
467+OH_AI_API OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path,
468+                                             OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
469+                                             const OH_AI_TrainCfgHandle train_cfg);
470+
471+/// \brief Train model by step. Only valid for Lite Train.
472+///
473+/// \param[in] model Model object handle.
474+/// \param[in] before CallBack before predict.
475+/// \param[in] after CallBack after predict.
476+///
477+/// \return OH_AI_Status.
478+OH_AI_API OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before,
479+                                     const OH_AI_KernelCallBack after);
480+
481+/// \brief Sets the Learning Rate of the training. Only valid for Lite Train.
482+///
483+/// \param[in] learning_rate to set.
484+///
485+/// \return OH_AI_Status of operation.
486+OH_AI_API OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate);
487+
488+/// \brief Obtains the Learning Rate of the optimizer. Only valid for Lite Train.
489+///
490+/// \return Learning rate. 0.0 if no optimizer was found.
491+OH_AI_API float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model);
492+
493+/// \brief Obtains all weights tensors of the model. Only valid for Lite Train.
494+///
495+/// \param[in] model Model object handle.
496+///
497+/// \return The vector that includes all gradient tensors.
498+OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model);
499+
500+/// \brief update weights tensors of the model. Only valid for Lite Train.
501+///
502+/// \param[in] new_weights A vector new weights.
503+///
504+/// \return OH_AI_Status
505+OH_AI_API OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights);
506+
507+/// \brief Get the model running mode.
508+///
509+/// \param[in] model Model object handle.
510+///
511+/// \return Is Train Mode or not.
512+OH_AI_API bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model);
513+
514+/// \brief Set the model running mode. Only valid for Lite Train.
515+///
516+/// \param[in] model Model object handle.
517+/// \param[in] train True means model runs in Train Mode, otherwise Eval Mode.
518+///
519+/// \return OH_AI_Status.
520+OH_AI_API OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train);
521+
522+/// \brief Setup training with virtual batches. Only valid for Lite Train.
523+///
524+/// \param[in] model Model object handle.
525+/// \param[in] virtual_batch_multiplier - virtual batch multiplier, use any number < 1 to disable.
526+/// \param[in] lr - learning rate to use for virtual batch, -1 for internal configuration.
527+/// \param[in] momentum - batch norm momentum to use for virtual batch, -1 for internal configuration.
528+///
529+/// \return OH_AI_Status.
530+OH_AI_API OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr,
531+                                                    float momentum);
532+
533+/// \brief Export training model from file. Only valid for Lite Train.
534+///
535+/// \param[in] model The model data.
536+/// \param[in] model_type The model file type.
537+/// \param[in] model_file The exported model file.
538+/// \param[in] quantization_type The quantification type.
539+/// \param[in] export_inference_only Whether to export a reasoning only model.
540+/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
541+/// empty, and export the complete reasoning model.
542+/// \param[in] num The number of output_tensor_name.
543+///
544+/// \return OH_AI_Status.
545+OH_AI_API OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
546+                                         OH_AI_QuantizationType quantization_type, bool export_inference_only,
547+                                         char **output_tensor_name, size_t num);
548+
549+/// \brief Export training model from buffer. Only valid for Lite Train.
550+///
551+/// \param[in] model The model data.
552+/// \param[in] model_type The model file type.
553+/// \param[in] model_data The exported model buffer.
554+/// \param[in] data_size The exported model buffer size.
555+/// \param[in] quantization_type The quantification type.
556+/// \param[in] export_inference_only Whether to export a reasoning only model.
557+/// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
558+/// empty, and export the complete reasoning model.
559+/// \param[in] num The number of output_tensor_name.
560+///
561+/// \return OH_AI_Status.
562+OH_AI_API OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data,
563+                                               size_t *data_size, OH_AI_QuantizationType quantization_type,
564+                                               bool export_inference_only, char **output_tensor_name, size_t num);
565+
566+/// \brief Export model's weights, which can be used in micro only. Only valid for Lite Train.
567+///
568+/// \param[in] model The model data.
569+/// \param[in] model_type The model file type.
570+/// \param[in] weight_file The path of exported weight file.
571+/// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`.
572+/// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format.
573+/// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable.
574+/// \param[in] num The number of changeable_weights_name.
575+///
576+/// \return OH_AI_Status.
577+OH_AI_API OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type,
578+                                                               const char *weight_file, bool is_inference,
579+                                                               bool enable_fp16, char **changeable_weights_name,
580+                                                               size_t num);
581+
582 #ifdef __cplusplus
583 }
584 #endif
585diff --git a/include/c_api/tensor_c.h b/include/c_api/tensor_c.h
586index f18ba163..6d2aaab6 100644
587--- a/include/c_api/tensor_c.h
588+++ b/include/c_api/tensor_c.h
589@@ -17,6 +17,7 @@
590 #define MINDSPORE_INCLUDE_C_API_TENSOE_C_H
591 
592 #include <stddef.h>
593+#include "include/c_api/status_c.h"
594 #include "include/c_api/types_c.h"
595 #include "include/c_api/data_type_c.h"
596 #include "include/c_api/format_c.h"
597@@ -112,6 +113,19 @@ OH_AI_API OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor);
598 /// \param[in] data A pointer to the data of the tensor.
599 OH_AI_API void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data);
600 
601+/// \brief Set the data for the tensor with user-allocated data buffer.
602+/// The main purpose of this interface is providing a way of using memory already allocated by user as the Model's
603+/// input, but not which allocated inside the Model object. It can reduce one copy.
604+/// Note: The tensor won't free the data provided by invoker. Invoker has the responsibility to free it. And this
605+/// free action should not be preformed before destruction of the tensor.
606+///
607+/// \param[in] tensor Tensor object handle.
608+/// \param[in] data A pointer to the user data buffer.
609+/// \param[in] data the byte size of the user data buffer.
610+///
611+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
612+OH_AI_API OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size);
613+
614 /// \brief Obtain the data pointer of the tensor.
615 ///
616 /// \param[in] tensor Tensor object handle.
617diff --git a/include/c_api/types_c.h b/include/c_api/types_c.h
618index dba54ffa..e520e336 100644
619--- a/include/c_api/types_c.h
620+++ b/include/c_api/types_c.h
621@@ -40,10 +40,65 @@ typedef enum OH_AI_DeviceType {
622   OH_AI_DEVICETYPE_KIRIN_NPU,
623   // add new type here
624   // ohos-only device range: [60, 80)
625-  OH_AI_DEVICETYPE__NNRT = 60,
626+  OH_AI_DEVICETYPE_NNRT = 60,
627   OH_AI_DEVICETYPE_INVALID = 100,
628 } OH_AI_DeviceType;
629 
630+typedef enum OH_AI_NNRTDeviceType {
631+  /** Devices that are not CPU, GPU, or dedicated accelerator */
632+  OH_AI_NNRTDEVICE_OTHERS = 0,
633+  /** CPU device */
634+  OH_AI_NNRTDEVICE_CPU = 1,
635+  /** GPU device */
636+  OH_AI_NNRTDEVICE_GPU = 2,
637+  /** Dedicated hardware accelerator */
638+  OH_AI_NNRTDEVICE_ACCELERATOR = 3,
639+} OH_AI_NNRTDeviceType;
640+
641+typedef enum OH_AI_PerformanceMode {
642+  /** No performance mode preference */
643+  OH_AI_PERFORMANCE_NONE = 0,
644+  /** Low power consumption mode*/
645+  OH_AI_PERFORMANCE_LOW = 1,
646+  /** Medium performance mode */
647+  OH_AI_PERFORMANCE_MEDIUM = 2,
648+  /** High performance mode */
649+  OH_AI_PERFORMANCE_HIGH = 3,
650+  /** Ultimate performance mode */
651+  OH_AI_PERFORMANCE_EXTREME = 4
652+} OH_AI_PerformanceMode;
653+
654+typedef enum OH_AI_Priority {
655+  /** No priority preference */
656+  OH_AI_PRIORITY_NONE = 0,
657+  /** Low priority */
658+  OH_AI_PRIORITY_LOW = 1,
659+  /** Medium priority */
660+  OH_AI_PRIORITY_MEDIUM = 2,
661+  /** High priority */
662+  OH_AI_PRIORITY_HIGH = 3
663+} OH_AI_Priority;
664+
665+typedef enum OH_AI_OptimizationLevel {
666+  /** Do not change */
667+  OH_AI_KO0 = 0,
668+  /** Cast network to float16, keep batchnorm and loss in float32 */
669+  OH_AI_KO2 = 2,
670+  /** Cast network to float16, including bacthnorm */
671+  OH_AI_KO3 = 3,
672+  /** Choose optimization based on device */
673+  OH_AI_KAUTO = 4,
674+  OH_AI_KOPTIMIZATIONTYPE = 0xFFFFFFFF
675+} OH_AI_OptimizationLevel;
676+
677+typedef enum OH_AI_QuantizationType {
678+  OH_AI_NO_QUANT = 0,
679+  OH_AI_WEIGHT_QUANT = 1,
680+  OH_AI_FULL_QUANT = 2,
681+  OH_AI_UNKNOWN_QUANT_TYPE = 0xFFFFFFFF
682+} OH_AI_QuantizationType;
683+
684+typedef struct NNRTDeviceDesc NNRTDeviceDesc;
685 #ifdef __cplusplus
686 }
687 #endif
688diff --git a/include/sdk_api/context.h b/include/sdk_api/context.h
689index 5bfc9279..e12b8d6f 100644
690--- a/include/sdk_api/context.h
691+++ b/include/sdk_api/context.h
692@@ -174,6 +174,109 @@ OH_AI_API void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info,
693 /// \return NPU frequency
694 OH_AI_API int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info);
695 
696+/// \brief Obtain the all device descriptions in NNRT.
697+///
698+/// \param[out] num Number of NNRT device description.
699+///
700+/// \return NNRT device description array.
701+OH_AI_API NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num);
702+
703+/// \brief Obtain the specified element in NNRt device description array.
704+///
705+/// \param[in] descs NNRT device description array.
706+/// \param[in] index Element index.
707+///
708+/// \return NNRT device description.
709+OH_AI_API NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index);
710+
711+/// \brief Destroy the NNRT device descriptions returned by OH_AI_NNRTGetAllDeviceDescs().
712+///
713+/// \param[in] desc NNRT device description array.
714+OH_AI_API void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc);
715+
716+/// \brief Obtain the device id in NNRT device description.
717+///
718+/// \param[in] desc pointer to the NNRT device description instance.
719+///
720+/// \return NNRT device id.
721+OH_AI_API size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
722+
723+/// \brief Obtain the device name in NNRT device description.
724+///
725+/// \param[in] desc pointer to the NNRT device description instance.
726+///
727+/// \return NNRT device name.
728+OH_AI_API const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
729+
730+/// \brief Obtain the device type in NNRT device description.
731+///
732+/// \param[in] desc pointer to the NNRT device description instance.
733+///
734+/// \return NNRT device type.
735+OH_AI_API OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc);
736+
737+/// \brief Create the NNRT device info by exactly matching the specific device name.
738+///
739+/// \param[in] name NNRt device name.
740+///
741+/// \return Device info object handle.
742+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name);
743+
744+/// \brief Create the NNRT device info by finding the first device with the specific device type.
745+///
746+/// \param[in] name NNRt device type.
747+///
748+/// \return Device info object handle.
749+OH_AI_API OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type);
750+
751+/// \brief Set the NNRT device id, Only valid for NNRT.
752+///
753+/// \param[in] device_info Device info object handle.
754+/// \param[in] device_id NNRT device id.
755+OH_AI_API void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id);
756+
757+/// \brief Obtain the NNRT device id, Only valid for NNRT.
758+///
759+/// \param[in] device_info Device info object handle.
760+///
761+/// \return NNRT device id.
762+OH_AI_API size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info);
763+
764+/// \brief Set the NNRT performance mode, Only valid for NNRT.
765+///
766+/// \param[in] device_info Device info object handle.
767+/// \param[in] device_id NNRT performance mode.
768+OH_AI_API void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode);
769+
770+/// \brief Obtain the NNRT performance mode, Only valid for NNRT.
771+///
772+/// \param[in] device_info Device info object handle.
773+///
774+/// \return NNRT performance mode.
775+OH_AI_API OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info);
776+
777+/// \brief Set the NNRT priority, Only valid for NNRT.
778+///
779+/// \param[in] device_info Device info object handle.
780+/// \param[in] device_id NNRT priority.
781+OH_AI_API void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority);
782+
783+/// \brief Obtain the NNRT priority, Only valid for NNRT.
784+///
785+/// \param[in] device_info Device info object handle.
786+///
787+/// \return NNRT priority.
788+OH_AI_API OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info);
789+
790+/// \brief Add extension of key/value format to device info, Only valid for NNRT.
791+///
792+/// \param[in] device_info Device info object handle.
793+/// \param[in] name The content of key as a C string.
794+/// \param[in] value The pointer to the value, which is a byte array.
795+/// \param[in] value_size The size of the value, which is a byte array.
796+///
797+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
798+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info, const char *name, const char *value, size_t value_size);
799 #ifdef __cplusplus
800 }
801 #endif
802diff --git a/include/sdk_api/tensor.h b/include/sdk_api/tensor.h
803index f6ba02cd..3dad04ac 100644
804--- a/include/sdk_api/tensor.h
805+++ b/include/sdk_api/tensor.h
806@@ -17,6 +17,7 @@
807 #define MINDSPORE_INCLUDE_C_API_TENSOE_C_H
808 
809 #include <stddef.h>
810+#include "mindspore/status.h"
811 #include "mindspore/types.h"
812 #include "mindspore/data_type.h"
813 #include "mindspore/format.h"
814@@ -140,6 +141,18 @@ OH_AI_API int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor);
815 /// \return The data size of the tensor.
816 OH_AI_API size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor);
817 
818+/// \brief Set the data for the tensor with user-allocated data buffer.
819+/// The main purpose of this interface is providing a way of using memory already allocated by user as the Model's
820+/// input, but not which allocated inside the Model object. It can reduce one copy.
821+/// Note: The tensor won't free the data provided by invoker. Invoker has the responsibility to free it. And this
822+/// free action should not be preformed before destruction of the tensor.
823+///
824+/// \param[in] tensor Tensor object handle.
825+/// \param[in] data A pointer to the user data buffer.
826+/// \param[in] data the byte size of the user data buffer.
827+///
828+/// \return OH_AI_STATUS_SUCCESS if success, or detail error code if failed.
829+OH_AI_API OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size);
830 #ifdef __cplusplus
831 }
832 #endif
833diff --git a/include/sdk_api/types.h b/include/sdk_api/types.h
834index a39c6daa..d38660b0 100644
835--- a/include/sdk_api/types.h
836+++ b/include/sdk_api/types.h
837@@ -40,10 +40,46 @@ typedef enum OH_AI_DeviceType {
838   OH_AI_DEVICETYPE_KIRIN_NPU,
839   // add new type here
840   // ohos-only device range: [60, 80)
841-  OH_AI_DeviceType_NNRT = 60,
842+  OH_AI_DEVICETYPE_NNRT = 60,
843   OH_AI_DEVICETYPE_INVALID = 100,
844 } OH_AI_DeviceType;
845 
846+typedef enum OH_AI_NNRTDeviceType {
847+  /** Devices that are not CPU, GPU, or dedicated accelerator */
848+  OH_AI_NNRTDEVICE_OTHERS = 0,
849+  /** CPU device */
850+  OH_AI_NNRTDEVICE_CPU = 1,
851+  /** GPU device */
852+  OH_AI_NNRTDEVICE_GPU = 2,
853+  /** Dedicated hardware accelerator */
854+  OH_AI_NNRTDEVICE_ACCELERATOR = 3,
855+} OH_AI_NNRTDeviceType;
856+
857+typedef enum OH_AI_PerformanceMode {
858+  /** No performance mode preference */
859+  OH_AI_PERFORMANCE_NONE = 0,
860+  /** Low power consumption mode*/
861+  OH_AI_PERFORMANCE_LOW = 1,
862+  /** Medium performance mode */
863+  OH_AI_PERFORMANCE_MEDIUM = 2,
864+  /** High performance mode */
865+  OH_AI_PERFORMANCE_HIGH = 3,
866+  /** Ultimate performance mode */
867+  OH_AI_PERFORMANCE_EXTREME = 4
868+} OH_AI_PerformanceMode;
869+
870+typedef enum OH_AI_Priority {
871+  /** No priority preference */
872+  OH_AI_PRIORITY_NONE = 0,
873+  /** Low priority */
874+  OH_AI_PRIORITY_LOW = 1,
875+  /** Medium priority */
876+  OH_AI_PRIORITY_MEDIUM = 2,
877+  /** High priority */
878+  OH_AI_PRIORITY_HIGH = 3
879+} OH_AI_Priority;
880+
881+typedef struct NNRTDeviceDesc NNRTDeviceDesc;
882 #ifdef __cplusplus
883 }
884 #endif
885diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
886index 7bbc3782..103e53b7 100644
887--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
888+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/BUILD.gn
889@@ -498,6 +498,9 @@ infer_shape_sources = [
890   "infer/crop_infer.c",
891   "infer/cumsum_infer.c",
892   "infer/custom_gru_infer.c",
893+  "infer/custom_masked_fill_infer.c",
894+  "infer/custom_is_inf_infer.c",
895+  "infer/custom_tensor_scatter_max_infer.c",
896   "infer/decoder_layer_infer.c",
897   "infer/deconv2d_infer.c",
898   "infer/depth_to_space_infer.c",
899diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt
900index c1685a65..6fef44fd 100644
901--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt
902+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/CMakeLists.txt
903@@ -238,7 +238,7 @@ endif()
904 if(PLATFORM_ARM)
905     set(NO_FAST_MATH_OPTI ${NNACL_DIR}/fp32/resize_fp32.c)
906     set_source_files_properties(${NO_FAST_MATH_OPTI} PROPERTIES LANGUAGE C
907-        COMPILE_FLAGS "${CMAKE_C_FLAGS} -fno-fast-math")
908+        COMPILE_FLAGS "${CMAKE_C_FLAGS} -w -fno-fast-math")
909 endif()
910 
911 add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC} ${MS_X86_SIMD_SRC})
912diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h
913new file mode 100644
914index 00000000..14bd1d76
915--- /dev/null
916+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx/scatter_nd_binary_avx.h
917@@ -0,0 +1,66 @@
918+/**
919+* Copyright 2023 Huawei Technologies Co., Ltd
920+*
921+* Licensed under the Apache License, Version 2.0 (the "License");
922+* you may not use this file except in compliance with the License.
923+* You may obtain a copy of the License at
924+*
925+* http://www.apache.org/licenses/LICENSE-2.0
926+*
927+* Unless required by applicable law or agreed to in writing, software
928+* distributed under the License is distributed on an "AS IS" BASIS,
929+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
930+* See the License for the specific language governing permissions and
931+* limitations under the License.
932+*/
933+#ifndef NNACL_BASE_SCATTER_ND_BINARY_AVX_H_
934+#define NNACL_BASE_SCATTER_ND_BINARY_AVX_H_
935+
936+#include "nnacl/intrinsics/ms_simd_instructions.h"
937+#include "nnacl/intrinsics/ms_simd_avx_instructions.h"
938+
939+#ifdef __cplusplus
940+extern "C" {
941+#endif
942+#pragma GCC push_options
943+#pragma GCC target("avx", "avx2")
944+#define MS_SIMD_INSTRUCTION MS_SIMD_AVX_INSTRUCTION
945+#define BLOCK_NUM 8
946+#define MS_SIMD_AVX
947+
948+static inline int ScatterNDAddFp32AVX(int index, const float *update, int size, float *output) {
949+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
950+    SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
951+  }
952+  return index;
953+}
954+
955+static inline int ScatterNDAddInt32AVX(int index, const int *update, int size, int *output) {
956+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
957+    SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
958+  }
959+  return index;
960+}
961+
962+static inline int ScatterNDMaxFp32AVX(int index, const float *update, int size, float *output) {
963+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
964+    SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
965+  }
966+  return index;
967+}
968+
969+static inline int ScatterNDMaxInt32AVX(int index, const int *update, int size, int *output) {
970+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
971+    SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
972+  }
973+  return index;
974+}
975+
976+#undef MS_SIMD_INSTRUCTION
977+#undef BLOCK_NUM
978+#pragma GCC pop_options
979+#undef MS_SIMD_AVX
980+#ifdef __cplusplus
981+}
982+#endif
983+#endif  // NNACL_BASE_SCATTER_ND_BINARY_AVX_H_
984diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h
985new file mode 100644
986index 00000000..abf024c5
987--- /dev/null
988+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/avx512/scatter_nd_binary_avx512.h
989@@ -0,0 +1,66 @@
990+/**
991+* Copyright 2023 Huawei Technologies Co., Ltd
992+*
993+* Licensed under the Apache License, Version 2.0 (the "License");
994+* you may not use this file except in compliance with the License.
995+* You may obtain a copy of the License at
996+*
997+* http://www.apache.org/licenses/LICENSE-2.0
998+*
999+* Unless required by applicable law or agreed to in writing, software
1000+* distributed under the License is distributed on an "AS IS" BASIS,
1001+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1002+* See the License for the specific language governing permissions and
1003+* limitations under the License.
1004+*/
1005+#ifndef NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_
1006+#define NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_
1007+
1008+#include "nnacl/intrinsics/ms_simd_instructions.h"
1009+#include "nnacl/intrinsics/ms_simd_avx512_instructions.h"
1010+
1011+#ifdef __cplusplus
1012+extern "C" {
1013+#endif
1014+#pragma GCC push_options
1015+#pragma GCC target("avx512f")
1016+#define MS_SIMD_INSTRUCTION MS_SIMD_AVX512_INSTRUCTION
1017+#define BLOCK_NUM 16
1018+#define MS_SIMD_AVX512
1019+
1020+static inline int ScatterNDAddFp32AVX512(int index, const float *update, int size, float *output) {
1021+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1022+    SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1023+  }
1024+  return index;
1025+}
1026+
1027+static inline int ScatterNDAddInt32AVX512(int index, const int *update, int size, int *output) {
1028+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1029+    SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1030+  }
1031+  return index;
1032+}
1033+
1034+static inline int ScatterNDMaxFp32AVX512(int index, const float *update, int size, float *output) {
1035+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1036+    SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1037+  }
1038+  return index;
1039+}
1040+
1041+static inline int ScatterNDMaxInt32AVX512(int index, const int *update, int size, int *output) {
1042+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1043+    SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1044+  }
1045+  return index;
1046+}
1047+
1048+#undef MS_SIMD_INSTRUCTION
1049+#undef BLOCK_NUM
1050+#pragma GCC pop_options
1051+#undef MS_SIMD_AVX512
1052+#ifdef __cplusplus
1053+}
1054+#endif
1055+#endif  // NNACL_BASE_SCATTER_ND_BINARY_AVX512_H_
1056diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c
1057index bca71f55..e496bb4b 100644
1058--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c
1059+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.c
1060@@ -77,3 +77,31 @@ int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets,
1061   }
1062   return NNACL_OK;
1063 }
1064+
1065+int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
1066+                 int task_id) {
1067+  if (update == NULL || output == NULL || output_unit_offsets == NULL || param == NULL) {
1068+    return NNACL_NULL_PTR;
1069+  }
1070+  if (param->op_parameter.thread_num_ == 0) {
1071+    return NNACL_ERR;
1072+  }
1073+  int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_);
1074+  int begin = unit_per_thread * task_id;
1075+  int end = MSMIN(begin + unit_per_thread, param->num_unit);
1076+  if (type == 0) {
1077+    float *update_fp32 = (float *)update;
1078+    float *output_fp32 = (float *)output;
1079+    for (int i = begin; i < end; i++) {
1080+      const float *update_data = update_fp32 + i * param->unit_size;
1081+      float *output_data = output_fp32 + output_unit_offsets[i];
1082+      int j = 0;
1083+
1084+      SIMD_RUN_NO_SCALAR(ScatterNDMaxFp32, j, update_data, param->unit_size, output_data);
1085+      for (; j < param->unit_size; j++) {
1086+        output_data[j] = fmaxf(update_data[j], output_data[j]);
1087+      }
1088+    }
1089+  }
1090+  return NNACL_OK;
1091+}
1092diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h
1093index 3af55335..36657cd9 100644
1094--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h
1095+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary.h
1096@@ -27,6 +27,9 @@ int ScatterNDUpdate(void *output, const void *update, int *output_unit_offsets,
1097 
1098 int ScatterNDAdd(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
1099                  int task_id);
1100+
1101+int ScatterNDMax(const void *update, void *output, int *output_unit_offsets, const ScatterNDParameter *param, int type,
1102+                 int task_id);
1103 #ifdef __cplusplus
1104 }
1105 #endif
1106diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in
1107index c72d9cc2..46bb20ce 100644
1108--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in
1109+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/base/scatter_nd_binary_simd.h.in
1110@@ -38,6 +38,20 @@ static inline int ScatterNDAddInt32@SIMD_INSTRUCTION@(int index, const int *upda
1111   return index;
1112 }
1113 
1114+static inline int ScatterNDMaxFp32@SIMD_INSTRUCTION@(int index, const float *update, int size, float *output) {
1115+for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1116+SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1117+}
1118+return index;
1119+}
1120+
1121+static inline int ScatterNDMaxInt32@SIMD_INSTRUCTION@(int index, const int *update, int size, int *output) {
1122+for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1123+SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1124+}
1125+return index;
1126+}
1127+
1128 @SIMD_INSTRUCTION_END@
1129 #ifdef __cplusplus
1130 }
1131diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h
1132new file mode 100644
1133index 00000000..e1eae394
1134--- /dev/null
1135+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_is_inf_parameter.h
1136@@ -0,0 +1,26 @@
1137+/**
1138+ * Copyright 2023 Huawei Technologies Co., Ltd
1139+ *
1140+ * Licensed under the Apache License, Version 2.0 (the "License");
1141+ * you may not use this file except in compliance with the License.
1142+ * You may obtain a copy of the License at
1143+ *
1144+ * http://www.apache.org/licenses/LICENSE-2.0
1145+ *
1146+ * Unless required by applicable law or agreed to in writing, software
1147+ * distributed under the License is distributed on an "AS IS" BASIS,
1148+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1149+ * See the License for the specific language governing permissions and
1150+ * limitations under the License.
1151+ */
1152+#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_
1153+#define MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_
1154+
1155+#include "nnacl/op_base.h"
1156+
1157+typedef struct CustomIsInfParameter {
1158+  // Primitive parameter
1159+  OpParameter op_parameter_;
1160+} CustomIsInfParameter;
1161+
1162+#endif  // MINDSPORE_NNACL_CUSTOM_IS_INF_PARAMETER_H_
1163diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h
1164new file mode 100644
1165index 00000000..047d3d3f
1166--- /dev/null
1167+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_masked_fill_parameter.h
1168@@ -0,0 +1,26 @@
1169+/**
1170+ * Copyright 2023 Huawei Technologies Co., Ltd
1171+ *
1172+ * Licensed under the Apache License, Version 2.0 (the "License");
1173+ * you may not use this file except in compliance with the License.
1174+ * You may obtain a copy of the License at
1175+ *
1176+ * http://www.apache.org/licenses/LICENSE-2.0
1177+ *
1178+ * Unless required by applicable law or agreed to in writing, software
1179+ * distributed under the License is distributed on an "AS IS" BASIS,
1180+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1181+ * See the License for the specific language governing permissions and
1182+ * limitations under the License.
1183+ */
1184+#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_
1185+#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_
1186+
1187+#include "nnacl/op_base.h"
1188+
1189+typedef struct CustomMaskedFillParameter {
1190+  // Primitive parameter
1191+  OpParameter op_parameter_;
1192+} CustomMaskedFillParameter;
1193+
1194+#endif  // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_PARAMETER_H_
1195diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h
1196new file mode 100644
1197index 00000000..ba6940db
1198--- /dev/null
1199+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/custom_tensor_scatter_max_parameter.h
1200@@ -0,0 +1,26 @@
1201+/**
1202+ * Copyright 2023 Huawei Technologies Co., Ltd
1203+ *
1204+ * Licensed under the Apache License, Version 2.0 (the "License");
1205+ * you may not use this file except in compliance with the License.
1206+ * You may obtain a copy of the License at
1207+ *
1208+ * http://www.apache.org/licenses/LICENSE-2.0
1209+ *
1210+ * Unless required by applicable law or agreed to in writing, software
1211+ * distributed under the License is distributed on an "AS IS" BASIS,
1212+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1213+ * See the License for the specific language governing permissions and
1214+ * limitations under the License.
1215+ */
1216+#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_
1217+#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_
1218+
1219+#include "nnacl/op_base.h"
1220+
1221+typedef struct CustomTensorScatterMaxParameter {
1222+  // Primitive parameter
1223+  OpParameter op_parameter_;
1224+} CustomTensorScatterMaxParameter;
1225+
1226+#endif  // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_PARAMETER_H_
1227diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c
1228new file mode 100644
1229index 00000000..fc87d157
1230--- /dev/null
1231+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.c
1232@@ -0,0 +1,38 @@
1233+/**
1234+ * Copyright 2023 Huawei Technologies Co., Ltd
1235+ *
1236+ * Licensed under the Apache License, Version 2.0 (the "License");
1237+ * you may not use this file except in compliance with the License.
1238+ * You may obtain a copy of the License at
1239+ *
1240+ * http://www.apache.org/licenses/LICENSE-2.0
1241+ *
1242+ * Unless required by applicable law or agreed to in writing, software
1243+ * distributed under the License is distributed on an "AS IS" BASIS,
1244+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1245+ * See the License for the specific language governing permissions and
1246+ * limitations under the License.
1247+ */
1248+
1249+#include "nnacl/infer/custom_is_inf_infer.h"
1250+#include "nnacl/infer/infer_register.h"
1251+
1252+int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
1253+                          OpParameter *parameter) {
1254+  int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C1NUM, C1NUM);
1255+  if (check_ret != NNACL_OK) {
1256+    return check_ret;
1257+  }
1258+
1259+  const TensorC *input = inputs[0];
1260+  TensorC *output = outputs[0];
1261+  output->data_type_ = kNumberTypeBool;
1262+  output->format_ = input->format_;
1263+  if (!InferFlag(inputs, inputs_size)) {
1264+    return NNACL_INFER_INVALID;
1265+  }
1266+  SetShapeTensor(output, input);
1267+  return NNACL_OK;
1268+}
1269+
1270+REG_INFER(CustomIsInf, PrimType_Inner_CustomIsInf, CustomIsInfInferShape)
1271diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h
1272new file mode 100644
1273index 00000000..d1b4b33d
1274--- /dev/null
1275+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_is_inf_infer.h
1276@@ -0,0 +1,31 @@
1277+/**
1278+ * Copyright 2023 Huawei Technologies Co., Ltd
1279+ *
1280+ * Licensed under the Apache License, Version 2.0 (the "License");
1281+ * you may not use this file except in compliance with the License.
1282+ * You may obtain a copy of the License at
1283+ *
1284+ * http://www.apache.org/licenses/LICENSE-2.0
1285+ *
1286+ * Unless required by applicable law or agreed to in writing, software
1287+ * distributed under the License is distributed on an "AS IS" BASIS,
1288+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1289+ * See the License for the specific language governing permissions and
1290+ * limitations under the License.
1291+ */
1292+#ifndef MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H
1293+#define MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H
1294+
1295+#include "nnacl/infer/common_infer.h"
1296+
1297+#ifdef __cplusplus
1298+extern "C" {
1299+#endif
1300+
1301+int CustomIsInfInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
1302+                          OpParameter *parameter);
1303+
1304+#ifdef __cplusplus
1305+}
1306+#endif
1307+#endif  // MINDSPORE_NNACL_CUSTOM_IS_INF_INFER_H
1308diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c
1309new file mode 100644
1310index 00000000..957a4d4f
1311--- /dev/null
1312+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.c
1313@@ -0,0 +1,37 @@
1314+/**
1315+ * Copyright 2023 Huawei Technologies Co., Ltd
1316+ *
1317+ * Licensed under the Apache License, Version 2.0 (the "License");
1318+ * you may not use this file except in compliance with the License.
1319+ * You may obtain a copy of the License at
1320+ *
1321+ * http://www.apache.org/licenses/LICENSE-2.0
1322+ *
1323+ * Unless required by applicable law or agreed to in writing, software
1324+ * distributed under the License is distributed on an "AS IS" BASIS,
1325+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1326+ * See the License for the specific language governing permissions and
1327+ * limitations under the License.
1328+ */
1329+
1330+#include "nnacl/infer/custom_masked_fill_infer.h"
1331+#include "nnacl/infer/infer_register.h"
1332+
1333+int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
1334+                               OpParameter *parameter) {
1335+  int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM);
1336+  if (check_ret != NNACL_OK) {
1337+    return check_ret;
1338+  }
1339+
1340+  const TensorC *input = inputs[0];
1341+  TensorC *output = outputs[0];
1342+  SetDataTypeFormat(output, input);
1343+  if (!InferFlag(inputs, inputs_size)) {
1344+    return NNACL_INFER_INVALID;
1345+  }
1346+  SetShapeTensor(output, input);
1347+  return NNACL_OK;
1348+}
1349+
1350+REG_INFER(CustomMaskedFill, PrimType_Inner_CustomMaskedFill, CustomMaskedFillInferShape)
1351diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h
1352new file mode 100644
1353index 00000000..a8adbae2
1354--- /dev/null
1355+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_masked_fill_infer.h
1356@@ -0,0 +1,31 @@
1357+/**
1358+ * Copyright 2023 Huawei Technologies Co., Ltd
1359+ *
1360+ * Licensed under the Apache License, Version 2.0 (the "License");
1361+ * you may not use this file except in compliance with the License.
1362+ * You may obtain a copy of the License at
1363+ *
1364+ * http://www.apache.org/licenses/LICENSE-2.0
1365+ *
1366+ * Unless required by applicable law or agreed to in writing, software
1367+ * distributed under the License is distributed on an "AS IS" BASIS,
1368+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1369+ * See the License for the specific language governing permissions and
1370+ * limitations under the License.
1371+ */
1372+#ifndef MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H
1373+#define MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H
1374+
1375+#include "nnacl/infer/common_infer.h"
1376+
1377+#ifdef __cplusplus
1378+extern "C" {
1379+#endif
1380+
1381+int CustomMaskedFillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
1382+                               OpParameter *parameter);
1383+
1384+#ifdef __cplusplus
1385+}
1386+#endif
1387+#endif  // MINDSPORE_NNACL_CUSTOM_MASKED_FILL_INFER_H
1388diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c
1389new file mode 100644
1390index 00000000..be6716ba
1391--- /dev/null
1392+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.c
1393@@ -0,0 +1,37 @@
1394+/**
1395+ * Copyright 2023 Huawei Technologies Co., Ltd
1396+ *
1397+ * Licensed under the Apache License, Version 2.0 (the "License");
1398+ * you may not use this file except in compliance with the License.
1399+ * You may obtain a copy of the License at
1400+ *
1401+ * http://www.apache.org/licenses/LICENSE-2.0
1402+ *
1403+ * Unless required by applicable law or agreed to in writing, software
1404+ * distributed under the License is distributed on an "AS IS" BASIS,
1405+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1406+ * See the License for the specific language governing permissions and
1407+ * limitations under the License.
1408+ */
1409+
1410+#include "nnacl/infer/custom_tensor_scatter_max_infer.h"
1411+#include "nnacl/infer/infer_register.h"
1412+
1413+int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
1414+                                     size_t outputs_size, OpParameter *parameter) {
1415+  int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, C3NUM, C1NUM);
1416+  if (check_ret != NNACL_OK) {
1417+    return check_ret;
1418+  }
1419+
1420+  const TensorC *input = inputs[0];
1421+  TensorC *output = outputs[0];
1422+  SetDataTypeFormat(output, input);
1423+  if (!InferFlag(inputs, inputs_size)) {
1424+    return NNACL_INFER_INVALID;
1425+  }
1426+  SetShapeTensor(output, input);
1427+  return NNACL_OK;
1428+}
1429+
1430+REG_INFER(CustomTensorScatterMax, PrimType_Inner_CustomTensorScatterMax, CustomTensorScatterMaxInferShape)
1431diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h
1432new file mode 100644
1433index 00000000..641aa483
1434--- /dev/null
1435+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/custom_tensor_scatter_max_infer.h
1436@@ -0,0 +1,31 @@
1437+/**
1438+ * Copyright 2023 Huawei Technologies Co., Ltd
1439+ *
1440+ * Licensed under the Apache License, Version 2.0 (the "License");
1441+ * you may not use this file except in compliance with the License.
1442+ * You may obtain a copy of the License at
1443+ *
1444+ * http://www.apache.org/licenses/LICENSE-2.0
1445+ *
1446+ * Unless required by applicable law or agreed to in writing, software
1447+ * distributed under the License is distributed on an "AS IS" BASIS,
1448+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1449+ * See the License for the specific language governing permissions and
1450+ * limitations under the License.
1451+ */
1452+#ifndef MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H
1453+#define MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H
1454+
1455+#include "nnacl/infer/common_infer.h"
1456+
1457+#ifdef __cplusplus
1458+extern "C" {
1459+#endif
1460+
1461+int CustomTensorScatterMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
1462+                                     size_t outputs_size, OpParameter *parameter);
1463+
1464+#ifdef __cplusplus
1465+}
1466+#endif
1467+#endif  // MINDSPORE_NNACL_CUSTOM_TENSOR_SCATTER_MAX_INFER_H
1468diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h
1469new file mode 100644
1470index 00000000..d7c34768
1471--- /dev/null
1472+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/neon/scatter_nd_binary_neon.h
1473@@ -0,0 +1,65 @@
1474+/**
1475+* Copyright 2023 Huawei Technologies Co., Ltd
1476+*
1477+* Licensed under the Apache License, Version 2.0 (the "License");
1478+* you may not use this file except in compliance with the License.
1479+* You may obtain a copy of the License at
1480+*
1481+* http://www.apache.org/licenses/LICENSE-2.0
1482+*
1483+* Unless required by applicable law or agreed to in writing, software
1484+* distributed under the License is distributed on an "AS IS" BASIS,
1485+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1486+* See the License for the specific language governing permissions and
1487+* limitations under the License.
1488+*/
1489+#ifndef NNACL_BASE_SCATTER_ND_BINARY_NEON_H_
1490+#define NNACL_BASE_SCATTER_ND_BINARY_NEON_H_
1491+
1492+#include "nnacl/intrinsics/ms_simd_instructions.h"
1493+#include "nnacl/intrinsics/ms_simd_neon_instructions.h"
1494+
1495+#ifdef __cplusplus
1496+extern "C" {
1497+#endif
1498+
1499+#define MS_SIMD_INSTRUCTION MS_SIMD_NEON_INSTRUCTION
1500+#define BLOCK_NUM 4
1501+#define MS_SIMD_NEON
1502+
1503+static inline int ScatterNDAddFp32NEON(int index, const float *update, int size, float *output) {
1504+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1505+    SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1506+  }
1507+  return index;
1508+}
1509+
1510+static inline int ScatterNDAddInt32NEON(int index, const int *update, int size, int *output) {
1511+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1512+    SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1513+  }
1514+  return index;
1515+}
1516+
1517+static inline int ScatterNDMaxFp32NEON(int index, const float *update, int size, float *output) {
1518+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1519+    SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1520+  }
1521+  return index;
1522+}
1523+
1524+static inline int ScatterNDMaxInt32NEON(int index, const int *update, int size, int *output) {
1525+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1526+    SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1527+  }
1528+  return index;
1529+}
1530+
1531+#undef MS_SIMD_INSTRUCTION
1532+#undef BLOCK_NUM
1533+
1534+#undef MS_SIMD_NEON
1535+#ifdef __cplusplus
1536+}
1537+#endif
1538+#endif  // NNACL_BASE_SCATTER_ND_BINARY_NEON_H_
1539diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
1540index 955a70a5..895f7e3d 100644
1541--- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
1542+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h
1543@@ -558,6 +558,10 @@ enum PrimType {
1544   PrimType_Inner_CustomGru = 10010,
1545   PrimType_Inner_CastGatherReduceFusion = 10011,
1546   PrimType_Inner_ReduceConcatFusion = 10012,
1547+  PrimType_Inner_ThirdPartyModel = 10013,
1548+  PrimType_Inner_CustomMaskedFill = 10014,
1549+  PrimType_Inner_CustomTensorScatterMax = 10015,
1550+  PrimType_Inner_CustomIsInf = 10016,
1551   PrimType_InnerOpMax,
1552   PrimType_InnerOpMin = PrimType_Inner_ToFormat
1553 };
1554diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h
1555new file mode 100644
1556index 00000000..dd9878f7
1557--- /dev/null
1558+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/scatter_nd_binary_simd.h
1559@@ -0,0 +1,36 @@
1560+/**
1561+* Copyright 2023 Huawei Technologies Co., Ltd
1562+*
1563+* Licensed under the Apache License, Version 2.0 (the "License");
1564+* you may not use this file except in compliance with the License.
1565+* You may obtain a copy of the License at
1566+*
1567+* http://www.apache.org/licenses/LICENSE-2.0
1568+*
1569+* Unless required by applicable law or agreed to in writing, software
1570+* distributed under the License is distributed on an "AS IS" BASIS,
1571+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1572+* See the License for the specific language governing permissions and
1573+* limitations under the License.
1574+*/
1575+#ifndef NNACL_SCATTER_ND_BINARY_SIMD_H_
1576+#define NNACL_SCATTER_ND_BINARY_SIMD_H_
1577+
1578+#include "nnacl/intrinsics/ms_simd_instructions.h"
1579+#ifdef ENABLE_AVX512
1580+#include "nnacl/avx512/scatter_nd_binary_avx512.h"
1581+#endif
1582+
1583+#ifdef ENABLE_AVX
1584+#include "nnacl/avx/scatter_nd_binary_avx.h"
1585+#endif
1586+
1587+#ifdef ENABLE_SSE
1588+#include "nnacl/sse/scatter_nd_binary_sse.h"
1589+#endif
1590+
1591+#ifdef ENABLE_ARM
1592+#include "nnacl/neon/scatter_nd_binary_neon.h"
1593+#endif
1594+
1595+#endif
1596diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h
1597new file mode 100644
1598index 00000000..983d2923
1599--- /dev/null
1600+++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/sse/scatter_nd_binary_sse.h
1601@@ -0,0 +1,66 @@
1602+/**
1603+* Copyright 2023 Huawei Technologies Co., Ltd
1604+*
1605+* Licensed under the Apache License, Version 2.0 (the "License");
1606+* you may not use this file except in compliance with the License.
1607+* You may obtain a copy of the License at
1608+*
1609+* http://www.apache.org/licenses/LICENSE-2.0
1610+*
1611+* Unless required by applicable law or agreed to in writing, software
1612+* distributed under the License is distributed on an "AS IS" BASIS,
1613+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1614+* See the License for the specific language governing permissions and
1615+* limitations under the License.
1616+*/
1617+#ifndef NNACL_BASE_SCATTER_ND_BINARY_SSE_H_
1618+#define NNACL_BASE_SCATTER_ND_BINARY_SSE_H_
1619+
1620+#include "nnacl/intrinsics/ms_simd_instructions.h"
1621+#include "nnacl/intrinsics/ms_simd_sse_instructions.h"
1622+
1623+#ifdef __cplusplus
1624+extern "C" {
1625+#endif
1626+#pragma GCC push_options
1627+#pragma GCC target("sse4.1")
1628+#define MS_SIMD_INSTRUCTION MS_SIMD_SSE_INSTRUCTION
1629+#define BLOCK_NUM 4
1630+#define MS_SIMD_SSE
1631+
1632+static inline int ScatterNDAddFp32SSE(int index, const float *update, int size, float *output) {
1633+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1634+    SIMD_ST_F32(output + index, SIMD_ADD_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1635+  }
1636+  return index;
1637+}
1638+
1639+static inline int ScatterNDAddInt32SSE(int index, const int *update, int size, int *output) {
1640+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1641+    SIMD_ST_EPI32(output + index, SIMD_ADD_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1642+  }
1643+  return index;
1644+}
1645+
1646+static inline int ScatterNDMaxFp32SSE(int index, const float *update, int size, float *output) {
1647+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1648+    SIMD_ST_F32(output + index, SIMD_MAX_F32(SIMD_LD_F32(output + index), SIMD_LD_F32(update + index)));
1649+  }
1650+  return index;
1651+}
1652+
1653+static inline int ScatterNDMaxInt32SSE(int index, const int *update, int size, int *output) {
1654+  for (int block_max_size = size - BLOCK_NUM + 1; index < block_max_size; index += BLOCK_NUM) {
1655+    SIMD_ST_EPI32(output + index, SIMD_MAX_EPI32(SIMD_LD_EPI32(output + index), SIMD_LD_EPI32(update + index)));
1656+  }
1657+  return index;
1658+}
1659+
1660+#undef MS_SIMD_INSTRUCTION
1661+#undef BLOCK_NUM
1662+#pragma GCC pop_options
1663+#undef MS_SIMD_SSE
1664+#ifdef __cplusplus
1665+}
1666+#endif
1667+#endif  // NNACL_BASE_SCATTER_ND_BINARY_SSE_H_
1668diff --git a/mindspore/core/mindrt/BUILD.gn b/mindspore/core/mindrt/BUILD.gn
1669index b56d5f5c..b0e7c70d 100644
1670--- a/mindspore/core/mindrt/BUILD.gn
1671+++ b/mindspore/core/mindrt/BUILD.gn
1672@@ -41,8 +41,15 @@ ohos_source_set("mindrt_obj") {
1673     "../../core/",
1674   ]
1675 
1676+  defines = [
1677+    "ENABLE_MINDRT",
1678+    "MS_COMPILE_OHOS",
1679+    "BUILD_LITE",
1680+  ]
1681+
1682+  external_deps = [ "hilog:libhilog" ]
1683+
1684   remove_configs = [ "//build/config/compiler:no_rtti" ]
1685-  defines = [ "BUILD_LITE" ]
1686 
1687   part_name = "mindspore"
1688   subsystem_name = "thirdparty"
1689diff --git a/mindspore/core/mindrt/src/thread/actor_threadpool.cc b/mindspore/core/mindrt/src/thread/actor_threadpool.cc
1690index 70414757..c50c46e0 100644
1691--- a/mindspore/core/mindrt/src/thread/actor_threadpool.cc
1692+++ b/mindspore/core/mindrt/src/thread/actor_threadpool.cc
1693@@ -32,7 +32,7 @@ void ActorWorker::RunWithSpin() {
1694   }
1695 #if !defined(__APPLE__) && !defined(_MSC_VER)
1696   static std::atomic_int index{0};
1697-  (void)pthread_setname_np(pthread_self(), ("ActorThread_" + std::to_string(index++)).c_str());
1698+  (void)pthread_setname_np(pthread_self(), ("OS_Actor_" + std::to_string(index++)).c_str());
1699 #endif
1700 #ifdef PLATFORM_86
1701   // Some CPU kernels need set the flush zero mode to improve performance.
1702diff --git a/mindspore/core/mindrt/src/thread/core_affinity.cc b/mindspore/core/mindrt/src/thread/core_affinity.cc
1703index 33bf3529..a3478dff 100644
1704--- a/mindspore/core/mindrt/src/thread/core_affinity.cc
1705+++ b/mindspore/core/mindrt/src/thread/core_affinity.cc
1706@@ -344,12 +344,12 @@ int CoreAffinity::InitBindCoreId(size_t thread_num, BindMode bind_mode) {
1707 int CoreAffinity::SetAffinity() { return THREAD_OK; }
1708 #elif defined(BIND_CORE)
1709 int CoreAffinity::SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set) {
1710-#ifdef __ANDROID__
1711-#if __ANDROID_API__ >= 21
1712+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS)
1713+#if (__ANDROID_API__ >= 21) || defined(MS_COMPILE_OHOS)
1714   THREAD_INFO("thread: %d, mask: %lu", pthread_gettid_np(thread_id), cpu_set->__bits[0]);
1715   int ret = sched_setaffinity(pthread_gettid_np(thread_id), sizeof(cpu_set_t), cpu_set);
1716   if (ret != THREAD_OK) {
1717-    THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", pthread_gettid_np(thread_id), ret);
1718+    THREAD_ERROR("bind thread %d to cpu failed. ERROR %{public}d", pthread_gettid_np(thread_id), ret);
1719     return THREAD_ERROR;
1720   }
1721 #endif
1722diff --git a/mindspore/core/mindrt/src/thread/core_affinity.h b/mindspore/core/mindrt/src/thread/core_affinity.h
1723index 2dd2abd1..28b0967a 100644
1724--- a/mindspore/core/mindrt/src/thread/core_affinity.h
1725+++ b/mindspore/core/mindrt/src/thread/core_affinity.h
1726@@ -23,7 +23,7 @@
1727 #ifdef PARALLEL_INFERENCE
1728 #define BIND_CORE
1729 #endif
1730-#ifdef __ANDROID__
1731+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS)
1732 #define BIND_CORE
1733 #include <sched.h>
1734 #endif
1735diff --git a/mindspore/core/mindrt/src/thread/parallel_threadpool.cc b/mindspore/core/mindrt/src/thread/parallel_threadpool.cc
1736index 9e0dd25c..09c39f32 100644
1737--- a/mindspore/core/mindrt/src/thread/parallel_threadpool.cc
1738+++ b/mindspore/core/mindrt/src/thread/parallel_threadpool.cc
1739@@ -48,7 +48,7 @@ void ParallelWorker::ParallelRun() {
1740     SetAffinity();
1741   }
1742 #if !defined(__APPLE__) && !defined(_MSC_VER)
1743-  (void)pthread_setname_np(pthread_self(), ("ParallelThread_" + std::to_string(worker_id_)).c_str());
1744+  (void)pthread_setname_np(pthread_self(), ("OS_Parallel_" + std::to_string(worker_id_)).c_str());
1745 #endif
1746 #ifdef PLATFORM_86
1747   // Some CPU kernels need set the flush zero mode to improve performance.
1748diff --git a/mindspore/core/mindrt/src/thread/threadlog.h b/mindspore/core/mindrt/src/thread/threadlog.h
1749index 7ed917f1..b212a401 100644
1750--- a/mindspore/core/mindrt/src/thread/threadlog.h
1751+++ b/mindspore/core/mindrt/src/thread/threadlog.h
1752@@ -16,7 +16,9 @@
1753 
1754 #ifndef MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_LOG_H_
1755 #define MINDSPORE_CORE_MINDRT_RUNTIME_THREADPOOL_LOG_H_
1756-
1757+#ifdef MS_COMPILE_OHOS
1758+#include "hilog/log.h"
1759+#endif
1760 namespace mindspore {
1761 #ifdef THREAD_POOL_DEBUG
1762 #include <stdio.h>
1763@@ -32,13 +34,35 @@ namespace mindspore {
1764   }
1765 #else
1766 #define THREAD_DEBUG(content, ...)
1767-#define THREAD_INFO(content, ...)
1768 #define THREAD_TEST_TRUE(flag)
1769+
1770 #if defined(__ANDROID__)
1771+#define THREAD_INFO(content, ...)
1772 #include <android/log.h>
1773 #define THREAD_ERROR(content, args...) \
1774   { __android_log_print(ANDROID_LOG_ERROR, "MS_LITE", "%s|%d: " #content "\r\n", __func__, __LINE__, ##args); }
1775+
1776+#elif defined(MS_COMPILE_OHOS) // For OHOS, use hilog.
1777+
1778+#define MINDRT_OHOS_LOG_DOMAIN 0x2102
1779+#define MINDRT_OHOS_LOG_TAG "MS_LITE"
1780+
1781+#ifdef MS_COMPILE_WITH_OHOS_NDK
1782+// When build with OHOS NDK, use public api of hilog module.
1783+#define THREAD_INFO(content, args...) \
1784+  { OH_LOG_Print(LOG_APP, LOG_INFO, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); }
1785+#define THREAD_ERROR(content, args...) \
1786+  { OH_LOG_Print(LOG_APP, LOG_ERROR, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); }
1787+#else
1788+// When build in OHOS repo, use inner api of hilog module.
1789+#define THREAD_INFO(content, args...) \
1790+  { HiLogPrint(LOG_APP, LOG_INFO, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); }
1791+#define THREAD_ERROR(content, args...) \
1792+  { HiLogPrint(LOG_APP, LOG_ERROR, MINDRT_OHOS_LOG_DOMAIN, MINDRT_OHOS_LOG_TAG, "%s:%d " #content, __func__, __LINE__, ##args); }
1793+#endif
1794+
1795 #else
1796+#define THREAD_INFO(content, ...)
1797 #define THREAD_ERROR(content, ...)
1798 #endif
1799 #endif
1800diff --git a/mindspore/core/mindrt/src/thread/threadpool.cc b/mindspore/core/mindrt/src/thread/threadpool.cc
1801index c56e0425..2301be8c 100644
1802--- a/mindspore/core/mindrt/src/thread/threadpool.cc
1803+++ b/mindspore/core/mindrt/src/thread/threadpool.cc
1804@@ -68,10 +68,11 @@ void Worker::SetAffinity() {
1805 #ifdef _WIN32
1806   SetWindowsSelfAffinity(core_id_);
1807 #elif defined(BIND_CORE)
1808-#ifdef __ANDROID__
1809+#if defined(__ANDROID__) || defined(MS_COMPILE_OHOS)
1810+  THREAD_INFO("thread: %d, mask: %lu", gettid(), mask_.__bits[0]);
1811   int ret = sched_setaffinity(gettid(), sizeof(cpu_set_t), &mask_);
1812   if (ret != THREAD_OK) {
1813-    THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", gettid(), errno);
1814+    THREAD_ERROR("bind thread %d to cpu failed. ERROR %{public}d", gettid(), errno);
1815   }
1816   return;
1817 #else
1818@@ -111,7 +112,7 @@ void Worker::Run() {
1819   }
1820 #if !defined(__APPLE__) && !defined(_MSC_VER)
1821   static std::atomic_int index = {0};
1822-  (void)pthread_setname_np(pthread_self(), ("KernelThread_" + std::to_string(index++)).c_str());
1823+  (void)pthread_setname_np(pthread_self(), ("OS_Kernel_" + std::to_string(index++)).c_str());
1824 #endif
1825 #ifdef PLATFORM_86
1826   // Some CPU kernels need set the flush zero mode to improve performance.
1827diff --git a/mindspore/lite/BUILD.gn b/mindspore/lite/BUILD.gn
1828index a774b58c..f7e465e2 100644
1829--- a/mindspore/lite/BUILD.gn
1830+++ b/mindspore/lite/BUILD.gn
1831@@ -71,9 +71,14 @@
1832 
1833 import("//build/ohos.gni")
1834 
1835+declare_args() {
1836+    mindspore_feature_nnrt_metagraph = false
1837+}
1838+
1839 ohos_group("mindspore") {
1840   deps = [
1841     ":mindspore_lib",
1842+    ":mindspore_ndk",
1843     ":mindspore_train_lib",
1844     "mindir:mindir_lib",
1845     "src/litert/js_api:mindsporelite_napi"
1846@@ -180,7 +185,6 @@ lite_mindrt_sources = [
1847 ]
1848 
1849 all_lite_sources += cxx_api_sources
1850-all_lite_sources += c_api_sources
1851 all_lite_sources += api_source
1852 all_lite_sources += control_flow_kernel_sources
1853 all_lite_sources += experimental_sources
1854@@ -368,7 +372,6 @@ ohos_shared_library("mindspore_lib") {
1855   sources = all_sources
1856 
1857   include_dirs = [
1858-    "//base/hiviewdfx/hilog/interfaces/native/innerkits/include",
1859     "//third_party/flatbuffers/include",
1860     "./",
1861     "../",
1862@@ -384,6 +387,7 @@ ohos_shared_library("mindspore_lib") {
1863     "../ccsrc/",
1864     "src/litert/kernel/cpu/",
1865     "../core/mindrt/src/",
1866+    "//foundation/ai/neural_network_runtime/",
1867   ]
1868 
1869   defines = [
1870@@ -426,24 +430,29 @@ ohos_shared_library("mindspore_lib") {
1871 
1872   external_deps = [ "hilog:libhilog" ]
1873 
1874-  output_name = "libmindspore-lite.huawei"
1875+  output_name = "libmindspore-lite"
1876   output_extension = "so"
1877   innerapi_tags = [ "platformsdk" ]
1878   SUPPORT_NNRT = true
1879   if (SUPPORT_NNRT) {
1880+    if (mindspore_feature_nnrt_metagraph) {
1881+      defines += [ "SUPPORT_NNRT_METAGRAPH" ]
1882+      print("enabled feature: mindspore_feature_nnrt_metagraph")
1883+    }
1884     sources += [
1885       "src/litert/delegate/nnrt/checker/primitive_check.cc",
1886       "src/litert/delegate/nnrt/nnrt_delegate.cc",
1887       "src/litert/delegate/nnrt/nnrt_model_kernel.cc",
1888     ]
1889     include_dirs += [
1890-      "//foundation/ai/neural_network_runtime",
1891       "src/delegate/nnrt/include",
1892       "../../mindspore/core/ir",
1893       "mindir/include",
1894       "mindir/inner_headers",
1895     ]
1896+
1897     external_deps += [ "neural_network_runtime:nnrt_target" ]
1898+
1899     deps += [ "mindir:mindir_lib" ]
1900     defines += [ "SUPPORT_NNRT" ]
1901   }
1902@@ -461,6 +470,67 @@ ohos_shared_library("mindspore_lib") {
1903   subsystem_name = "thirdparty"
1904 }
1905 
1906+# NDK lib
1907+ohos_shared_library("mindspore_ndk") {
1908+  deps = [
1909+    ":mindspore_lib",
1910+    ":mindspore_train_lib"
1911+  ]
1912+
1913+  sources = c_api_sources
1914+
1915+  include_dirs = [
1916+    "//base/hiviewdfx/hilog/interfaces/native/innerkits/include",
1917+    "//third_party/flatbuffers/include",
1918+    "./",
1919+    "../",
1920+    "../../",
1921+    "../core",
1922+    "src",
1923+    "src/c_api/",
1924+    "../ccsrc/plugin/device/cpu/kernel/",
1925+    "../core/mindrt/src/",
1926+    "../core/mindrt/include/",
1927+    "../../third_party/",
1928+    "./schema/",
1929+    "../ccsrc/",
1930+    "//foundation/ai/neural_network_runtime/",
1931+  ]
1932+
1933+  defines = [
1934+    "SUPPORT_NNRT",
1935+    "MS_COMPILE_OHOS",
1936+    "PRIMITIVE_WRITEABLE",
1937+    "RUNTIME_PASS_CLIP",
1938+    "ENABLE_MULTI_LAYOUT",
1939+    "VERSION_STR=\"2.1.0\"",
1940+  ]
1941+
1942+  configs = [
1943+    ":mindspore_api",
1944+    ":disable_android",
1945+    ":secure_option",
1946+  ]
1947+
1948+  external_deps = [ "neural_network_runtime:nnrt_target" ]
1949+
1950+  remove_configs = [ "//build/config/compiler:no_rtti" ]
1951+
1952+  output_name = "libmindspore_lite_ndk"
1953+  output_extension = "so"
1954+  innerapi_tags = [ "ndk"]
1955+  cflags_cc = [
1956+    "-Wno-ignored-qualifiers",
1957+    "-Wunused-private-field",
1958+    "-Wno-unused-private-field",
1959+    "-Wno-inconsistent-missing-override",
1960+    "-Wno-macro-redefined",
1961+    "-Wno-constant-conversion",
1962+  ]
1963+  part_name = "mindspore"
1964+  subsystem_name = "thirdparty"
1965+}
1966+
1967 # Train library
1968 expression_cxx_api_sources = [
1969   "src/litert/cxx_api/expression/net.cc",
1970@@ -614,7 +684,6 @@ ohos_shared_library("mindspore_train_lib") {
1971   sources = all_train_sources
1972 
1973   include_dirs = [
1974-    "//base/hiviewdfx/hilog/interfaces/native/innerkits/include",
1975     "//third_party/flatbuffers/include",
1976     "./",
1977     "../",
1978@@ -698,6 +767,9 @@ config("disable_android") {
1979     "-U__ANDROID__",
1980     "-U__ANDROID_API__",
1981   ]
1982+  ldflags = [
1983+    "-Wl,--no-as-needed",
1984+  ]
1985 }
1986 
1987 config("secure_option") {
1988diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt
1989index 72337f70..1faf2f38 100644
1990--- a/mindspore/lite/CMakeLists.txt
1991+++ b/mindspore/lite/CMakeLists.txt
1992@@ -298,8 +298,9 @@ elseif(TOOLCHAIN_NAME STREQUAL "ohos-lite")
1993 elseif(TOOLCHAIN_NAME STREQUAL "ohos")
1994     set(TARGET_OHOS on)
1995     add_compile_definitions(MS_COMPILE_OHOS)
1996-    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions")
1997-    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions")
1998+    add_compile_definitions(MS_COMPILE_WITH_OHOS_NDK)
1999+    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions -Wno-deprecated-builtins")
2000+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-command-line-argument -Wno-c++17-extensions -Wno-deprecated-builtins")
2001 endif()
2002 
2003 if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0
2004diff --git a/mindspore/lite/include/lite_types.h b/mindspore/lite/include/lite_types.h
2005index 017e98a8..860390d5 100644
2006--- a/mindspore/lite/include/lite_types.h
2007+++ b/mindspore/lite/include/lite_types.h
2008@@ -42,6 +42,7 @@ typedef enum {
2009   DT_NPU,    /**< NPU device type */
2010   DT_ASCEND, /**< ASCEND device type */
2011   DT_CUSTOM, /**< EXTEND device type */
2012+  DT_NNRT,   /**< NNRT device type */
2013   DT_END     /**< NO device type */
2014 } DeviceType;
2015 
2016diff --git a/mindspore/lite/include/model.h b/mindspore/lite/include/model.h
2017index 93e27ea9..b96c7e35 100644
2018--- a/mindspore/lite/include/model.h
2019+++ b/mindspore/lite/include/model.h
2020@@ -25,6 +25,7 @@ namespace mindspore {
2021 namespace schema {
2022 struct Tensor;
2023 }  // namespace schema
2024+
2025 namespace lite {
2026 typedef enum { ModelType_MSLite, ModelType_MindIR } LiteModelType;
2027 
2028@@ -62,7 +63,10 @@ struct MS_API LiteGraph {
2029   bool model_obfuscated_ = false;
2030   std::vector<unsigned char *> deobf_prims_;
2031 #endif
2032+
2033+  std::string ToString() const;
2034 };
2035+
2036 struct MS_API Model {
2037   LiteGraph graph_;
2038   char *buf = nullptr;
2039diff --git a/mindspore/lite/include/registry/converter_context.h b/mindspore/lite/include/registry/converter_context.h
2040index 2d72b200..4bc92599 100644
2041--- a/mindspore/lite/include/registry/converter_context.h
2042+++ b/mindspore/lite/include/registry/converter_context.h
2043@@ -39,7 +39,9 @@ enum MS_API FmkType : int {
2044   kFmkTypeMs = 3,
2045   kFmkTypeTflite = 4,
2046   kFmkTypePytorch = 5,
2047-  kFmkTypeMsLite = 6,
2048+  kFmkTypeThirdParty = 6,
2049+  kFmkTypeMsLite = 7,
2050+  kFmkTypeEnd = 8,  // For range check purpose, valid range: [0, kFmkTypeEnd)
2051 };
2052 
2053 /// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
2054diff --git a/mindspore/lite/mindir/include/mindir.h b/mindspore/lite/mindir/include/mindir.h
2055index ca811dce..f47cad8c 100644
2056--- a/mindspore/lite/mindir/include/mindir.h
2057+++ b/mindspore/lite/mindir/include/mindir.h
2058@@ -151,6 +151,8 @@ int64_t MindIR_Conv2DFusion_GetOutChannel(ConstPrimitivePtr primitive);
2059 void MindIR_Conv2DFusion_SetOutChannel(PrimitivePtr *primitive, int64_t out_channel);
2060 ActivationType MindIR_Conv2DFusion_GetActivationType(ConstPrimitivePtr primitive);
2061 void MindIR_Conv2DFusion_SetActivationType(PrimitivePtr *primitive, ActivationType activation_type);
2062+Format MindIR_Conv2DFusion_GetFormat(ConstPrimitivePtr primitive);
2063+void MindIR_Conv2DFusion_SetFormat(PrimitivePtr *primitive, Format format);
2064 
2065 // ********** Conv2dTransposeFusion **********
2066 PrimitivePtr MindIR_Conv2dTransposeFusion_CreatePrimitive(
2067diff --git a/mindspore/lite/mindir/src/mindir.cc b/mindspore/lite/mindir/src/mindir.cc
2068index 7fc9c00e..374bbef5 100644
2069--- a/mindspore/lite/mindir/src/mindir.cc
2070+++ b/mindspore/lite/mindir/src/mindir.cc
2071@@ -1215,6 +1215,46 @@ void MindIR_Conv2DFusion_SetActivationType(PrimitivePtr *primitive, ActivationTy
2072   }
2073 }
2074 
2075+Format MindIR_Conv2DFusion_GetFormat(ConstPrimitivePtr primitive) {
2076+  if (primitive != nullptr) {
2077+    auto prim = static_cast<const schema::Primitive *>(primitive);
2078+    auto value = prim->value_as_Conv2DFusion();
2079+    if (prim != nullptr && value != nullptr) {
2080+      return static_cast<Format>(value->format());
2081+    } else {
2082+      Format en = static_cast<Format>(0);
2083+      return en;
2084+    }
2085+  } else {
2086+    Format en = static_cast<Format>(0);
2087+    return en;
2088+  }
2089+}
2090+
2091+void MindIR_Conv2DFusion_SetFormat(PrimitivePtr *primitive, Format format) {
2092+  if (primitive != nullptr && *primitive != nullptr) {
2093+    auto prim = static_cast<schema::Primitive *>(*primitive);
2094+    auto value = prim->value_as_Conv2DFusion();
2095+    if (prim != nullptr && value != nullptr) {
2096+      flatbuffers::FlatBufferBuilder fbb;
2097+      auto ops_offset = schema::CreateConv2DFusion(
2098+        fbb, static_cast<schema::Format>(format),
2099+        fbb.CreateVector(value->kernel_size()->data(), value->kernel_size()->size()),
2100+        fbb.CreateVector(value->stride()->data(), value->stride()->size()),
2101+        fbb.CreateVector(value->dilation()->data(), value->dilation()->size()),
2102+        static_cast<schema::PadMode>(value->pad_mode()),
2103+        fbb.CreateVector(value->pad_list()->data(), value->pad_list()->size()), 0, value->group(), value->in_channel(),
2104+        value->out_channel(), static_cast<schema::ActivationType>(value->activation_type()));
2105+      auto prim_offset =
2106+        schema::CreatePrimitive(fbb, static_cast<schema::PrimitiveType>(NODE_TYPE_CONV2D_FUSION), ops_offset.o);
2107+      fbb.Finish(prim_offset);
2108+      auto new_addr = MindIRMemoryManager::GetInstance()->CreatePrimitiveFromBuilder(fbb, prim);
2109+      auto ret_value = flatbuffers::GetMutableRoot<schema::Primitive>(new_addr);
2110+      *primitive = ret_value;
2111+    }
2112+  }
2113+}
2114+
2115 // ********** Conv2dTransposeFusion **********
2116 PrimitivePtr MindIR_Conv2dTransposeFusion_CreatePrimitive(
2117   const std::vector<int64_t> &kernel_size, const std::vector<int64_t> &stride, const std::vector<int64_t> &dilation,
2118diff --git a/mindspore/lite/mindir/src/mindir_tensor.cc b/mindspore/lite/mindir/src/mindir_tensor.cc
2119index 9ec2d0e4..2db4ce8b 100644
2120--- a/mindspore/lite/mindir/src/mindir_tensor.cc
2121+++ b/mindspore/lite/mindir/src/mindir_tensor.cc
2122@@ -134,7 +134,7 @@ void MindIR_Tensor_SetDataType(TensorPtr *tensor, DataType data_type) {
2123         name = fbb.CreateString(value->name()->c_str(), value->name()->size());
2124       }
2125       auto ops_offset =
2126-        schema::CreateTensor(fbb, 0, value->dataType(), dims, static_cast<schema::Format>(value->format()), 0, 0, data,
2127+        schema::CreateTensor(fbb, 0, data_type, dims, static_cast<schema::Format>(value->format()), 0, 0, data,
2128                              ConvertQuantParams(fbb, value->quantParams()), 0, name);
2129       fbb.Finish(ops_offset);
2130       auto new_addr = MindIRMemoryManager::GetInstance()->CreateTensorFromBuilder(fbb, value);
2131diff --git a/mindspore/lite/mindir/src/utils.cc b/mindspore/lite/mindir/src/utils.cc
2132index 28d66ceb..b044f414 100644
2133--- a/mindspore/lite/mindir/src/utils.cc
2134+++ b/mindspore/lite/mindir/src/utils.cc
2135@@ -22,7 +22,7 @@ namespace lite {
2136 
2137 // ********** PrimitiveBase **********
2138 NodeType MindIR_Primitive_GetType(PrimitivePtr primitive) {
2139-  auto prim = flatbuffers::GetMutableRoot<schema::Primitive>(primitive);
2140+  auto prim = static_cast<schema::Primitive *>(primitive);
2141   auto type = prim->value_type();
2142   return static_cast<NodeType>(type);
2143 }
2144diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt
2145index 5afccc87..de1781cd 100644
2146--- a/mindspore/lite/src/CMakeLists.txt
2147+++ b/mindspore/lite/src/CMakeLists.txt
2148@@ -410,6 +410,11 @@ add_subdirectory(common)
2149 add_library(lite_src_mid OBJECT ${LITE_SRC})
2150 add_dependencies(lite_src_mid lite_src_common_mid fbs_src fbs_inner_src)
2151 
2152+if(SUPPORT_NNRT)
2153+    add_subdirectory(litert/delegate/nnrt)
2154+    target_link_libraries(lite_src_mid nnrt_mid)
2155+endif()
2156+
2157 if(MSLITE_ENABLE_ACL)
2158     include_directories(${TOP_DIR}/graphengine/910/inc/external)
2159     if(NOT (MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE))
2160@@ -497,7 +502,6 @@ if(MSLITE_ENABLE_MINDRT)
2161 endif()
2162 
2163 if (SUPPORT_NNRT)
2164-    add_subdirectory(litert/delegate/nnrt)
2165     target_link_libraries(mindspore-lite nnrt_mid)
2166     target_link_libraries(mindspore-lite_static nnrt_mid)
2167 endif()
2168diff --git a/mindspore/lite/src/common/context_util.cc b/mindspore/lite/src/common/context_util.cc
2169index f011e0d7..0fa4ebd0 100644
2170--- a/mindspore/lite/src/common/context_util.cc
2171+++ b/mindspore/lite/src/common/context_util.cc
2172@@ -118,6 +118,17 @@ std::shared_ptr<mindspore::DeviceInfoContext> CustomDeviceInfoFromCustomDeviceCo
2173   MS_CHECK_TRUE_RET(device_info != nullptr, nullptr);
2174   return device_info;
2175 }
2176+
2177+std::shared_ptr<mindspore::NNRTDeviceInfo> NNRtDeviceInfoFromNNRtDeviceContext(
2178+  const lite::DeviceContext &nnrt_context) {
2179+  if (nnrt_context.device_type_ != DT_NNRT) {
2180+    MS_LOG(ERROR) << "Function input parameter is not NNRt context.";
2181+    return nullptr;
2182+  }
2183+  auto nnrt_info = std::make_shared<mindspore::NNRTDeviceInfo>();
2184+  MS_CHECK_TRUE_RET(nnrt_info != nullptr, nullptr);
2185+  return nnrt_info;
2186+}
2187 }  // namespace
2188 
2189 mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &context) {
2190@@ -144,7 +155,8 @@ mindspore::Context *MSContextFromContext(const std::shared_ptr<InnerContext> &co
2191                       {DT_GPU, GPUDeviceInfoFromGPUDeviceContext},
2192                       {DT_NPU, NPUDeviceInfoFromNPUDeviceContext},
2193                       {DT_ASCEND, AscendDeviceInfoFromAscendDeviceContext},
2194-                      {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext}};
2195+                      {DT_CUSTOM, CustomDeviceInfoFromCustomDeviceContext},
2196+                      {DT_NNRT, NNRtDeviceInfoFromNNRtDeviceContext}};
2197   for (auto &device_context : context->device_list_) {
2198     auto device_type = device_context.device_type_;
2199     if (transfer_funcs.find(device_type) == transfer_funcs.end()) {
2200diff --git a/mindspore/lite/src/common/log.cc b/mindspore/lite/src/common/log.cc
2201index 66c0d76b..f1040662 100644
2202--- a/mindspore/lite/src/common/log.cc
2203+++ b/mindspore/lite/src/common/log.cc
2204@@ -21,6 +21,13 @@
2205 #include <android/log.h>
2206 #endif
2207 
2208+#ifdef MS_COMPILE_OHOS
2209+#define LOG_DOMAIN 0xD002102
2210+#define LOG_TAG "MS_LITE"
2211+#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s"
2212+#include "hilog/log.h"
2213+#endif
2214+
2215 // namespace to support utils module definition namespace mindspore constexpr const char *ANDROID_LOG_TAG = "MS_LITE";
2216 namespace mindspore {
2217 #if defined(__ANDROID__)
2218@@ -73,17 +80,33 @@ static int GetAndroidLogLevel(LiteLogLevel level) {
2219 
2220 #ifdef MS_COMPILE_OHOS
2221 void PrintHiLog(LiteLogLevel level, const char *file, int line, const char *func, const char *msg) {
2222+#ifdef MS_COMPILE_WITH_OHOS_NDK
2223+  // When build with OHOS NDK, use public api of hilog module.
2224   if (level == LiteLogLevel::DEBUG) {
2225-    OHOS::HiviewDFX::HiLog::Debug(MSLite_LABEL, FORMAT, file, line, func, msg);
2226+    OH_LOG_Print(LOG_APP, LOG_DEBUG, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2227   } else if (level == LiteLogLevel::INFO) {
2228-    OHOS::HiviewDFX::HiLog::Info(MSLite_LABEL, FORMAT, file, line, func, msg);
2229+    OH_LOG_Print(LOG_APP, LOG_INFO, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2230   } else if (level == LiteLogLevel::WARNING) {
2231-    OHOS::HiviewDFX::HiLog::Warn(MSLite_LABEL, FORMAT, file, line, func, msg);
2232+    OH_LOG_Print(LOG_APP, LOG_WARN, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2233   } else if (level == LiteLogLevel::ERROR) {
2234-    OHOS::HiviewDFX::HiLog::Error(MSLite_LABEL, FORMAT, file, line, func, msg);
2235+    OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2236   } else {
2237-    OHOS::HiviewDFX::HiLog::Error(MSLite_LABEL, FORMAT, file, line, func, msg);
2238+    OH_LOG_Print(LOG_APP, LOG_ERROR, LOG_DOMAIN, LOG_TAG, FORMAT, file, line, func, msg);
2239   }
2240+#else
2241+  // When build in OHOS repo, use inner api of hilog module.
2242+  if (level == LiteLogLevel::DEBUG) {
2243+    HILOG_DEBUG(LOG_APP, FORMAT, file, line, func, msg);
2244+  } else if (level == LiteLogLevel::INFO) {
2245+    HILOG_INFO(LOG_APP, FORMAT, file, line, func, msg);
2246+  } else if (level == LiteLogLevel::WARNING) {
2247+    HILOG_WARN(LOG_APP, FORMAT, file, line, func, msg);
2248+  } else if (level == LiteLogLevel::ERROR) {
2249+    HILOG_ERROR(LOG_APP, FORMAT, file, line, func, msg);
2250+  } else {
2251+    HILOG_ERROR(LOG_APP, FORMAT, file, line, func, msg);
2252+  }
2253+#endif
2254 }
2255 #endif
2256 
2257diff --git a/mindspore/lite/src/common/log.h b/mindspore/lite/src/common/log.h
2258index 3002a454..bea21f01 100644
2259--- a/mindspore/lite/src/common/log.h
2260+++ b/mindspore/lite/src/common/log.h
2261@@ -23,12 +23,6 @@
2262 #include <unordered_map>
2263 #include "utils/overload.h"
2264 
2265-#ifdef MS_COMPILE_OHOS
2266-#define LOG_DOMAIN 0x2102
2267-#define LOG_TAG "MS_Lite"
2268-#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s"
2269-#include "hilog/log.h"
2270-#endif
2271 // NOTICE: when relative path of 'log.h' changed, macro 'LITE_LOG_HEAR_FILE_REL_PATH' must be changed
2272 #ifndef LITE_LOG_HEAR_FILE_REL_PATH
2273 #define LITE_LOG_HEAR_FILE_REL_PATH "mindspore/lite/src/common/log.h"
2274@@ -140,6 +134,9 @@ class LiteLogWriter {
2275   LiteLogLevel log_level_;
2276 };
2277 
2278+#define MSLOG_IF(level)                                                                                  \
2279+  mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \
2280+    mindspore::LiteLogStream()
2281 
2282 #define MS_LOG(level) MS_LOG_##level
2283 
2284@@ -148,47 +145,6 @@ class LiteLogWriter {
2285 #define MS_LOG_WARNING MSLOG_IF(mindspore::LiteLogLevel::WARNING)
2286 #define MS_LOG_ERROR MSLOG_IF(mindspore::LiteLogLevel::ERROR)
2287 
2288-
2289-#ifdef MS_COMPILE_OHOS
2290-namespace {
2291-constexpr unsigned int MSLITE_DOMAIN_ID_START = 0xD0029A0;
2292-constexpr unsigned int MSLITE_DOMAIN_ID_END = MSLITE_DOMAIN_ID_START + 32;
2293-constexpr unsigned int TEST_DOMAIN_ID = 0xD000F00;
2294-}  // namespace
2295-
2296-#define FILE_NAME (__builtin_strrchr(__FILE__, '/') ? __builtin_strrchr(__FILE__, '/') + 1 : __FILE__)
2297-#define FORMAT "[%{public}s:%{public}d] %{public}s# %{public}s"
2298-
2299-#define MSLOG_IF(level)                                                                                  \
2300-  mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \
2301-    mindspore::LiteLogStream()
2302-
2303-enum MSLiteManagerLogLabel {
2304-  // Component labels, you can add if needed
2305-  COMP_FWK = 0,
2306-  // Test label
2307-  LABEL_TEST,
2308-  // The end of labels, max to the domain id range length 32
2309-  LABEL_END,
2310-};
2311-
2312-enum MSLiteManagerLogDomain {
2313-  DOMAIN_FRAMEWORK = MSLITE_DOMAIN_ID_START + COMP_FWK,  // 0xD0029A0
2314-  DOMAIN_TEST = TEST_DOMAIN_ID,                          // 0xD000F00
2315-  DOMAIN_END = MSLITE_DOMAIN_ID_END,  // Max to 0xD002940, keep the sequence and length same as MSLiteManagerLogLabel
2316-};
2317-
2318-// Keep the sequence and length same as MSLiteManagerLogDomain
2319-static constexpr OHOS::HiviewDFX::HiLogLabel MSLite_LABEL = {LOG_CORE, DOMAIN_FRAMEWORK, "MSLiteFwk"};
2320-
2321-#else
2322-
2323-#define MSLOG_IF(level)                                                                                  \
2324-  mindspore::LiteLogWriter(mindspore::LiteLocationInfo(LITE_FILE_NAME, __LINE__, __FUNCTION__), level) < \
2325-    mindspore::LiteLogStream()
2326-
2327-#endif
2328-
2329 }  // namespace mindspore
2330 
2331 #ifdef Debug
2332diff --git a/mindspore/lite/src/common/ops/populate/custom_populate.cc b/mindspore/lite/src/common/ops/populate/custom_populate.cc
2333index 5e1878b9..13957ed7 100644
2334--- a/mindspore/lite/src/common/ops/populate/custom_populate.cc
2335+++ b/mindspore/lite/src/common/ops/populate/custom_populate.cc
2336@@ -19,6 +19,9 @@
2337 #include "nnacl/custom_parameter.h"
2338 #include "nnacl/split_parameter.h"
2339 #include "nnacl/custom_gru_parameter.h"
2340+#include "nnacl/custom_masked_fill_parameter.h"
2341+#include "nnacl/custom_is_inf_parameter.h"
2342+#include "nnacl/custom_tensor_scatter_max_parameter.h"
2343 using mindspore::schema::PrimitiveType_Custom;
2344 
2345 namespace mindspore {
2346@@ -92,6 +95,39 @@ OpParameter *CreateCustomGruParameter() {
2347   return reinterpret_cast<OpParameter *>(param);
2348 }
2349 
2350+OpParameter *CreateCustomIsInfParameter() {
2351+  auto *param = static_cast<CustomIsInfParameter *>(malloc(sizeof(CustomIsInfParameter)));
2352+  if (param == nullptr) {
2353+    MS_LOG(ERROR) << "malloc CustomIsInfParameter failed.";
2354+    return nullptr;
2355+  }
2356+  memset(param, 0, sizeof(CustomIsInfParameter));
2357+  param->op_parameter_.type_ = PrimType_Inner_CustomIsInf;
2358+  return reinterpret_cast<OpParameter *>(param);
2359+}
2360+
2361+OpParameter *CreateCustomTensorScatterMaxParameter() {
2362+  auto *param = static_cast<CustomTensorScatterMaxParameter *>(malloc(sizeof(CustomTensorScatterMaxParameter)));
2363+  if (param == nullptr) {
2364+    MS_LOG(ERROR) << "malloc CustomTensorScatterMaxParameter failed.";
2365+    return nullptr;
2366+  }
2367+  memset(param, 0, sizeof(CustomTensorScatterMaxParameter));
2368+  param->op_parameter_.type_ = PrimType_Inner_CustomTensorScatterMax;
2369+  return reinterpret_cast<OpParameter *>(param);
2370+}
2371+
2372+OpParameter *CreateCustomMaskedFillParameter() {
2373+  auto *param = static_cast<CustomMaskedFillParameter *>(malloc(sizeof(CustomMaskedFillParameter)));
2374+  if (param == nullptr) {
2375+    MS_LOG(ERROR) << "malloc CustomMaskedFillParameter failed.";
2376+    return nullptr;
2377+  }
2378+  memset(param, 0, sizeof(CustomMaskedFillParameter));
2379+  param->op_parameter_.type_ = PrimType_Inner_CustomMaskedFill;
2380+  return reinterpret_cast<OpParameter *>(param);
2381+}
2382+
2383 OpParameter *PopulateCustomParameter(const void *prim) {
2384   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
2385   auto primitive = static_cast<const schema::Primitive *>(prim);
2386@@ -131,6 +167,23 @@ OpParameter *PopulateCustomParameter(const void *prim) {
2387     return CreateCustomGruParameter();
2388   } else if (type == "CastGatherReduceFusion") {
2389     return CreateParam(PrimType_Inner_CastGatherReduceFusion);
2390+  } else if (type == "ThirdPartyModel") {
2391+    auto *param = static_cast<CustomParameter *>(malloc(sizeof(CustomParameter)));
2392+    if (param == nullptr) {
2393+      MS_LOG(ERROR) << "malloc CustomParameter failed.";
2394+      return nullptr;
2395+    }
2396+    memset(param, 0, sizeof(CustomParameter));
2397+    param->op_parameter_.type_ = PrimType_Inner_ThirdPartyModel;
2398+    // Just use the attr_data pointer to save the prim directly, the inner value is parsed as necessary.
2399+    param->attr_data[0] = static_cast<char *>(const_cast<void *>(prim));
2400+    return reinterpret_cast<OpParameter *>(param);
2401+  } else if (type == "MaskedFill") {
2402+    return CreateCustomMaskedFillParameter();
2403+  } else if (type == "TensorScatterMax") {
2404+    return CreateCustomTensorScatterMaxParameter();
2405+  } else if (type == "IsInf") {
2406+    return CreateCustomIsInfParameter();
2407   } else {
2408     MS_LOG(ERROR) << "Unsupported custom type: " << type;
2409   }
2410diff --git a/mindspore/lite/src/litert/c_api/context_c.cc b/mindspore/lite/src/litert/c_api/context_c.cc
2411index f614ef09..c5f825aa 100644
2412--- a/mindspore/lite/src/litert/c_api/context_c.cc
2413+++ b/mindspore/lite/src/litert/c_api/context_c.cc
2414@@ -14,12 +14,17 @@
2415  * limitations under the License.
2416  */
2417 #include "include/c_api/context_c.h"
2418-#include "src/litert/c_api/context_c.h"
2419+#include "include/api/context.h"
2420+#include <string.h>
2421+#include "src/litert/c_api/type_c_private.h"
2422 #include "src/common/log_adapter.h"
2423+#ifdef SUPPORT_NNRT
2424+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
2425+#endif
2426 
2427 // ================ Context ================
2428 OH_AI_ContextHandle OH_AI_ContextCreate() {
2429-  auto impl = new (std::nothrow) mindspore::ContextC;
2430+  auto impl = new (std::nothrow) mindspore::Context();
2431   if (impl == nullptr) {
2432     MS_LOG(ERROR) << "memory allocation failed.";
2433     return nullptr;
2434@@ -29,7 +34,7 @@ OH_AI_ContextHandle OH_AI_ContextCreate() {
2435 
2436 void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) {
2437   if (context != nullptr && *context != nullptr) {
2438-    auto impl = static_cast<mindspore::ContextC *>(*context);
2439+    auto impl = static_cast<mindspore::Context *>(*context);
2440     delete impl;
2441     *context = nullptr;
2442   }
2443@@ -40,8 +45,8 @@ void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num)
2444     MS_LOG(ERROR) << "param is nullptr.";
2445     return;
2446   }
2447-  auto impl = static_cast<mindspore::ContextC *>(context);
2448-  impl->thread_num = thread_num;
2449+  auto impl = static_cast<mindspore::Context *>(context);
2450+  impl->SetThreadNum(thread_num);
2451 }
2452 
2453 int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
2454@@ -49,8 +54,8 @@ int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
2455     MS_LOG(ERROR) << "param is nullptr.";
2456     return 0;
2457   }
2458-  auto impl = static_cast<mindspore::ContextC *>(context);
2459-  return impl->thread_num;
2460+  auto impl = static_cast<mindspore::Context *>(context);
2461+  return impl->GetThreadNum();
2462 }
2463 
2464 void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
2465@@ -58,8 +63,8 @@ void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
2466     MS_LOG(ERROR) << "param is nullptr.";
2467     return;
2468   }
2469-  auto impl = static_cast<mindspore::ContextC *>(context);
2470-  impl->affinity_mode = mode;
2471+  auto impl = static_cast<mindspore::Context *>(context);
2472+  impl->SetThreadAffinity(mode);
2473   return;
2474 }
2475 
2476@@ -68,8 +73,8 @@ int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) {
2477     MS_LOG(ERROR) << "param is nullptr.";
2478     return 0;
2479   }
2480-  auto impl = static_cast<mindspore::ContextC *>(context);
2481-  return impl->affinity_mode;
2482+  auto impl = static_cast<mindspore::Context *>(context);
2483+  return impl->GetThreadAffinityMode();
2484 }
2485 
2486 void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) {
2487@@ -78,8 +83,8 @@ void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const i
2488     return;
2489   }
2490   const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
2491-  auto impl = static_cast<mindspore::ContextC *>(context);
2492-  impl->affinity_core_list = vec_core_list;
2493+  auto impl = static_cast<mindspore::Context *>(context);
2494+  impl->SetThreadAffinity(vec_core_list);
2495   return;
2496 }
2497 
2498@@ -88,9 +93,18 @@ const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle
2499     MS_LOG(ERROR) << "param is nullptr.";
2500     return nullptr;
2501   }
2502-  auto impl = static_cast<mindspore::ContextC *>(context);
2503-  *core_num = impl->affinity_core_list.size();
2504-  return impl->affinity_core_list.data();
2505+  auto impl = static_cast<mindspore::Context *>(context);
2506+  auto affinity_core_list = impl->GetThreadAffinityCoreList();
2507+  *core_num = affinity_core_list.size();
2508+  int32_t *core_list = static_cast<int32_t *>(malloc((*core_num) * sizeof(int32_t)));
2509+  if (core_list == nullptr) {
2510+    MS_LOG(ERROR) << "malloc core_list is null.";
2511+    return nullptr;
2512+  }
2513+  for (size_t i = 0; i < affinity_core_list.size(); i++) {
2514+    core_list[i] = affinity_core_list[i];
2515+  }
2516+  return core_list;
2517 }
2518 
2519 void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_parallel) {
2520@@ -98,8 +112,8 @@ void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_paralle
2521     MS_LOG(ERROR) << "param is nullptr.";
2522     return;
2523   }
2524-  auto impl = static_cast<mindspore::ContextC *>(context);
2525-  impl->enable_parallel = is_parallel;
2526+  auto impl = static_cast<mindspore::Context *>(context);
2527+  impl->SetEnableParallel(is_parallel);
2528 }
2529 
2530 bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
2531@@ -107,8 +121,8 @@ bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
2532     MS_LOG(ERROR) << "param is nullptr.";
2533     return false;
2534   }
2535-  auto impl = static_cast<mindspore::ContextC *>(context);
2536-  return impl->enable_parallel;
2537+  auto impl = static_cast<mindspore::Context *>(context);
2538+  return impl->GetEnableParallel();
2539 }
2540 
2541 void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) {
2542@@ -116,25 +130,36 @@ void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHan
2543     MS_LOG(ERROR) << "param is nullptr.";
2544     return;
2545   }
2546-  auto impl = static_cast<mindspore::ContextC *>(context);
2547-  std::shared_ptr<mindspore::DeviceInfoC> device(static_cast<mindspore::DeviceInfoC *>(device_info));
2548-  impl->device_info_list.push_back(device);
2549+  auto impl = static_cast<mindspore::Context *>(context);
2550+  std::shared_ptr<mindspore::DeviceInfoContext> device(static_cast<mindspore::DeviceInfoContext *>(device_info));
2551+  impl->MutableDeviceInfo().push_back(device);
2552 }
2553 
2554 // ================ DeviceInfo ================
2555 OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type) {
2556-  mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC;
2557+  mindspore::DeviceInfoContext *impl;
2558+  if (OH_AI_DEVICETYPE_CPU == device_type) {
2559+    impl = new (std::nothrow) mindspore::CPUDeviceInfo();
2560+  } else if (OH_AI_DEVICETYPE_GPU == device_type) {
2561+    impl = new (std::nothrow) mindspore::GPUDeviceInfo();
2562+  } else if (OH_AI_DEVICETYPE_KIRIN_NPU == device_type) {
2563+    impl = new (std::nothrow) mindspore::KirinNPUDeviceInfo();
2564+  } else if (OH_AI_DEVICETYPE_NNRT == device_type) {
2565+    impl = new (std::nothrow) mindspore::NNRTDeviceInfo();
2566+  } else {
2567+    MS_LOG(ERROR) << "device_type is invalid.";
2568+    impl = nullptr;
2569+  }
2570   if (impl == nullptr) {
2571     MS_LOG(ERROR) << "memory allocation failed.";
2572     return nullptr;
2573   }
2574-  impl->device_type = device_type;
2575   return static_cast<OH_AI_DeviceInfoHandle>(impl);
2576 }
2577 
2578 void OH_AI_DeviceInfoDestroy(OH_AI_DeviceInfoHandle *device_info) {
2579   if (device_info != nullptr && *device_info != nullptr) {
2580-    auto impl = static_cast<mindspore::DeviceInfoC *>(*device_info);
2581+    auto impl = static_cast<mindspore::DeviceInfoContext *>(*device_info);
2582     delete impl;
2583     *device_info = nullptr;
2584   }
2585@@ -145,8 +170,8 @@ void OH_AI_DeviceInfoSetProvider(OH_AI_DeviceInfoHandle device_info, const char
2586     MS_LOG(ERROR) << "param is nullptr.";
2587     return;
2588   }
2589-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2590-  impl->provider = provider;
2591+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2592+  impl->SetProvider(provider);
2593 }
2594 
2595 const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info) {
2596@@ -154,8 +179,14 @@ const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info
2597     MS_LOG(ERROR) << "param is nullptr.";
2598     return nullptr;
2599   }
2600-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2601-  return impl->provider.c_str();
2602+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2603+  char *provider = static_cast<char *>(malloc(impl->GetProvider().size() + 1));
2604+  if (provider == nullptr) {
2605+    MS_LOG(ERROR) << "malloc provider is null.";
2606+    return nullptr;
2607+  }
2608+  strcpy(provider, impl->GetProvider().c_str());
2609+  return provider;
2610 }
2611 
2612 void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const char *device) {
2613@@ -163,8 +194,8 @@ void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const
2614     MS_LOG(ERROR) << "param is nullptr.";
2615     return;
2616   }
2617-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2618-  impl->provider_device = device;
2619+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2620+  impl->SetProviderDevice(device);
2621 }
2622 
2623 const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle device_info) {
2624@@ -172,8 +203,14 @@ const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle devic
2625     MS_LOG(ERROR) << "param is nullptr.";
2626     return nullptr;
2627   }
2628-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2629-  return impl->provider_device.c_str();
2630+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2631+  char *provider_device = static_cast<char *>(malloc(impl->GetProviderDevice().size() + 1));
2632+  if (provider_device == nullptr) {
2633+    MS_LOG(ERROR) << "malloc provider_device is null.";
2634+    return nullptr;
2635+  }
2636+  strcpy(provider_device, impl->GetProviderDevice().c_str());
2637+  return provider_device;
2638 }
2639 
2640 OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle device_info) {
2641@@ -181,8 +218,8 @@ OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle devi
2642     MS_LOG(ERROR) << "param is nullptr.";
2643     return OH_AI_DEVICETYPE_INVALID;
2644   }
2645-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2646-  return impl->device_type;
2647+  auto impl = static_cast<mindspore::DeviceInfoContext *>(device_info);
2648+  return static_cast<OH_AI_DeviceType>(impl->GetDeviceType());
2649 }
2650 
2651 void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_fp16) {
2652@@ -190,9 +227,17 @@ void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_f
2653     MS_LOG(ERROR) << "param is nullptr.";
2654     return;
2655   }
2656-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2657-  if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) {
2658-    impl->enable_fp16 = is_fp16;
2659+
2660+  auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info);
2661+  if (OH_AI_DEVICETYPE_CPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2662+    auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
2663+    impl->SetEnableFP16(is_fp16);
2664+  } else if (OH_AI_DEVICETYPE_GPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2665+    auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
2666+    impl->SetEnableFP16(is_fp16);
2667+  } else if (OH_AI_DEVICETYPE_NNRT == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2668+    auto impl = static_cast<mindspore::NNRTDeviceInfo *>(device_info);
2669+    impl->SetEnableFP16(is_fp16);
2670   } else {
2671     MS_LOG(ERROR) << "Unsupported Feature.";
2672   }
2673@@ -203,11 +248,19 @@ bool OH_AI_DeviceInfoGetEnableFP16(const OH_AI_DeviceInfoHandle device_info) {
2674     MS_LOG(ERROR) << "param is nullptr.";
2675     return false;
2676   }
2677-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2678-  if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) {
2679-    return impl->enable_fp16;
2680+
2681+  auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info);
2682+  if (OH_AI_DEVICETYPE_CPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2683+    auto impl = static_cast<mindspore::CPUDeviceInfo *>(device_info);
2684+    return impl->GetEnableFP16();
2685+  } else if (OH_AI_DEVICETYPE_GPU == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2686+    auto impl = static_cast<mindspore::GPUDeviceInfo *>(device_info);
2687+    return impl->GetEnableFP16();
2688+  } else if (OH_AI_DEVICETYPE_NNRT == static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType())) {
2689+    auto impl = static_cast<mindspore::NNRTDeviceInfo *>(device_info);
2690+    return impl->GetEnableFP16();
2691   } else {
2692-    MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type;
2693+    MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl_device->GetDeviceType();
2694     return false;
2695   }
2696 }
2697@@ -217,9 +270,10 @@ void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, int freque
2698     MS_LOG(ERROR) << "param is nullptr.";
2699     return;
2700   }
2701-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2702-  if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
2703-    impl->frequency = frequency;
2704+  auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info);
2705+  if (static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType()) == OH_AI_DEVICETYPE_KIRIN_NPU) {
2706+    auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
2707+    impl->SetFrequency(frequency);
2708   } else {
2709     MS_LOG(ERROR) << "Unsupported Feature.";
2710   }
2711@@ -230,11 +284,231 @@ int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info) {  //
2712     MS_LOG(ERROR) << "param is nullptr.";
2713     return -1;
2714   }
2715-  auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
2716-  if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
2717-    return impl->frequency;
2718+  auto impl_device = static_cast<mindspore::DeviceInfoContext *>(device_info);
2719+  if (static_cast<OH_AI_DeviceType>(impl_device->GetDeviceType()) == OH_AI_DEVICETYPE_KIRIN_NPU) {
2720+    auto impl = static_cast<mindspore::KirinNPUDeviceInfo *>(device_info);
2721+    return impl->GetFrequency();
2722   } else {
2723     MS_LOG(ERROR) << "Unsupported Feature.";
2724     return -1;
2725   }
2726 }
2727+
2728+NNRTDeviceDesc *OH_AI_GetAllNNRTDeviceDescs(size_t *num) {
2729+  if (num == nullptr) {
2730+    MS_LOG(ERROR) << "Input num is null";
2731+    return nullptr;
2732+  }
2733+#ifdef SUPPORT_NNRT
2734+  *num = 0;
2735+
2736+  const size_t *all_device_ids;
2737+  uint32_t device_count;
2738+  auto ret = OH_NNDevice_GetAllDevicesID(&all_device_ids, &device_count);
2739+  if ((ret != OH_NN_SUCCESS) || (device_count == 0)) {
2740+    MS_LOG(ERROR) << "NNRT get all device id failed, ret: " << ret;
2741+    return nullptr;
2742+  }
2743+
2744+  NNRTDeviceDesc *desc = (NNRTDeviceDesc *)malloc(sizeof(NNRTDeviceDesc) * device_count);
2745+  if (desc == nullptr) {
2746+    MS_LOG(ERROR) << "NNRT allocate desc failed";
2747+    return nullptr;
2748+  }
2749+
2750+  for (uint32_t i = 0; i < device_count; i++) {
2751+    desc[i].device_id = all_device_ids[i];
2752+    OH_NN_DeviceType type;
2753+    (void)OH_NNDevice_GetType(all_device_ids[i], &type);
2754+    desc[i].device_type = static_cast<OH_AI_NNRTDeviceType>(type);
2755+
2756+    const char *name = nullptr;
2757+    (void)OH_NNDevice_GetName(all_device_ids[i], &name);
2758+    desc[i].device_name[127] = '\0';
2759+    strncpy(desc[i].device_name, name, 127);
2760+  }
2761+  *num = device_count;
2762+  return desc;
2763+#else
2764+  return nullptr;
2765+#endif
2766+}
2767+
2768+NNRTDeviceDesc *OH_AI_GetElementOfNNRTDeviceDescs(NNRTDeviceDesc *descs, size_t index) {
2769+  if (descs == nullptr) {
2770+    MS_LOG(ERROR) << "descs is null";
2771+    return nullptr;
2772+  }
2773+  return descs + index;
2774+}
2775+
2776+void OH_AI_DestroyAllNNRTDeviceDescs(NNRTDeviceDesc **desc) {
2777+  if (desc == nullptr) {
2778+    MS_LOG(WARNING) << "desc is null";
2779+    return;
2780+  }
2781+  free(*desc);
2782+  *desc = nullptr;
2783+}
2784+
2785+size_t OH_AI_GetDeviceIdFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) {
2786+  if (desc == nullptr) {
2787+    MS_LOG(ERROR) << "NNRT desc is null";
2788+    return 0;
2789+  }
2790+  return desc->device_id;
2791+}
2792+
2793+const char *OH_AI_GetNameFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) {
2794+  if (desc == nullptr) {
2795+    MS_LOG(ERROR) << "NNRT desc is null";
2796+    return nullptr;
2797+  }
2798+  return desc->device_name;
2799+}
2800+
2801+OH_AI_NNRTDeviceType OH_AI_GetTypeFromNNRTDeviceDesc(const NNRTDeviceDesc *desc) {
2802+  if (desc == nullptr) {
2803+    MS_LOG(ERROR) << "NNRT desc is null";
2804+    return OH_AI_NNRTDeviceType::OH_AI_NNRTDEVICE_OTHERS;
2805+  }
2806+  return desc->device_type;
2807+}
2808+
2809+OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByName(const char *name) {
2810+  size_t num = 0;
2811+  NNRTDeviceDesc *desc = OH_AI_GetAllNNRTDeviceDescs(&num);
2812+  if (desc == nullptr) {
2813+    MS_LOG(ERROR) << "Get all device desc failed";
2814+    return nullptr;
2815+  }
2816+
2817+  OH_AI_DeviceInfoHandle handle = nullptr;
2818+  for (size_t i = 0; i < num; i++) {
2819+    if (strncmp(desc[i].device_name, name, NNRT_DEVICE_NAME_MAX - 1) == 0) {
2820+      handle = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
2821+      OH_AI_DeviceInfoSetDeviceId(handle, desc[i].device_id);
2822+      break;
2823+    }
2824+  }
2825+  OH_AI_DestroyAllNNRTDeviceDescs(&desc);
2826+  return handle;
2827+}
2828+
2829+OH_AI_DeviceInfoHandle OH_AI_CreateNNRTDeviceInfoByType(OH_AI_NNRTDeviceType type) {
2830+  size_t num = 0;
2831+  NNRTDeviceDesc *desc = OH_AI_GetAllNNRTDeviceDescs(&num);
2832+  if (desc == nullptr) {
2833+    MS_LOG(ERROR) << "Get all device desc failed";
2834+    return nullptr;
2835+  }
2836+
2837+  OH_AI_DeviceInfoHandle handle = nullptr;
2838+  for (size_t i = 0; i < num; i++) {
2839+    if (desc[i].device_type == type) {
2840+      handle = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
2841+      OH_AI_DeviceInfoSetDeviceId(handle, desc[i].device_id);
2842+      break;
2843+    }
2844+  }
2845+  OH_AI_DestroyAllNNRTDeviceDescs(&desc);
2846+  return handle;
2847+}
2848+
2849+void OH_AI_DeviceInfoSetDeviceId(OH_AI_DeviceInfoHandle device_info, size_t device_id) {
2850+  if (device_info == nullptr) {
2851+    MS_LOG(ERROR) << "device info is null";
2852+    return;
2853+  }
2854+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2855+    MS_LOG(ERROR) << "Set device_id of non-NNRT device is not allowable, ignored";
2856+    return;
2857+  }
2858+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2859+  impl->SetDeviceID(device_id);
2860+}
2861+
2862+size_t OH_AI_DeviceInfoGetDeviceId(const OH_AI_DeviceInfoHandle device_info) {
2863+  if (device_info == nullptr) {
2864+    MS_LOG(ERROR) << "device info is null";
2865+    return 0;
2866+  }
2867+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2868+    MS_LOG(ERROR) << "Get device_id of non-NNRT device is not allowable, ignored";
2869+    return 0;
2870+  }
2871+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2872+  return impl->GetDeviceID();
2873+}
2874+
2875+void OH_AI_DeviceInfoSetPerformanceMode(OH_AI_DeviceInfoHandle device_info, OH_AI_PerformanceMode mode) {
2876+  if (device_info == nullptr) {
2877+    MS_LOG(ERROR) << "device info is null";
2878+    return;
2879+  }
2880+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2881+    MS_LOG(ERROR) << "Set performance_mode of non-NNRT device is not allowable, ignored";
2882+    return;
2883+  }
2884+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2885+  impl->SetPerformanceMode(mode);
2886+}
2887+
2888+OH_AI_PerformanceMode OH_AI_DeviceInfoGetPerformanceMode(const OH_AI_DeviceInfoHandle device_info) {
2889+  if (device_info == nullptr) {
2890+    MS_LOG(ERROR) << "device info is null";
2891+    return OH_AI_PERFORMANCE_NONE;
2892+  }
2893+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2894+    MS_LOG(ERROR) << "Get performance_mode of non-NNRT device is not allowable, ignored";
2895+    return OH_AI_PERFORMANCE_NONE;
2896+  }
2897+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2898+  return static_cast<OH_AI_PerformanceMode>(impl->GetPerformanceMode());
2899+}
2900+
2901+void OH_AI_DeviceInfoSetPriority(OH_AI_DeviceInfoHandle device_info, OH_AI_Priority priority) {
2902+  if (device_info == nullptr) {
2903+    MS_LOG(ERROR) << "device info is null";
2904+    return;
2905+  }
2906+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2907+    MS_LOG(ERROR) << "Set priority of non-NNRT device is not allowable, ignored";
2908+    return;
2909+  }
2910+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2911+  impl->SetPriority(priority);
2912+}
2913+
2914+OH_AI_Priority OH_AI_DeviceInfoGetPriority(const OH_AI_DeviceInfoHandle device_info) {
2915+  if (device_info == nullptr) {
2916+    MS_LOG(ERROR) << "device info is null";
2917+    return OH_AI_PRIORITY_NONE;
2918+  }
2919+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2920+    MS_LOG(ERROR) << "Get priority of non-NNRT device is not allowable, ignored";
2921+    return OH_AI_PRIORITY_NONE;
2922+  }
2923+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2924+  return static_cast<OH_AI_Priority>(impl->GetPriority());
2925+}
2926+
2927+OH_AI_API OH_AI_Status OH_AI_DeviceInfoAddExtension(OH_AI_DeviceInfoHandle device_info,
2928+                                                    const char *name, const char*value, size_t value_size) {
2929+  if (device_info == nullptr) {
2930+    MS_LOG(ERROR) << "device info is null";
2931+    return OH_AI_STATUS_LITE_NULLPTR;
2932+  }
2933+  if (OH_AI_DeviceInfoGetDeviceType(device_info) != OH_AI_DEVICETYPE_NNRT) {
2934+    MS_LOG(ERROR) << "Add extension to non-NNRT device is not allowable, ignored";
2935+    return OH_AI_STATUS_LITE_ERROR;
2936+  }
2937+  auto impl = reinterpret_cast<mindspore::NNRTDeviceInfo *>(device_info);
2938+  mindspore::Extension extension;
2939+  extension.name = std::string(name);
2940+  extension.value = std::vector<uint8_t>(value, value + value_size);
2941+  std::vector<mindspore::Extension> extension_list = impl->GetExtensions();
2942+  extension_list.push_back(extension);
2943+  impl->SetExtensions(extension_list);
2944+  return OH_AI_STATUS_SUCCESS;
2945+}
2946\ No newline at end of file
2947diff --git a/mindspore/lite/src/litert/c_api/context_c.h b/mindspore/lite/src/litert/c_api/context_c.h
2948index 076f4d1f..dc88b8a4 100644
2949--- a/mindspore/lite/src/litert/c_api/context_c.h
2950+++ b/mindspore/lite/src/litert/c_api/context_c.h
2951@@ -21,27 +21,4 @@
2952 #include <memory>
2953 #include "include/c_api/types_c.h"
2954 
2955-namespace mindspore {
2956-class Allocator;
2957-class Delegate;
2958-
2959-typedef struct DeviceInfoC {
2960-  OH_AI_DeviceType device_type;
2961-  bool enable_fp16 = false;
2962-  int frequency = 3;
2963-  std::string provider;
2964-  std::string provider_device;
2965-  std::shared_ptr<Allocator> allocator = nullptr;
2966-} DeviceInfoC;
2967-
2968-typedef struct ContextC {
2969-  std::vector<std::shared_ptr<DeviceInfoC>> device_info_list;
2970-  int32_t thread_num = 2;
2971-  bool enable_parallel = false;
2972-  std::vector<int32_t> affinity_core_list;
2973-  int affinity_mode = 0;
2974-  int delegate_mode = 0;
2975-  std::shared_ptr<Delegate> delegate = nullptr;
2976-} ContextC;
2977-}  // namespace mindspore
2978 #endif  // MINDSPORE_LITE_SRC_RUNTIME_C_API_CONTEXT_C_H_
2979diff --git a/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore/lite/src/litert/c_api/model_c.cc
2980index 802df6b1..9da52d76 100644
2981--- a/mindspore/lite/src/litert/c_api/model_c.cc
2982+++ b/mindspore/lite/src/litert/c_api/model_c.cc
2983@@ -17,321 +17,135 @@
2984 #include <vector>
2985 #include <cstdint>
2986 #include "include/api/context.h"
2987+#include <include/api/serialization.h>
2988 #include "include/api/types.h"
2989 #include "src/litert/cxx_api/tensor/tensor_impl.h"
2990 #include "src/litert/cxx_api/converters.h"
2991-#include "src/litert/lite_session.h"
2992-#include "src/litert/cpu_info.h"
2993+#include "src/litert//cxx_api/model/model_impl.h"
2994 
2995 namespace mindspore {
2996 class ModelC {
2997- public:
2998-  ModelC() : session_(nullptr), context_(nullptr) {}
2999+public:
3000+  ModelC() : model_(nullptr) {}
3001   ~ModelC() {
3002-    for (auto &impl : tensor_map_) {
3003-      delete impl.second;
3004+    for (auto in : inputs_) {
3005+      delete in;
3006+    }
3007+    for (auto out : outputs_) {
3008+      delete out;
3009+    }
3010+    for (auto out : outputs_train_) {
3011+      delete out;
3012     }
3013   }
3014 
3015-  Status Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context);
3016-  Status Build(const std::string &model_path, ModelType model_type, const ContextC *model_context);
3017-  Status Resize(const std::vector<LiteTensorImpl *> &inputs, const std::vector<std::vector<int64_t>> &shapes);
3018-
3019-  Status Predict(const OH_AI_TensorHandle *inputs, size_t input_num, OH_AI_TensorHandle **outputs, size_t *output_num,
3020-                 const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after);
3021-
3022-  LiteTensorImpl **GetInputs(size_t *input_num);
3023-  LiteTensorImpl **GetOutputs(size_t *output_num);
3024+  MSTensor **GetInputs(size_t *input_num);
3025+  MSTensor **GetOutputs(size_t *output_num);
3026+  mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &oh_callback);
3027+  std::shared_ptr<Model> model_;
3028+  std::shared_ptr<Context> context_;
3029 
3030- private:
3031-  Status RunGraph(const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after);
3032-  void ResetTensorData(std::vector<void *> old_data, std::vector<lite::Tensor *> tensors);
3033-  LiteTensorImpl *TensorToTensorImpl(mindspore::lite::Tensor *tensor);
3034-
3035- private:
3036-  std::shared_ptr<lite::LiteSession> session_ = nullptr;
3037-  std::shared_ptr<const ContextC> context_ = nullptr;
3038-  std::map<mindspore::lite::Tensor *, LiteTensorImpl *> tensor_map_;
3039-  std::vector<LiteTensorImpl *> inputs_;
3040-  std::vector<LiteTensorImpl *> outputs_;
3041-  bool is_already_built = false;
3042+private:
3043+  MSTensor **GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors);
3044+  std::vector<MSTensor *> inputs_;
3045+  std::vector<MSTensor *> outputs_;
3046+  std::vector<MSTensor *> outputs_train_;
3047 };
3048 
3049-Status ModelC::Build(const void *model_data, size_t data_size, ModelType model_type, const ContextC *model_context) {
3050-  if (is_already_built) {
3051-    MS_LOG(ERROR) << "The model is already built.";
3052-    return kLiteModelRebuild;
3053-  }
3054-  if (!PlatformInstructionSetSupportCheck()) {
3055-    MS_LOG(ERROR) << "The platform exist don't support's instruction.";
3056-    return kLiteNotSupport;
3057-  }
3058-  if(context_.get() != model_context){
3059-    context_.reset(model_context);
3060-  }
3061-  session_ = std::make_shared<lite::LiteSession>();
3062-  if (session_ == nullptr) {
3063-    MS_LOG(ERROR) << "create session failed";
3064-    return kLiteNullptr;
3065-  }
3066-  auto ret = session_->Init(ContextUtils::Convert(model_context));
3067-  if (ret != mindspore::lite::RET_OK) {
3068-    MS_LOG(ERROR) << "init session failed";
3069-    return static_cast<StatusCode>(ret);
3070-  }
3071-  ret = session_->LoadModelAndCompileByBuf(static_cast<const char *>(model_data), model_type, data_size);
3072-  if (ret != RET_OK) {
3073-    MS_LOG(ERROR) << "Load and compile failed";
3074-    return static_cast<StatusCode>(ret);
3075-  }
3076-  is_already_built = true;
3077-  return static_cast<StatusCode>(kSuccess);
3078-}
3079-
3080-Status ModelC::Build(const std::string &model_path, ModelType model_type, const ContextC *model_context) {
3081-  if (is_already_built) {
3082-    MS_LOG(ERROR) << "The model is already built.";
3083-    return kLiteModelRebuild;
3084-  }
3085-  if (!PlatformInstructionSetSupportCheck()) {
3086-    MS_LOG(ERROR) << "The platform exist don't support's instruction.";
3087-    return kLiteNotSupport;
3088-  }
3089-  if(context_.get() != model_context){
3090-    context_.reset(model_context);
3091-  }
3092-  session_ = std::make_shared<lite::LiteSession>();
3093-  if (session_ == nullptr) {
3094-    MS_LOG(ERROR) << "create session failed";
3095-    return kLiteNullptr;
3096-  }
3097-  auto ret = session_->Init(ContextUtils::Convert(model_context));
3098-  if (ret != mindspore::lite::RET_OK) {
3099-    MS_LOG(ERROR) << "init session failed";
3100-    return static_cast<StatusCode>(ret);
3101+MSTensor **ModelC::GetInputs(size_t *input_num) {
3102+  if (model_ == nullptr) {
3103+    MS_LOG(ERROR) << "model_ is nullptr.";
3104+    return nullptr;
3105   }
3106-  ret = session_->LoadModelAndCompileByPath(model_path, model_type);
3107-  if (ret != RET_OK) {
3108-    MS_LOG(ERROR) << "Load and compile failed";
3109-    return static_cast<StatusCode>(ret);
3110+  if (!inputs_.empty()) {
3111+    *input_num = inputs_.size();
3112+    return inputs_.data();
3113   }
3114-  is_already_built = true;
3115-  return static_cast<StatusCode>(kSuccess);
3116-}
3117 
3118-Status ModelC::Resize(const std::vector<LiteTensorImpl *> &inputs, const std::vector<std::vector<int64_t>> &shapes) {
3119-  std::vector<lite::Tensor *> inner_input;
3120-  size_t input_num = inputs.size();
3121-  for (size_t i = 0; i < input_num; i++) {
3122-    auto input = inputs[i];
3123-    if (input == nullptr || input->lite_tensor() == nullptr) {
3124-      MS_LOG(ERROR) << "Input tensor is null.";
3125-      return kLiteInputTensorError;
3126+  auto inputs = model_->GetInputs();
3127+  *input_num = inputs.size();
3128+  inputs_.resize(inputs.size(), nullptr);
3129+  for (size_t i = 0; i < inputs.size(); i++) {
3130+    inputs_[i] = new (std::nothrow) MSTensor(inputs[i].impl());
3131+    if (inputs_[i] == nullptr) {
3132+      inputs_.clear();
3133+      return nullptr;
3134     }
3135-    inner_input.push_back(input->lite_tensor());
3136   }
3137-  size_t shape_num = shapes.size();
3138-  std::vector<std::vector<int32_t>> inner_shapes(shape_num);
3139-  for (size_t i = 0; i < shape_num; i++) {
3140-    std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(inner_shapes[i]),
3141-                   [](int64_t value) { return static_cast<int32_t>(value); });
3142-  }
3143-  if (session_ == nullptr) {
3144-    MS_LOG(ERROR) << "Session implement is null.";
3145-    return kLiteNullptr;
3146-  }
3147-  auto ret = session_->Resize(inner_input, inner_shapes);
3148-  return static_cast<StatusCode>(ret);
3149+  return inputs_.data();
3150 }
3151 
3152-void ModelC::ResetTensorData(std::vector<void *> old_data, std::vector<lite::Tensor *> tensors) {
3153-  for (size_t j = 0; j < old_data.size(); j++) {
3154-    tensors.at(j)->set_data(old_data.at(j));
3155+MSTensor **ModelC::GetOutputs(size_t *output_num) {
3156+  if (model_->GetTrainMode() == true) {
3157+    return GetOutputsTensor(output_num, &outputs_train_);
3158+  } else {
3159+    return GetOutputsTensor(output_num, &outputs_);
3160   }
3161 }
3162 
3163-Status ModelC::Predict(const OH_AI_TensorHandle *inputs, size_t input_num, OH_AI_TensorHandle **outputs,
3164-                       size_t *output_num, const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after) {
3165-  if (outputs == nullptr || session_ == nullptr) {
3166-    MS_LOG(ERROR) << "param is nullptr.";
3167-    return kLiteError;
3168+MSTensor **ModelC::GetOutputsTensor(size_t *output_num, std::vector<MSTensor *> *vec_tensors) {
3169+  if (model_ == nullptr) {
3170+    MS_LOG(ERROR) << "model_ is nullptr.";
3171+    return nullptr;
3172   }
3173-  auto model_inputs = session_->GetInputs();
3174-  if (model_inputs.size() != input_num) {
3175-    MS_LOG(ERROR) << "Wrong input size.";
3176-    return kLiteError;
3177+  if (!vec_tensors->empty()) {
3178+    *output_num = vec_tensors->size();
3179+    return vec_tensors->data();
3180   }
3181-  std::vector<void *> old_data;
3182-  for (size_t i = 0; i < input_num; i++) {
3183-    auto real_input = model_inputs[i];
3184-    auto user_input = static_cast<LiteTensorImpl *>(inputs[i]);
3185-    if (user_input->DataType() != static_cast<DataType>(real_input->data_type())) {
3186-      ResetTensorData(old_data, model_inputs);
3187-      MS_LOG(ERROR) << "DataType does not match, input:" << user_input->Name()
3188-                    << ", real:" << real_input->tensor_name();
3189-      return kLiteInputTensorError;
3190-    }
3191-    if (user_input->Data() == nullptr) {
3192-      ResetTensorData(old_data, model_inputs);
3193-      MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has no data.";
3194-      return kLiteInputTensorError;
3195-    }
3196 
3197-    // GPU tensor can't manipulate CPU memory which the user provides.
3198-    // When model input is GPU tensor and user input is NOT GPU data,
3199-    // just free model input's data for late GPU Tensor filling.
3200-    if (IS_OPENCL_ALLOCATOR(real_input->allocator()) && (!IS_OPENCL_ALLOCATOR(user_input->GetAllocator()))) {
3201-      real_input->FreeData();
3202-    }
3203-    old_data.push_back(real_input->data());  // Save original data in model tensors.
3204-
3205-    if (real_input->data_type() == kObjectTypeString) {
3206-      std::vector<int32_t> shape;
3207-      std::transform(user_input->Shape().begin(), user_input->Shape().end(), std::back_inserter(shape),
3208-                     [](int64_t value) { return static_cast<int32_t>(value); });
3209-      real_input->set_shape(shape);
3210-      real_input->set_data(user_input->MutableData());
3211-    } else {
3212-      if (user_input->MutableData() != real_input->data()) {
3213-        if (real_input->Size() != user_input->DataSize()) {
3214-          ResetTensorData(old_data, model_inputs);
3215-          MS_LOG(ERROR) << "Tensor " << user_input->Name() << " has wrong data size.";
3216-          return kLiteInputTensorError;
3217-        }
3218-        if (!IS_OPENCL_ALLOCATOR(real_input->allocator())) {
3219-          real_input->set_data(user_input->MutableData());
3220-        } else {
3221-          // Use outside CPU data to fill GPU Tensor.
3222-          auto dst_data = real_input->MutableData();
3223-          auto src_data = user_input->MutableData();
3224-          (void)memcpy(dst_data, src_data, real_input->Size());
3225-        }
3226-      }
3227-    }
3228-  }
3229-  auto ret = RunGraph(before, after);
3230-  ResetTensorData(old_data, model_inputs);
3231-  if (ret != kSuccess) {
3232-    MS_LOG(ERROR) << "Run graph failed.";
3233-    return ret;
3234-  }
3235-
3236-  *outputs = reinterpret_cast<OH_AI_TensorHandle *>(GetOutputs(output_num));
3237-  return kSuccess;
3238-}
3239-
3240-Status ModelC::RunGraph(const OH_AI_KernelCallBack &before, const OH_AI_KernelCallBack &after) {
3241-  KernelCallBack before_call_back = nullptr;
3242-  KernelCallBack after_call_back = nullptr;
3243-  if (before != nullptr) {
3244-    before_call_back = [&](const std::vector<mindspore::lite::Tensor *> &before_inputs,
3245-                           const std::vector<mindspore::lite::Tensor *> &before_outputs,
3246-                           const MSCallBackParam &call_param) {
3247-      std::vector<LiteTensorImpl> inputs_impl;
3248-      std::vector<LiteTensorImpl> outputs_impl;
3249-      std::vector<OH_AI_TensorHandle> op_inputs;
3250-      std::vector<OH_AI_TensorHandle> op_outputs;
3251-    size_t op_input_num = before_inputs.size();
3252-    for (size_t i = 0; i < op_input_num; i++) {
3253-      inputs_impl.emplace_back(before_inputs[i]);
3254-      op_inputs.push_back(&(inputs_impl.back()));
3255-    }
3256-    size_t op_output_num = before_outputs.size();
3257-    for (size_t i = 0; i < op_output_num; i++) {
3258-      outputs_impl.emplace_back(before_outputs[i]);
3259-      op_outputs.push_back(&(outputs_impl.back()));
3260-    }
3261-      const OH_AI_CallBackParam op_info = {const_cast<char *>(call_param.node_name.c_str()),
3262-                                         const_cast<char *>(call_param.node_type.c_str())};
3263-      OH_AI_TensorHandleArray inputs = {op_input_num, op_inputs.data()};
3264-      OH_AI_TensorHandleArray outputs = {op_output_num, op_outputs.data()};
3265-    return before(inputs, outputs, op_info);
3266-  };
3267-  }
3268-  if (after != nullptr) {
3269-    after_call_back = [&](const std::vector<mindspore::lite::Tensor *> &after_inputs,
3270-                          const std::vector<mindspore::lite::Tensor *> &after_outputs,
3271-                          const MSCallBackParam &call_param) {
3272-      std::vector<LiteTensorImpl> inputs_impl;
3273-      std::vector<LiteTensorImpl> outputs_impl;
3274-      std::vector<OH_AI_TensorHandle> op_inputs;
3275-      std::vector<OH_AI_TensorHandle> op_outputs;
3276-    size_t op_input_num = after_inputs.size();
3277-    for (size_t i = 0; i < op_input_num; i++) {
3278-      inputs_impl.emplace_back(after_inputs[i]);
3279-      op_inputs.push_back(&(inputs_impl.back()));
3280-    }
3281-    size_t op_output_num = after_outputs.size();
3282-    for (size_t i = 0; i < op_output_num; i++) {
3283-      outputs_impl.emplace_back(after_outputs[i]);
3284-      op_outputs.push_back(&(outputs_impl.back()));
3285-    }
3286-    const OH_AI_CallBackParam op_info = {const_cast<char *>(call_param.node_name.c_str()),
3287-                                         const_cast<char *>(call_param.node_type.c_str())};
3288-    OH_AI_TensorHandleArray inputs = {op_input_num, op_inputs.data()};
3289-    OH_AI_TensorHandleArray outputs = {op_output_num, op_outputs.data()};
3290-    return after(inputs, outputs, op_info);
3291-  };
3292-  }
3293-  auto ret = session_->RunGraph(before_call_back, after_call_back);
3294-  return static_cast<StatusCode>(ret);
3295-}
3296-
3297-LiteTensorImpl *ModelC::TensorToTensorImpl(mindspore::lite::Tensor *tensor) {
3298-  LiteTensorImpl *impl = nullptr;
3299-  auto iter = tensor_map_.find(tensor);
3300-  if (iter != tensor_map_.end()) {
3301-    impl = iter->second;
3302-  } else {
3303-    impl = new (std::nothrow) LiteTensorImpl(tensor);
3304-    if (impl == nullptr || impl->lite_tensor() == nullptr) {
3305-      MS_LOG(ERROR) << "Create tensor failed.";
3306+  auto outputs = model_->GetOutputs();
3307+  *output_num = outputs.size();
3308+  vec_tensors->resize(outputs.size(), nullptr);
3309+  for (size_t i = 0; i < outputs.size(); i++) {
3310+    (*vec_tensors)[i] = new (std::nothrow) MSTensor(outputs[i].impl());
3311+    if ((*vec_tensors)[i] == nullptr) {
3312+      vec_tensors->clear();
3313       return nullptr;
3314     }
3315-    tensor_map_[tensor] = impl;
3316   }
3317-  return impl;
3318+  return vec_tensors->data();
3319 }
3320 
3321-LiteTensorImpl **ModelC::GetInputs(size_t *input_num) {
3322-  if (session_ == nullptr || input_num == nullptr) {
3323-    MS_LOG(ERROR) << "Session is null.";
3324-    return nullptr;
3325-  }
3326-  auto inputs = session_->GetInputs();
3327-  *input_num = inputs.size();
3328-  if (inputs_.capacity() < *input_num) {
3329-    inputs_.reserve(*input_num);
3330-  }
3331-  inputs_.clear();
3332-  std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_),
3333-                 [&](lite::Tensor *input) { return TensorToTensorImpl(input); });
3334-  return inputs_.data();
3335-}
3336+mindspore::MSKernelCallBack ModelC::TransCallBack(const OH_AI_KernelCallBack &oh_callback) {
3337+  mindspore::MSKernelCallBack call_back = nullptr;
3338+  if (oh_callback != nullptr) {
3339+    call_back = [&](const std::vector<mindspore::MSTensor> &inputs,
3340+                    const std::vector<mindspore::MSTensor> &outputs,
3341+                    const mindspore::MSCallBackParam &opInfo) {
3342+      std::vector<OH_AI_TensorHandle> vec_inputs;
3343+      std::vector<OH_AI_TensorHandle> vec_outputs;
3344+      OH_AI_CallBackParam call_back = {const_cast<char *>(opInfo.node_name.c_str()),
3345+                                       const_cast<char *>(opInfo.node_type.c_str())};
3346+      size_t inputs_handle_num = inputs.size();
3347+      for (size_t i = 0; i < inputs_handle_num; i++) {
3348+        vec_inputs.push_back(
3349+          static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(inputs)[i])));
3350+      }
3351+      size_t outputs_handle_num = inputs.size();
3352+      for (size_t i = 0; i < outputs_handle_num; i++) {
3353+        vec_outputs.push_back(
3354+          static_cast<OH_AI_TensorHandle>(&(static_cast<std::vector<mindspore::MSTensor>>(outputs)[i])));
3355+      }
3356 
3357-LiteTensorImpl **ModelC::GetOutputs(size_t *output_num) {
3358-  if (session_ == nullptr || output_num == nullptr) {
3359-    MS_LOG(ERROR) << "Session is null.";
3360-    return nullptr;
3361-  }
3362-  auto outputs = session_->GetOutputs();
3363-  *output_num = outputs.size();
3364-  if (outputs_.capacity() < *output_num) {
3365-    outputs_.reserve(*output_num);
3366+      OH_AI_TensorHandleArray handle_inputs = {inputs_handle_num, vec_inputs.data()};
3367+      OH_AI_TensorHandleArray handle_outputs = {outputs_handle_num, vec_outputs.data()};
3368+      return oh_callback(handle_inputs, handle_outputs, call_back);
3369+    };
3370   }
3371-  outputs_.clear();
3372-  std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_),
3373-                 [&](std::unordered_map<std::string, mindspore::lite::Tensor *>::value_type iter) {
3374-                   return TensorToTensorImpl(iter.second);
3375-                 });
3376-  return outputs_.data();
3377+  return call_back;
3378 }
3379 }  // namespace mindspore
3380 
3381 OH_AI_ModelHandle OH_AI_ModelCreate() {
3382   auto impl = new (std::nothrow) mindspore::ModelC();
3383   if (impl == nullptr) {
3384-    MS_LOG(ERROR) << "Model implement is null.";
3385+    MS_LOG(ERROR) << "Model implement is nullptr.";
3386+    return nullptr;
3387+  }
3388+  impl->model_ = std::make_shared<mindspore::Model>();
3389+  if (impl->model_ == nullptr) {
3390+    MS_LOG(ERROR) << "model_ is nullptr.";
3391+    delete impl;
3392     return nullptr;
3393   }
3394   return static_cast<OH_AI_ModelHandle>(impl);
3395@@ -358,55 +172,59 @@ size_t OH_AI_ModelCalcWorkspaceSize(OH_AI_ModelHandle model) {
3396 OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
3397                               OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context) {
3398   if (model == nullptr || model_data == nullptr || model_context == nullptr) {
3399-    MS_LOG(ERROR) << "param is nullptr.";
3400+    MS_LOG(ERROR) << "model/model_data/model_context is nullptr.";
3401     return OH_AI_STATUS_LITE_NULLPTR;
3402   }
3403   if (model_type == OH_AI_MODELTYPE_INVALID) {
3404-    MS_LOG(ERROR) << "param is invalid.";
3405+    MS_LOG(ERROR) << "model_type is invalid.";
3406     return OH_AI_STATUS_LITE_PARAM_INVALID;
3407   }
3408-  mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
3409+  mindspore::Context *context = static_cast<mindspore::Context *>(model_context);
3410   auto impl = static_cast<mindspore::ModelC *>(model);
3411-  auto ret = impl->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), context);
3412+  if (impl->context_.get() != context) {
3413+    impl->context_.reset(context);
3414+  }
3415+  auto ret = impl->model_->Build(model_data, data_size, static_cast<mindspore::ModelType>(model_type), impl->context_);
3416   return static_cast<OH_AI_Status>(ret.StatusCode());
3417 }
3418 
3419 OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type,
3420                                       const OH_AI_ContextHandle model_context) {
3421   if (model == nullptr || model_path == nullptr || model_context == nullptr) {
3422-    MS_LOG(ERROR) << "param is nullptr.";
3423+    MS_LOG(ERROR) << "model/model_path/model_context is nullptr.";
3424     return OH_AI_STATUS_LITE_NULLPTR;
3425   }
3426   if (model_type == OH_AI_MODELTYPE_INVALID) {
3427-    MS_LOG(ERROR) << "param is invalid.";
3428+    MS_LOG(ERROR) << "model_type is invalid.";
3429     return OH_AI_STATUS_LITE_PARAM_INVALID;
3430   }
3431-  mindspore::ContextC *context = static_cast<mindspore::ContextC *>(model_context);
3432+  mindspore::Context *context = static_cast<mindspore::Context *>(model_context);
3433   auto impl = static_cast<mindspore::ModelC *>(model);
3434-  auto ret = impl->Build(model_path, static_cast<mindspore::ModelType>(model_type), context);
3435+  if (impl->context_.get() != context) {
3436+    impl->context_.reset(context);
3437+  }
3438+  auto ret = impl->model_->Build(model_path, static_cast<mindspore::ModelType>(model_type), impl->context_);
3439   return static_cast<OH_AI_Status>(ret.StatusCode());
3440 }
3441 
3442 OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs,
3443                                OH_AI_ShapeInfo *shape_infos, size_t shape_info_num) {
3444   if (model == nullptr || shape_infos == nullptr) {
3445-    MS_LOG(ERROR) << "param is nullptr.";
3446+    MS_LOG(ERROR) << "model/shape_infos is nullptr.";
3447     return OH_AI_STATUS_LITE_NULLPTR;
3448   }
3449-  std::vector<mindspore::LiteTensorImpl *> vec_inputs;
3450-  std::transform(inputs.handle_list, inputs.handle_list + inputs.handle_num, std::back_inserter(vec_inputs),
3451-                 [](OH_AI_TensorHandle value) { return static_cast<mindspore::LiteTensorImpl *>(value); });
3452+  std::vector<mindspore::MSTensor> vec_inputs;
3453+  for (size_t i = 0; i < inputs.handle_num; ++i) {
3454+    vec_inputs.push_back(*static_cast<mindspore::MSTensor *>(inputs.handle_list[i]));
3455+  }
3456+
3457   std::vector<std::vector<int64_t>> vec_dims;
3458   for (size_t i = 0; i < shape_info_num; i++) {
3459     std::vector<int64_t> shape(shape_infos[i].shape, shape_infos[i].shape + shape_infos[i].shape_num);
3460-    if (std::any_of(shape.begin(), shape.end(), [](int64_t val) { return val < 0 || val > INT32_MAX; })) {
3461-      MS_LOG(ERROR) << "Invalid shape: " << shape << ", each dimension must be in [0, INT32_MAX]";
3462-      return OH_AI_STATUS_LITE_PARAM_INVALID;
3463-    }
3464     vec_dims.push_back(shape);
3465   }
3466   auto impl = static_cast<mindspore::ModelC *>(model);
3467-  auto ret = impl->Resize(vec_inputs, vec_dims);
3468+  auto ret = impl->model_->Resize(vec_inputs, vec_dims);
3469   return static_cast<OH_AI_Status>(ret.StatusCode());
3470 }
3471 
3472@@ -414,15 +232,25 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl
3473                                 OH_AI_TensorHandleArray *outputs, const OH_AI_KernelCallBack before,
3474                                 const OH_AI_KernelCallBack after) {
3475   if (model == nullptr) {
3476-    MS_LOG(ERROR) << "param is nullptr.";
3477+    MS_LOG(ERROR) << "model is nullptr.";
3478     return OH_AI_STATUS_LITE_NULLPTR;
3479   }
3480+  std::vector<mindspore::MSTensor> ms_tensor_inputs;
3481+  for (size_t i = 0; i < inputs.handle_num; i++) {
3482+    auto user_input = static_cast<mindspore::MSTensor *>(inputs.handle_list[i]);
3483+    ms_tensor_inputs.push_back(*user_input);
3484+  }
3485+
3486   auto impl = static_cast<mindspore::ModelC *>(model);
3487-  auto ret = impl->Predict(inputs.handle_list, inputs.handle_num, &(outputs->handle_list), &(outputs->handle_num),
3488-                           before, after);
3489+  mindspore::MSKernelCallBack before_call_back = impl->TransCallBack(before);
3490+  mindspore::MSKernelCallBack after_call_back = impl->TransCallBack(after);
3491+
3492+  std::vector<mindspore::MSTensor> ms_tensor_outputs;
3493+  auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back);
3494   if (!ret.IsOk()) {
3495     MS_LOG(ERROR) << "Predict fail, ret :" << ret;
3496   }
3497+  outputs->handle_list = reinterpret_cast<OH_AI_TensorHandle *>(impl->GetOutputs(&outputs->handle_num));
3498   return static_cast<OH_AI_Status>(ret.StatusCode());
3499 }
3500 
3501@@ -431,11 +259,6 @@ OH_AI_Status OH_AI_ModelRunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallB
3502   return OH_AI_STATUS_LITE_NOT_SUPPORT;
3503 }
3504 
3505-OH_AI_Status OH_AI_ModelSetTrainMode(const OH_AI_ModelHandle model, bool train) {
3506-  MS_LOG(ERROR) << "Unsupported Feature.";
3507-  return OH_AI_STATUS_LITE_NOT_SUPPORT;
3508-}
3509-
3510 OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *export_path) {
3511   MS_LOG(ERROR) << "Unsupported Feature.";
3512   return OH_AI_STATUS_LITE_NOT_SUPPORT;
3513@@ -443,7 +266,7 @@ OH_AI_Status OH_AI_ModelExportWeight(const OH_AI_ModelHandle model, const char *
3514 
3515 OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) {
3516   if (model == nullptr) {
3517-    MS_LOG(ERROR) << "param is nullptr.";
3518+    MS_LOG(ERROR) << "model is nullptr.";
3519     return {0, nullptr};
3520   }
3521   auto impl = static_cast<mindspore::ModelC *>(model);
3522@@ -454,7 +277,7 @@ OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model) {
3523 
3524 OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) {
3525   if (model == nullptr) {
3526-    MS_LOG(ERROR) << "param is nullptr.";
3527+    MS_LOG(ERROR) << "model is nullptr.";
3528     return {0, nullptr};
3529   }
3530   auto impl = static_cast<mindspore::ModelC *>(model);
3531@@ -465,7 +288,7 @@ OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model) {
3532 
3533 OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
3534   if (model == nullptr || tensor_name == nullptr) {
3535-    MS_LOG(ERROR) << "param is nullptr.";
3536+    MS_LOG(ERROR) << "model/tensor_name is nullptr.";
3537     return nullptr;
3538   }
3539   auto impl = static_cast<mindspore::ModelC *>(model);
3540@@ -482,7 +305,7 @@ OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model
3541 
3542 OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name) {
3543   if (model == nullptr || tensor_name == nullptr) {
3544-    MS_LOG(ERROR) << "param is nullptr.";
3545+    MS_LOG(ERROR) << "model/tensor_name is nullptr.";
3546     return nullptr;
3547   }
3548   auto impl = static_cast<mindspore::ModelC *>(model);
3549@@ -496,3 +319,294 @@ OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle mode
3550   MS_LOG(ERROR) << "tensor is not exist.";
3551   return nullptr;
3552 }
3553+
3554+OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate() {
3555+  auto impl = new (std::nothrow) mindspore::TrainCfg();
3556+  if (impl == nullptr) {
3557+    MS_LOG(ERROR) << "TrainCfg implement is nullptr.";
3558+    return nullptr;
3559+  }
3560+  return static_cast<OH_AI_TrainCfgHandle>(impl);
3561+}
3562+
3563+void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg) {
3564+  if (train_cfg != nullptr && *train_cfg != nullptr) {
3565+    auto impl = static_cast<mindspore::TrainCfg *>(*train_cfg);
3566+    delete impl;
3567+    *train_cfg = nullptr;
3568+  }
3569+}
3570+
3571+char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num) {
3572+  if (train_cfg == nullptr || num == nullptr) {
3573+    MS_LOG(ERROR) << "train_cfg/num is nullptr.";
3574+    return nullptr;
3575+  }
3576+  auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
3577+  auto loss_name = impl->GetLossName();
3578+  *num = loss_name.size();
3579+  char **name = static_cast<char **>(malloc(loss_name.size()));
3580+  if (name == nullptr) {
3581+    MS_LOG(ERROR) << "Failed to malloc loss_name.";
3582+    return nullptr;
3583+  }
3584+  for (size_t i = 0; i < loss_name.size(); i++) {
3585+    name[i] = static_cast<char *>(malloc(loss_name[i].size() + 1));
3586+    strcpy(name[i], loss_name[i].c_str());
3587+  }
3588+  return name;
3589+}
3590+
3591+void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num) {
3592+  if (train_cfg == nullptr) {
3593+    MS_LOG(ERROR) << "train_cfg is nullptr.";
3594+    return;
3595+  }
3596+  auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
3597+  std::vector<std::string> vec_name;
3598+  for (size_t i = 0; i < num; i++) {
3599+    vec_name.push_back(loss_name[i]);
3600+  }
3601+  impl->SetLossName(vec_name);
3602+}
3603+
3604+OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg) {
3605+  if (train_cfg == nullptr) {
3606+    MS_LOG(ERROR) << "train_cfg is nullptr.";
3607+    return OH_AI_KO0;
3608+  }
3609+  auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
3610+  return static_cast<OH_AI_OptimizationLevel>(impl->optimization_level_);
3611+}
3612+
3613+void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level) {
3614+  if (train_cfg == nullptr) {
3615+    MS_LOG(ERROR) << "train_cfg is nullptr.";
3616+    return;
3617+  }
3618+  auto impl = static_cast<mindspore::TrainCfg *>(train_cfg);
3619+  impl->optimization_level_ = static_cast<mindspore::OptimizationLevel>(level);
3620+}
3621+
3622+OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
3623+                                   OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
3624+                                   const OH_AI_TrainCfgHandle train_cfg) {
3625+  if (model == nullptr || model_data == nullptr || model_context == nullptr) {
3626+    MS_LOG(ERROR) << "model/model_data/model_context is nullptr.";
3627+    return OH_AI_STATUS_LITE_NULLPTR;
3628+  }
3629+  if (model_type == OH_AI_MODELTYPE_INVALID) {
3630+    MS_LOG(ERROR) << "model_type is invalid.";
3631+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3632+  }
3633+  auto impl = static_cast<mindspore::ModelC *>(model);
3634+
3635+  mindspore::Graph graph;
3636+  auto status = mindspore::Serialization::Load(model_data, data_size, static_cast<mindspore::ModelType>(model_type), &graph);
3637+  if (status != mindspore::kSuccess) {
3638+    MS_LOG(ERROR) << "load ms file failed.";
3639+    return OH_AI_STATUS_LITE_ERROR;
3640+  }
3641+  auto context = static_cast<mindspore::Context *>(model_context);
3642+  auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
3643+  if (impl->context_.get() != context) {
3644+    impl->context_.reset(context);
3645+  }
3646+  auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
3647+                                 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
3648+  if (ret != mindspore::kSuccess) {
3649+    MS_LOG(ERROR) << "Load and compile failed";
3650+  }
3651+  return static_cast<OH_AI_Status>(ret.StatusCode());
3652+}
3653+
3654+OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path,
3655+                                           OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
3656+                                           const OH_AI_TrainCfgHandle train_cfg) {
3657+  if (model == nullptr || model_path == nullptr || model_context == nullptr) {
3658+    MS_LOG(ERROR) << "model/model_path/model_context is nullptr.";
3659+    return OH_AI_STATUS_LITE_NULLPTR;
3660+  }
3661+  if (model_type == OH_AI_MODELTYPE_INVALID) {
3662+    MS_LOG(ERROR) << "model_type is invalid.";
3663+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3664+  }
3665+  auto impl = static_cast<mindspore::ModelC *>(model);
3666+
3667+  mindspore::Graph graph;
3668+  auto status = mindspore::Serialization::Load(model_path, static_cast<mindspore::ModelType>(model_type), &graph);
3669+  if (status != mindspore::kSuccess) {
3670+    MS_LOG(ERROR) << "load ms file failed. " << model_path;
3671+    return OH_AI_STATUS_LITE_ERROR;
3672+  }
3673+  auto context = static_cast<mindspore::Context *>(model_context);
3674+  auto build_train_cfg = static_cast<mindspore::TrainCfg *>(train_cfg);
3675+  if (impl->context_.get() != context) {
3676+    impl->context_.reset(context);
3677+  }
3678+  auto ret = impl->model_->Build(static_cast<mindspore::GraphCell>(graph), impl->context_,
3679+                                 std::shared_ptr<mindspore::TrainCfg>(build_train_cfg));
3680+  if (ret != mindspore::kSuccess) {
3681+    MS_LOG(ERROR) << "Load and compile failed";
3682+  }
3683+  return static_cast<OH_AI_Status>(ret.StatusCode());
3684+}
3685+
3686+OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate) {
3687+  if (model == nullptr) {
3688+    MS_LOG(ERROR) << "model is nullptr.";
3689+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3690+  }
3691+  auto impl = static_cast<mindspore::ModelC *>(model);
3692+  auto ret = impl->model_->SetLearningRate(learning_rate);
3693+  return static_cast<OH_AI_Status>(ret.StatusCode());
3694+}
3695+
3696+float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model) {
3697+  if (model == nullptr) {
3698+    MS_LOG(ERROR) << "model is nullptr.";
3699+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3700+  }
3701+  auto impl = static_cast<mindspore::ModelC *>(model);
3702+  return impl->model_->GetLearningRate();
3703+}
3704+
3705+OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after) {
3706+  if (model == nullptr) {
3707+    MS_LOG(ERROR) << "model is nullptr.";
3708+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3709+  }
3710+  auto impl = static_cast<mindspore::ModelC *>(model);
3711+  auto ret = impl->model_->RunStep(impl->TransCallBack(before), impl->TransCallBack(after));
3712+  return static_cast<OH_AI_Status>(ret.StatusCode());
3713+}
3714+
3715+OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model) {
3716+  if (model == nullptr) {
3717+    MS_LOG(ERROR) << "model is nullptr.";
3718+    return {0, nullptr};
3719+  }
3720+  auto impl = static_cast<mindspore::ModelC *>(model);
3721+  auto features = impl->model_->GetFeatureMaps();
3722+  size_t handle_num = features.size();
3723+
3724+  mindspore::MSTensor **handle_list = static_cast<mindspore::MSTensor **>(malloc(
3725+    handle_num * sizeof(mindspore::MSTensor *)));
3726+  if (handle_list == nullptr) {
3727+    MS_LOG(ERROR) << "Failed to malloc handle_list.";
3728+    return {0, nullptr};
3729+  }
3730+  for (size_t i = 0; i < handle_num; i++) {
3731+    handle_list[i] = new mindspore::MSTensor(features[i].impl());
3732+  }
3733+  return {handle_num, reinterpret_cast<OH_AI_TensorHandle *>(handle_list)};
3734+}
3735+
3736+OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights) {
3737+  if (model == nullptr) {
3738+    MS_LOG(ERROR) << "model is nullptr.";
3739+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3740+  }
3741+  auto impl = static_cast<mindspore::ModelC *>(model);
3742+  std::vector<mindspore::MSTensor> weights;
3743+  for (size_t i = 0; i < new_weights.handle_num; i++) {
3744+    weights.push_back(*static_cast<mindspore::MSTensor *>(new_weights.handle_list[i]));
3745+  }
3746+  auto ret = impl->model_->UpdateWeights(weights);
3747+  return static_cast<OH_AI_Status>(ret.StatusCode());
3748+}
3749+
3750+bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model) {
3751+  if (model == nullptr) {
3752+    MS_LOG(ERROR) << "model is nullptr.";
3753+    return false;
3754+  }
3755+  auto impl = static_cast<mindspore::ModelC *>(model);
3756+  return impl->model_->GetTrainMode();
3757+}
3758+
3759+OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train) {
3760+  if (model == nullptr) {
3761+    MS_LOG(ERROR) << "model is nullptr.";
3762+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3763+  }
3764+  auto impl = static_cast<mindspore::ModelC *>(model);
3765+  auto ret = impl->model_->SetTrainMode(train);
3766+  return static_cast<OH_AI_Status>(ret.StatusCode());
3767+}
3768+
3769+OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr, float momentum) {
3770+  if (model == nullptr) {
3771+    MS_LOG(ERROR) << "model is nullptr.";
3772+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3773+  }
3774+  auto impl = static_cast<mindspore::ModelC *>(model);
3775+  auto ret = impl->model_->SetupVirtualBatch(virtual_batch_multiplier, lr, momentum);
3776+  return static_cast<OH_AI_Status>(ret.StatusCode());
3777+}
3778+
3779+OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
3780+                               OH_AI_QuantizationType quantization_type, bool export_inference_only,
3781+                               char **output_tensor_name, size_t num) {
3782+  if (model == nullptr) {
3783+    MS_LOG(ERROR) << "model is nullptr.";
3784+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3785+  }
3786+  auto impl = static_cast<mindspore::ModelC *>(model);
3787+  std::vector<std::string> tensor_name;
3788+  for (size_t i = 0; i < num; i++) {
3789+    tensor_name.push_back(output_tensor_name[i]);
3790+  }
3791+  auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type),
3792+                                                   model_file,
3793+                                                   static_cast<mindspore::QuantizationType>(quantization_type),
3794+                                                   export_inference_only, tensor_name);
3795+  if (!ret.IsOk()) {
3796+    MS_LOG(ERROR) << "export model fail, ret :" << ret;
3797+  }
3798+  return static_cast<OH_AI_Status>(ret.StatusCode());
3799+}
3800+
3801+OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data,
3802+                                     size_t *data_size, OH_AI_QuantizationType quantization_type,
3803+                                     bool export_inference_only, char **output_tensor_name, size_t num) {
3804+  if (model == nullptr) {
3805+    MS_LOG(ERROR) << "model is nullptr.";
3806+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3807+  }
3808+  auto impl = static_cast<mindspore::ModelC *>(model);
3809+  std::vector<std::string> tensor_name;
3810+  for (size_t i = 0; i < num; i++) {
3811+    tensor_name.push_back(output_tensor_name[i]);
3812+  }
3813+  mindspore::Buffer buffer;
3814+  auto ret = mindspore::Serialization::ExportModel(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type),
3815+                                                   &buffer, static_cast<mindspore::QuantizationType>(quantization_type),
3816+                                                   export_inference_only, tensor_name);
3817+  auto data = static_cast<char *>(buffer.MutableData());
3818+  *model_data = (char *) malloc(buffer.DataSize());
3819+  *data_size = buffer.DataSize();
3820+  memcpy(*model_data, data, buffer.DataSize());
3821+  if (!ret.IsOk()) {
3822+    MS_LOG(ERROR) << "export model fail, ret :" << ret;
3823+  }
3824+  return static_cast<OH_AI_Status>(ret.StatusCode());
3825+}
3826+
3827+OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *weight_file,
3828+  bool is_inference, bool enable_fp16, char **changeable_weights_name, size_t num) {
3829+  if (model == nullptr) {
3830+    MS_LOG(ERROR) << "model is nullptr.";
3831+    return OH_AI_STATUS_LITE_PARAM_INVALID;
3832+  }
3833+  auto impl = static_cast<mindspore::ModelC *>(model);
3834+  std::vector<std::string> weights_name;
3835+  for (size_t i = 0; i < num; i++) {
3836+    weights_name.push_back(changeable_weights_name[i]);
3837+  }
3838+  auto ret = mindspore::Serialization::ExportWeightsCollaborateWithMicro(*(impl->model_.get()), static_cast<mindspore::ModelType>(model_type), weight_file, is_inference, enable_fp16, weights_name);
3839+  if (!ret.IsOk()) {
3840+    MS_LOG(ERROR) << "export model fail, ret :" << ret;
3841+  }
3842+  return static_cast<OH_AI_Status>(ret.StatusCode());
3843+}
3844diff --git a/mindspore/lite/src/litert/c_api/tensor_c.cc b/mindspore/lite/src/litert/c_api/tensor_c.cc
3845index 7b5c4c2f..4b1e6aff 100644
3846--- a/mindspore/lite/src/litert/c_api/tensor_c.cc
3847+++ b/mindspore/lite/src/litert/c_api/tensor_c.cc
3848@@ -17,7 +17,6 @@
3849 #include "include/api/status.h"
3850 #include "src/tensor.h"
3851 #include "src/litert/cxx_api/tensor/tensor_impl.h"
3852-#include "src/litert/inner_allocator.h"
3853 
3854 OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, const int64_t *shape, size_t shape_num,
3855                                       const void *data, size_t data_len) {
3856@@ -31,18 +30,23 @@ OH_AI_TensorHandle OH_AI_TensorCreate(const char *name, OH_AI_DataType type, con
3857   }
3858   auto lite_tensor =
3859     mindspore::lite::Tensor::CreateTensor(name, static_cast<mindspore::TypeId>(type), vec_shape, data, data_len);
3860-  auto impl = new (std::nothrow) mindspore::LiteTensorImpl(lite_tensor);
3861-  if (impl == nullptr || impl->lite_tensor() == nullptr) {
3862+  auto lite_tensor_impl = std::make_shared<mindspore::LiteTensorImpl>(lite_tensor);
3863+  if (lite_tensor_impl == nullptr || lite_tensor_impl->lite_tensor() == nullptr) {
3864     MS_LOG(ERROR) << "Failed to allocate tensor impl.";
3865     return nullptr;
3866   }
3867-  impl->set_from_session(false);
3868+  lite_tensor_impl->set_from_session(false);
3869+  auto impl = new (std::nothrow) mindspore::MSTensor(lite_tensor_impl);
3870+  if (impl == nullptr) {
3871+    MS_LOG(ERROR) << "Failed to allocate MSTensor.";
3872+    return nullptr;
3873+  }
3874   return impl;
3875 }
3876 
3877 void OH_AI_TensorDestroy(OH_AI_TensorHandle *tensor) {
3878   if (tensor != nullptr && *tensor != nullptr) {
3879-    auto impl = static_cast<mindspore::LiteTensorImpl *>(*tensor);
3880+    auto impl = static_cast<mindspore::MSTensor *>(*tensor);
3881     delete impl;
3882     *tensor = nullptr;
3883   }
3884@@ -53,20 +57,14 @@ OH_AI_TensorHandle OH_AI_TensorClone(OH_AI_TensorHandle tensor) {
3885     MS_LOG(ERROR) << "param is nullptr.";
3886     return nullptr;
3887   }
3888-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3889-  auto lite_tensor = static_cast<mindspore::lite::Tensor *>(impl->lite_tensor());
3890-  auto clone = mindspore::lite::Tensor::CopyTensor(*lite_tensor, true, lite_tensor->allocator());
3891-  if (clone == nullptr) {
3892-    MS_LOG(ERROR) << "Failed to allocate tensor.";
3893-    return nullptr;
3894-  }
3895-  auto clone_impl = new (std::nothrow) mindspore::LiteTensorImpl(clone);
3896+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3897+  auto clone_impl = impl->Clone();
3898   if (clone_impl == nullptr) {
3899-    delete clone;
3900     MS_LOG(ERROR) << "Failed to allocate tensor impl.";
3901     return nullptr;
3902   }
3903-  clone_impl->set_from_session(false);
3904+  std::static_pointer_cast<mindspore::LiteTensorImpl>(clone_impl->impl())->set_own_data(false);
3905+  clone_impl->SetTensorName(impl->Name() + "_duplicate");
3906   return clone_impl;
3907 }
3908 
3909@@ -75,8 +73,8 @@ void OH_AI_TensorSetName(OH_AI_TensorHandle tensor, const char *name) {
3910     MS_LOG(ERROR) << "param is nullptr.";
3911     return;
3912   }
3913-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3914-  impl->SetName(name);
3915+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3916+  impl->SetTensorName(name);
3917 }
3918 
3919 const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) {
3920@@ -84,8 +82,8 @@ const char *OH_AI_TensorGetName(const OH_AI_TensorHandle tensor) {
3921     MS_LOG(ERROR) << "param is nullptr.";
3922     return nullptr;
3923   }
3924-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3925-  return impl->Name().c_str();
3926+  auto ms_tensor = static_cast<mindspore::MSTensor *>(tensor);
3927+  return std::static_pointer_cast<mindspore::LiteTensorImpl>(ms_tensor->impl())->Name().c_str();
3928 }
3929 
3930 void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) {
3931@@ -93,7 +91,7 @@ void OH_AI_TensorSetDataType(OH_AI_TensorHandle tensor, OH_AI_DataType type) {
3932     MS_LOG(ERROR) << "param is nullptr.";
3933     return;
3934   }
3935-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3936+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3937   impl->SetDataType(static_cast<mindspore::DataType>(type));
3938 }
3939 
3940@@ -102,7 +100,7 @@ OH_AI_DataType OH_AI_TensorGetDataType(const OH_AI_TensorHandle tensor) {
3941     MS_LOG(ERROR) << "param is nullptr.";
3942     return OH_AI_DATATYPE_UNKNOWN;
3943   }
3944-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3945+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3946   auto dtype = impl->DataType();
3947   return static_cast<OH_AI_DataType>(dtype);
3948 }
3949@@ -112,7 +110,7 @@ void OH_AI_TensorSetShape(OH_AI_TensorHandle tensor, const int64_t *shape, size_
3950     MS_LOG(ERROR) << "param is nullptr.";
3951     return;
3952   }
3953-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3954+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3955   std::vector<int64_t> vec_shape(shape_num);
3956   for (size_t i = 0; i < shape_num; i++) {
3957     vec_shape[i] = shape[i];
3958@@ -125,7 +123,7 @@ const int64_t *OH_AI_TensorGetShape(const OH_AI_TensorHandle tensor, size_t *sha
3959     MS_LOG(ERROR) << "param is nullptr.";
3960     return nullptr;
3961   }
3962-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3963+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3964   *shape_num = impl->Shape().size();
3965   return impl->Shape().data();
3966 }
3967@@ -135,7 +133,7 @@ void OH_AI_TensorSetFormat(OH_AI_TensorHandle tensor, OH_AI_Format format) {
3968     MS_LOG(ERROR) << "param is nullptr.";
3969     return;
3970   }
3971-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3972+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3973   return impl->SetFormat(static_cast<mindspore::Format>(format));
3974 }
3975 
3976@@ -144,8 +142,8 @@ OH_AI_Format OH_AI_TensorGetFormat(const OH_AI_TensorHandle tensor) {
3977     MS_LOG(ERROR) << "param is nullptr.";
3978     return OH_AI_FORMAT_NHWC;
3979   }
3980-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3981-  return static_cast<OH_AI_Format>(impl->Format());
3982+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3983+  return static_cast<OH_AI_Format>(impl->format());
3984 }
3985 
3986 void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) {
3987@@ -153,16 +151,34 @@ void OH_AI_TensorSetData(OH_AI_TensorHandle tensor, void *data) {
3988     MS_LOG(ERROR) << "param is nullptr.";
3989     return;
3990   }
3991-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
3992+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
3993   return impl->SetData(data, true);
3994 }
3995 
3996+OH_AI_Status OH_AI_TensorSetUserData(OH_AI_TensorHandle tensor, void *data, size_t data_size) {
3997+  if (tensor == nullptr) {
3998+    MS_LOG(ERROR) << "param is nullptr.";
3999+    return OH_AI_STATUS_LITE_NULLPTR;
4000+  }
4001+
4002+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4003+  if ((impl->DataSize() > 0) && (data_size != impl->DataSize())) {
4004+    MS_LOG(ERROR) << "input data size does not match inner data size";
4005+    return OH_AI_STATUS_LITE_PARAM_INVALID;
4006+  }
4007+
4008+  // This is one tricky way to represent that the inner data is not owned by tensor itself.
4009+  impl->SetAllocator(nullptr);
4010+  impl->SetData(data, false);
4011+  return OH_AI_STATUS_SUCCESS;
4012+}
4013+
4014 const void *OH_AI_TensorGetData(const OH_AI_TensorHandle tensor) {
4015   if (tensor == nullptr) {
4016     MS_LOG(ERROR) << "param is nullptr.";
4017     return nullptr;
4018   }
4019-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
4020+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4021   return impl->Data().get();
4022 }
4023 
4024@@ -171,7 +187,7 @@ void *OH_AI_TensorGetMutableData(const OH_AI_TensorHandle tensor) {
4025     MS_LOG(ERROR) << "param is nullptr.";
4026     return nullptr;
4027   }
4028-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
4029+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4030   return impl->MutableData();
4031 }
4032 
4033@@ -180,7 +196,7 @@ int64_t OH_AI_TensorGetElementNum(const OH_AI_TensorHandle tensor) {
4034     MS_LOG(ERROR) << "param is nullptr.";
4035     return 0;
4036   }
4037-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
4038+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4039   return impl->ElementNum();
4040 }
4041 
4042@@ -189,6 +205,6 @@ size_t OH_AI_TensorGetDataSize(const OH_AI_TensorHandle tensor) {
4043     MS_LOG(ERROR) << "param is nullptr.";
4044     return 0;
4045   }
4046-  auto impl = static_cast<mindspore::LiteTensorImpl *>(tensor);
4047+  auto impl = static_cast<mindspore::MSTensor *>(tensor);
4048   return impl->DataSize();
4049 }
4050diff --git a/mindspore/lite/src/litert/c_api/type_c_private.h b/mindspore/lite/src/litert/c_api/type_c_private.h
4051new file mode 100644
4052index 00000000..2d3b3883
4053--- /dev/null
4054+++ b/mindspore/lite/src/litert/c_api/type_c_private.h
4055@@ -0,0 +1,40 @@
4056+/**
4057+ * Copyright 2023 Huawei Technologies Co., Ltd
4058+ *
4059+ * Licensed under the Apache License, Version 2.0 (the "License");
4060+ * you may not use this file except in compliance with the License.
4061+ * You may obtain a copy of the License at
4062+ *
4063+ * http://www.apache.org/licenses/LICENSE-2.0
4064+ *
4065+ * Unless required by applicable law or agreed to in writing, software
4066+ * distributed under the License is distributed on an "AS IS" BASIS,
4067+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
4068+ * See the License for the specific language governing permissions and
4069+ * limitations under the License.
4070+ */
4071+#ifndef MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_
4072+#define MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_
4073+
4074+#include <string>
4075+#include <vector>
4076+#include <memory>
4077+#include <stddef.h>
4078+#include "include/c_api/types_c.h"
4079+
4080+#ifdef __cplusplus
4081+extern "C" {
4082+#endif
4083+
4084+#define NNRT_DEVICE_NAME_MAX (128)
4085+
4086+struct NNRTDeviceDesc {
4087+  size_t device_id;
4088+  OH_AI_NNRTDeviceType device_type;
4089+  char device_name[NNRT_DEVICE_NAME_MAX];
4090+};
4091+
4092+#ifdef __cplusplus
4093+}
4094+#endif
4095+#endif  // MINDSPORE_LITE_SRC_LITERT_C_API_TYPE_C_PRIVATE_H_
4096diff --git a/mindspore/lite/src/litert/cxx_api/context.cc b/mindspore/lite/src/litert/cxx_api/context.cc
4097index 1371bcf0..e5f19d28 100644
4098--- a/mindspore/lite/src/litert/cxx_api/context.cc
4099+++ b/mindspore/lite/src/litert/cxx_api/context.cc
4100@@ -50,6 +50,11 @@ constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dyn
4101 constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size";
4102 constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize";
4103 constexpr auto kModelOptionAscendRankID = "mindspore.option.ascend.rank_id";
4104+constexpr auto kModelOptionNNRTDeviceID = "mindspore.option.nnrt.device_id";
4105+constexpr auto kModelOptionNNRTPerformanceMode = "mindspore.option.nnrt.performance_mode";
4106+constexpr auto kModelOptionNNRTPriority = "mindspore.option.nnrt.priority";
4107+constexpr auto kModelOptionNNRTEnableFP16 = "mindspore.option.nnrt.enable_fp16";
4108+constexpr auto kModelOptionNNRTExtensions = "mindspore.option.nnrt.extensions";
4109 #ifdef USE_GLOG
4110 extern "C" {
4111 extern void mindspore_log_init();
4112@@ -684,4 +689,84 @@ std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
4113   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendBufferOptimize);
4114   return StringToChar(ref);
4115 }
4116+
4117+void NNRTDeviceInfo::SetDeviceID(size_t device_id) {
4118+  if (data_ == nullptr) {
4119+    MS_LOG(ERROR) << "Invalid context.";
4120+    return;
4121+  }
4122+  data_->params[kModelOptionNNRTDeviceID] = device_id;
4123+}
4124+
4125+size_t NNRTDeviceInfo::GetDeviceID() const {
4126+  if (data_ == nullptr) {
4127+    MS_LOG(ERROR) << "Invalid context.";
4128+    return 0;
4129+  }
4130+  return GetValue<size_t>(data_, kModelOptionNNRTDeviceID);
4131+}
4132+
4133+void NNRTDeviceInfo::SetPerformanceMode(int performance_mode) {
4134+  if (data_ == nullptr) {
4135+    MS_LOG(ERROR) << "Invalid context.";
4136+    return;
4137+  }
4138+  data_->params[kModelOptionNNRTPerformanceMode] = performance_mode;
4139+}
4140+
4141+int NNRTDeviceInfo::GetPerformanceMode() const {
4142+  if (data_ == nullptr) {
4143+    MS_LOG(ERROR) << "Invalid context.";
4144+    return 0;
4145+  }
4146+  return GetValue<int>(data_, kModelOptionNNRTPerformanceMode);
4147+}
4148+
4149+void NNRTDeviceInfo::SetPriority(int priority) {
4150+  if (data_ == nullptr) {
4151+    MS_LOG(ERROR) << "Invalid context.";
4152+    return;
4153+  }
4154+  data_->params[kModelOptionNNRTPriority] = priority;
4155+}
4156+
4157+int NNRTDeviceInfo::GetPriority() const {
4158+  if (data_ == nullptr) {
4159+    MS_LOG(ERROR) << "Invalid context.";
4160+    return 0;
4161+  }
4162+  return GetValue<int>(data_, kModelOptionNNRTPriority);
4163+}
4164+
4165+void NNRTDeviceInfo::SetEnableFP16(bool is_fp16) {
4166+  if (data_ == nullptr) {
4167+    MS_LOG(ERROR) << "Invalid context.";
4168+    return;
4169+  }
4170+  data_->params[kModelOptionNNRTEnableFP16] = is_fp16;
4171+}
4172+
4173+bool NNRTDeviceInfo::GetEnableFP16() const {
4174+  if (data_ == nullptr) {
4175+    MS_LOG(ERROR) << "Invalid context.";
4176+    return false;
4177+  }
4178+  return GetValue<bool>(data_, kModelOptionNNRTEnableFP16);
4179+}
4180+
4181+void NNRTDeviceInfo::SetExtensions(const std::vector<Extension> &extensions) {
4182+  if (data_ == nullptr) {
4183+    MS_LOG(ERROR) << "Invalid context.";
4184+    return;
4185+  }
4186+  data_->params[kModelOptionNNRTExtensions] = extensions;
4187+}
4188+
4189+std::vector<Extension> NNRTDeviceInfo::GetExtensions() const {
4190+  if (data_ == nullptr) {
4191+    MS_LOG(ERROR) << "Invalid context.";
4192+    return {};
4193+  }
4194+  return GetValue<std::vector<Extension>>(data_, kModelOptionNNRTExtensions);
4195+}
4196 }  // namespace mindspore
4197diff --git a/mindspore/lite/src/litert/cxx_api/converters.cc b/mindspore/lite/src/litert/cxx_api/converters.cc
4198index 0ff345cc..e54a36ee 100644
4199--- a/mindspore/lite/src/litert/cxx_api/converters.cc
4200+++ b/mindspore/lite/src/litert/cxx_api/converters.cc
4201@@ -86,6 +86,23 @@ Status ContextUtils::AddCustomDevice(lite::InnerContext *inner_context,
4202   return kSuccess;
4203 }
4204 
4205+Status ContextUtils::AddNNRtDevice(lite::InnerContext *inner_context, size_t device_id, int performance_mode,
4206+                                   int priority, bool enable_fp16, const std::vector<Extension> &extensions) {
4207+  lite::DeviceInfo device_info = {0};
4208+  device_info.nnrt_device_info_.device_id_ = device_id;
4209+  device_info.nnrt_device_info_.performance_mode_ = performance_mode;
4210+  device_info.nnrt_device_info_.priority_ = priority;
4211+  device_info.nnrt_device_info_.enable_fp16_ = enable_fp16;
4212+  for (auto src_extension: extensions) {
4213+    lite::Extension dest_extension;
4214+    dest_extension.name = src_extension.name;
4215+    dest_extension.value = src_extension.value;
4216+    device_info.nnrt_device_info_.extensions_.push_back(dest_extension);
4217+  }
4218+  inner_context->device_list_.push_back({lite::DT_NNRT, device_info});
4219+  return kSuccess;
4220+}
4221+
4222 void ContextUtils::ResetContextDefaultParam(Context *context) {
4223   if (context->GetInterOpParallelNum() == 0) {
4224     context->SetInterOpParallelNum(kDefaultInterOpParallelNum);
4225@@ -163,44 +180,11 @@ std::shared_ptr<lite::InnerContext> ContextUtils::Convert(Context *context) {
4226       ret = AddAscendDevice(inner_context.get(), device.get());
4227     } else if (device->GetDeviceType() == kCustomDevice) {
4228       ret = AddCustomDevice(inner_context.get(), device);
4229-    }
4230-    if (ret != kSuccess) {
4231-      MS_LOG(ERROR) << "Add device failed!";
4232-      return nullptr;
4233-    }
4234-  }
4235-  return inner_context;
4236-}
4237-
4238-std::shared_ptr<lite::InnerContext> ContextUtils::Convert(const ContextC *context_c) {
4239-  auto inner_context = std::make_shared<lite::InnerContext>();
4240-  if ((context_c == nullptr) || (inner_context == nullptr)) {
4241-    MS_LOG(ERROR) << "Invalid context pointers.";
4242-    return nullptr;
4243-  }
4244-  auto device_list = context_c->device_info_list;
4245-  if (device_list.size() == 0 || device_list.size() > kMaxNumOfDevices) {
4246-    MS_LOG(ERROR) << "Device num, support min: 1, max: " << kMaxNumOfDevices;
4247-    return nullptr;
4248-  }
4249-  SetContextAttr(context_c->thread_num, 1, context_c->enable_parallel, context_c->affinity_core_list,
4250-                 context_c->delegate_mode, context_c->delegate, inner_context.get());
4251-  inner_context->device_list_.clear();
4252-  Status ret = kLiteError;
4253-  for (auto &device_info_c : device_list) {
4254-    MS_CHECK_TRUE_RET(device_info_c != nullptr, nullptr);
4255-    lite::DeviceInfo device_info = {{0}};
4256-    if (device_info_c->device_type == OH_AI_DEVICETYPE_CPU) {
4257-      if (device_info_c->allocator == nullptr) {
4258-        device_info_c->allocator = Allocator::Create();
4259-      }
4260-      ret = AddCpuDevice(device_info_c->allocator, context_c->affinity_mode, device_info_c->enable_fp16,
4261-                         device_info_c->provider, device_info_c->provider_device, inner_context.get());
4262-    } else if (device_info_c->device_type == OH_AI_DEVICETYPE_GPU) {
4263-      ret = AddGpuDevice(device_info_c->enable_fp16, 0, 0, 0, false, nullptr, nullptr, device_info_c->provider,
4264-                         device_info_c->provider_device, device_info_c->allocator, inner_context.get());
4265-    } else if (device_info_c->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
4266-      ret = AddNpuDevice(device_info_c->enable_fp16, device_info_c->frequency, inner_context.get());
4267+    } else if (device->GetDeviceType() == kNNRt) {
4268+      auto nnrt_device_info = device->Cast<NNRTDeviceInfo>();
4269+      ret = AddNNRtDevice(inner_context.get(), nnrt_device_info->GetDeviceID(),
4270+                          nnrt_device_info->GetPerformanceMode(), nnrt_device_info->GetPriority(),
4271+                          nnrt_device_info->GetEnableFP16(), nnrt_device_info->GetExtensions());
4272     }
4273     if (ret != kSuccess) {
4274       MS_LOG(ERROR) << "Add device failed!";
4275diff --git a/mindspore/lite/src/litert/cxx_api/converters.h b/mindspore/lite/src/litert/cxx_api/converters.h
4276index 0c043fc3..1af7c7df 100644
4277--- a/mindspore/lite/src/litert/cxx_api/converters.h
4278+++ b/mindspore/lite/src/litert/cxx_api/converters.h
4279@@ -24,14 +24,12 @@
4280 #include "include/api/cfg.h"
4281 #include "include/train/train_cfg.h"
4282 #include "src/litert/inner_context.h"
4283-#include "src/litert/c_api/context_c.h"
4284 #include "src/common/log_adapter.h"
4285 
4286 namespace mindspore {
4287 class MS_API ContextUtils {
4288  public:
4289   static std::shared_ptr<lite::InnerContext> Convert(Context *context);
4290-  static std::shared_ptr<lite::InnerContext> Convert(const ContextC *context_c);
4291 
4292  private:
4293   static void SetContextAttr(int32_t thread_num, int32_t inter_op_parallel_num, bool enable_parallel,
4294@@ -48,6 +46,8 @@ class MS_API ContextUtils {
4295   static Status AddNpuDevice(bool enable_fp16, int frequency, lite::InnerContext *inner_context);
4296   static Status AddAscendDevice(lite::InnerContext *inner_context, DeviceInfoContext *device);
4297   static Status AddCustomDevice(lite::InnerContext *inner_context, const std::shared_ptr<DeviceInfoContext> &device);
4298+  static Status AddNNRtDevice(lite::InnerContext *inner_context, size_t device_id, int performance_mode, int priority,
4299+                              bool enable_fp16, const std::vector<Extension> &extensions);
4300   static bool IsAffinityModeValid(int affinity_mode) {
4301     return affinity_mode >= lite::NO_BIND && affinity_mode <= lite::MID_CPU;
4302   }
4303diff --git a/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt b/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt
4304index 70aa63f3..625459e2 100644
4305--- a/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt
4306+++ b/mindspore/lite/src/litert/delegate/nnrt/CMakeLists.txt
4307@@ -1,30 +1,13 @@
4308 include_directories(${DDK_PATH})
4309 include_directories($(CCSRC_DIR)/plugin/device/cpu/kernel)
4310+include_directories(${CMAKE_SOURCE_DIR}/../../../../../../foundation/ai/neural_network_runtime/)
4311 
4312 include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
4313-#include_directories(/home/tony/wty/workspace/ohos/third_party/mindspore/mindspore/lite/mindir/include/inner)
4314-#include_directories(/home/tony/wty/workspace/ohos/third_party/mindspore/mindspore/lite/mindir/include)
4315+
4316 file(GLOB_RECURSE NNRT_SRC
4317         ${CMAKE_CURRENT_SOURCE_DIR}/*.cc
4318 )
4319-
4320-#add_library(hiai SHARED IMPORTED)
4321-#set_target_properties(hiai PROPERTIES IMPORTED_LOCATION
4322-#        ${DDK_LIB_PATH}/libhiai.so)
4323-#add_library(hiai_ir SHARED IMPORTED)
4324-#set_target_properties(hiai_ir PROPERTIES IMPORTED_LOCATION
4325-#        ${DDK_LIB_PATH}/libhiai_ir.so)
4326-#add_library(hiai_ir_build SHARED IMPORTED)
4327-#set_target_properties(hiai_ir_build PROPERTIES IMPORTED_LOCATION
4328-#        ${DDK_LIB_PATH}/libhiai_ir_build.so)
4329-#add_library(npu_kernel_mid OBJECT ${NPU_RUNTIME_SRC})
4330-#add_dependencies(npu_kernel_mid fbs_src)
4331-#target_link_libraries(
4332-#        npu_kernel_mid
4333-#        hiai
4334-#        hiai_ir
4335-#        hiai_ir_build
4336-#)
4337-
4338 file(GLOB convert_source checker/*.cc)
4339-add_library(nnr_mid OBJECT ${NNRT_SRC} ${convert_source} )
4340\ No newline at end of file
4341+
4342+add_library(nnrt_mid OBJECT ${NNRT_SRC} ${convert_source})
4343+target_include_directories(nnrt_mid PUBLIC ${CMAKE_SOURCE_DIR}/../../../../../../foundation/ai/neural_network_runtime/)
4344\ No newline at end of file
4345diff --git a/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc b/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc
4346index 4df7e477..6b191c8e 100644
4347--- a/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc
4348+++ b/mindspore/lite/src/litert/delegate/nnrt/checker/primitive_check.cc
4349@@ -109,6 +109,8 @@ Status CheckPrimitiveSupported(const schema::Primitive *primitive) {
4350         return mindspore::kSuccess;
4351       case schema::PrimitiveType_Unsqueeze:
4352         return mindspore::kSuccess;
4353+      case schema::PrimitiveType_Custom:
4354+        return mindspore::kSuccess;
4355       default: {
4356         MS_LOG(WARNING) << "No primitive type :" << (int)(type);
4357         return mindspore::kLiteSuccessExit;
4358diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc
4359index 34897331..9f012e76 100644
4360--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc
4361+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.cc
4362@@ -13,144 +13,637 @@
4363  * See the License for the specific language governing permissions and
4364  * limitations under the License.
4365  */
4366+
4367+#include <unordered_set>
4368+#include <numeric>
4369 #include "nnrt_delegate.h"
4370 #include "checker/primitive_check.h"
4371 #include "src/common/log_adapter.h"
4372-#include "interfaces/kits/c/neural_network_runtime.h"
4373+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
4374 #include "interfaces/innerkits/c/neural_network_runtime_inner.h"
4375 #include "nnrt_model_kernel.h"
4376+#include "schema/model_generated.h"
4377+#include "schema/ops_generated.h"
4378+#include "flatbuffers/flatbuffers.h"
4379+#include "litert/tensor_category.h"
4380+
4381+namespace mindspore {
4382+namespace lite {
4383+void NNRTDelegate::InitCachePath() {
4384+  static const std::string kCachePathName = "CachePath";
4385+  static const std::string kCacheVersion = "CacheVersion";
4386+
4387+  const auto &extensions = nnrt_device_info_.extensions_;
4388 
4389-mindspore::Status mindspore::NNRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
4390-  if (this->nnrt_lite_graph == nullptr) {
4391-    MS_LOG(ERROR) << "nnrt_lite_graph is nullptr.";
4392-    return mindspore::kLiteError;
4393+  auto iter_path = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
4394+    return extension.name == kCachePathName;
4395+  });
4396+  if (iter_path != extensions.end()) {
4397+    cache_path_ = std::string(iter_path->value.begin(), iter_path->value.end());
4398   }
4399-  if (this->nnrt_lite_graph->sub_graphs_.empty()) {
4400-    // must have at lease one subgraph
4401-    MS_LOG(ERROR) << "must have at lease one subgraph";
4402-    return mindspore::kLiteError;
4403+
4404+  auto iter_version = std::find_if(extensions.begin(), extensions.end(), [](const Extension &extension) {
4405+    return extension.name == kCacheVersion;
4406+  });
4407+  if (iter_version != extensions.end()) {
4408+    std::string version_str = std::string(iter_version->value.begin(), iter_version->value.end());
4409+    cache_version_ = static_cast<uint32_t>(std::atol(version_str.c_str()));
4410   }
4411-  OH_NN_ReturnCode ret_code;
4412-  OH_NNModel *oh_nnmodel = OH_NNModel_Construct();
4413-  if (oh_nnmodel == nullptr) {
4414-    MS_LOG(ERROR) << "Construct NNModel failed, oh_nnmodel is nullptr.";
4415-    return mindspore::kLiteError;
4416+}
4417+
4418+Status NNRTDelegate::Build(DelegateModel<schema::Primitive> *model) {
4419+#ifdef SUPPORT_NNRT_METAGRAPH
4420+  if (IsKirinNPU()) {
4421+    MS_LOG(DEBUG) << "Choose to build nnrt model with Metagraph";
4422+    InitCachePath();
4423+    return BuildKirinNPUModel(model);
4424   }
4425+#endif
4426 
4427-  ret_code = OH_NNModel_BuildFromLiteGraph(oh_nnmodel, this->nnrt_lite_graph);
4428-  if (ret_code != OH_NN_SUCCESS) {
4429-    MS_LOG(ERROR) << "Build NNModel failed, OH_NN_ReturnCode = " << ret_code;
4430-    OH_NNModel_Destroy(&oh_nnmodel);
4431-    return mindspore::kLiteError;
4432+  return BuildNormalModel(model);
4433+}
4434+
4435+bool NNRTDelegate::IsCustomModel() const {
4436+  // check if there is only one Cutsom kernel in LiteModel.
4437+  if (lite_graph_ == nullptr) {
4438+    return false;
4439+  }
4440+  if (lite_graph_->all_nodes_.size() != 1) {
4441+    return false;
4442+  }
4443+  auto node = lite_graph_->all_nodes_[0];
4444+  if (node == nullptr) {
4445+    return false;
4446+  }
4447+  if (node->node_type_ != mindspore::schema::PrimitiveType_Custom) {
4448+    return false;
4449+  }
4450+  return true;
4451+}
4452+
4453+#ifdef SUPPORT_NNRT_METAGRAPH
4454+bool NNRTDelegate::IsKirinNPU() const {
4455+  const std::string kirin_npu_name_prefix = "NPU_";
4456+  auto device_id = nnrt_device_info_.device_id_;
4457+  const char *device_name;
4458+  auto ret = OH_NNDevice_GetName(device_id, &device_name);
4459+  if (ret != OH_NN_SUCCESS) {
4460+    MS_LOG(WARNING) << "Get name of device: " << device_id << " failed, error: " << ret;
4461+    return false;
4462+  }
4463+
4464+  if (strncmp(kirin_npu_name_prefix.c_str(), device_name, kirin_npu_name_prefix.size()) != 0) {
4465+    MS_LOG(WARNING) << "strncmp: " << device_id << " failed, device_name: " << device_name;
4466+    return false;
4467+  }
4468+  return true;
4469+}
4470+
4471+Status NNRTDelegate::BuildKirinNPUModel(DelegateModel<schema::Primitive> *model) {
4472+  OH_NNModel *nn_model = OH_NNModel_Construct();
4473+  if (nn_model == nullptr) {
4474+    MS_LOG(ERROR) << "Create NNModel failed, result is nullptr";
4475+    return kLiteNullptr;
4476+  }
4477+
4478+  size_t extension_size = nnrt_device_info_.extensions_.size();
4479+  std::vector<OH_NN_Extension> extensions;
4480+  MS_LOG_DEBUG << "set extensions, item number: " << extension_size;
4481+  const size_t kExtensionNameMax = 128; // This is a length limitation in NNRT API.
4482+  for (size_t i = 0; i < extension_size; i++) {
4483+    auto &src_extension = nnrt_device_info_.extensions_[i];
4484+    OH_NN_Extension dst_extension;
4485+    dst_extension.name[kExtensionNameMax - 1] = '\0';
4486+    strncpy(dst_extension.name, src_extension.name.c_str(), kExtensionNameMax - 1);
4487+    dst_extension.value = (char *)((void *)src_extension.value.data());
4488+    dst_extension.valueSize = src_extension.value.size();
4489+    extensions.push_back(dst_extension);
4490+    MS_LOG_DEBUG << "set extension, item name: " << dst_extension.name << ", value size: " << dst_extension.valueSize;
4491+  }
4492+
4493+  if (IsCustomModel()) {
4494+    auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, lite_graph_);
4495+    if (ret != OH_NN_SUCCESS) {
4496+      MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret;
4497+      OH_NNModel_Destroy(&nn_model);
4498+      return kLiteError;
4499+    }
4500+  } else {
4501+    SetKirinModelInputsAndOutputs(nn_model);
4502+    auto ret = OH_NNModel_BuildFromMetaGraph(nn_model, meta_graph_, extensions.data(), extensions.size());
4503+    if (ret != OH_NN_SUCCESS) {
4504+      MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret;
4505+      OH_NNModel_Destroy(&nn_model);
4506+      return kLiteError;
4507+    }
4508+  }
4509+
4510+  auto ret2 =  CreateFullModelKernel(model, nn_model);
4511+  if (ret2 != kSuccess) {
4512+    MS_LOG(ERROR) << "Create full model kernel failed, ret: " << ret2;
4513+    return kLiteError;
4514   }
4515-  MS_LOG(INFO) << "NNRTDelegate creates NNModel success.";
4516+  return kSuccess;
4517+}
4518+
4519+std::vector<OH_NN_TensorInfo> NNRTDelegate::CreateNNTensorInfos(const std::vector<uint32_t> &indices) const {
4520+  std::vector<OH_NN_TensorInfo> nn_tensor_infos;
4521+  for (auto index: indices) {
4522+    auto tensor = lite_graph_->all_tensors_[index];
4523+    auto shape = tensor->dims();
4524+    auto data_type = tensor->dataType();
4525+    auto name = tensor->name();
4526+    auto format = tensor->format();
4527 
4528-  OH_NNCompilation *oh_nn_compilation = nullptr;
4529-  oh_nn_compilation = OH_NNCompilation_Construct(oh_nnmodel);
4530+    OH_NN_TensorInfo info;
4531+    info.dataType = CastToNNRTDataType(static_cast<mindspore::DataType>(data_type));
4532+    info.dimensions = shape->data();
4533+    info.dimensionCount = shape->size();
4534+    strcpy(info.name, name->c_str());
4535+    info.format = CastToNNRTFormat(static_cast<Format>(format));
4536+    nn_tensor_infos.push_back(info);
4537+  }
4538+  return nn_tensor_infos;
4539+}
4540 
4541-  if (oh_nn_compilation == nullptr) {
4542+Status NNRTDelegate::SetKirinModelInputsAndOutputs(OH_NNModel *nn_model) {
4543+  std::vector<OH_NN_TensorInfo> inputInfos;
4544+  std::vector<OH_NN_TensorInfo> outputInfos;
4545+  auto input_infos = CreateNNTensorInfos(lite_graph_->input_indices_);
4546+  auto output_infos = CreateNNTensorInfos(lite_graph_->output_indices_);
4547+  OH_NNModel_SetInputsAndOutputsInfo(nn_model, input_infos.data(), input_infos.size(), output_infos.data(),
4548+                                     output_infos.size());
4549+  return kSuccess;
4550+}
4551+
4552+Status NNRTDelegate::CreateFullModelKernel(DelegateModel<schema::Primitive> *model, OH_NNModel *nn_model) {
4553+  OH_NNCompilation *nn_compilation = OH_NNCompilation_Construct(nn_model);
4554+  if (nn_compilation == nullptr) {
4555     MS_LOG(ERROR) << "Construct NNCompilation failed";
4556-    OH_NNModel_Destroy(&oh_nnmodel);
4557-    return mindspore::kLiteError;
4558+    OH_NNModel_Destroy(&nn_model);
4559+    return kLiteError;
4560   }
4561-  MS_LOG(INFO) << "NNRTDelegate creates NNCompilation success.";
4562+  MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success.";
4563 
4564-  const size_t *allDevicesID = nullptr;
4565-  uint32_t device_count = 0;
4566-  ret_code = OH_NNDevice_GetAllDevicesID(&allDevicesID, &device_count);
4567-  if (ret_code != OH_NN_SUCCESS) {
4568-    MS_LOG(ERROR) << "NNModel GetAllDevicesID failed, OH_NN_ReturnCode = " << ret_code;
4569-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4570-    OH_NNModel_Destroy(&oh_nnmodel);
4571-    return mindspore::kLiteError;
4572+  auto ret_code = InitNNCompilation(nn_compilation);
4573+  if (ret_code != kSuccess) {
4574+    MS_LOG(ERROR) << "Init NNCompilation failed";
4575+    OH_NNModel_Destroy(&nn_model);
4576+    OH_NNCompilation_Destroy(&nn_compilation);
4577+    return kLiteError;
4578   }
4579+  OH_NNModel_Destroy(&nn_model);
4580 
4581-  if (device_count <= 0) {
4582-    MS_LOG(WARNING) << "No NNRt Device found, fall back to CPU. ";
4583-    // OH_NNCompilation_Destroy(&oh_nn_compilation);
4584-    // OH_NNModel_Destroy(&oh_nnmodel);
4585-    return mindspore::kSuccess;
4586+  OH_NNExecutor *nn_executor = nullptr;
4587+  nn_executor = OH_NNExecutor_Construct(nn_compilation);
4588+  if (nn_executor == nullptr) {
4589+    MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code;
4590+    OH_NNCompilation_Destroy(&nn_compilation);
4591+    return kLiteError;
4592   }
4593-  MS_LOG(INFO) << "NNRTDelegate GetAllDevicesID success.";
4594+  OH_NNCompilation_Destroy(&nn_compilation);
4595 
4596-  // check if model ops are supported
4597-  const bool *issupported = nullptr;
4598+  auto nnrt_model_kernel = new (std::nothrow)NNRTModelKernel(nn_executor, model->inputs(), model->outputs());
4599+  if (nnrt_model_kernel == nullptr) {
4600+    OH_NNExecutor_Destroy(&nn_executor);
4601+    MS_LOG(ERROR) << "new NNRTModelKernel failed";
4602+    return kLiteError;
4603+  }
4604+  model->Replace(model->BeginKernelIterator(), model->EndKernelIterator(), nnrt_model_kernel);
4605+  return kSuccess;
4606+}
4607+#endif
4608+
4609+Status NNRTDelegate::BuildNormalModel(DelegateModel<schema::Primitive> *model) {
4610+  MS_LOG(DEBUG) << "Start to build NNRT model.";
4611+  if ((lite_graph_ == nullptr) || (lite_graph_->sub_graphs_.size() > 1)) {
4612+    MS_LOG(WARNING) << "LiteGraph contains more than one subgraph. NNRT does not support control-flow model yet, fallback to CPU";
4613+    return kSuccess;
4614+  }
4615+
4616+  OH_NNModel *full_model = CreateFullNNModel();
4617+  if (full_model == nullptr) {
4618+    MS_LOG(WARNING) << "Build full NNModel failed, fallback to CPU";
4619+    return kSuccess;
4620+  }
4621+  std::vector<bool> op_supports = QueryOpSupports(full_model);
4622+  if (op_supports.empty()) {
4623+    MS_LOG(WARNING) << "Query no op supports for full model, fallback to CPU";
4624+    OH_NNModel_Destroy(&full_model);
4625+    return kSuccess;
4626+  }
4627+  auto nnrt_subgraph_ranges = GetNNRTSubgraphRanges(model, op_supports);
4628+  MS_LOG(INFO) << "Found NNRT subgraph count: " << nnrt_subgraph_ranges.size();
4629+
4630+  std::vector<LiteGraph *> sub_lite_graphs;
4631+  auto ret = CreateLiteGraphForNNRTSubgraph(nnrt_subgraph_ranges, &sub_lite_graphs);
4632+  if (ret != kSuccess) {
4633+    OH_NNModel_Destroy(&full_model);
4634+    MS_LOG(WARNING) << "Create NNRT sub LiteGraph failed, fallback to CPU";
4635+    return kSuccess;
4636+  }
4637+
4638+  std::vector<NNRTModelKernel *> nnrt_subgraph_kernels;
4639+  ret = CreateNNRTSubgraphKernels(model, sub_lite_graphs, nnrt_subgraph_ranges, &nnrt_subgraph_kernels);
4640+  if (ret != kSuccess) {
4641+    OH_NNModel_Destroy(&full_model);
4642+    MS_LOG(WARNING) << "Create NNRT subgraph kernel failed, fallback to CPU";
4643+    return kSuccess;
4644+  }
4645+
4646+  ReplaceNNRTKernelsInDelegateModel(model, nnrt_subgraph_ranges, nnrt_subgraph_kernels);
4647+  OH_NNModel_Destroy(&full_model);
4648+  MS_LOG(INFO) << "NNRTDelegate build success.";
4649+  return kSuccess;
4650+}
4651+
4652+OH_NNModel *NNRTDelegate::CreateFullNNModel() {
4653+  if (lite_graph_ == nullptr) {
4654+    MS_LOG(ERROR) << "Lite graph is null";
4655+    return nullptr;
4656+  }
4657+
4658+  if (lite_graph_->sub_graphs_.empty()) {
4659+    MS_LOG(ERROR) << "Lite graph must have at lease one subgraph";
4660+    return nullptr;
4661+  }
4662+
4663+  OH_NNModel *nn_model = OH_NNModel_Construct();
4664+  if (nn_model == nullptr) {
4665+    MS_LOG(ERROR) << "Create NNModel failed, result is nullptr";
4666+    return nullptr;
4667+  }
4668+
4669+  auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, lite_graph_);
4670+  if (ret != OH_NN_SUCCESS) {
4671+    MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret;
4672+    OH_NNModel_Destroy(&nn_model);
4673+    return nullptr;
4674+  }
4675+  return nn_model;
4676+}
4677+
4678+std::vector<bool> NNRTDelegate::QueryOpSupports(OH_NNModel *nn_model) {
4679+  const bool *is_supported = nullptr; // Note: this memory is owned by nn_model, don't free alone.
4680   uint32_t op_count = 0;
4681-  ret_code = OH_NNModel_GetAvailableOperations(oh_nnmodel, allDevicesID[0], &issupported, &op_count);
4682-  if (ret_code != OH_NN_SUCCESS) {
4683-    MS_LOG(ERROR) << "NNModel GetAvailableOperations failed, OH_NN_ReturnCode = " << ret_code
4684-                  << ", maybe due to dataParcel data length limitaion. Fall back to CPU.";
4685-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4686-    OH_NNModel_Destroy(&oh_nnmodel);
4687-    return mindspore::kSuccess;
4688+  auto ret = OH_NNModel_GetAvailableOperations(nn_model, nnrt_device_info_.device_id_, &is_supported, &op_count);
4689+  if (ret != OH_NN_SUCCESS) {
4690+    MS_LOG(WARNING) << "NNModel GetAvailableOperations failed, ret: " << ret
4691+                  << ", maybe caused by dataParcel data length limitation";
4692+    return {};
4693   }
4694-  uint32_t supported_op_count = 0;
4695-  for (uint32_t i = 0; i < op_count; i++) {
4696-    if (issupported[i]) {
4697-      supported_op_count++;
4698+  std::vector<bool> op_supports(is_supported, is_supported + op_count);
4699+  return op_supports;
4700+}
4701+
4702+/* Find continuous sub-sequence in op_supports. */
4703+std::vector<NNRTOpRange> NNRTDelegate::GetNNRTSubgraphRanges(DelegateModel<schema::Primitive> *model,
4704+                                                             const std::vector<bool> &op_supports) {
4705+  std::vector<NNRTOpRange> nnrt_subgraph_ranges;
4706+  NNRTOpRange op_range;
4707+  bool start_count = false;
4708+  for (size_t i = 0; i < op_supports.size(); i++) {
4709+    if (op_supports[i]) {
4710+      if (start_count == false) {
4711+        start_count = true;
4712+        op_range.begin_index_ = i;
4713+        op_range.begin_iter_ = model->BeginKernelIterator() + i;
4714+      }
4715+    } else {
4716+      if (start_count == true) {
4717+        start_count = false;
4718+        op_range.end_index_ = i;
4719+        op_range.end_iter_ = model->BeginKernelIterator() + i;
4720+        nnrt_subgraph_ranges.push_back(op_range);
4721+      }
4722     }
4723   }
4724-  if (op_count != supported_op_count) {
4725-    MS_LOG(WARNING) << "this model has " << op_count << "ops, but NNRT only support " << supported_op_count
4726-                    << " ops, fall back to CPU.";
4727-    // must support all op, else fall back to CPU
4728-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4729-    OH_NNModel_Destroy(&oh_nnmodel);
4730-    return mindspore::kSuccess;
4731+  // handle last true subsequence
4732+  if (start_count == true) {
4733+    op_range.end_index_ = op_supports.size();
4734+    op_range.end_iter_ = model->EndKernelIterator();
4735+    nnrt_subgraph_ranges.push_back(op_range);
4736+    MS_LOG(INFO) << "Schedule NNRT subgraph range: [" << op_range.begin_index_ << ", " << op_range.end_index_ << ")";
4737   }
4738-  MS_LOG(INFO) << "NNRtDelegate supports all op in this model.";
4739+  return nnrt_subgraph_ranges;
4740+}
4741+
4742+/**
4743+ * This method ONLY works when the follow pre-conditions are satisfied:
4744+ * 1. The node order of lite_graph_->all_nodes should be consistent with DelegateModel sequence.
4745+ *  This ensures the kernel replacement in DelegateModel based on the re-organizing info from lite_graph_ is correct.
4746+ * 2. The node indices of lite_graph_->sub_graphs[0].node_indices should be monotonically increasing from 0 to size - 1.
4747+ */
4748+Status NNRTDelegate::CreateLiteGraphForNNRTSubgraph(
4749+    const std::vector<NNRTOpRange> &nnrt_op_ranges,
4750+    std::vector<LiteGraph *> *sub_lite_graphs) {
4751+  MS_LOG(INFO) << "Start creating LiteGraph for NNRT subgraph";
4752+  for (const auto &op_range: nnrt_op_ranges) {
4753+    MS_LOG(INFO) << "Process op range: [" << op_range.begin_index_ << ", " << op_range.end_index_ << ")";
4754+    LiteGraph *sub_lite_graph = new (std::nothrow)LiteGraph;
4755+    if (sub_lite_graph == nullptr) {
4756+      MS_LOG(ERROR) << "Allocate LiteGraph failed";
4757+      return kLiteError;
4758+    }
4759+    sub_lite_graph->name_ = lite_graph_->name_;
4760+    sub_lite_graph->version_ = lite_graph_->version_;
4761 
4762-  ret_code = OH_NNCompilation_SetDevice(oh_nn_compilation, allDevicesID[0]);
4763+    auto sub_graph = new (std::nothrow)LiteGraph::SubGraph;
4764+    if (sub_graph == nullptr) {
4765+      MS_LOG(ERROR) << "Allocate SubGraph failed";
4766+      return kLiteError;
4767+    }
4768+    sub_graph->name_ = lite_graph_->name_;
4769+    sub_lite_graph->sub_graphs_.push_back(sub_graph);
4770 
4771+    // deal with all_nodes
4772+    MS_LOG(INFO) << "Assemble all_nodes...";
4773+    int new_node_index = 0;
4774+    std::map<uint32_t, schema::Tensor *> in_tensor_index_map;
4775+    std::map<uint32_t, schema::Tensor *> out_tensor_index_map;
4776+    for (size_t index = op_range.begin_index_; index < op_range.end_index_; index++) {
4777+      LiteGraph::Node *node = new (std::nothrow)LiteGraph::Node;
4778+      if (node == nullptr) {
4779+        MS_LOG(ERROR) << "Allocate Node failed";
4780+        return kLiteError;
4781+      }
4782+      *node = *lite_graph_->all_nodes_[index];
4783+      sub_lite_graph->all_nodes_.push_back(node);
4784+      sub_graph->node_indices_.push_back(new_node_index++);
4785+
4786+      for (auto i: node->input_indices_) {
4787+        in_tensor_index_map.emplace(i, lite_graph_->all_tensors_[i]);
4788+      }
4789+      for (auto i: node->output_indices_) {
4790+        out_tensor_index_map.emplace(i, lite_graph_->all_tensors_[i]);
4791+      }
4792+    }
4793+
4794+    // deal with all_tensors
4795+    MS_LOG(INFO) << "Assemble all_tensors...";
4796+    std::set<schema::Tensor *> tensors;
4797+    for (auto iter: in_tensor_index_map) {
4798+      tensors.emplace(iter.second);
4799+    }
4800+    for (auto iter: out_tensor_index_map) {
4801+      tensors.emplace(iter.second);
4802+    }
4803+
4804+    uint32_t new_index = 0;
4805+    std::map<schema::Tensor *, uint32_t> new_tensor_maps;
4806+    for (auto tensor: tensors) {
4807+      new_tensor_maps.emplace(tensor, new_index++);
4808+    }
4809+
4810+    sub_lite_graph->all_tensors_ = std::vector<schema::Tensor *>(tensors.begin(), tensors.end());
4811+
4812+    // deal with every node's input/output indices
4813+    MS_LOG(INFO) << "Set input/output indices of each node...";
4814+    for (auto node: sub_lite_graph->all_nodes_) {
4815+      for (auto &index : node->input_indices_) {
4816+        index = new_tensor_maps.at(in_tensor_index_map.at(index));
4817+      }
4818+      for (auto &index : node->output_indices_) {
4819+        index = new_tensor_maps.at(out_tensor_index_map.at(index));
4820+      }
4821+    }
4822+
4823+    // deal with subgraph's input/output indices
4824+    MS_LOG(INFO) << "Set input/output indices of each subgraph...";
4825+    sub_graph->tensor_indices_ = std::vector<uint32_t>(tensors.size());
4826+    std::iota(sub_graph->tensor_indices_.begin(), sub_graph->tensor_indices_.end(), 0U);
4827+
4828+    for (auto iter: in_tensor_index_map) {
4829+      auto new_tensor_index = new_tensor_maps[iter.second];
4830+      MS_LOG(DEBUG) << "handle input: old: " << iter.first << ", new: " << new_tensor_index << std::endl;
4831+      if (IsConstTensor(*iter.second)) {
4832+        MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is const." << std::endl;
4833+        continue;
4834+      }
4835+
4836+      bool is_subgraph_input = true;
4837+      for (auto node: sub_lite_graph->all_nodes_) {
4838+        if (std::find(node->output_indices_.begin(), node->output_indices_.end(), new_tensor_index) !=
4839+            node->output_indices_.end()) {
4840+          is_subgraph_input = false;
4841+          MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is not subgraph input." << std::endl;
4842+          break;
4843+        }
4844+      }
4845+      if (is_subgraph_input) {
4846+        sub_graph->input_indices_.push_back(new_tensor_index);
4847+        MS_LOG(DEBUG) << "- select tensor: " << new_tensor_index << " as subgraph input." << std::endl;
4848+      }
4849+    }
4850+
4851+    for (auto iter: out_tensor_index_map) {
4852+      int new_tensor_index = new_tensor_maps.at(iter.second);
4853+      MS_LOG(DEBUG) << "handle output: old: " << iter.first << ", new: " << new_tensor_index << std::endl;
4854+      if (IsConstTensor(*iter.second)) {
4855+        MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is const." << std::endl;
4856+        continue;
4857+      }
4858+
4859+      bool is_subgraph_output = false;
4860+      for (size_t i = 0; i < lite_graph_->all_nodes_.size(); i++) {
4861+        if ((i >= op_range.begin_index_) && (i < op_range.end_index_)) {
4862+          continue;
4863+        }
4864+        auto node = lite_graph_->all_nodes_[i];
4865+        if (std::find(node->input_indices_.begin(), node->input_indices_.end(), iter.first) !=
4866+            node->input_indices_.end()) { // As the input of node which does not belong to the subgraph.
4867+          is_subgraph_output = true;
4868+          MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is original subgraph output. node: " << node->primitive_ << std::endl;
4869+          break;
4870+        }
4871+      }
4872+      bool is_graph_output = (std::find(lite_graph_->output_indices_.begin(),lite_graph_->output_indices_.end(),
4873+                                        iter.first) != lite_graph_->output_indices_.end());
4874+      if (is_graph_output) {
4875+        MS_LOG(DEBUG) << "- tensor: " << new_tensor_index << " is graph output." << std::endl;
4876+      }
4877+      if (is_subgraph_output || is_graph_output) {
4878+        sub_graph->output_indices_.push_back(new_tensor_index);
4879+        MS_LOG(DEBUG) << "- select tensor: " << new_tensor_index << " as subgraph output." << std::endl;
4880+      }
4881+    }
4882+
4883+    // deal with full-graph's input/output indices
4884+    sub_lite_graph->input_indices_ = sub_graph->input_indices_;
4885+    sub_lite_graph->output_indices_ = sub_graph->output_indices_;
4886+    sub_lite_graphs->push_back(sub_lite_graph);
4887+  }
4888+  MS_LOG(INFO) << "Finished creating LiteGraph for NNRT subgraph";
4889+  return kSuccess;
4890+}
4891+
4892+struct TensorLocation {
4893+  uint32_t node_index; // the index of node which the tensor belongs to.
4894+  uint32_t tensor_index; // the index of node in/out tensors which the tensor is located at.
4895+};
4896+
4897+Status NNRTDelegate::InitNNCompilation(OH_NNCompilation *nn_compilation) const {
4898+  auto ret_code = OH_NNCompilation_SetDevice(nn_compilation, nnrt_device_info_.device_id_);
4899   if (ret_code != OH_NN_SUCCESS) {
4900-    MS_LOG(ERROR) << "NNCompilation SetDevice failed, OH_NN_ReturnCode = " << ret_code;
4901-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4902-    OH_NNModel_Destroy(&oh_nnmodel);
4903-    return mindspore::kLiteError;
4904+    MS_LOG(ERROR) << "NNCompilation set device id failed, ret: " << ret_code;
4905+    return kLiteError;
4906+  }
4907+  ret_code = OH_NNCompilation_SetPerformanceMode(nn_compilation,
4908+                                                 (OH_NN_PerformanceMode)(nnrt_device_info_.performance_mode_));
4909+  if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
4910+    MS_LOG(ERROR) << "NNCompilation set performance mode failed, ret: " << ret_code;
4911+    return kLiteError;
4912+  }
4913+  ret_code = OH_NNCompilation_SetPriority(nn_compilation, (OH_NN_Priority)(nnrt_device_info_.priority_));
4914+  if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
4915+    MS_LOG(ERROR) << "NNCompilation set priority failed, ret: " << ret_code;
4916+    return kLiteError;
4917+  }
4918+  ret_code = OH_NNCompilation_EnableFloat16(nn_compilation, nnrt_device_info_.enable_fp16_);
4919+  if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
4920+    MS_LOG(ERROR) << "NNCompilation enable fp16 failed, ret: " << ret_code;
4921+    return kLiteError;
4922   }
4923 
4924-  ret_code = OH_NNCompilation_Build(oh_nn_compilation);
4925+  if (!cache_path_.empty()) { // Set cache path if user indeed set it.
4926+    ret_code = OH_NNCompilation_SetCache(nn_compilation, cache_path_.c_str(), cache_version_);
4927+    if ((ret_code != OH_NN_SUCCESS) && (ret_code != OH_NN_OPERATION_FORBIDDEN)) {
4928+      MS_LOG(ERROR) << "NNCompilation set cache failed, ret: " << ret_code;
4929+      return kLiteError;
4930+    }
4931+  }
4932 
4933+  ret_code = OH_NNCompilation_Build(nn_compilation);
4934   if (ret_code != OH_NN_SUCCESS) {
4935-    MS_LOG(ERROR) << "Build NNCompilation failed, OH_NN_ReturnCode = " << ret_code;
4936-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4937-    OH_NNModel_Destroy(&oh_nnmodel);
4938-    return mindspore::kLiteError;
4939-  }
4940-
4941-  MS_LOG(DEBUG) << "NNRTDelegate SetDevice success.";
4942-
4943-  OH_NNExecutor *oh_nn_executor = nullptr;
4944-  oh_nn_executor = OH_NNExecutor_Construct(oh_nn_compilation);
4945-  if (oh_nn_executor == nullptr) {
4946-    MS_LOG(ERROR) << "Construct NNCompilation SetDevice failed, OH_NN_ReturnCode = " << ret_code;
4947-    OH_NNCompilation_Destroy(&oh_nn_compilation);
4948-    OH_NNModel_Destroy(&oh_nnmodel);
4949-    return mindspore::kLiteError;
4950-  }
4951-  MS_LOG(DEBUG) << "NNRTDelegate creates NNExecutor success.";
4952-  mindspore::Status prepare_data_ret;
4953-  auto nnr_model_kernel = new (std::nothrow) NNRTModelKernel(oh_nn_executor, model->inputs(), model->outputs());
4954-  if (nnr_model_kernel == nullptr) {
4955-    MS_LOG(ERROR) << "new NNRTModelKernel failed";
4956-    return mindspore::kLiteError;
4957+    MS_LOG(ERROR) << "Build NNCompilation failed, ret: " << ret_code;
4958+    return kLiteError;
4959   }
4960-  OH_NNCompilation_Destroy(&oh_nn_compilation);
4961-  OH_NNModel_Destroy(&oh_nnmodel);
4962-  KernelIter from = model->BeginKernelIterator();
4963-  KernelIter end = model->EndKernelIterator();
4964-  model->Replace(from, end, nnr_model_kernel);
4965+  return kSuccess;
4966+}
4967+
4968+Status NNRTDelegate::CreateNNRTSubgraphKernels(DelegateModel<schema::Primitive> *model,
4969+                                               const std::vector<LiteGraph *> &sub_lite_graphs, const std::vector<NNRTOpRange> &nnrt_subgraph_ranges,
4970+                                               std::vector<NNRTModelKernel *> *nnrt_subgraph_kernels) {
4971+  for (size_t i = 0; i < sub_lite_graphs.size(); i++) {
4972+    auto sub_lite_graph = sub_lite_graphs[i];
4973+
4974+    OH_NNModel *nn_model = OH_NNModel_Construct();
4975+    auto ret = OH_NNModel_BuildFromLiteGraph(nn_model, sub_lite_graph);
4976+    if (ret != OH_NN_SUCCESS) {
4977+      MS_LOG(ERROR) << "Build NNModel failed, ret: " << ret;
4978+      OH_NNModel_Destroy(&nn_model);
4979+      return kLiteError;
4980+    }
4981 
4982-  MS_LOG(INFO) << "NNRTDelegate build  success.";
4983-  return mindspore::kSuccess;
4984+    OH_NNCompilation *nn_compilation = OH_NNCompilation_Construct(nn_model);
4985+    if (nn_compilation == nullptr) {
4986+      MS_LOG(ERROR) << "Construct NNCompilation failed";
4987+      OH_NNModel_Destroy(&nn_model);
4988+      return kLiteError;
4989+    }
4990+    MS_LOG(DEBUG) << "NNRTDelegate creates NNCompilation success.";
4991+
4992+    auto ret_code = InitNNCompilation(nn_compilation);
4993+    if (ret_code != kSuccess) {
4994+      MS_LOG(ERROR) << "Init NNCompilation failed";
4995+      OH_NNCompilation_Destroy(&nn_compilation);
4996+      OH_NNModel_Destroy(&nn_model);
4997+      return kLiteError;
4998+    }
4999+
5000+    OH_NNExecutor *nn_executor = nullptr;
5001+    nn_executor = OH_NNExecutor_Construct(nn_compilation);
5002+    if (nn_executor == nullptr) {
5003+      MS_LOG(ERROR) << "Construct NNExecutor failed, ret: " << ret_code;
5004+      OH_NNCompilation_Destroy(&nn_compilation);
5005+      OH_NNModel_Destroy(&nn_model);
5006+      return kLiteError;
5007+    }
5008+    MS_LOG(DEBUG) << "NNRTDelegate creates NNExecutor success.";
5009+
5010+    bool format_not_support = false;
5011+    std::vector<MSTensor> in_tensors;
5012+    for (auto index: sub_lite_graph->sub_graphs_[0]->input_indices_) {
5013+      TensorLocation location;
5014+      for (auto node_index: sub_lite_graph->sub_graphs_[0]->node_indices_) {
5015+        auto node = sub_lite_graph->all_nodes_[node_index];
5016+        auto iter = std::find(node->input_indices_.begin(), node->input_indices_.end(), index);
5017+        if (iter != node->input_indices_.end()) {
5018+          uint32_t tensor_index = iter - node->input_indices_.begin();
5019+          location.node_index = node_index;
5020+          location.tensor_index = tensor_index;
5021+          MS_LOG(INFO) << "Found graph input index: " << index << " is the " << tensor_index << "th input of the node " << node->primitive_;
5022+          break;
5023+        }
5024+      }
5025+      KernelIter kernel_iter = nnrt_subgraph_ranges[i].begin_iter_ + location.node_index;
5026+      in_tensors.push_back((*kernel_iter)->inputs()[location.tensor_index]);
5027+      if (in_tensors.back().format() != Format::NHWC) {
5028+        format_not_support = true;
5029+        break ;
5030+      }
5031+    }
5032+
5033+    std::vector<MSTensor> out_tensors;
5034+    for (auto index: sub_lite_graph->sub_graphs_[0]->output_indices_) {
5035+      TensorLocation location;
5036+      for (auto node_index: sub_lite_graph->sub_graphs_[0]->node_indices_) {
5037+        auto node = sub_lite_graph->all_nodes_[node_index];
5038+        auto iter = std::find(node->output_indices_.begin(), node->output_indices_.end(), index);
5039+        if (iter != node->output_indices_.end()) {
5040+          uint32_t tensor_index = iter - node->output_indices_.begin();
5041+          location.node_index = node_index;
5042+          location.tensor_index = tensor_index;
5043+          MS_LOG(INFO) << "Found graph output index: " << index << " is the " << tensor_index << "th output of the node " << node->primitive_;
5044+          break;
5045+        }
5046+      }
5047+      KernelIter kernel_iter = nnrt_subgraph_ranges[i].begin_iter_ + location.node_index;
5048+      out_tensors.push_back((*kernel_iter)->outputs()[location.tensor_index]);
5049+      if (out_tensors.back().format() != Format::NHWC) {
5050+        format_not_support = true;
5051+        break ;
5052+      }
5053+    }
5054+    if (format_not_support) {
5055+      MS_LOG(WARNING) << "Not support in/out tensor format, skip this subgraph";
5056+      OH_NNCompilation_Destroy(&nn_compilation);
5057+      OH_NNModel_Destroy(&nn_model);
5058+      nnrt_subgraph_kernels->push_back(nullptr);
5059+      continue ;
5060+    }
5061+
5062+    auto nnrt_model_kernel = new (std::nothrow)NNRTModelKernel(nn_executor, in_tensors, out_tensors);
5063+    if (nnrt_model_kernel == nullptr) {
5064+      MS_LOG(ERROR) << "new NNRTModelKernel failed";
5065+      return kLiteError;
5066+    }
5067+    OH_NNCompilation_Destroy(&nn_compilation);
5068+    OH_NNModel_Destroy(&nn_model);
5069+    nnrt_subgraph_kernels->push_back(nnrt_model_kernel);
5070+  }
5071+  return kSuccess;
5072 }
5073 
5074-mindspore::Status mindspore::NNRTDelegate::Init() {
5075-  MS_LOG(DEBUG) << "NNRTDelegate init success.";
5076-  return mindspore::kSuccess;
5077+void NNRTDelegate::ReplaceNNRTKernelsInDelegateModel(DelegateModel<schema::Primitive> *model,
5078+                                       const std::vector<NNRTOpRange> &nnrt_subgraph_ranges,
5079+                                       const std::vector<NNRTModelKernel *> &nnrt_subgraph_kernels) {
5080+  // Here we perform the replacement from back to front intentionally! If replace from front to end, the kernel
5081+  // sequence would shrink and the later begin_iter_/end_iter_ may be erased already.
5082+  for (int i = nnrt_subgraph_ranges.size() - 1; i >= 0; i--) {
5083+    if (nnrt_subgraph_kernels[i] == nullptr) {
5084+      continue;
5085+    }
5086+    auto from = nnrt_subgraph_ranges[i].begin_iter_;
5087+    auto end = nnrt_subgraph_ranges[i].end_iter_;
5088+    (void)model->Replace(from, end, nnrt_subgraph_kernels[i]);
5089+    MS_LOG(INFO) << "Replace nnrt subgraph kernel in range: [" << (from - model->BeginKernelIterator())
5090+      << ", " << (end - model->BeginKernelIterator()) << ")";
5091+  }
5092 }
5093-mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::Primitive> *model,
5094-                                                         OH_NNExecutor *oh_nn_executor) {
5095+
5096+Status NNRTDelegate::PrepareInputs(DelegateModel<schema::Primitive> *model,
5097+                                   OH_NNExecutor *oh_nn_executor) {
5098   auto input_tensors = model->inputs();
5099   for (size_t i = 0; i < input_tensors.size(); i++) {
5100     auto tensor = input_tensors[i];
5101@@ -161,10 +654,10 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P
5102     std::vector<double> scale;
5103     std::vector<int32_t> zero_point;
5104     if (!tmp_quant_param.empty()) {
5105-      quant_param = new (std::nothrow) OH_NN_QuantParam;
5106+      quant_param = new(std::nothrow) OH_NN_QuantParam;
5107       if (quant_param == nullptr) {
5108         MS_LOG(ERROR) << "new OH_NN_QuantParam failed.";
5109-        return mindspore::kLiteError;
5110+        return kLiteError;
5111       }
5112       for (auto qparam : tmp_quant_param) {
5113         bit_num.emplace_back(qparam.bit_num);
5114@@ -176,12 +669,12 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P
5115       quant_param->scale = scale.data();
5116       quant_param->zeroPoint = zero_point.data();
5117     }
5118-    auto oprend = new (std::nothrow) OH_NN_Tensor;
5119+    auto oprend = new(std::nothrow) OH_NN_Tensor;
5120     if (oprend == nullptr) {
5121       MS_LOG(ERROR) << "new OH_NN_Tensor Failed";
5122-      return mindspore::kLiteError;
5123+      return kLiteError;
5124     }
5125-    oprend->dataType = ConvertDataType(tensor.DataType());
5126+    oprend->dataType = CastToNNRTDataType(tensor.DataType());
5127     oprend->dimensionCount = tensor_shape.size();
5128 
5129     std::vector<int32_t> dimensions_list;
5130@@ -191,14 +684,14 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P
5131       } else {
5132         MS_LOG(ERROR) << "NNExecutor SetInput failed,tensor dimension is is too large, max dim = " << INT32_MAX
5133                       << ", but get dimension = " << shape;
5134-        return mindspore::kLiteError;
5135+        return kLiteError;
5136       }
5137     }
5138     oprend->dimensions = dimensions_list.data();
5139     oprend->quantParam = quant_param;
5140     oprend->type = OH_NN_TENSOR;
5141     OH_NN_ReturnCode ret_code =
5142-      OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize());
5143+        OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize());
5144     delete (oprend);
5145 
5146     if (!tmp_quant_param.empty()) {
5147@@ -209,70 +702,41 @@ mindspore::Status mindspore::NNRTDelegate::PrepareInputs(DelegateModel<schema::P
5148     if (ret_code != OH_NN_SUCCESS) {
5149       MS_LOG(ERROR) << "NNExecutor SetInput failed, current input tensor is" << tensor.Name()
5150                     << "OH_NN_ReturnCode = " << ret_code;
5151-      return mindspore::kLiteError;
5152+      return kLiteError;
5153     }
5154   }
5155+  return kSuccess;
5156+}
5157+
5158+OH_NN_DataType NNRTDelegate::CastToNNRTDataType(DataType data_type) {
5159+  const std::unordered_map<DataType, OH_NN_DataType> kDataTypeMap = {
5160+      {DataType::kNumberTypeBool, OH_NN_BOOL},
5161+      {DataType::kNumberTypeInt8, OH_NN_INT8},
5162+      {DataType::kNumberTypeInt16, OH_NN_INT16},
5163+      {DataType::kNumberTypeInt32, OH_NN_INT32},
5164+      {DataType::kNumberTypeInt64, OH_NN_INT64},
5165+      {DataType::kNumberTypeUInt8, OH_NN_UINT8},
5166+      {DataType::kNumberTypeUInt16, OH_NN_UINT16},
5167+      {DataType::kNumberTypeUInt32, OH_NN_UINT32},
5168+      {DataType::kNumberTypeUInt64, OH_NN_UINT64},
5169+      {DataType::kNumberTypeFloat16, OH_NN_FLOAT16},
5170+      {DataType::kNumberTypeFloat32, OH_NN_FLOAT32},
5171+      {DataType::kNumberTypeFloat64, OH_NN_FLOAT64},
5172+  };
5173 
5174-  return mindspore::kSuccess;
5175+  auto iter = kDataTypeMap.find(data_type);
5176+  if (iter == kDataTypeMap.end()) {
5177+    return OH_NN_UNKNOWN;
5178+  }
5179+  return iter->second;
5180 }
5181-OH_NN_DataType mindspore::NNRTDelegate::ConvertDataType(mindspore::DataType data_type) {
5182-  OH_NN_DataType oh_data_type;
5183-  switch (data_type) {
5184-    case mindspore::DataType::kTypeUnknown:
5185-    case mindspore::DataType::kObjectTypeString:
5186-    case mindspore::DataType::kObjectTypeList:
5187-    case mindspore::DataType::kObjectTypeTuple:
5188-    case mindspore::DataType::kObjectTypeTensorType:
5189-    case mindspore::DataType::kNumberTypeBegin:
5190-    case mindspore::DataType::kNumberTypeEnd:
5191-    case mindspore::DataType::kInvalidType:
5192-      oh_data_type = OH_NN_UNKNOWN;
5193-      break;
5194-    case mindspore::DataType::kNumberTypeBool:
5195-      oh_data_type = OH_NN_BOOL;
5196-      break;
5197-    case mindspore::DataType::kNumberTypeInt8:
5198-      oh_data_type = OH_NN_INT8;
5199-      break;
5200-    case mindspore::DataType::kNumberTypeInt16:
5201-      oh_data_type = OH_NN_INT16;
5202-      break;
5203-    case mindspore::DataType::kNumberTypeInt32:
5204-      oh_data_type = OH_NN_INT32;
5205-      break;
5206-    case mindspore::DataType::kNumberTypeInt64:
5207-      oh_data_type = OH_NN_INT64;
5208-      break;
5209-    case mindspore::DataType::kNumberTypeUInt8:
5210-      oh_data_type = OH_NN_UINT8;
5211-      break;
5212-    case mindspore::DataType::kNumberTypeUInt16:
5213-      oh_data_type = OH_NN_UINT16;
5214-      break;
5215-    case mindspore::DataType::kNumberTypeUInt32:
5216-      oh_data_type = OH_NN_UINT32;
5217-      break;
5218-    case mindspore::DataType::kNumberTypeUInt64:
5219-      oh_data_type = OH_NN_UINT64;
5220-      break;
5221-    case mindspore::DataType::kNumberTypeFloat16:
5222-      oh_data_type = OH_NN_FLOAT16;
5223-      break;
5224-    case mindspore::DataType::kNumberTypeFloat32:
5225-      oh_data_type = OH_NN_FLOAT32;
5226-      break;
5227-    case mindspore::DataType::kNumberTypeFloat64:
5228-      oh_data_type = OH_NN_FLOAT64;
5229-      break;
5230-    default: {
5231-      oh_data_type = OH_NN_UNKNOWN;
5232-    }
5233-  }
5234-  return oh_data_type;
5235+
5236+OH_NN_Format NNRTDelegate::CastToNNRTFormat(Format format) {
5237+  return OH_NN_FORMAT_NHWC;
5238 }
5239 
5240-mindspore::Status mindspore::NNRTDelegate::PrepareOutputs(DelegateModel<schema::Primitive> *model,
5241-                                                          OH_NNExecutor *oh_nn_executor) {
5242+Status NNRTDelegate::PrepareOutputs(DelegateModel<schema::Primitive> *model,
5243+                                    OH_NNExecutor *oh_nn_executor) {
5244   auto output_tensors = model->outputs();
5245   for (size_t i = 0; i < output_tensors.size(); i++) {
5246     auto tensor = output_tensors[i];
5247@@ -280,17 +744,17 @@ mindspore::Status mindspore::NNRTDelegate::PrepareOutputs(DelegateModel<schema::
5248     if (ret_code != OH_NN_SUCCESS) {
5249       MS_LOG(ERROR) << "NNExecutor SetOutput failed, current out tensor is" << tensor.Name()
5250                     << ", OH_NN_ReturnCode = " << ret_code;
5251-      return mindspore::kLiteError;
5252+      return kLiteError;
5253     }
5254   }
5255-  return mindspore::kSuccess;
5256+  return kSuccess;
5257 }
5258 
5259-void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGraph &lite_graph) {
5260+void NNRTDelegate::ShallowCopyLiteGraph(const lite::LiteGraph &lite_graph) {
5261   Status ret;
5262   for (auto node : lite_graph.all_nodes_) {
5263     ret = lite::CheckPrimitiveSupported(static_cast<const schema::Primitive *>(node->primitive_));
5264-    if (ret == mindspore::kLiteError) {
5265+    if (ret == kLiteError) {
5266       MS_LOG(ERROR) << " primitive supported check failed.";
5267       return;
5268     }
5269@@ -299,7 +763,7 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr
5270   node_list.reserve(lite_graph.all_nodes_.size());
5271   // copy node
5272   for (auto node : lite_graph.all_nodes_) {
5273-    auto new_node = new (std::nothrow) LiteGraph::Node;
5274+    auto new_node = new(std::nothrow) LiteGraph::Node;
5275     if (new_node == nullptr) {
5276       MS_LOG(ERROR) << " new LiteGraph::Node failed.";
5277       return;
5278@@ -318,7 +782,7 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr
5279   // copy subgraph
5280   std::vector<LiteGraph::SubGraph *> subgraph_list;
5281   for (auto subgraph : lite_graph.sub_graphs_) {
5282-    auto new_subgraph = new (std::nothrow) LiteGraph::SubGraph;
5283+    auto new_subgraph = new(std::nothrow) LiteGraph::SubGraph;
5284     if (new_subgraph == nullptr) {
5285       MS_LOG(ERROR) << "new LiteGraph::Subgraph failed.";
5286       return;
5287@@ -331,30 +795,32 @@ void mindspore::NNRTDelegate::ShallowCopyLiteGraph(const mindspore::lite::LiteGr
5288   }
5289   for (auto tensor : lite_graph.all_tensors_) {
5290     ret = lite::CheckTensorSupported(static_cast<const schema::Tensor *>(tensor));
5291-    if (ret == mindspore::kLiteError) {
5292+    if (ret == kLiteError) {
5293       MS_LOG(ERROR) << "tensor supported check failed.";
5294       return;
5295     }
5296   }
5297 
5298-  nnrt_lite_graph = new (std::nothrow) lite::LiteGraph();
5299-  if (nnrt_lite_graph == nullptr) {
5300+  lite_graph_ = new(std::nothrow) lite::LiteGraph();
5301+  if (lite_graph_ == nullptr) {
5302     MS_LOG(ERROR) << "new LiteGraph failed.";
5303     return;
5304   }
5305 
5306-  nnrt_lite_graph->name_ = lite_graph.name_;
5307-  nnrt_lite_graph->version_ = lite_graph.version_;
5308-  nnrt_lite_graph->input_indices_ = lite_graph.input_indices_;
5309-  nnrt_lite_graph->output_indices_ = lite_graph.output_indices_;
5310-  nnrt_lite_graph->all_tensors_ = lite_graph.all_tensors_;
5311-  nnrt_lite_graph->all_nodes_ = node_list;
5312-  nnrt_lite_graph->sub_graphs_ = subgraph_list;
5313+  lite_graph_->name_ = lite_graph.name_;
5314+  lite_graph_->version_ = lite_graph.version_;
5315+  lite_graph_->input_indices_ = lite_graph.input_indices_;
5316+  lite_graph_->output_indices_ = lite_graph.output_indices_;
5317+  lite_graph_->all_tensors_ = lite_graph.all_tensors_;
5318+  lite_graph_->all_nodes_ = node_list;
5319+  lite_graph_->sub_graphs_ = subgraph_list;
5320   MS_LOG(INFO) << "ShallowCopyLiteGraph success.";
5321 }
5322 
5323-mindspore::NNRTDelegate::~NNRTDelegate() {
5324-  if (this->nnrt_lite_graph != nullptr) {
5325+NNRTDelegate::~NNRTDelegate() {
5326+  if (lite_graph_ != nullptr) {
5327     MS_LOG(ERROR) << "Delete NNRTDelegate.";
5328   }
5329-};
5330+}
5331+}  // namespace lite
5332+}  // namespace mindspore
5333diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h
5334index c2847704..52626339 100644
5335--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h
5336+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_delegate.h
5337@@ -15,37 +15,81 @@
5338  */
5339 #ifndef MINDSPORE_NNR_DELEGATE_H
5340 #define MINDSPORE_NNR_DELEGATE_H
5341+
5342 #include <vector>
5343 #include <map>
5344 #include "include/api/delegate.h"
5345 #include "include/model.h"
5346-#include "interfaces/kits/c/neural_network_runtime_type.h"
5347-namespace mindspore {
5348+#include "src/litert/inner_context.h"
5349+#include "nnrt_model_kernel.h"
5350+#include "schema/model_generated.h"
5351+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime_type.h"
5352+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
5353+#include "interfaces/innerkits/c/neural_network_runtime_inner.h"
5354 
5355-using namespace lite;
5356+namespace mindspore {
5357+namespace lite {
5358+struct NNRTOpRange {
5359+  /* NNRT kernel range in DelegateModel: [begin_iter_, end_iter_) */
5360+  KernelIter begin_iter_;
5361+  KernelIter end_iter_;
5362+  /* NNRT node range in lite_graph_: [begin_index_, end_index_) */
5363+  size_t begin_index_;
5364+  size_t end_index_;
5365+};
5366 
5367 class NNRTDelegate : public Delegate {
5368  public:
5369-  NNRTDelegate() : Delegate(){};
5370-
5371+  NNRTDelegate() = default;
5372+  NNRTDelegate(const NNRtDeviceInfo &nnrt_device_info) : nnrt_device_info_(nnrt_device_info) {}
5373   ~NNRTDelegate() override;
5374-
5375-  Status Init() override;
5376-
5377+  Status Init() override { return kSuccess; }
5378   Status Build(DelegateModel<schema::Primitive> *model) override;
5379-
5380   void ShallowCopyLiteGraph(const lite::LiteGraph &liteGraph);
5381-
5382- protected:
5383-  LiteGraph *nnrt_lite_graph = nullptr;
5384+  void SetMetaGraph(const void *meta_graph) {
5385+    meta_graph_ = meta_graph;
5386+  }
5387+  static std::vector<NNRTOpRange> GetNNRTSubgraphRanges(DelegateModel<schema::Primitive> *model,
5388+                                                        const std::vector<bool> &op_supports);
5389 
5390  private:
5391-  //  static LiteGraph* CreateLiteGraph(const LiteGraph &liteGraph);
5392+  void InitCachePath();
5393+  Status BuildNormalModel(DelegateModel<schema::Primitive> *model);
5394+  OH_NNModel *CreateFullNNModel();
5395+  std::vector<bool> QueryOpSupports(OH_NNModel *nn_model);
5396+  Status CreateLiteGraphForNNRTSubgraph(
5397+    const std::vector<NNRTOpRange> &nnrt_op_ranges,
5398+    std::vector<LiteGraph *> *sub_lite_graphs);
5399+  Status CreateNNRTSubgraphKernels(
5400+    DelegateModel<schema::Primitive> *model,
5401+    const std::vector<LiteGraph *> &sub_lite_graphs,
5402+    const std::vector<NNRTOpRange> &nnrt_subgraph_ranges,
5403+    std::vector<NNRTModelKernel *> *nnrt_subgraph_kernels);
5404+  void ReplaceNNRTKernelsInDelegateModel(DelegateModel<schema::Primitive> *model,
5405+                                         const std::vector<NNRTOpRange> &nnrt_subgraph_ranges,
5406+                                         const std::vector<NNRTModelKernel *> &nnrt_subgraph_kernels);
5407   Status PrepareInputs(DelegateModel<schema::Primitive> *model, OH_NNExecutor *oh_nn_executor);
5408   Status PrepareOutputs(DelegateModel<schema::Primitive> *model, OH_NNExecutor *oh_nn_executor);
5409-  OH_NN_DataType ConvertDataType(mindspore::DataType data_type);
5410-};
5411+  Status InitNNCompilation(OH_NNCompilation *nn_compilation) const;
5412+  static OH_NN_DataType CastToNNRTDataType(mindspore::DataType data_type);
5413+  static OH_NN_Format CastToNNRTFormat(Format format);
5414+  bool IsCustomModel() const;
5415+
5416+#ifdef SUPPORT_NNRT_METAGRAPH
5417+  bool IsKirinNPU() const;
5418+  Status BuildKirinNPUModel(DelegateModel<schema::Primitive> *model);
5419+  Status SetKirinModelInputsAndOutputs(OH_NNModel *nn_model);
5420+  std::vector<OH_NN_TensorInfo> CreateNNTensorInfos(const std::vector<uint32_t> &indices) const;
5421+  Status CreateFullModelKernel(DelegateModel<schema::Primitive> *model, OH_NNModel *nn_model);
5422+#endif
5423 
5424+  NNRtDeviceInfo nnrt_device_info_;
5425+  LiteGraph *lite_graph_ = nullptr;
5426+  const void *meta_graph_ = nullptr;
5427+  std::string cache_path_ = "";
5428+  uint32_t cache_version_ = 0;
5429+};
5430+}  // namespace lite
5431 }  // namespace mindspore
5432 
5433 #endif  // MINDSPORE_NNR_DELEGATE_H
5434diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
5435index 5acf2e9a..67443e08 100644
5436--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
5437+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.cc
5438@@ -97,7 +97,7 @@ OH_NN_DataType mindspore::NNRTModelKernel::ConvertDataType(mindspore::DataType d
5439 }
5440 int mindspore::NNRTModelKernel::PrepareInputs() {
5441   auto input_tensors = this->inputs();
5442-  for (int i = 0; i < input_tensors.size(); i++) {
5443+  for (size_t i = 0; i < input_tensors.size(); i++) {
5444     auto tensor = input_tensors[i];
5445     auto tensor_shape = tensor.Shape();
5446     auto tmp_quant_param = tensor.QuantParams();
5447@@ -142,6 +142,7 @@ int mindspore::NNRTModelKernel::PrepareInputs() {
5448     oprend->dimensions = dimensions_list.data();
5449     oprend->quantParam = quant_param;
5450     oprend->type = OH_NN_TENSOR;
5451+    MS_LOG_INFO << "input tensor: " << tensor.Name() << ", data: " << (void *)tensor.MutableData() << ", size: " << tensor.DataSize();
5452     OH_NN_ReturnCode ret_code =
5453       OH_NNExecutor_SetInput(oh_nn_executor, i, oprend, tensor.MutableData(), tensor.DataSize());
5454     delete (oprend);
5455diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
5456index cf9481df..ea15f7ca 100644
5457--- a/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
5458+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_model_kernel.h
5459@@ -20,7 +20,7 @@
5460 #include <map>
5461 #include <utility>
5462 #include "include/api/kernel.h"
5463-#include "interfaces/kits/c/neural_network_runtime.h"
5464+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
5465 #include "src/common/log_adapter.h"
5466 #include "include/errorcode.h"
5467 
5468diff --git a/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc b/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc
5469new file mode 100644
5470index 00000000..8ac283af
5471--- /dev/null
5472+++ b/mindspore/lite/src/litert/delegate/nnrt/nnrt_stub.cc
5473@@ -0,0 +1,99 @@
5474+/**
5475+* Copyright 2023 Huawei Technologies Co., Ltd
5476+*
5477+* Licensed under the Apache License, Version 2.0 (the "License");
5478+* you may not use this file except in compliance with the License.
5479+* You may obtain a copy of the License at
5480+*
5481+* http://www.apache.org/licenses/LICENSE-2.0
5482+*
5483+* Unless required by applicable law or agreed to in writing, software
5484+* distributed under the License is distributed on an "AS IS" BASIS,
5485+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5486+* See the License for the specific language governing permissions and
5487+* limitations under the License.
5488+*/
5489+
5490+#include "interfaces/kits/c/neural_network_runtime/neural_network_runtime.h"
5491+#include "interfaces/innerkits/c/neural_network_runtime_inner.h"
5492+
5493+OH_NNModel *OH_NNModel_Construct(void) {
5494+  return NULL;
5495+}
5496+
5497+OH_NN_ReturnCode OH_NNExecutor_Run(OH_NNExecutor *executor) {
5498+  return OH_NN_SUCCESS;
5499+}
5500+
5501+OH_NN_ReturnCode OH_NNCompilation_Build(OH_NNCompilation *compilation) {
5502+  return OH_NN_SUCCESS;
5503+}
5504+
5505+void OH_NNCompilation_Destroy(OH_NNCompilation **compilation) {}
5506+
5507+OH_NNExecutor *OH_NNExecutor_Construct(OH_NNCompilation *compilation) {
5508+  return NULL;
5509+}
5510+
5511+void OH_NNExecutor_Destroy(OH_NNExecutor **executor) {}
5512+
5513+OH_NNCompilation *OH_NNCompilation_Construct(const OH_NNModel *model) {
5514+  return NULL;
5515+}
5516+
5517+OH_NN_ReturnCode OH_NNDevice_GetAllDevicesID(const size_t **allDevicesID, uint32_t *deviceCount) {
5518+  return OH_NN_SUCCESS;
5519+}
5520+
5521+OH_NN_ReturnCode OH_NNExecutor_SetOutput(OH_NNExecutor *executor,
5522+                                         uint32_t outputIndex,
5523+                                         void *dataBuffer,
5524+                                         size_t length) {
5525+  return OH_NN_SUCCESS;
5526+}
5527+
5528+OH_NN_ReturnCode OH_NNCompilation_SetDevice(OH_NNCompilation *compilation, size_t deviceID) {
5529+  return OH_NN_SUCCESS;
5530+}
5531+
5532+OH_NN_ReturnCode OH_NNExecutor_SetInput(OH_NNExecutor *executor,
5533+                                        uint32_t inputIndex,
5534+                                        const OH_NN_Tensor *tensor,
5535+                                        const void *dataBuffer,
5536+                                        size_t length) {
5537+  return OH_NN_SUCCESS;
5538+}
5539+
5540+void OH_NNModel_Destroy(OH_NNModel **model) {}
5541+
5542+OH_NN_ReturnCode OH_NNModel_GetAvailableOperations(OH_NNModel *model,
5543+                                                   size_t deviceID,
5544+                                                   const bool **isSupported,
5545+                                                   uint32_t *opCount) {
5546+  return OH_NN_SUCCESS;
5547+}
5548+
5549+OH_NN_ReturnCode OH_NNModel_BuildFromLiteGraph(OH_NNModel *model, const void *liteGraph) {
5550+  return OH_NN_SUCCESS;
5551+}
5552+
5553+OH_NN_ReturnCode OH_NNDevice_GetName(size_t deviceID, const char **name) {
5554+  return OH_NN_SUCCESS;
5555+}
5556+
5557+OH_NN_ReturnCode OH_NNDevice_GetType(size_t deviceID, OH_NN_DeviceType *deviceType) {
5558+  return OH_NN_SUCCESS;
5559+}
5560+
5561+OH_NN_ReturnCode OH_NNCompilation_SetPriority(OH_NNCompilation *compilation, OH_NN_Priority priority) {
5562+  return OH_NN_SUCCESS;
5563+}
5564+
5565+OH_NN_ReturnCode OH_NNCompilation_EnableFloat16(OH_NNCompilation *compilation, bool enableFloat16) {
5566+  return OH_NN_SUCCESS;
5567+}
5568+
5569+OH_NN_ReturnCode OH_NNCompilation_SetPerformanceMode(OH_NNCompilation *compilation,
5570+                                                     OH_NN_PerformanceMode performanceMode) {
5571+  return OH_NN_SUCCESS;
5572+}
5573\ No newline at end of file
5574diff --git a/mindspore/lite/src/litert/infer_manager.cc b/mindspore/lite/src/litert/infer_manager.cc
5575index 2b21d1ca..908ab122 100644
5576--- a/mindspore/lite/src/litert/infer_manager.cc
5577+++ b/mindspore/lite/src/litert/infer_manager.cc
5578@@ -162,7 +162,8 @@ int KernelInferShape(const std::vector<lite::Tensor *> &inputs, const std::vecto
5579   if (parameter->type_ == static_cast<int>(schema::PrimitiveType_PartialFusion) ||
5580       parameter->type_ == static_cast<int>(schema::PrimitiveType_Switch) ||
5581       parameter->type_ == static_cast<int>(schema::PrimitiveType_Call) ||
5582-      parameter->type_ == static_cast<int>(schema::PrimitiveType_SwitchLayer)) {
5583+      parameter->type_ == static_cast<int>(schema::PrimitiveType_SwitchLayer) ||
5584+      parameter->type_ == static_cast<int>(PrimType_Inner_ThirdPartyModel)) {
5585     MS_LOG(INFO) << "no need infer shape.";
5586     return RET_OK;
5587   }
5588diff --git a/mindspore/lite/src/litert/inner_context.cc b/mindspore/lite/src/litert/inner_context.cc
5589index 7cbac8f7..bf585ff0 100644
5590--- a/mindspore/lite/src/litert/inner_context.cc
5591+++ b/mindspore/lite/src/litert/inner_context.cc
5592@@ -122,6 +122,10 @@ int InnerContext::Init() {
5593 #endif
5594   }
5595 
5596+  if (IsDeviceTypeEnabled(DT_NNRT)) {
5597+    MS_LOG(DEBUG) << "NNRT enabled.";
5598+  }
5599+
5600   if (CreateThreadPool(false)) {
5601     MS_LOG(ERROR) << "CreateThreadPool failed.";
5602     return RET_ERROR;
5603diff --git a/mindspore/lite/src/litert/inner_context.h b/mindspore/lite/src/litert/inner_context.h
5604index 88281eb1..8735961c 100644
5605--- a/mindspore/lite/src/litert/inner_context.h
5606+++ b/mindspore/lite/src/litert/inner_context.h
5607@@ -71,12 +71,26 @@ typedef struct CustomDeviceInfo {
5608   std::shared_ptr<DeviceInfoContext> user_defined_device_info_;
5609 } CustomDeviceInfo;
5610 
5611+typedef struct Extension {
5612+  std::string name; // config name
5613+  std::vector<uint8_t> value; // config value
5614+} Extension;
5615+
5616+typedef struct NNRtDeviceInfo {
5617+  size_t device_id_ = 0;
5618+  int priority_ = 0;
5619+  int performance_mode_ = 0;
5620+  bool enable_fp16_ = false;
5621+  std::vector<Extension> extensions_;
5622+} NNRtDeviceInfo;
5623+
5624 struct DeviceInfo {
5625   CpuDeviceInfo cpu_device_info_;
5626   GpuDeviceInfo gpu_device_info_;
5627   NpuDeviceInfo npu_device_info_;
5628   AscendDeviceInfo ascend_device_info_;
5629   CustomDeviceInfo custom_device_info_;
5630+  NNRtDeviceInfo nnrt_device_info_;
5631 };
5632 
5633 struct DeviceContext {
5634diff --git a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn
5635index 48308425..65065b5b 100644
5636--- a/mindspore/lite/src/litert/kernel/cpu/BUILD.gn
5637+++ b/mindspore/lite/src/litert/kernel/cpu/BUILD.gn
5638@@ -13,6 +13,10 @@ cpu_kernel_sources = [
5639     "base/call.cc",
5640     "base/constant_of_shape.cc",
5641     "base/convolution_base.cc",
5642+    "base/custom_base.cc",
5643+    "base/custom_masked_fill.cc",
5644+    "base/custom_is_inf.cc",
5645+    "base/custom_tensor_scatter.cc",
5646     "base/detection_post_process_base.cc",
5647     "base/format_transpose.cc",
5648     "base/group_convolution_base.cc",
5649@@ -37,7 +41,6 @@ cpu_kernel_sources = [
5650     "fp32/batchnorm_fp32.cc",
5651     "fp32/batch_to_space_fp32.cc",
5652     "fp32/broadcast_to_fp32.cc",
5653-    "fp32/cast_for_x86_fp16.cc",
5654     "fp32/cast_fp32.cc",
5655     "fp32/convolution_1x1_fp32.cc",
5656     "fp32/convolution_delegate_fp32.cc",
5657@@ -118,6 +121,10 @@ cpu_kernel_sources = [
5658     "fp32/online_fusion/split_reduce_concat_fp32.cc",
5659 ]
5660 
5661+if ((target_cpu != "arm") && (target_cpu != "arm64")) {
5662+    cpu_kernel_sources += [ "src/runtime/kernel/cpu/fp32/cast_for_x86_fp16.cc" ]
5663+}
5664+
5665 arm64_cpu_kernel_sources = [
5666   "fp32/convolution_im2col_arm64_fp32.cc",
5667   "fp32/matmul_fp32_arm64.cc",
5668@@ -142,6 +149,42 @@ sse_avx_avx512_kernel_sources = [
5669   "fp32/matmul_fp32_avx512.cc",
5670 ]
5671 
5672+fp16_kernel_sources = [
5673+  "fp16/batchnorm_fp16.cc",
5674+  "fp16/biasadd_fp16.cc",
5675+  "fp16/cast_fp16.cc",
5676+  "fp16/common_fp16.cc",
5677+  "fp16/convolution_1x1_fp16.cc",
5678+  "fp16/convolution_delegate_fp16.cc",
5679+  "fp16/convolution_depthwise_3x3_fp16.cc",
5680+  "fp16/convolution_depthwise_fp16.cc",
5681+  "fp16/convolution_depthwise_slidewindow_fp16.cc",
5682+  "fp16/convolution_fp16.cc",
5683+  "fp16/convolution_winograd_fp16.cc",
5684+  "fp16/custom_gru_fp16.cc",
5685+  "fp16/deconvolution_depthwise_fp16.cc",
5686+  "fp16/deconvolution_fp16.cc",
5687+  "fp16/deconvolution_winograd_fp16.cc",
5688+  "fp16/depth_to_space_fp16.cc",
5689+  "fp16/dynamic_quant_fp16.cc",
5690+  "fp16/fullconnection_fp16.cc",
5691+  "fp16/fused_batchnorm_fp16.cc",
5692+  "fp16/group_convolution_fp16.cc",
5693+  "fp16/gru_fp16.cc",
5694+  "fp16/instance_norm_fp16.cc",
5695+  "fp16/layout_transform_fp16.cc",
5696+  "fp16/lstm_fp16.cc",
5697+  "fp16/matmul_base_fp16.cc",
5698+  "fp16/matmul_fp16.cc",
5699+  "fp16/power_fp16.cc",
5700+  "fp16/prelu_fp16.cc",
5701+  "fp16/quant_dtype_cast_fp16.cc",
5702+  "fp16/reduce_fp16.cc",
5703+  "fp16/resize_fp16.cc",
5704+  "fp16/slice_fp16.cc",
5705+  "fp16/where_fp16.cc",
5706+]
5707+
5708 int8_kernel_sources = [
5709     "int8/activation_int8.cc",
5710     "int8/add_int8.cc",
5711@@ -227,6 +270,12 @@ all_cpu_kernel_sources += int8_kernel_sources
5712 all_cpu_kernel_sources += string_kernel_sources
5713 all_cpu_kernel_sources += control_kernel_sources
5714 
5715+if (target_cpu == "arm64") {
5716+    all_cpu_kernel_sources += fp16_kernel_sources
5717+} else {
5718+    not_needed(fp16_kernel_sources)
5719+}
5720+
5721 if (target_cpu == "arm") {
5722   all_cpu_kernel_sources -= arm64_cpu_kernel_sources
5723   all_cpu_kernel_sources -= sse_avx_avx512_kernel_sources
5724diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc
5725new file mode 100644
5726index 00000000..9921e063
5727--- /dev/null
5728+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.cc
5729@@ -0,0 +1,46 @@
5730+/**
5731+ * Copyright 2022 Huawei Technologies Co., Ltd
5732+ *
5733+ * Licensed under the Apache License, Version 2.0 (the "License");
5734+ * you may not use this file except in compliance with the License.
5735+ * You may obtain a copy of the License at
5736+ *
5737+ * http://www.apache.org/licenses/LICENSE-2.0
5738+ *
5739+ * Unless required by applicable law or agreed to in writing, software
5740+ * distributed under the License is distributed on an "AS IS" BASIS,
5741+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5742+ * See the License for the specific language governing permissions and
5743+ * limitations under the License.
5744+ */
5745+
5746+#include "src/litert/kernel/cpu/base/custom_base.h"
5747+#include <algorithm>
5748+#include <utility>
5749+#include <vector>
5750+#include "src/litert/kernel_registry.h"
5751+#include "nnacl/op_base.h"
5752+
5753+using mindspore::kernel::KERNEL_ARCH;
5754+using mindspore::lite::KernelRegistrar;
5755+using mindspore::lite::RET_ERROR;
5756+using mindspore::lite::RET_OK;
5757+using mindspore::schema::PrimitiveType_Custom;
5758+
5759+namespace mindspore::kernel {
5760+int CustomBaseCPUKernel::Prepare() {
5761+  return RET_OK;
5762+}
5763+
5764+int CustomBaseCPUKernel::ReSize() {
5765+  return RET_OK;
5766+}
5767+
5768+int CustomBaseCPUKernel::Run() {
5769+  return RET_OK;
5770+}
5771+
5772+REG_KERNEL(kCPU, kNumberTypeInt32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>)
5773+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>)
5774+REG_KERNEL(kCPU, kNumberTypeBool, PrimType_Inner_ThirdPartyModel, LiteKernelCreator<CustomBaseCPUKernel>)
5775+}  // namespace mindspore::kernel
5776diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h
5777new file mode 100644
5778index 00000000..ecb4c72d
5779--- /dev/null
5780+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_base.h
5781@@ -0,0 +1,43 @@
5782+/**
5783+ * Copyright 2022 Huawei Technologies Co., Ltd
5784+ *
5785+ * Licensed under the Apache License, Version 2.0 (the "License");
5786+ * you may not use this file except in compliance with the License.
5787+ * You may obtain a copy of the License at
5788+ *
5789+ * http://www.apache.org/licenses/LICENSE-2.0
5790+ *
5791+ * Unless required by applicable law or agreed to in writing, software
5792+ * distributed under the License is distributed on an "AS IS" BASIS,
5793+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5794+ * See the License for the specific language governing permissions and
5795+ * limitations under the License.
5796+ */
5797+
5798+#ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_
5799+#define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_
5800+
5801+#include <vector>
5802+#include "src/litert/lite_kernel.h"
5803+#include "nnacl/custom_parameter.h"
5804+
5805+namespace mindspore::kernel {
5806+class CustomBaseCPUKernel : public LiteKernel {
5807+ public:
5808+  CustomBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
5809+                      const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
5810+      : LiteKernel(parameter, inputs, outputs, ctx) {
5811+    custom_param_ = reinterpret_cast<CustomParameter *>(op_parameter_);
5812+  }
5813+  ~CustomBaseCPUKernel() override = default;
5814+
5815+  int Prepare() override;
5816+  int ReSize() override;
5817+  int Run() override;
5818+
5819+ private:
5820+  CustomParameter *custom_param_ = nullptr;
5821+};
5822+}  // namespace mindspore::kernel
5823+
5824+#endif  // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_CUSTOM_BASE_H_
5825diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc
5826new file mode 100644
5827index 00000000..edffea42
5828--- /dev/null
5829+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.cc
5830@@ -0,0 +1,61 @@
5831+/**
5832+ * Copyright 2023 Huawei Technologies Co., Ltd
5833+ *
5834+ * Licensed under the Apache License, Version 2.0 (the "License");
5835+ * you may not use this file except in compliance with the License.
5836+ * You may obtain a copy of the License at
5837+ *
5838+ * http://www.apache.org/licenses/LICENSE-2.0
5839+ *
5840+ * Unless required by applicable law or agreed to in writing, software
5841+ * distributed under the License is distributed on an "AS IS" BASIS,
5842+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5843+ * See the License for the specific language governing permissions and
5844+ * limitations under the License.
5845+ */
5846+#include "src/litert/kernel_registry.h"
5847+#include "include/errorcode.h"
5848+#include "src/litert/kernel/cpu/base/custom_is_inf.h"
5849+#include "src/common/tensor_util.h"
5850+#include "nnacl/op_base.h"
5851+
5852+using mindspore::lite::KernelRegistrar;
5853+using mindspore::lite::RET_ERROR;
5854+using mindspore::lite::RET_OK;
5855+
5856+namespace mindspore::kernel {
5857+
5858+int CustomIsInfCPUKernel::Prepare() {
5859+  CHECK_LESS_RETURN(in_tensors_.size(), C1NUM);
5860+  CHECK_LESS_RETURN(out_tensors_.size(), C1NUM);
5861+  return RET_OK;
5862+}
5863+
5864+int CustomIsInfCPUKernel::ReSize() { return RET_OK; }
5865+
5866+void CustomIsInfCPUKernel::LaunchKernelFloat(const float *input, bool *output) {
5867+  auto elem_num = in_tensors_[FIRST_INPUT]->ElementsNum();
5868+
5869+  for (int i = 0; i < elem_num; i++) {
5870+    output[i] = std::isinf(input[i]);
5871+  }
5872+}
5873+
5874+int CustomIsInfCPUKernel::Run() {
5875+  auto input = in_tensors_[FIRST_INPUT];
5876+  auto output = out_tensors_[FIRST_INPUT];
5877+  CHECK_NULL_RETURN(input);
5878+  CHECK_NULL_RETURN(output);
5879+
5880+  if (input->data_type() == kNumberTypeFloat32 || input->data_type() == kNumberTypeFloat) {
5881+    LaunchKernelFloat(reinterpret_cast<const float *>(input->data()), reinterpret_cast<bool *>(output->data()));
5882+  } else {
5883+    MS_LOG(ERROR) << "unsupported input data type " << input->data_type();
5884+    return RET_ERROR;
5885+  }
5886+
5887+  return RET_OK;
5888+}
5889+
5890+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomIsInf, LiteKernelCreator<CustomIsInfCPUKernel>)
5891+}  // namespace mindspore::kernel
5892diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h
5893new file mode 100644
5894index 00000000..e63d8ec7
5895--- /dev/null
5896+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_is_inf.h
5897@@ -0,0 +1,38 @@
5898+/**
5899+ * Copyright 2023 Huawei Technologies Co., Ltd
5900+ *
5901+ * Licensed under the Apache License, Version 2.0 (the "License");
5902+ * you may not use this file except in compliance with the License.
5903+ * You may obtain a copy of the License at
5904+ *
5905+ * http://www.apache.org/licenses/LICENSE-2.0
5906+ *
5907+ * Unless required by applicable law or agreed to in writing, software
5908+ * distributed under the License is distributed on an "AS IS" BASIS,
5909+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5910+ * See the License for the specific language governing permissions and
5911+ * limitations under the License.
5912+ */
5913+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_
5914+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_
5915+
5916+#include <vector>
5917+#include "src/litert/lite_kernel.h"
5918+
5919+namespace mindspore::kernel {
5920+class CustomIsInfCPUKernel : public LiteKernel {
5921+ public:
5922+  CustomIsInfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
5923+                       const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
5924+      : LiteKernel(parameter, inputs, outputs, ctx) {}
5925+  ~CustomIsInfCPUKernel() override = default;
5926+  int Prepare() override;
5927+  int ReSize() override;
5928+  int Run() override;
5929+
5930+ private:
5931+  void LaunchKernelFloat(const float *input, bool *output);
5932+};
5933+}  // namespace mindspore::kernel
5934+
5935+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_IS_INF_CPU_H_
5936diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc
5937new file mode 100644
5938index 00000000..9af1af5d
5939--- /dev/null
5940+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.cc
5941@@ -0,0 +1,84 @@
5942+/**
5943+ * Copyright 2023 Huawei Technologies Co., Ltd
5944+ *
5945+ * Licensed under the Apache License, Version 2.0 (the "License");
5946+ * you may not use this file except in compliance with the License.
5947+ * You may obtain a copy of the License at
5948+ *
5949+ * http://www.apache.org/licenses/LICENSE-2.0
5950+ *
5951+ * Unless required by applicable law or agreed to in writing, software
5952+ * distributed under the License is distributed on an "AS IS" BASIS,
5953+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
5954+ * See the License for the specific language governing permissions and
5955+ * limitations under the License.
5956+ */
5957+#include "src/litert/kernel_registry.h"
5958+#include "include/errorcode.h"
5959+#include "src/litert/kernel/cpu/base/custom_masked_fill.h"
5960+#include "src/common/tensor_util.h"
5961+#include "nnacl/op_base.h"
5962+
5963+using mindspore::lite::KernelRegistrar;
5964+using mindspore::lite::RET_ERROR;
5965+using mindspore::lite::RET_OK;
5966+
5967+namespace mindspore::kernel {
5968+
5969+int CustomMaskedFillCPUKernel::Prepare() {
5970+  CHECK_LESS_RETURN(in_tensors_.size(), C3NUM);
5971+  CHECK_LESS_RETURN(out_tensors_.size(), C1NUM);
5972+
5973+  // only support input value as a single float value
5974+  MS_CHECK_TRUE_MSG(in_tensors_[FIRST_INPUT]->data_type() == mindspore::TypeId::kNumberTypeFloat32 ||
5975+                      in_tensors_[FIRST_INPUT]->data_type() == mindspore::TypeId::kNumberTypeFloat,
5976+                    RET_ERROR, "input dtype must be float32");
5977+  if (in_tensors_[THIRD_INPUT]->ElementsNum() != 1) {
5978+    MS_LOG(ERROR) << "only support fill value as a single float";
5979+    return RET_ERROR;
5980+  }
5981+  MS_CHECK_TRUE_MSG(in_tensors_[SECOND_INPUT]->data_type() == mindspore::TypeId::kNumberTypeBool, RET_ERROR,
5982+                    "mask dtype must be bool");
5983+  if (!InferShapeDone()) {
5984+    return RET_OK;
5985+  }
5986+  return ReSize();
5987+}
5988+
5989+int CustomMaskedFillCPUKernel::ReSize() { return RET_OK; }
5990+
5991+int CustomMaskedFillCPUKernel::Run() {
5992+  auto input = in_tensors_[FIRST_INPUT];
5993+  auto mask = in_tensors_[SECOND_INPUT];
5994+  auto value = in_tensors_[THIRD_INPUT];
5995+  auto output = out_tensors_[FIRST_INPUT];
5996+  CHECK_NULL_RETURN(input);
5997+  CHECK_NULL_RETURN(mask);
5998+  CHECK_NULL_RETURN(value);
5999+  CHECK_NULL_RETURN(output);
6000+
6001+  if (input->shape() != mask->shape()) {
6002+    MS_LOG(ERROR) << "Not support broadcast mask to input";
6003+    return RET_ERROR;
6004+  }
6005+
6006+  auto value_data = reinterpret_cast<float *>(value->data());
6007+  auto fill_value = value_data[0];
6008+
6009+  auto data_num = input->ElementsNum();
6010+  auto input_data = reinterpret_cast<float *>(input->data());
6011+  auto mask_data = reinterpret_cast<bool *>(mask->data());
6012+  auto output_data = reinterpret_cast<float *>(output->data());
6013+  for (int64_t i = 0; i < data_num; i++) {
6014+    if (mask_data[i]) {
6015+      output_data[i] = fill_value;
6016+    } else {
6017+      output_data[i] = input_data[i];
6018+    }
6019+  }
6020+
6021+  return RET_OK;
6022+}
6023+
6024+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomMaskedFill, LiteKernelCreator<CustomMaskedFillCPUKernel>)
6025+}  // namespace mindspore::kernel
6026diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h
6027new file mode 100644
6028index 00000000..04a2dcab
6029--- /dev/null
6030+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_masked_fill.h
6031@@ -0,0 +1,35 @@
6032+/**
6033+ * Copyright 2023 Huawei Technologies Co., Ltd
6034+ *
6035+ * Licensed under the Apache License, Version 2.0 (the "License");
6036+ * you may not use this file except in compliance with the License.
6037+ * You may obtain a copy of the License at
6038+ *
6039+ * http://www.apache.org/licenses/LICENSE-2.0
6040+ *
6041+ * Unless required by applicable law or agreed to in writing, software
6042+ * distributed under the License is distributed on an "AS IS" BASIS,
6043+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6044+ * See the License for the specific language governing permissions and
6045+ * limitations under the License.
6046+ */
6047+#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_
6048+#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_
6049+
6050+#include <vector>
6051+#include "src/litert/lite_kernel.h"
6052+
6053+namespace mindspore::kernel {
6054+class CustomMaskedFillCPUKernel : public LiteKernel {
6055+ public:
6056+  CustomMaskedFillCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
6057+                            const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
6058+      : LiteKernel(parameter, inputs, outputs, ctx) {}
6059+  ~CustomMaskedFillCPUKernel() override = default;
6060+  int Prepare() override;
6061+  int ReSize() override;
6062+  int Run() override;
6063+};
6064+}  // namespace mindspore::kernel
6065+
6066+#endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_CUSTOM_MASKED_FILL_H_
6067diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc
6068new file mode 100644
6069index 00000000..d52d67d5
6070--- /dev/null
6071+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.cc
6072@@ -0,0 +1,75 @@
6073+/**
6074+ * Copyright 2022 Huawei Technologies Co., Ltd
6075+ *
6076+ * Licensed under the Apache License, Version 2.0 (the "License");
6077+ * you may not use this file except in compliance with the License.
6078+ * You may obtain a copy of the License at
6079+ *
6080+ * http://www.apache.org/licenses/LICENSE-2.0
6081+ *
6082+ * Unless required by applicable law or agreed to in writing, software
6083+ * distributed under the License is distributed on an "AS IS" BASIS,
6084+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6085+ * See the License for the specific language governing permissions and
6086+ * limitations under the License.
6087+ */
6088+
6089+#include "src/litert/kernel/cpu/base/custom_tensor_scatter.h"
6090+#include <cstring>
6091+#include "schema/model_generated.h"
6092+#include "src/litert/kernel_registry.h"
6093+#include "include/errorcode.h"
6094+#include "nnacl/base/scatter_nd_binary.h"
6095+
6096+using mindspore::kernel::KERNEL_ARCH;
6097+using mindspore::lite::KernelRegistrar;
6098+using mindspore::lite::RET_ERROR;
6099+using mindspore::lite::RET_OK;
6100+
6101+namespace mindspore::kernel {
6102+namespace {
6103+int TensorScatterRun(void *cdata, int task_id, float, float) {
6104+  auto kernel = static_cast<CustomTensorScatterCPUKernel *>(cdata);
6105+  CHECK_NULL_RETURN(kernel);
6106+  return kernel->TensorScatterDispatch(task_id);
6107+}
6108+}  // namespace
6109+
6110+int CustomTensorScatterCPUKernel::TensorScatterDispatch(int task_id) {
6111+  auto data_type = in_tensors_[kScatterUpdateInputIndex]->data_type();
6112+  if (data_type != kNumberTypeFloat32) {
6113+    MS_LOG(ERROR) << "TensorScatterMax only support float32 input tensor, but got " << data_type;
6114+    return RET_ERROR;
6115+  }
6116+  int type = data_type == kNumberTypeFloat32 ? 0 : 1;
6117+  // multi thread have some problems to solve
6118+  param_->op_parameter.thread_num_ = 1;
6119+  auto ret = ScatterNDMax(in_tensors_[kScatterUpdateIndex]->data(), out_tensors_[kOutputIndex]->data(),
6120+                          output_unit_offsets_.data(), param_, type, task_id);
6121+  if (ret != RET_OK) {
6122+    MS_LOG(ERROR) << "ScatterNDMax failed, ret: " << ret;
6123+    return RET_ERROR;
6124+  }
6125+  return RET_OK;
6126+}
6127+
6128+int CustomTensorScatterCPUKernel::Run() {
6129+  auto in_tensor = in_tensors().front();
6130+  auto out_tensor = out_tensors().front();
6131+  (void)memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
6132+  auto indices = in_tensors_.at(kScatterIndicesIndex);
6133+  if (!indices->IsConst() && ReSize() != RET_OK) {
6134+    MS_LOG(ERROR) << "TensorScatterAdd resize failed.";
6135+    return RET_ERROR;
6136+  }
6137+
6138+  auto ret = ParallelLaunch(ms_context_, TensorScatterRun, this, op_parameter_->thread_num_);
6139+  if (ret != RET_OK) {
6140+    MS_LOG(ERROR) << "TensorScatterAdd error error_code[" << ret << "]";
6141+  }
6142+  return ret;
6143+}
6144+
6145+REG_KERNEL(kCPU, kNumberTypeFloat32, PrimType_Inner_CustomTensorScatterMax,
6146+           LiteKernelCreator<CustomTensorScatterCPUKernel>)
6147+}  // namespace mindspore::kernel
6148diff --git a/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h
6149new file mode 100644
6150index 00000000..e39733c5
6151--- /dev/null
6152+++ b/mindspore/lite/src/litert/kernel/cpu/base/custom_tensor_scatter.h
6153@@ -0,0 +1,36 @@
6154+/**
6155+ * Copyright 2022 Huawei Technologies Co., Ltd
6156+ *
6157+ * Licensed under the Apache License, Version 2.0 (the "License");
6158+ * you may not use this file except in compliance with the License.
6159+ * You may obtain a copy of the License at
6160+ *
6161+ * http://www.apache.org/licenses/LICENSE-2.0
6162+ *
6163+ * Unless required by applicable law or agreed to in writing, software
6164+ * distributed under the License is distributed on an "AS IS" BASIS,
6165+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6166+ * See the License for the specific language governing permissions and
6167+ * limitations under the License.
6168+ */
6169+
6170+#ifndef MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_
6171+#define MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_
6172+
6173+#include <vector>
6174+#include "src/litert/kernel/cpu/base/scatter_nd_binary.h"
6175+
6176+namespace mindspore::kernel {
6177+class CustomTensorScatterCPUKernel : public ScatterNDBinaryCPUKernel {
6178+ public:
6179+  explicit CustomTensorScatterCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
6180+                                        const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
6181+      : ScatterNDBinaryCPUKernel(parameter, inputs, outputs, ctx) {}
6182+  ~CustomTensorScatterCPUKernel() override = default;
6183+
6184+  int Run() override;
6185+  int TensorScatterDispatch(int task_id);
6186+};
6187+}  // namespace mindspore::kernel
6188+
6189+#endif  // MINDSPORE_LITE_SRC_LITERT_KERNEL_CPU_BASE_TENSOR_SCATTER_ADD_H_
6190diff --git a/mindspore/lite/src/litert/lite_model.cc b/mindspore/lite/src/litert/lite_model.cc
6191index 2c5bc658..13652633 100644
6192--- a/mindspore/lite/src/litert/lite_model.cc
6193+++ b/mindspore/lite/src/litert/lite_model.cc
6194@@ -98,6 +98,8 @@ int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) {
6195   if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr ||
6196       sub_graph.tensorIndices() == nullptr) {
6197     MS_LOG(ERROR) << "sub_graph is invalid";
6198+    MS_LOG(ERROR) << "sub_graph.name() = " << sub_graph.name() << ", sub_graph.inputIndices() = " << sub_graph.inputIndices()
6199+      << ", sub_graph.outputIndices() = " << sub_graph.outputIndices() << ", sub_graph.tensorIndices() = " << sub_graph.tensorIndices();
6200     return RET_ERROR;
6201   }
6202 
6203@@ -620,6 +622,33 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf, minds
6204   return model;
6205 }
6206 
6207+std::string LiteGraph::ToString() const {
6208+  std::stringstream ss;
6209+  ss << "all_nodes: " << all_nodes_.size() << std::endl;
6210+  for (size_t i = 0; i < all_nodes_.size(); i++) {
6211+    ss << "- node " << i << ": " << all_nodes_[i]->primitive_ << std::endl;
6212+    ss << "- node " << i << " input_indices_: " << all_nodes_[i]->input_indices_ << std::endl;
6213+    ss << "- node " << i << " output_indices_: " << all_nodes_[i]->output_indices_ << std::endl;
6214+  }
6215+  ss << "all_tensors: " << all_tensors_.size() << std::endl;
6216+  for (size_t i = 0; i < all_tensors_.size(); i++) {
6217+    ss << "- tensor " << i << ": " << all_tensors_[i] << std::endl;
6218+  }
6219+  ss << "input_indices: " << input_indices_<< std::endl;
6220+  ss << "output_indices: " << output_indices_ << std::endl;
6221+
6222+  ss << "subgraphs: " << std::endl;
6223+  int count = 0;
6224+  for (auto subgraph: sub_graphs_) {
6225+    ss << "- subgraph " << count++ << std::endl;
6226+    ss << "--- subgraph input " << subgraph->input_indices_ << std::endl;
6227+    ss << "--- subgraph output " << subgraph->output_indices_ << std::endl;
6228+    ss << "--- subgraph node " << subgraph->node_indices_ << std::endl;
6229+    ss << "--- subgraph tensor " << subgraph->tensor_indices_ << std::endl;
6230+  }
6231+  return ss.str();
6232+}
6233+
6234 Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); }
6235 
6236 Model *Model::Import(const char *filename) { return ImportFromPath(filename); }
6237diff --git a/mindspore/lite/src/litert/lite_session.cc b/mindspore/lite/src/litert/lite_session.cc
6238index 8f54879e..f635c8d2 100644
6239--- a/mindspore/lite/src/litert/lite_session.cc
6240+++ b/mindspore/lite/src/litert/lite_session.cc
6241@@ -67,6 +67,9 @@
6242 #include "thread/parallel_thread_pool_manager.h"
6243 #endif
6244 #include "src/litert/runtime_packed_node_pass.h"
6245+#ifdef SUPPORT_NNRT
6246+#include "src/litert/delegate/nnrt/nnrt_delegate.h"
6247+#endif
6248 
6249 using AbstractBaseModel = mindspore::infer::AbstractBaseModel;
6250 
6251@@ -635,12 +638,6 @@ int LiteSession::CompileGraph(Model *model) {
6252   MarkSharedWeight(kernels_);
6253   FreePackOpWeight(kernels_);
6254 
6255-  ret = RuntimeAllocatorInit();
6256-  if (ret != RET_OK) {
6257-    MS_LOG(ERROR) << "Runtime allocator init failed.";
6258-    is_running_.store(false);
6259-    return ret;
6260-  }
6261   infer_along_running_ = infer_along_running_ && (runtime_allocator_ == nullptr);
6262   if (infer_along_running_) {
6263     this->context_->set_infer_checker(InferCheckerAll);
6264@@ -1092,6 +1089,27 @@ int LiteSession::CreateCoreMLDelegate() {
6265   return RET_OK;
6266 }
6267 
6268+int LiteSession::CreateNNRTDelegate() {
6269+#if SUPPORT_NNRT
6270+  auto iter = std::find_if(context_->device_list_.begin(), context_->device_list_.end(),
6271+                           [](DeviceContext &device) { return device.device_type_ == lite::DT_NNRT; });
6272+  if(iter == context_->device_list_.end()) {
6273+    MS_LOG(ERROR) << "Found non NNRT device info";
6274+    return RET_ERROR;
6275+  }
6276+
6277+  delegate_ = std::make_shared<NNRTDelegate>(iter->device_info_.nnrt_device_info_);
6278+  if (delegate_ == nullptr) {
6279+    MS_LOG(ERROR) << "New NNRT delegate failed";
6280+    return RET_ERROR;
6281+  }
6282+//  ((NNRTDelegate *)(delegate_.get()))->SetMetaGraph(this->model_->buf);
6283+  delegate_device_type_ = DT_NNRT;
6284+  this->context_->delegate = delegate_;
6285+#endif
6286+  return RET_OK;
6287+};
6288+
6289 int LiteSession::DelegateInit() {
6290 #ifndef DELEGATE_CLIP
6291   int ret = RET_OK;
6292@@ -1115,6 +1133,8 @@ int LiteSession::DelegateInit() {
6293       ret = CreateNPUDelegate();
6294     } else if (context_->IsDeviceTypeEnabled(DT_GPU)) {
6295       ret = CreateTensorRTDelegate();
6296+    } else if (context_->IsDeviceTypeEnabled(DT_NNRT)) {
6297+      ret = CreateNNRTDelegate();
6298     }
6299   }
6300 
6301@@ -1496,12 +1516,6 @@ int LiteSession::Resize(const std::vector<mindspore::lite::Tensor *> &inputs,
6302     return ret;
6303   }
6304 
6305-  if (RuntimeAllocatorInit() != RET_OK) {
6306-    MS_LOG(ERROR) << "Runtime allocator in resize failed.";
6307-    is_running_.store(false);
6308-    return RET_ERROR;
6309-  }
6310-
6311   auto status = GraphOptimizePass(&kernels_);
6312   if (status != RET_OK) {
6313     MS_LOG(ERROR) << "GraphOptimizePass failed.";
6314@@ -2022,7 +2036,6 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path,
6315     delete model;
6316     return RET_ERROR;
6317   }
6318-  model->Free();
6319   set_model(model);
6320   return RET_OK;
6321 }
6322diff --git a/mindspore/lite/src/litert/lite_session.h b/mindspore/lite/src/litert/lite_session.h
6323index f8f8fe08..64a5f6d3 100644
6324--- a/mindspore/lite/src/litert/lite_session.h
6325+++ b/mindspore/lite/src/litert/lite_session.h
6326@@ -178,6 +178,7 @@ class MS_API LiteSession {
6327   int CreateNPUDelegate();
6328   int CreateNNAPIDelegate();
6329   int CreateCoreMLDelegate();
6330+  int CreateNNRTDelegate();
6331   int DelegateInit();
6332   int InitGPURuntime();
6333   int InitSharedThreadPool();
6334diff --git a/mindspore/lite/src/litert/scheduler.cc b/mindspore/lite/src/litert/scheduler.cc
6335index 11382b09..199b4361 100644
6336--- a/mindspore/lite/src/litert/scheduler.cc
6337+++ b/mindspore/lite/src/litert/scheduler.cc
6338@@ -60,6 +60,9 @@
6339 #if defined(PARALLEL_INFERENCE) && defined(ENABLE_MINDRT)
6340 #include "thread/parallel_thread_pool_manager.h"
6341 #endif
6342+#ifdef SUPPORT_NNRT
6343+#include "src/litert/delegate/nnrt/nnrt_delegate.h"
6344+#endif
6345 
6346 using AbstractBaseModel = mindspore::infer::AbstractBaseModel;
6347 
6348@@ -368,6 +371,7 @@ STATUS Scheduler::DelQuantDTypeCastKernel(std::vector<kernel::KernelExec *> *ker
6349 }
6350 
6351 int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
6352+  MS_LOG(DEBUG) << "Start schedule.";
6353   int check_input_ret = CheckInputParam(dst_kernels);
6354   if (check_input_ret != RET_OK) {
6355     MS_LOG(ERROR) << "CheckInputParam failed! ret: " << check_input_ret;
6356@@ -404,11 +408,13 @@ int Scheduler::Schedule(std::vector<kernel::KernelExec *> *dst_kernels) {
6357   }
6358   shape_fusion_pass_->StoreStateAndReset();
6359 
6360+  MS_LOG(DEBUG) << "Start to init delegate kernels.";
6361   ret = InitDelegateKernels(dst_kernels);
6362   if (ret != RET_OK) {
6363     MS_LOG(ERROR) << "Repalce delegate kernels failed.";
6364     return ret;
6365   }
6366+  MS_LOG(DEBUG) << "Finish to init delegate kernels.";
6367 
6368   ret = CheckCpuValid(dst_kernels);
6369   if (ret != RET_OK) {
6370@@ -500,6 +506,17 @@ int Scheduler::ReplaceDelegateKernels(std::vector<kernel::KernelExec *> *dst_ker
6371     MS_LOG(ERROR) << "New delegate model failed.";
6372     return RET_NULL_PTR;
6373   }
6374+
6375+#ifdef SUPPORT_NNRT
6376+  if (context_->IsDeviceTypeEnabled(DT_NNRT)) {
6377+    auto delegate = static_cast<NNRTDelegate *>(delegate_.get());
6378+    delegate->ShallowCopyLiteGraph(this->src_model_->graph_);
6379+    void *meta_graph = reinterpret_cast<void*>(const_cast<mindspore::schema::MetaGraph *>(
6380+      mindspore::schema::GetMetaGraph(this->src_model_->buf)));
6381+    delegate->SetMetaGraph(meta_graph);
6382+  }
6383+#endif
6384+
6385   auto ret = delegate_->Build(model);
6386   if (ret != mindspore::kSuccess) {
6387     delete model;
6388diff --git a/mindspore/lite/src/litert/tensor_category.cc b/mindspore/lite/src/litert/tensor_category.cc
6389index 70d13865..e57cdb28 100644
6390--- a/mindspore/lite/src/litert/tensor_category.cc
6391+++ b/mindspore/lite/src/litert/tensor_category.cc
6392@@ -30,5 +30,9 @@ Category TensorCategory(const schema::Tensor &tensor) {
6393   auto data_size = tensor.data() == nullptr ? 0 : tensor.data()->size();
6394   return TensorCategory(tensor.nodeType(), shape_num, TypeId(tensor.dataType()), data_size);
6395 }
6396+
6397+bool IsConstTensor(const schema::Tensor &tensor) {
6398+  return TensorCategory(tensor) != Category::VAR;
6399+}
6400 }  // namespace lite
6401 }  // namespace mindspore
6402diff --git a/mindspore/lite/src/litert/tensor_category.h b/mindspore/lite/src/litert/tensor_category.h
6403index 83273032..70e65b31 100644
6404--- a/mindspore/lite/src/litert/tensor_category.h
6405+++ b/mindspore/lite/src/litert/tensor_category.h
6406@@ -35,6 +35,7 @@ enum Category {
6407 
6408 Category TensorCategory(const int node_type, const size_t shape_num, const TypeId data_type, const size_t data_size);
6409 Category TensorCategory(const schema::Tensor &tensor);
6410+bool IsConstTensor(const schema::Tensor &tensor);
6411 }  // namespace lite
6412 }  // namespace mindspore
6413 #endif  // MINDSPORE_LITE_SRC_RUNTIME_TENSOR_CATEGORY_H_
6414diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt
6415index 60e240f0..78dab536 100644
6416--- a/mindspore/lite/test/CMakeLists.txt
6417+++ b/mindspore/lite/test/CMakeLists.txt
6418@@ -28,10 +28,14 @@ file(GLOB_RECURSE TEST_UT_SRC
6419         ${TEST_DIR}/ut/src/runtime/kernel/arm/common/*.cc
6420         ${TEST_DIR}/ut/src/runtime/kernel/arm/fp32/*.cc
6421         ${TEST_DIR}/ut/src/runtime/kernel/arm/string/*.cc
6422-        ${TEST_DIR}/ut/src/api/context_c_test.cc
6423-        ${TEST_DIR}/ut/src/api/model_c_test.cc
6424-        ${TEST_DIR}/ut/src/api/tensor_c_test.cc`
6425+#        ${TEST_DIR}/ut/src/api/context_c_test.cc
6426+#        ${TEST_DIR}/ut/src/api/model_c_test.cc
6427+#        ${TEST_DIR}/ut/src/api/tensor_c_test.cc`
6428         )
6429+if(MSLITE_ENABLE_NNRT)
6430+    list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/runtime/nnrt_delegate/nnrt_delegate_tests.cc)
6431+endif()
6432+
6433 if(MSLITE_ENABLE_SERVER_INFERENCE)
6434     list(APPEND TEST_UT_SRC ${TEST_DIR}/ut/src/api/model_parallel_runner_test.cc)
6435 endif()
6436@@ -86,7 +90,7 @@ endif()
6437 
6438 if(MSLITE_ENABLE_INT8)
6439     file(GLOB_RECURSE TEST_INT8_UT_SRC
6440-            ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
6441+#            ${TEST_DIR}/ut/src/runtime/kernel/arm/int8/*.cc
6442             ${TEST_DIR}/ut/nnacl/int8/*.cc
6443             )
6444     list(APPEND TEST_UT_SRC ${TEST_INT8_UT_SRC})
6445@@ -118,6 +122,7 @@ if(MSLITE_ENABLE_CONVERTER)
6446             ${TEST_DIR}/ut/tools/converter/registry/*.cc
6447             ${TEST_DIR}/ut/tools/converter/parser/tflite/*.cc
6448             ${TEST_DIR}/ut/tools/converter/api/*.cc
6449+            ${TEST_DIR}/ut/tools/converter/config_parser/*.cc
6450             ${TEST_DIR}/st/converter_test.cc
6451             ${TEST_DIR}/st/delegate_test.cc
6452             ${TEST_DIR}/st/mindrt_parallel_test.cc
6453@@ -232,7 +237,7 @@ endif()
6454 
6455 if(MSLITE_ENABLE_CONVERTER)
6456     target_link_libraries(lite-test-converter tflite_parser_mid caffe_parser_mid
6457-                                    onnx_parser_mid tf_parser_mid)
6458+                                    onnx_parser_mid tf_parser_mid third_party_parser_mid)
6459 endif()
6460 
6461 if(MSLITE_ENABLE_MODEL_OBF)
6462diff --git a/mindspore/lite/test/runtest.sh b/mindspore/lite/test/runtest.sh
6463index c0d6d843..abdea6f4 100644
6464--- a/mindspore/lite/test/runtest.sh
6465+++ b/mindspore/lite/test/runtest.sh
6466@@ -80,6 +80,7 @@ if [ "$ENABLE_CONVERTER_TEST" = true ]; then
6467   ./lite-test-converter --gtest_filter="PassRegistryTest.TestRegistry"
6468   ./lite-test-converter --gtest_filter="TestConverterAPI.*"
6469   ./lite-test-converter --gtest_filter="SpecifyGraphOutputFormatTest*"
6470+  ./lite-test-converter --gtest_filter="TestThirdPartyParamParser.*"
6471 fi
6472 ./lite-test --gtest_filter="TestRegistry.TestAdd"
6473 ./lite-test --gtest_filter="TestRegistryCustomOp.TestCustomAdd"
6474diff --git a/mindspore/lite/test/ut/test_data/third_party_model.cfg b/mindspore/lite/test/ut/test_data/third_party_model.cfg
6475new file mode 100644
6476index 00000000..b5fcba75
6477--- /dev/null
6478+++ b/mindspore/lite/test/ut/test_data/third_party_model.cfg
6479@@ -0,0 +1,8 @@
6480+[third_party_model]
6481+input_names=demo_in_0;demo_in_1;demo_in_2
6482+input_dtypes=float32;float16;float64
6483+input_shapes=1;2,3;4,5,6
6484+output_names=demo_out_0;demo_out_1;demo_out_2;demo_out_4
6485+output_dtypes=int32;int16;int8;uint8
6486+output_shapes=10;20,30;40;50,60,70
6487+extended_parameters=foo:foo_value;bar:bar_value
6488diff --git a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc
6489index 549bdd72..e73afc0e 100644
6490--- a/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc
6491+++ b/mindspore/lite/test/ut/tools/converter/api/converter_api_test.cc
6492@@ -34,3 +34,13 @@ TEST(TestConverterAPI, ConvertCaffeWithNotExistWeight) {
6493   mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeCaffe, caffe_model, output_model, caffe_weight);
6494   ASSERT_FALSE(converter.Convert().IsOk());
6495 }
6496+
6497+TEST(TestConverterAPI, ConvertThirdParty) {
6498+  std::string third_party_model = "./relu.mindir";
6499+  std::string config_model = "./third_party_model.cfg";
6500+  std::string output_model = "./demo_third_party.ms";
6501+
6502+  mindspore::Converter converter(mindspore::converter::FmkType::kFmkTypeThirdParty, third_party_model, output_model);
6503+  converter.SetConfigFile(config_model);
6504+  ASSERT_TRUE(converter.Convert().IsOk());
6505+}
6506\ No newline at end of file
6507diff --git a/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc
6508new file mode 100644
6509index 00000000..c8eb5536
6510--- /dev/null
6511+++ b/mindspore/lite/test/ut/tools/converter/config_parser/third_party_param_parser_test.cc
6512@@ -0,0 +1,176 @@
6513+/**
6514+ * Copyright 2023 Huawei Technologies Co., Ltd
6515+ *
6516+ * Licensed under the Apache License, Version 2.0 (the "License");
6517+ * you may not use this file except in compliance with the License.
6518+ * You may obtain a copy of the License at
6519+ *
6520+ * http://www.apache.org/licenses/LICENSE-2.0
6521+ *
6522+ * Unless required by applicable law or agreed to in writing, software
6523+ * distributed under the License is distributed on an "AS IS" BASIS,
6524+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6525+ * See the License for the specific language governing permissions and
6526+ * limitations under the License.
6527+ */
6528+
6529+#include "gtest/gtest.h"
6530+#include "tools/converter/config_parser/third_party_param_parser.h"
6531+
6532+using mindspore::ThirdPartyModelParam;
6533+using mindspore::TypeId;
6534+using mindspore::lite::RET_OK;
6535+using mindspore::lite::ThirdPartyModelString;
6536+using mindspore::lite::ThirdPartyParamParser;
6537+
6538+const ThirdPartyModelString kDemoSISOParam = {
6539+  // SISO is short for single-input-single-output.
6540+  .input_dtypes = "float32",
6541+  .input_shapes = "1,2,3,4",
6542+  .input_names = "siso_input",
6543+  .output_dtypes = "int32",
6544+  .output_shapes = "2",
6545+  .output_names = "siso_output",
6546+  .extended_parameters = "siso_foo:siso_foo_value;siso_bar:siso_bar_value",
6547+};
6548+
6549+const ThirdPartyModelString kDemoMIMOParam = {
6550+  // MIMO is short for multiple-input-multiple-output.
6551+  .input_dtypes = "float32;int8;float16",
6552+  .input_shapes = "1,2,3,4;5,6;7,8,9",
6553+  .input_names = "mimo_in_0;mimo_in_1;mimo_in_2",
6554+  .output_dtypes = "int32;float32",
6555+  .output_shapes = "2,4;10,20,30",
6556+  .output_names = "mimo_out_0;mimo_out_1",
6557+  .extended_parameters = "mimo_foo:mimo_foo_value;mimo_bar:mimo_bar_value",
6558+};
6559+
6560+TEST(TestThirdPartyParamParser, ParseSISOParam) {
6561+  ThirdPartyModelString param_string = kDemoSISOParam;
6562+  ThirdPartyModelParam result;
6563+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6564+
6565+  ASSERT_EQ(result.input_names, std::vector<std::string>{"siso_input"});
6566+  ASSERT_EQ(result.input_shapes.size(), 1U);
6567+  std::vector<int64_t> expect_in_shape = {1, 2, 3, 4};
6568+  ASSERT_EQ(result.input_shapes[0], expect_in_shape);
6569+  ASSERT_EQ(result.input_dtypes, std::vector<TypeId>{TypeId::kNumberTypeFloat32});
6570+
6571+  ASSERT_EQ(result.output_names, std::vector<std::string>{"siso_output"});
6572+  ASSERT_EQ(result.output_shapes.size(), 1U);
6573+  std::vector<int64_t> expect_out_shape = {2};
6574+  ASSERT_EQ(result.output_shapes[0], expect_out_shape);
6575+  ASSERT_EQ(result.output_dtypes, std::vector<TypeId>{TypeId::kNumberTypeInt32});
6576+
6577+  const auto &ext_param = result.extended_parameters;
6578+  ASSERT_EQ(ext_param.size(), 2U);
6579+  ASSERT_TRUE(ext_param.find("siso_foo") != ext_param.end());
6580+  auto expect_foo_value = ext_param.at("siso_foo");
6581+  ASSERT_EQ(std::string(expect_foo_value.begin(), expect_foo_value.end()), "siso_foo_value");
6582+  ASSERT_TRUE(ext_param.find("siso_bar") != ext_param.end());
6583+  auto expect_bar_value = ext_param.at("siso_bar");
6584+  ASSERT_EQ(std::string(expect_bar_value.begin(), expect_bar_value.end()), "siso_bar_value");
6585+}
6586+
6587+TEST(TestThirdPartyParamParser, ParseValidDtype) {
6588+  ThirdPartyModelString param_string = kDemoSISOParam;
6589+  const std::vector<std::string> kValidDtypeStrings = {
6590+    "float64", "float32", "float16", "int64", "int32", "int16", "int8", "uint8", "bool",
6591+  };
6592+
6593+  const std::vector<TypeId> kExpects = {
6594+    TypeId::kNumberTypeFloat64, TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat16,
6595+    TypeId::kNumberTypeInt64,   TypeId::kNumberTypeInt32,   TypeId::kNumberTypeInt16,
6596+    TypeId::kNumberTypeInt8,    TypeId::kNumberTypeUInt8,   TypeId::kNumberTypeBool};
6597+
6598+  for (size_t i = 0; i < kValidDtypeStrings.size(); i++) {
6599+    param_string.input_dtypes = kValidDtypeStrings[i];
6600+    ThirdPartyModelParam result;
6601+    ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6602+    ASSERT_EQ(result.input_dtypes[0], kExpects[i]);
6603+  }
6604+}
6605+
6606+TEST(TestThirdPartyParamParser, ParseInvalidDtype) {
6607+  ThirdPartyModelParam result;
6608+  ThirdPartyModelString param_string = kDemoSISOParam;
6609+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6610+  param_string.input_dtypes = "bad_dtype";
6611+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6612+}
6613+
6614+TEST(TestThirdPartyParamParser, ParseValidShape) {
6615+  ThirdPartyModelString param_string = kDemoSISOParam;
6616+  param_string.input_shapes = "256,256,1024,96";  // Only support fixed shape.
6617+  ThirdPartyModelParam result;
6618+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6619+  std::vector<int64_t> expect = {256, 256, 1024, 96};
6620+  ASSERT_EQ(result.input_shapes[0], expect);
6621+}
6622+
6623+TEST(TestThirdPartyParamParser, ParseInvalidShape) {
6624+  ThirdPartyModelParam result;
6625+  ThirdPartyModelString param_string = kDemoSISOParam;
6626+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6627+
6628+  param_string.input_shapes = "256,256,1024,-1";
6629+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6630+
6631+  param_string.input_shapes = "256,256,0,96";
6632+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6633+
6634+  param_string.input_shapes = "256,-256,1024,96";
6635+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6636+
6637+  param_string.input_shapes = "256,foo,1024,96";
6638+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6639+}
6640+
6641+TEST(TestThirdPartyParamParser, ParseDefaultName) {
6642+  ThirdPartyModelParam result;
6643+  ThirdPartyModelString param_string = kDemoSISOParam;
6644+  param_string.input_names = "";
6645+  param_string.output_names = "";
6646+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6647+  ASSERT_EQ(result.input_names[0], "in_0");
6648+  ASSERT_EQ(result.output_names[0], "out_0");
6649+}
6650+
6651+TEST(TestThirdPartyParamParser, ParseMIMOParam) {
6652+  ThirdPartyModelString param_string = kDemoMIMOParam;
6653+  ThirdPartyModelParam result;
6654+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6655+
6656+  std::vector<std::string> expect_input_names = {"mimo_in_0", "mimo_in_1", "mimo_in_2"};
6657+  ASSERT_EQ(result.input_names, expect_input_names);
6658+  std::vector<std::vector<int64_t>> expect_input_shapes = {{1, 2, 3, 4}, {5, 6}, {7, 8, 9}};
6659+  ASSERT_EQ(result.input_shapes, expect_input_shapes);
6660+  std::vector<TypeId> expect_input_dtypes = {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeInt8,
6661+                                             TypeId::kNumberTypeFloat16};
6662+  ASSERT_EQ(result.input_dtypes, expect_input_dtypes);
6663+
6664+  std::vector<std::string> expect_output_names = {"mimo_out_0", "mimo_out_1"};
6665+  ASSERT_EQ(result.output_names, expect_output_names);
6666+  std::vector<std::vector<int64_t>> expect_output_shapes = {{2, 4}, {10, 20, 30}};
6667+  ASSERT_EQ(result.output_shapes, expect_output_shapes);
6668+  std::vector<TypeId> expect_output_dtypes = {TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32};
6669+  ASSERT_EQ(result.output_dtypes, expect_output_dtypes);
6670+}
6671+
6672+TEST(TestThirdPartyParamParser, ParseMismatchedShapeAndDtypeSize) {
6673+  ThirdPartyModelString param_string = kDemoMIMOParam;
6674+  ThirdPartyModelParam result;
6675+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6676+
6677+  param_string.input_shapes = "1,2,3,4;5,6";  // shape size is 2 while dtype size is 3.
6678+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6679+}
6680+
6681+TEST(TestThirdPartyParamParser, ParseMismatchedNameAndDtypeSize) {
6682+  ThirdPartyModelString param_string = kDemoMIMOParam;
6683+  ThirdPartyModelParam result;
6684+  ASSERT_EQ(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6685+
6686+  param_string.input_names = "mimo_in_0;mimo_in_1";  // name size is 2 while dtype size is 3.
6687+  ASSERT_NE(ThirdPartyParamParser::Parse(param_string, &result), RET_OK);
6688+}
6689diff --git a/mindspore/lite/tools/benchmark/benchmark_base.cc b/mindspore/lite/tools/benchmark/benchmark_base.cc
6690index 16b1e218..ebaa9212 100644
6691--- a/mindspore/lite/tools/benchmark/benchmark_base.cc
6692+++ b/mindspore/lite/tools/benchmark/benchmark_base.cc
6693@@ -323,7 +323,7 @@ int BenchmarkBase::CheckThreadNumValid() {
6694 
6695 int BenchmarkBase::CheckDeviceTypeValid() {
6696   if (flags_->device_ != "CPU" && flags_->device_ != "GPU" && flags_->device_ != "NPU" &&
6697-      flags_->device_ != "Ascend310" && flags_->device_ != "Ascend310P") {
6698+      flags_->device_ != "Ascend310" && flags_->device_ != "Ascend310P" && flags_->device_ != "NNRT") {
6699     MS_LOG(ERROR) << "Device type:" << flags_->device_ << " is not supported.";
6700     std::cerr << "Device type:" << flags_->device_ << " is not supported." << std::endl;
6701     return RET_ERROR;
6702diff --git a/mindspore/lite/tools/benchmark/benchmark_base.h b/mindspore/lite/tools/benchmark/benchmark_base.h
6703index acdea21a..f818270c 100644
6704--- a/mindspore/lite/tools/benchmark/benchmark_base.h
6705+++ b/mindspore/lite/tools/benchmark/benchmark_base.h
6706@@ -122,7 +122,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
6707     AddFlag(&BenchmarkFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
6708     AddFlag(&BenchmarkFlags::group_info_file_, "GroupInfoFile", "Communication group info file", "");
6709     AddFlag(&BenchmarkFlags::config_file_, "configFile", "Config file", "");
6710-    AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend310P | Auto", "CPU");
6711+    AddFlag(&BenchmarkFlags::device_, "device", "CPU | GPU | NPU | Ascend310 | Ascend310P | NNRT | Auto", "CPU");
6712     AddFlag(&BenchmarkFlags::provider_, "provider", "device provider litert | tensorrt | mindrt", "litert");
6713     AddFlag(&BenchmarkFlags::cpu_bind_mode_, "cpuBindMode", "Input 0 for NO_BIND, 1 for HIGHER_CPU, 2 for MID_CPU.", 1);
6714     // MarkPerformance
6715diff --git a/mindspore/lite/tools/benchmark/benchmark_c_api.cc b/mindspore/lite/tools/benchmark/benchmark_c_api.cc
6716index 252e65c6..cb0c56b0 100644
6717--- a/mindspore/lite/tools/benchmark/benchmark_c_api.cc
6718+++ b/mindspore/lite/tools/benchmark/benchmark_c_api.cc
6719@@ -125,6 +125,10 @@ int BenchmarkCApi::InitContext() {
6720     OH_AI_DeviceInfoSetFrequency(npu_device_info, kFrequencyDefault);
6721     OH_AI_ContextAddDeviceInfo(context_, npu_device_info);
6722   }
6723+  if (flags_->device_ == "NNRT") {
6724+    OH_AI_DeviceInfoHandle nnrt_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
6725+    OH_AI_ContextAddDeviceInfo(context_, nnrt_device_info);
6726+  }
6727   OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
6728   OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_);
6729   OH_AI_ContextAddDeviceInfo(context_, cpu_device_info);
6730diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc
6731index bb36c168..c18111b6 100644
6732--- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc
6733+++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc
6734@@ -521,6 +521,11 @@ int BenchmarkUnifiedApi::InitMSContext(const std::shared_ptr<mindspore::Context>
6735     // InitMSContextForAscend(context, &device_list);
6736   }
6737 
6738+  if (flags_->device_ == "NNRT" || flags_->device_ == "Auto") {
6739+    std::shared_ptr<NNRTDeviceInfo> nnrt_device_info = std::make_shared<NNRTDeviceInfo>();
6740+    device_list.push_back(nnrt_device_info);
6741+  }
6742+
6743   // CPU priority is behind GPU and NPU
6744   std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
6745   device_info->SetEnableFP16(flags_->enable_fp16_);
6746diff --git a/mindspore/lite/tools/benchmark_train/CMakeLists.txt b/mindspore/lite/tools/benchmark_train/CMakeLists.txt
6747index 0c558524..1b9fc347 100644
6748--- a/mindspore/lite/tools/benchmark_train/CMakeLists.txt
6749+++ b/mindspore/lite/tools/benchmark_train/CMakeLists.txt
6750@@ -9,6 +9,9 @@ set(COMMON_SRC
6751 set(TEST_SRC
6752     ${CMAKE_CURRENT_SOURCE_DIR}/main.cc
6753     ${CMAKE_CURRENT_SOURCE_DIR}/net_train.cc
6754+    ${CMAKE_CURRENT_SOURCE_DIR}/net_train_base.cc
6755+    ${CMAKE_CURRENT_SOURCE_DIR}/run_net_train.cc
6756+    ${CMAKE_CURRENT_SOURCE_DIR}/net_train_c_api.cc
6757     )
6758 
6759 # add static securec link library
6760diff --git a/mindspore/lite/tools/benchmark_train/main.cc b/mindspore/lite/tools/benchmark_train/main.cc
6761index abf3d9dd..76f85aa7 100644
6762--- a/mindspore/lite/tools/benchmark_train/main.cc
6763+++ b/mindspore/lite/tools/benchmark_train/main.cc
6764@@ -17,7 +17,8 @@
6765 #include <malloc.h>
6766 #include <unistd.h>
6767 #include <fstream>
6768-#include "tools/benchmark_train/net_train.h"
6769+#include <iostream>
6770+#include "tools/benchmark_train/run_net_train.h"
6771 
6772 void PrintMem() {
6773   std::string proc_file = "/proc/" + std::to_string(getpid()) + "/status";
6774diff --git a/mindspore/lite/tools/benchmark_train/net_runner.cc b/mindspore/lite/tools/benchmark_train/net_runner.cc
6775index 9b63d29f..edf3e964 100644
6776--- a/mindspore/lite/tools/benchmark_train/net_runner.cc
6777+++ b/mindspore/lite/tools/benchmark_train/net_runner.cc
6778@@ -15,7 +15,7 @@
6779  */
6780 
6781 #include "tools/benchmark_train/net_runner.h"
6782-#include "tools/benchmark_train/net_train.h"
6783+#include "tools/benchmark_train/net_train_base.h"
6784 #include <getopt.h>
6785 #include <malloc.h>
6786 #include <cmath>
6787@@ -187,7 +187,7 @@ int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) {
6788     auto output = tensor.Data();
6789     size_t size;
6790     std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
6791-    auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(output_file.c_str(), &size));
6792+    auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(output_file.c_str(), &size));
6793     if (bin_buf == nullptr) {
6794       MS_LOG(ERROR) << "ReadFile return nullptr";
6795       std::cout << "ReadFile return nullptr" << std::endl;
6796@@ -200,7 +200,7 @@ int NetRunner::CompareOutput(const std::vector<mindspore::MSTensor> &outputs) {
6797                 << ", read size: " << size << std::endl;
6798       return mindspore::kLiteError;
6799     }
6800-    float bias = mindspore::lite::NetTrain::CompareData<float>(bin_buf.get(), tensor.ElementNum(),
6801+    float bias = mindspore::lite::NetTrainBase::CompareData<float>(bin_buf.get(), tensor.ElementNum(),
6802                                                                reinterpret_cast<const float *>(output.get()));
6803     if (bias >= 0) {
6804       total_bias += bias;
6805@@ -332,7 +332,7 @@ int NetRunner::ReadInputFile(std::vector<mindspore::MSTensor> *ms_inputs) {
6806     }
6807     size_t size;
6808     std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
6809-    auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrain::ReadFileBuf(file_name.c_str(), &size));
6810+    auto bin_buf = std::unique_ptr<float[]>(mindspore::lite::NetTrainBase::ReadFileBuf(file_name.c_str(), &size));
6811     if (bin_buf == nullptr) {
6812       MS_LOG(ERROR) << "ReadFile return nullptr";
6813       std::cout << "ReadFile return nullptr" << std::endl;
6814@@ -368,4 +368,4 @@ int CallBack(mindspore::lite::NetTrainFlags *flags) {
6815   return nr.Main();
6816 }
6817 
6818-int init = mindspore::lite::NetTrain::SetNr(CallBack);
6819+int init = mindspore::lite::NetTrainBase::SetNr(CallBack);
6820diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc
6821index d1150043..514bba53 100644
6822--- a/mindspore/lite/tools/benchmark_train/net_train.cc
6823+++ b/mindspore/lite/tools/benchmark_train/net_train.cc
6824@@ -31,74 +31,11 @@
6825 
6826 namespace mindspore {
6827 namespace lite {
6828-static const char *DELIM_SLASH = "/";
6829-constexpr const char *DELIM_COLON = ":";
6830-constexpr const char *DELIM_COMMA = ",";
6831-constexpr int RET_TOO_BIG = -9;
6832 constexpr int kField0 = 0;
6833 constexpr int kField1 = 1;
6834 constexpr int kField2 = 2;
6835 constexpr int kField3 = 3;
6836 constexpr int kField4 = 4;
6837-constexpr int kFieldsToPrint = 5;
6838-constexpr int kPrintOffset = 4;
6839-static const int kTHOUSAND = 1000;
6840-constexpr int kDumpInputsAndOutputs = 0;
6841-constexpr int kDumpOutputs = 2;
6842-
6843-const std::unordered_map<int, std::string> kTypeIdMap{
6844-  {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"},    {kNumberTypeFloat32, "Float32"},
6845-  {kNumberTypeInt8, "Int8"},       {kNumberTypeInt16, "Int16"},      {kNumberTypeInt, "Int32"},
6846-  {kNumberTypeInt32, "Int32"},     {kNumberTypeUInt8, "UInt8"},      {kNumberTypeUInt16, "UInt16"},
6847-  {kNumberTypeUInt, "UInt32"},     {kNumberTypeUInt32, "UInt32"},    {kObjectTypeString, "String"},
6848-  {kNumberTypeBool, "Bool"},       {kObjectTypeTensorType, "Tensor"}};
6849-
6850-const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{
6851-  {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"},     {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"},
6852-  {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"},     {mindspore::CKHW, "CKHW"},   {mindspore::KHWC, "KHWC"},
6853-  {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"},         {mindspore::HW4, "HW4"},     {mindspore::NC, "NC"},
6854-  {mindspore::NC4, "NC4"},   {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}};
6855-
6856-std::function<int(NetTrainFlags *)> NetTrain::nr_cb_ = nullptr;
6857-
6858-int NetTrain::SetNr(std::function<int(NetTrainFlags *)> param) {
6859-  nr_cb_ = param;
6860-  return 0;
6861-}
6862-
6863-float *NetTrain::ReadFileBuf(const std::string file, size_t *size) {
6864-  if (file.empty()) {
6865-    MS_LOG(ERROR) << "file is nullptr";
6866-    return nullptr;
6867-  }
6868-  MS_ASSERT(size != nullptr);
6869-  std::string real_path = RealPath(file.c_str());
6870-  std::ifstream ifs(real_path);
6871-  if (!ifs.good()) {
6872-    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
6873-    return nullptr;
6874-  }
6875-
6876-  if (!ifs.is_open()) {
6877-    MS_LOG(ERROR) << "file: " << real_path << " open failed";
6878-    return nullptr;
6879-  }
6880-
6881-  ifs.seekg(0, std::ios::end);
6882-  *size = ifs.tellg();
6883-  std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1);
6884-  if (buf == nullptr) {
6885-    MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
6886-    ifs.close();
6887-    return nullptr;
6888-  }
6889-
6890-  ifs.seekg(0, std::ios::beg);
6891-  ifs.read(reinterpret_cast<char *>(buf.get()), *size);
6892-  ifs.close();
6893-
6894-  return buf.release();
6895-}
6896 
6897 int NetTrain::GenerateInputData() {
6898   for (auto tensor : ms_inputs_for_api_) {
6899@@ -120,28 +57,6 @@ int NetTrain::GenerateInputData() {
6900   return RET_OK;
6901 }
6902 
6903-int NetTrain::LoadInput() {
6904-  inputs_buf_.clear();
6905-  inputs_size_.clear();
6906-  batch_num_ = 0;
6907-  if (flags_->in_data_file_.empty()) {
6908-    auto status = GenerateInputData();
6909-    if (status != RET_OK) {
6910-      std::cerr << "Generate input data error " << status << std::endl;
6911-      MS_LOG(ERROR) << "Generate input data error " << status;
6912-      return status;
6913-    }
6914-  } else {
6915-    auto status = ReadInputFile();
6916-    if (status != RET_OK) {
6917-      std::cerr << "Read Input File error, " << status << std::endl;
6918-      MS_LOG(ERROR) << "Read Input File error, " << status;
6919-      return status;
6920-    }
6921-  }
6922-  return RET_OK;
6923-}
6924-
6925 int NetTrain::LoadStepInput(size_t step) {
6926   if (step >= batch_num_) {
6927     auto cur_batch = step + 1;
6928@@ -269,30 +184,6 @@ int NetTrain::CompareOutput() {
6929   }
6930 }
6931 
6932-std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name,
6933-                                   const std::string &file_type, const size_t &idx) {
6934-  std::string file_name = op_name;
6935-  auto pos = file_name.find_first_of('/');
6936-  while (pos != std::string::npos) {
6937-    file_name.replace(pos, 1, ".");
6938-    pos = file_name.find_first_of('/');
6939-  }
6940-  file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_";
6941-  for (const auto &dim : tensor->Shape()) {
6942-    file_name += std::to_string(dim) + "_";
6943-  }
6944-  if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) {
6945-    file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType()));
6946-  }
6947-  auto tensor_format = tensor->format();
6948-  if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) {
6949-    file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin";
6950-  }
6951-
6952-  file_name += ".bin";
6953-  return file_name;
6954-}
6955-
6956 int NetTrain::MarkPerformance() {
6957   MS_LOG(INFO) << "Running train loops...";
6958   std::cout << "Running train loops..." << std::endl;
6959@@ -574,26 +465,6 @@ int NetTrain::CreateAndRunNetwork(const std::string &filename, const std::string
6960   return RET_OK;
6961 }
6962 
6963-int NetTrain::RunNetTrain() {
6964-  auto file_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1);
6965-  bool is_train = (file_name.find("train") != std::string::npos) || !flags_->bb_model_file_.empty();
6966-  auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, is_train, flags_->epochs_);
6967-  if (status != RET_OK) {
6968-    MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status;
6969-    std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status
6970-              << std::endl;
6971-    return status;
6972-  }
6973-
6974-  status = CheckExecutionOfSavedModels();  // re-initialize sessions according to flags
6975-  if (status != RET_OK) {
6976-    MS_LOG(ERROR) << "Run CheckExecute error: " << status;
6977-    std::cout << "Run CheckExecute error: " << status << std::endl;
6978-    return status;
6979-  }
6980-  return RET_OK;
6981-}
6982-
6983 int NetTrain::SaveModels() {
6984   if (!flags_->export_file_.empty()) {
6985     if (flags_->bb_model_file_.empty()) {
6986@@ -635,77 +506,6 @@ int NetTrain::SaveModels() {
6987   return RET_OK;
6988 }
6989 
6990-int NetTrain::CheckExecutionOfSavedModels() {
6991-  int status = RET_OK;
6992-  if (!flags_->export_file_.empty()) {
6993-    status = NetTrain::CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0);
6994-    if (status != RET_OK) {
6995-      MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status;
6996-      std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl;
6997-      return status;
6998-    }
6999-    if (flags_->bb_model_file_.empty()) {
7000-      status = NetTrain::CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false);
7001-      if (status != RET_OK) {
7002-        MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status;
7003-        std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl;
7004-        return status;
7005-      }
7006-    }
7007-  }
7008-  if (!flags_->inference_file_.empty()) {
7009-    status = NetTrain::CreateAndRunNetwork(flags_->inference_file_, "", false, 0);
7010-    if (status != RET_OK) {
7011-      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status;
7012-      std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl;
7013-      return status;
7014-    }
7015-    status = NetTrain::CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false);
7016-    if (status != RET_OK) {
7017-      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status;
7018-      std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl;
7019-      return status;
7020-    }
7021-  }
7022-  return status;
7023-}
7024-
7025-void NetTrain::CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out) {
7026-  if (tensor == nullptr) {
7027-    MS_LOG(ERROR) << "input tensor is nullptr.";
7028-    return;
7029-  }
7030-  int tensor_size = tensor->ElementNum();
7031-  void *data = tensor->MutableData();
7032-  auto *fdata = reinterpret_cast<float *>(tensor->MutableData());
7033-  auto type = tensor->DataType();
7034-  std::cout << node_type << " " << in_out << id << " shape=" << tensor->Shape() << " sum=";
7035-  switch (type) {
7036-    case mindspore::DataType::kNumberTypeFloat32:
7037-      TensorNan(reinterpret_cast<float *>(data), tensor_size);
7038-      std::cout << TensorSum<float>(data, tensor_size) << std::endl;
7039-      std::cout << "tensor name: " << tensor->Name() << std::endl;
7040-      std::cout << "data: ";
7041-      for (int i = 0; i <= kPrintOffset && i < tensor_size; i++) {
7042-        std::cout << static_cast<float>(fdata[i]) << ", ";
7043-      }
7044-      std::cout << std::endl;
7045-      break;
7046-    case mindspore::DataType::kNumberTypeInt32:
7047-      std::cout << TensorSum<int>(data, tensor_size) << std::endl;
7048-      break;
7049-#ifdef ENABLE_FP16
7050-    case mindspore::DataType::kNumberTypeFloat16:
7051-      std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl;
7052-      TensorNan(reinterpret_cast<float16_t *>(data), tensor_size);
7053-      break;
7054-#endif
7055-    default:
7056-      std::cout << "unsupported type:" << static_cast<int>(type) << std::endl;
7057-      break;
7058-  }
7059-}
7060-
7061 int NetTrain::InitDumpTensorDataCallbackParameter() {
7062   // before callback
7063   before_call_back_ = [&](const std::vector<mindspore::MSTensor> &before_inputs,
7064@@ -815,178 +615,6 @@ int NetTrain::InitTimeProfilingCallbackParameter() {
7065   return RET_OK;
7066 }
7067 
7068-int NetTrain::InitCallbackParameter() {
7069-  int ret = RET_OK;
7070-  if (flags_->dump_tensor_data_) {
7071-    ret = InitDumpTensorDataCallbackParameter();
7072-  } else if (flags_->time_profiling_) {
7073-    ret = InitTimeProfilingCallbackParameter();
7074-  }
7075-  return ret;
7076-}
7077-
7078-void NetTrainFlags::InitResizeDimsList() {
7079-  std::string content = this->resize_dims_in_;
7080-  std::vector<int> shape;
7081-  auto shape_strs = StrSplit(content, std::string(DELIM_COLON));
7082-  for (const auto &shape_str : shape_strs) {
7083-    shape.clear();
7084-    auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA));
7085-    std::cout << "Resize Dims: ";
7086-    for (const auto &dim_str : dim_strs) {
7087-      std::cout << dim_str << " ";
7088-      shape.emplace_back(static_cast<int>(std::stoi(dim_str)));
7089-    }
7090-    std::cout << std::endl;
7091-    this->resize_dims_.emplace_back(shape);
7092-  }
7093-}
7094-
7095-int NetTrain::Init() {
7096-  if (this->flags_ == nullptr) {
7097-    return 1;
7098-  }
7099-  MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
7100-  MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
7101-  MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
7102-  MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_;
7103-  MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_;
7104-  MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_;
7105-  MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
7106-  MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
7107-  MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
7108-  MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;
7109-  MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_;
7110-
7111-  if (this->flags_->epochs_ < 0) {
7112-    MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";
7113-    std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl;
7114-    return RET_ERROR;
7115-  }
7116-
7117-  if (this->flags_->num_threads_ < 1) {
7118-    MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
7119-    std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl;
7120-    return RET_ERROR;
7121-  }
7122-
7123-  this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
7124-
7125-  if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) {
7126-    MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided";
7127-    std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl;
7128-    return RET_ERROR;
7129-  }
7130-
7131-  if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) {
7132-    MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided";
7133-    std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl;
7134-    return RET_ERROR;
7135-  }
7136-
7137-  if (flags_->model_file_.empty()) {
7138-    MS_LOG(ERROR) << "modelPath is required";
7139-    std::cerr << "modelPath is required" << std::endl;
7140-    return 1;
7141-  }
7142-
7143-  // get dump data output path
7144-  auto dump_cfg_path = std::getenv(dump::kConfigPath);
7145-  if (dump_cfg_path != nullptr) {
7146-    flags_->dump_tensor_data_ = true;
7147-    if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) {
7148-      MS_LOG(ERROR) << "parse dump config file failed.";
7149-      return RET_ERROR;
7150-    }
7151-  } else {
7152-    MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data";
7153-  }
7154-
7155-  auto status = InitCallbackParameter();
7156-  if (status != RET_OK) {
7157-    MS_LOG(ERROR) << "Init callback Parameter failed.";
7158-    std::cerr << "Init callback Parameter failed." << std::endl;
7159-    return RET_ERROR;
7160-  }
7161-
7162-  flags_->InitResizeDimsList();
7163-  if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
7164-      flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
7165-    MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
7166-    std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
7167-    return RET_ERROR;
7168-  }
7169-  return RET_OK;
7170-}
7171-
7172-namespace {
7173-constexpr int kNumToPrint = 5;
7174-}
7175-
7176-int NetTrain::InitDumpConfigFromJson(std::string path) {
7177-  auto real_path = RealPath(path.c_str());
7178-  std::ifstream ifs(real_path);
7179-  if (!ifs.good()) {
7180-    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
7181-    return RET_ERROR;
7182-  }
7183-  if (!ifs.is_open()) {
7184-    MS_LOG(ERROR) << "file: " << real_path << " open failed";
7185-    return RET_ERROR;
7186-  }
7187-
7188-  try {
7189-    dump_cfg_json_ = nlohmann::json::parse(ifs);
7190-  } catch (const nlohmann::json::parse_error &error) {
7191-    MS_LOG(ERROR) << "parse json file failed, please check your file.";
7192-    return RET_ERROR;
7193-  }
7194-  if (dump_cfg_json_[dump::kSettings] == nullptr) {
7195-    MS_LOG(ERROR) << "\"common_dump_settings\" is required.";
7196-    return RET_ERROR;
7197-  }
7198-  if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) {
7199-    MS_LOG(ERROR) << "\"dump_mode\" is required.";
7200-    return RET_ERROR;
7201-  }
7202-  if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) {
7203-    MS_LOG(ERROR) << "\"path\" is required.";
7204-    return RET_ERROR;
7205-  }
7206-  if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) {
7207-    dump_cfg_json_[dump::kSettings][dump::kNetName] = "default";
7208-  }
7209-  if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) {
7210-    dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0;
7211-  }
7212-  if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr &&
7213-      !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) {
7214-    if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) {
7215-      MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)";
7216-      return RET_ERROR;
7217-    }
7218-  }
7219-
7220-  auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>();
7221-  auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>();
7222-  if (abs_path.back() == '\\' || abs_path.back() == '/') {
7223-    dump_file_output_dir_ = abs_path + net_name;
7224-  } else {
7225-#ifdef _WIN32
7226-    dump_file_output_dir_ = abs_path + "\\" + net_name;
7227-#else
7228-    dump_file_output_dir_ = abs_path + "/" + net_name;
7229-#endif
7230-  }
7231-
7232-  auto status = CreateOutputDir(&dump_file_output_dir_);
7233-  if (status != RET_OK) {
7234-    MS_LOG(ERROR) << "create data output directory failed.";
7235-    return RET_ERROR;
7236-  }
7237-  return RET_OK;
7238-}
7239-
7240 int NetTrain::PrintResult(const std::vector<std::string> &title,
7241                           const std::map<std::string, std::pair<int, float>> &result) {
7242   std::vector<size_t> columnLenMax(kFieldsToPrint);
7243@@ -1035,7 +663,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
7244   }
7245 
7246   printf("-------------------------------------------------------------------------\n");
7247-  for (int i = 0; i < kNumToPrint; i++) {
7248+  for (int i = 0; i < kFieldsToPrint; i++) {
7249     auto printBuf = title[i];
7250     if (printBuf.size() > columnLenMax.at(i)) {
7251       columnLenMax.at(i) = printBuf.size();
7252@@ -1045,7 +673,7 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
7253   }
7254   printf("\n");
7255   for (auto &row : rows) {
7256-    for (int j = 0; j < kNumToPrint; j++) {
7257+    for (int j = 0; j < kFieldsToPrint; j++) {
7258       auto printBuf = row[j];
7259       printBuf.resize(columnLenMax.at(j), ' ');
7260       printf("%s\t", printBuf.c_str());
7261@@ -1054,47 +682,5 @@ int NetTrain::PrintResult(const std::vector<std::string> &title,
7262   }
7263   return RET_OK;
7264 }
7265-
7266-int RunNetTrain(int argc, const char **argv) {
7267-  NetTrainFlags flags;
7268-  Option<std::string> err = flags.ParseFlags(argc, argv);
7269-
7270-  if (err.IsSome()) {
7271-    std::cerr << err.Get() << std::endl;
7272-    std::cerr << flags.Usage() << std::endl;
7273-    return RET_ERROR;
7274-  }
7275-
7276-  if (flags.help) {
7277-    std::cerr << flags.Usage() << std::endl;
7278-    return RET_OK;
7279-  }
7280-  if (flags.unified_api_) {
7281-    return NetTrain::RunNr(&flags);
7282-  }
7283-  NetTrain net_trainer(&flags);
7284-  auto status = net_trainer.Init();
7285-  if (status != RET_OK) {
7286-    MS_LOG(ERROR) << "NetTrain init Error : " << status;
7287-    std::cerr << "NetTrain init Error : " << status << std::endl;
7288-    return RET_ERROR;
7289-  }
7290-
7291-  status = net_trainer.RunNetTrain();
7292-  if (status != RET_OK) {
7293-    MS_LOG(ERROR) << "Run NetTrain "
7294-                  << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
7295-                  << " Failed : " << status;
7296-    std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
7297-              << " Failed : " << status << std::endl;
7298-    return RET_ERROR;
7299-  }
7300-
7301-  MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
7302-               << " Success.";
7303-  std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
7304-            << " Success." << std::endl;
7305-  return RET_OK;
7306-}
7307 }  // namespace lite
7308 }  // namespace mindspore
7309diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h
7310index 67e58a04..bdf0ec88 100644
7311--- a/mindspore/lite/tools/benchmark_train/net_train.h
7312+++ b/mindspore/lite/tools/benchmark_train/net_train.h
7313@@ -42,183 +42,22 @@
7314 #include "tools/common/flag_parser.h"
7315 #include "src/common/file_utils.h"
7316 #include "src/common/utils.h"
7317-
7318-#ifdef ENABLE_FP16
7319-static __attribute__((always_inline)) inline bool MS_ISNAN_FP16(float16_t var) {
7320-  volatile float16_t d = var;
7321-  return d != d;
7322-}
7323-#endif
7324+#include "tools/benchmark_train/net_train_base.h"
7325 
7326 namespace mindspore::lite {
7327-enum MS_API DataType { kImage = 0, kBinary = 1 };
7328-
7329-constexpr float relativeTolerance = 1e-5;
7330-constexpr float absoluteTolerance = 1e-8;
7331 extern const std::unordered_map<int, std::string> kTypeIdMap;
7332 extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
7333 
7334-namespace dump {
7335-constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG";
7336-constexpr auto kSettings = "common_dump_settings";
7337-constexpr auto kMode = "dump_mode";
7338-constexpr auto kPath = "path";
7339-constexpr auto kNetName = "net_name";
7340-constexpr auto kInputOutput = "input_output";
7341-constexpr auto kKernels = "kernels";
7342-}  // namespace dump
7343-
7344-template <typename T>
7345-float TensorSum(const void *data, int size) {
7346-  const T *typed_data = reinterpret_cast<const T *>(data);
7347-  float sum = 0.f;
7348-  for (int i = 0; i < size; i++) {
7349-    sum += static_cast<float>(typed_data[i]);
7350-  }
7351-  return sum;
7352-}
7353-
7354-class MS_API NetTrainFlags : public virtual FlagParser {
7355+class MS_API NetTrain : public NetTrainBase {
7356  public:
7357-  NetTrainFlags() {
7358-    // common
7359-    AddFlag(&NetTrainFlags::model_file_, "modelFile", "Input model file", "");
7360-    AddFlag(&NetTrainFlags::bb_model_file_, "bbModelFile", "Backboine model for transfer session", "");
7361-    AddFlag(&NetTrainFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
7362-    // MarkPerformance
7363-    AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0);
7364-    AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false);
7365-    AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1);
7366-    AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1);
7367-    // MarkAccuracy
7368-    AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", "");
7369-    AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", "");
7370-    AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
7371-    AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false);
7372-    AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false);
7373-    AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", "");
7374-    AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", "");
7375-    AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false);
7376-    AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes",
7377-            "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
7378-    AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false);
7379-  }
7380-
7381-  ~NetTrainFlags() override = default;
7382-  void InitResizeDimsList();
7383+  explicit NetTrain(NetTrainFlags *flags) : NetTrainBase(flags) {}
7384+  virtual ~NetTrain() {}
7385 
7386- public:
7387-  // common
7388-  std::string model_file_;
7389-  std::string in_data_file_;
7390-  std::string bb_model_file_;
7391-  std::vector<std::string> input_data_list_;
7392-  DataType in_data_type_;
7393-  std::string in_data_type_in_ = "bin";
7394-  int cpu_bind_mode_ = 1;
7395-  bool enable_fp16_ = false;
7396-  bool virtual_batch_ = false;
7397-  // MarkPerformance
7398-  int num_threads_ = 1;
7399-  int warm_up_loop_count_ = 0;
7400-  bool time_profiling_;
7401-  int epochs_ = 1;
7402-  // MarkAccuracy
7403-  std::string data_file_;
7404-  std::string data_type_ = "FLOAT";
7405-  float accuracy_threshold_;
7406-  // Resize
7407-  std::string export_file_ = "";
7408-  std::string resize_dims_in_ = "";
7409-  bool layer_checksum_ = false;
7410-  std::vector<std::vector<int>> resize_dims_;
7411-  std::string loss_name_ = "";
7412-  std::string inference_file_ = "";
7413-  bool unified_api_ = false;
7414-  bool dump_tensor_data_ = false;
7415-};
7416-
7417-class MS_API NetTrain {
7418- public:
7419-  explicit NetTrain(NetTrainFlags *flags) : flags_(flags) {}
7420-  virtual ~NetTrain() = default;
7421-
7422-  int Init();
7423-  int RunNetTrain();
7424-  static float *ReadFileBuf(const std::string file, size_t *size);
7425-  static int SetNr(std::function<int(NetTrainFlags *)> param);
7426-  static int RunNr(NetTrainFlags *flags) {
7427-    if (nr_cb_ != nullptr) {
7428-      return nr_cb_(flags);
7429-    }
7430-    MS_LOG(WARNING) << "unified api was not tested";
7431-    std::cout << "unified api was not tested";
7432-    return RET_OK;
7433-  }
7434-  // tensorData need to be converter first
7435-  template <typename T>
7436-  static float CompareData(const float *refOutput, int size, const T *msTensorData) {
7437-    size_t errorCount = 0;
7438-    float meanError = 0;
7439-    std::cout << "Out tensor size is: " << size << std::endl;
7440-    std::cout << "Data of model output: ";
7441-    for (int j = 0; j < std::min(50, size); j++) {
7442-      std::cout << static_cast<float>(msTensorData[j]) << " ";
7443-    }
7444-    std::cout << std::endl;
7445-    std::cout << "Data of Ref output  : ";
7446-    for (int j = 0; j < std::min(50, size); j++) {
7447-      std::cout << refOutput[j] << " ";
7448-    }
7449-    std::cout << std::endl;
7450-    for (int j = 0; j < size; j++) {
7451-      if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
7452-        std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
7453-        MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
7454-        return RET_ERROR;
7455-      }
7456-
7457-      auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]);
7458-      auto absoluteError = std::fabs(static_cast<float>(msTensorData[j]) - refOutput[j]);
7459-      if (absoluteError > tolerance) {
7460-        if (fabs(refOutput[j]) == 0) {
7461-          if (absoluteError > 1e-5) {
7462-            meanError += absoluteError;
7463-            errorCount++;
7464-          } else {
7465-            continue;
7466-          }
7467-        } else {
7468-          // just assume that atol = rtol
7469-          meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN);
7470-          errorCount++;
7471-        }
7472-      }
7473-    }
7474-    std::cout << std::endl;
7475-    if (meanError > 0.0f) {
7476-      meanError /= errorCount;
7477-    }
7478-
7479-    if (meanError <= 0.0000001) {
7480-      std::cout << "Mean bias of tensor: 0%" << std::endl;
7481-    } else {
7482-      std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl;
7483-    }
7484-    return meanError;
7485-  }
7486-  int InitDumpConfigFromJson(std::string path);
7487-
7488- private:
7489-  // call GenerateInputData or ReadInputFile to init inputTensors
7490-  int LoadInput();
7491-  void CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out);
7492+ protected:
7493   // call GenerateRandomData to fill inputTensors
7494-  int GenerateInputData();
7495+  int GenerateInputData() override;
7496 
7497-  int GenerateRandomData(mindspore::MSTensor *tensor);
7498-
7499-  int ReadInputFile();
7500+  int ReadInputFile() override;
7501 
7502   int LoadStepInput(size_t step);
7503 
7504@@ -227,20 +66,19 @@ class MS_API NetTrain {
7505   void InitTrainCfg(const std::shared_ptr<TrainCfg> &train_cfg);
7506 
7507   int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs,
7508-                          bool check_accuracy = true);
7509+                          bool check_accuracy = true) override;
7510 
7511   int CreateAndRunNetworkForInference(const std::string &filename, const std::shared_ptr<mindspore::Context> &context);
7512 
7513   int CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename,
7514                                   const std::shared_ptr<mindspore::Context> &context,
7515                                   const std::shared_ptr<TrainCfg> &train_cfg, int epochs);
7516-  int InitCallbackParameter();
7517 
7518-  int InitDumpTensorDataCallbackParameter();
7519+  int InitDumpTensorDataCallbackParameter() override;
7520 
7521-  int InitTimeProfilingCallbackParameter();
7522+  int InitTimeProfilingCallbackParameter() override;
7523 
7524-  int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result);
7525+  int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) override;
7526 
7527   template <typename T>
7528   void PrintInputData(mindspore::MSTensor *input) {
7529@@ -256,39 +94,11 @@ class MS_API NetTrain {
7530     std::cout << std::endl;
7531   }
7532 
7533-  template <typename T>
7534-  std::vector<int64_t> ConverterToInt64Vector(const std::vector<T> &srcDims) {
7535-    std::vector<int64_t> dims;
7536-    for (auto shape : srcDims) {
7537-      dims.push_back(static_cast<int64_t>(shape));
7538-    }
7539-    return dims;
7540-  }
7541-  int MarkPerformance();
7542-  int MarkAccuracy(bool enforce_accuracy = true);
7543-  int CompareOutput();
7544-  int SaveModels();
7545-  int CheckExecutionOfSavedModels();
7546-  void TensorNan(const float *data, int size) {
7547-    for (int i = 0; i < size; i++) {
7548-      if (std::isnan(data[i])) {
7549-        std::cout << "nan value of index=" << i << ", " << data[i] << std::endl;
7550-        break;
7551-      }
7552-    }
7553-  }
7554-#ifdef ENABLE_FP16
7555-  void TensorNan(float16_t *data, int size) {
7556-    for (int i = 0; i < size; i++) {
7557-      if (MS_ISNAN_FP16(data[i]) || std::isinf(data[i])) {
7558-        std::cout << "nan or inf value of index=" << i << ", " << data[i] << std::endl;
7559-        break;
7560-      }
7561-    }
7562-  }
7563-#endif
7564-  NetTrainFlags *flags_{nullptr};
7565-  static std::function<int(NetTrainFlags *)> nr_cb_;
7566+  int MarkPerformance() override;
7567+  int MarkAccuracy(bool enforce_accuracy = true) override;
7568+  int CompareOutput() override;
7569+  int SaveModels() override;
7570+
7571   // callback parameters
7572   uint64_t op_begin_ = 0;
7573   int op_call_times_total_ = 0;
7574@@ -301,13 +111,6 @@ class MS_API NetTrain {
7575 
7576   mindspore::MSKernelCallBack before_call_back_{nullptr};
7577   mindspore::MSKernelCallBack after_call_back_{nullptr};
7578-  nlohmann::json dump_cfg_json_;
7579-  std::string dump_file_output_dir_;
7580-  std::vector<std::shared_ptr<char>> inputs_buf_;
7581-  std::vector<size_t> inputs_size_;
7582-  size_t batch_num_ = 0;
7583 };
7584-
7585-int MS_API RunNetTrain(int argc, const char **argv);
7586 }  // namespace mindspore::lite
7587 #endif  // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_
7588diff --git a/mindspore/lite/tools/benchmark_train/net_train_base.cc b/mindspore/lite/tools/benchmark_train/net_train_base.cc
7589new file mode 100644
7590index 00000000..8d3c75de
7591--- /dev/null
7592+++ b/mindspore/lite/tools/benchmark_train/net_train_base.cc
7593@@ -0,0 +1,410 @@
7594+/**
7595+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
7596+ *
7597+ * Licensed under the Apache License, Version 2.0 (the "License");
7598+ * you may not use this file except in compliance with the License.
7599+ * You may obtain a copy of the License at
7600+ *
7601+ * http://www.apache.org/licenses/LICENSE-2.0
7602+ *
7603+ * Unless required by applicable law or agreed to in writing, software
7604+ * distributed under the License is distributed on an "AS IS" BASIS,
7605+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7606+ * See the License for the specific language governing permissions and
7607+ * limitations under the License.
7608+ */
7609+
7610+#include "tools/benchmark_train/net_train_base.h"
7611+#define __STDC_FORMAT_MACROS
7612+#undef __STDC_FORMAT_MACROS
7613+#include <algorithm>
7614+#include <cstring>
7615+#ifdef ENABLE_NEON
7616+#include <arm_neon.h>
7617+#endif
7618+#include "src/common/common.h"
7619+#include "include/api/serialization.h"
7620+
7621+namespace mindspore {
7622+namespace lite {
7623+const std::unordered_map<int, std::string> kTypeIdMap{
7624+  {kNumberTypeFloat16, "Float16"}, {kNumberTypeFloat, "Float32"},    {kNumberTypeFloat32, "Float32"},
7625+  {kNumberTypeInt8, "Int8"},       {kNumberTypeInt16, "Int16"},      {kNumberTypeInt, "Int32"},
7626+  {kNumberTypeInt32, "Int32"},     {kNumberTypeUInt8, "UInt8"},      {kNumberTypeUInt16, "UInt16"},
7627+  {kNumberTypeUInt, "UInt32"},     {kNumberTypeUInt32, "UInt32"},    {kObjectTypeString, "String"},
7628+  {kNumberTypeBool, "Bool"},       {kObjectTypeTensorType, "Tensor"}};
7629+
7630+const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap{
7631+  {mindspore::NCHW, "NCHW"}, {mindspore::NHWC, "NHWC"},     {mindspore::NHWC4, "NHWC4"}, {mindspore::HWKC, "HWKC"},
7632+  {mindspore::HWCK, "HWCK"}, {mindspore::KCHW, "KCHW"},     {mindspore::CKHW, "CKHW"},   {mindspore::KHWC, "KHWC"},
7633+  {mindspore::CHWK, "CHWK"}, {mindspore::HW, "HW"},         {mindspore::HW4, "HW4"},     {mindspore::NC, "NC"},
7634+  {mindspore::NC4, "NC4"},   {mindspore::NC4HW4, "NC4HW4"}, {mindspore::NCDHW, "NCDHW"}};
7635+
7636+std::function<int(NetTrainFlags *)> NetTrainBase::nr_cb_ = nullptr;
7637+
7638+int NetTrainBase::SetNr(std::function<int(NetTrainFlags *)> param) {
7639+  nr_cb_ = param;
7640+  return 0;
7641+}
7642+
7643+float *NetTrainBase::ReadFileBuf(const std::string file, size_t *size) {
7644+  if (file.empty()) {
7645+    MS_LOG(ERROR) << "file is nullptr";
7646+    return nullptr;
7647+  }
7648+  MS_ASSERT(size != nullptr);
7649+  std::string real_path = RealPath(file.c_str());
7650+  std::ifstream ifs(real_path);
7651+  if (!ifs.good()) {
7652+    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
7653+    return nullptr;
7654+  }
7655+
7656+  if (!ifs.is_open()) {
7657+    MS_LOG(ERROR) << "file: " << real_path << " open failed";
7658+    return nullptr;
7659+  }
7660+
7661+  ifs.seekg(0, std::ios::end);
7662+  *size = ifs.tellg();
7663+  std::unique_ptr<float[]> buf = std::make_unique<float[]>(*size / sizeof(float) + 1);
7664+  if (buf == nullptr) {
7665+    MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
7666+    ifs.close();
7667+    return nullptr;
7668+  }
7669+
7670+  ifs.seekg(0, std::ios::beg);
7671+  ifs.read(reinterpret_cast<char *>(buf.get()), *size);
7672+  ifs.close();
7673+
7674+  return buf.release();
7675+}
7676+
7677+int NetTrainBase::GenerateRandomData(mindspore::MSTensor *tensor) {
7678+  auto input_data = tensor->MutableData();
7679+  if (input_data == nullptr) {
7680+    MS_LOG(ERROR) << "MallocData for inTensor failed";
7681+    return RET_ERROR;
7682+  }
7683+  auto tensor_byte_size = tensor->DataSize();
7684+  char *casted_data = static_cast<char *>(input_data);
7685+  for (size_t i = 0; i < tensor_byte_size; i++) {
7686+    casted_data[i] =
7687+      (tensor->DataType() == mindspore::DataType::kNumberTypeFloat32) ? static_cast<char>(i) : static_cast<char>(0);
7688+  }
7689+  return RET_OK;
7690+}
7691+
7692+int NetTrainBase::LoadInput() {
7693+  inputs_buf_.clear();
7694+  inputs_size_.clear();
7695+  batch_num_ = 0;
7696+  if (flags_->in_data_file_.empty()) {
7697+    auto status = GenerateInputData();
7698+    if (status != RET_OK) {
7699+      std::cerr << "Generate input data error " << status << std::endl;
7700+      MS_LOG(ERROR) << "Generate input data error " << status;
7701+      return status;
7702+    }
7703+  } else {
7704+    auto status = ReadInputFile();
7705+    if (status != RET_OK) {
7706+      std::cerr << "Read Input File error, " << status << std::endl;
7707+      MS_LOG(ERROR) << "Read Input File error, " << status;
7708+      return status;
7709+    }
7710+  }
7711+  return RET_OK;
7712+}
7713+
7714+int NetTrainBase::RunNetTrain() {
7715+  auto file_name = flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1);
7716+  bool is_train = (file_name.find("train") != std::string::npos) || !flags_->bb_model_file_.empty();
7717+  auto status = CreateAndRunNetwork(flags_->model_file_, flags_->bb_model_file_, is_train, flags_->epochs_);
7718+  if (status != RET_OK) {
7719+    MS_LOG(ERROR) << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status;
7720+    std::cout << "CreateAndRunNetwork failed for model " << flags_->model_file_ << ". Status is " << status
7721+              << std::endl;
7722+    return status;
7723+  }
7724+
7725+  status = CheckExecutionOfSavedModels();  // re-initialize sessions according to flags
7726+  if (status != RET_OK) {
7727+    MS_LOG(ERROR) << "Run CheckExecute error: " << status;
7728+    std::cout << "Run CheckExecute error: " << status << std::endl;
7729+    return status;
7730+  }
7731+  return RET_OK;
7732+}
7733+
7734+int NetTrainBase::CheckExecutionOfSavedModels() {
7735+  int status = RET_OK;
7736+  if (!flags_->export_file_.empty()) {
7737+    status = CreateAndRunNetwork(flags_->export_file_, flags_->bb_model_file_, true, 0);
7738+    if (status != RET_OK) {
7739+      MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << " error: " << status;
7740+      std::cout << "Run Exported model " << flags_->export_file_ << " error: " << status << std::endl;
7741+      return status;
7742+    }
7743+    if (flags_->bb_model_file_.empty()) {
7744+      status = CreateAndRunNetwork(flags_->export_file_ + "_qt", "", true, 0, false);
7745+      if (status != RET_OK) {
7746+        MS_LOG(ERROR) << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status;
7747+        std::cout << "Run Exported model " << flags_->export_file_ << "_qt.ms error: " << status << std::endl;
7748+        return status;
7749+      }
7750+    }
7751+  }
7752+  if (!flags_->inference_file_.empty()) {
7753+    status = CreateAndRunNetwork(flags_->inference_file_, "", false, 0);
7754+    if (status != RET_OK) {
7755+      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << ".ms error: " << status;
7756+      std::cout << "Running saved model " << flags_->inference_file_ << ".ms error: " << status << std::endl;
7757+      return status;
7758+    }
7759+    status = CreateAndRunNetwork(flags_->inference_file_ + "_qt", "", false, 0, false);
7760+    if (status != RET_OK) {
7761+      MS_LOG(ERROR) << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status;
7762+      std::cout << "Running saved model " << flags_->inference_file_ << "_qt.ms error: " << status << std::endl;
7763+      return status;
7764+    }
7765+  }
7766+  return status;
7767+}
7768+
7769+void NetTrainBase::CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out) {
7770+  if (tensor == nullptr) {
7771+    MS_LOG(ERROR) << "input tensor is nullptr.";
7772+    return;
7773+  }
7774+  int tensor_size = tensor->ElementNum();
7775+  void *data = tensor->MutableData();
7776+  auto *fdata = reinterpret_cast<float *>(tensor->MutableData());
7777+  auto type = tensor->DataType();
7778+  std::cout << node_type << " " << in_out << id << " shape=" << tensor->Shape() << " sum=";
7779+  switch (type) {
7780+    case mindspore::DataType::kNumberTypeFloat32:
7781+      TensorNan(reinterpret_cast<float *>(data), tensor_size);
7782+      std::cout << TensorSum<float>(data, tensor_size) << std::endl;
7783+      std::cout << "tensor name: " << tensor->Name() << std::endl;
7784+      std::cout << "data: ";
7785+      for (int i = 0; i <= kPrintOffset && i < tensor_size; i++) {
7786+        std::cout << static_cast<float>(fdata[i]) << ", ";
7787+      }
7788+      std::cout << std::endl;
7789+      break;
7790+    case mindspore::DataType::kNumberTypeInt32:
7791+      std::cout << TensorSum<int>(data, tensor_size) << std::endl;
7792+      break;
7793+#ifdef ENABLE_FP16
7794+    case mindspore::DataType::kNumberTypeFloat16:
7795+      std::cout << TensorSum<float16_t>(data, tensor_size) << std::endl;
7796+      TensorNan(reinterpret_cast<float16_t *>(data), tensor_size);
7797+      break;
7798+#endif
7799+    default:
7800+      std::cout << "unsupported type:" << static_cast<int>(type) << std::endl;
7801+      break;
7802+  }
7803+}
7804+
7805+std::string NetTrainBase::GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name,
7806+                                   const std::string &file_type, const size_t &idx) {
7807+  std::string file_name = op_name;
7808+  auto pos = file_name.find_first_of('/');
7809+  while (pos != std::string::npos) {
7810+    file_name.replace(pos, 1, ".");
7811+    pos = file_name.find_first_of('/');
7812+  }
7813+  file_name += "_" + file_type + "_" + std::to_string(idx) + "_shape_";
7814+  for (const auto &dim : tensor->Shape()) {
7815+    file_name += std::to_string(dim) + "_";
7816+  }
7817+  if (kTypeIdMap.find(static_cast<int>(tensor->DataType())) != kTypeIdMap.end()) {
7818+    file_name += kTypeIdMap.at(static_cast<int>(tensor->DataType()));
7819+  }
7820+  auto tensor_format = tensor->format();
7821+  if (kTensorFormatMap.find(tensor_format) != kTensorFormatMap.end()) {
7822+    file_name += "_" + kTensorFormatMap.at(tensor_format) + ".bin";
7823+  }
7824+
7825+  file_name += ".bin";
7826+  return file_name;
7827+}
7828+
7829+int NetTrainBase::InitCallbackParameter() {
7830+  int ret = RET_OK;
7831+  if (flags_->dump_tensor_data_) {
7832+    ret = InitDumpTensorDataCallbackParameter();
7833+  } else if (flags_->time_profiling_) {
7834+    ret = InitTimeProfilingCallbackParameter();
7835+  }
7836+  return ret;
7837+}
7838+
7839+void NetTrainFlags::InitResizeDimsList() {
7840+  std::string content = this->resize_dims_in_;
7841+  if (content.empty()) {
7842+    return;
7843+  }
7844+  std::vector<int> shape;
7845+  auto shape_strs = StrSplit(content, std::string(DELIM_COLON));
7846+  for (const auto &shape_str : shape_strs) {
7847+    shape.clear();
7848+    auto dim_strs = StrSplit(shape_str, std::string(DELIM_COMMA));
7849+    std::cout << "Resize Dims: ";
7850+    for (const auto &dim_str : dim_strs) {
7851+      std::cout << dim_str << " ";
7852+      shape.emplace_back(static_cast<int>(std::stoi(dim_str)));
7853+    }
7854+    std::cout << std::endl;
7855+    this->resize_dims_.emplace_back(shape);
7856+  }
7857+}
7858+
7859+int NetTrainBase::Init() {
7860+  if (this->flags_ == nullptr) {
7861+    return 1;
7862+  }
7863+  MS_LOG(INFO) << "ModelPath = " << this->flags_->model_file_;
7864+  MS_LOG(INFO) << "InDataPath = " << this->flags_->in_data_file_;
7865+  MS_LOG(INFO) << "InDataType = " << this->flags_->in_data_type_in_;
7866+  MS_LOG(INFO) << "Epochs = " << this->flags_->epochs_;
7867+  MS_LOG(INFO) << "AccuracyThreshold = " << this->flags_->accuracy_threshold_;
7868+  MS_LOG(INFO) << "WarmUpLoopCount = " << this->flags_->warm_up_loop_count_;
7869+  MS_LOG(INFO) << "NumThreads = " << this->flags_->num_threads_;
7870+  MS_LOG(INFO) << "expectedDataFile = " << this->flags_->data_file_;
7871+  MS_LOG(INFO) << "exportDataFile = " << this->flags_->export_file_;
7872+  MS_LOG(INFO) << "enableFp16 = " << this->flags_->enable_fp16_;
7873+  MS_LOG(INFO) << "virtualBatch = " << this->flags_->virtual_batch_;
7874+
7875+  if (this->flags_->epochs_ < 0) {
7876+    MS_LOG(ERROR) << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0";
7877+    std::cerr << "epochs:" << this->flags_->epochs_ << " must be equal/greater than 0" << std::endl;
7878+    return RET_ERROR;
7879+  }
7880+
7881+  if (this->flags_->num_threads_ < 1) {
7882+    MS_LOG(ERROR) << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0";
7883+    std::cerr << "numThreads:" << this->flags_->num_threads_ << " must be greater than 0" << std::endl;
7884+    return RET_ERROR;
7885+  }
7886+
7887+  this->flags_->in_data_type_ = this->flags_->in_data_type_in_ == "img" ? kImage : kBinary;
7888+
7889+  if (flags_->in_data_file_.empty() && !flags_->data_file_.empty()) {
7890+    MS_LOG(ERROR) << "expectedDataFile not supported in case that inDataFile is not provided";
7891+    std::cerr << "expectedDataFile is not supported in case that inDataFile is not provided" << std::endl;
7892+    return RET_ERROR;
7893+  }
7894+
7895+  if (flags_->in_data_file_.empty() && !flags_->export_file_.empty()) {
7896+    MS_LOG(ERROR) << "exportDataFile not supported in case that inDataFile is not provided";
7897+    std::cerr << "exportDataFile is not supported in case that inDataFile is not provided" << std::endl;
7898+    return RET_ERROR;
7899+  }
7900+
7901+  if (flags_->model_file_.empty()) {
7902+    MS_LOG(ERROR) << "modelPath is required";
7903+    std::cerr << "modelPath is required" << std::endl;
7904+    return 1;
7905+  }
7906+
7907+  // get dump data output path
7908+  auto dump_cfg_path = std::getenv(dump::kConfigPath);
7909+  if (dump_cfg_path != nullptr) {
7910+    flags_->dump_tensor_data_ = true;
7911+    if (InitDumpConfigFromJson(dump_cfg_path) != RET_OK) {
7912+      MS_LOG(ERROR) << "parse dump config file failed.";
7913+      return RET_ERROR;
7914+    }
7915+  } else {
7916+    MS_LOG(INFO) << "No MINDSPORE_DUMP_CONFIG in env, don't need to dump data";
7917+  }
7918+
7919+  auto status = InitCallbackParameter();
7920+  if (status != RET_OK) {
7921+    MS_LOG(ERROR) << "Init callback Parameter failed.";
7922+    std::cerr << "Init callback Parameter failed." << std::endl;
7923+    return RET_ERROR;
7924+  }
7925+
7926+  flags_->InitResizeDimsList();
7927+  if (!flags_->resize_dims_.empty() && !flags_->input_data_list_.empty() &&
7928+      flags_->resize_dims_.size() != flags_->input_data_list_.size()) {
7929+    MS_LOG(ERROR) << "Size of input resizeDims should be equal to size of input inDataPath";
7930+    std::cerr << "Size of input resizeDims should be equal to size of input inDataPath" << std::endl;
7931+    return RET_ERROR;
7932+  }
7933+  return RET_OK;
7934+}
7935+
7936+int NetTrainBase::InitDumpConfigFromJson(std::string path) {
7937+  auto real_path = RealPath(path.c_str());
7938+  std::ifstream ifs(real_path);
7939+  if (!ifs.good()) {
7940+    MS_LOG(ERROR) << "file: " << real_path << " is not exist";
7941+    return RET_ERROR;
7942+  }
7943+  if (!ifs.is_open()) {
7944+    MS_LOG(ERROR) << "file: " << real_path << " open failed";
7945+    return RET_ERROR;
7946+  }
7947+
7948+  try {
7949+    dump_cfg_json_ = nlohmann::json::parse(ifs);
7950+  } catch (const nlohmann::json::parse_error &error) {
7951+    MS_LOG(ERROR) << "parse json file failed, please check your file.";
7952+    return RET_ERROR;
7953+  }
7954+  if (dump_cfg_json_[dump::kSettings] == nullptr) {
7955+    MS_LOG(ERROR) << "\"common_dump_settings\" is required.";
7956+    return RET_ERROR;
7957+  }
7958+  if (dump_cfg_json_[dump::kSettings][dump::kMode] == nullptr) {
7959+    MS_LOG(ERROR) << "\"dump_mode\" is required.";
7960+    return RET_ERROR;
7961+  }
7962+  if (dump_cfg_json_[dump::kSettings][dump::kPath] == nullptr) {
7963+    MS_LOG(ERROR) << "\"path\" is required.";
7964+    return RET_ERROR;
7965+  }
7966+  if (dump_cfg_json_[dump::kSettings][dump::kNetName] == nullptr) {
7967+    dump_cfg_json_[dump::kSettings][dump::kNetName] = "default";
7968+  }
7969+  if (dump_cfg_json_[dump::kSettings][dump::kInputOutput] == nullptr) {
7970+    dump_cfg_json_[dump::kSettings][dump::kInputOutput] = 0;
7971+  }
7972+  if (dump_cfg_json_[dump::kSettings][dump::kKernels] != nullptr &&
7973+      !dump_cfg_json_[dump::kSettings][dump::kKernels].empty()) {
7974+    if (dump_cfg_json_[dump::kSettings][dump::kMode] == 0) {
7975+      MS_LOG(ERROR) << R"("dump_mode" should be 1 when "kernels" isn't empty.)";
7976+      return RET_ERROR;
7977+    }
7978+  }
7979+
7980+  auto abs_path = dump_cfg_json_[dump::kSettings][dump::kPath].get<std::string>();
7981+  auto net_name = dump_cfg_json_[dump::kSettings][dump::kNetName].get<std::string>();
7982+  if (abs_path.back() == '\\' || abs_path.back() == '/') {
7983+    dump_file_output_dir_ = abs_path + net_name;
7984+  } else {
7985+#ifdef _WIN32
7986+    dump_file_output_dir_ = abs_path + "\\" + net_name;
7987+#else
7988+    dump_file_output_dir_ = abs_path + "/" + net_name;
7989+#endif
7990+  }
7991+
7992+  auto status = CreateOutputDir(&dump_file_output_dir_);
7993+  if (status != RET_OK) {
7994+    MS_LOG(ERROR) << "create data output directory failed.";
7995+    return RET_ERROR;
7996+  }
7997+  return RET_OK;
7998+}
7999+
8000+NetTrainBase:: ~NetTrainBase() {
8001+}
8002+}  // namespace lite
8003+}  // namespace mindspore
8004diff --git a/mindspore/lite/tools/benchmark_train/net_train_base.h b/mindspore/lite/tools/benchmark_train/net_train_base.h
8005new file mode 100644
8006index 00000000..e3d5f39a
8007--- /dev/null
8008+++ b/mindspore/lite/tools/benchmark_train/net_train_base.h
8009@@ -0,0 +1,288 @@
8010+/**
8011+ * Copyright 2020-2023 Huawei Technologies Co., Ltd
8012+ *
8013+ * Licensed under the Apache License, Version 2.0 (the "License");
8014+ * you may not use this file except in compliance with the License.
8015+ * You may obtain a copy of the License at
8016+ *
8017+ * http://www.apache.org/licenses/LICENSE-2.0
8018+ *
8019+ * Unless required by applicable law or agreed to in writing, software
8020+ * distributed under the License is distributed on an "AS IS" BASIS,
8021+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8022+ * See the License for the specific language governing permissions and
8023+ * limitations under the License.
8024+ */
8025+
8026+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_
8027+#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_
8028+
8029+#include <getopt.h>
8030+#include <csignal>
8031+#include <unordered_map>
8032+#include <fstream>
8033+#include <iostream>
8034+#include <map>
8035+#include <cmath>
8036+#include <string>
8037+#include <vector>
8038+#include <memory>
8039+#include <cfloat>
8040+#include <utility>
8041+#include <algorithm>
8042+#include <nlohmann/json.hpp>
8043+#include "include/api/model.h"
8044+#include "include/api/types.h"
8045+#include "include/api/context.h"
8046+#include "include/api/cfg.h"
8047+
8048+#ifdef ENABLE_FP16
8049+#include <arm_neon.h>
8050+#endif
8051+#include "tools/common/flag_parser.h"
8052+#include "src/common/file_utils.h"
8053+#include "src/common/utils.h"
8054+
8055+#ifdef ENABLE_FP16
8056+static __attribute__((always_inline)) inline bool MS_ISNAN_FP16(float16_t var) {
8057+  volatile float16_t d = var;
8058+  return d != d;
8059+}
8060+#endif
8061+
8062+namespace mindspore::lite {
8063+enum MS_API DataType { kImage = 0, kBinary = 1 };
8064+
8065+constexpr float relativeTolerance = 1e-5;
8066+constexpr float absoluteTolerance = 1e-8;
8067+extern const std::unordered_map<int, std::string> kTypeIdMap;
8068+extern const std::unordered_map<mindspore::Format, std::string> kTensorFormatMap;
8069+
8070+constexpr const char *DELIM_SLASH = "/";
8071+constexpr const char *DELIM_COLON = ":";
8072+constexpr const char *DELIM_COMMA = ",";
8073+
8074+constexpr int RET_TOO_BIG = -9;
8075+constexpr int kFieldsToPrint = 5;
8076+constexpr int kPrintOffset = 4;
8077+constexpr int kDumpInputsAndOutputs = 0;
8078+constexpr int kDumpOutputs = 2;
8079+constexpr int kTHOUSAND = 1000;
8080+
8081+namespace dump {
8082+constexpr auto kConfigPath = "MINDSPORE_DUMP_CONFIG";
8083+constexpr auto kSettings = "common_dump_settings";
8084+constexpr auto kMode = "dump_mode";
8085+constexpr auto kPath = "path";
8086+constexpr auto kNetName = "net_name";
8087+constexpr auto kInputOutput = "input_output";
8088+constexpr auto kKernels = "kernels";
8089+}  // namespace dump
8090+
8091+template <typename T>
8092+float TensorSum(const void *data, int size) {
8093+  const T *typed_data = reinterpret_cast<const T *>(data);
8094+  float sum = 0.f;
8095+  for (int i = 0; i < size; i++) {
8096+    sum += static_cast<float>(typed_data[i]);
8097+  }
8098+  return sum;
8099+}
8100+
8101+class MS_API NetTrainFlags : public virtual FlagParser {
8102+ public:
8103+  NetTrainFlags() {
8104+    // common
8105+    AddFlag(&NetTrainFlags::model_file_, "modelFile", "Input model file", "");
8106+    AddFlag(&NetTrainFlags::bb_model_file_, "bbModelFile", "Backboine model for transfer session", "");
8107+    AddFlag(&NetTrainFlags::in_data_file_, "inDataFile", "Input data file, if not set, use random input", "");
8108+    // MarkPerformance
8109+    AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0);
8110+    AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false);
8111+    AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1);
8112+    AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1);
8113+    // MarkAccuracy
8114+    AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", "");
8115+    AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", "");
8116+    AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
8117+    AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false);
8118+    AddFlag(&NetTrainFlags::enable_fp16_, "enableFp16", "Enable float16", false);
8119+    AddFlag(&NetTrainFlags::loss_name_, "lossName", "loss layer name", "");
8120+    AddFlag(&NetTrainFlags::inference_file_, "inferenceFile", "MS file to export inference model", "");
8121+    AddFlag(&NetTrainFlags::virtual_batch_, "virtualBatch", "use virtual batch", false);
8122+    AddFlag(&NetTrainFlags::resize_dims_in_, "inputShapes",
8123+            "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
8124+    AddFlag(&NetTrainFlags::unified_api_, "unifiedApi", "do unified api test", false);
8125+  }
8126+
8127+  ~NetTrainFlags() override = default;
8128+  void InitResizeDimsList();
8129+
8130+ public:
8131+  // common
8132+  std::string model_file_;
8133+  std::string in_data_file_;
8134+  std::string bb_model_file_;
8135+  std::vector<std::string> input_data_list_;
8136+  DataType in_data_type_;
8137+  std::string in_data_type_in_ = "bin";
8138+  int cpu_bind_mode_ = 1;
8139+  bool enable_fp16_ = false;
8140+  bool virtual_batch_ = false;
8141+  // MarkPerformance
8142+  int num_threads_ = 1;
8143+  int warm_up_loop_count_ = 0;
8144+  bool time_profiling_;
8145+  int epochs_ = 1;
8146+  // MarkAccuracy
8147+  std::string data_file_;
8148+  std::string data_type_ = "FLOAT";
8149+  float accuracy_threshold_;
8150+  // Resize
8151+  std::string export_file_ = "";
8152+  std::string resize_dims_in_ = "";
8153+  bool layer_checksum_ = false;
8154+  std::vector<std::vector<int>> resize_dims_;
8155+  std::string loss_name_ = "";
8156+  std::string inference_file_ = "";
8157+  bool unified_api_ = false;
8158+  bool dump_tensor_data_ = false;
8159+};
8160+
8161+class MS_API NetTrainBase {
8162+ public:
8163+  explicit NetTrainBase(NetTrainFlags *flags) : flags_(flags) {}
8164+  virtual ~NetTrainBase();
8165+
8166+  int Init();
8167+  int RunNetTrain();
8168+  static float *ReadFileBuf(const std::string file, size_t *size);
8169+  static int SetNr(std::function<int(NetTrainFlags *)> param);
8170+  static int RunNr(NetTrainFlags *flags) {
8171+    if (nr_cb_ != nullptr) {
8172+      return nr_cb_(flags);
8173+    }
8174+    MS_LOG(WARNING) << "unified api was not tested";
8175+    std::cout << "unified api was not tested";
8176+    return RET_OK;
8177+  }
8178+  // tensorData need to be converter first
8179+  template <typename T>
8180+  static float CompareData(const float *refOutput, int size, const T *msTensorData) {
8181+    size_t errorCount = 0;
8182+    float meanError = 0;
8183+    std::cout << "Out tensor size is: " << size << std::endl;
8184+    std::cout << "Data of model output: ";
8185+    for (int j = 0; j < std::min(50, size); j++) {
8186+      std::cout << static_cast<float>(msTensorData[j]) << " ";
8187+    }
8188+    std::cout << std::endl;
8189+    std::cout << "Data of Ref output  : ";
8190+    for (int j = 0; j < std::min(50, size); j++) {
8191+      std::cout << refOutput[j] << " ";
8192+    }
8193+    std::cout << std::endl;
8194+    for (int j = 0; j < size; j++) {
8195+      if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) {
8196+        std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl;
8197+        MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
8198+        return RET_ERROR;
8199+      }
8200+
8201+      auto tolerance = absoluteTolerance + relativeTolerance * fabs(refOutput[j]);
8202+      auto absoluteError = std::fabs(static_cast<float>(msTensorData[j]) - refOutput[j]);
8203+      if (absoluteError > tolerance) {
8204+        if (fabs(refOutput[j]) == 0) {
8205+          if (absoluteError > 1e-5) {
8206+            meanError += absoluteError;
8207+            errorCount++;
8208+          } else {
8209+            continue;
8210+          }
8211+        } else {
8212+          // just assume that atol = rtol
8213+          meanError += absoluteError / (fabs(refOutput[j]) + FLT_MIN);
8214+          errorCount++;
8215+        }
8216+      }
8217+    }
8218+    std::cout << std::endl;
8219+    if (meanError > 0.0f) {
8220+      meanError /= errorCount;
8221+    }
8222+
8223+    if (meanError <= 0.0000001) {
8224+      std::cout << "Mean bias of tensor: 0%" << std::endl;
8225+    } else {
8226+      std::cout << "Mean bias of tensor: " << meanError * 100 << "%" << std::endl;
8227+    }
8228+    return meanError;
8229+  }
8230+  int InitDumpConfigFromJson(std::string path);
8231+
8232+ protected:
8233+  // call GenerateInputData or ReadInputFile to init inputTensors
8234+  int LoadInput();
8235+  void CheckSum(MSTensor *tensor, const std::string &node_type, int id, const std::string &in_out);
8236+  // call GenerateRandomData to fill inputTensors
8237+  virtual int GenerateInputData() = 0;
8238+
8239+  int GenerateRandomData(mindspore::MSTensor *tensor);
8240+
8241+  std::string GenerateOutputFileName(mindspore::MSTensor *tensor, const std::string &op_name,
8242+                                     const std::string &file_type, const size_t &idx);
8243+  virtual int ReadInputFile() = 0;
8244+
8245+  virtual int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs,
8246+                          bool check_accuracy = true) = 0;
8247+
8248+  int InitCallbackParameter();
8249+
8250+  virtual int InitDumpTensorDataCallbackParameter() = 0;
8251+
8252+  virtual int InitTimeProfilingCallbackParameter() = 0;
8253+
8254+  virtual int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) = 0;
8255+
8256+  template <typename T>
8257+  std::vector<int64_t> ConverterToInt64Vector(const std::vector<T> &srcDims) {
8258+    std::vector<int64_t> dims;
8259+    for (auto shape : srcDims) {
8260+      dims.push_back(static_cast<int64_t>(shape));
8261+    }
8262+    return dims;
8263+  }
8264+  virtual int MarkPerformance() = 0;
8265+  virtual int MarkAccuracy(bool enforce_accuracy = true) = 0;
8266+  virtual int CompareOutput() = 0;
8267+  virtual int SaveModels() = 0;
8268+  int CheckExecutionOfSavedModels();
8269+  void TensorNan(const float *data, int size) {
8270+    for (int i = 0; i < size; i++) {
8271+      if (std::isnan(data[i])) {
8272+        std::cout << "nan value of index=" << i << ", " << data[i] << std::endl;
8273+        break;
8274+      }
8275+    }
8276+  }
8277+#ifdef ENABLE_FP16
8278+  void TensorNan(float16_t *data, int size) {
8279+    for (int i = 0; i < size; i++) {
8280+      if (MS_ISNAN_FP16(data[i]) || std::isinf(data[i])) {
8281+        std::cout << "nan or inf value of index=" << i << ", " << data[i] << std::endl;
8282+        break;
8283+      }
8284+    }
8285+  }
8286+#endif
8287+  NetTrainFlags *flags_{nullptr};
8288+  static std::function<int(NetTrainFlags *)> nr_cb_;
8289+
8290+  nlohmann::json dump_cfg_json_;
8291+  std::string dump_file_output_dir_;
8292+  std::vector<std::shared_ptr<char>> inputs_buf_;
8293+  std::vector<size_t> inputs_size_;
8294+  size_t batch_num_ = 0;
8295+};
8296+}  // namespace mindspore::lite
8297+#endif  // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_BASE_H_
8298diff --git a/mindspore/lite/tools/benchmark_train/net_train_c_api.cc b/mindspore/lite/tools/benchmark_train/net_train_c_api.cc
8299new file mode 100644
8300index 00000000..4dcf3af6
8301--- /dev/null
8302+++ b/mindspore/lite/tools/benchmark_train/net_train_c_api.cc
8303@@ -0,0 +1,659 @@
8304+/**
8305+ * Copyright 2023-2023 Huawei Technologies Co., Ltd
8306+ *
8307+ * Licensed under the Apache License, Version 2.0 (the "License");
8308+ * you may not use this file except in compliance with the License.
8309+ * You may obtain a copy of the License at
8310+ *
8311+ * http://www.apache.org/licenses/LICENSE-2.0
8312+ *
8313+ * Unless required by applicable law or agreed to in writing, software
8314+ * distributed under the License is distributed on an "AS IS" BASIS,
8315+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8316+ * See the License for the specific language governing permissions and
8317+ * limitations under the License.
8318+ */
8319+
8320+#include "net_train_c_api.h"
8321+#include "securec/include/securec.h"
8322+
8323+namespace mindspore {
8324+namespace lite {
8325+uint64_t g_op_begin_ = 0;
8326+int g_op_call_times_total_ = 0;
8327+float g_op_cost_total_ = 0.0f;
8328+
8329+int NetTrainCApi::GenerateInputData() {
8330+  for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) {
8331+    OH_AI_TensorHandle tensor = ms_inputs_for_api_.handle_list[i];
8332+    auto data_type = OH_AI_TensorGetDataType(tensor);
8333+    if (data_type == OH_AI_DATATYPE_OBJECTTYPE_STRING) {
8334+      MS_LOG(ERROR) << "Unsupported OH_AI_DATATYPE_OBJECTTYPE_STRING";
8335+      return RET_ERROR;
8336+    } else {
8337+      (void)GenerateRandomData(static_cast<mindspore::MSTensor *>(tensor));
8338+    }
8339+  }
8340+  return RET_OK;
8341+}
8342+
8343+int NetTrainCApi::SaveModels() {
8344+  if (!flags_->export_file_.empty()) {
8345+    if (flags_->bb_model_file_.empty()) {
8346+      auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->export_file_ + "_qt").c_str(), OH_AI_WEIGHT_QUANT, false,
8347+                        nullptr, 0);
8348+      if (status != OH_AI_STATUS_SUCCESS) {
8349+        MS_LOG(ERROR) << "Export quantized model error " << flags_->export_file_ + "_qt";
8350+        std::cout << "Export quantized model error " << flags_->export_file_ + "_qt" << std::endl;
8351+        return RET_ERROR;
8352+      }
8353+    }
8354+    auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->export_file_).c_str(), OH_AI_NO_QUANT, false,
8355+                                    nullptr, 0);
8356+
8357+    if (status != OH_AI_STATUS_SUCCESS) {
8358+      MS_LOG(ERROR) << "Export non quantized model error " << flags_->export_file_;
8359+      std::cout << "Export non quantized model error " << flags_->export_file_ << std::endl;
8360+      return RET_ERROR;
8361+    }
8362+  }
8363+  if (!flags_->inference_file_.empty()) {
8364+    auto status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->inference_file_ + "_qt").c_str(), OH_AI_WEIGHT_QUANT, true,
8365+                                    nullptr, 0);
8366+    if (status != OH_AI_STATUS_SUCCESS) {
8367+      MS_LOG(ERROR) << "Export quantized inference model error " << flags_->inference_file_ + "_qt";
8368+      std::cout << "Export quantized inference model error " << flags_->inference_file_ + "_qt" << std::endl;
8369+      return RET_ERROR;
8370+    }
8371+
8372+    auto tick = GetTimeUs();
8373+    status = OH_AI_ExportModel(ms_model_, OH_AI_MODELTYPE_MINDIR, (flags_->inference_file_).c_str(), OH_AI_NO_QUANT, true,
8374+                                    nullptr, 0);
8375+    if (status != OH_AI_STATUS_SUCCESS) {
8376+      MS_LOG(ERROR) << "Export non quantized inference model error " << flags_->inference_file_;
8377+      std::cout << "Export non quantized inference model error " << flags_->inference_file_ << std::endl;
8378+      return RET_ERROR;
8379+    }
8380+    std::cout << "ExportInference() execution time is " << GetTimeUs() - tick << "us\n";
8381+  }
8382+  return RET_OK;
8383+}
8384+
8385+int NetTrainCApi::LoadStepInput(size_t step) {
8386+  if (step >= batch_num_) {
8387+    auto cur_batch = step + 1;
8388+    MS_LOG(ERROR) << "Max input Batch is:" << batch_num_ << " but got batch :" << cur_batch;
8389+    return RET_ERROR;
8390+  }
8391+  for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) {
8392+    OH_AI_TensorHandle cur_tensor = ms_inputs_for_api_.handle_list[i];
8393+    MS_ASSERT(cur_tensor != nullptr);
8394+    auto tensor_data_size = OH_AI_TensorGetDataSize(cur_tensor);
8395+    auto input_data = OH_AI_TensorGetMutableData(cur_tensor);
8396+    MS_ASSERT(input_data != nullptr);
8397+    memcpy_s(input_data, tensor_data_size, inputs_buf_[i].get() + step * tensor_data_size, tensor_data_size);
8398+  }
8399+  return RET_OK;
8400+}
8401+
8402+int NetTrainCApi::ReadInputFile() {
8403+  if (this->flags_->in_data_type_ == lite::kImage) {
8404+    MS_LOG(ERROR) << "Unsupported image input";
8405+    return RET_ERROR;
8406+  } else {
8407+    for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) {
8408+      OH_AI_TensorHandle tensor = ms_inputs_for_api_.handle_list[i];
8409+      MS_ASSERT(tensor != nullptr);
8410+      size_t size;
8411+      std::string file_name = flags_->in_data_file_ + std::to_string(i + 1) + ".bin";
8412+      auto bin_buf = lite::ReadFile(file_name.c_str(), &size);
8413+      if (bin_buf == nullptr) {
8414+        MS_LOG(ERROR) << "ReadFile failed";
8415+        return RET_ERROR;
8416+      }
8417+      auto tensor_data_size = OH_AI_TensorGetDataSize(tensor);
8418+      MS_ASSERT(tensor_data_size != 0);
8419+      if (size == 0 || size % tensor_data_size != 0 || (batch_num_ != 0 && size / tensor_data_size != batch_num_)) {
8420+        std::cerr << "Input binary file size error, required :N * " << tensor_data_size << ", in fact: " << size
8421+                  << " ,file_name: " << file_name.c_str() << std::endl;
8422+        MS_LOG(ERROR) << "Input binary file size error, required: N * " << tensor_data_size << ", in fact: " << size
8423+                      << " ,file_name: " << file_name.c_str();
8424+        delete bin_buf;
8425+        return RET_ERROR;
8426+      }
8427+      inputs_buf_.emplace_back(bin_buf);
8428+      inputs_size_.emplace_back(size);
8429+      batch_num_ = size / tensor_data_size;
8430+    }
8431+  }
8432+  return RET_OK;
8433+}
8434+
8435+int NetTrainCApi::InitDumpTensorDataCallbackParameter() {
8436+  MS_LOG(ERROR) << "Unsupported feature.";
8437+  return RET_ERROR;
8438+}
8439+
8440+int NetTrainCApi::InitTimeProfilingCallbackParameter() {
8441+  before_call_back_ = TimeProfilingBeforeCallback;
8442+  after_call_back_ = TimeProfilingAfterCallback;
8443+  return RET_OK;
8444+}
8445+
8446+int NetTrainCApi::InitMSContext() {
8447+  context_ = OH_AI_ContextCreate();
8448+  if (context_ == nullptr) {
8449+    MS_LOG(INFO) << "OH_AI_ContextCreate failed";
8450+    return RET_ERROR;
8451+  }
8452+  OH_AI_ContextSetThreadNum(context_, flags_->num_threads_);
8453+  OH_AI_ContextSetThreadAffinityMode(context_, flags_->cpu_bind_mode_);
8454+
8455+  OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
8456+  OH_AI_DeviceInfoSetEnableFP16(cpu_device_info, flags_->enable_fp16_);
8457+  OH_AI_ContextAddDeviceInfo(context_, cpu_device_info);
8458+  return RET_OK;
8459+}
8460+
8461+char **NetTrainCApi::TransStrVectorToCharArrays(const std::vector<std::string> &s) {
8462+  char **char_arr = static_cast<char **>(malloc(s.size() * sizeof(char *)));
8463+  for (size_t i = 0; i < s.size(); i++) {
8464+    char_arr[i] = static_cast<char *>(malloc((s[i].size() + 1)));
8465+    strcpy(char_arr[i], s[i].c_str());
8466+  }
8467+  return char_arr;
8468+}
8469+
8470+std::vector<std::string> NetTrainCApi::TransCharArraysToStrVector(char **c, const size_t &num) {
8471+  std::vector<std::string> str;
8472+  for (size_t i = 0; i < num; i++) {
8473+    str.push_back(std::string(c[i]));
8474+  }
8475+  return str;
8476+}
8477+
8478+void NetTrainCApi::InitTrainCfg() {
8479+  if (flags_->loss_name_.empty()) {
8480+    return;
8481+  }
8482+
8483+  std::string delimiter = ",";
8484+  size_t pos = 0;
8485+  std::string token;
8486+  train_cfg_ = OH_AI_TrainCfgCreate();
8487+  size_t num = 0;
8488+  std::vector<std::string> train_cfg_loss_name;
8489+  OH_AI_TrainCfgSetLossName(train_cfg_, nullptr, train_cfg_loss_name.size());
8490+  while ((pos = flags_->loss_name_.find(delimiter)) != std::string::npos) {
8491+    token = flags_->loss_name_.substr(0, pos);
8492+    flags_->loss_name_.erase(0, pos + delimiter.length());  // change to delim without deletion
8493+    char **name = OH_AI_TrainCfgGetLossName(train_cfg_, &num);
8494+    train_cfg_loss_name = TransCharArraysToStrVector(name, num);
8495+    train_cfg_loss_name.push_back(token);
8496+    char **loss_name = TransStrVectorToCharArrays(train_cfg_loss_name);
8497+    OH_AI_TrainCfgSetLossName(train_cfg_, const_cast<const char **>(loss_name), train_cfg_loss_name.size());
8498+    for (size_t i = 0; i < train_cfg_loss_name.size(); i++) {
8499+      free(loss_name[i]);
8500+    }
8501+    free(loss_name);
8502+    for (size_t i = 0; i < num; i++) {
8503+      free(name[i]);
8504+    }
8505+    free(name);
8506+  }
8507+  if (!(flags_->loss_name_.empty())) {
8508+    char **name = OH_AI_TrainCfgGetLossName(train_cfg_, &num);
8509+    train_cfg_loss_name = TransCharArraysToStrVector(name, num);
8510+    train_cfg_loss_name.push_back(flags_->loss_name_);
8511+    char **loss_name = TransStrVectorToCharArrays(train_cfg_loss_name);
8512+    OH_AI_TrainCfgSetLossName(train_cfg_, const_cast<const char **>(loss_name), train_cfg_loss_name.size());
8513+    for (size_t i = 0; i < train_cfg_loss_name.size(); i++) {
8514+      free(loss_name[i]);
8515+    }
8516+    free(loss_name);
8517+    for (size_t i = 0; i < num; i++) {
8518+      free(name[i]);
8519+    }
8520+    free(name);
8521+  }
8522+}
8523+
8524+int NetTrainCApi::CreateAndRunNetworkForInference(const std::string &filename,
8525+                                              const OH_AI_ContextHandle &context) {
8526+  std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
8527+  std::string filenamems = filename;
8528+  if (filenamems.substr(filenamems.find_last_of('.') + 1) != "ms") {
8529+    filenamems = filenamems + ".ms";
8530+  }
8531+  MS_LOG(INFO) << "start reading model file " << filenamems.c_str();
8532+  std::cout << "start reading model file " << filenamems.c_str() << std::endl;
8533+  auto status = OH_AI_ModelBuildFromFile(ms_model_, filenamems.c_str(),
8534+                                         static_cast<OH_AI_ModelType>(mindspore::kMindIR), context);
8535+  if (status != OH_AI_STATUS_SUCCESS) {
8536+    MS_LOG(ERROR) << "ms model build failed. " << model_name;
8537+    return RET_ERROR;
8538+  }
8539+  return RET_OK;
8540+}
8541+
8542+int NetTrainCApi::CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename,
8543+                                          const OH_AI_ContextHandle &context,
8544+                                          const OH_AI_TrainCfgHandle &train_cfg, int epochs) {
8545+  std::string model_name = filename.substr(filename.find_last_of(DELIM_SLASH) + 1);
8546+  OH_AI_Status status;
8547+  if (!bb_filename.empty()) {
8548+      MS_LOG(ERROR) << "build transfer learning not supported. " << model_name;
8549+      return RET_ERROR;
8550+  } else {
8551+    MS_LOG(INFO) << "Build mindspore model from model file" << filename.c_str();
8552+    std::cout << "Build mindspore model from model file" << filename.c_str() << std::endl;
8553+    status = OH_AI_TrainModelBuildFromFile(ms_model_, filename.c_str(), OH_AI_MODELTYPE_MINDIR, context, train_cfg);
8554+    if (status != OH_AI_STATUS_SUCCESS) {
8555+      MS_LOG(ERROR) << "build transfer learning failed. " << model_name;
8556+      return RET_ERROR;
8557+    }
8558+  }
8559+  if (epochs > 0) {
8560+    if (flags_->virtual_batch_) {
8561+      OH_AI_ModelSetupVirtualBatch(ms_model_, epochs, -1.0f, -1.0f);
8562+    }
8563+    status = OH_AI_ModelSetTrainMode(ms_model_, true);
8564+    if (status != OH_AI_STATUS_SUCCESS) {
8565+      MS_LOG(ERROR) << "set train mode failed. ";
8566+      return RET_ERROR;
8567+    }
8568+  }
8569+  return RET_OK;
8570+}
8571+
8572+int NetTrainCApi::CompareOutput() {
8573+  std::cout << "================ Comparing Forward Output data ================" << std::endl;
8574+  float total_bias = 0;
8575+  int total_size = 0;
8576+  bool has_error = false;
8577+  auto output_tensors_handle = OH_AI_ModelGetOutputs(ms_model_);
8578+
8579+  std::vector<mindspore::MSTensor> output_tensors;
8580+  for (size_t i = 0; i < output_tensors_handle.handle_num; i++) {
8581+    output_tensors.push_back(*static_cast<mindspore::MSTensor *>(output_tensors_handle.handle_list[i]));
8582+  }
8583+  if (output_tensors.empty()) {
8584+    MS_LOG(ERROR) << "Cannot find output tensors, get model output failed";
8585+    return RET_ERROR;
8586+  }
8587+  std::map<std::string, MSTensor> ordered_outputs;
8588+  for (const auto &output_tensor : output_tensors) {
8589+    ordered_outputs.insert({output_tensor.Name(), output_tensor});
8590+  }
8591+  int i = 1;
8592+  mindspore::MSTensor tensor;
8593+  for (auto &ordered_output : ordered_outputs) {
8594+    tensor = ordered_output.second;
8595+    std::cout << "output is tensor " << ordered_output.first << "\n";
8596+    auto outputs = tensor.MutableData();
8597+    size_t size;
8598+    std::string output_file = flags_->data_file_ + std::to_string(i) + ".bin";
8599+    auto bin_buf = std::unique_ptr<float[]>(ReadFileBuf(output_file.c_str(), &size));
8600+    if (bin_buf == nullptr) {
8601+      MS_LOG(ERROR) << "ReadFile return nullptr";
8602+      std::cout << "ReadFile return nullptr" << std::endl;
8603+      return RET_ERROR;
8604+    }
8605+    if (size != tensor.DataSize()) {
8606+      MS_LOG(ERROR) << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize()
8607+                    << ", read size: " << size;
8608+      std::cout << "Output buffer and output file differ by size. Tensor size: " << tensor.DataSize()
8609+                << ", read size: " << size << std::endl;
8610+      return RET_ERROR;
8611+    }
8612+    float bias = CompareData<float>(bin_buf.get(), tensor.ElementNum(), reinterpret_cast<float *>(outputs));
8613+    if (bias >= 0) {
8614+      total_bias += bias;
8615+      total_size++;
8616+    } else {
8617+      has_error = true;
8618+      break;
8619+    }
8620+    i++;
8621+  }
8622+
8623+  if (!has_error) {
8624+    float mean_bias;
8625+    if (total_size != 0) {
8626+      mean_bias = total_bias / total_size * 100;
8627+    } else {
8628+      mean_bias = 0;
8629+    }
8630+
8631+    std::cout << "Mean bias of all nodes/tensors: " << mean_bias << "%"
8632+              << " threshold is:" << this->flags_->accuracy_threshold_ << std::endl;
8633+    std::cout << "=======================================================" << std::endl << std::endl;
8634+
8635+    if (mean_bias > this->flags_->accuracy_threshold_) {
8636+      MS_LOG(INFO) << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%";
8637+      std::cout << "Mean bias of all nodes/tensors is too big: " << mean_bias << "%" << std::endl;
8638+      return RET_TOO_BIG;
8639+    } else {
8640+      return RET_OK;
8641+    }
8642+  } else {
8643+    MS_LOG(ERROR) << "Error in CompareData";
8644+    std::cerr << "Error in CompareData" << std::endl;
8645+    std::cout << "=======================================================" << std::endl << std::endl;
8646+    return RET_ERROR;
8647+  }
8648+}
8649+
8650+int NetTrainCApi::MarkPerformance() {
8651+  MS_LOG(INFO) << "Running train loops...";
8652+  std::cout << "Running train loops..." << std::endl;
8653+  uint64_t time_min = 0xFFFFFFFFFFFFFFFF;
8654+  uint64_t time_max = 0;
8655+  uint64_t time_avg = 0;
8656+  std::vector<MSTensor> outputs;
8657+
8658+  for (int i = 0; i < flags_->epochs_; i++) {
8659+    auto start = GetTimeUs();
8660+    for (size_t step = 0; step < batch_num_; step++) {
8661+      MS_LOG(INFO) << "Run for epoch:" << i << " step:" << step;
8662+      auto ret = LoadStepInput(step);
8663+      if (ret != RET_OK) {
8664+        return ret;
8665+      }
8666+      auto status = OH_AI_RunStep(ms_model_, before_call_back_, after_call_back_);
8667+      if (status != OH_AI_STATUS_SUCCESS) {
8668+        MS_LOG(ERROR) << "Inference error " << status;
8669+        std::cerr << "Inference error " << status;
8670+        return RET_ERROR;
8671+      }
8672+    }
8673+
8674+    auto end = GetTimeUs();
8675+    auto time = end - start;
8676+    time_min = std::min(time_min, time);
8677+    time_max = std::max(time_max, time);
8678+    time_avg += time;
8679+  }
8680+
8681+  if (flags_->time_profiling_) {
8682+    const std::vector<std::string> per_op_name = {"opName", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
8683+    const std::vector<std::string> per_op_type = {"opType", "avg(ms)", "percent", "calledTimes", "opTotalTime"};
8684+    PrintResult(per_op_name, g_c_op_times_by_name_);
8685+    PrintResult(per_op_type, g_c_op_times_by_type_);
8686+  }
8687+
8688+  if (flags_->epochs_ > 0) {
8689+    time_avg /= static_cast<size_t>(flags_->epochs_);
8690+    MS_LOG(INFO) << "Model = " << flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str()
8691+                 << ", NumThreads = " << flags_->num_threads_ << ", MinRunTime = " << time_min / 1000.0f
8692+                 << ", MaxRuntime = " << time_max / 1000.0f << ", AvgRunTime = " << time_avg / 1000.0f;
8693+    printf("Model = %s, NumThreads = %d, MinRunTime = %f ms, MaxRuntime = %f ms, AvgRunTime = %f ms\n",
8694+           flags_->model_file_.substr(flags_->model_file_.find_last_of(DELIM_SLASH) + 1).c_str(), flags_->num_threads_,
8695+           time_min / 1000.0f, time_max / 1000.0f, time_avg / 1000.0f);
8696+  }
8697+  return RET_OK;
8698+}
8699+
8700+int NetTrainCApi::MarkAccuracy(bool enforce_accuracy) {
8701+  MS_LOG(INFO) << "MarkAccuracy";
8702+  auto load_ret = LoadStepInput(0);
8703+  if (load_ret != RET_OK) {
8704+    return load_ret;
8705+  }
8706+  auto status = PrintInputData();
8707+  if (status != RET_OK) {
8708+    MS_LOG(ERROR) << "PrintInputData failed, ret: " << status;
8709+    return status;
8710+  }
8711+  status = OH_AI_RunStep(ms_model_, before_call_back_, after_call_back_);
8712+  if (status != OH_AI_STATUS_SUCCESS) {
8713+    MS_LOG(ERROR) << "Inference error " << status;
8714+    std::cerr << "Inference error " << status << std::endl;
8715+    return RET_ERROR;
8716+  }
8717+
8718+  auto ret = CompareOutput();
8719+  if (ret == RET_TOO_BIG && !enforce_accuracy) {
8720+    MS_LOG(INFO) << "Accuracy Error is big but not enforced";
8721+    std::cout << "Accuracy Error is big but not enforced" << std::endl;
8722+    return RET_OK;
8723+  }
8724+
8725+  if (ret != RET_OK) {
8726+    MS_LOG(ERROR) << "Compare output error " << ret;
8727+    std::cerr << "Compare output error " << ret << std::endl;
8728+    return ret;
8729+  }
8730+  return RET_OK;
8731+}
8732+
8733+int NetTrainCApi::CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train,
8734+                                  int epochs, bool check_accuracy) {
8735+  auto start_prepare_time = GetTimeUs();
8736+
8737+  int ret = InitMSContext();
8738+  if (ret != RET_OK) {
8739+    MS_LOG(ERROR) << "InitContext failed, ret: " << ret;
8740+    return ret;
8741+  }
8742+
8743+  InitTrainCfg();
8744+  ms_model_ = OH_AI_ModelCreate();
8745+
8746+  if (is_train) {
8747+    ret = CreateAndRunNetworkForTrain(filename, bb_filename, context_ , train_cfg_, epochs);
8748+    if (ret != RET_OK) {
8749+      MS_LOG(ERROR) << "CreateAndRunNetworkForTrain failed.";
8750+      return RET_ERROR;
8751+    }
8752+  } else {
8753+    ret = CreateAndRunNetworkForInference(filename, context_);
8754+    if (ret != RET_OK) {
8755+      MS_LOG(ERROR) << "CreateAndRunNetworkForInference failed.";
8756+      return RET_ERROR;
8757+    }
8758+  }
8759+
8760+  ms_inputs_for_api_ = OH_AI_ModelGetInputs(ms_model_);
8761+  if (ms_inputs_for_api_.handle_list == nullptr) {
8762+    MS_LOG(ERROR) << "OH_AI_ModelGetInputs failed, ret: ";
8763+    return RET_ERROR;
8764+  }
8765+
8766+  if (!flags_->resize_dims_.empty()) {
8767+    std::vector<OH_AI_ShapeInfo> shape_infos;
8768+    std::transform(flags_->resize_dims_.begin(), flags_->resize_dims_.end(), std::back_inserter(shape_infos),
8769+                   [&](auto &shapes) {
8770+                     OH_AI_ShapeInfo shape_info;
8771+                     shape_info.shape_num = shapes.size();
8772+                     for (size_t i = 0; i < shape_info.shape_num; i++) {
8773+                       shape_info.shape[i] = shapes[i];
8774+                     }
8775+                     return shape_info;
8776+                   });
8777+    auto status = OH_AI_ModelResize(ms_model_, ms_inputs_for_api_, shape_infos.data(), shape_infos.size());
8778+    if (status != OH_AI_STATUS_SUCCESS) {
8779+      MS_LOG(ERROR) << "Input tensor resize failed.";
8780+      std::cout << "Input tensor resize failed.";
8781+      return RET_ERROR;
8782+    }
8783+  }
8784+
8785+  auto end_prepare_time = GetTimeUs();
8786+  MS_LOG(INFO) << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms";
8787+  std::cout << "PrepareTime = " << ((end_prepare_time - start_prepare_time) / kTHOUSAND) << " ms" << std::endl;
8788+  // Load input
8789+  MS_LOG(INFO) << "Load input data";
8790+  auto status = LoadInput();
8791+  if (status != RET_OK) {
8792+    MS_LOG(ERROR) << "Load input data error";
8793+    std::cout << "Load input data error" << std::endl;
8794+    return status;
8795+  }
8796+
8797+  if ((epochs > 0) && is_train) {
8798+    status = MarkPerformance();
8799+    if (status != RET_OK) {
8800+      MS_LOG(ERROR) << "Run MarkPerformance error: " << status;
8801+      std::cout << "Run MarkPerformance error: " << status << std::endl;
8802+      return status;
8803+    }
8804+    SaveModels();  // save file if flags are on
8805+  }
8806+  if (!flags_->data_file_.empty()) {
8807+    auto res = OH_AI_ModelSetTrainMode(ms_model_, false);
8808+    if (res != OH_AI_STATUS_SUCCESS) {
8809+      MS_LOG(ERROR) << "set eval mode failed. ";
8810+      return RET_ERROR;
8811+    }
8812+
8813+    status = MarkAccuracy(check_accuracy);
8814+    if (status != RET_OK) {
8815+      MS_LOG(ERROR) << "Run MarkAccuracy error: " << status;
8816+      std::cout << "Run MarkAccuracy error: " << status << std::endl;
8817+      return status;
8818+    }
8819+  }
8820+  return RET_OK;
8821+}
8822+
8823+int NetTrainCApi::PrintInputData() {
8824+  constexpr int64_t kPrintDataNum = 20;
8825+  for (size_t i = 0; i < ms_inputs_for_api_.handle_num; i++) {
8826+    auto input = ms_inputs_for_api_.handle_list[i];
8827+    std::cout << "InData" << i << ": ";
8828+    auto data_type = static_cast<TypeId>(OH_AI_TensorGetDataType(input));
8829+    if (data_type == TypeId::kObjectTypeString) {
8830+      MS_LOG(ERROR) << "Unsupported OH_AI_DATATYPE_OBJECTTYPE_STRING.";
8831+      return RET_ERROR;
8832+    }
8833+    auto tensor_data = OH_AI_TensorGetData(input);
8834+    size_t print_num = std::min(OH_AI_TensorGetElementNum(input), kPrintDataNum);
8835+    for (size_t j = 0; j < print_num; j++) {
8836+      if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat) {
8837+        std::cout << static_cast<const float *>(tensor_data)[j] << " ";
8838+      } else if (data_type == TypeId::kNumberTypeInt8) {
8839+        std::cout << static_cast<const int8_t *>(tensor_data)[j] << " ";
8840+      } else if (data_type == TypeId::kNumberTypeUInt8) {
8841+        std::cout << static_cast<const uint8_t *>(tensor_data)[j] << " ";
8842+      } else if (data_type == TypeId::kNumberTypeInt32) {
8843+        std::cout << static_cast<const int32_t *>(tensor_data)[j] << " ";
8844+      } else if (data_type == TypeId::kNumberTypeInt64) {
8845+        std::cout << static_cast<const int64_t *>(tensor_data)[j] << " ";
8846+      } else if (data_type == TypeId::kNumberTypeBool) {
8847+        std::cout << static_cast<const bool *>(tensor_data)[j] << " ";
8848+      } else {
8849+        MS_LOG(ERROR) << "Datatype: " << data_type << " is not supported.";
8850+        return RET_ERROR;
8851+      }
8852+    }
8853+    std::cout << std::endl;
8854+  }
8855+  return RET_OK;
8856+}
8857+
8858+int NetTrainCApi::PrintResult(const std::vector<std::string> &title,
8859+                          const std::map<std::string, std::pair<int, float>> &result) {
8860+  std::vector<size_t> columnLenMax(kFieldsToPrint);
8861+  std::vector<std::vector<std::string>> rows;
8862+
8863+  for (auto &iter : result) {
8864+    std::string stringBuf[kFieldsToPrint];
8865+    std::vector<std::string> columns;
8866+    size_t len = 0;
8867+    int index = 0;
8868+    len = iter.first.size();
8869+    if (len > columnLenMax.at(index)) {
8870+      columnLenMax.at(index) = len + kPrintOffset;
8871+    }
8872+    columns.push_back(iter.first);
8873+
8874+    index++;
8875+    if (title[0] == "opName") {
8876+      stringBuf[index] = std::to_string(iter.second.second / flags_->epochs_);
8877+    } else {
8878+      stringBuf[index] = std::to_string(iter.second.second / iter.second.first);
8879+    }
8880+    len = stringBuf[index].length();
8881+    if (len > columnLenMax.at(index)) {
8882+      columnLenMax.at(index) = len + kPrintOffset;
8883+    }
8884+    columns.emplace_back(stringBuf[index]);
8885+
8886+    index++;
8887+    stringBuf[index] = std::to_string(iter.second.second / g_op_cost_total_);
8888+    len = stringBuf[index].length();
8889+    if (len > columnLenMax.at(index)) {
8890+      columnLenMax.at(index) = len + kPrintOffset;
8891+    }
8892+    columns.emplace_back(stringBuf[index]);
8893+
8894+    index++;
8895+    stringBuf[index] = std::to_string(iter.second.first);
8896+    len = stringBuf[index].length();
8897+    if (len > columnLenMax.at(index)) {
8898+      columnLenMax.at(index) = len + kPrintOffset;
8899+    }
8900+    columns.emplace_back(stringBuf[index]);
8901+
8902+    index++;
8903+    stringBuf[index] = std::to_string(iter.second.second);
8904+    len = stringBuf[index].length();
8905+    if (len > columnLenMax.at(index)) {
8906+      columnLenMax.at(index) = len + kPrintOffset;
8907+    }
8908+    columns.emplace_back(stringBuf[index]);
8909+
8910+    rows.push_back(columns);
8911+  }
8912+
8913+  printf("-------------------------------------------------------------------------\n");
8914+  for (int i = 0; i < kFieldsToPrint; i++) {
8915+    auto printBuf = title[i];
8916+    if (printBuf.size() > columnLenMax.at(i)) {
8917+      columnLenMax.at(i) = printBuf.size();
8918+    }
8919+    printBuf.resize(columnLenMax.at(i), ' ');
8920+    printf("%s\t", printBuf.c_str());
8921+  }
8922+  printf("\n");
8923+  for (auto &row : rows) {
8924+    for (int j = 0; j < kFieldsToPrint; j++) {
8925+      auto printBuf = row[j];
8926+      printBuf.resize(columnLenMax.at(j), ' ');
8927+      printf("%s\t", printBuf.c_str());
8928+    }
8929+    printf("\n");
8930+  }
8931+  return RET_OK;
8932+}
8933+
8934+bool TimeProfilingBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
8935+                                 const OH_AI_CallBackParam kernel_Info) {
8936+  if (g_c_op_times_by_type_.find(kernel_Info.node_type) == g_c_op_times_by_type_.end()) {
8937+    g_c_op_times_by_type_.insert(std::make_pair(kernel_Info.node_type, std::make_pair(0, 0.0f)));
8938+  }
8939+  if (g_c_op_times_by_name_.find(kernel_Info.node_name) == g_c_op_times_by_name_.end()) {
8940+    g_c_op_times_by_name_.insert(std::make_pair(kernel_Info.node_name, std::make_pair(0, 0.0f)));
8941+  }
8942+
8943+  g_op_call_times_total_++;
8944+  g_op_begin_ = mindspore::lite::GetTimeUs();
8945+  return true;
8946+}
8947+
8948+bool TimeProfilingAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
8949+                                const OH_AI_CallBackParam kernel_Info) {
8950+  uint64_t opEnd = mindspore::lite::GetTimeUs();
8951+  float cost = static_cast<float>(opEnd - g_op_begin_) / 1000.0f;
8952+  g_op_cost_total_ += cost;
8953+  g_c_op_times_by_type_[kernel_Info.node_type].first++;
8954+  g_c_op_times_by_type_[kernel_Info.node_type].second += cost;
8955+  g_c_op_times_by_name_[kernel_Info.node_name].first++;
8956+  g_c_op_times_by_name_[kernel_Info.node_name].second += cost;
8957+  return true;
8958+}
8959+}  // namespace lite
8960+}  // namespace mindspore
8961+
8962+
8963diff --git a/mindspore/lite/tools/benchmark_train/net_train_c_api.h b/mindspore/lite/tools/benchmark_train/net_train_c_api.h
8964new file mode 100644
8965index 00000000..bb84d3c1
8966--- /dev/null
8967+++ b/mindspore/lite/tools/benchmark_train/net_train_c_api.h
8968@@ -0,0 +1,121 @@
8969+/**
8970+ * Copyright 2023-2023 Huawei Technologies Co., Ltd
8971+ *
8972+ * Licensed under the Apache License, Version 2.0 (the "License");
8973+ * you may not use this file except in compliance with the License.
8974+ * You may obtain a copy of the License at
8975+ *
8976+ * http://www.apache.org/licenses/LICENSE-2.0
8977+ *
8978+ * Unless required by applicable law or agreed to in writing, software
8979+ * distributed under the License is distributed on an "AS IS" BASIS,
8980+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8981+ * See the License for the specific language governing permissions and
8982+ * limitations under the License.
8983+ */
8984+
8985+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H
8986+#define MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H
8987+
8988+#include <getopt.h>
8989+#include <csignal>
8990+#include <unordered_map>
8991+#include <fstream>
8992+#include <iostream>
8993+#include <map>
8994+#include <cmath>
8995+#include <string>
8996+#include <vector>
8997+#include <memory>
8998+#include <cfloat>
8999+#include <utility>
9000+#include <algorithm>
9001+#include <nlohmann/json.hpp>
9002+#include "include/api/model.h"
9003+#include "include/api/types.h"
9004+#include "include/api/context.h"
9005+#include "include/api/cfg.h"
9006+
9007+#include "include/c_api/model_c.h"
9008+#include "include/c_api/context_c.h"
9009+
9010+#ifdef ENABLE_FP16
9011+#include <arm_neon.h>
9012+#endif
9013+#include "tools/common/flag_parser.h"
9014+#include "src/common/file_utils.h"
9015+#include "src/common/utils.h"
9016+#include "tools/benchmark_train/net_train_base.h"
9017+
9018+namespace mindspore::lite {
9019+  namespace {
9020+    std::map<std::string, std::pair<int, float>> g_c_op_times_by_type_;
9021+    std::map<std::string, std::pair<int, float>> g_c_op_times_by_name_;
9022+  }
9023+#ifdef __cplusplus
9024+  extern "C" {
9025+#endif
9026+  bool TimeProfilingBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
9027+                                   const OH_AI_CallBackParam kernel_Info);
9028+  bool TimeProfilingAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
9029+                                  const OH_AI_CallBackParam kernel_Info);
9030+#ifdef __cplusplus
9031+  }
9032+#endif
9033+
9034+class MS_API NetTrainCApi : public NetTrainBase  {
9035+ public:
9036+  explicit NetTrainCApi(NetTrainFlags *flags) : NetTrainBase(flags) {}
9037+  virtual ~NetTrainCApi() {};
9038+
9039+ protected:
9040+  // call GenerateRandomData to fill inputTensors
9041+  int GenerateInputData() override;
9042+
9043+  int ReadInputFile() override;
9044+
9045+  int LoadStepInput(size_t step);
9046+
9047+  int InitMSContext();
9048+
9049+  void InitTrainCfg();
9050+
9051+  char **TransStrVectorToCharArrays(const std::vector<std::string> &s);
9052+
9053+  std::vector<std::string> TransCharArraysToStrVector(char **c, const size_t &num);
9054+
9055+  int CreateAndRunNetwork(const std::string &filename, const std::string &bb_filename, bool is_train, int epochs,
9056+    bool check_accuracy = true) override;
9057+
9058+  int CreateAndRunNetworkForInference(const std::string &filename, const OH_AI_ContextHandle &context);
9059+
9060+  int CreateAndRunNetworkForTrain(const std::string &filename, const std::string &bb_filename,
9061+    const OH_AI_ContextHandle &context,
9062+    const OH_AI_TrainCfgHandle &train_cfg, int epochs);
9063+
9064+  int InitDumpTensorDataCallbackParameter() override;
9065+
9066+  int InitTimeProfilingCallbackParameter() override;
9067+
9068+  int PrintResult(const std::vector<std::string> &title, const std::map<std::string, std::pair<int, float>> &result) override;
9069+
9070+  int PrintInputData();
9071+
9072+  int MarkPerformance() override;
9073+
9074+  int MarkAccuracy(bool enforce_accuracy = true) override;
9075+
9076+  int CompareOutput() override;
9077+
9078+  int SaveModels() override;
9079+
9080+  OH_AI_ModelHandle ms_model_;
9081+  OH_AI_TensorHandleArray ms_inputs_for_api_;
9082+  OH_AI_ContextHandle context_ = nullptr;
9083+  OH_AI_TrainCfgHandle train_cfg_ = nullptr;
9084+  OH_AI_KernelCallBack before_call_back_{nullptr};
9085+  OH_AI_KernelCallBack after_call_back_{nullptr};
9086+};
9087+}  // namespace mindspore::lite
9088+
9089+#endif //MINDSPORE_LITE_TOOLS_BENCHMARK_NET_TRAIN_C_API_H
9090diff --git a/mindspore/lite/tools/benchmark_train/run_net_train.cc b/mindspore/lite/tools/benchmark_train/run_net_train.cc
9091new file mode 100644
9092index 00000000..37a7e602
9093--- /dev/null
9094+++ b/mindspore/lite/tools/benchmark_train/run_net_train.cc
9095@@ -0,0 +1,86 @@
9096+/**
9097+ * Copyright 2020 Huawei Technologies Co., Ltd
9098+ *
9099+ * Licensed under the Apache License, Version 2.0 (the "License");
9100+ * you may not use this file except in compliance with the License.
9101+ * You may obtain a copy of the License at
9102+ *
9103+ * http://www.apache.org/licenses/LICENSE-2.0
9104+ *
9105+ * Unless required by applicable law or agreed to in writing, software
9106+ * distributed under the License is distributed on an "AS IS" BASIS,
9107+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9108+ * See the License for the specific language governing permissions and
9109+ * limitations under the License.
9110+ */
9111+
9112+#include "tools/benchmark_train/run_net_train.h"
9113+#include "tools/benchmark_train/net_train.h"
9114+#include "tools/benchmark_train/net_train_c_api.h"
9115+
9116+namespace mindspore {
9117+namespace lite {
9118+int RunNetTrain(int argc, const char **argv) {
9119+  NetTrainFlags flags;
9120+  Option<std::string> err = flags.ParseFlags(argc, argv);
9121+
9122+  if (err.IsSome()) {
9123+    std::cerr << err.Get() << std::endl;
9124+    std::cerr << flags.Usage() << std::endl;
9125+    return RET_ERROR;
9126+  }
9127+
9128+  if (flags.help) {
9129+    std::cerr << flags.Usage() << std::endl;
9130+    return RET_OK;
9131+  }
9132+  if (flags.unified_api_) {
9133+    return NetTrain::RunNr(&flags);
9134+  }
9135+
9136+  auto api_type = std::getenv("MSLITE_API_TYPE");
9137+  if (api_type != nullptr) {
9138+    MS_LOG(INFO) << "MSLITE_API_TYPE = " << api_type;
9139+    std::cout << "MSLITE_API_TYPE = " << api_type << std::endl;
9140+  }
9141+
9142+  NetTrainBase *net_trainer = nullptr;
9143+  if (api_type == nullptr || std::string(api_type) == "NEW") {
9144+    net_trainer = new (std::nothrow) NetTrain(&flags);
9145+  } else if (std::string(api_type) == "C") {
9146+    net_trainer = new (std::nothrow) NetTrainCApi(&flags);
9147+  } else {
9148+    MS_LOG(ERROR) << "Invalid MSLITE_API_TYPE, (NEW/C, default:NEW)";
9149+    return RET_ERROR;
9150+  }
9151+
9152+  if (net_trainer == nullptr) {
9153+    MS_LOG(ERROR) << "new net_trainer failed.";
9154+    return RET_ERROR;
9155+  }
9156+  auto status = net_trainer->Init();
9157+  if (status != RET_OK) {
9158+    MS_LOG(ERROR) << "NetTrain init Error : " << status;
9159+    std::cerr << "NetTrain init Error : " << status << std::endl;
9160+    return RET_ERROR;
9161+  }
9162+
9163+  status = net_trainer->RunNetTrain();
9164+  if (status != RET_OK) {
9165+    MS_LOG(ERROR) << "Run NetTrain "
9166+                  << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str()
9167+                  << " Failed : " << status;
9168+    std::cerr << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str()
9169+              << " Failed : " << status << std::endl;
9170+    return RET_ERROR;
9171+  }
9172+
9173+  MS_LOG(INFO) << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str()
9174+               << " Success.";
9175+  std::cout << "Run NetTrain " << flags.model_file_.substr(flags.model_file_.find_last_of("/") + 1).c_str()
9176+            << " Success." << std::endl;
9177+  delete net_trainer;
9178+  return RET_OK;
9179+}
9180+}  // namespace lite
9181+}  // namespace mindspore
9182\ No newline at end of file
9183diff --git a/mindspore/lite/tools/benchmark_train/run_net_train.h b/mindspore/lite/tools/benchmark_train/run_net_train.h
9184new file mode 100644
9185index 00000000..9ca2d73c
9186--- /dev/null
9187+++ b/mindspore/lite/tools/benchmark_train/run_net_train.h
9188@@ -0,0 +1,22 @@
9189+/**
9190+ * Copyright 2023-2023 Huawei Technologies Co., Ltd
9191+ *
9192+ * Licensed under the Apache License, Version 2.0 (the "License");
9193+ * you may not use this file except in compliance with the License.
9194+ * You may obtain a copy of the License at
9195+ *
9196+ * http://www.apache.org/licenses/LICENSE-2.0
9197+ *
9198+ * Unless required by applicable law or agreed to in writing, software
9199+ * distributed under the License is distributed on an "AS IS" BASIS,
9200+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9201+ * See the License for the specific language governing permissions and
9202+ * limitations under the License.
9203+ */
9204+
9205+#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H
9206+#define MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H
9207+namespace mindspore::lite {
9208+int RunNetTrain(int argc, const char **argv);
9209+}  // namespace mindspore::lite
9210+#endif  // MINDSPORE_LITE_TOOLS_BENCHMARK_RUN_NET_TRAIN_H
9211diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt
9212index 1e09d2ed..f854620f 100644
9213--- a/mindspore/lite/tools/converter/CMakeLists.txt
9214+++ b/mindspore/lite/tools/converter/CMakeLists.txt
9215@@ -7,6 +7,8 @@ endif()
9216 
9217 set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
9218 
9219+include_directories(${CMAKE_SOURCE_DIR}/mindspore/lite/)
9220+
9221 if(ENABLE_GPU)
9222     add_compile_definitions(ENABLE_GPU)
9223 endif()
9224@@ -70,6 +72,7 @@ add_subdirectory(parser/caffe)
9225 add_subdirectory(parser/tflite)
9226 add_subdirectory(parser/onnx)
9227 add_subdirectory(parser/tf)
9228+add_subdirectory(parser/third_party)
9229 if(ENABLE_CONVERT_PYTORCH_MODEL)
9230     add_subdirectory(parser/pytorch)
9231 endif()
9232@@ -363,6 +366,7 @@ target_link_libraries(mindspore_converter
9233         tf_parser_mid
9234         caffe_parser_mid
9235         onnx_parser_mid
9236+        third_party_parser_mid
9237         lite_exporter_mid
9238         graph_pass_mid
9239         fusion_mid
9240diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
9241index fecc56d9..2e7ca749 100644
9242--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
9243+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc
9244@@ -34,6 +34,7 @@ constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param";
9245 constexpr auto kDataPreprocessParam = "data_preprocess_param";
9246 constexpr auto kRegistry = "registry";
9247 constexpr auto kMicroParam = "micro_param";
9248+constexpr auto kThirdPartyModelParam = "third_party_model";
9249 constexpr auto kCpuOptionParam = "cpu_option_cfg_param";
9250 constexpr auto kCustomOppPath = "custom_opp_path";
9251 constexpr auto kTransformQuantParam = "transform_quant_param";
9252@@ -330,6 +331,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin
9253     MS_LOG(ERROR) << "ParseMicroParamString failed.";
9254     return ret;
9255   }
9256+  ret = ParseThirdPartyParamString(*maps);
9257+  (void)maps->erase(kThirdPartyModelParam);
9258+  if (ret != RET_OK) {
9259+    MS_LOG(ERROR) << "ParseTransformQuantString failed.";
9260+    return ret;
9261+  }
9262   ret = ParseWeightQuantString(*maps);
9263   (void)maps->erase(kWeightQuantParam);
9264   if (ret != RET_OK) {
9265@@ -594,5 +601,25 @@ int ConfigFileParser::ParseGraphKernelString(const std::map<std::string, std::ma
9266   }
9267   return RET_OK;
9268 }
9269+
9270+int ConfigFileParser::ParseThirdPartyParamString(
9271+  const std::map<std::string, std::map<std::string, std::string>> &sections) {
9272+  if (sections.find(kThirdPartyModelParam) == sections.end()) {
9273+    return RET_OK;
9274+  }
9275+  const auto &input_args = sections.at(kThirdPartyModelParam);
9276+  const std::map<std::string, std::string &> kValidArgs = {
9277+    {"input_shapes", third_party_model_string_.input_shapes},
9278+    {"input_dtypes", third_party_model_string_.input_dtypes},
9279+    {"input_names", third_party_model_string_.input_names},
9280+    {"input_formats", third_party_model_string_.input_formats},
9281+    {"output_shapes", third_party_model_string_.output_shapes},
9282+    {"output_dtypes", third_party_model_string_.output_dtypes},
9283+    {"output_names", third_party_model_string_.output_names},
9284+    {"output_formats", third_party_model_string_.output_formats},
9285+    {"extended_parameters", third_party_model_string_.extended_parameters},
9286+  };
9287+  return SetMapData(input_args, kValidArgs, kThirdPartyModelParam);
9288+}
9289 }  // namespace lite
9290 }  // namespace mindspore
9291diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h
9292index 31269816..6997bac8 100644
9293--- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h
9294+++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h
9295@@ -110,6 +110,18 @@ struct MicroParamString {
9296   std::string changeable_weights_name;
9297 };
9298 
9299+struct ThirdPartyModelString {
9300+  std::string input_dtypes;
9301+  std::string input_shapes;
9302+  std::string input_names;  // optional, default: ""
9303+  std::string input_formats;  // optional, default: NHWC
9304+  std::string output_dtypes;
9305+  std::string output_shapes;
9306+  std::string output_names;  // optional, default: ""
9307+  std::string output_formats;  // optional, default: NHWC
9308+  std::string extended_parameters;  // format: {key1:value1;ker2:value2}
9309+};
9310+
9311 struct CpuOptionCfgString {
9312   std::string architecture;
9313   std::string instruction;
9314@@ -144,6 +156,7 @@ class ConfigFileParser {
9315   RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; }
9316   AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; }
9317   MicroParamString GetMicroParamString() { return this->micro_param_string_; }
9318+  lite::ThirdPartyModelString GetThirdPartyModelString() const { return this->third_party_model_string_; }
9319   CpuOptionCfgString GetCpuOptionCfgString() { return this->cpu_option_cfg_string_; }
9320   TransformQuantString GetTransformQuantString() const { return this->transform_quant_string_; }
9321   AscendQuantString GetAscendQuantString() const { return this->ascend_quant_string_; }
9322@@ -161,6 +174,7 @@ class ConfigFileParser {
9323   int SetMapData(const std::map<std::string, std::string> &input_map,
9324                  const std::map<std::string, std::string &> &parse_map, const std::string &section);
9325   int ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps);
9326+  int ParseThirdPartyParamString(const std::map<std::string, std::map<std::string, std::string>> &sections);
9327   int ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps);
9328   int ParseTransformQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
9329   int ParseAscendQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
9330@@ -176,6 +190,7 @@ class ConfigFileParser {
9331   RegistryInfoString registry_info_string_;
9332   AclOptionCfgString acl_option_cfg_string_;
9333   MicroParamString micro_param_string_;
9334+  lite::ThirdPartyModelString third_party_model_string_;
9335   CpuOptionCfgString cpu_option_cfg_string_;
9336   TransformQuantString transform_quant_string_;
9337   AscendQuantString ascend_quant_string_;
9338diff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc
9339new file mode 100644
9340index 00000000..aee6a29c
9341--- /dev/null
9342+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.cc
9343@@ -0,0 +1,299 @@
9344+/**
9345+ * Copyright 2023 Huawei Technologies Co., Ltd
9346+ *
9347+ * Licensed under the Apache License, Version 2.0 (the "License");
9348+ * you may not use this file except in compliance with the License.
9349+ * You may obtain a copy of the License at
9350+ *
9351+ * http://www.apache.org/licenses/LICENSE-2.0
9352+ *
9353+ * Unless required by applicable law or agreed to in writing, software
9354+ * distributed under the License is distributed on an "AS IS" BASIS,
9355+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9356+ * See the License for the specific language governing permissions and
9357+ * limitations under the License.
9358+ */
9359+
9360+#include "tools/converter/config_parser/third_party_param_parser.h"
9361+#include <vector>
9362+#include <string>
9363+#include <map>
9364+#include "include/errorcode.h"
9365+#include "src/common/log_adapter.h"
9366+#include "nnacl/op_base.h"
9367+#include "tools/common/string_util.h"
9368+
9369+namespace mindspore {
9370+namespace lite {
9371+namespace {
9372+const std::map<std::string, TypeId> kDataTypeMap = {
9373+  {"float64", TypeId::kNumberTypeFloat64}, {"float32", TypeId::kNumberTypeFloat32},
9374+  {"float16", TypeId::kNumberTypeFloat16}, {"int64", TypeId::kNumberTypeInt64},
9375+  {"int32", TypeId::kNumberTypeInt32},     {"int16", TypeId::kNumberTypeInt16},
9376+  {"int8", TypeId::kNumberTypeInt8},       {"uint8", TypeId::kNumberTypeUInt8},
9377+  {"bool", TypeId::kNumberTypeBool},
9378+};
9379+
9380+TypeId ConvertDataType(const std::string &type) {
9381+  auto iter = kDataTypeMap.find(type);
9382+  if (iter == kDataTypeMap.end()) {
9383+    return TypeId::kTypeUnknown;
9384+  }
9385+  return iter->second;
9386+}
9387+}  // namespace
9388+
9389+/**
9390+ * Parse shapes like "1,256,256,3;3,96;96,96", and return like [[1,256,256,3], [3,96], [96,96]].
9391+ */
9392+int ThirdPartyParamParser::DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes) {
9393+  MS_CHECK_TRUE_RET(dst_shapes != nullptr, RET_ERROR);
9394+  dst_shapes->clear();
9395+
9396+  auto tmp_shapes = SplitStringToVector(src, ";");
9397+  for (auto tmp_shape : tmp_shapes) {
9398+    auto tmp = SplitStringToVector(tmp_shape, ",");
9399+    std::vector<int64_t> shape = {};
9400+    for (auto t : tmp) {
9401+      int value = 0;
9402+      if (!ConvertIntNum(t, &value)) {
9403+        MS_LOG(ERROR) << "Found error when convert shape string to integer";
9404+        return RET_ERROR;
9405+      }
9406+      if (value <= 0) {  // Valid shape value should be greater than 0.
9407+        MS_LOG(ERROR) << "Only support fixed shapes in third party param";
9408+        return RET_ERROR;
9409+      }
9410+      shape.push_back(value);
9411+    }
9412+    dst_shapes->push_back(shape);
9413+  }
9414+  return RET_OK;
9415+}
9416+
9417+/**
9418+ * Parse extended parameter like "key_1:value_1;key_2:value_2" and get {{"key_1", "value_1"}, {"key_2", "value_2"}}.
9419+ */
9420+int ThirdPartyParamParser::DoParseExtendedParameters(const std::string &src,
9421+                                                     std::map<std::string, std::vector<uint8_t>> *dst_ext_param) {
9422+  MS_CHECK_TRUE_RET(dst_ext_param != nullptr, RET_ERROR);
9423+  constexpr size_t kKeyIndex = 0U;
9424+  constexpr size_t kValueIndex = 1U;
9425+  constexpr size_t kKeyValueSize = 2U;
9426+
9427+  if (src == "") {  // Just return if 'extended_parameters' is configured.
9428+    return RET_OK;
9429+  }
9430+
9431+  auto tmp_list = SplitStringToVector(src, ";");
9432+  std::map<std::string, std::vector<uint8_t>> tmp_map = {};
9433+  for (auto tmp : tmp_list) {
9434+    auto key_and_value = SplitStringToVector(tmp, ":");
9435+    if (key_and_value.size() != kKeyValueSize) {
9436+      MS_LOG(ERROR) << "Parse extended parameters failed, should keep key:value format";
9437+      return RET_ERROR;
9438+    }
9439+    auto key = key_and_value[kKeyIndex];
9440+    auto value = key_and_value[kValueIndex];
9441+    if (tmp_map.find(key) != tmp_map.end()) {
9442+      MS_LOG(ERROR) << "Parse extended parameters failed, key should not be duplicated";
9443+      return RET_ERROR;
9444+    }
9445+    tmp_map.emplace(key, std::vector<uint8_t>(value.begin(), value.end()));
9446+  }
9447+
9448+  *dst_ext_param = tmp_map;
9449+  return RET_OK;
9450+}
9451+
9452+/**
9453+ * Parse dtypes like "float32;float32;int32" and return [kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeInt32]
9454+ */
9455+int ThirdPartyParamParser::DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes) {
9456+  MS_CHECK_TRUE_RET(dst_dtypes != nullptr, RET_ERROR);
9457+  dst_dtypes->clear();
9458+  auto tmp_dtypes = SplitStringToVector(src, ";");
9459+  for (auto tmp_dtype : tmp_dtypes) {
9460+    TypeId type = ConvertDataType(tmp_dtype);
9461+    if (type == kTypeUnknown) {
9462+      MS_LOG(ERROR) << "Parse dtypes in third party model config failed";
9463+      return RET_ERROR;
9464+    }
9465+    dst_dtypes->push_back(type);
9466+  }
9467+  return RET_OK;
9468+}
9469+
9470+/**
9471+ * Parse names like "foo;bar;boo" and get ["foo", "bar", "boo"]
9472+ * If input names are not provided in config, use the default prefix to generate like: "in_0;in_1;..;in_n"
9473+ */
9474+int ThirdPartyParamParser::DoParseNames(const std::string &src, size_t num, const std::string &default_prefix,
9475+                                        std::vector<std::string> *dst_names) {
9476+  MS_CHECK_TRUE_RET(dst_names != nullptr, RET_ERROR);
9477+  std::string tmp_names = src;
9478+  if (tmp_names.empty()) {
9479+    std::string tmp = "";
9480+    for (size_t i = 0; i < num; i++) {
9481+      tmp += default_prefix + "_" + std::to_string(i);
9482+      if (i + 1 < num) {
9483+        tmp += ";";
9484+      }
9485+    }
9486+    tmp_names = tmp;
9487+  }
9488+
9489+  *dst_names = SplitStringToVector(tmp_names, ";");
9490+  if (dst_names->size() != num) {
9491+    MS_LOG(ERROR) << "Name number " << dst_names->size() << " and input number: " << num << " are not equal";
9492+    return RET_ERROR;
9493+  }
9494+  return RET_OK;
9495+}
9496+
9497+/**
9498+ * Parse formats like "NCHW;NHWC" and get [NCHW, NHWC]
9499+ */
9500+namespace {
9501+  int StringToFormat(const std::string &format_string, schema::Format *format) {
9502+    static const std::unordered_map<std::string, schema::Format> kFormatTable = {
9503+      {"NCHW", schema::Format::Format_NCHW},
9504+      {"NHWC", schema::Format::Format_NHWC},
9505+      {"NHWC4", schema::Format::Format_NHWC4},
9506+      {"HWKC", schema::Format::Format_HWKC},
9507+      {"HWCK", schema::Format::Format_HWCK},
9508+      {"KCHW", schema::Format::Format_KCHW},
9509+      {"CKHW", schema::Format::Format_CKHW},
9510+      {"KHWC", schema::Format::Format_KHWC},
9511+      {"CHWK", schema::Format::Format_CHWK},
9512+      {"HW", schema::Format::Format_HW},
9513+      {"HW4", schema::Format::Format_HW4},
9514+      {"NC", schema::Format::Format_NC},
9515+      {"NC4", schema::Format::Format_NC4},
9516+      {"NC4HW4", schema::Format::Format_NC4HW4},
9517+      {"NUM_OF_FORMAT", schema::Format::Format_NUM_OF_FORMAT},
9518+      {"NCDHW", schema::Format::Format_NCDHW},
9519+      {"NWC", schema::Format::Format_NWC},
9520+      {"NCW", schema::Format::Format_NCW},
9521+    };
9522+
9523+    if (format == nullptr) {
9524+      return RET_NULL_PTR;
9525+    }
9526+
9527+    auto iter = kFormatTable.find(format_string);
9528+    if (iter == kFormatTable.end()) {
9529+      return RET_PARAM_INVALID;
9530+    }
9531+
9532+    *format = iter->second;
9533+    return RET_OK;
9534+  }
9535+}
9536+
9537+int ThirdPartyParamParser::DoParseFormats(const std::string &src, size_t num,
9538+                                          std::vector<schema::Format> *result_formats) {
9539+  MS_CHECK_TRUE_RET(result_formats != nullptr, RET_ERROR);
9540+  std::string tmp_names = src;
9541+  if (tmp_names.empty()) {
9542+    std::vector<schema::Format> default_formats(num, schema::Format::Format_NHWC);
9543+    *result_formats = default_formats;
9544+    return RET_OK;
9545+  }
9546+
9547+  auto format_strings = SplitStringToVector(tmp_names, ";");
9548+  if (format_strings.size() != num) {
9549+    MS_LOG(ERROR) << "Number of format: " << format_strings.size() << " and number of tensor: " << num << " are not equal";
9550+    return RET_ERROR;
9551+  }
9552+
9553+  std::vector<schema::Format> result(num);
9554+  for (size_t i = 0; i < num; i++) {
9555+    if (StringToFormat(format_strings[i], &result[i]) != RET_OK) {
9556+      MS_LOG(ERROR) << "Tensor format:" << format_strings[i] << " is invalid";
9557+      return RET_PARAM_INVALID;
9558+    }
9559+  }
9560+  *result_formats = result;
9561+  return RET_OK;
9562+}
9563+
9564+int ThirdPartyParamParser::Parse(const ThirdPartyModelString &param_string, ThirdPartyModelParam *param) {
9565+  MS_CHECK_TRUE_RET(param != nullptr, RET_ERROR);
9566+
9567+  auto ret = DoParseShape(param_string.input_shapes, &(param->input_shapes));
9568+  if (ret != RET_OK) {
9569+    MS_LOG(ERROR) << "Parse input shapes of third party param failed";
9570+    return RET_ERROR;
9571+  }
9572+
9573+  ret = DoParseDtypes(param_string.input_dtypes, &(param->input_dtypes));
9574+  if (ret != RET_OK) {
9575+    MS_LOG(ERROR) << "Parse input dtypes of third party param failed";
9576+    return RET_ERROR;
9577+  }
9578+
9579+  auto input_shape_num = param->input_shapes.size();
9580+  auto input_dtype_num = param->input_dtypes.size();
9581+  if (input_shape_num != input_dtype_num) {
9582+    MS_LOG(ERROR) << "Input shape number: " << input_shape_num << " and dtype number: " << input_dtype_num
9583+                  << " are not equal";
9584+    return RET_ERROR;
9585+  }
9586+
9587+  ret = DoParseFormats(param_string.input_formats, input_shape_num, &(param->input_formats));
9588+  if (ret != RET_OK) {
9589+    MS_LOG(ERROR) << "Parse input formats of third party param failed";
9590+    return RET_ERROR;
9591+  }
9592+
9593+  const std::string kInputNamePrefix = "in";
9594+  ret = DoParseNames(param_string.input_names, input_shape_num, kInputNamePrefix, &(param->input_names));
9595+  if (ret != RET_OK) {
9596+    MS_LOG(ERROR) << "Parse input names of third party param failed";
9597+    return RET_ERROR;
9598+  }
9599+
9600+  ret = DoParseShape(param_string.output_shapes, &(param->output_shapes));
9601+  if (ret != RET_OK) {
9602+    MS_LOG(ERROR) << "Parse output shaped of third party param failed";
9603+    return RET_ERROR;
9604+  }
9605+
9606+  ret = DoParseDtypes(param_string.output_dtypes, &(param->output_dtypes));
9607+  if (ret != RET_OK) {
9608+    MS_LOG(ERROR) << "Parse output dtypes of third party param failed";
9609+    return RET_ERROR;
9610+  }
9611+
9612+  auto output_shape_num = param->output_shapes.size();
9613+  auto output_dtype_num = param->output_dtypes.size();
9614+  if (output_shape_num != output_dtype_num) {
9615+    MS_LOG(ERROR) << "Output shape number: " << output_shape_num << " and dtype number: " << output_dtype_num
9616+                  << " are not equal";
9617+    return RET_ERROR;
9618+  }
9619+
9620+  ret = DoParseFormats(param_string.output_formats, output_shape_num, &(param->output_formats));
9621+  if (ret != RET_OK) {
9622+    MS_LOG(ERROR) << "Parse output formats of third party param failed";
9623+    return RET_ERROR;
9624+  }
9625+
9626+  const std::string kOutputNamePrefix = "out";
9627+  ret = DoParseNames(param_string.output_names, output_shape_num, kOutputNamePrefix, &(param->output_names));
9628+  if (ret != RET_OK) {
9629+    MS_LOG(ERROR) << "Parse output names of third party param failed";
9630+    return RET_ERROR;
9631+  }
9632+
9633+  ret = DoParseExtendedParameters(param_string.extended_parameters, &(param->extended_parameters));
9634+  if (ret != RET_OK) {
9635+    MS_LOG(ERROR) << "Parse extended parameter of third party param failed";
9636+    return RET_ERROR;
9637+  }
9638+
9639+  return RET_OK;
9640+}
9641+}  // namespace lite
9642+}  // namespace mindspore
9643diff --git a/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h
9644new file mode 100644
9645index 00000000..5cf6e8fb
9646--- /dev/null
9647+++ b/mindspore/lite/tools/converter/config_parser/third_party_param_parser.h
9648@@ -0,0 +1,44 @@
9649+/**
9650+ * Copyright 2023 Huawei Technologies Co., Ltd
9651+ *
9652+ * Licensed under the Apache License, Version 2.0 (the "License");
9653+ * you may not use this file except in compliance with the License.
9654+ * You may obtain a copy of the License at
9655+ *
9656+ * http://www.apache.org/licenses/LICENSE-2.0
9657+ *
9658+ * Unless required by applicable law or agreed to in writing, software
9659+ * distributed under the License is distributed on an "AS IS" BASIS,
9660+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9661+ * See the License for the specific language governing permissions and
9662+ * limitations under the License.
9663+ */
9664+
9665+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_
9666+#define MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_
9667+#include <string>
9668+#include <vector>
9669+#include <map>
9670+#include "include/errorcode.h"
9671+#include "tools/converter/cxx_api/converter_para.h"
9672+#include "tools/converter/config_parser/config_file_parser.h"
9673+
9674+namespace mindspore {
9675+namespace lite {
9676+class ThirdPartyParamParser {
9677+ public:
9678+  static int Parse(const lite::ThirdPartyModelString &param_string, ThirdPartyModelParam *param);
9679+
9680+ private:
9681+  static int DoParseShape(const std::string &src, std::vector<std::vector<int64_t>> *dst_shapes);
9682+  static int DoParseExtendedParameters(const std::string &src,
9683+                                       std::map<std::string, std::vector<uint8_t>> *dst_ext_param);
9684+  static int DoParseDtypes(const std::string &src, std::vector<TypeId> *dst_dtypes);
9685+  static int DoParseNames(const std::string &src, size_t num, const std::string &default_prefix,
9686+                          std::vector<std::string> *dst_names);
9687+  static int DoParseFormats(const std::string &src, size_t num, std::vector<schema::Format> *result_formats);
9688+};
9689+}  // namespace lite
9690+}  // namespace mindspore
9691+
9692+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_CONFIG_PARSER_THIRD_PARTY_PARAM_PARSER_H_
9693diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc
9694index df3176c2..a61bd51c 100644
9695--- a/mindspore/lite/tools/converter/converter.cc
9696+++ b/mindspore/lite/tools/converter/converter.cc
9697@@ -49,6 +49,7 @@
9698 #include "tools/converter/config_parser/preprocess_parser.h"
9699 #include "tools/converter/config_parser/quant_param_parser.h"
9700 #include "tools/converter/config_parser/graph_kernel_param_parser.h"
9701+#include "tools/converter/config_parser/third_party_param_parser.h"
9702 #include "tools/converter/converter_funcgraph.h"
9703 #include "tools/converter/converter_metagraph.h"
9704 #include "tools/common/string_util.h"
9705@@ -472,6 +473,12 @@ int ConverterImpl::ParseParam(lite::ConfigFileParser *config_parser, const std::
9706       MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
9707       return ret;
9708     }
9709+    ret = lite::ThirdPartyParamParser::Parse(config_parser->GetThirdPartyModelString(),
9710+                                             &param->thirdPartyModelParam);
9711+    if (ret != RET_OK) {
9712+      MS_LOG(ERROR) << "Parse third party param failed.";
9713+      return ret;
9714+    }
9715     ret = InitExtendedIntegrationInfo(param, *config_parser);
9716     if (ret != RET_OK) {
9717       MS_LOG(ERROR) << "Parse extended integration info failed.";
9718@@ -699,19 +706,20 @@ std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const s
9719 
9720 int CheckFmkType(const std::shared_ptr<ConverterPara> &param) {
9721   if (param != nullptr) {
9722-    std::set valid_values = {FmkType::kFmkTypeTf,    FmkType::kFmkTypeCaffe,  FmkType::kFmkTypeOnnx,
9723-                             FmkType::kFmkTypeMs,    FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch,
9724-                             FmkType::kFmkTypeMsLite};
9725-    if (std::find(valid_values.begin(), valid_values.end(), param->fmk_type) == valid_values.end()) {
9726-      MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be "
9727-                       "kFmkTypeTf|kFmkTypeCaffe|kFmkTypeOnnx|kFmkTypeMs|kFmkTypeTflite|kFmkTypeMsLite"
9728-                    << ", but got " << param->fmk_type;
9729-      return RET_INPUT_PARAM_INVALID;
9730-    }
9731-    if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) {
9732-      MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag";
9733-      return RET_INPUT_PARAM_INVALID;
9734-    }
9735+    return RET_OK;
9736+  }
9737+  std::set kValidFmkTypes = {FmkType::kFmkTypeTf,    FmkType::kFmkTypeCaffe,  FmkType::kFmkTypeOnnx,
9738+                           FmkType::kFmkTypeMs,    FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch,
9739+                           FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty};
9740+  if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) {
9741+    MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be "
9742+                     "TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|MSLITE|THIRDPARTY"
9743+                  << ", but got " << param->fmk_type;
9744+    return RET_INPUT_PARAM_INVALID;
9745+  }
9746+  if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) {
9747+    MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag";
9748+    return RET_INPUT_PARAM_INVALID;
9749   }
9750   return RET_OK;
9751 }
9752diff --git a/mindspore/lite/tools/converter/converter_funcgraph.cc b/mindspore/lite/tools/converter/converter_funcgraph.cc
9753index f03f995c..61d5c463 100644
9754--- a/mindspore/lite/tools/converter/converter_funcgraph.cc
9755+++ b/mindspore/lite/tools/converter/converter_funcgraph.cc
9756@@ -90,6 +90,7 @@ FuncGraphPtr ConverterFuncGraph::Load3rdModelToFuncgraph(const std::shared_ptr<C
9757   converter_parameters.save_type = param->save_type;
9758   converter_parameters.model_file = param->model_file;
9759   converter_parameters.weight_file = param->weight_file;
9760+  converter_parameters.attrs.emplace("config_file", param->config_file);
9761   func_graph_base = model_parser->Parse(converter_parameters);
9762   if (func_graph_base == nullptr) {
9763     delete model_parser;
9764@@ -447,11 +448,13 @@ STATUS ConverterFuncGraph::Optimize(const std::shared_ptr<ConverterPara> &param,
9765     return status;
9766   }
9767 
9768-  AnfTransform funcgraph_transform;
9769-  status = funcgraph_transform.Transform(func_graph, param);
9770-  if (status != RET_OK) {
9771-    MS_LOG(ERROR) << "Transform anf graph failed.";
9772-    return status;
9773+  if (param->fmk_type != converter::FmkType::kFmkTypeThirdParty) {
9774+    AnfTransform funcgraph_transform;
9775+    status = funcgraph_transform.Transform(func_graph, param);
9776+    if (status != RET_OK) {
9777+      MS_LOG(ERROR) << "Transform anf graph failed.";
9778+      return status;
9779+    }
9780   }
9781 
9782   status = UnifyFuncGraphOutputFormat(param, func_graph);
9783diff --git a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
9784index 4883c48d..024e209f 100644
9785--- a/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
9786+++ b/mindspore/lite/tools/converter/converter_lite/converter_flags.cc
9787@@ -138,11 +138,11 @@ int Flags::InitFmk() {
9788   // value check not here, it is in converter c++ API's CheckValueParam method.
9789   std::map<std::string, FmkType> StrToEnumFmkTypeMap = {
9790     {"CAFFE", kFmkTypeCaffe}, {"MINDIR", kFmkTypeMs},       {"TFLITE", kFmkTypeTflite}, {"ONNX", kFmkTypeOnnx},
9791-    {"TF", kFmkTypeTf},       {"PYTORCH", kFmkTypePytorch}, {"MSLITE", kFmkTypeMsLite}};
9792+    {"TF", kFmkTypeTf},       {"PYTORCH", kFmkTypePytorch}, {"MSLITE", kFmkTypeMsLite}, {"THIRDPARTY", kFmkTypeThirdParty}};
9793   if (StrToEnumFmkTypeMap.find(this->fmkIn) != StrToEnumFmkTypeMap.end()) {
9794     this->fmk = StrToEnumFmkTypeMap.at(this->fmkIn);
9795   } else {
9796-    std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX" << std::endl;
9797+    std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX|PYTORCH|THIRDPARTY" << std::endl;
9798     return RET_INPUT_PARAM_INVALID;
9799   }
9800 
9801diff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h
9802index a4f72a69..33210fd0 100644
9803--- a/mindspore/lite/tools/converter/cxx_api/converter_para.h
9804+++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h
9805@@ -21,6 +21,7 @@
9806 #include <vector>
9807 #include <set>
9808 #include "include/converter.h"
9809+#include "mindapi/base/type_id.h"
9810 #include "tools/converter/quantizer/quant_params.h"
9811 #include "tools/converter/preprocess/preprocess_param.h"
9812 #include "tools/converter/adapter/acl/common/acl_types.h"
9813@@ -35,6 +36,18 @@ struct ParallelSplitConfig {
9814   std::vector<std::string> parallel_devices_;
9815 };
9816 
9817+struct ThirdPartyModelParam {
9818+  std::vector<TypeId> input_dtypes;
9819+  std::vector<std::vector<int64_t>> input_shapes;
9820+  std::vector<std::string> input_names;
9821+  std::vector<schema::Format> input_formats;
9822+  std::vector<TypeId> output_dtypes;
9823+  std::vector<std::vector<int64_t>> output_shapes;
9824+  std::vector<std::string> output_names;
9825+  std::vector<schema::Format> output_formats;
9826+  std::map<std::string, std::vector<uint8_t>> extended_parameters;
9827+};
9828+
9829 struct CpuOptionCfg {
9830   std::string architecture;
9831   std::string instruction;
9832@@ -97,6 +110,7 @@ struct ConverterPara {
9833   lite::acl::AclModelOptionCfg aclModelOptionCfgParam;
9834   lite::micro::MicroParam microParam;
9835   ParallelSplitConfig parallel_split_config;
9836+  ThirdPartyModelParam thirdPartyModelParam;
9837   AscendGeOptionCfg ascendGeOptionCfg;
9838   std::string device;
9839   std::string provider;
9840diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc
9841index 90b744e5..bf1a82ae 100644
9842--- a/mindspore/lite/tools/converter/graphdef_transform.cc
9843+++ b/mindspore/lite/tools/converter/graphdef_transform.cc
9844@@ -76,11 +76,55 @@ int QuantTransform(const std::shared_ptr<ConverterPara> &param, schema::MetaGrap
9845   }
9846   return RET_OK;
9847 }
9848+
9849+int FillGraphOutputShape(MetaGraphT *meta_graph, const std::vector<std::vector<int64_t>> output_shapes) {
9850+  const auto &out_indices = meta_graph->outputIndex;
9851+  for (size_t i = 0; i < out_indices.size(); i++) {
9852+    auto &out_tensor = meta_graph->allTensors[out_indices[i]];
9853+    out_tensor->dims = {};
9854+    for (size_t k = 0; k < output_shapes[i].size(); k++) {
9855+      out_tensor->dims.push_back(static_cast<int32_t>(output_shapes[i][k]));
9856+    }
9857+  }
9858+  return RET_OK;
9859+}
9860+
9861+void FillGraphInputAndOutputFormats(MetaGraphT *meta_graph, const ConverterPara &para) {
9862+  const auto &in_indices = meta_graph->inputIndex;
9863+  for (size_t i = 0; i < in_indices.size(); i++) {
9864+    auto &in_tensor = meta_graph->allTensors[in_indices[i]];
9865+    in_tensor->format = para.thirdPartyModelParam.input_formats[i];
9866+    MS_LOG_DEBUG << "input " << i << " format: " << EnumNameFormat(in_tensor->format);
9867+  }
9868+
9869+  const auto &out_indices = meta_graph->outputIndex;
9870+  for (size_t i = 0; i < out_indices.size(); i++) {
9871+    auto &out_tensor = meta_graph->allTensors[out_indices[i]];
9872+    out_tensor->format = para.thirdPartyModelParam.output_formats[i];
9873+    MS_LOG_DEBUG << "output " << i << " format: " << EnumNameFormat(out_tensor->format);
9874+  }
9875+}
9876 }  // namespace
9877 
9878 int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> &param) {
9879   MS_ASSERT(param != nullptr);
9880   STATUS status;
9881+
9882+  if (param->fmk_type == converter::kFmkTypeThirdParty) {
9883+
9884+    // Legacy optimizer infer shape, but op Custom which wraps third party model has no infer-shape function.
9885+    // So we don't perform legacy optimization for kFmkTypeThirdParty case.
9886+    auto ret = FillGraphOutputShape(graph_defT_, param->thirdPartyModelParam.output_shapes);
9887+    if (ret != RET_OK) {
9888+      MS_LOG(ERROR) << "Fill output shape of third party model failed, ret:" << ret;
9889+      return ret;
9890+    }
9891+
9892+    // Tensor of FuncGraph has no attribute of format, so set format in MetaGraph.
9893+    FillGraphInputAndOutputFormats(graph_defT_, *param);
9894+    return RET_OK;
9895+  }
9896+
9897   {
9898     auto old_nodes = GetGraphNodes(*graph_defT_);
9899     Optimizer unused_op_remove_optimizer;
9900diff --git a/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt
9901new file mode 100644
9902index 00000000..b55e0194
9903--- /dev/null
9904+++ b/mindspore/lite/tools/converter/parser/third_party/CMakeLists.txt
9905@@ -0,0 +1,4 @@
9906+add_library(third_party_parser_mid OBJECT third_party_model_parser.cc)
9907+add_dependencies(third_party_parser_mid proto_mid)
9908+add_dependencies(third_party_parser_mid fbs_src)
9909+add_dependencies(third_party_parser_mid fbs_inner_src)
9910\ No newline at end of file
9911diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc
9912new file mode 100644
9913index 00000000..652db4af
9914--- /dev/null
9915+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.cc
9916@@ -0,0 +1,277 @@
9917+/**
9918+ * Copyright 2023 Huawei Technologies Co., Ltd
9919+ *
9920+ * Licensed under the Apache License, Version 2.0 (the "License");
9921+ * you may not use this file except in compliance with the License.
9922+ * You may obtain a copy of the License at
9923+ *
9924+ * http://www.apache.org/licenses/LICENSE-2.0
9925+ *
9926+ * Unless required by applicable law or agreed to in writing, software
9927+ * distributed under the License is distributed on an "AS IS" BASIS,
9928+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9929+ * See the License for the specific language governing permissions and
9930+ * limitations under the License.
9931+ */
9932+#include "tools/converter/parser/third_party/third_party_model_parser.h"
9933+#include <string>
9934+#include <vector>
9935+#include <memory>
9936+#include "ir/value.h"
9937+#include "mindapi/base/type_id.h"
9938+#include "src/common/log_util.h"
9939+#include "src/common/file_utils.h"
9940+#include "nnacl/op_base.h"
9941+#include "ops/primitive_c.h"
9942+#include "ops/custom.h"
9943+#include "ops/tuple_get_item.h"
9944+#include "ops/make_tuple.h"
9945+#include "ops/return.h"
9946+#include "tools/converter/config_parser/config_file_parser.h"
9947+#include "include/registry/model_parser_registry.h"
9948+#include "tools/common/graph_util.h"
9949+#include "tools/common/tensor_util.h"
9950+#include "tools/converter/converter_context.h"
9951+#include "tools/converter/parser/lite_model_parser_creator.h"
9952+
9953+using mindspore::converter::kFmkTypeThirdParty;
9954+
9955+namespace mindspore {
9956+namespace lite {
9957+api::FuncGraphPtr ThirdPartyModelParser::Parse(const converter::ConverterParameters &flag) {
9958+  model_file_ = flag.model_file;
9959+  auto &attrs = flag.attrs;
9960+  auto iter = attrs.find("config_file");
9961+  if (iter == attrs.end()) {
9962+    return nullptr;
9963+  }
9964+  auto config_file = iter->second;
9965+
9966+  auto ret = InitConfig(config_file);
9967+  if (ret != RET_OK) {
9968+    MS_LOG(ERROR) << "Init config for third party model parsing failed";
9969+    return nullptr;
9970+  }
9971+
9972+  return CreateFuncGraph();
9973+}
9974+
9975+STATUS ThirdPartyModelParser::InitConfig(const std::string &config_file) {
9976+  lite::ConfigFileParser config_parser;
9977+  if (config_file.empty()) {
9978+    MS_LOG(ERROR) << "Missing config file in converting third party model";
9979+    return RET_ERROR;
9980+  }
9981+  auto ret = config_parser.ParseConfigFile(config_file);
9982+  if (ret != RET_OK) {
9983+    MS_LOG(ERROR) << "Get third party model section from config file failed";
9984+    return RET_ERROR;
9985+  }
9986+
9987+  ret = ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), &param_);
9988+  if (ret != RET_OK) {
9989+    MS_LOG(ERROR) << "Parse third party model param failed.";
9990+    return ret;
9991+  }
9992+  return RET_OK;
9993+}
9994+
9995+api::FuncGraphPtr ThirdPartyModelParser::CreateFuncGraph() {
9996+  auto func_graph = std::make_shared<FuncGraph>();
9997+  MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
9998+  auto type_value = MakeValue(static_cast<int>(converter::kFmkTypeThirdParty));
9999+  MS_CHECK_TRUE_RET(type_value != nullptr, nullptr);
10000+  func_graph->set_attr("fmk", type_value);
10001+  auto attr_value = MakeValue("third_party");
10002+  MS_CHECK_TRUE_RET(attr_value != nullptr, nullptr);
10003+  func_graph->set_attr("graph_name", attr_value);
10004+
10005+  std::vector<AnfNodePtr> input_nodes = {};
10006+  auto ret = BuildGraphInputs(func_graph, &input_nodes);
10007+  if (ret != RET_OK) {
10008+    MS_LOG(ERROR) << "Create func graph input nodes failed";
10009+    return nullptr;
10010+  }
10011+
10012+  CNodePtr custom_node = nullptr;
10013+  ret = BuildCustomOp(func_graph, input_nodes, &custom_node);
10014+  if (ret != RET_OK) {
10015+    MS_LOG(ERROR) << "Create func graph custom op node failed";
10016+    return nullptr;
10017+  }
10018+
10019+  ret = BuildGraphOutputs(func_graph, custom_node);
10020+  if (ret != RET_OK) {
10021+    MS_LOG(ERROR) << "Create func graph output nodes failed";
10022+    return nullptr;
10023+  }
10024+
10025+  static auto manager = Manage(func_graph);
10026+  func_graph->set_manager(manager);
10027+
10028+  auto result_graph = api::MakeShared<api::FuncGraph>(func_graph);
10029+  return result_graph;
10030+}
10031+
10032+STATUS ThirdPartyModelParser::BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs) {
10033+  MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr);
10034+  auto &dtypes = param_.input_dtypes;
10035+  auto &shapes = param_.input_shapes;
10036+  auto &names = param_.input_names;
10037+
10038+  auto input_size = dtypes.size();
10039+
10040+  // Create parameter nodes for graph inputs
10041+  for (size_t i = 0; i < input_size; i++) {
10042+    auto parameter = func_graph->add_parameter();
10043+    MSLITE_CHECK_PTR(parameter);
10044+    auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]);
10045+    if (abstract_tensor == nullptr) {
10046+      MS_LOG(ERROR) << "Create tensor abstract failed";
10047+      return RET_ERROR;
10048+    }
10049+    parameter->set_abstract(abstract_tensor);
10050+    parameter->set_name(names[i]);
10051+    op_inputs->push_back(parameter);
10052+  }
10053+
10054+  // Create parameter nodes for const tensor which wrapped third model buffer.
10055+  size_t model_size = 0U;
10056+  auto model_data = ReadFile(model_file_.c_str(), &model_size);
10057+  std::vector<int64_t> model_shape = {static_cast<int64_t>(model_size)};
10058+  auto tensor_info = CreateTensorInfo(nullptr, 0, model_shape, kNumberTypeUInt8);
10059+  if (tensor_info == nullptr) {
10060+    MS_LOG(ERROR) << "init tensor info failed";
10061+    delete model_data;
10062+    return RET_NULL_PTR;
10063+  }
10064+  auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c());
10065+  if (memcpy_s(tensor_data, tensor_info->Size(), model_data, model_size) != EOK) {
10066+    MS_LOG(ERROR) << "memcpy failed.";
10067+    delete model_data;
10068+    return RET_ERROR;
10069+  }
10070+  delete model_data;
10071+  auto parameter = func_graph->add_parameter();
10072+  MSLITE_CHECK_PTR(parameter);
10073+  auto status = InitParameterFromTensorInfo(parameter, tensor_info);
10074+  if (status != RET_OK) {
10075+    MS_LOG(ERROR) << "init parameter from tensor info failed.";
10076+    return RET_ERROR;
10077+  }
10078+  parameter->set_name("ThirdPartyModel");
10079+  op_inputs->push_back(parameter);
10080+  return RET_OK;
10081+}
10082+
10083+STATUS ThirdPartyModelParser::BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs,
10084+                                            CNodePtr *operator_node) {
10085+  MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr);
10086+  NotSupportOp::GetInstance()->set_fmk_type("THIRDPARTY");
10087+  STATUS status = RET_OK;
10088+
10089+  // create primitive and build CNode of CUSTOM operator
10090+  ops::PrimitiveCPtr primitive_c;
10091+  auto prim = std::make_unique<ops::Custom>();
10092+  MS_CHECK_TRUE_RET(prim != nullptr, RET_ERROR);
10093+  prim->set_type("ThirdPartyModel");
10094+
10095+  const auto &attr = param_.extended_parameters;
10096+  prim->set_attr(attr);
10097+  primitive_c = prim->GetPrim();
10098+  if (primitive_c == nullptr) {
10099+    MS_LOG(ERROR) << "failed to create primitive: custom";
10100+    return RET_ERROR;
10101+  }
10102+
10103+  auto operator_cnode = func_graph->NewCNode(primitive_c, op_inputs);
10104+  MSLITE_CHECK_PTR(operator_cnode);
10105+  operator_cnode->set_fullname_with_scope("Custom");
10106+  *operator_node = operator_cnode;
10107+  return status;
10108+}
10109+
10110+STATUS ThirdPartyModelParser::BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node) {
10111+  MS_ASSERT(anf_node_map != nullptr && func_graph != nullptr);
10112+
10113+  auto dtypes = param_.output_dtypes;
10114+  auto shapes = param_.output_shapes;
10115+  auto names = param_.output_names;
10116+
10117+  auto output_size = dtypes.size();
10118+  std::vector<AnfNodePtr> output_nodes = {};
10119+
10120+  // Use TupleGetItem to wrap op outputs.
10121+  AbstractBasePtrList abstract_list;
10122+  for (size_t i = 0; i < output_size; i++) {
10123+    auto abstract_tensor = CreateTensorAbstract(shapes[i], dtypes[i]);
10124+    if (abstract_tensor == nullptr) {
10125+      MS_LOG(ERROR) << "Create tensor abstract failed";
10126+      return RET_ERROR;
10127+    }
10128+    abstract_list.emplace_back(abstract_tensor);
10129+    auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
10130+    if (tuple_get_item_prim_ptr == nullptr) {
10131+      MS_LOG(ERROR) << "new TupleGetItem failed";
10132+      return RET_NULL_PTR;
10133+    }
10134+    auto tuple_get_item_prim_c = tuple_get_item_prim_ptr->GetPrim();
10135+    MSLITE_CHECK_PTR(tuple_get_item_prim_c);
10136+    auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_c);
10137+    MSLITE_CHECK_PTR(tuple_get_item_prim);
10138+    auto get_item_value = NewValueNode(MakeValue<int>(i));
10139+    MSLITE_CHECK_PTR(get_item_value);
10140+    std::vector<AnfNodePtr> inputs = {tuple_get_item_prim, operator_node, get_item_value};
10141+    CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
10142+    MSLITE_CHECK_PTR(get_item_cnode);
10143+    std::string output_item_name = operator_node->fullname_with_scope() + "_getitem_" + std::to_string(i);
10144+    auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
10145+    if (get_item_abstract == nullptr) {
10146+      MS_LOG(ERROR) << "Create tensor abstarct failed";
10147+      return RET_ERROR;
10148+    }
10149+    get_item_cnode->set_fullname_with_scope(output_item_name);
10150+    get_item_cnode->set_abstract(get_item_abstract);
10151+    output_nodes.push_back(get_item_cnode);
10152+  }
10153+  auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
10154+  MSLITE_CHECK_PTR(abstract_tuple);
10155+  operator_node->set_abstract(abstract_tuple);
10156+
10157+  // Use MakeTuple node to wrap all outputs as single input of Return node.
10158+  auto make_tuple_prim_ptr = std::make_shared<ops::MakeTuple>();
10159+  if (make_tuple_prim_ptr == nullptr) {
10160+    MS_LOG(ERROR) << "new MakeTuple failed";
10161+    return RET_NULL_PTR;
10162+  }
10163+  auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim();
10164+  MSLITE_CHECK_PTR(make_tuple_prim_c);
10165+  auto make_tuple_prim = NewValueNode(make_tuple_prim_c);
10166+  MSLITE_CHECK_PTR(make_tuple_prim);
10167+  std::vector<AnfNodePtr> make_tuple_inputs = output_nodes;
10168+  make_tuple_inputs.insert(make_tuple_inputs.begin(), make_tuple_prim);
10169+  auto make_tuple_cnode = func_graph->NewCNode(make_tuple_inputs);
10170+  MSLITE_CHECK_PTR(make_tuple_cnode);
10171+  make_tuple_cnode->set_fullname_with_scope("return_tuple");
10172+
10173+  auto return_prim_ptr = std::make_shared<ops::Return>();
10174+  if (return_prim_ptr == nullptr) {
10175+    MS_LOG(ERROR) << "new Return failed";
10176+    return RET_NULL_PTR;
10177+  }
10178+  auto return_prim_c = return_prim_ptr->GetPrim();
10179+  MSLITE_CHECK_PTR(return_prim_c);
10180+  std::vector<AnfNodePtr> op_inputs{make_tuple_cnode};
10181+  auto cnode = func_graph->NewCNode(return_prim_c, op_inputs);
10182+  MSLITE_CHECK_PTR(cnode);
10183+  cnode->set_fullname_with_scope("Return");
10184+  func_graph->set_return(cnode);
10185+
10186+  // Save original output tensor names.
10187+  ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(names);
10188+  return RET_OK;
10189+}
10190+
10191+REG_MODEL_PARSER(kFmkTypeThirdParty, LiteModelParserCreator<ThirdPartyModelParser>)
10192+}  // namespace lite
10193+}  // namespace mindspore
10194diff --git a/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h
10195new file mode 100644
10196index 00000000..c4b197b8
10197--- /dev/null
10198+++ b/mindspore/lite/tools/converter/parser/third_party/third_party_model_parser.h
10199@@ -0,0 +1,50 @@
10200+/**
10201+ * Copyright 2023 Huawei Technologies Co., Ltd
10202+ *
10203+ * Licensed under the Apache License, Version 2.0 (the "License");
10204+ * you may not use this file except in compliance with the License.
10205+ * You may obtain a copy of the License at
10206+ *
10207+ * http://www.apache.org/licenses/LICENSE-2.0
10208+ *
10209+ * Unless required by applicable law or agreed to in writing, software
10210+ * distributed under the License is distributed on an "AS IS" BASIS,
10211+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10212+ * See the License for the specific language governing permissions and
10213+ * limitations under the License.
10214+ */
10215+
10216+#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_
10217+#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_
10218+
10219+#include <string>
10220+#include <vector>
10221+#include "schema/inner/model_generated.h"
10222+#include "base/base.h"
10223+#include "ir/anf.h"
10224+#include "ir/func_graph.h"
10225+#include "include/errorcode.h"
10226+#include "include/registry/model_parser.h"
10227+#include "tools/converter/config_parser/third_party_param_parser.h"
10228+
10229+namespace mindspore {
10230+namespace lite {
10231+class ThirdPartyModelParser : public converter::ModelParser {
10232+ public:
10233+  api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
10234+
10235+ private:
10236+  STATUS InitConfig(const std::string &config_file);
10237+  api::FuncGraphPtr CreateFuncGraph();
10238+  STATUS BuildGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *op_inputs);
10239+  STATUS BuildCustomOp(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &op_inputs,
10240+                       CNodePtr *operator_node);
10241+  STATUS BuildGraphOutputs(const FuncGraphPtr &func_graph, const CNodePtr &operator_node);
10242+
10243+  std::string model_file_ = "";
10244+  ThirdPartyModelParam param_;
10245+};
10246+}  // namespace lite
10247+}  // namespace mindspore
10248+
10249+#endif  // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_THIRDPARTY_THIRDPARTY_MODEL_PARSER_H_
10250diff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc
10251index 832fb92d..6bc2d4d3 100644
10252--- a/mindspore/lite/tools/converter/registry/model_parser_registry.cc
10253+++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc
10254@@ -26,7 +26,7 @@ std::map<FmkType, ModelParserCreator> model_parser_room;
10255 }  // namespace
10256 
10257 ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) {
10258-  if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) {
10259+  if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) {
10260     MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
10261     return;
10262   }
10263@@ -38,7 +38,7 @@ ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator
10264 }
10265 
10266 converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
10267-  if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypePytorch) {
10268+  if (fmk < converter::kFmkTypeTf || fmk >= converter::kFmkTypeEnd) {
10269     MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
10270     return nullptr;
10271   }
10272-- 
102732.17.1
10274
10275