1fd4e5da5Sopenharmony_ci// Copyright (c) 2023 Google LLC. 2fd4e5da5Sopenharmony_ci// 3fd4e5da5Sopenharmony_ci// Licensed under the Apache License, Version 2.0 (the "License"); 4fd4e5da5Sopenharmony_ci// you may not use this file except in compliance with the License. 5fd4e5da5Sopenharmony_ci// You may obtain a copy of the License at 6fd4e5da5Sopenharmony_ci// 7fd4e5da5Sopenharmony_ci// http://www.apache.org/licenses/LICENSE-2.0 8fd4e5da5Sopenharmony_ci// 9fd4e5da5Sopenharmony_ci// Unless required by applicable law or agreed to in writing, software 10fd4e5da5Sopenharmony_ci// distributed under the License is distributed on an "AS IS" BASIS, 11fd4e5da5Sopenharmony_ci// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12fd4e5da5Sopenharmony_ci// See the License for the specific language governing permissions and 13fd4e5da5Sopenharmony_ci// limitations under the License. 14fd4e5da5Sopenharmony_ci 15fd4e5da5Sopenharmony_ci#include "extract_source.h" 16fd4e5da5Sopenharmony_ci 17fd4e5da5Sopenharmony_ci#include <cassert> 18fd4e5da5Sopenharmony_ci#include <string> 19fd4e5da5Sopenharmony_ci#include <unordered_map> 20fd4e5da5Sopenharmony_ci#include <vector> 21fd4e5da5Sopenharmony_ci 22fd4e5da5Sopenharmony_ci#include "source/opt/log.h" 23fd4e5da5Sopenharmony_ci#include "spirv-tools/libspirv.hpp" 24fd4e5da5Sopenharmony_ci#include "spirv/unified1/spirv.hpp" 25fd4e5da5Sopenharmony_ci#include "tools/util/cli_consumer.h" 26fd4e5da5Sopenharmony_ci 27fd4e5da5Sopenharmony_cinamespace { 28fd4e5da5Sopenharmony_ci 29fd4e5da5Sopenharmony_ciconstexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6; 30fd4e5da5Sopenharmony_ci 31fd4e5da5Sopenharmony_ci// Extract a string literal from a given range. 32fd4e5da5Sopenharmony_ci// Copies all the characters from `begin` to the first '\0' it encounters, while 33fd4e5da5Sopenharmony_ci// removing escape patterns. 34fd4e5da5Sopenharmony_ci// Not finding a '\0' before reaching `end` fails the extraction. 35fd4e5da5Sopenharmony_ci// 36fd4e5da5Sopenharmony_ci// Returns `true` if the extraction succeeded. 37fd4e5da5Sopenharmony_ci// `output` value is undefined if false is returned. 38fd4e5da5Sopenharmony_cispv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin, 39fd4e5da5Sopenharmony_ci const char* end, std::string* output) { 40fd4e5da5Sopenharmony_ci size_t sourceLength = std::distance(begin, end); 41fd4e5da5Sopenharmony_ci std::string escapedString; 42fd4e5da5Sopenharmony_ci escapedString.resize(sourceLength); 43fd4e5da5Sopenharmony_ci 44fd4e5da5Sopenharmony_ci size_t writeIndex = 0; 45fd4e5da5Sopenharmony_ci size_t readIndex = 0; 46fd4e5da5Sopenharmony_ci for (; readIndex < sourceLength; writeIndex++, readIndex++) { 47fd4e5da5Sopenharmony_ci const char read = begin[readIndex]; 48fd4e5da5Sopenharmony_ci if (read == '\0') { 49fd4e5da5Sopenharmony_ci escapedString.resize(writeIndex); 50fd4e5da5Sopenharmony_ci output->append(escapedString); 51fd4e5da5Sopenharmony_ci return SPV_SUCCESS; 52fd4e5da5Sopenharmony_ci } 53fd4e5da5Sopenharmony_ci 54fd4e5da5Sopenharmony_ci if (read == '\\') { 55fd4e5da5Sopenharmony_ci ++readIndex; 56fd4e5da5Sopenharmony_ci } 57fd4e5da5Sopenharmony_ci escapedString[writeIndex] = begin[readIndex]; 58fd4e5da5Sopenharmony_ci } 59fd4e5da5Sopenharmony_ci 60fd4e5da5Sopenharmony_ci spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 61fd4e5da5Sopenharmony_ci "Missing NULL terminator for literal string."); 62fd4e5da5Sopenharmony_ci return SPV_ERROR_INVALID_BINARY; 63fd4e5da5Sopenharmony_ci} 64fd4e5da5Sopenharmony_ci 65fd4e5da5Sopenharmony_cispv_result_t extractOpString(const spv_position_t& loc, 66fd4e5da5Sopenharmony_ci const spv_parsed_instruction_t& instruction, 67fd4e5da5Sopenharmony_ci std::string* output) { 68fd4e5da5Sopenharmony_ci assert(output != nullptr); 69fd4e5da5Sopenharmony_ci assert(instruction.opcode == spv::Op::OpString); 70fd4e5da5Sopenharmony_ci if (instruction.num_operands != 2) { 71fd4e5da5Sopenharmony_ci spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 72fd4e5da5Sopenharmony_ci "Missing operands for OpString."); 73fd4e5da5Sopenharmony_ci return SPV_ERROR_INVALID_BINARY; 74fd4e5da5Sopenharmony_ci } 75fd4e5da5Sopenharmony_ci 76fd4e5da5Sopenharmony_ci const auto& operand = instruction.operands[1]; 77fd4e5da5Sopenharmony_ci const char* stringBegin = 78fd4e5da5Sopenharmony_ci reinterpret_cast<const char*>(instruction.words + operand.offset); 79fd4e5da5Sopenharmony_ci const char* stringEnd = reinterpret_cast<const char*>( 80fd4e5da5Sopenharmony_ci instruction.words + operand.offset + operand.num_words); 81fd4e5da5Sopenharmony_ci return ExtractStringLiteral(loc, stringBegin, stringEnd, output); 82fd4e5da5Sopenharmony_ci} 83fd4e5da5Sopenharmony_ci 84fd4e5da5Sopenharmony_cispv_result_t extractOpSourceContinued( 85fd4e5da5Sopenharmony_ci const spv_position_t& loc, const spv_parsed_instruction_t& instruction, 86fd4e5da5Sopenharmony_ci std::string* output) { 87fd4e5da5Sopenharmony_ci assert(output != nullptr); 88fd4e5da5Sopenharmony_ci assert(instruction.opcode == spv::Op::OpSourceContinued); 89fd4e5da5Sopenharmony_ci if (instruction.num_operands != 1) { 90fd4e5da5Sopenharmony_ci spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 91fd4e5da5Sopenharmony_ci "Missing operands for OpSourceContinued."); 92fd4e5da5Sopenharmony_ci return SPV_ERROR_INVALID_BINARY; 93fd4e5da5Sopenharmony_ci } 94fd4e5da5Sopenharmony_ci 95fd4e5da5Sopenharmony_ci const auto& operand = instruction.operands[0]; 96fd4e5da5Sopenharmony_ci const char* stringBegin = 97fd4e5da5Sopenharmony_ci reinterpret_cast<const char*>(instruction.words + operand.offset); 98fd4e5da5Sopenharmony_ci const char* stringEnd = reinterpret_cast<const char*>( 99fd4e5da5Sopenharmony_ci instruction.words + operand.offset + operand.num_words); 100fd4e5da5Sopenharmony_ci return ExtractStringLiteral(loc, stringBegin, stringEnd, output); 101fd4e5da5Sopenharmony_ci} 102fd4e5da5Sopenharmony_ci 103fd4e5da5Sopenharmony_cispv_result_t extractOpSource(const spv_position_t& loc, 104fd4e5da5Sopenharmony_ci const spv_parsed_instruction_t& instruction, 105fd4e5da5Sopenharmony_ci spv::Id* filename, std::string* code) { 106fd4e5da5Sopenharmony_ci assert(filename != nullptr && code != nullptr); 107fd4e5da5Sopenharmony_ci assert(instruction.opcode == spv::Op::OpSource); 108fd4e5da5Sopenharmony_ci // OpCode [ Source Language | Version | File (optional) | Source (optional) ] 109fd4e5da5Sopenharmony_ci if (instruction.num_words < 3) { 110fd4e5da5Sopenharmony_ci spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 111fd4e5da5Sopenharmony_ci "Missing operands for OpSource."); 112fd4e5da5Sopenharmony_ci return SPV_ERROR_INVALID_BINARY; 113fd4e5da5Sopenharmony_ci } 114fd4e5da5Sopenharmony_ci 115fd4e5da5Sopenharmony_ci *filename = 0; 116fd4e5da5Sopenharmony_ci *code = ""; 117fd4e5da5Sopenharmony_ci if (instruction.num_words < 4) { 118fd4e5da5Sopenharmony_ci return SPV_SUCCESS; 119fd4e5da5Sopenharmony_ci } 120fd4e5da5Sopenharmony_ci *filename = instruction.words[3]; 121fd4e5da5Sopenharmony_ci 122fd4e5da5Sopenharmony_ci if (instruction.num_words < 5) { 123fd4e5da5Sopenharmony_ci return SPV_SUCCESS; 124fd4e5da5Sopenharmony_ci } 125fd4e5da5Sopenharmony_ci 126fd4e5da5Sopenharmony_ci const char* stringBegin = 127fd4e5da5Sopenharmony_ci reinterpret_cast<const char*>(instruction.words + 4); 128fd4e5da5Sopenharmony_ci const char* stringEnd = 129fd4e5da5Sopenharmony_ci reinterpret_cast<const char*>(instruction.words + instruction.num_words); 130fd4e5da5Sopenharmony_ci return ExtractStringLiteral(loc, stringBegin, stringEnd, code); 131fd4e5da5Sopenharmony_ci} 132fd4e5da5Sopenharmony_ci 133fd4e5da5Sopenharmony_ci} // namespace 134fd4e5da5Sopenharmony_ci 135fd4e5da5Sopenharmony_cibool ExtractSourceFromModule( 136fd4e5da5Sopenharmony_ci const std::vector<uint32_t>& binary, 137fd4e5da5Sopenharmony_ci std::unordered_map<std::string, std::string>* output) { 138fd4e5da5Sopenharmony_ci auto context = spvtools::SpirvTools(kDefaultEnvironment); 139fd4e5da5Sopenharmony_ci context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer); 140fd4e5da5Sopenharmony_ci 141fd4e5da5Sopenharmony_ci // There is nothing valuable in the header. 142fd4e5da5Sopenharmony_ci spvtools::HeaderParser headerParser = [](const spv_endianness_t, 143fd4e5da5Sopenharmony_ci const spv_parsed_header_t&) { 144fd4e5da5Sopenharmony_ci return SPV_SUCCESS; 145fd4e5da5Sopenharmony_ci }; 146fd4e5da5Sopenharmony_ci 147fd4e5da5Sopenharmony_ci std::unordered_map<uint32_t, std::string> stringMap; 148fd4e5da5Sopenharmony_ci std::vector<std::pair<spv::Id, std::string>> sources; 149fd4e5da5Sopenharmony_ci spv::Op lastOpcode = spv::Op::OpMax; 150fd4e5da5Sopenharmony_ci size_t instructionIndex = 0; 151fd4e5da5Sopenharmony_ci 152fd4e5da5Sopenharmony_ci spvtools::InstructionParser instructionParser = 153fd4e5da5Sopenharmony_ci [&stringMap, &sources, &lastOpcode, 154fd4e5da5Sopenharmony_ci &instructionIndex](const spv_parsed_instruction_t& instruction) { 155fd4e5da5Sopenharmony_ci const spv_position_t loc = {0, 0, instructionIndex + 1}; 156fd4e5da5Sopenharmony_ci spv_result_t result = SPV_SUCCESS; 157fd4e5da5Sopenharmony_ci 158fd4e5da5Sopenharmony_ci if (instruction.opcode == spv::Op::OpString) { 159fd4e5da5Sopenharmony_ci std::string content; 160fd4e5da5Sopenharmony_ci result = extractOpString(loc, instruction, &content); 161fd4e5da5Sopenharmony_ci if (result == SPV_SUCCESS) { 162fd4e5da5Sopenharmony_ci stringMap.emplace(instruction.result_id, std::move(content)); 163fd4e5da5Sopenharmony_ci } 164fd4e5da5Sopenharmony_ci } else if (instruction.opcode == spv::Op::OpSource) { 165fd4e5da5Sopenharmony_ci spv::Id filenameId; 166fd4e5da5Sopenharmony_ci std::string code; 167fd4e5da5Sopenharmony_ci result = extractOpSource(loc, instruction, &filenameId, &code); 168fd4e5da5Sopenharmony_ci if (result == SPV_SUCCESS) { 169fd4e5da5Sopenharmony_ci sources.emplace_back(std::make_pair(filenameId, std::move(code))); 170fd4e5da5Sopenharmony_ci } 171fd4e5da5Sopenharmony_ci } else if (instruction.opcode == spv::Op::OpSourceContinued) { 172fd4e5da5Sopenharmony_ci if (lastOpcode != spv::Op::OpSource) { 173fd4e5da5Sopenharmony_ci spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc, 174fd4e5da5Sopenharmony_ci "OpSourceContinued MUST follow an OpSource."); 175fd4e5da5Sopenharmony_ci return SPV_ERROR_INVALID_BINARY; 176fd4e5da5Sopenharmony_ci } 177fd4e5da5Sopenharmony_ci 178fd4e5da5Sopenharmony_ci assert(sources.size() > 0); 179fd4e5da5Sopenharmony_ci result = extractOpSourceContinued(loc, instruction, 180fd4e5da5Sopenharmony_ci &sources.back().second); 181fd4e5da5Sopenharmony_ci } 182fd4e5da5Sopenharmony_ci 183fd4e5da5Sopenharmony_ci ++instructionIndex; 184fd4e5da5Sopenharmony_ci lastOpcode = static_cast<spv::Op>(instruction.opcode); 185fd4e5da5Sopenharmony_ci return result; 186fd4e5da5Sopenharmony_ci }; 187fd4e5da5Sopenharmony_ci 188fd4e5da5Sopenharmony_ci if (!context.Parse(binary, headerParser, instructionParser)) { 189fd4e5da5Sopenharmony_ci return false; 190fd4e5da5Sopenharmony_ci } 191fd4e5da5Sopenharmony_ci 192fd4e5da5Sopenharmony_ci std::string defaultName = "unnamed-"; 193fd4e5da5Sopenharmony_ci size_t unnamedCount = 0; 194fd4e5da5Sopenharmony_ci for (auto & [ id, code ] : sources) { 195fd4e5da5Sopenharmony_ci std::string filename; 196fd4e5da5Sopenharmony_ci const auto it = stringMap.find(id); 197fd4e5da5Sopenharmony_ci if (it == stringMap.cend() || it->second.empty()) { 198fd4e5da5Sopenharmony_ci filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl"; 199fd4e5da5Sopenharmony_ci ++unnamedCount; 200fd4e5da5Sopenharmony_ci } else { 201fd4e5da5Sopenharmony_ci filename = it->second; 202fd4e5da5Sopenharmony_ci } 203fd4e5da5Sopenharmony_ci 204fd4e5da5Sopenharmony_ci if (output->count(filename) != 0) { 205fd4e5da5Sopenharmony_ci spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {}, 206fd4e5da5Sopenharmony_ci "Source file name conflict."); 207fd4e5da5Sopenharmony_ci return false; 208fd4e5da5Sopenharmony_ci } 209fd4e5da5Sopenharmony_ci output->insert({filename, code}); 210fd4e5da5Sopenharmony_ci } 211fd4e5da5Sopenharmony_ci 212fd4e5da5Sopenharmony_ci return true; 213fd4e5da5Sopenharmony_ci} 214