1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
16#include <unordered_set>
17#include <vector>
18
19#include "gmock/gmock.h"
20#include "test/opt/pass_fixture.h"
21#include "test/opt/pass_utils.h"
22
23namespace spvtools {
24namespace opt {
25namespace {
26
27using ::testing::HasSubstr;
28using ::testing::MatchesRegex;
29using StrengthReductionBasicTest = PassTest<::testing::Test>;
30
31// Test to make sure we replace 5*8.
32TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) {
33  const std::vector<const char*> text = {
34      // clang-format off
35               "OpCapability Shader",
36          "%1 = OpExtInstImport \"GLSL.std.450\"",
37               "OpMemoryModel Logical GLSL450",
38               "OpEntryPoint Vertex %main \"main\"",
39               "OpName %main \"main\"",
40       "%void = OpTypeVoid",
41          "%4 = OpTypeFunction %void",
42       "%uint = OpTypeInt 32 0",
43     "%uint_5 = OpConstant %uint 5",
44     "%uint_8 = OpConstant %uint 8",
45       "%main = OpFunction %void None %4",
46          "%8 = OpLabel",
47          "%9 = OpIMul %uint %uint_5 %uint_8",
48               "OpReturn",
49               "OpFunctionEnd"
50      // clang-format on
51  };
52
53  auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
54      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
55
56  EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
57  const std::string& output = std::get<0>(result);
58  EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
59  EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3"));
60}
61
62// TODO(dneto): Add Effcee as required dependency, and make this unconditional.
63// Test to make sure we replace 16*5
64// Also demonstrate use of Effcee matching.
65TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) {
66  const std::string text = R"(
67               OpCapability Shader
68          %1 = OpExtInstImport "GLSL.std.450"
69               OpMemoryModel Logical GLSL450
70               OpEntryPoint Vertex %main "main"
71               OpName %main "main"
72       %void = OpTypeVoid
73          %4 = OpTypeFunction %void
74; We know disassembly will produce %uint here, but
75;  CHECK: %uint = OpTypeInt 32 0
76;  CHECK-DAG: [[five:%[a-zA-Z_\d]+]] = OpConstant %uint 5
77
78; We have RE2 regular expressions, so \w matches [_a-zA-Z0-9].
79; This shows the preferred pattern for matching SPIR-V identifiers.
80; (We could have cheated in this case since we know the disassembler will
81; generate the 'nice' name of "%uint_4".
82;  CHECK-DAG: [[four:%\w+]] = OpConstant %uint 4
83       %uint = OpTypeInt 32 0
84     %uint_5 = OpConstant %uint 5
85    %uint_16 = OpConstant %uint 16
86       %main = OpFunction %void None %4
87; CHECK: OpLabel
88          %8 = OpLabel
89; CHECK-NEXT: OpShiftLeftLogical %uint [[five]] [[four]]
90; The multiplication disappears.
91; CHECK-NOT: OpIMul
92          %9 = OpIMul %uint %uint_16 %uint_5
93               OpReturn
94; CHECK: OpFunctionEnd
95               OpFunctionEnd)";
96
97  SinglePassRunAndMatch<StrengthReductionPass>(text, false);
98}
99
100// Test to make sure we replace a multiple of 32 and 4.
101TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) {
102  // In this case, we have two powers of 2.  Need to make sure we replace only
103  // one of them for the bit shift.
104  // clang-format off
105  const std::string text = R"(
106          OpCapability Shader
107     %1 = OpExtInstImport "GLSL.std.450"
108          OpMemoryModel Logical GLSL450
109          OpEntryPoint Vertex %main "main"
110          OpName %main "main"
111  %void = OpTypeVoid
112     %4 = OpTypeFunction %void
113   %int = OpTypeInt 32 1
114%int_32 = OpConstant %int 32
115 %int_4 = OpConstant %int 4
116  %main = OpFunction %void None %4
117     %8 = OpLabel
118     %9 = OpIMul %int %int_32 %int_4
119          OpReturn
120          OpFunctionEnd
121)";
122  // clang-format on
123  auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
124      text, /* skip_nop = */ true, /* do_validation = */ false);
125
126  EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
127  const std::string& output = std::get<0>(result);
128  EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
129  EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5"));
130}
131
132// Test to make sure we don't replace 0*5.
133TEST_F(StrengthReductionBasicTest, BasicDontReplace0) {
134  const std::vector<const char*> text = {
135      // clang-format off
136               "OpCapability Shader",
137          "%1 = OpExtInstImport \"GLSL.std.450\"",
138               "OpMemoryModel Logical GLSL450",
139               "OpEntryPoint Vertex %main \"main\"",
140               "OpName %main \"main\"",
141       "%void = OpTypeVoid",
142          "%4 = OpTypeFunction %void",
143        "%int = OpTypeInt 32 1",
144      "%int_0 = OpConstant %int 0",
145      "%int_5 = OpConstant %int 5",
146       "%main = OpFunction %void None %4",
147          "%8 = OpLabel",
148          "%9 = OpIMul %int %int_0 %int_5",
149               "OpReturn",
150               "OpFunctionEnd"
151      // clang-format on
152  };
153
154  auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
155      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
156
157  EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
158}
159
160// Test to make sure we do not replace a multiple of 5 and 7.
161TEST_F(StrengthReductionBasicTest, BasicNoChange) {
162  const std::vector<const char*> text = {
163      // clang-format off
164             "OpCapability Shader",
165        "%1 = OpExtInstImport \"GLSL.std.450\"",
166             "OpMemoryModel Logical GLSL450",
167             "OpEntryPoint Vertex %2 \"main\"",
168             "OpName %2 \"main\"",
169        "%3 = OpTypeVoid",
170        "%4 = OpTypeFunction %3",
171        "%5 = OpTypeInt 32 1",
172        "%6 = OpTypeInt 32 0",
173        "%7 = OpConstant %5 5",
174        "%8 = OpConstant %5 7",
175        "%2 = OpFunction %3 None %4",
176        "%9 = OpLabel",
177        "%10 = OpIMul %5 %7 %8",
178             "OpReturn",
179             "OpFunctionEnd",
180      // clang-format on
181  };
182
183  auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
184      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
185
186  EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
187}
188
189// Test to make sure constants and types are reused and not duplicated.
190TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) {
191  const std::vector<const char*> text = {
192      // clang-format off
193               "OpCapability Shader",
194          "%1 = OpExtInstImport \"GLSL.std.450\"",
195               "OpMemoryModel Logical GLSL450",
196               "OpEntryPoint Vertex %main \"main\"",
197               "OpName %main \"main\"",
198       "%void = OpTypeVoid",
199          "%4 = OpTypeFunction %void",
200       "%uint = OpTypeInt 32 0",
201     "%uint_8 = OpConstant %uint 8",
202     "%uint_3 = OpConstant %uint 3",
203       "%main = OpFunction %void None %4",
204          "%8 = OpLabel",
205          "%9 = OpIMul %uint %uint_8 %uint_3",
206               "OpReturn",
207               "OpFunctionEnd",
208      // clang-format on
209  };
210
211  auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
212      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
213
214  EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
215  const std::string& output = std::get<0>(result);
216  EXPECT_THAT(output,
217              Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*")));
218  EXPECT_THAT(output, Not(MatchesRegex(".*OpTypeInt 32 0.*OpTypeInt 32 0.*")));
219}
220
221// Test to make sure we generate the constants only once
222TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) {
223  const std::vector<const char*> text = {
224      // clang-format off
225               "OpCapability Shader",
226          "%1 = OpExtInstImport \"GLSL.std.450\"",
227               "OpMemoryModel Logical GLSL450",
228               "OpEntryPoint Vertex %main \"main\"",
229               "OpName %main \"main\"",
230       "%void = OpTypeVoid",
231          "%4 = OpTypeFunction %void",
232       "%uint = OpTypeInt 32 0",
233     "%uint_5 = OpConstant %uint 5",
234     "%uint_9 = OpConstant %uint 9",
235   "%uint_128 = OpConstant %uint 128",
236       "%main = OpFunction %void None %4",
237          "%8 = OpLabel",
238          "%9 = OpIMul %uint %uint_5 %uint_128",
239         "%10 = OpIMul %uint %uint_9 %uint_128",
240               "OpReturn",
241               "OpFunctionEnd"
242      // clang-format on
243  };
244
245  auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
246      JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
247
248  EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
249  const std::string& output = std::get<0>(result);
250  EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
251  EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7"));
252  EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_9 %uint_7"));
253}
254
255// Test to make sure we generate the instructions in the correct position and
256// that the uses get replaced as well.  Here we check that the use in the return
257// is replaced, we also check that we can replace two OpIMuls when one feeds the
258// other.
259TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) {
260  // This is just the preamble to set up the test.
261  const std::vector<const char*> common_text = {
262      // clang-format off
263               "OpCapability Shader",
264          "%1 = OpExtInstImport \"GLSL.std.450\"",
265               "OpMemoryModel Logical GLSL450",
266               "OpEntryPoint Fragment %main \"main\" %gl_FragColor",
267               "OpExecutionMode %main OriginUpperLeft",
268               "OpName %main \"main\"",
269               "OpName %foo_i1_ \"foo(i1;\"",
270               "OpName %n \"n\"",
271               "OpName %gl_FragColor \"gl_FragColor\"",
272               "OpName %param \"param\"",
273               "OpDecorate %gl_FragColor Location 0",
274       "%void = OpTypeVoid",
275          "%3 = OpTypeFunction %void",
276        "%int = OpTypeInt 32 1",
277"%_ptr_Function_int = OpTypePointer Function %int",
278          "%8 = OpTypeFunction %int %_ptr_Function_int",
279    "%int_256 = OpConstant %int 256",
280      "%int_2 = OpConstant %int 2",
281      "%float = OpTypeFloat 32",
282    "%v4float = OpTypeVector %float 4",
283"%_ptr_Output_v4float = OpTypePointer Output %v4float",
284"%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
285    "%float_1 = OpConstant %float 1",
286     "%int_10 = OpConstant %int 10",
287  "%float_0_375 = OpConstant %float 0.375",
288  "%float_0_75 = OpConstant %float 0.75",
289       "%uint = OpTypeInt 32 0",
290     "%uint_8 = OpConstant %uint 8",
291     "%uint_1 = OpConstant %uint 1",
292       "%main = OpFunction %void None %3",
293          "%5 = OpLabel",
294      "%param = OpVariable %_ptr_Function_int Function",
295               "OpStore %param %int_10",
296         "%26 = OpFunctionCall %int %foo_i1_ %param",
297         "%27 = OpConvertSToF %float %26",
298         "%28 = OpFDiv %float %float_1 %27",
299         "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1",
300               "OpStore %gl_FragColor %31",
301               "OpReturn",
302               "OpFunctionEnd"
303      // clang-format on
304  };
305
306  // This is the real test.  The two OpIMul should be replaced.  The expected
307  // output is in |foo_after|.
308  const std::vector<const char*> foo_before = {
309      // clang-format off
310    "%foo_i1_ = OpFunction %int None %8",
311          "%n = OpFunctionParameter %_ptr_Function_int",
312         "%11 = OpLabel",
313         "%12 = OpLoad %int %n",
314         "%14 = OpIMul %int %12 %int_256",
315         "%16 = OpIMul %int %14 %int_2",
316               "OpReturnValue %16",
317               "OpFunctionEnd",
318
319      // clang-format on
320  };
321
322  const std::vector<const char*> foo_after = {
323      // clang-format off
324    "%foo_i1_ = OpFunction %int None %8",
325          "%n = OpFunctionParameter %_ptr_Function_int",
326         "%11 = OpLabel",
327         "%12 = OpLoad %int %n",
328         "%33 = OpShiftLeftLogical %int %12 %uint_8",
329         "%34 = OpShiftLeftLogical %int %33 %uint_1",
330               "OpReturnValue %34",
331               "OpFunctionEnd",
332      // clang-format on
333  };
334
335  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
336  SinglePassRunAndCheck<StrengthReductionPass>(
337      JoinAllInsts(Concat(common_text, foo_before)),
338      JoinAllInsts(Concat(common_text, foo_after)),
339      /* skip_nop = */ true, /* do_validate = */ true);
340}
341
342// Test that, when the result of an OpIMul instruction has more than 1 use, and
343// the instruction is replaced, all of the uses of the results are replace with
344// the new result.
345TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) {
346  // This is just the preamble to set up the test.
347  const std::vector<const char*> common_text = {
348      // clang-format off
349               "OpCapability Shader",
350          "%1 = OpExtInstImport \"GLSL.std.450\"",
351               "OpMemoryModel Logical GLSL450",
352               "OpEntryPoint Fragment %main \"main\" %gl_FragColor",
353               "OpExecutionMode %main OriginUpperLeft",
354               "OpName %main \"main\"",
355               "OpName %foo_i1_ \"foo(i1;\"",
356               "OpName %n \"n\"",
357               "OpName %gl_FragColor \"gl_FragColor\"",
358               "OpName %param \"param\"",
359               "OpDecorate %gl_FragColor Location 0",
360       "%void = OpTypeVoid",
361          "%3 = OpTypeFunction %void",
362        "%int = OpTypeInt 32 1",
363"%_ptr_Function_int = OpTypePointer Function %int",
364          "%8 = OpTypeFunction %int %_ptr_Function_int",
365    "%int_256 = OpConstant %int 256",
366      "%int_2 = OpConstant %int 2",
367      "%float = OpTypeFloat 32",
368    "%v4float = OpTypeVector %float 4",
369"%_ptr_Output_v4float = OpTypePointer Output %v4float",
370"%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
371    "%float_1 = OpConstant %float 1",
372     "%int_10 = OpConstant %int 10",
373  "%float_0_375 = OpConstant %float 0.375",
374  "%float_0_75 = OpConstant %float 0.75",
375       "%uint = OpTypeInt 32 0",
376     "%uint_8 = OpConstant %uint 8",
377     "%uint_1 = OpConstant %uint 1",
378       "%main = OpFunction %void None %3",
379          "%5 = OpLabel",
380      "%param = OpVariable %_ptr_Function_int Function",
381               "OpStore %param %int_10",
382         "%26 = OpFunctionCall %int %foo_i1_ %param",
383         "%27 = OpConvertSToF %float %26",
384         "%28 = OpFDiv %float %float_1 %27",
385         "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1",
386               "OpStore %gl_FragColor %31",
387               "OpReturn",
388               "OpFunctionEnd"
389      // clang-format on
390  };
391
392  // This is the real test.  The two OpIMul instructions should be replaced.  In
393  // particular, we want to be sure that both uses of %16 are changed to use the
394  // new result.
395  const std::vector<const char*> foo_before = {
396      // clang-format off
397    "%foo_i1_ = OpFunction %int None %8",
398          "%n = OpFunctionParameter %_ptr_Function_int",
399         "%11 = OpLabel",
400         "%12 = OpLoad %int %n",
401         "%14 = OpIMul %int %12 %int_256",
402         "%16 = OpIMul %int %14 %int_2",
403         "%17 = OpIAdd %int %14 %16",
404               "OpReturnValue %17",
405               "OpFunctionEnd",
406
407      // clang-format on
408  };
409
410  const std::vector<const char*> foo_after = {
411      // clang-format off
412    "%foo_i1_ = OpFunction %int None %8",
413          "%n = OpFunctionParameter %_ptr_Function_int",
414         "%11 = OpLabel",
415         "%12 = OpLoad %int %n",
416         "%34 = OpShiftLeftLogical %int %12 %uint_8",
417         "%35 = OpShiftLeftLogical %int %34 %uint_1",
418         "%17 = OpIAdd %int %34 %35",
419               "OpReturnValue %17",
420               "OpFunctionEnd",
421      // clang-format on
422  };
423
424  SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
425  SinglePassRunAndCheck<StrengthReductionPass>(
426      JoinAllInsts(Concat(common_text, foo_before)),
427      JoinAllInsts(Concat(common_text, foo_after)),
428      /* skip_nop = */ true, /* do_validate = */ true);
429}
430
431}  // namespace
432}  // namespace opt
433}  // namespace spvtools
434