1// Copyright (c) 2015-2016 The Khronos Group Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/val/validate.h"
16
17#include <functional>
18#include <iterator>
19#include <memory>
20#include <string>
21#include <vector>
22
23#include "source/binary.h"
24#include "source/diagnostic.h"
25#include "source/enum_string_mapping.h"
26#include "source/extensions.h"
27#include "source/opcode.h"
28#include "source/spirv_constant.h"
29#include "source/spirv_endian.h"
30#include "source/spirv_target_env.h"
31#include "source/val/construct.h"
32#include "source/val/instruction.h"
33#include "source/val/validation_state.h"
34#include "spirv-tools/libspirv.h"
35
36namespace {
37// TODO(issue 1950): The validator only returns a single message anyway, so no
38// point in generating more than 1 warning.
39static uint32_t kDefaultMaxNumOfWarnings = 1;
40}  // namespace
41
42namespace spvtools {
43namespace val {
44namespace {
45
46// Parses OpExtension instruction and registers extension.
47void RegisterExtension(ValidationState_t& _,
48                       const spv_parsed_instruction_t* inst) {
49  const std::string extension_str = spvtools::GetExtensionString(inst);
50  Extension extension;
51  if (!GetExtensionFromString(extension_str.c_str(), &extension)) {
52    // The error will be logged in the ProcessInstruction pass.
53    return;
54  }
55
56  _.RegisterExtension(extension);
57}
58
59// Parses the beginning of the module searching for OpExtension instructions.
60// Registers extensions if recognized. Returns SPV_REQUESTED_TERMINATION
61// once an instruction which is not spv::Op::OpCapability and
62// spv::Op::OpExtension is encountered. According to the SPIR-V spec extensions
63// are declared after capabilities and before everything else.
64spv_result_t ProcessExtensions(void* user_data,
65                               const spv_parsed_instruction_t* inst) {
66  const spv::Op opcode = static_cast<spv::Op>(inst->opcode);
67  if (opcode == spv::Op::OpCapability) return SPV_SUCCESS;
68
69  if (opcode == spv::Op::OpExtension) {
70    ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
71    RegisterExtension(_, inst);
72    return SPV_SUCCESS;
73  }
74
75  // OpExtension block is finished, requesting termination.
76  return SPV_REQUESTED_TERMINATION;
77}
78
79spv_result_t ProcessInstruction(void* user_data,
80                                const spv_parsed_instruction_t* inst) {
81  ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
82
83  auto* instruction = _.AddOrderedInstruction(inst);
84  _.RegisterDebugInstruction(instruction);
85
86  return SPV_SUCCESS;
87}
88
89spv_result_t ValidateForwardDecls(ValidationState_t& _) {
90  if (_.unresolved_forward_id_count() == 0) return SPV_SUCCESS;
91
92  std::stringstream ss;
93  std::vector<uint32_t> ids = _.UnresolvedForwardIds();
94
95  std::transform(
96      std::begin(ids), std::end(ids),
97      std::ostream_iterator<std::string>(ss, " "),
98      bind(&ValidationState_t::getIdName, std::ref(_), std::placeholders::_1));
99
100  auto id_str = ss.str();
101  return _.diag(SPV_ERROR_INVALID_ID, nullptr)
102         << "The following forward referenced IDs have not been defined:\n"
103         << id_str.substr(0, id_str.size() - 1);
104}
105
106// Entry point validation. Based on 2.16.1 (Universal Validation Rules) of the
107// SPIRV spec:
108// * There is at least one OpEntryPoint instruction, unless the Linkage
109//   capability is being used.
110// * No function can be targeted by both an OpEntryPoint instruction and an
111//   OpFunctionCall instruction.
112//
113// Additionally enforces that entry points for Vulkan should not have recursion.
114spv_result_t ValidateEntryPoints(ValidationState_t& _) {
115  _.ComputeFunctionToEntryPointMapping();
116  _.ComputeRecursiveEntryPoints();
117
118  if (_.entry_points().empty() && !_.HasCapability(spv::Capability::Linkage)) {
119    return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
120           << "No OpEntryPoint instruction was found. This is only allowed if "
121              "the Linkage capability is being used.";
122  }
123
124  for (const auto& entry_point : _.entry_points()) {
125    if (_.IsFunctionCallTarget(entry_point)) {
126      return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
127             << "A function (" << entry_point
128             << ") may not be targeted by both an OpEntryPoint instruction and "
129                "an OpFunctionCall instruction.";
130    }
131
132    // For Vulkan, the static function-call graph for an entry point
133    // must not contain cycles.
134    if (spvIsVulkanEnv(_.context()->target_env)) {
135      if (_.recursive_entry_points().find(entry_point) !=
136          _.recursive_entry_points().end()) {
137        return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
138               << _.VkErrorID(4634)
139               << "Entry points may not have a call graph with cycles.";
140      }
141    }
142  }
143
144  return SPV_SUCCESS;
145}
146
147spv_result_t ValidateBinaryUsingContextAndValidationState(
148    const spv_context_t& context, const uint32_t* words, const size_t num_words,
149    spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
150  auto binary = std::unique_ptr<spv_const_binary_t>(
151      new spv_const_binary_t{words, num_words});
152
153  spv_endianness_t endian;
154  spv_position_t position = {};
155  if (spvBinaryEndianness(binary.get(), &endian)) {
156    return DiagnosticStream(position, context.consumer, "",
157                            SPV_ERROR_INVALID_BINARY)
158           << "Invalid SPIR-V magic number.";
159  }
160
161  spv_header_t header;
162  if (spvBinaryHeaderGet(binary.get(), endian, &header)) {
163    return DiagnosticStream(position, context.consumer, "",
164                            SPV_ERROR_INVALID_BINARY)
165           << "Invalid SPIR-V header.";
166  }
167
168  if (header.version > spvVersionForTargetEnv(context.target_env)) {
169    return DiagnosticStream(position, context.consumer, "",
170                            SPV_ERROR_WRONG_VERSION)
171           << "Invalid SPIR-V binary version "
172           << SPV_SPIRV_VERSION_MAJOR_PART(header.version) << "."
173           << SPV_SPIRV_VERSION_MINOR_PART(header.version)
174           << " for target environment "
175           << spvTargetEnvDescription(context.target_env) << ".";
176  }
177
178  if (header.bound > vstate->options()->universal_limits_.max_id_bound) {
179    return DiagnosticStream(position, context.consumer, "",
180                            SPV_ERROR_INVALID_BINARY)
181           << "Invalid SPIR-V.  The id bound is larger than the max id bound "
182           << vstate->options()->universal_limits_.max_id_bound << ".";
183  }
184
185  // Look for OpExtension instructions and register extensions.
186  // This parse should not produce any error messages. Hijack the context and
187  // replace the message consumer so that we do not pollute any state in input
188  // consumer.
189  spv_context_t hijacked_context = context;
190  hijacked_context.consumer = [](spv_message_level_t, const char*,
191                                 const spv_position_t&, const char*) {};
192  spvBinaryParse(&hijacked_context, vstate, words, num_words,
193                 /* parsed_header = */ nullptr, ProcessExtensions,
194                 /* diagnostic = */ nullptr);
195
196  // Parse the module and perform inline validation checks. These checks do
197  // not require the knowledge of the whole module.
198  if (auto error = spvBinaryParse(&context, vstate, words, num_words,
199                                  /*parsed_header =*/nullptr,
200                                  ProcessInstruction, pDiagnostic)) {
201    return error;
202  }
203
204  bool has_mask_task_nv = false;
205  bool has_mask_task_ext = false;
206  std::vector<Instruction*> visited_entry_points;
207  for (auto& instruction : vstate->ordered_instructions()) {
208    {
209      // In order to do this work outside of Process Instruction we need to be
210      // able to, briefly, de-const the instruction.
211      Instruction* inst = const_cast<Instruction*>(&instruction);
212
213      if (inst->opcode() == spv::Op::OpEntryPoint) {
214        const auto entry_point = inst->GetOperandAs<uint32_t>(1);
215        const auto execution_model = inst->GetOperandAs<spv::ExecutionModel>(0);
216        const std::string desc_name = inst->GetOperandAs<std::string>(2);
217
218        ValidationState_t::EntryPointDescription desc;
219        desc.name = desc_name;
220
221        std::vector<uint32_t> interfaces;
222        for (size_t j = 3; j < inst->operands().size(); ++j)
223          desc.interfaces.push_back(inst->word(inst->operand(j).offset));
224
225        vstate->RegisterEntryPoint(entry_point, execution_model,
226                                   std::move(desc));
227
228        if (visited_entry_points.size() > 0) {
229          for (const Instruction* check_inst : visited_entry_points) {
230            const auto check_execution_model =
231                check_inst->GetOperandAs<spv::ExecutionModel>(0);
232            const std::string check_name =
233                check_inst->GetOperandAs<std::string>(2);
234
235            if (desc_name == check_name &&
236                execution_model == check_execution_model) {
237              return vstate->diag(SPV_ERROR_INVALID_DATA, inst)
238                     << "2 Entry points cannot share the same name and "
239                        "ExecutionMode.";
240            }
241          }
242        }
243        visited_entry_points.push_back(inst);
244
245        has_mask_task_nv |= (execution_model == spv::ExecutionModel::TaskNV ||
246                             execution_model == spv::ExecutionModel::MeshNV);
247        has_mask_task_ext |= (execution_model == spv::ExecutionModel::TaskEXT ||
248                              execution_model == spv::ExecutionModel::MeshEXT);
249      }
250      if (inst->opcode() == spv::Op::OpFunctionCall) {
251        if (!vstate->in_function_body()) {
252          return vstate->diag(SPV_ERROR_INVALID_LAYOUT, &instruction)
253                 << "A FunctionCall must happen within a function body.";
254        }
255
256        const auto called_id = inst->GetOperandAs<uint32_t>(2);
257        vstate->AddFunctionCallTarget(called_id);
258      }
259
260      if (vstate->in_function_body()) {
261        inst->set_function(&(vstate->current_function()));
262        inst->set_block(vstate->current_function().current_block());
263
264        if (vstate->in_block() && spvOpcodeIsBlockTerminator(inst->opcode())) {
265          vstate->current_function().current_block()->set_terminator(inst);
266        }
267      }
268
269      if (auto error = IdPass(*vstate, inst)) return error;
270    }
271
272    if (auto error = CapabilityPass(*vstate, &instruction)) return error;
273    if (auto error = ModuleLayoutPass(*vstate, &instruction)) return error;
274    if (auto error = CfgPass(*vstate, &instruction)) return error;
275    if (auto error = InstructionPass(*vstate, &instruction)) return error;
276
277    // Now that all of the checks are done, update the state.
278    {
279      Instruction* inst = const_cast<Instruction*>(&instruction);
280      vstate->RegisterInstruction(inst);
281      if (inst->opcode() == spv::Op::OpTypeForwardPointer) {
282        vstate->RegisterForwardPointer(inst->GetOperandAs<uint32_t>(0));
283      }
284    }
285  }
286
287  if (!vstate->has_memory_model_specified())
288    return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
289           << "Missing required OpMemoryModel instruction.";
290
291  if (vstate->in_function_body())
292    return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
293           << "Missing OpFunctionEnd at end of module.";
294
295  if (vstate->HasCapability(spv::Capability::BindlessTextureNV) &&
296      !vstate->has_samplerimage_variable_address_mode_specified())
297    return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
298           << "Missing required OpSamplerImageAddressingModeNV instruction.";
299
300  if (has_mask_task_ext && has_mask_task_nv)
301    return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
302           << vstate->VkErrorID(7102)
303           << "Module can't mix MeshEXT/TaskEXT with MeshNV/TaskNV Execution "
304              "Model.";
305
306  // Catch undefined forward references before performing further checks.
307  if (auto error = ValidateForwardDecls(*vstate)) return error;
308
309  // Calculate reachability after all the blocks are parsed, but early that it
310  // can be relied on in subsequent pases.
311  ReachabilityPass(*vstate);
312
313  // ID usage needs be handled in its own iteration of the instructions,
314  // between the two others. It depends on the first loop to have been
315  // finished, so that all instructions have been registered. And the following
316  // loop depends on all of the usage data being populated. Thus it cannot live
317  // in either of those iterations.
318  // It should also live after the forward declaration check, since it will
319  // have problems with missing forward declarations, but give less useful error
320  // messages.
321  for (size_t i = 0; i < vstate->ordered_instructions().size(); ++i) {
322    auto& instruction = vstate->ordered_instructions()[i];
323    if (auto error = UpdateIdUse(*vstate, &instruction)) return error;
324  }
325
326  // Validate individual opcodes.
327  for (size_t i = 0; i < vstate->ordered_instructions().size(); ++i) {
328    auto& instruction = vstate->ordered_instructions()[i];
329
330    // Keep these passes in the order they appear in the SPIR-V specification
331    // sections to maintain test consistency.
332    if (auto error = MiscPass(*vstate, &instruction)) return error;
333    if (auto error = DebugPass(*vstate, &instruction)) return error;
334    if (auto error = AnnotationPass(*vstate, &instruction)) return error;
335    if (auto error = ExtensionPass(*vstate, &instruction)) return error;
336    if (auto error = ModeSettingPass(*vstate, &instruction)) return error;
337    if (auto error = TypePass(*vstate, &instruction)) return error;
338    if (auto error = ConstantPass(*vstate, &instruction)) return error;
339    if (auto error = MemoryPass(*vstate, &instruction)) return error;
340    if (auto error = FunctionPass(*vstate, &instruction)) return error;
341    if (auto error = ImagePass(*vstate, &instruction)) return error;
342    if (auto error = ConversionPass(*vstate, &instruction)) return error;
343    if (auto error = CompositesPass(*vstate, &instruction)) return error;
344    if (auto error = ArithmeticsPass(*vstate, &instruction)) return error;
345    if (auto error = BitwisePass(*vstate, &instruction)) return error;
346    if (auto error = LogicalsPass(*vstate, &instruction)) return error;
347    if (auto error = ControlFlowPass(*vstate, &instruction)) return error;
348    if (auto error = DerivativesPass(*vstate, &instruction)) return error;
349    if (auto error = AtomicsPass(*vstate, &instruction)) return error;
350    if (auto error = PrimitivesPass(*vstate, &instruction)) return error;
351    if (auto error = BarriersPass(*vstate, &instruction)) return error;
352    // Group
353    // Device-Side Enqueue
354    // Pipe
355    if (auto error = NonUniformPass(*vstate, &instruction)) return error;
356
357    if (auto error = LiteralsPass(*vstate, &instruction)) return error;
358    if (auto error = RayQueryPass(*vstate, &instruction)) return error;
359    if (auto error = RayTracingPass(*vstate, &instruction)) return error;
360    if (auto error = RayReorderNVPass(*vstate, &instruction)) return error;
361    if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
362  }
363
364  // Validate the preconditions involving adjacent instructions. e.g.
365  // spv::Op::OpPhi must only be preceded by spv::Op::OpLabel, spv::Op::OpPhi,
366  // or spv::Op::OpLine.
367  if (auto error = ValidateAdjacency(*vstate)) return error;
368
369  if (auto error = ValidateEntryPoints(*vstate)) return error;
370  // CFG checks are performed after the binary has been parsed
371  // and the CFGPass has collected information about the control flow
372  if (auto error = PerformCfgChecks(*vstate)) return error;
373  if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error;
374  if (auto error = ValidateDecorations(*vstate)) return error;
375  if (auto error = ValidateInterfaces(*vstate)) return error;
376  // TODO(dsinclair): Restructure ValidateBuiltins so we can move into the
377  // for() above as it loops over all ordered_instructions internally.
378  if (auto error = ValidateBuiltIns(*vstate)) return error;
379  // These checks must be performed after individual opcode checks because
380  // those checks register the limitation checked here.
381  for (const auto& inst : vstate->ordered_instructions()) {
382    if (auto error = ValidateExecutionLimitations(*vstate, &inst)) return error;
383    if (auto error = ValidateSmallTypeUses(*vstate, &inst)) return error;
384    if (auto error = ValidateQCOMImageProcessingTextureUsages(*vstate, &inst))
385      return error;
386  }
387
388  return SPV_SUCCESS;
389}
390
391}  // namespace
392
393spv_result_t ValidateBinaryAndKeepValidationState(
394    const spv_const_context context, spv_const_validator_options options,
395    const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic,
396    std::unique_ptr<ValidationState_t>* vstate) {
397  spv_context_t hijack_context = *context;
398  if (pDiagnostic) {
399    *pDiagnostic = nullptr;
400    UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
401  }
402
403  vstate->reset(new ValidationState_t(&hijack_context, options, words,
404                                      num_words, kDefaultMaxNumOfWarnings));
405
406  return ValidateBinaryUsingContextAndValidationState(
407      hijack_context, words, num_words, pDiagnostic, vstate->get());
408}
409
410}  // namespace val
411}  // namespace spvtools
412
413spv_result_t spvValidate(const spv_const_context context,
414                         const spv_const_binary binary,
415                         spv_diagnostic* pDiagnostic) {
416  return spvValidateBinary(context, binary->code, binary->wordCount,
417                           pDiagnostic);
418}
419
420spv_result_t spvValidateBinary(const spv_const_context context,
421                               const uint32_t* words, const size_t num_words,
422                               spv_diagnostic* pDiagnostic) {
423  spv_context_t hijack_context = *context;
424  if (pDiagnostic) {
425    *pDiagnostic = nullptr;
426    spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
427  }
428
429  // This interface is used for default command line options.
430  spv_validator_options default_options = spvValidatorOptionsCreate();
431
432  // Create the ValidationState using the context and default options.
433  spvtools::val::ValidationState_t vstate(&hijack_context, default_options,
434                                          words, num_words,
435                                          kDefaultMaxNumOfWarnings);
436
437  spv_result_t result =
438      spvtools::val::ValidateBinaryUsingContextAndValidationState(
439          hijack_context, words, num_words, pDiagnostic, &vstate);
440
441  spvValidatorOptionsDestroy(default_options);
442  return result;
443}
444
445spv_result_t spvValidateWithOptions(const spv_const_context context,
446                                    spv_const_validator_options options,
447                                    const spv_const_binary binary,
448                                    spv_diagnostic* pDiagnostic) {
449  spv_context_t hijack_context = *context;
450  if (pDiagnostic) {
451    *pDiagnostic = nullptr;
452    spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
453  }
454
455  // Create the ValidationState using the context.
456  spvtools::val::ValidationState_t vstate(&hijack_context, options,
457                                          binary->code, binary->wordCount,
458                                          kDefaultMaxNumOfWarnings);
459
460  return spvtools::val::ValidateBinaryUsingContextAndValidationState(
461      hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate);
462}
463