1be168c0dSopenharmony_ci/*
2be168c0dSopenharmony_ci * Copyright (c) 2023 Huawei Device Co., Ltd.
3be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License");
4be168c0dSopenharmony_ci * you may not use this file except in compliance with the License.
5be168c0dSopenharmony_ci * You may obtain a copy of the License at
6be168c0dSopenharmony_ci *
7be168c0dSopenharmony_ci *     http://www.apache.org/licenses/LICENSE-2.0
8be168c0dSopenharmony_ci *
9be168c0dSopenharmony_ci * Unless required by applicable law or agreed to in writing, software
10be168c0dSopenharmony_ci * distributed under the License is distributed on an "AS IS" BASIS,
11be168c0dSopenharmony_ci * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12be168c0dSopenharmony_ci * See the License for the specific language governing permissions and
13be168c0dSopenharmony_ci * limitations under the License.
14be168c0dSopenharmony_ci */
15be168c0dSopenharmony_ci
16be168c0dSopenharmony_ci#include "model_utils.h"
17be168c0dSopenharmony_ci#include <securec.h>
18be168c0dSopenharmony_ci#include "gtest/gtest.h"
19be168c0dSopenharmony_ci#include "include/c_api/context_c.h"
20be168c0dSopenharmony_ci#include "include/c_api/model_c.h"
21be168c0dSopenharmony_ci#include "include/c_api/types_c.h"
22be168c0dSopenharmony_ci#include "include/c_api/status_c.h"
23be168c0dSopenharmony_ci#include "include/c_api/data_type_c.h"
24be168c0dSopenharmony_ci#include "include/c_api/tensor_c.h"
25be168c0dSopenharmony_ci#include "include/c_api/format_c.h"
26be168c0dSopenharmony_ci#include "common.h"
27be168c0dSopenharmony_ci
28be168c0dSopenharmony_cistd::string g_testResourcesDir = "/data/test/resource/";
29be168c0dSopenharmony_ci
30be168c0dSopenharmony_ci// function before callback
31be168c0dSopenharmony_cibool PrintBeforeCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
32be168c0dSopenharmony_ci                         const OH_AI_CallBackParam kernelInfo) {
33be168c0dSopenharmony_ci    std::cout << "Before forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl;
34be168c0dSopenharmony_ci    return true;
35be168c0dSopenharmony_ci}
36be168c0dSopenharmony_ci
37be168c0dSopenharmony_ci// function after callback
38be168c0dSopenharmony_cibool PrintAfterCallback(const OH_AI_TensorHandleArray inputs, const OH_AI_TensorHandleArray outputs,
39be168c0dSopenharmony_ci                        const OH_AI_CallBackParam kernelInfo) {
40be168c0dSopenharmony_ci    std::cout << "After forwarding " << kernelInfo.node_name << " " << kernelInfo.node_type << std::endl;
41be168c0dSopenharmony_ci    return true;
42be168c0dSopenharmony_ci}
43be168c0dSopenharmony_ci
44be168c0dSopenharmony_ci// add cpu device info
45be168c0dSopenharmony_civoid AddContextDeviceCPU(OH_AI_ContextHandle context) {
46be168c0dSopenharmony_ci    OH_AI_DeviceInfoHandle cpuDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_CPU);
47be168c0dSopenharmony_ci    ASSERT_NE(cpuDeviceInfo, nullptr);
48be168c0dSopenharmony_ci    OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(cpuDeviceInfo);
49be168c0dSopenharmony_ci    printf("==========deviceType:%d\n", deviceType);
50be168c0dSopenharmony_ci    ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_CPU);
51be168c0dSopenharmony_ci    OH_AI_ContextAddDeviceInfo(context, cpuDeviceInfo);
52be168c0dSopenharmony_ci}
53be168c0dSopenharmony_ci
54be168c0dSopenharmony_cibool IsNPU() {
55be168c0dSopenharmony_ci    size_t num = 0;
56be168c0dSopenharmony_ci    auto desc = OH_AI_GetAllNNRTDeviceDescs(&num);
57be168c0dSopenharmony_ci    if (desc == nullptr) {
58be168c0dSopenharmony_ci        return false;
59be168c0dSopenharmony_ci    }
60be168c0dSopenharmony_ci    auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
61be168c0dSopenharmony_ci    const std::string npuNamePrefix = "NPU_";
62be168c0dSopenharmony_ci    if (strncmp(npuNamePrefix.c_str(), name, npuNamePrefix.size()) != 0) {
63be168c0dSopenharmony_ci        return false;
64be168c0dSopenharmony_ci    }
65be168c0dSopenharmony_ci    return true;
66be168c0dSopenharmony_ci}
67be168c0dSopenharmony_ci
68be168c0dSopenharmony_ci// add nnrt device info
69be168c0dSopenharmony_civoid AddContextDeviceNNRT(OH_AI_ContextHandle context) {
70be168c0dSopenharmony_ci    size_t num = 0;
71be168c0dSopenharmony_ci    auto desc = OH_AI_GetAllNNRTDeviceDescs(&num);
72be168c0dSopenharmony_ci    if (desc == nullptr) {
73be168c0dSopenharmony_ci        return;
74be168c0dSopenharmony_ci    }
75be168c0dSopenharmony_ci
76be168c0dSopenharmony_ci    std::cout << "found " << num << " nnrt devices" << std::endl;
77be168c0dSopenharmony_ci    auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc);
78be168c0dSopenharmony_ci    auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
79be168c0dSopenharmony_ci    auto type = OH_AI_GetTypeFromNNRTDeviceDesc(desc);
80be168c0dSopenharmony_ci    std::cout << "NNRT device: id = " << id << ", name: " << name << ", type:" << type << std::endl;
81be168c0dSopenharmony_ci
82be168c0dSopenharmony_ci    OH_AI_DeviceInfoHandle nnrtDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
83be168c0dSopenharmony_ci    ASSERT_NE(nnrtDeviceInfo, nullptr);
84be168c0dSopenharmony_ci    OH_AI_DeviceInfoSetDeviceId(nnrtDeviceInfo, id);
85be168c0dSopenharmony_ci    OH_AI_DestroyAllNNRTDeviceDescs(&desc);
86be168c0dSopenharmony_ci
87be168c0dSopenharmony_ci    OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(nnrtDeviceInfo);
88be168c0dSopenharmony_ci    printf("==========deviceType:%d\n", deviceType);
89be168c0dSopenharmony_ci    ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_NNRT);
90be168c0dSopenharmony_ci
91be168c0dSopenharmony_ci    OH_AI_DeviceInfoSetPerformanceMode(nnrtDeviceInfo, OH_AI_PERFORMANCE_MEDIUM);
92be168c0dSopenharmony_ci    ASSERT_EQ(OH_AI_DeviceInfoGetPerformanceMode(nnrtDeviceInfo), OH_AI_PERFORMANCE_MEDIUM);
93be168c0dSopenharmony_ci    OH_AI_DeviceInfoSetPriority(nnrtDeviceInfo, OH_AI_PRIORITY_MEDIUM);
94be168c0dSopenharmony_ci    ASSERT_EQ(OH_AI_DeviceInfoGetPriority(nnrtDeviceInfo), OH_AI_PRIORITY_MEDIUM);
95be168c0dSopenharmony_ci
96be168c0dSopenharmony_ci    OH_AI_ContextAddDeviceInfo(context, nnrtDeviceInfo);
97be168c0dSopenharmony_ci}
98be168c0dSopenharmony_ci
99be168c0dSopenharmony_ci// fill data to inputs tensor
100be168c0dSopenharmony_civoid FillInputsData(OH_AI_TensorHandleArray inputs, std::string modelName, bool isTranspose) {
101be168c0dSopenharmony_ci    for (size_t i = 0; i < inputs.handle_num; ++i) {
102be168c0dSopenharmony_ci        printf("==========ReadFile==========\n");
103be168c0dSopenharmony_ci        size_t size1;
104be168c0dSopenharmony_ci        size_t *ptrSize1 = &size1;
105be168c0dSopenharmony_ci        std::string inputDataPath = g_testResourcesDir + modelName + "_" + std::to_string(i) + ".input";
106be168c0dSopenharmony_ci        const char *imagePath = inputDataPath.c_str();
107be168c0dSopenharmony_ci        char *imageBuf = ReadFile(imagePath, ptrSize1);
108be168c0dSopenharmony_ci        ASSERT_NE(imageBuf, nullptr);
109be168c0dSopenharmony_ci        OH_AI_TensorHandle tensor = inputs.handle_list[i];
110be168c0dSopenharmony_ci        int64_t elementNum = OH_AI_TensorGetElementNum(tensor);
111be168c0dSopenharmony_ci        printf("Tensor name: %s. \n", OH_AI_TensorGetName(tensor));
112be168c0dSopenharmony_ci        float *inputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(inputs.handle_list[i]));
113be168c0dSopenharmony_ci        ASSERT_NE(inputData, nullptr);
114be168c0dSopenharmony_ci        if (isTranspose) {
115be168c0dSopenharmony_ci            printf("==========Transpose==========\n");
116be168c0dSopenharmony_ci            size_t shapeNum;
117be168c0dSopenharmony_ci            const int64_t *shape = OH_AI_TensorGetShape(tensor, &shapeNum);
118be168c0dSopenharmony_ci            auto imageBufNhwc = new char[size1];
119be168c0dSopenharmony_ci            PackNCHWToNHWCFp32(imageBuf, imageBufNhwc, shape[0], shape[1] * shape[2], shape[3]);
120be168c0dSopenharmony_ci            errno_t ret = memcpy_s(inputData, size1, imageBufNhwc, size1);
121be168c0dSopenharmony_ci            if (ret != EOK) {
122be168c0dSopenharmony_ci                printf("memcpy_s failed, ret: %d\n", ret);
123be168c0dSopenharmony_ci            }
124be168c0dSopenharmony_ci            delete[] imageBufNhwc;
125be168c0dSopenharmony_ci        } else {
126be168c0dSopenharmony_ci            errno_t ret = memcpy_s(inputData, size1, imageBuf, size1);
127be168c0dSopenharmony_ci            if (ret != EOK) {
128be168c0dSopenharmony_ci                printf("memcpy_s failed, ret: %d\n", ret);
129be168c0dSopenharmony_ci            }
130be168c0dSopenharmony_ci        }
131be168c0dSopenharmony_ci        printf("input data after filling is: ");
132be168c0dSopenharmony_ci        for (int j = 0; j < elementNum && j <= 20; ++j) {
133be168c0dSopenharmony_ci            printf("%f ", inputData[j]);
134be168c0dSopenharmony_ci        }
135be168c0dSopenharmony_ci        printf("\n");
136be168c0dSopenharmony_ci        delete[] imageBuf;
137be168c0dSopenharmony_ci    }
138be168c0dSopenharmony_ci}
139be168c0dSopenharmony_ci
140be168c0dSopenharmony_ci// compare result after predict
141be168c0dSopenharmony_civoid CompareResult(OH_AI_TensorHandleArray outputs, std::string modelName, float atol, float rtol) {
142be168c0dSopenharmony_ci    printf("==========GetOutput==========\n");
143be168c0dSopenharmony_ci    for (size_t i = 0; i < outputs.handle_num; ++i) {
144be168c0dSopenharmony_ci        OH_AI_TensorHandle tensor = outputs.handle_list[i];
145be168c0dSopenharmony_ci        int64_t elementNum = OH_AI_TensorGetElementNum(tensor);
146be168c0dSopenharmony_ci        printf("Tensor name: %s .\n", OH_AI_TensorGetName(tensor));
147be168c0dSopenharmony_ci        float *outputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(tensor));
148be168c0dSopenharmony_ci        printf("output data is:");
149be168c0dSopenharmony_ci        for (int j = 0; j < elementNum && j <= 20; ++j) {
150be168c0dSopenharmony_ci            printf("%f ", outputData[j]);
151be168c0dSopenharmony_ci        }
152be168c0dSopenharmony_ci        printf("\n");
153be168c0dSopenharmony_ci        printf("==========compFp32WithTData==========\n");
154be168c0dSopenharmony_ci        std::string outputFile = g_testResourcesDir + modelName + std::to_string(i) + ".output";
155be168c0dSopenharmony_ci        bool result = compFp32WithTData(outputData, outputFile, atol, rtol, false);
156be168c0dSopenharmony_ci        EXPECT_EQ(result, true);
157be168c0dSopenharmony_ci    }
158be168c0dSopenharmony_ci}
159be168c0dSopenharmony_ci
160be168c0dSopenharmony_ci// model build and predict
161be168c0dSopenharmony_civoid ModelPredict(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
162be168c0dSopenharmony_ci                  OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) {
163be168c0dSopenharmony_ci    std::string modelPath = g_testResourcesDir + modelName + ".ms";
164be168c0dSopenharmony_ci    const char *graphPath = modelPath.c_str();
165be168c0dSopenharmony_ci    OH_AI_Status ret = OH_AI_STATUS_SUCCESS;
166be168c0dSopenharmony_ci    if (buildByGraph) {
167be168c0dSopenharmony_ci        printf("==========Build model by graphBuf==========\n");
168be168c0dSopenharmony_ci        size_t size;
169be168c0dSopenharmony_ci        size_t *ptrSize = &size;
170be168c0dSopenharmony_ci        char *graphBuf = ReadFile(graphPath, ptrSize);
171be168c0dSopenharmony_ci        ASSERT_NE(graphBuf, nullptr);
172be168c0dSopenharmony_ci        ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context);
173be168c0dSopenharmony_ci        delete[] graphBuf;
174be168c0dSopenharmony_ci    } else {
175be168c0dSopenharmony_ci        printf("==========Build model==========\n");
176be168c0dSopenharmony_ci        ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context);
177be168c0dSopenharmony_ci    }
178be168c0dSopenharmony_ci    printf("==========build model return code:%d\n", ret);
179be168c0dSopenharmony_ci    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
180be168c0dSopenharmony_ci    printf("==========GetInputs==========\n");
181be168c0dSopenharmony_ci    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
182be168c0dSopenharmony_ci    ASSERT_NE(inputs.handle_list, nullptr);
183be168c0dSopenharmony_ci    if (shapeInfos.shape_num != 0) {
184be168c0dSopenharmony_ci        printf("==========Resizes==========\n");
185be168c0dSopenharmony_ci        OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num);
186be168c0dSopenharmony_ci        printf("==========Resizes return code:%d\n", resize_ret);
187be168c0dSopenharmony_ci        ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS);
188be168c0dSopenharmony_ci    }
189be168c0dSopenharmony_ci
190be168c0dSopenharmony_ci    FillInputsData(inputs, modelName, isTranspose);
191be168c0dSopenharmony_ci    OH_AI_TensorHandleArray outputs;
192be168c0dSopenharmony_ci    OH_AI_Status predictRet = OH_AI_STATUS_SUCCESS;
193be168c0dSopenharmony_ci    if (isCallback) {
194be168c0dSopenharmony_ci        printf("==========Model Predict Callback==========\n");
195be168c0dSopenharmony_ci        OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback;
196be168c0dSopenharmony_ci        OH_AI_KernelCallBack afterCallBack = PrintAfterCallback;
197be168c0dSopenharmony_ci        predictRet = OH_AI_ModelPredict(model, inputs, &outputs, beforeCallBack, afterCallBack);
198be168c0dSopenharmony_ci    } else {
199be168c0dSopenharmony_ci        printf("==========Model Predict==========\n");
200be168c0dSopenharmony_ci        predictRet = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr);
201be168c0dSopenharmony_ci    }
202be168c0dSopenharmony_ci    printf("==========Model Predict End==========\n");
203be168c0dSopenharmony_ci    ASSERT_EQ(predictRet, OH_AI_STATUS_SUCCESS);
204be168c0dSopenharmony_ci    printf("=========CompareResult===========\n");
205be168c0dSopenharmony_ci    CompareResult(outputs, modelName);
206be168c0dSopenharmony_ci    printf("=========OH_AI_ModelDestroy===========\n");
207be168c0dSopenharmony_ci    OH_AI_ModelDestroy(&model);
208be168c0dSopenharmony_ci    printf("=========OH_AI_ModelDestroy End===========\n");
209be168c0dSopenharmony_ci}
210be168c0dSopenharmony_ci
211be168c0dSopenharmony_ci// add invalid model_type
212be168c0dSopenharmony_civoid ModelPredict_ModelBuild(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
213be168c0dSopenharmony_ci                             bool buildByGraph) {
214be168c0dSopenharmony_ci    std::string modelPath = g_testResourcesDir + modelName + ".ms";
215be168c0dSopenharmony_ci    const char *graphPath = modelPath.c_str();
216be168c0dSopenharmony_ci    OH_AI_Status ret = OH_AI_STATUS_SUCCESS;
217be168c0dSopenharmony_ci    if (buildByGraph) {
218be168c0dSopenharmony_ci        printf("==========Build model by graphBuf==========\n");
219be168c0dSopenharmony_ci        size_t size;
220be168c0dSopenharmony_ci        size_t *ptrSize = &size;
221be168c0dSopenharmony_ci        char *graphBuf = ReadFile(graphPath, ptrSize);
222be168c0dSopenharmony_ci        ASSERT_NE(graphBuf, nullptr);
223be168c0dSopenharmony_ci        ret = OH_AI_ModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_INVALID, context);
224be168c0dSopenharmony_ci        if (ret != OH_AI_STATUS_LITE_PARAM_INVALID) {
225be168c0dSopenharmony_ci            printf("OH_AI_ModelBuild failed due to model_type is OH_AI_MODELTYPE_INVALID.\n");
226be168c0dSopenharmony_ci        }
227be168c0dSopenharmony_ci        delete[] graphBuf;
228be168c0dSopenharmony_ci    } else {
229be168c0dSopenharmony_ci        printf("==========Build model==========\n");
230be168c0dSopenharmony_ci        ret = OH_AI_ModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_INVALID, context);
231be168c0dSopenharmony_ci        if (ret != OH_AI_STATUS_LITE_PARAM_INVALID) {
232be168c0dSopenharmony_ci            printf("OH_AI_ModelBuildFromFile failed due to model_type is OH_AI_MODELTYPE_INVALID.\n");
233be168c0dSopenharmony_ci        }
234be168c0dSopenharmony_ci    }
235be168c0dSopenharmony_ci    printf("==========build model return code:%d\n", ret);
236be168c0dSopenharmony_ci}
237be168c0dSopenharmony_ci
238be168c0dSopenharmony_ci// model train build and predict
239be168c0dSopenharmony_civoid ModelTrain(OH_AI_ModelHandle model, OH_AI_ContextHandle context, std::string modelName,
240be168c0dSopenharmony_ci                OH_AI_ShapeInfo shapeInfos, bool buildByGraph, bool isTranspose, bool isCallback) {
241be168c0dSopenharmony_ci    std::string modelPath = g_testResourcesDir + modelName + ".ms";
242be168c0dSopenharmony_ci    const char *graphPath = modelPath.c_str();
243be168c0dSopenharmony_ci    OH_AI_TrainCfgHandle trainCfg = OH_AI_TrainCfgCreate();
244be168c0dSopenharmony_ci    OH_AI_Status ret = OH_AI_STATUS_SUCCESS;
245be168c0dSopenharmony_ci    if (buildByGraph) {
246be168c0dSopenharmony_ci        printf("==========Build model by graphBuf==========\n");
247be168c0dSopenharmony_ci        size_t size;
248be168c0dSopenharmony_ci        size_t *ptrSize = &size;
249be168c0dSopenharmony_ci        char *graphBuf = ReadFile(graphPath, ptrSize);
250be168c0dSopenharmony_ci        ASSERT_NE(graphBuf, nullptr);
251be168c0dSopenharmony_ci        ret = OH_AI_TrainModelBuild(model, graphBuf, size, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
252be168c0dSopenharmony_ci        delete[] graphBuf;
253be168c0dSopenharmony_ci    } else {
254be168c0dSopenharmony_ci        printf("==========Build model==========\n");
255be168c0dSopenharmony_ci        ret = OH_AI_TrainModelBuildFromFile(model, graphPath, OH_AI_MODELTYPE_MINDIR, context, trainCfg);
256be168c0dSopenharmony_ci    }
257be168c0dSopenharmony_ci    printf("==========build model return code:%d\n", ret);
258be168c0dSopenharmony_ci    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
259be168c0dSopenharmony_ci    printf("==========GetInputs==========\n");
260be168c0dSopenharmony_ci    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
261be168c0dSopenharmony_ci    ASSERT_NE(inputs.handle_list, nullptr);
262be168c0dSopenharmony_ci    if (shapeInfos.shape_num != 0) {
263be168c0dSopenharmony_ci        printf("==========Resizes==========\n");
264be168c0dSopenharmony_ci        OH_AI_Status resize_ret = OH_AI_ModelResize(model, inputs, &shapeInfos, inputs.handle_num);
265be168c0dSopenharmony_ci        printf("==========Resizes return code:%d\n", resize_ret);
266be168c0dSopenharmony_ci        ASSERT_EQ(resize_ret, OH_AI_STATUS_SUCCESS);
267be168c0dSopenharmony_ci    }
268be168c0dSopenharmony_ci    FillInputsData(inputs, modelName, isTranspose);
269be168c0dSopenharmony_ci    ret = OH_AI_ModelSetTrainMode(model, true);
270be168c0dSopenharmony_ci    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
271be168c0dSopenharmony_ci    if (isCallback) {
272be168c0dSopenharmony_ci        printf("==========Model RunStep Callback==========\n");
273be168c0dSopenharmony_ci        OH_AI_KernelCallBack beforeCallBack = PrintBeforeCallback;
274be168c0dSopenharmony_ci        OH_AI_KernelCallBack afterCallBack = PrintAfterCallback;
275be168c0dSopenharmony_ci        ret = OH_AI_RunStep(model, beforeCallBack, afterCallBack);
276be168c0dSopenharmony_ci    } else {
277be168c0dSopenharmony_ci        printf("==========Model RunStep==========\n");
278be168c0dSopenharmony_ci        ret = OH_AI_RunStep(model, nullptr, nullptr);
279be168c0dSopenharmony_ci    }
280be168c0dSopenharmony_ci    printf("==========Model RunStep End==========\n");
281be168c0dSopenharmony_ci    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
282be168c0dSopenharmony_ci}
283