1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/fuzz/transformation_swap_commutable_operands.h"
16
17 #include "gtest/gtest.h"
18 #include "source/fuzz/fuzzer_util.h"
19 #include "source/fuzz/instruction_descriptor.h"
20 #include "test/fuzz/fuzz_test_util.h"
21
22 namespace spvtools {
23 namespace fuzz {
24 namespace {
25
TEST(TransformationSwapCommutableOperandsTest, IsApplicableTest)26 TEST(TransformationSwapCommutableOperandsTest, IsApplicableTest) {
27 std::string shader = R"(
28 OpCapability Shader
29 %1 = OpExtInstImport "GLSL.std.450"
30 OpMemoryModel Logical GLSL450
31 OpEntryPoint Fragment %4 "main"
32 OpExecutionMode %4 OriginUpperLeft
33 OpSource ESSL 310
34 OpName %4 "main"
35 %2 = OpTypeVoid
36 %3 = OpTypeFunction %2
37 %6 = OpTypeInt 32 1
38 %7 = OpTypeInt 32 0
39 %8 = OpConstant %7 2
40 %9 = OpTypeArray %6 %8
41 %10 = OpTypePointer Function %9
42 %12 = OpConstant %6 1
43 %13 = OpConstant %6 2
44 %14 = OpConstantComposite %9 %12 %13
45 %15 = OpTypePointer Function %6
46 %17 = OpConstant %6 0
47 %29 = OpTypeFloat 32
48 %30 = OpTypeArray %29 %8
49 %31 = OpTypePointer Function %30
50 %33 = OpConstant %29 1
51 %34 = OpConstant %29 2
52 %35 = OpConstantComposite %30 %33 %34
53 %36 = OpTypePointer Function %29
54 %49 = OpTypeVector %29 3
55 %50 = OpTypeArray %49 %8
56 %51 = OpTypePointer Function %50
57 %53 = OpConstant %29 3
58 %54 = OpConstantComposite %49 %33 %34 %53
59 %55 = OpConstant %29 4
60 %56 = OpConstant %29 5
61 %57 = OpConstant %29 6
62 %58 = OpConstantComposite %49 %55 %56 %57
63 %59 = OpConstantComposite %50 %54 %58
64 %61 = OpTypePointer Function %49
65 %4 = OpFunction %2 None %3
66 %5 = OpLabel
67 %11 = OpVariable %10 Function
68 %16 = OpVariable %15 Function
69 %23 = OpVariable %15 Function
70 %32 = OpVariable %31 Function
71 %37 = OpVariable %36 Function
72 %43 = OpVariable %36 Function
73 %52 = OpVariable %51 Function
74 %60 = OpVariable %36 Function
75 OpStore %11 %14
76 %18 = OpAccessChain %15 %11 %17
77 %19 = OpLoad %6 %18
78 %20 = OpAccessChain %15 %11 %12
79 %21 = OpLoad %6 %20
80 %22 = OpIAdd %6 %19 %21
81 OpStore %16 %22
82 %24 = OpAccessChain %15 %11 %17
83 %25 = OpLoad %6 %24
84 %26 = OpAccessChain %15 %11 %12
85 %27 = OpLoad %6 %26
86 %28 = OpIMul %6 %25 %27
87 OpStore %23 %28
88 OpStore %32 %35
89 %38 = OpAccessChain %36 %32 %17
90 %39 = OpLoad %29 %38
91 %40 = OpAccessChain %36 %32 %12
92 %41 = OpLoad %29 %40
93 %42 = OpFAdd %29 %39 %41
94 OpStore %37 %42
95 %44 = OpAccessChain %36 %32 %17
96 %45 = OpLoad %29 %44
97 %46 = OpAccessChain %36 %32 %12
98 %47 = OpLoad %29 %46
99 %48 = OpFMul %29 %45 %47
100 OpStore %43 %48
101 OpStore %52 %59
102 %62 = OpAccessChain %61 %52 %17
103 %63 = OpLoad %49 %62
104 %64 = OpAccessChain %61 %52 %12
105 %65 = OpLoad %49 %64
106 %66 = OpDot %29 %63 %65
107 OpStore %60 %66
108 OpReturn
109 OpFunctionEnd
110 )";
111
112 const auto env = SPV_ENV_UNIVERSAL_1_5;
113 const auto consumer = nullptr;
114 const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
115 spvtools::ValidatorOptions validator_options;
116 ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
117 kConsoleMessageConsumer));
118 TransformationContext transformation_context(
119 MakeUnique<FactManager>(context.get()), validator_options);
120 // Tests existing commutative instructions
121 auto instructionDescriptor =
122 MakeInstructionDescriptor(22, spv::Op::OpIAdd, 0);
123 auto transformation =
124 TransformationSwapCommutableOperands(instructionDescriptor);
125 ASSERT_TRUE(
126 transformation.IsApplicable(context.get(), transformation_context));
127
128 instructionDescriptor = MakeInstructionDescriptor(28, spv::Op::OpIMul, 0);
129 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
130 ASSERT_TRUE(
131 transformation.IsApplicable(context.get(), transformation_context));
132
133 instructionDescriptor = MakeInstructionDescriptor(42, spv::Op::OpFAdd, 0);
134 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
135 ASSERT_TRUE(
136 transformation.IsApplicable(context.get(), transformation_context));
137
138 instructionDescriptor = MakeInstructionDescriptor(48, spv::Op::OpFMul, 0);
139 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
140 ASSERT_TRUE(
141 transformation.IsApplicable(context.get(), transformation_context));
142
143 instructionDescriptor = MakeInstructionDescriptor(66, spv::Op::OpDot, 0);
144 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
145 ASSERT_TRUE(
146 transformation.IsApplicable(context.get(), transformation_context));
147
148 // Tests existing non-commutative instructions
149 instructionDescriptor =
150 MakeInstructionDescriptor(1, spv::Op::OpExtInstImport, 0);
151 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
152 ASSERT_FALSE(
153 transformation.IsApplicable(context.get(), transformation_context));
154
155 instructionDescriptor = MakeInstructionDescriptor(5, spv::Op::OpLabel, 0);
156 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
157 ASSERT_FALSE(
158 transformation.IsApplicable(context.get(), transformation_context));
159
160 instructionDescriptor = MakeInstructionDescriptor(8, spv::Op::OpConstant, 0);
161 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
162 ASSERT_FALSE(
163 transformation.IsApplicable(context.get(), transformation_context));
164
165 instructionDescriptor = MakeInstructionDescriptor(11, spv::Op::OpVariable, 0);
166 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
167 ASSERT_FALSE(
168 transformation.IsApplicable(context.get(), transformation_context));
169
170 instructionDescriptor =
171 MakeInstructionDescriptor(14, spv::Op::OpConstantComposite, 0);
172 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
173 ASSERT_FALSE(
174 transformation.IsApplicable(context.get(), transformation_context));
175
176 // Tests the base instruction id not existing
177 instructionDescriptor =
178 MakeInstructionDescriptor(67, spv::Op::OpIAddCarry, 0);
179 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
180 ASSERT_FALSE(
181 transformation.IsApplicable(context.get(), transformation_context));
182
183 instructionDescriptor = MakeInstructionDescriptor(68, spv::Op::OpIEqual, 0);
184 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
185 ASSERT_FALSE(
186 transformation.IsApplicable(context.get(), transformation_context));
187
188 instructionDescriptor =
189 MakeInstructionDescriptor(69, spv::Op::OpINotEqual, 0);
190 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
191 ASSERT_FALSE(
192 transformation.IsApplicable(context.get(), transformation_context));
193
194 instructionDescriptor =
195 MakeInstructionDescriptor(70, spv::Op::OpFOrdEqual, 0);
196 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
197 ASSERT_FALSE(
198 transformation.IsApplicable(context.get(), transformation_context));
199
200 instructionDescriptor = MakeInstructionDescriptor(71, spv::Op::OpPtrEqual, 0);
201 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
202 ASSERT_FALSE(
203 transformation.IsApplicable(context.get(), transformation_context));
204
205 // Tests there being no instruction with the desired opcode after the base
206 // instruction id
207 instructionDescriptor = MakeInstructionDescriptor(24, spv::Op::OpIAdd, 0);
208 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
209 ASSERT_FALSE(
210 transformation.IsApplicable(context.get(), transformation_context));
211
212 instructionDescriptor = MakeInstructionDescriptor(38, spv::Op::OpIMul, 0);
213 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
214 ASSERT_FALSE(
215 transformation.IsApplicable(context.get(), transformation_context));
216
217 instructionDescriptor = MakeInstructionDescriptor(45, spv::Op::OpFAdd, 0);
218 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
219 ASSERT_FALSE(
220 transformation.IsApplicable(context.get(), transformation_context));
221
222 instructionDescriptor = MakeInstructionDescriptor(66, spv::Op::OpFMul, 0);
223 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
224 ASSERT_FALSE(
225 transformation.IsApplicable(context.get(), transformation_context));
226
227 // Tests there being an instruction with the desired opcode after the base
228 // instruction id, but the skip count associated with the instruction
229 // descriptor being so high.
230 instructionDescriptor = MakeInstructionDescriptor(11, spv::Op::OpIAdd, 100);
231 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
232 ASSERT_FALSE(
233 transformation.IsApplicable(context.get(), transformation_context));
234
235 instructionDescriptor = MakeInstructionDescriptor(16, spv::Op::OpIMul, 100);
236 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
237 ASSERT_FALSE(
238 transformation.IsApplicable(context.get(), transformation_context));
239
240 instructionDescriptor = MakeInstructionDescriptor(23, spv::Op::OpFAdd, 100);
241 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
242 ASSERT_FALSE(
243 transformation.IsApplicable(context.get(), transformation_context));
244
245 instructionDescriptor = MakeInstructionDescriptor(32, spv::Op::OpFMul, 100);
246 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
247 ASSERT_FALSE(
248 transformation.IsApplicable(context.get(), transformation_context));
249
250 instructionDescriptor = MakeInstructionDescriptor(37, spv::Op::OpDot, 100);
251 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
252 ASSERT_FALSE(
253 transformation.IsApplicable(context.get(), transformation_context));
254 }
255
TEST(TransformationSwapCommutableOperandsTest, ApplyTest)256 TEST(TransformationSwapCommutableOperandsTest, ApplyTest) {
257 std::string shader = R"(
258 OpCapability Shader
259 %1 = OpExtInstImport "GLSL.std.450"
260 OpMemoryModel Logical GLSL450
261 OpEntryPoint Fragment %4 "main"
262 OpExecutionMode %4 OriginUpperLeft
263 OpSource ESSL 310
264 OpName %4 "main"
265 %2 = OpTypeVoid
266 %3 = OpTypeFunction %2
267 %6 = OpTypeInt 32 1
268 %7 = OpTypeInt 32 0
269 %8 = OpConstant %7 2
270 %9 = OpTypeArray %6 %8
271 %10 = OpTypePointer Function %9
272 %12 = OpConstant %6 1
273 %13 = OpConstant %6 2
274 %14 = OpConstantComposite %9 %12 %13
275 %15 = OpTypePointer Function %6
276 %17 = OpConstant %6 0
277 %29 = OpTypeFloat 32
278 %30 = OpTypeArray %29 %8
279 %31 = OpTypePointer Function %30
280 %33 = OpConstant %29 1
281 %34 = OpConstant %29 2
282 %35 = OpConstantComposite %30 %33 %34
283 %36 = OpTypePointer Function %29
284 %49 = OpTypeVector %29 3
285 %50 = OpTypeArray %49 %8
286 %51 = OpTypePointer Function %50
287 %53 = OpConstant %29 3
288 %54 = OpConstantComposite %49 %33 %34 %53
289 %55 = OpConstant %29 4
290 %56 = OpConstant %29 5
291 %57 = OpConstant %29 6
292 %58 = OpConstantComposite %49 %55 %56 %57
293 %59 = OpConstantComposite %50 %54 %58
294 %61 = OpTypePointer Function %49
295 %4 = OpFunction %2 None %3
296 %5 = OpLabel
297 %11 = OpVariable %10 Function
298 %16 = OpVariable %15 Function
299 %23 = OpVariable %15 Function
300 %32 = OpVariable %31 Function
301 %37 = OpVariable %36 Function
302 %43 = OpVariable %36 Function
303 %52 = OpVariable %51 Function
304 %60 = OpVariable %36 Function
305 OpStore %11 %14
306 %18 = OpAccessChain %15 %11 %17
307 %19 = OpLoad %6 %18
308 %20 = OpAccessChain %15 %11 %12
309 %21 = OpLoad %6 %20
310 %22 = OpIAdd %6 %19 %21
311 OpStore %16 %22
312 %24 = OpAccessChain %15 %11 %17
313 %25 = OpLoad %6 %24
314 %26 = OpAccessChain %15 %11 %12
315 %27 = OpLoad %6 %26
316 %28 = OpIMul %6 %25 %27
317 OpStore %23 %28
318 OpStore %32 %35
319 %38 = OpAccessChain %36 %32 %17
320 %39 = OpLoad %29 %38
321 %40 = OpAccessChain %36 %32 %12
322 %41 = OpLoad %29 %40
323 %42 = OpFAdd %29 %39 %41
324 OpStore %37 %42
325 %44 = OpAccessChain %36 %32 %17
326 %45 = OpLoad %29 %44
327 %46 = OpAccessChain %36 %32 %12
328 %47 = OpLoad %29 %46
329 %48 = OpFMul %29 %45 %47
330 OpStore %43 %48
331 OpStore %52 %59
332 %62 = OpAccessChain %61 %52 %17
333 %63 = OpLoad %49 %62
334 %64 = OpAccessChain %61 %52 %12
335 %65 = OpLoad %49 %64
336 %66 = OpDot %29 %63 %65
337 OpStore %60 %66
338 OpReturn
339 OpFunctionEnd
340 )";
341
342 const auto env = SPV_ENV_UNIVERSAL_1_5;
343 const auto consumer = nullptr;
344 const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
345 spvtools::ValidatorOptions validator_options;
346 ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
347 kConsoleMessageConsumer));
348 TransformationContext transformation_context(
349 MakeUnique<FactManager>(context.get()), validator_options);
350 auto instructionDescriptor =
351 MakeInstructionDescriptor(22, spv::Op::OpIAdd, 0);
352 auto transformation =
353 TransformationSwapCommutableOperands(instructionDescriptor);
354 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
355
356 instructionDescriptor = MakeInstructionDescriptor(28, spv::Op::OpIMul, 0);
357 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
358 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
359
360 instructionDescriptor = MakeInstructionDescriptor(42, spv::Op::OpFAdd, 0);
361 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
362 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
363
364 instructionDescriptor = MakeInstructionDescriptor(48, spv::Op::OpFMul, 0);
365 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
366 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
367
368 instructionDescriptor = MakeInstructionDescriptor(66, spv::Op::OpDot, 0);
369 transformation = TransformationSwapCommutableOperands(instructionDescriptor);
370 ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
371
372 std::string variantShader = R"(
373 OpCapability Shader
374 %1 = OpExtInstImport "GLSL.std.450"
375 OpMemoryModel Logical GLSL450
376 OpEntryPoint Fragment %4 "main"
377 OpExecutionMode %4 OriginUpperLeft
378 OpSource ESSL 310
379 OpName %4 "main"
380 %2 = OpTypeVoid
381 %3 = OpTypeFunction %2
382 %6 = OpTypeInt 32 1
383 %7 = OpTypeInt 32 0
384 %8 = OpConstant %7 2
385 %9 = OpTypeArray %6 %8
386 %10 = OpTypePointer Function %9
387 %12 = OpConstant %6 1
388 %13 = OpConstant %6 2
389 %14 = OpConstantComposite %9 %12 %13
390 %15 = OpTypePointer Function %6
391 %17 = OpConstant %6 0
392 %29 = OpTypeFloat 32
393 %30 = OpTypeArray %29 %8
394 %31 = OpTypePointer Function %30
395 %33 = OpConstant %29 1
396 %34 = OpConstant %29 2
397 %35 = OpConstantComposite %30 %33 %34
398 %36 = OpTypePointer Function %29
399 %49 = OpTypeVector %29 3
400 %50 = OpTypeArray %49 %8
401 %51 = OpTypePointer Function %50
402 %53 = OpConstant %29 3
403 %54 = OpConstantComposite %49 %33 %34 %53
404 %55 = OpConstant %29 4
405 %56 = OpConstant %29 5
406 %57 = OpConstant %29 6
407 %58 = OpConstantComposite %49 %55 %56 %57
408 %59 = OpConstantComposite %50 %54 %58
409 %61 = OpTypePointer Function %49
410 %4 = OpFunction %2 None %3
411 %5 = OpLabel
412 %11 = OpVariable %10 Function
413 %16 = OpVariable %15 Function
414 %23 = OpVariable %15 Function
415 %32 = OpVariable %31 Function
416 %37 = OpVariable %36 Function
417 %43 = OpVariable %36 Function
418 %52 = OpVariable %51 Function
419 %60 = OpVariable %36 Function
420 OpStore %11 %14
421 %18 = OpAccessChain %15 %11 %17
422 %19 = OpLoad %6 %18
423 %20 = OpAccessChain %15 %11 %12
424 %21 = OpLoad %6 %20
425 %22 = OpIAdd %6 %21 %19
426 OpStore %16 %22
427 %24 = OpAccessChain %15 %11 %17
428 %25 = OpLoad %6 %24
429 %26 = OpAccessChain %15 %11 %12
430 %27 = OpLoad %6 %26
431 %28 = OpIMul %6 %27 %25
432 OpStore %23 %28
433 OpStore %32 %35
434 %38 = OpAccessChain %36 %32 %17
435 %39 = OpLoad %29 %38
436 %40 = OpAccessChain %36 %32 %12
437 %41 = OpLoad %29 %40
438 %42 = OpFAdd %29 %41 %39
439 OpStore %37 %42
440 %44 = OpAccessChain %36 %32 %17
441 %45 = OpLoad %29 %44
442 %46 = OpAccessChain %36 %32 %12
443 %47 = OpLoad %29 %46
444 %48 = OpFMul %29 %47 %45
445 OpStore %43 %48
446 OpStore %52 %59
447 %62 = OpAccessChain %61 %52 %17
448 %63 = OpLoad %49 %62
449 %64 = OpAccessChain %61 %52 %12
450 %65 = OpLoad %49 %64
451 %66 = OpDot %29 %65 %63
452 OpStore %60 %66
453 OpReturn
454 OpFunctionEnd
455 )";
456
457 ASSERT_TRUE(IsEqual(env, variantShader, context.get()));
458 }
459
460 } // namespace
461 } // namespace fuzz
462 } // namespace spvtools
463