1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <memory>
16#include <vector>
17
18#include "gmock/gmock.h"
19#include "source/opt/loop_descriptor.h"
20#include "source/opt/loop_fusion.h"
21#include "test/opt/pass_fixture.h"
22
23namespace spvtools {
24namespace opt {
25namespace {
26
27using FusionCompatibilityTest = PassTest<::testing::Test>;
28
29/*
30Generated from the following GLSL + --eliminate-local-multi-store
31
32#version 440 core
33void main() {
34  int i = 0; // Can't fuse, i=0 in first & i=10 in second
35  for (; i < 10; i++) {}
36  for (; i < 10; i++) {}
37}
38*/
39TEST_F(FusionCompatibilityTest, SameInductionVariableDifferentBounds) {
40  const std::string text = R"(
41               OpCapability Shader
42          %1 = OpExtInstImport "GLSL.std.450"
43               OpMemoryModel Logical GLSL450
44               OpEntryPoint Fragment %4 "main"
45               OpExecutionMode %4 OriginUpperLeft
46               OpSource GLSL 440
47               OpName %4 "main"
48               OpName %8 "i"
49          %2 = OpTypeVoid
50          %3 = OpTypeFunction %2
51          %6 = OpTypeInt 32 1
52          %7 = OpTypePointer Function %6
53          %9 = OpConstant %6 0
54         %16 = OpConstant %6 10
55         %17 = OpTypeBool
56         %20 = OpConstant %6 1
57          %4 = OpFunction %2 None %3
58          %5 = OpLabel
59          %8 = OpVariable %7 Function
60               OpStore %8 %9
61               OpBranch %10
62         %10 = OpLabel
63         %31 = OpPhi %6 %9 %5 %21 %13
64               OpLoopMerge %12 %13 None
65               OpBranch %14
66         %14 = OpLabel
67         %18 = OpSLessThan %17 %31 %16
68               OpBranchConditional %18 %11 %12
69         %11 = OpLabel
70               OpBranch %13
71         %13 = OpLabel
72         %21 = OpIAdd %6 %31 %20
73               OpStore %8 %21
74               OpBranch %10
75         %12 = OpLabel
76               OpBranch %22
77         %22 = OpLabel
78         %32 = OpPhi %6 %31 %12 %30 %25
79               OpLoopMerge %24 %25 None
80               OpBranch %26
81         %26 = OpLabel
82         %28 = OpSLessThan %17 %32 %16
83               OpBranchConditional %28 %23 %24
84         %23 = OpLabel
85               OpBranch %25
86         %25 = OpLabel
87         %30 = OpIAdd %6 %32 %20
88               OpStore %8 %30
89               OpBranch %22
90         %24 = OpLabel
91               OpReturn
92               OpFunctionEnd
93  )";
94
95  std::unique_ptr<IRContext> context =
96      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
97                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
98  Module* module = context->module();
99  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
100                             << text << std::endl;
101  Function& f = *module->begin();
102  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
103  EXPECT_EQ(ld.NumLoops(), 2u);
104
105  auto loops = ld.GetLoopsInBinaryLayoutOrder();
106
107  LoopFusion fusion(context.get(), loops[0], loops[1]);
108  EXPECT_FALSE(fusion.AreCompatible());
109}
110
111/*
112Generated from the following GLSL + --eliminate-local-multi-store
113
114// 1
115#version 440 core
116void main() {
117  for (int i = 0; i < 10; i++) {}
118  for (int i = 0; i < 10; i++) {}
119}
120*/
121TEST_F(FusionCompatibilityTest, Compatible) {
122  const std::string text = R"(
123               OpCapability Shader
124          %1 = OpExtInstImport "GLSL.std.450"
125               OpMemoryModel Logical GLSL450
126               OpEntryPoint Fragment %4 "main"
127               OpExecutionMode %4 OriginUpperLeft
128               OpSource GLSL 440
129               OpName %4 "main"
130               OpName %8 "i"
131               OpName %22 "i"
132          %2 = OpTypeVoid
133          %3 = OpTypeFunction %2
134          %6 = OpTypeInt 32 1
135          %7 = OpTypePointer Function %6
136          %9 = OpConstant %6 0
137         %16 = OpConstant %6 10
138         %17 = OpTypeBool
139         %20 = OpConstant %6 1
140          %4 = OpFunction %2 None %3
141          %5 = OpLabel
142          %8 = OpVariable %7 Function
143         %22 = OpVariable %7 Function
144               OpStore %8 %9
145               OpBranch %10
146         %10 = OpLabel
147         %32 = OpPhi %6 %9 %5 %21 %13
148               OpLoopMerge %12 %13 None
149               OpBranch %14
150         %14 = OpLabel
151         %18 = OpSLessThan %17 %32 %16
152               OpBranchConditional %18 %11 %12
153         %11 = OpLabel
154               OpBranch %13
155         %13 = OpLabel
156         %21 = OpIAdd %6 %32 %20
157               OpStore %8 %21
158               OpBranch %10
159         %12 = OpLabel
160               OpStore %22 %9
161               OpBranch %23
162         %23 = OpLabel
163         %33 = OpPhi %6 %9 %12 %31 %26
164               OpLoopMerge %25 %26 None
165               OpBranch %27
166         %27 = OpLabel
167         %29 = OpSLessThan %17 %33 %16
168               OpBranchConditional %29 %24 %25
169         %24 = OpLabel
170               OpBranch %26
171         %26 = OpLabel
172         %31 = OpIAdd %6 %33 %20
173               OpStore %22 %31
174               OpBranch %23
175         %25 = OpLabel
176               OpReturn
177               OpFunctionEnd
178  )";
179
180  std::unique_ptr<IRContext> context =
181      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
182                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
183  Module* module = context->module();
184  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
185                             << text << std::endl;
186  Function& f = *module->begin();
187  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
188  EXPECT_EQ(ld.NumLoops(), 2u);
189
190  auto loops = ld.GetLoopsInBinaryLayoutOrder();
191
192  LoopFusion fusion(context.get(), loops[0], loops[1]);
193  EXPECT_TRUE(fusion.AreCompatible());
194}
195
196/*
197Generated from the following GLSL + --eliminate-local-multi-store
198
199// 2
200#version 440 core
201void main() {
202  for (int i = 0; i < 10; i++) {}
203  for (int j = 0; j < 10; j++) {}
204}
205
206*/
207TEST_F(FusionCompatibilityTest, DifferentName) {
208  const std::string text = R"(
209               OpCapability Shader
210          %1 = OpExtInstImport "GLSL.std.450"
211               OpMemoryModel Logical GLSL450
212               OpEntryPoint Fragment %4 "main"
213               OpExecutionMode %4 OriginUpperLeft
214               OpSource GLSL 440
215               OpName %4 "main"
216               OpName %8 "i"
217               OpName %22 "j"
218          %2 = OpTypeVoid
219          %3 = OpTypeFunction %2
220          %6 = OpTypeInt 32 1
221          %7 = OpTypePointer Function %6
222          %9 = OpConstant %6 0
223         %16 = OpConstant %6 10
224         %17 = OpTypeBool
225         %20 = OpConstant %6 1
226          %4 = OpFunction %2 None %3
227          %5 = OpLabel
228          %8 = OpVariable %7 Function
229         %22 = OpVariable %7 Function
230               OpStore %8 %9
231               OpBranch %10
232         %10 = OpLabel
233         %32 = OpPhi %6 %9 %5 %21 %13
234               OpLoopMerge %12 %13 None
235               OpBranch %14
236         %14 = OpLabel
237         %18 = OpSLessThan %17 %32 %16
238               OpBranchConditional %18 %11 %12
239         %11 = OpLabel
240               OpBranch %13
241         %13 = OpLabel
242         %21 = OpIAdd %6 %32 %20
243               OpStore %8 %21
244               OpBranch %10
245         %12 = OpLabel
246               OpStore %22 %9
247               OpBranch %23
248         %23 = OpLabel
249         %33 = OpPhi %6 %9 %12 %31 %26
250               OpLoopMerge %25 %26 None
251               OpBranch %27
252         %27 = OpLabel
253         %29 = OpSLessThan %17 %33 %16
254               OpBranchConditional %29 %24 %25
255         %24 = OpLabel
256               OpBranch %26
257         %26 = OpLabel
258         %31 = OpIAdd %6 %33 %20
259               OpStore %22 %31
260               OpBranch %23
261         %25 = OpLabel
262               OpReturn
263               OpFunctionEnd
264  )";
265
266  std::unique_ptr<IRContext> context =
267      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
268                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
269  Module* module = context->module();
270  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
271                             << text << std::endl;
272  Function& f = *module->begin();
273  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
274  EXPECT_EQ(ld.NumLoops(), 2u);
275
276  auto loops = ld.GetLoopsInBinaryLayoutOrder();
277
278  LoopFusion fusion(context.get(), loops[0], loops[1]);
279  EXPECT_TRUE(fusion.AreCompatible());
280}
281
282/*
283Generated from the following GLSL + --eliminate-local-multi-store
284
285#version 440 core
286void main() {
287  // Can't fuse, different step
288  for (int i = 0; i < 10; i++) {}
289  for (int j = 0; j < 10; j=j+2) {}
290}
291
292*/
293TEST_F(FusionCompatibilityTest, SameBoundsDifferentStep) {
294  const std::string text = R"(
295               OpCapability Shader
296          %1 = OpExtInstImport "GLSL.std.450"
297               OpMemoryModel Logical GLSL450
298               OpEntryPoint Fragment %4 "main"
299               OpExecutionMode %4 OriginUpperLeft
300               OpSource GLSL 440
301               OpName %4 "main"
302               OpName %8 "i"
303               OpName %22 "j"
304          %2 = OpTypeVoid
305          %3 = OpTypeFunction %2
306          %6 = OpTypeInt 32 1
307          %7 = OpTypePointer Function %6
308          %9 = OpConstant %6 0
309         %16 = OpConstant %6 10
310         %17 = OpTypeBool
311         %20 = OpConstant %6 1
312         %31 = OpConstant %6 2
313          %4 = OpFunction %2 None %3
314          %5 = OpLabel
315          %8 = OpVariable %7 Function
316         %22 = OpVariable %7 Function
317               OpStore %8 %9
318               OpBranch %10
319         %10 = OpLabel
320         %33 = OpPhi %6 %9 %5 %21 %13
321               OpLoopMerge %12 %13 None
322               OpBranch %14
323         %14 = OpLabel
324         %18 = OpSLessThan %17 %33 %16
325               OpBranchConditional %18 %11 %12
326         %11 = OpLabel
327               OpBranch %13
328         %13 = OpLabel
329         %21 = OpIAdd %6 %33 %20
330               OpStore %8 %21
331               OpBranch %10
332         %12 = OpLabel
333               OpStore %22 %9
334               OpBranch %23
335         %23 = OpLabel
336         %34 = OpPhi %6 %9 %12 %32 %26
337               OpLoopMerge %25 %26 None
338               OpBranch %27
339         %27 = OpLabel
340         %29 = OpSLessThan %17 %34 %16
341               OpBranchConditional %29 %24 %25
342         %24 = OpLabel
343               OpBranch %26
344         %26 = OpLabel
345         %32 = OpIAdd %6 %34 %31
346               OpStore %22 %32
347               OpBranch %23
348         %25 = OpLabel
349               OpReturn
350               OpFunctionEnd
351  )";
352
353  std::unique_ptr<IRContext> context =
354      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
355                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
356  Module* module = context->module();
357  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
358                             << text << std::endl;
359  Function& f = *module->begin();
360  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
361  EXPECT_EQ(ld.NumLoops(), 2u);
362
363  auto loops = ld.GetLoopsInBinaryLayoutOrder();
364
365  LoopFusion fusion(context.get(), loops[0], loops[1]);
366  EXPECT_FALSE(fusion.AreCompatible());
367}
368
369/*
370Generated from the following GLSL + --eliminate-local-multi-store
371
372// 4
373#version 440 core
374void main() {
375  // Can't fuse, different upper bound
376  for (int i = 0; i < 10; i++) {}
377  for (int j = 0; j < 20; j++) {}
378}
379
380*/
381TEST_F(FusionCompatibilityTest, DifferentUpperBound) {
382  const std::string text = R"(
383               OpCapability Shader
384          %1 = OpExtInstImport "GLSL.std.450"
385               OpMemoryModel Logical GLSL450
386               OpEntryPoint Fragment %4 "main"
387               OpExecutionMode %4 OriginUpperLeft
388               OpSource GLSL 440
389               OpName %4 "main"
390               OpName %8 "i"
391               OpName %22 "j"
392          %2 = OpTypeVoid
393          %3 = OpTypeFunction %2
394          %6 = OpTypeInt 32 1
395          %7 = OpTypePointer Function %6
396          %9 = OpConstant %6 0
397         %16 = OpConstant %6 10
398         %17 = OpTypeBool
399         %20 = OpConstant %6 1
400         %29 = OpConstant %6 20
401          %4 = OpFunction %2 None %3
402          %5 = OpLabel
403          %8 = OpVariable %7 Function
404         %22 = OpVariable %7 Function
405               OpStore %8 %9
406               OpBranch %10
407         %10 = OpLabel
408         %33 = OpPhi %6 %9 %5 %21 %13
409               OpLoopMerge %12 %13 None
410               OpBranch %14
411         %14 = OpLabel
412         %18 = OpSLessThan %17 %33 %16
413               OpBranchConditional %18 %11 %12
414         %11 = OpLabel
415               OpBranch %13
416         %13 = OpLabel
417         %21 = OpIAdd %6 %33 %20
418               OpStore %8 %21
419               OpBranch %10
420         %12 = OpLabel
421               OpStore %22 %9
422               OpBranch %23
423         %23 = OpLabel
424         %34 = OpPhi %6 %9 %12 %32 %26
425               OpLoopMerge %25 %26 None
426               OpBranch %27
427         %27 = OpLabel
428         %30 = OpSLessThan %17 %34 %29
429               OpBranchConditional %30 %24 %25
430         %24 = OpLabel
431               OpBranch %26
432         %26 = OpLabel
433         %32 = OpIAdd %6 %34 %20
434               OpStore %22 %32
435               OpBranch %23
436         %25 = OpLabel
437               OpReturn
438               OpFunctionEnd
439  )";
440
441  std::unique_ptr<IRContext> context =
442      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
443                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
444  Module* module = context->module();
445  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
446                             << text << std::endl;
447  Function& f = *module->begin();
448  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
449  EXPECT_EQ(ld.NumLoops(), 2u);
450
451  auto loops = ld.GetLoopsInBinaryLayoutOrder();
452
453  LoopFusion fusion(context.get(), loops[0], loops[1]);
454  EXPECT_FALSE(fusion.AreCompatible());
455}
456
457/*
458Generated from the following GLSL + --eliminate-local-multi-store
459
460// 5
461#version 440 core
462void main() {
463  // Can't fuse, different lower bound
464  for (int i = 5; i < 10; i++) {}
465  for (int j = 0; j < 10; j++) {}
466}
467
468*/
469TEST_F(FusionCompatibilityTest, DifferentLowerBound) {
470  const std::string text = R"(
471                OpCapability Shader
472          %1 = OpExtInstImport "GLSL.std.450"
473               OpMemoryModel Logical GLSL450
474               OpEntryPoint Fragment %4 "main"
475               OpExecutionMode %4 OriginUpperLeft
476               OpSource GLSL 440
477               OpName %4 "main"
478               OpName %8 "i"
479               OpName %22 "j"
480          %2 = OpTypeVoid
481          %3 = OpTypeFunction %2
482          %6 = OpTypeInt 32 1
483          %7 = OpTypePointer Function %6
484          %9 = OpConstant %6 5
485         %16 = OpConstant %6 10
486         %17 = OpTypeBool
487         %20 = OpConstant %6 1
488         %23 = OpConstant %6 0
489          %4 = OpFunction %2 None %3
490          %5 = OpLabel
491          %8 = OpVariable %7 Function
492         %22 = OpVariable %7 Function
493               OpStore %8 %9
494               OpBranch %10
495         %10 = OpLabel
496         %33 = OpPhi %6 %9 %5 %21 %13
497               OpLoopMerge %12 %13 None
498               OpBranch %14
499         %14 = OpLabel
500         %18 = OpSLessThan %17 %33 %16
501               OpBranchConditional %18 %11 %12
502         %11 = OpLabel
503               OpBranch %13
504         %13 = OpLabel
505         %21 = OpIAdd %6 %33 %20
506               OpStore %8 %21
507               OpBranch %10
508         %12 = OpLabel
509               OpStore %22 %23
510               OpBranch %24
511         %24 = OpLabel
512         %34 = OpPhi %6 %23 %12 %32 %27
513               OpLoopMerge %26 %27 None
514               OpBranch %28
515         %28 = OpLabel
516         %30 = OpSLessThan %17 %34 %16
517               OpBranchConditional %30 %25 %26
518         %25 = OpLabel
519               OpBranch %27
520         %27 = OpLabel
521         %32 = OpIAdd %6 %34 %20
522               OpStore %22 %32
523               OpBranch %24
524         %26 = OpLabel
525               OpReturn
526               OpFunctionEnd
527  )";
528
529  std::unique_ptr<IRContext> context =
530      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
531                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
532  Module* module = context->module();
533  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
534                             << text << std::endl;
535  Function& f = *module->begin();
536  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
537  EXPECT_EQ(ld.NumLoops(), 2u);
538
539  auto loops = ld.GetLoopsInBinaryLayoutOrder();
540
541  LoopFusion fusion(context.get(), loops[0], loops[1]);
542  EXPECT_FALSE(fusion.AreCompatible());
543}
544
545/*
546Generated from the following GLSL + --eliminate-local-multi-store
547
548// 6
549#version 440 core
550void main() {
551  // Can't fuse, break in first loop
552  for (int i = 0; i < 10; i++) {
553    if (i == 5) {
554      break;
555    }
556  }
557  for (int j = 0; j < 10; j++) {}
558}
559
560*/
561TEST_F(FusionCompatibilityTest, Break) {
562  const std::string text = R"(
563                OpCapability Shader
564          %1 = OpExtInstImport "GLSL.std.450"
565               OpMemoryModel Logical GLSL450
566               OpEntryPoint Fragment %4 "main"
567               OpExecutionMode %4 OriginUpperLeft
568               OpSource GLSL 440
569               OpName %4 "main"
570               OpName %8 "i"
571               OpName %28 "j"
572          %2 = OpTypeVoid
573          %3 = OpTypeFunction %2
574          %6 = OpTypeInt 32 1
575          %7 = OpTypePointer Function %6
576          %9 = OpConstant %6 0
577         %16 = OpConstant %6 10
578         %17 = OpTypeBool
579         %20 = OpConstant %6 5
580         %26 = OpConstant %6 1
581          %4 = OpFunction %2 None %3
582          %5 = OpLabel
583          %8 = OpVariable %7 Function
584         %28 = OpVariable %7 Function
585               OpStore %8 %9
586               OpBranch %10
587         %10 = OpLabel
588         %38 = OpPhi %6 %9 %5 %27 %13
589               OpLoopMerge %12 %13 None
590               OpBranch %14
591         %14 = OpLabel
592         %18 = OpSLessThan %17 %38 %16
593               OpBranchConditional %18 %11 %12
594         %11 = OpLabel
595         %21 = OpIEqual %17 %38 %20
596               OpSelectionMerge %23 None
597               OpBranchConditional %21 %22 %23
598         %22 = OpLabel
599               OpBranch %12
600         %23 = OpLabel
601               OpBranch %13
602         %13 = OpLabel
603         %27 = OpIAdd %6 %38 %26
604               OpStore %8 %27
605               OpBranch %10
606         %12 = OpLabel
607               OpStore %28 %9
608               OpBranch %29
609         %29 = OpLabel
610         %39 = OpPhi %6 %9 %12 %37 %32
611               OpLoopMerge %31 %32 None
612               OpBranch %33
613         %33 = OpLabel
614         %35 = OpSLessThan %17 %39 %16
615               OpBranchConditional %35 %30 %31
616         %30 = OpLabel
617               OpBranch %32
618         %32 = OpLabel
619         %37 = OpIAdd %6 %39 %26
620               OpStore %28 %37
621               OpBranch %29
622         %31 = OpLabel
623               OpReturn
624               OpFunctionEnd
625  )";
626
627  std::unique_ptr<IRContext> context =
628      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
629                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
630  Module* module = context->module();
631  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
632                             << text << std::endl;
633  Function& f = *module->begin();
634  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
635  EXPECT_EQ(ld.NumLoops(), 2u);
636
637  auto loops = ld.GetLoopsInBinaryLayoutOrder();
638
639  LoopFusion fusion(context.get(), loops[0], loops[1]);
640  EXPECT_FALSE(fusion.AreCompatible());
641}
642
643/*
644Generated from the following GLSL + --eliminate-local-multi-store
645
646#version 440 core
647layout(location = 0) in vec4 c;
648void main() {
649  int N = int(c.x);
650  for (int i = 0; i < N; i++) {}
651  for (int j = 0; j < N; j++) {}
652}
653
654*/
655TEST_F(FusionCompatibilityTest, UnknownButSameUpperBound) {
656  const std::string text = R"(
657               OpCapability Shader
658          %1 = OpExtInstImport "GLSL.std.450"
659               OpMemoryModel Logical GLSL450
660               OpEntryPoint Fragment %4 "main" %12
661               OpExecutionMode %4 OriginUpperLeft
662               OpSource GLSL 440
663               OpName %4 "main"
664               OpName %8 "N"
665               OpName %12 "c"
666               OpName %19 "i"
667               OpName %33 "j"
668               OpDecorate %12 Location 0
669          %2 = OpTypeVoid
670          %3 = OpTypeFunction %2
671          %6 = OpTypeInt 32 1
672          %7 = OpTypePointer Function %6
673          %9 = OpTypeFloat 32
674         %10 = OpTypeVector %9 4
675         %11 = OpTypePointer Input %10
676         %12 = OpVariable %11 Input
677         %13 = OpTypeInt 32 0
678         %14 = OpConstant %13 0
679         %15 = OpTypePointer Input %9
680         %20 = OpConstant %6 0
681         %28 = OpTypeBool
682         %31 = OpConstant %6 1
683          %4 = OpFunction %2 None %3
684          %5 = OpLabel
685          %8 = OpVariable %7 Function
686         %19 = OpVariable %7 Function
687         %33 = OpVariable %7 Function
688         %16 = OpAccessChain %15 %12 %14
689         %17 = OpLoad %9 %16
690         %18 = OpConvertFToS %6 %17
691               OpStore %8 %18
692               OpStore %19 %20
693               OpBranch %21
694         %21 = OpLabel
695         %44 = OpPhi %6 %20 %5 %32 %24
696               OpLoopMerge %23 %24 None
697               OpBranch %25
698         %25 = OpLabel
699         %29 = OpSLessThan %28 %44 %18
700               OpBranchConditional %29 %22 %23
701         %22 = OpLabel
702               OpBranch %24
703         %24 = OpLabel
704         %32 = OpIAdd %6 %44 %31
705               OpStore %19 %32
706               OpBranch %21
707         %23 = OpLabel
708               OpStore %33 %20
709               OpBranch %34
710         %34 = OpLabel
711         %46 = OpPhi %6 %20 %23 %43 %37
712               OpLoopMerge %36 %37 None
713               OpBranch %38
714         %38 = OpLabel
715         %41 = OpSLessThan %28 %46 %18
716               OpBranchConditional %41 %35 %36
717         %35 = OpLabel
718               OpBranch %37
719         %37 = OpLabel
720         %43 = OpIAdd %6 %46 %31
721               OpStore %33 %43
722               OpBranch %34
723         %36 = OpLabel
724               OpReturn
725               OpFunctionEnd
726  )";
727
728  std::unique_ptr<IRContext> context =
729      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
730                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
731  Module* module = context->module();
732  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
733                             << text << std::endl;
734  Function& f = *module->begin();
735  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
736  EXPECT_EQ(ld.NumLoops(), 2u);
737
738  auto loops = ld.GetLoopsInBinaryLayoutOrder();
739
740  LoopFusion fusion(context.get(), loops[0], loops[1]);
741  EXPECT_TRUE(fusion.AreCompatible());
742}
743
744/*
745Generated from the following GLSL + --eliminate-local-multi-store
746
747#version 440 core
748layout(location = 0) in vec4 c;
749void main() {
750  int N = int(c.x);
751  for (int i = 0; N > j; i++) {}
752  for (int j = 0; N > j; j++) {}
753}
754*/
755TEST_F(FusionCompatibilityTest, UnknownButSameUpperBoundReverseCondition) {
756  const std::string text = R"(
757               OpCapability Shader
758          %1 = OpExtInstImport "GLSL.std.450"
759               OpMemoryModel Logical GLSL450
760               OpEntryPoint Fragment %4 "main" %12
761               OpExecutionMode %4 OriginUpperLeft
762               OpSource GLSL 440
763               OpName %4 "main"
764               OpName %8 "N"
765               OpName %12 "c"
766               OpName %19 "i"
767               OpName %33 "j"
768               OpDecorate %12 Location 0
769          %2 = OpTypeVoid
770          %3 = OpTypeFunction %2
771          %6 = OpTypeInt 32 1
772          %7 = OpTypePointer Function %6
773          %9 = OpTypeFloat 32
774         %10 = OpTypeVector %9 4
775         %11 = OpTypePointer Input %10
776         %12 = OpVariable %11 Input
777         %13 = OpTypeInt 32 0
778         %14 = OpConstant %13 0
779         %15 = OpTypePointer Input %9
780         %20 = OpConstant %6 0
781         %28 = OpTypeBool
782         %31 = OpConstant %6 1
783          %4 = OpFunction %2 None %3
784          %5 = OpLabel
785          %8 = OpVariable %7 Function
786         %19 = OpVariable %7 Function
787         %33 = OpVariable %7 Function
788         %16 = OpAccessChain %15 %12 %14
789         %17 = OpLoad %9 %16
790         %18 = OpConvertFToS %6 %17
791               OpStore %8 %18
792               OpStore %19 %20
793               OpBranch %21
794         %21 = OpLabel
795         %45 = OpPhi %6 %20 %5 %32 %24
796               OpLoopMerge %23 %24 None
797               OpBranch %25
798         %25 = OpLabel
799         %29 = OpSGreaterThan %28 %18 %45
800               OpBranchConditional %29 %22 %23
801         %22 = OpLabel
802               OpBranch %24
803         %24 = OpLabel
804         %32 = OpIAdd %6 %45 %31
805               OpStore %19 %32
806               OpBranch %21
807         %23 = OpLabel
808               OpStore %33 %20
809               OpBranch %34
810         %34 = OpLabel
811         %47 = OpPhi %6 %20 %23 %43 %37
812               OpLoopMerge %36 %37 None
813               OpBranch %38
814         %38 = OpLabel
815         %41 = OpSGreaterThan %28 %18 %47
816               OpBranchConditional %41 %35 %36
817         %35 = OpLabel
818               OpBranch %37
819         %37 = OpLabel
820         %43 = OpIAdd %6 %47 %31
821               OpStore %33 %43
822               OpBranch %34
823         %36 = OpLabel
824               OpReturn
825               OpFunctionEnd
826  )";
827
828  std::unique_ptr<IRContext> context =
829      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
830                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
831  Module* module = context->module();
832  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
833                             << text << std::endl;
834  Function& f = *module->begin();
835  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
836  EXPECT_EQ(ld.NumLoops(), 2u);
837
838  auto loops = ld.GetLoopsInBinaryLayoutOrder();
839
840  LoopFusion fusion(context.get(), loops[0], loops[1]);
841  EXPECT_TRUE(fusion.AreCompatible());
842}
843
844/*
845Generated from the following GLSL + --eliminate-local-multi-store
846
847#version 440 core
848layout(location = 0) in vec4 c;
849void main() {
850  // Can't fuse different bound
851  int N = int(c.x);
852  for (int i = 0; i < N; i++) {}
853  for (int j = 0; j < N+1; j++) {}
854}
855
856*/
857TEST_F(FusionCompatibilityTest, UnknownUpperBoundAddition) {
858  const std::string text = R"(
859               OpCapability Shader
860          %1 = OpExtInstImport "GLSL.std.450"
861               OpMemoryModel Logical GLSL450
862               OpEntryPoint Fragment %4 "main" %12
863               OpExecutionMode %4 OriginUpperLeft
864               OpSource GLSL 440
865               OpName %4 "main"
866               OpName %8 "N"
867               OpName %12 "c"
868               OpName %19 "i"
869               OpName %33 "j"
870               OpDecorate %12 Location 0
871          %2 = OpTypeVoid
872          %3 = OpTypeFunction %2
873          %6 = OpTypeInt 32 1
874          %7 = OpTypePointer Function %6
875          %9 = OpTypeFloat 32
876         %10 = OpTypeVector %9 4
877         %11 = OpTypePointer Input %10
878         %12 = OpVariable %11 Input
879         %13 = OpTypeInt 32 0
880         %14 = OpConstant %13 0
881         %15 = OpTypePointer Input %9
882         %20 = OpConstant %6 0
883         %28 = OpTypeBool
884         %31 = OpConstant %6 1
885          %4 = OpFunction %2 None %3
886          %5 = OpLabel
887          %8 = OpVariable %7 Function
888         %19 = OpVariable %7 Function
889         %33 = OpVariable %7 Function
890         %16 = OpAccessChain %15 %12 %14
891         %17 = OpLoad %9 %16
892         %18 = OpConvertFToS %6 %17
893               OpStore %8 %18
894               OpStore %19 %20
895               OpBranch %21
896         %21 = OpLabel
897         %45 = OpPhi %6 %20 %5 %32 %24
898               OpLoopMerge %23 %24 None
899               OpBranch %25
900         %25 = OpLabel
901         %29 = OpSLessThan %28 %45 %18
902               OpBranchConditional %29 %22 %23
903         %22 = OpLabel
904               OpBranch %24
905         %24 = OpLabel
906         %32 = OpIAdd %6 %45 %31
907               OpStore %19 %32
908               OpBranch %21
909         %23 = OpLabel
910               OpStore %33 %20
911               OpBranch %34
912         %34 = OpLabel
913         %47 = OpPhi %6 %20 %23 %44 %37
914               OpLoopMerge %36 %37 None
915               OpBranch %38
916         %38 = OpLabel
917         %41 = OpIAdd %6 %18 %31
918         %42 = OpSLessThan %28 %47 %41
919               OpBranchConditional %42 %35 %36
920         %35 = OpLabel
921               OpBranch %37
922         %37 = OpLabel
923         %44 = OpIAdd %6 %47 %31
924               OpStore %33 %44
925               OpBranch %34
926         %36 = OpLabel
927               OpReturn
928               OpFunctionEnd
929  )";
930
931  std::unique_ptr<IRContext> context =
932      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
933                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
934  Module* module = context->module();
935  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
936                             << text << std::endl;
937  Function& f = *module->begin();
938  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
939  EXPECT_EQ(ld.NumLoops(), 2u);
940
941  auto loops = ld.GetLoopsInBinaryLayoutOrder();
942
943  LoopFusion fusion(context.get(), loops[0], loops[1]);
944  EXPECT_FALSE(fusion.AreCompatible());
945}
946
947/*
948Generated from the following GLSL + --eliminate-local-multi-store
949
950// 10
951#version 440 core
952void main() {
953  for (int i = 0; i < 10; i++) {}
954  for (int j = 0; j < 10; j++) {}
955  for (int k = 0; k < 10; k++) {}
956}
957
958*/
959TEST_F(FusionCompatibilityTest, SeveralAdjacentLoops) {
960  const std::string text = R"(
961               OpCapability Shader
962          %1 = OpExtInstImport "GLSL.std.450"
963               OpMemoryModel Logical GLSL450
964               OpEntryPoint Fragment %4 "main"
965               OpExecutionMode %4 OriginUpperLeft
966               OpSource GLSL 440
967               OpName %4 "main"
968               OpName %8 "i"
969               OpName %22 "j"
970               OpName %32 "k"
971          %2 = OpTypeVoid
972          %3 = OpTypeFunction %2
973          %6 = OpTypeInt 32 1
974          %7 = OpTypePointer Function %6
975          %9 = OpConstant %6 0
976         %16 = OpConstant %6 10
977         %17 = OpTypeBool
978         %20 = OpConstant %6 1
979          %4 = OpFunction %2 None %3
980          %5 = OpLabel
981          %8 = OpVariable %7 Function
982         %22 = OpVariable %7 Function
983         %32 = OpVariable %7 Function
984               OpStore %8 %9
985               OpBranch %10
986         %10 = OpLabel
987         %42 = OpPhi %6 %9 %5 %21 %13
988               OpLoopMerge %12 %13 None
989               OpBranch %14
990         %14 = OpLabel
991         %18 = OpSLessThan %17 %42 %16
992               OpBranchConditional %18 %11 %12
993         %11 = OpLabel
994               OpBranch %13
995         %13 = OpLabel
996         %21 = OpIAdd %6 %42 %20
997               OpStore %8 %21
998               OpBranch %10
999         %12 = OpLabel
1000               OpStore %22 %9
1001               OpBranch %23
1002         %23 = OpLabel
1003         %43 = OpPhi %6 %9 %12 %31 %26
1004               OpLoopMerge %25 %26 None
1005               OpBranch %27
1006         %27 = OpLabel
1007         %29 = OpSLessThan %17 %43 %16
1008               OpBranchConditional %29 %24 %25
1009         %24 = OpLabel
1010               OpBranch %26
1011         %26 = OpLabel
1012         %31 = OpIAdd %6 %43 %20
1013               OpStore %22 %31
1014               OpBranch %23
1015         %25 = OpLabel
1016               OpStore %32 %9
1017               OpBranch %33
1018         %33 = OpLabel
1019         %44 = OpPhi %6 %9 %25 %41 %36
1020               OpLoopMerge %35 %36 None
1021               OpBranch %37
1022         %37 = OpLabel
1023         %39 = OpSLessThan %17 %44 %16
1024               OpBranchConditional %39 %34 %35
1025         %34 = OpLabel
1026               OpBranch %36
1027         %36 = OpLabel
1028         %41 = OpIAdd %6 %44 %20
1029               OpStore %32 %41
1030               OpBranch %33
1031         %35 = OpLabel
1032               OpReturn
1033               OpFunctionEnd
1034  )";
1035
1036  std::unique_ptr<IRContext> context =
1037      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1038                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1039  Module* module = context->module();
1040  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1041                             << text << std::endl;
1042  Function& f = *module->begin();
1043  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
1044  EXPECT_EQ(ld.NumLoops(), 3u);
1045
1046  auto loops = ld.GetLoopsInBinaryLayoutOrder();
1047
1048  auto loop_0 = loops[0];
1049  auto loop_1 = loops[1];
1050  auto loop_2 = loops[2];
1051
1052  EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_0).AreCompatible());
1053  EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_2).AreCompatible());
1054  EXPECT_FALSE(LoopFusion(context.get(), loop_1, loop_0).AreCompatible());
1055  EXPECT_TRUE(LoopFusion(context.get(), loop_0, loop_1).AreCompatible());
1056  EXPECT_TRUE(LoopFusion(context.get(), loop_1, loop_2).AreCompatible());
1057}
1058
1059/*
1060Generated from the following GLSL + --eliminate-local-multi-store
1061
1062#version 440 core
1063void main() {
1064  // Can't fuse, not adjacent
1065  int x = 0;
1066  for (int i = 0; i < 10; i++) {
1067    if (i > 10) {
1068      x++;
1069    }
1070  }
1071  x++;
1072  for (int j = 0; j < 10; j++) {}
1073  for (int k = 0; k < 10; k++) {}
1074}
1075
1076*/
1077TEST_F(FusionCompatibilityTest, NonAdjacentLoops) {
1078  const std::string text = R"(
1079               OpCapability Shader
1080          %1 = OpExtInstImport "GLSL.std.450"
1081               OpMemoryModel Logical GLSL450
1082               OpEntryPoint Fragment %4 "main"
1083               OpExecutionMode %4 OriginUpperLeft
1084               OpSource GLSL 440
1085               OpName %4 "main"
1086               OpName %8 "x"
1087               OpName %10 "i"
1088               OpName %31 "j"
1089               OpName %41 "k"
1090          %2 = OpTypeVoid
1091          %3 = OpTypeFunction %2
1092          %6 = OpTypeInt 32 1
1093          %7 = OpTypePointer Function %6
1094          %9 = OpConstant %6 0
1095         %17 = OpConstant %6 10
1096         %18 = OpTypeBool
1097         %25 = OpConstant %6 1
1098          %4 = OpFunction %2 None %3
1099          %5 = OpLabel
1100          %8 = OpVariable %7 Function
1101         %10 = OpVariable %7 Function
1102         %31 = OpVariable %7 Function
1103         %41 = OpVariable %7 Function
1104               OpStore %8 %9
1105               OpStore %10 %9
1106               OpBranch %11
1107         %11 = OpLabel
1108         %52 = OpPhi %6 %9 %5 %56 %14
1109         %51 = OpPhi %6 %9 %5 %28 %14
1110               OpLoopMerge %13 %14 None
1111               OpBranch %15
1112         %15 = OpLabel
1113         %19 = OpSLessThan %18 %51 %17
1114               OpBranchConditional %19 %12 %13
1115         %12 = OpLabel
1116         %21 = OpSGreaterThan %18 %52 %17
1117               OpSelectionMerge %23 None
1118               OpBranchConditional %21 %22 %23
1119         %22 = OpLabel
1120         %26 = OpIAdd %6 %52 %25
1121               OpStore %8 %26
1122               OpBranch %23
1123         %23 = OpLabel
1124         %56 = OpPhi %6 %52 %12 %26 %22
1125               OpBranch %14
1126         %14 = OpLabel
1127         %28 = OpIAdd %6 %51 %25
1128               OpStore %10 %28
1129               OpBranch %11
1130         %13 = OpLabel
1131         %30 = OpIAdd %6 %52 %25
1132               OpStore %8 %30
1133               OpStore %31 %9
1134               OpBranch %32
1135         %32 = OpLabel
1136         %53 = OpPhi %6 %9 %13 %40 %35
1137               OpLoopMerge %34 %35 None
1138               OpBranch %36
1139         %36 = OpLabel
1140         %38 = OpSLessThan %18 %53 %17
1141               OpBranchConditional %38 %33 %34
1142         %33 = OpLabel
1143               OpBranch %35
1144         %35 = OpLabel
1145         %40 = OpIAdd %6 %53 %25
1146               OpStore %31 %40
1147               OpBranch %32
1148         %34 = OpLabel
1149               OpStore %41 %9
1150               OpBranch %42
1151         %42 = OpLabel
1152         %54 = OpPhi %6 %9 %34 %50 %45
1153               OpLoopMerge %44 %45 None
1154               OpBranch %46
1155         %46 = OpLabel
1156         %48 = OpSLessThan %18 %54 %17
1157               OpBranchConditional %48 %43 %44
1158         %43 = OpLabel
1159               OpBranch %45
1160         %45 = OpLabel
1161         %50 = OpIAdd %6 %54 %25
1162               OpStore %41 %50
1163               OpBranch %42
1164         %44 = OpLabel
1165               OpReturn
1166               OpFunctionEnd
1167  )";
1168
1169  std::unique_ptr<IRContext> context =
1170      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1171                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1172  Module* module = context->module();
1173  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1174                             << text << std::endl;
1175  Function& f = *module->begin();
1176  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
1177  EXPECT_EQ(ld.NumLoops(), 3u);
1178
1179  auto loops = ld.GetLoopsInBinaryLayoutOrder();
1180
1181  auto loop_0 = loops[0];
1182  auto loop_1 = loops[1];
1183  auto loop_2 = loops[2];
1184
1185  EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_0).AreCompatible());
1186  EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_2).AreCompatible());
1187  EXPECT_FALSE(LoopFusion(context.get(), loop_0, loop_1).AreCompatible());
1188  EXPECT_TRUE(LoopFusion(context.get(), loop_1, loop_2).AreCompatible());
1189}
1190
1191/*
1192Generated from the following GLSL + --eliminate-local-multi-store
1193
1194// 12
1195#version 440 core
1196void main() {
1197  int j = 0;
1198  int i = 0;
1199  for (; i < 10; i++) {}
1200  for (; j < 10; j++) {}
1201}
1202
1203*/
1204TEST_F(FusionCompatibilityTest, CompatibleInitDeclaredBeforeLoops) {
1205  const std::string text = R"(
1206               OpCapability Shader
1207          %1 = OpExtInstImport "GLSL.std.450"
1208               OpMemoryModel Logical GLSL450
1209               OpEntryPoint Fragment %4 "main"
1210               OpExecutionMode %4 OriginUpperLeft
1211               OpSource GLSL 440
1212               OpName %4 "main"
1213               OpName %8 "j"
1214               OpName %10 "i"
1215          %2 = OpTypeVoid
1216          %3 = OpTypeFunction %2
1217          %6 = OpTypeInt 32 1
1218          %7 = OpTypePointer Function %6
1219          %9 = OpConstant %6 0
1220         %17 = OpConstant %6 10
1221         %18 = OpTypeBool
1222         %21 = OpConstant %6 1
1223          %4 = OpFunction %2 None %3
1224          %5 = OpLabel
1225          %8 = OpVariable %7 Function
1226         %10 = OpVariable %7 Function
1227               OpStore %8 %9
1228               OpStore %10 %9
1229               OpBranch %11
1230         %11 = OpLabel
1231         %32 = OpPhi %6 %9 %5 %22 %14
1232               OpLoopMerge %13 %14 None
1233               OpBranch %15
1234         %15 = OpLabel
1235         %19 = OpSLessThan %18 %32 %17
1236               OpBranchConditional %19 %12 %13
1237         %12 = OpLabel
1238               OpBranch %14
1239         %14 = OpLabel
1240         %22 = OpIAdd %6 %32 %21
1241               OpStore %10 %22
1242               OpBranch %11
1243         %13 = OpLabel
1244               OpBranch %23
1245         %23 = OpLabel
1246         %33 = OpPhi %6 %9 %13 %31 %26
1247               OpLoopMerge %25 %26 None
1248               OpBranch %27
1249         %27 = OpLabel
1250         %29 = OpSLessThan %18 %33 %17
1251               OpBranchConditional %29 %24 %25
1252         %24 = OpLabel
1253               OpBranch %26
1254         %26 = OpLabel
1255         %31 = OpIAdd %6 %33 %21
1256               OpStore %8 %31
1257               OpBranch %23
1258         %25 = OpLabel
1259               OpReturn
1260               OpFunctionEnd
1261  )";
1262
1263  std::unique_ptr<IRContext> context =
1264      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1265                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1266  Module* module = context->module();
1267  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1268                             << text << std::endl;
1269  Function& f = *module->begin();
1270  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
1271  EXPECT_EQ(ld.NumLoops(), 2u);
1272
1273  auto loops = ld.GetLoopsInBinaryLayoutOrder();
1274
1275  EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible());
1276}
1277
1278/*
1279Generated from the following GLSL + --eliminate-local-multi-store
1280
1281// 13 regenerate!
1282#version 440 core
1283void main() {
1284  int[10] a;
1285  int[10] b;
1286  // Can't fuse, several induction variables
1287  for (int j = 0; j < 10; j++) {
1288    b[i] = a[i];
1289  }
1290  for (int i = 0, j = 0; i < 10; i++, j = j+2) {
1291  }
1292}
1293
1294*/
1295TEST_F(FusionCompatibilityTest, SeveralInductionVariables) {
1296  const std::string text = R"(
1297               OpCapability Shader
1298          %1 = OpExtInstImport "GLSL.std.450"
1299               OpMemoryModel Logical GLSL450
1300               OpEntryPoint Fragment %4 "main"
1301               OpExecutionMode %4 OriginUpperLeft
1302               OpSource GLSL 440
1303               OpName %4 "main"
1304               OpName %8 "j"
1305               OpName %23 "b"
1306               OpName %25 "a"
1307               OpName %33 "i"
1308               OpName %34 "j"
1309          %2 = OpTypeVoid
1310          %3 = OpTypeFunction %2
1311          %6 = OpTypeInt 32 1
1312          %7 = OpTypePointer Function %6
1313          %9 = OpConstant %6 0
1314         %16 = OpConstant %6 10
1315         %17 = OpTypeBool
1316         %19 = OpTypeInt 32 0
1317         %20 = OpConstant %19 10
1318         %21 = OpTypeArray %6 %20
1319         %22 = OpTypePointer Function %21
1320         %31 = OpConstant %6 1
1321         %48 = OpConstant %6 2
1322          %4 = OpFunction %2 None %3
1323          %5 = OpLabel
1324          %8 = OpVariable %7 Function
1325         %23 = OpVariable %22 Function
1326         %25 = OpVariable %22 Function
1327         %33 = OpVariable %7 Function
1328         %34 = OpVariable %7 Function
1329               OpStore %8 %9
1330               OpBranch %10
1331         %10 = OpLabel
1332         %50 = OpPhi %6 %9 %5 %32 %13
1333               OpLoopMerge %12 %13 None
1334               OpBranch %14
1335         %14 = OpLabel
1336         %18 = OpSLessThan %17 %50 %16
1337               OpBranchConditional %18 %11 %12
1338         %11 = OpLabel
1339         %27 = OpAccessChain %7 %25 %50
1340         %28 = OpLoad %6 %27
1341         %29 = OpAccessChain %7 %23 %50
1342               OpStore %29 %28
1343               OpBranch %13
1344         %13 = OpLabel
1345         %32 = OpIAdd %6 %50 %31
1346               OpStore %8 %32
1347               OpBranch %10
1348         %12 = OpLabel
1349               OpStore %33 %9
1350               OpStore %34 %9
1351               OpBranch %35
1352         %35 = OpLabel
1353         %52 = OpPhi %6 %9 %12 %49 %38
1354         %51 = OpPhi %6 %9 %12 %46 %38
1355               OpLoopMerge %37 %38 None
1356               OpBranch %39
1357         %39 = OpLabel
1358         %41 = OpSLessThan %17 %51 %16
1359               OpBranchConditional %41 %36 %37
1360         %36 = OpLabel
1361         %44 = OpAccessChain %7 %25 %52
1362               OpStore %44 %51
1363               OpBranch %38
1364         %38 = OpLabel
1365         %46 = OpIAdd %6 %51 %31
1366               OpStore %33 %46
1367         %49 = OpIAdd %6 %52 %48
1368               OpStore %34 %49
1369               OpBranch %35
1370         %37 = OpLabel
1371               OpReturn
1372               OpFunctionEnd
1373  )";
1374
1375  std::unique_ptr<IRContext> context =
1376      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1377                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1378  Module* module = context->module();
1379  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1380                             << text << std::endl;
1381  Function& f = *module->begin();
1382  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
1383  EXPECT_EQ(ld.NumLoops(), 2u);
1384
1385  auto loops = ld.GetLoopsInBinaryLayoutOrder();
1386
1387  EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible());
1388}
1389
1390/*
1391Generated from the following GLSL + --eliminate-local-multi-store
1392
1393// 14
1394#version 440 core
1395void main() {
1396  // Fine
1397  for (int i = 0; i < 10; i = i + 2) {}
1398  for (int j = 0; j < 10; j = j + 2) {}
1399}
1400
1401*/
1402TEST_F(FusionCompatibilityTest, CompatibleNonIncrementStep) {
1403  const std::string text = R"(
1404               OpCapability Shader
1405          %1 = OpExtInstImport "GLSL.std.450"
1406               OpMemoryModel Logical GLSL450
1407               OpEntryPoint Fragment %4 "main"
1408               OpExecutionMode %4 OriginUpperLeft
1409               OpSource GLSL 440
1410               OpName %4 "main"
1411               OpName %8 "j"
1412               OpName %10 "i"
1413               OpName %11 "i"
1414               OpName %24 "j"
1415          %2 = OpTypeVoid
1416          %3 = OpTypeFunction %2
1417          %6 = OpTypeInt 32 1
1418          %7 = OpTypePointer Function %6
1419          %9 = OpConstant %6 0
1420         %18 = OpConstant %6 10
1421         %19 = OpTypeBool
1422         %22 = OpConstant %6 2
1423          %4 = OpFunction %2 None %3
1424          %5 = OpLabel
1425          %8 = OpVariable %7 Function
1426         %10 = OpVariable %7 Function
1427         %11 = OpVariable %7 Function
1428         %24 = OpVariable %7 Function
1429               OpStore %8 %9
1430               OpStore %10 %9
1431               OpStore %11 %9
1432               OpBranch %12
1433         %12 = OpLabel
1434         %34 = OpPhi %6 %9 %5 %23 %15
1435               OpLoopMerge %14 %15 None
1436               OpBranch %16
1437         %16 = OpLabel
1438         %20 = OpSLessThan %19 %34 %18
1439               OpBranchConditional %20 %13 %14
1440         %13 = OpLabel
1441               OpBranch %15
1442         %15 = OpLabel
1443         %23 = OpIAdd %6 %34 %22
1444               OpStore %11 %23
1445               OpBranch %12
1446         %14 = OpLabel
1447               OpStore %24 %9
1448               OpBranch %25
1449         %25 = OpLabel
1450         %35 = OpPhi %6 %9 %14 %33 %28
1451               OpLoopMerge %27 %28 None
1452               OpBranch %29
1453         %29 = OpLabel
1454         %31 = OpSLessThan %19 %35 %18
1455               OpBranchConditional %31 %26 %27
1456         %26 = OpLabel
1457               OpBranch %28
1458         %28 = OpLabel
1459         %33 = OpIAdd %6 %35 %22
1460               OpStore %24 %33
1461               OpBranch %25
1462         %27 = OpLabel
1463               OpReturn
1464               OpFunctionEnd
1465  )";
1466
1467  std::unique_ptr<IRContext> context =
1468      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1469                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1470  Module* module = context->module();
1471  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1472                             << text << std::endl;
1473  Function& f = *module->begin();
1474  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
1475  EXPECT_EQ(ld.NumLoops(), 2u);
1476
1477  auto loops = ld.GetLoopsInBinaryLayoutOrder();
1478
1479  EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible());
1480}
1481
1482/*
1483Generated from the following GLSL + --eliminate-local-multi-store
1484
1485// 15
1486#version 440 core
1487
1488int j = 0;
1489
1490void main() {
1491  // Not compatible, unknown init for second.
1492  for (int i = 0; i < 10; i = i + 2) {}
1493  for (; j < 10; j = j + 2) {}
1494}
1495
1496*/
1497TEST_F(FusionCompatibilityTest, UnknonInitForSecondLoop) {
1498  const std::string text = R"(
1499               OpCapability Shader
1500          %1 = OpExtInstImport "GLSL.std.450"
1501               OpMemoryModel Logical GLSL450
1502               OpEntryPoint Fragment %4 "main"
1503               OpExecutionMode %4 OriginUpperLeft
1504               OpSource GLSL 440
1505               OpName %4 "main"
1506               OpName %8 "j"
1507               OpName %11 "i"
1508          %2 = OpTypeVoid
1509          %3 = OpTypeFunction %2
1510          %6 = OpTypeInt 32 1
1511          %7 = OpTypePointer Private %6
1512          %8 = OpVariable %7 Private
1513          %9 = OpConstant %6 0
1514         %10 = OpTypePointer Function %6
1515         %18 = OpConstant %6 10
1516         %19 = OpTypeBool
1517         %22 = OpConstant %6 2
1518          %4 = OpFunction %2 None %3
1519          %5 = OpLabel
1520         %11 = OpVariable %10 Function
1521               OpStore %8 %9
1522               OpStore %11 %9
1523               OpBranch %12
1524         %12 = OpLabel
1525         %33 = OpPhi %6 %9 %5 %23 %15
1526               OpLoopMerge %14 %15 None
1527               OpBranch %16
1528         %16 = OpLabel
1529         %20 = OpSLessThan %19 %33 %18
1530               OpBranchConditional %20 %13 %14
1531         %13 = OpLabel
1532               OpBranch %15
1533         %15 = OpLabel
1534         %23 = OpIAdd %6 %33 %22
1535               OpStore %11 %23
1536               OpBranch %12
1537         %14 = OpLabel
1538               OpBranch %24
1539         %24 = OpLabel
1540               OpLoopMerge %26 %27 None
1541               OpBranch %28
1542         %28 = OpLabel
1543         %29 = OpLoad %6 %8
1544         %30 = OpSLessThan %19 %29 %18
1545               OpBranchConditional %30 %25 %26
1546         %25 = OpLabel
1547               OpBranch %27
1548         %27 = OpLabel
1549         %31 = OpLoad %6 %8
1550         %32 = OpIAdd %6 %31 %22
1551               OpStore %8 %32
1552               OpBranch %24
1553         %26 = OpLabel
1554               OpReturn
1555               OpFunctionEnd
1556  )";
1557
1558  std::unique_ptr<IRContext> context =
1559      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1560                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1561  Module* module = context->module();
1562  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1563                             << text << std::endl;
1564  Function& f = *module->begin();
1565  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
1566  EXPECT_EQ(ld.NumLoops(), 2u);
1567
1568  auto loops = ld.GetLoopsInBinaryLayoutOrder();
1569
1570  EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible());
1571}
1572
1573/*
1574Generated from the following GLSL + --eliminate-local-multi-store
1575
1576// 16
1577#version 440 core
1578void main() {
1579  // Not compatible, continue in loop 0
1580  for (int i = 0; i < 10; ++i) {
1581    if (i % 2 == 1) {
1582      continue;
1583    }
1584  }
1585  for (int j = 0; j < 10; ++j) {}
1586}
1587
1588*/
1589TEST_F(FusionCompatibilityTest, Continue) {
1590  const std::string text = R"(
1591               OpCapability Shader
1592          %1 = OpExtInstImport "GLSL.std.450"
1593               OpMemoryModel Logical GLSL450
1594               OpEntryPoint Fragment %4 "main"
1595               OpExecutionMode %4 OriginUpperLeft
1596               OpSource GLSL 440
1597               OpName %4 "main"
1598               OpName %8 "i"
1599               OpName %29 "j"
1600          %2 = OpTypeVoid
1601          %3 = OpTypeFunction %2
1602          %6 = OpTypeInt 32 1
1603          %7 = OpTypePointer Function %6
1604          %9 = OpConstant %6 0
1605         %16 = OpConstant %6 10
1606         %17 = OpTypeBool
1607         %20 = OpConstant %6 2
1608         %22 = OpConstant %6 1
1609          %4 = OpFunction %2 None %3
1610          %5 = OpLabel
1611          %8 = OpVariable %7 Function
1612         %29 = OpVariable %7 Function
1613               OpStore %8 %9
1614               OpBranch %10
1615         %10 = OpLabel
1616         %39 = OpPhi %6 %9 %5 %28 %13
1617               OpLoopMerge %12 %13 None
1618               OpBranch %14
1619         %14 = OpLabel
1620         %18 = OpSLessThan %17 %39 %16
1621               OpBranchConditional %18 %11 %12
1622         %11 = OpLabel
1623         %21 = OpSMod %6 %39 %20
1624         %23 = OpIEqual %17 %21 %22
1625               OpSelectionMerge %25 None
1626               OpBranchConditional %23 %24 %25
1627         %24 = OpLabel
1628               OpBranch %13
1629         %25 = OpLabel
1630               OpBranch %13
1631         %13 = OpLabel
1632         %28 = OpIAdd %6 %39 %22
1633               OpStore %8 %28
1634               OpBranch %10
1635         %12 = OpLabel
1636               OpStore %29 %9
1637               OpBranch %30
1638         %30 = OpLabel
1639         %40 = OpPhi %6 %9 %12 %38 %33
1640               OpLoopMerge %32 %33 None
1641               OpBranch %34
1642         %34 = OpLabel
1643         %36 = OpSLessThan %17 %40 %16
1644               OpBranchConditional %36 %31 %32
1645         %31 = OpLabel
1646               OpBranch %33
1647         %33 = OpLabel
1648         %38 = OpIAdd %6 %40 %22
1649               OpStore %29 %38
1650               OpBranch %30
1651         %32 = OpLabel
1652               OpReturn
1653               OpFunctionEnd
1654  )";
1655
1656  std::unique_ptr<IRContext> context =
1657      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1658                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1659  Module* module = context->module();
1660  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1661                             << text << std::endl;
1662  Function& f = *module->begin();
1663  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
1664  EXPECT_EQ(ld.NumLoops(), 2u);
1665
1666  auto loops = ld.GetLoopsInBinaryLayoutOrder();
1667
1668  EXPECT_FALSE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible());
1669}
1670
1671/*
1672Generated from the following GLSL + --eliminate-local-multi-store
1673
1674#version 440 core
1675void main() {
1676  int[10] a;
1677  // Compatible
1678  for (int i = 0; i < 10; ++i) {
1679    if (i % 2 == 1) {
1680    } else {
1681      a[i] = i;
1682    }
1683  }
1684  for (int j = 0; j < 10; ++j) {}
1685}
1686
1687*/
1688TEST_F(FusionCompatibilityTest, IfElseInLoop) {
1689  const std::string text = R"(
1690               OpCapability Shader
1691          %1 = OpExtInstImport "GLSL.std.450"
1692               OpMemoryModel Logical GLSL450
1693               OpEntryPoint Fragment %4 "main"
1694               OpExecutionMode %4 OriginUpperLeft
1695               OpSource GLSL 440
1696               OpName %4 "main"
1697               OpName %8 "i"
1698               OpName %31 "a"
1699               OpName %37 "j"
1700          %2 = OpTypeVoid
1701          %3 = OpTypeFunction %2
1702          %6 = OpTypeInt 32 1
1703          %7 = OpTypePointer Function %6
1704          %9 = OpConstant %6 0
1705         %16 = OpConstant %6 10
1706         %17 = OpTypeBool
1707         %20 = OpConstant %6 2
1708         %22 = OpConstant %6 1
1709         %27 = OpTypeInt 32 0
1710         %28 = OpConstant %27 10
1711         %29 = OpTypeArray %6 %28
1712         %30 = OpTypePointer Function %29
1713          %4 = OpFunction %2 None %3
1714          %5 = OpLabel
1715          %8 = OpVariable %7 Function
1716         %31 = OpVariable %30 Function
1717         %37 = OpVariable %7 Function
1718               OpStore %8 %9
1719               OpBranch %10
1720         %10 = OpLabel
1721         %47 = OpPhi %6 %9 %5 %36 %13
1722               OpLoopMerge %12 %13 None
1723               OpBranch %14
1724         %14 = OpLabel
1725         %18 = OpSLessThan %17 %47 %16
1726               OpBranchConditional %18 %11 %12
1727         %11 = OpLabel
1728         %21 = OpSMod %6 %47 %20
1729         %23 = OpIEqual %17 %21 %22
1730               OpSelectionMerge %25 None
1731               OpBranchConditional %23 %24 %26
1732         %24 = OpLabel
1733               OpBranch %25
1734         %26 = OpLabel
1735         %34 = OpAccessChain %7 %31 %47
1736               OpStore %34 %47
1737               OpBranch %25
1738         %25 = OpLabel
1739               OpBranch %13
1740         %13 = OpLabel
1741         %36 = OpIAdd %6 %47 %22
1742               OpStore %8 %36
1743               OpBranch %10
1744         %12 = OpLabel
1745               OpStore %37 %9
1746               OpBranch %38
1747         %38 = OpLabel
1748         %48 = OpPhi %6 %9 %12 %46 %41
1749               OpLoopMerge %40 %41 None
1750               OpBranch %42
1751         %42 = OpLabel
1752         %44 = OpSLessThan %17 %48 %16
1753               OpBranchConditional %44 %39 %40
1754         %39 = OpLabel
1755               OpBranch %41
1756         %41 = OpLabel
1757         %46 = OpIAdd %6 %48 %22
1758               OpStore %37 %46
1759               OpBranch %38
1760         %40 = OpLabel
1761               OpReturn
1762               OpFunctionEnd
1763  )";
1764
1765  std::unique_ptr<IRContext> context =
1766      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
1767                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1768  Module* module = context->module();
1769  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
1770                             << text << std::endl;
1771  Function& f = *module->begin();
1772  LoopDescriptor& ld = *context->GetLoopDescriptor(&f);
1773  EXPECT_EQ(ld.NumLoops(), 2u);
1774
1775  auto loops = ld.GetLoopsInBinaryLayoutOrder();
1776
1777  EXPECT_TRUE(LoopFusion(context.get(), loops[0], loops[1]).AreCompatible());
1778}
1779
1780}  // namespace
1781}  // namespace opt
1782}  // namespace spvtools
1783