1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "constantfold.h"
17 #include <cmath>
18 #include <cfloat>
19 #include <climits>
20 #include <type_traits>
21 #include "mpl_logging.h"
22 #include "mir_function.h"
23 #include "mir_builder.h"
24 #include "global_tables.h"
25 #include "me_option.h"
26 #include "maple_phase_manager.h"
27 #include "mir_type.h"
28
29 #include "gtest/gtest.h"
30
31 using namespace maple;
32 using namespace std;
33 namespace {
TEST(FoldIntConstBinaryMIRConst_FUNC, t01)34 TEST(FoldIntConstBinaryMIRConst_FUNC, t01)
35 {
36 MIRIntConst *mc_int_ptr1, *mc_int_ptr2;
37 mc_int_ptr1 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(2, *GlobalTables::GetTypeTable().GetInt64());
38 mc_int_ptr2 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetInt64());
39 std::vector<Opcode> input_op_ls = {OP_add, OP_sub, OP_mul, OP_div, OP_rem, OP_ashr,
40 OP_lshr, OP_shl, OP_max, OP_min, OP_band, OP_bior, OP_bxor};
41 std::vector<uint64> output_ls = {3, 1, 2, 2, 0, 1, 1, 4, 2, 1, 0, 3, 3};
42 PrimType resultType = PTY_i64;
43
44 MIRModule mirmodule("mirmodule");
45 ConstantFold cf_obj(mirmodule, false);
46 MIRConst *ans;
47 ASSERT_EQ(input_op_ls.size(), output_ls.size());
48 for (int i = 0; i < input_op_ls.size(); i++)
49 {
50 ans = cf_obj.FoldIntConstBinaryMIRConst(input_op_ls[i], resultType, *mc_int_ptr1, *mc_int_ptr2);
51 EXPECT_EQ(((MIRIntConst *)ans)->GetExtValue(), output_ls[i]);
52 }
53 }
54
TEST(FoldFPConstBinary_FUNC, t02)55 TEST(FoldFPConstBinary_FUNC, t02)
56 {
57 MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
58 // NOTE:After creating a MIRModule class object, mirFunc must be set on the object
59 StIdx sdidx_obj;
60 MemPoolCtrler memPoolCtrler;
61 MemPool memPool(memPoolCtrler, "poolName");
62 MIRFunction mir_func(&mirmodule2, sdidx_obj);
63 mir_func.SetMemPool(&memPool);
64 mirmodule.SetCurFunction(&mir_func);
65
66 ConstvalNode constvalnode0, constvalnode1;
67 std::vector<PrimType> primtype_ls = {PTY_f32, PTY_f64};
68 std::vector<Opcode> input_op_ls = {OP_add, OP_sub, OP_mul, OP_div, OP_max, OP_min};
69 MIRFloatConst *mirConst_float_ptr1 = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0);
70 MIRFloatConst *mirConst_float_ptr2 = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(2.0);
71
72 constvalnode0.SetConstVal(mirConst_float_ptr1);
73 constvalnode1.SetConstVal(mirConst_float_ptr2);
74
75 ConstantFold cf_obj(mirmodule, false);
76 for (int i = 0; i < input_op_ls.size(); i++)
77 {
78 for (int j = 0; j < primtype_ls.size(); j++)
79 {
80 BinaryNode root(input_op_ls[i], primtype_ls[j], &constvalnode0, &constvalnode1);
81 BaseNode *ans = cf_obj.Fold(&root);
82 EXPECT_EQ(ans->IsConstval(), true);
83 }
84 }
85 }
86
TEST(FoldCompare_FUNC, t03)87 TEST(FoldCompare_FUNC, t03)
88 {
89 MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
90 StIdx sdidx_obj;
91 MemPoolCtrler memPoolCtrler;
92 MemPool memPool(memPoolCtrler, "poolName");
93 MIRFunction mir_func(&mirmodule2, sdidx_obj); // mir_func Can be designed arbitrarily
94 mir_func.SetMemPool(&memPool); // The SetMemPool function must be called
95 mirmodule.SetCurFunction(&mir_func);
96
97 PrimType primtyp = PTY_i64;
98 PrimType opndType = PTY_i64; // PTY_64 -> ... -> FoldIntConstComparisonMIRConst
99 std::vector<Opcode> opcode_ls = {OP_eq, OP_ge, OP_gt, OP_le, OP_lt, OP_ne, OP_cmp};
100
101 MIRIntConst *mc_int_ptr0, *mc_int_ptr1;
102 mc_int_ptr0 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, *GlobalTables::GetTypeTable().GetInt64());
103 mc_int_ptr1 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetInt64());
104
105 ConstantFold cf_obj(mirmodule, false);
106 ConstvalNode constvalnode0, constvalnode1;
107
108 for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++)
109 {
110 constvalnode0.SetConstVal(mc_int_ptr0);
111 constvalnode1.SetConstVal(mc_int_ptr1);
112 CompareNode root(opcode_ls[opc_idx], primtyp, opndType, &constvalnode0, &constvalnode1);
113 BaseNode *ans = cf_obj.Fold(&root);
114 EXPECT_EQ(ans->IsConstval(), true);
115 }
116 }
117
TEST(FoldCompare_FUNC, t04)118 TEST(FoldCompare_FUNC, t04)
119 {
120 MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
121 StIdx sdidx_obj;
122 MemPoolCtrler memPoolCtrler;
123 MemPool memPool(memPoolCtrler, "poolName");
124 MIRFunction mir_func(&mirmodule2, sdidx_obj);
125 mir_func.SetMemPool(&memPool);
126 mirmodule.SetCurFunction(&mir_func);
127
128 PrimType primtype = PTY_i64;
129 std::vector<PrimType> opndType_ls = {PTY_f32, PTY_f64}; // PTY_f32 or PTY_f64 -> ... -> ComparisonResult
130 std::vector<Opcode> opcode_ls = {OP_eq, OP_ge, OP_gt, OP_le, OP_lt, OP_ne, OP_cmp};
131
132 MIRConst *ptr_ls[2][2] = {
133 {GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0),
134 GlobalTables::GetFpConstTable().GetOrCreateFloatConst(2.0)},
135 {GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.0),
136 GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(2.0)}};
137
138 ConstantFold cf_obj(mirmodule, false);
139 ConstvalNode constvalnode0, constvalnode1;
140
141 for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++)
142 {
143 for (int i = 0; i < opndType_ls.size(); i++)
144 {
145 // NOTE: if opndType==PTY_f32,constvalnode1.SetConstVal() actual paremeter must be PTY_f32
146 constvalnode0.SetConstVal(ptr_ls[i][0]);
147 constvalnode1.SetConstVal(ptr_ls[i][1]);
148 CompareNode root(opcode_ls[opc_idx], primtype, opndType_ls[i], &constvalnode0, &constvalnode1);
149 BaseNode *ans = cf_obj.Fold(&root);
150 EXPECT_EQ(ans->IsConstval(), true);
151 }
152 }
153 }
154
TEST(FoldUnary_FUNC, t05)155 TEST(FoldUnary_FUNC, t05)
156 {
157 MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
158 StIdx sdidx_obj;
159 MemPoolCtrler memPoolCtrler;
160 MemPool memPool(memPoolCtrler, "poolName");
161 MIRFunction mir_func(&mirmodule2, sdidx_obj);
162 mir_func.SetMemPool(&memPool);
163 mirmodule.SetCurFunction(&mir_func);
164
165 std::vector<PrimType> primtype_ls = {PTY_i64, PTY_f32, PTY_f64};
166 std::vector<std::vector<Opcode>> opcode_ls = {
167 {OP_abs, OP_bnot, OP_lnot, OP_neg},
168 {OP_abs, OP_neg, OP_sqrt}};
169 std::vector<MIRConst *> const_ptr_ls = {
170 GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, *GlobalTables::GetTypeTable().GetInt64()),
171 GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0),
172 GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.0)};
173 ConstantFold cf_obj(mirmodule, false);
174 ConstvalNode constvalnode;
175
176 for (int i = 0; i < primtype_ls.size(); i++)
177 {
178 if (!i)
179 { // i=0:integer
180 for (int opc_idx = 0; opc_idx < opcode_ls[0].size(); opc_idx++)
181 {
182 constvalnode.SetConstVal(const_ptr_ls[i]);
183 UnaryNode root(opcode_ls[0][opc_idx], primtype_ls[i], &constvalnode);
184 BaseNode *ans = cf_obj.Fold(&root);
185 EXPECT_EQ(ans->IsConstval(), true);
186 }
187 }
188 else
189 { // i=1:float ,i=2:double
190 for (int opc_idx = 0; opc_idx < opcode_ls[1].size(); opc_idx++)
191 {
192 constvalnode.SetConstVal(const_ptr_ls[i]);
193 UnaryNode root(opcode_ls[1][opc_idx], primtype_ls[i], &constvalnode);
194 BaseNode *ans = cf_obj.Fold(&root);
195 EXPECT_EQ(ans->IsConstval(), true);
196 }
197 }
198 }
199 }
200 // The from and to types required to enter the Ceil branch {from,to}.Pay attention to modifying with 'static'
201 static std::vector<std::vector<PrimType>> primtype_pairs_ceil = {
202 {PTY_f32, PTY_f32}, {PTY_f32, PTY_u64}, {PTY_f64, PTY_f64}, {PTY_f64, PTY_u64}};
203 static std::vector<std::vector<PrimType>> primtype_pairs_floor = {
204 {PTY_f32, PTY_i64}, {PTY_f64, PTY_i64}, {PTY_i64, PTY_f32},
205 {PTY_u64, PTY_f32}, {PTY_i64, PTY_f64}, {PTY_u64, PTY_f64}};
206 static std::vector<std::vector<PrimType>> primtype_pairs_cvt = {
207 {PTY_u1, PTY_u1}, {PTY_f32, PTY_u8}};
208 static std::vector<std::vector<PrimType>> primtype_pairs_trunc = {
209 {PTY_f32, PTY_f32}, {PTY_f32, PTY_u64}, {PTY_f64, PTY_f64}, {PTY_f64, PTY_u64}};
210 static std::unordered_map<Opcode, std::vector<std::vector<PrimType>>> opcode_to_primtype_pairs = {
211 {OP_ceil, primtype_pairs_ceil}, {OP_floor, primtype_pairs_floor},
212 {OP_cvt, primtype_pairs_cvt}, {OP_trunc, primtype_pairs_trunc}};
213
TEST(FoldTypeCvt_FUNC, t06)214 TEST(FoldTypeCvt_FUNC, t06)
215 {
216 MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2");
217 StIdx sdidx_obj;
218 MemPoolCtrler memPoolCtrler;
219 MemPool memPool(memPoolCtrler, "poolName");
220 MIRFunction mir_func(&mirmodule2, sdidx_obj);
221 mir_func.SetMemPool(&memPool);
222 mirmodule.SetCurFunction(&mir_func);
223 std::vector<Opcode> opcode_ls = {OP_ceil, OP_floor, OP_trunc, OP_cvt};
224 std::unordered_map<PrimType, MIRConst *> const_ptr_map = {
225 {PTY_f32, GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.5)},
226 {PTY_u64, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1,
227 *GlobalTables::GetTypeTable().GetUInt64())},
228 {PTY_f64, GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.5)},
229 {PTY_i64, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1,
230 *GlobalTables::GetTypeTable().GetInt64())},
231 {PTY_u1, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1,
232 *GlobalTables::GetTypeTable().GetUInt1())},
233 {PTY_u8, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1,
234 *GlobalTables::GetTypeTable().GetUInt8())}};
235 ConstantFold cf_obj(mirmodule, false);
236 ConstvalNode constvalnode;
237 for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++)
238 {
239 Opcode opcode = opcode_ls[opc_idx];
240 for (int pairs_idx = 0; pairs_idx < opcode_to_primtype_pairs[opcode].size(); pairs_idx++)
241 {
242 PrimType from = opcode_to_primtype_pairs[opcode][pairs_idx][0],
243 to = opcode_to_primtype_pairs[opcode][pairs_idx][1];
244 constvalnode.SetConstVal(const_ptr_map[from]);
245 TypeCvtNode type_cvt_node(opcode, to, from, &constvalnode);
246 BaseNode *ans = cf_obj.Fold(&type_cvt_node);
247 EXPECT_EQ(ans->IsConstval(), true);
248 }
249 }
250 }
251 } // namespace
252