1/* 2 * Copyright © 2016 Intel Corporation 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 */ 23 24#include <math.h> 25#include "vtn_private.h" 26#include "spirv_info.h" 27 28/* 29 * Normally, column vectors in SPIR-V correspond to a single NIR SSA 30 * definition. But for matrix multiplies, we want to do one routine for 31 * multiplying a matrix by a matrix and then pretend that vectors are matrices 32 * with one column. So we "wrap" these things, and unwrap the result before we 33 * send it off. 34 */ 35 36static struct vtn_ssa_value * 37wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) 38{ 39 if (val == NULL) 40 return NULL; 41 42 if (glsl_type_is_matrix(val->type)) 43 return val; 44 45 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); 46 dest->type = glsl_get_bare_type(val->type); 47 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); 48 dest->elems[0] = val; 49 50 return dest; 51} 52 53static struct vtn_ssa_value * 54unwrap_matrix(struct vtn_ssa_value *val) 55{ 56 if (glsl_type_is_matrix(val->type)) 57 return val; 58 59 return val->elems[0]; 60} 61 62static struct vtn_ssa_value * 63matrix_multiply(struct vtn_builder *b, 64 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1) 65{ 66 67 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0); 68 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1); 69 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed); 70 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed); 71 72 unsigned src0_rows = glsl_get_vector_elements(src0->type); 73 unsigned src0_columns = glsl_get_matrix_columns(src0->type); 74 unsigned src1_columns = glsl_get_matrix_columns(src1->type); 75 76 const struct glsl_type *dest_type; 77 if (src1_columns > 1) { 78 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type), 79 src0_rows, src1_columns); 80 } else { 81 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows); 82 } 83 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); 84 85 dest = wrap_matrix(b, dest); 86 87 bool transpose_result = false; 88 if (src0_transpose && src1_transpose) { 89 /* transpose(A) * transpose(B) = transpose(B * A) */ 90 src1 = src0_transpose; 91 src0 = src1_transpose; 92 src0_transpose = NULL; 93 src1_transpose = NULL; 94 transpose_result = true; 95 } 96 97 if (src0_transpose && !src1_transpose && 98 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) { 99 /* We already have the rows of src0 and the columns of src1 available, 100 * so we can just take the dot product of each row with each column to 101 * get the result. 102 */ 103 104 for (unsigned i = 0; i < src1_columns; i++) { 105 nir_ssa_def *vec_src[4]; 106 for (unsigned j = 0; j < src0_rows; j++) { 107 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def, 108 src1->elems[i]->def); 109 } 110 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows); 111 } 112 } else { 113 /* We don't handle the case where src1 is transposed but not src0, since 114 * the general case only uses individual components of src1 so the 115 * optimizer should chew through the transpose we emitted for src1. 116 */ 117 118 for (unsigned i = 0; i < src1_columns; i++) { 119 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ 120 dest->elems[i]->def = 121 nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def, 122 nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1)); 123 for (int j = src0_columns - 2; j >= 0; j--) { 124 dest->elems[i]->def = 125 nir_ffma(&b->nb, src0->elems[j]->def, 126 nir_channel(&b->nb, src1->elems[i]->def, j), 127 dest->elems[i]->def); 128 } 129 } 130 } 131 132 dest = unwrap_matrix(dest); 133 134 if (transpose_result) 135 dest = vtn_ssa_transpose(b, dest); 136 137 return dest; 138} 139 140static struct vtn_ssa_value * 141mat_times_scalar(struct vtn_builder *b, 142 struct vtn_ssa_value *mat, 143 nir_ssa_def *scalar) 144{ 145 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type); 146 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) { 147 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type))) 148 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar); 149 else 150 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); 151 } 152 153 return dest; 154} 155 156nir_ssa_def * 157vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def) 158{ 159 if (def->bit_size == 16) 160 return def; 161 162 switch (base_type) { 163 case GLSL_TYPE_FLOAT: 164 return nir_f2fmp(&b->nb, def); 165 case GLSL_TYPE_INT: 166 case GLSL_TYPE_UINT: 167 return nir_i2imp(&b->nb, def); 168 /* Workaround for 3DMark Wild Life which has RelaxedPrecision on 169 * OpLogical* operations (which is forbidden by spec). 170 */ 171 case GLSL_TYPE_BOOL: 172 return def; 173 default: 174 unreachable("bad relaxed precision input type"); 175 } 176} 177 178struct vtn_ssa_value * 179vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src) 180{ 181 if (!src) 182 return src; 183 184 struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type); 185 186 if (src->transposed) { 187 srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed); 188 } else { 189 enum glsl_base_type base_type = glsl_get_base_type(src->type); 190 191 if (glsl_type_is_vector_or_scalar(src->type)) { 192 srcmp->def = vtn_mediump_downconvert(b, base_type, src->def); 193 } else { 194 assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT); 195 for (int i = 0; i < glsl_get_matrix_columns(src->type); i++) 196 srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def); 197 } 198 } 199 200 return srcmp; 201} 202 203static struct vtn_ssa_value * 204vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, 205 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1) 206{ 207 switch (opcode) { 208 case SpvOpFNegate: { 209 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); 210 unsigned cols = glsl_get_matrix_columns(src0->type); 211 for (unsigned i = 0; i < cols; i++) 212 dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def); 213 return dest; 214 } 215 216 case SpvOpFAdd: { 217 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); 218 unsigned cols = glsl_get_matrix_columns(src0->type); 219 for (unsigned i = 0; i < cols; i++) 220 dest->elems[i]->def = 221 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def); 222 return dest; 223 } 224 225 case SpvOpFSub: { 226 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type); 227 unsigned cols = glsl_get_matrix_columns(src0->type); 228 for (unsigned i = 0; i < cols; i++) 229 dest->elems[i]->def = 230 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def); 231 return dest; 232 } 233 234 case SpvOpTranspose: 235 return vtn_ssa_transpose(b, src0); 236 237 case SpvOpMatrixTimesScalar: 238 if (src0->transposed) { 239 return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, 240 src1->def)); 241 } else { 242 return mat_times_scalar(b, src0, src1->def); 243 } 244 break; 245 246 case SpvOpVectorTimesMatrix: 247 case SpvOpMatrixTimesVector: 248 case SpvOpMatrixTimesMatrix: 249 if (opcode == SpvOpVectorTimesMatrix) { 250 return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); 251 } else { 252 return matrix_multiply(b, src0, src1); 253 } 254 break; 255 256 default: vtn_fail_with_opcode("unknown matrix opcode", opcode); 257 } 258} 259 260static nir_alu_type 261convert_op_src_type(SpvOp opcode) 262{ 263 switch (opcode) { 264 case SpvOpFConvert: 265 case SpvOpConvertFToS: 266 case SpvOpConvertFToU: 267 return nir_type_float; 268 case SpvOpSConvert: 269 case SpvOpConvertSToF: 270 case SpvOpSatConvertSToU: 271 return nir_type_int; 272 case SpvOpUConvert: 273 case SpvOpConvertUToF: 274 case SpvOpSatConvertUToS: 275 return nir_type_uint; 276 default: 277 unreachable("Unhandled conversion op"); 278 } 279} 280 281static nir_alu_type 282convert_op_dst_type(SpvOp opcode) 283{ 284 switch (opcode) { 285 case SpvOpFConvert: 286 case SpvOpConvertSToF: 287 case SpvOpConvertUToF: 288 return nir_type_float; 289 case SpvOpSConvert: 290 case SpvOpConvertFToS: 291 case SpvOpSatConvertUToS: 292 return nir_type_int; 293 case SpvOpUConvert: 294 case SpvOpConvertFToU: 295 case SpvOpSatConvertSToU: 296 return nir_type_uint; 297 default: 298 unreachable("Unhandled conversion op"); 299 } 300} 301 302nir_op 303vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, 304 SpvOp opcode, bool *swap, bool *exact, 305 unsigned src_bit_size, unsigned dst_bit_size) 306{ 307 /* Indicates that the first two arguments should be swapped. This is 308 * used for implementing greater-than and less-than-or-equal. 309 */ 310 *swap = false; 311 312 *exact = false; 313 314 switch (opcode) { 315 case SpvOpSNegate: return nir_op_ineg; 316 case SpvOpFNegate: return nir_op_fneg; 317 case SpvOpNot: return nir_op_inot; 318 case SpvOpIAdd: return nir_op_iadd; 319 case SpvOpFAdd: return nir_op_fadd; 320 case SpvOpISub: return nir_op_isub; 321 case SpvOpFSub: return nir_op_fsub; 322 case SpvOpIMul: return nir_op_imul; 323 case SpvOpFMul: return nir_op_fmul; 324 case SpvOpUDiv: return nir_op_udiv; 325 case SpvOpSDiv: return nir_op_idiv; 326 case SpvOpFDiv: return nir_op_fdiv; 327 case SpvOpUMod: return nir_op_umod; 328 case SpvOpSMod: return nir_op_imod; 329 case SpvOpFMod: return nir_op_fmod; 330 case SpvOpSRem: return nir_op_irem; 331 case SpvOpFRem: return nir_op_frem; 332 333 case SpvOpShiftRightLogical: return nir_op_ushr; 334 case SpvOpShiftRightArithmetic: return nir_op_ishr; 335 case SpvOpShiftLeftLogical: return nir_op_ishl; 336 case SpvOpLogicalOr: return nir_op_ior; 337 case SpvOpLogicalEqual: return nir_op_ieq; 338 case SpvOpLogicalNotEqual: return nir_op_ine; 339 case SpvOpLogicalAnd: return nir_op_iand; 340 case SpvOpLogicalNot: return nir_op_inot; 341 case SpvOpBitwiseOr: return nir_op_ior; 342 case SpvOpBitwiseXor: return nir_op_ixor; 343 case SpvOpBitwiseAnd: return nir_op_iand; 344 case SpvOpSelect: return nir_op_bcsel; 345 case SpvOpIEqual: return nir_op_ieq; 346 347 case SpvOpBitFieldInsert: return nir_op_bitfield_insert; 348 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract; 349 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract; 350 case SpvOpBitReverse: return nir_op_bitfield_reverse; 351 352 case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz; 353 /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */ 354 case SpvOpAbsISubINTEL: return nir_op_uabs_isub; 355 case SpvOpAbsUSubINTEL: return nir_op_uabs_usub; 356 case SpvOpIAddSatINTEL: return nir_op_iadd_sat; 357 case SpvOpUAddSatINTEL: return nir_op_uadd_sat; 358 case SpvOpIAverageINTEL: return nir_op_ihadd; 359 case SpvOpUAverageINTEL: return nir_op_uhadd; 360 case SpvOpIAverageRoundedINTEL: return nir_op_irhadd; 361 case SpvOpUAverageRoundedINTEL: return nir_op_urhadd; 362 case SpvOpISubSatINTEL: return nir_op_isub_sat; 363 case SpvOpUSubSatINTEL: return nir_op_usub_sat; 364 case SpvOpIMul32x16INTEL: return nir_op_imul_32x16; 365 case SpvOpUMul32x16INTEL: return nir_op_umul_32x16; 366 367 /* The ordered / unordered operators need special implementation besides 368 * the logical operator to use since they also need to check if operands are 369 * ordered. 370 */ 371 case SpvOpFOrdEqual: *exact = true; return nir_op_feq; 372 case SpvOpFUnordEqual: *exact = true; return nir_op_feq; 373 case SpvOpINotEqual: return nir_op_ine; 374 case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */ 375 case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu; 376 case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu; 377 case SpvOpULessThan: return nir_op_ult; 378 case SpvOpSLessThan: return nir_op_ilt; 379 case SpvOpFOrdLessThan: *exact = true; return nir_op_flt; 380 case SpvOpFUnordLessThan: *exact = true; return nir_op_flt; 381 case SpvOpUGreaterThan: *swap = true; return nir_op_ult; 382 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt; 383 case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt; 384 case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt; 385 case SpvOpULessThanEqual: *swap = true; return nir_op_uge; 386 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige; 387 case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge; 388 case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge; 389 case SpvOpUGreaterThanEqual: return nir_op_uge; 390 case SpvOpSGreaterThanEqual: return nir_op_ige; 391 case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge; 392 case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge; 393 394 /* Conversions: */ 395 case SpvOpQuantizeToF16: return nir_op_fquantize2f16; 396 case SpvOpUConvert: 397 case SpvOpConvertFToU: 398 case SpvOpConvertFToS: 399 case SpvOpConvertSToF: 400 case SpvOpConvertUToF: 401 case SpvOpSConvert: 402 case SpvOpFConvert: { 403 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size; 404 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size; 405 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef); 406 } 407 408 case SpvOpPtrCastToGeneric: return nir_op_mov; 409 case SpvOpGenericCastToPtr: return nir_op_mov; 410 411 /* Derivatives: */ 412 case SpvOpDPdx: return nir_op_fddx; 413 case SpvOpDPdy: return nir_op_fddy; 414 case SpvOpDPdxFine: return nir_op_fddx_fine; 415 case SpvOpDPdyFine: return nir_op_fddy_fine; 416 case SpvOpDPdxCoarse: return nir_op_fddx_coarse; 417 case SpvOpDPdyCoarse: return nir_op_fddy_coarse; 418 419 case SpvOpIsNormal: return nir_op_fisnormal; 420 case SpvOpIsFinite: return nir_op_fisfinite; 421 422 default: 423 vtn_fail("No NIR equivalent: %u", opcode); 424 } 425} 426 427static void 428handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val, 429 UNUSED int member, const struct vtn_decoration *dec, 430 UNUSED void *_void) 431{ 432 vtn_assert(dec->scope == VTN_DEC_DECORATION); 433 if (dec->decoration != SpvDecorationNoContraction) 434 return; 435 436 b->nb.exact = true; 437} 438 439void 440vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val) 441{ 442 vtn_foreach_decoration(b, val, handle_no_contraction, NULL); 443} 444 445nir_rounding_mode 446vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode) 447{ 448 switch (mode) { 449 case SpvFPRoundingModeRTE: 450 return nir_rounding_mode_rtne; 451 case SpvFPRoundingModeRTZ: 452 return nir_rounding_mode_rtz; 453 case SpvFPRoundingModeRTP: 454 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL, 455 "FPRoundingModeRTP is only supported in kernels"); 456 return nir_rounding_mode_ru; 457 case SpvFPRoundingModeRTN: 458 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL, 459 "FPRoundingModeRTN is only supported in kernels"); 460 return nir_rounding_mode_rd; 461 default: 462 vtn_fail("Unsupported rounding mode: %s", 463 spirv_fproundingmode_to_string(mode)); 464 break; 465 } 466} 467 468struct conversion_opts { 469 nir_rounding_mode rounding_mode; 470 bool saturate; 471}; 472 473static void 474handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val, 475 UNUSED int member, 476 const struct vtn_decoration *dec, void *_opts) 477{ 478 struct conversion_opts *opts = _opts; 479 480 switch (dec->decoration) { 481 case SpvDecorationFPRoundingMode: 482 opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]); 483 break; 484 485 case SpvDecorationSaturatedConversion: 486 vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL, 487 "Saturated conversions are only allowed in kernels"); 488 opts->saturate = true; 489 break; 490 491 default: 492 break; 493 } 494} 495 496static void 497handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val, 498 UNUSED int member, 499 const struct vtn_decoration *dec, void *_alu) 500{ 501 nir_alu_instr *alu = _alu; 502 switch (dec->decoration) { 503 case SpvDecorationNoSignedWrap: 504 alu->no_signed_wrap = true; 505 break; 506 case SpvDecorationNoUnsignedWrap: 507 alu->no_unsigned_wrap = true; 508 break; 509 default: 510 /* Do nothing. */ 511 break; 512 } 513} 514 515static void 516vtn_value_is_relaxed_precision_cb(struct vtn_builder *b, 517 struct vtn_value *val, int member, 518 const struct vtn_decoration *dec, void *void_ctx) 519{ 520 bool *relaxed_precision = void_ctx; 521 switch (dec->decoration) { 522 case SpvDecorationRelaxedPrecision: 523 *relaxed_precision = true; 524 break; 525 526 default: 527 break; 528 } 529} 530 531bool 532vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val) 533{ 534 bool result = false; 535 vtn_foreach_decoration(b, val, 536 vtn_value_is_relaxed_precision_cb, &result); 537 return result; 538} 539 540static bool 541vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val) 542{ 543 if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val)) 544 return false; 545 546 switch (opcode) { 547 case SpvOpDPdx: 548 case SpvOpDPdy: 549 case SpvOpDPdxFine: 550 case SpvOpDPdyFine: 551 case SpvOpDPdxCoarse: 552 case SpvOpDPdyCoarse: 553 case SpvOpFwidth: 554 case SpvOpFwidthFine: 555 case SpvOpFwidthCoarse: 556 return b->options->mediump_16bit_derivatives; 557 default: 558 return true; 559 } 560} 561 562static nir_ssa_def * 563vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_ssa_def *def) 564{ 565 if (def->bit_size != 16) 566 return def; 567 568 switch (base_type) { 569 case GLSL_TYPE_FLOAT: 570 return nir_f2f32(&b->nb, def); 571 case GLSL_TYPE_INT: 572 return nir_i2i32(&b->nb, def); 573 case GLSL_TYPE_UINT: 574 return nir_u2u32(&b->nb, def); 575 default: 576 unreachable("bad relaxed precision output type"); 577 } 578} 579 580void 581vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value) 582{ 583 enum glsl_base_type base_type = glsl_get_base_type(value->type); 584 585 if (glsl_type_is_vector_or_scalar(value->type)) { 586 value->def = vtn_mediump_upconvert(b, base_type, value->def); 587 } else { 588 for (int i = 0; i < glsl_get_matrix_columns(value->type); i++) 589 value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def); 590 } 591} 592 593void 594vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, 595 const uint32_t *w, unsigned count) 596{ 597 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]); 598 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; 599 600 vtn_handle_no_contraction(b, dest_val); 601 bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val); 602 603 /* Collect the various SSA sources */ 604 const unsigned num_inputs = count - 3; 605 struct vtn_ssa_value *vtn_src[4] = { NULL, }; 606 for (unsigned i = 0; i < num_inputs; i++) { 607 vtn_src[i] = vtn_ssa_value(b, w[i + 3]); 608 if (mediump_16bit) 609 vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]); 610 } 611 612 if (glsl_type_is_matrix(vtn_src[0]->type) || 613 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) { 614 struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]); 615 616 if (mediump_16bit) 617 vtn_mediump_upconvert_value(b, dest); 618 619 vtn_push_ssa_value(b, w[2], dest); 620 b->nb.exact = b->exact; 621 return; 622 } 623 624 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); 625 nir_ssa_def *src[4] = { NULL, }; 626 for (unsigned i = 0; i < num_inputs; i++) { 627 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); 628 src[i] = vtn_src[i]->def; 629 } 630 631 switch (opcode) { 632 case SpvOpAny: 633 dest->def = nir_bany(&b->nb, src[0]); 634 break; 635 636 case SpvOpAll: 637 dest->def = nir_ball(&b->nb, src[0]); 638 break; 639 640 case SpvOpOuterProduct: { 641 for (unsigned i = 0; i < src[1]->num_components; i++) { 642 dest->elems[i]->def = 643 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i)); 644 } 645 break; 646 } 647 648 case SpvOpDot: 649 dest->def = nir_fdot(&b->nb, src[0], src[1]); 650 break; 651 652 case SpvOpIAddCarry: 653 vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); 654 dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); 655 dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]); 656 break; 657 658 case SpvOpISubBorrow: 659 vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); 660 dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); 661 dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]); 662 break; 663 664 case SpvOpUMulExtended: { 665 vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); 666 if (src[0]->bit_size == 32) { 667 nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]); 668 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul); 669 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul); 670 } else { 671 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); 672 dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]); 673 } 674 break; 675 } 676 677 case SpvOpSMulExtended: { 678 vtn_assert(glsl_type_is_struct_or_ifc(dest_type)); 679 if (src[0]->bit_size == 32) { 680 nir_ssa_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]); 681 dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul); 682 dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul); 683 } else { 684 dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); 685 dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]); 686 } 687 break; 688 } 689 690 case SpvOpFwidth: 691 dest->def = nir_fadd(&b->nb, 692 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), 693 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0]))); 694 break; 695 case SpvOpFwidthFine: 696 dest->def = nir_fadd(&b->nb, 697 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), 698 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0]))); 699 break; 700 case SpvOpFwidthCoarse: 701 dest->def = nir_fadd(&b->nb, 702 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), 703 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0]))); 704 break; 705 706 case SpvOpVectorTimesScalar: 707 /* The builder will take care of splatting for us. */ 708 dest->def = nir_fmul(&b->nb, src[0], src[1]); 709 break; 710 711 case SpvOpIsNan: { 712 const bool save_exact = b->nb.exact; 713 714 b->nb.exact = true; 715 dest->def = nir_fneu(&b->nb, src[0], src[0]); 716 b->nb.exact = save_exact; 717 break; 718 } 719 720 case SpvOpOrdered: { 721 const bool save_exact = b->nb.exact; 722 723 b->nb.exact = true; 724 dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]), 725 nir_feq(&b->nb, src[1], src[1])); 726 b->nb.exact = save_exact; 727 break; 728 } 729 730 case SpvOpUnordered: { 731 const bool save_exact = b->nb.exact; 732 733 b->nb.exact = true; 734 dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]), 735 nir_fneu(&b->nb, src[1], src[1])); 736 b->nb.exact = save_exact; 737 break; 738 } 739 740 case SpvOpIsInf: { 741 nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size); 742 dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf); 743 break; 744 } 745 746 case SpvOpFUnordEqual: { 747 const bool save_exact = b->nb.exact; 748 749 b->nb.exact = true; 750 751 /* This could also be implemented as !(a < b || b < a). If one or both 752 * of the source are numbers, later optimization passes can easily 753 * eliminate the isnan() checks. This may trim the sequence down to a 754 * single (a == b) operation. Otherwise, the optimizer can transform 755 * whatever is left to !(a < b || b < a). Since some applications will 756 * open-code this sequence, these optimizations are needed anyway. 757 */ 758 dest->def = 759 nir_ior(&b->nb, 760 nir_feq(&b->nb, src[0], src[1]), 761 nir_ior(&b->nb, 762 nir_fneu(&b->nb, src[0], src[0]), 763 nir_fneu(&b->nb, src[1], src[1]))); 764 765 b->nb.exact = save_exact; 766 break; 767 } 768 769 case SpvOpFUnordLessThan: 770 case SpvOpFUnordGreaterThan: 771 case SpvOpFUnordLessThanEqual: 772 case SpvOpFUnordGreaterThanEqual: { 773 bool swap; 774 bool unused_exact; 775 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); 776 unsigned dst_bit_size = glsl_get_bit_size(dest_type); 777 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, 778 &unused_exact, 779 src_bit_size, dst_bit_size); 780 781 if (swap) { 782 nir_ssa_def *tmp = src[0]; 783 src[0] = src[1]; 784 src[1] = tmp; 785 } 786 787 const bool save_exact = b->nb.exact; 788 789 b->nb.exact = true; 790 791 /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */ 792 switch (op) { 793 case nir_op_fge: op = nir_op_flt; break; 794 case nir_op_flt: op = nir_op_fge; break; 795 default: unreachable("Impossible opcode."); 796 } 797 798 dest->def = 799 nir_inot(&b->nb, 800 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL)); 801 802 b->nb.exact = save_exact; 803 break; 804 } 805 806 case SpvOpLessOrGreater: 807 case SpvOpFOrdNotEqual: { 808 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value 809 * from the ALU will probably already be false if the operands are not 810 * ordered so we don’t need to handle it specially. 811 */ 812 const bool save_exact = b->nb.exact; 813 814 b->nb.exact = true; 815 816 /* This could also be implemented as (a < b || b < a). If one or both 817 * of the source are numbers, later optimization passes can easily 818 * eliminate the isnan() checks. This may trim the sequence down to a 819 * single (a != b) operation. Otherwise, the optimizer can transform 820 * whatever is left to (a < b || b < a). Since some applications will 821 * open-code this sequence, these optimizations are needed anyway. 822 */ 823 dest->def = 824 nir_iand(&b->nb, 825 nir_fneu(&b->nb, src[0], src[1]), 826 nir_iand(&b->nb, 827 nir_feq(&b->nb, src[0], src[0]), 828 nir_feq(&b->nb, src[1], src[1]))); 829 830 b->nb.exact = save_exact; 831 break; 832 } 833 834 case SpvOpUConvert: 835 case SpvOpConvertFToU: 836 case SpvOpConvertFToS: 837 case SpvOpConvertSToF: 838 case SpvOpConvertUToF: 839 case SpvOpSConvert: 840 case SpvOpFConvert: 841 case SpvOpSatConvertSToU: 842 case SpvOpSatConvertUToS: { 843 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); 844 unsigned dst_bit_size = glsl_get_bit_size(dest_type); 845 nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size; 846 nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size; 847 848 struct conversion_opts opts = { 849 .rounding_mode = nir_rounding_mode_undef, 850 .saturate = false, 851 }; 852 vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts); 853 854 if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS) 855 opts.saturate = true; 856 857 if (b->shader->info.stage == MESA_SHADER_KERNEL) { 858 if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) { 859 nir_op op = nir_type_conversion_op(src_type, dst_type, 860 nir_rounding_mode_undef); 861 dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL); 862 } else { 863 dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0], 864 src_type, dst_type, 865 opts.rounding_mode, opts.saturate); 866 } 867 } else { 868 vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef && 869 dst_type != nir_type_float16, 870 "Rounding modes are only allowed on conversions to " 871 "16-bit float types"); 872 nir_op op = nir_type_conversion_op(src_type, dst_type, 873 opts.rounding_mode); 874 dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL); 875 } 876 break; 877 } 878 879 case SpvOpBitFieldInsert: 880 case SpvOpBitFieldSExtract: 881 case SpvOpBitFieldUExtract: 882 case SpvOpShiftLeftLogical: 883 case SpvOpShiftRightArithmetic: 884 case SpvOpShiftRightLogical: { 885 bool swap; 886 bool exact; 887 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type); 888 unsigned dst_bit_size = glsl_get_bit_size(dest_type); 889 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact, 890 src0_bit_size, dst_bit_size); 891 892 assert(!exact); 893 894 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl || 895 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract || 896 op == nir_op_ibitfield_extract); 897 898 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { 899 unsigned src_bit_size = 900 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]); 901 if (src_bit_size == 0) 902 continue; 903 if (src_bit_size != src[i]->bit_size) { 904 assert(src_bit_size == 32); 905 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize 906 * supported by the NIR instructions. See discussion here: 907 * 908 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html 909 */ 910 src[i] = nir_u2u32(&b->nb, src[i]); 911 } 912 } 913 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); 914 break; 915 } 916 917 case SpvOpSignBitSet: 918 dest->def = nir_i2b(&b->nb, 919 nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1))); 920 break; 921 922 case SpvOpUCountTrailingZerosINTEL: 923 dest->def = nir_umin(&b->nb, 924 nir_find_lsb(&b->nb, src[0]), 925 nir_imm_int(&b->nb, 32u)); 926 break; 927 928 case SpvOpBitCount: { 929 /* bit_count always returns int32, but the SPIR-V opcode just says the return 930 * value needs to be big enough to store the number of bits. 931 */ 932 dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type)); 933 break; 934 } 935 936 case SpvOpSDotKHR: 937 case SpvOpUDotKHR: 938 case SpvOpSUDotKHR: 939 case SpvOpSDotAccSatKHR: 940 case SpvOpUDotAccSatKHR: 941 case SpvOpSUDotAccSatKHR: 942 unreachable("Should have called vtn_handle_integer_dot instead."); 943 944 default: { 945 bool swap; 946 bool exact; 947 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type); 948 unsigned dst_bit_size = glsl_get_bit_size(dest_type); 949 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, 950 &exact, 951 src_bit_size, dst_bit_size); 952 953 if (swap) { 954 nir_ssa_def *tmp = src[0]; 955 src[0] = src[1]; 956 src[1] = tmp; 957 } 958 959 switch (op) { 960 case nir_op_ishl: 961 case nir_op_ishr: 962 case nir_op_ushr: 963 if (src[1]->bit_size != 32) 964 src[1] = nir_u2u32(&b->nb, src[1]); 965 break; 966 default: 967 break; 968 } 969 970 const bool save_exact = b->nb.exact; 971 972 if (exact) 973 b->nb.exact = true; 974 975 dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); 976 977 b->nb.exact = save_exact; 978 break; 979 } /* default */ 980 } 981 982 switch (opcode) { 983 case SpvOpIAdd: 984 case SpvOpIMul: 985 case SpvOpISub: 986 case SpvOpShiftLeftLogical: 987 case SpvOpSNegate: { 988 nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr); 989 vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu); 990 break; 991 } 992 default: 993 /* Do nothing. */ 994 break; 995 } 996 997 if (mediump_16bit) 998 vtn_mediump_upconvert_value(b, dest); 999 vtn_push_ssa_value(b, w[2], dest); 1000 1001 b->nb.exact = b->exact; 1002} 1003 1004void 1005vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode, 1006 const uint32_t *w, unsigned count) 1007{ 1008 struct vtn_value *dest_val = vtn_untyped_value(b, w[2]); 1009 const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; 1010 const unsigned dest_size = glsl_get_bit_size(dest_type); 1011 1012 vtn_handle_no_contraction(b, dest_val); 1013 1014 /* Collect the various SSA sources. 1015 * 1016 * Due to the optional "Packed Vector Format" field, determine number of 1017 * inputs from the opcode. This differs from vtn_handle_alu. 1018 */ 1019 const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR || 1020 opcode == SpvOpUDotAccSatKHR || 1021 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2; 1022 1023 vtn_assert(count >= num_inputs + 3); 1024 1025 struct vtn_ssa_value *vtn_src[3] = { NULL, }; 1026 nir_ssa_def *src[3] = { NULL, }; 1027 1028 for (unsigned i = 0; i < num_inputs; i++) { 1029 vtn_src[i] = vtn_ssa_value(b, w[i + 3]); 1030 src[i] = vtn_src[i]->def; 1031 1032 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); 1033 } 1034 1035 /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR, 1036 * the SPV_KHR_integer_dot_product spec says: 1037 * 1038 * _Vector 1_ and _Vector 2_ must have the same type. 1039 * 1040 * The practical requirement is the same bit-size and the same number of 1041 * components. 1042 */ 1043 vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) != 1044 glsl_get_bit_size(vtn_src[1]->type) || 1045 glsl_get_vector_elements(vtn_src[0]->type) != 1046 glsl_get_vector_elements(vtn_src[1]->type), 1047 "Vector 1 and vector 2 source of opcode %s must have the same " 1048 "type", 1049 spirv_op_to_string(opcode)); 1050 1051 if (num_inputs == 3) { 1052 /* The SPV_KHR_integer_dot_product spec says: 1053 * 1054 * The type of Accumulator must be the same as Result Type. 1055 * 1056 * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8 1057 * types (far below) assumes these types have the same size. 1058 */ 1059 vtn_fail_if(dest_type != vtn_src[2]->type, 1060 "Accumulator type must be the same as Result Type for " 1061 "opcode %s", 1062 spirv_op_to_string(opcode)); 1063 } 1064 1065 unsigned packed_bit_size = 8; 1066 if (glsl_type_is_vector(vtn_src[0]->type)) { 1067 /* FINISHME: Is this actually as good or better for platforms that don't 1068 * have the special instructions (i.e., one or both of has_dot_4x8 or 1069 * has_sudot_4x8 is false)? 1070 */ 1071 if (glsl_get_vector_elements(vtn_src[0]->type) == 4 && 1072 glsl_get_bit_size(vtn_src[0]->type) == 8 && 1073 glsl_get_bit_size(dest_type) <= 32) { 1074 src[0] = nir_pack_32_4x8(&b->nb, src[0]); 1075 src[1] = nir_pack_32_4x8(&b->nb, src[1]); 1076 } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 && 1077 glsl_get_bit_size(vtn_src[0]->type) == 16 && 1078 glsl_get_bit_size(dest_type) <= 32 && 1079 opcode != SpvOpSUDotKHR && 1080 opcode != SpvOpSUDotAccSatKHR) { 1081 src[0] = nir_pack_32_2x16(&b->nb, src[0]); 1082 src[1] = nir_pack_32_2x16(&b->nb, src[1]); 1083 packed_bit_size = 16; 1084 } 1085 } else if (glsl_type_is_scalar(vtn_src[0]->type) && 1086 glsl_type_is_32bit(vtn_src[0]->type)) { 1087 /* The SPV_KHR_integer_dot_product spec says: 1088 * 1089 * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed 1090 * Vector Format_ must be specified to select how the integers are to 1091 * be interpreted as vectors. 1092 * 1093 * The "Packed Vector Format" value follows the last input. 1094 */ 1095 vtn_assert(count == (num_inputs + 4)); 1096 const SpvPackedVectorFormat pack_format = w[num_inputs + 3]; 1097 vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR, 1098 "Unsupported vector packing format %d for opcode %s", 1099 pack_format, spirv_op_to_string(opcode)); 1100 } else { 1101 vtn_fail_with_opcode("Invalid source types.", opcode); 1102 } 1103 1104 nir_ssa_def *dest = NULL; 1105 1106 if (src[0]->num_components > 1) { 1107 const nir_op s_conversion_op = 1108 nir_type_conversion_op(nir_type_int, nir_type_int | dest_size, 1109 nir_rounding_mode_undef); 1110 1111 const nir_op u_conversion_op = 1112 nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size, 1113 nir_rounding_mode_undef); 1114 1115 nir_op src0_conversion_op; 1116 nir_op src1_conversion_op; 1117 1118 switch (opcode) { 1119 case SpvOpSDotKHR: 1120 case SpvOpSDotAccSatKHR: 1121 src0_conversion_op = s_conversion_op; 1122 src1_conversion_op = s_conversion_op; 1123 break; 1124 1125 case SpvOpUDotKHR: 1126 case SpvOpUDotAccSatKHR: 1127 src0_conversion_op = u_conversion_op; 1128 src1_conversion_op = u_conversion_op; 1129 break; 1130 1131 case SpvOpSUDotKHR: 1132 case SpvOpSUDotAccSatKHR: 1133 src0_conversion_op = s_conversion_op; 1134 src1_conversion_op = u_conversion_op; 1135 break; 1136 1137 default: 1138 unreachable("Invalid opcode."); 1139 } 1140 1141 /* The SPV_KHR_integer_dot_product spec says: 1142 * 1143 * All components of the input vectors are sign-extended to the bit 1144 * width of the result's type. The sign-extended input vectors are 1145 * then multiplied component-wise and all components of the vector 1146 * resulting from the component-wise multiplication are added 1147 * together. The resulting value will equal the low-order N bits of 1148 * the correct result R, where N is the result width and R is 1149 * computed with enough precision to avoid overflow and underflow. 1150 */ 1151 const unsigned vector_components = 1152 glsl_get_vector_elements(vtn_src[0]->type); 1153 1154 for (unsigned i = 0; i < vector_components; i++) { 1155 nir_ssa_def *const src0 = 1156 nir_build_alu(&b->nb, src0_conversion_op, 1157 nir_channel(&b->nb, src[0], i), NULL, NULL, NULL); 1158 1159 nir_ssa_def *const src1 = 1160 nir_build_alu(&b->nb, src1_conversion_op, 1161 nir_channel(&b->nb, src[1], i), NULL, NULL, NULL); 1162 1163 nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1); 1164 1165 dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result); 1166 } 1167 1168 if (num_inputs == 3) { 1169 /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says: 1170 * 1171 * Signed integer dot product of _Vector 1_ and _Vector 2_ and 1172 * signed saturating addition of the result with _Accumulator_. 1173 * 1174 * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says: 1175 * 1176 * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and 1177 * unsigned saturating addition of the result with _Accumulator_. 1178 * 1179 * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says: 1180 * 1181 * Mixed-signedness integer dot product of _Vector 1_ and _Vector 1182 * 2_ and signed saturating addition of the result with 1183 * _Accumulator_. 1184 */ 1185 dest = (opcode == SpvOpUDotAccSatKHR) 1186 ? nir_uadd_sat(&b->nb, dest, src[2]) 1187 : nir_iadd_sat(&b->nb, dest, src[2]); 1188 } 1189 } else { 1190 assert(src[0]->num_components == 1 && src[1]->num_components == 1); 1191 assert(src[0]->bit_size == 32 && src[1]->bit_size == 32); 1192 1193 nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32); 1194 bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR || 1195 opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR; 1196 1197 if (packed_bit_size == 16) { 1198 switch (opcode) { 1199 case SpvOpSDotKHR: 1200 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero); 1201 break; 1202 case SpvOpUDotKHR: 1203 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero); 1204 break; 1205 case SpvOpSDotAccSatKHR: 1206 if (dest_size == 32) 1207 dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]); 1208 else 1209 dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero); 1210 break; 1211 case SpvOpUDotAccSatKHR: 1212 if (dest_size == 32) 1213 dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]); 1214 else 1215 dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero); 1216 break; 1217 default: 1218 unreachable("Invalid opcode."); 1219 } 1220 } else { 1221 switch (opcode) { 1222 case SpvOpSDotKHR: 1223 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero); 1224 break; 1225 case SpvOpUDotKHR: 1226 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero); 1227 break; 1228 case SpvOpSUDotKHR: 1229 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero); 1230 break; 1231 case SpvOpSDotAccSatKHR: 1232 if (dest_size == 32) 1233 dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]); 1234 else 1235 dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero); 1236 break; 1237 case SpvOpUDotAccSatKHR: 1238 if (dest_size == 32) 1239 dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]); 1240 else 1241 dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero); 1242 break; 1243 case SpvOpSUDotAccSatKHR: 1244 if (dest_size == 32) 1245 dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]); 1246 else 1247 dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero); 1248 break; 1249 default: 1250 unreachable("Invalid opcode."); 1251 } 1252 } 1253 1254 if (dest_size != 32) { 1255 /* When the accumulator is 32-bits, a NIR dot-product with saturate 1256 * is generated above. In all other cases a regular dot-product is 1257 * generated above, and separate addition with saturate is generated 1258 * here. 1259 * 1260 * The SPV_KHR_integer_dot_product spec says: 1261 * 1262 * If any of the multiplications or additions, with the exception 1263 * of the final accumulation, overflow or underflow, the result of 1264 * the instruction is undefined. 1265 * 1266 * Therefore it is safe to cast the dot-product result down to the 1267 * size of the accumulator before doing the addition. Since the 1268 * result of the dot-product cannot overflow 32-bits, this is also 1269 * safe to cast up. 1270 */ 1271 if (num_inputs == 3) { 1272 dest = is_signed 1273 ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2]) 1274 : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]); 1275 } else { 1276 dest = is_signed 1277 ? nir_i2i(&b->nb, dest, dest_size) 1278 : nir_u2u(&b->nb, dest, dest_size); 1279 } 1280 } 1281 } 1282 1283 vtn_push_nir_ssa(b, w[2], dest); 1284 1285 b->nb.exact = b->exact; 1286} 1287 1288void 1289vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count) 1290{ 1291 vtn_assert(count == 4); 1292 /* From the definition of OpBitcast in the SPIR-V 1.2 spec: 1293 * 1294 * "If Result Type has the same number of components as Operand, they 1295 * must also have the same component width, and results are computed per 1296 * component. 1297 * 1298 * If Result Type has a different number of components than Operand, the 1299 * total number of bits in Result Type must equal the total number of 1300 * bits in Operand. Let L be the type, either Result Type or Operand’s 1301 * type, that has the larger number of components. Let S be the other 1302 * type, with the smaller number of components. The number of components 1303 * in L must be an integer multiple of the number of components in S. 1304 * The first component (that is, the only or lowest-numbered component) 1305 * of S maps to the first components of L, and so on, up to the last 1306 * component of S mapping to the last components of L. Within this 1307 * mapping, any single component of S (mapping to multiple components of 1308 * L) maps its lower-ordered bits to the lower-numbered components of L." 1309 */ 1310 1311 struct vtn_type *type = vtn_get_type(b, w[1]); 1312 struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]); 1313 1314 vtn_fail_if(src->num_components * src->bit_size != 1315 glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type), 1316 "Source and destination of OpBitcast must have the same " 1317 "total number of bits"); 1318 nir_ssa_def *val = 1319 nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type)); 1320 vtn_push_nir_ssa(b, w[2], val); 1321} 1322