1/* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8#include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h" 9 10#include "src/sksl/GLSL.std.450.h" 11 12#include "include/sksl/DSLCore.h" 13#include "src/sksl/SkSLCompiler.h" 14#include "src/sksl/SkSLOperators.h" 15#include "src/sksl/SkSLThreadContext.h" 16#include "src/sksl/ir/SkSLBinaryExpression.h" 17#include "src/sksl/ir/SkSLBlock.h" 18#include "src/sksl/ir/SkSLConstructorArrayCast.h" 19#include "src/sksl/ir/SkSLConstructorCompound.h" 20#include "src/sksl/ir/SkSLConstructorCompoundCast.h" 21#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h" 22#include "src/sksl/ir/SkSLConstructorMatrixResize.h" 23#include "src/sksl/ir/SkSLConstructorScalarCast.h" 24#include "src/sksl/ir/SkSLConstructorSplat.h" 25#include "src/sksl/ir/SkSLDoStatement.h" 26#include "src/sksl/ir/SkSLExpressionStatement.h" 27#include "src/sksl/ir/SkSLExtension.h" 28#include "src/sksl/ir/SkSLField.h" 29#include "src/sksl/ir/SkSLFieldAccess.h" 30#include "src/sksl/ir/SkSLForStatement.h" 31#include "src/sksl/ir/SkSLFunctionCall.h" 32#include "src/sksl/ir/SkSLFunctionDeclaration.h" 33#include "src/sksl/ir/SkSLFunctionDefinition.h" 34#include "src/sksl/ir/SkSLIfStatement.h" 35#include "src/sksl/ir/SkSLIndexExpression.h" 36#include "src/sksl/ir/SkSLInterfaceBlock.h" 37#include "src/sksl/ir/SkSLPostfixExpression.h" 38#include "src/sksl/ir/SkSLPrefixExpression.h" 39#include "src/sksl/ir/SkSLReturnStatement.h" 40#include "src/sksl/ir/SkSLSwitchStatement.h" 41#include "src/sksl/ir/SkSLSwizzle.h" 42#include "src/sksl/ir/SkSLTernaryExpression.h" 43#include "src/sksl/ir/SkSLVarDeclarations.h" 44#include "src/sksl/ir/SkSLVariableReference.h" 45 46#ifdef SK_VULKAN 47#include "src/gpu/vk/GrVkCaps.h" 48#endif 49 50#define kLast_Capability SpvCapabilityMultiViewport 51 52constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000; 53constexpr int DEVICE_CLOCKWISE_BUILTIN = -1001; 54 55namespace SkSL { 56 57// Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the 58// sksl generator if we want. 59// https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84 60static const int32_t SKSL_MAGIC = 0x001F0000; 61 62void SPIRVCodeGenerator::setupIntrinsics() { 63#define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \ 64 GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x) 65#define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicOpcodeKind, \ 66 GLSLstd450 ## ifFloat, \ 67 GLSLstd450 ## ifInt, \ 68 GLSLstd450 ## ifUInt, \ 69 SpvOpUndef) 70#define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicOpcodeKind, \ 71 SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x) 72#define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \ 73 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \ 74 k ## x ## _SpecialIntrinsic) 75 fIntrinsicMap[k_round_IntrinsicKind] = ALL_GLSL(Round); 76 fIntrinsicMap[k_roundEven_IntrinsicKind] = ALL_GLSL(RoundEven); 77 fIntrinsicMap[k_trunc_IntrinsicKind] = ALL_GLSL(Trunc); 78 fIntrinsicMap[k_abs_IntrinsicKind] = BY_TYPE_GLSL(FAbs, SAbs, SAbs); 79 fIntrinsicMap[k_sign_IntrinsicKind] = BY_TYPE_GLSL(FSign, SSign, SSign); 80 fIntrinsicMap[k_floor_IntrinsicKind] = ALL_GLSL(Floor); 81 fIntrinsicMap[k_ceil_IntrinsicKind] = ALL_GLSL(Ceil); 82 fIntrinsicMap[k_fract_IntrinsicKind] = ALL_GLSL(Fract); 83 fIntrinsicMap[k_radians_IntrinsicKind] = ALL_GLSL(Radians); 84 fIntrinsicMap[k_degrees_IntrinsicKind] = ALL_GLSL(Degrees); 85 fIntrinsicMap[k_sin_IntrinsicKind] = ALL_GLSL(Sin); 86 fIntrinsicMap[k_cos_IntrinsicKind] = ALL_GLSL(Cos); 87 fIntrinsicMap[k_tan_IntrinsicKind] = ALL_GLSL(Tan); 88 fIntrinsicMap[k_asin_IntrinsicKind] = ALL_GLSL(Asin); 89 fIntrinsicMap[k_acos_IntrinsicKind] = ALL_GLSL(Acos); 90 fIntrinsicMap[k_atan_IntrinsicKind] = SPECIAL(Atan); 91 fIntrinsicMap[k_sinh_IntrinsicKind] = ALL_GLSL(Sinh); 92 fIntrinsicMap[k_cosh_IntrinsicKind] = ALL_GLSL(Cosh); 93 fIntrinsicMap[k_tanh_IntrinsicKind] = ALL_GLSL(Tanh); 94 fIntrinsicMap[k_asinh_IntrinsicKind] = ALL_GLSL(Asinh); 95 fIntrinsicMap[k_acosh_IntrinsicKind] = ALL_GLSL(Acosh); 96 fIntrinsicMap[k_atanh_IntrinsicKind] = ALL_GLSL(Atanh); 97 fIntrinsicMap[k_pow_IntrinsicKind] = ALL_GLSL(Pow); 98 fIntrinsicMap[k_exp_IntrinsicKind] = ALL_GLSL(Exp); 99 fIntrinsicMap[k_log_IntrinsicKind] = ALL_GLSL(Log); 100 fIntrinsicMap[k_exp2_IntrinsicKind] = ALL_GLSL(Exp2); 101 fIntrinsicMap[k_log2_IntrinsicKind] = ALL_GLSL(Log2); 102 fIntrinsicMap[k_sqrt_IntrinsicKind] = ALL_GLSL(Sqrt); 103 fIntrinsicMap[k_inverse_IntrinsicKind] = ALL_GLSL(MatrixInverse); 104 fIntrinsicMap[k_outerProduct_IntrinsicKind] = ALL_SPIRV(OuterProduct); 105 fIntrinsicMap[k_transpose_IntrinsicKind] = ALL_SPIRV(Transpose); 106 fIntrinsicMap[k_isinf_IntrinsicKind] = ALL_SPIRV(IsInf); 107 fIntrinsicMap[k_isnan_IntrinsicKind] = ALL_SPIRV(IsNan); 108 fIntrinsicMap[k_inversesqrt_IntrinsicKind] = ALL_GLSL(InverseSqrt); 109 fIntrinsicMap[k_determinant_IntrinsicKind] = ALL_GLSL(Determinant); 110 fIntrinsicMap[k_matrixCompMult_IntrinsicKind] = SPECIAL(MatrixCompMult); 111 fIntrinsicMap[k_matrixInverse_IntrinsicKind] = ALL_GLSL(MatrixInverse); 112 fIntrinsicMap[k_mod_IntrinsicKind] = SPECIAL(Mod); 113 fIntrinsicMap[k_modf_IntrinsicKind] = ALL_GLSL(Modf); 114 fIntrinsicMap[k_min_IntrinsicKind] = SPECIAL(Min); 115 fIntrinsicMap[k_max_IntrinsicKind] = SPECIAL(Max); 116 fIntrinsicMap[k_clamp_IntrinsicKind] = SPECIAL(Clamp); 117 fIntrinsicMap[k_saturate_IntrinsicKind] = SPECIAL(Saturate); 118 fIntrinsicMap[k_dot_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 119 SpvOpDot, SpvOpUndef, SpvOpUndef, SpvOpUndef); 120 fIntrinsicMap[k_mix_IntrinsicKind] = SPECIAL(Mix); 121 fIntrinsicMap[k_step_IntrinsicKind] = SPECIAL(Step); 122 fIntrinsicMap[k_smoothstep_IntrinsicKind] = SPECIAL(SmoothStep); 123 fIntrinsicMap[k_fma_IntrinsicKind] = ALL_GLSL(Fma); 124 fIntrinsicMap[k_frexp_IntrinsicKind] = ALL_GLSL(Frexp); 125 fIntrinsicMap[k_ldexp_IntrinsicKind] = ALL_GLSL(Ldexp); 126 127#define PACK(type) fIntrinsicMap[k_pack##type##_IntrinsicKind] = ALL_GLSL(Pack##type); \ 128 fIntrinsicMap[k_unpack##type##_IntrinsicKind] = ALL_GLSL(Unpack##type) 129 PACK(Snorm4x8); 130 PACK(Unorm4x8); 131 PACK(Snorm2x16); 132 PACK(Unorm2x16); 133 PACK(Half2x16); 134 PACK(Double2x32); 135#undef PACK 136 fIntrinsicMap[k_length_IntrinsicKind] = ALL_GLSL(Length); 137 fIntrinsicMap[k_distance_IntrinsicKind] = ALL_GLSL(Distance); 138 fIntrinsicMap[k_cross_IntrinsicKind] = ALL_GLSL(Cross); 139 fIntrinsicMap[k_normalize_IntrinsicKind] = ALL_GLSL(Normalize); 140 fIntrinsicMap[k_faceforward_IntrinsicKind] = ALL_GLSL(FaceForward); 141 fIntrinsicMap[k_reflect_IntrinsicKind] = ALL_GLSL(Reflect); 142 fIntrinsicMap[k_refract_IntrinsicKind] = ALL_GLSL(Refract); 143 fIntrinsicMap[k_bitCount_IntrinsicKind] = ALL_SPIRV(BitCount); 144 fIntrinsicMap[k_findLSB_IntrinsicKind] = ALL_GLSL(FindILsb); 145 fIntrinsicMap[k_findMSB_IntrinsicKind] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb); 146 fIntrinsicMap[k_dFdx_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 147 SpvOpDPdx, SpvOpUndef, 148 SpvOpUndef, SpvOpUndef); 149 fIntrinsicMap[k_dFdy_IntrinsicKind] = SPECIAL(DFdy); 150 fIntrinsicMap[k_fwidth_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 151 SpvOpFwidth, SpvOpUndef, 152 SpvOpUndef, SpvOpUndef); 153 fIntrinsicMap[k_makeSampler2D_IntrinsicKind] = SPECIAL(SampledImage); 154 155 fIntrinsicMap[k_sample_IntrinsicKind] = SPECIAL(Texture); 156 fIntrinsicMap[k_subpassLoad_IntrinsicKind] = SPECIAL(SubpassLoad); 157 158 fIntrinsicMap[k_floatBitsToInt_IntrinsicKind] = ALL_SPIRV(Bitcast); 159 fIntrinsicMap[k_floatBitsToUint_IntrinsicKind] = ALL_SPIRV(Bitcast); 160 fIntrinsicMap[k_intBitsToFloat_IntrinsicKind] = ALL_SPIRV(Bitcast); 161 fIntrinsicMap[k_uintBitsToFloat_IntrinsicKind] = ALL_SPIRV(Bitcast); 162 163 fIntrinsicMap[k_any_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 164 SpvOpUndef, SpvOpUndef, 165 SpvOpUndef, SpvOpAny); 166 fIntrinsicMap[k_all_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 167 SpvOpUndef, SpvOpUndef, 168 SpvOpUndef, SpvOpAll); 169 fIntrinsicMap[k_not_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 170 SpvOpUndef, SpvOpUndef, SpvOpUndef, 171 SpvOpLogicalNot); 172 fIntrinsicMap[k_equal_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 173 SpvOpFOrdEqual, SpvOpIEqual, 174 SpvOpIEqual, SpvOpLogicalEqual); 175 fIntrinsicMap[k_notEqual_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 176 SpvOpFOrdNotEqual, SpvOpINotEqual, 177 SpvOpINotEqual, 178 SpvOpLogicalNotEqual); 179 fIntrinsicMap[k_lessThan_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 180 SpvOpFOrdLessThan, 181 SpvOpSLessThan, 182 SpvOpULessThan, 183 SpvOpUndef); 184 fIntrinsicMap[k_lessThanEqual_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 185 SpvOpFOrdLessThanEqual, 186 SpvOpSLessThanEqual, 187 SpvOpULessThanEqual, 188 SpvOpUndef); 189 fIntrinsicMap[k_greaterThan_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 190 SpvOpFOrdGreaterThan, 191 SpvOpSGreaterThan, 192 SpvOpUGreaterThan, 193 SpvOpUndef); 194 fIntrinsicMap[k_greaterThanEqual_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind, 195 SpvOpFOrdGreaterThanEqual, 196 SpvOpSGreaterThanEqual, 197 SpvOpUGreaterThanEqual, 198 SpvOpUndef); 199// interpolateAt* not yet supported... 200} 201 202void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) { 203 out.write((const char*) &word, sizeof(word)); 204} 205 206static bool is_float(const Context& context, const Type& type) { 207 return (type.isScalar() || type.isVector() || type.isMatrix()) && 208 type.componentType().isFloat(); 209} 210 211static bool is_signed(const Context& context, const Type& type) { 212 return (type.isScalar() || type.isVector()) && type.componentType().isSigned(); 213} 214 215static bool is_unsigned(const Context& context, const Type& type) { 216 return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned(); 217} 218 219static bool is_bool(const Context& context, const Type& type) { 220 return (type.isScalar() || type.isVector()) && type.componentType().isBoolean(); 221} 222 223static bool is_out(const Modifiers& m) { 224 return (m.fFlags & Modifiers::kOut_Flag) != 0; 225} 226 227static bool is_in(const Modifiers& m) { 228 switch (m.fFlags & (Modifiers::kOut_Flag | Modifiers::kIn_Flag)) { 229 case Modifiers::kOut_Flag: // out 230 return false; 231 232 case 0: // implicit in 233 case Modifiers::kIn_Flag: // explicit in 234 case Modifiers::kOut_Flag | Modifiers::kIn_Flag: // inout 235 return true; 236 237 default: SkUNREACHABLE; 238 } 239} 240 241void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) { 242 SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer); 243 SkASSERT(opCode != SpvOpUndef); 244 switch (opCode) { 245 case SpvOpReturn: // fall through 246 case SpvOpReturnValue: // fall through 247 case SpvOpKill: // fall through 248 case SpvOpSwitch: // fall through 249 case SpvOpBranch: // fall through 250 case SpvOpBranchConditional: 251 if (fCurrentBlock == 0) { 252 // We just encountered dead code--instructions that don't have an associated block. 253 // Synthesize a label if this happens; this is necessary to satisfy the validator. 254 this->writeLabel(this->nextId(nullptr), out); 255 } 256 fCurrentBlock = 0; 257 break; 258 case SpvOpConstant: // fall through 259 case SpvOpConstantTrue: // fall through 260 case SpvOpConstantFalse: // fall through 261 case SpvOpConstantComposite: // fall through 262 case SpvOpTypeVoid: // fall through 263 case SpvOpTypeInt: // fall through 264 case SpvOpTypeFloat: // fall through 265 case SpvOpTypeBool: // fall through 266 case SpvOpTypeVector: // fall through 267 case SpvOpTypeMatrix: // fall through 268 case SpvOpTypeArray: // fall through 269 case SpvOpTypePointer: // fall through 270 case SpvOpTypeFunction: // fall through 271 case SpvOpTypeRuntimeArray: // fall through 272 case SpvOpTypeStruct: // fall through 273 case SpvOpTypeImage: // fall through 274 case SpvOpTypeSampledImage: // fall through 275 case SpvOpTypeSampler: // fall through 276 case SpvOpVariable: // fall through 277 case SpvOpFunction: // fall through 278 case SpvOpFunctionParameter: // fall through 279 case SpvOpFunctionEnd: // fall through 280 case SpvOpExecutionMode: // fall through 281 case SpvOpMemoryModel: // fall through 282 case SpvOpCapability: // fall through 283 case SpvOpExtInstImport: // fall through 284 case SpvOpEntryPoint: // fall through 285 case SpvOpSource: // fall through 286 case SpvOpSourceExtension: // fall through 287 case SpvOpName: // fall through 288 case SpvOpMemberName: // fall through 289 case SpvOpDecorate: // fall through 290 case SpvOpMemberDecorate: 291 break; 292 default: 293 // We may find ourselves with dead code--instructions that don't have an associated 294 // block. This should be a rare event, but if it happens, synthesize a label; this is 295 // necessary to satisfy the validator. 296 if (fCurrentBlock == 0) { 297 this->writeLabel(this->nextId(nullptr), out); 298 } 299 break; 300 } 301 this->writeWord((length << 16) | opCode, out); 302} 303 304void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) { 305 SkASSERT(!fCurrentBlock); 306 fCurrentBlock = label; 307 this->writeInstruction(SpvOpLabel, label, out); 308} 309 310void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) { 311 this->writeOpCode(opCode, 1, out); 312} 313 314void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) { 315 this->writeOpCode(opCode, 2, out); 316 this->writeWord(word1, out); 317} 318 319void SPIRVCodeGenerator::writeString(skstd::string_view s, OutputStream& out) { 320 out.write(s.data(), s.length()); 321 switch (s.length() % 4) { 322 case 1: 323 out.write8(0); 324 [[fallthrough]]; 325 case 2: 326 out.write8(0); 327 [[fallthrough]]; 328 case 3: 329 out.write8(0); 330 break; 331 default: 332 this->writeWord(0, out); 333 break; 334 } 335} 336 337void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, skstd::string_view string, 338 OutputStream& out) { 339 this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out); 340 this->writeString(string, out); 341} 342 343 344void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, skstd::string_view string, 345 OutputStream& out) { 346 this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out); 347 this->writeWord(word1, out); 348 this->writeString(string, out); 349} 350 351void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 352 skstd::string_view string, OutputStream& out) { 353 this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out); 354 this->writeWord(word1, out); 355 this->writeWord(word2, out); 356 this->writeString(string, out); 357} 358 359void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 360 OutputStream& out) { 361 this->writeOpCode(opCode, 3, out); 362 this->writeWord(word1, out); 363 this->writeWord(word2, out); 364} 365 366void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 367 int32_t word3, OutputStream& out) { 368 this->writeOpCode(opCode, 4, out); 369 this->writeWord(word1, out); 370 this->writeWord(word2, out); 371 this->writeWord(word3, out); 372} 373 374void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 375 int32_t word3, int32_t word4, OutputStream& out) { 376 this->writeOpCode(opCode, 5, out); 377 this->writeWord(word1, out); 378 this->writeWord(word2, out); 379 this->writeWord(word3, out); 380 this->writeWord(word4, out); 381} 382 383void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 384 int32_t word3, int32_t word4, int32_t word5, 385 OutputStream& out) { 386 this->writeOpCode(opCode, 6, out); 387 this->writeWord(word1, out); 388 this->writeWord(word2, out); 389 this->writeWord(word3, out); 390 this->writeWord(word4, out); 391 this->writeWord(word5, out); 392} 393 394void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 395 int32_t word3, int32_t word4, int32_t word5, 396 int32_t word6, OutputStream& out) { 397 this->writeOpCode(opCode, 7, out); 398 this->writeWord(word1, out); 399 this->writeWord(word2, out); 400 this->writeWord(word3, out); 401 this->writeWord(word4, out); 402 this->writeWord(word5, out); 403 this->writeWord(word6, out); 404} 405 406void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 407 int32_t word3, int32_t word4, int32_t word5, 408 int32_t word6, int32_t word7, OutputStream& out) { 409 this->writeOpCode(opCode, 8, out); 410 this->writeWord(word1, out); 411 this->writeWord(word2, out); 412 this->writeWord(word3, out); 413 this->writeWord(word4, out); 414 this->writeWord(word5, out); 415 this->writeWord(word6, out); 416 this->writeWord(word7, out); 417} 418 419void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 420 int32_t word3, int32_t word4, int32_t word5, 421 int32_t word6, int32_t word7, int32_t word8, 422 OutputStream& out) { 423 this->writeOpCode(opCode, 9, out); 424 this->writeWord(word1, out); 425 this->writeWord(word2, out); 426 this->writeWord(word3, out); 427 this->writeWord(word4, out); 428 this->writeWord(word5, out); 429 this->writeWord(word6, out); 430 this->writeWord(word7, out); 431 this->writeWord(word8, out); 432} 433 434void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) { 435 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) { 436 if (fCapabilities & bit) { 437 this->writeInstruction(SpvOpCapability, (SpvId) i, out); 438 } 439 } 440 this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out); 441} 442 443SpvId SPIRVCodeGenerator::nextId(const Type* type) { 444 return this->nextId(type && type->hasPrecision() && !type->highPrecision() 445 ? Precision::kRelaxed 446 : Precision::kDefault); 447} 448 449SpvId SPIRVCodeGenerator::nextId(Precision precision) { 450 if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) { 451 this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision, 452 fDecorationBuffer); 453 } 454 return fIdCount++; 455} 456 457void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout, 458 SpvId resultId) { 459 this->writeInstruction(SpvOpName, resultId, String(type.name()).c_str(), fNameBuffer); 460 // go ahead and write all of the field types, so we don't inadvertently write them while we're 461 // in the middle of writing the struct instruction 462 std::vector<SpvId> types; 463 for (const auto& f : type.fields()) { 464 types.push_back(this->getType(*f.fType, memoryLayout)); 465 } 466 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer); 467 this->writeWord(resultId, fConstantBuffer); 468 for (SpvId id : types) { 469 this->writeWord(id, fConstantBuffer); 470 } 471 size_t offset = 0; 472 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) { 473 const Type::Field& field = type.fields()[i]; 474 if (!MemoryLayout::LayoutIsSupported(*field.fType)) { 475 fContext.fErrors->error(type.fLine, "type '" + field.fType->name() + 476 "' is not permitted here"); 477 return; 478 } 479 size_t size = memoryLayout.size(*field.fType); 480 size_t alignment = memoryLayout.alignment(*field.fType); 481 const Layout& fieldLayout = field.fModifiers.fLayout; 482 if (fieldLayout.fOffset >= 0) { 483 if (fieldLayout.fOffset < (int) offset) { 484 fContext.fErrors->error(type.fLine, 485 "offset of field '" + field.fName + "' must be at " 486 "least " + to_string((int) offset)); 487 } 488 if (fieldLayout.fOffset % alignment) { 489 fContext.fErrors->error(type.fLine, 490 "offset of field '" + field.fName + "' must be a multiple" 491 " of " + to_string((int) alignment)); 492 } 493 offset = fieldLayout.fOffset; 494 } else { 495 size_t mod = offset % alignment; 496 if (mod) { 497 offset += alignment - mod; 498 } 499 } 500 this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer); 501 this->writeLayout(fieldLayout, resultId, i); 502 if (field.fModifiers.fLayout.fBuiltin < 0) { 503 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset, 504 (SpvId) offset, fDecorationBuffer); 505 } 506 if (field.fType->isMatrix()) { 507 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor, 508 fDecorationBuffer); 509 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride, 510 (SpvId) memoryLayout.stride(*field.fType), 511 fDecorationBuffer); 512 } 513 if (!field.fType->highPrecision()) { 514 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, 515 SpvDecorationRelaxedPrecision, fDecorationBuffer); 516 } 517 offset += size; 518 if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) { 519 offset += alignment - offset % alignment; 520 } 521 } 522} 523 524const Type& SPIRVCodeGenerator::getActualType(const Type& type) { 525 if (type.isFloat()) { 526 return *fContext.fTypes.fFloat; 527 } 528 if (type.isSigned()) { 529 return *fContext.fTypes.fInt; 530 } 531 if (type.isUnsigned()) { 532 return *fContext.fTypes.fUInt; 533 } 534 if (type.isMatrix() || type.isVector()) { 535 if (type.componentType() == *fContext.fTypes.fHalf) { 536 return fContext.fTypes.fFloat->toCompound(fContext, type.columns(), type.rows()); 537 } 538 if (type.componentType() == *fContext.fTypes.fShort) { 539 return fContext.fTypes.fInt->toCompound(fContext, type.columns(), type.rows()); 540 } 541 if (type.componentType() == *fContext.fTypes.fUShort) { 542 return fContext.fTypes.fUInt->toCompound(fContext, type.columns(), type.rows()); 543 } 544 } 545 return type; 546} 547 548SpvId SPIRVCodeGenerator::getType(const Type& type) { 549 return this->getType(type, fDefaultLayout); 550} 551 552SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) { 553 const Type* type; 554 std::unique_ptr<Type> arrayType; 555 String arrayName; 556 557 if (rawType.isArray()) { 558 // For arrays, we need to synthesize a temporary Array type using the "actual" component 559 // type. That is, if `short[10]` is passed in, we need to synthesize a `int[10]` Type. 560 // Otherwise, we can end up with two different SpvIds for the same array type. 561 const Type& component = this->getActualType(rawType.componentType()); 562 arrayName = component.getArrayName(rawType.columns()); 563 arrayType = Type::MakeArrayType(arrayName, component, rawType.columns()); 564 type = arrayType.get(); 565 } else { 566 // For non-array types, we can simply look up the "actual" type and use it. 567 type = &this->getActualType(rawType); 568 } 569 570 String key(type->name()); 571 if (type->isStruct() || type->isArray()) { 572 key += to_string((int)layout.fStd); 573#ifdef SK_DEBUG 574 SkASSERT(layout.fStd == MemoryLayout::Standard::k140_Standard || 575 layout.fStd == MemoryLayout::Standard::k430_Standard); 576 MemoryLayout::Standard otherStd = layout.fStd == MemoryLayout::Standard::k140_Standard 577 ? MemoryLayout::Standard::k430_Standard 578 : MemoryLayout::Standard::k140_Standard; 579 String otherKey = type->name() + to_string((int)otherStd); 580 SkASSERT(fTypeMap.find(otherKey) == fTypeMap.end()); 581#endif 582 } 583 auto entry = fTypeMap.find(key); 584 if (entry == fTypeMap.end()) { 585 SpvId result = this->nextId(nullptr); 586 switch (type->typeKind()) { 587 case Type::TypeKind::kScalar: 588 if (type->isBoolean()) { 589 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer); 590 } else if (type->isSigned()) { 591 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer); 592 } else if (type->isUnsigned()) { 593 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer); 594 } else if (type->isFloat()) { 595 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer); 596 } else { 597 SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str()); 598 } 599 break; 600 case Type::TypeKind::kVector: 601 this->writeInstruction(SpvOpTypeVector, result, 602 this->getType(type->componentType(), layout), 603 type->columns(), fConstantBuffer); 604 break; 605 case Type::TypeKind::kMatrix: 606 this->writeInstruction( 607 SpvOpTypeMatrix, 608 result, 609 this->getType(IndexExpression::IndexType(fContext, *type), layout), 610 type->columns(), 611 fConstantBuffer); 612 break; 613 case Type::TypeKind::kStruct: 614 this->writeStruct(*type, layout, result); 615 break; 616 case Type::TypeKind::kArray: { 617 if (!MemoryLayout::LayoutIsSupported(*type)) { 618 fContext.fErrors->error(type->fLine, 619 "type '" + type->name() + "' is not permitted here"); 620 return this->nextId(nullptr); 621 } 622 if (type->columns() > 0) { 623 SpvId typeId = this->getType(type->componentType(), layout); 624 SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt); 625 this->writeInstruction(SpvOpTypeArray, result, typeId, countId, 626 fConstantBuffer); 627 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 628 (int32_t) layout.stride(*type), 629 fDecorationBuffer); 630 } else { 631 // We shouldn't have any runtime-sized arrays right now 632 fContext.fErrors->error(type->fLine, 633 "runtime-sized arrays are not supported in SPIR-V"); 634 this->writeInstruction(SpvOpTypeRuntimeArray, result, 635 this->getType(type->componentType(), layout), 636 fConstantBuffer); 637 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 638 (int32_t) layout.stride(*type), 639 fDecorationBuffer); 640 } 641 break; 642 } 643 case Type::TypeKind::kSampler: { 644 SpvId image = result; 645 if (SpvDimSubpassData != type->dimensions()) { 646 image = this->getType(type->textureType(), layout); 647 } 648 if (SpvDimBuffer == type->dimensions()) { 649 fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer); 650 } 651 if (SpvDimSubpassData != type->dimensions()) { 652 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer); 653 } 654 break; 655 } 656 case Type::TypeKind::kSeparateSampler: { 657 this->writeInstruction(SpvOpTypeSampler, result, fConstantBuffer); 658 break; 659 } 660 case Type::TypeKind::kTexture: { 661 this->writeInstruction(SpvOpTypeImage, result, 662 this->getType(*fContext.fTypes.fFloat, layout), 663 type->dimensions(), type->isDepth(), 664 type->isArrayedTexture(), type->isMultisampled(), 665 type->isSampled() ? 1 : 2, SpvImageFormatUnknown, 666 fConstantBuffer); 667 fImageTypeMap[key] = result; 668 break; 669 } 670 default: 671 if (type->isVoid()) { 672 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer); 673 } else { 674 SkDEBUGFAILF("invalid type: %s", type->description().c_str()); 675 } 676 break; 677 } 678 fTypeMap[key] = result; 679 return result; 680 } 681 return entry->second; 682} 683 684SpvId SPIRVCodeGenerator::getImageType(const Type& type) { 685 SkASSERT(type.typeKind() == Type::TypeKind::kSampler); 686 this->getType(type); 687 String key = type.name() + to_string((int) fDefaultLayout.fStd); 688 SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end()); 689 return fImageTypeMap[key]; 690} 691 692SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) { 693 String key = to_string(this->getType(function.returnType())) + "("; 694 String separator; 695 const std::vector<const Variable*>& parameters = function.parameters(); 696 for (size_t i = 0; i < parameters.size(); i++) { 697 key += separator; 698 separator = ", "; 699 key += to_string(this->getType(parameters[i]->type())); 700 } 701 key += ")"; 702 auto entry = fTypeMap.find(key); 703 if (entry == fTypeMap.end()) { 704 SpvId result = this->nextId(nullptr); 705 int32_t length = 3 + (int32_t) parameters.size(); 706 SpvId returnType = this->getType(function.returnType()); 707 std::vector<SpvId> parameterTypes; 708 for (size_t i = 0; i < parameters.size(); i++) { 709 // glslang seems to treat all function arguments as pointers whether they need to be or 710 // not. I was initially puzzled by this until I ran bizarre failures with certain 711 // patterns of function calls and control constructs, as exemplified by this minimal 712 // failure case: 713 // 714 // void sphere(float x) { 715 // } 716 // 717 // void map() { 718 // sphere(1.0); 719 // } 720 // 721 // void main() { 722 // for (int i = 0; i < 1; i++) { 723 // map(); 724 // } 725 // } 726 // 727 // As of this writing, compiling this in the "obvious" way (with sphere taking a float) 728 // crashes. Making it take a float* and storing the argument in a temporary variable, 729 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of 730 // the spec makes this make sense. 731 parameterTypes.push_back(this->getPointerType(parameters[i]->type(), 732 SpvStorageClassFunction)); 733 } 734 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer); 735 this->writeWord(result, fConstantBuffer); 736 this->writeWord(returnType, fConstantBuffer); 737 for (SpvId id : parameterTypes) { 738 this->writeWord(id, fConstantBuffer); 739 } 740 fTypeMap[key] = result; 741 return result; 742 } 743 return entry->second; 744} 745 746SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) { 747 return this->getPointerType(type, fDefaultLayout, storageClass); 748} 749 750SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout, 751 SpvStorageClass_ storageClass) { 752 const Type& type = this->getActualType(rawType); 753 String key = type.displayName() + "*" + to_string(layout.fStd) + to_string(storageClass); 754 auto entry = fTypeMap.find(key); 755 if (entry == fTypeMap.end()) { 756 SpvId result = this->nextId(nullptr); 757 this->writeInstruction(SpvOpTypePointer, result, storageClass, 758 this->getType(type), fConstantBuffer); 759 fTypeMap[key] = result; 760 return result; 761 } 762 return entry->second; 763} 764 765SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) { 766 switch (expr.kind()) { 767 case Expression::Kind::kBinary: 768 return this->writeBinaryExpression(expr.as<BinaryExpression>(), out); 769 case Expression::Kind::kConstructorArrayCast: 770 return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out); 771 case Expression::Kind::kConstructorArray: 772 case Expression::Kind::kConstructorStruct: 773 return this->writeCompositeConstructor(expr.asAnyConstructor(), out); 774 case Expression::Kind::kConstructorDiagonalMatrix: 775 return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out); 776 case Expression::Kind::kConstructorMatrixResize: 777 return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out); 778 case Expression::Kind::kConstructorScalarCast: 779 return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out); 780 case Expression::Kind::kConstructorSplat: 781 return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out); 782 case Expression::Kind::kConstructorCompound: 783 return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out); 784 case Expression::Kind::kConstructorCompoundCast: 785 return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out); 786 case Expression::Kind::kFieldAccess: 787 return this->writeFieldAccess(expr.as<FieldAccess>(), out); 788 case Expression::Kind::kFunctionCall: 789 return this->writeFunctionCall(expr.as<FunctionCall>(), out); 790 case Expression::Kind::kLiteral: 791 return this->writeLiteral(expr.as<Literal>()); 792 case Expression::Kind::kPrefix: 793 return this->writePrefixExpression(expr.as<PrefixExpression>(), out); 794 case Expression::Kind::kPostfix: 795 return this->writePostfixExpression(expr.as<PostfixExpression>(), out); 796 case Expression::Kind::kSwizzle: 797 return this->writeSwizzle(expr.as<Swizzle>(), out); 798 case Expression::Kind::kVariableReference: 799 return this->writeVariableReference(expr.as<VariableReference>(), out); 800 case Expression::Kind::kTernary: 801 return this->writeTernaryExpression(expr.as<TernaryExpression>(), out); 802 case Expression::Kind::kIndex: 803 return this->writeIndexExpression(expr.as<IndexExpression>(), out); 804 default: 805 SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str()); 806 break; 807 } 808 return -1; 809} 810 811SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) { 812 const FunctionDeclaration& function = c.function(); 813 auto intrinsic = fIntrinsicMap.find(function.intrinsicKind()); 814 if (intrinsic == fIntrinsicMap.end()) { 815 fContext.fErrors->error(c.fLine, "unsupported intrinsic '" + function.description() + "'"); 816 return -1; 817 } 818 int32_t intrinsicId; 819 const ExpressionArray& arguments = c.arguments(); 820 if (arguments.size() > 0) { 821 const Type& type = arguments[0]->type(); 822 if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicOpcodeKind || 823 is_float(fContext, type)) { 824 intrinsicId = std::get<1>(intrinsic->second); 825 } else if (is_signed(fContext, type)) { 826 intrinsicId = std::get<2>(intrinsic->second); 827 } else if (is_unsigned(fContext, type)) { 828 intrinsicId = std::get<3>(intrinsic->second); 829 } else if (is_bool(fContext, type)) { 830 intrinsicId = std::get<4>(intrinsic->second); 831 } else { 832 intrinsicId = std::get<1>(intrinsic->second); 833 } 834 } else { 835 intrinsicId = std::get<1>(intrinsic->second); 836 } 837 switch (std::get<0>(intrinsic->second)) { 838 case kGLSL_STD_450_IntrinsicOpcodeKind: { 839 SpvId result = this->nextId(&c.type()); 840 std::vector<SpvId> argumentIds; 841 std::vector<TempVar> tempVars; 842 argumentIds.reserve(arguments.size()); 843 for (size_t i = 0; i < arguments.size(); i++) { 844 if (is_out(function.parameters()[i]->modifiers())) { 845 argumentIds.push_back( 846 this->writeFunctionCallArgument(*arguments[i], 847 function.parameters()[i]->modifiers(), 848 &tempVars, 849 out)); 850 } else { 851 argumentIds.push_back(this->writeExpression(*arguments[i], out)); 852 } 853 } 854 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out); 855 this->writeWord(this->getType(c.type()), out); 856 this->writeWord(result, out); 857 this->writeWord(fGLSLExtendedInstructions, out); 858 this->writeWord(intrinsicId, out); 859 for (SpvId id : argumentIds) { 860 this->writeWord(id, out); 861 } 862 this->copyBackTempVars(tempVars, out); 863 return result; 864 } 865 case kSPIRV_IntrinsicOpcodeKind: { 866 // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul 867 if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) { 868 intrinsicId = SpvOpFMul; 869 } 870 SpvId result = this->nextId(&c.type()); 871 std::vector<SpvId> argumentIds; 872 std::vector<TempVar> tempVars; 873 argumentIds.reserve(arguments.size()); 874 for (size_t i = 0; i < arguments.size(); i++) { 875 if (is_out(function.parameters()[i]->modifiers())) { 876 argumentIds.push_back( 877 this->writeFunctionCallArgument(*arguments[i], 878 function.parameters()[i]->modifiers(), 879 &tempVars, 880 out)); 881 } else { 882 argumentIds.push_back(this->writeExpression(*arguments[i], out)); 883 } 884 } 885 if (!c.type().isVoid()) { 886 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out); 887 this->writeWord(this->getType(c.type()), out); 888 this->writeWord(result, out); 889 } else { 890 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out); 891 } 892 for (SpvId id : argumentIds) { 893 this->writeWord(id, out); 894 } 895 this->copyBackTempVars(tempVars, out); 896 return result; 897 } 898 case kSpecial_IntrinsicOpcodeKind: 899 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out); 900 default: 901 fContext.fErrors->error(c.fLine, "unsupported intrinsic '" + function.description() + 902 "'"); 903 return -1; 904 } 905} 906 907SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) { 908 SkASSERT(vectorSize >= 1 && vectorSize <= 4); 909 const Type& argType = arg.type(); 910 SpvId raw = this->writeExpression(arg, out); 911 if (argType.isScalar()) { 912 if (vectorSize == 1) { 913 return raw; 914 } 915 SpvId vector = this->nextId(&argType); 916 this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out); 917 this->writeWord(this->getType(argType.toCompound(fContext, vectorSize, 1)), out); 918 this->writeWord(vector, out); 919 for (int i = 0; i < vectorSize; i++) { 920 this->writeWord(raw, out); 921 } 922 return vector; 923 } else { 924 SkASSERT(vectorSize == argType.columns()); 925 return raw; 926 } 927} 928 929std::vector<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) { 930 int vectorSize = 1; 931 for (const auto& a : args) { 932 if (a->type().isVector()) { 933 if (vectorSize > 1) { 934 SkASSERT(a->type().columns() == vectorSize); 935 } else { 936 vectorSize = a->type().columns(); 937 } 938 } 939 } 940 std::vector<SpvId> result; 941 result.reserve(args.size()); 942 for (const auto& arg : args) { 943 result.push_back(this->vectorize(*arg, vectorSize, out)); 944 } 945 return result; 946} 947 948void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst, 949 SpvId signedInst, SpvId unsignedInst, 950 const std::vector<SpvId>& args, 951 OutputStream& out) { 952 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out); 953 this->writeWord(this->getType(type), out); 954 this->writeWord(id, out); 955 this->writeWord(fGLSLExtendedInstructions, out); 956 957 if (is_float(fContext, type)) { 958 this->writeWord(floatInst, out); 959 } else if (is_signed(fContext, type)) { 960 this->writeWord(signedInst, out); 961 } else if (is_unsigned(fContext, type)) { 962 this->writeWord(unsignedInst, out); 963 } else { 964 SkASSERT(false); 965 } 966 for (SpvId a : args) { 967 this->writeWord(a, out); 968 } 969} 970 971SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, 972 OutputStream& out) { 973 const ExpressionArray& arguments = c.arguments(); 974 const Type& callType = c.type(); 975 SpvId result = this->nextId(nullptr); 976 switch (kind) { 977 case kAtan_SpecialIntrinsic: { 978 std::vector<SpvId> argumentIds; 979 argumentIds.reserve(arguments.size()); 980 for (const std::unique_ptr<Expression>& arg : arguments) { 981 argumentIds.push_back(this->writeExpression(*arg, out)); 982 } 983 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out); 984 this->writeWord(this->getType(callType), out); 985 this->writeWord(result, out); 986 this->writeWord(fGLSLExtendedInstructions, out); 987 this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out); 988 for (SpvId id : argumentIds) { 989 this->writeWord(id, out); 990 } 991 break; 992 } 993 case kSampledImage_SpecialIntrinsic: { 994 SkASSERT(arguments.size() == 2); 995 SpvId img = this->writeExpression(*arguments[0], out); 996 SpvId sampler = this->writeExpression(*arguments[1], out); 997 this->writeInstruction(SpvOpSampledImage, 998 this->getType(callType), 999 result, 1000 img, 1001 sampler, 1002 out); 1003 break; 1004 } 1005 case kSubpassLoad_SpecialIntrinsic: { 1006 SpvId img = this->writeExpression(*arguments[0], out); 1007 ExpressionArray args; 1008 args.reserve_back(2); 1009 args.push_back(Literal::MakeInt(fContext, /*line=*/-1, /*value=*/0)); 1010 args.push_back(Literal::MakeInt(fContext, /*line=*/-1, /*value=*/0)); 1011 ConstructorCompound ctor(/*line=*/-1, *fContext.fTypes.fInt2, std::move(args)); 1012 SpvId coords = this->writeConstantVector(ctor); 1013 if (arguments.size() == 1) { 1014 this->writeInstruction(SpvOpImageRead, 1015 this->getType(callType), 1016 result, 1017 img, 1018 coords, 1019 out); 1020 } else { 1021 SkASSERT(arguments.size() == 2); 1022 SpvId sample = this->writeExpression(*arguments[1], out); 1023 this->writeInstruction(SpvOpImageRead, 1024 this->getType(callType), 1025 result, 1026 img, 1027 coords, 1028 SpvImageOperandsSampleMask, 1029 sample, 1030 out); 1031 } 1032 break; 1033 } 1034 case kTexture_SpecialIntrinsic: { 1035 SpvOp_ op = SpvOpImageSampleImplicitLod; 1036 const Type& arg1Type = arguments[1]->type(); 1037 switch (arguments[0]->type().dimensions()) { 1038 case SpvDim1D: 1039 if (arg1Type == *fContext.fTypes.fFloat2) { 1040 op = SpvOpImageSampleProjImplicitLod; 1041 } else { 1042 SkASSERT(arg1Type == *fContext.fTypes.fFloat); 1043 } 1044 break; 1045 case SpvDim2D: 1046 if (arg1Type == *fContext.fTypes.fFloat3) { 1047 op = SpvOpImageSampleProjImplicitLod; 1048 } else { 1049 SkASSERT(arg1Type == *fContext.fTypes.fFloat2); 1050 } 1051 break; 1052 case SpvDim3D: 1053 if (arg1Type == *fContext.fTypes.fFloat4) { 1054 op = SpvOpImageSampleProjImplicitLod; 1055 } else { 1056 SkASSERT(arg1Type == *fContext.fTypes.fFloat3); 1057 } 1058 break; 1059 case SpvDimCube: // fall through 1060 case SpvDimRect: // fall through 1061 case SpvDimBuffer: // fall through 1062 case SpvDimSubpassData: 1063 break; 1064 } 1065 SpvId type = this->getType(callType); 1066 SpvId sampler = this->writeExpression(*arguments[0], out); 1067 SpvId uv = this->writeExpression(*arguments[1], out); 1068 if (arguments.size() == 3) { 1069 this->writeInstruction(op, type, result, sampler, uv, 1070 SpvImageOperandsBiasMask, 1071 this->writeExpression(*arguments[2], out), 1072 out); 1073 } else { 1074 SkASSERT(arguments.size() == 2); 1075 if (fProgram.fConfig->fSettings.fSharpenTextures) { 1076 SpvId lodBias = this->writeLiteral(-0.5, *fContext.fTypes.fFloat); 1077 this->writeInstruction(op, type, result, sampler, uv, 1078 SpvImageOperandsBiasMask, lodBias, out); 1079 } else { 1080 this->writeInstruction(op, type, result, sampler, uv, 1081 out); 1082 } 1083 } 1084 break; 1085 } 1086 case kMod_SpecialIntrinsic: { 1087 std::vector<SpvId> args = this->vectorize(arguments, out); 1088 SkASSERT(args.size() == 2); 1089 const Type& operandType = arguments[0]->type(); 1090 SpvOp_ op; 1091 if (is_float(fContext, operandType)) { 1092 op = SpvOpFMod; 1093 } else if (is_signed(fContext, operandType)) { 1094 op = SpvOpSMod; 1095 } else if (is_unsigned(fContext, operandType)) { 1096 op = SpvOpUMod; 1097 } else { 1098 SkASSERT(false); 1099 return 0; 1100 } 1101 this->writeOpCode(op, 5, out); 1102 this->writeWord(this->getType(operandType), out); 1103 this->writeWord(result, out); 1104 this->writeWord(args[0], out); 1105 this->writeWord(args[1], out); 1106 break; 1107 } 1108 case kDFdy_SpecialIntrinsic: { 1109 SpvId fn = this->writeExpression(*arguments[0], out); 1110 this->writeOpCode(SpvOpDPdy, 4, out); 1111 this->writeWord(this->getType(callType), out); 1112 this->writeWord(result, out); 1113 this->writeWord(fn, out); 1114 this->addRTFlipUniform(c.fLine); 1115 using namespace dsl; 1116 DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1, 1117 SKSL_RTFLIP_NAME)); 1118 SpvId rtFlipY = this->vectorize(*rtFlip.y().release(), callType.columns(), out); 1119 SpvId flipped = this->nextId(&callType); 1120 this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result, rtFlipY, 1121 out); 1122 result = flipped; 1123 break; 1124 } 1125 case kClamp_SpecialIntrinsic: { 1126 std::vector<SpvId> args = this->vectorize(arguments, out); 1127 SkASSERT(args.size() == 3); 1128 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp, 1129 GLSLstd450UClamp, args, out); 1130 break; 1131 } 1132 case kMax_SpecialIntrinsic: { 1133 std::vector<SpvId> args = this->vectorize(arguments, out); 1134 SkASSERT(args.size() == 2); 1135 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax, 1136 GLSLstd450UMax, args, out); 1137 break; 1138 } 1139 case kMin_SpecialIntrinsic: { 1140 std::vector<SpvId> args = this->vectorize(arguments, out); 1141 SkASSERT(args.size() == 2); 1142 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin, 1143 GLSLstd450UMin, args, out); 1144 break; 1145 } 1146 case kMix_SpecialIntrinsic: { 1147 std::vector<SpvId> args = this->vectorize(arguments, out); 1148 SkASSERT(args.size() == 3); 1149 if (arguments[2]->type().componentType().isBoolean()) { 1150 // Use OpSelect to implement Boolean mix(). 1151 SpvId falseId = this->writeExpression(*arguments[0], out); 1152 SpvId trueId = this->writeExpression(*arguments[1], out); 1153 SpvId conditionId = this->writeExpression(*arguments[2], out); 1154 this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result, 1155 conditionId, trueId, falseId, out); 1156 } else { 1157 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef, 1158 SpvOpUndef, args, out); 1159 } 1160 break; 1161 } 1162 case kSaturate_SpecialIntrinsic: { 1163 SkASSERT(arguments.size() == 1); 1164 ExpressionArray finalArgs; 1165 finalArgs.reserve_back(3); 1166 finalArgs.push_back(arguments[0]->clone()); 1167 finalArgs.push_back(Literal::MakeFloat(fContext, /*line=*/-1, /*value=*/0)); 1168 finalArgs.push_back(Literal::MakeFloat(fContext, /*line=*/-1, /*value=*/1)); 1169 std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out); 1170 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp, 1171 GLSLstd450UClamp, spvArgs, out); 1172 break; 1173 } 1174 case kSmoothStep_SpecialIntrinsic: { 1175 std::vector<SpvId> args = this->vectorize(arguments, out); 1176 SkASSERT(args.size() == 3); 1177 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef, 1178 SpvOpUndef, args, out); 1179 break; 1180 } 1181 case kStep_SpecialIntrinsic: { 1182 std::vector<SpvId> args = this->vectorize(arguments, out); 1183 SkASSERT(args.size() == 2); 1184 this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef, 1185 SpvOpUndef, args, out); 1186 break; 1187 } 1188 case kMatrixCompMult_SpecialIntrinsic: { 1189 SkASSERT(arguments.size() == 2); 1190 SpvId lhs = this->writeExpression(*arguments[0], out); 1191 SpvId rhs = this->writeExpression(*arguments[1], out); 1192 result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out); 1193 break; 1194 } 1195 } 1196 return result; 1197} 1198 1199SpvId SPIRVCodeGenerator::writeFunctionCallArgument(const Expression& arg, 1200 const Modifiers& paramModifiers, 1201 std::vector<TempVar>* tempVars, 1202 OutputStream& out) { 1203 // ID of temporary variable that we will use to hold this argument, or 0 if it is being 1204 // passed directly 1205 SpvId tmpVar; 1206 // if we need a temporary var to store this argument, this is the value to store in the var 1207 SpvId tmpValueId = -1; 1208 1209 if (is_out(paramModifiers)) { 1210 std::unique_ptr<LValue> lv = this->getLValue(arg, out); 1211 SpvId ptr = lv->getPointer(); 1212 if (ptr != (SpvId) -1 && lv->isMemoryObjectPointer()) { 1213 return ptr; 1214 } 1215 1216 // lvalue cannot simply be read and written via a pointer (e.g. it's a swizzle). We need to 1217 // to use a temp variable. 1218 if (is_in(paramModifiers)) { 1219 tmpValueId = lv->load(out); 1220 } 1221 tmpVar = this->nextId(&arg.type()); 1222 tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)}); 1223 } else { 1224 // See getFunctionType for an explanation of why we're always using pointer parameters. 1225 tmpValueId = this->writeExpression(arg, out); 1226 tmpVar = this->nextId(nullptr); 1227 } 1228 this->writeInstruction(SpvOpVariable, 1229 this->getPointerType(arg.type(), SpvStorageClassFunction), 1230 tmpVar, 1231 SpvStorageClassFunction, 1232 fVariableBuffer); 1233 if (tmpValueId != (SpvId)-1) { 1234 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out); 1235 } 1236 return tmpVar; 1237} 1238 1239void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) { 1240 for (const TempVar& tempVar : tempVars) { 1241 SpvId load = this->nextId(tempVar.type); 1242 this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out); 1243 tempVar.lvalue->store(load, out); 1244 } 1245} 1246 1247SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) { 1248 const FunctionDeclaration& function = c.function(); 1249 if (function.isIntrinsic() && !function.definition()) { 1250 return this->writeIntrinsicCall(c, out); 1251 } 1252 const ExpressionArray& arguments = c.arguments(); 1253 const auto& entry = fFunctionMap.find(&function); 1254 if (entry == fFunctionMap.end()) { 1255 fContext.fErrors->error(c.fLine, "function '" + function.description() + 1256 "' is not defined"); 1257 return -1; 1258 } 1259 // Temp variables are used to write back out-parameters after the function call is complete. 1260 std::vector<TempVar> tempVars; 1261 std::vector<SpvId> argumentIds; 1262 argumentIds.reserve(arguments.size()); 1263 for (size_t i = 0; i < arguments.size(); i++) { 1264 argumentIds.push_back(this->writeFunctionCallArgument(*arguments[i], 1265 function.parameters()[i]->modifiers(), 1266 &tempVars, 1267 out)); 1268 } 1269 SpvId result = this->nextId(nullptr); 1270 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) arguments.size(), out); 1271 this->writeWord(this->getType(c.type()), out); 1272 this->writeWord(result, out); 1273 this->writeWord(entry->second, out); 1274 for (SpvId id : argumentIds) { 1275 this->writeWord(id, out); 1276 } 1277 // Now that the call is complete, we copy temp out-variables back to their real lvalues. 1278 this->copyBackTempVars(tempVars, out); 1279 return result; 1280} 1281 1282SpvId SPIRVCodeGenerator::writeConstantVector(const AnyConstructor& c) { 1283 const Type& type = c.type(); 1284 SkASSERT(type.isVector() && c.isCompileTimeConstant()); 1285 1286 // Get each of the constructor components as SPIR-V constants. 1287 SPIRVVectorConstant key{this->getType(type), 1288 /*fValueId=*/{SpvId(-1), SpvId(-1), SpvId(-1), SpvId(-1)}}; 1289 1290 const Type& scalarType = type.componentType(); 1291 for (int n = 0; n < type.columns(); n++) { 1292 skstd::optional<double> slotVal = c.getConstantValue(n); 1293 if (!slotVal.has_value()) { 1294 SkDEBUGFAILF("writeConstantVector: %s not actually constant", c.description().c_str()); 1295 return (SpvId)-1; 1296 } 1297 key.fValueId[n] = this->writeLiteral(*slotVal, scalarType); 1298 } 1299 1300 // Check to see if we've already synthesized this vector constant. 1301 auto [iter, newlyCreated] = fVectorConstants.insert({key, (SpvId)-1}); 1302 if (newlyCreated) { 1303 // Emit an OpConstantComposite instruction for this constant. 1304 SpvId result = this->nextId(&type); 1305 this->writeOpCode(SpvOpConstantComposite, 3 + type.columns(), fConstantBuffer); 1306 this->writeWord(key.fTypeId, fConstantBuffer); 1307 this->writeWord(result, fConstantBuffer); 1308 for (int i = 0; i < type.columns(); i++) { 1309 this->writeWord(key.fValueId[i], fConstantBuffer); 1310 } 1311 iter->second = result; 1312 } 1313 return iter->second; 1314} 1315 1316SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId, 1317 const Type& inputType, 1318 const Type& outputType, 1319 OutputStream& out) { 1320 if (outputType.isFloat()) { 1321 return this->castScalarToFloat(inputExprId, inputType, outputType, out); 1322 } 1323 if (outputType.isSigned()) { 1324 return this->castScalarToSignedInt(inputExprId, inputType, outputType, out); 1325 } 1326 if (outputType.isUnsigned()) { 1327 return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out); 1328 } 1329 if (outputType.isBoolean()) { 1330 return this->castScalarToBoolean(inputExprId, inputType, outputType, out); 1331 } 1332 1333 fContext.fErrors->error(-1, "unsupported cast: " + inputType.description() + 1334 " to " + outputType.description()); 1335 return inputExprId; 1336} 1337 1338SpvId SPIRVCodeGenerator::writeFloatConstructor(const AnyConstructor& c, OutputStream& out) { 1339 SkASSERT(c.argumentSpan().size() == 1); 1340 SkASSERT(c.type().isFloat()); 1341 const Expression& ctorExpr = *c.argumentSpan().front(); 1342 SpvId expressionId = this->writeExpression(ctorExpr, out); 1343 return this->castScalarToFloat(expressionId, ctorExpr.type(), c.type(), out); 1344} 1345 1346SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType, 1347 const Type& outputType, OutputStream& out) { 1348 // Casting a float to float is a no-op. 1349 if (inputType.isFloat()) { 1350 return inputId; 1351 } 1352 1353 // Given the input type, generate the appropriate instruction to cast to float. 1354 SpvId result = this->nextId(&outputType); 1355 if (inputType.isBoolean()) { 1356 // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0. 1357 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat); 1358 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat); 1359 this->writeInstruction(SpvOpSelect, this->getType(outputType), result, 1360 inputId, oneID, zeroID, out); 1361 } else if (inputType.isSigned()) { 1362 this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out); 1363 } else if (inputType.isUnsigned()) { 1364 this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out); 1365 } else { 1366 SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str()); 1367 return (SpvId)-1; 1368 } 1369 return result; 1370} 1371 1372SpvId SPIRVCodeGenerator::writeIntConstructor(const AnyConstructor& c, OutputStream& out) { 1373 SkASSERT(c.argumentSpan().size() == 1); 1374 SkASSERT(c.type().isSigned()); 1375 const Expression& ctorExpr = *c.argumentSpan().front(); 1376 SpvId expressionId = this->writeExpression(ctorExpr, out); 1377 return this->castScalarToSignedInt(expressionId, ctorExpr.type(), c.type(), out); 1378} 1379 1380SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType, 1381 const Type& outputType, OutputStream& out) { 1382 // Casting a signed int to signed int is a no-op. 1383 if (inputType.isSigned()) { 1384 return inputId; 1385 } 1386 1387 // Given the input type, generate the appropriate instruction to cast to signed int. 1388 SpvId result = this->nextId(&outputType); 1389 if (inputType.isBoolean()) { 1390 // Use OpSelect to convert the boolean argument to a literal 1 or 0. 1391 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt); 1392 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt); 1393 this->writeInstruction(SpvOpSelect, this->getType(outputType), result, 1394 inputId, oneID, zeroID, out); 1395 } else if (inputType.isFloat()) { 1396 this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out); 1397 } else if (inputType.isUnsigned()) { 1398 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out); 1399 } else { 1400 SkDEBUGFAILF("unsupported type for signed int typecast: %s", 1401 inputType.description().c_str()); 1402 return (SpvId)-1; 1403 } 1404 return result; 1405} 1406 1407SpvId SPIRVCodeGenerator::writeUIntConstructor(const AnyConstructor& c, OutputStream& out) { 1408 SkASSERT(c.argumentSpan().size() == 1); 1409 SkASSERT(c.type().isUnsigned()); 1410 const Expression& ctorExpr = *c.argumentSpan().front(); 1411 SpvId expressionId = this->writeExpression(ctorExpr, out); 1412 return this->castScalarToUnsignedInt(expressionId, ctorExpr.type(), c.type(), out); 1413} 1414 1415SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType, 1416 const Type& outputType, OutputStream& out) { 1417 // Casting an unsigned int to unsigned int is a no-op. 1418 if (inputType.isUnsigned()) { 1419 return inputId; 1420 } 1421 1422 // Given the input type, generate the appropriate instruction to cast to unsigned int. 1423 SpvId result = this->nextId(&outputType); 1424 if (inputType.isBoolean()) { 1425 // Use OpSelect to convert the boolean argument to a literal 1u or 0u. 1426 const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt); 1427 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt); 1428 this->writeInstruction(SpvOpSelect, this->getType(outputType), result, 1429 inputId, oneID, zeroID, out); 1430 } else if (inputType.isFloat()) { 1431 this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out); 1432 } else if (inputType.isSigned()) { 1433 this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out); 1434 } else { 1435 SkDEBUGFAILF("unsupported type for unsigned int typecast: %s", 1436 inputType.description().c_str()); 1437 return (SpvId)-1; 1438 } 1439 return result; 1440} 1441 1442SpvId SPIRVCodeGenerator::writeBooleanConstructor(const AnyConstructor& c, OutputStream& out) { 1443 SkASSERT(c.argumentSpan().size() == 1); 1444 SkASSERT(c.type().isBoolean()); 1445 const Expression& ctorExpr = *c.argumentSpan().front(); 1446 SpvId expressionId = this->writeExpression(ctorExpr, out); 1447 return this->castScalarToBoolean(expressionId, ctorExpr.type(), c.type(), out); 1448} 1449 1450SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType, 1451 const Type& outputType, OutputStream& out) { 1452 // Casting a bool to bool is a no-op. 1453 if (inputType.isBoolean()) { 1454 return inputId; 1455 } 1456 1457 // Given the input type, generate the appropriate instruction to cast to bool. 1458 SpvId result = this->nextId(nullptr); 1459 if (inputType.isSigned()) { 1460 // Synthesize a boolean result by comparing the input against a signed zero literal. 1461 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt); 1462 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result, 1463 inputId, zeroID, out); 1464 } else if (inputType.isUnsigned()) { 1465 // Synthesize a boolean result by comparing the input against an unsigned zero literal. 1466 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt); 1467 this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result, 1468 inputId, zeroID, out); 1469 } else if (inputType.isFloat()) { 1470 // Synthesize a boolean result by comparing the input against a floating-point zero literal. 1471 const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat); 1472 this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result, 1473 inputId, zeroID, out); 1474 } else { 1475 SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str()); 1476 return (SpvId)-1; 1477 } 1478 return result; 1479} 1480 1481void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, 1482 OutputStream& out) { 1483 SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat); 1484 std::vector<SpvId> columnIds; 1485 columnIds.reserve(type.columns()); 1486 for (int column = 0; column < type.columns(); column++) { 1487 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(), 1488 out); 1489 this->writeWord(this->getType(type.componentType().toCompound( 1490 fContext, /*columns=*/type.rows(), /*rows=*/1)), 1491 out); 1492 SpvId columnId = this->nextId(&type); 1493 this->writeWord(columnId, out); 1494 columnIds.push_back(columnId); 1495 for (int row = 0; row < type.rows(); row++) { 1496 this->writeWord(row == column ? diagonal : zeroId, out); 1497 } 1498 } 1499 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), 1500 out); 1501 this->writeWord(this->getType(type), out); 1502 this->writeWord(id, out); 1503 for (SpvId columnId : columnIds) { 1504 this->writeWord(columnId, out); 1505 } 1506} 1507 1508SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType, 1509 OutputStream& out) { 1510 SkASSERT(srcType.isMatrix()); 1511 SkASSERT(dstType.isMatrix()); 1512 SkASSERT(srcType.componentType() == dstType.componentType()); 1513 SpvId id = this->nextId(&dstType); 1514 SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext, 1515 srcType.rows(), 1516 1)); 1517 SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext, 1518 dstType.rows(), 1519 1)); 1520 SkASSERT(dstType.componentType().isFloat()); 1521 const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType()); 1522 const SpvId oneId = this->writeLiteral(1.0, dstType.componentType()); 1523 1524 SpvId columns[4]; 1525 for (int i = 0; i < dstType.columns(); i++) { 1526 if (i < srcType.columns()) { 1527 // we're still inside the src matrix, copy the column 1528 SpvId srcColumn = this->nextId(&dstType); 1529 this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out); 1530 SpvId dstColumn; 1531 if (srcType.rows() == dstType.rows()) { 1532 // columns are equal size, don't need to do anything 1533 dstColumn = srcColumn; 1534 } 1535 else if (dstType.rows() > srcType.rows()) { 1536 // dst column is bigger, need to zero-pad it 1537 dstColumn = this->nextId(&dstType); 1538 int delta = dstType.rows() - srcType.rows(); 1539 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out); 1540 this->writeWord(dstColumnType, out); 1541 this->writeWord(dstColumn, out); 1542 this->writeWord(srcColumn, out); 1543 for (int j = srcType.rows(); j < dstType.rows(); ++j) { 1544 this->writeWord((i == j) ? oneId : zeroId, out); 1545 } 1546 } 1547 else { 1548 // dst column is smaller, need to swizzle the src column 1549 dstColumn = this->nextId(&dstType); 1550 this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out); 1551 this->writeWord(dstColumnType, out); 1552 this->writeWord(dstColumn, out); 1553 this->writeWord(srcColumn, out); 1554 this->writeWord(srcColumn, out); 1555 for (int j = 0; j < dstType.rows(); j++) { 1556 this->writeWord(j, out); 1557 } 1558 } 1559 columns[i] = dstColumn; 1560 } else { 1561 // we're past the end of the src matrix, need to synthesize an identity-matrix column 1562 SpvId identityColumn = this->nextId(&dstType); 1563 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out); 1564 this->writeWord(dstColumnType, out); 1565 this->writeWord(identityColumn, out); 1566 for (int j = 0; j < dstType.rows(); ++j) { 1567 this->writeWord((i == j) ? oneId : zeroId, out); 1568 } 1569 columns[i] = identityColumn; 1570 } 1571 } 1572 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out); 1573 this->writeWord(this->getType(dstType), out); 1574 this->writeWord(id, out); 1575 for (int i = 0; i < dstType.columns(); i++) { 1576 this->writeWord(columns[i], out); 1577 } 1578 return id; 1579} 1580 1581void SPIRVCodeGenerator::addColumnEntry(const Type& columnType, 1582 std::vector<SpvId>* currentColumn, 1583 std::vector<SpvId>* columnIds, 1584 int rows, 1585 SpvId entry, 1586 OutputStream& out) { 1587 SkASSERT((int)currentColumn->size() < rows); 1588 currentColumn->push_back(entry); 1589 if ((int)currentColumn->size() == rows) { 1590 // Synthesize this column into a vector. 1591 SpvId columnId = this->writeComposite(*currentColumn, columnType, out); 1592 columnIds->push_back(columnId); 1593 currentColumn->clear(); 1594 } 1595} 1596 1597SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) { 1598 const Type& type = c.type(); 1599 SkASSERT(type.isMatrix()); 1600 SkASSERT(!c.arguments().empty()); 1601 const Type& arg0Type = c.arguments()[0]->type(); 1602 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1603 // an instruction 1604 std::vector<SpvId> arguments; 1605 arguments.reserve(c.arguments().size()); 1606 for (const std::unique_ptr<Expression>& arg : c.arguments()) { 1607 arguments.push_back(this->writeExpression(*arg, out)); 1608 } 1609 1610 if (arguments.size() == 1 && arg0Type.isVector()) { 1611 // Special-case handling of float4 -> mat2x2. 1612 SkASSERT(type.rows() == 2 && type.columns() == 2); 1613 SkASSERT(arg0Type.columns() == 4); 1614 SpvId componentType = this->getType(type.componentType()); 1615 SpvId v[4]; 1616 for (int i = 0; i < 4; ++i) { 1617 v[i] = this->nextId(&type); 1618 this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, 1619 out); 1620 } 1621 const Type& vecType = type.componentType().toCompound(fContext, /*columns=*/2, /*rows=*/1); 1622 SpvId v0v1 = this->writeComposite({v[0], v[1]}, vecType, out); 1623 SpvId v2v3 = this->writeComposite({v[2], v[3]}, vecType, out); 1624 return this->writeComposite({v0v1, v2v3}, type, out); 1625 } 1626 1627 int rows = type.rows(); 1628 const Type& columnType = type.componentType().toCompound(fContext, 1629 /*columns=*/rows, /*rows=*/1); 1630 // SpvIds of completed columns of the matrix. 1631 std::vector<SpvId> columnIds; 1632 // SpvIds of scalars we have written to the current column so far. 1633 std::vector<SpvId> currentColumn; 1634 for (size_t i = 0; i < arguments.size(); i++) { 1635 const Type& argType = c.arguments()[i]->type(); 1636 if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) { 1637 // This vector is a complete matrix column by itself and can be used as-is. 1638 columnIds.push_back(arguments[i]); 1639 } else if (argType.columns() == 1) { 1640 // This argument is a lone scalar and can be added to the current column as-is. 1641 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, arguments[i], out); 1642 } else { 1643 // This argument needs to be decomposed into its constituent scalars. 1644 SpvId componentType = this->getType(argType.componentType()); 1645 for (int j = 0; j < argType.columns(); ++j) { 1646 SpvId swizzle = this->nextId(&argType); 1647 this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle, 1648 arguments[i], j, out); 1649 this->addColumnEntry(columnType, ¤tColumn, &columnIds, rows, swizzle, out); 1650 } 1651 } 1652 } 1653 SkASSERT(columnIds.size() == (size_t) type.columns()); 1654 return this->writeComposite(columnIds, type, out); 1655} 1656 1657SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c, 1658 OutputStream& out) { 1659 return c.type().isMatrix() ? this->writeMatrixConstructor(c, out) 1660 : this->writeVectorConstructor(c, out); 1661} 1662 1663SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) { 1664 const Type& type = c.type(); 1665 const Type& componentType = type.componentType(); 1666 SkASSERT(type.isVector()); 1667 1668 if (c.isCompileTimeConstant()) { 1669 return this->writeConstantVector(c); 1670 } 1671 1672 std::vector<SpvId> arguments; 1673 arguments.reserve(c.arguments().size()); 1674 for (size_t i = 0; i < c.arguments().size(); i++) { 1675 const Type& argType = c.arguments()[i]->type(); 1676 SkASSERT(componentType == argType.componentType()); 1677 1678 SpvId arg = this->writeExpression(*c.arguments()[i], out); 1679 if (argType.isMatrix()) { 1680 // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out 1681 // each scalar separately. 1682 SkASSERT(argType.rows() == 2); 1683 SkASSERT(argType.columns() == 2); 1684 for (int j = 0; j < 4; ++j) { 1685 SpvId componentId = this->nextId(&componentType); 1686 this->writeInstruction(SpvOpCompositeExtract, this->getType(componentType), 1687 componentId, arg, j / 2, j % 2, out); 1688 arguments.push_back(componentId); 1689 } 1690 } else if (argType.isVector()) { 1691 // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle 1692 // vector arguments at all, so we always extract each vector component and pass them 1693 // into OpCompositeConstruct individually. 1694 for (int j = 0; j < argType.columns(); j++) { 1695 SpvId componentId = this->nextId(&componentType); 1696 this->writeInstruction(SpvOpCompositeExtract, this->getType(componentType), 1697 componentId, arg, j, out); 1698 arguments.push_back(componentId); 1699 } 1700 } else { 1701 arguments.push_back(arg); 1702 } 1703 } 1704 1705 return this->writeComposite(arguments, type, out); 1706} 1707 1708SpvId SPIRVCodeGenerator::writeComposite(const std::vector<SpvId>& arguments, 1709 const Type& type, 1710 OutputStream& out) { 1711 SkASSERT(arguments.size() == (type.isStruct() ? type.fields().size() : (size_t)type.columns())); 1712 1713 SpvId result = this->nextId(&type); 1714 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out); 1715 this->writeWord(this->getType(type), out); 1716 this->writeWord(result, out); 1717 for (SpvId id : arguments) { 1718 this->writeWord(id, out); 1719 } 1720 return result; 1721} 1722 1723SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) { 1724 // Use writeConstantVector to deduplicate constant splats. 1725 if (c.isCompileTimeConstant()) { 1726 return this->writeConstantVector(c); 1727 } 1728 1729 // Write the splat argument. 1730 SpvId argument = this->writeExpression(*c.argument(), out); 1731 1732 // Generate a OpCompositeConstruct which repeats the argument N times. 1733 std::vector<SpvId> arguments(/*count*/ c.type().columns(), /*value*/ argument); 1734 return this->writeComposite(arguments, c.type(), out); 1735} 1736 1737 1738SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) { 1739 SkASSERT(c.type().isArray() || c.type().isStruct()); 1740 auto ctorArgs = c.argumentSpan(); 1741 1742 std::vector<SpvId> arguments; 1743 arguments.reserve(ctorArgs.size()); 1744 for (const std::unique_ptr<Expression>& arg : ctorArgs) { 1745 arguments.push_back(this->writeExpression(*arg, out)); 1746 } 1747 1748 return this->writeComposite(arguments, c.type(), out); 1749} 1750 1751SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c, 1752 OutputStream& out) { 1753 const Type& type = c.type(); 1754 if (this->getActualType(type) == this->getActualType(c.argument()->type())) { 1755 return this->writeExpression(*c.argument(), out); 1756 } 1757 1758 const Expression& ctorExpr = *c.argument(); 1759 SpvId expressionId = this->writeExpression(ctorExpr, out); 1760 return this->castScalarToType(expressionId, ctorExpr.type(), type, out); 1761} 1762 1763SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c, 1764 OutputStream& out) { 1765 const Type& ctorType = c.type(); 1766 const Type& argType = c.argument()->type(); 1767 SkASSERT(ctorType.isVector() || ctorType.isMatrix()); 1768 1769 // Write the composite that we are casting. If the actual type matches, we are done. 1770 SpvId compositeId = this->writeExpression(*c.argument(), out); 1771 if (this->getActualType(ctorType) == this->getActualType(argType)) { 1772 return compositeId; 1773 } 1774 1775 // writeMatrixCopy can cast matrices to a different type. 1776 if (ctorType.isMatrix()) { 1777 return this->writeMatrixCopy(compositeId, argType, ctorType, out); 1778 } 1779 1780 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the 1781 // components and convert each one manually. 1782 const Type& srcType = argType.componentType(); 1783 const Type& dstType = ctorType.componentType(); 1784 1785 std::vector<SpvId> arguments; 1786 arguments.reserve(argType.columns()); 1787 for (int index = 0; index < argType.columns(); ++index) { 1788 SpvId componentId = this->nextId(&srcType); 1789 this->writeInstruction(SpvOpCompositeExtract, this->getType(srcType), componentId, 1790 compositeId, index, out); 1791 arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out)); 1792 } 1793 1794 return this->writeComposite(arguments, ctorType, out); 1795} 1796 1797SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, 1798 OutputStream& out) { 1799 const Type& type = c.type(); 1800 SkASSERT(type.isMatrix()); 1801 SkASSERT(c.argument()->type().isScalar()); 1802 1803 // Write out the scalar argument. 1804 SpvId argument = this->writeExpression(*c.argument(), out); 1805 1806 // Build the diagonal matrix. 1807 SpvId result = this->nextId(&type); 1808 this->writeUniformScaleMatrix(result, argument, type, out); 1809 return result; 1810} 1811 1812SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c, 1813 OutputStream& out) { 1814 // Write the input matrix. 1815 SpvId argument = this->writeExpression(*c.argument(), out); 1816 1817 // Use matrix-copy to resize the input matrix to its new size. 1818 return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out); 1819} 1820 1821static SpvStorageClass_ get_storage_class(const Variable& var, 1822 SpvStorageClass_ fallbackStorageClass) { 1823 const Modifiers& modifiers = var.modifiers(); 1824 if (modifiers.fFlags & Modifiers::kIn_Flag) { 1825 SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag)); 1826 return SpvStorageClassInput; 1827 } 1828 if (modifiers.fFlags & Modifiers::kOut_Flag) { 1829 SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag)); 1830 return SpvStorageClassOutput; 1831 } 1832 if (modifiers.fFlags & Modifiers::kUniform_Flag) { 1833 if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) { 1834 return SpvStorageClassPushConstant; 1835 } 1836 if (var.type().typeKind() == Type::TypeKind::kSampler || 1837 var.type().typeKind() == Type::TypeKind::kSeparateSampler || 1838 var.type().typeKind() == Type::TypeKind::kTexture) { 1839 return SpvStorageClassUniformConstant; 1840 } 1841 return SpvStorageClassUniform; 1842 } 1843 return fallbackStorageClass; 1844} 1845 1846static SpvStorageClass_ get_storage_class(const Expression& expr) { 1847 switch (expr.kind()) { 1848 case Expression::Kind::kVariableReference: { 1849 const Variable& var = *expr.as<VariableReference>().variable(); 1850 if (var.storage() != Variable::Storage::kGlobal) { 1851 return SpvStorageClassFunction; 1852 } 1853 return get_storage_class(var, SpvStorageClassPrivate); 1854 } 1855 case Expression::Kind::kFieldAccess: 1856 return get_storage_class(*expr.as<FieldAccess>().base()); 1857 case Expression::Kind::kIndex: 1858 return get_storage_class(*expr.as<IndexExpression>().base()); 1859 default: 1860 return SpvStorageClassFunction; 1861 } 1862} 1863 1864std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) { 1865 std::vector<SpvId> chain; 1866 switch (expr.kind()) { 1867 case Expression::Kind::kIndex: { 1868 const IndexExpression& indexExpr = expr.as<IndexExpression>(); 1869 chain = this->getAccessChain(*indexExpr.base(), out); 1870 chain.push_back(this->writeExpression(*indexExpr.index(), out)); 1871 break; 1872 } 1873 case Expression::Kind::kFieldAccess: { 1874 const FieldAccess& fieldExpr = expr.as<FieldAccess>(); 1875 chain = this->getAccessChain(*fieldExpr.base(), out); 1876 chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt)); 1877 break; 1878 } 1879 default: { 1880 SpvId id = this->getLValue(expr, out)->getPointer(); 1881 SkASSERT(id != (SpvId) -1); 1882 chain.push_back(id); 1883 break; 1884 } 1885 } 1886 return chain; 1887} 1888 1889class PointerLValue : public SPIRVCodeGenerator::LValue { 1890public: 1891 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type, 1892 SPIRVCodeGenerator::Precision precision) 1893 : fGen(gen) 1894 , fPointer(pointer) 1895 , fIsMemoryObject(isMemoryObject) 1896 , fType(type) 1897 , fPrecision(precision) {} 1898 1899 SpvId getPointer() override { 1900 return fPointer; 1901 } 1902 1903 bool isMemoryObjectPointer() const override { 1904 return fIsMemoryObject; 1905 } 1906 1907 SpvId load(OutputStream& out) override { 1908 SpvId result = fGen.nextId(fPrecision); 1909 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out); 1910 return result; 1911 } 1912 1913 void store(SpvId value, OutputStream& out) override { 1914 fGen.writeInstruction(SpvOpStore, fPointer, value, out); 1915 } 1916 1917private: 1918 SPIRVCodeGenerator& fGen; 1919 const SpvId fPointer; 1920 const bool fIsMemoryObject; 1921 const SpvId fType; 1922 const SPIRVCodeGenerator::Precision fPrecision; 1923}; 1924 1925class SwizzleLValue : public SPIRVCodeGenerator::LValue { 1926public: 1927 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components, 1928 const Type& baseType, const Type& swizzleType) 1929 : fGen(gen) 1930 , fVecPointer(vecPointer) 1931 , fComponents(components) 1932 , fBaseType(&baseType) 1933 , fSwizzleType(&swizzleType) {} 1934 1935 bool applySwizzle(const ComponentArray& components, const Type& newType) override { 1936 ComponentArray updatedSwizzle; 1937 for (int8_t component : components) { 1938 if (component < 0 || component >= fComponents.count()) { 1939 SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component); 1940 return false; 1941 } 1942 updatedSwizzle.push_back(fComponents[component]); 1943 } 1944 fComponents = updatedSwizzle; 1945 fSwizzleType = &newType; 1946 return true; 1947 } 1948 1949 SpvId load(OutputStream& out) override { 1950 SpvId base = fGen.nextId(fBaseType); 1951 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out); 1952 SpvId result = fGen.nextId(fBaseType); 1953 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out); 1954 fGen.writeWord(fGen.getType(*fSwizzleType), out); 1955 fGen.writeWord(result, out); 1956 fGen.writeWord(base, out); 1957 fGen.writeWord(base, out); 1958 for (int component : fComponents) { 1959 fGen.writeWord(component, out); 1960 } 1961 return result; 1962 } 1963 1964 void store(SpvId value, OutputStream& out) override { 1965 // use OpVectorShuffle to mix and match the vector components. We effectively create 1966 // a virtual vector out of the concatenation of the left and right vectors, and then 1967 // select components from this virtual vector to make the result vector. For 1968 // instance, given: 1969 // float3L = ...; 1970 // float3R = ...; 1971 // L.xz = R.xy; 1972 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want 1973 // our result vector to look like (R.x, L.y, R.y), so we need to select indices 1974 // (3, 1, 4). 1975 SpvId base = fGen.nextId(fBaseType); 1976 fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out); 1977 SpvId shuffle = fGen.nextId(fBaseType); 1978 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out); 1979 fGen.writeWord(fGen.getType(*fBaseType), out); 1980 fGen.writeWord(shuffle, out); 1981 fGen.writeWord(base, out); 1982 fGen.writeWord(value, out); 1983 for (int i = 0; i < fBaseType->columns(); i++) { 1984 // current offset into the virtual vector, defaults to pulling the unmodified 1985 // value from the left side 1986 int offset = i; 1987 // check to see if we are writing this component 1988 for (size_t j = 0; j < fComponents.size(); j++) { 1989 if (fComponents[j] == i) { 1990 // we're writing to this component, so adjust the offset to pull from 1991 // the correct component of the right side instead of preserving the 1992 // value from the left 1993 offset = (int) (j + fBaseType->columns()); 1994 break; 1995 } 1996 } 1997 fGen.writeWord(offset, out); 1998 } 1999 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out); 2000 } 2001 2002private: 2003 SPIRVCodeGenerator& fGen; 2004 const SpvId fVecPointer; 2005 ComponentArray fComponents; 2006 const Type* fBaseType; 2007 const Type* fSwizzleType; 2008}; 2009 2010int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const { 2011 auto iter = fTopLevelUniformMap.find(&var); 2012 return (iter != fTopLevelUniformMap.end()) ? iter->second : -1; 2013} 2014 2015std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr, 2016 OutputStream& out) { 2017 const Type& type = expr.type(); 2018 Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed; 2019 switch (expr.kind()) { 2020 case Expression::Kind::kVariableReference: { 2021 const Variable& var = *expr.as<VariableReference>().variable(); 2022 int uniformIdx = this->findUniformFieldIndex(var); 2023 if (uniformIdx >= 0) { 2024 SpvId memberId = this->nextId(nullptr); 2025 SpvId typeId = this->getPointerType(type, SpvStorageClassUniform); 2026 SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt); 2027 this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId, 2028 uniformIdxId, out); 2029 return std::make_unique<PointerLValue>(*this, memberId, 2030 /*isMemoryObjectPointer=*/true, 2031 this->getType(type), precision); 2032 } 2033 SpvId typeId = this->getType(type, this->memoryLayoutForVariable(var)); 2034 auto entry = fVariableMap.find(&var); 2035 SkASSERTF(entry != fVariableMap.end(), "%s", expr.description().c_str()); 2036 return std::make_unique<PointerLValue>(*this, entry->second, 2037 /*isMemoryObjectPointer=*/true, 2038 typeId, precision); 2039 } 2040 case Expression::Kind::kIndex: // fall through 2041 case Expression::Kind::kFieldAccess: { 2042 std::vector<SpvId> chain = this->getAccessChain(expr, out); 2043 SpvId member = this->nextId(nullptr); 2044 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out); 2045 this->writeWord(this->getPointerType(type, get_storage_class(expr)), out); 2046 this->writeWord(member, out); 2047 for (SpvId idx : chain) { 2048 this->writeWord(idx, out); 2049 } 2050 return std::make_unique<PointerLValue>(*this, member, /*isMemoryObjectPointer=*/false, 2051 this->getType(type), precision); 2052 } 2053 case Expression::Kind::kSwizzle: { 2054 const Swizzle& swizzle = expr.as<Swizzle>(); 2055 std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out); 2056 if (lvalue->applySwizzle(swizzle.components(), type)) { 2057 return lvalue; 2058 } 2059 SpvId base = lvalue->getPointer(); 2060 if (base == (SpvId) -1) { 2061 fContext.fErrors->error(swizzle.fLine, "unable to retrieve lvalue from swizzle"); 2062 } 2063 if (swizzle.components().size() == 1) { 2064 SpvId member = this->nextId(nullptr); 2065 SpvId typeId = this->getPointerType(type, get_storage_class(*swizzle.base())); 2066 SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt); 2067 this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out); 2068 return std::make_unique<PointerLValue>(*this, 2069 member, 2070 /*isMemoryObjectPointer=*/false, 2071 this->getType(type), 2072 precision); 2073 } else { 2074 return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(), 2075 swizzle.base()->type(), type); 2076 } 2077 } 2078 default: { 2079 // expr isn't actually an lvalue, create a placeholder variable for it. This case 2080 // happens due to the need to store values in temporary variables during function 2081 // calls (see comments in getFunctionType); erroneous uses of rvalues as lvalues 2082 // should have been caught before code generation 2083 SpvId result = this->nextId(nullptr); 2084 SpvId pointerType = this->getPointerType(type, SpvStorageClassFunction); 2085 this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction, 2086 fVariableBuffer); 2087 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out); 2088 return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true, 2089 this->getType(type), precision); 2090 } 2091 } 2092} 2093 2094SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) { 2095 const Variable* variable = ref.variable(); 2096 if (variable->modifiers().fLayout.fBuiltin == DEVICE_FRAGCOORDS_BUILTIN) { 2097 // Down below, we rewrite raw references to sk_FragCoord with expressions that reference 2098 // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly access 2099 // the fragcoord; do so now. 2100 dsl::DSLGlobalVar fragCoord("sk_FragCoord"); 2101 return this->getLValue(*dsl::DSLExpression(fragCoord).release(), out)->load(out); 2102 } 2103 if (variable->modifiers().fLayout.fBuiltin == DEVICE_CLOCKWISE_BUILTIN) { 2104 // Down below, we rewrite raw references to sk_Clockwise with expressions that reference 2105 // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly 2106 // access front facing; do so now. 2107 dsl::DSLGlobalVar clockwise("sk_Clockwise"); 2108 return this->getLValue(*dsl::DSLExpression(clockwise).release(), out)->load(out); 2109 } 2110 2111 // Handle inserting use of uniform to flip y when referencing sk_FragCoord. 2112 if (variable->modifiers().fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) { 2113 this->addRTFlipUniform(ref.fLine); 2114 // Use sk_RTAdjust to compute the flipped coordinate 2115 using namespace dsl; 2116 const char* DEVICE_COORDS_NAME = "__device_FragCoords"; 2117 SymbolTable& symbols = *ThreadContext::SymbolTable(); 2118 // Use a uniform to flip the Y coordinate. The new expression will be written in 2119 // terms of __device_FragCoords, which is a fake variable that means "access the 2120 // underlying fragcoords directly without flipping it". 2121 DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1, 2122 SKSL_RTFLIP_NAME)); 2123 if (!symbols[DEVICE_COORDS_NAME]) { 2124 AutoAttachPoolToThread attach(fProgram.fPool.get()); 2125 Modifiers modifiers; 2126 modifiers.fLayout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN; 2127 auto coordsVar = std::make_unique<Variable>(/*line=*/-1, 2128 fContext.fModifiersPool->add(modifiers), 2129 DEVICE_COORDS_NAME, 2130 fContext.fTypes.fFloat4.get(), 2131 true, 2132 Variable::Storage::kGlobal); 2133 fSPIRVBonusVariables.insert(coordsVar.get()); 2134 symbols.add(std::move(coordsVar)); 2135 } 2136 DSLGlobalVar deviceCoord(DEVICE_COORDS_NAME); 2137 std::unique_ptr<Expression> rtFlipSkSLExpr = rtFlip.release(); 2138 DSLExpression x = DSLExpression(rtFlipSkSLExpr->clone()).x(); 2139 DSLExpression y = DSLExpression(std::move(rtFlipSkSLExpr)).y(); 2140 return this->writeExpression(*dsl::Float4(deviceCoord.x(), 2141 std::move(x) + std::move(y) * deviceCoord.y(), 2142 deviceCoord.z(), 2143 deviceCoord.w()).release(), 2144 out); 2145 } 2146 2147 // Handle flipping sk_Clockwise. 2148 if (variable->modifiers().fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN) { 2149 this->addRTFlipUniform(ref.fLine); 2150 using namespace dsl; 2151 const char* DEVICE_CLOCKWISE_NAME = "__device_Clockwise"; 2152 SymbolTable& symbols = *ThreadContext::SymbolTable(); 2153 // Use a uniform to flip the Y coordinate. The new expression will be written in 2154 // terms of __device_Clockwise, which is a fake variable that means "access the 2155 // underlying FrontFacing directly". 2156 DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1, 2157 SKSL_RTFLIP_NAME)); 2158 if (!symbols[DEVICE_CLOCKWISE_NAME]) { 2159 AutoAttachPoolToThread attach(fProgram.fPool.get()); 2160 Modifiers modifiers; 2161 modifiers.fLayout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN; 2162 auto clockwiseVar = std::make_unique<Variable>(/*line=*/-1, 2163 fContext.fModifiersPool->add(modifiers), 2164 DEVICE_CLOCKWISE_NAME, 2165 fContext.fTypes.fBool.get(), 2166 true, 2167 Variable::Storage::kGlobal); 2168 fSPIRVBonusVariables.insert(clockwiseVar.get()); 2169 symbols.add(std::move(clockwiseVar)); 2170 } 2171 DSLGlobalVar deviceClockwise(DEVICE_CLOCKWISE_NAME); 2172 // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia, 2173 // we use the default convention of "counter-clockwise face is front". 2174 return this->writeExpression(*dsl::Bool(Select(rtFlip.y() > 0, 2175 !deviceClockwise, 2176 deviceClockwise)).release(), 2177 out); 2178 } 2179 2180 return this->getLValue(ref, out)->load(out); 2181} 2182 2183SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) { 2184 if (expr.base()->type().isVector()) { 2185 SpvId base = this->writeExpression(*expr.base(), out); 2186 SpvId index = this->writeExpression(*expr.index(), out); 2187 SpvId result = this->nextId(nullptr); 2188 this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base, 2189 index, out); 2190 return result; 2191 } 2192 return getLValue(expr, out)->load(out); 2193} 2194 2195SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) { 2196 return getLValue(f, out)->load(out); 2197} 2198 2199SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) { 2200 SpvId base = this->writeExpression(*swizzle.base(), out); 2201 SpvId result = this->nextId(&swizzle.type()); 2202 size_t count = swizzle.components().size(); 2203 if (count == 1) { 2204 this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.type()), result, base, 2205 swizzle.components()[0], out); 2206 } else { 2207 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out); 2208 this->writeWord(this->getType(swizzle.type()), out); 2209 this->writeWord(result, out); 2210 this->writeWord(base, out); 2211 this->writeWord(base, out); 2212 for (int component : swizzle.components()) { 2213 this->writeWord(component, out); 2214 } 2215 } 2216 return result; 2217} 2218 2219SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, 2220 const Type& operandType, SpvId lhs, 2221 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, 2222 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) { 2223 SpvId result = this->nextId(&resultType); 2224 if (is_float(fContext, operandType)) { 2225 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out); 2226 } else if (is_signed(fContext, operandType)) { 2227 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out); 2228 } else if (is_unsigned(fContext, operandType)) { 2229 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out); 2230 } else if (is_bool(fContext, operandType)) { 2231 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out); 2232 } else { 2233 fContext.fErrors->error(operandType.fLine, 2234 "unsupported operand for binary expression: " + operandType.description()); 2235 } 2236 return result; 2237} 2238 2239SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op, 2240 OutputStream& out) { 2241 if (operandType.isVector()) { 2242 SpvId result = this->nextId(nullptr); 2243 this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out); 2244 return result; 2245 } 2246 return id; 2247} 2248 2249SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, 2250 SpvOp_ floatOperator, SpvOp_ intOperator, 2251 SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator, 2252 OutputStream& out) { 2253 SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator; 2254 SkASSERT(operandType.isMatrix()); 2255 SpvId columnType = this->getType(operandType.componentType().toCompound(fContext, 2256 operandType.rows(), 2257 1)); 2258 SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext, 2259 operandType.rows(), 2260 1)); 2261 SpvId boolType = this->getType(*fContext.fTypes.fBool); 2262 SpvId result = 0; 2263 for (int i = 0; i < operandType.columns(); i++) { 2264 SpvId columnL = this->nextId(&operandType); 2265 this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out); 2266 SpvId columnR = this->nextId(&operandType); 2267 this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out); 2268 SpvId compare = this->nextId(&operandType); 2269 this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out); 2270 SpvId merge = this->nextId(nullptr); 2271 this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out); 2272 if (result != 0) { 2273 SpvId next = this->nextId(nullptr); 2274 this->writeInstruction(mergeOperator, boolType, next, result, merge, out); 2275 result = next; 2276 } 2277 else { 2278 result = merge; 2279 } 2280 } 2281 return result; 2282} 2283 2284SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs, 2285 SpvId rhs, SpvOp_ op, OutputStream& out) { 2286 SkASSERT(operandType.isMatrix()); 2287 SpvId columnType = this->getType(operandType.componentType().toCompound(fContext, 2288 operandType.rows(), 2289 1)); 2290 std::vector<SpvId> columns; 2291 columns.reserve(operandType.columns()); 2292 for (int i = 0; i < operandType.columns(); i++) { 2293 SpvId columnL = this->nextId(&operandType); 2294 this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out); 2295 SpvId columnR = this->nextId(&operandType); 2296 this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out); 2297 columns.push_back(this->nextId(&operandType)); 2298 this->writeInstruction(op, columnType, columns[i], columnL, columnR, out); 2299 } 2300 return this->writeComposite(columns, operandType, out); 2301} 2302 2303SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) { 2304 SkASSERT(type.isFloat()); 2305 SpvId one = this->writeLiteral(1.0, type); 2306 SpvId reciprocal = this->nextId(&type); 2307 this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out); 2308 return reciprocal; 2309} 2310 2311SpvId SPIRVCodeGenerator::writeScalarToMatrixSplat(const Type& matrixType, 2312 SpvId scalarId, 2313 OutputStream& out) { 2314 // Splat the scalar into a vector. 2315 const Type& vectorType = matrixType.componentType().toCompound(fContext, 2316 /*columns=*/matrixType.rows(), 2317 /*rows=*/1); 2318 std::vector<SpvId> vecArguments(/*count*/ matrixType.rows(), /*value*/ scalarId); 2319 SpvId vectorId = this->writeComposite(vecArguments, vectorType, out); 2320 2321 // Splat the vector into a matrix. 2322 std::vector<SpvId> matArguments(/*count*/ matrixType.columns(), /*value*/ vectorId); 2323 return this->writeComposite(matArguments, matrixType, out); 2324} 2325 2326SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op, 2327 const Type& rightType, SpvId rhs, 2328 const Type& resultType, OutputStream& out) { 2329 // The comma operator ignores the type of the left-hand side entirely. 2330 if (op.kind() == Token::Kind::TK_COMMA) { 2331 return rhs; 2332 } 2333 // overall type we are operating on: float2, int, uint4... 2334 const Type* operandType; 2335 // IR allows mismatched types in expressions (e.g. float2 * float), but they need special 2336 // handling in SPIR-V 2337 if (this->getActualType(leftType) != this->getActualType(rightType)) { 2338 if (leftType.isVector() && rightType.isNumber()) { 2339 if (resultType.componentType().isFloat()) { 2340 switch (op.kind()) { 2341 case Token::Kind::TK_SLASH: { 2342 rhs = this->writeReciprocal(rightType, rhs, out); 2343 [[fallthrough]]; 2344 } 2345 case Token::Kind::TK_STAR: { 2346 SpvId result = this->nextId(&resultType); 2347 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType), 2348 result, lhs, rhs, out); 2349 return result; 2350 } 2351 default: 2352 break; 2353 } 2354 } 2355 // promote number to vector 2356 const Type& vecType = leftType; 2357 SpvId vec = this->nextId(&vecType); 2358 this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out); 2359 this->writeWord(this->getType(vecType), out); 2360 this->writeWord(vec, out); 2361 for (int i = 0; i < vecType.columns(); i++) { 2362 this->writeWord(rhs, out); 2363 } 2364 rhs = vec; 2365 operandType = &leftType; 2366 } else if (rightType.isVector() && leftType.isNumber()) { 2367 if (resultType.componentType().isFloat()) { 2368 if (op.kind() == Token::Kind::TK_STAR) { 2369 SpvId result = this->nextId(&resultType); 2370 this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType), 2371 result, rhs, lhs, out); 2372 return result; 2373 } 2374 } 2375 // promote number to vector 2376 const Type& vecType = rightType; 2377 SpvId vec = this->nextId(&vecType); 2378 this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out); 2379 this->writeWord(this->getType(vecType), out); 2380 this->writeWord(vec, out); 2381 for (int i = 0; i < vecType.columns(); i++) { 2382 this->writeWord(lhs, out); 2383 } 2384 lhs = vec; 2385 operandType = &rightType; 2386 } else if (leftType.isMatrix()) { 2387 if (op.kind() == Token::Kind::TK_STAR) { 2388 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V. 2389 SpvOp_ spvop; 2390 if (rightType.isMatrix()) { 2391 spvop = SpvOpMatrixTimesMatrix; 2392 } else if (rightType.isVector()) { 2393 spvop = SpvOpMatrixTimesVector; 2394 } else { 2395 SkASSERT(rightType.isScalar()); 2396 spvop = SpvOpMatrixTimesScalar; 2397 } 2398 SpvId result = this->nextId(&resultType); 2399 this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out); 2400 return result; 2401 } else { 2402 // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we 2403 // expect to have a scalar here. 2404 SkASSERT(rightType.isScalar()); 2405 2406 // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path. 2407 SpvId rhsMatrix = this->writeScalarToMatrixSplat(leftType, rhs, out); 2408 2409 // Perform this operation as matrix-op-matrix. 2410 return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix, 2411 resultType, out); 2412 } 2413 } else if (rightType.isMatrix()) { 2414 if (op.kind() == Token::Kind::TK_STAR) { 2415 // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V. 2416 SpvId result = this->nextId(&resultType); 2417 if (leftType.isVector()) { 2418 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), 2419 result, lhs, rhs, out); 2420 } else { 2421 SkASSERT(leftType.isScalar()); 2422 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), 2423 result, rhs, lhs, out); 2424 } 2425 return result; 2426 } else { 2427 // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we 2428 // expect to have a scalar here. 2429 SkASSERT(leftType.isScalar()); 2430 2431 // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path. 2432 SpvId lhsMatrix = this->writeScalarToMatrixSplat(rightType, lhs, out); 2433 2434 // Perform this operation as matrix-op-matrix. 2435 return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs, 2436 resultType, out); 2437 } 2438 } else { 2439 fContext.fErrors->error(leftType.fLine, "unsupported mixed-type expression"); 2440 return -1; 2441 } 2442 } else { 2443 operandType = &this->getActualType(leftType); 2444 SkASSERT(*operandType == this->getActualType(rightType)); 2445 } 2446 switch (op.kind()) { 2447 case Token::Kind::TK_EQEQ: { 2448 if (operandType->isMatrix()) { 2449 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual, 2450 SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out); 2451 } 2452 if (operandType->isStruct()) { 2453 return this->writeStructComparison(*operandType, lhs, op, rhs, out); 2454 } 2455 if (operandType->isArray()) { 2456 return this->writeArrayComparison(*operandType, lhs, op, rhs, out); 2457 } 2458 SkASSERT(resultType.isBoolean()); 2459 const Type* tmpType; 2460 if (operandType->isVector()) { 2461 tmpType = &fContext.fTypes.fBool->toCompound(fContext, 2462 operandType->columns(), 2463 operandType->rows()); 2464 } else { 2465 tmpType = &resultType; 2466 } 2467 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs, 2468 SpvOpFOrdEqual, SpvOpIEqual, 2469 SpvOpIEqual, SpvOpLogicalEqual, out), 2470 *operandType, SpvOpAll, out); 2471 } 2472 case Token::Kind::TK_NEQ: 2473 if (operandType->isMatrix()) { 2474 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual, 2475 SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out); 2476 } 2477 if (operandType->isStruct()) { 2478 return this->writeStructComparison(*operandType, lhs, op, rhs, out); 2479 } 2480 if (operandType->isArray()) { 2481 return this->writeArrayComparison(*operandType, lhs, op, rhs, out); 2482 } 2483 [[fallthrough]]; 2484 case Token::Kind::TK_LOGICALXOR: 2485 SkASSERT(resultType.isBoolean()); 2486 const Type* tmpType; 2487 if (operandType->isVector()) { 2488 tmpType = &fContext.fTypes.fBool->toCompound(fContext, 2489 operandType->columns(), 2490 operandType->rows()); 2491 } else { 2492 tmpType = &resultType; 2493 } 2494 return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs, 2495 SpvOpFOrdNotEqual, SpvOpINotEqual, 2496 SpvOpINotEqual, SpvOpLogicalNotEqual, 2497 out), 2498 *operandType, SpvOpAny, out); 2499 case Token::Kind::TK_GT: 2500 SkASSERT(resultType.isBoolean()); 2501 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2502 SpvOpFOrdGreaterThan, SpvOpSGreaterThan, 2503 SpvOpUGreaterThan, SpvOpUndef, out); 2504 case Token::Kind::TK_LT: 2505 SkASSERT(resultType.isBoolean()); 2506 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan, 2507 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out); 2508 case Token::Kind::TK_GTEQ: 2509 SkASSERT(resultType.isBoolean()); 2510 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2511 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual, 2512 SpvOpUGreaterThanEqual, SpvOpUndef, out); 2513 case Token::Kind::TK_LTEQ: 2514 SkASSERT(resultType.isBoolean()); 2515 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2516 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual, 2517 SpvOpULessThanEqual, SpvOpUndef, out); 2518 case Token::Kind::TK_PLUS: 2519 if (leftType.isMatrix() && rightType.isMatrix()) { 2520 SkASSERT(leftType == rightType); 2521 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFAdd, out); 2522 } 2523 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, 2524 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 2525 case Token::Kind::TK_MINUS: 2526 if (leftType.isMatrix() && rightType.isMatrix()) { 2527 SkASSERT(leftType == rightType); 2528 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFSub, out); 2529 } 2530 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, 2531 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2532 case Token::Kind::TK_STAR: 2533 if (leftType.isMatrix() && rightType.isMatrix()) { 2534 // matrix multiply 2535 SpvId result = this->nextId(&resultType); 2536 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, 2537 lhs, rhs, out); 2538 return result; 2539 } 2540 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul, 2541 SpvOpIMul, SpvOpIMul, SpvOpUndef, out); 2542 case Token::Kind::TK_SLASH: 2543 if (leftType.isMatrix() && rightType.isMatrix()) { 2544 SkASSERT(leftType == rightType); 2545 return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFDiv, out); 2546 } 2547 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv, 2548 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out); 2549 case Token::Kind::TK_PERCENT: 2550 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod, 2551 SpvOpSMod, SpvOpUMod, SpvOpUndef, out); 2552 case Token::Kind::TK_SHL: 2553 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2554 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical, 2555 SpvOpUndef, out); 2556 case Token::Kind::TK_SHR: 2557 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2558 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical, 2559 SpvOpUndef, out); 2560 case Token::Kind::TK_BITWISEAND: 2561 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2562 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out); 2563 case Token::Kind::TK_BITWISEOR: 2564 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2565 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out); 2566 case Token::Kind::TK_BITWISEXOR: 2567 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2568 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out); 2569 default: 2570 fContext.fErrors->error(0, "unsupported token"); 2571 return -1; 2572 } 2573} 2574 2575SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op, 2576 SpvId rhs, OutputStream& out) { 2577 // The inputs must be arrays, and the op must be == or !=. 2578 SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ); 2579 SkASSERT(arrayType.isArray()); 2580 const Type& componentType = arrayType.componentType(); 2581 const SpvId componentTypeId = this->getType(componentType); 2582 const int arraySize = arrayType.columns(); 2583 SkASSERT(arraySize > 0); 2584 2585 // Synthesize equality checks for each item in the array. 2586 const Type& boolType = *fContext.fTypes.fBool; 2587 SpvId allComparisons = (SpvId)-1; 2588 for (int index = 0; index < arraySize; ++index) { 2589 // Get the left and right item in the array. 2590 SpvId itemL = this->nextId(&componentType); 2591 this->writeInstruction(SpvOpCompositeExtract, componentTypeId, itemL, lhs, index, out); 2592 SpvId itemR = this->nextId(&componentType); 2593 this->writeInstruction(SpvOpCompositeExtract, componentTypeId, itemR, rhs, index, out); 2594 // Use `writeBinaryExpression` with the requested == or != operator on these items. 2595 SpvId comparison = this->writeBinaryExpression(componentType, itemL, op, 2596 componentType, itemR, boolType, out); 2597 // Merge this comparison result with all the other comparisons we've done. 2598 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out); 2599 } 2600 return allComparisons; 2601} 2602 2603SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op, 2604 SpvId rhs, OutputStream& out) { 2605 // The inputs must be structs containing fields, and the op must be == or !=. 2606 SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ); 2607 SkASSERT(structType.isStruct()); 2608 const std::vector<Type::Field>& fields = structType.fields(); 2609 SkASSERT(!fields.empty()); 2610 2611 // Synthesize equality checks for each field in the struct. 2612 const Type& boolType = *fContext.fTypes.fBool; 2613 SpvId allComparisons = (SpvId)-1; 2614 for (int index = 0; index < (int)fields.size(); ++index) { 2615 // Get the left and right versions of this field. 2616 const Type& fieldType = *fields[index].fType; 2617 const SpvId fieldTypeId = this->getType(fieldType); 2618 2619 SpvId fieldL = this->nextId(&fieldType); 2620 this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldL, lhs, index, out); 2621 SpvId fieldR = this->nextId(&fieldType); 2622 this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldR, rhs, index, out); 2623 // Use `writeBinaryExpression` with the requested == or != operator on these fields. 2624 SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR, 2625 boolType, out); 2626 // Merge this comparison result with all the other comparisons we've done. 2627 allComparisons = this->mergeComparisons(comparison, allComparisons, op, out); 2628 } 2629 return allComparisons; 2630} 2631 2632SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op, 2633 OutputStream& out) { 2634 // If this is the first entry, we don't need to merge comparison results with anything. 2635 if (allComparisons == (SpvId)-1) { 2636 return comparison; 2637 } 2638 // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons. 2639 const Type& boolType = *fContext.fTypes.fBool; 2640 SpvId boolTypeId = this->getType(boolType); 2641 SpvId logicalOp = this->nextId(&boolType); 2642 switch (op.kind()) { 2643 case Token::Kind::TK_EQEQ: 2644 this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp, 2645 comparison, allComparisons, out); 2646 break; 2647 case Token::Kind::TK_NEQ: 2648 this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp, 2649 comparison, allComparisons, out); 2650 break; 2651 default: 2652 SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName()); 2653 return (SpvId)-1; 2654 } 2655 return logicalOp; 2656} 2657 2658static float division_by_literal_value(Operator op, const Expression& right) { 2659 // If this is a division by a literal value, returns that literal value. Otherwise, returns 0. 2660 if (op.kind() == Token::Kind::TK_SLASH && right.isFloatLiteral()) { 2661 float rhsValue = right.as<Literal>().floatValue(); 2662 if (std::isfinite(rhsValue)) { 2663 return rhsValue; 2664 } 2665 } 2666 return 0.0f; 2667} 2668 2669SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) { 2670 const Expression* left = b.left().get(); 2671 const Expression* right = b.right().get(); 2672 Operator op = b.getOperator(); 2673 2674 switch (op.kind()) { 2675 case Token::Kind::TK_EQ: { 2676 // Handles assignment. 2677 SpvId rhs = this->writeExpression(*right, out); 2678 this->getLValue(*left, out)->store(rhs, out); 2679 return rhs; 2680 } 2681 case Token::Kind::TK_LOGICALAND: 2682 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS. 2683 return this->writeLogicalAnd(*b.left(), *b.right(), out); 2684 2685 case Token::Kind::TK_LOGICALOR: 2686 // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS. 2687 return this->writeLogicalOr(*b.left(), *b.right(), out); 2688 2689 default: 2690 break; 2691 } 2692 2693 std::unique_ptr<LValue> lvalue; 2694 SpvId lhs; 2695 if (op.isAssignment()) { 2696 lvalue = this->getLValue(*left, out); 2697 lhs = lvalue->load(out); 2698 } else { 2699 lvalue = nullptr; 2700 lhs = this->writeExpression(*left, out); 2701 } 2702 2703 SpvId rhs; 2704 float rhsValue = division_by_literal_value(op, *right); 2705 if (rhsValue != 0.0f) { 2706 // Rewrite floating-point division by a literal into multiplication by the reciprocal. 2707 // This converts `expr / 2` into `expr * 0.5` 2708 // This improves codegen, especially for certain types of divides (e.g. vector/scalar). 2709 op = Operator(Token::Kind::TK_STAR); 2710 rhs = this->writeLiteral(1.0 / rhsValue, right->type()); 2711 } else { 2712 // Write the right-hand side expression normally. 2713 rhs = this->writeExpression(*right, out); 2714 } 2715 2716 SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(), 2717 right->type(), rhs, b.type(), out); 2718 if (lvalue) { 2719 lvalue->store(result, out); 2720 } 2721 return result; 2722} 2723 2724SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right, 2725 OutputStream& out) { 2726 SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool); 2727 SpvId lhs = this->writeExpression(left, out); 2728 SpvId rhsLabel = this->nextId(nullptr); 2729 SpvId end = this->nextId(nullptr); 2730 SpvId lhsBlock = fCurrentBlock; 2731 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2732 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out); 2733 this->writeLabel(rhsLabel, out); 2734 SpvId rhs = this->writeExpression(right, out); 2735 SpvId rhsBlock = fCurrentBlock; 2736 this->writeInstruction(SpvOpBranch, end, out); 2737 this->writeLabel(end, out); 2738 SpvId result = this->nextId(nullptr); 2739 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant, 2740 lhsBlock, rhs, rhsBlock, out); 2741 return result; 2742} 2743 2744SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right, 2745 OutputStream& out) { 2746 SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool); 2747 SpvId lhs = this->writeExpression(left, out); 2748 SpvId rhsLabel = this->nextId(nullptr); 2749 SpvId end = this->nextId(nullptr); 2750 SpvId lhsBlock = fCurrentBlock; 2751 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2752 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out); 2753 this->writeLabel(rhsLabel, out); 2754 SpvId rhs = this->writeExpression(right, out); 2755 SpvId rhsBlock = fCurrentBlock; 2756 this->writeInstruction(SpvOpBranch, end, out); 2757 this->writeLabel(end, out); 2758 SpvId result = this->nextId(nullptr); 2759 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant, 2760 lhsBlock, rhs, rhsBlock, out); 2761 return result; 2762} 2763 2764SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) { 2765 const Type& type = t.type(); 2766 SpvId test = this->writeExpression(*t.test(), out); 2767 if (t.ifTrue()->type().columns() == 1 && 2768 t.ifTrue()->isCompileTimeConstant() && 2769 t.ifFalse()->isCompileTimeConstant()) { 2770 // both true and false are constants, can just use OpSelect 2771 SpvId result = this->nextId(nullptr); 2772 SpvId trueId = this->writeExpression(*t.ifTrue(), out); 2773 SpvId falseId = this->writeExpression(*t.ifFalse(), out); 2774 this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId, 2775 out); 2776 return result; 2777 } 2778 // was originally using OpPhi to choose the result, but for some reason that is crashing on 2779 // Adreno. Switched to storing the result in a temp variable as glslang does. 2780 SpvId var = this->nextId(nullptr); 2781 this->writeInstruction(SpvOpVariable, this->getPointerType(type, SpvStorageClassFunction), 2782 var, SpvStorageClassFunction, fVariableBuffer); 2783 SpvId trueLabel = this->nextId(nullptr); 2784 SpvId falseLabel = this->nextId(nullptr); 2785 SpvId end = this->nextId(nullptr); 2786 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2787 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out); 2788 this->writeLabel(trueLabel, out); 2789 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifTrue(), out), out); 2790 this->writeInstruction(SpvOpBranch, end, out); 2791 this->writeLabel(falseLabel, out); 2792 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifFalse(), out), out); 2793 this->writeInstruction(SpvOpBranch, end, out); 2794 this->writeLabel(end, out); 2795 SpvId result = this->nextId(&type); 2796 this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out); 2797 return result; 2798} 2799 2800SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) { 2801 const Type& type = p.type(); 2802 if (p.getOperator().kind() == Token::Kind::TK_MINUS) { 2803 SpvId result = this->nextId(&type); 2804 SpvId typeId = this->getType(type); 2805 SpvId expr = this->writeExpression(*p.operand(), out); 2806 if (is_float(fContext, type)) { 2807 this->writeInstruction(SpvOpFNegate, typeId, result, expr, out); 2808 } else if (is_signed(fContext, type) || is_unsigned(fContext, type)) { 2809 this->writeInstruction(SpvOpSNegate, typeId, result, expr, out); 2810 } else { 2811 SkDEBUGFAILF("unsupported prefix expression %s", p.description().c_str()); 2812 } 2813 return result; 2814 } 2815 switch (p.getOperator().kind()) { 2816 case Token::Kind::TK_PLUS: 2817 return this->writeExpression(*p.operand(), out); 2818 case Token::Kind::TK_PLUSPLUS: { 2819 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out); 2820 SpvId one = this->writeLiteral(1.0, type); 2821 SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, 2822 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, 2823 out); 2824 lv->store(result, out); 2825 return result; 2826 } 2827 case Token::Kind::TK_MINUSMINUS: { 2828 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out); 2829 SpvId one = this->writeLiteral(1.0, type); 2830 SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFSub, 2831 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2832 lv->store(result, out); 2833 return result; 2834 } 2835 case Token::Kind::TK_LOGICALNOT: { 2836 SkASSERT(p.operand()->type().isBoolean()); 2837 SpvId result = this->nextId(nullptr); 2838 this->writeInstruction(SpvOpLogicalNot, this->getType(type), result, 2839 this->writeExpression(*p.operand(), out), out); 2840 return result; 2841 } 2842 case Token::Kind::TK_BITWISENOT: { 2843 SpvId result = this->nextId(nullptr); 2844 this->writeInstruction(SpvOpNot, this->getType(type), result, 2845 this->writeExpression(*p.operand(), out), out); 2846 return result; 2847 } 2848 default: 2849 SkDEBUGFAILF("unsupported prefix expression: %s", p.description().c_str()); 2850 return -1; 2851 } 2852} 2853 2854SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) { 2855 const Type& type = p.type(); 2856 std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out); 2857 SpvId result = lv->load(out); 2858 SpvId one = this->writeLiteral(1.0, type); 2859 switch (p.getOperator().kind()) { 2860 case Token::Kind::TK_PLUSPLUS: { 2861 SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFAdd, 2862 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 2863 lv->store(temp, out); 2864 return result; 2865 } 2866 case Token::Kind::TK_MINUSMINUS: { 2867 SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFSub, 2868 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2869 lv->store(temp, out); 2870 return result; 2871 } 2872 default: 2873 SkDEBUGFAILF("unsupported postfix expression %s", p.description().c_str()); 2874 return -1; 2875 } 2876} 2877 2878SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) { 2879 return this->writeLiteral(l.value(), l.type()); 2880} 2881 2882SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) { 2883 int32_t valueBits; 2884 if (type.isFloat()) { 2885 float fValue = value; 2886 memcpy(&valueBits, &fValue, sizeof(valueBits)); 2887 } else { 2888 SKSL_INT iValue = value; 2889 valueBits = iValue; 2890 } 2891 2892 SPIRVNumberConstant key{valueBits, type.numberKind()}; 2893 auto [iter, newlyCreated] = fNumberConstants.insert({key, (SpvId)-1}); 2894 if (newlyCreated) { 2895 SpvId result = this->nextId(nullptr); 2896 iter->second = result; 2897 2898 if (type.isBoolean()) { 2899 this->writeInstruction(valueBits ? SpvOpConstantTrue : SpvOpConstantFalse, 2900 this->getType(type), result, fConstantBuffer); 2901 } else { 2902 this->writeInstruction(SpvOpConstant, this->getType(type), result, 2903 (SpvId)valueBits, fConstantBuffer); 2904 } 2905 } 2906 2907 return iter->second; 2908} 2909 2910SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) { 2911 SpvId result = fFunctionMap[&f]; 2912 SpvId returnTypeId = this->getType(f.returnType()); 2913 SpvId functionTypeId = this->getFunctionType(f); 2914 this->writeInstruction(SpvOpFunction, returnTypeId, result, 2915 SpvFunctionControlMaskNone, functionTypeId, out); 2916 String mangledName = f.mangledName(); 2917 this->writeInstruction(SpvOpName, 2918 result, 2919 skstd::string_view(mangledName.c_str(), mangledName.size()), 2920 fNameBuffer); 2921 for (const Variable* parameter : f.parameters()) { 2922 SpvId id = this->nextId(nullptr); 2923 fVariableMap[parameter] = id; 2924 SpvId type = this->getPointerType(parameter->type(), SpvStorageClassFunction); 2925 this->writeInstruction(SpvOpFunctionParameter, type, id, out); 2926 } 2927 return result; 2928} 2929 2930SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) { 2931 fVariableBuffer.reset(); 2932 SpvId result = this->writeFunctionStart(f.declaration(), out); 2933 fCurrentBlock = 0; 2934 this->writeLabel(this->nextId(nullptr), out); 2935 StringStream bodyBuffer; 2936 this->writeBlock(f.body()->as<Block>(), bodyBuffer); 2937 write_stringstream(fVariableBuffer, out); 2938 if (f.declaration().isMain()) { 2939 write_stringstream(fGlobalInitializersBuffer, out); 2940 } 2941 write_stringstream(bodyBuffer, out); 2942 if (fCurrentBlock) { 2943 if (f.declaration().returnType().isVoid()) { 2944 this->writeInstruction(SpvOpReturn, out); 2945 } else { 2946 this->writeInstruction(SpvOpUnreachable, out); 2947 } 2948 } 2949 this->writeInstruction(SpvOpFunctionEnd, out); 2950 return result; 2951} 2952 2953void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) { 2954 if (layout.fLocation >= 0) { 2955 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation, 2956 fDecorationBuffer); 2957 } 2958 if (layout.fBinding >= 0) { 2959 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding, 2960 fDecorationBuffer); 2961 } 2962 if (layout.fIndex >= 0) { 2963 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex, 2964 fDecorationBuffer); 2965 } 2966 if (layout.fSet >= 0) { 2967 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet, 2968 fDecorationBuffer); 2969 } 2970 if (layout.fInputAttachmentIndex >= 0) { 2971 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex, 2972 layout.fInputAttachmentIndex, fDecorationBuffer); 2973 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment); 2974 } 2975 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) { 2976 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin, 2977 fDecorationBuffer); 2978 } 2979} 2980 2981void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) { 2982 if (layout.fLocation >= 0) { 2983 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation, 2984 layout.fLocation, fDecorationBuffer); 2985 } 2986 if (layout.fBinding >= 0) { 2987 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding, 2988 layout.fBinding, fDecorationBuffer); 2989 } 2990 if (layout.fIndex >= 0) { 2991 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex, 2992 layout.fIndex, fDecorationBuffer); 2993 } 2994 if (layout.fSet >= 0) { 2995 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet, 2996 layout.fSet, fDecorationBuffer); 2997 } 2998 if (layout.fInputAttachmentIndex >= 0) { 2999 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex, 3000 layout.fInputAttachmentIndex, fDecorationBuffer); 3001 } 3002 if (layout.fBuiltin >= 0) { 3003 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn, 3004 layout.fBuiltin, fDecorationBuffer); 3005 } 3006} 3007 3008MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const { 3009 bool pushConstant = ((v.modifiers().fLayout.fFlags & Layout::kPushConstant_Flag) != 0); 3010 return pushConstant ? MemoryLayout(MemoryLayout::k430_Standard) : fDefaultLayout; 3011} 3012 3013SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) { 3014 MemoryLayout memoryLayout = this->memoryLayoutForVariable(intf.variable()); 3015 SpvId result = this->nextId(nullptr); 3016 const Variable& intfVar = intf.variable(); 3017 const Type& type = intfVar.type(); 3018 if (!MemoryLayout::LayoutIsSupported(type)) { 3019 fContext.fErrors->error(type.fLine, "type '" + type.name() + "' is not permitted here"); 3020 return this->nextId(nullptr); 3021 } 3022 SpvStorageClass_ storageClass = get_storage_class(intf.variable(), SpvStorageClassFunction); 3023 if (fProgram.fInputs.fUseFlipRTUniform && appendRTFlip && type.isStruct()) { 3024 // We can only have one interface block (because we use push_constant and that is limited 3025 // to one per program), so we need to append rtflip to this one rather than synthesize an 3026 // entirely new block when the variable is referenced. And we can't modify the existing 3027 // block, so we instead create a modified copy of it and write that. 3028 std::vector<Type::Field> fields = type.fields(); 3029 fields.emplace_back(Modifiers(Layout(/*flags=*/0, 3030 /*location=*/-1, 3031 fProgram.fConfig->fSettings.fRTFlipOffset, 3032 /*binding=*/-1, 3033 /*index=*/-1, 3034 /*set=*/-1, 3035 /*builtin=*/-1, 3036 /*inputAttachmentIndex=*/-1), 3037 /*flags=*/0), 3038 SKSL_RTFLIP_NAME, 3039 fContext.fTypes.fFloat2.get()); 3040 { 3041 AutoAttachPoolToThread attach(fProgram.fPool.get()); 3042 const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol( 3043 Type::MakeStructType(type.fLine, type.name(), std::move(fields))); 3044 const Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol( 3045 std::make_unique<Variable>(intfVar.fLine, 3046 &intfVar.modifiers(), 3047 intfVar.name(), 3048 rtFlipStructType, 3049 intfVar.isBuiltin(), 3050 intfVar.storage())); 3051 fSPIRVBonusVariables.insert(modifiedVar); 3052 InterfaceBlock modifiedCopy(intf.fLine, 3053 *modifiedVar, 3054 intf.typeName(), 3055 intf.instanceName(), 3056 intf.arraySize(), 3057 intf.typeOwner()); 3058 result = this->writeInterfaceBlock(modifiedCopy, false); 3059 fProgram.fSymbols->add(std::make_unique<Field>( 3060 /*line=*/-1, modifiedVar, rtFlipStructType->fields().size() - 1)); 3061 } 3062 fVariableMap[&intfVar] = result; 3063 fWroteRTFlip = true; 3064 return result; 3065 } 3066 const Modifiers& intfModifiers = intfVar.modifiers(); 3067 SpvId typeId = this->getType(type, memoryLayout); 3068 if (intfModifiers.fLayout.fBuiltin == -1) { 3069 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer); 3070 } 3071 SpvId ptrType = this->nextId(nullptr); 3072 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer); 3073 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer); 3074 Layout layout = intfModifiers.fLayout; 3075 if (intfModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) { 3076 layout.fSet = 0; 3077 } 3078 this->writeLayout(layout, result); 3079 fVariableMap[&intfVar] = result; 3080 return result; 3081} 3082 3083bool SPIRVCodeGenerator::isDead(const Variable& var) const { 3084 // During SPIR-V code generation, we synthesize some extra bonus variables that don't actually 3085 // exist in the Program at all and aren't tracked by the ProgramUsage. They aren't dead, though. 3086 if (fSPIRVBonusVariables.count(&var)) { 3087 return false; 3088 } 3089 ProgramUsage::VariableCounts counts = fProgram.usage()->get(var); 3090 if (counts.fRead || counts.fWrite) { 3091 return false; 3092 } 3093 // It's not entirely clear what the rules are for eliding interface variables. Generally, it 3094 // causes problems to elide them, even when they're dead. 3095 return !(var.modifiers().fFlags & 3096 (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag)); 3097} 3098 3099void SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind, const VarDeclaration& varDecl) { 3100 const Variable& var = varDecl.var(); 3101 if (var.modifiers().fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN && 3102 kind != ProgramKind::kFragment) { 3103 SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut); 3104 return; 3105 } 3106 if (var.modifiers().fLayout.fBuiltin == SK_SECONDARYFRAGCOLOR_BUILTIN) { 3107 return; 3108 } 3109 if (this->isDead(var)) { 3110 return; 3111 } 3112 SpvStorageClass_ storageClass = get_storage_class(var, SpvStorageClassPrivate); 3113 if (storageClass == SpvStorageClassUniform) { 3114 // Top-level uniforms are emitted in writeUniformBuffer. 3115 fTopLevelUniforms.push_back(&varDecl); 3116 return; 3117 } 3118 const Type& type = var.type(); 3119 Layout layout = var.modifiers().fLayout; 3120 if (layout.fSet < 0 && storageClass == SpvStorageClassUniformConstant) { 3121 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet; 3122 } 3123 SpvId id = this->nextId(&type); 3124 fVariableMap[&var] = id; 3125 SpvId typeId = this->getPointerType(type, storageClass); 3126 this->writeInstruction(SpvOpVariable, typeId, id, storageClass, fConstantBuffer); 3127 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer); 3128 if (varDecl.value()) { 3129 SkASSERT(!fCurrentBlock); 3130 fCurrentBlock = -1; 3131 SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer); 3132 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer); 3133 fCurrentBlock = 0; 3134 } 3135 this->writeLayout(layout, id); 3136 if (var.modifiers().fFlags & Modifiers::kFlat_Flag) { 3137 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer); 3138 } 3139 if (var.modifiers().fFlags & Modifiers::kNoPerspective_Flag) { 3140 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective, 3141 fDecorationBuffer); 3142 } 3143} 3144 3145void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) { 3146 const Variable& var = varDecl.var(); 3147 SpvId id = this->nextId(&var.type()); 3148 fVariableMap[&var] = id; 3149 SpvId type = this->getPointerType(var.type(), SpvStorageClassFunction); 3150 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer); 3151 this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer); 3152 if (varDecl.value()) { 3153 SpvId value = this->writeExpression(*varDecl.value(), out); 3154 this->writeInstruction(SpvOpStore, id, value, out); 3155 } 3156} 3157 3158void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) { 3159 switch (s.kind()) { 3160 case Statement::Kind::kInlineMarker: 3161 case Statement::Kind::kNop: 3162 break; 3163 case Statement::Kind::kBlock: 3164 this->writeBlock(s.as<Block>(), out); 3165 break; 3166 case Statement::Kind::kExpression: 3167 this->writeExpression(*s.as<ExpressionStatement>().expression(), out); 3168 break; 3169 case Statement::Kind::kReturn: 3170 this->writeReturnStatement(s.as<ReturnStatement>(), out); 3171 break; 3172 case Statement::Kind::kVarDeclaration: 3173 this->writeVarDeclaration(s.as<VarDeclaration>(), out); 3174 break; 3175 case Statement::Kind::kIf: 3176 this->writeIfStatement(s.as<IfStatement>(), out); 3177 break; 3178 case Statement::Kind::kFor: 3179 this->writeForStatement(s.as<ForStatement>(), out); 3180 break; 3181 case Statement::Kind::kDo: 3182 this->writeDoStatement(s.as<DoStatement>(), out); 3183 break; 3184 case Statement::Kind::kSwitch: 3185 this->writeSwitchStatement(s.as<SwitchStatement>(), out); 3186 break; 3187 case Statement::Kind::kBreak: 3188 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out); 3189 break; 3190 case Statement::Kind::kContinue: 3191 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out); 3192 break; 3193 case Statement::Kind::kDiscard: 3194 this->writeInstruction(SpvOpKill, out); 3195 break; 3196 default: 3197 SkDEBUGFAILF("unsupported statement: %s", s.description().c_str()); 3198 break; 3199 } 3200} 3201 3202void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) { 3203 for (const std::unique_ptr<Statement>& stmt : b.children()) { 3204 this->writeStatement(*stmt, out); 3205 } 3206} 3207 3208void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) { 3209 SpvId test = this->writeExpression(*stmt.test(), out); 3210 SpvId ifTrue = this->nextId(nullptr); 3211 SpvId ifFalse = this->nextId(nullptr); 3212 if (stmt.ifFalse()) { 3213 SpvId end = this->nextId(nullptr); 3214 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 3215 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out); 3216 this->writeLabel(ifTrue, out); 3217 this->writeStatement(*stmt.ifTrue(), out); 3218 if (fCurrentBlock) { 3219 this->writeInstruction(SpvOpBranch, end, out); 3220 } 3221 this->writeLabel(ifFalse, out); 3222 this->writeStatement(*stmt.ifFalse(), out); 3223 if (fCurrentBlock) { 3224 this->writeInstruction(SpvOpBranch, end, out); 3225 } 3226 this->writeLabel(end, out); 3227 } else { 3228 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out); 3229 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out); 3230 this->writeLabel(ifTrue, out); 3231 this->writeStatement(*stmt.ifTrue(), out); 3232 if (fCurrentBlock) { 3233 this->writeInstruction(SpvOpBranch, ifFalse, out); 3234 } 3235 this->writeLabel(ifFalse, out); 3236 } 3237} 3238 3239void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) { 3240 if (f.initializer()) { 3241 this->writeStatement(*f.initializer(), out); 3242 } 3243 SpvId header = this->nextId(nullptr); 3244 SpvId start = this->nextId(nullptr); 3245 SpvId body = this->nextId(nullptr); 3246 SpvId next = this->nextId(nullptr); 3247 fContinueTarget.push(next); 3248 SpvId end = this->nextId(nullptr); 3249 fBreakTarget.push(end); 3250 this->writeInstruction(SpvOpBranch, header, out); 3251 this->writeLabel(header, out); 3252 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out); 3253 this->writeInstruction(SpvOpBranch, start, out); 3254 this->writeLabel(start, out); 3255 if (f.test()) { 3256 SpvId test = this->writeExpression(*f.test(), out); 3257 this->writeInstruction(SpvOpBranchConditional, test, body, end, out); 3258 } else { 3259 this->writeInstruction(SpvOpBranch, body, out); 3260 } 3261 this->writeLabel(body, out); 3262 this->writeStatement(*f.statement(), out); 3263 if (fCurrentBlock) { 3264 this->writeInstruction(SpvOpBranch, next, out); 3265 } 3266 this->writeLabel(next, out); 3267 if (f.next()) { 3268 this->writeExpression(*f.next(), out); 3269 } 3270 this->writeInstruction(SpvOpBranch, header, out); 3271 this->writeLabel(end, out); 3272 fBreakTarget.pop(); 3273 fContinueTarget.pop(); 3274} 3275 3276void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) { 3277 SpvId header = this->nextId(nullptr); 3278 SpvId start = this->nextId(nullptr); 3279 SpvId next = this->nextId(nullptr); 3280 SpvId continueTarget = this->nextId(nullptr); 3281 fContinueTarget.push(continueTarget); 3282 SpvId end = this->nextId(nullptr); 3283 fBreakTarget.push(end); 3284 this->writeInstruction(SpvOpBranch, header, out); 3285 this->writeLabel(header, out); 3286 this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out); 3287 this->writeInstruction(SpvOpBranch, start, out); 3288 this->writeLabel(start, out); 3289 this->writeStatement(*d.statement(), out); 3290 if (fCurrentBlock) { 3291 this->writeInstruction(SpvOpBranch, next, out); 3292 } 3293 this->writeLabel(next, out); 3294 this->writeInstruction(SpvOpBranch, continueTarget, out); 3295 this->writeLabel(continueTarget, out); 3296 SpvId test = this->writeExpression(*d.test(), out); 3297 this->writeInstruction(SpvOpBranchConditional, test, header, end, out); 3298 this->writeLabel(end, out); 3299 fBreakTarget.pop(); 3300 fContinueTarget.pop(); 3301} 3302 3303void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) { 3304 SpvId value = this->writeExpression(*s.value(), out); 3305 std::vector<SpvId> labels; 3306 SpvId end = this->nextId(nullptr); 3307 SpvId defaultLabel = end; 3308 fBreakTarget.push(end); 3309 int size = 3; 3310 auto& cases = s.cases(); 3311 for (const std::unique_ptr<Statement>& stmt : cases) { 3312 const SwitchCase& c = stmt->as<SwitchCase>(); 3313 SpvId label = this->nextId(nullptr); 3314 labels.push_back(label); 3315 if (c.value()) { 3316 size += 2; 3317 } else { 3318 defaultLabel = label; 3319 } 3320 } 3321 labels.push_back(end); 3322 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 3323 this->writeOpCode(SpvOpSwitch, size, out); 3324 this->writeWord(value, out); 3325 this->writeWord(defaultLabel, out); 3326 for (size_t i = 0; i < cases.size(); ++i) { 3327 const SwitchCase& c = cases[i]->as<SwitchCase>(); 3328 if (!c.value()) { 3329 continue; 3330 } 3331 this->writeWord(c.value()->as<Literal>().intValue(), out); 3332 this->writeWord(labels[i], out); 3333 } 3334 for (size_t i = 0; i < cases.size(); ++i) { 3335 const SwitchCase& c = cases[i]->as<SwitchCase>(); 3336 this->writeLabel(labels[i], out); 3337 this->writeStatement(*c.statement(), out); 3338 if (fCurrentBlock) { 3339 this->writeInstruction(SpvOpBranch, labels[i + 1], out); 3340 } 3341 } 3342 this->writeLabel(end, out); 3343 fBreakTarget.pop(); 3344} 3345 3346void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) { 3347 if (r.expression()) { 3348 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out), 3349 out); 3350 } else { 3351 this->writeInstruction(SpvOpReturn, out); 3352 } 3353} 3354 3355// Given any function, returns the top-level symbol table (OUTSIDE of the function's scope). 3356static std::shared_ptr<SymbolTable> get_top_level_symbol_table(const FunctionDeclaration& anyFunc) { 3357 return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent; 3358} 3359 3360SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter( 3361 const FunctionDeclaration& main) { 3362 // Our goal is to synthesize a tiny helper function which looks like this: 3363 // void _entrypoint() { sk_FragColor = main(); } 3364 3365 // Fish a symbol table out of main(). 3366 std::shared_ptr<SymbolTable> symbolTable = get_top_level_symbol_table(main); 3367 3368 // Get `sk_FragColor` as a writable reference. 3369 const Symbol* skFragColorSymbol = (*symbolTable)["sk_FragColor"]; 3370 SkASSERT(skFragColorSymbol); 3371 const Variable& skFragColorVar = skFragColorSymbol->as<Variable>(); 3372 auto skFragColorRef = std::make_unique<VariableReference>(/*line=*/-1, &skFragColorVar, 3373 VariableReference::RefKind::kWrite); 3374 // Synthesize a call to the `main()` function. 3375 if (main.returnType() != skFragColorRef->type()) { 3376 fContext.fErrors->error(main.fLine, "SPIR-V does not support returning '" + 3377 main.returnType().description() + "' from main()"); 3378 return {}; 3379 } 3380 ExpressionArray args; 3381 if (main.parameters().size() == 1) { 3382 if (main.parameters()[0]->type() != *fContext.fTypes.fFloat2) { 3383 fContext.fErrors->error(main.fLine, 3384 "SPIR-V does not support parameter of type '" + 3385 main.parameters()[0]->type().description() + "' to main()"); 3386 return {}; 3387 } 3388 args.push_back(dsl::Float2(0).release()); 3389 } 3390 auto callMainFn = std::make_unique<FunctionCall>(/*line=*/-1, &main.returnType(), &main, 3391 std::move(args)); 3392 3393 // Synthesize `skFragColor = main()` as a BinaryExpression. 3394 auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>( 3395 /*line=*/-1, 3396 std::move(skFragColorRef), 3397 Token::Kind::TK_EQ, 3398 std::move(callMainFn), 3399 &main.returnType())); 3400 3401 // Function bodies are always wrapped in a Block. 3402 StatementArray entrypointStmts; 3403 entrypointStmts.push_back(std::move(assignmentStmt)); 3404 auto entrypointBlock = Block::Make(/*line=*/-1, std::move(entrypointStmts), 3405 symbolTable, /*isScope=*/true); 3406 // Declare an entrypoint function. 3407 EntrypointAdapter adapter; 3408 adapter.fLayout = {}; 3409 adapter.fModifiers = Modifiers{adapter.fLayout, Modifiers::kHasSideEffects_Flag}; 3410 adapter.entrypointDecl = 3411 std::make_unique<FunctionDeclaration>(/*line=*/-1, 3412 &adapter.fModifiers, 3413 "_entrypoint", 3414 /*parameters=*/std::vector<const Variable*>{}, 3415 /*returnType=*/fContext.fTypes.fVoid.get(), 3416 /*builtin=*/false); 3417 // Define it. 3418 adapter.entrypointDef = FunctionDefinition::Convert(fContext, 3419 /*line=*/-1, 3420 *adapter.entrypointDecl, 3421 std::move(entrypointBlock), 3422 /*builtin=*/false); 3423 3424 adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get()); 3425 return adapter; 3426} 3427 3428void SPIRVCodeGenerator::writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable) { 3429 SkASSERT(!fTopLevelUniforms.empty()); 3430 static constexpr char kUniformBufferName[] = "_UniformBuffer"; 3431 3432 // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build 3433 // a lookup table of variables to UniformBuffer field indices. 3434 std::vector<Type::Field> fields; 3435 fields.reserve(fTopLevelUniforms.size()); 3436 fTopLevelUniformMap.reserve(fTopLevelUniforms.size()); 3437 for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) { 3438 const Variable* var = &topLevelUniform->var(); 3439 fTopLevelUniformMap[var] = (int)fields.size(); 3440 fields.emplace_back(var->modifiers(), var->name(), &var->type()); 3441 } 3442 fUniformBuffer.fStruct = Type::MakeStructType(/*line=*/-1, kUniformBufferName, 3443 std::move(fields)); 3444 3445 // Create a global variable to contain this struct. 3446 Layout layout; 3447 layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding; 3448 layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet; 3449 Modifiers modifiers{layout, Modifiers::kUniform_Flag}; 3450 3451 fUniformBuffer.fInnerVariable = std::make_unique<Variable>( 3452 /*line=*/-1, fProgram.fModifiers->add(modifiers), kUniformBufferName, 3453 fUniformBuffer.fStruct.get(), /*builtin=*/false, Variable::Storage::kGlobal); 3454 3455 // Create an interface block object for this global variable. 3456 fUniformBuffer.fInterfaceBlock = std::make_unique<InterfaceBlock>( 3457 /*offset=*/-1, *fUniformBuffer.fInnerVariable, kUniformBufferName, 3458 kUniformBufferName, /*arraySize=*/0, topLevelSymbolTable); 3459 3460 // Generate an interface block and hold onto its ID. 3461 fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock); 3462} 3463 3464void SPIRVCodeGenerator::addRTFlipUniform(int line) { 3465 if (fWroteRTFlip) { 3466 return; 3467 } 3468 // Flip variable hasn't been written yet. This means we don't have an existing 3469 // interface block, so we're free to just synthesize one. 3470 fWroteRTFlip = true; 3471 std::vector<Type::Field> fields; 3472 if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) { 3473 fContext.fErrors->error(line, "RTFlipOffset is negative"); 3474 } 3475 fields.emplace_back(Modifiers(Layout(/*flags=*/0, 3476 /*location=*/-1, 3477 fProgram.fConfig->fSettings.fRTFlipOffset, 3478 /*binding=*/-1, 3479 /*index=*/-1, 3480 /*set=*/-1, 3481 /*builtin=*/-1, 3482 /*inputAttachmentIndex=*/-1), 3483 /*flags=*/0), 3484 SKSL_RTFLIP_NAME, 3485 fContext.fTypes.fFloat2.get()); 3486 skstd::string_view name = "sksl_synthetic_uniforms"; 3487 const Type* intfStruct = 3488 fSynthetics.takeOwnershipOfSymbol(Type::MakeStructType(/*line=*/-1, name, fields)); 3489 int binding = fProgram.fConfig->fSettings.fRTFlipBinding; 3490 if (binding == -1) { 3491 fContext.fErrors->error(line, "layout(binding=...) is required in SPIR-V"); 3492 } 3493 int set = fProgram.fConfig->fSettings.fRTFlipSet; 3494 if (set == -1) { 3495 fContext.fErrors->error(line, "layout(set=...) is required in SPIR-V"); 3496 } 3497 bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants; 3498 int flags = usePushConstants ? Layout::Flag::kPushConstant_Flag : 0; 3499 const Modifiers* modsPtr; 3500 { 3501 AutoAttachPoolToThread attach(fProgram.fPool.get()); 3502 Modifiers modifiers(Layout(flags, 3503 /*location=*/-1, 3504 /*offset=*/-1, 3505 binding, 3506 /*index=*/-1, 3507 set, 3508 /*builtin=*/-1, 3509 /*inputAttachmentIndex=*/-1), 3510 Modifiers::kUniform_Flag); 3511 modsPtr = fProgram.fModifiers->add(modifiers); 3512 } 3513 const Variable* intfVar = fSynthetics.takeOwnershipOfSymbol( 3514 std::make_unique<Variable>(/*line=*/-1, 3515 modsPtr, 3516 name, 3517 intfStruct, 3518 /*builtin=*/false, 3519 Variable::Storage::kGlobal)); 3520 fSPIRVBonusVariables.insert(intfVar); 3521 { 3522 AutoAttachPoolToThread attach(fProgram.fPool.get()); 3523 fProgram.fSymbols->add(std::make_unique<Field>(/*line=*/-1, intfVar, /*field=*/0)); 3524 } 3525 InterfaceBlock intf(/*line=*/-1, 3526 *intfVar, 3527 name, 3528 /*instanceName=*/"", 3529 /*arraySize=*/0, 3530 std::make_shared<SymbolTable>(fContext, /*builtin=*/false)); 3531 3532 this->writeInterfaceBlock(intf, false); 3533} 3534 3535void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) { 3536 fGLSLExtendedInstructions = this->nextId(nullptr); 3537 StringStream body; 3538 // Assign SpvIds to functions. 3539 const FunctionDeclaration* main = nullptr; 3540 for (const ProgramElement* e : program.elements()) { 3541 if (e->is<FunctionDefinition>()) { 3542 const FunctionDefinition& funcDef = e->as<FunctionDefinition>(); 3543 const FunctionDeclaration& funcDecl = funcDef.declaration(); 3544 fFunctionMap[&funcDecl] = this->nextId(nullptr); 3545 if (funcDecl.isMain()) { 3546 main = &funcDecl; 3547 } 3548 } 3549 } 3550 // Make sure we have a main() function. 3551 if (!main) { 3552 fContext.fErrors->error(/*line=*/-1, "program does not contain a main() function"); 3553 return; 3554 } 3555 // Emit interface blocks. 3556 std::set<SpvId> interfaceVars; 3557 for (const ProgramElement* e : program.elements()) { 3558 if (e->is<InterfaceBlock>()) { 3559 const InterfaceBlock& intf = e->as<InterfaceBlock>(); 3560 SpvId id = this->writeInterfaceBlock(intf); 3561 3562 const Modifiers& modifiers = intf.variable().modifiers(); 3563 if ((modifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) && 3564 modifiers.fLayout.fBuiltin == -1 && !this->isDead(intf.variable())) { 3565 interfaceVars.insert(id); 3566 } 3567 } 3568 } 3569 // Emit global variable declarations. 3570 for (const ProgramElement* e : program.elements()) { 3571 if (e->is<GlobalVarDeclaration>()) { 3572 this->writeGlobalVar(program.fConfig->fKind, 3573 e->as<GlobalVarDeclaration>().declaration()->as<VarDeclaration>()); 3574 } 3575 } 3576 // Emit top-level uniforms into a dedicated uniform buffer. 3577 if (!fTopLevelUniforms.empty()) { 3578 this->writeUniformBuffer(get_top_level_symbol_table(*main)); 3579 } 3580 // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real 3581 // main() and stores the result into sk_FragColor. 3582 EntrypointAdapter adapter; 3583 if (main->returnType() == *fContext.fTypes.fHalf4) { 3584 adapter = this->writeEntrypointAdapter(*main); 3585 if (adapter.entrypointDecl) { 3586 fFunctionMap[adapter.entrypointDecl.get()] = this->nextId(nullptr); 3587 this->writeFunction(*adapter.entrypointDef, body); 3588 main = adapter.entrypointDecl.get(); 3589 } 3590 } 3591 // Emit all the functions. 3592 for (const ProgramElement* e : program.elements()) { 3593 if (e->is<FunctionDefinition>()) { 3594 this->writeFunction(e->as<FunctionDefinition>(), body); 3595 } 3596 } 3597 // Add global in/out variables to the list of interface variables. 3598 for (auto entry : fVariableMap) { 3599 const Variable* var = entry.first; 3600 if (var->storage() == Variable::Storage::kGlobal && 3601 (var->modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) && 3602 !this->isDead(*var)) { 3603 interfaceVars.insert(entry.second); 3604 } 3605 } 3606 this->writeCapabilities(out); 3607 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out); 3608 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out); 3609 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->name().length() + 4) / 4) + 3610 (int32_t) interfaceVars.size(), out); 3611 switch (program.fConfig->fKind) { 3612 case ProgramKind::kVertex: 3613 this->writeWord(SpvExecutionModelVertex, out); 3614 break; 3615 case ProgramKind::kFragment: 3616 this->writeWord(SpvExecutionModelFragment, out); 3617 break; 3618 default: 3619 SK_ABORT("cannot write this kind of program to SPIR-V\n"); 3620 } 3621 SpvId entryPoint = fFunctionMap[main]; 3622 this->writeWord(entryPoint, out); 3623 this->writeString(main->name(), out); 3624 for (int var : interfaceVars) { 3625 this->writeWord(var, out); 3626 } 3627 if (program.fConfig->fKind == ProgramKind::kFragment) { 3628 this->writeInstruction(SpvOpExecutionMode, 3629 fFunctionMap[main], 3630 SpvExecutionModeOriginUpperLeft, 3631 out); 3632 } 3633 for (const ProgramElement* e : program.elements()) { 3634 if (e->is<Extension>()) { 3635 this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out); 3636 } 3637 } 3638 3639 write_stringstream(fExtraGlobalsBuffer, out); 3640 write_stringstream(fNameBuffer, out); 3641 write_stringstream(fDecorationBuffer, out); 3642 write_stringstream(fConstantBuffer, out); 3643 write_stringstream(body, out); 3644} 3645 3646bool SPIRVCodeGenerator::generateCode() { 3647 SkASSERT(!fContext.fErrors->errorCount()); 3648 this->writeWord(SpvMagicNumber, *fOut); 3649 this->writeWord(SpvVersion, *fOut); 3650 this->writeWord(SKSL_MAGIC, *fOut); 3651 StringStream buffer; 3652 this->writeInstructions(fProgram, buffer); 3653 this->writeWord(fIdCount, *fOut); 3654 this->writeWord(0, *fOut); // reserved, always zero 3655 write_stringstream(buffer, *fOut); 3656 fContext.fErrors->reportPendingErrors(PositionInfo()); 3657 return fContext.fErrors->errorCount() == 0; 3658} 3659 3660} // namespace SkSL 3661