1/*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "dxil_nir.h"
25
26#include "nir_builder.h"
27#include "nir_deref.h"
28#include "nir_to_dxil.h"
29#include "util/u_math.h"
30#include "vulkan/vulkan_core.h"
31
32static void
33cl_type_size_align(const struct glsl_type *type, unsigned *size,
34                   unsigned *align)
35{
36   *size = glsl_get_cl_size(type);
37   *align = glsl_get_cl_alignment(type);
38}
39
40static void
41extract_comps_from_vec32(nir_builder *b, nir_ssa_def *vec32,
42                         unsigned dst_bit_size,
43                         nir_ssa_def **dst_comps,
44                         unsigned num_dst_comps)
45{
46   unsigned step = DIV_ROUND_UP(dst_bit_size, 32);
47   unsigned comps_per32b = 32 / dst_bit_size;
48   nir_ssa_def *tmp;
49
50   for (unsigned i = 0; i < vec32->num_components; i += step) {
51      switch (dst_bit_size) {
52      case 64:
53         tmp = nir_pack_64_2x32_split(b, nir_channel(b, vec32, i),
54                                         nir_channel(b, vec32, i + 1));
55         dst_comps[i / 2] = tmp;
56         break;
57      case 32:
58         dst_comps[i] = nir_channel(b, vec32, i);
59         break;
60      case 16:
61      case 8: {
62         unsigned dst_offs = i * comps_per32b;
63
64         tmp = nir_unpack_bits(b, nir_channel(b, vec32, i), dst_bit_size);
65         for (unsigned j = 0; j < comps_per32b && dst_offs + j < num_dst_comps; j++)
66            dst_comps[dst_offs + j] = nir_channel(b, tmp, j);
67         }
68
69         break;
70      }
71   }
72}
73
74static nir_ssa_def *
75load_comps_to_vec32(nir_builder *b, unsigned src_bit_size,
76                    nir_ssa_def **src_comps, unsigned num_src_comps)
77{
78   unsigned num_vec32comps = DIV_ROUND_UP(num_src_comps * src_bit_size, 32);
79   unsigned step = DIV_ROUND_UP(src_bit_size, 32);
80   unsigned comps_per32b = 32 / src_bit_size;
81   nir_ssa_def *vec32comps[4];
82
83   for (unsigned i = 0; i < num_vec32comps; i += step) {
84      switch (src_bit_size) {
85      case 64:
86         vec32comps[i] = nir_unpack_64_2x32_split_x(b, src_comps[i / 2]);
87         vec32comps[i + 1] = nir_unpack_64_2x32_split_y(b, src_comps[i / 2]);
88         break;
89      case 32:
90         vec32comps[i] = src_comps[i];
91         break;
92      case 16:
93      case 8: {
94         unsigned src_offs = i * comps_per32b;
95
96         vec32comps[i] = nir_u2u32(b, src_comps[src_offs]);
97         for (unsigned j = 1; j < comps_per32b && src_offs + j < num_src_comps; j++) {
98            nir_ssa_def *tmp = nir_ishl(b, nir_u2u32(b, src_comps[src_offs + j]),
99                                           nir_imm_int(b, j * src_bit_size));
100            vec32comps[i] = nir_ior(b, vec32comps[i], tmp);
101         }
102         break;
103      }
104      }
105   }
106
107   return nir_vec(b, vec32comps, num_vec32comps);
108}
109
110static nir_ssa_def *
111build_load_ptr_dxil(nir_builder *b, nir_deref_instr *deref, nir_ssa_def *idx)
112{
113   return nir_load_ptr_dxil(b, 1, 32, &deref->dest.ssa, idx);
114}
115
116static bool
117lower_load_deref(nir_builder *b, nir_intrinsic_instr *intr)
118{
119   assert(intr->dest.is_ssa);
120
121   b->cursor = nir_before_instr(&intr->instr);
122
123   nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
124   if (!nir_deref_mode_is(deref, nir_var_shader_temp))
125      return false;
126   nir_ssa_def *ptr = nir_u2u32(b, nir_build_deref_offset(b, deref, cl_type_size_align));
127   nir_ssa_def *offset = nir_iand(b, ptr, nir_inot(b, nir_imm_int(b, 3)));
128
129   assert(intr->dest.is_ssa);
130   unsigned num_components = nir_dest_num_components(intr->dest);
131   unsigned bit_size = nir_dest_bit_size(intr->dest);
132   unsigned load_size = MAX2(32, bit_size);
133   unsigned num_bits = num_components * bit_size;
134   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
135   unsigned comp_idx = 0;
136
137   nir_deref_path path;
138   nir_deref_path_init(&path, deref, NULL);
139   nir_ssa_def *base_idx = nir_ishr(b, offset, nir_imm_int(b, 2 /* log2(32 / 8) */));
140
141   /* Split loads into 32-bit chunks */
142   for (unsigned i = 0; i < num_bits; i += load_size) {
143      unsigned subload_num_bits = MIN2(num_bits - i, load_size);
144      nir_ssa_def *idx = nir_iadd(b, base_idx, nir_imm_int(b, i / 32));
145      nir_ssa_def *vec32 = build_load_ptr_dxil(b, path.path[0], idx);
146
147      if (load_size == 64) {
148         idx = nir_iadd(b, idx, nir_imm_int(b, 1));
149         vec32 = nir_vec2(b, vec32,
150                             build_load_ptr_dxil(b, path.path[0], idx));
151      }
152
153      /* If we have 2 bytes or less to load we need to adjust the u32 value so
154       * we can always extract the LSB.
155       */
156      if (subload_num_bits <= 16) {
157         nir_ssa_def *shift = nir_imul(b, nir_iand(b, ptr, nir_imm_int(b, 3)),
158                                          nir_imm_int(b, 8));
159         vec32 = nir_ushr(b, vec32, shift);
160      }
161
162      /* And now comes the pack/unpack step to match the original type. */
163      extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
164                               subload_num_bits / bit_size);
165      comp_idx += subload_num_bits / bit_size;
166   }
167
168   nir_deref_path_finish(&path);
169   assert(comp_idx == num_components);
170   nir_ssa_def *result = nir_vec(b, comps, num_components);
171   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
172   nir_instr_remove(&intr->instr);
173   return true;
174}
175
176static nir_ssa_def *
177ubo_load_select_32b_comps(nir_builder *b, nir_ssa_def *vec32,
178                          nir_ssa_def *offset, unsigned num_bytes)
179{
180   assert(num_bytes == 16 || num_bytes == 12 || num_bytes == 8 ||
181          num_bytes == 4 || num_bytes == 3 || num_bytes == 2 ||
182          num_bytes == 1);
183   assert(vec32->num_components == 4);
184
185   /* 16 and 12 byte types are always aligned on 16 bytes. */
186   if (num_bytes > 8)
187      return vec32;
188
189   nir_ssa_def *comps[4];
190   nir_ssa_def *cond;
191
192   for (unsigned i = 0; i < 4; i++)
193      comps[i] = nir_channel(b, vec32, i);
194
195   /* If we have 8bytes or less to load, select which half the vec4 should
196    * be used.
197    */
198   cond = nir_ine(b, nir_iand(b, offset, nir_imm_int(b, 0x8)),
199                                 nir_imm_int(b, 0));
200
201   comps[0] = nir_bcsel(b, cond, comps[2], comps[0]);
202   comps[1] = nir_bcsel(b, cond, comps[3], comps[1]);
203
204   /* Thanks to the CL alignment constraints, if we want 8 bytes we're done. */
205   if (num_bytes == 8)
206      return nir_vec(b, comps, 2);
207
208   /* 4 bytes or less needed, select which of the 32bit component should be
209    * used and return it. The sub-32bit split is handled in
210    * extract_comps_from_vec32().
211    */
212   cond = nir_ine(b, nir_iand(b, offset, nir_imm_int(b, 0x4)),
213                                 nir_imm_int(b, 0));
214   return nir_bcsel(b, cond, comps[1], comps[0]);
215}
216
217nir_ssa_def *
218build_load_ubo_dxil(nir_builder *b, nir_ssa_def *buffer,
219                    nir_ssa_def *offset, unsigned num_components,
220                    unsigned bit_size)
221{
222   nir_ssa_def *idx = nir_ushr(b, offset, nir_imm_int(b, 4));
223   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
224   unsigned num_bits = num_components * bit_size;
225   unsigned comp_idx = 0;
226
227   /* We need to split loads in 16byte chunks because that's the
228    * granularity of cBufferLoadLegacy().
229    */
230   for (unsigned i = 0; i < num_bits; i += (16 * 8)) {
231      /* For each 16byte chunk (or smaller) we generate a 32bit ubo vec
232       * load.
233       */
234      unsigned subload_num_bits = MIN2(num_bits - i, 16 * 8);
235      nir_ssa_def *vec32 =
236         nir_load_ubo_dxil(b, 4, 32, buffer, nir_iadd(b, idx, nir_imm_int(b, i / (16 * 8))));
237
238      /* First re-arrange the vec32 to account for intra 16-byte offset. */
239      vec32 = ubo_load_select_32b_comps(b, vec32, offset, subload_num_bits / 8);
240
241      /* If we have 2 bytes or less to load we need to adjust the u32 value so
242       * we can always extract the LSB.
243       */
244      if (subload_num_bits <= 16) {
245         nir_ssa_def *shift = nir_imul(b, nir_iand(b, offset,
246                                                      nir_imm_int(b, 3)),
247                                          nir_imm_int(b, 8));
248         vec32 = nir_ushr(b, vec32, shift);
249      }
250
251      /* And now comes the pack/unpack step to match the original type. */
252      extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
253                               subload_num_bits / bit_size);
254      comp_idx += subload_num_bits / bit_size;
255   }
256
257   assert(comp_idx == num_components);
258   return nir_vec(b, comps, num_components);
259}
260
261static bool
262lower_load_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
263{
264   assert(intr->dest.is_ssa);
265   assert(intr->src[0].is_ssa);
266   assert(intr->src[1].is_ssa);
267
268   b->cursor = nir_before_instr(&intr->instr);
269
270   nir_ssa_def *buffer = intr->src[0].ssa;
271   nir_ssa_def *offset = nir_iand(b, intr->src[1].ssa, nir_imm_int(b, ~3));
272   enum gl_access_qualifier access = nir_intrinsic_access(intr);
273   unsigned bit_size = nir_dest_bit_size(intr->dest);
274   unsigned num_components = nir_dest_num_components(intr->dest);
275   unsigned num_bits = num_components * bit_size;
276
277   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
278   unsigned comp_idx = 0;
279
280   /* We need to split loads in 16byte chunks because that's the optimal
281    * granularity of bufferLoad(). Minimum alignment is 4byte, which saves
282    * from us from extra complexity to extract >= 32 bit components.
283    */
284   for (unsigned i = 0; i < num_bits; i += 4 * 32) {
285      /* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
286       * load.
287       */
288      unsigned subload_num_bits = MIN2(num_bits - i, 4 * 32);
289
290      /* The number of components to store depends on the number of bytes. */
291      nir_ssa_def *vec32 =
292         nir_load_ssbo(b, DIV_ROUND_UP(subload_num_bits, 32), 32,
293                       buffer, nir_iadd(b, offset, nir_imm_int(b, i / 8)),
294                       .align_mul = 4,
295                       .align_offset = 0,
296                       .access = access);
297
298      /* If we have 2 bytes or less to load we need to adjust the u32 value so
299       * we can always extract the LSB.
300       */
301      if (subload_num_bits <= 16) {
302         nir_ssa_def *shift = nir_imul(b, nir_iand(b, intr->src[1].ssa, nir_imm_int(b, 3)),
303                                          nir_imm_int(b, 8));
304         vec32 = nir_ushr(b, vec32, shift);
305      }
306
307      /* And now comes the pack/unpack step to match the original type. */
308      extract_comps_from_vec32(b, vec32, bit_size, &comps[comp_idx],
309                               subload_num_bits / bit_size);
310      comp_idx += subload_num_bits / bit_size;
311   }
312
313   assert(comp_idx == num_components);
314   nir_ssa_def *result = nir_vec(b, comps, num_components);
315   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
316   nir_instr_remove(&intr->instr);
317   return true;
318}
319
320static bool
321lower_store_ssbo(nir_builder *b, nir_intrinsic_instr *intr)
322{
323   b->cursor = nir_before_instr(&intr->instr);
324
325   assert(intr->src[0].is_ssa);
326   assert(intr->src[1].is_ssa);
327   assert(intr->src[2].is_ssa);
328
329   nir_ssa_def *val = intr->src[0].ssa;
330   nir_ssa_def *buffer = intr->src[1].ssa;
331   nir_ssa_def *offset = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, ~3));
332
333   unsigned bit_size = val->bit_size;
334   unsigned num_components = val->num_components;
335   unsigned num_bits = num_components * bit_size;
336
337   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS] = { 0 };
338   unsigned comp_idx = 0;
339
340   unsigned write_mask = nir_intrinsic_write_mask(intr);
341   for (unsigned i = 0; i < num_components; i++)
342      if (write_mask & (1 << i))
343         comps[i] = nir_channel(b, val, i);
344
345   /* We split stores in 16byte chunks because that's the optimal granularity
346    * of bufferStore(). Minimum alignment is 4byte, which saves from us from
347    * extra complexity to store >= 32 bit components.
348    */
349   unsigned bit_offset = 0;
350   while (true) {
351      /* Skip over holes in the write mask */
352      while (comp_idx < num_components && comps[comp_idx] == NULL) {
353         comp_idx++;
354         bit_offset += bit_size;
355      }
356      if (comp_idx >= num_components)
357         break;
358
359      /* For each 16byte chunk (or smaller) we generate a 32bit ssbo vec
360       * store. If a component is skipped by the write mask, do a smaller
361       * sub-store
362       */
363      unsigned num_src_comps_stored = 0, substore_num_bits = 0;
364      while(num_src_comps_stored + comp_idx < num_components &&
365            substore_num_bits + bit_offset < num_bits &&
366            substore_num_bits < 4 * 32 &&
367            comps[comp_idx + num_src_comps_stored]) {
368         ++num_src_comps_stored;
369         substore_num_bits += bit_size;
370      }
371      nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, bit_offset / 8));
372      nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
373                                               num_src_comps_stored);
374      nir_intrinsic_instr *store;
375
376      if (substore_num_bits < 32) {
377         nir_ssa_def *mask = nir_imm_int(b, (1 << substore_num_bits) - 1);
378
379        /* If we have 16 bits or less to store we need to place them
380         * correctly in the u32 component. Anything greater than 16 bits
381         * (including uchar3) is naturally aligned on 32bits.
382         */
383         if (substore_num_bits <= 16) {
384            nir_ssa_def *pos = nir_iand(b, intr->src[2].ssa, nir_imm_int(b, 3));
385            nir_ssa_def *shift = nir_imul_imm(b, pos, 8);
386
387            vec32 = nir_ishl(b, vec32, shift);
388            mask = nir_ishl(b, mask, shift);
389         }
390
391         store = nir_intrinsic_instr_create(b->shader,
392                                            nir_intrinsic_store_ssbo_masked_dxil);
393         store->src[0] = nir_src_for_ssa(vec32);
394         store->src[1] = nir_src_for_ssa(nir_inot(b, mask));
395         store->src[2] = nir_src_for_ssa(buffer);
396         store->src[3] = nir_src_for_ssa(local_offset);
397      } else {
398         store = nir_intrinsic_instr_create(b->shader,
399                                            nir_intrinsic_store_ssbo);
400         store->src[0] = nir_src_for_ssa(vec32);
401         store->src[1] = nir_src_for_ssa(buffer);
402         store->src[2] = nir_src_for_ssa(local_offset);
403
404         nir_intrinsic_set_align(store, 4, 0);
405      }
406
407      /* The number of components to store depends on the number of bits. */
408      store->num_components = DIV_ROUND_UP(substore_num_bits, 32);
409      nir_builder_instr_insert(b, &store->instr);
410      comp_idx += num_src_comps_stored;
411      bit_offset += substore_num_bits;
412
413      if (nir_intrinsic_has_write_mask(store))
414         nir_intrinsic_set_write_mask(store, (1 << store->num_components) - 1);
415   }
416
417   nir_instr_remove(&intr->instr);
418   return true;
419}
420
421static void
422lower_load_vec32(nir_builder *b, nir_ssa_def *index, unsigned num_comps, nir_ssa_def **comps, nir_intrinsic_op op)
423{
424   for (unsigned i = 0; i < num_comps; i++) {
425      nir_intrinsic_instr *load =
426         nir_intrinsic_instr_create(b->shader, op);
427
428      load->num_components = 1;
429      load->src[0] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
430      nir_ssa_dest_init(&load->instr, &load->dest, 1, 32, NULL);
431      nir_builder_instr_insert(b, &load->instr);
432      comps[i] = &load->dest.ssa;
433   }
434}
435
436static bool
437lower_32b_offset_load(nir_builder *b, nir_intrinsic_instr *intr)
438{
439   assert(intr->dest.is_ssa);
440   unsigned bit_size = nir_dest_bit_size(intr->dest);
441   unsigned num_components = nir_dest_num_components(intr->dest);
442   unsigned num_bits = num_components * bit_size;
443
444   b->cursor = nir_before_instr(&intr->instr);
445   nir_intrinsic_op op = intr->intrinsic;
446
447   assert(intr->src[0].is_ssa);
448   nir_ssa_def *offset = intr->src[0].ssa;
449   if (op == nir_intrinsic_load_shared) {
450      offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
451      op = nir_intrinsic_load_shared_dxil;
452   } else {
453      offset = nir_u2u32(b, offset);
454      op = nir_intrinsic_load_scratch_dxil;
455   }
456   nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
457   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
458   nir_ssa_def *comps_32bit[NIR_MAX_VEC_COMPONENTS * 2];
459
460   /* We need to split loads in 32-bit accesses because the buffer
461    * is an i32 array and DXIL does not support type casts.
462    */
463   unsigned num_32bit_comps = DIV_ROUND_UP(num_bits, 32);
464   lower_load_vec32(b, index, num_32bit_comps, comps_32bit, op);
465   unsigned num_comps_per_pass = MIN2(num_32bit_comps, 4);
466
467   for (unsigned i = 0; i < num_32bit_comps; i += num_comps_per_pass) {
468      unsigned num_vec32_comps = MIN2(num_32bit_comps - i, 4);
469      unsigned num_dest_comps = num_vec32_comps * 32 / bit_size;
470      nir_ssa_def *vec32 = nir_vec(b, &comps_32bit[i], num_vec32_comps);
471
472      /* If we have 16 bits or less to load we need to adjust the u32 value so
473       * we can always extract the LSB.
474       */
475      if (num_bits <= 16) {
476         nir_ssa_def *shift =
477            nir_imul(b, nir_iand(b, offset, nir_imm_int(b, 3)),
478                        nir_imm_int(b, 8));
479         vec32 = nir_ushr(b, vec32, shift);
480      }
481
482      /* And now comes the pack/unpack step to match the original type. */
483      unsigned dest_index = i * 32 / bit_size;
484      extract_comps_from_vec32(b, vec32, bit_size, &comps[dest_index], num_dest_comps);
485   }
486
487   nir_ssa_def *result = nir_vec(b, comps, num_components);
488   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
489   nir_instr_remove(&intr->instr);
490
491   return true;
492}
493
494static void
495lower_store_vec32(nir_builder *b, nir_ssa_def *index, nir_ssa_def *vec32, nir_intrinsic_op op)
496{
497
498   for (unsigned i = 0; i < vec32->num_components; i++) {
499      nir_intrinsic_instr *store =
500         nir_intrinsic_instr_create(b->shader, op);
501
502      store->src[0] = nir_src_for_ssa(nir_channel(b, vec32, i));
503      store->src[1] = nir_src_for_ssa(nir_iadd(b, index, nir_imm_int(b, i)));
504      store->num_components = 1;
505      nir_builder_instr_insert(b, &store->instr);
506   }
507}
508
509static void
510lower_masked_store_vec32(nir_builder *b, nir_ssa_def *offset, nir_ssa_def *index,
511                         nir_ssa_def *vec32, unsigned num_bits, nir_intrinsic_op op)
512{
513   nir_ssa_def *mask = nir_imm_int(b, (1 << num_bits) - 1);
514
515   /* If we have 16 bits or less to store we need to place them correctly in
516    * the u32 component. Anything greater than 16 bits (including uchar3) is
517    * naturally aligned on 32bits.
518    */
519   if (num_bits <= 16) {
520      nir_ssa_def *shift =
521         nir_imul_imm(b, nir_iand(b, offset, nir_imm_int(b, 3)), 8);
522
523      vec32 = nir_ishl(b, vec32, shift);
524      mask = nir_ishl(b, mask, shift);
525   }
526
527   if (op == nir_intrinsic_store_shared_dxil) {
528      /* Use the dedicated masked intrinsic */
529      nir_store_shared_masked_dxil(b, vec32, nir_inot(b, mask), index);
530   } else {
531      /* For scratch, since we don't need atomics, just generate the read-modify-write in NIR */
532      nir_ssa_def *load = nir_load_scratch_dxil(b, 1, 32, index);
533
534      nir_ssa_def *new_val = nir_ior(b, vec32,
535                                     nir_iand(b,
536                                              nir_inot(b, mask),
537                                              load));
538
539      lower_store_vec32(b, index, new_val, op);
540   }
541}
542
543static bool
544lower_32b_offset_store(nir_builder *b, nir_intrinsic_instr *intr)
545{
546   assert(intr->src[0].is_ssa);
547   unsigned num_components = nir_src_num_components(intr->src[0]);
548   unsigned bit_size = nir_src_bit_size(intr->src[0]);
549   unsigned num_bits = num_components * bit_size;
550
551   b->cursor = nir_before_instr(&intr->instr);
552   nir_intrinsic_op op = intr->intrinsic;
553
554   nir_ssa_def *offset = intr->src[1].ssa;
555   if (op == nir_intrinsic_store_shared) {
556      offset = nir_iadd(b, offset, nir_imm_int(b, nir_intrinsic_base(intr)));
557      op = nir_intrinsic_store_shared_dxil;
558   } else {
559      offset = nir_u2u32(b, offset);
560      op = nir_intrinsic_store_scratch_dxil;
561   }
562   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS];
563
564   unsigned comp_idx = 0;
565   for (unsigned i = 0; i < num_components; i++)
566      comps[i] = nir_channel(b, intr->src[0].ssa, i);
567
568   for (unsigned i = 0; i < num_bits; i += 4 * 32) {
569      /* For each 4byte chunk (or smaller) we generate a 32bit scalar store.
570       */
571      unsigned substore_num_bits = MIN2(num_bits - i, 4 * 32);
572      nir_ssa_def *local_offset = nir_iadd(b, offset, nir_imm_int(b, i / 8));
573      nir_ssa_def *vec32 = load_comps_to_vec32(b, bit_size, &comps[comp_idx],
574                                               substore_num_bits / bit_size);
575      nir_ssa_def *index = nir_ushr(b, local_offset, nir_imm_int(b, 2));
576
577      /* For anything less than 32bits we need to use the masked version of the
578       * intrinsic to preserve data living in the same 32bit slot.
579       */
580      if (num_bits < 32) {
581         lower_masked_store_vec32(b, local_offset, index, vec32, num_bits, op);
582      } else {
583         lower_store_vec32(b, index, vec32, op);
584      }
585
586      comp_idx += substore_num_bits / bit_size;
587   }
588
589   nir_instr_remove(&intr->instr);
590
591   return true;
592}
593
594static void
595ubo_to_temp_patch_deref_mode(nir_deref_instr *deref)
596{
597   deref->modes = nir_var_shader_temp;
598   nir_foreach_use(use_src, &deref->dest.ssa) {
599      if (use_src->parent_instr->type != nir_instr_type_deref)
600         continue;
601
602      nir_deref_instr *parent = nir_instr_as_deref(use_src->parent_instr);
603      ubo_to_temp_patch_deref_mode(parent);
604   }
605}
606
607static void
608ubo_to_temp_update_entry(nir_deref_instr *deref, struct hash_entry *he)
609{
610   assert(nir_deref_mode_is(deref, nir_var_mem_constant));
611   assert(deref->dest.is_ssa);
612   assert(he->data);
613
614   nir_foreach_use(use_src, &deref->dest.ssa) {
615      if (use_src->parent_instr->type == nir_instr_type_deref) {
616         ubo_to_temp_update_entry(nir_instr_as_deref(use_src->parent_instr), he);
617      } else if (use_src->parent_instr->type == nir_instr_type_intrinsic) {
618         nir_intrinsic_instr *intr = nir_instr_as_intrinsic(use_src->parent_instr);
619         if (intr->intrinsic != nir_intrinsic_load_deref)
620            he->data = NULL;
621      } else {
622         he->data = NULL;
623      }
624
625      if (!he->data)
626         break;
627   }
628}
629
630bool
631dxil_nir_lower_ubo_to_temp(nir_shader *nir)
632{
633   struct hash_table *ubo_to_temp = _mesa_pointer_hash_table_create(NULL);
634   bool progress = false;
635
636   /* First pass: collect all UBO accesses that could be turned into
637    * shader temp accesses.
638    */
639   foreach_list_typed(nir_function, func, node, &nir->functions) {
640      if (!func->is_entrypoint)
641         continue;
642      assert(func->impl);
643
644      nir_foreach_block(block, func->impl) {
645         nir_foreach_instr_safe(instr, block) {
646            if (instr->type != nir_instr_type_deref)
647               continue;
648
649            nir_deref_instr *deref = nir_instr_as_deref(instr);
650            if (!nir_deref_mode_is(deref, nir_var_mem_constant) ||
651                deref->deref_type != nir_deref_type_var)
652                  continue;
653
654            struct hash_entry *he =
655               _mesa_hash_table_search(ubo_to_temp, deref->var);
656
657            if (!he)
658               he = _mesa_hash_table_insert(ubo_to_temp, deref->var, deref->var);
659
660            if (!he->data)
661               continue;
662
663            ubo_to_temp_update_entry(deref, he);
664         }
665      }
666   }
667
668   hash_table_foreach(ubo_to_temp, he) {
669      nir_variable *var = he->data;
670
671      if (!var)
672         continue;
673
674      /* Change the variable mode. */
675      var->data.mode = nir_var_shader_temp;
676
677      /* Make sure the variable has a name.
678       * DXIL variables must have names.
679       */
680      if (!var->name)
681         var->name = ralloc_asprintf(nir, "global_%d", exec_list_length(&nir->variables));
682
683      progress = true;
684   }
685   _mesa_hash_table_destroy(ubo_to_temp, NULL);
686
687   /* Second pass: patch all derefs that were accessing the converted UBOs
688    * variables.
689    */
690   foreach_list_typed(nir_function, func, node, &nir->functions) {
691      if (!func->is_entrypoint)
692         continue;
693      assert(func->impl);
694
695      nir_foreach_block(block, func->impl) {
696         nir_foreach_instr_safe(instr, block) {
697            if (instr->type != nir_instr_type_deref)
698               continue;
699
700            nir_deref_instr *deref = nir_instr_as_deref(instr);
701            if (nir_deref_mode_is(deref, nir_var_mem_constant) &&
702                deref->deref_type == nir_deref_type_var &&
703                deref->var->data.mode == nir_var_shader_temp)
704               ubo_to_temp_patch_deref_mode(deref);
705         }
706      }
707   }
708
709   return progress;
710}
711
712static bool
713lower_load_ubo(nir_builder *b, nir_intrinsic_instr *intr)
714{
715   assert(intr->dest.is_ssa);
716   assert(intr->src[0].is_ssa);
717   assert(intr->src[1].is_ssa);
718
719   b->cursor = nir_before_instr(&intr->instr);
720
721   nir_ssa_def *result =
722      build_load_ubo_dxil(b, intr->src[0].ssa, intr->src[1].ssa,
723                             nir_dest_num_components(intr->dest),
724                             nir_dest_bit_size(intr->dest));
725
726   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
727   nir_instr_remove(&intr->instr);
728   return true;
729}
730
731bool
732dxil_nir_lower_loads_stores_to_dxil(nir_shader *nir)
733{
734   bool progress = false;
735
736   foreach_list_typed(nir_function, func, node, &nir->functions) {
737      if (!func->is_entrypoint)
738         continue;
739      assert(func->impl);
740
741      nir_builder b;
742      nir_builder_init(&b, func->impl);
743
744      nir_foreach_block(block, func->impl) {
745         nir_foreach_instr_safe(instr, block) {
746            if (instr->type != nir_instr_type_intrinsic)
747               continue;
748            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
749
750            switch (intr->intrinsic) {
751            case nir_intrinsic_load_deref:
752               progress |= lower_load_deref(&b, intr);
753               break;
754            case nir_intrinsic_load_shared:
755            case nir_intrinsic_load_scratch:
756               progress |= lower_32b_offset_load(&b, intr);
757               break;
758            case nir_intrinsic_load_ssbo:
759               progress |= lower_load_ssbo(&b, intr);
760               break;
761            case nir_intrinsic_load_ubo:
762               progress |= lower_load_ubo(&b, intr);
763               break;
764            case nir_intrinsic_store_shared:
765            case nir_intrinsic_store_scratch:
766               progress |= lower_32b_offset_store(&b, intr);
767               break;
768            case nir_intrinsic_store_ssbo:
769               progress |= lower_store_ssbo(&b, intr);
770               break;
771            default:
772               break;
773            }
774         }
775      }
776   }
777
778   return progress;
779}
780
781static bool
782lower_shared_atomic(nir_builder *b, nir_intrinsic_instr *intr,
783                    nir_intrinsic_op dxil_op)
784{
785   b->cursor = nir_before_instr(&intr->instr);
786
787   assert(intr->src[0].is_ssa);
788   nir_ssa_def *offset =
789      nir_iadd(b, intr->src[0].ssa, nir_imm_int(b, nir_intrinsic_base(intr)));
790   nir_ssa_def *index = nir_ushr(b, offset, nir_imm_int(b, 2));
791
792   nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b->shader, dxil_op);
793   atomic->src[0] = nir_src_for_ssa(index);
794   assert(intr->src[1].is_ssa);
795   atomic->src[1] = nir_src_for_ssa(intr->src[1].ssa);
796   if (dxil_op == nir_intrinsic_shared_atomic_comp_swap_dxil) {
797      assert(intr->src[2].is_ssa);
798      atomic->src[2] = nir_src_for_ssa(intr->src[2].ssa);
799   }
800   atomic->num_components = 0;
801   nir_ssa_dest_init(&atomic->instr, &atomic->dest, 1, 32, NULL);
802
803   nir_builder_instr_insert(b, &atomic->instr);
804   nir_ssa_def_rewrite_uses(&intr->dest.ssa, &atomic->dest.ssa);
805   nir_instr_remove(&intr->instr);
806   return true;
807}
808
809bool
810dxil_nir_lower_atomics_to_dxil(nir_shader *nir)
811{
812   bool progress = false;
813
814   foreach_list_typed(nir_function, func, node, &nir->functions) {
815      if (!func->is_entrypoint)
816         continue;
817      assert(func->impl);
818
819      nir_builder b;
820      nir_builder_init(&b, func->impl);
821
822      nir_foreach_block(block, func->impl) {
823         nir_foreach_instr_safe(instr, block) {
824            if (instr->type != nir_instr_type_intrinsic)
825               continue;
826            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
827
828            switch (intr->intrinsic) {
829
830#define ATOMIC(op)                                                            \
831  case nir_intrinsic_shared_atomic_##op:                                     \
832     progress |= lower_shared_atomic(&b, intr,                                \
833                                     nir_intrinsic_shared_atomic_##op##_dxil); \
834     break
835
836            ATOMIC(add);
837            ATOMIC(imin);
838            ATOMIC(umin);
839            ATOMIC(imax);
840            ATOMIC(umax);
841            ATOMIC(and);
842            ATOMIC(or);
843            ATOMIC(xor);
844            ATOMIC(exchange);
845            ATOMIC(comp_swap);
846
847#undef ATOMIC
848            default:
849               break;
850            }
851         }
852      }
853   }
854
855   return progress;
856}
857
858static bool
859lower_deref_ssbo(nir_builder *b, nir_deref_instr *deref)
860{
861   assert(nir_deref_mode_is(deref, nir_var_mem_ssbo));
862   assert(deref->deref_type == nir_deref_type_var ||
863          deref->deref_type == nir_deref_type_cast);
864   nir_variable *var = deref->var;
865
866   b->cursor = nir_before_instr(&deref->instr);
867
868   if (deref->deref_type == nir_deref_type_var) {
869      /* We turn all deref_var into deref_cast and build a pointer value based on
870       * the var binding which encodes the UAV id.
871       */
872      nir_ssa_def *ptr = nir_imm_int64(b, (uint64_t)var->data.binding << 32);
873      nir_deref_instr *deref_cast =
874         nir_build_deref_cast(b, ptr, nir_var_mem_ssbo, deref->type,
875                              glsl_get_explicit_stride(var->type));
876      nir_ssa_def_rewrite_uses(&deref->dest.ssa,
877                               &deref_cast->dest.ssa);
878      nir_instr_remove(&deref->instr);
879
880      deref = deref_cast;
881      return true;
882   }
883   return false;
884}
885
886bool
887dxil_nir_lower_deref_ssbo(nir_shader *nir)
888{
889   bool progress = false;
890
891   foreach_list_typed(nir_function, func, node, &nir->functions) {
892      if (!func->is_entrypoint)
893         continue;
894      assert(func->impl);
895
896      nir_builder b;
897      nir_builder_init(&b, func->impl);
898
899      nir_foreach_block(block, func->impl) {
900         nir_foreach_instr_safe(instr, block) {
901            if (instr->type != nir_instr_type_deref)
902               continue;
903
904            nir_deref_instr *deref = nir_instr_as_deref(instr);
905
906            if (!nir_deref_mode_is(deref, nir_var_mem_ssbo) ||
907                (deref->deref_type != nir_deref_type_var &&
908                 deref->deref_type != nir_deref_type_cast))
909               continue;
910
911            progress |= lower_deref_ssbo(&b, deref);
912         }
913      }
914   }
915
916   return progress;
917}
918
919static bool
920lower_alu_deref_srcs(nir_builder *b, nir_alu_instr *alu)
921{
922   const nir_op_info *info = &nir_op_infos[alu->op];
923   bool progress = false;
924
925   b->cursor = nir_before_instr(&alu->instr);
926
927   for (unsigned i = 0; i < info->num_inputs; i++) {
928      nir_deref_instr *deref = nir_src_as_deref(alu->src[i].src);
929
930      if (!deref)
931         continue;
932
933      nir_deref_path path;
934      nir_deref_path_init(&path, deref, NULL);
935      nir_deref_instr *root_deref = path.path[0];
936      nir_deref_path_finish(&path);
937
938      if (root_deref->deref_type != nir_deref_type_cast)
939         continue;
940
941      nir_ssa_def *ptr =
942         nir_iadd(b, root_deref->parent.ssa,
943                     nir_build_deref_offset(b, deref, cl_type_size_align));
944      nir_instr_rewrite_src(&alu->instr, &alu->src[i].src, nir_src_for_ssa(ptr));
945      progress = true;
946   }
947
948   return progress;
949}
950
951bool
952dxil_nir_opt_alu_deref_srcs(nir_shader *nir)
953{
954   bool progress = false;
955
956   foreach_list_typed(nir_function, func, node, &nir->functions) {
957      if (!func->is_entrypoint)
958         continue;
959      assert(func->impl);
960
961      nir_builder b;
962      nir_builder_init(&b, func->impl);
963
964      nir_foreach_block(block, func->impl) {
965         nir_foreach_instr_safe(instr, block) {
966            if (instr->type != nir_instr_type_alu)
967               continue;
968
969            nir_alu_instr *alu = nir_instr_as_alu(instr);
970            progress |= lower_alu_deref_srcs(&b, alu);
971         }
972      }
973   }
974
975   return progress;
976}
977
978static nir_ssa_def *
979memcpy_load_deref_elem(nir_builder *b, nir_deref_instr *parent,
980                       nir_ssa_def *index)
981{
982   nir_deref_instr *deref;
983
984   index = nir_i2i(b, index, nir_dest_bit_size(parent->dest));
985   assert(parent->deref_type == nir_deref_type_cast);
986   deref = nir_build_deref_ptr_as_array(b, parent, index);
987
988   return nir_load_deref(b, deref);
989}
990
991static void
992memcpy_store_deref_elem(nir_builder *b, nir_deref_instr *parent,
993                        nir_ssa_def *index, nir_ssa_def *value)
994{
995   nir_deref_instr *deref;
996
997   index = nir_i2i(b, index, nir_dest_bit_size(parent->dest));
998   assert(parent->deref_type == nir_deref_type_cast);
999   deref = nir_build_deref_ptr_as_array(b, parent, index);
1000   nir_store_deref(b, deref, value, 1);
1001}
1002
1003static bool
1004lower_memcpy_deref(nir_builder *b, nir_intrinsic_instr *intr)
1005{
1006   nir_deref_instr *dst_deref = nir_src_as_deref(intr->src[0]);
1007   nir_deref_instr *src_deref = nir_src_as_deref(intr->src[1]);
1008   assert(intr->src[2].is_ssa);
1009   nir_ssa_def *num_bytes = intr->src[2].ssa;
1010
1011   assert(dst_deref && src_deref);
1012
1013   b->cursor = nir_after_instr(&intr->instr);
1014
1015   dst_deref = nir_build_deref_cast(b, &dst_deref->dest.ssa, dst_deref->modes,
1016                                       glsl_uint8_t_type(), 1);
1017   src_deref = nir_build_deref_cast(b, &src_deref->dest.ssa, src_deref->modes,
1018                                       glsl_uint8_t_type(), 1);
1019
1020   /*
1021    * We want to avoid 64b instructions, so let's assume we'll always be
1022    * passed a value that fits in a 32b type and truncate the 64b value.
1023    */
1024   num_bytes = nir_u2u32(b, num_bytes);
1025
1026   nir_variable *loop_index_var =
1027     nir_local_variable_create(b->impl, glsl_uint_type(), "loop_index");
1028   nir_deref_instr *loop_index_deref = nir_build_deref_var(b, loop_index_var);
1029   nir_store_deref(b, loop_index_deref, nir_imm_int(b, 0), 1);
1030
1031   nir_loop *loop = nir_push_loop(b);
1032   nir_ssa_def *loop_index = nir_load_deref(b, loop_index_deref);
1033   nir_ssa_def *cmp = nir_ige(b, loop_index, num_bytes);
1034   nir_if *loop_check = nir_push_if(b, cmp);
1035   nir_jump(b, nir_jump_break);
1036   nir_pop_if(b, loop_check);
1037   nir_ssa_def *val = memcpy_load_deref_elem(b, src_deref, loop_index);
1038   memcpy_store_deref_elem(b, dst_deref, loop_index, val);
1039   nir_store_deref(b, loop_index_deref, nir_iadd_imm(b, loop_index, 1), 1);
1040   nir_pop_loop(b, loop);
1041   nir_instr_remove(&intr->instr);
1042   return true;
1043}
1044
1045bool
1046dxil_nir_lower_memcpy_deref(nir_shader *nir)
1047{
1048   bool progress = false;
1049
1050   foreach_list_typed(nir_function, func, node, &nir->functions) {
1051      if (!func->is_entrypoint)
1052         continue;
1053      assert(func->impl);
1054
1055      nir_builder b;
1056      nir_builder_init(&b, func->impl);
1057
1058      nir_foreach_block(block, func->impl) {
1059         nir_foreach_instr_safe(instr, block) {
1060            if (instr->type != nir_instr_type_intrinsic)
1061               continue;
1062
1063            nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1064
1065            if (intr->intrinsic == nir_intrinsic_memcpy_deref)
1066               progress |= lower_memcpy_deref(&b, intr);
1067         }
1068      }
1069   }
1070
1071   return progress;
1072}
1073
1074static void
1075cast_phi(nir_builder *b, nir_phi_instr *phi, unsigned new_bit_size)
1076{
1077   nir_phi_instr *lowered = nir_phi_instr_create(b->shader);
1078   int num_components = 0;
1079   int old_bit_size = phi->dest.ssa.bit_size;
1080
1081   nir_op upcast_op = nir_type_conversion_op(nir_type_uint | old_bit_size,
1082                                             nir_type_uint | new_bit_size,
1083                                             nir_rounding_mode_undef);
1084   nir_op downcast_op = nir_type_conversion_op(nir_type_uint | new_bit_size,
1085                                               nir_type_uint | old_bit_size,
1086                                               nir_rounding_mode_undef);
1087
1088   nir_foreach_phi_src(src, phi) {
1089      assert(num_components == 0 || num_components == src->src.ssa->num_components);
1090      num_components = src->src.ssa->num_components;
1091
1092      b->cursor = nir_after_instr_and_phis(src->src.ssa->parent_instr);
1093
1094      nir_ssa_def *cast = nir_build_alu(b, upcast_op, src->src.ssa, NULL, NULL, NULL);
1095      nir_phi_instr_add_src(lowered, src->pred, nir_src_for_ssa(cast));
1096   }
1097
1098   nir_ssa_dest_init(&lowered->instr, &lowered->dest,
1099                     num_components, new_bit_size, NULL);
1100
1101   b->cursor = nir_before_instr(&phi->instr);
1102   nir_builder_instr_insert(b, &lowered->instr);
1103
1104   b->cursor = nir_after_phis(nir_cursor_current_block(b->cursor));
1105   nir_ssa_def *result = nir_build_alu(b, downcast_op, &lowered->dest.ssa, NULL, NULL, NULL);
1106
1107   nir_ssa_def_rewrite_uses(&phi->dest.ssa, result);
1108   nir_instr_remove(&phi->instr);
1109}
1110
1111static bool
1112upcast_phi_impl(nir_function_impl *impl, unsigned min_bit_size)
1113{
1114   nir_builder b;
1115   nir_builder_init(&b, impl);
1116   bool progress = false;
1117
1118   nir_foreach_block_reverse(block, impl) {
1119      nir_foreach_instr_safe(instr, block) {
1120         if (instr->type != nir_instr_type_phi)
1121            continue;
1122
1123         nir_phi_instr *phi = nir_instr_as_phi(instr);
1124         assert(phi->dest.is_ssa);
1125
1126         if (phi->dest.ssa.bit_size == 1 ||
1127             phi->dest.ssa.bit_size >= min_bit_size)
1128            continue;
1129
1130         cast_phi(&b, phi, min_bit_size);
1131         progress = true;
1132      }
1133   }
1134
1135   if (progress) {
1136      nir_metadata_preserve(impl, nir_metadata_block_index |
1137                                  nir_metadata_dominance);
1138   } else {
1139      nir_metadata_preserve(impl, nir_metadata_all);
1140   }
1141
1142   return progress;
1143}
1144
1145bool
1146dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size)
1147{
1148   bool progress = false;
1149
1150   nir_foreach_function(function, shader) {
1151      if (function->impl)
1152         progress |= upcast_phi_impl(function->impl, min_bit_size);
1153   }
1154
1155   return progress;
1156}
1157
1158struct dxil_nir_split_clip_cull_distance_params {
1159   nir_variable *new_var;
1160   nir_shader *shader;
1161};
1162
1163/* In GLSL and SPIR-V, clip and cull distance are arrays of floats (with a limit of 8).
1164 * In DXIL, clip and cull distances are up to 2 float4s combined.
1165 * Coming from GLSL, we can request this 2 float4 format, but coming from SPIR-V,
1166 * we can't, and have to accept a "compact" array of scalar floats.
1167 *
1168 * To help emitting a valid input signature for this case, split the variables so that they
1169 * match what we need to put in the signature (e.g. { float clip[4]; float clip1; float cull[3]; })
1170 */
1171static bool
1172dxil_nir_split_clip_cull_distance_instr(nir_builder *b,
1173                                        nir_instr *instr,
1174                                        void *cb_data)
1175{
1176   struct dxil_nir_split_clip_cull_distance_params *params = cb_data;
1177   nir_variable *new_var = params->new_var;
1178
1179   if (instr->type != nir_instr_type_deref)
1180      return false;
1181
1182   nir_deref_instr *deref = nir_instr_as_deref(instr);
1183   nir_variable *var = nir_deref_instr_get_variable(deref);
1184   if (!var ||
1185       var->data.location < VARYING_SLOT_CLIP_DIST0 ||
1186       var->data.location > VARYING_SLOT_CULL_DIST1 ||
1187       !var->data.compact)
1188      return false;
1189
1190   /* The location should only be inside clip distance, because clip
1191    * and cull should've been merged by nir_lower_clip_cull_distance_arrays()
1192    */
1193   assert(var->data.location == VARYING_SLOT_CLIP_DIST0 ||
1194          var->data.location == VARYING_SLOT_CLIP_DIST1);
1195
1196   /* The deref chain to the clip/cull variables should be simple, just the
1197    * var and an array with a constant index, otherwise more lowering/optimization
1198    * might be needed before this pass, e.g. copy prop, lower_io_to_temporaries,
1199    * split_var_copies, and/or lower_var_copies. In the case of arrayed I/O like
1200    * inputs to the tessellation or geometry stages, there might be a second level
1201    * of array index.
1202    */
1203   assert(deref->deref_type == nir_deref_type_var ||
1204          deref->deref_type == nir_deref_type_array);
1205
1206   b->cursor = nir_before_instr(instr);
1207   unsigned arrayed_io_length = 0;
1208   const struct glsl_type *old_type = var->type;
1209   if (nir_is_arrayed_io(var, b->shader->info.stage)) {
1210      arrayed_io_length = glsl_array_size(old_type);
1211      old_type = glsl_get_array_element(old_type);
1212   }
1213   if (!new_var) {
1214      /* Update lengths for new and old vars */
1215      int old_length = glsl_array_size(old_type);
1216      int new_length = (old_length + var->data.location_frac) - 4;
1217      old_length -= new_length;
1218
1219      /* The existing variable fits in the float4 */
1220      if (new_length <= 0)
1221         return false;
1222
1223      new_var = nir_variable_clone(var, params->shader);
1224      nir_shader_add_variable(params->shader, new_var);
1225      assert(glsl_get_base_type(glsl_get_array_element(old_type)) == GLSL_TYPE_FLOAT);
1226      var->type = glsl_array_type(glsl_float_type(), old_length, 0);
1227      new_var->type = glsl_array_type(glsl_float_type(), new_length, 0);
1228      if (arrayed_io_length) {
1229         var->type = glsl_array_type(var->type, arrayed_io_length, 0);
1230         new_var->type = glsl_array_type(new_var->type, arrayed_io_length, 0);
1231      }
1232      new_var->data.location++;
1233      new_var->data.location_frac = 0;
1234      params->new_var = new_var;
1235   }
1236
1237   /* Update the type for derefs of the old var */
1238   if (deref->deref_type == nir_deref_type_var) {
1239      deref->type = var->type;
1240      return false;
1241   }
1242
1243   if (glsl_type_is_array(deref->type)) {
1244      assert(arrayed_io_length > 0);
1245      deref->type = glsl_get_array_element(var->type);
1246      return false;
1247   }
1248
1249   assert(glsl_get_base_type(deref->type) == GLSL_TYPE_FLOAT);
1250
1251   nir_const_value *index = nir_src_as_const_value(deref->arr.index);
1252   assert(index);
1253
1254   /* Treat this array as a vector starting at the component index in location_frac,
1255    * so if location_frac is 1 and index is 0, then it's accessing the 'y' component
1256    * of the vector. If index + location_frac is >= 4, there's no component there,
1257    * so we need to add a new variable and adjust the index.
1258    */
1259   unsigned total_index = index->u32 + var->data.location_frac;
1260   if (total_index < 4)
1261      return false;
1262
1263   nir_deref_instr *new_var_deref = nir_build_deref_var(b, new_var);
1264   nir_deref_instr *new_intermediate_deref = new_var_deref;
1265   if (arrayed_io_length) {
1266      nir_deref_instr *parent = nir_src_as_deref(deref->parent);
1267      assert(parent->deref_type == nir_deref_type_array);
1268      new_intermediate_deref = nir_build_deref_array(b, new_intermediate_deref, parent->arr.index.ssa);
1269   }
1270   nir_deref_instr *new_array_deref = nir_build_deref_array(b, new_intermediate_deref, nir_imm_int(b, total_index % 4));
1271   nir_ssa_def_rewrite_uses(&deref->dest.ssa, &new_array_deref->dest.ssa);
1272   return true;
1273}
1274
1275bool
1276dxil_nir_split_clip_cull_distance(nir_shader *shader)
1277{
1278   struct dxil_nir_split_clip_cull_distance_params params = {
1279      .new_var = NULL,
1280      .shader = shader,
1281   };
1282   nir_shader_instructions_pass(shader,
1283                                dxil_nir_split_clip_cull_distance_instr,
1284                                nir_metadata_block_index |
1285                                nir_metadata_dominance |
1286                                nir_metadata_loop_analysis,
1287                                &params);
1288   return params.new_var != NULL;
1289}
1290
1291static bool
1292dxil_nir_lower_double_math_instr(nir_builder *b,
1293                                 nir_instr *instr,
1294                                 UNUSED void *cb_data)
1295{
1296   if (instr->type != nir_instr_type_alu)
1297      return false;
1298
1299   nir_alu_instr *alu = nir_instr_as_alu(instr);
1300
1301   /* TODO: See if we can apply this explicitly to packs/unpacks that are then
1302    * used as a double. As-is, if we had an app explicitly do a 64bit integer op,
1303    * then try to bitcast to double (not expressible in HLSL, but it is in other
1304    * source languages), this would unpack the integer and repack as a double, when
1305    * we probably want to just send the bitcast through to the backend.
1306    */
1307
1308   b->cursor = nir_before_instr(&alu->instr);
1309
1310   bool progress = false;
1311   for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) {
1312      if (nir_alu_type_get_base_type(nir_op_infos[alu->op].input_types[i]) == nir_type_float &&
1313          alu->src[i].src.ssa->bit_size == 64) {
1314         unsigned num_components = nir_op_infos[alu->op].input_sizes[i];
1315         if (!num_components)
1316            num_components = alu->dest.dest.ssa.num_components;
1317         nir_ssa_def *components[NIR_MAX_VEC_COMPONENTS];
1318         for (unsigned c = 0; c < num_components; ++c) {
1319            nir_ssa_def *packed_double = nir_channel(b, alu->src[i].src.ssa, alu->src[i].swizzle[c]);
1320            nir_ssa_def *unpacked_double = nir_unpack_64_2x32(b, packed_double);
1321            components[c] = nir_pack_double_2x32_dxil(b, unpacked_double);
1322            alu->src[i].swizzle[c] = c;
1323         }
1324         nir_instr_rewrite_src_ssa(instr, &alu->src[i].src, nir_vec(b, components, num_components));
1325         progress = true;
1326      }
1327   }
1328
1329   if (nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type) == nir_type_float &&
1330       alu->dest.dest.ssa.bit_size == 64) {
1331      b->cursor = nir_after_instr(&alu->instr);
1332      nir_ssa_def *components[NIR_MAX_VEC_COMPONENTS];
1333      for (unsigned c = 0; c < alu->dest.dest.ssa.num_components; ++c) {
1334         nir_ssa_def *packed_double = nir_channel(b, &alu->dest.dest.ssa, c);
1335         nir_ssa_def *unpacked_double = nir_unpack_double_2x32_dxil(b, packed_double);
1336         components[c] = nir_pack_64_2x32(b, unpacked_double);
1337      }
1338      nir_ssa_def *repacked_dvec = nir_vec(b, components, alu->dest.dest.ssa.num_components);
1339      nir_ssa_def_rewrite_uses_after(&alu->dest.dest.ssa, repacked_dvec, repacked_dvec->parent_instr);
1340      progress = true;
1341   }
1342
1343   return progress;
1344}
1345
1346bool
1347dxil_nir_lower_double_math(nir_shader *shader)
1348{
1349   return nir_shader_instructions_pass(shader,
1350                                       dxil_nir_lower_double_math_instr,
1351                                       nir_metadata_block_index |
1352                                       nir_metadata_dominance |
1353                                       nir_metadata_loop_analysis,
1354                                       NULL);
1355}
1356
1357typedef struct {
1358   gl_system_value *values;
1359   uint32_t count;
1360} zero_system_values_state;
1361
1362static bool
1363lower_system_value_to_zero_filter(const nir_instr* instr, const void* cb_state)
1364{
1365   if (instr->type != nir_instr_type_intrinsic) {
1366      return false;
1367   }
1368
1369   nir_intrinsic_instr* intrin = nir_instr_as_intrinsic(instr);
1370
1371   /* All the intrinsics we care about are loads */
1372   if (!nir_intrinsic_infos[intrin->intrinsic].has_dest)
1373      return false;
1374
1375   assert(intrin->dest.is_ssa);
1376
1377   zero_system_values_state* state = (zero_system_values_state*)cb_state;
1378   for (uint32_t i = 0; i < state->count; ++i) {
1379      gl_system_value value = state->values[i];
1380      nir_intrinsic_op value_op = nir_intrinsic_from_system_value(value);
1381
1382      if (intrin->intrinsic == value_op) {
1383         return true;
1384      } else if (intrin->intrinsic == nir_intrinsic_load_deref) {
1385         nir_deref_instr* deref = nir_src_as_deref(intrin->src[0]);
1386         if (!nir_deref_mode_is(deref, nir_var_system_value))
1387            return false;
1388
1389         nir_variable* var = deref->var;
1390         if (var->data.location == value) {
1391            return true;
1392         }
1393      }
1394   }
1395
1396   return false;
1397}
1398
1399static nir_ssa_def*
1400lower_system_value_to_zero_instr(nir_builder* b, nir_instr* instr, void* _state)
1401{
1402   return nir_imm_int(b, 0);
1403}
1404
1405bool
1406dxil_nir_lower_system_values_to_zero(nir_shader* shader,
1407                                     gl_system_value* system_values,
1408                                     uint32_t count)
1409{
1410   zero_system_values_state state = { system_values, count };
1411   return nir_shader_lower_instructions(shader,
1412      lower_system_value_to_zero_filter,
1413      lower_system_value_to_zero_instr,
1414      &state);
1415}
1416
1417static void
1418lower_load_local_group_size(nir_builder *b, nir_intrinsic_instr *intr)
1419{
1420   b->cursor = nir_after_instr(&intr->instr);
1421
1422   nir_const_value v[3] = {
1423      nir_const_value_for_int(b->shader->info.workgroup_size[0], 32),
1424      nir_const_value_for_int(b->shader->info.workgroup_size[1], 32),
1425      nir_const_value_for_int(b->shader->info.workgroup_size[2], 32)
1426   };
1427   nir_ssa_def *size = nir_build_imm(b, 3, 32, v);
1428   nir_ssa_def_rewrite_uses(&intr->dest.ssa, size);
1429   nir_instr_remove(&intr->instr);
1430}
1431
1432static bool
1433lower_system_values_impl(nir_builder *b, nir_instr *instr, void *_state)
1434{
1435   if (instr->type != nir_instr_type_intrinsic)
1436      return false;
1437   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1438   switch (intr->intrinsic) {
1439   case nir_intrinsic_load_workgroup_size:
1440      lower_load_local_group_size(b, intr);
1441      return true;
1442   default:
1443      return false;
1444   }
1445}
1446
1447bool
1448dxil_nir_lower_system_values(nir_shader *shader)
1449{
1450   return nir_shader_instructions_pass(shader, lower_system_values_impl,
1451      nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, NULL);
1452}
1453
1454static const struct glsl_type *
1455get_bare_samplers_for_type(const struct glsl_type *type, bool is_shadow)
1456{
1457   const struct glsl_type *base_sampler_type =
1458      is_shadow ?
1459      glsl_bare_shadow_sampler_type() : glsl_bare_sampler_type();
1460   return glsl_type_wrap_in_arrays(base_sampler_type, type);
1461}
1462
1463static const struct glsl_type *
1464get_textures_for_sampler_type(const struct glsl_type *type)
1465{
1466   return glsl_type_wrap_in_arrays(
1467      glsl_sampler_type_to_texture(
1468         glsl_without_array(type)), type);
1469}
1470
1471static bool
1472redirect_sampler_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1473{
1474   if (instr->type != nir_instr_type_tex)
1475      return false;
1476
1477   nir_tex_instr *tex = nir_instr_as_tex(instr);
1478
1479   int sampler_idx = nir_tex_instr_src_index(tex, nir_tex_src_sampler_deref);
1480   if (sampler_idx == -1) {
1481      /* No sampler deref - does this instruction even need a sampler? If not,
1482       * sampler_index doesn't necessarily point to a sampler, so early-out.
1483       */
1484      if (!nir_tex_instr_need_sampler(tex))
1485         return false;
1486
1487      /* No derefs but needs a sampler, must be using indices */
1488      nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->sampler_index);
1489
1490      /* Already have a bare sampler here */
1491      if (bare_sampler)
1492         return false;
1493
1494      nir_variable *old_sampler = NULL;
1495      nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1496         if (var->data.binding <= tex->sampler_index &&
1497             var->data.binding + glsl_type_get_sampler_count(var->type) >
1498                tex->sampler_index) {
1499
1500            /* Already have a bare sampler for this binding and it is of the
1501             * correct type, add it to the table */
1502            if (glsl_type_is_bare_sampler(glsl_without_array(var->type)) &&
1503                glsl_sampler_type_is_shadow(glsl_without_array(var->type)) ==
1504                   tex->is_shadow) {
1505               _mesa_hash_table_u64_insert(data, tex->sampler_index, var);
1506               return false;
1507            }
1508
1509            old_sampler = var;
1510         }
1511      }
1512
1513      assert(old_sampler);
1514
1515      /* Clone the original sampler to a bare sampler of the correct type */
1516      bare_sampler = nir_variable_clone(old_sampler, b->shader);
1517      nir_shader_add_variable(b->shader, bare_sampler);
1518
1519      bare_sampler->type =
1520         get_bare_samplers_for_type(old_sampler->type, tex->is_shadow);
1521      _mesa_hash_table_u64_insert(data, tex->sampler_index, bare_sampler);
1522      return true;
1523   }
1524
1525   /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1526   nir_deref_instr *final_deref = nir_src_as_deref(tex->src[sampler_idx].src);
1527   nir_deref_path path;
1528   nir_deref_path_init(&path, final_deref, NULL);
1529
1530   nir_deref_instr *old_tail = path.path[0];
1531   assert(old_tail->deref_type == nir_deref_type_var);
1532   nir_variable *old_var = old_tail->var;
1533   if (glsl_type_is_bare_sampler(glsl_without_array(old_var->type)) &&
1534       glsl_sampler_type_is_shadow(glsl_without_array(old_var->type)) ==
1535          tex->is_shadow) {
1536      nir_deref_path_finish(&path);
1537      return false;
1538   }
1539
1540   uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1541                      old_var->data.binding;
1542   nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1543   if (!new_var) {
1544      new_var = nir_variable_clone(old_var, b->shader);
1545      nir_shader_add_variable(b->shader, new_var);
1546      new_var->type =
1547         get_bare_samplers_for_type(old_var->type, tex->is_shadow);
1548      _mesa_hash_table_u64_insert(data, var_key, new_var);
1549   }
1550
1551   b->cursor = nir_after_instr(&old_tail->instr);
1552   nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1553
1554   for (unsigned i = 1; path.path[i]; ++i) {
1555      b->cursor = nir_after_instr(&path.path[i]->instr);
1556      new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1557   }
1558
1559   nir_deref_path_finish(&path);
1560   nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[sampler_idx].src, &new_tail->dest.ssa);
1561   return true;
1562}
1563
1564static bool
1565redirect_texture_derefs(struct nir_builder *b, nir_instr *instr, void *data)
1566{
1567   if (instr->type != nir_instr_type_tex)
1568      return false;
1569
1570   nir_tex_instr *tex = nir_instr_as_tex(instr);
1571
1572   int texture_idx = nir_tex_instr_src_index(tex, nir_tex_src_texture_deref);
1573   if (texture_idx == -1) {
1574      /* No derefs, must be using indices */
1575      nir_variable *bare_sampler = _mesa_hash_table_u64_search(data, tex->texture_index);
1576
1577      /* Already have a texture here */
1578      if (bare_sampler)
1579         return false;
1580
1581      nir_variable *typed_sampler = NULL;
1582      nir_foreach_variable_with_modes(var, b->shader, nir_var_uniform) {
1583         if (var->data.binding <= tex->texture_index &&
1584             var->data.binding + glsl_type_get_texture_count(var->type) > tex->texture_index) {
1585            /* Already have a texture for this binding, add it to the table */
1586            _mesa_hash_table_u64_insert(data, tex->texture_index, var);
1587            return false;
1588         }
1589
1590         if (var->data.binding <= tex->texture_index &&
1591             var->data.binding + glsl_type_get_sampler_count(var->type) > tex->texture_index &&
1592             !glsl_type_is_bare_sampler(glsl_without_array(var->type))) {
1593            typed_sampler = var;
1594         }
1595      }
1596
1597      /* Clone the typed sampler to a texture and we're done */
1598      assert(typed_sampler);
1599      bare_sampler = nir_variable_clone(typed_sampler, b->shader);
1600      bare_sampler->type = get_textures_for_sampler_type(typed_sampler->type);
1601      nir_shader_add_variable(b->shader, bare_sampler);
1602      _mesa_hash_table_u64_insert(data, tex->texture_index, bare_sampler);
1603      return true;
1604   }
1605
1606   /* Using derefs, means we have to rewrite the deref chain in addition to cloning */
1607   nir_deref_instr *final_deref = nir_src_as_deref(tex->src[texture_idx].src);
1608   nir_deref_path path;
1609   nir_deref_path_init(&path, final_deref, NULL);
1610
1611   nir_deref_instr *old_tail = path.path[0];
1612   assert(old_tail->deref_type == nir_deref_type_var);
1613   nir_variable *old_var = old_tail->var;
1614   if (glsl_type_is_texture(glsl_without_array(old_var->type)) ||
1615       glsl_type_is_image(glsl_without_array(old_var->type))) {
1616      nir_deref_path_finish(&path);
1617      return false;
1618   }
1619
1620   uint64_t var_key = ((uint64_t)old_var->data.descriptor_set << 32) |
1621                      old_var->data.binding;
1622   nir_variable *new_var = _mesa_hash_table_u64_search(data, var_key);
1623   if (!new_var) {
1624      new_var = nir_variable_clone(old_var, b->shader);
1625      new_var->type = get_textures_for_sampler_type(old_var->type);
1626      nir_shader_add_variable(b->shader, new_var);
1627      _mesa_hash_table_u64_insert(data, var_key, new_var);
1628   }
1629
1630   b->cursor = nir_after_instr(&old_tail->instr);
1631   nir_deref_instr *new_tail = nir_build_deref_var(b, new_var);
1632
1633   for (unsigned i = 1; path.path[i]; ++i) {
1634      b->cursor = nir_after_instr(&path.path[i]->instr);
1635      new_tail = nir_build_deref_follower(b, new_tail, path.path[i]);
1636   }
1637
1638   nir_deref_path_finish(&path);
1639   nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[texture_idx].src, &new_tail->dest.ssa);
1640
1641   return true;
1642}
1643
1644bool
1645dxil_nir_split_typed_samplers(nir_shader *nir)
1646{
1647   struct hash_table_u64 *hash_table = _mesa_hash_table_u64_create(NULL);
1648
1649   bool progress = nir_shader_instructions_pass(nir, redirect_sampler_derefs,
1650      nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, hash_table);
1651
1652   _mesa_hash_table_u64_clear(hash_table);
1653
1654   progress |= nir_shader_instructions_pass(nir, redirect_texture_derefs,
1655      nir_metadata_block_index | nir_metadata_dominance | nir_metadata_loop_analysis, hash_table);
1656
1657   _mesa_hash_table_u64_destroy(hash_table);
1658   return progress;
1659}
1660
1661
1662static bool
1663lower_bool_input_filter(const nir_instr *instr,
1664                        UNUSED const void *_options)
1665{
1666   if (instr->type != nir_instr_type_intrinsic)
1667      return false;
1668
1669   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1670   if (intr->intrinsic == nir_intrinsic_load_front_face)
1671      return true;
1672
1673   if (intr->intrinsic == nir_intrinsic_load_deref) {
1674      nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
1675      nir_variable *var = nir_deref_instr_get_variable(deref);
1676      return var->data.mode == nir_var_shader_in &&
1677             glsl_get_base_type(var->type) == GLSL_TYPE_BOOL;
1678   }
1679
1680   return false;
1681}
1682
1683static nir_ssa_def *
1684lower_bool_input_impl(nir_builder *b, nir_instr *instr,
1685                      UNUSED void *_options)
1686{
1687   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1688
1689   if (intr->intrinsic == nir_intrinsic_load_deref) {
1690      nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
1691      nir_variable *var = nir_deref_instr_get_variable(deref);
1692
1693      /* rewrite var->type */
1694      var->type = glsl_vector_type(GLSL_TYPE_UINT,
1695                                   glsl_get_vector_elements(var->type));
1696      deref->type = var->type;
1697   }
1698
1699   intr->dest.ssa.bit_size = 32;
1700   return nir_i2b1(b, &intr->dest.ssa);
1701}
1702
1703bool
1704dxil_nir_lower_bool_input(struct nir_shader *s)
1705{
1706   return nir_shader_lower_instructions(s, lower_bool_input_filter,
1707                                        lower_bool_input_impl, NULL);
1708}
1709
1710static bool
1711lower_sysval_to_load_input_impl(nir_builder *b, nir_instr *instr, void *data)
1712{
1713   if (instr->type != nir_instr_type_intrinsic)
1714      return false;
1715
1716   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
1717   gl_system_value sysval = SYSTEM_VALUE_MAX;
1718   switch (intr->intrinsic) {
1719   case nir_intrinsic_load_front_face:
1720      sysval = SYSTEM_VALUE_FRONT_FACE;
1721      break;
1722   case nir_intrinsic_load_instance_id:
1723      sysval = SYSTEM_VALUE_INSTANCE_ID;
1724      break;
1725   case nir_intrinsic_load_vertex_id_zero_base:
1726      sysval = SYSTEM_VALUE_VERTEX_ID_ZERO_BASE;
1727      break;
1728   default:
1729      return false;
1730   }
1731
1732   nir_variable **sysval_vars = (nir_variable **)data;
1733   nir_variable *var = sysval_vars[sysval];
1734   assert(var);
1735
1736   b->cursor = nir_before_instr(instr);
1737   nir_ssa_def *result = nir_build_load_input(b, intr->dest.ssa.num_components, intr->dest.ssa.bit_size, nir_imm_int(b, 0),
1738      .base = var->data.driver_location, .dest_type = nir_get_nir_type_for_glsl_type(var->type));
1739   nir_ssa_def_rewrite_uses(&intr->dest.ssa, result);
1740   return true;
1741}
1742
1743bool
1744dxil_nir_lower_sysval_to_load_input(nir_shader *s, nir_variable **sysval_vars)
1745{
1746   return nir_shader_instructions_pass(s, lower_sysval_to_load_input_impl,
1747      nir_metadata_block_index | nir_metadata_dominance, sysval_vars);
1748}
1749
1750/* Comparison function to sort io values so that first come normal varyings,
1751 * then system values, and then system generated values.
1752 */
1753static int
1754variable_location_cmp(const nir_variable* a, const nir_variable* b)
1755{
1756   // Sort by stream, driver_location, location, location_frac, then index
1757   unsigned a_location = a->data.location;
1758   if (a_location >= VARYING_SLOT_PATCH0)
1759      a_location -= VARYING_SLOT_PATCH0;
1760   unsigned b_location = b->data.location;
1761   if (b_location >= VARYING_SLOT_PATCH0)
1762      b_location -= VARYING_SLOT_PATCH0;
1763   unsigned a_stream = a->data.stream & ~NIR_STREAM_PACKED;
1764   unsigned b_stream = b->data.stream & ~NIR_STREAM_PACKED;
1765   return a_stream != b_stream ?
1766            a_stream - b_stream :
1767            a->data.driver_location != b->data.driver_location ?
1768               a->data.driver_location - b->data.driver_location :
1769               a_location !=  b_location ?
1770                  a_location - b_location :
1771                  a->data.location_frac != b->data.location_frac ?
1772                     a->data.location_frac - b->data.location_frac :
1773                     a->data.index - b->data.index;
1774}
1775
1776/* Order varyings according to driver location */
1777uint64_t
1778dxil_sort_by_driver_location(nir_shader* s, nir_variable_mode modes)
1779{
1780   nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1781
1782   uint64_t result = 0;
1783   nir_foreach_variable_with_modes(var, s, modes) {
1784      result |= 1ull << var->data.location;
1785   }
1786   return result;
1787}
1788
1789/* Sort PS outputs so that color outputs come first */
1790void
1791dxil_sort_ps_outputs(nir_shader* s)
1792{
1793   nir_foreach_variable_with_modes_safe(var, s, nir_var_shader_out) {
1794      /* We use the driver_location here to avoid introducing a new
1795       * struct or member variable here. The true, updated driver location
1796       * will be written below, after sorting */
1797      switch (var->data.location) {
1798      case FRAG_RESULT_DEPTH:
1799         var->data.driver_location = 1;
1800         break;
1801      case FRAG_RESULT_STENCIL:
1802         var->data.driver_location = 2;
1803         break;
1804      case FRAG_RESULT_SAMPLE_MASK:
1805         var->data.driver_location = 3;
1806         break;
1807      default:
1808         var->data.driver_location = 0;
1809      }
1810   }
1811
1812   nir_sort_variables_with_modes(s, variable_location_cmp,
1813                                 nir_var_shader_out);
1814
1815   unsigned driver_loc = 0;
1816   nir_foreach_variable_with_modes(var, s, nir_var_shader_out) {
1817      var->data.driver_location = driver_loc++;
1818   }
1819}
1820
1821/* Order between stage values so that normal varyings come first,
1822 * then sysvalues and then system generated values.
1823 */
1824uint64_t
1825dxil_reassign_driver_locations(nir_shader* s, nir_variable_mode modes,
1826   uint64_t other_stage_mask)
1827{
1828   nir_foreach_variable_with_modes_safe(var, s, modes) {
1829      /* We use the driver_location here to avoid introducing a new
1830       * struct or member variable here. The true, updated driver location
1831       * will be written below, after sorting */
1832      var->data.driver_location = nir_var_to_dxil_sysvalue_type(var, other_stage_mask);
1833   }
1834
1835   nir_sort_variables_with_modes(s, variable_location_cmp, modes);
1836
1837   uint64_t result = 0;
1838   unsigned driver_loc = 0, driver_patch_loc = 0;
1839   nir_foreach_variable_with_modes(var, s, modes) {
1840      if (var->data.location < 64)
1841         result |= 1ull << var->data.location;
1842      /* Overlap patches with non-patch */
1843      var->data.driver_location = var->data.patch ?
1844         driver_patch_loc++ : driver_loc++;
1845   }
1846   return result;
1847}
1848
1849static bool
1850lower_ubo_array_one_to_static(struct nir_builder *b, nir_instr *inst,
1851                              void *cb_data)
1852{
1853   if (inst->type != nir_instr_type_intrinsic)
1854      return false;
1855
1856   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(inst);
1857
1858   if (intrin->intrinsic != nir_intrinsic_load_vulkan_descriptor)
1859      return false;
1860
1861   nir_variable *var =
1862      nir_get_binding_variable(b->shader, nir_chase_binding(intrin->src[0]));
1863
1864   if (!var)
1865      return false;
1866
1867   if (!glsl_type_is_array(var->type) || glsl_array_size(var->type) != 1)
1868      return false;
1869
1870   nir_intrinsic_instr *index = nir_src_as_intrinsic(intrin->src[0]);
1871   /* We currently do not support reindex */
1872   assert(index && index->intrinsic == nir_intrinsic_vulkan_resource_index);
1873
1874   if (nir_src_is_const(index->src[0]) && nir_src_as_uint(index->src[0]) == 0)
1875      return false;
1876
1877   if (nir_intrinsic_desc_type(index) != VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER)
1878      return false;
1879
1880   b->cursor = nir_instr_remove(&index->instr);
1881
1882   // Indexing out of bounds on array of UBOs is considered undefined
1883   // behavior. Therefore, we just hardcode all the index to 0.
1884   uint8_t bit_size = index->dest.ssa.bit_size;
1885   nir_ssa_def *zero = nir_imm_intN_t(b, 0, bit_size);
1886   nir_ssa_def *dest =
1887      nir_vulkan_resource_index(b, index->num_components, bit_size, zero,
1888                                .desc_set = nir_intrinsic_desc_set(index),
1889                                .binding = nir_intrinsic_binding(index),
1890                                .desc_type = nir_intrinsic_desc_type(index));
1891
1892   nir_ssa_def_rewrite_uses(&index->dest.ssa, dest);
1893
1894   return true;
1895}
1896
1897bool
1898dxil_nir_lower_ubo_array_one_to_static(nir_shader *s)
1899{
1900   bool progress = nir_shader_instructions_pass(
1901      s, lower_ubo_array_one_to_static, nir_metadata_none, NULL);
1902
1903   return progress;
1904}
1905
1906static bool
1907is_fquantize2f16(const nir_instr *instr, const void *data)
1908{
1909   if (instr->type != nir_instr_type_alu)
1910      return false;
1911
1912   nir_alu_instr *alu = nir_instr_as_alu(instr);
1913   return alu->op == nir_op_fquantize2f16;
1914}
1915
1916static nir_ssa_def *
1917lower_fquantize2f16(struct nir_builder *b, nir_instr *instr, void *data)
1918{
1919   /*
1920    * SpvOpQuantizeToF16 documentation says:
1921    *
1922    * "
1923    * If Value is an infinity, the result is the same infinity.
1924    * If Value is a NaN, the result is a NaN, but not necessarily the same NaN.
1925    * If Value is positive with a magnitude too large to represent as a 16-bit
1926    * floating-point value, the result is positive infinity. If Value is negative
1927    * with a magnitude too large to represent as a 16-bit floating-point value,
1928    * the result is negative infinity. If the magnitude of Value is too small to
1929    * represent as a normalized 16-bit floating-point value, the result may be
1930    * either +0 or -0.
1931    * "
1932    *
1933    * which we turn into:
1934    *
1935    *   if (val < MIN_FLOAT16)
1936    *      return -INFINITY;
1937    *   else if (val > MAX_FLOAT16)
1938    *      return -INFINITY;
1939    *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) != 0)
1940    *      return -0.0f;
1941    *   else if (fabs(val) < SMALLEST_NORMALIZED_FLOAT16 && sign(val) == 0)
1942    *      return +0.0f;
1943    *   else
1944    *      return round(val);
1945    */
1946   nir_alu_instr *alu = nir_instr_as_alu(instr);
1947   nir_ssa_def *src =
1948      nir_ssa_for_src(b, alu->src[0].src, nir_src_num_components(alu->src[0].src));
1949
1950   nir_ssa_def *neg_inf_cond =
1951      nir_flt(b, src, nir_imm_float(b, -65504.0f));
1952   nir_ssa_def *pos_inf_cond =
1953      nir_flt(b, nir_imm_float(b, 65504.0f), src);
1954   nir_ssa_def *zero_cond =
1955      nir_flt(b, nir_fabs(b, src), nir_imm_float(b, ldexpf(1.0, -14)));
1956   nir_ssa_def *zero = nir_iand_imm(b, src, 1 << 31);
1957   nir_ssa_def *round = nir_iand_imm(b, src, ~BITFIELD_MASK(13));
1958
1959   nir_ssa_def *res =
1960      nir_bcsel(b, neg_inf_cond, nir_imm_float(b, -INFINITY), round);
1961   res = nir_bcsel(b, pos_inf_cond, nir_imm_float(b, INFINITY), res);
1962   res = nir_bcsel(b, zero_cond, zero, res);
1963   return res;
1964}
1965
1966bool
1967dxil_nir_lower_fquantize2f16(nir_shader *s)
1968{
1969   return nir_shader_lower_instructions(s, is_fquantize2f16, lower_fquantize2f16, NULL);
1970}
1971
1972static bool
1973fix_io_uint_deref_types(struct nir_builder *builder, nir_instr *instr, void *data)
1974{
1975   if (instr->type != nir_instr_type_deref)
1976      return false;
1977
1978   nir_deref_instr *deref = nir_instr_as_deref(instr);
1979   nir_variable *var =
1980      deref->deref_type == nir_deref_type_var ? deref->var : NULL;
1981
1982   if (var == data) {
1983      deref->type = var->type;
1984      return true;
1985   }
1986
1987   return false;
1988}
1989
1990static bool
1991fix_io_uint_type(nir_shader *s, nir_variable_mode modes, int slot)
1992{
1993   nir_variable *fixed_var = NULL;
1994   nir_foreach_variable_with_modes(var, s, modes) {
1995      if (var->data.location == slot) {
1996         if (var->type == glsl_uint_type())
1997            return false;
1998
1999         assert(var->type == glsl_int_type());
2000         var->type = glsl_uint_type();
2001         fixed_var = var;
2002         break;
2003      }
2004   }
2005
2006   assert(fixed_var);
2007
2008   return nir_shader_instructions_pass(s, fix_io_uint_deref_types,
2009                                       nir_metadata_all, fixed_var);
2010}
2011
2012bool
2013dxil_nir_fix_io_uint_type(nir_shader *s, uint64_t in_mask, uint64_t out_mask)
2014{
2015   if (!(s->info.outputs_written & out_mask) &&
2016       !(s->info.inputs_read & in_mask))
2017      return false;
2018
2019   bool progress = false;
2020
2021   while (in_mask) {
2022      int slot = u_bit_scan64(&in_mask);
2023      progress |= (s->info.inputs_read & (1ull << slot)) &&
2024                  fix_io_uint_type(s, nir_var_shader_in, slot);
2025   }
2026
2027   while (out_mask) {
2028      int slot = u_bit_scan64(&out_mask);
2029      progress |= (s->info.outputs_written & (1ull << slot)) &&
2030                  fix_io_uint_type(s, nir_var_shader_out, slot);
2031   }
2032
2033   return progress;
2034}
2035
2036struct remove_after_discard_state {
2037   struct nir_block *active_block;
2038};
2039
2040static bool
2041remove_after_discard(struct nir_builder *builder, nir_instr *instr,
2042                      void *cb_data)
2043{
2044   struct remove_after_discard_state *state = cb_data;
2045   if (instr->block == state->active_block) {
2046      nir_instr_remove_v(instr);
2047      return true;
2048   }
2049
2050   if (instr->type != nir_instr_type_intrinsic)
2051      return false;
2052
2053   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2054
2055   if (intr->intrinsic != nir_intrinsic_discard &&
2056       intr->intrinsic != nir_intrinsic_terminate &&
2057       intr->intrinsic != nir_intrinsic_discard_if &&
2058       intr->intrinsic != nir_intrinsic_terminate_if)
2059      return false;
2060
2061   state->active_block = instr->block;
2062
2063   return false;
2064}
2065
2066static bool
2067lower_kill(struct nir_builder *builder, nir_instr *instr, void *_cb_data)
2068{
2069   if (instr->type != nir_instr_type_intrinsic)
2070      return false;
2071
2072   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2073
2074   if (intr->intrinsic != nir_intrinsic_discard &&
2075       intr->intrinsic != nir_intrinsic_terminate &&
2076       intr->intrinsic != nir_intrinsic_discard_if &&
2077       intr->intrinsic != nir_intrinsic_terminate_if)
2078      return false;
2079
2080   builder->cursor = nir_instr_remove(instr);
2081   if (intr->intrinsic == nir_intrinsic_discard ||
2082       intr->intrinsic == nir_intrinsic_terminate) {
2083      nir_demote(builder);
2084   } else {
2085      assert(intr->src[0].is_ssa);
2086      nir_demote_if(builder, intr->src[0].ssa);
2087   }
2088
2089   nir_jump(builder, nir_jump_return);
2090
2091   return true;
2092}
2093
2094bool
2095dxil_nir_lower_discard_and_terminate(nir_shader *s)
2096{
2097   if (s->info.stage != MESA_SHADER_FRAGMENT)
2098      return false;
2099
2100   // This pass only works if all functions have been inlined
2101   assert(exec_list_length(&s->functions) == 1);
2102   struct remove_after_discard_state state;
2103   state.active_block = NULL;
2104   nir_shader_instructions_pass(s, remove_after_discard, nir_metadata_none,
2105                                &state);
2106   return nir_shader_instructions_pass(s, lower_kill, nir_metadata_none,
2107                                       NULL);
2108}
2109
2110static bool
2111update_writes(struct nir_builder *b, nir_instr *instr, void *_state)
2112{
2113   if (instr->type != nir_instr_type_intrinsic)
2114      return false;
2115   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
2116   if (intr->intrinsic != nir_intrinsic_store_output)
2117      return false;
2118
2119   nir_io_semantics io = nir_intrinsic_io_semantics(intr);
2120   if (io.location != VARYING_SLOT_POS)
2121      return false;
2122
2123   nir_ssa_def *src = intr->src[0].ssa;
2124   unsigned write_mask = nir_intrinsic_write_mask(intr);
2125   if (src->num_components == 4 && write_mask == 0xf)
2126      return false;
2127
2128   b->cursor = nir_before_instr(instr);
2129   unsigned first_comp = nir_intrinsic_component(intr);
2130   nir_ssa_def *channels[4] = { NULL, NULL, NULL, NULL };
2131   assert(first_comp + src->num_components <= ARRAY_SIZE(channels));
2132   for (unsigned i = 0; i < src->num_components; ++i)
2133      if (write_mask & (1 << i))
2134         channels[i + first_comp] = nir_channel(b, src, i);
2135   for (unsigned i = 0; i < 4; ++i)
2136      if (!channels[i])
2137         channels[i] = nir_imm_intN_t(b, 0, src->bit_size);
2138
2139   nir_instr_rewrite_src_ssa(instr, &intr->src[0], nir_vec(b, channels, 4));
2140   nir_intrinsic_set_component(intr, 0);
2141   nir_intrinsic_set_write_mask(intr, 0xf);
2142   return true;
2143}
2144
2145bool
2146dxil_nir_ensure_position_writes(nir_shader *s)
2147{
2148   if (s->info.stage != MESA_SHADER_VERTEX &&
2149       s->info.stage != MESA_SHADER_GEOMETRY &&
2150       s->info.stage != MESA_SHADER_TESS_EVAL)
2151      return false;
2152   if ((s->info.outputs_written & VARYING_BIT_POS) == 0)
2153      return false;
2154
2155   return nir_shader_instructions_pass(s, update_writes,
2156                                       nir_metadata_block_index | nir_metadata_dominance,
2157                                       NULL);
2158}
2159