1/*
2 * Copyright (C) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "default_limits.h"
17#include <glslang/Public/ShaderLang.h>
18#include <SPIRV/GlslangToSpv.h>
19#include <SPIRV/SpvTools.h>
20#include <spirv-tools/optimizer.hpp>
21
22#include "spirv_cross.hpp"
23
24// #include "preprocess/preprocess.h"
25#include <algorithm>
26#include <chrono>
27#include <filesystem>
28#include <fstream>
29#include <iostream>
30#include <memory>
31#include <numeric>
32#include <optional>
33#include <sstream>
34#include <string>
35#include <thread>
36
37#include "io/dev/FileMonitor.h"
38#include "lume/Log.h"
39#include "shader_type.h"
40#include "spirv_cross_helpers_gles.h"
41
42using namespace std::chrono_literals;
43
44// Enumerations from Engine which should match: Format, DescriptorType, ShaderStageFlagBits
45/** Format */
46enum class Format {
47    /** Undefined */
48    UNDEFINED = 0,
49    /** R4G4 UNORM PACK8 */
50    R4G4_UNORM_PACK8 = 1,
51    /** R4G4B4A4 UNORM PACK16 */
52    R4G4B4A4_UNORM_PACK16 = 2,
53    /** B4G4R4A4 UNORM PACK16 */
54    B4G4R4A4_UNORM_PACK16 = 3,
55    /** R5G6B5 UNORM PACK16 */
56    R5G6B5_UNORM_PACK16 = 4,
57    /** B5G6R5 UNORM PACK16 */
58    B5G6R5_UNORM_PACK16 = 5,
59    /** R5G5B5A1 UNORM PACK16 */
60    R5G5B5A1_UNORM_PACK16 = 6,
61    /** B5G5R5A1 UNORM PACK16 */
62    B5G5R5A1_UNORM_PACK16 = 7,
63    /** A1R5G5B5 UNORM PACK16 */
64    A1R5G5B5_UNORM_PACK16 = 8,
65    /** R8 UNORM */
66    R8_UNORM = 9,
67    /** R8 SNORM */
68    R8_SNORM = 10,
69    /** R8 USCALED */
70    R8_USCALED = 11,
71    /** R8 SSCALED */
72    R8_SSCALED = 12,
73    /** R8 UINT */
74    R8_UINT = 13,
75    /** R8 SINT */
76    R8_SINT = 14,
77    /** R8 SRGB */
78    R8_SRGB = 15,
79    /** R8G8 UNORM */
80    R8G8_UNORM = 16,
81    /** R8G8 SNORM */
82    R8G8_SNORM = 17,
83    /** R8G8 USCALED */
84    R8G8_USCALED = 18,
85    /** R8G8 SSCALED */
86    R8G8_SSCALED = 19,
87    /** R8G8 UINT */
88    R8G8_UINT = 20,
89    /** R8G8 SINT */
90    R8G8_SINT = 21,
91    /** R8G8 SRGB */
92    R8G8_SRGB = 22,
93    /** R8G8B8 UNORM */
94    R8G8B8_UNORM = 23,
95    /** R8G8B8 SNORM */
96    R8G8B8_SNORM = 24,
97    /** R8G8B8 USCALED */
98    R8G8B8_USCALED = 25,
99    /** R8G8B8 SSCALED */
100    R8G8B8_SSCALED = 26,
101    /** R8G8B8 UINT */
102    R8G8B8_UINT = 27,
103    /** R8G8B8 SINT */
104    R8G8B8_SINT = 28,
105    /** R8G8B8 SRGB */
106    R8G8B8_SRGB = 29,
107    /** B8G8R8 UNORM */
108    B8G8R8_UNORM = 30,
109    /** B8G8R8 SNORM */
110    B8G8R8_SNORM = 31,
111    /** B8G8R8 UINT */
112    B8G8R8_UINT = 34,
113    /** B8G8R8 SINT */
114    B8G8R8_SINT = 35,
115    /** B8G8R8 SRGB */
116    B8G8R8_SRGB = 36,
117    /** R8G8B8A8 UNORM */
118    R8G8B8A8_UNORM = 37,
119    /** R8G8B8A8 SNORM */
120    R8G8B8A8_SNORM = 38,
121    /** R8G8B8A8 USCALED */
122    R8G8B8A8_USCALED = 39,
123    /** R8G8B8A8 SSCALED */
124    R8G8B8A8_SSCALED = 40,
125    /** R8G8B8A8 UINT */
126    R8G8B8A8_UINT = 41,
127    /** R8G8B8A8 SINT */
128    R8G8B8A8_SINT = 42,
129    /** R8G8B8A8 SRGB */
130    R8G8B8A8_SRGB = 43,
131    /** B8G8R8A8 UNORM */
132    B8G8R8A8_UNORM = 44,
133    /** B8G8R8A8 SNORM */
134    B8G8R8A8_SNORM = 45,
135    /** B8G8R8A8 UINT */
136    B8G8R8A8_UINT = 48,
137    /** B8G8R8A8 SINT */
138    B8G8R8A8_SINT = 49,
139    /** FORMAT B8G8R8A8 SRGB */
140    B8G8R8A8_SRGB = 50,
141    /** A8B8G8R8 UNORM PACK32 */
142    A8B8G8R8_UNORM_PACK32 = 51,
143    /** A8B8G8R8 SNORM PACK32 */
144    A8B8G8R8_SNORM_PACK32 = 52,
145    /** A8B8G8R8 USCALED PACK32 */
146    A8B8G8R8_USCALED_PACK32 = 53,
147    /** A8B8G8R8 SSCALED PACK32 */
148    A8B8G8R8_SSCALED_PACK32 = 54,
149    /** A8B8G8R8 UINT PACK32 */
150    A8B8G8R8_UINT_PACK32 = 55,
151    /** A8B8G8R8 SINT PACK32 */
152    A8B8G8R8_SINT_PACK32 = 56,
153    /** A8B8G8R8 SRGB PACK32 */
154    A8B8G8R8_SRGB_PACK32 = 57,
155    /** A2R10G10B10 UNORM PACK32 */
156    A2R10G10B10_UNORM_PACK32 = 58,
157    /** A2R10G10B10 UINT PACK32 */
158    A2R10G10B10_UINT_PACK32 = 62,
159    /** A2R10G10B10 SINT PACK32 */
160    A2R10G10B10_SINT_PACK32 = 63,
161    /** A2B10G10R10 UNORM PACK32 */
162    A2B10G10R10_UNORM_PACK32 = 64,
163    /** A2B10G10R10 SNORM PACK32 */
164    A2B10G10R10_SNORM_PACK32 = 65,
165    /** A2B10G10R10 USCALED PACK32 */
166    A2B10G10R10_USCALED_PACK32 = 66,
167    /** A2B10G10R10 SSCALED PACK32 */
168    A2B10G10R10_SSCALED_PACK32 = 67,
169    /** A2B10G10R10 UINT PACK32 */
170    A2B10G10R10_UINT_PACK32 = 68,
171    /** A2B10G10R10 SINT PACK32 */
172    A2B10G10R10_SINT_PACK32 = 69,
173    /** R16 UNORM */
174    R16_UNORM = 70,
175    /** R16 SNORM */
176    R16_SNORM = 71,
177    /** R16 USCALED */
178    R16_USCALED = 72,
179    /** R16 SSCALED */
180    R16_SSCALED = 73,
181    /** R16 UINT */
182    R16_UINT = 74,
183    /** R16 SINT */
184    R16_SINT = 75,
185    /** R16 SFLOAT */
186    R16_SFLOAT = 76,
187    /** R16G16 UNORM */
188    R16G16_UNORM = 77,
189    /** R16G16 SNORM */
190    R16G16_SNORM = 78,
191    /** R16G16 USCALED */
192    R16G16_USCALED = 79,
193    /** R16G16 SSCALED */
194    R16G16_SSCALED = 80,
195    /** R16G16 UINT */
196    R16G16_UINT = 81,
197    /** R16G16 SINT */
198    R16G16_SINT = 82,
199    /** R16G16 SFLOAT */
200    R16G16_SFLOAT = 83,
201    /** R16G16B16 UNORM */
202    R16G16B16_UNORM = 84,
203    /** R16G16B16 SNORM */
204    R16G16B16_SNORM = 85,
205    /** R16G16B16 USCALED */
206    R16G16B16_USCALED = 86,
207    /** R16G16B16 SSCALED */
208    R16G16B16_SSCALED = 87,
209    /** R16G16B16 UINT */
210    R16G16B16_UINT = 88,
211    /** R16G16B16 SINT */
212    R16G16B16_SINT = 89,
213    /** R16G16B16 SFLOAT */
214    R16G16B16_SFLOAT = 90,
215    /** R16G16B16A16 UNORM */
216    R16G16B16A16_UNORM = 91,
217    /** R16G16B16A16 SNORM */
218    R16G16B16A16_SNORM = 92,
219    /** R16G16B16A16 USCALED */
220    R16G16B16A16_USCALED = 93,
221    /** R16G16B16A16 SSCALED */
222    R16G16B16A16_SSCALED = 94,
223    /** R16G16B16A16 UINT */
224    R16G16B16A16_UINT = 95,
225    /** R16G16B16A16 SINT */
226    R16G16B16A16_SINT = 96,
227    /** R16G16B16A16 SFLOAT */
228    R16G16B16A16_SFLOAT = 97,
229    /** R32 UINT */
230    R32_UINT = 98,
231    /** R32 SINT */
232    R32_SINT = 99,
233    /** R32 SFLOAT */
234    R32_SFLOAT = 100,
235    /** R32G32 UINT */
236    R32G32_UINT = 101,
237    /** R32G32 SINT */
238    R32G32_SINT = 102,
239    /** R32G32 SFLOAT */
240    R32G32_SFLOAT = 103,
241    /** R32G32B32 UINT */
242    R32G32B32_UINT = 104,
243    /** R32G32B32 SINT */
244    R32G32B32_SINT = 105,
245    /** R32G32B32 SFLOAT */
246    R32G32B32_SFLOAT = 106,
247    /** R32G32B32A32 UINT */
248    R32G32B32A32_UINT = 107,
249    /** R32G32B32A32 SINT */
250    R32G32B32A32_SINT = 108,
251    /** R32G32B32A32 SFLOAT */
252    R32G32B32A32_SFLOAT = 109,
253    /** B10G11R11 UFLOAT PACK32 */
254    B10G11R11_UFLOAT_PACK32 = 122,
255    /** E5B9G9R9 UFLOAT PACK32 */
256    E5B9G9R9_UFLOAT_PACK32 = 123,
257    /** D16 UNORM */
258    D16_UNORM = 124,
259    /** X8 D24 UNORM PACK32 */
260    X8_D24_UNORM_PACK32 = 125,
261    /** D32 SFLOAT */
262    D32_SFLOAT = 126,
263    /** S8 UINT */
264    S8_UINT = 127,
265    /** D24 UNORM S8 UINT */
266    D24_UNORM_S8_UINT = 129,
267    /** BC1 RGB UNORM BLOCK */
268    BC1_RGB_UNORM_BLOCK = 131,
269    /** BC1 RGB SRGB BLOCK */
270    BC1_RGB_SRGB_BLOCK = 132,
271    /** BC1 RGBA UNORM BLOCK */
272    BC1_RGBA_UNORM_BLOCK = 133,
273    /** BC1 RGBA SRGB BLOCK */
274    BC1_RGBA_SRGB_BLOCK = 134,
275    /** BC2 UNORM BLOCK */
276    BC2_UNORM_BLOCK = 135,
277    /** BC2 SRGB BLOCK */
278    BC2_SRGB_BLOCK = 136,
279    /** BC3 UNORM BLOCK */
280    BC3_UNORM_BLOCK = 137,
281    /** BC3 SRGB BLOCK */
282    BC3_SRGB_BLOCK = 138,
283    /** BC4 UNORM BLOCK */
284    BC4_UNORM_BLOCK = 139,
285    /** BC4 SNORM BLOCK */
286    BC4_SNORM_BLOCK = 140,
287    /** BC5 UNORM BLOCK */
288    BC5_UNORM_BLOCK = 141,
289    /** BC5 SNORM BLOCK */
290    BC5_SNORM_BLOCK = 142,
291    /** BC6H UFLOAT BLOCK */
292    BC6H_UFLOAT_BLOCK = 143,
293    /** BC6H SFLOAT BLOCK */
294    BC6H_SFLOAT_BLOCK = 144,
295    /** BC7 UNORM BLOCK */
296    BC7_UNORM_BLOCK = 145,
297    /** BC7 SRGB BLOCK */
298    BC7_SRGB_BLOCK = 146,
299    /** ETC2 R8G8B8 UNORM BLOCK */
300    ETC2_R8G8B8_UNORM_BLOCK = 147,
301    /** ETC2 R8G8B8 SRGB BLOCK */
302    ETC2_R8G8B8_SRGB_BLOCK = 148,
303    /** ETC2 R8G8B8A1 UNORM BLOCK */
304    ETC2_R8G8B8A1_UNORM_BLOCK = 149,
305    /** ETC2 R8G8B8A1 SRGB BLOCK */
306    ETC2_R8G8B8A1_SRGB_BLOCK = 150,
307    /** ETC2 R8G8B8A8 UNORM BLOCK */
308    ETC2_R8G8B8A8_UNORM_BLOCK = 151,
309    /** ETC2 R8G8B8A8 SRGB BLOCK */
310    ETC2_R8G8B8A8_SRGB_BLOCK = 152,
311    /** EAC R11 UNORM BLOCK */
312    EAC_R11_UNORM_BLOCK = 153,
313    /** EAC R11 SNORM BLOCK */
314    EAC_R11_SNORM_BLOCK = 154,
315    /** EAC R11G11 UNORM BLOCK */
316    EAC_R11G11_UNORM_BLOCK = 155,
317    /** EAC R11G11 SNORM BLOCK */
318    EAC_R11G11_SNORM_BLOCK = 156,
319    /** ASTC 4x4 UNORM BLOCK */
320    ASTC_4x4_UNORM_BLOCK = 157,
321    /** ASTC 4x4 SRGB BLOCK */
322    ASTC_4x4_SRGB_BLOCK = 158,
323    /** ASTC 5x4 UNORM BLOCK */
324    ASTC_5x4_UNORM_BLOCK = 159,
325    /** ASTC 5x4 SRGB BLOCK */
326    ASTC_5x4_SRGB_BLOCK = 160,
327    /** ASTC 5x5 UNORM BLOCK */
328    ASTC_5x5_UNORM_BLOCK = 161,
329    /** ASTC 5x5 SRGB BLOCK */
330    ASTC_5x5_SRGB_BLOCK = 162,
331    /** ASTC 6x5 UNORM BLOCK */
332    ASTC_6x5_UNORM_BLOCK = 163,
333    /** ASTC 6x5 SRGB BLOCK */
334    ASTC_6x5_SRGB_BLOCK = 164,
335    /** ASTC 6x6 UNORM BLOCK */
336    ASTC_6x6_UNORM_BLOCK = 165,
337    /** ASTC 6x6 SRGB BLOCK */
338    ASTC_6x6_SRGB_BLOCK = 166,
339    /** ASTC 8x5 UNORM BLOCK */
340    ASTC_8x5_UNORM_BLOCK = 167,
341    /** ASTC 8x5 SRGB BLOCK */
342    ASTC_8x5_SRGB_BLOCK = 168,
343    /** ASTC 8x6 UNORM BLOCK */
344    ASTC_8x6_UNORM_BLOCK = 169,
345    /** ASTC 8x6 SRGB BLOCK */
346    ASTC_8x6_SRGB_BLOCK = 170,
347    /** ASTC 8x8 UNORM BLOCK */
348    ASTC_8x8_UNORM_BLOCK = 171,
349    /** ASTC 8x8 SRGB BLOCK */
350    ASTC_8x8_SRGB_BLOCK = 172,
351    /** ASTC 10x5 UNORM BLOCK */
352    ASTC_10x5_UNORM_BLOCK = 173,
353    /** ASTC 10x5 SRGB BLOCK */
354    ASTC_10x5_SRGB_BLOCK = 174,
355    /** ASTC 10x6 UNORM BLOCK */
356    ASTC_10x6_UNORM_BLOCK = 175,
357    /** ASTC 10x6 SRGB BLOCK */
358    ASTC_10x6_SRGB_BLOCK = 176,
359    /** ASTC 10x8 UNORM BLOCK */
360    ASTC_10x8_UNORM_BLOCK = 177,
361    /** ASTC 10x8 SRGB BLOCK */
362    ASTC_10x8_SRGB_BLOCK = 178,
363    /** ASTC 10x10 UNORM BLOCK */
364    ASTC_10x10_UNORM_BLOCK = 179,
365    /** ASTC 10x10 SRGB BLOCK */
366    ASTC_10x10_SRGB_BLOCK = 180,
367    /** ASTC 12x10 UNORM BLOCK */
368    ASTC_12x10_UNORM_BLOCK = 181,
369    /** ASTC 12x10 SRGB BLOCK */
370    ASTC_12x10_SRGB_BLOCK = 182,
371    /** ASTC 12x12 UNORM BLOCK */
372    ASTC_12x12_UNORM_BLOCK = 183,
373    /** ASTC 12x12 SRGB BLOCK */
374    ASTC_12x12_SRGB_BLOCK = 184,
375    /** G8B8G8R8 422 UNORM */
376    G8B8G8R8_422_UNORM = 1000156000,
377    /** B8G8R8G8 422 UNORM */
378    B8G8R8G8_422_UNORM = 1000156001,
379    /** G8 B8 R8 3PLANE 420 UNORM */
380    G8_B8_R8_3PLANE_420_UNORM = 1000156002,
381    /** G8 B8R8 2PLANE 420 UNORM */
382    G8_B8R8_2PLANE_420_UNORM = 1000156003,
383    /** G8 B8 R8 3PLANE 422 UNORM */
384    G8_B8_R8_3PLANE_422_UNORM = 1000156004,
385    /** G8 B8R8 2PLANE 422 UNORM */
386    G8_B8R8_2PLANE_422_UNORM = 1000156005,
387    /** Max enumeration */
388    MAX_ENUM = 0x7FFFFFFF
389};
390
391enum class DescriptorType {
392    /** Sampler */
393    SAMPLER = 0,
394    /** Combined image sampler */
395    COMBINED_IMAGE_SAMPLER = 1,
396    /** Sampled image */
397    SAMPLED_IMAGE = 2,
398    /** Storage image */
399    STORAGE_IMAGE = 3,
400    /** Uniform texel buffer */
401    UNIFORM_TEXEL_BUFFER = 4,
402    /** Storage texel buffer */
403    STORAGE_TEXEL_BUFFER = 5,
404    /** Uniform buffer */
405    UNIFORM_BUFFER = 6,
406    /** Storage buffer */
407    STORAGE_BUFFER = 7,
408    /** Dynamic uniform buffer */
409    UNIFORM_BUFFER_DYNAMIC = 8,
410    /** Dynamic storage buffer */
411    STORAGE_BUFFER_DYNAMIC = 9,
412    /** Input attachment */
413    INPUT_ATTACHMENT = 10,
414    /** Acceleration structure */
415    ACCELERATION_STRUCTURE = 1000150000,
416    /** Max enumeration */
417    MAX_ENUM = 0x7FFFFFFF
418};
419
420/** Vertex input rate */
421enum class VertexInputRate {
422    /** Vertex */
423    VERTEX = 0,
424    /** Instance */
425    INSTANCE = 1,
426    /** Max enumeration */
427    MAX_ENUM = 0x7FFFFFFF
428};
429
430/** Pipeline layout constants */
431struct PipelineLayoutConstants {
432    /** Max descriptor set count */
433    static constexpr uint32_t MAX_DESCRIPTOR_SET_COUNT { 4u };
434    /** Max dynamic descriptor offset count */
435    static constexpr uint32_t MAX_DYNAMIC_DESCRIPTOR_OFFSET_COUNT { 16u };
436    /** Invalid index */
437    static constexpr uint32_t INVALID_INDEX { ~0u };
438    /** Max push constant byte size */
439    static constexpr uint32_t MAX_PUSH_CONSTANT_BYTE_SIZE { 128u };
440};
441
442/** Descriptor set layout binding */
443struct DescriptorSetLayoutBinding {
444    /** Binding */
445    uint32_t binding { PipelineLayoutConstants::INVALID_INDEX };
446    /** Descriptor type */
447    DescriptorType descriptorType { DescriptorType::MAX_ENUM };
448    /** Descriptor count */
449    uint32_t descriptorCount { 0 };
450    /** Stage flags */
451    ShaderStageFlags shaderStageFlags;
452};
453
454/** Descriptor set layout */
455struct DescriptorSetLayout {
456    /** Set */
457    uint32_t set { PipelineLayoutConstants::INVALID_INDEX };
458    /** Bindings */
459    std::vector<DescriptorSetLayoutBinding> bindings;
460};
461
462/** Push constant */
463struct PushConstant {
464    /** Shader stage flags */
465    ShaderStageFlags shaderStageFlags;
466    /** Byte size */
467    uint32_t byteSize { 0 };
468};
469
470/** Pipeline layout */
471struct PipelineLayout {
472    /** Push constant */
473    PushConstant pushConstant;
474    /** Descriptor set count */
475    uint32_t descriptorSetCount { 0 };
476    /** Descriptor sets */
477    DescriptorSetLayout descriptorSetLayouts[PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT] {};
478};
479
480constexpr const uint32_t RESERVED_CONSTANT_ID_INDEX { 256 };
481
482/** Vertex input declaration */
483struct VertexInputDeclaration {
484    /** Vertex input binding description */
485    struct VertexInputBindingDescription {
486        /** Binding */
487        uint32_t binding { ~0u };
488        /** Stride */
489        uint32_t stride { 0u };
490        /** Vertex input rate */
491        VertexInputRate vertexInputRate { VertexInputRate::MAX_ENUM };
492    };
493
494    /** Vertex input attribute description */
495    struct VertexInputAttributeDescription {
496        /** Location */
497        uint32_t location { ~0u };
498        /** Binding */
499        uint32_t binding { ~0u };
500        /** Format */
501        Format format { Format::UNDEFINED };
502        /** Offset */
503        uint32_t offset { 0u };
504    };
505};
506
507struct VertexAttributeInfo {
508    uint32_t byteSize { 0 };
509    VertexInputDeclaration::VertexInputAttributeDescription description;
510};
511
512struct UVec3 {
513    uint32_t x;
514    uint32_t y;
515    uint32_t z;
516};
517
518struct ShaderReflectionData {
519    array_view<const uint8_t> reflectionData;
520
521    bool IsValid() const;
522    ShaderStageFlags GetStageFlags() const;
523    PipelineLayout GetPipelineLayout() const;
524    std::vector<ShaderSpecializationConstant> GetSpecializationConstants() const;
525    std::vector<VertexInputDeclaration::VertexInputAttributeDescription> GetInputDescriptions() const;
526    UVec3 GetLocalSize() const;
527};
528
529struct ShaderModuleCreateInfo {
530    ShaderStageFlags shaderStageFlags;
531    array_view<const uint8_t> spvData;
532    ShaderReflectionData reflectionData;
533};
534
535struct CompilationSettings {
536    ShaderEnv shaderEnv;
537    std::vector<std::filesystem::path> shaderIncludePaths;
538    std::optional<spvtools::Optimizer> optimizer;
539    std::filesystem::path& shaderSourcePath;
540    std::filesystem::path& compiledShaderDestinationPath;
541};
542
543constexpr uint8_t REFLECTION_TAG[] = { 'r', 'f', 'l', 0 };
544struct ReflectionHeader {
545    uint8_t tag[sizeof(REFLECTION_TAG)];
546    uint16_t type;
547    uint16_t offsetPushConstants;
548    uint16_t offsetSpecializationConstants;
549    uint16_t offsetDescriptorSets;
550    uint16_t offsetInputs;
551    uint16_t offsetLocalSize;
552};
553
554class scope {
555private:
556    std::function<void()> init;
557    std::function<void()> deinit;
558
559public:
560    scope(const std::function<void()>&& initializer, const std::function<void()>&& deinitalizer)
561        : init(initializer), deinit(deinitalizer)
562    {
563        init();
564    }
565
566    ~scope()
567    {
568        deinit();
569    }
570};
571
572bool ShaderReflectionData::IsValid() const
573{
574    if (reflectionData.size() < sizeof(ReflectionHeader)) {
575        return false;
576    }
577    const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
578    return memcmp(header.tag, REFLECTION_TAG, sizeof(REFLECTION_TAG)) == 0;
579}
580
581ShaderStageFlags ShaderReflectionData::GetStageFlags() const
582{
583    ShaderStageFlags flags;
584    const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
585    flags = static_cast<ShaderStageFlagBits>(header.type);
586    return flags;
587}
588
589PipelineLayout ShaderReflectionData::GetPipelineLayout() const
590{
591    PipelineLayout pipelineLayout;
592    const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
593    if (header.offsetPushConstants && header.offsetPushConstants < reflectionData.size()) {
594        auto ptr = reflectionData.data() + header.offsetPushConstants;
595        const auto constants = *ptr;
596        if (constants) {
597            pipelineLayout.pushConstant.shaderStageFlags = static_cast<ShaderStageFlagBits>(header.type);
598            pipelineLayout.pushConstant.byteSize = static_cast<uint32_t>(*(ptr + 1) | (*(ptr + 2) << 8));
599        }
600    }
601    if (header.offsetDescriptorSets && header.offsetDescriptorSets < reflectionData.size()) {
602        auto ptr = reflectionData.data() + header.offsetDescriptorSets;
603        pipelineLayout.descriptorSetCount = static_cast<uint32_t>(*(ptr) | (*(ptr + 1) << 8));
604        ptr += 2;
605        for (auto i = 0u; i < pipelineLayout.descriptorSetCount; ++i) {
606            // write to correct set location
607            const uint32_t set = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
608            assert(set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT);
609            auto& layout = pipelineLayout.descriptorSetLayouts[set];
610            layout.set = set;
611            ptr += 2;
612            const auto bindings = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
613            ptr += 2;
614            for (auto j = 0u; j < bindings; ++j) {
615                DescriptorSetLayoutBinding binding;
616                binding.binding = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
617                ptr += 2;
618                binding.descriptorType = static_cast<DescriptorType>(*ptr | (*(ptr + 1) << 8));
619                ptr += 2;
620                binding.descriptorCount = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8));
621                ptr += 2;
622                binding.shaderStageFlags = static_cast<ShaderStageFlagBits>(header.type);
623                layout.bindings.push_back(binding);
624            }
625        }
626    }
627    return pipelineLayout;
628}
629
630std::vector<ShaderSpecializationConstant> ShaderReflectionData::GetSpecializationConstants() const
631{
632    std::vector<ShaderSpecializationConstant> constants;
633    const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
634    if (header.offsetSpecializationConstants && header.offsetSpecializationConstants < reflectionData.size()) {
635        auto ptr = reflectionData.data() + header.offsetSpecializationConstants;
636        const auto size = *ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24;
637        ptr += 4;
638        for (auto i = 0; i < size; ++i) {
639            ShaderSpecializationConstant constant;
640            constant.shaderStage = static_cast<ShaderStageFlagBits>(header.type);
641            constant.id = static_cast<uint32_t>(*ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24);
642            ptr += 4;
643            constant.type = static_cast<ShaderSpecializationConstant::Type>(
644                *ptr | *(ptr + 1) << 8 | *(ptr + 2) << 16 | *(ptr + 3) << 24);
645            ptr += 4;
646            constant.offset = 0;
647            constants.push_back(constant);
648        }
649    }
650    return constants;
651}
652
653std::vector<VertexInputDeclaration::VertexInputAttributeDescription> ShaderReflectionData::GetInputDescriptions() const
654{
655    std::vector<VertexInputDeclaration::VertexInputAttributeDescription> inputs;
656    const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
657    if (header.offsetInputs && header.offsetInputs < reflectionData.size()) {
658        auto ptr = reflectionData.data() + header.offsetInputs;
659        const auto size = *(ptr) | (*(ptr + 1) << 8);
660        ptr += 2;
661        for (auto i = 0; i < size; ++i) {
662            VertexInputDeclaration::VertexInputAttributeDescription desc;
663            desc.location = static_cast<uint32_t>(*(ptr) | (*(ptr + 1) << 8));
664            ptr += 2;
665            desc.binding = desc.location;
666            desc.format = static_cast<Format>(*(ptr) | (*(ptr + 1) << 8));
667            ptr += 2;
668            desc.offset = 0;
669            inputs.push_back(desc);
670        }
671    }
672    return inputs;
673}
674
675UVec3 ShaderReflectionData::GetLocalSize() const
676{
677    UVec3 sizes;
678    const ReflectionHeader& header = *reinterpret_cast<const ReflectionHeader*>(reflectionData.data());
679    if (header.offsetLocalSize && header.offsetLocalSize < reflectionData.size()) {
680        auto ptr = reflectionData.data() + header.offsetLocalSize;
681        sizes.x = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
682        ptr += 4;
683        sizes.y = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
684        ptr += 4;
685        sizes.z = static_cast<uint32_t>(*ptr | (*(ptr + 1) << 8) | (*(ptr + 2)) << 16 | (*(ptr + 3)) << 24);
686        ptr += 4;
687    }
688    return sizes;
689}
690
691std::string readFileToString(std::string_view aFilename)
692{
693    std::stringstream ss;
694    std::ifstream file;
695
696    file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
697    try {
698        file.open(aFilename.data(), std::ios::in);
699
700        if (!file.fail()) {
701            ss << file.rdbuf();
702            return ss.str();
703        }
704    } catch (std::exception const& ex) {
705        LUME_LOG_E("Error reading file: '%s': %s", aFilename.data(), ex.what());
706        return {};
707    }
708    return {};
709}
710
711class FileIncluder : public glslang::TShader::Includer {
712public:
713    const CompilationSettings& settings;
714    FileIncluder(const CompilationSettings& compilationSettings) : settings(compilationSettings) {}
715
716private:
717    virtual IncludeResult* include(
718        const char* headerName, const char* includerName, size_t inclusionDepth, bool relative)
719    {
720        std::filesystem::path path;
721        bool found = false;
722        if (relative == true) {
723            path.append(settings.shaderSourcePath.c_str());
724            path.append(includerName);
725            path = path.parent_path();
726            path.append(headerName);
727            found = std::filesystem::exists(path);
728        }
729
730        for (int i = 0; i < settings.shaderIncludePaths.size() && found == false; ++i) {
731            path.assign(settings.shaderIncludePaths[i]);
732            path.append(headerName);
733            found = std::filesystem::exists(path);
734        }
735
736        if (found == true) {
737            auto str = path.string();
738
739            std::ifstream file(path);
740            file.seekg(0, file.end);
741            std::streampos length = file.tellg();
742            file.seekg(0, file.beg);
743
744            char* memory = new (std::nothrow) char[length + std::streampos(1)];
745            if (memory == 0) {
746                return nullptr;
747            }
748
749            char* last = std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), memory);
750            IncludeResult* result = new (std::nothrow) IncludeResult(str, memory, std::distance(memory, last), 0);
751            if (result == 0) {
752                delete memory;
753                return nullptr;
754            }
755
756            return result;
757        }
758
759        return nullptr;
760    }
761
762    virtual IncludeResult* includeSystem(const char* headerName, const char* includerName, size_t inclusionDepth)
763    {
764        return include(headerName, includerName, inclusionDepth, false);
765    }
766
767    virtual IncludeResult* includeLocal(const char* headerName, const char* includerName, size_t inclusionDepth)
768    {
769        return include(headerName, includerName, inclusionDepth, true);
770    }
771
772    virtual void releaseInclude(IncludeResult* include)
773    {
774        delete include;
775    }
776};
777
778glslang::EShTargetLanguageVersion ToSpirVVersion(glslang::EShTargetClientVersion env_version)
779{
780    if (env_version == glslang::EShTargetVulkan_1_0) {
781        return glslang::EShTargetSpv_1_0;
782    } else if (env_version == glslang::EShTargetVulkan_1_1) {
783        return glslang::EShTargetSpv_1_3;
784    } else if (env_version == glslang::EShTargetVulkan_1_2) {
785        return glslang::EShTargetSpv_1_5;
786#if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
787    } else if (env_version == glslang::EShTargetVulkan_1_3) {
788        return glslang::EShTargetSpv_1_6;
789#endif
790    } else {
791        return glslang::EShTargetSpv_1_0;
792    }
793}
794
795std::string preProcessShader(
796    std::string_view aSource, ShaderKind aKind, std::string_view aSourceName, const CompilationSettings& settings)
797{
798    glslang::EShTargetLanguageVersion languageVersion;
799    glslang::EShTargetClientVersion version;
800    EShLanguage stage;
801    switch (aKind) {
802        case ShaderKind::VERTEX:
803            stage = EShLanguage::EShLangVertex;
804            break;
805        case ShaderKind::FRAGMENT:
806            stage = EShLanguage::EShLangFragment;
807            break;
808        case ShaderKind::COMPUTE:
809            stage = EShLanguage::EShLangCompute;
810            break;
811        default:
812            LUME_LOG_E("Spirv preprocessing compilation failed '%s'", "ShaderKind not recognized");
813            return {};
814    }
815
816    switch (settings.shaderEnv) {
817        case ShaderEnv::version_vulkan_1_0:
818            version = glslang::EShTargetClientVersion::EShTargetVulkan_1_0;
819            break;
820        case ShaderEnv::version_vulkan_1_1:
821            version = glslang::EShTargetClientVersion::EShTargetVulkan_1_1;
822            break;
823        case ShaderEnv::version_vulkan_1_2:
824            version = glslang::EShTargetClientVersion::EShTargetVulkan_1_2;
825            break;
826#if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
827        case ShaderEnv::version_vulkan_1_3:
828            version = glslang::EShTargetClientVersion::EShTargetVulkan_1_3;
829            break;
830#endif
831        default:
832            LUME_LOG_E("Spirv preprocessing compilation failed '%s'", "ShaderEnv not recognized");
833            return {};
834    }
835
836    languageVersion = ToSpirVVersion(version);
837
838    FileIncluder includer(settings);
839    glslang::TShader shader(stage);
840    const char* shader_strings = aSource.data();
841    const int shader_lengths = static_cast<int>(aSource.size());
842    const char* string_names = aSourceName.data();
843    std::string_view preamble = "#extension GL_GOOGLE_include_directive : enable\n";
844    shader.setStringsWithLengthsAndNames(&shader_strings, &shader_lengths, &string_names, 1);
845    shader.setPreamble(preamble.data());
846    shader.setEntryPoint("main");
847    shader.setAutoMapBindings(false);
848    shader.setAutoMapLocations(false);
849    shader.setShiftImageBinding(0);
850    shader.setShiftSamplerBinding(0);
851    shader.setShiftTextureBinding(0);
852    shader.setShiftUboBinding(0);
853    shader.setShiftSsboBinding(0);
854    shader.setShiftUavBinding(0);
855    shader.setEnvClient(glslang::EShClient::EShClientVulkan, version);
856    shader.setEnvTarget(glslang::EShTargetLanguage::EShTargetSpv, languageVersion);
857    shader.setInvertY(false);
858    shader.setNanMinMaxClamp(false);
859
860    std::string output;
861    const EShMessages rules =
862        static_cast<EShMessages>(EShMsgOnlyPreprocessor | EShMsgSpvRules | EShMsgVulkanRules | EShMsgCascadingErrors);
863    if (!shader.preprocess(
864            &kGLSLangDefaultTResource, 110, EProfile::ENoProfile, false, false, rules, &output, includer)) {
865        LUME_LOG_E("Spirv preprocessing compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoLog());
866        LUME_LOG_E("Spirv preprocessing compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoDebugLog());
867
868        output = { output.begin() + preamble.size(), output.end() };
869        return {};
870    }
871
872    output = { output.begin() + preamble.size(), output.end() };
873    return output;
874}
875
876std::vector<uint32_t> compileShaderToSpirvBinary(
877    std::string_view aSource, ShaderKind aKind, std::string_view aSourceName, const CompilationSettings& settings)
878{
879    glslang::EShTargetLanguageVersion languageVersion;
880    glslang::EShTargetClientVersion version;
881    EShLanguage stage;
882    switch (aKind) {
883        case ShaderKind::VERTEX:
884            stage = EShLanguage::EShLangVertex;
885            break;
886        case ShaderKind::FRAGMENT:
887            stage = EShLanguage::EShLangFragment;
888            break;
889        case ShaderKind::COMPUTE:
890            stage = EShLanguage::EShLangCompute;
891            break;
892        default:
893            LUME_LOG_E("Spirv binary compilation failed '%s'", "ShaderKind not recognized");
894            return {};
895    }
896
897    switch (settings.shaderEnv) {
898        case ShaderEnv::version_vulkan_1_0:
899            version = glslang::EShTargetClientVersion::EShTargetVulkan_1_0;
900            break;
901        case ShaderEnv::version_vulkan_1_1:
902            version = glslang::EShTargetClientVersion::EShTargetVulkan_1_1;
903            break;
904        case ShaderEnv::version_vulkan_1_2:
905            version = glslang::EShTargetClientVersion::EShTargetVulkan_1_2;
906            break;
907#if GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
908        case ShaderEnv::version_vulkan_1_3:
909            version = glslang::EShTargetClientVersion::EShTargetVulkan_1_3;
910            break;
911#endif
912        default:
913            LUME_LOG_E("Spirv binary compilation failed '%s'", "ShaderEnv not recognized");
914            return {};
915    }
916
917    languageVersion = ToSpirVVersion(version);
918
919    glslang::TShader shader(stage);
920    const char* shader_strings = aSource.data();
921    const int shader_lengths = static_cast<int>(aSource.size());
922    const char* string_names = aSourceName.data();
923    shader.setStringsWithLengthsAndNames(&shader_strings, &shader_lengths, &string_names, 1);
924    shader.setPreamble("#extension GL_GOOGLE_include_directive : enable\n");
925    shader.setEntryPoint("main");
926    shader.setAutoMapBindings(false);
927    shader.setAutoMapLocations(false);
928    shader.setShiftImageBinding(0);
929    shader.setShiftSamplerBinding(0);
930    shader.setShiftTextureBinding(0);
931    shader.setShiftUboBinding(0);
932    shader.setShiftSsboBinding(0);
933    shader.setShiftUavBinding(0);
934    shader.setEnvClient(glslang::EShClient::EShClientVulkan, version);
935    shader.setEnvTarget(glslang::EShTargetLanguage::EShTargetSpv, languageVersion);
936    shader.setInvertY(false);
937    shader.setNanMinMaxClamp(false);
938
939    const EShMessages rules = static_cast<EShMessages>(EShMsgSpvRules | EShMsgVulkanRules | EShMsgCascadingErrors);
940    if (!shader.parse(&kGLSLangDefaultTResource, 110, EProfile::ENoProfile, false, false, rules)) {
941        LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoLog());
942        LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), shader.getInfoDebugLog());
943        return {};
944    }
945
946    glslang::TProgram program;
947    program.addShader(&shader);
948    if (!program.link(EShMsgDefault) || !program.mapIO()) {
949        LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), program.getInfoLog());
950        LUME_LOG_E("Spirv binary compilation failed '%s':\n%s", aSourceName.data(), program.getInfoDebugLog());
951        return {};
952    }
953
954    std::vector<unsigned int> spirv;
955    glslang::SpvOptions spv_options;
956    spv_options.generateDebugInfo = false;
957    spv_options.disableOptimizer = true;
958    spv_options.optimizeSize = false;
959    spv::SpvBuildLogger logger;
960    glslang::TIntermediate* intermediate = program.getIntermediate(stage);
961    glslang::GlslangToSpv(*intermediate, spirv, &logger, &spv_options);
962
963    const uint32_t shaderc_generator_word = 13; // From SPIR-V XML Registry
964    const uint32_t generator_word_index = 2;    // SPIR-V 2.3: Physical layout
965    assert(spirv.size() > generator_word_index);
966    spirv[generator_word_index] = (spirv[generator_word_index] & 0xffff) | (shaderc_generator_word << 16);
967    return spirv;
968}
969
970void processResource(const spirv_cross::Compiler& compiler, const spirv_cross::Resource& resource,
971    ShaderStageFlags shaderStateFlags, DescriptorType type, DescriptorSetLayout* layouts)
972{
973    const uint32_t set = compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
974
975    assert(set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT);
976    if (set >= PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT) {
977        return;
978    }
979    DescriptorSetLayout& layout = layouts[set];
980    layout.set = set;
981
982    // Collect bindings.
983    const uint32_t bindingIndex = compiler.get_decoration(resource.id, spv::DecorationBinding);
984    auto& bindings = layout.bindings;
985    if (auto pos = std::find_if(bindings.begin(), bindings.end(),
986            [bindingIndex](const DescriptorSetLayoutBinding& binding) { return binding.binding == bindingIndex; });
987        pos == bindings.end()) {
988        const spirv_cross::SPIRType& spirType = compiler.get_type(resource.type_id);
989
990        DescriptorSetLayoutBinding binding;
991        binding.binding = bindingIndex;
992        binding.descriptorType = type;
993        binding.descriptorCount = spirType.array.empty() ? 1 : spirType.array[0];
994        binding.shaderStageFlags = shaderStateFlags;
995
996        bindings.emplace_back(binding);
997    } else {
998        pos->shaderStageFlags |= shaderStateFlags;
999    }
1000}
1001
1002void reflectDescriptorSets(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1003    ShaderStageFlags shaderStateFlags, DescriptorSetLayout* layouts)
1004{
1005    for (const auto& ref : resources.sampled_images) {
1006        processResource(compiler, ref, shaderStateFlags, DescriptorType::COMBINED_IMAGE_SAMPLER, layouts);
1007    }
1008
1009    for (const auto& ref : resources.separate_samplers) {
1010        processResource(compiler, ref, shaderStateFlags, DescriptorType::SAMPLER, layouts);
1011    }
1012
1013    for (const auto& ref : resources.separate_images) {
1014        processResource(compiler, ref, shaderStateFlags, DescriptorType::SAMPLED_IMAGE, layouts);
1015    }
1016
1017    for (const auto& ref : resources.storage_images) {
1018        processResource(compiler, ref, shaderStateFlags, DescriptorType::STORAGE_IMAGE, layouts);
1019    }
1020
1021    for (const auto& ref : resources.uniform_buffers) {
1022        processResource(compiler, ref, shaderStateFlags, DescriptorType::UNIFORM_BUFFER, layouts);
1023    }
1024
1025    for (const auto& ref : resources.storage_buffers) {
1026        processResource(compiler, ref, shaderStateFlags, DescriptorType::STORAGE_BUFFER, layouts);
1027    }
1028
1029    for (const auto& ref : resources.subpass_inputs) {
1030        processResource(compiler, ref, shaderStateFlags, DescriptorType::INPUT_ATTACHMENT, layouts);
1031    }
1032
1033    for (const auto& ref : resources.acceleration_structures) {
1034        processResource(compiler, ref, shaderStateFlags, DescriptorType::ACCELERATION_STRUCTURE, layouts);
1035    }
1036
1037    std::sort(layouts, layouts + PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT,
1038        [](const DescriptorSetLayout& lhs, const DescriptorSetLayout& rhs) { return (lhs.set < rhs.set); });
1039
1040    std::for_each(
1041        layouts, layouts + PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT, [](DescriptorSetLayout& layout) {
1042            std::sort(layout.bindings.begin(), layout.bindings.end(),
1043                [](const DescriptorSetLayoutBinding& lhs, const DescriptorSetLayoutBinding& rhs) {
1044                    return (lhs.binding < rhs.binding);
1045                });
1046        });
1047}
1048
1049void reflectPushContants(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1050    ShaderStageFlags shaderStateFlags, PushConstant& pushConstant)
1051{
1052    // NOTE: support for only one push constant
1053    if (resources.push_constant_buffers.size() > 0) {
1054        pushConstant.shaderStageFlags |= shaderStateFlags;
1055
1056        const auto ranges = compiler.get_active_buffer_ranges(resources.push_constant_buffers[0].id);
1057        const uint32_t byteSize = std::accumulate(
1058            ranges.begin(), ranges.end(), 0u, [](uint32_t byteSize, const spirv_cross::BufferRange& range) {
1059                return byteSize + static_cast<uint32_t>(range.range);
1060            });
1061        pushConstant.byteSize = std::max(pushConstant.byteSize, byteSize);
1062    }
1063}
1064
1065std::vector<ShaderSpecializationConstant> reflectSpecializationConstants(
1066    const spirv_cross::Compiler& compiler, ShaderStageFlags shaderStateFlags)
1067{
1068    std::vector<ShaderSpecializationConstant> specializationConstants;
1069    uint32_t offset = 0;
1070    for (auto const& constant : compiler.get_specialization_constants()) {
1071        if (constant.constant_id < RESERVED_CONSTANT_ID_INDEX) {
1072            const spirv_cross::SPIRConstant& spirvConstant = compiler.get_constant(constant.id);
1073            const auto type = compiler.get_type(spirvConstant.constant_type);
1074            ShaderSpecializationConstant::Type constantType = ShaderSpecializationConstant::Type::INVALID;
1075            if (type.basetype == spirv_cross::SPIRType::Boolean) {
1076                constantType = ShaderSpecializationConstant::Type::BOOL;
1077            } else if (type.basetype == spirv_cross::SPIRType::UInt) {
1078                constantType = ShaderSpecializationConstant::Type::UINT32;
1079            } else if (type.basetype == spirv_cross::SPIRType::Int) {
1080                constantType = ShaderSpecializationConstant::Type::INT32;
1081            } else if (type.basetype == spirv_cross::SPIRType::Float) {
1082                constantType = ShaderSpecializationConstant::Type::FLOAT;
1083            } else {
1084                assert(false && "Unhandled specialization constant type");
1085            }
1086            const uint32_t size = spirvConstant.vector_size() * spirvConstant.columns() * sizeof(uint32_t);
1087            specializationConstants.push_back(
1088                ShaderSpecializationConstant { shaderStateFlags, constant.constant_id, constantType, offset });
1089            offset += size;
1090        }
1091    }
1092    // sorted based on offset due to offset mapping with shader combinations
1093    // NOTE: id and name indexing
1094    std::sort(specializationConstants.begin(), specializationConstants.end(),
1095        [](const auto& lhs, const auto& rhs) { return (lhs.offset < rhs.offset); });
1096
1097    return specializationConstants;
1098}
1099
1100Format convertToVertexInputFormat(const spirv_cross::SPIRType& type)
1101{
1102    using BaseType = spirv_cross::SPIRType::BaseType;
1103
1104    // ivecn: a vector of signed integers
1105    if (type.basetype == BaseType::Int) {
1106        switch (type.vecsize) {
1107            case 1:
1108                return Format::R32_SINT;
1109            case 2:
1110                return Format::R32G32_SINT;
1111            case 3:
1112                return Format::R32G32B32_SINT;
1113            case 4:
1114                return Format::R32G32B32A32_SINT;
1115        }
1116    }
1117
1118    // uvecn: a vector of unsigned integers
1119    if (type.basetype == BaseType::UInt) {
1120        switch (type.vecsize) {
1121            case 1:
1122                return Format::R32_UINT;
1123            case 2:
1124                return Format::R32G32_UINT;
1125            case 3:
1126                return Format::R32G32B32_UINT;
1127            case 4:
1128                return Format::R32G32B32A32_UINT;
1129        }
1130    }
1131
1132    // halfn: a vector of half-precision floating-point numbers
1133    if (type.basetype == BaseType::Half) {
1134        switch (type.vecsize) {
1135            case 1:
1136                return Format::R16_SFLOAT;
1137            case 2:
1138                return Format::R16G16_SFLOAT;
1139            case 3:
1140                return Format::R16G16B16_SFLOAT;
1141            case 4:
1142                return Format::R16G16B16A16_SFLOAT;
1143        }
1144    }
1145
1146    // vecn: a vector of single-precision floating-point numbers
1147    if (type.basetype == BaseType::Float) {
1148        switch (type.vecsize) {
1149            case 1:
1150                return Format::R32_SFLOAT;
1151            case 2:
1152                return Format::R32G32_SFLOAT;
1153            case 3:
1154                return Format::R32G32B32_SFLOAT;
1155            case 4:
1156                return Format::R32G32B32A32_SFLOAT;
1157        }
1158    }
1159
1160    return Format::UNDEFINED;
1161}
1162
1163void reflectVertexInputs(const spirv_cross::Compiler& compiler, const spirv_cross::ShaderResources& resources,
1164    ShaderStageFlags /* shaderStateFlags */,
1165    std::vector<VertexInputDeclaration::VertexInputAttributeDescription>& vertexInputAttributes)
1166{
1167    std::vector<VertexAttributeInfo> vertexAttributeInfos;
1168
1169    // Vertex input attributes.
1170    for (auto& attr : resources.stage_inputs) {
1171        const spirv_cross::SPIRType attributeType = compiler.get_type(attr.type_id);
1172
1173        VertexAttributeInfo info;
1174
1175        // For now, assume that every vertex attribute comes from it's own binding which equals the location.
1176        info.description.location = compiler.get_decoration(attr.id, spv::DecorationLocation);
1177        info.description.binding = info.description.location;
1178        info.description.format = convertToVertexInputFormat(attributeType);
1179        info.description.offset = 0;
1180
1181        info.byteSize = attributeType.vecsize * (attributeType.width / 8);
1182
1183        vertexAttributeInfos.emplace_back(std::move(info));
1184    }
1185
1186    // Sort input attributes by binding and location.
1187    std::sort(std::begin(vertexAttributeInfos), std::end(vertexAttributeInfos),
1188        [](const VertexAttributeInfo& aA, const VertexAttributeInfo& aB) {
1189            if (aA.description.binding < aB.description.binding) {
1190                return true;
1191            }
1192
1193            return aA.description.location < aB.description.location;
1194        });
1195
1196    // Create final attributes.
1197    if (!vertexAttributeInfos.empty()) {
1198        for (auto& info : vertexAttributeInfos) {
1199            vertexInputAttributes.push_back(info.description);
1200        }
1201    }
1202}
1203
1204template<typename T>
1205void push(std::vector<uint8_t>& buffer, T data)
1206{
1207    buffer.push_back(data & 0xff);
1208    if constexpr (sizeof(T) > 1) {
1209        buffer.push_back((data >> 8) & 0xff);
1210    }
1211    if constexpr (sizeof(T) > 2) {
1212        buffer.push_back((data >> 16) & 0xff);
1213    }
1214    if constexpr (sizeof(T) > 3) {
1215        buffer.push_back((data >> 24) & 0xff);
1216    }
1217}
1218
1219std::vector<uint8_t> reflectSpvBinary(const std::vector<uint32_t>& aBinary, ShaderKind aKind)
1220{
1221    const spirv_cross::Compiler compiler(aBinary);
1222
1223    const auto shaderStateFlags = ShaderStageFlags(aKind);
1224
1225    const spirv_cross::ShaderResources resources = compiler.get_shader_resources();
1226
1227    PipelineLayout pipelineLayout;
1228    reflectDescriptorSets(compiler, resources, shaderStateFlags, pipelineLayout.descriptorSetLayouts);
1229    pipelineLayout.descriptorSetCount =
1230        static_cast<uint32_t>(std::count_if(std::begin(pipelineLayout.descriptorSetLayouts),
1231            std::end(pipelineLayout.descriptorSetLayouts), [](const DescriptorSetLayout& layout) {
1232                return layout.set < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT;
1233            }));
1234    reflectPushContants(compiler, resources, shaderStateFlags, pipelineLayout.pushConstant);
1235
1236    // some additional information mainly for GL
1237    std::vector<Gles::PushConstantReflection> pushConstantReflection;
1238    for (auto& remap : resources.push_constant_buffers) {
1239        const auto& blockType = compiler.get_type(remap.base_type_id);
1240        auto name = compiler.get_name(remap.id);
1241        (void)(blockType);
1242        assert((blockType.basetype == spirv_cross::SPIRType::Struct) && "Push constant is not a struct!");
1243        Gles::ProcessStruct(std::string_view(name.data(), name.size()), 0, compiler, remap.base_type_id,
1244            pushConstantReflection, shaderStateFlags);
1245    }
1246
1247    auto specializationConstants = reflectSpecializationConstants(compiler, shaderStateFlags);
1248
1249    // NOTE: this is done for all although the name is 'Vertex'InputAttributes
1250    std::vector<VertexInputDeclaration::VertexInputAttributeDescription> vertexInputAttributes;
1251    reflectVertexInputs(compiler, resources, shaderStateFlags, vertexInputAttributes);
1252
1253    std::vector<uint8_t> reflection;
1254    reflection.reserve(512u);
1255    constexpr uint8_t TAG[] = { 'r', 'f', 'l', 0 }; // last one is version
1256    uint16_t type = 0;
1257    uint16_t offsetPushConstants = 0;
1258    uint16_t offsetSpecializationConstants = 0;
1259    uint16_t offsetDescriptorSets = 0;
1260    uint16_t offsetInputs = 0;
1261    uint16_t offsetLocalSize = 0;
1262    // tag
1263    {
1264        reflection.insert(reflection.end(), std::begin(TAG), std::end(TAG));
1265    }
1266    // shader type
1267    {
1268        push(reflection, static_cast<uint16_t>(shaderStateFlags.flags));
1269    }
1270    // offsets
1271    {
1272        reflection.resize(reflection.size() + sizeof(uint16_t) * 5);
1273    }
1274    // push constants
1275    {
1276        offsetPushConstants = static_cast<uint16_t>(reflection.size());
1277        if (pipelineLayout.pushConstant.byteSize) {
1278            reflection.push_back(1);
1279            push(reflection, static_cast<uint16_t>(pipelineLayout.pushConstant.byteSize));
1280
1281            push(reflection, static_cast<uint8_t>(pushConstantReflection.size()));
1282            for (const auto& refl : pushConstantReflection) {
1283                push(reflection, refl.type);
1284                push(reflection, static_cast<uint16_t>(refl.offset));
1285                push(reflection, static_cast<uint16_t>(refl.size));
1286                push(reflection, static_cast<uint16_t>(refl.arraySize));
1287                push(reflection, static_cast<uint16_t>(refl.arrayStride));
1288                push(reflection, static_cast<uint16_t>(refl.matrixStride));
1289                push(reflection, static_cast<uint16_t>(refl.name.size()));
1290                reflection.insert(reflection.end(), std::begin(refl.name), std::end(refl.name));
1291            }
1292        } else {
1293            reflection.push_back(0);
1294        }
1295    }
1296    // specialization constants
1297    {
1298        offsetSpecializationConstants = static_cast<uint16_t>(reflection.size());
1299        {
1300            const auto size = static_cast<uint32_t>(specializationConstants.size());
1301            push(reflection, static_cast<uint32_t>(specializationConstants.size()));
1302        }
1303        for (auto const& constant : specializationConstants) {
1304            push(reflection, static_cast<uint32_t>(constant.id));
1305            push(reflection, static_cast<uint32_t>(constant.type));
1306        }
1307    }
1308    // descriptor sets
1309    {
1310        offsetDescriptorSets = static_cast<uint16_t>(reflection.size());
1311        {
1312            push(reflection, static_cast<uint16_t>(pipelineLayout.descriptorSetCount));
1313        }
1314        auto begin = std::begin(pipelineLayout.descriptorSetLayouts);
1315        auto end = begin;
1316        std::advance(end, pipelineLayout.descriptorSetCount);
1317        std::for_each(begin, end, [&reflection](const DescriptorSetLayout& layout) {
1318            push(reflection, static_cast<uint16_t>(layout.set));
1319            push(reflection, static_cast<uint16_t>(layout.bindings.size()));
1320            for (const auto& binding : layout.bindings) {
1321                push(reflection, static_cast<uint16_t>(binding.binding));
1322                push(reflection, static_cast<uint16_t>(binding.descriptorType));
1323                push(reflection, static_cast<uint16_t>(binding.descriptorCount));
1324            }
1325        });
1326    }
1327    // inputs
1328    {
1329        offsetInputs = static_cast<uint16_t>(reflection.size());
1330        const auto size = static_cast<uint16_t>(vertexInputAttributes.size());
1331        push(reflection, size);
1332        for (const auto& input : vertexInputAttributes) {
1333            push(reflection, static_cast<uint16_t>(input.location));
1334            push(reflection, static_cast<uint16_t>(input.format));
1335        }
1336    }
1337    // local size
1338    if (shaderStateFlags & ShaderStageFlagBits::COMPUTE_BIT) {
1339        offsetLocalSize = static_cast<uint16_t>(reflection.size());
1340        uint32_t size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 0);
1341        push(reflection, size);
1342
1343        size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 1);
1344        push(reflection, size);
1345
1346        size = compiler.get_execution_mode_argument(spv::ExecutionMode::ExecutionModeLocalSize, 2);
1347        push(reflection, size);
1348    }
1349    // update offsets to real values
1350    {
1351        auto ptr = reflection.data() + (sizeof(TAG) + sizeof(type));
1352        *ptr++ = offsetPushConstants & 0xff;
1353        *ptr++ = (offsetPushConstants >> 8) & 0xff;
1354        *ptr++ = offsetSpecializationConstants & 0xff;
1355        *ptr++ = (offsetSpecializationConstants >> 8) & 0xff;
1356        *ptr++ = offsetDescriptorSets & 0xff;
1357        *ptr++ = (offsetDescriptorSets >> 8) & 0xff;
1358        *ptr++ = offsetInputs & 0xff;
1359        *ptr++ = (offsetInputs >> 8) & 0xff;
1360        *ptr++ = offsetLocalSize & 0xff;
1361        *ptr++ = (offsetLocalSize >> 8) & 0xff;
1362    }
1363
1364    return reflection;
1365}
1366
1367struct Binding {
1368    uint8_t set;
1369    uint8_t bind;
1370};
1371
1372Binding get_binding(Gles::CoreCompiler& compiler, spirv_cross::ID id)
1373{
1374    const uint32_t dset = compiler.get_decoration(id, spv::Decoration::DecorationDescriptorSet);
1375    const uint32_t dbind = compiler.get_decoration(id, spv::Decoration::DecorationBinding);
1376    assert(dset < Gles::ResourceLimits::MAX_SETS);
1377    assert(dbind < Gles::ResourceLimits::MAX_BIND_IN_SET);
1378    const uint8_t set = static_cast<uint8_t>(dset);
1379    const uint8_t bind = static_cast<uint8_t>(dbind);
1380    return { set, bind };
1381}
1382
1383void SortSets(PipelineLayout& pipelineLayout)
1384{
1385    pipelineLayout.descriptorSetCount = 0;
1386    for (uint32_t idx = 0; idx < PipelineLayoutConstants::MAX_DESCRIPTOR_SET_COUNT; ++idx) {
1387        DescriptorSetLayout& currSet = pipelineLayout.descriptorSetLayouts[idx];
1388        if (currSet.set != PipelineLayoutConstants::INVALID_INDEX) {
1389            pipelineLayout.descriptorSetCount++;
1390            std::sort(currSet.bindings.begin(), currSet.bindings.end(),
1391                [](auto const& lhs, auto const& rhs) { return (lhs.binding < rhs.binding); });
1392        }
1393    }
1394}
1395
1396void Collect(Gles::CoreCompiler& compiler, const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
1397    const uint32_t forceBinding = 0)
1398{
1399    std::string name;
1400
1401    for (const auto& remap : resources) {
1402        const auto binding = get_binding(compiler, remap.id);
1403
1404        name.resize(name.capacity() - 1);
1405        const auto nameLen = sprintf(name.data(), "s%u_b%u", binding.set, binding.bind);
1406        name.resize(nameLen);
1407
1408        // if name is empty it's a block and we need to rename the base_type_id i.e.
1409        // uniform <base_type_id> { vec4 foo; } <id>;
1410        if (auto origname = compiler.get_name(remap.id); origname.empty()) {
1411            compiler.set_name(remap.base_type_id, name);
1412            name.insert(name.begin(), '_');
1413            compiler.set_name(remap.id, name);
1414        } else {
1415            // uniform <id> vec4 foo;
1416            compiler.set_name(remap.id, name);
1417        }
1418
1419        compiler.unset_decoration(remap.id, spv::DecorationDescriptorSet);
1420        compiler.unset_decoration(remap.id, spv::DecorationBinding);
1421        if (forceBinding > 0) {
1422            compiler.set_decoration(
1423                remap.id, spv::DecorationBinding, forceBinding - 1); // will be over-written later. (special handling)
1424        }
1425    }
1426}
1427
1428struct ShaderModulePlatformDataGLES {
1429    std::vector<Gles::PushConstantReflection> infos;
1430};
1431
1432void CollectRes(
1433    Gles::CoreCompiler& compiler, const spirv_cross::ShaderResources& res, ShaderModulePlatformDataGLES& plat_)
1434{
1435    // collect names for later linkage
1436    static constexpr uint32_t DefaultBinding = 11;
1437    Collect(compiler, res.storage_buffers, DefaultBinding + 1);
1438    Collect(compiler, res.storage_images, DefaultBinding + 1);
1439    Collect(compiler, res.uniform_buffers, 0); // 0 == remove binding decorations (let's the compiler decide)
1440    Collect(compiler, res.subpass_inputs, 0);  // 0 == remove binding decorations (let's the compiler decide)
1441
1442    // handle the real sampled images.
1443    Collect(compiler, res.sampled_images, 0); // 0 == remove binding decorations (let's the compiler decide)
1444
1445    // and now the "generated ones" (separate image/sampler)
1446    std::string imageName;
1447    std::string samplerName;
1448    std::string temp;
1449    for (auto& remap : compiler.get_combined_image_samplers()) {
1450        const auto imageBinding = get_binding(compiler, remap.image_id);
1451        {
1452            imageName.resize(imageName.capacity() - 1);
1453            const auto nameLen = sprintf(imageName.data(), "s%u_b%u", imageBinding.set, imageBinding.bind);
1454            imageName.resize(nameLen);
1455        }
1456        const auto samplerBinding = get_binding(compiler, remap.sampler_id);
1457        {
1458            samplerName.resize(samplerName.capacity() - 1);
1459            const auto nameLen = sprintf(samplerName.data(), "s%u_b%u", samplerBinding.set, samplerBinding.bind);
1460            samplerName.resize(nameLen);
1461        }
1462
1463        temp.reserve(imageName.size() + samplerName.size() + 1);
1464        temp.clear();
1465        temp.append(imageName);
1466        temp.append(1, '_');
1467        temp.append(samplerName);
1468        compiler.set_name(remap.combined_id, temp);
1469    }
1470}
1471
1472/** Device backend type */
1473enum class DeviceBackendType {
1474    /** Vulkan backend */
1475    VULKAN,
1476    /** GLES backend */
1477    OPENGLES,
1478    /** OpenGL backend */
1479    OPENGL
1480};
1481
1482void SetupSpirvCross(ShaderStageFlags stage, Gles::CoreCompiler* compiler, DeviceBackendType backend, bool ovrEnabled)
1483{
1484    spirv_cross::CompilerGLSL::Options options;
1485
1486    if (backend == DeviceBackendType::OPENGLES) {
1487        options.version = 320;
1488        options.es = true;
1489        options.fragment.default_float_precision = spirv_cross::CompilerGLSL::Options::Precision::Highp;
1490        options.fragment.default_int_precision = spirv_cross::CompilerGLSL::Options::Precision::Highp;
1491    }
1492
1493    if (backend == DeviceBackendType::OPENGL) {
1494        options.version = 450;
1495        options.es = false;
1496    }
1497
1498#if defined(CORE_USE_SEPARATE_SHADER_OBJECTS) && (CORE_USE_SEPARATE_SHADER_OBJECTS == 1)
1499    if (stage & (CORE_SHADER_STAGE_VERTEX_BIT | CORE_SHADER_STAGE_FRAGMENT_BIT)) {
1500        options.separate_shader_objects = true;
1501    }
1502#endif
1503
1504    options.ovr_multiview_view_count = ovrEnabled ? 1 : 0;
1505
1506    compiler->set_common_options(options);
1507}
1508
1509struct Shader {
1510    ShaderStageFlags shaderStageFlags_;
1511    DeviceBackendType backend_;
1512    ShaderModulePlatformDataGLES plat_;
1513    bool ovrEnabled;
1514
1515    std::string source_;
1516};
1517
1518void ProcessShaderModule(Shader& me, const ShaderModuleCreateInfo& createInfo)
1519{
1520    // perform reflection.
1521    auto compiler = Gles::CoreCompiler(reinterpret_cast<const uint32_t*>(createInfo.spvData.data()),
1522        static_cast<uint32_t>(createInfo.spvData.size() / sizeof(uint32_t)));
1523    // Set some options.
1524    SetupSpirvCross(me.shaderStageFlags_, &compiler, me.backend_, me.ovrEnabled);
1525
1526    // first step in converting CORE_FLIP_NDC to regular uniform. (specializationconstant -> constant) this makes the
1527    // compiled glsl more readable, and simpler to post process later.
1528    Gles::ConvertSpecConstToConstant(compiler, "CORE_FLIP_NDC");
1529    // const auto& res = compiler.get_shader_resources();
1530
1531    auto active = compiler.get_active_interface_variables();
1532    const auto& res = compiler.get_shader_resources(active);
1533    compiler.set_enabled_interface_variables(std::move(active));
1534
1535    Gles::ReflectPushConstants(compiler, res, me.plat_.infos, me.shaderStageFlags_);
1536    compiler.build_combined_image_samplers();
1537    CollectRes(compiler, res, me.plat_);
1538
1539    // set "CORE_BACKEND_TYPE" specialization to 1.
1540    Gles::SetSpecMacro(compiler, "CORE_BACKEND_TYPE", 1U);
1541
1542    me.source_ = compiler.compile();
1543    Gles::ConvertConstantToUniform(compiler, me.source_, "CORE_FLIP_NDC");
1544}
1545
1546template<typename T>
1547bool writeToFile(const array_view<T>& data, std::filesystem::path aDestinationFile)
1548{
1549    std::ofstream outputStream(aDestinationFile, std::ios::out | std::ios::binary);
1550    if (outputStream.is_open()) {
1551        outputStream.write((const char*)data.data(), data.size() * sizeof(T));
1552        outputStream.close();
1553        return true;
1554    } else {
1555        LUME_LOG_E("Could not write file: '%s'", aDestinationFile.string().c_str());
1556        return false;
1557    }
1558}
1559
1560bool runAllCompilationStages(std::string_view inputFilename, CompilationSettings& settings)
1561{
1562    try {
1563        // std::string inputFilename = aFile;
1564        const std::filesystem::path relativeInputFilename =
1565            std::filesystem::relative(inputFilename, settings.shaderSourcePath);
1566        const std::string relativeFilename = relativeInputFilename.string();
1567        const std::string extension = std::filesystem::path(inputFilename).extension().string();
1568        std::filesystem::path outputFilename = settings.compiledShaderDestinationPath / relativeInputFilename;
1569
1570        // Make sure the output dir hierarchy exists.
1571        std::filesystem::create_directories(outputFilename.parent_path());
1572
1573        // Just copying json files to the destination dir.
1574        if (extension == ".json") {
1575            if (!std::filesystem::exists(outputFilename) ||
1576                !std::filesystem::equivalent(inputFilename, outputFilename)) {
1577                LUME_LOG_I("  %s", relativeFilename.c_str());
1578                std::filesystem::copy(inputFilename, outputFilename, std::filesystem::copy_options::overwrite_existing);
1579            }
1580            return true;
1581        } else {
1582            LUME_LOG_I("  %s", relativeFilename.c_str());
1583            outputFilename += ".spv";
1584
1585            LUME_LOG_V("    input: '%s'", inputFilename.data());
1586            LUME_LOG_V("      dst: '%s'", settings.compiledShaderDestinationPath.string().c_str());
1587            LUME_LOG_V(" relative: '%s'", relativeFilename.c_str());
1588            LUME_LOG_V("   output: '%s'", outputFilename.string().c_str());
1589
1590            if (std::string shaderSource = readFileToString(inputFilename); !shaderSource.empty()) {
1591                ShaderKind shaderKind;
1592                if (extension == ".vert") {
1593                    shaderKind = ShaderKind::VERTEX;
1594                } else if (extension == ".frag") {
1595                    shaderKind = ShaderKind::FRAGMENT;
1596                } else if (extension == ".comp") {
1597                    shaderKind = ShaderKind::COMPUTE;
1598                } else {
1599                    return false;
1600                }
1601
1602                if (std::string preProcessedShader =
1603                        preProcessShader(shaderSource, shaderKind, relativeFilename, settings);
1604                    !preProcessedShader.empty()) {
1605                    if (true) {
1606                        auto reflectionFile = outputFilename;
1607                        reflectionFile += ".pre";
1608                        if (!writeToFile(
1609                                array_view(preProcessedShader.data(), preProcessedShader.size()), reflectionFile)) {
1610                            LUME_LOG_E("Failed to save reflection %s", reflectionFile.string().data());
1611                        }
1612                    }
1613
1614                    if (std::vector<uint32_t> spvBinary =
1615                            compileShaderToSpirvBinary(preProcessedShader, shaderKind, relativeFilename, settings);
1616                        !spvBinary.empty()) {
1617                        const auto reflection = reflectSpvBinary(spvBinary, shaderKind);
1618                        if (reflection.empty()) {
1619                            LUME_LOG_E("Failed to reflect %s", inputFilename.data());
1620                        } else {
1621                            auto reflectionFile = outputFilename;
1622                            reflectionFile += ".lsb";
1623                            if (!writeToFile(array_view(reflection.data(), reflection.size()), reflectionFile)) {
1624                                LUME_LOG_E("Failed to save reflection %s", reflectionFile.string().data());
1625                            }
1626                        }
1627
1628                        if (settings.optimizer) {
1629                            // spirv-opt resets the passes everytime so then need to be setup
1630                            settings.optimizer->RegisterPerformancePasses();
1631                            if (!settings.optimizer->Run(spvBinary.data(), spvBinary.size(), &spvBinary)) {
1632                                LUME_LOG_E("Failed to optimize %s", inputFilename.data());
1633                            }
1634                        }
1635
1636                        if (writeToFile(array_view(spvBinary.data(), spvBinary.size()), outputFilename)) {
1637                            LUME_LOG_D("  -> %s", outputFilename.string().c_str());
1638                            auto glFile = outputFilename;
1639                            glFile += ".gl";
1640                            try {
1641                                bool multiviewEnabled = false;
1642                                if (shaderKind == ShaderKind::VERTEX) {
1643                                    static constexpr const std::string_view multiview = "GL_EXT_multiview";
1644                                    for (auto pos = shaderSource.find(multiview); pos != std::string::npos;
1645                                         pos = shaderSource.find(multiview, pos + multiview.size())) {
1646                                        if ((shaderSource.rfind("#extension", pos) != std::string::npos) &&
1647                                            (shaderSource.find("enabled", pos + multiview.size()) !=
1648                                                std::string::npos)) {
1649                                            multiviewEnabled = true;
1650                                            break;
1651                                        }
1652                                    }
1653                                }
1654                                Shader shader;
1655                                shader.shaderStageFlags_ =
1656                                    shaderKind == ShaderKind::VERTEX
1657                                        ? ShaderStageFlagBits::VERTEX_BIT
1658                                        : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1659                                                                              : ShaderStageFlagBits::COMPUTE_BIT);
1660
1661                                shader.backend_ = DeviceBackendType::OPENGL;
1662                                shader.ovrEnabled = multiviewEnabled;
1663                                ShaderModuleCreateInfo info;
1664                                info.shaderStageFlags =
1665                                    shaderKind == ShaderKind::VERTEX
1666                                        ? ShaderStageFlagBits::VERTEX_BIT
1667                                        : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1668                                                                              : ShaderStageFlagBits::COMPUTE_BIT);
1669                                info.spvData =
1670                                    array_view(static_cast<const uint8_t*>(static_cast<const void*>(spvBinary.data())),
1671                                        spvBinary.size() * sizeof(decltype(spvBinary)::value_type));
1672                                info.reflectionData.reflectionData =
1673                                    array_view(static_cast<const uint8_t*>(static_cast<const void*>(reflection.data())),
1674                                        reflection.size() * sizeof(decltype(reflection)::value_type));
1675                                ProcessShaderModule(shader, info);
1676                                writeToFile(array_view(static_cast<const uint8_t*>(
1677                                                           static_cast<const void*>(shader.source_.data())),
1678                                                shader.source_.size()),
1679                                    glFile);
1680                            } catch (std::exception const& e) {
1681                                LUME_LOG_E("Failed to generate GL shader for '%s': %s", inputFilename.data(), e.what());
1682                            }
1683
1684                            auto glesFile = glFile;
1685                            glesFile += "es";
1686                            try {
1687                                bool multiviewEnabled = false;
1688                                if (shaderKind == ShaderKind::VERTEX) {
1689                                    static constexpr const std::string_view multiview = "GL_EXT_multiview";
1690                                    for (auto pos = shaderSource.find(multiview); pos != std::string::npos;
1691                                         pos = shaderSource.find(multiview, pos + multiview.size())) {
1692                                        if ((shaderSource.rfind("#extension", pos) != std::string::npos) &&
1693                                            (shaderSource.find("enabled", pos + multiview.size()) !=
1694                                                std::string::npos)) {
1695                                            multiviewEnabled = true;
1696                                            break;
1697                                        }
1698                                    }
1699                                }
1700                                Shader shader;
1701                                shader.shaderStageFlags_ =
1702                                    shaderKind == ShaderKind::VERTEX
1703                                        ? ShaderStageFlagBits::VERTEX_BIT
1704                                        : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1705                                                                              : ShaderStageFlagBits::COMPUTE_BIT);
1706
1707                                shader.backend_ = DeviceBackendType::OPENGLES;
1708                                shader.ovrEnabled = multiviewEnabled;
1709                                ShaderModuleCreateInfo info;
1710                                info.shaderStageFlags =
1711                                    shaderKind == ShaderKind::VERTEX
1712                                        ? ShaderStageFlagBits::VERTEX_BIT
1713                                        : (shaderKind == ShaderKind::FRAGMENT ? ShaderStageFlagBits::FRAGMENT_BIT
1714                                                                              : ShaderStageFlagBits::COMPUTE_BIT);
1715                                info.spvData =
1716                                    array_view(static_cast<const uint8_t*>(static_cast<const void*>(spvBinary.data())),
1717                                        spvBinary.size() * sizeof(decltype(spvBinary)::value_type));
1718                                info.reflectionData.reflectionData =
1719                                    array_view(static_cast<const uint8_t*>(static_cast<const void*>(reflection.data())),
1720                                        reflection.size() * sizeof(decltype(reflection)::value_type));
1721                                ProcessShaderModule(shader, info);
1722                                writeToFile(array_view(static_cast<const uint8_t*>(
1723                                                           static_cast<const void*>(shader.source_.data())),
1724                                                shader.source_.size()),
1725                                    glesFile);
1726                            } catch (std::exception const& e) {
1727                                LUME_LOG_E(
1728                                    "Failed to generate GLES shader for '%s': %s", inputFilename.data(), e.what());
1729                            }
1730
1731                            return true;
1732                        }
1733                    }
1734                }
1735            }
1736        }
1737    } catch (std::exception const& e) {
1738        LUME_LOG_E("Processing file failed '%s': %s", inputFilename.data(), e.what());
1739    }
1740    return false;
1741}
1742
1743void show_usage()
1744{
1745    std::cout << "LumeShaderCompiler - Supported shader types: vertex (.vert), fragment (.frag), compute (.comp)"
1746              << std::endl
1747              << std::endl;
1748    std::cout << "How to use: \n"
1749                 "LumeShaderCompiler.exe --source <source path> (sets destination path to same as source)\n"
1750                 "LumeShaderCompiler.exe --source <source path> --destination <destination path>\n"
1751                 "LumeShaderCompiler.exe --monitor (monitors changes in the source files)"
1752              << std::endl;
1753}
1754
1755std::vector<std::string> filterByExtension(
1756    const std::vector<std::string>& aFilenames, const std::vector<std::string_view>& aIncludeExtensions)
1757{
1758    std::vector<std::string> filtered;
1759    for (auto const& file : aFilenames) {
1760        std::string lowercaseFileExt = std::filesystem::path(file).extension().string();
1761        std::transform(lowercaseFileExt.begin(), lowercaseFileExt.end(), lowercaseFileExt.begin(), tolower);
1762
1763        for (auto const& extension : aIncludeExtensions) {
1764            if (lowercaseFileExt == extension) {
1765                filtered.push_back(file);
1766                break;
1767            }
1768        }
1769    }
1770
1771    return filtered;
1772}
1773
1774int main(int argc, char* argv[])
1775{
1776    if (argc == 1) {
1777        show_usage();
1778        return 0;
1779    }
1780
1781    std::filesystem::path const currentFolder = std::filesystem::current_path();
1782    std::filesystem::path shaderSourcesPath = currentFolder;
1783    std::filesystem::path compiledShaderDestinationPath;
1784    std::vector<std::filesystem::path> shaderIncludePaths;
1785    std::filesystem::path sourceFile;
1786
1787    bool monitorChanges = false;
1788    bool optimizeSpirv = false;
1789    ShaderEnv envVersion = ShaderEnv::version_vulkan_1_0;
1790    for (int i = 1; i < argc; ++i) {
1791        const auto arg = std::string_view(argv[i]);
1792        if (arg == "--help") {
1793            show_usage();
1794            return 0;
1795        } else if (arg == "--sourceFile") {
1796            if (i + 1 < argc) {
1797                sourceFile = argv[++i];
1798                sourceFile.make_preferred();
1799                shaderSourcesPath = sourceFile;
1800                shaderSourcesPath.remove_filename();
1801                if (compiledShaderDestinationPath.empty()) {
1802                    compiledShaderDestinationPath = shaderSourcesPath;
1803                }
1804            } else {
1805                LUME_LOG_E("--sourceFile option requires one argument.\n");
1806                return 1;
1807            }
1808        } else if (arg == "--source") {
1809            if (i + 1 < argc) {
1810                shaderSourcesPath = argv[++i];
1811                shaderSourcesPath.make_preferred();
1812                if (compiledShaderDestinationPath.empty()) {
1813                    compiledShaderDestinationPath = shaderSourcesPath;
1814                }
1815            } else {
1816                LUME_LOG_E("--source option requires one argument.");
1817                return 1;
1818            }
1819        } else if (arg == "--destination") {
1820            if (i + 1 < argc) {
1821                compiledShaderDestinationPath = argv[++i];
1822                compiledShaderDestinationPath.make_preferred();
1823            } else {
1824                LUME_LOG_E("--destination option requires one argument.");
1825                return 1;
1826            }
1827        } else if (arg == "--include") {
1828            if (i + 1 < argc) {
1829                shaderIncludePaths.emplace_back(argv[++i]).make_preferred();
1830            } else {
1831                LUME_LOG_E("--include option requires one argument.");
1832                return 1;
1833            }
1834
1835        } else if (arg == "--monitor") {
1836            monitorChanges = true;
1837        } else if (arg == "--optimize") {
1838            optimizeSpirv = true;
1839        } else if (arg == "--vulkan") {
1840            if (i + 1 < argc) {
1841                const auto version = std::string_view(argv[++i]);
1842                if (version == "1.0") {
1843                    envVersion = ShaderEnv::version_vulkan_1_0;
1844                } else if (version == "1.1") {
1845                    envVersion = ShaderEnv::version_vulkan_1_1;
1846                } else if (version == "1.2") {
1847                    envVersion = ShaderEnv::version_vulkan_1_2;
1848#ifdef GLSLANG_VERSION >= GLSLANG_VERSION_12_2_0
1849                } else if (version == "1.3") {
1850                    envVersion = ShaderEnv::version_vulkan_1_3;
1851#endif
1852                } else {
1853                    LUME_LOG_E("Invalid argument for option --vulkan.");
1854                    return 1;
1855                }
1856            } else {
1857                LUME_LOG_E("--vulkan option requires one argument.");
1858                return 1;
1859            }
1860        }
1861    }
1862
1863    if (compiledShaderDestinationPath.empty()) {
1864        compiledShaderDestinationPath = currentFolder;
1865    }
1866
1867    ige::FileMonitor fileMonitor;
1868
1869    if (!std::filesystem::exists(shaderSourcesPath)) {
1870        LUME_LOG_E("Source path does not exist: '%s'", shaderSourcesPath.string().c_str());
1871        return 1;
1872    }
1873
1874    // Make sure the destination dir exists.
1875    std::filesystem::create_directories(compiledShaderDestinationPath);
1876
1877    if (!std::filesystem::exists(compiledShaderDestinationPath)) {
1878        LUME_LOG_E("Destination path does not exist: '%s'", compiledShaderDestinationPath.string().c_str());
1879        return 1;
1880    }
1881
1882    fileMonitor.addPath(shaderSourcesPath.string());
1883    std::vector<std::string> fileList = [&]() {
1884        std::vector<std::string> list;
1885        if (!sourceFile.empty()) {
1886            list.push_back(sourceFile.u8string());
1887        } else {
1888            list = fileMonitor.getMonitoredFiles();
1889        }
1890        return list;
1891    }();
1892
1893    const std::vector<std::string_view> supportedFileTypes = { ".vert", ".frag", ".comp", ".json" };
1894    fileList = filterByExtension(fileList, supportedFileTypes);
1895
1896    LUME_LOG_I("     Source path: '%s'", std::filesystem::absolute(shaderSourcesPath).string().c_str());
1897    for (auto const& path : shaderIncludePaths) {
1898        LUME_LOG_I("    Include path: '%s'", std::filesystem::absolute(path).string().c_str());
1899    }
1900    LUME_LOG_I("Destination path: '%s'", std::filesystem::absolute(compiledShaderDestinationPath).string().c_str());
1901    LUME_LOG_I("");
1902    LUME_LOG_I("Processing:");
1903
1904    int errorCount = 0;
1905    scope scope([]() { glslang::InitializeProcess(); }, []() { glslang::FinalizeProcess(); });
1906
1907    std::vector<std::filesystem::path> searchPath;
1908    searchPath.reserve(searchPath.size() + 1 + shaderIncludePaths.size());
1909    searchPath.emplace_back(shaderSourcesPath.string());
1910    for (auto const& path : shaderIncludePaths) {
1911        searchPath.emplace_back(path.string());
1912    }
1913
1914    auto settings =
1915        CompilationSettings { envVersion, searchPath, {}, shaderSourcesPath, compiledShaderDestinationPath };
1916
1917    if (optimizeSpirv) {
1918        spv_target_env targetEnv = spv_target_env::SPV_ENV_VULKAN_1_0;
1919        switch (envVersion) {
1920            case ShaderEnv::version_vulkan_1_0:
1921                targetEnv = spv_target_env::SPV_ENV_VULKAN_1_0;
1922                break;
1923            case ShaderEnv::version_vulkan_1_1:
1924                targetEnv = spv_target_env::SPV_ENV_VULKAN_1_1;
1925                break;
1926            case ShaderEnv::version_vulkan_1_2:
1927                targetEnv = spv_target_env::SPV_ENV_VULKAN_1_2;
1928                break;
1929            case ShaderEnv::version_vulkan_1_3:
1930                targetEnv = spv_target_env::SPV_ENV_VULKAN_1_3;
1931                break;
1932            default:
1933                break;
1934        }
1935        settings.optimizer.emplace(targetEnv);
1936    }
1937
1938    // Startup compilation.
1939    for (auto const& file : fileList) {
1940        std::string relativeFilename = std::filesystem::relative(file, shaderSourcesPath).string();
1941        LUME_LOG_D("Tracked source file: '%s'", relativeFilename.c_str());
1942        if (!runAllCompilationStages(file, settings)) {
1943            errorCount++;
1944        }
1945    }
1946
1947    if (errorCount == 0) {
1948        LUME_LOG_I("Success.");
1949    } else {
1950        LUME_LOG_E("Failed: %d", errorCount);
1951    }
1952
1953    if (monitorChanges) {
1954        LUME_LOG_I("Monitoring file changes.");
1955    }
1956
1957    // Main loop.
1958    while (monitorChanges) {
1959        std::vector<std::string> addedFiles, removedFiles, modifiedFiles;
1960        fileMonitor.scanModifications(addedFiles, removedFiles, modifiedFiles);
1961        modifiedFiles = filterByExtension(modifiedFiles, supportedFileTypes);
1962
1963        if (sourceFile.empty()) {
1964            addedFiles = filterByExtension(addedFiles, supportedFileTypes);
1965            removedFiles = filterByExtension(removedFiles, supportedFileTypes);
1966
1967            if (!addedFiles.empty()) {
1968                LUME_LOG_I("Files added:");
1969                for (auto const& addedFile : addedFiles) {
1970                    runAllCompilationStages(addedFile, settings);
1971                }
1972            }
1973
1974            if (!removedFiles.empty()) {
1975                LUME_LOG_I("Files removed:");
1976                for (auto const& removedFile : removedFiles) {
1977                    std::string relativeFilename = std::filesystem::relative(removedFile, shaderSourcesPath).string();
1978                    LUME_LOG_I("  %s", relativeFilename.c_str());
1979                }
1980            }
1981
1982            if (!modifiedFiles.empty()) {
1983                LUME_LOG_I("Files modified:");
1984                for (auto const& modifiedFile : modifiedFiles) {
1985                    runAllCompilationStages(modifiedFile, settings);
1986                }
1987            }
1988        } else if (!modifiedFiles.empty()) {
1989            if (auto pos = std::find_if(modifiedFiles.cbegin(), modifiedFiles.cend(),
1990                    [&sourceFile](const std::string& modified) { return modified == sourceFile; });
1991                pos != modifiedFiles.cend()) {
1992                runAllCompilationStages(*pos, settings);
1993            }
1994        }
1995
1996        std::this_thread::sleep_for(std::chrono::seconds(1));
1997    }
1998
1999    return errorCount;
2000}