1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/sksl/codegen/SkSLSPIRVCodeGenerator.h"
9
10#include "src/sksl/GLSL.std.450.h"
11
12#include "include/sksl/DSLCore.h"
13#include "src/sksl/SkSLCompiler.h"
14#include "src/sksl/SkSLOperators.h"
15#include "src/sksl/SkSLThreadContext.h"
16#include "src/sksl/ir/SkSLBinaryExpression.h"
17#include "src/sksl/ir/SkSLBlock.h"
18#include "src/sksl/ir/SkSLConstructorArrayCast.h"
19#include "src/sksl/ir/SkSLConstructorCompound.h"
20#include "src/sksl/ir/SkSLConstructorCompoundCast.h"
21#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
22#include "src/sksl/ir/SkSLConstructorMatrixResize.h"
23#include "src/sksl/ir/SkSLConstructorScalarCast.h"
24#include "src/sksl/ir/SkSLConstructorSplat.h"
25#include "src/sksl/ir/SkSLDoStatement.h"
26#include "src/sksl/ir/SkSLExpressionStatement.h"
27#include "src/sksl/ir/SkSLExtension.h"
28#include "src/sksl/ir/SkSLField.h"
29#include "src/sksl/ir/SkSLFieldAccess.h"
30#include "src/sksl/ir/SkSLForStatement.h"
31#include "src/sksl/ir/SkSLFunctionCall.h"
32#include "src/sksl/ir/SkSLFunctionDeclaration.h"
33#include "src/sksl/ir/SkSLFunctionDefinition.h"
34#include "src/sksl/ir/SkSLIfStatement.h"
35#include "src/sksl/ir/SkSLIndexExpression.h"
36#include "src/sksl/ir/SkSLInterfaceBlock.h"
37#include "src/sksl/ir/SkSLPostfixExpression.h"
38#include "src/sksl/ir/SkSLPrefixExpression.h"
39#include "src/sksl/ir/SkSLReturnStatement.h"
40#include "src/sksl/ir/SkSLSwitchStatement.h"
41#include "src/sksl/ir/SkSLSwizzle.h"
42#include "src/sksl/ir/SkSLTernaryExpression.h"
43#include "src/sksl/ir/SkSLVarDeclarations.h"
44#include "src/sksl/ir/SkSLVariableReference.h"
45
46#ifdef SK_VULKAN
47#include "src/gpu/vk/GrVkCaps.h"
48#endif
49
50#define kLast_Capability SpvCapabilityMultiViewport
51
52constexpr int DEVICE_FRAGCOORDS_BUILTIN = -1000;
53constexpr int DEVICE_CLOCKWISE_BUILTIN  = -1001;
54
55namespace SkSL {
56
57// Skia's magic number is 31 and goes in the top 16 bits. We can use the lower bits to version the
58// sksl generator if we want.
59// https://github.com/KhronosGroup/SPIRV-Headers/blob/master/include/spirv/spir-v.xml#L84
60static const int32_t SKSL_MAGIC  = 0x001F0000;
61
62void SPIRVCodeGenerator::setupIntrinsics() {
63#define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicOpcodeKind, GLSLstd450 ## x, \
64                                    GLSLstd450 ## x, GLSLstd450 ## x, GLSLstd450 ## x)
65#define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicOpcodeKind, \
66                                                             GLSLstd450 ## ifFloat,             \
67                                                             GLSLstd450 ## ifInt,               \
68                                                             GLSLstd450 ## ifUInt,              \
69                                                             SpvOpUndef)
70#define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicOpcodeKind, \
71                                     SpvOp ## x, SpvOp ## x, SpvOp ## x, SpvOp ## x)
72#define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicOpcodeKind, k ## x ## _SpecialIntrinsic, \
73                                   k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic,  \
74                                   k ## x ## _SpecialIntrinsic)
75    fIntrinsicMap[k_round_IntrinsicKind]         = ALL_GLSL(Round);
76    fIntrinsicMap[k_roundEven_IntrinsicKind]     = ALL_GLSL(RoundEven);
77    fIntrinsicMap[k_trunc_IntrinsicKind]         = ALL_GLSL(Trunc);
78    fIntrinsicMap[k_abs_IntrinsicKind]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
79    fIntrinsicMap[k_sign_IntrinsicKind]          = BY_TYPE_GLSL(FSign, SSign, SSign);
80    fIntrinsicMap[k_floor_IntrinsicKind]         = ALL_GLSL(Floor);
81    fIntrinsicMap[k_ceil_IntrinsicKind]          = ALL_GLSL(Ceil);
82    fIntrinsicMap[k_fract_IntrinsicKind]         = ALL_GLSL(Fract);
83    fIntrinsicMap[k_radians_IntrinsicKind]       = ALL_GLSL(Radians);
84    fIntrinsicMap[k_degrees_IntrinsicKind]       = ALL_GLSL(Degrees);
85    fIntrinsicMap[k_sin_IntrinsicKind]           = ALL_GLSL(Sin);
86    fIntrinsicMap[k_cos_IntrinsicKind]           = ALL_GLSL(Cos);
87    fIntrinsicMap[k_tan_IntrinsicKind]           = ALL_GLSL(Tan);
88    fIntrinsicMap[k_asin_IntrinsicKind]          = ALL_GLSL(Asin);
89    fIntrinsicMap[k_acos_IntrinsicKind]          = ALL_GLSL(Acos);
90    fIntrinsicMap[k_atan_IntrinsicKind]          = SPECIAL(Atan);
91    fIntrinsicMap[k_sinh_IntrinsicKind]          = ALL_GLSL(Sinh);
92    fIntrinsicMap[k_cosh_IntrinsicKind]          = ALL_GLSL(Cosh);
93    fIntrinsicMap[k_tanh_IntrinsicKind]          = ALL_GLSL(Tanh);
94    fIntrinsicMap[k_asinh_IntrinsicKind]         = ALL_GLSL(Asinh);
95    fIntrinsicMap[k_acosh_IntrinsicKind]         = ALL_GLSL(Acosh);
96    fIntrinsicMap[k_atanh_IntrinsicKind]         = ALL_GLSL(Atanh);
97    fIntrinsicMap[k_pow_IntrinsicKind]           = ALL_GLSL(Pow);
98    fIntrinsicMap[k_exp_IntrinsicKind]           = ALL_GLSL(Exp);
99    fIntrinsicMap[k_log_IntrinsicKind]           = ALL_GLSL(Log);
100    fIntrinsicMap[k_exp2_IntrinsicKind]          = ALL_GLSL(Exp2);
101    fIntrinsicMap[k_log2_IntrinsicKind]          = ALL_GLSL(Log2);
102    fIntrinsicMap[k_sqrt_IntrinsicKind]          = ALL_GLSL(Sqrt);
103    fIntrinsicMap[k_inverse_IntrinsicKind]       = ALL_GLSL(MatrixInverse);
104    fIntrinsicMap[k_outerProduct_IntrinsicKind]  = ALL_SPIRV(OuterProduct);
105    fIntrinsicMap[k_transpose_IntrinsicKind]     = ALL_SPIRV(Transpose);
106    fIntrinsicMap[k_isinf_IntrinsicKind]         = ALL_SPIRV(IsInf);
107    fIntrinsicMap[k_isnan_IntrinsicKind]         = ALL_SPIRV(IsNan);
108    fIntrinsicMap[k_inversesqrt_IntrinsicKind]   = ALL_GLSL(InverseSqrt);
109    fIntrinsicMap[k_determinant_IntrinsicKind]   = ALL_GLSL(Determinant);
110    fIntrinsicMap[k_matrixCompMult_IntrinsicKind] = SPECIAL(MatrixCompMult);
111    fIntrinsicMap[k_matrixInverse_IntrinsicKind] = ALL_GLSL(MatrixInverse);
112    fIntrinsicMap[k_mod_IntrinsicKind]           = SPECIAL(Mod);
113    fIntrinsicMap[k_modf_IntrinsicKind]          = ALL_GLSL(Modf);
114    fIntrinsicMap[k_min_IntrinsicKind]           = SPECIAL(Min);
115    fIntrinsicMap[k_max_IntrinsicKind]           = SPECIAL(Max);
116    fIntrinsicMap[k_clamp_IntrinsicKind]         = SPECIAL(Clamp);
117    fIntrinsicMap[k_saturate_IntrinsicKind]      = SPECIAL(Saturate);
118    fIntrinsicMap[k_dot_IntrinsicKind]           = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
119                                                      SpvOpDot, SpvOpUndef, SpvOpUndef, SpvOpUndef);
120    fIntrinsicMap[k_mix_IntrinsicKind]           = SPECIAL(Mix);
121    fIntrinsicMap[k_step_IntrinsicKind]          = SPECIAL(Step);
122    fIntrinsicMap[k_smoothstep_IntrinsicKind]    = SPECIAL(SmoothStep);
123    fIntrinsicMap[k_fma_IntrinsicKind]           = ALL_GLSL(Fma);
124    fIntrinsicMap[k_frexp_IntrinsicKind]         = ALL_GLSL(Frexp);
125    fIntrinsicMap[k_ldexp_IntrinsicKind]         = ALL_GLSL(Ldexp);
126
127#define PACK(type) fIntrinsicMap[k_pack##type##_IntrinsicKind] = ALL_GLSL(Pack##type); \
128                   fIntrinsicMap[k_unpack##type##_IntrinsicKind] = ALL_GLSL(Unpack##type)
129    PACK(Snorm4x8);
130    PACK(Unorm4x8);
131    PACK(Snorm2x16);
132    PACK(Unorm2x16);
133    PACK(Half2x16);
134    PACK(Double2x32);
135#undef PACK
136    fIntrinsicMap[k_length_IntrinsicKind]      = ALL_GLSL(Length);
137    fIntrinsicMap[k_distance_IntrinsicKind]    = ALL_GLSL(Distance);
138    fIntrinsicMap[k_cross_IntrinsicKind]       = ALL_GLSL(Cross);
139    fIntrinsicMap[k_normalize_IntrinsicKind]   = ALL_GLSL(Normalize);
140    fIntrinsicMap[k_faceforward_IntrinsicKind] = ALL_GLSL(FaceForward);
141    fIntrinsicMap[k_reflect_IntrinsicKind]     = ALL_GLSL(Reflect);
142    fIntrinsicMap[k_refract_IntrinsicKind]     = ALL_GLSL(Refract);
143    fIntrinsicMap[k_bitCount_IntrinsicKind]    = ALL_SPIRV(BitCount);
144    fIntrinsicMap[k_findLSB_IntrinsicKind]     = ALL_GLSL(FindILsb);
145    fIntrinsicMap[k_findMSB_IntrinsicKind]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
146    fIntrinsicMap[k_dFdx_IntrinsicKind]        = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
147                                                                 SpvOpDPdx, SpvOpUndef,
148                                                                 SpvOpUndef, SpvOpUndef);
149    fIntrinsicMap[k_dFdy_IntrinsicKind]        = SPECIAL(DFdy);
150    fIntrinsicMap[k_fwidth_IntrinsicKind]      = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
151                                                                 SpvOpFwidth, SpvOpUndef,
152                                                                 SpvOpUndef, SpvOpUndef);
153    fIntrinsicMap[k_makeSampler2D_IntrinsicKind] = SPECIAL(SampledImage);
154
155    fIntrinsicMap[k_sample_IntrinsicKind]      = SPECIAL(Texture);
156    fIntrinsicMap[k_subpassLoad_IntrinsicKind] = SPECIAL(SubpassLoad);
157
158    fIntrinsicMap[k_floatBitsToInt_IntrinsicKind]  = ALL_SPIRV(Bitcast);
159    fIntrinsicMap[k_floatBitsToUint_IntrinsicKind] = ALL_SPIRV(Bitcast);
160    fIntrinsicMap[k_intBitsToFloat_IntrinsicKind]  = ALL_SPIRV(Bitcast);
161    fIntrinsicMap[k_uintBitsToFloat_IntrinsicKind] = ALL_SPIRV(Bitcast);
162
163    fIntrinsicMap[k_any_IntrinsicKind]        = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
164                                                                SpvOpUndef, SpvOpUndef,
165                                                                SpvOpUndef, SpvOpAny);
166    fIntrinsicMap[k_all_IntrinsicKind]        = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
167                                                                SpvOpUndef, SpvOpUndef,
168                                                                SpvOpUndef, SpvOpAll);
169    fIntrinsicMap[k_not_IntrinsicKind]        = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
170                                                                SpvOpUndef, SpvOpUndef, SpvOpUndef,
171                                                                SpvOpLogicalNot);
172    fIntrinsicMap[k_equal_IntrinsicKind]      = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
173                                                                SpvOpFOrdEqual, SpvOpIEqual,
174                                                                SpvOpIEqual, SpvOpLogicalEqual);
175    fIntrinsicMap[k_notEqual_IntrinsicKind]   = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
176                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
177                                                                SpvOpINotEqual,
178                                                                SpvOpLogicalNotEqual);
179    fIntrinsicMap[k_lessThan_IntrinsicKind]         = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
180                                                                      SpvOpFOrdLessThan,
181                                                                      SpvOpSLessThan,
182                                                                      SpvOpULessThan,
183                                                                      SpvOpUndef);
184    fIntrinsicMap[k_lessThanEqual_IntrinsicKind]    = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
185                                                                      SpvOpFOrdLessThanEqual,
186                                                                      SpvOpSLessThanEqual,
187                                                                      SpvOpULessThanEqual,
188                                                                      SpvOpUndef);
189    fIntrinsicMap[k_greaterThan_IntrinsicKind]      = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
190                                                                      SpvOpFOrdGreaterThan,
191                                                                      SpvOpSGreaterThan,
192                                                                      SpvOpUGreaterThan,
193                                                                      SpvOpUndef);
194    fIntrinsicMap[k_greaterThanEqual_IntrinsicKind] = std::make_tuple(kSPIRV_IntrinsicOpcodeKind,
195                                                                      SpvOpFOrdGreaterThanEqual,
196                                                                      SpvOpSGreaterThanEqual,
197                                                                      SpvOpUGreaterThanEqual,
198                                                                      SpvOpUndef);
199// interpolateAt* not yet supported...
200}
201
202void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
203    out.write((const char*) &word, sizeof(word));
204}
205
206static bool is_float(const Context& context, const Type& type) {
207    return (type.isScalar() || type.isVector() || type.isMatrix()) &&
208           type.componentType().isFloat();
209}
210
211static bool is_signed(const Context& context, const Type& type) {
212    return (type.isScalar() || type.isVector()) && type.componentType().isSigned();
213}
214
215static bool is_unsigned(const Context& context, const Type& type) {
216    return (type.isScalar() || type.isVector()) && type.componentType().isUnsigned();
217}
218
219static bool is_bool(const Context& context, const Type& type) {
220    return (type.isScalar() || type.isVector()) && type.componentType().isBoolean();
221}
222
223static bool is_out(const Modifiers& m) {
224    return (m.fFlags & Modifiers::kOut_Flag) != 0;
225}
226
227static bool is_in(const Modifiers& m) {
228    switch (m.fFlags & (Modifiers::kOut_Flag | Modifiers::kIn_Flag)) {
229        case Modifiers::kOut_Flag:                       // out
230            return false;
231
232        case 0:                                          // implicit in
233        case Modifiers::kIn_Flag:                        // explicit in
234        case Modifiers::kOut_Flag | Modifiers::kIn_Flag: // inout
235            return true;
236
237        default: SkUNREACHABLE;
238    }
239}
240
241void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
242    SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
243    SkASSERT(opCode != SpvOpUndef);
244    switch (opCode) {
245        case SpvOpReturn:      // fall through
246        case SpvOpReturnValue: // fall through
247        case SpvOpKill:        // fall through
248        case SpvOpSwitch:      // fall through
249        case SpvOpBranch:      // fall through
250        case SpvOpBranchConditional:
251            if (fCurrentBlock == 0) {
252                // We just encountered dead code--instructions that don't have an associated block.
253                // Synthesize a label if this happens; this is necessary to satisfy the validator.
254                this->writeLabel(this->nextId(nullptr), out);
255            }
256            fCurrentBlock = 0;
257            break;
258        case SpvOpConstant:          // fall through
259        case SpvOpConstantTrue:      // fall through
260        case SpvOpConstantFalse:     // fall through
261        case SpvOpConstantComposite: // fall through
262        case SpvOpTypeVoid:          // fall through
263        case SpvOpTypeInt:           // fall through
264        case SpvOpTypeFloat:         // fall through
265        case SpvOpTypeBool:          // fall through
266        case SpvOpTypeVector:        // fall through
267        case SpvOpTypeMatrix:        // fall through
268        case SpvOpTypeArray:         // fall through
269        case SpvOpTypePointer:       // fall through
270        case SpvOpTypeFunction:      // fall through
271        case SpvOpTypeRuntimeArray:  // fall through
272        case SpvOpTypeStruct:        // fall through
273        case SpvOpTypeImage:         // fall through
274        case SpvOpTypeSampledImage:  // fall through
275        case SpvOpTypeSampler:       // fall through
276        case SpvOpVariable:          // fall through
277        case SpvOpFunction:          // fall through
278        case SpvOpFunctionParameter: // fall through
279        case SpvOpFunctionEnd:       // fall through
280        case SpvOpExecutionMode:     // fall through
281        case SpvOpMemoryModel:       // fall through
282        case SpvOpCapability:        // fall through
283        case SpvOpExtInstImport:     // fall through
284        case SpvOpEntryPoint:        // fall through
285        case SpvOpSource:            // fall through
286        case SpvOpSourceExtension:   // fall through
287        case SpvOpName:              // fall through
288        case SpvOpMemberName:        // fall through
289        case SpvOpDecorate:          // fall through
290        case SpvOpMemberDecorate:
291            break;
292        default:
293            // We may find ourselves with dead code--instructions that don't have an associated
294            // block. This should be a rare event, but if it happens, synthesize a label; this is
295            // necessary to satisfy the validator.
296            if (fCurrentBlock == 0) {
297                this->writeLabel(this->nextId(nullptr), out);
298            }
299            break;
300    }
301    this->writeWord((length << 16) | opCode, out);
302}
303
304void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
305    SkASSERT(!fCurrentBlock);
306    fCurrentBlock = label;
307    this->writeInstruction(SpvOpLabel, label, out);
308}
309
310void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
311    this->writeOpCode(opCode, 1, out);
312}
313
314void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
315    this->writeOpCode(opCode, 2, out);
316    this->writeWord(word1, out);
317}
318
319void SPIRVCodeGenerator::writeString(skstd::string_view s, OutputStream& out) {
320    out.write(s.data(), s.length());
321    switch (s.length() % 4) {
322        case 1:
323            out.write8(0);
324            [[fallthrough]];
325        case 2:
326            out.write8(0);
327            [[fallthrough]];
328        case 3:
329            out.write8(0);
330            break;
331        default:
332            this->writeWord(0, out);
333            break;
334    }
335}
336
337void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, skstd::string_view string,
338                                          OutputStream& out) {
339    this->writeOpCode(opCode, 1 + (string.length() + 4) / 4, out);
340    this->writeString(string, out);
341}
342
343
344void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, skstd::string_view string,
345                                          OutputStream& out) {
346    this->writeOpCode(opCode, 2 + (string.length() + 4) / 4, out);
347    this->writeWord(word1, out);
348    this->writeString(string, out);
349}
350
351void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
352                                          skstd::string_view string, OutputStream& out) {
353    this->writeOpCode(opCode, 3 + (string.length() + 4) / 4, out);
354    this->writeWord(word1, out);
355    this->writeWord(word2, out);
356    this->writeString(string, out);
357}
358
359void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
360                                          OutputStream& out) {
361    this->writeOpCode(opCode, 3, out);
362    this->writeWord(word1, out);
363    this->writeWord(word2, out);
364}
365
366void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
367                                          int32_t word3, OutputStream& out) {
368    this->writeOpCode(opCode, 4, out);
369    this->writeWord(word1, out);
370    this->writeWord(word2, out);
371    this->writeWord(word3, out);
372}
373
374void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
375                                          int32_t word3, int32_t word4, OutputStream& out) {
376    this->writeOpCode(opCode, 5, out);
377    this->writeWord(word1, out);
378    this->writeWord(word2, out);
379    this->writeWord(word3, out);
380    this->writeWord(word4, out);
381}
382
383void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
384                                          int32_t word3, int32_t word4, int32_t word5,
385                                          OutputStream& out) {
386    this->writeOpCode(opCode, 6, out);
387    this->writeWord(word1, out);
388    this->writeWord(word2, out);
389    this->writeWord(word3, out);
390    this->writeWord(word4, out);
391    this->writeWord(word5, out);
392}
393
394void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
395                                          int32_t word3, int32_t word4, int32_t word5,
396                                          int32_t word6, OutputStream& out) {
397    this->writeOpCode(opCode, 7, out);
398    this->writeWord(word1, out);
399    this->writeWord(word2, out);
400    this->writeWord(word3, out);
401    this->writeWord(word4, out);
402    this->writeWord(word5, out);
403    this->writeWord(word6, out);
404}
405
406void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
407                                          int32_t word3, int32_t word4, int32_t word5,
408                                          int32_t word6, int32_t word7, OutputStream& out) {
409    this->writeOpCode(opCode, 8, out);
410    this->writeWord(word1, out);
411    this->writeWord(word2, out);
412    this->writeWord(word3, out);
413    this->writeWord(word4, out);
414    this->writeWord(word5, out);
415    this->writeWord(word6, out);
416    this->writeWord(word7, out);
417}
418
419void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
420                                          int32_t word3, int32_t word4, int32_t word5,
421                                          int32_t word6, int32_t word7, int32_t word8,
422                                          OutputStream& out) {
423    this->writeOpCode(opCode, 9, out);
424    this->writeWord(word1, out);
425    this->writeWord(word2, out);
426    this->writeWord(word3, out);
427    this->writeWord(word4, out);
428    this->writeWord(word5, out);
429    this->writeWord(word6, out);
430    this->writeWord(word7, out);
431    this->writeWord(word8, out);
432}
433
434void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
435    for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
436        if (fCapabilities & bit) {
437            this->writeInstruction(SpvOpCapability, (SpvId) i, out);
438        }
439    }
440    this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
441}
442
443SpvId SPIRVCodeGenerator::nextId(const Type* type) {
444    return this->nextId(type && type->hasPrecision() && !type->highPrecision()
445                ? Precision::kRelaxed
446                : Precision::kDefault);
447}
448
449SpvId SPIRVCodeGenerator::nextId(Precision precision) {
450    if (precision == Precision::kRelaxed && !fProgram.fConfig->fSettings.fForceHighPrecision) {
451        this->writeInstruction(SpvOpDecorate, fIdCount, SpvDecorationRelaxedPrecision,
452                               fDecorationBuffer);
453    }
454    return fIdCount++;
455}
456
457void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
458                                     SpvId resultId) {
459    this->writeInstruction(SpvOpName, resultId, String(type.name()).c_str(), fNameBuffer);
460    // go ahead and write all of the field types, so we don't inadvertently write them while we're
461    // in the middle of writing the struct instruction
462    std::vector<SpvId> types;
463    for (const auto& f : type.fields()) {
464        types.push_back(this->getType(*f.fType, memoryLayout));
465    }
466    this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
467    this->writeWord(resultId, fConstantBuffer);
468    for (SpvId id : types) {
469        this->writeWord(id, fConstantBuffer);
470    }
471    size_t offset = 0;
472    for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
473        const Type::Field& field = type.fields()[i];
474        if (!MemoryLayout::LayoutIsSupported(*field.fType)) {
475            fContext.fErrors->error(type.fLine, "type '" + field.fType->name() +
476                                    "' is not permitted here");
477            return;
478        }
479        size_t size = memoryLayout.size(*field.fType);
480        size_t alignment = memoryLayout.alignment(*field.fType);
481        const Layout& fieldLayout = field.fModifiers.fLayout;
482        if (fieldLayout.fOffset >= 0) {
483            if (fieldLayout.fOffset < (int) offset) {
484                fContext.fErrors->error(type.fLine,
485                                        "offset of field '" + field.fName + "' must be at "
486                                        "least " + to_string((int) offset));
487            }
488            if (fieldLayout.fOffset % alignment) {
489                fContext.fErrors->error(type.fLine,
490                                        "offset of field '" + field.fName + "' must be a multiple"
491                                        " of " + to_string((int) alignment));
492            }
493            offset = fieldLayout.fOffset;
494        } else {
495            size_t mod = offset % alignment;
496            if (mod) {
497                offset += alignment - mod;
498            }
499        }
500        this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
501        this->writeLayout(fieldLayout, resultId, i);
502        if (field.fModifiers.fLayout.fBuiltin < 0) {
503            this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
504                                   (SpvId) offset, fDecorationBuffer);
505        }
506        if (field.fType->isMatrix()) {
507            this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
508                                   fDecorationBuffer);
509            this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
510                                   (SpvId) memoryLayout.stride(*field.fType),
511                                   fDecorationBuffer);
512        }
513        if (!field.fType->highPrecision()) {
514            this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
515                                   SpvDecorationRelaxedPrecision, fDecorationBuffer);
516        }
517        offset += size;
518        if ((field.fType->isArray() || field.fType->isStruct()) && offset % alignment != 0) {
519            offset += alignment - offset % alignment;
520        }
521    }
522}
523
524const Type& SPIRVCodeGenerator::getActualType(const Type& type) {
525    if (type.isFloat()) {
526        return *fContext.fTypes.fFloat;
527    }
528    if (type.isSigned()) {
529        return *fContext.fTypes.fInt;
530    }
531    if (type.isUnsigned()) {
532        return *fContext.fTypes.fUInt;
533    }
534    if (type.isMatrix() || type.isVector()) {
535        if (type.componentType() == *fContext.fTypes.fHalf) {
536            return fContext.fTypes.fFloat->toCompound(fContext, type.columns(), type.rows());
537        }
538        if (type.componentType() == *fContext.fTypes.fShort) {
539            return fContext.fTypes.fInt->toCompound(fContext, type.columns(), type.rows());
540        }
541        if (type.componentType() == *fContext.fTypes.fUShort) {
542            return fContext.fTypes.fUInt->toCompound(fContext, type.columns(), type.rows());
543        }
544    }
545    return type;
546}
547
548SpvId SPIRVCodeGenerator::getType(const Type& type) {
549    return this->getType(type, fDefaultLayout);
550}
551
552SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
553    const Type* type;
554    std::unique_ptr<Type> arrayType;
555    String arrayName;
556
557    if (rawType.isArray()) {
558        // For arrays, we need to synthesize a temporary Array type using the "actual" component
559        // type. That is, if `short[10]` is passed in, we need to synthesize a `int[10]` Type.
560        // Otherwise, we can end up with two different SpvIds for the same array type.
561        const Type& component = this->getActualType(rawType.componentType());
562        arrayName = component.getArrayName(rawType.columns());
563        arrayType = Type::MakeArrayType(arrayName, component, rawType.columns());
564        type = arrayType.get();
565    } else {
566        // For non-array types, we can simply look up the "actual" type and use it.
567        type = &this->getActualType(rawType);
568    }
569
570    String key(type->name());
571    if (type->isStruct() || type->isArray()) {
572        key += to_string((int)layout.fStd);
573#ifdef SK_DEBUG
574        SkASSERT(layout.fStd == MemoryLayout::Standard::k140_Standard ||
575                 layout.fStd == MemoryLayout::Standard::k430_Standard);
576        MemoryLayout::Standard otherStd = layout.fStd == MemoryLayout::Standard::k140_Standard
577                                                  ? MemoryLayout::Standard::k430_Standard
578                                                  : MemoryLayout::Standard::k140_Standard;
579        String otherKey = type->name() + to_string((int)otherStd);
580        SkASSERT(fTypeMap.find(otherKey) == fTypeMap.end());
581#endif
582    }
583    auto entry = fTypeMap.find(key);
584    if (entry == fTypeMap.end()) {
585        SpvId result = this->nextId(nullptr);
586        switch (type->typeKind()) {
587            case Type::TypeKind::kScalar:
588                if (type->isBoolean()) {
589                    this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
590                } else if (type->isSigned()) {
591                    this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
592                } else if (type->isUnsigned()) {
593                    this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
594                } else if (type->isFloat()) {
595                    this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
596                } else {
597                    SkDEBUGFAILF("unrecognized scalar type '%s'", type->description().c_str());
598                }
599                break;
600            case Type::TypeKind::kVector:
601                this->writeInstruction(SpvOpTypeVector, result,
602                                       this->getType(type->componentType(), layout),
603                                       type->columns(), fConstantBuffer);
604                break;
605            case Type::TypeKind::kMatrix:
606                this->writeInstruction(
607                        SpvOpTypeMatrix,
608                        result,
609                        this->getType(IndexExpression::IndexType(fContext, *type), layout),
610                        type->columns(),
611                        fConstantBuffer);
612                break;
613            case Type::TypeKind::kStruct:
614                this->writeStruct(*type, layout, result);
615                break;
616            case Type::TypeKind::kArray: {
617                if (!MemoryLayout::LayoutIsSupported(*type)) {
618                    fContext.fErrors->error(type->fLine,
619                                            "type '" + type->name() + "' is not permitted here");
620                    return this->nextId(nullptr);
621                }
622                if (type->columns() > 0) {
623                    SpvId typeId = this->getType(type->componentType(), layout);
624                    SpvId countId = this->writeLiteral(type->columns(), *fContext.fTypes.fInt);
625                    this->writeInstruction(SpvOpTypeArray, result, typeId, countId,
626                                           fConstantBuffer);
627                    this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
628                                           (int32_t) layout.stride(*type),
629                                           fDecorationBuffer);
630                } else {
631                    // We shouldn't have any runtime-sized arrays right now
632                    fContext.fErrors->error(type->fLine,
633                                            "runtime-sized arrays are not supported in SPIR-V");
634                    this->writeInstruction(SpvOpTypeRuntimeArray, result,
635                                           this->getType(type->componentType(), layout),
636                                           fConstantBuffer);
637                    this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
638                                           (int32_t) layout.stride(*type),
639                                           fDecorationBuffer);
640                }
641                break;
642            }
643            case Type::TypeKind::kSampler: {
644                SpvId image = result;
645                if (SpvDimSubpassData != type->dimensions()) {
646                    image = this->getType(type->textureType(), layout);
647                }
648                if (SpvDimBuffer == type->dimensions()) {
649                    fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
650                }
651                if (SpvDimSubpassData != type->dimensions()) {
652                    this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
653                }
654                break;
655            }
656            case Type::TypeKind::kSeparateSampler: {
657                this->writeInstruction(SpvOpTypeSampler, result, fConstantBuffer);
658                break;
659            }
660            case Type::TypeKind::kTexture: {
661                this->writeInstruction(SpvOpTypeImage, result,
662                                       this->getType(*fContext.fTypes.fFloat, layout),
663                                       type->dimensions(), type->isDepth(),
664                                       type->isArrayedTexture(), type->isMultisampled(),
665                                       type->isSampled() ? 1 : 2, SpvImageFormatUnknown,
666                                       fConstantBuffer);
667                fImageTypeMap[key] = result;
668                break;
669            }
670            default:
671                if (type->isVoid()) {
672                    this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
673                } else {
674                    SkDEBUGFAILF("invalid type: %s", type->description().c_str());
675                }
676                break;
677        }
678        fTypeMap[key] = result;
679        return result;
680    }
681    return entry->second;
682}
683
684SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
685    SkASSERT(type.typeKind() == Type::TypeKind::kSampler);
686    this->getType(type);
687    String key = type.name() + to_string((int) fDefaultLayout.fStd);
688    SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
689    return fImageTypeMap[key];
690}
691
692SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
693    String key = to_string(this->getType(function.returnType())) + "(";
694    String separator;
695    const std::vector<const Variable*>& parameters = function.parameters();
696    for (size_t i = 0; i < parameters.size(); i++) {
697        key += separator;
698        separator = ", ";
699        key += to_string(this->getType(parameters[i]->type()));
700    }
701    key += ")";
702    auto entry = fTypeMap.find(key);
703    if (entry == fTypeMap.end()) {
704        SpvId result = this->nextId(nullptr);
705        int32_t length = 3 + (int32_t) parameters.size();
706        SpvId returnType = this->getType(function.returnType());
707        std::vector<SpvId> parameterTypes;
708        for (size_t i = 0; i < parameters.size(); i++) {
709            // glslang seems to treat all function arguments as pointers whether they need to be or
710            // not. I  was initially puzzled by this until I ran bizarre failures with certain
711            // patterns of function calls and control constructs, as exemplified by this minimal
712            // failure case:
713            //
714            // void sphere(float x) {
715            // }
716            //
717            // void map() {
718            //     sphere(1.0);
719            // }
720            //
721            // void main() {
722            //     for (int i = 0; i < 1; i++) {
723            //         map();
724            //     }
725            // }
726            //
727            // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
728            // crashes. Making it take a float* and storing the argument in a temporary variable,
729            // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
730            // the spec makes this make sense.
731            parameterTypes.push_back(this->getPointerType(parameters[i]->type(),
732                                                          SpvStorageClassFunction));
733        }
734        this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
735        this->writeWord(result, fConstantBuffer);
736        this->writeWord(returnType, fConstantBuffer);
737        for (SpvId id : parameterTypes) {
738            this->writeWord(id, fConstantBuffer);
739        }
740        fTypeMap[key] = result;
741        return result;
742    }
743    return entry->second;
744}
745
746SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
747    return this->getPointerType(type, fDefaultLayout, storageClass);
748}
749
750SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
751                                         SpvStorageClass_ storageClass) {
752    const Type& type = this->getActualType(rawType);
753    String key = type.displayName() + "*" + to_string(layout.fStd) + to_string(storageClass);
754    auto entry = fTypeMap.find(key);
755    if (entry == fTypeMap.end()) {
756        SpvId result = this->nextId(nullptr);
757        this->writeInstruction(SpvOpTypePointer, result, storageClass,
758                               this->getType(type), fConstantBuffer);
759        fTypeMap[key] = result;
760        return result;
761    }
762    return entry->second;
763}
764
765SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
766    switch (expr.kind()) {
767        case Expression::Kind::kBinary:
768            return this->writeBinaryExpression(expr.as<BinaryExpression>(), out);
769        case Expression::Kind::kConstructorArrayCast:
770            return this->writeExpression(*expr.as<ConstructorArrayCast>().argument(), out);
771        case Expression::Kind::kConstructorArray:
772        case Expression::Kind::kConstructorStruct:
773            return this->writeCompositeConstructor(expr.asAnyConstructor(), out);
774        case Expression::Kind::kConstructorDiagonalMatrix:
775            return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
776        case Expression::Kind::kConstructorMatrixResize:
777            return this->writeConstructorMatrixResize(expr.as<ConstructorMatrixResize>(), out);
778        case Expression::Kind::kConstructorScalarCast:
779            return this->writeConstructorScalarCast(expr.as<ConstructorScalarCast>(), out);
780        case Expression::Kind::kConstructorSplat:
781            return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
782        case Expression::Kind::kConstructorCompound:
783            return this->writeConstructorCompound(expr.as<ConstructorCompound>(), out);
784        case Expression::Kind::kConstructorCompoundCast:
785            return this->writeConstructorCompoundCast(expr.as<ConstructorCompoundCast>(), out);
786        case Expression::Kind::kFieldAccess:
787            return this->writeFieldAccess(expr.as<FieldAccess>(), out);
788        case Expression::Kind::kFunctionCall:
789            return this->writeFunctionCall(expr.as<FunctionCall>(), out);
790        case Expression::Kind::kLiteral:
791            return this->writeLiteral(expr.as<Literal>());
792        case Expression::Kind::kPrefix:
793            return this->writePrefixExpression(expr.as<PrefixExpression>(), out);
794        case Expression::Kind::kPostfix:
795            return this->writePostfixExpression(expr.as<PostfixExpression>(), out);
796        case Expression::Kind::kSwizzle:
797            return this->writeSwizzle(expr.as<Swizzle>(), out);
798        case Expression::Kind::kVariableReference:
799            return this->writeVariableReference(expr.as<VariableReference>(), out);
800        case Expression::Kind::kTernary:
801            return this->writeTernaryExpression(expr.as<TernaryExpression>(), out);
802        case Expression::Kind::kIndex:
803            return this->writeIndexExpression(expr.as<IndexExpression>(), out);
804        default:
805            SkDEBUGFAILF("unsupported expression: %s", expr.description().c_str());
806            break;
807    }
808    return -1;
809}
810
811SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
812    const FunctionDeclaration& function = c.function();
813    auto intrinsic = fIntrinsicMap.find(function.intrinsicKind());
814    if (intrinsic == fIntrinsicMap.end()) {
815        fContext.fErrors->error(c.fLine, "unsupported intrinsic '" + function.description() + "'");
816        return -1;
817    }
818    int32_t intrinsicId;
819    const ExpressionArray& arguments = c.arguments();
820    if (arguments.size() > 0) {
821        const Type& type = arguments[0]->type();
822        if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicOpcodeKind ||
823            is_float(fContext, type)) {
824            intrinsicId = std::get<1>(intrinsic->second);
825        } else if (is_signed(fContext, type)) {
826            intrinsicId = std::get<2>(intrinsic->second);
827        } else if (is_unsigned(fContext, type)) {
828            intrinsicId = std::get<3>(intrinsic->second);
829        } else if (is_bool(fContext, type)) {
830            intrinsicId = std::get<4>(intrinsic->second);
831        } else {
832            intrinsicId = std::get<1>(intrinsic->second);
833        }
834    } else {
835        intrinsicId = std::get<1>(intrinsic->second);
836    }
837    switch (std::get<0>(intrinsic->second)) {
838        case kGLSL_STD_450_IntrinsicOpcodeKind: {
839            SpvId result = this->nextId(&c.type());
840            std::vector<SpvId> argumentIds;
841            std::vector<TempVar> tempVars;
842            argumentIds.reserve(arguments.size());
843            for (size_t i = 0; i < arguments.size(); i++) {
844                if (is_out(function.parameters()[i]->modifiers())) {
845                    argumentIds.push_back(
846                            this->writeFunctionCallArgument(*arguments[i],
847                                                            function.parameters()[i]->modifiers(),
848                                                            &tempVars,
849                                                            out));
850                } else {
851                    argumentIds.push_back(this->writeExpression(*arguments[i], out));
852                }
853            }
854            this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
855            this->writeWord(this->getType(c.type()), out);
856            this->writeWord(result, out);
857            this->writeWord(fGLSLExtendedInstructions, out);
858            this->writeWord(intrinsicId, out);
859            for (SpvId id : argumentIds) {
860                this->writeWord(id, out);
861            }
862            this->copyBackTempVars(tempVars, out);
863            return result;
864        }
865        case kSPIRV_IntrinsicOpcodeKind: {
866            // GLSL supports dot(float, float), but SPIR-V does not. Convert it to FMul
867            if (intrinsicId == SpvOpDot && arguments[0]->type().isScalar()) {
868                intrinsicId = SpvOpFMul;
869            }
870            SpvId result = this->nextId(&c.type());
871            std::vector<SpvId> argumentIds;
872            std::vector<TempVar> tempVars;
873            argumentIds.reserve(arguments.size());
874            for (size_t i = 0; i < arguments.size(); i++) {
875                if (is_out(function.parameters()[i]->modifiers())) {
876                    argumentIds.push_back(
877                            this->writeFunctionCallArgument(*arguments[i],
878                                                            function.parameters()[i]->modifiers(),
879                                                            &tempVars,
880                                                            out));
881                } else {
882                    argumentIds.push_back(this->writeExpression(*arguments[i], out));
883                }
884            }
885            if (!c.type().isVoid()) {
886                this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
887                this->writeWord(this->getType(c.type()), out);
888                this->writeWord(result, out);
889            } else {
890                this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
891            }
892            for (SpvId id : argumentIds) {
893                this->writeWord(id, out);
894            }
895            this->copyBackTempVars(tempVars, out);
896            return result;
897        }
898        case kSpecial_IntrinsicOpcodeKind:
899            return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
900        default:
901            fContext.fErrors->error(c.fLine, "unsupported intrinsic '" + function.description() +
902                                             "'");
903            return -1;
904    }
905}
906
907SpvId SPIRVCodeGenerator::vectorize(const Expression& arg, int vectorSize, OutputStream& out) {
908    SkASSERT(vectorSize >= 1 && vectorSize <= 4);
909    const Type& argType = arg.type();
910    SpvId raw = this->writeExpression(arg, out);
911    if (argType.isScalar()) {
912        if (vectorSize == 1) {
913            return raw;
914        }
915        SpvId vector = this->nextId(&argType);
916        this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
917        this->writeWord(this->getType(argType.toCompound(fContext, vectorSize, 1)), out);
918        this->writeWord(vector, out);
919        for (int i = 0; i < vectorSize; i++) {
920            this->writeWord(raw, out);
921        }
922        return vector;
923    } else {
924        SkASSERT(vectorSize == argType.columns());
925        return raw;
926    }
927}
928
929std::vector<SpvId> SPIRVCodeGenerator::vectorize(const ExpressionArray& args, OutputStream& out) {
930    int vectorSize = 1;
931    for (const auto& a : args) {
932        if (a->type().isVector()) {
933            if (vectorSize > 1) {
934                SkASSERT(a->type().columns() == vectorSize);
935            } else {
936                vectorSize = a->type().columns();
937            }
938        }
939    }
940    std::vector<SpvId> result;
941    result.reserve(args.size());
942    for (const auto& arg : args) {
943        result.push_back(this->vectorize(*arg, vectorSize, out));
944    }
945    return result;
946}
947
948void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
949                                                      SpvId signedInst, SpvId unsignedInst,
950                                                      const std::vector<SpvId>& args,
951                                                      OutputStream& out) {
952    this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
953    this->writeWord(this->getType(type), out);
954    this->writeWord(id, out);
955    this->writeWord(fGLSLExtendedInstructions, out);
956
957    if (is_float(fContext, type)) {
958        this->writeWord(floatInst, out);
959    } else if (is_signed(fContext, type)) {
960        this->writeWord(signedInst, out);
961    } else if (is_unsigned(fContext, type)) {
962        this->writeWord(unsignedInst, out);
963    } else {
964        SkASSERT(false);
965    }
966    for (SpvId a : args) {
967        this->writeWord(a, out);
968    }
969}
970
971SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
972                                                OutputStream& out) {
973    const ExpressionArray& arguments = c.arguments();
974    const Type& callType = c.type();
975    SpvId result = this->nextId(nullptr);
976    switch (kind) {
977        case kAtan_SpecialIntrinsic: {
978            std::vector<SpvId> argumentIds;
979            argumentIds.reserve(arguments.size());
980            for (const std::unique_ptr<Expression>& arg : arguments) {
981                argumentIds.push_back(this->writeExpression(*arg, out));
982            }
983            this->writeOpCode(SpvOpExtInst, 5 + (int32_t) argumentIds.size(), out);
984            this->writeWord(this->getType(callType), out);
985            this->writeWord(result, out);
986            this->writeWord(fGLSLExtendedInstructions, out);
987            this->writeWord(argumentIds.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
988            for (SpvId id : argumentIds) {
989                this->writeWord(id, out);
990            }
991            break;
992        }
993        case kSampledImage_SpecialIntrinsic: {
994            SkASSERT(arguments.size() == 2);
995            SpvId img = this->writeExpression(*arguments[0], out);
996            SpvId sampler = this->writeExpression(*arguments[1], out);
997            this->writeInstruction(SpvOpSampledImage,
998                                   this->getType(callType),
999                                   result,
1000                                   img,
1001                                   sampler,
1002                                   out);
1003            break;
1004        }
1005        case kSubpassLoad_SpecialIntrinsic: {
1006            SpvId img = this->writeExpression(*arguments[0], out);
1007            ExpressionArray args;
1008            args.reserve_back(2);
1009            args.push_back(Literal::MakeInt(fContext, /*line=*/-1, /*value=*/0));
1010            args.push_back(Literal::MakeInt(fContext, /*line=*/-1, /*value=*/0));
1011            ConstructorCompound ctor(/*line=*/-1, *fContext.fTypes.fInt2, std::move(args));
1012            SpvId coords = this->writeConstantVector(ctor);
1013            if (arguments.size() == 1) {
1014                this->writeInstruction(SpvOpImageRead,
1015                                       this->getType(callType),
1016                                       result,
1017                                       img,
1018                                       coords,
1019                                       out);
1020            } else {
1021                SkASSERT(arguments.size() == 2);
1022                SpvId sample = this->writeExpression(*arguments[1], out);
1023                this->writeInstruction(SpvOpImageRead,
1024                                       this->getType(callType),
1025                                       result,
1026                                       img,
1027                                       coords,
1028                                       SpvImageOperandsSampleMask,
1029                                       sample,
1030                                       out);
1031            }
1032            break;
1033        }
1034        case kTexture_SpecialIntrinsic: {
1035            SpvOp_ op = SpvOpImageSampleImplicitLod;
1036            const Type& arg1Type = arguments[1]->type();
1037            switch (arguments[0]->type().dimensions()) {
1038                case SpvDim1D:
1039                    if (arg1Type == *fContext.fTypes.fFloat2) {
1040                        op = SpvOpImageSampleProjImplicitLod;
1041                    } else {
1042                        SkASSERT(arg1Type == *fContext.fTypes.fFloat);
1043                    }
1044                    break;
1045                case SpvDim2D:
1046                    if (arg1Type == *fContext.fTypes.fFloat3) {
1047                        op = SpvOpImageSampleProjImplicitLod;
1048                    } else {
1049                        SkASSERT(arg1Type == *fContext.fTypes.fFloat2);
1050                    }
1051                    break;
1052                case SpvDim3D:
1053                    if (arg1Type == *fContext.fTypes.fFloat4) {
1054                        op = SpvOpImageSampleProjImplicitLod;
1055                    } else {
1056                        SkASSERT(arg1Type == *fContext.fTypes.fFloat3);
1057                    }
1058                    break;
1059                case SpvDimCube:   // fall through
1060                case SpvDimRect:   // fall through
1061                case SpvDimBuffer: // fall through
1062                case SpvDimSubpassData:
1063                    break;
1064            }
1065            SpvId type = this->getType(callType);
1066            SpvId sampler = this->writeExpression(*arguments[0], out);
1067            SpvId uv = this->writeExpression(*arguments[1], out);
1068            if (arguments.size() == 3) {
1069                this->writeInstruction(op, type, result, sampler, uv,
1070                                       SpvImageOperandsBiasMask,
1071                                       this->writeExpression(*arguments[2], out),
1072                                       out);
1073            } else {
1074                SkASSERT(arguments.size() == 2);
1075                if (fProgram.fConfig->fSettings.fSharpenTextures) {
1076                    SpvId lodBias = this->writeLiteral(-0.5, *fContext.fTypes.fFloat);
1077                    this->writeInstruction(op, type, result, sampler, uv,
1078                                           SpvImageOperandsBiasMask, lodBias, out);
1079                } else {
1080                    this->writeInstruction(op, type, result, sampler, uv,
1081                                           out);
1082                }
1083            }
1084            break;
1085        }
1086        case kMod_SpecialIntrinsic: {
1087            std::vector<SpvId> args = this->vectorize(arguments, out);
1088            SkASSERT(args.size() == 2);
1089            const Type& operandType = arguments[0]->type();
1090            SpvOp_ op;
1091            if (is_float(fContext, operandType)) {
1092                op = SpvOpFMod;
1093            } else if (is_signed(fContext, operandType)) {
1094                op = SpvOpSMod;
1095            } else if (is_unsigned(fContext, operandType)) {
1096                op = SpvOpUMod;
1097            } else {
1098                SkASSERT(false);
1099                return 0;
1100            }
1101            this->writeOpCode(op, 5, out);
1102            this->writeWord(this->getType(operandType), out);
1103            this->writeWord(result, out);
1104            this->writeWord(args[0], out);
1105            this->writeWord(args[1], out);
1106            break;
1107        }
1108        case kDFdy_SpecialIntrinsic: {
1109            SpvId fn = this->writeExpression(*arguments[0], out);
1110            this->writeOpCode(SpvOpDPdy, 4, out);
1111            this->writeWord(this->getType(callType), out);
1112            this->writeWord(result, out);
1113            this->writeWord(fn, out);
1114            this->addRTFlipUniform(c.fLine);
1115            using namespace dsl;
1116            DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1,
1117                    SKSL_RTFLIP_NAME));
1118            SpvId rtFlipY = this->vectorize(*rtFlip.y().release(), callType.columns(), out);
1119            SpvId flipped = this->nextId(&callType);
1120            this->writeInstruction(SpvOpFMul, this->getType(callType), flipped, result, rtFlipY,
1121                                   out);
1122            result = flipped;
1123            break;
1124        }
1125        case kClamp_SpecialIntrinsic: {
1126            std::vector<SpvId> args = this->vectorize(arguments, out);
1127            SkASSERT(args.size() == 3);
1128            this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1129                                               GLSLstd450UClamp, args, out);
1130            break;
1131        }
1132        case kMax_SpecialIntrinsic: {
1133            std::vector<SpvId> args = this->vectorize(arguments, out);
1134            SkASSERT(args.size() == 2);
1135            this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMax, GLSLstd450SMax,
1136                                               GLSLstd450UMax, args, out);
1137            break;
1138        }
1139        case kMin_SpecialIntrinsic: {
1140            std::vector<SpvId> args = this->vectorize(arguments, out);
1141            SkASSERT(args.size() == 2);
1142            this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMin, GLSLstd450SMin,
1143                                               GLSLstd450UMin, args, out);
1144            break;
1145        }
1146        case kMix_SpecialIntrinsic: {
1147            std::vector<SpvId> args = this->vectorize(arguments, out);
1148            SkASSERT(args.size() == 3);
1149            if (arguments[2]->type().componentType().isBoolean()) {
1150                // Use OpSelect to implement Boolean mix().
1151                SpvId falseId     = this->writeExpression(*arguments[0], out);
1152                SpvId trueId      = this->writeExpression(*arguments[1], out);
1153                SpvId conditionId = this->writeExpression(*arguments[2], out);
1154                this->writeInstruction(SpvOpSelect, this->getType(arguments[0]->type()), result,
1155                                       conditionId, trueId, falseId, out);
1156            } else {
1157                this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FMix, SpvOpUndef,
1158                                                   SpvOpUndef, args, out);
1159            }
1160            break;
1161        }
1162        case kSaturate_SpecialIntrinsic: {
1163            SkASSERT(arguments.size() == 1);
1164            ExpressionArray finalArgs;
1165            finalArgs.reserve_back(3);
1166            finalArgs.push_back(arguments[0]->clone());
1167            finalArgs.push_back(Literal::MakeFloat(fContext, /*line=*/-1, /*value=*/0));
1168            finalArgs.push_back(Literal::MakeFloat(fContext, /*line=*/-1, /*value=*/1));
1169            std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out);
1170            this->writeGLSLExtendedInstruction(callType, result, GLSLstd450FClamp, GLSLstd450SClamp,
1171                                               GLSLstd450UClamp, spvArgs, out);
1172            break;
1173        }
1174        case kSmoothStep_SpecialIntrinsic: {
1175            std::vector<SpvId> args = this->vectorize(arguments, out);
1176            SkASSERT(args.size() == 3);
1177            this->writeGLSLExtendedInstruction(callType, result, GLSLstd450SmoothStep, SpvOpUndef,
1178                                               SpvOpUndef, args, out);
1179            break;
1180        }
1181        case kStep_SpecialIntrinsic: {
1182            std::vector<SpvId> args = this->vectorize(arguments, out);
1183            SkASSERT(args.size() == 2);
1184            this->writeGLSLExtendedInstruction(callType, result, GLSLstd450Step, SpvOpUndef,
1185                                               SpvOpUndef, args, out);
1186            break;
1187        }
1188        case kMatrixCompMult_SpecialIntrinsic: {
1189            SkASSERT(arguments.size() == 2);
1190            SpvId lhs = this->writeExpression(*arguments[0], out);
1191            SpvId rhs = this->writeExpression(*arguments[1], out);
1192            result = this->writeComponentwiseMatrixBinary(callType, lhs, rhs, SpvOpFMul, out);
1193            break;
1194        }
1195    }
1196    return result;
1197}
1198
1199SpvId SPIRVCodeGenerator::writeFunctionCallArgument(const Expression& arg,
1200                                                    const Modifiers& paramModifiers,
1201                                                    std::vector<TempVar>* tempVars,
1202                                                    OutputStream& out) {
1203    // ID of temporary variable that we will use to hold this argument, or 0 if it is being
1204    // passed directly
1205    SpvId tmpVar;
1206    // if we need a temporary var to store this argument, this is the value to store in the var
1207    SpvId tmpValueId = -1;
1208
1209    if (is_out(paramModifiers)) {
1210        std::unique_ptr<LValue> lv = this->getLValue(arg, out);
1211        SpvId ptr = lv->getPointer();
1212        if (ptr != (SpvId) -1 && lv->isMemoryObjectPointer()) {
1213            return ptr;
1214        }
1215
1216        // lvalue cannot simply be read and written via a pointer (e.g. it's a swizzle). We need to
1217        // to use a temp variable.
1218        if (is_in(paramModifiers)) {
1219            tmpValueId = lv->load(out);
1220        }
1221        tmpVar = this->nextId(&arg.type());
1222        tempVars->push_back(TempVar{tmpVar, &arg.type(), std::move(lv)});
1223    } else {
1224        // See getFunctionType for an explanation of why we're always using pointer parameters.
1225        tmpValueId = this->writeExpression(arg, out);
1226        tmpVar = this->nextId(nullptr);
1227    }
1228    this->writeInstruction(SpvOpVariable,
1229                           this->getPointerType(arg.type(), SpvStorageClassFunction),
1230                           tmpVar,
1231                           SpvStorageClassFunction,
1232                           fVariableBuffer);
1233    if (tmpValueId != (SpvId)-1) {
1234        this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1235    }
1236    return tmpVar;
1237}
1238
1239void SPIRVCodeGenerator::copyBackTempVars(const std::vector<TempVar>& tempVars, OutputStream& out) {
1240    for (const TempVar& tempVar : tempVars) {
1241        SpvId load = this->nextId(tempVar.type);
1242        this->writeInstruction(SpvOpLoad, this->getType(*tempVar.type), load, tempVar.spvId, out);
1243        tempVar.lvalue->store(load, out);
1244    }
1245}
1246
1247SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
1248    const FunctionDeclaration& function = c.function();
1249    if (function.isIntrinsic() && !function.definition()) {
1250        return this->writeIntrinsicCall(c, out);
1251    }
1252    const ExpressionArray& arguments = c.arguments();
1253    const auto& entry = fFunctionMap.find(&function);
1254    if (entry == fFunctionMap.end()) {
1255        fContext.fErrors->error(c.fLine, "function '" + function.description() +
1256                                         "' is not defined");
1257        return -1;
1258    }
1259    // Temp variables are used to write back out-parameters after the function call is complete.
1260    std::vector<TempVar> tempVars;
1261    std::vector<SpvId> argumentIds;
1262    argumentIds.reserve(arguments.size());
1263    for (size_t i = 0; i < arguments.size(); i++) {
1264        argumentIds.push_back(this->writeFunctionCallArgument(*arguments[i],
1265                                                              function.parameters()[i]->modifiers(),
1266                                                              &tempVars,
1267                                                              out));
1268    }
1269    SpvId result = this->nextId(nullptr);
1270    this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) arguments.size(), out);
1271    this->writeWord(this->getType(c.type()), out);
1272    this->writeWord(result, out);
1273    this->writeWord(entry->second, out);
1274    for (SpvId id : argumentIds) {
1275        this->writeWord(id, out);
1276    }
1277    // Now that the call is complete, we copy temp out-variables back to their real lvalues.
1278    this->copyBackTempVars(tempVars, out);
1279    return result;
1280}
1281
1282SpvId SPIRVCodeGenerator::writeConstantVector(const AnyConstructor& c) {
1283    const Type& type = c.type();
1284    SkASSERT(type.isVector() && c.isCompileTimeConstant());
1285
1286    // Get each of the constructor components as SPIR-V constants.
1287    SPIRVVectorConstant key{this->getType(type),
1288                            /*fValueId=*/{SpvId(-1), SpvId(-1), SpvId(-1), SpvId(-1)}};
1289
1290    const Type& scalarType = type.componentType();
1291    for (int n = 0; n < type.columns(); n++) {
1292        skstd::optional<double> slotVal = c.getConstantValue(n);
1293        if (!slotVal.has_value()) {
1294            SkDEBUGFAILF("writeConstantVector: %s not actually constant", c.description().c_str());
1295            return (SpvId)-1;
1296        }
1297        key.fValueId[n] = this->writeLiteral(*slotVal, scalarType);
1298    }
1299
1300    // Check to see if we've already synthesized this vector constant.
1301    auto [iter, newlyCreated] = fVectorConstants.insert({key, (SpvId)-1});
1302    if (newlyCreated) {
1303        // Emit an OpConstantComposite instruction for this constant.
1304        SpvId result = this->nextId(&type);
1305        this->writeOpCode(SpvOpConstantComposite, 3 + type.columns(), fConstantBuffer);
1306        this->writeWord(key.fTypeId, fConstantBuffer);
1307        this->writeWord(result, fConstantBuffer);
1308        for (int i = 0; i < type.columns(); i++) {
1309            this->writeWord(key.fValueId[i], fConstantBuffer);
1310        }
1311        iter->second = result;
1312    }
1313    return iter->second;
1314}
1315
1316SpvId SPIRVCodeGenerator::castScalarToType(SpvId inputExprId,
1317                                           const Type& inputType,
1318                                           const Type& outputType,
1319                                           OutputStream& out) {
1320    if (outputType.isFloat()) {
1321        return this->castScalarToFloat(inputExprId, inputType, outputType, out);
1322    }
1323    if (outputType.isSigned()) {
1324        return this->castScalarToSignedInt(inputExprId, inputType, outputType, out);
1325    }
1326    if (outputType.isUnsigned()) {
1327        return this->castScalarToUnsignedInt(inputExprId, inputType, outputType, out);
1328    }
1329    if (outputType.isBoolean()) {
1330        return this->castScalarToBoolean(inputExprId, inputType, outputType, out);
1331    }
1332
1333    fContext.fErrors->error(-1, "unsupported cast: " + inputType.description() +
1334                                " to " + outputType.description());
1335    return inputExprId;
1336}
1337
1338SpvId SPIRVCodeGenerator::writeFloatConstructor(const AnyConstructor& c, OutputStream& out) {
1339    SkASSERT(c.argumentSpan().size() == 1);
1340    SkASSERT(c.type().isFloat());
1341    const Expression& ctorExpr = *c.argumentSpan().front();
1342    SpvId expressionId = this->writeExpression(ctorExpr, out);
1343    return this->castScalarToFloat(expressionId, ctorExpr.type(), c.type(), out);
1344}
1345
1346SpvId SPIRVCodeGenerator::castScalarToFloat(SpvId inputId, const Type& inputType,
1347                                            const Type& outputType, OutputStream& out) {
1348    // Casting a float to float is a no-op.
1349    if (inputType.isFloat()) {
1350        return inputId;
1351    }
1352
1353    // Given the input type, generate the appropriate instruction to cast to float.
1354    SpvId result = this->nextId(&outputType);
1355    if (inputType.isBoolean()) {
1356        // Use OpSelect to convert the boolean argument to a literal 1.0 or 0.0.
1357        const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fFloat);
1358        const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1359        this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1360                               inputId, oneID, zeroID, out);
1361    } else if (inputType.isSigned()) {
1362        this->writeInstruction(SpvOpConvertSToF, this->getType(outputType), result, inputId, out);
1363    } else if (inputType.isUnsigned()) {
1364        this->writeInstruction(SpvOpConvertUToF, this->getType(outputType), result, inputId, out);
1365    } else {
1366        SkDEBUGFAILF("unsupported type for float typecast: %s", inputType.description().c_str());
1367        return (SpvId)-1;
1368    }
1369    return result;
1370}
1371
1372SpvId SPIRVCodeGenerator::writeIntConstructor(const AnyConstructor& c, OutputStream& out) {
1373    SkASSERT(c.argumentSpan().size() == 1);
1374    SkASSERT(c.type().isSigned());
1375    const Expression& ctorExpr = *c.argumentSpan().front();
1376    SpvId expressionId = this->writeExpression(ctorExpr, out);
1377    return this->castScalarToSignedInt(expressionId, ctorExpr.type(), c.type(), out);
1378}
1379
1380SpvId SPIRVCodeGenerator::castScalarToSignedInt(SpvId inputId, const Type& inputType,
1381                                                const Type& outputType, OutputStream& out) {
1382    // Casting a signed int to signed int is a no-op.
1383    if (inputType.isSigned()) {
1384        return inputId;
1385    }
1386
1387    // Given the input type, generate the appropriate instruction to cast to signed int.
1388    SpvId result = this->nextId(&outputType);
1389    if (inputType.isBoolean()) {
1390        // Use OpSelect to convert the boolean argument to a literal 1 or 0.
1391        const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fInt);
1392        const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
1393        this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1394                               inputId, oneID, zeroID, out);
1395    } else if (inputType.isFloat()) {
1396        this->writeInstruction(SpvOpConvertFToS, this->getType(outputType), result, inputId, out);
1397    } else if (inputType.isUnsigned()) {
1398        this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1399    } else {
1400        SkDEBUGFAILF("unsupported type for signed int typecast: %s",
1401                     inputType.description().c_str());
1402        return (SpvId)-1;
1403    }
1404    return result;
1405}
1406
1407SpvId SPIRVCodeGenerator::writeUIntConstructor(const AnyConstructor& c, OutputStream& out) {
1408    SkASSERT(c.argumentSpan().size() == 1);
1409    SkASSERT(c.type().isUnsigned());
1410    const Expression& ctorExpr = *c.argumentSpan().front();
1411    SpvId expressionId = this->writeExpression(ctorExpr, out);
1412    return this->castScalarToUnsignedInt(expressionId, ctorExpr.type(), c.type(), out);
1413}
1414
1415SpvId SPIRVCodeGenerator::castScalarToUnsignedInt(SpvId inputId, const Type& inputType,
1416                                                  const Type& outputType, OutputStream& out) {
1417    // Casting an unsigned int to unsigned int is a no-op.
1418    if (inputType.isUnsigned()) {
1419        return inputId;
1420    }
1421
1422    // Given the input type, generate the appropriate instruction to cast to unsigned int.
1423    SpvId result = this->nextId(&outputType);
1424    if (inputType.isBoolean()) {
1425        // Use OpSelect to convert the boolean argument to a literal 1u or 0u.
1426        const SpvId oneID = this->writeLiteral(1.0, *fContext.fTypes.fUInt);
1427        const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
1428        this->writeInstruction(SpvOpSelect, this->getType(outputType), result,
1429                               inputId, oneID, zeroID, out);
1430    } else if (inputType.isFloat()) {
1431        this->writeInstruction(SpvOpConvertFToU, this->getType(outputType), result, inputId, out);
1432    } else if (inputType.isSigned()) {
1433        this->writeInstruction(SpvOpBitcast, this->getType(outputType), result, inputId, out);
1434    } else {
1435        SkDEBUGFAILF("unsupported type for unsigned int typecast: %s",
1436                     inputType.description().c_str());
1437        return (SpvId)-1;
1438    }
1439    return result;
1440}
1441
1442SpvId SPIRVCodeGenerator::writeBooleanConstructor(const AnyConstructor& c, OutputStream& out) {
1443    SkASSERT(c.argumentSpan().size() == 1);
1444    SkASSERT(c.type().isBoolean());
1445    const Expression& ctorExpr = *c.argumentSpan().front();
1446    SpvId expressionId = this->writeExpression(ctorExpr, out);
1447    return this->castScalarToBoolean(expressionId, ctorExpr.type(), c.type(), out);
1448}
1449
1450SpvId SPIRVCodeGenerator::castScalarToBoolean(SpvId inputId, const Type& inputType,
1451                                              const Type& outputType, OutputStream& out) {
1452    // Casting a bool to bool is a no-op.
1453    if (inputType.isBoolean()) {
1454        return inputId;
1455    }
1456
1457    // Given the input type, generate the appropriate instruction to cast to bool.
1458    SpvId result = this->nextId(nullptr);
1459    if (inputType.isSigned()) {
1460        // Synthesize a boolean result by comparing the input against a signed zero literal.
1461        const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fInt);
1462        this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1463                               inputId, zeroID, out);
1464    } else if (inputType.isUnsigned()) {
1465        // Synthesize a boolean result by comparing the input against an unsigned zero literal.
1466        const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fUInt);
1467        this->writeInstruction(SpvOpINotEqual, this->getType(outputType), result,
1468                               inputId, zeroID, out);
1469    } else if (inputType.isFloat()) {
1470        // Synthesize a boolean result by comparing the input against a floating-point zero literal.
1471        const SpvId zeroID = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1472        this->writeInstruction(SpvOpFUnordNotEqual, this->getType(outputType), result,
1473                               inputId, zeroID, out);
1474    } else {
1475        SkDEBUGFAILF("unsupported type for boolean typecast: %s", inputType.description().c_str());
1476        return (SpvId)-1;
1477    }
1478    return result;
1479}
1480
1481void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1482                                                 OutputStream& out) {
1483    SpvId zeroId = this->writeLiteral(0.0, *fContext.fTypes.fFloat);
1484    std::vector<SpvId> columnIds;
1485    columnIds.reserve(type.columns());
1486    for (int column = 0; column < type.columns(); column++) {
1487        this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1488                          out);
1489        this->writeWord(this->getType(type.componentType().toCompound(
1490                                fContext, /*columns=*/type.rows(), /*rows=*/1)),
1491                        out);
1492        SpvId columnId = this->nextId(&type);
1493        this->writeWord(columnId, out);
1494        columnIds.push_back(columnId);
1495        for (int row = 0; row < type.rows(); row++) {
1496            this->writeWord(row == column ? diagonal : zeroId, out);
1497        }
1498    }
1499    this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1500                      out);
1501    this->writeWord(this->getType(type), out);
1502    this->writeWord(id, out);
1503    for (SpvId columnId : columnIds) {
1504        this->writeWord(columnId, out);
1505    }
1506}
1507
1508SpvId SPIRVCodeGenerator::writeMatrixCopy(SpvId src, const Type& srcType, const Type& dstType,
1509                                          OutputStream& out) {
1510    SkASSERT(srcType.isMatrix());
1511    SkASSERT(dstType.isMatrix());
1512    SkASSERT(srcType.componentType() == dstType.componentType());
1513    SpvId id = this->nextId(&dstType);
1514    SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1515                                                                           srcType.rows(),
1516                                                                           1));
1517    SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1518                                                                           dstType.rows(),
1519                                                                           1));
1520    SkASSERT(dstType.componentType().isFloat());
1521    const SpvId zeroId = this->writeLiteral(0.0, dstType.componentType());
1522    const SpvId oneId = this->writeLiteral(1.0, dstType.componentType());
1523
1524    SpvId columns[4];
1525    for (int i = 0; i < dstType.columns(); i++) {
1526        if (i < srcType.columns()) {
1527            // we're still inside the src matrix, copy the column
1528            SpvId srcColumn = this->nextId(&dstType);
1529            this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1530            SpvId dstColumn;
1531            if (srcType.rows() == dstType.rows()) {
1532                // columns are equal size, don't need to do anything
1533                dstColumn = srcColumn;
1534            }
1535            else if (dstType.rows() > srcType.rows()) {
1536                // dst column is bigger, need to zero-pad it
1537                dstColumn = this->nextId(&dstType);
1538                int delta = dstType.rows() - srcType.rows();
1539                this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1540                this->writeWord(dstColumnType, out);
1541                this->writeWord(dstColumn, out);
1542                this->writeWord(srcColumn, out);
1543                for (int j = srcType.rows(); j < dstType.rows(); ++j) {
1544                    this->writeWord((i == j) ? oneId : zeroId, out);
1545                }
1546            }
1547            else {
1548                // dst column is smaller, need to swizzle the src column
1549                dstColumn = this->nextId(&dstType);
1550                this->writeOpCode(SpvOpVectorShuffle, 5 + dstType.rows(), out);
1551                this->writeWord(dstColumnType, out);
1552                this->writeWord(dstColumn, out);
1553                this->writeWord(srcColumn, out);
1554                this->writeWord(srcColumn, out);
1555                for (int j = 0; j < dstType.rows(); j++) {
1556                    this->writeWord(j, out);
1557                }
1558            }
1559            columns[i] = dstColumn;
1560        } else {
1561            // we're past the end of the src matrix, need to synthesize an identity-matrix column
1562            SpvId identityColumn = this->nextId(&dstType);
1563            this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1564            this->writeWord(dstColumnType, out);
1565            this->writeWord(identityColumn, out);
1566            for (int j = 0; j < dstType.rows(); ++j) {
1567                this->writeWord((i == j) ? oneId : zeroId, out);
1568            }
1569            columns[i] = identityColumn;
1570        }
1571    }
1572    this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1573    this->writeWord(this->getType(dstType), out);
1574    this->writeWord(id, out);
1575    for (int i = 0; i < dstType.columns(); i++) {
1576        this->writeWord(columns[i], out);
1577    }
1578    return id;
1579}
1580
1581void SPIRVCodeGenerator::addColumnEntry(const Type& columnType,
1582                                        std::vector<SpvId>* currentColumn,
1583                                        std::vector<SpvId>* columnIds,
1584                                        int rows,
1585                                        SpvId entry,
1586                                        OutputStream& out) {
1587    SkASSERT((int)currentColumn->size() < rows);
1588    currentColumn->push_back(entry);
1589    if ((int)currentColumn->size() == rows) {
1590        // Synthesize this column into a vector.
1591        SpvId columnId = this->writeComposite(*currentColumn, columnType, out);
1592        columnIds->push_back(columnId);
1593        currentColumn->clear();
1594    }
1595}
1596
1597SpvId SPIRVCodeGenerator::writeMatrixConstructor(const ConstructorCompound& c, OutputStream& out) {
1598    const Type& type = c.type();
1599    SkASSERT(type.isMatrix());
1600    SkASSERT(!c.arguments().empty());
1601    const Type& arg0Type = c.arguments()[0]->type();
1602    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1603    // an instruction
1604    std::vector<SpvId> arguments;
1605    arguments.reserve(c.arguments().size());
1606    for (const std::unique_ptr<Expression>& arg : c.arguments()) {
1607        arguments.push_back(this->writeExpression(*arg, out));
1608    }
1609
1610    if (arguments.size() == 1 && arg0Type.isVector()) {
1611        // Special-case handling of float4 -> mat2x2.
1612        SkASSERT(type.rows() == 2 && type.columns() == 2);
1613        SkASSERT(arg0Type.columns() == 4);
1614        SpvId componentType = this->getType(type.componentType());
1615        SpvId v[4];
1616        for (int i = 0; i < 4; ++i) {
1617            v[i] = this->nextId(&type);
1618            this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i,
1619                                   out);
1620        }
1621        const Type& vecType = type.componentType().toCompound(fContext, /*columns=*/2, /*rows=*/1);
1622        SpvId v0v1 = this->writeComposite({v[0], v[1]}, vecType, out);
1623        SpvId v2v3 = this->writeComposite({v[2], v[3]}, vecType, out);
1624        return this->writeComposite({v0v1, v2v3}, type, out);
1625    }
1626
1627    int rows = type.rows();
1628    const Type& columnType = type.componentType().toCompound(fContext,
1629                                                             /*columns=*/rows, /*rows=*/1);
1630    // SpvIds of completed columns of the matrix.
1631    std::vector<SpvId> columnIds;
1632    // SpvIds of scalars we have written to the current column so far.
1633    std::vector<SpvId> currentColumn;
1634    for (size_t i = 0; i < arguments.size(); i++) {
1635        const Type& argType = c.arguments()[i]->type();
1636        if (currentColumn.empty() && argType.isVector() && argType.columns() == rows) {
1637            // This vector is a complete matrix column by itself and can be used as-is.
1638            columnIds.push_back(arguments[i]);
1639        } else if (argType.columns() == 1) {
1640            // This argument is a lone scalar and can be added to the current column as-is.
1641            this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, arguments[i], out);
1642        } else {
1643            // This argument needs to be decomposed into its constituent scalars.
1644            SpvId componentType = this->getType(argType.componentType());
1645            for (int j = 0; j < argType.columns(); ++j) {
1646                SpvId swizzle = this->nextId(&argType);
1647                this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle,
1648                                       arguments[i], j, out);
1649                this->addColumnEntry(columnType, &currentColumn, &columnIds, rows, swizzle, out);
1650            }
1651        }
1652    }
1653    SkASSERT(columnIds.size() == (size_t) type.columns());
1654    return this->writeComposite(columnIds, type, out);
1655}
1656
1657SpvId SPIRVCodeGenerator::writeConstructorCompound(const ConstructorCompound& c,
1658                                                   OutputStream& out) {
1659    return c.type().isMatrix() ? this->writeMatrixConstructor(c, out)
1660                               : this->writeVectorConstructor(c, out);
1661}
1662
1663SpvId SPIRVCodeGenerator::writeVectorConstructor(const ConstructorCompound& c, OutputStream& out) {
1664    const Type& type = c.type();
1665    const Type& componentType = type.componentType();
1666    SkASSERT(type.isVector());
1667
1668    if (c.isCompileTimeConstant()) {
1669        return this->writeConstantVector(c);
1670    }
1671
1672    std::vector<SpvId> arguments;
1673    arguments.reserve(c.arguments().size());
1674    for (size_t i = 0; i < c.arguments().size(); i++) {
1675        const Type& argType = c.arguments()[i]->type();
1676        SkASSERT(componentType == argType.componentType());
1677
1678        SpvId arg = this->writeExpression(*c.arguments()[i], out);
1679        if (argType.isMatrix()) {
1680            // CompositeConstruct cannot take a 2x2 matrix as an input, so we need to extract out
1681            // each scalar separately.
1682            SkASSERT(argType.rows() == 2);
1683            SkASSERT(argType.columns() == 2);
1684            for (int j = 0; j < 4; ++j) {
1685                SpvId componentId = this->nextId(&componentType);
1686                this->writeInstruction(SpvOpCompositeExtract, this->getType(componentType),
1687                                       componentId, arg, j / 2, j % 2, out);
1688                arguments.push_back(componentId);
1689            }
1690        } else if (argType.isVector()) {
1691            // There's a bug in the Intel Vulkan driver where OpCompositeConstruct doesn't handle
1692            // vector arguments at all, so we always extract each vector component and pass them
1693            // into OpCompositeConstruct individually.
1694            for (int j = 0; j < argType.columns(); j++) {
1695                SpvId componentId = this->nextId(&componentType);
1696                this->writeInstruction(SpvOpCompositeExtract, this->getType(componentType),
1697                                       componentId, arg, j, out);
1698                arguments.push_back(componentId);
1699            }
1700        } else {
1701            arguments.push_back(arg);
1702        }
1703    }
1704
1705    return this->writeComposite(arguments, type, out);
1706}
1707
1708SpvId SPIRVCodeGenerator::writeComposite(const std::vector<SpvId>& arguments,
1709                                         const Type& type,
1710                                         OutputStream& out) {
1711    SkASSERT(arguments.size() == (type.isStruct() ? type.fields().size() : (size_t)type.columns()));
1712
1713    SpvId result = this->nextId(&type);
1714    this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1715    this->writeWord(this->getType(type), out);
1716    this->writeWord(result, out);
1717    for (SpvId id : arguments) {
1718        this->writeWord(id, out);
1719    }
1720    return result;
1721}
1722
1723SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
1724    // Use writeConstantVector to deduplicate constant splats.
1725    if (c.isCompileTimeConstant()) {
1726        return this->writeConstantVector(c);
1727    }
1728
1729    // Write the splat argument.
1730    SpvId argument = this->writeExpression(*c.argument(), out);
1731
1732    // Generate a OpCompositeConstruct which repeats the argument N times.
1733    std::vector<SpvId> arguments(/*count*/ c.type().columns(), /*value*/ argument);
1734    return this->writeComposite(arguments, c.type(), out);
1735}
1736
1737
1738SpvId SPIRVCodeGenerator::writeCompositeConstructor(const AnyConstructor& c, OutputStream& out) {
1739    SkASSERT(c.type().isArray() || c.type().isStruct());
1740    auto ctorArgs = c.argumentSpan();
1741
1742    std::vector<SpvId> arguments;
1743    arguments.reserve(ctorArgs.size());
1744    for (const std::unique_ptr<Expression>& arg : ctorArgs) {
1745        arguments.push_back(this->writeExpression(*arg, out));
1746    }
1747
1748    return this->writeComposite(arguments, c.type(), out);
1749}
1750
1751SpvId SPIRVCodeGenerator::writeConstructorScalarCast(const ConstructorScalarCast& c,
1752                                                     OutputStream& out) {
1753    const Type& type = c.type();
1754    if (this->getActualType(type) == this->getActualType(c.argument()->type())) {
1755        return this->writeExpression(*c.argument(), out);
1756    }
1757
1758    const Expression& ctorExpr = *c.argument();
1759    SpvId expressionId = this->writeExpression(ctorExpr, out);
1760    return this->castScalarToType(expressionId, ctorExpr.type(), type, out);
1761}
1762
1763SpvId SPIRVCodeGenerator::writeConstructorCompoundCast(const ConstructorCompoundCast& c,
1764                                                       OutputStream& out) {
1765    const Type& ctorType = c.type();
1766    const Type& argType = c.argument()->type();
1767    SkASSERT(ctorType.isVector() || ctorType.isMatrix());
1768
1769    // Write the composite that we are casting. If the actual type matches, we are done.
1770    SpvId compositeId = this->writeExpression(*c.argument(), out);
1771    if (this->getActualType(ctorType) == this->getActualType(argType)) {
1772        return compositeId;
1773    }
1774
1775    // writeMatrixCopy can cast matrices to a different type.
1776    if (ctorType.isMatrix()) {
1777        return this->writeMatrixCopy(compositeId, argType, ctorType, out);
1778    }
1779
1780    // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to extract the
1781    // components and convert each one manually.
1782    const Type& srcType = argType.componentType();
1783    const Type& dstType = ctorType.componentType();
1784
1785    std::vector<SpvId> arguments;
1786    arguments.reserve(argType.columns());
1787    for (int index = 0; index < argType.columns(); ++index) {
1788        SpvId componentId = this->nextId(&srcType);
1789        this->writeInstruction(SpvOpCompositeExtract, this->getType(srcType), componentId,
1790                               compositeId, index, out);
1791        arguments.push_back(this->castScalarToType(componentId, srcType, dstType, out));
1792    }
1793
1794    return this->writeComposite(arguments, ctorType, out);
1795}
1796
1797SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
1798                                                         OutputStream& out) {
1799    const Type& type = c.type();
1800    SkASSERT(type.isMatrix());
1801    SkASSERT(c.argument()->type().isScalar());
1802
1803    // Write out the scalar argument.
1804    SpvId argument = this->writeExpression(*c.argument(), out);
1805
1806    // Build the diagonal matrix.
1807    SpvId result = this->nextId(&type);
1808    this->writeUniformScaleMatrix(result, argument, type, out);
1809    return result;
1810}
1811
1812SpvId SPIRVCodeGenerator::writeConstructorMatrixResize(const ConstructorMatrixResize& c,
1813                                                       OutputStream& out) {
1814    // Write the input matrix.
1815    SpvId argument = this->writeExpression(*c.argument(), out);
1816
1817    // Use matrix-copy to resize the input matrix to its new size.
1818    return this->writeMatrixCopy(argument, c.argument()->type(), c.type(), out);
1819}
1820
1821static SpvStorageClass_ get_storage_class(const Variable& var,
1822                                          SpvStorageClass_ fallbackStorageClass) {
1823    const Modifiers& modifiers = var.modifiers();
1824    if (modifiers.fFlags & Modifiers::kIn_Flag) {
1825        SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1826        return SpvStorageClassInput;
1827    }
1828    if (modifiers.fFlags & Modifiers::kOut_Flag) {
1829        SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1830        return SpvStorageClassOutput;
1831    }
1832    if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1833        if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1834            return SpvStorageClassPushConstant;
1835        }
1836        if (var.type().typeKind() == Type::TypeKind::kSampler ||
1837            var.type().typeKind() == Type::TypeKind::kSeparateSampler ||
1838            var.type().typeKind() == Type::TypeKind::kTexture) {
1839            return SpvStorageClassUniformConstant;
1840        }
1841        return SpvStorageClassUniform;
1842    }
1843    return fallbackStorageClass;
1844}
1845
1846static SpvStorageClass_ get_storage_class(const Expression& expr) {
1847    switch (expr.kind()) {
1848        case Expression::Kind::kVariableReference: {
1849            const Variable& var = *expr.as<VariableReference>().variable();
1850            if (var.storage() != Variable::Storage::kGlobal) {
1851                return SpvStorageClassFunction;
1852            }
1853            return get_storage_class(var, SpvStorageClassPrivate);
1854        }
1855        case Expression::Kind::kFieldAccess:
1856            return get_storage_class(*expr.as<FieldAccess>().base());
1857        case Expression::Kind::kIndex:
1858            return get_storage_class(*expr.as<IndexExpression>().base());
1859        default:
1860            return SpvStorageClassFunction;
1861    }
1862}
1863
1864std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1865    std::vector<SpvId> chain;
1866    switch (expr.kind()) {
1867        case Expression::Kind::kIndex: {
1868            const IndexExpression& indexExpr = expr.as<IndexExpression>();
1869            chain = this->getAccessChain(*indexExpr.base(), out);
1870            chain.push_back(this->writeExpression(*indexExpr.index(), out));
1871            break;
1872        }
1873        case Expression::Kind::kFieldAccess: {
1874            const FieldAccess& fieldExpr = expr.as<FieldAccess>();
1875            chain = this->getAccessChain(*fieldExpr.base(), out);
1876            chain.push_back(this->writeLiteral(fieldExpr.fieldIndex(), *fContext.fTypes.fInt));
1877            break;
1878        }
1879        default: {
1880            SpvId id = this->getLValue(expr, out)->getPointer();
1881            SkASSERT(id != (SpvId) -1);
1882            chain.push_back(id);
1883            break;
1884        }
1885    }
1886    return chain;
1887}
1888
1889class PointerLValue : public SPIRVCodeGenerator::LValue {
1890public:
1891    PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, bool isMemoryObject, SpvId type,
1892                  SPIRVCodeGenerator::Precision precision)
1893    : fGen(gen)
1894    , fPointer(pointer)
1895    , fIsMemoryObject(isMemoryObject)
1896    , fType(type)
1897    , fPrecision(precision) {}
1898
1899    SpvId getPointer() override {
1900        return fPointer;
1901    }
1902
1903    bool isMemoryObjectPointer() const override {
1904        return fIsMemoryObject;
1905    }
1906
1907    SpvId load(OutputStream& out) override {
1908        SpvId result = fGen.nextId(fPrecision);
1909        fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1910        return result;
1911    }
1912
1913    void store(SpvId value, OutputStream& out) override {
1914        fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1915    }
1916
1917private:
1918    SPIRVCodeGenerator& fGen;
1919    const SpvId fPointer;
1920    const bool fIsMemoryObject;
1921    const SpvId fType;
1922    const SPIRVCodeGenerator::Precision fPrecision;
1923};
1924
1925class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1926public:
1927    SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const ComponentArray& components,
1928                  const Type& baseType, const Type& swizzleType)
1929    : fGen(gen)
1930    , fVecPointer(vecPointer)
1931    , fComponents(components)
1932    , fBaseType(&baseType)
1933    , fSwizzleType(&swizzleType) {}
1934
1935    bool applySwizzle(const ComponentArray& components, const Type& newType) override {
1936        ComponentArray updatedSwizzle;
1937        for (int8_t component : components) {
1938            if (component < 0 || component >= fComponents.count()) {
1939                SkDEBUGFAILF("swizzle accessed nonexistent component %d", (int)component);
1940                return false;
1941            }
1942            updatedSwizzle.push_back(fComponents[component]);
1943        }
1944        fComponents = updatedSwizzle;
1945        fSwizzleType = &newType;
1946        return true;
1947    }
1948
1949    SpvId load(OutputStream& out) override {
1950        SpvId base = fGen.nextId(fBaseType);
1951        fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
1952        SpvId result = fGen.nextId(fBaseType);
1953        fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1954        fGen.writeWord(fGen.getType(*fSwizzleType), out);
1955        fGen.writeWord(result, out);
1956        fGen.writeWord(base, out);
1957        fGen.writeWord(base, out);
1958        for (int component : fComponents) {
1959            fGen.writeWord(component, out);
1960        }
1961        return result;
1962    }
1963
1964    void store(SpvId value, OutputStream& out) override {
1965        // use OpVectorShuffle to mix and match the vector components. We effectively create
1966        // a virtual vector out of the concatenation of the left and right vectors, and then
1967        // select components from this virtual vector to make the result vector. For
1968        // instance, given:
1969        // float3L = ...;
1970        // float3R = ...;
1971        // L.xz = R.xy;
1972        // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1973        // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1974        // (3, 1, 4).
1975        SpvId base = fGen.nextId(fBaseType);
1976        fGen.writeInstruction(SpvOpLoad, fGen.getType(*fBaseType), base, fVecPointer, out);
1977        SpvId shuffle = fGen.nextId(fBaseType);
1978        fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType->columns(), out);
1979        fGen.writeWord(fGen.getType(*fBaseType), out);
1980        fGen.writeWord(shuffle, out);
1981        fGen.writeWord(base, out);
1982        fGen.writeWord(value, out);
1983        for (int i = 0; i < fBaseType->columns(); i++) {
1984            // current offset into the virtual vector, defaults to pulling the unmodified
1985            // value from the left side
1986            int offset = i;
1987            // check to see if we are writing this component
1988            for (size_t j = 0; j < fComponents.size(); j++) {
1989                if (fComponents[j] == i) {
1990                    // we're writing to this component, so adjust the offset to pull from
1991                    // the correct component of the right side instead of preserving the
1992                    // value from the left
1993                    offset = (int) (j + fBaseType->columns());
1994                    break;
1995                }
1996            }
1997            fGen.writeWord(offset, out);
1998        }
1999        fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
2000    }
2001
2002private:
2003    SPIRVCodeGenerator& fGen;
2004    const SpvId fVecPointer;
2005    ComponentArray fComponents;
2006    const Type* fBaseType;
2007    const Type* fSwizzleType;
2008};
2009
2010int SPIRVCodeGenerator::findUniformFieldIndex(const Variable& var) const {
2011    auto iter = fTopLevelUniformMap.find(&var);
2012    return (iter != fTopLevelUniformMap.end()) ? iter->second : -1;
2013}
2014
2015std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
2016                                                                          OutputStream& out) {
2017    const Type& type = expr.type();
2018    Precision precision = type.highPrecision() ? Precision::kDefault : Precision::kRelaxed;
2019    switch (expr.kind()) {
2020        case Expression::Kind::kVariableReference: {
2021            const Variable& var = *expr.as<VariableReference>().variable();
2022            int uniformIdx = this->findUniformFieldIndex(var);
2023            if (uniformIdx >= 0) {
2024                SpvId memberId = this->nextId(nullptr);
2025                SpvId typeId = this->getPointerType(type, SpvStorageClassUniform);
2026                SpvId uniformIdxId = this->writeLiteral((double)uniformIdx, *fContext.fTypes.fInt);
2027                this->writeInstruction(SpvOpAccessChain, typeId, memberId, fUniformBufferId,
2028                                       uniformIdxId, out);
2029                return std::make_unique<PointerLValue>(*this, memberId,
2030                                                       /*isMemoryObjectPointer=*/true,
2031                                                       this->getType(type), precision);
2032            }
2033            SpvId typeId = this->getType(type, this->memoryLayoutForVariable(var));
2034            auto entry = fVariableMap.find(&var);
2035            SkASSERTF(entry != fVariableMap.end(), "%s", expr.description().c_str());
2036            return std::make_unique<PointerLValue>(*this, entry->second,
2037                                                   /*isMemoryObjectPointer=*/true,
2038                                                   typeId, precision);
2039        }
2040        case Expression::Kind::kIndex: // fall through
2041        case Expression::Kind::kFieldAccess: {
2042            std::vector<SpvId> chain = this->getAccessChain(expr, out);
2043            SpvId member = this->nextId(nullptr);
2044            this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
2045            this->writeWord(this->getPointerType(type, get_storage_class(expr)), out);
2046            this->writeWord(member, out);
2047            for (SpvId idx : chain) {
2048                this->writeWord(idx, out);
2049            }
2050            return std::make_unique<PointerLValue>(*this, member, /*isMemoryObjectPointer=*/false,
2051                                                   this->getType(type), precision);
2052        }
2053        case Expression::Kind::kSwizzle: {
2054            const Swizzle& swizzle = expr.as<Swizzle>();
2055            std::unique_ptr<LValue> lvalue = this->getLValue(*swizzle.base(), out);
2056            if (lvalue->applySwizzle(swizzle.components(), type)) {
2057                return lvalue;
2058            }
2059            SpvId base = lvalue->getPointer();
2060            if (base == (SpvId) -1) {
2061                fContext.fErrors->error(swizzle.fLine, "unable to retrieve lvalue from swizzle");
2062            }
2063            if (swizzle.components().size() == 1) {
2064                SpvId member = this->nextId(nullptr);
2065                SpvId typeId = this->getPointerType(type, get_storage_class(*swizzle.base()));
2066                SpvId indexId = this->writeLiteral(swizzle.components()[0], *fContext.fTypes.fInt);
2067                this->writeInstruction(SpvOpAccessChain, typeId, member, base, indexId, out);
2068                return std::make_unique<PointerLValue>(*this,
2069                                                       member,
2070                                                       /*isMemoryObjectPointer=*/false,
2071                                                       this->getType(type),
2072                                                       precision);
2073            } else {
2074                return std::make_unique<SwizzleLValue>(*this, base, swizzle.components(),
2075                                                       swizzle.base()->type(), type);
2076            }
2077        }
2078        default: {
2079            // expr isn't actually an lvalue, create a placeholder variable for it. This case
2080            // happens due to the need to store values in temporary variables during function
2081            // calls (see comments in getFunctionType); erroneous uses of rvalues as lvalues
2082            // should have been caught before code generation
2083            SpvId result = this->nextId(nullptr);
2084            SpvId pointerType = this->getPointerType(type, SpvStorageClassFunction);
2085            this->writeInstruction(SpvOpVariable, pointerType, result, SpvStorageClassFunction,
2086                                   fVariableBuffer);
2087            this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
2088            return std::make_unique<PointerLValue>(*this, result, /*isMemoryObjectPointer=*/true,
2089                                                   this->getType(type), precision);
2090        }
2091    }
2092}
2093
2094SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
2095    const Variable* variable = ref.variable();
2096    if (variable->modifiers().fLayout.fBuiltin == DEVICE_FRAGCOORDS_BUILTIN) {
2097        // Down below, we rewrite raw references to sk_FragCoord with expressions that reference
2098        // DEVICE_FRAGCOORDS_BUILTIN. This is a fake variable that means we need to directly access
2099        // the fragcoord; do so now.
2100        dsl::DSLGlobalVar fragCoord("sk_FragCoord");
2101        return this->getLValue(*dsl::DSLExpression(fragCoord).release(), out)->load(out);
2102    }
2103    if (variable->modifiers().fLayout.fBuiltin == DEVICE_CLOCKWISE_BUILTIN) {
2104        // Down below, we rewrite raw references to sk_Clockwise with expressions that reference
2105        // DEVICE_CLOCKWISE_BUILTIN. This is a fake variable that means we need to directly
2106        // access front facing; do so now.
2107        dsl::DSLGlobalVar clockwise("sk_Clockwise");
2108        return this->getLValue(*dsl::DSLExpression(clockwise).release(), out)->load(out);
2109    }
2110
2111    // Handle inserting use of uniform to flip y when referencing sk_FragCoord.
2112    if (variable->modifiers().fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
2113        this->addRTFlipUniform(ref.fLine);
2114        // Use sk_RTAdjust to compute the flipped coordinate
2115        using namespace dsl;
2116        const char* DEVICE_COORDS_NAME = "__device_FragCoords";
2117        SymbolTable& symbols = *ThreadContext::SymbolTable();
2118        // Use a uniform to flip the Y coordinate. The new expression will be written in
2119        // terms of __device_FragCoords, which is a fake variable that means "access the
2120        // underlying fragcoords directly without flipping it".
2121        DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1,
2122                SKSL_RTFLIP_NAME));
2123        if (!symbols[DEVICE_COORDS_NAME]) {
2124            AutoAttachPoolToThread attach(fProgram.fPool.get());
2125            Modifiers modifiers;
2126            modifiers.fLayout.fBuiltin = DEVICE_FRAGCOORDS_BUILTIN;
2127            auto coordsVar = std::make_unique<Variable>(/*line=*/-1,
2128                                                        fContext.fModifiersPool->add(modifiers),
2129                                                        DEVICE_COORDS_NAME,
2130                                                        fContext.fTypes.fFloat4.get(),
2131                                                        true,
2132                                                        Variable::Storage::kGlobal);
2133            fSPIRVBonusVariables.insert(coordsVar.get());
2134            symbols.add(std::move(coordsVar));
2135        }
2136        DSLGlobalVar deviceCoord(DEVICE_COORDS_NAME);
2137        std::unique_ptr<Expression> rtFlipSkSLExpr = rtFlip.release();
2138        DSLExpression x = DSLExpression(rtFlipSkSLExpr->clone()).x();
2139        DSLExpression y = DSLExpression(std::move(rtFlipSkSLExpr)).y();
2140        return this->writeExpression(*dsl::Float4(deviceCoord.x(),
2141                                                  std::move(x) + std::move(y) * deviceCoord.y(),
2142                                                  deviceCoord.z(),
2143                                                  deviceCoord.w()).release(),
2144                                     out);
2145    }
2146
2147    // Handle flipping sk_Clockwise.
2148    if (variable->modifiers().fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN) {
2149        this->addRTFlipUniform(ref.fLine);
2150        using namespace dsl;
2151        const char* DEVICE_CLOCKWISE_NAME = "__device_Clockwise";
2152        SymbolTable& symbols = *ThreadContext::SymbolTable();
2153        // Use a uniform to flip the Y coordinate. The new expression will be written in
2154        // terms of __device_Clockwise, which is a fake variable that means "access the
2155        // underlying FrontFacing directly".
2156        DSLExpression rtFlip(ThreadContext::Compiler().convertIdentifier(/*line=*/-1,
2157                SKSL_RTFLIP_NAME));
2158        if (!symbols[DEVICE_CLOCKWISE_NAME]) {
2159            AutoAttachPoolToThread attach(fProgram.fPool.get());
2160            Modifiers modifiers;
2161            modifiers.fLayout.fBuiltin = DEVICE_CLOCKWISE_BUILTIN;
2162            auto clockwiseVar = std::make_unique<Variable>(/*line=*/-1,
2163                                                           fContext.fModifiersPool->add(modifiers),
2164                                                           DEVICE_CLOCKWISE_NAME,
2165                                                           fContext.fTypes.fBool.get(),
2166                                                           true,
2167                                                           Variable::Storage::kGlobal);
2168            fSPIRVBonusVariables.insert(clockwiseVar.get());
2169            symbols.add(std::move(clockwiseVar));
2170        }
2171        DSLGlobalVar deviceClockwise(DEVICE_CLOCKWISE_NAME);
2172        // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia,
2173        // we use the default convention of "counter-clockwise face is front".
2174        return this->writeExpression(*dsl::Bool(Select(rtFlip.y() > 0,
2175                                                       !deviceClockwise,
2176                                                       deviceClockwise)).release(),
2177                                     out);
2178    }
2179
2180    return this->getLValue(ref, out)->load(out);
2181}
2182
2183SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
2184    if (expr.base()->type().isVector()) {
2185        SpvId base = this->writeExpression(*expr.base(), out);
2186        SpvId index = this->writeExpression(*expr.index(), out);
2187        SpvId result = this->nextId(nullptr);
2188        this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.type()), result, base,
2189                               index, out);
2190        return result;
2191    }
2192    return getLValue(expr, out)->load(out);
2193}
2194
2195SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
2196    return getLValue(f, out)->load(out);
2197}
2198
2199SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
2200    SpvId base = this->writeExpression(*swizzle.base(), out);
2201    SpvId result = this->nextId(&swizzle.type());
2202    size_t count = swizzle.components().size();
2203    if (count == 1) {
2204        this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.type()), result, base,
2205                               swizzle.components()[0], out);
2206    } else {
2207        this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
2208        this->writeWord(this->getType(swizzle.type()), out);
2209        this->writeWord(result, out);
2210        this->writeWord(base, out);
2211        this->writeWord(base, out);
2212        for (int component : swizzle.components()) {
2213            this->writeWord(component, out);
2214        }
2215    }
2216    return result;
2217}
2218
2219SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
2220                                               const Type& operandType, SpvId lhs,
2221                                               SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
2222                                               SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
2223    SpvId result = this->nextId(&resultType);
2224    if (is_float(fContext, operandType)) {
2225        this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
2226    } else if (is_signed(fContext, operandType)) {
2227        this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
2228    } else if (is_unsigned(fContext, operandType)) {
2229        this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
2230    } else if (is_bool(fContext, operandType)) {
2231        this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
2232    } else {
2233        fContext.fErrors->error(operandType.fLine,
2234                "unsupported operand for binary expression: " + operandType.description());
2235    }
2236    return result;
2237}
2238
2239SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
2240                                     OutputStream& out) {
2241    if (operandType.isVector()) {
2242        SpvId result = this->nextId(nullptr);
2243        this->writeInstruction(op, this->getType(*fContext.fTypes.fBool), result, id, out);
2244        return result;
2245    }
2246    return id;
2247}
2248
2249SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
2250                                                SpvOp_ floatOperator, SpvOp_ intOperator,
2251                                                SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
2252                                                OutputStream& out) {
2253    SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
2254    SkASSERT(operandType.isMatrix());
2255    SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2256                                                                            operandType.rows(),
2257                                                                            1));
2258    SpvId bvecType = this->getType(fContext.fTypes.fBool->toCompound(fContext,
2259                                                                    operandType.rows(),
2260                                                                    1));
2261    SpvId boolType = this->getType(*fContext.fTypes.fBool);
2262    SpvId result = 0;
2263    for (int i = 0; i < operandType.columns(); i++) {
2264        SpvId columnL = this->nextId(&operandType);
2265        this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2266        SpvId columnR = this->nextId(&operandType);
2267        this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2268        SpvId compare = this->nextId(&operandType);
2269        this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
2270        SpvId merge = this->nextId(nullptr);
2271        this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
2272        if (result != 0) {
2273            SpvId next = this->nextId(nullptr);
2274            this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
2275            result = next;
2276        }
2277        else {
2278            result = merge;
2279        }
2280    }
2281    return result;
2282}
2283
2284SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
2285                                                         SpvId rhs, SpvOp_ op, OutputStream& out) {
2286    SkASSERT(operandType.isMatrix());
2287    SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
2288                                                                            operandType.rows(),
2289                                                                            1));
2290    std::vector<SpvId> columns;
2291    columns.reserve(operandType.columns());
2292    for (int i = 0; i < operandType.columns(); i++) {
2293        SpvId columnL = this->nextId(&operandType);
2294        this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
2295        SpvId columnR = this->nextId(&operandType);
2296        this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
2297        columns.push_back(this->nextId(&operandType));
2298        this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
2299    }
2300    return this->writeComposite(columns, operandType, out);
2301}
2302
2303SpvId SPIRVCodeGenerator::writeReciprocal(const Type& type, SpvId value, OutputStream& out) {
2304    SkASSERT(type.isFloat());
2305    SpvId one = this->writeLiteral(1.0, type);
2306    SpvId reciprocal = this->nextId(&type);
2307    this->writeInstruction(SpvOpFDiv, this->getType(type), reciprocal, one, value, out);
2308    return reciprocal;
2309}
2310
2311SpvId SPIRVCodeGenerator::writeScalarToMatrixSplat(const Type& matrixType,
2312                                                   SpvId scalarId,
2313                                                   OutputStream& out) {
2314    // Splat the scalar into a vector.
2315    const Type& vectorType = matrixType.componentType().toCompound(fContext,
2316                                                                   /*columns=*/matrixType.rows(),
2317                                                                   /*rows=*/1);
2318    std::vector<SpvId> vecArguments(/*count*/ matrixType.rows(), /*value*/ scalarId);
2319    SpvId vectorId = this->writeComposite(vecArguments, vectorType, out);
2320
2321    // Splat the vector into a matrix.
2322    std::vector<SpvId> matArguments(/*count*/ matrixType.columns(), /*value*/ vectorId);
2323    return this->writeComposite(matArguments, matrixType, out);
2324}
2325
2326SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Operator op,
2327                                                const Type& rightType, SpvId rhs,
2328                                                const Type& resultType, OutputStream& out) {
2329    // The comma operator ignores the type of the left-hand side entirely.
2330    if (op.kind() == Token::Kind::TK_COMMA) {
2331        return rhs;
2332    }
2333    // overall type we are operating on: float2, int, uint4...
2334    const Type* operandType;
2335    // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
2336    // handling in SPIR-V
2337    if (this->getActualType(leftType) != this->getActualType(rightType)) {
2338        if (leftType.isVector() && rightType.isNumber()) {
2339            if (resultType.componentType().isFloat()) {
2340                switch (op.kind()) {
2341                    case Token::Kind::TK_SLASH: {
2342                        rhs = this->writeReciprocal(rightType, rhs, out);
2343                        [[fallthrough]];
2344                    }
2345                    case Token::Kind::TK_STAR: {
2346                        SpvId result = this->nextId(&resultType);
2347                        this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2348                                               result, lhs, rhs, out);
2349                        return result;
2350                    }
2351                    default:
2352                        break;
2353                }
2354            }
2355            // promote number to vector
2356            const Type& vecType = leftType;
2357            SpvId vec = this->nextId(&vecType);
2358            this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2359            this->writeWord(this->getType(vecType), out);
2360            this->writeWord(vec, out);
2361            for (int i = 0; i < vecType.columns(); i++) {
2362                this->writeWord(rhs, out);
2363            }
2364            rhs = vec;
2365            operandType = &leftType;
2366        } else if (rightType.isVector() && leftType.isNumber()) {
2367            if (resultType.componentType().isFloat()) {
2368                if (op.kind() == Token::Kind::TK_STAR) {
2369                    SpvId result = this->nextId(&resultType);
2370                    this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
2371                                           result, rhs, lhs, out);
2372                    return result;
2373                }
2374            }
2375            // promote number to vector
2376            const Type& vecType = rightType;
2377            SpvId vec = this->nextId(&vecType);
2378            this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
2379            this->writeWord(this->getType(vecType), out);
2380            this->writeWord(vec, out);
2381            for (int i = 0; i < vecType.columns(); i++) {
2382                this->writeWord(lhs, out);
2383            }
2384            lhs = vec;
2385            operandType = &rightType;
2386        } else if (leftType.isMatrix()) {
2387            if (op.kind() == Token::Kind::TK_STAR) {
2388                // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
2389                SpvOp_ spvop;
2390                if (rightType.isMatrix()) {
2391                    spvop = SpvOpMatrixTimesMatrix;
2392                } else if (rightType.isVector()) {
2393                    spvop = SpvOpMatrixTimesVector;
2394                } else {
2395                    SkASSERT(rightType.isScalar());
2396                    spvop = SpvOpMatrixTimesScalar;
2397                }
2398                SpvId result = this->nextId(&resultType);
2399                this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
2400                return result;
2401            } else {
2402                // Matrix-op-vector is not supported in GLSL/SkSL for non-multiplication ops; we
2403                // expect to have a scalar here.
2404                SkASSERT(rightType.isScalar());
2405
2406                // Splat rhs across an entire matrix so we can reuse the matrix-op-matrix path.
2407                SpvId rhsMatrix = this->writeScalarToMatrixSplat(leftType, rhs, out);
2408
2409                // Perform this operation as matrix-op-matrix.
2410                return this->writeBinaryExpression(leftType, lhs, op, leftType, rhsMatrix,
2411                                                   resultType, out);
2412            }
2413        } else if (rightType.isMatrix()) {
2414            if (op.kind() == Token::Kind::TK_STAR) {
2415                // Matrix-times-vector and matrix-times-scalar have dedicated ops in SPIR-V.
2416                SpvId result = this->nextId(&resultType);
2417                if (leftType.isVector()) {
2418                    this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType),
2419                                           result, lhs, rhs, out);
2420                } else {
2421                    SkASSERT(leftType.isScalar());
2422                    this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType),
2423                                           result, rhs, lhs, out);
2424                }
2425                return result;
2426            } else {
2427                // Vector-op-matrix is not supported in GLSL/SkSL for non-multiplication ops; we
2428                // expect to have a scalar here.
2429                SkASSERT(leftType.isScalar());
2430
2431                // Splat lhs across an entire matrix so we can reuse the matrix-op-matrix path.
2432                SpvId lhsMatrix = this->writeScalarToMatrixSplat(rightType, lhs, out);
2433
2434                // Perform this operation as matrix-op-matrix.
2435                return this->writeBinaryExpression(rightType, lhsMatrix, op, rightType, rhs,
2436                                                   resultType, out);
2437            }
2438        } else {
2439            fContext.fErrors->error(leftType.fLine, "unsupported mixed-type expression");
2440            return -1;
2441        }
2442    } else {
2443        operandType = &this->getActualType(leftType);
2444        SkASSERT(*operandType == this->getActualType(rightType));
2445    }
2446    switch (op.kind()) {
2447        case Token::Kind::TK_EQEQ: {
2448            if (operandType->isMatrix()) {
2449                return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
2450                                                   SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
2451            }
2452            if (operandType->isStruct()) {
2453                return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2454            }
2455            if (operandType->isArray()) {
2456                return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2457            }
2458            SkASSERT(resultType.isBoolean());
2459            const Type* tmpType;
2460            if (operandType->isVector()) {
2461                tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2462                                                             operandType->columns(),
2463                                                             operandType->rows());
2464            } else {
2465                tmpType = &resultType;
2466            }
2467            return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2468                                                               SpvOpFOrdEqual, SpvOpIEqual,
2469                                                               SpvOpIEqual, SpvOpLogicalEqual, out),
2470                                    *operandType, SpvOpAll, out);
2471        }
2472        case Token::Kind::TK_NEQ:
2473            if (operandType->isMatrix()) {
2474                return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
2475                                                   SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
2476            }
2477            if (operandType->isStruct()) {
2478                return this->writeStructComparison(*operandType, lhs, op, rhs, out);
2479            }
2480            if (operandType->isArray()) {
2481                return this->writeArrayComparison(*operandType, lhs, op, rhs, out);
2482            }
2483            [[fallthrough]];
2484        case Token::Kind::TK_LOGICALXOR:
2485            SkASSERT(resultType.isBoolean());
2486            const Type* tmpType;
2487            if (operandType->isVector()) {
2488                tmpType = &fContext.fTypes.fBool->toCompound(fContext,
2489                                                             operandType->columns(),
2490                                                             operandType->rows());
2491            } else {
2492                tmpType = &resultType;
2493            }
2494            return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
2495                                                               SpvOpFOrdNotEqual, SpvOpINotEqual,
2496                                                               SpvOpINotEqual, SpvOpLogicalNotEqual,
2497                                                               out),
2498                                    *operandType, SpvOpAny, out);
2499        case Token::Kind::TK_GT:
2500            SkASSERT(resultType.isBoolean());
2501            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2502                                              SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2503                                              SpvOpUGreaterThan, SpvOpUndef, out);
2504        case Token::Kind::TK_LT:
2505            SkASSERT(resultType.isBoolean());
2506            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2507                                              SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2508        case Token::Kind::TK_GTEQ:
2509            SkASSERT(resultType.isBoolean());
2510            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2511                                              SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2512                                              SpvOpUGreaterThanEqual, SpvOpUndef, out);
2513        case Token::Kind::TK_LTEQ:
2514            SkASSERT(resultType.isBoolean());
2515            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2516                                              SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2517                                              SpvOpULessThanEqual, SpvOpUndef, out);
2518        case Token::Kind::TK_PLUS:
2519            if (leftType.isMatrix() && rightType.isMatrix()) {
2520                SkASSERT(leftType == rightType);
2521                return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFAdd, out);
2522            }
2523            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2524                                              SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2525        case Token::Kind::TK_MINUS:
2526            if (leftType.isMatrix() && rightType.isMatrix()) {
2527                SkASSERT(leftType == rightType);
2528                return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFSub, out);
2529            }
2530            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2531                                              SpvOpISub, SpvOpISub, SpvOpUndef, out);
2532        case Token::Kind::TK_STAR:
2533            if (leftType.isMatrix() && rightType.isMatrix()) {
2534                // matrix multiply
2535                SpvId result = this->nextId(&resultType);
2536                this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2537                                       lhs, rhs, out);
2538                return result;
2539            }
2540            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2541                                              SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2542        case Token::Kind::TK_SLASH:
2543            if (leftType.isMatrix() && rightType.isMatrix()) {
2544                SkASSERT(leftType == rightType);
2545                return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs, SpvOpFDiv, out);
2546            }
2547            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2548                                              SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2549        case Token::Kind::TK_PERCENT:
2550            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2551                                              SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2552        case Token::Kind::TK_SHL:
2553            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2554                                              SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2555                                              SpvOpUndef, out);
2556        case Token::Kind::TK_SHR:
2557            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2558                                              SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2559                                              SpvOpUndef, out);
2560        case Token::Kind::TK_BITWISEAND:
2561            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2562                                              SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2563        case Token::Kind::TK_BITWISEOR:
2564            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2565                                              SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2566        case Token::Kind::TK_BITWISEXOR:
2567            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2568                                              SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2569        default:
2570            fContext.fErrors->error(0, "unsupported token");
2571            return -1;
2572    }
2573}
2574
2575SpvId SPIRVCodeGenerator::writeArrayComparison(const Type& arrayType, SpvId lhs, Operator op,
2576                                               SpvId rhs, OutputStream& out) {
2577    // The inputs must be arrays, and the op must be == or !=.
2578    SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ);
2579    SkASSERT(arrayType.isArray());
2580    const Type& componentType = arrayType.componentType();
2581    const SpvId componentTypeId = this->getType(componentType);
2582    const int arraySize = arrayType.columns();
2583    SkASSERT(arraySize > 0);
2584
2585    // Synthesize equality checks for each item in the array.
2586    const Type& boolType = *fContext.fTypes.fBool;
2587    SpvId allComparisons = (SpvId)-1;
2588    for (int index = 0; index < arraySize; ++index) {
2589        // Get the left and right item in the array.
2590        SpvId itemL = this->nextId(&componentType);
2591        this->writeInstruction(SpvOpCompositeExtract, componentTypeId, itemL, lhs, index, out);
2592        SpvId itemR = this->nextId(&componentType);
2593        this->writeInstruction(SpvOpCompositeExtract, componentTypeId, itemR, rhs, index, out);
2594        // Use `writeBinaryExpression` with the requested == or != operator on these items.
2595        SpvId comparison = this->writeBinaryExpression(componentType, itemL, op,
2596                                                       componentType, itemR, boolType, out);
2597        // Merge this comparison result with all the other comparisons we've done.
2598        allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
2599    }
2600    return allComparisons;
2601}
2602
2603SpvId SPIRVCodeGenerator::writeStructComparison(const Type& structType, SpvId lhs, Operator op,
2604                                                SpvId rhs, OutputStream& out) {
2605    // The inputs must be structs containing fields, and the op must be == or !=.
2606    SkASSERT(op.kind() == Token::Kind::TK_EQEQ || op.kind() == Token::Kind::TK_NEQ);
2607    SkASSERT(structType.isStruct());
2608    const std::vector<Type::Field>& fields = structType.fields();
2609    SkASSERT(!fields.empty());
2610
2611    // Synthesize equality checks for each field in the struct.
2612    const Type& boolType = *fContext.fTypes.fBool;
2613    SpvId allComparisons = (SpvId)-1;
2614    for (int index = 0; index < (int)fields.size(); ++index) {
2615        // Get the left and right versions of this field.
2616        const Type& fieldType = *fields[index].fType;
2617        const SpvId fieldTypeId = this->getType(fieldType);
2618
2619        SpvId fieldL = this->nextId(&fieldType);
2620        this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldL, lhs, index, out);
2621        SpvId fieldR = this->nextId(&fieldType);
2622        this->writeInstruction(SpvOpCompositeExtract, fieldTypeId, fieldR, rhs, index, out);
2623        // Use `writeBinaryExpression` with the requested == or != operator on these fields.
2624        SpvId comparison = this->writeBinaryExpression(fieldType, fieldL, op, fieldType, fieldR,
2625                                                       boolType, out);
2626        // Merge this comparison result with all the other comparisons we've done.
2627        allComparisons = this->mergeComparisons(comparison, allComparisons, op, out);
2628    }
2629    return allComparisons;
2630}
2631
2632SpvId SPIRVCodeGenerator::mergeComparisons(SpvId comparison, SpvId allComparisons, Operator op,
2633                                           OutputStream& out) {
2634    // If this is the first entry, we don't need to merge comparison results with anything.
2635    if (allComparisons == (SpvId)-1) {
2636        return comparison;
2637    }
2638    // Use LogicalAnd or LogicalOr to combine the comparison with all the other comparisons.
2639    const Type& boolType = *fContext.fTypes.fBool;
2640    SpvId boolTypeId = this->getType(boolType);
2641    SpvId logicalOp = this->nextId(&boolType);
2642    switch (op.kind()) {
2643        case Token::Kind::TK_EQEQ:
2644            this->writeInstruction(SpvOpLogicalAnd, boolTypeId, logicalOp,
2645                                   comparison, allComparisons, out);
2646            break;
2647        case Token::Kind::TK_NEQ:
2648            this->writeInstruction(SpvOpLogicalOr, boolTypeId, logicalOp,
2649                                   comparison, allComparisons, out);
2650            break;
2651        default:
2652            SkDEBUGFAILF("mergeComparisons only supports == and !=, not %s", op.operatorName());
2653            return (SpvId)-1;
2654    }
2655    return logicalOp;
2656}
2657
2658static float division_by_literal_value(Operator op, const Expression& right) {
2659    // If this is a division by a literal value, returns that literal value. Otherwise, returns 0.
2660    if (op.kind() == Token::Kind::TK_SLASH && right.isFloatLiteral()) {
2661        float rhsValue = right.as<Literal>().floatValue();
2662        if (std::isfinite(rhsValue)) {
2663            return rhsValue;
2664        }
2665    }
2666    return 0.0f;
2667}
2668
2669SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
2670    const Expression* left = b.left().get();
2671    const Expression* right = b.right().get();
2672    Operator op = b.getOperator();
2673
2674    switch (op.kind()) {
2675        case Token::Kind::TK_EQ: {
2676            // Handles assignment.
2677            SpvId rhs = this->writeExpression(*right, out);
2678            this->getLValue(*left, out)->store(rhs, out);
2679            return rhs;
2680        }
2681        case Token::Kind::TK_LOGICALAND:
2682            // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
2683            return this->writeLogicalAnd(*b.left(), *b.right(), out);
2684
2685        case Token::Kind::TK_LOGICALOR:
2686            // Handles short-circuiting; we don't necessarily evaluate both LHS and RHS.
2687            return this->writeLogicalOr(*b.left(), *b.right(), out);
2688
2689        default:
2690            break;
2691    }
2692
2693    std::unique_ptr<LValue> lvalue;
2694    SpvId lhs;
2695    if (op.isAssignment()) {
2696        lvalue = this->getLValue(*left, out);
2697        lhs = lvalue->load(out);
2698    } else {
2699        lvalue = nullptr;
2700        lhs = this->writeExpression(*left, out);
2701    }
2702
2703    SpvId rhs;
2704    float rhsValue = division_by_literal_value(op, *right);
2705    if (rhsValue != 0.0f) {
2706        // Rewrite floating-point division by a literal into multiplication by the reciprocal.
2707        // This converts `expr / 2` into `expr * 0.5`
2708        // This improves codegen, especially for certain types of divides (e.g. vector/scalar).
2709        op = Operator(Token::Kind::TK_STAR);
2710        rhs = this->writeLiteral(1.0 / rhsValue, right->type());
2711    } else {
2712        // Write the right-hand side expression normally.
2713        rhs = this->writeExpression(*right, out);
2714    }
2715
2716    SpvId result = this->writeBinaryExpression(left->type(), lhs, op.removeAssignment(),
2717                                               right->type(), rhs, b.type(), out);
2718    if (lvalue) {
2719        lvalue->store(result, out);
2720    }
2721    return result;
2722}
2723
2724SpvId SPIRVCodeGenerator::writeLogicalAnd(const Expression& left, const Expression& right,
2725                                          OutputStream& out) {
2726    SpvId falseConstant = this->writeLiteral(0.0, *fContext.fTypes.fBool);
2727    SpvId lhs = this->writeExpression(left, out);
2728    SpvId rhsLabel = this->nextId(nullptr);
2729    SpvId end = this->nextId(nullptr);
2730    SpvId lhsBlock = fCurrentBlock;
2731    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2732    this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2733    this->writeLabel(rhsLabel, out);
2734    SpvId rhs = this->writeExpression(right, out);
2735    SpvId rhsBlock = fCurrentBlock;
2736    this->writeInstruction(SpvOpBranch, end, out);
2737    this->writeLabel(end, out);
2738    SpvId result = this->nextId(nullptr);
2739    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, falseConstant,
2740                           lhsBlock, rhs, rhsBlock, out);
2741    return result;
2742}
2743
2744SpvId SPIRVCodeGenerator::writeLogicalOr(const Expression& left, const Expression& right,
2745                                         OutputStream& out) {
2746    SpvId trueConstant = this->writeLiteral(1.0, *fContext.fTypes.fBool);
2747    SpvId lhs = this->writeExpression(left, out);
2748    SpvId rhsLabel = this->nextId(nullptr);
2749    SpvId end = this->nextId(nullptr);
2750    SpvId lhsBlock = fCurrentBlock;
2751    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2752    this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2753    this->writeLabel(rhsLabel, out);
2754    SpvId rhs = this->writeExpression(right, out);
2755    SpvId rhsBlock = fCurrentBlock;
2756    this->writeInstruction(SpvOpBranch, end, out);
2757    this->writeLabel(end, out);
2758    SpvId result = this->nextId(nullptr);
2759    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fTypes.fBool), result, trueConstant,
2760                           lhsBlock, rhs, rhsBlock, out);
2761    return result;
2762}
2763
2764SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2765    const Type& type = t.type();
2766    SpvId test = this->writeExpression(*t.test(), out);
2767    if (t.ifTrue()->type().columns() == 1 &&
2768        t.ifTrue()->isCompileTimeConstant() &&
2769        t.ifFalse()->isCompileTimeConstant()) {
2770        // both true and false are constants, can just use OpSelect
2771        SpvId result = this->nextId(nullptr);
2772        SpvId trueId = this->writeExpression(*t.ifTrue(), out);
2773        SpvId falseId = this->writeExpression(*t.ifFalse(), out);
2774        this->writeInstruction(SpvOpSelect, this->getType(type), result, test, trueId, falseId,
2775                               out);
2776        return result;
2777    }
2778    // was originally using OpPhi to choose the result, but for some reason that is crashing on
2779    // Adreno. Switched to storing the result in a temp variable as glslang does.
2780    SpvId var = this->nextId(nullptr);
2781    this->writeInstruction(SpvOpVariable, this->getPointerType(type, SpvStorageClassFunction),
2782                           var, SpvStorageClassFunction, fVariableBuffer);
2783    SpvId trueLabel = this->nextId(nullptr);
2784    SpvId falseLabel = this->nextId(nullptr);
2785    SpvId end = this->nextId(nullptr);
2786    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2787    this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2788    this->writeLabel(trueLabel, out);
2789    this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifTrue(), out), out);
2790    this->writeInstruction(SpvOpBranch, end, out);
2791    this->writeLabel(falseLabel, out);
2792    this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.ifFalse(), out), out);
2793    this->writeInstruction(SpvOpBranch, end, out);
2794    this->writeLabel(end, out);
2795    SpvId result = this->nextId(&type);
2796    this->writeInstruction(SpvOpLoad, this->getType(type), result, var, out);
2797    return result;
2798}
2799
2800SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2801    const Type& type = p.type();
2802    if (p.getOperator().kind() == Token::Kind::TK_MINUS) {
2803        SpvId result = this->nextId(&type);
2804        SpvId typeId = this->getType(type);
2805        SpvId expr = this->writeExpression(*p.operand(), out);
2806        if (is_float(fContext, type)) {
2807            this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2808        } else if (is_signed(fContext, type) || is_unsigned(fContext, type)) {
2809            this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2810        } else {
2811            SkDEBUGFAILF("unsupported prefix expression %s", p.description().c_str());
2812        }
2813        return result;
2814    }
2815    switch (p.getOperator().kind()) {
2816        case Token::Kind::TK_PLUS:
2817            return this->writeExpression(*p.operand(), out);
2818        case Token::Kind::TK_PLUSPLUS: {
2819            std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2820            SpvId one = this->writeLiteral(1.0, type);
2821            SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one,
2822                                                      SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2823                                                      out);
2824            lv->store(result, out);
2825            return result;
2826        }
2827        case Token::Kind::TK_MINUSMINUS: {
2828            std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2829            SpvId one = this->writeLiteral(1.0, type);
2830            SpvId result = this->writeBinaryOperation(type, type, lv->load(out), one, SpvOpFSub,
2831                                                      SpvOpISub, SpvOpISub, SpvOpUndef, out);
2832            lv->store(result, out);
2833            return result;
2834        }
2835        case Token::Kind::TK_LOGICALNOT: {
2836            SkASSERT(p.operand()->type().isBoolean());
2837            SpvId result = this->nextId(nullptr);
2838            this->writeInstruction(SpvOpLogicalNot, this->getType(type), result,
2839                                   this->writeExpression(*p.operand(), out), out);
2840            return result;
2841        }
2842        case Token::Kind::TK_BITWISENOT: {
2843            SpvId result = this->nextId(nullptr);
2844            this->writeInstruction(SpvOpNot, this->getType(type), result,
2845                                   this->writeExpression(*p.operand(), out), out);
2846            return result;
2847        }
2848        default:
2849            SkDEBUGFAILF("unsupported prefix expression: %s", p.description().c_str());
2850            return -1;
2851    }
2852}
2853
2854SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2855    const Type& type = p.type();
2856    std::unique_ptr<LValue> lv = this->getLValue(*p.operand(), out);
2857    SpvId result = lv->load(out);
2858    SpvId one = this->writeLiteral(1.0, type);
2859    switch (p.getOperator().kind()) {
2860        case Token::Kind::TK_PLUSPLUS: {
2861            SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFAdd,
2862                                                    SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2863            lv->store(temp, out);
2864            return result;
2865        }
2866        case Token::Kind::TK_MINUSMINUS: {
2867            SpvId temp = this->writeBinaryOperation(type, type, result, one, SpvOpFSub,
2868                                                    SpvOpISub, SpvOpISub, SpvOpUndef, out);
2869            lv->store(temp, out);
2870            return result;
2871        }
2872        default:
2873            SkDEBUGFAILF("unsupported postfix expression %s", p.description().c_str());
2874            return -1;
2875    }
2876}
2877
2878SpvId SPIRVCodeGenerator::writeLiteral(const Literal& l) {
2879    return this->writeLiteral(l.value(), l.type());
2880}
2881
2882SpvId SPIRVCodeGenerator::writeLiteral(double value, const Type& type) {
2883    int32_t valueBits;
2884    if (type.isFloat()) {
2885        float fValue = value;
2886        memcpy(&valueBits, &fValue, sizeof(valueBits));
2887    } else {
2888        SKSL_INT iValue = value;
2889        valueBits = iValue;
2890    }
2891
2892    SPIRVNumberConstant key{valueBits, type.numberKind()};
2893    auto [iter, newlyCreated] = fNumberConstants.insert({key, (SpvId)-1});
2894    if (newlyCreated) {
2895        SpvId result = this->nextId(nullptr);
2896        iter->second = result;
2897
2898        if (type.isBoolean()) {
2899            this->writeInstruction(valueBits ? SpvOpConstantTrue : SpvOpConstantFalse,
2900                                   this->getType(type), result, fConstantBuffer);
2901        } else {
2902            this->writeInstruction(SpvOpConstant, this->getType(type), result,
2903                                   (SpvId)valueBits, fConstantBuffer);
2904        }
2905    }
2906
2907    return iter->second;
2908}
2909
2910SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2911    SpvId result = fFunctionMap[&f];
2912    SpvId returnTypeId = this->getType(f.returnType());
2913    SpvId functionTypeId = this->getFunctionType(f);
2914    this->writeInstruction(SpvOpFunction, returnTypeId, result,
2915                           SpvFunctionControlMaskNone, functionTypeId, out);
2916    String mangledName = f.mangledName();
2917    this->writeInstruction(SpvOpName,
2918                           result,
2919                           skstd::string_view(mangledName.c_str(), mangledName.size()),
2920                           fNameBuffer);
2921    for (const Variable* parameter : f.parameters()) {
2922        SpvId id = this->nextId(nullptr);
2923        fVariableMap[parameter] = id;
2924        SpvId type = this->getPointerType(parameter->type(), SpvStorageClassFunction);
2925        this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2926    }
2927    return result;
2928}
2929
2930SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2931    fVariableBuffer.reset();
2932    SpvId result = this->writeFunctionStart(f.declaration(), out);
2933    fCurrentBlock = 0;
2934    this->writeLabel(this->nextId(nullptr), out);
2935    StringStream bodyBuffer;
2936    this->writeBlock(f.body()->as<Block>(), bodyBuffer);
2937    write_stringstream(fVariableBuffer, out);
2938    if (f.declaration().isMain()) {
2939        write_stringstream(fGlobalInitializersBuffer, out);
2940    }
2941    write_stringstream(bodyBuffer, out);
2942    if (fCurrentBlock) {
2943        if (f.declaration().returnType().isVoid()) {
2944            this->writeInstruction(SpvOpReturn, out);
2945        } else {
2946            this->writeInstruction(SpvOpUnreachable, out);
2947        }
2948    }
2949    this->writeInstruction(SpvOpFunctionEnd, out);
2950    return result;
2951}
2952
2953void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2954    if (layout.fLocation >= 0) {
2955        this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2956                               fDecorationBuffer);
2957    }
2958    if (layout.fBinding >= 0) {
2959        this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2960                               fDecorationBuffer);
2961    }
2962    if (layout.fIndex >= 0) {
2963        this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2964                               fDecorationBuffer);
2965    }
2966    if (layout.fSet >= 0) {
2967        this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2968                               fDecorationBuffer);
2969    }
2970    if (layout.fInputAttachmentIndex >= 0) {
2971        this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2972                               layout.fInputAttachmentIndex, fDecorationBuffer);
2973        fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2974    }
2975    if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
2976        this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2977                               fDecorationBuffer);
2978    }
2979}
2980
2981void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2982    if (layout.fLocation >= 0) {
2983        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2984                               layout.fLocation, fDecorationBuffer);
2985    }
2986    if (layout.fBinding >= 0) {
2987        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2988                               layout.fBinding, fDecorationBuffer);
2989    }
2990    if (layout.fIndex >= 0) {
2991        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2992                               layout.fIndex, fDecorationBuffer);
2993    }
2994    if (layout.fSet >= 0) {
2995        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2996                               layout.fSet, fDecorationBuffer);
2997    }
2998    if (layout.fInputAttachmentIndex >= 0) {
2999        this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
3000                               layout.fInputAttachmentIndex, fDecorationBuffer);
3001    }
3002    if (layout.fBuiltin >= 0) {
3003        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
3004                               layout.fBuiltin, fDecorationBuffer);
3005    }
3006}
3007
3008MemoryLayout SPIRVCodeGenerator::memoryLayoutForVariable(const Variable& v) const {
3009    bool pushConstant = ((v.modifiers().fLayout.fFlags & Layout::kPushConstant_Flag) != 0);
3010    return pushConstant ? MemoryLayout(MemoryLayout::k430_Standard) : fDefaultLayout;
3011}
3012
3013SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf, bool appendRTFlip) {
3014    MemoryLayout memoryLayout = this->memoryLayoutForVariable(intf.variable());
3015    SpvId result = this->nextId(nullptr);
3016    const Variable& intfVar = intf.variable();
3017    const Type& type = intfVar.type();
3018    if (!MemoryLayout::LayoutIsSupported(type)) {
3019        fContext.fErrors->error(type.fLine, "type '" + type.name() + "' is not permitted here");
3020        return this->nextId(nullptr);
3021    }
3022    SpvStorageClass_ storageClass = get_storage_class(intf.variable(), SpvStorageClassFunction);
3023    if (fProgram.fInputs.fUseFlipRTUniform && appendRTFlip && type.isStruct()) {
3024        // We can only have one interface block (because we use push_constant and that is limited
3025        // to one per program), so we need to append rtflip to this one rather than synthesize an
3026        // entirely new block when the variable is referenced. And we can't modify the existing
3027        // block, so we instead create a modified copy of it and write that.
3028        std::vector<Type::Field> fields = type.fields();
3029        fields.emplace_back(Modifiers(Layout(/*flags=*/0,
3030                                             /*location=*/-1,
3031                                             fProgram.fConfig->fSettings.fRTFlipOffset,
3032                                             /*binding=*/-1,
3033                                             /*index=*/-1,
3034                                             /*set=*/-1,
3035                                             /*builtin=*/-1,
3036                                             /*inputAttachmentIndex=*/-1),
3037                                      /*flags=*/0),
3038                            SKSL_RTFLIP_NAME,
3039                            fContext.fTypes.fFloat2.get());
3040        {
3041            AutoAttachPoolToThread attach(fProgram.fPool.get());
3042            const Type* rtFlipStructType = fProgram.fSymbols->takeOwnershipOfSymbol(
3043                    Type::MakeStructType(type.fLine, type.name(), std::move(fields)));
3044            const Variable* modifiedVar = fProgram.fSymbols->takeOwnershipOfSymbol(
3045                    std::make_unique<Variable>(intfVar.fLine,
3046                                               &intfVar.modifiers(),
3047                                               intfVar.name(),
3048                                               rtFlipStructType,
3049                                               intfVar.isBuiltin(),
3050                                               intfVar.storage()));
3051            fSPIRVBonusVariables.insert(modifiedVar);
3052            InterfaceBlock modifiedCopy(intf.fLine,
3053                                        *modifiedVar,
3054                                        intf.typeName(),
3055                                        intf.instanceName(),
3056                                        intf.arraySize(),
3057                                        intf.typeOwner());
3058            result = this->writeInterfaceBlock(modifiedCopy, false);
3059            fProgram.fSymbols->add(std::make_unique<Field>(
3060                    /*line=*/-1, modifiedVar, rtFlipStructType->fields().size() - 1));
3061        }
3062        fVariableMap[&intfVar] = result;
3063        fWroteRTFlip = true;
3064        return result;
3065    }
3066    const Modifiers& intfModifiers = intfVar.modifiers();
3067    SpvId typeId = this->getType(type, memoryLayout);
3068    if (intfModifiers.fLayout.fBuiltin == -1) {
3069        this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
3070    }
3071    SpvId ptrType = this->nextId(nullptr);
3072    this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
3073    this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
3074    Layout layout = intfModifiers.fLayout;
3075    if (intfModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
3076        layout.fSet = 0;
3077    }
3078    this->writeLayout(layout, result);
3079    fVariableMap[&intfVar] = result;
3080    return result;
3081}
3082
3083bool SPIRVCodeGenerator::isDead(const Variable& var) const {
3084    // During SPIR-V code generation, we synthesize some extra bonus variables that don't actually
3085    // exist in the Program at all and aren't tracked by the ProgramUsage. They aren't dead, though.
3086    if (fSPIRVBonusVariables.count(&var)) {
3087        return false;
3088    }
3089    ProgramUsage::VariableCounts counts = fProgram.usage()->get(var);
3090    if (counts.fRead || counts.fWrite) {
3091        return false;
3092    }
3093    // It's not entirely clear what the rules are for eliding interface variables. Generally, it
3094    // causes problems to elide them, even when they're dead.
3095    return !(var.modifiers().fFlags &
3096             (Modifiers::kIn_Flag | Modifiers::kOut_Flag | Modifiers::kUniform_Flag));
3097}
3098
3099void SPIRVCodeGenerator::writeGlobalVar(ProgramKind kind, const VarDeclaration& varDecl) {
3100    const Variable& var = varDecl.var();
3101    if (var.modifiers().fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
3102        kind != ProgramKind::kFragment) {
3103        SkASSERT(!fProgram.fConfig->fSettings.fFragColorIsInOut);
3104        return;
3105    }
3106    if (var.modifiers().fLayout.fBuiltin == SK_SECONDARYFRAGCOLOR_BUILTIN) {
3107        return;
3108    }
3109    if (this->isDead(var)) {
3110        return;
3111    }
3112    SpvStorageClass_ storageClass = get_storage_class(var, SpvStorageClassPrivate);
3113    if (storageClass == SpvStorageClassUniform) {
3114        // Top-level uniforms are emitted in writeUniformBuffer.
3115        fTopLevelUniforms.push_back(&varDecl);
3116        return;
3117    }
3118    const Type& type = var.type();
3119    Layout layout = var.modifiers().fLayout;
3120    if (layout.fSet < 0 && storageClass == SpvStorageClassUniformConstant) {
3121        layout.fSet = fProgram.fConfig->fSettings.fDefaultUniformSet;
3122    }
3123    SpvId id = this->nextId(&type);
3124    fVariableMap[&var] = id;
3125    SpvId typeId = this->getPointerType(type, storageClass);
3126    this->writeInstruction(SpvOpVariable, typeId, id, storageClass, fConstantBuffer);
3127    this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3128    if (varDecl.value()) {
3129        SkASSERT(!fCurrentBlock);
3130        fCurrentBlock = -1;
3131        SpvId value = this->writeExpression(*varDecl.value(), fGlobalInitializersBuffer);
3132        this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
3133        fCurrentBlock = 0;
3134    }
3135    this->writeLayout(layout, id);
3136    if (var.modifiers().fFlags & Modifiers::kFlat_Flag) {
3137        this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
3138    }
3139    if (var.modifiers().fFlags & Modifiers::kNoPerspective_Flag) {
3140        this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
3141                                fDecorationBuffer);
3142    }
3143}
3144
3145void SPIRVCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl, OutputStream& out) {
3146    const Variable& var = varDecl.var();
3147    SpvId id = this->nextId(&var.type());
3148    fVariableMap[&var] = id;
3149    SpvId type = this->getPointerType(var.type(), SpvStorageClassFunction);
3150    this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
3151    this->writeInstruction(SpvOpName, id, var.name(), fNameBuffer);
3152    if (varDecl.value()) {
3153        SpvId value = this->writeExpression(*varDecl.value(), out);
3154        this->writeInstruction(SpvOpStore, id, value, out);
3155    }
3156}
3157
3158void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
3159    switch (s.kind()) {
3160        case Statement::Kind::kInlineMarker:
3161        case Statement::Kind::kNop:
3162            break;
3163        case Statement::Kind::kBlock:
3164            this->writeBlock(s.as<Block>(), out);
3165            break;
3166        case Statement::Kind::kExpression:
3167            this->writeExpression(*s.as<ExpressionStatement>().expression(), out);
3168            break;
3169        case Statement::Kind::kReturn:
3170            this->writeReturnStatement(s.as<ReturnStatement>(), out);
3171            break;
3172        case Statement::Kind::kVarDeclaration:
3173            this->writeVarDeclaration(s.as<VarDeclaration>(), out);
3174            break;
3175        case Statement::Kind::kIf:
3176            this->writeIfStatement(s.as<IfStatement>(), out);
3177            break;
3178        case Statement::Kind::kFor:
3179            this->writeForStatement(s.as<ForStatement>(), out);
3180            break;
3181        case Statement::Kind::kDo:
3182            this->writeDoStatement(s.as<DoStatement>(), out);
3183            break;
3184        case Statement::Kind::kSwitch:
3185            this->writeSwitchStatement(s.as<SwitchStatement>(), out);
3186            break;
3187        case Statement::Kind::kBreak:
3188            this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
3189            break;
3190        case Statement::Kind::kContinue:
3191            this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
3192            break;
3193        case Statement::Kind::kDiscard:
3194            this->writeInstruction(SpvOpKill, out);
3195            break;
3196        default:
3197            SkDEBUGFAILF("unsupported statement: %s", s.description().c_str());
3198            break;
3199    }
3200}
3201
3202void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
3203    for (const std::unique_ptr<Statement>& stmt : b.children()) {
3204        this->writeStatement(*stmt, out);
3205    }
3206}
3207
3208void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
3209    SpvId test = this->writeExpression(*stmt.test(), out);
3210    SpvId ifTrue = this->nextId(nullptr);
3211    SpvId ifFalse = this->nextId(nullptr);
3212    if (stmt.ifFalse()) {
3213        SpvId end = this->nextId(nullptr);
3214        this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3215        this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3216        this->writeLabel(ifTrue, out);
3217        this->writeStatement(*stmt.ifTrue(), out);
3218        if (fCurrentBlock) {
3219            this->writeInstruction(SpvOpBranch, end, out);
3220        }
3221        this->writeLabel(ifFalse, out);
3222        this->writeStatement(*stmt.ifFalse(), out);
3223        if (fCurrentBlock) {
3224            this->writeInstruction(SpvOpBranch, end, out);
3225        }
3226        this->writeLabel(end, out);
3227    } else {
3228        this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
3229        this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
3230        this->writeLabel(ifTrue, out);
3231        this->writeStatement(*stmt.ifTrue(), out);
3232        if (fCurrentBlock) {
3233            this->writeInstruction(SpvOpBranch, ifFalse, out);
3234        }
3235        this->writeLabel(ifFalse, out);
3236    }
3237}
3238
3239void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
3240    if (f.initializer()) {
3241        this->writeStatement(*f.initializer(), out);
3242    }
3243    SpvId header = this->nextId(nullptr);
3244    SpvId start = this->nextId(nullptr);
3245    SpvId body = this->nextId(nullptr);
3246    SpvId next = this->nextId(nullptr);
3247    fContinueTarget.push(next);
3248    SpvId end = this->nextId(nullptr);
3249    fBreakTarget.push(end);
3250    this->writeInstruction(SpvOpBranch, header, out);
3251    this->writeLabel(header, out);
3252    this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
3253    this->writeInstruction(SpvOpBranch, start, out);
3254    this->writeLabel(start, out);
3255    if (f.test()) {
3256        SpvId test = this->writeExpression(*f.test(), out);
3257        this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
3258    } else {
3259        this->writeInstruction(SpvOpBranch, body, out);
3260    }
3261    this->writeLabel(body, out);
3262    this->writeStatement(*f.statement(), out);
3263    if (fCurrentBlock) {
3264        this->writeInstruction(SpvOpBranch, next, out);
3265    }
3266    this->writeLabel(next, out);
3267    if (f.next()) {
3268        this->writeExpression(*f.next(), out);
3269    }
3270    this->writeInstruction(SpvOpBranch, header, out);
3271    this->writeLabel(end, out);
3272    fBreakTarget.pop();
3273    fContinueTarget.pop();
3274}
3275
3276void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
3277    SpvId header = this->nextId(nullptr);
3278    SpvId start = this->nextId(nullptr);
3279    SpvId next = this->nextId(nullptr);
3280    SpvId continueTarget = this->nextId(nullptr);
3281    fContinueTarget.push(continueTarget);
3282    SpvId end = this->nextId(nullptr);
3283    fBreakTarget.push(end);
3284    this->writeInstruction(SpvOpBranch, header, out);
3285    this->writeLabel(header, out);
3286    this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
3287    this->writeInstruction(SpvOpBranch, start, out);
3288    this->writeLabel(start, out);
3289    this->writeStatement(*d.statement(), out);
3290    if (fCurrentBlock) {
3291        this->writeInstruction(SpvOpBranch, next, out);
3292    }
3293    this->writeLabel(next, out);
3294    this->writeInstruction(SpvOpBranch, continueTarget, out);
3295    this->writeLabel(continueTarget, out);
3296    SpvId test = this->writeExpression(*d.test(), out);
3297    this->writeInstruction(SpvOpBranchConditional, test, header, end, out);
3298    this->writeLabel(end, out);
3299    fBreakTarget.pop();
3300    fContinueTarget.pop();
3301}
3302
3303void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
3304    SpvId value = this->writeExpression(*s.value(), out);
3305    std::vector<SpvId> labels;
3306    SpvId end = this->nextId(nullptr);
3307    SpvId defaultLabel = end;
3308    fBreakTarget.push(end);
3309    int size = 3;
3310    auto& cases = s.cases();
3311    for (const std::unique_ptr<Statement>& stmt : cases) {
3312        const SwitchCase& c = stmt->as<SwitchCase>();
3313        SpvId label = this->nextId(nullptr);
3314        labels.push_back(label);
3315        if (c.value()) {
3316            size += 2;
3317        } else {
3318            defaultLabel = label;
3319        }
3320    }
3321    labels.push_back(end);
3322    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
3323    this->writeOpCode(SpvOpSwitch, size, out);
3324    this->writeWord(value, out);
3325    this->writeWord(defaultLabel, out);
3326    for (size_t i = 0; i < cases.size(); ++i) {
3327        const SwitchCase& c = cases[i]->as<SwitchCase>();
3328        if (!c.value()) {
3329            continue;
3330        }
3331        this->writeWord(c.value()->as<Literal>().intValue(), out);
3332        this->writeWord(labels[i], out);
3333    }
3334    for (size_t i = 0; i < cases.size(); ++i) {
3335        const SwitchCase& c = cases[i]->as<SwitchCase>();
3336        this->writeLabel(labels[i], out);
3337        this->writeStatement(*c.statement(), out);
3338        if (fCurrentBlock) {
3339            this->writeInstruction(SpvOpBranch, labels[i + 1], out);
3340        }
3341    }
3342    this->writeLabel(end, out);
3343    fBreakTarget.pop();
3344}
3345
3346void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
3347    if (r.expression()) {
3348        this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.expression(), out),
3349                               out);
3350    } else {
3351        this->writeInstruction(SpvOpReturn, out);
3352    }
3353}
3354
3355// Given any function, returns the top-level symbol table (OUTSIDE of the function's scope).
3356static std::shared_ptr<SymbolTable> get_top_level_symbol_table(const FunctionDeclaration& anyFunc) {
3357    return anyFunc.definition()->body()->as<Block>().symbolTable()->fParent;
3358}
3359
3360SPIRVCodeGenerator::EntrypointAdapter SPIRVCodeGenerator::writeEntrypointAdapter(
3361        const FunctionDeclaration& main) {
3362    // Our goal is to synthesize a tiny helper function which looks like this:
3363    //     void _entrypoint() { sk_FragColor = main(); }
3364
3365    // Fish a symbol table out of main().
3366    std::shared_ptr<SymbolTable> symbolTable = get_top_level_symbol_table(main);
3367
3368    // Get `sk_FragColor` as a writable reference.
3369    const Symbol* skFragColorSymbol = (*symbolTable)["sk_FragColor"];
3370    SkASSERT(skFragColorSymbol);
3371    const Variable& skFragColorVar = skFragColorSymbol->as<Variable>();
3372    auto skFragColorRef = std::make_unique<VariableReference>(/*line=*/-1, &skFragColorVar,
3373                                                              VariableReference::RefKind::kWrite);
3374    // Synthesize a call to the `main()` function.
3375    if (main.returnType() != skFragColorRef->type()) {
3376        fContext.fErrors->error(main.fLine, "SPIR-V does not support returning '" +
3377                                            main.returnType().description() + "' from main()");
3378        return {};
3379    }
3380    ExpressionArray args;
3381    if (main.parameters().size() == 1) {
3382        if (main.parameters()[0]->type() != *fContext.fTypes.fFloat2) {
3383            fContext.fErrors->error(main.fLine,
3384                    "SPIR-V does not support parameter of type '" +
3385                    main.parameters()[0]->type().description() + "' to main()");
3386            return {};
3387        }
3388        args.push_back(dsl::Float2(0).release());
3389    }
3390    auto callMainFn = std::make_unique<FunctionCall>(/*line=*/-1, &main.returnType(), &main,
3391                                                     std::move(args));
3392
3393    // Synthesize `skFragColor = main()` as a BinaryExpression.
3394    auto assignmentStmt = std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
3395            /*line=*/-1,
3396            std::move(skFragColorRef),
3397            Token::Kind::TK_EQ,
3398            std::move(callMainFn),
3399            &main.returnType()));
3400
3401    // Function bodies are always wrapped in a Block.
3402    StatementArray entrypointStmts;
3403    entrypointStmts.push_back(std::move(assignmentStmt));
3404    auto entrypointBlock = Block::Make(/*line=*/-1, std::move(entrypointStmts),
3405                                       symbolTable, /*isScope=*/true);
3406    // Declare an entrypoint function.
3407    EntrypointAdapter adapter;
3408    adapter.fLayout = {};
3409    adapter.fModifiers = Modifiers{adapter.fLayout, Modifiers::kHasSideEffects_Flag};
3410    adapter.entrypointDecl =
3411            std::make_unique<FunctionDeclaration>(/*line=*/-1,
3412                                                  &adapter.fModifiers,
3413                                                  "_entrypoint",
3414                                                  /*parameters=*/std::vector<const Variable*>{},
3415                                                  /*returnType=*/fContext.fTypes.fVoid.get(),
3416                                                  /*builtin=*/false);
3417    // Define it.
3418    adapter.entrypointDef = FunctionDefinition::Convert(fContext,
3419                                                        /*line=*/-1,
3420                                                        *adapter.entrypointDecl,
3421                                                        std::move(entrypointBlock),
3422                                                        /*builtin=*/false);
3423
3424    adapter.entrypointDecl->setDefinition(adapter.entrypointDef.get());
3425    return adapter;
3426}
3427
3428void SPIRVCodeGenerator::writeUniformBuffer(std::shared_ptr<SymbolTable> topLevelSymbolTable) {
3429    SkASSERT(!fTopLevelUniforms.empty());
3430    static constexpr char kUniformBufferName[] = "_UniformBuffer";
3431
3432    // Convert the list of top-level uniforms into a matching struct named _UniformBuffer, and build
3433    // a lookup table of variables to UniformBuffer field indices.
3434    std::vector<Type::Field> fields;
3435    fields.reserve(fTopLevelUniforms.size());
3436    fTopLevelUniformMap.reserve(fTopLevelUniforms.size());
3437    for (const VarDeclaration* topLevelUniform : fTopLevelUniforms) {
3438        const Variable* var = &topLevelUniform->var();
3439        fTopLevelUniformMap[var] = (int)fields.size();
3440        fields.emplace_back(var->modifiers(), var->name(), &var->type());
3441    }
3442    fUniformBuffer.fStruct = Type::MakeStructType(/*line=*/-1, kUniformBufferName,
3443                                                 std::move(fields));
3444
3445    // Create a global variable to contain this struct.
3446    Layout layout;
3447    layout.fBinding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
3448    layout.fSet     = fProgram.fConfig->fSettings.fDefaultUniformSet;
3449    Modifiers modifiers{layout, Modifiers::kUniform_Flag};
3450
3451    fUniformBuffer.fInnerVariable = std::make_unique<Variable>(
3452            /*line=*/-1, fProgram.fModifiers->add(modifiers), kUniformBufferName,
3453            fUniformBuffer.fStruct.get(), /*builtin=*/false, Variable::Storage::kGlobal);
3454
3455    // Create an interface block object for this global variable.
3456    fUniformBuffer.fInterfaceBlock = std::make_unique<InterfaceBlock>(
3457            /*offset=*/-1, *fUniformBuffer.fInnerVariable, kUniformBufferName,
3458            kUniformBufferName, /*arraySize=*/0, topLevelSymbolTable);
3459
3460    // Generate an interface block and hold onto its ID.
3461    fUniformBufferId = this->writeInterfaceBlock(*fUniformBuffer.fInterfaceBlock);
3462}
3463
3464void SPIRVCodeGenerator::addRTFlipUniform(int line) {
3465    if (fWroteRTFlip) {
3466        return;
3467    }
3468    // Flip variable hasn't been written yet. This means we don't have an existing
3469    // interface block, so we're free to just synthesize one.
3470    fWroteRTFlip = true;
3471    std::vector<Type::Field> fields;
3472    if (fProgram.fConfig->fSettings.fRTFlipOffset < 0) {
3473        fContext.fErrors->error(line, "RTFlipOffset is negative");
3474    }
3475    fields.emplace_back(Modifiers(Layout(/*flags=*/0,
3476                                         /*location=*/-1,
3477                                         fProgram.fConfig->fSettings.fRTFlipOffset,
3478                                         /*binding=*/-1,
3479                                         /*index=*/-1,
3480                                         /*set=*/-1,
3481                                         /*builtin=*/-1,
3482                                         /*inputAttachmentIndex=*/-1),
3483                                  /*flags=*/0),
3484                        SKSL_RTFLIP_NAME,
3485                        fContext.fTypes.fFloat2.get());
3486    skstd::string_view name = "sksl_synthetic_uniforms";
3487    const Type* intfStruct =
3488            fSynthetics.takeOwnershipOfSymbol(Type::MakeStructType(/*line=*/-1, name, fields));
3489    int binding = fProgram.fConfig->fSettings.fRTFlipBinding;
3490    if (binding == -1) {
3491        fContext.fErrors->error(line, "layout(binding=...) is required in SPIR-V");
3492    }
3493    int set = fProgram.fConfig->fSettings.fRTFlipSet;
3494    if (set == -1) {
3495        fContext.fErrors->error(line, "layout(set=...) is required in SPIR-V");
3496    }
3497    bool usePushConstants = fProgram.fConfig->fSettings.fUsePushConstants;
3498    int flags = usePushConstants ? Layout::Flag::kPushConstant_Flag : 0;
3499    const Modifiers* modsPtr;
3500    {
3501        AutoAttachPoolToThread attach(fProgram.fPool.get());
3502        Modifiers modifiers(Layout(flags,
3503                                   /*location=*/-1,
3504                                   /*offset=*/-1,
3505                                   binding,
3506                                   /*index=*/-1,
3507                                   set,
3508                                   /*builtin=*/-1,
3509                                   /*inputAttachmentIndex=*/-1),
3510                            Modifiers::kUniform_Flag);
3511        modsPtr = fProgram.fModifiers->add(modifiers);
3512    }
3513    const Variable* intfVar = fSynthetics.takeOwnershipOfSymbol(
3514            std::make_unique<Variable>(/*line=*/-1,
3515                                       modsPtr,
3516                                       name,
3517                                       intfStruct,
3518                                       /*builtin=*/false,
3519                                       Variable::Storage::kGlobal));
3520    fSPIRVBonusVariables.insert(intfVar);
3521    {
3522        AutoAttachPoolToThread attach(fProgram.fPool.get());
3523        fProgram.fSymbols->add(std::make_unique<Field>(/*line=*/-1, intfVar, /*field=*/0));
3524    }
3525    InterfaceBlock intf(/*line=*/-1,
3526                        *intfVar,
3527                        name,
3528                        /*instanceName=*/"",
3529                        /*arraySize=*/0,
3530                        std::make_shared<SymbolTable>(fContext, /*builtin=*/false));
3531
3532    this->writeInterfaceBlock(intf, false);
3533}
3534
3535void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
3536    fGLSLExtendedInstructions = this->nextId(nullptr);
3537    StringStream body;
3538    // Assign SpvIds to functions.
3539    const FunctionDeclaration* main = nullptr;
3540    for (const ProgramElement* e : program.elements()) {
3541        if (e->is<FunctionDefinition>()) {
3542            const FunctionDefinition& funcDef = e->as<FunctionDefinition>();
3543            const FunctionDeclaration& funcDecl = funcDef.declaration();
3544            fFunctionMap[&funcDecl] = this->nextId(nullptr);
3545            if (funcDecl.isMain()) {
3546                main = &funcDecl;
3547            }
3548        }
3549    }
3550    // Make sure we have a main() function.
3551    if (!main) {
3552        fContext.fErrors->error(/*line=*/-1, "program does not contain a main() function");
3553        return;
3554    }
3555    // Emit interface blocks.
3556    std::set<SpvId> interfaceVars;
3557    for (const ProgramElement* e : program.elements()) {
3558        if (e->is<InterfaceBlock>()) {
3559            const InterfaceBlock& intf = e->as<InterfaceBlock>();
3560            SpvId id = this->writeInterfaceBlock(intf);
3561
3562            const Modifiers& modifiers = intf.variable().modifiers();
3563            if ((modifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
3564                modifiers.fLayout.fBuiltin == -1 && !this->isDead(intf.variable())) {
3565                interfaceVars.insert(id);
3566            }
3567        }
3568    }
3569    // Emit global variable declarations.
3570    for (const ProgramElement* e : program.elements()) {
3571        if (e->is<GlobalVarDeclaration>()) {
3572            this->writeGlobalVar(program.fConfig->fKind,
3573                                 e->as<GlobalVarDeclaration>().declaration()->as<VarDeclaration>());
3574        }
3575    }
3576    // Emit top-level uniforms into a dedicated uniform buffer.
3577    if (!fTopLevelUniforms.empty()) {
3578        this->writeUniformBuffer(get_top_level_symbol_table(*main));
3579    }
3580    // If main() returns a half4, synthesize a tiny entrypoint function which invokes the real
3581    // main() and stores the result into sk_FragColor.
3582    EntrypointAdapter adapter;
3583    if (main->returnType() == *fContext.fTypes.fHalf4) {
3584        adapter = this->writeEntrypointAdapter(*main);
3585        if (adapter.entrypointDecl) {
3586            fFunctionMap[adapter.entrypointDecl.get()] = this->nextId(nullptr);
3587            this->writeFunction(*adapter.entrypointDef, body);
3588            main = adapter.entrypointDecl.get();
3589        }
3590    }
3591    // Emit all the functions.
3592    for (const ProgramElement* e : program.elements()) {
3593        if (e->is<FunctionDefinition>()) {
3594            this->writeFunction(e->as<FunctionDefinition>(), body);
3595        }
3596    }
3597    // Add global in/out variables to the list of interface variables.
3598    for (auto entry : fVariableMap) {
3599        const Variable* var = entry.first;
3600        if (var->storage() == Variable::Storage::kGlobal &&
3601            (var->modifiers().fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag)) &&
3602            !this->isDead(*var)) {
3603            interfaceVars.insert(entry.second);
3604        }
3605    }
3606    this->writeCapabilities(out);
3607    this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
3608    this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
3609    this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->name().length() + 4) / 4) +
3610                      (int32_t) interfaceVars.size(), out);
3611    switch (program.fConfig->fKind) {
3612        case ProgramKind::kVertex:
3613            this->writeWord(SpvExecutionModelVertex, out);
3614            break;
3615        case ProgramKind::kFragment:
3616            this->writeWord(SpvExecutionModelFragment, out);
3617            break;
3618        default:
3619            SK_ABORT("cannot write this kind of program to SPIR-V\n");
3620    }
3621    SpvId entryPoint = fFunctionMap[main];
3622    this->writeWord(entryPoint, out);
3623    this->writeString(main->name(), out);
3624    for (int var : interfaceVars) {
3625        this->writeWord(var, out);
3626    }
3627    if (program.fConfig->fKind == ProgramKind::kFragment) {
3628        this->writeInstruction(SpvOpExecutionMode,
3629                               fFunctionMap[main],
3630                               SpvExecutionModeOriginUpperLeft,
3631                               out);
3632    }
3633    for (const ProgramElement* e : program.elements()) {
3634        if (e->is<Extension>()) {
3635            this->writeInstruction(SpvOpSourceExtension, e->as<Extension>().name(), out);
3636        }
3637    }
3638
3639    write_stringstream(fExtraGlobalsBuffer, out);
3640    write_stringstream(fNameBuffer, out);
3641    write_stringstream(fDecorationBuffer, out);
3642    write_stringstream(fConstantBuffer, out);
3643    write_stringstream(body, out);
3644}
3645
3646bool SPIRVCodeGenerator::generateCode() {
3647    SkASSERT(!fContext.fErrors->errorCount());
3648    this->writeWord(SpvMagicNumber, *fOut);
3649    this->writeWord(SpvVersion, *fOut);
3650    this->writeWord(SKSL_MAGIC, *fOut);
3651    StringStream buffer;
3652    this->writeInstructions(fProgram, buffer);
3653    this->writeWord(fIdCount, *fOut);
3654    this->writeWord(0, *fOut); // reserved, always zero
3655    write_stringstream(buffer, *fOut);
3656    fContext.fErrors->reportPendingErrors(PositionInfo());
3657    return fContext.fErrors->errorCount() == 0;
3658}
3659
3660}  // namespace SkSL
3661