1// Copyright (c) 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
16
17#include "gmock/gmock.h"
18#include "gtest/gtest.h"
19#include "source/opt/ir_context.h"
20#include "test/opt/pass_fixture.h"
21#include "test/opt/pass_utils.h"
22
23namespace spvtools {
24namespace opt {
25namespace {
26
27using ::testing::ContainerEq;
28
29using CFGTest = PassTest<::testing::Test>;
30
31TEST_F(CFGTest, ForEachBlockInPostOrderIf) {
32  const std::string test = R"(
33OpCapability Shader
34%1 = OpExtInstImport "GLSL.std.450"
35OpMemoryModel Logical GLSL450
36OpEntryPoint Vertex %main "main"
37OpName %main "main"
38%bool = OpTypeBool
39%true = OpConstantTrue %bool
40%void = OpTypeVoid
41%4 = OpTypeFunction %void
42%uint = OpTypeInt 32 0
43%5 = OpConstant %uint 5
44%main = OpFunction %void None %4
45%8 = OpLabel
46OpSelectionMerge %10 None
47OpBranchConditional %true %9 %10
48%9 = OpLabel
49OpBranch %10
50%10 = OpLabel
51OpReturn
52OpFunctionEnd
53)";
54
55  std::unique_ptr<IRContext> context =
56      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
57                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
58  ASSERT_NE(nullptr, context);
59
60  CFG* cfg = context->cfg();
61  Module* module = context->module();
62  Function* function = &*module->begin();
63  std::vector<uint32_t> order;
64  cfg->ForEachBlockInPostOrder(&*function->begin(), [&order](BasicBlock* bb) {
65    order.push_back(bb->id());
66  });
67
68  std::vector<uint32_t> expected_result = {10, 9, 8};
69  EXPECT_THAT(order, ContainerEq(expected_result));
70}
71
72TEST_F(CFGTest, ForEachBlockInPostOrderLoop) {
73  const std::string test = R"(
74OpCapability Shader
75%1 = OpExtInstImport "GLSL.std.450"
76OpMemoryModel Logical GLSL450
77OpEntryPoint Vertex %main "main"
78OpName %main "main"
79%bool = OpTypeBool
80%true = OpConstantTrue %bool
81%void = OpTypeVoid
82%4 = OpTypeFunction %void
83%uint = OpTypeInt 32 0
84%5 = OpConstant %uint 5
85%main = OpFunction %void None %4
86%8 = OpLabel
87OpBranch %9
88%9 = OpLabel
89OpLoopMerge %11 %10 None
90OpBranchConditional %true %11 %10
91%10 = OpLabel
92OpBranch %9
93%11 = OpLabel
94OpReturn
95OpFunctionEnd
96)";
97
98  std::unique_ptr<IRContext> context =
99      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
100                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
101  ASSERT_NE(nullptr, context);
102
103  CFG* cfg = context->cfg();
104  Module* module = context->module();
105  Function* function = &*module->begin();
106  std::vector<uint32_t> order;
107  cfg->ForEachBlockInPostOrder(&*function->begin(), [&order](BasicBlock* bb) {
108    order.push_back(bb->id());
109  });
110
111  std::vector<uint32_t> expected_result1 = {10, 11, 9, 8};
112  std::vector<uint32_t> expected_result2 = {11, 10, 9, 8};
113  EXPECT_THAT(order, AnyOf(ContainerEq(expected_result1),
114                           ContainerEq(expected_result2)));
115}
116
117TEST_F(CFGTest, ForEachBlockInReversePostOrderIf) {
118  const std::string test = R"(
119OpCapability Shader
120%1 = OpExtInstImport "GLSL.std.450"
121OpMemoryModel Logical GLSL450
122OpEntryPoint Vertex %main "main"
123OpName %main "main"
124%bool = OpTypeBool
125%true = OpConstantTrue %bool
126%void = OpTypeVoid
127%4 = OpTypeFunction %void
128%uint = OpTypeInt 32 0
129%5 = OpConstant %uint 5
130%main = OpFunction %void None %4
131%8 = OpLabel
132OpSelectionMerge %10 None
133OpBranchConditional %true %9 %10
134%9 = OpLabel
135OpBranch %10
136%10 = OpLabel
137OpReturn
138OpFunctionEnd
139)";
140
141  std::unique_ptr<IRContext> context =
142      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
143                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
144  ASSERT_NE(nullptr, context);
145
146  CFG* cfg = context->cfg();
147  Module* module = context->module();
148  Function* function = &*module->begin();
149  std::vector<uint32_t> order;
150  cfg->ForEachBlockInReversePostOrder(
151      &*function->begin(),
152      [&order](BasicBlock* bb) { order.push_back(bb->id()); });
153
154  std::vector<uint32_t> expected_result = {8, 9, 10};
155  EXPECT_THAT(order, ContainerEq(expected_result));
156}
157
158TEST_F(CFGTest, ForEachBlockInReversePostOrderLoop) {
159  const std::string test = R"(
160OpCapability Shader
161%1 = OpExtInstImport "GLSL.std.450"
162OpMemoryModel Logical GLSL450
163OpEntryPoint Vertex %main "main"
164OpName %main "main"
165%bool = OpTypeBool
166%true = OpConstantTrue %bool
167%void = OpTypeVoid
168%4 = OpTypeFunction %void
169%uint = OpTypeInt 32 0
170%5 = OpConstant %uint 5
171%main = OpFunction %void None %4
172%8 = OpLabel
173OpBranch %9
174%9 = OpLabel
175OpLoopMerge %11 %10 None
176OpBranchConditional %true %11 %10
177%10 = OpLabel
178OpBranch %9
179%11 = OpLabel
180OpReturn
181OpFunctionEnd
182)";
183
184  std::unique_ptr<IRContext> context =
185      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
186                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
187  ASSERT_NE(nullptr, context);
188
189  CFG* cfg = context->cfg();
190  Module* module = context->module();
191  Function* function = &*module->begin();
192  std::vector<uint32_t> order;
193  cfg->ForEachBlockInReversePostOrder(
194      &*function->begin(),
195      [&order](BasicBlock* bb) { order.push_back(bb->id()); });
196
197  std::vector<uint32_t> expected_result1 = {8, 9, 10, 11};
198  std::vector<uint32_t> expected_result2 = {8, 9, 11, 10};
199  EXPECT_THAT(order, AnyOf(ContainerEq(expected_result1),
200                           ContainerEq(expected_result2)));
201}
202
203TEST_F(CFGTest, SplitLoopHeaderForSingleBlockLoop) {
204  const std::string test = R"(
205               OpCapability Shader
206          %1 = OpExtInstImport "GLSL.std.450"
207               OpMemoryModel Logical GLSL450
208               OpEntryPoint Fragment %2 "main"
209               OpExecutionMode %2 OriginUpperLeft
210       %void = OpTypeVoid
211       %uint = OpTypeInt 32 0
212     %uint_0 = OpConstant %uint 0
213          %6 = OpTypeFunction %void
214          %2 = OpFunction %void None %6
215          %7 = OpLabel
216               OpBranch %8
217          %8 = OpLabel
218          %9 = OpPhi %uint %uint_0 %7 %9 %8
219               OpLoopMerge %10 %8 None
220               OpBranch %8
221         %10 = OpLabel
222               OpUnreachable
223               OpFunctionEnd
224)";
225
226  const std::string expected_result = R"(OpCapability Shader
227%1 = OpExtInstImport "GLSL.std.450"
228OpMemoryModel Logical GLSL450
229OpEntryPoint Fragment %2 "main"
230OpExecutionMode %2 OriginUpperLeft
231%void = OpTypeVoid
232%uint = OpTypeInt 32 0
233%uint_0 = OpConstant %uint 0
234%6 = OpTypeFunction %void
235%2 = OpFunction %void None %6
236%7 = OpLabel
237OpBranch %8
238%8 = OpLabel
239OpBranch %11
240%11 = OpLabel
241%9 = OpPhi %uint %9 %11 %uint_0 %8
242OpLoopMerge %10 %11 None
243OpBranch %11
244%10 = OpLabel
245OpUnreachable
246OpFunctionEnd
247)";
248
249  std::unique_ptr<IRContext> context =
250      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
251                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
252  ASSERT_NE(nullptr, context);
253
254  BasicBlock* loop_header = context->get_instr_block(8);
255  ASSERT_TRUE(loop_header->GetLoopMergeInst() != nullptr);
256
257  CFG* cfg = context->cfg();
258  cfg->SplitLoopHeader(loop_header);
259
260  std::vector<uint32_t> binary;
261  bool skip_nop = false;
262  context->module()->ToBinary(&binary, skip_nop);
263
264  std::string optimized_asm;
265  SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
266  EXPECT_TRUE(tools.Disassemble(binary, &optimized_asm,
267                                SpirvTools::kDefaultDisassembleOption))
268      << "Disassembling failed for shader\n"
269      << std::endl;
270
271  EXPECT_EQ(optimized_asm, expected_result);
272}
273
274TEST_F(CFGTest, ComputeStructedOrderForLoop) {
275  const std::string test = R"(
276OpCapability Shader
277%1 = OpExtInstImport "GLSL.std.450"
278OpMemoryModel Logical GLSL450
279OpEntryPoint Vertex %main "main"
280OpName %main "main"
281%bool = OpTypeBool
282%true = OpConstantTrue %bool
283%void = OpTypeVoid
284%4 = OpTypeFunction %void
285%uint = OpTypeInt 32 0
286%5 = OpConstant %uint 5
287%main = OpFunction %void None %4
288%8 = OpLabel
289OpBranch %9
290%9 = OpLabel
291OpLoopMerge %11 %10 None
292OpBranchConditional %true %11 %10
293%10 = OpLabel
294OpBranch %9
295%11 = OpLabel
296OpBranch %12
297%12 = OpLabel
298OpReturn
299OpFunctionEnd
300)";
301
302  std::unique_ptr<IRContext> context =
303      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, test,
304                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
305  ASSERT_NE(nullptr, context);
306
307  CFG* cfg = context->cfg();
308  Module* module = context->module();
309  Function* function = &*module->begin();
310  std::list<BasicBlock*> order;
311  cfg->ComputeStructuredOrder(function, context->get_instr_block(9),
312                              context->get_instr_block(11), &order);
313
314  // Order should contain the loop header, the continue target, and the merge
315  // node.
316  std::list<BasicBlock*> expected_result = {context->get_instr_block(9),
317                                            context->get_instr_block(10),
318                                            context->get_instr_block(11)};
319  EXPECT_THAT(order, ContainerEq(expected_result));
320}
321
322}  // namespace
323}  // namespace opt
324}  // namespace spvtools
325