1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef MPL2MPL_INCLUDE_CONSTANTFOLD_H
17 #define MPL2MPL_INCLUDE_CONSTANTFOLD_H
18 #include "mir_nodes.h"
19 #include "phase_impl.h"
20 
21 #include <optional>
22 
23 namespace maple {
24 class ConstantFold : public FuncOptimizeImpl {
25 public:
ConstantFold(MIRModule &mod, bool trace)26     ConstantFold(MIRModule &mod, bool trace) : FuncOptimizeImpl(mod, trace), mirModule(&mod) {}
27 
ConstantFold(MIRModule &mod)28     explicit ConstantFold(MIRModule &mod) : FuncOptimizeImpl(mod, false), mirModule(&mod) {}
29 
30     // Fold an expression.
31     // It returns a new expression if there was something to fold, or
32     // nullptr otherwise.
33     BaseNode *Fold(BaseNode *node);
34 
35     FuncOptimizeImpl *Clone() override
36     {
37         return new ConstantFold(*this);
38     }
39 
40     ~ConstantFold() override
41     {
42         mirModule = nullptr;
43     }
44 
45     template <class T>
46     T CalIntValueFromFloatValue(T value, const MIRType &resultType) const;
47     MIRConst *FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor = true) const;
48     MIRConst *FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const;
49     MIRConst *FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const;
50     MIRConst *FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const;
51     static MIRConst *FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
52                                                 const MIRIntConst &intConst1);
53     MIRConst *FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
54                                           const MIRConst &const0, const MIRConst &const1) const;
55     static bool IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB);
56     static MIRIntConst *FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode);
57 private:
58     std::pair<BaseNode*, std::optional<IntVal>> FoldBase(BaseNode *node) const;
59     std::pair<BaseNode*, std::optional<IntVal>> FoldBinary(BinaryNode *node);
60     std::pair<BaseNode*, std::optional<IntVal>> FoldCompare(CompareNode *node);
61     std::pair<BaseNode*, std::optional<IntVal>> FoldExtractbits(ExtractbitsNode *node);
62     ConstvalNode *FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size, const ConstvalNode &cst) const;
63     std::pair<BaseNode*, std::optional<IntVal>> FoldIread(IreadNode *node);
64     std::pair<BaseNode*, std::optional<IntVal>> FoldRetype(RetypeNode *node);
65     std::pair<BaseNode*, std::optional<IntVal>> FoldUnary(UnaryNode *node);
66     std::pair<BaseNode*, std::optional<IntVal>> FoldTypeCvt(TypeCvtNode *node);
67     ConstvalNode *FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
68     ConstvalNode *FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
69     ConstvalNode *FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
70     ConstvalNode *FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
71     ConstvalNode *FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const;
72     ConstvalNode *FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, const ConstvalNode &const0,
73                                       const ConstvalNode &const1) const;
74     ConstvalNode *FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
75                                   const ConstvalNode &const1) const;
76     ConstvalNode *FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
77                                          const ConstvalNode &const0, const ConstvalNode &const1) const;
78     MIRIntConst *FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
79                                                 const MIRIntConst &intConst0, const MIRIntConst &intConst1) const;
80     ConstvalNode *FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
81                                      const ConstvalNode &const1) const;
82     ConstvalNode *FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
83                                         const ConstvalNode &const0, const ConstvalNode &const1) const;
84     bool ConstValueEqual(int64 leftValue, int64 rightValue) const;
85     bool ConstValueEqual(float leftValue, float rightValue) const;
86     bool ConstValueEqual(double leftValue, double rightValue) const;
87     template<typename T>
88     bool FullyEqual(T leftValue, T rightValue) const;
89     template<typename T>
90     int64 ComparisonResult(Opcode op, T *leftConst, T *rightConst) const;
91     MIRIntConst *FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
92                                                const MIRConst &leftConst, const MIRConst &rightConst) const;
93     ConstvalNode *FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
94                                     const ConstvalNode &const1) const;
95     ConstvalNode *FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const;
96     template <typename T>
97     ConstvalNode *FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const;
98     BaseNode *NegateTree(BaseNode *node) const;
99     BaseNode *Negate(BaseNode *node) const;
100     BaseNode *Negate(UnaryNode *node) const;
101     BaseNode *Negate(const ConstvalNode *node) const;
102     BinaryNode *NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs, BaseNode *rhs) const;
103     UnaryNode *NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const;
104     std::pair<BaseNode*, std::optional<IntVal>> DispatchFold(BaseNode *node);
105     BaseNode *PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const;
106     BaseNode *SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt = false) const;
107     BaseNode *SimplifyDoubleCompare(CompareNode &compareNode) const;
108     CompareNode *FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType, BaseNode &l,
109                                             BaseNode &r) const;
110     MIRModule *mirModule;
111 };
112 
113 }  // namespace maple
114 #endif  // MPL2MPL_INCLUDE_CONSTANTFOLD_H
115