1/*
2 * Copyright © 2015 Thomas Helland
3 * Copyright © 2019 Valve Corporation
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 */
24
25/*
26 * This pass converts the ssa-graph into "Loop Closed SSA form". This is
27 * done by placing phi nodes at the exits of the loop for all values
28 * that are used outside the loop. The result is it transforms:
29 *
30 * loop {                    ->      loop {
31 *    ssa2 = ....            ->          ssa2 = ...
32 *    if (cond)              ->          if (cond)
33 *       break;              ->             break;
34 *    ssa3 = ssa2 * ssa4     ->          ssa3 = ssa2 * ssa4
35 * }                         ->       }
36 * ssa6 = ssa2 + 4           ->       ssa5 = phi(ssa2)
37 *                                    ssa6 = ssa5 + 4
38 */
39
40#include "nir.h"
41
42typedef struct {
43   /* The nir_shader we are transforming */
44   nir_shader *shader;
45
46   /* The loop we store information for */
47   nir_loop *loop;
48   nir_block *block_after_loop;
49   nir_block **exit_blocks;
50
51   /* Whether to skip loop invariant variables */
52   bool skip_invariants;
53   bool skip_bool_invariants;
54
55   bool progress;
56} lcssa_state;
57
58static bool
59is_if_use_inside_loop(nir_src *use, nir_loop *loop)
60{
61   nir_block *block_before_loop =
62      nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
63   nir_block *block_after_loop =
64      nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
65
66   nir_block *prev_block =
67      nir_cf_node_as_block(nir_cf_node_prev(&use->parent_if->cf_node));
68   if (prev_block->index <= block_before_loop->index ||
69       prev_block->index >= block_after_loop->index) {
70      return false;
71   }
72
73   return true;
74}
75
76static bool
77is_use_inside_loop(nir_src *use, nir_loop *loop)
78{
79   nir_block *block_before_loop =
80      nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
81   nir_block *block_after_loop =
82      nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
83
84   if (use->parent_instr->block->index <= block_before_loop->index ||
85       use->parent_instr->block->index >= block_after_loop->index) {
86      return false;
87   }
88
89   return true;
90}
91
92static bool
93is_defined_before_loop(nir_ssa_def *def, nir_loop *loop)
94{
95   nir_instr *instr = def->parent_instr;
96   nir_block *block_before_loop =
97      nir_cf_node_as_block(nir_cf_node_prev(&loop->cf_node));
98
99   return instr->block->index <= block_before_loop->index;
100}
101
102typedef enum instr_invariance {
103   undefined = 0,
104   invariant,
105   not_invariant,
106} instr_invariance;
107
108static instr_invariance
109instr_is_invariant(nir_instr *instr, nir_loop *loop);
110
111static bool
112def_is_invariant(nir_ssa_def *def, nir_loop *loop)
113{
114   if (is_defined_before_loop(def, loop))
115      return invariant;
116
117   if (def->parent_instr->pass_flags == undefined)
118      def->parent_instr->pass_flags = instr_is_invariant(def->parent_instr, loop);
119
120   return def->parent_instr->pass_flags == invariant;
121}
122
123static bool
124src_is_invariant(nir_src *src, void *state)
125{
126   assert(src->is_ssa);
127   return def_is_invariant(src->ssa, (nir_loop *)state);
128}
129
130static instr_invariance
131phi_is_invariant(nir_phi_instr *instr, nir_loop *loop)
132{
133   /* Base case: it's a phi at the loop header
134    * Loop-header phis are updated in each loop iteration with
135    * the loop-carried value, and thus control-flow dependent
136    * on the loop itself.
137    */
138   if (instr->instr.block == nir_loop_first_block(loop))
139      return not_invariant;
140
141   nir_foreach_phi_src(src, instr) {
142      if (!src_is_invariant(&src->src, loop))
143         return not_invariant;
144   }
145
146   /* All loop header- and LCSSA-phis should be handled by this point. */
147   nir_cf_node *prev = nir_cf_node_prev(&instr->instr.block->cf_node);
148   assert(prev && prev->type == nir_cf_node_if);
149
150   /* Invariance of phis after if-nodes also depends on the invariance
151    * of the branch condition.
152    */
153   nir_if *if_node = nir_cf_node_as_if(prev);
154   if (!def_is_invariant(if_node->condition.ssa, loop))
155      return not_invariant;
156
157   return invariant;
158}
159
160
161/* An instruction is said to be loop-invariant if it
162 * - has no sideeffects and
163 * - solely depends on variables defined outside of the loop or
164 *   by other invariant instructions
165 */
166static instr_invariance
167instr_is_invariant(nir_instr *instr, nir_loop *loop)
168{
169   assert(instr->pass_flags == undefined);
170
171   switch (instr->type) {
172   case nir_instr_type_load_const:
173   case nir_instr_type_ssa_undef:
174      return invariant;
175   case nir_instr_type_call:
176      return not_invariant;
177   case nir_instr_type_phi:
178      return phi_is_invariant(nir_instr_as_phi(instr), loop);
179   case nir_instr_type_intrinsic: {
180      nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
181      if (!(nir_intrinsic_infos[intrinsic->intrinsic].flags & NIR_INTRINSIC_CAN_REORDER))
182         return not_invariant;
183   }
184   FALLTHROUGH;
185   default:
186      return nir_foreach_src(instr, src_is_invariant, loop) ? invariant : not_invariant;
187   }
188
189   return invariant;
190}
191
192static bool
193convert_loop_exit_for_ssa(nir_ssa_def *def, void *void_state)
194{
195   lcssa_state *state = void_state;
196   bool all_uses_inside_loop = true;
197
198   /* Don't create LCSSA-Phis for loop-invariant variables */
199   if (state->skip_invariants &&
200       (def->bit_size != 1 || state->skip_bool_invariants)) {
201      assert(def->parent_instr->pass_flags != undefined);
202      if (def->parent_instr->pass_flags == invariant)
203         return true;
204   }
205
206   nir_foreach_use(use, def) {
207      if (use->parent_instr->type == nir_instr_type_phi &&
208          use->parent_instr->block == state->block_after_loop) {
209         continue;
210      }
211
212      if (!is_use_inside_loop(use, state->loop)) {
213         all_uses_inside_loop = false;
214      }
215   }
216
217   nir_foreach_if_use(use, def) {
218      if (!is_if_use_inside_loop(use, state->loop)) {
219         all_uses_inside_loop = false;
220      }
221   }
222
223   /* There where no sources that had defs outside the loop */
224   if (all_uses_inside_loop)
225      return true;
226
227   /* Initialize a phi-instruction */
228   nir_phi_instr *phi = nir_phi_instr_create(state->shader);
229   nir_ssa_dest_init(&phi->instr, &phi->dest,
230                     def->num_components, def->bit_size, "LCSSA-phi");
231
232   /* Create a phi node with as many sources pointing to the same ssa_def as
233    * the block has predecessors.
234    */
235   uint32_t num_exits = state->block_after_loop->predecessors->entries;
236   for (uint32_t i = 0; i < num_exits; i++) {
237      nir_phi_instr_add_src(phi, state->exit_blocks[i], nir_src_for_ssa(def));
238   }
239
240   nir_instr_insert_before_block(state->block_after_loop, &phi->instr);
241   nir_ssa_def *dest = &phi->dest.ssa;
242
243   /* deref instructions need a cast after the phi */
244   if (def->parent_instr->type == nir_instr_type_deref) {
245      nir_deref_instr *cast =
246         nir_deref_instr_create(state->shader, nir_deref_type_cast);
247
248      nir_deref_instr *instr = nir_instr_as_deref(def->parent_instr);
249      cast->modes = instr->modes;
250      cast->type = instr->type;
251      cast->parent = nir_src_for_ssa(&phi->dest.ssa);
252      cast->cast.ptr_stride = nir_deref_instr_array_stride(instr);
253
254      nir_ssa_dest_init(&cast->instr, &cast->dest,
255                        phi->dest.ssa.num_components,
256                        phi->dest.ssa.bit_size, NULL);
257      nir_instr_insert(nir_after_phis(state->block_after_loop), &cast->instr);
258      dest = &cast->dest.ssa;
259   }
260
261   /* Run through all uses and rewrite those outside the loop to point to
262    * the phi instead of pointing to the ssa-def.
263    */
264   nir_foreach_use_safe(use, def) {
265      if (use->parent_instr->type == nir_instr_type_phi &&
266          state->block_after_loop == use->parent_instr->block) {
267         continue;
268      }
269
270      if (!is_use_inside_loop(use, state->loop)) {
271         nir_instr_rewrite_src(use->parent_instr, use, nir_src_for_ssa(dest));
272      }
273   }
274
275   nir_foreach_if_use_safe(use, def) {
276      if (!is_if_use_inside_loop(use, state->loop)) {
277         nir_if_rewrite_condition(use->parent_if, nir_src_for_ssa(dest));
278      }
279   }
280
281   state->progress = true;
282   return true;
283}
284
285static void
286setup_loop_state(lcssa_state *state, nir_loop *loop)
287{
288   state->loop = loop;
289   state->block_after_loop =
290      nir_cf_node_as_block(nir_cf_node_next(&loop->cf_node));
291
292   ralloc_free(state->exit_blocks);
293   state->exit_blocks = nir_block_get_predecessors_sorted(state->block_after_loop, state);
294}
295
296static void
297convert_to_lcssa(nir_cf_node *cf_node, lcssa_state *state)
298{
299   switch (cf_node->type) {
300   case nir_cf_node_block:
301      return;
302   case nir_cf_node_if: {
303      nir_if *if_stmt = nir_cf_node_as_if(cf_node);
304      foreach_list_typed(nir_cf_node, nested_node, node, &if_stmt->then_list)
305         convert_to_lcssa(nested_node, state);
306      foreach_list_typed(nir_cf_node, nested_node, node, &if_stmt->else_list)
307         convert_to_lcssa(nested_node, state);
308      return;
309   }
310   case nir_cf_node_loop: {
311      if (state->skip_invariants) {
312         nir_foreach_block_in_cf_node(block, cf_node) {
313            nir_foreach_instr(instr, block)
314               instr->pass_flags = undefined;
315         }
316      }
317
318      /* first, convert inner loops */
319      nir_loop *loop = nir_cf_node_as_loop(cf_node);
320      foreach_list_typed(nir_cf_node, nested_node, node, &loop->body)
321         convert_to_lcssa(nested_node, state);
322
323      setup_loop_state(state, loop);
324
325      /* mark loop-invariant instructions */
326      if (state->skip_invariants) {
327         /* Without a loop all instructions are invariant.
328          * For outer loops, multiple breaks can still create phis.
329          * The variance then depends on all (nested) break conditions.
330          * We don't consider this, but assume all not_invariant.
331          */
332         if (nir_loop_first_block(loop)->predecessors->entries == 1)
333            goto end;
334
335         nir_foreach_block_in_cf_node(block, cf_node) {
336            nir_foreach_instr(instr, block) {
337               if (instr->pass_flags == undefined)
338                  instr->pass_flags = instr_is_invariant(instr, nir_cf_node_as_loop(cf_node));
339            }
340         }
341      }
342
343      nir_foreach_block_in_cf_node(block, cf_node) {
344         nir_foreach_instr(instr, block) {
345            nir_foreach_ssa_def(instr, convert_loop_exit_for_ssa, state);
346
347            /* for outer loops, invariant instructions can be variant */
348            if (state->skip_invariants && instr->pass_flags == invariant)
349               instr->pass_flags = undefined;
350         }
351      }
352
353end:
354      /* For outer loops, the LCSSA-phi should be considered not invariant */
355      if (state->skip_invariants) {
356         nir_foreach_instr(instr, state->block_after_loop) {
357            if (instr->type == nir_instr_type_phi)
358               instr->pass_flags = not_invariant;
359            else
360               break;
361         }
362      }
363      return;
364   }
365   default:
366      unreachable("unknown cf node type");
367   }
368}
369
370void
371nir_convert_loop_to_lcssa(nir_loop *loop)
372{
373   nir_function_impl *impl = nir_cf_node_get_function(&loop->cf_node);
374
375   nir_metadata_require(impl, nir_metadata_block_index);
376
377   lcssa_state *state = rzalloc(NULL, lcssa_state);
378   setup_loop_state(state, loop);
379   state->shader = impl->function->shader;
380   state->skip_invariants = false;
381   state->skip_bool_invariants = false;
382
383   nir_foreach_block_in_cf_node (block, &loop->cf_node) {
384      nir_foreach_instr(instr, block)
385         nir_foreach_ssa_def(instr, convert_loop_exit_for_ssa, state);
386   }
387
388   ralloc_free(state);
389}
390
391bool
392nir_convert_to_lcssa(nir_shader *shader, bool skip_invariants, bool skip_bool_invariants)
393{
394   bool progress = false;
395   lcssa_state *state = rzalloc(NULL, lcssa_state);
396   state->shader = shader;
397   state->skip_invariants = skip_invariants;
398   state->skip_bool_invariants = skip_bool_invariants;
399
400   nir_foreach_function(function, shader) {
401      if (function->impl == NULL)
402         continue;
403
404      state->progress = false;
405      nir_metadata_require(function->impl, nir_metadata_block_index);
406
407      foreach_list_typed(nir_cf_node, node, node, &function->impl->body)
408         convert_to_lcssa(node, state);
409
410      if (state->progress) {
411         progress = true;
412         nir_metadata_preserve(function->impl, nir_metadata_block_index |
413                                               nir_metadata_dominance);
414      } else {
415         nir_metadata_preserve(function->impl, nir_metadata_all);
416      }
417   }
418
419   ralloc_free(state);
420   return progress;
421}
422
423