1/* 2 * Copyright (c) 2023 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16#include "constantfold.h" 17#include <cmath> 18#include <cfloat> 19#include <climits> 20#include <type_traits> 21#include "mpl_logging.h" 22#include "mir_function.h" 23#include "mir_builder.h" 24#include "global_tables.h" 25#include "me_option.h" 26#include "maple_phase_manager.h" 27#include "mir_type.h" 28 29namespace maple { 30 31namespace { 32constexpr uint32 kByteSizeOfBit64 = 8; // byte number for 64 bit 33constexpr uint32 kBitSizePerByte = 8; 34constexpr maple::int32 kMaxOffset = INT_MAX - 8; 35 36enum CompareRes : int64 { kLess = -1, kEqual = 0, kGreater = 1 }; 37 38std::optional<IntVal> operator*(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2) 39{ 40 if (!v1 && !v2) { 41 return std::nullopt; 42 } 43 44 // Perform all calculations in terms of the maximum available signed type. 45 // The value will be truncated for an appropriate type when constant is created in PairToExpr function 46 return v1 && v2 ? v1->Mul(*v2, PTY_i64) : IntVal(static_cast<uint64>(0), PTY_i64); 47} 48 49// Perform all calculations in terms of the maximum available signed type. 50// The value will be truncated for an appropriate type when constant is created in PairToExpr function 51std::optional<IntVal> AddSub(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2, bool isAdd) 52{ 53 if (!v1 && !v2) { 54 return std::nullopt; 55 } 56 57 if (v1 && v2) { 58 return isAdd ? v1->Add(*v2, PTY_i64) : v1->Sub(*v2, PTY_i64); 59 } 60 61 if (v1) { 62 return v1->TruncOrExtend(PTY_i64); 63 } 64 65 // !v1 && v2 66 return isAdd ? v2->TruncOrExtend(PTY_i64) : -(v2->TruncOrExtend(PTY_i64)); 67} 68 69std::optional<IntVal> operator+(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2) 70{ 71 return AddSub(v1, v2, true); 72} 73 74std::optional<IntVal> operator-(const std::optional<IntVal> &v1, const std::optional<IntVal> &v2) 75{ 76 return AddSub(v1, v2, false); 77} 78 79} // anonymous namespace 80 81// This phase is designed to achieve compiler optimization by 82// simplifying constant expressions. The constant expression 83// is evaluated and replaced by the value calculated on compile 84// time to save time on runtime. 85// 86// The main procedure shows as following: 87// A. Analyze expression type 88// B. Analysis operator type 89// C. Replace the expression with the result of the operation 90 91// true if the constant's bits are made of only one group of contiguous 1's 92// starting at bit 0 93static bool ContiguousBitsOf1(uint64 x) 94{ 95 if (x == 0) { 96 return false; 97 } 98 return (~x & (x + 1)) == (x + 1); 99} 100 101inline bool IsPowerOf2(uint64 num) 102{ 103 if (num == 0) { 104 return false; 105 } 106 return (~(num - 1) & num) == num; 107} 108 109BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs, 110 BaseNode *rhs) const 111{ 112 CHECK_NULL_FATAL(old); 113 BinaryNode *result = nullptr; 114 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == lhs && old->Opnd(1) == rhs) { 115 result = old; 116 } else { 117 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(op, primType, lhs, rhs); 118 } 119 return result; 120} 121 122UnaryNode *ConstantFold::NewUnaryNode(UnaryNode *old, Opcode op, PrimType primType, BaseNode *expr) const 123{ 124 CHECK_NULL_FATAL(old); 125 UnaryNode *result = nullptr; 126 if (old->GetOpCode() == op && old->GetPrimType() == primType && old->Opnd(0) == expr) { 127 result = old; 128 } else { 129 result = mirModule->CurFuncCodeMemPool()->New<UnaryNode>(op, primType, expr); 130 } 131 return result; 132} 133 134BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair<BaseNode*, std::optional<IntVal>> &pair) const 135{ 136 CHECK_NULL_FATAL(pair.first); 137 BaseNode *result = pair.first; 138 if (!pair.second || *pair.second == 0 || GetPrimTypeSize(resultType) > k8ByteSize) { 139 return result; 140 } 141 if (pair.first->GetOpCode() == OP_neg && !pair.second->GetSignBit()) { 142 // -a, 5 -> 5 - a 143 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst( 144 static_cast<uint64>(pair.second->GetExtValue()), resultType); 145 BaseNode *r = static_cast<UnaryNode*>(pair.first)->Opnd(0); 146 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, val, r); 147 } else { 148 if ((!pair.second->GetSignBit() && 149 pair.second->GetSXTValue(static_cast<uint8>(GetPrimTypeBitSize(resultType))) > 0) || 150 pair.second->TruncOrExtend(resultType).IsMinValue() || 151 pair.second->GetSXTValue() == INT64_MIN) { 152 // +-a, 5 -> a + 5 153 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst( 154 static_cast<uint64>(pair.second->GetExtValue()), resultType); 155 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_add, resultType, pair.first, val); 156 } else { 157 // +-a, -5 -> a + -5 158 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst( 159 static_cast<uint64>((-pair.second.value()).GetExtValue()), resultType); 160 result = mirModule->CurFuncCodeMemPool()->New<BinaryNode>(OP_sub, resultType, pair.first, val); 161 } 162 } 163 return result; 164} 165 166std::pair<BaseNode *, std::optional<IntVal>> ConstantFold::FoldBase(BaseNode *node) const 167{ 168 return std::make_pair(node, std::nullopt); 169} 170 171std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::DispatchFold(BaseNode *node) 172{ 173 CHECK_NULL_FATAL(node); 174 if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) { 175 return {node, std::nullopt}; 176 } 177 switch (node->GetOpCode()) { 178 case OP_abs: 179 case OP_bnot: 180 case OP_lnot: 181 case OP_neg: 182 case OP_sqrt: 183 return FoldUnary(static_cast<UnaryNode*>(node)); 184 case OP_ceil: 185 case OP_floor: 186 case OP_trunc: 187 case OP_cvt: 188 return FoldTypeCvt(static_cast<TypeCvtNode*>(node)); 189 case OP_sext: 190 case OP_zext: 191 case OP_extractbits: 192 return FoldExtractbits(static_cast<ExtractbitsNode*>(node)); 193 case OP_iread: 194 return FoldIread(static_cast<IreadNode*>(node)); 195 case OP_add: 196 case OP_ashr: 197 case OP_band: 198 case OP_bior: 199 case OP_bxor: 200 case OP_div: 201 case OP_lshr: 202 case OP_max: 203 case OP_min: 204 case OP_mul: 205 case OP_rem: 206 case OP_shl: 207 case OP_sub: 208 return FoldBinary(static_cast<BinaryNode*>(node)); 209 case OP_eq: 210 case OP_ne: 211 case OP_ge: 212 case OP_gt: 213 case OP_le: 214 case OP_lt: 215 case OP_cmp: 216 return FoldCompare(static_cast<CompareNode*>(node)); 217 case OP_retype: 218 return FoldRetype(static_cast<RetypeNode*>(node)); 219 default: 220 return FoldBase(static_cast<BaseNode*>(node)); 221 } 222} 223 224BaseNode *ConstantFold::Negate(BaseNode *node) const 225{ 226 CHECK_NULL_FATAL(node); 227 return mirModule->CurFuncCodeMemPool()->New<UnaryNode>(OP_neg, PrimType(node->GetPrimType()), node); 228} 229 230BaseNode *ConstantFold::Negate(UnaryNode *node) const 231{ 232 CHECK_NULL_FATAL(node); 233 BaseNode *result = nullptr; 234 if (node->GetOpCode() == OP_neg) { 235 result = static_cast<BaseNode*>(node->Opnd(0)); 236 } else { 237 BaseNode *n = static_cast<BaseNode*>(node); 238 result = NewUnaryNode(node, OP_neg, node->GetPrimType(), n); 239 } 240 return result; 241} 242 243BaseNode *ConstantFold::Negate(const ConstvalNode *node) const 244{ 245 CHECK_NULL_FATAL(node); 246 ConstvalNode *copy = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator()); 247 CHECK_NULL_FATAL(copy); 248 copy->GetConstVal()->Neg(); 249 return copy; 250} 251 252BaseNode *ConstantFold::NegateTree(BaseNode *node) const 253{ 254 CHECK_NULL_FATAL(node); 255 if (node->IsUnaryNode()) { 256 return Negate(static_cast<UnaryNode*>(node)); 257 } else if (node->GetOpCode() == OP_constval) { 258 return Negate(static_cast<ConstvalNode*>(node)); 259 } else { 260 return Negate(static_cast<BaseNode*>(node)); 261 } 262} 263 264MIRIntConst *ConstantFold::FoldIntConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, 265 const MIRIntConst &intConst0, 266 const MIRIntConst &intConst1) const 267{ 268 uint64 result = 0; 269 270 bool greater = intConst0.GetValue().Greater(intConst1.GetValue(), opndType); 271 bool equal = intConst0.GetValue().Equal(intConst1.GetValue(), opndType); 272 bool less = intConst0.GetValue().Less(intConst1.GetValue(), opndType); 273 274 switch (opcode) { 275 case OP_eq: { 276 result = equal; 277 break; 278 } 279 case OP_ge: { 280 result = (greater || equal) ? 1 : 0; 281 break; 282 } 283 case OP_gt: { 284 result = greater; 285 break; 286 } 287 case OP_le: { 288 result = (less || equal) ? 1 : 0; 289 break; 290 } 291 case OP_lt: { 292 result = less; 293 break; 294 } 295 case OP_ne: { 296 result = !equal; 297 break; 298 } 299 case OP_cmp: { 300 if (greater) { 301 result = kGreater; 302 } else if (equal) { 303 result = kEqual; 304 } else if (less) { 305 result = static_cast<uint64>(kLess); 306 } 307 break; 308 } 309 default: 310 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstComparison"); 311 break; 312 } 313 // determine the type 314 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType); 315 // form the constant 316 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type); 317 return constValue; 318} 319 320ConstvalNode *ConstantFold::FoldIntConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, 321 const ConstvalNode &const0, const ConstvalNode &const1) const 322{ 323 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal()); 324 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal()); 325 CHECK_NULL_FATAL(intConst0); 326 CHECK_NULL_FATAL(intConst1); 327 MIRIntConst *constValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1); 328 // form the ConstvalNode 329 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 330 resultConst->SetPrimType(resultType); 331 resultConst->SetConstVal(constValue); 332 return resultConst; 333} 334 335MIRConst *ConstantFold::FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst &intConst0, 336 const MIRIntConst &intConst1) 337{ 338 IntVal intVal0 = intConst0.GetValue(); 339 IntVal intVal1 = intConst1.GetValue(); 340 IntVal result(static_cast<uint64>(0), resultType); 341 342 switch (opcode) { 343 case OP_add: { 344 result = intVal0.Add(intVal1, resultType); 345 break; 346 } 347 case OP_sub: { 348 result = intVal0.Sub(intVal1, resultType); 349 break; 350 } 351 case OP_mul: { 352 result = intVal0.Mul(intVal1, resultType); 353 break; 354 } 355 case OP_div: { 356 result = intVal0.Div(intVal1, resultType); 357 break; 358 } 359 case OP_rem: { 360 result = intVal0.Rem(intVal1, resultType); 361 break; 362 } 363 case OP_ashr: { 364 result = intVal0.AShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType); 365 break; 366 } 367 case OP_lshr: { 368 result = intVal0.LShr(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType); 369 break; 370 } 371 case OP_shl: { 372 result = intVal0.Shl(intVal1.GetZXTValue() % GetAlignedPrimTypeBitSize(resultType), resultType); 373 break; 374 } 375 case OP_max: { 376 result = Max(intVal0, intVal1, resultType); 377 break; 378 } 379 case OP_min: { 380 result = Min(intVal0, intVal1, resultType); 381 break; 382 } 383 case OP_band: { 384 result = intVal0.And(intVal1, resultType); 385 break; 386 } 387 case OP_bior: { 388 result = intVal0.Or(intVal1, resultType); 389 break; 390 } 391 case OP_bxor: { 392 result = intVal0.Xor(intVal1, resultType); 393 break; 394 } 395 default: 396 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstBinary"); 397 break; 398 } 399 // determine the type 400 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType); 401 // form the constant 402 MIRIntConst *constValue = 403 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type); 404 return constValue; 405} 406 407ConstvalNode *ConstantFold::FoldIntConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, 408 const ConstvalNode &const1) const 409{ 410 const MIRIntConst *intConst0 = safe_cast<MIRIntConst>(const0.GetConstVal()); 411 const MIRIntConst *intConst1 = safe_cast<MIRIntConst>(const1.GetConstVal()); 412 CHECK_NULL_FATAL(intConst0); 413 CHECK_NULL_FATAL(intConst1); 414 MIRConst *constValue = FoldIntConstBinaryMIRConst(opcode, resultType, *intConst0, *intConst1); 415 // form the ConstvalNode 416 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 417 resultConst->SetPrimType(resultType); 418 resultConst->SetConstVal(constValue); 419 return resultConst; 420} 421 422ConstvalNode *ConstantFold::FoldFPConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, 423 const ConstvalNode &const1) const 424{ 425 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match"); 426 const MIRDoubleConst *doubleConst0 = nullptr; 427 const MIRDoubleConst *doubleConst1 = nullptr; 428 const MIRFloatConst *floatConst0 = nullptr; 429 const MIRFloatConst *floatConst1 = nullptr; 430 bool useDouble = (const0.GetPrimType() == PTY_f64); 431 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 432 resultConst->SetPrimType(resultType); 433 if (useDouble) { 434 doubleConst0 = safe_cast<MIRDoubleConst>(const0.GetConstVal()); 435 doubleConst1 = safe_cast<MIRDoubleConst>(const1.GetConstVal()); 436 CHECK_NULL_FATAL(doubleConst0); 437 CHECK_NULL_FATAL(doubleConst1); 438 } else { 439 floatConst0 = safe_cast<MIRFloatConst>(const0.GetConstVal()); 440 floatConst1 = safe_cast<MIRFloatConst>(const1.GetConstVal()); 441 CHECK_NULL_FATAL(floatConst0); 442 CHECK_NULL_FATAL(floatConst1); 443 } 444 float constValueFloat = 0.0; 445 double constValueDouble = 0.0; 446 switch (opcode) { 447 case OP_add: { 448 if (useDouble) { 449 constValueDouble = doubleConst0->GetValue() + doubleConst1->GetValue(); 450 } else { 451 constValueFloat = floatConst0->GetValue() + floatConst1->GetValue(); 452 } 453 break; 454 } 455 case OP_sub: { 456 if (useDouble) { 457 constValueDouble = doubleConst0->GetValue() - doubleConst1->GetValue(); 458 } else { 459 constValueFloat = floatConst0->GetValue() - floatConst1->GetValue(); 460 } 461 break; 462 } 463 case OP_mul: { 464 if (useDouble) { 465 constValueDouble = doubleConst0->GetValue() * doubleConst1->GetValue(); 466 } else { 467 constValueFloat = floatConst0->GetValue() * floatConst1->GetValue(); 468 } 469 break; 470 } 471 case OP_div: { 472 // for floats div by 0 is well defined 473 if (useDouble) { 474 constValueDouble = doubleConst0->GetValue() / doubleConst1->GetValue(); 475 } else { 476 constValueFloat = floatConst0->GetValue() / floatConst1->GetValue(); 477 } 478 break; 479 } 480 case OP_max: { 481 if (useDouble) { 482 constValueDouble = (doubleConst0->GetValue() >= doubleConst1->GetValue()) ? doubleConst0->GetValue() 483 : doubleConst1->GetValue(); 484 } else { 485 constValueFloat = (floatConst0->GetValue() >= floatConst1->GetValue()) ? floatConst0->GetValue() 486 : floatConst1->GetValue(); 487 } 488 break; 489 } 490 case OP_min: { 491 if (useDouble) { 492 constValueDouble = (doubleConst0->GetValue() <= doubleConst1->GetValue()) ? doubleConst0->GetValue() 493 : doubleConst1->GetValue(); 494 } else { 495 constValueFloat = (floatConst0->GetValue() <= floatConst1->GetValue()) ? floatConst0->GetValue() 496 : floatConst1->GetValue(); 497 } 498 break; 499 } 500 case OP_rem: 501 case OP_ashr: 502 case OP_lshr: 503 case OP_shl: 504 case OP_band: 505 case OP_bior: 506 case OP_bxor: { 507 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstBinary"); 508 break; 509 } 510 default: 511 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstBinary"); 512 break; 513 } 514 if (resultType == PTY_f64) { 515 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValueDouble)); 516 } else { 517 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(constValueFloat)); 518 } 519 return resultConst; 520} 521 522bool ConstantFold::ConstValueEqual(int64 leftValue, int64 rightValue) const 523{ 524 return (leftValue == rightValue); 525} 526 527bool ConstantFold::ConstValueEqual(float leftValue, float rightValue) const 528{ 529 auto result = fabs(leftValue - rightValue); 530 return leftValue <= FLT_MIN && rightValue <= FLT_MIN ? result < FLT_MIN : result <= FLT_MIN; 531} 532 533bool ConstantFold::ConstValueEqual(double leftValue, double rightValue) const 534{ 535 auto result = fabs(leftValue - rightValue); 536 return leftValue <= DBL_MIN && rightValue <= DBL_MIN ? result < DBL_MIN : result <= DBL_MIN; 537} 538 539template<typename T> 540bool ConstantFold::FullyEqual(T leftValue, T rightValue) const 541{ 542 if (std::isinf(leftValue) && std::isinf(rightValue)) { 543 // (inf == inf), add the judgement here in case of the subtraction between float type inf 544 return true; 545 } else { 546 return ConstValueEqual(leftValue, rightValue); 547 } 548} 549 550template<typename T> 551int64 ConstantFold::ComparisonResult(Opcode op, T *leftConst, T *rightConst) const 552{ 553 DEBUG_ASSERT(leftConst != nullptr, "leftConst should not be nullptr"); 554 typename T::value_type leftValue = leftConst->GetValue(); 555 DEBUG_ASSERT(rightConst != nullptr, "rightConst should not be nullptr"); 556 typename T::value_type rightValue = rightConst->GetValue(); 557 int64 result = 0; 558 switch (op) { 559 case OP_eq: { 560 result = FullyEqual(leftValue, rightValue); 561 break; 562 } 563 case OP_ge: { 564 result = (leftValue > rightValue) || FullyEqual(leftValue, rightValue); 565 break; 566 } 567 case OP_gt: { 568 result = (leftValue > rightValue); 569 break; 570 } 571 case OP_le: { 572 result = (leftValue < rightValue) || FullyEqual(leftValue, rightValue); 573 break; 574 } 575 case OP_lt: { 576 result = (leftValue < rightValue); 577 break; 578 } 579 case OP_ne: { 580 result = !FullyEqual(leftValue, rightValue); 581 break; 582 } 583 [[clang::fallthrough]]; 584 case OP_cmp: { 585 if (leftValue > rightValue) { 586 result = kGreater; 587 } else if (FullyEqual(leftValue, rightValue)) { 588 result = kEqual; 589 } else { 590 result = kLess; 591 } 592 break; 593 } 594 default: 595 DEBUG_ASSERT(false, "Unknown opcode for Comparison"); 596 break; 597 } 598 return result; 599} 600 601MIRIntConst *ConstantFold::FoldFPConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, 602 const MIRConst &leftConst, const MIRConst &rightConst) const 603{ 604 int64 result = 0; 605 bool useDouble = (opndType == PTY_f64); 606 if (useDouble) { 607 result = 608 ComparisonResult(opcode, safe_cast<MIRDoubleConst>(&leftConst), safe_cast<MIRDoubleConst>(&rightConst)); 609 } else { 610 result = ComparisonResult(opcode, safe_cast<MIRFloatConst>(&leftConst), safe_cast<MIRFloatConst>(&rightConst)); 611 } 612 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType); 613 MIRIntConst *resultConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result), type); 614 return resultConst; 615} 616 617ConstvalNode *ConstantFold::FoldFPConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, 618 const ConstvalNode &const0, const ConstvalNode &const1) const 619{ 620 DEBUG_ASSERT(const0.GetPrimType() == const1.GetPrimType(), "The types of the operands must match"); 621 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 622 resultConst->SetPrimType(resultType); 623 resultConst->SetConstVal( 624 FoldFPConstComparisonMIRConst(opcode, resultType, opndType, *const0.GetConstVal(), *const1.GetConstVal())); 625 return resultConst; 626} 627 628MIRConst *ConstantFold::FoldConstComparisonMIRConst(Opcode opcode, PrimType resultType, PrimType opndType, 629 const MIRConst &const0, const MIRConst &const1) const 630{ 631 MIRConst *returnValue = nullptr; 632 if (IsPrimitiveInteger(opndType)) { 633 const auto *intConst0 = safe_cast<MIRIntConst>(&const0); 634 const auto *intConst1 = safe_cast<MIRIntConst>(&const1); 635 ASSERT_NOT_NULL(intConst0); 636 ASSERT_NOT_NULL(intConst1); 637 returnValue = FoldIntConstComparisonMIRConst(opcode, resultType, opndType, *intConst0, *intConst1); 638 } else if (opndType == PTY_f32 || opndType == PTY_f64) { 639 returnValue = FoldFPConstComparisonMIRConst(opcode, resultType, opndType, const0, const1); 640 } else { 641 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparisonMIRConst"); 642 } 643 return returnValue; 644} 645 646ConstvalNode *ConstantFold::FoldConstComparison(Opcode opcode, PrimType resultType, PrimType opndType, 647 const ConstvalNode &const0, const ConstvalNode &const1) const 648{ 649 ConstvalNode *returnValue = nullptr; 650 if (IsPrimitiveInteger(opndType)) { 651 returnValue = FoldIntConstComparison(opcode, resultType, opndType, const0, const1); 652 } else if (opndType == PTY_f32 || opndType == PTY_f64) { 653 returnValue = FoldFPConstComparison(opcode, resultType, opndType, const0, const1); 654 } else { 655 DEBUG_ASSERT(false, "Unhandled case for FoldConstComparison"); 656 } 657 return returnValue; 658} 659 660CompareNode *ConstantFold::FoldConstComparisonReverse(Opcode opcode, PrimType resultType, PrimType opndType, 661 BaseNode &l, BaseNode &r) const 662{ 663 CompareNode *result = nullptr; 664 Opcode op = opcode; 665 switch (opcode) { 666 case OP_gt: { 667 op = OP_lt; 668 break; 669 } 670 case OP_lt: { 671 op = OP_gt; 672 break; 673 } 674 case OP_ge: { 675 op = OP_le; 676 break; 677 } 678 case OP_le: { 679 op = OP_ge; 680 break; 681 } 682 case OP_eq: { 683 break; 684 } 685 case OP_ne: { 686 break; 687 } 688 default: 689 DEBUG_ASSERT(false, "Unknown opcode for FoldConstComparisonReverse"); 690 break; 691 } 692 693 result = 694 mirModule->CurFuncCodeMemPool()->New<CompareNode>(Opcode(op), PrimType(resultType), PrimType(opndType), &r, &l); 695 return result; 696} 697 698ConstvalNode *ConstantFold::FoldConstBinary(Opcode opcode, PrimType resultType, const ConstvalNode &const0, 699 const ConstvalNode &const1) const 700{ 701 ConstvalNode *returnValue = nullptr; 702 if (IsPrimitiveInteger(resultType)) { 703 returnValue = FoldIntConstBinary(opcode, resultType, const0, const1); 704 } else if (resultType == PTY_f32 || resultType == PTY_f64) { 705 returnValue = FoldFPConstBinary(opcode, resultType, const0, const1); 706 } else { 707 DEBUG_ASSERT(false, "Unhandled case for FoldConstBinary"); 708 } 709 return returnValue; 710} 711 712MIRIntConst *ConstantFold::FoldIntConstUnaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *constNode) 713{ 714 CHECK_NULL_FATAL(constNode); 715 IntVal result = constNode->GetValue().TruncOrExtend(resultType); 716 switch (opcode) { 717 case OP_abs: { 718 if (IsSignedInteger(constNode->GetType().GetPrimType()) && result.GetSignBit()) { 719 result = -result; 720 } 721 break; 722 } 723 case OP_bnot: { 724 result = ~result; 725 break; 726 } 727 case OP_lnot: { 728 uint64 resultInt = result == 0 ? 1 : 0; 729 result = {resultInt, resultType}; 730 break; 731 } 732 case OP_neg: { 733 result = -result; 734 break; 735 } 736 case OP_sext: // handled in FoldExtractbits 737 case OP_zext: // handled in FoldExtractbits 738 case OP_extractbits: // handled in FoldExtractbits 739 case OP_sqrt: { 740 DEBUG_ASSERT(false, "Unexpected opcode in FoldIntConstUnaryMIRConst"); 741 break; 742 } 743 default: 744 DEBUG_ASSERT(false, "Unknown opcode for FoldIntConstUnaryMIRConst"); 745 break; 746 } 747 // determine the type 748 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType); 749 // form the constant 750 MIRIntConst *constValue = 751 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(result.GetExtValue()), type); 752 return constValue; 753} 754 755template <typename T> 756ConstvalNode *ConstantFold::FoldFPConstUnary(Opcode opcode, PrimType resultType, ConstvalNode *constNode) const 757{ 758 CHECK_NULL_FATAL(constNode); 759 double constValue = 0; 760 T *fpCst = static_cast<T*>(constNode->GetConstVal()); 761 switch (opcode) { 762 case OP_neg: { 763 constValue = typename T::value_type(-fpCst->GetValue()); 764 break; 765 } 766 case OP_abs: { 767 constValue = typename T::value_type(fabs(fpCst->GetValue())); 768 break; 769 } 770 case OP_sqrt: { 771 constValue = typename T::value_type(sqrt(fpCst->GetValue())); 772 break; 773 } 774 case OP_bnot: 775 case OP_lnot: 776 case OP_sext: 777 case OP_zext: 778 case OP_extractbits: { 779 DEBUG_ASSERT(false, "Unexpected opcode in FoldFPConstUnary"); 780 break; 781 } 782 default: 783 DEBUG_ASSERT(false, "Unknown opcode for FoldFPConstUnary"); 784 break; 785 } 786 auto *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 787 resultConst->SetPrimType(resultType); 788 if (resultType == PTY_f32) { 789 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(static_cast<float>(constValue))); 790 } else if (resultType == PTY_f64) { 791 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(constValue)); 792 } else { 793 CHECK_FATAL(false, "PrimType for MIRFloatConst / MIRDoubleConst should be PTY_f32 / PTY_f64"); 794 } 795 return resultConst; 796} 797 798ConstvalNode *ConstantFold::FoldConstUnary(Opcode opcode, PrimType resultType, ConstvalNode &constNode) const 799{ 800 ConstvalNode *returnValue = nullptr; 801 if (IsPrimitiveInteger(resultType)) { 802 const MIRIntConst *cst = safe_cast<MIRIntConst>(constNode.GetConstVal()); 803 auto constValue = FoldIntConstUnaryMIRConst(opcode, resultType, cst); 804 returnValue = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 805 returnValue->SetPrimType(resultType); 806 returnValue->SetConstVal(constValue); 807 } else if (resultType == PTY_f32) { 808 returnValue = FoldFPConstUnary<MIRFloatConst>(opcode, resultType, &constNode); 809 } else if (resultType == PTY_f64) { 810 returnValue = FoldFPConstUnary<MIRDoubleConst>(opcode, resultType, &constNode); 811 } else { 812 DEBUG_ASSERT(false, "Unhandled case for FoldConstUnary"); 813 } 814 return returnValue; 815} 816 817std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldRetype(RetypeNode *node) 818{ 819 CHECK_NULL_FATAL(node); 820 BaseNode *result = node; 821 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0)); 822 if (node->Opnd(0) != p.first) { 823 RetypeNode *newRetNode = node->CloneTree(mirModule->GetCurFuncCodeMPAllocator()); 824 CHECK_FATAL(newRetNode != nullptr, "newRetNode is null in ConstantFold::FoldRetype"); 825 newRetNode->SetOpnd(PairToExpr(node->Opnd(0)->GetPrimType(), p), 0); 826 result = newRetNode; 827 } 828 return std::make_pair(result, std::nullopt); 829} 830 831std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldUnary(UnaryNode *node) 832{ 833 CHECK_NULL_FATAL(node); 834 BaseNode *result = nullptr; 835 std::optional<IntVal> sum = std::nullopt; 836 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0)); 837 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first); 838 if (cst != nullptr) { 839 result = FoldConstUnary(node->GetOpCode(), node->GetPrimType(), *cst); 840 } else { 841 bool isInt = IsPrimitiveInteger(node->GetPrimType()); 842 // The neg node will be recreated regardless of whether the folding is successful or not. And the neg node's 843 // primType will be set to opnd type. There will be problems in some cases. For example: 844 // before cf: 845 // neg i32 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f)) 846 // after cf: 847 // neg u1 (eq u1 f32 (dread f32 %f_4_2, constval f32 0f)) # wrong! 848 // As a workaround, we exclude u1 opnd type 849 if (isInt && node->GetOpCode() == OP_neg && p.first->GetPrimType() != PTY_u1) { 850 result = NegateTree(p.first); 851 if (result->GetOpCode() == OP_neg) { 852 PrimType origPtyp = node->GetPrimType(); 853 PrimType newPtyp = result->GetPrimType(); 854 if (newPtyp == origPtyp) { 855 if (static_cast<UnaryNode*>(result)->Opnd(0) == node->Opnd(0)) { 856 // NegateTree returned an UnaryNode quivalent to `n`, so keep the 857 // original UnaryNode to preserve identity 858 result = node; 859 } 860 } else { 861 if (GetPrimTypeSize(newPtyp) != GetPrimTypeSize(origPtyp)) { 862 // do not fold explicit cvt 863 result = NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), 864 PairToExpr(node->Opnd(0)->GetPrimType(), p)); 865 return std::make_pair(result, std::nullopt); 866 } else { 867 result->SetPrimType(origPtyp); 868 } 869 } 870 } 871 if (p.second) { 872 sum = -(*p.second); 873 } 874 } else { 875 result = 876 NewUnaryNode(node, node->GetOpCode(), node->GetPrimType(), PairToExpr(node->Opnd(0)->GetPrimType(), p)); 877 } 878 } 879 return std::make_pair(result, sum); 880} 881 882static bool FloatToIntOverflow(float fval, PrimType totype) 883{ 884 static const float safeFloatMaxToInt32 = 2147483520.0f; // 2^31 - 128 885 static const float safeFloatMinToInt32 = -2147483520.0f; 886 static const float safeFloatMaxToInt64 = 9223372036854775680.0f; // 2^63 - 128 887 static const float safeFloatMinToInt64 = -9223372036854775680.0f; 888 if (!std::isfinite(fval)) { 889 return true; 890 } 891 if (totype == PTY_i64 || totype == PTY_u64) { 892 if (fval < safeFloatMinToInt64 || fval > safeFloatMaxToInt64) { 893 return true; 894 } 895 } else { 896 if (fval < safeFloatMinToInt32 || fval > safeFloatMaxToInt32) { 897 return true; 898 } 899 } 900 return false; 901} 902 903static bool DoubleToIntOverflow(double dval, PrimType totype) 904{ 905 static const double safeDoubleMaxToInt32 = 2147482624.0; // 2^31 - 1024 906 static const double safeDoubleMinToInt32 = -2147482624.0; 907 static const double safeDoubleMaxToInt64 = 9223372036854774784.0; // 2^63 - 1024 908 static const double safeDoubleMinToInt64 = -9223372036854774784.0; 909 if (!std::isfinite(dval)) { 910 return true; 911 } 912 if (totype == PTY_i64 || totype == PTY_u64) { 913 if (dval < safeDoubleMinToInt64 || dval > safeDoubleMaxToInt64) { 914 return true; 915 } 916 } else { 917 if (dval < safeDoubleMinToInt32 || dval > safeDoubleMaxToInt32) { 918 return true; 919 } 920 } 921 return false; 922} 923 924ConstvalNode *ConstantFold::FoldCeil(const ConstvalNode &cst, PrimType fromType, PrimType toType) const 925{ 926 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 927 resultConst->SetPrimType(toType); 928 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType); 929 if (fromType == PTY_f32) { 930 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal()); 931 ASSERT_NOT_NULL(constValue); 932 float floatValue = ceil(constValue->GetValue()); 933 if (IsPrimitiveFloat(toType)) { 934 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue)); 935 } else if (FloatToIntOverflow(floatValue, toType)) { 936 return nullptr; 937 } else { 938 resultConst->SetConstVal( 939 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType)); 940 } 941 } else { 942 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal()); 943 ASSERT_NOT_NULL(constValue); 944 double doubleValue = ceil(constValue->GetValue()); 945 if (IsPrimitiveFloat(toType)) { 946 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue)); 947 } else if (DoubleToIntOverflow(doubleValue, toType)) { 948 return nullptr; 949 } else { 950 resultConst->SetConstVal( 951 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType)); 952 } 953 } 954 return resultConst; 955} 956 957template <class T> 958T ConstantFold::CalIntValueFromFloatValue(T value, const MIRType &resultType) const 959{ 960 DEBUG_ASSERT(kByteSizeOfBit64 >= resultType.GetSize(), "unsupported type"); 961 size_t shiftNum = (kByteSizeOfBit64 - resultType.GetSize()) * kBitSizePerByte; 962 bool isSigned = IsSignedInteger(resultType.GetPrimType()); 963 int64 max = (IntVal(std::numeric_limits<int64>::max(), PTY_i64) >> shiftNum).GetExtValue(); 964 uint64 umax = std::numeric_limits<uint64>::max() >> shiftNum; 965 int64 min = isSigned ? (IntVal(std::numeric_limits<int64>::min(), PTY_i64) >> shiftNum).GetExtValue() : 0; 966 if (isSigned && (value > max)) { 967 return static_cast<T>(max); 968 } else if (!isSigned && (value > umax)) { 969 return static_cast<T>(umax); 970 } else if (value < min) { 971 return static_cast<T>(min); 972 } 973 return value; 974} 975 976MIRConst *ConstantFold::FoldFloorMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType, bool isFloor) const 977{ 978 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType); 979 if (fromType == PTY_f32) { 980 const auto &constValue = static_cast<const MIRFloatConst&>(cst); 981 float floatValue = constValue.GetValue(); 982 if (isFloor) { 983 floatValue = floor(constValue.GetValue()); 984 } 985 if (IsPrimitiveFloat(toType)) { 986 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue); 987 } 988 if (FloatToIntOverflow(floatValue, toType)) { 989 return nullptr; 990 } 991 floatValue = CalIntValueFromFloatValue(floatValue, resultType); 992 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType); 993 } else { 994 const auto &constValue = static_cast<const MIRDoubleConst&>(cst); 995 double doubleValue = constValue.GetValue(); 996 if (isFloor) { 997 doubleValue = floor(constValue.GetValue()); 998 } 999 if (IsPrimitiveFloat(toType)) { 1000 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue); 1001 } 1002 if (DoubleToIntOverflow(doubleValue, toType)) { 1003 return nullptr; 1004 } 1005 doubleValue = CalIntValueFromFloatValue(doubleValue, resultType); 1006 // gcc/clang have bugs convert double to unsigned long, must convert to signed long first; 1007 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(doubleValue), resultType); 1008 } 1009} 1010 1011ConstvalNode *ConstantFold::FoldFloor(const ConstvalNode &cst, PrimType fromType, PrimType toType) const 1012{ 1013 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 1014 resultConst->SetPrimType(toType); 1015 resultConst->SetConstVal(FoldFloorMIRConst(*cst.GetConstVal(), fromType, toType)); 1016 return resultConst; 1017} 1018 1019MIRConst *ConstantFold::FoldRoundMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const 1020{ 1021 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType); 1022 if (fromType == PTY_f32) { 1023 const auto &constValue = static_cast<const MIRFloatConst&>(cst); 1024 float floatValue = round(constValue.GetValue()); 1025 if (FloatToIntOverflow(floatValue, toType)) { 1026 return nullptr; 1027 } 1028 return GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<int64>(floatValue), resultType); 1029 } else if (fromType == PTY_f64) { 1030 const auto &constValue = static_cast<const MIRDoubleConst&>(cst); 1031 double doubleValue = round(constValue.GetValue()); 1032 if (DoubleToIntOverflow(doubleValue, toType)) { 1033 return nullptr; 1034 } 1035 return GlobalTables::GetIntConstTable().GetOrCreateIntConst( 1036 static_cast<uint64>(static_cast<int64>(doubleValue)), resultType); 1037 } else if (toType == PTY_f32 && IsPrimitiveInteger(fromType)) { 1038 const auto &constValue = static_cast<const MIRIntConst&>(cst); 1039 if (IsSignedInteger(fromType)) { 1040 int64 fromValue = constValue.GetExtValue(); 1041 float floatValue = round(static_cast<float>(fromValue)); 1042 if (static_cast<int64>(floatValue) == fromValue) { 1043 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue); 1044 } 1045 } else { 1046 uint64 fromValue = static_cast<uint64>(constValue.GetExtValue()); 1047 float floatValue = round(static_cast<float>(fromValue)); 1048 if (static_cast<uint64>(floatValue) == fromValue) { 1049 return GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue); 1050 } 1051 } 1052 } else if (toType == PTY_f64 && IsPrimitiveInteger(fromType)) { 1053 const auto &constValue = static_cast<const MIRIntConst&>(cst); 1054 if (IsSignedInteger(fromType)) { 1055 int64 fromValue = constValue.GetExtValue(); 1056 double doubleValue = round(static_cast<double>(fromValue)); 1057 if (static_cast<int64>(doubleValue) == fromValue) { 1058 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue); 1059 } 1060 } else { 1061 uint64 fromValue = static_cast<uint64>(constValue.GetExtValue()); 1062 double doubleValue = round(static_cast<double>(fromValue)); 1063 if (static_cast<uint64>(doubleValue) == fromValue) { 1064 return GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue); 1065 } 1066 } 1067 } 1068 return nullptr; 1069} 1070 1071ConstvalNode *ConstantFold::FoldRound(const ConstvalNode &cst, PrimType fromType, PrimType toType) const 1072{ 1073 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 1074 resultConst->SetPrimType(toType); 1075 resultConst->SetConstVal(FoldRoundMIRConst(*cst.GetConstVal(), fromType, toType)); 1076 return resultConst; 1077} 1078 1079ConstvalNode *ConstantFold::FoldTrunc(const ConstvalNode &cst, PrimType fromType, PrimType toType) const 1080{ 1081 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 1082 resultConst->SetPrimType(toType); 1083 MIRType &resultType = *GlobalTables::GetTypeTable().GetPrimType(toType); 1084 if (fromType == PTY_f32) { 1085 const MIRFloatConst *constValue = safe_cast<MIRFloatConst>(cst.GetConstVal()); 1086 CHECK_NULL_FATAL(constValue); 1087 float floatValue = trunc(constValue->GetValue()); 1088 if (IsPrimitiveFloat(toType)) { 1089 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue)); 1090 } else if (FloatToIntOverflow(floatValue, toType)) { 1091 return nullptr; 1092 } else { 1093 resultConst->SetConstVal( 1094 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(floatValue), resultType)); 1095 } 1096 } else { 1097 const MIRDoubleConst *constValue = safe_cast<MIRDoubleConst>(cst.GetConstVal()); 1098 CHECK_NULL_FATAL(constValue); 1099 double doubleValue = trunc(constValue->GetValue()); 1100 if (IsPrimitiveFloat(toType)) { 1101 resultConst->SetConstVal(GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue)); 1102 } else if (DoubleToIntOverflow(doubleValue, toType)) { 1103 return nullptr; 1104 } else { 1105 resultConst->SetConstVal( 1106 GlobalTables::GetIntConstTable().GetOrCreateIntConst(static_cast<uint64>(doubleValue), resultType)); 1107 } 1108 } 1109 return resultConst; 1110} 1111 1112MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromType, PrimType toType) const 1113{ 1114 if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) { 1115 MIRConst *toConst = nullptr; 1116 uint32 fromSize = GetPrimTypeBitSize(fromType); 1117 uint32 toSize = GetPrimTypeBitSize(toType); 1118 // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here. 1119 if (fromType == PTY_u1) { 1120 fromSize = 1; 1121 } 1122 if (toType == PTY_u1) { 1123 toSize = 1; 1124 } 1125 if (toSize > fromSize) { 1126 Opcode op = OP_zext; 1127 if (IsSignedInteger(fromType)) { 1128 op = OP_sext; 1129 } 1130 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst); 1131 ASSERT_NOT_NULL(constVal); 1132 toConst = FoldSignExtendMIRConst(op, toType, static_cast<uint8>(fromSize), 1133 constVal->GetValue().TruncOrExtend(fromType)); 1134 } else { 1135 const MIRIntConst *constVal = safe_cast<MIRIntConst>(cst); 1136 ASSERT_NOT_NULL(constVal); 1137 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType); 1138 toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst( 1139 static_cast<uint64>(constVal->GetExtValue()), type); 1140 } 1141 return toConst; 1142 } 1143 if (IsPrimitiveFloat(fromType) && IsPrimitiveFloat(toType)) { 1144 MIRConst *toConst = nullptr; 1145 if (GetPrimTypeBitSize(toType) < GetPrimTypeBitSize(fromType)) { 1146 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 32, "We suppot F32 and F64"); // just support 32 or 64 1147 const MIRDoubleConst *fromValue = safe_cast<MIRDoubleConst>(cst); 1148 ASSERT_NOT_NULL(fromValue); 1149 float floatValue = static_cast<float>(fromValue->GetValue()); 1150 MIRFloatConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateFloatConst(floatValue); 1151 toConst = toValue; 1152 } else { 1153 DEBUG_ASSERT(GetPrimTypeBitSize(toType) == 64, "We suppot F32 and F64"); // just support 32 or 64 1154 const MIRFloatConst *fromValue = safe_cast<MIRFloatConst>(cst); 1155 ASSERT_NOT_NULL(fromValue); 1156 double doubleValue = static_cast<double>(fromValue->GetValue()); 1157 MIRDoubleConst *toValue = GlobalTables::GetFpConstTable().GetOrCreateDoubleConst(doubleValue); 1158 toConst = toValue; 1159 } 1160 return toConst; 1161 } 1162 if (IsPrimitiveFloat(fromType) && IsPrimitiveInteger(toType)) { 1163 return FoldFloorMIRConst(cst, fromType, toType, false); 1164 } 1165 if (IsPrimitiveInteger(fromType) && IsPrimitiveFloat(toType)) { 1166 return FoldRoundMIRConst(cst, fromType, toType); 1167 } 1168 CHECK_FATAL(false, "Unexpected case in ConstFoldTypeCvt"); 1169 return nullptr; 1170} 1171 1172ConstvalNode *ConstantFold::FoldTypeCvt(const ConstvalNode &cst, PrimType fromType, PrimType toType) const 1173{ 1174 MIRConst *toConstValue = FoldTypeCvtMIRConst(*cst.GetConstVal(), fromType, toType); 1175 if (toConstValue == nullptr) { 1176 return nullptr; 1177 } 1178 ConstvalNode *toConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 1179 toConst->SetPrimType(toConstValue->GetType().GetPrimType()); 1180 toConst->SetConstVal(toConstValue); 1181 return toConst; 1182} 1183 1184// return a primType with bit size >= bitSize (and the nearest one), 1185// and its signed/float type is the same as ptyp 1186PrimType GetNearestSizePtyp(uint8 bitSize, PrimType ptyp) 1187{ 1188 bool isSigned = IsSignedInteger(ptyp); 1189 bool isFloat = IsPrimitiveFloat(ptyp); 1190 if (bitSize == 1) { // 1 bit 1191 return PTY_u1; 1192 } 1193 if (bitSize <= 8) { // 8 bit 1194 return isSigned ? PTY_i8 : PTY_u8; 1195 } 1196 if (bitSize <= 16) { // 16 bit 1197 return isSigned ? PTY_i16 : PTY_u16; 1198 } 1199 if (bitSize <= 32) { // 32 bit 1200 return isFloat ? PTY_f32 : (isSigned ? PTY_i32 : PTY_u32); 1201 } 1202 if (bitSize <= 64) { // 64 bit 1203 return isFloat ? PTY_f64 : (isSigned ? PTY_i64 : PTY_u64); 1204 } 1205 return ptyp; 1206} 1207 1208size_t GetIntPrimTypeMax(PrimType ptyp) 1209{ 1210 switch (ptyp) { 1211 case PTY_u1: 1212 return 1; 1213 case PTY_u8: 1214 return UINT8_MAX; 1215 case PTY_i8: 1216 return INT8_MAX; 1217 case PTY_u16: 1218 return UINT16_MAX; 1219 case PTY_i16: 1220 return INT16_MAX; 1221 case PTY_u32: 1222 return UINT32_MAX; 1223 case PTY_i32: 1224 return INT32_MAX; 1225 case PTY_u64: 1226 return UINT64_MAX; 1227 case PTY_i64: 1228 return INT64_MAX; 1229 default: 1230 CHECK_FATAL(false, "NYI"); 1231 } 1232} 1233 1234ssize_t GetIntPrimTypeMin(PrimType ptyp) 1235{ 1236 if (IsUnsignedInteger(ptyp)) { 1237 return 0; 1238 } 1239 switch (ptyp) { 1240 case PTY_i8: 1241 return INT8_MIN; 1242 case PTY_i16: 1243 return INT16_MIN; 1244 case PTY_i32: 1245 return INT32_MIN; 1246 case PTY_i64: 1247 return INT64_MIN; 1248 default: 1249 CHECK_FATAL(false, "NYI"); 1250 } 1251} 1252 1253static bool IsCvtEliminatable(PrimType fromPtyp, PrimType destPtyp, Opcode op, Opcode opndOp) 1254{ 1255 if (op != OP_cvt || (opndOp == OP_zext || opndOp == OP_sext)) { 1256 return false; 1257 } 1258 if (GetPrimTypeSize(fromPtyp) != GetPrimTypeSize(destPtyp)) { 1259 return false; 1260 } 1261 return (IsPossible64BitAddress(fromPtyp) && IsPossible64BitAddress(destPtyp)) || 1262 (IsPossible32BitAddress(fromPtyp) && IsPossible32BitAddress(destPtyp)) || 1263 (IsPrimitivePureScalar(fromPtyp) && IsPrimitivePureScalar(destPtyp)); 1264} 1265 1266std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldTypeCvt(TypeCvtNode *node) 1267{ 1268 CHECK_NULL_FATAL(node); 1269 BaseNode *result = nullptr; 1270 if (GetPrimTypeSize(node->GetPrimType()) > k8ByteSize) { 1271 return {node, std::nullopt}; 1272 } 1273 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0)); 1274 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first); 1275 PrimType destPtyp = node->GetPrimType(); 1276 PrimType fromPtyp = node->FromType(); 1277 if (cst != nullptr) { 1278 switch (node->GetOpCode()) { 1279 case OP_ceil: { 1280 result = FoldCeil(*cst, fromPtyp, destPtyp); 1281 break; 1282 } 1283 case OP_cvt: { 1284 result = FoldTypeCvt(*cst, fromPtyp, destPtyp); 1285 break; 1286 } 1287 case OP_floor: { 1288 result = FoldFloor(*cst, fromPtyp, destPtyp); 1289 break; 1290 } 1291 case OP_trunc: { 1292 result = FoldTrunc(*cst, fromPtyp, destPtyp); 1293 break; 1294 } 1295 default: 1296 DEBUG_ASSERT(false, "Unexpected opcode in TypeCvtNodeConstFold"); 1297 break; 1298 } 1299 } else if (IsCvtEliminatable(fromPtyp, destPtyp, node->GetOpCode(), p.first->GetOpCode())) { 1300 // the cvt is redundant 1301 return std::make_pair(p.first, p.second ? IntVal(*p.second, node->GetPrimType()) : p.second); 1302 } 1303 if (result == nullptr) { 1304 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p); 1305 if (e != node->Opnd(0)) { 1306 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>( 1307 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->FromType()), e); 1308 } else { 1309 result = node; 1310 } 1311 } 1312 return std::make_pair(result, std::nullopt); 1313} 1314 1315MIRConst *ConstantFold::FoldSignExtendMIRConst(Opcode opcode, PrimType resultType, uint8 size, const IntVal &val) const 1316{ 1317 uint64 result = opcode == OP_sext ? static_cast<uint64>(val.GetSXTValue(size)) : val.GetZXTValue(size); 1318 MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(resultType); 1319 MIRIntConst *constValue = GlobalTables::GetIntConstTable().GetOrCreateIntConst(result, type); 1320 return constValue; 1321} 1322 1323ConstvalNode *ConstantFold::FoldSignExtend(Opcode opcode, PrimType resultType, uint8 size, 1324 const ConstvalNode &cst) const 1325{ 1326 ConstvalNode *resultConst = mirModule->CurFuncCodeMemPool()->New<ConstvalNode>(); 1327 const auto *intCst = safe_cast<MIRIntConst>(cst.GetConstVal()); 1328 ASSERT_NOT_NULL(intCst); 1329 IntVal val = intCst->GetValue().TruncOrExtend(size, opcode == OP_sext); 1330 MIRConst *toConst = FoldSignExtendMIRConst(opcode, resultType, size, val); 1331 resultConst->SetPrimType(toConst->GetType().GetPrimType()); 1332 resultConst->SetConstVal(toConst); 1333 return resultConst; 1334} 1335 1336// check if truncation is redundant due to dread or iread having same effect 1337static bool ExtractbitsRedundant(const ExtractbitsNode &x, MIRFunction &f) 1338{ 1339 if (GetPrimTypeSize(x.GetPrimType()) == k8ByteSize) { 1340 return false; // this is trying to be conservative 1341 } 1342 BaseNode *opnd = x.Opnd(0); 1343 MIRType *mirType = nullptr; 1344 if (opnd->GetOpCode() == OP_dread) { 1345 DreadNode *dread = static_cast<DreadNode*>(opnd); 1346 MIRSymbol *sym = f.GetLocalOrGlobalSymbol(dread->GetStIdx()); 1347 ASSERT_NOT_NULL(sym); 1348 mirType = sym->GetType(); 1349 } else if (opnd->GetOpCode() == OP_iread) { 1350 IreadNode *iread = static_cast<IreadNode*>(opnd); 1351 MIRPtrType *ptrType = 1352 dynamic_cast<MIRPtrType*>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(iread->GetTyIdx())); 1353 if (ptrType == nullptr) { 1354 return false; 1355 } 1356 mirType = ptrType->GetPointedType(); 1357 } else if (opnd->GetOpCode() == OP_extractbits && 1358 x.GetBitsSize() > static_cast<ExtractbitsNode*>(opnd)->GetBitsSize()) { 1359 return (x.GetOpCode() == OP_zext && x.GetPrimType() == opnd->GetPrimType() && 1360 IsUnsignedInteger(opnd->GetPrimType())); 1361 } else { 1362 return false; 1363 } 1364 return IsPrimitiveInteger(mirType->GetPrimType()) && 1365 ((x.GetOpCode() == OP_zext && IsUnsignedInteger(opnd->GetPrimType())) || 1366 (x.GetOpCode() == OP_sext && IsSignedInteger(opnd->GetPrimType()))) && 1367 mirType->GetSize() * kBitSizePerByte == x.GetBitsSize() && 1368 mirType->GetPrimType() == x.GetPrimType(); 1369} 1370 1371// sext and zext also handled automatically 1372std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldExtractbits(ExtractbitsNode *node) 1373{ 1374 CHECK_NULL_FATAL(node); 1375 BaseNode *result = nullptr; 1376 uint8 offset = node->GetBitsOffset(); 1377 uint8 size = node->GetBitsSize(); 1378 Opcode opcode = node->GetOpCode(); 1379 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0)); 1380 ConstvalNode *cst = safe_cast<ConstvalNode>(p.first); 1381 if (cst != nullptr && (opcode == OP_sext || opcode == OP_zext)) { 1382 result = FoldSignExtend(opcode, node->GetPrimType(), size, *cst); 1383 return std::make_pair(result, std::nullopt); 1384 } 1385 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p); 1386 if (e != node->Opnd(0)) { 1387 result = mirModule->CurFuncCodeMemPool()->New<ExtractbitsNode>(opcode, PrimType(node->GetPrimType()), offset, 1388 size, e); 1389 } else { 1390 result = node; 1391 } 1392 // check for consecutive and redundant extraction of same bits 1393 BaseNode *opnd = result->Opnd(0); 1394 DEBUG_ASSERT(opnd != nullptr, "opnd shoule not be null"); 1395 Opcode opndOp = opnd->GetOpCode(); 1396 if (opndOp == OP_extractbits || opndOp == OP_sext || opndOp == OP_zext) { 1397 uint8 opndOffset = static_cast<ExtractbitsNode*>(opnd)->GetBitsOffset(); 1398 uint8 opndSize = static_cast<ExtractbitsNode*>(opnd)->GetBitsSize(); 1399 if (offset == opndOffset && size == opndSize) { 1400 result->SetOpnd(opnd->Opnd(0), 0); // delete the redundant extraction 1401 } 1402 } 1403 if (offset == 0 && size >= k8ByteSize && IsPowerOf2(size)) { 1404 if (ExtractbitsRedundant(*static_cast<ExtractbitsNode*>(result), *mirModule->CurFunction())) { 1405 return std::make_pair(result->Opnd(0), std::nullopt); 1406 } 1407 } 1408 return std::make_pair(result, std::nullopt); 1409} 1410 1411std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldIread(IreadNode *node) 1412{ 1413 CHECK_NULL_FATAL(node); 1414 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node->Opnd(0)); 1415 BaseNode *e = PairToExpr(node->Opnd(0)->GetPrimType(), p); 1416 node->SetOpnd(e, 0); 1417 BaseNode *result = node; 1418 if (e->GetOpCode() != OP_addrof) { 1419 return std::make_pair(result, std::nullopt); 1420 } 1421 1422 AddrofNode *addrofNode = static_cast<AddrofNode*>(e); 1423 MIRSymbol *msy = mirModule->CurFunction()->GetLocalOrGlobalSymbol(addrofNode->GetStIdx()); 1424 DEBUG_ASSERT(msy != nullptr, "nullptr check"); 1425 TyIdx typeId = msy->GetTyIdx(); 1426 CHECK_FATAL(!GlobalTables::GetTypeTable().GetTypeTable().empty(), "container check"); 1427 MIRType *msyType = GlobalTables::GetTypeTable().GetTypeTable()[typeId]; 1428 MIRPtrType *ptrType = static_cast<MIRPtrType *>(GlobalTables::GetTypeTable().GetTypeFromTyIdx(node->GetTyIdx())); 1429 // If the high level type of iaddrof/iread doesn't match 1430 // the type of addrof's rhs, this optimization cannot be done. 1431 if (ptrType->GetPointedType() != msyType) { 1432 return std::make_pair(result, std::nullopt); 1433 } 1434 1435 Opcode op = node->GetOpCode(); 1436 if (op == OP_iread) { 1437 result = mirModule->CurFuncCodeMemPool()->New<AddrofNode>(OP_dread, node->GetPrimType(), addrofNode->GetStIdx(), 1438 node->GetFieldID() + addrofNode->GetFieldID()); 1439 } 1440 return std::make_pair(result, std::nullopt); 1441} 1442 1443bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB) 1444{ 1445 switch (op) { 1446 case OP_add: { 1447 int64 res = static_cast<int64>(static_cast<uint64>(cstA) + static_cast<uint64>(cstB)); 1448 if (IsUnsignedInteger(primType)) { 1449 return static_cast<uint64>(res) < static_cast<uint64>(cstA); 1450 } 1451 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1; 1452 return (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag != 1453 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag) && 1454 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag != 1455 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag); 1456 } 1457 case OP_sub: { 1458 if (IsUnsignedInteger(primType)) { 1459 return cstA < cstB; 1460 } 1461 int64 res = static_cast<int64>(static_cast<uint64>(cstA) - static_cast<uint64>(cstB)); 1462 auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1; 1463 return (static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag != 1464 static_cast<uint64>(cstB) >> rightShiftNumToGetSignFlag) && 1465 (static_cast<uint64>(res) >> rightShiftNumToGetSignFlag != 1466 static_cast<uint64>(cstA) >> rightShiftNumToGetSignFlag); 1467 } 1468 default: { 1469 return false; 1470 } 1471 } 1472} 1473 1474std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldBinary(BinaryNode *node) 1475{ 1476 CHECK_NULL_FATAL(node); 1477 BaseNode *result = nullptr; 1478 std::optional<IntVal> sum = std::nullopt; 1479 Opcode op = node->GetOpCode(); 1480 PrimType primType = node->GetPrimType(); 1481 PrimType lPrimTypes = node->Opnd(0)->GetPrimType(); 1482 PrimType rPrimTypes = node->Opnd(1)->GetPrimType(); 1483 std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0)); 1484 std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1)); 1485 BaseNode *l = lp.first; 1486 BaseNode *r = rp.first; 1487 ASSERT_NOT_NULL(r); 1488 ConstvalNode *lConst = safe_cast<ConstvalNode>(l); 1489 ConstvalNode *rConst = safe_cast<ConstvalNode>(r); 1490 bool isInt = IsPrimitiveInteger(primType); 1491 1492 if (lConst != nullptr && rConst != nullptr) { 1493 MIRConst *lConstVal = lConst->GetConstVal(); 1494 MIRConst *rConstVal = rConst->GetConstVal(); 1495 ASSERT_NOT_NULL(lConstVal); 1496 ASSERT_NOT_NULL(rConstVal); 1497 // Don't fold div by 0, for floats div by 0 is well defined. 1498 if ((op == OP_div || op == OP_rem) && isInt && 1499 !IsDivSafe(static_cast<MIRIntConst &>(*lConstVal), static_cast<MIRIntConst &>(*rConstVal), primType)) { 1500 result = NewBinaryNode(node, op, primType, lConst, rConst); 1501 } else { 1502 // 4 + 2 -> return a pair(result = ConstValNode(6), sum = 0) 1503 // Create a new ConstvalNode for 6 but keep the sum = 0. This simplify the 1504 // logic since the alternative is to return pair(result = nullptr, sum = 6). 1505 // Doing so would introduce many nullptr checks in the code. See previous 1506 // commits that implemented that logic for a comparison. 1507 result = FoldConstBinary(op, primType, *lConst, *rConst); 1508 } 1509 } else if (lConst != nullptr && isInt) { 1510 MIRIntConst *mcst = safe_cast<MIRIntConst>(lConst->GetConstVal()); 1511 ASSERT_NOT_NULL(mcst); 1512 PrimType cstTyp = mcst->GetType().GetPrimType(); 1513 IntVal cst = mcst->GetValue(); 1514 if (op == OP_add) { 1515 if (IsSignedInteger(cstTyp) && rp.second && 1516 IntegerOpIsOverflow(OP_add, cstTyp, cst.GetExtValue(), rp.second->GetExtValue())) { 1517 // do not introduce signed integer overflow 1518 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp)); 1519 } else { 1520 sum = cst + rp.second; 1521 result = r; 1522 } 1523 } else if (op == OP_sub && r->GetPrimType() != PTY_u1) { 1524 // We exclude u1 type for fixing the following wrong example: 1525 // before cf: 1526 // sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16))) 1527 // after cf: 1528 // add i32 (cvt i32 u1 (neg u1 (eq u1 i32 (dread i32 %i, constval i32 16))), constval i32 17)) 1529 sum = cst - rp.second; 1530 if (GetPrimTypeSize(r->GetPrimType()) < GetPrimTypeSize(primType)) { 1531 r = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, r->GetPrimType(), r); 1532 } 1533 result = NegateTree(r); 1534 } else if ((op == OP_mul || op == OP_div || op == OP_rem || op == OP_ashr || op == OP_lshr || op == OP_shl || 1535 op == OP_band) && 1536 cst == 0) { 1537 // 0 * X -> 0 1538 // 0 / X -> 0 1539 // 0 % X -> 0 1540 // 0 >> X -> 0 1541 // 0 << X -> 0 1542 // 0 & X -> 0 1543 // 0 && X -> 0 1544 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp); 1545 } else if (op == OP_mul && cst == 1) { 1546 // 1 * X --> X 1547 sum = rp.second; 1548 result = r; 1549 } else if (op == OP_bior && cst == -1) { 1550 // (-1) | X -> -1 1551 result = mirModule->GetMIRBuilder()->CreateIntConst(static_cast<uint64>(-1), cstTyp); 1552 } else if (op == OP_mul && rp.second.has_value() && *rp.second != 0) { 1553 // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)] 1554 sum = cst * rp.second; 1555 if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) { 1556 rp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, rp.first); 1557 } 1558 result = NewBinaryNode(node, OP_mul, primType, lConst, rp.first); 1559 } else if ((op == OP_bior || op == OP_bxor) && cst == 0) { 1560 // 0 | X -> X 1561 // 0 ^ X -> X 1562 sum = rp.second; 1563 result = r; 1564 } else { 1565 result = NewBinaryNode(node, op, primType, l, PairToExpr(rPrimTypes, rp)); 1566 } 1567 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) { 1568 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result); 1569 } 1570 } else if (rConst != nullptr && isInt) { 1571 MIRIntConst *mcst = safe_cast<MIRIntConst>(rConst->GetConstVal()); 1572 ASSERT_NOT_NULL(mcst); 1573 PrimType cstTyp = mcst->GetType().GetPrimType(); 1574 IntVal cst = mcst->GetValue(); 1575 if (op == OP_add) { 1576 if (lp.second && IntegerOpIsOverflow(op, cstTyp, lp.second->GetExtValue(), cst.GetExtValue())) { 1577 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp)); 1578 } else { 1579 result = l; 1580 sum = lp.second + cst; 1581 } 1582 } else if (op == OP_sub && (!cst.IsSigned() || !cst.IsMinValue())) { 1583 result = l; 1584 sum = lp.second - cst; 1585 } else if ((op == OP_mul || op == OP_band) && cst == 0) { 1586 // X * 0 -> 0 1587 // X & 0 -> 0 1588 // X && 0 -> 0 1589 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp); 1590 } else if ((op == OP_mul || op == OP_div) && cst == 1) { 1591 // case [X * 1 -> X] 1592 // case [X / 1 = X] 1593 sum = lp.second; 1594 result = l; 1595 } else if (op == OP_div && !lp.second.has_value() && l->GetOpCode() == OP_mul && 1596 IsSignedInteger(primType) && IsSignedInteger(lPrimTypes) && IsSignedInteger(rPrimTypes)) { 1597 // temporary fix for constfold of mul/div in DejaGnu 1598 // Later we need a more formal interface for pattern match 1599 // X * Y / Y -> X 1600 BaseNode *x = l->Opnd(0); 1601 BaseNode *y = l->Opnd(1); 1602 ConstvalNode *xConst = safe_cast<ConstvalNode>(x); 1603 ConstvalNode *yConst = safe_cast<ConstvalNode>(y); 1604 bool foldMulDiv = false; 1605 if (yConst != nullptr && xConst == nullptr && 1606 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) { 1607 MIRIntConst *yCst = safe_cast<MIRIntConst>(yConst->GetConstVal()); 1608 ASSERT_NOT_NULL(yCst); 1609 IntVal mulCst = yCst->GetValue(); 1610 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() && 1611 mulCst.GetExtValue() == cst.GetExtValue()) { 1612 foldMulDiv = true; 1613 result = x; 1614 } 1615 } else if (xConst != nullptr && yConst == nullptr && 1616 IsSignedInteger(x->GetPrimType()) && IsSignedInteger(y->GetPrimType())) { 1617 MIRIntConst *xCst = safe_cast<MIRIntConst>(xConst->GetConstVal()); 1618 ASSERT_NOT_NULL(xCst); 1619 IntVal mulCst = xCst->GetValue(); 1620 if (mulCst.GetBitWidth() == cst.GetBitWidth() && mulCst.IsSigned() == cst.IsSigned() && 1621 mulCst.GetExtValue() == cst.GetExtValue()) { 1622 foldMulDiv = true; 1623 result = y; 1624 } 1625 } 1626 if (!foldMulDiv) { 1627 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r); 1628 } 1629 } else if (op == OP_mul && lp.second.has_value() && *lp.second != 0 && lp.second->GetSXTValue() > -kMaxOffset) { 1630 // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)] 1631 sum = lp.second * cst; 1632 if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) { 1633 lp.first = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_i32, lp.first); 1634 } 1635 if (lp.first->GetOpCode() == OP_neg && cst == -1) { 1636 // special case: ((-X) + konst) * (-1) -> the pair [(X), -konst] 1637 result = lp.first->Opnd(0); 1638 } else { 1639 result = NewBinaryNode(node, OP_mul, primType, lp.first, rConst); 1640 } 1641 } else if (op == OP_band && cst == -1) { 1642 // X & (-1) -> X 1643 sum = lp.second; 1644 result = l; 1645 } else if (op == OP_band && ContiguousBitsOf1(cst.GetZXTValue()) && 1646 (!lp.second.has_value() || lp.second == 0)) { 1647 bool fold2extractbits = false; 1648 if (l->GetOpCode() == OP_ashr || l->GetOpCode() == OP_lshr) { 1649 BinaryNode *shrNode = static_cast<BinaryNode *>(l); 1650 if (shrNode->Opnd(1)->GetOpCode() == OP_constval) { 1651 ConstvalNode *shrOpnd = static_cast<ConstvalNode *>(shrNode->Opnd(1)); 1652 int64 shrAmt = static_cast<MIRIntConst*>(shrOpnd->GetConstVal())->GetExtValue(); 1653 uint64 ucst = cst.GetZXTValue(); 1654 uint32 bsize = 0; 1655 do { 1656 bsize++; 1657 ucst >>= 1; 1658 } while (ucst != 0); 1659 if (shrAmt + static_cast<int64>(bsize) <= 1660 static_cast<int64>(GetPrimTypeSize(primType) * kBitSizePerByte) && 1661 static_cast<uint64>(shrAmt) < GetPrimTypeSize(primType) * kBitSizePerByte) { 1662 fold2extractbits = true; 1663 // change to use extractbits 1664 result = mirModule->GetMIRBuilder()->CreateExprExtractbits(OP_extractbits, 1665 GetUnsignedPrimType(primType), static_cast<uint32>(shrAmt), bsize, shrNode->Opnd(0)); 1666 sum = std::nullopt; 1667 } 1668 } 1669 } 1670 if (!fold2extractbits) { 1671 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r); 1672 sum = std::nullopt; 1673 } 1674 } else if (op == OP_bior && cst == -1) { 1675 // X | (-1) -> -1 1676 result = mirModule->GetMIRBuilder()->CreateIntConst(-1ULL, cstTyp); 1677 } else if ((op == OP_ashr || op == OP_lshr || op == OP_shl || op == OP_bior || op == OP_bxor) && cst == 0) { 1678 // X >> 0 -> X 1679 // X << 0 -> X 1680 // X | 0 -> X 1681 // X ^ 0 -> X 1682 sum = lp.second; 1683 result = l; 1684 } else if (op == OP_bxor && cst == 1 && primType != PTY_u1) { 1685 // bxor i32 ( 1686 // cvt i32 u1 (regread u1 %13), 1687 // constValue i32 1), 1688 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp)); 1689 if (l->GetOpCode() == OP_cvt && (!lp.second || lp.second == 0)) { 1690 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(l); 1691 if (cvtNode->Opnd(0)->GetPrimType() == PTY_u1) { 1692 BaseNode *base = cvtNode->Opnd(0); 1693 BaseNode *constValue = mirModule->GetMIRBuilder()->CreateIntConst(1, base->GetPrimType()); 1694 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(base); 1695 BinaryNode *temp = NewBinaryNode(node, op, PTY_u1, PairToExpr(base->GetPrimType(), p), constValue); 1696 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, PTY_u1, temp); 1697 } 1698 } 1699 } else if (op == OP_rem && cst == 1) { 1700 // X % 1 -> 0 1701 result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp); 1702 } else { 1703 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r); 1704 } 1705 if (!IsNoCvtNeeded(result->GetPrimType(), primType)) { 1706 result = mirModule->CurFuncCodeMemPool()->New<TypeCvtNode>(OP_cvt, primType, result->GetPrimType(), result); 1707 } 1708 } else if (isInt && (op == OP_add || op == OP_sub)) { 1709 if (op == OP_add) { 1710 result = NewBinaryNode(node, op, primType, l, r); 1711 sum = lp.second + rp.second; 1712 } else if (r != nullptr && node->Opnd(1)->GetOpCode() == OP_sub && r->GetOpCode() == OP_neg) { 1713 // if fold is (x - (y - z)) -> (x - neg(z)) - y 1714 // (x - neg(z)) Could cross the int limit 1715 // return node 1716 result = node; 1717 } else { 1718 result = NewBinaryNode(node, op, primType, l, r); 1719 sum = lp.second - rp.second; 1720 } 1721 } else { 1722 result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp)); 1723 } 1724 return std::make_pair(result, sum); 1725} 1726 1727BaseNode *ConstantFold::SimplifyDoubleConstvalCompare(CompareNode &node, bool isRConstval, bool isGtOrLt) const 1728{ 1729 if (isRConstval) { 1730 ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(1)); 1731 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) { 1732 const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(0)); 1733 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(), 1734 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1)); 1735 } 1736 } else { 1737 ConstvalNode *constNode = static_cast<ConstvalNode*>(node.Opnd(0)); 1738 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) { 1739 const CompareNode *compNode = static_cast<CompareNode*>(node.Opnd(1)); 1740 if (isGtOrLt) { 1741 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(), 1742 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(1), compNode->Opnd(0)); 1743 } else { 1744 return mirModule->CurFuncCodeMemPool()->New<CompareNode>(node.GetOpCode(), 1745 node.GetPrimType(), compNode->GetOpndType(), compNode->Opnd(0), compNode->Opnd(1)); 1746 } 1747 } 1748 } 1749 return &node; 1750} 1751 1752BaseNode *ConstantFold::SimplifyDoubleCompare(CompareNode &compareNode) const 1753{ 1754 // See arm manual B.cond(P2993) and FCMP(P1091) 1755 CompareNode *node = &compareNode; 1756 BaseNode *result = node; 1757 BaseNode *l = node->Opnd(0); 1758 BaseNode *r = node->Opnd(1); 1759 if (node->GetOpCode() == OP_ne || node->GetOpCode() == OP_eq) { 1760 if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) || 1761 (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) { 1762 result = SimplifyDoubleConstvalCompare(*node, (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval)); 1763 } else if (node->GetOpCode() == OP_ne && r->GetOpCode() == OP_constval) { 1764 // ne (u1 x, constValue 0) <==> x 1765 ConstvalNode *constNode = static_cast<ConstvalNode*>(r); 1766 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero()) { 1767 BaseNode *opnd = l; 1768 do { 1769 if (opnd->GetPrimType() == PTY_u1 || (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) { 1770 result = opnd; 1771 break; 1772 } else if (opnd->GetOpCode() == OP_cvt) { 1773 TypeCvtNode *cvtNode = static_cast<TypeCvtNode*>(opnd); 1774 opnd = cvtNode->Opnd(0); 1775 } else { 1776 opnd = nullptr; 1777 } 1778 } while (opnd != nullptr); 1779 } 1780 } else if (node->GetOpCode() == OP_eq && r->GetOpCode() == OP_constval) { 1781 ConstvalNode *constNode = static_cast<ConstvalNode*>(r); 1782 if (constNode->GetConstVal()->GetKind() == kConstInt && constNode->GetConstVal()->IsZero() && 1783 (l->GetOpCode() == OP_ne || l->GetOpCode() == OP_eq)) { 1784 auto resOp = l->GetOpCode() == OP_ne ? OP_eq : OP_ne; 1785 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>( 1786 resOp, l->GetPrimType(), static_cast<CompareNode*>(l)->GetOpndType(), l->Opnd(0), l->Opnd(1)); 1787 } 1788 } 1789 } else if (node->GetOpCode() == OP_gt || node->GetOpCode() == OP_lt) { 1790 if ((l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval) || 1791 (r->GetOpCode() == OP_cmp && l->GetOpCode() == OP_constval)) { 1792 result = SimplifyDoubleConstvalCompare(*node, 1793 (l->GetOpCode() == OP_cmp && r->GetOpCode() == OP_constval), true); 1794 } 1795 } 1796 return result; 1797} 1798 1799std::pair<BaseNode*, std::optional<IntVal>> ConstantFold::FoldCompare(CompareNode *node) 1800{ 1801 CHECK_NULL_FATAL(node); 1802 BaseNode *result = nullptr; 1803 std::pair<BaseNode*, std::optional<IntVal>> lp = DispatchFold(node->Opnd(0)); 1804 std::pair<BaseNode*, std::optional<IntVal>> rp = DispatchFold(node->Opnd(1)); 1805 ConstvalNode *lConst = safe_cast<ConstvalNode>(lp.first); 1806 ConstvalNode *rConst = safe_cast<ConstvalNode>(rp.first); 1807 Opcode opcode = node->GetOpCode(); 1808 if (lConst != nullptr && rConst != nullptr) { 1809 result = FoldConstComparison(node->GetOpCode(), node->GetPrimType(), node->GetOpndType(), *lConst, *rConst); 1810 } else if (lConst != nullptr && rConst == nullptr && opcode != OP_cmp && 1811 lConst->GetConstVal()->GetKind() == kConstInt) { 1812 BaseNode *l = lp.first; 1813 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp); 1814 result = FoldConstComparisonReverse(opcode, node->GetPrimType(), node->GetOpndType(), *l, *r); 1815 } else { 1816 BaseNode *l = PairToExpr(node->Opnd(0)->GetPrimType(), lp); 1817 BaseNode *r = PairToExpr(node->Opnd(1)->GetPrimType(), rp); 1818 if (l != node->Opnd(0) || r != node->Opnd(1)) { 1819 result = mirModule->CurFuncCodeMemPool()->New<CompareNode>( 1820 Opcode(node->GetOpCode()), PrimType(node->GetPrimType()), PrimType(node->GetOpndType()), l, r); 1821 } else { 1822 result = node; 1823 } 1824 auto *compareNode = static_cast<CompareNode*>(result); 1825 CHECK_NULL_FATAL(compareNode); 1826 result = SimplifyDoubleCompare(*compareNode); 1827 } 1828 return std::make_pair(result, std::nullopt); 1829} 1830 1831BaseNode *ConstantFold::Fold(BaseNode *node) 1832{ 1833 if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) { 1834 return nullptr; 1835 } 1836 std::pair<BaseNode*, std::optional<IntVal>> p = DispatchFold(node); 1837 BaseNode *result = PairToExpr(node->GetPrimType(), p); 1838 if (result == node) { 1839 result = nullptr; 1840 } 1841 return result; 1842} 1843 1844} // namespace maple 1845