1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "constantfold.h"
17 #include <cmath>
18 #include <cfloat>
19 #include <climits>
20 #include <type_traits>
21 #include "mpl_logging.h"
22 #include "mir_function.h"
23 #include "mir_builder.h"
24 #include "global_tables.h"
25 #include "me_option.h"
26 #include "maple_phase_manager.h"
27 #include "mir_type.h"
28 
29 #include "gtest/gtest.h"
30 
31 using namespace maple;
32 using namespace std;
33 namespace {
TEST(FoldIntConstBinaryMIRConst_FUNC, t01)34 TEST(FoldIntConstBinaryMIRConst_FUNC, t01)
35 {
36     MIRIntConst *mc_int_ptr1, *mc_int_ptr2;
37     mc_int_ptr1 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(2, *GlobalTables::GetTypeTable().GetInt64());
38     mc_int_ptr2 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetInt64());
39     std::vector<Opcode> input_op_ls = {OP_add, OP_sub, OP_mul, OP_div, OP_rem, OP_ashr,
40                                        OP_lshr, OP_shl, OP_max, OP_min, OP_band, OP_bior, OP_bxor};
41     std::vector<uint64> output_ls = {3, 1, 2, 2, 0, 1, 1, 4, 2, 1, 0, 3, 3};
42     PrimType resultType = PTY_i64;
43 
44     MIRModule mirmodule("mirmodule");
45     ConstantFold cf_obj(mirmodule, false);
46     MIRConst *ans;
47     ASSERT_EQ(input_op_ls.size(), output_ls.size());
48     for (int i = 0; i < input_op_ls.size(); i++)
49     {
50         ans = cf_obj.FoldIntConstBinaryMIRConst(input_op_ls[i], resultType, *mc_int_ptr1, *mc_int_ptr2);
51         EXPECT_EQ(((MIRIntConst *)ans)->GetExtValue(), output_ls[i]);
52     }
53 }
54 
TEST(FoldFPConstBinary_FUNC, t02)55 TEST(FoldFPConstBinary_FUNC, t02)
56 {
57     MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
58     // NOTE:After creating a MIRModule class object, mirFunc must be set on the object
59     StIdx sdidx_obj;
60     MemPoolCtrler memPoolCtrler;
61     MemPool memPool(memPoolCtrler, "poolName");
62     MIRFunction mir_func(&mirmodule2, sdidx_obj);
63     mir_func.SetMemPool(&memPool);
64     mirmodule.SetCurFunction(&mir_func);
65 
66     ConstvalNode constvalnode0, constvalnode1;
67     std::vector<PrimType> primtype_ls = {PTY_f32, PTY_f64};
68     std::vector<Opcode> input_op_ls = {OP_add, OP_sub, OP_mul, OP_div, OP_max, OP_min};
69     MIRFloatConst *mirConst_float_ptr1 = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0);
70     MIRFloatConst *mirConst_float_ptr2 = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(2.0);
71 
72     constvalnode0.SetConstVal(mirConst_float_ptr1);
73     constvalnode1.SetConstVal(mirConst_float_ptr2);
74 
75     ConstantFold cf_obj(mirmodule, false);
76     for (int i = 0; i < input_op_ls.size(); i++)
77     {
78         for (int j = 0; j < primtype_ls.size(); j++)
79         {
80             BinaryNode root(input_op_ls[i], primtype_ls[j], &constvalnode0, &constvalnode1);
81             BaseNode *ans = cf_obj.Fold(&root);
82             EXPECT_EQ(ans->IsConstval(), true);
83         }
84     }
85 }
86 
TEST(FoldCompare_FUNC, t03)87 TEST(FoldCompare_FUNC, t03)
88 {
89     MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
90     StIdx sdidx_obj;
91     MemPoolCtrler memPoolCtrler;
92     MemPool memPool(memPoolCtrler, "poolName");
93     MIRFunction mir_func(&mirmodule2, sdidx_obj); // mir_func Can be designed arbitrarily
94     mir_func.SetMemPool(&memPool);                // The SetMemPool function must be called
95     mirmodule.SetCurFunction(&mir_func);
96 
97     PrimType primtyp = PTY_i64;
98     PrimType opndType = PTY_i64; // PTY_64 -> ... -> FoldIntConstComparisonMIRConst
99     std::vector<Opcode> opcode_ls = {OP_eq, OP_ge, OP_gt, OP_le, OP_lt, OP_ne, OP_cmp};
100 
101     MIRIntConst *mc_int_ptr0, *mc_int_ptr1;
102     mc_int_ptr0 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, *GlobalTables::GetTypeTable().GetInt64());
103     mc_int_ptr1 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetInt64());
104 
105     ConstantFold cf_obj(mirmodule, false);
106     ConstvalNode constvalnode0, constvalnode1;
107 
108     for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++)
109     {
110         constvalnode0.SetConstVal(mc_int_ptr0);
111         constvalnode1.SetConstVal(mc_int_ptr1);
112         CompareNode root(opcode_ls[opc_idx], primtyp, opndType, &constvalnode0, &constvalnode1);
113         BaseNode *ans = cf_obj.Fold(&root);
114         EXPECT_EQ(ans->IsConstval(), true);
115     }
116 }
117 
TEST(FoldCompare_FUNC, t04)118 TEST(FoldCompare_FUNC, t04)
119 {
120     MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
121     StIdx sdidx_obj;
122     MemPoolCtrler memPoolCtrler;
123     MemPool memPool(memPoolCtrler, "poolName");
124     MIRFunction mir_func(&mirmodule2, sdidx_obj);
125     mir_func.SetMemPool(&memPool);
126     mirmodule.SetCurFunction(&mir_func);
127 
128     PrimType primtype = PTY_i64;
129     std::vector<PrimType> opndType_ls = {PTY_f32, PTY_f64}; // PTY_f32 or PTY_f64 -> ... -> ComparisonResult
130     std::vector<Opcode> opcode_ls = {OP_eq, OP_ge, OP_gt, OP_le, OP_lt, OP_ne, OP_cmp};
131 
132     MIRConst *ptr_ls[2][2] = {
133         {GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0),
134          GlobalTables::GetFpConstTable().GetOrCreateFloatConst(2.0)},
135         {GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.0),
136          GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(2.0)}};
137 
138     ConstantFold cf_obj(mirmodule, false);
139     ConstvalNode constvalnode0, constvalnode1;
140 
141     for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++)
142     {
143         for (int i = 0; i < opndType_ls.size(); i++)
144         {
145             // NOTE: if opndType==PTY_f32,constvalnode1.SetConstVal() actual paremeter must be PTY_f32
146             constvalnode0.SetConstVal(ptr_ls[i][0]);
147             constvalnode1.SetConstVal(ptr_ls[i][1]);
148             CompareNode root(opcode_ls[opc_idx], primtype, opndType_ls[i], &constvalnode0, &constvalnode1);
149             BaseNode *ans = cf_obj.Fold(&root);
150             EXPECT_EQ(ans->IsConstval(), true);
151         }
152     }
153 }
154 
TEST(FoldUnary_FUNC, t05)155 TEST(FoldUnary_FUNC, t05)
156 {
157     MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
158     StIdx sdidx_obj;
159     MemPoolCtrler memPoolCtrler;
160     MemPool memPool(memPoolCtrler, "poolName");
161     MIRFunction mir_func(&mirmodule2, sdidx_obj);
162     mir_func.SetMemPool(&memPool);
163     mirmodule.SetCurFunction(&mir_func);
164 
165     std::vector<PrimType> primtype_ls = {PTY_i64, PTY_f32, PTY_f64};
166     std::vector<std::vector<Opcode>> opcode_ls = {
167         {OP_abs, OP_bnot, OP_lnot, OP_neg},
168         {OP_abs, OP_neg, OP_sqrt}};
169     std::vector<MIRConst *> const_ptr_ls = {
170         GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, *GlobalTables::GetTypeTable().GetInt64()),
171         GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0),
172         GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.0)};
173     ConstantFold cf_obj(mirmodule, false);
174     ConstvalNode constvalnode;
175 
176     for (int i = 0; i < primtype_ls.size(); i++)
177     {
178         if (!i)
179         { // i=0:integer
180             for (int opc_idx = 0; opc_idx < opcode_ls[0].size(); opc_idx++)
181             {
182                 constvalnode.SetConstVal(const_ptr_ls[i]);
183                 UnaryNode root(opcode_ls[0][opc_idx], primtype_ls[i], &constvalnode);
184                 BaseNode *ans = cf_obj.Fold(&root);
185                 EXPECT_EQ(ans->IsConstval(), true);
186             }
187         }
188         else
189         { // i=1:float ,i=2:double
190             for (int opc_idx = 0; opc_idx < opcode_ls[1].size(); opc_idx++)
191             {
192                 constvalnode.SetConstVal(const_ptr_ls[i]);
193                 UnaryNode root(opcode_ls[1][opc_idx], primtype_ls[i], &constvalnode);
194                 BaseNode *ans = cf_obj.Fold(&root);
195                 EXPECT_EQ(ans->IsConstval(), true);
196             }
197         }
198     }
199 }
200 // The from and to types required to enter the Ceil branch {from,to}.Pay attention to modifying with 'static'
201 static std::vector<std::vector<PrimType>> primtype_pairs_ceil = {
202     {PTY_f32, PTY_f32}, {PTY_f32, PTY_u64}, {PTY_f64, PTY_f64}, {PTY_f64, PTY_u64}};
203 static std::vector<std::vector<PrimType>> primtype_pairs_floor = {
204     {PTY_f32, PTY_i64}, {PTY_f64, PTY_i64}, {PTY_i64, PTY_f32},
205     {PTY_u64, PTY_f32}, {PTY_i64, PTY_f64}, {PTY_u64, PTY_f64}};
206 static std::vector<std::vector<PrimType>> primtype_pairs_cvt = {
207     {PTY_u1, PTY_u1}, {PTY_f32, PTY_u8}};
208 static std::vector<std::vector<PrimType>> primtype_pairs_trunc = {
209     {PTY_f32, PTY_f32}, {PTY_f32, PTY_u64}, {PTY_f64, PTY_f64}, {PTY_f64, PTY_u64}};
210 static std::unordered_map<Opcode, std::vector<std::vector<PrimType>>> opcode_to_primtype_pairs = {
211     {OP_ceil, primtype_pairs_ceil}, {OP_floor, primtype_pairs_floor},
212     {OP_cvt, primtype_pairs_cvt}, {OP_trunc, primtype_pairs_trunc}};
213 
TEST(FoldTypeCvt_FUNC, t06)214 TEST(FoldTypeCvt_FUNC, t06)
215 {
216     MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
217     StIdx sdidx_obj;
218     MemPoolCtrler memPoolCtrler;
219     MemPool memPool(memPoolCtrler, "poolName");
220     MIRFunction mir_func(&mirmodule2, sdidx_obj);
221     mir_func.SetMemPool(&memPool);
222     mirmodule.SetCurFunction(&mir_func);
223     std::vector<Opcode> opcode_ls = {OP_ceil, OP_floor, OP_trunc, OP_cvt};
224     std::unordered_map<PrimType, MIRConst *> const_ptr_map = {
225         {PTY_f32, GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.5)},
226         {PTY_u64, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1,
227                                                                        *GlobalTables::GetTypeTable().GetUInt64())},
228         {PTY_f64, GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.5)},
229         {PTY_i64, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1,
230                                                                        *GlobalTables::GetTypeTable().GetInt64())},
231         {PTY_u1, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1,
232                                                                       *GlobalTables::GetTypeTable().GetUInt1())},
233         {PTY_u8, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1,
234                                                                       *GlobalTables::GetTypeTable().GetUInt8())}};
235     ConstantFold cf_obj(mirmodule, false);
236     ConstvalNode constvalnode;
237     for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++)
238     {
239         Opcode opcode = opcode_ls[opc_idx];
240         for (int pairs_idx = 0; pairs_idx < opcode_to_primtype_pairs[opcode].size(); pairs_idx++)
241         {
242             PrimType from = opcode_to_primtype_pairs[opcode][pairs_idx][0],
243                      to = opcode_to_primtype_pairs[opcode][pairs_idx][1];
244             constvalnode.SetConstVal(const_ptr_map[from]);
245             TypeCvtNode type_cvt_node(opcode, to, from, &constvalnode);
246             BaseNode *ans = cf_obj.Fold(&type_cvt_node);
247             EXPECT_EQ(ans->IsConstval(), true);
248         }
249     }
250 }
251 }  // namespace
252