1/*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "nir.h"
25#include "nir_builder.h"
26#include "nir_deref.h"
27#include "nir_vla.h"
28
29#include "util/set.h"
30#include "util/u_math.h"
31
32static struct set *
33get_complex_used_vars(nir_shader *shader, void *mem_ctx)
34{
35   struct set *complex_vars = _mesa_pointer_set_create(mem_ctx);
36
37   nir_foreach_function(function, shader) {
38      if (!function->impl)
39         continue;
40
41      nir_foreach_block(block, function->impl) {
42         nir_foreach_instr(instr, block) {
43            if (instr->type != nir_instr_type_deref)
44               continue;
45
46            nir_deref_instr *deref = nir_instr_as_deref(instr);
47
48            /* We only need to consider var derefs because
49             * nir_deref_instr_has_complex_use is recursive.
50             */
51            if (deref->deref_type == nir_deref_type_var &&
52                nir_deref_instr_has_complex_use(deref, 0))
53               _mesa_set_add(complex_vars, deref->var);
54         }
55      }
56   }
57
58   return complex_vars;
59}
60
61struct split_var_state {
62   void *mem_ctx;
63
64   nir_shader *shader;
65   nir_function_impl *impl;
66
67   nir_variable *base_var;
68};
69
70struct field {
71   struct field *parent;
72
73   const struct glsl_type *type;
74
75   unsigned num_fields;
76   struct field *fields;
77
78   nir_variable *var;
79};
80
81static int
82num_array_levels_in_array_of_vector_type(const struct glsl_type *type)
83{
84   int num_levels = 0;
85   while (true) {
86      if (glsl_type_is_array_or_matrix(type)) {
87         num_levels++;
88         type = glsl_get_array_element(type);
89      } else if (glsl_type_is_vector_or_scalar(type)) {
90         return num_levels;
91      } else {
92         /* Not an array of vectors */
93         return -1;
94      }
95   }
96}
97
98static void
99init_field_for_type(struct field *field, struct field *parent,
100                    const struct glsl_type *type,
101                    const char *name,
102                    struct split_var_state *state)
103{
104   *field = (struct field) {
105      .parent = parent,
106      .type = type,
107   };
108
109   const struct glsl_type *struct_type = glsl_without_array(type);
110   if (glsl_type_is_struct_or_ifc(struct_type)) {
111      field->num_fields = glsl_get_length(struct_type),
112      field->fields = ralloc_array(state->mem_ctx, struct field,
113                                   field->num_fields);
114      for (unsigned i = 0; i < field->num_fields; i++) {
115         char *field_name = NULL;
116         if (name) {
117            field_name = ralloc_asprintf(state->mem_ctx, "%s_%s", name,
118                                         glsl_get_struct_elem_name(struct_type, i));
119         } else {
120            field_name = ralloc_asprintf(state->mem_ctx, "{unnamed %s}_%s",
121                                         glsl_get_type_name(struct_type),
122                                         glsl_get_struct_elem_name(struct_type, i));
123         }
124         init_field_for_type(&field->fields[i], field,
125                             glsl_get_struct_field(struct_type, i),
126                             field_name, state);
127      }
128   } else {
129      const struct glsl_type *var_type = type;
130      for (struct field *f = field->parent; f; f = f->parent)
131         var_type = glsl_type_wrap_in_arrays(var_type, f->type);
132
133      nir_variable_mode mode = state->base_var->data.mode;
134      if (mode == nir_var_function_temp) {
135         field->var = nir_local_variable_create(state->impl, var_type, name);
136      } else {
137         field->var = nir_variable_create(state->shader, mode, var_type, name);
138      }
139      field->var->data.ray_query = state->base_var->data.ray_query;
140   }
141}
142
143static bool
144split_var_list_structs(nir_shader *shader,
145                       nir_function_impl *impl,
146                       struct exec_list *vars,
147                       nir_variable_mode mode,
148                       struct hash_table *var_field_map,
149                       struct set **complex_vars,
150                       void *mem_ctx)
151{
152   struct split_var_state state = {
153      .mem_ctx = mem_ctx,
154      .shader = shader,
155      .impl = impl,
156   };
157
158   struct exec_list split_vars;
159   exec_list_make_empty(&split_vars);
160
161   /* To avoid list confusion (we'll be adding things as we split variables),
162    * pull all of the variables we plan to split off of the list
163    */
164   nir_foreach_variable_in_list_safe(var, vars) {
165      if (var->data.mode != mode)
166         continue;
167
168      if (!glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
169         continue;
170
171      if (*complex_vars == NULL)
172         *complex_vars = get_complex_used_vars(shader, mem_ctx);
173
174      /* We can't split a variable that's referenced with deref that has any
175       * sort of complex usage.
176       */
177      if (_mesa_set_search(*complex_vars, var))
178         continue;
179
180      exec_node_remove(&var->node);
181      exec_list_push_tail(&split_vars, &var->node);
182   }
183
184   nir_foreach_variable_in_list(var, &split_vars) {
185      state.base_var = var;
186
187      struct field *root_field = ralloc(mem_ctx, struct field);
188      init_field_for_type(root_field, NULL, var->type, var->name, &state);
189      _mesa_hash_table_insert(var_field_map, var, root_field);
190   }
191
192   return !exec_list_is_empty(&split_vars);
193}
194
195static void
196split_struct_derefs_impl(nir_function_impl *impl,
197                         struct hash_table *var_field_map,
198                         nir_variable_mode modes,
199                         void *mem_ctx)
200{
201   nir_builder b;
202   nir_builder_init(&b, impl);
203
204   nir_foreach_block(block, impl) {
205      nir_foreach_instr_safe(instr, block) {
206         if (instr->type != nir_instr_type_deref)
207            continue;
208
209         nir_deref_instr *deref = nir_instr_as_deref(instr);
210         if (!nir_deref_mode_may_be(deref, modes))
211            continue;
212
213         /* Clean up any dead derefs we find lying around.  They may refer to
214          * variables we're planning to split.
215          */
216         if (nir_deref_instr_remove_if_unused(deref))
217            continue;
218
219         if (!glsl_type_is_vector_or_scalar(deref->type))
220            continue;
221
222         nir_variable *base_var = nir_deref_instr_get_variable(deref);
223         /* If we can't chase back to the variable, then we're a complex use.
224          * This should have been detected by get_complex_used_vars() and the
225          * variable should not have been split.  However, we have no way of
226          * knowing that here, so we just have to trust it.
227          */
228         if (base_var == NULL)
229            continue;
230
231         struct hash_entry *entry =
232            _mesa_hash_table_search(var_field_map, base_var);
233         if (!entry)
234            continue;
235
236         struct field *root_field = entry->data;
237
238         nir_deref_path path;
239         nir_deref_path_init(&path, deref, mem_ctx);
240
241         struct field *tail_field = root_field;
242         for (unsigned i = 0; path.path[i]; i++) {
243            if (path.path[i]->deref_type != nir_deref_type_struct)
244               continue;
245
246            assert(i > 0);
247            assert(glsl_type_is_struct_or_ifc(path.path[i - 1]->type));
248            assert(path.path[i - 1]->type ==
249                   glsl_without_array(tail_field->type));
250
251            tail_field = &tail_field->fields[path.path[i]->strct.index];
252         }
253         nir_variable *split_var = tail_field->var;
254
255         nir_deref_instr *new_deref = NULL;
256         for (unsigned i = 0; path.path[i]; i++) {
257            nir_deref_instr *p = path.path[i];
258            b.cursor = nir_after_instr(&p->instr);
259
260            switch (p->deref_type) {
261            case nir_deref_type_var:
262               assert(new_deref == NULL);
263               new_deref = nir_build_deref_var(&b, split_var);
264               break;
265
266            case nir_deref_type_array:
267            case nir_deref_type_array_wildcard:
268               new_deref = nir_build_deref_follower(&b, new_deref, p);
269               break;
270
271            case nir_deref_type_struct:
272               /* Nothing to do; we're splitting structs */
273               break;
274
275            default:
276               unreachable("Invalid deref type in path");
277            }
278         }
279
280         assert(new_deref->type == deref->type);
281         nir_ssa_def_rewrite_uses(&deref->dest.ssa,
282                                  &new_deref->dest.ssa);
283         nir_deref_instr_remove_if_unused(deref);
284      }
285   }
286}
287
288/** A pass for splitting structs into multiple variables
289 *
290 * This pass splits arrays of structs into multiple variables, one for each
291 * (possibly nested) structure member.  After this pass completes, no
292 * variables of the given mode will contain a struct type.
293 */
294bool
295nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
296{
297   void *mem_ctx = ralloc_context(NULL);
298   struct hash_table *var_field_map =
299      _mesa_pointer_hash_table_create(mem_ctx);
300   struct set *complex_vars = NULL;
301
302   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
303
304   bool has_global_splits = false;
305   if (modes & nir_var_shader_temp) {
306      has_global_splits = split_var_list_structs(shader, NULL,
307                                                 &shader->variables,
308                                                 nir_var_shader_temp,
309                                                 var_field_map,
310                                                 &complex_vars,
311                                                 mem_ctx);
312   }
313
314   bool progress = false;
315   nir_foreach_function(function, shader) {
316      if (!function->impl)
317         continue;
318
319      bool has_local_splits = false;
320      if (modes & nir_var_function_temp) {
321         has_local_splits = split_var_list_structs(shader, function->impl,
322                                                   &function->impl->locals,
323                                                   nir_var_function_temp,
324                                                   var_field_map,
325                                                   &complex_vars,
326                                                   mem_ctx);
327      }
328
329      if (has_global_splits || has_local_splits) {
330         split_struct_derefs_impl(function->impl, var_field_map,
331                                  modes, mem_ctx);
332
333         nir_metadata_preserve(function->impl, nir_metadata_block_index |
334                                               nir_metadata_dominance);
335         progress = true;
336      } else {
337         nir_metadata_preserve(function->impl, nir_metadata_all);
338      }
339   }
340
341   ralloc_free(mem_ctx);
342
343   return progress;
344}
345
346struct array_level_info {
347   unsigned array_len;
348   bool split;
349};
350
351struct array_split {
352   /* Only set if this is the tail end of the splitting */
353   nir_variable *var;
354
355   unsigned num_splits;
356   struct array_split *splits;
357};
358
359struct array_var_info {
360   nir_variable *base_var;
361
362   const struct glsl_type *split_var_type;
363
364   bool split_var;
365   struct array_split root_split;
366
367   unsigned num_levels;
368   struct array_level_info levels[0];
369};
370
371static bool
372init_var_list_array_infos(nir_shader *shader,
373                          struct exec_list *vars,
374                          nir_variable_mode mode,
375                          struct hash_table *var_info_map,
376                          struct set **complex_vars,
377                          void *mem_ctx)
378{
379   bool has_array = false;
380
381   nir_foreach_variable_in_list(var, vars) {
382      if (var->data.mode != mode)
383         continue;
384
385      int num_levels = num_array_levels_in_array_of_vector_type(var->type);
386      if (num_levels <= 0)
387         continue;
388
389      if (*complex_vars == NULL)
390         *complex_vars = get_complex_used_vars(shader, mem_ctx);
391
392      /* We can't split a variable that's referenced with deref that has any
393       * sort of complex usage.
394       */
395      if (_mesa_set_search(*complex_vars, var))
396         continue;
397
398      struct array_var_info *info =
399         rzalloc_size(mem_ctx, sizeof(*info) +
400                               num_levels * sizeof(info->levels[0]));
401
402      info->base_var = var;
403      info->num_levels = num_levels;
404
405      const struct glsl_type *type = var->type;
406      for (int i = 0; i < num_levels; i++) {
407         info->levels[i].array_len = glsl_get_length(type);
408         type = glsl_get_array_element(type);
409
410         /* All levels start out initially as split */
411         info->levels[i].split = true;
412      }
413
414      _mesa_hash_table_insert(var_info_map, var, info);
415      has_array = true;
416   }
417
418   return has_array;
419}
420
421static struct array_var_info *
422get_array_var_info(nir_variable *var,
423                   struct hash_table *var_info_map)
424{
425   struct hash_entry *entry =
426      _mesa_hash_table_search(var_info_map, var);
427   return entry ? entry->data : NULL;
428}
429
430static struct array_var_info *
431get_array_deref_info(nir_deref_instr *deref,
432                     struct hash_table *var_info_map,
433                     nir_variable_mode modes)
434{
435   if (!nir_deref_mode_may_be(deref, modes))
436      return NULL;
437
438   nir_variable *var = nir_deref_instr_get_variable(deref);
439   if (var == NULL)
440      return NULL;
441
442   return get_array_var_info(var, var_info_map);
443}
444
445static void
446mark_array_deref_used(nir_deref_instr *deref,
447                      struct hash_table *var_info_map,
448                      nir_variable_mode modes,
449                      void *mem_ctx)
450{
451   struct array_var_info *info =
452      get_array_deref_info(deref, var_info_map, modes);
453   if (!info)
454      return;
455
456   nir_deref_path path;
457   nir_deref_path_init(&path, deref, mem_ctx);
458
459   /* Walk the path and look for indirects.  If we have an array deref with an
460    * indirect, mark the given level as not being split.
461    */
462   for (unsigned i = 0; i < info->num_levels; i++) {
463      nir_deref_instr *p = path.path[i + 1];
464      if (p->deref_type == nir_deref_type_array &&
465          !nir_src_is_const(p->arr.index))
466         info->levels[i].split = false;
467   }
468}
469
470static void
471mark_array_usage_impl(nir_function_impl *impl,
472                      struct hash_table *var_info_map,
473                      nir_variable_mode modes,
474                      void *mem_ctx)
475{
476   nir_foreach_block(block, impl) {
477      nir_foreach_instr(instr, block) {
478         if (instr->type != nir_instr_type_intrinsic)
479            continue;
480
481         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
482         switch (intrin->intrinsic) {
483         case nir_intrinsic_copy_deref:
484            mark_array_deref_used(nir_src_as_deref(intrin->src[1]),
485                                  var_info_map, modes, mem_ctx);
486            FALLTHROUGH;
487
488         case nir_intrinsic_load_deref:
489         case nir_intrinsic_store_deref:
490            mark_array_deref_used(nir_src_as_deref(intrin->src[0]),
491                                  var_info_map, modes, mem_ctx);
492            break;
493
494         default:
495            break;
496         }
497      }
498   }
499}
500
501static void
502create_split_array_vars(struct array_var_info *var_info,
503                        unsigned level,
504                        struct array_split *split,
505                        const char *name,
506                        nir_shader *shader,
507                        nir_function_impl *impl,
508                        void *mem_ctx)
509{
510   while (level < var_info->num_levels && !var_info->levels[level].split) {
511      name = ralloc_asprintf(mem_ctx, "%s[*]", name);
512      level++;
513   }
514
515   if (level == var_info->num_levels) {
516      /* We add parens to the variable name so it looks like "(foo[2][*])" so
517       * that further derefs will look like "(foo[2][*])[ssa_6]"
518       */
519      name = ralloc_asprintf(mem_ctx, "(%s)", name);
520
521      nir_variable_mode mode = var_info->base_var->data.mode;
522      if (mode == nir_var_function_temp) {
523         split->var = nir_local_variable_create(impl,
524                                                var_info->split_var_type, name);
525      } else {
526         split->var = nir_variable_create(shader, mode,
527                                          var_info->split_var_type, name);
528      }
529      split->var->data.ray_query = var_info->base_var->data.ray_query;
530   } else {
531      assert(var_info->levels[level].split);
532      split->num_splits = var_info->levels[level].array_len;
533      split->splits = rzalloc_array(mem_ctx, struct array_split,
534                                    split->num_splits);
535      for (unsigned i = 0; i < split->num_splits; i++) {
536         create_split_array_vars(var_info, level + 1, &split->splits[i],
537                                 ralloc_asprintf(mem_ctx, "%s[%d]", name, i),
538                                 shader, impl, mem_ctx);
539      }
540   }
541}
542
543static bool
544split_var_list_arrays(nir_shader *shader,
545                      nir_function_impl *impl,
546                      struct exec_list *vars,
547                      nir_variable_mode mode,
548                      struct hash_table *var_info_map,
549                      void *mem_ctx)
550{
551   struct exec_list split_vars;
552   exec_list_make_empty(&split_vars);
553
554   nir_foreach_variable_in_list_safe(var, vars) {
555      if (var->data.mode != mode)
556         continue;
557
558      struct array_var_info *info = get_array_var_info(var, var_info_map);
559      if (!info)
560         continue;
561
562      bool has_split = false;
563      const struct glsl_type *split_type =
564         glsl_without_array_or_matrix(var->type);
565      for (int i = info->num_levels - 1; i >= 0; i--) {
566         if (info->levels[i].split) {
567            has_split = true;
568            continue;
569         }
570
571         /* If the original type was a matrix type, we'd like to keep that so
572          * we don't convert matrices into arrays.
573          */
574         if (i == info->num_levels - 1 &&
575             glsl_type_is_matrix(glsl_without_array(var->type))) {
576            split_type = glsl_matrix_type(glsl_get_base_type(split_type),
577                                          glsl_get_components(split_type),
578                                          info->levels[i].array_len);
579         } else {
580            split_type = glsl_array_type(split_type, info->levels[i].array_len, 0);
581         }
582      }
583
584      if (has_split) {
585         info->split_var_type = split_type;
586         /* To avoid list confusion (we'll be adding things as we split
587          * variables), pull all of the variables we plan to split off of the
588          * main variable list.
589          */
590         exec_node_remove(&var->node);
591         exec_list_push_tail(&split_vars, &var->node);
592      } else {
593         assert(split_type == glsl_get_bare_type(var->type));
594         /* If we're not modifying this variable, delete the info so we skip
595          * it faster in later passes.
596          */
597         _mesa_hash_table_remove_key(var_info_map, var);
598      }
599   }
600
601   nir_foreach_variable_in_list(var, &split_vars) {
602      struct array_var_info *info = get_array_var_info(var, var_info_map);
603      create_split_array_vars(info, 0, &info->root_split, var->name,
604                              shader, impl, mem_ctx);
605   }
606
607   return !exec_list_is_empty(&split_vars);
608}
609
610static bool
611deref_has_split_wildcard(nir_deref_path *path,
612                         struct array_var_info *info)
613{
614   if (info == NULL)
615      return false;
616
617   assert(path->path[0]->var == info->base_var);
618   for (unsigned i = 0; i < info->num_levels; i++) {
619      if (path->path[i + 1]->deref_type == nir_deref_type_array_wildcard &&
620          info->levels[i].split)
621         return true;
622   }
623
624   return false;
625}
626
627static bool
628array_path_is_out_of_bounds(nir_deref_path *path,
629                            struct array_var_info *info)
630{
631   if (info == NULL)
632      return false;
633
634   assert(path->path[0]->var == info->base_var);
635   for (unsigned i = 0; i < info->num_levels; i++) {
636      nir_deref_instr *p = path->path[i + 1];
637      if (p->deref_type == nir_deref_type_array_wildcard)
638         continue;
639
640      if (nir_src_is_const(p->arr.index) &&
641          nir_src_as_uint(p->arr.index) >= info->levels[i].array_len)
642         return true;
643   }
644
645   return false;
646}
647
648static void
649emit_split_copies(nir_builder *b,
650                  struct array_var_info *dst_info, nir_deref_path *dst_path,
651                  unsigned dst_level, nir_deref_instr *dst,
652                  struct array_var_info *src_info, nir_deref_path *src_path,
653                  unsigned src_level, nir_deref_instr *src)
654{
655   nir_deref_instr *dst_p, *src_p;
656
657   while ((dst_p = dst_path->path[dst_level + 1])) {
658      if (dst_p->deref_type == nir_deref_type_array_wildcard)
659         break;
660
661      dst = nir_build_deref_follower(b, dst, dst_p);
662      dst_level++;
663   }
664
665   while ((src_p = src_path->path[src_level + 1])) {
666      if (src_p->deref_type == nir_deref_type_array_wildcard)
667         break;
668
669      src = nir_build_deref_follower(b, src, src_p);
670      src_level++;
671   }
672
673   if (src_p == NULL || dst_p == NULL) {
674      assert(src_p == NULL && dst_p == NULL);
675      nir_copy_deref(b, dst, src);
676   } else {
677      assert(dst_p->deref_type == nir_deref_type_array_wildcard &&
678             src_p->deref_type == nir_deref_type_array_wildcard);
679
680      if ((dst_info && dst_info->levels[dst_level].split) ||
681          (src_info && src_info->levels[src_level].split)) {
682         /* There are no indirects at this level on one of the source or the
683          * destination so we are lowering it.
684          */
685         assert(glsl_get_length(dst_path->path[dst_level]->type) ==
686                glsl_get_length(src_path->path[src_level]->type));
687         unsigned len = glsl_get_length(dst_path->path[dst_level]->type);
688         for (unsigned i = 0; i < len; i++) {
689            emit_split_copies(b, dst_info, dst_path, dst_level + 1,
690                              nir_build_deref_array_imm(b, dst, i),
691                              src_info, src_path, src_level + 1,
692                              nir_build_deref_array_imm(b, src, i));
693         }
694      } else {
695         /* Neither side is being split so we just keep going */
696         emit_split_copies(b, dst_info, dst_path, dst_level + 1,
697                           nir_build_deref_array_wildcard(b, dst),
698                           src_info, src_path, src_level + 1,
699                           nir_build_deref_array_wildcard(b, src));
700      }
701   }
702}
703
704static void
705split_array_copies_impl(nir_function_impl *impl,
706                        struct hash_table *var_info_map,
707                        nir_variable_mode modes,
708                        void *mem_ctx)
709{
710   nir_builder b;
711   nir_builder_init(&b, impl);
712
713   nir_foreach_block(block, impl) {
714      nir_foreach_instr_safe(instr, block) {
715         if (instr->type != nir_instr_type_intrinsic)
716            continue;
717
718         nir_intrinsic_instr *copy = nir_instr_as_intrinsic(instr);
719         if (copy->intrinsic != nir_intrinsic_copy_deref)
720            continue;
721
722         nir_deref_instr *dst_deref = nir_src_as_deref(copy->src[0]);
723         nir_deref_instr *src_deref = nir_src_as_deref(copy->src[1]);
724
725         struct array_var_info *dst_info =
726            get_array_deref_info(dst_deref, var_info_map, modes);
727         struct array_var_info *src_info =
728            get_array_deref_info(src_deref, var_info_map, modes);
729
730         if (!src_info && !dst_info)
731            continue;
732
733         nir_deref_path dst_path, src_path;
734         nir_deref_path_init(&dst_path, dst_deref, mem_ctx);
735         nir_deref_path_init(&src_path, src_deref, mem_ctx);
736
737         if (!deref_has_split_wildcard(&dst_path, dst_info) &&
738             !deref_has_split_wildcard(&src_path, src_info))
739            continue;
740
741         b.cursor = nir_instr_remove(&copy->instr);
742
743         emit_split_copies(&b, dst_info, &dst_path, 0, dst_path.path[0],
744                               src_info, &src_path, 0, src_path.path[0]);
745      }
746   }
747}
748
749static void
750split_array_access_impl(nir_function_impl *impl,
751                        struct hash_table *var_info_map,
752                        nir_variable_mode modes,
753                        void *mem_ctx)
754{
755   nir_builder b;
756   nir_builder_init(&b, impl);
757
758   nir_foreach_block(block, impl) {
759      nir_foreach_instr_safe(instr, block) {
760         if (instr->type == nir_instr_type_deref) {
761            /* Clean up any dead derefs we find lying around.  They may refer
762             * to variables we're planning to split.
763             */
764            nir_deref_instr *deref = nir_instr_as_deref(instr);
765            if (nir_deref_mode_may_be(deref, modes))
766               nir_deref_instr_remove_if_unused(deref);
767            continue;
768         }
769
770         if (instr->type != nir_instr_type_intrinsic)
771            continue;
772
773         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
774         if (intrin->intrinsic != nir_intrinsic_load_deref &&
775             intrin->intrinsic != nir_intrinsic_store_deref &&
776             intrin->intrinsic != nir_intrinsic_copy_deref)
777            continue;
778
779         const unsigned num_derefs =
780            intrin->intrinsic == nir_intrinsic_copy_deref ? 2 : 1;
781
782         for (unsigned d = 0; d < num_derefs; d++) {
783            nir_deref_instr *deref = nir_src_as_deref(intrin->src[d]);
784
785            struct array_var_info *info =
786               get_array_deref_info(deref, var_info_map, modes);
787            if (!info)
788               continue;
789
790            nir_deref_path path;
791            nir_deref_path_init(&path, deref, mem_ctx);
792
793            b.cursor = nir_before_instr(&intrin->instr);
794
795            if (array_path_is_out_of_bounds(&path, info)) {
796               /* If one of the derefs is out-of-bounds, we just delete the
797                * instruction.  If a destination is out of bounds, then it may
798                * have been in-bounds prior to shrinking so we don't want to
799                * accidentally stomp something.  However, we've already proven
800                * that it will never be read so it's safe to delete.  If a
801                * source is out of bounds then it is loading random garbage.
802                * For loads, we replace their uses with an undef instruction
803                * and for copies we just delete the copy since it was writing
804                * undefined garbage anyway and we may as well leave the random
805                * garbage in the destination alone.
806                */
807               if (intrin->intrinsic == nir_intrinsic_load_deref) {
808                  nir_ssa_def *u =
809                     nir_ssa_undef(&b, intrin->dest.ssa.num_components,
810                                       intrin->dest.ssa.bit_size);
811                  nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
812                                           u);
813               }
814               nir_instr_remove(&intrin->instr);
815               for (unsigned i = 0; i < num_derefs; i++)
816                  nir_deref_instr_remove_if_unused(nir_src_as_deref(intrin->src[i]));
817               break;
818            }
819
820            struct array_split *split = &info->root_split;
821            for (unsigned i = 0; i < info->num_levels; i++) {
822               if (info->levels[i].split) {
823                  nir_deref_instr *p = path.path[i + 1];
824                  unsigned index = nir_src_as_uint(p->arr.index);
825                  assert(index < info->levels[i].array_len);
826                  split = &split->splits[index];
827               }
828            }
829            assert(!split->splits && split->var);
830
831            nir_deref_instr *new_deref = nir_build_deref_var(&b, split->var);
832            for (unsigned i = 0; i < info->num_levels; i++) {
833               if (!info->levels[i].split) {
834                  new_deref = nir_build_deref_follower(&b, new_deref,
835                                                       path.path[i + 1]);
836               }
837            }
838            assert(new_deref->type == deref->type);
839
840            /* Rewrite the deref source to point to the split one */
841            nir_instr_rewrite_src(&intrin->instr, &intrin->src[d],
842                                  nir_src_for_ssa(&new_deref->dest.ssa));
843            nir_deref_instr_remove_if_unused(deref);
844         }
845      }
846   }
847}
848
849/** A pass for splitting arrays of vectors into multiple variables
850 *
851 * This pass looks at arrays (possibly multiple levels) of vectors (not
852 * structures or other types) and tries to split them into piles of variables,
853 * one for each array element.  The heuristic used is simple: If a given array
854 * level is never used with an indirect, that array level will get split.
855 *
856 * This pass probably could handles structures easily enough but making a pass
857 * that could see through an array of structures of arrays would be difficult
858 * so it's best to just run nir_split_struct_vars first.
859 */
860bool
861nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
862{
863   void *mem_ctx = ralloc_context(NULL);
864   struct hash_table *var_info_map = _mesa_pointer_hash_table_create(mem_ctx);
865   struct set *complex_vars = NULL;
866
867   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
868
869   bool has_global_array = false;
870   if (modes & nir_var_shader_temp) {
871      has_global_array = init_var_list_array_infos(shader,
872                                                   &shader->variables,
873                                                   nir_var_shader_temp,
874                                                   var_info_map,
875                                                   &complex_vars,
876                                                   mem_ctx);
877   }
878
879   bool has_any_array = false;
880   nir_foreach_function(function, shader) {
881      if (!function->impl)
882         continue;
883
884      bool has_local_array = false;
885      if (modes & nir_var_function_temp) {
886         has_local_array = init_var_list_array_infos(shader,
887                                                     &function->impl->locals,
888                                                     nir_var_function_temp,
889                                                     var_info_map,
890                                                     &complex_vars,
891                                                     mem_ctx);
892      }
893
894      if (has_global_array || has_local_array) {
895         has_any_array = true;
896         mark_array_usage_impl(function->impl, var_info_map, modes, mem_ctx);
897      }
898   }
899
900   /* If we failed to find any arrays of arrays, bail early. */
901   if (!has_any_array) {
902      ralloc_free(mem_ctx);
903      nir_shader_preserve_all_metadata(shader);
904      return false;
905   }
906
907   bool has_global_splits = false;
908   if (modes & nir_var_shader_temp) {
909      has_global_splits = split_var_list_arrays(shader, NULL,
910                                                &shader->variables,
911                                                nir_var_shader_temp,
912                                                var_info_map, mem_ctx);
913   }
914
915   bool progress = false;
916   nir_foreach_function(function, shader) {
917      if (!function->impl)
918         continue;
919
920      bool has_local_splits = false;
921      if (modes & nir_var_function_temp) {
922         has_local_splits = split_var_list_arrays(shader, function->impl,
923                                                  &function->impl->locals,
924                                                  nir_var_function_temp,
925                                                  var_info_map, mem_ctx);
926      }
927
928      if (has_global_splits || has_local_splits) {
929         split_array_copies_impl(function->impl, var_info_map, modes, mem_ctx);
930         split_array_access_impl(function->impl, var_info_map, modes, mem_ctx);
931
932         nir_metadata_preserve(function->impl, nir_metadata_block_index |
933                                               nir_metadata_dominance);
934         progress = true;
935      } else {
936         nir_metadata_preserve(function->impl, nir_metadata_all);
937      }
938   }
939
940   ralloc_free(mem_ctx);
941
942   return progress;
943}
944
945struct array_level_usage {
946   unsigned array_len;
947
948   /* The value UINT_MAX will be used to indicate an indirect */
949   unsigned max_read;
950   unsigned max_written;
951
952   /* True if there is a copy that isn't to/from a shrinkable array */
953   bool has_external_copy;
954   struct set *levels_copied;
955};
956
957struct vec_var_usage {
958   /* Convenience set of all components this variable has */
959   nir_component_mask_t all_comps;
960
961   nir_component_mask_t comps_read;
962   nir_component_mask_t comps_written;
963
964   nir_component_mask_t comps_kept;
965
966   /* True if there is a copy that isn't to/from a shrinkable vector */
967   bool has_external_copy;
968   bool has_complex_use;
969   struct set *vars_copied;
970
971   unsigned num_levels;
972   struct array_level_usage levels[0];
973};
974
975static struct vec_var_usage *
976get_vec_var_usage(nir_variable *var,
977                  struct hash_table *var_usage_map,
978                  bool add_usage_entry, void *mem_ctx)
979{
980   struct hash_entry *entry = _mesa_hash_table_search(var_usage_map, var);
981   if (entry)
982      return entry->data;
983
984   if (!add_usage_entry)
985      return NULL;
986
987   /* Check to make sure that we are working with an array of vectors.  We
988    * don't bother to shrink single vectors because we figure that we can
989    * clean it up better with SSA than by inserting piles of vecN instructions
990    * to compact results.
991    */
992   int num_levels = num_array_levels_in_array_of_vector_type(var->type);
993   if (num_levels < 1)
994      return NULL; /* Not an array of vectors */
995
996   struct vec_var_usage *usage =
997      rzalloc_size(mem_ctx, sizeof(*usage) +
998                            num_levels * sizeof(usage->levels[0]));
999
1000   usage->num_levels = num_levels;
1001   const struct glsl_type *type = var->type;
1002   for (unsigned i = 0; i < num_levels; i++) {
1003      usage->levels[i].array_len = glsl_get_length(type);
1004      type = glsl_get_array_element(type);
1005   }
1006   assert(glsl_type_is_vector_or_scalar(type));
1007
1008   usage->all_comps = (1 << glsl_get_components(type)) - 1;
1009
1010   _mesa_hash_table_insert(var_usage_map, var, usage);
1011
1012   return usage;
1013}
1014
1015static struct vec_var_usage *
1016get_vec_deref_usage(nir_deref_instr *deref,
1017                    struct hash_table *var_usage_map,
1018                    nir_variable_mode modes,
1019                    bool add_usage_entry, void *mem_ctx)
1020{
1021   if (!nir_deref_mode_may_be(deref, modes))
1022      return NULL;
1023
1024   nir_variable *var = nir_deref_instr_get_variable(deref);
1025   if (var == NULL)
1026      return NULL;
1027
1028   return get_vec_var_usage(nir_deref_instr_get_variable(deref),
1029                            var_usage_map, add_usage_entry, mem_ctx);
1030}
1031
1032static void
1033mark_deref_if_complex(nir_deref_instr *deref,
1034                      struct hash_table *var_usage_map,
1035                      nir_variable_mode modes,
1036                      void *mem_ctx)
1037{
1038   /* Only bother with var derefs because nir_deref_instr_has_complex_use is
1039    * recursive.
1040    */
1041   if (deref->deref_type != nir_deref_type_var)
1042      return;
1043
1044   if (!(deref->var->data.mode & modes))
1045      return;
1046
1047   if (!nir_deref_instr_has_complex_use(deref, 0))
1048      return;
1049
1050   struct vec_var_usage *usage =
1051      get_vec_var_usage(deref->var, var_usage_map, true, mem_ctx);
1052   if (!usage)
1053      return;
1054
1055   usage->has_complex_use = true;
1056}
1057
1058static void
1059mark_deref_used(nir_deref_instr *deref,
1060                nir_component_mask_t comps_read,
1061                nir_component_mask_t comps_written,
1062                nir_deref_instr *copy_deref,
1063                struct hash_table *var_usage_map,
1064                nir_variable_mode modes,
1065                void *mem_ctx)
1066{
1067   if (!nir_deref_mode_may_be(deref, modes))
1068      return;
1069
1070   nir_variable *var = nir_deref_instr_get_variable(deref);
1071   if (var == NULL)
1072      return;
1073
1074   struct vec_var_usage *usage =
1075      get_vec_var_usage(var, var_usage_map, true, mem_ctx);
1076   if (!usage)
1077      return;
1078
1079   usage->comps_read |= comps_read & usage->all_comps;
1080   usage->comps_written |= comps_written & usage->all_comps;
1081
1082   struct vec_var_usage *copy_usage = NULL;
1083   if (copy_deref) {
1084      copy_usage = get_vec_deref_usage(copy_deref, var_usage_map, modes,
1085                                       true, mem_ctx);
1086      if (copy_usage) {
1087         if (usage->vars_copied == NULL) {
1088            usage->vars_copied = _mesa_pointer_set_create(mem_ctx);
1089         }
1090         _mesa_set_add(usage->vars_copied, copy_usage);
1091      } else {
1092         usage->has_external_copy = true;
1093      }
1094   }
1095
1096   nir_deref_path path;
1097   nir_deref_path_init(&path, deref, mem_ctx);
1098
1099   nir_deref_path copy_path;
1100   if (copy_usage)
1101      nir_deref_path_init(&copy_path, copy_deref, mem_ctx);
1102
1103   unsigned copy_i = 0;
1104   for (unsigned i = 0; i < usage->num_levels; i++) {
1105      struct array_level_usage *level = &usage->levels[i];
1106      nir_deref_instr *deref = path.path[i + 1];
1107      assert(deref->deref_type == nir_deref_type_array ||
1108             deref->deref_type == nir_deref_type_array_wildcard);
1109
1110      unsigned max_used;
1111      if (deref->deref_type == nir_deref_type_array) {
1112         max_used = nir_src_is_const(deref->arr.index) ?
1113                    nir_src_as_uint(deref->arr.index) : UINT_MAX;
1114      } else {
1115         /* For wildcards, we read or wrote the whole thing. */
1116         assert(deref->deref_type == nir_deref_type_array_wildcard);
1117         max_used = level->array_len - 1;
1118
1119         if (copy_usage) {
1120            /* Match each wildcard level with the level on copy_usage */
1121            for (; copy_path.path[copy_i + 1]; copy_i++) {
1122               if (copy_path.path[copy_i + 1]->deref_type ==
1123                   nir_deref_type_array_wildcard)
1124                  break;
1125            }
1126            struct array_level_usage *copy_level =
1127               &copy_usage->levels[copy_i++];
1128
1129            if (level->levels_copied == NULL) {
1130               level->levels_copied = _mesa_pointer_set_create(mem_ctx);
1131            }
1132            _mesa_set_add(level->levels_copied, copy_level);
1133         } else {
1134            /* We have a wildcard and it comes from a variable we aren't
1135             * tracking; flag it and we'll know to not shorten this array.
1136             */
1137            level->has_external_copy = true;
1138         }
1139      }
1140
1141      if (comps_written)
1142         level->max_written = MAX2(level->max_written, max_used);
1143      if (comps_read)
1144         level->max_read = MAX2(level->max_read, max_used);
1145   }
1146}
1147
1148static bool
1149src_is_load_deref(nir_src src, nir_src deref_src)
1150{
1151   nir_intrinsic_instr *load = nir_src_as_intrinsic(src);
1152   if (load == NULL || load->intrinsic != nir_intrinsic_load_deref)
1153      return false;
1154
1155   assert(load->src[0].is_ssa);
1156
1157   return load->src[0].ssa == deref_src.ssa;
1158}
1159
1160/* Returns all non-self-referential components of a store instruction.  A
1161 * component is self-referential if it comes from the same component of a load
1162 * instruction on the same deref.  If the only data in a particular component
1163 * of a variable came directly from that component then it's undefined.  The
1164 * only way to get defined data into a component of a variable is for it to
1165 * get written there by something outside or from a different component.
1166 *
1167 * This is a fairly common pattern in shaders that come from either GLSL IR or
1168 * GLSLang because both glsl_to_nir and GLSLang implement write-masking with
1169 * load-vec-store.
1170 */
1171static nir_component_mask_t
1172get_non_self_referential_store_comps(nir_intrinsic_instr *store)
1173{
1174   nir_component_mask_t comps = nir_intrinsic_write_mask(store);
1175
1176   assert(store->src[1].is_ssa);
1177   nir_instr *src_instr = store->src[1].ssa->parent_instr;
1178   if (src_instr->type != nir_instr_type_alu)
1179      return comps;
1180
1181   nir_alu_instr *src_alu = nir_instr_as_alu(src_instr);
1182
1183   if (src_alu->op == nir_op_mov) {
1184      /* If it's just a swizzle of a load from the same deref, discount any
1185       * channels that don't move in the swizzle.
1186       */
1187      if (src_is_load_deref(src_alu->src[0].src, store->src[0])) {
1188         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
1189            if (src_alu->src[0].swizzle[i] == i)
1190               comps &= ~(1u << i);
1191         }
1192      }
1193   } else if (nir_op_is_vec(src_alu->op)) {
1194      /* If it's a vec, discount any channels that are just loads from the
1195       * same deref put in the same spot.
1196       */
1197      for (unsigned i = 0; i < nir_op_infos[src_alu->op].num_inputs; i++) {
1198         if (src_is_load_deref(src_alu->src[i].src, store->src[0]) &&
1199             src_alu->src[i].swizzle[0] == i)
1200            comps &= ~(1u << i);
1201      }
1202   }
1203
1204   return comps;
1205}
1206
1207static void
1208find_used_components_impl(nir_function_impl *impl,
1209                          struct hash_table *var_usage_map,
1210                          nir_variable_mode modes,
1211                          void *mem_ctx)
1212{
1213   nir_foreach_block(block, impl) {
1214      nir_foreach_instr(instr, block) {
1215         if (instr->type == nir_instr_type_deref) {
1216            mark_deref_if_complex(nir_instr_as_deref(instr),
1217                                  var_usage_map, modes, mem_ctx);
1218         }
1219
1220         if (instr->type != nir_instr_type_intrinsic)
1221            continue;
1222
1223         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1224         switch (intrin->intrinsic) {
1225         case nir_intrinsic_load_deref:
1226            mark_deref_used(nir_src_as_deref(intrin->src[0]),
1227                            nir_ssa_def_components_read(&intrin->dest.ssa), 0,
1228                            NULL, var_usage_map, modes, mem_ctx);
1229            break;
1230
1231         case nir_intrinsic_store_deref:
1232            mark_deref_used(nir_src_as_deref(intrin->src[0]),
1233                            0, get_non_self_referential_store_comps(intrin),
1234                            NULL, var_usage_map, modes, mem_ctx);
1235            break;
1236
1237         case nir_intrinsic_copy_deref: {
1238            /* Just mark everything used for copies. */
1239            nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
1240            nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
1241            mark_deref_used(dst, 0, ~0, src, var_usage_map, modes, mem_ctx);
1242            mark_deref_used(src, ~0, 0, dst, var_usage_map, modes, mem_ctx);
1243            break;
1244         }
1245
1246         default:
1247            break;
1248         }
1249      }
1250   }
1251}
1252
1253static bool
1254shrink_vec_var_list(struct exec_list *vars,
1255                    nir_variable_mode mode,
1256                    struct hash_table *var_usage_map)
1257{
1258   /* Initialize the components kept field of each variable.  This is the
1259    * AND of the components written and components read.  If a component is
1260    * written but never read, it's dead.  If it is read but never written,
1261    * then all values read are undefined garbage and we may as well not read
1262    * them.
1263    *
1264    * The same logic applies to the array length.  We make the array length
1265    * the minimum needed required length between read and write and plan to
1266    * discard any OOB access.  The one exception here is indirect writes
1267    * because we don't know where they will land and we can't shrink an array
1268    * with indirect writes because previously in-bounds writes may become
1269    * out-of-bounds and have undefined behavior.
1270    *
1271    * Also, if we have a copy that to/from something we can't shrink, we need
1272    * to leave components and array_len of any wildcards alone.
1273    */
1274   nir_foreach_variable_in_list(var, vars) {
1275      if (var->data.mode != mode)
1276         continue;
1277
1278      struct vec_var_usage *usage =
1279         get_vec_var_usage(var, var_usage_map, false, NULL);
1280      if (!usage)
1281         continue;
1282
1283      assert(usage->comps_kept == 0);
1284      if (usage->has_external_copy || usage->has_complex_use)
1285         usage->comps_kept = usage->all_comps;
1286      else
1287         usage->comps_kept = usage->comps_read & usage->comps_written;
1288
1289      for (unsigned i = 0; i < usage->num_levels; i++) {
1290         struct array_level_usage *level = &usage->levels[i];
1291         assert(level->array_len > 0);
1292
1293         if (level->max_written == UINT_MAX || level->has_external_copy ||
1294             usage->has_complex_use)
1295            continue; /* Can't shrink */
1296
1297         unsigned max_used = MIN2(level->max_read, level->max_written);
1298         level->array_len = MIN2(max_used, level->array_len - 1) + 1;
1299      }
1300   }
1301
1302   /* In order for variable copies to work, we have to have the same data type
1303    * on the source and the destination.  In order to satisfy this, we run a
1304    * little fixed-point algorithm to transitively ensure that we get enough
1305    * components and array elements for this to hold for all copies.
1306    */
1307   bool fp_progress;
1308   do {
1309      fp_progress = false;
1310      nir_foreach_variable_in_list(var, vars) {
1311         if (var->data.mode != mode)
1312            continue;
1313
1314         struct vec_var_usage *var_usage =
1315            get_vec_var_usage(var, var_usage_map, false, NULL);
1316         if (!var_usage || !var_usage->vars_copied)
1317            continue;
1318
1319         set_foreach(var_usage->vars_copied, copy_entry) {
1320            struct vec_var_usage *copy_usage = (void *)copy_entry->key;
1321            if (copy_usage->comps_kept != var_usage->comps_kept) {
1322               nir_component_mask_t comps_kept =
1323                  (var_usage->comps_kept | copy_usage->comps_kept);
1324               var_usage->comps_kept = comps_kept;
1325               copy_usage->comps_kept = comps_kept;
1326               fp_progress = true;
1327            }
1328         }
1329
1330         for (unsigned i = 0; i < var_usage->num_levels; i++) {
1331            struct array_level_usage *var_level = &var_usage->levels[i];
1332            if (!var_level->levels_copied)
1333               continue;
1334
1335            set_foreach(var_level->levels_copied, copy_entry) {
1336               struct array_level_usage *copy_level = (void *)copy_entry->key;
1337               if (var_level->array_len != copy_level->array_len) {
1338                  unsigned array_len =
1339                     MAX2(var_level->array_len, copy_level->array_len);
1340                  var_level->array_len = array_len;
1341                  copy_level->array_len = array_len;
1342                  fp_progress = true;
1343               }
1344            }
1345         }
1346      }
1347   } while (fp_progress);
1348
1349   bool vars_shrunk = false;
1350   nir_foreach_variable_in_list_safe(var, vars) {
1351      if (var->data.mode != mode)
1352         continue;
1353
1354      struct vec_var_usage *usage =
1355         get_vec_var_usage(var, var_usage_map, false, NULL);
1356      if (!usage)
1357         continue;
1358
1359      bool shrunk = false;
1360      const struct glsl_type *vec_type = var->type;
1361      for (unsigned i = 0; i < usage->num_levels; i++) {
1362         /* If we've reduced the array to zero elements at some level, just
1363          * set comps_kept to 0 and delete the variable.
1364          */
1365         if (usage->levels[i].array_len == 0) {
1366            usage->comps_kept = 0;
1367            break;
1368         }
1369
1370         assert(usage->levels[i].array_len <= glsl_get_length(vec_type));
1371         if (usage->levels[i].array_len < glsl_get_length(vec_type))
1372            shrunk = true;
1373         vec_type = glsl_get_array_element(vec_type);
1374      }
1375      assert(glsl_type_is_vector_or_scalar(vec_type));
1376
1377      assert(usage->comps_kept == (usage->comps_kept & usage->all_comps));
1378      if (usage->comps_kept != usage->all_comps)
1379         shrunk = true;
1380
1381      if (usage->comps_kept == 0) {
1382         /* This variable is dead, remove it */
1383         vars_shrunk = true;
1384         exec_node_remove(&var->node);
1385         continue;
1386      }
1387
1388      if (!shrunk) {
1389         /* This variable doesn't need to be shrunk.  Remove it from the
1390          * hash table so later steps will ignore it.
1391          */
1392         _mesa_hash_table_remove_key(var_usage_map, var);
1393         continue;
1394      }
1395
1396      /* Build the new var type */
1397      unsigned new_num_comps = util_bitcount(usage->comps_kept);
1398      const struct glsl_type *new_type =
1399         glsl_vector_type(glsl_get_base_type(vec_type), new_num_comps);
1400      for (int i = usage->num_levels - 1; i >= 0; i--) {
1401         assert(usage->levels[i].array_len > 0);
1402         /* If the original type was a matrix type, we'd like to keep that so
1403          * we don't convert matrices into arrays.
1404          */
1405         if (i == usage->num_levels - 1 &&
1406             glsl_type_is_matrix(glsl_without_array(var->type)) &&
1407             new_num_comps > 1 && usage->levels[i].array_len > 1) {
1408            new_type = glsl_matrix_type(glsl_get_base_type(new_type),
1409                                        new_num_comps,
1410                                        usage->levels[i].array_len);
1411         } else {
1412            new_type = glsl_array_type(new_type, usage->levels[i].array_len, 0);
1413         }
1414      }
1415      var->type = new_type;
1416
1417      vars_shrunk = true;
1418   }
1419
1420   return vars_shrunk;
1421}
1422
1423static bool
1424vec_deref_is_oob(nir_deref_instr *deref,
1425                 struct vec_var_usage *usage)
1426{
1427   nir_deref_path path;
1428   nir_deref_path_init(&path, deref, NULL);
1429
1430   bool oob = false;
1431   for (unsigned i = 0; i < usage->num_levels; i++) {
1432      nir_deref_instr *p = path.path[i + 1];
1433      if (p->deref_type == nir_deref_type_array_wildcard)
1434         continue;
1435
1436      if (nir_src_is_const(p->arr.index) &&
1437          nir_src_as_uint(p->arr.index) >= usage->levels[i].array_len) {
1438         oob = true;
1439         break;
1440      }
1441   }
1442
1443   nir_deref_path_finish(&path);
1444
1445   return oob;
1446}
1447
1448static bool
1449vec_deref_is_dead_or_oob(nir_deref_instr *deref,
1450                         struct hash_table *var_usage_map,
1451                         nir_variable_mode modes)
1452{
1453   struct vec_var_usage *usage =
1454      get_vec_deref_usage(deref, var_usage_map, modes, false, NULL);
1455   if (!usage)
1456      return false;
1457
1458   return usage->comps_kept == 0 || vec_deref_is_oob(deref, usage);
1459}
1460
1461static void
1462shrink_vec_var_access_impl(nir_function_impl *impl,
1463                           struct hash_table *var_usage_map,
1464                           nir_variable_mode modes)
1465{
1466   nir_builder b;
1467   nir_builder_init(&b, impl);
1468
1469   nir_foreach_block(block, impl) {
1470      nir_foreach_instr_safe(instr, block) {
1471         switch (instr->type) {
1472         case nir_instr_type_deref: {
1473            nir_deref_instr *deref = nir_instr_as_deref(instr);
1474            if (!nir_deref_mode_may_be(deref, modes))
1475               break;
1476
1477            /* Clean up any dead derefs we find lying around.  They may refer
1478             * to variables we've deleted.
1479             */
1480            if (nir_deref_instr_remove_if_unused(deref))
1481               break;
1482
1483            /* Update the type in the deref to keep the types consistent as
1484             * you walk down the chain.  We don't need to check if this is one
1485             * of the derefs we're shrinking because this is a no-op if it
1486             * isn't.  The worst that could happen is that we accidentally fix
1487             * an invalid deref.
1488             */
1489            if (deref->deref_type == nir_deref_type_var) {
1490               deref->type = deref->var->type;
1491            } else if (deref->deref_type == nir_deref_type_array ||
1492                       deref->deref_type == nir_deref_type_array_wildcard) {
1493               nir_deref_instr *parent = nir_deref_instr_parent(deref);
1494               assert(glsl_type_is_array(parent->type) ||
1495                      glsl_type_is_matrix(parent->type));
1496               deref->type = glsl_get_array_element(parent->type);
1497            }
1498            break;
1499         }
1500
1501         case nir_instr_type_intrinsic: {
1502            nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1503
1504            /* If we have a copy whose source or destination has been deleted
1505             * because we determined the variable was dead, then we just
1506             * delete the copy instruction.  If the source variable was dead
1507             * then it was writing undefined garbage anyway and if it's the
1508             * destination variable that's dead then the write isn't needed.
1509             */
1510            if (intrin->intrinsic == nir_intrinsic_copy_deref) {
1511               nir_deref_instr *dst = nir_src_as_deref(intrin->src[0]);
1512               nir_deref_instr *src = nir_src_as_deref(intrin->src[1]);
1513               if (vec_deref_is_dead_or_oob(dst, var_usage_map, modes) ||
1514                   vec_deref_is_dead_or_oob(src, var_usage_map, modes)) {
1515                  nir_instr_remove(&intrin->instr);
1516                  nir_deref_instr_remove_if_unused(dst);
1517                  nir_deref_instr_remove_if_unused(src);
1518               }
1519               continue;
1520            }
1521
1522            if (intrin->intrinsic != nir_intrinsic_load_deref &&
1523                intrin->intrinsic != nir_intrinsic_store_deref)
1524               continue;
1525
1526            nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1527            if (!nir_deref_mode_may_be(deref, modes))
1528               continue;
1529
1530            struct vec_var_usage *usage =
1531               get_vec_deref_usage(deref, var_usage_map, modes, false, NULL);
1532            if (!usage)
1533               continue;
1534
1535            if (usage->comps_kept == 0 || vec_deref_is_oob(deref, usage)) {
1536               if (intrin->intrinsic == nir_intrinsic_load_deref) {
1537                  nir_ssa_def *u =
1538                     nir_ssa_undef(&b, intrin->dest.ssa.num_components,
1539                                       intrin->dest.ssa.bit_size);
1540                  nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
1541                                           u);
1542               }
1543               nir_instr_remove(&intrin->instr);
1544               nir_deref_instr_remove_if_unused(deref);
1545               continue;
1546            }
1547
1548            /* If we're not dropping any components, there's no need to
1549             * compact vectors.
1550             */
1551            if (usage->comps_kept == usage->all_comps)
1552               continue;
1553
1554            if (intrin->intrinsic == nir_intrinsic_load_deref) {
1555               b.cursor = nir_after_instr(&intrin->instr);
1556
1557               nir_ssa_def *undef =
1558                  nir_ssa_undef(&b, 1, intrin->dest.ssa.bit_size);
1559               nir_ssa_def *vec_srcs[NIR_MAX_VEC_COMPONENTS];
1560               unsigned c = 0;
1561               for (unsigned i = 0; i < intrin->num_components; i++) {
1562                  if (usage->comps_kept & (1u << i))
1563                     vec_srcs[i] = nir_channel(&b, &intrin->dest.ssa, c++);
1564                  else
1565                     vec_srcs[i] = undef;
1566               }
1567               nir_ssa_def *vec = nir_vec(&b, vec_srcs, intrin->num_components);
1568
1569               nir_ssa_def_rewrite_uses_after(&intrin->dest.ssa,
1570                                              vec,
1571                                              vec->parent_instr);
1572
1573               /* The SSA def is now only used by the swizzle.  It's safe to
1574                * shrink the number of components.
1575                */
1576               assert(list_length(&intrin->dest.ssa.uses) == c);
1577               intrin->num_components = c;
1578               intrin->dest.ssa.num_components = c;
1579            } else {
1580               nir_component_mask_t write_mask =
1581                  nir_intrinsic_write_mask(intrin);
1582
1583               unsigned swizzle[NIR_MAX_VEC_COMPONENTS];
1584               nir_component_mask_t new_write_mask = 0;
1585               unsigned c = 0;
1586               for (unsigned i = 0; i < intrin->num_components; i++) {
1587                  if (usage->comps_kept & (1u << i)) {
1588                     swizzle[c] = i;
1589                     if (write_mask & (1u << i))
1590                        new_write_mask |= 1u << c;
1591                     c++;
1592                  }
1593               }
1594
1595               b.cursor = nir_before_instr(&intrin->instr);
1596
1597               nir_ssa_def *swizzled =
1598                  nir_swizzle(&b, intrin->src[1].ssa, swizzle, c);
1599
1600               /* Rewrite to use the compacted source */
1601               nir_instr_rewrite_src(&intrin->instr, &intrin->src[1],
1602                                     nir_src_for_ssa(swizzled));
1603               nir_intrinsic_set_write_mask(intrin, new_write_mask);
1604               intrin->num_components = c;
1605            }
1606            break;
1607         }
1608
1609         default:
1610            break;
1611         }
1612      }
1613   }
1614}
1615
1616static bool
1617function_impl_has_vars_with_modes(nir_function_impl *impl,
1618                                  nir_variable_mode modes)
1619{
1620   nir_shader *shader = impl->function->shader;
1621
1622   if (modes & ~nir_var_function_temp) {
1623      nir_foreach_variable_with_modes(var, shader,
1624                                      modes & ~nir_var_function_temp)
1625         return true;
1626   }
1627
1628   if ((modes & nir_var_function_temp) && !exec_list_is_empty(&impl->locals))
1629      return true;
1630
1631   return false;
1632}
1633
1634/** Attempt to shrink arrays of vectors
1635 *
1636 * This pass looks at variables which contain a vector or an array (possibly
1637 * multiple dimensions) of vectors and attempts to lower to a smaller vector
1638 * or array.  If the pass can prove that a component of a vector (or array of
1639 * vectors) is never really used, then that component will be removed.
1640 * Similarly, the pass attempts to shorten arrays based on what elements it
1641 * can prove are never read or never contain valid data.
1642 */
1643bool
1644nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
1645{
1646   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
1647
1648   void *mem_ctx = ralloc_context(NULL);
1649
1650   struct hash_table *var_usage_map =
1651      _mesa_pointer_hash_table_create(mem_ctx);
1652
1653   bool has_vars_to_shrink = false;
1654   nir_foreach_function(function, shader) {
1655      if (!function->impl)
1656         continue;
1657
1658      /* Don't even bother crawling the IR if we don't have any variables.
1659       * Given that this pass deletes any unused variables, it's likely that
1660       * we will be in this scenario eventually.
1661       */
1662      if (function_impl_has_vars_with_modes(function->impl, modes)) {
1663         has_vars_to_shrink = true;
1664         find_used_components_impl(function->impl, var_usage_map,
1665                                   modes, mem_ctx);
1666      }
1667   }
1668   if (!has_vars_to_shrink) {
1669      ralloc_free(mem_ctx);
1670      nir_shader_preserve_all_metadata(shader);
1671      return false;
1672   }
1673
1674   bool globals_shrunk = false;
1675   if (modes & nir_var_shader_temp) {
1676      globals_shrunk = shrink_vec_var_list(&shader->variables,
1677                                           nir_var_shader_temp,
1678                                           var_usage_map);
1679   }
1680
1681   bool progress = false;
1682   nir_foreach_function(function, shader) {
1683      if (!function->impl)
1684         continue;
1685
1686      bool locals_shrunk = false;
1687      if (modes & nir_var_function_temp) {
1688         locals_shrunk = shrink_vec_var_list(&function->impl->locals,
1689                                             nir_var_function_temp,
1690                                             var_usage_map);
1691      }
1692
1693      if (globals_shrunk || locals_shrunk) {
1694         shrink_vec_var_access_impl(function->impl, var_usage_map, modes);
1695
1696         nir_metadata_preserve(function->impl, nir_metadata_block_index |
1697                                               nir_metadata_dominance);
1698         progress = true;
1699      } else {
1700         nir_metadata_preserve(function->impl, nir_metadata_all);
1701      }
1702   }
1703
1704   ralloc_free(mem_ctx);
1705
1706   return progress;
1707}
1708