1/*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "model_utils.h"
17#include <securec.h>
18#include "gtest/gtest.h"
19#include "include/c_api/context_c.h"
20#include "include/c_api/model_c.h"
21#include "include/c_api/types_c.h"
22#include "include/c_api/status_c.h"
23#include "include/c_api/data_type_c.h"
24#include "include/c_api/tensor_c.h"
25#include "include/c_api/format_c.h"
26#include "common.h"
27
28std::string g_testResourcesDir = "/data/test/resource/";
29
30// function before callback
31bool PrintBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
32                         const OH_AI_CallBackParam kernelInfo) {
33    std::cout << "Before forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl;
34    return true;
35}
36
37// function after callback
38bool PrintAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
39                        const OH_AI_CallBackParam kernelInfo) {
40    std::cout << "After forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl;
41    return true;
42}
43
44// add cpu device info
45void AddContextDeviceCPU(OH_AI_ContextHandle context) {
46    OH_AI_DeviceInfoHandle cpuDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
47    ASSERT_NE(cpuDeviceInfo, nullptr);
48    OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(cpuDeviceInfo);
49    printf("==========deviceType:%d\n", deviceType);
50    ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_CPU);
51    OH_AI_ContextAddDeviceInfo(context, cpuDeviceInfo);
52}
53
54bool IsNPU() {
55    size_t num = 0;
56    auto desc = OH_AI_GetAllNNRTDeviceDescs(&num);
57    if (desc == nullptr) {
58        return false;
59    }
60    auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
61    const std::string npuNamePrefix = "NPU_";
62    if (strncmp(npuNamePrefix.c_str(), name, npuNamePrefix.size()) != 0) {
63        return false;
64    }
65    return true;
66}
67
68// add nnrt device info
69void AddContextDeviceNNRT(OH_AI_ContextHandle context) {
70    size_t num = 0;
71    auto desc = OH_AI_GetAllNNRTDeviceDescs(&num);
72    if (desc == nullptr) {
73        return;
74    }
75
76    std::cout << "found " << num << " nnrt devices" << std::endl;
77    auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc);
78    auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
79    auto type = OH_AI_GetTypeFromNNRTDeviceDesc(desc);
80    std::cout << "NNRT device: id = " << id << ", name: " << name << ", type:" << type << std::endl;
81
82    OH_AI_DeviceInfoHandle nnrtDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
83    ASSERT_NE(nnrtDeviceInfo, nullptr);
84    OH_AI_DeviceInfoSetDeviceId(nnrtDeviceInfo, id);
85    OH_AI_DestroyAllNNRTDeviceDescs(&desc);
86
87    OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(nnrtDeviceInfo);
88    printf("==========deviceType:%d\n", deviceType);
89    ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_NNRT);
90
91    OH_AI_DeviceInfoSetPerformanceMode(nnrtDeviceInfo, OH_AI_PERFORMANCE_MEDIUM);
92    ASSERT_EQ(OH_AI_DeviceInfoGetPerformanceMode(nnrtDeviceInfo), OH_AI_PERFORMANCE_MEDIUM);
93    OH_AI_DeviceInfoSetPriority(nnrtDeviceInfo, OH_AI_PRIORITY_MEDIUM);
94    ASSERT_EQ(OH_AI_DeviceInfoGetPriority(nnrtDeviceInfo), OH_AI_PRIORITY_MEDIUM);
95
96    OH_AI_ContextAddDeviceInfo(context, nnrtDeviceInfo);
97}
98
99// fill data to inputs tensor
100void FillInputsData(OH_AI_TensorHandleArray inputs, std::string modelName, bool isTranspose) {
101    for (size_t i = 0; i < inputs.handle_num; ++i) {
102        printf("==========ReadFile==========\n");
103        size_t size1;
104        size_t *ptrSize1 = &size1;
105        std::string inputDataPath = g_testResourcesDir + modelName + "_" + std::to_string(i) + ".input";
106        const char *imagePath = inputDataPath.c_str();
107        char *imageBuf = ReadFile(imagePath, ptrSize1);
108        ASSERT_NE(imageBuf, nullptr);
109        OH_AI_TensorHandle tensor = inputs.handle_list[i];
110        int64_t elementNum = OH_AI_TensorGetElementNum(tensor);
111        printf("Tensor name: %s. \n", OH_AI_TensorGetName(tensor));
112        float *inputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(inputs.handle_list[i]));
113        ASSERT_NE(inputData, nullptr);
114        if (isTranspose) {
115            printf("==========Transpose==========\n");
116            size_t shapeNum;
117            const int64_t *shape = OH_AI_TensorGetShape(tensor, &shapeNum);
118            auto imageBufNhwc = new char[size1];
119            PackNCHWToNHWCFp32(imageBuf, imageBufNhwc, shape[0], shape[1] * shape[2], shape[3]);
120            errno_t ret = memcpy_s(inputData, size1, imageBufNhwc, size1);
121            if (ret != EOK) {
122                printf("memcpy_s failed, ret: %d\n", ret);
123            }
124            delete[] imageBufNhwc;
125        } else {
126            errno_t ret = memcpy_s(inputData, size1, imageBuf, size1);
127            if (ret != EOK) {
128                printf("memcpy_s failed, ret: %d\n", ret);
129            }
130        }
131        printf("input data after filling is: ");
132        for (int j = 0; j < elementNum && j <= 20; ++j) {
133            printf("%f ", inputData[j]);
134        }
135        printf("\n");
136        delete[] imageBuf;
137    }
138}
139
140// compare result after predict
141void CompareResult(OH_AI_TensorHandleArray outputs, std::string modelName, float atol, float rtol) {
142    printf("==========GetOutput==========\n");
143    for (size_t i = 0; i < outputs.handle_num; ++i) {
144        OH_AI_TensorHandle tensor = outputs.handle_list[i];
145        int64_t elementNum = OH_AI_TensorGetElementNum(tensor);
146        printf("Tensor name: %s .\n", OH_AI_TensorGetName(tensor));
147        float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(tensor));
148        printf("output data is:");
149        for (int j = 0; j < elementNum && j <= 20; ++j) {
150            printf("%f ", outputData[j]);
151        }
152        printf("\n");
153        printf("==========compFp32WithTData==========\n");
154        std::string outputFile = g_testResourcesDir + modelName + std::to_string(i) + ".output";
155        bool result = compFp32WithTData(outputData, outputFile, atol, rtol, false);
156        EXPECT_EQ(result, true);
157    }
158}
159
160// model build and predict
161void ModelPredict(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
162                  OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) {
163    std::string modelPath = g_testResourcesDir + modelName + ".ms";
164    const char *graphPath = modelPath.c_str();
165    OH_AI_Status ret = OH_AI_STATUS_SUCCESS;
166    if (buildByGraph) {
167        printf("==========Build model by graphBuf==========\n");
168        size_t size;
169        size_t *ptrSize = &size;
170        char *graphBuf = ReadFile(graphPath, ptrSize);
171        ASSERT_NE(graphBuf, nullptr);
172        ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context);
173        delete[] graphBuf;
174    } else {
175        printf("==========Build model==========\n");
176        ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context);
177    }
178    printf("==========build model return code:%d\n", ret);
179    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
180    printf("==========GetInputs==========\n");
181    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
182    ASSERT_NE(inputs.handle_list, nullptr);
183    if (shapeInfos.shape_num != 0) {
184        printf("==========Resizes==========\n");
185        OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num);
186        printf("==========Resizes return code:%d\n", resize_ret);
187        ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS);
188    }
189
190    FillInputsData(inputs, modelName, isTranspose);
191    OH_AI_TensorHandleArray outputs;
192    OH_AI_Status predictRet = OH_AI_STATUS_SUCCESS;
193    if (isCallback) {
194        printf("==========Model Predict Callback==========\n");
195        OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback;
196        OH_AI_KernelCallBack afterCallBack = PrintAfterCallback;
197        predictRet = OH_AI_ModelPredict(model, inputs, &outputs, beforeCallBack, afterCallBack);
198    } else {
199        printf("==========Model Predict==========\n");
200        predictRet = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr);
201    }
202    printf("==========Model Predict End==========\n");
203    ASSERT_EQ(predictRet, OH_AI_STATUS_SUCCESS);
204    printf("=========CompareResult===========\n");
205    CompareResult(outputs, modelName);
206    printf("=========OH_AI_ModelDestroy===========\n");
207    OH_AI_ModelDestroy(&model);
208    printf("=========OH_AI_ModelDestroy End===========\n");
209}
210
211// add invalid model_type
212void ModelPredict_ModelBuild(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
213                             bool buildByGraph) {
214    std::string modelPath = g_testResourcesDir + modelName + ".ms";
215    const char *graphPath = modelPath.c_str();
216    OH_AI_Status ret = OH_AI_STATUS_SUCCESS;
217    if (buildByGraph) {
218        printf("==========Build model by graphBuf==========\n");
219        size_t size;
220        size_t *ptrSize = &size;
221        char *graphBuf = ReadFile(graphPath, ptrSize);
222        ASSERT_NE(graphBuf, nullptr);
223        ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_INVALID, context);
224        if (ret != OH_AI_STATUS_LITE_PARAM_INVALID) {
225            printf("OH_AI_ModelBuild failed due to model_type is OH_AI_MODELTYPE_INVALID.\n");
226        }
227        delete[] graphBuf;
228    } else {
229        printf("==========Build model==========\n");
230        ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_INVALID, context);
231        if (ret != OH_AI_STATUS_LITE_PARAM_INVALID) {
232            printf("OH_AI_ModelBuildFromFile failed due to model_type is OH_AI_MODELTYPE_INVALID.\n");
233        }
234    }
235    printf("==========build model return code:%d\n", ret);
236}
237
238// model train build and predict
239void ModelTrain(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
240                OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) {
241    std::string modelPath = g_testResourcesDir + modelName + ".ms";
242    const char *graphPath = modelPath.c_str();
243    OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
244    OH_AI_Status ret = OH_AI_STATUS_SUCCESS;
245    if (buildByGraph) {
246        printf("==========Build model by graphBuf==========\n");
247        size_t size;
248        size_t *ptrSize = &size;
249        char *graphBuf = ReadFile(graphPath, ptrSize);
250        ASSERT_NE(graphBuf, nullptr);
251        ret = OH_AI_TrainModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
252        delete[] graphBuf;
253    } else {
254        printf("==========Build model==========\n");
255        ret = OH_AI_TrainModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
256    }
257    printf("==========build model return code:%d\n", ret);
258    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
259    printf("==========GetInputs==========\n");
260    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
261    ASSERT_NE(inputs.handle_list, nullptr);
262    if (shapeInfos.shape_num != 0) {
263        printf("==========Resizes==========\n");
264        OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num);
265        printf("==========Resizes return code:%d\n", resize_ret);
266        ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS);
267    }
268    FillInputsData(inputs, modelName, isTranspose);
269    ret = OH_AI_ModelSetTrainMode(model, true);
270    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
271    if (isCallback) {
272        printf("==========Model RunStep Callback==========\n");
273        OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback;
274        OH_AI_KernelCallBack afterCallBack = PrintAfterCallback;
275        ret = OH_AI_RunStep(model, beforeCallBack, afterCallBack);
276    } else {
277        printf("==========Model RunStep==========\n");
278        ret = OH_AI_RunStep(model, nullptr, nullptr);
279    }
280    printf("==========Model RunStep End==========\n");
281    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
282}
283