1/*
2 * Copyright © 2020 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#ifndef BRW_NIR_RT_BUILDER_H
25#define BRW_NIR_RT_BUILDER_H
26
27/* This file provides helpers to access memory based data structures that the
28 * RT hardware reads/writes and their locations.
29 *
30 * See also "Memory Based Data Structures for Ray Tracing" (BSpec 47547) and
31 * "Ray Tracing Address Computation for Memory Resident Structures" (BSpec
32 * 47550).
33 */
34
35#include "brw_rt.h"
36#include "nir_builder.h"
37
38#define is_access_for_builder(b) \
39   ((b)->shader->info.stage == MESA_SHADER_FRAGMENT ? \
40    ACCESS_INCLUDE_HELPERS : 0)
41
42static inline nir_ssa_def *
43brw_nir_rt_load(nir_builder *b, nir_ssa_def *addr, unsigned align,
44                unsigned components, unsigned bit_size)
45{
46   return nir_build_load_global(b, components, bit_size, addr,
47                                .align_mul = align,
48                                .access = is_access_for_builder(b));
49}
50
51static inline void
52brw_nir_rt_store(nir_builder *b, nir_ssa_def *addr, unsigned align,
53                 nir_ssa_def *value, unsigned write_mask)
54{
55   nir_build_store_global(b, value, addr,
56                          .align_mul = align,
57                          .write_mask = (write_mask) &
58                                        BITFIELD_MASK(value->num_components),
59                          .access = is_access_for_builder(b));
60}
61
62static inline nir_ssa_def *
63brw_nir_rt_load_const(nir_builder *b, unsigned components,
64                      nir_ssa_def *addr, nir_ssa_def *pred)
65{
66   return nir_build_load_global_const_block_intel(b, components, addr, pred);
67}
68
69static inline nir_ssa_def *
70brw_load_btd_dss_id(nir_builder *b)
71{
72   return nir_build_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_DSS);
73}
74
75static inline nir_ssa_def *
76brw_nir_rt_load_num_simd_lanes_per_dss(nir_builder *b,
77                                       const struct intel_device_info *devinfo)
78{
79   return nir_imm_int(b, devinfo->num_thread_per_eu *
80                         devinfo->max_eus_per_subslice *
81                         16 /* The RT computation is based off SIMD16 */);
82}
83
84static inline nir_ssa_def *
85brw_load_eu_thread_simd(nir_builder *b)
86{
87   return nir_build_load_topology_id_intel(b, .base = BRW_TOPOLOGY_ID_EU_THREAD_SIMD);
88}
89
90static inline nir_ssa_def *
91brw_nir_rt_async_stack_id(nir_builder *b)
92{
93   return nir_iadd(b, nir_umul_32x16(b, nir_load_ray_num_dss_rt_stacks_intel(b),
94                                        brw_load_btd_dss_id(b)),
95                      nir_load_btd_stack_id_intel(b));
96}
97
98static inline nir_ssa_def *
99brw_nir_rt_sync_stack_id(nir_builder *b)
100{
101   return brw_load_eu_thread_simd(b);
102}
103
104/* We have our own load/store scratch helpers because they emit a global
105 * memory read or write based on the scratch_base_ptr system value rather
106 * than a load/store_scratch intrinsic.
107 */
108static inline nir_ssa_def *
109brw_nir_rt_load_scratch(nir_builder *b, uint32_t offset, unsigned align,
110                        unsigned num_components, unsigned bit_size)
111{
112   nir_ssa_def *addr =
113      nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
114   return brw_nir_rt_load(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
115                             num_components, bit_size);
116}
117
118static inline void
119brw_nir_rt_store_scratch(nir_builder *b, uint32_t offset, unsigned align,
120                         nir_ssa_def *value, nir_component_mask_t write_mask)
121{
122   nir_ssa_def *addr =
123      nir_iadd_imm(b, nir_load_scratch_base_ptr(b, 1, 64, 1), offset);
124   brw_nir_rt_store(b, addr, MIN2(align, BRW_BTD_STACK_ALIGN),
125                    value, write_mask);
126}
127
128static inline void
129brw_nir_btd_spawn(nir_builder *b, nir_ssa_def *record_addr)
130{
131   nir_btd_spawn_intel(b, nir_load_btd_global_arg_addr_intel(b), record_addr);
132}
133
134static inline void
135brw_nir_btd_retire(nir_builder *b)
136{
137   nir_btd_retire_intel(b);
138}
139
140/** This is a pseudo-op which does a bindless return
141 *
142 * It loads the return address from the stack and calls btd_spawn to spawn the
143 * resume shader.
144 */
145static inline void
146brw_nir_btd_return(struct nir_builder *b)
147{
148   nir_ssa_def *resume_addr =
149      brw_nir_rt_load_scratch(b, BRW_BTD_STACK_RESUME_BSR_ADDR_OFFSET,
150                              8 /* align */, 1, 64);
151   brw_nir_btd_spawn(b, resume_addr);
152}
153
154static inline void
155assert_def_size(nir_ssa_def *def, unsigned num_components, unsigned bit_size)
156{
157   assert(def->num_components == num_components);
158   assert(def->bit_size == bit_size);
159}
160
161static inline nir_ssa_def *
162brw_nir_num_rt_stacks(nir_builder *b,
163                      const struct intel_device_info *devinfo)
164{
165   return nir_imul_imm(b, nir_load_ray_num_dss_rt_stacks_intel(b),
166                          intel_device_info_num_dual_subslices(devinfo));
167}
168
169static inline nir_ssa_def *
170brw_nir_rt_sw_hotzone_addr(nir_builder *b,
171                           const struct intel_device_info *devinfo)
172{
173   nir_ssa_def *offset32 =
174      nir_imul_imm(b, brw_nir_rt_async_stack_id(b),
175                      BRW_RT_SIZEOF_HOTZONE);
176
177   offset32 = nir_iadd(b, offset32, nir_ineg(b,
178      nir_imul_imm(b, brw_nir_num_rt_stacks(b, devinfo),
179                      BRW_RT_SIZEOF_HOTZONE)));
180
181   return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
182                      nir_i2i64(b, offset32));
183}
184
185static inline nir_ssa_def *
186brw_nir_rt_sync_stack_addr(nir_builder *b,
187                           nir_ssa_def *base_mem_addr,
188                           const struct intel_device_info *devinfo)
189{
190   /* For Ray queries (Synchronous Ray Tracing), the formula is similar but
191    * goes down from rtMemBasePtr :
192    *
193    *    syncBase  = RTDispatchGlobals.rtMemBasePtr
194    *              - (DSSID * NUM_SIMD_LANES_PER_DSS + SyncStackID + 1)
195    *              * syncStackSize
196    *
197    * We assume that we can calculate a 32-bit offset first and then add it
198    * to the 64-bit base address at the end.
199    */
200   nir_ssa_def *offset32 =
201      nir_imul(b,
202               nir_iadd(b,
203                        nir_imul(b, brw_load_btd_dss_id(b),
204                                    brw_nir_rt_load_num_simd_lanes_per_dss(b, devinfo)),
205                        nir_iadd_imm(b, brw_nir_rt_sync_stack_id(b), 1)),
206               nir_imm_int(b, BRW_RT_SIZEOF_RAY_QUERY));
207   return nir_isub(b, base_mem_addr, nir_u2u64(b, offset32));
208}
209
210static inline nir_ssa_def *
211brw_nir_rt_stack_addr(nir_builder *b)
212{
213   /* From the BSpec "Address Computation for Memory Based Data Structures:
214    * Ray and TraversalStack (Async Ray Tracing)":
215    *
216    *    stackBase = RTDispatchGlobals.rtMemBasePtr
217    *              + (DSSID * RTDispatchGlobals.numDSSRTStacks + stackID)
218    *              * RTDispatchGlobals.stackSizePerRay // 64B aligned
219    *
220    * We assume that we can calculate a 32-bit offset first and then add it
221    * to the 64-bit base address at the end.
222    */
223   nir_ssa_def *offset32 =
224      nir_imul(b, brw_nir_rt_async_stack_id(b),
225                  nir_load_ray_hw_stack_size_intel(b));
226   return nir_iadd(b, nir_load_ray_base_mem_addr_intel(b),
227                      nir_u2u64(b, offset32));
228}
229
230static inline nir_ssa_def *
231brw_nir_rt_mem_hit_addr_from_addr(nir_builder *b,
232                        nir_ssa_def *stack_addr,
233                        bool committed)
234{
235   return nir_iadd_imm(b, stack_addr, committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
236}
237
238static inline nir_ssa_def *
239brw_nir_rt_mem_hit_addr(nir_builder *b, bool committed)
240{
241   return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
242                          committed ? 0 : BRW_RT_SIZEOF_HIT_INFO);
243}
244
245static inline nir_ssa_def *
246brw_nir_rt_hit_attrib_data_addr(nir_builder *b)
247{
248   return nir_iadd_imm(b, brw_nir_rt_stack_addr(b),
249                          BRW_RT_OFFSETOF_HIT_ATTRIB_DATA);
250}
251
252static inline nir_ssa_def *
253brw_nir_rt_mem_ray_addr(nir_builder *b,
254                        nir_ssa_def *stack_addr,
255                        enum brw_rt_bvh_level bvh_level)
256{
257   /* From the BSpec "Address Computation for Memory Based Data Structures:
258    * Ray and TraversalStack (Async Ray Tracing)":
259    *
260    *    rayBase = stackBase + sizeof(HitInfo) * 2 // 64B aligned
261    *    rayPtr  = rayBase + bvhLevel * sizeof(Ray); // 64B aligned
262    *
263    * In Vulkan, we always have exactly two levels of BVH: World and Object.
264    */
265   uint32_t offset = BRW_RT_SIZEOF_HIT_INFO * 2 +
266                     bvh_level * BRW_RT_SIZEOF_RAY;
267   return nir_iadd_imm(b, stack_addr, offset);
268}
269
270static inline nir_ssa_def *
271brw_nir_rt_sw_stack_addr(nir_builder *b,
272                         const struct intel_device_info *devinfo)
273{
274   nir_ssa_def *addr = nir_load_ray_base_mem_addr_intel(b);
275
276   nir_ssa_def *offset32 = nir_imul(b, brw_nir_num_rt_stacks(b, devinfo),
277                                       nir_load_ray_hw_stack_size_intel(b));
278   addr = nir_iadd(b, addr, nir_u2u64(b, offset32));
279
280   nir_ssa_def *offset_in_stack =
281      nir_imul(b, nir_u2u64(b, brw_nir_rt_async_stack_id(b)),
282                  nir_u2u64(b, nir_load_ray_sw_stack_size_intel(b)));
283
284   return nir_iadd(b, addr, offset_in_stack);
285}
286
287static inline nir_ssa_def *
288nir_unpack_64_4x16_split_z(nir_builder *b, nir_ssa_def *val)
289{
290   return nir_unpack_32_2x16_split_x(b, nir_unpack_64_2x32_split_y(b, val));
291}
292
293struct brw_nir_rt_globals_defs {
294   nir_ssa_def *base_mem_addr;
295   nir_ssa_def *call_stack_handler_addr;
296   nir_ssa_def *hw_stack_size;
297   nir_ssa_def *num_dss_rt_stacks;
298   nir_ssa_def *hit_sbt_addr;
299   nir_ssa_def *hit_sbt_stride;
300   nir_ssa_def *miss_sbt_addr;
301   nir_ssa_def *miss_sbt_stride;
302   nir_ssa_def *sw_stack_size;
303   nir_ssa_def *launch_size;
304   nir_ssa_def *call_sbt_addr;
305   nir_ssa_def *call_sbt_stride;
306   nir_ssa_def *resume_sbt_addr;
307};
308
309static inline void
310brw_nir_rt_load_globals_addr(nir_builder *b,
311                             struct brw_nir_rt_globals_defs *defs,
312                             nir_ssa_def *addr)
313{
314   nir_ssa_def *data;
315   data = brw_nir_rt_load_const(b, 16, addr, nir_imm_true(b));
316   defs->base_mem_addr = nir_pack_64_2x32(b, nir_channels(b, data, 0x3));
317
318   defs->call_stack_handler_addr =
319      nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
320
321   defs->hw_stack_size = nir_channel(b, data, 4);
322   defs->num_dss_rt_stacks = nir_iand_imm(b, nir_channel(b, data, 5), 0xffff);
323   defs->hit_sbt_addr =
324      nir_pack_64_2x32_split(b, nir_channel(b, data, 8),
325                                nir_extract_i16(b, nir_channel(b, data, 9),
326                                                   nir_imm_int(b, 0)));
327   defs->hit_sbt_stride =
328      nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 9));
329   defs->miss_sbt_addr =
330      nir_pack_64_2x32_split(b, nir_channel(b, data, 10),
331                                nir_extract_i16(b, nir_channel(b, data, 11),
332                                                   nir_imm_int(b, 0)));
333   defs->miss_sbt_stride =
334      nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 11));
335   defs->sw_stack_size = nir_channel(b, data, 12);
336   defs->launch_size = nir_channels(b, data, 0x7u << 13);
337
338   data = brw_nir_rt_load_const(b, 8, nir_iadd_imm(b, addr, 64), nir_imm_true(b));
339   defs->call_sbt_addr =
340      nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
341                                nir_extract_i16(b, nir_channel(b, data, 1),
342                                                   nir_imm_int(b, 0)));
343   defs->call_sbt_stride =
344      nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 1));
345
346   defs->resume_sbt_addr =
347      nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
348}
349
350static inline void
351brw_nir_rt_load_globals(nir_builder *b,
352                        struct brw_nir_rt_globals_defs *defs)
353{
354   brw_nir_rt_load_globals_addr(b, defs, nir_load_btd_global_arg_addr_intel(b));
355}
356
357static inline nir_ssa_def *
358brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_ssa_def *vec2)
359{
360   /* Hit record leaf pointers are 42-bit and assumed to be in 64B chunks.
361    * This leaves 22 bits at the top for other stuff.
362    */
363   nir_ssa_def *ptr64 = nir_imul_imm(b, nir_pack_64_2x32(b, vec2), 64);
364
365   /* The top 16 bits (remember, we shifted by 6 already) contain garbage
366    * that we need to get rid of.
367    */
368   nir_ssa_def *ptr_lo = nir_unpack_64_2x32_split_x(b, ptr64);
369   nir_ssa_def *ptr_hi = nir_unpack_64_2x32_split_y(b, ptr64);
370   ptr_hi = nir_extract_i16(b, ptr_hi, nir_imm_int(b, 0));
371   return nir_pack_64_2x32_split(b, ptr_lo, ptr_hi);
372}
373
374/**
375 * MemHit memory layout (BSpec 47547) :
376 *
377 *      name            bits    description
378 *    - t               32      hit distance of current hit (or initial traversal distance)
379 *    - u               32      barycentric hit coordinates
380 *    - v               32      barycentric hit coordinates
381 *    - primIndexDelta  16      prim index delta for compressed meshlets and quads
382 *    - valid            1      set if there is a hit
383 *    - leafType         3      type of node primLeafPtr is pointing to
384 *    - primLeafIndex    4      index of the hit primitive inside the leaf
385 *    - bvhLevel         3      the instancing level at which the hit occured
386 *    - frontFace        1      whether we hit the front-facing side of a triangle (also used to pass opaque flag when calling intersection shaders)
387 *    - pad0             4      unused bits
388 *    - primLeafPtr     42      pointer to BVH leaf node (multiple of 64 bytes)
389 *    - hitGroupRecPtr0 22      LSB of hit group record of the hit triangle (multiple of 16 bytes)
390 *    - instLeafPtr     42      pointer to BVH instance leaf node (in multiple of 64 bytes)
391 *    - hitGroupRecPtr1 22      MSB of hit group record of the hit triangle (multiple of 32 bytes)
392 */
393struct brw_nir_rt_mem_hit_defs {
394   nir_ssa_def *t;
395   nir_ssa_def *tri_bary; /**< Only valid for triangle geometry */
396   nir_ssa_def *aabb_hit_kind; /**< Only valid for AABB geometry */
397   nir_ssa_def *valid;
398   nir_ssa_def *leaf_type;
399   nir_ssa_def *prim_leaf_index;
400   nir_ssa_def *bvh_level;
401   nir_ssa_def *front_face;
402   nir_ssa_def *done; /**< Only for ray queries */
403   nir_ssa_def *prim_leaf_ptr;
404   nir_ssa_def *inst_leaf_ptr;
405};
406
407static inline void
408brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
409                                  struct brw_nir_rt_mem_hit_defs *defs,
410                                  nir_ssa_def *stack_addr,
411                                  bool committed)
412{
413   nir_ssa_def *hit_addr =
414      brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
415
416   nir_ssa_def *data = brw_nir_rt_load(b, hit_addr, 16, 4, 32);
417   defs->t = nir_channel(b, data, 0);
418   defs->aabb_hit_kind = nir_channel(b, data, 1);
419   defs->tri_bary = nir_channels(b, data, 0x6);
420   nir_ssa_def *bitfield = nir_channel(b, data, 3);
421   defs->valid = nir_i2b(b, nir_iand_imm(b, bitfield, 1u << 16));
422   defs->leaf_type =
423      nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 17), nir_imm_int(b, 3));
424   defs->prim_leaf_index =
425      nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 20), nir_imm_int(b, 4));
426   defs->bvh_level =
427      nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 24), nir_imm_int(b, 3));
428   defs->front_face = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 27));
429   defs->done = nir_i2b(b, nir_iand_imm(b, bitfield, 1 << 28));
430
431   data = brw_nir_rt_load(b, nir_iadd_imm(b, hit_addr, 16), 16, 4, 32);
432   defs->prim_leaf_ptr =
433      brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 0));
434   defs->inst_leaf_ptr =
435      brw_nir_rt_unpack_leaf_ptr(b, nir_channels(b, data, 0x3 << 2));
436}
437
438static inline void
439brw_nir_rt_init_mem_hit_at_addr(nir_builder *b,
440                                nir_ssa_def *stack_addr,
441                                bool committed,
442                                nir_ssa_def *t_max)
443{
444   nir_ssa_def *mem_hit_addr =
445      brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
446
447   /* Set the t_max value from the ray initialization */
448   nir_ssa_def *hit_t_addr = mem_hit_addr;
449   brw_nir_rt_store(b, hit_t_addr, 4, t_max, 0x1);
450
451   /* Clear all the flags packed behind primIndexDelta */
452   nir_ssa_def *state_addr = nir_iadd_imm(b, mem_hit_addr, 12);
453   brw_nir_rt_store(b, state_addr, 4, nir_imm_int(b, 0), 0x1);
454}
455
456static inline void
457brw_nir_rt_load_mem_hit(nir_builder *b,
458                        struct brw_nir_rt_mem_hit_defs *defs,
459                        bool committed)
460{
461   brw_nir_rt_load_mem_hit_from_addr(b, defs, brw_nir_rt_stack_addr(b),
462                                     committed);
463}
464
465static inline void
466brw_nir_memcpy_global(nir_builder *b,
467                      nir_ssa_def *dst_addr, uint32_t dst_align,
468                      nir_ssa_def *src_addr, uint32_t src_align,
469                      uint32_t size)
470{
471   /* We're going to copy in 16B chunks */
472   assert(size % 16 == 0);
473   dst_align = MIN2(dst_align, 16);
474   src_align = MIN2(src_align, 16);
475
476   for (unsigned offset = 0; offset < size; offset += 16) {
477      nir_ssa_def *data =
478         brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), src_align,
479                         4, 32);
480      brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), dst_align,
481                       data, 0xf /* write_mask */);
482   }
483}
484
485static inline void
486brw_nir_memclear_global(nir_builder *b,
487                        nir_ssa_def *dst_addr, uint32_t dst_align,
488                        uint32_t size)
489{
490   /* We're going to copy in 16B chunks */
491   assert(size % 16 == 0);
492   dst_align = MIN2(dst_align, 16);
493
494   nir_ssa_def *zero = nir_imm_ivec4(b, 0, 0, 0, 0);
495   for (unsigned offset = 0; offset < size; offset += 16) {
496      brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), dst_align,
497                       zero, 0xf /* write_mask */);
498   }
499}
500
501static inline nir_ssa_def *
502brw_nir_rt_query_done(nir_builder *b, nir_ssa_def *stack_addr)
503{
504   struct brw_nir_rt_mem_hit_defs hit_in = {};
505   brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, stack_addr,
506                                     false /* committed */);
507
508   return hit_in.done;
509}
510
511static inline void
512brw_nir_rt_set_dword_bit_at(nir_builder *b,
513                            nir_ssa_def *addr,
514                            uint32_t addr_offset,
515                            uint32_t bit)
516{
517   nir_ssa_def *dword_addr = nir_iadd_imm(b, addr, addr_offset);
518   nir_ssa_def *dword = brw_nir_rt_load(b, dword_addr, 4, 1, 32);
519   brw_nir_rt_store(b, dword_addr, 4, nir_ior_imm(b, dword, 1u << bit), 0x1);
520}
521
522static inline void
523brw_nir_rt_query_mark_done(nir_builder *b, nir_ssa_def *stack_addr)
524{
525   brw_nir_rt_set_dword_bit_at(b,
526                               brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
527                                                                 false /* committed */),
528                               4 * 3 /* dword offset */, 28 /* bit */);
529}
530
531/* This helper clears the 3rd dword of the MemHit structure where the valid
532 * bit is located.
533 */
534static inline void
535brw_nir_rt_query_mark_init(nir_builder *b, nir_ssa_def *stack_addr)
536{
537   nir_ssa_def *dword_addr;
538
539   for (uint32_t i = 0; i < 2; i++) {
540      dword_addr =
541         nir_iadd_imm(b,
542                      brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr,
543                                                        i == 0 /* committed */),
544                      4 * 3 /* dword offset */);
545      brw_nir_rt_store(b, dword_addr, 4, nir_imm_int(b, 0), 0x1);
546   }
547}
548
549/* This helper is pretty much a memcpy of uncommitted into committed hit
550 * structure, just adding the valid bit.
551 */
552static inline void
553brw_nir_rt_commit_hit_addr(nir_builder *b, nir_ssa_def *stack_addr)
554{
555   nir_ssa_def *dst_addr =
556      brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
557   nir_ssa_def *src_addr =
558      brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
559
560   for (unsigned offset = 0; offset < BRW_RT_SIZEOF_HIT_INFO; offset += 16) {
561      nir_ssa_def *data =
562         brw_nir_rt_load(b, nir_iadd_imm(b, src_addr, offset), 16, 4, 32);
563
564      if (offset == 0) {
565         data = nir_vec4(b,
566                         nir_channel(b, data, 0),
567                         nir_channel(b, data, 1),
568                         nir_channel(b, data, 2),
569                         nir_ior_imm(b,
570                                     nir_channel(b, data, 3),
571                                     0x1 << 16 /* valid */));
572
573         /* Also write the potential hit as we change it. */
574         brw_nir_rt_store(b, nir_iadd_imm(b, src_addr, offset), 16,
575                          data, 0xf /* write_mask */);
576      }
577
578      brw_nir_rt_store(b, nir_iadd_imm(b, dst_addr, offset), 16,
579                       data, 0xf /* write_mask */);
580   }
581}
582
583static inline void
584brw_nir_rt_commit_hit(nir_builder *b)
585{
586   nir_ssa_def *stack_addr = brw_nir_rt_stack_addr(b);
587   brw_nir_rt_commit_hit_addr(b, stack_addr);
588}
589
590static inline void
591brw_nir_rt_generate_hit_addr(nir_builder *b, nir_ssa_def *stack_addr, nir_ssa_def *t_val)
592{
593   nir_ssa_def *committed_addr =
594      brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, true /* committed */);
595   nir_ssa_def *potential_addr =
596      brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, false /* committed */);
597
598   /* Set:
599    *
600    *   potential.t     = t_val;
601    *   potential.valid = true;
602    */
603   nir_ssa_def *potential_hit_dwords_0_3 =
604      brw_nir_rt_load(b, potential_addr, 16, 4, 32);
605   potential_hit_dwords_0_3 =
606      nir_vec4(b,
607               t_val,
608               nir_channel(b, potential_hit_dwords_0_3, 1),
609               nir_channel(b, potential_hit_dwords_0_3, 2),
610               nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3),
611                           (0x1 << 16) /* valid */));
612   brw_nir_rt_store(b, potential_addr, 16, potential_hit_dwords_0_3, 0xf /* write_mask */);
613
614   /* Set:
615    *
616    *   committed.t               = t_val;
617    *   committed.u               = 0.0f;
618    *   committed.v               = 0.0f;
619    *   committed.valid           = true;
620    *   committed.leaf_type       = potential.leaf_type;
621    *   committed.bvh_level       = BRW_RT_BVH_LEVEL_OBJECT;
622    *   committed.front_face      = false;
623    *   committed.prim_leaf_index = 0;
624    *   committed.done            = false;
625    */
626   nir_ssa_def *committed_hit_dwords_0_3 =
627      brw_nir_rt_load(b, committed_addr, 16, 4, 32);
628   committed_hit_dwords_0_3 =
629      nir_vec4(b,
630               t_val,
631               nir_imm_float(b, 0.0f),
632               nir_imm_float(b, 0.0f),
633               nir_ior_imm(b,
634                           nir_ior_imm(b, nir_channel(b, potential_hit_dwords_0_3, 3), 0x000e0000),
635                           (0x1 << 16)                     /* valid */ |
636                           (BRW_RT_BVH_LEVEL_OBJECT << 24) /* leaf_type */));
637   brw_nir_rt_store(b, committed_addr, 16, committed_hit_dwords_0_3, 0xf /* write_mask */);
638
639   /* Set:
640    *
641    *   committed.prim_leaf_ptr   = potential.prim_leaf_ptr;
642    *   committed.inst_leaf_ptr   = potential.inst_leaf_ptr;
643    */
644   brw_nir_memcpy_global(b,
645                         nir_iadd_imm(b, committed_addr, 16), 16,
646                         nir_iadd_imm(b, potential_addr, 16), 16,
647                         16);
648}
649
650struct brw_nir_rt_mem_ray_defs {
651   nir_ssa_def *orig;
652   nir_ssa_def *dir;
653   nir_ssa_def *t_near;
654   nir_ssa_def *t_far;
655   nir_ssa_def *root_node_ptr;
656   nir_ssa_def *ray_flags;
657   nir_ssa_def *hit_group_sr_base_ptr;
658   nir_ssa_def *hit_group_sr_stride;
659   nir_ssa_def *miss_sr_ptr;
660   nir_ssa_def *shader_index_multiplier;
661   nir_ssa_def *inst_leaf_ptr;
662   nir_ssa_def *ray_mask;
663};
664
665static inline void
666brw_nir_rt_store_mem_ray_query_at_addr(nir_builder *b,
667                                       nir_ssa_def *ray_addr,
668                                       const struct brw_nir_rt_mem_ray_defs *defs)
669{
670   assert_def_size(defs->orig, 3, 32);
671   assert_def_size(defs->dir, 3, 32);
672   brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
673      nir_vec4(b, nir_channel(b, defs->orig, 0),
674                  nir_channel(b, defs->orig, 1),
675                  nir_channel(b, defs->orig, 2),
676                  nir_channel(b, defs->dir, 0)),
677      ~0 /* write mask */);
678
679   assert_def_size(defs->t_near, 1, 32);
680   assert_def_size(defs->t_far, 1, 32);
681   brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
682      nir_vec4(b, nir_channel(b, defs->dir, 1),
683                  nir_channel(b, defs->dir, 2),
684                  defs->t_near,
685                  defs->t_far),
686      ~0 /* write mask */);
687
688   assert_def_size(defs->root_node_ptr, 1, 64);
689   assert_def_size(defs->ray_flags, 1, 16);
690   brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
691      nir_vec2(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
692                  nir_pack_32_2x16_split(b,
693                     nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
694                     defs->ray_flags)),
695      0x3 /* write mask */);
696
697   /* leaf_ptr is optional */
698   nir_ssa_def *inst_leaf_ptr;
699   if (defs->inst_leaf_ptr) {
700      inst_leaf_ptr = defs->inst_leaf_ptr;
701   } else {
702      inst_leaf_ptr = nir_imm_int64(b, 0);
703   }
704
705   assert_def_size(inst_leaf_ptr, 1, 64);
706   assert_def_size(defs->ray_mask, 1, 32);
707   brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 56), 8,
708      nir_vec2(b, nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
709                  nir_pack_32_2x16_split(b,
710                     nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
711                     nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
712      ~0 /* write mask */);
713}
714
715static inline void
716brw_nir_rt_store_mem_ray(nir_builder *b,
717                         const struct brw_nir_rt_mem_ray_defs *defs,
718                         enum brw_rt_bvh_level bvh_level)
719{
720   nir_ssa_def *ray_addr =
721      brw_nir_rt_mem_ray_addr(b, brw_nir_rt_stack_addr(b), bvh_level);
722
723   assert_def_size(defs->orig, 3, 32);
724   assert_def_size(defs->dir, 3, 32);
725   brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 0), 16,
726      nir_vec4(b, nir_channel(b, defs->orig, 0),
727                  nir_channel(b, defs->orig, 1),
728                  nir_channel(b, defs->orig, 2),
729                  nir_channel(b, defs->dir, 0)),
730      ~0 /* write mask */);
731
732   assert_def_size(defs->t_near, 1, 32);
733   assert_def_size(defs->t_far, 1, 32);
734   brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 16), 16,
735      nir_vec4(b, nir_channel(b, defs->dir, 1),
736                  nir_channel(b, defs->dir, 2),
737                  defs->t_near,
738                  defs->t_far),
739      ~0 /* write mask */);
740
741   assert_def_size(defs->root_node_ptr, 1, 64);
742   assert_def_size(defs->ray_flags, 1, 16);
743   assert_def_size(defs->hit_group_sr_base_ptr, 1, 64);
744   assert_def_size(defs->hit_group_sr_stride, 1, 16);
745   brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
746      nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
747                  nir_pack_32_2x16_split(b,
748                     nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
749                     defs->ray_flags),
750                  nir_unpack_64_2x32_split_x(b, defs->hit_group_sr_base_ptr),
751                  nir_pack_32_2x16_split(b,
752                     nir_unpack_64_4x16_split_z(b, defs->hit_group_sr_base_ptr),
753                     defs->hit_group_sr_stride)),
754      ~0 /* write mask */);
755
756   /* leaf_ptr is optional */
757   nir_ssa_def *inst_leaf_ptr;
758   if (defs->inst_leaf_ptr) {
759      inst_leaf_ptr = defs->inst_leaf_ptr;
760   } else {
761      inst_leaf_ptr = nir_imm_int64(b, 0);
762   }
763
764   assert_def_size(defs->miss_sr_ptr, 1, 64);
765   assert_def_size(defs->shader_index_multiplier, 1, 32);
766   assert_def_size(inst_leaf_ptr, 1, 64);
767   assert_def_size(defs->ray_mask, 1, 32);
768   brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 16,
769      nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->miss_sr_ptr),
770                  nir_pack_32_2x16_split(b,
771                     nir_unpack_64_4x16_split_z(b, defs->miss_sr_ptr),
772                     nir_unpack_32_2x16_split_x(b,
773                        nir_ishl(b, defs->shader_index_multiplier,
774                                    nir_imm_int(b, 8)))),
775                  nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
776                  nir_pack_32_2x16_split(b,
777                     nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
778                     nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
779      ~0 /* write mask */);
780}
781
782static inline void
783brw_nir_rt_load_mem_ray_from_addr(nir_builder *b,
784                                  struct brw_nir_rt_mem_ray_defs *defs,
785                                  nir_ssa_def *ray_base_addr,
786                                  enum brw_rt_bvh_level bvh_level)
787{
788   nir_ssa_def *ray_addr = brw_nir_rt_mem_ray_addr(b,
789                                                   ray_base_addr,
790                                                   bvh_level);
791
792   nir_ssa_def *data[4] = {
793      brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr,  0), 16, 4, 32),
794      brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 16), 16, 4, 32),
795      brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 32), 16, 4, 32),
796      brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 48), 16, 4, 32),
797   };
798
799   defs->orig = nir_channels(b, data[0], 0x7);
800   defs->dir = nir_vec3(b, nir_channel(b, data[0], 3),
801                           nir_channel(b, data[1], 0),
802                           nir_channel(b, data[1], 1));
803   defs->t_near = nir_channel(b, data[1], 2);
804   defs->t_far = nir_channel(b, data[1], 3);
805   defs->root_node_ptr =
806      nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
807                                nir_extract_i16(b, nir_channel(b, data[2], 1),
808                                                   nir_imm_int(b, 0)));
809   defs->ray_flags =
810      nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 1));
811   defs->hit_group_sr_base_ptr =
812      nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
813                                nir_extract_i16(b, nir_channel(b, data[2], 3),
814                                                   nir_imm_int(b, 0)));
815   defs->hit_group_sr_stride =
816      nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 3));
817   defs->miss_sr_ptr =
818      nir_pack_64_2x32_split(b, nir_channel(b, data[3], 0),
819                                nir_extract_i16(b, nir_channel(b, data[3], 1),
820                                                   nir_imm_int(b, 0)));
821   defs->shader_index_multiplier =
822      nir_ushr(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 1)),
823                  nir_imm_int(b, 8));
824   defs->inst_leaf_ptr =
825      nir_pack_64_2x32_split(b, nir_channel(b, data[3], 2),
826                                nir_extract_i16(b, nir_channel(b, data[3], 3),
827                                                   nir_imm_int(b, 0)));
828   defs->ray_mask =
829      nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 3));
830}
831
832static inline void
833brw_nir_rt_load_mem_ray(nir_builder *b,
834                        struct brw_nir_rt_mem_ray_defs *defs,
835                        enum brw_rt_bvh_level bvh_level)
836{
837   brw_nir_rt_load_mem_ray_from_addr(b, defs, brw_nir_rt_stack_addr(b),
838                                     bvh_level);
839}
840
841struct brw_nir_rt_bvh_instance_leaf_defs {
842   nir_ssa_def *shader_index;
843   nir_ssa_def *contribution_to_hit_group_index;
844   nir_ssa_def *world_to_object[4];
845   nir_ssa_def *instance_id;
846   nir_ssa_def *instance_index;
847   nir_ssa_def *object_to_world[4];
848};
849
850static inline void
851brw_nir_rt_load_bvh_instance_leaf(nir_builder *b,
852                                  struct brw_nir_rt_bvh_instance_leaf_defs *defs,
853                                  nir_ssa_def *leaf_addr)
854{
855   defs->shader_index =
856      nir_iand_imm(b, brw_nir_rt_load(b, leaf_addr, 4, 1, 32), (1 << 24) - 1);
857   defs->contribution_to_hit_group_index =
858      nir_iand_imm(b,
859                   brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 4), 4, 1, 32),
860                   (1 << 24) - 1);
861
862   defs->world_to_object[0] =
863      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 16), 4, 3, 32);
864   defs->world_to_object[1] =
865      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 28), 4, 3, 32);
866   defs->world_to_object[2] =
867      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 40), 4, 3, 32);
868   /* The last column of the matrices is swapped between the two probably
869    * because it makes it easier/faster for hardware somehow.
870    */
871   defs->object_to_world[3] =
872      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 52), 4, 3, 32);
873
874   nir_ssa_def *data =
875      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 64), 4, 4, 32);
876   defs->instance_id = nir_channel(b, data, 2);
877   defs->instance_index = nir_channel(b, data, 3);
878
879   defs->object_to_world[0] =
880      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 80), 4, 3, 32);
881   defs->object_to_world[1] =
882      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 92), 4, 3, 32);
883   defs->object_to_world[2] =
884      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 104), 4, 3, 32);
885   defs->world_to_object[3] =
886      brw_nir_rt_load(b, nir_iadd_imm(b, leaf_addr, 116), 4, 3, 32);
887}
888
889struct brw_nir_rt_bvh_primitive_leaf_defs {
890   nir_ssa_def *shader_index;
891   nir_ssa_def *geom_mask;
892   nir_ssa_def *geom_index;
893   nir_ssa_def *type;
894   nir_ssa_def *geom_flags;
895};
896
897static inline void
898brw_nir_rt_load_bvh_primitive_leaf(nir_builder *b,
899                                   struct brw_nir_rt_bvh_primitive_leaf_defs *defs,
900                                   nir_ssa_def *leaf_addr)
901{
902   nir_ssa_def *desc = brw_nir_rt_load(b, leaf_addr, 4, 2, 32);
903
904   defs->shader_index =
905      nir_ubitfield_extract(b, nir_channel(b, desc, 0),
906                            nir_imm_int(b, 23), nir_imm_int(b, 0));
907   defs->geom_mask =
908      nir_ubitfield_extract(b, nir_channel(b, desc, 0),
909                            nir_imm_int(b, 31), nir_imm_int(b, 24));
910
911   defs->geom_index =
912      nir_ubitfield_extract(b, nir_channel(b, desc, 1),
913                            nir_imm_int(b, 28), nir_imm_int(b, 0));
914   defs->type =
915      nir_ubitfield_extract(b, nir_channel(b, desc, 1),
916                            nir_imm_int(b, 29), nir_imm_int(b, 29));
917   defs->geom_flags =
918      nir_ubitfield_extract(b, nir_channel(b, desc, 1),
919                            nir_imm_int(b, 31), nir_imm_int(b, 30));
920}
921
922static inline nir_ssa_def *
923brw_nir_rt_load_primitive_id_from_hit(nir_builder *b,
924                                      nir_ssa_def *is_procedural,
925                                      const struct brw_nir_rt_mem_hit_defs *defs)
926{
927   if (!is_procedural) {
928      is_procedural =
929         nir_ieq(b, defs->leaf_type,
930                    nir_imm_int(b, BRW_RT_BVH_NODE_TYPE_PROCEDURAL));
931   }
932
933   /* The IDs are located in the leaf. Take the index of the hit.
934    *
935    * The index in dw[3] for procedural and dw[2] for quad.
936    */
937   nir_ssa_def *offset =
938      nir_bcsel(b, is_procedural,
939                   nir_iadd_imm(b, nir_ishl_imm(b, defs->prim_leaf_index, 2), 12),
940                   nir_imm_int(b, 8));
941   return nir_load_global(b, nir_iadd(b, defs->prim_leaf_ptr,
942                                         nir_u2u64(b, offset)),
943                             4, /* align */ 1, 32);
944}
945
946static inline nir_ssa_def *
947brw_nir_rt_acceleration_structure_to_root_node(nir_builder *b,
948                                               nir_ssa_def *as_addr)
949{
950   /* The HW memory structure in which we specify what acceleration structure
951    * to traverse, takes the address to the root node in the acceleration
952    * structure, not the acceleration structure itself. To find that, we have
953    * to read the root node offset from the acceleration structure which is
954    * the first QWord.
955    *
956    * But if the acceleration structure pointer is NULL, then we should return
957    * NULL as root node pointer.
958    *
959    * TODO: we could optimize this by assuming that for a given version of the
960    * BVH, we can find the root node at a given offset.
961    */
962   nir_ssa_def *root_node_ptr, *null_node_ptr;
963   nir_push_if(b, nir_ieq(b, as_addr, nir_imm_int64(b, 0)));
964   {
965      null_node_ptr = nir_imm_int64(b, 0);
966   }
967   nir_push_else(b, NULL);
968   {
969      root_node_ptr =
970         nir_iadd(b, as_addr, brw_nir_rt_load(b, as_addr, 256, 1, 64));
971   }
972   nir_pop_if(b, NULL);
973
974   return nir_if_phi(b, null_node_ptr, root_node_ptr);
975}
976
977#endif /* BRW_NIR_RT_BUILDER_H */
978