1// Copyright (c) 2017 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/opt/strength_reduction_pass.h"
16
17#include <cstring>
18#include <memory>
19#include <utility>
20#include <vector>
21
22#include "source/opt/def_use_manager.h"
23#include "source/opt/ir_context.h"
24#include "source/opt/log.h"
25#include "source/opt/reflect.h"
26
27namespace spvtools {
28namespace opt {
29namespace {
30// Count the number of trailing zeros in the binary representation of
31// |constVal|.
32uint32_t CountTrailingZeros(uint32_t constVal) {
33  // Faster if we use the hardware count trailing zeros instruction.
34  // If not available, we could create a table.
35  uint32_t shiftAmount = 0;
36  while ((constVal & 1) == 0) {
37    ++shiftAmount;
38    constVal = (constVal >> 1);
39  }
40  return shiftAmount;
41}
42
43// Return true if |val| is a power of 2.
44bool IsPowerOf2(uint32_t val) {
45  // The idea is that the & will clear out the least
46  // significant 1 bit.  If it is a power of 2, then
47  // there is exactly 1 bit set, and the value becomes 0.
48  if (val == 0) return false;
49  return ((val - 1) & val) == 0;
50}
51
52}  // namespace
53
54Pass::Status StrengthReductionPass::Process() {
55  // Initialize the member variables on a per module basis.
56  bool modified = false;
57  int32_type_id_ = 0;
58  uint32_type_id_ = 0;
59  std::memset(constant_ids_, 0, sizeof(constant_ids_));
60
61  FindIntTypesAndConstants();
62  modified = ScanFunctions();
63  return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
64}
65
66bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
67    BasicBlock::iterator* inst) {
68  assert((*inst)->opcode() == spv::Op::OpIMul &&
69         "Only works for multiplication of integers.");
70  bool modified = false;
71
72  // Currently only works on 32-bit integers.
73  if ((*inst)->type_id() != int32_type_id_ &&
74      (*inst)->type_id() != uint32_type_id_) {
75    return modified;
76  }
77
78  // Check the operands for a constant that is a power of 2.
79  for (int i = 0; i < 2; i++) {
80    uint32_t opId = (*inst)->GetSingleWordInOperand(i);
81    Instruction* opInst = get_def_use_mgr()->GetDef(opId);
82    if (opInst->opcode() == spv::Op::OpConstant) {
83      // We found a constant operand.
84      uint32_t constVal = opInst->GetSingleWordOperand(2);
85
86      if (IsPowerOf2(constVal)) {
87        modified = true;
88        uint32_t shiftAmount = CountTrailingZeros(constVal);
89        uint32_t shiftConstResultId = GetConstantId(shiftAmount);
90
91        // Create the new instruction.
92        uint32_t newResultId = TakeNextId();
93        std::vector<Operand> newOperands;
94        newOperands.push_back((*inst)->GetInOperand(1 - i));
95        Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
96                             {shiftConstResultId});
97        newOperands.push_back(shiftOperand);
98        std::unique_ptr<Instruction> newInstruction(
99            new Instruction(context(), spv::Op::OpShiftLeftLogical,
100                            (*inst)->type_id(), newResultId, newOperands));
101
102        // Insert the new instruction and update the data structures.
103        (*inst) = (*inst).InsertBefore(std::move(newInstruction));
104        get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
105        ++(*inst);
106        context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
107
108        // Remove the old instruction.
109        Instruction* inst_to_delete = &*(*inst);
110        --(*inst);
111        context()->KillInst(inst_to_delete);
112
113        // We do not want to replace the instruction twice if both operands
114        // are constants that are a power of 2.  So we break here.
115        break;
116      }
117    }
118  }
119
120  return modified;
121}
122
123void StrengthReductionPass::FindIntTypesAndConstants() {
124  analysis::Integer int32(32, true);
125  int32_type_id_ = context()->get_type_mgr()->GetId(&int32);
126  analysis::Integer uint32(32, false);
127  uint32_type_id_ = context()->get_type_mgr()->GetId(&uint32);
128  for (auto iter = get_module()->types_values_begin();
129       iter != get_module()->types_values_end(); ++iter) {
130    switch (iter->opcode()) {
131      case spv::Op::OpConstant:
132        if (iter->type_id() == uint32_type_id_) {
133          uint32_t value = iter->GetSingleWordOperand(2);
134          if (value <= 32) constant_ids_[value] = iter->result_id();
135        }
136        break;
137      default:
138        break;
139    }
140  }
141}
142
143uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
144  assert(val <= 32 &&
145         "This function does not handle constants larger than 32.");
146
147  if (constant_ids_[val] == 0) {
148    if (uint32_type_id_ == 0) {
149      analysis::Integer uint(32, false);
150      uint32_type_id_ = context()->get_type_mgr()->GetTypeInstruction(&uint);
151    }
152
153    // Construct the constant.
154    uint32_t resultId = TakeNextId();
155    Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
156                     {val});
157    std::unique_ptr<Instruction> newConstant(new Instruction(
158        context(), spv::Op::OpConstant, uint32_type_id_, resultId, {constant}));
159    get_module()->AddGlobalValue(std::move(newConstant));
160
161    // Notify the DefUseManager about this constant.
162    auto constantIter = --get_module()->types_values_end();
163    get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
164
165    // Store the result id for next time.
166    constant_ids_[val] = resultId;
167  }
168
169  return constant_ids_[val];
170}
171
172bool StrengthReductionPass::ScanFunctions() {
173  // I did not use |ForEachInst| in the module because the function that acts on
174  // the instruction gets a pointer to the instruction.  We cannot use that to
175  // insert a new instruction.  I want an iterator.
176  bool modified = false;
177  for (auto& func : *get_module()) {
178    for (auto& bb : func) {
179      for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
180        switch (inst->opcode()) {
181          case spv::Op::OpIMul:
182            if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
183            break;
184          default:
185            break;
186        }
187      }
188    }
189  }
190  return modified;
191}
192
193}  // namespace opt
194}  // namespace spvtools
195