1// Copyright 2017 The Dawn Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "dawn_native/metal/ShaderModuleMTL.h"
16
17#include "dawn_native/BindGroupLayout.h"
18#include "dawn_native/TintUtils.h"
19#include "dawn_native/metal/DeviceMTL.h"
20#include "dawn_native/metal/PipelineLayoutMTL.h"
21#include "dawn_native/metal/RenderPipelineMTL.h"
22
23#include <tint/tint.h>
24
25#include <sstream>
26
27namespace dawn_native { namespace metal {
28
29    // static
30    ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device,
31                                                          const ShaderModuleDescriptor* descriptor,
32                                                          ShaderModuleParseResult* parseResult) {
33        Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
34        DAWN_TRY(module->Initialize(parseResult));
35        return module;
36    }
37
38    ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
39        : ShaderModuleBase(device, descriptor) {
40    }
41
42    MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) {
43        ScopedTintICEHandler scopedICEHandler(GetDevice());
44        return InitializeBase(parseResult);
45    }
46
47    ResultOrError<std::string> ShaderModule::TranslateToMSL(
48        const char* entryPointName,
49        SingleShaderStage stage,
50        const PipelineLayout* layout,
51        uint32_t sampleMask,
52        const RenderPipeline* renderPipeline,
53        std::string* remappedEntryPointName,
54        bool* needsStorageBufferLength,
55        bool* hasInvariantAttribute,
56        std::vector<uint32_t>* workgroupAllocations) {
57        ScopedTintICEHandler scopedICEHandler(GetDevice());
58
59        std::ostringstream errorStream;
60        errorStream << "Tint MSL failure:" << std::endl;
61
62        // Remap BindingNumber to BindingIndex in WGSL shader
63        using BindingRemapper = tint::transform::BindingRemapper;
64        using BindingPoint = tint::transform::BindingPoint;
65        BindingRemapper::BindingPoints bindingPoints;
66        BindingRemapper::AccessControls accessControls;
67
68        for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
69            const BindGroupLayoutBase::BindingMap& bindingMap =
70                layout->GetBindGroupLayout(group)->GetBindingMap();
71            for (const auto& it : bindingMap) {
72                BindingNumber bindingNumber = it.first;
73                BindingIndex bindingIndex = it.second;
74
75                const BindingInfo& bindingInfo =
76                    layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex);
77
78                if (!(bindingInfo.visibility & StageBit(stage))) {
79                    continue;
80                }
81
82                uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex];
83
84                BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
85                                             static_cast<uint32_t>(bindingNumber)};
86                BindingPoint dstBindingPoint{0, shaderIndex};
87                if (srcBindingPoint != dstBindingPoint) {
88                    bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
89                }
90            }
91        }
92
93        tint::transform::Manager transformManager;
94        tint::transform::DataMap transformInputs;
95
96        // We only remap bindings for the target entry point, so we need to strip all other entry
97        // points to avoid generating invalid bindings for them.
98        transformManager.Add<tint::transform::SingleEntryPoint>();
99        transformInputs.Add<tint::transform::SingleEntryPoint::Config>(entryPointName);
100
101        if (stage == SingleShaderStage::Vertex &&
102            GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) {
103            transformManager.Add<tint::transform::VertexPulling>();
104            AddVertexPullingTransformConfig(*renderPipeline, entryPointName,
105                                            kPullingBufferBindingSet, &transformInputs);
106
107            for (VertexBufferSlot slot :
108                 IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
109                uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot);
110
111                // Tell Tint to map (kPullingBufferBindingSet, slot) to this MSL buffer index.
112                BindingPoint srcBindingPoint{static_cast<uint32_t>(kPullingBufferBindingSet),
113                                             static_cast<uint8_t>(slot)};
114                BindingPoint dstBindingPoint{0, metalIndex};
115                if (srcBindingPoint != dstBindingPoint) {
116                    bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
117                }
118            }
119        }
120        if (GetDevice()->IsRobustnessEnabled()) {
121            transformManager.Add<tint::transform::Robustness>();
122        }
123        transformManager.Add<tint::transform::BindingRemapper>();
124        transformManager.Add<tint::transform::Renamer>();
125
126        if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) {
127            // We still need to rename MSL reserved keywords
128            transformInputs.Add<tint::transform::Renamer::Config>(
129                tint::transform::Renamer::Target::kMslKeywords);
130        }
131
132        transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
133                                                         std::move(accessControls),
134                                                         /* mayCollide */ true);
135
136        tint::Program program;
137        tint::transform::DataMap transformOutputs;
138        DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
139                                               &transformOutputs, nullptr));
140
141        if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) {
142            auto it = data->remappings.find(entryPointName);
143            if (it != data->remappings.end()) {
144                *remappedEntryPointName = it->second;
145            } else {
146                DAWN_INVALID_IF(!GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming),
147                                "Could not find remapped name for entry point.");
148
149                *remappedEntryPointName = entryPointName;
150            }
151        } else {
152            return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data.");
153        }
154
155        tint::writer::msl::Options options;
156        options.buffer_size_ubo_index = kBufferLengthBufferSlot;
157        options.fixed_sample_mask = sampleMask;
158        options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
159        options.emit_vertex_point_size =
160            stage == SingleShaderStage::Vertex &&
161            renderPipeline->GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList;
162        auto result = tint::writer::msl::Generate(&program, options);
163        DAWN_INVALID_IF(!result.success, "An error occured while generating MSL: %s.",
164                        result.error);
165
166        *needsStorageBufferLength = result.needs_storage_buffer_sizes;
167        *hasInvariantAttribute = result.has_invariant_attribute;
168        *workgroupAllocations = std::move(result.workgroup_allocations[*remappedEntryPointName]);
169
170        return std::move(result.msl);
171    }
172
173    MaybeError ShaderModule::CreateFunction(const char* entryPointName,
174                                            SingleShaderStage stage,
175                                            const PipelineLayout* layout,
176                                            ShaderModule::MetalFunctionData* out,
177                                            id constantValuesPointer,
178                                            uint32_t sampleMask,
179                                            const RenderPipeline* renderPipeline) {
180        ASSERT(!IsError());
181        ASSERT(out);
182
183        // Vertex stages must specify a renderPipeline
184        if (stage == SingleShaderStage::Vertex) {
185            ASSERT(renderPipeline != nullptr);
186        }
187
188        std::string remappedEntryPointName;
189        std::string msl;
190        bool hasInvariantAttribute = false;
191        DAWN_TRY_ASSIGN(msl,
192                        TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline,
193                                       &remappedEntryPointName, &out->needsStorageBufferLength,
194                                       &hasInvariantAttribute, &out->workgroupAllocations));
195
196        // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
197        // category. -Wunused-variable in particular comes up a lot in generated code, and some
198        // (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead
199        // of a warning.
200        msl = R"(
201#ifdef __clang__
202#pragma clang diagnostic ignored "-Wall"
203#endif
204)" + msl;
205
206        if (GetDevice()->IsToggleEnabled(Toggle::DumpShaders)) {
207            std::ostringstream dumpedMsg;
208            dumpedMsg << "/* Dumped generated MSL */" << std::endl << msl;
209            GetDevice()->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str());
210        }
211
212        NSRef<NSString> mslSource = AcquireNSRef([[NSString alloc] initWithUTF8String:msl.c_str()]);
213
214        NSRef<MTLCompileOptions> compileOptions = AcquireNSRef([[MTLCompileOptions alloc] init]);
215        if (hasInvariantAttribute) {
216            if (@available(macOS 11.0, iOS 13.0, *)) {
217                (*compileOptions).preserveInvariance = true;
218            }
219        }
220        auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
221        NSError* error = nullptr;
222        NSPRef<id<MTLLibrary>> library =
223            AcquireNSPRef([mtlDevice newLibraryWithSource:mslSource.Get()
224                                                  options:compileOptions.Get()
225                                                    error:&error]);
226        if (error != nullptr) {
227            DAWN_INVALID_IF(error.code != MTLLibraryErrorCompileWarning,
228                            "Unable to create library object: %s.",
229                            [error.localizedDescription UTF8String]);
230        }
231        ASSERT(library != nil);
232
233        NSRef<NSString> name =
234            AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]);
235
236        if (constantValuesPointer != nil) {
237            if (@available(macOS 10.12, *)) {
238                MTLFunctionConstantValues* constantValues = constantValuesPointer;
239                out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()
240                                                             constantValues:constantValues
241                                                                      error:&error]);
242                if (error != nullptr) {
243                    if (error.code != MTLLibraryErrorCompileWarning) {
244                        return DAWN_VALIDATION_ERROR(std::string("Function compile error: ") +
245                                                     [error.localizedDescription UTF8String]);
246                    }
247                }
248                ASSERT(out->function != nil);
249            } else {
250                UNREACHABLE();
251            }
252        } else {
253            out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
254        }
255
256        if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
257            GetEntryPoint(entryPointName).usedVertexInputs.any()) {
258            out->needsStorageBufferLength = true;
259        }
260
261        return {};
262    }
263}}  // namespace dawn_native::metal
264