1// Copyright (c) 2021 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/control_dependence.h"
16
17#include <algorithm>
18#include <vector>
19
20#include "gmock/gmock-matchers.h"
21#include "gtest/gtest.h"
22#include "source/opt/build_module.h"
23#include "source/opt/cfg.h"
24#include "test/opt/function_utils.h"
25
26namespace spvtools {
27namespace opt {
28
29namespace {
30void GatherEdges(const ControlDependenceAnalysis& cdg,
31                 std::vector<ControlDependence>& ret) {
32  cdg.ForEachBlockLabel([&](uint32_t label) {
33    ret.reserve(ret.size() + cdg.GetDependenceTargets(label).size());
34    ret.insert(ret.end(), cdg.GetDependenceTargets(label).begin(),
35               cdg.GetDependenceTargets(label).end());
36  });
37  std::sort(ret.begin(), ret.end());
38  // Verify that reverse graph is the same.
39  std::vector<ControlDependence> reverse_edges;
40  reverse_edges.reserve(ret.size());
41  cdg.ForEachBlockLabel([&](uint32_t label) {
42    reverse_edges.insert(reverse_edges.end(),
43                         cdg.GetDependenceSources(label).begin(),
44                         cdg.GetDependenceSources(label).end());
45  });
46  std::sort(reverse_edges.begin(), reverse_edges.end());
47  ASSERT_THAT(reverse_edges, testing::ElementsAreArray(ret));
48}
49
50using ControlDependenceTest = ::testing::Test;
51
52TEST(ControlDependenceTest, DependenceSimpleCFG) {
53  const std::string text = R"(
54               OpCapability Addresses
55               OpCapability Kernel
56               OpMemoryModel Physical64 OpenCL
57               OpEntryPoint Kernel %1 "main"
58          %2 = OpTypeVoid
59          %3 = OpTypeFunction %2
60          %4 = OpTypeBool
61          %5 = OpTypeInt 32 0
62          %6 = OpConstant %5 0
63          %7 = OpConstantFalse %4
64          %8 = OpConstantTrue %4
65          %9 = OpConstant %5 1
66          %1 = OpFunction %2 None %3
67         %10 = OpLabel
68               OpBranch %11
69         %11 = OpLabel
70               OpSwitch %6 %12 1 %13
71         %12 = OpLabel
72               OpBranch %14
73         %13 = OpLabel
74               OpBranch %14
75         %14 = OpLabel
76               OpBranchConditional %8 %15 %16
77         %15 = OpLabel
78               OpBranch %19
79         %16 = OpLabel
80               OpBranchConditional %8 %17 %18
81         %17 = OpLabel
82               OpBranch %18
83         %18 = OpLabel
84               OpBranch %19
85         %19 = OpLabel
86               OpReturn
87               OpFunctionEnd
88)";
89
90  // CFG: (all edges pointing downward)
91  //   %10
92  //    |
93  //   %11
94  //  /   \ (R: %6 == 1, L: default)
95  // %12 %13
96  //  \   /
97  //   %14
98  // T/   \F
99  // %15  %16
100  //  | T/ |F
101  //  | %17|
102  //  |  \ |
103  //  |   %18
104  //  |  /
105  // %19
106
107  std::unique_ptr<IRContext> context =
108      BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text,
109                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
110  Module* module = context->module();
111  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
112                             << text << std::endl;
113  const Function* fn = spvtest::GetFunction(module, 1);
114  const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10);
115  EXPECT_EQ(entry, fn->entry().get())
116      << "The entry node is not the expected one";
117
118  {
119    PostDominatorAnalysis pdom;
120    const CFG& cfg = *context->cfg();
121    pdom.InitializeTree(cfg, fn);
122    ControlDependenceAnalysis cdg;
123    cdg.ComputeControlDependenceGraph(cfg, pdom);
124
125    // Test HasBlock.
126    for (uint32_t id = 10; id <= 19; id++) {
127      EXPECT_TRUE(cdg.HasBlock(id));
128    }
129    EXPECT_TRUE(cdg.HasBlock(ControlDependenceAnalysis::kPseudoEntryBlock));
130    // Check blocks before/after valid range.
131    EXPECT_FALSE(cdg.HasBlock(5));
132    EXPECT_FALSE(cdg.HasBlock(25));
133    EXPECT_FALSE(cdg.HasBlock(UINT32_MAX));
134
135    // Test ForEachBlockLabel.
136    std::set<uint32_t> block_labels;
137    cdg.ForEachBlockLabel([&block_labels](uint32_t id) {
138      bool inserted = block_labels.insert(id).second;
139      EXPECT_TRUE(inserted);  // Should have no duplicates.
140    });
141    EXPECT_THAT(block_labels, testing::ElementsAre(0, 10, 11, 12, 13, 14, 15,
142                                                   16, 17, 18, 19));
143
144    {
145      // Test WhileEachBlockLabel.
146      uint32_t iters = 0;
147      EXPECT_TRUE(cdg.WhileEachBlockLabel([&iters](uint32_t) {
148        ++iters;
149        return true;
150      }));
151      EXPECT_EQ((uint32_t)block_labels.size(), iters);
152      iters = 0;
153      EXPECT_FALSE(cdg.WhileEachBlockLabel([&iters](uint32_t) {
154        ++iters;
155        return false;
156      }));
157      EXPECT_EQ(1, iters);
158    }
159
160    // Test IsDependent.
161    EXPECT_TRUE(cdg.IsDependent(12, 11));
162    EXPECT_TRUE(cdg.IsDependent(13, 11));
163    EXPECT_TRUE(cdg.IsDependent(15, 14));
164    EXPECT_TRUE(cdg.IsDependent(16, 14));
165    EXPECT_TRUE(cdg.IsDependent(18, 14));
166    EXPECT_TRUE(cdg.IsDependent(17, 16));
167    EXPECT_TRUE(cdg.IsDependent(10, 0));
168    EXPECT_TRUE(cdg.IsDependent(11, 0));
169    EXPECT_TRUE(cdg.IsDependent(14, 0));
170    EXPECT_TRUE(cdg.IsDependent(19, 0));
171    EXPECT_FALSE(cdg.IsDependent(14, 11));
172    EXPECT_FALSE(cdg.IsDependent(17, 14));
173    EXPECT_FALSE(cdg.IsDependent(19, 14));
174    EXPECT_FALSE(cdg.IsDependent(12, 0));
175
176    // Test GetDependenceSources/Targets.
177    std::vector<ControlDependence> edges;
178    GatherEdges(cdg, edges);
179    EXPECT_THAT(edges,
180                testing::ElementsAre(
181                    ControlDependence(0, 10), ControlDependence(0, 11, 10),
182                    ControlDependence(0, 14, 10), ControlDependence(0, 19, 10),
183                    ControlDependence(11, 12), ControlDependence(11, 13),
184                    ControlDependence(14, 15), ControlDependence(14, 16),
185                    ControlDependence(14, 18, 16), ControlDependence(16, 17)));
186
187    const uint32_t expected_condition_ids[] = {
188        0, 0, 0, 0, 6, 6, 8, 8, 8, 8,
189    };
190
191    for (uint32_t i = 0; i < edges.size(); i++) {
192      EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg));
193    }
194  }
195}
196
197TEST(ControlDependenceTest, DependencePaperCFG) {
198  const std::string text = R"(
199               OpCapability Addresses
200               OpCapability Kernel
201               OpMemoryModel Physical64 OpenCL
202               OpEntryPoint Kernel %101 "main"
203        %102 = OpTypeVoid
204        %103 = OpTypeFunction %102
205        %104 = OpTypeBool
206        %108 = OpConstantTrue %104
207        %101 = OpFunction %102 None %103
208          %1 = OpLabel
209               OpBranch %2
210          %2 = OpLabel
211               OpBranchConditional %108 %3 %7
212          %3 = OpLabel
213               OpBranchConditional %108 %4 %5
214          %4 = OpLabel
215               OpBranch %6
216          %5 = OpLabel
217               OpBranch %6
218          %6 = OpLabel
219               OpBranch %8
220          %7 = OpLabel
221               OpBranch %8
222          %8 = OpLabel
223               OpBranch %9
224          %9 = OpLabel
225               OpBranchConditional %108 %10 %11
226         %10 = OpLabel
227               OpBranch %11
228         %11 = OpLabel
229               OpBranchConditional %108 %12 %9
230         %12 = OpLabel
231               OpBranchConditional %108 %13 %2
232         %13 = OpLabel
233               OpReturn
234               OpFunctionEnd
235)";
236
237  // CFG: (edges pointing downward if no arrow)
238  //         %1
239  //         |
240  //         %2 <----+
241  //       T/  \F    |
242  //      %3    \    |
243  //    T/  \F   \   |
244  //    %4  %5    %7 |
245  //     \  /    /   |
246  //      %6    /    |
247  //        \  /     |
248  //         %8      |
249  //         |       |
250  //         %9 <-+  |
251  //       T/  |  |  |
252  //       %10 |  |  |
253  //        \  |  |  |
254  //         %11-F+  |
255  //         T|      |
256  //         %12-F---+
257  //         T|
258  //         %13
259
260  std::unique_ptr<IRContext> context =
261      BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text,
262                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
263  Module* module = context->module();
264  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
265                             << text << std::endl;
266  const Function* fn = spvtest::GetFunction(module, 101);
267  const BasicBlock* entry = spvtest::GetBasicBlock(fn, 1);
268  EXPECT_EQ(entry, fn->entry().get())
269      << "The entry node is not the expected one";
270
271  {
272    PostDominatorAnalysis pdom;
273    const CFG& cfg = *context->cfg();
274    pdom.InitializeTree(cfg, fn);
275    ControlDependenceAnalysis cdg;
276    cdg.ComputeControlDependenceGraph(cfg, pdom);
277
278    std::vector<ControlDependence> edges;
279    GatherEdges(cdg, edges);
280    EXPECT_THAT(
281        edges, testing::ElementsAre(
282                   ControlDependence(0, 1), ControlDependence(0, 2, 1),
283                   ControlDependence(0, 8, 1), ControlDependence(0, 9, 1),
284                   ControlDependence(0, 11, 1), ControlDependence(0, 12, 1),
285                   ControlDependence(0, 13, 1), ControlDependence(2, 3),
286                   ControlDependence(2, 6, 3), ControlDependence(2, 7),
287                   ControlDependence(3, 4), ControlDependence(3, 5),
288                   ControlDependence(9, 10), ControlDependence(11, 9),
289                   ControlDependence(11, 11, 9), ControlDependence(12, 2),
290                   ControlDependence(12, 8, 2), ControlDependence(12, 9, 2),
291                   ControlDependence(12, 11, 2), ControlDependence(12, 12, 2)));
292
293    const uint32_t expected_condition_ids[] = {
294        0,   0,   0,   0,   0,   0,   0,   108, 108, 108,
295        108, 108, 108, 108, 108, 108, 108, 108, 108, 108,
296    };
297
298    for (uint32_t i = 0; i < edges.size(); i++) {
299      EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg));
300    }
301  }
302}
303
304}  // namespace
305}  // namespace opt
306}  // namespace spvtools
307