1// Copyright (c) 2017 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include "source/opt/propagator.h" 16 17#include <map> 18#include <memory> 19#include <vector> 20 21#include "gmock/gmock.h" 22#include "gtest/gtest.h" 23#include "source/opt/build_module.h" 24#include "source/opt/cfg.h" 25#include "source/opt/ir_context.h" 26 27namespace spvtools { 28namespace opt { 29namespace { 30 31using ::testing::UnorderedElementsAre; 32 33class PropagatorTest : public testing::Test { 34 protected: 35 virtual void TearDown() { 36 ctx_.reset(nullptr); 37 values_.clear(); 38 values_vec_.clear(); 39 } 40 41 void Assemble(const std::string& input) { 42 ctx_ = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input); 43 ASSERT_NE(nullptr, ctx_) << "Assembling failed for shader:\n" 44 << input << "\n"; 45 } 46 47 bool Propagate(const SSAPropagator::VisitFunction& visit_fn) { 48 SSAPropagator propagator(ctx_.get(), visit_fn); 49 bool retval = false; 50 for (auto& fn : *ctx_->module()) { 51 retval |= propagator.Run(&fn); 52 } 53 return retval; 54 } 55 56 const std::vector<uint32_t>& GetValues() { 57 values_vec_.clear(); 58 for (const auto& it : values_) { 59 values_vec_.push_back(it.second); 60 } 61 return values_vec_; 62 } 63 64 std::unique_ptr<IRContext> ctx_; 65 std::map<uint32_t, uint32_t> values_; 66 std::vector<uint32_t> values_vec_; 67}; 68 69TEST_F(PropagatorTest, LocalPropagate) { 70 const std::string spv_asm = R"( 71 OpCapability Shader 72 %1 = OpExtInstImport "GLSL.std.450" 73 OpMemoryModel Logical GLSL450 74 OpEntryPoint Fragment %main "main" %outparm 75 OpExecutionMode %main OriginUpperLeft 76 OpSource GLSL 450 77 OpName %main "main" 78 OpName %x "x" 79 OpName %y "y" 80 OpName %z "z" 81 OpName %outparm "outparm" 82 OpDecorate %outparm Location 0 83 %void = OpTypeVoid 84 %3 = OpTypeFunction %void 85 %int = OpTypeInt 32 1 86%_ptr_Function_int = OpTypePointer Function %int 87 %int_4 = OpConstant %int 4 88 %int_3 = OpConstant %int 3 89 %int_1 = OpConstant %int 1 90%_ptr_Output_int = OpTypePointer Output %int 91 %outparm = OpVariable %_ptr_Output_int Output 92 %main = OpFunction %void None %3 93 %5 = OpLabel 94 %x = OpVariable %_ptr_Function_int Function 95 %y = OpVariable %_ptr_Function_int Function 96 %z = OpVariable %_ptr_Function_int Function 97 OpStore %x %int_4 98 OpStore %y %int_3 99 OpStore %z %int_1 100 %20 = OpLoad %int %z 101 OpStore %outparm %20 102 OpReturn 103 OpFunctionEnd 104 )"; 105 Assemble(spv_asm); 106 107 const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) { 108 *dest_bb = nullptr; 109 if (instr->opcode() == spv::Op::OpStore) { 110 uint32_t lhs_id = instr->GetSingleWordOperand(0); 111 uint32_t rhs_id = instr->GetSingleWordOperand(1); 112 Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id); 113 if (rhs_def->opcode() == spv::Op::OpConstant) { 114 uint32_t val = rhs_def->GetSingleWordOperand(2); 115 values_[lhs_id] = val; 116 return SSAPropagator::kInteresting; 117 } 118 } 119 return SSAPropagator::kVarying; 120 }; 121 122 EXPECT_TRUE(Propagate(visit_fn)); 123 EXPECT_THAT(GetValues(), UnorderedElementsAre(4, 3, 1)); 124} 125 126TEST_F(PropagatorTest, PropagateThroughPhis) { 127 const std::string spv_asm = R"( 128 OpCapability Shader 129 %1 = OpExtInstImport "GLSL.std.450" 130 OpMemoryModel Logical GLSL450 131 OpEntryPoint Fragment %main "main" %x %outparm 132 OpExecutionMode %main OriginUpperLeft 133 OpSource GLSL 450 134 OpName %main "main" 135 OpName %x "x" 136 OpName %outparm "outparm" 137 OpDecorate %x Flat 138 OpDecorate %x Location 0 139 OpDecorate %outparm Location 0 140 %void = OpTypeVoid 141 %3 = OpTypeFunction %void 142 %int = OpTypeInt 32 1 143 %bool = OpTypeBool 144%_ptr_Function_int = OpTypePointer Function %int 145 %int_4 = OpConstant %int 4 146 %int_3 = OpConstant %int 3 147 %int_1 = OpConstant %int 1 148%_ptr_Input_int = OpTypePointer Input %int 149 %x = OpVariable %_ptr_Input_int Input 150%_ptr_Output_int = OpTypePointer Output %int 151 %outparm = OpVariable %_ptr_Output_int Output 152 %main = OpFunction %void None %3 153 %4 = OpLabel 154 %5 = OpLoad %int %x 155 %6 = OpSGreaterThan %bool %5 %int_3 156 OpSelectionMerge %25 None 157 OpBranchConditional %6 %22 %23 158 %22 = OpLabel 159 %7 = OpLoad %int %int_4 160 OpBranch %25 161 %23 = OpLabel 162 %8 = OpLoad %int %int_4 163 OpBranch %25 164 %25 = OpLabel 165 %35 = OpPhi %int %7 %22 %8 %23 166 OpStore %outparm %35 167 OpReturn 168 OpFunctionEnd 169 )"; 170 171 Assemble(spv_asm); 172 173 Instruction* phi_instr = nullptr; 174 const auto visit_fn = [this, &phi_instr](Instruction* instr, 175 BasicBlock** dest_bb) { 176 *dest_bb = nullptr; 177 if (instr->opcode() == spv::Op::OpLoad) { 178 uint32_t rhs_id = instr->GetSingleWordOperand(2); 179 Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id); 180 if (rhs_def->opcode() == spv::Op::OpConstant) { 181 uint32_t val = rhs_def->GetSingleWordOperand(2); 182 values_[instr->result_id()] = val; 183 return SSAPropagator::kInteresting; 184 } 185 } else if (instr->opcode() == spv::Op::OpPhi) { 186 phi_instr = instr; 187 SSAPropagator::PropStatus retval; 188 for (uint32_t i = 2; i < instr->NumOperands(); i += 2) { 189 uint32_t phi_arg_id = instr->GetSingleWordOperand(i); 190 auto it = values_.find(phi_arg_id); 191 if (it != values_.end()) { 192 EXPECT_EQ(it->second, 4u); 193 retval = SSAPropagator::kInteresting; 194 values_[instr->result_id()] = it->second; 195 } else { 196 retval = SSAPropagator::kNotInteresting; 197 break; 198 } 199 } 200 return retval; 201 } 202 203 return SSAPropagator::kVarying; 204 }; 205 206 EXPECT_TRUE(Propagate(visit_fn)); 207 208 // The propagator should've concluded that the Phi instruction has a constant 209 // value of 4. 210 EXPECT_NE(phi_instr, nullptr); 211 EXPECT_EQ(values_[phi_instr->result_id()], 4u); 212 213 EXPECT_THAT(GetValues(), UnorderedElementsAre(4u, 4u, 4u)); 214} 215 216} // namespace 217} // namespace opt 218} // namespace spvtools 219