1// Copyright (c) 2017 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include "source/opt/strength_reduction_pass.h" 16 17#include <cstring> 18#include <memory> 19#include <utility> 20#include <vector> 21 22#include "source/opt/def_use_manager.h" 23#include "source/opt/ir_context.h" 24#include "source/opt/log.h" 25#include "source/opt/reflect.h" 26 27namespace spvtools { 28namespace opt { 29namespace { 30// Count the number of trailing zeros in the binary representation of 31// |constVal|. 32uint32_t CountTrailingZeros(uint32_t constVal) { 33 // Faster if we use the hardware count trailing zeros instruction. 34 // If not available, we could create a table. 35 uint32_t shiftAmount = 0; 36 while ((constVal & 1) == 0) { 37 ++shiftAmount; 38 constVal = (constVal >> 1); 39 } 40 return shiftAmount; 41} 42 43// Return true if |val| is a power of 2. 44bool IsPowerOf2(uint32_t val) { 45 // The idea is that the & will clear out the least 46 // significant 1 bit. If it is a power of 2, then 47 // there is exactly 1 bit set, and the value becomes 0. 48 if (val == 0) return false; 49 return ((val - 1) & val) == 0; 50} 51 52} // namespace 53 54Pass::Status StrengthReductionPass::Process() { 55 // Initialize the member variables on a per module basis. 56 bool modified = false; 57 int32_type_id_ = 0; 58 uint32_type_id_ = 0; 59 std::memset(constant_ids_, 0, sizeof(constant_ids_)); 60 61 FindIntTypesAndConstants(); 62 modified = ScanFunctions(); 63 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); 64} 65 66bool StrengthReductionPass::ReplaceMultiplyByPowerOf2( 67 BasicBlock::iterator* inst) { 68 assert((*inst)->opcode() == spv::Op::OpIMul && 69 "Only works for multiplication of integers."); 70 bool modified = false; 71 72 // Currently only works on 32-bit integers. 73 if ((*inst)->type_id() != int32_type_id_ && 74 (*inst)->type_id() != uint32_type_id_) { 75 return modified; 76 } 77 78 // Check the operands for a constant that is a power of 2. 79 for (int i = 0; i < 2; i++) { 80 uint32_t opId = (*inst)->GetSingleWordInOperand(i); 81 Instruction* opInst = get_def_use_mgr()->GetDef(opId); 82 if (opInst->opcode() == spv::Op::OpConstant) { 83 // We found a constant operand. 84 uint32_t constVal = opInst->GetSingleWordOperand(2); 85 86 if (IsPowerOf2(constVal)) { 87 modified = true; 88 uint32_t shiftAmount = CountTrailingZeros(constVal); 89 uint32_t shiftConstResultId = GetConstantId(shiftAmount); 90 91 // Create the new instruction. 92 uint32_t newResultId = TakeNextId(); 93 std::vector<Operand> newOperands; 94 newOperands.push_back((*inst)->GetInOperand(1 - i)); 95 Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID, 96 {shiftConstResultId}); 97 newOperands.push_back(shiftOperand); 98 std::unique_ptr<Instruction> newInstruction( 99 new Instruction(context(), spv::Op::OpShiftLeftLogical, 100 (*inst)->type_id(), newResultId, newOperands)); 101 102 // Insert the new instruction and update the data structures. 103 (*inst) = (*inst).InsertBefore(std::move(newInstruction)); 104 get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst)); 105 ++(*inst); 106 context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId); 107 108 // Remove the old instruction. 109 Instruction* inst_to_delete = &*(*inst); 110 --(*inst); 111 context()->KillInst(inst_to_delete); 112 113 // We do not want to replace the instruction twice if both operands 114 // are constants that are a power of 2. So we break here. 115 break; 116 } 117 } 118 } 119 120 return modified; 121} 122 123void StrengthReductionPass::FindIntTypesAndConstants() { 124 analysis::Integer int32(32, true); 125 int32_type_id_ = context()->get_type_mgr()->GetId(&int32); 126 analysis::Integer uint32(32, false); 127 uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32); 128 for (auto iter = get_module()->types_values_begin(); 129 iter != get_module()->types_values_end(); ++iter) { 130 switch (iter->opcode()) { 131 case spv::Op::OpConstant: 132 if (iter->type_id() == uint32_type_id_) { 133 uint32_t value = iter->GetSingleWordOperand(2); 134 if (value <= 32) constant_ids_[value] = iter->result_id(); 135 } 136 break; 137 default: 138 break; 139 } 140 } 141} 142 143uint32_t StrengthReductionPass::GetConstantId(uint32_t val) { 144 assert(val <= 32 && 145 "This function does not handle constants larger than 32."); 146 147 if (constant_ids_[val] == 0) { 148 if (uint32_type_id_ == 0) { 149 analysis::Integer uint(32, false); 150 uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint); 151 } 152 153 // Construct the constant. 154 uint32_t resultId = TakeNextId(); 155 Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, 156 {val}); 157 std::unique_ptr<Instruction> newConstant(new Instruction( 158 context(), spv::Op::OpConstant, uint32_type_id_, resultId, {constant})); 159 get_module()->AddGlobalValue(std::move(newConstant)); 160 161 // Notify the DefUseManager about this constant. 162 auto constantIter = --get_module()->types_values_end(); 163 get_def_use_mgr()->AnalyzeInstDef(&*constantIter); 164 165 // Store the result id for next time. 166 constant_ids_[val] = resultId; 167 } 168 169 return constant_ids_[val]; 170} 171 172bool StrengthReductionPass::ScanFunctions() { 173 // I did not use |ForEachInst| in the module because the function that acts on 174 // the instruction gets a pointer to the instruction. We cannot use that to 175 // insert a new instruction. I want an iterator. 176 bool modified = false; 177 for (auto& func : *get_module()) { 178 for (auto& bb : func) { 179 for (auto inst = bb.begin(); inst != bb.end(); ++inst) { 180 switch (inst->opcode()) { 181 case spv::Op::OpIMul: 182 if (ReplaceMultiplyByPowerOf2(&inst)) modified = true; 183 break; 184 default: 185 break; 186 } 187 } 188 } 189 } 190 return modified; 191} 192 193} // namespace opt 194} // namespace spvtools 195