1 // Copyright (c) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef SOURCE_OPT_INTERFACE_VAR_SROA_H_
16 #define SOURCE_OPT_INTERFACE_VAR_SROA_H_
17 
18 #include <unordered_set>
19 
20 #include "source/opt/pass.h"
21 
22 namespace spvtools {
23 namespace opt {
24 
25 // See optimizer.hpp for documentation.
26 //
27 // Note that the current implementation of this pass covers only store, load,
28 // access chain instructions for the interface variables. Supporting other types
29 // of instructions is a future work.
30 class InterfaceVariableScalarReplacement : public Pass {
31  public:
InterfaceVariableScalarReplacement()32   InterfaceVariableScalarReplacement() {}
33 
34   const char* name() const override {
35     return "interface-variable-scalar-replacement";
36   }
37   Status Process() override;
38 
39   IRContext::Analysis GetPreservedAnalyses() override {
40     return IRContext::kAnalysisDecorations | IRContext::kAnalysisDefUse |
41            IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
42   }
43 
44  private:
45   // A struct containing components of a composite variable. If the composite
46   // consists of multiple or recursive components, |component_variable| is
47   // nullptr and |nested_composite_components| keeps the components. If it has a
48   // single component, |nested_composite_components| is empty and
49   // |component_variable| is the component. Note that each element of
50   // |nested_composite_components| has the NestedCompositeComponents struct as
51   // its type that can recursively keep the components.
52   struct NestedCompositeComponents {
NestedCompositeComponentsspvtools::opt::InterfaceVariableScalarReplacement::NestedCompositeComponents53     NestedCompositeComponents() : component_variable(nullptr) {}
54 
HasMultipleComponentsspvtools::opt::InterfaceVariableScalarReplacement::NestedCompositeComponents55     bool HasMultipleComponents() const {
56       return !nested_composite_components.empty();
57     }
58 
GetComponentsspvtools::opt::InterfaceVariableScalarReplacement::NestedCompositeComponents59     const std::vector<NestedCompositeComponents>& GetComponents() const {
60       return nested_composite_components;
61     }
62 
AddComponentspvtools::opt::InterfaceVariableScalarReplacement::NestedCompositeComponents63     void AddComponent(const NestedCompositeComponents& component) {
64       nested_composite_components.push_back(component);
65     }
66 
GetComponentVariablespvtools::opt::InterfaceVariableScalarReplacement::NestedCompositeComponents67     Instruction* GetComponentVariable() const { return component_variable; }
68 
SetSingleComponentVariablespvtools::opt::InterfaceVariableScalarReplacement::NestedCompositeComponents69     void SetSingleComponentVariable(Instruction* var) {
70       component_variable = var;
71     }
72 
73    private:
74     std::vector<NestedCompositeComponents> nested_composite_components;
75     Instruction* component_variable;
76   };
77 
78   // Collects all interface variables used by the |entry_point|.
79   std::vector<Instruction*> CollectInterfaceVariables(Instruction& entry_point);
80 
81   // Returns whether |var| has the extra arrayness for the entry point
82   // |entry_point| or not.
83   bool HasExtraArrayness(Instruction& entry_point, Instruction* var);
84 
85   // Finds a Location BuiltIn decoration of |var| and returns it via
86   // |location|. Returns true whether the location exists or not.
87   bool GetVariableLocation(Instruction* var, uint32_t* location);
88 
89   // Finds a Component BuiltIn decoration of |var| and returns it via
90   // |component|. Returns true whether the component exists or not.
91   bool GetVariableComponent(Instruction* var, uint32_t* component);
92 
93   // Returns the type of |var| as an instruction.
94   Instruction* GetTypeOfVariable(Instruction* var);
95 
96   // Replaces an interface variable |interface_var| whose type is
97   // |interface_var_type| with scalars and returns whether it succeeds or not.
98   // |location| is the value of Location Decoration for |interface_var|.
99   // |component| is the value of Component Decoration for |interface_var|.
100   // If |extra_array_length| is 0, it means |interface_var| has a Patch
101   // decoration. Otherwise, |extra_array_length| denotes the length of the extra
102   // array of |interface_var|.
103   bool ReplaceInterfaceVariableWithScalars(Instruction* interface_var,
104                                            Instruction* interface_var_type,
105                                            uint32_t location,
106                                            uint32_t component,
107                                            uint32_t extra_array_length);
108 
109   // Creates scalar variables with the storage classe |storage_class| to replace
110   // an interface variable whose type is |interface_var_type|. If
111   // |extra_array_length| is not zero, adds the extra arrayness to the created
112   // scalar variables.
113   NestedCompositeComponents CreateScalarInterfaceVarsForReplacement(
114       Instruction* interface_var_type, spv::StorageClass storage_class,
115       uint32_t extra_array_length);
116 
117   // Creates scalar variables with the storage classe |storage_class| to replace
118   // the interface variable whose type is OpTypeArray |interface_var_type| with.
119   // If |extra_array_length| is not zero, adds the extra arrayness to all the
120   // scalar variables.
121   NestedCompositeComponents CreateScalarInterfaceVarsForArray(
122       Instruction* interface_var_type, spv::StorageClass storage_class,
123       uint32_t extra_array_length);
124 
125   // Creates scalar variables with the storage classe |storage_class| to replace
126   // the interface variable whose type is OpTypeMatrix |interface_var_type|
127   // with. If |extra_array_length| is not zero, adds the extra arrayness to all
128   // the scalar variables.
129   NestedCompositeComponents CreateScalarInterfaceVarsForMatrix(
130       Instruction* interface_var_type, spv::StorageClass storage_class,
131       uint32_t extra_array_length);
132 
133   // Recursively adds Location and Component decorations to variables in
134   // |vars| with |location| and |component|. Increases |location| by one after
135   // it actually adds Location and Component decorations for a variable.
136   void AddLocationAndComponentDecorations(const NestedCompositeComponents& vars,
137                                           uint32_t* location,
138                                           uint32_t component);
139 
140   // Replaces the interface variable |interface_var| with
141   // |scalar_interface_vars| and returns whether it succeeds or not.
142   // |extra_arrayness| is the extra arrayness of the interface variable.
143   // |scalar_interface_vars| contains the nested variables to replace the
144   // interface variable with.
145   bool ReplaceInterfaceVarWith(
146       Instruction* interface_var, uint32_t extra_arrayness,
147       const NestedCompositeComponents& scalar_interface_vars);
148 
149   // Replaces |interface_var| in the operands of instructions
150   // |interface_var_users| with |scalar_interface_vars|. This is a recursive
151   // method and |interface_var_component_indices| is used to specify which
152   // recursive component of |interface_var| is replaced. Returns composite
153   // construct instructions to be replaced with load instructions of
154   // |interface_var_users| via |loads_to_composites|. Returns composite
155   // construct instructions to be replaced with load instructions of access
156   // chain instructions in |interface_var_users| via
157   // |loads_for_access_chain_to_composites|.
158   bool ReplaceComponentsOfInterfaceVarWith(
159       Instruction* interface_var,
160       const std::vector<Instruction*>& interface_var_users,
161       const NestedCompositeComponents& scalar_interface_vars,
162       std::vector<uint32_t>& interface_var_component_indices,
163       const uint32_t* extra_array_index,
164       std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
165       std::unordered_map<Instruction*, Instruction*>*
166           loads_for_access_chain_to_composites);
167 
168   // Replaces |interface_var| in the operands of instructions
169   // |interface_var_users| with |components| that is a vector of components for
170   // the interface variable |interface_var|. This is a recursive method and
171   // |interface_var_component_indices| is used to specify which recursive
172   // component of |interface_var| is replaced. Returns composite construct
173   // instructions to be replaced with load instructions of |interface_var_users|
174   // via |loads_to_composites|. Returns composite construct instructions to be
175   // replaced with load instructions of access chain instructions in
176   // |interface_var_users| via |loads_for_access_chain_to_composites|.
177   bool ReplaceMultipleComponentsOfInterfaceVarWith(
178       Instruction* interface_var,
179       const std::vector<Instruction*>& interface_var_users,
180       const std::vector<NestedCompositeComponents>& components,
181       std::vector<uint32_t>& interface_var_component_indices,
182       const uint32_t* extra_array_index,
183       std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
184       std::unordered_map<Instruction*, Instruction*>*
185           loads_for_access_chain_to_composites);
186 
187   // Replaces a component of |interface_var| that is used as an operand of
188   // instruction |interface_var_user| with |scalar_var|.
189   // |interface_var_component_indices| is a vector of recursive indices for
190   // which recursive component of |interface_var| is replaced. If
191   // |interface_var_user| is a load, returns the component value via
192   // |loads_to_component_values|. If |interface_var_user| is an access chain,
193   // returns the component value for loads of |interface_var_user| via
194   // |loads_for_access_chain_to_component_values|.
195   bool ReplaceComponentOfInterfaceVarWith(
196       Instruction* interface_var, Instruction* interface_var_user,
197       Instruction* scalar_var,
198       const std::vector<uint32_t>& interface_var_component_indices,
199       const uint32_t* extra_array_index,
200       std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
201       std::unordered_map<Instruction*, Instruction*>*
202           loads_for_access_chain_to_component_values);
203 
204   // Creates instructions to load |scalar_var| and inserts them before
205   // |insert_before|. If |extra_array_index| is not null, they load
206   // |extra_array_index| th component of |scalar_var| instead of |scalar_var|
207   // itself.
208   Instruction* LoadScalarVar(Instruction* scalar_var,
209                              const uint32_t* extra_array_index,
210                              Instruction* insert_before);
211 
212   // Creates instructions to load an access chain to |var| and inserts them
213   // before |insert_before|. |Indexes| will be Indexes operand of the access
214   // chain.
215   Instruction* LoadAccessChainToVar(Instruction* var,
216                                     const std::vector<uint32_t>& indexes,
217                                     Instruction* insert_before);
218 
219   // Creates instructions to store a component of an aggregate whose id is
220   // |value_id| to an access chain to |scalar_var| and inserts the created
221   // instructions before |insert_before|. To get the component, recursively
222   // traverses the aggregate with |component_indices| as indexes.
223   // Numbers in |access_chain_indices| are the Indexes operand of the access
224   // chain to |scalar_var|
225   void StoreComponentOfValueToAccessChainToScalarVar(
226       uint32_t value_id, const std::vector<uint32_t>& component_indices,
227       Instruction* scalar_var,
228       const std::vector<uint32_t>& access_chain_indices,
229       Instruction* insert_before);
230 
231   // Creates instructions to store a component of an aggregate whose id is
232   // |value_id| to |scalar_var| and inserts the created instructions before
233   // |insert_before|. To get the component, recursively traverses the aggregate
234   // using |extra_array_index| and |component_indices| as indexes.
235   void StoreComponentOfValueToScalarVar(
236       uint32_t value_id, const std::vector<uint32_t>& component_indices,
237       Instruction* scalar_var, const uint32_t* extra_array_index,
238       Instruction* insert_before);
239 
240   // Creates instructions to store a component of an aggregate whose id is
241   // |value_id| to |ptr| and inserts the created instructions before
242   // |insert_before|. To get the component, recursively traverses the aggregate
243   // using |extra_array_index| and |component_indices| as indexes.
244   // |component_type_id| is the id of the type instruction of the component.
245   void StoreComponentOfValueTo(uint32_t component_type_id, uint32_t value_id,
246                                const std::vector<uint32_t>& component_indices,
247                                Instruction* ptr,
248                                const uint32_t* extra_array_index,
249                                Instruction* insert_before);
250 
251   // Creates new OpCompositeExtract with |type_id| for Result Type,
252   // |composite_id| for Composite operand, and |indexes| for Indexes operands.
253   // If |extra_first_index| is not nullptr, uses it as the first Indexes
254   // operand.
255   Instruction* CreateCompositeExtract(uint32_t type_id, uint32_t composite_id,
256                                       const std::vector<uint32_t>& indexes,
257                                       const uint32_t* extra_first_index);
258 
259   // Creates a new OpLoad whose Result Type is |type_id| and Pointer operand is
260   // |ptr|. Inserts the new instruction before |insert_before|.
261   Instruction* CreateLoad(uint32_t type_id, Instruction* ptr,
262                           Instruction* insert_before);
263 
264   // Clones an annotation instruction |annotation_inst| and sets the target
265   // operand of the new annotation instruction as |var_id|.
266   void CloneAnnotationForVariable(Instruction* annotation_inst,
267                                   uint32_t var_id);
268 
269   // Replaces the interface variable |interface_var| in the operands of the
270   // entry point |entry_point| with |scalar_var_id|. If it cannot find
271   // |interface_var| from the operands of the entry point |entry_point|, adds
272   // |scalar_var_id| as an operand of the entry point |entry_point|.
273   bool ReplaceInterfaceVarInEntryPoint(Instruction* interface_var,
274                                        Instruction* entry_point,
275                                        uint32_t scalar_var_id);
276 
277   // Creates an access chain instruction whose Base operand is |var| and Indexes
278   // operand is |index|. |component_type_id| is the id of the type instruction
279   // that is the type of component. Inserts the new access chain before
280   // |insert_before|.
281   Instruction* CreateAccessChainWithIndex(uint32_t component_type_id,
282                                           Instruction* var, uint32_t index,
283                                           Instruction* insert_before);
284 
285   // Returns the pointee type of the type of variable |var|.
286   uint32_t GetPointeeTypeIdOfVar(Instruction* var);
287 
288   // Replaces the access chain |access_chain| and its users with a new access
289   // chain that points |scalar_var| as the Base operand having
290   // |interface_var_component_indices| as Indexes operands and users of the new
291   // access chain. When some of the users are load instructions, returns the
292   // original load instruction to the new instruction that loads a component of
293   // the original load value via |loads_to_component_values|.
294   void ReplaceAccessChainWith(
295       Instruction* access_chain,
296       const std::vector<uint32_t>& interface_var_component_indices,
297       Instruction* scalar_var,
298       std::unordered_map<Instruction*, Instruction*>*
299           loads_to_component_values);
300 
301   // Assuming that |access_chain| is an access chain instruction whose Base
302   // operand is |base_access_chain|, replaces the operands of |access_chain|
303   // with operands of |base_access_chain| and Indexes operands of
304   // |access_chain|.
305   void UseBaseAccessChainForAccessChain(Instruction* access_chain,
306                                         Instruction* base_access_chain);
307 
308   // Creates composite construct instructions for load instructions that are the
309   // keys of |loads_to_component_values| if no such composite construct
310   // instructions exist. Adds a component of the composite as an operand of the
311   // created composite construct instruction. Each value of
312   // |loads_to_component_values| is the component. Returns the created composite
313   // construct instructions using |loads_to_composites|. |depth_to_component| is
314   // the number of recursive access steps to get the component from the
315   // composite.
316   void AddComponentsToCompositesForLoads(
317       const std::unordered_map<Instruction*, Instruction*>&
318           loads_to_component_values,
319       std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
320       uint32_t depth_to_component);
321 
322   // Creates a composite construct instruction for a component of the value of
323   // instruction |load| in |depth_to_component| th recursive depth and inserts
324   // it after |load|.
325   Instruction* CreateCompositeConstructForComponentOfLoad(
326       Instruction* load, uint32_t depth_to_component);
327 
328   // Creates a new access chain instruction that points to variable |var| whose
329   // type is the instruction with |var_type_id| and inserts it before
330   // |insert_before|. The new access chain will have |index_ids| for Indexes
331   // operands. Returns the type id of the component that is pointed by the new
332   // access chain via |component_type_id|.
333   Instruction* CreateAccessChainToVar(uint32_t var_type_id, Instruction* var,
334                                       const std::vector<uint32_t>& index_ids,
335                                       Instruction* insert_before,
336                                       uint32_t* component_type_id);
337 
338   // Returns the result id of OpTypeArray instrunction whose Element Type
339   // operand is |elem_type_id| and Length operand is |array_length|.
340   uint32_t GetArrayType(uint32_t elem_type_id, uint32_t array_length);
341 
342   // Returns the result id of OpTypePointer instrunction whose Type
343   // operand is |type_id| and Storage Class operand is |storage_class|.
344   uint32_t GetPointerType(uint32_t type_id, spv::StorageClass storage_class);
345 
346   // Kills an instrunction |inst| and its users.
347   void KillInstructionAndUsers(Instruction* inst);
348 
349   // Kills a vector of instrunctions |insts| and their users.
350   void KillInstructionsAndUsers(const std::vector<Instruction*>& insts);
351 
352   // Kills all OpDecorate instructions for Location and Component of the
353   // variable whose id is |var_id|.
354   void KillLocationAndComponentDecorations(uint32_t var_id);
355 
356   // If |var| has the extra arrayness for an entry point, reports an error and
357   // returns true. Otherwise, returns false.
358   bool ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var);
359 
360   // If |var| does not have the extra arrayness for an entry point, reports an
361   // error and returns true. Otherwise, returns false.
362   bool ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var);
363 
364   // If |interface_var| has the extra arrayness for an entry point but it does
365   // not have one for another entry point, reports an error and returns false.
366   // Otherwise, returns true. |has_extra_arrayness| denotes whether it has an
367   // extra arrayness for an entry point or not.
368   bool CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
369                                                  bool has_extra_arrayness);
370 
371   // Conducts the scalar replacement for the interface variables used by the
372   // |entry_point|.
373   Pass::Status ReplaceInterfaceVarsWithScalars(Instruction& entry_point);
374 
375   // A set of interface variable ids that were already removed from operands of
376   // the entry point.
377   std::unordered_set<uint32_t>
378       interface_vars_removed_from_entry_point_operands_;
379 
380   // A mapping from ids of new composite construct instructions that load
381   // instructions are replaced with to the recursive depth of the component of
382   // load that the new component construct instruction is used for.
383   std::unordered_map<uint32_t, uint32_t> composite_ids_to_component_depths;
384 
385   // A set of interface variables with the extra arrayness for any of the entry
386   // points.
387   std::unordered_set<Instruction*> vars_with_extra_arrayness;
388 
389   // A set of interface variables without the extra arrayness for any of the
390   // entry points.
391   std::unordered_set<Instruction*> vars_without_extra_arrayness;
392 };
393 
394 }  // namespace opt
395 }  // namespace spvtools
396 
397 #endif  // SOURCE_OPT_INTERFACE_VAR_SROA_H_
398