1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/propagator.h"
16
17#include <map>
18#include <memory>
19#include <vector>
20
21#include "gmock/gmock.h"
22#include "gtest/gtest.h"
23#include "source/opt/build_module.h"
24#include "source/opt/cfg.h"
25#include "source/opt/ir_context.h"
26
27namespace spvtools {
28namespace opt {
29namespace {
30
31using ::testing::UnorderedElementsAre;
32
33class PropagatorTest : public testing::Test {
34 protected:
35  virtual void TearDown() {
36    ctx_.reset(nullptr);
37    values_.clear();
38    values_vec_.clear();
39  }
40
41  void Assemble(const std::string& input) {
42    ctx_ = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
43    ASSERT_NE(nullptr, ctx_) << "Assembling failed for shader:\n"
44                             << input << "\n";
45  }
46
47  bool Propagate(const SSAPropagator::VisitFunction& visit_fn) {
48    SSAPropagator propagator(ctx_.get(), visit_fn);
49    bool retval = false;
50    for (auto& fn : *ctx_->module()) {
51      retval |= propagator.Run(&fn);
52    }
53    return retval;
54  }
55
56  const std::vector<uint32_t>& GetValues() {
57    values_vec_.clear();
58    for (const auto& it : values_) {
59      values_vec_.push_back(it.second);
60    }
61    return values_vec_;
62  }
63
64  std::unique_ptr<IRContext> ctx_;
65  std::map<uint32_t, uint32_t> values_;
66  std::vector<uint32_t> values_vec_;
67};
68
69TEST_F(PropagatorTest, LocalPropagate) {
70  const std::string spv_asm = R"(
71               OpCapability Shader
72          %1 = OpExtInstImport "GLSL.std.450"
73               OpMemoryModel Logical GLSL450
74               OpEntryPoint Fragment %main "main" %outparm
75               OpExecutionMode %main OriginUpperLeft
76               OpSource GLSL 450
77               OpName %main "main"
78               OpName %x "x"
79               OpName %y "y"
80               OpName %z "z"
81               OpName %outparm "outparm"
82               OpDecorate %outparm Location 0
83       %void = OpTypeVoid
84          %3 = OpTypeFunction %void
85        %int = OpTypeInt 32 1
86%_ptr_Function_int = OpTypePointer Function %int
87      %int_4 = OpConstant %int 4
88      %int_3 = OpConstant %int 3
89      %int_1 = OpConstant %int 1
90%_ptr_Output_int = OpTypePointer Output %int
91    %outparm = OpVariable %_ptr_Output_int Output
92       %main = OpFunction %void None %3
93          %5 = OpLabel
94          %x = OpVariable %_ptr_Function_int Function
95          %y = OpVariable %_ptr_Function_int Function
96          %z = OpVariable %_ptr_Function_int Function
97               OpStore %x %int_4
98               OpStore %y %int_3
99               OpStore %z %int_1
100         %20 = OpLoad %int %z
101               OpStore %outparm %20
102               OpReturn
103               OpFunctionEnd
104               )";
105  Assemble(spv_asm);
106
107  const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) {
108    *dest_bb = nullptr;
109    if (instr->opcode() == spv::Op::OpStore) {
110      uint32_t lhs_id = instr->GetSingleWordOperand(0);
111      uint32_t rhs_id = instr->GetSingleWordOperand(1);
112      Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
113      if (rhs_def->opcode() == spv::Op::OpConstant) {
114        uint32_t val = rhs_def->GetSingleWordOperand(2);
115        values_[lhs_id] = val;
116        return SSAPropagator::kInteresting;
117      }
118    }
119    return SSAPropagator::kVarying;
120  };
121
122  EXPECT_TRUE(Propagate(visit_fn));
123  EXPECT_THAT(GetValues(), UnorderedElementsAre(4, 3, 1));
124}
125
126TEST_F(PropagatorTest, PropagateThroughPhis) {
127  const std::string spv_asm = R"(
128               OpCapability Shader
129          %1 = OpExtInstImport "GLSL.std.450"
130               OpMemoryModel Logical GLSL450
131               OpEntryPoint Fragment %main "main" %x %outparm
132               OpExecutionMode %main OriginUpperLeft
133               OpSource GLSL 450
134               OpName %main "main"
135               OpName %x "x"
136               OpName %outparm "outparm"
137               OpDecorate %x Flat
138               OpDecorate %x Location 0
139               OpDecorate %outparm Location 0
140       %void = OpTypeVoid
141          %3 = OpTypeFunction %void
142        %int = OpTypeInt 32 1
143       %bool = OpTypeBool
144%_ptr_Function_int = OpTypePointer Function %int
145      %int_4 = OpConstant %int 4
146      %int_3 = OpConstant %int 3
147      %int_1 = OpConstant %int 1
148%_ptr_Input_int = OpTypePointer Input %int
149          %x = OpVariable %_ptr_Input_int Input
150%_ptr_Output_int = OpTypePointer Output %int
151    %outparm = OpVariable %_ptr_Output_int Output
152       %main = OpFunction %void None %3
153          %4 = OpLabel
154          %5 = OpLoad %int %x
155          %6 = OpSGreaterThan %bool %5 %int_3
156               OpSelectionMerge %25 None
157               OpBranchConditional %6 %22 %23
158         %22 = OpLabel
159          %7 = OpLoad %int %int_4
160               OpBranch %25
161         %23 = OpLabel
162          %8 = OpLoad %int %int_4
163               OpBranch %25
164         %25 = OpLabel
165         %35 = OpPhi %int %7 %22 %8 %23
166               OpStore %outparm %35
167               OpReturn
168               OpFunctionEnd
169               )";
170
171  Assemble(spv_asm);
172
173  Instruction* phi_instr = nullptr;
174  const auto visit_fn = [this, &phi_instr](Instruction* instr,
175                                           BasicBlock** dest_bb) {
176    *dest_bb = nullptr;
177    if (instr->opcode() == spv::Op::OpLoad) {
178      uint32_t rhs_id = instr->GetSingleWordOperand(2);
179      Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
180      if (rhs_def->opcode() == spv::Op::OpConstant) {
181        uint32_t val = rhs_def->GetSingleWordOperand(2);
182        values_[instr->result_id()] = val;
183        return SSAPropagator::kInteresting;
184      }
185    } else if (instr->opcode() == spv::Op::OpPhi) {
186      phi_instr = instr;
187      SSAPropagator::PropStatus retval;
188      for (uint32_t i = 2; i < instr->NumOperands(); i += 2) {
189        uint32_t phi_arg_id = instr->GetSingleWordOperand(i);
190        auto it = values_.find(phi_arg_id);
191        if (it != values_.end()) {
192          EXPECT_EQ(it->second, 4u);
193          retval = SSAPropagator::kInteresting;
194          values_[instr->result_id()] = it->second;
195        } else {
196          retval = SSAPropagator::kNotInteresting;
197          break;
198        }
199      }
200      return retval;
201    }
202
203    return SSAPropagator::kVarying;
204  };
205
206  EXPECT_TRUE(Propagate(visit_fn));
207
208  // The propagator should've concluded that the Phi instruction has a constant
209  // value of 4.
210  EXPECT_NE(phi_instr, nullptr);
211  EXPECT_EQ(values_[phi_instr->result_id()], 4u);
212
213  EXPECT_THAT(GetValues(), UnorderedElementsAre(4u, 4u, 4u));
214}
215
216}  // namespace
217}  // namespace opt
218}  // namespace spvtools
219