1/*
2 * Copyright © 2021 Google
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "radv_acceleration_structure.h"
25#include "radv_debug.h"
26#include "radv_meta.h"
27#include "radv_private.h"
28#include "radv_rt_common.h"
29#include "radv_shader.h"
30
31#include "nir/nir.h"
32#include "nir/nir_builder.h"
33#include "nir/nir_builtin_builder.h"
34
35static VkRayTracingPipelineCreateInfoKHR
36radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
37{
38   VkRayTracingPipelineCreateInfoKHR local_create_info = *pCreateInfo;
39   uint32_t total_stages = pCreateInfo->stageCount;
40   uint32_t total_groups = pCreateInfo->groupCount;
41
42   if (pCreateInfo->pLibraryInfo) {
43      for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
44         RADV_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]);
45         struct radv_library_pipeline *library_pipeline = radv_pipeline_to_library(pipeline);
46
47         total_stages += library_pipeline->stage_count;
48         total_groups += library_pipeline->group_count;
49      }
50   }
51   VkPipelineShaderStageCreateInfo *stages = NULL;
52   VkRayTracingShaderGroupCreateInfoKHR *groups = NULL;
53   local_create_info.stageCount = total_stages;
54   local_create_info.groupCount = total_groups;
55   local_create_info.pStages = stages =
56      malloc(sizeof(VkPipelineShaderStageCreateInfo) * total_stages);
57   local_create_info.pGroups = groups =
58      malloc(sizeof(VkRayTracingShaderGroupCreateInfoKHR) * total_groups);
59   if (!local_create_info.pStages || !local_create_info.pGroups)
60      return local_create_info;
61
62   total_stages = pCreateInfo->stageCount;
63   total_groups = pCreateInfo->groupCount;
64   for (unsigned j = 0; j < pCreateInfo->stageCount; ++j)
65      stages[j] = pCreateInfo->pStages[j];
66   for (unsigned j = 0; j < pCreateInfo->groupCount; ++j)
67      groups[j] = pCreateInfo->pGroups[j];
68
69   if (pCreateInfo->pLibraryInfo) {
70      for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
71         RADV_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]);
72         struct radv_library_pipeline *library_pipeline = radv_pipeline_to_library(pipeline);
73
74         for (unsigned j = 0; j < library_pipeline->stage_count; ++j)
75            stages[total_stages + j] = library_pipeline->stages[j];
76         for (unsigned j = 0; j < library_pipeline->group_count; ++j) {
77            VkRayTracingShaderGroupCreateInfoKHR *dst = &groups[total_groups + j];
78            *dst = library_pipeline->groups[j];
79            if (dst->generalShader != VK_SHADER_UNUSED_KHR)
80               dst->generalShader += total_stages;
81            if (dst->closestHitShader != VK_SHADER_UNUSED_KHR)
82               dst->closestHitShader += total_stages;
83            if (dst->anyHitShader != VK_SHADER_UNUSED_KHR)
84               dst->anyHitShader += total_stages;
85            if (dst->intersectionShader != VK_SHADER_UNUSED_KHR)
86               dst->intersectionShader += total_stages;
87         }
88         total_stages += library_pipeline->stage_count;
89         total_groups += library_pipeline->group_count;
90      }
91   }
92   return local_create_info;
93}
94
95static VkResult
96radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache,
97                                const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
98                                const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)
99{
100   RADV_FROM_HANDLE(radv_device, device, _device);
101   struct radv_library_pipeline *pipeline;
102
103   pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*pipeline), 8,
104                         VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
105   if (pipeline == NULL)
106      return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
107
108   radv_pipeline_init(device, &pipeline->base, RADV_PIPELINE_LIBRARY);
109
110   VkRayTracingPipelineCreateInfoKHR local_create_info =
111      radv_create_merged_rt_create_info(pCreateInfo);
112   if (!local_create_info.pStages || !local_create_info.pGroups)
113      goto fail;
114
115   if (local_create_info.stageCount) {
116      pipeline->stage_count = local_create_info.stageCount;
117
118      size_t size = sizeof(VkPipelineShaderStageCreateInfo) * local_create_info.stageCount;
119      pipeline->stages = malloc(size);
120      if (!pipeline->stages)
121         goto fail;
122
123      memcpy(pipeline->stages, local_create_info.pStages, size);
124
125      pipeline->hashes = malloc(sizeof(*pipeline->hashes) * local_create_info.stageCount);
126      if (!pipeline->hashes)
127         goto fail;
128
129      pipeline->identifiers = malloc(sizeof(*pipeline->identifiers) * local_create_info.stageCount);
130      if (!pipeline->identifiers)
131         goto fail;
132
133      for (uint32_t i = 0; i < local_create_info.stageCount; i++) {
134         RADV_FROM_HANDLE(vk_shader_module, module, pipeline->stages[i].module);
135
136         const VkPipelineShaderStageModuleIdentifierCreateInfoEXT *iinfo =
137            vk_find_struct_const(local_create_info.pStages[i].pNext,
138                                 PIPELINE_SHADER_STAGE_MODULE_IDENTIFIER_CREATE_INFO_EXT);
139
140         if (module) {
141            struct vk_shader_module *new_module = vk_shader_module_clone(NULL, module);
142            pipeline->stages[i].module = vk_shader_module_to_handle(new_module);
143            pipeline->stages[i].pNext = NULL;
144         } else {
145            assert(iinfo);
146            pipeline->identifiers[i].identifierSize =
147               MIN2(iinfo->identifierSize, sizeof(pipeline->hashes[i].sha1));
148            memcpy(pipeline->hashes[i].sha1, iinfo->pIdentifier,
149                   pipeline->identifiers[i].identifierSize);
150            pipeline->stages[i].module = VK_NULL_HANDLE;
151            pipeline->stages[i].pNext = &pipeline->identifiers[i];
152            pipeline->identifiers[i].sType =
153               VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_MODULE_IDENTIFIER_CREATE_INFO_EXT;
154            pipeline->identifiers[i].pNext = NULL;
155            pipeline->identifiers[i].pIdentifier = pipeline->hashes[i].sha1;
156         }
157      }
158   }
159
160   if (local_create_info.groupCount) {
161      size_t size = sizeof(VkRayTracingShaderGroupCreateInfoKHR) * local_create_info.groupCount;
162      pipeline->group_count = local_create_info.groupCount;
163      pipeline->groups = malloc(size);
164      if (!pipeline->groups)
165         goto fail;
166      memcpy(pipeline->groups, local_create_info.pGroups, size);
167   }
168
169   *pPipeline = radv_pipeline_to_handle(&pipeline->base);
170
171   free((void *)local_create_info.pGroups);
172   free((void *)local_create_info.pStages);
173   return VK_SUCCESS;
174fail:
175   free(pipeline->groups);
176   free(pipeline->stages);
177   free(pipeline->hashes);
178   free(pipeline->identifiers);
179   free((void *)local_create_info.pGroups);
180   free((void *)local_create_info.pStages);
181   return VK_ERROR_OUT_OF_HOST_MEMORY;
182}
183
184/*
185 * Global variables for an RT pipeline
186 */
187struct rt_variables {
188   const VkRayTracingPipelineCreateInfoKHR *create_info;
189
190   /* idx of the next shader to run in the next iteration of the main loop.
191    * During traversal, idx is used to store the SBT index and will contain
192    * the correct resume index upon returning.
193    */
194   nir_variable *idx;
195
196   /* scratch offset of the argument area relative to stack_ptr */
197   nir_variable *arg;
198
199   nir_variable *stack_ptr;
200
201   /* global address of the SBT entry used for the shader */
202   nir_variable *shader_record_ptr;
203
204   /* trace_ray arguments */
205   nir_variable *accel_struct;
206   nir_variable *flags;
207   nir_variable *cull_mask;
208   nir_variable *sbt_offset;
209   nir_variable *sbt_stride;
210   nir_variable *miss_index;
211   nir_variable *origin;
212   nir_variable *tmin;
213   nir_variable *direction;
214   nir_variable *tmax;
215
216   /* from the BTAS instance currently being visited */
217   nir_variable *custom_instance_and_mask;
218
219   /* Properties of the primitive currently being visited. */
220   nir_variable *primitive_id;
221   nir_variable *geometry_id_and_flags;
222   nir_variable *instance_id;
223   nir_variable *instance_addr;
224   nir_variable *hit_kind;
225   nir_variable *opaque;
226
227   /* Safeguard to ensure we don't end up in an infinite loop of non-existing case. Should not be
228    * needed but is extra anti-hang safety during bring-up. */
229   nir_variable *main_loop_case_visited;
230
231   /* Output variables for intersection & anyhit shaders. */
232   nir_variable *ahit_accept;
233   nir_variable *ahit_terminate;
234
235   /* Array of stack size struct for recording the max stack size for each group. */
236   struct radv_pipeline_shader_stack_size *stack_sizes;
237   unsigned stage_idx;
238};
239
240static void
241reserve_stack_size(struct rt_variables *vars, uint32_t size)
242{
243   for (uint32_t group_idx = 0; group_idx < vars->create_info->groupCount; group_idx++) {
244      const VkRayTracingShaderGroupCreateInfoKHR *group = vars->create_info->pGroups + group_idx;
245
246      if (vars->stage_idx == group->generalShader || vars->stage_idx == group->closestHitShader)
247         vars->stack_sizes[group_idx].recursive_size =
248            MAX2(vars->stack_sizes[group_idx].recursive_size, size);
249
250      if (vars->stage_idx == group->anyHitShader || vars->stage_idx == group->intersectionShader)
251         vars->stack_sizes[group_idx].non_recursive_size =
252            MAX2(vars->stack_sizes[group_idx].non_recursive_size, size);
253   }
254}
255
256static struct rt_variables
257create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info,
258                    struct radv_pipeline_shader_stack_size *stack_sizes)
259{
260   struct rt_variables vars = {
261      .create_info = create_info,
262   };
263   vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
264   vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
265   vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
266   vars.shader_record_ptr =
267      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
268
269   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
270   vars.accel_struct =
271      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
272   vars.flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ray_flags");
273   vars.cull_mask = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
274   vars.sbt_offset =
275      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
276   vars.sbt_stride =
277      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
278   vars.miss_index =
279      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index");
280   vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin");
281   vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin");
282   vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
283   vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
284
285   vars.custom_instance_and_mask = nir_variable_create(
286      shader, nir_var_shader_temp, glsl_uint_type(), "custom_instance_and_mask");
287   vars.primitive_id =
288      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
289   vars.geometry_id_and_flags =
290      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
291   vars.instance_id =
292      nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "instance_id");
293   vars.instance_addr =
294      nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
295   vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
296   vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque");
297
298   vars.main_loop_case_visited =
299      nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "main_loop_case_visited");
300   vars.ahit_accept =
301      nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_accept");
302   vars.ahit_terminate =
303      nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_terminate");
304
305   vars.stack_sizes = stack_sizes;
306   return vars;
307}
308
309/*
310 * Remap all the variables between the two rt_variables struct for inlining.
311 */
312static void
313map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
314                 const struct rt_variables *dst)
315{
316   src->create_info = dst->create_info;
317
318   _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
319   _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
320   _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
321   _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
322
323   _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
324   _mesa_hash_table_insert(var_remap, src->flags, dst->flags);
325   _mesa_hash_table_insert(var_remap, src->cull_mask, dst->cull_mask);
326   _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset);
327   _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride);
328   _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index);
329   _mesa_hash_table_insert(var_remap, src->origin, dst->origin);
330   _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin);
331   _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
332   _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
333
334   _mesa_hash_table_insert(var_remap, src->custom_instance_and_mask, dst->custom_instance_and_mask);
335   _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
336   _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
337   _mesa_hash_table_insert(var_remap, src->instance_id, dst->instance_id);
338   _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
339   _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind);
340   _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque);
341   _mesa_hash_table_insert(var_remap, src->ahit_accept, dst->ahit_accept);
342   _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate);
343
344   src->stack_sizes = dst->stack_sizes;
345   src->stage_idx = dst->stage_idx;
346}
347
348/*
349 * Create a copy of the global rt variables where the primitive/instance related variables are
350 * independent.This is needed as we need to keep the old values of the global variables around
351 * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied
352 * to the outer variables once we commit to a better hit.
353 */
354static struct rt_variables
355create_inner_vars(nir_builder *b, const struct rt_variables *vars)
356{
357   struct rt_variables inner_vars = *vars;
358   inner_vars.idx =
359      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
360   inner_vars.shader_record_ptr = nir_variable_create(
361      b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
362   inner_vars.primitive_id =
363      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
364   inner_vars.geometry_id_and_flags = nir_variable_create(
365      b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags");
366   inner_vars.tmax =
367      nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax");
368   inner_vars.instance_id =
369      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_instance_id");
370   inner_vars.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp,
371                                                  glsl_uint64_t_type(), "inner_instance_addr");
372   inner_vars.hit_kind =
373      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind");
374   inner_vars.custom_instance_and_mask = nir_variable_create(
375      b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_custom_instance_and_mask");
376
377   return inner_vars;
378}
379
380/* The hit attributes are stored on the stack. This is the offset compared to the current stack
381 * pointer of where the hit attrib is stored. */
382const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE);
383
384static void
385insert_rt_return(nir_builder *b, const struct rt_variables *vars)
386{
387   nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -16), 1);
388   nir_store_var(b, vars->idx,
389                 nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1);
390}
391
392enum sbt_type {
393   SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
394   SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
395   SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
396   SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
397};
398
399static nir_ssa_def *
400get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding)
401{
402   nir_ssa_def *desc_base_addr = nir_load_sbt_base_amd(b);
403
404   nir_ssa_def *desc =
405      nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding)));
406
407   nir_ssa_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
408   nir_ssa_def *stride =
409      nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, stride_offset));
410
411   return nir_iadd(b, desc, nir_imul(b, nir_u2u64(b, idx), stride));
412}
413
414static void
415load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx,
416               enum sbt_type binding, unsigned offset)
417{
418   nir_ssa_def *addr = get_sbt_ptr(b, idx, binding);
419
420   nir_ssa_def *load_addr = nir_iadd_imm(b, addr, offset);
421   nir_ssa_def *v_idx = nir_build_load_global(b, 1, 32, load_addr);
422
423   nir_store_var(b, vars->idx, v_idx, 1);
424
425   nir_ssa_def *record_addr = nir_iadd_imm(b, addr, RADV_RT_HANDLE_SIZE);
426   nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
427}
428
429/* This lowers all the RT instructions that we do not want to pass on to the combined shader and
430 * that we can implement using the variables from the shader we are going to inline into. */
431static void
432lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base)
433{
434   nir_builder b_shader;
435   nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader));
436
437   nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
438      nir_foreach_instr_safe (instr, block) {
439         switch (instr->type) {
440         case nir_instr_type_intrinsic: {
441            b_shader.cursor = nir_before_instr(instr);
442            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
443            nir_ssa_def *ret = NULL;
444
445            switch (intr->intrinsic) {
446            case nir_intrinsic_rt_execute_callable: {
447               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
448               uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
449
450               nir_store_var(
451                  &b_shader, vars->stack_ptr,
452                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
453               nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx),
454                                 nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
455
456               nir_store_var(&b_shader, vars->stack_ptr,
457                             nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16),
458                             1);
459               load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0);
460
461               nir_store_var(&b_shader, vars->arg,
462                             nir_iadd_imm(&b_shader, intr->src[1].ssa, -size - 16), 1);
463
464               reserve_stack_size(vars, size + 16);
465               break;
466            }
467            case nir_intrinsic_rt_trace_ray: {
468               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
469               uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
470
471               nir_store_var(
472                  &b_shader, vars->stack_ptr,
473                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
474               nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx),
475                                 nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
476
477               nir_store_var(&b_shader, vars->stack_ptr,
478                             nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16),
479                             1);
480
481               nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1);
482               nir_store_var(&b_shader, vars->arg,
483                             nir_iadd_imm(&b_shader, intr->src[10].ssa, -size - 16), 1);
484
485               reserve_stack_size(vars, size + 16);
486
487               /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
488               nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1);
489               nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1);
490               nir_store_var(&b_shader, vars->cull_mask,
491                             nir_iand_imm(&b_shader, intr->src[2].ssa, 0xff), 0x1);
492               nir_store_var(&b_shader, vars->sbt_offset,
493                             nir_iand_imm(&b_shader, intr->src[3].ssa, 0xf), 0x1);
494               nir_store_var(&b_shader, vars->sbt_stride,
495                             nir_iand_imm(&b_shader, intr->src[4].ssa, 0xf), 0x1);
496               nir_store_var(&b_shader, vars->miss_index,
497                             nir_iand_imm(&b_shader, intr->src[5].ssa, 0xffff), 0x1);
498               nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7);
499               nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1);
500               nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7);
501               nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1);
502               break;
503            }
504            case nir_intrinsic_rt_resume: {
505               uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
506
507               nir_store_var(
508                  &b_shader, vars->stack_ptr,
509                  nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), -size), 1);
510               break;
511            }
512            case nir_intrinsic_rt_return_amd: {
513               if (shader->info.stage == MESA_SHADER_RAYGEN) {
514                  nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1);
515                  break;
516               }
517               insert_rt_return(&b_shader, vars);
518               break;
519            }
520            case nir_intrinsic_load_scratch: {
521               nir_instr_rewrite_src_ssa(
522                  instr, &intr->src[0],
523                  nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa));
524               continue;
525            }
526            case nir_intrinsic_store_scratch: {
527               nir_instr_rewrite_src_ssa(
528                  instr, &intr->src[1],
529                  nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa));
530               continue;
531            }
532            case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
533               ret = nir_load_var(&b_shader, vars->arg);
534               break;
535            }
536            case nir_intrinsic_load_shader_record_ptr: {
537               ret = nir_load_var(&b_shader, vars->shader_record_ptr);
538               break;
539            }
540            case nir_intrinsic_load_ray_launch_id: {
541               ret = nir_load_global_invocation_id(&b_shader, 32);
542               break;
543            }
544            case nir_intrinsic_load_ray_launch_size: {
545               nir_ssa_def *launch_size_addr =
546                  nir_load_ray_launch_size_addr_amd(&b_shader);
547
548               nir_ssa_def * xy = nir_build_load_smem_amd(
549                  &b_shader, 2, launch_size_addr, nir_imm_int(&b_shader, 0));
550               nir_ssa_def * z = nir_build_load_smem_amd(
551                  &b_shader, 1, launch_size_addr, nir_imm_int(&b_shader, 8));
552
553               nir_ssa_def *xyz[3] = {
554                  nir_channel(&b_shader, xy, 0),
555                  nir_channel(&b_shader, xy, 1),
556                  z,
557               };
558               ret = nir_vec(&b_shader, xyz, 3);
559               break;
560            }
561            case nir_intrinsic_load_ray_t_min: {
562               ret = nir_load_var(&b_shader, vars->tmin);
563               break;
564            }
565            case nir_intrinsic_load_ray_t_max: {
566               ret = nir_load_var(&b_shader, vars->tmax);
567               break;
568            }
569            case nir_intrinsic_load_ray_world_origin: {
570               ret = nir_load_var(&b_shader, vars->origin);
571               break;
572            }
573            case nir_intrinsic_load_ray_world_direction: {
574               ret = nir_load_var(&b_shader, vars->direction);
575               break;
576            }
577            case nir_intrinsic_load_ray_instance_custom_index: {
578               ret = nir_load_var(&b_shader, vars->custom_instance_and_mask);
579               ret = nir_iand_imm(&b_shader, ret, 0xFFFFFF);
580               break;
581            }
582            case nir_intrinsic_load_primitive_id: {
583               ret = nir_load_var(&b_shader, vars->primitive_id);
584               break;
585            }
586            case nir_intrinsic_load_ray_geometry_index: {
587               ret = nir_load_var(&b_shader, vars->geometry_id_and_flags);
588               ret = nir_iand_imm(&b_shader, ret, 0xFFFFFFF);
589               break;
590            }
591            case nir_intrinsic_load_instance_id: {
592               ret = nir_load_var(&b_shader, vars->instance_id);
593               break;
594            }
595            case nir_intrinsic_load_ray_flags: {
596               ret = nir_load_var(&b_shader, vars->flags);
597               break;
598            }
599            case nir_intrinsic_load_ray_hit_kind: {
600               ret = nir_load_var(&b_shader, vars->hit_kind);
601               break;
602            }
603            case nir_intrinsic_load_ray_world_to_object: {
604               unsigned c = nir_intrinsic_column(intr);
605               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
606               nir_ssa_def *wto_matrix[3];
607               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
608
609               nir_ssa_def *vals[3];
610               for (unsigned i = 0; i < 3; ++i)
611                  vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
612
613               ret = nir_vec(&b_shader, vals, 3);
614               if (c == 3)
615                  ret = nir_fneg(&b_shader,
616                                 nir_build_vec3_mat_mult(&b_shader, ret, wto_matrix, false));
617               break;
618            }
619            case nir_intrinsic_load_ray_object_to_world: {
620               unsigned c = nir_intrinsic_column(intr);
621               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
622               if (c == 3) {
623                  nir_ssa_def *wto_matrix[3];
624                  nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
625
626                  nir_ssa_def *vals[3];
627                  for (unsigned i = 0; i < 3; ++i)
628                     vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
629
630                  ret = nir_vec(&b_shader, vals, 3);
631               } else {
632                  ret = nir_build_load_global(
633                     &b_shader, 3, 32, nir_iadd_imm(&b_shader, instance_node_addr, 92 + c * 12));
634               }
635               break;
636            }
637            case nir_intrinsic_load_ray_object_origin: {
638               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
639               nir_ssa_def *wto_matrix[] = {
640                  nir_build_load_global(&b_shader, 4, 32,
641                                        nir_iadd_imm(&b_shader, instance_node_addr, 16),
642                                        .align_mul = 64, .align_offset = 16),
643                  nir_build_load_global(&b_shader, 4, 32,
644                                        nir_iadd_imm(&b_shader, instance_node_addr, 32),
645                                        .align_mul = 64, .align_offset = 32),
646                  nir_build_load_global(&b_shader, 4, 32,
647                                        nir_iadd_imm(&b_shader, instance_node_addr, 48),
648                                        .align_mul = 64, .align_offset = 48)};
649               ret = nir_build_vec3_mat_mult_pre(
650                  &b_shader, nir_load_var(&b_shader, vars->origin), wto_matrix);
651               break;
652            }
653            case nir_intrinsic_load_ray_object_direction: {
654               nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
655               nir_ssa_def *wto_matrix[3];
656               nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
657               ret = nir_build_vec3_mat_mult(
658                  &b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false);
659               break;
660            }
661            case nir_intrinsic_load_intersection_opaque_amd: {
662               ret = nir_load_var(&b_shader, vars->opaque);
663               break;
664            }
665            case nir_intrinsic_load_cull_mask: {
666               ret = nir_load_var(&b_shader, vars->cull_mask);
667               break;
668            }
669            case nir_intrinsic_ignore_ray_intersection: {
670               nir_store_var(&b_shader, vars->ahit_accept, nir_imm_false(&b_shader), 0x1);
671
672               /* The if is a workaround to avoid having to fix up control flow manually */
673               nir_push_if(&b_shader, nir_imm_true(&b_shader));
674               nir_jump(&b_shader, nir_jump_return);
675               nir_pop_if(&b_shader, NULL);
676               break;
677            }
678            case nir_intrinsic_terminate_ray: {
679               nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
680               nir_store_var(&b_shader, vars->ahit_terminate, nir_imm_true(&b_shader), 0x1);
681
682               /* The if is a workaround to avoid having to fix up control flow manually */
683               nir_push_if(&b_shader, nir_imm_true(&b_shader));
684               nir_jump(&b_shader, nir_jump_return);
685               nir_pop_if(&b_shader, NULL);
686               break;
687            }
688            case nir_intrinsic_report_ray_intersection: {
689               nir_push_if(
690                  &b_shader,
691                  nir_iand(
692                     &b_shader,
693                     nir_fge(&b_shader, nir_load_var(&b_shader, vars->tmax), intr->src[0].ssa),
694                     nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin))));
695               {
696                  nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
697                  nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1);
698                  nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1);
699               }
700               nir_pop_if(&b_shader, NULL);
701               break;
702            }
703            default:
704               continue;
705            }
706
707            if (ret)
708               nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
709            nir_instr_remove(instr);
710            break;
711         }
712         case nir_instr_type_jump: {
713            nir_jump_instr *jump = nir_instr_as_jump(instr);
714            if (jump->type == nir_jump_halt) {
715               b_shader.cursor = nir_instr_remove(instr);
716               nir_jump(&b_shader, nir_jump_return);
717            }
718            break;
719         }
720         default:
721            break;
722         }
723      }
724   }
725
726   nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
727}
728
729static void
730insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
731               uint32_t call_idx_base, uint32_t call_idx)
732{
733   struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
734
735   nir_opt_dead_cf(shader);
736
737   struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->stack_sizes);
738   map_rt_variables(var_remap, &src_vars, vars);
739
740   NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
741
742   NIR_PASS(_, shader, nir_opt_remove_phis);
743   NIR_PASS(_, shader, nir_lower_returns);
744   NIR_PASS(_, shader, nir_opt_dce);
745
746   reserve_stack_size(vars, shader->scratch_size);
747
748   nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
749   nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1);
750   nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
751   nir_pop_if(b, NULL);
752
753   /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
754   ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
755
756   ralloc_free(var_remap);
757}
758
759static bool
760lower_rt_derefs(nir_shader *shader)
761{
762   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
763
764   bool progress = false;
765
766   nir_builder b;
767   nir_builder_init(&b, impl);
768
769   b.cursor = nir_before_cf_list(&impl->body);
770   nir_ssa_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
771
772   nir_foreach_block (block, impl) {
773      nir_foreach_instr_safe (instr, block) {
774         if (instr->type != nir_instr_type_deref)
775            continue;
776
777         nir_deref_instr *deref = nir_instr_as_deref(instr);
778         b.cursor = nir_before_instr(&deref->instr);
779
780         nir_deref_instr *replacement = NULL;
781         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
782            deref->modes = nir_var_function_temp;
783            progress = true;
784
785            if (deref->deref_type == nir_deref_type_var)
786               replacement =
787                  nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
788         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
789            deref->modes = nir_var_function_temp;
790            progress = true;
791
792            if (deref->deref_type == nir_deref_type_var)
793               replacement = nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET),
794                                                  nir_var_function_temp, deref->type, 0);
795         }
796
797         if (replacement != NULL) {
798            nir_ssa_def_rewrite_uses(&deref->dest.ssa, &replacement->dest.ssa);
799            nir_instr_remove(&deref->instr);
800         }
801      }
802   }
803
804   if (progress)
805      nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
806   else
807      nir_metadata_preserve(impl, nir_metadata_all);
808
809   return progress;
810}
811
812static nir_shader *
813parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo)
814{
815   struct radv_pipeline_key key;
816   memset(&key, 0, sizeof(key));
817
818   struct radv_pipeline_stage rt_stage;
819
820   radv_pipeline_stage_init(sinfo, &rt_stage, vk_to_mesa_shader_stage(sinfo->stage));
821
822   nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, &key);
823
824   if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
825       shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) {
826      nir_block *last_block = nir_impl_last_block(nir_shader_get_entrypoint(shader));
827      nir_builder b_inner;
828      nir_builder_init(&b_inner, nir_shader_get_entrypoint(shader));
829      b_inner.cursor = nir_after_block(last_block);
830      nir_rt_return_amd(&b_inner);
831   }
832
833   NIR_PASS(_, shader, nir_lower_vars_to_explicit_types,
834            nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
835            glsl_get_natural_size_align_bytes);
836
837   NIR_PASS(_, shader, lower_rt_derefs);
838
839   NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp,
840            nir_address_format_32bit_offset);
841
842   return shader;
843}
844
845static nir_function_impl *
846lower_any_hit_for_intersection(nir_shader *any_hit)
847{
848   nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
849
850   /* Any-hit shaders need three parameters */
851   assert(impl->function->num_params == 0);
852   nir_parameter params[] = {
853      {
854         /* A pointer to a boolean value for whether or not the hit was
855          * accepted.
856          */
857         .num_components = 1,
858         .bit_size = 32,
859      },
860      {
861         /* The hit T value */
862         .num_components = 1,
863         .bit_size = 32,
864      },
865      {
866         /* The hit kind */
867         .num_components = 1,
868         .bit_size = 32,
869      },
870   };
871   impl->function->num_params = ARRAY_SIZE(params);
872   impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
873   memcpy(impl->function->params, params, sizeof(params));
874
875   nir_builder build;
876   nir_builder_init(&build, impl);
877   nir_builder *b = &build;
878
879   b->cursor = nir_before_cf_list(&impl->body);
880
881   nir_ssa_def *commit_ptr = nir_load_param(b, 0);
882   nir_ssa_def *hit_t = nir_load_param(b, 1);
883   nir_ssa_def *hit_kind = nir_load_param(b, 2);
884
885   nir_deref_instr *commit =
886      nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0);
887
888   nir_foreach_block_safe (block, impl) {
889      nir_foreach_instr_safe (instr, block) {
890         switch (instr->type) {
891         case nir_instr_type_intrinsic: {
892            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
893            switch (intrin->intrinsic) {
894            case nir_intrinsic_ignore_ray_intersection:
895               b->cursor = nir_instr_remove(&intrin->instr);
896               /* We put the newly emitted code inside a dummy if because it's
897                * going to contain a jump instruction and we don't want to
898                * deal with that mess here.  It'll get dealt with by our
899                * control-flow optimization passes.
900                */
901               nir_store_deref(b, commit, nir_imm_false(b), 0x1);
902               nir_push_if(b, nir_imm_true(b));
903               nir_jump(b, nir_jump_return);
904               nir_pop_if(b, NULL);
905               break;
906
907            case nir_intrinsic_terminate_ray:
908               /* The "normal" handling of terminateRay works fine in
909                * intersection shaders.
910                */
911               break;
912
913            case nir_intrinsic_load_ray_t_max:
914               nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_t);
915               nir_instr_remove(&intrin->instr);
916               break;
917
918            case nir_intrinsic_load_ray_hit_kind:
919               nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_kind);
920               nir_instr_remove(&intrin->instr);
921               break;
922
923            default:
924               break;
925            }
926            break;
927         }
928         case nir_instr_type_jump: {
929            nir_jump_instr *jump = nir_instr_as_jump(instr);
930            if (jump->type == nir_jump_halt) {
931               b->cursor = nir_instr_remove(instr);
932               nir_jump(b, nir_jump_return);
933            }
934            break;
935         }
936
937         default:
938            break;
939         }
940      }
941   }
942
943   nir_validate_shader(any_hit, "after initial any-hit lowering");
944
945   nir_lower_returns_impl(impl);
946
947   nir_validate_shader(any_hit, "after lowering returns");
948
949   return impl;
950}
951
952/* Inline the any_hit shader into the intersection shader so we don't have
953 * to implement yet another shader call interface here. Neither do any recursion.
954 */
955static void
956nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
957{
958   void *dead_ctx = ralloc_context(intersection);
959
960   nir_function_impl *any_hit_impl = NULL;
961   struct hash_table *any_hit_var_remap = NULL;
962   if (any_hit) {
963      any_hit = nir_shader_clone(dead_ctx, any_hit);
964      NIR_PASS(_, any_hit, nir_opt_dce);
965      any_hit_impl = lower_any_hit_for_intersection(any_hit);
966      any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
967   }
968
969   nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
970
971   nir_builder build;
972   nir_builder_init(&build, impl);
973   nir_builder *b = &build;
974
975   b->cursor = nir_before_cf_list(&impl->body);
976
977   nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
978   nir_store_var(b, commit, nir_imm_false(b), 0x1);
979
980   nir_foreach_block_safe (block, impl) {
981      nir_foreach_instr_safe (instr, block) {
982         if (instr->type != nir_instr_type_intrinsic)
983            continue;
984
985         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
986         if (intrin->intrinsic != nir_intrinsic_report_ray_intersection)
987            continue;
988
989         b->cursor = nir_instr_remove(&intrin->instr);
990         nir_ssa_def *hit_t = nir_ssa_for_src(b, intrin->src[0], 1);
991         nir_ssa_def *hit_kind = nir_ssa_for_src(b, intrin->src[1], 1);
992         nir_ssa_def *min_t = nir_load_ray_t_min(b);
993         nir_ssa_def *max_t = nir_load_ray_t_max(b);
994
995         /* bool commit_tmp = false; */
996         nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp");
997         nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
998
999         nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t)));
1000         {
1001            /* Any-hit defaults to commit */
1002            nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
1003
1004            if (any_hit_impl != NULL) {
1005               nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b)));
1006               {
1007                  nir_ssa_def *params[] = {
1008                     &nir_build_deref_var(b, commit_tmp)->dest.ssa,
1009                     hit_t,
1010                     hit_kind,
1011                  };
1012                  nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap);
1013               }
1014               nir_pop_if(b, NULL);
1015            }
1016
1017            nir_push_if(b, nir_load_var(b, commit_tmp));
1018            {
1019               nir_report_ray_intersection(b, 1, hit_t, hit_kind);
1020            }
1021            nir_pop_if(b, NULL);
1022         }
1023         nir_pop_if(b, NULL);
1024
1025         nir_ssa_def *accepted = nir_load_var(b, commit_tmp);
1026         nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted);
1027      }
1028   }
1029
1030   /* We did some inlining; have to re-index SSA defs */
1031   nir_index_ssa_defs(impl);
1032
1033   /* Eliminate the casts introduced for the commit return of the any-hit shader. */
1034   NIR_PASS(_, intersection, nir_opt_deref);
1035
1036   ralloc_free(dead_ctx);
1037}
1038
1039/* Variables only used internally to ray traversal. This is data that describes
1040 * the current state of the traversal vs. what we'd give to a shader.  e.g. what
1041 * is the instance we're currently visiting vs. what is the instance of the
1042 * closest hit. */
1043struct rt_traversal_vars {
1044   nir_variable *origin;
1045   nir_variable *dir;
1046   nir_variable *inv_dir;
1047   nir_variable *sbt_offset_and_flags;
1048   nir_variable *instance_id;
1049   nir_variable *custom_instance_and_mask;
1050   nir_variable *instance_addr;
1051   nir_variable *hit;
1052   nir_variable *bvh_base;
1053   nir_variable *stack;
1054   nir_variable *top_stack;
1055};
1056
1057static struct rt_traversal_vars
1058init_traversal_vars(nir_builder *b)
1059{
1060   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1061   struct rt_traversal_vars ret;
1062
1063   ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
1064   ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
1065   ret.inv_dir =
1066      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
1067   ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1068                                                  "traversal_sbt_offset_and_flags");
1069   ret.instance_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1070                                         "traversal_instance_id");
1071   ret.custom_instance_and_mask = nir_variable_create(
1072      b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_custom_instance_and_mask");
1073   ret.instance_addr =
1074      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1075   ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit");
1076   ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(),
1077                                      "traversal_bvh_base");
1078   ret.stack =
1079      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
1080   ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1081                                       "traversal_top_stack_ptr");
1082   return ret;
1083}
1084
1085static void
1086visit_any_hit_shaders(struct radv_device *device,
1087                      const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1088                      struct rt_variables *vars)
1089{
1090   nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
1091
1092   nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
1093   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
1094      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
1095      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
1096
1097      switch (group_info->type) {
1098      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
1099         shader_id = group_info->anyHitShader;
1100         break;
1101      default:
1102         break;
1103      }
1104      if (shader_id == VK_SHADER_UNUSED_KHR)
1105         continue;
1106
1107      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
1108      nir_shader *nir_stage = parse_rt_stage(device, stage);
1109
1110      vars->stage_idx = shader_id;
1111      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
1112   }
1113   nir_pop_if(b, NULL);
1114}
1115
1116static void
1117insert_traversal_triangle_case(struct radv_device *device,
1118                               const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1119                               nir_ssa_def *result, const struct rt_variables *vars,
1120                               const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)
1121{
1122   nir_ssa_def *dist = nir_channel(b, result, 0);
1123   nir_ssa_def *div = nir_channel(b, result, 1);
1124   dist = nir_fdiv(b, dist, div);
1125   nir_ssa_def *frontface = nir_flt(b, nir_imm_float(b, 0), div);
1126   nir_ssa_def *switch_ccw =
1127      nir_test_mask(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1128                    VK_GEOMETRY_INSTANCE_TRIANGLE_FLIP_FACING_BIT_KHR << 24);
1129   frontface = nir_ixor(b, frontface, switch_ccw);
1130
1131   nir_ssa_def *not_cull =
1132      nir_inot(b, nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsSkipTrianglesKHRMask));
1133   nir_ssa_def *not_facing_cull = nir_ieq_imm(
1134      b,
1135      nir_iand(b, nir_load_var(b, vars->flags),
1136               nir_bcsel(b, frontface, nir_imm_int(b, SpvRayFlagsCullFrontFacingTrianglesKHRMask),
1137                         nir_imm_int(b, SpvRayFlagsCullBackFacingTrianglesKHRMask))),
1138      0);
1139
1140   not_cull = nir_iand(
1141      b, not_cull,
1142      nir_ior(b, not_facing_cull,
1143              nir_test_mask(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1144                            VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR << 24)));
1145
1146   nir_push_if(b, nir_iand(b,
1147                           nir_iand(b, nir_flt(b, dist, nir_load_var(b, vars->tmax)),
1148                                    nir_flt(b, nir_load_var(b, vars->tmin), dist)),
1149                           not_cull));
1150   {
1151
1152      nir_ssa_def *triangle_info =
1153         nir_build_load_global(b, 2, 32,
1154                               nir_iadd_imm(b, build_node_to_addr(device, b, bvh_node),
1155                                            offsetof(struct radv_bvh_triangle_node, triangle_id)));
1156      nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0);
1157      nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1);
1158      nir_ssa_def *geometry_id = nir_iand_imm(b, geometry_id_and_flags, 0xfffffff);
1159      nir_ssa_def *is_opaque = hit_is_opaque(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1160                                             nir_load_var(b, vars->flags), geometry_id_and_flags);
1161
1162      not_cull =
1163         nir_ieq_imm(b,
1164                     nir_iand(b, nir_load_var(b, vars->flags),
1165                              nir_bcsel(b, is_opaque, nir_imm_int(b, SpvRayFlagsCullOpaqueKHRMask),
1166                                        nir_imm_int(b, SpvRayFlagsCullNoOpaqueKHRMask))),
1167                     0);
1168      nir_push_if(b, not_cull);
1169      {
1170         nir_ssa_def *sbt_idx = nir_iadd(
1171            b,
1172            nir_iadd(b, nir_load_var(b, vars->sbt_offset),
1173                     nir_iand_imm(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 0xffffff)),
1174            nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id));
1175         nir_ssa_def *divs[2] = {div, div};
1176         nir_ssa_def *ij = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2));
1177         nir_ssa_def *hit_kind =
1178            nir_bcsel(b, frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
1179
1180         nir_store_scratch(
1181            b, ij, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), RADV_HIT_ATTRIB_OFFSET),
1182            .align_mul = 16);
1183
1184         nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
1185         nir_store_var(b, vars->ahit_terminate, nir_imm_false(b), 0x1);
1186
1187         nir_push_if(b, nir_inot(b, is_opaque));
1188         {
1189            struct rt_variables inner_vars = create_inner_vars(b, vars);
1190
1191            nir_store_var(b, inner_vars.primitive_id, primitive_id, 1);
1192            nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1);
1193            nir_store_var(b, inner_vars.tmax, dist, 0x1);
1194            nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1195            nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr),
1196                          0x1);
1197            nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
1198            nir_store_var(b, inner_vars.custom_instance_and_mask,
1199                          nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1200
1201            load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
1202
1203            visit_any_hit_shaders(device, pCreateInfo, b, &inner_vars);
1204
1205            nir_push_if(b, nir_inot(b, nir_load_var(b, vars->ahit_accept)));
1206            {
1207               nir_jump(b, nir_jump_continue);
1208            }
1209            nir_pop_if(b, NULL);
1210         }
1211         nir_pop_if(b, NULL);
1212
1213         nir_store_var(b, vars->primitive_id, primitive_id, 1);
1214         nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1);
1215         nir_store_var(b, vars->tmax, dist, 0x1);
1216         nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1217         nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1218         nir_store_var(b, vars->hit_kind, hit_kind, 0x1);
1219         nir_store_var(b, vars->custom_instance_and_mask,
1220                       nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1221
1222         nir_store_var(b, vars->idx, sbt_idx, 1);
1223         nir_store_var(b, trav_vars->hit, nir_imm_true(b), 1);
1224
1225         nir_ssa_def *terminate_on_first_hit =
1226            nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
1227         nir_ssa_def *ray_terminated = nir_load_var(b, vars->ahit_terminate);
1228         nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
1229         {
1230            nir_jump(b, nir_jump_break);
1231         }
1232         nir_pop_if(b, NULL);
1233      }
1234      nir_pop_if(b, NULL);
1235   }
1236   nir_pop_if(b, NULL);
1237}
1238
1239static void
1240insert_traversal_aabb_case(struct radv_device *device,
1241                           const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1242                           const struct rt_variables *vars,
1243                           const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)
1244{
1245   nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node);
1246   nir_ssa_def *triangle_info = nir_build_load_global(b, 2, 32, nir_iadd_imm(b, node_addr, 24));
1247   nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0);
1248   nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1);
1249   nir_ssa_def *geometry_id = nir_iand_imm(b, geometry_id_and_flags, 0xfffffff);
1250   nir_ssa_def *is_opaque = hit_is_opaque(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1251                                          nir_load_var(b, vars->flags), geometry_id_and_flags);
1252
1253   nir_ssa_def *not_skip_aabb =
1254      nir_inot(b, nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsSkipAABBsKHRMask));
1255   nir_ssa_def *not_cull = nir_iand(
1256      b, not_skip_aabb,
1257      nir_ieq_imm(b,
1258                  nir_iand(b, nir_load_var(b, vars->flags),
1259                           nir_bcsel(b, is_opaque, nir_imm_int(b, SpvRayFlagsCullOpaqueKHRMask),
1260                                     nir_imm_int(b, SpvRayFlagsCullNoOpaqueKHRMask))),
1261                  0));
1262   nir_push_if(b, not_cull);
1263   {
1264      nir_ssa_def *sbt_idx = nir_iadd(
1265         b,
1266         nir_iadd(b, nir_load_var(b, vars->sbt_offset),
1267                  nir_iand_imm(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 0xffffff)),
1268         nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id));
1269
1270      struct rt_variables inner_vars = create_inner_vars(b, vars);
1271
1272      /* For AABBs the intersection shader writes the hit kind, and only does it if it is the
1273       * next closest hit candidate. */
1274      inner_vars.hit_kind = vars->hit_kind;
1275
1276      nir_store_var(b, inner_vars.primitive_id, primitive_id, 1);
1277      nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1);
1278      nir_store_var(b, inner_vars.tmax, nir_load_var(b, vars->tmax), 0x1);
1279      nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1280      nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1281      nir_store_var(b, inner_vars.custom_instance_and_mask,
1282                    nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1283      nir_store_var(b, inner_vars.opaque, is_opaque, 1);
1284
1285      load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
1286
1287      nir_store_var(b, vars->ahit_accept, nir_imm_false(b), 0x1);
1288      nir_store_var(b, vars->ahit_terminate, nir_imm_false(b), 0x1);
1289
1290      nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
1291      for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
1292         const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
1293         uint32_t shader_id = VK_SHADER_UNUSED_KHR;
1294         uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
1295
1296         switch (group_info->type) {
1297         case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
1298            shader_id = group_info->intersectionShader;
1299            any_hit_shader_id = group_info->anyHitShader;
1300            break;
1301         default:
1302            break;
1303         }
1304         if (shader_id == VK_SHADER_UNUSED_KHR)
1305            continue;
1306
1307         const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
1308         nir_shader *nir_stage = parse_rt_stage(device, stage);
1309
1310         nir_shader *any_hit_stage = NULL;
1311         if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
1312            stage = &pCreateInfo->pStages[any_hit_shader_id];
1313            any_hit_stage = parse_rt_stage(device, stage);
1314
1315            nir_lower_intersection_shader(nir_stage, any_hit_stage);
1316            ralloc_free(any_hit_stage);
1317         }
1318
1319         inner_vars.stage_idx = shader_id;
1320         insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
1321      }
1322      nir_push_else(b, NULL);
1323      {
1324         nir_ssa_def *vec3_zero = nir_channels(b, nir_imm_vec4(b, 0, 0, 0, 0), 0x7);
1325         nir_ssa_def *vec3_inf =
1326            nir_channels(b, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, 0), 0x7);
1327
1328         nir_ssa_def *bvh_lo = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 0));
1329         nir_ssa_def *bvh_hi = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 12));
1330
1331         bvh_lo = nir_fsub(b, bvh_lo, nir_load_var(b, trav_vars->origin));
1332         bvh_hi = nir_fsub(b, bvh_hi, nir_load_var(b, trav_vars->origin));
1333         nir_ssa_def *t_vec = nir_fmin(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)),
1334                                       nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir)));
1335         nir_ssa_def *t2_vec = nir_fmax(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)),
1336                                        nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir)));
1337         /* If we run parallel to one of the edges the range should be [0, inf) not [0,0] */
1338         t2_vec =
1339            nir_bcsel(b, nir_feq(b, nir_load_var(b, trav_vars->dir), vec3_zero), vec3_inf, t2_vec);
1340
1341         nir_ssa_def *t_min = nir_fmax(b, nir_channel(b, t_vec, 0), nir_channel(b, t_vec, 1));
1342         t_min = nir_fmax(b, t_min, nir_channel(b, t_vec, 2));
1343
1344         nir_ssa_def *t_max = nir_fmin(b, nir_channel(b, t2_vec, 0), nir_channel(b, t2_vec, 1));
1345         t_max = nir_fmin(b, t_max, nir_channel(b, t2_vec, 2));
1346
1347         nir_push_if(b, nir_iand(b, nir_fge(b, nir_load_var(b, vars->tmax), t_min),
1348                                 nir_fge(b, t_max, nir_load_var(b, vars->tmin))));
1349         {
1350            nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
1351            nir_store_var(b, vars->tmax, nir_fmax(b, t_min, nir_load_var(b, vars->tmin)), 1);
1352         }
1353         nir_pop_if(b, NULL);
1354      }
1355      nir_pop_if(b, NULL);
1356
1357      nir_push_if(b, nir_load_var(b, vars->ahit_accept));
1358      {
1359         nir_store_var(b, vars->primitive_id, primitive_id, 1);
1360         nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1);
1361         nir_store_var(b, vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
1362         nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1363         nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1364         nir_store_var(b, vars->custom_instance_and_mask,
1365                       nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1366
1367         nir_store_var(b, vars->idx, sbt_idx, 1);
1368         nir_store_var(b, trav_vars->hit, nir_imm_true(b), 1);
1369
1370         nir_ssa_def *terminate_on_first_hit =
1371            nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
1372         nir_ssa_def *ray_terminated = nir_load_var(b, vars->ahit_terminate);
1373         nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
1374         {
1375            nir_jump(b, nir_jump_break);
1376         }
1377         nir_pop_if(b, NULL);
1378      }
1379      nir_pop_if(b, NULL);
1380   }
1381   nir_pop_if(b, NULL);
1382}
1383
1384static nir_shader *
1385build_traversal_shader(struct radv_device *device,
1386                       const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1387                       const struct rt_variables *dst_vars,
1388                       struct hash_table *var_remap)
1389{
1390   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal");
1391   b.shader->info.internal = false;
1392   b.shader->info.workgroup_size[0] = 8;
1393   b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
1394   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, dst_vars->stack_sizes);
1395   map_rt_variables(var_remap, &vars, dst_vars);
1396
1397   unsigned lanes = device->physical_device->rt_wave_size;
1398   unsigned elements = lanes * MAX_STACK_ENTRY_COUNT;
1399   nir_variable *stack_var = nir_variable_create(b.shader, nir_var_mem_shared,
1400                                                 glsl_array_type(glsl_uint_type(), elements, 0),
1401                                                 "trav_stack");
1402   nir_deref_instr *stack_deref = nir_build_deref_var(&b, stack_var);
1403   nir_deref_instr *stack;
1404   nir_ssa_def *stack_idx_stride = nir_imm_int(&b, lanes);
1405   nir_ssa_def *stack_idx_base = nir_load_local_invocation_index(&b);
1406
1407   nir_ssa_def *accel_struct = nir_load_var(&b, vars.accel_struct);
1408
1409   struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
1410
1411   nir_store_var(&b, trav_vars.hit, nir_imm_false(&b), 1);
1412
1413   nir_push_if(&b, nir_ine_imm(&b, accel_struct, 0));
1414   {
1415      nir_store_var(&b, trav_vars.bvh_base, build_addr_to_node(&b, accel_struct), 1);
1416
1417      nir_ssa_def *bvh_root = nir_build_load_global(
1418         &b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, .align_mul = 64);
1419
1420      nir_ssa_def *desc = create_bvh_descriptor(&b);
1421      nir_ssa_def *vec3ones = nir_channels(&b, nir_imm_vec4(&b, 1.0, 1.0, 1.0, 1.0), 0x7);
1422
1423      nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
1424      nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
1425      nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
1426      nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1);
1427      nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
1428
1429      nir_store_var(&b, trav_vars.stack, nir_iadd(&b, stack_idx_base, stack_idx_stride), 1);
1430      stack = nir_build_deref_array(&b, stack_deref, stack_idx_base);
1431      nir_store_deref(&b, stack, bvh_root, 0x1);
1432
1433      nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1);
1434
1435      nir_push_loop(&b);
1436
1437      nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, trav_vars.stack), stack_idx_base));
1438      nir_jump(&b, nir_jump_break);
1439      nir_pop_if(&b, NULL);
1440
1441      nir_push_if(
1442         &b, nir_uge(&b, nir_load_var(&b, trav_vars.top_stack), nir_load_var(&b, trav_vars.stack)));
1443      nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1);
1444      nir_store_var(&b, trav_vars.bvh_base,
1445                    build_addr_to_node(&b, nir_load_var(&b, vars.accel_struct)), 1);
1446      nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
1447      nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
1448      nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
1449      nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
1450
1451      nir_pop_if(&b, NULL);
1452
1453      nir_store_var(&b, trav_vars.stack,
1454                    nir_isub(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
1455
1456      stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
1457      nir_ssa_def *bvh_node = nir_load_deref(&b, stack);
1458      nir_ssa_def *bvh_node_type = nir_iand_imm(&b, bvh_node, 7);
1459
1460      bvh_node = nir_iadd(&b, nir_load_var(&b, trav_vars.bvh_base), nir_u2u(&b, bvh_node, 64));
1461      nir_ssa_def *intrinsic_result = NULL;
1462      if (!radv_emulate_rt(device->physical_device)) {
1463         intrinsic_result = nir_bvh64_intersect_ray_amd(
1464            &b, 32, desc, nir_unpack_64_2x32(&b, bvh_node), nir_load_var(&b, vars.tmax),
1465            nir_load_var(&b, trav_vars.origin), nir_load_var(&b, trav_vars.dir),
1466            nir_load_var(&b, trav_vars.inv_dir));
1467      }
1468
1469      nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 4), 0));
1470      {
1471         nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 2), 0));
1472         {
1473            /* custom */
1474            nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 1), 0));
1475            if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR)) {
1476               insert_traversal_aabb_case(device, pCreateInfo, &b, &vars, &trav_vars, bvh_node);
1477            }
1478            nir_push_else(&b, NULL);
1479            {
1480               /* instance */
1481               nir_ssa_def *instance_node_addr = build_node_to_addr(device, &b, bvh_node);
1482               nir_ssa_def *instance_data =
1483                  nir_build_load_global(&b, 4, 32, instance_node_addr, .align_mul = 64);
1484               nir_ssa_def *wto_matrix[] = {
1485                  nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 16),
1486                                        .align_mul = 64, .align_offset = 16),
1487                  nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 32),
1488                                        .align_mul = 64, .align_offset = 32),
1489                  nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 48),
1490                                        .align_mul = 64, .align_offset = 48)};
1491               nir_ssa_def *instance_id =
1492                  nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, instance_node_addr, 88));
1493               nir_ssa_def *instance_and_mask = nir_channel(&b, instance_data, 2);
1494               nir_ssa_def *instance_mask = nir_ushr_imm(&b, instance_and_mask, 24);
1495
1496               nir_push_if(
1497                  &b,
1498                  nir_ieq_imm(&b, nir_iand(&b, instance_mask, nir_load_var(&b, vars.cull_mask)), 0));
1499               nir_jump(&b, nir_jump_continue);
1500               nir_pop_if(&b, NULL);
1501
1502               nir_store_var(&b, trav_vars.top_stack, nir_load_var(&b, trav_vars.stack), 1);
1503               nir_store_var(&b, trav_vars.bvh_base,
1504                             build_addr_to_node(
1505                                &b, nir_pack_64_2x32(&b, nir_channels(&b, instance_data, 0x3))),
1506                             1);
1507               stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
1508               nir_store_deref(&b, stack, nir_iand_imm(&b, nir_channel(&b, instance_data, 0), 63), 0x1);
1509
1510               nir_store_var(&b, trav_vars.stack,
1511                             nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
1512
1513               nir_store_var(
1514                  &b, trav_vars.origin,
1515                  nir_build_vec3_mat_mult_pre(&b, nir_load_var(&b, vars.origin), wto_matrix), 7);
1516               nir_store_var(
1517                  &b, trav_vars.dir,
1518                  nir_build_vec3_mat_mult(&b, nir_load_var(&b, vars.direction), wto_matrix, false),
1519                  7);
1520               nir_store_var(&b, trav_vars.inv_dir,
1521                             nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
1522               nir_store_var(&b, trav_vars.custom_instance_and_mask, instance_and_mask, 1);
1523               nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_channel(&b, instance_data, 3),
1524                             1);
1525               nir_store_var(&b, trav_vars.instance_id, instance_id, 1);
1526               nir_store_var(&b, trav_vars.instance_addr, instance_node_addr, 1);
1527            }
1528            nir_pop_if(&b, NULL);
1529         }
1530         nir_push_else(&b, NULL);
1531         {
1532            /* box */
1533            nir_ssa_def *result = intrinsic_result;
1534            if (!result) {
1535               /* If we didn't run the intrinsic cause the hardware didn't support it,
1536                * emulate ray/box intersection here */
1537               result = intersect_ray_amd_software_box(device,
1538                  &b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin),
1539                  nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir));
1540            }
1541
1542            for (unsigned i = 4; i-- > 0; ) {
1543               nir_ssa_def *new_node = nir_channel(&b, result, i);
1544               nir_push_if(&b, nir_ine_imm(&b, new_node, 0xffffffff));
1545               {
1546                  stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
1547                  nir_store_deref(&b, stack, new_node, 0x1);
1548                  nir_store_var(
1549                     &b, trav_vars.stack,
1550                     nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
1551               }
1552               nir_pop_if(&b, NULL);
1553            }
1554         }
1555         nir_pop_if(&b, NULL);
1556      }
1557      nir_push_else(&b, NULL);
1558      if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)) {
1559         nir_ssa_def *result = intrinsic_result;
1560         if (!result) {
1561            /* If we didn't run the intrinsic cause the hardware didn't support it,
1562             * emulate ray/tri intersection here */
1563            result = intersect_ray_amd_software_tri(device,
1564               &b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin),
1565               nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir));
1566         }
1567         insert_traversal_triangle_case(device, pCreateInfo, &b, result, &vars, &trav_vars, bvh_node);
1568      }
1569      nir_pop_if(&b, NULL);
1570
1571      nir_pop_loop(&b, NULL);
1572   }
1573   nir_pop_if(&b, NULL);
1574
1575   /* Initialize follow-up shader. */
1576   nir_push_if(&b, nir_load_var(&b, trav_vars.hit));
1577   {
1578      /* vars.idx contains the SBT index at this point. */
1579      load_sbt_entry(&b, &vars, nir_load_var(&b, vars.idx), SBT_HIT, 0);
1580
1581      nir_ssa_def *should_return = nir_ior(&b,
1582                                           nir_test_mask(&b, nir_load_var(&b, vars.flags),
1583                                                         SpvRayFlagsSkipClosestHitShaderKHRMask),
1584                                           nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
1585
1586      /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence
1587       * need to return immediately to the calling shader. */
1588      nir_push_if(&b, should_return);
1589      {
1590         insert_rt_return(&b, &vars);
1591      }
1592      nir_pop_if(&b, NULL);
1593   }
1594   nir_push_else(&b, NULL);
1595   {
1596      /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
1597       * for miss shaders if none of the rays miss. */
1598      load_sbt_entry(&b, &vars, nir_load_var(&b, vars.miss_index), SBT_MISS, 0);
1599   }
1600   nir_pop_if(&b, NULL);
1601
1602   return b.shader;
1603}
1604
1605
1606static void
1607insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1608                 nir_builder *b, const struct rt_variables *vars)
1609{
1610   struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
1611   nir_shader *shader = build_traversal_shader(device, pCreateInfo, vars, var_remap);
1612
1613   /* For now, just inline the traversal shader */
1614   nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 1));
1615   nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1);
1616   nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
1617   nir_pop_if(b, NULL);
1618
1619   /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
1620   ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
1621
1622   ralloc_free(var_remap);
1623}
1624
1625static unsigned
1626compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1627                      const struct radv_pipeline_shader_stack_size *stack_sizes)
1628{
1629   unsigned raygen_size = 0;
1630   unsigned callable_size = 0;
1631   unsigned chit_size = 0;
1632   unsigned miss_size = 0;
1633   unsigned non_recursive_size = 0;
1634
1635   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
1636      non_recursive_size = MAX2(stack_sizes[i].non_recursive_size, non_recursive_size);
1637
1638      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
1639      uint32_t shader_id = VK_SHADER_UNUSED_KHR;
1640      unsigned size = stack_sizes[i].recursive_size;
1641
1642      switch (group_info->type) {
1643      case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
1644         shader_id = group_info->generalShader;
1645         break;
1646      case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
1647      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
1648         shader_id = group_info->closestHitShader;
1649         break;
1650      default:
1651         break;
1652      }
1653      if (shader_id == VK_SHADER_UNUSED_KHR)
1654         continue;
1655
1656      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
1657      switch (stage->stage) {
1658      case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1659         raygen_size = MAX2(raygen_size, size);
1660         break;
1661      case VK_SHADER_STAGE_MISS_BIT_KHR:
1662         miss_size = MAX2(miss_size, size);
1663         break;
1664      case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1665         chit_size = MAX2(chit_size, size);
1666         break;
1667      case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1668         callable_size = MAX2(callable_size, size);
1669         break;
1670      default:
1671         unreachable("Invalid stage type in RT shader");
1672      }
1673   }
1674   return raygen_size +
1675          MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) *
1676             MAX2(MAX2(chit_size, miss_size), non_recursive_size) +
1677          MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) *
1678             MAX2(chit_size, miss_size) +
1679          2 * callable_size;
1680}
1681
1682bool
1683radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
1684{
1685   if (!pCreateInfo->pDynamicState)
1686      return false;
1687
1688   for (unsigned i = 0; i < pCreateInfo->pDynamicState->dynamicStateCount; ++i) {
1689      if (pCreateInfo->pDynamicState->pDynamicStates[i] ==
1690          VK_DYNAMIC_STATE_RAY_TRACING_PIPELINE_STACK_SIZE_KHR)
1691         return true;
1692   }
1693
1694   return false;
1695}
1696
1697static bool
1698should_move_rt_instruction(nir_intrinsic_op intrinsic)
1699{
1700   switch (intrinsic) {
1701   case nir_intrinsic_load_rt_arg_scratch_offset_amd:
1702   case nir_intrinsic_load_ray_flags:
1703   case nir_intrinsic_load_ray_object_origin:
1704   case nir_intrinsic_load_ray_world_origin:
1705   case nir_intrinsic_load_ray_t_min:
1706   case nir_intrinsic_load_ray_object_direction:
1707   case nir_intrinsic_load_ray_world_direction:
1708   case nir_intrinsic_load_ray_t_max:
1709      return true;
1710   default:
1711      return false;
1712   }
1713}
1714
1715static void
1716move_rt_instructions(nir_shader *shader)
1717{
1718   nir_cursor target = nir_before_cf_list(&nir_shader_get_entrypoint(shader)->body);
1719
1720   nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
1721      nir_foreach_instr_safe (instr, block) {
1722         if (instr->type != nir_instr_type_intrinsic)
1723            continue;
1724
1725         nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
1726
1727         if (!should_move_rt_instruction(intrinsic->intrinsic))
1728            continue;
1729
1730         nir_instr_move(target, instr);
1731      }
1732   }
1733
1734   nir_metadata_preserve(nir_shader_get_entrypoint(shader),
1735                         nir_metadata_all & (~nir_metadata_instr_index));
1736}
1737
1738static nir_shader *
1739create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1740                 struct radv_pipeline_shader_stack_size *stack_sizes)
1741{
1742   struct radv_pipeline_key key;
1743   memset(&key, 0, sizeof(key));
1744
1745   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined");
1746   b.shader->info.internal = false;
1747   b.shader->info.workgroup_size[0] = 8;
1748   b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
1749
1750   struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
1751   load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, 0);
1752   nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
1753
1754   nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1);
1755
1756   nir_loop *loop = nir_push_loop(&b);
1757
1758   nir_push_if(&b, nir_ior(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0),
1759                           nir_inot(&b, nir_load_var(&b, vars.main_loop_case_visited))));
1760   nir_jump(&b, nir_jump_break);
1761   nir_pop_if(&b, NULL);
1762
1763   nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, false), 1);
1764
1765   insert_traversal(device, pCreateInfo, &b, &vars);
1766
1767   nir_ssa_def *idx = nir_load_var(&b, vars.idx);
1768
1769   /* We do a trick with the indexing of the resume shaders so that the first
1770    * shader of stage x always gets id x and the resume shader ids then come after
1771    * stageCount. This makes the shadergroup handles independent of compilation. */
1772   unsigned call_idx_base = pCreateInfo->stageCount + 1;
1773   for (unsigned i = 0; i < pCreateInfo->stageCount; ++i) {
1774      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i];
1775      gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
1776      if (type != MESA_SHADER_RAYGEN && type != MESA_SHADER_CALLABLE &&
1777          type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS)
1778         continue;
1779
1780      nir_shader *nir_stage = parse_rt_stage(device, stage);
1781
1782      /* Move ray tracing system values to the top that are set by rt_trace_ray
1783       * to prevent them from being overwritten by other rt_trace_ray calls.
1784       */
1785      NIR_PASS_V(nir_stage, move_rt_instructions);
1786
1787      uint32_t num_resume_shaders = 0;
1788      nir_shader **resume_shaders = NULL;
1789      nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders,
1790                             &num_resume_shaders, nir_stage);
1791
1792      vars.stage_idx = i;
1793      insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2);
1794      for (unsigned j = 0; j < num_resume_shaders; ++j) {
1795         insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j);
1796      }
1797      call_idx_base += num_resume_shaders;
1798   }
1799
1800   nir_pop_loop(&b, loop);
1801
1802   if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) {
1803      /* Put something so scratch gets enabled in the shader. */
1804      b.shader->scratch_size = 16;
1805   } else
1806      b.shader->scratch_size = compute_rt_stack_size(pCreateInfo, stack_sizes);
1807
1808   /* Deal with all the inline functions. */
1809   nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
1810   nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
1811
1812   return b.shader;
1813}
1814
1815static VkResult
1816radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
1817                        const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1818                        const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)
1819{
1820   RADV_FROM_HANDLE(radv_device, device, _device);
1821   VkResult result;
1822   struct radv_pipeline *pipeline = NULL;
1823   struct radv_compute_pipeline *compute_pipeline = NULL;
1824   struct radv_pipeline_shader_stack_size *stack_sizes = NULL;
1825   uint8_t hash[20];
1826   nir_shader *shader = NULL;
1827   bool keep_statistic_info =
1828      (pCreateInfo->flags & VK_PIPELINE_CREATE_CAPTURE_STATISTICS_BIT_KHR) ||
1829      (device->instance->debug_flags & RADV_DEBUG_DUMP_SHADER_STATS) || device->keep_shader_info;
1830
1831   if (pCreateInfo->flags & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR)
1832      return radv_rt_pipeline_library_create(_device, _cache, pCreateInfo, pAllocator, pPipeline);
1833
1834   VkRayTracingPipelineCreateInfoKHR local_create_info =
1835      radv_create_merged_rt_create_info(pCreateInfo);
1836   if (!local_create_info.pStages || !local_create_info.pGroups) {
1837      result = VK_ERROR_OUT_OF_HOST_MEMORY;
1838      goto fail;
1839   }
1840
1841   radv_hash_rt_shaders(hash, &local_create_info, radv_get_hash_flags(device, keep_statistic_info));
1842   struct vk_shader_module module = {.base.type = VK_OBJECT_TYPE_SHADER_MODULE};
1843
1844   VkPipelineShaderStageRequiredSubgroupSizeCreateInfo subgroup_size = {
1845      .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO,
1846      .pNext = NULL,
1847      .requiredSubgroupSize = device->physical_device->rt_wave_size,
1848   };
1849
1850   VkComputePipelineCreateInfo compute_info = {
1851      .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1852      .pNext = NULL,
1853      .flags = pCreateInfo->flags | VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT,
1854      .stage =
1855         {
1856            .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1857            .pNext = &subgroup_size,
1858            .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1859            .module = vk_shader_module_to_handle(&module),
1860            .pName = "main",
1861         },
1862      .layout = pCreateInfo->layout,
1863   };
1864
1865   /* First check if we can get things from the cache before we take the expensive step of
1866    * generating the nir. */
1867   result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash,
1868                                         stack_sizes, local_create_info.groupCount, pPipeline);
1869
1870   if (result == VK_PIPELINE_COMPILE_REQUIRED) {
1871      if (pCreateInfo->flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT)
1872         goto fail;
1873
1874      stack_sizes = calloc(sizeof(*stack_sizes), local_create_info.groupCount);
1875      if (!stack_sizes) {
1876         result = VK_ERROR_OUT_OF_HOST_MEMORY;
1877         goto fail;
1878      }
1879
1880      shader = create_rt_shader(device, &local_create_info, stack_sizes);
1881      module.nir = shader;
1882      compute_info.flags = pCreateInfo->flags;
1883      result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash,
1884                                            stack_sizes, local_create_info.groupCount, pPipeline);
1885      stack_sizes = NULL;
1886
1887      if (result != VK_SUCCESS)
1888         goto shader_fail;
1889   }
1890   pipeline = radv_pipeline_from_handle(*pPipeline);
1891   compute_pipeline = radv_pipeline_to_compute(pipeline);
1892
1893   compute_pipeline->rt_group_handles =
1894      calloc(sizeof(*compute_pipeline->rt_group_handles), local_create_info.groupCount);
1895   if (!compute_pipeline->rt_group_handles) {
1896      result = VK_ERROR_OUT_OF_HOST_MEMORY;
1897      goto shader_fail;
1898   }
1899
1900   compute_pipeline->dynamic_stack_size = radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo);
1901
1902   /* For General and ClosestHit shaders, we can use the shader ID directly as handle.
1903    * As (potentially different) AnyHit shaders are inlined, for Intersection shaders
1904    * we use the Group ID.
1905    */
1906   for (unsigned i = 0; i < local_create_info.groupCount; ++i) {
1907      const VkRayTracingShaderGroupCreateInfoKHR *group_info = &local_create_info.pGroups[i];
1908      switch (group_info->type) {
1909      case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
1910         if (group_info->generalShader != VK_SHADER_UNUSED_KHR)
1911            compute_pipeline->rt_group_handles[i].handles[0] = group_info->generalShader + 2;
1912         break;
1913      case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
1914         if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR)
1915            compute_pipeline->rt_group_handles[i].handles[1] = i + 2;
1916         FALLTHROUGH;
1917      case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
1918         if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR)
1919            compute_pipeline->rt_group_handles[i].handles[0] = group_info->closestHitShader + 2;
1920         if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
1921            compute_pipeline->rt_group_handles[i].handles[1] = i + 2;
1922         break;
1923      case VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR:
1924         unreachable("VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR");
1925      }
1926   }
1927
1928shader_fail:
1929   if (result != VK_SUCCESS && pipeline)
1930      radv_pipeline_destroy(device, pipeline, pAllocator);
1931   ralloc_free(shader);
1932fail:
1933   free((void *)local_create_info.pGroups);
1934   free((void *)local_create_info.pStages);
1935   free(stack_sizes);
1936   return result;
1937}
1938
1939VKAPI_ATTR VkResult VKAPI_CALL
1940radv_CreateRayTracingPipelinesKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation,
1941                                  VkPipelineCache pipelineCache, uint32_t count,
1942                                  const VkRayTracingPipelineCreateInfoKHR *pCreateInfos,
1943                                  const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines)
1944{
1945   VkResult result = VK_SUCCESS;
1946
1947   unsigned i = 0;
1948   for (; i < count; i++) {
1949      VkResult r;
1950      r = radv_rt_pipeline_create(_device, pipelineCache, &pCreateInfos[i], pAllocator,
1951                                  &pPipelines[i]);
1952      if (r != VK_SUCCESS) {
1953         result = r;
1954         pPipelines[i] = VK_NULL_HANDLE;
1955
1956         if (pCreateInfos[i].flags & VK_PIPELINE_CREATE_EARLY_RETURN_ON_FAILURE_BIT)
1957            break;
1958      }
1959   }
1960
1961   for (; i < count; ++i)
1962      pPipelines[i] = VK_NULL_HANDLE;
1963
1964   return result;
1965}
1966
1967VKAPI_ATTR VkResult VKAPI_CALL
1968radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, uint32_t firstGroup,
1969                                        uint32_t groupCount, size_t dataSize, void *pData)
1970{
1971   RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
1972   struct radv_compute_pipeline *compute_pipeline = radv_pipeline_to_compute(pipeline);
1973   char *data = pData;
1974
1975   STATIC_ASSERT(sizeof(*compute_pipeline->rt_group_handles) <= RADV_RT_HANDLE_SIZE);
1976
1977   memset(data, 0, groupCount * RADV_RT_HANDLE_SIZE);
1978
1979   for (uint32_t i = 0; i < groupCount; ++i) {
1980      memcpy(data + i * RADV_RT_HANDLE_SIZE, &compute_pipeline->rt_group_handles[firstGroup + i],
1981             sizeof(*compute_pipeline->rt_group_handles));
1982   }
1983
1984   return VK_SUCCESS;
1985}
1986
1987VKAPI_ATTR VkDeviceSize VKAPI_CALL
1988radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline, uint32_t group,
1989                                          VkShaderGroupShaderKHR groupShader)
1990{
1991   RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
1992   struct radv_compute_pipeline *compute_pipeline = radv_pipeline_to_compute(pipeline);
1993   const struct radv_pipeline_shader_stack_size *stack_size =
1994      &compute_pipeline->rt_stack_sizes[group];
1995
1996   if (groupShader == VK_SHADER_GROUP_SHADER_ANY_HIT_KHR ||
1997       groupShader == VK_SHADER_GROUP_SHADER_INTERSECTION_KHR)
1998      return stack_size->non_recursive_size;
1999   else
2000      return stack_size->recursive_size;
2001}
2002
2003VKAPI_ATTR VkResult VKAPI_CALL
2004radv_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice _device, VkPipeline pipeline,
2005                                                     uint32_t firstGroup, uint32_t groupCount,
2006                                                     size_t dataSize, void *pData)
2007{
2008   RADV_FROM_HANDLE(radv_device, device, _device);
2009   unreachable("Unimplemented");
2010   return vk_error(device, VK_ERROR_FEATURE_NOT_PRESENT);
2011}
2012