1/* 2 * Copyright © 2016 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#include "vtn_private.h" 25 26static struct vtn_ssa_value * 27vtn_build_subgroup_instr(struct vtn_builder *b, 28 nir_intrinsic_op nir_op, 29 struct vtn_ssa_value *src0, 30 nir_ssa_def *index, 31 unsigned const_idx0, 32 unsigned const_idx1) 33{ 34 /* Some of the subgroup operations take an index. SPIR-V allows this to be 35 * any integer type. To make things simpler for drivers, we only support 36 * 32-bit indices. 37 */ 38 if (index && index->bit_size != 32) 39 index = nir_u2u32(&b->nb, index); 40 41 struct vtn_ssa_value *dst = vtn_create_ssa_value(b, src0->type); 42 43 vtn_assert(dst->type == src0->type); 44 if (!glsl_type_is_vector_or_scalar(dst->type)) { 45 for (unsigned i = 0; i < glsl_get_length(dst->type); i++) { 46 dst->elems[0] = 47 vtn_build_subgroup_instr(b, nir_op, src0->elems[i], index, 48 const_idx0, const_idx1); 49 } 50 return dst; 51 } 52 53 nir_intrinsic_instr *intrin = 54 nir_intrinsic_instr_create(b->nb.shader, nir_op); 55 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, 56 dst->type, NULL); 57 intrin->num_components = intrin->dest.ssa.num_components; 58 59 intrin->src[0] = nir_src_for_ssa(src0->def); 60 if (index) 61 intrin->src[1] = nir_src_for_ssa(index); 62 63 intrin->const_index[0] = const_idx0; 64 intrin->const_index[1] = const_idx1; 65 66 nir_builder_instr_insert(&b->nb, &intrin->instr); 67 68 dst->def = &intrin->dest.ssa; 69 70 return dst; 71} 72 73void 74vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, 75 const uint32_t *w, unsigned count) 76{ 77 struct vtn_type *dest_type = vtn_get_type(b, w[1]); 78 79 switch (opcode) { 80 case SpvOpGroupNonUniformElect: { 81 vtn_fail_if(dest_type->type != glsl_bool_type(), 82 "OpGroupNonUniformElect must return a Bool"); 83 nir_intrinsic_instr *elect = 84 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect); 85 nir_ssa_dest_init_for_type(&elect->instr, &elect->dest, 86 dest_type->type, NULL); 87 nir_builder_instr_insert(&b->nb, &elect->instr); 88 vtn_push_nir_ssa(b, w[2], &elect->dest.ssa); 89 break; 90 } 91 92 case SpvOpGroupNonUniformBallot: 93 case SpvOpSubgroupBallotKHR: { 94 bool has_scope = (opcode != SpvOpSubgroupBallotKHR); 95 vtn_fail_if(dest_type->type != glsl_vector_type(GLSL_TYPE_UINT, 4), 96 "OpGroupNonUniformBallot must return a uvec4"); 97 nir_intrinsic_instr *ballot = 98 nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot); 99 ballot->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[3 + has_scope])); 100 nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL); 101 ballot->num_components = 4; 102 nir_builder_instr_insert(&b->nb, &ballot->instr); 103 vtn_push_nir_ssa(b, w[2], &ballot->dest.ssa); 104 break; 105 } 106 107 case SpvOpGroupNonUniformInverseBallot: { 108 /* This one is just a BallotBitfieldExtract with subgroup invocation. 109 * We could add a NIR intrinsic but it's easier to just lower it on the 110 * spot. 111 */ 112 nir_intrinsic_instr *intrin = 113 nir_intrinsic_instr_create(b->nb.shader, 114 nir_intrinsic_ballot_bitfield_extract); 115 116 intrin->src[0] = nir_src_for_ssa(vtn_get_nir_ssa(b, w[4])); 117 intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb)); 118 119 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, 120 dest_type->type, NULL); 121 nir_builder_instr_insert(&b->nb, &intrin->instr); 122 123 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa); 124 break; 125 } 126 127 case SpvOpGroupNonUniformBallotBitExtract: 128 case SpvOpGroupNonUniformBallotBitCount: 129 case SpvOpGroupNonUniformBallotFindLSB: 130 case SpvOpGroupNonUniformBallotFindMSB: { 131 nir_ssa_def *src0, *src1 = NULL; 132 nir_intrinsic_op op; 133 switch (opcode) { 134 case SpvOpGroupNonUniformBallotBitExtract: 135 op = nir_intrinsic_ballot_bitfield_extract; 136 src0 = vtn_get_nir_ssa(b, w[4]); 137 src1 = vtn_get_nir_ssa(b, w[5]); 138 break; 139 case SpvOpGroupNonUniformBallotBitCount: 140 switch ((SpvGroupOperation)w[4]) { 141 case SpvGroupOperationReduce: 142 op = nir_intrinsic_ballot_bit_count_reduce; 143 break; 144 case SpvGroupOperationInclusiveScan: 145 op = nir_intrinsic_ballot_bit_count_inclusive; 146 break; 147 case SpvGroupOperationExclusiveScan: 148 op = nir_intrinsic_ballot_bit_count_exclusive; 149 break; 150 default: 151 unreachable("Invalid group operation"); 152 } 153 src0 = vtn_get_nir_ssa(b, w[5]); 154 break; 155 case SpvOpGroupNonUniformBallotFindLSB: 156 op = nir_intrinsic_ballot_find_lsb; 157 src0 = vtn_get_nir_ssa(b, w[4]); 158 break; 159 case SpvOpGroupNonUniformBallotFindMSB: 160 op = nir_intrinsic_ballot_find_msb; 161 src0 = vtn_get_nir_ssa(b, w[4]); 162 break; 163 default: 164 unreachable("Unhandled opcode"); 165 } 166 167 nir_intrinsic_instr *intrin = 168 nir_intrinsic_instr_create(b->nb.shader, op); 169 170 intrin->src[0] = nir_src_for_ssa(src0); 171 if (src1) 172 intrin->src[1] = nir_src_for_ssa(src1); 173 174 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, 175 dest_type->type, NULL); 176 nir_builder_instr_insert(&b->nb, &intrin->instr); 177 178 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa); 179 break; 180 } 181 182 case SpvOpGroupNonUniformBroadcastFirst: 183 case SpvOpSubgroupFirstInvocationKHR: { 184 bool has_scope = (opcode != SpvOpSubgroupFirstInvocationKHR); 185 vtn_push_ssa_value(b, w[2], 186 vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation, 187 vtn_ssa_value(b, w[3 + has_scope]), 188 NULL, 0, 0)); 189 break; 190 } 191 192 case SpvOpGroupNonUniformBroadcast: 193 case SpvOpGroupBroadcast: 194 case SpvOpSubgroupReadInvocationKHR: { 195 bool has_scope = (opcode != SpvOpSubgroupReadInvocationKHR); 196 vtn_push_ssa_value(b, w[2], 197 vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation, 198 vtn_ssa_value(b, w[3 + has_scope]), 199 vtn_get_nir_ssa(b, w[4 + has_scope]), 0, 0)); 200 break; 201 } 202 203 case SpvOpGroupNonUniformAll: 204 case SpvOpGroupNonUniformAny: 205 case SpvOpGroupNonUniformAllEqual: 206 case SpvOpGroupAll: 207 case SpvOpGroupAny: 208 case SpvOpSubgroupAllKHR: 209 case SpvOpSubgroupAnyKHR: 210 case SpvOpSubgroupAllEqualKHR: { 211 vtn_fail_if(dest_type->type != glsl_bool_type(), 212 "OpGroupNonUniform(All|Any|AllEqual) must return a bool"); 213 nir_intrinsic_op op; 214 switch (opcode) { 215 case SpvOpGroupNonUniformAll: 216 case SpvOpGroupAll: 217 case SpvOpSubgroupAllKHR: 218 op = nir_intrinsic_vote_all; 219 break; 220 case SpvOpGroupNonUniformAny: 221 case SpvOpGroupAny: 222 case SpvOpSubgroupAnyKHR: 223 op = nir_intrinsic_vote_any; 224 break; 225 case SpvOpSubgroupAllEqualKHR: 226 op = nir_intrinsic_vote_ieq; 227 break; 228 case SpvOpGroupNonUniformAllEqual: 229 switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) { 230 case GLSL_TYPE_FLOAT: 231 case GLSL_TYPE_FLOAT16: 232 case GLSL_TYPE_DOUBLE: 233 op = nir_intrinsic_vote_feq; 234 break; 235 case GLSL_TYPE_UINT: 236 case GLSL_TYPE_INT: 237 case GLSL_TYPE_UINT8: 238 case GLSL_TYPE_INT8: 239 case GLSL_TYPE_UINT16: 240 case GLSL_TYPE_INT16: 241 case GLSL_TYPE_UINT64: 242 case GLSL_TYPE_INT64: 243 case GLSL_TYPE_BOOL: 244 op = nir_intrinsic_vote_ieq; 245 break; 246 default: 247 unreachable("Unhandled type"); 248 } 249 break; 250 default: 251 unreachable("Unhandled opcode"); 252 } 253 254 nir_ssa_def *src0; 255 if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll || 256 opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny || 257 opcode == SpvOpGroupNonUniformAllEqual) { 258 src0 = vtn_get_nir_ssa(b, w[4]); 259 } else { 260 src0 = vtn_get_nir_ssa(b, w[3]); 261 } 262 nir_intrinsic_instr *intrin = 263 nir_intrinsic_instr_create(b->nb.shader, op); 264 if (nir_intrinsic_infos[op].src_components[0] == 0) 265 intrin->num_components = src0->num_components; 266 intrin->src[0] = nir_src_for_ssa(src0); 267 nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest, 268 dest_type->type, NULL); 269 nir_builder_instr_insert(&b->nb, &intrin->instr); 270 271 vtn_push_nir_ssa(b, w[2], &intrin->dest.ssa); 272 break; 273 } 274 275 case SpvOpGroupNonUniformShuffle: 276 case SpvOpGroupNonUniformShuffleXor: 277 case SpvOpGroupNonUniformShuffleUp: 278 case SpvOpGroupNonUniformShuffleDown: { 279 nir_intrinsic_op op; 280 switch (opcode) { 281 case SpvOpGroupNonUniformShuffle: 282 op = nir_intrinsic_shuffle; 283 break; 284 case SpvOpGroupNonUniformShuffleXor: 285 op = nir_intrinsic_shuffle_xor; 286 break; 287 case SpvOpGroupNonUniformShuffleUp: 288 op = nir_intrinsic_shuffle_up; 289 break; 290 case SpvOpGroupNonUniformShuffleDown: 291 op = nir_intrinsic_shuffle_down; 292 break; 293 default: 294 unreachable("Invalid opcode"); 295 } 296 vtn_push_ssa_value(b, w[2], 297 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), 298 vtn_get_nir_ssa(b, w[5]), 0, 0)); 299 break; 300 } 301 302 case SpvOpSubgroupShuffleINTEL: 303 case SpvOpSubgroupShuffleXorINTEL: { 304 nir_intrinsic_op op = opcode == SpvOpSubgroupShuffleINTEL ? 305 nir_intrinsic_shuffle : nir_intrinsic_shuffle_xor; 306 vtn_push_ssa_value(b, w[2], 307 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[3]), 308 vtn_get_nir_ssa(b, w[4]), 0, 0)); 309 break; 310 } 311 312 case SpvOpSubgroupShuffleUpINTEL: 313 case SpvOpSubgroupShuffleDownINTEL: { 314 /* TODO: Move this lower on the compiler stack, where we can move the 315 * current/other data to adjacent registers to avoid doing a shuffle 316 * twice. 317 */ 318 319 nir_builder *nb = &b->nb; 320 nir_ssa_def *size = nir_load_subgroup_size(nb); 321 nir_ssa_def *delta = vtn_get_nir_ssa(b, w[5]); 322 323 /* Rewrite UP in terms of DOWN. 324 * 325 * UP(a, b, delta) == DOWN(a, b, size - delta) 326 */ 327 if (opcode == SpvOpSubgroupShuffleUpINTEL) 328 delta = nir_isub(nb, size, delta); 329 330 nir_ssa_def *index = nir_iadd(nb, nir_load_subgroup_invocation(nb), delta); 331 struct vtn_ssa_value *current = 332 vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[3]), 333 index, 0, 0); 334 335 struct vtn_ssa_value *next = 336 vtn_build_subgroup_instr(b, nir_intrinsic_shuffle, vtn_ssa_value(b, w[4]), 337 nir_isub(nb, index, size), 0, 0); 338 339 nir_ssa_def *cond = nir_ilt(nb, index, size); 340 vtn_push_nir_ssa(b, w[2], nir_bcsel(nb, cond, current->def, next->def)); 341 342 break; 343 } 344 345 case SpvOpGroupNonUniformQuadBroadcast: 346 vtn_push_ssa_value(b, w[2], 347 vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast, 348 vtn_ssa_value(b, w[4]), 349 vtn_get_nir_ssa(b, w[5]), 0, 0)); 350 break; 351 352 case SpvOpGroupNonUniformQuadSwap: { 353 unsigned direction = vtn_constant_uint(b, w[5]); 354 nir_intrinsic_op op; 355 switch (direction) { 356 case 0: 357 op = nir_intrinsic_quad_swap_horizontal; 358 break; 359 case 1: 360 op = nir_intrinsic_quad_swap_vertical; 361 break; 362 case 2: 363 op = nir_intrinsic_quad_swap_diagonal; 364 break; 365 default: 366 vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap"); 367 } 368 vtn_push_ssa_value(b, w[2], 369 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[4]), NULL, 0, 0)); 370 break; 371 } 372 373 case SpvOpGroupNonUniformIAdd: 374 case SpvOpGroupNonUniformFAdd: 375 case SpvOpGroupNonUniformIMul: 376 case SpvOpGroupNonUniformFMul: 377 case SpvOpGroupNonUniformSMin: 378 case SpvOpGroupNonUniformUMin: 379 case SpvOpGroupNonUniformFMin: 380 case SpvOpGroupNonUniformSMax: 381 case SpvOpGroupNonUniformUMax: 382 case SpvOpGroupNonUniformFMax: 383 case SpvOpGroupNonUniformBitwiseAnd: 384 case SpvOpGroupNonUniformBitwiseOr: 385 case SpvOpGroupNonUniformBitwiseXor: 386 case SpvOpGroupNonUniformLogicalAnd: 387 case SpvOpGroupNonUniformLogicalOr: 388 case SpvOpGroupNonUniformLogicalXor: 389 case SpvOpGroupIAdd: 390 case SpvOpGroupFAdd: 391 case SpvOpGroupFMin: 392 case SpvOpGroupUMin: 393 case SpvOpGroupSMin: 394 case SpvOpGroupFMax: 395 case SpvOpGroupUMax: 396 case SpvOpGroupSMax: 397 case SpvOpGroupIAddNonUniformAMD: 398 case SpvOpGroupFAddNonUniformAMD: 399 case SpvOpGroupFMinNonUniformAMD: 400 case SpvOpGroupUMinNonUniformAMD: 401 case SpvOpGroupSMinNonUniformAMD: 402 case SpvOpGroupFMaxNonUniformAMD: 403 case SpvOpGroupUMaxNonUniformAMD: 404 case SpvOpGroupSMaxNonUniformAMD: { 405 nir_op reduction_op; 406 switch (opcode) { 407 case SpvOpGroupNonUniformIAdd: 408 case SpvOpGroupIAdd: 409 case SpvOpGroupIAddNonUniformAMD: 410 reduction_op = nir_op_iadd; 411 break; 412 case SpvOpGroupNonUniformFAdd: 413 case SpvOpGroupFAdd: 414 case SpvOpGroupFAddNonUniformAMD: 415 reduction_op = nir_op_fadd; 416 break; 417 case SpvOpGroupNonUniformIMul: 418 reduction_op = nir_op_imul; 419 break; 420 case SpvOpGroupNonUniformFMul: 421 reduction_op = nir_op_fmul; 422 break; 423 case SpvOpGroupNonUniformSMin: 424 case SpvOpGroupSMin: 425 case SpvOpGroupSMinNonUniformAMD: 426 reduction_op = nir_op_imin; 427 break; 428 case SpvOpGroupNonUniformUMin: 429 case SpvOpGroupUMin: 430 case SpvOpGroupUMinNonUniformAMD: 431 reduction_op = nir_op_umin; 432 break; 433 case SpvOpGroupNonUniformFMin: 434 case SpvOpGroupFMin: 435 case SpvOpGroupFMinNonUniformAMD: 436 reduction_op = nir_op_fmin; 437 break; 438 case SpvOpGroupNonUniformSMax: 439 case SpvOpGroupSMax: 440 case SpvOpGroupSMaxNonUniformAMD: 441 reduction_op = nir_op_imax; 442 break; 443 case SpvOpGroupNonUniformUMax: 444 case SpvOpGroupUMax: 445 case SpvOpGroupUMaxNonUniformAMD: 446 reduction_op = nir_op_umax; 447 break; 448 case SpvOpGroupNonUniformFMax: 449 case SpvOpGroupFMax: 450 case SpvOpGroupFMaxNonUniformAMD: 451 reduction_op = nir_op_fmax; 452 break; 453 case SpvOpGroupNonUniformBitwiseAnd: 454 case SpvOpGroupNonUniformLogicalAnd: 455 reduction_op = nir_op_iand; 456 break; 457 case SpvOpGroupNonUniformBitwiseOr: 458 case SpvOpGroupNonUniformLogicalOr: 459 reduction_op = nir_op_ior; 460 break; 461 case SpvOpGroupNonUniformBitwiseXor: 462 case SpvOpGroupNonUniformLogicalXor: 463 reduction_op = nir_op_ixor; 464 break; 465 default: 466 unreachable("Invalid reduction operation"); 467 } 468 469 nir_intrinsic_op op; 470 unsigned cluster_size = 0; 471 switch ((SpvGroupOperation)w[4]) { 472 case SpvGroupOperationReduce: 473 op = nir_intrinsic_reduce; 474 break; 475 case SpvGroupOperationInclusiveScan: 476 op = nir_intrinsic_inclusive_scan; 477 break; 478 case SpvGroupOperationExclusiveScan: 479 op = nir_intrinsic_exclusive_scan; 480 break; 481 case SpvGroupOperationClusteredReduce: 482 op = nir_intrinsic_reduce; 483 assert(count == 7); 484 cluster_size = vtn_constant_uint(b, w[6]); 485 break; 486 default: 487 unreachable("Invalid group operation"); 488 } 489 490 vtn_push_ssa_value(b, w[2], 491 vtn_build_subgroup_instr(b, op, vtn_ssa_value(b, w[5]), NULL, 492 reduction_op, cluster_size)); 493 break; 494 } 495 496 default: 497 unreachable("Invalid SPIR-V opcode"); 498 } 499} 500