1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/strength_reduction_pass.h"
16 
17 #include <cstring>
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "source/opt/def_use_manager.h"
23 #include "source/opt/ir_context.h"
24 #include "source/opt/log.h"
25 #include "source/opt/reflect.h"
26 
27 namespace spvtools {
28 namespace opt {
29 namespace {
30 // Count the number of trailing zeros in the binary representation of
31 // |constVal|.
CountTrailingZeros(uint32_t constVal)32 uint32_t CountTrailingZeros(uint32_t constVal) {
33   // Faster if we use the hardware count trailing zeros instruction.
34   // If not available, we could create a table.
35   uint32_t shiftAmount = 0;
36   while ((constVal & 1) == 0) {
37     ++shiftAmount;
38     constVal = (constVal >> 1);
39   }
40   return shiftAmount;
41 }
42 
43 // Return true if |val| is a power of 2.
IsPowerOf2(uint32_t val)44 bool IsPowerOf2(uint32_t val) {
45   // The idea is that the & will clear out the least
46   // significant 1 bit.  If it is a power of 2, then
47   // there is exactly 1 bit set, and the value becomes 0.
48   if (val == 0) return false;
49   return ((val - 1) & val) == 0;
50 }
51 
52 }  // namespace
53 
Process()54 Pass::Status StrengthReductionPass::Process() {
55   // Initialize the member variables on a per module basis.
56   bool modified = false;
57   int32_type_id_ = 0;
58   uint32_type_id_ = 0;
59   std::memset(constant_ids_, 0, sizeof(constant_ids_));
60 
61   FindIntTypesAndConstants();
62   modified = ScanFunctions();
63   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
64 }
65 
ReplaceMultiplyByPowerOf2( BasicBlock::iterator* inst)66 bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
67     BasicBlock::iterator* inst) {
68   assert((*inst)->opcode() == spv::Op::OpIMul &&
69          "Only works for multiplication of integers.");
70   bool modified = false;
71 
72   // Currently only works on 32-bit integers.
73   if ((*inst)->type_id() != int32_type_id_ &&
74       (*inst)->type_id() != uint32_type_id_) {
75     return modified;
76   }
77 
78   // Check the operands for a constant that is a power of 2.
79   for (int i = 0; i < 2; i++) {
80     uint32_t opId = (*inst)->GetSingleWordInOperand(i);
81     Instruction* opInst = get_def_use_mgr()->GetDef(opId);
82     if (opInst->opcode() == spv::Op::OpConstant) {
83       // We found a constant operand.
84       uint32_t constVal = opInst->GetSingleWordOperand(2);
85 
86       if (IsPowerOf2(constVal)) {
87         modified = true;
88         uint32_t shiftAmount = CountTrailingZeros(constVal);
89         uint32_t shiftConstResultId = GetConstantId(shiftAmount);
90 
91         // Create the new instruction.
92         uint32_t newResultId = TakeNextId();
93         std::vector<Operand> newOperands;
94         newOperands.push_back((*inst)->GetInOperand(1 - i));
95         Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
96                              {shiftConstResultId});
97         newOperands.push_back(shiftOperand);
98         std::unique_ptr<Instruction> newInstruction(
99             new Instruction(context(), spv::Op::OpShiftLeftLogical,
100                             (*inst)->type_id(), newResultId, newOperands));
101 
102         // Insert the new instruction and update the data structures.
103         (*inst) = (*inst).InsertBefore(std::move(newInstruction));
104         get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
105         ++(*inst);
106         context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
107 
108         // Remove the old instruction.
109         Instruction* inst_to_delete = &*(*inst);
110         --(*inst);
111         context()->KillInst(inst_to_delete);
112 
113         // We do not want to replace the instruction twice if both operands
114         // are constants that are a power of 2.  So we break here.
115         break;
116       }
117     }
118   }
119 
120   return modified;
121 }
122 
FindIntTypesAndConstants()123 void StrengthReductionPass::FindIntTypesAndConstants() {
124   analysis::Integer int32(32, true);
125   int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
126   analysis::Integer uint32(32, false);
127   uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
128   for (auto iter = get_module()->types_values_begin();
129        iter != get_module()->types_values_end(); ++iter) {
130     switch (iter->opcode()) {
131       case spv::Op::OpConstant:
132         if (iter->type_id() == uint32_type_id_) {
133           uint32_t value = iter->GetSingleWordOperand(2);
134           if (value <= 32) constant_ids_[value] = iter->result_id();
135         }
136         break;
137       default:
138         break;
139     }
140   }
141 }
142 
GetConstantId(uint32_t val)143 uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
144   assert(val <= 32 &&
145          "This function does not handle constants larger than 32.");
146 
147   if (constant_ids_[val] == 0) {
148     if (uint32_type_id_ == 0) {
149       analysis::Integer uint(32, false);
150       uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
151     }
152 
153     // Construct the constant.
154     uint32_t resultId = TakeNextId();
155     Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
156                      {val});
157     std::unique_ptr<Instruction> newConstant(new Instruction(
158         context(), spv::Op::OpConstant, uint32_type_id_, resultId, {constant}));
159     get_module()->AddGlobalValue(std::move(newConstant));
160 
161     // Notify the DefUseManager about this constant.
162     auto constantIter = --get_module()->types_values_end();
163     get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
164 
165     // Store the result id for next time.
166     constant_ids_[val] = resultId;
167   }
168 
169   return constant_ids_[val];
170 }
171 
ScanFunctions()172 bool StrengthReductionPass::ScanFunctions() {
173   // I did not use |ForEachInst| in the module because the function that acts on
174   // the instruction gets a pointer to the instruction.  We cannot use that to
175   // insert a new instruction.  I want an iterator.
176   bool modified = false;
177   for (auto& func : *get_module()) {
178     for (auto& bb : func) {
179       for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
180         switch (inst->opcode()) {
181           case spv::Op::OpIMul:
182             if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
183             break;
184           default:
185             break;
186         }
187       }
188     }
189   }
190   return modified;
191 }
192 
193 }  // namespace opt
194 }  // namespace spvtools
195