1/*
2 * Copyright © 2021 Bas Nieuwenhuizen
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23#include "radv_acceleration_structure.h"
24#include "radv_private.h"
25
26#include "util/format/format_utils.h"
27#include "util/half_float.h"
28#include "nir_builder.h"
29#include "radv_cs.h"
30#include "radv_meta.h"
31
32#include "radix_sort/radv_radix_sort.h"
33
34/* Min and max bounds of the bvh used to compute morton codes */
35#define SCRATCH_TOTAL_BOUNDS_SIZE (6 * sizeof(float))
36
37enum accel_struct_build {
38   accel_struct_build_unoptimized,
39   accel_struct_build_lbvh,
40};
41
42static enum accel_struct_build
43get_accel_struct_build(const struct radv_physical_device *pdevice,
44                       VkAccelerationStructureBuildTypeKHR buildType)
45{
46   return buildType == VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR
47             ? accel_struct_build_lbvh
48             : accel_struct_build_unoptimized;
49}
50
51static uint32_t
52get_node_id_stride(enum accel_struct_build build_mode)
53{
54   switch (build_mode) {
55   case accel_struct_build_unoptimized:
56      return 4;
57   case accel_struct_build_lbvh:
58      return 8;
59   default:
60      unreachable("Unhandled accel_struct_build!");
61   }
62}
63
64VKAPI_ATTR void VKAPI_CALL
65radv_GetAccelerationStructureBuildSizesKHR(
66   VkDevice _device, VkAccelerationStructureBuildTypeKHR buildType,
67   const VkAccelerationStructureBuildGeometryInfoKHR *pBuildInfo,
68   const uint32_t *pMaxPrimitiveCounts, VkAccelerationStructureBuildSizesInfoKHR *pSizeInfo)
69{
70   RADV_FROM_HANDLE(radv_device, device, _device);
71
72   uint64_t triangles = 0, boxes = 0, instances = 0;
73
74   STATIC_ASSERT(sizeof(struct radv_bvh_triangle_node) == 64);
75   STATIC_ASSERT(sizeof(struct radv_bvh_aabb_node) == 64);
76   STATIC_ASSERT(sizeof(struct radv_bvh_instance_node) == 128);
77   STATIC_ASSERT(sizeof(struct radv_bvh_box16_node) == 64);
78   STATIC_ASSERT(sizeof(struct radv_bvh_box32_node) == 128);
79
80   for (uint32_t i = 0; i < pBuildInfo->geometryCount; ++i) {
81      const VkAccelerationStructureGeometryKHR *geometry;
82      if (pBuildInfo->pGeometries)
83         geometry = &pBuildInfo->pGeometries[i];
84      else
85         geometry = pBuildInfo->ppGeometries[i];
86
87      switch (geometry->geometryType) {
88      case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
89         triangles += pMaxPrimitiveCounts[i];
90         break;
91      case VK_GEOMETRY_TYPE_AABBS_KHR:
92         boxes += pMaxPrimitiveCounts[i];
93         break;
94      case VK_GEOMETRY_TYPE_INSTANCES_KHR:
95         instances += pMaxPrimitiveCounts[i];
96         break;
97      case VK_GEOMETRY_TYPE_MAX_ENUM_KHR:
98         unreachable("VK_GEOMETRY_TYPE_MAX_ENUM_KHR unhandled");
99      }
100   }
101
102   uint64_t children = boxes + instances + triangles;
103   /* Initialize to 1 to have enought space for the root node. */
104   uint64_t internal_nodes = 1;
105   while (children > 1) {
106      children = DIV_ROUND_UP(children, 4);
107      internal_nodes += children;
108   }
109
110   uint64_t size = boxes * 128 + instances * 128 + triangles * 64 + internal_nodes * 128 +
111                   ALIGN(sizeof(struct radv_accel_struct_header), 64);
112
113   pSizeInfo->accelerationStructureSize = size;
114
115   /* 2x the max number of nodes in a BVH layer and order information for sorting when using
116    * LBVH (one uint32_t each, two buffers) plus space to store the bounds.
117    * LBVH is only supported for device builds and hardware that supports global atomics.
118    */
119   enum accel_struct_build build_mode = get_accel_struct_build(device->physical_device, buildType);
120   uint32_t node_id_stride = get_node_id_stride(build_mode);
121
122   uint32_t leaf_count = boxes + instances + triangles;
123   VkDeviceSize scratchSize = 2 * leaf_count * node_id_stride;
124
125   if (build_mode == accel_struct_build_lbvh) {
126      radix_sort_vk_memory_requirements_t requirements;
127      radix_sort_vk_get_memory_requirements(device->meta_state.accel_struct_build.radix_sort,
128                                            leaf_count, &requirements);
129
130      /* Make sure we have the space required by the radix sort. */
131      scratchSize = MAX2(scratchSize, requirements.keyvals_size * 2);
132
133      scratchSize += requirements.internal_size + SCRATCH_TOTAL_BOUNDS_SIZE;
134   }
135
136   scratchSize = MAX2(4096, scratchSize);
137   pSizeInfo->updateScratchSize = scratchSize;
138   pSizeInfo->buildScratchSize = scratchSize;
139}
140
141VKAPI_ATTR VkResult VKAPI_CALL
142radv_CreateAccelerationStructureKHR(VkDevice _device,
143                                    const VkAccelerationStructureCreateInfoKHR *pCreateInfo,
144                                    const VkAllocationCallbacks *pAllocator,
145                                    VkAccelerationStructureKHR *pAccelerationStructure)
146{
147   RADV_FROM_HANDLE(radv_device, device, _device);
148   RADV_FROM_HANDLE(radv_buffer, buffer, pCreateInfo->buffer);
149   struct radv_acceleration_structure *accel;
150
151   accel = vk_alloc2(&device->vk.alloc, pAllocator, sizeof(*accel), 8,
152                     VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
153   if (accel == NULL)
154      return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
155
156   vk_object_base_init(&device->vk, &accel->base, VK_OBJECT_TYPE_ACCELERATION_STRUCTURE_KHR);
157
158   accel->mem_offset = buffer->offset + pCreateInfo->offset;
159   accel->size = pCreateInfo->size;
160   accel->bo = buffer->bo;
161
162   *pAccelerationStructure = radv_acceleration_structure_to_handle(accel);
163   return VK_SUCCESS;
164}
165
166VKAPI_ATTR void VKAPI_CALL
167radv_DestroyAccelerationStructureKHR(VkDevice _device,
168                                     VkAccelerationStructureKHR accelerationStructure,
169                                     const VkAllocationCallbacks *pAllocator)
170{
171   RADV_FROM_HANDLE(radv_device, device, _device);
172   RADV_FROM_HANDLE(radv_acceleration_structure, accel, accelerationStructure);
173
174   if (!accel)
175      return;
176
177   vk_object_base_finish(&accel->base);
178   vk_free2(&device->vk.alloc, pAllocator, accel);
179}
180
181VKAPI_ATTR VkDeviceAddress VKAPI_CALL
182radv_GetAccelerationStructureDeviceAddressKHR(
183   VkDevice _device, const VkAccelerationStructureDeviceAddressInfoKHR *pInfo)
184{
185   RADV_FROM_HANDLE(radv_acceleration_structure, accel, pInfo->accelerationStructure);
186   return radv_accel_struct_get_va(accel);
187}
188
189VKAPI_ATTR VkResult VKAPI_CALL
190radv_WriteAccelerationStructuresPropertiesKHR(
191   VkDevice _device, uint32_t accelerationStructureCount,
192   const VkAccelerationStructureKHR *pAccelerationStructures, VkQueryType queryType,
193   size_t dataSize, void *pData, size_t stride)
194{
195   RADV_FROM_HANDLE(radv_device, device, _device);
196   char *data_out = (char *)pData;
197
198   for (uint32_t i = 0; i < accelerationStructureCount; ++i) {
199      RADV_FROM_HANDLE(radv_acceleration_structure, accel, pAccelerationStructures[i]);
200      const char *base_ptr = (const char *)device->ws->buffer_map(accel->bo);
201      if (!base_ptr)
202         return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
203
204      const struct radv_accel_struct_header *header = (const void *)(base_ptr + accel->mem_offset);
205      if (stride * i + sizeof(VkDeviceSize) <= dataSize) {
206         uint64_t value;
207         switch (queryType) {
208         case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_COMPACTED_SIZE_KHR:
209            value = header->compacted_size;
210            break;
211         case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR:
212            value = header->serialization_size;
213            break;
214         case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_BOTTOM_LEVEL_POINTERS_KHR:
215            value = header->instance_count;
216            break;
217         case VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SIZE_KHR:
218            value = header->size;
219            break;
220         default:
221            unreachable("Unhandled acceleration structure query");
222         }
223         *(VkDeviceSize *)(data_out + stride * i) = value;
224      }
225      device->ws->buffer_unmap(accel->bo);
226   }
227   return VK_SUCCESS;
228}
229
230struct radv_bvh_build_ctx {
231   uint32_t *write_scratch;
232   char *base;
233   char *curr_ptr;
234};
235
236static void
237build_triangles(struct radv_bvh_build_ctx *ctx, const VkAccelerationStructureGeometryKHR *geom,
238                const VkAccelerationStructureBuildRangeInfoKHR *range, unsigned geometry_id)
239{
240   const VkAccelerationStructureGeometryTrianglesDataKHR *tri_data = &geom->geometry.triangles;
241   VkTransformMatrixKHR matrix;
242   const char *index_data = (const char *)tri_data->indexData.hostAddress;
243   const char *v_data_base = (const char *)tri_data->vertexData.hostAddress;
244
245   if (tri_data->indexType == VK_INDEX_TYPE_NONE_KHR)
246      v_data_base += range->primitiveOffset;
247   else
248      index_data += range->primitiveOffset;
249
250   if (tri_data->transformData.hostAddress) {
251      matrix = *(const VkTransformMatrixKHR *)((const char *)tri_data->transformData.hostAddress +
252                                               range->transformOffset);
253   } else {
254      matrix = (VkTransformMatrixKHR){
255         .matrix = {{1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0}}};
256   }
257
258   for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 64) {
259      struct radv_bvh_triangle_node *node = (void *)ctx->curr_ptr;
260      uint32_t node_offset = ctx->curr_ptr - ctx->base;
261      uint32_t node_id = node_offset >> 3;
262      *ctx->write_scratch++ = node_id;
263
264      for (unsigned v = 0; v < 3; ++v) {
265         uint32_t v_index = range->firstVertex;
266         switch (tri_data->indexType) {
267         case VK_INDEX_TYPE_NONE_KHR:
268            v_index += p * 3 + v;
269            break;
270         case VK_INDEX_TYPE_UINT8_EXT:
271            v_index += *(const uint8_t *)index_data;
272            index_data += 1;
273            break;
274         case VK_INDEX_TYPE_UINT16:
275            v_index += *(const uint16_t *)index_data;
276            index_data += 2;
277            break;
278         case VK_INDEX_TYPE_UINT32:
279            v_index += *(const uint32_t *)index_data;
280            index_data += 4;
281            break;
282         case VK_INDEX_TYPE_MAX_ENUM:
283            unreachable("Unhandled VK_INDEX_TYPE_MAX_ENUM");
284            break;
285         }
286
287         const char *v_data = v_data_base + v_index * tri_data->vertexStride;
288         float coords[4];
289         switch (tri_data->vertexFormat) {
290         case VK_FORMAT_R32G32_SFLOAT:
291            coords[0] = *(const float *)(v_data + 0);
292            coords[1] = *(const float *)(v_data + 4);
293            coords[2] = 0.0f;
294            coords[3] = 1.0f;
295            break;
296         case VK_FORMAT_R32G32B32_SFLOAT:
297            coords[0] = *(const float *)(v_data + 0);
298            coords[1] = *(const float *)(v_data + 4);
299            coords[2] = *(const float *)(v_data + 8);
300            coords[3] = 1.0f;
301            break;
302         case VK_FORMAT_R32G32B32A32_SFLOAT:
303            coords[0] = *(const float *)(v_data + 0);
304            coords[1] = *(const float *)(v_data + 4);
305            coords[2] = *(const float *)(v_data + 8);
306            coords[3] = *(const float *)(v_data + 12);
307            break;
308         case VK_FORMAT_R16G16_SFLOAT:
309            coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
310            coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
311            coords[2] = 0.0f;
312            coords[3] = 1.0f;
313            break;
314         case VK_FORMAT_R16G16B16_SFLOAT:
315            coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
316            coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
317            coords[2] = _mesa_half_to_float(*(const uint16_t *)(v_data + 4));
318            coords[3] = 1.0f;
319            break;
320         case VK_FORMAT_R16G16B16A16_SFLOAT:
321            coords[0] = _mesa_half_to_float(*(const uint16_t *)(v_data + 0));
322            coords[1] = _mesa_half_to_float(*(const uint16_t *)(v_data + 2));
323            coords[2] = _mesa_half_to_float(*(const uint16_t *)(v_data + 4));
324            coords[3] = _mesa_half_to_float(*(const uint16_t *)(v_data + 6));
325            break;
326         case VK_FORMAT_R16G16_SNORM:
327            coords[0] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 0), 16);
328            coords[1] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 2), 16);
329            coords[2] = 0.0f;
330            coords[3] = 1.0f;
331            break;
332         case VK_FORMAT_R16G16_UNORM:
333            coords[0] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 0), 16);
334            coords[1] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 2), 16);
335            coords[2] = 0.0f;
336            coords[3] = 1.0f;
337            break;
338         case VK_FORMAT_R16G16B16A16_SNORM:
339            coords[0] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 0), 16);
340            coords[1] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 2), 16);
341            coords[2] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 4), 16);
342            coords[3] = _mesa_snorm_to_float(*(const int16_t *)(v_data + 6), 16);
343            break;
344         case VK_FORMAT_R16G16B16A16_UNORM:
345            coords[0] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 0), 16);
346            coords[1] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 2), 16);
347            coords[2] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 4), 16);
348            coords[3] = _mesa_unorm_to_float(*(const uint16_t *)(v_data + 6), 16);
349            break;
350         case VK_FORMAT_R8G8_SNORM:
351            coords[0] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 0), 8);
352            coords[1] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 1), 8);
353            coords[2] = 0.0f;
354            coords[3] = 1.0f;
355            break;
356         case VK_FORMAT_R8G8_UNORM:
357            coords[0] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 0), 8);
358            coords[1] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 1), 8);
359            coords[2] = 0.0f;
360            coords[3] = 1.0f;
361            break;
362         case VK_FORMAT_R8G8B8A8_SNORM:
363            coords[0] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 0), 8);
364            coords[1] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 1), 8);
365            coords[2] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 2), 8);
366            coords[3] = _mesa_snorm_to_float(*(const int8_t *)(v_data + 3), 8);
367            break;
368         case VK_FORMAT_R8G8B8A8_UNORM:
369            coords[0] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 0), 8);
370            coords[1] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 1), 8);
371            coords[2] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 2), 8);
372            coords[3] = _mesa_unorm_to_float(*(const uint8_t *)(v_data + 3), 8);
373            break;
374         case VK_FORMAT_A2B10G10R10_UNORM_PACK32: {
375            uint32_t val = *(const uint32_t *)v_data;
376            coords[0] = _mesa_unorm_to_float((val >> 0) & 0x3FF, 10);
377            coords[1] = _mesa_unorm_to_float((val >> 10) & 0x3FF, 10);
378            coords[2] = _mesa_unorm_to_float((val >> 20) & 0x3FF, 10);
379            coords[3] = _mesa_unorm_to_float((val >> 30) & 0x3, 2);
380         } break;
381         default:
382            unreachable("Unhandled vertex format in BVH build");
383         }
384
385         for (unsigned j = 0; j < 3; ++j) {
386            float r = 0;
387            for (unsigned k = 0; k < 4; ++k)
388               r += matrix.matrix[j][k] * coords[k];
389            node->coords[v][j] = r;
390         }
391
392         node->triangle_id = p;
393         node->geometry_id_and_flags = geometry_id | (geom->flags << 28);
394
395         /* Seems to be needed for IJ, otherwise I = J = ? */
396         node->id = 9;
397      }
398   }
399}
400
401static VkResult
402build_instances(struct radv_device *device, struct radv_bvh_build_ctx *ctx,
403                const VkAccelerationStructureGeometryKHR *geom,
404                const VkAccelerationStructureBuildRangeInfoKHR *range)
405{
406   const VkAccelerationStructureGeometryInstancesDataKHR *inst_data = &geom->geometry.instances;
407
408   for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 128) {
409      const char *instance_data =
410         (const char *)inst_data->data.hostAddress + range->primitiveOffset;
411      const VkAccelerationStructureInstanceKHR *instance =
412         inst_data->arrayOfPointers
413            ? (((const VkAccelerationStructureInstanceKHR *const *)instance_data)[p])
414            : &((const VkAccelerationStructureInstanceKHR *)instance_data)[p];
415      if (!instance->accelerationStructureReference) {
416         continue;
417      }
418
419      struct radv_bvh_instance_node *node = (void *)ctx->curr_ptr;
420      uint32_t node_offset = ctx->curr_ptr - ctx->base;
421      uint32_t node_id = (node_offset >> 3) | radv_bvh_node_instance;
422      *ctx->write_scratch++ = node_id;
423
424      float transform[16], inv_transform[16];
425      memcpy(transform, &instance->transform.matrix, sizeof(instance->transform.matrix));
426      transform[12] = transform[13] = transform[14] = 0.0f;
427      transform[15] = 1.0f;
428
429      util_invert_mat4x4(inv_transform, transform);
430      memcpy(node->wto_matrix, inv_transform, sizeof(node->wto_matrix));
431      node->wto_matrix[3] = transform[3];
432      node->wto_matrix[7] = transform[7];
433      node->wto_matrix[11] = transform[11];
434      node->custom_instance_and_mask = instance->instanceCustomIndex | (instance->mask << 24);
435      node->sbt_offset_and_flags =
436         instance->instanceShaderBindingTableRecordOffset | (instance->flags << 24);
437      node->instance_id = p;
438
439      for (unsigned i = 0; i < 3; ++i)
440         for (unsigned j = 0; j < 3; ++j)
441            node->otw_matrix[i * 3 + j] = instance->transform.matrix[j][i];
442
443      RADV_FROM_HANDLE(radv_acceleration_structure, src_accel_struct,
444                       (VkAccelerationStructureKHR)instance->accelerationStructureReference);
445      const void *src_base = device->ws->buffer_map(src_accel_struct->bo);
446      if (!src_base)
447         return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
448
449      src_base = (const char *)src_base + src_accel_struct->mem_offset;
450      const struct radv_accel_struct_header *src_header = src_base;
451      node->base_ptr = radv_accel_struct_get_va(src_accel_struct) | src_header->root_node_offset;
452
453      for (unsigned j = 0; j < 3; ++j) {
454         node->aabb[0][j] = instance->transform.matrix[j][3];
455         node->aabb[1][j] = instance->transform.matrix[j][3];
456         for (unsigned k = 0; k < 3; ++k) {
457            node->aabb[0][j] += MIN2(instance->transform.matrix[j][k] * src_header->aabb[0][k],
458                                     instance->transform.matrix[j][k] * src_header->aabb[1][k]);
459            node->aabb[1][j] += MAX2(instance->transform.matrix[j][k] * src_header->aabb[0][k],
460                                     instance->transform.matrix[j][k] * src_header->aabb[1][k]);
461         }
462      }
463      device->ws->buffer_unmap(src_accel_struct->bo);
464   }
465   return VK_SUCCESS;
466}
467
468static void
469build_aabbs(struct radv_bvh_build_ctx *ctx, const VkAccelerationStructureGeometryKHR *geom,
470            const VkAccelerationStructureBuildRangeInfoKHR *range, unsigned geometry_id)
471{
472   const VkAccelerationStructureGeometryAabbsDataKHR *aabb_data = &geom->geometry.aabbs;
473
474   for (uint32_t p = 0; p < range->primitiveCount; ++p, ctx->curr_ptr += 64) {
475      struct radv_bvh_aabb_node *node = (void *)ctx->curr_ptr;
476      uint32_t node_offset = ctx->curr_ptr - ctx->base;
477      uint32_t node_id = (node_offset >> 3) | radv_bvh_node_aabb;
478      *ctx->write_scratch++ = node_id;
479
480      const VkAabbPositionsKHR *aabb =
481         (const VkAabbPositionsKHR *)((const char *)aabb_data->data.hostAddress +
482                                      range->primitiveOffset + p * aabb_data->stride);
483
484      node->aabb[0][0] = aabb->minX;
485      node->aabb[0][1] = aabb->minY;
486      node->aabb[0][2] = aabb->minZ;
487      node->aabb[1][0] = aabb->maxX;
488      node->aabb[1][1] = aabb->maxY;
489      node->aabb[1][2] = aabb->maxZ;
490      node->primitive_id = p;
491      node->geometry_id_and_flags = geometry_id;
492   }
493}
494
495static uint32_t
496leaf_node_count(const VkAccelerationStructureBuildGeometryInfoKHR *info,
497                const VkAccelerationStructureBuildRangeInfoKHR *ranges)
498{
499   uint32_t count = 0;
500   for (uint32_t i = 0; i < info->geometryCount; ++i) {
501      count += ranges[i].primitiveCount;
502   }
503   return count;
504}
505
506static void
507compute_bounds(const char *base_ptr, uint32_t node_id, float *bounds)
508{
509   for (unsigned i = 0; i < 3; ++i)
510      bounds[i] = INFINITY;
511   for (unsigned i = 0; i < 3; ++i)
512      bounds[3 + i] = -INFINITY;
513
514   switch (node_id & 7) {
515   case radv_bvh_node_triangle: {
516      const struct radv_bvh_triangle_node *node = (const void *)(base_ptr + (node_id / 8 * 64));
517      for (unsigned v = 0; v < 3; ++v) {
518         for (unsigned j = 0; j < 3; ++j) {
519            bounds[j] = MIN2(bounds[j], node->coords[v][j]);
520            bounds[3 + j] = MAX2(bounds[3 + j], node->coords[v][j]);
521         }
522      }
523      break;
524   }
525   case radv_bvh_node_internal: {
526      const struct radv_bvh_box32_node *node = (const void *)(base_ptr + (node_id / 8 * 64));
527      for (unsigned c2 = 0; c2 < 4; ++c2) {
528         if (isnan(node->coords[c2][0][0]))
529            continue;
530         for (unsigned j = 0; j < 3; ++j) {
531            bounds[j] = MIN2(bounds[j], node->coords[c2][0][j]);
532            bounds[3 + j] = MAX2(bounds[3 + j], node->coords[c2][1][j]);
533         }
534      }
535      break;
536   }
537   case radv_bvh_node_instance: {
538      const struct radv_bvh_instance_node *node = (const void *)(base_ptr + (node_id / 8 * 64));
539      for (unsigned j = 0; j < 3; ++j) {
540         bounds[j] = MIN2(bounds[j], node->aabb[0][j]);
541         bounds[3 + j] = MAX2(bounds[3 + j], node->aabb[1][j]);
542      }
543      break;
544   }
545   case radv_bvh_node_aabb: {
546      const struct radv_bvh_aabb_node *node = (const void *)(base_ptr + (node_id / 8 * 64));
547      for (unsigned j = 0; j < 3; ++j) {
548         bounds[j] = MIN2(bounds[j], node->aabb[0][j]);
549         bounds[3 + j] = MAX2(bounds[3 + j], node->aabb[1][j]);
550      }
551      break;
552   }
553   }
554}
555
556struct bvh_opt_entry {
557   uint64_t key;
558   uint32_t node_id;
559};
560
561static int
562bvh_opt_compare(const void *_a, const void *_b)
563{
564   const struct bvh_opt_entry *a = _a;
565   const struct bvh_opt_entry *b = _b;
566
567   if (a->key < b->key)
568      return -1;
569   if (a->key > b->key)
570      return 1;
571   if (a->node_id < b->node_id)
572      return -1;
573   if (a->node_id > b->node_id)
574      return 1;
575   return 0;
576}
577
578static void
579optimize_bvh(const char *base_ptr, uint32_t *node_ids, uint32_t node_count)
580{
581   if (node_count == 0)
582      return;
583
584   float bounds[6];
585   for (unsigned i = 0; i < 3; ++i)
586      bounds[i] = INFINITY;
587   for (unsigned i = 0; i < 3; ++i)
588      bounds[3 + i] = -INFINITY;
589
590   for (uint32_t i = 0; i < node_count; ++i) {
591      float node_bounds[6];
592      compute_bounds(base_ptr, node_ids[i], node_bounds);
593      for (unsigned j = 0; j < 3; ++j)
594         bounds[j] = MIN2(bounds[j], node_bounds[j]);
595      for (unsigned j = 0; j < 3; ++j)
596         bounds[3 + j] = MAX2(bounds[3 + j], node_bounds[3 + j]);
597   }
598
599   struct bvh_opt_entry *entries = calloc(node_count, sizeof(struct bvh_opt_entry));
600   if (!entries)
601      return;
602
603   for (uint32_t i = 0; i < node_count; ++i) {
604      float node_bounds[6];
605      compute_bounds(base_ptr, node_ids[i], node_bounds);
606      float node_coords[3];
607      for (unsigned j = 0; j < 3; ++j)
608         node_coords[j] = (node_bounds[j] + node_bounds[3 + j]) * 0.5;
609      int32_t coords[3];
610      for (unsigned j = 0; j < 3; ++j)
611         coords[j] = MAX2(
612            MIN2((int32_t)((node_coords[j] - bounds[j]) / (bounds[3 + j] - bounds[j]) * (1 << 21)),
613                 (1 << 21) - 1),
614            0);
615      uint64_t key = 0;
616      for (unsigned j = 0; j < 21; ++j)
617         for (unsigned k = 0; k < 3; ++k)
618            key |= (uint64_t)((coords[k] >> j) & 1) << (j * 3 + k);
619      entries[i].key = key;
620      entries[i].node_id = node_ids[i];
621   }
622
623   qsort(entries, node_count, sizeof(entries[0]), bvh_opt_compare);
624   for (unsigned i = 0; i < node_count; ++i)
625      node_ids[i] = entries[i].node_id;
626
627   free(entries);
628}
629
630static void
631fill_accel_struct_header(struct radv_accel_struct_header *header)
632{
633   /* 16 bytes per invocation, 64 invocations per workgroup */
634   header->copy_dispatch_size[0] = DIV_ROUND_UP(header->compacted_size, 16 * 64);
635   header->copy_dispatch_size[1] = 1;
636   header->copy_dispatch_size[2] = 1;
637
638   header->serialization_size =
639      header->compacted_size + align(sizeof(struct radv_accel_struct_serialization_header) +
640                                        sizeof(uint64_t) * header->instance_count,
641                                     128);
642
643   header->size = header->serialization_size -
644                  sizeof(struct radv_accel_struct_serialization_header) -
645                  sizeof(uint64_t) * header->instance_count;
646}
647
648static VkResult
649build_bvh(struct radv_device *device, const VkAccelerationStructureBuildGeometryInfoKHR *info,
650          const VkAccelerationStructureBuildRangeInfoKHR *ranges)
651{
652   RADV_FROM_HANDLE(radv_acceleration_structure, accel, info->dstAccelerationStructure);
653   VkResult result = VK_SUCCESS;
654
655   uint32_t *scratch[2];
656   scratch[0] = info->scratchData.hostAddress;
657   scratch[1] = scratch[0] + leaf_node_count(info, ranges);
658
659   char *base_ptr = (char *)device->ws->buffer_map(accel->bo);
660   if (!base_ptr)
661      return vk_error(device, VK_ERROR_OUT_OF_HOST_MEMORY);
662
663   base_ptr = base_ptr + accel->mem_offset;
664   struct radv_accel_struct_header *header = (void *)base_ptr;
665   void *first_node_ptr = (char *)base_ptr + ALIGN(sizeof(*header), 64);
666
667   struct radv_bvh_build_ctx ctx = {.write_scratch = scratch[0],
668                                    .base = base_ptr,
669                                    .curr_ptr = (char *)first_node_ptr + 128};
670
671   uint64_t instance_offset = (const char *)ctx.curr_ptr - (const char *)base_ptr;
672   uint64_t instance_count = 0;
673
674   /* This initializes the leaf nodes of the BVH all at the same level. */
675   for (int inst = 1; inst >= 0; --inst) {
676      for (uint32_t i = 0; i < info->geometryCount; ++i) {
677         const VkAccelerationStructureGeometryKHR *geom =
678            info->pGeometries ? &info->pGeometries[i] : info->ppGeometries[i];
679
680         if ((inst && geom->geometryType != VK_GEOMETRY_TYPE_INSTANCES_KHR) ||
681             (!inst && geom->geometryType == VK_GEOMETRY_TYPE_INSTANCES_KHR))
682            continue;
683
684         switch (geom->geometryType) {
685         case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
686            build_triangles(&ctx, geom, ranges + i, i);
687            break;
688         case VK_GEOMETRY_TYPE_AABBS_KHR:
689            build_aabbs(&ctx, geom, ranges + i, i);
690            break;
691         case VK_GEOMETRY_TYPE_INSTANCES_KHR: {
692            result = build_instances(device, &ctx, geom, ranges + i);
693            if (result != VK_SUCCESS)
694               goto fail;
695
696            instance_count += ranges[i].primitiveCount;
697            break;
698         }
699         case VK_GEOMETRY_TYPE_MAX_ENUM_KHR:
700            unreachable("VK_GEOMETRY_TYPE_MAX_ENUM_KHR unhandled");
701         }
702      }
703   }
704
705   uint32_t node_counts[2] = {ctx.write_scratch - scratch[0], 0};
706   optimize_bvh(base_ptr, scratch[0], node_counts[0]);
707   unsigned d;
708
709   /*
710    * This is the most naive BVH building algorithm I could think of:
711    * just iteratively builds each level from bottom to top with
712    * the children of each node being in-order and tightly packed.
713    *
714    * Is probably terrible for traversal but should be easy to build an
715    * equivalent GPU version.
716    */
717   for (d = 0; node_counts[d & 1] > 1 || d == 0; ++d) {
718      uint32_t child_count = node_counts[d & 1];
719      const uint32_t *children = scratch[d & 1];
720      uint32_t *dst_ids = scratch[(d & 1) ^ 1];
721      unsigned dst_count;
722      unsigned child_idx = 0;
723      for (dst_count = 0; child_idx < MAX2(1, child_count); ++dst_count, child_idx += 4) {
724         unsigned local_child_count = MIN2(4, child_count - child_idx);
725         uint32_t child_ids[4];
726         float bounds[4][6];
727
728         for (unsigned c = 0; c < local_child_count; ++c) {
729            uint32_t id = children[child_idx + c];
730            child_ids[c] = id;
731
732            compute_bounds(base_ptr, id, bounds[c]);
733         }
734
735         struct radv_bvh_box32_node *node;
736
737         /* Put the root node at base_ptr so the id = 0, which allows some
738          * traversal optimizations. */
739         if (child_idx == 0 && local_child_count == child_count) {
740            node = first_node_ptr;
741            header->root_node_offset = ((char *)first_node_ptr - (char *)base_ptr) / 64 * 8 + 5;
742         } else {
743            uint32_t dst_id = (ctx.curr_ptr - base_ptr) / 64;
744            dst_ids[dst_count] = dst_id * 8 + 5;
745
746            node = (void *)ctx.curr_ptr;
747            ctx.curr_ptr += 128;
748         }
749
750         for (unsigned c = 0; c < local_child_count; ++c) {
751            node->children[c] = child_ids[c];
752            for (unsigned i = 0; i < 2; ++i)
753               for (unsigned j = 0; j < 3; ++j)
754                  node->coords[c][i][j] = bounds[c][i * 3 + j];
755         }
756         for (unsigned c = local_child_count; c < 4; ++c) {
757            for (unsigned i = 0; i < 2; ++i)
758               for (unsigned j = 0; j < 3; ++j)
759                  node->coords[c][i][j] = NAN;
760         }
761      }
762
763      node_counts[(d & 1) ^ 1] = dst_count;
764   }
765
766   compute_bounds(base_ptr, header->root_node_offset, &header->aabb[0][0]);
767
768   header->instance_offset = instance_offset;
769   header->instance_count = instance_count;
770   header->compacted_size = (char *)ctx.curr_ptr - base_ptr;
771
772   fill_accel_struct_header(header);
773
774fail:
775   device->ws->buffer_unmap(accel->bo);
776   return result;
777}
778
779VKAPI_ATTR VkResult VKAPI_CALL
780radv_BuildAccelerationStructuresKHR(
781   VkDevice _device, VkDeferredOperationKHR deferredOperation, uint32_t infoCount,
782   const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
783   const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos)
784{
785   RADV_FROM_HANDLE(radv_device, device, _device);
786   VkResult result = VK_SUCCESS;
787
788   for (uint32_t i = 0; i < infoCount; ++i) {
789      result = build_bvh(device, pInfos + i, ppBuildRangeInfos[i]);
790      if (result != VK_SUCCESS)
791         break;
792   }
793   return result;
794}
795
796VKAPI_ATTR VkResult VKAPI_CALL
797radv_CopyAccelerationStructureKHR(VkDevice _device, VkDeferredOperationKHR deferredOperation,
798                                  const VkCopyAccelerationStructureInfoKHR *pInfo)
799{
800   RADV_FROM_HANDLE(radv_device, device, _device);
801   RADV_FROM_HANDLE(radv_acceleration_structure, src_struct, pInfo->src);
802   RADV_FROM_HANDLE(radv_acceleration_structure, dst_struct, pInfo->dst);
803
804   char *src_ptr = (char *)device->ws->buffer_map(src_struct->bo);
805   if (!src_ptr)
806      return VK_ERROR_OUT_OF_HOST_MEMORY;
807
808   char *dst_ptr = (char *)device->ws->buffer_map(dst_struct->bo);
809   if (!dst_ptr) {
810      device->ws->buffer_unmap(src_struct->bo);
811      return VK_ERROR_OUT_OF_HOST_MEMORY;
812   }
813
814   src_ptr += src_struct->mem_offset;
815   dst_ptr += dst_struct->mem_offset;
816
817   const struct radv_accel_struct_header *header = (const void *)src_ptr;
818   memcpy(dst_ptr, src_ptr, header->compacted_size);
819
820   device->ws->buffer_unmap(src_struct->bo);
821   device->ws->buffer_unmap(dst_struct->bo);
822   return VK_SUCCESS;
823}
824
825static nir_builder
826create_accel_build_shader(struct radv_device *device, const char *name)
827{
828   nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "%s", name);
829   b.shader->info.workgroup_size[0] = 64;
830
831   assert(b.shader->info.workgroup_size[1] == 1);
832   assert(b.shader->info.workgroup_size[2] == 1);
833   assert(!b.shader->info.workgroup_size_variable);
834
835   return b;
836}
837
838static nir_ssa_def *
839get_indices(nir_builder *b, nir_ssa_def *addr, nir_ssa_def *type, nir_ssa_def *id)
840{
841   const struct glsl_type *uvec3_type = glsl_vector_type(GLSL_TYPE_UINT, 3);
842   nir_variable *result =
843      nir_variable_create(b->shader, nir_var_shader_temp, uvec3_type, "indices");
844
845   nir_push_if(b, nir_ult(b, type, nir_imm_int(b, 2)));
846   nir_push_if(b, nir_ieq_imm(b, type, VK_INDEX_TYPE_UINT16));
847   {
848      nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 6));
849      nir_ssa_def *indices[3];
850      for (unsigned i = 0; i < 3; ++i) {
851         indices[i] = nir_build_load_global(
852            b, 1, 16, nir_iadd(b, addr, nir_u2u64(b, nir_iadd_imm(b, index_id, 2 * i))));
853      }
854      nir_store_var(b, result, nir_u2u32(b, nir_vec(b, indices, 3)), 7);
855   }
856   nir_push_else(b, NULL);
857   {
858      nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 12));
859      nir_ssa_def *indices =
860         nir_build_load_global(b, 3, 32, nir_iadd(b, addr, nir_u2u64(b, index_id)));
861      nir_store_var(b, result, indices, 7);
862   }
863   nir_pop_if(b, NULL);
864   nir_push_else(b, NULL);
865   {
866      nir_ssa_def *index_id = nir_umul24(b, id, nir_imm_int(b, 3));
867      nir_ssa_def *indices[] = {
868         index_id,
869         nir_iadd_imm(b, index_id, 1),
870         nir_iadd_imm(b, index_id, 2),
871      };
872
873      nir_push_if(b, nir_ieq_imm(b, type, VK_INDEX_TYPE_NONE_KHR));
874      {
875         nir_store_var(b, result, nir_vec(b, indices, 3), 7);
876      }
877      nir_push_else(b, NULL);
878      {
879         for (unsigned i = 0; i < 3; ++i) {
880            indices[i] =
881               nir_build_load_global(b, 1, 8, nir_iadd(b, addr, nir_u2u64(b, indices[i])));
882         }
883         nir_store_var(b, result, nir_u2u32(b, nir_vec(b, indices, 3)), 7);
884      }
885      nir_pop_if(b, NULL);
886   }
887   nir_pop_if(b, NULL);
888   return nir_load_var(b, result);
889}
890
891static void
892get_vertices(nir_builder *b, nir_ssa_def *addresses, nir_ssa_def *format, nir_ssa_def *positions[3])
893{
894   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
895   nir_variable *results[3] = {
896      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex0"),
897      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex1"),
898      nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "vertex2")};
899
900   VkFormat formats[] = {
901      VK_FORMAT_R32G32B32_SFLOAT,
902      VK_FORMAT_R32G32B32A32_SFLOAT,
903      VK_FORMAT_R16G16B16_SFLOAT,
904      VK_FORMAT_R16G16B16A16_SFLOAT,
905      VK_FORMAT_R16G16_SFLOAT,
906      VK_FORMAT_R32G32_SFLOAT,
907      VK_FORMAT_R16G16_SNORM,
908      VK_FORMAT_R16G16_UNORM,
909      VK_FORMAT_R16G16B16A16_SNORM,
910      VK_FORMAT_R16G16B16A16_UNORM,
911      VK_FORMAT_R8G8_SNORM,
912      VK_FORMAT_R8G8_UNORM,
913      VK_FORMAT_R8G8B8A8_SNORM,
914      VK_FORMAT_R8G8B8A8_UNORM,
915      VK_FORMAT_A2B10G10R10_UNORM_PACK32,
916   };
917
918   for (unsigned f = 0; f < ARRAY_SIZE(formats); ++f) {
919      if (f + 1 < ARRAY_SIZE(formats))
920         nir_push_if(b, nir_ieq_imm(b, format, formats[f]));
921
922      for (unsigned i = 0; i < 3; ++i) {
923         switch (formats[f]) {
924         case VK_FORMAT_R32G32B32_SFLOAT:
925         case VK_FORMAT_R32G32B32A32_SFLOAT:
926            nir_store_var(b, results[i],
927                          nir_build_load_global(b, 3, 32, nir_channel(b, addresses, i)), 7);
928            break;
929         case VK_FORMAT_R32G32_SFLOAT:
930         case VK_FORMAT_R16G16_SFLOAT:
931         case VK_FORMAT_R16G16B16_SFLOAT:
932         case VK_FORMAT_R16G16B16A16_SFLOAT:
933         case VK_FORMAT_R16G16_SNORM:
934         case VK_FORMAT_R16G16_UNORM:
935         case VK_FORMAT_R16G16B16A16_SNORM:
936         case VK_FORMAT_R16G16B16A16_UNORM:
937         case VK_FORMAT_R8G8_SNORM:
938         case VK_FORMAT_R8G8_UNORM:
939         case VK_FORMAT_R8G8B8A8_SNORM:
940         case VK_FORMAT_R8G8B8A8_UNORM:
941         case VK_FORMAT_A2B10G10R10_UNORM_PACK32: {
942            unsigned components = MIN2(3, vk_format_get_nr_components(formats[f]));
943            unsigned comp_bits =
944               vk_format_get_blocksizebits(formats[f]) / vk_format_get_nr_components(formats[f]);
945            unsigned comp_bytes = comp_bits / 8;
946            nir_ssa_def *values[3];
947            nir_ssa_def *addr = nir_channel(b, addresses, i);
948
949            if (formats[f] == VK_FORMAT_A2B10G10R10_UNORM_PACK32) {
950               comp_bits = 10;
951               nir_ssa_def *val = nir_build_load_global(b, 1, 32, addr);
952               for (unsigned j = 0; j < 3; ++j)
953                  values[j] = nir_ubfe(b, val, nir_imm_int(b, j * 10), nir_imm_int(b, 10));
954            } else {
955               for (unsigned j = 0; j < components; ++j)
956                  values[j] =
957                     nir_build_load_global(b, 1, comp_bits, nir_iadd_imm(b, addr, j * comp_bytes));
958
959               for (unsigned j = components; j < 3; ++j)
960                  values[j] = nir_imm_intN_t(b, 0, comp_bits);
961            }
962
963            nir_ssa_def *vec;
964            if (util_format_is_snorm(vk_format_to_pipe_format(formats[f]))) {
965               for (unsigned j = 0; j < 3; ++j) {
966                  values[j] = nir_fdiv(b, nir_i2f32(b, values[j]),
967                                       nir_imm_float(b, (1u << (comp_bits - 1)) - 1));
968                  values[j] = nir_fmax(b, values[j], nir_imm_float(b, -1.0));
969               }
970               vec = nir_vec(b, values, 3);
971            } else if (util_format_is_unorm(vk_format_to_pipe_format(formats[f]))) {
972               for (unsigned j = 0; j < 3; ++j) {
973                  values[j] =
974                     nir_fdiv(b, nir_u2f32(b, values[j]), nir_imm_float(b, (1u << comp_bits) - 1));
975                  values[j] = nir_fmin(b, values[j], nir_imm_float(b, 1.0));
976               }
977               vec = nir_vec(b, values, 3);
978            } else if (comp_bits == 16)
979               vec = nir_f2f32(b, nir_vec(b, values, 3));
980            else
981               vec = nir_vec(b, values, 3);
982            nir_store_var(b, results[i], vec, 7);
983            break;
984         }
985         default:
986            unreachable("Unhandled format");
987         }
988      }
989      if (f + 1 < ARRAY_SIZE(formats))
990         nir_push_else(b, NULL);
991   }
992   for (unsigned f = 1; f < ARRAY_SIZE(formats); ++f) {
993      nir_pop_if(b, NULL);
994   }
995
996   for (unsigned i = 0; i < 3; ++i)
997      positions[i] = nir_load_var(b, results[i]);
998}
999
1000struct build_primitive_constants {
1001   uint64_t node_dst_addr;
1002   uint64_t scratch_addr;
1003   uint32_t dst_offset;
1004   uint32_t dst_scratch_offset;
1005   uint32_t geometry_type;
1006   uint32_t geometry_id;
1007
1008   union {
1009      struct {
1010         uint64_t vertex_addr;
1011         uint64_t index_addr;
1012         uint64_t transform_addr;
1013         uint32_t vertex_stride;
1014         uint32_t vertex_format;
1015         uint32_t index_format;
1016      };
1017      struct {
1018         uint64_t instance_data;
1019         uint32_t array_of_pointers;
1020      };
1021      struct {
1022         uint64_t aabb_addr;
1023         uint32_t aabb_stride;
1024      };
1025   };
1026};
1027
1028struct bounds_constants {
1029   uint64_t node_addr;
1030   uint64_t scratch_addr;
1031};
1032
1033struct morton_constants {
1034   uint64_t node_addr;
1035   uint64_t scratch_addr;
1036};
1037
1038struct fill_constants {
1039   uint64_t addr;
1040   uint32_t value;
1041};
1042
1043struct build_internal_constants {
1044   uint64_t node_dst_addr;
1045   uint64_t scratch_addr;
1046   uint32_t dst_offset;
1047   uint32_t dst_scratch_offset;
1048   uint32_t src_scratch_offset;
1049   uint32_t fill_header;
1050};
1051
1052/* This inverts a 3x3 matrix using cofactors, as in e.g.
1053 * https://www.mathsisfun.com/algebra/matrix-inverse-minors-cofactors-adjugate.html */
1054static void
1055nir_invert_3x3(nir_builder *b, nir_ssa_def *in[3][3], nir_ssa_def *out[3][3])
1056{
1057   nir_ssa_def *cofactors[3][3];
1058   for (unsigned i = 0; i < 3; ++i) {
1059      for (unsigned j = 0; j < 3; ++j) {
1060         cofactors[i][j] =
1061            nir_fsub(b, nir_fmul(b, in[(i + 1) % 3][(j + 1) % 3], in[(i + 2) % 3][(j + 2) % 3]),
1062                     nir_fmul(b, in[(i + 1) % 3][(j + 2) % 3], in[(i + 2) % 3][(j + 1) % 3]));
1063      }
1064   }
1065
1066   nir_ssa_def *det = NULL;
1067   for (unsigned i = 0; i < 3; ++i) {
1068      nir_ssa_def *det_part = nir_fmul(b, in[0][i], cofactors[0][i]);
1069      det = det ? nir_fadd(b, det, det_part) : det_part;
1070   }
1071
1072   nir_ssa_def *det_inv = nir_frcp(b, det);
1073   for (unsigned i = 0; i < 3; ++i) {
1074      for (unsigned j = 0; j < 3; ++j) {
1075         out[i][j] = nir_fmul(b, cofactors[j][i], det_inv);
1076      }
1077   }
1078}
1079
1080static nir_ssa_def *
1081id_to_node_id_offset(nir_builder *b, nir_ssa_def *global_id,
1082                     const struct radv_physical_device *pdevice)
1083{
1084   uint32_t stride = get_node_id_stride(
1085      get_accel_struct_build(pdevice, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR));
1086
1087   return nir_imul_imm(b, global_id, stride);
1088}
1089
1090static nir_ssa_def *
1091id_to_morton_offset(nir_builder *b, nir_ssa_def *global_id,
1092                    const struct radv_physical_device *pdevice)
1093{
1094   enum accel_struct_build build_mode =
1095      get_accel_struct_build(pdevice, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
1096   assert(build_mode == accel_struct_build_lbvh);
1097
1098   uint32_t stride = get_node_id_stride(build_mode);
1099
1100   return nir_iadd_imm(b, nir_imul_imm(b, global_id, stride), sizeof(uint32_t));
1101}
1102
1103static void
1104atomic_fminmax(struct radv_device *dev, nir_builder *b, nir_ssa_def *addr, bool is_max,
1105               nir_ssa_def *val)
1106{
1107   if (radv_has_shader_buffer_float_minmax(dev->physical_device)) {
1108      if (is_max)
1109         nir_global_atomic_fmax(b, 32, addr, val);
1110      else
1111         nir_global_atomic_fmin(b, 32, addr, val);
1112      return;
1113   }
1114
1115   /* Use an integer comparison to work correctly with negative zero. */
1116   val = nir_bcsel(b, nir_ilt(b, val, nir_imm_int(b, 0)),
1117                   nir_isub(b, nir_imm_int(b, -2147483648), val), val);
1118
1119   if (is_max)
1120      nir_global_atomic_imax(b, 32, addr, val);
1121   else
1122      nir_global_atomic_imin(b, 32, addr, val);
1123}
1124
1125static nir_ssa_def *
1126read_fminmax_atomic(struct radv_device *dev, nir_builder *b, unsigned channels, nir_ssa_def *addr)
1127{
1128   nir_ssa_def *val = nir_build_load_global(b, channels, 32, addr,
1129                                            .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER);
1130
1131   if (radv_has_shader_buffer_float_minmax(dev->physical_device))
1132      return val;
1133
1134   return nir_bcsel(b, nir_ilt(b, val, nir_imm_int(b, 0)),
1135                    nir_isub(b, nir_imm_int(b, -2147483648), val), val);
1136}
1137
1138static nir_shader *
1139build_leaf_shader(struct radv_device *dev)
1140{
1141   enum accel_struct_build build_mode =
1142      get_accel_struct_build(dev->physical_device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
1143
1144   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1145   nir_builder b = create_accel_build_shader(dev, "accel_build_leaf_shader");
1146
1147   nir_ssa_def *pconst0 =
1148      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1149   nir_ssa_def *pconst1 =
1150      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
1151   nir_ssa_def *pconst2 =
1152      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 32, .range = 16);
1153   nir_ssa_def *pconst3 =
1154      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 48, .range = 16);
1155   nir_ssa_def *index_format =
1156      nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 64, .range = 4);
1157
1158   nir_ssa_def *node_dst_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
1159   nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
1160   nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
1161   nir_ssa_def *scratch_offset = nir_channel(&b, pconst1, 1);
1162   nir_ssa_def *geom_type = nir_channel(&b, pconst1, 2);
1163   nir_ssa_def *geometry_id = nir_channel(&b, pconst1, 3);
1164
1165   nir_ssa_def *global_id =
1166      nir_iadd(&b,
1167               nir_imul_imm(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
1168                            b.shader->info.workgroup_size[0]),
1169               nir_channels(&b, nir_load_local_invocation_id(&b), 1));
1170   nir_ssa_def *scratch_dst_addr =
1171      nir_iadd(&b, scratch_addr,
1172               nir_u2u64(&b, nir_iadd(&b, scratch_offset,
1173                                      id_to_node_id_offset(&b, global_id, dev->physical_device))));
1174   if (build_mode != accel_struct_build_unoptimized)
1175      scratch_dst_addr = nir_iadd_imm(&b, scratch_dst_addr, SCRATCH_TOTAL_BOUNDS_SIZE);
1176
1177   nir_variable *bounds[2] = {
1178      nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1179      nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1180   };
1181
1182   nir_push_if(&b, nir_ieq_imm(&b, geom_type, VK_GEOMETRY_TYPE_TRIANGLES_KHR));
1183   { /* Triangles */
1184      nir_ssa_def *vertex_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011));
1185      nir_ssa_def *index_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b1100));
1186      nir_ssa_def *transform_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst3, 3));
1187      nir_ssa_def *vertex_stride = nir_channel(&b, pconst3, 2);
1188      nir_ssa_def *vertex_format = nir_channel(&b, pconst3, 3);
1189      unsigned repl_swizzle[4] = {0, 0, 0, 0};
1190
1191      nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 64));
1192      nir_ssa_def *triangle_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1193
1194      nir_ssa_def *indices = get_indices(&b, index_addr, index_format, global_id);
1195      nir_ssa_def *vertex_addresses = nir_iadd(
1196         &b, nir_u2u64(&b, nir_imul(&b, indices, nir_swizzle(&b, vertex_stride, repl_swizzle, 3))),
1197         nir_swizzle(&b, vertex_addr, repl_swizzle, 3));
1198      nir_ssa_def *positions[3];
1199      get_vertices(&b, vertex_addresses, vertex_format, positions);
1200
1201      nir_ssa_def *node_data[16];
1202      memset(node_data, 0, sizeof(node_data));
1203
1204      nir_variable *transform[] = {
1205         nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform0"),
1206         nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform1"),
1207         nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "transform2"),
1208      };
1209      nir_store_var(&b, transform[0], nir_imm_vec4(&b, 1.0, 0.0, 0.0, 0.0), 0xf);
1210      nir_store_var(&b, transform[1], nir_imm_vec4(&b, 0.0, 1.0, 0.0, 0.0), 0xf);
1211      nir_store_var(&b, transform[2], nir_imm_vec4(&b, 0.0, 0.0, 1.0, 0.0), 0xf);
1212
1213      nir_push_if(&b, nir_ine_imm(&b, transform_addr, 0));
1214      nir_store_var(&b, transform[0],
1215                    nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, transform_addr, 0),
1216                                          .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER),
1217                    0xf);
1218      nir_store_var(&b, transform[1],
1219                    nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, transform_addr, 16),
1220                                          .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER),
1221                    0xf);
1222      nir_store_var(&b, transform[2],
1223                    nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, transform_addr, 32),
1224                                          .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER),
1225                    0xf);
1226      nir_pop_if(&b, NULL);
1227
1228      for (unsigned i = 0; i < 3; ++i)
1229         for (unsigned j = 0; j < 3; ++j)
1230            node_data[i * 3 + j] = nir_fdph(&b, positions[i], nir_load_var(&b, transform[j]));
1231
1232      nir_ssa_def *min_bound = NULL;
1233      nir_ssa_def *max_bound = NULL;
1234      for (unsigned i = 0; i < 3; ++i) {
1235         nir_ssa_def *position = nir_vec(&b, node_data + i * 3, 3);
1236         if (min_bound) {
1237            min_bound = nir_fmin(&b, min_bound, position);
1238            max_bound = nir_fmax(&b, max_bound, position);
1239         } else {
1240            min_bound = position;
1241            max_bound = position;
1242         }
1243      }
1244
1245      nir_store_var(&b, bounds[0], min_bound, 7);
1246      nir_store_var(&b, bounds[1], max_bound, 7);
1247
1248      node_data[12] = global_id;
1249      node_data[13] = geometry_id;
1250      node_data[15] = nir_imm_int(&b, 9);
1251      for (unsigned i = 0; i < ARRAY_SIZE(node_data); ++i)
1252         if (!node_data[i])
1253            node_data[i] = nir_imm_int(&b, 0);
1254
1255      for (unsigned i = 0; i < 4; ++i) {
1256         nir_build_store_global(&b, nir_vec(&b, node_data + i * 4, 4),
1257                                nir_iadd_imm(&b, triangle_node_dst_addr, i * 16), .align_mul = 16);
1258      }
1259
1260      nir_ssa_def *node_id =
1261         nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), radv_bvh_node_triangle);
1262      nir_build_store_global(&b, node_id, scratch_dst_addr);
1263   }
1264   nir_push_else(&b, NULL);
1265   nir_push_if(&b, nir_ieq_imm(&b, geom_type, VK_GEOMETRY_TYPE_AABBS_KHR));
1266   { /* AABBs */
1267      nir_ssa_def *aabb_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011));
1268      nir_ssa_def *aabb_stride = nir_channel(&b, pconst2, 2);
1269
1270      nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 64));
1271      nir_ssa_def *aabb_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1272
1273      nir_ssa_def *node_id = nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), radv_bvh_node_aabb);
1274      nir_build_store_global(&b, node_id, scratch_dst_addr);
1275
1276      aabb_addr = nir_iadd(&b, aabb_addr, nir_u2u64(&b, nir_imul(&b, aabb_stride, global_id)));
1277
1278      nir_ssa_def *min_bound =
1279         nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, aabb_addr, 0),
1280                               .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER);
1281      nir_ssa_def *max_bound =
1282         nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, aabb_addr, 12),
1283                               .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER);
1284
1285      nir_store_var(&b, bounds[0], min_bound, 7);
1286      nir_store_var(&b, bounds[1], max_bound, 7);
1287
1288      nir_ssa_def *values[] = {nir_channel(&b, min_bound, 0),
1289                               nir_channel(&b, min_bound, 1),
1290                               nir_channel(&b, min_bound, 2),
1291                               nir_channel(&b, max_bound, 0),
1292                               nir_channel(&b, max_bound, 1),
1293                               nir_channel(&b, max_bound, 2),
1294                               global_id,
1295                               geometry_id};
1296
1297      nir_build_store_global(&b, nir_vec(&b, values + 0, 4),
1298                             nir_iadd_imm(&b, aabb_node_dst_addr, 0), .align_mul = 16);
1299      nir_build_store_global(&b, nir_vec(&b, values + 4, 4),
1300                             nir_iadd_imm(&b, aabb_node_dst_addr, 16), .align_mul = 16);
1301   }
1302   nir_push_else(&b, NULL);
1303   { /* Instances */
1304
1305      nir_variable *instance_addr_var =
1306         nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
1307      nir_push_if(&b, nir_ine_imm(&b, nir_channel(&b, pconst2, 2), 0));
1308      {
1309         nir_ssa_def *ptr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011)),
1310                                     nir_u2u64(&b, nir_imul_imm(&b, global_id, 8)));
1311         nir_ssa_def *addr =
1312            nir_pack_64_2x32(&b, nir_build_load_global(&b, 2, 32, ptr, .align_mul = 8));
1313         nir_store_var(&b, instance_addr_var, addr, 1);
1314      }
1315      nir_push_else(&b, NULL);
1316      {
1317         nir_ssa_def *addr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011)),
1318                                      nir_u2u64(&b, nir_imul_imm(&b, global_id, 64)));
1319         nir_store_var(&b, instance_addr_var, addr, 1);
1320      }
1321      nir_pop_if(&b, NULL);
1322      nir_ssa_def *instance_addr = nir_load_var(&b, instance_addr_var);
1323
1324      nir_ssa_def *inst_transform[] = {
1325         nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 0)),
1326         nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 16)),
1327         nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 32))};
1328      nir_ssa_def *inst3 = nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 48));
1329
1330      nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 128));
1331      node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
1332
1333      nir_ssa_def *node_id =
1334         nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), radv_bvh_node_instance);
1335      nir_build_store_global(&b, node_id, scratch_dst_addr);
1336
1337      nir_ssa_def *header_addr = nir_pack_64_2x32(&b, nir_channels(&b, inst3, 12));
1338      nir_push_if(&b, nir_ine_imm(&b, header_addr, 0));
1339      nir_ssa_def *header_root_offset =
1340         nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, header_addr, 0));
1341      nir_ssa_def *header_min = nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, header_addr, 8));
1342      nir_ssa_def *header_max = nir_build_load_global(&b, 3, 32, nir_iadd_imm(&b, header_addr, 20));
1343
1344      nir_ssa_def *bound_defs[2][3];
1345      for (unsigned i = 0; i < 3; ++i) {
1346         bound_defs[0][i] = bound_defs[1][i] = nir_channel(&b, inst_transform[i], 3);
1347
1348         nir_ssa_def *mul_a = nir_fmul(&b, nir_channels(&b, inst_transform[i], 7), header_min);
1349         nir_ssa_def *mul_b = nir_fmul(&b, nir_channels(&b, inst_transform[i], 7), header_max);
1350         nir_ssa_def *mi = nir_fmin(&b, mul_a, mul_b);
1351         nir_ssa_def *ma = nir_fmax(&b, mul_a, mul_b);
1352         for (unsigned j = 0; j < 3; ++j) {
1353            bound_defs[0][i] = nir_fadd(&b, bound_defs[0][i], nir_channel(&b, mi, j));
1354            bound_defs[1][i] = nir_fadd(&b, bound_defs[1][i], nir_channel(&b, ma, j));
1355         }
1356      }
1357
1358      nir_store_var(&b, bounds[0], nir_vec(&b, bound_defs[0], 3), 7);
1359      nir_store_var(&b, bounds[1], nir_vec(&b, bound_defs[1], 3), 7);
1360
1361      /* Store object to world matrix */
1362      for (unsigned i = 0; i < 3; ++i) {
1363         nir_ssa_def *vals[3];
1364         for (unsigned j = 0; j < 3; ++j)
1365            vals[j] = nir_channel(&b, inst_transform[j], i);
1366
1367         nir_build_store_global(&b, nir_vec(&b, vals, 3),
1368                                nir_iadd_imm(&b, node_dst_addr, 92 + 12 * i));
1369      }
1370
1371      nir_ssa_def *m_in[3][3], *m_out[3][3], *m_vec[3][4];
1372      for (unsigned i = 0; i < 3; ++i)
1373         for (unsigned j = 0; j < 3; ++j)
1374            m_in[i][j] = nir_channel(&b, inst_transform[i], j);
1375      nir_invert_3x3(&b, m_in, m_out);
1376      for (unsigned i = 0; i < 3; ++i) {
1377         for (unsigned j = 0; j < 3; ++j)
1378            m_vec[i][j] = m_out[i][j];
1379         m_vec[i][3] = nir_channel(&b, inst_transform[i], 3);
1380      }
1381
1382      for (unsigned i = 0; i < 3; ++i) {
1383         nir_build_store_global(&b, nir_vec(&b, m_vec[i], 4),
1384                                nir_iadd_imm(&b, node_dst_addr, 16 + 16 * i));
1385      }
1386
1387      nir_ssa_def *out0[4] = {
1388         nir_ior(&b, nir_channel(&b, nir_unpack_64_2x32(&b, header_addr), 0), header_root_offset),
1389         nir_channel(&b, nir_unpack_64_2x32(&b, header_addr), 1), nir_channel(&b, inst3, 0),
1390         nir_channel(&b, inst3, 1)};
1391      nir_build_store_global(&b, nir_vec(&b, out0, 4), nir_iadd_imm(&b, node_dst_addr, 0));
1392      nir_build_store_global(&b, global_id, nir_iadd_imm(&b, node_dst_addr, 88));
1393      nir_pop_if(&b, NULL);
1394      nir_build_store_global(&b, nir_load_var(&b, bounds[0]), nir_iadd_imm(&b, node_dst_addr, 64));
1395      nir_build_store_global(&b, nir_load_var(&b, bounds[1]), nir_iadd_imm(&b, node_dst_addr, 76));
1396   }
1397   nir_pop_if(&b, NULL);
1398   nir_pop_if(&b, NULL);
1399
1400   if (build_mode != accel_struct_build_unoptimized) {
1401      nir_ssa_def *min = nir_load_var(&b, bounds[0]);
1402      nir_ssa_def *max = nir_load_var(&b, bounds[1]);
1403
1404      nir_ssa_def *min_reduced = nir_reduce(&b, min, .reduction_op = nir_op_fmin);
1405      nir_ssa_def *max_reduced = nir_reduce(&b, max, .reduction_op = nir_op_fmax);
1406
1407      nir_push_if(&b, nir_elect(&b, 1));
1408
1409      atomic_fminmax(dev, &b, scratch_addr, false, nir_channel(&b, min_reduced, 0));
1410      atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 4), false,
1411                     nir_channel(&b, min_reduced, 1));
1412      atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 8), false,
1413                     nir_channel(&b, min_reduced, 2));
1414
1415      atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 12), true,
1416                     nir_channel(&b, max_reduced, 0));
1417      atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 16), true,
1418                     nir_channel(&b, max_reduced, 1));
1419      atomic_fminmax(dev, &b, nir_iadd_imm(&b, scratch_addr, 20), true,
1420                     nir_channel(&b, max_reduced, 2));
1421   }
1422
1423   return b.shader;
1424}
1425
1426static void
1427determine_bounds(nir_builder *b, nir_ssa_def *node_addr, nir_ssa_def *node_id,
1428                 nir_variable *bounds_vars[2])
1429{
1430   nir_ssa_def *node_type = nir_iand_imm(b, node_id, 7);
1431   node_addr =
1432      nir_iadd(b, node_addr, nir_u2u64(b, nir_ishl_imm(b, nir_iand_imm(b, node_id, ~7u), 3)));
1433
1434   nir_push_if(b, nir_ieq_imm(b, node_type, radv_bvh_node_triangle));
1435   {
1436      nir_ssa_def *positions[3];
1437      for (unsigned i = 0; i < 3; ++i)
1438         positions[i] = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, i * 12));
1439      nir_ssa_def *bounds[] = {positions[0], positions[0]};
1440      for (unsigned i = 1; i < 3; ++i) {
1441         bounds[0] = nir_fmin(b, bounds[0], positions[i]);
1442         bounds[1] = nir_fmax(b, bounds[1], positions[i]);
1443      }
1444      nir_store_var(b, bounds_vars[0], bounds[0], 7);
1445      nir_store_var(b, bounds_vars[1], bounds[1], 7);
1446   }
1447   nir_push_else(b, NULL);
1448   nir_push_if(b, nir_ieq_imm(b, node_type, radv_bvh_node_internal));
1449   {
1450      nir_ssa_def *input_bounds[4][2];
1451      for (unsigned i = 0; i < 4; ++i)
1452         for (unsigned j = 0; j < 2; ++j)
1453            input_bounds[i][j] =
1454               nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 16 + i * 24 + j * 12));
1455      nir_ssa_def *bounds[] = {input_bounds[0][0], input_bounds[0][1]};
1456      for (unsigned i = 1; i < 4; ++i) {
1457         bounds[0] = nir_fmin(b, bounds[0], input_bounds[i][0]);
1458         bounds[1] = nir_fmax(b, bounds[1], input_bounds[i][1]);
1459      }
1460
1461      nir_store_var(b, bounds_vars[0], bounds[0], 7);
1462      nir_store_var(b, bounds_vars[1], bounds[1], 7);
1463   }
1464   nir_push_else(b, NULL);
1465   nir_push_if(b, nir_ieq_imm(b, node_type, radv_bvh_node_instance));
1466   { /* Instances */
1467      nir_ssa_def *bounds[2];
1468      for (unsigned i = 0; i < 2; ++i)
1469         bounds[i] = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, 64 + i * 12));
1470      nir_store_var(b, bounds_vars[0], bounds[0], 7);
1471      nir_store_var(b, bounds_vars[1], bounds[1], 7);
1472   }
1473   nir_push_else(b, NULL);
1474   { /* AABBs */
1475      nir_ssa_def *bounds[2];
1476      for (unsigned i = 0; i < 2; ++i)
1477         bounds[i] = nir_build_load_global(b, 3, 32, nir_iadd_imm(b, node_addr, i * 12));
1478      nir_store_var(b, bounds_vars[0], bounds[0], 7);
1479      nir_store_var(b, bounds_vars[1], bounds[1], 7);
1480   }
1481   nir_pop_if(b, NULL);
1482   nir_pop_if(b, NULL);
1483   nir_pop_if(b, NULL);
1484}
1485
1486/* https://developer.nvidia.com/blog/thinking-parallel-part-iii-tree-construction-gpu/ */
1487static nir_ssa_def *
1488build_morton_component(nir_builder *b, nir_ssa_def *x)
1489{
1490   x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000101u), 0x0F00F00Fu);
1491   x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000011u), 0xC30C30C3u);
1492   x = nir_iand_imm(b, nir_imul_imm(b, x, 0x00000005u), 0x49249249u);
1493   return x;
1494}
1495
1496static nir_shader *
1497build_morton_shader(struct radv_device *dev)
1498{
1499   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1500
1501   nir_builder b = create_accel_build_shader(dev, "accel_build_morton_shader");
1502
1503   /*
1504    * push constants:
1505    *   i32 x 2: node address
1506    *   i32 x 2: scratch address
1507    */
1508   nir_ssa_def *pconst0 =
1509      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1510
1511   nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
1512   nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
1513
1514   nir_ssa_def *global_id =
1515      nir_iadd(&b,
1516               nir_imul_imm(&b, nir_channel(&b, nir_load_workgroup_id(&b, 32), 0),
1517                            b.shader->info.workgroup_size[0]),
1518               nir_load_local_invocation_index(&b));
1519
1520   nir_ssa_def *node_id_addr =
1521      nir_iadd(&b, nir_iadd_imm(&b, scratch_addr, SCRATCH_TOTAL_BOUNDS_SIZE),
1522               nir_u2u64(&b, id_to_node_id_offset(&b, global_id, dev->physical_device)));
1523   nir_ssa_def *node_id =
1524      nir_build_load_global(&b, 1, 32, node_id_addr, .align_mul = 4, .align_offset = 0);
1525
1526   nir_variable *node_bounds[2] = {
1527      nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1528      nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1529   };
1530
1531   determine_bounds(&b, node_addr, node_id, node_bounds);
1532
1533   nir_ssa_def *node_min = nir_load_var(&b, node_bounds[0]);
1534   nir_ssa_def *node_max = nir_load_var(&b, node_bounds[1]);
1535   nir_ssa_def *node_pos =
1536      nir_fmul(&b, nir_fadd(&b, node_min, node_max), nir_imm_vec3(&b, 0.5, 0.5, 0.5));
1537
1538   nir_ssa_def *bvh_min = read_fminmax_atomic(dev, &b, 3, scratch_addr);
1539   nir_ssa_def *bvh_max = read_fminmax_atomic(dev, &b, 3, nir_iadd_imm(&b, scratch_addr, 12));
1540   nir_ssa_def *bvh_size = nir_fsub(&b, bvh_max, bvh_min);
1541
1542   nir_ssa_def *normalized_node_pos = nir_fdiv(&b, nir_fsub(&b, node_pos, bvh_min), bvh_size);
1543
1544   nir_ssa_def *x_int =
1545      nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 0), 255.0));
1546   nir_ssa_def *x_morton = build_morton_component(&b, x_int);
1547
1548   nir_ssa_def *y_int =
1549      nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 1), 255.0));
1550   nir_ssa_def *y_morton = build_morton_component(&b, y_int);
1551
1552   nir_ssa_def *z_int =
1553      nir_f2u32(&b, nir_fmul_imm(&b, nir_channel(&b, normalized_node_pos, 2), 255.0));
1554   nir_ssa_def *z_morton = build_morton_component(&b, z_int);
1555
1556   nir_ssa_def *morton_code = nir_iadd(
1557      &b, nir_iadd(&b, nir_ishl_imm(&b, x_morton, 2), nir_ishl_imm(&b, y_morton, 1)), z_morton);
1558   nir_ssa_def *key = nir_ishl_imm(&b, morton_code, 8);
1559
1560   nir_ssa_def *dst_addr =
1561      nir_iadd(&b, nir_iadd_imm(&b, scratch_addr, SCRATCH_TOTAL_BOUNDS_SIZE),
1562               nir_u2u64(&b, id_to_morton_offset(&b, global_id, dev->physical_device)));
1563   nir_build_store_global(&b, key, dst_addr, .align_mul = 4);
1564
1565   return b.shader;
1566}
1567
1568static nir_shader *
1569build_internal_shader(struct radv_device *dev)
1570{
1571   const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
1572   nir_builder b = create_accel_build_shader(dev, "accel_build_internal_shader");
1573
1574   /*
1575    * push constants:
1576    *   i32 x 2: node dst address
1577    *   i32 x 2: scratch address
1578    *   i32: dst offset
1579    *   i32: dst scratch offset
1580    *   i32: src scratch offset
1581    *   i32: src_node_count | (fill_header << 31)
1582    */
1583   nir_ssa_def *pconst0 =
1584      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1585   nir_ssa_def *pconst1 =
1586      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
1587
1588   nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
1589   nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
1590   nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
1591   nir_ssa_def *dst_scratch_offset = nir_channel(&b, pconst1, 1);
1592   nir_ssa_def *src_scratch_offset = nir_channel(&b, pconst1, 2);
1593   nir_ssa_def *src_node_count = nir_iand_imm(&b, nir_channel(&b, pconst1, 3), 0x7FFFFFFFU);
1594   nir_ssa_def *fill_header =
1595      nir_ine_imm(&b, nir_iand_imm(&b, nir_channel(&b, pconst1, 3), 0x80000000U), 0);
1596
1597   nir_ssa_def *global_id =
1598      nir_iadd(&b,
1599               nir_imul_imm(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
1600                            b.shader->info.workgroup_size[0]),
1601               nir_channels(&b, nir_load_local_invocation_id(&b), 1));
1602   nir_ssa_def *src_idx = nir_imul_imm(&b, global_id, 4);
1603   nir_ssa_def *src_count = nir_umin(&b, nir_imm_int(&b, 4), nir_isub(&b, src_node_count, src_idx));
1604
1605   nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_ishl_imm(&b, global_id, 7));
1606   nir_ssa_def *node_dst_addr = nir_iadd(&b, node_addr, nir_u2u64(&b, node_offset));
1607
1608   nir_ssa_def *src_base_addr =
1609      nir_iadd(&b, scratch_addr,
1610               nir_u2u64(&b, nir_iadd(&b, src_scratch_offset,
1611                                      id_to_node_id_offset(&b, src_idx, dev->physical_device))));
1612
1613   enum accel_struct_build build_mode =
1614      get_accel_struct_build(dev->physical_device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
1615   uint32_t node_id_stride = get_node_id_stride(build_mode);
1616
1617   nir_ssa_def *src_nodes[4];
1618   for (uint32_t i = 0; i < 4; i++) {
1619      src_nodes[i] =
1620         nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, src_base_addr, i * node_id_stride));
1621      nir_build_store_global(&b, src_nodes[i], nir_iadd_imm(&b, node_dst_addr, i * 4));
1622   }
1623
1624   nir_ssa_def *total_bounds[2] = {
1625      nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
1626      nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7),
1627   };
1628
1629   for (unsigned i = 0; i < 4; ++i) {
1630      nir_variable *bounds[2] = {
1631         nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "min_bound"),
1632         nir_variable_create(b.shader, nir_var_shader_temp, vec3_type, "max_bound"),
1633      };
1634      nir_store_var(&b, bounds[0], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1635      nir_store_var(&b, bounds[1], nir_channels(&b, nir_imm_vec4(&b, NAN, NAN, NAN, NAN), 7), 7);
1636
1637      nir_push_if(&b, nir_ilt(&b, nir_imm_int(&b, i), src_count));
1638      determine_bounds(&b, node_addr, src_nodes[i], bounds);
1639      nir_pop_if(&b, NULL);
1640      nir_build_store_global(&b, nir_load_var(&b, bounds[0]),
1641                             nir_iadd_imm(&b, node_dst_addr, 16 + 24 * i));
1642      nir_build_store_global(&b, nir_load_var(&b, bounds[1]),
1643                             nir_iadd_imm(&b, node_dst_addr, 28 + 24 * i));
1644      total_bounds[0] = nir_fmin(&b, total_bounds[0], nir_load_var(&b, bounds[0]));
1645      total_bounds[1] = nir_fmax(&b, total_bounds[1], nir_load_var(&b, bounds[1]));
1646   }
1647
1648   nir_ssa_def *node_id =
1649      nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), radv_bvh_node_internal);
1650   nir_ssa_def *dst_scratch_addr =
1651      nir_iadd(&b, scratch_addr,
1652               nir_u2u64(&b, nir_iadd(&b, dst_scratch_offset,
1653                                      id_to_node_id_offset(&b, global_id, dev->physical_device))));
1654   nir_build_store_global(&b, node_id, dst_scratch_addr);
1655
1656   nir_push_if(&b, fill_header);
1657   nir_build_store_global(&b, node_id, node_addr);
1658   nir_build_store_global(&b, total_bounds[0], nir_iadd_imm(&b, node_addr, 8));
1659   nir_build_store_global(&b, total_bounds[1], nir_iadd_imm(&b, node_addr, 20));
1660   nir_pop_if(&b, NULL);
1661   return b.shader;
1662}
1663
1664enum copy_mode {
1665   COPY_MODE_COPY,
1666   COPY_MODE_SERIALIZE,
1667   COPY_MODE_DESERIALIZE,
1668};
1669
1670struct copy_constants {
1671   uint64_t src_addr;
1672   uint64_t dst_addr;
1673   uint32_t mode;
1674};
1675
1676static nir_shader *
1677build_copy_shader(struct radv_device *dev)
1678{
1679   nir_builder b = create_accel_build_shader(dev, "accel_copy");
1680
1681   nir_ssa_def *invoc_id = nir_load_local_invocation_id(&b);
1682   nir_ssa_def *wg_id = nir_load_workgroup_id(&b, 32);
1683   nir_ssa_def *block_size =
1684      nir_imm_ivec4(&b, b.shader->info.workgroup_size[0], b.shader->info.workgroup_size[1],
1685                    b.shader->info.workgroup_size[2], 0);
1686
1687   nir_ssa_def *global_id =
1688      nir_channel(&b, nir_iadd(&b, nir_imul(&b, wg_id, block_size), invoc_id), 0);
1689
1690   nir_variable *offset_var =
1691      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "offset");
1692   nir_ssa_def *offset = nir_imul_imm(&b, global_id, 16);
1693   nir_store_var(&b, offset_var, offset, 1);
1694
1695   nir_ssa_def *increment = nir_imul_imm(&b, nir_channel(&b, nir_load_num_workgroups(&b, 32), 0),
1696                                         b.shader->info.workgroup_size[0] * 16);
1697
1698   nir_ssa_def *pconst0 =
1699      nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
1700   nir_ssa_def *pconst1 =
1701      nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 16, .range = 4);
1702   nir_ssa_def *src_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
1703   nir_ssa_def *dst_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
1704   nir_ssa_def *mode = nir_channel(&b, pconst1, 0);
1705
1706   nir_variable *compacted_size_var =
1707      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "compacted_size");
1708   nir_variable *src_offset_var =
1709      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "src_offset");
1710   nir_variable *dst_offset_var =
1711      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "dst_offset");
1712   nir_variable *instance_offset_var =
1713      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "instance_offset");
1714   nir_variable *instance_count_var =
1715      nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint_type(), "instance_count");
1716   nir_variable *value_var =
1717      nir_variable_create(b.shader, nir_var_shader_temp, glsl_vec4_type(), "value");
1718
1719   nir_push_if(&b, nir_ieq_imm(&b, mode, COPY_MODE_SERIALIZE));
1720   {
1721      nir_ssa_def *instance_count = nir_build_load_global(
1722         &b, 1, 32,
1723         nir_iadd_imm(&b, src_base_addr,
1724                      offsetof(struct radv_accel_struct_header, instance_count)));
1725      nir_ssa_def *compacted_size = nir_build_load_global(
1726         &b, 1, 64,
1727         nir_iadd_imm(&b, src_base_addr,
1728                      offsetof(struct radv_accel_struct_header, compacted_size)));
1729      nir_ssa_def *serialization_size = nir_build_load_global(
1730         &b, 1, 64,
1731         nir_iadd_imm(&b, src_base_addr,
1732                      offsetof(struct radv_accel_struct_header, serialization_size)));
1733
1734      nir_store_var(&b, compacted_size_var, compacted_size, 1);
1735      nir_store_var(&b, instance_offset_var,
1736                    nir_build_load_global(
1737                       &b, 1, 32,
1738                       nir_iadd_imm(&b, src_base_addr,
1739                                    offsetof(struct radv_accel_struct_header, instance_offset))),
1740                    1);
1741      nir_store_var(&b, instance_count_var, instance_count, 1);
1742
1743      nir_ssa_def *dst_offset = nir_iadd_imm(&b, nir_imul_imm(&b, instance_count, sizeof(uint64_t)),
1744                                             sizeof(struct radv_accel_struct_serialization_header));
1745      nir_store_var(&b, src_offset_var, nir_imm_int(&b, 0), 1);
1746      nir_store_var(&b, dst_offset_var, dst_offset, 1);
1747
1748      nir_push_if(&b, nir_ieq_imm(&b, global_id, 0));
1749      {
1750         nir_build_store_global(&b, serialization_size,
1751                                nir_iadd_imm(&b, dst_base_addr,
1752                                             offsetof(struct radv_accel_struct_serialization_header,
1753                                                      serialization_size)));
1754         nir_build_store_global(
1755            &b, compacted_size,
1756            nir_iadd_imm(&b, dst_base_addr,
1757                         offsetof(struct radv_accel_struct_serialization_header, compacted_size)));
1758         nir_build_store_global(
1759            &b, nir_u2u64(&b, instance_count),
1760            nir_iadd_imm(&b, dst_base_addr,
1761                         offsetof(struct radv_accel_struct_serialization_header, instance_count)));
1762      }
1763      nir_pop_if(&b, NULL);
1764   }
1765   nir_push_else(&b, NULL);
1766   nir_push_if(&b, nir_ieq_imm(&b, mode, COPY_MODE_DESERIALIZE));
1767   {
1768      nir_ssa_def *instance_count = nir_build_load_global(
1769         &b, 1, 32,
1770         nir_iadd_imm(&b, src_base_addr,
1771                      offsetof(struct radv_accel_struct_serialization_header, instance_count)));
1772      nir_ssa_def *src_offset = nir_iadd_imm(&b, nir_imul_imm(&b, instance_count, sizeof(uint64_t)),
1773                                             sizeof(struct radv_accel_struct_serialization_header));
1774
1775      nir_ssa_def *header_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, src_offset));
1776      nir_store_var(&b, compacted_size_var,
1777                    nir_build_load_global(
1778                       &b, 1, 64,
1779                       nir_iadd_imm(&b, header_addr,
1780                                    offsetof(struct radv_accel_struct_header, compacted_size))),
1781                    1);
1782      nir_store_var(&b, instance_offset_var,
1783                    nir_build_load_global(
1784                       &b, 1, 32,
1785                       nir_iadd_imm(&b, header_addr,
1786                                    offsetof(struct radv_accel_struct_header, instance_offset))),
1787                    1);
1788      nir_store_var(&b, instance_count_var, instance_count, 1);
1789      nir_store_var(&b, src_offset_var, src_offset, 1);
1790      nir_store_var(&b, dst_offset_var, nir_imm_int(&b, 0), 1);
1791   }
1792   nir_push_else(&b, NULL); /* COPY_MODE_COPY */
1793   {
1794      nir_store_var(&b, compacted_size_var,
1795                    nir_build_load_global(
1796                       &b, 1, 64,
1797                       nir_iadd_imm(&b, src_base_addr,
1798                                    offsetof(struct radv_accel_struct_header, compacted_size))),
1799                    1);
1800
1801      nir_store_var(&b, src_offset_var, nir_imm_int(&b, 0), 1);
1802      nir_store_var(&b, dst_offset_var, nir_imm_int(&b, 0), 1);
1803      nir_store_var(&b, instance_offset_var, nir_imm_int(&b, 0), 1);
1804      nir_store_var(&b, instance_count_var, nir_imm_int(&b, 0), 1);
1805   }
1806   nir_pop_if(&b, NULL);
1807   nir_pop_if(&b, NULL);
1808
1809   nir_ssa_def *instance_bound =
1810      nir_imul_imm(&b, nir_load_var(&b, instance_count_var), sizeof(struct radv_bvh_instance_node));
1811   nir_ssa_def *compacted_size = nir_build_load_global(
1812      &b, 1, 32,
1813      nir_iadd_imm(&b, src_base_addr, offsetof(struct radv_accel_struct_header, compacted_size)));
1814
1815   nir_push_loop(&b);
1816   {
1817      offset = nir_load_var(&b, offset_var);
1818      nir_push_if(&b, nir_ilt(&b, offset, compacted_size));
1819      {
1820         nir_ssa_def *src_offset = nir_iadd(&b, offset, nir_load_var(&b, src_offset_var));
1821         nir_ssa_def *dst_offset = nir_iadd(&b, offset, nir_load_var(&b, dst_offset_var));
1822         nir_ssa_def *src_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, src_offset));
1823         nir_ssa_def *dst_addr = nir_iadd(&b, dst_base_addr, nir_u2u64(&b, dst_offset));
1824
1825         nir_ssa_def *value = nir_build_load_global(&b, 4, 32, src_addr, .align_mul = 16);
1826         nir_store_var(&b, value_var, value, 0xf);
1827
1828         nir_ssa_def *instance_offset = nir_isub(&b, offset, nir_load_var(&b, instance_offset_var));
1829         nir_ssa_def *in_instance_bound =
1830            nir_iand(&b, nir_uge(&b, offset, nir_load_var(&b, instance_offset_var)),
1831                     nir_ult(&b, instance_offset, instance_bound));
1832         nir_ssa_def *instance_start = nir_ieq_imm(
1833            &b, nir_iand_imm(&b, instance_offset, sizeof(struct radv_bvh_instance_node) - 1), 0);
1834
1835         nir_push_if(&b, nir_iand(&b, in_instance_bound, instance_start));
1836         {
1837            nir_ssa_def *instance_id = nir_ushr_imm(&b, instance_offset, 7);
1838
1839            nir_push_if(&b, nir_ieq_imm(&b, mode, COPY_MODE_SERIALIZE));
1840            {
1841               nir_ssa_def *instance_addr = nir_imul_imm(&b, instance_id, sizeof(uint64_t));
1842               instance_addr = nir_iadd_imm(&b, instance_addr,
1843                                            sizeof(struct radv_accel_struct_serialization_header));
1844               instance_addr = nir_iadd(&b, dst_base_addr, nir_u2u64(&b, instance_addr));
1845
1846               nir_build_store_global(&b, nir_channels(&b, value, 3), instance_addr,
1847                                      .align_mul = 8);
1848            }
1849            nir_push_else(&b, NULL);
1850            {
1851               nir_ssa_def *instance_addr = nir_imul_imm(&b, instance_id, sizeof(uint64_t));
1852               instance_addr = nir_iadd_imm(&b, instance_addr,
1853                                            sizeof(struct radv_accel_struct_serialization_header));
1854               instance_addr = nir_iadd(&b, src_base_addr, nir_u2u64(&b, instance_addr));
1855
1856               nir_ssa_def *instance_value =
1857                  nir_build_load_global(&b, 2, 32, instance_addr, .align_mul = 8);
1858
1859               nir_ssa_def *values[] = {
1860                  nir_channel(&b, instance_value, 0),
1861                  nir_channel(&b, instance_value, 1),
1862                  nir_channel(&b, value, 2),
1863                  nir_channel(&b, value, 3),
1864               };
1865
1866               nir_store_var(&b, value_var, nir_vec(&b, values, 4), 0xf);
1867            }
1868            nir_pop_if(&b, NULL);
1869         }
1870         nir_pop_if(&b, NULL);
1871
1872         nir_store_var(&b, offset_var, nir_iadd(&b, offset, increment), 1);
1873
1874         nir_build_store_global(&b, nir_load_var(&b, value_var), dst_addr, .align_mul = 16);
1875      }
1876      nir_push_else(&b, NULL);
1877      {
1878         nir_jump(&b, nir_jump_break);
1879      }
1880      nir_pop_if(&b, NULL);
1881   }
1882   nir_pop_loop(&b, NULL);
1883   return b.shader;
1884}
1885
1886void
1887radv_device_finish_accel_struct_build_state(struct radv_device *device)
1888{
1889   struct radv_meta_state *state = &device->meta_state;
1890   radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.copy_pipeline,
1891                        &state->alloc);
1892   radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.internal_pipeline,
1893                        &state->alloc);
1894   radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.leaf_pipeline,
1895                        &state->alloc);
1896   radv_DestroyPipeline(radv_device_to_handle(device), state->accel_struct_build.morton_pipeline,
1897                        &state->alloc);
1898   radv_DestroyPipelineLayout(radv_device_to_handle(device),
1899                              state->accel_struct_build.copy_p_layout, &state->alloc);
1900   radv_DestroyPipelineLayout(radv_device_to_handle(device),
1901                              state->accel_struct_build.internal_p_layout, &state->alloc);
1902   radv_DestroyPipelineLayout(radv_device_to_handle(device),
1903                              state->accel_struct_build.leaf_p_layout, &state->alloc);
1904   radv_DestroyPipelineLayout(radv_device_to_handle(device),
1905                              state->accel_struct_build.morton_p_layout, &state->alloc);
1906
1907   if (state->accel_struct_build.radix_sort)
1908      radix_sort_vk_destroy(state->accel_struct_build.radix_sort, radv_device_to_handle(device),
1909                            &state->alloc);
1910}
1911
1912static VkResult
1913create_build_pipeline(struct radv_device *device, nir_shader *shader, unsigned push_constant_size,
1914                      VkPipeline *pipeline, VkPipelineLayout *layout)
1915{
1916   const VkPipelineLayoutCreateInfo pl_create_info = {
1917      .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
1918      .setLayoutCount = 0,
1919      .pushConstantRangeCount = 1,
1920      .pPushConstantRanges =
1921         &(VkPushConstantRange){VK_SHADER_STAGE_COMPUTE_BIT, 0, push_constant_size},
1922   };
1923
1924   VkResult result = radv_CreatePipelineLayout(radv_device_to_handle(device), &pl_create_info,
1925                                               &device->meta_state.alloc, layout);
1926   if (result != VK_SUCCESS) {
1927      ralloc_free(shader);
1928      return result;
1929   }
1930
1931   VkPipelineShaderStageCreateInfo shader_stage = {
1932      .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
1933      .stage = VK_SHADER_STAGE_COMPUTE_BIT,
1934      .module = vk_shader_module_handle_from_nir(shader),
1935      .pName = "main",
1936      .pSpecializationInfo = NULL,
1937   };
1938
1939   VkComputePipelineCreateInfo pipeline_info = {
1940      .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
1941      .stage = shader_stage,
1942      .flags = 0,
1943      .layout = *layout,
1944   };
1945
1946   result = radv_CreateComputePipelines(radv_device_to_handle(device),
1947                                        radv_pipeline_cache_to_handle(&device->meta_state.cache), 1,
1948                                        &pipeline_info, &device->meta_state.alloc, pipeline);
1949
1950   if (result != VK_SUCCESS) {
1951      ralloc_free(shader);
1952      return result;
1953   }
1954
1955   return VK_SUCCESS;
1956}
1957
1958static void
1959radix_sort_fill_buffer(VkCommandBuffer commandBuffer,
1960                       radix_sort_vk_buffer_info_t const *buffer_info, VkDeviceSize offset,
1961                       VkDeviceSize size, uint32_t data)
1962{
1963   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
1964
1965   assert(size != VK_WHOLE_SIZE);
1966
1967   radv_fill_buffer(cmd_buffer, NULL, NULL, buffer_info->devaddr + buffer_info->offset + offset,
1968                    size, data);
1969}
1970
1971VkResult
1972radv_device_init_accel_struct_build_state(struct radv_device *device)
1973{
1974   VkResult result;
1975   nir_shader *leaf_cs = build_leaf_shader(device);
1976   nir_shader *internal_cs = build_internal_shader(device);
1977   nir_shader *copy_cs = build_copy_shader(device);
1978
1979   result = create_build_pipeline(device, leaf_cs, sizeof(struct build_primitive_constants),
1980                                  &device->meta_state.accel_struct_build.leaf_pipeline,
1981                                  &device->meta_state.accel_struct_build.leaf_p_layout);
1982   if (result != VK_SUCCESS)
1983      return result;
1984
1985   result = create_build_pipeline(device, internal_cs, sizeof(struct build_internal_constants),
1986                                  &device->meta_state.accel_struct_build.internal_pipeline,
1987                                  &device->meta_state.accel_struct_build.internal_p_layout);
1988   if (result != VK_SUCCESS)
1989      return result;
1990
1991   result = create_build_pipeline(device, copy_cs, sizeof(struct copy_constants),
1992                                  &device->meta_state.accel_struct_build.copy_pipeline,
1993                                  &device->meta_state.accel_struct_build.copy_p_layout);
1994
1995   if (result != VK_SUCCESS)
1996      return result;
1997
1998   if (get_accel_struct_build(device->physical_device,
1999                              VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR) ==
2000       accel_struct_build_lbvh) {
2001      nir_shader *morton_cs = build_morton_shader(device);
2002
2003      result = create_build_pipeline(device, morton_cs, sizeof(struct morton_constants),
2004                                     &device->meta_state.accel_struct_build.morton_pipeline,
2005                                     &device->meta_state.accel_struct_build.morton_p_layout);
2006      if (result != VK_SUCCESS)
2007         return result;
2008
2009      device->meta_state.accel_struct_build.radix_sort =
2010         radv_create_radix_sort_u64(radv_device_to_handle(device), &device->meta_state.alloc,
2011                                    radv_pipeline_cache_to_handle(&device->meta_state.cache));
2012
2013      struct radix_sort_vk_sort_devaddr_info *radix_sort_info =
2014         &device->meta_state.accel_struct_build.radix_sort_info;
2015      radix_sort_info->ext = NULL;
2016      radix_sort_info->key_bits = 24;
2017      radix_sort_info->fill_buffer = radix_sort_fill_buffer;
2018   }
2019
2020   return result;
2021}
2022
2023struct bvh_state {
2024   uint32_t node_offset;
2025   uint32_t node_count;
2026   uint32_t scratch_offset;
2027   uint32_t buffer_1_offset;
2028   uint32_t buffer_2_offset;
2029
2030   uint32_t instance_offset;
2031   uint32_t instance_count;
2032};
2033
2034VKAPI_ATTR void VKAPI_CALL
2035radv_CmdBuildAccelerationStructuresKHR(
2036   VkCommandBuffer commandBuffer, uint32_t infoCount,
2037   const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
2038   const VkAccelerationStructureBuildRangeInfoKHR *const *ppBuildRangeInfos)
2039{
2040   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2041   struct radv_meta_saved_state saved_state;
2042
2043   enum radv_cmd_flush_bits flush_bits =
2044      RADV_CMD_FLAG_CS_PARTIAL_FLUSH |
2045      radv_src_access_flush(cmd_buffer, VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT,
2046                            NULL) |
2047      radv_dst_access_flush(cmd_buffer, VK_ACCESS_2_SHADER_READ_BIT | VK_ACCESS_2_SHADER_WRITE_BIT,
2048                            NULL);
2049
2050   enum accel_struct_build build_mode = get_accel_struct_build(
2051      cmd_buffer->device->physical_device, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR);
2052   uint32_t node_id_stride = get_node_id_stride(build_mode);
2053
2054   radv_meta_save(
2055      &saved_state, cmd_buffer,
2056      RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2057   struct bvh_state *bvh_states = calloc(infoCount, sizeof(struct bvh_state));
2058
2059   if (build_mode != accel_struct_build_unoptimized) {
2060      for (uint32_t i = 0; i < infoCount; ++i) {
2061         if (radv_has_shader_buffer_float_minmax(cmd_buffer->device->physical_device)) {
2062            /* Clear the bvh bounds with nan. */
2063            si_cp_dma_clear_buffer(cmd_buffer, pInfos[i].scratchData.deviceAddress,
2064                                   6 * sizeof(float), 0x7FC00000);
2065         } else {
2066            /* Clear the bvh bounds with int max/min. */
2067            si_cp_dma_clear_buffer(cmd_buffer, pInfos[i].scratchData.deviceAddress,
2068                                   3 * sizeof(float), 0x7fffffff);
2069            si_cp_dma_clear_buffer(cmd_buffer,
2070                                   pInfos[i].scratchData.deviceAddress + 3 * sizeof(float),
2071                                   3 * sizeof(float), 0x80000000);
2072         }
2073      }
2074
2075      cmd_buffer->state.flush_bits |= flush_bits;
2076   }
2077
2078   radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
2079                        cmd_buffer->device->meta_state.accel_struct_build.leaf_pipeline);
2080
2081   for (uint32_t i = 0; i < infoCount; ++i) {
2082      RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
2083                       pInfos[i].dstAccelerationStructure);
2084
2085      struct build_primitive_constants prim_consts = {
2086         .node_dst_addr = radv_accel_struct_get_va(accel_struct),
2087         .scratch_addr = pInfos[i].scratchData.deviceAddress,
2088         .dst_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64) + 128,
2089         .dst_scratch_offset = 0,
2090      };
2091      bvh_states[i].node_offset = prim_consts.dst_offset;
2092      bvh_states[i].instance_offset = prim_consts.dst_offset;
2093
2094      for (int inst = 1; inst >= 0; --inst) {
2095         for (unsigned j = 0; j < pInfos[i].geometryCount; ++j) {
2096            const VkAccelerationStructureGeometryKHR *geom =
2097               pInfos[i].pGeometries ? &pInfos[i].pGeometries[j] : pInfos[i].ppGeometries[j];
2098
2099            if (!inst == (geom->geometryType == VK_GEOMETRY_TYPE_INSTANCES_KHR))
2100               continue;
2101
2102            const VkAccelerationStructureBuildRangeInfoKHR *buildRangeInfo =
2103               &ppBuildRangeInfos[i][j];
2104
2105            prim_consts.geometry_type = geom->geometryType;
2106            prim_consts.geometry_id = j | (geom->flags << 28);
2107            unsigned prim_size;
2108            switch (geom->geometryType) {
2109            case VK_GEOMETRY_TYPE_TRIANGLES_KHR:
2110               prim_consts.vertex_addr =
2111                  geom->geometry.triangles.vertexData.deviceAddress +
2112                  buildRangeInfo->firstVertex * geom->geometry.triangles.vertexStride;
2113               prim_consts.index_addr = geom->geometry.triangles.indexData.deviceAddress;
2114
2115               if (geom->geometry.triangles.indexType == VK_INDEX_TYPE_NONE_KHR)
2116                  prim_consts.vertex_addr += buildRangeInfo->primitiveOffset;
2117               else
2118                  prim_consts.index_addr += buildRangeInfo->primitiveOffset;
2119
2120               prim_consts.transform_addr = geom->geometry.triangles.transformData.deviceAddress;
2121               if (prim_consts.transform_addr)
2122                  prim_consts.transform_addr += buildRangeInfo->transformOffset;
2123
2124               prim_consts.vertex_stride = geom->geometry.triangles.vertexStride;
2125               prim_consts.vertex_format = geom->geometry.triangles.vertexFormat;
2126               prim_consts.index_format = geom->geometry.triangles.indexType;
2127               prim_size = 64;
2128               break;
2129            case VK_GEOMETRY_TYPE_AABBS_KHR:
2130               prim_consts.aabb_addr =
2131                  geom->geometry.aabbs.data.deviceAddress + buildRangeInfo->primitiveOffset;
2132               prim_consts.aabb_stride = geom->geometry.aabbs.stride;
2133               prim_size = 64;
2134               break;
2135            case VK_GEOMETRY_TYPE_INSTANCES_KHR:
2136               prim_consts.instance_data =
2137                  geom->geometry.instances.data.deviceAddress + buildRangeInfo->primitiveOffset;
2138               prim_consts.array_of_pointers = geom->geometry.instances.arrayOfPointers ? 1 : 0;
2139               prim_size = 128;
2140               bvh_states[i].instance_count += buildRangeInfo->primitiveCount;
2141               break;
2142            default:
2143               unreachable("Unknown geometryType");
2144            }
2145
2146            radv_CmdPushConstants(
2147               commandBuffer, cmd_buffer->device->meta_state.accel_struct_build.leaf_p_layout,
2148               VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(prim_consts), &prim_consts);
2149            radv_unaligned_dispatch(cmd_buffer, buildRangeInfo->primitiveCount, 1, 1);
2150            prim_consts.dst_offset += prim_size * buildRangeInfo->primitiveCount;
2151            prim_consts.dst_scratch_offset += node_id_stride * buildRangeInfo->primitiveCount;
2152         }
2153      }
2154      bvh_states[i].node_offset = prim_consts.dst_offset;
2155      bvh_states[i].node_count = prim_consts.dst_scratch_offset / node_id_stride;
2156   }
2157
2158   if (build_mode == accel_struct_build_lbvh) {
2159      cmd_buffer->state.flush_bits |= flush_bits;
2160
2161      radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
2162                           cmd_buffer->device->meta_state.accel_struct_build.morton_pipeline);
2163
2164      for (uint32_t i = 0; i < infoCount; ++i) {
2165         RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
2166                          pInfos[i].dstAccelerationStructure);
2167
2168         const struct morton_constants consts = {
2169            .node_addr = radv_accel_struct_get_va(accel_struct),
2170            .scratch_addr = pInfos[i].scratchData.deviceAddress,
2171         };
2172
2173         radv_CmdPushConstants(commandBuffer,
2174                               cmd_buffer->device->meta_state.accel_struct_build.morton_p_layout,
2175                               VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2176         radv_unaligned_dispatch(cmd_buffer, bvh_states[i].node_count, 1, 1);
2177      }
2178
2179      cmd_buffer->state.flush_bits |= flush_bits;
2180
2181      for (uint32_t i = 0; i < infoCount; ++i) {
2182         struct radix_sort_vk_memory_requirements requirements;
2183         radix_sort_vk_get_memory_requirements(
2184            cmd_buffer->device->meta_state.accel_struct_build.radix_sort, bvh_states[i].node_count,
2185            &requirements);
2186
2187         struct radix_sort_vk_sort_devaddr_info info =
2188            cmd_buffer->device->meta_state.accel_struct_build.radix_sort_info;
2189         info.count = bvh_states[i].node_count;
2190
2191         VkDeviceAddress base_addr =
2192            pInfos[i].scratchData.deviceAddress + SCRATCH_TOTAL_BOUNDS_SIZE;
2193
2194         info.keyvals_even.buffer = VK_NULL_HANDLE;
2195         info.keyvals_even.offset = 0;
2196         info.keyvals_even.devaddr = base_addr;
2197
2198         info.keyvals_odd = base_addr + requirements.keyvals_size;
2199
2200         info.internal.buffer = VK_NULL_HANDLE;
2201         info.internal.offset = 0;
2202         info.internal.devaddr = base_addr + requirements.keyvals_size * 2;
2203
2204         VkDeviceAddress result_addr;
2205         radix_sort_vk_sort_devaddr(cmd_buffer->device->meta_state.accel_struct_build.radix_sort,
2206                                    &info, radv_device_to_handle(cmd_buffer->device), commandBuffer,
2207                                    &result_addr);
2208
2209         assert(result_addr == info.keyvals_even.devaddr || result_addr == info.keyvals_odd);
2210
2211         if (result_addr == info.keyvals_even.devaddr) {
2212            bvh_states[i].buffer_1_offset = SCRATCH_TOTAL_BOUNDS_SIZE;
2213            bvh_states[i].buffer_2_offset = SCRATCH_TOTAL_BOUNDS_SIZE + requirements.keyvals_size;
2214         } else {
2215            bvh_states[i].buffer_1_offset = SCRATCH_TOTAL_BOUNDS_SIZE + requirements.keyvals_size;
2216            bvh_states[i].buffer_2_offset = SCRATCH_TOTAL_BOUNDS_SIZE;
2217         }
2218         bvh_states[i].scratch_offset = bvh_states[i].buffer_1_offset;
2219      }
2220
2221      cmd_buffer->state.flush_bits |= flush_bits;
2222   } else {
2223      for (uint32_t i = 0; i < infoCount; ++i) {
2224         bvh_states[i].buffer_1_offset = 0;
2225         bvh_states[i].buffer_2_offset = bvh_states[i].node_count * 4;
2226      }
2227   }
2228
2229   radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
2230                        cmd_buffer->device->meta_state.accel_struct_build.internal_pipeline);
2231   bool progress = true;
2232   for (unsigned iter = 0; progress; ++iter) {
2233      progress = false;
2234      for (uint32_t i = 0; i < infoCount; ++i) {
2235         RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
2236                          pInfos[i].dstAccelerationStructure);
2237
2238         if (iter && bvh_states[i].node_count == 1)
2239            continue;
2240
2241         if (!progress)
2242            cmd_buffer->state.flush_bits |= flush_bits;
2243
2244         progress = true;
2245
2246         uint32_t dst_node_count = MAX2(1, DIV_ROUND_UP(bvh_states[i].node_count, 4));
2247         bool final_iter = dst_node_count == 1;
2248
2249         uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
2250         uint32_t buffer_1_offset = bvh_states[i].buffer_1_offset;
2251         uint32_t buffer_2_offset = bvh_states[i].buffer_2_offset;
2252         uint32_t dst_scratch_offset =
2253            (src_scratch_offset == buffer_1_offset) ? buffer_2_offset : buffer_1_offset;
2254
2255         uint32_t dst_node_offset = bvh_states[i].node_offset;
2256         if (final_iter)
2257            dst_node_offset = ALIGN(sizeof(struct radv_accel_struct_header), 64);
2258
2259         const struct build_internal_constants consts = {
2260            .node_dst_addr = radv_accel_struct_get_va(accel_struct),
2261            .scratch_addr = pInfos[i].scratchData.deviceAddress,
2262            .dst_offset = dst_node_offset,
2263            .dst_scratch_offset = dst_scratch_offset,
2264            .src_scratch_offset = src_scratch_offset,
2265            .fill_header = bvh_states[i].node_count | (final_iter ? 0x80000000U : 0),
2266         };
2267
2268         radv_CmdPushConstants(commandBuffer,
2269                               cmd_buffer->device->meta_state.accel_struct_build.internal_p_layout,
2270                               VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2271         radv_unaligned_dispatch(cmd_buffer, dst_node_count, 1, 1);
2272         if (!final_iter)
2273            bvh_states[i].node_offset += dst_node_count * 128;
2274         bvh_states[i].node_count = dst_node_count;
2275         bvh_states[i].scratch_offset = dst_scratch_offset;
2276      }
2277   }
2278   for (uint32_t i = 0; i < infoCount; ++i) {
2279      RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct,
2280                       pInfos[i].dstAccelerationStructure);
2281      const size_t base = offsetof(struct radv_accel_struct_header, compacted_size);
2282      struct radv_accel_struct_header header;
2283
2284      header.instance_offset = bvh_states[i].instance_offset;
2285      header.instance_count = bvh_states[i].instance_count;
2286      header.compacted_size = bvh_states[i].node_offset;
2287
2288      fill_accel_struct_header(&header);
2289
2290      radv_update_buffer_cp(cmd_buffer,
2291                            radv_buffer_get_va(accel_struct->bo) + accel_struct->mem_offset + base,
2292                            (const char *)&header + base, sizeof(header) - base);
2293   }
2294   free(bvh_states);
2295   radv_meta_restore(&saved_state, cmd_buffer);
2296}
2297
2298VKAPI_ATTR void VKAPI_CALL
2299radv_CmdCopyAccelerationStructureKHR(VkCommandBuffer commandBuffer,
2300                                     const VkCopyAccelerationStructureInfoKHR *pInfo)
2301{
2302   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2303   RADV_FROM_HANDLE(radv_acceleration_structure, src, pInfo->src);
2304   RADV_FROM_HANDLE(radv_acceleration_structure, dst, pInfo->dst);
2305   struct radv_meta_saved_state saved_state;
2306
2307   radv_meta_save(
2308      &saved_state, cmd_buffer,
2309      RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2310
2311   uint64_t src_addr = radv_accel_struct_get_va(src);
2312   uint64_t dst_addr = radv_accel_struct_get_va(dst);
2313
2314   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2315                        cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2316
2317   const struct copy_constants consts = {
2318      .src_addr = src_addr,
2319      .dst_addr = dst_addr,
2320      .mode = COPY_MODE_COPY,
2321   };
2322
2323   radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2324                         cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2325                         VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2326
2327   cmd_buffer->state.flush_bits |=
2328      radv_dst_access_flush(cmd_buffer, VK_ACCESS_2_INDIRECT_COMMAND_READ_BIT, NULL);
2329
2330   radv_indirect_dispatch(cmd_buffer, src->bo,
2331                          src_addr + offsetof(struct radv_accel_struct_header, copy_dispatch_size));
2332   radv_meta_restore(&saved_state, cmd_buffer);
2333}
2334
2335VKAPI_ATTR void VKAPI_CALL
2336radv_GetDeviceAccelerationStructureCompatibilityKHR(
2337   VkDevice _device, const VkAccelerationStructureVersionInfoKHR *pVersionInfo,
2338   VkAccelerationStructureCompatibilityKHR *pCompatibility)
2339{
2340   RADV_FROM_HANDLE(radv_device, device, _device);
2341   uint8_t zero[VK_UUID_SIZE] = {
2342      0,
2343   };
2344   bool compat =
2345      memcmp(pVersionInfo->pVersionData, device->physical_device->driver_uuid, VK_UUID_SIZE) == 0 &&
2346      memcmp(pVersionInfo->pVersionData + VK_UUID_SIZE, zero, VK_UUID_SIZE) == 0;
2347   *pCompatibility = compat ? VK_ACCELERATION_STRUCTURE_COMPATIBILITY_COMPATIBLE_KHR
2348                            : VK_ACCELERATION_STRUCTURE_COMPATIBILITY_INCOMPATIBLE_KHR;
2349}
2350
2351VKAPI_ATTR VkResult VKAPI_CALL
2352radv_CopyMemoryToAccelerationStructureKHR(VkDevice _device,
2353                                          VkDeferredOperationKHR deferredOperation,
2354                                          const VkCopyMemoryToAccelerationStructureInfoKHR *pInfo)
2355{
2356   RADV_FROM_HANDLE(radv_device, device, _device);
2357   RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct, pInfo->dst);
2358
2359   char *base = device->ws->buffer_map(accel_struct->bo);
2360   if (!base)
2361      return VK_ERROR_OUT_OF_HOST_MEMORY;
2362
2363   base += accel_struct->mem_offset;
2364   const struct radv_accel_struct_header *header = (const struct radv_accel_struct_header *)base;
2365
2366   const char *src = pInfo->src.hostAddress;
2367   struct radv_accel_struct_serialization_header *src_header = (void *)src;
2368   src += sizeof(*src_header) + sizeof(uint64_t) * src_header->instance_count;
2369
2370   memcpy(base, src, src_header->compacted_size);
2371
2372   for (unsigned i = 0; i < src_header->instance_count; ++i) {
2373      uint64_t *p = (uint64_t *)(base + i * 128 + header->instance_offset);
2374      *p = (*p & 63) | src_header->instances[i];
2375   }
2376
2377   device->ws->buffer_unmap(accel_struct->bo);
2378   return VK_SUCCESS;
2379}
2380
2381VKAPI_ATTR VkResult VKAPI_CALL
2382radv_CopyAccelerationStructureToMemoryKHR(VkDevice _device,
2383                                          VkDeferredOperationKHR deferredOperation,
2384                                          const VkCopyAccelerationStructureToMemoryInfoKHR *pInfo)
2385{
2386   RADV_FROM_HANDLE(radv_device, device, _device);
2387   RADV_FROM_HANDLE(radv_acceleration_structure, accel_struct, pInfo->src);
2388
2389   const char *base = device->ws->buffer_map(accel_struct->bo);
2390   if (!base)
2391      return VK_ERROR_OUT_OF_HOST_MEMORY;
2392
2393   base += accel_struct->mem_offset;
2394   const struct radv_accel_struct_header *header = (const struct radv_accel_struct_header *)base;
2395
2396   char *dst = pInfo->dst.hostAddress;
2397   struct radv_accel_struct_serialization_header *dst_header = (void *)dst;
2398   dst += sizeof(*dst_header) + sizeof(uint64_t) * header->instance_count;
2399
2400   memcpy(dst_header->driver_uuid, device->physical_device->driver_uuid, VK_UUID_SIZE);
2401   memset(dst_header->accel_struct_compat, 0, VK_UUID_SIZE);
2402
2403   dst_header->serialization_size = header->serialization_size;
2404   dst_header->compacted_size = header->compacted_size;
2405   dst_header->instance_count = header->instance_count;
2406
2407   memcpy(dst, base, header->compacted_size);
2408
2409   for (unsigned i = 0; i < header->instance_count; ++i) {
2410      dst_header->instances[i] =
2411         *(const uint64_t *)(base + i * 128 + header->instance_offset) & ~63ull;
2412   }
2413
2414   device->ws->buffer_unmap(accel_struct->bo);
2415   return VK_SUCCESS;
2416}
2417
2418VKAPI_ATTR void VKAPI_CALL
2419radv_CmdCopyMemoryToAccelerationStructureKHR(
2420   VkCommandBuffer commandBuffer, const VkCopyMemoryToAccelerationStructureInfoKHR *pInfo)
2421{
2422   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2423   RADV_FROM_HANDLE(radv_acceleration_structure, dst, pInfo->dst);
2424   struct radv_meta_saved_state saved_state;
2425
2426   radv_meta_save(
2427      &saved_state, cmd_buffer,
2428      RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2429
2430   uint64_t dst_addr = radv_accel_struct_get_va(dst);
2431
2432   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2433                        cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2434
2435   const struct copy_constants consts = {
2436      .src_addr = pInfo->src.deviceAddress,
2437      .dst_addr = dst_addr,
2438      .mode = COPY_MODE_DESERIALIZE,
2439   };
2440
2441   radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2442                         cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2443                         VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2444
2445   radv_CmdDispatch(commandBuffer, 512, 1, 1);
2446   radv_meta_restore(&saved_state, cmd_buffer);
2447}
2448
2449VKAPI_ATTR void VKAPI_CALL
2450radv_CmdCopyAccelerationStructureToMemoryKHR(
2451   VkCommandBuffer commandBuffer, const VkCopyAccelerationStructureToMemoryInfoKHR *pInfo)
2452{
2453   RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer);
2454   RADV_FROM_HANDLE(radv_acceleration_structure, src, pInfo->src);
2455   struct radv_meta_saved_state saved_state;
2456
2457   radv_meta_save(
2458      &saved_state, cmd_buffer,
2459      RADV_META_SAVE_COMPUTE_PIPELINE | RADV_META_SAVE_DESCRIPTORS | RADV_META_SAVE_CONSTANTS);
2460
2461   uint64_t src_addr = radv_accel_struct_get_va(src);
2462
2463   radv_CmdBindPipeline(radv_cmd_buffer_to_handle(cmd_buffer), VK_PIPELINE_BIND_POINT_COMPUTE,
2464                        cmd_buffer->device->meta_state.accel_struct_build.copy_pipeline);
2465
2466   const struct copy_constants consts = {
2467      .src_addr = src_addr,
2468      .dst_addr = pInfo->dst.deviceAddress,
2469      .mode = COPY_MODE_SERIALIZE,
2470   };
2471
2472   radv_CmdPushConstants(radv_cmd_buffer_to_handle(cmd_buffer),
2473                         cmd_buffer->device->meta_state.accel_struct_build.copy_p_layout,
2474                         VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
2475
2476   cmd_buffer->state.flush_bits |=
2477      radv_dst_access_flush(cmd_buffer, VK_ACCESS_2_INDIRECT_COMMAND_READ_BIT, NULL);
2478
2479   radv_indirect_dispatch(cmd_buffer, src->bo,
2480                          src_addr + offsetof(struct radv_accel_struct_header, copy_dispatch_size));
2481   radv_meta_restore(&saved_state, cmd_buffer);
2482
2483   /* Set the header of the serialized data. */
2484   uint8_t header_data[2 * VK_UUID_SIZE] = {0};
2485   memcpy(header_data, cmd_buffer->device->physical_device->driver_uuid, VK_UUID_SIZE);
2486
2487   radv_update_buffer_cp(cmd_buffer, pInfo->dst.deviceAddress, header_data, sizeof(header_data));
2488}
2489
2490VKAPI_ATTR void VKAPI_CALL
2491radv_CmdBuildAccelerationStructuresIndirectKHR(
2492   VkCommandBuffer commandBuffer, uint32_t infoCount,
2493   const VkAccelerationStructureBuildGeometryInfoKHR *pInfos,
2494   const VkDeviceAddress *pIndirectDeviceAddresses, const uint32_t *pIndirectStrides,
2495   const uint32_t *const *ppMaxPrimitiveCounts)
2496{
2497   unreachable("Unimplemented");
2498}
2499