1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <memory>
16#include <string>
17#include <unordered_set>
18#include <vector>
19
20#include "gmock/gmock.h"
21#include "source/opt/iterator.h"
22#include "source/opt/loop_descriptor.h"
23#include "source/opt/pass.h"
24#include "source/opt/tree_iterator.h"
25#include "test/opt/assembly_builder.h"
26#include "test/opt/function_utils.h"
27#include "test/opt/pass_fixture.h"
28#include "test/opt/pass_utils.h"
29
30namespace spvtools {
31namespace opt {
32namespace {
33
34using ::testing::UnorderedElementsAre;
35
36bool Validate(const std::vector<uint32_t>& bin) {
37  spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
38  spv_context spvContext = spvContextCreate(target_env);
39  spv_diagnostic diagnostic = nullptr;
40  spv_const_binary_t binary = {bin.data(), bin.size()};
41  spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
42  if (error != 0) spvDiagnosticPrint(diagnostic);
43  spvDiagnosticDestroy(diagnostic);
44  spvContextDestroy(spvContext);
45  return error == 0;
46}
47
48using PassClassTest = PassTest<::testing::Test>;
49
50/*
51Generated from the following GLSL
52#version 330 core
53layout(location = 0) out vec4 c;
54void main() {
55  int i = 0;
56  for (; i < 10; ++i) {
57    int j = 0;
58    int k = 0;
59    for (; j < 11; ++j) {}
60    for (; k < 12; ++k) {}
61  }
62}
63*/
64TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
65  const std::string text = R"(
66               OpCapability Shader
67          %1 = OpExtInstImport "GLSL.std.450"
68               OpMemoryModel Logical GLSL450
69               OpEntryPoint Fragment %2 "main" %3
70               OpExecutionMode %2 OriginUpperLeft
71               OpSource GLSL 330
72               OpName %2 "main"
73               OpName %4 "i"
74               OpName %5 "j"
75               OpName %6 "k"
76               OpName %3 "c"
77               OpDecorate %3 Location 0
78          %7 = OpTypeVoid
79          %8 = OpTypeFunction %7
80          %9 = OpTypeInt 32 1
81         %10 = OpTypePointer Function %9
82         %11 = OpConstant %9 0
83         %12 = OpConstant %9 10
84         %13 = OpTypeBool
85         %14 = OpConstant %9 11
86         %15 = OpConstant %9 1
87         %16 = OpConstant %9 12
88         %17 = OpTypeFloat 32
89         %18 = OpTypeVector %17 4
90         %19 = OpTypePointer Output %18
91          %3 = OpVariable %19 Output
92          %2 = OpFunction %7 None %8
93         %20 = OpLabel
94          %4 = OpVariable %10 Function
95          %5 = OpVariable %10 Function
96          %6 = OpVariable %10 Function
97               OpStore %4 %11
98               OpBranch %21
99         %21 = OpLabel
100               OpLoopMerge %22 %23 None
101               OpBranch %24
102         %24 = OpLabel
103         %25 = OpLoad %9 %4
104         %26 = OpSLessThan %13 %25 %12
105               OpBranchConditional %26 %27 %22
106         %27 = OpLabel
107               OpStore %5 %11
108               OpStore %6 %11
109               OpBranch %28
110         %28 = OpLabel
111               OpLoopMerge %29 %30 None
112               OpBranch %31
113         %31 = OpLabel
114         %32 = OpLoad %9 %5
115         %33 = OpSLessThan %13 %32 %14
116               OpBranchConditional %33 %34 %29
117         %34 = OpLabel
118               OpBranch %30
119         %30 = OpLabel
120         %35 = OpLoad %9 %5
121         %36 = OpIAdd %9 %35 %15
122               OpStore %5 %36
123               OpBranch %28
124         %29 = OpLabel
125               OpBranch %37
126         %37 = OpLabel
127               OpLoopMerge %38 %39 None
128               OpBranch %40
129         %40 = OpLabel
130         %41 = OpLoad %9 %6
131         %42 = OpSLessThan %13 %41 %16
132               OpBranchConditional %42 %43 %38
133         %43 = OpLabel
134               OpBranch %39
135         %39 = OpLabel
136         %44 = OpLoad %9 %6
137         %45 = OpIAdd %9 %44 %15
138               OpStore %6 %45
139               OpBranch %37
140         %38 = OpLabel
141               OpBranch %23
142         %23 = OpLabel
143         %46 = OpLoad %9 %4
144         %47 = OpIAdd %9 %46 %15
145               OpStore %4 %47
146               OpBranch %21
147         %22 = OpLabel
148               OpReturn
149               OpFunctionEnd
150  )";
151  // clang-format on
152  std::unique_ptr<IRContext> context =
153      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
154                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
155  Module* module = context->module();
156  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
157                             << text << std::endl;
158  const Function* f = spvtest::GetFunction(module, 2);
159  LoopDescriptor& ld = *context->GetLoopDescriptor(f);
160
161  EXPECT_EQ(ld.NumLoops(), 3u);
162
163  // Invalid basic block id.
164  EXPECT_EQ(ld[0u], nullptr);
165  // Not a loop header.
166  EXPECT_EQ(ld[20], nullptr);
167
168  Loop& parent_loop = *ld[21];
169  EXPECT_TRUE(parent_loop.HasNestedLoops());
170  EXPECT_FALSE(parent_loop.IsNested());
171  EXPECT_EQ(parent_loop.GetDepth(), 1u);
172  EXPECT_EQ(std::distance(parent_loop.begin(), parent_loop.end()), 2u);
173  EXPECT_EQ(parent_loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 21));
174  EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23));
175  EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22));
176
177  Loop& child_loop_1 = *ld[28];
178  EXPECT_FALSE(child_loop_1.HasNestedLoops());
179  EXPECT_TRUE(child_loop_1.IsNested());
180  EXPECT_EQ(child_loop_1.GetDepth(), 2u);
181  EXPECT_EQ(std::distance(child_loop_1.begin(), child_loop_1.end()), 0u);
182  EXPECT_EQ(child_loop_1.GetHeaderBlock(), spvtest::GetBasicBlock(f, 28));
183  EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30));
184  EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29));
185
186  Loop& child_loop_2 = *ld[37];
187  EXPECT_FALSE(child_loop_2.HasNestedLoops());
188  EXPECT_TRUE(child_loop_2.IsNested());
189  EXPECT_EQ(child_loop_2.GetDepth(), 2u);
190  EXPECT_EQ(std::distance(child_loop_2.begin(), child_loop_2.end()), 0u);
191  EXPECT_EQ(child_loop_2.GetHeaderBlock(), spvtest::GetBasicBlock(f, 37));
192  EXPECT_EQ(child_loop_2.GetLatchBlock(), spvtest::GetBasicBlock(f, 39));
193  EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38));
194}
195
196static void CheckLoopBlocks(Loop* loop,
197                            std::unordered_set<uint32_t>* expected_ids) {
198  SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id()));
199  for (uint32_t bb_id : loop->GetBlocks()) {
200    EXPECT_EQ(expected_ids->count(bb_id), 1u);
201    expected_ids->erase(bb_id);
202  }
203  EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
204  EXPECT_EQ(expected_ids->size(), 0u);
205}
206
207/*
208Generated from the following GLSL
209#version 330 core
210layout(location = 0) out vec4 c;
211void main() {
212  int i = 0;
213  for (; i < 10; ++i) {
214    for (int j = 0; j < 11; ++j) {
215      if (j < 5) {
216        for (int k = 0; k < 12; ++k) {}
217      }
218      else {}
219      for (int k = 0; k < 12; ++k) {}
220    }
221  }
222}*/
223TEST_F(PassClassTest, TripleNestedLoop) {
224  const std::string text = R"(
225               OpCapability Shader
226          %1 = OpExtInstImport "GLSL.std.450"
227               OpMemoryModel Logical GLSL450
228               OpEntryPoint Fragment %2 "main" %3
229               OpExecutionMode %2 OriginUpperLeft
230               OpSource GLSL 330
231               OpName %2 "main"
232               OpName %4 "i"
233               OpName %5 "j"
234               OpName %6 "k"
235               OpName %7 "k"
236               OpName %3 "c"
237               OpDecorate %3 Location 0
238          %8 = OpTypeVoid
239          %9 = OpTypeFunction %8
240         %10 = OpTypeInt 32 1
241         %11 = OpTypePointer Function %10
242         %12 = OpConstant %10 0
243         %13 = OpConstant %10 10
244         %14 = OpTypeBool
245         %15 = OpConstant %10 11
246         %16 = OpConstant %10 5
247         %17 = OpConstant %10 12
248         %18 = OpConstant %10 1
249         %19 = OpTypeFloat 32
250         %20 = OpTypeVector %19 4
251         %21 = OpTypePointer Output %20
252          %3 = OpVariable %21 Output
253          %2 = OpFunction %8 None %9
254         %22 = OpLabel
255          %4 = OpVariable %11 Function
256          %5 = OpVariable %11 Function
257          %6 = OpVariable %11 Function
258          %7 = OpVariable %11 Function
259               OpStore %4 %12
260               OpBranch %23
261         %23 = OpLabel
262               OpLoopMerge %24 %25 None
263               OpBranch %26
264         %26 = OpLabel
265         %27 = OpLoad %10 %4
266         %28 = OpSLessThan %14 %27 %13
267               OpBranchConditional %28 %29 %24
268         %29 = OpLabel
269               OpStore %5 %12
270               OpBranch %30
271         %30 = OpLabel
272               OpLoopMerge %31 %32 None
273               OpBranch %33
274         %33 = OpLabel
275         %34 = OpLoad %10 %5
276         %35 = OpSLessThan %14 %34 %15
277               OpBranchConditional %35 %36 %31
278         %36 = OpLabel
279         %37 = OpLoad %10 %5
280         %38 = OpSLessThan %14 %37 %16
281               OpSelectionMerge %39 None
282               OpBranchConditional %38 %40 %39
283         %40 = OpLabel
284               OpStore %6 %12
285               OpBranch %41
286         %41 = OpLabel
287               OpLoopMerge %42 %43 None
288               OpBranch %44
289         %44 = OpLabel
290         %45 = OpLoad %10 %6
291         %46 = OpSLessThan %14 %45 %17
292               OpBranchConditional %46 %47 %42
293         %47 = OpLabel
294               OpBranch %43
295         %43 = OpLabel
296         %48 = OpLoad %10 %6
297         %49 = OpIAdd %10 %48 %18
298               OpStore %6 %49
299               OpBranch %41
300         %42 = OpLabel
301               OpBranch %39
302         %39 = OpLabel
303               OpStore %7 %12
304               OpBranch %50
305         %50 = OpLabel
306               OpLoopMerge %51 %52 None
307               OpBranch %53
308         %53 = OpLabel
309         %54 = OpLoad %10 %7
310         %55 = OpSLessThan %14 %54 %17
311               OpBranchConditional %55 %56 %51
312         %56 = OpLabel
313               OpBranch %52
314         %52 = OpLabel
315         %57 = OpLoad %10 %7
316         %58 = OpIAdd %10 %57 %18
317               OpStore %7 %58
318               OpBranch %50
319         %51 = OpLabel
320               OpBranch %32
321         %32 = OpLabel
322         %59 = OpLoad %10 %5
323         %60 = OpIAdd %10 %59 %18
324               OpStore %5 %60
325               OpBranch %30
326         %31 = OpLabel
327               OpBranch %25
328         %25 = OpLabel
329         %61 = OpLoad %10 %4
330         %62 = OpIAdd %10 %61 %18
331               OpStore %4 %62
332               OpBranch %23
333         %24 = OpLabel
334               OpReturn
335               OpFunctionEnd
336  )";
337  // clang-format on
338  std::unique_ptr<IRContext> context =
339      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
340                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
341  Module* module = context->module();
342  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
343                             << text << std::endl;
344  const Function* f = spvtest::GetFunction(module, 2);
345  LoopDescriptor& ld = *context->GetLoopDescriptor(f);
346
347  EXPECT_EQ(ld.NumLoops(), 4u);
348
349  // Invalid basic block id.
350  EXPECT_EQ(ld[0u], nullptr);
351  // Not in a loop.
352  EXPECT_EQ(ld[22], nullptr);
353
354  // Check that we can map basic block to the correct loop.
355  // The following block ids do not belong to a loop.
356  for (uint32_t bb_id : {22, 24}) EXPECT_EQ(ld[bb_id], nullptr);
357
358  {
359    std::unordered_set<uint32_t> basic_block_in_loop = {
360        {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43,
361         42, 39, 50, 53, 56, 52, 51, 32, 31, 25}};
362    Loop* loop = ld[23];
363    CheckLoopBlocks(loop, &basic_block_in_loop);
364
365    EXPECT_TRUE(loop->HasNestedLoops());
366    EXPECT_FALSE(loop->IsNested());
367    EXPECT_EQ(loop->GetDepth(), 1u);
368    EXPECT_EQ(std::distance(loop->begin(), loop->end()), 1u);
369    EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 22));
370    EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 23));
371    EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 25));
372    EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 24));
373    EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
374    EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
375  }
376
377  {
378    std::unordered_set<uint32_t> basic_block_in_loop = {
379        {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}};
380    Loop* loop = ld[30];
381    CheckLoopBlocks(loop, &basic_block_in_loop);
382
383    EXPECT_TRUE(loop->HasNestedLoops());
384    EXPECT_TRUE(loop->IsNested());
385    EXPECT_EQ(loop->GetDepth(), 2u);
386    EXPECT_EQ(std::distance(loop->begin(), loop->end()), 2u);
387    EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 29));
388    EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 30));
389    EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 32));
390    EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 31));
391    EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
392    EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
393  }
394
395  {
396    std::unordered_set<uint32_t> basic_block_in_loop = {{41, 44, 47, 43}};
397    Loop* loop = ld[41];
398    CheckLoopBlocks(loop, &basic_block_in_loop);
399
400    EXPECT_FALSE(loop->HasNestedLoops());
401    EXPECT_TRUE(loop->IsNested());
402    EXPECT_EQ(loop->GetDepth(), 3u);
403    EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
404    EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 40));
405    EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 41));
406    EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 43));
407    EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 42));
408    EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
409    EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
410  }
411
412  {
413    std::unordered_set<uint32_t> basic_block_in_loop = {{50, 53, 56, 52}};
414    Loop* loop = ld[50];
415    CheckLoopBlocks(loop, &basic_block_in_loop);
416
417    EXPECT_FALSE(loop->HasNestedLoops());
418    EXPECT_TRUE(loop->IsNested());
419    EXPECT_EQ(loop->GetDepth(), 3u);
420    EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
421    EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 39));
422    EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 50));
423    EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 52));
424    EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 51));
425    EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
426    EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
427  }
428
429  // Make sure LoopDescriptor gives us the inner most loop when we query for
430  // loops.
431  for (const BasicBlock& bb : *f) {
432    if (Loop* loop = ld[&bb]) {
433      for (Loop& sub_loop :
434           make_range(++TreeDFIterator<Loop>(loop), TreeDFIterator<Loop>())) {
435        EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id()));
436      }
437    }
438  }
439}
440
441/*
442Generated from the following GLSL
443#version 330 core
444layout(location = 0) out vec4 c;
445void main() {
446  for (int i = 0; i < 10; ++i) {
447    for (int j = 0; j < 11; ++j) {
448      for (int k = 0; k < 11; ++k) {}
449    }
450    for (int k = 0; k < 12; ++k) {}
451  }
452}
453*/
454TEST_F(PassClassTest, LoopParentTest) {
455  const std::string text = R"(
456               OpCapability Shader
457          %1 = OpExtInstImport "GLSL.std.450"
458               OpMemoryModel Logical GLSL450
459               OpEntryPoint Fragment %2 "main" %3
460               OpExecutionMode %2 OriginUpperLeft
461               OpSource GLSL 330
462               OpName %2 "main"
463               OpName %4 "i"
464               OpName %5 "j"
465               OpName %6 "k"
466               OpName %7 "k"
467               OpName %3 "c"
468               OpDecorate %3 Location 0
469          %8 = OpTypeVoid
470          %9 = OpTypeFunction %8
471         %10 = OpTypeInt 32 1
472         %11 = OpTypePointer Function %10
473         %12 = OpConstant %10 0
474         %13 = OpConstant %10 10
475         %14 = OpTypeBool
476         %15 = OpConstant %10 11
477         %16 = OpConstant %10 1
478         %17 = OpConstant %10 12
479         %18 = OpTypeFloat 32
480         %19 = OpTypeVector %18 4
481         %20 = OpTypePointer Output %19
482          %3 = OpVariable %20 Output
483          %2 = OpFunction %8 None %9
484         %21 = OpLabel
485          %4 = OpVariable %11 Function
486          %5 = OpVariable %11 Function
487          %6 = OpVariable %11 Function
488          %7 = OpVariable %11 Function
489               OpStore %4 %12
490               OpBranch %22
491         %22 = OpLabel
492               OpLoopMerge %23 %24 None
493               OpBranch %25
494         %25 = OpLabel
495         %26 = OpLoad %10 %4
496         %27 = OpSLessThan %14 %26 %13
497               OpBranchConditional %27 %28 %23
498         %28 = OpLabel
499               OpStore %5 %12
500               OpBranch %29
501         %29 = OpLabel
502               OpLoopMerge %30 %31 None
503               OpBranch %32
504         %32 = OpLabel
505         %33 = OpLoad %10 %5
506         %34 = OpSLessThan %14 %33 %15
507               OpBranchConditional %34 %35 %30
508         %35 = OpLabel
509               OpStore %6 %12
510               OpBranch %36
511         %36 = OpLabel
512               OpLoopMerge %37 %38 None
513               OpBranch %39
514         %39 = OpLabel
515         %40 = OpLoad %10 %6
516         %41 = OpSLessThan %14 %40 %15
517               OpBranchConditional %41 %42 %37
518         %42 = OpLabel
519               OpBranch %38
520         %38 = OpLabel
521         %43 = OpLoad %10 %6
522         %44 = OpIAdd %10 %43 %16
523               OpStore %6 %44
524               OpBranch %36
525         %37 = OpLabel
526               OpBranch %31
527         %31 = OpLabel
528         %45 = OpLoad %10 %5
529         %46 = OpIAdd %10 %45 %16
530               OpStore %5 %46
531               OpBranch %29
532         %30 = OpLabel
533               OpStore %7 %12
534               OpBranch %47
535         %47 = OpLabel
536               OpLoopMerge %48 %49 None
537               OpBranch %50
538         %50 = OpLabel
539         %51 = OpLoad %10 %7
540         %52 = OpSLessThan %14 %51 %17
541               OpBranchConditional %52 %53 %48
542         %53 = OpLabel
543               OpBranch %49
544         %49 = OpLabel
545         %54 = OpLoad %10 %7
546         %55 = OpIAdd %10 %54 %16
547               OpStore %7 %55
548               OpBranch %47
549         %48 = OpLabel
550               OpBranch %24
551         %24 = OpLabel
552         %56 = OpLoad %10 %4
553         %57 = OpIAdd %10 %56 %16
554               OpStore %4 %57
555               OpBranch %22
556         %23 = OpLabel
557               OpReturn
558               OpFunctionEnd
559  )";
560  // clang-format on
561  std::unique_ptr<IRContext> context =
562      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
563                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
564  Module* module = context->module();
565  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
566                             << text << std::endl;
567  const Function* f = spvtest::GetFunction(module, 2);
568  LoopDescriptor& ld = *context->GetLoopDescriptor(f);
569
570  EXPECT_EQ(ld.NumLoops(), 4u);
571
572  {
573    Loop& loop = *ld[22];
574    EXPECT_TRUE(loop.HasNestedLoops());
575    EXPECT_FALSE(loop.IsNested());
576    EXPECT_EQ(loop.GetDepth(), 1u);
577    EXPECT_EQ(loop.GetParent(), nullptr);
578  }
579
580  {
581    Loop& loop = *ld[29];
582    EXPECT_TRUE(loop.HasNestedLoops());
583    EXPECT_TRUE(loop.IsNested());
584    EXPECT_EQ(loop.GetDepth(), 2u);
585    EXPECT_EQ(loop.GetParent(), ld[22]);
586  }
587
588  {
589    Loop& loop = *ld[36];
590    EXPECT_FALSE(loop.HasNestedLoops());
591    EXPECT_TRUE(loop.IsNested());
592    EXPECT_EQ(loop.GetDepth(), 3u);
593    EXPECT_EQ(loop.GetParent(), ld[29]);
594  }
595
596  {
597    Loop& loop = *ld[47];
598    EXPECT_FALSE(loop.HasNestedLoops());
599    EXPECT_TRUE(loop.IsNested());
600    EXPECT_EQ(loop.GetDepth(), 2u);
601    EXPECT_EQ(loop.GetParent(), ld[22]);
602  }
603}
604
605/*
606Generated from the following GLSL + --eliminate-local-multi-store
607The preheader of loop %33 and %41 were removed as well.
608
609#version 330 core
610void main() {
611  int a = 0;
612  for (int i = 0; i < 10; ++i) {
613    if (i == 0) {
614      a = 1;
615    } else {
616      a = 2;
617    }
618    for (int j = 0; j < 11; ++j) {
619      a++;
620    }
621  }
622  for (int k = 0; k < 12; ++k) {}
623}
624*/
625TEST_F(PassClassTest, CreatePreheaderTest) {
626  const std::string text = R"(
627               OpCapability Shader
628          %1 = OpExtInstImport "GLSL.std.450"
629               OpMemoryModel Logical GLSL450
630               OpEntryPoint Fragment %2 "main"
631               OpExecutionMode %2 OriginUpperLeft
632               OpSource GLSL 330
633               OpName %2 "main"
634          %3 = OpTypeVoid
635          %4 = OpTypeFunction %3
636          %5 = OpTypeInt 32 1
637          %6 = OpTypePointer Function %5
638          %7 = OpConstant %5 0
639          %8 = OpConstant %5 10
640          %9 = OpTypeBool
641         %10 = OpConstant %5 1
642         %11 = OpConstant %5 2
643         %12 = OpConstant %5 11
644         %13 = OpConstant %5 12
645         %14 = OpUndef %5
646          %2 = OpFunction %3 None %4
647         %15 = OpLabel
648               OpBranch %16
649         %16 = OpLabel
650         %17 = OpPhi %5 %7 %15 %18 %19
651         %20 = OpPhi %5 %7 %15 %21 %19
652         %22 = OpPhi %5 %14 %15 %23 %19
653               OpLoopMerge %41 %19 None
654               OpBranch %25
655         %25 = OpLabel
656         %26 = OpSLessThan %9 %20 %8
657               OpBranchConditional %26 %27 %41
658         %27 = OpLabel
659         %28 = OpIEqual %9 %20 %7
660               OpSelectionMerge %33 None
661               OpBranchConditional %28 %30 %31
662         %30 = OpLabel
663               OpBranch %33
664         %31 = OpLabel
665               OpBranch %33
666         %33 = OpLabel
667         %18 = OpPhi %5 %10 %30 %11 %31 %34 %35
668         %23 = OpPhi %5 %7 %30 %7 %31 %36 %35
669               OpLoopMerge %37 %35 None
670               OpBranch %38
671         %38 = OpLabel
672         %39 = OpSLessThan %9 %23 %12
673               OpBranchConditional %39 %40 %37
674         %40 = OpLabel
675         %34 = OpIAdd %5 %18 %10
676               OpBranch %35
677         %35 = OpLabel
678         %36 = OpIAdd %5 %23 %10
679               OpBranch %33
680         %37 = OpLabel
681               OpBranch %19
682         %19 = OpLabel
683         %21 = OpIAdd %5 %20 %10
684               OpBranch %16
685         %41 = OpLabel
686         %42 = OpPhi %5 %7 %25 %43 %44
687               OpLoopMerge %45 %44 None
688               OpBranch %46
689         %46 = OpLabel
690         %47 = OpSLessThan %9 %42 %13
691               OpBranchConditional %47 %48 %45
692         %48 = OpLabel
693               OpBranch %44
694         %44 = OpLabel
695         %43 = OpIAdd %5 %42 %10
696               OpBranch %41
697         %45 = OpLabel
698               OpReturn
699               OpFunctionEnd
700  )";
701  // clang-format on
702  std::unique_ptr<IRContext> context =
703      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
704                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
705  Module* module = context->module();
706  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
707                             << text << std::endl;
708  const Function* f = spvtest::GetFunction(module, 2);
709  LoopDescriptor& ld = *context->GetLoopDescriptor(f);
710  // No invalidation of the cfg should occur during this test.
711  CFG* cfg = context->cfg();
712
713  EXPECT_EQ(ld.NumLoops(), 3u);
714
715  {
716    Loop& loop = *ld[16];
717    EXPECT_TRUE(loop.HasNestedLoops());
718    EXPECT_FALSE(loop.IsNested());
719    EXPECT_EQ(loop.GetDepth(), 1u);
720    EXPECT_EQ(loop.GetParent(), nullptr);
721  }
722
723  {
724    Loop& loop = *ld[33];
725    EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
726    EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr);
727    // Make sure the loop descriptor was properly updated.
728    EXPECT_EQ(ld[loop.GetPreHeaderBlock()], ld[16]);
729    {
730      const std::vector<uint32_t>& preds =
731          cfg->preds(loop.GetPreHeaderBlock()->id());
732      std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
733      EXPECT_EQ(pred_set.size(), 2u);
734      EXPECT_TRUE(pred_set.count(30));
735      EXPECT_TRUE(pred_set.count(31));
736      // Check the phi instructions.
737      loop.GetPreHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
738        for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
739          EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
740        }
741      });
742    }
743    {
744      const std::vector<uint32_t>& preds =
745          cfg->preds(loop.GetHeaderBlock()->id());
746      std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
747      EXPECT_EQ(pred_set.size(), 2u);
748      EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
749      EXPECT_TRUE(pred_set.count(35));
750      // Check the phi instructions.
751      loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
752        for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
753          EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
754        }
755      });
756    }
757  }
758
759  {
760    Loop& loop = *ld[41];
761    EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
762    EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr);
763    EXPECT_EQ(ld[loop.GetPreHeaderBlock()], nullptr);
764    EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id()).size(), 1u);
765    EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id())[0], 25u);
766    // Check the phi instructions.
767    loop.GetPreHeaderBlock()->ForEachPhiInst([](Instruction* phi) {
768      EXPECT_EQ(phi->NumInOperands(), 2u);
769      EXPECT_EQ(phi->GetSingleWordInOperand(1), 25u);
770    });
771    {
772      const std::vector<uint32_t>& preds =
773          cfg->preds(loop.GetHeaderBlock()->id());
774      std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
775      EXPECT_EQ(pred_set.size(), 2u);
776      EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
777      EXPECT_TRUE(pred_set.count(44));
778      // Check the phi instructions.
779      loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
780        for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
781          EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
782        }
783      });
784    }
785  }
786
787  // Make sure pre-header insertion leaves the module valid.
788  std::vector<uint32_t> bin;
789  context->module()->ToBinary(&bin, true);
790  EXPECT_TRUE(Validate(bin));
791}
792
793}  // namespace
794}  // namespace opt
795}  // namespace spvtools
796