1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "ecmascript/compiler/induction_variable_analysis.h"
17 
18 namespace panda::ecmascript::kungfu {
19 
IsIntConstant(GateRef gate) const20 bool InductionVariableAnalysis::IsIntConstant(GateRef gate) const
21 {
22     if (acc_.GetOpCode(gate) != OpCode::CONSTANT) {
23         return false;
24     }
25     JSTaggedValue value(acc_.GetConstantValue(gate));
26     return value.IsInt();
27 }
28 
IsInductionVariable(GateRef gate) const29 bool InductionVariableAnalysis::IsInductionVariable(GateRef gate) const
30 {
31     if (acc_.GetOpCode(gate) != OpCode::VALUE_SELECTOR) {
32         return false;
33     }
34     size_t numValueIn = acc_.GetNumValueIn(gate);
35     GateRef startGate = acc_.GetValueIn(gate, 0);
36     GateRef valueGate = acc_.GetValueIn(gate, 1);
37     if (!IsIntConstant(startGate)) {
38         return false;
39     }
40     if (acc_.GetOpCode(valueGate) != OpCode::TYPED_BINARY_OP) {
41         return false;
42     }
43     TypedBinOp binOp = acc_.GetTypedBinaryOp(valueGate);
44     if (binOp != TypedBinOp::TYPED_ADD && binOp != TypedBinOp::TYPED_SUB) {
45         return false;
46     }
47     TypedBinaryAccessor accessor(acc_.TryGetValue(valueGate));
48     const ParamType paramType = accessor.GetParamType();
49     if (!paramType.IsIntType()) {
50         return false;
51     }
52 
53     for (size_t i = 2; i < numValueIn; i++) { // 2: skip startGate and valueGate
54         if (acc_.GetValueIn(gate, i) != valueGate) {
55             return false;
56         }
57     }
58 
59     // check if value satisfies a = a + x
60     if (acc_.GetValueIn(valueGate, 0) != gate && acc_.GetValueIn(valueGate, 1) != gate) {
61         return false;
62     }
63     GateRef stride = acc_.GetValueIn(valueGate, 1);
64     if (acc_.GetValueIn(valueGate, 0) != gate) {
65         stride = acc_.GetValueIn(valueGate, 0);
66     }
67     if (!IsIntConstant(stride)) {
68         return false;
69     }
70     return true;
71 }
72 
GetStartAndStride(GateRef gate) const73 std::pair<int32_t, int32_t> InductionVariableAnalysis::GetStartAndStride(GateRef gate) const
74 {
75     ASSERT(acc_.GetOpCode(gate) == OpCode::VALUE_SELECTOR);
76     GateRef startGate = acc_.GetValueIn(gate, 0);
77     ASSERT(IsIntConstant(startGate));
78     auto start = GetIntFromTaggedConstant(startGate);
79 
80     GateRef valueGate = acc_.GetValueIn(gate, 1);
81     ASSERT(acc_.GetOpCode(valueGate) == OpCode::TYPED_BINARY_OP);
82     [[maybe_unused]]TypedBinOp binOp = acc_.GetTypedBinaryOp(valueGate);
83     ASSERT(binOp == TypedBinOp::TYPED_ADD || binOp == TypedBinOp::TYPED_SUB);
84     TypedBinaryAccessor accessor(acc_.TryGetValue(valueGate));
85     [[maybe_unused]]const ParamType paramType = accessor.GetParamType();
86     ASSERT(paramType.IsIntType());
87 
88     GateRef strideGate = acc_.GetValueIn(valueGate, 1);
89     if (acc_.GetValueIn(valueGate, 0) != gate) {
90         strideGate = acc_.GetValueIn(valueGate, 0);
91     }
92     ASSERT(IsIntConstant(strideGate));
93     auto stride = GetIntFromTaggedConstant(strideGate);
94 
95     // a - xb < c -> a + (-x)b < c
96     if (acc_.GetTypedBinaryOp(valueGate) == TypedBinOp::TYPED_SUB) {
97         stride = -stride;
98     }
99 
100     return std::make_pair(start, stride);
101 }
102 
GetIntFromTaggedConstant(GateRef gate) const103 int32_t InductionVariableAnalysis::GetIntFromTaggedConstant(GateRef gate) const
104 {
105     ASSERT(acc_.GetOpCode(gate) == OpCode::CONSTANT);
106     JSTaggedValue value(acc_.GetConstantValue(gate));
107     return value.GetInt();
108 }
109 
IsLessOrGreaterCmp(GateRef gate) const110 bool InductionVariableAnalysis::IsLessOrGreaterCmp(GateRef gate) const
111 {
112     return acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_GREATEREQ ||
113         acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_GREATER ||
114         acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_LESSEQ ||
115         acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_LESS;
116 }
117 
TryGetLoopTimes(const GraphLinearizer::LoopInfo& loop, int32_t& loopTimes) const118 bool InductionVariableAnalysis::TryGetLoopTimes(const GraphLinearizer::LoopInfo& loop, int32_t& loopTimes) const
119 {
120     if (loop.loopExits->size() > 1) {
121         return false;
122     }
123     ASSERT(loop.loopExits->size() == 1);
124     GateRef loopExit = loop.loopExits->at(0)->GetState();
125     ASSERT(acc_.GetOpCode(loopExit) == OpCode::IF_TRUE || acc_.GetOpCode(loopExit) == OpCode::IF_FALSE);
126     GateRef conditionJump = acc_.GetState(loopExit);
127     GateRef cmp = acc_.GetValueIn(conditionJump);
128     if (acc_.GetOpCode(cmp) != OpCode::TYPED_BINARY_OP || !IsLessOrGreaterCmp(cmp)) {
129         return false;
130     }
131     GateRef limitGate = acc_.GetValueIn(cmp, 1);
132     if (!IsIntConstant(limitGate)) {
133         return false;
134     }
135     int32_t limit = GetIntFromTaggedConstant(limitGate);
136 
137     GateRef selector = acc_.GetValueIn(cmp, 0);
138     if (!IsInductionVariable(selector)) {
139         return false;
140     }
141 
142     auto [start, stride] = GetStartAndStride(selector);
143 
144     bool cmpFlag = (acc_.GetOpCode(loopExit) == OpCode::IF_TRUE) ^
145         (acc_.GetTypedJumpAccessor(conditionJump).GetTypedJumpOp() == TypedJumpOp::TYPED_JEQZ) ^
146         (acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESSEQ ||
147         acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESS);
148     bool equalFlag = (acc_.GetOpCode(loopExit) == OpCode::IF_TRUE) ^
149         (acc_.GetTypedJumpAccessor(conditionJump).GetTypedJumpOp() == TypedJumpOp::TYPED_JEQZ) ^
150         (acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_GREATEREQ ||
151         acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESSEQ);
152 
153     // a + xb >= c -> c - xb <= a
154     if (!cmpFlag) {
155         std::swap(start, limit);
156         stride = -stride;
157     }
158     // a + xb < c -> a + xb <= c - 1
159     if (!equalFlag) {
160         limit--;
161     }
162     if (start > limit) {
163         loopTimes = 0;
164         return true;
165     }
166     loopTimes = (limit - start) / stride + 1;
167     if (IsLogEnabled() && IsTraced()) {
168         LOG_COMPILER(INFO) << "loopTimes: "<< loopTimes << " start: " << start
169                            << " stride: " << stride << " limit: " << limit;
170     }
171     return true;
172 }
173 
CollectInductionSelector()174 void InductionVariableAnalysis::CollectInductionSelector()
175 {
176     for (const auto &loop : graphLinearizer_.loops_) {
177         int32_t loopTimes = 0;
178 
179         if (TryGetLoopTimes(loop, loopTimes)) {
180             ReplaceInductionVariable(loop, loopTimes);
181         }
182     }
183 }
184 
ReplaceInductionVariable(const GraphLinearizer::LoopInfo& loop, const int32_t loopTimes)185 void InductionVariableAnalysis::ReplaceInductionVariable(const GraphLinearizer::LoopInfo& loop,
186                                                          const int32_t loopTimes)
187 {
188     GateRef loopBegin = loop.loopHead->GetState();
189     auto uses = acc_.Uses(loopBegin);
190     for (auto it = uses.begin(); it != uses.end(); it++) {
191         if (acc_.GetOpCode(*it) == OpCode::VALUE_SELECTOR) {
192             ASSERT(acc_.GetState(*it) == loopBegin);
193             if (!IsInductionVariable(*it)) {
194                 continue;
195             }
196             auto [start, stride] = GetStartAndStride(*it);
197             int64_t result = start + static_cast<int64_t>(stride) * loopTimes;
198             if (result > static_cast<int64_t>(INT_MAX) || result < static_cast<int64_t>(INT_MIN)) {
199                 return;
200             }
201             if (IsLogEnabled() && IsTraced()) {
202                 LOG_COMPILER(INFO) << "result = " << start << " + " << stride << " * "
203                                     << loopTimes << " = " << result;
204             }
205             TryReplaceOutOfLoopUses(*it, loop, static_cast<int32_t>(result));
206         }
207     }
208 }
209 
TryReplaceOutOfLoopUses(GateRef gate, const GraphLinearizer::LoopInfo& loop, const int32_t result)210 void InductionVariableAnalysis::TryReplaceOutOfLoopUses(GateRef gate,
211                                                         const GraphLinearizer::LoopInfo& loop,
212                                                         const int32_t result)
213 {
214     ASSERT(IsInductionVariable(gate));
215     auto uses = acc_.Uses(gate);
216     for (auto it = uses.begin(); it != uses.end();) {
217         auto region = graphLinearizer_.GateToRegion(*it);
218         if (!loop.loopBodys->TestBit(region->GetId()) && loop.loopHead != region) {
219             GateRef constantValue = builder_.Int32(result);
220             it = acc_.ReplaceIn(it, constantValue);
221         } else {
222             it++;
223         }
224     }
225 }
226 
Run()227 void InductionVariableAnalysis::Run()
228 {
229     graphLinearizer_.SetScheduleJSOpcode();
230     graphLinearizer_.LinearizeGraph();
231     CollectInductionSelector();
232     if (IsLogEnabled()) {
233         LOG_COMPILER(INFO) << "";
234         LOG_COMPILER(INFO) << "\033[34m"
235                            << "===================="
236                            << " After Induction Variable Analysis "
237                            << "[" << GetMethodName() << "]"
238                            << "===================="
239                            << "\033[0m";
240         circuit_->PrintAllGatesWithBytecode();
241         LOG_COMPILER(INFO) << "\033[34m" << "========================= End ==========================" << "\033[0m";
242     }
243 }
244 
245 }