1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_swap_commutable_operands.h"
16 
17 #include "gtest/gtest.h"
18 #include "source/fuzz/fuzzer_util.h"
19 #include "source/fuzz/instruction_descriptor.h"
20 #include "test/fuzz/fuzz_test_util.h"
21 
22 namespace spvtools {
23 namespace fuzz {
24 namespace {
25 
TEST(TransformationSwapCommutableOperandsTest, IsApplicableTest)26 TEST(TransformationSwapCommutableOperandsTest, IsApplicableTest) {
27   std::string shader = R"(
28                OpCapability Shader
29           %1 = OpExtInstImport "GLSL.std.450"
30                OpMemoryModel Logical GLSL450
31                OpEntryPoint Fragment %4 "main"
32                OpExecutionMode %4 OriginUpperLeft
33                OpSource ESSL 310
34                OpName %4 "main"
35           %2 = OpTypeVoid
36           %3 = OpTypeFunction %2
37           %6 = OpTypeInt 32 1
38           %7 = OpTypeInt 32 0
39           %8 = OpConstant %7 2
40           %9 = OpTypeArray %6 %8
41          %10 = OpTypePointer Function %9
42          %12 = OpConstant %6 1
43          %13 = OpConstant %6 2
44          %14 = OpConstantComposite %9 %12 %13
45          %15 = OpTypePointer Function %6
46          %17 = OpConstant %6 0
47          %29 = OpTypeFloat 32
48          %30 = OpTypeArray %29 %8
49          %31 = OpTypePointer Function %30
50          %33 = OpConstant %29 1
51          %34 = OpConstant %29 2
52          %35 = OpConstantComposite %30 %33 %34
53          %36 = OpTypePointer Function %29
54          %49 = OpTypeVector %29 3
55          %50 = OpTypeArray %49 %8
56          %51 = OpTypePointer Function %50
57          %53 = OpConstant %29 3
58          %54 = OpConstantComposite %49 %33 %34 %53
59          %55 = OpConstant %29 4
60          %56 = OpConstant %29 5
61          %57 = OpConstant %29 6
62          %58 = OpConstantComposite %49 %55 %56 %57
63          %59 = OpConstantComposite %50 %54 %58
64          %61 = OpTypePointer Function %49
65           %4 = OpFunction %2 None %3
66           %5 = OpLabel
67          %11 = OpVariable %10 Function
68          %16 = OpVariable %15 Function
69          %23 = OpVariable %15 Function
70          %32 = OpVariable %31 Function
71          %37 = OpVariable %36 Function
72          %43 = OpVariable %36 Function
73          %52 = OpVariable %51 Function
74          %60 = OpVariable %36 Function
75                OpStore %11 %14
76          %18 = OpAccessChain %15 %11 %17
77          %19 = OpLoad %6 %18
78          %20 = OpAccessChain %15 %11 %12
79          %21 = OpLoad %6 %20
80          %22 = OpIAdd %6 %19 %21
81                OpStore %16 %22
82          %24 = OpAccessChain %15 %11 %17
83          %25 = OpLoad %6 %24
84          %26 = OpAccessChain %15 %11 %12
85          %27 = OpLoad %6 %26
86          %28 = OpIMul %6 %25 %27
87                OpStore %23 %28
88                OpStore %32 %35
89          %38 = OpAccessChain %36 %32 %17
90          %39 = OpLoad %29 %38
91          %40 = OpAccessChain %36 %32 %12
92          %41 = OpLoad %29 %40
93          %42 = OpFAdd %29 %39 %41
94                OpStore %37 %42
95          %44 = OpAccessChain %36 %32 %17
96          %45 = OpLoad %29 %44
97          %46 = OpAccessChain %36 %32 %12
98          %47 = OpLoad %29 %46
99          %48 = OpFMul %29 %45 %47
100                OpStore %43 %48
101                OpStore %52 %59
102          %62 = OpAccessChain %61 %52 %17
103          %63 = OpLoad %49 %62
104          %64 = OpAccessChain %61 %52 %12
105          %65 = OpLoad %49 %64
106          %66 = OpDot %29 %63 %65
107                OpStore %60 %66
108                OpReturn
109                OpFunctionEnd
110   )";
111 
112   const auto env = SPV_ENV_UNIVERSAL_1_5;
113   const auto consumer = nullptr;
114   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
115   spvtools::ValidatorOptions validator_options;
116   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
117                                                kConsoleMessageConsumer));
118   TransformationContext transformation_context(
119       MakeUnique<FactManager>(context.get()), validator_options);
120   // Tests existing commutative instructions
121   auto instructionDescriptor =
122       MakeInstructionDescriptor(22, spv::Op::OpIAdd, 0);
123   auto transformation =
124       TransformationSwapCommutableOperands(instructionDescriptor);
125   ASSERT_TRUE(
126       transformation.IsApplicable(context.get(), transformation_context));
127 
128   instructionDescriptor = MakeInstructionDescriptor(28, spv::Op::OpIMul, 0);
129   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
130   ASSERT_TRUE(
131       transformation.IsApplicable(context.get(), transformation_context));
132 
133   instructionDescriptor = MakeInstructionDescriptor(42, spv::Op::OpFAdd, 0);
134   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
135   ASSERT_TRUE(
136       transformation.IsApplicable(context.get(), transformation_context));
137 
138   instructionDescriptor = MakeInstructionDescriptor(48, spv::Op::OpFMul, 0);
139   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
140   ASSERT_TRUE(
141       transformation.IsApplicable(context.get(), transformation_context));
142 
143   instructionDescriptor = MakeInstructionDescriptor(66, spv::Op::OpDot, 0);
144   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
145   ASSERT_TRUE(
146       transformation.IsApplicable(context.get(), transformation_context));
147 
148   // Tests existing non-commutative instructions
149   instructionDescriptor =
150       MakeInstructionDescriptor(1, spv::Op::OpExtInstImport, 0);
151   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
152   ASSERT_FALSE(
153       transformation.IsApplicable(context.get(), transformation_context));
154 
155   instructionDescriptor = MakeInstructionDescriptor(5, spv::Op::OpLabel, 0);
156   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
157   ASSERT_FALSE(
158       transformation.IsApplicable(context.get(), transformation_context));
159 
160   instructionDescriptor = MakeInstructionDescriptor(8, spv::Op::OpConstant, 0);
161   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
162   ASSERT_FALSE(
163       transformation.IsApplicable(context.get(), transformation_context));
164 
165   instructionDescriptor = MakeInstructionDescriptor(11, spv::Op::OpVariable, 0);
166   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
167   ASSERT_FALSE(
168       transformation.IsApplicable(context.get(), transformation_context));
169 
170   instructionDescriptor =
171       MakeInstructionDescriptor(14, spv::Op::OpConstantComposite, 0);
172   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
173   ASSERT_FALSE(
174       transformation.IsApplicable(context.get(), transformation_context));
175 
176   // Tests the base instruction id not existing
177   instructionDescriptor =
178       MakeInstructionDescriptor(67, spv::Op::OpIAddCarry, 0);
179   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
180   ASSERT_FALSE(
181       transformation.IsApplicable(context.get(), transformation_context));
182 
183   instructionDescriptor = MakeInstructionDescriptor(68, spv::Op::OpIEqual, 0);
184   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
185   ASSERT_FALSE(
186       transformation.IsApplicable(context.get(), transformation_context));
187 
188   instructionDescriptor =
189       MakeInstructionDescriptor(69, spv::Op::OpINotEqual, 0);
190   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
191   ASSERT_FALSE(
192       transformation.IsApplicable(context.get(), transformation_context));
193 
194   instructionDescriptor =
195       MakeInstructionDescriptor(70, spv::Op::OpFOrdEqual, 0);
196   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
197   ASSERT_FALSE(
198       transformation.IsApplicable(context.get(), transformation_context));
199 
200   instructionDescriptor = MakeInstructionDescriptor(71, spv::Op::OpPtrEqual, 0);
201   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
202   ASSERT_FALSE(
203       transformation.IsApplicable(context.get(), transformation_context));
204 
205   // Tests there being no instruction with the desired opcode after the base
206   // instruction id
207   instructionDescriptor = MakeInstructionDescriptor(24, spv::Op::OpIAdd, 0);
208   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
209   ASSERT_FALSE(
210       transformation.IsApplicable(context.get(), transformation_context));
211 
212   instructionDescriptor = MakeInstructionDescriptor(38, spv::Op::OpIMul, 0);
213   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
214   ASSERT_FALSE(
215       transformation.IsApplicable(context.get(), transformation_context));
216 
217   instructionDescriptor = MakeInstructionDescriptor(45, spv::Op::OpFAdd, 0);
218   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
219   ASSERT_FALSE(
220       transformation.IsApplicable(context.get(), transformation_context));
221 
222   instructionDescriptor = MakeInstructionDescriptor(66, spv::Op::OpFMul, 0);
223   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
224   ASSERT_FALSE(
225       transformation.IsApplicable(context.get(), transformation_context));
226 
227   // Tests there being an instruction with the desired opcode after the base
228   // instruction id, but the skip count associated with the instruction
229   // descriptor being so high.
230   instructionDescriptor = MakeInstructionDescriptor(11, spv::Op::OpIAdd, 100);
231   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
232   ASSERT_FALSE(
233       transformation.IsApplicable(context.get(), transformation_context));
234 
235   instructionDescriptor = MakeInstructionDescriptor(16, spv::Op::OpIMul, 100);
236   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
237   ASSERT_FALSE(
238       transformation.IsApplicable(context.get(), transformation_context));
239 
240   instructionDescriptor = MakeInstructionDescriptor(23, spv::Op::OpFAdd, 100);
241   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
242   ASSERT_FALSE(
243       transformation.IsApplicable(context.get(), transformation_context));
244 
245   instructionDescriptor = MakeInstructionDescriptor(32, spv::Op::OpFMul, 100);
246   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
247   ASSERT_FALSE(
248       transformation.IsApplicable(context.get(), transformation_context));
249 
250   instructionDescriptor = MakeInstructionDescriptor(37, spv::Op::OpDot, 100);
251   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
252   ASSERT_FALSE(
253       transformation.IsApplicable(context.get(), transformation_context));
254 }
255 
TEST(TransformationSwapCommutableOperandsTest, ApplyTest)256 TEST(TransformationSwapCommutableOperandsTest, ApplyTest) {
257   std::string shader = R"(
258                OpCapability Shader
259           %1 = OpExtInstImport "GLSL.std.450"
260                OpMemoryModel Logical GLSL450
261                OpEntryPoint Fragment %4 "main"
262                OpExecutionMode %4 OriginUpperLeft
263                OpSource ESSL 310
264                OpName %4 "main"
265           %2 = OpTypeVoid
266           %3 = OpTypeFunction %2
267           %6 = OpTypeInt 32 1
268           %7 = OpTypeInt 32 0
269           %8 = OpConstant %7 2
270           %9 = OpTypeArray %6 %8
271          %10 = OpTypePointer Function %9
272          %12 = OpConstant %6 1
273          %13 = OpConstant %6 2
274          %14 = OpConstantComposite %9 %12 %13
275          %15 = OpTypePointer Function %6
276          %17 = OpConstant %6 0
277          %29 = OpTypeFloat 32
278          %30 = OpTypeArray %29 %8
279          %31 = OpTypePointer Function %30
280          %33 = OpConstant %29 1
281          %34 = OpConstant %29 2
282          %35 = OpConstantComposite %30 %33 %34
283          %36 = OpTypePointer Function %29
284          %49 = OpTypeVector %29 3
285          %50 = OpTypeArray %49 %8
286          %51 = OpTypePointer Function %50
287          %53 = OpConstant %29 3
288          %54 = OpConstantComposite %49 %33 %34 %53
289          %55 = OpConstant %29 4
290          %56 = OpConstant %29 5
291          %57 = OpConstant %29 6
292          %58 = OpConstantComposite %49 %55 %56 %57
293          %59 = OpConstantComposite %50 %54 %58
294          %61 = OpTypePointer Function %49
295           %4 = OpFunction %2 None %3
296           %5 = OpLabel
297          %11 = OpVariable %10 Function
298          %16 = OpVariable %15 Function
299          %23 = OpVariable %15 Function
300          %32 = OpVariable %31 Function
301          %37 = OpVariable %36 Function
302          %43 = OpVariable %36 Function
303          %52 = OpVariable %51 Function
304          %60 = OpVariable %36 Function
305                OpStore %11 %14
306          %18 = OpAccessChain %15 %11 %17
307          %19 = OpLoad %6 %18
308          %20 = OpAccessChain %15 %11 %12
309          %21 = OpLoad %6 %20
310          %22 = OpIAdd %6 %19 %21
311                OpStore %16 %22
312          %24 = OpAccessChain %15 %11 %17
313          %25 = OpLoad %6 %24
314          %26 = OpAccessChain %15 %11 %12
315          %27 = OpLoad %6 %26
316          %28 = OpIMul %6 %25 %27
317                OpStore %23 %28
318                OpStore %32 %35
319          %38 = OpAccessChain %36 %32 %17
320          %39 = OpLoad %29 %38
321          %40 = OpAccessChain %36 %32 %12
322          %41 = OpLoad %29 %40
323          %42 = OpFAdd %29 %39 %41
324                OpStore %37 %42
325          %44 = OpAccessChain %36 %32 %17
326          %45 = OpLoad %29 %44
327          %46 = OpAccessChain %36 %32 %12
328          %47 = OpLoad %29 %46
329          %48 = OpFMul %29 %45 %47
330                OpStore %43 %48
331                OpStore %52 %59
332          %62 = OpAccessChain %61 %52 %17
333          %63 = OpLoad %49 %62
334          %64 = OpAccessChain %61 %52 %12
335          %65 = OpLoad %49 %64
336          %66 = OpDot %29 %63 %65
337                OpStore %60 %66
338                OpReturn
339                OpFunctionEnd
340   )";
341 
342   const auto env = SPV_ENV_UNIVERSAL_1_5;
343   const auto consumer = nullptr;
344   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
345   spvtools::ValidatorOptions validator_options;
346   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
347                                                kConsoleMessageConsumer));
348   TransformationContext transformation_context(
349       MakeUnique<FactManager>(context.get()), validator_options);
350   auto instructionDescriptor =
351       MakeInstructionDescriptor(22, spv::Op::OpIAdd, 0);
352   auto transformation =
353       TransformationSwapCommutableOperands(instructionDescriptor);
354   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
355 
356   instructionDescriptor = MakeInstructionDescriptor(28, spv::Op::OpIMul, 0);
357   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
358   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
359 
360   instructionDescriptor = MakeInstructionDescriptor(42, spv::Op::OpFAdd, 0);
361   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
362   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
363 
364   instructionDescriptor = MakeInstructionDescriptor(48, spv::Op::OpFMul, 0);
365   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
366   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
367 
368   instructionDescriptor = MakeInstructionDescriptor(66, spv::Op::OpDot, 0);
369   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
370   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
371 
372   std::string variantShader = R"(
373                OpCapability Shader
374           %1 = OpExtInstImport "GLSL.std.450"
375                OpMemoryModel Logical GLSL450
376                OpEntryPoint Fragment %4 "main"
377                OpExecutionMode %4 OriginUpperLeft
378                OpSource ESSL 310
379                OpName %4 "main"
380           %2 = OpTypeVoid
381           %3 = OpTypeFunction %2
382           %6 = OpTypeInt 32 1
383           %7 = OpTypeInt 32 0
384           %8 = OpConstant %7 2
385           %9 = OpTypeArray %6 %8
386          %10 = OpTypePointer Function %9
387          %12 = OpConstant %6 1
388          %13 = OpConstant %6 2
389          %14 = OpConstantComposite %9 %12 %13
390          %15 = OpTypePointer Function %6
391          %17 = OpConstant %6 0
392          %29 = OpTypeFloat 32
393          %30 = OpTypeArray %29 %8
394          %31 = OpTypePointer Function %30
395          %33 = OpConstant %29 1
396          %34 = OpConstant %29 2
397          %35 = OpConstantComposite %30 %33 %34
398          %36 = OpTypePointer Function %29
399          %49 = OpTypeVector %29 3
400          %50 = OpTypeArray %49 %8
401          %51 = OpTypePointer Function %50
402          %53 = OpConstant %29 3
403          %54 = OpConstantComposite %49 %33 %34 %53
404          %55 = OpConstant %29 4
405          %56 = OpConstant %29 5
406          %57 = OpConstant %29 6
407          %58 = OpConstantComposite %49 %55 %56 %57
408          %59 = OpConstantComposite %50 %54 %58
409          %61 = OpTypePointer Function %49
410           %4 = OpFunction %2 None %3
411           %5 = OpLabel
412          %11 = OpVariable %10 Function
413          %16 = OpVariable %15 Function
414          %23 = OpVariable %15 Function
415          %32 = OpVariable %31 Function
416          %37 = OpVariable %36 Function
417          %43 = OpVariable %36 Function
418          %52 = OpVariable %51 Function
419          %60 = OpVariable %36 Function
420                OpStore %11 %14
421          %18 = OpAccessChain %15 %11 %17
422          %19 = OpLoad %6 %18
423          %20 = OpAccessChain %15 %11 %12
424          %21 = OpLoad %6 %20
425          %22 = OpIAdd %6 %21 %19
426                OpStore %16 %22
427          %24 = OpAccessChain %15 %11 %17
428          %25 = OpLoad %6 %24
429          %26 = OpAccessChain %15 %11 %12
430          %27 = OpLoad %6 %26
431          %28 = OpIMul %6 %27 %25
432                OpStore %23 %28
433                OpStore %32 %35
434          %38 = OpAccessChain %36 %32 %17
435          %39 = OpLoad %29 %38
436          %40 = OpAccessChain %36 %32 %12
437          %41 = OpLoad %29 %40
438          %42 = OpFAdd %29 %41 %39
439                OpStore %37 %42
440          %44 = OpAccessChain %36 %32 %17
441          %45 = OpLoad %29 %44
442          %46 = OpAccessChain %36 %32 %12
443          %47 = OpLoad %29 %46
444          %48 = OpFMul %29 %47 %45
445                OpStore %43 %48
446                OpStore %52 %59
447          %62 = OpAccessChain %61 %52 %17
448          %63 = OpLoad %49 %62
449          %64 = OpAccessChain %61 %52 %12
450          %65 = OpLoad %49 %64
451          %66 = OpDot %29 %65 %63
452                OpStore %60 %66
453                OpReturn
454                OpFunctionEnd
455   )";
456 
457   ASSERT_TRUE(IsEqual(env, variantShader, context.get()));
458 }
459 
460 }  // namespace
461 }  // namespace fuzz
462 }  // namespace spvtools
463