1/*
2 * Copyright (C) 2021 Valve Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 * SOFTWARE.
22 */
23
24#include "ir3.h"
25
26/* Lower several macro-instructions needed for shader subgroup support that
27 * must be turned into if statements. We do this after RA and post-RA
28 * scheduling to give the scheduler a chance to rearrange them, because RA
29 * may need to insert OPC_META_READ_FIRST to handle splitting live ranges, and
30 * also because some (e.g. BALLOT and READ_FIRST) must produce a shared
31 * register that cannot be spilled to a normal register until after the if,
32 * which makes implementing spilling more complicated if they are already
33 * lowered.
34 */
35
36static void
37replace_pred(struct ir3_block *block, struct ir3_block *old_pred,
38             struct ir3_block *new_pred)
39{
40   for (unsigned i = 0; i < block->predecessors_count; i++) {
41      if (block->predecessors[i] == old_pred) {
42         block->predecessors[i] = new_pred;
43         return;
44      }
45   }
46}
47
48static void
49replace_physical_pred(struct ir3_block *block, struct ir3_block *old_pred,
50                      struct ir3_block *new_pred)
51{
52   for (unsigned i = 0; i < block->physical_predecessors_count; i++) {
53      if (block->physical_predecessors[i] == old_pred) {
54         block->physical_predecessors[i] = new_pred;
55         return;
56      }
57   }
58}
59
60static void
61mov_immed(struct ir3_register *dst, struct ir3_block *block, unsigned immed)
62{
63   struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
64   struct ir3_register *mov_dst = ir3_dst_create(mov, dst->num, dst->flags);
65   mov_dst->wrmask = dst->wrmask;
66   struct ir3_register *src = ir3_src_create(
67      mov, INVALID_REG, (dst->flags & IR3_REG_HALF) | IR3_REG_IMMED);
68   src->uim_val = immed;
69   mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
70   mov->cat1.src_type = mov->cat1.dst_type;
71   mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
72}
73
74static void
75mov_reg(struct ir3_block *block, struct ir3_register *dst,
76        struct ir3_register *src)
77{
78   struct ir3_instruction *mov = ir3_instr_create(block, OPC_MOV, 1, 1);
79
80   struct ir3_register *mov_dst =
81      ir3_dst_create(mov, dst->num, dst->flags & (IR3_REG_HALF | IR3_REG_SHARED));
82   struct ir3_register *mov_src =
83      ir3_src_create(mov, src->num, src->flags & (IR3_REG_HALF | IR3_REG_SHARED));
84   mov_dst->wrmask = dst->wrmask;
85   mov_src->wrmask = src->wrmask;
86   mov->repeat = util_last_bit(mov_dst->wrmask) - 1;
87
88   mov->cat1.dst_type = (dst->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
89   mov->cat1.src_type = (src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
90}
91
92static void
93binop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
94      struct ir3_register *src0, struct ir3_register *src1)
95{
96   struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 2);
97
98   unsigned flags = dst->flags & IR3_REG_HALF;
99   struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
100   struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
101   struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
102
103   instr_dst->wrmask = dst->wrmask;
104   instr_src0->wrmask = src0->wrmask;
105   instr_src1->wrmask = src1->wrmask;
106   instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
107}
108
109static void
110triop(struct ir3_block *block, opc_t opc, struct ir3_register *dst,
111      struct ir3_register *src0, struct ir3_register *src1,
112      struct ir3_register *src2)
113{
114   struct ir3_instruction *instr = ir3_instr_create(block, opc, 1, 3);
115
116   unsigned flags = dst->flags & IR3_REG_HALF;
117   struct ir3_register *instr_dst = ir3_dst_create(instr, dst->num, flags);
118   struct ir3_register *instr_src0 = ir3_src_create(instr, src0->num, flags);
119   struct ir3_register *instr_src1 = ir3_src_create(instr, src1->num, flags);
120   struct ir3_register *instr_src2 = ir3_src_create(instr, src2->num, flags);
121
122   instr_dst->wrmask = dst->wrmask;
123   instr_src0->wrmask = src0->wrmask;
124   instr_src1->wrmask = src1->wrmask;
125   instr_src2->wrmask = src2->wrmask;
126   instr->repeat = util_last_bit(instr_dst->wrmask) - 1;
127}
128
129static void
130do_reduce(struct ir3_block *block, reduce_op_t opc,
131          struct ir3_register *dst, struct ir3_register *src0,
132          struct ir3_register *src1)
133{
134   switch (opc) {
135#define CASE(name)                                                             \
136   case REDUCE_OP_##name:                                                      \
137      binop(block, OPC_##name, dst, src0, src1);                               \
138      break;
139
140   CASE(ADD_U)
141   CASE(ADD_F)
142   CASE(MUL_F)
143   CASE(MIN_U)
144   CASE(MIN_S)
145   CASE(MIN_F)
146   CASE(MAX_U)
147   CASE(MAX_S)
148   CASE(MAX_F)
149   CASE(AND_B)
150   CASE(OR_B)
151   CASE(XOR_B)
152
153#undef CASE
154
155   case REDUCE_OP_MUL_U:
156      if (dst->flags & IR3_REG_HALF) {
157         binop(block, OPC_MUL_S24, dst, src0, src1);
158      } else {
159         /* 32-bit multiplication macro - see ir3_nir_imul */
160         binop(block, OPC_MULL_U, dst, src0, src1);
161         triop(block, OPC_MADSH_M16, dst, src0, src1, dst);
162         triop(block, OPC_MADSH_M16, dst, src1, src0, dst);
163      }
164      break;
165   }
166}
167
168static struct ir3_block *
169split_block(struct ir3 *ir, struct ir3_block *before_block,
170            struct ir3_instruction *instr)
171{
172   struct ir3_block *after_block = ir3_block_create(ir);
173   list_add(&after_block->node, &before_block->node);
174
175   for (unsigned i = 0; i < ARRAY_SIZE(before_block->successors); i++) {
176      after_block->successors[i] = before_block->successors[i];
177      if (after_block->successors[i])
178         replace_pred(after_block->successors[i], before_block, after_block);
179   }
180
181   for (unsigned i = 0; i < ARRAY_SIZE(before_block->physical_successors);
182        i++) {
183      after_block->physical_successors[i] =
184         before_block->physical_successors[i];
185      if (after_block->physical_successors[i]) {
186         replace_physical_pred(after_block->physical_successors[i],
187                               before_block, after_block);
188      }
189   }
190
191   before_block->successors[0] = before_block->successors[1] = NULL;
192   before_block->physical_successors[0] = before_block->physical_successors[1] = NULL;
193
194   foreach_instr_from_safe (rem_instr, &instr->node,
195                            &before_block->instr_list) {
196      list_del(&rem_instr->node);
197      list_addtail(&rem_instr->node, &after_block->instr_list);
198      rem_instr->block = after_block;
199   }
200
201   after_block->brtype = before_block->brtype;
202   after_block->condition = before_block->condition;
203
204   return after_block;
205}
206
207static void
208link_blocks_physical(struct ir3_block *pred, struct ir3_block *succ,
209                     unsigned index)
210{
211   pred->physical_successors[index] = succ;
212   ir3_block_add_physical_predecessor(succ, pred);
213}
214
215static void
216link_blocks(struct ir3_block *pred, struct ir3_block *succ, unsigned index)
217{
218   pred->successors[index] = succ;
219   ir3_block_add_predecessor(succ, pred);
220   link_blocks_physical(pred, succ, index);
221}
222
223static struct ir3_block *
224create_if(struct ir3 *ir, struct ir3_block *before_block,
225          struct ir3_block *after_block)
226{
227   struct ir3_block *then_block = ir3_block_create(ir);
228   list_add(&then_block->node, &before_block->node);
229
230   link_blocks(before_block, then_block, 0);
231   link_blocks(before_block, after_block, 1);
232   link_blocks(then_block, after_block, 0);
233
234   return then_block;
235}
236
237static bool
238lower_instr(struct ir3 *ir, struct ir3_block **block, struct ir3_instruction *instr)
239{
240   switch (instr->opc) {
241   case OPC_BALLOT_MACRO:
242   case OPC_ANY_MACRO:
243   case OPC_ALL_MACRO:
244   case OPC_ELECT_MACRO:
245   case OPC_READ_COND_MACRO:
246   case OPC_READ_FIRST_MACRO:
247   case OPC_SWZ_SHARED_MACRO:
248   case OPC_SCAN_MACRO:
249      break;
250   default:
251      return false;
252   }
253
254   struct ir3_block *before_block = *block;
255   struct ir3_block *after_block = split_block(ir, before_block, instr);
256
257   if (instr->opc == OPC_SCAN_MACRO) {
258      /* The pseudo-code for the scan macro is:
259       *
260       * while (true) {
261       *    header:
262       *    if (elect()) {
263       *       exit:
264       *       exclusive = reduce;
265       *       inclusive = src OP exclusive;
266       *       reduce = inclusive;
267       *    }
268       *    footer:
269       * }
270       *
271       * This is based on the blob's sequence, and carefully crafted to avoid
272       * using the shared register "reduce" except in move instructions, since
273       * using it in the actual OP isn't possible for half-registers.
274       */
275      struct ir3_block *header = ir3_block_create(ir);
276      list_add(&header->node, &before_block->node);
277
278      struct ir3_block *exit = ir3_block_create(ir);
279      list_add(&exit->node, &header->node);
280
281      struct ir3_block *footer = ir3_block_create(ir);
282      list_add(&footer->node, &exit->node);
283
284      link_blocks(before_block, header, 0);
285
286      link_blocks(header, exit, 0);
287      link_blocks(header, footer, 1);
288      header->brtype = IR3_BRANCH_GETONE;
289
290      link_blocks(exit, after_block, 0);
291      link_blocks_physical(exit, footer, 1);
292
293      link_blocks(footer, header, 0);
294
295      struct ir3_register *exclusive = instr->dsts[0];
296      struct ir3_register *inclusive = instr->dsts[1];
297      struct ir3_register *reduce = instr->dsts[2];
298      struct ir3_register *src = instr->srcs[0];
299
300      mov_reg(exit, exclusive, reduce);
301      do_reduce(exit, instr->cat1.reduce_op, inclusive, src, exclusive);
302      mov_reg(exit, reduce, inclusive);
303   } else {
304      struct ir3_block *then_block = create_if(ir, before_block, after_block);
305
306      /* For ballot, the destination must be initialized to 0 before we do
307       * the movmsk because the condition may be 0 and then the movmsk will
308       * be skipped. Because it's a shared register we have to wrap the
309       * initialization in a getone block.
310       */
311      if (instr->opc == OPC_BALLOT_MACRO) {
312         before_block->brtype = IR3_BRANCH_GETONE;
313         before_block->condition = NULL;
314         mov_immed(instr->dsts[0], then_block, 0);
315         before_block = after_block;
316         after_block = split_block(ir, before_block, instr);
317         then_block = create_if(ir, before_block, after_block);
318      }
319
320      switch (instr->opc) {
321      case OPC_BALLOT_MACRO:
322      case OPC_READ_COND_MACRO:
323      case OPC_ANY_MACRO:
324      case OPC_ALL_MACRO:
325         before_block->condition = instr->srcs[0]->def->instr;
326         break;
327      default:
328         before_block->condition = NULL;
329         break;
330      }
331
332      switch (instr->opc) {
333      case OPC_BALLOT_MACRO:
334      case OPC_READ_COND_MACRO:
335         before_block->brtype = IR3_BRANCH_COND;
336         break;
337      case OPC_ANY_MACRO:
338         before_block->brtype = IR3_BRANCH_ANY;
339         break;
340      case OPC_ALL_MACRO:
341         before_block->brtype = IR3_BRANCH_ALL;
342         break;
343      case OPC_ELECT_MACRO:
344      case OPC_READ_FIRST_MACRO:
345      case OPC_SWZ_SHARED_MACRO:
346         before_block->brtype = IR3_BRANCH_GETONE;
347         break;
348      default:
349         unreachable("bad opcode");
350      }
351
352      switch (instr->opc) {
353      case OPC_ALL_MACRO:
354      case OPC_ANY_MACRO:
355      case OPC_ELECT_MACRO:
356         mov_immed(instr->dsts[0], then_block, 1);
357         mov_immed(instr->dsts[0], before_block, 0);
358         break;
359
360      case OPC_BALLOT_MACRO: {
361         unsigned comp_count = util_last_bit(instr->dsts[0]->wrmask);
362         struct ir3_instruction *movmsk =
363            ir3_instr_create(then_block, OPC_MOVMSK, 1, 0);
364         ir3_dst_create(movmsk, instr->dsts[0]->num, instr->dsts[0]->flags);
365         movmsk->repeat = comp_count - 1;
366         break;
367      }
368
369      case OPC_READ_COND_MACRO:
370      case OPC_READ_FIRST_MACRO: {
371         struct ir3_instruction *mov =
372            ir3_instr_create(then_block, OPC_MOV, 1, 1);
373         unsigned src = instr->opc == OPC_READ_COND_MACRO ? 1 : 0;
374         ir3_dst_create(mov, instr->dsts[0]->num, instr->dsts[0]->flags);
375         struct ir3_register *new_src = ir3_src_create(mov, 0, 0);
376         *new_src = *instr->srcs[src];
377         mov->cat1.dst_type = TYPE_U32;
378         mov->cat1.src_type =
379            (new_src->flags & IR3_REG_HALF) ? TYPE_U16 : TYPE_U32;
380         break;
381      }
382
383      case OPC_SWZ_SHARED_MACRO: {
384         struct ir3_instruction *swz =
385            ir3_instr_create(then_block, OPC_SWZ, 2, 2);
386         ir3_dst_create(swz, instr->dsts[0]->num, instr->dsts[0]->flags);
387         ir3_dst_create(swz, instr->dsts[1]->num, instr->dsts[1]->flags);
388         ir3_src_create(swz, instr->srcs[0]->num, instr->srcs[0]->flags);
389         ir3_src_create(swz, instr->srcs[1]->num, instr->srcs[1]->flags);
390         swz->cat1.dst_type = swz->cat1.src_type = TYPE_U32;
391         swz->repeat = 1;
392         break;
393      }
394
395      default:
396         unreachable("bad opcode");
397      }
398   }
399
400   *block = after_block;
401   list_delinit(&instr->node);
402   return true;
403}
404
405static bool
406lower_block(struct ir3 *ir, struct ir3_block **block)
407{
408   bool progress = true;
409
410   bool inner_progress;
411   do {
412      inner_progress = false;
413      foreach_instr (instr, &(*block)->instr_list) {
414         if (lower_instr(ir, block, instr)) {
415            /* restart the loop with the new block we created because the
416             * iterator has been invalidated.
417             */
418            progress = inner_progress = true;
419            break;
420         }
421      }
422   } while (inner_progress);
423
424   return progress;
425}
426
427bool
428ir3_lower_subgroups(struct ir3 *ir)
429{
430   bool progress = false;
431
432   foreach_block (block, &ir->block_list)
433      progress |= lower_block(ir, &block);
434
435   return progress;
436}
437