1be168c0dSopenharmony_ci/* 2be168c0dSopenharmony_ci * Copyright (c) 2023 Huawei Device Co., Ltd. 3be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License"); 4be168c0dSopenharmony_ci * you may not use this file except in compliance with the License. 5be168c0dSopenharmony_ci * You may obtain a copy of the License at 6be168c0dSopenharmony_ci * 7be168c0dSopenharmony_ci * http://www.apache.org/licenses/LICENSE-2.0 8be168c0dSopenharmony_ci * 9be168c0dSopenharmony_ci * Unless required by applicable law or agreed to in writing, software 10be168c0dSopenharmony_ci * distributed under the License is distributed on an "AS IS" BASIS, 11be168c0dSopenharmony_ci * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12be168c0dSopenharmony_ci * See the License for the specific language governing permissions and 13be168c0dSopenharmony_ci * limitations under the License. 14be168c0dSopenharmony_ci */ 15be168c0dSopenharmony_ci#ifndef OHOS_MINDSPORE_TEST_MODEL_UTILS_H 16be168c0dSopenharmony_ci#define OHOS_MINDSPORE_TEST_MODEL_UTILS_H 17be168c0dSopenharmony_ci 18be168c0dSopenharmony_ci#include "gtest/gtest.h" 19be168c0dSopenharmony_ci#include "include/c_api/context_c.h" 20be168c0dSopenharmony_ci#include "include/c_api/model_c.h" 21be168c0dSopenharmony_ci#include "include/c_api/types_c.h" 22be168c0dSopenharmony_ci#include "include/c_api/status_c.h" 23be168c0dSopenharmony_ci#include "include/c_api/data_type_c.h" 24be168c0dSopenharmony_ci#include "include/c_api/tensor_c.h" 25be168c0dSopenharmony_ci#include "include/c_api/format_c.h" 26be168c0dSopenharmony_ci 27be168c0dSopenharmony_ci// function before callback 28be168c0dSopenharmony_cibool PrintBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 29be168c0dSopenharmony_ci const OH_AI_CallBackParam kernelInfo); 30be168c0dSopenharmony_ci// function after callback 31be168c0dSopenharmony_cibool PrintAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 32be168c0dSopenharmony_ci const OH_AI_CallBackParam kernelInfo); 33be168c0dSopenharmony_ci// add cpu device info 34be168c0dSopenharmony_civoid AddContextDeviceCPU(OH_AI_ContextHandle context); 35be168c0dSopenharmony_cibool IsNPU(); 36be168c0dSopenharmony_ci// add nnrt device info 37be168c0dSopenharmony_civoid AddContextDeviceNNRT(OH_AI_ContextHandle context); 38be168c0dSopenharmony_ci// fill data to inputs tensor 39be168c0dSopenharmony_civoid FillInputsData(OH_AI_TensorHandleArray inputs, std::string modelName, bool isTranspose); 40be168c0dSopenharmony_ci// compare result after predict 41be168c0dSopenharmony_civoid CompareResult(OH_AI_TensorHandleArray outputs, std::string modelName, float atol = 0.01, float rtol = 0.01); 42be168c0dSopenharmony_ci// model build and predict 43be168c0dSopenharmony_civoid ModelPredict(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 44be168c0dSopenharmony_ci OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback); 45be168c0dSopenharmony_civoid ModelPredict_ModelBuild(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 46be168c0dSopenharmony_ci bool buildByGraph); 47be168c0dSopenharmony_civoid ModelTrain(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 48be168c0dSopenharmony_ci OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback); 49be168c0dSopenharmony_ci 50be168c0dSopenharmony_ci#endif //OHOS_MINDSPORE_TEST_MODEL_UTILS_H 51