1/*
2 * Copyright © 2020 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "brw_nir_rt.h"
25#include "brw_nir_rt_builder.h"
26
27static bool
28resize_deref(nir_builder *b, nir_deref_instr *deref,
29             unsigned num_components, unsigned bit_size)
30{
31   assert(deref->dest.is_ssa);
32   if (deref->dest.ssa.num_components == num_components &&
33       deref->dest.ssa.bit_size == bit_size)
34      return false;
35
36   /* NIR requires array indices have to match the deref bit size */
37   if (deref->dest.ssa.bit_size != bit_size &&
38       (deref->deref_type == nir_deref_type_array ||
39        deref->deref_type == nir_deref_type_ptr_as_array)) {
40      b->cursor = nir_before_instr(&deref->instr);
41      assert(deref->arr.index.is_ssa);
42      nir_ssa_def *idx;
43      if (nir_src_is_const(deref->arr.index)) {
44         idx = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index), bit_size);
45      } else {
46         idx = nir_i2i(b, deref->arr.index.ssa, bit_size);
47      }
48      nir_instr_rewrite_src(&deref->instr, &deref->arr.index,
49                            nir_src_for_ssa(idx));
50   }
51
52   deref->dest.ssa.num_components = num_components;
53   deref->dest.ssa.bit_size = bit_size;
54
55   return true;
56}
57
58static bool
59lower_rt_io_derefs(nir_shader *shader)
60{
61   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
62
63   bool progress = false;
64
65   unsigned num_shader_call_vars = 0;
66   nir_foreach_variable_with_modes(var, shader, nir_var_shader_call_data)
67      num_shader_call_vars++;
68
69   unsigned num_ray_hit_attrib_vars = 0;
70   nir_foreach_variable_with_modes(var, shader, nir_var_ray_hit_attrib)
71      num_ray_hit_attrib_vars++;
72
73   /* At most one payload is allowed because it's an input.  Technically, this
74    * is also true for hit attribute variables.  However, after we inline an
75    * any-hit shader into an intersection shader, we can end up with multiple
76    * hit attribute variables.  They'll end up mapping to a cast from the same
77    * base pointer so this is fine.
78    */
79   assert(num_shader_call_vars <= 1);
80
81   nir_builder b;
82   nir_builder_init(&b, impl);
83
84   b.cursor = nir_before_cf_list(&impl->body);
85   nir_ssa_def *call_data_addr = NULL;
86   if (num_shader_call_vars > 0) {
87      assert(shader->scratch_size >= BRW_BTD_STACK_CALLEE_DATA_SIZE);
88      call_data_addr =
89         brw_nir_rt_load_scratch(&b, BRW_BTD_STACK_CALL_DATA_PTR_OFFSET, 8,
90                                 1, 64);
91      progress = true;
92   }
93
94   gl_shader_stage stage = shader->info.stage;
95   nir_ssa_def *hit_attrib_addr = NULL;
96   if (num_ray_hit_attrib_vars > 0) {
97      assert(stage == MESA_SHADER_ANY_HIT ||
98             stage == MESA_SHADER_CLOSEST_HIT ||
99             stage == MESA_SHADER_INTERSECTION);
100      nir_ssa_def *hit_addr =
101         brw_nir_rt_mem_hit_addr(&b, stage == MESA_SHADER_CLOSEST_HIT);
102      /* The vec2 barycentrics are in 2nd and 3rd dwords of MemHit */
103      nir_ssa_def *bary_addr = nir_iadd_imm(&b, hit_addr, 4);
104      hit_attrib_addr = nir_bcsel(&b, nir_load_leaf_procedural_intel(&b),
105                                      brw_nir_rt_hit_attrib_data_addr(&b),
106                                      bary_addr);
107      progress = true;
108   }
109
110   nir_foreach_block(block, impl) {
111      nir_foreach_instr_safe(instr, block) {
112         if (instr->type != nir_instr_type_deref)
113            continue;
114
115         nir_deref_instr *deref = nir_instr_as_deref(instr);
116         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
117            deref->modes = nir_var_function_temp;
118            if (deref->deref_type == nir_deref_type_var) {
119               b.cursor = nir_before_instr(&deref->instr);
120               nir_deref_instr *cast =
121                  nir_build_deref_cast(&b, call_data_addr,
122                                       nir_var_function_temp,
123                                       deref->var->type, 0);
124               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
125                                        &cast->dest.ssa);
126               nir_instr_remove(&deref->instr);
127               progress = true;
128            }
129         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
130            deref->modes = nir_var_function_temp;
131            if (deref->deref_type == nir_deref_type_var) {
132               b.cursor = nir_before_instr(&deref->instr);
133               nir_deref_instr *cast =
134                  nir_build_deref_cast(&b, hit_attrib_addr,
135                                       nir_var_function_temp,
136                                       deref->type, 0);
137               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
138                                        &cast->dest.ssa);
139               nir_instr_remove(&deref->instr);
140               progress = true;
141            }
142         }
143
144         /* We're going to lower all function_temp memory to scratch using
145          * 64-bit addresses.  We need to resize all our derefs first or else
146          * nir_lower_explicit_io will have a fit.
147          */
148         if (nir_deref_mode_is(deref, nir_var_function_temp) &&
149             resize_deref(&b, deref, 1, 64))
150            progress = true;
151      }
152   }
153
154   if (progress) {
155      nir_metadata_preserve(impl, nir_metadata_block_index |
156                                  nir_metadata_dominance);
157   } else {
158      nir_metadata_preserve(impl, nir_metadata_all);
159   }
160
161   return progress;
162}
163
164/** Lowers ray-tracing shader I/O and scratch access
165 *
166 * SPV_KHR_ray_tracing adds three new types of I/O, each of which need their
167 * own bit of special care:
168 *
169 *  - Shader payload data:  This is represented by the IncomingCallableData
170 *    and IncomingRayPayload storage classes which are both represented by
171 *    nir_var_call_data in NIR.  There is at most one of these per-shader and
172 *    they contain payload data passed down the stack from the parent shader
173 *    when it calls executeCallable() or traceRay().  In our implementation,
174 *    the actual storage lives in the calling shader's scratch space and we're
175 *    passed a pointer to it.
176 *
177 *  - Hit attribute data:  This is represented by the HitAttribute storage
178 *    class in SPIR-V and nir_var_ray_hit_attrib in NIR.  For triangle
179 *    geometry, it's supposed to contain two floats which are the barycentric
180 *    coordinates.  For AABS/procedural geometry, it contains the hit data
181 *    written out by the intersection shader.  In our implementation, it's a
182 *    64-bit pointer which points either to the u/v area of the relevant
183 *    MemHit data structure or the space right after the HW ray stack entry.
184 *
185 *  - Shader record buffer data:  This allows read-only access to the data
186 *    stored in the SBT right after the bindless shader handles.  It's
187 *    effectively a UBO with a magic address.  Coming out of spirv_to_nir,
188 *    we get a nir_intrinsic_load_shader_record_ptr which is cast to a
189 *    nir_var_mem_global deref and all access happens through that.  The
190 *    shader_record_ptr system value is handled in brw_nir_lower_rt_intrinsics
191 *    and we assume nir_lower_explicit_io is called elsewhere thanks to
192 *    VK_KHR_buffer_device_address so there's really nothing to do here.
193 *
194 * We also handle lowering any remaining function_temp variables to scratch at
195 * this point.  This gets rid of any remaining arrays and also takes care of
196 * the sending side of ray payloads where we pass pointers to a function_temp
197 * variable down the call stack.
198 */
199static void
200lower_rt_io_and_scratch(nir_shader *nir)
201{
202   /* First, we to ensure all the I/O variables have explicit types.  Because
203    * these are shader-internal and don't come in from outside, they don't
204    * have an explicit memory layout and we have to assign them one.
205    */
206   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
207              nir_var_function_temp |
208              nir_var_shader_call_data |
209              nir_var_ray_hit_attrib,
210              glsl_get_natural_size_align_bytes);
211
212   /* Now patch any derefs to I/O vars */
213   NIR_PASS_V(nir, lower_rt_io_derefs);
214
215   /* Finally, lower any remaining function_temp, mem_constant, or
216    * ray_hit_attrib access to 64-bit global memory access.
217    */
218   NIR_PASS_V(nir, nir_lower_explicit_io,
219              nir_var_function_temp |
220              nir_var_mem_constant |
221              nir_var_ray_hit_attrib,
222              nir_address_format_64bit_global);
223}
224
225static void
226build_terminate_ray(nir_builder *b)
227{
228   nir_ssa_def *skip_closest_hit = nir_test_mask(b, nir_load_ray_flags(b),
229      BRW_RT_RAY_FLAG_SKIP_CLOSEST_HIT_SHADER);
230   nir_push_if(b, skip_closest_hit);
231   {
232      /* The shader that calls traceRay() is unable to access any ray hit
233       * information except for that which is explicitly written into the ray
234       * payload by shaders invoked during the trace.  If there's no closest-
235       * hit shader, then accepting the hit has no observable effect; it's
236       * just extra memory traffic for no reason.
237       */
238      brw_nir_btd_return(b);
239      nir_jump(b, nir_jump_halt);
240   }
241   nir_push_else(b, NULL);
242   {
243      /* The closest hit shader is in the same shader group as the any-hit
244       * shader that we're currently in.  We can get the address for its SBT
245       * handle by looking at the shader record pointer and subtracting the
246       * size of a SBT handle.  The BINDLESS_SHADER_RECORD for a closest hit
247       * shader is the first one in the SBT handle.
248       */
249      nir_ssa_def *closest_hit =
250         nir_iadd_imm(b, nir_load_shader_record_ptr(b),
251                        -BRW_RT_SBT_HANDLE_SIZE);
252
253      brw_nir_rt_commit_hit(b);
254      brw_nir_btd_spawn(b, closest_hit);
255      nir_jump(b, nir_jump_halt);
256   }
257   nir_pop_if(b, NULL);
258}
259
260/** Lowers away ray walk intrinsics
261 *
262 * This lowers terminate_ray, ignore_ray_intersection, and the NIR-specific
263 * accept_ray_intersection intrinsics to the appropriate Intel-specific
264 * intrinsics.
265 */
266static bool
267lower_ray_walk_intrinsics(nir_shader *shader,
268                          const struct intel_device_info *devinfo)
269{
270   assert(shader->info.stage == MESA_SHADER_ANY_HIT ||
271          shader->info.stage == MESA_SHADER_INTERSECTION);
272
273   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
274
275   nir_builder b;
276   nir_builder_init(&b, impl);
277
278   bool progress = false;
279   nir_foreach_block_safe(block, impl) {
280      nir_foreach_instr_safe(instr, block) {
281         if (instr->type != nir_instr_type_intrinsic)
282            continue;
283
284         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
285
286         switch (intrin->intrinsic) {
287         case nir_intrinsic_ignore_ray_intersection: {
288            b.cursor = nir_instr_remove(&intrin->instr);
289
290            /* We put the newly emitted code inside a dummy if because it's
291             * going to contain a jump instruction and we don't want to deal
292             * with that mess here.  It'll get dealt with by our control-flow
293             * optimization passes.
294             */
295            nir_push_if(&b, nir_imm_true(&b));
296            nir_trace_ray_intel(&b,
297                                nir_load_btd_global_arg_addr_intel(&b),
298                                nir_imm_int(&b, BRW_RT_BVH_LEVEL_OBJECT),
299                                nir_imm_int(&b, GEN_RT_TRACE_RAY_CONTINUE),
300                                .synchronous = false);
301            nir_jump(&b, nir_jump_halt);
302            nir_pop_if(&b, NULL);
303            progress = true;
304            break;
305         }
306
307         case nir_intrinsic_accept_ray_intersection: {
308            b.cursor = nir_instr_remove(&intrin->instr);
309
310            nir_ssa_def *terminate = nir_test_mask(&b, nir_load_ray_flags(&b),
311               BRW_RT_RAY_FLAG_TERMINATE_ON_FIRST_HIT);
312            nir_push_if(&b, terminate);
313            {
314               build_terminate_ray(&b);
315            }
316            nir_push_else(&b, NULL);
317            {
318               nir_trace_ray_intel(&b,
319                                   nir_load_btd_global_arg_addr_intel(&b),
320                                   nir_imm_int(&b, BRW_RT_BVH_LEVEL_OBJECT),
321                                   nir_imm_int(&b, GEN_RT_TRACE_RAY_COMMIT),
322                                   .synchronous = false);
323               nir_jump(&b, nir_jump_halt);
324            }
325            nir_pop_if(&b, NULL);
326            progress = true;
327            break;
328         }
329
330         case nir_intrinsic_terminate_ray: {
331            b.cursor = nir_instr_remove(&intrin->instr);
332            build_terminate_ray(&b);
333            progress = true;
334            break;
335         }
336
337         default:
338            break;
339         }
340      }
341   }
342
343   if (progress) {
344      nir_metadata_preserve(impl, nir_metadata_none);
345   } else {
346      nir_metadata_preserve(impl, nir_metadata_all);
347   }
348
349   return progress;
350}
351
352void
353brw_nir_lower_raygen(nir_shader *nir)
354{
355   assert(nir->info.stage == MESA_SHADER_RAYGEN);
356   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
357   lower_rt_io_and_scratch(nir);
358}
359
360void
361brw_nir_lower_any_hit(nir_shader *nir, const struct intel_device_info *devinfo)
362{
363   assert(nir->info.stage == MESA_SHADER_ANY_HIT);
364   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
365   NIR_PASS_V(nir, lower_ray_walk_intrinsics, devinfo);
366   lower_rt_io_and_scratch(nir);
367}
368
369void
370brw_nir_lower_closest_hit(nir_shader *nir)
371{
372   assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
373   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
374   lower_rt_io_and_scratch(nir);
375}
376
377void
378brw_nir_lower_miss(nir_shader *nir)
379{
380   assert(nir->info.stage == MESA_SHADER_MISS);
381   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
382   lower_rt_io_and_scratch(nir);
383}
384
385void
386brw_nir_lower_callable(nir_shader *nir)
387{
388   assert(nir->info.stage == MESA_SHADER_CALLABLE);
389   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
390   lower_rt_io_and_scratch(nir);
391}
392
393void
394brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
395                                            const nir_shader *any_hit,
396                                            const struct intel_device_info *devinfo)
397{
398   assert(intersection->info.stage == MESA_SHADER_INTERSECTION);
399   assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT);
400   NIR_PASS_V(intersection, brw_nir_lower_shader_returns);
401   NIR_PASS_V(intersection, brw_nir_lower_intersection_shader,
402              any_hit, devinfo);
403   NIR_PASS_V(intersection, lower_ray_walk_intrinsics, devinfo);
404   lower_rt_io_and_scratch(intersection);
405}
406
407static nir_ssa_def *
408build_load_uniform(nir_builder *b, unsigned offset,
409                   unsigned num_components, unsigned bit_size)
410{
411   return nir_load_uniform(b, num_components, bit_size, nir_imm_int(b, 0),
412                           .base = offset,
413                           .range = num_components * bit_size / 8);
414}
415
416#define load_trampoline_param(b, name, num_components, bit_size) \
417   build_load_uniform((b), offsetof(struct brw_rt_raygen_trampoline_params, name), \
418                      (num_components), (bit_size))
419
420nir_shader *
421brw_nir_create_raygen_trampoline(const struct brw_compiler *compiler,
422                                 void *mem_ctx)
423{
424   const struct intel_device_info *devinfo = compiler->devinfo;
425   const nir_shader_compiler_options *nir_options =
426      compiler->nir_options[MESA_SHADER_COMPUTE];
427
428   STATIC_ASSERT(sizeof(struct brw_rt_raygen_trampoline_params) == 32);
429
430   nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE,
431                                                  nir_options,
432                                                  "RT Ray-Gen Trampoline");
433   ralloc_steal(mem_ctx, b.shader);
434
435   b.shader->info.workgroup_size_variable = true;
436
437   /* The RT global data and raygen BINDLESS_SHADER_RECORD addresses are
438    * passed in as push constants in the first register.  We deal with the
439    * raygen BSR address here; the global data we'll deal with later.
440    */
441   b.shader->num_uniforms = 32;
442   nir_ssa_def *raygen_bsr_addr =
443      load_trampoline_param(&b, raygen_bsr_addr, 1, 64);
444   nir_ssa_def *local_shift =
445      nir_u2u32(&b, load_trampoline_param(&b, local_group_size_log2, 3, 8));
446
447   nir_ssa_def *global_id = nir_load_workgroup_id(&b, 32);
448   nir_ssa_def *simd_channel = nir_load_subgroup_invocation(&b);
449   nir_ssa_def *local_x =
450      nir_ubfe(&b, simd_channel, nir_imm_int(&b, 0),
451                  nir_channel(&b, local_shift, 0));
452   nir_ssa_def *local_y =
453      nir_ubfe(&b, simd_channel, nir_channel(&b, local_shift, 0),
454                  nir_channel(&b, local_shift, 1));
455   nir_ssa_def *local_z =
456      nir_ubfe(&b, simd_channel,
457                  nir_iadd(&b, nir_channel(&b, local_shift, 0),
458                              nir_channel(&b, local_shift, 1)),
459                  nir_channel(&b, local_shift, 2));
460   nir_ssa_def *launch_id =
461      nir_iadd(&b, nir_ishl(&b, global_id, local_shift),
462                  nir_vec3(&b, local_x, local_y, local_z));
463
464   nir_ssa_def *launch_size = nir_load_ray_launch_size(&b);
465   nir_push_if(&b, nir_ball(&b, nir_ult(&b, launch_id, launch_size)));
466   {
467      nir_store_global(&b, brw_nir_rt_sw_hotzone_addr(&b, devinfo), 16,
468                       nir_vec4(&b, nir_imm_int(&b, 0), /* Stack ptr */
469                                    nir_channel(&b, launch_id, 0),
470                                    nir_channel(&b, launch_id, 1),
471                                    nir_channel(&b, launch_id, 2)),
472                       0xf /* write mask */);
473
474      brw_nir_btd_spawn(&b, raygen_bsr_addr);
475   }
476   nir_push_else(&b, NULL);
477   {
478      /* Even though these invocations aren't being used for anything, the
479       * hardware allocated stack IDs for them.  They need to retire them.
480       */
481      brw_nir_btd_retire(&b);
482   }
483   nir_pop_if(&b, NULL);
484
485   nir_shader *nir = b.shader;
486   nir->info.name = ralloc_strdup(nir, "RT: TraceRay trampoline");
487   nir_validate_shader(nir, "in brw_nir_create_raygen_trampoline");
488   brw_preprocess_nir(compiler, nir, NULL);
489
490   NIR_PASS_V(nir, brw_nir_lower_rt_intrinsics, devinfo);
491
492   nir_builder_init(&b, nir_shader_get_entrypoint(b.shader));
493   /* brw_nir_lower_rt_intrinsics will leave us with a btd_global_arg_addr
494    * intrinsic which doesn't exist in compute shaders.  We also created one
495    * above when we generated the BTD spawn intrinsic.  Now we go through and
496    * replace them with a uniform load.
497    */
498   nir_foreach_block(block, b.impl) {
499      nir_foreach_instr_safe(instr, block) {
500         if (instr->type != nir_instr_type_intrinsic)
501            continue;
502
503         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
504         if (intrin->intrinsic != nir_intrinsic_load_btd_global_arg_addr_intel)
505            continue;
506
507         b.cursor = nir_before_instr(&intrin->instr);
508         nir_ssa_def *global_arg_addr =
509            load_trampoline_param(&b, rt_disp_globals_addr, 1, 64);
510         assert(intrin->dest.is_ssa);
511         nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
512                                  global_arg_addr);
513         nir_instr_remove(instr);
514      }
515   }
516
517   NIR_PASS_V(nir, brw_nir_lower_cs_intrinsics);
518
519   brw_nir_optimize(nir, compiler, true, false);
520
521   return nir;
522}
523