1// Copyright 2017 The Dawn Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include "dawn_native/metal/ShaderModuleMTL.h" 16 17#include "dawn_native/BindGroupLayout.h" 18#include "dawn_native/TintUtils.h" 19#include "dawn_native/metal/DeviceMTL.h" 20#include "dawn_native/metal/PipelineLayoutMTL.h" 21#include "dawn_native/metal/RenderPipelineMTL.h" 22 23#include <tint/tint.h> 24 25#include <sstream> 26 27namespace dawn_native { namespace metal { 28 29 // static 30 ResultOrError<Ref<ShaderModule>> ShaderModule::Create(Device* device, 31 const ShaderModuleDescriptor* descriptor, 32 ShaderModuleParseResult* parseResult) { 33 Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); 34 DAWN_TRY(module->Initialize(parseResult)); 35 return module; 36 } 37 38 ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor) 39 : ShaderModuleBase(device, descriptor) { 40 } 41 42 MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { 43 ScopedTintICEHandler scopedICEHandler(GetDevice()); 44 return InitializeBase(parseResult); 45 } 46 47 ResultOrError<std::string> ShaderModule::TranslateToMSL( 48 const char* entryPointName, 49 SingleShaderStage stage, 50 const PipelineLayout* layout, 51 uint32_t sampleMask, 52 const RenderPipeline* renderPipeline, 53 std::string* remappedEntryPointName, 54 bool* needsStorageBufferLength, 55 bool* hasInvariantAttribute, 56 std::vector<uint32_t>* workgroupAllocations) { 57 ScopedTintICEHandler scopedICEHandler(GetDevice()); 58 59 std::ostringstream errorStream; 60 errorStream << "Tint MSL failure:" << std::endl; 61 62 // Remap BindingNumber to BindingIndex in WGSL shader 63 using BindingRemapper = tint::transform::BindingRemapper; 64 using BindingPoint = tint::transform::BindingPoint; 65 BindingRemapper::BindingPoints bindingPoints; 66 BindingRemapper::AccessControls accessControls; 67 68 for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { 69 const BindGroupLayoutBase::BindingMap& bindingMap = 70 layout->GetBindGroupLayout(group)->GetBindingMap(); 71 for (const auto& it : bindingMap) { 72 BindingNumber bindingNumber = it.first; 73 BindingIndex bindingIndex = it.second; 74 75 const BindingInfo& bindingInfo = 76 layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex); 77 78 if (!(bindingInfo.visibility & StageBit(stage))) { 79 continue; 80 } 81 82 uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex]; 83 84 BindingPoint srcBindingPoint{static_cast<uint32_t>(group), 85 static_cast<uint32_t>(bindingNumber)}; 86 BindingPoint dstBindingPoint{0, shaderIndex}; 87 if (srcBindingPoint != dstBindingPoint) { 88 bindingPoints.emplace(srcBindingPoint, dstBindingPoint); 89 } 90 } 91 } 92 93 tint::transform::Manager transformManager; 94 tint::transform::DataMap transformInputs; 95 96 // We only remap bindings for the target entry point, so we need to strip all other entry 97 // points to avoid generating invalid bindings for them. 98 transformManager.Add<tint::transform::SingleEntryPoint>(); 99 transformInputs.Add<tint::transform::SingleEntryPoint::Config>(entryPointName); 100 101 if (stage == SingleShaderStage::Vertex && 102 GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { 103 transformManager.Add<tint::transform::VertexPulling>(); 104 AddVertexPullingTransformConfig(*renderPipeline, entryPointName, 105 kPullingBufferBindingSet, &transformInputs); 106 107 for (VertexBufferSlot slot : 108 IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { 109 uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); 110 111 // Tell Tint to map (kPullingBufferBindingSet, slot) to this MSL buffer index. 112 BindingPoint srcBindingPoint{static_cast<uint32_t>(kPullingBufferBindingSet), 113 static_cast<uint8_t>(slot)}; 114 BindingPoint dstBindingPoint{0, metalIndex}; 115 if (srcBindingPoint != dstBindingPoint) { 116 bindingPoints.emplace(srcBindingPoint, dstBindingPoint); 117 } 118 } 119 } 120 if (GetDevice()->IsRobustnessEnabled()) { 121 transformManager.Add<tint::transform::Robustness>(); 122 } 123 transformManager.Add<tint::transform::BindingRemapper>(); 124 transformManager.Add<tint::transform::Renamer>(); 125 126 if (GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming)) { 127 // We still need to rename MSL reserved keywords 128 transformInputs.Add<tint::transform::Renamer::Config>( 129 tint::transform::Renamer::Target::kMslKeywords); 130 } 131 132 transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints), 133 std::move(accessControls), 134 /* mayCollide */ true); 135 136 tint::Program program; 137 tint::transform::DataMap transformOutputs; 138 DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs, 139 &transformOutputs, nullptr)); 140 141 if (auto* data = transformOutputs.Get<tint::transform::Renamer::Data>()) { 142 auto it = data->remappings.find(entryPointName); 143 if (it != data->remappings.end()) { 144 *remappedEntryPointName = it->second; 145 } else { 146 DAWN_INVALID_IF(!GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming), 147 "Could not find remapped name for entry point."); 148 149 *remappedEntryPointName = entryPointName; 150 } 151 } else { 152 return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data."); 153 } 154 155 tint::writer::msl::Options options; 156 options.buffer_size_ubo_index = kBufferLengthBufferSlot; 157 options.fixed_sample_mask = sampleMask; 158 options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit); 159 options.emit_vertex_point_size = 160 stage == SingleShaderStage::Vertex && 161 renderPipeline->GetPrimitiveTopology() == wgpu::PrimitiveTopology::PointList; 162 auto result = tint::writer::msl::Generate(&program, options); 163 DAWN_INVALID_IF(!result.success, "An error occured while generating MSL: %s.", 164 result.error); 165 166 *needsStorageBufferLength = result.needs_storage_buffer_sizes; 167 *hasInvariantAttribute = result.has_invariant_attribute; 168 *workgroupAllocations = std::move(result.workgroup_allocations[*remappedEntryPointName]); 169 170 return std::move(result.msl); 171 } 172 173 MaybeError ShaderModule::CreateFunction(const char* entryPointName, 174 SingleShaderStage stage, 175 const PipelineLayout* layout, 176 ShaderModule::MetalFunctionData* out, 177 id constantValuesPointer, 178 uint32_t sampleMask, 179 const RenderPipeline* renderPipeline) { 180 ASSERT(!IsError()); 181 ASSERT(out); 182 183 // Vertex stages must specify a renderPipeline 184 if (stage == SingleShaderStage::Vertex) { 185 ASSERT(renderPipeline != nullptr); 186 } 187 188 std::string remappedEntryPointName; 189 std::string msl; 190 bool hasInvariantAttribute = false; 191 DAWN_TRY_ASSIGN(msl, 192 TranslateToMSL(entryPointName, stage, layout, sampleMask, renderPipeline, 193 &remappedEntryPointName, &out->needsStorageBufferLength, 194 &hasInvariantAttribute, &out->workgroupAllocations)); 195 196 // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall 197 // category. -Wunused-variable in particular comes up a lot in generated code, and some 198 // (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead 199 // of a warning. 200 msl = R"( 201#ifdef __clang__ 202#pragma clang diagnostic ignored "-Wall" 203#endif 204)" + msl; 205 206 if (GetDevice()->IsToggleEnabled(Toggle::DumpShaders)) { 207 std::ostringstream dumpedMsg; 208 dumpedMsg << "/* Dumped generated MSL */" << std::endl << msl; 209 GetDevice()->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); 210 } 211 212 NSRef<NSString> mslSource = AcquireNSRef([[NSString alloc] initWithUTF8String:msl.c_str()]); 213 214 NSRef<MTLCompileOptions> compileOptions = AcquireNSRef([[MTLCompileOptions alloc] init]); 215 if (hasInvariantAttribute) { 216 if (@available(macOS 11.0, iOS 13.0, *)) { 217 (*compileOptions).preserveInvariance = true; 218 } 219 } 220 auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice(); 221 NSError* error = nullptr; 222 NSPRef<id<MTLLibrary>> library = 223 AcquireNSPRef([mtlDevice newLibraryWithSource:mslSource.Get() 224 options:compileOptions.Get() 225 error:&error]); 226 if (error != nullptr) { 227 DAWN_INVALID_IF(error.code != MTLLibraryErrorCompileWarning, 228 "Unable to create library object: %s.", 229 [error.localizedDescription UTF8String]); 230 } 231 ASSERT(library != nil); 232 233 NSRef<NSString> name = 234 AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]); 235 236 if (constantValuesPointer != nil) { 237 if (@available(macOS 10.12, *)) { 238 MTLFunctionConstantValues* constantValues = constantValuesPointer; 239 out->function = AcquireNSPRef([*library newFunctionWithName:name.Get() 240 constantValues:constantValues 241 error:&error]); 242 if (error != nullptr) { 243 if (error.code != MTLLibraryErrorCompileWarning) { 244 return DAWN_VALIDATION_ERROR(std::string("Function compile error: ") + 245 [error.localizedDescription UTF8String]); 246 } 247 } 248 ASSERT(out->function != nil); 249 } else { 250 UNREACHABLE(); 251 } 252 } else { 253 out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); 254 } 255 256 if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && 257 GetEntryPoint(entryPointName).usedVertexInputs.any()) { 258 out->needsStorageBufferLength = true; 259 } 260 261 return {}; 262 } 263}} // namespace dawn_native::metal 264