1/*
2 * Copyright (C) 2020 Google, Inc.
3 * Copyright (C) 2021 Advanced Micro Devices, Inc.
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "nir.h"
26#include "nir_builder.h"
27
28/**
29 * Return the intrinsic if it matches the mask in "modes", else return NULL.
30 */
31static nir_intrinsic_instr *
32get_io_intrinsic(nir_instr *instr, nir_variable_mode modes,
33                 nir_variable_mode *out_mode)
34{
35   if (instr->type != nir_instr_type_intrinsic)
36      return NULL;
37
38   nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
39
40   switch (intr->intrinsic) {
41   case nir_intrinsic_load_input:
42   case nir_intrinsic_load_input_vertex:
43   case nir_intrinsic_load_interpolated_input:
44   case nir_intrinsic_load_per_vertex_input:
45      *out_mode = nir_var_shader_in;
46      return modes & nir_var_shader_in ? intr : NULL;
47   case nir_intrinsic_load_output:
48   case nir_intrinsic_load_per_vertex_output:
49   case nir_intrinsic_store_output:
50   case nir_intrinsic_store_per_vertex_output:
51      *out_mode = nir_var_shader_out;
52      return modes & nir_var_shader_out ? intr : NULL;
53   default:
54      return NULL;
55   }
56}
57
58/**
59 * Recompute the IO "base" indices from scratch to remove holes or to fix
60 * incorrect base values due to changes in IO locations by using IO locations
61 * to assign new bases. The mapping from locations to bases becomes
62 * monotonically increasing.
63 */
64bool
65nir_recompute_io_bases(nir_shader *nir, nir_variable_mode modes)
66{
67   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
68
69   BITSET_DECLARE(inputs, NUM_TOTAL_VARYING_SLOTS);
70   BITSET_DECLARE(outputs, NUM_TOTAL_VARYING_SLOTS);
71   BITSET_ZERO(inputs);
72   BITSET_ZERO(outputs);
73
74   /* Gather the bitmasks of used locations. */
75   nir_foreach_block_safe (block, impl) {
76      nir_foreach_instr_safe (instr, block) {
77         nir_variable_mode mode;
78         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
79         if (!intr)
80            continue;
81
82         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
83         unsigned num_slots = sem.num_slots;
84         if (sem.medium_precision)
85            num_slots = (num_slots + sem.high_16bits + 1) / 2;
86
87         if (mode == nir_var_shader_in) {
88            for (unsigned i = 0; i < num_slots; i++)
89               BITSET_SET(inputs, sem.location + i);
90         } else if (!sem.dual_source_blend_index) {
91            for (unsigned i = 0; i < num_slots; i++)
92               BITSET_SET(outputs, sem.location + i);
93         }
94      }
95   }
96
97   /* Renumber bases. */
98   bool changed = false;
99
100   nir_foreach_block_safe (block, impl) {
101      nir_foreach_instr_safe (instr, block) {
102         nir_variable_mode mode;
103         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
104         if (!intr)
105            continue;
106
107         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
108         unsigned num_slots = sem.num_slots;
109         if (sem.medium_precision)
110            num_slots = (num_slots + sem.high_16bits + 1) / 2;
111
112         if (mode == nir_var_shader_in) {
113            nir_intrinsic_set_base(intr,
114                                   BITSET_PREFIX_SUM(inputs, sem.location));
115         } else if (sem.dual_source_blend_index) {
116            nir_intrinsic_set_base(intr,
117                                   BITSET_PREFIX_SUM(outputs, NUM_TOTAL_VARYING_SLOTS));
118         } else {
119            nir_intrinsic_set_base(intr,
120                                   BITSET_PREFIX_SUM(outputs, sem.location));
121         }
122         changed = true;
123      }
124   }
125
126   if (changed) {
127      nir_metadata_preserve(impl, nir_metadata_dominance |
128                                  nir_metadata_block_index);
129   } else {
130      nir_metadata_preserve(impl, nir_metadata_all);
131   }
132
133   return changed;
134}
135
136/**
137 * Lower mediump inputs and/or outputs to 16 bits.
138 *
139 * \param modes            Whether to lower inputs, outputs, or both.
140 * \param varying_mask     Determines which varyings to skip (VS inputs,
141 *    FS outputs, and patch varyings ignore this mask).
142 * \param use_16bit_slots  Remap lowered slots to* VARYING_SLOT_VARn_16BIT.
143 */
144bool
145nir_lower_mediump_io(nir_shader *nir, nir_variable_mode modes,
146                     uint64_t varying_mask, bool use_16bit_slots)
147{
148   bool changed = false;
149   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
150   assert(impl);
151
152   nir_builder b;
153   nir_builder_init(&b, impl);
154
155   nir_foreach_block_safe (block, impl) {
156      nir_foreach_instr_safe (instr, block) {
157         nir_variable_mode mode;
158         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
159         if (!intr)
160            continue;
161
162         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
163         nir_ssa_def *(*convert)(nir_builder *, nir_ssa_def *);
164         bool is_varying = !(nir->info.stage == MESA_SHADER_VERTEX &&
165                             mode == nir_var_shader_in) &&
166                           !(nir->info.stage == MESA_SHADER_FRAGMENT &&
167                             mode == nir_var_shader_out);
168
169         if (!sem.medium_precision ||
170             (is_varying && sem.location <= VARYING_SLOT_VAR31 &&
171              !(varying_mask & BITFIELD64_BIT(sem.location))))
172            continue; /* can't lower */
173
174         if (nir_intrinsic_has_src_type(intr)) {
175            /* Stores. */
176            nir_alu_type type = nir_intrinsic_src_type(intr);
177
178            switch (type) {
179            case nir_type_float32:
180               convert = nir_f2fmp;
181               break;
182            case nir_type_int32:
183            case nir_type_uint32:
184               convert = nir_i2imp;
185               break;
186            default:
187               continue; /* already lowered? */
188            }
189
190            /* Convert the 32-bit store into a 16-bit store. */
191            b.cursor = nir_before_instr(&intr->instr);
192            nir_instr_rewrite_src_ssa(&intr->instr, &intr->src[0],
193                                      convert(&b, intr->src[0].ssa));
194            nir_intrinsic_set_src_type(intr, (type & ~32) | 16);
195         } else {
196            /* Loads. */
197            nir_alu_type type = nir_intrinsic_dest_type(intr);
198
199            switch (type) {
200            case nir_type_float32:
201               convert = nir_f2f32;
202               break;
203            case nir_type_int32:
204               convert = nir_i2i32;
205               break;
206            case nir_type_uint32:
207               convert = nir_u2u32;
208               break;
209            default:
210               continue; /* already lowered? */
211            }
212
213            /* Convert the 32-bit load into a 16-bit load. */
214            b.cursor = nir_after_instr(&intr->instr);
215            intr->dest.ssa.bit_size = 16;
216            nir_intrinsic_set_dest_type(intr, (type & ~32) | 16);
217            nir_ssa_def *dst = convert(&b, &intr->dest.ssa);
218            nir_ssa_def_rewrite_uses_after(&intr->dest.ssa, dst,
219                                           dst->parent_instr);
220         }
221
222         if (use_16bit_slots && is_varying &&
223             sem.location >= VARYING_SLOT_VAR0 &&
224             sem.location <= VARYING_SLOT_VAR31) {
225            unsigned index = sem.location - VARYING_SLOT_VAR0;
226
227            sem.location = VARYING_SLOT_VAR0_16BIT + index / 2;
228            sem.high_16bits = index % 2;
229            nir_intrinsic_set_io_semantics(intr, sem);
230         }
231         changed = true;
232      }
233   }
234
235   if (changed && use_16bit_slots)
236      nir_recompute_io_bases(nir, modes);
237
238   if (changed) {
239      nir_metadata_preserve(impl, nir_metadata_dominance |
240                                  nir_metadata_block_index);
241   } else {
242      nir_metadata_preserve(impl, nir_metadata_all);
243   }
244
245   return changed;
246}
247
248/**
249 * Set the mediump precision bit for those shader inputs and outputs that are
250 * set in the "modes" mask. Non-generic varyings (that GLES3 doesn't have)
251 * are ignored. The "types" mask can be (nir_type_float | nir_type_int), etc.
252 */
253bool
254nir_force_mediump_io(nir_shader *nir, nir_variable_mode modes,
255                     nir_alu_type types)
256{
257   bool changed = false;
258   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
259   assert(impl);
260
261   nir_builder b;
262   nir_builder_init(&b, impl);
263
264   nir_foreach_block_safe (block, impl) {
265      nir_foreach_instr_safe (instr, block) {
266         nir_variable_mode mode;
267         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
268         if (!intr)
269            continue;
270
271         nir_alu_type type;
272         if (nir_intrinsic_has_src_type(intr))
273            type = nir_intrinsic_src_type(intr);
274         else
275            type = nir_intrinsic_dest_type(intr);
276         if (!(type & types))
277            continue;
278
279         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
280
281         if (nir->info.stage == MESA_SHADER_FRAGMENT &&
282             mode == nir_var_shader_out) {
283            /* Only accept FS outputs. */
284            if (sem.location < FRAG_RESULT_DATA0 &&
285                sem.location != FRAG_RESULT_COLOR)
286               continue;
287         } else if (nir->info.stage == MESA_SHADER_VERTEX &&
288                    mode == nir_var_shader_in) {
289            /* Accept all VS inputs. */
290         } else {
291            /* Only accept generic varyings. */
292            if (sem.location < VARYING_SLOT_VAR0 ||
293                sem.location > VARYING_SLOT_VAR31)
294            continue;
295         }
296
297         sem.medium_precision = 1;
298         nir_intrinsic_set_io_semantics(intr, sem);
299         changed = true;
300      }
301   }
302
303   if (changed) {
304      nir_metadata_preserve(impl, nir_metadata_dominance |
305                                  nir_metadata_block_index);
306   } else {
307      nir_metadata_preserve(impl, nir_metadata_all);
308   }
309
310   return changed;
311}
312
313/**
314 * Remap 16-bit varying slots to the original 32-bit varying slots.
315 * This only changes IO semantics and bases.
316 */
317bool
318nir_unpack_16bit_varying_slots(nir_shader *nir, nir_variable_mode modes)
319{
320   bool changed = false;
321   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
322   assert(impl);
323
324   nir_foreach_block_safe (block, impl) {
325      nir_foreach_instr_safe (instr, block) {
326         nir_variable_mode mode;
327         nir_intrinsic_instr *intr = get_io_intrinsic(instr, modes, &mode);
328         if (!intr)
329            continue;
330
331         nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
332
333         if (sem.location < VARYING_SLOT_VAR0_16BIT ||
334             sem.location > VARYING_SLOT_VAR15_16BIT)
335            continue;
336
337         sem.location = VARYING_SLOT_VAR0 +
338                        (sem.location - VARYING_SLOT_VAR0_16BIT) * 2 +
339                        sem.high_16bits;
340         sem.high_16bits = 0;
341         nir_intrinsic_set_io_semantics(intr, sem);
342         changed = true;
343      }
344   }
345
346   if (changed)
347      nir_recompute_io_bases(nir, modes);
348
349   if (changed) {
350      nir_metadata_preserve(impl, nir_metadata_dominance |
351                                  nir_metadata_block_index);
352   } else {
353      nir_metadata_preserve(impl, nir_metadata_all);
354   }
355
356   return changed;
357}
358
359static bool
360is_n_to_m_conversion(nir_instr *instr, unsigned n, nir_op m)
361{
362   if (instr->type != nir_instr_type_alu)
363      return false;
364
365   nir_alu_instr *alu = nir_instr_as_alu(instr);
366   return alu->op == m && alu->src[0].src.ssa->bit_size == n;
367}
368
369static bool
370is_f16_to_f32_conversion(nir_instr *instr)
371{
372   return is_n_to_m_conversion(instr, 16, nir_op_f2f32);
373}
374
375static bool
376is_f32_to_f16_conversion(nir_instr *instr)
377{
378   return is_n_to_m_conversion(instr, 32, nir_op_f2f16) ||
379          is_n_to_m_conversion(instr, 32, nir_op_f2fmp);
380}
381
382static bool
383is_i16_to_i32_conversion(nir_instr *instr)
384{
385   return is_n_to_m_conversion(instr, 16, nir_op_i2i32);
386}
387
388static bool
389is_u16_to_u32_conversion(nir_instr *instr)
390{
391   return is_n_to_m_conversion(instr, 16, nir_op_u2u32);
392}
393
394static bool
395is_i32_to_i16_conversion(nir_instr *instr)
396{
397   return is_n_to_m_conversion(instr, 32, nir_op_i2i16) ||
398          is_n_to_m_conversion(instr, 32, nir_op_u2u16) ||
399          is_n_to_m_conversion(instr, 32, nir_op_i2imp);
400}
401
402/**
403 * Fix types of source operands of texture opcodes according to
404 * the constraints by inserting the appropriate conversion opcodes.
405 *
406 * For example, if the type of derivatives must be equal to texture
407 * coordinates and the type of the texture bias must be 32-bit, there
408 * will be 2 constraints describing that.
409 */
410bool
411nir_legalize_16bit_sampler_srcs(nir_shader *nir,
412                                nir_tex_src_type_constraints constraints)
413{
414   bool changed = false;
415   nir_function_impl *impl = nir_shader_get_entrypoint(nir);
416   assert(impl);
417
418   nir_builder b;
419   nir_builder_init(&b, impl);
420
421   nir_foreach_block_safe (block, impl) {
422      nir_foreach_instr_safe (instr, block) {
423         if (instr->type != nir_instr_type_tex)
424            continue;
425
426         nir_tex_instr *tex = nir_instr_as_tex(instr);
427         int8_t map[nir_num_tex_src_types];
428         memset(map, -1, sizeof(map));
429
430         /* Create a mapping from src_type to src[i]. */
431         for (unsigned i = 0; i < tex->num_srcs; i++)
432            map[tex->src[i].src_type] = i;
433
434         /* Legalize src types. */
435         for (unsigned i = 0; i < tex->num_srcs; i++) {
436            nir_tex_src_type_constraint c = constraints[tex->src[i].src_type];
437
438            if (!c.legalize_type)
439               continue;
440
441            /* Determine the required bit size for the src. */
442            unsigned bit_size;
443            if (c.bit_size) {
444               bit_size = c.bit_size;
445            } else {
446               if (map[c.match_src] == -1)
447                  continue; /* e.g. txs */
448
449               bit_size = tex->src[map[c.match_src]].src.ssa->bit_size;
450            }
451
452            /* Check if the type is legal. */
453            if (bit_size == tex->src[i].src.ssa->bit_size)
454               continue;
455
456            /* Fix the bit size. */
457            bool is_sint = nir_tex_instr_src_type(tex, i) == nir_type_int;
458            bool is_uint = nir_tex_instr_src_type(tex, i) == nir_type_uint;
459            nir_ssa_def *(*convert)(nir_builder *, nir_ssa_def *);
460
461            switch (bit_size) {
462            case 16:
463               convert = is_sint ? nir_i2i16 :
464                         is_uint ? nir_u2u16 : nir_f2f16;
465               break;
466            case 32:
467               convert = is_sint ? nir_i2i32 :
468                         is_uint ? nir_u2u32 : nir_f2f32;
469               break;
470            default:
471               assert(!"unexpected bit size");
472               continue;
473            }
474
475            b.cursor = nir_before_instr(&tex->instr);
476            nir_ssa_def *conv =
477               convert(&b, nir_ssa_for_src(&b, tex->src[i].src,
478                                           tex->src[i].src.ssa->num_components));
479            nir_instr_rewrite_src_ssa(&tex->instr, &tex->src[i].src, conv);
480            changed = true;
481         }
482      }
483   }
484
485   if (changed) {
486      nir_metadata_preserve(impl, nir_metadata_dominance |
487                                  nir_metadata_block_index);
488   } else {
489      nir_metadata_preserve(impl, nir_metadata_all);
490   }
491
492   return changed;
493}
494
495static bool
496const_is_f16(nir_ssa_scalar scalar)
497{
498   double value = nir_ssa_scalar_as_float(scalar);
499   return value == _mesa_half_to_float(_mesa_float_to_half(value));
500}
501
502static bool
503const_is_u16(nir_ssa_scalar scalar)
504{
505   uint64_t value = nir_ssa_scalar_as_uint(scalar);
506   return value == (uint16_t) value;
507}
508
509static bool
510const_is_i16(nir_ssa_scalar scalar)
511{
512   int64_t value = nir_ssa_scalar_as_int(scalar);
513   return value == (int16_t) value;
514}
515
516static bool
517can_fold_16bit_src(nir_ssa_def *ssa, nir_alu_type src_type, bool sext_matters)
518{
519   bool fold_f16 = src_type == nir_type_float32;
520   bool fold_u16 = src_type == nir_type_uint32 && sext_matters;
521   bool fold_i16 = src_type == nir_type_int32 && sext_matters;
522   bool fold_i16_u16 = (src_type == nir_type_uint32 || src_type == nir_type_int32) && !sext_matters;
523
524   bool can_fold = fold_f16 || fold_u16 || fold_i16 || fold_i16_u16;
525   for (unsigned i = 0; can_fold && i < ssa->num_components; i++) {
526      nir_ssa_scalar comp = nir_ssa_scalar_resolved(ssa, i);
527      if (comp.def->parent_instr->type == nir_instr_type_ssa_undef)
528         continue;
529      else if (nir_ssa_scalar_is_const(comp)) {
530         if (fold_f16)
531            can_fold &= const_is_f16(comp);
532         else if (fold_u16)
533            can_fold &= const_is_u16(comp);
534         else if (fold_i16)
535            can_fold &= const_is_i16(comp);
536         else if (fold_i16_u16)
537            can_fold &= (const_is_u16(comp) || const_is_i16(comp));
538      } else {
539         if (fold_f16)
540            can_fold &= is_f16_to_f32_conversion(comp.def->parent_instr);
541         else if (fold_u16)
542            can_fold &= is_u16_to_u32_conversion(comp.def->parent_instr);
543         else if (fold_i16)
544            can_fold &= is_i16_to_i32_conversion(comp.def->parent_instr);
545         else if (fold_i16_u16)
546            can_fold &= (is_i16_to_i32_conversion(comp.def->parent_instr) ||
547                         is_u16_to_u32_conversion(comp.def->parent_instr));
548      }
549   }
550
551   return can_fold;
552}
553
554static void
555fold_16bit_src(nir_builder *b, nir_instr *instr, nir_src *src, nir_alu_type src_type)
556{
557   b->cursor = nir_before_instr(instr);
558
559   nir_ssa_scalar new_comps[NIR_MAX_VEC_COMPONENTS];
560   for (unsigned i = 0; i < src->ssa->num_components; i++) {
561      nir_ssa_scalar comp = nir_ssa_scalar_resolved(src->ssa, i);
562
563      if (comp.def->parent_instr->type == nir_instr_type_ssa_undef)
564         new_comps[i] = nir_get_ssa_scalar(nir_ssa_undef(b, 1, 16), 0);
565      else if (nir_ssa_scalar_is_const(comp)) {
566         nir_ssa_def *constant;
567         if (src_type == nir_type_float32)
568            constant = nir_imm_float16(b, nir_ssa_scalar_as_float(comp));
569         else
570            constant = nir_imm_intN_t(b, nir_ssa_scalar_as_uint(comp), 16);
571         new_comps[i] = nir_get_ssa_scalar(constant, 0);
572      } else {
573         /* conversion instruction */
574         new_comps[i] = nir_ssa_scalar_chase_alu_src(comp, 0);
575      }
576   }
577
578   nir_ssa_def *new_vec = nir_vec_scalars(b, new_comps, src->ssa->num_components);
579
580   nir_instr_rewrite_src_ssa(instr, src, new_vec);
581}
582
583static bool
584fold_16bit_store_data(nir_builder *b, nir_intrinsic_instr *instr)
585{
586   nir_alu_type src_type = nir_intrinsic_src_type(instr);
587   nir_src *data_src = &instr->src[3];
588
589   b->cursor = nir_before_instr(&instr->instr);
590
591   if (!can_fold_16bit_src(data_src->ssa, src_type, true))
592      return false;
593
594   fold_16bit_src(b, &instr->instr, data_src, src_type);
595
596   nir_intrinsic_set_src_type(instr, (src_type & ~32) | 16);
597
598   return true;
599}
600
601static bool
602fold_16bit_destination(nir_ssa_def *ssa, nir_alu_type dest_type,
603                       unsigned exec_mode, nir_rounding_mode rdm)
604{
605   bool is_f32_to_f16 = dest_type == nir_type_float32;
606   bool is_i32_to_i16 = dest_type == nir_type_int32 || dest_type == nir_type_uint32;
607
608   nir_rounding_mode src_rdm =
609      nir_get_rounding_mode_from_float_controls(exec_mode, nir_type_float16);
610   bool allow_standard = (src_rdm == rdm || src_rdm == nir_rounding_mode_undef);
611   bool allow_rtz = rdm == nir_rounding_mode_rtz;
612   bool allow_rtne = rdm == nir_rounding_mode_rtne;
613
614   nir_foreach_use(use, ssa) {
615      nir_instr *instr = use->parent_instr;
616      is_f32_to_f16 &= (allow_standard && is_f32_to_f16_conversion(instr)) ||
617                       (allow_rtz && is_n_to_m_conversion(instr, 32, nir_op_f2f16_rtz)) ||
618                       (allow_rtne && is_n_to_m_conversion(instr, 32, nir_op_f2f16_rtne));
619      is_i32_to_i16 &= is_i32_to_i16_conversion(instr);
620   }
621
622   if (!is_f32_to_f16 && !is_i32_to_i16)
623      return false;
624
625   /* All uses are the same conversions. Replace them with mov. */
626   nir_foreach_use(use, ssa) {
627      nir_alu_instr *conv = nir_instr_as_alu(use->parent_instr);
628      conv->op = nir_op_mov;
629   }
630
631   ssa->bit_size = 16;
632   return true;
633}
634
635static bool
636fold_16bit_load_data(nir_builder *b, nir_intrinsic_instr *instr,
637                     unsigned exec_mode, nir_rounding_mode rdm)
638{
639   nir_alu_type dest_type = nir_intrinsic_dest_type(instr);
640
641   if (!fold_16bit_destination(&instr->dest.ssa, dest_type, exec_mode, rdm))
642      return false;
643
644   nir_intrinsic_set_dest_type(instr, (dest_type & ~32) | 16);
645
646   return true;
647}
648
649static bool
650fold_16bit_tex_dest(nir_tex_instr *tex, unsigned exec_mode,
651                    nir_rounding_mode rdm)
652{
653   /* Skip sparse residency */
654   if (tex->is_sparse)
655      return false;
656
657   if (tex->op != nir_texop_tex &&
658       tex->op != nir_texop_txb &&
659       tex->op != nir_texop_txd &&
660       tex->op != nir_texop_txl &&
661       tex->op != nir_texop_txf &&
662       tex->op != nir_texop_txf_ms &&
663       tex->op != nir_texop_tg4 &&
664       tex->op != nir_texop_tex_prefetch &&
665       tex->op != nir_texop_fragment_fetch_amd)
666      return false;
667
668   if (!fold_16bit_destination(&tex->dest.ssa, tex->dest_type, exec_mode, rdm))
669      return false;
670
671   tex->dest_type = (tex->dest_type & ~32) | 16;
672   return true;
673}
674
675
676static bool
677fold_16bit_tex_srcs(nir_builder *b, nir_tex_instr *tex,
678                    struct nir_fold_tex_srcs_options *options)
679{
680   if (tex->op != nir_texop_tex &&
681       tex->op != nir_texop_txb &&
682       tex->op != nir_texop_txd &&
683       tex->op != nir_texop_txl &&
684       tex->op != nir_texop_txf &&
685       tex->op != nir_texop_txf_ms &&
686       tex->op != nir_texop_tg4 &&
687       tex->op != nir_texop_tex_prefetch &&
688       tex->op != nir_texop_fragment_fetch_amd &&
689       tex->op != nir_texop_fragment_mask_fetch_amd)
690      return false;
691
692   if (!(options->sampler_dims & BITFIELD_BIT(tex->sampler_dim)))
693      return false;
694
695   unsigned fold_srcs = 0;
696   for (unsigned i = 0; i < tex->num_srcs; i++) {
697      /* Filter out sources that should be ignored. */
698      if (!(BITFIELD_BIT(tex->src[i].src_type) & options->src_types))
699         continue;
700
701      nir_src *src = &tex->src[i].src;
702
703      nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
704
705      /* Zero-extension (u16) and sign-extension (i16) have
706       * the same behavior here - txf returns 0 if bit 15 is set
707       * because it's out of bounds and the higher bits don't
708       * matter.
709       */
710      if (!can_fold_16bit_src(src->ssa, src_type, false))
711         return false;
712
713      fold_srcs |= (1 << i);
714   }
715
716   u_foreach_bit(i, fold_srcs) {
717      nir_src *src = &tex->src[i].src;
718      nir_alu_type src_type = nir_tex_instr_src_type(tex, i) | src->ssa->bit_size;
719      fold_16bit_src(b, &tex->instr, src, src_type);
720   }
721
722   return !!fold_srcs;
723}
724
725static bool
726fold_16bit_tex_image(nir_builder *b, nir_instr *instr, void *params)
727{
728   struct nir_fold_16bit_tex_image_options *options = params;
729   unsigned exec_mode = b->shader->info.float_controls_execution_mode;
730   bool progress = false;
731
732   if (instr->type == nir_instr_type_intrinsic) {
733      nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
734
735      switch (intrinsic->intrinsic) {
736      case nir_intrinsic_bindless_image_store:
737      case nir_intrinsic_image_deref_store:
738      case nir_intrinsic_image_store:
739         if (options->fold_image_load_store_data)
740            progress |= fold_16bit_store_data(b, intrinsic);
741         break;
742      case nir_intrinsic_bindless_image_load:
743      case nir_intrinsic_image_deref_load:
744      case nir_intrinsic_image_load:
745         if (options->fold_image_load_store_data)
746            progress |= fold_16bit_load_data(b, intrinsic, exec_mode, options->rounding_mode);
747         break;
748      default:
749         break;
750      }
751   } else if (instr->type == nir_instr_type_tex) {
752      nir_tex_instr *tex = nir_instr_as_tex(instr);
753
754      if (options->fold_tex_dest)
755         progress |= fold_16bit_tex_dest(tex, exec_mode, options->rounding_mode);
756
757      for (unsigned i = 0; i < options->fold_srcs_options_count; i++) {
758         progress |= fold_16bit_tex_srcs(b, tex, &options->fold_srcs_options[i]);
759      }
760   }
761
762   return progress;
763}
764
765bool nir_fold_16bit_tex_image(nir_shader *nir,
766                              struct nir_fold_16bit_tex_image_options *options)
767{
768   return nir_shader_instructions_pass(nir,
769                                       fold_16bit_tex_image,
770                                       nir_metadata_block_index | nir_metadata_dominance,
771                                       options);
772}
773