1// Copyright (c) 2023 Google LLC. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include "extract_source.h" 16 17#include <cassert> 18#include <string> 19#include <unordered_map> 20#include <vector> 21 22#include "source/opt/log.h" 23#include "spirv-tools/libspirv.hpp" 24#include "spirv/unified1/spirv.hpp" 25#include "tools/util/cli_consumer.h" 26 27namespace { 28 29constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6; 30 31// Extract a string literal from a given range. 32// Copies all the characters from `begin` to the first '\0' it encounters, while 33// removing escape patterns. 34// Not finding a '\0' before reaching `end` fails the extraction. 35// 36// Returns `true` if the extraction succeeded. 37// `output` value is undefined if false is returned. 38spv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin, 39 const char* end, std::string* output) { 40 size_t sourceLength = std::distance(begin, end); 41 std::string escapedString; 42 escapedString.resize(sourceLength); 43 44 size_t writeIndex = 0; 45 size_t readIndex = 0; 46 for (; readIndex < sourceLength; writeIndex++, readIndex++) { 47 const char read = begin[readIndex]; 48 if (read == '\0') { 49 escapedString.resize(writeIndex); 50 output->append(escapedString); 51 return SPV_SUCCESS; 52 } 53 54 if (read == '\\') { 55 ++readIndex; 56 } 57 escapedString[writeIndex] = begin[readIndex]; 58 } 59 60 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 61 "Missing NULL terminator for literal string."); 62 return SPV_ERROR_INVALID_BINARY; 63} 64 65spv_result_t extractOpString(const spv_position_t& loc, 66 const spv_parsed_instruction_t& instruction, 67 std::string* output) { 68 assert(output != nullptr); 69 assert(instruction.opcode == spv::Op::OpString); 70 if (instruction.num_operands != 2) { 71 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 72 "Missing operands for OpString."); 73 return SPV_ERROR_INVALID_BINARY; 74 } 75 76 const auto& operand = instruction.operands[1]; 77 const char* stringBegin = 78 reinterpret_cast<const char*>(instruction.words + operand.offset); 79 const char* stringEnd = reinterpret_cast<const char*>( 80 instruction.words + operand.offset + operand.num_words); 81 return ExtractStringLiteral(loc, stringBegin, stringEnd, output); 82} 83 84spv_result_t extractOpSourceContinued( 85 const spv_position_t& loc, const spv_parsed_instruction_t& instruction, 86 std::string* output) { 87 assert(output != nullptr); 88 assert(instruction.opcode == spv::Op::OpSourceContinued); 89 if (instruction.num_operands != 1) { 90 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 91 "Missing operands for OpSourceContinued."); 92 return SPV_ERROR_INVALID_BINARY; 93 } 94 95 const auto& operand = instruction.operands[0]; 96 const char* stringBegin = 97 reinterpret_cast<const char*>(instruction.words + operand.offset); 98 const char* stringEnd = reinterpret_cast<const char*>( 99 instruction.words + operand.offset + operand.num_words); 100 return ExtractStringLiteral(loc, stringBegin, stringEnd, output); 101} 102 103spv_result_t extractOpSource(const spv_position_t& loc, 104 const spv_parsed_instruction_t& instruction, 105 spv::Id* filename, std::string* code) { 106 assert(filename != nullptr && code != nullptr); 107 assert(instruction.opcode == spv::Op::OpSource); 108 // OpCode [ Source Language | Version | File (optional) | Source (optional) ] 109 if (instruction.num_words < 3) { 110 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 111 "Missing operands for OpSource."); 112 return SPV_ERROR_INVALID_BINARY; 113 } 114 115 *filename = 0; 116 *code = ""; 117 if (instruction.num_words < 4) { 118 return SPV_SUCCESS; 119 } 120 *filename = instruction.words[3]; 121 122 if (instruction.num_words < 5) { 123 return SPV_SUCCESS; 124 } 125 126 const char* stringBegin = 127 reinterpret_cast<const char*>(instruction.words + 4); 128 const char* stringEnd = 129 reinterpret_cast<const char*>(instruction.words + instruction.num_words); 130 return ExtractStringLiteral(loc, stringBegin, stringEnd, code); 131} 132 133} // namespace 134 135bool ExtractSourceFromModule( 136 const std::vector<uint32_t>& binary, 137 std::unordered_map<std::string, std::string>* output) { 138 auto context = spvtools::SpirvTools(kDefaultEnvironment); 139 context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); 140 141 // There is nothing valuable in the header. 142 spvtools::HeaderParser headerParser = [](const spv_endianness_t, 143 const spv_parsed_header_t&) { 144 return SPV_SUCCESS; 145 }; 146 147 std::unordered_map<uint32_t, std::string> stringMap; 148 std::vector<std::pair<spv::Id, std::string>> sources; 149 spv::Op lastOpcode = spv::Op::OpMax; 150 size_t instructionIndex = 0; 151 152 spvtools::InstructionParser instructionParser = 153 [&stringMap, &sources, &lastOpcode, 154 &instructionIndex](const spv_parsed_instruction_t& instruction) { 155 const spv_position_t loc = {0, 0, instructionIndex + 1}; 156 spv_result_t result = SPV_SUCCESS; 157 158 if (instruction.opcode == spv::Op::OpString) { 159 std::string content; 160 result = extractOpString(loc, instruction, &content); 161 if (result == SPV_SUCCESS) { 162 stringMap.emplace(instruction.result_id, std::move(content)); 163 } 164 } else if (instruction.opcode == spv::Op::OpSource) { 165 spv::Id filenameId; 166 std::string code; 167 result = extractOpSource(loc, instruction, &filenameId, &code); 168 if (result == SPV_SUCCESS) { 169 sources.emplace_back(std::make_pair(filenameId, std::move(code))); 170 } 171 } else if (instruction.opcode == spv::Op::OpSourceContinued) { 172 if (lastOpcode != spv::Op::OpSource) { 173 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 174 "OpSourceContinued MUST follow an OpSource."); 175 return SPV_ERROR_INVALID_BINARY; 176 } 177 178 assert(sources.size() > 0); 179 result = extractOpSourceContinued(loc, instruction, 180 &sources.back().second); 181 } 182 183 ++instructionIndex; 184 lastOpcode = static_cast<spv::Op>(instruction.opcode); 185 return result; 186 }; 187 188 if (!context.Parse(binary, headerParser, instructionParser)) { 189 return false; 190 } 191 192 std::string defaultName = "unnamed-"; 193 size_t unnamedCount = 0; 194 for (auto & [ id, code ] : sources) { 195 std::string filename; 196 const auto it = stringMap.find(id); 197 if (it == stringMap.cend() || it->second.empty()) { 198 filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl"; 199 ++unnamedCount; 200 } else { 201 filename = it->second; 202 } 203 204 if (output->count(filename) != 0) { 205 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {}, 206 "Source file name conflict."); 207 return false; 208 } 209 output->insert({filename, code}); 210 } 211 212 return true; 213} 214