1/*
2 * Copyright © 2015 Intel Corporation
3 * Copyright © 2019 Valve Corporation
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 * IN THE SOFTWARE.
23 *
24 * Authors:
25 *    Jason Ekstrand (jason@jlekstrand.net)
26 *    Samuel Pitoiset (samuel.pitoiset@gmail.com>
27 */
28
29#include "nir.h"
30#include "nir_builder.h"
31
32static nir_ssa_def *
33lower_frexp_sig(nir_builder *b, nir_ssa_def *x)
34{
35   nir_ssa_def *abs_x = nir_fabs(b, x);
36   nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
37   nir_ssa_def *sign_mantissa_mask, *exponent_value;
38
39   switch (x->bit_size) {
40   case 16:
41      /* Half-precision floating-point values are stored as
42       *   1 sign bit;
43       *   5 exponent bits;
44       *   10 mantissa bits.
45       *
46       * An exponent shift of 10 will shift the mantissa out, leaving only the
47       * exponent and sign bit (which itself may be zero, if the absolute value
48       * was taken before the bitcast and shift).
49       */
50      sign_mantissa_mask = nir_imm_intN_t(b, 0x83ffu, 16);
51      /* Exponent of floating-point values in the range [0.5, 1.0). */
52      exponent_value = nir_imm_intN_t(b, 0x3800u, 16);
53      break;
54   case 32:
55      /* Single-precision floating-point values are stored as
56       *   1 sign bit;
57       *   8 exponent bits;
58       *   23 mantissa bits.
59       *
60       * An exponent shift of 23 will shift the mantissa out, leaving only the
61       * exponent and sign bit (which itself may be zero, if the absolute value
62       * was taken before the bitcast and shift.
63       */
64      sign_mantissa_mask = nir_imm_int(b, 0x807fffffu);
65      /* Exponent of floating-point values in the range [0.5, 1.0). */
66      exponent_value = nir_imm_int(b, 0x3f000000u);
67      break;
68   case 64:
69      /* Double-precision floating-point values are stored as
70       *   1 sign bit;
71       *   11 exponent bits;
72       *   52 mantissa bits.
73       *
74       * An exponent shift of 20 will shift the remaining mantissa bits out,
75       * leaving only the exponent and sign bit (which itself may be zero, if
76       * the absolute value was taken before the bitcast and shift.
77       */
78      sign_mantissa_mask = nir_imm_int(b, 0x800fffffu);
79      /* Exponent of floating-point values in the range [0.5, 1.0). */
80      exponent_value = nir_imm_int(b, 0x3fe00000u);
81      break;
82   default:
83      unreachable("Invalid bitsize");
84   }
85
86   if (x->bit_size == 64) {
87      /* We only need to deal with the exponent so first we extract the upper
88       * 32 bits using nir_unpack_64_2x32_split_y.
89       */
90      nir_ssa_def *upper_x = nir_unpack_64_2x32_split_y(b, x);
91
92      /* If x is ±0, ±Inf, or NaN, return x unmodified. */
93      nir_ssa_def *new_upper =
94         nir_bcsel(b,
95                   nir_iand(b,
96                            nir_flt(b, zero, abs_x),
97                            nir_fisfinite(b, x)),
98                   nir_ior(b,
99                           nir_iand(b, upper_x, sign_mantissa_mask),
100                           exponent_value),
101                   upper_x);
102
103      nir_ssa_def *lower_x = nir_unpack_64_2x32_split_x(b, x);
104
105      return nir_pack_64_2x32_split(b, lower_x, new_upper);
106   } else {
107      /* If x is ±0, ±Inf, or NaN, return x unmodified. */
108      return nir_bcsel(b,
109                       nir_iand(b,
110                                nir_flt(b, zero, abs_x),
111                                nir_fisfinite(b, x)),
112                       nir_ior(b,
113                               nir_iand(b, x, sign_mantissa_mask),
114                               exponent_value),
115                       x);
116   }
117}
118
119static nir_ssa_def *
120lower_frexp_exp(nir_builder *b, nir_ssa_def *x)
121{
122   nir_ssa_def *abs_x = nir_fabs(b, x);
123   nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
124   nir_ssa_def *is_not_zero = nir_fneu(b, abs_x, zero);
125   nir_ssa_def *exponent;
126
127   switch (x->bit_size) {
128   case 16: {
129      nir_ssa_def *exponent_shift = nir_imm_int(b, 10);
130      nir_ssa_def *exponent_bias = nir_imm_intN_t(b, -14, 16);
131
132      /* Significand return must be of the same type as the input, but the
133       * exponent must be a 32-bit integer.
134       */
135      exponent = nir_i2i32(b, nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
136                              nir_bcsel(b, is_not_zero, exponent_bias, zero)));
137      break;
138   }
139   case 32: {
140      nir_ssa_def *exponent_shift = nir_imm_int(b, 23);
141      nir_ssa_def *exponent_bias = nir_imm_int(b, -126);
142
143      exponent = nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
144                             nir_bcsel(b, is_not_zero, exponent_bias, zero));
145      break;
146   }
147   case 64: {
148      nir_ssa_def *exponent_shift = nir_imm_int(b, 20);
149      nir_ssa_def *exponent_bias = nir_imm_int(b, -1022);
150
151      nir_ssa_def *zero32 = nir_imm_int(b, 0);
152      nir_ssa_def *abs_upper_x = nir_unpack_64_2x32_split_y(b, abs_x);
153
154      exponent = nir_iadd(b, nir_ushr(b, abs_upper_x, exponent_shift),
155                             nir_bcsel(b, is_not_zero, exponent_bias, zero32));
156      break;
157   }
158   default:
159      unreachable("Invalid bitsize");
160   }
161
162   return exponent;
163}
164
165static bool
166lower_frexp_impl(nir_function_impl *impl)
167{
168   bool progress = false;
169
170   nir_builder b;
171   nir_builder_init(&b, impl);
172
173   nir_foreach_block(block, impl) {
174      nir_foreach_instr_safe(instr, block) {
175         if (instr->type != nir_instr_type_alu)
176            continue;
177
178         nir_alu_instr *alu_instr = nir_instr_as_alu(instr);
179         nir_ssa_def *lower;
180
181         b.cursor = nir_before_instr(instr);
182
183         switch (alu_instr->op) {
184         case nir_op_frexp_sig:
185            lower = lower_frexp_sig(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
186            break;
187         case nir_op_frexp_exp:
188            lower = lower_frexp_exp(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
189            break;
190         default:
191            continue;
192         }
193
194         nir_ssa_def_rewrite_uses(&alu_instr->dest.dest.ssa,
195                                  lower);
196         nir_instr_remove(instr);
197         progress = true;
198      }
199   }
200
201   if (progress) {
202      nir_metadata_preserve(impl, nir_metadata_block_index |
203                                  nir_metadata_dominance);
204   }
205
206   return progress;
207}
208
209bool
210nir_lower_frexp(nir_shader *shader)
211{
212   bool progress = false;
213
214   nir_foreach_function(function, shader) {
215      if (function->impl)
216         progress |= lower_frexp_impl(function->impl);
217   }
218
219   return progress;
220}
221