1/*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#ifndef MPL2MPL_INCLUDE_CONSTANTFOLD_H
17#define MPL2MPL_INCLUDE_CONSTANTFOLD_H
18#include "mir_nodes.h"
19#include "phase_impl.h"
20
21#include <optional>
22
23namespace maple {
24class ConstantFold : public FuncOptimizeImpl {
25public:
26    ConstantFold(MIRModule &mod, bool trace) : FuncOptimizeImpl(mod, trace), mirModule(&mod) {}
27
28    explicit ConstantFold(MIRModule &mod) : FuncOptimizeImpl(mod, false), mirModule(&mod) {}
29
30    // Fold an expression.
31    // It returns a new expression if there was something to fold, or
32    // nullptr otherwise.
33    BaseNode *Fold(BaseNode *node);
34
35    FuncOptimizeImpl *Clone() override
36    {
37        return new ConstantFold(*this);
38    }
39
40    ~ConstantFold() override
41    {
42        mirModule = nullptr;
43    }
44
45    template <class T>
46    T CalIntValueFromFloatValue(T value, const MIRType &resultType) const;
47    MIRConst *FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor = true) const;
48    MIRConst *FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const;
49    MIRConst *FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const;
50    MIRConst *FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const;
51    static MIRConst *FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
52                                                const MIRIntConst &intConst1);
53    MIRConst *FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
54                                          const MIRConst &const0, const MIRConst &const1) const;
55    static bool IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB);
56    static MIRIntConst *FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode);
57private:
58    std::pair<BaseNode*, std::optional<IntVal>> FoldBase(BaseNode *node) const;
59    std::pair<BaseNode*, std::optional<IntVal>> FoldBinary(BinaryNode *node);
60    std::pair<BaseNode*, std::optional<IntVal>> FoldCompare(CompareNode *node);
61    std::pair<BaseNode*, std::optional<IntVal>> FoldExtractbits(ExtractbitsNode *node);
62    ConstvalNode *FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size, const ConstvalNode &cst) const;
63    std::pair<BaseNode*, std::optional<IntVal>> FoldIread(IreadNode *node);
64    std::pair<BaseNode*, std::optional<IntVal>> FoldRetype(RetypeNode *node);
65    std::pair<BaseNode*, std::optional<IntVal>> FoldUnary(UnaryNode *node);
66    std::pair<BaseNode*, std::optional<IntVal>> FoldTypeCvt(TypeCvtNode *node);
67    ConstvalNode *FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
68    ConstvalNode *FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
69    ConstvalNode *FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
70    ConstvalNode *FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
71    ConstvalNode *FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
72    ConstvalNode *FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, const ConstvalNode &const0,
73                                      const ConstvalNode &const1) const;
74    ConstvalNode *FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
75                                  const ConstvalNode &const1) const;
76    ConstvalNode *FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
77                                         const ConstvalNode &const0, const ConstvalNode &const1) const;
78    MIRIntConst *FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
79                                                const MIRIntConst &intConst0, const MIRIntConst &intConst1) const;
80    ConstvalNode *FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
81                                     const ConstvalNode &const1) const;
82    ConstvalNode *FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
83                                        const ConstvalNode &const0, const ConstvalNode &const1) const;
84    bool ConstValueEqual(int64 leftValue, int64 rightValue) const;
85    bool ConstValueEqual(float leftValue, float rightValue) const;
86    bool ConstValueEqual(double leftValue, double rightValue) const;
87    template<typename T>
88    bool FullyEqual(T leftValue, T rightValue) const;
89    template<typename T>
90    int64 ComparisonResult(Opcode op, T *leftConst, T *rightConst) const;
91    MIRIntConst *FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
92                                               const MIRConst &leftConst, const MIRConst &rightConst) const;
93    ConstvalNode *FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
94                                    const ConstvalNode &const1) const;
95    ConstvalNode *FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const;
96    template <typename T>
97    ConstvalNode *FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const;
98    BaseNode *NegateTree(BaseNode *node) const;
99    BaseNode *Negate(BaseNode *node) const;
100    BaseNode *Negate(UnaryNode *node) const;
101    BaseNode *Negate(const ConstvalNode *node) const;
102    BinaryNode *NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs, BaseNode *rhs) const;
103    UnaryNode *NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const;
104    std::pair<BaseNode*, std::optional<IntVal>> DispatchFold(BaseNode *node);
105    BaseNode *PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const;
106    BaseNode *SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt = false) const;
107    BaseNode *SimplifyDoubleCompare(CompareNode &compareNode) const;
108    CompareNode *FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType, BaseNode &l,
109                                            BaseNode &r) const;
110    MIRModule *mirModule;
111};
112
113}  // namespace maple
114#endif  // MPL2MPL_INCLUDE_CONSTANTFOLD_H
115