1be168c0dSopenharmony_ci/*
2be168c0dSopenharmony_ci * Copyright (c) 2023 Huawei Device Co., Ltd.
3be168c0dSopenharmony_ci * Licensed under the Apache License, Version 2.0 (the "License");
4be168c0dSopenharmony_ci * you may not use this file except in compliance with the License.
5be168c0dSopenharmony_ci * You may obtain a copy of the License at
6be168c0dSopenharmony_ci *
7be168c0dSopenharmony_ci *     http://www.apache.org/licenses/LICENSE-2.0
8be168c0dSopenharmony_ci *
9be168c0dSopenharmony_ci * Unless required by applicable law or agreed to in writing, software
10be168c0dSopenharmony_ci * distributed under the License is distributed on an "AS IS" BASIS,
11be168c0dSopenharmony_ci * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12be168c0dSopenharmony_ci * See the License for the specific language governing permissions and
13be168c0dSopenharmony_ci * limitations under the License.
14be168c0dSopenharmony_ci */
15be168c0dSopenharmony_ci
16be168c0dSopenharmony_ci#include "gtest/gtest.h"
17be168c0dSopenharmony_ci#include <random>
18be168c0dSopenharmony_ci#include "../utils/model_utils.h"
19be168c0dSopenharmony_ci#include "../utils/common.h"
20be168c0dSopenharmony_ci
21be168c0dSopenharmony_ciclass MSLiteNnrtTest: public testing::Test {
22be168c0dSopenharmony_ci  protected:
23be168c0dSopenharmony_ci    static void SetUpTestCase(void) {}
24be168c0dSopenharmony_ci    static void TearDownTestCase(void) {}
25be168c0dSopenharmony_ci    virtual void SetUp() {}
26be168c0dSopenharmony_ci    virtual void TearDown() {}
27be168c0dSopenharmony_ci};
28be168c0dSopenharmony_ci
29be168c0dSopenharmony_ci/*
30be168c0dSopenharmony_ci * @tc.name: Nnrt_Test
31be168c0dSopenharmony_ci * @tc.desc: Verify the NNRT delegate.
32be168c0dSopenharmony_ci * @tc.type: FUNC
33be168c0dSopenharmony_ci */
34be168c0dSopenharmony_ciHWTEST(MSLiteNnrtTest, Nnrt_ContextTest, testing::ext::TestSize.Level0) {
35be168c0dSopenharmony_ci    std::cout << "==========Get All Nnrt Device Descs==========" << std::endl;
36be168c0dSopenharmony_ci    size_t num = 0;
37be168c0dSopenharmony_ci    auto descs = OH_AI_GetAllNNRTDeviceDescs(&num);
38be168c0dSopenharmony_ci    if (descs == nullptr) {
39be168c0dSopenharmony_ci        std::cout << "descs is nullptr , num: " << num << std::endl;
40be168c0dSopenharmony_ci        ASSERT_EQ(num, 0);
41be168c0dSopenharmony_ci        return;
42be168c0dSopenharmony_ci    }
43be168c0dSopenharmony_ci
44be168c0dSopenharmony_ci    std::cout << "found " << num << " nnrt devices" << std::endl;
45be168c0dSopenharmony_ci    for (size_t i = 0; i < num; i++) {
46be168c0dSopenharmony_ci        auto desc = OH_AI_GetElementOfNNRTDeviceDescs(descs, i);
47be168c0dSopenharmony_ci        ASSERT_NE(desc, nullptr);
48be168c0dSopenharmony_ci        auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc);
49be168c0dSopenharmony_ci        auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
50be168c0dSopenharmony_ci        auto type = OH_AI_GetTypeFromNNRTDeviceDesc(desc);
51be168c0dSopenharmony_ci        std::cout << "NNRT device: id = " << id << ", name: " << name << ", type:" << type << std::endl;
52be168c0dSopenharmony_ci    }
53be168c0dSopenharmony_ci
54be168c0dSopenharmony_ci    OH_AI_DestroyAllNNRTDeviceDescs(&descs);
55be168c0dSopenharmony_ci    ASSERT_EQ(descs, nullptr);
56be168c0dSopenharmony_ci}
57be168c0dSopenharmony_ci
58be168c0dSopenharmony_ci/*
59be168c0dSopenharmony_ci * @tc.name: Nnrt_CreateNnrtDevice
60be168c0dSopenharmony_ci * @tc.desc: Verify the NNRT device create function.
61be168c0dSopenharmony_ci * @tc.type: FUNC
62be168c0dSopenharmony_ci */
63be168c0dSopenharmony_ciHWTEST(MSLiteNnrtTest, Nnrt_CreateNnrtDevice, testing::ext::TestSize.Level0) {
64be168c0dSopenharmony_ci    std::cout << "==========Get All Nnrt Device Descs==========" << std::endl;
65be168c0dSopenharmony_ci    size_t num = 0;
66be168c0dSopenharmony_ci    auto desc = OH_AI_GetAllNNRTDeviceDescs(&num);
67be168c0dSopenharmony_ci    if (desc == nullptr) {
68be168c0dSopenharmony_ci        std::cout << "descs is nullptr , num: " << num << std::endl;
69be168c0dSopenharmony_ci        ASSERT_EQ(num, 0);
70be168c0dSopenharmony_ci        return;
71be168c0dSopenharmony_ci    }
72be168c0dSopenharmony_ci
73be168c0dSopenharmony_ci    std::cout << "found " << num << " nnrt devices" << std::endl;
74be168c0dSopenharmony_ci    auto id = OH_AI_GetDeviceIdFromNNRTDeviceDesc(desc);
75be168c0dSopenharmony_ci    auto name = OH_AI_GetNameFromNNRTDeviceDesc(desc);
76be168c0dSopenharmony_ci    auto type = OH_AI_GetTypeFromNNRTDeviceDesc(desc);
77be168c0dSopenharmony_ci    std::cout << "NNRT device: id = " << id << ", name = " << name << ", type = " << type << std::endl;
78be168c0dSopenharmony_ci
79be168c0dSopenharmony_ci    // create by name
80be168c0dSopenharmony_ci    auto nnrtDeviceInfo = OH_AI_CreateNNRTDeviceInfoByName(name);
81be168c0dSopenharmony_ci    ASSERT_NE(nnrtDeviceInfo, nullptr);
82be168c0dSopenharmony_ci
83be168c0dSopenharmony_ci    OH_AI_DeviceType deviceType = OH_AI_DeviceInfoGetDeviceType(nnrtDeviceInfo);
84be168c0dSopenharmony_ci    printf("==========deviceType:%d\n", deviceType);
85be168c0dSopenharmony_ci    ASSERT_EQ(OH_AI_DeviceInfoGetDeviceId(nnrtDeviceInfo), id);
86be168c0dSopenharmony_ci    ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_NNRT);
87be168c0dSopenharmony_ci    OH_AI_DeviceInfoDestroy(&nnrtDeviceInfo);
88be168c0dSopenharmony_ci    ASSERT_EQ(nnrtDeviceInfo, nullptr);
89be168c0dSopenharmony_ci
90be168c0dSopenharmony_ci    // create by type
91be168c0dSopenharmony_ci    nnrtDeviceInfo = OH_AI_CreateNNRTDeviceInfoByType(type);
92be168c0dSopenharmony_ci    ASSERT_NE(nnrtDeviceInfo, nullptr);
93be168c0dSopenharmony_ci
94be168c0dSopenharmony_ci    deviceType = OH_AI_DeviceInfoGetDeviceType(nnrtDeviceInfo);
95be168c0dSopenharmony_ci    printf("==========deviceType:%d\n", deviceType);
96be168c0dSopenharmony_ci    ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_NNRT);
97be168c0dSopenharmony_ci    ASSERT_EQ(OH_AI_DeviceInfoGetDeviceId(nnrtDeviceInfo), id);
98be168c0dSopenharmony_ci    OH_AI_DeviceInfoDestroy(&nnrtDeviceInfo);
99be168c0dSopenharmony_ci    ASSERT_EQ(nnrtDeviceInfo, nullptr);
100be168c0dSopenharmony_ci
101be168c0dSopenharmony_ci    // create by id
102be168c0dSopenharmony_ci    nnrtDeviceInfo = OH_AI_DeviceInfoCreate(OH_AI_DEVICETYPE_NNRT);
103be168c0dSopenharmony_ci    ASSERT_NE(nnrtDeviceInfo, nullptr);
104be168c0dSopenharmony_ci    OH_AI_DeviceInfoSetDeviceId(nnrtDeviceInfo, id);
105be168c0dSopenharmony_ci
106be168c0dSopenharmony_ci    deviceType = OH_AI_DeviceInfoGetDeviceType(nnrtDeviceInfo);
107be168c0dSopenharmony_ci    printf("==========deviceType:%d\n", deviceType);
108be168c0dSopenharmony_ci    ASSERT_EQ(deviceType, OH_AI_DEVICETYPE_NNRT);
109be168c0dSopenharmony_ci
110be168c0dSopenharmony_ci    OH_AI_DeviceInfoSetPerformanceMode(nnrtDeviceInfo, OH_AI_PERFORMANCE_MEDIUM);
111be168c0dSopenharmony_ci    ASSERT_EQ(OH_AI_DeviceInfoGetPerformanceMode(nnrtDeviceInfo), OH_AI_PERFORMANCE_MEDIUM);
112be168c0dSopenharmony_ci    OH_AI_DeviceInfoSetPriority(nnrtDeviceInfo, OH_AI_PRIORITY_MEDIUM);
113be168c0dSopenharmony_ci    ASSERT_EQ(OH_AI_DeviceInfoGetPriority(nnrtDeviceInfo), OH_AI_PRIORITY_MEDIUM);
114be168c0dSopenharmony_ci    std::string cachePath = "/data/local/tmp/";
115be168c0dSopenharmony_ci    std::string cacheVersion = "1";
116be168c0dSopenharmony_ci    OH_AI_DeviceInfoAddExtension(nnrtDeviceInfo, "CachePath", cachePath.c_str(), cachePath.size());
117be168c0dSopenharmony_ci    OH_AI_DeviceInfoAddExtension(nnrtDeviceInfo, "CacheVersion", cacheVersion.c_str(), cacheVersion.size());
118be168c0dSopenharmony_ci    OH_AI_DeviceInfoDestroy(&nnrtDeviceInfo);
119be168c0dSopenharmony_ci    ASSERT_EQ(nnrtDeviceInfo, nullptr);
120be168c0dSopenharmony_ci
121be168c0dSopenharmony_ci    OH_AI_DestroyAllNNRTDeviceDescs(&desc);
122be168c0dSopenharmony_ci}
123be168c0dSopenharmony_ci
124be168c0dSopenharmony_ci/*
125be168c0dSopenharmony_ci * @tc.name: Nnrt_NpuPredict
126be168c0dSopenharmony_ci * @tc.desc: Verify the NNRT predict.
127be168c0dSopenharmony_ci * @tc.type: FUNC
128be168c0dSopenharmony_ci */
129be168c0dSopenharmony_ciHWTEST(MSLiteNnrtTest, Nnrt_NpuPredict, testing::ext::TestSize.Level0) {
130be168c0dSopenharmony_ci    if (!IsNPU()) {
131be168c0dSopenharmony_ci        printf("NNRt is not NPU, skip this test");
132be168c0dSopenharmony_ci        return;
133be168c0dSopenharmony_ci    }
134be168c0dSopenharmony_ci
135be168c0dSopenharmony_ci    printf("==========Init Context==========\n");
136be168c0dSopenharmony_ci    OH_AI_ContextHandle context = OH_AI_ContextCreate();
137be168c0dSopenharmony_ci    ASSERT_NE(context, nullptr);
138be168c0dSopenharmony_ci    AddContextDeviceNNRT(context);
139be168c0dSopenharmony_ci    printf("==========Create model==========\n");
140be168c0dSopenharmony_ci    OH_AI_ModelHandle model = OH_AI_ModelCreate();
141be168c0dSopenharmony_ci    ASSERT_NE(model, nullptr);
142be168c0dSopenharmony_ci    printf("==========Build model==========\n");
143be168c0dSopenharmony_ci    OH_AI_Status ret = OH_AI_ModelBuildFromFile(model, "/data/test/resource/tinynet.om.ms",
144be168c0dSopenharmony_ci                                                OH_AI_MODELTYPE_MINDIR, context);
145be168c0dSopenharmony_ci    printf("==========build model return code:%d\n", ret);
146be168c0dSopenharmony_ci    if (ret != OH_AI_STATUS_SUCCESS) {
147be168c0dSopenharmony_ci        printf("==========build model failed, ret: %d\n", ret);
148be168c0dSopenharmony_ci        OH_AI_ModelDestroy(&model);
149be168c0dSopenharmony_ci        return;
150be168c0dSopenharmony_ci    }
151be168c0dSopenharmony_ci
152be168c0dSopenharmony_ci    printf("==========GetInputs==========\n");
153be168c0dSopenharmony_ci    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
154be168c0dSopenharmony_ci    ASSERT_NE(inputs.handle_list, nullptr);
155be168c0dSopenharmony_ci    for (size_t i = 0; i < inputs.handle_num; ++i) {
156be168c0dSopenharmony_ci        OH_AI_TensorHandle tensor = inputs.handle_list[i];
157be168c0dSopenharmony_ci        float *inputData = reinterpret_cast<float *>(OH_AI_TensorGetMutableData(tensor));
158be168c0dSopenharmony_ci        size_t elementNum = OH_AI_TensorGetElementNum(tensor);
159be168c0dSopenharmony_ci        std::random_device rd;
160be168c0dSopenharmony_ci        std::mt19937 gen(rd());
161be168c0dSopenharmony_ci        std::uniform_real_distribution<float> dis(0.0f,1.0f);
162be168c0dSopenharmony_ci        for (size_t z = 0; z < elementNum; z++) {
163be168c0dSopenharmony_ci            inputData[z] = dis(gen);
164be168c0dSopenharmony_ci        }
165be168c0dSopenharmony_ci    }
166be168c0dSopenharmony_ci    printf("==========Model Predict==========\n");
167be168c0dSopenharmony_ci    OH_AI_TensorHandleArray outputs;
168be168c0dSopenharmony_ci    ret = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr);
169be168c0dSopenharmony_ci    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
170be168c0dSopenharmony_ci    OH_AI_ModelDestroy(&model);
171be168c0dSopenharmony_ci}
172be168c0dSopenharmony_ci
173be168c0dSopenharmony_ci/*
174be168c0dSopenharmony_ci * @tc.name: Nnrt_NpuCpuPredict
175be168c0dSopenharmony_ci * @tc.desc: Verify the NNRT npu/cpu predict.
176be168c0dSopenharmony_ci * @tc.type: FUNC
177be168c0dSopenharmony_ci */
178be168c0dSopenharmony_ciHWTEST(MSLiteNnrtTest, Nnrt_NpuCpuPredict, testing::ext::TestSize.Level0) {
179be168c0dSopenharmony_ci    printf("==========Init Context==========\n");
180be168c0dSopenharmony_ci    OH_AI_ContextHandle context = OH_AI_ContextCreate();
181be168c0dSopenharmony_ci    ASSERT_NE(context, nullptr);
182be168c0dSopenharmony_ci    AddContextDeviceNNRT(context);
183be168c0dSopenharmony_ci    AddContextDeviceCPU(context);
184be168c0dSopenharmony_ci    printf("==========Create model==========\n");
185be168c0dSopenharmony_ci    OH_AI_ModelHandle model = OH_AI_ModelCreate();
186be168c0dSopenharmony_ci    ASSERT_NE(model, nullptr);
187be168c0dSopenharmony_ci    printf("==========Build model==========\n");
188be168c0dSopenharmony_ci    OH_AI_Status ret = OH_AI_ModelBuildFromFile(model, "/data/test/resource/ml_face_isface.ms",
189be168c0dSopenharmony_ci                                                OH_AI_MODELTYPE_MINDIR, context);
190be168c0dSopenharmony_ci    printf("==========build model return code:%d\n", ret);
191be168c0dSopenharmony_ci    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
192be168c0dSopenharmony_ci
193be168c0dSopenharmony_ci    printf("==========GetInputs==========\n");
194be168c0dSopenharmony_ci    OH_AI_TensorHandleArray inputs = OH_AI_ModelGetInputs(model);
195be168c0dSopenharmony_ci    ASSERT_NE(inputs.handle_list, nullptr);
196be168c0dSopenharmony_ci    FillInputsData(inputs, "ml_face_isface", true);
197be168c0dSopenharmony_ci    printf("==========Model Predict==========\n");
198be168c0dSopenharmony_ci    OH_AI_TensorHandleArray outputs;
199be168c0dSopenharmony_ci    ret = OH_AI_ModelPredict(model, inputs, &outputs, nullptr, nullptr);
200be168c0dSopenharmony_ci    ASSERT_EQ(ret, OH_AI_STATUS_SUCCESS);
201be168c0dSopenharmony_ci    CompareResult(outputs, "ml_face_isface");
202be168c0dSopenharmony_ci    OH_AI_ModelDestroy(&model);
203be168c0dSopenharmony_ci}
204