/* * Copyright (c) 2023 Huawei Device Co., Ltd. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "constantfold.h" #include #include #include #include #include "mpl_logging.h" #include "mir_function.h" #include "mir_builder.h" #include "global_tables.h" #include "me_option.h" #include "maple_phase_manager.h" #include "mir_type.h" #include "gtest/gtest.h" using namespace maple; using namespace std; namespace { TEST(FoldIntConstBinaryMIRConst_FUNC, t01) { MIRIntConst *mc_int_ptr1, *mc_int_ptr2; mc_int_ptr1 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(2, *GlobalTables::GetTypeTable().GetInt64()); mc_int_ptr2 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetInt64()); std::vector input_op_ls = {OP_add, OP_sub, OP_mul, OP_div, OP_rem, OP_ashr, OP_lshr, OP_shl, OP_max, OP_min, OP_band, OP_bior, OP_bxor}; std::vector output_ls = {3, 1, 2, 2, 0, 1, 1, 4, 2, 1, 0, 3, 3}; PrimType resultType = PTY_i64; MIRModule mirmodule("mirmodule"); ConstantFold cf_obj(mirmodule, false); MIRConst *ans; ASSERT_EQ(input_op_ls.size(), output_ls.size()); for (int i = 0; i < input_op_ls.size(); i++) { ans = cf_obj.FoldIntConstBinaryMIRConst(input_op_ls[i], resultType, *mc_int_ptr1, *mc_int_ptr2); EXPECT_EQ(((MIRIntConst *)ans)->GetExtValue(), output_ls[i]); } } TEST(FoldFPConstBinary_FUNC, t02) { MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2"); // NOTE:After creating a MIRModule class object, mirFunc must be set on the object StIdx sdidx_obj; MemPoolCtrler memPoolCtrler; MemPool memPool(memPoolCtrler, "poolName"); MIRFunction mir_func(&mirmodule2, sdidx_obj); mir_func.SetMemPool(&memPool); mirmodule.SetCurFunction(&mir_func); ConstvalNode constvalnode0, constvalnode1; std::vector primtype_ls = {PTY_f32, PTY_f64}; std::vector input_op_ls = {OP_add, OP_sub, OP_mul, OP_div, OP_max, OP_min}; MIRFloatConst *mirConst_float_ptr1 = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0); MIRFloatConst *mirConst_float_ptr2 = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(2.0); constvalnode0.SetConstVal(mirConst_float_ptr1); constvalnode1.SetConstVal(mirConst_float_ptr2); ConstantFold cf_obj(mirmodule, false); for (int i = 0; i < input_op_ls.size(); i++) { for (int j = 0; j < primtype_ls.size(); j++) { BinaryNode root(input_op_ls[i], primtype_ls[j], &constvalnode0, &constvalnode1); BaseNode *ans = cf_obj.Fold(&root); EXPECT_EQ(ans->IsConstval(), true); } } } TEST(FoldCompare_FUNC, t03) { MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2"); StIdx sdidx_obj; MemPoolCtrler memPoolCtrler; MemPool memPool(memPoolCtrler, "poolName"); MIRFunction mir_func(&mirmodule2, sdidx_obj); // mir_func Can be designed arbitrarily mir_func.SetMemPool(&memPool); // The SetMemPool function must be called mirmodule.SetCurFunction(&mir_func); PrimType primtyp = PTY_i64; PrimType opndType = PTY_i64; // PTY_64 -> ... -> FoldIntConstComparisonMIRConst std::vector opcode_ls = {OP_eq, OP_ge, OP_gt, OP_le, OP_lt, OP_ne, OP_cmp}; MIRIntConst *mc_int_ptr0, *mc_int_ptr1; mc_int_ptr0 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, *GlobalTables::GetTypeTable().GetInt64()); mc_int_ptr1 = GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetInt64()); ConstantFold cf_obj(mirmodule, false); ConstvalNode constvalnode0, constvalnode1; for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++) { constvalnode0.SetConstVal(mc_int_ptr0); constvalnode1.SetConstVal(mc_int_ptr1); CompareNode root(opcode_ls[opc_idx], primtyp, opndType, &constvalnode0, &constvalnode1); BaseNode *ans = cf_obj.Fold(&root); EXPECT_EQ(ans->IsConstval(), true); } } TEST(FoldCompare_FUNC, t04) { MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2"); StIdx sdidx_obj; MemPoolCtrler memPoolCtrler; MemPool memPool(memPoolCtrler, "poolName"); MIRFunction mir_func(&mirmodule2, sdidx_obj); mir_func.SetMemPool(&memPool); mirmodule.SetCurFunction(&mir_func); PrimType primtype = PTY_i64; std::vector opndType_ls = {PTY_f32, PTY_f64}; // PTY_f32 or PTY_f64 -> ... -> ComparisonResult std::vector opcode_ls = {OP_eq, OP_ge, OP_gt, OP_le, OP_lt, OP_ne, OP_cmp}; MIRConst *ptr_ls[2][2] = { {GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0), GlobalTables::GetFpConstTable().GetOrCreateFloatConst(2.0)}, {GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.0), GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(2.0)}}; ConstantFold cf_obj(mirmodule, false); ConstvalNode constvalnode0, constvalnode1; for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++) { for (int i = 0; i < opndType_ls.size(); i++) { // NOTE: if opndType==PTY_f32,constvalnode1.SetConstVal() actual paremeter must be PTY_f32 constvalnode0.SetConstVal(ptr_ls[i][0]); constvalnode1.SetConstVal(ptr_ls[i][1]); CompareNode root(opcode_ls[opc_idx], primtype, opndType_ls[i], &constvalnode0, &constvalnode1); BaseNode *ans = cf_obj.Fold(&root); EXPECT_EQ(ans->IsConstval(), true); } } } TEST(FoldUnary_FUNC, t05) { MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2"); StIdx sdidx_obj; MemPoolCtrler memPoolCtrler; MemPool memPool(memPoolCtrler, "poolName"); MIRFunction mir_func(&mirmodule2, sdidx_obj); mir_func.SetMemPool(&memPool); mirmodule.SetCurFunction(&mir_func); std::vector primtype_ls = {PTY_i64, PTY_f32, PTY_f64}; std::vector> opcode_ls = { {OP_abs, OP_bnot, OP_lnot, OP_neg}, {OP_abs, OP_neg, OP_sqrt}}; std::vector const_ptr_ls = { GlobalTables::GetIntConstTable().GetOrCreateIntConst(0, *GlobalTables::GetTypeTable().GetInt64()), GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.0), GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.0)}; ConstantFold cf_obj(mirmodule, false); ConstvalNode constvalnode; for (int i = 0; i < primtype_ls.size(); i++) { if (!i) { // i=0:integer for (int opc_idx = 0; opc_idx < opcode_ls[0].size(); opc_idx++) { constvalnode.SetConstVal(const_ptr_ls[i]); UnaryNode root(opcode_ls[0][opc_idx], primtype_ls[i], &constvalnode); BaseNode *ans = cf_obj.Fold(&root); EXPECT_EQ(ans->IsConstval(), true); } } else { // i=1:float ,i=2:double for (int opc_idx = 0; opc_idx < opcode_ls[1].size(); opc_idx++) { constvalnode.SetConstVal(const_ptr_ls[i]); UnaryNode root(opcode_ls[1][opc_idx], primtype_ls[i], &constvalnode); BaseNode *ans = cf_obj.Fold(&root); EXPECT_EQ(ans->IsConstval(), true); } } } } // The from and to types required to enter the Ceil branch {from,to}.Pay attention to modifying with 'static' static std::vector> primtype_pairs_ceil = { {PTY_f32, PTY_f32}, {PTY_f32, PTY_u64}, {PTY_f64, PTY_f64}, {PTY_f64, PTY_u64}}; static std::vector> primtype_pairs_floor = { {PTY_f32, PTY_i64}, {PTY_f64, PTY_i64}, {PTY_i64, PTY_f32}, {PTY_u64, PTY_f32}, {PTY_i64, PTY_f64}, {PTY_u64, PTY_f64}}; static std::vector> primtype_pairs_cvt = { {PTY_u1, PTY_u1}, {PTY_f32, PTY_u8}}; static std::vector> primtype_pairs_trunc = { {PTY_f32, PTY_f32}, {PTY_f32, PTY_u64}, {PTY_f64, PTY_f64}, {PTY_f64, PTY_u64}}; static std::unordered_map>> opcode_to_primtype_pairs = { {OP_ceil, primtype_pairs_ceil}, {OP_floor, primtype_pairs_floor}, {OP_cvt, primtype_pairs_cvt}, {OP_trunc, primtype_pairs_trunc}}; TEST(FoldTypeCvt_FUNC, t06) { MIRModule mirmodule("mirmodule"), mirmodule2("mirmodule2"); StIdx sdidx_obj; MemPoolCtrler memPoolCtrler; MemPool memPool(memPoolCtrler, "poolName"); MIRFunction mir_func(&mirmodule2, sdidx_obj); mir_func.SetMemPool(&memPool); mirmodule.SetCurFunction(&mir_func); std::vector opcode_ls = {OP_ceil, OP_floor, OP_trunc, OP_cvt}; std::unordered_map const_ptr_map = { {PTY_f32, GlobalTables::GetFpConstTable().GetOrCreateFloatConst(1.5)}, {PTY_u64, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetUInt64())}, {PTY_f64, GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(1.5)}, {PTY_i64, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetInt64())}, {PTY_u1, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetUInt1())}, {PTY_u8, GlobalTables::GetIntConstTable().GetOrCreateIntConst(1, *GlobalTables::GetTypeTable().GetUInt8())}}; ConstantFold cf_obj(mirmodule, false); ConstvalNode constvalnode; for (int opc_idx = 0; opc_idx < opcode_ls.size(); opc_idx++) { Opcode opcode = opcode_ls[opc_idx]; for (int pairs_idx = 0; pairs_idx < opcode_to_primtype_pairs[opcode].size(); pairs_idx++) { PrimType from = opcode_to_primtype_pairs[opcode][pairs_idx][0], to = opcode_to_primtype_pairs[opcode][pairs_idx][1]; constvalnode.SetConstVal(const_ptr_map[from]); TypeCvtNode type_cvt_node(opcode, to, from, &constvalnode); BaseNode *ans = cf_obj.Fold(&type_cvt_node); EXPECT_EQ(ans->IsConstval(), true); } } } } // namespace