1/*
2 * Copyright © 2015 Connor Abbott
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 */
24
25/**
26 * nir_opt_vectorize() aims to vectorize ALU instructions.
27 *
28 * The default vectorization width is 4.
29 * If desired, a callback function which returns the max vectorization width
30 * per instruction can be provided.
31 *
32 * The max vectorization width must be a power of 2.
33 */
34
35#include "nir.h"
36#include "nir_vla.h"
37#include "nir_builder.h"
38#include "util/u_dynarray.h"
39
40#define HASH(hash, data) XXH32(&data, sizeof(data), hash)
41
42static uint32_t
43hash_src(uint32_t hash, const nir_src *src)
44{
45   assert(src->is_ssa);
46   void *hash_data = nir_src_is_const(*src) ? NULL : src->ssa;
47
48   return HASH(hash, hash_data);
49}
50
51static uint32_t
52hash_alu_src(uint32_t hash, const nir_alu_src *src,
53             uint32_t num_components, uint32_t max_vec)
54{
55   assert(!src->abs && !src->negate);
56
57   /* hash whether a swizzle accesses elements beyond the maximum
58    * vectorization factor:
59    * For example accesses to .x and .y are considered different variables
60    * compared to accesses to .z and .w for 16-bit vec2.
61    */
62   uint32_t swizzle = (src->swizzle[0] & ~(max_vec - 1));
63   hash = HASH(hash, swizzle);
64
65   return hash_src(hash, &src->src);
66}
67
68static uint32_t
69hash_instr(const void *data)
70{
71   const nir_instr *instr = (nir_instr *) data;
72   assert(instr->type == nir_instr_type_alu);
73   nir_alu_instr *alu = nir_instr_as_alu(instr);
74
75   uint32_t hash = HASH(0, alu->op);
76   hash = HASH(hash, alu->dest.dest.ssa.bit_size);
77
78   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
79      hash = hash_alu_src(hash, &alu->src[i],
80                          alu->dest.dest.ssa.num_components,
81                          instr->pass_flags);
82
83   return hash;
84}
85
86static bool
87srcs_equal(const nir_src *src1, const nir_src *src2)
88{
89   assert(src1->is_ssa);
90   assert(src2->is_ssa);
91
92   return src1->ssa == src2->ssa ||
93          (nir_src_is_const(*src1) && nir_src_is_const(*src2));
94}
95
96static bool
97alu_srcs_equal(const nir_alu_src *src1, const nir_alu_src *src2,
98               uint32_t max_vec)
99{
100   assert(!src1->abs);
101   assert(!src1->negate);
102   assert(!src2->abs);
103   assert(!src2->negate);
104
105   uint32_t mask = ~(max_vec - 1);
106   if ((src1->swizzle[0] & mask) != (src2->swizzle[0] & mask))
107      return false;
108
109   return srcs_equal(&src1->src, &src2->src);
110}
111
112static bool
113instrs_equal(const void *data1, const void *data2)
114{
115   const nir_instr *instr1 = (nir_instr *) data1;
116   const nir_instr *instr2 = (nir_instr *) data2;
117   assert(instr1->type == nir_instr_type_alu);
118   assert(instr2->type == nir_instr_type_alu);
119
120   nir_alu_instr *alu1 = nir_instr_as_alu(instr1);
121   nir_alu_instr *alu2 = nir_instr_as_alu(instr2);
122
123   if (alu1->op != alu2->op)
124      return false;
125
126   if (alu1->dest.dest.ssa.bit_size != alu2->dest.dest.ssa.bit_size)
127      return false;
128
129   for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
130      if (!alu_srcs_equal(&alu1->src[i], &alu2->src[i], instr1->pass_flags))
131         return false;
132   }
133
134   return true;
135}
136
137static bool
138instr_can_rewrite(nir_instr *instr)
139{
140   switch (instr->type) {
141   case nir_instr_type_alu: {
142      nir_alu_instr *alu = nir_instr_as_alu(instr);
143
144      /* Don't try and vectorize mov's. Either they'll be handled by copy
145       * prop, or they're actually necessary and trying to vectorize them
146       * would result in fighting with copy prop.
147       */
148      if (alu->op == nir_op_mov)
149         return false;
150
151      /* no need to hash instructions which are already vectorized */
152      if (alu->dest.dest.ssa.num_components >= instr->pass_flags)
153         return false;
154
155      if (nir_op_infos[alu->op].output_size != 0)
156         return false;
157
158      for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
159         if (nir_op_infos[alu->op].input_sizes[i] != 0)
160            return false;
161
162         /* don't hash instructions which are already swizzled
163          * outside of max_components: these should better be scalarized */
164         uint32_t mask = ~(instr->pass_flags - 1);
165         for (unsigned j = 1; j < alu->dest.dest.ssa.num_components; j++) {
166            if ((alu->src[i].swizzle[0] & mask) != (alu->src[i].swizzle[j] & mask))
167               return false;
168         }
169      }
170
171      return true;
172   }
173
174   /* TODO support phi nodes */
175   default:
176      break;
177   }
178
179   return false;
180}
181
182/*
183 * Tries to combine two instructions whose sources are different components of
184 * the same instructions into one vectorized instruction. Note that instr1
185 * should dominate instr2.
186 */
187static nir_instr *
188instr_try_combine(struct set *instr_set, nir_instr *instr1, nir_instr *instr2)
189{
190   assert(instr1->type == nir_instr_type_alu);
191   assert(instr2->type == nir_instr_type_alu);
192   nir_alu_instr *alu1 = nir_instr_as_alu(instr1);
193   nir_alu_instr *alu2 = nir_instr_as_alu(instr2);
194
195   assert(alu1->dest.dest.ssa.bit_size == alu2->dest.dest.ssa.bit_size);
196   unsigned alu1_components = alu1->dest.dest.ssa.num_components;
197   unsigned alu2_components = alu2->dest.dest.ssa.num_components;
198   unsigned total_components = alu1_components + alu2_components;
199
200   assert(instr1->pass_flags == instr2->pass_flags);
201   if (total_components > instr1->pass_flags)
202      return NULL;
203
204   nir_builder b;
205   nir_builder_init(&b, nir_cf_node_get_function(&instr1->block->cf_node));
206   b.cursor = nir_after_instr(instr1);
207
208   nir_alu_instr *new_alu = nir_alu_instr_create(b.shader, alu1->op);
209   nir_ssa_dest_init(&new_alu->instr, &new_alu->dest.dest,
210                     total_components, alu1->dest.dest.ssa.bit_size, NULL);
211   new_alu->dest.write_mask = (1 << total_components) - 1;
212   new_alu->instr.pass_flags = alu1->instr.pass_flags;
213
214   /* If either channel is exact, we have to preserve it even if it's
215    * not optimal for other channels.
216    */
217   new_alu->exact = alu1->exact || alu2->exact;
218
219   /* If all channels don't wrap, we can say that the whole vector doesn't
220    * wrap.
221    */
222   new_alu->no_signed_wrap = alu1->no_signed_wrap && alu2->no_signed_wrap;
223   new_alu->no_unsigned_wrap = alu1->no_unsigned_wrap && alu2->no_unsigned_wrap;
224
225   for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) {
226      /* handle constant merging case */
227      if (alu1->src[i].src.ssa != alu2->src[i].src.ssa) {
228         nir_const_value *c1 = nir_src_as_const_value(alu1->src[i].src);
229         nir_const_value *c2 = nir_src_as_const_value(alu2->src[i].src);
230         assert(c1 && c2);
231         nir_const_value value[NIR_MAX_VEC_COMPONENTS];
232         unsigned bit_size = alu1->src[i].src.ssa->bit_size;
233
234         for (unsigned j = 0; j < total_components; j++) {
235            value[j].u64 = j < alu1_components ?
236                              c1[alu1->src[i].swizzle[j]].u64 :
237                              c2[alu2->src[i].swizzle[j - alu1_components]].u64;
238         }
239         nir_ssa_def *def = nir_build_imm(&b, total_components, bit_size, value);
240
241         new_alu->src[i].src = nir_src_for_ssa(def);
242         for (unsigned j = 0; j < total_components; j++)
243            new_alu->src[i].swizzle[j] = j;
244         continue;
245      }
246
247      new_alu->src[i].src = alu1->src[i].src;
248
249      for (unsigned j = 0; j < alu1_components; j++)
250         new_alu->src[i].swizzle[j] = alu1->src[i].swizzle[j];
251
252      for (unsigned j = 0; j < alu2_components; j++) {
253         new_alu->src[i].swizzle[j + alu1_components] =
254            alu2->src[i].swizzle[j];
255      }
256   }
257
258   nir_builder_instr_insert(&b, &new_alu->instr);
259
260   /* update all ALU uses */
261   nir_foreach_use_safe(src, &alu1->dest.dest.ssa) {
262      nir_instr *user_instr = src->parent_instr;
263      if (user_instr->type == nir_instr_type_alu) {
264         /* Check if user is found in the hashset */
265         struct set_entry *entry = _mesa_set_search(instr_set, user_instr);
266
267         /* For ALU instructions, rewrite the source directly to avoid a
268          * round-trip through copy propagation.
269          */
270         nir_instr_rewrite_src(user_instr, src,
271                               nir_src_for_ssa(&new_alu->dest.dest.ssa));
272
273         /* Rehash user if it was found in the hashset */
274         if (entry && entry->key == user_instr) {
275            _mesa_set_remove(instr_set, entry);
276            _mesa_set_add(instr_set, user_instr);
277         }
278      }
279   }
280
281   nir_foreach_use_safe(src, &alu2->dest.dest.ssa) {
282      if (src->parent_instr->type == nir_instr_type_alu) {
283         /* For ALU instructions, rewrite the source directly to avoid a
284          * round-trip through copy propagation.
285          */
286         nir_instr_rewrite_src(src->parent_instr, src,
287                               nir_src_for_ssa(&new_alu->dest.dest.ssa));
288
289         nir_alu_src *alu_src = container_of(src, nir_alu_src, src);
290         nir_alu_instr *use = nir_instr_as_alu(src->parent_instr);
291         unsigned components = nir_ssa_alu_instr_src_components(use, alu_src - use->src);
292         for (unsigned i = 0; i < components; i++)
293            alu_src->swizzle[i] += alu1_components;
294      }
295   }
296
297   /* update all other uses if there are any */
298   unsigned swiz[NIR_MAX_VEC_COMPONENTS];
299
300   if (!nir_ssa_def_is_unused(&alu1->dest.dest.ssa)) {
301      for (unsigned i = 0; i < alu1_components; i++)
302         swiz[i] = i;
303      nir_ssa_def *new_alu1 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz,
304                                          alu1_components);
305      nir_ssa_def_rewrite_uses(&alu1->dest.dest.ssa, new_alu1);
306   }
307
308   if (!nir_ssa_def_is_unused(&alu2->dest.dest.ssa)) {
309      for (unsigned i = 0; i < alu2_components; i++)
310         swiz[i] = i + alu1_components;
311      nir_ssa_def *new_alu2 = nir_swizzle(&b, &new_alu->dest.dest.ssa, swiz,
312                                          alu2_components);
313      nir_ssa_def_rewrite_uses(&alu2->dest.dest.ssa, new_alu2);
314   }
315
316   nir_instr_remove(instr1);
317   nir_instr_remove(instr2);
318
319   return &new_alu->instr;
320}
321
322static struct set *
323vec_instr_set_create(void)
324{
325   return _mesa_set_create(NULL, hash_instr, instrs_equal);
326}
327
328static void
329vec_instr_set_destroy(struct set *instr_set)
330{
331   _mesa_set_destroy(instr_set, NULL);
332}
333
334static bool
335vec_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr,
336                             nir_vectorize_cb filter, void *data)
337{
338   /* set max vector to instr pass flags: this is used to hash swizzles */
339   instr->pass_flags = filter ? filter(instr, data) : 4;
340   assert(util_is_power_of_two_or_zero(instr->pass_flags));
341
342   if (!instr_can_rewrite(instr))
343      return false;
344
345   struct set_entry *entry = _mesa_set_search(instr_set, instr);
346   if (entry) {
347      nir_instr *old_instr = (nir_instr *) entry->key;
348      _mesa_set_remove(instr_set, entry);
349      nir_instr *new_instr = instr_try_combine(instr_set, old_instr, instr);
350      if (new_instr) {
351         if (instr_can_rewrite(new_instr))
352            _mesa_set_add(instr_set, new_instr);
353         return true;
354      }
355   }
356
357   _mesa_set_add(instr_set, instr);
358   return false;
359}
360
361static bool
362vectorize_block(nir_block *block, struct set *instr_set,
363                nir_vectorize_cb filter, void *data)
364{
365   bool progress = false;
366
367   nir_foreach_instr_safe(instr, block) {
368      if (vec_instr_set_add_or_rewrite(instr_set, instr, filter, data))
369         progress = true;
370   }
371
372   for (unsigned i = 0; i < block->num_dom_children; i++) {
373      nir_block *child = block->dom_children[i];
374      progress |= vectorize_block(child, instr_set, filter, data);
375   }
376
377   nir_foreach_instr_reverse(instr, block) {
378      if (instr_can_rewrite(instr))
379         _mesa_set_remove_key(instr_set, instr);
380   }
381
382   return progress;
383}
384
385static bool
386nir_opt_vectorize_impl(nir_function_impl *impl,
387                       nir_vectorize_cb filter, void *data)
388{
389   struct set *instr_set = vec_instr_set_create();
390
391   nir_metadata_require(impl, nir_metadata_dominance);
392
393   bool progress = vectorize_block(nir_start_block(impl), instr_set,
394                                   filter, data);
395
396   if (progress) {
397      nir_metadata_preserve(impl, nir_metadata_block_index |
398                                  nir_metadata_dominance);
399   } else {
400      nir_metadata_preserve(impl, nir_metadata_all);
401   }
402
403   vec_instr_set_destroy(instr_set);
404   return progress;
405}
406
407bool
408nir_opt_vectorize(nir_shader *shader, nir_vectorize_cb filter,
409                  void *data)
410{
411   bool progress = false;
412
413   nir_foreach_function(function, shader) {
414      if (function->impl)
415         progress |= nir_opt_vectorize_impl(function->impl, filter, data);
416   }
417
418   return progress;
419}
420