1/*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "constantfold.h"
17#include <cmath>
18#include <cfloat>
19#include <climits>
20#include <type_traits>
21#include "mpl_logging.h"
22#include "mir_function.h"
23#include "mir_builder.h"
24#include "global_tables.h"
25#include "me_option.h"
26#include "maple_phase_manager.h"
27#include "mir_type.h"
28
29namespace maple {
30
31namespace {
32constexpr uint32 kByteSizeOfBit64 = 8;                            // byte number for 64 bit
33constexpr uint32 kBitSizePerByte = 8;
34constexpr maple::int32 kMaxOffset = INT_MAX - 8;
35
36enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
37
38std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
39{
40    if (!v1 && !v2) {
41        return std::nullopt;
42    }
43
44    // Perform all calculations in terms of the maximum available signed type.
45    // The value will be truncated for an appropriate type when constant is created in PairToExpr function
46    return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64);
47}
48
49// Perform all calculations in terms of the maximum available signed type.
50// The value will be truncated for an appropriate type when constant is created in PairToExpr function
51std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
52{
53    if (!v1 && !v2) {
54        return std::nullopt;
55    }
56
57    if (v1 && v2) {
58        return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
59    }
60
61    if (v1) {
62        return v1->TruncOrExtend(PTY_i64);
63    }
64
65    // !v1 && v2
66    return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
67}
68
69std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
70{
71    return AddSub(v1, v2, true);
72}
73
74std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
75{
76    return AddSub(v1, v2, false);
77}
78
79}  // anonymous namespace
80
81// This phase is designed to achieve compiler optimization by
82// simplifying constant expressions. The constant expression
83// is evaluated and replaced by the value calculated on compile
84// time to save time on runtime.
85//
86// The main procedure shows as following:
87// A. Analyze expression type
88// B. Analysis operator type
89// C. Replace the expression with the result of the operation
90
91// true if the constant's bits are made of only one group of contiguous 1's
92// starting at bit 0
93static bool ContiguousBitsOf1(uint64 x)
94{
95    if (x == 0) {
96        return false;
97    }
98    return (~x & (x + 1)) == (x + 1);
99}
100
101inline bool IsPowerOf2(uint64 num)
102{
103    if (num == 0) {
104        return false;
105    }
106    return (~(num - 1) & num) == num;
107}
108
109BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
110                                        BaseNode *rhs) const
111{
112    CHECK_NULL_FATAL(old);
113    BinaryNode *result = nullptr;
114    if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
115        result = old;
116    } else {
117        result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
118    }
119    return result;
120}
121
122UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
123{
124    CHECK_NULL_FATAL(old);
125    UnaryNode *result = nullptr;
126    if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
127        result = old;
128    } else {
129        result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
130    }
131    return result;
132}
133
134BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const
135{
136    CHECK_NULL_FATAL(pair.first);
137    BaseNode *result = pair.first;
138    if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) {
139        return result;
140    }
141    if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
142        // -a, 5 -> 5 - a
143        ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
144            static_cast<uint64>(pair.second->GetExtValue()), resultType);
145        BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0);
146        result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
147    } else {
148        if ((!pair.second->GetSignBit() &&
149            pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) ||
150            pair.second->TruncOrExtend(resultType).IsMinValue() ||
151            pair.second->GetSXTValue() == INT64_MIN) {
152            // +-a, 5 -> a + 5
153            ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
154                static_cast<uint64>(pair.second->GetExtValue()), resultType);
155            result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
156        } else {
157            // +-a, -5 -> a + -5
158            ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
159                static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType);
160            result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
161        }
162    }
163    return result;
164}
165
166std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
167{
168    return std::make_pair(node, std::nullopt);
169}
170
171std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
172{
173    CHECK_NULL_FATAL(node);
174    if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
175        return {node, std::nullopt};
176    }
177    switch (node->GetOpCode()) {
178        case OP_abs:
179        case OP_bnot:
180        case OP_lnot:
181        case OP_neg:
182        case OP_sqrt:
183            return FoldUnary(static_cast<UnaryNode*>(node));
184        case OP_ceil:
185        case OP_floor:
186        case OP_trunc:
187        case OP_cvt:
188            return FoldTypeCvt(static_cast<TypeCvtNode*>(node));
189        case OP_sext:
190        case OP_zext:
191        case OP_extractbits:
192            return FoldExtractbits(static_cast<ExtractbitsNode*>(node));
193        case OP_iread:
194            return FoldIread(static_cast<IreadNode*>(node));
195        case OP_add:
196        case OP_ashr:
197        case OP_band:
198        case OP_bior:
199        case OP_bxor:
200        case OP_div:
201        case OP_lshr:
202        case OP_max:
203        case OP_min:
204        case OP_mul:
205        case OP_rem:
206        case OP_shl:
207        case OP_sub:
208            return FoldBinary(static_cast<BinaryNode*>(node));
209        case OP_eq:
210        case OP_ne:
211        case OP_ge:
212        case OP_gt:
213        case OP_le:
214        case OP_lt:
215        case OP_cmp:
216            return FoldCompare(static_cast<CompareNode*>(node));
217        case OP_retype:
218            return FoldRetype(static_cast<RetypeNode*>(node));
219        default:
220            return FoldBase(static_cast<BaseNode*>(node));
221    }
222}
223
224BaseNode *ConstantFold::Negate(BaseNode *node) const
225{
226    CHECK_NULL_FATAL(node);
227    return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
228}
229
230BaseNode *ConstantFold::Negate(UnaryNode *node) const
231{
232    CHECK_NULL_FATAL(node);
233    BaseNode *result = nullptr;
234    if (node->GetOpCode() == OP_neg) {
235        result = static_cast<BaseNode*>(node->Opnd(0));
236    } else {
237        BaseNode *n = static_cast<BaseNode*>(node);
238        result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
239    }
240    return result;
241}
242
243BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
244{
245    CHECK_NULL_FATAL(node);
246    ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
247    CHECK_NULL_FATAL(copy);
248    copy->GetConstVal()->Neg();
249    return copy;
250}
251
252BaseNode *ConstantFold::NegateTree(BaseNode *node) const
253{
254    CHECK_NULL_FATAL(node);
255    if (node->IsUnaryNode()) {
256        return Negate(static_cast<UnaryNode*>(node));
257    } else if (node->GetOpCode() == OP_constval) {
258        return Negate(static_cast<ConstvalNode*>(node));
259    } else {
260        return Negate(static_cast<BaseNode*>(node));
261    }
262}
263
264MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
265                                                          const MIRIntConst &intConst0,
266                                                          const MIRIntConst &intConst1) const
267{
268    uint64 result = 0;
269
270    bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
271    bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
272    bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
273
274    switch (opcode) {
275        case OP_eq: {
276            result = equal;
277            break;
278        }
279        case OP_ge: {
280            result = (greater || equal) ? 1 : 0;
281            break;
282        }
283        case OP_gt: {
284            result = greater;
285            break;
286        }
287        case OP_le: {
288            result = (less || equal) ? 1 : 0;
289            break;
290        }
291        case OP_lt: {
292            result = less;
293            break;
294        }
295        case OP_ne: {
296            result = !equal;
297            break;
298        }
299        case OP_cmp: {
300            if (greater) {
301                result = kGreater;
302            } else if (equal) {
303                result = kEqual;
304            } else if (less) {
305                result = static_cast<uint64>(kLess);
306            }
307            break;
308        }
309        default:
310            DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
311            break;
312    }
313    // determine the type
314    MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
315    // form the constant
316    MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
317    return constValue;
318}
319
320ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
321                                                   const ConstvalNode &const0, const ConstvalNode &const1) const
322{
323    const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
324    const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
325    CHECK_NULL_FATAL(intConst0);
326    CHECK_NULL_FATAL(intConst1);
327    MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
328    // form the ConstvalNode
329    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
330    resultConst->SetPrimType(resultType);
331    resultConst->SetConstVal(constValue);
332    return resultConst;
333}
334
335MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
336                                                   const MIRIntConst &intConst1)
337{
338    IntVal intVal0 = intConst0.GetValue();
339    IntVal intVal1 = intConst1.GetValue();
340    IntVal result(static_cast<uint64>(0), resultType);
341
342    switch (opcode) {
343        case OP_add: {
344            result = intVal0.Add(intVal1, resultType);
345            break;
346        }
347        case OP_sub: {
348            result = intVal0.Sub(intVal1, resultType);
349            break;
350        }
351        case OP_mul: {
352            result = intVal0.Mul(intVal1, resultType);
353            break;
354        }
355        case OP_div: {
356            result = intVal0.Div(intVal1, resultType);
357            break;
358        }
359        case OP_rem: {
360            result = intVal0.Rem(intVal1, resultType);
361            break;
362        }
363        case OP_ashr: {
364            result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
365            break;
366        }
367        case OP_lshr: {
368            result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
369            break;
370        }
371        case OP_shl: {
372            result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
373            break;
374        }
375        case OP_max: {
376            result = Max(intVal0, intVal1, resultType);
377            break;
378        }
379        case OP_min: {
380            result = Min(intVal0, intVal1, resultType);
381            break;
382        }
383        case OP_band: {
384            result = intVal0.And(intVal1, resultType);
385            break;
386        }
387        case OP_bior: {
388            result = intVal0.Or(intVal1, resultType);
389            break;
390        }
391        case OP_bxor: {
392            result = intVal0.Xor(intVal1, resultType);
393            break;
394        }
395        default:
396            DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
397            break;
398    }
399    // determine the type
400    MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
401    // form the constant
402    MIRIntConst *constValue =
403        GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
404    return constValue;
405}
406
407ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
408                                               const ConstvalNode &const1) const
409{
410    const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
411    const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
412    CHECK_NULL_FATAL(intConst0);
413    CHECK_NULL_FATAL(intConst1);
414    MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1);
415    // form the ConstvalNode
416    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
417    resultConst->SetPrimType(resultType);
418    resultConst->SetConstVal(constValue);
419    return resultConst;
420}
421
422ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
423                                              const ConstvalNode &const1) const
424{
425    DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
426    const MIRDoubleConst *doubleConst0 = nullptr;
427    const MIRDoubleConst *doubleConst1 = nullptr;
428    const MIRFloatConst *floatConst0 = nullptr;
429    const MIRFloatConst *floatConst1 = nullptr;
430    bool useDouble = (const0.GetPrimType() == PTY_f64);
431    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
432    resultConst->SetPrimType(resultType);
433    if (useDouble) {
434        doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
435        doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
436        CHECK_NULL_FATAL(doubleConst0);
437        CHECK_NULL_FATAL(doubleConst1);
438    } else {
439        floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
440        floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
441        CHECK_NULL_FATAL(floatConst0);
442        CHECK_NULL_FATAL(floatConst1);
443    }
444    float constValueFloat = 0.0;
445    double constValueDouble = 0.0;
446    switch (opcode) {
447        case OP_add: {
448            if (useDouble) {
449                constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
450            } else {
451                constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
452            }
453            break;
454        }
455        case OP_sub: {
456            if (useDouble) {
457                constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
458            } else {
459                constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
460            }
461            break;
462        }
463        case OP_mul: {
464            if (useDouble) {
465                constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
466            } else {
467                constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
468            }
469            break;
470        }
471        case OP_div: {
472            // for floats div by 0 is well defined
473            if (useDouble) {
474                constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
475            } else {
476                constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
477            }
478            break;
479        }
480        case OP_max: {
481            if (useDouble) {
482                constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
483                                                                                        : doubleConst1->GetValue();
484            } else {
485                constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
486                                                                                    : floatConst1->GetValue();
487            }
488            break;
489        }
490        case OP_min: {
491            if (useDouble) {
492                constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
493                                                                                        : doubleConst1->GetValue();
494            } else {
495                constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
496                                                                                    : floatConst1->GetValue();
497            }
498            break;
499        }
500        case OP_rem:
501        case OP_ashr:
502        case OP_lshr:
503        case OP_shl:
504        case OP_band:
505        case OP_bior:
506        case OP_bxor: {
507            DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
508            break;
509        }
510        default:
511            DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
512            break;
513    }
514    if (resultType == PTY_f64) {
515        resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
516    } else {
517        resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
518    }
519    return resultConst;
520}
521
522bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
523{
524    return (leftValue == rightValue);
525}
526
527bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
528{
529    auto result = fabs(leftValue - rightValue);
530    return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN;
531}
532
533bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
534{
535    auto result = fabs(leftValue - rightValue);
536    return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN;
537}
538
539template<typename T>
540bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
541{
542    if (std::isinf(leftValue) && std::isinf(rightValue)) {
543        // (inf == inf), add the judgement here in case of the subtraction between float type inf
544        return true;
545    } else {
546        return ConstValueEqual(leftValue, rightValue);
547    }
548}
549
550template<typename T>
551int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
552{
553    DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr");
554    typename T::value_type leftValue = leftConst->GetValue();
555    DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr");
556    typename T::value_type rightValue = rightConst->GetValue();
557    int64 result = 0;
558    switch (op) {
559        case OP_eq: {
560            result = FullyEqual(leftValue, rightValue);
561            break;
562        }
563        case OP_ge: {
564            result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
565            break;
566        }
567        case OP_gt: {
568            result = (leftValue > rightValue);
569            break;
570        }
571        case OP_le: {
572            result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
573            break;
574        }
575        case OP_lt: {
576            result = (leftValue < rightValue);
577            break;
578        }
579        case OP_ne: {
580            result = !FullyEqual(leftValue, rightValue);
581            break;
582        }
583        [[clang::fallthrough]];
584        case OP_cmp: {
585            if (leftValue > rightValue) {
586                result = kGreater;
587            } else if (FullyEqual(leftValue, rightValue)) {
588                result = kEqual;
589            } else {
590                result = kLess;
591            }
592            break;
593        }
594        default:
595            DEBUG_ASSERT(false, "Unknown opcode for Comparison");
596            break;
597    }
598    return result;
599}
600
601MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
602                                                         const MIRConst &leftConst, const MIRConst &rightConst) const
603{
604    int64 result = 0;
605    bool useDouble = (opndType == PTY_f64);
606    if (useDouble) {
607        result =
608            ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
609    } else {
610        result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
611    }
612    MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
613    MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type);
614    return resultConst;
615}
616
617ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
618                                                  const ConstvalNode &const0, const ConstvalNode &const1) const
619{
620    DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
621    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
622    resultConst->SetPrimType(resultType);
623    resultConst->SetConstVal(
624        FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
625    return resultConst;
626}
627
628MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
629                                                    const MIRConst &const0, const MIRConst &const1) const
630{
631    MIRConst *returnValue = nullptr;
632    if (IsPrimitiveInteger(opndType)) {
633        const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
634        const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
635        ASSERT_NOT_NULL(intConst0);
636        ASSERT_NOT_NULL(intConst1);
637        returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
638    } else if (opndType == PTY_f32 || opndType == PTY_f64) {
639        returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
640    } else {
641        DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
642    }
643    return returnValue;
644}
645
646ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
647                                                const ConstvalNode &const0, const ConstvalNode &const1) const
648{
649    ConstvalNode *returnValue = nullptr;
650    if (IsPrimitiveInteger(opndType)) {
651        returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
652    } else if (opndType == PTY_f32 || opndType == PTY_f64) {
653        returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
654    } else {
655        DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
656    }
657    return returnValue;
658}
659
660CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
661                                                      BaseNode &l, BaseNode &r) const
662{
663    CompareNode *result = nullptr;
664    Opcode op = opcode;
665    switch (opcode) {
666        case OP_gt: {
667            op = OP_lt;
668            break;
669        }
670        case OP_lt: {
671            op = OP_gt;
672            break;
673        }
674        case OP_ge: {
675            op = OP_le;
676            break;
677        }
678        case OP_le: {
679            op = OP_ge;
680            break;
681        }
682        case OP_eq: {
683            break;
684        }
685        case OP_ne: {
686            break;
687        }
688        default:
689            DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
690            break;
691    }
692
693    result =
694        mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
695    return result;
696}
697
698ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
699                                            const ConstvalNode &const1) const
700{
701    ConstvalNode *returnValue = nullptr;
702    if (IsPrimitiveInteger(resultType)) {
703        returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
704    } else if (resultType == PTY_f32 || resultType == PTY_f64) {
705        returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
706    } else {
707        DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
708    }
709    return returnValue;
710}
711
712MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)
713{
714    CHECK_NULL_FATAL(constNode);
715    IntVal result = constNode->GetValue().TruncOrExtend(resultType);
716    switch (opcode) {
717        case OP_abs: {
718            if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) {
719                result = -result;
720            }
721            break;
722        }
723        case OP_bnot: {
724            result = ~result;
725            break;
726        }
727        case OP_lnot: {
728            uint64 resultInt = result == 0 ? 1 : 0;
729            result = {resultInt, resultType};
730            break;
731        }
732        case OP_neg: {
733            result = -result;
734            break;
735        }
736        case OP_sext:         // handled in FoldExtractbits
737        case OP_zext:         // handled in FoldExtractbits
738        case OP_extractbits:  // handled in FoldExtractbits
739        case OP_sqrt: {
740            DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst");
741            break;
742        }
743        default:
744            DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst");
745            break;
746    }
747    // determine the type
748    MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
749    // form the constant
750    MIRIntConst *constValue =
751        GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
752    return constValue;
753}
754
755template <typename T>
756ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
757{
758    CHECK_NULL_FATAL(constNode);
759    double constValue = 0;
760    T *fpCst = static_cast<T*>(constNode->GetConstVal());
761    switch (opcode) {
762        case OP_neg: {
763            constValue = typename T::value_type(-fpCst->GetValue());
764            break;
765        }
766        case OP_abs: {
767            constValue = typename T::value_type(fabs(fpCst->GetValue()));
768            break;
769        }
770        case OP_sqrt: {
771            constValue = typename T::value_type(sqrt(fpCst->GetValue()));
772            break;
773        }
774        case OP_bnot:
775        case OP_lnot:
776        case OP_sext:
777        case OP_zext:
778        case OP_extractbits: {
779            DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
780            break;
781        }
782        default:
783            DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
784            break;
785    }
786    auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
787    resultConst->SetPrimType(resultType);
788    if (resultType == PTY_f32) {
789        resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue)));
790    } else if (resultType == PTY_f64) {
791        resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
792    } else {
793        CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64");
794    }
795    return resultConst;
796}
797
798ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const
799{
800    ConstvalNode *returnValue = nullptr;
801    if (IsPrimitiveInteger(resultType)) {
802        const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal());
803        auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst);
804        returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
805        returnValue->SetPrimType(resultType);
806        returnValue->SetConstVal(constValue);
807    } else if (resultType == PTY_f32) {
808        returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode);
809    } else if (resultType == PTY_f64) {
810        returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode);
811    } else {
812        DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
813    }
814    return returnValue;
815}
816
817std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
818{
819    CHECK_NULL_FATAL(node);
820    BaseNode *result = node;
821    std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
822    if (node->Opnd(0) != p.first) {
823        RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
824        CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
825        newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
826        result = newRetNode;
827    }
828    return std::make_pair(result, std::nullopt);
829}
830
831std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
832{
833    CHECK_NULL_FATAL(node);
834    BaseNode *result = nullptr;
835    std::optional<IntVal> sum = std::nullopt;
836    std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
837    ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
838    if (cst != nullptr) {
839        result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst);
840    } else {
841        bool isInt = IsPrimitiveInteger(node->GetPrimType());
842        // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
843        // primType will be set to opnd type. There will be problems in some cases. For example:
844        // before cf:
845        //   neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
846        // after cf:
847        //   neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))  # wrong!
848        // As a workaround, we exclude u1 opnd type
849        if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
850            result = NegateTree(p.first);
851            if (result->GetOpCode() == OP_neg) {
852                PrimType origPtyp = node->GetPrimType();
853                PrimType newPtyp = result->GetPrimType();
854                if (newPtyp == origPtyp) {
855                if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) {
856                    // NegateTree returned an UnaryNode quivalent to `n`, so keep the
857                    // original UnaryNode to preserve identity
858                    result = node;
859                }
860                } else {
861                    if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
862                        // do not fold explicit cvt
863                        result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(),
864                            PairToExpr(node->Opnd(0)->GetPrimType(), p));
865                        return std::make_pair(result, std::nullopt);
866                    } else {
867                        result->SetPrimType(origPtyp);
868                    }
869                }
870            }
871            if (p.second) {
872                sum = -(*p.second);
873            }
874        } else {
875            result =
876                NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
877        }
878    }
879    return std::make_pair(result, sum);
880}
881
882static bool FloatToIntOverflow(float fval, PrimType totype)
883{
884    static const float safeFloatMaxToInt32 = 2147483520.0f;  // 2^31 - 128
885    static const float safeFloatMinToInt32 = -2147483520.0f;
886    static const float safeFloatMaxToInt64 = 9223372036854775680.0f;  // 2^63 - 128
887    static const float safeFloatMinToInt64 = -9223372036854775680.0f;
888    if (!std::isfinite(fval)) {
889        return true;
890    }
891    if (totype == PTY_i64 || totype == PTY_u64) {
892        if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
893            return true;
894        }
895    } else {
896        if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
897            return true;
898        }
899    }
900    return false;
901}
902
903static bool DoubleToIntOverflow(double dval, PrimType totype)
904{
905    static const double safeDoubleMaxToInt32 = 2147482624.0;  // 2^31 - 1024
906    static const double safeDoubleMinToInt32 = -2147482624.0;
907    static const double safeDoubleMaxToInt64 = 9223372036854774784.0;  // 2^63 - 1024
908    static const double safeDoubleMinToInt64 = -9223372036854774784.0;
909    if (!std::isfinite(dval)) {
910        return true;
911    }
912    if (totype == PTY_i64 || totype == PTY_u64) {
913        if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
914            return true;
915        }
916    } else {
917        if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
918            return true;
919        }
920    }
921    return false;
922}
923
924ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
925{
926    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
927    resultConst->SetPrimType(toType);
928    MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
929    if (fromType == PTY_f32) {
930        const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
931        ASSERT_NOT_NULL(constValue);
932        float floatValue = ceil(constValue->GetValue());
933        if (IsPrimitiveFloat(toType)) {
934            resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
935        } else if (FloatToIntOverflow(floatValue, toType)) {
936            return nullptr;
937        } else {
938            resultConst->SetConstVal(
939                GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
940        }
941    } else {
942        const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
943        ASSERT_NOT_NULL(constValue);
944        double doubleValue = ceil(constValue->GetValue());
945        if (IsPrimitiveFloat(toType)) {
946            resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
947        } else if (DoubleToIntOverflow(doubleValue, toType)) {
948            return nullptr;
949        } else {
950            resultConst->SetConstVal(
951                GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
952        }
953    }
954    return resultConst;
955}
956
957template <class T>
958T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
959{
960    DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
961    size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
962    bool isSigned = IsSignedInteger(resultType.GetPrimType());
963    int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue();
964    uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum;
965    int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0;
966    if (isSigned && (value > max)) {
967        return static_cast<T>(max);
968    } else if (!isSigned && (value > umax)) {
969        return static_cast<T>(umax);
970    } else if (value < min) {
971        return static_cast<T>(min);
972    }
973    return value;
974}
975
976MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
977{
978    MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
979    if (fromType == PTY_f32) {
980        const auto &constValue = static_cast<const MIRFloatConst&>(cst);
981        float floatValue = constValue.GetValue();
982        if (isFloor) {
983            floatValue = floor(constValue.GetValue());
984        }
985        if (IsPrimitiveFloat(toType)) {
986            return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
987        }
988        if (FloatToIntOverflow(floatValue, toType)) {
989            return nullptr;
990        }
991        floatValue = CalIntValueFromFloatValue(floatValue, resultType);
992        return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType);
993    } else {
994        const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
995        double doubleValue = constValue.GetValue();
996        if (isFloor) {
997            doubleValue = floor(constValue.GetValue());
998        }
999        if (IsPrimitiveFloat(toType)) {
1000            return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1001        }
1002        if (DoubleToIntOverflow(doubleValue, toType)) {
1003            return nullptr;
1004        }
1005        doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
1006        // gcc/clang have bugs convert double to unsigned long, must convert to signed long first;
1007        return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1008    }
1009}
1010
1011ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1012{
1013    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1014    resultConst->SetPrimType(toType);
1015    resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1016    return resultConst;
1017}
1018
1019MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1020{
1021    MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1022    if (fromType == PTY_f32) {
1023        const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1024        float floatValue = round(constValue.GetValue());
1025        if (FloatToIntOverflow(floatValue, toType)) {
1026            return nullptr;
1027        }
1028        return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1029    } else if (fromType == PTY_f64) {
1030        const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1031        double doubleValue = round(constValue.GetValue());
1032        if (DoubleToIntOverflow(doubleValue, toType)) {
1033            return nullptr;
1034        }
1035        return GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1036            static_cast<uint64>(static_cast<int64>(doubleValue)), resultType);
1037    } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1038        const auto &constValue = static_cast<const MIRIntConst&>(cst);
1039        if (IsSignedInteger(fromType)) {
1040            int64 fromValue = constValue.GetExtValue();
1041            float floatValue = round(static_cast<float>(fromValue));
1042            if (static_cast<int64>(floatValue) == fromValue) {
1043                return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1044            }
1045        } else {
1046            uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1047            float floatValue = round(static_cast<float>(fromValue));
1048            if (static_cast<uint64>(floatValue) == fromValue) {
1049                return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1050            }
1051        }
1052    } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1053        const auto &constValue = static_cast<const MIRIntConst&>(cst);
1054        if (IsSignedInteger(fromType)) {
1055            int64 fromValue = constValue.GetExtValue();
1056            double doubleValue = round(static_cast<double>(fromValue));
1057            if (static_cast<int64>(doubleValue) == fromValue) {
1058                return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1059            }
1060        } else {
1061            uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1062            double doubleValue = round(static_cast<double>(fromValue));
1063            if (static_cast<uint64>(doubleValue) == fromValue) {
1064                return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1065            }
1066        }
1067    }
1068    return nullptr;
1069}
1070
1071ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1072{
1073    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1074    resultConst->SetPrimType(toType);
1075    resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1076    return resultConst;
1077}
1078
1079ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1080{
1081    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1082    resultConst->SetPrimType(toType);
1083    MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1084    if (fromType == PTY_f32) {
1085        const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1086        CHECK_NULL_FATAL(constValue);
1087        float floatValue = trunc(constValue->GetValue());
1088        if (IsPrimitiveFloat(toType)) {
1089            resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1090        } else if (FloatToIntOverflow(floatValue, toType)) {
1091            return nullptr;
1092        } else {
1093            resultConst->SetConstVal(
1094                GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1095        }
1096    } else {
1097        const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1098        CHECK_NULL_FATAL(constValue);
1099        double doubleValue = trunc(constValue->GetValue());
1100        if (IsPrimitiveFloat(toType)) {
1101            resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1102        } else if (DoubleToIntOverflow(doubleValue, toType)) {
1103            return nullptr;
1104        } else {
1105            resultConst->SetConstVal(
1106                GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1107        }
1108    }
1109    return resultConst;
1110}
1111
1112MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1113{
1114    if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1115        MIRConst *toConst = nullptr;
1116        uint32 fromSize = GetPrimTypeBitSize(fromType);
1117        uint32 toSize = GetPrimTypeBitSize(toType);
1118        // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1119        if (fromType == PTY_u1) {
1120            fromSize = 1;
1121        }
1122        if (toType == PTY_u1) {
1123            toSize = 1;
1124        }
1125        if (toSize > fromSize) {
1126            Opcode op = OP_zext;
1127            if (IsSignedInteger(fromType)) {
1128                op = OP_sext;
1129            }
1130            const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1131            ASSERT_NOT_NULL(constVal);
1132            toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize),
1133                constVal->GetValue().TruncOrExtend(fromType));
1134        } else {
1135            const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1136            ASSERT_NOT_NULL(constVal);
1137            MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1138            toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1139                static_cast<uint64>(constVal->GetExtValue()), type);
1140        }
1141        return toConst;
1142    }
1143    if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1144        MIRConst *toConst = nullptr;
1145        if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1146            DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1147            const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1148            ASSERT_NOT_NULL(fromValue);
1149            float floatValue = static_cast<float>(fromValue->GetValue());
1150            MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1151            toConst = toValue;
1152        } else {
1153            DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1154            const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1155            ASSERT_NOT_NULL(fromValue);
1156            double doubleValue = static_cast<double>(fromValue->GetValue());
1157            MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1158            toConst = toValue;
1159        }
1160        return toConst;
1161    }
1162    if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1163        return FoldFloorMIRConst(cst, fromType, toType, false);
1164    }
1165    if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1166        return FoldRoundMIRConst(cst, fromType, toType);
1167    }
1168    CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1169    return nullptr;
1170}
1171
1172ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1173{
1174    MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1175    if (toConstValue == nullptr) {
1176        return nullptr;
1177    }
1178    ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1179    toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1180    toConst->SetConstVal(toConstValue);
1181    return toConst;
1182}
1183
1184// return a primType with bit size >= bitSize (and the nearest one),
1185// and its signed/float type is the same as ptyp
1186PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1187{
1188    bool isSigned = IsSignedInteger(ptyp);
1189    bool isFloat = IsPrimitiveFloat(ptyp);
1190    if (bitSize == 1) { // 1 bit
1191        return PTY_u1;
1192    }
1193    if (bitSize <= 8) { // 8 bit
1194        return isSigned ? PTY_i8 : PTY_u8;
1195    }
1196    if (bitSize <= 16) { // 16 bit
1197        return isSigned ? PTY_i16 : PTY_u16;
1198    }
1199    if (bitSize <= 32) { // 32 bit
1200        return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1201    }
1202    if (bitSize <= 64) { // 64 bit
1203        return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1204    }
1205    return ptyp;
1206}
1207
1208size_t GetIntPrimTypeMax(PrimType ptyp)
1209{
1210    switch (ptyp) {
1211        case PTY_u1:
1212            return 1;
1213        case PTY_u8:
1214            return UINT8_MAX;
1215        case PTY_i8:
1216            return INT8_MAX;
1217        case PTY_u16:
1218            return UINT16_MAX;
1219        case PTY_i16:
1220            return INT16_MAX;
1221        case PTY_u32:
1222            return UINT32_MAX;
1223        case PTY_i32:
1224            return INT32_MAX;
1225        case PTY_u64:
1226            return UINT64_MAX;
1227        case PTY_i64:
1228            return INT64_MAX;
1229        default:
1230            CHECK_FATAL(false, "NYI");
1231    }
1232}
1233
1234ssize_t GetIntPrimTypeMin(PrimType ptyp)
1235{
1236    if (IsUnsignedInteger(ptyp)) {
1237        return 0;
1238    }
1239    switch (ptyp) {
1240        case PTY_i8:
1241            return INT8_MIN;
1242        case PTY_i16:
1243            return INT16_MIN;
1244        case PTY_i32:
1245            return INT32_MIN;
1246        case PTY_i64:
1247            return INT64_MIN;
1248        default:
1249            CHECK_FATAL(false, "NYI");
1250    }
1251}
1252
1253static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)
1254{
1255    if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) {
1256        return false;
1257    }
1258    if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) {
1259        return false;
1260    }
1261    return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1262        (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1263        (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp));
1264}
1265
1266std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1267{
1268    CHECK_NULL_FATAL(node);
1269    BaseNode *result = nullptr;
1270    if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
1271        return {node, std::nullopt};
1272    }
1273    std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1274    ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1275    PrimType destPtyp = node->GetPrimType();
1276    PrimType fromPtyp = node->FromType();
1277    if (cst != nullptr) {
1278        switch (node->GetOpCode()) {
1279            case OP_ceil: {
1280                result = FoldCeil(*cst, fromPtyp, destPtyp);
1281                break;
1282            }
1283            case OP_cvt: {
1284                result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1285                break;
1286            }
1287            case OP_floor: {
1288                result = FoldFloor(*cst, fromPtyp, destPtyp);
1289                break;
1290            }
1291            case OP_trunc: {
1292                result = FoldTrunc(*cst, fromPtyp, destPtyp);
1293                break;
1294            }
1295            default:
1296                DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1297                break;
1298        }
1299    } else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) {
1300        // the cvt is redundant
1301        return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1302    }
1303    if (result == nullptr) {
1304        BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1305        if (e != node->Opnd(0)) {
1306            result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1307                Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1308        } else {
1309            result = node;
1310        }
1311    }
1312    return std::make_pair(result, std::nullopt);
1313}
1314
1315MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1316{
1317    uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size);
1318    MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1319    MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1320    return constValue;
1321}
1322
1323ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1324                                           const ConstvalNode &cst) const
1325{
1326    ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1327    const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1328    ASSERT_NOT_NULL(intCst);
1329    IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1330    MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1331    resultConst->SetPrimType(toConst->GetType().GetPrimType());
1332    resultConst->SetConstVal(toConst);
1333    return resultConst;
1334}
1335
1336// check if truncation is redundant due to dread or iread having same effect
1337static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)
1338{
1339    if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) {
1340        return false;  // this is trying to be conservative
1341    }
1342    BaseNode *opnd = x.Opnd(0);
1343    MIRType *mirType = nullptr;
1344    if (opnd->GetOpCode() == OP_dread) {
1345        DreadNode *dread = static_cast<DreadNode*>(opnd);
1346        MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx());
1347        ASSERT_NOT_NULL(sym);
1348        mirType = sym->GetType();
1349    } else if (opnd->GetOpCode() == OP_iread) {
1350        IreadNode *iread = static_cast<IreadNode*>(opnd);
1351        MIRPtrType *ptrType =
1352            dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1353        if (ptrType == nullptr) {
1354            return false;
1355        }
1356        mirType = ptrType->GetPointedType();
1357    } else if (opnd->GetOpCode() == OP_extractbits &&
1358                x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) {
1359        return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() &&
1360            IsUnsignedInteger(opnd->GetPrimType()));
1361    } else {
1362        return false;
1363    }
1364    return IsPrimitiveInteger(mirType->GetPrimType()) &&
1365            ((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) ||
1366            (x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) &&
1367            mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() &&
1368            mirType->GetPrimType() == x.GetPrimType();
1369}
1370
1371// sext and zext also handled automatically
1372std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1373{
1374    CHECK_NULL_FATAL(node);
1375    BaseNode *result = nullptr;
1376    uint8 offset = node->GetBitsOffset();
1377    uint8 size = node->GetBitsSize();
1378    Opcode opcode = node->GetOpCode();
1379    std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1380    ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1381    if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1382        result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1383        return std::make_pair(result, std::nullopt);
1384    }
1385    BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1386    if (e != node->Opnd(0)) {
1387        result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1388                                                                       size, e);
1389    } else {
1390        result = node;
1391    }
1392    // check for consecutive and redundant extraction of same bits
1393    BaseNode *opnd = result->Opnd(0);
1394    DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null");
1395    Opcode opndOp = opnd->GetOpCode();
1396    if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1397        uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset();
1398        uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize();
1399        if (offset == opndOffset && size == opndSize) {
1400            result->SetOpnd(opnd->Opnd(0), 0);  // delete the redundant extraction
1401        }
1402    }
1403    if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1404        if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) {
1405            return std::make_pair(result->Opnd(0), std::nullopt);
1406        }
1407    }
1408    return std::make_pair(result, std::nullopt);
1409}
1410
1411std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1412{
1413    CHECK_NULL_FATAL(node);
1414    std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1415    BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1416    node->SetOpnd(e, 0);
1417    BaseNode *result = node;
1418    if (e->GetOpCode() != OP_addrof) {
1419        return std::make_pair(result, std::nullopt);
1420    }
1421
1422    AddrofNode *addrofNode = static_cast<AddrofNode*>(e);
1423    MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1424    DEBUG_ASSERT(msy != nullptr, "nullptr check");
1425    TyIdx typeId = msy->GetTyIdx();
1426    CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check");
1427    MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1428    MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1429    // If the high level type of iaddrof/iread doesn't match
1430    // the type of addrof's rhs, this optimization cannot be done.
1431    if (ptrType->GetPointedType() != msyType) {
1432        return std::make_pair(result, std::nullopt);
1433    }
1434
1435    Opcode op = node->GetOpCode();
1436    if (op == OP_iread) {
1437        result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1438                                                                  node->GetFieldID() + addrofNode->GetFieldID());
1439    }
1440    return std::make_pair(result, std::nullopt);
1441}
1442
1443bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1444{
1445    switch (op) {
1446        case OP_add: {
1447            int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1448            if (IsUnsignedInteger(primType)) {
1449                return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1450            }
1451            auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1452            return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1453                    static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1454                   (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1455                    static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1456        }
1457        case OP_sub: {
1458            if (IsUnsignedInteger(primType)) {
1459                return cstA < cstB;
1460            }
1461            int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1462            auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1463            return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1464                    static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1465                   (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1466                    static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1467        }
1468        default: {
1469            return false;
1470        }
1471    }
1472}
1473
1474std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1475{
1476    CHECK_NULL_FATAL(node);
1477    BaseNode *result = nullptr;
1478    std::optional<IntVal> sum = std::nullopt;
1479    Opcode op = node->GetOpCode();
1480    PrimType primType = node->GetPrimType();
1481    PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1482    PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1483    std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1484    std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1485    BaseNode *l = lp.first;
1486    BaseNode *r = rp.first;
1487    ASSERT_NOT_NULL(r);
1488    ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1489    ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1490    bool isInt = IsPrimitiveInteger(primType);
1491
1492    if (lConst != nullptr && rConst != nullptr) {
1493        MIRConst *lConstVal = lConst->GetConstVal();
1494        MIRConst *rConstVal = rConst->GetConstVal();
1495        ASSERT_NOT_NULL(lConstVal);
1496        ASSERT_NOT_NULL(rConstVal);
1497        // Don't fold div by 0, for floats div by 0 is well defined.
1498        if ((op == OP_div || op == OP_rem) && isInt &&
1499            !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1500            result = NewBinaryNode(node, op, primType, lConst, rConst);
1501        } else {
1502            // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1503            // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1504            // logic since the alternative is to return pair(result = nullptr, sum = 6).
1505            // Doing so would introduce many nullptr checks in the code. See previous
1506            // commits that implemented that logic for a comparison.
1507            result = FoldConstBinary(op, primType, *lConst, *rConst);
1508        }
1509    } else if (lConst != nullptr && isInt) {
1510        MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1511        ASSERT_NOT_NULL(mcst);
1512        PrimType cstTyp = mcst->GetType().GetPrimType();
1513        IntVal cst = mcst->GetValue();
1514        if (op == OP_add) {
1515            if (IsSignedInteger(cstTyp) && rp.second &&
1516                IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) {
1517                // do not introduce signed integer overflow
1518                result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1519            } else {
1520                sum = cst + rp.second;
1521                result = r;
1522            }
1523        } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1524            // We exclude u1 type for fixing the following wrong example:
1525            // before cf:
1526            //   sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1527            // after cf:
1528            //   add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1529            sum = cst - rp.second;
1530            if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) {
1531                r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r);
1532            }
1533            result = NegateTree(r);
1534        } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1535                    op == OP_band) &&
1536                    cst == 0) {
1537            // 0 * X -> 0
1538            // 0 / X -> 0
1539            // 0 % X -> 0
1540            // 0 >> X -> 0
1541            // 0 << X -> 0
1542            // 0 & X -> 0
1543            // 0 && X -> 0
1544            result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1545        } else if (op == OP_mul && cst == 1) {
1546            // 1 * X --> X
1547            sum = rp.second;
1548            result = r;
1549        } else if (op == OP_bior && cst == -1) {
1550            // (-1) | X -> -1
1551            result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp);
1552        } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1553            // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1554            sum = cst * rp.second;
1555            if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1556                rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1557            }
1558            result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1559        } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1560            // 0 | X -> X
1561            // 0 ^ X -> X
1562            sum = rp.second;
1563            result = r;
1564        } else {
1565            result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1566        }
1567        if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1568            result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1569        }
1570    } else if (rConst != nullptr && isInt) {
1571        MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1572        ASSERT_NOT_NULL(mcst);
1573        PrimType cstTyp = mcst->GetType().GetPrimType();
1574        IntVal cst = mcst->GetValue();
1575        if (op == OP_add) {
1576            if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1577                result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1578            } else {
1579                result = l;
1580                sum = lp.second + cst;
1581            }
1582        } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1583            result = l;
1584            sum = lp.second - cst;
1585        } else if ((op == OP_mul || op == OP_band) && cst == 0) {
1586            // X * 0 -> 0
1587            // X & 0 -> 0
1588            // X && 0 -> 0
1589            result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1590        } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1591            // case [X * 1 -> X]
1592            // case [X / 1 = X]
1593            sum = lp.second;
1594            result = l;
1595        } else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul &&
1596                IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) {
1597            // temporary fix for constfold of mul/div in DejaGnu
1598            // Later we need a more formal interface for pattern match
1599            // X * Y / Y -> X
1600            BaseNode *x = l->Opnd(0);
1601            BaseNode *y = l->Opnd(1);
1602            ConstvalNode *xConst = safe_cast<ConstvalNode>(x);
1603            ConstvalNode *yConst = safe_cast<ConstvalNode>(y);
1604            bool foldMulDiv = false;
1605            if (yConst != nullptr && xConst == nullptr &&
1606                IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1607                MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal());
1608                ASSERT_NOT_NULL(yCst);
1609                IntVal mulCst = yCst->GetValue();
1610                if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1611                    mulCst.GetExtValue() == cst.GetExtValue()) {
1612                    foldMulDiv = true;
1613                    result = x;
1614                }
1615            } else if (xConst != nullptr && yConst == nullptr &&
1616                        IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1617                MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal());
1618                ASSERT_NOT_NULL(xCst);
1619                IntVal mulCst = xCst->GetValue();
1620                if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1621                    mulCst.GetExtValue() == cst.GetExtValue()) {
1622                    foldMulDiv = true;
1623                    result = y;
1624                }
1625            }
1626            if (!foldMulDiv) {
1627                result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1628            }
1629        } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1630            // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1631            sum = lp.second * cst;
1632            if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1633                lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1634            }
1635            if (lp.first->GetOpCode() == OP_neg && cst == -1) {
1636                // special case: ((-X) + konst) * (-1) -> the pair [(X), -konst]
1637                result = lp.first->Opnd(0);
1638            } else {
1639                result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1640            }
1641        } else if (op == OP_band && cst == -1) {
1642            // X & (-1) -> X
1643            sum = lp.second;
1644            result = l;
1645        } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1646                   (!lp.second.has_value() || lp.second == 0)) {
1647            bool fold2extractbits = false;
1648            if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1649                BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1650                if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1651                    ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1652                    int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue();
1653                    uint64 ucst = cst.GetZXTValue();
1654                    uint32 bsize = 0;
1655                    do {
1656                        bsize++;
1657                        ucst >>= 1;
1658                    } while (ucst != 0);
1659                    if (shrAmt + static_cast<int64>(bsize) <=
1660                        static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) &&
1661                        static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) {
1662                        fold2extractbits = true;
1663                        // change to use extractbits
1664                        result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits,
1665                            GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0));
1666                        sum = std::nullopt;
1667                    }
1668                }
1669            }
1670            if (!fold2extractbits) {
1671                result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1672                sum = std::nullopt;
1673            }
1674        } else if (op == OP_bior && cst == -1) {
1675            // X | (-1) -> -1
1676            result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp);
1677        } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1678            // X >> 0 -> X
1679            // X << 0 -> X
1680            // X | 0 -> X
1681            // X ^ 0 -> X
1682            sum = lp.second;
1683            result = l;
1684        } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1685            // bxor i32 (
1686            //   cvt i32 u1 (regread u1 %13),
1687            //  constValue i32 1),
1688            result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1689            if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) {
1690                TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l);
1691                if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1692                    BaseNode *base = cvtNode->Opnd(0);
1693                    BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1694                    std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base);
1695                    BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1696                    result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1697                }
1698            }
1699        } else if (op == OP_rem && cst == 1) {
1700            // X % 1 -> 0
1701            result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1702        } else {
1703            result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1704        }
1705        if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1706            result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1707        }
1708    } else if (isInt && (op == OP_add || op == OP_sub)) {
1709        if (op == OP_add) {
1710            result = NewBinaryNode(node, op, primType, l, r);
1711            sum = lp.second + rp.second;
1712        } else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
1713            // if fold is (x - (y - z))    ->     (x - neg(z)) - y
1714            // (x - neg(z)) Could cross the int limit
1715            // return node
1716            result = node;
1717        } else {
1718            result = NewBinaryNode(node, op, primType, l, r);
1719            sum = lp.second - rp.second;
1720        }
1721    } else {
1722        result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1723    }
1724    return std::make_pair(result, sum);
1725}
1726
1727BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const
1728{
1729    if (isRConstval) {
1730        ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1));
1731        if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1732            const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0));
1733            return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1734                node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1735        }
1736    } else {
1737        ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0));
1738        if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1739            const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1));
1740            if (isGtOrLt) {
1741                return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1742                    node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0));
1743            } else {
1744                return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1745                    node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1746            }
1747        }
1748    }
1749    return &node;
1750}
1751
1752BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
1753{
1754    // See arm manual B.cond(P2993) and FCMP(P1091)
1755    CompareNode *node = &compareNode;
1756    BaseNode *result = node;
1757    BaseNode *l = node->Opnd(0);
1758    BaseNode *r = node->Opnd(1);
1759    if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
1760        if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1761            (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1762            result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval));
1763        } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
1764            // ne (u1 x, constValue 0)  <==> x
1765            ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1766            if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1767                BaseNode *opnd = l;
1768                do {
1769                    if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1770                        result = opnd;
1771                        break;
1772                    } else if (opnd->GetOpCode() == OP_cvt) {
1773                        TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd);
1774                        opnd = cvtNode->Opnd(0);
1775                    } else {
1776                        opnd = nullptr;
1777                    }
1778                } while (opnd != nullptr);
1779            }
1780        } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
1781            ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1782            if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
1783                (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1784                auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne;
1785                result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1786                    resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1));
1787            }
1788        }
1789    } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
1790        if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1791            (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1792            result = SimplifyDoubleConstvalCompare(*node,
1793                (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true);
1794        }
1795    }
1796    return result;
1797}
1798
1799std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
1800{
1801    CHECK_NULL_FATAL(node);
1802    BaseNode *result = nullptr;
1803    std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1804    std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1805    ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
1806    ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
1807    Opcode opcode = node->GetOpCode();
1808    if (lConst != nullptr && rConst != nullptr) {
1809        result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
1810    } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
1811               lConst->GetConstVal()->GetKind() == kConstInt) {
1812        BaseNode *l = lp.first;
1813        BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1814        result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
1815    } else {
1816        BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
1817        BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1818        if (l != node->Opnd(0) || r != node->Opnd(1)) {
1819            result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1820                Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
1821        } else {
1822            result = node;
1823        }
1824        auto *compareNode = static_cast<CompareNode*>(result);
1825        CHECK_NULL_FATAL(compareNode);
1826        result = SimplifyDoubleCompare(*compareNode);
1827    }
1828    return std::make_pair(result, std::nullopt);
1829}
1830
1831BaseNode *ConstantFold::Fold(BaseNode *node)
1832{
1833    if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
1834        return nullptr;
1835    }
1836    std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node);
1837    BaseNode *result = PairToExpr(node->GetPrimType(), p);
1838    if (result == node) {
1839        result = nullptr;
1840    }
1841    return result;
1842}
1843
1844}  // namespace maple
1845