1 /*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "constantfold.h"
17 #include <cmath>
18 #include <cfloat>
19 #include <climits>
20 #include <type_traits>
21 #include "mpl_logging.h"
22 #include "mir_function.h"
23 #include "mir_builder.h"
24 #include "global_tables.h"
25 #include "me_option.h"
26 #include "maple_phase_manager.h"
27 #include "mir_type.h"
28
29 namespace maple {
30
31 namespace {
32 constexpr uint32 kByteSizeOfBit64 = 8; // byte number for 64 bit
33 constexpr uint32 kBitSizePerByte = 8;
34 constexpr maple::int32 kMaxOffset = INT_MAX - 8;
35
36 enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 };
37
operator *(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)38 std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
39 {
40 if (!v1 && !v2) {
41 return std::nullopt;
42 }
43
44 // Perform all calculations in terms of the maximum available signed type.
45 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
46 return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64);
47 }
48
49 // Perform all calculations in terms of the maximum available signed type.
50 // The value will be truncated for an appropriate type when constant is created in PairToExpr function
AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)51 std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd)
52 {
53 if (!v1 && !v2) {
54 return std::nullopt;
55 }
56
57 if (v1 && v2) {
58 return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64);
59 }
60
61 if (v1) {
62 return v1->TruncOrExtend(PTY_i64);
63 }
64
65 // !v1 && v2
66 return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64));
67 }
68
operator +(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)69 std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
70 {
71 return AddSub(v1, v2, true);
72 }
73
operator -(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)74 std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2)
75 {
76 return AddSub(v1, v2, false);
77 }
78
79 } // anonymous namespace
80
81 // This phase is designed to achieve compiler optimization by
82 // simplifying constant expressions. The constant expression
83 // is evaluated and replaced by the value calculated on compile
84 // time to save time on runtime.
85 //
86 // The main procedure shows as following:
87 // A. Analyze expression type
88 // B. Analysis operator type
89 // C. Replace the expression with the result of the operation
90
91 // true if the constant's bits are made of only one group of contiguous 1's
92 // starting at bit 0
ContiguousBitsOf1(uint64 x)93 static bool ContiguousBitsOf1(uint64 x)
94 {
95 if (x == 0) {
96 return false;
97 }
98 return (~x & (x + 1)) == (x + 1);
99 }
100
IsPowerOf2(uint64 num)101 inline bool IsPowerOf2(uint64 num)
102 {
103 if (num == 0) {
104 return false;
105 }
106 return (~(num - 1) & num) == num;
107 }
108
NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs, BaseNode *rhs) const109 BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs,
110 BaseNode *rhs) const
111 {
112 CHECK_NULL_FATAL(old);
113 BinaryNode *result = nullptr;
114 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) {
115 result = old;
116 } else {
117 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs);
118 }
119 return result;
120 }
121
NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const122 UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const
123 {
124 CHECK_NULL_FATAL(old);
125 UnaryNode *result = nullptr;
126 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) {
127 result = old;
128 } else {
129 result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr);
130 }
131 return result;
132 }
133
PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const134 BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const
135 {
136 CHECK_NULL_FATAL(pair.first);
137 BaseNode *result = pair.first;
138 if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) {
139 return result;
140 }
141 if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) {
142 // -a, 5 -> 5 - a
143 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
144 static_cast<uint64>(pair.second->GetExtValue()), resultType);
145 BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0);
146 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r);
147 } else {
148 if ((!pair.second->GetSignBit() &&
149 pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) ||
150 pair.second->TruncOrExtend(resultType).IsMinValue() ||
151 pair.second->GetSXTValue() == INT64_MIN) {
152 // +-a, 5 -> a + 5
153 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
154 static_cast<uint64>(pair.second->GetExtValue()), resultType);
155 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val);
156 } else {
157 // +-a, -5 -> a + -5
158 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(
159 static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType);
160 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val);
161 }
162 }
163 return result;
164 }
165
FoldBase(BaseNode *node) const166 std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const
167 {
168 return std::make_pair(node, std::nullopt);
169 }
170
DispatchFold(BaseNode *node)171 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node)
172 {
173 CHECK_NULL_FATAL(node);
174 if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
175 return {node, std::nullopt};
176 }
177 switch (node->GetOpCode()) {
178 case OP_abs:
179 case OP_bnot:
180 case OP_lnot:
181 case OP_neg:
182 case OP_sqrt:
183 return FoldUnary(static_cast<UnaryNode*>(node));
184 case OP_ceil:
185 case OP_floor:
186 case OP_trunc:
187 case OP_cvt:
188 return FoldTypeCvt(static_cast<TypeCvtNode*>(node));
189 case OP_sext:
190 case OP_zext:
191 case OP_extractbits:
192 return FoldExtractbits(static_cast<ExtractbitsNode*>(node));
193 case OP_iread:
194 return FoldIread(static_cast<IreadNode*>(node));
195 case OP_add:
196 case OP_ashr:
197 case OP_band:
198 case OP_bior:
199 case OP_bxor:
200 case OP_div:
201 case OP_lshr:
202 case OP_max:
203 case OP_min:
204 case OP_mul:
205 case OP_rem:
206 case OP_shl:
207 case OP_sub:
208 return FoldBinary(static_cast<BinaryNode*>(node));
209 case OP_eq:
210 case OP_ne:
211 case OP_ge:
212 case OP_gt:
213 case OP_le:
214 case OP_lt:
215 case OP_cmp:
216 return FoldCompare(static_cast<CompareNode*>(node));
217 case OP_retype:
218 return FoldRetype(static_cast<RetypeNode*>(node));
219 default:
220 return FoldBase(static_cast<BaseNode*>(node));
221 }
222 }
223
Negate(BaseNode *node) const224 BaseNode *ConstantFold::Negate(BaseNode *node) const
225 {
226 CHECK_NULL_FATAL(node);
227 return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node);
228 }
229
Negate(UnaryNode *node) const230 BaseNode *ConstantFold::Negate(UnaryNode *node) const
231 {
232 CHECK_NULL_FATAL(node);
233 BaseNode *result = nullptr;
234 if (node->GetOpCode() == OP_neg) {
235 result = static_cast<BaseNode*>(node->Opnd(0));
236 } else {
237 BaseNode *n = static_cast<BaseNode*>(node);
238 result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n);
239 }
240 return result;
241 }
242
Negate(const ConstvalNode *node) const243 BaseNode *ConstantFold::Negate(const ConstvalNode *node) const
244 {
245 CHECK_NULL_FATAL(node);
246 ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
247 CHECK_NULL_FATAL(copy);
248 copy->GetConstVal()->Neg();
249 return copy;
250 }
251
NegateTree(BaseNode *node) const252 BaseNode *ConstantFold::NegateTree(BaseNode *node) const
253 {
254 CHECK_NULL_FATAL(node);
255 if (node->IsUnaryNode()) {
256 return Negate(static_cast<UnaryNode*>(node));
257 } else if (node->GetOpCode() == OP_constval) {
258 return Negate(static_cast<ConstvalNode*>(node));
259 } else {
260 return Negate(static_cast<BaseNode*>(node));
261 }
262 }
263
FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, const MIRIntConst &intConst0, const MIRIntConst &intConst1) const264 MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
265 const MIRIntConst &intConst0,
266 const MIRIntConst &intConst1) const
267 {
268 uint64 result = 0;
269
270 bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType);
271 bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType);
272 bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType);
273
274 switch (opcode) {
275 case OP_eq: {
276 result = equal;
277 break;
278 }
279 case OP_ge: {
280 result = (greater || equal) ? 1 : 0;
281 break;
282 }
283 case OP_gt: {
284 result = greater;
285 break;
286 }
287 case OP_le: {
288 result = (less || equal) ? 1 : 0;
289 break;
290 }
291 case OP_lt: {
292 result = less;
293 break;
294 }
295 case OP_ne: {
296 result = !equal;
297 break;
298 }
299 case OP_cmp: {
300 if (greater) {
301 result = kGreater;
302 } else if (equal) {
303 result = kEqual;
304 } else if (less) {
305 result = static_cast<uint64>(kLess);
306 }
307 break;
308 }
309 default:
310 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison");
311 break;
312 }
313 // determine the type
314 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
315 // form the constant
316 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
317 return constValue;
318 }
319
FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, const ConstvalNode &const0, const ConstvalNode &const1) const320 ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
321 const ConstvalNode &const0, const ConstvalNode &const1) const
322 {
323 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
324 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
325 CHECK_NULL_FATAL(intConst0);
326 CHECK_NULL_FATAL(intConst1);
327 MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
328 // form the ConstvalNode
329 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
330 resultConst->SetPrimType(resultType);
331 resultConst->SetConstVal(constValue);
332 return resultConst;
333 }
334
FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0, const MIRIntConst &intConst1)335 MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0,
336 const MIRIntConst &intConst1)
337 {
338 IntVal intVal0 = intConst0.GetValue();
339 IntVal intVal1 = intConst1.GetValue();
340 IntVal result(static_cast<uint64>(0), resultType);
341
342 switch (opcode) {
343 case OP_add: {
344 result = intVal0.Add(intVal1, resultType);
345 break;
346 }
347 case OP_sub: {
348 result = intVal0.Sub(intVal1, resultType);
349 break;
350 }
351 case OP_mul: {
352 result = intVal0.Mul(intVal1, resultType);
353 break;
354 }
355 case OP_div: {
356 result = intVal0.Div(intVal1, resultType);
357 break;
358 }
359 case OP_rem: {
360 result = intVal0.Rem(intVal1, resultType);
361 break;
362 }
363 case OP_ashr: {
364 result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
365 break;
366 }
367 case OP_lshr: {
368 result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
369 break;
370 }
371 case OP_shl: {
372 result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType);
373 break;
374 }
375 case OP_max: {
376 result = Max(intVal0, intVal1, resultType);
377 break;
378 }
379 case OP_min: {
380 result = Min(intVal0, intVal1, resultType);
381 break;
382 }
383 case OP_band: {
384 result = intVal0.And(intVal1, resultType);
385 break;
386 }
387 case OP_bior: {
388 result = intVal0.Or(intVal1, resultType);
389 break;
390 }
391 case OP_bxor: {
392 result = intVal0.Xor(intVal1, resultType);
393 break;
394 }
395 default:
396 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary");
397 break;
398 }
399 // determine the type
400 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
401 // form the constant
402 MIRIntConst *constValue =
403 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
404 return constValue;
405 }
406
FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, const ConstvalNode &const1) const407 ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
408 const ConstvalNode &const1) const
409 {
410 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal());
411 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal());
412 CHECK_NULL_FATAL(intConst0);
413 CHECK_NULL_FATAL(intConst1);
414 MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1);
415 // form the ConstvalNode
416 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
417 resultConst->SetPrimType(resultType);
418 resultConst->SetConstVal(constValue);
419 return resultConst;
420 }
421
FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, const ConstvalNode &const1) const422 ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
423 const ConstvalNode &const1) const
424 {
425 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
426 const MIRDoubleConst *doubleConst0 = nullptr;
427 const MIRDoubleConst *doubleConst1 = nullptr;
428 const MIRFloatConst *floatConst0 = nullptr;
429 const MIRFloatConst *floatConst1 = nullptr;
430 bool useDouble = (const0.GetPrimType() == PTY_f64);
431 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
432 resultConst->SetPrimType(resultType);
433 if (useDouble) {
434 doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal());
435 doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal());
436 CHECK_NULL_FATAL(doubleConst0);
437 CHECK_NULL_FATAL(doubleConst1);
438 } else {
439 floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal());
440 floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal());
441 CHECK_NULL_FATAL(floatConst0);
442 CHECK_NULL_FATAL(floatConst1);
443 }
444 float constValueFloat = 0.0;
445 double constValueDouble = 0.0;
446 switch (opcode) {
447 case OP_add: {
448 if (useDouble) {
449 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue();
450 } else {
451 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue();
452 }
453 break;
454 }
455 case OP_sub: {
456 if (useDouble) {
457 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue();
458 } else {
459 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue();
460 }
461 break;
462 }
463 case OP_mul: {
464 if (useDouble) {
465 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue();
466 } else {
467 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue();
468 }
469 break;
470 }
471 case OP_div: {
472 // for floats div by 0 is well defined
473 if (useDouble) {
474 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue();
475 } else {
476 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue();
477 }
478 break;
479 }
480 case OP_max: {
481 if (useDouble) {
482 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue()
483 : doubleConst1->GetValue();
484 } else {
485 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue()
486 : floatConst1->GetValue();
487 }
488 break;
489 }
490 case OP_min: {
491 if (useDouble) {
492 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue()
493 : doubleConst1->GetValue();
494 } else {
495 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue()
496 : floatConst1->GetValue();
497 }
498 break;
499 }
500 case OP_rem:
501 case OP_ashr:
502 case OP_lshr:
503 case OP_shl:
504 case OP_band:
505 case OP_bior:
506 case OP_bxor: {
507 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary");
508 break;
509 }
510 default:
511 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary");
512 break;
513 }
514 if (resultType == PTY_f64) {
515 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble));
516 } else {
517 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat));
518 }
519 return resultConst;
520 }
521
ConstValueEqual(int64 leftValue, int64 rightValue) const522 bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const
523 {
524 return (leftValue == rightValue);
525 }
526
ConstValueEqual(float leftValue, float rightValue) const527 bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const
528 {
529 auto result = fabs(leftValue - rightValue);
530 return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN;
531 }
532
ConstValueEqual(double leftValue, double rightValue) const533 bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const
534 {
535 auto result = fabs(leftValue - rightValue);
536 return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN;
537 }
538
539 template<typename T>
FullyEqual(T leftValue, T rightValue) const540 bool ConstantFold::FullyEqual(T leftValue, T rightValue) const
541 {
542 if (std::isinf(leftValue) && std::isinf(rightValue)) {
543 // (inf == inf), add the judgement here in case of the subtraction between float type inf
544 return true;
545 } else {
546 return ConstValueEqual(leftValue, rightValue);
547 }
548 }
549
550 template<typename T>
ComparisonResult(Opcode op, T *leftConst, T *rightConst) const551 int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const
552 {
553 DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr");
554 typename T::value_type leftValue = leftConst->GetValue();
555 DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr");
556 typename T::value_type rightValue = rightConst->GetValue();
557 int64 result = 0;
558 switch (op) {
559 case OP_eq: {
560 result = FullyEqual(leftValue, rightValue);
561 break;
562 }
563 case OP_ge: {
564 result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue);
565 break;
566 }
567 case OP_gt: {
568 result = (leftValue > rightValue);
569 break;
570 }
571 case OP_le: {
572 result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue);
573 break;
574 }
575 case OP_lt: {
576 result = (leftValue < rightValue);
577 break;
578 }
579 case OP_ne: {
580 result = !FullyEqual(leftValue, rightValue);
581 break;
582 }
583 [[clang::fallthrough]];
584 case OP_cmp: {
585 if (leftValue > rightValue) {
586 result = kGreater;
587 } else if (FullyEqual(leftValue, rightValue)) {
588 result = kEqual;
589 } else {
590 result = kLess;
591 }
592 break;
593 }
594 default:
595 DEBUG_ASSERT(false, "Unknown opcode for Comparison");
596 break;
597 }
598 return result;
599 }
600
FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, const MIRConst &leftConst, const MIRConst &rightConst) const601 MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
602 const MIRConst &leftConst, const MIRConst &rightConst) const
603 {
604 int64 result = 0;
605 bool useDouble = (opndType == PTY_f64);
606 if (useDouble) {
607 result =
608 ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst));
609 } else {
610 result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst));
611 }
612 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
613 MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type);
614 return resultConst;
615 }
616
FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, const ConstvalNode &const0, const ConstvalNode &const1) const617 ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
618 const ConstvalNode &const0, const ConstvalNode &const1) const
619 {
620 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match");
621 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
622 resultConst->SetPrimType(resultType);
623 resultConst->SetConstVal(
624 FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal()));
625 return resultConst;
626 }
627
FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, const MIRConst &const0, const MIRConst &const1) const628 MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType,
629 const MIRConst &const0, const MIRConst &const1) const
630 {
631 MIRConst *returnValue = nullptr;
632 if (IsPrimitiveInteger(opndType)) {
633 const auto *intConst0 = safe_cast<MIRIntConst>(&const0);
634 const auto *intConst1 = safe_cast<MIRIntConst>(&const1);
635 ASSERT_NOT_NULL(intConst0);
636 ASSERT_NOT_NULL(intConst1);
637 returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1);
638 } else if (opndType == PTY_f32 || opndType == PTY_f64) {
639 returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1);
640 } else {
641 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst");
642 }
643 return returnValue;
644 }
645
FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, const ConstvalNode &const0, const ConstvalNode &const1) const646 ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType,
647 const ConstvalNode &const0, const ConstvalNode &const1) const
648 {
649 ConstvalNode *returnValue = nullptr;
650 if (IsPrimitiveInteger(opndType)) {
651 returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1);
652 } else if (opndType == PTY_f32 || opndType == PTY_f64) {
653 returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1);
654 } else {
655 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison");
656 }
657 return returnValue;
658 }
659
FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType, BaseNode &l, BaseNode &r) const660 CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType,
661 BaseNode &l, BaseNode &r) const
662 {
663 CompareNode *result = nullptr;
664 Opcode op = opcode;
665 switch (opcode) {
666 case OP_gt: {
667 op = OP_lt;
668 break;
669 }
670 case OP_lt: {
671 op = OP_gt;
672 break;
673 }
674 case OP_ge: {
675 op = OP_le;
676 break;
677 }
678 case OP_le: {
679 op = OP_ge;
680 break;
681 }
682 case OP_eq: {
683 break;
684 }
685 case OP_ne: {
686 break;
687 }
688 default:
689 DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse");
690 break;
691 }
692
693 result =
694 mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l);
695 return result;
696 }
697
FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, const ConstvalNode &const1) const698 ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0,
699 const ConstvalNode &const1) const
700 {
701 ConstvalNode *returnValue = nullptr;
702 if (IsPrimitiveInteger(resultType)) {
703 returnValue = FoldIntConstBinary(opcode, resultType, const0, const1);
704 } else if (resultType == PTY_f32 || resultType == PTY_f64) {
705 returnValue = FoldFPConstBinary(opcode, resultType, const0, const1);
706 } else {
707 DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary");
708 }
709 return returnValue;
710 }
711
FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)712 MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode)
713 {
714 CHECK_NULL_FATAL(constNode);
715 IntVal result = constNode->GetValue().TruncOrExtend(resultType);
716 switch (opcode) {
717 case OP_abs: {
718 if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) {
719 result = -result;
720 }
721 break;
722 }
723 case OP_bnot: {
724 result = ~result;
725 break;
726 }
727 case OP_lnot: {
728 uint64 resultInt = result == 0 ? 1 : 0;
729 result = {resultInt, resultType};
730 break;
731 }
732 case OP_neg: {
733 result = -result;
734 break;
735 }
736 case OP_sext: // handled in FoldExtractbits
737 case OP_zext: // handled in FoldExtractbits
738 case OP_extractbits: // handled in FoldExtractbits
739 case OP_sqrt: {
740 DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst");
741 break;
742 }
743 default:
744 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst");
745 break;
746 }
747 // determine the type
748 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
749 // form the constant
750 MIRIntConst *constValue =
751 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type);
752 return constValue;
753 }
754
755 template <typename T>
FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const756 ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const
757 {
758 CHECK_NULL_FATAL(constNode);
759 double constValue = 0;
760 T *fpCst = static_cast<T*>(constNode->GetConstVal());
761 switch (opcode) {
762 case OP_neg: {
763 constValue = typename T::value_type(-fpCst->GetValue());
764 break;
765 }
766 case OP_abs: {
767 constValue = typename T::value_type(fabs(fpCst->GetValue()));
768 break;
769 }
770 case OP_sqrt: {
771 constValue = typename T::value_type(sqrt(fpCst->GetValue()));
772 break;
773 }
774 case OP_bnot:
775 case OP_lnot:
776 case OP_sext:
777 case OP_zext:
778 case OP_extractbits: {
779 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary");
780 break;
781 }
782 default:
783 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary");
784 break;
785 }
786 auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
787 resultConst->SetPrimType(resultType);
788 if (resultType == PTY_f32) {
789 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue)));
790 } else if (resultType == PTY_f64) {
791 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue));
792 } else {
793 CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64");
794 }
795 return resultConst;
796 }
797
FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const798 ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const
799 {
800 ConstvalNode *returnValue = nullptr;
801 if (IsPrimitiveInteger(resultType)) {
802 const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal());
803 auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst);
804 returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
805 returnValue->SetPrimType(resultType);
806 returnValue->SetConstVal(constValue);
807 } else if (resultType == PTY_f32) {
808 returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode);
809 } else if (resultType == PTY_f64) {
810 returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode);
811 } else {
812 DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary");
813 }
814 return returnValue;
815 }
816
FoldRetype(RetypeNode *node)817 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node)
818 {
819 CHECK_NULL_FATAL(node);
820 BaseNode *result = node;
821 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
822 if (node->Opnd(0) != p.first) {
823 RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator());
824 CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype");
825 newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0);
826 result = newRetNode;
827 }
828 return std::make_pair(result, std::nullopt);
829 }
830
FoldUnary(UnaryNode *node)831 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node)
832 {
833 CHECK_NULL_FATAL(node);
834 BaseNode *result = nullptr;
835 std::optional<IntVal> sum = std::nullopt;
836 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
837 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
838 if (cst != nullptr) {
839 result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst);
840 } else {
841 bool isInt = IsPrimitiveInteger(node->GetPrimType());
842 // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's
843 // primType will be set to opnd type. There will be problems in some cases. For example:
844 // before cf:
845 // neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f))
846 // after cf:
847 // neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f)) # wrong!
848 // As a workaround, we exclude u1 opnd type
849 if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) {
850 result = NegateTree(p.first);
851 if (result->GetOpCode() == OP_neg) {
852 PrimType origPtyp = node->GetPrimType();
853 PrimType newPtyp = result->GetPrimType();
854 if (newPtyp == origPtyp) {
855 if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) {
856 // NegateTree returned an UnaryNode quivalent to `n`, so keep the
857 // original UnaryNode to preserve identity
858 result = node;
859 }
860 } else {
861 if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) {
862 // do not fold explicit cvt
863 result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(),
864 PairToExpr(node->Opnd(0)->GetPrimType(), p));
865 return std::make_pair(result, std::nullopt);
866 } else {
867 result->SetPrimType(origPtyp);
868 }
869 }
870 }
871 if (p.second) {
872 sum = -(*p.second);
873 }
874 } else {
875 result =
876 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p));
877 }
878 }
879 return std::make_pair(result, sum);
880 }
881
FloatToIntOverflow(float fval, PrimType totype)882 static bool FloatToIntOverflow(float fval, PrimType totype)
883 {
884 static const float safeFloatMaxToInt32 = 2147483520.0f; // 2^31 - 128
885 static const float safeFloatMinToInt32 = -2147483520.0f;
886 static const float safeFloatMaxToInt64 = 9223372036854775680.0f; // 2^63 - 128
887 static const float safeFloatMinToInt64 = -9223372036854775680.0f;
888 if (!std::isfinite(fval)) {
889 return true;
890 }
891 if (totype == PTY_i64 || totype == PTY_u64) {
892 if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) {
893 return true;
894 }
895 } else {
896 if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) {
897 return true;
898 }
899 }
900 return false;
901 }
902
DoubleToIntOverflow(double dval, PrimType totype)903 static bool DoubleToIntOverflow(double dval, PrimType totype)
904 {
905 static const double safeDoubleMaxToInt32 = 2147482624.0; // 2^31 - 1024
906 static const double safeDoubleMinToInt32 = -2147482624.0;
907 static const double safeDoubleMaxToInt64 = 9223372036854774784.0; // 2^63 - 1024
908 static const double safeDoubleMinToInt64 = -9223372036854774784.0;
909 if (!std::isfinite(dval)) {
910 return true;
911 }
912 if (totype == PTY_i64 || totype == PTY_u64) {
913 if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) {
914 return true;
915 }
916 } else {
917 if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) {
918 return true;
919 }
920 }
921 return false;
922 }
923
FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const924 ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
925 {
926 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
927 resultConst->SetPrimType(toType);
928 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
929 if (fromType == PTY_f32) {
930 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
931 ASSERT_NOT_NULL(constValue);
932 float floatValue = ceil(constValue->GetValue());
933 if (IsPrimitiveFloat(toType)) {
934 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
935 } else if (FloatToIntOverflow(floatValue, toType)) {
936 return nullptr;
937 } else {
938 resultConst->SetConstVal(
939 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
940 }
941 } else {
942 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
943 ASSERT_NOT_NULL(constValue);
944 double doubleValue = ceil(constValue->GetValue());
945 if (IsPrimitiveFloat(toType)) {
946 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
947 } else if (DoubleToIntOverflow(doubleValue, toType)) {
948 return nullptr;
949 } else {
950 resultConst->SetConstVal(
951 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
952 }
953 }
954 return resultConst;
955 }
956
957 template <class T>
CalIntValueFromFloatValue(T value, const MIRType &resultType) const958 T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const
959 {
960 DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type");
961 size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte;
962 bool isSigned = IsSignedInteger(resultType.GetPrimType());
963 int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue();
964 uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum;
965 int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0;
966 if (isSigned && (value > max)) {
967 return static_cast<T>(max);
968 } else if (!isSigned && (value > umax)) {
969 return static_cast<T>(umax);
970 } else if (value < min) {
971 return static_cast<T>(min);
972 }
973 return value;
974 }
975
FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const976 MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const
977 {
978 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
979 if (fromType == PTY_f32) {
980 const auto &constValue = static_cast<const MIRFloatConst&>(cst);
981 float floatValue = constValue.GetValue();
982 if (isFloor) {
983 floatValue = floor(constValue.GetValue());
984 }
985 if (IsPrimitiveFloat(toType)) {
986 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
987 }
988 if (FloatToIntOverflow(floatValue, toType)) {
989 return nullptr;
990 }
991 floatValue = CalIntValueFromFloatValue(floatValue, resultType);
992 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType);
993 } else {
994 const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
995 double doubleValue = constValue.GetValue();
996 if (isFloor) {
997 doubleValue = floor(constValue.GetValue());
998 }
999 if (IsPrimitiveFloat(toType)) {
1000 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1001 }
1002 if (DoubleToIntOverflow(doubleValue, toType)) {
1003 return nullptr;
1004 }
1005 doubleValue = CalIntValueFromFloatValue(doubleValue, resultType);
1006 // gcc/clang have bugs convert double to unsigned long, must convert to signed long first;
1007 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType);
1008 }
1009 }
1010
FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const1011 ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1012 {
1013 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1014 resultConst->SetPrimType(toType);
1015 resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType));
1016 return resultConst;
1017 }
1018
FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const1019 MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1020 {
1021 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1022 if (fromType == PTY_f32) {
1023 const auto &constValue = static_cast<const MIRFloatConst&>(cst);
1024 float floatValue = round(constValue.GetValue());
1025 if (FloatToIntOverflow(floatValue, toType)) {
1026 return nullptr;
1027 }
1028 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType);
1029 } else if (fromType == PTY_f64) {
1030 const auto &constValue = static_cast<const MIRDoubleConst&>(cst);
1031 double doubleValue = round(constValue.GetValue());
1032 if (DoubleToIntOverflow(doubleValue, toType)) {
1033 return nullptr;
1034 }
1035 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1036 static_cast<uint64>(static_cast<int64>(doubleValue)), resultType);
1037 } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) {
1038 const auto &constValue = static_cast<const MIRIntConst&>(cst);
1039 if (IsSignedInteger(fromType)) {
1040 int64 fromValue = constValue.GetExtValue();
1041 float floatValue = round(static_cast<float>(fromValue));
1042 if (static_cast<int64>(floatValue) == fromValue) {
1043 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1044 }
1045 } else {
1046 uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1047 float floatValue = round(static_cast<float>(fromValue));
1048 if (static_cast<uint64>(floatValue) == fromValue) {
1049 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1050 }
1051 }
1052 } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) {
1053 const auto &constValue = static_cast<const MIRIntConst&>(cst);
1054 if (IsSignedInteger(fromType)) {
1055 int64 fromValue = constValue.GetExtValue();
1056 double doubleValue = round(static_cast<double>(fromValue));
1057 if (static_cast<int64>(doubleValue) == fromValue) {
1058 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1059 }
1060 } else {
1061 uint64 fromValue = static_cast<uint64>(constValue.GetExtValue());
1062 double doubleValue = round(static_cast<double>(fromValue));
1063 if (static_cast<uint64>(doubleValue) == fromValue) {
1064 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1065 }
1066 }
1067 }
1068 return nullptr;
1069 }
1070
FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const1071 ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1072 {
1073 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1074 resultConst->SetPrimType(toType);
1075 resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType));
1076 return resultConst;
1077 }
1078
FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const1079 ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1080 {
1081 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1082 resultConst->SetPrimType(toType);
1083 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType);
1084 if (fromType == PTY_f32) {
1085 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal());
1086 CHECK_NULL_FATAL(constValue);
1087 float floatValue = trunc(constValue->GetValue());
1088 if (IsPrimitiveFloat(toType)) {
1089 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue));
1090 } else if (FloatToIntOverflow(floatValue, toType)) {
1091 return nullptr;
1092 } else {
1093 resultConst->SetConstVal(
1094 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType));
1095 }
1096 } else {
1097 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal());
1098 CHECK_NULL_FATAL(constValue);
1099 double doubleValue = trunc(constValue->GetValue());
1100 if (IsPrimitiveFloat(toType)) {
1101 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue));
1102 } else if (DoubleToIntOverflow(doubleValue, toType)) {
1103 return nullptr;
1104 } else {
1105 resultConst->SetConstVal(
1106 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType));
1107 }
1108 }
1109 return resultConst;
1110 }
1111
FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const1112 MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const
1113 {
1114 if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) {
1115 MIRConst *toConst = nullptr;
1116 uint32 fromSize = GetPrimTypeBitSize(fromType);
1117 uint32 toSize = GetPrimTypeBitSize(toType);
1118 // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here.
1119 if (fromType == PTY_u1) {
1120 fromSize = 1;
1121 }
1122 if (toType == PTY_u1) {
1123 toSize = 1;
1124 }
1125 if (toSize > fromSize) {
1126 Opcode op = OP_zext;
1127 if (IsSignedInteger(fromType)) {
1128 op = OP_sext;
1129 }
1130 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1131 ASSERT_NOT_NULL(constVal);
1132 toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize),
1133 constVal->GetValue().TruncOrExtend(fromType));
1134 } else {
1135 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst);
1136 ASSERT_NOT_NULL(constVal);
1137 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType);
1138 toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(
1139 static_cast<uint64>(constVal->GetExtValue()), type);
1140 }
1141 return toConst;
1142 }
1143 if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) {
1144 MIRConst *toConst = nullptr;
1145 if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) {
1146 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64
1147 const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst);
1148 ASSERT_NOT_NULL(fromValue);
1149 float floatValue = static_cast<float>(fromValue->GetValue());
1150 MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue);
1151 toConst = toValue;
1152 } else {
1153 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64
1154 const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst);
1155 ASSERT_NOT_NULL(fromValue);
1156 double doubleValue = static_cast<double>(fromValue->GetValue());
1157 MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue);
1158 toConst = toValue;
1159 }
1160 return toConst;
1161 }
1162 if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) {
1163 return FoldFloorMIRConst(cst, fromType, toType, false);
1164 }
1165 if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) {
1166 return FoldRoundMIRConst(cst, fromType, toType);
1167 }
1168 CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt");
1169 return nullptr;
1170 }
1171
FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const1172 ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const
1173 {
1174 MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType);
1175 if (toConstValue == nullptr) {
1176 return nullptr;
1177 }
1178 ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1179 toConst->SetPrimType(toConstValue->GetType().GetPrimType());
1180 toConst->SetConstVal(toConstValue);
1181 return toConst;
1182 }
1183
1184 // return a primType with bit size >= bitSize (and the nearest one),
1185 // and its signed/float type is the same as ptyp
GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)1186 PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp)
1187 {
1188 bool isSigned = IsSignedInteger(ptyp);
1189 bool isFloat = IsPrimitiveFloat(ptyp);
1190 if (bitSize == 1) { // 1 bit
1191 return PTY_u1;
1192 }
1193 if (bitSize <= 8) { // 8 bit
1194 return isSigned ? PTY_i8 : PTY_u8;
1195 }
1196 if (bitSize <= 16) { // 16 bit
1197 return isSigned ? PTY_i16 : PTY_u16;
1198 }
1199 if (bitSize <= 32) { // 32 bit
1200 return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32);
1201 }
1202 if (bitSize <= 64) { // 64 bit
1203 return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64);
1204 }
1205 return ptyp;
1206 }
1207
GetIntPrimTypeMax(PrimType ptyp)1208 size_t GetIntPrimTypeMax(PrimType ptyp)
1209 {
1210 switch (ptyp) {
1211 case PTY_u1:
1212 return 1;
1213 case PTY_u8:
1214 return UINT8_MAX;
1215 case PTY_i8:
1216 return INT8_MAX;
1217 case PTY_u16:
1218 return UINT16_MAX;
1219 case PTY_i16:
1220 return INT16_MAX;
1221 case PTY_u32:
1222 return UINT32_MAX;
1223 case PTY_i32:
1224 return INT32_MAX;
1225 case PTY_u64:
1226 return UINT64_MAX;
1227 case PTY_i64:
1228 return INT64_MAX;
1229 default:
1230 CHECK_FATAL(false, "NYI");
1231 }
1232 }
1233
GetIntPrimTypeMin(PrimType ptyp)1234 ssize_t GetIntPrimTypeMin(PrimType ptyp)
1235 {
1236 if (IsUnsignedInteger(ptyp)) {
1237 return 0;
1238 }
1239 switch (ptyp) {
1240 case PTY_i8:
1241 return INT8_MIN;
1242 case PTY_i16:
1243 return INT16_MIN;
1244 case PTY_i32:
1245 return INT32_MIN;
1246 case PTY_i64:
1247 return INT64_MIN;
1248 default:
1249 CHECK_FATAL(false, "NYI");
1250 }
1251 }
1252
IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)1253 static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp)
1254 {
1255 if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) {
1256 return false;
1257 }
1258 if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) {
1259 return false;
1260 }
1261 return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) ||
1262 (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) ||
1263 (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp));
1264 }
1265
FoldTypeCvt(TypeCvtNode *node)1266 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node)
1267 {
1268 CHECK_NULL_FATAL(node);
1269 BaseNode *result = nullptr;
1270 if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) {
1271 return {node, std::nullopt};
1272 }
1273 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1274 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1275 PrimType destPtyp = node->GetPrimType();
1276 PrimType fromPtyp = node->FromType();
1277 if (cst != nullptr) {
1278 switch (node->GetOpCode()) {
1279 case OP_ceil: {
1280 result = FoldCeil(*cst, fromPtyp, destPtyp);
1281 break;
1282 }
1283 case OP_cvt: {
1284 result = FoldTypeCvt(*cst, fromPtyp, destPtyp);
1285 break;
1286 }
1287 case OP_floor: {
1288 result = FoldFloor(*cst, fromPtyp, destPtyp);
1289 break;
1290 }
1291 case OP_trunc: {
1292 result = FoldTrunc(*cst, fromPtyp, destPtyp);
1293 break;
1294 }
1295 default:
1296 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold");
1297 break;
1298 }
1299 } else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) {
1300 // the cvt is redundant
1301 return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second);
1302 }
1303 if (result == nullptr) {
1304 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1305 if (e != node->Opnd(0)) {
1306 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(
1307 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e);
1308 } else {
1309 result = node;
1310 }
1311 }
1312 return std::make_pair(result, std::nullopt);
1313 }
1314
FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const1315 MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const
1316 {
1317 uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size);
1318 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType);
1319 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type);
1320 return constValue;
1321 }
1322
FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size, const ConstvalNode &cst) const1323 ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size,
1324 const ConstvalNode &cst) const
1325 {
1326 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>();
1327 const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal());
1328 ASSERT_NOT_NULL(intCst);
1329 IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext);
1330 MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val);
1331 resultConst->SetPrimType(toConst->GetType().GetPrimType());
1332 resultConst->SetConstVal(toConst);
1333 return resultConst;
1334 }
1335
1336 // check if truncation is redundant due to dread or iread having same effect
ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)1337 static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f)
1338 {
1339 if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) {
1340 return false; // this is trying to be conservative
1341 }
1342 BaseNode *opnd = x.Opnd(0);
1343 MIRType *mirType = nullptr;
1344 if (opnd->GetOpCode() == OP_dread) {
1345 DreadNode *dread = static_cast<DreadNode*>(opnd);
1346 MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx());
1347 ASSERT_NOT_NULL(sym);
1348 mirType = sym->GetType();
1349 } else if (opnd->GetOpCode() == OP_iread) {
1350 IreadNode *iread = static_cast<IreadNode*>(opnd);
1351 MIRPtrType *ptrType =
1352 dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx()));
1353 if (ptrType == nullptr) {
1354 return false;
1355 }
1356 mirType = ptrType->GetPointedType();
1357 } else if (opnd->GetOpCode() == OP_extractbits &&
1358 x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) {
1359 return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() &&
1360 IsUnsignedInteger(opnd->GetPrimType()));
1361 } else {
1362 return false;
1363 }
1364 return IsPrimitiveInteger(mirType->GetPrimType()) &&
1365 ((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) ||
1366 (x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) &&
1367 mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() &&
1368 mirType->GetPrimType() == x.GetPrimType();
1369 }
1370
1371 // sext and zext also handled automatically
FoldExtractbits(ExtractbitsNode *node)1372 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node)
1373 {
1374 CHECK_NULL_FATAL(node);
1375 BaseNode *result = nullptr;
1376 uint8 offset = node->GetBitsOffset();
1377 uint8 size = node->GetBitsSize();
1378 Opcode opcode = node->GetOpCode();
1379 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1380 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first);
1381 if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) {
1382 result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst);
1383 return std::make_pair(result, std::nullopt);
1384 }
1385 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1386 if (e != node->Opnd(0)) {
1387 result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset,
1388 size, e);
1389 } else {
1390 result = node;
1391 }
1392 // check for consecutive and redundant extraction of same bits
1393 BaseNode *opnd = result->Opnd(0);
1394 DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null");
1395 Opcode opndOp = opnd->GetOpCode();
1396 if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) {
1397 uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset();
1398 uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize();
1399 if (offset == opndOffset && size == opndSize) {
1400 result->SetOpnd(opnd->Opnd(0), 0); // delete the redundant extraction
1401 }
1402 }
1403 if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) {
1404 if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) {
1405 return std::make_pair(result->Opnd(0), std::nullopt);
1406 }
1407 }
1408 return std::make_pair(result, std::nullopt);
1409 }
1410
FoldIread(IreadNode *node)1411 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node)
1412 {
1413 CHECK_NULL_FATAL(node);
1414 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0));
1415 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p);
1416 node->SetOpnd(e, 0);
1417 BaseNode *result = node;
1418 if (e->GetOpCode() != OP_addrof) {
1419 return std::make_pair(result, std::nullopt);
1420 }
1421
1422 AddrofNode *addrofNode = static_cast<AddrofNode*>(e);
1423 MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx());
1424 DEBUG_ASSERT(msy != nullptr, "nullptr check");
1425 TyIdx typeId = msy->GetTyIdx();
1426 CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check");
1427 MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId];
1428 MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx()));
1429 // If the high level type of iaddrof/iread doesn't match
1430 // the type of addrof's rhs, this optimization cannot be done.
1431 if (ptrType->GetPointedType() != msyType) {
1432 return std::make_pair(result, std::nullopt);
1433 }
1434
1435 Opcode op = node->GetOpCode();
1436 if (op == OP_iread) {
1437 result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(),
1438 node->GetFieldID() + addrofNode->GetFieldID());
1439 }
1440 return std::make_pair(result, std::nullopt);
1441 }
1442
IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)1443 bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB)
1444 {
1445 switch (op) {
1446 case OP_add: {
1447 int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB));
1448 if (IsUnsignedInteger(primType)) {
1449 return static_cast<uint64>(res) < static_cast<uint64>(cstA);
1450 }
1451 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1452 return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1453 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) &&
1454 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1455 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag);
1456 }
1457 case OP_sub: {
1458 if (IsUnsignedInteger(primType)) {
1459 return cstA < cstB;
1460 }
1461 int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB));
1462 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1;
1463 return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag !=
1464 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) &&
1465 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag !=
1466 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag);
1467 }
1468 default: {
1469 return false;
1470 }
1471 }
1472 }
1473
FoldBinary(BinaryNode *node)1474 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node)
1475 {
1476 CHECK_NULL_FATAL(node);
1477 BaseNode *result = nullptr;
1478 std::optional<IntVal> sum = std::nullopt;
1479 Opcode op = node->GetOpCode();
1480 PrimType primType = node->GetPrimType();
1481 PrimType lPrimTypes = node->Opnd(0)->GetPrimType();
1482 PrimType rPrimTypes = node->Opnd(1)->GetPrimType();
1483 std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1484 std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1485 BaseNode *l = lp.first;
1486 BaseNode *r = rp.first;
1487 ASSERT_NOT_NULL(r);
1488 ConstvalNode *lConst = safe_cast<ConstvalNode>(l);
1489 ConstvalNode *rConst = safe_cast<ConstvalNode>(r);
1490 bool isInt = IsPrimitiveInteger(primType);
1491
1492 if (lConst != nullptr && rConst != nullptr) {
1493 MIRConst *lConstVal = lConst->GetConstVal();
1494 MIRConst *rConstVal = rConst->GetConstVal();
1495 ASSERT_NOT_NULL(lConstVal);
1496 ASSERT_NOT_NULL(rConstVal);
1497 // Don't fold div by 0, for floats div by 0 is well defined.
1498 if ((op == OP_div || op == OP_rem) && isInt &&
1499 !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) {
1500 result = NewBinaryNode(node, op, primType, lConst, rConst);
1501 } else {
1502 // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0)
1503 // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the
1504 // logic since the alternative is to return pair(result = nullptr, sum = 6).
1505 // Doing so would introduce many nullptr checks in the code. See previous
1506 // commits that implemented that logic for a comparison.
1507 result = FoldConstBinary(op, primType, *lConst, *rConst);
1508 }
1509 } else if (lConst != nullptr && isInt) {
1510 MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal());
1511 ASSERT_NOT_NULL(mcst);
1512 PrimType cstTyp = mcst->GetType().GetPrimType();
1513 IntVal cst = mcst->GetValue();
1514 if (op == OP_add) {
1515 if (IsSignedInteger(cstTyp) && rp.second &&
1516 IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) {
1517 // do not introduce signed integer overflow
1518 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1519 } else {
1520 sum = cst + rp.second;
1521 result = r;
1522 }
1523 } else if (op == OP_sub && r->GetPrimType() != PTY_u1) {
1524 // We exclude u1 type for fixing the following wrong example:
1525 // before cf:
1526 // sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16)))
1527 // after cf:
1528 // add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17))
1529 sum = cst - rp.second;
1530 if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) {
1531 r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r);
1532 }
1533 result = NegateTree(r);
1534 } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl ||
1535 op == OP_band) &&
1536 cst == 0) {
1537 // 0 * X -> 0
1538 // 0 / X -> 0
1539 // 0 % X -> 0
1540 // 0 >> X -> 0
1541 // 0 << X -> 0
1542 // 0 & X -> 0
1543 // 0 && X -> 0
1544 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1545 } else if (op == OP_mul && cst == 1) {
1546 // 1 * X --> X
1547 sum = rp.second;
1548 result = r;
1549 } else if (op == OP_bior && cst == -1) {
1550 // (-1) | X -> -1
1551 result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp);
1552 } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) {
1553 // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)]
1554 sum = cst * rp.second;
1555 if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) {
1556 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first);
1557 }
1558 result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first);
1559 } else if ((op == OP_bior || op == OP_bxor) && cst == 0) {
1560 // 0 | X -> X
1561 // 0 ^ X -> X
1562 sum = rp.second;
1563 result = r;
1564 } else {
1565 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp));
1566 }
1567 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1568 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1569 }
1570 } else if (rConst != nullptr && isInt) {
1571 MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal());
1572 ASSERT_NOT_NULL(mcst);
1573 PrimType cstTyp = mcst->GetType().GetPrimType();
1574 IntVal cst = mcst->GetValue();
1575 if (op == OP_add) {
1576 if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) {
1577 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1578 } else {
1579 result = l;
1580 sum = lp.second + cst;
1581 }
1582 } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) {
1583 result = l;
1584 sum = lp.second - cst;
1585 } else if ((op == OP_mul || op == OP_band) && cst == 0) {
1586 // X * 0 -> 0
1587 // X & 0 -> 0
1588 // X && 0 -> 0
1589 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1590 } else if ((op == OP_mul || op == OP_div) && cst == 1) {
1591 // case [X * 1 -> X]
1592 // case [X / 1 = X]
1593 sum = lp.second;
1594 result = l;
1595 } else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul &&
1596 IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) {
1597 // temporary fix for constfold of mul/div in DejaGnu
1598 // Later we need a more formal interface for pattern match
1599 // X * Y / Y -> X
1600 BaseNode *x = l->Opnd(0);
1601 BaseNode *y = l->Opnd(1);
1602 ConstvalNode *xConst = safe_cast<ConstvalNode>(x);
1603 ConstvalNode *yConst = safe_cast<ConstvalNode>(y);
1604 bool foldMulDiv = false;
1605 if (yConst != nullptr && xConst == nullptr &&
1606 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1607 MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal());
1608 ASSERT_NOT_NULL(yCst);
1609 IntVal mulCst = yCst->GetValue();
1610 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1611 mulCst.GetExtValue() == cst.GetExtValue()) {
1612 foldMulDiv = true;
1613 result = x;
1614 }
1615 } else if (xConst != nullptr && yConst == nullptr &&
1616 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) {
1617 MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal());
1618 ASSERT_NOT_NULL(xCst);
1619 IntVal mulCst = xCst->GetValue();
1620 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() &&
1621 mulCst.GetExtValue() == cst.GetExtValue()) {
1622 foldMulDiv = true;
1623 result = y;
1624 }
1625 }
1626 if (!foldMulDiv) {
1627 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1628 }
1629 } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) {
1630 // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)]
1631 sum = lp.second * cst;
1632 if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) {
1633 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first);
1634 }
1635 if (lp.first->GetOpCode() == OP_neg && cst == -1) {
1636 // special case: ((-X) + konst) * (-1) -> the pair [(X), -konst]
1637 result = lp.first->Opnd(0);
1638 } else {
1639 result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst);
1640 }
1641 } else if (op == OP_band && cst == -1) {
1642 // X & (-1) -> X
1643 sum = lp.second;
1644 result = l;
1645 } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) &&
1646 (!lp.second.has_value() || lp.second == 0)) {
1647 bool fold2extractbits = false;
1648 if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) {
1649 BinaryNode *shrNode = static_cast<BinaryNode *>(l);
1650 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) {
1651 ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1));
1652 int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue();
1653 uint64 ucst = cst.GetZXTValue();
1654 uint32 bsize = 0;
1655 do {
1656 bsize++;
1657 ucst >>= 1;
1658 } while (ucst != 0);
1659 if (shrAmt + static_cast<int64>(bsize) <=
1660 static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) &&
1661 static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) {
1662 fold2extractbits = true;
1663 // change to use extractbits
1664 result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits,
1665 GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0));
1666 sum = std::nullopt;
1667 }
1668 }
1669 }
1670 if (!fold2extractbits) {
1671 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1672 sum = std::nullopt;
1673 }
1674 } else if (op == OP_bior && cst == -1) {
1675 // X | (-1) -> -1
1676 result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp);
1677 } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) {
1678 // X >> 0 -> X
1679 // X << 0 -> X
1680 // X | 0 -> X
1681 // X ^ 0 -> X
1682 sum = lp.second;
1683 result = l;
1684 } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) {
1685 // bxor i32 (
1686 // cvt i32 u1 (regread u1 %13),
1687 // constValue i32 1),
1688 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1689 if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) {
1690 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l);
1691 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) {
1692 BaseNode *base = cvtNode->Opnd(0);
1693 BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType());
1694 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base);
1695 BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue);
1696 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp);
1697 }
1698 }
1699 } else if (op == OP_rem && cst == 1) {
1700 // X % 1 -> 0
1701 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp);
1702 } else {
1703 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r);
1704 }
1705 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) {
1706 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result);
1707 }
1708 } else if (isInt && (op == OP_add || op == OP_sub)) {
1709 if (op == OP_add) {
1710 result = NewBinaryNode(node, op, primType, l, r);
1711 sum = lp.second + rp.second;
1712 } else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) {
1713 // if fold is (x - (y - z)) -> (x - neg(z)) - y
1714 // (x - neg(z)) Could cross the int limit
1715 // return node
1716 result = node;
1717 } else {
1718 result = NewBinaryNode(node, op, primType, l, r);
1719 sum = lp.second - rp.second;
1720 }
1721 } else {
1722 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp));
1723 }
1724 return std::make_pair(result, sum);
1725 }
1726
SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const1727 BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const
1728 {
1729 if (isRConstval) {
1730 ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1));
1731 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1732 const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0));
1733 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1734 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1735 }
1736 } else {
1737 ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0));
1738 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1739 const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1));
1740 if (isGtOrLt) {
1741 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1742 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0));
1743 } else {
1744 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(),
1745 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1));
1746 }
1747 }
1748 }
1749 return &node;
1750 }
1751
SimplifyDoubleCompare(CompareNode &compareNode) const1752 BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const
1753 {
1754 // See arm manual B.cond(P2993) and FCMP(P1091)
1755 CompareNode *node = &compareNode;
1756 BaseNode *result = node;
1757 BaseNode *l = node->Opnd(0);
1758 BaseNode *r = node->Opnd(1);
1759 if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) {
1760 if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1761 (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1762 result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval));
1763 } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) {
1764 // ne (u1 x, constValue 0) <==> x
1765 ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1766 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) {
1767 BaseNode *opnd = l;
1768 do {
1769 if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1770 result = opnd;
1771 break;
1772 } else if (opnd->GetOpCode() == OP_cvt) {
1773 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd);
1774 opnd = cvtNode->Opnd(0);
1775 } else {
1776 opnd = nullptr;
1777 }
1778 } while (opnd != nullptr);
1779 }
1780 } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) {
1781 ConstvalNode *constNode = static_cast<ConstvalNode*>(r);
1782 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() &&
1783 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) {
1784 auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne;
1785 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1786 resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1));
1787 }
1788 }
1789 } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) {
1790 if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) ||
1791 (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) {
1792 result = SimplifyDoubleConstvalCompare(*node,
1793 (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true);
1794 }
1795 }
1796 return result;
1797 }
1798
FoldCompare(CompareNode *node)1799 std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node)
1800 {
1801 CHECK_NULL_FATAL(node);
1802 BaseNode *result = nullptr;
1803 std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0));
1804 std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1));
1805 ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first);
1806 ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first);
1807 Opcode opcode = node->GetOpCode();
1808 if (lConst != nullptr && rConst != nullptr) {
1809 result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst);
1810 } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp &&
1811 lConst->GetConstVal()->GetKind() == kConstInt) {
1812 BaseNode *l = lp.first;
1813 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1814 result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r);
1815 } else {
1816 BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp);
1817 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp);
1818 if (l != node->Opnd(0) || r != node->Opnd(1)) {
1819 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>(
1820 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r);
1821 } else {
1822 result = node;
1823 }
1824 auto *compareNode = static_cast<CompareNode*>(result);
1825 CHECK_NULL_FATAL(compareNode);
1826 result = SimplifyDoubleCompare(*compareNode);
1827 }
1828 return std::make_pair(result, std::nullopt);
1829 }
1830
Fold(BaseNode *node)1831 BaseNode *ConstantFold::Fold(BaseNode *node)
1832 {
1833 if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) {
1834 return nullptr;
1835 }
1836 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node);
1837 BaseNode *result = PairToExpr(node->GetPrimType(), p);
1838 if (result == node) {
1839 result = nullptr;
1840 }
1841 return result;
1842 }
1843
1844 } // namespace maple
1845