1// Copyright (c) 2016 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <string>
16#include <utility>
17#include <vector>
18
19#include "gmock/gmock.h"
20#include "gtest/gtest.h"
21#include "spirv-tools/optimizer.hpp"
22#include "spirv/unified1/spirv.hpp11"
23
24namespace spvtools {
25namespace {
26
27using ::testing::ContainerEq;
28using ::testing::HasSubstr;
29
30// Return a string that contains the minimum instructions needed to form
31// a valid module.  Other instructions can be appended to this string.
32std::string Header() {
33  return R"(OpCapability Shader
34OpCapability Linkage
35OpMemoryModel Logical GLSL450
36)";
37}
38
39// When we assemble with a target environment of SPIR-V 1.1, we expect
40// the following in the module header version word.
41const uint32_t kExpectedSpvVersion = 0x10100;
42
43TEST(CppInterface, SuccessfulRoundTrip) {
44  const std::string input_text = "%2 = OpSizeOf %1 %3\n";
45  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
46
47  std::vector<uint32_t> binary;
48  EXPECT_TRUE(t.Assemble(input_text, &binary));
49  EXPECT_TRUE(binary.size() > 5u);
50  EXPECT_EQ(spv::MagicNumber, binary[0]);
51  EXPECT_EQ(kExpectedSpvVersion, binary[1]);
52
53  // This cannot pass validation since %1 is not defined.
54  t.SetMessageConsumer([](spv_message_level_t level, const char* source,
55                          const spv_position_t& position, const char* message) {
56    EXPECT_EQ(SPV_MSG_ERROR, level);
57    EXPECT_STREQ("input", source);
58    EXPECT_EQ(0u, position.line);
59    EXPECT_EQ(0u, position.column);
60    EXPECT_EQ(1u, position.index);
61    EXPECT_STREQ("ID '1[%1]' has not been defined\n  %2 = OpSizeOf %1 %3\n",
62                 message);
63  });
64  EXPECT_FALSE(t.Validate(binary));
65
66  std::string output_text;
67  EXPECT_TRUE(t.Disassemble(binary, &output_text));
68  EXPECT_EQ(input_text, output_text);
69}
70
71TEST(CppInterface, AssembleEmptyModule) {
72  std::vector<uint32_t> binary(10, 42);
73  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
74  EXPECT_TRUE(t.Assemble("", &binary));
75  // We only have the header.
76  EXPECT_EQ(5u, binary.size());
77  EXPECT_EQ(spv::MagicNumber, binary[0]);
78  EXPECT_EQ(kExpectedSpvVersion, binary[1]);
79}
80
81TEST(CppInterface, AssembleOverloads) {
82  const std::string input_text = "%2 = OpSizeOf %1 %3\n";
83  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
84  {
85    std::vector<uint32_t> binary;
86    EXPECT_TRUE(t.Assemble(input_text, &binary));
87    EXPECT_TRUE(binary.size() > 5u);
88    EXPECT_EQ(spv::MagicNumber, binary[0]);
89    EXPECT_EQ(kExpectedSpvVersion, binary[1]);
90  }
91  {
92    std::vector<uint32_t> binary;
93    EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size(), &binary));
94    EXPECT_TRUE(binary.size() > 5u);
95    EXPECT_EQ(spv::MagicNumber, binary[0]);
96    EXPECT_EQ(kExpectedSpvVersion, binary[1]);
97  }
98  {  // Ignore the last newline.
99    std::vector<uint32_t> binary;
100    EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size() - 1, &binary));
101    EXPECT_TRUE(binary.size() > 5u);
102    EXPECT_EQ(spv::MagicNumber, binary[0]);
103    EXPECT_EQ(kExpectedSpvVersion, binary[1]);
104  }
105}
106
107TEST(CppInterface, DisassembleEmptyModule) {
108  std::string text(10, 'x');
109  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
110  int invocation_count = 0;
111  t.SetMessageConsumer(
112      [&invocation_count](spv_message_level_t level, const char* source,
113                          const spv_position_t& position, const char* message) {
114        ++invocation_count;
115        EXPECT_EQ(SPV_MSG_ERROR, level);
116        EXPECT_STREQ("input", source);
117        EXPECT_EQ(0u, position.line);
118        EXPECT_EQ(0u, position.column);
119        EXPECT_EQ(0u, position.index);
120        EXPECT_STREQ("Missing module.", message);
121      });
122  EXPECT_FALSE(t.Disassemble({}, &text));
123  EXPECT_EQ("xxxxxxxxxx", text);  // The original string is unmodified.
124  EXPECT_EQ(1, invocation_count);
125}
126
127TEST(CppInterface, DisassembleOverloads) {
128  const std::string input_text = "%2 = OpSizeOf %1 %3\n";
129  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
130
131  std::vector<uint32_t> binary;
132  EXPECT_TRUE(t.Assemble(input_text, &binary));
133
134  {
135    std::string output_text;
136    EXPECT_TRUE(t.Disassemble(binary, &output_text));
137    EXPECT_EQ(input_text, output_text);
138  }
139  {
140    std::string output_text;
141    EXPECT_TRUE(t.Disassemble(binary.data(), binary.size(), &output_text));
142    EXPECT_EQ(input_text, output_text);
143  }
144}
145
146TEST(CppInterface, SuccessfulValidation) {
147  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
148  int invocation_count = 0;
149  t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*,
150                                           const spv_position_t&, const char*) {
151    ++invocation_count;
152  });
153
154  std::vector<uint32_t> binary;
155  EXPECT_TRUE(t.Assemble(Header(), &binary));
156  EXPECT_TRUE(t.Validate(binary));
157  EXPECT_EQ(0, invocation_count);
158}
159
160TEST(CppInterface, ValidateOverloads) {
161  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
162  std::vector<uint32_t> binary;
163  EXPECT_TRUE(t.Assemble(Header(), &binary));
164
165  { EXPECT_TRUE(t.Validate(binary)); }
166  { EXPECT_TRUE(t.Validate(binary.data(), binary.size())); }
167}
168
169TEST(CppInterface, ValidateEmptyModule) {
170  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
171  int invocation_count = 0;
172  t.SetMessageConsumer(
173      [&invocation_count](spv_message_level_t level, const char* source,
174                          const spv_position_t& position, const char* message) {
175        ++invocation_count;
176        EXPECT_EQ(SPV_MSG_ERROR, level);
177        EXPECT_STREQ("input", source);
178        EXPECT_EQ(0u, position.line);
179        EXPECT_EQ(0u, position.column);
180        EXPECT_EQ(0u, position.index);
181        EXPECT_STREQ("Invalid SPIR-V magic number.", message);
182      });
183  EXPECT_FALSE(t.Validate({}));
184  EXPECT_EQ(1, invocation_count);
185}
186
187// Returns the assembly for a SPIR-V module with a struct declaration
188// with the given number of members.
189std::string MakeModuleHavingStruct(int num_members) {
190  std::stringstream os;
191  os << Header();
192  os << R"(%1 = OpTypeInt 32 0
193           %2 = OpTypeStruct)";
194  for (int i = 0; i < num_members; i++) os << " %1";
195  return os.str();
196}
197
198TEST(CppInterface, ValidateWithOptionsPass) {
199  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
200  std::vector<uint32_t> binary;
201  EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
202  const ValidatorOptions opts;
203
204  EXPECT_TRUE(t.Validate(binary.data(), binary.size(), opts));
205}
206
207TEST(CppInterface, ValidateWithOptionsFail) {
208  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
209  std::vector<uint32_t> binary;
210  EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary));
211  ValidatorOptions opts;
212  opts.SetUniversalLimit(spv_validator_limit_max_struct_members, 9);
213  std::stringstream os;
214  t.SetMessageConsumer([&os](spv_message_level_t, const char*,
215                             const spv_position_t&,
216                             const char* message) { os << message; });
217
218  EXPECT_FALSE(t.Validate(binary.data(), binary.size(), opts));
219  EXPECT_THAT(
220      os.str(),
221      HasSubstr(
222          "Number of OpTypeStruct members (10) has exceeded the limit (9)"));
223}
224
225// Checks that after running the given optimizer |opt| on the given |original|
226// source code, we can get the given |optimized| source code.
227void CheckOptimization(const std::string& original,
228                       const std::string& optimized, const Optimizer& opt) {
229  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
230  std::vector<uint32_t> original_binary;
231  ASSERT_TRUE(t.Assemble(original, &original_binary));
232
233  std::vector<uint32_t> optimized_binary;
234  EXPECT_TRUE(opt.Run(original_binary.data(), original_binary.size(),
235                      &optimized_binary));
236
237  std::string optimized_text;
238  EXPECT_TRUE(t.Disassemble(optimized_binary, &optimized_text));
239  EXPECT_EQ(optimized, optimized_text);
240}
241
242TEST(CppInterface, OptimizeEmptyModule) {
243  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
244  std::vector<uint32_t> binary;
245  EXPECT_TRUE(t.Assemble("", &binary));
246
247  Optimizer o(SPV_ENV_UNIVERSAL_1_1);
248  o.RegisterPass(CreateStripDebugInfoPass());
249
250  // Fails to validate.
251  EXPECT_FALSE(o.Run(binary.data(), binary.size(), &binary));
252}
253
254TEST(CppInterface, OptimizeModifiedModule) {
255  Optimizer o(SPV_ENV_UNIVERSAL_1_1);
256  o.RegisterPass(CreateStripDebugInfoPass());
257  CheckOptimization(Header() + "OpSource GLSL 450", Header(), o);
258}
259
260TEST(CppInterface, OptimizeMulitplePasses) {
261  std::string original_text = Header() +
262                              "OpSource GLSL 450 "
263                              "OpDecorate %true SpecId 1 "
264                              "%bool = OpTypeBool "
265                              "%true = OpSpecConstantTrue %bool";
266
267  Optimizer o(SPV_ENV_UNIVERSAL_1_1);
268  o.RegisterPass(CreateStripDebugInfoPass())
269      .RegisterPass(CreateFreezeSpecConstantValuePass());
270
271  std::string expected_text = Header() +
272                              "%bool = OpTypeBool\n"
273                              "%true = OpConstantTrue %bool\n";
274
275  CheckOptimization(original_text, expected_text, o);
276}
277
278TEST(CppInterface, OptimizeDoNothingWithPassToken) {
279  CreateFreezeSpecConstantValuePass();
280  auto token = CreateUnifyConstantPass();
281}
282
283TEST(CppInterface, OptimizeReassignPassToken) {
284  auto token = CreateNullPass();
285  token = CreateStripDebugInfoPass();
286
287  CheckOptimization(
288      Header() + "OpSource GLSL 450", Header(),
289      Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token)));
290}
291
292TEST(CppInterface, OptimizeMoveConstructPassToken) {
293  auto token1 = CreateStripDebugInfoPass();
294  Optimizer::PassToken token2(std::move(token1));
295
296  CheckOptimization(
297      Header() + "OpSource GLSL 450", Header(),
298      Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
299}
300
301TEST(CppInterface, OptimizeMoveAssignPassToken) {
302  auto token1 = CreateStripDebugInfoPass();
303  auto token2 = CreateNullPass();
304  token2 = std::move(token1);
305
306  CheckOptimization(
307      Header() + "OpSource GLSL 450", Header(),
308      Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2)));
309}
310
311TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) {
312  SpirvTools t(SPV_ENV_UNIVERSAL_1_1);
313  std::vector<uint32_t> binary;
314  ASSERT_TRUE(t.Assemble(Header() + "OpSource GLSL 450", &binary));
315
316  EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1)
317                  .RegisterPass(CreateStripDebugInfoPass())
318                  .Run(binary.data(), binary.size(), &binary));
319
320  std::string optimized_text;
321  EXPECT_TRUE(t.Disassemble(binary, &optimized_text));
322  EXPECT_EQ(Header(), optimized_text);
323}
324
325TEST(SpirvHeadersCpp, BitwiseOrMemoryAccessMask) {
326  EXPECT_EQ(spv::MemoryAccessMask(6), spv::MemoryAccessMask::Aligned |
327                                          spv::MemoryAccessMask::Nontemporal);
328}
329
330TEST(SpirvHeadersCpp, BitwiseAndMemoryAccessMask) {
331  EXPECT_EQ(spv::MemoryAccessMask::Aligned,
332            spv::MemoryAccessMask::Aligned & spv::MemoryAccessMask(6));
333  EXPECT_EQ(spv::MemoryAccessMask::Nontemporal,
334            spv::MemoryAccessMask::Nontemporal & spv::MemoryAccessMask(6));
335  EXPECT_EQ(spv::MemoryAccessMask(0), spv::MemoryAccessMask::Nontemporal &
336                                          spv::MemoryAccessMask::Aligned);
337}
338
339TEST(SpirvHeadersCpp, BitwiseXorMemoryAccessMask) {
340  EXPECT_EQ(spv::MemoryAccessMask::Nontemporal,
341            spv::MemoryAccessMask::Aligned ^ spv::MemoryAccessMask(6));
342  EXPECT_EQ(spv::MemoryAccessMask::Aligned,
343            spv::MemoryAccessMask::Nontemporal ^ spv::MemoryAccessMask(6));
344  EXPECT_EQ(spv::MemoryAccessMask(6), spv::MemoryAccessMask::Nontemporal ^
345                                          spv::MemoryAccessMask::Aligned);
346  EXPECT_EQ(spv::MemoryAccessMask(0), spv::MemoryAccessMask::Nontemporal ^
347                                          spv::MemoryAccessMask::Nontemporal);
348}
349
350TEST(SpirvHeadersCpp, BitwiseNegateMemoryAccessMask) {
351  EXPECT_EQ(spv::MemoryAccessMask(~(uint32_t(4))),
352            ~spv::MemoryAccessMask::Nontemporal);
353}
354
355// TODO(antiagainst): tests for SetMessageConsumer().
356
357}  // namespace
358}  // namespace spvtools
359