1// Copyright (c) 2021 Google LLC. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include "source/opt/control_dependence.h" 16 17#include <algorithm> 18#include <vector> 19 20#include "gmock/gmock-matchers.h" 21#include "gtest/gtest.h" 22#include "source/opt/build_module.h" 23#include "source/opt/cfg.h" 24#include "test/opt/function_utils.h" 25 26namespace spvtools { 27namespace opt { 28 29namespace { 30void GatherEdges(const ControlDependenceAnalysis& cdg, 31 std::vector<ControlDependence>& ret) { 32 cdg.ForEachBlockLabel([&](uint32_t label) { 33 ret.reserve(ret.size() + cdg.GetDependenceTargets(label).size()); 34 ret.insert(ret.end(), cdg.GetDependenceTargets(label).begin(), 35 cdg.GetDependenceTargets(label).end()); 36 }); 37 std::sort(ret.begin(), ret.end()); 38 // Verify that reverse graph is the same. 39 std::vector<ControlDependence> reverse_edges; 40 reverse_edges.reserve(ret.size()); 41 cdg.ForEachBlockLabel([&](uint32_t label) { 42 reverse_edges.insert(reverse_edges.end(), 43 cdg.GetDependenceSources(label).begin(), 44 cdg.GetDependenceSources(label).end()); 45 }); 46 std::sort(reverse_edges.begin(), reverse_edges.end()); 47 ASSERT_THAT(reverse_edges, testing::ElementsAreArray(ret)); 48} 49 50using ControlDependenceTest = ::testing::Test; 51 52TEST(ControlDependenceTest, DependenceSimpleCFG) { 53 const std::string text = R"( 54 OpCapability Addresses 55 OpCapability Kernel 56 OpMemoryModel Physical64 OpenCL 57 OpEntryPoint Kernel %1 "main" 58 %2 = OpTypeVoid 59 %3 = OpTypeFunction %2 60 %4 = OpTypeBool 61 %5 = OpTypeInt 32 0 62 %6 = OpConstant %5 0 63 %7 = OpConstantFalse %4 64 %8 = OpConstantTrue %4 65 %9 = OpConstant %5 1 66 %1 = OpFunction %2 None %3 67 %10 = OpLabel 68 OpBranch %11 69 %11 = OpLabel 70 OpSwitch %6 %12 1 %13 71 %12 = OpLabel 72 OpBranch %14 73 %13 = OpLabel 74 OpBranch %14 75 %14 = OpLabel 76 OpBranchConditional %8 %15 %16 77 %15 = OpLabel 78 OpBranch %19 79 %16 = OpLabel 80 OpBranchConditional %8 %17 %18 81 %17 = OpLabel 82 OpBranch %18 83 %18 = OpLabel 84 OpBranch %19 85 %19 = OpLabel 86 OpReturn 87 OpFunctionEnd 88)"; 89 90 // CFG: (all edges pointing downward) 91 // %10 92 // | 93 // %11 94 // / \ (R: %6 == 1, L: default) 95 // %12 %13 96 // \ / 97 // %14 98 // T/ \F 99 // %15 %16 100 // | T/ |F 101 // | %17| 102 // | \ | 103 // | %18 104 // | / 105 // %19 106 107 std::unique_ptr<IRContext> context = 108 BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, 109 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); 110 Module* module = context->module(); 111 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" 112 << text << std::endl; 113 const Function* fn = spvtest::GetFunction(module, 1); 114 const BasicBlock* entry = spvtest::GetBasicBlock(fn, 10); 115 EXPECT_EQ(entry, fn->entry().get()) 116 << "The entry node is not the expected one"; 117 118 { 119 PostDominatorAnalysis pdom; 120 const CFG& cfg = *context->cfg(); 121 pdom.InitializeTree(cfg, fn); 122 ControlDependenceAnalysis cdg; 123 cdg.ComputeControlDependenceGraph(cfg, pdom); 124 125 // Test HasBlock. 126 for (uint32_t id = 10; id <= 19; id++) { 127 EXPECT_TRUE(cdg.HasBlock(id)); 128 } 129 EXPECT_TRUE(cdg.HasBlock(ControlDependenceAnalysis::kPseudoEntryBlock)); 130 // Check blocks before/after valid range. 131 EXPECT_FALSE(cdg.HasBlock(5)); 132 EXPECT_FALSE(cdg.HasBlock(25)); 133 EXPECT_FALSE(cdg.HasBlock(UINT32_MAX)); 134 135 // Test ForEachBlockLabel. 136 std::set<uint32_t> block_labels; 137 cdg.ForEachBlockLabel([&block_labels](uint32_t id) { 138 bool inserted = block_labels.insert(id).second; 139 EXPECT_TRUE(inserted); // Should have no duplicates. 140 }); 141 EXPECT_THAT(block_labels, testing::ElementsAre(0, 10, 11, 12, 13, 14, 15, 142 16, 17, 18, 19)); 143 144 { 145 // Test WhileEachBlockLabel. 146 uint32_t iters = 0; 147 EXPECT_TRUE(cdg.WhileEachBlockLabel([&iters](uint32_t) { 148 ++iters; 149 return true; 150 })); 151 EXPECT_EQ((uint32_t)block_labels.size(), iters); 152 iters = 0; 153 EXPECT_FALSE(cdg.WhileEachBlockLabel([&iters](uint32_t) { 154 ++iters; 155 return false; 156 })); 157 EXPECT_EQ(1, iters); 158 } 159 160 // Test IsDependent. 161 EXPECT_TRUE(cdg.IsDependent(12, 11)); 162 EXPECT_TRUE(cdg.IsDependent(13, 11)); 163 EXPECT_TRUE(cdg.IsDependent(15, 14)); 164 EXPECT_TRUE(cdg.IsDependent(16, 14)); 165 EXPECT_TRUE(cdg.IsDependent(18, 14)); 166 EXPECT_TRUE(cdg.IsDependent(17, 16)); 167 EXPECT_TRUE(cdg.IsDependent(10, 0)); 168 EXPECT_TRUE(cdg.IsDependent(11, 0)); 169 EXPECT_TRUE(cdg.IsDependent(14, 0)); 170 EXPECT_TRUE(cdg.IsDependent(19, 0)); 171 EXPECT_FALSE(cdg.IsDependent(14, 11)); 172 EXPECT_FALSE(cdg.IsDependent(17, 14)); 173 EXPECT_FALSE(cdg.IsDependent(19, 14)); 174 EXPECT_FALSE(cdg.IsDependent(12, 0)); 175 176 // Test GetDependenceSources/Targets. 177 std::vector<ControlDependence> edges; 178 GatherEdges(cdg, edges); 179 EXPECT_THAT(edges, 180 testing::ElementsAre( 181 ControlDependence(0, 10), ControlDependence(0, 11, 10), 182 ControlDependence(0, 14, 10), ControlDependence(0, 19, 10), 183 ControlDependence(11, 12), ControlDependence(11, 13), 184 ControlDependence(14, 15), ControlDependence(14, 16), 185 ControlDependence(14, 18, 16), ControlDependence(16, 17))); 186 187 const uint32_t expected_condition_ids[] = { 188 0, 0, 0, 0, 6, 6, 8, 8, 8, 8, 189 }; 190 191 for (uint32_t i = 0; i < edges.size(); i++) { 192 EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg)); 193 } 194 } 195} 196 197TEST(ControlDependenceTest, DependencePaperCFG) { 198 const std::string text = R"( 199 OpCapability Addresses 200 OpCapability Kernel 201 OpMemoryModel Physical64 OpenCL 202 OpEntryPoint Kernel %101 "main" 203 %102 = OpTypeVoid 204 %103 = OpTypeFunction %102 205 %104 = OpTypeBool 206 %108 = OpConstantTrue %104 207 %101 = OpFunction %102 None %103 208 %1 = OpLabel 209 OpBranch %2 210 %2 = OpLabel 211 OpBranchConditional %108 %3 %7 212 %3 = OpLabel 213 OpBranchConditional %108 %4 %5 214 %4 = OpLabel 215 OpBranch %6 216 %5 = OpLabel 217 OpBranch %6 218 %6 = OpLabel 219 OpBranch %8 220 %7 = OpLabel 221 OpBranch %8 222 %8 = OpLabel 223 OpBranch %9 224 %9 = OpLabel 225 OpBranchConditional %108 %10 %11 226 %10 = OpLabel 227 OpBranch %11 228 %11 = OpLabel 229 OpBranchConditional %108 %12 %9 230 %12 = OpLabel 231 OpBranchConditional %108 %13 %2 232 %13 = OpLabel 233 OpReturn 234 OpFunctionEnd 235)"; 236 237 // CFG: (edges pointing downward if no arrow) 238 // %1 239 // | 240 // %2 <----+ 241 // T/ \F | 242 // %3 \ | 243 // T/ \F \ | 244 // %4 %5 %7 | 245 // \ / / | 246 // %6 / | 247 // \ / | 248 // %8 | 249 // | | 250 // %9 <-+ | 251 // T/ | | | 252 // %10 | | | 253 // \ | | | 254 // %11-F+ | 255 // T| | 256 // %12-F---+ 257 // T| 258 // %13 259 260 std::unique_ptr<IRContext> context = 261 BuildModule(SPV_ENV_UNIVERSAL_1_0, nullptr, text, 262 SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); 263 Module* module = context->module(); 264 EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n" 265 << text << std::endl; 266 const Function* fn = spvtest::GetFunction(module, 101); 267 const BasicBlock* entry = spvtest::GetBasicBlock(fn, 1); 268 EXPECT_EQ(entry, fn->entry().get()) 269 << "The entry node is not the expected one"; 270 271 { 272 PostDominatorAnalysis pdom; 273 const CFG& cfg = *context->cfg(); 274 pdom.InitializeTree(cfg, fn); 275 ControlDependenceAnalysis cdg; 276 cdg.ComputeControlDependenceGraph(cfg, pdom); 277 278 std::vector<ControlDependence> edges; 279 GatherEdges(cdg, edges); 280 EXPECT_THAT( 281 edges, testing::ElementsAre( 282 ControlDependence(0, 1), ControlDependence(0, 2, 1), 283 ControlDependence(0, 8, 1), ControlDependence(0, 9, 1), 284 ControlDependence(0, 11, 1), ControlDependence(0, 12, 1), 285 ControlDependence(0, 13, 1), ControlDependence(2, 3), 286 ControlDependence(2, 6, 3), ControlDependence(2, 7), 287 ControlDependence(3, 4), ControlDependence(3, 5), 288 ControlDependence(9, 10), ControlDependence(11, 9), 289 ControlDependence(11, 11, 9), ControlDependence(12, 2), 290 ControlDependence(12, 8, 2), ControlDependence(12, 9, 2), 291 ControlDependence(12, 11, 2), ControlDependence(12, 12, 2))); 292 293 const uint32_t expected_condition_ids[] = { 294 0, 0, 0, 0, 0, 0, 0, 108, 108, 108, 295 108, 108, 108, 108, 108, 108, 108, 108, 108, 108, 296 }; 297 298 for (uint32_t i = 0; i < edges.size(); i++) { 299 EXPECT_EQ(expected_condition_ids[i], edges[i].GetConditionID(cfg)); 300 } 301 } 302} 303 304} // namespace 305} // namespace opt 306} // namespace spvtools 307