1/*
2 * Copyright © 2020 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "nir.h"
25#include "nir_builder.h"
26#include "nir_phi_builder.h"
27#include "util/u_math.h"
28
29static bool
30move_system_values_to_top(nir_shader *shader)
31{
32   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
33
34   bool progress = false;
35   nir_foreach_block(block, impl) {
36      nir_foreach_instr_safe(instr, block) {
37         if (instr->type != nir_instr_type_intrinsic)
38            continue;
39
40         /* These intrinsics not only can't be re-materialized but aren't
41          * preserved when moving to the continuation shader.  We have to move
42          * them to the top to ensure they get spilled as needed.
43          */
44         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
45         switch (intrin->intrinsic) {
46         case nir_intrinsic_load_shader_record_ptr:
47         case nir_intrinsic_load_btd_local_arg_addr_intel:
48            nir_instr_remove(instr);
49            nir_instr_insert(nir_before_cf_list(&impl->body), instr);
50            progress = true;
51            break;
52
53         default:
54            break;
55         }
56      }
57   }
58
59   if (progress) {
60      nir_metadata_preserve(impl, nir_metadata_block_index |
61                                  nir_metadata_dominance);
62   } else {
63      nir_metadata_preserve(impl, nir_metadata_all);
64   }
65
66   return progress;
67}
68
69static bool
70instr_is_shader_call(nir_instr *instr)
71{
72   if (instr->type != nir_instr_type_intrinsic)
73      return false;
74
75   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
76   return intrin->intrinsic == nir_intrinsic_trace_ray ||
77          intrin->intrinsic == nir_intrinsic_report_ray_intersection ||
78          intrin->intrinsic == nir_intrinsic_execute_callable;
79}
80
81/* Previously named bitset, it had to be renamed as FreeBSD defines a struct
82 * named bitset in sys/_bitset.h required by pthread_np.h which is included
83 * from src/util/u_thread.h that is indirectly included by this file.
84 */
85struct brw_bitset {
86   BITSET_WORD *set;
87   unsigned size;
88};
89
90static struct brw_bitset
91bitset_create(void *mem_ctx, unsigned size)
92{
93   return (struct brw_bitset) {
94      .set = rzalloc_array(mem_ctx, BITSET_WORD, BITSET_WORDS(size)),
95      .size = size,
96   };
97}
98
99static bool
100src_is_in_bitset(nir_src *src, void *_set)
101{
102   struct brw_bitset *set = _set;
103   assert(src->is_ssa);
104
105   /* Any SSA values which were added after we generated liveness information
106    * are things generated by this pass and, while most of it is arithmetic
107    * which we could re-materialize, we don't need to because it's only used
108    * for a single load/store and so shouldn't cross any shader calls.
109    */
110   if (src->ssa->index >= set->size)
111      return false;
112
113   return BITSET_TEST(set->set, src->ssa->index);
114}
115
116static void
117add_ssa_def_to_bitset(nir_ssa_def *def, struct brw_bitset *set)
118{
119   if (def->index >= set->size)
120      return;
121
122   BITSET_SET(set->set, def->index);
123}
124
125static bool
126can_remat_instr(nir_instr *instr, struct brw_bitset *remat)
127{
128   /* Set of all values which are trivially re-materializable and we shouldn't
129    * ever spill them.  This includes:
130    *
131    *   - Undef values
132    *   - Constants
133    *   - Uniforms (UBO or push constant)
134    *   - ALU combinations of any of the above
135    *   - Derefs which are either complete or casts of any of the above
136    *
137    * Because this pass rewrites things in-order and phis are always turned
138    * into register writes, We can use "is it SSA?" to answer the question
139    * "can my source be re-materialized?".
140    */
141   switch (instr->type) {
142   case nir_instr_type_alu:
143      if (!nir_instr_as_alu(instr)->dest.dest.is_ssa)
144         return false;
145
146      return nir_foreach_src(instr, src_is_in_bitset, remat);
147
148   case nir_instr_type_deref:
149      return nir_foreach_src(instr, src_is_in_bitset, remat);
150
151   case nir_instr_type_intrinsic: {
152      nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
153      switch (intrin->intrinsic) {
154      case nir_intrinsic_load_ubo:
155      case nir_intrinsic_vulkan_resource_index:
156      case nir_intrinsic_vulkan_resource_reindex:
157      case nir_intrinsic_load_vulkan_descriptor:
158      case nir_intrinsic_load_push_constant:
159         /* These intrinsics don't need to be spilled as long as they don't
160          * depend on any spilled values.
161          */
162         return nir_foreach_src(instr, src_is_in_bitset, remat);
163
164      case nir_intrinsic_load_scratch_base_ptr:
165      case nir_intrinsic_load_ray_launch_id:
166      case nir_intrinsic_load_topology_id_intel:
167      case nir_intrinsic_load_btd_global_arg_addr_intel:
168      case nir_intrinsic_load_btd_resume_sbt_addr_intel:
169      case nir_intrinsic_load_ray_base_mem_addr_intel:
170      case nir_intrinsic_load_ray_hw_stack_size_intel:
171      case nir_intrinsic_load_ray_sw_stack_size_intel:
172      case nir_intrinsic_load_ray_num_dss_rt_stacks_intel:
173      case nir_intrinsic_load_ray_hit_sbt_addr_intel:
174      case nir_intrinsic_load_ray_hit_sbt_stride_intel:
175      case nir_intrinsic_load_ray_miss_sbt_addr_intel:
176      case nir_intrinsic_load_ray_miss_sbt_stride_intel:
177      case nir_intrinsic_load_callable_sbt_addr_intel:
178      case nir_intrinsic_load_callable_sbt_stride_intel:
179      case nir_intrinsic_load_reloc_const_intel:
180      case nir_intrinsic_load_ray_query_global_intel:
181         /* Notably missing from the above list is btd_local_arg_addr_intel.
182          * This is because the resume shader will have a different local
183          * argument pointer because it has a different BSR.  Any access of
184          * the original shader's local arguments needs to be preserved so
185          * that pointer has to be saved on the stack.
186          *
187          * TODO: There may be some system values we want to avoid
188          *       re-materializing as well but we have to be very careful
189          *       to ensure that it's a system value which cannot change
190          *       across a shader call.
191          */
192         return true;
193
194      default:
195         return false;
196      }
197   }
198
199   case nir_instr_type_ssa_undef:
200   case nir_instr_type_load_const:
201      return true;
202
203   default:
204      return false;
205   }
206}
207
208static bool
209can_remat_ssa_def(nir_ssa_def *def, struct brw_bitset *remat)
210{
211   return can_remat_instr(def->parent_instr, remat);
212}
213
214static nir_ssa_def *
215remat_ssa_def(nir_builder *b, nir_ssa_def *def)
216{
217   nir_instr *clone = nir_instr_clone(b->shader, def->parent_instr);
218   nir_builder_instr_insert(b, clone);
219   return nir_instr_ssa_def(clone);
220}
221
222struct pbv_array {
223   struct nir_phi_builder_value **arr;
224   unsigned len;
225};
226
227static struct nir_phi_builder_value *
228get_phi_builder_value_for_def(nir_ssa_def *def,
229                              struct pbv_array *pbv_arr)
230{
231   if (def->index >= pbv_arr->len)
232      return NULL;
233
234   return pbv_arr->arr[def->index];
235}
236
237static nir_ssa_def *
238get_phi_builder_def_for_src(nir_src *src, struct pbv_array *pbv_arr,
239                            nir_block *block)
240{
241   assert(src->is_ssa);
242
243   struct nir_phi_builder_value *pbv =
244      get_phi_builder_value_for_def(src->ssa, pbv_arr);
245   if (pbv == NULL)
246      return NULL;
247
248   return nir_phi_builder_value_get_block_def(pbv, block);
249}
250
251static bool
252rewrite_instr_src_from_phi_builder(nir_src *src, void *_pbv_arr)
253{
254   nir_block *block;
255   if (src->parent_instr->type == nir_instr_type_phi) {
256      nir_phi_src *phi_src = exec_node_data(nir_phi_src, src, src);
257      block = phi_src->pred;
258   } else {
259      block = src->parent_instr->block;
260   }
261
262   nir_ssa_def *new_def = get_phi_builder_def_for_src(src, _pbv_arr, block);
263   if (new_def != NULL)
264      nir_instr_rewrite_src(src->parent_instr, src, nir_src_for_ssa(new_def));
265   return true;
266}
267
268static nir_ssa_def *
269spill_fill(nir_builder *before, nir_builder *after, nir_ssa_def *def, unsigned offset,
270           nir_address_format address_format, unsigned stack_alignment)
271{
272   const unsigned comp_size = def->bit_size / 8;
273
274   switch(address_format) {
275   case nir_address_format_32bit_offset:
276      nir_store_scratch(before, def, nir_imm_int(before, offset),
277                        .align_mul = MIN2(comp_size, stack_alignment),
278                        .write_mask = BITFIELD_MASK(def->num_components));
279      def = nir_load_scratch(after, def->num_components, def->bit_size,
280                             nir_imm_int(after, offset), .align_mul = MIN2(comp_size, stack_alignment));
281      break;
282   case nir_address_format_64bit_global: {
283      nir_ssa_def *addr = nir_iadd_imm(before, nir_load_scratch_base_ptr(before, 1, 64, 1), offset);
284      nir_store_global(before, addr, MIN2(comp_size, stack_alignment), def, ~0);
285      addr = nir_iadd_imm(after, nir_load_scratch_base_ptr(after, 1, 64, 1), offset);
286      def = nir_load_global(after, addr, MIN2(comp_size, stack_alignment),
287                            def->num_components, def->bit_size);
288      break;
289   }
290   default:
291      unreachable("Unimplemented address format");
292   }
293   return def;
294}
295
296static void
297spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls,
298                                      nir_address_format address_format,
299                                      unsigned stack_alignment)
300{
301   /* TODO: If a SSA def is filled more than once, we probably want to just
302    *       spill it at the LCM of the fill sites so we avoid unnecessary
303    *       extra spills
304    *
305    * TODO: If a SSA def is defined outside a loop but live through some call
306    *       inside the loop, we probably want to spill outside the loop.  We
307    *       may also want to fill outside the loop if it's not used in the
308    *       loop.
309    *
310    * TODO: Right now, we only re-materialize things if their immediate
311    *       sources are things which we filled.  We probably want to expand
312    *       that to re-materialize things whose sources are things we can
313    *       re-materialize from things we filled.  We may want some DAG depth
314    *       heuristic on this.
315    */
316
317   /* This happens per-shader rather than per-impl because we mess with
318    * nir_shader::scratch_size.
319    */
320   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
321
322   nir_metadata_require(impl, nir_metadata_live_ssa_defs |
323                              nir_metadata_dominance |
324                              nir_metadata_block_index);
325
326   void *mem_ctx = ralloc_context(shader);
327
328   const unsigned num_ssa_defs = impl->ssa_alloc;
329   const unsigned live_words = BITSET_WORDS(num_ssa_defs);
330   struct brw_bitset trivial_remat = bitset_create(mem_ctx, num_ssa_defs);
331
332   /* Array of all live SSA defs which are spill candidates */
333   nir_ssa_def **spill_defs =
334      rzalloc_array(mem_ctx, nir_ssa_def *, num_ssa_defs);
335
336   /* For each spill candidate, an array of every time it's defined by a fill,
337    * indexed by call instruction index.
338    */
339   nir_ssa_def ***fill_defs =
340      rzalloc_array(mem_ctx, nir_ssa_def **, num_ssa_defs);
341
342   /* For each call instruction, the liveness set at the call */
343   const BITSET_WORD **call_live =
344      rzalloc_array(mem_ctx, const BITSET_WORD *, num_calls);
345
346   /* For each call instruction, the block index of the block it lives in */
347   uint32_t *call_block_indices = rzalloc_array(mem_ctx, uint32_t, num_calls);
348
349   /* Walk the call instructions and fetch the liveness set and block index
350    * for each one.  We need to do this before we start modifying the shader
351    * so that liveness doesn't complain that it's been invalidated.  Don't
352    * worry, we'll be very careful with our live sets. :-)
353    */
354   unsigned call_idx = 0;
355   nir_foreach_block(block, impl) {
356      nir_foreach_instr(instr, block) {
357         if (!instr_is_shader_call(instr))
358            continue;
359
360         call_block_indices[call_idx] = block->index;
361
362         /* The objective here is to preserve values around shader call
363          * instructions.  Therefore, we use the live set after the
364          * instruction as the set of things we want to preserve.  Because
365          * none of our shader call intrinsics return anything, we don't have
366          * to worry about spilling over a return value.
367          *
368          * TODO: This isn't quite true for report_intersection.
369          */
370         call_live[call_idx] =
371            nir_get_live_ssa_defs(nir_after_instr(instr), mem_ctx);
372
373         call_idx++;
374      }
375   }
376
377   nir_builder before, after;
378   nir_builder_init(&before, impl);
379   nir_builder_init(&after, impl);
380
381   call_idx = 0;
382   unsigned max_scratch_size = shader->scratch_size;
383   nir_foreach_block(block, impl) {
384      nir_foreach_instr_safe(instr, block) {
385         nir_ssa_def *def = nir_instr_ssa_def(instr);
386         if (def != NULL) {
387            if (can_remat_ssa_def(def, &trivial_remat)) {
388               add_ssa_def_to_bitset(def, &trivial_remat);
389            } else {
390               spill_defs[def->index] = def;
391            }
392         }
393
394         if (!instr_is_shader_call(instr))
395            continue;
396
397         const BITSET_WORD *live = call_live[call_idx];
398
399         /* Make a copy of trivial_remat that we'll update as we crawl through
400          * the live SSA defs and unspill them.
401          */
402         struct brw_bitset remat = bitset_create(mem_ctx, num_ssa_defs);
403         memcpy(remat.set, trivial_remat.set, live_words * sizeof(BITSET_WORD));
404
405         /* Before the two builders are always separated by the call
406          * instruction, it won't break anything to have two of them.
407          */
408         before.cursor = nir_before_instr(instr);
409         after.cursor = nir_after_instr(instr);
410
411         unsigned offset = shader->scratch_size;
412         for (unsigned w = 0; w < live_words; w++) {
413            BITSET_WORD spill_mask = live[w] & ~trivial_remat.set[w];
414            while (spill_mask) {
415               int i = u_bit_scan(&spill_mask);
416               assert(i >= 0);
417               unsigned index = w * BITSET_WORDBITS + i;
418               assert(index < num_ssa_defs);
419
420               nir_ssa_def *def = spill_defs[index];
421               if (can_remat_ssa_def(def, &remat)) {
422                  /* If this SSA def is re-materializable or based on other
423                   * things we've already spilled, re-materialize it rather
424                   * than spilling and filling.  Anything which is trivially
425                   * re-materializable won't even get here because we take
426                   * those into account in spill_mask above.
427                   */
428                  def = remat_ssa_def(&after, def);
429               } else {
430                  bool is_bool = def->bit_size == 1;
431                  if (is_bool)
432                     def = nir_b2b32(&before, def);
433
434                  const unsigned comp_size = def->bit_size / 8;
435                  offset = ALIGN(offset, comp_size);
436
437                  def = spill_fill(&before, &after, def, offset,
438                                   address_format,stack_alignment);
439
440                  if (is_bool)
441                     def = nir_b2b1(&after, def);
442
443                  offset += def->num_components * comp_size;
444               }
445
446               /* Mark this SSA def as available in the remat set so that, if
447                * some other SSA def we need is computed based on it, we can
448                * just re-compute instead of fetching from memory.
449                */
450               BITSET_SET(remat.set, index);
451
452               /* For now, we just make a note of this new SSA def.  We'll
453                * fix things up with the phi builder as a second pass.
454                */
455               if (fill_defs[index] == NULL) {
456                  fill_defs[index] =
457                     rzalloc_array(mem_ctx, nir_ssa_def *, num_calls);
458               }
459               fill_defs[index][call_idx] = def;
460            }
461         }
462
463         nir_builder *b = &before;
464
465         offset = ALIGN(offset, stack_alignment);
466         max_scratch_size = MAX2(max_scratch_size, offset);
467
468         /* First thing on the called shader's stack is the resume address
469          * followed by a pointer to the payload.
470          */
471         nir_intrinsic_instr *call = nir_instr_as_intrinsic(instr);
472
473         /* Lower to generic intrinsics with information about the stack & resume shader. */
474         switch (call->intrinsic) {
475         case nir_intrinsic_trace_ray: {
476            nir_rt_trace_ray(b, call->src[0].ssa, call->src[1].ssa,
477                              call->src[2].ssa, call->src[3].ssa,
478                              call->src[4].ssa, call->src[5].ssa,
479                              call->src[6].ssa, call->src[7].ssa,
480                              call->src[8].ssa, call->src[9].ssa,
481                              call->src[10].ssa,
482                              .call_idx = call_idx, .stack_size = offset);
483            break;
484         }
485
486         case nir_intrinsic_report_ray_intersection:
487            unreachable("Any-hit shaders must be inlined");
488
489         case nir_intrinsic_execute_callable: {
490            nir_rt_execute_callable(b, call->src[0].ssa, call->src[1].ssa, .call_idx = call_idx, .stack_size = offset);
491            break;
492         }
493
494         default:
495            unreachable("Invalid shader call instruction");
496         }
497
498         nir_rt_resume(b, .call_idx = call_idx, .stack_size = offset);
499
500         nir_instr_remove(&call->instr);
501
502         call_idx++;
503      }
504   }
505   assert(call_idx == num_calls);
506   shader->scratch_size = max_scratch_size;
507
508   struct nir_phi_builder *pb = nir_phi_builder_create(impl);
509   struct pbv_array pbv_arr = {
510      .arr = rzalloc_array(mem_ctx, struct nir_phi_builder_value *,
511                           num_ssa_defs),
512      .len = num_ssa_defs,
513   };
514
515   const unsigned block_words = BITSET_WORDS(impl->num_blocks);
516   BITSET_WORD *def_blocks = ralloc_array(mem_ctx, BITSET_WORD, block_words);
517
518   /* Go through and set up phi builder values for each spillable value which
519    * we ever needed to spill at any point.
520    */
521   for (unsigned index = 0; index < num_ssa_defs; index++) {
522      if (fill_defs[index] == NULL)
523         continue;
524
525      nir_ssa_def *def = spill_defs[index];
526
527      memset(def_blocks, 0, block_words * sizeof(BITSET_WORD));
528      BITSET_SET(def_blocks, def->parent_instr->block->index);
529      for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) {
530         if (fill_defs[index][call_idx] != NULL)
531            BITSET_SET(def_blocks, call_block_indices[call_idx]);
532      }
533
534      pbv_arr.arr[index] = nir_phi_builder_add_value(pb, def->num_components,
535                                                     def->bit_size, def_blocks);
536   }
537
538   /* Walk the shader one more time and rewrite SSA defs as needed using the
539    * phi builder.
540    */
541   nir_foreach_block(block, impl) {
542      nir_foreach_instr_safe(instr, block) {
543         nir_ssa_def *def = nir_instr_ssa_def(instr);
544         if (def != NULL) {
545            struct nir_phi_builder_value *pbv =
546               get_phi_builder_value_for_def(def, &pbv_arr);
547            if (pbv != NULL)
548               nir_phi_builder_value_set_block_def(pbv, block, def);
549         }
550
551         if (instr->type == nir_instr_type_phi)
552            continue;
553
554         nir_foreach_src(instr, rewrite_instr_src_from_phi_builder, &pbv_arr);
555
556         if (instr->type != nir_instr_type_intrinsic)
557            continue;
558
559         nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
560         if (resume->intrinsic != nir_intrinsic_rt_resume)
561            continue;
562
563         call_idx = nir_intrinsic_call_idx(resume);
564
565         /* Technically, this is the wrong place to add the fill defs to the
566          * phi builder values because we haven't seen any of the load_scratch
567          * instructions for this call yet.  However, we know based on how we
568          * emitted them that no value ever gets used until after the load
569          * instruction has been emitted so this should be safe.  If we ever
570          * fail validation due this it likely means a bug in our spilling
571          * code and not the phi re-construction code here.
572          */
573         for (unsigned index = 0; index < num_ssa_defs; index++) {
574            if (fill_defs[index] && fill_defs[index][call_idx]) {
575               nir_phi_builder_value_set_block_def(pbv_arr.arr[index], block,
576                                                   fill_defs[index][call_idx]);
577            }
578         }
579      }
580
581      nir_if *following_if = nir_block_get_following_if(block);
582      if (following_if) {
583         nir_ssa_def *new_def =
584            get_phi_builder_def_for_src(&following_if->condition,
585                                        &pbv_arr, block);
586         if (new_def != NULL)
587            nir_if_rewrite_condition(following_if, nir_src_for_ssa(new_def));
588      }
589
590      /* Handle phi sources that source from this block.  We have to do this
591       * as a separate pass because the phi builder assumes that uses and
592       * defs are processed in an order that respects dominance.  When we have
593       * loops, a phi source may be a back-edge so we have to handle it as if
594       * it were one of the last instructions in the predecessor block.
595       */
596      nir_foreach_phi_src_leaving_block(block,
597                                        rewrite_instr_src_from_phi_builder,
598                                        &pbv_arr);
599   }
600
601   nir_phi_builder_finish(pb);
602
603   ralloc_free(mem_ctx);
604
605   nir_metadata_preserve(impl, nir_metadata_block_index |
606                               nir_metadata_dominance);
607}
608
609static nir_instr *
610find_resume_instr(nir_function_impl *impl, unsigned call_idx)
611{
612   nir_foreach_block(block, impl) {
613      nir_foreach_instr(instr, block) {
614         if (instr->type != nir_instr_type_intrinsic)
615            continue;
616
617         nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
618         if (resume->intrinsic != nir_intrinsic_rt_resume)
619            continue;
620
621         if (nir_intrinsic_call_idx(resume) == call_idx)
622            return &resume->instr;
623      }
624   }
625   unreachable("Couldn't find resume instruction");
626}
627
628/* Walk the CF tree and duplicate the contents of every loop, one half runs on
629 * resume and the other half is for any post-resume loop iterations.  We are
630 * careful in our duplication to ensure that resume_instr is in the resume
631 * half of the loop though a copy of resume_instr will remain in the other
632 * half as well in case the same shader call happens twice.
633 */
634static bool
635duplicate_loop_bodies(nir_function_impl *impl, nir_instr *resume_instr)
636{
637   nir_register *resume_reg = NULL;
638   for (nir_cf_node *node = resume_instr->block->cf_node.parent;
639        node->type != nir_cf_node_function; node = node->parent) {
640      if (node->type != nir_cf_node_loop)
641         continue;
642
643      nir_loop *loop = nir_cf_node_as_loop(node);
644
645      if (resume_reg == NULL) {
646         /* We only create resume_reg if we encounter a loop.  This way we can
647          * avoid re-validating the shader and calling ssa_to_regs in the case
648          * where it's just if-ladders.
649          */
650         resume_reg = nir_local_reg_create(impl);
651         resume_reg->num_components = 1;
652         resume_reg->bit_size = 1;
653
654         nir_builder b;
655         nir_builder_init(&b, impl);
656
657         /* Initialize resume to true */
658         b.cursor = nir_before_cf_list(&impl->body);
659         nir_store_reg(&b, resume_reg, nir_imm_true(&b), 1);
660
661         /* Set resume to false right after the resume instruction */
662         b.cursor = nir_after_instr(resume_instr);
663         nir_store_reg(&b, resume_reg, nir_imm_false(&b), 1);
664      }
665
666      /* Before we go any further, make sure that everything which exits the
667       * loop or continues around to the top of the loop does so through
668       * registers.  We're about to duplicate the loop body and we'll have
669       * serious trouble if we don't do this.
670       */
671      nir_convert_loop_to_lcssa(loop);
672      nir_lower_phis_to_regs_block(nir_loop_first_block(loop));
673      nir_lower_phis_to_regs_block(
674         nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node)));
675
676      nir_cf_list cf_list;
677      nir_cf_list_extract(&cf_list, &loop->body);
678
679      nir_if *_if = nir_if_create(impl->function->shader);
680      _if->condition = nir_src_for_reg(resume_reg);
681      nir_cf_node_insert(nir_after_cf_list(&loop->body), &_if->cf_node);
682
683      nir_cf_list clone;
684      nir_cf_list_clone(&clone, &cf_list, &loop->cf_node, NULL);
685
686      /* Insert the clone in the else and the original in the then so that
687       * the resume_instr remains valid even after the duplication.
688       */
689      nir_cf_reinsert(&cf_list, nir_before_cf_list(&_if->then_list));
690      nir_cf_reinsert(&clone, nir_before_cf_list(&_if->else_list));
691   }
692
693   if (resume_reg != NULL)
694      nir_metadata_preserve(impl, nir_metadata_none);
695
696   return resume_reg != NULL;
697}
698
699static bool
700cf_node_contains_block(nir_cf_node *node, nir_block *block)
701{
702   for (nir_cf_node *n = &block->cf_node; n != NULL; n = n->parent) {
703      if (n == node)
704         return true;
705   }
706
707   return false;
708}
709
710static void
711rewrite_phis_to_pred(nir_block *block, nir_block *pred)
712{
713   nir_foreach_instr(instr, block) {
714      if (instr->type != nir_instr_type_phi)
715         break;
716
717      nir_phi_instr *phi = nir_instr_as_phi(instr);
718
719      ASSERTED bool found = false;
720      nir_foreach_phi_src(phi_src, phi) {
721         if (phi_src->pred == pred) {
722            found = true;
723            assert(phi_src->src.is_ssa);
724            nir_ssa_def_rewrite_uses(&phi->dest.ssa, phi_src->src.ssa);
725            break;
726         }
727      }
728      assert(found);
729   }
730}
731
732static bool
733cursor_is_after_jump(nir_cursor cursor)
734{
735   switch (cursor.option) {
736   case nir_cursor_before_instr:
737   case nir_cursor_before_block:
738      return false;
739   case nir_cursor_after_instr:
740      return cursor.instr->type == nir_instr_type_jump;
741   case nir_cursor_after_block:
742      return nir_block_ends_in_jump(cursor.block);;
743   }
744   unreachable("Invalid cursor option");
745}
746
747/** Flattens if ladders leading up to a resume
748 *
749 * Given a resume_instr, this function flattens any if ladders leading to the
750 * resume instruction and deletes any code that cannot be encountered on a
751 * direct path to the resume instruction.  This way we get, for the most part,
752 * straight-line control-flow up to the resume instruction.
753 *
754 * While we do this flattening, we also move any code which is in the remat
755 * set up to the top of the function or to the top of the resume portion of
756 * the current loop.  We don't worry about control-flow as we do this because
757 * phis will never be in the remat set (see can_remat_instr) and so nothing
758 * control-dependent will ever need to be re-materialized.  It is possible
759 * that this algorithm will preserve too many instructions by moving them to
760 * the top but we leave that for DCE to clean up.  Any code not in the remat
761 * set is deleted because it's either unused in the continuation or else
762 * unspilled from a previous continuation and the unspill code is after the
763 * resume instruction.
764 *
765 * If, for instance, we have something like this:
766 *
767 *    // block 0
768 *    if (cond1) {
769 *       // block 1
770 *    } else {
771 *       // block 2
772 *       if (cond2) {
773 *          // block 3
774 *          resume;
775 *          if (cond3) {
776 *             // block 4
777 *          }
778 *       } else {
779 *          // block 5
780 *       }
781 *    }
782 *
783 * then we know, because we know the resume instruction had to be encoutered,
784 * that cond1 = false and cond2 = true and we lower as follows:
785 *
786 *    // block 0
787 *    // block 2
788 *    // block 3
789 *    resume;
790 *    if (cond3) {
791 *       // block 4
792 *    }
793 *
794 * As you can see, the code in blocks 1 and 5 was removed because there is no
795 * path from the start of the shader to the resume instruction which execute
796 * blocks 1 or 5.  Any remat code from blocks 0, 2, and 3 is preserved and
797 * moved to the top.  If the resume instruction is inside a loop then we know
798 * a priori that it is of the form
799 *
800 *    loop {
801 *       if (resume) {
802 *          // Contents containing resume_instr
803 *       } else {
804 *          // Second copy of contents
805 *       }
806 *    }
807 *
808 * In this case, we only descend into the first half of the loop.  The second
809 * half is left alone as that portion is only ever executed after the resume
810 * instruction.
811 */
812static bool
813flatten_resume_if_ladder(nir_builder *b,
814                         nir_cf_node *parent_node,
815                         struct exec_list *child_list,
816                         bool child_list_contains_cursor,
817                         nir_instr *resume_instr,
818                         struct brw_bitset *remat)
819{
820   nir_cf_list cf_list;
821
822   /* If our child list contains the cursor instruction then we start out
823    * before the cursor instruction.  We need to know this so that we can skip
824    * moving instructions which are already before the cursor.
825    */
826   bool before_cursor = child_list_contains_cursor;
827
828   nir_cf_node *resume_node = NULL;
829   foreach_list_typed_safe(nir_cf_node, child, node, child_list) {
830      switch (child->type) {
831      case nir_cf_node_block: {
832         nir_block *block = nir_cf_node_as_block(child);
833         if (b->cursor.option == nir_cursor_before_block &&
834             b->cursor.block == block) {
835            assert(before_cursor);
836            before_cursor = false;
837         }
838         nir_foreach_instr_safe(instr, block) {
839            if ((b->cursor.option == nir_cursor_before_instr ||
840                 b->cursor.option == nir_cursor_after_instr) &&
841                b->cursor.instr == instr) {
842               assert(nir_cf_node_is_first(&block->cf_node));
843               assert(before_cursor);
844               before_cursor = false;
845               continue;
846            }
847
848            if (instr == resume_instr)
849               goto found_resume;
850
851            if (!before_cursor && can_remat_instr(instr, remat)) {
852               nir_instr_remove(instr);
853               nir_instr_insert(b->cursor, instr);
854               b->cursor = nir_after_instr(instr);
855
856               nir_ssa_def *def = nir_instr_ssa_def(instr);
857               BITSET_SET(remat->set, def->index);
858            }
859         }
860         if (b->cursor.option == nir_cursor_after_block &&
861             b->cursor.block == block) {
862            assert(before_cursor);
863            before_cursor = false;
864         }
865         break;
866      }
867
868      case nir_cf_node_if: {
869         nir_if *_if = nir_cf_node_as_if(child);
870
871         /* Because of the dummy blocks inserted in the first if block of the
872          * loops, it's possible we find an empty if block that contains our
873          * cursor. At this point, the block should still be empty and we can
874          * just skip it and consider we're after the cursor.
875          */
876         if (cf_node_contains_block(&_if->cf_node,
877                                    nir_cursor_current_block(b->cursor))) {
878            /* Some sanity checks to verify this is actually a dummy block */
879            assert(nir_src_as_bool(_if->condition) == true);
880            assert(nir_cf_list_is_empty_block(&_if->then_list));
881            assert(nir_cf_list_is_empty_block(&_if->else_list));
882            before_cursor = false;
883            break;
884         }
885         assert(!before_cursor);
886
887         if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->then_list,
888                                      false, resume_instr, remat)) {
889            resume_node = child;
890            rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
891                                 nir_if_last_then_block(_if));
892            goto found_resume;
893         }
894
895         if (flatten_resume_if_ladder(b, &_if->cf_node, &_if->else_list,
896                                      false, resume_instr, remat)) {
897            resume_node = child;
898            rewrite_phis_to_pred(nir_cf_node_as_block(nir_cf_node_next(child)),
899                                 nir_if_last_else_block(_if));
900            goto found_resume;
901         }
902         break;
903      }
904
905      case nir_cf_node_loop: {
906         assert(!before_cursor);
907         nir_loop *loop = nir_cf_node_as_loop(child);
908
909         if (cf_node_contains_block(&loop->cf_node, resume_instr->block)) {
910            /* Thanks to our loop body duplication pass, every level of loop
911             * containing the resume instruction contains exactly three nodes:
912             * two blocks and an if.  We don't want to lower away this if
913             * because it's the resume selection if.  The resume half is
914             * always the then_list so that's what we want to flatten.
915             */
916            nir_block *header = nir_loop_first_block(loop);
917            nir_if *_if = nir_cf_node_as_if(nir_cf_node_next(&header->cf_node));
918
919            nir_builder bl;
920            nir_builder_init(&bl, b->impl);
921            bl.cursor = nir_before_cf_list(&_if->then_list);
922            /* We want to place anything re-materialized from inside the loop
923             * at the top of the resume half of the loop.
924             *
925             * Because we're inside a loop, we might run into a break/continue
926             * instructions. We can't place those within a block of
927             * instructions, they need to be at the end of a block. So we
928             * build our own dummy block to place them.
929             */
930            nir_push_if(&bl, nir_imm_true(&bl));
931            {
932               ASSERTED bool found =
933                  flatten_resume_if_ladder(&bl, &_if->cf_node, &_if->then_list,
934                                           true, resume_instr, remat);
935               assert(found);
936            }
937            nir_pop_if(&bl, NULL);
938
939            resume_node = child;
940            goto found_resume;
941         } else {
942            ASSERTED bool found =
943               flatten_resume_if_ladder(b, &loop->cf_node, &loop->body,
944                                        false, resume_instr, remat);
945            assert(!found);
946         }
947         break;
948      }
949
950      case nir_cf_node_function:
951         unreachable("Unsupported CF node type");
952      }
953   }
954   assert(!before_cursor);
955
956   /* If we got here, we didn't find the resume node or instruction. */
957   return false;
958
959found_resume:
960   /* If we got here then we found either the resume node or the resume
961    * instruction in this CF list.
962    */
963   if (resume_node) {
964      /* If the resume instruction is buried in side one of our children CF
965       * nodes, resume_node now points to that child.
966       */
967      if (resume_node->type == nir_cf_node_if) {
968         /* Thanks to the recursive call, all of the interesting contents of
969          * resume_node have been copied before the cursor.  We just need to
970          * copy the stuff after resume_node.
971          */
972         nir_cf_extract(&cf_list, nir_after_cf_node(resume_node),
973                                  nir_after_cf_list(child_list));
974      } else {
975         /* The loop contains its own cursor and still has useful stuff in it.
976          * We want to move everything after and including the loop to before
977          * the cursor.
978          */
979         assert(resume_node->type == nir_cf_node_loop);
980         nir_cf_extract(&cf_list, nir_before_cf_node(resume_node),
981                                  nir_after_cf_list(child_list));
982      }
983   } else {
984      /* If we found the resume instruction in one of our blocks, grab
985       * everything after it in the entire list (not just the one block), and
986       * place it before the cursor instr.
987       */
988      nir_cf_extract(&cf_list, nir_after_instr(resume_instr),
989                               nir_after_cf_list(child_list));
990   }
991
992   if (cursor_is_after_jump(b->cursor)) {
993      /* If the resume instruction is in a loop, it's possible cf_list ends
994       * in a break or continue instruction, in which case we don't want to
995       * insert anything.  It's also possible we have an early return if
996       * someone hasn't lowered those yet.  In either case, nothing after that
997       * point executes in this context so we can delete it.
998       */
999      nir_cf_delete(&cf_list);
1000   } else {
1001      b->cursor = nir_cf_reinsert(&cf_list, b->cursor);
1002   }
1003
1004   if (!resume_node) {
1005      /* We want the resume to be the first "interesting" instruction */
1006      nir_instr_remove(resume_instr);
1007      nir_instr_insert(nir_before_cf_list(&b->impl->body), resume_instr);
1008   }
1009
1010   /* We've copied everything interesting out of this CF list to before the
1011    * cursor.  Delete everything else.
1012    */
1013   if (child_list_contains_cursor) {
1014      /* If the cursor is in child_list, then we're either a loop or function
1015       * that contains the cursor. Cursors are always placed in a wrapper if
1016       * (true) to deal with break/continue and early returns. We've already
1017       * moved everything interesting inside the wrapper if and we want to
1018       * remove whatever is left after it.
1019       */
1020      nir_block *cursor_block = nir_cursor_current_block(b->cursor);
1021      nir_if *wrapper_if = nir_cf_node_as_if(cursor_block->cf_node.parent);
1022      assert(wrapper_if->cf_node.parent == parent_node);
1023      /* The wrapper if blocks are either put into the body of the main
1024       * function, or within the resume if block of the loops.
1025       */
1026      assert(parent_node->type == nir_cf_node_function ||
1027             (parent_node->type == nir_cf_node_if &&
1028              parent_node->parent->type == nir_cf_node_loop));
1029      nir_cf_extract(&cf_list, nir_after_cf_node(&wrapper_if->cf_node),
1030                     nir_after_cf_list(child_list));
1031   } else {
1032      nir_cf_list_extract(&cf_list, child_list);
1033   }
1034   nir_cf_delete(&cf_list);
1035
1036   return true;
1037}
1038
1039static nir_instr *
1040lower_resume(nir_shader *shader, int call_idx)
1041{
1042   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1043
1044   nir_instr *resume_instr = find_resume_instr(impl, call_idx);
1045
1046   if (duplicate_loop_bodies(impl, resume_instr)) {
1047      nir_validate_shader(shader, "after duplicate_loop_bodies in "
1048                                  "brw_nir_lower_shader_calls");
1049      /* If we duplicated the bodies of any loops, run regs_to_ssa to get rid
1050       * of all those pesky registers we just added.
1051       */
1052      NIR_PASS_V(shader, nir_lower_regs_to_ssa);
1053   }
1054
1055   /* Re-index nir_ssa_def::index.  We don't care about actual liveness in
1056    * this pass but, so we can use the same helpers as the spilling pass, we
1057    * need to make sure that live_index is something sane.  It's used
1058    * constantly for determining if an SSA value has been added since the
1059    * start of the pass.
1060    */
1061   nir_index_ssa_defs(impl);
1062
1063   void *mem_ctx = ralloc_context(shader);
1064
1065   /* Used to track which things may have been assumed to be re-materialized
1066    * by the spilling pass and which we shouldn't delete.
1067    */
1068   struct brw_bitset remat = bitset_create(mem_ctx, impl->ssa_alloc);
1069
1070   /* Create a nop instruction to use as a cursor as we extract and re-insert
1071    * stuff into the CFG.
1072    */
1073   nir_builder b;
1074   nir_builder_init(&b, impl);
1075   b.cursor = nir_before_cf_list(&impl->body);
1076
1077   nir_push_if(&b, nir_imm_true(&b));
1078   {
1079      ASSERTED bool found =
1080         flatten_resume_if_ladder(&b, &impl->cf_node, &impl->body,
1081                                  true, resume_instr, &remat);
1082      assert(found);
1083   }
1084   nir_pop_if(&b, NULL);
1085
1086   ralloc_free(mem_ctx);
1087
1088   nir_validate_shader(shader, "after flatten_resume_if_ladder in "
1089                               "brw_nir_lower_shader_calls");
1090
1091   nir_metadata_preserve(impl, nir_metadata_none);
1092
1093   return resume_instr;
1094}
1095
1096static void
1097replace_resume_with_halt(nir_shader *shader, nir_instr *keep)
1098{
1099   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1100
1101   nir_builder b;
1102   nir_builder_init(&b, impl);
1103
1104   nir_foreach_block_safe(block, impl) {
1105      nir_foreach_instr_safe(instr, block) {
1106         if (instr == keep)
1107            continue;
1108
1109         if (instr->type != nir_instr_type_intrinsic)
1110            continue;
1111
1112         nir_intrinsic_instr *resume = nir_instr_as_intrinsic(instr);
1113         if (resume->intrinsic != nir_intrinsic_rt_resume)
1114            continue;
1115
1116         /* If this is some other resume, then we've kicked off a ray or
1117          * bindless thread and we don't want to go any further in this
1118          * shader.  Insert a halt so that NIR will delete any instructions
1119          * dominated by this call instruction including the scratch_load
1120          * instructions we inserted.
1121          */
1122         nir_cf_list cf_list;
1123         nir_cf_extract(&cf_list, nir_after_instr(&resume->instr),
1124                                  nir_after_block(block));
1125         nir_cf_delete(&cf_list);
1126         b.cursor = nir_instr_remove(&resume->instr);
1127         nir_jump(&b, nir_jump_halt);
1128         break;
1129      }
1130   }
1131}
1132
1133/** Lower shader call instructions to split shaders.
1134 *
1135 * Shader calls can be split into an initial shader and a series of "resume"
1136 * shaders.   When the shader is first invoked, it is the initial shader which
1137 * is executed.  At any point in the initial shader or any one of the resume
1138 * shaders, a shader call operation may be performed.  The possible shader call
1139 * operations are:
1140 *
1141 *  - trace_ray
1142 *  - report_ray_intersection
1143 *  - execute_callable
1144 *
1145 * When a shader call operation is performed, we push all live values to the
1146 * stack,call rt_trace_ray/rt_execute_callable and then kill the shader. Once
1147 * the operation we invoked is complete, a callee shader will return execution
1148 * to the respective resume shader. The resume shader pops the contents off
1149 * the stack and picks up where the calling shader left off.
1150 *
1151 * Stack management is assumed to be done after this pass. Call
1152 * instructions and their resumes get annotated with stack information that
1153 * should be enough for the backend to implement proper stack management.
1154 */
1155bool
1156nir_lower_shader_calls(nir_shader *shader,
1157                       nir_address_format address_format,
1158                       unsigned stack_alignment,
1159                       nir_shader ***resume_shaders_out,
1160                       uint32_t *num_resume_shaders_out,
1161                       void *mem_ctx)
1162{
1163   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
1164
1165   nir_builder b;
1166   nir_builder_init(&b, impl);
1167
1168   int num_calls = 0;
1169   nir_foreach_block(block, impl) {
1170      nir_foreach_instr_safe(instr, block) {
1171         if (instr_is_shader_call(instr))
1172            num_calls++;
1173      }
1174   }
1175
1176   if (num_calls == 0) {
1177      nir_shader_preserve_all_metadata(shader);
1178      *num_resume_shaders_out = 0;
1179      return false;
1180   }
1181
1182   /* Some intrinsics not only can't be re-materialized but aren't preserved
1183    * when moving to the continuation shader.  We have to move them to the top
1184    * to ensure they get spilled as needed.
1185    */
1186   {
1187      bool progress = false;
1188      NIR_PASS(progress, shader, move_system_values_to_top);
1189      if (progress)
1190         NIR_PASS(progress, shader, nir_opt_cse);
1191   }
1192
1193   NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls,
1194              num_calls, address_format, stack_alignment);
1195
1196   nir_opt_remove_phis(shader);
1197
1198   /* Make N copies of our shader */
1199   nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);
1200   for (unsigned i = 0; i < num_calls; i++) {
1201      resume_shaders[i] = nir_shader_clone(mem_ctx, shader);
1202
1203      /* Give them a recognizable name */
1204      resume_shaders[i]->info.name =
1205         ralloc_asprintf(mem_ctx, "%s%sresume_%u",
1206                         shader->info.name ? shader->info.name : "",
1207                         shader->info.name ? "-" : "",
1208                         i);
1209   }
1210
1211   replace_resume_with_halt(shader, NULL);
1212   for (unsigned i = 0; i < num_calls; i++) {
1213      nir_instr *resume_instr = lower_resume(resume_shaders[i], i);
1214      replace_resume_with_halt(resume_shaders[i], resume_instr);
1215      nir_opt_remove_phis(resume_shaders[i]);
1216      /* Remove the dummy blocks added by flatten_resume_if_ladder() */
1217      nir_opt_if(resume_shaders[i], nir_opt_if_optimize_phi_true_false);
1218   }
1219
1220   *resume_shaders_out = resume_shaders;
1221   *num_resume_shaders_out = num_calls;
1222
1223   return true;
1224}
1225