1/*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24#include "vtn_private.h"
25
26static struct vtn_ssa_value *
27vtn_build_subgroup_instr(struct vtn_builder *b,
28                         nir_intrinsic_op nir_op,
29                         struct vtn_ssa_value *src0,
30                         nir_ssa_def *index,
31                         unsigned const_idx0,
32                         unsigned const_idx1)
33{
34   /* Some of the subgroup operations take an index.  SPIR-V allows this to be
35    * any integer type.  To make things simpler for drivers, we only support
36    * 32-bit indices.
37    */
38   if (index && index->bit_size != 32)
39      index = nir_u2u32(&b->nb, index);
40
41   struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type);
42
43   vtn_assert(dst->type == src0->type);
44   if (!glsl_type_is_vector_or_scalar(dst->type)) {
45      for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
46         dst->elems[0] =
47            vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index,
48                                     const_idx0, const_idx1);
49      }
50      return dst;
51   }
52
53   nir_intrinsic_instr *intrin =
54      nir_intrinsic_instr_create(b->nb.shader, nir_op);
55   nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
56                              dst->type, NULL);
57   intrin->num_components = intrin->dest.ssa.num_components;
58
59   intrin->src[0] = nir_src_for_ssa(src0->def);
60   if (index)
61      intrin->src[1] = nir_src_for_ssa(index);
62
63   intrin->const_index[0] = const_idx0;
64   intrin->const_index[1] = const_idx1;
65
66   nir_builder_instr_insert(&b->nb, &intrin->instr);
67
68   dst->def = &intrin->dest.ssa;
69
70   return dst;
71}
72
73void
74vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
75                    const uint32_t *w, unsigned count)
76{
77   struct vtn_type *dest_type = vtn_get_type(b, w[1]);
78
79   switch (opcode) {
80   case SpvOpGroupNonUniformElect: {
81      vtn_fail_if(dest_type->type != glsl_bool_type(),
82                  "OpGroupNonUniformElect must return a Bool");
83      nir_intrinsic_instr *elect =
84         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
85      nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
86                                 dest_type->type, NULL);
87      nir_builder_instr_insert(&b->nb, &elect->instr);
88      vtn_push_nir_ssa(b, w[2], &elect->dest.ssa);
89      break;
90   }
91
92   case SpvOpGroupNonUniformBallot:
93   case SpvOpSubgroupBallotKHR: {
94      bool has_scope = (opcode != SpvOpSubgroupBallotKHR);
95      vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
96                  "OpGroupNonUniformBallot must return a uvec4");
97      nir_intrinsic_instr *ballot =
98         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
99      ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope]));
100      nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
101      ballot->num_components = 4;
102      nir_builder_instr_insert(&b->nb, &ballot->instr);
103      vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa);
104      break;
105   }
106
107   case SpvOpGroupNonUniformInverseBallot: {
108      /* This one is just a BallotBitfieldExtract with subgroup invocation.
109       * We could add a NIR intrinsic but it's easier to just lower it on the
110       * spot.
111       */
112      nir_intrinsic_instr *intrin =
113         nir_intrinsic_instr_create(b->nb.shader,
114                                    nir_intrinsic_ballot_bitfield_extract);
115
116      intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4]));
117      intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
118
119      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
120                                 dest_type->type, NULL);
121      nir_builder_instr_insert(&b->nb, &intrin->instr);
122
123      vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
124      break;
125   }
126
127   case SpvOpGroupNonUniformBallotBitExtract:
128   case SpvOpGroupNonUniformBallotBitCount:
129   case SpvOpGroupNonUniformBallotFindLSB:
130   case SpvOpGroupNonUniformBallotFindMSB: {
131      nir_ssa_def *src0, *src1 = NULL;
132      nir_intrinsic_op op;
133      switch (opcode) {
134      case SpvOpGroupNonUniformBallotBitExtract:
135         op = nir_intrinsic_ballot_bitfield_extract;
136         src0 = vtn_get_nir_ssa(b, w[4]);
137         src1 = vtn_get_nir_ssa(b, w[5]);
138         break;
139      case SpvOpGroupNonUniformBallotBitCount:
140         switch ((SpvGroupOperation)w[4]) {
141         case SpvGroupOperationReduce:
142            op = nir_intrinsic_ballot_bit_count_reduce;
143            break;
144         case SpvGroupOperationInclusiveScan:
145            op = nir_intrinsic_ballot_bit_count_inclusive;
146            break;
147         case SpvGroupOperationExclusiveScan:
148            op = nir_intrinsic_ballot_bit_count_exclusive;
149            break;
150         default:
151            unreachable("Invalid group operation");
152         }
153         src0 = vtn_get_nir_ssa(b, w[5]);
154         break;
155      case SpvOpGroupNonUniformBallotFindLSB:
156         op = nir_intrinsic_ballot_find_lsb;
157         src0 = vtn_get_nir_ssa(b, w[4]);
158         break;
159      case SpvOpGroupNonUniformBallotFindMSB:
160         op = nir_intrinsic_ballot_find_msb;
161         src0 = vtn_get_nir_ssa(b, w[4]);
162         break;
163      default:
164         unreachable("Unhandled opcode");
165      }
166
167      nir_intrinsic_instr *intrin =
168         nir_intrinsic_instr_create(b->nb.shader, op);
169
170      intrin->src[0] = nir_src_for_ssa(src0);
171      if (src1)
172         intrin->src[1] = nir_src_for_ssa(src1);
173
174      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
175                                 dest_type->type, NULL);
176      nir_builder_instr_insert(&b->nb, &intrin->instr);
177
178      vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
179      break;
180   }
181
182   case SpvOpGroupNonUniformBroadcastFirst:
183   case SpvOpSubgroupFirstInvocationKHR: {
184      bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR);
185      vtn_push_ssa_value(b, w[2],
186         vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
187                                  vtn_ssa_value(b, w[3 + has_scope]),
188                                  NULL, 0, 0));
189      break;
190   }
191
192   case SpvOpGroupNonUniformBroadcast:
193   case SpvOpGroupBroadcast:
194   case SpvOpSubgroupReadInvocationKHR: {
195      bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR);
196      vtn_push_ssa_value(b, w[2],
197         vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
198                                  vtn_ssa_value(b, w[3 + has_scope]),
199                                  vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0));
200      break;
201   }
202
203   case SpvOpGroupNonUniformAll:
204   case SpvOpGroupNonUniformAny:
205   case SpvOpGroupNonUniformAllEqual:
206   case SpvOpGroupAll:
207   case SpvOpGroupAny:
208   case SpvOpSubgroupAllKHR:
209   case SpvOpSubgroupAnyKHR:
210   case SpvOpSubgroupAllEqualKHR: {
211      vtn_fail_if(dest_type->type != glsl_bool_type(),
212                  "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
213      nir_intrinsic_op op;
214      switch (opcode) {
215      case SpvOpGroupNonUniformAll:
216      case SpvOpGroupAll:
217      case SpvOpSubgroupAllKHR:
218         op = nir_intrinsic_vote_all;
219         break;
220      case SpvOpGroupNonUniformAny:
221      case SpvOpGroupAny:
222      case SpvOpSubgroupAnyKHR:
223         op = nir_intrinsic_vote_any;
224         break;
225      case SpvOpSubgroupAllEqualKHR:
226         op = nir_intrinsic_vote_ieq;
227         break;
228      case SpvOpGroupNonUniformAllEqual:
229         switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
230         case GLSL_TYPE_FLOAT:
231         case GLSL_TYPE_FLOAT16:
232         case GLSL_TYPE_DOUBLE:
233            op = nir_intrinsic_vote_feq;
234            break;
235         case GLSL_TYPE_UINT:
236         case GLSL_TYPE_INT:
237         case GLSL_TYPE_UINT8:
238         case GLSL_TYPE_INT8:
239         case GLSL_TYPE_UINT16:
240         case GLSL_TYPE_INT16:
241         case GLSL_TYPE_UINT64:
242         case GLSL_TYPE_INT64:
243         case GLSL_TYPE_BOOL:
244            op = nir_intrinsic_vote_ieq;
245            break;
246         default:
247            unreachable("Unhandled type");
248         }
249         break;
250      default:
251         unreachable("Unhandled opcode");
252      }
253
254      nir_ssa_def *src0;
255      if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
256          opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
257          opcode == SpvOpGroupNonUniformAllEqual) {
258         src0 = vtn_get_nir_ssa(b, w[4]);
259      } else {
260         src0 = vtn_get_nir_ssa(b, w[3]);
261      }
262      nir_intrinsic_instr *intrin =
263         nir_intrinsic_instr_create(b->nb.shader, op);
264      if (nir_intrinsic_infos[op].src_components[0] == 0)
265         intrin->num_components = src0->num_components;
266      intrin->src[0] = nir_src_for_ssa(src0);
267      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
268                                 dest_type->type, NULL);
269      nir_builder_instr_insert(&b->nb, &intrin->instr);
270
271      vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa);
272      break;
273   }
274
275   case SpvOpGroupNonUniformShuffle:
276   case SpvOpGroupNonUniformShuffleXor:
277   case SpvOpGroupNonUniformShuffleUp:
278   case SpvOpGroupNonUniformShuffleDown: {
279      nir_intrinsic_op op;
280      switch (opcode) {
281      case SpvOpGroupNonUniformShuffle:
282         op = nir_intrinsic_shuffle;
283         break;
284      case SpvOpGroupNonUniformShuffleXor:
285         op = nir_intrinsic_shuffle_xor;
286         break;
287      case SpvOpGroupNonUniformShuffleUp:
288         op = nir_intrinsic_shuffle_up;
289         break;
290      case SpvOpGroupNonUniformShuffleDown:
291         op = nir_intrinsic_shuffle_down;
292         break;
293      default:
294         unreachable("Invalid opcode");
295      }
296      vtn_push_ssa_value(b, w[2],
297         vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]),
298                                  vtn_get_nir_ssa(b, w[5]), 0, 0));
299      break;
300   }
301
302   case SpvOpSubgroupShuffleINTEL:
303   case SpvOpSubgroupShuffleXorINTEL: {
304      nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ?
305         nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor;
306      vtn_push_ssa_value(b, w[2],
307         vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]),
308                                  vtn_get_nir_ssa(b, w[4]), 0, 0));
309      break;
310   }
311
312   case SpvOpSubgroupShuffleUpINTEL:
313   case SpvOpSubgroupShuffleDownINTEL: {
314      /* TODO: Move this lower on the compiler stack, where we can move the
315       * current/other data to adjacent registers to avoid doing a shuffle
316       * twice.
317       */
318
319      nir_builder *nb = &b->nb;
320      nir_ssa_def *size = nir_load_subgroup_size(nb);
321      nir_ssa_def *delta = vtn_get_nir_ssa(b, w[5]);
322
323      /* Rewrite UP in terms of DOWN.
324       *
325       *   UP(a, b, delta) == DOWN(a, b, size - delta)
326       */
327      if (opcode == SpvOpSubgroupShuffleUpINTEL)
328         delta = nir_isub(nb, size, delta);
329
330      nir_ssa_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta);
331      struct vtn_ssa_value *current =
332         vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]),
333                                  index, 0, 0);
334
335      struct vtn_ssa_value *next =
336         vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]),
337                                  nir_isub(nb, index, size), 0, 0);
338
339      nir_ssa_def *cond = nir_ilt(nb, index, size);
340      vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def));
341
342      break;
343   }
344
345   case SpvOpGroupNonUniformQuadBroadcast:
346      vtn_push_ssa_value(b, w[2],
347         vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
348                                  vtn_ssa_value(b, w[4]),
349                                  vtn_get_nir_ssa(b, w[5]), 0, 0));
350      break;
351
352   case SpvOpGroupNonUniformQuadSwap: {
353      unsigned direction = vtn_constant_uint(b, w[5]);
354      nir_intrinsic_op op;
355      switch (direction) {
356      case 0:
357         op = nir_intrinsic_quad_swap_horizontal;
358         break;
359      case 1:
360         op = nir_intrinsic_quad_swap_vertical;
361         break;
362      case 2:
363         op = nir_intrinsic_quad_swap_diagonal;
364         break;
365      default:
366         vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
367      }
368      vtn_push_ssa_value(b, w[2],
369         vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0));
370      break;
371   }
372
373   case SpvOpGroupNonUniformIAdd:
374   case SpvOpGroupNonUniformFAdd:
375   case SpvOpGroupNonUniformIMul:
376   case SpvOpGroupNonUniformFMul:
377   case SpvOpGroupNonUniformSMin:
378   case SpvOpGroupNonUniformUMin:
379   case SpvOpGroupNonUniformFMin:
380   case SpvOpGroupNonUniformSMax:
381   case SpvOpGroupNonUniformUMax:
382   case SpvOpGroupNonUniformFMax:
383   case SpvOpGroupNonUniformBitwiseAnd:
384   case SpvOpGroupNonUniformBitwiseOr:
385   case SpvOpGroupNonUniformBitwiseXor:
386   case SpvOpGroupNonUniformLogicalAnd:
387   case SpvOpGroupNonUniformLogicalOr:
388   case SpvOpGroupNonUniformLogicalXor:
389   case SpvOpGroupIAdd:
390   case SpvOpGroupFAdd:
391   case SpvOpGroupFMin:
392   case SpvOpGroupUMin:
393   case SpvOpGroupSMin:
394   case SpvOpGroupFMax:
395   case SpvOpGroupUMax:
396   case SpvOpGroupSMax:
397   case SpvOpGroupIAddNonUniformAMD:
398   case SpvOpGroupFAddNonUniformAMD:
399   case SpvOpGroupFMinNonUniformAMD:
400   case SpvOpGroupUMinNonUniformAMD:
401   case SpvOpGroupSMinNonUniformAMD:
402   case SpvOpGroupFMaxNonUniformAMD:
403   case SpvOpGroupUMaxNonUniformAMD:
404   case SpvOpGroupSMaxNonUniformAMD: {
405      nir_op reduction_op;
406      switch (opcode) {
407      case SpvOpGroupNonUniformIAdd:
408      case SpvOpGroupIAdd:
409      case SpvOpGroupIAddNonUniformAMD:
410         reduction_op = nir_op_iadd;
411         break;
412      case SpvOpGroupNonUniformFAdd:
413      case SpvOpGroupFAdd:
414      case SpvOpGroupFAddNonUniformAMD:
415         reduction_op = nir_op_fadd;
416         break;
417      case SpvOpGroupNonUniformIMul:
418         reduction_op = nir_op_imul;
419         break;
420      case SpvOpGroupNonUniformFMul:
421         reduction_op = nir_op_fmul;
422         break;
423      case SpvOpGroupNonUniformSMin:
424      case SpvOpGroupSMin:
425      case SpvOpGroupSMinNonUniformAMD:
426         reduction_op = nir_op_imin;
427         break;
428      case SpvOpGroupNonUniformUMin:
429      case SpvOpGroupUMin:
430      case SpvOpGroupUMinNonUniformAMD:
431         reduction_op = nir_op_umin;
432         break;
433      case SpvOpGroupNonUniformFMin:
434      case SpvOpGroupFMin:
435      case SpvOpGroupFMinNonUniformAMD:
436         reduction_op = nir_op_fmin;
437         break;
438      case SpvOpGroupNonUniformSMax:
439      case SpvOpGroupSMax:
440      case SpvOpGroupSMaxNonUniformAMD:
441         reduction_op = nir_op_imax;
442         break;
443      case SpvOpGroupNonUniformUMax:
444      case SpvOpGroupUMax:
445      case SpvOpGroupUMaxNonUniformAMD:
446         reduction_op = nir_op_umax;
447         break;
448      case SpvOpGroupNonUniformFMax:
449      case SpvOpGroupFMax:
450      case SpvOpGroupFMaxNonUniformAMD:
451         reduction_op = nir_op_fmax;
452         break;
453      case SpvOpGroupNonUniformBitwiseAnd:
454      case SpvOpGroupNonUniformLogicalAnd:
455         reduction_op = nir_op_iand;
456         break;
457      case SpvOpGroupNonUniformBitwiseOr:
458      case SpvOpGroupNonUniformLogicalOr:
459         reduction_op = nir_op_ior;
460         break;
461      case SpvOpGroupNonUniformBitwiseXor:
462      case SpvOpGroupNonUniformLogicalXor:
463         reduction_op = nir_op_ixor;
464         break;
465      default:
466         unreachable("Invalid reduction operation");
467      }
468
469      nir_intrinsic_op op;
470      unsigned cluster_size = 0;
471      switch ((SpvGroupOperation)w[4]) {
472      case SpvGroupOperationReduce:
473         op = nir_intrinsic_reduce;
474         break;
475      case SpvGroupOperationInclusiveScan:
476         op = nir_intrinsic_inclusive_scan;
477         break;
478      case SpvGroupOperationExclusiveScan:
479         op = nir_intrinsic_exclusive_scan;
480         break;
481      case SpvGroupOperationClusteredReduce:
482         op = nir_intrinsic_reduce;
483         assert(count == 7);
484         cluster_size = vtn_constant_uint(b, w[6]);
485         break;
486      default:
487         unreachable("Invalid group operation");
488      }
489
490      vtn_push_ssa_value(b, w[2],
491         vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL,
492                                  reduction_op, cluster_size));
493      break;
494   }
495
496   default:
497      unreachable("Invalid SPIR-V opcode");
498   }
499}
500