1// Copyright (c) 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/reduce/reducer.h"
16
17#include <unordered_map>
18
19#include "source/opt/build_module.h"
20#include "source/reduce/operand_to_const_reduction_opportunity_finder.h"
21#include "source/reduce/remove_unused_instruction_reduction_opportunity_finder.h"
22#include "test/reduce/reduce_test_util.h"
23
24namespace spvtools {
25namespace reduce {
26namespace {
27
28const spv_target_env kEnv = SPV_ENV_UNIVERSAL_1_3;
29const MessageConsumer kMessageConsumer = NopDiagnostic;
30
31// This changes its mind each time IsInteresting is invoked as to whether the
32// binary is interesting, until some limit is reached after which the binary is
33// always deemed interesting.  This is useful to test that reduction passes
34// interleave in interesting ways for a while, and then always succeed after
35// some point; the latter is important to end up with a predictable final
36// reduced binary for tests.
37class PingPongInteresting {
38 public:
39  explicit PingPongInteresting(uint32_t always_interesting_after)
40      : is_interesting_(true),
41        always_interesting_after_(always_interesting_after),
42        count_(0) {}
43
44  bool IsInteresting() {
45    bool result;
46    if (count_ > always_interesting_after_) {
47      result = true;
48    } else {
49      result = is_interesting_;
50      is_interesting_ = !is_interesting_;
51    }
52    count_++;
53    return result;
54  }
55
56 private:
57  bool is_interesting_;
58  const uint32_t always_interesting_after_;
59  uint32_t count_;
60};
61
62TEST(ReducerTest, ExprToConstantAndRemoveUnreferenced) {
63  // Check that ExprToConstant and RemoveUnreferenced work together; once some
64  // ID uses have been changed to constants, those IDs can be removed.
65  std::string original = R"(
66               OpCapability Shader
67          %1 = OpExtInstImport "GLSL.std.450"
68               OpMemoryModel Logical GLSL450
69               OpEntryPoint Fragment %4 "main" %60
70               OpExecutionMode %4 OriginUpperLeft
71               OpSource ESSL 310
72               OpName %4 "main"
73               OpName %16 "buf2"
74               OpMemberName %16 0 "i"
75               OpName %18 ""
76               OpName %25 "buf1"
77               OpMemberName %25 0 "f"
78               OpName %27 ""
79               OpName %60 "_GLF_color"
80               OpMemberDecorate %16 0 Offset 0
81               OpDecorate %16 Block
82               OpDecorate %18 DescriptorSet 0
83               OpDecorate %18 Binding 2
84               OpMemberDecorate %25 0 Offset 0
85               OpDecorate %25 Block
86               OpDecorate %27 DescriptorSet 0
87               OpDecorate %27 Binding 1
88               OpDecorate %60 Location 0
89          %2 = OpTypeVoid
90          %3 = OpTypeFunction %2
91          %6 = OpTypeInt 32 1
92          %9 = OpConstant %6 0
93         %16 = OpTypeStruct %6
94         %17 = OpTypePointer Uniform %16
95         %18 = OpVariable %17 Uniform
96         %19 = OpTypePointer Uniform %6
97         %22 = OpTypeBool
98        %100 = OpConstantTrue %22
99         %24 = OpTypeFloat 32
100         %25 = OpTypeStruct %24
101         %26 = OpTypePointer Uniform %25
102         %27 = OpVariable %26 Uniform
103         %28 = OpTypePointer Uniform %24
104         %31 = OpConstant %24 2
105         %56 = OpConstant %6 1
106         %58 = OpTypeVector %24 4
107         %59 = OpTypePointer Output %58
108         %60 = OpVariable %59 Output
109         %72 = OpUndef %24
110         %74 = OpUndef %6
111          %4 = OpFunction %2 None %3
112          %5 = OpLabel
113               OpBranch %10
114         %10 = OpLabel
115         %73 = OpPhi %6 %74 %5 %77 %34
116         %71 = OpPhi %24 %72 %5 %76 %34
117         %70 = OpPhi %6 %9 %5 %57 %34
118         %20 = OpAccessChain %19 %18 %9
119         %21 = OpLoad %6 %20
120         %23 = OpSLessThan %22 %70 %21
121               OpLoopMerge %12 %34 None
122               OpBranchConditional %23 %11 %12
123         %11 = OpLabel
124         %29 = OpAccessChain %28 %27 %9
125         %30 = OpLoad %24 %29
126         %32 = OpFOrdGreaterThan %22 %30 %31
127               OpSelectionMerge %90 None
128               OpBranchConditional %32 %33 %46
129         %33 = OpLabel
130         %40 = OpFAdd %24 %71 %30
131         %45 = OpISub %6 %73 %21
132               OpBranch %90
133         %46 = OpLabel
134         %50 = OpFMul %24 %71 %30
135         %54 = OpSDiv %6 %73 %21
136               OpBranch %90
137         %90 = OpLabel
138         %77 = OpPhi %6 %45 %33 %54 %46
139         %76 = OpPhi %24 %40 %33 %50 %46
140               OpBranch %34
141         %34 = OpLabel
142         %57 = OpIAdd %6 %70 %56
143               OpBranch %10
144         %12 = OpLabel
145         %61 = OpAccessChain %28 %27 %9
146         %62 = OpLoad %24 %61
147         %66 = OpConvertSToF %24 %21
148         %68 = OpConvertSToF %24 %73
149         %69 = OpCompositeConstruct %58 %62 %71 %66 %68
150               OpStore %60 %69
151               OpReturn
152               OpFunctionEnd
153  )";
154
155  std::string expected = R"(
156               OpCapability Shader
157          %1 = OpExtInstImport "GLSL.std.450"
158               OpMemoryModel Logical GLSL450
159               OpEntryPoint Fragment %4 "main"
160               OpExecutionMode %4 OriginUpperLeft
161          %2 = OpTypeVoid
162          %3 = OpTypeFunction %2
163          %6 = OpTypeInt 32 1
164          %9 = OpConstant %6 0
165         %22 = OpTypeBool
166        %100 = OpConstantTrue %22
167         %24 = OpTypeFloat 32
168         %31 = OpConstant %24 2
169         %56 = OpConstant %6 1
170         %72 = OpUndef %24
171         %74 = OpUndef %6
172          %4 = OpFunction %2 None %3
173          %5 = OpLabel
174               OpBranch %10
175         %10 = OpLabel
176               OpLoopMerge %12 %34 None
177               OpBranchConditional %100 %11 %12
178         %11 = OpLabel
179               OpSelectionMerge %90 None
180               OpBranchConditional %100 %33 %46
181         %33 = OpLabel
182               OpBranch %90
183         %46 = OpLabel
184               OpBranch %90
185         %90 = OpLabel
186               OpBranch %34
187         %34 = OpLabel
188               OpBranch %10
189         %12 = OpLabel
190               OpReturn
191               OpFunctionEnd
192  )";
193
194  Reducer reducer(kEnv);
195  PingPongInteresting ping_pong_interesting(10);
196  reducer.SetMessageConsumer(kMessageConsumer);
197  reducer.SetInterestingnessFunction(
198      [&ping_pong_interesting](const std::vector<uint32_t>&, uint32_t) -> bool {
199        return ping_pong_interesting.IsInteresting();
200      });
201  reducer.AddReductionPass(
202      MakeUnique<RemoveUnusedInstructionReductionOpportunityFinder>(false));
203  reducer.AddReductionPass(
204      MakeUnique<OperandToConstReductionOpportunityFinder>());
205
206  std::vector<uint32_t> binary_in;
207  SpirvTools t(kEnv);
208
209  ASSERT_TRUE(t.Assemble(original, &binary_in, kReduceAssembleOption));
210  std::vector<uint32_t> binary_out;
211  spvtools::ReducerOptions reducer_options;
212  reducer_options.set_step_limit(500);
213  reducer_options.set_fail_on_validation_error(true);
214  spvtools::ValidatorOptions validator_options;
215
216  Reducer::ReductionResultStatus status = reducer.Run(
217      std::move(binary_in), &binary_out, reducer_options, validator_options);
218
219  ASSERT_EQ(status, Reducer::ReductionResultStatus::kComplete);
220
221  CheckEqual(kEnv, expected, binary_out);
222}
223
224bool InterestingWhileOpcodeExists(const std::vector<uint32_t>& binary,
225                                  spv::Op opcode, uint32_t count, bool dump) {
226  if (dump) {
227    std::stringstream ss;
228    ss << "temp_" << count << ".spv";
229    DumpShader(binary, ss.str().c_str());
230  }
231
232  std::unique_ptr<opt::IRContext> context =
233      BuildModule(kEnv, kMessageConsumer, binary.data(), binary.size());
234  assert(context);
235  bool interesting = false;
236  for (auto& function : *context->module()) {
237    context->cfg()->ForEachBlockInPostOrder(
238        &*function.begin(),
239        [opcode, &interesting](opt::BasicBlock* block) -> void {
240          for (auto& inst : *block) {
241            if (inst.opcode() == spv::Op(opcode)) {
242              interesting = true;
243              break;
244            }
245          }
246        });
247    if (interesting) {
248      break;
249    }
250  }
251  return interesting;
252}
253
254bool InterestingWhileIMulReachable(const std::vector<uint32_t>& binary,
255                                   uint32_t count) {
256  return InterestingWhileOpcodeExists(binary, spv::Op::OpIMul, count, false);
257}
258
259bool InterestingWhileSDivReachable(const std::vector<uint32_t>& binary,
260                                   uint32_t count) {
261  return InterestingWhileOpcodeExists(binary, spv::Op::OpSDiv, count, false);
262}
263
264// The shader below was derived from the following GLSL, and optimized.
265// #version 310 es
266// precision highp float;
267// layout(location = 0) out vec4 _GLF_color;
268// int foo() {
269//    int x = 1;
270//    int y;
271//    x = y / x;   // SDiv
272//    return x;
273// }
274// void main() {
275//    int c;
276//    while (bool(c)) {
277//        do {
278//            if (bool(c)) {
279//                if (bool(c)) {
280//                    ++c;
281//                } else {
282//                    _GLF_color.x = float(c*c);  // IMul
283//                }
284//                return;
285//            }
286//        } while(bool(foo()));
287//        return;
288//    }
289// }
290const std::string kShaderWithLoopsDivAndMul = R"(
291               OpCapability Shader
292          %1 = OpExtInstImport "GLSL.std.450"
293               OpMemoryModel Logical GLSL450
294               OpEntryPoint Fragment %4 "main" %49
295               OpExecutionMode %4 OriginUpperLeft
296               OpSource ESSL 310
297               OpName %4 "main"
298               OpName %49 "_GLF_color"
299               OpDecorate %49 Location 0
300               OpDecorate %52 RelaxedPrecision
301               OpDecorate %77 RelaxedPrecision
302          %2 = OpTypeVoid
303          %3 = OpTypeFunction %2
304          %6 = OpTypeInt 32 1
305         %12 = OpConstant %6 1
306         %27 = OpTypeBool
307         %28 = OpTypeInt 32 0
308         %29 = OpConstant %28 0
309         %46 = OpTypeFloat 32
310         %47 = OpTypeVector %46 4
311         %48 = OpTypePointer Output %47
312         %49 = OpVariable %48 Output
313         %54 = OpTypePointer Output %46
314         %64 = OpConstantFalse %27
315         %67 = OpConstantTrue %27
316         %81 = OpUndef %6
317          %4 = OpFunction %2 None %3
318          %5 = OpLabel
319               OpBranch %61
320         %61 = OpLabel
321               OpLoopMerge %60 %63 None
322               OpBranch %20
323         %20 = OpLabel
324         %30 = OpINotEqual %27 %81 %29
325               OpLoopMerge %22 %23 None
326               OpBranchConditional %30 %21 %22
327         %21 = OpLabel
328               OpBranch %31
329         %31 = OpLabel
330               OpLoopMerge %33 %38 None
331               OpBranch %32
332         %32 = OpLabel
333               OpBranchConditional %30 %37 %38
334         %37 = OpLabel
335               OpSelectionMerge %42 None
336               OpBranchConditional %30 %41 %45
337         %41 = OpLabel
338               OpBranch %42
339         %45 = OpLabel
340         %52 = OpIMul %6 %81 %81
341         %53 = OpConvertSToF %46 %52
342         %55 = OpAccessChain %54 %49 %29
343               OpStore %55 %53
344               OpBranch %42
345         %42 = OpLabel
346               OpBranch %33
347         %38 = OpLabel
348         %77 = OpSDiv %6 %81 %12
349         %58 = OpINotEqual %27 %77 %29
350               OpBranchConditional %58 %31 %33
351         %33 = OpLabel
352         %86 = OpPhi %27 %67 %42 %64 %38
353               OpSelectionMerge %68 None
354               OpBranchConditional %86 %22 %68
355         %68 = OpLabel
356               OpBranch %22
357         %23 = OpLabel
358               OpBranch %20
359         %22 = OpLabel
360         %90 = OpPhi %27 %64 %20 %86 %33 %67 %68
361               OpSelectionMerge %70 None
362               OpBranchConditional %90 %60 %70
363         %70 = OpLabel
364               OpBranch %60
365         %63 = OpLabel
366               OpBranch %61
367         %60 = OpLabel
368               OpReturn
369               OpFunctionEnd
370  )";
371
372// The shader below comes from the following GLSL.
373// #version 320 es
374//
375//  int baz(int x) {
376//   int y = x + 1;
377//   y = y + 2;
378//   if (y > 0) {
379//     return x;
380//   }
381//   return x + 1;
382// }
383//
384//  int bar(int a) {
385//   if (a == 3) {
386//     return baz(2*a);
387//   }
388//   a = a + 1;
389//   for (int i = 0; i < 10; i++) {
390//     a += baz(a);
391//   }
392//   return a;
393// }
394//
395//  void main() {
396//   int x;
397//   x = 3;
398//   x += 1;
399//   x += bar(x);
400//   x += baz(x);
401// }
402const std::string kShaderWithMultipleFunctions = R"(
403               OpCapability Shader
404          %1 = OpExtInstImport "GLSL.std.450"
405               OpMemoryModel Logical GLSL450
406               OpEntryPoint Fragment %4 "main"
407               OpExecutionMode %4 OriginUpperLeft
408               OpSource ESSL 320
409          %2 = OpTypeVoid
410          %3 = OpTypeFunction %2
411          %6 = OpTypeInt 32 1
412          %7 = OpTypePointer Function %6
413          %8 = OpTypeFunction %6 %7
414         %17 = OpConstant %6 1
415         %20 = OpConstant %6 2
416         %23 = OpConstant %6 0
417         %24 = OpTypeBool
418         %35 = OpConstant %6 3
419         %53 = OpConstant %6 10
420          %4 = OpFunction %2 None %3
421          %5 = OpLabel
422         %65 = OpVariable %7 Function
423         %68 = OpVariable %7 Function
424         %73 = OpVariable %7 Function
425               OpStore %65 %35
426         %66 = OpLoad %6 %65
427         %67 = OpIAdd %6 %66 %17
428               OpStore %65 %67
429         %69 = OpLoad %6 %65
430               OpStore %68 %69
431         %70 = OpFunctionCall %6 %13 %68
432         %71 = OpLoad %6 %65
433         %72 = OpIAdd %6 %71 %70
434               OpStore %65 %72
435         %74 = OpLoad %6 %65
436               OpStore %73 %74
437         %75 = OpFunctionCall %6 %10 %73
438         %76 = OpLoad %6 %65
439         %77 = OpIAdd %6 %76 %75
440               OpStore %65 %77
441               OpReturn
442               OpFunctionEnd
443         %10 = OpFunction %6 None %8
444          %9 = OpFunctionParameter %7
445         %11 = OpLabel
446         %15 = OpVariable %7 Function
447         %16 = OpLoad %6 %9
448         %18 = OpIAdd %6 %16 %17
449               OpStore %15 %18
450         %19 = OpLoad %6 %15
451         %21 = OpIAdd %6 %19 %20
452               OpStore %15 %21
453         %22 = OpLoad %6 %15
454         %25 = OpSGreaterThan %24 %22 %23
455               OpSelectionMerge %27 None
456               OpBranchConditional %25 %26 %27
457         %26 = OpLabel
458         %28 = OpLoad %6 %9
459               OpReturnValue %28
460         %27 = OpLabel
461         %30 = OpLoad %6 %9
462         %31 = OpIAdd %6 %30 %17
463               OpReturnValue %31
464               OpFunctionEnd
465         %13 = OpFunction %6 None %8
466         %12 = OpFunctionParameter %7
467         %14 = OpLabel
468         %41 = OpVariable %7 Function
469         %46 = OpVariable %7 Function
470         %55 = OpVariable %7 Function
471         %34 = OpLoad %6 %12
472         %36 = OpIEqual %24 %34 %35
473               OpSelectionMerge %38 None
474               OpBranchConditional %36 %37 %38
475         %37 = OpLabel
476         %39 = OpLoad %6 %12
477         %40 = OpIMul %6 %20 %39
478               OpStore %41 %40
479         %42 = OpFunctionCall %6 %10 %41
480               OpReturnValue %42
481         %38 = OpLabel
482         %44 = OpLoad %6 %12
483         %45 = OpIAdd %6 %44 %17
484               OpStore %12 %45
485               OpStore %46 %23
486               OpBranch %47
487         %47 = OpLabel
488               OpLoopMerge %49 %50 None
489               OpBranch %51
490         %51 = OpLabel
491         %52 = OpLoad %6 %46
492         %54 = OpSLessThan %24 %52 %53
493               OpBranchConditional %54 %48 %49
494         %48 = OpLabel
495         %56 = OpLoad %6 %12
496               OpStore %55 %56
497         %57 = OpFunctionCall %6 %10 %55
498         %58 = OpLoad %6 %12
499         %59 = OpIAdd %6 %58 %57
500               OpStore %12 %59
501               OpBranch %50
502         %50 = OpLabel
503         %60 = OpLoad %6 %46
504         %61 = OpIAdd %6 %60 %17
505               OpStore %46 %61
506               OpBranch %47
507         %49 = OpLabel
508         %62 = OpLoad %6 %12
509               OpReturnValue %62
510               OpFunctionEnd
511  )";
512
513TEST(ReducerTest, ShaderReduceWhileMulReachable) {
514  Reducer reducer(kEnv);
515
516  reducer.SetInterestingnessFunction(InterestingWhileIMulReachable);
517  reducer.AddDefaultReductionPasses();
518  reducer.SetMessageConsumer(kMessageConsumer);
519
520  std::vector<uint32_t> binary_in;
521  SpirvTools t(kEnv);
522
523  ASSERT_TRUE(
524      t.Assemble(kShaderWithLoopsDivAndMul, &binary_in, kReduceAssembleOption));
525  std::vector<uint32_t> binary_out;
526  spvtools::ReducerOptions reducer_options;
527  reducer_options.set_step_limit(500);
528  reducer_options.set_fail_on_validation_error(true);
529  spvtools::ValidatorOptions validator_options;
530
531  Reducer::ReductionResultStatus status = reducer.Run(
532      std::move(binary_in), &binary_out, reducer_options, validator_options);
533
534  ASSERT_EQ(status, Reducer::ReductionResultStatus::kComplete);
535}
536
537TEST(ReducerTest, ShaderReduceWhileDivReachable) {
538  Reducer reducer(kEnv);
539
540  reducer.SetInterestingnessFunction(InterestingWhileSDivReachable);
541  reducer.AddDefaultReductionPasses();
542  reducer.SetMessageConsumer(kMessageConsumer);
543
544  std::vector<uint32_t> binary_in;
545  SpirvTools t(kEnv);
546
547  ASSERT_TRUE(
548      t.Assemble(kShaderWithLoopsDivAndMul, &binary_in, kReduceAssembleOption));
549  std::vector<uint32_t> binary_out;
550  spvtools::ReducerOptions reducer_options;
551  reducer_options.set_step_limit(500);
552  reducer_options.set_fail_on_validation_error(true);
553  spvtools::ValidatorOptions validator_options;
554
555  Reducer::ReductionResultStatus status = reducer.Run(
556      std::move(binary_in), &binary_out, reducer_options, validator_options);
557
558  ASSERT_EQ(status, Reducer::ReductionResultStatus::kComplete);
559}
560
561// Computes an instruction count for each function in the module represented by
562// |binary|.
563std::unordered_map<uint32_t, uint32_t> GetFunctionInstructionCount(
564    const std::vector<uint32_t>& binary) {
565  std::unique_ptr<opt::IRContext> context =
566      BuildModule(kEnv, kMessageConsumer, binary.data(), binary.size());
567  assert(context != nullptr && "Failed to build module.");
568  std::unordered_map<uint32_t, uint32_t> result;
569  for (auto& function : *context->module()) {
570    uint32_t& count = result[function.result_id()] = 0;
571    function.ForEachInst([&count](opt::Instruction*) { count++; });
572  }
573  return result;
574}
575
576TEST(ReducerTest, SingleFunctionReduction) {
577  Reducer reducer(kEnv);
578
579  PingPongInteresting ping_pong_interesting(4);
580  reducer.SetInterestingnessFunction(
581      [&ping_pong_interesting](const std::vector<uint32_t>&, uint32_t) -> bool {
582        return ping_pong_interesting.IsInteresting();
583      });
584  reducer.AddDefaultReductionPasses();
585  reducer.SetMessageConsumer(kMessageConsumer);
586
587  std::vector<uint32_t> binary_in;
588  SpirvTools t(kEnv);
589
590  ASSERT_TRUE(t.Assemble(kShaderWithMultipleFunctions, &binary_in,
591                         kReduceAssembleOption));
592
593  auto original_instruction_count = GetFunctionInstructionCount(binary_in);
594
595  std::vector<uint32_t> binary_out;
596  spvtools::ReducerOptions reducer_options;
597  reducer_options.set_step_limit(500);
598  reducer_options.set_fail_on_validation_error(true);
599
600  // Instruct the reducer to only target function 13.
601  reducer_options.set_target_function(13);
602
603  spvtools::ValidatorOptions validator_options;
604
605  Reducer::ReductionResultStatus status = reducer.Run(
606      std::move(binary_in), &binary_out, reducer_options, validator_options);
607
608  ASSERT_EQ(status, Reducer::ReductionResultStatus::kComplete);
609
610  auto final_instruction_count = GetFunctionInstructionCount(binary_out);
611
612  // Nothing should have been removed from these functions.
613  ASSERT_EQ(original_instruction_count.at(4), final_instruction_count.at(4));
614  ASSERT_EQ(original_instruction_count.at(10), final_instruction_count.at(10));
615
616  // Function 13 should have been reduced to these five instructions:
617  //   OpFunction
618  //   OpFunctionParameter
619  //   OpLabel
620  //   OpReturnValue
621  //   OpFunctionEnd
622  ASSERT_EQ(5, final_instruction_count.at(13));
623}
624
625}  // namespace
626}  // namespace reduce
627}  // namespace spvtools
628