1/*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "ecmascript/compiler/induction_variable_analysis.h"
17
18namespace panda::ecmascript::kungfu {
19
20bool InductionVariableAnalysis::IsIntConstant(GateRef gate) const
21{
22    if (acc_.GetOpCode(gate) != OpCode::CONSTANT) {
23        return false;
24    }
25    JSTaggedValue value(acc_.GetConstantValue(gate));
26    return value.IsInt();
27}
28
29bool InductionVariableAnalysis::IsInductionVariable(GateRef gate) const
30{
31    if (acc_.GetOpCode(gate) != OpCode::VALUE_SELECTOR) {
32        return false;
33    }
34    size_t numValueIn = acc_.GetNumValueIn(gate);
35    GateRef startGate = acc_.GetValueIn(gate, 0);
36    GateRef valueGate = acc_.GetValueIn(gate, 1);
37    if (!IsIntConstant(startGate)) {
38        return false;
39    }
40    if (acc_.GetOpCode(valueGate) != OpCode::TYPED_BINARY_OP) {
41        return false;
42    }
43    TypedBinOp binOp = acc_.GetTypedBinaryOp(valueGate);
44    if (binOp != TypedBinOp::TYPED_ADD && binOp != TypedBinOp::TYPED_SUB) {
45        return false;
46    }
47    TypedBinaryAccessor accessor(acc_.TryGetValue(valueGate));
48    const ParamType paramType = accessor.GetParamType();
49    if (!paramType.IsIntType()) {
50        return false;
51    }
52
53    for (size_t i = 2; i < numValueIn; i++) { // 2: skip startGate and valueGate
54        if (acc_.GetValueIn(gate, i) != valueGate) {
55            return false;
56        }
57    }
58
59    // check if value satisfies a = a + x
60    if (acc_.GetValueIn(valueGate, 0) != gate && acc_.GetValueIn(valueGate, 1) != gate) {
61        return false;
62    }
63    GateRef stride = acc_.GetValueIn(valueGate, 1);
64    if (acc_.GetValueIn(valueGate, 0) != gate) {
65        stride = acc_.GetValueIn(valueGate, 0);
66    }
67    if (!IsIntConstant(stride)) {
68        return false;
69    }
70    return true;
71}
72
73std::pair<int32_t, int32_t> InductionVariableAnalysis::GetStartAndStride(GateRef gate) const
74{
75    ASSERT(acc_.GetOpCode(gate) == OpCode::VALUE_SELECTOR);
76    GateRef startGate = acc_.GetValueIn(gate, 0);
77    ASSERT(IsIntConstant(startGate));
78    auto start = GetIntFromTaggedConstant(startGate);
79
80    GateRef valueGate = acc_.GetValueIn(gate, 1);
81    ASSERT(acc_.GetOpCode(valueGate) == OpCode::TYPED_BINARY_OP);
82    [[maybe_unused]]TypedBinOp binOp = acc_.GetTypedBinaryOp(valueGate);
83    ASSERT(binOp == TypedBinOp::TYPED_ADD || binOp == TypedBinOp::TYPED_SUB);
84    TypedBinaryAccessor accessor(acc_.TryGetValue(valueGate));
85    [[maybe_unused]]const ParamType paramType = accessor.GetParamType();
86    ASSERT(paramType.IsIntType());
87
88    GateRef strideGate = acc_.GetValueIn(valueGate, 1);
89    if (acc_.GetValueIn(valueGate, 0) != gate) {
90        strideGate = acc_.GetValueIn(valueGate, 0);
91    }
92    ASSERT(IsIntConstant(strideGate));
93    auto stride = GetIntFromTaggedConstant(strideGate);
94
95    // a - xb < c -> a + (-x)b < c
96    if (acc_.GetTypedBinaryOp(valueGate) == TypedBinOp::TYPED_SUB) {
97        stride = -stride;
98    }
99
100    return std::make_pair(start, stride);
101}
102
103int32_t InductionVariableAnalysis::GetIntFromTaggedConstant(GateRef gate) const
104{
105    ASSERT(acc_.GetOpCode(gate) == OpCode::CONSTANT);
106    JSTaggedValue value(acc_.GetConstantValue(gate));
107    return value.GetInt();
108}
109
110bool InductionVariableAnalysis::IsLessOrGreaterCmp(GateRef gate) const
111{
112    return acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_GREATEREQ ||
113        acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_GREATER ||
114        acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_LESSEQ ||
115        acc_.GetTypedBinaryOp(gate) == TypedBinOp::TYPED_LESS;
116}
117
118bool InductionVariableAnalysis::TryGetLoopTimes(const GraphLinearizer::LoopInfo& loop, int32_t& loopTimes) const
119{
120    if (loop.loopExits->size() > 1) {
121        return false;
122    }
123    ASSERT(loop.loopExits->size() == 1);
124    GateRef loopExit = loop.loopExits->at(0)->GetState();
125    ASSERT(acc_.GetOpCode(loopExit) == OpCode::IF_TRUE || acc_.GetOpCode(loopExit) == OpCode::IF_FALSE);
126    GateRef conditionJump = acc_.GetState(loopExit);
127    GateRef cmp = acc_.GetValueIn(conditionJump);
128    if (acc_.GetOpCode(cmp) != OpCode::TYPED_BINARY_OP || !IsLessOrGreaterCmp(cmp)) {
129        return false;
130    }
131    GateRef limitGate = acc_.GetValueIn(cmp, 1);
132    if (!IsIntConstant(limitGate)) {
133        return false;
134    }
135    int32_t limit = GetIntFromTaggedConstant(limitGate);
136
137    GateRef selector = acc_.GetValueIn(cmp, 0);
138    if (!IsInductionVariable(selector)) {
139        return false;
140    }
141
142    auto [start, stride] = GetStartAndStride(selector);
143
144    bool cmpFlag = (acc_.GetOpCode(loopExit) == OpCode::IF_TRUE) ^
145        (acc_.GetTypedJumpAccessor(conditionJump).GetTypedJumpOp() == TypedJumpOp::TYPED_JEQZ) ^
146        (acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESSEQ ||
147        acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESS);
148    bool equalFlag = (acc_.GetOpCode(loopExit) == OpCode::IF_TRUE) ^
149        (acc_.GetTypedJumpAccessor(conditionJump).GetTypedJumpOp() == TypedJumpOp::TYPED_JEQZ) ^
150        (acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_GREATEREQ ||
151        acc_.GetTypedBinaryOp(cmp) == TypedBinOp::TYPED_LESSEQ);
152
153    // a + xb >= c -> c - xb <= a
154    if (!cmpFlag) {
155        std::swap(start, limit);
156        stride = -stride;
157    }
158    // a + xb < c -> a + xb <= c - 1
159    if (!equalFlag) {
160        limit--;
161    }
162    if (start > limit) {
163        loopTimes = 0;
164        return true;
165    }
166    loopTimes = (limit - start) / stride + 1;
167    if (IsLogEnabled() && IsTraced()) {
168        LOG_COMPILER(INFO) << "loopTimes: "<< loopTimes << " start: " << start
169                           << " stride: " << stride << " limit: " << limit;
170    }
171    return true;
172}
173
174void InductionVariableAnalysis::CollectInductionSelector()
175{
176    for (const auto &loop : graphLinearizer_.loops_) {
177        int32_t loopTimes = 0;
178
179        if (TryGetLoopTimes(loop, loopTimes)) {
180            ReplaceInductionVariable(loop, loopTimes);
181        }
182    }
183}
184
185void InductionVariableAnalysis::ReplaceInductionVariable(const GraphLinearizer::LoopInfo& loop,
186                                                         const int32_t loopTimes)
187{
188    GateRef loopBegin = loop.loopHead->GetState();
189    auto uses = acc_.Uses(loopBegin);
190    for (auto it = uses.begin(); it != uses.end(); it++) {
191        if (acc_.GetOpCode(*it) == OpCode::VALUE_SELECTOR) {
192            ASSERT(acc_.GetState(*it) == loopBegin);
193            if (!IsInductionVariable(*it)) {
194                continue;
195            }
196            auto [start, stride] = GetStartAndStride(*it);
197            int64_t result = start + static_cast<int64_t>(stride) * loopTimes;
198            if (result > static_cast<int64_t>(INT_MAX) || result < static_cast<int64_t>(INT_MIN)) {
199                return;
200            }
201            if (IsLogEnabled() && IsTraced()) {
202                LOG_COMPILER(INFO) << "result = " << start << " + " << stride << " * "
203                                    << loopTimes << " = " << result;
204            }
205            TryReplaceOutOfLoopUses(*it, loop, static_cast<int32_t>(result));
206        }
207    }
208}
209
210void InductionVariableAnalysis::TryReplaceOutOfLoopUses(GateRef gate,
211                                                        const GraphLinearizer::LoopInfo& loop,
212                                                        const int32_t result)
213{
214    ASSERT(IsInductionVariable(gate));
215    auto uses = acc_.Uses(gate);
216    for (auto it = uses.begin(); it != uses.end();) {
217        auto region = graphLinearizer_.GateToRegion(*it);
218        if (!loop.loopBodys->TestBit(region->GetId()) && loop.loopHead != region) {
219            GateRef constantValue = builder_.Int32(result);
220            it = acc_.ReplaceIn(it, constantValue);
221        } else {
222            it++;
223        }
224    }
225}
226
227void InductionVariableAnalysis::Run()
228{
229    graphLinearizer_.SetScheduleJSOpcode();
230    graphLinearizer_.LinearizeGraph();
231    CollectInductionSelector();
232    if (IsLogEnabled()) {
233        LOG_COMPILER(INFO) << "";
234        LOG_COMPILER(INFO) << "\033[34m"
235                           << "===================="
236                           << " After Induction Variable Analysis "
237                           << "[" << GetMethodName() << "]"
238                           << "===================="
239                           << "\033[0m";
240        circuit_->PrintAllGatesWithBytecode();
241        LOG_COMPILER(INFO) << "\033[34m" << "========================= End ==========================" << "\033[0m";
242    }
243}
244
245}