1// Copyright (c) 2016 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include <string> 16#include <utility> 17#include <vector> 18 19#include "gmock/gmock.h" 20#include "gtest/gtest.h" 21#include "spirv-tools/optimizer.hpp" 22#include "spirv/unified1/spirv.hpp11" 23 24namespace spvtools { 25namespace { 26 27using ::testing::ContainerEq; 28using ::testing::HasSubstr; 29 30// Return a string that contains the minimum instructions needed to form 31// a valid module. Other instructions can be appended to this string. 32std::string Header() { 33 return R"(OpCapability Shader 34OpCapability Linkage 35OpMemoryModel Logical GLSL450 36)"; 37} 38 39// When we assemble with a target environment of SPIR-V 1.1, we expect 40// the following in the module header version word. 41const uint32_t kExpectedSpvVersion = 0x10100; 42 43TEST(CppInterface, SuccessfulRoundTrip) { 44 const std::string input_text = "%2 = OpSizeOf %1 %3\n"; 45 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 46 47 std::vector<uint32_t> binary; 48 EXPECT_TRUE(t.Assemble(input_text, &binary)); 49 EXPECT_TRUE(binary.size() > 5u); 50 EXPECT_EQ(spv::MagicNumber, binary[0]); 51 EXPECT_EQ(kExpectedSpvVersion, binary[1]); 52 53 // This cannot pass validation since %1 is not defined. 54 t.SetMessageConsumer([](spv_message_level_t level, const char* source, 55 const spv_position_t& position, const char* message) { 56 EXPECT_EQ(SPV_MSG_ERROR, level); 57 EXPECT_STREQ("input", source); 58 EXPECT_EQ(0u, position.line); 59 EXPECT_EQ(0u, position.column); 60 EXPECT_EQ(1u, position.index); 61 EXPECT_STREQ("ID '1[%1]' has not been defined\n %2 = OpSizeOf %1 %3\n", 62 message); 63 }); 64 EXPECT_FALSE(t.Validate(binary)); 65 66 std::string output_text; 67 EXPECT_TRUE(t.Disassemble(binary, &output_text)); 68 EXPECT_EQ(input_text, output_text); 69} 70 71TEST(CppInterface, AssembleEmptyModule) { 72 std::vector<uint32_t> binary(10, 42); 73 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 74 EXPECT_TRUE(t.Assemble("", &binary)); 75 // We only have the header. 76 EXPECT_EQ(5u, binary.size()); 77 EXPECT_EQ(spv::MagicNumber, binary[0]); 78 EXPECT_EQ(kExpectedSpvVersion, binary[1]); 79} 80 81TEST(CppInterface, AssembleOverloads) { 82 const std::string input_text = "%2 = OpSizeOf %1 %3\n"; 83 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 84 { 85 std::vector<uint32_t> binary; 86 EXPECT_TRUE(t.Assemble(input_text, &binary)); 87 EXPECT_TRUE(binary.size() > 5u); 88 EXPECT_EQ(spv::MagicNumber, binary[0]); 89 EXPECT_EQ(kExpectedSpvVersion, binary[1]); 90 } 91 { 92 std::vector<uint32_t> binary; 93 EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size(), &binary)); 94 EXPECT_TRUE(binary.size() > 5u); 95 EXPECT_EQ(spv::MagicNumber, binary[0]); 96 EXPECT_EQ(kExpectedSpvVersion, binary[1]); 97 } 98 { // Ignore the last newline. 99 std::vector<uint32_t> binary; 100 EXPECT_TRUE(t.Assemble(input_text.data(), input_text.size() - 1, &binary)); 101 EXPECT_TRUE(binary.size() > 5u); 102 EXPECT_EQ(spv::MagicNumber, binary[0]); 103 EXPECT_EQ(kExpectedSpvVersion, binary[1]); 104 } 105} 106 107TEST(CppInterface, DisassembleEmptyModule) { 108 std::string text(10, 'x'); 109 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 110 int invocation_count = 0; 111 t.SetMessageConsumer( 112 [&invocation_count](spv_message_level_t level, const char* source, 113 const spv_position_t& position, const char* message) { 114 ++invocation_count; 115 EXPECT_EQ(SPV_MSG_ERROR, level); 116 EXPECT_STREQ("input", source); 117 EXPECT_EQ(0u, position.line); 118 EXPECT_EQ(0u, position.column); 119 EXPECT_EQ(0u, position.index); 120 EXPECT_STREQ("Missing module.", message); 121 }); 122 EXPECT_FALSE(t.Disassemble({}, &text)); 123 EXPECT_EQ("xxxxxxxxxx", text); // The original string is unmodified. 124 EXPECT_EQ(1, invocation_count); 125} 126 127TEST(CppInterface, DisassembleOverloads) { 128 const std::string input_text = "%2 = OpSizeOf %1 %3\n"; 129 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 130 131 std::vector<uint32_t> binary; 132 EXPECT_TRUE(t.Assemble(input_text, &binary)); 133 134 { 135 std::string output_text; 136 EXPECT_TRUE(t.Disassemble(binary, &output_text)); 137 EXPECT_EQ(input_text, output_text); 138 } 139 { 140 std::string output_text; 141 EXPECT_TRUE(t.Disassemble(binary.data(), binary.size(), &output_text)); 142 EXPECT_EQ(input_text, output_text); 143 } 144} 145 146TEST(CppInterface, SuccessfulValidation) { 147 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 148 int invocation_count = 0; 149 t.SetMessageConsumer([&invocation_count](spv_message_level_t, const char*, 150 const spv_position_t&, const char*) { 151 ++invocation_count; 152 }); 153 154 std::vector<uint32_t> binary; 155 EXPECT_TRUE(t.Assemble(Header(), &binary)); 156 EXPECT_TRUE(t.Validate(binary)); 157 EXPECT_EQ(0, invocation_count); 158} 159 160TEST(CppInterface, ValidateOverloads) { 161 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 162 std::vector<uint32_t> binary; 163 EXPECT_TRUE(t.Assemble(Header(), &binary)); 164 165 { EXPECT_TRUE(t.Validate(binary)); } 166 { EXPECT_TRUE(t.Validate(binary.data(), binary.size())); } 167} 168 169TEST(CppInterface, ValidateEmptyModule) { 170 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 171 int invocation_count = 0; 172 t.SetMessageConsumer( 173 [&invocation_count](spv_message_level_t level, const char* source, 174 const spv_position_t& position, const char* message) { 175 ++invocation_count; 176 EXPECT_EQ(SPV_MSG_ERROR, level); 177 EXPECT_STREQ("input", source); 178 EXPECT_EQ(0u, position.line); 179 EXPECT_EQ(0u, position.column); 180 EXPECT_EQ(0u, position.index); 181 EXPECT_STREQ("Invalid SPIR-V magic number.", message); 182 }); 183 EXPECT_FALSE(t.Validate({})); 184 EXPECT_EQ(1, invocation_count); 185} 186 187// Returns the assembly for a SPIR-V module with a struct declaration 188// with the given number of members. 189std::string MakeModuleHavingStruct(int num_members) { 190 std::stringstream os; 191 os << Header(); 192 os << R"(%1 = OpTypeInt 32 0 193 %2 = OpTypeStruct)"; 194 for (int i = 0; i < num_members; i++) os << " %1"; 195 return os.str(); 196} 197 198TEST(CppInterface, ValidateWithOptionsPass) { 199 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 200 std::vector<uint32_t> binary; 201 EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary)); 202 const ValidatorOptions opts; 203 204 EXPECT_TRUE(t.Validate(binary.data(), binary.size(), opts)); 205} 206 207TEST(CppInterface, ValidateWithOptionsFail) { 208 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 209 std::vector<uint32_t> binary; 210 EXPECT_TRUE(t.Assemble(MakeModuleHavingStruct(10), &binary)); 211 ValidatorOptions opts; 212 opts.SetUniversalLimit(spv_validator_limit_max_struct_members, 9); 213 std::stringstream os; 214 t.SetMessageConsumer([&os](spv_message_level_t, const char*, 215 const spv_position_t&, 216 const char* message) { os << message; }); 217 218 EXPECT_FALSE(t.Validate(binary.data(), binary.size(), opts)); 219 EXPECT_THAT( 220 os.str(), 221 HasSubstr( 222 "Number of OpTypeStruct members (10) has exceeded the limit (9)")); 223} 224 225// Checks that after running the given optimizer |opt| on the given |original| 226// source code, we can get the given |optimized| source code. 227void CheckOptimization(const std::string& original, 228 const std::string& optimized, const Optimizer& opt) { 229 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 230 std::vector<uint32_t> original_binary; 231 ASSERT_TRUE(t.Assemble(original, &original_binary)); 232 233 std::vector<uint32_t> optimized_binary; 234 EXPECT_TRUE(opt.Run(original_binary.data(), original_binary.size(), 235 &optimized_binary)); 236 237 std::string optimized_text; 238 EXPECT_TRUE(t.Disassemble(optimized_binary, &optimized_text)); 239 EXPECT_EQ(optimized, optimized_text); 240} 241 242TEST(CppInterface, OptimizeEmptyModule) { 243 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 244 std::vector<uint32_t> binary; 245 EXPECT_TRUE(t.Assemble("", &binary)); 246 247 Optimizer o(SPV_ENV_UNIVERSAL_1_1); 248 o.RegisterPass(CreateStripDebugInfoPass()); 249 250 // Fails to validate. 251 EXPECT_FALSE(o.Run(binary.data(), binary.size(), &binary)); 252} 253 254TEST(CppInterface, OptimizeModifiedModule) { 255 Optimizer o(SPV_ENV_UNIVERSAL_1_1); 256 o.RegisterPass(CreateStripDebugInfoPass()); 257 CheckOptimization(Header() + "OpSource GLSL 450", Header(), o); 258} 259 260TEST(CppInterface, OptimizeMulitplePasses) { 261 std::string original_text = Header() + 262 "OpSource GLSL 450 " 263 "OpDecorate %true SpecId 1 " 264 "%bool = OpTypeBool " 265 "%true = OpSpecConstantTrue %bool"; 266 267 Optimizer o(SPV_ENV_UNIVERSAL_1_1); 268 o.RegisterPass(CreateStripDebugInfoPass()) 269 .RegisterPass(CreateFreezeSpecConstantValuePass()); 270 271 std::string expected_text = Header() + 272 "%bool = OpTypeBool\n" 273 "%true = OpConstantTrue %bool\n"; 274 275 CheckOptimization(original_text, expected_text, o); 276} 277 278TEST(CppInterface, OptimizeDoNothingWithPassToken) { 279 CreateFreezeSpecConstantValuePass(); 280 auto token = CreateUnifyConstantPass(); 281} 282 283TEST(CppInterface, OptimizeReassignPassToken) { 284 auto token = CreateNullPass(); 285 token = CreateStripDebugInfoPass(); 286 287 CheckOptimization( 288 Header() + "OpSource GLSL 450", Header(), 289 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token))); 290} 291 292TEST(CppInterface, OptimizeMoveConstructPassToken) { 293 auto token1 = CreateStripDebugInfoPass(); 294 Optimizer::PassToken token2(std::move(token1)); 295 296 CheckOptimization( 297 Header() + "OpSource GLSL 450", Header(), 298 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2))); 299} 300 301TEST(CppInterface, OptimizeMoveAssignPassToken) { 302 auto token1 = CreateStripDebugInfoPass(); 303 auto token2 = CreateNullPass(); 304 token2 = std::move(token1); 305 306 CheckOptimization( 307 Header() + "OpSource GLSL 450", Header(), 308 Optimizer(SPV_ENV_UNIVERSAL_1_1).RegisterPass(std::move(token2))); 309} 310 311TEST(CppInterface, OptimizeSameAddressForOriginalOptimizedBinary) { 312 SpirvTools t(SPV_ENV_UNIVERSAL_1_1); 313 std::vector<uint32_t> binary; 314 ASSERT_TRUE(t.Assemble(Header() + "OpSource GLSL 450", &binary)); 315 316 EXPECT_TRUE(Optimizer(SPV_ENV_UNIVERSAL_1_1) 317 .RegisterPass(CreateStripDebugInfoPass()) 318 .Run(binary.data(), binary.size(), &binary)); 319 320 std::string optimized_text; 321 EXPECT_TRUE(t.Disassemble(binary, &optimized_text)); 322 EXPECT_EQ(Header(), optimized_text); 323} 324 325TEST(SpirvHeadersCpp, BitwiseOrMemoryAccessMask) { 326 EXPECT_EQ(spv::MemoryAccessMask(6), spv::MemoryAccessMask::Aligned | 327 spv::MemoryAccessMask::Nontemporal); 328} 329 330TEST(SpirvHeadersCpp, BitwiseAndMemoryAccessMask) { 331 EXPECT_EQ(spv::MemoryAccessMask::Aligned, 332 spv::MemoryAccessMask::Aligned & spv::MemoryAccessMask(6)); 333 EXPECT_EQ(spv::MemoryAccessMask::Nontemporal, 334 spv::MemoryAccessMask::Nontemporal & spv::MemoryAccessMask(6)); 335 EXPECT_EQ(spv::MemoryAccessMask(0), spv::MemoryAccessMask::Nontemporal & 336 spv::MemoryAccessMask::Aligned); 337} 338 339TEST(SpirvHeadersCpp, BitwiseXorMemoryAccessMask) { 340 EXPECT_EQ(spv::MemoryAccessMask::Nontemporal, 341 spv::MemoryAccessMask::Aligned ^ spv::MemoryAccessMask(6)); 342 EXPECT_EQ(spv::MemoryAccessMask::Aligned, 343 spv::MemoryAccessMask::Nontemporal ^ spv::MemoryAccessMask(6)); 344 EXPECT_EQ(spv::MemoryAccessMask(6), spv::MemoryAccessMask::Nontemporal ^ 345 spv::MemoryAccessMask::Aligned); 346 EXPECT_EQ(spv::MemoryAccessMask(0), spv::MemoryAccessMask::Nontemporal ^ 347 spv::MemoryAccessMask::Nontemporal); 348} 349 350TEST(SpirvHeadersCpp, BitwiseNegateMemoryAccessMask) { 351 EXPECT_EQ(spv::MemoryAccessMask(~(uint32_t(4))), 352 ~spv::MemoryAccessMask::Nontemporal); 353} 354 355// TODO(antiagainst): tests for SetMessageConsumer(). 356 357} // namespace 358} // namespace spvtools 359