1/* 2 * Copyright © 2020 Collabora Ltd. 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#ifndef NIR_CONVERSION_BUILDER_H 25#define NIR_CONVERSION_BUILDER_H 26 27#include "util/u_math.h" 28#include "nir_builder.h" 29#include "nir_builtin_builder.h" 30 31#ifdef __cplusplus 32extern "C" { 33#endif 34 35static inline nir_ssa_def * 36nir_round_float_to_int(nir_builder *b, nir_ssa_def *src, 37 nir_rounding_mode round) 38{ 39 switch (round) { 40 case nir_rounding_mode_ru: 41 return nir_fceil(b, src); 42 43 case nir_rounding_mode_rd: 44 return nir_ffloor(b, src); 45 46 case nir_rounding_mode_rtne: 47 return nir_fround_even(b, src); 48 49 case nir_rounding_mode_undef: 50 case nir_rounding_mode_rtz: 51 break; 52 } 53 unreachable("unexpected rounding mode"); 54} 55 56static inline nir_ssa_def * 57nir_round_float_to_float(nir_builder *b, nir_ssa_def *src, 58 unsigned dest_bit_size, 59 nir_rounding_mode round) 60{ 61 unsigned src_bit_size = src->bit_size; 62 if (dest_bit_size > src_bit_size) 63 return src; /* No rounding is needed for an up-convert */ 64 65 nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size, 66 nir_type_float | dest_bit_size, 67 nir_rounding_mode_undef); 68 nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size, 69 nir_type_float | src_bit_size, 70 nir_rounding_mode_undef); 71 72 switch (round) { 73 case nir_rounding_mode_ru: { 74 /* If lower-precision conversion results in a lower value, push it 75 * up one ULP. */ 76 nir_ssa_def *lower_prec = 77 nir_build_alu(b, low_conv, src, NULL, NULL, NULL); 78 nir_ssa_def *roundtrip = 79 nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL); 80 nir_ssa_def *cmp = nir_flt(b, roundtrip, src); 81 nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size); 82 return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec); 83 } 84 case nir_rounding_mode_rd: { 85 /* If lower-precision conversion results in a higher value, push it 86 * down one ULP. */ 87 nir_ssa_def *lower_prec = 88 nir_build_alu(b, low_conv, src, NULL, NULL, NULL); 89 nir_ssa_def *roundtrip = 90 nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL); 91 nir_ssa_def *cmp = nir_flt(b, src, roundtrip); 92 nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size); 93 return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec); 94 } 95 case nir_rounding_mode_rtz: 96 return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)), 97 nir_round_float_to_float(b, src, dest_bit_size, 98 nir_rounding_mode_ru), 99 nir_round_float_to_float(b, src, dest_bit_size, 100 nir_rounding_mode_rd)); 101 case nir_rounding_mode_rtne: 102 case nir_rounding_mode_undef: 103 break; 104 } 105 unreachable("unexpected rounding mode"); 106} 107 108static inline nir_ssa_def * 109nir_round_int_to_float(nir_builder *b, nir_ssa_def *src, 110 nir_alu_type src_type, 111 unsigned dest_bit_size, 112 nir_rounding_mode round) 113{ 114 /* We only care whether or not its signed */ 115 src_type = nir_alu_type_get_base_type(src_type); 116 117 unsigned mantissa_bits; 118 switch (dest_bit_size) { 119 case 16: 120 mantissa_bits = 10; 121 break; 122 case 32: 123 mantissa_bits = 23; 124 break; 125 case 64: 126 mantissa_bits = 52; 127 break; 128 default: unreachable("Unsupported bit size"); 129 } 130 131 if (src->bit_size < mantissa_bits) 132 return src; 133 134 if (src_type == nir_type_int) { 135 nir_ssa_def *sign = 136 nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1))); 137 nir_ssa_def *abs = nir_iabs(b, src); 138 nir_ssa_def *positive_rounded = 139 nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round); 140 nir_ssa_def *max_positive = 141 nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size); 142 switch (round) { 143 case nir_rounding_mode_rtz: 144 return nir_bcsel(b, sign, nir_ineg(b, positive_rounded), 145 positive_rounded); 146 break; 147 case nir_rounding_mode_ru: 148 return nir_bcsel(b, sign, 149 nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)), 150 nir_umin(b, positive_rounded, max_positive)); 151 break; 152 case nir_rounding_mode_rd: 153 return nir_bcsel(b, sign, 154 nir_ineg(b, 155 nir_umin(b, max_positive, 156 nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))), 157 positive_rounded); 158 case nir_rounding_mode_rtne: 159 case nir_rounding_mode_undef: 160 break; 161 } 162 unreachable("unexpected rounding mode"); 163 } else { 164 nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits); 165 nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size); 166 nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size); 167 nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size); 168 nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose); 169 nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one)); 170 nir_ssa_def *truncated = nir_iand(b, src, mask); 171 switch (round) { 172 case nir_rounding_mode_rtz: 173 case nir_rounding_mode_rd: 174 return truncated; 175 break; 176 case nir_rounding_mode_ru: 177 return nir_bcsel(b, nir_ieq(b, src, truncated), 178 src, nir_uadd_sat(b, truncated, adjust)); 179 case nir_rounding_mode_rtne: 180 case nir_rounding_mode_undef: 181 break; 182 } 183 unreachable("unexpected rounding mode"); 184 } 185} 186 187/** Returns true if the representable range of a contains the representable 188 * range of b. 189 */ 190static inline bool 191nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b) 192{ 193 /* Split types from bit sizes */ 194 nir_alu_type a_base_type = nir_alu_type_get_base_type(a); 195 nir_alu_type b_base_type = nir_alu_type_get_base_type(b); 196 unsigned a_bit_size = nir_alu_type_get_type_size(a); 197 unsigned b_bit_size = nir_alu_type_get_type_size(b); 198 199 /* This requires sized types */ 200 assert(a_bit_size > 0 && b_bit_size > 0); 201 202 if (a_base_type == b_base_type && a_bit_size >= b_bit_size) 203 return true; 204 205 if (a_base_type == nir_type_int && b_base_type == nir_type_uint && 206 a_bit_size > b_bit_size) 207 return true; 208 209 /* 16-bit floats fit in 32-bit integers */ 210 if (a_base_type == nir_type_int && a_bit_size >= 32 && 211 b == nir_type_float16) 212 return true; 213 214 /* All signed or unsigned ints can fit in float or above. A uint8 can fit 215 * in a float16. 216 */ 217 if (a_base_type == nir_type_float && b_base_type != nir_type_float && 218 (a_bit_size >= 32 || b_bit_size == 8)) 219 return true; 220 221 return false; 222} 223 224/** 225 * Retrieves limits used for clamping a value of the src type into 226 * the widest representable range of the dst type via cmp + bcsel 227 */ 228static inline void 229nir_get_clamp_limits(nir_builder *b, 230 nir_alu_type src_type, 231 nir_alu_type dest_type, 232 nir_ssa_def **low, nir_ssa_def **high) 233{ 234 /* Split types from bit sizes */ 235 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); 236 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); 237 unsigned src_bit_size = nir_alu_type_get_type_size(src_type); 238 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); 239 assert(dest_bit_size != 0 && src_bit_size != 0); 240 241 *low = NULL; 242 *high = NULL; 243 244 /* limits of the destination type, expressed in the source type */ 245 switch (dest_base_type) { 246 case nir_type_int: { 247 int64_t ilow, ihigh; 248 if (dest_bit_size == 64) { 249 ilow = INT64_MIN; 250 ihigh = INT64_MAX; 251 } else { 252 ilow = -(1ll << (dest_bit_size - 1)); 253 ihigh = (1ll << (dest_bit_size - 1)) - 1; 254 } 255 256 if (src_base_type == nir_type_int) { 257 *low = nir_imm_intN_t(b, ilow, src_bit_size); 258 *high = nir_imm_intN_t(b, ihigh, src_bit_size); 259 } else if (src_base_type == nir_type_uint) { 260 assert(src_bit_size >= dest_bit_size); 261 *high = nir_imm_intN_t(b, ihigh, src_bit_size); 262 } else { 263 *low = nir_imm_floatN_t(b, ilow, src_bit_size); 264 *high = nir_imm_floatN_t(b, ihigh, src_bit_size); 265 } 266 break; 267 } 268 case nir_type_uint: { 269 uint64_t uhigh = dest_bit_size == 64 ? 270 ~0ull : (1ull << dest_bit_size) - 1; 271 if (src_base_type != nir_type_float) { 272 *low = nir_imm_intN_t(b, 0, src_bit_size); 273 if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size) 274 *high = nir_imm_intN_t(b, uhigh, src_bit_size); 275 } else { 276 *low = nir_imm_floatN_t(b, 0.0f, src_bit_size); 277 *high = nir_imm_floatN_t(b, uhigh, src_bit_size); 278 } 279 break; 280 } 281 case nir_type_float: { 282 double flow, fhigh; 283 switch (dest_bit_size) { 284 case 16: 285 flow = -65504.0f; 286 fhigh = 65504.0f; 287 break; 288 case 32: 289 flow = -FLT_MAX; 290 fhigh = FLT_MAX; 291 break; 292 case 64: 293 flow = -DBL_MAX; 294 fhigh = DBL_MAX; 295 break; 296 default: 297 unreachable("Unhandled bit size"); 298 } 299 300 switch (src_base_type) { 301 case nir_type_int: { 302 int64_t src_ilow, src_ihigh; 303 if (src_bit_size == 64) { 304 src_ilow = INT64_MIN; 305 src_ihigh = INT64_MAX; 306 } else { 307 src_ilow = -(1ll << (src_bit_size - 1)); 308 src_ihigh = (1ll << (src_bit_size - 1)) - 1; 309 } 310 if (src_ilow < flow) 311 *low = nir_imm_intN_t(b, flow, src_bit_size); 312 if (src_ihigh > fhigh) 313 *high = nir_imm_intN_t(b, fhigh, src_bit_size); 314 break; 315 } 316 case nir_type_uint: { 317 uint64_t src_uhigh = src_bit_size == 64 ? 318 ~0ull : (1ull << src_bit_size) - 1; 319 if (src_uhigh > fhigh) 320 *high = nir_imm_intN_t(b, fhigh, src_bit_size); 321 break; 322 } 323 case nir_type_float: 324 *low = nir_imm_floatN_t(b, flow, src_bit_size); 325 *high = nir_imm_floatN_t(b, fhigh, src_bit_size); 326 break; 327 default: 328 unreachable("Clamping from unknown type"); 329 } 330 break; 331 } 332 default: 333 unreachable("clamping to unknown type"); 334 break; 335 } 336} 337 338/** 339 * Clamp the value into the widest representatble range of the 340 * destination type with cmp + bcsel. 341 * 342 * val/val_type: The variables used for bcsel 343 * src/src_type: The variables used for comparison 344 * dest_type: The type which determines the range used for comparison 345 */ 346static inline nir_ssa_def * 347nir_clamp_to_type_range(nir_builder *b, 348 nir_ssa_def *val, nir_alu_type val_type, 349 nir_ssa_def *src, nir_alu_type src_type, 350 nir_alu_type dest_type) 351{ 352 assert(nir_alu_type_get_type_size(src_type) == 0 || 353 nir_alu_type_get_type_size(src_type) == src->bit_size); 354 src_type |= src->bit_size; 355 if (nir_alu_type_range_contains_type_range(dest_type, src_type)) 356 return val; 357 358 /* limits of the destination type, expressed in the source type */ 359 nir_ssa_def *low = NULL, *high = NULL; 360 nir_get_clamp_limits(b, src_type, dest_type, &low, &high); 361 362 nir_ssa_def *low_cond = NULL, *high_cond = NULL; 363 switch (nir_alu_type_get_base_type(src_type)) { 364 case nir_type_int: 365 low_cond = low ? nir_ilt(b, src, low) : NULL; 366 high_cond = high ? nir_ilt(b, high, src) : NULL; 367 break; 368 case nir_type_uint: 369 low_cond = low ? nir_ult(b, src, low) : NULL; 370 high_cond = high ? nir_ult(b, high, src) : NULL; 371 break; 372 case nir_type_float: 373 low_cond = low ? nir_fge(b, low, src) : NULL; 374 high_cond = high ? nir_fge(b, src, high) : NULL; 375 break; 376 default: 377 unreachable("clamping from unknown type"); 378 } 379 380 nir_ssa_def *val_low = low, *val_high = high; 381 if (val_type != src_type) { 382 nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high); 383 } 384 385 nir_ssa_def *res = val; 386 if (low_cond && val_low) 387 res = nir_bcsel(b, low_cond, val_low, res); 388 if (high_cond && val_high) 389 res = nir_bcsel(b, high_cond, val_high, res); 390 391 return res; 392} 393 394static inline nir_rounding_mode 395nir_simplify_conversion_rounding(nir_alu_type src_type, 396 nir_alu_type dest_type, 397 nir_rounding_mode rounding) 398{ 399 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); 400 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); 401 unsigned src_bit_size = nir_alu_type_get_type_size(src_type); 402 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); 403 assert(src_bit_size > 0 && dest_bit_size > 0); 404 405 if (rounding == nir_rounding_mode_undef) 406 return rounding; 407 408 /* Pure integer conversion doesn't have any rounding */ 409 if (src_base_type != nir_type_float && 410 dest_base_type != nir_type_float) 411 return nir_rounding_mode_undef; 412 413 /* Float down-casts don't round */ 414 if (src_base_type == nir_type_float && 415 dest_base_type == nir_type_float && 416 dest_bit_size >= src_bit_size) 417 return nir_rounding_mode_undef; 418 419 /* Regular float to int conversions are RTZ */ 420 if (src_base_type == nir_type_float && 421 dest_base_type != nir_type_float && 422 rounding == nir_rounding_mode_rtz) 423 return nir_rounding_mode_undef; 424 425 /* The CL spec requires regular conversions to float to be RTNE */ 426 if (dest_base_type == nir_type_float && 427 rounding == nir_rounding_mode_rtne) 428 return nir_rounding_mode_undef; 429 430 /* Couldn't simplify */ 431 return rounding; 432} 433 434static inline nir_ssa_def * 435nir_convert_with_rounding(nir_builder *b, 436 nir_ssa_def *src, nir_alu_type src_type, 437 nir_alu_type dest_type, 438 nir_rounding_mode round, 439 bool clamp) 440{ 441 /* Some stuff wants sized types */ 442 assert(nir_alu_type_get_type_size(src_type) == 0 || 443 nir_alu_type_get_type_size(src_type) == src->bit_size); 444 src_type |= src->bit_size; 445 446 /* Split types from bit sizes */ 447 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); 448 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); 449 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); 450 451 /* Try to simplify the conversion if we can */ 452 clamp = clamp && 453 !nir_alu_type_range_contains_type_range(dest_type, src_type); 454 round = nir_simplify_conversion_rounding(src_type, dest_type, round); 455 456 /* For float -> int/uint conversions, we might not be able to represent 457 * the destination range in the source float accurately. For these cases, 458 * do the comparison in float range, but the bcsel in the destination range. 459 */ 460 bool clamp_after_conversion = clamp && 461 src_base_type == nir_type_float && 462 dest_base_type != nir_type_float; 463 464 /* 465 * If we don't care about rounding and clamping, we can just use NIR's 466 * built-in ops. There is also a special case for SPIR-V in shaders, where 467 * f32/f64 -> f16 conversions can have one of two rounding modes applied, 468 * which NIR has built-in opcodes for. 469 * 470 * For the rest, we have our own implementation of rounding and clamping. 471 */ 472 bool trivial_convert; 473 if (!clamp && round == nir_rounding_mode_undef) { 474 trivial_convert = true; 475 } else if (!clamp && src_type == nir_type_float32 && 476 dest_type == nir_type_float16 && 477 (round == nir_rounding_mode_rtne || 478 round == nir_rounding_mode_rtz)) { 479 trivial_convert = true; 480 } else { 481 trivial_convert = false; 482 } 483 if (trivial_convert) { 484 nir_op op = nir_type_conversion_op(src_type, dest_type, round); 485 return nir_build_alu(b, op, src, NULL, NULL, NULL); 486 } 487 488 nir_ssa_def *dest = src; 489 490 /* clamp the result into range */ 491 if (clamp && !clamp_after_conversion) 492 dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type); 493 494 /* round with selected rounding mode */ 495 if (!trivial_convert && round != nir_rounding_mode_undef) { 496 if (src_base_type == nir_type_float) { 497 if (dest_base_type == nir_type_float) { 498 dest = nir_round_float_to_float(b, dest, dest_bit_size, round); 499 } else { 500 dest = nir_round_float_to_int(b, dest, round); 501 } 502 } else { 503 dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round); 504 } 505 506 round = nir_rounding_mode_undef; 507 } 508 509 /* now we can convert the value */ 510 nir_op op = nir_type_conversion_op(src_type, dest_type, round); 511 dest = nir_build_alu(b, op, dest, NULL, NULL, NULL); 512 513 if (clamp_after_conversion) 514 dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type); 515 516 return dest; 517} 518 519#ifdef __cplusplus 520} 521#endif 522 523#endif /* NIR_CONVERSION_BUILDER_H */ 524