1/*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#ifndef MPL2MPL_INCLUDE_SCC_H
17#define MPL2MPL_INCLUDE_SCC_H
18#include "base_graph_node.h"
19namespace maple {
20class BaseGraphNode;
21
22constexpr uint32 kShiftSccUniqueIDNum = 16;
23
24// Note that T is the type of the graph node.
25template <typename T>
26class SCCNode {
27public:
28    SCCNode(uint32 index, MapleAllocator &alloc)
29        : id(index), nodes(alloc.Adapter()), inScc(alloc.Adapter()), outScc(alloc.Adapter())
30    {
31    }
32
33    ~SCCNode() = default;
34
35    void AddNode(T *node)
36    {
37        nodes.push_back(node);
38    }
39
40    void Dump() const
41    {
42        LogInfo::MapleLogger() << "SCC " << id << " contains " << nodes.size() << " node(s)\n";
43        for (auto const kIt : nodes) {
44            T *node = kIt;
45            LogInfo::MapleLogger() << "  " << node->GetIdentity() << "\n";
46        }
47    }
48
49    void DumpCycle() const
50    {
51        T *currNode = nodes[0];
52        std::vector<T *> searched;
53        searched.push_back(currNode);
54        std::vector<T *> invalidNodes;
55        std::vector<BaseGraphNode *> outNodes;
56        while (true) {
57            bool findNewOut = false;
58            currNode->GetOutNodes(outNodes);
59            for (auto outIt : outNodes) {
60                auto outNode = static_cast<T *>(outIt);
61                if (outNode->GetSCCNode() == this) {
62                    size_t j = 0;
63                    for (; j < invalidNodes.size(); ++j) {
64                        if (invalidNodes[j] == outNode) {
65                            break;
66                        }
67                    }
68                    // Find a invalid node
69                    if (j < invalidNodes.size()) {
70                        continue;
71                    }
72                    for (j = 0; j < searched.size(); ++j) {
73                        if (searched[j] == outNode) {
74                            break;
75                        }
76                    }
77                    if (j == searched.size()) {
78                        currNode = outNode;
79                        searched.push_back(currNode);
80                        findNewOut = true;
81                        break;
82                    }
83                }
84            }
85            outNodes.clear();
86            if (searched.size() == nodes.size()) {
87                break;
88            }
89            if (!findNewOut) {
90                invalidNodes.push_back(searched[searched.size() - 1]);
91                searched.pop_back();
92                currNode = searched[searched.size() - 1];
93            }
94        }
95        for (auto it = searched.begin(); it != searched.end(); ++it) {
96            LogInfo::MapleLogger() << (*it)->GetIdentity() << '\n';
97        }
98    }
99
100    void Verify() const
101    {
102        CHECK_FATAL(!nodes.empty(), "the size of nodes less than zero");
103        for (T *const &node : nodes) {
104            if (node->GetSCCNode() != this) {
105                CHECK_FATAL(false, "must equal this");
106            }
107        }
108    }
109
110    void Setup()
111    {
112        std::vector<BaseGraphNode *> outNodes;
113        std::vector<BaseGraphNode *> inNodes;
114        for (T *const &node : nodes) {
115            node->GetOutNodes(outNodes);
116            for (auto outIt : outNodes) {
117                auto outNode = static_cast<T *>(outIt);
118                if (outNode == nullptr) {
119                    continue;
120                }
121                auto outNodeScc = outNode->GetSCCNode();
122                if (outNodeScc == this) {
123                    continue;
124                }
125                outScc.insert(outNodeScc);
126                outNodeScc->inScc.insert(this);
127            }
128            outNodes.clear();
129        }
130    }
131
132    const MapleVector<T *> &GetNodes() const
133    {
134        return nodes;
135    }
136
137    MapleVector<T *> &GetNodes()
138    {
139        return nodes;
140    }
141
142    const MapleSet<SCCNode<T> *, Comparator<SCCNode<T>>> &GetOutScc() const
143    {
144        return outScc;
145    }
146
147    const MapleSet<SCCNode<T> *, Comparator<SCCNode<T>>> &GetInScc() const
148    {
149        return inScc;
150    }
151
152    void RemoveInScc(SCCNode<T> *const sccNode)
153    {
154        inScc.erase(sccNode);
155    }
156
157    bool HasRecursion() const
158    {
159        if (nodes.empty()) {
160            return false;
161        }
162        if (nodes.size() > 1) {
163            return true;
164        }
165        T *node = nodes[0];
166        std::vector<BaseGraphNode *> outNodes;
167        node->GetOutNodes(outNodes);
168        for (auto outIt : outNodes) {
169            auto outNode = static_cast<T *>(outIt);
170            if (outNode == nullptr) {
171                continue;
172            }
173            if (node == outNode) {
174                return true;
175            }
176        }
177        return false;
178    }
179
180    bool HasSelfRecursion() const
181    {
182        if (nodes.size() != 1) {
183            return false;
184        }
185        T *node = nodes[0];
186        std::vector<BaseGraphNode *> outNodes;
187        node->GetOutNodes(outNodes);
188        for (auto outIt : outNodes) {
189            auto outNode = static_cast<T *>(outIt);
190            if (outNode == nullptr) {
191                continue;
192            }
193            if (node == outNode) {
194                return true;
195            }
196        }
197        return false;
198    }
199
200    bool HasInScc() const
201    {
202        return (!inScc.empty());
203    }
204
205    uint32 GetID() const
206    {
207        return id;
208    }
209
210    uint32 GetUniqueID() const
211    {
212        return GetID() << maple::kShiftSccUniqueIDNum;
213    }
214
215private:
216    uint32 id;
217    MapleVector<T *> nodes;
218    MapleSet<SCCNode<T> *, Comparator<SCCNode<T>>> inScc;
219    MapleSet<SCCNode<T> *, Comparator<SCCNode<T>>> outScc;
220};
221
222template <typename T>
223void BuildSCCDFS(T &rootNode, uint32 &visitIndex, MapleVector<SCCNode<T> *> &sccNodes, std::vector<T *> &nodes,
224                 std::vector<uint32> &visitedOrder, std::vector<uint32> &lowestOrder, std::vector<bool> &inStack,
225                 std::vector<uint32> &visitStack, uint32 &numOfSccs, MapleAllocator &cgAlloc)
226{
227    uint32 id = rootNode.GetID();
228    nodes.at(id) = &rootNode;
229    visitedOrder.at(id) = visitIndex;
230    lowestOrder.at(id) = visitIndex;
231    ++visitIndex;
232    inStack.at(id) = true;
233
234    std::vector<BaseGraphNode *> outNodes;
235    rootNode.GetOutNodes(outNodes);
236    for (auto outIt : outNodes) {
237        auto outNode = static_cast<T *>(outIt);
238        if (outNode == nullptr) {
239            continue;
240        }
241        uint32 outNodeId = outNode->GetID();
242        if (visitedOrder.at(outNodeId) == 0) {
243            // callee has not been processed yet
244            BuildSCCDFS(*outNode, visitIndex, sccNodes, nodes, visitedOrder, lowestOrder, inStack, visitStack,
245                        numOfSccs, cgAlloc);
246            if (lowestOrder.at(outNodeId) < lowestOrder.at(id)) {
247                lowestOrder.at(id) = lowestOrder.at(outNodeId);
248            }
249        } else if (inStack.at(outNodeId) && (visitedOrder.at(outNodeId) < lowestOrder.at(id))) {
250            // back edge
251            lowestOrder.at(id) = visitedOrder.at(outNodeId);
252        }
253    }
254
255    if (visitedOrder.at(id) == lowestOrder.at(id)) {
256        SCCNode<T> *sccNode = cgAlloc.GetMemPool()->New<SCCNode<T>>(numOfSccs++, cgAlloc);
257        inStack.at(id) = false;
258        T *rootNode = nodes.at(id);
259        rootNode->SetSCCNode(sccNode);
260        sccNode->AddNode(rootNode);
261        while (!visitStack.empty()) {
262            auto stackTopId = visitStack.back();
263            if (visitedOrder.at(stackTopId) < visitedOrder.at(id)) {
264                break;
265            }
266            visitStack.pop_back();
267            inStack.at(stackTopId) = false;
268            T *topNode = nodes.at(stackTopId);
269            topNode->SetSCCNode(sccNode);
270            sccNode->AddNode(topNode);
271        }
272        sccNodes.push_back(sccNode);
273    } else {
274        visitStack.push_back(id);
275    }
276}
277
278template <typename T>
279void VerifySCC(std::vector<T *> nodes)
280{
281    for (auto node : nodes) {
282        if (node->GetSCCNode() == nullptr) {
283            CHECK_FATAL(false, "nullptr check in VerifySCC()");
284        }
285    }
286}
287
288template <typename T>
289uint32 BuildSCC(MapleAllocator &cgAlloc, uint32 numOfNodes, std::vector<T *> &allNodes, bool debugScc,
290                MapleVector<SCCNode<T> *> &topologicalVec, bool clearOld = false)
291{
292    // This is the mapping between cg_id to node.
293    std::vector<T *> id2NodeMap(numOfNodes, nullptr);
294    std::vector<uint32> visitedOrder(numOfNodes, 0);
295    std::vector<uint32> lowestOrder(numOfNodes, 0);
296    std::vector<bool> inStack(numOfNodes, false);
297    std::vector<uint32> visitStack;
298    uint32 visitIndex = 1;
299    uint32 numOfSccs = 0;
300    if (clearOld) {
301        // clear old scc before computing
302        for (auto node : allNodes) {
303            node->SetSCCNode(nullptr);
304        }
305    }
306    // However, not all SCC can be reached from roots.
307    // E.g. foo()->foo(), foo is not considered as a root.
308    for (auto node : allNodes) {
309        if (node->GetSCCNode() == nullptr) {
310            BuildSCCDFS(*node, visitIndex, topologicalVec, id2NodeMap, visitedOrder, lowestOrder, inStack, visitStack,
311                        numOfSccs, cgAlloc);
312        }
313    }
314    for (auto scc : topologicalVec) {
315        scc->Verify();
316        scc->Setup();  // fix caller and callee info.
317        if (debugScc && scc->HasRecursion()) {
318            scc->Dump();
319        }
320    }
321    std::reverse(topologicalVec.begin(), topologicalVec.end());
322    return numOfSccs;
323}
324}  // namespace maple
325#endif  // MPL2MPL_INCLUDE_SCC_H