1/*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 * DEALINGS IN THE SOFTWARE.
22 */
23
24/**
25 * \file opt_minmax.cpp
26 *
27 * Drop operands from an expression tree of only min/max operations if they
28 * can be proven to not contribute to the final result.
29 *
30 * The algorithm is similar to alpha-beta pruning on a minmax search.
31 */
32
33#include "ir.h"
34#include "ir_visitor.h"
35#include "ir_rvalue_visitor.h"
36#include "ir_optimization.h"
37#include "ir_builder.h"
38#include "program/prog_instruction.h"
39#include "compiler/glsl_types.h"
40#include "main/macros.h"
41#include "util/half_float.h"
42
43using namespace ir_builder;
44
45namespace {
46
47enum compare_components_result {
48   LESS,
49   LESS_OR_EQUAL,
50   EQUAL,
51   GREATER_OR_EQUAL,
52   GREATER,
53   MIXED
54};
55
56class minmax_range {
57public:
58   minmax_range(ir_constant *low = NULL, ir_constant *high = NULL)
59   {
60      this->low = low;
61      this->high = high;
62   }
63
64   /* low is the lower limit of the range, high is the higher limit. NULL on
65    * low means negative infinity (unlimited) and on high positive infinity
66    * (unlimited). Because of the two interpretations of the value NULL,
67    * arbitrary comparison between ir_constants is impossible.
68    */
69   ir_constant *low;
70   ir_constant *high;
71};
72
73class ir_minmax_visitor : public ir_rvalue_enter_visitor {
74public:
75   ir_minmax_visitor()
76      : progress(false)
77   {
78   }
79
80   ir_rvalue *prune_expression(ir_expression *expr, minmax_range baserange);
81
82   void handle_rvalue(ir_rvalue **rvalue);
83
84   bool progress;
85};
86
87/*
88 * Returns LESS if all vector components of `a' are strictly lower than of `b',
89 * GREATER if all vector components of `a' are strictly greater than of `b',
90 * MIXED if some vector components of `a' are strictly lower than of `b' while
91 * others are strictly greater, or EQUAL otherwise.
92 */
93static enum compare_components_result
94compare_components(ir_constant *a, ir_constant *b)
95{
96   assert(a != NULL);
97   assert(b != NULL);
98
99   assert(a->type->base_type == b->type->base_type);
100
101   unsigned a_inc = a->type->is_scalar() ? 0 : 1;
102   unsigned b_inc = b->type->is_scalar() ? 0 : 1;
103   unsigned components = MAX2(a->type->components(), b->type->components());
104
105   bool foundless = false;
106   bool foundgreater = false;
107   bool foundequal = false;
108
109   for (unsigned i = 0, c0 = 0, c1 = 0;
110        i < components;
111        c0 += a_inc, c1 += b_inc, ++i) {
112      switch (a->type->base_type) {
113      case GLSL_TYPE_UINT16:
114         if (a->value.u16[c0] < b->value.u16[c1])
115            foundless = true;
116         else if (a->value.u16[c0] > b->value.u16[c1])
117            foundgreater = true;
118         else
119            foundequal = true;
120         break;
121      case GLSL_TYPE_INT16:
122         if (a->value.i16[c0] < b->value.i16[c1])
123            foundless = true;
124         else if (a->value.i16[c0] > b->value.i16[c1])
125            foundgreater = true;
126         else
127            foundequal = true;
128         break;
129      case GLSL_TYPE_UINT:
130         if (a->value.u[c0] < b->value.u[c1])
131            foundless = true;
132         else if (a->value.u[c0] > b->value.u[c1])
133            foundgreater = true;
134         else
135            foundequal = true;
136         break;
137      case GLSL_TYPE_INT:
138         if (a->value.i[c0] < b->value.i[c1])
139            foundless = true;
140         else if (a->value.i[c0] > b->value.i[c1])
141            foundgreater = true;
142         else
143            foundequal = true;
144         break;
145      case GLSL_TYPE_FLOAT16: {
146         float af = _mesa_half_to_float(a->value.f16[c0]);
147         float bf = _mesa_half_to_float(b->value.f16[c1]);
148         if (af < bf)
149            foundless = true;
150         else if (af > bf)
151            foundgreater = true;
152         else
153            foundequal = true;
154         break;
155      }
156      case GLSL_TYPE_FLOAT:
157         if (a->value.f[c0] < b->value.f[c1])
158            foundless = true;
159         else if (a->value.f[c0] > b->value.f[c1])
160            foundgreater = true;
161         else
162            foundequal = true;
163         break;
164      case GLSL_TYPE_DOUBLE:
165         if (a->value.d[c0] < b->value.d[c1])
166            foundless = true;
167         else if (a->value.d[c0] > b->value.d[c1])
168            foundgreater = true;
169         else
170            foundequal = true;
171         break;
172      default:
173         unreachable("not reached");
174      }
175   }
176
177   if (foundless && foundgreater) {
178      /* Some components are strictly lower, others are strictly greater */
179      return MIXED;
180   }
181
182   if (foundequal) {
183       /* It is not mixed, but it is not strictly lower or greater */
184      if (foundless)
185         return LESS_OR_EQUAL;
186      if (foundgreater)
187         return GREATER_OR_EQUAL;
188      return EQUAL;
189   }
190
191   /* All components are strictly lower or strictly greater */
192   return foundless ? LESS : GREATER;
193}
194
195static ir_constant *
196combine_constant(bool ismin, ir_constant *a, ir_constant *b)
197{
198   void *mem_ctx = ralloc_parent(a);
199   ir_constant *c = a->clone(mem_ctx, NULL);
200   for (unsigned i = 0; i < c->type->components(); i++) {
201      switch (c->type->base_type) {
202      case GLSL_TYPE_UINT16:
203         if ((ismin && b->value.u16[i] < c->value.u16[i]) ||
204             (!ismin && b->value.u16[i] > c->value.u16[i]))
205            c->value.u16[i] = b->value.u16[i];
206         break;
207      case GLSL_TYPE_INT16:
208         if ((ismin && b->value.i16[i] < c->value.i16[i]) ||
209             (!ismin && b->value.i16[i] > c->value.i16[i]))
210            c->value.i16[i] = b->value.i16[i];
211         break;
212      case GLSL_TYPE_UINT:
213         if ((ismin && b->value.u[i] < c->value.u[i]) ||
214             (!ismin && b->value.u[i] > c->value.u[i]))
215            c->value.u[i] = b->value.u[i];
216         break;
217      case GLSL_TYPE_INT:
218         if ((ismin && b->value.i[i] < c->value.i[i]) ||
219             (!ismin && b->value.i[i] > c->value.i[i]))
220            c->value.i[i] = b->value.i[i];
221         break;
222      case GLSL_TYPE_FLOAT16: {
223         float bf = _mesa_half_to_float(b->value.f16[i]);
224         float cf = _mesa_half_to_float(c->value.f16[i]);
225         if ((ismin && bf < cf) || (!ismin && bf > cf))
226            c->value.f16[i] = b->value.f16[i];
227         break;
228      }
229      case GLSL_TYPE_FLOAT:
230         if ((ismin && b->value.f[i] < c->value.f[i]) ||
231             (!ismin && b->value.f[i] > c->value.f[i]))
232            c->value.f[i] = b->value.f[i];
233         break;
234      case GLSL_TYPE_DOUBLE:
235         if ((ismin && b->value.d[i] < c->value.d[i]) ||
236             (!ismin && b->value.d[i] > c->value.d[i]))
237            c->value.d[i] = b->value.d[i];
238         break;
239      default:
240         assert(!"not reached");
241      }
242   }
243   return c;
244}
245
246static ir_constant *
247smaller_constant(ir_constant *a, ir_constant *b)
248{
249   assert(a != NULL);
250   assert(b != NULL);
251
252   enum compare_components_result ret = compare_components(a, b);
253   if (ret == MIXED)
254      return combine_constant(true, a, b);
255   else if (ret < EQUAL)
256      return a;
257   else
258      return b;
259}
260
261static ir_constant *
262larger_constant(ir_constant *a, ir_constant *b)
263{
264   assert(a != NULL);
265   assert(b != NULL);
266
267   enum compare_components_result ret = compare_components(a, b);
268   if (ret == MIXED)
269      return combine_constant(false, a, b);
270   else if (ret < EQUAL)
271      return b;
272   else
273      return a;
274}
275
276/* Combines two ranges by doing an element-wise min() / max() depending on the
277 * operation.
278 */
279static minmax_range
280combine_range(minmax_range r0, minmax_range r1, bool ismin)
281{
282   minmax_range ret;
283
284   if (!r0.low) {
285      ret.low = ismin ? r0.low : r1.low;
286   } else if (!r1.low) {
287      ret.low = ismin ? r1.low : r0.low;
288   } else {
289      ret.low = ismin ? smaller_constant(r0.low, r1.low) :
290         larger_constant(r0.low, r1.low);
291   }
292
293   if (!r0.high) {
294      ret.high = ismin ? r1.high : r0.high;
295   } else if (!r1.high) {
296      ret.high = ismin ? r0.high : r1.high;
297   } else {
298      ret.high = ismin ? smaller_constant(r0.high, r1.high) :
299         larger_constant(r0.high, r1.high);
300   }
301
302   return ret;
303}
304
305/* Returns a range so that lower limit is the larger of the two lower limits,
306 * and higher limit is the smaller of the two higher limits.
307 */
308static minmax_range
309range_intersection(minmax_range r0, minmax_range r1)
310{
311   minmax_range ret;
312
313   if (!r0.low)
314      ret.low = r1.low;
315   else if (!r1.low)
316      ret.low = r0.low;
317   else
318      ret.low = larger_constant(r0.low, r1.low);
319
320   if (!r0.high)
321      ret.high = r1.high;
322   else if (!r1.high)
323      ret.high = r0.high;
324   else
325      ret.high = smaller_constant(r0.high, r1.high);
326
327   return ret;
328}
329
330static minmax_range
331get_range(ir_rvalue *rval)
332{
333   ir_expression *expr = rval->as_expression();
334   if (expr && (expr->operation == ir_binop_min ||
335                expr->operation == ir_binop_max)) {
336      minmax_range r0 = get_range(expr->operands[0]);
337      minmax_range r1 = get_range(expr->operands[1]);
338      return combine_range(r0, r1, expr->operation == ir_binop_min);
339   }
340
341   ir_constant *c = rval->as_constant();
342   if (c) {
343      return minmax_range(c, c);
344   }
345
346   return minmax_range();
347}
348
349/**
350 * Prunes a min/max expression considering the base range of the parent
351 * min/max expression.
352 *
353 * @param baserange the range that the parents of this min/max expression
354 * in the min/max tree will clamp its value to.
355 */
356ir_rvalue *
357ir_minmax_visitor::prune_expression(ir_expression *expr, minmax_range baserange)
358{
359   assert(expr->operation == ir_binop_min ||
360          expr->operation == ir_binop_max);
361
362   bool ismin = expr->operation == ir_binop_min;
363   minmax_range limits[2];
364
365   /* Recurse to get the ranges for each of the subtrees of this
366    * expression. We need to do this as a separate step because we need to
367    * know the ranges of each of the subtrees before we prune either one.
368    * Consider something like this:
369    *
370    *        max
371    *     /       \
372    *    max     max
373    *   /   \   /   \
374    *  3    a   b    2
375    *
376    * We would like to prune away the max on the bottom-right, but to do so
377    * we need to know the range of the expression on the left beforehand,
378    * and there's no guarantee that we will visit either subtree in a
379    * particular order.
380    */
381   for (unsigned i = 0; i < 2; ++i)
382      limits[i] = get_range(expr->operands[i]);
383
384   for (unsigned i = 0; i < 2; ++i) {
385      bool is_redundant = false;
386
387      enum compare_components_result cr = LESS;
388      if (ismin) {
389         /* If this operand will always be greater than the other one, it's
390          * redundant.
391          */
392         if (limits[i].low && limits[1 - i].high) {
393               cr = compare_components(limits[i].low, limits[1 - i].high);
394            if (cr >= EQUAL && cr != MIXED)
395               is_redundant = true;
396         }
397         /* If this operand is always greater than baserange, then even if
398          * it's smaller than the other one it'll get clamped, so it's
399          * redundant.
400          */
401         if (!is_redundant && limits[i].low && baserange.high) {
402            cr = compare_components(limits[i].low, baserange.high);
403            if (cr > EQUAL && cr != MIXED)
404               is_redundant = true;
405         }
406      } else {
407         /* If this operand will always be lower than the other one, it's
408          * redundant.
409          */
410         if (limits[i].high && limits[1 - i].low) {
411            cr = compare_components(limits[i].high, limits[1 - i].low);
412            if (cr <= EQUAL)
413               is_redundant = true;
414         }
415         /* If this operand is always lower than baserange, then even if
416          * it's greater than the other one it'll get clamped, so it's
417          * redundant.
418          */
419         if (!is_redundant && limits[i].high && baserange.low) {
420            cr = compare_components(limits[i].high, baserange.low);
421            if (cr < EQUAL)
422               is_redundant = true;
423         }
424      }
425
426      if (is_redundant) {
427         progress = true;
428
429         /* Recurse if necessary. */
430         ir_expression *op_expr = expr->operands[1 - i]->as_expression();
431         if (op_expr && (op_expr->operation == ir_binop_min ||
432                         op_expr->operation == ir_binop_max)) {
433            return prune_expression(op_expr, baserange);
434         }
435
436         return expr->operands[1 - i];
437      } else if (cr == MIXED) {
438         /* If we have mixed vector operands, we can try to resolve the minmax
439          * expression by doing a component-wise minmax:
440          *
441          *             min                          min
442          *           /    \                       /    \
443          *         min     a       ===>        [1,1]    a
444          *       /    \
445          *    [1,3]   [3,1]
446          *
447          */
448         ir_constant *a = expr->operands[0]->as_constant();
449         ir_constant *b = expr->operands[1]->as_constant();
450         if (a && b)
451            return combine_constant(ismin, a, b);
452      }
453   }
454
455   /* Now recurse to operands giving them the proper baserange. The baserange
456    * to pass is the intersection of our baserange and the other operand's
457    * limit with one of the ranges unlimited. If we can't compute a valid
458    * intersection, we use the current baserange.
459    */
460   for (unsigned i = 0; i < 2; ++i) {
461      ir_expression *op_expr = expr->operands[i]->as_expression();
462      if (op_expr && (op_expr->operation == ir_binop_min ||
463                      op_expr->operation == ir_binop_max)) {
464         /* We can only compute a new baserange for this operand if we managed
465          * to compute a valid range for the other operand.
466          */
467         if (ismin)
468            limits[1 - i].low = NULL;
469         else
470            limits[1 - i].high = NULL;
471         minmax_range base = range_intersection(limits[1 - i], baserange);
472         expr->operands[i] = prune_expression(op_expr, base);
473      }
474   }
475
476   /* If we got here we could not discard any of the operands of the minmax
477    * expression, but we can still try to resolve the expression if both
478    * operands are constant. We do this after the loop above, to make sure
479    * that if our operands are minmax expressions we have tried to prune them
480    * first (hopefully reducing them to constants).
481    */
482   ir_constant *a = expr->operands[0]->as_constant();
483   ir_constant *b = expr->operands[1]->as_constant();
484   if (a && b)
485      return combine_constant(ismin, a, b);
486
487   return expr;
488}
489
490static ir_rvalue *
491swizzle_if_required(ir_expression *expr, ir_rvalue *rval)
492{
493   if (expr->type->is_vector() && rval->type->is_scalar()) {
494      return swizzle(rval, SWIZZLE_XXXX, expr->type->vector_elements);
495   } else {
496      return rval;
497   }
498}
499
500void
501ir_minmax_visitor::handle_rvalue(ir_rvalue **rvalue)
502{
503   if (!*rvalue)
504      return;
505
506   ir_expression *expr = (*rvalue)->as_expression();
507   if (!expr || (expr->operation != ir_binop_min &&
508                 expr->operation != ir_binop_max))
509      return;
510
511   ir_rvalue *new_rvalue = prune_expression(expr, minmax_range());
512   if (new_rvalue == *rvalue)
513      return;
514
515   /* If the expression type is a vector and the optimization leaves a scalar
516    * as the result, we need to turn it into a vector.
517    */
518   *rvalue = swizzle_if_required(expr, new_rvalue);
519
520   progress = true;
521}
522
523}
524
525bool
526do_minmax_prune(exec_list *instructions)
527{
528   ir_minmax_visitor v;
529
530   visit_list_elements(&v, instructions);
531
532   return v.progress;
533}
534