1// Copyright (c) 2015-2016 The Khronos Group Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Assembler tests for instructions in the "Control Flow" section of the
16// SPIR-V spec.
17
18#include <sstream>
19#include <string>
20#include <tuple>
21#include <vector>
22
23#include "gmock/gmock.h"
24#include "test/test_fixture.h"
25#include "test/unit_spirv.h"
26
27namespace spvtools {
28namespace {
29
30using spvtest::Concatenate;
31using spvtest::EnumCase;
32using spvtest::MakeInstruction;
33using spvtest::TextToBinaryTest;
34using ::testing::Combine;
35using ::testing::Eq;
36using ::testing::TestWithParam;
37using ::testing::Values;
38using ::testing::ValuesIn;
39
40// Test OpSelectionMerge
41
42using OpSelectionMergeTest = spvtest::TextToBinaryTestBase<
43    TestWithParam<EnumCase<spv::SelectionControlMask>>>;
44
45TEST_P(OpSelectionMergeTest, AnySingleSelectionControlMask) {
46  const std::string input = "OpSelectionMerge %1 " + GetParam().name();
47  EXPECT_THAT(CompiledInstructions(input),
48              Eq(MakeInstruction(spv::Op::OpSelectionMerge,
49                                 {1, uint32_t(GetParam().value())})));
50}
51
52// clang-format off
53#define CASE(VALUE,NAME) { spv::SelectionControlMask::VALUE, NAME}
54INSTANTIATE_TEST_SUITE_P(TextToBinarySelectionMerge, OpSelectionMergeTest,
55                        ValuesIn(std::vector<EnumCase<spv::SelectionControlMask>>{
56                            CASE(MaskNone, "None"),
57                            CASE(Flatten, "Flatten"),
58                            CASE(DontFlatten, "DontFlatten"),
59                        }));
60#undef CASE
61// clang-format on
62
63TEST_F(OpSelectionMergeTest, CombinedSelectionControlMask) {
64  const std::string input = "OpSelectionMerge %1 Flatten|DontFlatten";
65  const uint32_t expected_mask =
66      uint32_t(spv::SelectionControlMask::Flatten |
67               spv::SelectionControlMask::DontFlatten);
68  EXPECT_THAT(
69      CompiledInstructions(input),
70      Eq(MakeInstruction(spv::Op::OpSelectionMerge, {1, expected_mask})));
71}
72
73TEST_F(OpSelectionMergeTest, WrongSelectionControl) {
74  // Case sensitive: "flatten" != "Flatten" and thus wrong.
75  EXPECT_THAT(CompileFailure("OpSelectionMerge %1 flatten|DontFlatten"),
76              Eq("Invalid selection control operand 'flatten|DontFlatten'."));
77}
78
79// Test OpLoopMerge
80
81using OpLoopMergeTest = spvtest::TextToBinaryTestBase<
82    TestWithParam<std::tuple<spv_target_env, EnumCase<int>>>>;
83
84TEST_P(OpLoopMergeTest, AnySingleLoopControlMask) {
85  const auto ctrl = std::get<1>(GetParam());
86  std::ostringstream input;
87  input << "OpLoopMerge %merge %continue " << ctrl.name();
88  for (auto num : ctrl.operands()) input << " " << num;
89  EXPECT_THAT(CompiledInstructions(input.str(), std::get<0>(GetParam())),
90              Eq(MakeInstruction(spv::Op::OpLoopMerge, {1, 2, ctrl.value()},
91                                 ctrl.operands())));
92}
93
94#define CASE(VALUE, NAME) \
95  { int32_t(spv::LoopControlMask::VALUE), NAME }
96#define CASE1(VALUE, NAME, PARM)                         \
97  {                                                      \
98    int32_t(spv::LoopControlMask::VALUE), NAME, { PARM } \
99  }
100INSTANTIATE_TEST_SUITE_P(
101    TextToBinaryLoopMerge, OpLoopMergeTest,
102    Combine(Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_1),
103            ValuesIn(std::vector<EnumCase<int>>{
104                // clang-format off
105                CASE(MaskNone, "None"),
106                CASE(Unroll, "Unroll"),
107                CASE(DontUnroll, "DontUnroll"),
108                // clang-format on
109            })));
110
111INSTANTIATE_TEST_SUITE_P(
112    TextToBinaryLoopMergeV11, OpLoopMergeTest,
113    Combine(Values(SPV_ENV_UNIVERSAL_1_1),
114            ValuesIn(std::vector<EnumCase<int>>{
115                // clang-format off
116                CASE(DependencyInfinite, "DependencyInfinite"),
117                CASE1(DependencyLength, "DependencyLength", 234),
118                {int32_t(spv::LoopControlMask::Unroll|spv::LoopControlMask::DependencyLength),
119                      "DependencyLength|Unroll", {33}},
120                // clang-format on
121            })));
122#undef CASE
123#undef CASE1
124
125TEST_F(OpLoopMergeTest, CombinedLoopControlMask) {
126  const std::string input = "OpLoopMerge %merge %continue Unroll|DontUnroll";
127  const uint32_t expected_mask =
128      uint32_t(spv::LoopControlMask::Unroll | spv::LoopControlMask::DontUnroll);
129  EXPECT_THAT(CompiledInstructions(input),
130              Eq(MakeInstruction(spv::Op::OpLoopMerge, {1, 2, expected_mask})));
131}
132
133TEST_F(OpLoopMergeTest, WrongLoopControl) {
134  EXPECT_THAT(CompileFailure("OpLoopMerge %m %c none"),
135              Eq("Invalid loop control operand 'none'."));
136}
137
138// Test OpSwitch
139
140TEST_F(TextToBinaryTest, SwitchGoodZeroTargets) {
141  EXPECT_THAT(CompiledInstructions("OpSwitch %selector %default"),
142              Eq(MakeInstruction(spv::Op::OpSwitch, {1, 2})));
143}
144
145TEST_F(TextToBinaryTest, SwitchGoodOneTarget) {
146  EXPECT_THAT(
147      CompiledInstructions("%1 = OpTypeInt 32 0\n"
148                           "%2 = OpConstant %1 52\n"
149                           "OpSwitch %2 %default 12 %target0"),
150      Eq(Concatenate({MakeInstruction(spv::Op::OpTypeInt, {1, 32, 0}),
151                      MakeInstruction(spv::Op::OpConstant, {1, 2, 52}),
152                      MakeInstruction(spv::Op::OpSwitch, {2, 3, 12, 4})})));
153}
154
155TEST_F(TextToBinaryTest, SwitchGoodTwoTargets) {
156  EXPECT_THAT(
157      CompiledInstructions("%1 = OpTypeInt 32 0\n"
158                           "%2 = OpConstant %1 52\n"
159                           "OpSwitch %2 %default 12 %target0 42 %target1"),
160      Eq(Concatenate({
161          MakeInstruction(spv::Op::OpTypeInt, {1, 32, 0}),
162          MakeInstruction(spv::Op::OpConstant, {1, 2, 52}),
163          MakeInstruction(spv::Op::OpSwitch, {2, 3, 12, 4, 42, 5}),
164      })));
165}
166
167TEST_F(TextToBinaryTest, SwitchBadMissingSelector) {
168  EXPECT_THAT(CompileFailure("OpSwitch"),
169              Eq("Expected operand for OpSwitch instruction, but found the end "
170                 "of the stream."));
171}
172
173TEST_F(TextToBinaryTest, SwitchBadInvalidSelector) {
174  EXPECT_THAT(CompileFailure("OpSwitch 12"),
175              Eq("Expected id to start with %."));
176}
177
178TEST_F(TextToBinaryTest, SwitchBadMissingDefault) {
179  EXPECT_THAT(CompileFailure("OpSwitch %selector"),
180              Eq("Expected operand for OpSwitch instruction, but found the end "
181                 "of the stream."));
182}
183
184TEST_F(TextToBinaryTest, SwitchBadInvalidDefault) {
185  EXPECT_THAT(CompileFailure("OpSwitch %selector 12"),
186              Eq("Expected id to start with %."));
187}
188
189TEST_F(TextToBinaryTest, SwitchBadInvalidLiteral) {
190  // The assembler recognizes "OpSwitch %selector %default" as a complete
191  // instruction.  Then it tries to parse "%abc" as the start of a new
192  // instruction, but can't since it hits the end of stream.
193  const auto input = R"(%i32 = OpTypeInt 32 0
194                        %selector = OpConstant %i32 42
195                        OpSwitch %selector %default %abc)";
196  EXPECT_THAT(CompileFailure(input), Eq("Expected '=', found end of stream."));
197}
198
199TEST_F(TextToBinaryTest, SwitchBadMissingTarget) {
200  EXPECT_THAT(CompileFailure("%1 = OpTypeInt 32 0\n"
201                             "%2 = OpConstant %1 52\n"
202                             "OpSwitch %2 %default 12"),
203              Eq("Expected operand for OpSwitch instruction, but found the end "
204                 "of the stream."));
205}
206
207// A test case for an OpSwitch.
208// It is also parameterized to test encodings OpConstant
209// integer literals.  This can capture both single and multi-word
210// integer literal tests.
211struct SwitchTestCase {
212  std::string constant_type_args;
213  std::string constant_value_arg;
214  std::string case_value_arg;
215  std::vector<uint32_t> expected_instructions;
216};
217
218using OpSwitchValidTest =
219    spvtest::TextToBinaryTestBase<TestWithParam<SwitchTestCase>>;
220
221// Tests the encoding of OpConstant literal values, and also
222// the literal integer cases in an OpSwitch.  This can
223// test both single and multi-word integer literal encodings.
224TEST_P(OpSwitchValidTest, ValidTypes) {
225  const std::string input = "%1 = OpTypeInt " + GetParam().constant_type_args +
226                            "\n"
227                            "%2 = OpConstant %1 " +
228                            GetParam().constant_value_arg +
229                            "\n"
230                            "OpSwitch %2 %default " +
231                            GetParam().case_value_arg + " %4\n";
232  std::vector<uint32_t> instructions;
233  EXPECT_THAT(CompiledInstructions(input),
234              Eq(GetParam().expected_instructions));
235}
236
237// Constructs a SwitchTestCase from the given integer_width, signedness,
238// constant value string, and expected encoded constant.
239SwitchTestCase MakeSwitchTestCase(uint32_t integer_width,
240                                  uint32_t integer_signedness,
241                                  std::string constant_str,
242                                  std::vector<uint32_t> encoded_constant,
243                                  std::string case_value_str,
244                                  std::vector<uint32_t> encoded_case_value) {
245  std::stringstream ss;
246  ss << integer_width << " " << integer_signedness;
247  return SwitchTestCase{
248      ss.str(),
249      constant_str,
250      case_value_str,
251      {Concatenate(
252          {MakeInstruction(spv::Op::OpTypeInt,
253                           {1, integer_width, integer_signedness}),
254           MakeInstruction(spv::Op::OpConstant,
255                           Concatenate({{1, 2}, encoded_constant})),
256           MakeInstruction(spv::Op::OpSwitch,
257                           Concatenate({{2, 3}, encoded_case_value, {4}}))})}};
258}
259
260INSTANTIATE_TEST_SUITE_P(
261    TextToBinaryOpSwitchValid1Word, OpSwitchValidTest,
262    ValuesIn(std::vector<SwitchTestCase>({
263        MakeSwitchTestCase(32, 0, "42", {42}, "100", {100}),
264        MakeSwitchTestCase(32, 1, "-1", {0xffffffff}, "100", {100}),
265        // SPIR-V 1.0 Rev 1 clarified that for an integer narrower than 32-bits,
266        // its bits will appear in the lower order bits of the 32-bit word, and
267        // a signed integer is sign-extended.
268        MakeSwitchTestCase(7, 0, "127", {127}, "100", {100}),
269        MakeSwitchTestCase(14, 0, "99", {99}, "100", {100}),
270        MakeSwitchTestCase(16, 0, "65535", {65535}, "100", {100}),
271        MakeSwitchTestCase(16, 1, "101", {101}, "100", {100}),
272        // Demonstrate sign extension
273        MakeSwitchTestCase(16, 1, "-2", {0xfffffffe}, "100", {100}),
274        // Hex cases
275        MakeSwitchTestCase(16, 1, "0x7ffe", {0x7ffe}, "0x1234", {0x1234}),
276        MakeSwitchTestCase(16, 1, "0x8000", {0xffff8000}, "0x8100",
277                           {0xffff8100}),
278        MakeSwitchTestCase(16, 0, "0x8000", {0x00008000}, "0x8100", {0x8100}),
279    })));
280
281// NB: The words LOW ORDER bits show up first.
282INSTANTIATE_TEST_SUITE_P(
283    TextToBinaryOpSwitchValid2Words, OpSwitchValidTest,
284    ValuesIn(std::vector<SwitchTestCase>({
285        MakeSwitchTestCase(33, 0, "101", {101, 0}, "500", {500, 0}),
286        MakeSwitchTestCase(48, 1, "-1", {0xffffffff, 0xffffffff}, "900",
287                           {900, 0}),
288        MakeSwitchTestCase(64, 1, "-2", {0xfffffffe, 0xffffffff}, "-5",
289                           {0xfffffffb, uint32_t(-1)}),
290        // Hex cases
291        MakeSwitchTestCase(48, 1, "0x7fffffffffff", {0xffffffff, 0x00007fff},
292                           "100", {100, 0}),
293        MakeSwitchTestCase(48, 1, "0x800000000000", {0x00000000, 0xffff8000},
294                           "0x800000000000", {0x00000000, 0xffff8000}),
295        MakeSwitchTestCase(48, 0, "0x800000000000", {0x00000000, 0x00008000},
296                           "0x800000000000", {0x00000000, 0x00008000}),
297        MakeSwitchTestCase(63, 0, "0x500000000", {0, 5}, "12", {12, 0}),
298        MakeSwitchTestCase(64, 0, "0x600000000", {0, 6}, "12", {12, 0}),
299        MakeSwitchTestCase(64, 1, "0x700000123", {0x123, 7}, "12", {12, 0}),
300    })));
301
302using ControlFlowRoundTripTest = RoundTripTest;
303
304TEST_P(ControlFlowRoundTripTest, DisassemblyEqualsAssemblyInput) {
305  const std::string assembly = GetParam();
306  EXPECT_THAT(EncodeAndDecodeSuccessfully(assembly), Eq(assembly)) << assembly;
307}
308
309INSTANTIATE_TEST_SUITE_P(
310    OpSwitchRoundTripUnsignedIntegers, ControlFlowRoundTripTest,
311    ValuesIn(std::vector<std::string>({
312        // Unsigned 16-bit.
313        "%1 = OpTypeInt 16 0\n%2 = OpConstant %1 65535\nOpSwitch %2 %3\n",
314        // Unsigned 32-bit, three non-default cases.
315        "%1 = OpTypeInt 32 0\n%2 = OpConstant %1 123456\n"
316        "OpSwitch %2 %3 100 %4 102 %5 1000000 %6\n",
317        // Unsigned 48-bit, three non-default cases.
318        "%1 = OpTypeInt 48 0\n%2 = OpConstant %1 5000000000\n"
319        "OpSwitch %2 %3 100 %4 102 %5 6000000000 %6\n",
320        // Unsigned 64-bit, three non-default cases.
321        "%1 = OpTypeInt 64 0\n%2 = OpConstant %1 9223372036854775807\n"
322        "OpSwitch %2 %3 100 %4 102 %5 9000000000000000000 %6\n",
323    })));
324
325INSTANTIATE_TEST_SUITE_P(
326    OpSwitchRoundTripSignedIntegers, ControlFlowRoundTripTest,
327    ValuesIn(std::vector<std::string>{
328        // Signed 16-bit, with two non-default cases
329        "%1 = OpTypeInt 16 1\n%2 = OpConstant %1 32767\n"
330        "OpSwitch %2 %3 99 %4 -102 %5\n",
331        "%1 = OpTypeInt 16 1\n%2 = OpConstant %1 -32768\n"
332        "OpSwitch %2 %3 99 %4 -102 %5\n",
333        // Signed 32-bit, two non-default cases.
334        "%1 = OpTypeInt 32 1\n%2 = OpConstant %1 -123456\n"
335        "OpSwitch %2 %3 100 %4 -123456 %5\n",
336        "%1 = OpTypeInt 32 1\n%2 = OpConstant %1 123456\n"
337        "OpSwitch %2 %3 100 %4 123456 %5\n",
338        // Signed 48-bit, three non-default cases.
339        "%1 = OpTypeInt 48 1\n%2 = OpConstant %1 5000000000\n"
340        "OpSwitch %2 %3 100 %4 -7000000000 %5 6000000000 %6\n",
341        "%1 = OpTypeInt 48 1\n%2 = OpConstant %1 -5000000000\n"
342        "OpSwitch %2 %3 100 %4 -7000000000 %5 6000000000 %6\n",
343        // Signed 64-bit, three non-default cases.
344        "%1 = OpTypeInt 64 1\n%2 = OpConstant %1 9223372036854775807\n"
345        "OpSwitch %2 %3 100 %4 7000000000 %5 -1000000000000000000 %6\n",
346        "%1 = OpTypeInt 64 1\n%2 = OpConstant %1 -9223372036854775808\n"
347        "OpSwitch %2 %3 100 %4 7000000000 %5 -1000000000000000000 %6\n",
348    }));
349
350using OpSwitchInvalidTypeTestCase =
351    spvtest::TextToBinaryTestBase<TestWithParam<std::string>>;
352
353TEST_P(OpSwitchInvalidTypeTestCase, InvalidTypes) {
354  const std::string input =
355      "%1 = " + GetParam() +
356      "\n"
357      "%3 = OpCopyObject %1 %2\n"  // We only care the type of the expression
358      "     OpSwitch %3 %default 32 %c\n";
359  EXPECT_THAT(CompileFailure(input),
360              Eq("The selector operand for OpSwitch must be the result of an "
361                 "instruction that generates an integer scalar"));
362}
363
364// clang-format off
365INSTANTIATE_TEST_SUITE_P(
366    TextToBinaryOpSwitchInvalidTests, OpSwitchInvalidTypeTestCase,
367    ValuesIn(std::vector<std::string>{
368      {"OpTypeVoid",
369       "OpTypeBool",
370       "OpTypeFloat 32",
371       "OpTypeVector %a 32",
372       "OpTypeMatrix %a 32",
373       "OpTypeImage %a 1D 0 0 0 0 Unknown",
374       "OpTypeSampler",
375       "OpTypeSampledImage %a",
376       "OpTypeArray %a %b",
377       "OpTypeRuntimeArray %a",
378       "OpTypeStruct %a",
379       "OpTypeOpaque \"Foo\"",
380       "OpTypePointer UniformConstant %a",
381       "OpTypeFunction %a %b",
382       "OpTypeEvent",
383       "OpTypeDeviceEvent",
384       "OpTypeReserveId",
385       "OpTypeQueue",
386       "OpTypePipe ReadOnly",
387
388       // Skip OpTypeForwardPointer because it doesn't even produce a result
389       // ID.
390
391       // At least one thing that isn't a type at all
392       "OpNot %a %b"
393      },
394    }));
395// clang-format on
396
397using OpKillTest = spvtest::TextToBinaryTest;
398
399INSTANTIATE_TEST_SUITE_P(OpKillTest, ControlFlowRoundTripTest,
400                         Values("OpKill\n"));
401
402TEST_F(OpKillTest, ExtraArgsAssemblyError) {
403  const std::string input = "OpKill 1";
404  EXPECT_THAT(CompileFailure(input),
405              Eq("Expected <opcode> or <result-id> at the beginning of an "
406                 "instruction, found '1'."));
407}
408
409using OpTerminateInvocationTest = spvtest::TextToBinaryTest;
410
411INSTANTIATE_TEST_SUITE_P(OpTerminateInvocationTest, ControlFlowRoundTripTest,
412                         Values("OpTerminateInvocation\n"));
413
414TEST_F(OpTerminateInvocationTest, ExtraArgsAssemblyError) {
415  const std::string input = "OpTerminateInvocation 1";
416  EXPECT_THAT(CompileFailure(input),
417              Eq("Expected <opcode> or <result-id> at the beginning of an "
418                 "instruction, found '1'."));
419}
420
421// TODO(dneto): OpPhi
422// TODO(dneto): OpLoopMerge
423// TODO(dneto): OpLabel
424// TODO(dneto): OpBranch
425// TODO(dneto): OpSwitch
426// TODO(dneto): OpReturn
427// TODO(dneto): OpReturnValue
428// TODO(dneto): OpUnreachable
429// TODO(dneto): OpLifetimeStart
430// TODO(dneto): OpLifetimeStop
431
432}  // namespace
433}  // namespace spvtools
434