1/**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17/**
18 * @addtogroup MindSpore
19 * @{
20 *
21 * @brief 提供MindSpore Lite的模型推理相关接口。
22 *
23 * @Syscap SystemCapability.Ai.MindSpore
24 * @since 9
25 */
26
27/**
28 * @file model.h
29 *
30 * @brief 提供了模型相关接口,可以用于模型创建、模型推理等。
31 *
32 * @library libmindspore_lite_ndk.so
33 * @since 9
34 */
35#ifndef MINDSPORE_INCLUDE_C_API_MODEL_C_H
36#define MINDSPORE_INCLUDE_C_API_MODEL_C_H
37
38#include "mindspore/tensor.h"
39#include "mindspore/context.h"
40#include "mindspore/status.h"
41
42#ifdef __cplusplus
43extern "C" {
44#endif
45
46typedef void *OH_AI_ModelHandle;
47
48typedef void *OH_AI_TrainCfgHandle;
49
50typedef struct OH_AI_TensorHandleArray {
51  size_t handle_num;
52  OH_AI_TensorHandle *handle_list;
53} OH_AI_TensorHandleArray;
54
55#define OH_AI_MAX_SHAPE_NUM 32
56typedef struct OH_AI_ShapeInfo {
57  size_t shape_num;
58  int64_t shape[OH_AI_MAX_SHAPE_NUM];
59} OH_AI_ShapeInfo;
60
61typedef struct OH_AI_CallBackParam {
62  char *node_name;
63  char *node_type;
64} OH_AI_CallBackParam;
65
66typedef bool (*OH_AI_KernelCallBack)(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
67                                     const OH_AI_CallBackParam kernel_Info);
68
69/**
70 * @brief Create a model object.
71 * @return Model object handle.
72 * @since 9
73 */
74OH_AI_API OH_AI_ModelHandle OH_AI_ModelCreate();
75
76/**
77 * @brief Destroy the model object.
78 * @param model Model object handle address.
79 * @since 9
80 */
81OH_AI_API void OH_AI_ModelDestroy(OH_AI_ModelHandle *model);
82
83/**
84 * @brief Build the model from model file buffer so that it can run on a device.
85 * @param model Model object handle.
86 * @param model_data Define the buffer read from a model file.
87 * @param data_size Define bytes number of model file buffer.
88 * @param model_type Define The type of model file.
89 * @param model_context Define the context used to store options during execution.
90 * @return OH_AI_Status.
91 * @since 9
92 */
93OH_AI_API OH_AI_Status OH_AI_ModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
94                                        OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context);
95
96/**
97 * @brief Load and build the model from model path so that it can run on a device.
98 * @param model Model object handle.
99 * @param model_path Define the model file path.
100 * @param model_type Define The type of model file.
101 * @param model_context Define the context used to store options during execution.
102 * @return OH_AI_Status.
103 * @since 9
104 */
105OH_AI_API OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path,
106                                                OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context);
107
108/**
109 * @brief Resizes the shapes of inputs.
110 * @param model Model object handle.
111 * @param inputs The array that includes all input tensor handles.
112 * @param shape_infos Defines the new shapes of inputs, should be consistent with inputs.
113 * @param shape_info_num The num of shape_infos.
114 * @return OH_AI_Status.
115 * @since 9
116 */
117OH_AI_API OH_AI_Status OH_AI_ModelResize(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs,
118                                         OH_AI_ShapeInfo *shape_infos, size_t shape_info_num);
119
120/**
121 * @brief Inference model.
122 * @param model Model object handle.
123 * @param inputs The array that includes all input tensor handles.
124 * @param outputs The array that includes all output tensor handles.
125 * @param before CallBack before predict.
126 * @param after CallBack after predict.
127 * @return OH_AI_Status.
128 * @since 9
129 */
130OH_AI_API OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray inputs,
131                                          OH_AI_TensorHandleArray *outputs, const OH_AI_KernelCallBack before,
132                                          const OH_AI_KernelCallBack after);
133
134/**
135 * @brief Obtains all input tensor handles of the model.
136 * @param model Model object handle.
137 * @return The array that includes all input tensor handles.
138 * @since 9
139 */
140OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetInputs(const OH_AI_ModelHandle model);
141
142/**
143 * @brief Obtains all output tensor handles of the model.
144 * @param model Model object handle.
145 * @return The array that includes all output tensor handles.
146 * @since 9
147 */
148OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetOutputs(const OH_AI_ModelHandle model);
149
150/**
151 * @brief Obtains the input tensor handle of the model by name.
152 * @param model Model object handle.
153 * @param tensor_name The name of tensor.
154 * @return The input tensor handle with the given name, if the name is not found, an NULL is returned.
155 * @since 9
156 */
157OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetInputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name);
158
159/**
160 * @brief Obtains the output tensor handle of the model by name.
161 * @param model Model object handle.
162 * @param tensor_name The name of tensor.
163 * @return The output tensor handle with the given name, if the name is not found, an NULL is returned.
164 * @since 9
165 */
166OH_AI_API OH_AI_TensorHandle OH_AI_ModelGetOutputByTensorName(const OH_AI_ModelHandle model, const char *tensor_name);
167
168/**
169 * @brief Create a TrainCfg object. Only valid for Lite Train.
170 * @return TrainCfg object handle.
171 * @since 11
172 */
173OH_AI_API OH_AI_TrainCfgHandle OH_AI_TrainCfgCreate();
174
175/**
176 * @brief Destroy the train_cfg object. Only valid for Lite Train.
177 * @param train_cfg TrainCfg object handle.
178 * @since 11
179 */
180OH_AI_API void OH_AI_TrainCfgDestroy(OH_AI_TrainCfgHandle *train_cfg);
181
182/**
183 * @brief Obtains part of the name that identify a loss kernel. Only valid for Lite Train.
184 * @param train_cfg TrainCfg object handle.
185 * @param num The num of loss_name.
186 * @return loss_name.
187 * @since 11
188 */
189OH_AI_API char **OH_AI_TrainCfgGetLossName(OH_AI_TrainCfgHandle train_cfg, size_t *num);
190
191/**
192 * @brief Set part of the name that identify a loss kernel. Only valid for Lite Train.
193 * @param train_cfg TrainCfg object handle.
194 * @param loss_name Define part of the name that identify a loss kernel.
195 * @param num The num of loss_name.
196 * @since 11
197 */
198OH_AI_API void OH_AI_TrainCfgSetLossName(OH_AI_TrainCfgHandle train_cfg, const char **loss_name, size_t num);
199
200/**
201 * @brief Obtains optimization level of the train_cfg. Only valid for Lite Train.
202 * @param train_cfg TrainCfg object handle.
203 * @return OH_AI_OptimizationLevel.
204 * @since 11
205 */
206OH_AI_API OH_AI_OptimizationLevel OH_AI_TrainCfgGetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg);
207
208/**
209 * @brief Set optimization level of the train_cfg. Only valid for Lite Train.
210 * @param train_cfg TrainCfg object handle.
211 * @param level The optimization level of train_cfg.
212 * @since 11
213 */
214OH_AI_API void OH_AI_TrainCfgSetOptimizationLevel(OH_AI_TrainCfgHandle train_cfg, OH_AI_OptimizationLevel level);
215
216/**
217 * @brief Build the train model from model buffer so that it can run on a device. Only valid for Lite Train.
218 * @param model Model object handle.
219 * @param model_data Define the buffer read from a model file.
220 * @param data_size Define bytes number of model file buffer.
221 * @param model_type Define The type of model file.
222 * @param model_context Define the context used to store options during execution.
223 * @param train_cfg Define the config used by training.
224 * @return OH_AI_Status.
225 * @since 11
226 */
227OH_AI_API OH_AI_Status OH_AI_TrainModelBuild(OH_AI_ModelHandle model, const void *model_data, size_t data_size,
228                                             OH_AI_ModelType model_type, const OH_AI_ContextHandle model_context,
229                                             const OH_AI_TrainCfgHandle train_cfg);
230
231/**
232 * @brief Build the train model from model file buffer so that it can run on a device. Only valid for Lite Train.
233 * @param model Model object handle.
234 * @param model_path Define the model path.
235 * @param model_type Define The type of model file.
236 * @param model_context Define the context used to store options during execution.
237 * @param train_cfg Define the config used by training.
238 * @return OH_AI_Status.
239 * @since 11
240 */
241OH_AI_API OH_AI_Status OH_AI_TrainModelBuildFromFile(OH_AI_ModelHandle model, const char *model_path,
242                                                     OH_AI_ModelType model_type,
243                                                     const OH_AI_ContextHandle model_context,
244                                                     const OH_AI_TrainCfgHandle train_cfg);
245
246/**
247 * @brief Train model by step. Only valid for Lite Train.
248 * @param model Model object handle.
249 * @param before CallBack before predict.
250 * @param after CallBack after predict.
251 * @return OH_AI_Status.
252 * @since 11
253 */
254OH_AI_API OH_AI_Status OH_AI_RunStep(OH_AI_ModelHandle model, const OH_AI_KernelCallBack before,
255                                     const OH_AI_KernelCallBack after);
256
257/**
258 * @brief Sets the Learning Rate of the training. Only valid for Lite Train.
259 * @param learning_rate to set.
260 * @return OH_AI_Status of operation.
261 * @since 11
262 */
263OH_AI_API OH_AI_Status OH_AI_ModelSetLearningRate(OH_AI_ModelHandle model, float learning_rate);
264
265/**
266 * @brief Obtains the Learning Rate of the optimizer. Only valid for Lite Train.
267 * @param model Model object handle.
268 * @return Learning rate. 0.0 if no optimizer was found.
269 * @since 11
270 */
271OH_AI_API float OH_AI_ModelGetLearningRate(OH_AI_ModelHandle model);
272
273/**
274 * @brief Obtains all weights tensors of the model. Only valid for Lite Train.
275 * @param model Model object handle.
276 * @return The vector that includes all gradient tensors.
277 * @since 11
278 */
279OH_AI_API OH_AI_TensorHandleArray OH_AI_ModelGetWeights(OH_AI_ModelHandle model);
280
281/**
282 * @brief update weights tensors of the model. Only valid for Lite Train.
283 * @param new_weights A vector new weights.
284 * @return OH_AI_Status
285 * @since 11
286 */
287OH_AI_API OH_AI_Status OH_AI_ModelUpdateWeights(OH_AI_ModelHandle model, const OH_AI_TensorHandleArray new_weights);
288
289/**
290 * @brief Get the model running mode.
291 * @param model Model object handle.
292 * @return Is Train Mode or not.
293 * @since 11
294 */
295OH_AI_API bool OH_AI_ModelGetTrainMode(OH_AI_ModelHandle model);
296
297/**
298 * @brief Set the model running mode. Only valid for Lite Train.
299 * @param model Model object handle.
300 * @param train True means model runs in Train Mode, otherwise Eval Mode.
301 * @return OH_AI_Status.
302 * @since 11
303 */
304OH_AI_API OH_AI_Status OH_AI_ModelSetTrainMode(OH_AI_ModelHandle model, bool train);
305
306/**
307 * @brief Setup training with virtual batches. Only valid for Lite Train.
308 * @param model Model object handle.
309 * @param virtual_batch_multiplier Virtual batch multiplier, use any number < 1 to disable.
310 * @param lr Learning rate to use for virtual batch, -1 for internal configuration.
311 * @param momentum Batch norm momentum to use for virtual batch, -1 for internal configuration.
312 * @return OH_AI_Status.
313 * @since 11
314 */
315OH_AI_API OH_AI_Status OH_AI_ModelSetupVirtualBatch(OH_AI_ModelHandle model, int virtual_batch_multiplier, float lr,
316                                                    float momentum);
317
318/**
319 * @brief Export training model from file. Only valid for Lite Train.
320 * @param model The model data.
321 * @param model_type The model file type.
322 * @param model_file The exported model file.
323 * @param quantization_type The quantification type.
324 * @param export_inference_only Whether to export a reasoning only model.
325 * @param output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
326 *        empty, and export the complete reasoning model.
327 * @param num The number of output_tensor_name.
328 * @return OH_AI_Status.
329 * @since 11
330 */
331OH_AI_API OH_AI_Status OH_AI_ExportModel(OH_AI_ModelHandle model, OH_AI_ModelType model_type, const char *model_file,
332                                         OH_AI_QuantizationType quantization_type, bool export_inference_only,
333                                         char **output_tensor_name, size_t num);
334
335/**
336 * @brief Export training model from buffer. Only valid for Lite Train.
337 * @param model The model data.
338 * @param model_type The model file type.
339 * @param model_data The exported model buffer.
340 * @param data_size The exported model buffer size.
341 * @param quantization_type The quantification type.
342 * @param export_inference_only Whether to export a reasoning only model.
343 * @param output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
344 *        empty, and export the complete reasoning model.
345 * @param num The number of output_tensor_name.
346 * @return OH_AI_Status.
347 * @since 11
348 */
349OH_AI_API OH_AI_Status OH_AI_ExportModelBuffer(OH_AI_ModelHandle model, OH_AI_ModelType model_type, char **model_data,
350                                                size_t *data_size, OH_AI_QuantizationType quantization_type,
351                                                bool export_inference_only, char **output_tensor_name, size_t num);
352
353/**
354 * @brief Export model's weights, which can be used in micro only. Only valid for Lite Train.
355 * @param model The model data.
356 * @param model_type The model file type.
357 * @param weight_file The path of exported weight file.
358 * @param is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`.
359 * @param enable_fp16 Float-weight is whether to be saved in float16 format.
360 * @param changeable_weights_name The set the name of these weight tensors, whose shape is changeable.
361 * @param num The number of changeable_weights_name.
362 * @return OH_AI_Status.
363 * @since 11
364 */
365OH_AI_API OH_AI_Status OH_AI_ExportWeightsCollaborateWithMicro(OH_AI_ModelHandle model, OH_AI_ModelType model_type,
366                                                               const char *weight_file, bool is_inference,
367                                                               bool enable_fp16, char **changeable_weights_name,
368                                                               size_t num);
369
370#ifdef __cplusplus
371}
372#endif
373#endif  // MINDSPORE_INCLUDE_C_API_MODEL_C_H
374