1// Copyright (c) 2018 Google LLC. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include <algorithm> 16 17#include "source/enum_string_mapping.h" 18#include "source/opcode.h" 19#include "source/val/instruction.h" 20#include "source/val/validate.h" 21#include "source/val/validation_state.h" 22 23namespace spvtools { 24namespace val { 25namespace { 26 27// Returns true if |a| and |b| are instructions defining pointers that point to 28// types logically match and the decorations that apply to |b| are a subset 29// of the decorations that apply to |a|. 30bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b, 31 ValidationState_t& _) { 32 if (a->opcode() != spv::Op::OpTypePointer || 33 b->opcode() != spv::Op::OpTypePointer) { 34 return false; 35 } 36 37 const auto& dec_a = _.id_decorations(a->id()); 38 const auto& dec_b = _.id_decorations(b->id()); 39 for (const auto& dec : dec_b) { 40 if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) { 41 return false; 42 } 43 } 44 45 uint32_t a_type = a->GetOperandAs<uint32_t>(2); 46 uint32_t b_type = b->GetOperandAs<uint32_t>(2); 47 48 if (a_type == b_type) { 49 return true; 50 } 51 52 Instruction* a_type_inst = _.FindDef(a_type); 53 Instruction* b_type_inst = _.FindDef(b_type); 54 55 return _.LogicallyMatch(a_type_inst, b_type_inst, true); 56} 57 58spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { 59 const auto function_type_id = inst->GetOperandAs<uint32_t>(3); 60 const auto function_type = _.FindDef(function_type_id); 61 if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) { 62 return _.diag(SPV_ERROR_INVALID_ID, inst) 63 << "OpFunction Function Type <id> " << _.getIdName(function_type_id) 64 << " is not a function type."; 65 } 66 67 const auto return_id = function_type->GetOperandAs<uint32_t>(1); 68 if (return_id != inst->type_id()) { 69 return _.diag(SPV_ERROR_INVALID_ID, inst) 70 << "OpFunction Result Type <id> " << _.getIdName(inst->type_id()) 71 << " does not match the Function Type's return type <id> " 72 << _.getIdName(return_id) << "."; 73 } 74 75 const std::vector<spv::Op> acceptable = { 76 spv::Op::OpGroupDecorate, 77 spv::Op::OpDecorate, 78 spv::Op::OpEnqueueKernel, 79 spv::Op::OpEntryPoint, 80 spv::Op::OpExecutionMode, 81 spv::Op::OpExecutionModeId, 82 spv::Op::OpFunctionCall, 83 spv::Op::OpGetKernelNDrangeSubGroupCount, 84 spv::Op::OpGetKernelNDrangeMaxSubGroupSize, 85 spv::Op::OpGetKernelWorkGroupSize, 86 spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple, 87 spv::Op::OpGetKernelLocalSizeForSubgroupCount, 88 spv::Op::OpGetKernelMaxNumSubgroups, 89 spv::Op::OpName}; 90 for (auto& pair : inst->uses()) { 91 const auto* use = pair.first; 92 if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == 93 acceptable.end() && 94 !use->IsNonSemantic() && !use->IsDebugInfo()) { 95 return _.diag(SPV_ERROR_INVALID_ID, use) 96 << "Invalid use of function result id " << _.getIdName(inst->id()) 97 << "."; 98 } 99 } 100 101 return SPV_SUCCESS; 102} 103 104spv_result_t ValidateFunctionParameter(ValidationState_t& _, 105 const Instruction* inst) { 106 // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. 107 size_t param_index = 0; 108 size_t inst_num = inst->LineNum() - 1; 109 if (inst_num == 0) { 110 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) 111 << "Function parameter cannot be the first instruction."; 112 } 113 114 auto func_inst = &_.ordered_instructions()[inst_num]; 115 while (--inst_num) { 116 func_inst = &_.ordered_instructions()[inst_num]; 117 if (func_inst->opcode() == spv::Op::OpFunction) { 118 break; 119 } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) { 120 ++param_index; 121 } 122 } 123 124 if (func_inst->opcode() != spv::Op::OpFunction) { 125 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) 126 << "Function parameter must be preceded by a function."; 127 } 128 129 const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3); 130 const auto function_type = _.FindDef(function_type_id); 131 if (!function_type) { 132 return _.diag(SPV_ERROR_INVALID_ID, func_inst) 133 << "Missing function type definition."; 134 } 135 if (param_index >= function_type->words().size() - 3) { 136 return _.diag(SPV_ERROR_INVALID_ID, inst) 137 << "Too many OpFunctionParameters for " << func_inst->id() 138 << ": expected " << function_type->words().size() - 3 139 << " based on the function's type"; 140 } 141 142 const auto param_type = 143 _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2)); 144 if (!param_type || inst->type_id() != param_type->id()) { 145 return _.diag(SPV_ERROR_INVALID_ID, inst) 146 << "OpFunctionParameter Result Type <id> " 147 << _.getIdName(inst->type_id()) 148 << " does not match the OpTypeFunction parameter " 149 "type of the same index."; 150 } 151 152 // Validate that PhysicalStorageBuffer have one of Restrict, Aliased, 153 // RestrictPointer, or AliasedPointer. 154 auto param_nonarray_type_id = param_type->id(); 155 while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) { 156 param_nonarray_type_id = 157 _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u); 158 } 159 if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer) { 160 auto param_nonarray_type = _.FindDef(param_nonarray_type_id); 161 if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) == 162 spv::StorageClass::PhysicalStorageBuffer) { 163 // check for Aliased or Restrict 164 const auto& decorations = _.id_decorations(inst->id()); 165 166 bool foundAliased = std::any_of( 167 decorations.begin(), decorations.end(), [](const Decoration& d) { 168 return spv::Decoration::Aliased == d.dec_type(); 169 }); 170 171 bool foundRestrict = std::any_of( 172 decorations.begin(), decorations.end(), [](const Decoration& d) { 173 return spv::Decoration::Restrict == d.dec_type(); 174 }); 175 176 if (!foundAliased && !foundRestrict) { 177 return _.diag(SPV_ERROR_INVALID_ID, inst) 178 << "OpFunctionParameter " << inst->id() 179 << ": expected Aliased or Restrict for PhysicalStorageBuffer " 180 "pointer."; 181 } 182 if (foundAliased && foundRestrict) { 183 return _.diag(SPV_ERROR_INVALID_ID, inst) 184 << "OpFunctionParameter " << inst->id() 185 << ": can't specify both Aliased and Restrict for " 186 "PhysicalStorageBuffer pointer."; 187 } 188 } else { 189 const auto pointee_type_id = 190 param_nonarray_type->GetOperandAs<uint32_t>(2); 191 const auto pointee_type = _.FindDef(pointee_type_id); 192 if (spv::Op::OpTypePointer == pointee_type->opcode() && 193 pointee_type->GetOperandAs<spv::StorageClass>(1u) == 194 spv::StorageClass::PhysicalStorageBuffer) { 195 // check for AliasedPointer/RestrictPointer 196 const auto& decorations = _.id_decorations(inst->id()); 197 198 bool foundAliased = std::any_of( 199 decorations.begin(), decorations.end(), [](const Decoration& d) { 200 return spv::Decoration::AliasedPointer == d.dec_type(); 201 }); 202 203 bool foundRestrict = std::any_of( 204 decorations.begin(), decorations.end(), [](const Decoration& d) { 205 return spv::Decoration::RestrictPointer == d.dec_type(); 206 }); 207 208 if (!foundAliased && !foundRestrict) { 209 return _.diag(SPV_ERROR_INVALID_ID, inst) 210 << "OpFunctionParameter " << inst->id() 211 << ": expected AliasedPointer or RestrictPointer for " 212 "PhysicalStorageBuffer pointer."; 213 } 214 if (foundAliased && foundRestrict) { 215 return _.diag(SPV_ERROR_INVALID_ID, inst) 216 << "OpFunctionParameter " << inst->id() 217 << ": can't specify both AliasedPointer and " 218 "RestrictPointer for PhysicalStorageBuffer pointer."; 219 } 220 } 221 } 222 } 223 224 return SPV_SUCCESS; 225} 226 227spv_result_t ValidateFunctionCall(ValidationState_t& _, 228 const Instruction* inst) { 229 const auto function_id = inst->GetOperandAs<uint32_t>(2); 230 const auto function = _.FindDef(function_id); 231 if (!function || spv::Op::OpFunction != function->opcode()) { 232 return _.diag(SPV_ERROR_INVALID_ID, inst) 233 << "OpFunctionCall Function <id> " << _.getIdName(function_id) 234 << " is not a function."; 235 } 236 237 auto return_type = _.FindDef(function->type_id()); 238 if (!return_type || return_type->id() != inst->type_id()) { 239 return _.diag(SPV_ERROR_INVALID_ID, inst) 240 << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id()) 241 << "s type does not match Function <id> " 242 << _.getIdName(return_type->id()) << "s return type."; 243 } 244 245 const auto function_type_id = function->GetOperandAs<uint32_t>(3); 246 const auto function_type = _.FindDef(function_type_id); 247 if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) { 248 return _.diag(SPV_ERROR_INVALID_ID, inst) 249 << "Missing function type definition."; 250 } 251 252 const auto function_call_arg_count = inst->words().size() - 4; 253 const auto function_param_count = function_type->words().size() - 3; 254 if (function_param_count != function_call_arg_count) { 255 return _.diag(SPV_ERROR_INVALID_ID, inst) 256 << "OpFunctionCall Function <id>'s parameter count does not match " 257 "the argument count."; 258 } 259 260 for (size_t argument_index = 3, param_index = 2; 261 argument_index < inst->operands().size(); 262 argument_index++, param_index++) { 263 const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index); 264 const auto argument = _.FindDef(argument_id); 265 if (!argument) { 266 return _.diag(SPV_ERROR_INVALID_ID, inst) 267 << "Missing argument " << argument_index - 3 << " definition."; 268 } 269 270 const auto argument_type = _.FindDef(argument->type_id()); 271 if (!argument_type) { 272 return _.diag(SPV_ERROR_INVALID_ID, inst) 273 << "Missing argument " << argument_index - 3 274 << " type definition."; 275 } 276 277 const auto parameter_type_id = 278 function_type->GetOperandAs<uint32_t>(param_index); 279 const auto parameter_type = _.FindDef(parameter_type_id); 280 if (!parameter_type || argument_type->id() != parameter_type->id()) { 281 if (!_.options()->before_hlsl_legalization || 282 !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) { 283 return _.diag(SPV_ERROR_INVALID_ID, inst) 284 << "OpFunctionCall Argument <id> " << _.getIdName(argument_id) 285 << "s type does not match Function <id> " 286 << _.getIdName(parameter_type_id) << "s parameter type."; 287 } 288 } 289 290 if (_.addressing_model() == spv::AddressingModel::Logical) { 291 if (parameter_type->opcode() == spv::Op::OpTypePointer && 292 !_.options()->relax_logical_pointer) { 293 spv::StorageClass sc = 294 parameter_type->GetOperandAs<spv::StorageClass>(1u); 295 // Validate which storage classes can be pointer operands. 296 switch (sc) { 297 case spv::StorageClass::UniformConstant: 298 case spv::StorageClass::Function: 299 case spv::StorageClass::Private: 300 case spv::StorageClass::Workgroup: 301 case spv::StorageClass::AtomicCounter: 302 // These are always allowed. 303 break; 304 case spv::StorageClass::StorageBuffer: 305 if (!_.features().variable_pointers) { 306 return _.diag(SPV_ERROR_INVALID_ID, inst) 307 << "StorageBuffer pointer operand " 308 << _.getIdName(argument_id) 309 << " requires a variable pointers capability"; 310 } 311 break; 312 default: 313 return _.diag(SPV_ERROR_INVALID_ID, inst) 314 << "Invalid storage class for pointer operand " 315 << _.getIdName(argument_id); 316 } 317 318 // Validate memory object declaration requirements. 319 if (argument->opcode() != spv::Op::OpVariable && 320 argument->opcode() != spv::Op::OpFunctionParameter) { 321 const bool ssbo_vptr = _.features().variable_pointers && 322 sc == spv::StorageClass::StorageBuffer; 323 const bool wg_vptr = 324 _.HasCapability(spv::Capability::VariablePointers) && 325 sc == spv::StorageClass::Workgroup; 326 const bool uc_ptr = sc == spv::StorageClass::UniformConstant; 327 if (!ssbo_vptr && !wg_vptr && !uc_ptr) { 328 return _.diag(SPV_ERROR_INVALID_ID, inst) 329 << "Pointer operand " << _.getIdName(argument_id) 330 << " must be a memory object declaration"; 331 } 332 } 333 } 334 } 335 } 336 return SPV_SUCCESS; 337} 338 339} // namespace 340 341spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { 342 switch (inst->opcode()) { 343 case spv::Op::OpFunction: 344 if (auto error = ValidateFunction(_, inst)) return error; 345 break; 346 case spv::Op::OpFunctionParameter: 347 if (auto error = ValidateFunctionParameter(_, inst)) return error; 348 break; 349 case spv::Op::OpFunctionCall: 350 if (auto error = ValidateFunctionCall(_, inst)) return error; 351 break; 352 default: 353 break; 354 } 355 356 return SPV_SUCCESS; 357} 358 359} // namespace val 360} // namespace spvtools 361