1/* 2 * Copyright © Microsoft Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#include "nir_builder.h" 25 26/* The following float-to-half conversion routines are based on the "half" library: 27 * https://sourceforge.net/projects/half/ 28 * 29 * half - IEEE 754-based half-precision floating-point library. 30 * 31 * Copyright (c) 2012-2019 Christian Rau <rauy@users.sourceforge.net> 32 * 33 * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation 34 * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, 35 * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the 36 * Software is furnished to do so, subject to the following conditions: 37 * 38 * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 39 * 40 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 41 * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 42 * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 43 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 44 * 45 * Version 2.1.0 46 */ 47static bool 48lower_fp16_casts_filter(const nir_instr *instr, const void *data) 49{ 50 if (instr->type == nir_instr_type_alu) { 51 nir_alu_instr *alu = nir_instr_as_alu(instr); 52 switch (alu->op) { 53 case nir_op_f2f16: 54 case nir_op_f2f16_rtne: 55 case nir_op_f2f16_rtz: 56 return true; 57 default: 58 return false; 59 } 60 } else if (instr->type == nir_instr_type_intrinsic) { 61 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 62 return intrin->intrinsic == nir_intrinsic_convert_alu_types && 63 nir_intrinsic_dest_type(intrin) == nir_type_float16; 64 } 65 return false; 66} 67 68static nir_ssa_def * 69half_rounded(nir_builder *b, nir_ssa_def *value, nir_ssa_def *guard, nir_ssa_def *sticky, 70 nir_ssa_def *sign, nir_rounding_mode mode) 71{ 72 switch (mode) { 73 case nir_rounding_mode_rtne: 74 return nir_iadd(b, value, nir_iand(b, guard, nir_ior(b, sticky, value))); 75 case nir_rounding_mode_ru: 76 sign = nir_ushr(b, sign, nir_imm_int(b, 31)); 77 return nir_iadd(b, value, nir_iand(b, nir_inot(b, sign), 78 nir_ior(b, guard, sticky))); 79 case nir_rounding_mode_rd: 80 sign = nir_ushr(b, sign, nir_imm_int(b, 31)); 81 return nir_iadd(b, value, nir_iand(b, sign, 82 nir_ior(b, guard, sticky))); 83 default: 84 return value; 85 } 86} 87 88static nir_ssa_def * 89float_to_half_impl(nir_builder *b, nir_ssa_def *src, nir_rounding_mode mode) 90{ 91 nir_ssa_def *f32infinity = nir_imm_int(b, 255 << 23); 92 nir_ssa_def *f16max = nir_imm_int(b, (127 + 16) << 23); 93 94 if (src->bit_size == 64) 95 src = nir_f2f32(b, src); 96 nir_ssa_def *sign = nir_iand(b, src, nir_imm_int(b, 0x80000000)); 97 nir_ssa_def *one = nir_imm_int(b, 1); 98 99 nir_ssa_def *abs = nir_iand(b, src, nir_imm_int(b, 0x7FFFFFFF)); 100 /* NaN or INF. For rtne, overflow also becomes INF, so combine the comparisons */ 101 nir_push_if(b, nir_ige(b, abs, mode == nir_rounding_mode_rtne ? f16max : f32infinity)); 102 nir_ssa_def *inf_nanfp16 = nir_bcsel(b, 103 nir_ilt(b, f32infinity, abs), 104 nir_imm_int(b, 0x7E00), 105 nir_imm_int(b, 0x7C00)); 106 nir_push_else(b, NULL); 107 108 nir_ssa_def *overflowed_fp16 = NULL; 109 if (mode != nir_rounding_mode_rtne) { 110 /* Handle overflow */ 111 nir_push_if(b, nir_ige(b, abs, f16max)); 112 switch (mode) { 113 case nir_rounding_mode_rtz: 114 overflowed_fp16 = nir_imm_int(b, 0x7BFF); 115 break; 116 case nir_rounding_mode_ru: 117 /* Negative becomes max float, positive becomes inf */ 118 overflowed_fp16 = nir_bcsel(b, nir_i2b1(b, sign), nir_imm_int(b, 0x7BFF), nir_imm_int(b, 0x7C00)); 119 break; 120 case nir_rounding_mode_rd: 121 /* Negative becomes inf, positive becomes max float */ 122 overflowed_fp16 = nir_bcsel(b, nir_i2b1(b, sign), nir_imm_int(b, 0x7C00), nir_imm_int(b, 0x7BFF)); 123 break; 124 default: unreachable("Should've been handled already"); 125 } 126 nir_push_else(b, NULL); 127 } 128 129 nir_push_if(b, nir_ige(b, abs, nir_imm_int(b, 113 << 23))); 130 131 /* FP16 will be normal */ 132 nir_ssa_def *zero = nir_imm_int(b, 0); 133 nir_ssa_def *value = nir_ior(b, 134 nir_ishl(b, 135 nir_isub(b, 136 nir_ushr(b, abs, nir_imm_int(b, 23)), 137 nir_imm_int(b, 112)), 138 nir_imm_int(b, 10)), 139 nir_iand(b, nir_ushr(b, abs, nir_imm_int(b, 13)), nir_imm_int(b, 0x3FFF))); 140 nir_ssa_def *guard = nir_iand(b, nir_ushr(b, abs, nir_imm_int(b, 12)), one); 141 nir_ssa_def *sticky = nir_bcsel(b, nir_ine(b, nir_iand(b, abs, nir_imm_int(b, 0xFFF)), zero), one, zero); 142 nir_ssa_def *normal_fp16 = half_rounded(b, value, guard, sticky, sign, mode); 143 144 nir_push_else(b, NULL); 145 nir_push_if(b, nir_ige(b, abs, nir_imm_int(b, 102 << 23))); 146 147 /* FP16 will be denormal */ 148 nir_ssa_def *i = nir_isub(b, nir_imm_int(b, 125), nir_ushr(b, abs, nir_imm_int(b, 23))); 149 nir_ssa_def *masked = nir_ior(b, nir_iand(b, abs, nir_imm_int(b, 0x7FFFFF)), nir_imm_int(b, 0x800000)); 150 value = nir_ushr(b, masked, nir_iadd(b, i, one)); 151 guard = nir_iand(b, nir_ushr(b, masked, i), one); 152 sticky = nir_bcsel(b, nir_ine(b, nir_iand(b, masked, nir_isub(b, nir_ishl(b, one, i), one)), zero), one, zero); 153 nir_ssa_def *denormal_fp16 = half_rounded(b, value, guard, sticky, sign, mode); 154 155 nir_push_else(b, NULL); 156 157 /* Handle underflow. Nonzero values need to shift up or down for round-up or round-down */ 158 nir_ssa_def *underflowed_fp16 = zero; 159 if (mode == nir_rounding_mode_ru || 160 mode == nir_rounding_mode_rd) { 161 nir_push_if(b, nir_i2b1(b, abs)); 162 163 if (mode == nir_rounding_mode_ru) 164 underflowed_fp16 = nir_bcsel(b, nir_i2b1(b, sign), zero, one); 165 else 166 underflowed_fp16 = nir_bcsel(b, nir_i2b1(b, sign), one, zero); 167 168 nir_push_else(b, NULL); 169 nir_pop_if(b, NULL); 170 underflowed_fp16 = nir_if_phi(b, underflowed_fp16, zero); 171 } 172 173 nir_pop_if(b, NULL); 174 nir_ssa_def *underflowed_or_denorm_fp16 = nir_if_phi(b, denormal_fp16, underflowed_fp16); 175 176 nir_pop_if(b, NULL); 177 nir_ssa_def *finite_fp16 = nir_if_phi(b, normal_fp16, underflowed_or_denorm_fp16); 178 179 nir_ssa_def *finite_or_overflowed_fp16 = finite_fp16; 180 if (mode != nir_rounding_mode_rtne) { 181 nir_pop_if(b, NULL); 182 finite_or_overflowed_fp16 = nir_if_phi(b, overflowed_fp16, finite_fp16); 183 } 184 185 nir_pop_if(b, NULL); 186 nir_ssa_def *fp16 = nir_if_phi(b, inf_nanfp16, finite_or_overflowed_fp16); 187 188 return nir_u2u16(b, nir_ior(b, fp16, nir_ushr(b, sign, nir_imm_int(b, 16)))); 189} 190 191static nir_ssa_def * 192lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data) 193{ 194 nir_ssa_def *src, *dst; 195 uint8_t *swizzle = NULL; 196 nir_rounding_mode mode = nir_rounding_mode_rtne; 197 198 if (instr->type == nir_instr_type_alu) { 199 nir_alu_instr *alu = nir_instr_as_alu(instr); 200 src = alu->src[0].src.ssa; 201 swizzle = alu->src[0].swizzle; 202 dst = &alu->dest.dest.ssa; 203 switch (alu->op) { 204 case nir_op_f2f16: 205 case nir_op_f2f16_rtne: 206 break; 207 case nir_op_f2f16_rtz: 208 mode = nir_rounding_mode_rtz; 209 break; 210 default: unreachable("Should've been filtered"); 211 } 212 } else { 213 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); 214 assert(nir_intrinsic_src_type(intrin) == nir_type_float32); 215 src = intrin->src[0].ssa; 216 dst = &intrin->dest.ssa; 217 mode = nir_intrinsic_rounding_mode(intrin); 218 } 219 220 nir_ssa_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL }; 221 222 for (unsigned i = 0; i < dst->num_components; i++) { 223 nir_ssa_def *comp = nir_channel(b, src, swizzle ? swizzle[i] : i); 224 rets[i] = float_to_half_impl(b, comp, mode); 225 } 226 227 return nir_vec(b, rets, dst->num_components); 228} 229 230bool 231nir_lower_fp16_casts(nir_shader *shader) 232{ 233 return nir_shader_lower_instructions(shader, 234 lower_fp16_casts_filter, 235 lower_fp16_cast_impl, 236 NULL); 237} 238