1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/struct_cfg_analysis.h"
16
17 #include "source/opt/ir_context.h"
18
19 namespace spvtools {
20 namespace opt {
21 namespace {
22 constexpr uint32_t kMergeNodeIndex = 0;
23 constexpr uint32_t kContinueNodeIndex = 1;
24 } // namespace
25
StructuredCFGAnalysis(IRContext* ctx)26 StructuredCFGAnalysis::StructuredCFGAnalysis(IRContext* ctx) : context_(ctx) {
27 // If this is not a shader, there are no merge instructions, and not
28 // structured CFG to analyze.
29 if (!context_->get_feature_mgr()->HasCapability(spv::Capability::Shader)) {
30 return;
31 }
32
33 for (auto& func : *context_->module()) {
34 AddBlocksInFunction(&func);
35 }
36 }
37
AddBlocksInFunction(Function* func)38 void StructuredCFGAnalysis::AddBlocksInFunction(Function* func) {
39 if (func->begin() == func->end()) return;
40
41 std::list<BasicBlock*> order;
42 context_->cfg()->ComputeStructuredOrder(func, &*func->begin(), &order);
43
44 struct TraversalInfo {
45 ConstructInfo cinfo;
46 uint32_t merge_node;
47 uint32_t continue_node;
48 };
49
50 // Set up a stack to keep track of currently active constructs.
51 std::vector<TraversalInfo> state;
52 state.emplace_back();
53 state[0].cinfo.containing_construct = 0;
54 state[0].cinfo.containing_loop = 0;
55 state[0].cinfo.containing_switch = 0;
56 state[0].cinfo.in_continue = false;
57 state[0].merge_node = 0;
58 state[0].continue_node = 0;
59
60 for (BasicBlock* block : order) {
61 if (context_->cfg()->IsPseudoEntryBlock(block) ||
62 context_->cfg()->IsPseudoExitBlock(block)) {
63 continue;
64 }
65
66 if (block->id() == state.back().merge_node) {
67 state.pop_back();
68 }
69
70 // This works because the structured order is designed to keep the blocks in
71 // the continue construct between the continue header and the merge node.
72 if (block->id() == state.back().continue_node) {
73 state.back().cinfo.in_continue = true;
74 }
75
76 bb_to_construct_.emplace(std::make_pair(block->id(), state.back().cinfo));
77
78 if (Instruction* merge_inst = block->GetMergeInst()) {
79 TraversalInfo new_state;
80 new_state.merge_node =
81 merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
82 new_state.cinfo.containing_construct = block->id();
83
84 if (merge_inst->opcode() == spv::Op::OpLoopMerge) {
85 new_state.cinfo.containing_loop = block->id();
86 new_state.cinfo.containing_switch = 0;
87 new_state.continue_node =
88 merge_inst->GetSingleWordInOperand(kContinueNodeIndex);
89 if (block->id() == new_state.continue_node) {
90 new_state.cinfo.in_continue = true;
91 bb_to_construct_[block->id()].in_continue = true;
92 } else {
93 new_state.cinfo.in_continue = false;
94 }
95 } else {
96 new_state.cinfo.containing_loop = state.back().cinfo.containing_loop;
97 new_state.cinfo.in_continue = state.back().cinfo.in_continue;
98 new_state.continue_node = state.back().continue_node;
99
100 if (merge_inst->NextNode()->opcode() == spv::Op::OpSwitch) {
101 new_state.cinfo.containing_switch = block->id();
102 } else {
103 new_state.cinfo.containing_switch =
104 state.back().cinfo.containing_switch;
105 }
106 }
107
108 state.emplace_back(new_state);
109 merge_blocks_.Set(new_state.merge_node);
110 }
111 }
112 }
113
ContainingConstruct(Instruction* inst)114 uint32_t StructuredCFGAnalysis::ContainingConstruct(Instruction* inst) {
115 uint32_t bb = context_->get_instr_block(inst)->id();
116 return ContainingConstruct(bb);
117 }
118
MergeBlock(uint32_t bb_id)119 uint32_t StructuredCFGAnalysis::MergeBlock(uint32_t bb_id) {
120 uint32_t header_id = ContainingConstruct(bb_id);
121 if (header_id == 0) {
122 return 0;
123 }
124
125 BasicBlock* header = context_->cfg()->block(header_id);
126 Instruction* merge_inst = header->GetMergeInst();
127 return merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
128 }
129
NestingDepth(uint32_t bb_id)130 uint32_t StructuredCFGAnalysis::NestingDepth(uint32_t bb_id) {
131 uint32_t result = 0;
132
133 // Find the merge block of the current merge construct as long as the block is
134 // inside a merge construct, exiting one for each iteration.
135 for (uint32_t merge_block_id = MergeBlock(bb_id); merge_block_id != 0;
136 merge_block_id = MergeBlock(merge_block_id)) {
137 result++;
138 }
139
140 return result;
141 }
142
LoopMergeBlock(uint32_t bb_id)143 uint32_t StructuredCFGAnalysis::LoopMergeBlock(uint32_t bb_id) {
144 uint32_t header_id = ContainingLoop(bb_id);
145 if (header_id == 0) {
146 return 0;
147 }
148
149 BasicBlock* header = context_->cfg()->block(header_id);
150 Instruction* merge_inst = header->GetMergeInst();
151 return merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
152 }
153
LoopContinueBlock(uint32_t bb_id)154 uint32_t StructuredCFGAnalysis::LoopContinueBlock(uint32_t bb_id) {
155 uint32_t header_id = ContainingLoop(bb_id);
156 if (header_id == 0) {
157 return 0;
158 }
159
160 BasicBlock* header = context_->cfg()->block(header_id);
161 Instruction* merge_inst = header->GetMergeInst();
162 return merge_inst->GetSingleWordInOperand(kContinueNodeIndex);
163 }
164
LoopNestingDepth(uint32_t bb_id)165 uint32_t StructuredCFGAnalysis::LoopNestingDepth(uint32_t bb_id) {
166 uint32_t result = 0;
167
168 // Find the merge block of the current loop as long as the block is inside a
169 // loop, exiting a loop for each iteration.
170 for (uint32_t merge_block_id = LoopMergeBlock(bb_id); merge_block_id != 0;
171 merge_block_id = LoopMergeBlock(merge_block_id)) {
172 result++;
173 }
174
175 return result;
176 }
177
SwitchMergeBlock(uint32_t bb_id)178 uint32_t StructuredCFGAnalysis::SwitchMergeBlock(uint32_t bb_id) {
179 uint32_t header_id = ContainingSwitch(bb_id);
180 if (header_id == 0) {
181 return 0;
182 }
183
184 BasicBlock* header = context_->cfg()->block(header_id);
185 Instruction* merge_inst = header->GetMergeInst();
186 return merge_inst->GetSingleWordInOperand(kMergeNodeIndex);
187 }
188
IsContinueBlock(uint32_t bb_id)189 bool StructuredCFGAnalysis::IsContinueBlock(uint32_t bb_id) {
190 assert(bb_id != 0);
191 return LoopContinueBlock(bb_id) == bb_id;
192 }
193
IsInContainingLoopsContinueConstruct( uint32_t bb_id)194 bool StructuredCFGAnalysis::IsInContainingLoopsContinueConstruct(
195 uint32_t bb_id) {
196 auto it = bb_to_construct_.find(bb_id);
197 if (it == bb_to_construct_.end()) {
198 return false;
199 }
200 return it->second.in_continue;
201 }
202
IsInContinueConstruct(uint32_t bb_id)203 bool StructuredCFGAnalysis::IsInContinueConstruct(uint32_t bb_id) {
204 while (bb_id != 0) {
205 if (IsInContainingLoopsContinueConstruct(bb_id)) {
206 return true;
207 }
208 bb_id = ContainingLoop(bb_id);
209 }
210 return false;
211 }
212
IsMergeBlock(uint32_t bb_id)213 bool StructuredCFGAnalysis::IsMergeBlock(uint32_t bb_id) {
214 return merge_blocks_.Get(bb_id);
215 }
216
217 std::unordered_set<uint32_t>
FindFuncsCalledFromContinue()218 StructuredCFGAnalysis::FindFuncsCalledFromContinue() {
219 std::unordered_set<uint32_t> called_from_continue;
220 std::queue<uint32_t> funcs_to_process;
221
222 // First collect the functions that are called directly from a continue
223 // construct.
224 for (Function& func : *context_->module()) {
225 for (auto& bb : func) {
226 if (IsInContainingLoopsContinueConstruct(bb.id())) {
227 for (const Instruction& inst : bb) {
228 if (inst.opcode() == spv::Op::OpFunctionCall) {
229 funcs_to_process.push(inst.GetSingleWordInOperand(0));
230 }
231 }
232 }
233 }
234 }
235
236 // Now collect all of the functions that are indirectly called as well.
237 while (!funcs_to_process.empty()) {
238 uint32_t func_id = funcs_to_process.front();
239 funcs_to_process.pop();
240 Function* func = context_->GetFunction(func_id);
241 if (called_from_continue.insert(func_id).second) {
242 context_->AddCalls(func, &funcs_to_process);
243 }
244 }
245 return called_from_continue;
246 }
247
248 } // namespace opt
249 } // namespace spvtools
250