1 /*
2  * Copyright © 2021 Google
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "radv_acceleration_structure.h"
25 #include "radv_debug.h"
26 #include "radv_meta.h"
27 #include "radv_private.h"
28 #include "radv_rt_common.h"
29 #include "radv_shader.h"
30 
31 #include "nir/nir.h"
32 #include "nir/nir_builder.h"
33 #include "nir/nir_builtin_builder.h"
34 
35 static VkRayTracingPipelineCreateInfoKHR
radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)36 radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
37 {
38    VkRayTracingPipelineCreateInfoKHR local_create_info = *pCreateInfo;
39    uint32_t total_stages = pCreateInfo->stageCount;
40    uint32_t total_groups = pCreateInfo->groupCount;
41 
42    if (pCreateInfo->pLibraryInfo) {
43       for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
44          RADV_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]);
45          struct radv_library_pipeline *library_pipeline = radv_pipeline_to_library(pipeline);
46 
47          total_stages += library_pipeline->stage_count;
48          total_groups += library_pipeline->group_count;
49       }
50    }
51    VkPipelineShaderStageCreateInfo *stages = NULL;
52    VkRayTracingShaderGroupCreateInfoKHR *groups = NULL;
53    local_create_info.stageCount = total_stages;
54    local_create_info.groupCount = total_groups;
55    local_create_info.pStages = stages =
56       malloc(sizeof(VkPipelineShaderStageCreateInfo) * total_stages);
57    local_create_info.pGroups = groups =
58       malloc(sizeof(VkRayTracingShaderGroupCreateInfoKHR) * total_groups);
59    if (!local_create_info.pStages || !local_create_info.pGroups)
60       return local_create_info;
61 
62    total_stages = pCreateInfo->stageCount;
63    total_groups = pCreateInfo->groupCount;
64    for (unsigned j = 0; j < pCreateInfo->stageCount; ++j)
65       stages[j] = pCreateInfo->pStages[j];
66    for (unsigned j = 0; j < pCreateInfo->groupCount; ++j)
67       groups[j] = pCreateInfo->pGroups[j];
68 
69    if (pCreateInfo->pLibraryInfo) {
70       for (unsigned i = 0; i < pCreateInfo->pLibraryInfo->libraryCount; ++i) {
71          RADV_FROM_HANDLE(radv_pipeline, pipeline, pCreateInfo->pLibraryInfo->pLibraries[i]);
72          struct radv_library_pipeline *library_pipeline = radv_pipeline_to_library(pipeline);
73 
74          for (unsigned j = 0; j < library_pipeline->stage_count; ++j)
75             stages[total_stages + j] = library_pipeline->stages[j];
76          for (unsigned j = 0; j < library_pipeline->group_count; ++j) {
77             VkRayTracingShaderGroupCreateInfoKHR *dst = &groups[total_groups + j];
78             *dst = library_pipeline->groups[j];
79             if (dst->generalShader != VK_SHADER_UNUSED_KHR)
80                dst->generalShader += total_stages;
81             if (dst->closestHitShader != VK_SHADER_UNUSED_KHR)
82                dst->closestHitShader += total_stages;
83             if (dst->anyHitShader != VK_SHADER_UNUSED_KHR)
84                dst->anyHitShader += total_stages;
85             if (dst->intersectionShader != VK_SHADER_UNUSED_KHR)
86                dst->intersectionShader += total_stages;
87          }
88          total_stages += library_pipeline->stage_count;
89          total_groups += library_pipeline->group_count;
90       }
91    }
92    return local_create_info;
93 }
94 
95 static VkResult
radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)96 radv_rt_pipeline_library_create(VkDevice _device, VkPipelineCache _cache,
97                                 const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
98                                 const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)
99 {
100    RADV_FROM_HANDLE(radv_device, device, _device);
101    struct radv_library_pipeline *pipeline;
102 
103    pipeline = vk_zalloc2(&device->vk.alloc, pAllocator, sizeof(*pipeline), 8,
104                          VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
105    if (pipeline == NULL)
106       return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
107 
108    radv_pipeline_init(device, &pipeline->base, RADV_PIPELINE_LIBRARY);
109 
110    VkRayTracingPipelineCreateInfoKHR local_create_info =
111       radv_create_merged_rt_create_info(pCreateInfo);
112    if (!local_create_info.pStages || !local_create_info.pGroups)
113       goto fail;
114 
115    if (local_create_info.stageCount) {
116       pipeline->stage_count = local_create_info.stageCount;
117 
118       size_t size = sizeof(VkPipelineShaderStageCreateInfo) * local_create_info.stageCount;
119       pipeline->stages = malloc(size);
120       if (!pipeline->stages)
121          goto fail;
122 
123       memcpy(pipeline->stages, local_create_info.pStages, size);
124 
125       pipeline->hashes = malloc(sizeof(*pipeline->hashes) * local_create_info.stageCount);
126       if (!pipeline->hashes)
127          goto fail;
128 
129       pipeline->identifiers = malloc(sizeof(*pipeline->identifiers) * local_create_info.stageCount);
130       if (!pipeline->identifiers)
131          goto fail;
132 
133       for (uint32_t i = 0; i < local_create_info.stageCount; i++) {
134          RADV_FROM_HANDLE(vk_shader_module, module, pipeline->stages[i].module);
135 
136          const VkPipelineShaderStageModuleIdentifierCreateInfoEXT *iinfo =
137             vk_find_struct_const(local_create_info.pStages[i].pNext,
138                                  PIPELINE_SHADER_STAGE_MODULE_IDENTIFIER_CREATE_INFO_EXT);
139 
140          if (module) {
141             struct vk_shader_module *new_module = vk_shader_module_clone(NULL, module);
142             pipeline->stages[i].module = vk_shader_module_to_handle(new_module);
143             pipeline->stages[i].pNext = NULL;
144          } else {
145             assert(iinfo);
146             pipeline->identifiers[i].identifierSize =
147                MIN2(iinfo->identifierSize, sizeof(pipeline->hashes[i].sha1));
148             memcpy(pipeline->hashes[i].sha1, iinfo->pIdentifier,
149                    pipeline->identifiers[i].identifierSize);
150             pipeline->stages[i].module = VK_NULL_HANDLE;
151             pipeline->stages[i].pNext = &pipeline->identifiers[i];
152             pipeline->identifiers[i].sType =
153                VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_MODULE_IDENTIFIER_CREATE_INFO_EXT;
154             pipeline->identifiers[i].pNext = NULL;
155             pipeline->identifiers[i].pIdentifier = pipeline->hashes[i].sha1;
156          }
157       }
158    }
159 
160    if (local_create_info.groupCount) {
161       size_t size = sizeof(VkRayTracingShaderGroupCreateInfoKHR) * local_create_info.groupCount;
162       pipeline->group_count = local_create_info.groupCount;
163       pipeline->groups = malloc(size);
164       if (!pipeline->groups)
165          goto fail;
166       memcpy(pipeline->groups, local_create_info.pGroups, size);
167    }
168 
169    *pPipeline = radv_pipeline_to_handle(&pipeline->base);
170 
171    free((void *)local_create_info.pGroups);
172    free((void *)local_create_info.pStages);
173    return VK_SUCCESS;
174 fail:
175    free(pipeline->groups);
176    free(pipeline->stages);
177    free(pipeline->hashes);
178    free(pipeline->identifiers);
179    free((void *)local_create_info.pGroups);
180    free((void *)local_create_info.pStages);
181    return VK_ERROR_OUT_OF_HOST_MEMORY;
182 }
183 
184 /*
185  * Global variables for an RT pipeline
186  */
187 struct rt_variables {
188    const VkRayTracingPipelineCreateInfoKHR *create_info;
189 
190    /* idx of the next shader to run in the next iteration of the main loop.
191     * During traversal, idx is used to store the SBT index and will contain
192     * the correct resume index upon returning.
193     */
194    nir_variable *idx;
195 
196    /* scratch offset of the argument area relative to stack_ptr */
197    nir_variable *arg;
198 
199    nir_variable *stack_ptr;
200 
201    /* global address of the SBT entry used for the shader */
202    nir_variable *shader_record_ptr;
203 
204    /* trace_ray arguments */
205    nir_variable *accel_struct;
206    nir_variable *flags;
207    nir_variable *cull_mask;
208    nir_variable *sbt_offset;
209    nir_variable *sbt_stride;
210    nir_variable *miss_index;
211    nir_variable *origin;
212    nir_variable *tmin;
213    nir_variable *direction;
214    nir_variable *tmax;
215 
216    /* from the BTAS instance currently being visited */
217    nir_variable *custom_instance_and_mask;
218 
219    /* Properties of the primitive currently being visited. */
220    nir_variable *primitive_id;
221    nir_variable *geometry_id_and_flags;
222    nir_variable *instance_id;
223    nir_variable *instance_addr;
224    nir_variable *hit_kind;
225    nir_variable *opaque;
226 
227    /* Safeguard to ensure we don't end up in an infinite loop of non-existing case. Should not be
228     * needed but is extra anti-hang safety during bring-up. */
229    nir_variable *main_loop_case_visited;
230 
231    /* Output variables for intersection & anyhit shaders. */
232    nir_variable *ahit_accept;
233    nir_variable *ahit_terminate;
234 
235    /* Array of stack size struct for recording the max stack size for each group. */
236    struct radv_pipeline_shader_stack_size *stack_sizes;
237    unsigned stage_idx;
238 };
239 
240 static void
reserve_stack_size(struct rt_variables *vars, uint32_t size)241 reserve_stack_size(struct rt_variables *vars, uint32_t size)
242 {
243    for (uint32_t group_idx = 0; group_idx < vars->create_info->groupCount; group_idx++) {
244       const VkRayTracingShaderGroupCreateInfoKHR *group = vars->create_info->pGroups + group_idx;
245 
246       if (vars->stage_idx == group->generalShader || vars->stage_idx == group->closestHitShader)
247          vars->stack_sizes[group_idx].recursive_size =
248             MAX2(vars->stack_sizes[group_idx].recursive_size, size);
249 
250       if (vars->stage_idx == group->anyHitShader || vars->stage_idx == group->intersectionShader)
251          vars->stack_sizes[group_idx].non_recursive_size =
252             MAX2(vars->stack_sizes[group_idx].non_recursive_size, size);
253    }
254 }
255 
256 static struct rt_variables
create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info, struct radv_pipeline_shader_stack_size *stack_sizes)257 create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *create_info,
258                     struct radv_pipeline_shader_stack_size *stack_sizes)
259 {
260    struct rt_variables vars = {
261       .create_info = create_info,
262    };
263    vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
264    vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
265    vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
266    vars.shader_record_ptr =
267       nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr");
268 
269    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
270    vars.accel_struct =
271       nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct");
272    vars.flags = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "ray_flags");
273    vars.cull_mask = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "cull_mask");
274    vars.sbt_offset =
275       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_offset");
276    vars.sbt_stride =
277       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "sbt_stride");
278    vars.miss_index =
279       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "miss_index");
280    vars.origin = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_origin");
281    vars.tmin = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmin");
282    vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
283    vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
284 
285    vars.custom_instance_and_mask = nir_variable_create(
286       shader, nir_var_shader_temp, glsl_uint_type(), "custom_instance_and_mask");
287    vars.primitive_id =
288       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
289    vars.geometry_id_and_flags =
290       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
291    vars.instance_id =
292       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "instance_id");
293    vars.instance_addr =
294       nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
295    vars.hit_kind = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "hit_kind");
296    vars.opaque = nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "opaque");
297 
298    vars.main_loop_case_visited =
299       nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "main_loop_case_visited");
300    vars.ahit_accept =
301       nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_accept");
302    vars.ahit_terminate =
303       nir_variable_create(shader, nir_var_shader_temp, glsl_bool_type(), "ahit_terminate");
304 
305    vars.stack_sizes = stack_sizes;
306    return vars;
307 }
308 
309 /*
310  * Remap all the variables between the two rt_variables struct for inlining.
311  */
312 static void
map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, const struct rt_variables *dst)313 map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
314                  const struct rt_variables *dst)
315 {
316    src->create_info = dst->create_info;
317 
318    _mesa_hash_table_insert(var_remap, src->idx, dst->idx);
319    _mesa_hash_table_insert(var_remap, src->arg, dst->arg);
320    _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
321    _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
322 
323    _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct);
324    _mesa_hash_table_insert(var_remap, src->flags, dst->flags);
325    _mesa_hash_table_insert(var_remap, src->cull_mask, dst->cull_mask);
326    _mesa_hash_table_insert(var_remap, src->sbt_offset, dst->sbt_offset);
327    _mesa_hash_table_insert(var_remap, src->sbt_stride, dst->sbt_stride);
328    _mesa_hash_table_insert(var_remap, src->miss_index, dst->miss_index);
329    _mesa_hash_table_insert(var_remap, src->origin, dst->origin);
330    _mesa_hash_table_insert(var_remap, src->tmin, dst->tmin);
331    _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
332    _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
333 
334    _mesa_hash_table_insert(var_remap, src->custom_instance_and_mask, dst->custom_instance_and_mask);
335    _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
336    _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
337    _mesa_hash_table_insert(var_remap, src->instance_id, dst->instance_id);
338    _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
339    _mesa_hash_table_insert(var_remap, src->hit_kind, dst->hit_kind);
340    _mesa_hash_table_insert(var_remap, src->opaque, dst->opaque);
341    _mesa_hash_table_insert(var_remap, src->ahit_accept, dst->ahit_accept);
342    _mesa_hash_table_insert(var_remap, src->ahit_terminate, dst->ahit_terminate);
343 
344    src->stack_sizes = dst->stack_sizes;
345    src->stage_idx = dst->stage_idx;
346 }
347 
348 /*
349  * Create a copy of the global rt variables where the primitive/instance related variables are
350  * independent.This is needed as we need to keep the old values of the global variables around
351  * in case e.g. an anyhit shader reject the collision. So there are inner variables that get copied
352  * to the outer variables once we commit to a better hit.
353  */
354 static struct rt_variables
create_inner_vars(nir_builder *b, const struct rt_variables *vars)355 create_inner_vars(nir_builder *b, const struct rt_variables *vars)
356 {
357    struct rt_variables inner_vars = *vars;
358    inner_vars.idx =
359       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
360    inner_vars.shader_record_ptr = nir_variable_create(
361       b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
362    inner_vars.primitive_id =
363       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
364    inner_vars.geometry_id_and_flags = nir_variable_create(
365       b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_geometry_id_and_flags");
366    inner_vars.tmax =
367       nir_variable_create(b->shader, nir_var_shader_temp, glsl_float_type(), "inner_tmax");
368    inner_vars.instance_id =
369       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_instance_id");
370    inner_vars.instance_addr = nir_variable_create(b->shader, nir_var_shader_temp,
371                                                   glsl_uint64_t_type(), "inner_instance_addr");
372    inner_vars.hit_kind =
373       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_kind");
374    inner_vars.custom_instance_and_mask = nir_variable_create(
375       b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_custom_instance_and_mask");
376 
377    return inner_vars;
378 }
379 
380 /* The hit attributes are stored on the stack. This is the offset compared to the current stack
381  * pointer of where the hit attrib is stored. */
382 const uint32_t RADV_HIT_ATTRIB_OFFSET = -(16 + RADV_MAX_HIT_ATTRIB_SIZE);
383 
384 static void
insert_rt_return(nir_builder *b, const struct rt_variables *vars)385 insert_rt_return(nir_builder *b, const struct rt_variables *vars)
386 {
387    nir_store_var(b, vars->stack_ptr, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), -16), 1);
388    nir_store_var(b, vars->idx,
389                  nir_load_scratch(b, 1, 32, nir_load_var(b, vars->stack_ptr), .align_mul = 16), 1);
390 }
391 
392 enum sbt_type {
393    SBT_RAYGEN = offsetof(VkTraceRaysIndirectCommand2KHR, raygenShaderRecordAddress),
394    SBT_MISS = offsetof(VkTraceRaysIndirectCommand2KHR, missShaderBindingTableAddress),
395    SBT_HIT = offsetof(VkTraceRaysIndirectCommand2KHR, hitShaderBindingTableAddress),
396    SBT_CALLABLE = offsetof(VkTraceRaysIndirectCommand2KHR, callableShaderBindingTableAddress),
397 };
398 
399 static nir_ssa_def *
get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding)400 get_sbt_ptr(nir_builder *b, nir_ssa_def *idx, enum sbt_type binding)
401 {
402    nir_ssa_def *desc_base_addr = nir_load_sbt_base_amd(b);
403 
404    nir_ssa_def *desc =
405       nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, nir_imm_int(b, binding)));
406 
407    nir_ssa_def *stride_offset = nir_imm_int(b, binding + (binding == SBT_RAYGEN ? 8 : 16));
408    nir_ssa_def *stride =
409       nir_pack_64_2x32(b, nir_build_load_smem_amd(b, 2, desc_base_addr, stride_offset));
410 
411    return nir_iadd(b, desc, nir_imul(b, nir_u2u64(b, idx), stride));
412 }
413 
414 static void
load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx, enum sbt_type binding, unsigned offset)415 load_sbt_entry(nir_builder *b, const struct rt_variables *vars, nir_ssa_def *idx,
416                enum sbt_type binding, unsigned offset)
417 {
418    nir_ssa_def *addr = get_sbt_ptr(b, idx, binding);
419 
420    nir_ssa_def *load_addr = nir_iadd_imm(b, addr, offset);
421    nir_ssa_def *v_idx = nir_build_load_global(b, 1, 32, load_addr);
422 
423    nir_store_var(b, vars->idx, v_idx, 1);
424 
425    nir_ssa_def *record_addr = nir_iadd_imm(b, addr, RADV_RT_HANDLE_SIZE);
426    nir_store_var(b, vars->shader_record_ptr, record_addr, 1);
427 }
428 
429 /* This lowers all the RT instructions that we do not want to pass on to the combined shader and
430  * that we can implement using the variables from the shader we are going to inline into. */
431 static void
lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base)432 lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned call_idx_base)
433 {
434    nir_builder b_shader;
435    nir_builder_init(&b_shader, nir_shader_get_entrypoint(shader));
436 
437    nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
438       nir_foreach_instr_safe (instr, block) {
439          switch (instr->type) {
440          case nir_instr_type_intrinsic: {
441             b_shader.cursor = nir_before_instr(instr);
442             nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
443             nir_ssa_def *ret = NULL;
444 
445             switch (intr->intrinsic) {
446             case nir_intrinsic_rt_execute_callable: {
447                uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
448                uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
449 
450                nir_store_var(
451                   &b_shader, vars->stack_ptr,
452                   nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
453                nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx),
454                                  nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
455 
456                nir_store_var(&b_shader, vars->stack_ptr,
457                              nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16),
458                              1);
459                load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_CALLABLE, 0);
460 
461                nir_store_var(&b_shader, vars->arg,
462                              nir_iadd_imm(&b_shader, intr->src[1].ssa, -size - 16), 1);
463 
464                reserve_stack_size(vars, size + 16);
465                break;
466             }
467             case nir_intrinsic_rt_trace_ray: {
468                uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
469                uint32_t ret_idx = call_idx_base + nir_intrinsic_call_idx(intr) + 1;
470 
471                nir_store_var(
472                   &b_shader, vars->stack_ptr,
473                   nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), size), 1);
474                nir_store_scratch(&b_shader, nir_imm_int(&b_shader, ret_idx),
475                                  nir_load_var(&b_shader, vars->stack_ptr), .align_mul = 16);
476 
477                nir_store_var(&b_shader, vars->stack_ptr,
478                              nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), 16),
479                              1);
480 
481                nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 1), 1);
482                nir_store_var(&b_shader, vars->arg,
483                              nir_iadd_imm(&b_shader, intr->src[10].ssa, -size - 16), 1);
484 
485                reserve_stack_size(vars, size + 16);
486 
487                /* Per the SPIR-V extension spec we have to ignore some bits for some arguments. */
488                nir_store_var(&b_shader, vars->accel_struct, intr->src[0].ssa, 0x1);
489                nir_store_var(&b_shader, vars->flags, intr->src[1].ssa, 0x1);
490                nir_store_var(&b_shader, vars->cull_mask,
491                              nir_iand_imm(&b_shader, intr->src[2].ssa, 0xff), 0x1);
492                nir_store_var(&b_shader, vars->sbt_offset,
493                              nir_iand_imm(&b_shader, intr->src[3].ssa, 0xf), 0x1);
494                nir_store_var(&b_shader, vars->sbt_stride,
495                              nir_iand_imm(&b_shader, intr->src[4].ssa, 0xf), 0x1);
496                nir_store_var(&b_shader, vars->miss_index,
497                              nir_iand_imm(&b_shader, intr->src[5].ssa, 0xffff), 0x1);
498                nir_store_var(&b_shader, vars->origin, intr->src[6].ssa, 0x7);
499                nir_store_var(&b_shader, vars->tmin, intr->src[7].ssa, 0x1);
500                nir_store_var(&b_shader, vars->direction, intr->src[8].ssa, 0x7);
501                nir_store_var(&b_shader, vars->tmax, intr->src[9].ssa, 0x1);
502                break;
503             }
504             case nir_intrinsic_rt_resume: {
505                uint32_t size = align(nir_intrinsic_stack_size(intr), 16) + RADV_MAX_HIT_ATTRIB_SIZE;
506 
507                nir_store_var(
508                   &b_shader, vars->stack_ptr,
509                   nir_iadd_imm(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), -size), 1);
510                break;
511             }
512             case nir_intrinsic_rt_return_amd: {
513                if (shader->info.stage == MESA_SHADER_RAYGEN) {
514                   nir_store_var(&b_shader, vars->idx, nir_imm_int(&b_shader, 0), 1);
515                   break;
516                }
517                insert_rt_return(&b_shader, vars);
518                break;
519             }
520             case nir_intrinsic_load_scratch: {
521                nir_instr_rewrite_src_ssa(
522                   instr, &intr->src[0],
523                   nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[0].ssa));
524                continue;
525             }
526             case nir_intrinsic_store_scratch: {
527                nir_instr_rewrite_src_ssa(
528                   instr, &intr->src[1],
529                   nir_iadd(&b_shader, nir_load_var(&b_shader, vars->stack_ptr), intr->src[1].ssa));
530                continue;
531             }
532             case nir_intrinsic_load_rt_arg_scratch_offset_amd: {
533                ret = nir_load_var(&b_shader, vars->arg);
534                break;
535             }
536             case nir_intrinsic_load_shader_record_ptr: {
537                ret = nir_load_var(&b_shader, vars->shader_record_ptr);
538                break;
539             }
540             case nir_intrinsic_load_ray_launch_id: {
541                ret = nir_load_global_invocation_id(&b_shader, 32);
542                break;
543             }
544             case nir_intrinsic_load_ray_launch_size: {
545                nir_ssa_def *launch_size_addr =
546                   nir_load_ray_launch_size_addr_amd(&b_shader);
547 
548                nir_ssa_def * xy = nir_build_load_smem_amd(
549                   &b_shader, 2, launch_size_addr, nir_imm_int(&b_shader, 0));
550                nir_ssa_def * z = nir_build_load_smem_amd(
551                   &b_shader, 1, launch_size_addr, nir_imm_int(&b_shader, 8));
552 
553                nir_ssa_def *xyz[3] = {
554                   nir_channel(&b_shader, xy, 0),
555                   nir_channel(&b_shader, xy, 1),
556                   z,
557                };
558                ret = nir_vec(&b_shader, xyz, 3);
559                break;
560             }
561             case nir_intrinsic_load_ray_t_min: {
562                ret = nir_load_var(&b_shader, vars->tmin);
563                break;
564             }
565             case nir_intrinsic_load_ray_t_max: {
566                ret = nir_load_var(&b_shader, vars->tmax);
567                break;
568             }
569             case nir_intrinsic_load_ray_world_origin: {
570                ret = nir_load_var(&b_shader, vars->origin);
571                break;
572             }
573             case nir_intrinsic_load_ray_world_direction: {
574                ret = nir_load_var(&b_shader, vars->direction);
575                break;
576             }
577             case nir_intrinsic_load_ray_instance_custom_index: {
578                ret = nir_load_var(&b_shader, vars->custom_instance_and_mask);
579                ret = nir_iand_imm(&b_shader, ret, 0xFFFFFF);
580                break;
581             }
582             case nir_intrinsic_load_primitive_id: {
583                ret = nir_load_var(&b_shader, vars->primitive_id);
584                break;
585             }
586             case nir_intrinsic_load_ray_geometry_index: {
587                ret = nir_load_var(&b_shader, vars->geometry_id_and_flags);
588                ret = nir_iand_imm(&b_shader, ret, 0xFFFFFFF);
589                break;
590             }
591             case nir_intrinsic_load_instance_id: {
592                ret = nir_load_var(&b_shader, vars->instance_id);
593                break;
594             }
595             case nir_intrinsic_load_ray_flags: {
596                ret = nir_load_var(&b_shader, vars->flags);
597                break;
598             }
599             case nir_intrinsic_load_ray_hit_kind: {
600                ret = nir_load_var(&b_shader, vars->hit_kind);
601                break;
602             }
603             case nir_intrinsic_load_ray_world_to_object: {
604                unsigned c = nir_intrinsic_column(intr);
605                nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
606                nir_ssa_def *wto_matrix[3];
607                nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
608 
609                nir_ssa_def *vals[3];
610                for (unsigned i = 0; i < 3; ++i)
611                   vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
612 
613                ret = nir_vec(&b_shader, vals, 3);
614                if (c == 3)
615                   ret = nir_fneg(&b_shader,
616                                  nir_build_vec3_mat_mult(&b_shader, ret, wto_matrix, false));
617                break;
618             }
619             case nir_intrinsic_load_ray_object_to_world: {
620                unsigned c = nir_intrinsic_column(intr);
621                nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
622                if (c == 3) {
623                   nir_ssa_def *wto_matrix[3];
624                   nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
625 
626                   nir_ssa_def *vals[3];
627                   for (unsigned i = 0; i < 3; ++i)
628                      vals[i] = nir_channel(&b_shader, wto_matrix[i], c);
629 
630                   ret = nir_vec(&b_shader, vals, 3);
631                } else {
632                   ret = nir_build_load_global(
633                      &b_shader, 3, 32, nir_iadd_imm(&b_shader, instance_node_addr, 92 + c * 12));
634                }
635                break;
636             }
637             case nir_intrinsic_load_ray_object_origin: {
638                nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
639                nir_ssa_def *wto_matrix[] = {
640                   nir_build_load_global(&b_shader, 4, 32,
641                                         nir_iadd_imm(&b_shader, instance_node_addr, 16),
642                                         .align_mul = 64, .align_offset = 16),
643                   nir_build_load_global(&b_shader, 4, 32,
644                                         nir_iadd_imm(&b_shader, instance_node_addr, 32),
645                                         .align_mul = 64, .align_offset = 32),
646                   nir_build_load_global(&b_shader, 4, 32,
647                                         nir_iadd_imm(&b_shader, instance_node_addr, 48),
648                                         .align_mul = 64, .align_offset = 48)};
649                ret = nir_build_vec3_mat_mult_pre(
650                   &b_shader, nir_load_var(&b_shader, vars->origin), wto_matrix);
651                break;
652             }
653             case nir_intrinsic_load_ray_object_direction: {
654                nir_ssa_def *instance_node_addr = nir_load_var(&b_shader, vars->instance_addr);
655                nir_ssa_def *wto_matrix[3];
656                nir_build_wto_matrix_load(&b_shader, instance_node_addr, wto_matrix);
657                ret = nir_build_vec3_mat_mult(
658                   &b_shader, nir_load_var(&b_shader, vars->direction), wto_matrix, false);
659                break;
660             }
661             case nir_intrinsic_load_intersection_opaque_amd: {
662                ret = nir_load_var(&b_shader, vars->opaque);
663                break;
664             }
665             case nir_intrinsic_load_cull_mask: {
666                ret = nir_load_var(&b_shader, vars->cull_mask);
667                break;
668             }
669             case nir_intrinsic_ignore_ray_intersection: {
670                nir_store_var(&b_shader, vars->ahit_accept, nir_imm_false(&b_shader), 0x1);
671 
672                /* The if is a workaround to avoid having to fix up control flow manually */
673                nir_push_if(&b_shader, nir_imm_true(&b_shader));
674                nir_jump(&b_shader, nir_jump_return);
675                nir_pop_if(&b_shader, NULL);
676                break;
677             }
678             case nir_intrinsic_terminate_ray: {
679                nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
680                nir_store_var(&b_shader, vars->ahit_terminate, nir_imm_true(&b_shader), 0x1);
681 
682                /* The if is a workaround to avoid having to fix up control flow manually */
683                nir_push_if(&b_shader, nir_imm_true(&b_shader));
684                nir_jump(&b_shader, nir_jump_return);
685                nir_pop_if(&b_shader, NULL);
686                break;
687             }
688             case nir_intrinsic_report_ray_intersection: {
689                nir_push_if(
690                   &b_shader,
691                   nir_iand(
692                      &b_shader,
693                      nir_fge(&b_shader, nir_load_var(&b_shader, vars->tmax), intr->src[0].ssa),
694                      nir_fge(&b_shader, intr->src[0].ssa, nir_load_var(&b_shader, vars->tmin))));
695                {
696                   nir_store_var(&b_shader, vars->ahit_accept, nir_imm_true(&b_shader), 0x1);
697                   nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 1);
698                   nir_store_var(&b_shader, vars->hit_kind, intr->src[1].ssa, 1);
699                }
700                nir_pop_if(&b_shader, NULL);
701                break;
702             }
703             default:
704                continue;
705             }
706 
707             if (ret)
708                nir_ssa_def_rewrite_uses(&intr->dest.ssa, ret);
709             nir_instr_remove(instr);
710             break;
711          }
712          case nir_instr_type_jump: {
713             nir_jump_instr *jump = nir_instr_as_jump(instr);
714             if (jump->type == nir_jump_halt) {
715                b_shader.cursor = nir_instr_remove(instr);
716                nir_jump(&b_shader, nir_jump_return);
717             }
718             break;
719          }
720          default:
721             break;
722          }
723       }
724    }
725 
726    nir_metadata_preserve(nir_shader_get_entrypoint(shader), nir_metadata_none);
727 }
728 
729 static void
insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx, uint32_t call_idx_base, uint32_t call_idx)730 insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx,
731                uint32_t call_idx_base, uint32_t call_idx)
732 {
733    struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
734 
735    nir_opt_dead_cf(shader);
736 
737    struct rt_variables src_vars = create_rt_variables(shader, vars->create_info, vars->stack_sizes);
738    map_rt_variables(var_remap, &src_vars, vars);
739 
740    NIR_PASS_V(shader, lower_rt_instructions, &src_vars, call_idx_base);
741 
742    NIR_PASS(_, shader, nir_opt_remove_phis);
743    NIR_PASS(_, shader, nir_lower_returns);
744    NIR_PASS(_, shader, nir_opt_dce);
745 
746    reserve_stack_size(vars, shader->scratch_size);
747 
748    nir_push_if(b, nir_ieq_imm(b, idx, call_idx));
749    nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1);
750    nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
751    nir_pop_if(b, NULL);
752 
753    /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
754    ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
755 
756    ralloc_free(var_remap);
757 }
758 
759 static bool
lower_rt_derefs(nir_shader *shader)760 lower_rt_derefs(nir_shader *shader)
761 {
762    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
763 
764    bool progress = false;
765 
766    nir_builder b;
767    nir_builder_init(&b, impl);
768 
769    b.cursor = nir_before_cf_list(&impl->body);
770    nir_ssa_def *arg_offset = nir_load_rt_arg_scratch_offset_amd(&b);
771 
772    nir_foreach_block (block, impl) {
773       nir_foreach_instr_safe (instr, block) {
774          if (instr->type != nir_instr_type_deref)
775             continue;
776 
777          nir_deref_instr *deref = nir_instr_as_deref(instr);
778          b.cursor = nir_before_instr(&deref->instr);
779 
780          nir_deref_instr *replacement = NULL;
781          if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
782             deref->modes = nir_var_function_temp;
783             progress = true;
784 
785             if (deref->deref_type == nir_deref_type_var)
786                replacement =
787                   nir_build_deref_cast(&b, arg_offset, nir_var_function_temp, deref->var->type, 0);
788          } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
789             deref->modes = nir_var_function_temp;
790             progress = true;
791 
792             if (deref->deref_type == nir_deref_type_var)
793                replacement = nir_build_deref_cast(&b, nir_imm_int(&b, RADV_HIT_ATTRIB_OFFSET),
794                                                   nir_var_function_temp, deref->type, 0);
795          }
796 
797          if (replacement != NULL) {
798             nir_ssa_def_rewrite_uses(&deref->dest.ssa, &replacement->dest.ssa);
799             nir_instr_remove(&deref->instr);
800          }
801       }
802    }
803 
804    if (progress)
805       nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance);
806    else
807       nir_metadata_preserve(impl, nir_metadata_all);
808 
809    return progress;
810 }
811 
812 static nir_shader *
parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo)813 parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreateInfo *sinfo)
814 {
815    struct radv_pipeline_key key;
816    memset(&key, 0, sizeof(key));
817 
818    struct radv_pipeline_stage rt_stage;
819 
820    radv_pipeline_stage_init(sinfo, &rt_stage, vk_to_mesa_shader_stage(sinfo->stage));
821 
822    nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, &key);
823 
824    if (shader->info.stage == MESA_SHADER_RAYGEN || shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
825        shader->info.stage == MESA_SHADER_CALLABLE || shader->info.stage == MESA_SHADER_MISS) {
826       nir_block *last_block = nir_impl_last_block(nir_shader_get_entrypoint(shader));
827       nir_builder b_inner;
828       nir_builder_init(&b_inner, nir_shader_get_entrypoint(shader));
829       b_inner.cursor = nir_after_block(last_block);
830       nir_rt_return_amd(&b_inner);
831    }
832 
833    NIR_PASS(_, shader, nir_lower_vars_to_explicit_types,
834             nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib,
835             glsl_get_natural_size_align_bytes);
836 
837    NIR_PASS(_, shader, lower_rt_derefs);
838 
839    NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp,
840             nir_address_format_32bit_offset);
841 
842    return shader;
843 }
844 
845 static nir_function_impl *
lower_any_hit_for_intersection(nir_shader *any_hit)846 lower_any_hit_for_intersection(nir_shader *any_hit)
847 {
848    nir_function_impl *impl = nir_shader_get_entrypoint(any_hit);
849 
850    /* Any-hit shaders need three parameters */
851    assert(impl->function->num_params == 0);
852    nir_parameter params[] = {
853       {
854          /* A pointer to a boolean value for whether or not the hit was
855           * accepted.
856           */
857          .num_components = 1,
858          .bit_size = 32,
859       },
860       {
861          /* The hit T value */
862          .num_components = 1,
863          .bit_size = 32,
864       },
865       {
866          /* The hit kind */
867          .num_components = 1,
868          .bit_size = 32,
869       },
870    };
871    impl->function->num_params = ARRAY_SIZE(params);
872    impl->function->params = ralloc_array(any_hit, nir_parameter, ARRAY_SIZE(params));
873    memcpy(impl->function->params, params, sizeof(params));
874 
875    nir_builder build;
876    nir_builder_init(&build, impl);
877    nir_builder *b = &build;
878 
879    b->cursor = nir_before_cf_list(&impl->body);
880 
881    nir_ssa_def *commit_ptr = nir_load_param(b, 0);
882    nir_ssa_def *hit_t = nir_load_param(b, 1);
883    nir_ssa_def *hit_kind = nir_load_param(b, 2);
884 
885    nir_deref_instr *commit =
886       nir_build_deref_cast(b, commit_ptr, nir_var_function_temp, glsl_bool_type(), 0);
887 
888    nir_foreach_block_safe (block, impl) {
889       nir_foreach_instr_safe (instr, block) {
890          switch (instr->type) {
891          case nir_instr_type_intrinsic: {
892             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
893             switch (intrin->intrinsic) {
894             case nir_intrinsic_ignore_ray_intersection:
895                b->cursor = nir_instr_remove(&intrin->instr);
896                /* We put the newly emitted code inside a dummy if because it's
897                 * going to contain a jump instruction and we don't want to
898                 * deal with that mess here.  It'll get dealt with by our
899                 * control-flow optimization passes.
900                 */
901                nir_store_deref(b, commit, nir_imm_false(b), 0x1);
902                nir_push_if(b, nir_imm_true(b));
903                nir_jump(b, nir_jump_return);
904                nir_pop_if(b, NULL);
905                break;
906 
907             case nir_intrinsic_terminate_ray:
908                /* The "normal" handling of terminateRay works fine in
909                 * intersection shaders.
910                 */
911                break;
912 
913             case nir_intrinsic_load_ray_t_max:
914                nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_t);
915                nir_instr_remove(&intrin->instr);
916                break;
917 
918             case nir_intrinsic_load_ray_hit_kind:
919                nir_ssa_def_rewrite_uses(&intrin->dest.ssa, hit_kind);
920                nir_instr_remove(&intrin->instr);
921                break;
922 
923             default:
924                break;
925             }
926             break;
927          }
928          case nir_instr_type_jump: {
929             nir_jump_instr *jump = nir_instr_as_jump(instr);
930             if (jump->type == nir_jump_halt) {
931                b->cursor = nir_instr_remove(instr);
932                nir_jump(b, nir_jump_return);
933             }
934             break;
935          }
936 
937          default:
938             break;
939          }
940       }
941    }
942 
943    nir_validate_shader(any_hit, "after initial any-hit lowering");
944 
945    nir_lower_returns_impl(impl);
946 
947    nir_validate_shader(any_hit, "after lowering returns");
948 
949    return impl;
950 }
951 
952 /* Inline the any_hit shader into the intersection shader so we don't have
953  * to implement yet another shader call interface here. Neither do any recursion.
954  */
955 static void
nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)956 nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
957 {
958    void *dead_ctx = ralloc_context(intersection);
959 
960    nir_function_impl *any_hit_impl = NULL;
961    struct hash_table *any_hit_var_remap = NULL;
962    if (any_hit) {
963       any_hit = nir_shader_clone(dead_ctx, any_hit);
964       NIR_PASS(_, any_hit, nir_opt_dce);
965       any_hit_impl = lower_any_hit_for_intersection(any_hit);
966       any_hit_var_remap = _mesa_pointer_hash_table_create(dead_ctx);
967    }
968 
969    nir_function_impl *impl = nir_shader_get_entrypoint(intersection);
970 
971    nir_builder build;
972    nir_builder_init(&build, impl);
973    nir_builder *b = &build;
974 
975    b->cursor = nir_before_cf_list(&impl->body);
976 
977    nir_variable *commit = nir_local_variable_create(impl, glsl_bool_type(), "ray_commit");
978    nir_store_var(b, commit, nir_imm_false(b), 0x1);
979 
980    nir_foreach_block_safe (block, impl) {
981       nir_foreach_instr_safe (instr, block) {
982          if (instr->type != nir_instr_type_intrinsic)
983             continue;
984 
985          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
986          if (intrin->intrinsic != nir_intrinsic_report_ray_intersection)
987             continue;
988 
989          b->cursor = nir_instr_remove(&intrin->instr);
990          nir_ssa_def *hit_t = nir_ssa_for_src(b, intrin->src[0], 1);
991          nir_ssa_def *hit_kind = nir_ssa_for_src(b, intrin->src[1], 1);
992          nir_ssa_def *min_t = nir_load_ray_t_min(b);
993          nir_ssa_def *max_t = nir_load_ray_t_max(b);
994 
995          /* bool commit_tmp = false; */
996          nir_variable *commit_tmp = nir_local_variable_create(impl, glsl_bool_type(), "commit_tmp");
997          nir_store_var(b, commit_tmp, nir_imm_false(b), 0x1);
998 
999          nir_push_if(b, nir_iand(b, nir_fge(b, hit_t, min_t), nir_fge(b, max_t, hit_t)));
1000          {
1001             /* Any-hit defaults to commit */
1002             nir_store_var(b, commit_tmp, nir_imm_true(b), 0x1);
1003 
1004             if (any_hit_impl != NULL) {
1005                nir_push_if(b, nir_inot(b, nir_load_intersection_opaque_amd(b)));
1006                {
1007                   nir_ssa_def *params[] = {
1008                      &nir_build_deref_var(b, commit_tmp)->dest.ssa,
1009                      hit_t,
1010                      hit_kind,
1011                   };
1012                   nir_inline_function_impl(b, any_hit_impl, params, any_hit_var_remap);
1013                }
1014                nir_pop_if(b, NULL);
1015             }
1016 
1017             nir_push_if(b, nir_load_var(b, commit_tmp));
1018             {
1019                nir_report_ray_intersection(b, 1, hit_t, hit_kind);
1020             }
1021             nir_pop_if(b, NULL);
1022          }
1023          nir_pop_if(b, NULL);
1024 
1025          nir_ssa_def *accepted = nir_load_var(b, commit_tmp);
1026          nir_ssa_def_rewrite_uses(&intrin->dest.ssa, accepted);
1027       }
1028    }
1029 
1030    /* We did some inlining; have to re-index SSA defs */
1031    nir_index_ssa_defs(impl);
1032 
1033    /* Eliminate the casts introduced for the commit return of the any-hit shader. */
1034    NIR_PASS(_, intersection, nir_opt_deref);
1035 
1036    ralloc_free(dead_ctx);
1037 }
1038 
1039 /* Variables only used internally to ray traversal. This is data that describes
1040  * the current state of the traversal vs. what we'd give to a shader.  e.g. what
1041  * is the instance we're currently visiting vs. what is the instance of the
1042  * closest hit. */
1043 struct rt_traversal_vars {
1044    nir_variable *origin;
1045    nir_variable *dir;
1046    nir_variable *inv_dir;
1047    nir_variable *sbt_offset_and_flags;
1048    nir_variable *instance_id;
1049    nir_variable *custom_instance_and_mask;
1050    nir_variable *instance_addr;
1051    nir_variable *hit;
1052    nir_variable *bvh_base;
1053    nir_variable *stack;
1054    nir_variable *top_stack;
1055 };
1056 
1057 static struct rt_traversal_vars
init_traversal_vars(nir_builder *b)1058 init_traversal_vars(nir_builder *b)
1059 {
1060    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1061    struct rt_traversal_vars ret;
1062 
1063    ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
1064    ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
1065    ret.inv_dir =
1066       nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
1067    ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1068                                                   "traversal_sbt_offset_and_flags");
1069    ret.instance_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1070                                          "traversal_instance_id");
1071    ret.custom_instance_and_mask = nir_variable_create(
1072       b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_custom_instance_and_mask");
1073    ret.instance_addr =
1074       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1075    ret.hit = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(), "traversal_hit");
1076    ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(),
1077                                       "traversal_bvh_base");
1078    ret.stack =
1079       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
1080    ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
1081                                        "traversal_top_stack_ptr");
1082    return ret;
1083 }
1084 
1085 static void
visit_any_hit_shaders(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, struct rt_variables *vars)1086 visit_any_hit_shaders(struct radv_device *device,
1087                       const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1088                       struct rt_variables *vars)
1089 {
1090    nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
1091 
1092    nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
1093    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
1094       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
1095       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
1096 
1097       switch (group_info->type) {
1098       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
1099          shader_id = group_info->anyHitShader;
1100          break;
1101       default:
1102          break;
1103       }
1104       if (shader_id == VK_SHADER_UNUSED_KHR)
1105          continue;
1106 
1107       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
1108       nir_shader *nir_stage = parse_rt_stage(device, stage);
1109 
1110       vars->stage_idx = shader_id;
1111       insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
1112    }
1113    nir_pop_if(b, NULL);
1114 }
1115 
1116 static void
insert_traversal_triangle_case(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, nir_ssa_def *result, const struct rt_variables *vars, const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)1117 insert_traversal_triangle_case(struct radv_device *device,
1118                                const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1119                                nir_ssa_def *result, const struct rt_variables *vars,
1120                                const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)
1121 {
1122    nir_ssa_def *dist = nir_channel(b, result, 0);
1123    nir_ssa_def *div = nir_channel(b, result, 1);
1124    dist = nir_fdiv(b, dist, div);
1125    nir_ssa_def *frontface = nir_flt(b, nir_imm_float(b, 0), div);
1126    nir_ssa_def *switch_ccw =
1127       nir_test_mask(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1128                     VK_GEOMETRY_INSTANCE_TRIANGLE_FLIP_FACING_BIT_KHR << 24);
1129    frontface = nir_ixor(b, frontface, switch_ccw);
1130 
1131    nir_ssa_def *not_cull =
1132       nir_inot(b, nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsSkipTrianglesKHRMask));
1133    nir_ssa_def *not_facing_cull = nir_ieq_imm(
1134       b,
1135       nir_iand(b, nir_load_var(b, vars->flags),
1136                nir_bcsel(b, frontface, nir_imm_int(b, SpvRayFlagsCullFrontFacingTrianglesKHRMask),
1137                          nir_imm_int(b, SpvRayFlagsCullBackFacingTrianglesKHRMask))),
1138       0);
1139 
1140    not_cull = nir_iand(
1141       b, not_cull,
1142       nir_ior(b, not_facing_cull,
1143               nir_test_mask(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1144                             VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR << 24)));
1145 
1146    nir_push_if(b, nir_iand(b,
1147                            nir_iand(b, nir_flt(b, dist, nir_load_var(b, vars->tmax)),
1148                                     nir_flt(b, nir_load_var(b, vars->tmin), dist)),
1149                            not_cull));
1150    {
1151 
1152       nir_ssa_def *triangle_info =
1153          nir_build_load_global(b, 2, 32,
1154                                nir_iadd_imm(b, build_node_to_addr(device, b, bvh_node),
1155                                             offsetof(struct radv_bvh_triangle_node, triangle_id)));
1156       nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0);
1157       nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1);
1158       nir_ssa_def *geometry_id = nir_iand_imm(b, geometry_id_and_flags, 0xfffffff);
1159       nir_ssa_def *is_opaque = hit_is_opaque(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1160                                              nir_load_var(b, vars->flags), geometry_id_and_flags);
1161 
1162       not_cull =
1163          nir_ieq_imm(b,
1164                      nir_iand(b, nir_load_var(b, vars->flags),
1165                               nir_bcsel(b, is_opaque, nir_imm_int(b, SpvRayFlagsCullOpaqueKHRMask),
1166                                         nir_imm_int(b, SpvRayFlagsCullNoOpaqueKHRMask))),
1167                      0);
1168       nir_push_if(b, not_cull);
1169       {
1170          nir_ssa_def *sbt_idx = nir_iadd(
1171             b,
1172             nir_iadd(b, nir_load_var(b, vars->sbt_offset),
1173                      nir_iand_imm(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 0xffffff)),
1174             nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id));
1175          nir_ssa_def *divs[2] = {div, div};
1176          nir_ssa_def *ij = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2));
1177          nir_ssa_def *hit_kind =
1178             nir_bcsel(b, frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
1179 
1180          nir_store_scratch(
1181             b, ij, nir_iadd_imm(b, nir_load_var(b, vars->stack_ptr), RADV_HIT_ATTRIB_OFFSET),
1182             .align_mul = 16);
1183 
1184          nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
1185          nir_store_var(b, vars->ahit_terminate, nir_imm_false(b), 0x1);
1186 
1187          nir_push_if(b, nir_inot(b, is_opaque));
1188          {
1189             struct rt_variables inner_vars = create_inner_vars(b, vars);
1190 
1191             nir_store_var(b, inner_vars.primitive_id, primitive_id, 1);
1192             nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1);
1193             nir_store_var(b, inner_vars.tmax, dist, 0x1);
1194             nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1195             nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr),
1196                           0x1);
1197             nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
1198             nir_store_var(b, inner_vars.custom_instance_and_mask,
1199                           nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1200 
1201             load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
1202 
1203             visit_any_hit_shaders(device, pCreateInfo, b, &inner_vars);
1204 
1205             nir_push_if(b, nir_inot(b, nir_load_var(b, vars->ahit_accept)));
1206             {
1207                nir_jump(b, nir_jump_continue);
1208             }
1209             nir_pop_if(b, NULL);
1210          }
1211          nir_pop_if(b, NULL);
1212 
1213          nir_store_var(b, vars->primitive_id, primitive_id, 1);
1214          nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1);
1215          nir_store_var(b, vars->tmax, dist, 0x1);
1216          nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1217          nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1218          nir_store_var(b, vars->hit_kind, hit_kind, 0x1);
1219          nir_store_var(b, vars->custom_instance_and_mask,
1220                        nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1221 
1222          nir_store_var(b, vars->idx, sbt_idx, 1);
1223          nir_store_var(b, trav_vars->hit, nir_imm_true(b), 1);
1224 
1225          nir_ssa_def *terminate_on_first_hit =
1226             nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
1227          nir_ssa_def *ray_terminated = nir_load_var(b, vars->ahit_terminate);
1228          nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
1229          {
1230             nir_jump(b, nir_jump_break);
1231          }
1232          nir_pop_if(b, NULL);
1233       }
1234       nir_pop_if(b, NULL);
1235    }
1236    nir_pop_if(b, NULL);
1237 }
1238 
1239 static void
insert_traversal_aabb_case(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, const struct rt_variables *vars, const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)1240 insert_traversal_aabb_case(struct radv_device *device,
1241                            const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
1242                            const struct rt_variables *vars,
1243                            const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)
1244 {
1245    nir_ssa_def *node_addr = build_node_to_addr(device, b, bvh_node);
1246    nir_ssa_def *triangle_info = nir_build_load_global(b, 2, 32, nir_iadd_imm(b, node_addr, 24));
1247    nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0);
1248    nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1);
1249    nir_ssa_def *geometry_id = nir_iand_imm(b, geometry_id_and_flags, 0xfffffff);
1250    nir_ssa_def *is_opaque = hit_is_opaque(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
1251                                           nir_load_var(b, vars->flags), geometry_id_and_flags);
1252 
1253    nir_ssa_def *not_skip_aabb =
1254       nir_inot(b, nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsSkipAABBsKHRMask));
1255    nir_ssa_def *not_cull = nir_iand(
1256       b, not_skip_aabb,
1257       nir_ieq_imm(b,
1258                   nir_iand(b, nir_load_var(b, vars->flags),
1259                            nir_bcsel(b, is_opaque, nir_imm_int(b, SpvRayFlagsCullOpaqueKHRMask),
1260                                      nir_imm_int(b, SpvRayFlagsCullNoOpaqueKHRMask))),
1261                   0));
1262    nir_push_if(b, not_cull);
1263    {
1264       nir_ssa_def *sbt_idx = nir_iadd(
1265          b,
1266          nir_iadd(b, nir_load_var(b, vars->sbt_offset),
1267                   nir_iand_imm(b, nir_load_var(b, trav_vars->sbt_offset_and_flags), 0xffffff)),
1268          nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id));
1269 
1270       struct rt_variables inner_vars = create_inner_vars(b, vars);
1271 
1272       /* For AABBs the intersection shader writes the hit kind, and only does it if it is the
1273        * next closest hit candidate. */
1274       inner_vars.hit_kind = vars->hit_kind;
1275 
1276       nir_store_var(b, inner_vars.primitive_id, primitive_id, 1);
1277       nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1);
1278       nir_store_var(b, inner_vars.tmax, nir_load_var(b, vars->tmax), 0x1);
1279       nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1280       nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1281       nir_store_var(b, inner_vars.custom_instance_and_mask,
1282                     nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1283       nir_store_var(b, inner_vars.opaque, is_opaque, 1);
1284 
1285       load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
1286 
1287       nir_store_var(b, vars->ahit_accept, nir_imm_false(b), 0x1);
1288       nir_store_var(b, vars->ahit_terminate, nir_imm_false(b), 0x1);
1289 
1290       nir_push_if(b, nir_ine_imm(b, nir_load_var(b, inner_vars.idx), 0));
1291       for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
1292          const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
1293          uint32_t shader_id = VK_SHADER_UNUSED_KHR;
1294          uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
1295 
1296          switch (group_info->type) {
1297          case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
1298             shader_id = group_info->intersectionShader;
1299             any_hit_shader_id = group_info->anyHitShader;
1300             break;
1301          default:
1302             break;
1303          }
1304          if (shader_id == VK_SHADER_UNUSED_KHR)
1305             continue;
1306 
1307          const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
1308          nir_shader *nir_stage = parse_rt_stage(device, stage);
1309 
1310          nir_shader *any_hit_stage = NULL;
1311          if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
1312             stage = &pCreateInfo->pStages[any_hit_shader_id];
1313             any_hit_stage = parse_rt_stage(device, stage);
1314 
1315             nir_lower_intersection_shader(nir_stage, any_hit_stage);
1316             ralloc_free(any_hit_stage);
1317          }
1318 
1319          inner_vars.stage_idx = shader_id;
1320          insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
1321       }
1322       nir_push_else(b, NULL);
1323       {
1324          nir_ssa_def *vec3_zero = nir_channels(b, nir_imm_vec4(b, 0, 0, 0, 0), 0x7);
1325          nir_ssa_def *vec3_inf =
1326             nir_channels(b, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, 0), 0x7);
1327 
1328          nir_ssa_def *bvh_lo = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 0));
1329          nir_ssa_def *bvh_hi = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 12));
1330 
1331          bvh_lo = nir_fsub(b, bvh_lo, nir_load_var(b, trav_vars->origin));
1332          bvh_hi = nir_fsub(b, bvh_hi, nir_load_var(b, trav_vars->origin));
1333          nir_ssa_def *t_vec = nir_fmin(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)),
1334                                        nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir)));
1335          nir_ssa_def *t2_vec = nir_fmax(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)),
1336                                         nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir)));
1337          /* If we run parallel to one of the edges the range should be [0, inf) not [0,0] */
1338          t2_vec =
1339             nir_bcsel(b, nir_feq(b, nir_load_var(b, trav_vars->dir), vec3_zero), vec3_inf, t2_vec);
1340 
1341          nir_ssa_def *t_min = nir_fmax(b, nir_channel(b, t_vec, 0), nir_channel(b, t_vec, 1));
1342          t_min = nir_fmax(b, t_min, nir_channel(b, t_vec, 2));
1343 
1344          nir_ssa_def *t_max = nir_fmin(b, nir_channel(b, t2_vec, 0), nir_channel(b, t2_vec, 1));
1345          t_max = nir_fmin(b, t_max, nir_channel(b, t2_vec, 2));
1346 
1347          nir_push_if(b, nir_iand(b, nir_fge(b, nir_load_var(b, vars->tmax), t_min),
1348                                  nir_fge(b, t_max, nir_load_var(b, vars->tmin))));
1349          {
1350             nir_store_var(b, vars->ahit_accept, nir_imm_true(b), 0x1);
1351             nir_store_var(b, vars->tmax, nir_fmax(b, t_min, nir_load_var(b, vars->tmin)), 1);
1352          }
1353          nir_pop_if(b, NULL);
1354       }
1355       nir_pop_if(b, NULL);
1356 
1357       nir_push_if(b, nir_load_var(b, vars->ahit_accept));
1358       {
1359          nir_store_var(b, vars->primitive_id, primitive_id, 1);
1360          nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1);
1361          nir_store_var(b, vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
1362          nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
1363          nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
1364          nir_store_var(b, vars->custom_instance_and_mask,
1365                        nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
1366 
1367          nir_store_var(b, vars->idx, sbt_idx, 1);
1368          nir_store_var(b, trav_vars->hit, nir_imm_true(b), 1);
1369 
1370          nir_ssa_def *terminate_on_first_hit =
1371             nir_test_mask(b, nir_load_var(b, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
1372          nir_ssa_def *ray_terminated = nir_load_var(b, vars->ahit_terminate);
1373          nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
1374          {
1375             nir_jump(b, nir_jump_break);
1376          }
1377          nir_pop_if(b, NULL);
1378       }
1379       nir_pop_if(b, NULL);
1380    }
1381    nir_pop_if(b, NULL);
1382 }
1383 
1384 static nir_shader *
build_traversal_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct rt_variables *dst_vars, struct hash_table *var_remap)1385 build_traversal_shader(struct radv_device *device,
1386                        const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1387                        const struct rt_variables *dst_vars,
1388                        struct hash_table *var_remap)
1389 {
1390    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal");
1391    b.shader->info.internal = false;
1392    b.shader->info.workgroup_size[0] = 8;
1393    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
1394    struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, dst_vars->stack_sizes);
1395    map_rt_variables(var_remap, &vars, dst_vars);
1396 
1397    unsigned lanes = device->physical_device->rt_wave_size;
1398    unsigned elements = lanes * MAX_STACK_ENTRY_COUNT;
1399    nir_variable *stack_var = nir_variable_create(b.shader, nir_var_mem_shared,
1400                                                  glsl_array_type(glsl_uint_type(), elements, 0),
1401                                                  "trav_stack");
1402    nir_deref_instr *stack_deref = nir_build_deref_var(&b, stack_var);
1403    nir_deref_instr *stack;
1404    nir_ssa_def *stack_idx_stride = nir_imm_int(&b, lanes);
1405    nir_ssa_def *stack_idx_base = nir_load_local_invocation_index(&b);
1406 
1407    nir_ssa_def *accel_struct = nir_load_var(&b, vars.accel_struct);
1408 
1409    struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
1410 
1411    nir_store_var(&b, trav_vars.hit, nir_imm_false(&b), 1);
1412 
1413    nir_push_if(&b, nir_ine_imm(&b, accel_struct, 0));
1414    {
1415       nir_store_var(&b, trav_vars.bvh_base, build_addr_to_node(&b, accel_struct), 1);
1416 
1417       nir_ssa_def *bvh_root = nir_build_load_global(
1418          &b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, .align_mul = 64);
1419 
1420       nir_ssa_def *desc = create_bvh_descriptor(&b);
1421       nir_ssa_def *vec3ones = nir_channels(&b, nir_imm_vec4(&b, 1.0, 1.0, 1.0, 1.0), 0x7);
1422 
1423       nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
1424       nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
1425       nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
1426       nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1);
1427       nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
1428 
1429       nir_store_var(&b, trav_vars.stack, nir_iadd(&b, stack_idx_base, stack_idx_stride), 1);
1430       stack = nir_build_deref_array(&b, stack_deref, stack_idx_base);
1431       nir_store_deref(&b, stack, bvh_root, 0x1);
1432 
1433       nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1);
1434 
1435       nir_push_loop(&b);
1436 
1437       nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, trav_vars.stack), stack_idx_base));
1438       nir_jump(&b, nir_jump_break);
1439       nir_pop_if(&b, NULL);
1440 
1441       nir_push_if(
1442          &b, nir_uge(&b, nir_load_var(&b, trav_vars.top_stack), nir_load_var(&b, trav_vars.stack)));
1443       nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1);
1444       nir_store_var(&b, trav_vars.bvh_base,
1445                     build_addr_to_node(&b, nir_load_var(&b, vars.accel_struct)), 1);
1446       nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7);
1447       nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7);
1448       nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
1449       nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1);
1450 
1451       nir_pop_if(&b, NULL);
1452 
1453       nir_store_var(&b, trav_vars.stack,
1454                     nir_isub(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
1455 
1456       stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
1457       nir_ssa_def *bvh_node = nir_load_deref(&b, stack);
1458       nir_ssa_def *bvh_node_type = nir_iand_imm(&b, bvh_node, 7);
1459 
1460       bvh_node = nir_iadd(&b, nir_load_var(&b, trav_vars.bvh_base), nir_u2u(&b, bvh_node, 64));
1461       nir_ssa_def *intrinsic_result = NULL;
1462       if (!radv_emulate_rt(device->physical_device)) {
1463          intrinsic_result = nir_bvh64_intersect_ray_amd(
1464             &b, 32, desc, nir_unpack_64_2x32(&b, bvh_node), nir_load_var(&b, vars.tmax),
1465             nir_load_var(&b, trav_vars.origin), nir_load_var(&b, trav_vars.dir),
1466             nir_load_var(&b, trav_vars.inv_dir));
1467       }
1468 
1469       nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 4), 0));
1470       {
1471          nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 2), 0));
1472          {
1473             /* custom */
1474             nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 1), 0));
1475             if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR)) {
1476                insert_traversal_aabb_case(device, pCreateInfo, &b, &vars, &trav_vars, bvh_node);
1477             }
1478             nir_push_else(&b, NULL);
1479             {
1480                /* instance */
1481                nir_ssa_def *instance_node_addr = build_node_to_addr(device, &b, bvh_node);
1482                nir_ssa_def *instance_data =
1483                   nir_build_load_global(&b, 4, 32, instance_node_addr, .align_mul = 64);
1484                nir_ssa_def *wto_matrix[] = {
1485                   nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 16),
1486                                         .align_mul = 64, .align_offset = 16),
1487                   nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 32),
1488                                         .align_mul = 64, .align_offset = 32),
1489                   nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 48),
1490                                         .align_mul = 64, .align_offset = 48)};
1491                nir_ssa_def *instance_id =
1492                   nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, instance_node_addr, 88));
1493                nir_ssa_def *instance_and_mask = nir_channel(&b, instance_data, 2);
1494                nir_ssa_def *instance_mask = nir_ushr_imm(&b, instance_and_mask, 24);
1495 
1496                nir_push_if(
1497                   &b,
1498                   nir_ieq_imm(&b, nir_iand(&b, instance_mask, nir_load_var(&b, vars.cull_mask)), 0));
1499                nir_jump(&b, nir_jump_continue);
1500                nir_pop_if(&b, NULL);
1501 
1502                nir_store_var(&b, trav_vars.top_stack, nir_load_var(&b, trav_vars.stack), 1);
1503                nir_store_var(&b, trav_vars.bvh_base,
1504                              build_addr_to_node(
1505                                 &b, nir_pack_64_2x32(&b, nir_channels(&b, instance_data, 0x3))),
1506                              1);
1507                stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
1508                nir_store_deref(&b, stack, nir_iand_imm(&b, nir_channel(&b, instance_data, 0), 63), 0x1);
1509 
1510                nir_store_var(&b, trav_vars.stack,
1511                              nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
1512 
1513                nir_store_var(
1514                   &b, trav_vars.origin,
1515                   nir_build_vec3_mat_mult_pre(&b, nir_load_var(&b, vars.origin), wto_matrix), 7);
1516                nir_store_var(
1517                   &b, trav_vars.dir,
1518                   nir_build_vec3_mat_mult(&b, nir_load_var(&b, vars.direction), wto_matrix, false),
1519                   7);
1520                nir_store_var(&b, trav_vars.inv_dir,
1521                              nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7);
1522                nir_store_var(&b, trav_vars.custom_instance_and_mask, instance_and_mask, 1);
1523                nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_channel(&b, instance_data, 3),
1524                              1);
1525                nir_store_var(&b, trav_vars.instance_id, instance_id, 1);
1526                nir_store_var(&b, trav_vars.instance_addr, instance_node_addr, 1);
1527             }
1528             nir_pop_if(&b, NULL);
1529          }
1530          nir_push_else(&b, NULL);
1531          {
1532             /* box */
1533             nir_ssa_def *result = intrinsic_result;
1534             if (!result) {
1535                /* If we didn't run the intrinsic cause the hardware didn't support it,
1536                 * emulate ray/box intersection here */
1537                result = intersect_ray_amd_software_box(device,
1538                   &b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin),
1539                   nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir));
1540             }
1541 
1542             for (unsigned i = 4; i-- > 0; ) {
1543                nir_ssa_def *new_node = nir_channel(&b, result, i);
1544                nir_push_if(&b, nir_ine_imm(&b, new_node, 0xffffffff));
1545                {
1546                   stack = nir_build_deref_array(&b, stack_deref, nir_load_var(&b, trav_vars.stack));
1547                   nir_store_deref(&b, stack, new_node, 0x1);
1548                   nir_store_var(
1549                      &b, trav_vars.stack,
1550                      nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_idx_stride), 1);
1551                }
1552                nir_pop_if(&b, NULL);
1553             }
1554          }
1555          nir_pop_if(&b, NULL);
1556       }
1557       nir_push_else(&b, NULL);
1558       if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)) {
1559          nir_ssa_def *result = intrinsic_result;
1560          if (!result) {
1561             /* If we didn't run the intrinsic cause the hardware didn't support it,
1562              * emulate ray/tri intersection here */
1563             result = intersect_ray_amd_software_tri(device,
1564                &b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin),
1565                nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir));
1566          }
1567          insert_traversal_triangle_case(device, pCreateInfo, &b, result, &vars, &trav_vars, bvh_node);
1568       }
1569       nir_pop_if(&b, NULL);
1570 
1571       nir_pop_loop(&b, NULL);
1572    }
1573    nir_pop_if(&b, NULL);
1574 
1575    /* Initialize follow-up shader. */
1576    nir_push_if(&b, nir_load_var(&b, trav_vars.hit));
1577    {
1578       /* vars.idx contains the SBT index at this point. */
1579       load_sbt_entry(&b, &vars, nir_load_var(&b, vars.idx), SBT_HIT, 0);
1580 
1581       nir_ssa_def *should_return = nir_ior(&b,
1582                                            nir_test_mask(&b, nir_load_var(&b, vars.flags),
1583                                                          SpvRayFlagsSkipClosestHitShaderKHRMask),
1584                                            nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
1585 
1586       /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence
1587        * need to return immediately to the calling shader. */
1588       nir_push_if(&b, should_return);
1589       {
1590          insert_rt_return(&b, &vars);
1591       }
1592       nir_pop_if(&b, NULL);
1593    }
1594    nir_push_else(&b, NULL);
1595    {
1596       /* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
1597        * for miss shaders if none of the rays miss. */
1598       load_sbt_entry(&b, &vars, nir_load_var(&b, vars.miss_index), SBT_MISS, 0);
1599    }
1600    nir_pop_if(&b, NULL);
1601 
1602    return b.shader;
1603 }
1604 
1605 
1606 static void
insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b, const struct rt_variables *vars)1607 insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1608                  nir_builder *b, const struct rt_variables *vars)
1609 {
1610    struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
1611    nir_shader *shader = build_traversal_shader(device, pCreateInfo, vars, var_remap);
1612 
1613    /* For now, just inline the traversal shader */
1614    nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 1));
1615    nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1);
1616    nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
1617    nir_pop_if(b, NULL);
1618 
1619    /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
1620    ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
1621 
1622    ralloc_free(var_remap);
1623 }
1624 
1625 static unsigned
compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const struct radv_pipeline_shader_stack_size *stack_sizes)1626 compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1627                       const struct radv_pipeline_shader_stack_size *stack_sizes)
1628 {
1629    unsigned raygen_size = 0;
1630    unsigned callable_size = 0;
1631    unsigned chit_size = 0;
1632    unsigned miss_size = 0;
1633    unsigned non_recursive_size = 0;
1634 
1635    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
1636       non_recursive_size = MAX2(stack_sizes[i].non_recursive_size, non_recursive_size);
1637 
1638       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
1639       uint32_t shader_id = VK_SHADER_UNUSED_KHR;
1640       unsigned size = stack_sizes[i].recursive_size;
1641 
1642       switch (group_info->type) {
1643       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
1644          shader_id = group_info->generalShader;
1645          break;
1646       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
1647       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
1648          shader_id = group_info->closestHitShader;
1649          break;
1650       default:
1651          break;
1652       }
1653       if (shader_id == VK_SHADER_UNUSED_KHR)
1654          continue;
1655 
1656       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
1657       switch (stage->stage) {
1658       case VK_SHADER_STAGE_RAYGEN_BIT_KHR:
1659          raygen_size = MAX2(raygen_size, size);
1660          break;
1661       case VK_SHADER_STAGE_MISS_BIT_KHR:
1662          miss_size = MAX2(miss_size, size);
1663          break;
1664       case VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR:
1665          chit_size = MAX2(chit_size, size);
1666          break;
1667       case VK_SHADER_STAGE_CALLABLE_BIT_KHR:
1668          callable_size = MAX2(callable_size, size);
1669          break;
1670       default:
1671          unreachable("Invalid stage type in RT shader");
1672       }
1673    }
1674    return raygen_size +
1675           MIN2(pCreateInfo->maxPipelineRayRecursionDepth, 1) *
1676              MAX2(MAX2(chit_size, miss_size), non_recursive_size) +
1677           MAX2(0, (int)(pCreateInfo->maxPipelineRayRecursionDepth) - 1) *
1678              MAX2(chit_size, miss_size) +
1679           2 * callable_size;
1680 }
1681 
1682 bool
radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)1683 radv_rt_pipeline_has_dynamic_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo)
1684 {
1685    if (!pCreateInfo->pDynamicState)
1686       return false;
1687 
1688    for (unsigned i = 0; i < pCreateInfo->pDynamicState->dynamicStateCount; ++i) {
1689       if (pCreateInfo->pDynamicState->pDynamicStates[i] ==
1690           VK_DYNAMIC_STATE_RAY_TRACING_PIPELINE_STACK_SIZE_KHR)
1691          return true;
1692    }
1693 
1694    return false;
1695 }
1696 
1697 static bool
should_move_rt_instruction(nir_intrinsic_op intrinsic)1698 should_move_rt_instruction(nir_intrinsic_op intrinsic)
1699 {
1700    switch (intrinsic) {
1701    case nir_intrinsic_load_rt_arg_scratch_offset_amd:
1702    case nir_intrinsic_load_ray_flags:
1703    case nir_intrinsic_load_ray_object_origin:
1704    case nir_intrinsic_load_ray_world_origin:
1705    case nir_intrinsic_load_ray_t_min:
1706    case nir_intrinsic_load_ray_object_direction:
1707    case nir_intrinsic_load_ray_world_direction:
1708    case nir_intrinsic_load_ray_t_max:
1709       return true;
1710    default:
1711       return false;
1712    }
1713 }
1714 
1715 static void
move_rt_instructions(nir_shader *shader)1716 move_rt_instructions(nir_shader *shader)
1717 {
1718    nir_cursor target = nir_before_cf_list(&nir_shader_get_entrypoint(shader)->body);
1719 
1720    nir_foreach_block (block, nir_shader_get_entrypoint(shader)) {
1721       nir_foreach_instr_safe (instr, block) {
1722          if (instr->type != nir_instr_type_intrinsic)
1723             continue;
1724 
1725          nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
1726 
1727          if (!should_move_rt_instruction(intrinsic->intrinsic))
1728             continue;
1729 
1730          nir_instr_move(target, instr);
1731       }
1732    }
1733 
1734    nir_metadata_preserve(nir_shader_get_entrypoint(shader),
1735                          nir_metadata_all & (~nir_metadata_instr_index));
1736 }
1737 
1738 static nir_shader *
create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes)1739 create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1740                  struct radv_pipeline_shader_stack_size *stack_sizes)
1741 {
1742    struct radv_pipeline_key key;
1743    memset(&key, 0, sizeof(key));
1744 
1745    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined");
1746    b.shader->info.internal = false;
1747    b.shader->info.workgroup_size[0] = 8;
1748    b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
1749 
1750    struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
1751    load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, 0);
1752    nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
1753 
1754    nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1);
1755 
1756    nir_loop *loop = nir_push_loop(&b);
1757 
1758    nir_push_if(&b, nir_ior(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0),
1759                            nir_inot(&b, nir_load_var(&b, vars.main_loop_case_visited))));
1760    nir_jump(&b, nir_jump_break);
1761    nir_pop_if(&b, NULL);
1762 
1763    nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, false), 1);
1764 
1765    insert_traversal(device, pCreateInfo, &b, &vars);
1766 
1767    nir_ssa_def *idx = nir_load_var(&b, vars.idx);
1768 
1769    /* We do a trick with the indexing of the resume shaders so that the first
1770     * shader of stage x always gets id x and the resume shader ids then come after
1771     * stageCount. This makes the shadergroup handles independent of compilation. */
1772    unsigned call_idx_base = pCreateInfo->stageCount + 1;
1773    for (unsigned i = 0; i < pCreateInfo->stageCount; ++i) {
1774       const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i];
1775       gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
1776       if (type != MESA_SHADER_RAYGEN && type != MESA_SHADER_CALLABLE &&
1777           type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS)
1778          continue;
1779 
1780       nir_shader *nir_stage = parse_rt_stage(device, stage);
1781 
1782       /* Move ray tracing system values to the top that are set by rt_trace_ray
1783        * to prevent them from being overwritten by other rt_trace_ray calls.
1784        */
1785       NIR_PASS_V(nir_stage, move_rt_instructions);
1786 
1787       uint32_t num_resume_shaders = 0;
1788       nir_shader **resume_shaders = NULL;
1789       nir_lower_shader_calls(nir_stage, nir_address_format_32bit_offset, 16, &resume_shaders,
1790                              &num_resume_shaders, nir_stage);
1791 
1792       vars.stage_idx = i;
1793       insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2);
1794       for (unsigned j = 0; j < num_resume_shaders; ++j) {
1795          insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, call_idx_base + 1 + j);
1796       }
1797       call_idx_base += num_resume_shaders;
1798    }
1799 
1800    nir_pop_loop(&b, loop);
1801 
1802    if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo)) {
1803       /* Put something so scratch gets enabled in the shader. */
1804       b.shader->scratch_size = 16;
1805    } else
1806       b.shader->scratch_size = compute_rt_stack_size(pCreateInfo, stack_sizes);
1807 
1808    /* Deal with all the inline functions. */
1809    nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
1810    nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
1811 
1812    return b.shader;
1813 }
1814 
1815 static VkResult
radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)1816 radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache,
1817                         const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
1818                         const VkAllocationCallbacks *pAllocator, VkPipeline *pPipeline)
1819 {
1820    RADV_FROM_HANDLE(radv_device, device, _device);
1821    VkResult result;
1822    struct radv_pipeline *pipeline = NULL;
1823    struct radv_compute_pipeline *compute_pipeline = NULL;
1824    struct radv_pipeline_shader_stack_size *stack_sizes = NULL;
1825    uint8_t hash[20];
1826    nir_shader *shader = NULL;
1827    bool keep_statistic_info =
1828       (pCreateInfo->flags & VK_PIPELINE_CREATE_CAPTURE_STATISTICS_BIT_KHR) ||
1829       (device->instance->debug_flags & RADV_DEBUG_DUMP_SHADER_STATS) || device->keep_shader_info;
1830 
1831    if (pCreateInfo->flags & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR)
1832       return radv_rt_pipeline_library_create(_device, _cache, pCreateInfo, pAllocator, pPipeline);
1833 
1834    VkRayTracingPipelineCreateInfoKHR local_create_info =
1835       radv_create_merged_rt_create_info(pCreateInfo);
1836    if (!local_create_info.pStages || !local_create_info.pGroups) {
1837       result = VK_ERROR_OUT_OF_HOST_MEMORY;
1838       goto fail;
1839    }
1840 
1841    radv_hash_rt_shaders(hash, &local_create_info, radv_get_hash_flags(device, keep_statistic_info));
1842    struct vk_shader_module module = {.base.type = VK_OBJECT_TYPE_SHADER_MODULE};
1843 
1844    VkPipelineShaderStageRequiredSubgroupSizeCreateInfo subgroup_size = {
1845       .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO,
1846       .pNext = NULL,
1847       .requiredSubgroupSize = device->physical_device->rt_wave_size,
1848    };
1849 
1850    VkComputePipelineCreateInfo compute_info = {
1851       .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1852       .pNext = NULL,
1853       .flags = pCreateInfo->flags | VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT,
1854       .stage =
1855          {
1856             .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1857             .pNext = &subgroup_size,
1858             .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1859             .module = vk_shader_module_to_handle(&module),
1860             .pName = "main",
1861          },
1862       .layout = pCreateInfo->layout,
1863    };
1864 
1865    /* First check if we can get things from the cache before we take the expensive step of
1866     * generating the nir. */
1867    result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash,
1868                                          stack_sizes, local_create_info.groupCount, pPipeline);
1869 
1870    if (result == VK_PIPELINE_COMPILE_REQUIRED) {
1871       if (pCreateInfo->flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT)
1872          goto fail;
1873 
1874       stack_sizes = calloc(sizeof(*stack_sizes), local_create_info.groupCount);
1875       if (!stack_sizes) {
1876          result = VK_ERROR_OUT_OF_HOST_MEMORY;
1877          goto fail;
1878       }
1879 
1880       shader = create_rt_shader(device, &local_create_info, stack_sizes);
1881       module.nir = shader;
1882       compute_info.flags = pCreateInfo->flags;
1883       result = radv_compute_pipeline_create(_device, _cache, &compute_info, pAllocator, hash,
1884                                             stack_sizes, local_create_info.groupCount, pPipeline);
1885       stack_sizes = NULL;
1886 
1887       if (result != VK_SUCCESS)
1888          goto shader_fail;
1889    }
1890    pipeline = radv_pipeline_from_handle(*pPipeline);
1891    compute_pipeline = radv_pipeline_to_compute(pipeline);
1892 
1893    compute_pipeline->rt_group_handles =
1894       calloc(sizeof(*compute_pipeline->rt_group_handles), local_create_info.groupCount);
1895    if (!compute_pipeline->rt_group_handles) {
1896       result = VK_ERROR_OUT_OF_HOST_MEMORY;
1897       goto shader_fail;
1898    }
1899 
1900    compute_pipeline->dynamic_stack_size = radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo);
1901 
1902    /* For General and ClosestHit shaders, we can use the shader ID directly as handle.
1903     * As (potentially different) AnyHit shaders are inlined, for Intersection shaders
1904     * we use the Group ID.
1905     */
1906    for (unsigned i = 0; i < local_create_info.groupCount; ++i) {
1907       const VkRayTracingShaderGroupCreateInfoKHR *group_info = &local_create_info.pGroups[i];
1908       switch (group_info->type) {
1909       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
1910          if (group_info->generalShader != VK_SHADER_UNUSED_KHR)
1911             compute_pipeline->rt_group_handles[i].handles[0] = group_info->generalShader + 2;
1912          break;
1913       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
1914          if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR)
1915             compute_pipeline->rt_group_handles[i].handles[1] = i + 2;
1916          FALLTHROUGH;
1917       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
1918          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR)
1919             compute_pipeline->rt_group_handles[i].handles[0] = group_info->closestHitShader + 2;
1920          if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
1921             compute_pipeline->rt_group_handles[i].handles[1] = i + 2;
1922          break;
1923       case VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR:
1924          unreachable("VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR");
1925       }
1926    }
1927 
1928 shader_fail:
1929    if (result != VK_SUCCESS && pipeline)
1930       radv_pipeline_destroy(device, pipeline, pAllocator);
1931    ralloc_free(shader);
1932 fail:
1933    free((void *)local_create_info.pGroups);
1934    free((void *)local_create_info.pStages);
1935    free(stack_sizes);
1936    return result;
1937 }
1938 
1939 VKAPI_ATTR VkResult VKAPI_CALL
radv_CreateRayTracingPipelinesKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation, VkPipelineCache pipelineCache, uint32_t count, const VkRayTracingPipelineCreateInfoKHR *pCreateInfos, const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines)1940 radv_CreateRayTracingPipelinesKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation,
1941                                   VkPipelineCache pipelineCache, uint32_t count,
1942                                   const VkRayTracingPipelineCreateInfoKHR *pCreateInfos,
1943                                   const VkAllocationCallbacks *pAllocator, VkPipeline *pPipelines)
1944 {
1945    VkResult result = VK_SUCCESS;
1946 
1947    unsigned i = 0;
1948    for (; i < count; i++) {
1949       VkResult r;
1950       r = radv_rt_pipeline_create(_device, pipelineCache, &pCreateInfos[i], pAllocator,
1951                                   &pPipelines[i]);
1952       if (r != VK_SUCCESS) {
1953          result = r;
1954          pPipelines[i] = VK_NULL_HANDLE;
1955 
1956          if (pCreateInfos[i].flags & VK_PIPELINE_CREATE_EARLY_RETURN_ON_FAILURE_BIT)
1957             break;
1958       }
1959    }
1960 
1961    for (; i < count; ++i)
1962       pPipelines[i] = VK_NULL_HANDLE;
1963 
1964    return result;
1965 }
1966 
1967 VKAPI_ATTR VkResult VKAPI_CALL
radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, uint32_t firstGroup, uint32_t groupCount, size_t dataSize, void *pData)1968 radv_GetRayTracingShaderGroupHandlesKHR(VkDevice device, VkPipeline _pipeline, uint32_t firstGroup,
1969                                         uint32_t groupCount, size_t dataSize, void *pData)
1970 {
1971    RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
1972    struct radv_compute_pipeline *compute_pipeline = radv_pipeline_to_compute(pipeline);
1973    char *data = pData;
1974 
1975    STATIC_ASSERT(sizeof(*compute_pipeline->rt_group_handles) <= RADV_RT_HANDLE_SIZE);
1976 
1977    memset(data, 0, groupCount * RADV_RT_HANDLE_SIZE);
1978 
1979    for (uint32_t i = 0; i < groupCount; ++i) {
1980       memcpy(data + i * RADV_RT_HANDLE_SIZE, &compute_pipeline->rt_group_handles[firstGroup + i],
1981              sizeof(*compute_pipeline->rt_group_handles));
1982    }
1983 
1984    return VK_SUCCESS;
1985 }
1986 
1987 VKAPI_ATTR VkDeviceSize VKAPI_CALL
radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline, uint32_t group, VkShaderGroupShaderKHR groupShader)1988 radv_GetRayTracingShaderGroupStackSizeKHR(VkDevice device, VkPipeline _pipeline, uint32_t group,
1989                                           VkShaderGroupShaderKHR groupShader)
1990 {
1991    RADV_FROM_HANDLE(radv_pipeline, pipeline, _pipeline);
1992    struct radv_compute_pipeline *compute_pipeline = radv_pipeline_to_compute(pipeline);
1993    const struct radv_pipeline_shader_stack_size *stack_size =
1994       &compute_pipeline->rt_stack_sizes[group];
1995 
1996    if (groupShader == VK_SHADER_GROUP_SHADER_ANY_HIT_KHR ||
1997        groupShader == VK_SHADER_GROUP_SHADER_INTERSECTION_KHR)
1998       return stack_size->non_recursive_size;
1999    else
2000       return stack_size->recursive_size;
2001 }
2002 
2003 VKAPI_ATTR VkResult VKAPI_CALL
radv_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice _device, VkPipeline pipeline, uint32_t firstGroup, uint32_t groupCount, size_t dataSize, void *pData)2004 radv_GetRayTracingCaptureReplayShaderGroupHandlesKHR(VkDevice _device, VkPipeline pipeline,
2005                                                      uint32_t firstGroup, uint32_t groupCount,
2006                                                      size_t dataSize, void *pData)
2007 {
2008    RADV_FROM_HANDLE(radv_device, device, _device);
2009    unreachable("Unimplemented");
2010    return vk_error(device, VK_ERROR_FEATURE_NOT_PRESENT);
2011 }
2012