1// Copyright (c) 2018 Google LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <algorithm>
16
17#include "source/enum_string_mapping.h"
18#include "source/opcode.h"
19#include "source/val/instruction.h"
20#include "source/val/validate.h"
21#include "source/val/validation_state.h"
22
23namespace spvtools {
24namespace val {
25namespace {
26
27// Returns true if |a| and |b| are instructions defining pointers that point to
28// types logically match and the decorations that apply to |b| are a subset
29// of the decorations that apply to |a|.
30bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
31                              ValidationState_t& _) {
32  if (a->opcode() != spv::Op::OpTypePointer ||
33      b->opcode() != spv::Op::OpTypePointer) {
34    return false;
35  }
36
37  const auto& dec_a = _.id_decorations(a->id());
38  const auto& dec_b = _.id_decorations(b->id());
39  for (const auto& dec : dec_b) {
40    if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
41      return false;
42    }
43  }
44
45  uint32_t a_type = a->GetOperandAs<uint32_t>(2);
46  uint32_t b_type = b->GetOperandAs<uint32_t>(2);
47
48  if (a_type == b_type) {
49    return true;
50  }
51
52  Instruction* a_type_inst = _.FindDef(a_type);
53  Instruction* b_type_inst = _.FindDef(b_type);
54
55  return _.LogicallyMatch(a_type_inst, b_type_inst, true);
56}
57
58spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
59  const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
60  const auto function_type = _.FindDef(function_type_id);
61  if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
62    return _.diag(SPV_ERROR_INVALID_ID, inst)
63           << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
64           << " is not a function type.";
65  }
66
67  const auto return_id = function_type->GetOperandAs<uint32_t>(1);
68  if (return_id != inst->type_id()) {
69    return _.diag(SPV_ERROR_INVALID_ID, inst)
70           << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
71           << " does not match the Function Type's return type <id> "
72           << _.getIdName(return_id) << ".";
73  }
74
75  const std::vector<spv::Op> acceptable = {
76      spv::Op::OpGroupDecorate,
77      spv::Op::OpDecorate,
78      spv::Op::OpEnqueueKernel,
79      spv::Op::OpEntryPoint,
80      spv::Op::OpExecutionMode,
81      spv::Op::OpExecutionModeId,
82      spv::Op::OpFunctionCall,
83      spv::Op::OpGetKernelNDrangeSubGroupCount,
84      spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
85      spv::Op::OpGetKernelWorkGroupSize,
86      spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
87      spv::Op::OpGetKernelLocalSizeForSubgroupCount,
88      spv::Op::OpGetKernelMaxNumSubgroups,
89      spv::Op::OpName};
90  for (auto& pair : inst->uses()) {
91    const auto* use = pair.first;
92    if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
93            acceptable.end() &&
94        !use->IsNonSemantic() && !use->IsDebugInfo()) {
95      return _.diag(SPV_ERROR_INVALID_ID, use)
96             << "Invalid use of function result id " << _.getIdName(inst->id())
97             << ".";
98    }
99  }
100
101  return SPV_SUCCESS;
102}
103
104spv_result_t ValidateFunctionParameter(ValidationState_t& _,
105                                       const Instruction* inst) {
106  // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
107  size_t param_index = 0;
108  size_t inst_num = inst->LineNum() - 1;
109  if (inst_num == 0) {
110    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
111           << "Function parameter cannot be the first instruction.";
112  }
113
114  auto func_inst = &_.ordered_instructions()[inst_num];
115  while (--inst_num) {
116    func_inst = &_.ordered_instructions()[inst_num];
117    if (func_inst->opcode() == spv::Op::OpFunction) {
118      break;
119    } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
120      ++param_index;
121    }
122  }
123
124  if (func_inst->opcode() != spv::Op::OpFunction) {
125    return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
126           << "Function parameter must be preceded by a function.";
127  }
128
129  const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
130  const auto function_type = _.FindDef(function_type_id);
131  if (!function_type) {
132    return _.diag(SPV_ERROR_INVALID_ID, func_inst)
133           << "Missing function type definition.";
134  }
135  if (param_index >= function_type->words().size() - 3) {
136    return _.diag(SPV_ERROR_INVALID_ID, inst)
137           << "Too many OpFunctionParameters for " << func_inst->id()
138           << ": expected " << function_type->words().size() - 3
139           << " based on the function's type";
140  }
141
142  const auto param_type =
143      _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
144  if (!param_type || inst->type_id() != param_type->id()) {
145    return _.diag(SPV_ERROR_INVALID_ID, inst)
146           << "OpFunctionParameter Result Type <id> "
147           << _.getIdName(inst->type_id())
148           << " does not match the OpTypeFunction parameter "
149              "type of the same index.";
150  }
151
152  // Validate that PhysicalStorageBuffer have one of Restrict, Aliased,
153  // RestrictPointer, or AliasedPointer.
154  auto param_nonarray_type_id = param_type->id();
155  while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) {
156    param_nonarray_type_id =
157        _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
158  }
159  if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer) {
160    auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
161    if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) ==
162        spv::StorageClass::PhysicalStorageBuffer) {
163      // check for Aliased or Restrict
164      const auto& decorations = _.id_decorations(inst->id());
165
166      bool foundAliased = std::any_of(
167          decorations.begin(), decorations.end(), [](const Decoration& d) {
168            return spv::Decoration::Aliased == d.dec_type();
169          });
170
171      bool foundRestrict = std::any_of(
172          decorations.begin(), decorations.end(), [](const Decoration& d) {
173            return spv::Decoration::Restrict == d.dec_type();
174          });
175
176      if (!foundAliased && !foundRestrict) {
177        return _.diag(SPV_ERROR_INVALID_ID, inst)
178               << "OpFunctionParameter " << inst->id()
179               << ": expected Aliased or Restrict for PhysicalStorageBuffer "
180                  "pointer.";
181      }
182      if (foundAliased && foundRestrict) {
183        return _.diag(SPV_ERROR_INVALID_ID, inst)
184               << "OpFunctionParameter " << inst->id()
185               << ": can't specify both Aliased and Restrict for "
186                  "PhysicalStorageBuffer pointer.";
187      }
188    } else {
189      const auto pointee_type_id =
190          param_nonarray_type->GetOperandAs<uint32_t>(2);
191      const auto pointee_type = _.FindDef(pointee_type_id);
192      if (spv::Op::OpTypePointer == pointee_type->opcode() &&
193          pointee_type->GetOperandAs<spv::StorageClass>(1u) ==
194              spv::StorageClass::PhysicalStorageBuffer) {
195        // check for AliasedPointer/RestrictPointer
196        const auto& decorations = _.id_decorations(inst->id());
197
198        bool foundAliased = std::any_of(
199            decorations.begin(), decorations.end(), [](const Decoration& d) {
200              return spv::Decoration::AliasedPointer == d.dec_type();
201            });
202
203        bool foundRestrict = std::any_of(
204            decorations.begin(), decorations.end(), [](const Decoration& d) {
205              return spv::Decoration::RestrictPointer == d.dec_type();
206            });
207
208        if (!foundAliased && !foundRestrict) {
209          return _.diag(SPV_ERROR_INVALID_ID, inst)
210                 << "OpFunctionParameter " << inst->id()
211                 << ": expected AliasedPointer or RestrictPointer for "
212                    "PhysicalStorageBuffer pointer.";
213        }
214        if (foundAliased && foundRestrict) {
215          return _.diag(SPV_ERROR_INVALID_ID, inst)
216                 << "OpFunctionParameter " << inst->id()
217                 << ": can't specify both AliasedPointer and "
218                    "RestrictPointer for PhysicalStorageBuffer pointer.";
219        }
220      }
221    }
222  }
223
224  return SPV_SUCCESS;
225}
226
227spv_result_t ValidateFunctionCall(ValidationState_t& _,
228                                  const Instruction* inst) {
229  const auto function_id = inst->GetOperandAs<uint32_t>(2);
230  const auto function = _.FindDef(function_id);
231  if (!function || spv::Op::OpFunction != function->opcode()) {
232    return _.diag(SPV_ERROR_INVALID_ID, inst)
233           << "OpFunctionCall Function <id> " << _.getIdName(function_id)
234           << " is not a function.";
235  }
236
237  auto return_type = _.FindDef(function->type_id());
238  if (!return_type || return_type->id() != inst->type_id()) {
239    return _.diag(SPV_ERROR_INVALID_ID, inst)
240           << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
241           << "s type does not match Function <id> "
242           << _.getIdName(return_type->id()) << "s return type.";
243  }
244
245  const auto function_type_id = function->GetOperandAs<uint32_t>(3);
246  const auto function_type = _.FindDef(function_type_id);
247  if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
248    return _.diag(SPV_ERROR_INVALID_ID, inst)
249           << "Missing function type definition.";
250  }
251
252  const auto function_call_arg_count = inst->words().size() - 4;
253  const auto function_param_count = function_type->words().size() - 3;
254  if (function_param_count != function_call_arg_count) {
255    return _.diag(SPV_ERROR_INVALID_ID, inst)
256           << "OpFunctionCall Function <id>'s parameter count does not match "
257              "the argument count.";
258  }
259
260  for (size_t argument_index = 3, param_index = 2;
261       argument_index < inst->operands().size();
262       argument_index++, param_index++) {
263    const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
264    const auto argument = _.FindDef(argument_id);
265    if (!argument) {
266      return _.diag(SPV_ERROR_INVALID_ID, inst)
267             << "Missing argument " << argument_index - 3 << " definition.";
268    }
269
270    const auto argument_type = _.FindDef(argument->type_id());
271    if (!argument_type) {
272      return _.diag(SPV_ERROR_INVALID_ID, inst)
273             << "Missing argument " << argument_index - 3
274             << " type definition.";
275    }
276
277    const auto parameter_type_id =
278        function_type->GetOperandAs<uint32_t>(param_index);
279    const auto parameter_type = _.FindDef(parameter_type_id);
280    if (!parameter_type || argument_type->id() != parameter_type->id()) {
281      if (!_.options()->before_hlsl_legalization ||
282          !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
283        return _.diag(SPV_ERROR_INVALID_ID, inst)
284               << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
285               << "s type does not match Function <id> "
286               << _.getIdName(parameter_type_id) << "s parameter type.";
287      }
288    }
289
290    if (_.addressing_model() == spv::AddressingModel::Logical) {
291      if (parameter_type->opcode() == spv::Op::OpTypePointer &&
292          !_.options()->relax_logical_pointer) {
293        spv::StorageClass sc =
294            parameter_type->GetOperandAs<spv::StorageClass>(1u);
295        // Validate which storage classes can be pointer operands.
296        switch (sc) {
297          case spv::StorageClass::UniformConstant:
298          case spv::StorageClass::Function:
299          case spv::StorageClass::Private:
300          case spv::StorageClass::Workgroup:
301          case spv::StorageClass::AtomicCounter:
302            // These are always allowed.
303            break;
304          case spv::StorageClass::StorageBuffer:
305            if (!_.features().variable_pointers) {
306              return _.diag(SPV_ERROR_INVALID_ID, inst)
307                     << "StorageBuffer pointer operand "
308                     << _.getIdName(argument_id)
309                     << " requires a variable pointers capability";
310            }
311            break;
312          default:
313            return _.diag(SPV_ERROR_INVALID_ID, inst)
314                   << "Invalid storage class for pointer operand "
315                   << _.getIdName(argument_id);
316        }
317
318        // Validate memory object declaration requirements.
319        if (argument->opcode() != spv::Op::OpVariable &&
320            argument->opcode() != spv::Op::OpFunctionParameter) {
321          const bool ssbo_vptr = _.features().variable_pointers &&
322                                 sc == spv::StorageClass::StorageBuffer;
323          const bool wg_vptr =
324              _.HasCapability(spv::Capability::VariablePointers) &&
325              sc == spv::StorageClass::Workgroup;
326          const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
327          if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
328            return _.diag(SPV_ERROR_INVALID_ID, inst)
329                   << "Pointer operand " << _.getIdName(argument_id)
330                   << " must be a memory object declaration";
331          }
332        }
333      }
334    }
335  }
336  return SPV_SUCCESS;
337}
338
339}  // namespace
340
341spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
342  switch (inst->opcode()) {
343    case spv::Op::OpFunction:
344      if (auto error = ValidateFunction(_, inst)) return error;
345      break;
346    case spv::Op::OpFunctionParameter:
347      if (auto error = ValidateFunctionParameter(_, inst)) return error;
348      break;
349    case spv::Op::OpFunctionCall:
350      if (auto error = ValidateFunctionCall(_, inst)) return error;
351      break;
352    default:
353      break;
354  }
355
356  return SPV_SUCCESS;
357}
358
359}  // namespace val
360}  // namespace spvtools
361