1// Copyright 2021 the V8 project authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "src/compiler/wasm-inlining.h"
6
7#include "src/compiler/all-nodes.h"
8#include "src/compiler/compiler-source-position-table.h"
9#include "src/compiler/node-matchers.h"
10#include "src/compiler/wasm-compiler.h"
11#include "src/wasm/function-body-decoder.h"
12#include "src/wasm/graph-builder-interface.h"
13#include "src/wasm/wasm-features.h"
14#include "src/wasm/wasm-module.h"
15#include "src/wasm/wasm-subtyping.h"
16
17namespace v8 {
18namespace internal {
19namespace compiler {
20
21Reduction WasmInliner::Reduce(Node* node) {
22  switch (node->opcode()) {
23    case IrOpcode::kCall:
24    case IrOpcode::kTailCall:
25      return ReduceCall(node);
26    default:
27      return NoChange();
28  }
29}
30
31#define TRACE(...) \
32  if (FLAG_trace_wasm_inlining) PrintF(__VA_ARGS__)
33
34void WasmInliner::Trace(Node* call, int inlinee, const char* decision) {
35  TRACE("[function %d: considering node %d, call to %d: %s]\n", function_index_,
36        call->id(), inlinee, decision);
37}
38
39uint32_t WasmInliner::FindOriginatingFunction(Node* call) {
40  DCHECK_EQ(inlined_functions_.size(), first_node_id_.size());
41  NodeId id = call->id();
42  if (inlined_functions_.size() == 0 || id < first_node_id_[0]) {
43    return function_index_;
44  }
45  for (size_t i = 1; i < first_node_id_.size(); i++) {
46    if (id < first_node_id_[i]) return inlined_functions_[i - 1];
47  }
48  DCHECK_GE(id, first_node_id_.back());
49  return inlined_functions_.back();
50}
51
52int WasmInliner::GetCallCount(Node* call) {
53  if (!FLAG_wasm_speculative_inlining) return 0;
54  base::MutexGuard guard(&module()->type_feedback.mutex);
55  wasm::WasmCodePosition position =
56      source_positions_->GetSourcePosition(call).ScriptOffset();
57  uint32_t func = FindOriginatingFunction(call);
58  auto maybe_feedback =
59      module()->type_feedback.feedback_for_function.find(func);
60  if (maybe_feedback == module()->type_feedback.feedback_for_function.end()) {
61    return 0;
62  }
63  wasm::FunctionTypeFeedback feedback = maybe_feedback->second;
64  // It's possible that we haven't processed the feedback yet. Currently,
65  // this can happen for targets of call_direct that haven't gotten hot yet,
66  // and for functions where Liftoff bailed out.
67  if (feedback.feedback_vector.size() == 0) return 0;
68  auto index_in_vector = feedback.positions.find(position);
69  if (index_in_vector == feedback.positions.end()) return 0;
70  return feedback.feedback_vector[index_in_vector->second]
71      .absolute_call_frequency;
72}
73
74// TODO(12166): Save inlined frames for trap/--trace-wasm purposes. Consider
75//              tail calls.
76Reduction WasmInliner::ReduceCall(Node* call) {
77  DCHECK(call->opcode() == IrOpcode::kCall ||
78         call->opcode() == IrOpcode::kTailCall);
79
80  if (seen_.find(call) != seen_.end()) {
81    TRACE("function %d: have already seen node %d, skipping\n", function_index_,
82          call->id());
83    return NoChange();
84  }
85  seen_.insert(call);
86
87  Node* callee = NodeProperties::GetValueInput(call, 0);
88  IrOpcode::Value reloc_opcode = mcgraph_->machine()->Is32()
89                                     ? IrOpcode::kRelocatableInt32Constant
90                                     : IrOpcode::kRelocatableInt64Constant;
91  if (callee->opcode() != reloc_opcode) {
92    TRACE("[function %d: considering node %d... not a relocatable constant]\n",
93          function_index_, call->id());
94    return NoChange();
95  }
96  auto info = OpParameter<RelocatablePtrConstantInfo>(callee->op());
97  uint32_t inlinee_index = static_cast<uint32_t>(info.value());
98  if (info.rmode() != RelocInfo::WASM_CALL) {
99    Trace(call, inlinee_index, "not a wasm call");
100    return NoChange();
101  }
102  if (inlinee_index < module()->num_imported_functions) {
103    Trace(call, inlinee_index, "imported function");
104    return NoChange();
105  }
106  if (inlinee_index == function_index_) {
107    Trace(call, inlinee_index, "recursive call");
108    return NoChange();
109  }
110
111  Trace(call, inlinee_index, "adding to inlining candidates!");
112
113  int call_count = GetCallCount(call);
114
115  CHECK_LT(inlinee_index, module()->functions.size());
116  const wasm::WasmFunction* inlinee = &module()->functions[inlinee_index];
117  base::Vector<const byte> function_bytes = wire_bytes_->GetCode(inlinee->code);
118
119  CandidateInfo candidate{call, inlinee_index, call_count,
120                          function_bytes.length()};
121
122  inlining_candidates_.push(candidate);
123  return NoChange();
124}
125
126bool SmallEnoughToInline(size_t current_graph_size, uint32_t candidate_size) {
127  if (WasmInliner::graph_size_allows_inlining(current_graph_size)) {
128    return true;
129  }
130  // For truly tiny functions, let's be a bit more generous.
131  return candidate_size < 10 &&
132         WasmInliner::graph_size_allows_inlining(current_graph_size - 100);
133}
134
135void WasmInliner::Trace(const CandidateInfo& candidate, const char* decision) {
136  TRACE(
137      "  [function %d: considering candidate {@%d, index=%d, count=%d, "
138      "size=%d}: %s]\n",
139      function_index_, candidate.node->id(), candidate.inlinee_index,
140      candidate.call_count, candidate.wire_byte_size, decision);
141}
142
143void WasmInliner::Finalize() {
144  TRACE("function %d %s: going though inlining candidates...\n",
145        function_index_, debug_name_);
146  if (inlining_candidates_.empty()) return;
147  while (!inlining_candidates_.empty()) {
148    CandidateInfo candidate = inlining_candidates_.top();
149    inlining_candidates_.pop();
150    Node* call = candidate.node;
151    if (call->IsDead()) {
152      Trace(candidate, "dead node");
153      continue;
154    }
155    int min_count_for_inlining = candidate.wire_byte_size / 2;
156    if (candidate.call_count < min_count_for_inlining) {
157      Trace(candidate, "not called often enough");
158      continue;
159    }
160    // We could build the candidate's graph first and consider its node count,
161    // but it turns out that wire byte size and node count are quite strongly
162    // correlated, at about 1.16 nodes per wire byte (measured for J2Wasm).
163    if (!SmallEnoughToInline(current_graph_size_, candidate.wire_byte_size)) {
164      Trace(candidate, "not enough inlining budget");
165      continue;
166    }
167    const wasm::WasmFunction* inlinee =
168        &module()->functions[candidate.inlinee_index];
169    base::Vector<const byte> function_bytes =
170        wire_bytes_->GetCode(inlinee->code);
171    // We use the signature based on the real argument types stored in the call
172    // node. This is more specific than the callee's formal signature and might
173    // enable some optimizations.
174    const wasm::FunctionSig* specialized_sig =
175        CallDescriptorOf(call->op())->wasm_sig();
176
177#if DEBUG
178    // Check that the real signature is a subtype of the formal one.
179    const wasm::FunctionSig* formal_sig =
180        WasmGraphBuilder::Int64LoweredSig(zone(), inlinee->sig);
181    CHECK_EQ(specialized_sig->parameter_count(), formal_sig->parameter_count());
182    CHECK_EQ(specialized_sig->return_count(), formal_sig->return_count());
183    for (size_t i = 0; i < specialized_sig->parameter_count(); i++) {
184      CHECK(wasm::IsSubtypeOf(specialized_sig->GetParam(i),
185                              formal_sig->GetParam(i), module()));
186    }
187    for (size_t i = 0; i < specialized_sig->return_count(); i++) {
188      CHECK(wasm::IsSubtypeOf(formal_sig->GetReturn(i),
189                              specialized_sig->GetReturn(i), module()));
190    }
191#endif
192
193    wasm::WasmFeatures detected;
194    std::vector<WasmLoopInfo> inlinee_loop_infos;
195
196    size_t subgraph_min_node_id = graph()->NodeCount();
197    Node* inlinee_start;
198    Node* inlinee_end;
199    for (const wasm::FunctionSig* sig = specialized_sig;;) {
200      const wasm::FunctionBody inlinee_body(sig, inlinee->code.offset(),
201                                            function_bytes.begin(),
202                                            function_bytes.end());
203      WasmGraphBuilder builder(env_, zone(), mcgraph_, inlinee_body.sig,
204                               source_positions_);
205      Graph::SubgraphScope scope(graph());
206      wasm::DecodeResult result = wasm::BuildTFGraph(
207          zone()->allocator(), env_->enabled_features, module(), &builder,
208          &detected, inlinee_body, &inlinee_loop_infos, node_origins_,
209          candidate.inlinee_index,
210          NodeProperties::IsExceptionalCall(call)
211              ? wasm::kInlinedHandledCall
212              : wasm::kInlinedNonHandledCall);
213      if (result.ok()) {
214        builder.LowerInt64(WasmGraphBuilder::kCalledFromWasm);
215        inlinee_start = graph()->start();
216        inlinee_end = graph()->end();
217        break;
218      }
219      if (sig == specialized_sig) {
220        // One possible reason for failure is the opportunistic signature
221        // specialization. Try again without that.
222        sig = inlinee->sig;
223        inlinee_loop_infos.clear();
224        Trace(candidate, "retrying with original signature");
225        continue;
226      }
227      // Otherwise report failure.
228      Trace(candidate, "failed to compile");
229      return;
230    }
231
232    size_t additional_nodes = graph()->NodeCount() - subgraph_min_node_id;
233    Trace(candidate, "inlining!");
234    current_graph_size_ += additional_nodes;
235    inlined_functions_.push_back(candidate.inlinee_index);
236    static_assert(std::is_same_v<NodeId, uint32_t>);
237    first_node_id_.push_back(static_cast<uint32_t>(subgraph_min_node_id));
238
239    if (call->opcode() == IrOpcode::kCall) {
240      InlineCall(call, inlinee_start, inlinee_end, inlinee->sig,
241                 subgraph_min_node_id);
242    } else {
243      InlineTailCall(call, inlinee_start, inlinee_end);
244    }
245    call->Kill();
246    loop_infos_->insert(loop_infos_->end(), inlinee_loop_infos.begin(),
247                        inlinee_loop_infos.end());
248    // Returning after only one inlining has been tried and found worse.
249  }
250}
251
252/* Rewire callee formal parameters to the call-site real parameters. Rewire
253 * effect and control dependencies of callee's start node with the respective
254 * inputs of the call node.
255 */
256void WasmInliner::RewireFunctionEntry(Node* call, Node* callee_start) {
257  Node* control = NodeProperties::GetControlInput(call);
258  Node* effect = NodeProperties::GetEffectInput(call);
259
260  for (Edge edge : callee_start->use_edges()) {
261    Node* use = edge.from();
262    switch (use->opcode()) {
263      case IrOpcode::kParameter: {
264        // Index 0 is the callee node.
265        int index = 1 + ParameterIndexOf(use->op());
266        Replace(use, NodeProperties::GetValueInput(call, index));
267        break;
268      }
269      default:
270        if (NodeProperties::IsEffectEdge(edge)) {
271          edge.UpdateTo(effect);
272        } else if (NodeProperties::IsControlEdge(edge)) {
273          // Projections pointing to the inlinee start are floating control.
274          // They should point to the graph's start.
275          edge.UpdateTo(use->opcode() == IrOpcode::kProjection
276                            ? graph()->start()
277                            : control);
278        } else {
279          UNREACHABLE();
280        }
281        Revisit(edge.from());
282        break;
283    }
284  }
285}
286
287void WasmInliner::InlineTailCall(Node* call, Node* callee_start,
288                                 Node* callee_end) {
289  DCHECK_EQ(call->opcode(), IrOpcode::kTailCall);
290  // 1) Rewire function entry.
291  RewireFunctionEntry(call, callee_start);
292  // 2) For tail calls, all we have to do is rewire all terminators of the
293  // inlined graph to the end of the caller graph.
294  for (Node* const input : callee_end->inputs()) {
295    DCHECK(IrOpcode::IsGraphTerminator(input->opcode()));
296    NodeProperties::MergeControlToEnd(graph(), common(), input);
297  }
298  for (Edge edge_to_end : call->use_edges()) {
299    DCHECK_EQ(edge_to_end.from(), graph()->end());
300    edge_to_end.UpdateTo(mcgraph()->Dead());
301  }
302  callee_end->Kill();
303  call->Kill();
304  Revisit(graph()->end());
305}
306
307namespace {
308// graph-builder-interface generates a dangling exception handler for each
309// throwing call in the inlinee. This might be followed by a LoopExit node.
310Node* DanglingHandler(Node* call) {
311  Node* if_exception = nullptr;
312  for (Node* use : call->uses()) {
313    if (use->opcode() == IrOpcode::kIfException) {
314      if_exception = use;
315      break;
316    }
317  }
318  DCHECK_NOT_NULL(if_exception);
319
320  // If this handler is dangling, return it.
321  if (if_exception->UseCount() == 0) return if_exception;
322
323  for (Node* use : if_exception->uses()) {
324    // Otherwise, look for a LoopExit use of this handler.
325    if (use->opcode() == IrOpcode::kLoopExit) {
326      for (Node* loop_exit_use : use->uses()) {
327        if (loop_exit_use->opcode() != IrOpcode::kLoopExitEffect &&
328            loop_exit_use->opcode() != IrOpcode::kLoopExitValue) {
329          // This LoopExit has a use other than LoopExitEffect/Value, so it is
330          // not dangling.
331          return nullptr;
332        }
333      }
334      return use;
335    }
336  }
337
338  return nullptr;
339}
340}  // namespace
341
342void WasmInliner::InlineCall(Node* call, Node* callee_start, Node* callee_end,
343                             const wasm::FunctionSig* inlinee_sig,
344                             size_t subgraph_min_node_id) {
345  DCHECK_EQ(call->opcode(), IrOpcode::kCall);
346
347  // 0) Before doing anything, if {call} has an exception handler, collect all
348  // unhandled calls in the subgraph.
349  Node* handler = nullptr;
350  std::vector<Node*> dangling_handlers;
351  if (NodeProperties::IsExceptionalCall(call, &handler)) {
352    AllNodes subgraph_nodes(zone(), callee_end, graph());
353    for (Node* node : subgraph_nodes.reachable) {
354      if (node->id() >= subgraph_min_node_id &&
355          !node->op()->HasProperty(Operator::kNoThrow)) {
356        Node* dangling_handler = DanglingHandler(node);
357        if (dangling_handler != nullptr) {
358          dangling_handlers.push_back(dangling_handler);
359        }
360      }
361    }
362  }
363
364  // 1) Rewire function entry.
365  RewireFunctionEntry(call, callee_start);
366
367  // 2) Handle all graph terminators for the callee.
368  NodeVector return_nodes(zone());
369  for (Node* const input : callee_end->inputs()) {
370    DCHECK(IrOpcode::IsGraphTerminator(input->opcode()));
371    switch (input->opcode()) {
372      case IrOpcode::kReturn:
373        // Returns are collected to be rewired into the caller graph later.
374        return_nodes.push_back(input);
375        break;
376      case IrOpcode::kDeoptimize:
377      case IrOpcode::kTerminate:
378      case IrOpcode::kThrow:
379        NodeProperties::MergeControlToEnd(graph(), common(), input);
380        Revisit(graph()->end());
381        break;
382      case IrOpcode::kTailCall: {
383        // A tail call in the callee inlined in a regular call in the caller has
384        // to be transformed into a regular call, and then returned from the
385        // inlinee. It will then be handled like any other return.
386        auto descriptor = CallDescriptorOf(input->op());
387        NodeProperties::ChangeOp(input, common()->Call(descriptor));
388        int return_arity = static_cast<int>(inlinee_sig->return_count());
389        NodeVector return_inputs(zone());
390        // The first input of a return node is always the 0 constant.
391        return_inputs.push_back(graph()->NewNode(common()->Int32Constant(0)));
392        if (return_arity == 1) {
393          return_inputs.push_back(input);
394        } else if (return_arity > 1) {
395          for (int i = 0; i < return_arity; i++) {
396            return_inputs.push_back(
397                graph()->NewNode(common()->Projection(i), input, input));
398          }
399        }
400
401        // Add effect and control inputs.
402        return_inputs.push_back(input->op()->EffectOutputCount() > 0
403                                    ? input
404                                    : NodeProperties::GetEffectInput(input));
405        return_inputs.push_back(input->op()->ControlOutputCount() > 0
406                                    ? input
407                                    : NodeProperties::GetControlInput(input));
408
409        Node* ret = graph()->NewNode(common()->Return(return_arity),
410                                     static_cast<int>(return_inputs.size()),
411                                     return_inputs.data());
412        return_nodes.push_back(ret);
413        break;
414      }
415      default:
416        UNREACHABLE();
417    }
418  }
419  callee_end->Kill();
420
421  // 3) Rewire unhandled calls to the handler.
422  int handler_count = static_cast<int>(dangling_handlers.size());
423
424  if (handler_count > 0) {
425    Node* control_output =
426        graph()->NewNode(common()->Merge(handler_count), handler_count,
427                         dangling_handlers.data());
428    std::vector<Node*> effects;
429    std::vector<Node*> values;
430    for (Node* control : dangling_handlers) {
431      if (control->opcode() == IrOpcode::kIfException) {
432        effects.push_back(control);
433        values.push_back(control);
434      } else {
435        DCHECK_EQ(control->opcode(), IrOpcode::kLoopExit);
436        Node* if_exception = control->InputAt(0);
437        DCHECK_EQ(if_exception->opcode(), IrOpcode::kIfException);
438        effects.push_back(graph()->NewNode(common()->LoopExitEffect(),
439                                           if_exception, control));
440        values.push_back(graph()->NewNode(
441            common()->LoopExitValue(MachineRepresentation::kTagged),
442            if_exception, control));
443      }
444    }
445
446    effects.push_back(control_output);
447    values.push_back(control_output);
448    Node* value_output = graph()->NewNode(
449        common()->Phi(MachineRepresentation::kTagged, handler_count),
450        handler_count + 1, values.data());
451    Node* effect_output = graph()->NewNode(common()->EffectPhi(handler_count),
452                                           handler_count + 1, effects.data());
453    ReplaceWithValue(handler, value_output, effect_output, control_output);
454  } else if (handler != nullptr) {
455    // Nothing in the inlined function can throw. Remove the handler.
456    ReplaceWithValue(handler, mcgraph()->Dead(), mcgraph()->Dead(),
457                     mcgraph()->Dead());
458  }
459
460  if (return_nodes.size() > 0) {
461    /* 4) Collect all return site value, effect, and control inputs into phis
462     * and merges. */
463    int const return_count = static_cast<int>(return_nodes.size());
464    NodeVector controls(zone());
465    NodeVector effects(zone());
466    for (Node* const return_node : return_nodes) {
467      controls.push_back(NodeProperties::GetControlInput(return_node));
468      effects.push_back(NodeProperties::GetEffectInput(return_node));
469    }
470    Node* control_output = graph()->NewNode(common()->Merge(return_count),
471                                            return_count, &controls.front());
472    effects.push_back(control_output);
473    Node* effect_output =
474        graph()->NewNode(common()->EffectPhi(return_count),
475                         static_cast<int>(effects.size()), &effects.front());
476
477    // The first input of a return node is discarded. This is because Wasm
478    // functions always return an additional 0 constant as a first return value.
479    DCHECK(
480        Int32Matcher(NodeProperties::GetValueInput(return_nodes[0], 0)).Is(0));
481    int const return_arity = return_nodes[0]->op()->ValueInputCount() - 1;
482    NodeVector values(zone());
483    for (int i = 0; i < return_arity; i++) {
484      NodeVector ith_values(zone());
485      for (Node* const return_node : return_nodes) {
486        Node* value = NodeProperties::GetValueInput(return_node, i + 1);
487        ith_values.push_back(value);
488      }
489      ith_values.push_back(control_output);
490      // Find the correct machine representation for the return values from the
491      // inlinee signature.
492      MachineRepresentation repr =
493          inlinee_sig->GetReturn(i).machine_representation();
494      Node* ith_value_output = graph()->NewNode(
495          common()->Phi(repr, return_count),
496          static_cast<int>(ith_values.size()), &ith_values.front());
497      values.push_back(ith_value_output);
498    }
499    for (Node* return_node : return_nodes) return_node->Kill();
500
501    if (return_arity == 0) {
502      // Void function, no value uses.
503      ReplaceWithValue(call, mcgraph()->Dead(), effect_output, control_output);
504    } else if (return_arity == 1) {
505      // One return value. Just replace value uses of the call node with it.
506      ReplaceWithValue(call, values[0], effect_output, control_output);
507    } else {
508      // Multiple returns. We have to find the projections of the call node and
509      // replace them with the returned values.
510      for (Edge use_edge : call->use_edges()) {
511        if (NodeProperties::IsValueEdge(use_edge)) {
512          Node* use = use_edge.from();
513          DCHECK_EQ(use->opcode(), IrOpcode::kProjection);
514          ReplaceWithValue(use, values[ProjectionIndexOf(use->op())]);
515        }
516      }
517      // All value inputs are replaced by the above loop, so it is ok to use
518      // Dead() as a dummy for value replacement.
519      ReplaceWithValue(call, mcgraph()->Dead(), effect_output, control_output);
520    }
521  } else {
522    // The callee can never return. The call node and all its uses are dead.
523    ReplaceWithValue(call, mcgraph()->Dead(), mcgraph()->Dead(),
524                     mcgraph()->Dead());
525  }
526}
527
528const wasm::WasmModule* WasmInliner::module() const { return env_->module; }
529
530#undef TRACE
531
532}  // namespace compiler
533}  // namespace internal
534}  // namespace v8
535