1e41f4b71Sopenharmony_ci# Using the MindSpore Lite Engine for On-Device Training (C/C++)
2e41f4b71Sopenharmony_ci
3e41f4b71Sopenharmony_ci## When to Use
4e41f4b71Sopenharmony_ci
5e41f4b71Sopenharmony_ciMindSpore Lite is an AI engine that implements AI model inference for different hardware devices. It has been used in a wide range of fields, such as image classification, target recognition, facial recognition, and character recognition. In addition, MindSpore Lite supports deployment of model training on devices, making it possible to adapt to user behavior in actual service scenarios.
6e41f4b71Sopenharmony_ci
7e41f4b71Sopenharmony_ciThis topic describes the general development process for using MindSpore Lite for model training on devices.
8e41f4b71Sopenharmony_ci
9e41f4b71Sopenharmony_ci
10e41f4b71Sopenharmony_ci## Available APIs
11e41f4b71Sopenharmony_ciThe following table list some APIs for using MindSpore Lite for model training.
12e41f4b71Sopenharmony_ci
13e41f4b71Sopenharmony_ci| API       | Description       |
14e41f4b71Sopenharmony_ci| ------------------ | ----------------- |
15e41f4b71Sopenharmony_ci|OH_AI_ContextHandle OH_AI_ContextCreate()|Creates a context object.|
16e41f4b71Sopenharmony_ci|OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type)|Creates a runtime device information object.|
17e41f4b71Sopenharmony_ci|void OH_AI_ContextDestroy(OH_AI_ContextHandle *context)|Destroys a context object.|
18e41f4b71Sopenharmony_ci|void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info)|Adds a runtime device information object.|
19e41f4b71Sopenharmony_ci|OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate()|Creates the pointer to a training configuration object.|
20e41f4b71Sopenharmony_ci|void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg)|Destroys the pointer to a training configuration object.|
21e41f4b71Sopenharmony_ci|OH_AI_ModelHandle OH_AI_ModelCreate()|Creates a model object.|
22e41f4b71Sopenharmony_ci|OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path, OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context, const OH_AI_TrainCfgHandle train_cfg)|Loads and builds a MindSpore training model from a model file.|
23e41f4b71Sopenharmony_ci|OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before, const OH_AI_KernelCallBack after)|Runs a single-step training model.|
24e41f4b71Sopenharmony_ci|OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train)|Sets the training mode.|
25e41f4b71Sopenharmony_ci|OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file, OH_AI_QuantizationType quantization_type, bool export_inference_only, char **output_tensor_name, size_t num)|Exports a trained MS model.|
26e41f4b71Sopenharmony_ci|void OH_AI_ModelDestroy(OH_AI_ModelHandle *model)|Destroys a model object.|
27e41f4b71Sopenharmony_ci
28e41f4b71Sopenharmony_ci
29e41f4b71Sopenharmony_ci## How to Develop
30e41f4b71Sopenharmony_ciThe following figure shows the development process for MindSpore Lite model training.
31e41f4b71Sopenharmony_ci
32e41f4b71Sopenharmony_ci**Figure 1** Development process for MindSpore Lite model training
33e41f4b71Sopenharmony_ci![how-to-use-train](figures/train_sequence_unify_api.png)
34e41f4b71Sopenharmony_ci
35e41f4b71Sopenharmony_ciBefore moving to the development process, you need to reference related header files and compile functions to generate random input. The sample code is as follows:
36e41f4b71Sopenharmony_ci
37e41f4b71Sopenharmony_ci```c
38e41f4b71Sopenharmony_ci#include <stdlib.h>
39e41f4b71Sopenharmony_ci#include <stdio.h>
40e41f4b71Sopenharmony_ci#include <string.h>
41e41f4b71Sopenharmony_ci#include "mindspore/model.h"
42e41f4b71Sopenharmony_ci
43e41f4b71Sopenharmony_ciint GenerateInputDataWithRandom(OH_AI_TensorHandleArray inputs) {
44e41f4b71Sopenharmony_ci  for (size_t i = 0; i < inputs.handle_num; ++i) {
45e41f4b71Sopenharmony_ci    float *input_data = (float *)OH_AI_TensorGetMutableData(inputs.handle_list[i]);
46e41f4b71Sopenharmony_ci    if (input_data == NULL) {
47e41f4b71Sopenharmony_ci      printf("OH_AI_TensorGetMutableData failed.\n");
48e41f4b71Sopenharmony_ci      return  OH_AI_STATUS_LITE_ERROR;
49e41f4b71Sopenharmony_ci    }
50e41f4b71Sopenharmony_ci    int64_t num = OH_AI_TensorGetElementNum(inputs.handle_list[i]);
51e41f4b71Sopenharmony_ci    const int divisor = 10;
52e41f4b71Sopenharmony_ci    for (size_t j = 0; j < num; j++) {
53e41f4b71Sopenharmony_ci      input_data[j] = (float)(rand() % divisor) / divisor;  // 0--0.9f
54e41f4b71Sopenharmony_ci    }
55e41f4b71Sopenharmony_ci  }
56e41f4b71Sopenharmony_ci  return OH_AI_STATUS_SUCCESS;
57e41f4b71Sopenharmony_ci}
58e41f4b71Sopenharmony_ci```
59e41f4b71Sopenharmony_ci
60e41f4b71Sopenharmony_ciThe development process consists of the following main steps:
61e41f4b71Sopenharmony_ci
62e41f4b71Sopenharmony_ci1. Prepare the required model.
63e41f4b71Sopenharmony_ci
64e41f4b71Sopenharmony_ci    The prepared model is in `.ms` format. This topic uses [lenet_train.ms](https://gitee.com/openharmony-sig/compatibility/blob/master/test_suite/resource/master/standard%20system/acts/resource/ai/mindspore/lenet_train/lenet_train.ms) as an example. To use a custom model, perform the following steps:
65e41f4b71Sopenharmony_ci
66e41f4b71Sopenharmony_ci    - Use Python to create a network model based on the MindSpore architecture and export the model as a `.mindir` file. For details, see [Quick Start](https://www.mindspore.cn/tutorials/en/r2.1/beginner/quick_start.html).
67e41f4b71Sopenharmony_ci    - Convert the `.mindir` model file into an `.ms` file. For details about the conversion procedure, see [Converting MindSpore Lite Models](https://www.mindspore.cn/lite/docs/en/r2.1/use/converter_train.html). The `.ms` file can be imported to the device to implement training based on the MindSpore device framework.
68e41f4b71Sopenharmony_ci
69e41f4b71Sopenharmony_ci2. Create a context and set parameters such as the device type and training configuration.
70e41f4b71Sopenharmony_ci
71e41f4b71Sopenharmony_ci    ```c
72e41f4b71Sopenharmony_ci    // Create and init context, add CPU device info
73e41f4b71Sopenharmony_ci    OH_AI_ContextHandle context = OH_AI_ContextCreate();
74e41f4b71Sopenharmony_ci    if (context == NULL) {
75e41f4b71Sopenharmony_ci        printf("OH_AI_ContextCreate failed.\n");
76e41f4b71Sopenharmony_ci        return OH_AI_STATUS_LITE_ERROR;
77e41f4b71Sopenharmony_ci    }
78e41f4b71Sopenharmony_ci
79e41f4b71Sopenharmony_ci    OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
80e41f4b71Sopenharmony_ci    if (cpu_device_info == NULL) {
81e41f4b71Sopenharmony_ci        printf("OH_AI_DeviceInfoCreate failed.\n");
82e41f4b71Sopenharmony_ci        OH_AI_ContextDestroy(&context);
83e41f4b71Sopenharmony_ci        return OH_AI_STATUS_LITE_ERROR;
84e41f4b71Sopenharmony_ci    }
85e41f4b71Sopenharmony_ci    OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
86e41f4b71Sopenharmony_ci
87e41f4b71Sopenharmony_ci    // Create trainCfg
88e41f4b71Sopenharmony_ci    OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
89e41f4b71Sopenharmony_ci    if (trainCfg == NULL) {
90e41f4b71Sopenharmony_ci        printf("OH_AI_TrainCfgCreate failed.\n");
91e41f4b71Sopenharmony_ci        OH_AI_ContextDestroy(&context);
92e41f4b71Sopenharmony_ci        return OH_AI_STATUS_LITE_ERROR;
93e41f4b71Sopenharmony_ci    }
94e41f4b71Sopenharmony_ci    ```
95e41f4b71Sopenharmony_ci
96e41f4b71Sopenharmony_ci3. Create, load, and build the model.
97e41f4b71Sopenharmony_ci
98e41f4b71Sopenharmony_ci    Call **OH_AI_TrainModelBuildFromFile** to load and build the model.
99e41f4b71Sopenharmony_ci
100e41f4b71Sopenharmony_ci    ```c
101e41f4b71Sopenharmony_ci    // Create model
102e41f4b71Sopenharmony_ci    OH_AI_ModelHandle model = OH_AI_ModelCreate();
103e41f4b71Sopenharmony_ci    if (model == NULL) {
104e41f4b71Sopenharmony_ci        printf("OH_AI_ModelCreate failed.\n");
105e41f4b71Sopenharmony_ci        OH_AI_TrainCfgDestroy(&trainCfg);
106e41f4b71Sopenharmony_ci        OH_AI_ContextDestroy(&context);
107e41f4b71Sopenharmony_ci        return OH_AI_STATUS_LITE_ERROR;
108e41f4b71Sopenharmony_ci    }
109e41f4b71Sopenharmony_ci
110e41f4b71Sopenharmony_ci    // Build model
111e41f4b71Sopenharmony_ci    int ret = OH_AI_TrainModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
112e41f4b71Sopenharmony_ci    if (ret != OH_AI_STATUS_SUCCESS) {
113e41f4b71Sopenharmony_ci        printf("OH_AI_TrainModelBuildFromFile failed, ret: %d.\n", ret);
114e41f4b71Sopenharmony_ci        OH_AI_ModelDestroy(&model);
115e41f4b71Sopenharmony_ci        return ret;
116e41f4b71Sopenharmony_ci    }
117e41f4b71Sopenharmony_ci    ```
118e41f4b71Sopenharmony_ci
119e41f4b71Sopenharmony_ci4. Input data.
120e41f4b71Sopenharmony_ci
121e41f4b71Sopenharmony_ci    Before executing model training, you need to populate data to the input tensor. In this example, random data is used to populate the model.
122e41f4b71Sopenharmony_ci
123e41f4b71Sopenharmony_ci    ```c
124e41f4b71Sopenharmony_ci    // Get Inputs
125e41f4b71Sopenharmony_ci    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
126e41f4b71Sopenharmony_ci    if (inputs.handle_list == NULL) {
127e41f4b71Sopenharmony_ci        printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
128e41f4b71Sopenharmony_ci        OH_AI_ModelDestroy(&model);
129e41f4b71Sopenharmony_ci        return ret;
130e41f4b71Sopenharmony_ci    }
131e41f4b71Sopenharmony_ci
132e41f4b71Sopenharmony_ci    // Generate random data as input data.
133e41f4b71Sopenharmony_ci    ret = GenerateInputDataWithRandom(inputs);
134e41f4b71Sopenharmony_ci    if (ret != OH_AI_STATUS_SUCCESS) {
135e41f4b71Sopenharmony_ci        printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
136e41f4b71Sopenharmony_ci        OH_AI_ModelDestroy(&model);
137e41f4b71Sopenharmony_ci        return ret;
138e41f4b71Sopenharmony_ci    }
139e41f4b71Sopenharmony_ci    ```
140e41f4b71Sopenharmony_ci
141e41f4b71Sopenharmony_ci5. Execute model training.
142e41f4b71Sopenharmony_ci
143e41f4b71Sopenharmony_ci    Use **OH_AI_ModelSetTrainMode** to set the training mode and use **OH_AI_RunStep** to run model training.
144e41f4b71Sopenharmony_ci
145e41f4b71Sopenharmony_ci    ```c
146e41f4b71Sopenharmony_ci    // Set Traim Mode
147e41f4b71Sopenharmony_ci    ret = OH_AI_ModelSetTrainMode(model, true);
148e41f4b71Sopenharmony_ci    if (ret != OH_AI_STATUS_SUCCESS) {
149e41f4b71Sopenharmony_ci        printf("OH_AI_ModelSetTrainMode failed, ret: %d.\n", ret);
150e41f4b71Sopenharmony_ci        OH_AI_ModelDestroy(&model);
151e41f4b71Sopenharmony_ci        return ret;
152e41f4b71Sopenharmony_ci    }
153e41f4b71Sopenharmony_ci
154e41f4b71Sopenharmony_ci    // Model Train Step
155e41f4b71Sopenharmony_ci    ret = OH_AI_RunStep(model, NULL, NULL);
156e41f4b71Sopenharmony_ci    if (ret != OH_AI_STATUS_SUCCESS) {
157e41f4b71Sopenharmony_ci        printf("OH_AI_RunStep failed, ret: %d.\n", ret);
158e41f4b71Sopenharmony_ci        OH_AI_ModelDestroy(&model);
159e41f4b71Sopenharmony_ci        return ret;
160e41f4b71Sopenharmony_ci    }
161e41f4b71Sopenharmony_ci    printf("Train Step Success.\n");
162e41f4b71Sopenharmony_ci    ```
163e41f4b71Sopenharmony_ci
164e41f4b71Sopenharmony_ci6. Export the trained model.
165e41f4b71Sopenharmony_ci
166e41f4b71Sopenharmony_ci    Use **OH_AI_ExportModel** to export the trained model.
167e41f4b71Sopenharmony_ci
168e41f4b71Sopenharmony_ci    ```c
169e41f4b71Sopenharmony_ci    // Export Train Model
170e41f4b71Sopenharmony_ci    ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_train_model, OH_AI_NO_QUANT, false, NULL, 0);
171e41f4b71Sopenharmony_ci    if (ret != OH_AI_STATUS_SUCCESS) {
172e41f4b71Sopenharmony_ci        printf("OH_AI_ExportModel train failed, ret: %d.\n", ret);
173e41f4b71Sopenharmony_ci        OH_AI_ModelDestroy(&model);
174e41f4b71Sopenharmony_ci        return ret;
175e41f4b71Sopenharmony_ci    }
176e41f4b71Sopenharmony_ci    printf("Export Train Model Success.\n");
177e41f4b71Sopenharmony_ci
178e41f4b71Sopenharmony_ci    // Export Inference Model
179e41f4b71Sopenharmony_ci    ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_infer_model, OH_AI_NO_QUANT, true, NULL, 0);
180e41f4b71Sopenharmony_ci    if (ret != OH_AI_STATUS_SUCCESS) {
181e41f4b71Sopenharmony_ci        printf("OH_AI_ExportModel inference failed, ret: %d.\n", ret);
182e41f4b71Sopenharmony_ci        OH_AI_ModelDestroy(&model);
183e41f4b71Sopenharmony_ci        return ret;
184e41f4b71Sopenharmony_ci    }
185e41f4b71Sopenharmony_ci    printf("Export Inference Model Success.\n");
186e41f4b71Sopenharmony_ci    ```
187e41f4b71Sopenharmony_ci
188e41f4b71Sopenharmony_ci7. Destroy the model.
189e41f4b71Sopenharmony_ci
190e41f4b71Sopenharmony_ci    If the MindSpore Lite inference framework is no longer needed, you need to destroy the created model.
191e41f4b71Sopenharmony_ci
192e41f4b71Sopenharmony_ci    ```c
193e41f4b71Sopenharmony_ci    // Delete model.
194e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
195e41f4b71Sopenharmony_ci    ```
196e41f4b71Sopenharmony_ci
197e41f4b71Sopenharmony_ci
198e41f4b71Sopenharmony_ci## Verification
199e41f4b71Sopenharmony_ci
200e41f4b71Sopenharmony_ci1. Write **CMakeLists.txt**.
201e41f4b71Sopenharmony_ci    ```c
202e41f4b71Sopenharmony_ci    cmake_minimum_required(VERSION 3.14)
203e41f4b71Sopenharmony_ci    project(TrainDemo)
204e41f4b71Sopenharmony_ci
205e41f4b71Sopenharmony_ci    add_executable(train_demo main.c)
206e41f4b71Sopenharmony_ci
207e41f4b71Sopenharmony_ci    target_link_libraries(
208e41f4b71Sopenharmony_ci            train_demo
209e41f4b71Sopenharmony_ci            mindspore_lite_ndk
210e41f4b71Sopenharmony_ci    )
211e41f4b71Sopenharmony_ci    ```
212e41f4b71Sopenharmony_ci
213e41f4b71Sopenharmony_ci   - To use ohos-sdk for cross compilation, you need to set the native toolchain path for the CMake tool as follows: `-DCMAKE_TOOLCHAIN_FILE="/xxx/native/build/cmake/ohos.toolchain.camke"`.
214e41f4b71Sopenharmony_ci
215e41f4b71Sopenharmony_ci   - Start cross compilation. When running the compilation command, set **OHOS_NDK** to the native toolchain path.
216e41f4b71Sopenharmony_ci      ```shell
217e41f4b71Sopenharmony_ci        mkdir -p build
218e41f4b71Sopenharmony_ci
219e41f4b71Sopenharmony_ci        cd ./build || exit
220e41f4b71Sopenharmony_ci        OHOS_NDK=""
221e41f4b71Sopenharmony_ci        cmake -G "Unix Makefiles" \
222e41f4b71Sopenharmony_ci              -S ../ \
223e41f4b71Sopenharmony_ci              -DCMAKE_TOOLCHAIN_FILE="$OHOS_NDK/build/cmake/ohos.toolchain.cmake" \
224e41f4b71Sopenharmony_ci              -DOHOS_ARCH=arm64-v8a \
225e41f4b71Sopenharmony_ci              -DCMAKE_BUILD_TYPE=Release
226e41f4b71Sopenharmony_ci
227e41f4b71Sopenharmony_ci        make
228e41f4b71Sopenharmony_ci      ```
229e41f4b71Sopenharmony_ci
230e41f4b71Sopenharmony_ci2. Run the executable program for compilation.
231e41f4b71Sopenharmony_ci
232e41f4b71Sopenharmony_ci    - Use hdc to connect to the device and put **train_demo** and **lenet_train.ms** to the same directory on the device.
233e41f4b71Sopenharmony_ci    - Use hdc shell to access the device, go to the directory where **train_demo** is located, and run the following command:
234e41f4b71Sopenharmony_ci
235e41f4b71Sopenharmony_ci    ```shell
236e41f4b71Sopenharmony_ci    ./train_demo ./lenet_train.ms export_train_model export_infer_model
237e41f4b71Sopenharmony_ci    ```
238e41f4b71Sopenharmony_ci
239e41f4b71Sopenharmony_ci    The operation is successful if the output is similar to the following:
240e41f4b71Sopenharmony_ci
241e41f4b71Sopenharmony_ci    ```shell
242e41f4b71Sopenharmony_ci    Train Step Success.
243e41f4b71Sopenharmony_ci    Export Train Model Success.
244e41f4b71Sopenharmony_ci    Export Inference Model Success.
245e41f4b71Sopenharmony_ci    Tensor name: Default/network-WithLossCell/_backbone-LeNet5/fc3-Dense/BiasAdd-op121, tensor size is 80, elements num: 20.
246e41f4b71Sopenharmony_ci    output data is:
247e41f4b71Sopenharmony_ci    0.000265 0.000231 0.000254 0.000269 0.000238 0.000228
248e41f4b71Sopenharmony_ci    ```
249e41f4b71Sopenharmony_ci
250e41f4b71Sopenharmony_ci    In the directory where **train_demo** is located, you can view the exported model files **export_train_model.ms** and **export_infer_model.ms**.
251e41f4b71Sopenharmony_ci
252e41f4b71Sopenharmony_ci
253e41f4b71Sopenharmony_ci## Sample
254e41f4b71Sopenharmony_ci
255e41f4b71Sopenharmony_ci```c
256e41f4b71Sopenharmony_ci#include <stdlib.h>
257e41f4b71Sopenharmony_ci#include <stdio.h>
258e41f4b71Sopenharmony_ci#include <string.h>
259e41f4b71Sopenharmony_ci#include "mindspore/model.h"
260e41f4b71Sopenharmony_ci
261e41f4b71Sopenharmony_ciint GenerateInputDataWithRandom(OH_AI_TensorHandleArray inputs) {
262e41f4b71Sopenharmony_ci  for (size_t i = 0; i < inputs.handle_num; ++i) {
263e41f4b71Sopenharmony_ci    float *input_data = (float *)OH_AI_TensorGetMutableData(inputs.handle_list[i]);
264e41f4b71Sopenharmony_ci    if (input_data == NULL) {
265e41f4b71Sopenharmony_ci      printf("OH_AI_TensorGetMutableData failed.\n");
266e41f4b71Sopenharmony_ci      return  OH_AI_STATUS_LITE_ERROR;
267e41f4b71Sopenharmony_ci    }
268e41f4b71Sopenharmony_ci    int64_t num = OH_AI_TensorGetElementNum(inputs.handle_list[i]);
269e41f4b71Sopenharmony_ci    const int divisor = 10;
270e41f4b71Sopenharmony_ci    for (size_t j = 0; j < num; j++) {
271e41f4b71Sopenharmony_ci      input_data[j] = (float)(rand() % divisor) / divisor;  // 0--0.9f
272e41f4b71Sopenharmony_ci    }
273e41f4b71Sopenharmony_ci  }
274e41f4b71Sopenharmony_ci  return OH_AI_STATUS_SUCCESS;
275e41f4b71Sopenharmony_ci}
276e41f4b71Sopenharmony_ci
277e41f4b71Sopenharmony_ciint ModelPredict(char* model_file) {
278e41f4b71Sopenharmony_ci  // Create and init context, add CPU device info
279e41f4b71Sopenharmony_ci  OH_AI_ContextHandle context = OH_AI_ContextCreate();
280e41f4b71Sopenharmony_ci  if (context == NULL) {
281e41f4b71Sopenharmony_ci    printf("OH_AI_ContextCreate failed.\n");
282e41f4b71Sopenharmony_ci    return OH_AI_STATUS_LITE_ERROR;
283e41f4b71Sopenharmony_ci  }
284e41f4b71Sopenharmony_ci
285e41f4b71Sopenharmony_ci  OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
286e41f4b71Sopenharmony_ci  if (cpu_device_info == NULL) {
287e41f4b71Sopenharmony_ci    printf("OH_AI_DeviceInfoCreate failed.\n");
288e41f4b71Sopenharmony_ci    OH_AI_ContextDestroy(&context);
289e41f4b71Sopenharmony_ci    return OH_AI_STATUS_LITE_ERROR;
290e41f4b71Sopenharmony_ci  }
291e41f4b71Sopenharmony_ci  OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
292e41f4b71Sopenharmony_ci
293e41f4b71Sopenharmony_ci  // Create model
294e41f4b71Sopenharmony_ci  OH_AI_ModelHandle model = OH_AI_ModelCreate();
295e41f4b71Sopenharmony_ci  if (model == NULL) {
296e41f4b71Sopenharmony_ci    printf("OH_AI_ModelCreate failed.\n");
297e41f4b71Sopenharmony_ci    OH_AI_ContextDestroy(&context);
298e41f4b71Sopenharmony_ci    return OH_AI_STATUS_LITE_ERROR;
299e41f4b71Sopenharmony_ci  }
300e41f4b71Sopenharmony_ci
301e41f4b71Sopenharmony_ci  // Build model
302e41f4b71Sopenharmony_ci  int ret = OH_AI_ModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context);
303e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
304e41f4b71Sopenharmony_ci    printf("OH_AI_ModelBuildFromFile failed, ret: %d.\n", ret);
305e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
306e41f4b71Sopenharmony_ci    return ret;
307e41f4b71Sopenharmony_ci  }
308e41f4b71Sopenharmony_ci
309e41f4b71Sopenharmony_ci  // Get Inputs
310e41f4b71Sopenharmony_ci  OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
311e41f4b71Sopenharmony_ci  if (inputs.handle_list == NULL) {
312e41f4b71Sopenharmony_ci    printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
313e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
314e41f4b71Sopenharmony_ci    return ret;
315e41f4b71Sopenharmony_ci  }
316e41f4b71Sopenharmony_ci
317e41f4b71Sopenharmony_ci  // Generate random data as input data.
318e41f4b71Sopenharmony_ci  ret = GenerateInputDataWithRandom(inputs);
319e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
320e41f4b71Sopenharmony_ci    printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
321e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
322e41f4b71Sopenharmony_ci    return ret;
323e41f4b71Sopenharmony_ci  }
324e41f4b71Sopenharmony_ci
325e41f4b71Sopenharmony_ci  // Model Predict
326e41f4b71Sopenharmony_ci  OH_AI_TensorHandleArray outputs;
327e41f4b71Sopenharmony_ci  ret = OH_AI_ModelPredict(model, inputs, &outputs, NULL, NULL);
328e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
329e41f4b71Sopenharmony_ci    printf("MSModelPredict failed, ret: %d.\n", ret);
330e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
331e41f4b71Sopenharmony_ci    return ret;
332e41f4b71Sopenharmony_ci  }
333e41f4b71Sopenharmony_ci
334e41f4b71Sopenharmony_ci  // Print Output Tensor Data.
335e41f4b71Sopenharmony_ci  for (size_t i = 0; i < outputs.handle_num; ++i) {
336e41f4b71Sopenharmony_ci    OH_AI_TensorHandle tensor = outputs.handle_list[i];
337e41f4b71Sopenharmony_ci    int64_t element_num = OH_AI_TensorGetElementNum(tensor);
338e41f4b71Sopenharmony_ci    printf("Tensor name: %s, tensor size is %ld ,elements num: %ld.\n", OH_AI_TensorGetName(tensor),
339e41f4b71Sopenharmony_ci           OH_AI_TensorGetDataSize(tensor), element_num);
340e41f4b71Sopenharmony_ci    const float *data = (const float *)OH_AI_TensorGetData(tensor);
341e41f4b71Sopenharmony_ci    printf("output data is:\n");
342e41f4b71Sopenharmony_ci    const int max_print_num = 50;
343e41f4b71Sopenharmony_ci    for (int j = 0; j < element_num && j <= max_print_num; ++j) {
344e41f4b71Sopenharmony_ci      printf("%f ", data[j]);
345e41f4b71Sopenharmony_ci    }
346e41f4b71Sopenharmony_ci    printf("\n");
347e41f4b71Sopenharmony_ci  }
348e41f4b71Sopenharmony_ci
349e41f4b71Sopenharmony_ci  OH_AI_ModelDestroy(&model);
350e41f4b71Sopenharmony_ci  return OH_AI_STATUS_SUCCESS;
351e41f4b71Sopenharmony_ci}
352e41f4b71Sopenharmony_ci
353e41f4b71Sopenharmony_ciint TrainDemo(int argc, const char **argv) {
354e41f4b71Sopenharmony_ci  if (argc < 4) {
355e41f4b71Sopenharmony_ci    printf("Model file must be provided.\n");
356e41f4b71Sopenharmony_ci    printf("Export Train Model path must be provided.\n");
357e41f4b71Sopenharmony_ci    printf("Export Inference Model path must be provided.\n");
358e41f4b71Sopenharmony_ci    return OH_AI_STATUS_LITE_ERROR;
359e41f4b71Sopenharmony_ci  }
360e41f4b71Sopenharmony_ci  const char *model_file = argv[1];
361e41f4b71Sopenharmony_ci  const char *export_train_model = argv[2];
362e41f4b71Sopenharmony_ci  const char *export_infer_model = argv[3];
363e41f4b71Sopenharmony_ci
364e41f4b71Sopenharmony_ci  // Create and init context, add CPU device info
365e41f4b71Sopenharmony_ci  OH_AI_ContextHandle context = OH_AI_ContextCreate();
366e41f4b71Sopenharmony_ci  if (context == NULL) {
367e41f4b71Sopenharmony_ci    printf("OH_AI_ContextCreate failed.\n");
368e41f4b71Sopenharmony_ci    return OH_AI_STATUS_LITE_ERROR;
369e41f4b71Sopenharmony_ci  }
370e41f4b71Sopenharmony_ci
371e41f4b71Sopenharmony_ci  OH_AI_DeviceInfoHandle cpu_device_info = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
372e41f4b71Sopenharmony_ci  if (cpu_device_info == NULL) {
373e41f4b71Sopenharmony_ci    printf("OH_AI_DeviceInfoCreate failed.\n");
374e41f4b71Sopenharmony_ci    OH_AI_ContextDestroy(&context);
375e41f4b71Sopenharmony_ci    return OH_AI_STATUS_LITE_ERROR;
376e41f4b71Sopenharmony_ci  }
377e41f4b71Sopenharmony_ci  OH_AI_ContextAddDeviceInfo(context, cpu_device_info);
378e41f4b71Sopenharmony_ci
379e41f4b71Sopenharmony_ci  // Create trainCfg
380e41f4b71Sopenharmony_ci  OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
381e41f4b71Sopenharmony_ci  if (trainCfg == NULL) {
382e41f4b71Sopenharmony_ci    printf("OH_AI_TrainCfgCreate failed.\n");
383e41f4b71Sopenharmony_ci    OH_AI_ContextDestroy(&context);
384e41f4b71Sopenharmony_ci    return OH_AI_STATUS_LITE_ERROR;
385e41f4b71Sopenharmony_ci  }
386e41f4b71Sopenharmony_ci
387e41f4b71Sopenharmony_ci  // Create model
388e41f4b71Sopenharmony_ci  OH_AI_ModelHandle model = OH_AI_ModelCreate();
389e41f4b71Sopenharmony_ci  if (model == NULL) {
390e41f4b71Sopenharmony_ci    printf("OH_AI_ModelCreate failed.\n");
391e41f4b71Sopenharmony_ci    OH_AI_TrainCfgDestroy(&trainCfg);
392e41f4b71Sopenharmony_ci    OH_AI_ContextDestroy(&context);
393e41f4b71Sopenharmony_ci    return OH_AI_STATUS_LITE_ERROR;
394e41f4b71Sopenharmony_ci  }
395e41f4b71Sopenharmony_ci
396e41f4b71Sopenharmony_ci  // Build model
397e41f4b71Sopenharmony_ci  int ret = OH_AI_TrainModelBuildFromFile(model, model_file, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
398e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
399e41f4b71Sopenharmony_ci    printf("OH_AI_TrainModelBuildFromFile failed, ret: %d.\n", ret);
400e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
401e41f4b71Sopenharmony_ci    return ret;
402e41f4b71Sopenharmony_ci  }
403e41f4b71Sopenharmony_ci
404e41f4b71Sopenharmony_ci  // Get Inputs
405e41f4b71Sopenharmony_ci  OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
406e41f4b71Sopenharmony_ci  if (inputs.handle_list == NULL) {
407e41f4b71Sopenharmony_ci    printf("OH_AI_ModelGetInputs failed, ret: %d.\n", ret);
408e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
409e41f4b71Sopenharmony_ci    return ret;
410e41f4b71Sopenharmony_ci  }
411e41f4b71Sopenharmony_ci
412e41f4b71Sopenharmony_ci  // Generate random data as input data.
413e41f4b71Sopenharmony_ci  ret = GenerateInputDataWithRandom(inputs);
414e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
415e41f4b71Sopenharmony_ci    printf("GenerateInputDataWithRandom failed, ret: %d.\n", ret);
416e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
417e41f4b71Sopenharmony_ci    return ret;
418e41f4b71Sopenharmony_ci  }
419e41f4b71Sopenharmony_ci
420e41f4b71Sopenharmony_ci  // Set Traim Mode
421e41f4b71Sopenharmony_ci  ret = OH_AI_ModelSetTrainMode(model, true);
422e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
423e41f4b71Sopenharmony_ci    printf("OH_AI_ModelSetTrainMode failed, ret: %d.\n", ret);
424e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
425e41f4b71Sopenharmony_ci    return ret;
426e41f4b71Sopenharmony_ci  }
427e41f4b71Sopenharmony_ci
428e41f4b71Sopenharmony_ci  // Model Train Step
429e41f4b71Sopenharmony_ci  ret = OH_AI_RunStep(model, NULL, NULL);
430e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
431e41f4b71Sopenharmony_ci    printf("OH_AI_RunStep failed, ret: %d.\n", ret);
432e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
433e41f4b71Sopenharmony_ci    return ret;
434e41f4b71Sopenharmony_ci  }
435e41f4b71Sopenharmony_ci  printf("Train Step Success.\n");
436e41f4b71Sopenharmony_ci
437e41f4b71Sopenharmony_ci  // Export Train Model
438e41f4b71Sopenharmony_ci  ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_train_model, OH_AI_NO_QUANT, false, NULL, 0);
439e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
440e41f4b71Sopenharmony_ci    printf("OH_AI_ExportModel train failed, ret: %d.\n", ret);
441e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
442e41f4b71Sopenharmony_ci    return ret;
443e41f4b71Sopenharmony_ci  }
444e41f4b71Sopenharmony_ci  printf("Export Train Model Success.\n");
445e41f4b71Sopenharmony_ci
446e41f4b71Sopenharmony_ci  // Export Inference Model
447e41f4b71Sopenharmony_ci  ret = OH_AI_ExportModel(model, OH_AI_MODELTYPE_MINDIR, export_infer_model, OH_AI_NO_QUANT, true, NULL, 0);
448e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
449e41f4b71Sopenharmony_ci    printf("OH_AI_ExportModel inference failed, ret: %d.\n", ret);
450e41f4b71Sopenharmony_ci    OH_AI_ModelDestroy(&model);
451e41f4b71Sopenharmony_ci    return ret;
452e41f4b71Sopenharmony_ci  }
453e41f4b71Sopenharmony_ci  printf("Export Inference Model Success.\n");
454e41f4b71Sopenharmony_ci
455e41f4b71Sopenharmony_ci  // Delete model.
456e41f4b71Sopenharmony_ci  OH_AI_ModelDestroy(&model);
457e41f4b71Sopenharmony_ci
458e41f4b71Sopenharmony_ci  // Use The Exported Model to predict
459e41f4b71Sopenharmony_ci  ret = ModelPredict(strcat(export_infer_model, ".ms"));
460e41f4b71Sopenharmony_ci  if (ret != OH_AI_STATUS_SUCCESS) {
461e41f4b71Sopenharmony_ci    printf("Exported Model to predict failed, ret: %d.\n", ret);
462e41f4b71Sopenharmony_ci    return ret;
463e41f4b71Sopenharmony_ci  }
464e41f4b71Sopenharmony_ci  return OH_AI_STATUS_SUCCESS;
465e41f4b71Sopenharmony_ci}
466e41f4b71Sopenharmony_ci
467e41f4b71Sopenharmony_ciint main(int argc, const char **argv) { return TrainDemo(argc, argv); }
468e41f4b71Sopenharmony_ci
469e41f4b71Sopenharmony_ci```
470