1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "constantfold.h"
17 #include <cmath>
18 #include <cfloat>
19 #include <climits>
20 #include <type_traits>
21 #include "mpl_logging.h"
22 #include "mir_function.h"
23 #include "mir_builder.h"
24 #include "global_tables.h"
25 #include "me_option.h"
26 #include "maple_phase_manager.h"
27 #include "mir_type.h"
28 
29 namespace maple {
30 
31 namespace {
32 constexpr uint32 kByteSizeOfBit64 = 8;                            // byte number for 64 bit
33 constexpr uint32 kBitSizePerByte = 8;
34 constexpr maple::int32 kMaxOffset = INT_MAX - 8;
35 
36 enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
37 
operator *(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)38 std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
39 {
40     if (!v1 && !v2) {
41         return std::nullopt;
42     }
43 
44     // Perform all calculations in terms of the maximum available signed type.
45     // The value will be truncated for an appropriate type when constant is created in PairToExpr function
46     return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64);
47 }
48 
49 // Perform all calculations in terms of the maximum available signed type.
50 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)51 std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
52 {
53     if (!v1 && !v2) {
54         return std::nullopt;
55     }
56 
57     if (v1 && v2) {
58         return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
59     }
60 
61     if (v1) {
62         return v1->TruncOrExtend(PTY_i64);
63     }
64 
65     // !v1 && v2
66     return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
67 }
68 
operator +(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)69 std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
70 {
71     return AddSub(v1, v2, true);
72 }
73 
operator -(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)74 std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
75 {
76     return AddSub(v1, v2, false);
77 }
78 
79 }  // anonymous namespace
80 
81 // This phase is designed to achieve compiler optimization by
82 // simplifying constant expressions. The constant expression
83 // is evaluated and replaced by the value calculated on compile
84 // time to save time on runtime.
85 //
86 // The main procedure shows as following:
87 // A. Analyze expression type
88 // B. Analysis operator type
89 // C. Replace the expression with the result of the operation
90 
91 // true if the constant's bits are made of only one group of contiguous 1's
92 // starting at bit 0
ContiguousBitsOf1(uint64 x)93 static bool ContiguousBitsOf1(uint64 x)
94 {
95     if (x == 0) {
96         return false;
97     }
98     return (~x & (x + 1)) == (x + 1);
99 }
100 
IsPowerOf2(uint64 num)101 inline bool IsPowerOf2(uint64 num)
102 {
103     if (num == 0) {
104         return false;
105     }
106     return (~(num - 1) & num) == num;
107 }
108 
NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs, BaseNode *rhs) const109 BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
110                                         BaseNode *rhs) const
111 {
112     CHECK_NULL_FATAL(old);
113     BinaryNode *result = nullptr;
114     if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
115         result = old;
116     } else {
117         result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
118     }
119     return result;
120 }
121 
NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const122 UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
123 {
124     CHECK_NULL_FATAL(old);
125     UnaryNode *result = nullptr;
126     if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
127         result = old;
128     } else {
129         result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
130     }
131     return result;
132 }
133 
PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const134 BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const
135 {
136     CHECK_NULL_FATAL(pair.first);
137     BaseNode *result = pair.first;
138     if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) {
139         return result;
140     }
141     if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
142         // -a, 5 -> 5 - a
143         ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
144             static_cast<uint64>(pair.second->GetExtValue()), resultType);
145         BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0);
146         result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
147     } else {
148         if ((!pair.second->GetSignBit() &&
149             pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) ||
150             pair.second->TruncOrExtend(resultType).IsMinValue() ||
151             pair.second->GetSXTValue() == INT64_MIN) {
152             // +-a, 5 -> a + 5
153             ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
154                 static_cast<uint64>(pair.second->GetExtValue()), resultType);
155             result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
156         } else {
157             // +-a, -5 -> a + -5
158             ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
159                 static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType);
160             result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
161         }
162     }
163     return result;
164 }
165 
FoldBase(BaseNode *node) const166 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
167 {
168     return std::make_pair(node, std::nullopt);
169 }
170 
DispatchFold(BaseNode *node)171 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
172 {
173     CHECK_NULL_FATAL(node);
174     if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
175         return {node, std::nullopt};
176     }
177     switch (node->GetOpCode()) {
178         case OP_abs:
179         case OP_bnot:
180         case OP_lnot:
181         case OP_neg:
182         case OP_sqrt:
183             return FoldUnary(static_cast<UnaryNode*>(node));
184         case OP_ceil:
185         case OP_floor:
186         case OP_trunc:
187         case OP_cvt:
188             return FoldTypeCvt(static_cast<TypeCvtNode*>(node));
189         case OP_sext:
190         case OP_zext:
191         case OP_extractbits:
192             return FoldExtractbits(static_cast<ExtractbitsNode*>(node));
193         case OP_iread:
194             return FoldIread(static_cast<IreadNode*>(node));
195         case OP_add:
196         case OP_ashr:
197         case OP_band:
198         case OP_bior:
199         case OP_bxor:
200         case OP_div:
201         case OP_lshr:
202         case OP_max:
203         case OP_min:
204         case OP_mul:
205         case OP_rem:
206         case OP_shl:
207         case OP_sub:
208             return FoldBinary(static_cast<BinaryNode*>(node));
209         case OP_eq:
210         case OP_ne:
211         case OP_ge:
212         case OP_gt:
213         case OP_le:
214         case OP_lt:
215         case OP_cmp:
216             return FoldCompare(static_cast<CompareNode*>(node));
217         case OP_retype:
218             return FoldRetype(static_cast<RetypeNode*>(node));
219         default:
220             return FoldBase(static_cast<BaseNode*>(node));
221     }
222 }
223 
Negate(BaseNode *node) const224 BaseNode *ConstantFold::Negate(BaseNode *node) const
225 {
226     CHECK_NULL_FATAL(node);
227     return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
228 }
229 
Negate(UnaryNode *node) const230 BaseNode *ConstantFold::Negate(UnaryNode *node) const
231 {
232     CHECK_NULL_FATAL(node);
233     BaseNode *result = nullptr;
234     if (node->GetOpCode() == OP_neg) {
235         result = static_cast<BaseNode*>(node->Opnd(0));
236     } else {
237         BaseNode *n = static_cast<BaseNode*>(node);
238         result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
239     }
240     return result;
241 }
242 
Negate(const ConstvalNode *node) const243 BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
244 {
245     CHECK_NULL_FATAL(node);
246     ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
247     CHECK_NULL_FATAL(copy);
248     copy->GetConstVal()->Neg();
249     return copy;
250 }
251 
NegateTree(BaseNode *node) const252 BaseNode *ConstantFold::NegateTree(BaseNode *node) const
253 {
254     CHECK_NULL_FATAL(node);
255     if (node->IsUnaryNode()) {
256         return Negate(static_cast<UnaryNode*>(node));
257     } else if (node->GetOpCode() == OP_constval) {
258         return Negate(static_cast<ConstvalNode*>(node));
259     } else {
260         return Negate(static_cast<BaseNode*>(node));
261     }
262 }
263 
FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, const MIRIntConst &intConst0, const MIRIntConst &intConst1) const264 MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
265                                                           const MIRIntConst &intConst0,
266                                                           const MIRIntConst &intConst1) const
267 {
268     uint64 result = 0;
269 
270     bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
271     bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
272     bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
273 
274     switch (opcode) {
275         case OP_eq: {
276             result = equal;
277             break;
278         }
279         case OP_ge: {
280             result = (greater || equal) ? 1 : 0;
281             break;
282         }
283         case OP_gt: {
284             result = greater;
285             break;
286         }
287         case OP_le: {
288             result = (less || equal) ? 1 : 0;
289             break;
290         }
291         case OP_lt: {
292             result = less;
293             break;
294         }
295         case OP_ne: {
296             result = !equal;
297             break;
298         }
299         case OP_cmp: {
300             if (greater) {
301                 result = kGreater;
302             } else if (equal) {
303                 result = kEqual;
304             } else if (less) {
305                 result = static_cast<uint64>(kLess);
306             }
307             break;
308         }
309         default:
310             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
311             break;
312     }
313     // determine the type
314     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
315     // form the constant
316     MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
317     return constValue;
318 }
319 
FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, const ConstvalNode &const0, const ConstvalNode &const1) const320 ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
321                                                    const ConstvalNode &const0, const ConstvalNode &const1) const
322 {
323     const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
324     const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
325     CHECK_NULL_FATAL(intConst0);
326     CHECK_NULL_FATAL(intConst1);
327     MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
328     // form the ConstvalNode
329     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
330     resultConst->SetPrimType(resultType);
331     resultConst->SetConstVal(constValue);
332     return resultConst;
333 }
334 
FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0, const MIRIntConst &intConst1)335 MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
336                                                    const MIRIntConst &intConst1)
337 {
338     IntVal intVal0 = intConst0.GetValue();
339     IntVal intVal1 = intConst1.GetValue();
340     IntVal result(static_cast<uint64>(0), resultType);
341 
342     switch (opcode) {
343         case OP_add: {
344             result = intVal0.Add(intVal1, resultType);
345             break;
346         }
347         case OP_sub: {
348             result = intVal0.Sub(intVal1, resultType);
349             break;
350         }
351         case OP_mul: {
352             result = intVal0.Mul(intVal1, resultType);
353             break;
354         }
355         case OP_div: {
356             result = intVal0.Div(intVal1, resultType);
357             break;
358         }
359         case OP_rem: {
360             result = intVal0.Rem(intVal1, resultType);
361             break;
362         }
363         case OP_ashr: {
364             result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
365             break;
366         }
367         case OP_lshr: {
368             result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
369             break;
370         }
371         case OP_shl: {
372             result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
373             break;
374         }
375         case OP_max: {
376             result = Max(intVal0, intVal1, resultType);
377             break;
378         }
379         case OP_min: {
380             result = Min(intVal0, intVal1, resultType);
381             break;
382         }
383         case OP_band: {
384             result = intVal0.And(intVal1, resultType);
385             break;
386         }
387         case OP_bior: {
388             result = intVal0.Or(intVal1, resultType);
389             break;
390         }
391         case OP_bxor: {
392             result = intVal0.Xor(intVal1, resultType);
393             break;
394         }
395         default:
396             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
397             break;
398     }
399     // determine the type
400     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
401     // form the constant
402     MIRIntConst *constValue =
403         GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
404     return constValue;
405 }
406 
FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, const ConstvalNode &const1) const407 ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
408                                                const ConstvalNode &const1) const
409 {
410     const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
411     const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
412     CHECK_NULL_FATAL(intConst0);
413     CHECK_NULL_FATAL(intConst1);
414     MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1);
415     // form the ConstvalNode
416     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
417     resultConst->SetPrimType(resultType);
418     resultConst->SetConstVal(constValue);
419     return resultConst;
420 }
421 
FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, const ConstvalNode &const1) const422 ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
423                                               const ConstvalNode &const1) const
424 {
425     DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
426     const MIRDoubleConst *doubleConst0 = nullptr;
427     const MIRDoubleConst *doubleConst1 = nullptr;
428     const MIRFloatConst *floatConst0 = nullptr;
429     const MIRFloatConst *floatConst1 = nullptr;
430     bool useDouble = (const0.GetPrimType() == PTY_f64);
431     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
432     resultConst->SetPrimType(resultType);
433     if (useDouble) {
434         doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
435         doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
436         CHECK_NULL_FATAL(doubleConst0);
437         CHECK_NULL_FATAL(doubleConst1);
438     } else {
439         floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
440         floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
441         CHECK_NULL_FATAL(floatConst0);
442         CHECK_NULL_FATAL(floatConst1);
443     }
444     float constValueFloat = 0.0;
445     double constValueDouble = 0.0;
446     switch (opcode) {
447         case OP_add: {
448             if (useDouble) {
449                 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
450             } else {
451                 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
452             }
453             break;
454         }
455         case OP_sub: {
456             if (useDouble) {
457                 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
458             } else {
459                 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
460             }
461             break;
462         }
463         case OP_mul: {
464             if (useDouble) {
465                 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
466             } else {
467                 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
468             }
469             break;
470         }
471         case OP_div: {
472             // for floats div by 0 is well defined
473             if (useDouble) {
474                 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
475             } else {
476                 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
477             }
478             break;
479         }
480         case OP_max: {
481             if (useDouble) {
482                 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
483                                                                                         : doubleConst1->GetValue();
484             } else {
485                 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
486                                                                                     : floatConst1->GetValue();
487             }
488             break;
489         }
490         case OP_min: {
491             if (useDouble) {
492                 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
493                                                                                         : doubleConst1->GetValue();
494             } else {
495                 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
496                                                                                     : floatConst1->GetValue();
497             }
498             break;
499         }
500         case OP_rem:
501         case OP_ashr:
502         case OP_lshr:
503         case OP_shl:
504         case OP_band:
505         case OP_bior:
506         case OP_bxor: {
507             DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
508             break;
509         }
510         default:
511             DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
512             break;
513     }
514     if (resultType == PTY_f64) {
515         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
516     } else {
517         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
518     }
519     return resultConst;
520 }
521 
ConstValueEqual(int64 leftValue, int64 rightValue) const522 bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
523 {
524     return (leftValue == rightValue);
525 }
526 
ConstValueEqual(float leftValue, float rightValue) const527 bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
528 {
529     auto result = fabs(leftValue - rightValue);
530     return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN;
531 }
532 
ConstValueEqual(double leftValue, double rightValue) const533 bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
534 {
535     auto result = fabs(leftValue - rightValue);
536     return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN;
537 }
538 
539 template<typename T>
FullyEqual(T leftValue, T rightValue) const540 bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
541 {
542     if (std::isinf(leftValue) && std::isinf(rightValue)) {
543         // (inf == inf), add the judgement here in case of the subtraction between float type inf
544         return true;
545     } else {
546         return ConstValueEqual(leftValue, rightValue);
547     }
548 }
549 
550 template<typename T>
ComparisonResult(Opcode op, T *leftConst, T *rightConst) const551 int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
552 {
553     DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr");
554     typename T::value_type leftValue = leftConst->GetValue();
555     DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr");
556     typename T::value_type rightValue = rightConst->GetValue();
557     int64 result = 0;
558     switch (op) {
559         case OP_eq: {
560             result = FullyEqual(leftValue, rightValue);
561             break;
562         }
563         case OP_ge: {
564             result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
565             break;
566         }
567         case OP_gt: {
568             result = (leftValue > rightValue);
569             break;
570         }
571         case OP_le: {
572             result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
573             break;
574         }
575         case OP_lt: {
576             result = (leftValue < rightValue);
577             break;
578         }
579         case OP_ne: {
580             result = !FullyEqual(leftValue, rightValue);
581             break;
582         }
583         [[clang::fallthrough]];
584         case OP_cmp: {
585             if (leftValue > rightValue) {
586                 result = kGreater;
587             } else if (FullyEqual(leftValue, rightValue)) {
588                 result = kEqual;
589             } else {
590                 result = kLess;
591             }
592             break;
593         }
594         default:
595             DEBUG_ASSERT(false, "Unknown opcode for Comparison");
596             break;
597     }
598     return result;
599 }
600 
FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, const MIRConst &leftConst, const MIRConst &rightConst) const601 MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
602                                                          const MIRConst &leftConst, const MIRConst &rightConst) const
603 {
604     int64 result = 0;
605     bool useDouble = (opndType == PTY_f64);
606     if (useDouble) {
607         result =
608             ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
609     } else {
610         result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
611     }
612     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
613     MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type);
614     return resultConst;
615 }
616 
FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, const ConstvalNode &const0, const ConstvalNode &const1) const617 ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
618                                                   const ConstvalNode &const0, const ConstvalNode &const1) const
619 {
620     DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
621     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
622     resultConst->SetPrimType(resultType);
623     resultConst->SetConstVal(
624         FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
625     return resultConst;
626 }
627 
FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, const MIRConst &const0, const MIRConst &const1) const628 MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
629                                                     const MIRConst &const0, const MIRConst &const1) const
630 {
631     MIRConst *returnValue = nullptr;
632     if (IsPrimitiveInteger(opndType)) {
633         const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
634         const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
635         ASSERT_NOT_NULL(intConst0);
636         ASSERT_NOT_NULL(intConst1);
637         returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
638     } else if (opndType == PTY_f32 || opndType == PTY_f64) {
639         returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
640     } else {
641         DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
642     }
643     return returnValue;
644 }
645 
FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, const ConstvalNode &const0, const ConstvalNode &const1) const646 ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
647                                                 const ConstvalNode &const0, const ConstvalNode &const1) const
648 {
649     ConstvalNode *returnValue = nullptr;
650     if (IsPrimitiveInteger(opndType)) {
651         returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
652     } else if (opndType == PTY_f32 || opndType == PTY_f64) {
653         returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
654     } else {
655         DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
656     }
657     return returnValue;
658 }
659 
FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType, BaseNode &l, BaseNode &r) const660 CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
661                                                       BaseNode &l, BaseNode &r) const
662 {
663     CompareNode *result = nullptr;
664     Opcode op = opcode;
665     switch (opcode) {
666         case OP_gt: {
667             op = OP_lt;
668             break;
669         }
670         case OP_lt: {
671             op = OP_gt;
672             break;
673         }
674         case OP_ge: {
675             op = OP_le;
676             break;
677         }
678         case OP_le: {
679             op = OP_ge;
680             break;
681         }
682         case OP_eq: {
683             break;
684         }
685         case OP_ne: {
686             break;
687         }
688         default:
689             DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
690             break;
691     }
692 
693     result =
694         mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
695     return result;
696 }
697 
FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, const ConstvalNode &const1) const698 ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
699                                             const ConstvalNode &const1) const
700 {
701     ConstvalNode *returnValue = nullptr;
702     if (IsPrimitiveInteger(resultType)) {
703         returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
704     } else if (resultType == PTY_f32 || resultType == PTY_f64) {
705         returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
706     } else {
707         DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
708     }
709     return returnValue;
710 }
711 
FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)712 MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)
713 {
714     CHECK_NULL_FATAL(constNode);
715     IntVal result = constNode->GetValue().TruncOrExtend(resultType);
716     switch (opcode) {
717         case OP_abs: {
718             if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) {
719                 result = -result;
720             }
721             break;
722         }
723         case OP_bnot: {
724             result = ~result;
725             break;
726         }
727         case OP_lnot: {
728             uint64 resultInt = result == 0 ? 1 : 0;
729             result = {resultInt, resultType};
730             break;
731         }
732         case OP_neg: {
733             result = -result;
734             break;
735         }
736         case OP_sext:         // handled in FoldExtractbits
737         case OP_zext:         // handled in FoldExtractbits
738         case OP_extractbits:  // handled in FoldExtractbits
739         case OP_sqrt: {
740             DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst");
741             break;
742         }
743         default:
744             DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst");
745             break;
746     }
747     // determine the type
748     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
749     // form the constant
750     MIRIntConst *constValue =
751         GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
752     return constValue;
753 }
754 
755 template <typename T>
FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const756 ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
757 {
758     CHECK_NULL_FATAL(constNode);
759     double constValue = 0;
760     T *fpCst = static_cast<T*>(constNode->GetConstVal());
761     switch (opcode) {
762         case OP_neg: {
763             constValue = typename T::value_type(-fpCst->GetValue());
764             break;
765         }
766         case OP_abs: {
767             constValue = typename T::value_type(fabs(fpCst->GetValue()));
768             break;
769         }
770         case OP_sqrt: {
771             constValue = typename T::value_type(sqrt(fpCst->GetValue()));
772             break;
773         }
774         case OP_bnot:
775         case OP_lnot:
776         case OP_sext:
777         case OP_zext:
778         case OP_extractbits: {
779             DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
780             break;
781         }
782         default:
783             DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
784             break;
785     }
786     auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
787     resultConst->SetPrimType(resultType);
788     if (resultType == PTY_f32) {
789         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue)));
790     } else if (resultType == PTY_f64) {
791         resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
792     } else {
793         CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64");
794     }
795     return resultConst;
796 }
797 
FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const798 ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const
799 {
800     ConstvalNode *returnValue = nullptr;
801     if (IsPrimitiveInteger(resultType)) {
802         const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal());
803         auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst);
804         returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
805         returnValue->SetPrimType(resultType);
806         returnValue->SetConstVal(constValue);
807     } else if (resultType == PTY_f32) {
808         returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode);
809     } else if (resultType == PTY_f64) {
810         returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode);
811     } else {
812         DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
813     }
814     return returnValue;
815 }
816 
FoldRetype(RetypeNode *node)817 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
818 {
819     CHECK_NULL_FATAL(node);
820     BaseNode *result = node;
821     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
822     if (node->Opnd(0) != p.first) {
823         RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
824         CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
825         newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
826         result = newRetNode;
827     }
828     return std::make_pair(result, std::nullopt);
829 }
830 
FoldUnary(UnaryNode *node)831 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
832 {
833     CHECK_NULL_FATAL(node);
834     BaseNode *result = nullptr;
835     std::optional<IntVal> sum = std::nullopt;
836     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
837     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
838     if (cst != nullptr) {
839         result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst);
840     } else {
841         bool isInt = IsPrimitiveInteger(node->GetPrimType());
842         // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
843         // primType will be set to opnd type. There will be problems in some cases. For example:
844         // before cf:
845         //   neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
846         // after cf:
847         //   neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))  # wrong!
848         // As a workaround, we exclude u1 opnd type
849         if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
850             result = NegateTree(p.first);
851             if (result->GetOpCode() == OP_neg) {
852                 PrimType origPtyp = node->GetPrimType();
853                 PrimType newPtyp = result->GetPrimType();
854                 if (newPtyp == origPtyp) {
855                 if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) {
856                     // NegateTree returned an UnaryNode quivalent to `n`, so keep the
857                     // original UnaryNode to preserve identity
858                     result = node;
859                 }
860                 } else {
861                     if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
862                         // do not fold explicit cvt
863                         result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(),
864                             PairToExpr(node->Opnd(0)->GetPrimType(), p));
865                         return std::make_pair(result, std::nullopt);
866                     } else {
867                         result->SetPrimType(origPtyp);
868                     }
869                 }
870             }
871             if (p.second) {
872                 sum = -(*p.second);
873             }
874         } else {
875             result =
876                 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
877         }
878     }
879     return std::make_pair(result, sum);
880 }
881 
FloatToIntOverflow(float fval, PrimType totype)882 static bool FloatToIntOverflow(float fval, PrimType totype)
883 {
884     static const float safeFloatMaxToInt32 = 2147483520.0f;  // 2^31 - 128
885     static const float safeFloatMinToInt32 = -2147483520.0f;
886     static const float safeFloatMaxToInt64 = 9223372036854775680.0f;  // 2^63 - 128
887     static const float safeFloatMinToInt64 = -9223372036854775680.0f;
888     if (!std::isfinite(fval)) {
889         return true;
890     }
891     if (totype == PTY_i64 || totype == PTY_u64) {
892         if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
893             return true;
894         }
895     } else {
896         if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
897             return true;
898         }
899     }
900     return false;
901 }
902 
DoubleToIntOverflow(double dval, PrimType totype)903 static bool DoubleToIntOverflow(double dval, PrimType totype)
904 {
905     static const double safeDoubleMaxToInt32 = 2147482624.0;  // 2^31 - 1024
906     static const double safeDoubleMinToInt32 = -2147482624.0;
907     static const double safeDoubleMaxToInt64 = 9223372036854774784.0;  // 2^63 - 1024
908     static const double safeDoubleMinToInt64 = -9223372036854774784.0;
909     if (!std::isfinite(dval)) {
910         return true;
911     }
912     if (totype == PTY_i64 || totype == PTY_u64) {
913         if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
914             return true;
915         }
916     } else {
917         if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
918             return true;
919         }
920     }
921     return false;
922 }
923 
FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const924 ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
925 {
926     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
927     resultConst->SetPrimType(toType);
928     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
929     if (fromType == PTY_f32) {
930         const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
931         ASSERT_NOT_NULL(constValue);
932         float floatValue = ceil(constValue->GetValue());
933         if (IsPrimitiveFloat(toType)) {
934             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
935         } else if (FloatToIntOverflow(floatValue, toType)) {
936             return nullptr;
937         } else {
938             resultConst->SetConstVal(
939                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
940         }
941     } else {
942         const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
943         ASSERT_NOT_NULL(constValue);
944         double doubleValue = ceil(constValue->GetValue());
945         if (IsPrimitiveFloat(toType)) {
946             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
947         } else if (DoubleToIntOverflow(doubleValue, toType)) {
948             return nullptr;
949         } else {
950             resultConst->SetConstVal(
951                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
952         }
953     }
954     return resultConst;
955 }
956 
957 template <class T>
CalIntValueFromFloatValue(T value, const MIRType &resultType) const958 T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
959 {
960     DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
961     size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
962     bool isSigned = IsSignedInteger(resultType.GetPrimType());
963     int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue();
964     uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum;
965     int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0;
966     if (isSigned && (value > max)) {
967         return static_cast<T>(max);
968     } else if (!isSigned && (value > umax)) {
969         return static_cast<T>(umax);
970     } else if (value < min) {
971         return static_cast<T>(min);
972     }
973     return value;
974 }
975 
FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const976 MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
977 {
978     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
979     if (fromType == PTY_f32) {
980         const auto &constValue = static_cast<const MIRFloatConst&>(cst);
981         float floatValue = constValue.GetValue();
982         if (isFloor) {
983             floatValue = floor(constValue.GetValue());
984         }
985         if (IsPrimitiveFloat(toType)) {
986             return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
987         }
988         if (FloatToIntOverflow(floatValue, toType)) {
989             return nullptr;
990         }
991         floatValue = CalIntValueFromFloatValue(floatValue, resultType);
992         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType);
993     } else {
994         const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
995         double doubleValue = constValue.GetValue();
996         if (isFloor) {
997             doubleValue = floor(constValue.GetValue());
998         }
999         if (IsPrimitiveFloat(toType)) {
1000             return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1001         }
1002         if (DoubleToIntOverflow(doubleValue, toType)) {
1003             return nullptr;
1004         }
1005         doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
1006         // gcc/clang have bugs convert double to unsigned long, must convert to signed long first;
1007         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1008     }
1009 }
1010 
FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const1011 ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1012 {
1013     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1014     resultConst->SetPrimType(toType);
1015     resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1016     return resultConst;
1017 }
1018 
FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const1019 MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1020 {
1021     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1022     if (fromType == PTY_f32) {
1023         const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1024         float floatValue = round(constValue.GetValue());
1025         if (FloatToIntOverflow(floatValue, toType)) {
1026             return nullptr;
1027         }
1028         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1029     } else if (fromType == PTY_f64) {
1030         const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1031         double doubleValue = round(constValue.GetValue());
1032         if (DoubleToIntOverflow(doubleValue, toType)) {
1033             return nullptr;
1034         }
1035         return GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1036             static_cast<uint64>(static_cast<int64>(doubleValue)), resultType);
1037     } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1038         const auto &constValue = static_cast<const MIRIntConst&>(cst);
1039         if (IsSignedInteger(fromType)) {
1040             int64 fromValue = constValue.GetExtValue();
1041             float floatValue = round(static_cast<float>(fromValue));
1042             if (static_cast<int64>(floatValue) == fromValue) {
1043                 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1044             }
1045         } else {
1046             uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1047             float floatValue = round(static_cast<float>(fromValue));
1048             if (static_cast<uint64>(floatValue) == fromValue) {
1049                 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1050             }
1051         }
1052     } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1053         const auto &constValue = static_cast<const MIRIntConst&>(cst);
1054         if (IsSignedInteger(fromType)) {
1055             int64 fromValue = constValue.GetExtValue();
1056             double doubleValue = round(static_cast<double>(fromValue));
1057             if (static_cast<int64>(doubleValue) == fromValue) {
1058                 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1059             }
1060         } else {
1061             uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1062             double doubleValue = round(static_cast<double>(fromValue));
1063             if (static_cast<uint64>(doubleValue) == fromValue) {
1064                 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1065             }
1066         }
1067     }
1068     return nullptr;
1069 }
1070 
FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const1071 ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1072 {
1073     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1074     resultConst->SetPrimType(toType);
1075     resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1076     return resultConst;
1077 }
1078 
FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const1079 ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1080 {
1081     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1082     resultConst->SetPrimType(toType);
1083     MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1084     if (fromType == PTY_f32) {
1085         const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1086         CHECK_NULL_FATAL(constValue);
1087         float floatValue = trunc(constValue->GetValue());
1088         if (IsPrimitiveFloat(toType)) {
1089             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1090         } else if (FloatToIntOverflow(floatValue, toType)) {
1091             return nullptr;
1092         } else {
1093             resultConst->SetConstVal(
1094                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1095         }
1096     } else {
1097         const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1098         CHECK_NULL_FATAL(constValue);
1099         double doubleValue = trunc(constValue->GetValue());
1100         if (IsPrimitiveFloat(toType)) {
1101             resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1102         } else if (DoubleToIntOverflow(doubleValue, toType)) {
1103             return nullptr;
1104         } else {
1105             resultConst->SetConstVal(
1106                 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1107         }
1108     }
1109     return resultConst;
1110 }
1111 
FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const1112 MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1113 {
1114     if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1115         MIRConst *toConst = nullptr;
1116         uint32 fromSize = GetPrimTypeBitSize(fromType);
1117         uint32 toSize = GetPrimTypeBitSize(toType);
1118         // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1119         if (fromType == PTY_u1) {
1120             fromSize = 1;
1121         }
1122         if (toType == PTY_u1) {
1123             toSize = 1;
1124         }
1125         if (toSize > fromSize) {
1126             Opcode op = OP_zext;
1127             if (IsSignedInteger(fromType)) {
1128                 op = OP_sext;
1129             }
1130             const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1131             ASSERT_NOT_NULL(constVal);
1132             toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize),
1133                 constVal->GetValue().TruncOrExtend(fromType));
1134         } else {
1135             const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1136             ASSERT_NOT_NULL(constVal);
1137             MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1138             toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1139                 static_cast<uint64>(constVal->GetExtValue()), type);
1140         }
1141         return toConst;
1142     }
1143     if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1144         MIRConst *toConst = nullptr;
1145         if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1146             DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1147             const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1148             ASSERT_NOT_NULL(fromValue);
1149             float floatValue = static_cast<float>(fromValue->GetValue());
1150             MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1151             toConst = toValue;
1152         } else {
1153             DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1154             const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1155             ASSERT_NOT_NULL(fromValue);
1156             double doubleValue = static_cast<double>(fromValue->GetValue());
1157             MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1158             toConst = toValue;
1159         }
1160         return toConst;
1161     }
1162     if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1163         return FoldFloorMIRConst(cst, fromType, toType, false);
1164     }
1165     if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1166         return FoldRoundMIRConst(cst, fromType, toType);
1167     }
1168     CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1169     return nullptr;
1170 }
1171 
FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const1172 ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1173 {
1174     MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1175     if (toConstValue == nullptr) {
1176         return nullptr;
1177     }
1178     ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1179     toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1180     toConst->SetConstVal(toConstValue);
1181     return toConst;
1182 }
1183 
1184 // return a primType with bit size >= bitSize (and the nearest one),
1185 // and its signed/float type is the same as ptyp
GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)1186 PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1187 {
1188     bool isSigned = IsSignedInteger(ptyp);
1189     bool isFloat = IsPrimitiveFloat(ptyp);
1190     if (bitSize == 1) { // 1 bit
1191         return PTY_u1;
1192     }
1193     if (bitSize <= 8) { // 8 bit
1194         return isSigned ? PTY_i8 : PTY_u8;
1195     }
1196     if (bitSize <= 16) { // 16 bit
1197         return isSigned ? PTY_i16 : PTY_u16;
1198     }
1199     if (bitSize <= 32) { // 32 bit
1200         return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1201     }
1202     if (bitSize <= 64) { // 64 bit
1203         return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1204     }
1205     return ptyp;
1206 }
1207 
GetIntPrimTypeMax(PrimType ptyp)1208 size_t GetIntPrimTypeMax(PrimType ptyp)
1209 {
1210     switch (ptyp) {
1211         case PTY_u1:
1212             return 1;
1213         case PTY_u8:
1214             return UINT8_MAX;
1215         case PTY_i8:
1216             return INT8_MAX;
1217         case PTY_u16:
1218             return UINT16_MAX;
1219         case PTY_i16:
1220             return INT16_MAX;
1221         case PTY_u32:
1222             return UINT32_MAX;
1223         case PTY_i32:
1224             return INT32_MAX;
1225         case PTY_u64:
1226             return UINT64_MAX;
1227         case PTY_i64:
1228             return INT64_MAX;
1229         default:
1230             CHECK_FATAL(false, "NYI");
1231     }
1232 }
1233 
GetIntPrimTypeMin(PrimType ptyp)1234 ssize_t GetIntPrimTypeMin(PrimType ptyp)
1235 {
1236     if (IsUnsignedInteger(ptyp)) {
1237         return 0;
1238     }
1239     switch (ptyp) {
1240         case PTY_i8:
1241             return INT8_MIN;
1242         case PTY_i16:
1243             return INT16_MIN;
1244         case PTY_i32:
1245             return INT32_MIN;
1246         case PTY_i64:
1247             return INT64_MIN;
1248         default:
1249             CHECK_FATAL(false, "NYI");
1250     }
1251 }
1252 
IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)1253 static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)
1254 {
1255     if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) {
1256         return false;
1257     }
1258     if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) {
1259         return false;
1260     }
1261     return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1262         (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1263         (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp));
1264 }
1265 
FoldTypeCvt(TypeCvtNode *node)1266 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1267 {
1268     CHECK_NULL_FATAL(node);
1269     BaseNode *result = nullptr;
1270     if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
1271         return {node, std::nullopt};
1272     }
1273     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1274     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1275     PrimType destPtyp = node->GetPrimType();
1276     PrimType fromPtyp = node->FromType();
1277     if (cst != nullptr) {
1278         switch (node->GetOpCode()) {
1279             case OP_ceil: {
1280                 result = FoldCeil(*cst, fromPtyp, destPtyp);
1281                 break;
1282             }
1283             case OP_cvt: {
1284                 result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1285                 break;
1286             }
1287             case OP_floor: {
1288                 result = FoldFloor(*cst, fromPtyp, destPtyp);
1289                 break;
1290             }
1291             case OP_trunc: {
1292                 result = FoldTrunc(*cst, fromPtyp, destPtyp);
1293                 break;
1294             }
1295             default:
1296                 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1297                 break;
1298         }
1299     } else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) {
1300         // the cvt is redundant
1301         return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1302     }
1303     if (result == nullptr) {
1304         BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1305         if (e != node->Opnd(0)) {
1306             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1307                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1308         } else {
1309             result = node;
1310         }
1311     }
1312     return std::make_pair(result, std::nullopt);
1313 }
1314 
FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const1315 MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1316 {
1317     uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size);
1318     MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1319     MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1320     return constValue;
1321 }
1322 
FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size, const ConstvalNode &cst) const1323 ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1324                                            const ConstvalNode &cst) const
1325 {
1326     ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1327     const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1328     ASSERT_NOT_NULL(intCst);
1329     IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1330     MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1331     resultConst->SetPrimType(toConst->GetType().GetPrimType());
1332     resultConst->SetConstVal(toConst);
1333     return resultConst;
1334 }
1335 
1336 // check if truncation is redundant due to dread or iread having same effect
ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)1337 static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)
1338 {
1339     if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) {
1340         return false;  // this is trying to be conservative
1341     }
1342     BaseNode *opnd = x.Opnd(0);
1343     MIRType *mirType = nullptr;
1344     if (opnd->GetOpCode() == OP_dread) {
1345         DreadNode *dread = static_cast<DreadNode*>(opnd);
1346         MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx());
1347         ASSERT_NOT_NULL(sym);
1348         mirType = sym->GetType();
1349     } else if (opnd->GetOpCode() == OP_iread) {
1350         IreadNode *iread = static_cast<IreadNode*>(opnd);
1351         MIRPtrType *ptrType =
1352             dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1353         if (ptrType == nullptr) {
1354             return false;
1355         }
1356         mirType = ptrType->GetPointedType();
1357     } else if (opnd->GetOpCode() == OP_extractbits &&
1358                 x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) {
1359         return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() &&
1360             IsUnsignedInteger(opnd->GetPrimType()));
1361     } else {
1362         return false;
1363     }
1364     return IsPrimitiveInteger(mirType->GetPrimType()) &&
1365             ((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) ||
1366             (x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) &&
1367             mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() &&
1368             mirType->GetPrimType() == x.GetPrimType();
1369 }
1370 
1371 // sext and zext also handled automatically
FoldExtractbits(ExtractbitsNode *node)1372 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1373 {
1374     CHECK_NULL_FATAL(node);
1375     BaseNode *result = nullptr;
1376     uint8 offset = node->GetBitsOffset();
1377     uint8 size = node->GetBitsSize();
1378     Opcode opcode = node->GetOpCode();
1379     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1380     ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1381     if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1382         result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1383         return std::make_pair(result, std::nullopt);
1384     }
1385     BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1386     if (e != node->Opnd(0)) {
1387         result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1388                                                                        size, e);
1389     } else {
1390         result = node;
1391     }
1392     // check for consecutive and redundant extraction of same bits
1393     BaseNode *opnd = result->Opnd(0);
1394     DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null");
1395     Opcode opndOp = opnd->GetOpCode();
1396     if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1397         uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset();
1398         uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize();
1399         if (offset == opndOffset && size == opndSize) {
1400             result->SetOpnd(opnd->Opnd(0), 0);  // delete the redundant extraction
1401         }
1402     }
1403     if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1404         if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) {
1405             return std::make_pair(result->Opnd(0), std::nullopt);
1406         }
1407     }
1408     return std::make_pair(result, std::nullopt);
1409 }
1410 
FoldIread(IreadNode *node)1411 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1412 {
1413     CHECK_NULL_FATAL(node);
1414     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1415     BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1416     node->SetOpnd(e, 0);
1417     BaseNode *result = node;
1418     if (e->GetOpCode() != OP_addrof) {
1419         return std::make_pair(result, std::nullopt);
1420     }
1421 
1422     AddrofNode *addrofNode = static_cast<AddrofNode*>(e);
1423     MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1424     DEBUG_ASSERT(msy != nullptr, "nullptr check");
1425     TyIdx typeId = msy->GetTyIdx();
1426     CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check");
1427     MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1428     MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1429     // If the high level type of iaddrof/iread doesn't match
1430     // the type of addrof's rhs, this optimization cannot be done.
1431     if (ptrType->GetPointedType() != msyType) {
1432         return std::make_pair(result, std::nullopt);
1433     }
1434 
1435     Opcode op = node->GetOpCode();
1436     if (op == OP_iread) {
1437         result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1438                                                                   node->GetFieldID() + addrofNode->GetFieldID());
1439     }
1440     return std::make_pair(result, std::nullopt);
1441 }
1442 
IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)1443 bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1444 {
1445     switch (op) {
1446         case OP_add: {
1447             int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1448             if (IsUnsignedInteger(primType)) {
1449                 return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1450             }
1451             auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1452             return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1453                     static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1454                    (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1455                     static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1456         }
1457         case OP_sub: {
1458             if (IsUnsignedInteger(primType)) {
1459                 return cstA < cstB;
1460             }
1461             int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1462             auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1463             return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1464                     static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1465                    (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1466                     static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1467         }
1468         default: {
1469             return false;
1470         }
1471     }
1472 }
1473 
FoldBinary(BinaryNode *node)1474 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1475 {
1476     CHECK_NULL_FATAL(node);
1477     BaseNode *result = nullptr;
1478     std::optional<IntVal> sum = std::nullopt;
1479     Opcode op = node->GetOpCode();
1480     PrimType primType = node->GetPrimType();
1481     PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1482     PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1483     std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1484     std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1485     BaseNode *l = lp.first;
1486     BaseNode *r = rp.first;
1487     ASSERT_NOT_NULL(r);
1488     ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1489     ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1490     bool isInt = IsPrimitiveInteger(primType);
1491 
1492     if (lConst != nullptr && rConst != nullptr) {
1493         MIRConst *lConstVal = lConst->GetConstVal();
1494         MIRConst *rConstVal = rConst->GetConstVal();
1495         ASSERT_NOT_NULL(lConstVal);
1496         ASSERT_NOT_NULL(rConstVal);
1497         // Don't fold div by 0, for floats div by 0 is well defined.
1498         if ((op == OP_div || op == OP_rem) && isInt &&
1499             !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1500             result = NewBinaryNode(node, op, primType, lConst, rConst);
1501         } else {
1502             // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1503             // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1504             // logic since the alternative is to return pair(result = nullptr, sum = 6).
1505             // Doing so would introduce many nullptr checks in the code. See previous
1506             // commits that implemented that logic for a comparison.
1507             result = FoldConstBinary(op, primType, *lConst, *rConst);
1508         }
1509     } else if (lConst != nullptr && isInt) {
1510         MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1511         ASSERT_NOT_NULL(mcst);
1512         PrimType cstTyp = mcst->GetType().GetPrimType();
1513         IntVal cst = mcst->GetValue();
1514         if (op == OP_add) {
1515             if (IsSignedInteger(cstTyp) && rp.second &&
1516                 IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) {
1517                 // do not introduce signed integer overflow
1518                 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1519             } else {
1520                 sum = cst + rp.second;
1521                 result = r;
1522             }
1523         } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1524             // We exclude u1 type for fixing the following wrong example:
1525             // before cf:
1526             //   sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1527             // after cf:
1528             //   add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1529             sum = cst - rp.second;
1530             if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) {
1531                 r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r);
1532             }
1533             result = NegateTree(r);
1534         } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1535                     op == OP_band) &&
1536                     cst == 0) {
1537             // 0 * X -> 0
1538             // 0 / X -> 0
1539             // 0 % X -> 0
1540             // 0 >> X -> 0
1541             // 0 << X -> 0
1542             // 0 & X -> 0
1543             // 0 && X -> 0
1544             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1545         } else if (op == OP_mul && cst == 1) {
1546             // 1 * X --> X
1547             sum = rp.second;
1548             result = r;
1549         } else if (op == OP_bior && cst == -1) {
1550             // (-1) | X -> -1
1551             result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp);
1552         } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1553             // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1554             sum = cst * rp.second;
1555             if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1556                 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1557             }
1558             result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1559         } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1560             // 0 | X -> X
1561             // 0 ^ X -> X
1562             sum = rp.second;
1563             result = r;
1564         } else {
1565             result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1566         }
1567         if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1568             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1569         }
1570     } else if (rConst != nullptr && isInt) {
1571         MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1572         ASSERT_NOT_NULL(mcst);
1573         PrimType cstTyp = mcst->GetType().GetPrimType();
1574         IntVal cst = mcst->GetValue();
1575         if (op == OP_add) {
1576             if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1577                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1578             } else {
1579                 result = l;
1580                 sum = lp.second + cst;
1581             }
1582         } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1583             result = l;
1584             sum = lp.second - cst;
1585         } else if ((op == OP_mul || op == OP_band) && cst == 0) {
1586             // X * 0 -> 0
1587             // X & 0 -> 0
1588             // X && 0 -> 0
1589             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1590         } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1591             // case [X * 1 -> X]
1592             // case [X / 1 = X]
1593             sum = lp.second;
1594             result = l;
1595         } else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul &&
1596                 IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) {
1597             // temporary fix for constfold of mul/div in DejaGnu
1598             // Later we need a more formal interface for pattern match
1599             // X * Y / Y -> X
1600             BaseNode *x = l->Opnd(0);
1601             BaseNode *y = l->Opnd(1);
1602             ConstvalNode *xConst = safe_cast<ConstvalNode>(x);
1603             ConstvalNode *yConst = safe_cast<ConstvalNode>(y);
1604             bool foldMulDiv = false;
1605             if (yConst != nullptr && xConst == nullptr &&
1606                 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1607                 MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal());
1608                 ASSERT_NOT_NULL(yCst);
1609                 IntVal mulCst = yCst->GetValue();
1610                 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1611                     mulCst.GetExtValue() == cst.GetExtValue()) {
1612                     foldMulDiv = true;
1613                     result = x;
1614                 }
1615             } else if (xConst != nullptr && yConst == nullptr &&
1616                         IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1617                 MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal());
1618                 ASSERT_NOT_NULL(xCst);
1619                 IntVal mulCst = xCst->GetValue();
1620                 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1621                     mulCst.GetExtValue() == cst.GetExtValue()) {
1622                     foldMulDiv = true;
1623                     result = y;
1624                 }
1625             }
1626             if (!foldMulDiv) {
1627                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1628             }
1629         } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1630             // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1631             sum = lp.second * cst;
1632             if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1633                 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1634             }
1635             if (lp.first->GetOpCode() == OP_neg && cst == -1) {
1636                 // special case: ((-X) + konst) * (-1) -> the pair [(X), -konst]
1637                 result = lp.first->Opnd(0);
1638             } else {
1639                 result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1640             }
1641         } else if (op == OP_band && cst == -1) {
1642             // X & (-1) -> X
1643             sum = lp.second;
1644             result = l;
1645         } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1646                    (!lp.second.has_value() || lp.second == 0)) {
1647             bool fold2extractbits = false;
1648             if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1649                 BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1650                 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1651                     ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1652                     int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue();
1653                     uint64 ucst = cst.GetZXTValue();
1654                     uint32 bsize = 0;
1655                     do {
1656                         bsize++;
1657                         ucst >>= 1;
1658                     } while (ucst != 0);
1659                     if (shrAmt + static_cast<int64>(bsize) <=
1660                         static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) &&
1661                         static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) {
1662                         fold2extractbits = true;
1663                         // change to use extractbits
1664                         result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits,
1665                             GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0));
1666                         sum = std::nullopt;
1667                     }
1668                 }
1669             }
1670             if (!fold2extractbits) {
1671                 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1672                 sum = std::nullopt;
1673             }
1674         } else if (op == OP_bior && cst == -1) {
1675             // X | (-1) -> -1
1676             result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp);
1677         } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1678             // X >> 0 -> X
1679             // X << 0 -> X
1680             // X | 0 -> X
1681             // X ^ 0 -> X
1682             sum = lp.second;
1683             result = l;
1684         } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1685             // bxor i32 (
1686             //   cvt i32 u1 (regread u1 %13),
1687             //  constValue i32 1),
1688             result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1689             if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) {
1690                 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l);
1691                 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1692                     BaseNode *base = cvtNode->Opnd(0);
1693                     BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1694                     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base);
1695                     BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1696                     result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1697                 }
1698             }
1699         } else if (op == OP_rem && cst == 1) {
1700             // X % 1 -> 0
1701             result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1702         } else {
1703             result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1704         }
1705         if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1706             result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1707         }
1708     } else if (isInt && (op == OP_add || op == OP_sub)) {
1709         if (op == OP_add) {
1710             result = NewBinaryNode(node, op, primType, l, r);
1711             sum = lp.second + rp.second;
1712         } else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
1713             // if fold is (x - (y - z))    ->     (x - neg(z)) - y
1714             // (x - neg(z)) Could cross the int limit
1715             // return node
1716             result = node;
1717         } else {
1718             result = NewBinaryNode(node, op, primType, l, r);
1719             sum = lp.second - rp.second;
1720         }
1721     } else {
1722         result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1723     }
1724     return std::make_pair(result, sum);
1725 }
1726 
SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const1727 BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const
1728 {
1729     if (isRConstval) {
1730         ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1));
1731         if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1732             const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0));
1733             return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1734                 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1735         }
1736     } else {
1737         ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0));
1738         if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1739             const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1));
1740             if (isGtOrLt) {
1741                 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1742                     node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0));
1743             } else {
1744                 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1745                     node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1746             }
1747         }
1748     }
1749     return &node;
1750 }
1751 
SimplifyDoubleCompare(CompareNode &compareNode) const1752 BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
1753 {
1754     // See arm manual B.cond(P2993) and FCMP(P1091)
1755     CompareNode *node = &compareNode;
1756     BaseNode *result = node;
1757     BaseNode *l = node->Opnd(0);
1758     BaseNode *r = node->Opnd(1);
1759     if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
1760         if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1761             (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1762             result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval));
1763         } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
1764             // ne (u1 x, constValue 0)  <==> x
1765             ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1766             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1767                 BaseNode *opnd = l;
1768                 do {
1769                     if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1770                         result = opnd;
1771                         break;
1772                     } else if (opnd->GetOpCode() == OP_cvt) {
1773                         TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd);
1774                         opnd = cvtNode->Opnd(0);
1775                     } else {
1776                         opnd = nullptr;
1777                     }
1778                 } while (opnd != nullptr);
1779             }
1780         } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
1781             ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1782             if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
1783                 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1784                 auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne;
1785                 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1786                     resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1));
1787             }
1788         }
1789     } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
1790         if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1791             (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1792             result = SimplifyDoubleConstvalCompare(*node,
1793                 (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true);
1794         }
1795     }
1796     return result;
1797 }
1798 
FoldCompare(CompareNode *node)1799 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
1800 {
1801     CHECK_NULL_FATAL(node);
1802     BaseNode *result = nullptr;
1803     std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1804     std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1805     ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
1806     ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
1807     Opcode opcode = node->GetOpCode();
1808     if (lConst != nullptr && rConst != nullptr) {
1809         result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
1810     } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
1811                lConst->GetConstVal()->GetKind() == kConstInt) {
1812         BaseNode *l = lp.first;
1813         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1814         result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
1815     } else {
1816         BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
1817         BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1818         if (l != node->Opnd(0) || r != node->Opnd(1)) {
1819             result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1820                 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
1821         } else {
1822             result = node;
1823         }
1824         auto *compareNode = static_cast<CompareNode*>(result);
1825         CHECK_NULL_FATAL(compareNode);
1826         result = SimplifyDoubleCompare(*compareNode);
1827     }
1828     return std::make_pair(result, std::nullopt);
1829 }
1830 
Fold(BaseNode *node)1831 BaseNode *ConstantFold::Fold(BaseNode *node)
1832 {
1833     if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
1834         return nullptr;
1835     }
1836     std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node);
1837     BaseNode *result = PairToExpr(node->GetPrimType(), p);
1838     if (result == node) {
1839         result = nullptr;
1840     }
1841     return result;
1842 }
1843 
1844 }  // namespace maple
1845