1/*
2 * Copyright © 2020 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#ifndef NIR_CONVERSION_BUILDER_H
25#define NIR_CONVERSION_BUILDER_H
26
27#include "util/u_math.h"
28#include "nir_builder.h"
29#include "nir_builtin_builder.h"
30
31#ifdef __cplusplus
32extern "C" {
33#endif
34
35static inline nir_ssa_def *
36nir_round_float_to_int(nir_builder *b, nir_ssa_def *src,
37                       nir_rounding_mode round)
38{
39   switch (round) {
40   case nir_rounding_mode_ru:
41      return nir_fceil(b, src);
42
43   case nir_rounding_mode_rd:
44      return nir_ffloor(b, src);
45
46   case nir_rounding_mode_rtne:
47      return nir_fround_even(b, src);
48
49   case nir_rounding_mode_undef:
50   case nir_rounding_mode_rtz:
51      break;
52   }
53   unreachable("unexpected rounding mode");
54}
55
56static inline nir_ssa_def *
57nir_round_float_to_float(nir_builder *b, nir_ssa_def *src,
58                         unsigned dest_bit_size,
59                         nir_rounding_mode round)
60{
61   unsigned src_bit_size = src->bit_size;
62   if (dest_bit_size > src_bit_size)
63      return src; /* No rounding is needed for an up-convert */
64
65   nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size,
66                                            nir_type_float | dest_bit_size,
67                                            nir_rounding_mode_undef);
68   nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size,
69                                             nir_type_float | src_bit_size,
70                                             nir_rounding_mode_undef);
71
72   switch (round) {
73   case nir_rounding_mode_ru: {
74      /* If lower-precision conversion results in a lower value, push it
75      * up one ULP. */
76      nir_ssa_def *lower_prec =
77         nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
78      nir_ssa_def *roundtrip =
79         nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
80      nir_ssa_def *cmp = nir_flt(b, roundtrip, src);
81      nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size);
82      return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec);
83   }
84   case nir_rounding_mode_rd: {
85      /* If lower-precision conversion results in a higher value, push it
86      * down one ULP. */
87      nir_ssa_def *lower_prec =
88         nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
89      nir_ssa_def *roundtrip =
90         nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
91      nir_ssa_def *cmp = nir_flt(b, src, roundtrip);
92      nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size);
93      return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec);
94   }
95   case nir_rounding_mode_rtz:
96      return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)),
97                          nir_round_float_to_float(b, src, dest_bit_size,
98                                                   nir_rounding_mode_ru),
99                          nir_round_float_to_float(b, src, dest_bit_size,
100                                                   nir_rounding_mode_rd));
101   case nir_rounding_mode_rtne:
102   case nir_rounding_mode_undef:
103      break;
104   }
105   unreachable("unexpected rounding mode");
106}
107
108static inline nir_ssa_def *
109nir_round_int_to_float(nir_builder *b, nir_ssa_def *src,
110                       nir_alu_type src_type,
111                       unsigned dest_bit_size,
112                       nir_rounding_mode round)
113{
114   /* We only care whether or not its signed */
115   src_type = nir_alu_type_get_base_type(src_type);
116
117   unsigned mantissa_bits;
118   switch (dest_bit_size) {
119   case 16:
120      mantissa_bits = 10;
121      break;
122   case 32:
123      mantissa_bits = 23;
124      break;
125   case 64:
126      mantissa_bits = 52;
127      break;
128   default: unreachable("Unsupported bit size");
129   }
130
131   if (src->bit_size < mantissa_bits)
132      return src;
133
134   if (src_type == nir_type_int) {
135      nir_ssa_def *sign =
136         nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1)));
137      nir_ssa_def *abs = nir_iabs(b, src);
138      nir_ssa_def *positive_rounded =
139         nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round);
140      nir_ssa_def *max_positive =
141         nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size);
142      switch (round) {
143      case nir_rounding_mode_rtz:
144         return nir_bcsel(b, sign, nir_ineg(b, positive_rounded),
145                                   positive_rounded);
146         break;
147      case nir_rounding_mode_ru:
148         return nir_bcsel(b, sign,
149                          nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)),
150                          nir_umin(b, positive_rounded, max_positive));
151         break;
152      case nir_rounding_mode_rd:
153         return nir_bcsel(b, sign,
154                          nir_ineg(b,
155                                   nir_umin(b, max_positive,
156                                            nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))),
157                          positive_rounded);
158      case nir_rounding_mode_rtne:
159      case nir_rounding_mode_undef:
160         break;
161      }
162      unreachable("unexpected rounding mode");
163   } else {
164      nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits);
165      nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size);
166      nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size);
167      nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size);
168      nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose);
169      nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one));
170      nir_ssa_def *truncated = nir_iand(b, src, mask);
171      switch (round) {
172      case nir_rounding_mode_rtz:
173      case nir_rounding_mode_rd:
174         return truncated;
175         break;
176      case nir_rounding_mode_ru:
177         return nir_bcsel(b, nir_ieq(b, src, truncated),
178                             src, nir_uadd_sat(b, truncated, adjust));
179      case nir_rounding_mode_rtne:
180      case nir_rounding_mode_undef:
181         break;
182      }
183      unreachable("unexpected rounding mode");
184   }
185}
186
187/** Returns true if the representable range of a contains the representable
188 * range of b.
189 */
190static inline bool
191nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)
192{
193   /* Split types from bit sizes */
194   nir_alu_type a_base_type = nir_alu_type_get_base_type(a);
195   nir_alu_type b_base_type = nir_alu_type_get_base_type(b);
196   unsigned a_bit_size = nir_alu_type_get_type_size(a);
197   unsigned b_bit_size = nir_alu_type_get_type_size(b);
198
199   /* This requires sized types */
200   assert(a_bit_size > 0 && b_bit_size > 0);
201
202   if (a_base_type == b_base_type && a_bit_size >= b_bit_size)
203      return true;
204
205   if (a_base_type == nir_type_int && b_base_type == nir_type_uint &&
206       a_bit_size > b_bit_size)
207      return true;
208
209   /* 16-bit floats fit in 32-bit integers */
210   if (a_base_type == nir_type_int && a_bit_size >= 32 &&
211       b == nir_type_float16)
212      return true;
213
214   /* All signed or unsigned ints can fit in float or above. A uint8 can fit
215    * in a float16.
216    */
217   if (a_base_type == nir_type_float && b_base_type != nir_type_float &&
218       (a_bit_size >= 32 || b_bit_size == 8))
219      return true;
220
221   return false;
222}
223
224/**
225 * Retrieves limits used for clamping a value of the src type into
226 * the widest representable range of the dst type via cmp + bcsel
227 */
228static inline void
229nir_get_clamp_limits(nir_builder *b,
230                     nir_alu_type src_type,
231                     nir_alu_type dest_type,
232                     nir_ssa_def **low, nir_ssa_def **high)
233{
234   /* Split types from bit sizes */
235   nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
236   nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
237   unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
238   unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
239   assert(dest_bit_size != 0 && src_bit_size != 0);
240
241   *low = NULL;
242   *high = NULL;
243
244   /* limits of the destination type, expressed in the source type */
245   switch (dest_base_type) {
246   case nir_type_int: {
247      int64_t ilow, ihigh;
248      if (dest_bit_size == 64) {
249         ilow = INT64_MIN;
250         ihigh = INT64_MAX;
251      } else {
252         ilow = -(1ll << (dest_bit_size - 1));
253         ihigh = (1ll << (dest_bit_size - 1)) - 1;
254      }
255
256      if (src_base_type == nir_type_int) {
257         *low = nir_imm_intN_t(b, ilow, src_bit_size);
258         *high = nir_imm_intN_t(b, ihigh, src_bit_size);
259      } else if (src_base_type == nir_type_uint) {
260         assert(src_bit_size >= dest_bit_size);
261         *high = nir_imm_intN_t(b, ihigh, src_bit_size);
262      } else {
263         *low = nir_imm_floatN_t(b, ilow, src_bit_size);
264         *high = nir_imm_floatN_t(b, ihigh, src_bit_size);
265      }
266      break;
267   }
268   case nir_type_uint: {
269      uint64_t uhigh = dest_bit_size == 64 ?
270         ~0ull : (1ull << dest_bit_size) - 1;
271      if (src_base_type != nir_type_float) {
272         *low = nir_imm_intN_t(b, 0, src_bit_size);
273         if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size)
274            *high = nir_imm_intN_t(b, uhigh, src_bit_size);
275      } else {
276         *low = nir_imm_floatN_t(b, 0.0f, src_bit_size);
277         *high = nir_imm_floatN_t(b, uhigh, src_bit_size);
278      }
279      break;
280   }
281   case nir_type_float: {
282      double flow, fhigh;
283      switch (dest_bit_size) {
284      case 16:
285         flow = -65504.0f;
286         fhigh = 65504.0f;
287         break;
288      case 32:
289         flow = -FLT_MAX;
290         fhigh = FLT_MAX;
291         break;
292      case 64:
293         flow = -DBL_MAX;
294         fhigh = DBL_MAX;
295         break;
296      default:
297         unreachable("Unhandled bit size");
298      }
299
300      switch (src_base_type) {
301      case nir_type_int: {
302         int64_t src_ilow, src_ihigh;
303         if (src_bit_size == 64) {
304            src_ilow = INT64_MIN;
305            src_ihigh = INT64_MAX;
306         } else {
307            src_ilow = -(1ll << (src_bit_size - 1));
308            src_ihigh = (1ll << (src_bit_size - 1)) - 1;
309         }
310         if (src_ilow < flow)
311            *low = nir_imm_intN_t(b, flow, src_bit_size);
312         if (src_ihigh > fhigh)
313            *high = nir_imm_intN_t(b, fhigh, src_bit_size);
314         break;
315      }
316      case nir_type_uint: {
317         uint64_t src_uhigh = src_bit_size == 64 ?
318            ~0ull : (1ull << src_bit_size) - 1;
319         if (src_uhigh > fhigh)
320            *high = nir_imm_intN_t(b, fhigh, src_bit_size);
321         break;
322      }
323      case nir_type_float:
324         *low = nir_imm_floatN_t(b, flow, src_bit_size);
325         *high = nir_imm_floatN_t(b, fhigh, src_bit_size);
326         break;
327      default:
328         unreachable("Clamping from unknown type");
329      }
330      break;
331   }
332   default:
333      unreachable("clamping to unknown type");
334      break;
335   }
336}
337
338/**
339 * Clamp the value into the widest representatble range of the
340 * destination type with cmp + bcsel.
341 *
342 * val/val_type: The variables used for bcsel
343 * src/src_type: The variables used for comparison
344 * dest_type: The type which determines the range used for comparison
345 */
346static inline nir_ssa_def *
347nir_clamp_to_type_range(nir_builder *b,
348                        nir_ssa_def *val, nir_alu_type val_type,
349                        nir_ssa_def *src, nir_alu_type src_type,
350                        nir_alu_type dest_type)
351{
352   assert(nir_alu_type_get_type_size(src_type) == 0 ||
353          nir_alu_type_get_type_size(src_type) == src->bit_size);
354   src_type |= src->bit_size;
355   if (nir_alu_type_range_contains_type_range(dest_type, src_type))
356      return val;
357
358   /* limits of the destination type, expressed in the source type */
359   nir_ssa_def *low = NULL, *high = NULL;
360   nir_get_clamp_limits(b, src_type, dest_type, &low, &high);
361
362   nir_ssa_def *low_cond = NULL, *high_cond = NULL;
363   switch (nir_alu_type_get_base_type(src_type)) {
364   case nir_type_int:
365      low_cond = low ? nir_ilt(b, src, low) : NULL;
366      high_cond = high ? nir_ilt(b, high, src) : NULL;
367      break;
368   case nir_type_uint:
369      low_cond = low ? nir_ult(b, src, low) : NULL;
370      high_cond = high ? nir_ult(b, high, src) : NULL;
371      break;
372   case nir_type_float:
373      low_cond = low ? nir_fge(b, low, src) : NULL;
374      high_cond = high ? nir_fge(b, src, high) : NULL;
375      break;
376   default:
377      unreachable("clamping from unknown type");
378   }
379
380   nir_ssa_def *val_low = low, *val_high = high;
381   if (val_type != src_type) {
382      nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high);
383   }
384
385   nir_ssa_def *res = val;
386   if (low_cond && val_low)
387      res = nir_bcsel(b, low_cond, val_low, res);
388   if (high_cond && val_high)
389      res = nir_bcsel(b, high_cond, val_high, res);
390
391   return res;
392}
393
394static inline nir_rounding_mode
395nir_simplify_conversion_rounding(nir_alu_type src_type,
396                                 nir_alu_type dest_type,
397                                 nir_rounding_mode rounding)
398{
399   nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
400   nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
401   unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
402   unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
403   assert(src_bit_size > 0 && dest_bit_size > 0);
404
405   if (rounding == nir_rounding_mode_undef)
406      return rounding;
407
408   /* Pure integer conversion doesn't have any rounding */
409   if (src_base_type != nir_type_float &&
410       dest_base_type != nir_type_float)
411      return nir_rounding_mode_undef;
412
413   /* Float down-casts don't round */
414   if (src_base_type == nir_type_float &&
415       dest_base_type == nir_type_float &&
416       dest_bit_size >= src_bit_size)
417      return nir_rounding_mode_undef;
418
419   /* Regular float to int conversions are RTZ */
420   if (src_base_type == nir_type_float &&
421       dest_base_type != nir_type_float &&
422       rounding == nir_rounding_mode_rtz)
423      return nir_rounding_mode_undef;
424
425   /* The CL spec requires regular conversions to float to be RTNE */
426   if (dest_base_type == nir_type_float &&
427       rounding == nir_rounding_mode_rtne)
428      return nir_rounding_mode_undef;
429
430   /* Couldn't simplify */
431   return rounding;
432}
433
434static inline nir_ssa_def *
435nir_convert_with_rounding(nir_builder *b,
436                          nir_ssa_def *src, nir_alu_type src_type,
437                          nir_alu_type dest_type,
438                          nir_rounding_mode round,
439                          bool clamp)
440{
441   /* Some stuff wants sized types */
442   assert(nir_alu_type_get_type_size(src_type) == 0 ||
443          nir_alu_type_get_type_size(src_type) == src->bit_size);
444   src_type |= src->bit_size;
445
446   /* Split types from bit sizes */
447   nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
448   nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
449   unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
450
451   /* Try to simplify the conversion if we can */
452   clamp = clamp &&
453      !nir_alu_type_range_contains_type_range(dest_type, src_type);
454   round = nir_simplify_conversion_rounding(src_type, dest_type, round);
455
456   /* For float -> int/uint conversions, we might not be able to represent
457    * the destination range in the source float accurately. For these cases,
458    * do the comparison in float range, but the bcsel in the destination range.
459    */
460   bool clamp_after_conversion = clamp &&
461      src_base_type == nir_type_float &&
462      dest_base_type != nir_type_float;
463
464   /*
465    * If we don't care about rounding and clamping, we can just use NIR's
466    * built-in ops. There is also a special case for SPIR-V in shaders, where
467    * f32/f64 -> f16 conversions can have one of two rounding modes applied,
468    * which NIR has built-in opcodes for.
469    *
470    * For the rest, we have our own implementation of rounding and clamping.
471    */
472   bool trivial_convert;
473   if (!clamp && round == nir_rounding_mode_undef) {
474      trivial_convert = true;
475   } else if (!clamp && src_type == nir_type_float32 &&
476                        dest_type == nir_type_float16 &&
477                        (round == nir_rounding_mode_rtne ||
478                         round == nir_rounding_mode_rtz)) {
479      trivial_convert = true;
480   } else {
481      trivial_convert = false;
482   }
483   if (trivial_convert) {
484      nir_op op = nir_type_conversion_op(src_type, dest_type, round);
485      return nir_build_alu(b, op, src, NULL, NULL, NULL);
486   }
487
488   nir_ssa_def *dest = src;
489
490   /* clamp the result into range */
491   if (clamp && !clamp_after_conversion)
492      dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type);
493
494   /* round with selected rounding mode */
495   if (!trivial_convert && round != nir_rounding_mode_undef) {
496      if (src_base_type == nir_type_float) {
497         if (dest_base_type == nir_type_float) {
498            dest = nir_round_float_to_float(b, dest, dest_bit_size, round);
499         } else {
500            dest = nir_round_float_to_int(b, dest, round);
501         }
502      } else {
503         dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round);
504      }
505
506      round = nir_rounding_mode_undef;
507   }
508
509   /* now we can convert the value */
510   nir_op op = nir_type_conversion_op(src_type, dest_type, round);
511   dest = nir_build_alu(b, op, dest, NULL, NULL, NULL);
512
513   if (clamp_after_conversion)
514      dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type);
515
516   return dest;
517}
518
519#ifdef __cplusplus
520}
521#endif
522
523#endif /* NIR_CONVERSION_BUILDER_H */
524