1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/trim_capabilities_pass.h"
16
17 #include <algorithm>
18 #include <array>
19 #include <cassert>
20 #include <functional>
21 #include <optional>
22 #include <queue>
23 #include <stack>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <vector>
27
28 #include "source/enum_set.h"
29 #include "source/enum_string_mapping.h"
30 #include "source/opt/ir_context.h"
31 #include "source/opt/reflect.h"
32 #include "source/spirv_target_env.h"
33 #include "source/util/string_utils.h"
34
35 namespace spvtools {
36 namespace opt {
37
38 namespace {
39 constexpr uint32_t kOpTypeFloatSizeIndex = 0;
40 constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
41 constexpr uint32_t kTypeArrayTypeIndex = 0;
42 constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
43 constexpr uint32_t kTypePointerTypeIdInIndex = 1;
44 constexpr uint32_t kOpTypeIntSizeIndex = 0;
45 constexpr uint32_t kOpTypeImageDimIndex = 1;
46 constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
47 constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
48 constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
49 constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
50 constexpr uint32_t kOpImageReadImageIndex = 0;
51 constexpr uint32_t kOpImageSparseReadImageIndex = 0;
52
53 // DFS visit of the type defined by `instruction`.
54 // If `condition` is true, children of the current node are visited.
55 // If `condition` is false, the children of the current node are ignored.
56 template <class UnaryPredicate>
DFSWhile(const Instruction* instruction, UnaryPredicate condition)57 static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
58 std::stack<uint32_t> instructions_to_visit;
59 instructions_to_visit.push(instruction->result_id());
60 const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
61
62 while (!instructions_to_visit.empty()) {
63 const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
64 instructions_to_visit.pop();
65
66 if (!condition(item)) {
67 continue;
68 }
69
70 if (item->opcode() == spv::Op::OpTypePointer) {
71 instructions_to_visit.push(
72 item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
73 continue;
74 }
75
76 if (item->opcode() == spv::Op::OpTypeMatrix ||
77 item->opcode() == spv::Op::OpTypeVector ||
78 item->opcode() == spv::Op::OpTypeArray ||
79 item->opcode() == spv::Op::OpTypeRuntimeArray) {
80 instructions_to_visit.push(
81 item->GetSingleWordInOperand(kTypeArrayTypeIndex));
82 continue;
83 }
84
85 if (item->opcode() == spv::Op::OpTypeStruct) {
86 item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
87 instructions_to_visit.push(*op_id);
88 });
89 continue;
90 }
91 }
92 }
93
94 // Walks the type defined by `instruction` (OpType* only).
95 // Returns `true` if any call to `predicate` with the type/subtype returns true.
96 template <class UnaryPredicate>
AnyTypeOf(const Instruction* instruction, UnaryPredicate predicate)97 static bool AnyTypeOf(const Instruction* instruction,
98 UnaryPredicate predicate) {
99 assert(IsTypeInst(instruction->opcode()) &&
100 "AnyTypeOf called with a non-type instruction.");
101
102 bool found_one = false;
103 DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
104 if (found_one || predicate(node)) {
105 found_one = true;
106 return false;
107 }
108
109 return true;
110 });
111 return found_one;
112 }
113
is16bitType(const Instruction* instruction)114 static bool is16bitType(const Instruction* instruction) {
115 if (instruction->opcode() != spv::Op::OpTypeInt &&
116 instruction->opcode() != spv::Op::OpTypeFloat) {
117 return false;
118 }
119
120 return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
121 }
122
Has16BitCapability(const FeatureManager* feature_manager)123 static bool Has16BitCapability(const FeatureManager* feature_manager) {
124 const CapabilitySet& capabilities = feature_manager->GetCapabilities();
125 return capabilities.contains(spv::Capability::Float16) ||
126 capabilities.contains(spv::Capability::Int16);
127 }
128
129 } // namespace
130
131 // ============== Begin opcode handler implementations. =======================
132 //
133 // Adding support for a new capability should only require adding a new handler,
134 // and updating the
135 // kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
136 //
137 // Handler names follow the following convention:
138 // Handler_<Opcode>_<Capability>()
139
Handler_OpTypeFloat_Float64( const Instruction* instruction)140 static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
141 const Instruction* instruction) {
142 assert(instruction->opcode() == spv::Op::OpTypeFloat &&
143 "This handler only support OpTypeFloat opcodes.");
144
145 const uint32_t size =
146 instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
147 return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
148 }
149
150 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction)151 Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
152 assert(instruction->opcode() == spv::Op::OpTypePointer &&
153 "This handler only support OpTypePointer opcodes.");
154
155 // This capability is only required if the variable has an Input/Output
156 // storage class.
157 spv::StorageClass storage_class = spv::StorageClass(
158 instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
159 if (storage_class != spv::StorageClass::Input &&
160 storage_class != spv::StorageClass::Output) {
161 return std::nullopt;
162 }
163
164 if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
165 return std::nullopt;
166 }
167
168 return AnyTypeOf(instruction, is16bitType)
169 ? std::optional(spv::Capability::StorageInputOutput16)
170 : std::nullopt;
171 }
172
173 static std::optional<spv::Capability>
Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction)174 Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
175 assert(instruction->opcode() == spv::Op::OpTypePointer &&
176 "This handler only support OpTypePointer opcodes.");
177
178 // This capability is only required if the variable has a PushConstant storage
179 // class.
180 spv::StorageClass storage_class = spv::StorageClass(
181 instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
182 if (storage_class != spv::StorageClass::PushConstant) {
183 return std::nullopt;
184 }
185
186 if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
187 return std::nullopt;
188 }
189
190 return AnyTypeOf(instruction, is16bitType)
191 ? std::optional(spv::Capability::StoragePushConstant16)
192 : std::nullopt;
193 }
194
195 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageUniformBufferBlock16( const Instruction* instruction)196 Handler_OpTypePointer_StorageUniformBufferBlock16(
197 const Instruction* instruction) {
198 assert(instruction->opcode() == spv::Op::OpTypePointer &&
199 "This handler only support OpTypePointer opcodes.");
200
201 // This capability is only required if the variable has a Uniform storage
202 // class.
203 spv::StorageClass storage_class = spv::StorageClass(
204 instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
205 if (storage_class != spv::StorageClass::Uniform) {
206 return std::nullopt;
207 }
208
209 if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
210 return std::nullopt;
211 }
212
213 const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
214 const bool matchesCondition =
215 AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
216 if (!decoration_mgr->HasDecoration(item->result_id(),
217 spv::Decoration::BufferBlock)) {
218 return false;
219 }
220
221 return AnyTypeOf(item, is16bitType);
222 });
223
224 return matchesCondition
225 ? std::optional(spv::Capability::StorageUniformBufferBlock16)
226 : std::nullopt;
227 }
228
Handler_OpTypePointer_StorageUniform16( const Instruction* instruction)229 static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
230 const Instruction* instruction) {
231 assert(instruction->opcode() == spv::Op::OpTypePointer &&
232 "This handler only support OpTypePointer opcodes.");
233
234 // This capability is only required if the variable has a Uniform storage
235 // class.
236 spv::StorageClass storage_class = spv::StorageClass(
237 instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
238 if (storage_class != spv::StorageClass::Uniform) {
239 return std::nullopt;
240 }
241
242 const auto* feature_manager = instruction->context()->get_feature_mgr();
243 if (!Has16BitCapability(feature_manager)) {
244 return std::nullopt;
245 }
246
247 const bool hasBufferBlockCapability =
248 feature_manager->GetCapabilities().contains(
249 spv::Capability::StorageUniformBufferBlock16);
250 const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
251 bool found16bitType = false;
252
253 DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
254 &found16bitType](const Instruction* item) {
255 if (found16bitType) {
256 return false;
257 }
258
259 if (hasBufferBlockCapability &&
260 decoration_mgr->HasDecoration(item->result_id(),
261 spv::Decoration::BufferBlock)) {
262 return false;
263 }
264
265 if (is16bitType(item)) {
266 found16bitType = true;
267 return false;
268 }
269
270 return true;
271 });
272
273 return found16bitType ? std::optional(spv::Capability::StorageUniform16)
274 : std::nullopt;
275 }
276
Handler_OpTypeInt_Int64( const Instruction* instruction)277 static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
278 const Instruction* instruction) {
279 assert(instruction->opcode() == spv::Op::OpTypeInt &&
280 "This handler only support OpTypeInt opcodes.");
281
282 const uint32_t size =
283 instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
284 return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
285 }
286
Handler_OpTypeImage_ImageMSArray( const Instruction* instruction)287 static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
288 const Instruction* instruction) {
289 assert(instruction->opcode() == spv::Op::OpTypeImage &&
290 "This handler only support OpTypeImage opcodes.");
291
292 const uint32_t arrayed =
293 instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
294 const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
295 const uint32_t sampled =
296 instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);
297
298 return arrayed == 1 && sampled == 2 && ms == 1
299 ? std::optional(spv::Capability::ImageMSArray)
300 : std::nullopt;
301 }
302
303 static std::optional<spv::Capability>
Handler_OpImageRead_StorageImageReadWithoutFormat( const Instruction* instruction)304 Handler_OpImageRead_StorageImageReadWithoutFormat(
305 const Instruction* instruction) {
306 assert(instruction->opcode() == spv::Op::OpImageRead &&
307 "This handler only support OpImageRead opcodes.");
308 const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
309
310 const uint32_t image_index =
311 instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
312 const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
313 const Instruction* type = def_use_mgr->GetDef(type_index);
314 const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
315 const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
316
317 const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
318 const bool requires_capability_for_unknown =
319 spv::Dim(dim) != spv::Dim::SubpassData;
320 return is_unknown && requires_capability_for_unknown
321 ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
322 : std::nullopt;
323 }
324
325 static std::optional<spv::Capability>
Handler_OpImageSparseRead_StorageImageReadWithoutFormat( const Instruction* instruction)326 Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
327 const Instruction* instruction) {
328 assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
329 "This handler only support OpImageSparseRead opcodes.");
330 const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
331
332 const uint32_t image_index =
333 instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
334 const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
335 const Instruction* type = def_use_mgr->GetDef(type_index);
336 const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
337
338 return spv::ImageFormat(format) == spv::ImageFormat::Unknown
339 ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
340 : std::nullopt;
341 }
342
343 // Opcode of interest to determine capabilities requirements.
344 constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 10> kOpcodeHandlers{{
345 // clang-format off
346 {spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
347 {spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
348 {spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },
349 {spv::Op::OpTypeImage, Handler_OpTypeImage_ImageMSArray},
350 {spv::Op::OpTypeInt, Handler_OpTypeInt_Int64 },
351 {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageInputOutput16},
352 {spv::Op::OpTypePointer, Handler_OpTypePointer_StoragePushConstant16},
353 {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
354 {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniform16},
355 {spv::Op::OpTypePointer, Handler_OpTypePointer_StorageUniformBufferBlock16},
356 // clang-format on
357 }};
358
359 // ============== End opcode handler implementations. =======================
360
361 namespace {
getExtensionsRelatedTo(const CapabilitySet& capabilities, const AssemblyGrammar& grammar)362 ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
363 const AssemblyGrammar& grammar) {
364 ExtensionSet output;
365 const spv_operand_desc_t* desc = nullptr;
366 for (auto capability : capabilities) {
367 if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
368 static_cast<uint32_t>(capability),
369 &desc)) {
370 continue;
371 }
372
373 for (uint32_t i = 0; i < desc->numExtensions; ++i) {
374 output.insert(desc->extensions[i]);
375 }
376 }
377
378 return output;
379 }
380 } // namespace
381
TrimCapabilitiesPass()382 TrimCapabilitiesPass::TrimCapabilitiesPass()
383 : supportedCapabilities_(
384 TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
385 TrimCapabilitiesPass::kSupportedCapabilities.cend()),
386 forbiddenCapabilities_(
387 TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
388 TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
389 untouchableCapabilities_(
390 TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
391 TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
392 opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
393
addInstructionRequirementsForOpcode( spv::Op opcode, CapabilitySet* capabilities, ExtensionSet* extensions) const394 void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
395 spv::Op opcode, CapabilitySet* capabilities,
396 ExtensionSet* extensions) const {
397 // Ignoring OpBeginInvocationInterlockEXT and OpEndInvocationInterlockEXT
398 // because they have three possible capabilities, only one of which is needed
399 if (opcode == spv::Op::OpBeginInvocationInterlockEXT ||
400 opcode == spv::Op::OpEndInvocationInterlockEXT) {
401 return;
402 }
403
404 const spv_opcode_desc_t* desc = {};
405 auto result = context()->grammar().lookupOpcode(opcode, &desc);
406 if (result != SPV_SUCCESS) {
407 return;
408 }
409
410 addSupportedCapabilitiesToSet(desc, capabilities);
411 addSupportedExtensionsToSet(desc, extensions);
412 }
413
addInstructionRequirementsForOperand( const Operand& operand, CapabilitySet* capabilities, ExtensionSet* extensions) const414 void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
415 const Operand& operand, CapabilitySet* capabilities,
416 ExtensionSet* extensions) const {
417 // No supported capability relies on a 2+-word operand.
418 if (operand.words.size() != 1) {
419 return;
420 }
421
422 // No supported capability relies on a literal string operand or an ID.
423 if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
424 operand.type == SPV_OPERAND_TYPE_ID ||
425 operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
426 return;
427 }
428
429 // case 1: Operand is a single value, can directly lookup.
430 if (!spvOperandIsConcreteMask(operand.type)) {
431 const spv_operand_desc_t* desc = {};
432 auto result = context()->grammar().lookupOperand(operand.type,
433 operand.words[0], &desc);
434 if (result != SPV_SUCCESS) {
435 return;
436 }
437 addSupportedCapabilitiesToSet(desc, capabilities);
438 addSupportedExtensionsToSet(desc, extensions);
439 return;
440 }
441
442 // case 2: operand can be a bitmask, we need to decompose the lookup.
443 for (uint32_t i = 0; i < 32; i++) {
444 const uint32_t mask = (1 << i) & operand.words[0];
445 if (!mask) {
446 continue;
447 }
448
449 const spv_operand_desc_t* desc = {};
450 auto result = context()->grammar().lookupOperand(operand.type, mask, &desc);
451 if (result != SPV_SUCCESS) {
452 continue;
453 }
454
455 addSupportedCapabilitiesToSet(desc, capabilities);
456 addSupportedExtensionsToSet(desc, extensions);
457 }
458 }
459
addInstructionRequirements( Instruction* instruction, CapabilitySet* capabilities, ExtensionSet* extensions) const460 void TrimCapabilitiesPass::addInstructionRequirements(
461 Instruction* instruction, CapabilitySet* capabilities,
462 ExtensionSet* extensions) const {
463 // Ignoring OpCapability and OpExtension instructions.
464 if (instruction->opcode() == spv::Op::OpCapability ||
465 instruction->opcode() == spv::Op::OpExtension) {
466 return;
467 }
468
469 addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
470 extensions);
471
472 // Second case: one of the opcode operand is gated by a capability.
473 const uint32_t operandCount = instruction->NumOperands();
474 for (uint32_t i = 0; i < operandCount; i++) {
475 addInstructionRequirementsForOperand(instruction->GetOperand(i),
476 capabilities, extensions);
477 }
478
479 // Last case: some complex logic needs to be run to determine capabilities.
480 auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
481 for (auto it = begin; it != end; it++) {
482 const OpcodeHandler handler = it->second;
483 auto result = handler(instruction);
484 if (!result.has_value()) {
485 continue;
486 }
487
488 capabilities->insert(*result);
489 }
490 }
491
AddExtensionsForOperand( const spv_operand_type_t type, const uint32_t value, ExtensionSet* extensions) const492 void TrimCapabilitiesPass::AddExtensionsForOperand(
493 const spv_operand_type_t type, const uint32_t value,
494 ExtensionSet* extensions) const {
495 const spv_operand_desc_t* desc = nullptr;
496 spv_result_t result = context()->grammar().lookupOperand(type, value, &desc);
497 if (result != SPV_SUCCESS) {
498 return;
499 }
500 addSupportedExtensionsToSet(desc, extensions);
501 }
502
503 std::pair<CapabilitySet, ExtensionSet>
DetermineRequiredCapabilitiesAndExtensions() const504 TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
505 CapabilitySet required_capabilities;
506 ExtensionSet required_extensions;
507
508 get_module()->ForEachInst([&](Instruction* instruction) {
509 addInstructionRequirements(instruction, &required_capabilities,
510 &required_extensions);
511 });
512
513 for (auto capability : required_capabilities) {
514 AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
515 static_cast<uint32_t>(capability),
516 &required_extensions);
517 }
518
519 #if !defined(NDEBUG)
520 // Debug only. We check the outputted required capabilities against the
521 // supported capabilities list. The supported capabilities list is useful for
522 // API users to quickly determine if they can use the pass or not. But this
523 // list has to remain up-to-date with the pass code. If we can detect a
524 // capability as required, but it's not listed, it means the list is
525 // out-of-sync. This method is not ideal, but should cover most cases.
526 {
527 for (auto capability : required_capabilities) {
528 assert(supportedCapabilities_.contains(capability) &&
529 "Module is using a capability that is not listed as supported.");
530 }
531 }
532 #endif
533
534 return std::make_pair(std::move(required_capabilities),
535 std::move(required_extensions));
536 }
537
TrimUnrequiredCapabilities( const CapabilitySet& required_capabilities) const538 Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
539 const CapabilitySet& required_capabilities) const {
540 const FeatureManager* feature_manager = context()->get_feature_mgr();
541 CapabilitySet capabilities_to_trim;
542 for (auto capability : feature_manager->GetCapabilities()) {
543 // Some capabilities cannot be safely removed. Leaving them untouched.
544 if (untouchableCapabilities_.contains(capability)) {
545 continue;
546 }
547
548 // If the capability is unsupported, don't trim it.
549 if (!supportedCapabilities_.contains(capability)) {
550 continue;
551 }
552
553 if (required_capabilities.contains(capability)) {
554 continue;
555 }
556
557 capabilities_to_trim.insert(capability);
558 }
559
560 for (auto capability : capabilities_to_trim) {
561 context()->RemoveCapability(capability);
562 }
563
564 return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
565 : Pass::Status::SuccessWithChange;
566 }
567
TrimUnrequiredExtensions( const ExtensionSet& required_extensions) const568 Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
569 const ExtensionSet& required_extensions) const {
570 const auto supported_extensions =
571 getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
572
573 bool modified_module = false;
574 for (auto extension : supported_extensions) {
575 if (required_extensions.contains(extension)) {
576 continue;
577 }
578
579 if (context()->RemoveExtension(extension)) {
580 modified_module = true;
581 }
582 }
583
584 return modified_module ? Pass::Status::SuccessWithChange
585 : Pass::Status::SuccessWithoutChange;
586 }
587
HasForbiddenCapabilities() const588 bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
589 // EnumSet.HasAnyOf returns `true` if the given set is empty.
590 if (forbiddenCapabilities_.size() == 0) {
591 return false;
592 }
593
594 const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
595 return capabilities.HasAnyOf(forbiddenCapabilities_);
596 }
597
Process()598 Pass::Status TrimCapabilitiesPass::Process() {
599 if (HasForbiddenCapabilities()) {
600 return Status::SuccessWithoutChange;
601 }
602
603 auto[required_capabilities, required_extensions] =
604 DetermineRequiredCapabilitiesAndExtensions();
605
606 Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
607 Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);
608
609 return capStatus == Pass::Status::SuccessWithChange ||
610 extStatus == Pass::Status::SuccessWithChange
611 ? Pass::Status::SuccessWithChange
612 : Pass::Status::SuccessWithoutChange;
613 }
614
615 } // namespace opt
616 } // namespace spvtools
617