1/* 2 * Copyright (c) 2024 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16#include "ecmascript/compiler/induction_variable_analysis.h" 17 18namespace panda::ecmascript::kungfu { 19 20bool InductionVariableAnalysis::IsIntConstant(GateRef gate) const 21{ 22 if (acc_.GetOpCode(gate) != OpCode::CONSTANT) { 23 return false; 24 } 25 JSTaggedValue value(acc_.GetConstantValue(gate)); 26 return value.IsInt(); 27} 28 29bool InductionVariableAnalysis::IsInductionVariable(GateRef gate) const 30{ 31 if (acc_.GetOpCode(gate) != OpCode::VALUE_SELECTOR) { 32 return false; 33 } 34 size_t numValueIn = acc_.GetNumValueIn(gate); 35 GateRef startGate = acc_.GetValueIn(gate, 0); 36 GateRef valueGate = acc_.GetValueIn(gate, 1); 37 if (!IsIntConstant(startGate)) { 38 return false; 39 } 40 if (acc_.GetOpCode(valueGate) != OpCode::TYPED_BINARY_OP) { 41 return false; 42 } 43 TypedBinOp binOp = acc_.GetTypedBinaryOp(valueGate); 44 if (binOp != TypedBinOp::TYPED_ADD && binOp != TypedBinOp::TYPED_SUB) { 45 return false; 46 } 47 TypedBinaryAccessor accessor(acc_.TryGetValue(valueGate)); 48 const ParamType paramType = accessor.GetParamType(); 49 if (!paramType.IsIntType()) { 50 return false; 51 } 52 53 for (size_t i = 2; i < numValueIn; i++) { // 2: skip startGate and valueGate 54 if (acc_.GetValueIn(gate, i) != valueGate) { 55 return false; 56 } 57 } 58 59 // check if value satisfies a = a + x 60 if (acc_.GetValueIn(valueGate, 0) != gate && acc_.GetValueIn(valueGate, 1) != gate) { 61 return false; 62 } 63 GateRef stride = acc_.GetValueIn(valueGate, 1); 64 if (acc_.GetValueIn(valueGate, 0) != gate) { 65 stride = acc_.GetValueIn(valueGate, 0); 66 } 67 if (!IsIntConstant(stride)) { 68 return false; 69 } 70 return true; 71} 72 73std::pair<int32_t, int32_t> InductionVariableAnalysis::GetStartAndStride(GateRef gate) const 74{ 75 ASSERT(acc_.GetOpCode(gate) == OpCode::VALUE_SELECTOR); 76 GateRef startGate = acc_.GetValueIn(gate, 0); 77 ASSERT(IsIntConstant(startGate)); 78 auto start = GetIntFromTaggedConstant(startGate); 79 80 GateRef valueGate = acc_.GetValueIn(gate, 1); 81 ASSERT(acc_.GetOpCode(valueGate) == OpCode::TYPED_BINARY_OP); 82 [[maybe_unused]]TypedBinOp binOp = acc_.GetTypedBinaryOp(valueGate); 83 ASSERT(binOp == TypedBinOp::TYPED_ADD || binOp == TypedBinOp::TYPED_SUB); 84 TypedBinaryAccessor accessor(acc_.TryGetValue(valueGate)); 85 [[maybe_unused]]const ParamType paramType = accessor.GetParamType(); 86 ASSERT(paramType.IsIntType()); 87 88 GateRef strideGate = acc_.GetValueIn(valueGate, 1); 89 if (acc_.GetValueIn(valueGate, 0) != gate) { 90 strideGate = acc_.GetValueIn(valueGate, 0); 91 } 92 ASSERT(IsIntConstant(strideGate)); 93 auto stride = GetIntFromTaggedConstant(strideGate); 94 95 // a - xb < c -> a + (-x)b < c 96 if (acc_.GetTypedBinaryOp(valueGate) == TypedBinOp::TYPED_SUB) { 97 stride = -stride; 98 } 99 100 return std::make_pair(start, stride); 101} 102 103int32_t InductionVariableAnalysis::GetIntFromTaggedConstant(GateRef gate) const 104{ 105 ASSERT(acc_.GetOpCode(gate) == OpCode::CONSTANT); 106 JSTaggedValue value(acc_.GetConstantValue(gate)); 107 return value.GetInt(); 108} 109 110bool InductionVariableAnalysis::IsLessOrGreaterCmp(GateRef gate) const 111{ 112 return acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_GREATEREQ || 113 acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_GREATER || 114 acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_LESSEQ || 115 acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_LESS; 116} 117 118bool InductionVariableAnalysis::TryGetLoopTimes(const GraphLinearizer::LoopInfo& loop, int32_t& loopTimes) const 119{ 120 if (loop.loopExits->size() > 1) { 121 return false; 122 } 123 ASSERT(loop.loopExits->size() == 1); 124 GateRef loopExit = loop.loopExits->at(0)->GetState(); 125 ASSERT(acc_.GetOpCode(loopExit) == OpCode::IF_TRUE || acc_.GetOpCode(loopExit) == OpCode::IF_FALSE); 126 GateRef conditionJump = acc_.GetState(loopExit); 127 GateRef cmp = acc_.GetValueIn(conditionJump); 128 if (acc_.GetOpCode(cmp) != OpCode::TYPED_BINARY_OP || !IsLessOrGreaterCmp(cmp)) { 129 return false; 130 } 131 GateRef limitGate = acc_.GetValueIn(cmp, 1); 132 if (!IsIntConstant(limitGate)) { 133 return false; 134 } 135 int32_t limit = GetIntFromTaggedConstant(limitGate); 136 137 GateRef selector = acc_.GetValueIn(cmp, 0); 138 if (!IsInductionVariable(selector)) { 139 return false; 140 } 141 142 auto [start, stride] = GetStartAndStride(selector); 143 144 bool cmpFlag = (acc_.GetOpCode(loopExit) == OpCode::IF_TRUE) ^ 145 (acc_.GetTypedJumpAccessor(conditionJump).GetTypedJumpOp() == TypedJumpOp::TYPED_JEQZ) ^ 146 (acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESSEQ || 147 acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESS); 148 bool equalFlag = (acc_.GetOpCode(loopExit) == OpCode::IF_TRUE) ^ 149 (acc_.GetTypedJumpAccessor(conditionJump).GetTypedJumpOp() == TypedJumpOp::TYPED_JEQZ) ^ 150 (acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_GREATEREQ || 151 acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESSEQ); 152 153 // a + xb >= c -> c - xb <= a 154 if (!cmpFlag) { 155 std::swap(start, limit); 156 stride = -stride; 157 } 158 // a + xb < c -> a + xb <= c - 1 159 if (!equalFlag) { 160 limit--; 161 } 162 if (start > limit) { 163 loopTimes = 0; 164 return true; 165 } 166 loopTimes = (limit - start) / stride + 1; 167 if (IsLogEnabled() && IsTraced()) { 168 LOG_COMPILER(INFO) << "loopTimes: "<< loopTimes << " start: " << start 169 << " stride: " << stride << " limit: " << limit; 170 } 171 return true; 172} 173 174void InductionVariableAnalysis::CollectInductionSelector() 175{ 176 for (const auto &loop : graphLinearizer_.loops_) { 177 int32_t loopTimes = 0; 178 179 if (TryGetLoopTimes(loop, loopTimes)) { 180 ReplaceInductionVariable(loop, loopTimes); 181 } 182 } 183} 184 185void InductionVariableAnalysis::ReplaceInductionVariable(const GraphLinearizer::LoopInfo& loop, 186 const int32_t loopTimes) 187{ 188 GateRef loopBegin = loop.loopHead->GetState(); 189 auto uses = acc_.Uses(loopBegin); 190 for (auto it = uses.begin(); it != uses.end(); it++) { 191 if (acc_.GetOpCode(*it) == OpCode::VALUE_SELECTOR) { 192 ASSERT(acc_.GetState(*it) == loopBegin); 193 if (!IsInductionVariable(*it)) { 194 continue; 195 } 196 auto [start, stride] = GetStartAndStride(*it); 197 int64_t result = start + static_cast<int64_t>(stride) * loopTimes; 198 if (result > static_cast<int64_t>(INT_MAX) || result < static_cast<int64_t>(INT_MIN)) { 199 return; 200 } 201 if (IsLogEnabled() && IsTraced()) { 202 LOG_COMPILER(INFO) << "result = " << start << " + " << stride << " * " 203 << loopTimes << " = " << result; 204 } 205 TryReplaceOutOfLoopUses(*it, loop, static_cast<int32_t>(result)); 206 } 207 } 208} 209 210void InductionVariableAnalysis::TryReplaceOutOfLoopUses(GateRef gate, 211 const GraphLinearizer::LoopInfo& loop, 212 const int32_t result) 213{ 214 ASSERT(IsInductionVariable(gate)); 215 auto uses = acc_.Uses(gate); 216 for (auto it = uses.begin(); it != uses.end();) { 217 auto region = graphLinearizer_.GateToRegion(*it); 218 if (!loop.loopBodys->TestBit(region->GetId()) && loop.loopHead != region) { 219 GateRef constantValue = builder_.Int32(result); 220 it = acc_.ReplaceIn(it, constantValue); 221 } else { 222 it++; 223 } 224 } 225} 226 227void InductionVariableAnalysis::Run() 228{ 229 graphLinearizer_.SetScheduleJSOpcode(); 230 graphLinearizer_.LinearizeGraph(); 231 CollectInductionSelector(); 232 if (IsLogEnabled()) { 233 LOG_COMPILER(INFO) << ""; 234 LOG_COMPILER(INFO) << "\033[34m" 235 << "====================" 236 << " After Induction Variable Analysis " 237 << "[" << GetMethodName() << "]" 238 << "====================" 239 << "\033[0m"; 240 circuit_->PrintAllGatesWithBytecode(); 241 LOG_COMPILER(INFO) << "\033[34m" << "========================= End ==========================" << "\033[0m"; 242 } 243} 244 245}