1fd4e5da5Sopenharmony_ci// Copyright (c) 2023 Google LLC.
2fd4e5da5Sopenharmony_ci//
3fd4e5da5Sopenharmony_ci// Licensed under the Apache License, Version 2.0 (the "License");
4fd4e5da5Sopenharmony_ci// you may not use this file except in compliance with the License.
5fd4e5da5Sopenharmony_ci// You may obtain a copy of the License at
6fd4e5da5Sopenharmony_ci//
7fd4e5da5Sopenharmony_ci//     http://www.apache.org/licenses/LICENSE-2.0
8fd4e5da5Sopenharmony_ci//
9fd4e5da5Sopenharmony_ci// Unless required by applicable law or agreed to in writing, software
10fd4e5da5Sopenharmony_ci// distributed under the License is distributed on an "AS IS" BASIS,
11fd4e5da5Sopenharmony_ci// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12fd4e5da5Sopenharmony_ci// See the License for the specific language governing permissions and
13fd4e5da5Sopenharmony_ci// limitations under the License.
14fd4e5da5Sopenharmony_ci
15fd4e5da5Sopenharmony_ci#include "extract_source.h"
16fd4e5da5Sopenharmony_ci
17fd4e5da5Sopenharmony_ci#include <cassert>
18fd4e5da5Sopenharmony_ci#include <string>
19fd4e5da5Sopenharmony_ci#include <unordered_map>
20fd4e5da5Sopenharmony_ci#include <vector>
21fd4e5da5Sopenharmony_ci
22fd4e5da5Sopenharmony_ci#include "source/opt/log.h"
23fd4e5da5Sopenharmony_ci#include "spirv-tools/libspirv.hpp"
24fd4e5da5Sopenharmony_ci#include "spirv/unified1/spirv.hpp"
25fd4e5da5Sopenharmony_ci#include "tools/util/cli_consumer.h"
26fd4e5da5Sopenharmony_ci
27fd4e5da5Sopenharmony_cinamespace {
28fd4e5da5Sopenharmony_ci
29fd4e5da5Sopenharmony_ciconstexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6;
30fd4e5da5Sopenharmony_ci
31fd4e5da5Sopenharmony_ci// Extract a string literal from a given range.
32fd4e5da5Sopenharmony_ci// Copies all the characters from `begin` to the first '\0' it encounters, while
33fd4e5da5Sopenharmony_ci// removing escape patterns.
34fd4e5da5Sopenharmony_ci// Not finding a '\0' before reaching `end` fails the extraction.
35fd4e5da5Sopenharmony_ci//
36fd4e5da5Sopenharmony_ci// Returns `true` if the extraction succeeded.
37fd4e5da5Sopenharmony_ci// `output` value is undefined if false is returned.
38fd4e5da5Sopenharmony_cispv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin,
39fd4e5da5Sopenharmony_ci                                  const char* end, std::string* output) {
40fd4e5da5Sopenharmony_ci  size_t sourceLength = std::distance(begin, end);
41fd4e5da5Sopenharmony_ci  std::string escapedString;
42fd4e5da5Sopenharmony_ci  escapedString.resize(sourceLength);
43fd4e5da5Sopenharmony_ci
44fd4e5da5Sopenharmony_ci  size_t writeIndex = 0;
45fd4e5da5Sopenharmony_ci  size_t readIndex = 0;
46fd4e5da5Sopenharmony_ci  for (; readIndex < sourceLength; writeIndex++, readIndex++) {
47fd4e5da5Sopenharmony_ci    const char read = begin[readIndex];
48fd4e5da5Sopenharmony_ci    if (read == '\0') {
49fd4e5da5Sopenharmony_ci      escapedString.resize(writeIndex);
50fd4e5da5Sopenharmony_ci      output->append(escapedString);
51fd4e5da5Sopenharmony_ci      return SPV_SUCCESS;
52fd4e5da5Sopenharmony_ci    }
53fd4e5da5Sopenharmony_ci
54fd4e5da5Sopenharmony_ci    if (read == '\\') {
55fd4e5da5Sopenharmony_ci      ++readIndex;
56fd4e5da5Sopenharmony_ci    }
57fd4e5da5Sopenharmony_ci    escapedString[writeIndex] = begin[readIndex];
58fd4e5da5Sopenharmony_ci  }
59fd4e5da5Sopenharmony_ci
60fd4e5da5Sopenharmony_ci  spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
61fd4e5da5Sopenharmony_ci                  "Missing NULL terminator for literal string.");
62fd4e5da5Sopenharmony_ci  return SPV_ERROR_INVALID_BINARY;
63fd4e5da5Sopenharmony_ci}
64fd4e5da5Sopenharmony_ci
65fd4e5da5Sopenharmony_cispv_result_t extractOpString(const spv_position_t& loc,
66fd4e5da5Sopenharmony_ci                             const spv_parsed_instruction_t& instruction,
67fd4e5da5Sopenharmony_ci                             std::string* output) {
68fd4e5da5Sopenharmony_ci  assert(output != nullptr);
69fd4e5da5Sopenharmony_ci  assert(instruction.opcode == spv::Op::OpString);
70fd4e5da5Sopenharmony_ci  if (instruction.num_operands != 2) {
71fd4e5da5Sopenharmony_ci    spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
72fd4e5da5Sopenharmony_ci                    "Missing operands for OpString.");
73fd4e5da5Sopenharmony_ci    return SPV_ERROR_INVALID_BINARY;
74fd4e5da5Sopenharmony_ci  }
75fd4e5da5Sopenharmony_ci
76fd4e5da5Sopenharmony_ci  const auto& operand = instruction.operands[1];
77fd4e5da5Sopenharmony_ci  const char* stringBegin =
78fd4e5da5Sopenharmony_ci      reinterpret_cast<const char*>(instruction.words + operand.offset);
79fd4e5da5Sopenharmony_ci  const char* stringEnd = reinterpret_cast<const char*>(
80fd4e5da5Sopenharmony_ci      instruction.words + operand.offset + operand.num_words);
81fd4e5da5Sopenharmony_ci  return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
82fd4e5da5Sopenharmony_ci}
83fd4e5da5Sopenharmony_ci
84fd4e5da5Sopenharmony_cispv_result_t extractOpSourceContinued(
85fd4e5da5Sopenharmony_ci    const spv_position_t& loc, const spv_parsed_instruction_t& instruction,
86fd4e5da5Sopenharmony_ci    std::string* output) {
87fd4e5da5Sopenharmony_ci  assert(output != nullptr);
88fd4e5da5Sopenharmony_ci  assert(instruction.opcode == spv::Op::OpSourceContinued);
89fd4e5da5Sopenharmony_ci  if (instruction.num_operands != 1) {
90fd4e5da5Sopenharmony_ci    spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
91fd4e5da5Sopenharmony_ci                    "Missing operands for OpSourceContinued.");
92fd4e5da5Sopenharmony_ci    return SPV_ERROR_INVALID_BINARY;
93fd4e5da5Sopenharmony_ci  }
94fd4e5da5Sopenharmony_ci
95fd4e5da5Sopenharmony_ci  const auto& operand = instruction.operands[0];
96fd4e5da5Sopenharmony_ci  const char* stringBegin =
97fd4e5da5Sopenharmony_ci      reinterpret_cast<const char*>(instruction.words + operand.offset);
98fd4e5da5Sopenharmony_ci  const char* stringEnd = reinterpret_cast<const char*>(
99fd4e5da5Sopenharmony_ci      instruction.words + operand.offset + operand.num_words);
100fd4e5da5Sopenharmony_ci  return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
101fd4e5da5Sopenharmony_ci}
102fd4e5da5Sopenharmony_ci
103fd4e5da5Sopenharmony_cispv_result_t extractOpSource(const spv_position_t& loc,
104fd4e5da5Sopenharmony_ci                             const spv_parsed_instruction_t& instruction,
105fd4e5da5Sopenharmony_ci                             spv::Id* filename, std::string* code) {
106fd4e5da5Sopenharmony_ci  assert(filename != nullptr && code != nullptr);
107fd4e5da5Sopenharmony_ci  assert(instruction.opcode == spv::Op::OpSource);
108fd4e5da5Sopenharmony_ci  // OpCode [ Source Language | Version | File (optional) | Source (optional) ]
109fd4e5da5Sopenharmony_ci  if (instruction.num_words < 3) {
110fd4e5da5Sopenharmony_ci    spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
111fd4e5da5Sopenharmony_ci                    "Missing operands for OpSource.");
112fd4e5da5Sopenharmony_ci    return SPV_ERROR_INVALID_BINARY;
113fd4e5da5Sopenharmony_ci  }
114fd4e5da5Sopenharmony_ci
115fd4e5da5Sopenharmony_ci  *filename = 0;
116fd4e5da5Sopenharmony_ci  *code = "";
117fd4e5da5Sopenharmony_ci  if (instruction.num_words < 4) {
118fd4e5da5Sopenharmony_ci    return SPV_SUCCESS;
119fd4e5da5Sopenharmony_ci  }
120fd4e5da5Sopenharmony_ci  *filename = instruction.words[3];
121fd4e5da5Sopenharmony_ci
122fd4e5da5Sopenharmony_ci  if (instruction.num_words < 5) {
123fd4e5da5Sopenharmony_ci    return SPV_SUCCESS;
124fd4e5da5Sopenharmony_ci  }
125fd4e5da5Sopenharmony_ci
126fd4e5da5Sopenharmony_ci  const char* stringBegin =
127fd4e5da5Sopenharmony_ci      reinterpret_cast<const char*>(instruction.words + 4);
128fd4e5da5Sopenharmony_ci  const char* stringEnd =
129fd4e5da5Sopenharmony_ci      reinterpret_cast<const char*>(instruction.words + instruction.num_words);
130fd4e5da5Sopenharmony_ci  return ExtractStringLiteral(loc, stringBegin, stringEnd, code);
131fd4e5da5Sopenharmony_ci}
132fd4e5da5Sopenharmony_ci
133fd4e5da5Sopenharmony_ci}  // namespace
134fd4e5da5Sopenharmony_ci
135fd4e5da5Sopenharmony_cibool ExtractSourceFromModule(
136fd4e5da5Sopenharmony_ci    const std::vector<uint32_t>& binary,
137fd4e5da5Sopenharmony_ci    std::unordered_map<std::string, std::string>* output) {
138fd4e5da5Sopenharmony_ci  auto context = spvtools::SpirvTools(kDefaultEnvironment);
139fd4e5da5Sopenharmony_ci  context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
140fd4e5da5Sopenharmony_ci
141fd4e5da5Sopenharmony_ci  // There is nothing valuable in the header.
142fd4e5da5Sopenharmony_ci  spvtools::HeaderParser headerParser = [](const spv_endianness_t,
143fd4e5da5Sopenharmony_ci                                           const spv_parsed_header_t&) {
144fd4e5da5Sopenharmony_ci    return SPV_SUCCESS;
145fd4e5da5Sopenharmony_ci  };
146fd4e5da5Sopenharmony_ci
147fd4e5da5Sopenharmony_ci  std::unordered_map<uint32_t, std::string> stringMap;
148fd4e5da5Sopenharmony_ci  std::vector<std::pair<spv::Id, std::string>> sources;
149fd4e5da5Sopenharmony_ci  spv::Op lastOpcode = spv::Op::OpMax;
150fd4e5da5Sopenharmony_ci  size_t instructionIndex = 0;
151fd4e5da5Sopenharmony_ci
152fd4e5da5Sopenharmony_ci  spvtools::InstructionParser instructionParser =
153fd4e5da5Sopenharmony_ci      [&stringMap, &sources, &lastOpcode,
154fd4e5da5Sopenharmony_ci       &instructionIndex](const spv_parsed_instruction_t& instruction) {
155fd4e5da5Sopenharmony_ci        const spv_position_t loc = {0, 0, instructionIndex + 1};
156fd4e5da5Sopenharmony_ci        spv_result_t result = SPV_SUCCESS;
157fd4e5da5Sopenharmony_ci
158fd4e5da5Sopenharmony_ci        if (instruction.opcode == spv::Op::OpString) {
159fd4e5da5Sopenharmony_ci          std::string content;
160fd4e5da5Sopenharmony_ci          result = extractOpString(loc, instruction, &content);
161fd4e5da5Sopenharmony_ci          if (result == SPV_SUCCESS) {
162fd4e5da5Sopenharmony_ci            stringMap.emplace(instruction.result_id, std::move(content));
163fd4e5da5Sopenharmony_ci          }
164fd4e5da5Sopenharmony_ci        } else if (instruction.opcode == spv::Op::OpSource) {
165fd4e5da5Sopenharmony_ci          spv::Id filenameId;
166fd4e5da5Sopenharmony_ci          std::string code;
167fd4e5da5Sopenharmony_ci          result = extractOpSource(loc, instruction, &filenameId, &code);
168fd4e5da5Sopenharmony_ci          if (result == SPV_SUCCESS) {
169fd4e5da5Sopenharmony_ci            sources.emplace_back(std::make_pair(filenameId, std::move(code)));
170fd4e5da5Sopenharmony_ci          }
171fd4e5da5Sopenharmony_ci        } else if (instruction.opcode == spv::Op::OpSourceContinued) {
172fd4e5da5Sopenharmony_ci          if (lastOpcode != spv::Op::OpSource) {
173fd4e5da5Sopenharmony_ci            spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
174fd4e5da5Sopenharmony_ci                            "OpSourceContinued MUST follow an OpSource.");
175fd4e5da5Sopenharmony_ci            return SPV_ERROR_INVALID_BINARY;
176fd4e5da5Sopenharmony_ci          }
177fd4e5da5Sopenharmony_ci
178fd4e5da5Sopenharmony_ci          assert(sources.size() > 0);
179fd4e5da5Sopenharmony_ci          result = extractOpSourceContinued(loc, instruction,
180fd4e5da5Sopenharmony_ci                                            &sources.back().second);
181fd4e5da5Sopenharmony_ci        }
182fd4e5da5Sopenharmony_ci
183fd4e5da5Sopenharmony_ci        ++instructionIndex;
184fd4e5da5Sopenharmony_ci        lastOpcode = static_cast<spv::Op>(instruction.opcode);
185fd4e5da5Sopenharmony_ci        return result;
186fd4e5da5Sopenharmony_ci      };
187fd4e5da5Sopenharmony_ci
188fd4e5da5Sopenharmony_ci  if (!context.Parse(binary, headerParser, instructionParser)) {
189fd4e5da5Sopenharmony_ci    return false;
190fd4e5da5Sopenharmony_ci  }
191fd4e5da5Sopenharmony_ci
192fd4e5da5Sopenharmony_ci  std::string defaultName = "unnamed-";
193fd4e5da5Sopenharmony_ci  size_t unnamedCount = 0;
194fd4e5da5Sopenharmony_ci  for (auto & [ id, code ] : sources) {
195fd4e5da5Sopenharmony_ci    std::string filename;
196fd4e5da5Sopenharmony_ci    const auto it = stringMap.find(id);
197fd4e5da5Sopenharmony_ci    if (it == stringMap.cend() || it->second.empty()) {
198fd4e5da5Sopenharmony_ci      filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl";
199fd4e5da5Sopenharmony_ci      ++unnamedCount;
200fd4e5da5Sopenharmony_ci    } else {
201fd4e5da5Sopenharmony_ci      filename = it->second;
202fd4e5da5Sopenharmony_ci    }
203fd4e5da5Sopenharmony_ci
204fd4e5da5Sopenharmony_ci    if (output->count(filename) != 0) {
205fd4e5da5Sopenharmony_ci      spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {},
206fd4e5da5Sopenharmony_ci                      "Source file name conflict.");
207fd4e5da5Sopenharmony_ci      return false;
208fd4e5da5Sopenharmony_ci    }
209fd4e5da5Sopenharmony_ci    output->insert({filename, code});
210fd4e5da5Sopenharmony_ci  }
211fd4e5da5Sopenharmony_ci
212fd4e5da5Sopenharmony_ci  return true;
213fd4e5da5Sopenharmony_ci}
214