1/* 2 * Copyright © 2021 Google 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#include "radv_acceleration_structure.h" 25#include "radv_debug.h" 26#include "radv_meta.h" 27#include "radv_private.h" 28#include "radv_rt_common.h" 29#include "radv_shader.h" 30 31#include "nir/nir.h" 32#include "nir/nir_builder.h" 33#include "nir/nir_builtin_builder.h" 34 35static VkRayTracingPipelineCreateInfoKHR 36radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo) 37{ 38 VkRayTracingPipelineCreateInfoKHR local_create_info = *pCreateInfo; 39 uint32_t total_stages = pCreateInfo->stageCount; 40 uint32_t total_groups = pCreateInfo->groupCount; 41 42 if (pCreateInfo->pLibraryInfo) { 43 for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) { 44 RADV_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]); 45 struct radv_library_pipeline *library_pipeline = radv_pipeline_to_library(pipeline); 46 47 total_stages += library_pipeline->stage_count; 48 total_groups += library_pipeline->group_count; 49 } 50 } 51 VkPipelineShaderStageCreateInfo *stages = NULL; 52 VkRayTracingShaderGroupCreateInfoKHR *groups = NULL; 53 local_create_info.stageCount = total_stages; 54 local_create_info.groupCount = total_groups; 55 local_create_info.pStages = stages = 56 malloc(sizeof(VkPipelineShaderStageCreateInfo) * total_stages); 57 local_create_info.pGroups = groups = 58 malloc(sizeof(VkRayTracingShaderGroupCreateInfoKHR) * total_groups); 59 if (!local_create_info.pStages || !local_create_info.pGroups) 60 return local_create_info; 61 62 total_stages = pCreateInfo->stageCount; 63 total_groups = pCreateInfo->groupCount; 64 for (unsigned j = 0; j < pCreateInfo->stageCount; ++j) 65 stages[j] = pCreateInfo->pStages[j]; 66 for (unsigned j = 0; j < pCreateInfo->groupCount; ++j) 67 groups[j] = pCreateInfo->pGroups[j]; 68 69 if (pCreateInfo->pLibraryInfo) { 70 for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) { 71 RADV_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]); 72 struct radv_library_pipeline *library_pipeline = radv_pipeline_to_library(pipeline); 73 74 for (unsigned j = 0; j < library_pipeline->stage_count; ++j) 75 stages[total_stages + j] = library_pipeline->stages[j]; 76 for (unsigned j = 0; j < library_pipeline->group_count; ++j) { 77 VkRayTracingShaderGroupCreateInfoKHR *dst = &groups[total_groups + j]; 78 *dst = library_pipeline->groups[j]; 79 if (dst->generalShader != VK_SHADER_UNUSED_KHR) 80 dst->generalShader += total_stages; 81 if (dst->closestHitShader != VK_SHADER_UNUSED_KHR) 82 dst->closestHitShader += total_stages; 83 if (dst->anyHitShader != VK_SHADER_UNUSED_KHR) 84 dst->anyHitShader += total_stages; 85 if (dst->intersectionShader != VK_SHADER_UNUSED_KHR) 86 dst->intersectionShader += total_stages; 87 } 88 total_stages += library_pipeline->stage_count; 89 total_groups += library_pipeline->group_count; 90 } 91 } 92 return local_create_info; 93} 94 95static VkResult 96radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache, 97 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, 98 const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline) 99{ 100 RADV_FROM_HANDLE(radv_device, device, _device); 101 struct radv_library_pipeline *pipeline; 102 103 pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*pipeline), 8, 104 VK_SYSTEM_ALLOCATION_SCOPE_OBJECT); 105 if (pipeline == NULL) 106 return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY); 107 108 radv_pipeline_init(device, &pipeline->base, RADV_PIPELINE_LIBRARY); 109 110 VkRayTracingPipelineCreateInfoKHR local_create_info = 111 radv_create_merged_rt_create_info(pCreateInfo); 112 if (!local_create_info.pStages || !local_create_info.pGroups) 113 goto fail; 114 115 if (local_create_info.stageCount) { 116 pipeline->stage_count = local_create_info.stageCount; 117 118 size_t size = sizeof(VkPipelineShaderStageCreateInfo) * local_create_info.stageCount; 119 pipeline->stages = malloc(size); 120 if (!pipeline->stages) 121 goto fail; 122 123 memcpy(pipeline->stages, local_create_info.pStages, size); 124 125 pipeline->hashes = malloc(sizeof(*pipeline->hashes) * local_create_info.stageCount); 126 if (!pipeline->hashes) 127 goto fail; 128 129 pipeline->identifiers = malloc(sizeof(*pipeline->identifiers) * local_create_info.stageCount); 130 if (!pipeline->identifiers) 131 goto fail; 132 133 for (uint32_t i = 0; i < local_create_info.stageCount; i++) { 134 RADV_FROM_HANDLE(vk_shader_module, module, pipeline->stages[i].module); 135 136 const VkPipelineShaderStageModuleIdentifierCreateInfoEXT *iinfo = 137 vk_find_struct_const(local_create_info.pStages[i].pNext, 138 PIPELINE_SHADER_STAGE_MODULE_IDENTIFIER_CREATE_INFO_EXT); 139 140 if (module) { 141 struct vk_shader_module *new_module = vk_shader_module_clone(NULL, module); 142 pipeline->stages[i].module = vk_shader_module_to_handle(new_module); 143 pipeline->stages[i].pNext = NULL; 144 } else { 145 assert(iinfo); 146 pipeline->identifiers[i].identifierSize = 147 MIN2(iinfo->identifierSize, sizeof(pipeline->hashes[i].sha1)); 148 memcpy(pipeline->hashes[i].sha1, iinfo->pIdentifier, 149 pipeline->identifiers[i].identifierSize); 150 pipeline->stages[i].module = VK_NULL_HANDLE; 151 pipeline->stages[i].pNext = &pipeline->identifiers[i]; 152 pipeline->identifiers[i].sType = 153 VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_MODULE_IDENTIFIER_CREATE_INFO_EXT; 154 pipeline->identifiers[i].pNext = NULL; 155 pipeline->identifiers[i].pIdentifier = pipeline->hashes[i].sha1; 156 } 157 } 158 } 159 160 if (local_create_info.groupCount) { 161 size_t size = sizeof(VkRayTracingShaderGroupCreateInfoKHR) * local_create_info.groupCount; 162 pipeline->group_count = local_create_info.groupCount; 163 pipeline->groups = malloc(size); 164 if (!pipeline->groups) 165 goto fail; 166 memcpy(pipeline->groups, local_create_info.pGroups, size); 167 } 168 169 *pPipeline = radv_pipeline_to_handle(&pipeline->base); 170 171 free((void *)local_create_info.pGroups); 172 free((void *)local_create_info.pStages); 173 return VK_SUCCESS; 174fail: 175 free(pipeline->groups); 176 free(pipeline->stages); 177 free(pipeline->hashes); 178 free(pipeline->identifiers); 179 free((void *)local_create_info.pGroups); 180 free((void *)local_create_info.pStages); 181 return VK_ERROR_OUT_OF_HOST_MEMORY; 182} 183 184/* 185 * Global variables for an RT pipeline 186 */ 187struct rt_variables { 188 const VkRayTracingPipelineCreateInfoKHR *create_info; 189 190 /* idx of the next shader to run in the next iteration of the main loop. 191 * During traversal, idx is used to store the SBT index and will contain 192 * the correct resume index upon returning. 193 */ 194 nir_variable *idx; 195 196 /* scratch offset of the argument area relative to stack_ptr */ 197 nir_variable *arg; 198 199 nir_variable *stack_ptr; 200 201 /* global address of the SBT entry used for the shader */ 202 nir_variable *shader_record_ptr; 203 204 /* trace_ray arguments */ 205 nir_variable *accel_struct; 206 nir_variable *flags; 207 nir_variable *cull_mask; 208 nir_variable *sbt_offset; 209 nir_variable *sbt_stride; 210 nir_variable *miss_index; 211 nir_variable *origin; 212 nir_variable *tmin; 213 nir_variable *direction; 214 nir_variable *tmax; 215 216 /* from the BTAS instance currently being visited */ 217 nir_variable *custom_instance_and_mask; 218 219 /* Properties of the primitive currently being visited. */ 220 nir_variable *primitive_id; 221 nir_variable *geometry_id_and_flags; 222 nir_variable *instance_id; 223 nir_variable *instance_addr; 224 nir_variable *hit_kind; 225 nir_variable *opaque; 226 227 /* Safeguard to ensure we don't end up in an infinite loop of non-existing case. Should not be 228 * needed but is extra anti-hang safety during bring-up. */ 229 nir_variable *main_loop_case_visited; 230 231 /* Output variables for intersection & anyhit shaders. */ 232 nir_variable *ahit_accept; 233 nir_variable *ahit_terminate; 234 235 /* Array of stack size struct for recording the max stack size for each group. */ 236 struct radv_pipeline_shader_stack_size *stack_sizes; 237 unsigned stage_idx; 238}; 239 240static void 241reserve_stack_size(struct rt_variables *vars, uint32_t size) 242{ 243 for (uint32_t group_idx = 0; group_idx < vars->create_info->groupCount; group_idx++) { 244 const VkRayTracingShaderGroupCreateInfoKHR *group = vars->create_info->pGroups + group_idx; 245 246 if (vars->stage_idx == group->generalShader || vars->stage_idx == group->closestHitShader) 247 vars->stack_sizes[group_idx].recursive_size = 248 MAX2(vars->stack_sizes[group_idx].recursive_size, size); 249 250 if (vars->stage_idx == group->anyHitShader || vars->stage_idx == group->intersectionShader) 251 vars->stack_sizes[group_idx].non_recursive_size = 252 MAX2(vars->stack_sizes[group_idx].non_recursive_size, size); 253 } 254} 255 256static struct rt_variables 257create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info, 258 struct radv_pipeline_shader_stack_size *stack_sizes) 259{ 260 struct rt_variables vars = { 261 .create_info = create_info, 262 }; 263 vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx"); 264 vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg"); 265 vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr"); 266 vars.shader_record_ptr = 267 nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr"); 268 269 const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); 270 vars.accel_struct = 271 nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct"); 272 vars.flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ray_flags"); 273 vars.cull_mask = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask"); 274 vars.sbt_offset = 275 nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset"); 276 vars.sbt_stride = 277 nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride"); 278 vars.miss_index = 279 nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index"); 280 vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin"); 281 vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin"); 282 vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction"); 283 vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax"); 284 285 vars.custom_instance_and_mask = nir_variable_create( 286 shader, nir_var_shader_temp, glsl_uint_type(), "custom_instance_and_mask"); 287 vars.primitive_id = 288 nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id"); 289 vars.geometry_id_and_flags = 290 nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags"); 291 vars.instance_id = 292 nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "instance_id"); 293 vars.instance_addr = 294 nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr"); 295 vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind"); 296 vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque"); 297 298 vars.main_loop_case_visited = 299 nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "main_loop_case_visited"); 300 vars.ahit_accept = 301 nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_accept"); 302 vars.ahit_terminate = 303 nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_terminate"); 304 305 vars.stack_sizes = stack_sizes; 306 return vars; 307} 308 309/* 310 * Remap all the variables between the two rt_variables struct for inlining. 311 */ 312static void 313map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, 314 const struct rt_variables *dst) 315{ 316 src->create_info = dst->create_info; 317 318 _mesa_hash_table_insert(var_remap, src->idx, dst->idx); 319 _mesa_hash_table_insert(var_remap, src->arg, dst->arg); 320 _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr); 321 _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr); 322 323 _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct); 324 _mesa_hash_table_insert(var_remap, src->flags, dst->flags); 325 _mesa_hash_table_insert(var_remap, src->cull_mask, dst->cull_mask); 326 _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset); 327 _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride); 328 _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index); 329 _mesa_hash_table_insert(var_remap, src->origin, dst->origin); 330 _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin); 331 _mesa_hash_table_insert(var_remap, src->direction, dst->direction); 332 _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax); 333 334 _mesa_hash_table_insert(var_remap, src->custom_instance_and_mask, dst->custom_instance_and_mask); 335 _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id); 336 _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags); 337 _mesa_hash_table_insert(var_remap, src->instance_id, dst->instance_id); 338 _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr); 339 _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind); 340 _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque); 341 _mesa_hash_table_insert(var_remap, src->ahit_accept, dst->ahit_accept); 342 _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate); 343 344 src->stack_sizes = dst->stack_sizes; 345 src->stage_idx = dst->stage_idx; 346} 347 348/* 349 * Create a copy of the global rt variables where the primitive/instance related variables are 350 * independent.This is needed as we need to keep the old values of the global variables around 351 * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied 352 * to the outer variables once we commit to a better hit. 353 */ 354static struct rt_variables 355create_inner_vars(nir_builder *b, const struct rt_variables *vars) 356{ 357 struct rt_variables inner_vars = *vars; 358 inner_vars.idx = 359 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx"); 360 inner_vars.shader_record_ptr = nir_variable_create( 361 b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr"); 362 inner_vars.primitive_id = 363 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id"); 364 inner_vars.geometry_id_and_flags = nir_variable_create( 365 b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags"); 366 inner_vars.tmax = 367 nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax"); 368 inner_vars.instance_id = 369 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_instance_id"); 370 inner_vars.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp, 371 glsl_uint64_t_type(), "inner_instance_addr"); 372 inner_vars.hit_kind = 373 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind"); 374 inner_vars.custom_instance_and_mask = nir_variable_create( 375 b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_custom_instance_and_mask"); 376 377 return inner_vars; 378} 379 380/* The hit attributes are stored on the stack. This is the offset compared to the current stack 381 * pointer of where the hit attrib is stored. */ 382const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE); 383 384static void 385insert_rt_return(nir_builder *b, const struct rt_variables *vars) 386{ 387 nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -16), 1); 388 nir_store_var(b, vars->idx, 389 nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1); 390} 391 392enum sbt_type { 393 SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress), 394 SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress), 395 SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress), 396 SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress), 397}; 398 399static nir_ssa_def * 400get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding) 401{ 402 nir_ssa_def *desc_base_addr = nir_load_sbt_base_amd(b); 403 404 nir_ssa_def *desc = 405 nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding))); 406 407 nir_ssa_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16)); 408 nir_ssa_def *stride = 409 nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, stride_offset)); 410 411 return nir_iadd(b, desc, nir_imul(b, nir_u2u64(b, idx), stride)); 412} 413 414static void 415load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx, 416 enum sbt_type binding, unsigned offset) 417{ 418 nir_ssa_def *addr = get_sbt_ptr(b, idx, binding); 419 420 nir_ssa_def *load_addr = nir_iadd_imm(b, addr, offset); 421 nir_ssa_def *v_idx = nir_build_load_global(b, 1, 32, load_addr); 422 423 nir_store_var(b, vars->idx, v_idx, 1); 424 425 nir_ssa_def *record_addr = nir_iadd_imm(b, addr, RADV_RT_HANDLE_SIZE); 426 nir_store_var(b, vars->shader_record_ptr, record_addr, 1); 427} 428 429/* This lowers all the RT instructions that we do not want to pass on to the combined shader and 430 * that we can implement using the variables from the shader we are going to inline into. */ 431static void 432lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base) 433{ 434 nir_builder b_shader; 435 nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader)); 436 437 nir_foreach_block (block, nir_shader_get_entrypoint(shader)) { 438 nir_foreach_instr_safe (instr, block) { 439 switch (instr->type) { 440 case nir_instr_type_intrinsic: { 441 b_shader.cursor = nir_before_instr(instr); 442 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); 443 nir_ssa_def *ret = NULL; 444 445 switch (intr->intrinsic) { 446 case nir_intrinsic_rt_execute_callable: { 447 uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; 448 uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1; 449 450 nir_store_var( 451 &b_shader, vars->stack_ptr, 452 nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1); 453 nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx), 454 nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16); 455 456 nir_store_var(&b_shader, vars->stack_ptr, 457 nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16), 458 1); 459 load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0); 460 461 nir_store_var(&b_shader, vars->arg, 462 nir_iadd_imm(&b_shader, intr->src[1].ssa, -size - 16), 1); 463 464 reserve_stack_size(vars, size + 16); 465 break; 466 } 467 case nir_intrinsic_rt_trace_ray: { 468 uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; 469 uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1; 470 471 nir_store_var( 472 &b_shader, vars->stack_ptr, 473 nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1); 474 nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx), 475 nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16); 476 477 nir_store_var(&b_shader, vars->stack_ptr, 478 nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16), 479 1); 480 481 nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1); 482 nir_store_var(&b_shader, vars->arg, 483 nir_iadd_imm(&b_shader, intr->src[10].ssa, -size - 16), 1); 484 485 reserve_stack_size(vars, size + 16); 486 487 /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */ 488 nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1); 489 nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1); 490 nir_store_var(&b_shader, vars->cull_mask, 491 nir_iand_imm(&b_shader, intr->src[2].ssa, 0xff), 0x1); 492 nir_store_var(&b_shader, vars->sbt_offset, 493 nir_iand_imm(&b_shader, intr->src[3].ssa, 0xf), 0x1); 494 nir_store_var(&b_shader, vars->sbt_stride, 495 nir_iand_imm(&b_shader, intr->src[4].ssa, 0xf), 0x1); 496 nir_store_var(&b_shader, vars->miss_index, 497 nir_iand_imm(&b_shader, intr->src[5].ssa, 0xffff), 0x1); 498 nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7); 499 nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1); 500 nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7); 501 nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1); 502 break; 503 } 504 case nir_intrinsic_rt_resume: { 505 uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE; 506 507 nir_store_var( 508 &b_shader, vars->stack_ptr, 509 nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), -size), 1); 510 break; 511 } 512 case nir_intrinsic_rt_return_amd: { 513 if (shader->info.stage == MESA_SHADER_RAYGEN) { 514 nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1); 515 break; 516 } 517 insert_rt_return(&b_shader, vars); 518 break; 519 } 520 case nir_intrinsic_load_scratch: { 521 nir_instr_rewrite_src_ssa( 522 instr, &intr->src[0], 523 nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa)); 524 continue; 525 } 526 case nir_intrinsic_store_scratch: { 527 nir_instr_rewrite_src_ssa( 528 instr, &intr->src[1], 529 nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa)); 530 continue; 531 } 532 case nir_intrinsic_load_rt_arg_scratch_offset_amd: { 533 ret = nir_load_var(&b_shader, vars->arg); 534 break; 535 } 536 case nir_intrinsic_load_shader_record_ptr: { 537 ret = nir_load_var(&b_shader, vars->shader_record_ptr); 538 break; 539 } 540 case nir_intrinsic_load_ray_launch_id: { 541 ret = nir_load_global_invocation_id(&b_shader, 32); 542 break; 543 } 544 case nir_intrinsic_load_ray_launch_size: { 545 nir_ssa_def *launch_size_addr = 546 nir_load_ray_launch_size_addr_amd(&b_shader); 547 548 nir_ssa_def * xy = nir_build_load_smem_amd( 549 &b_shader, 2, launch_size_addr, nir_imm_int(&b_shader, 0)); 550 nir_ssa_def * z = nir_build_load_smem_amd( 551 &b_shader, 1, launch_size_addr, nir_imm_int(&b_shader, 8)); 552 553 nir_ssa_def *xyz[3] = { 554 nir_channel(&b_shader, xy, 0), 555 nir_channel(&b_shader, xy, 1), 556 z, 557 }; 558 ret = nir_vec(&b_shader, xyz, 3); 559 break; 560 } 561 case nir_intrinsic_load_ray_t_min: { 562 ret = nir_load_var(&b_shader, vars->tmin); 563 break; 564 } 565 case nir_intrinsic_load_ray_t_max: { 566 ret = nir_load_var(&b_shader, vars->tmax); 567 break; 568 } 569 case nir_intrinsic_load_ray_world_origin: { 570 ret = nir_load_var(&b_shader, vars->origin); 571 break; 572 } 573 case nir_intrinsic_load_ray_world_direction: { 574 ret = nir_load_var(&b_shader, vars->direction); 575 break; 576 } 577 case nir_intrinsic_load_ray_instance_custom_index: { 578 ret = nir_load_var(&b_shader, vars->custom_instance_and_mask); 579 ret = nir_iand_imm(&b_shader, ret, 0xFFFFFF); 580 break; 581 } 582 case nir_intrinsic_load_primitive_id: { 583 ret = nir_load_var(&b_shader, vars->primitive_id); 584 break; 585 } 586 case nir_intrinsic_load_ray_geometry_index: { 587 ret = nir_load_var(&b_shader, vars->geometry_id_and_flags); 588 ret = nir_iand_imm(&b_shader, ret, 0xFFFFFFF); 589 break; 590 } 591 case nir_intrinsic_load_instance_id: { 592 ret = nir_load_var(&b_shader, vars->instance_id); 593 break; 594 } 595 case nir_intrinsic_load_ray_flags: { 596 ret = nir_load_var(&b_shader, vars->flags); 597 break; 598 } 599 case nir_intrinsic_load_ray_hit_kind: { 600 ret = nir_load_var(&b_shader, vars->hit_kind); 601 break; 602 } 603 case nir_intrinsic_load_ray_world_to_object: { 604 unsigned c = nir_intrinsic_column(intr); 605 nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); 606 nir_ssa_def *wto_matrix[3]; 607 nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); 608 609 nir_ssa_def *vals[3]; 610 for (unsigned i = 0; i < 3; ++i) 611 vals[i] = nir_channel(&b_shader, wto_matrix[i], c); 612 613 ret = nir_vec(&b_shader, vals, 3); 614 if (c == 3) 615 ret = nir_fneg(&b_shader, 616 nir_build_vec3_mat_mult(&b_shader, ret, wto_matrix, false)); 617 break; 618 } 619 case nir_intrinsic_load_ray_object_to_world: { 620 unsigned c = nir_intrinsic_column(intr); 621 nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); 622 if (c == 3) { 623 nir_ssa_def *wto_matrix[3]; 624 nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); 625 626 nir_ssa_def *vals[3]; 627 for (unsigned i = 0; i < 3; ++i) 628 vals[i] = nir_channel(&b_shader, wto_matrix[i], c); 629 630 ret = nir_vec(&b_shader, vals, 3); 631 } else { 632 ret = nir_build_load_global( 633 &b_shader, 3, 32, nir_iadd_imm(&b_shader, instance_node_addr, 92 + c * 12)); 634 } 635 break; 636 } 637 case nir_intrinsic_load_ray_object_origin: { 638 nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); 639 nir_ssa_def *wto_matrix[] = { 640 nir_build_load_global(&b_shader, 4, 32, 641 nir_iadd_imm(&b_shader, instance_node_addr, 16), 642 .align_mul = 64, .align_offset = 16), 643 nir_build_load_global(&b_shader, 4, 32, 644 nir_iadd_imm(&b_shader, instance_node_addr, 32), 645 .align_mul = 64, .align_offset = 32), 646 nir_build_load_global(&b_shader, 4, 32, 647 nir_iadd_imm(&b_shader, instance_node_addr, 48), 648 .align_mul = 64, .align_offset = 48)}; 649 ret = nir_build_vec3_mat_mult_pre( 650 &b_shader, nir_load_var(&b_shader, vars->origin), wto_matrix); 651 break; 652 } 653 case nir_intrinsic_load_ray_object_direction: { 654 nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr); 655 nir_ssa_def *wto_matrix[3]; 656 nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix); 657 ret = nir_build_vec3_mat_mult( 658 &b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false); 659 break; 660 } 661 case nir_intrinsic_load_intersection_opaque_amd: { 662 ret = nir_load_var(&b_shader, vars->opaque); 663 break; 664 } 665 case nir_intrinsic_load_cull_mask: { 666 ret = nir_load_var(&b_shader, vars->cull_mask); 667 break; 668 } 669 case nir_intrinsic_ignore_ray_intersection: { 670 nir_store_var(&b_shader, vars->ahit_accept, nir_imm_false(&b_shader), 0x1); 671 672 /* The if is a workaround to avoid having to fix up control flow manually */ 673 nir_push_if(&b_shader, nir_imm_true(&b_shader)); 674 nir_jump(&b_shader, nir_jump_return); 675 nir_pop_if(&b_shader, NULL); 676 break; 677 } 678 case nir_intrinsic_terminate_ray: { 679 nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1); 680 nir_store_var(&b_shader, vars->ahit_terminate, nir_imm_true(&b_shader), 0x1); 681 682 /* The if is a workaround to avoid having to fix up control flow manually */ 683 nir_push_if(&b_shader, nir_imm_true(&b_shader)); 684 nir_jump(&b_shader, nir_jump_return); 685 nir_pop_if(&b_shader, NULL); 686 break; 687 } 688 case nir_intrinsic_report_ray_intersection: { 689 nir_push_if( 690 &b_shader, 691 nir_iand( 692 &b_shader, 693 nir_fge(&b_shader, nir_load_var(&b_shader, vars->tmax), intr->src[0].ssa), 694 nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin)))); 695 { 696 nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1); 697 nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1); 698 nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1); 699 } 700 nir_pop_if(&b_shader, NULL); 701 break; 702 } 703 default: 704 continue; 705 } 706 707 if (ret) 708 nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret); 709 nir_instr_remove(instr); 710 break; 711 } 712 case nir_instr_type_jump: { 713 nir_jump_instr *jump = nir_instr_as_jump(instr); 714 if (jump->type == nir_jump_halt) { 715 b_shader.cursor = nir_instr_remove(instr); 716 nir_jump(&b_shader, nir_jump_return); 717 } 718 break; 719 } 720 default: 721 break; 722 } 723 } 724 } 725 726 nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none); 727} 728 729static void 730insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx, 731 uint32_t call_idx_base, uint32_t call_idx) 732{ 733 struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL); 734 735 nir_opt_dead_cf(shader); 736 737 struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->stack_sizes); 738 map_rt_variables(var_remap, &src_vars, vars); 739 740 NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base); 741 742 NIR_PASS(_, shader, nir_opt_remove_phis); 743 NIR_PASS(_, shader, nir_lower_returns); 744 NIR_PASS(_, shader, nir_opt_dce); 745 746 reserve_stack_size(vars, shader->scratch_size); 747 748 nir_push_if(b, nir_ieq_imm(b, idx, call_idx)); 749 nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1); 750 nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); 751 nir_pop_if(b, NULL); 752 753 /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */ 754 ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader)); 755 756 ralloc_free(var_remap); 757} 758 759static bool 760lower_rt_derefs(nir_shader *shader) 761{ 762 nir_function_impl *impl = nir_shader_get_entrypoint(shader); 763 764 bool progress = false; 765 766 nir_builder b; 767 nir_builder_init(&b, impl); 768 769 b.cursor = nir_before_cf_list(&impl->body); 770 nir_ssa_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b); 771 772 nir_foreach_block (block, impl) { 773 nir_foreach_instr_safe (instr, block) { 774 if (instr->type != nir_instr_type_deref) 775 continue; 776 777 nir_deref_instr *deref = nir_instr_as_deref(instr); 778 b.cursor = nir_before_instr(&deref->instr); 779 780 nir_deref_instr *replacement = NULL; 781 if (nir_deref_mode_is(deref, nir_var_shader_call_data)) { 782 deref->modes = nir_var_function_temp; 783 progress = true; 784 785 if (deref->deref_type == nir_deref_type_var) 786 replacement = 787 nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0); 788 } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) { 789 deref->modes = nir_var_function_temp; 790 progress = true; 791 792 if (deref->deref_type == nir_deref_type_var) 793 replacement = nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET), 794 nir_var_function_temp, deref->type, 0); 795 } 796 797 if (replacement != NULL) { 798 nir_ssa_def_rewrite_uses(&deref->dest.ssa, &replacement->dest.ssa); 799 nir_instr_remove(&deref->instr); 800 } 801 } 802 } 803 804 if (progress) 805 nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance); 806 else 807 nir_metadata_preserve(impl, nir_metadata_all); 808 809 return progress; 810} 811 812static nir_shader * 813parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo) 814{ 815 struct radv_pipeline_key key; 816 memset(&key, 0, sizeof(key)); 817 818 struct radv_pipeline_stage rt_stage; 819 820 radv_pipeline_stage_init(sinfo, &rt_stage, vk_to_mesa_shader_stage(sinfo->stage)); 821 822 nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, &key); 823 824 if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT || 825 shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) { 826 nir_block *last_block = nir_impl_last_block(nir_shader_get_entrypoint(shader)); 827 nir_builder b_inner; 828 nir_builder_init(&b_inner, nir_shader_get_entrypoint(shader)); 829 b_inner.cursor = nir_after_block(last_block); 830 nir_rt_return_amd(&b_inner); 831 } 832 833 NIR_PASS(_, shader, nir_lower_vars_to_explicit_types, 834 nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib, 835 glsl_get_natural_size_align_bytes); 836 837 NIR_PASS(_, shader, lower_rt_derefs); 838 839 NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp, 840 nir_address_format_32bit_offset); 841 842 return shader; 843} 844 845static nir_function_impl * 846lower_any_hit_for_intersection(nir_shader *any_hit) 847{ 848 nir_function_impl *impl = nir_shader_get_entrypoint(any_hit); 849 850 /* Any-hit shaders need three parameters */ 851 assert(impl->function->num_params == 0); 852 nir_parameter params[] = { 853 { 854 /* A pointer to a boolean value for whether or not the hit was 855 * accepted. 856 */ 857 .num_components = 1, 858 .bit_size = 32, 859 }, 860 { 861 /* The hit T value */ 862 .num_components = 1, 863 .bit_size = 32, 864 }, 865 { 866 /* The hit kind */ 867 .num_components = 1, 868 .bit_size = 32, 869 }, 870 }; 871 impl->function->num_params = ARRAY_SIZE(params); 872 impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params)); 873 memcpy(impl->function->params, params, sizeof(params)); 874 875 nir_builder build; 876 nir_builder_init(&build, impl); 877 nir_builder *b = &build; 878 879 b->cursor = nir_before_cf_list(&impl->body); 880 881 nir_ssa_def *commit_ptr = nir_load_param(b, 0); 882 nir_ssa_def *hit_t = nir_load_param(b, 1); 883 nir_ssa_def *hit_kind = nir_load_param(b, 2); 884 885 nir_deref_instr *commit = 886 nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0); 887 888 nir_foreach_block_safe (block, impl) { 889 nir_foreach_instr_safe (instr, block) { 890 switch (instr->type) { 891 case nir_instr_type_intrinsic: { 892 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 893 switch (intrin->intrinsic) { 894 case nir_intrinsic_ignore_ray_intersection: 895 b->cursor = nir_instr_remove(&intrin->instr); 896 /* We put the newly emitted code inside a dummy if because it's 897 * going to contain a jump instruction and we don't want to 898 * deal with that mess here. It'll get dealt with by our 899 * control-flow optimization passes. 900 */ 901 nir_store_deref(b, commit, nir_imm_false(b), 0x1); 902 nir_push_if(b, nir_imm_true(b)); 903 nir_jump(b, nir_jump_return); 904 nir_pop_if(b, NULL); 905 break; 906 907 case nir_intrinsic_terminate_ray: 908 /* The "normal" handling of terminateRay works fine in 909 * intersection shaders. 910 */ 911 break; 912 913 case nir_intrinsic_load_ray_t_max: 914 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_t); 915 nir_instr_remove(&intrin->instr); 916 break; 917 918 case nir_intrinsic_load_ray_hit_kind: 919 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_kind); 920 nir_instr_remove(&intrin->instr); 921 break; 922 923 default: 924 break; 925 } 926 break; 927 } 928 case nir_instr_type_jump: { 929 nir_jump_instr *jump = nir_instr_as_jump(instr); 930 if (jump->type == nir_jump_halt) { 931 b->cursor = nir_instr_remove(instr); 932 nir_jump(b, nir_jump_return); 933 } 934 break; 935 } 936 937 default: 938 break; 939 } 940 } 941 } 942 943 nir_validate_shader(any_hit, "after initial any-hit lowering"); 944 945 nir_lower_returns_impl(impl); 946 947 nir_validate_shader(any_hit, "after lowering returns"); 948 949 return impl; 950} 951 952/* Inline the any_hit shader into the intersection shader so we don't have 953 * to implement yet another shader call interface here. Neither do any recursion. 954 */ 955static void 956nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit) 957{ 958 void *dead_ctx = ralloc_context(intersection); 959 960 nir_function_impl *any_hit_impl = NULL; 961 struct hash_table *any_hit_var_remap = NULL; 962 if (any_hit) { 963 any_hit = nir_shader_clone(dead_ctx, any_hit); 964 NIR_PASS(_, any_hit, nir_opt_dce); 965 any_hit_impl = lower_any_hit_for_intersection(any_hit); 966 any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx); 967 } 968 969 nir_function_impl *impl = nir_shader_get_entrypoint(intersection); 970 971 nir_builder build; 972 nir_builder_init(&build, impl); 973 nir_builder *b = &build; 974 975 b->cursor = nir_before_cf_list(&impl->body); 976 977 nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit"); 978 nir_store_var(b, commit, nir_imm_false(b), 0x1); 979 980 nir_foreach_block_safe (block, impl) { 981 nir_foreach_instr_safe (instr, block) { 982 if (instr->type != nir_instr_type_intrinsic) 983 continue; 984 985 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 986 if (intrin->intrinsic != nir_intrinsic_report_ray_intersection) 987 continue; 988 989 b->cursor = nir_instr_remove(&intrin->instr); 990 nir_ssa_def *hit_t = nir_ssa_for_src(b, intrin->src[0], 1); 991 nir_ssa_def *hit_kind = nir_ssa_for_src(b, intrin->src[1], 1); 992 nir_ssa_def *min_t = nir_load_ray_t_min(b); 993 nir_ssa_def *max_t = nir_load_ray_t_max(b); 994 995 /* bool commit_tmp = false; */ 996 nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp"); 997 nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1); 998 999 nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t))); 1000 { 1001 /* Any-hit defaults to commit */ 1002 nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1); 1003 1004 if (any_hit_impl != NULL) { 1005 nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b))); 1006 { 1007 nir_ssa_def *params[] = { 1008 &nir_build_deref_var(b, commit_tmp)->dest.ssa, 1009 hit_t, 1010 hit_kind, 1011 }; 1012 nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap); 1013 } 1014 nir_pop_if(b, NULL); 1015 } 1016 1017 nir_push_if(b, nir_load_var(b, commit_tmp)); 1018 { 1019 nir_report_ray_intersection(b, 1, hit_t, hit_kind); 1020 } 1021 nir_pop_if(b, NULL); 1022 } 1023 nir_pop_if(b, NULL); 1024 1025 nir_ssa_def *accepted = nir_load_var(b, commit_tmp); 1026 nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted); 1027 } 1028 } 1029 1030 /* We did some inlining; have to re-index SSA defs */ 1031 nir_index_ssa_defs(impl); 1032 1033 /* Eliminate the casts introduced for the commit return of the any-hit shader. */ 1034 NIR_PASS(_, intersection, nir_opt_deref); 1035 1036 ralloc_free(dead_ctx); 1037} 1038 1039/* Variables only used internally to ray traversal. This is data that describes 1040 * the current state of the traversal vs. what we'd give to a shader. e.g. what 1041 * is the instance we're currently visiting vs. what is the instance of the 1042 * closest hit. */ 1043struct rt_traversal_vars { 1044 nir_variable *origin; 1045 nir_variable *dir; 1046 nir_variable *inv_dir; 1047 nir_variable *sbt_offset_and_flags; 1048 nir_variable *instance_id; 1049 nir_variable *custom_instance_and_mask; 1050 nir_variable *instance_addr; 1051 nir_variable *hit; 1052 nir_variable *bvh_base; 1053 nir_variable *stack; 1054 nir_variable *top_stack; 1055}; 1056 1057static struct rt_traversal_vars 1058init_traversal_vars(nir_builder *b) 1059{ 1060 const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); 1061 struct rt_traversal_vars ret; 1062 1063 ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin"); 1064 ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir"); 1065 ret.inv_dir = 1066 nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir"); 1067 ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), 1068 "traversal_sbt_offset_and_flags"); 1069 ret.instance_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), 1070 "traversal_instance_id"); 1071 ret.custom_instance_and_mask = nir_variable_create( 1072 b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_custom_instance_and_mask"); 1073 ret.instance_addr = 1074 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr"); 1075 ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit"); 1076 ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), 1077 "traversal_bvh_base"); 1078 ret.stack = 1079 nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr"); 1080 ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), 1081 "traversal_top_stack_ptr"); 1082 return ret; 1083} 1084 1085static void 1086visit_any_hit_shaders(struct radv_device *device, 1087 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, 1088 struct rt_variables *vars) 1089{ 1090 nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx); 1091 1092 nir_push_if(b, nir_ine_imm(b, sbt_idx, 0)); 1093 for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { 1094 const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; 1095 uint32_t shader_id = VK_SHADER_UNUSED_KHR; 1096 1097 switch (group_info->type) { 1098 case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: 1099 shader_id = group_info->anyHitShader; 1100 break; 1101 default: 1102 break; 1103 } 1104 if (shader_id == VK_SHADER_UNUSED_KHR) 1105 continue; 1106 1107 const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; 1108 nir_shader *nir_stage = parse_rt_stage(device, stage); 1109 1110 vars->stage_idx = shader_id; 1111 insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2); 1112 } 1113 nir_pop_if(b, NULL); 1114} 1115 1116static void 1117insert_traversal_triangle_case(struct radv_device *device, 1118 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, 1119 nir_ssa_def *result, const struct rt_variables *vars, 1120 const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node) 1121{ 1122 nir_ssa_def *dist = nir_channel(b, result, 0); 1123 nir_ssa_def *div = nir_channel(b, result, 1); 1124 dist = nir_fdiv(b, dist, div); 1125 nir_ssa_def *frontface = nir_flt(b, nir_imm_float(b, 0), div); 1126 nir_ssa_def *switch_ccw = 1127 nir_test_mask(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 1128 VK_GEOMETRY_INSTANCE_TRIANGLE_FLIP_FACING_BIT_KHR << 24); 1129 frontface = nir_ixor(b, frontface, switch_ccw); 1130 1131 nir_ssa_def *not_cull = 1132 nir_inot(b, nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsSkipTrianglesKHRMask)); 1133 nir_ssa_def *not_facing_cull = nir_ieq_imm( 1134 b, 1135 nir_iand(b, nir_load_var(b, vars->flags), 1136 nir_bcsel(b, frontface, nir_imm_int(b, SpvRayFlagsCullFrontFacingTrianglesKHRMask), 1137 nir_imm_int(b, SpvRayFlagsCullBackFacingTrianglesKHRMask))), 1138 0); 1139 1140 not_cull = nir_iand( 1141 b, not_cull, 1142 nir_ior(b, not_facing_cull, 1143 nir_test_mask(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 1144 VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR << 24))); 1145 1146 nir_push_if(b, nir_iand(b, 1147 nir_iand(b, nir_flt(b, dist, nir_load_var(b, vars->tmax)), 1148 nir_flt(b, nir_load_var(b, vars->tmin), dist)), 1149 not_cull)); 1150 { 1151 1152 nir_ssa_def *triangle_info = 1153 nir_build_load_global(b, 2, 32, 1154 nir_iadd_imm(b, build_node_to_addr(device, b, bvh_node), 1155 offsetof(struct radv_bvh_triangle_node, triangle_id))); 1156 nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0); 1157 nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1); 1158 nir_ssa_def *geometry_id = nir_iand_imm(b, geometry_id_and_flags, 0xfffffff); 1159 nir_ssa_def *is_opaque = hit_is_opaque(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 1160 nir_load_var(b, vars->flags), geometry_id_and_flags); 1161 1162 not_cull = 1163 nir_ieq_imm(b, 1164 nir_iand(b, nir_load_var(b, vars->flags), 1165 nir_bcsel(b, is_opaque, nir_imm_int(b, SpvRayFlagsCullOpaqueKHRMask), 1166 nir_imm_int(b, SpvRayFlagsCullNoOpaqueKHRMask))), 1167 0); 1168 nir_push_if(b, not_cull); 1169 { 1170 nir_ssa_def *sbt_idx = nir_iadd( 1171 b, 1172 nir_iadd(b, nir_load_var(b, vars->sbt_offset), 1173 nir_iand_imm(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 0xffffff)), 1174 nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id)); 1175 nir_ssa_def *divs[2] = {div, div}; 1176 nir_ssa_def *ij = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2)); 1177 nir_ssa_def *hit_kind = 1178 nir_bcsel(b, frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF)); 1179 1180 nir_store_scratch( 1181 b, ij, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), RADV_HIT_ATTRIB_OFFSET), 1182 .align_mul = 16); 1183 1184 nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1); 1185 nir_store_var(b, vars->ahit_terminate, nir_imm_false(b), 0x1); 1186 1187 nir_push_if(b, nir_inot(b, is_opaque)); 1188 { 1189 struct rt_variables inner_vars = create_inner_vars(b, vars); 1190 1191 nir_store_var(b, inner_vars.primitive_id, primitive_id, 1); 1192 nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1); 1193 nir_store_var(b, inner_vars.tmax, dist, 0x1); 1194 nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); 1195 nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 1196 0x1); 1197 nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1); 1198 nir_store_var(b, inner_vars.custom_instance_and_mask, 1199 nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); 1200 1201 load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4); 1202 1203 visit_any_hit_shaders(device, pCreateInfo, b, &inner_vars); 1204 1205 nir_push_if(b, nir_inot(b, nir_load_var(b, vars->ahit_accept))); 1206 { 1207 nir_jump(b, nir_jump_continue); 1208 } 1209 nir_pop_if(b, NULL); 1210 } 1211 nir_pop_if(b, NULL); 1212 1213 nir_store_var(b, vars->primitive_id, primitive_id, 1); 1214 nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1); 1215 nir_store_var(b, vars->tmax, dist, 0x1); 1216 nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); 1217 nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); 1218 nir_store_var(b, vars->hit_kind, hit_kind, 0x1); 1219 nir_store_var(b, vars->custom_instance_and_mask, 1220 nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); 1221 1222 nir_store_var(b, vars->idx, sbt_idx, 1); 1223 nir_store_var(b, trav_vars->hit, nir_imm_true(b), 1); 1224 1225 nir_ssa_def *terminate_on_first_hit = 1226 nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask); 1227 nir_ssa_def *ray_terminated = nir_load_var(b, vars->ahit_terminate); 1228 nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated)); 1229 { 1230 nir_jump(b, nir_jump_break); 1231 } 1232 nir_pop_if(b, NULL); 1233 } 1234 nir_pop_if(b, NULL); 1235 } 1236 nir_pop_if(b, NULL); 1237} 1238 1239static void 1240insert_traversal_aabb_case(struct radv_device *device, 1241 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, 1242 const struct rt_variables *vars, 1243 const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node) 1244{ 1245 nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node); 1246 nir_ssa_def *triangle_info = nir_build_load_global(b, 2, 32, nir_iadd_imm(b, node_addr, 24)); 1247 nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0); 1248 nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1); 1249 nir_ssa_def *geometry_id = nir_iand_imm(b, geometry_id_and_flags, 0xfffffff); 1250 nir_ssa_def *is_opaque = hit_is_opaque(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 1251 nir_load_var(b, vars->flags), geometry_id_and_flags); 1252 1253 nir_ssa_def *not_skip_aabb = 1254 nir_inot(b, nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsSkipAABBsKHRMask)); 1255 nir_ssa_def *not_cull = nir_iand( 1256 b, not_skip_aabb, 1257 nir_ieq_imm(b, 1258 nir_iand(b, nir_load_var(b, vars->flags), 1259 nir_bcsel(b, is_opaque, nir_imm_int(b, SpvRayFlagsCullOpaqueKHRMask), 1260 nir_imm_int(b, SpvRayFlagsCullNoOpaqueKHRMask))), 1261 0)); 1262 nir_push_if(b, not_cull); 1263 { 1264 nir_ssa_def *sbt_idx = nir_iadd( 1265 b, 1266 nir_iadd(b, nir_load_var(b, vars->sbt_offset), 1267 nir_iand_imm(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 0xffffff)), 1268 nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id)); 1269 1270 struct rt_variables inner_vars = create_inner_vars(b, vars); 1271 1272 /* For AABBs the intersection shader writes the hit kind, and only does it if it is the 1273 * next closest hit candidate. */ 1274 inner_vars.hit_kind = vars->hit_kind; 1275 1276 nir_store_var(b, inner_vars.primitive_id, primitive_id, 1); 1277 nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1); 1278 nir_store_var(b, inner_vars.tmax, nir_load_var(b, vars->tmax), 0x1); 1279 nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); 1280 nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); 1281 nir_store_var(b, inner_vars.custom_instance_and_mask, 1282 nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); 1283 nir_store_var(b, inner_vars.opaque, is_opaque, 1); 1284 1285 load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4); 1286 1287 nir_store_var(b, vars->ahit_accept, nir_imm_false(b), 0x1); 1288 nir_store_var(b, vars->ahit_terminate, nir_imm_false(b), 0x1); 1289 1290 nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0)); 1291 for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { 1292 const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; 1293 uint32_t shader_id = VK_SHADER_UNUSED_KHR; 1294 uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR; 1295 1296 switch (group_info->type) { 1297 case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: 1298 shader_id = group_info->intersectionShader; 1299 any_hit_shader_id = group_info->anyHitShader; 1300 break; 1301 default: 1302 break; 1303 } 1304 if (shader_id == VK_SHADER_UNUSED_KHR) 1305 continue; 1306 1307 const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; 1308 nir_shader *nir_stage = parse_rt_stage(device, stage); 1309 1310 nir_shader *any_hit_stage = NULL; 1311 if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) { 1312 stage = &pCreateInfo->pStages[any_hit_shader_id]; 1313 any_hit_stage = parse_rt_stage(device, stage); 1314 1315 nir_lower_intersection_shader(nir_stage, any_hit_stage); 1316 ralloc_free(any_hit_stage); 1317 } 1318 1319 inner_vars.stage_idx = shader_id; 1320 insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2); 1321 } 1322 nir_push_else(b, NULL); 1323 { 1324 nir_ssa_def *vec3_zero = nir_channels(b, nir_imm_vec4(b, 0, 0, 0, 0), 0x7); 1325 nir_ssa_def *vec3_inf = 1326 nir_channels(b, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, 0), 0x7); 1327 1328 nir_ssa_def *bvh_lo = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 0)); 1329 nir_ssa_def *bvh_hi = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 12)); 1330 1331 bvh_lo = nir_fsub(b, bvh_lo, nir_load_var(b, trav_vars->origin)); 1332 bvh_hi = nir_fsub(b, bvh_hi, nir_load_var(b, trav_vars->origin)); 1333 nir_ssa_def *t_vec = nir_fmin(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)), 1334 nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir))); 1335 nir_ssa_def *t2_vec = nir_fmax(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)), 1336 nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir))); 1337 /* If we run parallel to one of the edges the range should be [0, inf) not [0,0] */ 1338 t2_vec = 1339 nir_bcsel(b, nir_feq(b, nir_load_var(b, trav_vars->dir), vec3_zero), vec3_inf, t2_vec); 1340 1341 nir_ssa_def *t_min = nir_fmax(b, nir_channel(b, t_vec, 0), nir_channel(b, t_vec, 1)); 1342 t_min = nir_fmax(b, t_min, nir_channel(b, t_vec, 2)); 1343 1344 nir_ssa_def *t_max = nir_fmin(b, nir_channel(b, t2_vec, 0), nir_channel(b, t2_vec, 1)); 1345 t_max = nir_fmin(b, t_max, nir_channel(b, t2_vec, 2)); 1346 1347 nir_push_if(b, nir_iand(b, nir_fge(b, nir_load_var(b, vars->tmax), t_min), 1348 nir_fge(b, t_max, nir_load_var(b, vars->tmin)))); 1349 { 1350 nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1); 1351 nir_store_var(b, vars->tmax, nir_fmax(b, t_min, nir_load_var(b, vars->tmin)), 1); 1352 } 1353 nir_pop_if(b, NULL); 1354 } 1355 nir_pop_if(b, NULL); 1356 1357 nir_push_if(b, nir_load_var(b, vars->ahit_accept)); 1358 { 1359 nir_store_var(b, vars->primitive_id, primitive_id, 1); 1360 nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1); 1361 nir_store_var(b, vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1); 1362 nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1); 1363 nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1); 1364 nir_store_var(b, vars->custom_instance_and_mask, 1365 nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1); 1366 1367 nir_store_var(b, vars->idx, sbt_idx, 1); 1368 nir_store_var(b, trav_vars->hit, nir_imm_true(b), 1); 1369 1370 nir_ssa_def *terminate_on_first_hit = 1371 nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask); 1372 nir_ssa_def *ray_terminated = nir_load_var(b, vars->ahit_terminate); 1373 nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated)); 1374 { 1375 nir_jump(b, nir_jump_break); 1376 } 1377 nir_pop_if(b, NULL); 1378 } 1379 nir_pop_if(b, NULL); 1380 } 1381 nir_pop_if(b, NULL); 1382} 1383 1384static nir_shader * 1385build_traversal_shader(struct radv_device *device, 1386 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, 1387 const struct rt_variables *dst_vars, 1388 struct hash_table *var_remap) 1389{ 1390 nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal"); 1391 b.shader->info.internal = false; 1392 b.shader->info.workgroup_size[0] = 8; 1393 b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4; 1394 struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, dst_vars->stack_sizes); 1395 map_rt_variables(var_remap, &vars, dst_vars); 1396 1397 unsigned lanes = device->physical_device->rt_wave_size; 1398 unsigned elements = lanes * MAX_STACK_ENTRY_COUNT; 1399 nir_variable *stack_var = nir_variable_create(b.shader, nir_var_mem_shared, 1400 glsl_array_type(glsl_uint_type(), elements, 0), 1401 "trav_stack"); 1402 nir_deref_instr *stack_deref = nir_build_deref_var(&b, stack_var); 1403 nir_deref_instr *stack; 1404 nir_ssa_def *stack_idx_stride = nir_imm_int(&b, lanes); 1405 nir_ssa_def *stack_idx_base = nir_load_local_invocation_index(&b); 1406 1407 nir_ssa_def *accel_struct = nir_load_var(&b, vars.accel_struct); 1408 1409 struct rt_traversal_vars trav_vars = init_traversal_vars(&b); 1410 1411 nir_store_var(&b, trav_vars.hit, nir_imm_false(&b), 1); 1412 1413 nir_push_if(&b, nir_ine_imm(&b, accel_struct, 0)); 1414 { 1415 nir_store_var(&b, trav_vars.bvh_base, build_addr_to_node(&b, accel_struct), 1); 1416 1417 nir_ssa_def *bvh_root = nir_build_load_global( 1418 &b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, .align_mul = 64); 1419 1420 nir_ssa_def *desc = create_bvh_descriptor(&b); 1421 nir_ssa_def *vec3ones = nir_channels(&b, nir_imm_vec4(&b, 1.0, 1.0, 1.0, 1.0), 0x7); 1422 1423 nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7); 1424 nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7); 1425 nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7); 1426 nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1); 1427 nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1); 1428 1429 nir_store_var(&b, trav_vars.stack, nir_iadd(&b, stack_idx_base, stack_idx_stride), 1); 1430 stack = nir_build_deref_array(&b, stack_deref, stack_idx_base); 1431 nir_store_deref(&b, stack, bvh_root, 0x1); 1432 1433 nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1); 1434 1435 nir_push_loop(&b); 1436 1437 nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, trav_vars.stack), stack_idx_base)); 1438 nir_jump(&b, nir_jump_break); 1439 nir_pop_if(&b, NULL); 1440 1441 nir_push_if( 1442 &b, nir_uge(&b, nir_load_var(&b, trav_vars.top_stack), nir_load_var(&b, trav_vars.stack))); 1443 nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1); 1444 nir_store_var(&b, trav_vars.bvh_base, 1445 build_addr_to_node(&b, nir_load_var(&b, vars.accel_struct)), 1); 1446 nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7); 1447 nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7); 1448 nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7); 1449 nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1); 1450 1451 nir_pop_if(&b, NULL); 1452 1453 nir_store_var(&b, trav_vars.stack, 1454 nir_isub(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1); 1455 1456 stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack)); 1457 nir_ssa_def *bvh_node = nir_load_deref(&b, stack); 1458 nir_ssa_def *bvh_node_type = nir_iand_imm(&b, bvh_node, 7); 1459 1460 bvh_node = nir_iadd(&b, nir_load_var(&b, trav_vars.bvh_base), nir_u2u(&b, bvh_node, 64)); 1461 nir_ssa_def *intrinsic_result = NULL; 1462 if (!radv_emulate_rt(device->physical_device)) { 1463 intrinsic_result = nir_bvh64_intersect_ray_amd( 1464 &b, 32, desc, nir_unpack_64_2x32(&b, bvh_node), nir_load_var(&b, vars.tmax), 1465 nir_load_var(&b, trav_vars.origin), nir_load_var(&b, trav_vars.dir), 1466 nir_load_var(&b, trav_vars.inv_dir)); 1467 } 1468 1469 nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 4), 0)); 1470 { 1471 nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 2), 0)); 1472 { 1473 /* custom */ 1474 nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 1), 0)); 1475 if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR)) { 1476 insert_traversal_aabb_case(device, pCreateInfo, &b, &vars, &trav_vars, bvh_node); 1477 } 1478 nir_push_else(&b, NULL); 1479 { 1480 /* instance */ 1481 nir_ssa_def *instance_node_addr = build_node_to_addr(device, &b, bvh_node); 1482 nir_ssa_def *instance_data = 1483 nir_build_load_global(&b, 4, 32, instance_node_addr, .align_mul = 64); 1484 nir_ssa_def *wto_matrix[] = { 1485 nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 16), 1486 .align_mul = 64, .align_offset = 16), 1487 nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 32), 1488 .align_mul = 64, .align_offset = 32), 1489 nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 48), 1490 .align_mul = 64, .align_offset = 48)}; 1491 nir_ssa_def *instance_id = 1492 nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, instance_node_addr, 88)); 1493 nir_ssa_def *instance_and_mask = nir_channel(&b, instance_data, 2); 1494 nir_ssa_def *instance_mask = nir_ushr_imm(&b, instance_and_mask, 24); 1495 1496 nir_push_if( 1497 &b, 1498 nir_ieq_imm(&b, nir_iand(&b, instance_mask, nir_load_var(&b, vars.cull_mask)), 0)); 1499 nir_jump(&b, nir_jump_continue); 1500 nir_pop_if(&b, NULL); 1501 1502 nir_store_var(&b, trav_vars.top_stack, nir_load_var(&b, trav_vars.stack), 1); 1503 nir_store_var(&b, trav_vars.bvh_base, 1504 build_addr_to_node( 1505 &b, nir_pack_64_2x32(&b, nir_channels(&b, instance_data, 0x3))), 1506 1); 1507 stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack)); 1508 nir_store_deref(&b, stack, nir_iand_imm(&b, nir_channel(&b, instance_data, 0), 63), 0x1); 1509 1510 nir_store_var(&b, trav_vars.stack, 1511 nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1); 1512 1513 nir_store_var( 1514 &b, trav_vars.origin, 1515 nir_build_vec3_mat_mult_pre(&b, nir_load_var(&b, vars.origin), wto_matrix), 7); 1516 nir_store_var( 1517 &b, trav_vars.dir, 1518 nir_build_vec3_mat_mult(&b, nir_load_var(&b, vars.direction), wto_matrix, false), 1519 7); 1520 nir_store_var(&b, trav_vars.inv_dir, 1521 nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7); 1522 nir_store_var(&b, trav_vars.custom_instance_and_mask, instance_and_mask, 1); 1523 nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_channel(&b, instance_data, 3), 1524 1); 1525 nir_store_var(&b, trav_vars.instance_id, instance_id, 1); 1526 nir_store_var(&b, trav_vars.instance_addr, instance_node_addr, 1); 1527 } 1528 nir_pop_if(&b, NULL); 1529 } 1530 nir_push_else(&b, NULL); 1531 { 1532 /* box */ 1533 nir_ssa_def *result = intrinsic_result; 1534 if (!result) { 1535 /* If we didn't run the intrinsic cause the hardware didn't support it, 1536 * emulate ray/box intersection here */ 1537 result = intersect_ray_amd_software_box(device, 1538 &b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin), 1539 nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir)); 1540 } 1541 1542 for (unsigned i = 4; i-- > 0; ) { 1543 nir_ssa_def *new_node = nir_channel(&b, result, i); 1544 nir_push_if(&b, nir_ine_imm(&b, new_node, 0xffffffff)); 1545 { 1546 stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack)); 1547 nir_store_deref(&b, stack, new_node, 0x1); 1548 nir_store_var( 1549 &b, trav_vars.stack, 1550 nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1); 1551 } 1552 nir_pop_if(&b, NULL); 1553 } 1554 } 1555 nir_pop_if(&b, NULL); 1556 } 1557 nir_push_else(&b, NULL); 1558 if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)) { 1559 nir_ssa_def *result = intrinsic_result; 1560 if (!result) { 1561 /* If we didn't run the intrinsic cause the hardware didn't support it, 1562 * emulate ray/tri intersection here */ 1563 result = intersect_ray_amd_software_tri(device, 1564 &b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin), 1565 nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir)); 1566 } 1567 insert_traversal_triangle_case(device, pCreateInfo, &b, result, &vars, &trav_vars, bvh_node); 1568 } 1569 nir_pop_if(&b, NULL); 1570 1571 nir_pop_loop(&b, NULL); 1572 } 1573 nir_pop_if(&b, NULL); 1574 1575 /* Initialize follow-up shader. */ 1576 nir_push_if(&b, nir_load_var(&b, trav_vars.hit)); 1577 { 1578 /* vars.idx contains the SBT index at this point. */ 1579 load_sbt_entry(&b, &vars, nir_load_var(&b, vars.idx), SBT_HIT, 0); 1580 1581 nir_ssa_def *should_return = nir_ior(&b, 1582 nir_test_mask(&b, nir_load_var(&b, vars.flags), 1583 SpvRayFlagsSkipClosestHitShaderKHRMask), 1584 nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0)); 1585 1586 /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence 1587 * need to return immediately to the calling shader. */ 1588 nir_push_if(&b, should_return); 1589 { 1590 insert_rt_return(&b, &vars); 1591 } 1592 nir_pop_if(&b, NULL); 1593 } 1594 nir_push_else(&b, NULL); 1595 { 1596 /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer 1597 * for miss shaders if none of the rays miss. */ 1598 load_sbt_entry(&b, &vars, nir_load_var(&b, vars.miss_index), SBT_MISS, 0); 1599 } 1600 nir_pop_if(&b, NULL); 1601 1602 return b.shader; 1603} 1604 1605 1606static void 1607insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, 1608 nir_builder *b, const struct rt_variables *vars) 1609{ 1610 struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL); 1611 nir_shader *shader = build_traversal_shader(device, pCreateInfo, vars, var_remap); 1612 1613 /* For now, just inline the traversal shader */ 1614 nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 1)); 1615 nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1); 1616 nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); 1617 nir_pop_if(b, NULL); 1618 1619 /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */ 1620 ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader)); 1621 1622 ralloc_free(var_remap); 1623} 1624 1625static unsigned 1626compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, 1627 const struct radv_pipeline_shader_stack_size *stack_sizes) 1628{ 1629 unsigned raygen_size = 0; 1630 unsigned callable_size = 0; 1631 unsigned chit_size = 0; 1632 unsigned miss_size = 0; 1633 unsigned non_recursive_size = 0; 1634 1635 for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) { 1636 non_recursive_size = MAX2(stack_sizes[i].non_recursive_size, non_recursive_size); 1637 1638 const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i]; 1639 uint32_t shader_id = VK_SHADER_UNUSED_KHR; 1640 unsigned size = stack_sizes[i].recursive_size; 1641 1642 switch (group_info->type) { 1643 case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR: 1644 shader_id = group_info->generalShader; 1645 break; 1646 case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: 1647 case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: 1648 shader_id = group_info->closestHitShader; 1649 break; 1650 default: 1651 break; 1652 } 1653 if (shader_id == VK_SHADER_UNUSED_KHR) 1654 continue; 1655 1656 const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id]; 1657 switch (stage->stage) { 1658 case VK_SHADER_STAGE_RAYGEN_BIT_KHR: 1659 raygen_size = MAX2(raygen_size, size); 1660 break; 1661 case VK_SHADER_STAGE_MISS_BIT_KHR: 1662 miss_size = MAX2(miss_size, size); 1663 break; 1664 case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR: 1665 chit_size = MAX2(chit_size, size); 1666 break; 1667 case VK_SHADER_STAGE_CALLABLE_BIT_KHR: 1668 callable_size = MAX2(callable_size, size); 1669 break; 1670 default: 1671 unreachable("Invalid stage type in RT shader"); 1672 } 1673 } 1674 return raygen_size + 1675 MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) * 1676 MAX2(MAX2(chit_size, miss_size), non_recursive_size) + 1677 MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) * 1678 MAX2(chit_size, miss_size) + 1679 2 * callable_size; 1680} 1681 1682bool 1683radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo) 1684{ 1685 if (!pCreateInfo->pDynamicState) 1686 return false; 1687 1688 for (unsigned i = 0; i < pCreateInfo->pDynamicState->dynamicStateCount; ++i) { 1689 if (pCreateInfo->pDynamicState->pDynamicStates[i] == 1690 VK_DYNAMIC_STATE_RAY_TRACING_PIPELINE_STACK_SIZE_KHR) 1691 return true; 1692 } 1693 1694 return false; 1695} 1696 1697static bool 1698should_move_rt_instruction(nir_intrinsic_op intrinsic) 1699{ 1700 switch (intrinsic) { 1701 case nir_intrinsic_load_rt_arg_scratch_offset_amd: 1702 case nir_intrinsic_load_ray_flags: 1703 case nir_intrinsic_load_ray_object_origin: 1704 case nir_intrinsic_load_ray_world_origin: 1705 case nir_intrinsic_load_ray_t_min: 1706 case nir_intrinsic_load_ray_object_direction: 1707 case nir_intrinsic_load_ray_world_direction: 1708 case nir_intrinsic_load_ray_t_max: 1709 return true; 1710 default: 1711 return false; 1712 } 1713} 1714 1715static void 1716move_rt_instructions(nir_shader *shader) 1717{ 1718 nir_cursor target = nir_before_cf_list(&nir_shader_get_entrypoint(shader)->body); 1719 1720 nir_foreach_block (block, nir_shader_get_entrypoint(shader)) { 1721 nir_foreach_instr_safe (instr, block) { 1722 if (instr->type != nir_instr_type_intrinsic) 1723 continue; 1724 1725 nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr); 1726 1727 if (!should_move_rt_instruction(intrinsic->intrinsic)) 1728 continue; 1729 1730 nir_instr_move(target, instr); 1731 } 1732 } 1733 1734 nir_metadata_preserve(nir_shader_get_entrypoint(shader), 1735 nir_metadata_all & (~nir_metadata_instr_index)); 1736} 1737 1738static nir_shader * 1739create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, 1740 struct radv_pipeline_shader_stack_size *stack_sizes) 1741{ 1742 struct radv_pipeline_key key; 1743 memset(&key, 0, sizeof(key)); 1744 1745 nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined"); 1746 b.shader->info.internal = false; 1747 b.shader->info.workgroup_size[0] = 8; 1748 b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4; 1749 1750 struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes); 1751 load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, 0); 1752 nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1); 1753 1754 nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1); 1755 1756 nir_loop *loop = nir_push_loop(&b); 1757 1758 nir_push_if(&b, nir_ior(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0), 1759 nir_inot(&b, nir_load_var(&b, vars.main_loop_case_visited)))); 1760 nir_jump(&b, nir_jump_break); 1761 nir_pop_if(&b, NULL); 1762 1763 nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, false), 1); 1764 1765 insert_traversal(device, pCreateInfo, &b, &vars); 1766 1767 nir_ssa_def *idx = nir_load_var(&b, vars.idx); 1768 1769 /* We do a trick with the indexing of the resume shaders so that the first 1770 * shader of stage x always gets id x and the resume shader ids then come after 1771 * stageCount. This makes the shadergroup handles independent of compilation. */ 1772 unsigned call_idx_base = pCreateInfo->stageCount + 1; 1773 for (unsigned i = 0; i < pCreateInfo->stageCount; ++i) { 1774 const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i]; 1775 gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage); 1776 if (type != MESA_SHADER_RAYGEN && type != MESA_SHADER_CALLABLE && 1777 type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS) 1778 continue; 1779 1780 nir_shader *nir_stage = parse_rt_stage(device, stage); 1781 1782 /* Move ray tracing system values to the top that are set by rt_trace_ray 1783 * to prevent them from being overwritten by other rt_trace_ray calls. 1784 */ 1785 NIR_PASS_V(nir_stage, move_rt_instructions); 1786 1787 uint32_t num_resume_shaders = 0; 1788 nir_shader **resume_shaders = NULL; 1789 nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders, 1790 &num_resume_shaders, nir_stage); 1791 1792 vars.stage_idx = i; 1793 insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2); 1794 for (unsigned j = 0; j < num_resume_shaders; ++j) { 1795 insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j); 1796 } 1797 call_idx_base += num_resume_shaders; 1798 } 1799 1800 nir_pop_loop(&b, loop); 1801 1802 if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) { 1803 /* Put something so scratch gets enabled in the shader. */ 1804 b.shader->scratch_size = 16; 1805 } else 1806 b.shader->scratch_size = compute_rt_stack_size(pCreateInfo, stack_sizes); 1807 1808 /* Deal with all the inline functions. */ 1809 nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader)); 1810 nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none); 1811 1812 return b.shader; 1813} 1814 1815static VkResult 1816radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, 1817 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, 1818 const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline) 1819{ 1820 RADV_FROM_HANDLE(radv_device, device, _device); 1821 VkResult result; 1822 struct radv_pipeline *pipeline = NULL; 1823 struct radv_compute_pipeline *compute_pipeline = NULL; 1824 struct radv_pipeline_shader_stack_size *stack_sizes = NULL; 1825 uint8_t hash[20]; 1826 nir_shader *shader = NULL; 1827 bool keep_statistic_info = 1828 (pCreateInfo->flags & VK_PIPELINE_CREATE_CAPTURE_STATISTICS_BIT_KHR) || 1829 (device->instance->debug_flags & RADV_DEBUG_DUMP_SHADER_STATS) || device->keep_shader_info; 1830 1831 if (pCreateInfo->flags & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR) 1832 return radv_rt_pipeline_library_create(_device, _cache, pCreateInfo, pAllocator, pPipeline); 1833 1834 VkRayTracingPipelineCreateInfoKHR local_create_info = 1835 radv_create_merged_rt_create_info(pCreateInfo); 1836 if (!local_create_info.pStages || !local_create_info.pGroups) { 1837 result = VK_ERROR_OUT_OF_HOST_MEMORY; 1838 goto fail; 1839 } 1840 1841 radv_hash_rt_shaders(hash, &local_create_info, radv_get_hash_flags(device, keep_statistic_info)); 1842 struct vk_shader_module module = {.base.type = VK_OBJECT_TYPE_SHADER_MODULE}; 1843 1844 VkPipelineShaderStageRequiredSubgroupSizeCreateInfo subgroup_size = { 1845 .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO, 1846 .pNext = NULL, 1847 .requiredSubgroupSize = device->physical_device->rt_wave_size, 1848 }; 1849 1850 VkComputePipelineCreateInfo compute_info = { 1851 .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, 1852 .pNext = NULL, 1853 .flags = pCreateInfo->flags | VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT, 1854 .stage = 1855 { 1856 .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, 1857 .pNext = &subgroup_size, 1858 .stage = VK_SHADER_STAGE_COMPUTE_BIT, 1859 .module = vk_shader_module_to_handle(&module), 1860 .pName = "main", 1861 }, 1862 .layout = pCreateInfo->layout, 1863 }; 1864 1865 /* First check if we can get things from the cache before we take the expensive step of 1866 * generating the nir. */ 1867 result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash, 1868 stack_sizes, local_create_info.groupCount, pPipeline); 1869 1870 if (result == VK_PIPELINE_COMPILE_REQUIRED) { 1871 if (pCreateInfo->flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT) 1872 goto fail; 1873 1874 stack_sizes = calloc(sizeof(*stack_sizes), local_create_info.groupCount); 1875 if (!stack_sizes) { 1876 result = VK_ERROR_OUT_OF_HOST_MEMORY; 1877 goto fail; 1878 } 1879 1880 shader = create_rt_shader(device, &local_create_info, stack_sizes); 1881 module.nir = shader; 1882 compute_info.flags = pCreateInfo->flags; 1883 result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash, 1884 stack_sizes, local_create_info.groupCount, pPipeline); 1885 stack_sizes = NULL; 1886 1887 if (result != VK_SUCCESS) 1888 goto shader_fail; 1889 } 1890 pipeline = radv_pipeline_from_handle(*pPipeline); 1891 compute_pipeline = radv_pipeline_to_compute(pipeline); 1892 1893 compute_pipeline->rt_group_handles = 1894 calloc(sizeof(*compute_pipeline->rt_group_handles), local_create_info.groupCount); 1895 if (!compute_pipeline->rt_group_handles) { 1896 result = VK_ERROR_OUT_OF_HOST_MEMORY; 1897 goto shader_fail; 1898 } 1899 1900 compute_pipeline->dynamic_stack_size = radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo); 1901 1902 /* For General and ClosestHit shaders, we can use the shader ID directly as handle. 1903 * As (potentially different) AnyHit shaders are inlined, for Intersection shaders 1904 * we use the Group ID. 1905 */ 1906 for (unsigned i = 0; i < local_create_info.groupCount; ++i) { 1907 const VkRayTracingShaderGroupCreateInfoKHR *group_info = &local_create_info.pGroups[i]; 1908 switch (group_info->type) { 1909 case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR: 1910 if (group_info->generalShader != VK_SHADER_UNUSED_KHR) 1911 compute_pipeline->rt_group_handles[i].handles[0] = group_info->generalShader + 2; 1912 break; 1913 case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR: 1914 if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) 1915 compute_pipeline->rt_group_handles[i].handles[1] = i + 2; 1916 FALLTHROUGH; 1917 case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR: 1918 if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR) 1919 compute_pipeline->rt_group_handles[i].handles[0] = group_info->closestHitShader + 2; 1920 if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR) 1921 compute_pipeline->rt_group_handles[i].handles[1] = i + 2; 1922 break; 1923 case VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR: 1924 unreachable("VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR"); 1925 } 1926 } 1927 1928shader_fail: 1929 if (result != VK_SUCCESS && pipeline) 1930 radv_pipeline_destroy(device, pipeline, pAllocator); 1931 ralloc_free(shader); 1932fail: 1933 free((void *)local_create_info.pGroups); 1934 free((void *)local_create_info.pStages); 1935 free(stack_sizes); 1936 return result; 1937} 1938 1939VKAPI_ATTR VkResult VKAPI_CALL 1940radv_CreateRayTracingPipelinesKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation, 1941 VkPipelineCache pipelineCache, uint32_t count, 1942 const VkRayTracingPipelineCreateInfoKHR *pCreateInfos, 1943 const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines) 1944{ 1945 VkResult result = VK_SUCCESS; 1946 1947 unsigned i = 0; 1948 for (; i < count; i++) { 1949 VkResult r; 1950 r = radv_rt_pipeline_create(_device, pipelineCache, &pCreateInfos[i], pAllocator, 1951 &pPipelines[i]); 1952 if (r != VK_SUCCESS) { 1953 result = r; 1954 pPipelines[i] = VK_NULL_HANDLE; 1955 1956 if (pCreateInfos[i].flags & VK_PIPELINE_CREATE_EARLY_RETURN_ON_FAILURE_BIT) 1957 break; 1958 } 1959 } 1960 1961 for (; i < count; ++i) 1962 pPipelines[i] = VK_NULL_HANDLE; 1963 1964 return result; 1965} 1966 1967VKAPI_ATTR VkResult VKAPI_CALL 1968radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, uint32_t firstGroup, 1969 uint32_t groupCount, size_t dataSize, void *pData) 1970{ 1971 RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline); 1972 struct radv_compute_pipeline *compute_pipeline = radv_pipeline_to_compute(pipeline); 1973 char *data = pData; 1974 1975 STATIC_ASSERT(sizeof(*compute_pipeline->rt_group_handles) <= RADV_RT_HANDLE_SIZE); 1976 1977 memset(data, 0, groupCount * RADV_RT_HANDLE_SIZE); 1978 1979 for (uint32_t i = 0; i < groupCount; ++i) { 1980 memcpy(data + i * RADV_RT_HANDLE_SIZE, &compute_pipeline->rt_group_handles[firstGroup + i], 1981 sizeof(*compute_pipeline->rt_group_handles)); 1982 } 1983 1984 return VK_SUCCESS; 1985} 1986 1987VKAPI_ATTR VkDeviceSize VKAPI_CALL 1988radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline, uint32_t group, 1989 VkShaderGroupShaderKHR groupShader) 1990{ 1991 RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline); 1992 struct radv_compute_pipeline *compute_pipeline = radv_pipeline_to_compute(pipeline); 1993 const struct radv_pipeline_shader_stack_size *stack_size = 1994 &compute_pipeline->rt_stack_sizes[group]; 1995 1996 if (groupShader == VK_SHADER_GROUP_SHADER_ANY_HIT_KHR || 1997 groupShader == VK_SHADER_GROUP_SHADER_INTERSECTION_KHR) 1998 return stack_size->non_recursive_size; 1999 else 2000 return stack_size->recursive_size; 2001} 2002 2003VKAPI_ATTR VkResult VKAPI_CALL 2004radv_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice _device, VkPipeline pipeline, 2005 uint32_t firstGroup, uint32_t groupCount, 2006 size_t dataSize, void *pData) 2007{ 2008 RADV_FROM_HANDLE(radv_device, device, _device); 2009 unreachable("Unimplemented"); 2010 return vk_error(device, VK_ERROR_FEATURE_NOT_PRESENT); 2011} 2012