1//
2// Copyright (C) 2016-2017 Google, Inc.
3// Copyright (C) 2020 The Khronos Group Inc.
4//
5// All rights reserved.
6//
7// Redistribution and use in source and binary forms, with or without
8// modification, are permitted provided that the following conditions
9// are met:
10//
11//    Redistributions of source code must retain the above copyright
12//    notice, this list of conditions and the following disclaimer.
13//
14//    Redistributions in binary form must reproduce the above
15//    copyright notice, this list of conditions and the following
16//    disclaimer in the documentation and/or other materials provided
17//    with the distribution.
18//
19//    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
20//    contributors may be used to endorse or promote products derived
21//    from this software without specific prior written permission.
22//
23// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34// POSSIBILITY OF SUCH DAMAGE.
35//
36#include <algorithm>
37
38#include <gtest/gtest.h>
39
40#include "TestFixture.h"
41
42#include "glslang/MachineIndependent/iomapper.h"
43#include "glslang/MachineIndependent/reflection.h"
44
45namespace glslangtest {
46namespace {
47
48struct IoMapData {
49    std::vector<std::string> fileNames;
50    Semantics semantics;
51};
52
53using GlslMapIOTest = GlslangTest <::testing::TestWithParam<IoMapData>>;
54
55template<class T>
56std::string interfaceName(T symbol) {
57    return symbol.getType()->getBasicType() == glslang::EbtBlock ? std::string(symbol.getType()->getTypeName().c_str()) : symbol.name;
58}
59
60bool verifyIOMapping(std::string& linkingError, glslang::TProgram& program) {
61    bool success = true;
62
63    // Verify IO Mapping by generating reflection for each stage individually
64    // and comparing layout qualifiers on the results
65
66
67    int reflectionOptions = EShReflectionDefault;
68    //reflectionOptions |= EShReflectionStrictArraySuffix;
69    //reflectionOptions |= EShReflectionBasicArraySuffix;
70    reflectionOptions |= EShReflectionIntermediateIO;
71    reflectionOptions |= EShReflectionSeparateBuffers;
72    reflectionOptions |= EShReflectionAllBlockVariables;
73    //reflectionOptions |= EShReflectionUnwrapIOBlocks;
74
75    success &= program.buildReflection(reflectionOptions);
76
77    // check that the reflection output from the individual stages all makes sense..
78    std::vector<glslang::TReflection> stageReflections;
79    for (int s = 0; s < EShLangCount; ++s) {
80        if (program.getIntermediate((EShLanguage)s)) {
81            stageReflections.emplace_back((EShReflectionOptions)reflectionOptions, (EShLanguage)s, (EShLanguage)s);
82            success &= stageReflections.back().addStage((EShLanguage)s, *program.getIntermediate((EShLanguage)s));
83        }
84    }
85
86    // check that input/output locations match between stages
87    auto it = stageReflections.begin();
88    auto nextIt = it + 1;
89    for (; nextIt != stageReflections.end(); it++, nextIt++) {
90        int numOut = it->getNumPipeOutputs();
91        std::map<std::string, const glslang::TObjectReflection*> pipeOut;
92
93        for (int i = 0; i < numOut; i++) {
94            const glslang::TObjectReflection& out = it->getPipeOutput(i);
95            std::string name = interfaceName(out);
96            pipeOut[name] = &out;
97        }
98
99        int numIn = nextIt->getNumPipeInputs();
100        for (int i = 0; i < numIn; i++) {
101            auto in = nextIt->getPipeInput(i);
102            std::string name = interfaceName(in);
103            auto out = pipeOut.find(name);
104
105            if (out != pipeOut.end()) {
106                auto inQualifier = in.getType()->getQualifier();
107                auto outQualifier = out->second->getType()->getQualifier();
108                success &= outQualifier.layoutLocation == inQualifier.layoutLocation;
109            }
110            else {
111                if (!in.getType()->isStruct()) {
112                    bool found = false;
113                    for (auto outIt : pipeOut) {
114                        if (outIt.second->getType()->isStruct()) {
115                            unsigned int baseLoc = outIt.second->getType()->getQualifier().hasLocation() ?
116                                outIt.second->getType()->getQualifier().layoutLocation :
117                                std::numeric_limits<unsigned int>::max();
118                            for (size_t j = 0; j < outIt.second->getType()->getStruct()->size(); j++) {
119                                baseLoc = (*outIt.second->getType()->getStruct())[j].type->getQualifier().hasLocation() ?
120                                    (*outIt.second->getType()->getStruct())[j].type->getQualifier().layoutLocation : baseLoc;
121                                if (baseLoc != std::numeric_limits<unsigned int>::max()) {
122                                    if (baseLoc == in.getType()->getQualifier().layoutLocation) {
123                                        found = true;
124                                        break;
125                                    }
126                                    baseLoc += glslang::TIntermediate::computeTypeLocationSize(*(*outIt.second->getType()->getStruct())[j].type, EShLangVertex);
127                                }
128                            }
129                            if (found) {
130                                break;
131                            }
132                        }
133                    }
134                    success &= found;
135                }
136                else {
137                    unsigned int baseLoc = in.getType()->getQualifier().hasLocation() ? in.getType()->getQualifier().layoutLocation : -1;
138                    for (size_t j = 0; j < in.getType()->getStruct()->size(); j++) {
139                        baseLoc = (*in.getType()->getStruct())[j].type->getQualifier().hasLocation() ?
140                            (*in.getType()->getStruct())[j].type->getQualifier().layoutLocation : baseLoc;
141                        if (baseLoc != std::numeric_limits<unsigned int>::max()) {
142                            bool isMemberFound = false;
143                            for (auto outIt : pipeOut) {
144                                if (baseLoc == outIt.second->getType()->getQualifier().layoutLocation) {
145                                    isMemberFound = true;
146                                    break;
147                                }
148                            }
149                            if (!isMemberFound) {
150                                success &= false;
151                                break;
152                            }
153                            baseLoc += glslang::TIntermediate::computeTypeLocationSize(*(*in.getType()->getStruct())[j].type, EShLangVertex);
154                        }
155                    }
156                }
157            }
158        }
159    }
160
161    // compare uniforms in each stage to the program
162    {
163        int totalUniforms = program.getNumUniformVariables();
164        std::map<std::string, const glslang::TObjectReflection*> programUniforms;
165        for (int i = 0; i < totalUniforms; i++) {
166            const glslang::TObjectReflection& uniform = program.getUniform(i);
167            std::string name = interfaceName(uniform);
168            programUniforms[name] = &uniform;
169        }
170        it = stageReflections.begin();
171        for (; it != stageReflections.end(); it++) {
172            int numUniform = it->getNumUniforms();
173            std::map<std::string, glslang::TObjectReflection> uniforms;
174
175            for (int i = 0; i < numUniform; i++) {
176                glslang::TObjectReflection uniform = it->getUniform(i);
177                std::string name = interfaceName(uniform);
178                auto programUniform = programUniforms.find(name);
179
180                if (programUniform != programUniforms.end()) {
181                    auto stageQualifier = uniform.getType()->getQualifier();
182                    auto programQualifier = programUniform->second->getType()->getQualifier();
183
184                    success &= stageQualifier.layoutLocation == programQualifier.layoutLocation;
185                    success &= stageQualifier.layoutBinding == programQualifier.layoutBinding;
186                    success &= stageQualifier.layoutSet == programQualifier.layoutSet;
187                }
188                else {
189                    success &= false;
190                }
191            }
192        }
193    }
194
195    // compare uniform blocks in each stage to the program table
196    {
197        int totalUniforms = program.getNumUniformBlocks();
198        std::map<std::string, const glslang::TObjectReflection*> programUniforms;
199        for (int i = 0; i < totalUniforms; i++) {
200            const glslang::TObjectReflection& uniform = program.getUniformBlock(i);
201            std::string name = interfaceName(uniform);
202            programUniforms[name] = &uniform;
203        }
204        it = stageReflections.begin();
205        for (; it != stageReflections.end(); it++) {
206            int numUniform = it->getNumUniformBlocks();
207            std::map<std::string, glslang::TObjectReflection> uniforms;
208
209            for (int i = 0; i < numUniform; i++) {
210                glslang::TObjectReflection uniform = it->getUniformBlock(i);
211                std::string name = interfaceName(uniform);
212                auto programUniform = programUniforms.find(name);
213
214                if (programUniform != programUniforms.end()) {
215                    auto stageQualifier = uniform.getType()->getQualifier();
216                    auto programQualifier = programUniform->second->getType()->getQualifier();
217
218                    success &= stageQualifier.layoutLocation == programQualifier.layoutLocation;
219                    success &= stageQualifier.layoutBinding == programQualifier.layoutBinding;
220                    success &= stageQualifier.layoutSet == programQualifier.layoutSet;
221                }
222                else {
223                    success &= false;
224                }
225            }
226        }
227    }
228
229    if (!success) {
230        linkingError += "Mismatched cross-stage IO\n";
231    }
232
233    return success;
234}
235
236TEST_P(GlslMapIOTest, FromFile)
237{
238    const auto& fileNames = GetParam().fileNames;
239    Semantics semantics = GetParam().semantics;
240    const size_t fileCount = fileNames.size();
241    const EShMessages controls = DeriveOptions(Source::GLSL, semantics, Target::BothASTAndSpv);
242    GlslangResult result;
243
244    // Compile each input shader file.
245    bool success = true;
246    std::vector<std::unique_ptr<glslang::TShader>> shaders;
247    for (size_t i = 0; i < fileCount; ++i) {
248        std::string contents;
249        tryLoadFile(GlobalTestSettings.testRoot + "/" + fileNames[i],
250            "input", &contents);
251        shaders.emplace_back(
252            new glslang::TShader(GetShaderStage(GetSuffix(fileNames[i]))));
253        auto* shader = shaders.back().get();
254
255        shader->setAutoMapLocations(true);
256        shader->setAutoMapBindings(true);
257
258        if (controls & EShMsgSpvRules) {
259            if (controls & EShMsgVulkanRules) {
260                shader->setEnvInput((controls & EShMsgReadHlsl) ? glslang::EShSourceHlsl
261                                                               : glslang::EShSourceGlsl,
262                                    shader->getStage(), glslang::EShClientVulkan, 100);
263                shader->setEnvClient(glslang::EShClientVulkan, glslang::EShTargetVulkan_1_1);
264                shader->setEnvTarget(glslang::EShTargetSpv, glslang::EShTargetSpv_1_0);
265            } else {
266                shader->setEnvInput((controls & EShMsgReadHlsl) ? glslang::EShSourceHlsl
267                                                               : glslang::EShSourceGlsl,
268                                    shader->getStage(), glslang::EShClientOpenGL, 100);
269                shader->setEnvClient(glslang::EShClientOpenGL, glslang::EShTargetOpenGL_450);
270                shader->setEnvTarget(glslang::EshTargetSpv, glslang::EShTargetSpv_1_0);
271            }
272        }
273
274        success &= compile(shader, contents, "", controls);
275
276        result.shaderResults.push_back(
277            { fileNames[i], shader->getInfoLog(), shader->getInfoDebugLog() });
278    }
279
280    // Link all of them.
281    glslang::TProgram program;
282    for (const auto& shader : shaders) program.addShader(shader.get());
283    success &= program.link(controls);
284    result.linkingOutput = program.getInfoLog();
285    result.linkingError = program.getInfoDebugLog();
286
287    unsigned int stage = 0;
288    glslang::TIntermediate* firstIntermediate = nullptr;
289    while (!program.getIntermediate((EShLanguage)stage) && stage < EShLangCount) { stage++; }
290    firstIntermediate = program.getIntermediate((EShLanguage)stage);
291
292    glslang::TDefaultGlslIoResolver resolver(*firstIntermediate);
293    glslang::TGlslIoMapper ioMapper;
294
295    if (success) {
296        success &= program.mapIO(&resolver, &ioMapper);
297        result.linkingOutput = program.getInfoLog();
298        result.linkingError = program.getInfoDebugLog();
299    }
300
301    success &= verifyIOMapping(result.linkingError, program);
302    result.validationResult = success;
303
304    if (success && (controls & EShMsgSpvRules)) {
305        for (int stage = 0; stage < EShLangCount; ++stage) {
306            if (program.getIntermediate((EShLanguage)stage)) {
307                spv::SpvBuildLogger logger;
308                std::vector<uint32_t> spirv_binary;
309                options().disableOptimizer = false;
310                glslang::GlslangToSpv(*program.getIntermediate((EShLanguage)stage),
311                    spirv_binary, &logger, &options());
312
313                std::ostringstream disassembly_stream;
314                spv::Parameterize();
315                spv::Disassemble(disassembly_stream, spirv_binary);
316                result.spirvWarningsErrors += logger.getAllMessages();
317                result.spirv += disassembly_stream.str();
318                result.validationResult &= !options().validate || logger.getAllMessages().empty();
319            }
320        }
321    }
322
323    std::ostringstream stream;
324    outputResultToStream(&stream, result, controls);
325
326    // Check with expected results.
327    const std::string expectedOutputFname =
328        GlobalTestSettings.testRoot + "/baseResults/" + fileNames.front() + ".out";
329    std::string expectedOutput;
330    tryLoadFile(expectedOutputFname, "expected output", &expectedOutput);
331
332    checkEqAndUpdateIfRequested(expectedOutput, stream.str(), expectedOutputFname,
333        result.spirvWarningsErrors);
334}
335
336// clang-format off
337INSTANTIATE_TEST_SUITE_P(
338    Glsl, GlslMapIOTest,
339    ::testing::ValuesIn(std::vector<IoMapData>({
340        {{"iomap.crossStage.vert", "iomap.crossStage.frag" }, Semantics::OpenGL},
341        {{"iomap.crossStage.2.vert", "iomap.crossStage.2.geom", "iomap.crossStage.2.frag" }, Semantics::OpenGL},
342        {{"iomap.blockOutVariableIn.vert", "iomap.blockOutVariableIn.frag"}, Semantics::OpenGL},
343        {{"iomap.variableOutBlockIn.vert", "iomap.variableOutBlockIn.frag"}, Semantics::OpenGL},
344        {{"iomap.blockOutVariableIn.2.vert", "iomap.blockOutVariableIn.geom"}, Semantics::OpenGL},
345        {{"iomap.variableOutBlockIn.2.vert", "iomap.variableOutBlockIn.geom"}, Semantics::OpenGL},
346        // vulkan semantics
347        {{"iomap.crossStage.vk.vert", "iomap.crossStage.vk.geom", "iomap.crossStage.vk.frag" }, Semantics::Vulkan},
348    }))
349);
350// clang-format on
351
352}  // anonymous namespace
353}  // namespace glslangtest
354