1// Copyright (c) 2023 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "extract_source.h"
16
17#include <cassert>
18#include <string>
19#include <unordered_map>
20#include <vector>
21
22#include "source/opt/log.h"
23#include "spirv-tools/libspirv.hpp"
24#include "spirv/unified1/spirv.hpp"
25#include "tools/util/cli_consumer.h"
26
27namespace {
28
29constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6;
30
31// Extract a string literal from a given range.
32// Copies all the characters from `begin` to the first '\0' it encounters, while
33// removing escape patterns.
34// Not finding a '\0' before reaching `end` fails the extraction.
35//
36// Returns `true` if the extraction succeeded.
37// `output` value is undefined if false is returned.
38spv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin,
39                                  const char* end, std::string* output) {
40  size_t sourceLength = std::distance(begin, end);
41  std::string escapedString;
42  escapedString.resize(sourceLength);
43
44  size_t writeIndex = 0;
45  size_t readIndex = 0;
46  for (; readIndex < sourceLength; writeIndex++, readIndex++) {
47    const char read = begin[readIndex];
48    if (read == '\0') {
49      escapedString.resize(writeIndex);
50      output->append(escapedString);
51      return SPV_SUCCESS;
52    }
53
54    if (read == '\\') {
55      ++readIndex;
56    }
57    escapedString[writeIndex] = begin[readIndex];
58  }
59
60  spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
61                  "Missing NULL terminator for literal string.");
62  return SPV_ERROR_INVALID_BINARY;
63}
64
65spv_result_t extractOpString(const spv_position_t& loc,
66                             const spv_parsed_instruction_t& instruction,
67                             std::string* output) {
68  assert(output != nullptr);
69  assert(instruction.opcode == spv::Op::OpString);
70  if (instruction.num_operands != 2) {
71    spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
72                    "Missing operands for OpString.");
73    return SPV_ERROR_INVALID_BINARY;
74  }
75
76  const auto& operand = instruction.operands[1];
77  const char* stringBegin =
78      reinterpret_cast<const char*>(instruction.words + operand.offset);
79  const char* stringEnd = reinterpret_cast<const char*>(
80      instruction.words + operand.offset + operand.num_words);
81  return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
82}
83
84spv_result_t extractOpSourceContinued(
85    const spv_position_t& loc, const spv_parsed_instruction_t& instruction,
86    std::string* output) {
87  assert(output != nullptr);
88  assert(instruction.opcode == spv::Op::OpSourceContinued);
89  if (instruction.num_operands != 1) {
90    spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
91                    "Missing operands for OpSourceContinued.");
92    return SPV_ERROR_INVALID_BINARY;
93  }
94
95  const auto& operand = instruction.operands[0];
96  const char* stringBegin =
97      reinterpret_cast<const char*>(instruction.words + operand.offset);
98  const char* stringEnd = reinterpret_cast<const char*>(
99      instruction.words + operand.offset + operand.num_words);
100  return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
101}
102
103spv_result_t extractOpSource(const spv_position_t& loc,
104                             const spv_parsed_instruction_t& instruction,
105                             spv::Id* filename, std::string* code) {
106  assert(filename != nullptr && code != nullptr);
107  assert(instruction.opcode == spv::Op::OpSource);
108  // OpCode [ Source Language | Version | File (optional) | Source (optional) ]
109  if (instruction.num_words < 3) {
110    spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
111                    "Missing operands for OpSource.");
112    return SPV_ERROR_INVALID_BINARY;
113  }
114
115  *filename = 0;
116  *code = "";
117  if (instruction.num_words < 4) {
118    return SPV_SUCCESS;
119  }
120  *filename = instruction.words[3];
121
122  if (instruction.num_words < 5) {
123    return SPV_SUCCESS;
124  }
125
126  const char* stringBegin =
127      reinterpret_cast<const char*>(instruction.words + 4);
128  const char* stringEnd =
129      reinterpret_cast<const char*>(instruction.words + instruction.num_words);
130  return ExtractStringLiteral(loc, stringBegin, stringEnd, code);
131}
132
133}  // namespace
134
135bool ExtractSourceFromModule(
136    const std::vector<uint32_t>& binary,
137    std::unordered_map<std::string, std::string>* output) {
138  auto context = spvtools::SpirvTools(kDefaultEnvironment);
139  context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
140
141  // There is nothing valuable in the header.
142  spvtools::HeaderParser headerParser = [](const spv_endianness_t,
143                                           const spv_parsed_header_t&) {
144    return SPV_SUCCESS;
145  };
146
147  std::unordered_map<uint32_t, std::string> stringMap;
148  std::vector<std::pair<spv::Id, std::string>> sources;
149  spv::Op lastOpcode = spv::Op::OpMax;
150  size_t instructionIndex = 0;
151
152  spvtools::InstructionParser instructionParser =
153      [&stringMap, &sources, &lastOpcode,
154       &instructionIndex](const spv_parsed_instruction_t& instruction) {
155        const spv_position_t loc = {0, 0, instructionIndex + 1};
156        spv_result_t result = SPV_SUCCESS;
157
158        if (instruction.opcode == spv::Op::OpString) {
159          std::string content;
160          result = extractOpString(loc, instruction, &content);
161          if (result == SPV_SUCCESS) {
162            stringMap.emplace(instruction.result_id, std::move(content));
163          }
164        } else if (instruction.opcode == spv::Op::OpSource) {
165          spv::Id filenameId;
166          std::string code;
167          result = extractOpSource(loc, instruction, &filenameId, &code);
168          if (result == SPV_SUCCESS) {
169            sources.emplace_back(std::make_pair(filenameId, std::move(code)));
170          }
171        } else if (instruction.opcode == spv::Op::OpSourceContinued) {
172          if (lastOpcode != spv::Op::OpSource) {
173            spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
174                            "OpSourceContinued MUST follow an OpSource.");
175            return SPV_ERROR_INVALID_BINARY;
176          }
177
178          assert(sources.size() > 0);
179          result = extractOpSourceContinued(loc, instruction,
180                                            &sources.back().second);
181        }
182
183        ++instructionIndex;
184        lastOpcode = static_cast<spv::Op>(instruction.opcode);
185        return result;
186      };
187
188  if (!context.Parse(binary, headerParser, instructionParser)) {
189    return false;
190  }
191
192  std::string defaultName = "unnamed-";
193  size_t unnamedCount = 0;
194  for (auto & [ id, code ] : sources) {
195    std::string filename;
196    const auto it = stringMap.find(id);
197    if (it == stringMap.cend() || it->second.empty()) {
198      filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl";
199      ++unnamedCount;
200    } else {
201      filename = it->second;
202    }
203
204    if (output->count(filename) != 0) {
205      spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {},
206                      "Source file name conflict.");
207      return false;
208    }
209    output->insert({filename, code});
210  }
211
212  return true;
213}
214