1/*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include <math.h>
25#include "vtn_private.h"
26#include "spirv_info.h"
27
28/*
29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30 * definition. But for matrix multiplies, we want to do one routine for
31 * multiplying a matrix by a matrix and then pretend that vectors are matrices
32 * with one column. So we "wrap" these things, and unwrap the result before we
33 * send it off.
34 */
35
36static struct vtn_ssa_value *
37wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38{
39   if (val == NULL)
40      return NULL;
41
42   if (glsl_type_is_matrix(val->type))
43      return val;
44
45   struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46   dest->type = glsl_get_bare_type(val->type);
47   dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48   dest->elems[0] = val;
49
50   return dest;
51}
52
53static struct vtn_ssa_value *
54unwrap_matrix(struct vtn_ssa_value *val)
55{
56   if (glsl_type_is_matrix(val->type))
57         return val;
58
59   return val->elems[0];
60}
61
62static struct vtn_ssa_value *
63matrix_multiply(struct vtn_builder *b,
64                struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65{
66
67   struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68   struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69   struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70   struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71
72   unsigned src0_rows = glsl_get_vector_elements(src0->type);
73   unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74   unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75
76   const struct glsl_type *dest_type;
77   if (src1_columns > 1) {
78      dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79                                   src0_rows, src1_columns);
80   } else {
81      dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82   }
83   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84
85   dest = wrap_matrix(b, dest);
86
87   bool transpose_result = false;
88   if (src0_transpose && src1_transpose) {
89      /* transpose(A) * transpose(B) = transpose(B * A) */
90      src1 = src0_transpose;
91      src0 = src1_transpose;
92      src0_transpose = NULL;
93      src1_transpose = NULL;
94      transpose_result = true;
95   }
96
97   if (src0_transpose && !src1_transpose &&
98       glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99      /* We already have the rows of src0 and the columns of src1 available,
100       * so we can just take the dot product of each row with each column to
101       * get the result.
102       */
103
104      for (unsigned i = 0; i < src1_columns; i++) {
105         nir_ssa_def *vec_src[4];
106         for (unsigned j = 0; j < src0_rows; j++) {
107            vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108                                          src1->elems[i]->def);
109         }
110         dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111      }
112   } else {
113      /* We don't handle the case where src1 is transposed but not src0, since
114       * the general case only uses individual components of src1 so the
115       * optimizer should chew through the transpose we emitted for src1.
116       */
117
118      for (unsigned i = 0; i < src1_columns; i++) {
119         /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120         dest->elems[i]->def =
121            nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122                     nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123         for (int j = src0_columns - 2; j >= 0; j--) {
124            dest->elems[i]->def =
125               nir_ffma(&b->nb, src0->elems[j]->def,
126                                nir_channel(&b->nb, src1->elems[i]->def, j),
127                                dest->elems[i]->def);
128         }
129      }
130   }
131
132   dest = unwrap_matrix(dest);
133
134   if (transpose_result)
135      dest = vtn_ssa_transpose(b, dest);
136
137   return dest;
138}
139
140static struct vtn_ssa_value *
141mat_times_scalar(struct vtn_builder *b,
142                 struct vtn_ssa_value *mat,
143                 nir_ssa_def *scalar)
144{
145   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146   for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147      if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148         dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149      else
150         dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151   }
152
153   return dest;
154}
155
156nir_ssa_def *
157vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
158{
159   if (def->bit_size == 16)
160      return def;
161
162   switch (base_type) {
163   case GLSL_TYPE_FLOAT:
164      return nir_f2fmp(&b->nb, def);
165   case GLSL_TYPE_INT:
166   case GLSL_TYPE_UINT:
167      return nir_i2imp(&b->nb, def);
168   /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
169    * OpLogical* operations (which is forbidden by spec).
170    */
171   case GLSL_TYPE_BOOL:
172      return def;
173   default:
174      unreachable("bad relaxed precision input type");
175   }
176}
177
178struct vtn_ssa_value *
179vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
180{
181   if (!src)
182      return src;
183
184   struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
185
186   if (src->transposed) {
187      srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
188   } else {
189      enum glsl_base_type base_type = glsl_get_base_type(src->type);
190
191      if (glsl_type_is_vector_or_scalar(src->type)) {
192         srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
193      } else {
194         assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
195         for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
196            srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
197      }
198   }
199
200   return srcmp;
201}
202
203static struct vtn_ssa_value *
204vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
205                      struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
206{
207   switch (opcode) {
208   case SpvOpFNegate: {
209      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
210      unsigned cols = glsl_get_matrix_columns(src0->type);
211      for (unsigned i = 0; i < cols; i++)
212         dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
213      return dest;
214   }
215
216   case SpvOpFAdd: {
217      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
218      unsigned cols = glsl_get_matrix_columns(src0->type);
219      for (unsigned i = 0; i < cols; i++)
220         dest->elems[i]->def =
221            nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
222      return dest;
223   }
224
225   case SpvOpFSub: {
226      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
227      unsigned cols = glsl_get_matrix_columns(src0->type);
228      for (unsigned i = 0; i < cols; i++)
229         dest->elems[i]->def =
230            nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
231      return dest;
232   }
233
234   case SpvOpTranspose:
235      return vtn_ssa_transpose(b, src0);
236
237   case SpvOpMatrixTimesScalar:
238      if (src0->transposed) {
239         return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
240                                                         src1->def));
241      } else {
242         return mat_times_scalar(b, src0, src1->def);
243      }
244      break;
245
246   case SpvOpVectorTimesMatrix:
247   case SpvOpMatrixTimesVector:
248   case SpvOpMatrixTimesMatrix:
249      if (opcode == SpvOpVectorTimesMatrix) {
250         return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
251      } else {
252         return matrix_multiply(b, src0, src1);
253      }
254      break;
255
256   default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
257   }
258}
259
260static nir_alu_type
261convert_op_src_type(SpvOp opcode)
262{
263   switch (opcode) {
264   case SpvOpFConvert:
265   case SpvOpConvertFToS:
266   case SpvOpConvertFToU:
267      return nir_type_float;
268   case SpvOpSConvert:
269   case SpvOpConvertSToF:
270   case SpvOpSatConvertSToU:
271      return nir_type_int;
272   case SpvOpUConvert:
273   case SpvOpConvertUToF:
274   case SpvOpSatConvertUToS:
275      return nir_type_uint;
276   default:
277      unreachable("Unhandled conversion op");
278   }
279}
280
281static nir_alu_type
282convert_op_dst_type(SpvOp opcode)
283{
284   switch (opcode) {
285   case SpvOpFConvert:
286   case SpvOpConvertSToF:
287   case SpvOpConvertUToF:
288      return nir_type_float;
289   case SpvOpSConvert:
290   case SpvOpConvertFToS:
291   case SpvOpSatConvertUToS:
292      return nir_type_int;
293   case SpvOpUConvert:
294   case SpvOpConvertFToU:
295   case SpvOpSatConvertSToU:
296      return nir_type_uint;
297   default:
298      unreachable("Unhandled conversion op");
299   }
300}
301
302nir_op
303vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
304                                SpvOp opcode, bool *swap, bool *exact,
305                                unsigned src_bit_size, unsigned dst_bit_size)
306{
307   /* Indicates that the first two arguments should be swapped.  This is
308    * used for implementing greater-than and less-than-or-equal.
309    */
310   *swap = false;
311
312   *exact = false;
313
314   switch (opcode) {
315   case SpvOpSNegate:            return nir_op_ineg;
316   case SpvOpFNegate:            return nir_op_fneg;
317   case SpvOpNot:                return nir_op_inot;
318   case SpvOpIAdd:               return nir_op_iadd;
319   case SpvOpFAdd:               return nir_op_fadd;
320   case SpvOpISub:               return nir_op_isub;
321   case SpvOpFSub:               return nir_op_fsub;
322   case SpvOpIMul:               return nir_op_imul;
323   case SpvOpFMul:               return nir_op_fmul;
324   case SpvOpUDiv:               return nir_op_udiv;
325   case SpvOpSDiv:               return nir_op_idiv;
326   case SpvOpFDiv:               return nir_op_fdiv;
327   case SpvOpUMod:               return nir_op_umod;
328   case SpvOpSMod:               return nir_op_imod;
329   case SpvOpFMod:               return nir_op_fmod;
330   case SpvOpSRem:               return nir_op_irem;
331   case SpvOpFRem:               return nir_op_frem;
332
333   case SpvOpShiftRightLogical:     return nir_op_ushr;
334   case SpvOpShiftRightArithmetic:  return nir_op_ishr;
335   case SpvOpShiftLeftLogical:      return nir_op_ishl;
336   case SpvOpLogicalOr:             return nir_op_ior;
337   case SpvOpLogicalEqual:          return nir_op_ieq;
338   case SpvOpLogicalNotEqual:       return nir_op_ine;
339   case SpvOpLogicalAnd:            return nir_op_iand;
340   case SpvOpLogicalNot:            return nir_op_inot;
341   case SpvOpBitwiseOr:             return nir_op_ior;
342   case SpvOpBitwiseXor:            return nir_op_ixor;
343   case SpvOpBitwiseAnd:            return nir_op_iand;
344   case SpvOpSelect:                return nir_op_bcsel;
345   case SpvOpIEqual:                return nir_op_ieq;
346
347   case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
348   case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
349   case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
350   case SpvOpBitReverse:            return nir_op_bitfield_reverse;
351
352   case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
353   /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
354   case SpvOpAbsISubINTEL:          return nir_op_uabs_isub;
355   case SpvOpAbsUSubINTEL:          return nir_op_uabs_usub;
356   case SpvOpIAddSatINTEL:          return nir_op_iadd_sat;
357   case SpvOpUAddSatINTEL:          return nir_op_uadd_sat;
358   case SpvOpIAverageINTEL:         return nir_op_ihadd;
359   case SpvOpUAverageINTEL:         return nir_op_uhadd;
360   case SpvOpIAverageRoundedINTEL:  return nir_op_irhadd;
361   case SpvOpUAverageRoundedINTEL:  return nir_op_urhadd;
362   case SpvOpISubSatINTEL:          return nir_op_isub_sat;
363   case SpvOpUSubSatINTEL:          return nir_op_usub_sat;
364   case SpvOpIMul32x16INTEL:        return nir_op_imul_32x16;
365   case SpvOpUMul32x16INTEL:        return nir_op_umul_32x16;
366
367   /* The ordered / unordered operators need special implementation besides
368    * the logical operator to use since they also need to check if operands are
369    * ordered.
370    */
371   case SpvOpFOrdEqual:                            *exact = true;  return nir_op_feq;
372   case SpvOpFUnordEqual:                          *exact = true;  return nir_op_feq;
373   case SpvOpINotEqual:                                            return nir_op_ine;
374   case SpvOpLessOrGreater:                        /* Deprecated, use OrdNotEqual */
375   case SpvOpFOrdNotEqual:                         *exact = true;  return nir_op_fneu;
376   case SpvOpFUnordNotEqual:                       *exact = true;  return nir_op_fneu;
377   case SpvOpULessThan:                                            return nir_op_ult;
378   case SpvOpSLessThan:                                            return nir_op_ilt;
379   case SpvOpFOrdLessThan:                         *exact = true;  return nir_op_flt;
380   case SpvOpFUnordLessThan:                       *exact = true;  return nir_op_flt;
381   case SpvOpUGreaterThan:          *swap = true;                  return nir_op_ult;
382   case SpvOpSGreaterThan:          *swap = true;                  return nir_op_ilt;
383   case SpvOpFOrdGreaterThan:       *swap = true;  *exact = true;  return nir_op_flt;
384   case SpvOpFUnordGreaterThan:     *swap = true;  *exact = true;  return nir_op_flt;
385   case SpvOpULessThanEqual:        *swap = true;                  return nir_op_uge;
386   case SpvOpSLessThanEqual:        *swap = true;                  return nir_op_ige;
387   case SpvOpFOrdLessThanEqual:     *swap = true;  *exact = true;  return nir_op_fge;
388   case SpvOpFUnordLessThanEqual:   *swap = true;  *exact = true;  return nir_op_fge;
389   case SpvOpUGreaterThanEqual:                                    return nir_op_uge;
390   case SpvOpSGreaterThanEqual:                                    return nir_op_ige;
391   case SpvOpFOrdGreaterThanEqual:                 *exact = true;  return nir_op_fge;
392   case SpvOpFUnordGreaterThanEqual:               *exact = true;  return nir_op_fge;
393
394   /* Conversions: */
395   case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
396   case SpvOpUConvert:
397   case SpvOpConvertFToU:
398   case SpvOpConvertFToS:
399   case SpvOpConvertSToF:
400   case SpvOpConvertUToF:
401   case SpvOpSConvert:
402   case SpvOpFConvert: {
403      nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
404      nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
405      return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
406   }
407
408   case SpvOpPtrCastToGeneric:   return nir_op_mov;
409   case SpvOpGenericCastToPtr:   return nir_op_mov;
410
411   /* Derivatives: */
412   case SpvOpDPdx:         return nir_op_fddx;
413   case SpvOpDPdy:         return nir_op_fddy;
414   case SpvOpDPdxFine:     return nir_op_fddx_fine;
415   case SpvOpDPdyFine:     return nir_op_fddy_fine;
416   case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
417   case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
418
419   case SpvOpIsNormal:     return nir_op_fisnormal;
420   case SpvOpIsFinite:     return nir_op_fisfinite;
421
422   default:
423      vtn_fail("No NIR equivalent: %u", opcode);
424   }
425}
426
427static void
428handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
429                      UNUSED int member, const struct vtn_decoration *dec,
430                      UNUSED void *_void)
431{
432   vtn_assert(dec->scope == VTN_DEC_DECORATION);
433   if (dec->decoration != SpvDecorationNoContraction)
434      return;
435
436   b->nb.exact = true;
437}
438
439void
440vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
441{
442   vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
443}
444
445nir_rounding_mode
446vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
447{
448   switch (mode) {
449   case SpvFPRoundingModeRTE:
450      return nir_rounding_mode_rtne;
451   case SpvFPRoundingModeRTZ:
452      return nir_rounding_mode_rtz;
453   case SpvFPRoundingModeRTP:
454      vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
455                  "FPRoundingModeRTP is only supported in kernels");
456      return nir_rounding_mode_ru;
457   case SpvFPRoundingModeRTN:
458      vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
459                  "FPRoundingModeRTN is only supported in kernels");
460      return nir_rounding_mode_rd;
461   default:
462      vtn_fail("Unsupported rounding mode: %s",
463               spirv_fproundingmode_to_string(mode));
464      break;
465   }
466}
467
468struct conversion_opts {
469   nir_rounding_mode rounding_mode;
470   bool saturate;
471};
472
473static void
474handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
475                       UNUSED int member,
476                       const struct vtn_decoration *dec, void *_opts)
477{
478   struct conversion_opts *opts = _opts;
479
480   switch (dec->decoration) {
481   case SpvDecorationFPRoundingMode:
482      opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
483      break;
484
485   case SpvDecorationSaturatedConversion:
486      vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
487                  "Saturated conversions are only allowed in kernels");
488      opts->saturate = true;
489      break;
490
491   default:
492      break;
493   }
494}
495
496static void
497handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
498               UNUSED int member,
499               const struct vtn_decoration *dec, void *_alu)
500{
501   nir_alu_instr *alu = _alu;
502   switch (dec->decoration) {
503   case SpvDecorationNoSignedWrap:
504      alu->no_signed_wrap = true;
505      break;
506   case SpvDecorationNoUnsignedWrap:
507      alu->no_unsigned_wrap = true;
508      break;
509   default:
510      /* Do nothing. */
511      break;
512   }
513}
514
515static void
516vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
517                          struct vtn_value *val, int member,
518                          const struct vtn_decoration *dec, void *void_ctx)
519{
520   bool *relaxed_precision = void_ctx;
521   switch (dec->decoration) {
522   case SpvDecorationRelaxedPrecision:
523      *relaxed_precision = true;
524      break;
525
526   default:
527      break;
528   }
529}
530
531bool
532vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
533{
534   bool result = false;
535   vtn_foreach_decoration(b, val,
536                          vtn_value_is_relaxed_precision_cb, &result);
537   return result;
538}
539
540static bool
541vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
542{
543   if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
544      return false;
545
546   switch (opcode) {
547   case SpvOpDPdx:
548   case SpvOpDPdy:
549   case SpvOpDPdxFine:
550   case SpvOpDPdyFine:
551   case SpvOpDPdxCoarse:
552   case SpvOpDPdyCoarse:
553   case SpvOpFwidth:
554   case SpvOpFwidthFine:
555   case SpvOpFwidthCoarse:
556      return b->options->mediump_16bit_derivatives;
557   default:
558      return true;
559   }
560}
561
562static nir_ssa_def *
563vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def)
564{
565   if (def->bit_size != 16)
566      return def;
567
568   switch (base_type) {
569   case GLSL_TYPE_FLOAT:
570      return nir_f2f32(&b->nb, def);
571   case GLSL_TYPE_INT:
572      return nir_i2i32(&b->nb, def);
573   case GLSL_TYPE_UINT:
574      return nir_u2u32(&b->nb, def);
575   default:
576      unreachable("bad relaxed precision output type");
577   }
578}
579
580void
581vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
582{
583   enum glsl_base_type base_type = glsl_get_base_type(value->type);
584
585   if (glsl_type_is_vector_or_scalar(value->type)) {
586      value->def = vtn_mediump_upconvert(b, base_type, value->def);
587   } else {
588      for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
589         value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
590   }
591}
592
593void
594vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
595               const uint32_t *w, unsigned count)
596{
597   struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
598   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
599
600   vtn_handle_no_contraction(b, dest_val);
601   bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
602
603   /* Collect the various SSA sources */
604   const unsigned num_inputs = count - 3;
605   struct vtn_ssa_value *vtn_src[4] = { NULL, };
606   for (unsigned i = 0; i < num_inputs; i++) {
607      vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
608      if (mediump_16bit)
609         vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
610   }
611
612   if (glsl_type_is_matrix(vtn_src[0]->type) ||
613       (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
614      struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
615
616      if (mediump_16bit)
617         vtn_mediump_upconvert_value(b, dest);
618
619      vtn_push_ssa_value(b, w[2], dest);
620      b->nb.exact = b->exact;
621      return;
622   }
623
624   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
625   nir_ssa_def *src[4] = { NULL, };
626   for (unsigned i = 0; i < num_inputs; i++) {
627      vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
628      src[i] = vtn_src[i]->def;
629   }
630
631   switch (opcode) {
632   case SpvOpAny:
633      dest->def = nir_bany(&b->nb, src[0]);
634      break;
635
636   case SpvOpAll:
637      dest->def = nir_ball(&b->nb, src[0]);
638      break;
639
640   case SpvOpOuterProduct: {
641      for (unsigned i = 0; i < src[1]->num_components; i++) {
642         dest->elems[i]->def =
643            nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
644      }
645      break;
646   }
647
648   case SpvOpDot:
649      dest->def = nir_fdot(&b->nb, src[0], src[1]);
650      break;
651
652   case SpvOpIAddCarry:
653      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
654      dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
655      dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
656      break;
657
658   case SpvOpISubBorrow:
659      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
660      dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
661      dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
662      break;
663
664   case SpvOpUMulExtended: {
665      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
666      if (src[0]->bit_size == 32) {
667         nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
668         dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
669         dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
670      } else {
671         dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
672         dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
673      }
674      break;
675   }
676
677   case SpvOpSMulExtended: {
678      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
679      if (src[0]->bit_size == 32) {
680         nir_ssa_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
681         dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
682         dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
683      } else {
684         dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
685         dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
686      }
687      break;
688   }
689
690   case SpvOpFwidth:
691      dest->def = nir_fadd(&b->nb,
692                               nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
693                               nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
694      break;
695   case SpvOpFwidthFine:
696      dest->def = nir_fadd(&b->nb,
697                               nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
698                               nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
699      break;
700   case SpvOpFwidthCoarse:
701      dest->def = nir_fadd(&b->nb,
702                               nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
703                               nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
704      break;
705
706   case SpvOpVectorTimesScalar:
707      /* The builder will take care of splatting for us. */
708      dest->def = nir_fmul(&b->nb, src[0], src[1]);
709      break;
710
711   case SpvOpIsNan: {
712      const bool save_exact = b->nb.exact;
713
714      b->nb.exact = true;
715      dest->def = nir_fneu(&b->nb, src[0], src[0]);
716      b->nb.exact = save_exact;
717      break;
718   }
719
720   case SpvOpOrdered: {
721      const bool save_exact = b->nb.exact;
722
723      b->nb.exact = true;
724      dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
725                                   nir_feq(&b->nb, src[1], src[1]));
726      b->nb.exact = save_exact;
727      break;
728   }
729
730   case SpvOpUnordered: {
731      const bool save_exact = b->nb.exact;
732
733      b->nb.exact = true;
734      dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
735                                  nir_fneu(&b->nb, src[1], src[1]));
736      b->nb.exact = save_exact;
737      break;
738   }
739
740   case SpvOpIsInf: {
741      nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
742      dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
743      break;
744   }
745
746   case SpvOpFUnordEqual: {
747      const bool save_exact = b->nb.exact;
748
749      b->nb.exact = true;
750
751      /* This could also be implemented as !(a < b || b < a).  If one or both
752       * of the source are numbers, later optimization passes can easily
753       * eliminate the isnan() checks.  This may trim the sequence down to a
754       * single (a == b) operation.  Otherwise, the optimizer can transform
755       * whatever is left to !(a < b || b < a).  Since some applications will
756       * open-code this sequence, these optimizations are needed anyway.
757       */
758      dest->def =
759         nir_ior(&b->nb,
760                 nir_feq(&b->nb, src[0], src[1]),
761                 nir_ior(&b->nb,
762                         nir_fneu(&b->nb, src[0], src[0]),
763                         nir_fneu(&b->nb, src[1], src[1])));
764
765      b->nb.exact = save_exact;
766      break;
767   }
768
769   case SpvOpFUnordLessThan:
770   case SpvOpFUnordGreaterThan:
771   case SpvOpFUnordLessThanEqual:
772   case SpvOpFUnordGreaterThanEqual: {
773      bool swap;
774      bool unused_exact;
775      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
776      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
777      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
778                                                  &unused_exact,
779                                                  src_bit_size, dst_bit_size);
780
781      if (swap) {
782         nir_ssa_def *tmp = src[0];
783         src[0] = src[1];
784         src[1] = tmp;
785      }
786
787      const bool save_exact = b->nb.exact;
788
789      b->nb.exact = true;
790
791      /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
792      switch (op) {
793      case nir_op_fge: op = nir_op_flt; break;
794      case nir_op_flt: op = nir_op_fge; break;
795      default: unreachable("Impossible opcode.");
796      }
797
798      dest->def =
799         nir_inot(&b->nb,
800                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
801
802      b->nb.exact = save_exact;
803      break;
804   }
805
806   case SpvOpLessOrGreater:
807   case SpvOpFOrdNotEqual: {
808      /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
809       * from the ALU will probably already be false if the operands are not
810       * ordered so we don’t need to handle it specially.
811       */
812      const bool save_exact = b->nb.exact;
813
814      b->nb.exact = true;
815
816      /* This could also be implemented as (a < b || b < a).  If one or both
817       * of the source are numbers, later optimization passes can easily
818       * eliminate the isnan() checks.  This may trim the sequence down to a
819       * single (a != b) operation.  Otherwise, the optimizer can transform
820       * whatever is left to (a < b || b < a).  Since some applications will
821       * open-code this sequence, these optimizations are needed anyway.
822       */
823      dest->def =
824         nir_iand(&b->nb,
825                  nir_fneu(&b->nb, src[0], src[1]),
826                  nir_iand(&b->nb,
827                          nir_feq(&b->nb, src[0], src[0]),
828                          nir_feq(&b->nb, src[1], src[1])));
829
830      b->nb.exact = save_exact;
831      break;
832   }
833
834   case SpvOpUConvert:
835   case SpvOpConvertFToU:
836   case SpvOpConvertFToS:
837   case SpvOpConvertSToF:
838   case SpvOpConvertUToF:
839   case SpvOpSConvert:
840   case SpvOpFConvert:
841   case SpvOpSatConvertSToU:
842   case SpvOpSatConvertUToS: {
843      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
844      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
845      nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
846      nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
847
848      struct conversion_opts opts = {
849         .rounding_mode = nir_rounding_mode_undef,
850         .saturate = false,
851      };
852      vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
853
854      if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
855         opts.saturate = true;
856
857      if (b->shader->info.stage == MESA_SHADER_KERNEL) {
858         if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
859            nir_op op = nir_type_conversion_op(src_type, dst_type,
860                                               nir_rounding_mode_undef);
861            dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
862         } else {
863            dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
864                                              src_type, dst_type,
865                                              opts.rounding_mode, opts.saturate);
866         }
867      } else {
868         vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
869                     dst_type != nir_type_float16,
870                     "Rounding modes are only allowed on conversions to "
871                     "16-bit float types");
872         nir_op op = nir_type_conversion_op(src_type, dst_type,
873                                            opts.rounding_mode);
874         dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
875      }
876      break;
877   }
878
879   case SpvOpBitFieldInsert:
880   case SpvOpBitFieldSExtract:
881   case SpvOpBitFieldUExtract:
882   case SpvOpShiftLeftLogical:
883   case SpvOpShiftRightArithmetic:
884   case SpvOpShiftRightLogical: {
885      bool swap;
886      bool exact;
887      unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
888      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
889      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
890                                                  src0_bit_size, dst_bit_size);
891
892      assert(!exact);
893
894      assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
895              op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
896              op == nir_op_ibitfield_extract);
897
898      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
899         unsigned src_bit_size =
900            nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
901         if (src_bit_size == 0)
902            continue;
903         if (src_bit_size != src[i]->bit_size) {
904            assert(src_bit_size == 32);
905            /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
906             * supported by the NIR instructions. See discussion here:
907             *
908             * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
909             */
910            src[i] = nir_u2u32(&b->nb, src[i]);
911         }
912      }
913      dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
914      break;
915   }
916
917   case SpvOpSignBitSet:
918      dest->def = nir_i2b(&b->nb,
919         nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
920      break;
921
922   case SpvOpUCountTrailingZerosINTEL:
923      dest->def = nir_umin(&b->nb,
924                               nir_find_lsb(&b->nb, src[0]),
925                               nir_imm_int(&b->nb, 32u));
926      break;
927
928   case SpvOpBitCount: {
929      /* bit_count always returns int32, but the SPIR-V opcode just says the return
930       * value needs to be big enough to store the number of bits.
931       */
932      dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
933      break;
934   }
935
936   case SpvOpSDotKHR:
937   case SpvOpUDotKHR:
938   case SpvOpSUDotKHR:
939   case SpvOpSDotAccSatKHR:
940   case SpvOpUDotAccSatKHR:
941   case SpvOpSUDotAccSatKHR:
942      unreachable("Should have called vtn_handle_integer_dot instead.");
943
944   default: {
945      bool swap;
946      bool exact;
947      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
948      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
949      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
950                                                  &exact,
951                                                  src_bit_size, dst_bit_size);
952
953      if (swap) {
954         nir_ssa_def *tmp = src[0];
955         src[0] = src[1];
956         src[1] = tmp;
957      }
958
959      switch (op) {
960      case nir_op_ishl:
961      case nir_op_ishr:
962      case nir_op_ushr:
963         if (src[1]->bit_size != 32)
964            src[1] = nir_u2u32(&b->nb, src[1]);
965         break;
966      default:
967         break;
968      }
969
970      const bool save_exact = b->nb.exact;
971
972      if (exact)
973         b->nb.exact = true;
974
975      dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
976
977      b->nb.exact = save_exact;
978      break;
979   } /* default */
980   }
981
982   switch (opcode) {
983   case SpvOpIAdd:
984   case SpvOpIMul:
985   case SpvOpISub:
986   case SpvOpShiftLeftLogical:
987   case SpvOpSNegate: {
988      nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
989      vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
990      break;
991   }
992   default:
993      /* Do nothing. */
994      break;
995   }
996
997   if (mediump_16bit)
998      vtn_mediump_upconvert_value(b, dest);
999   vtn_push_ssa_value(b, w[2], dest);
1000
1001   b->nb.exact = b->exact;
1002}
1003
1004void
1005vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
1006                       const uint32_t *w, unsigned count)
1007{
1008   struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
1009   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
1010   const unsigned dest_size = glsl_get_bit_size(dest_type);
1011
1012   vtn_handle_no_contraction(b, dest_val);
1013
1014   /* Collect the various SSA sources.
1015    *
1016    * Due to the optional "Packed Vector Format" field, determine number of
1017    * inputs from the opcode.  This differs from vtn_handle_alu.
1018    */
1019   const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1020                                opcode == SpvOpUDotAccSatKHR ||
1021                                opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1022
1023   vtn_assert(count >= num_inputs + 3);
1024
1025   struct vtn_ssa_value *vtn_src[3] = { NULL, };
1026   nir_ssa_def *src[3] = { NULL, };
1027
1028   for (unsigned i = 0; i < num_inputs; i++) {
1029      vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1030      src[i] = vtn_src[i]->def;
1031
1032      vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1033   }
1034
1035   /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1036    * the SPV_KHR_integer_dot_product spec says:
1037    *
1038    *    _Vector 1_ and _Vector 2_ must have the same type.
1039    *
1040    * The practical requirement is the same bit-size and the same number of
1041    * components.
1042    */
1043   vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1044               glsl_get_bit_size(vtn_src[1]->type) ||
1045               glsl_get_vector_elements(vtn_src[0]->type) !=
1046               glsl_get_vector_elements(vtn_src[1]->type),
1047               "Vector 1 and vector 2 source of opcode %s must have the same "
1048               "type",
1049               spirv_op_to_string(opcode));
1050
1051   if (num_inputs == 3) {
1052      /* The SPV_KHR_integer_dot_product spec says:
1053       *
1054       *    The type of Accumulator must be the same as Result Type.
1055       *
1056       * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1057       * types (far below) assumes these types have the same size.
1058       */
1059      vtn_fail_if(dest_type != vtn_src[2]->type,
1060                  "Accumulator type must be the same as Result Type for "
1061                  "opcode %s",
1062                  spirv_op_to_string(opcode));
1063   }
1064
1065   unsigned packed_bit_size = 8;
1066   if (glsl_type_is_vector(vtn_src[0]->type)) {
1067      /* FINISHME: Is this actually as good or better for platforms that don't
1068       * have the special instructions (i.e., one or both of has_dot_4x8 or
1069       * has_sudot_4x8 is false)?
1070       */
1071      if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1072          glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1073          glsl_get_bit_size(dest_type) <= 32) {
1074         src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1075         src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1076      } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1077                 glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1078                 glsl_get_bit_size(dest_type) <= 32 &&
1079                 opcode != SpvOpSUDotKHR &&
1080                 opcode != SpvOpSUDotAccSatKHR) {
1081         src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1082         src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1083         packed_bit_size = 16;
1084      }
1085   } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1086              glsl_type_is_32bit(vtn_src[0]->type)) {
1087      /* The SPV_KHR_integer_dot_product spec says:
1088       *
1089       *    When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1090       *    Vector Format_ must be specified to select how the integers are to
1091       *    be interpreted as vectors.
1092       *
1093       * The "Packed Vector Format" value follows the last input.
1094       */
1095      vtn_assert(count == (num_inputs + 4));
1096      const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1097      vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1098                  "Unsupported vector packing format %d for opcode %s",
1099                  pack_format, spirv_op_to_string(opcode));
1100   } else {
1101      vtn_fail_with_opcode("Invalid source types.", opcode);
1102   }
1103
1104   nir_ssa_def *dest = NULL;
1105
1106   if (src[0]->num_components > 1) {
1107      const nir_op s_conversion_op =
1108         nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
1109                                nir_rounding_mode_undef);
1110
1111      const nir_op u_conversion_op =
1112         nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
1113                                nir_rounding_mode_undef);
1114
1115      nir_op src0_conversion_op;
1116      nir_op src1_conversion_op;
1117
1118      switch (opcode) {
1119      case SpvOpSDotKHR:
1120      case SpvOpSDotAccSatKHR:
1121         src0_conversion_op = s_conversion_op;
1122         src1_conversion_op = s_conversion_op;
1123         break;
1124
1125      case SpvOpUDotKHR:
1126      case SpvOpUDotAccSatKHR:
1127         src0_conversion_op = u_conversion_op;
1128         src1_conversion_op = u_conversion_op;
1129         break;
1130
1131      case SpvOpSUDotKHR:
1132      case SpvOpSUDotAccSatKHR:
1133         src0_conversion_op = s_conversion_op;
1134         src1_conversion_op = u_conversion_op;
1135         break;
1136
1137      default:
1138         unreachable("Invalid opcode.");
1139      }
1140
1141      /* The SPV_KHR_integer_dot_product spec says:
1142       *
1143       *    All components of the input vectors are sign-extended to the bit
1144       *    width of the result's type. The sign-extended input vectors are
1145       *    then multiplied component-wise and all components of the vector
1146       *    resulting from the component-wise multiplication are added
1147       *    together. The resulting value will equal the low-order N bits of
1148       *    the correct result R, where N is the result width and R is
1149       *    computed with enough precision to avoid overflow and underflow.
1150       */
1151      const unsigned vector_components =
1152         glsl_get_vector_elements(vtn_src[0]->type);
1153
1154      for (unsigned i = 0; i < vector_components; i++) {
1155         nir_ssa_def *const src0 =
1156            nir_build_alu(&b->nb, src0_conversion_op,
1157                          nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
1158
1159         nir_ssa_def *const src1 =
1160            nir_build_alu(&b->nb, src1_conversion_op,
1161                          nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
1162
1163         nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
1164
1165         dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1166      }
1167
1168      if (num_inputs == 3) {
1169         /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1170          *
1171          *    Signed integer dot product of _Vector 1_ and _Vector 2_ and
1172          *    signed saturating addition of the result with _Accumulator_.
1173          *
1174          * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1175          *
1176          *    Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1177          *    unsigned saturating addition of the result with _Accumulator_.
1178          *
1179          * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1180          *
1181          *    Mixed-signedness integer dot product of _Vector 1_ and _Vector
1182          *    2_ and signed saturating addition of the result with
1183          *    _Accumulator_.
1184          */
1185         dest = (opcode == SpvOpUDotAccSatKHR)
1186            ? nir_uadd_sat(&b->nb, dest, src[2])
1187            : nir_iadd_sat(&b->nb, dest, src[2]);
1188      }
1189   } else {
1190      assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1191      assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1192
1193      nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1194      bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1195                       opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1196
1197      if (packed_bit_size == 16) {
1198         switch (opcode) {
1199         case SpvOpSDotKHR:
1200            dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1201            break;
1202         case SpvOpUDotKHR:
1203            dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1204            break;
1205         case SpvOpSDotAccSatKHR:
1206            if (dest_size == 32)
1207               dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1208            else
1209               dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1210            break;
1211         case SpvOpUDotAccSatKHR:
1212            if (dest_size == 32)
1213               dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1214            else
1215               dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1216            break;
1217         default:
1218            unreachable("Invalid opcode.");
1219         }
1220      } else {
1221         switch (opcode) {
1222         case SpvOpSDotKHR:
1223            dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1224            break;
1225         case SpvOpUDotKHR:
1226            dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1227            break;
1228         case SpvOpSUDotKHR:
1229            dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1230            break;
1231         case SpvOpSDotAccSatKHR:
1232            if (dest_size == 32)
1233               dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1234            else
1235               dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1236            break;
1237         case SpvOpUDotAccSatKHR:
1238            if (dest_size == 32)
1239               dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1240            else
1241               dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1242            break;
1243         case SpvOpSUDotAccSatKHR:
1244            if (dest_size == 32)
1245               dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1246            else
1247               dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1248            break;
1249         default:
1250            unreachable("Invalid opcode.");
1251         }
1252      }
1253
1254      if (dest_size != 32) {
1255         /* When the accumulator is 32-bits, a NIR dot-product with saturate
1256          * is generated above.  In all other cases a regular dot-product is
1257          * generated above, and separate addition with saturate is generated
1258          * here.
1259          *
1260          * The SPV_KHR_integer_dot_product spec says:
1261          *
1262          *    If any of the multiplications or additions, with the exception
1263          *    of the final accumulation, overflow or underflow, the result of
1264          *    the instruction is undefined.
1265          *
1266          * Therefore it is safe to cast the dot-product result down to the
1267          * size of the accumulator before doing the addition.  Since the
1268          * result of the dot-product cannot overflow 32-bits, this is also
1269          * safe to cast up.
1270          */
1271         if (num_inputs == 3) {
1272            dest = is_signed
1273               ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
1274               : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
1275         } else {
1276            dest = is_signed
1277               ? nir_i2i(&b->nb, dest, dest_size)
1278               : nir_u2u(&b->nb, dest, dest_size);
1279         }
1280      }
1281   }
1282
1283   vtn_push_nir_ssa(b, w[2], dest);
1284
1285   b->nb.exact = b->exact;
1286}
1287
1288void
1289vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1290{
1291   vtn_assert(count == 4);
1292   /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1293    *
1294    *    "If Result Type has the same number of components as Operand, they
1295    *    must also have the same component width, and results are computed per
1296    *    component.
1297    *
1298    *    If Result Type has a different number of components than Operand, the
1299    *    total number of bits in Result Type must equal the total number of
1300    *    bits in Operand. Let L be the type, either Result Type or Operand’s
1301    *    type, that has the larger number of components. Let S be the other
1302    *    type, with the smaller number of components. The number of components
1303    *    in L must be an integer multiple of the number of components in S.
1304    *    The first component (that is, the only or lowest-numbered component)
1305    *    of S maps to the first components of L, and so on, up to the last
1306    *    component of S mapping to the last components of L. Within this
1307    *    mapping, any single component of S (mapping to multiple components of
1308    *    L) maps its lower-ordered bits to the lower-numbered components of L."
1309    */
1310
1311   struct vtn_type *type = vtn_get_type(b, w[1]);
1312   struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
1313
1314   vtn_fail_if(src->num_components * src->bit_size !=
1315               glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1316               "Source and destination of OpBitcast must have the same "
1317               "total number of bits");
1318   nir_ssa_def *val =
1319      nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1320   vtn_push_nir_ssa(b, w[2], val);
1321}
1322