1be168c0dSopenharmony_ci/* 2be168c0dSopenharmony_ci * Copyright (c) 2023 Huawei Device Co., Ltd. 3be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License"); 4be168c0dSopenharmony_ci * you may not use this file except in compliance with the License. 5be168c0dSopenharmony_ci * You may obtain a copy of the License at 6be168c0dSopenharmony_ci * 7be168c0dSopenharmony_ci * http://www.apache.org/licenses/LICENSE-2.0 8be168c0dSopenharmony_ci * 9be168c0dSopenharmony_ci * Unless required by applicable law or agreed to in writing, software 10be168c0dSopenharmony_ci * distributed under the License is distributed on an "AS IS" BASIS, 11be168c0dSopenharmony_ci * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12be168c0dSopenharmony_ci * See the License for the specific language governing permissions and 13be168c0dSopenharmony_ci * limitations under the License. 14be168c0dSopenharmony_ci */ 15be168c0dSopenharmony_ci 16be168c0dSopenharmony_ci#include "model_utils.h" 17be168c0dSopenharmony_ci#include <securec.h> 18be168c0dSopenharmony_ci#include "gtest/gtest.h" 19be168c0dSopenharmony_ci#include "include/c_api/context_c.h" 20be168c0dSopenharmony_ci#include "include/c_api/model_c.h" 21be168c0dSopenharmony_ci#include "include/c_api/types_c.h" 22be168c0dSopenharmony_ci#include "include/c_api/status_c.h" 23be168c0dSopenharmony_ci#include "include/c_api/data_type_c.h" 24be168c0dSopenharmony_ci#include "include/c_api/tensor_c.h" 25be168c0dSopenharmony_ci#include "include/c_api/format_c.h" 26be168c0dSopenharmony_ci#include "common.h" 27be168c0dSopenharmony_ci 28be168c0dSopenharmony_cistd::string g_testResourcesDir = "/data/test/resource/"; 29be168c0dSopenharmony_ci 30be168c0dSopenharmony_ci// function before callback 31be168c0dSopenharmony_cibool PrintBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 32be168c0dSopenharmony_ci const OH_AI_CallBackParam kernelInfo) { 33be168c0dSopenharmony_ci std::cout << "Before forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl; 34be168c0dSopenharmony_ci return true; 35be168c0dSopenharmony_ci} 36be168c0dSopenharmony_ci 37be168c0dSopenharmony_ci// function after callback 38be168c0dSopenharmony_cibool PrintAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 39be168c0dSopenharmony_ci const OH_AI_CallBackParam kernelInfo) { 40be168c0dSopenharmony_ci std::cout << "After forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl; 41be168c0dSopenharmony_ci return true; 42be168c0dSopenharmony_ci} 43be168c0dSopenharmony_ci 44be168c0dSopenharmony_ci// add cpu device info 45be168c0dSopenharmony_civoid AddContextDeviceCPU(OH_AI_ContextHandle context) { 46be168c0dSopenharmony_ci OH_AI_DeviceInfoHandle cpuDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU); 47be168c0dSopenharmony_ci ASSERT_NE(cpuDeviceInfo, nullptr); 48be168c0dSopenharmony_ci OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(cpuDeviceInfo); 49be168c0dSopenharmony_ci printf("==========deviceType:%d\n", deviceType); 50be168c0dSopenharmony_ci ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_CPU); 51be168c0dSopenharmony_ci OH_AI_ContextAddDeviceInfo(context, cpuDeviceInfo); 52be168c0dSopenharmony_ci} 53be168c0dSopenharmony_ci 54be168c0dSopenharmony_cibool IsNPU() { 55be168c0dSopenharmony_ci size_t num = 0; 56be168c0dSopenharmony_ci auto desc = OH_AI_GetAllNNRTDeviceDescs(&num); 57be168c0dSopenharmony_ci if (desc == nullptr) { 58be168c0dSopenharmony_ci return false; 59be168c0dSopenharmony_ci } 60be168c0dSopenharmony_ci auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc); 61be168c0dSopenharmony_ci const std::string npuNamePrefix = "NPU_"; 62be168c0dSopenharmony_ci if (strncmp(npuNamePrefix.c_str(), name, npuNamePrefix.size()) != 0) { 63be168c0dSopenharmony_ci return false; 64be168c0dSopenharmony_ci } 65be168c0dSopenharmony_ci return true; 66be168c0dSopenharmony_ci} 67be168c0dSopenharmony_ci 68be168c0dSopenharmony_ci// add nnrt device info 69be168c0dSopenharmony_civoid AddContextDeviceNNRT(OH_AI_ContextHandle context) { 70be168c0dSopenharmony_ci size_t num = 0; 71be168c0dSopenharmony_ci auto desc = OH_AI_GetAllNNRTDeviceDescs(&num); 72be168c0dSopenharmony_ci if (desc == nullptr) { 73be168c0dSopenharmony_ci return; 74be168c0dSopenharmony_ci } 75be168c0dSopenharmony_ci 76be168c0dSopenharmony_ci std::cout << "found " << num << " nnrt devices" << std::endl; 77be168c0dSopenharmony_ci auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc); 78be168c0dSopenharmony_ci auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc); 79be168c0dSopenharmony_ci auto type = OH_AI_GetTypeFromNNRTDeviceDesc(desc); 80be168c0dSopenharmony_ci std::cout << "NNRT device: id = " << id << ", name: " << name << ", type:" << type << std::endl; 81be168c0dSopenharmony_ci 82be168c0dSopenharmony_ci OH_AI_DeviceInfoHandle nnrtDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT); 83be168c0dSopenharmony_ci ASSERT_NE(nnrtDeviceInfo, nullptr); 84be168c0dSopenharmony_ci OH_AI_DeviceInfoSetDeviceId(nnrtDeviceInfo, id); 85be168c0dSopenharmony_ci OH_AI_DestroyAllNNRTDeviceDescs(&desc); 86be168c0dSopenharmony_ci 87be168c0dSopenharmony_ci OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(nnrtDeviceInfo); 88be168c0dSopenharmony_ci printf("==========deviceType:%d\n", deviceType); 89be168c0dSopenharmony_ci ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_NNRT); 90be168c0dSopenharmony_ci 91be168c0dSopenharmony_ci OH_AI_DeviceInfoSetPerformanceMode(nnrtDeviceInfo, OH_AI_PERFORMANCE_MEDIUM); 92be168c0dSopenharmony_ci ASSERT_EQ(OH_AI_DeviceInfoGetPerformanceMode(nnrtDeviceInfo), OH_AI_PERFORMANCE_MEDIUM); 93be168c0dSopenharmony_ci OH_AI_DeviceInfoSetPriority(nnrtDeviceInfo, OH_AI_PRIORITY_MEDIUM); 94be168c0dSopenharmony_ci ASSERT_EQ(OH_AI_DeviceInfoGetPriority(nnrtDeviceInfo), OH_AI_PRIORITY_MEDIUM); 95be168c0dSopenharmony_ci 96be168c0dSopenharmony_ci OH_AI_ContextAddDeviceInfo(context, nnrtDeviceInfo); 97be168c0dSopenharmony_ci} 98be168c0dSopenharmony_ci 99be168c0dSopenharmony_ci// fill data to inputs tensor 100be168c0dSopenharmony_civoid FillInputsData(OH_AI_TensorHandleArray inputs, std::string modelName, bool isTranspose) { 101be168c0dSopenharmony_ci for (size_t i = 0; i < inputs.handle_num; ++i) { 102be168c0dSopenharmony_ci printf("==========ReadFile==========\n"); 103be168c0dSopenharmony_ci size_t size1; 104be168c0dSopenharmony_ci size_t *ptrSize1 = &size1; 105be168c0dSopenharmony_ci std::string inputDataPath = g_testResourcesDir + modelName + "_" + std::to_string(i) + ".input"; 106be168c0dSopenharmony_ci const char *imagePath = inputDataPath.c_str(); 107be168c0dSopenharmony_ci char *imageBuf = ReadFile(imagePath, ptrSize1); 108be168c0dSopenharmony_ci ASSERT_NE(imageBuf, nullptr); 109be168c0dSopenharmony_ci OH_AI_TensorHandle tensor = inputs.handle_list[i]; 110be168c0dSopenharmony_ci int64_t elementNum = OH_AI_TensorGetElementNum(tensor); 111be168c0dSopenharmony_ci printf("Tensor name: %s. \n", OH_AI_TensorGetName(tensor)); 112be168c0dSopenharmony_ci float *inputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(inputs.handle_list[i])); 113be168c0dSopenharmony_ci ASSERT_NE(inputData, nullptr); 114be168c0dSopenharmony_ci if (isTranspose) { 115be168c0dSopenharmony_ci printf("==========Transpose==========\n"); 116be168c0dSopenharmony_ci size_t shapeNum; 117be168c0dSopenharmony_ci const int64_t *shape = OH_AI_TensorGetShape(tensor, &shapeNum); 118be168c0dSopenharmony_ci auto imageBufNhwc = new char[size1]; 119be168c0dSopenharmony_ci PackNCHWToNHWCFp32(imageBuf, imageBufNhwc, shape[0], shape[1] * shape[2], shape[3]); 120be168c0dSopenharmony_ci errno_t ret = memcpy_s(inputData, size1, imageBufNhwc, size1); 121be168c0dSopenharmony_ci if (ret != EOK) { 122be168c0dSopenharmony_ci printf("memcpy_s failed, ret: %d\n", ret); 123be168c0dSopenharmony_ci } 124be168c0dSopenharmony_ci delete[] imageBufNhwc; 125be168c0dSopenharmony_ci } else { 126be168c0dSopenharmony_ci errno_t ret = memcpy_s(inputData, size1, imageBuf, size1); 127be168c0dSopenharmony_ci if (ret != EOK) { 128be168c0dSopenharmony_ci printf("memcpy_s failed, ret: %d\n", ret); 129be168c0dSopenharmony_ci } 130be168c0dSopenharmony_ci } 131be168c0dSopenharmony_ci printf("input data after filling is: "); 132be168c0dSopenharmony_ci for (int j = 0; j < elementNum && j <= 20; ++j) { 133be168c0dSopenharmony_ci printf("%f ", inputData[j]); 134be168c0dSopenharmony_ci } 135be168c0dSopenharmony_ci printf("\n"); 136be168c0dSopenharmony_ci delete[] imageBuf; 137be168c0dSopenharmony_ci } 138be168c0dSopenharmony_ci} 139be168c0dSopenharmony_ci 140be168c0dSopenharmony_ci// compare result after predict 141be168c0dSopenharmony_civoid CompareResult(OH_AI_TensorHandleArray outputs, std::string modelName, float atol, float rtol) { 142be168c0dSopenharmony_ci printf("==========GetOutput==========\n"); 143be168c0dSopenharmony_ci for (size_t i = 0; i < outputs.handle_num; ++i) { 144be168c0dSopenharmony_ci OH_AI_TensorHandle tensor = outputs.handle_list[i]; 145be168c0dSopenharmony_ci int64_t elementNum = OH_AI_TensorGetElementNum(tensor); 146be168c0dSopenharmony_ci printf("Tensor name: %s .\n", OH_AI_TensorGetName(tensor)); 147be168c0dSopenharmony_ci float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(tensor)); 148be168c0dSopenharmony_ci printf("output data is:"); 149be168c0dSopenharmony_ci for (int j = 0; j < elementNum && j <= 20; ++j) { 150be168c0dSopenharmony_ci printf("%f ", outputData[j]); 151be168c0dSopenharmony_ci } 152be168c0dSopenharmony_ci printf("\n"); 153be168c0dSopenharmony_ci printf("==========compFp32WithTData==========\n"); 154be168c0dSopenharmony_ci std::string outputFile = g_testResourcesDir + modelName + std::to_string(i) + ".output"; 155be168c0dSopenharmony_ci bool result = compFp32WithTData(outputData, outputFile, atol, rtol, false); 156be168c0dSopenharmony_ci EXPECT_EQ(result, true); 157be168c0dSopenharmony_ci } 158be168c0dSopenharmony_ci} 159be168c0dSopenharmony_ci 160be168c0dSopenharmony_ci// model build and predict 161be168c0dSopenharmony_civoid ModelPredict(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 162be168c0dSopenharmony_ci OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) { 163be168c0dSopenharmony_ci std::string modelPath = g_testResourcesDir + modelName + ".ms"; 164be168c0dSopenharmony_ci const char *graphPath = modelPath.c_str(); 165be168c0dSopenharmony_ci OH_AI_Status ret = OH_AI_STATUS_SUCCESS; 166be168c0dSopenharmony_ci if (buildByGraph) { 167be168c0dSopenharmony_ci printf("==========Build model by graphBuf==========\n"); 168be168c0dSopenharmony_ci size_t size; 169be168c0dSopenharmony_ci size_t *ptrSize = &size; 170be168c0dSopenharmony_ci char *graphBuf = ReadFile(graphPath, ptrSize); 171be168c0dSopenharmony_ci ASSERT_NE(graphBuf, nullptr); 172be168c0dSopenharmony_ci ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context); 173be168c0dSopenharmony_ci delete[] graphBuf; 174be168c0dSopenharmony_ci } else { 175be168c0dSopenharmony_ci printf("==========Build model==========\n"); 176be168c0dSopenharmony_ci ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context); 177be168c0dSopenharmony_ci } 178be168c0dSopenharmony_ci printf("==========build model return code:%d\n", ret); 179be168c0dSopenharmony_ci ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS); 180be168c0dSopenharmony_ci printf("==========GetInputs==========\n"); 181be168c0dSopenharmony_ci OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model); 182be168c0dSopenharmony_ci ASSERT_NE(inputs.handle_list, nullptr); 183be168c0dSopenharmony_ci if (shapeInfos.shape_num != 0) { 184be168c0dSopenharmony_ci printf("==========Resizes==========\n"); 185be168c0dSopenharmony_ci OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num); 186be168c0dSopenharmony_ci printf("==========Resizes return code:%d\n", resize_ret); 187be168c0dSopenharmony_ci ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS); 188be168c0dSopenharmony_ci } 189be168c0dSopenharmony_ci 190be168c0dSopenharmony_ci FillInputsData(inputs, modelName, isTranspose); 191be168c0dSopenharmony_ci OH_AI_TensorHandleArray outputs; 192be168c0dSopenharmony_ci OH_AI_Status predictRet = OH_AI_STATUS_SUCCESS; 193be168c0dSopenharmony_ci if (isCallback) { 194be168c0dSopenharmony_ci printf("==========Model Predict Callback==========\n"); 195be168c0dSopenharmony_ci OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback; 196be168c0dSopenharmony_ci OH_AI_KernelCallBack afterCallBack = PrintAfterCallback; 197be168c0dSopenharmony_ci predictRet = OH_AI_ModelPredict(model, inputs, &outputs, beforeCallBack, afterCallBack); 198be168c0dSopenharmony_ci } else { 199be168c0dSopenharmony_ci printf("==========Model Predict==========\n"); 200be168c0dSopenharmony_ci predictRet = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr); 201be168c0dSopenharmony_ci } 202be168c0dSopenharmony_ci printf("==========Model Predict End==========\n"); 203be168c0dSopenharmony_ci ASSERT_EQ(predictRet, OH_AI_STATUS_SUCCESS); 204be168c0dSopenharmony_ci printf("=========CompareResult===========\n"); 205be168c0dSopenharmony_ci CompareResult(outputs, modelName); 206be168c0dSopenharmony_ci printf("=========OH_AI_ModelDestroy===========\n"); 207be168c0dSopenharmony_ci OH_AI_ModelDestroy(&model); 208be168c0dSopenharmony_ci printf("=========OH_AI_ModelDestroy End===========\n"); 209be168c0dSopenharmony_ci} 210be168c0dSopenharmony_ci 211be168c0dSopenharmony_ci// add invalid model_type 212be168c0dSopenharmony_civoid ModelPredict_ModelBuild(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 213be168c0dSopenharmony_ci bool buildByGraph) { 214be168c0dSopenharmony_ci std::string modelPath = g_testResourcesDir + modelName + ".ms"; 215be168c0dSopenharmony_ci const char *graphPath = modelPath.c_str(); 216be168c0dSopenharmony_ci OH_AI_Status ret = OH_AI_STATUS_SUCCESS; 217be168c0dSopenharmony_ci if (buildByGraph) { 218be168c0dSopenharmony_ci printf("==========Build model by graphBuf==========\n"); 219be168c0dSopenharmony_ci size_t size; 220be168c0dSopenharmony_ci size_t *ptrSize = &size; 221be168c0dSopenharmony_ci char *graphBuf = ReadFile(graphPath, ptrSize); 222be168c0dSopenharmony_ci ASSERT_NE(graphBuf, nullptr); 223be168c0dSopenharmony_ci ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_INVALID, context); 224be168c0dSopenharmony_ci if (ret != OH_AI_STATUS_LITE_PARAM_INVALID) { 225be168c0dSopenharmony_ci printf("OH_AI_ModelBuild failed due to model_type is OH_AI_MODELTYPE_INVALID.\n"); 226be168c0dSopenharmony_ci } 227be168c0dSopenharmony_ci delete[] graphBuf; 228be168c0dSopenharmony_ci } else { 229be168c0dSopenharmony_ci printf("==========Build model==========\n"); 230be168c0dSopenharmony_ci ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_INVALID, context); 231be168c0dSopenharmony_ci if (ret != OH_AI_STATUS_LITE_PARAM_INVALID) { 232be168c0dSopenharmony_ci printf("OH_AI_ModelBuildFromFile failed due to model_type is OH_AI_MODELTYPE_INVALID.\n"); 233be168c0dSopenharmony_ci } 234be168c0dSopenharmony_ci } 235be168c0dSopenharmony_ci printf("==========build model return code:%d\n", ret); 236be168c0dSopenharmony_ci} 237be168c0dSopenharmony_ci 238be168c0dSopenharmony_ci// model train build and predict 239be168c0dSopenharmony_civoid ModelTrain(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 240be168c0dSopenharmony_ci OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) { 241be168c0dSopenharmony_ci std::string modelPath = g_testResourcesDir + modelName + ".ms"; 242be168c0dSopenharmony_ci const char *graphPath = modelPath.c_str(); 243be168c0dSopenharmony_ci OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate(); 244be168c0dSopenharmony_ci OH_AI_Status ret = OH_AI_STATUS_SUCCESS; 245be168c0dSopenharmony_ci if (buildByGraph) { 246be168c0dSopenharmony_ci printf("==========Build model by graphBuf==========\n"); 247be168c0dSopenharmony_ci size_t size; 248be168c0dSopenharmony_ci size_t *ptrSize = &size; 249be168c0dSopenharmony_ci char *graphBuf = ReadFile(graphPath, ptrSize); 250be168c0dSopenharmony_ci ASSERT_NE(graphBuf, nullptr); 251be168c0dSopenharmony_ci ret = OH_AI_TrainModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context, trainCfg); 252be168c0dSopenharmony_ci delete[] graphBuf; 253be168c0dSopenharmony_ci } else { 254be168c0dSopenharmony_ci printf("==========Build model==========\n"); 255be168c0dSopenharmony_ci ret = OH_AI_TrainModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context, trainCfg); 256be168c0dSopenharmony_ci } 257be168c0dSopenharmony_ci printf("==========build model return code:%d\n", ret); 258be168c0dSopenharmony_ci ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS); 259be168c0dSopenharmony_ci printf("==========GetInputs==========\n"); 260be168c0dSopenharmony_ci OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model); 261be168c0dSopenharmony_ci ASSERT_NE(inputs.handle_list, nullptr); 262be168c0dSopenharmony_ci if (shapeInfos.shape_num != 0) { 263be168c0dSopenharmony_ci printf("==========Resizes==========\n"); 264be168c0dSopenharmony_ci OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num); 265be168c0dSopenharmony_ci printf("==========Resizes return code:%d\n", resize_ret); 266be168c0dSopenharmony_ci ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS); 267be168c0dSopenharmony_ci } 268be168c0dSopenharmony_ci FillInputsData(inputs, modelName, isTranspose); 269be168c0dSopenharmony_ci ret = OH_AI_ModelSetTrainMode(model, true); 270be168c0dSopenharmony_ci ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS); 271be168c0dSopenharmony_ci if (isCallback) { 272be168c0dSopenharmony_ci printf("==========Model RunStep Callback==========\n"); 273be168c0dSopenharmony_ci OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback; 274be168c0dSopenharmony_ci OH_AI_KernelCallBack afterCallBack = PrintAfterCallback; 275be168c0dSopenharmony_ci ret = OH_AI_RunStep(model, beforeCallBack, afterCallBack); 276be168c0dSopenharmony_ci } else { 277be168c0dSopenharmony_ci printf("==========Model RunStep==========\n"); 278be168c0dSopenharmony_ci ret = OH_AI_RunStep(model, nullptr, nullptr); 279be168c0dSopenharmony_ci } 280be168c0dSopenharmony_ci printf("==========Model RunStep End==========\n"); 281be168c0dSopenharmony_ci ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS); 282be168c0dSopenharmony_ci} 283