1 // Copyright (c) 2023 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/trim_capabilities_pass.h"
16 
17 #include <algorithm>
18 #include <array>
19 #include <cassert>
20 #include <functional>
21 #include <optional>
22 #include <queue>
23 #include <stack>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include <vector>
27 
28 #include "source/enum_set.h"
29 #include "source/enum_string_mapping.h"
30 #include "source/opt/ir_context.h"
31 #include "source/opt/reflect.h"
32 #include "source/spirv_target_env.h"
33 #include "source/util/string_utils.h"
34 
35 namespace spvtools {
36 namespace opt {
37 
38 namespace {
39 constexpr uint32_t kOpTypeFloatSizeIndex = 0;
40 constexpr uint32_t kOpTypePointerStorageClassIndex = 0;
41 constexpr uint32_t kTypeArrayTypeIndex = 0;
42 constexpr uint32_t kOpTypeScalarBitWidthIndex = 0;
43 constexpr uint32_t kTypePointerTypeIdInIndex = 1;
44 constexpr uint32_t kOpTypeIntSizeIndex = 0;
45 constexpr uint32_t kOpTypeImageDimIndex = 1;
46 constexpr uint32_t kOpTypeImageArrayedIndex = kOpTypeImageDimIndex + 2;
47 constexpr uint32_t kOpTypeImageMSIndex = kOpTypeImageArrayedIndex + 1;
48 constexpr uint32_t kOpTypeImageSampledIndex = kOpTypeImageMSIndex + 1;
49 constexpr uint32_t kOpTypeImageFormatIndex = kOpTypeImageSampledIndex + 1;
50 constexpr uint32_t kOpImageReadImageIndex = 0;
51 constexpr uint32_t kOpImageSparseReadImageIndex = 0;
52 
53 // DFS visit of the type defined by `instruction`.
54 // If `condition` is true, children of the current node are visited.
55 // If `condition` is false, the children of the current node are ignored.
56 template <class UnaryPredicate>
DFSWhile(const Instruction* instruction, UnaryPredicate condition)57 static void DFSWhile(const Instruction* instruction, UnaryPredicate condition) {
58   std::stack<uint32_t> instructions_to_visit;
59   instructions_to_visit.push(instruction->result_id());
60   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
61 
62   while (!instructions_to_visit.empty()) {
63     const Instruction* item = def_use_mgr->GetDef(instructions_to_visit.top());
64     instructions_to_visit.pop();
65 
66     if (!condition(item)) {
67       continue;
68     }
69 
70     if (item->opcode() == spv::Op::OpTypePointer) {
71       instructions_to_visit.push(
72           item->GetSingleWordInOperand(kTypePointerTypeIdInIndex));
73       continue;
74     }
75 
76     if (item->opcode() == spv::Op::OpTypeMatrix ||
77         item->opcode() == spv::Op::OpTypeVector ||
78         item->opcode() == spv::Op::OpTypeArray ||
79         item->opcode() == spv::Op::OpTypeRuntimeArray) {
80       instructions_to_visit.push(
81           item->GetSingleWordInOperand(kTypeArrayTypeIndex));
82       continue;
83     }
84 
85     if (item->opcode() == spv::Op::OpTypeStruct) {
86       item->ForEachInOperand([&instructions_to_visit](const uint32_t* op_id) {
87         instructions_to_visit.push(*op_id);
88       });
89       continue;
90     }
91   }
92 }
93 
94 // Walks the type defined by `instruction` (OpType* only).
95 // Returns `true` if any call to `predicate` with the type/subtype returns true.
96 template <class UnaryPredicate>
AnyTypeOf(const Instruction* instruction, UnaryPredicate predicate)97 static bool AnyTypeOf(const Instruction* instruction,
98                       UnaryPredicate predicate) {
99   assert(IsTypeInst(instruction->opcode()) &&
100          "AnyTypeOf called with a non-type instruction.");
101 
102   bool found_one = false;
103   DFSWhile(instruction, [&found_one, predicate](const Instruction* node) {
104     if (found_one || predicate(node)) {
105       found_one = true;
106       return false;
107     }
108 
109     return true;
110   });
111   return found_one;
112 }
113 
is16bitType(const Instruction* instruction)114 static bool is16bitType(const Instruction* instruction) {
115   if (instruction->opcode() != spv::Op::OpTypeInt &&
116       instruction->opcode() != spv::Op::OpTypeFloat) {
117     return false;
118   }
119 
120   return instruction->GetSingleWordInOperand(kOpTypeScalarBitWidthIndex) == 16;
121 }
122 
Has16BitCapability(const FeatureManager* feature_manager)123 static bool Has16BitCapability(const FeatureManager* feature_manager) {
124   const CapabilitySet& capabilities = feature_manager->GetCapabilities();
125   return capabilities.contains(spv::Capability::Float16) ||
126          capabilities.contains(spv::Capability::Int16);
127 }
128 
129 }  // namespace
130 
131 // ============== Begin opcode handler implementations. =======================
132 //
133 // Adding support for a new capability should only require adding a new handler,
134 // and updating the
135 // kSupportedCapabilities/kUntouchableCapabilities/kFordiddenCapabilities lists.
136 //
137 // Handler names follow the following convention:
138 //  Handler_<Opcode>_<Capability>()
139 
Handler_OpTypeFloat_Float64( const Instruction* instruction)140 static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
141     const Instruction* instruction) {
142   assert(instruction->opcode() == spv::Op::OpTypeFloat &&
143          "This handler only support OpTypeFloat opcodes.");
144 
145   const uint32_t size =
146       instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
147   return size == 64 ? std::optional(spv::Capability::Float64) : std::nullopt;
148 }
149 
150 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction)151 Handler_OpTypePointer_StorageInputOutput16(const Instruction* instruction) {
152   assert(instruction->opcode() == spv::Op::OpTypePointer &&
153          "This handler only support OpTypePointer opcodes.");
154 
155   // This capability is only required if the variable has an Input/Output
156   // storage class.
157   spv::StorageClass storage_class = spv::StorageClass(
158       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
159   if (storage_class != spv::StorageClass::Input &&
160       storage_class != spv::StorageClass::Output) {
161     return std::nullopt;
162   }
163 
164   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
165     return std::nullopt;
166   }
167 
168   return AnyTypeOf(instruction, is16bitType)
169              ? std::optional(spv::Capability::StorageInputOutput16)
170              : std::nullopt;
171 }
172 
173 static std::optional<spv::Capability>
Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction)174 Handler_OpTypePointer_StoragePushConstant16(const Instruction* instruction) {
175   assert(instruction->opcode() == spv::Op::OpTypePointer &&
176          "This handler only support OpTypePointer opcodes.");
177 
178   // This capability is only required if the variable has a PushConstant storage
179   // class.
180   spv::StorageClass storage_class = spv::StorageClass(
181       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
182   if (storage_class != spv::StorageClass::PushConstant) {
183     return std::nullopt;
184   }
185 
186   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
187     return std::nullopt;
188   }
189 
190   return AnyTypeOf(instruction, is16bitType)
191              ? std::optional(spv::Capability::StoragePushConstant16)
192              : std::nullopt;
193 }
194 
195 static std::optional<spv::Capability>
Handler_OpTypePointer_StorageUniformBufferBlock16( const Instruction* instruction)196 Handler_OpTypePointer_StorageUniformBufferBlock16(
197     const Instruction* instruction) {
198   assert(instruction->opcode() == spv::Op::OpTypePointer &&
199          "This handler only support OpTypePointer opcodes.");
200 
201   // This capability is only required if the variable has a Uniform storage
202   // class.
203   spv::StorageClass storage_class = spv::StorageClass(
204       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
205   if (storage_class != spv::StorageClass::Uniform) {
206     return std::nullopt;
207   }
208 
209   if (!Has16BitCapability(instruction->context()->get_feature_mgr())) {
210     return std::nullopt;
211   }
212 
213   const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
214   const bool matchesCondition =
215       AnyTypeOf(instruction, [decoration_mgr](const Instruction* item) {
216         if (!decoration_mgr->HasDecoration(item->result_id(),
217                                            spv::Decoration::BufferBlock)) {
218           return false;
219         }
220 
221         return AnyTypeOf(item, is16bitType);
222       });
223 
224   return matchesCondition
225              ? std::optional(spv::Capability::StorageUniformBufferBlock16)
226              : std::nullopt;
227 }
228 
Handler_OpTypePointer_StorageUniform16( const Instruction* instruction)229 static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
230     const Instruction* instruction) {
231   assert(instruction->opcode() == spv::Op::OpTypePointer &&
232          "This handler only support OpTypePointer opcodes.");
233 
234   // This capability is only required if the variable has a Uniform storage
235   // class.
236   spv::StorageClass storage_class = spv::StorageClass(
237       instruction->GetSingleWordInOperand(kOpTypePointerStorageClassIndex));
238   if (storage_class != spv::StorageClass::Uniform) {
239     return std::nullopt;
240   }
241 
242   const auto* feature_manager = instruction->context()->get_feature_mgr();
243   if (!Has16BitCapability(feature_manager)) {
244     return std::nullopt;
245   }
246 
247   const bool hasBufferBlockCapability =
248       feature_manager->GetCapabilities().contains(
249           spv::Capability::StorageUniformBufferBlock16);
250   const auto* decoration_mgr = instruction->context()->get_decoration_mgr();
251   bool found16bitType = false;
252 
253   DFSWhile(instruction, [decoration_mgr, hasBufferBlockCapability,
254                          &found16bitType](const Instruction* item) {
255     if (found16bitType) {
256       return false;
257     }
258 
259     if (hasBufferBlockCapability &&
260         decoration_mgr->HasDecoration(item->result_id(),
261                                       spv::Decoration::BufferBlock)) {
262       return false;
263     }
264 
265     if (is16bitType(item)) {
266       found16bitType = true;
267       return false;
268     }
269 
270     return true;
271   });
272 
273   return found16bitType ? std::optional(spv::Capability::StorageUniform16)
274                         : std::nullopt;
275 }
276 
Handler_OpTypeInt_Int64( const Instruction* instruction)277 static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
278     const Instruction* instruction) {
279   assert(instruction->opcode() == spv::Op::OpTypeInt &&
280          "This handler only support OpTypeInt opcodes.");
281 
282   const uint32_t size =
283       instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
284   return size == 64 ? std::optional(spv::Capability::Int64) : std::nullopt;
285 }
286 
Handler_OpTypeImage_ImageMSArray( const Instruction* instruction)287 static std::optional<spv::Capability> Handler_OpTypeImage_ImageMSArray(
288     const Instruction* instruction) {
289   assert(instruction->opcode() == spv::Op::OpTypeImage &&
290          "This handler only support OpTypeImage opcodes.");
291 
292   const uint32_t arrayed =
293       instruction->GetSingleWordInOperand(kOpTypeImageArrayedIndex);
294   const uint32_t ms = instruction->GetSingleWordInOperand(kOpTypeImageMSIndex);
295   const uint32_t sampled =
296       instruction->GetSingleWordInOperand(kOpTypeImageSampledIndex);
297 
298   return arrayed == 1 && sampled == 2 && ms == 1
299              ? std::optional(spv::Capability::ImageMSArray)
300              : std::nullopt;
301 }
302 
303 static std::optional<spv::Capability>
Handler_OpImageRead_StorageImageReadWithoutFormat( const Instruction* instruction)304 Handler_OpImageRead_StorageImageReadWithoutFormat(
305     const Instruction* instruction) {
306   assert(instruction->opcode() == spv::Op::OpImageRead &&
307          "This handler only support OpImageRead opcodes.");
308   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
309 
310   const uint32_t image_index =
311       instruction->GetSingleWordInOperand(kOpImageReadImageIndex);
312   const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
313   const Instruction* type = def_use_mgr->GetDef(type_index);
314   const uint32_t dim = type->GetSingleWordInOperand(kOpTypeImageDimIndex);
315   const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
316 
317   const bool is_unknown = spv::ImageFormat(format) == spv::ImageFormat::Unknown;
318   const bool requires_capability_for_unknown =
319       spv::Dim(dim) != spv::Dim::SubpassData;
320   return is_unknown && requires_capability_for_unknown
321              ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
322              : std::nullopt;
323 }
324 
325 static std::optional<spv::Capability>
Handler_OpImageSparseRead_StorageImageReadWithoutFormat( const Instruction* instruction)326 Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
327     const Instruction* instruction) {
328   assert(instruction->opcode() == spv::Op::OpImageSparseRead &&
329          "This handler only support OpImageSparseRead opcodes.");
330   const auto* def_use_mgr = instruction->context()->get_def_use_mgr();
331 
332   const uint32_t image_index =
333       instruction->GetSingleWordInOperand(kOpImageSparseReadImageIndex);
334   const uint32_t type_index = def_use_mgr->GetDef(image_index)->type_id();
335   const Instruction* type = def_use_mgr->GetDef(type_index);
336   const uint32_t format = type->GetSingleWordInOperand(kOpTypeImageFormatIndex);
337 
338   return spv::ImageFormat(format) == spv::ImageFormat::Unknown
339              ? std::optional(spv::Capability::StorageImageReadWithoutFormat)
340              : std::nullopt;
341 }
342 
343 // Opcode of interest to determine capabilities requirements.
344 constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 10> kOpcodeHandlers{{
345     // clang-format off
346     {spv::Op::OpImageRead,         Handler_OpImageRead_StorageImageReadWithoutFormat},
347     {spv::Op::OpImageSparseRead,   Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
348     {spv::Op::OpTypeFloat,         Handler_OpTypeFloat_Float64 },
349     {spv::Op::OpTypeImage,         Handler_OpTypeImage_ImageMSArray},
350     {spv::Op::OpTypeInt,           Handler_OpTypeInt_Int64 },
351     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageInputOutput16},
352     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StoragePushConstant16},
353     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
354     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniform16},
355     {spv::Op::OpTypePointer,       Handler_OpTypePointer_StorageUniformBufferBlock16},
356     // clang-format on
357 }};
358 
359 // ==============  End opcode handler implementations.  =======================
360 
361 namespace {
getExtensionsRelatedTo(const CapabilitySet& capabilities, const AssemblyGrammar& grammar)362 ExtensionSet getExtensionsRelatedTo(const CapabilitySet& capabilities,
363                                     const AssemblyGrammar& grammar) {
364   ExtensionSet output;
365   const spv_operand_desc_t* desc = nullptr;
366   for (auto capability : capabilities) {
367     if (SPV_SUCCESS != grammar.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
368                                              static_cast<uint32_t>(capability),
369                                              &desc)) {
370       continue;
371     }
372 
373     for (uint32_t i = 0; i < desc->numExtensions; ++i) {
374       output.insert(desc->extensions[i]);
375     }
376   }
377 
378   return output;
379 }
380 }  // namespace
381 
TrimCapabilitiesPass()382 TrimCapabilitiesPass::TrimCapabilitiesPass()
383     : supportedCapabilities_(
384           TrimCapabilitiesPass::kSupportedCapabilities.cbegin(),
385           TrimCapabilitiesPass::kSupportedCapabilities.cend()),
386       forbiddenCapabilities_(
387           TrimCapabilitiesPass::kForbiddenCapabilities.cbegin(),
388           TrimCapabilitiesPass::kForbiddenCapabilities.cend()),
389       untouchableCapabilities_(
390           TrimCapabilitiesPass::kUntouchableCapabilities.cbegin(),
391           TrimCapabilitiesPass::kUntouchableCapabilities.cend()),
392       opcodeHandlers_(kOpcodeHandlers.cbegin(), kOpcodeHandlers.cend()) {}
393 
addInstructionRequirementsForOpcode( spv::Op opcode, CapabilitySet* capabilities, ExtensionSet* extensions) const394 void TrimCapabilitiesPass::addInstructionRequirementsForOpcode(
395     spv::Op opcode, CapabilitySet* capabilities,
396     ExtensionSet* extensions) const {
397   // Ignoring OpBeginInvocationInterlockEXT and OpEndInvocationInterlockEXT
398   // because they have three possible capabilities, only one of which is needed
399   if (opcode == spv::Op::OpBeginInvocationInterlockEXT ||
400       opcode == spv::Op::OpEndInvocationInterlockEXT) {
401     return;
402   }
403 
404   const spv_opcode_desc_t* desc = {};
405   auto result = context()->grammar().lookupOpcode(opcode, &desc);
406   if (result != SPV_SUCCESS) {
407     return;
408   }
409 
410   addSupportedCapabilitiesToSet(desc, capabilities);
411   addSupportedExtensionsToSet(desc, extensions);
412 }
413 
addInstructionRequirementsForOperand( const Operand& operand, CapabilitySet* capabilities, ExtensionSet* extensions) const414 void TrimCapabilitiesPass::addInstructionRequirementsForOperand(
415     const Operand& operand, CapabilitySet* capabilities,
416     ExtensionSet* extensions) const {
417   // No supported capability relies on a 2+-word operand.
418   if (operand.words.size() != 1) {
419     return;
420   }
421 
422   // No supported capability relies on a literal string operand or an ID.
423   if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING ||
424       operand.type == SPV_OPERAND_TYPE_ID ||
425       operand.type == SPV_OPERAND_TYPE_RESULT_ID) {
426     return;
427   }
428 
429   // case 1: Operand is a single value, can directly lookup.
430   if (!spvOperandIsConcreteMask(operand.type)) {
431     const spv_operand_desc_t* desc = {};
432     auto result = context()->grammar().lookupOperand(operand.type,
433                                                      operand.words[0], &desc);
434     if (result != SPV_SUCCESS) {
435       return;
436     }
437     addSupportedCapabilitiesToSet(desc, capabilities);
438     addSupportedExtensionsToSet(desc, extensions);
439     return;
440   }
441 
442   // case 2: operand can be a bitmask, we need to decompose the lookup.
443   for (uint32_t i = 0; i < 32; i++) {
444     const uint32_t mask = (1 << i) & operand.words[0];
445     if (!mask) {
446       continue;
447     }
448 
449     const spv_operand_desc_t* desc = {};
450     auto result = context()->grammar().lookupOperand(operand.type, mask, &desc);
451     if (result != SPV_SUCCESS) {
452       continue;
453     }
454 
455     addSupportedCapabilitiesToSet(desc, capabilities);
456     addSupportedExtensionsToSet(desc, extensions);
457   }
458 }
459 
addInstructionRequirements( Instruction* instruction, CapabilitySet* capabilities, ExtensionSet* extensions) const460 void TrimCapabilitiesPass::addInstructionRequirements(
461     Instruction* instruction, CapabilitySet* capabilities,
462     ExtensionSet* extensions) const {
463   // Ignoring OpCapability and OpExtension instructions.
464   if (instruction->opcode() == spv::Op::OpCapability ||
465       instruction->opcode() == spv::Op::OpExtension) {
466     return;
467   }
468 
469   addInstructionRequirementsForOpcode(instruction->opcode(), capabilities,
470                                       extensions);
471 
472   // Second case: one of the opcode operand is gated by a capability.
473   const uint32_t operandCount = instruction->NumOperands();
474   for (uint32_t i = 0; i < operandCount; i++) {
475     addInstructionRequirementsForOperand(instruction->GetOperand(i),
476                                          capabilities, extensions);
477   }
478 
479   // Last case: some complex logic needs to be run to determine capabilities.
480   auto[begin, end] = opcodeHandlers_.equal_range(instruction->opcode());
481   for (auto it = begin; it != end; it++) {
482     const OpcodeHandler handler = it->second;
483     auto result = handler(instruction);
484     if (!result.has_value()) {
485       continue;
486     }
487 
488     capabilities->insert(*result);
489   }
490 }
491 
AddExtensionsForOperand( const spv_operand_type_t type, const uint32_t value, ExtensionSet* extensions) const492 void TrimCapabilitiesPass::AddExtensionsForOperand(
493     const spv_operand_type_t type, const uint32_t value,
494     ExtensionSet* extensions) const {
495   const spv_operand_desc_t* desc = nullptr;
496   spv_result_t result = context()->grammar().lookupOperand(type, value, &desc);
497   if (result != SPV_SUCCESS) {
498     return;
499   }
500   addSupportedExtensionsToSet(desc, extensions);
501 }
502 
503 std::pair<CapabilitySet, ExtensionSet>
DetermineRequiredCapabilitiesAndExtensions() const504 TrimCapabilitiesPass::DetermineRequiredCapabilitiesAndExtensions() const {
505   CapabilitySet required_capabilities;
506   ExtensionSet required_extensions;
507 
508   get_module()->ForEachInst([&](Instruction* instruction) {
509     addInstructionRequirements(instruction, &required_capabilities,
510                                &required_extensions);
511   });
512 
513   for (auto capability : required_capabilities) {
514     AddExtensionsForOperand(SPV_OPERAND_TYPE_CAPABILITY,
515                             static_cast<uint32_t>(capability),
516                             &required_extensions);
517   }
518 
519 #if !defined(NDEBUG)
520   // Debug only. We check the outputted required capabilities against the
521   // supported capabilities list. The supported capabilities list is useful for
522   // API users to quickly determine if they can use the pass or not. But this
523   // list has to remain up-to-date with the pass code. If we can detect a
524   // capability as required, but it's not listed, it means the list is
525   // out-of-sync. This method is not ideal, but should cover most cases.
526   {
527     for (auto capability : required_capabilities) {
528       assert(supportedCapabilities_.contains(capability) &&
529              "Module is using a capability that is not listed as supported.");
530     }
531   }
532 #endif
533 
534   return std::make_pair(std::move(required_capabilities),
535                         std::move(required_extensions));
536 }
537 
TrimUnrequiredCapabilities( const CapabilitySet& required_capabilities) const538 Pass::Status TrimCapabilitiesPass::TrimUnrequiredCapabilities(
539     const CapabilitySet& required_capabilities) const {
540   const FeatureManager* feature_manager = context()->get_feature_mgr();
541   CapabilitySet capabilities_to_trim;
542   for (auto capability : feature_manager->GetCapabilities()) {
543     // Some capabilities cannot be safely removed. Leaving them untouched.
544     if (untouchableCapabilities_.contains(capability)) {
545       continue;
546     }
547 
548     // If the capability is unsupported, don't trim it.
549     if (!supportedCapabilities_.contains(capability)) {
550       continue;
551     }
552 
553     if (required_capabilities.contains(capability)) {
554       continue;
555     }
556 
557     capabilities_to_trim.insert(capability);
558   }
559 
560   for (auto capability : capabilities_to_trim) {
561     context()->RemoveCapability(capability);
562   }
563 
564   return capabilities_to_trim.size() == 0 ? Pass::Status::SuccessWithoutChange
565                                           : Pass::Status::SuccessWithChange;
566 }
567 
TrimUnrequiredExtensions( const ExtensionSet& required_extensions) const568 Pass::Status TrimCapabilitiesPass::TrimUnrequiredExtensions(
569     const ExtensionSet& required_extensions) const {
570   const auto supported_extensions =
571       getExtensionsRelatedTo(supportedCapabilities_, context()->grammar());
572 
573   bool modified_module = false;
574   for (auto extension : supported_extensions) {
575     if (required_extensions.contains(extension)) {
576       continue;
577     }
578 
579     if (context()->RemoveExtension(extension)) {
580       modified_module = true;
581     }
582   }
583 
584   return modified_module ? Pass::Status::SuccessWithChange
585                          : Pass::Status::SuccessWithoutChange;
586 }
587 
HasForbiddenCapabilities() const588 bool TrimCapabilitiesPass::HasForbiddenCapabilities() const {
589   // EnumSet.HasAnyOf returns `true` if the given set is empty.
590   if (forbiddenCapabilities_.size() == 0) {
591     return false;
592   }
593 
594   const auto& capabilities = context()->get_feature_mgr()->GetCapabilities();
595   return capabilities.HasAnyOf(forbiddenCapabilities_);
596 }
597 
Process()598 Pass::Status TrimCapabilitiesPass::Process() {
599   if (HasForbiddenCapabilities()) {
600     return Status::SuccessWithoutChange;
601   }
602 
603   auto[required_capabilities, required_extensions] =
604       DetermineRequiredCapabilitiesAndExtensions();
605 
606   Pass::Status capStatus = TrimUnrequiredCapabilities(required_capabilities);
607   Pass::Status extStatus = TrimUnrequiredExtensions(required_extensions);
608 
609   return capStatus == Pass::Status::SuccessWithChange ||
610                  extStatus == Pass::Status::SuccessWithChange
611              ? Pass::Status::SuccessWithChange
612              : Pass::Status::SuccessWithoutChange;
613 }
614 
615 }  // namespace opt
616 }  // namespace spvtools
617