1be168c0dSopenharmony_ci/*
2be168c0dSopenharmony_ci * Copyright (c) 2023 Huawei Device Co., Ltd.
3be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License");
4be168c0dSopenharmony_ci * you may not use this file except in compliance with the License.
5be168c0dSopenharmony_ci * You may obtain a copy of the License at
6be168c0dSopenharmony_ci *
7be168c0dSopenharmony_ci *     http://www.apache.org/licenses/LICENSE-2.0
8be168c0dSopenharmony_ci *
9be168c0dSopenharmony_ci * Unless required by applicable law or agreed to in writing, software
10be168c0dSopenharmony_ci * distributed under the License is distributed on an "AS IS" BASIS,
11be168c0dSopenharmony_ci * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12be168c0dSopenharmony_ci * See the License for the specific language governing permissions and
13be168c0dSopenharmony_ci * limitations under the License.
14be168c0dSopenharmony_ci */
15be168c0dSopenharmony_ci#ifndef OHOS_MINDSPORE_TEST_MODEL_UTILS_H
16be168c0dSopenharmony_ci#define OHOS_MINDSPORE_TEST_MODEL_UTILS_H
17be168c0dSopenharmony_ci
18be168c0dSopenharmony_ci#include "gtest/gtest.h"
19be168c0dSopenharmony_ci#include "include/c_api/context_c.h"
20be168c0dSopenharmony_ci#include "include/c_api/model_c.h"
21be168c0dSopenharmony_ci#include "include/c_api/types_c.h"
22be168c0dSopenharmony_ci#include "include/c_api/status_c.h"
23be168c0dSopenharmony_ci#include "include/c_api/data_type_c.h"
24be168c0dSopenharmony_ci#include "include/c_api/tensor_c.h"
25be168c0dSopenharmony_ci#include "include/c_api/format_c.h"
26be168c0dSopenharmony_ci
27be168c0dSopenharmony_ci// function before callback
28be168c0dSopenharmony_cibool PrintBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
29be168c0dSopenharmony_ci                         const OH_AI_CallBackParam kernelInfo);
30be168c0dSopenharmony_ci// function after callback
31be168c0dSopenharmony_cibool PrintAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
32be168c0dSopenharmony_ci                        const OH_AI_CallBackParam kernelInfo);
33be168c0dSopenharmony_ci// add cpu device info
34be168c0dSopenharmony_civoid AddContextDeviceCPU(OH_AI_ContextHandle context);
35be168c0dSopenharmony_cibool IsNPU();
36be168c0dSopenharmony_ci// add nnrt device info
37be168c0dSopenharmony_civoid AddContextDeviceNNRT(OH_AI_ContextHandle context);
38be168c0dSopenharmony_ci// fill data to inputs tensor
39be168c0dSopenharmony_civoid FillInputsData(OH_AI_TensorHandleArray inputs, std::string modelName, bool isTranspose);
40be168c0dSopenharmony_ci// compare result after predict
41be168c0dSopenharmony_civoid CompareResult(OH_AI_TensorHandleArray outputs, std::string modelName, float atol = 0.01, float rtol = 0.01);
42be168c0dSopenharmony_ci// model build and predict
43be168c0dSopenharmony_civoid ModelPredict(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
44be168c0dSopenharmony_ci                  OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback);
45be168c0dSopenharmony_civoid ModelPredict_ModelBuild(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
46be168c0dSopenharmony_ci                             bool buildByGraph);
47be168c0dSopenharmony_civoid ModelTrain(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
48be168c0dSopenharmony_ci                OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback);
49be168c0dSopenharmony_ci
50be168c0dSopenharmony_ci#endif //OHOS_MINDSPORE_TEST_MODEL_UTILS_H
51