1/* 2 * Copyright (c) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16#include "model_utils.h" 17#include <securec.h> 18#include "gtest/gtest.h" 19#include "include/c_api/context_c.h" 20#include "include/c_api/model_c.h" 21#include "include/c_api/types_c.h" 22#include "include/c_api/status_c.h" 23#include "include/c_api/data_type_c.h" 24#include "include/c_api/tensor_c.h" 25#include "include/c_api/format_c.h" 26#include "common.h" 27 28std::string g_testResourcesDir = "/data/test/resource/"; 29 30// function before callback 31bool PrintBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 32 const OH_AI_CallBackParam kernelInfo) { 33 std::cout << "Before forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl; 34 return true; 35} 36 37// function after callback 38bool PrintAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs, 39 const OH_AI_CallBackParam kernelInfo) { 40 std::cout << "After forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl; 41 return true; 42} 43 44// add cpu device info 45void AddContextDeviceCPU(OH_AI_ContextHandle context) { 46 OH_AI_DeviceInfoHandle cpuDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU); 47 ASSERT_NE(cpuDeviceInfo, nullptr); 48 OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(cpuDeviceInfo); 49 printf("==========deviceType:%d\n", deviceType); 50 ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_CPU); 51 OH_AI_ContextAddDeviceInfo(context, cpuDeviceInfo); 52} 53 54bool IsNPU() { 55 size_t num = 0; 56 auto desc = OH_AI_GetAllNNRTDeviceDescs(&num); 57 if (desc == nullptr) { 58 return false; 59 } 60 auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc); 61 const std::string npuNamePrefix = "NPU_"; 62 if (strncmp(npuNamePrefix.c_str(), name, npuNamePrefix.size()) != 0) { 63 return false; 64 } 65 return true; 66} 67 68// add nnrt device info 69void AddContextDeviceNNRT(OH_AI_ContextHandle context) { 70 size_t num = 0; 71 auto desc = OH_AI_GetAllNNRTDeviceDescs(&num); 72 if (desc == nullptr) { 73 return; 74 } 75 76 std::cout << "found " << num << " nnrt devices" << std::endl; 77 auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc); 78 auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc); 79 auto type = OH_AI_GetTypeFromNNRTDeviceDesc(desc); 80 std::cout << "NNRT device: id = " << id << ", name: " << name << ", type:" << type << std::endl; 81 82 OH_AI_DeviceInfoHandle nnrtDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT); 83 ASSERT_NE(nnrtDeviceInfo, nullptr); 84 OH_AI_DeviceInfoSetDeviceId(nnrtDeviceInfo, id); 85 OH_AI_DestroyAllNNRTDeviceDescs(&desc); 86 87 OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(nnrtDeviceInfo); 88 printf("==========deviceType:%d\n", deviceType); 89 ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_NNRT); 90 91 OH_AI_DeviceInfoSetPerformanceMode(nnrtDeviceInfo, OH_AI_PERFORMANCE_MEDIUM); 92 ASSERT_EQ(OH_AI_DeviceInfoGetPerformanceMode(nnrtDeviceInfo), OH_AI_PERFORMANCE_MEDIUM); 93 OH_AI_DeviceInfoSetPriority(nnrtDeviceInfo, OH_AI_PRIORITY_MEDIUM); 94 ASSERT_EQ(OH_AI_DeviceInfoGetPriority(nnrtDeviceInfo), OH_AI_PRIORITY_MEDIUM); 95 96 OH_AI_ContextAddDeviceInfo(context, nnrtDeviceInfo); 97} 98 99// fill data to inputs tensor 100void FillInputsData(OH_AI_TensorHandleArray inputs, std::string modelName, bool isTranspose) { 101 for (size_t i = 0; i < inputs.handle_num; ++i) { 102 printf("==========ReadFile==========\n"); 103 size_t size1; 104 size_t *ptrSize1 = &size1; 105 std::string inputDataPath = g_testResourcesDir + modelName + "_" + std::to_string(i) + ".input"; 106 const char *imagePath = inputDataPath.c_str(); 107 char *imageBuf = ReadFile(imagePath, ptrSize1); 108 ASSERT_NE(imageBuf, nullptr); 109 OH_AI_TensorHandle tensor = inputs.handle_list[i]; 110 int64_t elementNum = OH_AI_TensorGetElementNum(tensor); 111 printf("Tensor name: %s. \n", OH_AI_TensorGetName(tensor)); 112 float *inputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(inputs.handle_list[i])); 113 ASSERT_NE(inputData, nullptr); 114 if (isTranspose) { 115 printf("==========Transpose==========\n"); 116 size_t shapeNum; 117 const int64_t *shape = OH_AI_TensorGetShape(tensor, &shapeNum); 118 auto imageBufNhwc = new char[size1]; 119 PackNCHWToNHWCFp32(imageBuf, imageBufNhwc, shape[0], shape[1] * shape[2], shape[3]); 120 errno_t ret = memcpy_s(inputData, size1, imageBufNhwc, size1); 121 if (ret != EOK) { 122 printf("memcpy_s failed, ret: %d\n", ret); 123 } 124 delete[] imageBufNhwc; 125 } else { 126 errno_t ret = memcpy_s(inputData, size1, imageBuf, size1); 127 if (ret != EOK) { 128 printf("memcpy_s failed, ret: %d\n", ret); 129 } 130 } 131 printf("input data after filling is: "); 132 for (int j = 0; j < elementNum && j <= 20; ++j) { 133 printf("%f ", inputData[j]); 134 } 135 printf("\n"); 136 delete[] imageBuf; 137 } 138} 139 140// compare result after predict 141void CompareResult(OH_AI_TensorHandleArray outputs, std::string modelName, float atol, float rtol) { 142 printf("==========GetOutput==========\n"); 143 for (size_t i = 0; i < outputs.handle_num; ++i) { 144 OH_AI_TensorHandle tensor = outputs.handle_list[i]; 145 int64_t elementNum = OH_AI_TensorGetElementNum(tensor); 146 printf("Tensor name: %s .\n", OH_AI_TensorGetName(tensor)); 147 float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(tensor)); 148 printf("output data is:"); 149 for (int j = 0; j < elementNum && j <= 20; ++j) { 150 printf("%f ", outputData[j]); 151 } 152 printf("\n"); 153 printf("==========compFp32WithTData==========\n"); 154 std::string outputFile = g_testResourcesDir + modelName + std::to_string(i) + ".output"; 155 bool result = compFp32WithTData(outputData, outputFile, atol, rtol, false); 156 EXPECT_EQ(result, true); 157 } 158} 159 160// model build and predict 161void ModelPredict(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 162 OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) { 163 std::string modelPath = g_testResourcesDir + modelName + ".ms"; 164 const char *graphPath = modelPath.c_str(); 165 OH_AI_Status ret = OH_AI_STATUS_SUCCESS; 166 if (buildByGraph) { 167 printf("==========Build model by graphBuf==========\n"); 168 size_t size; 169 size_t *ptrSize = &size; 170 char *graphBuf = ReadFile(graphPath, ptrSize); 171 ASSERT_NE(graphBuf, nullptr); 172 ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context); 173 delete[] graphBuf; 174 } else { 175 printf("==========Build model==========\n"); 176 ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context); 177 } 178 printf("==========build model return code:%d\n", ret); 179 ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS); 180 printf("==========GetInputs==========\n"); 181 OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model); 182 ASSERT_NE(inputs.handle_list, nullptr); 183 if (shapeInfos.shape_num != 0) { 184 printf("==========Resizes==========\n"); 185 OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num); 186 printf("==========Resizes return code:%d\n", resize_ret); 187 ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS); 188 } 189 190 FillInputsData(inputs, modelName, isTranspose); 191 OH_AI_TensorHandleArray outputs; 192 OH_AI_Status predictRet = OH_AI_STATUS_SUCCESS; 193 if (isCallback) { 194 printf("==========Model Predict Callback==========\n"); 195 OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback; 196 OH_AI_KernelCallBack afterCallBack = PrintAfterCallback; 197 predictRet = OH_AI_ModelPredict(model, inputs, &outputs, beforeCallBack, afterCallBack); 198 } else { 199 printf("==========Model Predict==========\n"); 200 predictRet = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr); 201 } 202 printf("==========Model Predict End==========\n"); 203 ASSERT_EQ(predictRet, OH_AI_STATUS_SUCCESS); 204 printf("=========CompareResult===========\n"); 205 CompareResult(outputs, modelName); 206 printf("=========OH_AI_ModelDestroy===========\n"); 207 OH_AI_ModelDestroy(&model); 208 printf("=========OH_AI_ModelDestroy End===========\n"); 209} 210 211// add invalid model_type 212void ModelPredict_ModelBuild(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 213 bool buildByGraph) { 214 std::string modelPath = g_testResourcesDir + modelName + ".ms"; 215 const char *graphPath = modelPath.c_str(); 216 OH_AI_Status ret = OH_AI_STATUS_SUCCESS; 217 if (buildByGraph) { 218 printf("==========Build model by graphBuf==========\n"); 219 size_t size; 220 size_t *ptrSize = &size; 221 char *graphBuf = ReadFile(graphPath, ptrSize); 222 ASSERT_NE(graphBuf, nullptr); 223 ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_INVALID, context); 224 if (ret != OH_AI_STATUS_LITE_PARAM_INVALID) { 225 printf("OH_AI_ModelBuild failed due to model_type is OH_AI_MODELTYPE_INVALID.\n"); 226 } 227 delete[] graphBuf; 228 } else { 229 printf("==========Build model==========\n"); 230 ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_INVALID, context); 231 if (ret != OH_AI_STATUS_LITE_PARAM_INVALID) { 232 printf("OH_AI_ModelBuildFromFile failed due to model_type is OH_AI_MODELTYPE_INVALID.\n"); 233 } 234 } 235 printf("==========build model return code:%d\n", ret); 236} 237 238// model train build and predict 239void ModelTrain(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName, 240 OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) { 241 std::string modelPath = g_testResourcesDir + modelName + ".ms"; 242 const char *graphPath = modelPath.c_str(); 243 OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate(); 244 OH_AI_Status ret = OH_AI_STATUS_SUCCESS; 245 if (buildByGraph) { 246 printf("==========Build model by graphBuf==========\n"); 247 size_t size; 248 size_t *ptrSize = &size; 249 char *graphBuf = ReadFile(graphPath, ptrSize); 250 ASSERT_NE(graphBuf, nullptr); 251 ret = OH_AI_TrainModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context, trainCfg); 252 delete[] graphBuf; 253 } else { 254 printf("==========Build model==========\n"); 255 ret = OH_AI_TrainModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context, trainCfg); 256 } 257 printf("==========build model return code:%d\n", ret); 258 ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS); 259 printf("==========GetInputs==========\n"); 260 OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model); 261 ASSERT_NE(inputs.handle_list, nullptr); 262 if (shapeInfos.shape_num != 0) { 263 printf("==========Resizes==========\n"); 264 OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num); 265 printf("==========Resizes return code:%d\n", resize_ret); 266 ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS); 267 } 268 FillInputsData(inputs, modelName, isTranspose); 269 ret = OH_AI_ModelSetTrainMode(model, true); 270 ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS); 271 if (isCallback) { 272 printf("==========Model RunStep Callback==========\n"); 273 OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback; 274 OH_AI_KernelCallBack afterCallBack = PrintAfterCallback; 275 ret = OH_AI_RunStep(model, beforeCallBack, afterCallBack); 276 } else { 277 printf("==========Model RunStep==========\n"); 278 ret = OH_AI_RunStep(model, nullptr, nullptr); 279 } 280 printf("==========Model RunStep End==========\n"); 281 ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS); 282} 283